From cd8dc4d42ec41a55caffdba9804156cb47c5f45c Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Mon, 27 Nov 2023 03:27:08 +0100 Subject: [PATCH] CompEstimator refatored --- quacc/evaluation/comp.py | 65 +---------------------------- quacc/evaluation/estimators.py | 75 ++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 64 deletions(-) create mode 100644 quacc/evaluation/estimators.py diff --git a/quacc/evaluation/comp.py b/quacc/evaluation/comp.py index 0cd71d0..0ea0eef 100644 --- a/quacc/evaluation/comp.py +++ b/quacc/evaluation/comp.py @@ -2,15 +2,13 @@ import multiprocessing import os import time from traceback import print_exception as traceback -from typing import List -import numpy as np import pandas as pd import quapy as qp from quacc.dataset import Dataset from quacc.environment import env -from quacc.evaluation import baseline, method +from quacc.evaluation.estimators import CE from quacc.evaluation.report import CompReport, DatasetReport from quacc.evaluation.worker import WorkerArgs, estimate_worker from quacc.logger import Logger @@ -19,67 +17,6 @@ pd.set_option("display.float_format", "{:.4f}".format) qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE -class CompEstimatorName_: - def __init__(self, ce): - self.ce = ce - - def __getitem__(self, e: str | List[str]): - if isinstance(e, str): - return self.ce._CompEstimator__get(e)[0] - elif isinstance(e, list): - return list(self.ce._CompEstimator__get(e).keys()) - - @property - def all(self): - all_keys = list(CompEstimator._CompEstimator__dict.keys()) - return self[all_keys] - - -class CompEstimatorFunc_: - def __init__(self, ce): - self.ce = ce - - def __getitem__(self, e: str | List[str]): - if isinstance(e, str): - return self.ce._CompEstimator__get(e)[1] - elif isinstance(e, list): - return list(self.ce._CompEstimator__get(e).values()) - - -class CompEstimator: - __dict = method._methods | baseline._baselines - - def __get(cls, e: str | List[str]): - if isinstance(e, str): - try: - return (e, cls.__dict[e]) - except KeyError: - raise KeyError(f"Invalid estimator: estimator {e} does not exist") - elif isinstance(e, list): - _subtr = np.setdiff1d(e, list(cls.__dict.keys())) - if len(_subtr) > 0: - raise KeyError( - f"Invalid estimator: estimator {_subtr[0]} does not exist" - ) - - e_fun = {k: fun for k, fun in cls.__dict.items() if k in e} - if "ref" not in e: - e_fun["ref"] = cls.__dict["ref"] - - return e_fun - - @property - def name(self): - return CompEstimatorName_(self) - - @property - def func(self): - return CompEstimatorFunc_(self) - - -CE = CompEstimator() - - def evaluate_comparison(dataset: Dataset, estimators=None) -> DatasetReport: log = Logger.logger() # with multiprocessing.Pool(1) as pool: diff --git a/quacc/evaluation/estimators.py b/quacc/evaluation/estimators.py new file mode 100644 index 0000000..d1d6f13 --- /dev/null +++ b/quacc/evaluation/estimators.py @@ -0,0 +1,75 @@ +from typing import List + +import numpy as np + +from quacc.evaluation import baseline, method + + +class CompEstimatorFunc_: + def __init__(self, ce): + self.ce = ce + + def __getitem__(self, e: str | List[str]): + if isinstance(e, str): + return list(self.ce._CompEstimator__get(e).values())[0] + elif isinstance(e, list): + return list(self.ce._CompEstimator__get(e).values()) + + +class CompEstimatorName_: + def __init__(self, ce): + self.ce = ce + + def __getitem__(self, e: str | List[str]): + if isinstance(e, str): + return list(self.ce._CompEstimator__get(e).keys())[0] + elif isinstance(e, list): + return list(self.ce._CompEstimator__get(e).keys()) + + @property + def all(self): + return list(self.ce._CompEstimator__get("__all").keys()) + + @property + def baselines(self): + return list(self.ce._CompEstimator__get("__baselines").keys()) + + +class CompEstimator: + def __get(cls, e: str | List[str]): + _dict = method._methods | baseline._baselines + + match e: + case "__all": + e = list(_dict.keys()) + case "__baselines": + e = list(baseline._baselines.keys()) + + if isinstance(e, str): + try: + return {e: _dict[e]} + except KeyError: + raise KeyError(f"Invalid estimator: estimator {e} does not exist") + elif isinstance(e, list): + _subtr = np.setdiff1d(e, list(_dict.keys())) + if len(_subtr) > 0: + raise KeyError( + f"Invalid estimator: estimator {_subtr[0]} does not exist" + ) + + e_fun = {k: fun for k, fun in _dict.items() if k in e} + if "ref" not in e: + e_fun["ref"] = _dict["ref"] + + return e_fun + + @property + def name(self): + return CompEstimatorName_(self) + + @property + def func(self): + return CompEstimatorFunc_(self) + + +CE = CompEstimator()