parallel code improved, env improved

This commit is contained in:
Lorenzo Volpi 2023-12-06 10:02:24 +01:00
parent 2e8af90543
commit c31de41e2f
2 changed files with 48 additions and 18 deletions

View File

@ -1,5 +1,7 @@
from contextlib import contextmanager
import numpy as np
import quapy as qp
import yaml
@ -42,8 +44,14 @@ class environ:
def __setdict(self, d: dict):
for k, v in d.items():
super().__setattr__(k, v)
match k:
case "SAMPLE_SIZE":
qp.environ["SAMPLE_SIZE"] = v
case "_R_SEED":
qp.environ["_R_SEED"] = v
np.random.seed(v)
def __getdict(self) -> dict:
def to_dict(self) -> dict:
return {k: self.__getattribute__(k) for k in environ._keys}
@property
@ -52,16 +60,22 @@ class environ:
@contextmanager
def load(self, conf):
__current = self.__getdict()
if conf is not None:
if isinstance(conf, dict):
__current = self.to_dict()
__np_random_state = np.random.get_state()
if conf is None:
conf = {}
if isinstance(conf, environ):
conf = conf.to_dict()
self.__setdict(conf)
elif isinstance(conf, environ):
self.__setdict(conf.__getdict())
try:
yield
finally:
self.__setdict(__current)
np.random.set_state(__np_random_state)
def load_confs(self):
for c in self.confs:

View File

@ -9,7 +9,8 @@ import pandas as pd
from joblib import Parallel, delayed
from tqdm import tqdm
from quacc.environment import env
from quacc import logger
from quacc.environment import env, environ
def combine_dataframes(dfs, df_index=[]) -> pd.DataFrame:
@ -77,16 +78,31 @@ def download_file(url: str, downloaded_path: Path):
urlretrieve(url, filename=downloaded_path, reporthook=t.update_to)
def parallel(func, args, n_jobs, seed=None):
"""
A wrapper of multiprocessing:
def parallel(
func,
f_args=None,
parallel: Parallel = None,
n_jobs=1,
verbose=0,
_env: environ | dict = None,
seed=None,
):
f_args = f_args or []
>>> Parallel(n_jobs=n_jobs)(
>>> delayed(func)(args_i) for args_i in args
>>> )
if _env is None:
_env = {}
elif isinstance(_env, environ):
_env = _env.to_dict()
that takes the `quapy.environ` variable as input silently.
Seeds the child processes to ensure reproducibility when n_jobs>1
"""
def wrapper(*args):
if seed is not None:
nonlocal _env
_env = _env | dict(_R_SEED=seed)
return Parallel(n_jobs=n_jobs, verbose=1)(delayed(func)(_args) for _args in args)
with env.load(_env):
return func(*args)
parallel = (
Parallel(n_jobs=n_jobs, verbose=verbose) if parallel is None else parallel
)
return parallel(delayed(wrapper)(*_args) for _args in f_args)