CompEstimator refatored
This commit is contained in:
parent
d7cbde7522
commit
cd8dc4d42e
|
@ -2,15 +2,13 @@ import multiprocessing
|
|||
import os
|
||||
import time
|
||||
from traceback import print_exception as traceback
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import quapy as qp
|
||||
|
||||
from quacc.dataset import Dataset
|
||||
from quacc.environment import env
|
||||
from quacc.evaluation import baseline, method
|
||||
from quacc.evaluation.estimators import CE
|
||||
from quacc.evaluation.report import CompReport, DatasetReport
|
||||
from quacc.evaluation.worker import WorkerArgs, estimate_worker
|
||||
from quacc.logger import Logger
|
||||
|
@ -19,67 +17,6 @@ pd.set_option("display.float_format", "{:.4f}".format)
|
|||
qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE
|
||||
|
||||
|
||||
class CompEstimatorName_:
|
||||
def __init__(self, ce):
|
||||
self.ce = ce
|
||||
|
||||
def __getitem__(self, e: str | List[str]):
|
||||
if isinstance(e, str):
|
||||
return self.ce._CompEstimator__get(e)[0]
|
||||
elif isinstance(e, list):
|
||||
return list(self.ce._CompEstimator__get(e).keys())
|
||||
|
||||
@property
|
||||
def all(self):
|
||||
all_keys = list(CompEstimator._CompEstimator__dict.keys())
|
||||
return self[all_keys]
|
||||
|
||||
|
||||
class CompEstimatorFunc_:
|
||||
def __init__(self, ce):
|
||||
self.ce = ce
|
||||
|
||||
def __getitem__(self, e: str | List[str]):
|
||||
if isinstance(e, str):
|
||||
return self.ce._CompEstimator__get(e)[1]
|
||||
elif isinstance(e, list):
|
||||
return list(self.ce._CompEstimator__get(e).values())
|
||||
|
||||
|
||||
class CompEstimator:
|
||||
__dict = method._methods | baseline._baselines
|
||||
|
||||
def __get(cls, e: str | List[str]):
|
||||
if isinstance(e, str):
|
||||
try:
|
||||
return (e, cls.__dict[e])
|
||||
except KeyError:
|
||||
raise KeyError(f"Invalid estimator: estimator {e} does not exist")
|
||||
elif isinstance(e, list):
|
||||
_subtr = np.setdiff1d(e, list(cls.__dict.keys()))
|
||||
if len(_subtr) > 0:
|
||||
raise KeyError(
|
||||
f"Invalid estimator: estimator {_subtr[0]} does not exist"
|
||||
)
|
||||
|
||||
e_fun = {k: fun for k, fun in cls.__dict.items() if k in e}
|
||||
if "ref" not in e:
|
||||
e_fun["ref"] = cls.__dict["ref"]
|
||||
|
||||
return e_fun
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return CompEstimatorName_(self)
|
||||
|
||||
@property
|
||||
def func(self):
|
||||
return CompEstimatorFunc_(self)
|
||||
|
||||
|
||||
CE = CompEstimator()
|
||||
|
||||
|
||||
def evaluate_comparison(dataset: Dataset, estimators=None) -> DatasetReport:
|
||||
log = Logger.logger()
|
||||
# with multiprocessing.Pool(1) as pool:
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from quacc.evaluation import baseline, method
|
||||
|
||||
|
||||
class CompEstimatorFunc_:
|
||||
def __init__(self, ce):
|
||||
self.ce = ce
|
||||
|
||||
def __getitem__(self, e: str | List[str]):
|
||||
if isinstance(e, str):
|
||||
return list(self.ce._CompEstimator__get(e).values())[0]
|
||||
elif isinstance(e, list):
|
||||
return list(self.ce._CompEstimator__get(e).values())
|
||||
|
||||
|
||||
class CompEstimatorName_:
|
||||
def __init__(self, ce):
|
||||
self.ce = ce
|
||||
|
||||
def __getitem__(self, e: str | List[str]):
|
||||
if isinstance(e, str):
|
||||
return list(self.ce._CompEstimator__get(e).keys())[0]
|
||||
elif isinstance(e, list):
|
||||
return list(self.ce._CompEstimator__get(e).keys())
|
||||
|
||||
@property
|
||||
def all(self):
|
||||
return list(self.ce._CompEstimator__get("__all").keys())
|
||||
|
||||
@property
|
||||
def baselines(self):
|
||||
return list(self.ce._CompEstimator__get("__baselines").keys())
|
||||
|
||||
|
||||
class CompEstimator:
|
||||
def __get(cls, e: str | List[str]):
|
||||
_dict = method._methods | baseline._baselines
|
||||
|
||||
match e:
|
||||
case "__all":
|
||||
e = list(_dict.keys())
|
||||
case "__baselines":
|
||||
e = list(baseline._baselines.keys())
|
||||
|
||||
if isinstance(e, str):
|
||||
try:
|
||||
return {e: _dict[e]}
|
||||
except KeyError:
|
||||
raise KeyError(f"Invalid estimator: estimator {e} does not exist")
|
||||
elif isinstance(e, list):
|
||||
_subtr = np.setdiff1d(e, list(_dict.keys()))
|
||||
if len(_subtr) > 0:
|
||||
raise KeyError(
|
||||
f"Invalid estimator: estimator {_subtr[0]} does not exist"
|
||||
)
|
||||
|
||||
e_fun = {k: fun for k, fun in _dict.items() if k in e}
|
||||
if "ref" not in e:
|
||||
e_fun["ref"] = _dict["ref"]
|
||||
|
||||
return e_fun
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return CompEstimatorName_(self)
|
||||
|
||||
@property
|
||||
def func(self):
|
||||
return CompEstimatorFunc_(self)
|
||||
|
||||
|
||||
CE = CompEstimator()
|
Loading…
Reference in New Issue