refactored for multiclass

This commit is contained in:
Lorenzo Volpi 2023-12-21 16:47:35 +01:00
parent 5d0ecfda39
commit d783989ebc
13 changed files with 925 additions and 358 deletions

View File

@ -20,34 +20,91 @@ from quapy.data import LabelledCollection
# #
def _split_index_by_pred(pred_proba: np.ndarray) -> List[np.ndarray]:
_pred_label = np.argmax(pred_proba, axis=1)
return [(_pred_label == cl).nonzero()[0] for cl in np.arange(pred_proba.shape[1])]
class ExtensionPolicy: class ExtensionPolicy:
def __init__(self, collapse_false=False): def __init__(self, collapse_false=False, group_false=False, dense=False):
self.collapse_false = collapse_false self.collapse_false = collapse_false
self.group_false = group_false
self.dense = dense
def qclasses(self, nbcl): def qclasses(self, nbcl):
if self.collapse_false: if self.collapse_false:
return np.arange(nbcl + 1) return np.arange(nbcl + 1)
else: elif self.group_false:
return np.arange(nbcl**2) return np.arange(nbcl * 2)
return np.arange(nbcl**2)
def eclasses(self, nbcl): def eclasses(self, nbcl):
return np.arange(nbcl**2) return np.arange(nbcl**2)
def tfp_classes(self, nbcl):
if self.group_false:
return np.arange(2)
else:
return np.arange(nbcl)
def matrix_idx(self, nbcl): def matrix_idx(self, nbcl):
if self.collapse_false: if self.collapse_false:
_idxs = np.array([[i, i] for i in range(nbcl)] + [[0, 1]]).T _idxs = np.array([[i, i] for i in range(nbcl)] + [[0, 1]]).T
return tuple(_idxs) return tuple(_idxs)
elif self.group_false:
diag_idxs = np.diag_indices(nbcl)
sub_diag_idxs = tuple(
np.array([((i + 1) % nbcl, i) for i in range(nbcl)]).T
)
return tuple(np.concatenate(axis) for axis in zip(diag_idxs, sub_diag_idxs))
# def mask_fn(m, k):
# n = m.shape[0]
# d = np.diag(np.tile(1, n))
# d[tuple(zip(*[(i, (i + 1) % n) for i in range(n)]))] = 1
# return d
# _mi = np.mask_indices(nbcl, mask_func=mask_fn)
# print(_mi)
# return _mi
else: else:
_idxs = np.indices((nbcl, nbcl)) _idxs = np.indices((nbcl, nbcl))
return _idxs[0].flatten(), _idxs[1].flatten() return _idxs[0].flatten(), _idxs[1].flatten()
def ext_lbl(self, nbcl): def ext_lbl(self, nbcl):
if self.collapse_false: if self.collapse_false:
return np.vectorize(
lambda t, p: t if t == p else nbcl, signature="(),()->()" def cf_fun(t, p):
) return t if t == p else nbcl
return np.vectorize(cf_fun, signature="(),()->()")
elif self.group_false:
def gf_fun(t, p):
# if t < nbcl - 1:
# return t * 2 if t == p else (t * 2) + 1
# else:
# return t * 2 if t != p else (t * 2) + 1
return p if t == p else nbcl + p
return np.vectorize(gf_fun, signature="(),()->()")
else: else:
return np.vectorize(lambda t, p: t * nbcl + p, signature="(),()->()")
def default_fn(t, p):
return t * nbcl + p
return np.vectorize(default_fn, signature="(),()->()")
def true_lbl_from_pred(self, nbcl):
if self.group_false:
return np.vectorize(lambda t, p: 0 if t == p else 1, signature="(),()->()")
else:
return np.vectorize(lambda t, p: t, signature="(),()->()")
def can_f1(self, nbcl):
return nbcl == 2 or (not self.collapse_false and not self.group_false)
class ExtendedData: class ExtendedData:
@ -75,10 +132,13 @@ class ExtendedData:
to_append = pred_proba to_append = pred_proba
if isinstance(instances, sp.csr_matrix): if isinstance(instances, sp.csr_matrix):
_to_append = sp.csr_matrix(to_append) if self.extpol.dense:
n_x = sp.hstack([instances, _to_append]) n_x = to_append
else:
n_x = sp.hstack([instances, sp.csr_matrix(to_append)], format="csr")
elif isinstance(instances, np.ndarray): elif isinstance(instances, np.ndarray):
n_x = np.concatenate((instances, to_append), axis=1) _concat = [instances, to_append] if not self.extpol.dense else [to_append]
n_x = np.concatenate(_concat, axis=1)
else: else:
raise ValueError("Unsupported matrix format") raise ValueError("Unsupported matrix format")
@ -88,30 +148,25 @@ class ExtendedData:
def X(self): def X(self):
return self.instances return self.instances
def __split_index_by_pred(self) -> List[np.ndarray]: @property
_pred_label = np.argmax(self.pred_proba_, axis=1) def nbcl(self):
return self.pred_proba_.shape[1]
return [ def split_by_pred(self, _indexes: List[np.ndarray] | None = None):
(_pred_label == cl).nonzero()[0]
for cl in np.arange(self.pred_proba_.shape[1])
]
def split_by_pred(self, return_indexes=False):
def _empty_matrix(): def _empty_matrix():
if isinstance(self.instances, np.ndarray): if isinstance(self.instances, np.ndarray):
return np.asarray([], dtype=int) return np.asarray([], dtype=int)
elif isinstance(self.instances, sp.csr_matrix): elif isinstance(self.instances, sp.csr_matrix):
return sp.csr_matrix(np.empty((0, 0), dtype=int)) return sp.csr_matrix(np.empty((0, 0), dtype=int))
_indexes = self.__split_index_by_pred() if _indexes is None:
_indexes = _split_index_by_pred(self.pred_proba_)
_instances = [ _instances = [
self.instances[ind] if ind.shape[0] > 0 else _empty_matrix() self.instances[ind] if ind.shape[0] > 0 else _empty_matrix()
for ind in _indexes for ind in _indexes
] ]
if return_indexes:
return _instances, _indexes
return _instances return _instances
def __len__(self): def __len__(self):
@ -142,41 +197,96 @@ class ExtendedLabels:
def __getitem__(self, idx): def __getitem__(self, idx):
return ExtendedLabels(self.true[idx], self.pred[idx], self.nbcl) return ExtendedLabels(self.true[idx], self.pred[idx], self.nbcl)
def split_by_pred(self, _indexes: List[np.ndarray]):
_labels = []
for cl, ind in enumerate(_indexes):
_true, _pred = self.true[ind], self.pred[ind]
assert (
_pred.shape[0] == 0 or (_pred == _pred[0]).all()
), "index is selecting non uniform class"
_tfp = self.extpol.true_lbl_from_pred(self.nbcl)(_true, _pred)
_labels.append(_tfp)
return _labels, self.extpol.tfp_classes(self.nbcl)
class ExtendedPrev: class ExtendedPrev:
def __init__( def __init__(
self, self,
flat: np.ndarray, flat: np.ndarray,
nbcl: int, nbcl: int,
q_classes: list, extpol: ExtensionPolicy = None,
extpol: ExtensionPolicy,
): ):
self.flat = flat self.flat = flat
self.nbcl = nbcl self.nbcl = nbcl
self.extpol = ExtensionPolicy() if extpol is None else extpol self.extpol = ExtensionPolicy() if extpol is None else extpol
self.__check_q_classes(q_classes) # self._matrix = self.__build_matrix()
self._matrix = self.__build_matrix()
def __check_q_classes(self, q_classes):
q_classes = np.array(q_classes)
_flat = np.zeros(self.extpol.qclasses(self.nbcl).shape)
_flat[q_classes] = self.flat
self.flat = _flat
def __build_matrix(self): def __build_matrix(self):
_matrix = np.zeros((self.nbcl, self.nbcl)) _matrix = np.zeros((self.nbcl, self.nbcl))
_matrix[self.extpol.matrix_idx(self.nbcl)] = self.flat _matrix[self.extpol.matrix_idx(self.nbcl)] = self.flat
return _matrix return _matrix
def can_f1(self):
return self.extpol.can_f1(self.nbcl)
@property @property
def A(self): def A(self):
return self._matrix # return self._matrix
return self.__build_matrix()
@property @property
def classes(self): def classes(self):
return self.extpol.qclasses(self.nbcl) return self.extpol.qclasses(self.nbcl)
class ExtMulPrev(ExtendedPrev):
def __init__(
self,
flat: np.ndarray,
nbcl: int,
q_classes: list = None,
extpol: ExtensionPolicy = None,
):
super().__init__(flat, nbcl, extpol=extpol)
self.flat = self.__check_q_classes(q_classes, flat)
def __check_q_classes(self, q_classes, flat):
if q_classes is None:
return flat
q_classes = np.array(q_classes)
_flat = np.zeros(self.extpol.qclasses(self.nbcl).shape)
_flat[q_classes] = flat
return _flat
class ExtBinPrev(ExtendedPrev):
def __init__(
self,
flat: List[np.ndarray],
nbcl: int,
q_classes: List[List[int]] = None,
extpol: ExtensionPolicy = None,
):
super().__init__(flat, nbcl, extpol=extpol)
flat = self.__check_q_classes(q_classes, flat)
self.flat = self.__build_flat(flat)
def __check_q_classes(self, q_classes, flat):
if q_classes is None:
return flat
_flat = []
for fl, qc in zip(flat, q_classes):
qc = np.array(qc)
_fl = np.zeros(self.extpol.tfp_classes(self.nbcl).shape)
_fl[qc] = fl
_flat.append(_fl)
return np.array(_flat)
def __build_flat(self, flat):
return np.concatenate(flat.T)
class ExtendedCollection(LabelledCollection): class ExtendedCollection(LabelledCollection):
def __init__( def __init__(
self, self,
@ -233,19 +343,17 @@ class ExtendedCollection(LabelledCollection):
def n_classes(self): def n_classes(self):
return len(self.e_labels_.classes) return len(self.e_labels_.classes)
def counts(self): def e_prevalence(self) -> ExtendedPrev:
_counts = super().counts() _prev = self.prevalence()
if self.extpol.collapse_false: return ExtendedPrev(_prev, self.n_base_classes, extpol=self.extpol)
_counts = np.insert(_counts, 2, 0)
return _counts
def split_by_pred(self): def split_by_pred(self):
_ncl = self.pred_proba.shape[1] _indexes = _split_index_by_pred(self.pred_proba)
_instances, _indexes = self.e_data_.split_by_pred(return_indexes=True) _instances = self.e_data_.split_by_pred(_indexes)
_labels = [self.ey[ind] for ind in _indexes] # _labels = [self.ey[ind] for ind in _indexes]
_labels, _cls = self.e_labels_.split_by_pred(_indexes)
return [ return [
LabelledCollection(inst, lbl.true, classes=range(0, _ncl)) LabelledCollection(inst, lbl, classes=_cls)
for inst, lbl in zip(_instances, _labels) for inst, lbl in zip(_instances, _labels)
] ]

View File

@ -1,3 +1,4 @@
import itertools
import math import math
import os import os
import pickle import pickle
@ -119,35 +120,23 @@ class DatasetSample:
return {"train": self.train_prev, "validation": self.validation_prev} return {"train": self.train_prev, "validation": self.validation_prev}
class Dataset: class DatasetProvider:
def __init__(self, name, n_prevalences=9, prevs=None, target=None): def __spambase(self, **kwargs):
self._name = name
self._target = target
self.prevs = None
self.n_prevs = n_prevalences
if prevs is not None:
prevs = np.unique([p for p in prevs if p > 0.0 and p < 1.0])
if prevs.shape[0] > 0:
self.prevs = np.sort(prevs)
self.n_prevs = self.prevs.shape[0]
def __spambase(self):
return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test
# provare min_df=5 # provare min_df=5
def __imdb(self): def __imdb(self, **kwargs):
return qp.datasets.fetch_reviews("imdb", tfidf=True, min_df=3).train_test return qp.datasets.fetch_reviews("imdb", tfidf=True, min_df=3).train_test
def __rcv1(self): def __rcv1(self, target, **kwargs):
n_train = 23149 n_train = 23149
available_targets = ["CCAT", "GCAT", "MCAT"] available_targets = ["CCAT", "GCAT", "MCAT"]
if self._target is None or self._target not in available_targets: if target is None or target not in available_targets:
raise ValueError(f"Invalid target {self._target}") raise ValueError(f"Invalid target {target}")
dataset = fetch_rcv1() dataset = fetch_rcv1()
target_index = np.where(dataset.target_names == self._target)[0] target_index = np.where(dataset.target_names == target)[0]
all_train_d = dataset.data[:n_train, :] all_train_d = dataset.data[:n_train, :]
test_d = dataset.data[n_train:, :] test_d = dataset.data[n_train:, :]
labels = dataset.target[:, target_index].toarray().flatten() labels = dataset.target[:, target_index].toarray().flatten()
@ -157,14 +146,14 @@ class Dataset:
return all_train, test return all_train, test
def __cifar10(self): def __cifar10(self, target, **kwargs):
dataset = fetch_cifar10() dataset = fetch_cifar10()
available_targets: list = dataset.label_names available_targets: list = dataset.label_names
if self._target is None or self._target not in available_targets: if target is None or self._target not in available_targets:
raise ValueError(f"Invalid target {self._target}") raise ValueError(f"Invalid target {target}")
target_index = available_targets.index(self._target) target_index = available_targets.index(target)
all_train_d = dataset.train.data all_train_d = dataset.train.data
all_train_l = (dataset.train.labels == target_index).astype(int) all_train_l = (dataset.train.labels == target_index).astype(int)
test_d = dataset.test.data test_d = dataset.test.data
@ -174,14 +163,14 @@ class Dataset:
return all_train, test return all_train, test
def __cifar100(self): def __cifar100(self, target, **kwargs):
dataset = fetch_cifar100() dataset = fetch_cifar100()
available_targets: list = dataset.coarse_label_names available_targets: list = dataset.coarse_label_names
if self._target is None or self._target not in available_targets: if target is None or target not in available_targets:
raise ValueError(f"Invalid target {self._target}") raise ValueError(f"Invalid target {target}")
target_index = available_targets.index(self._target) target_index = available_targets.index(target)
all_train_d = dataset.train.data all_train_d = dataset.train.data
all_train_l = (dataset.train.coarse_labels == target_index).astype(int) all_train_l = (dataset.train.coarse_labels == target_index).astype(int)
test_d = dataset.test.data test_d = dataset.test.data
@ -191,68 +180,123 @@ class Dataset:
return all_train, test return all_train, test
def __train_test(self) -> Tuple[LabelledCollection, LabelledCollection]: def __twitter_gasp(self, **kwargs):
return qp.datasets.fetch_twitter("gasp", min_df=3).train_test
def alltrain_test(
self, name: str, target: str | None
) -> Tuple[LabelledCollection, LabelledCollection]:
all_train, test = { all_train, test = {
"spambase": self.__spambase, "spambase": self.__spambase,
"imdb": self.__imdb, "imdb": self.__imdb,
"rcv1": self.__rcv1, "rcv1": self.__rcv1,
"cifar10": self.__cifar10, "cifar10": self.__cifar10,
"cifar100": self.__cifar100, "cifar100": self.__cifar100,
}[self._name]() "twitter_gasp": self.__twitter_gasp,
}[name](target=target)
return all_train, test return all_train, test
def get_raw(self) -> DatasetSample:
all_train, test = self.__train_test()
train, val = all_train.split_stratified( class Dataset(DatasetProvider):
train_prop=TRAIN_VAL_PROP, random_state=env._R_SEED def __init__(self, name, n_prevalences=9, prevs=None, target=None):
) self._name = name
self._target = target
return DatasetSample(train, val, test) self.all_train, self.test = self.alltrain_test(self._name, self._target)
self.__resample_all_train()
def get(self) -> List[DatasetSample]: self.prevs = None
all_train, test = self.__train_test() self._n_prevs = n_prevalences
self.__check_prevs(prevs)
self.prevs = self.__build_prevs()
# resample all_train set to have (0.5, 0.5) prevalence def __resample_all_train(self):
at_positives = np.sum(all_train.y) tr_counts, tr_ncl = self.all_train.counts(), self.all_train.n_classes
all_train = all_train.sampling( _resample_prevs = np.full((tr_ncl,), fill_value=1.0 / tr_ncl)
min(at_positives, len(all_train) - at_positives) * 2, self.all_train = self.all_train.sampling(
0.5, np.min(tr_counts) * tr_ncl,
*_resample_prevs.tolist(),
random_state=env._R_SEED, random_state=env._R_SEED,
) )
# sample prevalences def __check_prevs(self, prevs):
try:
iter(prevs)
except TypeError:
return
if prevs is None or len(prevs) == 0:
return
def is_float_iterable(obj):
try:
it = iter(obj)
return all([isinstance(o, float) for o in it])
except TypeError:
return False
if not all([is_float_iterable(p) for p in prevs]):
return
if not all([len(p) == self.all_train.n_classes for p in prevs]):
return
if not all([sum(p) == 1.0 for p in prevs]):
return
self.prevs = np.unique(prevs, axis=0)
def __build_prevs(self):
if self.prevs is not None: if self.prevs is not None:
prevs = self.prevs return self.prevs
else:
prevs = np.linspace(0.0, 1.0, num=self.n_prevs + 1, endpoint=False)[1:]
at_size = min(math.floor(len(all_train) * 0.5 / p) for p in prevs) dim = self.all_train.n_classes
datasets = [] lspace = np.linspace(0.0, 1.0, num=self._n_prevs + 1, endpoint=False)[1:]
for p in 1.0 - prevs: mesh = np.array(np.meshgrid(*[lspace for _ in range(dim)])).T.reshape(-1, dim)
all_train_sampled = all_train.sampling(at_size, p, random_state=env._R_SEED) mesh = mesh[np.where(mesh.sum(axis=1) == 1.0)]
train, validation = all_train_sampled.split_stratified( return np.around(np.unique(mesh, axis=0), decimals=4)
train_prop=TRAIN_VAL_PROP, random_state=env._R_SEED
)
datasets.append(DatasetSample(train, validation, test))
return datasets def __build_sample(
self,
p: np.ndarray,
at_size: int,
):
all_train_sampled = self.all_train.sampling(
at_size, *(p[:-1]), random_state=env._R_SEED
)
train, validation = all_train_sampled.split_stratified(
train_prop=TRAIN_VAL_PROP, random_state=env._R_SEED
)
return DatasetSample(train, validation, self.test)
def get(self) -> List[DatasetSample]:
at_size = min(
math.floor(len(self.all_train) * (1.0 / self.all_train.n_classes) / p)
for _prev in self.prevs
for p in _prev
)
return [self.__build_sample(p, at_size) for p in self.prevs]
def __call__(self): def __call__(self):
return self.get() return self.get()
@property @property
def name(self): def name(self):
match (self._name, self.n_prevs): match (self._name, self._n_prevs):
case (("rcv1" | "cifar10" | "cifar100"), 9): case (("rcv1" | "cifar10" | "cifar100"), 9):
return f"{self._name}_{self._target}" return f"{self._name}_{self._target}"
case (("rcv1" | "cifar10" | "cifar100"), _): case (("rcv1" | "cifar10" | "cifar100"), _):
return f"{self._name}_{self._target}_{self.n_prevs}prevs" return f"{self._name}_{self._target}_{self._n_prevs}prevs"
case (_, 9): case (_, 9):
return f"{self._name}" return f"{self._name}"
case (_, _): case (_, _):
return f"{self._name}_{self.n_prevs}prevs" return f"{self._name}_{self._n_prevs}prevs"
@property
def nprevs(self):
return self.prevs.shape[0]
# >>> fetch_rcv1().target_names # >>> fetch_rcv1().target_names

View File

@ -1,4 +1,10 @@
from functools import wraps
from typing import List
import numpy as np import numpy as np
import quapy as qp
from quacc.data import ExtendedPrev
def from_name(err_name): def from_name(err_name):
@ -21,30 +27,54 @@ def from_name(err_name):
# return 2 * (precision * recall) / (precision + recall) # return 2 * (precision * recall) / (precision + recall)
def f1(prev): def nae(prevs: np.ndarray, prevs_hat: np.ndarray) -> np.ndarray:
den = (2 * prev[3]) + prev[1] + prev[2] _ae = qp.error.ae(prevs, prevs_hat)
if den == 0: # _zae = (2.0 * (1.0 - prevs.min())) / prevs.shape[1]
return 0.0 _zae = 2.0 / prevs.shape[1]
return _ae / _zae
def f1(prev: np.ndarray | ExtendedPrev) -> float:
if isinstance(prev, ExtendedPrev):
prev = prev.A
def _score(idx):
_tp = prev[idx, idx]
_fn = prev[idx, :].sum() - _tp
_fp = prev[:, idx].sum() - _tp
_den = 2.0 * _tp + _fp + _fn
return 0.0 if _den == 0.0 else (2.0 * _tp) / _den
if prev.shape[0] == 2:
return _score(1)
else: else:
return (2 * prev[3]) / den _idxs = np.arange(prev.shape[0])
return np.array([_score(idx) for idx in _idxs]).mean()
def f1e(prev): def f1e(prev):
return 1 - f1(prev) return 1 - f1(prev)
def acc(prev: np.ndarray) -> float: def acc(prev: np.ndarray | ExtendedPrev) -> float:
return (prev[0] + prev[3]) / np.sum(prev) if isinstance(prev, ExtendedPrev):
prev = prev.A
return np.diag(prev).sum() / prev.sum()
def accd(true_prevs: np.ndarray, estim_prevs: np.ndarray) -> np.ndarray: def accd(
vacc = np.vectorize(acc, signature="(m)->()") true_prevs: List[np.ndarray | ExtendedPrev],
a_tp = vacc(true_prevs) estim_prevs: List[np.ndarray | ExtendedPrev],
a_ep = vacc(estim_prevs) ) -> np.ndarray:
a_tp = np.array([acc(tp) for tp in true_prevs])
a_ep = np.array([acc(ep) for ep in estim_prevs])
return np.abs(a_tp - a_ep) return np.abs(a_tp - a_ep)
def maccd(true_prevs: np.ndarray, estim_prevs: np.ndarray) -> float: def maccd(
true_prevs: List[np.ndarray | ExtendedPrev],
estim_prevs: List[np.ndarray | ExtendedPrev],
) -> float:
return accd(true_prevs, estim_prevs).mean() return accd(true_prevs, estim_prevs).mean()

View File

@ -1,34 +0,0 @@
from typing import Callable, Union
import numpy as np
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
import quacc as qc
from ..method.base import BaseAccuracyEstimator
def evaluate(
estimator: BaseAccuracyEstimator,
protocol: AbstractProtocol,
error_metric: Union[Callable | str],
) -> float:
if isinstance(error_metric, str):
error_metric = qc.error.from_name(error_metric)
collator_bck_ = protocol.collator
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
estim_prevs, true_prevs = [], []
for sample in protocol():
e_sample = estimator.extend(sample)
estim_prev = estimator.estimate(e_sample.eX)
estim_prevs.append(estim_prev)
true_prevs.append(e_sample.prevalence())
protocol.collator = collator_bck_
true_prevs = np.array(true_prevs)
estim_prevs = np.array(estim_prevs)
return error_metric(true_prevs, estim_prevs)

View File

@ -43,6 +43,7 @@ def kfcv(
predict_method="predict", predict_method="predict",
): ):
c_model_predict = getattr(c_model, predict_method) c_model_predict = getattr(c_model, predict_method)
f1_average = "binary" if validation.n_classes == 2 else "macro"
scoring = ["accuracy", "f1_macro"] scoring = ["accuracy", "f1_macro"]
scores = cross_validate(c_model, validation.X, validation.y, scoring=scoring) scores = cross_validate(c_model, validation.X, validation.y, scoring=scoring)
@ -53,7 +54,9 @@ def kfcv(
for test in protocol(): for test in protocol():
test_preds = c_model_predict(test.X) test_preds = c_model_predict(test.X)
meta_acc = abs(acc_score - metrics.accuracy_score(test.y, test_preds)) meta_acc = abs(acc_score - metrics.accuracy_score(test.y, test_preds))
meta_f1 = abs(f1_score - metrics.f1_score(test.y, test_preds)) meta_f1 = abs(
f1_score - metrics.f1_score(test.y, test_preds, average=f1_average)
)
report.append_row( report.append_row(
test.prevalence(), test.prevalence(),
acc_score=acc_score, acc_score=acc_score,
@ -72,13 +75,15 @@ def ref(
protocol: AbstractStochasticSeededProtocol, protocol: AbstractStochasticSeededProtocol,
): ):
c_model_predict = getattr(c_model, "predict") c_model_predict = getattr(c_model, "predict")
f1_average = "binary" if validation.n_classes == 2 else "macro"
report = EvaluationReport(name="ref") report = EvaluationReport(name="ref")
for test in protocol(): for test in protocol():
test_preds = c_model_predict(test.X) test_preds = c_model_predict(test.X)
report.append_row( report.append_row(
test.prevalence(), test.prevalence(),
acc_score=metrics.accuracy_score(test.y, test_preds), acc_score=metrics.accuracy_score(test.y, test_preds),
f1_score=metrics.f1_score(test.y, test_preds), f1_score=metrics.f1_score(test.y, test_preds, average=f1_average),
) )
return report return report
@ -93,6 +98,7 @@ def atc_mc(
): ):
"""garg""" """garg"""
c_model_predict = getattr(c_model, predict_method) c_model_predict = getattr(c_model, predict_method)
f1_average = "binary" if validation.n_classes == 2 else "macro"
## Load ID validation data probs and labels ## Load ID validation data probs and labels
val_probs, val_labels = c_model_predict(validation.X), validation.y val_probs, val_labels = c_model_predict(validation.X), validation.y
@ -110,8 +116,12 @@ def atc_mc(
test_scores = atc.get_max_conf(test_probs) test_scores = atc.get_max_conf(test_probs)
atc_accuracy = atc.get_ATC_acc(atc_thres, test_scores) atc_accuracy = atc.get_ATC_acc(atc_thres, test_scores)
meta_acc = abs(atc_accuracy - metrics.accuracy_score(test.y, test_preds)) meta_acc = abs(atc_accuracy - metrics.accuracy_score(test.y, test_preds))
f1_score = atc.get_ATC_f1(atc_thres, test_scores, test_probs) f1_score = atc.get_ATC_f1(
meta_f1 = abs(f1_score - metrics.f1_score(test.y, test_preds)) atc_thres, test_scores, test_probs, average=f1_average
)
meta_f1 = abs(
f1_score - metrics.f1_score(test.y, test_preds, average=f1_average)
)
report.append_row( report.append_row(
test.prevalence(), test.prevalence(),
acc=meta_acc, acc=meta_acc,
@ -132,6 +142,7 @@ def atc_ne(
): ):
"""garg""" """garg"""
c_model_predict = getattr(c_model, predict_method) c_model_predict = getattr(c_model, predict_method)
f1_average = "binary" if validation.n_classes == 2 else "macro"
## Load ID validation data probs and labels ## Load ID validation data probs and labels
val_probs, val_labels = c_model_predict(validation.X), validation.y val_probs, val_labels = c_model_predict(validation.X), validation.y
@ -149,8 +160,12 @@ def atc_ne(
test_scores = atc.get_entropy(test_probs) test_scores = atc.get_entropy(test_probs)
atc_accuracy = atc.get_ATC_acc(atc_thres, test_scores) atc_accuracy = atc.get_ATC_acc(atc_thres, test_scores)
meta_acc = abs(atc_accuracy - metrics.accuracy_score(test.y, test_preds)) meta_acc = abs(atc_accuracy - metrics.accuracy_score(test.y, test_preds))
f1_score = atc.get_ATC_f1(atc_thres, test_scores, test_probs) f1_score = atc.get_ATC_f1(
meta_f1 = abs(f1_score - metrics.f1_score(test.y, test_preds)) atc_thres, test_scores, test_probs, average=f1_average
)
meta_f1 = abs(
f1_score - metrics.f1_score(test.y, test_preds, average=f1_average)
)
report.append_row( report.append_row(
test.prevalence(), test.prevalence(),
acc=meta_acc, acc=meta_acc,
@ -170,11 +185,14 @@ def doc(
predict_method="predict_proba", predict_method="predict_proba",
): ):
c_model_predict = getattr(c_model, predict_method) c_model_predict = getattr(c_model, predict_method)
f1_average = "binary" if validation.n_classes == 2 else "macro"
val1, val2 = validation.split_stratified(train_prop=0.5, random_state=env._R_SEED) val1, val2 = validation.split_stratified(train_prop=0.5, random_state=env._R_SEED)
val1_probs = c_model_predict(val1.X) val1_probs = c_model_predict(val1.X)
val1_mc = np.max(val1_probs, axis=-1) val1_mc = np.max(val1_probs, axis=-1)
val1_preds = np.argmax(val1_probs, axis=-1) val1_preds = np.argmax(val1_probs, axis=-1)
val1_acc = metrics.accuracy_score(val1.y, val1_preds) val1_acc = metrics.accuracy_score(val1.y, val1_preds)
val1_f1 = metrics.f1_score(val1.y, val1_preds, average=f1_average)
val2_protocol = APP( val2_protocol = APP(
val2, val2,
n_prevalences=21, n_prevalences=21,
@ -193,26 +211,44 @@ def doc(
val2_prot_y.append(v2.y) val2_prot_y.append(v2.y)
val_scores = np.array([doclib.get_doc(val1_mc, v2_mc) for v2_mc in val2_prot_mc]) val_scores = np.array([doclib.get_doc(val1_mc, v2_mc) for v2_mc in val2_prot_mc])
val_targets = np.array( val_targets_acc = np.array(
[ [
val1_acc - metrics.accuracy_score(v2_y, v2_preds) val1_acc - metrics.accuracy_score(v2_y, v2_preds)
for v2_y, v2_preds in zip(val2_prot_y, val2_prot_preds) for v2_y, v2_preds in zip(val2_prot_y, val2_prot_preds)
] ]
) )
reg = LinearRegression().fit( reg_acc = LinearRegression().fit(val_scores[:, np.newaxis], val_targets_acc)
val_scores.reshape((val_scores.shape[0], 1)), val_targets val_targets_f1 = np.array(
[
val1_f1 - metrics.f1_score(v2_y, v2_preds, average=f1_average)
for v2_y, v2_preds in zip(val2_prot_y, val2_prot_preds)
]
) )
reg_f1 = LinearRegression().fit(val_scores[:, np.newaxis], val_targets_f1)
report = EvaluationReport(name="doc") report = EvaluationReport(name="doc")
for test in protocol(): for test in protocol():
test_probs = c_model_predict(test.X) test_probs = c_model_predict(test.X)
test_preds = np.argmax(test_probs, axis=-1) test_preds = np.argmax(test_probs, axis=-1)
test_mc = np.max(test_probs, axis=-1) test_mc = np.max(test_probs, axis=-1)
score = ( acc_score = (
val1_acc - reg.predict(np.array([[doclib.get_doc(val1_mc, test_mc)]]))[0] val1_acc
- reg_acc.predict(np.array([[doclib.get_doc(val1_mc, test_mc)]]))[0]
)
f1_score = (
val1_f1 - reg_f1.predict(np.array([[doclib.get_doc(val1_mc, test_mc)]]))[0]
)
meta_acc = abs(acc_score - metrics.accuracy_score(test.y, test_preds))
meta_f1 = abs(
f1_score - metrics.f1_score(test.y, test_preds, average=f1_average)
)
report.append_row(
test.prevalence(),
acc=meta_acc,
acc_score=acc_score,
f1=meta_f1,
f1_score=f1_score,
) )
meta_acc = abs(score - metrics.accuracy_score(test.y, test_preds))
report.append_row(test.prevalence(), acc=meta_acc, acc_score=score)
return report return report

View File

@ -2,6 +2,7 @@ import os
import time import time
from traceback import print_exception as traceback from traceback import print_exception as traceback
import numpy as np
import pandas as pd import pandas as pd
import quapy as qp import quapy as qp
from joblib import Parallel, delayed from joblib import Parallel, delayed
@ -76,7 +77,8 @@ def evaluate_comparison(dataset: Dataset, estimators=None) -> DatasetReport:
log.info(f"dataset {dataset.name} [pool size: {__pool_size}]") log.info(f"dataset {dataset.name} [pool size: {__pool_size}]")
for d in dataset(): for d in dataset():
log.info( log.info(
f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} started" f"Dataset sample {np.around(d.train_prev, decimals=2)} "
f"of dataset {dataset.name} started"
) )
par_tasks, seq_tasks = split_tasks( par_tasks, seq_tasks = split_tasks(
CE.func[estimators], CE.func[estimators],
@ -93,7 +95,8 @@ def evaluate_comparison(dataset: Dataset, estimators=None) -> DatasetReport:
g_time = time.time() - tstart g_time = time.time() - tstart
log.info( log.info(
f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} finished " f"Dataset sample {np.around(d.train_prev, decimals=2)} "
f"of dataset {dataset.name} finished "
f"[took {g_time:.4f}s]" f"[took {g_time:.4f}s]"
) )
@ -108,7 +111,8 @@ def evaluate_comparison(dataset: Dataset, estimators=None) -> DatasetReport:
except Exception as e: except Exception as e:
log.warning( log.warning(
f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} failed. " f"Dataset sample {np.around(d.train_prev, decimals=2)} "
f"of dataset {dataset.name} failed. "
f"Exception: {e}" f"Exception: {e}"
) )
traceback(e) traceback(e)

View File

@ -0,0 +1,32 @@
from typing import Callable, Union
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
import quacc as qc
from quacc.method.base import BaseAccuracyEstimator
def evaluate(
estimator: BaseAccuracyEstimator,
protocol: AbstractProtocol,
error_metric: Union[Callable | str],
) -> float:
if isinstance(error_metric, str):
error_metric = qc.error.from_name(error_metric)
collator_bck_ = protocol.collator
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
estim_prevs, true_prevs = [], []
for sample in protocol():
e_sample = estimator.extend(sample)
estim_prev = estimator.estimate(e_sample.eX)
estim_prevs.append(estim_prev)
true_prevs.append(e_sample.e_prevalence())
protocol.collator = collator_bck_
# true_prevs = np.array(true_prevs)
# estim_prevs = np.array(estim_prevs)
return error_metric(true_prevs, estim_prevs)

View File

@ -1,11 +1,12 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import Callable, List, Union
import numpy as np import numpy as np
from matplotlib.pylab import rand
from quapy.method.aggregative import PACC, SLD, BaseQuantifier from quapy.method.aggregative import PACC, SLD, BaseQuantifier
from quapy.protocol import UPP, AbstractProtocol from quapy.protocol import UPP, AbstractProtocol, OnLabelledCollectionProtocol
from sklearn.linear_model import LogisticRegression from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC from sklearn.svm import SVC, LinearSVC
import quacc as qc import quacc as qc
from quacc.environment import env from quacc.environment import env
@ -13,34 +14,51 @@ from quacc.evaluation.report import EvaluationReport
from quacc.method.base import BQAE, MCAE, BaseAccuracyEstimator from quacc.method.base import BQAE, MCAE, BaseAccuracyEstimator
from quacc.method.model_selection import ( from quacc.method.model_selection import (
GridSearchAE, GridSearchAE,
HalvingSearchAE,
RandomizedSearchAE,
SpiderSearchAE, SpiderSearchAE,
) )
from quacc.quantification import KDEy from quacc.quantification import KDEy
_param_grid = {
"sld": { def _param_grid(method, X_fit: np.ndarray):
"q__classifier__C": np.logspace(-3, 3, 7), match method:
"q__classifier__class_weight": [None, "balanced"], case "sld_lr":
"q__recalib": [None, "bcts"], return {
# "q__recalib": [None], "q__classifier__C": np.logspace(-3, 3, 7),
"confidence": [None, ["isoft"], ["max_conf", "entropy"]], "q__classifier__class_weight": [None, "balanced"],
}, "q__recalib": [None, "bcts"],
"pacc": { "confidence": [None, ["isoft"], ["max_conf", "entropy"]],
"q__classifier__C": np.logspace(-3, 3, 7), }
"q__classifier__class_weight": [None, "balanced"], case "sld_rbf":
"confidence": [None, ["isoft"], ["max_conf", "entropy"]], _scale = 1.0 / (X_fit.shape[1] * X_fit.var())
}, return {
"kde": { "q__classifier__C": np.logspace(-3, 3, 7),
"q__classifier__C": np.logspace(-3, 3, 7), "q__classifier__class_weight": [None, "balanced"],
"q__classifier__class_weight": [None, "balanced"], "q__classifier__gamma": _scale * np.logspace(-2, 2, 5),
# "q__classifier__class_weight": [None], "q__recalib": [None, "bcts"],
"q__bandwidth": np.linspace(0.01, 0.2, 20), "confidence": [None, ["isoft"], ["max_conf", "entropy"]],
"confidence": [None, ["isoft"]], }
# "confidence": [None], case "pacc":
}, return {
} "q__classifier__C": np.logspace(-3, 3, 7),
"q__classifier__class_weight": [None, "balanced"],
"confidence": [None, ["isoft"], ["max_conf", "entropy"]],
}
case "kde_lr":
return {
"q__classifier__C": np.logspace(-3, 3, 7),
"q__classifier__class_weight": [None, "balanced"],
"q__bandwidth": np.linspace(0.01, 0.2, 20),
"confidence": [None, ["isoft"]],
}
case "kde_rbf":
_scale = 1.0 / (X_fit.shape[1] * X_fit.var())
return {
"q__classifier__C": np.logspace(-3, 3, 7),
"q__classifier__class_weight": [None, "balanced"],
"q__classifier__gamma": _scale * np.logspace(-2, 2, 5),
"q__bandwidth": np.linspace(0.01, 0.2, 20),
"confidence": [None, ["isoft"]],
}
def evaluation_report( def evaluation_report(
@ -52,15 +70,19 @@ def evaluation_report(
try: try:
e_sample = estimator.extend(sample) e_sample = estimator.extend(sample)
estim_prev = estimator.estimate(e_sample.eX) estim_prev = estimator.estimate(e_sample.eX)
true_prev = e_sample.e_prevalence()
acc_score = qc.error.acc(estim_prev) acc_score = qc.error.acc(estim_prev)
f1_score = qc.error.f1(estim_prev) row = dict(
report.append_row(
sample.prevalence(),
acc_score=acc_score, acc_score=acc_score,
acc=abs(qc.error.acc(e_sample.prevalence()) - acc_score), acc=abs(qc.error.acc(true_prev) - acc_score),
f1_score=f1_score,
f1=abs(qc.error.f1(e_sample.prevalence()) - f1_score),
) )
if estim_prev.can_f1():
f1_score = qc.error.f1(estim_prev)
row = row | dict(
f1_score=f1_score,
f1=abs(qc.error.f1(true_prev) - f1_score),
)
report.append_row(sample.prevalence(), **row)
except Exception as e: except Exception as e:
print(f"sample prediction failed for method {method_name}: {e}") print(f"sample prediction failed for method {method_name}: {e}")
report.append_row( report.append_row(
@ -80,7 +102,9 @@ class EvaluationMethod:
q: BaseQuantifier q: BaseQuantifier
est_n: str est_n: str
conf: List[str] | str = None conf: List[str] | str = None
cf: bool = False cf: bool = False # collapse_false
gf: bool = False # group_false
d: bool = False # dense
def get_est(self, c_model): def get_est(self, c_model):
match self.est_n: match self.est_n:
@ -90,9 +114,17 @@ class EvaluationMethod:
self.q, self.q,
confidence=self.conf, confidence=self.conf,
collapse_false=self.cf, collapse_false=self.cf,
group_false=self.gf,
dense=self.d,
) )
case "bin": case "bin":
return BQAE(c_model, self.q, confidence=self.conf) return BQAE(
c_model,
self.q,
confidence=self.conf,
group_false=self.gf,
dense=self.d,
)
def __call__(self, c_model, validation, protocol) -> EvaluationReport: def __call__(self, c_model, validation, protocol) -> EvaluationReport:
est = self.get_est(c_model).fit(validation) est = self.get_est(c_model).fit(validation)
@ -109,22 +141,26 @@ class EvaluationMethodGridSearch(EvaluationMethod):
def get_search(self): def get_search(self):
match self.search: match self.search:
case "grid": case "grid":
return GridSearchAE return (GridSearchAE, {})
case "spider": case "spider" | "spider2":
return SpiderSearchAE return (SpiderSearchAE, dict(best_width=2))
case "spider3":
return (SpiderSearchAE, dict(best_width=3))
case _: case _:
return GridSearchAE return GridSearchAE
def __call__(self, c_model, validation, protocol) -> EvaluationReport: def __call__(self, c_model, validation, protocol) -> EvaluationReport:
v_train, v_val = validation.split_stratified(0.6, random_state=env._R_SEED) v_train, v_val = validation.split_stratified(0.6, random_state=env._R_SEED)
__grid = _param_grid.get(self.pg, {}) _model = self.get_est(c_model)
_search_class = self.get_search() _grid = _param_grid(self.pg, X_fit=_model.extend(v_train, prefit=True).X)
_search_class, _search_params = self.get_search()
est = _search_class( est = _search_class(
model=self.get_est(c_model), model=_model,
param_grid=__grid, param_grid=_grid,
refit=False, refit=False,
protocol=UPP(v_val, repeats=100), protocol=UPP(v_val, repeats=100),
verbose=False, verbose=False,
**_search_params,
).fit(v_train) ).fit(v_train)
return evaluation_report( return evaluation_report(
estimator=est, estimator=est,
@ -141,10 +177,18 @@ def __sld_lr():
return SLD(LogisticRegression()) return SLD(LogisticRegression())
def __sld_rbf():
return SLD(SVC(kernel="rbf", probability=True))
def __kde_lr(): def __kde_lr():
return KDEy(LogisticRegression(), random_state=env._R_SEED) return KDEy(LogisticRegression(), random_state=env._R_SEED)
def __kde_rbf():
return KDEy(SVC(kernel="rbf", probability=True), random_state=env._R_SEED)
def __sld_lsvc(): def __sld_lsvc():
return SLD(LinearSVC()) return SLD(LinearSVC())
@ -154,57 +198,212 @@ def __pacc_lr():
# fmt: off # fmt: off
__methods_set = [
# base sld
M("bin_sld", __sld_lr(), "bin" ),
M("mul_sld", __sld_lr(), "mul" ),
M("m3w_sld", __sld_lr(), "mul", cf=True),
# max_conf + entropy sld
M("binc_sld", __sld_lr(), "bin", conf=["max_conf", "entropy"] ),
M("mulc_sld", __sld_lr(), "mul", conf=["max_conf", "entropy"] ),
M("m3wc_sld", __sld_lr(), "mul", conf=["max_conf", "entropy"], cf=True),
# max_conf sld
M("binmc_sld", __sld_lr(), "bin", conf="max_conf", ),
M("mulmc_sld", __sld_lr(), "mul", conf="max_conf", ),
M("m3wmc_sld", __sld_lr(), "mul", conf="max_conf", cf=True),
# entropy sld
M("binne_sld", __sld_lr(), "bin", conf="entropy", ),
M("mulne_sld", __sld_lr(), "mul", conf="entropy", ),
M("m3wne_sld", __sld_lr(), "mul", conf="entropy", cf=True),
# inverse softmax sld
M("binis_sld", __sld_lr(), "bin", conf="isoft", ),
M("mulis_sld", __sld_lr(), "mul", conf="isoft", ),
M("m3wis_sld", __sld_lr(), "mul", conf="isoft", cf=True),
# gs sld
G("bin_sld_gs", __sld_lr(), "bin", pg="sld" ),
G("mul_sld_gs", __sld_lr(), "mul", pg="sld" ),
G("m3w_sld_gs", __sld_lr(), "mul", pg="sld", cf=True),
# base kde __sld_lr_set = [
M("bin_kde", __kde_lr(), "bin" ), M("bin_sld_lr", __sld_lr(), "bin" ),
M("mul_kde", __kde_lr(), "mul" ), M("bgf_sld_lr", __sld_lr(), "bin", gf=True),
M("m3w_kde", __kde_lr(), "mul", cf=True), M("mul_sld_lr", __sld_lr(), "mul" ),
# max_conf + entropy kde M("m3w_sld_lr", __sld_lr(), "mul", cf=True),
M("binc_kde", __kde_lr(), "bin", conf=["max_conf", "entropy"] ), M("mgf_sld_lr", __sld_lr(), "mul", gf=True),
M("mulc_kde", __kde_lr(), "mul", conf=["max_conf", "entropy"] ), # max_conf + entropy sld
M("m3wc_kde", __kde_lr(), "mul", conf=["max_conf", "entropy"], cf=True), M("bin_sld_lr_c", __sld_lr(), "bin", conf=["max_conf", "entropy"] ),
# max_conf kde M("bgf_sld_lr_c", __sld_lr(), "bin", conf=["max_conf", "entropy"], gf=True),
M("binmc_kde", __kde_lr(), "bin", conf="max_conf", ), M("mul_sld_lr_c", __sld_lr(), "mul", conf=["max_conf", "entropy"] ),
M("mulmc_kde", __kde_lr(), "mul", conf="max_conf", ), M("m3w_sld_lr_c", __sld_lr(), "mul", conf=["max_conf", "entropy"], cf=True),
M("m3wmc_kde", __kde_lr(), "mul", conf="max_conf", cf=True), M("mgf_sld_lr_c", __sld_lr(), "mul", conf=["max_conf", "entropy"], gf=True),
# entropy kde # max_conf sld
M("binne_kde", __kde_lr(), "bin", conf="entropy", ), M("bin_sld_lr_mc", __sld_lr(), "bin", conf="max_conf", ),
M("mulne_kde", __kde_lr(), "mul", conf="entropy", ), M("bgf_sld_lr_mc", __sld_lr(), "bin", conf="max_conf", gf=True),
M("m3wne_kde", __kde_lr(), "mul", conf="entropy", cf=True), M("mul_sld_lr_mc", __sld_lr(), "mul", conf="max_conf", ),
# inverse softmax kde M("m3w_sld_lr_mc", __sld_lr(), "mul", conf="max_conf", cf=True),
M("binis_kde", __kde_lr(), "bin", conf="isoft", ), M("mgf_sld_lr_mc", __sld_lr(), "mul", conf="max_conf", gf=True),
M("mulis_kde", __kde_lr(), "mul", conf="isoft", ), # entropy sld
M("m3wis_kde", __kde_lr(), "mul", conf="isoft", cf=True), M("bin_sld_lr_ne", __sld_lr(), "bin", conf="entropy", ),
# gs kde M("bgf_sld_lr_ne", __sld_lr(), "bin", conf="entropy", gf=True),
G("bin_kde_gs", __kde_lr(), "bin", pg="kde", search="spider" ), M("mul_sld_lr_ne", __sld_lr(), "mul", conf="entropy", ),
G("mul_kde_gs", __kde_lr(), "mul", pg="kde", search="spider" ), M("m3w_sld_lr_ne", __sld_lr(), "mul", conf="entropy", cf=True),
G("m3w_kde_gs", __kde_lr(), "mul", pg="kde", search="spider", cf=True), M("mgf_sld_lr_ne", __sld_lr(), "mul", conf="entropy", gf=True),
# inverse softmax sld
M("bin_sld_lr_is", __sld_lr(), "bin", conf="isoft", ),
M("bgf_sld_lr_is", __sld_lr(), "bin", conf="isoft", gf=True),
M("mul_sld_lr_is", __sld_lr(), "mul", conf="isoft", ),
M("m3w_sld_lr_is", __sld_lr(), "mul", conf="isoft", cf=True),
M("mgf_sld_lr_is", __sld_lr(), "mul", conf="isoft", gf=True),
# gs sld
G("bin_sld_lr_gs", __sld_lr(), "bin", pg="sld_lr" ),
G("bgf_sld_lr_gs", __sld_lr(), "bin", pg="sld_lr", gf=True),
G("mul_sld_lr_gs", __sld_lr(), "mul", pg="sld_lr" ),
G("m3w_sld_lr_gs", __sld_lr(), "mul", pg="sld_lr", cf=True),
G("mgf_sld_lr_gs", __sld_lr(), "mul", pg="sld_lr", gf=True),
] ]
__dense_sld_lr_set = [
M("d_bin_sld_lr", __sld_lr(), "bin", d=True, ),
M("d_bgf_sld_lr", __sld_lr(), "bin", d=True, gf=True),
M("d_mul_sld_lr", __sld_lr(), "mul", d=True, ),
M("d_m3w_sld_lr", __sld_lr(), "mul", d=True, cf=True),
M("d_mgf_sld_lr", __sld_lr(), "mul", d=True, gf=True),
# max_conf + entropy sld
M("d_bin_sld_lr_c", __sld_lr(), "bin", d=True, conf=["max_conf", "entropy"] ),
M("d_bgf_sld_lr_c", __sld_lr(), "bin", d=True, conf=["max_conf", "entropy"], gf=True),
M("d_mul_sld_lr_c", __sld_lr(), "mul", d=True, conf=["max_conf", "entropy"] ),
M("d_m3w_sld_lr_c", __sld_lr(), "mul", d=True, conf=["max_conf", "entropy"], cf=True),
M("d_mgf_sld_lr_c", __sld_lr(), "mul", d=True, conf=["max_conf", "entropy"], gf=True),
# max_conf sld
M("d_bin_sld_lr_mc", __sld_lr(), "bin", d=True, conf="max_conf", ),
M("d_bgf_sld_lr_mc", __sld_lr(), "bin", d=True, conf="max_conf", gf=True),
M("d_mul_sld_lr_mc", __sld_lr(), "mul", d=True, conf="max_conf", ),
M("d_m3w_sld_lr_mc", __sld_lr(), "mul", d=True, conf="max_conf", cf=True),
M("d_mgf_sld_lr_mc", __sld_lr(), "mul", d=True, conf="max_conf", gf=True),
# entropy sld
M("d_bin_sld_lr_ne", __sld_lr(), "bin", d=True, conf="entropy", ),
M("d_bgf_sld_lr_ne", __sld_lr(), "bin", d=True, conf="entropy", gf=True),
M("d_mul_sld_lr_ne", __sld_lr(), "mul", d=True, conf="entropy", ),
M("d_m3w_sld_lr_ne", __sld_lr(), "mul", d=True, conf="entropy", cf=True),
M("d_mgf_sld_lr_ne", __sld_lr(), "mul", d=True, conf="entropy", gf=True),
# inverse softmax sld
M("d_bin_sld_lr_is", __sld_lr(), "bin", d=True, conf="isoft", ),
M("d_bgf_sld_lr_is", __sld_lr(), "bin", d=True, conf="isoft", gf=True),
M("d_mul_sld_lr_is", __sld_lr(), "mul", d=True, conf="isoft", ),
M("d_m3w_sld_lr_is", __sld_lr(), "mul", d=True, conf="isoft", cf=True),
M("d_mgf_sld_lr_is", __sld_lr(), "mul", d=True, conf="isoft", gf=True),
# gs sld
G("d_bin_sld_lr_gs", __sld_lr(), "bin", d=True, pg="sld_lr" ),
G("d_bgf_sld_lr_gs", __sld_lr(), "bin", d=True, pg="sld_lr", gf=True),
G("d_mul_sld_lr_gs", __sld_lr(), "mul", d=True, pg="sld_lr" ),
G("d_m3w_sld_lr_gs", __sld_lr(), "mul", d=True, pg="sld_lr", cf=True),
G("d_mgf_sld_lr_gs", __sld_lr(), "mul", d=True, pg="sld_lr", gf=True),
]
__dense_sld_rbf_set = [
M("d_bin_sld_rbf", __sld_rbf(), "bin", d=True, ),
M("d_bgf_sld_rbf", __sld_rbf(), "bin", d=True, gf=True),
M("d_mul_sld_rbf", __sld_rbf(), "mul", d=True, ),
M("d_m3w_sld_rbf", __sld_rbf(), "mul", d=True, cf=True),
M("d_mgf_sld_rbf", __sld_rbf(), "mul", d=True, gf=True),
# max_conf + entropy sld
M("d_bin_sld_rbf_c", __sld_rbf(), "bin", d=True, conf=["max_conf", "entropy"] ),
M("d_bgf_sld_rbf_c", __sld_rbf(), "bin", d=True, conf=["max_conf", "entropy"], gf=True),
M("d_mul_sld_rbf_c", __sld_rbf(), "mul", d=True, conf=["max_conf", "entropy"] ),
M("d_m3w_sld_rbf_c", __sld_rbf(), "mul", d=True, conf=["max_conf", "entropy"], cf=True),
M("d_mgf_sld_rbf_c", __sld_rbf(), "mul", d=True, conf=["max_conf", "entropy"], gf=True),
# max_conf sld
M("d_bin_sld_rbf_mc", __sld_rbf(), "bin", d=True, conf="max_conf", ),
M("d_bgf_sld_rbf_mc", __sld_rbf(), "bin", d=True, conf="max_conf", gf=True),
M("d_mul_sld_rbf_mc", __sld_rbf(), "mul", d=True, conf="max_conf", ),
M("d_m3w_sld_rbf_mc", __sld_rbf(), "mul", d=True, conf="max_conf", cf=True),
M("d_mgf_sld_rbf_mc", __sld_rbf(), "mul", d=True, conf="max_conf", gf=True),
# entropy sld
M("d_bin_sld_rbf_ne", __sld_rbf(), "bin", d=True, conf="entropy", ),
M("d_bgf_sld_rbf_ne", __sld_rbf(), "bin", d=True, conf="entropy", gf=True),
M("d_mul_sld_rbf_ne", __sld_rbf(), "mul", d=True, conf="entropy", ),
M("d_m3w_sld_rbf_ne", __sld_rbf(), "mul", d=True, conf="entropy", cf=True),
M("d_mgf_sld_rbf_ne", __sld_rbf(), "mul", d=True, conf="entropy", gf=True),
# inverse softmax sld
M("d_bin_sld_rbf_is", __sld_rbf(), "bin", d=True, conf="isoft", ),
M("d_bgf_sld_rbf_is", __sld_rbf(), "bin", d=True, conf="isoft", gf=True),
M("d_mul_sld_rbf_is", __sld_rbf(), "mul", d=True, conf="isoft", ),
M("d_m3w_sld_rbf_is", __sld_rbf(), "mul", d=True, conf="isoft", cf=True),
M("d_mgf_sld_rbf_is", __sld_rbf(), "mul", d=True, conf="isoft", gf=True),
# gs sld
G("d_bin_sld_rbf_gs", __sld_rbf(), "bin", d=True, pg="sld_rbf", search="spider", ),
G("d_bgf_sld_rbf_gs", __sld_rbf(), "bin", d=True, pg="sld_rbf", search="spider", gf=True),
G("d_mul_sld_rbf_gs", __sld_rbf(), "mul", d=True, pg="sld_rbf", search="spider", ),
G("d_m3w_sld_rbf_gs", __sld_rbf(), "mul", d=True, pg="sld_rbf", search="spider", cf=True),
G("d_mgf_sld_rbf_gs", __sld_rbf(), "mul", d=True, pg="sld_rbf", search="spider", gf=True),
]
__kde_lr_set = [
# base kde
M("bin_kde_lr", __kde_lr(), "bin" ),
M("mul_kde_lr", __kde_lr(), "mul" ),
M("m3w_kde_lr", __kde_lr(), "mul", cf=True),
# max_conf + entropy kde
M("bin_kde_lr_c", __kde_lr(), "bin", conf=["max_conf", "entropy"] ),
M("mul_kde_lr_c", __kde_lr(), "mul", conf=["max_conf", "entropy"] ),
M("m3w_kde_lr_c", __kde_lr(), "mul", conf=["max_conf", "entropy"], cf=True),
# max_conf kde
M("bin_kde_lr_mc", __kde_lr(), "bin", conf="max_conf", ),
M("mul_kde_lr_mc", __kde_lr(), "mul", conf="max_conf", ),
M("m3w_kde_lr_mc", __kde_lr(), "mul", conf="max_conf", cf=True),
# entropy kde
M("bin_kde_lr_ne", __kde_lr(), "bin", conf="entropy", ),
M("mul_kde_lr_ne", __kde_lr(), "mul", conf="entropy", ),
M("m3w_kde_lr_ne", __kde_lr(), "mul", conf="entropy", cf=True),
# inverse softmax kde
M("bin_kde_lr_is", __kde_lr(), "bin", conf="isoft", ),
M("mul_kde_lr_is", __kde_lr(), "mul", conf="isoft", ),
M("m3w_kde_lr_is", __kde_lr(), "mul", conf="isoft", cf=True),
# gs kde
G("bin_kde_lr_gs", __kde_lr(), "bin", pg="kde_lr", search="spider" ),
G("mul_kde_lr_gs", __kde_lr(), "mul", pg="kde_lr", search="spider" ),
G("m3w_kde_lr_gs", __kde_lr(), "mul", pg="kde_lr", search="spider", cf=True),
]
__dense_kde_lr_set = [
# base kde
M("d_bin_kde_lr", __kde_lr(), "bin", d=True, ),
M("d_mul_kde_lr", __kde_lr(), "mul", d=True, ),
M("d_m3w_kde_lr", __kde_lr(), "mul", d=True, cf=True),
# max_conf + entropy kde
M("d_bin_kde_lr_c", __kde_lr(), "bin", d=True, conf=["max_conf", "entropy"] ),
M("d_mul_kde_lr_c", __kde_lr(), "mul", d=True, conf=["max_conf", "entropy"] ),
M("d_m3w_kde_lr_c", __kde_lr(), "mul", d=True, conf=["max_conf", "entropy"], cf=True),
# max_conf kde
M("d_bin_kde_lr_mc", __kde_lr(), "bin", d=True, conf="max_conf", ),
M("d_mul_kde_lr_mc", __kde_lr(), "mul", d=True, conf="max_conf", ),
M("d_m3w_kde_lr_mc", __kde_lr(), "mul", d=True, conf="max_conf", cf=True),
# entropy kde
M("d_bin_kde_lr_ne", __kde_lr(), "bin", d=True, conf="entropy", ),
M("d_mul_kde_lr_ne", __kde_lr(), "mul", d=True, conf="entropy", ),
M("d_m3w_kde_lr_ne", __kde_lr(), "mul", d=True, conf="entropy", cf=True),
# inverse softmax kde d=True,
M("d_bin_kde_lr_is", __kde_lr(), "bin", d=True, conf="isoft", ),
M("d_mul_kde_lr_is", __kde_lr(), "mul", d=True, conf="isoft", ),
M("d_m3w_kde_lr_is", __kde_lr(), "mul", d=True, conf="isoft", cf=True),
# gs kde
G("d_bin_kde_lr_gs", __kde_lr(), "bin", d=True, pg="kde_lr", search="spider" ),
G("d_mul_kde_lr_gs", __kde_lr(), "mul", d=True, pg="kde_lr", search="spider" ),
G("d_m3w_kde_lr_gs", __kde_lr(), "mul", d=True, pg="kde_lr", search="spider", cf=True),
]
__dense_kde_rbf_set = [
# base kde
M("d_bin_kde_rbf", __kde_rbf(), "bin", d=True, ),
M("d_mul_kde_rbf", __kde_rbf(), "mul", d=True, ),
M("d_m3w_kde_rbf", __kde_rbf(), "mul", d=True, cf=True),
# max_conf + entropy kde
M("d_bin_kde_rbf_c", __kde_rbf(), "bin", d=True, conf=["max_conf", "entropy"] ),
M("d_mul_kde_rbf_c", __kde_rbf(), "mul", d=True, conf=["max_conf", "entropy"] ),
M("d_m3w_kde_rbf_c", __kde_rbf(), "mul", d=True, conf=["max_conf", "entropy"], cf=True),
# max_conf kde
M("d_bin_kde_rbf_mc", __kde_rbf(), "bin", d=True, conf="max_conf", ),
M("d_mul_kde_rbf_mc", __kde_rbf(), "mul", d=True, conf="max_conf", ),
M("d_m3w_kde_rbf_mc", __kde_rbf(), "mul", d=True, conf="max_conf", cf=True),
# entropy kde
M("d_bin_kde_rbf_ne", __kde_rbf(), "bin", d=True, conf="entropy", ),
M("d_mul_kde_rbf_ne", __kde_rbf(), "mul", d=True, conf="entropy", ),
M("d_m3w_kde_rbf_ne", __kde_rbf(), "mul", d=True, conf="entropy", cf=True),
# inverse softmax kde
M("d_bin_kde_rbf_is", __kde_rbf(), "bin", d=True, conf="isoft", ),
M("d_mul_kde_rbf_is", __kde_rbf(), "mul", d=True, conf="isoft", ),
M("d_m3w_kde_rbf_is", __kde_rbf(), "mul", d=True, conf="isoft", cf=True),
# gs kde
G("d_bin_kde_rbf_gs", __kde_rbf(), "bin", d=True, pg="kde_rbf", search="spider" ),
G("d_mul_kde_rbf_gs", __kde_rbf(), "mul", d=True, pg="kde_rbf", search="spider" ),
G("d_m3w_kde_rbf_gs", __kde_rbf(), "mul", d=True, pg="kde_rbf", search="spider", cf=True),
]
# fmt: on # fmt: on
__methods_set = (
__sld_lr_set
+ __dense_sld_lr_set
+ __dense_sld_rbf_set
+ __kde_lr_set
+ __dense_kde_lr_set
+ __dense_kde_rbf_set
)
_methods = {m.name: m for m in __methods_set} _methods = {m.name: m for m in __methods_set}

View File

@ -1,11 +1,14 @@
import json import json
import pickle import pickle
from collections import defaultdict
from itertools import chain
from pathlib import Path from pathlib import Path
from typing import List, Tuple from typing import List, Tuple
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import quacc as qc
import quacc.plot as plot import quacc.plot as plot
from quacc.utils import fmt_line_md from quacc.utils import fmt_line_md
@ -22,14 +25,24 @@ def _get_estimators(estimators: List[str], cols: np.ndarray):
return estimators[np.isin(estimators, cols)] return estimators[np.isin(estimators, cols)]
def _get_shift(index: np.ndarray, train_prev: np.ndarray):
index = np.array([np.array(tp) for tp in index])
train_prevs = np.tile(train_prev, (index.shape[0], 1))
# assert index.shape[1] == train_prev.shape[0], "Mismatch in prevalence shape"
# _shift = np.abs(index - train_prev)[:, 1:].sum(axis=1)
_shift = qc.error.nae(index, train_prevs)
return np.around(_shift, decimals=2)
class EvaluationReport: class EvaluationReport:
def __init__(self, name=None): def __init__(self, name=None):
self.data: pd.DataFrame | None = None self.data: pd.DataFrame | None = None
self.fit_score = None
self.name = name if name is not None else "default" self.name = name if name is not None else "default"
self.time = 0.0
def append_row(self, basep: np.ndarray | Tuple, **row): def append_row(self, basep: np.ndarray | Tuple, **row):
bp = basep[1] # bp = basep[1]
bp = tuple(basep)
_keys, _values = zip(*row.items()) _keys, _values = zip(*row.items())
# _keys = list(row.keys()) # _keys = list(row.keys())
# _values = list(row.values()) # _values = list(row.values())
@ -89,7 +102,7 @@ class CompReport:
) )
.swaplevel(0, 1, axis=1) .swaplevel(0, 1, axis=1)
.sort_index(axis=1, level=0, sort_remaining=False) .sort_index(axis=1, level=0, sort_remaining=False)
.sort_index(axis=0, level=0) .sort_index(axis=0, level=0, ascending=False, sort_remaining=False)
) )
if times is None: if times is None:
@ -97,17 +110,13 @@ class CompReport:
else: else:
self.times = times self.times = times
self.times["tot"] = g_time self.times["tot"] = g_time if g_time is not None else 0.0
self.train_prev = train_prev self.train_prev = train_prev
self.valid_prev = valid_prev self.valid_prev = valid_prev
@property @property
def prevs(self) -> np.ndarray: def prevs(self) -> np.ndarray:
return np.sort(self._data.index.unique(0)) return self.data().index.unique(0)
@property
def np_prevs(self) -> np.ndarray:
return np.around([(1.0 - p, p) for p in self.prevs], decimals=2)
def join(self, other, how="update", estimators=None): def join(self, other, how="update", estimators=None):
if how not in ["update"]: if how not in ["update"]:
@ -160,16 +169,14 @@ class CompReport:
def shift_data( def shift_data(
self, metric: str = None, estimators: List[str] = None self, metric: str = None, estimators: List[str] = None
) -> pd.DataFrame: ) -> pd.DataFrame:
shift_idx_0 = np.around( shift_idx_0 = _get_shift(
np.abs( self._data.index.get_level_values(0).to_numpy(),
self._data.index.get_level_values(0).to_numpy() - self.train_prev[1] self.train_prev,
),
decimals=2,
) )
shift_idx_1 = np.empty(shape=shift_idx_0.shape, dtype="<i4") shift_idx_1 = np.zeros(shape=shift_idx_0.shape[0], dtype="<i4")
for _id in np.unique(shift_idx_0): for _id in np.unique(shift_idx_0):
_wh = np.where(shift_idx_0 == _id)[0] _wh = (shift_idx_0 == _id).nonzero()[0]
shift_idx_1[_wh] = np.arange(_wh.shape[0], dtype="<i4") shift_idx_1[_wh] = np.arange(_wh.shape[0], dtype="<i4")
shift_data = self._data.copy() shift_data = self._data.copy()
@ -191,26 +198,28 @@ class CompReport:
self, metric: str = None, estimators: List[str] = None self, metric: str = None, estimators: List[str] = None
) -> pd.DataFrame: ) -> pd.DataFrame:
f_dict = self.data(metric=metric, estimators=estimators) f_dict = self.data(metric=metric, estimators=estimators)
return f_dict.groupby(level=0).mean() return f_dict.groupby(level=0, sort=False).mean()
def stdev_by_prevs( def stdev_by_prevs(
self, metric: str = None, estimators: List[str] = None self, metric: str = None, estimators: List[str] = None
) -> pd.DataFrame: ) -> pd.DataFrame:
f_dict = self.data(metric=metric, estimators=estimators) f_dict = self.data(metric=metric, estimators=estimators)
return f_dict.groupby(level=0).std() return f_dict.groupby(level=0, sort=False).std()
def table(self, metric: str = None, estimators: List[str] = None) -> pd.DataFrame: def train_table(
self, metric: str = None, estimators: List[str] = None
) -> pd.DataFrame:
f_data = self.data(metric=metric, estimators=estimators) f_data = self.data(metric=metric, estimators=estimators)
avg_p = f_data.groupby(level=0).mean() avg_p = f_data.groupby(level=0, sort=False).mean()
avg_p.loc["avg", :] = f_data.mean() avg_p.loc["mean", :] = f_data.mean()
return avg_p return avg_p
def shift_table( def shift_table(
self, metric: str = None, estimators: List[str] = None self, metric: str = None, estimators: List[str] = None
) -> pd.DataFrame: ) -> pd.DataFrame:
f_data = self.shift_data(metric=metric, estimators=estimators) f_data = self.shift_data(metric=metric, estimators=estimators)
avg_p = f_data.groupby(level=0).mean() avg_p = f_data.groupby(level=0, sort=False).mean()
avg_p.loc["avg", :] = f_data.mean() avg_p.loc["mean", :] = f_data.mean()
return avg_p return avg_p
def get_plots( def get_plots(
@ -229,7 +238,7 @@ class CompReport:
return None return None
return plot.plot_delta( return plot.plot_delta(
base_prevs=self.np_prevs, base_prevs=self.prevs,
columns=avg_data.columns.to_numpy(), columns=avg_data.columns.to_numpy(),
data=avg_data.T.to_numpy(), data=avg_data.T.to_numpy(),
metric=metric, metric=metric,
@ -246,7 +255,7 @@ class CompReport:
st_data = self.stdev_by_prevs(metric=metric, estimators=estimators) st_data = self.stdev_by_prevs(metric=metric, estimators=estimators)
return plot.plot_delta( return plot.plot_delta(
base_prevs=self.np_prevs, base_prevs=self.prevs,
columns=avg_data.columns.to_numpy(), columns=avg_data.columns.to_numpy(),
data=avg_data.T.to_numpy(), data=avg_data.T.to_numpy(),
metric=metric, metric=metric,
@ -280,12 +289,13 @@ class CompReport:
if _shift_data.empty is True: if _shift_data.empty is True:
return None return None
shift_avg = _shift_data.groupby(level=0).mean() shift_avg = _shift_data.groupby(level=0, sort=False).mean()
shift_counts = _shift_data.groupby(level=0).count() shift_counts = _shift_data.groupby(level=0, sort=False).count()
shift_prevs = np.around( shift_prevs = shift_avg.index.unique(0)
[(1.0 - p, p) for p in np.sort(shift_avg.index.unique(0))], # shift_prevs = np.around(
decimals=2, # [(1.0 - p, p) for p in np.sort(shift_avg.index.unique(0))],
) # decimals=2,
# )
return plot.plot_shift( return plot.plot_shift(
shift_prevs=shift_prevs, shift_prevs=shift_prevs,
columns=shift_avg.columns.to_numpy(), columns=shift_avg.columns.to_numpy(),
@ -317,7 +327,10 @@ class CompReport:
res += "\n" res += "\n"
if "train_table" in modes: if "train_table" in modes:
res += "### table\n" res += "### table\n"
res += self.table(metric=metric, estimators=estimators).to_html() + "\n\n" res += (
self.train_table(metric=metric, estimators=estimators).to_html()
+ "\n\n"
)
if "shift_table" in modes: if "shift_table" in modes:
res += "### shift table\n" res += "### shift table\n"
res += ( res += (
@ -369,7 +382,7 @@ class DatasetReport:
def data(self, metric: str = None, estimators: List[str] = None) -> pd.DataFrame: def data(self, metric: str = None, estimators: List[str] = None) -> pd.DataFrame:
def _cr_train_prev(cr: CompReport): def _cr_train_prev(cr: CompReport):
return cr.train_prev[1] return tuple(np.around(cr.train_prev, decimals=2))
def _cr_data(cr: CompReport): def _cr_data(cr: CompReport):
return cr.data(metric, estimators) return cr.data(metric, estimators)
@ -381,11 +394,27 @@ class DatasetReport:
) )
_crs_train, _crs_data = zip(*_crs_sorted) _crs_train, _crs_data = zip(*_crs_sorted)
_data = pd.concat(_crs_data, axis=0, keys=np.around(_crs_train, decimals=2)) _data: pd.DataFrame = pd.concat(
_data = _data.sort_index(axis=0, level=0) _crs_data,
axis=0,
keys=_crs_train,
)
# The MultiIndex is recreated to make the outer-most level a tuple and not a
# sequence of values
_len_tr_idx = len(_crs_train[0])
_idx = _data.index.to_list()
_idx = pd.MultiIndex.from_tuples(
[tuple([midx[:_len_tr_idx]] + list(midx[_len_tr_idx:])) for midx in _idx]
)
_data.index = _idx
_data = _data.sort_index(axis=0, level=0, ascending=False, sort_remaining=False)
return _data return _data
def shift_data(self, metric: str = None, estimators: str = None) -> pd.DataFrame: def shift_data(
self, metric: str = None, estimators: List[str] = None
) -> pd.DataFrame:
_shift_data: pd.DataFrame = pd.concat( _shift_data: pd.DataFrame = pd.concat(
sorted( sorted(
[cr.shift_data(metric, estimators) for cr in self.crs], [cr.shift_data(metric, estimators) for cr in self.crs],
@ -423,6 +452,30 @@ class DatasetReport:
self.add(cr) self.add(cr)
return self return self
def train_table(
self, metric: str = None, estimators: List[str] = None
) -> pd.DataFrame:
f_data = self.data(metric=metric, estimators=estimators)
avg_p = f_data.groupby(level=1, sort=False).mean()
avg_p.loc["mean", :] = f_data.mean()
return avg_p
def test_table(
self, metric: str = None, estimators: List[str] = None
) -> pd.DataFrame:
f_data = self.data(metric=metric, estimators=estimators)
avg_p = f_data.groupby(level=0, sort=False).mean()
avg_p.loc["mean", :] = f_data.mean()
return avg_p
def shift_table(
self, metric: str = None, estimators: List[str] = None
) -> pd.DataFrame:
f_data = self.shift_data(metric=metric, estimators=estimators)
avg_p = f_data.groupby(level=0, sort=False).mean()
avg_p.loc["mean", :] = f_data.mean()
return avg_p
def get_plots( def get_plots(
self, self,
data=None, data=None,
@ -436,14 +489,15 @@ class DatasetReport:
): ):
if mode == "delta_train": if mode == "delta_train":
_data = self.data(metric, estimators) if data is None else data _data = self.data(metric, estimators) if data is None else data
avg_on_train = _data.groupby(level=1).mean() avg_on_train = _data.groupby(level=1, sort=False).mean()
if avg_on_train.empty: if avg_on_train.empty:
return None return None
prevs_on_train = np.sort(avg_on_train.index.unique(0)) prevs_on_train = avg_on_train.index.unique(0)
return plot.plot_delta( return plot.plot_delta(
base_prevs=np.around( # base_prevs=np.around(
[(1.0 - p, p) for p in prevs_on_train], decimals=2 # [(1.0 - p, p) for p in prevs_on_train], decimals=2
), # ),
base_prevs=prevs_on_train,
columns=avg_on_train.columns.to_numpy(), columns=avg_on_train.columns.to_numpy(),
data=avg_on_train.T.to_numpy(), data=avg_on_train.T.to_numpy(),
metric=metric, metric=metric,
@ -456,15 +510,16 @@ class DatasetReport:
) )
elif mode == "stdev_train": elif mode == "stdev_train":
_data = self.data(metric, estimators) if data is None else data _data = self.data(metric, estimators) if data is None else data
avg_on_train = _data.groupby(level=1).mean() avg_on_train = _data.groupby(level=1, sort=False).mean()
if avg_on_train.empty: if avg_on_train.empty:
return None return None
prevs_on_train = np.sort(avg_on_train.index.unique(0)) prevs_on_train = avg_on_train.index.unique(0)
stdev_on_train = _data.groupby(level=1).std() stdev_on_train = _data.groupby(level=1, sort=False).std()
return plot.plot_delta( return plot.plot_delta(
base_prevs=np.around( # base_prevs=np.around(
[(1.0 - p, p) for p in prevs_on_train], decimals=2 # [(1.0 - p, p) for p in prevs_on_train], decimals=2
), # ),
base_prevs=prevs_on_train,
columns=avg_on_train.columns.to_numpy(), columns=avg_on_train.columns.to_numpy(),
data=avg_on_train.T.to_numpy(), data=avg_on_train.T.to_numpy(),
metric=metric, metric=metric,
@ -478,12 +533,13 @@ class DatasetReport:
) )
elif mode == "delta_test": elif mode == "delta_test":
_data = self.data(metric, estimators) if data is None else data _data = self.data(metric, estimators) if data is None else data
avg_on_test = _data.groupby(level=0).mean() avg_on_test = _data.groupby(level=0, sort=False).mean()
if avg_on_test.empty: if avg_on_test.empty:
return None return None
prevs_on_test = np.sort(avg_on_test.index.unique(0)) prevs_on_test = avg_on_test.index.unique(0)
return plot.plot_delta( return plot.plot_delta(
base_prevs=np.around([(1.0 - p, p) for p in prevs_on_test], decimals=2), # base_prevs=np.around([(1.0 - p, p) for p in prevs_on_test], decimals=2),
base_prevs=prevs_on_test,
columns=avg_on_test.columns.to_numpy(), columns=avg_on_test.columns.to_numpy(),
data=avg_on_test.T.to_numpy(), data=avg_on_test.T.to_numpy(),
metric=metric, metric=metric,
@ -496,13 +552,14 @@ class DatasetReport:
) )
elif mode == "stdev_test": elif mode == "stdev_test":
_data = self.data(metric, estimators) if data is None else data _data = self.data(metric, estimators) if data is None else data
avg_on_test = _data.groupby(level=0).mean() avg_on_test = _data.groupby(level=0, sort=False).mean()
if avg_on_test.empty: if avg_on_test.empty:
return None return None
prevs_on_test = np.sort(avg_on_test.index.unique(0)) prevs_on_test = avg_on_test.index.unique(0)
stdev_on_test = _data.groupby(level=0).std() stdev_on_test = _data.groupby(level=0, sort=False).std()
return plot.plot_delta( return plot.plot_delta(
base_prevs=np.around([(1.0 - p, p) for p in prevs_on_test], decimals=2), # base_prevs=np.around([(1.0 - p, p) for p in prevs_on_test], decimals=2),
base_prevs=prevs_on_test,
columns=avg_on_test.columns.to_numpy(), columns=avg_on_test.columns.to_numpy(),
data=avg_on_test.T.to_numpy(), data=avg_on_test.T.to_numpy(),
metric=metric, metric=metric,
@ -516,13 +573,14 @@ class DatasetReport:
) )
elif mode == "shift": elif mode == "shift":
_shift_data = self.shift_data(metric, estimators) if data is None else data _shift_data = self.shift_data(metric, estimators) if data is None else data
avg_shift = _shift_data.groupby(level=0).mean() avg_shift = _shift_data.groupby(level=0, sort=False).mean()
if avg_shift.empty: if avg_shift.empty:
return None return None
count_shift = _shift_data.groupby(level=0).count() count_shift = _shift_data.groupby(level=0, sort=False).count()
prevs_shift = np.sort(avg_shift.index.unique(0)) prevs_shift = avg_shift.index.unique(0)
return plot.plot_shift( return plot.plot_shift(
shift_prevs=np.around([(1.0 - p, p) for p in prevs_shift], decimals=2), # shift_prevs=np.around([(1.0 - p, p) for p in prevs_shift], decimals=2),
shift_prevs=prevs_shift,
columns=avg_shift.columns.to_numpy(), columns=avg_shift.columns.to_numpy(),
data=avg_shift.T.to_numpy(), data=avg_shift.T.to_numpy(),
metric=metric, metric=metric,
@ -551,7 +609,14 @@ class DatasetReport:
and str(round(cr.train_prev[1] * 100)) not in cr_prevs and str(round(cr.train_prev[1] * 100)) not in cr_prevs
): ):
continue continue
res += f"{cr.to_md(conf, metric=metric, estimators=estimators, modes=cr_modes, plot_path=plot_path)}\n\n" _md = cr.to_md(
conf,
metric=metric,
estimators=estimators,
modes=cr_modes,
plot_path=plot_path,
)
res += f"{_md}\n\n"
_data = self.data(metric=metric, estimators=estimators) _data = self.data(metric=metric, estimators=estimators)
_shift_data = self.shift_data(metric=metric, estimators=estimators) _shift_data = self.shift_data(metric=metric, estimators=estimators)
@ -562,7 +627,7 @@ class DatasetReport:
res += "### avg on train\n" res += "### avg on train\n"
if "train_table" in dr_modes: if "train_table" in dr_modes:
avg_on_train_tbl = _data.groupby(level=1).mean() avg_on_train_tbl = _data.groupby(level=1, sort=False).mean()
avg_on_train_tbl.loc["avg", :] = _data.mean() avg_on_train_tbl.loc["avg", :] = _data.mean()
res += avg_on_train_tbl.to_html() + "\n\n" res += avg_on_train_tbl.to_html() + "\n\n"
@ -576,7 +641,8 @@ class DatasetReport:
base_path=plot_path, base_path=plot_path,
save_fig=True, save_fig=True,
) )
res += f"![plot_delta]({delta_op.relative_to(delta_op.parents[1]).as_posix()})\n" _op = delta_op.relative_to(delta_op.parents[1]).as_posix()
res += f"![plot_delta]({_op})\n"
if "stdev_train" in dr_modes: if "stdev_train" in dr_modes:
_, delta_stdev_op = self.get_plots( _, delta_stdev_op = self.get_plots(
@ -588,13 +654,14 @@ class DatasetReport:
base_path=plot_path, base_path=plot_path,
save_fig=True, save_fig=True,
) )
res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(delta_stdev_op.parents[1]).as_posix()})\n" _op = delta_stdev_op.relative_to(delta_stdev_op.parents[1]).as_posix()
res += f"![plot_delta_stdev]({_op})\n"
######################## avg on test ######################## ######################## avg on test ########################
res += "### avg on test\n" res += "### avg on test\n"
if "test_table" in dr_modes: if "test_table" in dr_modes:
avg_on_test_tbl = _data.groupby(level=0).mean() avg_on_test_tbl = _data.groupby(level=0, sort=False).mean()
avg_on_test_tbl.loc["avg", :] = _data.mean() avg_on_test_tbl.loc["avg", :] = _data.mean()
res += avg_on_test_tbl.to_html() + "\n\n" res += avg_on_test_tbl.to_html() + "\n\n"
@ -608,7 +675,8 @@ class DatasetReport:
base_path=plot_path, base_path=plot_path,
save_fig=True, save_fig=True,
) )
res += f"![plot_delta]({delta_op.relative_to(delta_op.parents[1]).as_posix()})\n" _op = delta_op.relative_to(delta_op.parents[1]).as_posix()
res += f"![plot_delta]({_op})\n"
if "stdev_test" in dr_modes: if "stdev_test" in dr_modes:
_, delta_stdev_op = self.get_plots( _, delta_stdev_op = self.get_plots(
@ -620,13 +688,14 @@ class DatasetReport:
base_path=plot_path, base_path=plot_path,
save_fig=True, save_fig=True,
) )
res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(delta_stdev_op.parents[1]).as_posix()})\n" _op = delta_stdev_op.relative_to(delta_stdev_op.parents[1]).as_posix()
res += f"![plot_delta_stdev]({_op})\n"
######################## avg shift ######################## ######################## avg shift ########################
res += "### avg dataset shift\n" res += "### avg dataset shift\n"
if "shift_table" in dr_modes: if "shift_table" in dr_modes:
shift_on_train_tbl = _shift_data.groupby(level=0).mean() shift_on_train_tbl = _shift_data.groupby(level=0, sort=False).mean()
shift_on_train_tbl.loc["avg", :] = _shift_data.mean() shift_on_train_tbl.loc["avg", :] = _shift_data.mean()
res += shift_on_train_tbl.to_html() + "\n\n" res += shift_on_train_tbl.to_html() + "\n\n"
@ -640,7 +709,8 @@ class DatasetReport:
base_path=plot_path, base_path=plot_path,
save_fig=True, save_fig=True,
) )
res += f"![plot_shift]({shift_op.relative_to(shift_op.parents[1]).as_posix()})\n" _op = shift_op.relative_to(shift_op.parents[1]).as_posix()
res += f"![plot_shift]({_op})\n"
return res return res
@ -669,7 +739,10 @@ class DatasetReportInfo:
self.dr = dr self.dr = dr
self.name = str(path.parent) self.name = str(path.parent)
_data = dr.data() _data = dr.data()
self.columns = list(_data.columns.unique(1)) self.columns = defaultdict(list)
for metric, estim in _data.columns:
self.columns[estim].append(metric)
# self.columns = list(_data.columns.unique(1))
self.train_prevs = len(self.dr.crs) self.train_prevs = len(self.dr.crs)
self.test_prevs = len(_data.index.unique(1)) self.test_prevs = len(_data.index.unique(1))
self.repeats = len(_data.index.unique(2)) self.repeats = len(_data.index.unique(2))

View File

@ -9,7 +9,14 @@ from quapy.method.aggregative import BaseQuantifier
from sklearn.base import BaseEstimator from sklearn.base import BaseEstimator
import quacc.method.confidence as conf import quacc.method.confidence as conf
from quacc.data import ExtendedCollection, ExtendedData, ExtensionPolicy from quacc.data import (
ExtBinPrev,
ExtendedCollection,
ExtendedData,
ExtendedPrev,
ExtensionPolicy,
ExtMulPrev,
)
class BaseAccuracyEstimator(BaseQuantifier): class BaseAccuracyEstimator(BaseQuantifier):
@ -17,10 +24,11 @@ class BaseAccuracyEstimator(BaseQuantifier):
self, self,
classifier: BaseEstimator, classifier: BaseEstimator,
quantifier: BaseQuantifier, quantifier: BaseQuantifier,
dense=False,
): ):
self.__check_classifier(classifier) self.__check_classifier(classifier)
self.quantifier = quantifier self.quantifier = quantifier
self.extpol = ExtensionPolicy() self.extpol = ExtensionPolicy(dense=dense)
def __check_classifier(self, classifier): def __check_classifier(self, classifier):
if not hasattr(classifier, "predict_proba"): if not hasattr(classifier, "predict_proba"):
@ -46,9 +54,13 @@ class BaseAccuracyEstimator(BaseQuantifier):
... ...
@abstractmethod @abstractmethod
def estimate(self, instances, ext=False) -> np.ndarray: def estimate(self, instances, ext=False) -> ExtendedPrev:
... ...
@property
def dense(self):
return self.extpol.dense
class ConfidenceBasedAccuracyEstimator(BaseAccuracyEstimator): class ConfidenceBasedAccuracyEstimator(BaseAccuracyEstimator):
def __init__( def __init__(
@ -98,10 +110,21 @@ class ConfidenceBasedAccuracyEstimator(BaseAccuracyEstimator):
return np.concatenate([_conf_ext, _pred_ext], axis=1) return np.concatenate([_conf_ext, _pred_ext], axis=1)
def extend(self, coll: LabelledCollection, pred_proba=None) -> ExtendedCollection: def extend(
self, coll: LabelledCollection, pred_proba=None, prefit=False
) -> ExtendedCollection:
if pred_proba is None: if pred_proba is None:
pred_proba = self.classifier.predict_proba(coll.X) pred_proba = self.classifier.predict_proba(coll.X)
if prefit:
self._fit_confidence(coll.X, coll.y, pred_proba)
else:
if not hasattr(self, "confidence_metrics"):
raise AttributeError(
"Confidence metrics are not fit and cannot be computed."
"Consider setting prefit to True."
)
_ext = self.__get_ext(coll.X, pred_proba) _ext = self.__get_ext(coll.X, pred_proba)
return ExtendedCollection.from_lc( return ExtendedCollection.from_lc(
coll, pred_proba=pred_proba, ext=_ext, extpol=self.extpol coll, pred_proba=pred_proba, ext=_ext, extpol=self.extpol
@ -125,17 +148,23 @@ class MultiClassAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
quantifier: BaseQuantifier, quantifier: BaseQuantifier,
confidence: str = None, confidence: str = None,
collapse_false=False, collapse_false=False,
group_false=False,
dense=False,
): ):
super().__init__( super().__init__(
classifier=classifier, classifier=classifier,
quantifier=quantifier, quantifier=quantifier,
confidence=confidence, confidence=confidence,
) )
self.extpol = ExtensionPolicy(collapse_false=collapse_false) self.extpol = ExtensionPolicy(
collapse_false=collapse_false,
group_false=group_false,
dense=dense,
)
self.e_train = None self.e_train = None
def _get_pred_ext(self, pred_proba: np.ndarray): # def _get_pred_ext(self, pred_proba: np.ndarray):
return np.argmax(pred_proba, axis=1, keepdims=True) # return np.argmax(pred_proba, axis=1, keepdims=True)
def fit(self, train: LabelledCollection): def fit(self, train: LabelledCollection):
pred_proba = self.classifier.predict_proba(train.X) pred_proba = self.classifier.predict_proba(train.X)
@ -148,31 +177,27 @@ class MultiClassAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
def estimate( def estimate(
self, instances: ExtendedData | np.ndarray | sp.csr_matrix self, instances: ExtendedData | np.ndarray | sp.csr_matrix
) -> np.ndarray: ) -> ExtendedPrev:
e_inst = instances e_inst = instances
if not isinstance(e_inst, ExtendedData): if not isinstance(e_inst, ExtendedData):
e_inst = self._extend_instances(instances) e_inst = self._extend_instances(instances)
estim_prev = self.quantifier.quantify(e_inst.X) estim_prev = self.quantifier.quantify(e_inst.X)
estim_prev = self._check_prevalence_classes( return ExtMulPrev(
estim_prev, self.quantifier.classes_ estim_prev,
e_inst.nbcl,
q_classes=self.quantifier.classes_,
extpol=self.extpol,
) )
if self.extpol.collapse_false:
estim_prev = np.insert(estim_prev, 2, 0.0)
return estim_prev
def _check_prevalence_classes(self, estim_prev, estim_classes) -> np.ndarray:
true_classes = self.e_train.classes_
for _cls in true_classes:
if _cls not in estim_classes:
estim_prev = np.insert(estim_prev, _cls, [0.0], axis=0)
return estim_prev
@property @property
def collapse_false(self): def collapse_false(self):
return self.extpol.collapse_false return self.extpol.collapse_false
@property
def group_false(self):
return self.extpol.group_false
class BinaryQuantifierAccuracyEstimator(ConfidenceBasedAccuracyEstimator): class BinaryQuantifierAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
def __init__( def __init__(
@ -180,6 +205,8 @@ class BinaryQuantifierAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
classifier: BaseEstimator, classifier: BaseEstimator,
quantifier: BaseAccuracyEstimator, quantifier: BaseAccuracyEstimator,
confidence: str = None, confidence: str = None,
group_false: bool = False,
dense: bool = False,
): ):
super().__init__( super().__init__(
classifier=classifier, classifier=classifier,
@ -187,6 +214,10 @@ class BinaryQuantifierAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
confidence=confidence, confidence=confidence,
) )
self.quantifiers = [] self.quantifiers = []
self.extpol = ExtensionPolicy(
group_false=group_false,
dense=dense,
)
def fit(self, train: LabelledCollection | ExtendedCollection): def fit(self, train: LabelledCollection | ExtendedCollection):
pred_proba = self.classifier.predict_proba(train.X) pred_proba = self.classifier.predict_proba(train.X)
@ -215,9 +246,14 @@ class BinaryQuantifierAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
norms = [s_i.shape[0] / len(e_inst) for s_i in s_inst] norms = [s_i.shape[0] / len(e_inst) for s_i in s_inst]
estim_prevs = self._quantify_helper(s_inst, norms) estim_prevs = self._quantify_helper(s_inst, norms)
# estim_prev = np.array([prev_row for prev_row in zip(*estim_prevs)]).flatten() # estim_prev = np.concatenate(estim_prevs.T)
estim_prev = np.concatenate(estim_prevs.T) # return ExtendedPrev(estim_prev, e_inst.nbcl, extpol=self.extpol)
return estim_prev return ExtBinPrev(
estim_prevs,
e_inst.nbcl,
q_classes=[quant.classes_ for quant in self.quantifiers],
extpol=self.extpol,
)
def _quantify_helper( def _quantify_helper(
self, self,
@ -229,9 +265,14 @@ class BinaryQuantifierAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
if inst.shape[0] > 0: if inst.shape[0] > 0:
estim_prevs.append(quant.quantify(inst) * norm) estim_prevs.append(quant.quantify(inst) * norm)
else: else:
estim_prevs.append(np.asarray([0.0, 0.0])) estim_prevs.append(np.zeros((len(quant.classes_),)))
return np.array(estim_prevs) # return np.array(estim_prevs)
return estim_prevs
@property
def group_false(self):
return self.extpol.group_false
BAE = BaseAccuracyEstimator BAE = BaseAccuracyEstimator

View File

@ -16,7 +16,7 @@ from quapy.protocol import (
import quacc as qc import quacc as qc
import quacc.error import quacc.error
from quacc.data import ExtendedCollection from quacc.data import ExtendedCollection
from quacc.evaluation import evaluate from quacc.evaluation.evaluate import evaluate
from quacc.logger import logger from quacc.logger import logger
from quacc.method.base import ( from quacc.method.base import (
BaseAccuracyEstimator, BaseAccuracyEstimator,
@ -194,14 +194,16 @@ class GridSearchAE(BaseAccuracyEstimator):
f"\tException: {e}", f"\tException: {e}",
level=1, level=1,
) )
raise e # raise e
score = None score = None
return params, score, model return params, score, model
def extend(self, coll: LabelledCollection, pred_proba=None) -> ExtendedCollection: def extend(
self, coll: LabelledCollection, pred_proba=None, prefit=False
) -> ExtendedCollection:
assert hasattr(self, "best_model_"), "quantify called before fit" assert hasattr(self, "best_model_"), "quantify called before fit"
return self.best_model().extend(coll, pred_proba=pred_proba) return self.best_model().extend(coll, pred_proba=pred_proba, prefit=prefit)
def estimate(self, instances): def estimate(self, instances):
"""Estimate class prevalence values using the best model found after calling the :meth:`fit` method. """Estimate class prevalence values using the best model found after calling the :meth:`fit` method.
@ -392,6 +394,19 @@ class SpiderSearchAE(GridSearchAE):
[(params, training) for params in _hyper], [(params, training) for params in _hyper],
parallel=parallel, parallel=parallel,
) )
# if all scores are None, select a new random batch
if all([s[1] is None for s in _iter_scores]):
rand_index = np.arange(len(_hyper_remaining))
np.random.shuffle(rand_index)
rand_index = rand_index[:batch_size]
remaining_index = np.setdiff1d(
np.arange(len(_hyper_remaining)), rand_index
)
_hyper = _hyper_remaining[rand_index]
_hyper_remaining = _hyper_remaining[remaining_index]
continue
_sorted_idx = np.argsort( _sorted_idx = np.argsort(
[1.0 if s is None else s for _, s, _ in _iter_scores] [1.0 if s is None else s for _, s, _ in _iter_scores]
) )

View File

@ -1,9 +1,11 @@
from pathlib import Path from pathlib import Path
from re import X
import matplotlib import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from cycler import cycler from cycler import cycler
from sklearn import base
from quacc import utils from quacc import utils
from quacc.plot.base import BasePlot from quacc.plot.base import BasePlot
@ -48,10 +50,15 @@ class MplPlot(BasePlot):
cm = plt.get_cmap("tab20") cm = plt.get_cmap("tab20")
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)]) cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
base_prevs = base_prevs[:, pos_class] # base_prevs = base_prevs[:, pos_class]
if isinstance(base_prevs[0], float):
base_prevs = np.around([(1 - bp, bp) for bp in base_prevs], decimals=4)
str_base_prevs = [str(tuple(bp)) for bp in base_prevs]
# xticks = [str(bp) for bp in base_prevs]
xticks = np.arange(len(base_prevs))
for method, deltas, _cy in zip(columns, data, cy): for method, deltas, _cy in zip(columns, data, cy):
ax.plot( ax.plot(
base_prevs, xticks,
deltas, deltas,
label=method, label=method,
color=_cy["color"], color=_cy["color"],
@ -67,7 +74,7 @@ class MplPlot(BasePlot):
np.where(deltas != np.nan)[0], np.where(deltas != np.nan)[0],
np.where(stdev != np.nan)[0], np.where(stdev != np.nan)[0],
) )
_bps, _ds, _st = base_prevs[nn_idx], deltas[nn_idx], stdev[nn_idx] _bps, _ds, _st = xticks[nn_idx], deltas[nn_idx], stdev[nn_idx]
ax.fill_between( ax.fill_between(
_bps, _bps,
_ds - _st, _ds - _st,
@ -76,6 +83,15 @@ class MplPlot(BasePlot):
alpha=0.25, alpha=0.25,
) )
def format_fn(tick_val, tick_pos):
if int(tick_val) in xticks:
return str_base_prevs[int(tick_val)]
return ""
ax.xaxis.set_major_locator(plt.MaxNLocator(nbins=6, integer=True, prune="both"))
ax.xaxis.set_major_formatter(format_fn)
ax.set( ax.set(
xlabel=f"{x_label} prevalence", xlabel=f"{x_label} prevalence",
ylabel=y_label, ylabel=y_label,
@ -187,7 +203,7 @@ class MplPlot(BasePlot):
cm = plt.get_cmap("tab20") cm = plt.get_cmap("tab20")
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)]) cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
shift_prevs = shift_prevs[:, pos_class] # shift_prevs = shift_prevs[:, pos_class]
for method, shifts, _cy in zip(columns, data, cy): for method, shifts, _cy in zip(columns, data, cy):
ax.plot( ax.plot(
shift_prevs, shift_prevs,

View File

@ -69,7 +69,9 @@ class PlotlyPlot(BasePlot):
legend=True, legend=True,
) -> go.Figure: ) -> go.Figure:
fig = go.Figure() fig = go.Figure()
x = base_prevs[:, pos_class] if isinstance(base_prevs[0], float):
base_prevs = np.around([(1 - bp, bp) for bp in base_prevs], decimals=4)
x = [str(tuple(bp)) for bp in base_prevs]
line_colors = self.get_colors(len(columns)) line_colors = self.get_colors(len(columns))
for name, delta in zip(columns, data): for name, delta in zip(columns, data):
color = next(line_colors) color = next(line_colors)
@ -177,7 +179,8 @@ class PlotlyPlot(BasePlot):
legend=True, legend=True,
) -> go.Figure: ) -> go.Figure:
fig = go.Figure() fig = go.Figure()
x = shift_prevs[:, pos_class] # x = shift_prevs[:, pos_class]
x = shift_prevs
line_colors = self.get_colors(len(columns)) line_colors = self.get_colors(len(columns))
for name, delta in zip(columns, data): for name, delta in zip(columns, data):
col_idx = (columns == name).nonzero()[0][0] col_idx = (columns == name).nonzero()[0][0]