diff --git a/quacc/__init__.py b/quacc/__init__.py index de5f84d..8a6c37b 100644 --- a/quacc/__init__.py +++ b/quacc/__init__.py @@ -3,3 +3,8 @@ import quacc.error as error import quacc.logger as logger import quacc.plot as plot import quacc.utils as utils +from quacc.environment import env + + +def _get_njobs(n_jobs): + return env.N_JOBS if n_jobs is None else n_jobs diff --git a/quacc/utils.py b/quacc/utils.py index 0310d69..c798884 100644 --- a/quacc/utils.py +++ b/quacc/utils.py @@ -1,10 +1,12 @@ import functools import os import shutil +from contextlib import ExitStack from pathlib import Path from urllib.request import urlretrieve import pandas as pd +from joblib import Parallel, delayed from tqdm import tqdm from quacc.environment import env @@ -73,3 +75,18 @@ def download_file(url: str, downloaded_path: Path): desc=downloaded_path.name, ) as t: urlretrieve(url, filename=downloaded_path, reporthook=t.update_to) + + +def parallel(func, args, n_jobs, seed=None): + """ + A wrapper of multiprocessing: + + >>> Parallel(n_jobs=n_jobs)( + >>> delayed(func)(args_i) for args_i in args + >>> ) + + that takes the `quapy.environ` variable as input silently. + Seeds the child processes to ensure reproducibility when n_jobs>1 + """ + + return Parallel(n_jobs=n_jobs, verbose=1)(delayed(func)(_args) for _args in args)