parallelization improved, worker code refactored

This commit is contained in:
Lorenzo Volpi 2023-12-06 10:00:23 +01:00
parent 17693318d8
commit 06761da870
2 changed files with 63 additions and 75 deletions

View File

@ -1,4 +1,3 @@
import multiprocessing
import os
import time
from traceback import print_exception as traceback
@ -6,20 +5,70 @@ from traceback import print_exception as traceback
import pandas as pd
import quapy as qp
from joblib import Parallel, delayed
from quapy.protocol import APP
from sklearn.linear_model import LogisticRegression
from quacc import logger
from quacc.dataset import Dataset
from quacc.environment import env
from quacc.evaluation.estimators import CE
from quacc.evaluation.report import CompReport, DatasetReport
from quacc.evaluation.worker import WorkerArgs, estimate_worker
from quacc.logger import Logger
from quacc.utils import parallel
# from quacc.logger import logger, logger_manager
# from quacc.evaluation.worker import WorkerArgs, estimate_worker
pd.set_option("display.float_format", "{:.4f}".format)
qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE
# qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE
def estimate_worker(_estimate, train, validation, test, q=None):
# qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE
log = logger.setup_worker_logger(q)
model = LogisticRegression()
model.fit(*train.Xy)
protocol = APP(
test,
n_prevalences=env.PROTOCOL_N_PREVS,
repeats=env.PROTOCOL_REPEATS,
return_type="labelled_collection",
random_state=env._R_SEED,
)
start = time.time()
try:
result = _estimate(model, validation, protocol)
except Exception as e:
log.warning(f"Method {_estimate.name} failed. Exception: {e}")
traceback(e)
return None
result.time = time.time() - start
log.info(f"{_estimate.name} finished [took {result.time:.4f}s]")
logger.logger_manager().rm_worker()
return result
def split_tasks(estimators, train, validation, test, q):
_par, _seq = [], []
for estim in estimators:
_task = [estim, train, validation, test]
match estim.name:
case n if n.endswith("_gs"):
_seq.append(_task)
case _:
_par.append(_task + [q])
return _par, _seq
def evaluate_comparison(dataset: Dataset, estimators=None) -> DatasetReport:
log = Logger.logger()
# log = Logger.logger()
log = logger.logger()
# with multiprocessing.Pool(1) as pool:
__pool_size = round(os.cpu_count() * 0.8)
# with multiprocessing.Pool(__pool_size) as pool:
@ -29,26 +78,18 @@ def evaluate_comparison(dataset: Dataset, estimators=None) -> DatasetReport:
log.info(
f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} started"
)
tasks = [
WorkerArgs(
_estimate=estim,
train=d.train,
validation=d.validation,
test=d.test,
_env=env,
q=Logger.queue(),
)
for estim in CE.func[estimators]
]
par_tasks, seq_tasks = split_tasks(
CE.func[estimators],
d.train,
d.validation,
d.test,
logger.logger_manager().q,
)
try:
tstart = time.time()
results = Parallel(n_jobs=1)(delayed(estimate_worker)(t) for t in tasks)
results = parallel(estimate_worker, par_tasks, n_jobs=env.N_JOBS, _env=env)
results += parallel(estimate_worker, seq_tasks, n_jobs=1, _env=env)
results = [r for r in results if r is not None]
# # r for r in pool.imap(estimate_worker, tasks) if r is not None
# r
# for r in map(estimate_worker, tasks)
# if r is not None
# ]
g_time = time.time() - tstart
log.info(

View File

@ -1,53 +0,0 @@
import time
from dataclasses import dataclass
from multiprocessing import Queue
from traceback import print_exception as traceback
import quapy as qp
from quapy.data import LabelledCollection
from quapy.protocol import APP
from sklearn.linear_model import LogisticRegression
from quacc.environment import env, environ
from quacc.logger import Logger, SubLogger
@dataclass(frozen=True)
class WorkerArgs:
_estimate: callable
train: LabelledCollection
validation: LabelledCollection
test: LabelledCollection
_env: environ
q: Queue
def estimate_worker(args: WorkerArgs):
with env.load(args._env):
qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE
# SubLogger.setup(args.q)
# log = SubLogger.logger()
log = Logger.logger()
model = LogisticRegression()
model.fit(*args.train.Xy)
protocol = APP(
args.test,
n_prevalences=env.PROTOCOL_N_PREVS,
repeats=env.PROTOCOL_REPEATS,
return_type="labelled_collection",
random_state=env._R_SEED,
)
start = time.time()
try:
result = args._estimate(model, args.validation, protocol)
except Exception as e:
log.warning(f"Method {args._estimate.name} failed. Exception: {e}")
traceback(e)
return None
result.time = time.time() - start
log.info(f"{args._estimate.name} finished [took {result.time:.4f}s]")
return result