parallel code improved, env improved
This commit is contained in:
parent
2e8af90543
commit
c31de41e2f
|
@ -1,5 +1,7 @@
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import quapy as qp
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,8 +44,14 @@ class environ:
|
||||||
def __setdict(self, d: dict):
|
def __setdict(self, d: dict):
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
super().__setattr__(k, v)
|
super().__setattr__(k, v)
|
||||||
|
match k:
|
||||||
|
case "SAMPLE_SIZE":
|
||||||
|
qp.environ["SAMPLE_SIZE"] = v
|
||||||
|
case "_R_SEED":
|
||||||
|
qp.environ["_R_SEED"] = v
|
||||||
|
np.random.seed(v)
|
||||||
|
|
||||||
def __getdict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
return {k: self.__getattribute__(k) for k in environ._keys}
|
return {k: self.__getattribute__(k) for k in environ._keys}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -52,16 +60,22 @@ class environ:
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def load(self, conf):
|
def load(self, conf):
|
||||||
__current = self.__getdict()
|
__current = self.to_dict()
|
||||||
if conf is not None:
|
__np_random_state = np.random.get_state()
|
||||||
if isinstance(conf, dict):
|
|
||||||
self.__setdict(conf)
|
if conf is None:
|
||||||
elif isinstance(conf, environ):
|
conf = {}
|
||||||
self.__setdict(conf.__getdict())
|
|
||||||
|
if isinstance(conf, environ):
|
||||||
|
conf = conf.to_dict()
|
||||||
|
|
||||||
|
self.__setdict(conf)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
self.__setdict(__current)
|
self.__setdict(__current)
|
||||||
|
np.random.set_state(__np_random_state)
|
||||||
|
|
||||||
def load_confs(self):
|
def load_confs(self):
|
||||||
for c in self.confs:
|
for c in self.confs:
|
||||||
|
|
|
@ -9,7 +9,8 @@ import pandas as pd
|
||||||
from joblib import Parallel, delayed
|
from joblib import Parallel, delayed
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from quacc.environment import env
|
from quacc import logger
|
||||||
|
from quacc.environment import env, environ
|
||||||
|
|
||||||
|
|
||||||
def combine_dataframes(dfs, df_index=[]) -> pd.DataFrame:
|
def combine_dataframes(dfs, df_index=[]) -> pd.DataFrame:
|
||||||
|
@ -77,16 +78,31 @@ def download_file(url: str, downloaded_path: Path):
|
||||||
urlretrieve(url, filename=downloaded_path, reporthook=t.update_to)
|
urlretrieve(url, filename=downloaded_path, reporthook=t.update_to)
|
||||||
|
|
||||||
|
|
||||||
def parallel(func, args, n_jobs, seed=None):
|
def parallel(
|
||||||
"""
|
func,
|
||||||
A wrapper of multiprocessing:
|
f_args=None,
|
||||||
|
parallel: Parallel = None,
|
||||||
|
n_jobs=1,
|
||||||
|
verbose=0,
|
||||||
|
_env: environ | dict = None,
|
||||||
|
seed=None,
|
||||||
|
):
|
||||||
|
f_args = f_args or []
|
||||||
|
|
||||||
>>> Parallel(n_jobs=n_jobs)(
|
if _env is None:
|
||||||
>>> delayed(func)(args_i) for args_i in args
|
_env = {}
|
||||||
>>> )
|
elif isinstance(_env, environ):
|
||||||
|
_env = _env.to_dict()
|
||||||
|
|
||||||
that takes the `quapy.environ` variable as input silently.
|
def wrapper(*args):
|
||||||
Seeds the child processes to ensure reproducibility when n_jobs>1
|
if seed is not None:
|
||||||
"""
|
nonlocal _env
|
||||||
|
_env = _env | dict(_R_SEED=seed)
|
||||||
|
|
||||||
return Parallel(n_jobs=n_jobs, verbose=1)(delayed(func)(_args) for _args in args)
|
with env.load(_env):
|
||||||
|
return func(*args)
|
||||||
|
|
||||||
|
parallel = (
|
||||||
|
Parallel(n_jobs=n_jobs, verbose=verbose) if parallel is None else parallel
|
||||||
|
)
|
||||||
|
return parallel(delayed(wrapper)(*_args) for _args in f_args)
|
||||||
|
|
Loading…
Reference in New Issue