parallalization utilities added
This commit is contained in:
parent
116d961313
commit
ad7a8f04a3
|
@ -3,3 +3,8 @@ import quacc.error as error
|
|||
import quacc.logger as logger
|
||||
import quacc.plot as plot
|
||||
import quacc.utils as utils
|
||||
from quacc.environment import env
|
||||
|
||||
|
||||
def _get_njobs(n_jobs):
|
||||
return env.N_JOBS if n_jobs is None else n_jobs
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
import functools
|
||||
import os
|
||||
import shutil
|
||||
from contextlib import ExitStack
|
||||
from pathlib import Path
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
import pandas as pd
|
||||
from joblib import Parallel, delayed
|
||||
from tqdm import tqdm
|
||||
|
||||
from quacc.environment import env
|
||||
|
@ -73,3 +75,18 @@ def download_file(url: str, downloaded_path: Path):
|
|||
desc=downloaded_path.name,
|
||||
) as t:
|
||||
urlretrieve(url, filename=downloaded_path, reporthook=t.update_to)
|
||||
|
||||
|
||||
def parallel(func, args, n_jobs, seed=None):
|
||||
"""
|
||||
A wrapper of multiprocessing:
|
||||
|
||||
>>> Parallel(n_jobs=n_jobs)(
|
||||
>>> delayed(func)(args_i) for args_i in args
|
||||
>>> )
|
||||
|
||||
that takes the `quapy.environ` variable as input silently.
|
||||
Seeds the child processes to ensure reproducibility when n_jobs>1
|
||||
"""
|
||||
|
||||
return Parallel(n_jobs=n_jobs, verbose=1)(delayed(func)(_args) for _args in args)
|
||||
|
|
Loading…
Reference in New Issue