merge_data cli module added
This commit is contained in:
parent
846ef2ba53
commit
907bd3a770
|
@ -0,0 +1,110 @@
|
|||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from quacc.evaluation.comp import CE
|
||||
from quacc.evaluation.report import DatasetReport, DatasetReportInfo
|
||||
|
||||
|
||||
def load_report_info(path: Path) -> DatasetReportInfo:
|
||||
return DatasetReport.unpickle(path, report_info=True)
|
||||
|
||||
|
||||
def list_reports(base_path: Path | str):
|
||||
if isinstance(base_path, str):
|
||||
base_path = Path(base_path)
|
||||
|
||||
if base_path.name == "plot":
|
||||
return []
|
||||
|
||||
reports = []
|
||||
for f in os.listdir(base_path):
|
||||
fp = base_path / f
|
||||
if fp.is_dir():
|
||||
reports.extend(list_reports(fp))
|
||||
elif fp.is_file():
|
||||
if fp.suffix == ".pickle" and fp.stem == base_path.name:
|
||||
reports.append(load_report_info(fp))
|
||||
|
||||
return reports
|
||||
|
||||
|
||||
def playground():
|
||||
data_a = np.array(np.random.random((4, 6)))
|
||||
data_b = np.array(np.random.random((4, 4)))
|
||||
_ind1 = pd.MultiIndex.from_product([["0.2", "0.8"], ["0", "1"]])
|
||||
_col1 = pd.MultiIndex.from_product([["a", "b"], ["1", "2", "5"]])
|
||||
_col2 = pd.MultiIndex.from_product([["a", "b"], ["1", "2"]])
|
||||
a = pd.DataFrame(data_a, index=_ind1, columns=_col1)
|
||||
b = pd.DataFrame(data_b, index=_ind1, columns=_col2)
|
||||
print(a)
|
||||
print(b)
|
||||
print((a.index == b.index).all())
|
||||
update_col = a.columns.intersection(b.columns)
|
||||
col_to_join = b.columns.difference(update_col)
|
||||
_b = b.drop(columns=[(slice(None), "2")])
|
||||
_join = pd.concat([a, _b.loc[:, col_to_join]], axis=1)
|
||||
_join.loc[:, update_col.to_list()] = _b.loc[:, update_col.to_list()]
|
||||
_join.sort_index(axis=1, level=0, sort_remaining=False, inplace=True)
|
||||
|
||||
print(_join)
|
||||
|
||||
|
||||
def merge(dri1: DatasetReportInfo, dri2: DatasetReportInfo, path: Path):
|
||||
drm = dri1.dr.join(dri2.dr, estimators=CE.name.all)
|
||||
|
||||
# save merged dr
|
||||
_path = path / drm.name / f"{drm.name}.pickle"
|
||||
os.makedirs(_path.parent, exist_ok=True)
|
||||
drm.pickle(_path)
|
||||
|
||||
# rename dri1 pickle
|
||||
dri1_bp = Path(dri1.name) / f"{dri1.name.split('/')[-1]}.pickle"
|
||||
os.rename(dri1_bp, dri1_bp.with_suffix(f".pickle.pre_{dri2.name.split('/')[-2]}"))
|
||||
|
||||
# copy merged pickle in place of old dri1 one
|
||||
shutil.copyfile(_path, dri1_bp)
|
||||
|
||||
# copy dri2 log file inside dri1 folder
|
||||
dri2_bp = Path(dri2.name) / f"{dri2.name.split('/')[-1]}.pickle"
|
||||
shutil.copyfile(
|
||||
dri2_bp.with_suffix(".log"),
|
||||
dri1_bp.with_name(f"{dri1_bp.stem}_{dri2.name.split('/')[-2]}.log"),
|
||||
)
|
||||
|
||||
|
||||
def run():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("path1", nargs="?", default=None)
|
||||
parser.add_argument("path2", nargs="?", default=None)
|
||||
parser.add_argument("-l", "--list", action="store_true", dest="list")
|
||||
parser.add_argument("-v", "--verbose", action="store_true", dest="verbose")
|
||||
parser.add_argument(
|
||||
"-o", "--output", action="store", dest="output", default="output/merge"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
reports = list_reports("output")
|
||||
reports = {r.name: r for r in reports}
|
||||
|
||||
if args.list:
|
||||
for i, r in enumerate(reports.values()):
|
||||
if args.verbose:
|
||||
print(f"{i}: {r}")
|
||||
else:
|
||||
print(f"{i}: {r.name}")
|
||||
else:
|
||||
dri1, dri2 = reports.get(args.path1, None), reports.get(args.path2, None)
|
||||
if dri1 is None or dri2 is None:
|
||||
raise ValueError(
|
||||
f"({args.path1}, {args.path2}) is not a valid pair of paths"
|
||||
)
|
||||
merge(dri1, dri2, path=Path(args.output))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
Loading…
Reference in New Issue