bugs fixed

This commit is contained in:
Lorenzo Volpi 2023-11-16 01:36:41 +01:00
parent aa4ab78eb5
commit a0e2d8e71f
1 changed files with 13 additions and 15 deletions

View File

@ -64,27 +64,25 @@ class ExtendedData:
return self.instances
def __split_index_by_pred(self) -> List[np.ndarray]:
_pred_label = np.argmax(self.pred_proba_, axis=0)
_pred_label = np.argmax(self.pred_proba_, axis=1)
return [
(_pred_label == cl).nonzero()[0]
for cl in np.arange(self.pred_proba_.shape[0])
for cl in np.arange(self.pred_proba_.shape[1])
]
def split_by_pred(self, return_indexes=False):
def _empty_matrix():
if isinstance(self.instances, np.ndarray):
return np.asarray([], dtype=int)
elif isinstance(self.instances, sp.csr_matrix):
return sp.csr_matrix(np.empty((0, 0), dtype=int))
_indexes = self.__split_index_by_pred()
if isinstance(self.instances, np.ndarray):
_instances = [
self.instances[ind] if ind.shape[0] > 0 else np.asarray([], dtype=int)
for ind in _indexes
]
elif isinstance(self.instances, sp.csr_matrix):
_instances = [
self.instances[ind]
if ind.shape[0] > 0
else sp.csr_matrix(np.empty((0, 0), dtype=int))
for ind in _indexes
]
_instances = [
self.instances[ind] if ind.shape[0] > 0 else _empty_matrix()
for ind in _indexes
]
if return_indexes:
return _instances, _indexes
@ -182,7 +180,7 @@ class ExtendedCollection(LabelledCollection):
return _counts
def split_by_pred(self):
_ncl = len(self.pred_proba)
_ncl = self.pred_proba.shape[1]
_instances, _indexes = self.e_data_.split_by_pred(return_indexes=True)
_labels = [self.ey[ind] for ind in _indexes]
return [