From c31de41e2f6ac471fdd564c730870725378a8434 Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Wed, 6 Dec 2023 10:02:24 +0100 Subject: [PATCH] parallel code improved, env improved --- quacc/environment.py | 28 +++++++++++++++++++++------- quacc/utils.py | 38 +++++++++++++++++++++++++++----------- 2 files changed, 48 insertions(+), 18 deletions(-) diff --git a/quacc/environment.py b/quacc/environment.py index f1505c8..09e586a 100644 --- a/quacc/environment.py +++ b/quacc/environment.py @@ -1,5 +1,7 @@ from contextlib import contextmanager +import numpy as np +import quapy as qp import yaml @@ -42,8 +44,14 @@ class environ: def __setdict(self, d: dict): for k, v in d.items(): super().__setattr__(k, v) + match k: + case "SAMPLE_SIZE": + qp.environ["SAMPLE_SIZE"] = v + case "_R_SEED": + qp.environ["_R_SEED"] = v + np.random.seed(v) - def __getdict(self) -> dict: + def to_dict(self) -> dict: return {k: self.__getattribute__(k) for k in environ._keys} @property @@ -52,16 +60,22 @@ class environ: @contextmanager def load(self, conf): - __current = self.__getdict() - if conf is not None: - if isinstance(conf, dict): - self.__setdict(conf) - elif isinstance(conf, environ): - self.__setdict(conf.__getdict()) + __current = self.to_dict() + __np_random_state = np.random.get_state() + + if conf is None: + conf = {} + + if isinstance(conf, environ): + conf = conf.to_dict() + + self.__setdict(conf) + try: yield finally: self.__setdict(__current) + np.random.set_state(__np_random_state) def load_confs(self): for c in self.confs: diff --git a/quacc/utils.py b/quacc/utils.py index c798884..88809e9 100644 --- a/quacc/utils.py +++ b/quacc/utils.py @@ -9,7 +9,8 @@ import pandas as pd from joblib import Parallel, delayed from tqdm import tqdm -from quacc.environment import env +from quacc import logger +from quacc.environment import env, environ def combine_dataframes(dfs, df_index=[]) -> pd.DataFrame: @@ -77,16 +78,31 @@ def download_file(url: str, downloaded_path: Path): urlretrieve(url, filename=downloaded_path, reporthook=t.update_to) -def parallel(func, args, n_jobs, seed=None): - """ - A wrapper of multiprocessing: +def parallel( + func, + f_args=None, + parallel: Parallel = None, + n_jobs=1, + verbose=0, + _env: environ | dict = None, + seed=None, +): + f_args = f_args or [] - >>> Parallel(n_jobs=n_jobs)( - >>> delayed(func)(args_i) for args_i in args - >>> ) + if _env is None: + _env = {} + elif isinstance(_env, environ): + _env = _env.to_dict() - that takes the `quapy.environ` variable as input silently. - Seeds the child processes to ensure reproducibility when n_jobs>1 - """ + def wrapper(*args): + if seed is not None: + nonlocal _env + _env = _env | dict(_R_SEED=seed) - return Parallel(n_jobs=n_jobs, verbose=1)(delayed(func)(_args) for _args in args) + with env.load(_env): + return func(*args) + + parallel = ( + Parallel(n_jobs=n_jobs, verbose=verbose) if parallel is None else parallel + ) + return parallel(delayed(wrapper)(*_args) for _args in f_args)