CompEstimator refatored

This commit is contained in:
Lorenzo Volpi 2023-11-27 03:27:08 +01:00
parent d7cbde7522
commit cd8dc4d42e
2 changed files with 76 additions and 64 deletions

View File

@ -2,15 +2,13 @@ import multiprocessing
import os
import time
from traceback import print_exception as traceback
from typing import List
import numpy as np
import pandas as pd
import quapy as qp
from quacc.dataset import Dataset
from quacc.environment import env
from quacc.evaluation import baseline, method
from quacc.evaluation.estimators import CE
from quacc.evaluation.report import CompReport, DatasetReport
from quacc.evaluation.worker import WorkerArgs, estimate_worker
from quacc.logger import Logger
@ -19,67 +17,6 @@ pd.set_option("display.float_format", "{:.4f}".format)
qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE
class CompEstimatorName_:
def __init__(self, ce):
self.ce = ce
def __getitem__(self, e: str | List[str]):
if isinstance(e, str):
return self.ce._CompEstimator__get(e)[0]
elif isinstance(e, list):
return list(self.ce._CompEstimator__get(e).keys())
@property
def all(self):
all_keys = list(CompEstimator._CompEstimator__dict.keys())
return self[all_keys]
class CompEstimatorFunc_:
def __init__(self, ce):
self.ce = ce
def __getitem__(self, e: str | List[str]):
if isinstance(e, str):
return self.ce._CompEstimator__get(e)[1]
elif isinstance(e, list):
return list(self.ce._CompEstimator__get(e).values())
class CompEstimator:
__dict = method._methods | baseline._baselines
def __get(cls, e: str | List[str]):
if isinstance(e, str):
try:
return (e, cls.__dict[e])
except KeyError:
raise KeyError(f"Invalid estimator: estimator {e} does not exist")
elif isinstance(e, list):
_subtr = np.setdiff1d(e, list(cls.__dict.keys()))
if len(_subtr) > 0:
raise KeyError(
f"Invalid estimator: estimator {_subtr[0]} does not exist"
)
e_fun = {k: fun for k, fun in cls.__dict.items() if k in e}
if "ref" not in e:
e_fun["ref"] = cls.__dict["ref"]
return e_fun
@property
def name(self):
return CompEstimatorName_(self)
@property
def func(self):
return CompEstimatorFunc_(self)
CE = CompEstimator()
def evaluate_comparison(dataset: Dataset, estimators=None) -> DatasetReport:
log = Logger.logger()
# with multiprocessing.Pool(1) as pool:

View File

@ -0,0 +1,75 @@
from typing import List
import numpy as np
from quacc.evaluation import baseline, method
class CompEstimatorFunc_:
def __init__(self, ce):
self.ce = ce
def __getitem__(self, e: str | List[str]):
if isinstance(e, str):
return list(self.ce._CompEstimator__get(e).values())[0]
elif isinstance(e, list):
return list(self.ce._CompEstimator__get(e).values())
class CompEstimatorName_:
def __init__(self, ce):
self.ce = ce
def __getitem__(self, e: str | List[str]):
if isinstance(e, str):
return list(self.ce._CompEstimator__get(e).keys())[0]
elif isinstance(e, list):
return list(self.ce._CompEstimator__get(e).keys())
@property
def all(self):
return list(self.ce._CompEstimator__get("__all").keys())
@property
def baselines(self):
return list(self.ce._CompEstimator__get("__baselines").keys())
class CompEstimator:
def __get(cls, e: str | List[str]):
_dict = method._methods | baseline._baselines
match e:
case "__all":
e = list(_dict.keys())
case "__baselines":
e = list(baseline._baselines.keys())
if isinstance(e, str):
try:
return {e: _dict[e]}
except KeyError:
raise KeyError(f"Invalid estimator: estimator {e} does not exist")
elif isinstance(e, list):
_subtr = np.setdiff1d(e, list(_dict.keys()))
if len(_subtr) > 0:
raise KeyError(
f"Invalid estimator: estimator {_subtr[0]} does not exist"
)
e_fun = {k: fun for k, fun in _dict.items() if k in e}
if "ref" not in e:
e_fun["ref"] = _dict["ref"]
return e_fun
@property
def name(self):
return CompEstimatorName_(self)
@property
def func(self):
return CompEstimatorFunc_(self)
CE = CompEstimator()