diff --git a/quacc/method/model_selection.py b/quacc/method/model_selection.py index c7c49c9..cd8d35b 100644 --- a/quacc/method/model_selection.py +++ b/quacc/method/model_selection.py @@ -15,7 +15,7 @@ import quacc.error from quacc.data import ExtendedCollection, ExtendedData from quacc.environment import env from quacc.evaluation import evaluate -from quacc.logger import SubLogger +from quacc.logger import Logger, SubLogger from quacc.method.base import ( BaseAccuracyEstimator, BinaryQuantifierAccuracyEstimator, @@ -32,7 +32,7 @@ class GridSearchAE(BaseAccuracyEstimator): error: Union[Callable, str] = qc.error.maccd, refit=True, # timeout=-1, - # n_jobs=None, + n_jobs=None, verbose=False, ): self.model = model @@ -40,7 +40,7 @@ class GridSearchAE(BaseAccuracyEstimator): self.protocol = protocol self.refit = refit # self.timeout = timeout - # self.n_jobs = qp._get_njobs(n_jobs) + self.n_jobs = qc._get_njobs(n_jobs) self.verbose = verbose self.__check_error(error) assert isinstance(protocol, AbstractProtocol), "unknown protocol" @@ -92,10 +92,16 @@ class GridSearchAE(BaseAccuracyEstimator): dict(zip(params_keys, val)) for val in itertools.product(*params_values) ] - # self._sout(f"starting model selection with {self.n_jobs =}") - self._sout("starting model selection") + self._sout(f"starting model selection with {self.n_jobs =}") + # self._sout("starting model selection") - scores = [self.__params_eval(params, training) for params in hyper] + # scores = [self.__params_eval((params, training)) for params in hyper] + scores = qc.utils.parallel( + self._params_eval, + ((params, training) for params in hyper), + seed=env._R_SEED, + n_jobs=self.n_jobs, + ) for params, score, model in scores: if score is not None: @@ -118,7 +124,7 @@ class GridSearchAE(BaseAccuracyEstimator): level=1, ) - log = SubLogger.logger() + log = Logger.logger() log.debug( f"[{self.model.__class__.__name__}] " f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) " @@ -137,7 +143,8 @@ class GridSearchAE(BaseAccuracyEstimator): return self - def __params_eval(self, params, training): + def _params_eval(self, args): + params, training = args protocol = self.protocol error = self.error