reworked data model
This commit is contained in:
parent
c9df56329a
commit
e5f631d4bc
220
quacc/data.py
220
quacc/data.py
|
@ -1,11 +1,9 @@
|
|||
import math
|
||||
from typing import List, Optional
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import scipy.sparse as sp
|
||||
from quapy.data import LabelledCollection
|
||||
|
||||
|
||||
# Extended classes
|
||||
#
|
||||
# 0 ~ True 0
|
||||
|
@ -20,32 +18,54 @@ from quapy.data import LabelledCollection
|
|||
# | False 0 | True 1 |
|
||||
# |__________|__________|
|
||||
#
|
||||
class ExClassManager:
|
||||
@staticmethod
|
||||
def get_ex(n_classes: int, true_class: int, pred_class: int) -> int:
|
||||
return true_class * n_classes + pred_class
|
||||
|
||||
@staticmethod
|
||||
def get_pred(n_classes: int, ex_class: int) -> int:
|
||||
return ex_class % n_classes
|
||||
|
||||
@staticmethod
|
||||
def get_true(n_classes: int, ex_class: int) -> int:
|
||||
return ex_class // n_classes
|
||||
|
||||
|
||||
class ExtendedCollection(LabelledCollection):
|
||||
class ExtendedData:
|
||||
def __init__(
|
||||
self,
|
||||
instances: np.ndarray | sp.csr_matrix,
|
||||
labels: np.ndarray,
|
||||
classes: Optional[List] = None,
|
||||
pred_proba: np.ndarray,
|
||||
ext: np.ndarray = None,
|
||||
):
|
||||
super().__init__(instances, labels, classes=classes)
|
||||
self.b_instances_ = instances
|
||||
self.pred_proba_ = pred_proba
|
||||
self.ext_ = ext
|
||||
self.instances = self.__extend_instances(instances, pred_proba, ext=ext)
|
||||
|
||||
def split_by_pred(self):
|
||||
_ncl = int(math.sqrt(self.n_classes))
|
||||
_indexes = ExtendedCollection._split_index_by_pred(_ncl, self.instances)
|
||||
def __extend_instances(
|
||||
self,
|
||||
instances: np.ndarray | sp.csr_matrix,
|
||||
pred_proba: np.ndarray,
|
||||
ext: np.ndarray = None,
|
||||
) -> np.ndarray | sp.csr_matrix:
|
||||
to_append = pred_proba
|
||||
if ext is not None:
|
||||
to_append = np.concatenate([ext, pred_proba], axis=1)
|
||||
|
||||
if isinstance(instances, sp.csr_matrix):
|
||||
_to_append = sp.csr_matrix(to_append)
|
||||
n_x = sp.hstack([instances, _to_append])
|
||||
elif isinstance(instances, np.ndarray):
|
||||
n_x = np.concatenate((instances, to_append), axis=1)
|
||||
else:
|
||||
raise ValueError("Unsupported matrix format")
|
||||
|
||||
return n_x
|
||||
|
||||
@property
|
||||
def X(self):
|
||||
return self.instances
|
||||
|
||||
def __split_index_by_pred(self) -> List[np.ndarray]:
|
||||
_pred_label = np.argmax(self.pred_proba_, axis=0)
|
||||
|
||||
return [
|
||||
(_pred_label == cl).nonzero()[0]
|
||||
for cl in np.arange(self.pred_proba_.shape[0])
|
||||
]
|
||||
|
||||
def split_by_pred(self, return_indexes=False):
|
||||
_indexes = self.__split_index_by_pred()
|
||||
if isinstance(self.instances, np.ndarray):
|
||||
_instances = [
|
||||
self.instances[ind] if ind.shape[0] > 0 else np.asarray([], dtype=int)
|
||||
|
@ -58,93 +78,95 @@ class ExtendedCollection(LabelledCollection):
|
|||
else sp.csr_matrix(np.empty((0, 0), dtype=int))
|
||||
for ind in _indexes
|
||||
]
|
||||
_labels = [
|
||||
np.asarray(
|
||||
[
|
||||
ExClassManager.get_true(_ncl, lbl)
|
||||
for lbl in (self.labels[ind] if len(ind) > 0 else [])
|
||||
],
|
||||
dtype=int,
|
||||
)
|
||||
for ind in _indexes
|
||||
]
|
||||
return [
|
||||
ExtendedCollection(inst, lbl, classes=range(0, _ncl))
|
||||
for (inst, lbl) in zip(_instances, _labels)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def split_inst_by_pred(
|
||||
cls, n_classes: int, instances: np.ndarray | sp.csr_matrix
|
||||
) -> (List[np.ndarray | sp.csr_matrix], List[float]):
|
||||
_indexes = cls._split_index_by_pred(n_classes, instances)
|
||||
if isinstance(instances, np.ndarray):
|
||||
_instances = [
|
||||
instances[ind] if ind.shape[0] > 0 else np.asarray([], dtype=int)
|
||||
for ind in _indexes
|
||||
]
|
||||
elif isinstance(instances, sp.csr_matrix):
|
||||
_instances = [
|
||||
instances[ind]
|
||||
if ind.shape[0] > 0
|
||||
else sp.csr_matrix(np.empty((0, 0), dtype=int))
|
||||
for ind in _indexes
|
||||
]
|
||||
norms = [inst.shape[0] / instances.shape[0] for inst in _instances]
|
||||
return _instances, norms
|
||||
if return_indexes:
|
||||
return _instances, _indexes
|
||||
|
||||
@classmethod
|
||||
def _split_index_by_pred(
|
||||
cls, n_classes: int, instances: np.ndarray | sp.csr_matrix
|
||||
) -> List[np.ndarray]:
|
||||
if isinstance(instances, np.ndarray):
|
||||
_pred_label = [np.argmax(inst[-n_classes:], axis=0) for inst in instances]
|
||||
elif isinstance(instances, sp.csr_matrix):
|
||||
_pred_label = [
|
||||
np.argmax(inst[:, -n_classes:].toarray().flatten(), axis=0)
|
||||
for inst in instances
|
||||
]
|
||||
else:
|
||||
raise ValueError("Unsupported matrix format")
|
||||
return _instances
|
||||
|
||||
return [
|
||||
np.asarray([j for (j, x) in enumerate(_pred_label) if x == i], dtype=int)
|
||||
for i in range(0, n_classes)
|
||||
]
|
||||
def __len__(self):
|
||||
return self.instances.shape[0]
|
||||
|
||||
@classmethod
|
||||
def extend_instances(
|
||||
cls, instances: np.ndarray | sp.csr_matrix, pred_proba: np.ndarray
|
||||
) -> np.ndarray | sp.csr_matrix:
|
||||
if isinstance(instances, sp.csr_matrix):
|
||||
_pred_proba = sp.csr_matrix(pred_proba)
|
||||
n_x = sp.hstack([instances, _pred_proba])
|
||||
elif isinstance(instances, np.ndarray):
|
||||
n_x = np.concatenate((instances, pred_proba), axis=1)
|
||||
else:
|
||||
raise ValueError("Unsupported matrix format")
|
||||
|
||||
return n_x
|
||||
class ExtendedLabels:
|
||||
def __init__(self, true: np.ndarray, pred: np.ndarray, ncl: np.ndarray):
|
||||
self.true = true
|
||||
self.pred = pred
|
||||
self.ncl = ncl
|
||||
|
||||
@classmethod
|
||||
def extend_collection(
|
||||
cls,
|
||||
base: LabelledCollection,
|
||||
pred_proba: np.ndarray,
|
||||
@property
|
||||
def y(self):
|
||||
return self.true * self.ncl + self.pred
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return ExtendedLabels(self.true[idx], self.pred[idx], self.ncl)
|
||||
|
||||
|
||||
class ExtendedCollection(LabelledCollection):
|
||||
def __init__(
|
||||
self,
|
||||
instances: np.ndarray | sp.csr_matrix,
|
||||
labels: np.ndarray,
|
||||
pred_proba: np.ndarray = None,
|
||||
ext: np.ndarray = None,
|
||||
):
|
||||
n_classes = base.n_classes
|
||||
e_data, e_labels, _classes = self.__extend_collection(
|
||||
instances=instances,
|
||||
labels=labels,
|
||||
pred_proba=pred_proba,
|
||||
ext=ext,
|
||||
)
|
||||
self.e_data_ = e_data
|
||||
self.e_labels_ = e_labels
|
||||
super().__init__(e_data.X, e_labels.y, classes=_classes)
|
||||
|
||||
@classmethod
|
||||
def from_lc(
|
||||
cls,
|
||||
lc: LabelledCollection,
|
||||
predict_proba: np.ndarray,
|
||||
ext: np.ndarray = None,
|
||||
):
|
||||
return ExtendedCollection(lc.X, lc.y, pred_proba=predict_proba, ext=ext)
|
||||
|
||||
@property
|
||||
def pred_proba(self):
|
||||
return self.e_data_.pred_proba_
|
||||
|
||||
@property
|
||||
def ext(self):
|
||||
return self.e_data_.ext_
|
||||
|
||||
@property
|
||||
def eX(self):
|
||||
return self.e_data_
|
||||
|
||||
@property
|
||||
def ey(self):
|
||||
return self.e_labels_
|
||||
|
||||
def split_by_pred(self):
|
||||
_ncl = len(self.pred_proba)
|
||||
_instances, _indexes = self.e_data_.split_by_pred(return_indexes=True)
|
||||
_labels = [self.ey[ind] for ind in _indexes]
|
||||
return [
|
||||
LabelledCollection(inst, lbl.true, classes=range(0, _ncl))
|
||||
for inst, lbl in zip(_instances, _labels)
|
||||
]
|
||||
|
||||
def __extend_collection(
|
||||
self,
|
||||
instances: sp.csr_matrix | np.ndarray,
|
||||
labels: np.ndarray,
|
||||
pred_proba: np.ndarray,
|
||||
ext: np.ndarray = None,
|
||||
) -> Tuple[ExtendedData, ExtendedLabels, np.ndarray]:
|
||||
n_classes = np.unique(labels).shape[0]
|
||||
# n_X = [ X | predicted probs. ]
|
||||
n_x = cls.extend_instances(base.X, pred_proba)
|
||||
e_instances = ExtendedData(instances, pred_proba, ext=ext)
|
||||
|
||||
# n_y = (exptected y, predicted y)
|
||||
pred_proba = pred_proba[:, -n_classes:]
|
||||
preds = np.argmax(pred_proba, axis=-1)
|
||||
n_y = np.asarray(
|
||||
[
|
||||
ExClassManager.get_ex(n_classes, true_class, pred_class)
|
||||
for (true_class, pred_class) in zip(base.y, preds)
|
||||
]
|
||||
)
|
||||
e_labels = ExtendedLabels(labels, preds, n_classes)
|
||||
|
||||
return ExtendedCollection(n_x, n_y, classes=[*range(0, n_classes * n_classes)])
|
||||
return e_instances, e_labels, np.arange(n_classes**2)
|
||||
|
|
|
@ -22,7 +22,7 @@ def evaluate(
|
|||
estim_prevs, true_prevs = [], []
|
||||
for sample in protocol():
|
||||
e_sample = estimator.extend(sample)
|
||||
estim_prev = estimator.estimate(e_sample.X, ext=True)
|
||||
estim_prev = estimator.estimate(e_sample.eX)
|
||||
estim_prevs.append(estim_prev)
|
||||
true_prevs.append(e_sample.prevalence())
|
||||
|
||||
|
|
|
@ -1,15 +1,14 @@
|
|||
import math
|
||||
from abc import abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import scipy.sparse as sp
|
||||
from quapy.data import LabelledCollection
|
||||
from quapy.method.aggregative import BaseQuantifier
|
||||
from scipy.sparse import csr_matrix
|
||||
from sklearn.base import BaseEstimator
|
||||
|
||||
from quacc.data import ExtendedCollection
|
||||
from quacc.data import ExtendedCollection, ExtendedData
|
||||
|
||||
|
||||
class BaseAccuracyEstimator(BaseQuantifier):
|
||||
|
@ -17,11 +16,9 @@ class BaseAccuracyEstimator(BaseQuantifier):
|
|||
self,
|
||||
classifier: BaseEstimator,
|
||||
quantifier: BaseQuantifier,
|
||||
confidence=None,
|
||||
):
|
||||
self.__check_classifier(classifier)
|
||||
self.quantifier = quantifier
|
||||
self.confidence = confidence
|
||||
|
||||
def __check_classifier(self, classifier):
|
||||
if not hasattr(classifier, "predict_proba"):
|
||||
|
@ -30,6 +27,45 @@ class BaseAccuracyEstimator(BaseQuantifier):
|
|||
)
|
||||
self.classifier = classifier
|
||||
|
||||
def extend(self, coll: LabelledCollection, pred_proba=None) -> ExtendedCollection:
|
||||
if pred_proba is None:
|
||||
pred_proba = self.classifier.predict_proba(coll.X)
|
||||
|
||||
return ExtendedCollection.from_lc(coll, pred_proba=pred_proba)
|
||||
|
||||
def _extend_instances(self, instances: np.ndarray | sp.csr_matrix, pred_proba=None):
|
||||
if pred_proba is None:
|
||||
pred_proba = self.classifier.predict_proba(instances)
|
||||
|
||||
return ExtendedData(instances, pred_proba=pred_proba)
|
||||
|
||||
@abstractmethod
|
||||
def fit(self, train: LabelledCollection | ExtendedCollection):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def estimate(self, instances, ext=False) -> np.ndarray:
|
||||
...
|
||||
|
||||
|
||||
class ConfidenceBasedAccuracyEstimator(BaseAccuracyEstimator):
|
||||
def __init__(
|
||||
self,
|
||||
classifier: BaseEstimator,
|
||||
quantifier: BaseQuantifier,
|
||||
confidence=None,
|
||||
):
|
||||
super().__init__(classifier, quantifier)
|
||||
self.__check_confidence(confidence)
|
||||
|
||||
def __check_confidence(self, confidence):
|
||||
if isinstance(confidence, str):
|
||||
self.confidence = [confidence]
|
||||
elif isinstance(confidence, list):
|
||||
self.confidence = confidence
|
||||
else:
|
||||
self.confidence = None
|
||||
|
||||
def __get_confidence(self):
|
||||
def max_conf(probas):
|
||||
_mc = np.max(probas, axis=-1)
|
||||
|
@ -42,47 +78,49 @@ class BaseAccuracyEstimator(BaseQuantifier):
|
|||
return _ent
|
||||
|
||||
if self.confidence is None:
|
||||
return None
|
||||
return []
|
||||
|
||||
__confs = {
|
||||
"max_conf": max_conf,
|
||||
"entropy": entropy,
|
||||
}
|
||||
return __confs.get(self.confidence, None)
|
||||
return [__confs.get(c, None) for c in self.confidence]
|
||||
|
||||
def __get_ext(self, pred_proba):
|
||||
_ext = pred_proba
|
||||
_f_conf = self.__get_confidence()
|
||||
if _f_conf is not None:
|
||||
_confs = _f_conf(pred_proba).reshape((len(pred_proba), 1))
|
||||
_ext = np.concatenate((_confs, pred_proba), axis=1)
|
||||
def __get_ext(self, pred_proba: np.ndarray) -> np.ndarray:
|
||||
__confidence = self.__get_confidence()
|
||||
|
||||
return _ext
|
||||
if __confidence is None or len(__confidence) == 0:
|
||||
return None
|
||||
|
||||
return np.concatenate(
|
||||
[
|
||||
_f_conf(pred_proba).reshape((len(pred_proba), 1))
|
||||
for _f_conf in __confidence
|
||||
if _f_conf is not None
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
def extend(self, coll: LabelledCollection, pred_proba=None) -> ExtendedCollection:
|
||||
if pred_proba is None:
|
||||
pred_proba = self.classifier.predict_proba(coll.X)
|
||||
|
||||
_ext = self.__get_ext(pred_proba)
|
||||
return ExtendedCollection.extend_collection(coll, pred_proba=_ext)
|
||||
return ExtendedCollection.from_lc(coll, pred_proba=pred_proba, ext=_ext)
|
||||
|
||||
def _extend_instances(self, instances: np.ndarray | csr_matrix, pred_proba=None):
|
||||
def _extend_instances(
|
||||
self,
|
||||
instances: np.ndarray | sp.csr_matrix,
|
||||
pred_proba=None,
|
||||
) -> ExtendedData:
|
||||
if pred_proba is None:
|
||||
pred_proba = self.classifier.predict_proba(instances)
|
||||
|
||||
_ext = self.__get_ext(pred_proba)
|
||||
return ExtendedCollection.extend_instances(instances, _ext)
|
||||
|
||||
@abstractmethod
|
||||
def fit(self, train: LabelledCollection | ExtendedCollection):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def estimate(self, instances, ext=False) -> np.ndarray:
|
||||
...
|
||||
return ExtendedData(instances, pred_proba=pred_proba, ext=_ext)
|
||||
|
||||
|
||||
class MultiClassAccuracyEstimator(BaseAccuracyEstimator):
|
||||
class MultiClassAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
|
||||
def __init__(
|
||||
self,
|
||||
classifier: BaseEstimator,
|
||||
|
@ -103,10 +141,14 @@ class MultiClassAccuracyEstimator(BaseAccuracyEstimator):
|
|||
|
||||
return self
|
||||
|
||||
def estimate(self, instances, ext=False) -> np.ndarray:
|
||||
e_inst = instances if ext else self._extend_instances(instances)
|
||||
def estimate(
|
||||
self, instances: ExtendedData | np.ndarray | sp.csr_matrix
|
||||
) -> np.ndarray:
|
||||
e_inst = instances
|
||||
if not isinstance(e_inst, ExtendedData):
|
||||
e_inst = self._extend_instances(instances)
|
||||
|
||||
estim_prev = self.quantifier.quantify(e_inst)
|
||||
estim_prev = self.quantifier.quantify(e_inst.X)
|
||||
return self._check_prevalence_classes(estim_prev, self.quantifier.classes_)
|
||||
|
||||
def _check_prevalence_classes(self, estim_prev, estim_classes) -> np.ndarray:
|
||||
|
@ -117,7 +159,7 @@ class MultiClassAccuracyEstimator(BaseAccuracyEstimator):
|
|||
return estim_prev
|
||||
|
||||
|
||||
class BinaryQuantifierAccuracyEstimator(BaseAccuracyEstimator):
|
||||
class BinaryQuantifierAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
|
||||
def __init__(
|
||||
self,
|
||||
classifier: BaseEstimator,
|
||||
|
@ -130,28 +172,30 @@ class BinaryQuantifierAccuracyEstimator(BaseAccuracyEstimator):
|
|||
confidence=confidence,
|
||||
)
|
||||
self.quantifiers = []
|
||||
self.e_trains = []
|
||||
|
||||
def fit(self, train: LabelledCollection | ExtendedCollection):
|
||||
self.e_train = self.extend(train)
|
||||
|
||||
self.n_classes = self.e_train.n_classes
|
||||
self.e_trains = self.e_train.split_by_pred()
|
||||
e_trains = self.e_train.split_by_pred()
|
||||
|
||||
self.quantifiers = []
|
||||
for train in self.e_trains:
|
||||
for train in e_trains:
|
||||
quant = deepcopy(self.quantifier)
|
||||
quant.fit(train)
|
||||
self.quantifiers.append(quant)
|
||||
|
||||
return self
|
||||
|
||||
def estimate(self, instances, ext=False):
|
||||
# TODO: test
|
||||
e_inst = instances if ext else self._extend_instances(instances)
|
||||
def estimate(
|
||||
self, instances: ExtendedData | np.ndarray | sp.csr_matrix
|
||||
) -> np.ndarray:
|
||||
e_inst = instances
|
||||
if not isinstance(e_inst, ExtendedData):
|
||||
e_inst = self._extend_instances(instances)
|
||||
|
||||
_ncl = int(math.sqrt(self.n_classes))
|
||||
s_inst, norms = ExtendedCollection.split_inst_by_pred(_ncl, e_inst)
|
||||
s_inst = e_inst.split_by_pred()
|
||||
norms = [s_i.shape[0] / len(e_inst) for s_i in s_inst]
|
||||
estim_prevs = self._quantify_helper(s_inst, norms)
|
||||
|
||||
estim_prev = np.array([prev_row for prev_row in zip(*estim_prevs)]).flatten()
|
||||
|
@ -159,7 +203,7 @@ class BinaryQuantifierAccuracyEstimator(BaseAccuracyEstimator):
|
|||
|
||||
def _quantify_helper(
|
||||
self,
|
||||
s_inst: List[np.ndarray | csr_matrix],
|
||||
s_inst: List[np.ndarray | sp.csr_matrix],
|
||||
norms: List[float],
|
||||
):
|
||||
estim_prevs = []
|
||||
|
|
|
@ -2,8 +2,8 @@ import itertools
|
|||
from copy import deepcopy
|
||||
from time import time
|
||||
from typing import Callable, Union
|
||||
import numpy as np
|
||||
|
||||
import numpy as np
|
||||
import quapy as qp
|
||||
from quapy.data import LabelledCollection
|
||||
from quapy.model_selection import GridSearchQ
|
||||
|
@ -12,7 +12,7 @@ from sklearn.base import BaseEstimator
|
|||
|
||||
import quacc as qc
|
||||
import quacc.error
|
||||
from quacc.data import ExtendedCollection
|
||||
from quacc.data import ExtendedCollection, ExtendedData
|
||||
from quacc.evaluation import evaluate
|
||||
from quacc.logger import SubLogger
|
||||
from quacc.method.base import (
|
||||
|
@ -182,7 +182,7 @@ class GridSearchAE(BaseAccuracyEstimator):
|
|||
assert hasattr(self, "best_model_"), "quantify called before fit"
|
||||
return self.best_model().extend(coll, pred_proba=pred_proba)
|
||||
|
||||
def estimate(self, instances, ext=False):
|
||||
def estimate(self, instances):
|
||||
"""Estimate class prevalence values using the best model found after calling the :meth:`fit` method.
|
||||
|
||||
:param instances: sample contanining the instances
|
||||
|
@ -191,7 +191,7 @@ class GridSearchAE(BaseAccuracyEstimator):
|
|||
"""
|
||||
|
||||
assert hasattr(self, "best_model_"), "estimate called before fit"
|
||||
return self.best_model().estimate(instances, ext=ext)
|
||||
return self.best_model().estimate(instances)
|
||||
|
||||
def set_params(self, **parameters):
|
||||
"""Sets the hyper-parameters to explore.
|
||||
|
@ -220,7 +220,6 @@ class GridSearchAE(BaseAccuracyEstimator):
|
|||
raise ValueError("best_model called before fit")
|
||||
|
||||
|
||||
|
||||
class MCAEgsq(MultiClassAccuracyEstimator):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -257,10 +256,15 @@ class MCAEgsq(MultiClassAccuracyEstimator):
|
|||
|
||||
return self
|
||||
|
||||
def estimate(self, instances, ext=False) -> np.ndarray:
|
||||
e_inst = instances if ext else self._extend_instances(instances)
|
||||
estim_prev = self.quantifier.quantify(e_inst)
|
||||
return self._check_prevalence_classes(estim_prev, self.quantifier.best_model().classes_)
|
||||
def estimate(self, instances) -> np.ndarray:
|
||||
e_inst = instances
|
||||
if not isinstance(e_inst, ExtendedData):
|
||||
e_inst = self._extend_instances(instances)
|
||||
|
||||
estim_prev = self.quantifier.quantify(e_inst.X)
|
||||
return self._check_prevalence_classes(
|
||||
estim_prev, self.quantifier.best_model().classes_
|
||||
)
|
||||
|
||||
|
||||
class BQAEgsq(BinaryQuantifierAccuracyEstimator):
|
||||
|
|
|
@ -1,48 +1,8 @@
|
|||
import pytest
|
||||
from quacc.data import ExClassManager as ECM, ExtendedCollection
|
||||
import numpy as np
|
||||
import pytest
|
||||
import scipy.sparse as sp
|
||||
|
||||
|
||||
class TestExClassManager:
|
||||
@pytest.mark.parametrize(
|
||||
"true_class,pred_class,result",
|
||||
[
|
||||
(0, 0, 0),
|
||||
(0, 1, 1),
|
||||
(1, 0, 2),
|
||||
(1, 1, 3),
|
||||
],
|
||||
)
|
||||
def test_get_ex(self, true_class, pred_class, result):
|
||||
ncl = 2
|
||||
assert ECM.get_ex(ncl, true_class, pred_class) == result
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ex_class,result",
|
||||
[
|
||||
(0, 0),
|
||||
(1, 1),
|
||||
(2, 0),
|
||||
(3, 1),
|
||||
],
|
||||
)
|
||||
def test_get_pred(self, ex_class, result):
|
||||
ncl = 2
|
||||
assert ECM.get_pred(ncl, ex_class) == result
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ex_class,result",
|
||||
[
|
||||
(0, 0),
|
||||
(1, 0),
|
||||
(2, 1),
|
||||
(3, 1),
|
||||
],
|
||||
)
|
||||
def test_get_true(self, ex_class, result):
|
||||
ncl = 2
|
||||
assert ECM.get_true(ncl, ex_class) == result
|
||||
from quacc.data import ExtendedCollection
|
||||
|
||||
|
||||
class TestExtendedCollection:
|
||||
|
|
Loading…
Reference in New Issue