QuAcc/quacc/utils.py

109 lines
2.7 KiB
Python
Raw Normal View History

2023-11-08 17:26:44 +01:00
import functools
import os
import shutil
2023-12-02 02:12:20 +01:00
from contextlib import ExitStack
2023-11-08 17:26:44 +01:00
from pathlib import Path
from urllib.request import urlretrieve
2023-11-08 17:26:44 +01:00
import pandas as pd
2023-12-02 02:12:20 +01:00
from joblib import Parallel, delayed
from tqdm import tqdm
2023-11-08 17:26:44 +01:00
2023-12-06 10:02:24 +01:00
from quacc import logger
from quacc.environment import env, environ
2023-11-08 17:26:44 +01:00
def combine_dataframes(dfs, df_index=[]) -> pd.DataFrame:
if len(dfs) < 1:
raise ValueError
if len(dfs) == 1:
return dfs[0]
df = dfs[0]
for ndf in dfs[1:]:
df = df.join(ndf.set_index(df_index), on=df_index)
return df
def avg_group_report(df: pd.DataFrame) -> pd.DataFrame:
def _reduce_func(s1, s2):
return {(n1, n2): v + s2[(n1, n2)] for ((n1, n2), v) in s1.items()}
lst = df.to_dict(orient="records")[1:-1]
summed_series = functools.reduce(_reduce_func, lst)
idx = df.columns.drop([("base", "T"), ("base", "F")])
avg_report = {
(n1, n2): (v / len(lst))
for ((n1, n2), v) in summed_series.items()
if n1 != "base"
}
return pd.DataFrame([avg_report], columns=idx)
def fmt_line_md(s):
return f"> {s} \n"
2023-11-26 16:33:43 +01:00
def create_dataser_dir(dir_name, update=False):
dataset_dir = Path(env.OUT_DIR_NAME) / dir_name
2023-11-08 17:26:44 +01:00
env.OUT_DIR = dataset_dir
if update:
2023-11-26 16:33:43 +01:00
os.makedirs(dataset_dir, exist_ok=True)
2023-11-08 17:26:44 +01:00
else:
shutil.rmtree(dataset_dir, ignore_errors=True)
2023-11-26 16:33:43 +01:00
os.makedirs(dataset_dir)
2023-11-16 01:36:18 +01:00
def get_quacc_home():
home = Path("~/quacc_home").expanduser()
os.makedirs(home, exist_ok=True)
return home
class TqdmUpTo(tqdm):
def update_to(self, b=1, bsize=1, tsize=None):
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n)
def download_file(url: str, downloaded_path: Path):
with TqdmUpTo(
unit="B",
unit_scale=True,
unit_divisor=1024,
miniters=1,
desc=downloaded_path.name,
) as t:
urlretrieve(url, filename=downloaded_path, reporthook=t.update_to)
2023-12-02 02:12:20 +01:00
2023-12-06 10:02:24 +01:00
def parallel(
func,
f_args=None,
parallel: Parallel = None,
n_jobs=1,
verbose=0,
_env: environ | dict = None,
seed=None,
):
f_args = f_args or []
if _env is None:
_env = {}
elif isinstance(_env, environ):
_env = _env.to_dict()
def wrapper(*args):
if seed is not None:
nonlocal _env
_env = _env | dict(_R_SEED=seed)
with env.load(_env):
return func(*args)
parallel = (
Parallel(n_jobs=n_jobs, verbose=verbose) if parallel is None else parallel
)
return parallel(delayed(wrapper)(*_args) for _args in f_args)