switched implementation to pool.imap

This commit is contained in:
Lorenzo Volpi 2023-11-26 16:29:28 +01:00
parent 72ff1334ac
commit afceb6ee1c
2 changed files with 65 additions and 60 deletions

View File

@ -12,7 +12,7 @@ from quacc.dataset import Dataset
from quacc.environment import env from quacc.environment import env
from quacc.evaluation import baseline, method from quacc.evaluation import baseline, method
from quacc.evaluation.report import CompReport, DatasetReport from quacc.evaluation.report import CompReport, DatasetReport
from quacc.evaluation.worker import estimate_worker from quacc.evaluation.worker import WorkerArgs, estimate_worker
from quacc.logger import Logger from quacc.logger import Logger
pd.set_option("display.float_format", "{:.4f}".format) pd.set_option("display.float_format", "{:.4f}".format)
@ -91,45 +91,42 @@ def evaluate_comparison(dataset: Dataset, estimators=None) -> DatasetReport:
log.info( log.info(
f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} started" f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} started"
) )
tstart = time.time()
tasks = [ tasks = [
(estim, d.train, d.validation, d.test) for estim in CE.func[estimators] WorkerArgs(
_estimate=estim,
train=d.train,
validation=d.validation,
test=d.test,
_env=env,
q=Logger.queue(),
)
for estim in CE.func[estimators]
] ]
try:
tstart = time.time()
results = [ results = [
pool.apply_async(estimate_worker, t, {"_env": env, "q": Logger.queue()}) r for r in pool.imap(estimate_worker, tasks) if r is not None
for t in tasks
] ]
results_got = [] g_time = time.time() - tstart
for _r in results: log.info(
try: f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} finished "
r = _r.get() f"[took {g_time:.4f}s]"
if r["result"] is not None:
results_got.append(r)
except Exception as e:
log.warning(
f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} failed. Exception: {e}"
) )
tend = time.time()
times = {r["name"]: r["time"] for r in results_got}
times["tot"] = tend - tstart
log.info(
f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} finished [took {times['tot']:.4f}s]"
)
try:
cr = CompReport( cr = CompReport(
[r["result"] for r in results_got], results,
name=dataset.name, name=dataset.name,
train_prev=d.train_prev, train_prev=d.train_prev,
valid_prev=d.validation_prev, valid_prev=d.validation_prev,
times=times, g_time=g_time,
) )
dr += cr
except Exception as e: except Exception as e:
log.warning( log.warning(
f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} failed. Exception: {e}" f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} failed. "
f"Exception: {e}"
) )
traceback(e) traceback(e)
cr = None
dr += cr
return dr return dr

View File

@ -1,44 +1,52 @@
import time import time
from dataclasses import dataclass
from multiprocessing import Queue
from traceback import print_exception as traceback from traceback import print_exception as traceback
import quapy as qp import quapy as qp
from quapy.data import LabelledCollection
from quapy.protocol import APP from quapy.protocol import APP
from sklearn.linear_model import LogisticRegression from sklearn.linear_model import LogisticRegression
from quacc.environment import env, environ
from quacc.logger import SubLogger from quacc.logger import SubLogger
def estimate_worker(_estimate, train, validation, test, _env=None, q=None): @dataclass(frozen=True)
qp.environ["SAMPLE_SIZE"] = _env.SAMPLE_SIZE class WorkerArgs:
SubLogger.setup(q) _estimate: callable
train: LabelledCollection
validation: LabelledCollection
test: LabelledCollection
_env: environ
q: Queue
def estimate_worker(args: WorkerArgs):
with env.load(args._env):
qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE
SubLogger.setup(args.q)
log = SubLogger.logger() log = SubLogger.logger()
model = LogisticRegression() model = LogisticRegression()
model.fit(*train.Xy) model.fit(*args.train.Xy)
protocol = APP( protocol = APP(
test, args.test,
n_prevalences=_env.PROTOCOL_N_PREVS, n_prevalences=env.PROTOCOL_N_PREVS,
repeats=_env.PROTOCOL_REPEATS, repeats=env.PROTOCOL_REPEATS,
return_type="labelled_collection", return_type="labelled_collection",
random_state=env._R_SEED,
) )
start = time.time() start = time.time()
try: try:
result = _estimate(model, validation, protocol) result = args._estimate(model, args.validation, protocol)
except Exception as e: except Exception as e:
log.warning(f"Method {_estimate.__name__} failed. Exception: {e}") log.warning(f"Method {args._estimate.name} failed. Exception: {e}")
traceback(e) traceback(e)
return { return None
"name": _estimate.__name__,
"result": None,
"time": 0,
}
end = time.time() result.time = time.time() - start
log.info(f"{_estimate.__name__} finished [took {end-start:.4f}s]") log.info(f"{args._estimate.name} finished [took {result.time:.4f}s]")
return { return result
"name": _estimate.__name__,
"result": result,
"time": end - start,
}