switched implementation to pool.imap

This commit is contained in:
Lorenzo Volpi 2023-11-26 16:29:28 +01:00
parent 72ff1334ac
commit afceb6ee1c
2 changed files with 65 additions and 60 deletions

View File

@ -12,7 +12,7 @@ from quacc.dataset import Dataset
from quacc.environment import env
from quacc.evaluation import baseline, method
from quacc.evaluation.report import CompReport, DatasetReport
from quacc.evaluation.worker import estimate_worker
from quacc.evaluation.worker import WorkerArgs, estimate_worker
from quacc.logger import Logger
pd.set_option("display.float_format", "{:.4f}".format)
@ -91,45 +91,42 @@ def evaluate_comparison(dataset: Dataset, estimators=None) -> DatasetReport:
log.info(
f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} started"
)
tstart = time.time()
tasks = [
(estim, d.train, d.validation, d.test) for estim in CE.func[estimators]
WorkerArgs(
_estimate=estim,
train=d.train,
validation=d.validation,
test=d.test,
_env=env,
q=Logger.queue(),
)
for estim in CE.func[estimators]
]
results = [
pool.apply_async(estimate_worker, t, {"_env": env, "q": Logger.queue()})
for t in tasks
]
results_got = []
for _r in results:
try:
r = _r.get()
if r["result"] is not None:
results_got.append(r)
except Exception as e:
log.warning(
f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} failed. Exception: {e}"
)
tend = time.time()
times = {r["name"]: r["time"] for r in results_got}
times["tot"] = tend - tstart
log.info(
f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} finished [took {times['tot']:.4f}s]"
)
try:
tstart = time.time()
results = [
r for r in pool.imap(estimate_worker, tasks) if r is not None
]
g_time = time.time() - tstart
log.info(
f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} finished "
f"[took {g_time:.4f}s]"
)
cr = CompReport(
[r["result"] for r in results_got],
results,
name=dataset.name,
train_prev=d.train_prev,
valid_prev=d.validation_prev,
times=times,
g_time=g_time,
)
dr += cr
except Exception as e:
log.warning(
f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} failed. Exception: {e}"
f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} failed. "
f"Exception: {e}"
)
traceback(e)
cr = None
dr += cr
return dr

View File

@ -1,44 +1,52 @@
import time
from dataclasses import dataclass
from multiprocessing import Queue
from traceback import print_exception as traceback
import quapy as qp
from quapy.data import LabelledCollection
from quapy.protocol import APP
from sklearn.linear_model import LogisticRegression
from quacc.environment import env, environ
from quacc.logger import SubLogger
def estimate_worker(_estimate, train, validation, test, _env=None, q=None):
qp.environ["SAMPLE_SIZE"] = _env.SAMPLE_SIZE
SubLogger.setup(q)
log = SubLogger.logger()
@dataclass(frozen=True)
class WorkerArgs:
_estimate: callable
train: LabelledCollection
validation: LabelledCollection
test: LabelledCollection
_env: environ
q: Queue
model = LogisticRegression()
model.fit(*train.Xy)
protocol = APP(
test,
n_prevalences=_env.PROTOCOL_N_PREVS,
repeats=_env.PROTOCOL_REPEATS,
return_type="labelled_collection",
)
start = time.time()
try:
result = _estimate(model, validation, protocol)
except Exception as e:
log.warning(f"Method {_estimate.__name__} failed. Exception: {e}")
traceback(e)
return {
"name": _estimate.__name__,
"result": None,
"time": 0,
}
def estimate_worker(args: WorkerArgs):
with env.load(args._env):
qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE
SubLogger.setup(args.q)
log = SubLogger.logger()
end = time.time()
log.info(f"{_estimate.__name__} finished [took {end-start:.4f}s]")
model = LogisticRegression()
return {
"name": _estimate.__name__,
"result": result,
"time": end - start,
}
model.fit(*args.train.Xy)
protocol = APP(
args.test,
n_prevalences=env.PROTOCOL_N_PREVS,
repeats=env.PROTOCOL_REPEATS,
return_type="labelled_collection",
random_state=env._R_SEED,
)
start = time.time()
try:
result = args._estimate(model, args.validation, protocol)
except Exception as e:
log.warning(f"Method {args._estimate.name} failed. Exception: {e}")
traceback(e)
return None
result.time = time.time() - start
log.info(f"{args._estimate.name} finished [took {result.time:.4f}s]")
return result