import os
from collections import defaultdict
from pathlib import Path

import panel as pn

from quacc.evaluation.estimators import CE
from quacc.evaluation.report import CompReport, DatasetReport
from quacc.evaluation.stats import wilcoxon

_plot_sizing_mode = "stretch_both"
valid_plot_modes = defaultdict(lambda: CompReport._default_modes)
valid_plot_modes["avg"] = DatasetReport._default_dr_modes


def create_plot(
    dr: DatasetReport,
    mode="delta",
    metric="acc",
    estimators=None,
    plot_view=None,
):
    _prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs]
    estimators = CE.name[estimators]
    if mode is None:
        mode = valid_plot_modes[plot_view][0]
    match (plot_view, mode):
        case ("avg", _ as plot_mode):
            _plot = dr.get_plots(
                mode=mode,
                metric=metric,
                estimators=estimators,
                conf="panel",
                save_fig=False,
            )
        case (_, _ as plot_mode):
            cr = dr.crs[_prevs.index(int(plot_view))]
            _plot = cr.get_plots(
                mode=plot_mode,
                metric=metric,
                estimators=estimators,
                conf="panel",
                save_fig=False,
            )
    if _plot is None:
        return None

    return pn.pane.Matplotlib(
        _plot,
        tight=True,
        format="png",
        # sizing_mode="scale_height",
        sizing_mode=_plot_sizing_mode,
        styles=dict(margin="0"),
        # sizing_mode="scale_both",
    )


def create_table(
    dr: DatasetReport,
    mode="delta",
    metric="acc",
    estimators=None,
    plot_view=None,
):
    _prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs]
    estimators = CE.name[estimators]
    if mode is None:
        mode = valid_plot_modes[plot_view][0]
    match (plot_view, mode):
        case ("avg", "train_table"):
            _data = (
                dr.data(metric=metric, estimators=estimators).groupby(level=1).mean()
            )
        case ("avg", "test_table"):
            _data = (
                dr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
            )
        case ("avg", "shift_table"):
            _data = (
                dr.shift_data(metric=metric, estimators=estimators)
                .groupby(level=0)
                .mean()
            )
        case ("avg", "stats_table"):
            _data = wilcoxon(dr, metric=metric, estimators=estimators)
        case (_, "train_table"):
            cr = dr.crs[_prevs.index(int(plot_view))]
            _data = (
                cr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
            )
        case (_, "shift_table"):
            cr = dr.crs[_prevs.index(int(plot_view))]
            _data = (
                cr.shift_data(metric=metric, estimators=estimators)
                .groupby(level=0)
                .mean()
            )
        case (_, "stats_table"):
            cr = dr.crs[_prevs.index(int(plot_view))]
            _data = wilcoxon(cr, metric=metric, estimators=estimators)

    return (
        pn.Column(
            pn.pane.DataFrame(
                _data,
                align="center",
                float_format=lambda v: f"{v:6e}",
                styles={"font-size-adjust": "0.62"},
            ),
            sizing_mode="stretch_both",
            # scroll=True,
        )
        if not _data.empty
        else None
    )


def create_result(
    dr: DatasetReport,
    mode="delta",
    metric="acc",
    estimators=None,
    plot_view=None,
):
    match mode:
        case m if m.endswith("table"):
            return create_table(dr, mode, metric, estimators, plot_view)
        case _:
            return create_plot(dr, mode, metric, estimators, plot_view)


def explore_datasets(root: Path | str):
    if isinstance(root, str):
        root = Path(root)

    if root.name == "plot":
        return []

    if not root.exists():
        return []

    drs = []
    for f in os.listdir(root):
        if (root / f).is_dir():
            drs += explore_datasets(root / f)
        elif f == f"{root.name}.pickle":
            drs.append(root / f)
            # drs.append((str(root),))

    return drs