diff --git a/quacc/evaluation/comp.py b/quacc/evaluation/comp.py index 344b9fd..0cd71d0 100644 --- a/quacc/evaluation/comp.py +++ b/quacc/evaluation/comp.py @@ -12,7 +12,7 @@ from quacc.dataset import Dataset from quacc.environment import env from quacc.evaluation import baseline, method from quacc.evaluation.report import CompReport, DatasetReport -from quacc.evaluation.worker import estimate_worker +from quacc.evaluation.worker import WorkerArgs, estimate_worker from quacc.logger import Logger pd.set_option("display.float_format", "{:.4f}".format) @@ -91,45 +91,42 @@ def evaluate_comparison(dataset: Dataset, estimators=None) -> DatasetReport: log.info( f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} started" ) - tstart = time.time() tasks = [ - (estim, d.train, d.validation, d.test) for estim in CE.func[estimators] + WorkerArgs( + _estimate=estim, + train=d.train, + validation=d.validation, + test=d.test, + _env=env, + q=Logger.queue(), + ) + for estim in CE.func[estimators] ] - results = [ - pool.apply_async(estimate_worker, t, {"_env": env, "q": Logger.queue()}) - for t in tasks - ] - - results_got = [] - for _r in results: - try: - r = _r.get() - if r["result"] is not None: - results_got.append(r) - except Exception as e: - log.warning( - f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} failed. Exception: {e}" - ) - - tend = time.time() - times = {r["name"]: r["time"] for r in results_got} - times["tot"] = tend - tstart - log.info( - f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} finished [took {times['tot']:.4f}s]" - ) try: + tstart = time.time() + results = [ + r for r in pool.imap(estimate_worker, tasks) if r is not None + ] + + g_time = time.time() - tstart + log.info( + f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} finished " + f"[took {g_time:.4f}s]" + ) + cr = CompReport( - [r["result"] for r in results_got], + results, name=dataset.name, train_prev=d.train_prev, valid_prev=d.validation_prev, - times=times, + g_time=g_time, ) + dr += cr + except Exception as e: log.warning( - f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} failed. Exception: {e}" + f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} failed. " + f"Exception: {e}" ) traceback(e) - cr = None - dr += cr return dr diff --git a/quacc/evaluation/worker.py b/quacc/evaluation/worker.py index e1d02b0..681c597 100644 --- a/quacc/evaluation/worker.py +++ b/quacc/evaluation/worker.py @@ -1,44 +1,52 @@ import time +from dataclasses import dataclass +from multiprocessing import Queue from traceback import print_exception as traceback import quapy as qp +from quapy.data import LabelledCollection from quapy.protocol import APP from sklearn.linear_model import LogisticRegression +from quacc.environment import env, environ from quacc.logger import SubLogger -def estimate_worker(_estimate, train, validation, test, _env=None, q=None): - qp.environ["SAMPLE_SIZE"] = _env.SAMPLE_SIZE - SubLogger.setup(q) - log = SubLogger.logger() +@dataclass(frozen=True) +class WorkerArgs: + _estimate: callable + train: LabelledCollection + validation: LabelledCollection + test: LabelledCollection + _env: environ + q: Queue - model = LogisticRegression() - model.fit(*train.Xy) - protocol = APP( - test, - n_prevalences=_env.PROTOCOL_N_PREVS, - repeats=_env.PROTOCOL_REPEATS, - return_type="labelled_collection", - ) - start = time.time() - try: - result = _estimate(model, validation, protocol) - except Exception as e: - log.warning(f"Method {_estimate.__name__} failed. Exception: {e}") - traceback(e) - return { - "name": _estimate.__name__, - "result": None, - "time": 0, - } +def estimate_worker(args: WorkerArgs): + with env.load(args._env): + qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE + SubLogger.setup(args.q) + log = SubLogger.logger() - end = time.time() - log.info(f"{_estimate.__name__} finished [took {end-start:.4f}s]") + model = LogisticRegression() - return { - "name": _estimate.__name__, - "result": result, - "time": end - start, - } + model.fit(*args.train.Xy) + protocol = APP( + args.test, + n_prevalences=env.PROTOCOL_N_PREVS, + repeats=env.PROTOCOL_REPEATS, + return_type="labelled_collection", + random_state=env._R_SEED, + ) + start = time.time() + try: + result = args._estimate(model, args.validation, protocol) + except Exception as e: + log.warning(f"Method {args._estimate.name} failed. Exception: {e}") + traceback(e) + return None + + result.time = time.time() - start + log.info(f"{args._estimate.name} finished [took {result.time:.4f}s]") + + return result