parallelization improved, worker code refactored

This commit is contained in:
Lorenzo Volpi 2023-12-06 10:00:23 +01:00
parent 17693318d8
commit 06761da870
2 changed files with 63 additions and 75 deletions

View File

@ -1,4 +1,3 @@
import multiprocessing
import os import os
import time import time
from traceback import print_exception as traceback from traceback import print_exception as traceback
@ -6,20 +5,70 @@ from traceback import print_exception as traceback
import pandas as pd import pandas as pd
import quapy as qp import quapy as qp
from joblib import Parallel, delayed from joblib import Parallel, delayed
from quapy.protocol import APP
from sklearn.linear_model import LogisticRegression
from quacc import logger
from quacc.dataset import Dataset from quacc.dataset import Dataset
from quacc.environment import env from quacc.environment import env
from quacc.evaluation.estimators import CE from quacc.evaluation.estimators import CE
from quacc.evaluation.report import CompReport, DatasetReport from quacc.evaluation.report import CompReport, DatasetReport
from quacc.evaluation.worker import WorkerArgs, estimate_worker from quacc.utils import parallel
from quacc.logger import Logger
# from quacc.logger import logger, logger_manager
# from quacc.evaluation.worker import WorkerArgs, estimate_worker
pd.set_option("display.float_format", "{:.4f}".format) pd.set_option("display.float_format", "{:.4f}".format)
qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE # qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE
def estimate_worker(_estimate, train, validation, test, q=None):
# qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE
log = logger.setup_worker_logger(q)
model = LogisticRegression()
model.fit(*train.Xy)
protocol = APP(
test,
n_prevalences=env.PROTOCOL_N_PREVS,
repeats=env.PROTOCOL_REPEATS,
return_type="labelled_collection",
random_state=env._R_SEED,
)
start = time.time()
try:
result = _estimate(model, validation, protocol)
except Exception as e:
log.warning(f"Method {_estimate.name} failed. Exception: {e}")
traceback(e)
return None
result.time = time.time() - start
log.info(f"{_estimate.name} finished [took {result.time:.4f}s]")
logger.logger_manager().rm_worker()
return result
def split_tasks(estimators, train, validation, test, q):
_par, _seq = [], []
for estim in estimators:
_task = [estim, train, validation, test]
match estim.name:
case n if n.endswith("_gs"):
_seq.append(_task)
case _:
_par.append(_task + [q])
return _par, _seq
def evaluate_comparison(dataset: Dataset, estimators=None) -> DatasetReport: def evaluate_comparison(dataset: Dataset, estimators=None) -> DatasetReport:
log = Logger.logger() # log = Logger.logger()
log = logger.logger()
# with multiprocessing.Pool(1) as pool: # with multiprocessing.Pool(1) as pool:
__pool_size = round(os.cpu_count() * 0.8) __pool_size = round(os.cpu_count() * 0.8)
# with multiprocessing.Pool(__pool_size) as pool: # with multiprocessing.Pool(__pool_size) as pool:
@ -29,26 +78,18 @@ def evaluate_comparison(dataset: Dataset, estimators=None) -> DatasetReport:
log.info( log.info(
f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} started" f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} started"
) )
tasks = [ par_tasks, seq_tasks = split_tasks(
WorkerArgs( CE.func[estimators],
_estimate=estim, d.train,
train=d.train, d.validation,
validation=d.validation, d.test,
test=d.test, logger.logger_manager().q,
_env=env, )
q=Logger.queue(),
)
for estim in CE.func[estimators]
]
try: try:
tstart = time.time() tstart = time.time()
results = Parallel(n_jobs=1)(delayed(estimate_worker)(t) for t in tasks) results = parallel(estimate_worker, par_tasks, n_jobs=env.N_JOBS, _env=env)
results += parallel(estimate_worker, seq_tasks, n_jobs=1, _env=env)
results = [r for r in results if r is not None] results = [r for r in results if r is not None]
# # r for r in pool.imap(estimate_worker, tasks) if r is not None
# r
# for r in map(estimate_worker, tasks)
# if r is not None
# ]
g_time = time.time() - tstart g_time = time.time() - tstart
log.info( log.info(

View File

@ -1,53 +0,0 @@
import time
from dataclasses import dataclass
from multiprocessing import Queue
from traceback import print_exception as traceback
import quapy as qp
from quapy.data import LabelledCollection
from quapy.protocol import APP
from sklearn.linear_model import LogisticRegression
from quacc.environment import env, environ
from quacc.logger import Logger, SubLogger
@dataclass(frozen=True)
class WorkerArgs:
_estimate: callable
train: LabelledCollection
validation: LabelledCollection
test: LabelledCollection
_env: environ
q: Queue
def estimate_worker(args: WorkerArgs):
with env.load(args._env):
qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE
# SubLogger.setup(args.q)
# log = SubLogger.logger()
log = Logger.logger()
model = LogisticRegression()
model.fit(*args.train.Xy)
protocol = APP(
args.test,
n_prevalences=env.PROTOCOL_N_PREVS,
repeats=env.PROTOCOL_REPEATS,
return_type="labelled_collection",
random_state=env._R_SEED,
)
start = time.time()
try:
result = args._estimate(model, args.validation, protocol)
except Exception as e:
log.warning(f"Method {args._estimate.name} failed. Exception: {e}")
traceback(e)
return None
result.time = time.time() - start
log.info(f"{args._estimate.name} finished [took {result.time:.4f}s]")
return result