bugs fixed
This commit is contained in:
parent
aa4ab78eb5
commit
a0e2d8e71f
|
@ -64,25 +64,23 @@ class ExtendedData:
|
||||||
return self.instances
|
return self.instances
|
||||||
|
|
||||||
def __split_index_by_pred(self) -> List[np.ndarray]:
|
def __split_index_by_pred(self) -> List[np.ndarray]:
|
||||||
_pred_label = np.argmax(self.pred_proba_, axis=0)
|
_pred_label = np.argmax(self.pred_proba_, axis=1)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
(_pred_label == cl).nonzero()[0]
|
(_pred_label == cl).nonzero()[0]
|
||||||
for cl in np.arange(self.pred_proba_.shape[0])
|
for cl in np.arange(self.pred_proba_.shape[1])
|
||||||
]
|
]
|
||||||
|
|
||||||
def split_by_pred(self, return_indexes=False):
|
def split_by_pred(self, return_indexes=False):
|
||||||
_indexes = self.__split_index_by_pred()
|
def _empty_matrix():
|
||||||
if isinstance(self.instances, np.ndarray):
|
if isinstance(self.instances, np.ndarray):
|
||||||
_instances = [
|
return np.asarray([], dtype=int)
|
||||||
self.instances[ind] if ind.shape[0] > 0 else np.asarray([], dtype=int)
|
|
||||||
for ind in _indexes
|
|
||||||
]
|
|
||||||
elif isinstance(self.instances, sp.csr_matrix):
|
elif isinstance(self.instances, sp.csr_matrix):
|
||||||
|
return sp.csr_matrix(np.empty((0, 0), dtype=int))
|
||||||
|
|
||||||
|
_indexes = self.__split_index_by_pred()
|
||||||
_instances = [
|
_instances = [
|
||||||
self.instances[ind]
|
self.instances[ind] if ind.shape[0] > 0 else _empty_matrix()
|
||||||
if ind.shape[0] > 0
|
|
||||||
else sp.csr_matrix(np.empty((0, 0), dtype=int))
|
|
||||||
for ind in _indexes
|
for ind in _indexes
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -182,7 +180,7 @@ class ExtendedCollection(LabelledCollection):
|
||||||
return _counts
|
return _counts
|
||||||
|
|
||||||
def split_by_pred(self):
|
def split_by_pred(self):
|
||||||
_ncl = len(self.pred_proba)
|
_ncl = self.pred_proba.shape[1]
|
||||||
_instances, _indexes = self.e_data_.split_by_pred(return_indexes=True)
|
_instances, _indexes = self.e_data_.split_by_pred(return_indexes=True)
|
||||||
_labels = [self.ey[ind] for ind in _indexes]
|
_labels = [self.ey[ind] for ind in _indexes]
|
||||||
return [
|
return [
|
||||||
|
|
Loading…
Reference in New Issue