QuAcc/merge_data.py

111 lines
3.5 KiB
Python
Raw Normal View History

2023-11-22 19:27:22 +01:00
import argparse
import os
import shutil
from pathlib import Path
import numpy as np
import pandas as pd
from quacc.evaluation.comp import CE
from quacc.evaluation.report import DatasetReport, DatasetReportInfo
def load_report_info(path: Path) -> DatasetReportInfo:
return DatasetReport.unpickle(path, report_info=True)
def list_reports(base_path: Path | str):
if isinstance(base_path, str):
base_path = Path(base_path)
if base_path.name == "plot":
return []
reports = []
for f in os.listdir(base_path):
fp = base_path / f
if fp.is_dir():
reports.extend(list_reports(fp))
elif fp.is_file():
if fp.suffix == ".pickle" and fp.stem == base_path.name:
reports.append(load_report_info(fp))
return reports
def playground():
data_a = np.array(np.random.random((4, 6)))
data_b = np.array(np.random.random((4, 4)))
_ind1 = pd.MultiIndex.from_product([["0.2", "0.8"], ["0", "1"]])
_col1 = pd.MultiIndex.from_product([["a", "b"], ["1", "2", "5"]])
_col2 = pd.MultiIndex.from_product([["a", "b"], ["1", "2"]])
a = pd.DataFrame(data_a, index=_ind1, columns=_col1)
b = pd.DataFrame(data_b, index=_ind1, columns=_col2)
print(a)
print(b)
print((a.index == b.index).all())
update_col = a.columns.intersection(b.columns)
col_to_join = b.columns.difference(update_col)
_b = b.drop(columns=[(slice(None), "2")])
_join = pd.concat([a, _b.loc[:, col_to_join]], axis=1)
_join.loc[:, update_col.to_list()] = _b.loc[:, update_col.to_list()]
_join.sort_index(axis=1, level=0, sort_remaining=False, inplace=True)
print(_join)
def merge(dri1: DatasetReportInfo, dri2: DatasetReportInfo, path: Path):
drm = dri1.dr.join(dri2.dr, estimators=CE.name.all)
# save merged dr
_path = path / drm.name / f"{drm.name}.pickle"
os.makedirs(_path.parent, exist_ok=True)
drm.pickle(_path)
# rename dri1 pickle
dri1_bp = Path(dri1.name) / f"{dri1.name.split('/')[-1]}.pickle"
os.rename(dri1_bp, dri1_bp.with_suffix(f".pickle.pre_{dri2.name.split('/')[-2]}"))
# copy merged pickle in place of old dri1 one
shutil.copyfile(_path, dri1_bp)
# copy dri2 log file inside dri1 folder
dri2_bp = Path(dri2.name) / f"{dri2.name.split('/')[-1]}.pickle"
shutil.copyfile(
dri2_bp.with_suffix(".log"),
dri1_bp.with_name(f"{dri1_bp.stem}_{dri2.name.split('/')[-2]}.log"),
)
def run():
parser = argparse.ArgumentParser()
parser.add_argument("path1", nargs="?", default=None)
parser.add_argument("path2", nargs="?", default=None)
parser.add_argument("-l", "--list", action="store_true", dest="list")
parser.add_argument("-v", "--verbose", action="store_true", dest="verbose")
parser.add_argument(
"-o", "--output", action="store", dest="output", default="output/merge"
)
args = parser.parse_args()
reports = list_reports("output")
reports = {r.name: r for r in reports}
if args.list:
for i, r in enumerate(reports.values()):
if args.verbose:
print(f"{i}: {r}")
else:
print(f"{i}: {r.name}")
else:
dri1, dri2 = reports.get(args.path1, None), reports.get(args.path2, None)
if dri1 is None or dri2 is None:
raise ValueError(
f"({args.path1}, {args.path2}) is not a valid pair of paths"
)
merge(dri1, dri2, path=Path(args.output))
if __name__ == "__main__":
run()