QuAcc/quacc/evaluation/estimators.py

81 lines
2.2 KiB
Python

from typing import List
import numpy as np
from quacc.evaluation import baseline, method
class CompEstimatorFunc_:
def __init__(self, ce):
self.ce = ce
def __getitem__(self, e: str | List[str]):
if isinstance(e, str):
return list(self.ce._CompEstimator__get(e).values())[0]
elif isinstance(e, list):
return list(self.ce._CompEstimator__get(e).values())
class CompEstimatorName_:
def __init__(self, ce):
self.ce = ce
def __getitem__(self, e: str | List[str]):
if isinstance(e, str):
return list(self.ce._CompEstimator__get(e).keys())[0]
elif isinstance(e, list):
return list(self.ce._CompEstimator__get(e).keys())
def sort(self, e: List[str]):
return list(self.ce._CompEstimator__get(e, get_ref=False).keys())
@property
def all(self):
return list(self.ce._CompEstimator__get("__all").keys())
@property
def baselines(self):
return list(self.ce._CompEstimator__get("__baselines").keys())
class CompEstimator:
def __get(cls, e: str | List[str], get_ref=True):
_dict = method._methods | baseline._baselines
match e:
case "__all":
e = list(_dict.keys())
case "__baselines":
e = list(baseline._baselines.keys())
if isinstance(e, str):
try:
return {e: _dict[e]}
except KeyError:
raise KeyError(f"Invalid estimator: estimator {e} does not exist")
elif isinstance(e, list):
_subtr = np.setdiff1d(e, list(_dict.keys()))
if len(_subtr) > 0:
raise KeyError(
f"Invalid estimator: estimator {_subtr[0]} does not exist"
)
e_fun = {k: fun for k, fun in _dict.items() if k in e}
if get_ref and "ref" not in e:
e_fun["ref"] = _dict["ref"]
elif not get_ref and "ref" in e:
del e_fun["ref"]
return e_fun
@property
def name(self):
return CompEstimatorName_(self)
@property
def func(self):
return CompEstimatorFunc_(self)
CE = CompEstimator()