import multiprocessing import time from traceback import print_exception as traceback from typing import List import pandas as pd import quapy as qp from quacc.dataset import Dataset from quacc.environment import env from quacc.evaluation import baseline, method from quacc.evaluation.report import CompReport, DatasetReport, EvaluationReport from quacc.evaluation.worker import estimate_worker from quacc.logger import Logger pd.set_option("display.float_format", "{:.4f}".format) qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE class CompEstimator: __dict = method._methods | baseline._baselines def __class_getitem__(cls, e: str | List[str]): if isinstance(e, str): try: return cls.__dict[e] except KeyError: raise KeyError(f"Invalid estimator: estimator {e} does not exist") elif isinstance(e, list): _subtr = [k for k in e if k not in cls.__dict] if len(_subtr) > 0: raise KeyError( f"Invalid estimator: estimator {_subtr[0]} does not exist" ) return [fun for k, fun in cls.__dict.items() if k in e] CE = CompEstimator def evaluate_comparison( dataset: Dataset, estimators=["OUR_BIN_SLD", "OUR_MUL_SLD"] ) -> EvaluationReport: log = Logger.logger() # with multiprocessing.Pool(1) as pool: with multiprocessing.Pool(len(estimators)) as pool: dr = DatasetReport(dataset.name) log.info(f"dataset {dataset.name}") for d in dataset(): log.info( f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} started" ) tstart = time.time() tasks = [(estim, d.train, d.validation, d.test) for estim in CE[estimators]] results = [ pool.apply_async(estimate_worker, t, {"_env": env, "q": Logger.queue()}) for t in tasks ] results_got = [] for _r in results: try: r = _r.get() if r["result"] is not None: results_got.append(r) except Exception as e: log.warning( f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} failed. Exception: {e}" ) tend = time.time() times = {r["name"]: r["time"] for r in results_got} times["tot"] = tend - tstart log.info( f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} finished [took {times['tot']:.4f}s]" ) try: cr = CompReport( [r["result"] for r in results_got], name=dataset.name, train_prev=d.train_prev, valid_prev=d.validation_prev, times=times, ) except Exception as e: log.warning( f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} failed. Exception: {e}" ) traceback(e) cr = None dr += cr return dr