diff --git a/quacc/models/cont_table.py b/quacc/models/cont_table.py index 8db240f..e507408 100644 --- a/quacc/models/cont_table.py +++ b/quacc/models/cont_table.py @@ -4,7 +4,7 @@ from copy import deepcopy import numpy as np import quapy.functional as F import scipy -from quapy.data.base import LabelledCollection +from quapy.data.base import LabelledCollection as LC from quapy.method.aggregative import AggregativeQuantifier from quapy.method.base import BaseQuantifier from scipy.sparse import csr_matrix, issparse @@ -15,6 +15,49 @@ from quacc.models.base import ClassifierAccuracyPrediction from quacc.models.utils import get_posteriors_from_h, max_conf, neg_entropy +class LabelledCollection(LC): + def empty_classes(self): + """ + Returns a np.ndarray of empty classes (classes present in self.classes_ but with + no positive instance). In case there is none, then an empty np.ndarray is returned + + :return: np.ndarray + """ + idx = np.argwhere(self.counts() == 0).flatten() + return self.classes_[idx] + + def non_empty_classes(self): + """ + Returns a np.ndarray of non-empty classes (classes present in self.classes_ but with + at least one positive instance). In case there is none, then an empty np.ndarray is returned + + :return: np.ndarray + """ + idx = np.argwhere(self.counts() > 0).flatten() + return self.classes_[idx] + + def has_empty_classes(self): + """ + Checks whether the collection has empty classes + + :return: boolean + """ + return len(self.empty_classes()) > 0 + + def compact_classes(self): + """ + Generates a new LabelledCollection object with no empty classes. It also returns a np.ndarray of + indexes that correspond to the old indexes of the new self.classes_. + + :return: (LabelledCollection, np.ndarray,) + """ + non_empty = self.non_empty_classes() + all_classes = self.classes_ + old_pos = np.searchsorted(all_classes, non_empty) + non_empty_collection = LabelledCollection(*self.Xy, classes=non_empty) + return non_empty_collection, old_pos + + class CAPContingencyTable(ClassifierAccuracyPrediction): def __init__(self, h: BaseEstimator, acc: callable): self.h = h @@ -304,9 +347,9 @@ class QuAcc1xN2(CAPContingencyTableQ, QuAcc): pred_labels = self.h.predict(val.X) true_labels = val.y - n = val.n_classes - classes_dot = np.arange(n**2) - ct_class_idx = classes_dot.reshape(n, n) + self.ncl = val.n_classes + classes_dot = np.arange(self.ncl**2) + ct_class_idx = classes_dot.reshape(self.ncl, self.ncl) X_dot = self._get_X_dot(val.X) y_dot = ct_class_idx[true_labels, pred_labels] @@ -315,7 +358,8 @@ class QuAcc1xN2(CAPContingencyTableQ, QuAcc): def predict_ct(self, X, oracle_prev=None): X_dot = self._get_X_dot(X) - return self.q.quantify(X_dot) + flat_ct = self.q.quantify(X_dot) + return flat_ct.reshape(self.ncl, self.ncl) class QuAcc1xNp1(CAPContingencyTableQ, QuAcc): @@ -343,11 +387,11 @@ class QuAcc1xNp1(CAPContingencyTableQ, QuAcc): pred_labels = self.h.predict(val.X) true_labels = val.y - n = val.n_classes - classes_dot = np.arange(n + 1) + self.ncl = val.n_classes + classes_dot = np.arange(self.ncl + 1) # ct_class_idx = classes_dot.reshape(n, n) - ct_class_idx = np.full((n, n), n) - ct_class_idx[np.diag_indices(n)] = np.arange(n) + ct_class_idx = np.full((self.ncl, self.ncl), self.ncl) + ct_class_idx[np.diag_indices(self.ncl)] = np.arange(self.ncl) X_dot = self._get_X_dot(val.X) y_dot = ct_class_idx[true_labels, pred_labels] @@ -364,7 +408,7 @@ class QuAcc1xNp1(CAPContingencyTableQ, QuAcc): def predict_ct(self, X: LabelledCollection, oracle_prev=None): X_dot = self._get_X_dot(X) ct_compressed = self.q.quantify(X_dot) - return self._get_ct_hat(X.n_classes, ct_compressed) + return self._get_ct_hat(self.ncl, ct_compressed) class QuAccNxN(CAPContingencyTableQ, QuAcc):