parallelization added

This commit is contained in:
Lorenzo Volpi 2023-12-02 02:11:30 +01:00
parent 8705f2b3c0
commit 27a384c1a1
1 changed files with 15 additions and 8 deletions

View File

@ -15,7 +15,7 @@ import quacc.error
from quacc.data import ExtendedCollection, ExtendedData
from quacc.environment import env
from quacc.evaluation import evaluate
from quacc.logger import SubLogger
from quacc.logger import Logger, SubLogger
from quacc.method.base import (
BaseAccuracyEstimator,
BinaryQuantifierAccuracyEstimator,
@ -32,7 +32,7 @@ class GridSearchAE(BaseAccuracyEstimator):
error: Union[Callable, str] = qc.error.maccd,
refit=True,
# timeout=-1,
# n_jobs=None,
n_jobs=None,
verbose=False,
):
self.model = model
@ -40,7 +40,7 @@ class GridSearchAE(BaseAccuracyEstimator):
self.protocol = protocol
self.refit = refit
# self.timeout = timeout
# self.n_jobs = qp._get_njobs(n_jobs)
self.n_jobs = qc._get_njobs(n_jobs)
self.verbose = verbose
self.__check_error(error)
assert isinstance(protocol, AbstractProtocol), "unknown protocol"
@ -92,10 +92,16 @@ class GridSearchAE(BaseAccuracyEstimator):
dict(zip(params_keys, val)) for val in itertools.product(*params_values)
]
# self._sout(f"starting model selection with {self.n_jobs =}")
self._sout("starting model selection")
self._sout(f"starting model selection with {self.n_jobs =}")
# self._sout("starting model selection")
scores = [self.__params_eval(params, training) for params in hyper]
# scores = [self.__params_eval((params, training)) for params in hyper]
scores = qc.utils.parallel(
self._params_eval,
((params, training) for params in hyper),
seed=env._R_SEED,
n_jobs=self.n_jobs,
)
for params, score, model in scores:
if score is not None:
@ -118,7 +124,7 @@ class GridSearchAE(BaseAccuracyEstimator):
level=1,
)
log = SubLogger.logger()
log = Logger.logger()
log.debug(
f"[{self.model.__class__.__name__}] "
f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) "
@ -137,7 +143,8 @@ class GridSearchAE(BaseAccuracyEstimator):
return self
def __params_eval(self, params, training):
def _params_eval(self, args):
params, training = args
protocol = self.protocol
error = self.error