import os
from collections import defaultdict
from pathlib import Path

import panel as pn

from quacc.evaluation.comp import CE
from quacc.evaluation.report import DatasetReport

_plot_sizing_mode = "stretch_both"
valid_plot_modes = defaultdict(
    lambda: [
        "delta_train",
        "stdev_train",
        "train_table",
        "shift",
        "shift_table",
        "diagonal",
    ]
)
valid_plot_modes["avg"] = [
    "delta_train",
    "stdev_train",
    "train_table",
    "shift",
    "shift_table",
    "delta_test",
    "stdev_test",
    "test_table",
]


def create_plots(
    dr: DatasetReport,
    mode="delta",
    metric="acc",
    estimators=None,
    plot_view=None,
):
    _prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs]
    estimators = CE.name[estimators]
    if mode is None:
        mode = valid_plot_modes[plot_view][0]
    _dpi = 112
    match (plot_view, mode):
        case ("avg", "train_table"):
            _data = (
                dr.data(metric=metric, estimators=estimators).groupby(level=1).mean()
            )
            return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
        case ("avg", "test_table"):
            _data = (
                dr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
            )
            return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
        case ("avg", "shift_table"):
            _data = (
                dr.shift_data(metric=metric, estimators=estimators)
                .groupby(level=0)
                .mean()
            )
            return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
        case ("avg", _ as plot_mode):
            _plot = dr.get_plots(
                mode=mode,
                metric=metric,
                estimators=estimators,
                conf="panel",
                return_fig=True,
            )
            return (
                pn.pane.Matplotlib(
                    _plot,
                    tight=True,
                    format="png",
                    # sizing_mode="scale_height",
                    sizing_mode=_plot_sizing_mode,
                    # sizing_mode="scale_both",
                )
                if _plot is not None
                else None
            )
        case (_, "train_table"):
            cr = dr.crs[_prevs.index(int(plot_view))]
            _data = (
                cr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
            )
            return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
        case (_, "shift_table"):
            cr = dr.crs[_prevs.index(int(plot_view))]
            _data = (
                cr.shift_data(metric=metric, estimators=estimators)
                .groupby(level=0)
                .mean()
            )
            return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
        case (_, _ as plot_mode):
            cr = dr.crs[_prevs.index(int(plot_view))]
            _plot = cr.get_plots(
                mode=plot_mode,
                metric=metric,
                estimators=estimators,
                conf="panel",
                return_fig=True,
            )
            return (
                pn.pane.Matplotlib(
                    _plot,
                    tight=True,
                    format="png",
                    sizing_mode=_plot_sizing_mode,
                    # sizing_mode="scale_height",
                    # sizing_mode="scale_both",
                )
                if _plot is not None
                else None
            )


def explore_datasets(root: Path | str):
    if isinstance(root, str):
        root = Path(root)

    if root.name == "plot":
        return []

    if not root.exists():
        return []

    drs = []
    for f in os.listdir(root):
        if (root / f).is_dir():
            drs += explore_datasets(root / f)
        elif f == f"{root.name}.pickle":
            drs.append(root / f)
            # drs.append((str(root),))

    return drs