data refactored, added ExtendedPrev
This commit is contained in:
parent
f1262b177f
commit
93de77b3d2
|
@ -24,6 +24,31 @@ class ExtensionPolicy:
|
|||
def __init__(self, collapse_false=False):
|
||||
self.collapse_false = collapse_false
|
||||
|
||||
def qclasses(self, nbcl):
|
||||
if self.collapse_false:
|
||||
return np.arange(nbcl + 1)
|
||||
else:
|
||||
return np.arange(nbcl**2)
|
||||
|
||||
def eclasses(self, nbcl):
|
||||
return np.arange(nbcl**2)
|
||||
|
||||
def matrix_idx(self, nbcl):
|
||||
if self.collapse_false:
|
||||
_idxs = np.array([[i, i] for i in range(nbcl)] + [[0, 1]]).T
|
||||
return tuple(_idxs)
|
||||
else:
|
||||
_idxs = np.indices((nbcl, nbcl))
|
||||
return _idxs[0].flatten(), _idxs[1].flatten()
|
||||
|
||||
def ext_lbl(self, nbcl):
|
||||
if self.collapse_false:
|
||||
return np.vectorize(
|
||||
lambda t, p: t if t == p else nbcl, signature="(),()->()"
|
||||
)
|
||||
else:
|
||||
return np.vectorize(lambda t, p: t * nbcl + p, signature="(),()->()")
|
||||
|
||||
|
||||
class ExtendedData:
|
||||
def __init__(
|
||||
|
@ -98,30 +123,58 @@ class ExtendedLabels:
|
|||
self,
|
||||
true: np.ndarray,
|
||||
pred: np.ndarray,
|
||||
ncl: np.ndarray,
|
||||
nbcl: np.ndarray,
|
||||
extpol: ExtensionPolicy = None,
|
||||
):
|
||||
self.extpol = ExtensionPolicy() if extpol is None else extpol
|
||||
self.true = true
|
||||
self.pred = pred
|
||||
self.ncl = ncl
|
||||
self.nbcl = nbcl
|
||||
|
||||
@property
|
||||
def y(self):
|
||||
if self.extpol.collapse_false:
|
||||
return self.true + self.pred
|
||||
else:
|
||||
return self.true * self.ncl + self.pred
|
||||
return self.extpol.ext_lbl(self.nbcl)(self.true, self.pred)
|
||||
|
||||
@property
|
||||
def classes(self):
|
||||
if self.extpol.collapse_false:
|
||||
return np.arange(self.ncl + 1)
|
||||
else:
|
||||
return np.arange(self.ncl**2)
|
||||
return self.extpol.qclasses(self.nbcl)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return ExtendedLabels(self.true[idx], self.pred[idx], self.ncl)
|
||||
return ExtendedLabels(self.true[idx], self.pred[idx], self.nbcl)
|
||||
|
||||
|
||||
class ExtendedPrev:
|
||||
def __init__(
|
||||
self,
|
||||
flat: np.ndarray,
|
||||
nbcl: int,
|
||||
q_classes: list,
|
||||
extpol: ExtensionPolicy,
|
||||
):
|
||||
self.flat = flat
|
||||
self.nbcl = nbcl
|
||||
self.extpol = ExtensionPolicy() if extpol is None else extpol
|
||||
self.__check_q_classes(q_classes)
|
||||
self._matrix = self.__build_matrix()
|
||||
|
||||
def __check_q_classes(self, q_classes):
|
||||
q_classes = np.array(q_classes)
|
||||
_flat = np.zeros(self.extpol.qclasses(self.nbcl).shape)
|
||||
_flat[q_classes] = self.flat
|
||||
self.flat = _flat
|
||||
|
||||
def __build_matrix(self):
|
||||
_matrix = np.zeros((self.nbcl, self.nbcl))
|
||||
_matrix[self.extpol.matrix_idx(self.nbcl)] = self.flat
|
||||
return _matrix
|
||||
|
||||
@property
|
||||
def A(self):
|
||||
return self._matrix
|
||||
|
||||
@property
|
||||
def classes(self):
|
||||
return self.extpol.qclasses(self.nbcl)
|
||||
|
||||
|
||||
class ExtendedCollection(LabelledCollection):
|
||||
|
@ -172,6 +225,14 @@ class ExtendedCollection(LabelledCollection):
|
|||
def ey(self):
|
||||
return self.e_labels_
|
||||
|
||||
@property
|
||||
def n_base_classes(self):
|
||||
return self.e_labels_.nbcl
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return len(self.e_labels_.classes)
|
||||
|
||||
def counts(self):
|
||||
_counts = super().counts()
|
||||
if self.extpol.collapse_false:
|
||||
|
|
Loading…
Reference in New Issue