parallelization added
This commit is contained in:
parent
8705f2b3c0
commit
27a384c1a1
|
@ -15,7 +15,7 @@ import quacc.error
|
|||
from quacc.data import ExtendedCollection, ExtendedData
|
||||
from quacc.environment import env
|
||||
from quacc.evaluation import evaluate
|
||||
from quacc.logger import SubLogger
|
||||
from quacc.logger import Logger, SubLogger
|
||||
from quacc.method.base import (
|
||||
BaseAccuracyEstimator,
|
||||
BinaryQuantifierAccuracyEstimator,
|
||||
|
@ -32,7 +32,7 @@ class GridSearchAE(BaseAccuracyEstimator):
|
|||
error: Union[Callable, str] = qc.error.maccd,
|
||||
refit=True,
|
||||
# timeout=-1,
|
||||
# n_jobs=None,
|
||||
n_jobs=None,
|
||||
verbose=False,
|
||||
):
|
||||
self.model = model
|
||||
|
@ -40,7 +40,7 @@ class GridSearchAE(BaseAccuracyEstimator):
|
|||
self.protocol = protocol
|
||||
self.refit = refit
|
||||
# self.timeout = timeout
|
||||
# self.n_jobs = qp._get_njobs(n_jobs)
|
||||
self.n_jobs = qc._get_njobs(n_jobs)
|
||||
self.verbose = verbose
|
||||
self.__check_error(error)
|
||||
assert isinstance(protocol, AbstractProtocol), "unknown protocol"
|
||||
|
@ -92,10 +92,16 @@ class GridSearchAE(BaseAccuracyEstimator):
|
|||
dict(zip(params_keys, val)) for val in itertools.product(*params_values)
|
||||
]
|
||||
|
||||
# self._sout(f"starting model selection with {self.n_jobs =}")
|
||||
self._sout("starting model selection")
|
||||
self._sout(f"starting model selection with {self.n_jobs =}")
|
||||
# self._sout("starting model selection")
|
||||
|
||||
scores = [self.__params_eval(params, training) for params in hyper]
|
||||
# scores = [self.__params_eval((params, training)) for params in hyper]
|
||||
scores = qc.utils.parallel(
|
||||
self._params_eval,
|
||||
((params, training) for params in hyper),
|
||||
seed=env._R_SEED,
|
||||
n_jobs=self.n_jobs,
|
||||
)
|
||||
|
||||
for params, score, model in scores:
|
||||
if score is not None:
|
||||
|
@ -118,7 +124,7 @@ class GridSearchAE(BaseAccuracyEstimator):
|
|||
level=1,
|
||||
)
|
||||
|
||||
log = SubLogger.logger()
|
||||
log = Logger.logger()
|
||||
log.debug(
|
||||
f"[{self.model.__class__.__name__}] "
|
||||
f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) "
|
||||
|
@ -137,7 +143,8 @@ class GridSearchAE(BaseAccuracyEstimator):
|
|||
|
||||
return self
|
||||
|
||||
def __params_eval(self, params, training):
|
||||
def _params_eval(self, args):
|
||||
params, training = args
|
||||
protocol = self.protocol
|
||||
error = self.error
|
||||
|
||||
|
|
Loading…
Reference in New Issue