parallelization improved, worker code refactored
This commit is contained in:
parent
17693318d8
commit
06761da870
|
@ -1,4 +1,3 @@
|
||||||
import multiprocessing
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from traceback import print_exception as traceback
|
from traceback import print_exception as traceback
|
||||||
|
@ -6,20 +5,70 @@ from traceback import print_exception as traceback
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import quapy as qp
|
import quapy as qp
|
||||||
from joblib import Parallel, delayed
|
from joblib import Parallel, delayed
|
||||||
|
from quapy.protocol import APP
|
||||||
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
|
||||||
|
from quacc import logger
|
||||||
from quacc.dataset import Dataset
|
from quacc.dataset import Dataset
|
||||||
from quacc.environment import env
|
from quacc.environment import env
|
||||||
from quacc.evaluation.estimators import CE
|
from quacc.evaluation.estimators import CE
|
||||||
from quacc.evaluation.report import CompReport, DatasetReport
|
from quacc.evaluation.report import CompReport, DatasetReport
|
||||||
from quacc.evaluation.worker import WorkerArgs, estimate_worker
|
from quacc.utils import parallel
|
||||||
from quacc.logger import Logger
|
|
||||||
|
# from quacc.logger import logger, logger_manager
|
||||||
|
|
||||||
|
# from quacc.evaluation.worker import WorkerArgs, estimate_worker
|
||||||
|
|
||||||
pd.set_option("display.float_format", "{:.4f}".format)
|
pd.set_option("display.float_format", "{:.4f}".format)
|
||||||
qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE
|
# qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_worker(_estimate, train, validation, test, q=None):
|
||||||
|
# qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE
|
||||||
|
log = logger.setup_worker_logger(q)
|
||||||
|
|
||||||
|
model = LogisticRegression()
|
||||||
|
|
||||||
|
model.fit(*train.Xy)
|
||||||
|
protocol = APP(
|
||||||
|
test,
|
||||||
|
n_prevalences=env.PROTOCOL_N_PREVS,
|
||||||
|
repeats=env.PROTOCOL_REPEATS,
|
||||||
|
return_type="labelled_collection",
|
||||||
|
random_state=env._R_SEED,
|
||||||
|
)
|
||||||
|
start = time.time()
|
||||||
|
try:
|
||||||
|
result = _estimate(model, validation, protocol)
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(f"Method {_estimate.name} failed. Exception: {e}")
|
||||||
|
traceback(e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
result.time = time.time() - start
|
||||||
|
log.info(f"{_estimate.name} finished [took {result.time:.4f}s]")
|
||||||
|
|
||||||
|
logger.logger_manager().rm_worker()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def split_tasks(estimators, train, validation, test, q):
|
||||||
|
_par, _seq = [], []
|
||||||
|
for estim in estimators:
|
||||||
|
_task = [estim, train, validation, test]
|
||||||
|
match estim.name:
|
||||||
|
case n if n.endswith("_gs"):
|
||||||
|
_seq.append(_task)
|
||||||
|
case _:
|
||||||
|
_par.append(_task + [q])
|
||||||
|
|
||||||
|
return _par, _seq
|
||||||
|
|
||||||
|
|
||||||
def evaluate_comparison(dataset: Dataset, estimators=None) -> DatasetReport:
|
def evaluate_comparison(dataset: Dataset, estimators=None) -> DatasetReport:
|
||||||
log = Logger.logger()
|
# log = Logger.logger()
|
||||||
|
log = logger.logger()
|
||||||
# with multiprocessing.Pool(1) as pool:
|
# with multiprocessing.Pool(1) as pool:
|
||||||
__pool_size = round(os.cpu_count() * 0.8)
|
__pool_size = round(os.cpu_count() * 0.8)
|
||||||
# with multiprocessing.Pool(__pool_size) as pool:
|
# with multiprocessing.Pool(__pool_size) as pool:
|
||||||
|
@ -29,26 +78,18 @@ def evaluate_comparison(dataset: Dataset, estimators=None) -> DatasetReport:
|
||||||
log.info(
|
log.info(
|
||||||
f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} started"
|
f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} started"
|
||||||
)
|
)
|
||||||
tasks = [
|
par_tasks, seq_tasks = split_tasks(
|
||||||
WorkerArgs(
|
CE.func[estimators],
|
||||||
_estimate=estim,
|
d.train,
|
||||||
train=d.train,
|
d.validation,
|
||||||
validation=d.validation,
|
d.test,
|
||||||
test=d.test,
|
logger.logger_manager().q,
|
||||||
_env=env,
|
)
|
||||||
q=Logger.queue(),
|
|
||||||
)
|
|
||||||
for estim in CE.func[estimators]
|
|
||||||
]
|
|
||||||
try:
|
try:
|
||||||
tstart = time.time()
|
tstart = time.time()
|
||||||
results = Parallel(n_jobs=1)(delayed(estimate_worker)(t) for t in tasks)
|
results = parallel(estimate_worker, par_tasks, n_jobs=env.N_JOBS, _env=env)
|
||||||
|
results += parallel(estimate_worker, seq_tasks, n_jobs=1, _env=env)
|
||||||
results = [r for r in results if r is not None]
|
results = [r for r in results if r is not None]
|
||||||
# # r for r in pool.imap(estimate_worker, tasks) if r is not None
|
|
||||||
# r
|
|
||||||
# for r in map(estimate_worker, tasks)
|
|
||||||
# if r is not None
|
|
||||||
# ]
|
|
||||||
|
|
||||||
g_time = time.time() - tstart
|
g_time = time.time() - tstart
|
||||||
log.info(
|
log.info(
|
||||||
|
|
|
@ -1,53 +0,0 @@
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from multiprocessing import Queue
|
|
||||||
from traceback import print_exception as traceback
|
|
||||||
|
|
||||||
import quapy as qp
|
|
||||||
from quapy.data import LabelledCollection
|
|
||||||
from quapy.protocol import APP
|
|
||||||
from sklearn.linear_model import LogisticRegression
|
|
||||||
|
|
||||||
from quacc.environment import env, environ
|
|
||||||
from quacc.logger import Logger, SubLogger
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class WorkerArgs:
|
|
||||||
_estimate: callable
|
|
||||||
train: LabelledCollection
|
|
||||||
validation: LabelledCollection
|
|
||||||
test: LabelledCollection
|
|
||||||
_env: environ
|
|
||||||
q: Queue
|
|
||||||
|
|
||||||
|
|
||||||
def estimate_worker(args: WorkerArgs):
|
|
||||||
with env.load(args._env):
|
|
||||||
qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE
|
|
||||||
# SubLogger.setup(args.q)
|
|
||||||
# log = SubLogger.logger()
|
|
||||||
log = Logger.logger()
|
|
||||||
|
|
||||||
model = LogisticRegression()
|
|
||||||
|
|
||||||
model.fit(*args.train.Xy)
|
|
||||||
protocol = APP(
|
|
||||||
args.test,
|
|
||||||
n_prevalences=env.PROTOCOL_N_PREVS,
|
|
||||||
repeats=env.PROTOCOL_REPEATS,
|
|
||||||
return_type="labelled_collection",
|
|
||||||
random_state=env._R_SEED,
|
|
||||||
)
|
|
||||||
start = time.time()
|
|
||||||
try:
|
|
||||||
result = args._estimate(model, args.validation, protocol)
|
|
||||||
except Exception as e:
|
|
||||||
log.warning(f"Method {args._estimate.name} failed. Exception: {e}")
|
|
||||||
traceback(e)
|
|
||||||
return None
|
|
||||||
|
|
||||||
result.time = time.time() - start
|
|
||||||
log.info(f"{args._estimate.name} finished [took {result.time:.4f}s]")
|
|
||||||
|
|
||||||
return result
|
|
Loading…
Reference in New Issue