2023-11-22 19:27:22 +01:00
|
|
|
import argparse
|
|
|
|
import os
|
|
|
|
import shutil
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import pandas as pd
|
|
|
|
|
2023-11-27 03:27:43 +01:00
|
|
|
from quacc.evaluation.estimators import CE
|
2023-11-22 19:27:22 +01:00
|
|
|
from quacc.evaluation.report import DatasetReport, DatasetReportInfo
|
|
|
|
|
|
|
|
|
|
|
|
def load_report_info(path: Path) -> DatasetReportInfo:
|
|
|
|
return DatasetReport.unpickle(path, report_info=True)
|
|
|
|
|
|
|
|
|
|
|
|
def list_reports(base_path: Path | str):
|
|
|
|
if isinstance(base_path, str):
|
|
|
|
base_path = Path(base_path)
|
|
|
|
|
|
|
|
if base_path.name == "plot":
|
|
|
|
return []
|
|
|
|
|
|
|
|
reports = []
|
|
|
|
for f in os.listdir(base_path):
|
|
|
|
fp = base_path / f
|
|
|
|
if fp.is_dir():
|
|
|
|
reports.extend(list_reports(fp))
|
|
|
|
elif fp.is_file():
|
|
|
|
if fp.suffix == ".pickle" and fp.stem == base_path.name:
|
|
|
|
reports.append(load_report_info(fp))
|
|
|
|
|
|
|
|
return reports
|
|
|
|
|
|
|
|
|
|
|
|
def playground():
|
|
|
|
data_a = np.array(np.random.random((4, 6)))
|
|
|
|
data_b = np.array(np.random.random((4, 4)))
|
|
|
|
_ind1 = pd.MultiIndex.from_product([["0.2", "0.8"], ["0", "1"]])
|
|
|
|
_col1 = pd.MultiIndex.from_product([["a", "b"], ["1", "2", "5"]])
|
|
|
|
_col2 = pd.MultiIndex.from_product([["a", "b"], ["1", "2"]])
|
|
|
|
a = pd.DataFrame(data_a, index=_ind1, columns=_col1)
|
|
|
|
b = pd.DataFrame(data_b, index=_ind1, columns=_col2)
|
|
|
|
print(a)
|
|
|
|
print(b)
|
|
|
|
print((a.index == b.index).all())
|
|
|
|
update_col = a.columns.intersection(b.columns)
|
|
|
|
col_to_join = b.columns.difference(update_col)
|
|
|
|
_b = b.drop(columns=[(slice(None), "2")])
|
|
|
|
_join = pd.concat([a, _b.loc[:, col_to_join]], axis=1)
|
|
|
|
_join.loc[:, update_col.to_list()] = _b.loc[:, update_col.to_list()]
|
|
|
|
_join.sort_index(axis=1, level=0, sort_remaining=False, inplace=True)
|
|
|
|
|
|
|
|
print(_join)
|
|
|
|
|
|
|
|
|
|
|
|
def merge(dri1: DatasetReportInfo, dri2: DatasetReportInfo, path: Path):
|
|
|
|
drm = dri1.dr.join(dri2.dr, estimators=CE.name.all)
|
|
|
|
|
|
|
|
# save merged dr
|
|
|
|
_path = path / drm.name / f"{drm.name}.pickle"
|
|
|
|
os.makedirs(_path.parent, exist_ok=True)
|
|
|
|
drm.pickle(_path)
|
|
|
|
|
|
|
|
# rename dri1 pickle
|
|
|
|
dri1_bp = Path(dri1.name) / f"{dri1.name.split('/')[-1]}.pickle"
|
|
|
|
os.rename(dri1_bp, dri1_bp.with_suffix(f".pickle.pre_{dri2.name.split('/')[-2]}"))
|
|
|
|
|
|
|
|
# copy merged pickle in place of old dri1 one
|
|
|
|
shutil.copyfile(_path, dri1_bp)
|
|
|
|
|
|
|
|
# copy dri2 log file inside dri1 folder
|
|
|
|
dri2_bp = Path(dri2.name) / f"{dri2.name.split('/')[-1]}.pickle"
|
|
|
|
shutil.copyfile(
|
|
|
|
dri2_bp.with_suffix(".log"),
|
|
|
|
dri1_bp.with_name(f"{dri1_bp.stem}_{dri2.name.split('/')[-2]}.log"),
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def run():
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("path1", nargs="?", default=None)
|
|
|
|
parser.add_argument("path2", nargs="?", default=None)
|
|
|
|
parser.add_argument("-l", "--list", action="store_true", dest="list")
|
|
|
|
parser.add_argument("-v", "--verbose", action="store_true", dest="verbose")
|
|
|
|
parser.add_argument(
|
|
|
|
"-o", "--output", action="store", dest="output", default="output/merge"
|
|
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
reports = list_reports("output")
|
|
|
|
reports = {r.name: r for r in reports}
|
|
|
|
|
|
|
|
if args.list:
|
|
|
|
for i, r in enumerate(reports.values()):
|
|
|
|
if args.verbose:
|
|
|
|
print(f"{i}: {r}")
|
|
|
|
else:
|
|
|
|
print(f"{i}: {r.name}")
|
|
|
|
else:
|
|
|
|
dri1, dri2 = reports.get(args.path1, None), reports.get(args.path2, None)
|
|
|
|
if dri1 is None or dri2 is None:
|
|
|
|
raise ValueError(
|
|
|
|
f"({args.path1}, {args.path2}) is not a valid pair of paths"
|
|
|
|
)
|
|
|
|
merge(dri1, dri2, path=Path(args.output))
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
run()
|