parallalization utilities added
This commit is contained in:
parent
116d961313
commit
ad7a8f04a3
|
@ -3,3 +3,8 @@ import quacc.error as error
|
||||||
import quacc.logger as logger
|
import quacc.logger as logger
|
||||||
import quacc.plot as plot
|
import quacc.plot as plot
|
||||||
import quacc.utils as utils
|
import quacc.utils as utils
|
||||||
|
from quacc.environment import env
|
||||||
|
|
||||||
|
|
||||||
|
def _get_njobs(n_jobs):
|
||||||
|
return env.N_JOBS if n_jobs is None else n_jobs
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
import functools
|
import functools
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
from contextlib import ExitStack
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from urllib.request import urlretrieve
|
from urllib.request import urlretrieve
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from joblib import Parallel, delayed
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from quacc.environment import env
|
from quacc.environment import env
|
||||||
|
@ -73,3 +75,18 @@ def download_file(url: str, downloaded_path: Path):
|
||||||
desc=downloaded_path.name,
|
desc=downloaded_path.name,
|
||||||
) as t:
|
) as t:
|
||||||
urlretrieve(url, filename=downloaded_path, reporthook=t.update_to)
|
urlretrieve(url, filename=downloaded_path, reporthook=t.update_to)
|
||||||
|
|
||||||
|
|
||||||
|
def parallel(func, args, n_jobs, seed=None):
|
||||||
|
"""
|
||||||
|
A wrapper of multiprocessing:
|
||||||
|
|
||||||
|
>>> Parallel(n_jobs=n_jobs)(
|
||||||
|
>>> delayed(func)(args_i) for args_i in args
|
||||||
|
>>> )
|
||||||
|
|
||||||
|
that takes the `quapy.environ` variable as input silently.
|
||||||
|
Seeds the child processes to ensure reproducibility when n_jobs>1
|
||||||
|
"""
|
||||||
|
|
||||||
|
return Parallel(n_jobs=n_jobs, verbose=1)(delayed(func)(_args) for _args in args)
|
||||||
|
|
Loading…
Reference in New Issue