From 06761da8704f03eb7c2d507367d9110243bd189d Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Wed, 6 Dec 2023 10:00:23 +0100 Subject: [PATCH] parallelization improved, worker code refactored --- quacc/evaluation/comp.py | 85 ++++++++++++++++++++++++++++---------- quacc/evaluation/worker.py | 53 ------------------------ 2 files changed, 63 insertions(+), 75 deletions(-) delete mode 100644 quacc/evaluation/worker.py diff --git a/quacc/evaluation/comp.py b/quacc/evaluation/comp.py index 956c77b..98f71ec 100644 --- a/quacc/evaluation/comp.py +++ b/quacc/evaluation/comp.py @@ -1,4 +1,3 @@ -import multiprocessing import os import time from traceback import print_exception as traceback @@ -6,20 +5,70 @@ from traceback import print_exception as traceback import pandas as pd import quapy as qp from joblib import Parallel, delayed +from quapy.protocol import APP +from sklearn.linear_model import LogisticRegression +from quacc import logger from quacc.dataset import Dataset from quacc.environment import env from quacc.evaluation.estimators import CE from quacc.evaluation.report import CompReport, DatasetReport -from quacc.evaluation.worker import WorkerArgs, estimate_worker -from quacc.logger import Logger +from quacc.utils import parallel + +# from quacc.logger import logger, logger_manager + +# from quacc.evaluation.worker import WorkerArgs, estimate_worker pd.set_option("display.float_format", "{:.4f}".format) -qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE +# qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE + + +def estimate_worker(_estimate, train, validation, test, q=None): + # qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE + log = logger.setup_worker_logger(q) + + model = LogisticRegression() + + model.fit(*train.Xy) + protocol = APP( + test, + n_prevalences=env.PROTOCOL_N_PREVS, + repeats=env.PROTOCOL_REPEATS, + return_type="labelled_collection", + random_state=env._R_SEED, + ) + start = time.time() + try: + result = _estimate(model, validation, protocol) + except Exception as e: + log.warning(f"Method {_estimate.name} failed. Exception: {e}") + traceback(e) + return None + + result.time = time.time() - start + log.info(f"{_estimate.name} finished [took {result.time:.4f}s]") + + logger.logger_manager().rm_worker() + + return result + + +def split_tasks(estimators, train, validation, test, q): + _par, _seq = [], [] + for estim in estimators: + _task = [estim, train, validation, test] + match estim.name: + case n if n.endswith("_gs"): + _seq.append(_task) + case _: + _par.append(_task + [q]) + + return _par, _seq def evaluate_comparison(dataset: Dataset, estimators=None) -> DatasetReport: - log = Logger.logger() + # log = Logger.logger() + log = logger.logger() # with multiprocessing.Pool(1) as pool: __pool_size = round(os.cpu_count() * 0.8) # with multiprocessing.Pool(__pool_size) as pool: @@ -29,26 +78,18 @@ def evaluate_comparison(dataset: Dataset, estimators=None) -> DatasetReport: log.info( f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} started" ) - tasks = [ - WorkerArgs( - _estimate=estim, - train=d.train, - validation=d.validation, - test=d.test, - _env=env, - q=Logger.queue(), - ) - for estim in CE.func[estimators] - ] + par_tasks, seq_tasks = split_tasks( + CE.func[estimators], + d.train, + d.validation, + d.test, + logger.logger_manager().q, + ) try: tstart = time.time() - results = Parallel(n_jobs=1)(delayed(estimate_worker)(t) for t in tasks) + results = parallel(estimate_worker, par_tasks, n_jobs=env.N_JOBS, _env=env) + results += parallel(estimate_worker, seq_tasks, n_jobs=1, _env=env) results = [r for r in results if r is not None] - # # r for r in pool.imap(estimate_worker, tasks) if r is not None - # r - # for r in map(estimate_worker, tasks) - # if r is not None - # ] g_time = time.time() - tstart log.info( diff --git a/quacc/evaluation/worker.py b/quacc/evaluation/worker.py deleted file mode 100644 index e36a248..0000000 --- a/quacc/evaluation/worker.py +++ /dev/null @@ -1,53 +0,0 @@ -import time -from dataclasses import dataclass -from multiprocessing import Queue -from traceback import print_exception as traceback - -import quapy as qp -from quapy.data import LabelledCollection -from quapy.protocol import APP -from sklearn.linear_model import LogisticRegression - -from quacc.environment import env, environ -from quacc.logger import Logger, SubLogger - - -@dataclass(frozen=True) -class WorkerArgs: - _estimate: callable - train: LabelledCollection - validation: LabelledCollection - test: LabelledCollection - _env: environ - q: Queue - - -def estimate_worker(args: WorkerArgs): - with env.load(args._env): - qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE - # SubLogger.setup(args.q) - # log = SubLogger.logger() - log = Logger.logger() - - model = LogisticRegression() - - model.fit(*args.train.Xy) - protocol = APP( - args.test, - n_prevalences=env.PROTOCOL_N_PREVS, - repeats=env.PROTOCOL_REPEATS, - return_type="labelled_collection", - random_state=env._R_SEED, - ) - start = time.time() - try: - result = args._estimate(model, args.validation, protocol) - except Exception as e: - log.warning(f"Method {args._estimate.name} failed. Exception: {e}") - traceback(e) - return None - - result.time = time.time() - start - log.info(f"{args._estimate.name} finished [took {result.time:.4f}s]") - - return result