1xn2 cont_table output fixed
This commit is contained in:
parent
5cfd5d87dd
commit
8a087e3e2f
|
@ -4,7 +4,7 @@ from copy import deepcopy
|
|||
import numpy as np
|
||||
import quapy.functional as F
|
||||
import scipy
|
||||
from quapy.data.base import LabelledCollection
|
||||
from quapy.data.base import LabelledCollection as LC
|
||||
from quapy.method.aggregative import AggregativeQuantifier
|
||||
from quapy.method.base import BaseQuantifier
|
||||
from scipy.sparse import csr_matrix, issparse
|
||||
|
@ -15,6 +15,49 @@ from quacc.models.base import ClassifierAccuracyPrediction
|
|||
from quacc.models.utils import get_posteriors_from_h, max_conf, neg_entropy
|
||||
|
||||
|
||||
class LabelledCollection(LC):
|
||||
def empty_classes(self):
|
||||
"""
|
||||
Returns a np.ndarray of empty classes (classes present in self.classes_ but with
|
||||
no positive instance). In case there is none, then an empty np.ndarray is returned
|
||||
|
||||
:return: np.ndarray
|
||||
"""
|
||||
idx = np.argwhere(self.counts() == 0).flatten()
|
||||
return self.classes_[idx]
|
||||
|
||||
def non_empty_classes(self):
|
||||
"""
|
||||
Returns a np.ndarray of non-empty classes (classes present in self.classes_ but with
|
||||
at least one positive instance). In case there is none, then an empty np.ndarray is returned
|
||||
|
||||
:return: np.ndarray
|
||||
"""
|
||||
idx = np.argwhere(self.counts() > 0).flatten()
|
||||
return self.classes_[idx]
|
||||
|
||||
def has_empty_classes(self):
|
||||
"""
|
||||
Checks whether the collection has empty classes
|
||||
|
||||
:return: boolean
|
||||
"""
|
||||
return len(self.empty_classes()) > 0
|
||||
|
||||
def compact_classes(self):
|
||||
"""
|
||||
Generates a new LabelledCollection object with no empty classes. It also returns a np.ndarray of
|
||||
indexes that correspond to the old indexes of the new self.classes_.
|
||||
|
||||
:return: (LabelledCollection, np.ndarray,)
|
||||
"""
|
||||
non_empty = self.non_empty_classes()
|
||||
all_classes = self.classes_
|
||||
old_pos = np.searchsorted(all_classes, non_empty)
|
||||
non_empty_collection = LabelledCollection(*self.Xy, classes=non_empty)
|
||||
return non_empty_collection, old_pos
|
||||
|
||||
|
||||
class CAPContingencyTable(ClassifierAccuracyPrediction):
|
||||
def __init__(self, h: BaseEstimator, acc: callable):
|
||||
self.h = h
|
||||
|
@ -304,9 +347,9 @@ class QuAcc1xN2(CAPContingencyTableQ, QuAcc):
|
|||
pred_labels = self.h.predict(val.X)
|
||||
true_labels = val.y
|
||||
|
||||
n = val.n_classes
|
||||
classes_dot = np.arange(n**2)
|
||||
ct_class_idx = classes_dot.reshape(n, n)
|
||||
self.ncl = val.n_classes
|
||||
classes_dot = np.arange(self.ncl**2)
|
||||
ct_class_idx = classes_dot.reshape(self.ncl, self.ncl)
|
||||
|
||||
X_dot = self._get_X_dot(val.X)
|
||||
y_dot = ct_class_idx[true_labels, pred_labels]
|
||||
|
@ -315,7 +358,8 @@ class QuAcc1xN2(CAPContingencyTableQ, QuAcc):
|
|||
|
||||
def predict_ct(self, X, oracle_prev=None):
|
||||
X_dot = self._get_X_dot(X)
|
||||
return self.q.quantify(X_dot)
|
||||
flat_ct = self.q.quantify(X_dot)
|
||||
return flat_ct.reshape(self.ncl, self.ncl)
|
||||
|
||||
|
||||
class QuAcc1xNp1(CAPContingencyTableQ, QuAcc):
|
||||
|
@ -343,11 +387,11 @@ class QuAcc1xNp1(CAPContingencyTableQ, QuAcc):
|
|||
pred_labels = self.h.predict(val.X)
|
||||
true_labels = val.y
|
||||
|
||||
n = val.n_classes
|
||||
classes_dot = np.arange(n + 1)
|
||||
self.ncl = val.n_classes
|
||||
classes_dot = np.arange(self.ncl + 1)
|
||||
# ct_class_idx = classes_dot.reshape(n, n)
|
||||
ct_class_idx = np.full((n, n), n)
|
||||
ct_class_idx[np.diag_indices(n)] = np.arange(n)
|
||||
ct_class_idx = np.full((self.ncl, self.ncl), self.ncl)
|
||||
ct_class_idx[np.diag_indices(self.ncl)] = np.arange(self.ncl)
|
||||
|
||||
X_dot = self._get_X_dot(val.X)
|
||||
y_dot = ct_class_idx[true_labels, pred_labels]
|
||||
|
@ -364,7 +408,7 @@ class QuAcc1xNp1(CAPContingencyTableQ, QuAcc):
|
|||
def predict_ct(self, X: LabelledCollection, oracle_prev=None):
|
||||
X_dot = self._get_X_dot(X)
|
||||
ct_compressed = self.q.quantify(X_dot)
|
||||
return self._get_ct_hat(X.n_classes, ct_compressed)
|
||||
return self._get_ct_hat(self.ncl, ct_compressed)
|
||||
|
||||
|
||||
class QuAccNxN(CAPContingencyTableQ, QuAcc):
|
||||
|
|
Loading…
Reference in New Issue