parallel code improved, env improved

This commit is contained in:
Lorenzo Volpi 2023-12-06 10:02:24 +01:00
parent 2e8af90543
commit c31de41e2f
2 changed files with 48 additions and 18 deletions

View File

@ -1,5 +1,7 @@
from contextlib import contextmanager from contextlib import contextmanager
import numpy as np
import quapy as qp
import yaml import yaml
@ -42,8 +44,14 @@ class environ:
def __setdict(self, d: dict): def __setdict(self, d: dict):
for k, v in d.items(): for k, v in d.items():
super().__setattr__(k, v) super().__setattr__(k, v)
match k:
case "SAMPLE_SIZE":
qp.environ["SAMPLE_SIZE"] = v
case "_R_SEED":
qp.environ["_R_SEED"] = v
np.random.seed(v)
def __getdict(self) -> dict: def to_dict(self) -> dict:
return {k: self.__getattribute__(k) for k in environ._keys} return {k: self.__getattribute__(k) for k in environ._keys}
@property @property
@ -52,16 +60,22 @@ class environ:
@contextmanager @contextmanager
def load(self, conf): def load(self, conf):
__current = self.__getdict() __current = self.to_dict()
if conf is not None: __np_random_state = np.random.get_state()
if isinstance(conf, dict):
if conf is None:
conf = {}
if isinstance(conf, environ):
conf = conf.to_dict()
self.__setdict(conf) self.__setdict(conf)
elif isinstance(conf, environ):
self.__setdict(conf.__getdict())
try: try:
yield yield
finally: finally:
self.__setdict(__current) self.__setdict(__current)
np.random.set_state(__np_random_state)
def load_confs(self): def load_confs(self):
for c in self.confs: for c in self.confs:

View File

@ -9,7 +9,8 @@ import pandas as pd
from joblib import Parallel, delayed from joblib import Parallel, delayed
from tqdm import tqdm from tqdm import tqdm
from quacc.environment import env from quacc import logger
from quacc.environment import env, environ
def combine_dataframes(dfs, df_index=[]) -> pd.DataFrame: def combine_dataframes(dfs, df_index=[]) -> pd.DataFrame:
@ -77,16 +78,31 @@ def download_file(url: str, downloaded_path: Path):
urlretrieve(url, filename=downloaded_path, reporthook=t.update_to) urlretrieve(url, filename=downloaded_path, reporthook=t.update_to)
def parallel(func, args, n_jobs, seed=None): def parallel(
""" func,
A wrapper of multiprocessing: f_args=None,
parallel: Parallel = None,
n_jobs=1,
verbose=0,
_env: environ | dict = None,
seed=None,
):
f_args = f_args or []
>>> Parallel(n_jobs=n_jobs)( if _env is None:
>>> delayed(func)(args_i) for args_i in args _env = {}
>>> ) elif isinstance(_env, environ):
_env = _env.to_dict()
that takes the `quapy.environ` variable as input silently. def wrapper(*args):
Seeds the child processes to ensure reproducibility when n_jobs>1 if seed is not None:
""" nonlocal _env
_env = _env | dict(_R_SEED=seed)
return Parallel(n_jobs=n_jobs, verbose=1)(delayed(func)(_args) for _args in args) with env.load(_env):
return func(*args)
parallel = (
Parallel(n_jobs=n_jobs, verbose=verbose) if parallel is None else parallel
)
return parallel(delayed(wrapper)(*_args) for _args in f_args)