import argparse
import os
import shutil
from pathlib import Path

import numpy as np
import pandas as pd

from quacc.evaluation.estimators import CE
from quacc.evaluation.report import DatasetReport, DatasetReportInfo


def load_report_info(path: Path) -> DatasetReportInfo:
    return DatasetReport.unpickle(path, report_info=True)


def list_reports(base_path: Path | str):
    if isinstance(base_path, str):
        base_path = Path(base_path)

    if base_path.name == "plot":
        return []

    reports = []
    for f in os.listdir(base_path):
        fp = base_path / f
        if fp.is_dir():
            reports.extend(list_reports(fp))
        elif fp.is_file():
            if fp.suffix == ".pickle" and fp.stem == base_path.name:
                reports.append(load_report_info(fp))

    return reports


def playground():
    data_a = np.array(np.random.random((4, 6)))
    data_b = np.array(np.random.random((4, 4)))
    _ind1 = pd.MultiIndex.from_product([["0.2", "0.8"], ["0", "1"]])
    _col1 = pd.MultiIndex.from_product([["a", "b"], ["1", "2", "5"]])
    _col2 = pd.MultiIndex.from_product([["a", "b"], ["1", "2"]])
    a = pd.DataFrame(data_a, index=_ind1, columns=_col1)
    b = pd.DataFrame(data_b, index=_ind1, columns=_col2)
    print(a)
    print(b)
    print((a.index == b.index).all())
    update_col = a.columns.intersection(b.columns)
    col_to_join = b.columns.difference(update_col)
    _b = b.drop(columns=[(slice(None), "2")])
    _join = pd.concat([a, _b.loc[:, col_to_join]], axis=1)
    _join.loc[:, update_col.to_list()] = _b.loc[:, update_col.to_list()]
    _join.sort_index(axis=1, level=0, sort_remaining=False, inplace=True)

    print(_join)


def merge(dri1: DatasetReportInfo, dri2: DatasetReportInfo, path: Path):
    drm = dri1.dr.join(dri2.dr, estimators=CE.name.all)

    # save merged dr
    _path = path / drm.name / f"{drm.name}.pickle"
    os.makedirs(_path.parent, exist_ok=True)
    drm.pickle(_path)

    # rename dri1 pickle
    dri1_bp = Path(dri1.name) / f"{dri1.name.split('/')[-1]}.pickle"
    os.rename(dri1_bp, dri1_bp.with_suffix(f".pickle.pre_{dri2.name.split('/')[-2]}"))

    # copy merged pickle in place of old dri1 one
    shutil.copyfile(_path, dri1_bp)

    # copy dri2 log file inside dri1 folder
    dri2_bp = Path(dri2.name) / f"{dri2.name.split('/')[-1]}.pickle"
    shutil.copyfile(
        dri2_bp.with_suffix(".log"),
        dri1_bp.with_name(f"{dri1_bp.stem}_{dri2.name.split('/')[-2]}.log"),
    )


def run():
    parser = argparse.ArgumentParser()
    parser.add_argument("path1", nargs="?", default=None)
    parser.add_argument("path2", nargs="?", default=None)
    parser.add_argument("-l", "--list", action="store_true", dest="list")
    parser.add_argument("-v", "--verbose", action="store_true", dest="verbose")
    parser.add_argument(
        "-o", "--output", action="store", dest="output", default="output/merge"
    )
    args = parser.parse_args()

    reports = list_reports("output")
    reports = {r.name: r for r in reports}

    if args.list:
        for i, r in enumerate(reports.values()):
            if args.verbose:
                print(f"{i}: {r}")
            else:
                print(f"{i}: {r.name}")
    else:
        dri1, dri2 = reports.get(args.path1, None), reports.get(args.path2, None)
        if dri1 is None or dri2 is None:
            raise ValueError(
                f"({args.path1}, {args.path2}) is not a valid pair of paths"
            )
        merge(dri1, dri2, path=Path(args.output))


if __name__ == "__main__":
    run()