bugs fixed
This commit is contained in:
parent
aa4ab78eb5
commit
a0e2d8e71f
|
@ -64,27 +64,25 @@ class ExtendedData:
|
|||
return self.instances
|
||||
|
||||
def __split_index_by_pred(self) -> List[np.ndarray]:
|
||||
_pred_label = np.argmax(self.pred_proba_, axis=0)
|
||||
_pred_label = np.argmax(self.pred_proba_, axis=1)
|
||||
|
||||
return [
|
||||
(_pred_label == cl).nonzero()[0]
|
||||
for cl in np.arange(self.pred_proba_.shape[0])
|
||||
for cl in np.arange(self.pred_proba_.shape[1])
|
||||
]
|
||||
|
||||
def split_by_pred(self, return_indexes=False):
|
||||
def _empty_matrix():
|
||||
if isinstance(self.instances, np.ndarray):
|
||||
return np.asarray([], dtype=int)
|
||||
elif isinstance(self.instances, sp.csr_matrix):
|
||||
return sp.csr_matrix(np.empty((0, 0), dtype=int))
|
||||
|
||||
_indexes = self.__split_index_by_pred()
|
||||
if isinstance(self.instances, np.ndarray):
|
||||
_instances = [
|
||||
self.instances[ind] if ind.shape[0] > 0 else np.asarray([], dtype=int)
|
||||
for ind in _indexes
|
||||
]
|
||||
elif isinstance(self.instances, sp.csr_matrix):
|
||||
_instances = [
|
||||
self.instances[ind]
|
||||
if ind.shape[0] > 0
|
||||
else sp.csr_matrix(np.empty((0, 0), dtype=int))
|
||||
for ind in _indexes
|
||||
]
|
||||
_instances = [
|
||||
self.instances[ind] if ind.shape[0] > 0 else _empty_matrix()
|
||||
for ind in _indexes
|
||||
]
|
||||
|
||||
if return_indexes:
|
||||
return _instances, _indexes
|
||||
|
@ -182,7 +180,7 @@ class ExtendedCollection(LabelledCollection):
|
|||
return _counts
|
||||
|
||||
def split_by_pred(self):
|
||||
_ncl = len(self.pred_proba)
|
||||
_ncl = self.pred_proba.shape[1]
|
||||
_instances, _indexes = self.e_data_.split_by_pred(return_indexes=True)
|
||||
_labels = [self.ey[ind] for ind in _indexes]
|
||||
return [
|
||||
|
|
Loading…
Reference in New Issue