data refactored, added ExtendedPrev

This commit is contained in:
Lorenzo Volpi 2023-12-08 00:44:44 +01:00
parent f1262b177f
commit 93de77b3d2
1 changed files with 72 additions and 11 deletions

View File

@ -24,6 +24,31 @@ class ExtensionPolicy:
def __init__(self, collapse_false=False):
self.collapse_false = collapse_false
def qclasses(self, nbcl):
if self.collapse_false:
return np.arange(nbcl + 1)
else:
return np.arange(nbcl**2)
def eclasses(self, nbcl):
return np.arange(nbcl**2)
def matrix_idx(self, nbcl):
if self.collapse_false:
_idxs = np.array([[i, i] for i in range(nbcl)] + [[0, 1]]).T
return tuple(_idxs)
else:
_idxs = np.indices((nbcl, nbcl))
return _idxs[0].flatten(), _idxs[1].flatten()
def ext_lbl(self, nbcl):
if self.collapse_false:
return np.vectorize(
lambda t, p: t if t == p else nbcl, signature="(),()->()"
)
else:
return np.vectorize(lambda t, p: t * nbcl + p, signature="(),()->()")
class ExtendedData:
def __init__(
@ -98,30 +123,58 @@ class ExtendedLabels:
self,
true: np.ndarray,
pred: np.ndarray,
ncl: np.ndarray,
nbcl: np.ndarray,
extpol: ExtensionPolicy = None,
):
self.extpol = ExtensionPolicy() if extpol is None else extpol
self.true = true
self.pred = pred
self.ncl = ncl
self.nbcl = nbcl
@property
def y(self):
if self.extpol.collapse_false:
return self.true + self.pred
else:
return self.true * self.ncl + self.pred
return self.extpol.ext_lbl(self.nbcl)(self.true, self.pred)
@property
def classes(self):
if self.extpol.collapse_false:
return np.arange(self.ncl + 1)
else:
return np.arange(self.ncl**2)
return self.extpol.qclasses(self.nbcl)
def __getitem__(self, idx):
return ExtendedLabels(self.true[idx], self.pred[idx], self.ncl)
return ExtendedLabels(self.true[idx], self.pred[idx], self.nbcl)
class ExtendedPrev:
def __init__(
self,
flat: np.ndarray,
nbcl: int,
q_classes: list,
extpol: ExtensionPolicy,
):
self.flat = flat
self.nbcl = nbcl
self.extpol = ExtensionPolicy() if extpol is None else extpol
self.__check_q_classes(q_classes)
self._matrix = self.__build_matrix()
def __check_q_classes(self, q_classes):
q_classes = np.array(q_classes)
_flat = np.zeros(self.extpol.qclasses(self.nbcl).shape)
_flat[q_classes] = self.flat
self.flat = _flat
def __build_matrix(self):
_matrix = np.zeros((self.nbcl, self.nbcl))
_matrix[self.extpol.matrix_idx(self.nbcl)] = self.flat
return _matrix
@property
def A(self):
return self._matrix
@property
def classes(self):
return self.extpol.qclasses(self.nbcl)
class ExtendedCollection(LabelledCollection):
@ -172,6 +225,14 @@ class ExtendedCollection(LabelledCollection):
def ey(self):
return self.e_labels_
@property
def n_base_classes(self):
return self.e_labels_.nbcl
@property
def n_classes(self):
return len(self.e_labels_.classes)
def counts(self):
_counts = super().counts()
if self.extpol.collapse_false: