1xn2 cont_table output fixed

This commit is contained in:
Lorenzo Volpi 2024-04-08 17:58:34 +02:00
parent 5cfd5d87dd
commit 8a087e3e2f
1 changed files with 54 additions and 10 deletions

View File

@ -4,7 +4,7 @@ from copy import deepcopy
import numpy as np
import quapy.functional as F
import scipy
from quapy.data.base import LabelledCollection
from quapy.data.base import LabelledCollection as LC
from quapy.method.aggregative import AggregativeQuantifier
from quapy.method.base import BaseQuantifier
from scipy.sparse import csr_matrix, issparse
@ -15,6 +15,49 @@ from quacc.models.base import ClassifierAccuracyPrediction
from quacc.models.utils import get_posteriors_from_h, max_conf, neg_entropy
class LabelledCollection(LC):
def empty_classes(self):
"""
Returns a np.ndarray of empty classes (classes present in self.classes_ but with
no positive instance). In case there is none, then an empty np.ndarray is returned
:return: np.ndarray
"""
idx = np.argwhere(self.counts() == 0).flatten()
return self.classes_[idx]
def non_empty_classes(self):
"""
Returns a np.ndarray of non-empty classes (classes present in self.classes_ but with
at least one positive instance). In case there is none, then an empty np.ndarray is returned
:return: np.ndarray
"""
idx = np.argwhere(self.counts() > 0).flatten()
return self.classes_[idx]
def has_empty_classes(self):
"""
Checks whether the collection has empty classes
:return: boolean
"""
return len(self.empty_classes()) > 0
def compact_classes(self):
"""
Generates a new LabelledCollection object with no empty classes. It also returns a np.ndarray of
indexes that correspond to the old indexes of the new self.classes_.
:return: (LabelledCollection, np.ndarray,)
"""
non_empty = self.non_empty_classes()
all_classes = self.classes_
old_pos = np.searchsorted(all_classes, non_empty)
non_empty_collection = LabelledCollection(*self.Xy, classes=non_empty)
return non_empty_collection, old_pos
class CAPContingencyTable(ClassifierAccuracyPrediction):
def __init__(self, h: BaseEstimator, acc: callable):
self.h = h
@ -304,9 +347,9 @@ class QuAcc1xN2(CAPContingencyTableQ, QuAcc):
pred_labels = self.h.predict(val.X)
true_labels = val.y
n = val.n_classes
classes_dot = np.arange(n**2)
ct_class_idx = classes_dot.reshape(n, n)
self.ncl = val.n_classes
classes_dot = np.arange(self.ncl**2)
ct_class_idx = classes_dot.reshape(self.ncl, self.ncl)
X_dot = self._get_X_dot(val.X)
y_dot = ct_class_idx[true_labels, pred_labels]
@ -315,7 +358,8 @@ class QuAcc1xN2(CAPContingencyTableQ, QuAcc):
def predict_ct(self, X, oracle_prev=None):
X_dot = self._get_X_dot(X)
return self.q.quantify(X_dot)
flat_ct = self.q.quantify(X_dot)
return flat_ct.reshape(self.ncl, self.ncl)
class QuAcc1xNp1(CAPContingencyTableQ, QuAcc):
@ -343,11 +387,11 @@ class QuAcc1xNp1(CAPContingencyTableQ, QuAcc):
pred_labels = self.h.predict(val.X)
true_labels = val.y
n = val.n_classes
classes_dot = np.arange(n + 1)
self.ncl = val.n_classes
classes_dot = np.arange(self.ncl + 1)
# ct_class_idx = classes_dot.reshape(n, n)
ct_class_idx = np.full((n, n), n)
ct_class_idx[np.diag_indices(n)] = np.arange(n)
ct_class_idx = np.full((self.ncl, self.ncl), self.ncl)
ct_class_idx[np.diag_indices(self.ncl)] = np.arange(self.ncl)
X_dot = self._get_X_dot(val.X)
y_dot = ct_class_idx[true_labels, pred_labels]
@ -364,7 +408,7 @@ class QuAcc1xNp1(CAPContingencyTableQ, QuAcc):
def predict_ct(self, X: LabelledCollection, oracle_prev=None):
X_dot = self._get_X_dot(X)
ct_compressed = self.q.quantify(X_dot)
return self._get_ct_hat(X.n_classes, ct_compressed)
return self._get_ct_hat(self.ncl, ct_compressed)
class QuAccNxN(CAPContingencyTableQ, QuAcc):