import numpy as np
from typing import Union, List
from collections.abc import Iterable
from dataclasses import dataclass
from scipy.stats import wilcoxon, ttest_ind_from_stats
import pandas as pd
import os
from pathlib import Path


@dataclass
class CellFormat:
    mean_prec: int = 3
    std_prec: int = 3
    show_std: bool = True
    remove_zero: bool = False
    color: bool = True
    maxtone: int = 50


class Cell:

    def __init__(self, format: CellFormat, group: 'CellGroup'):
        self.values = []
        self.format = format
        self.touch()
        self.group = group
        self.group.register_cell(self)

    def __len__(self):
        return len(self.values)

    def mean(self):
        if self.mean_ is None:
            self.mean_ = np.mean(self.values)
        return self.mean_

    def std(self):
        if self.std_ is None:
            self.std_ = np.std(self.values)
        return self.std_

    def touch(self):
        self.mean_ = None
        self.std_ = None

    def append(self, v: Union[float,Iterable]):
        if isinstance(v, Iterable):
            self.values.extend(v)
        self.values.append(v)
        self.touch()

    def isEmpty(self):
        return len(self)==0

    def isBest(self):
        best = self.group.best()
        if best is not None:
            return (best == self) or (np.isclose(best.mean(), self.mean()))
        return False

    def print_mean(self):
        if self.isEmpty():
            return ''
        else:
            return f'{self.mean():.{self.format.mean_prec}f}'

    def print(self):
        if self.isEmpty():
            return ''

        # mean
        # ---------------------------------------------------
        mean = self.print_mean()
        if self.format.remove_zero:
            mean = mean.replace('0.', '.')

        # std ?
        # ---------------------------------------------------
        if self.format.show_std:
            std  = f' $\pm$ {self.std():.{self.format.std_prec}f}'
        else:
            std = ''

        # bold or statistical test
        # ---------------------------------------------------
        if self.isBest():
            str_cell = f'\\textbf{{{mean}{std}}}'
        else:
            comp_symbol = ''
            pval = self.group.compare(self)
            if pval is not None:
                if 0.005 > pval:
                    comp_symbol = ''
                elif 0.05 > pval >= 0.005:
                    comp_symbol = '$^{\dag}$'
                elif pval >= 0.05:
                    comp_symbol = '${\ddag}$'
            str_cell = f'{mean}{comp_symbol}{std}'

        # color ?
        # ---------------------------------------------------
        if self.format.color:
            str_cell += ' ' + self.group.color(self)

        return str_cell


class CellGroup:

    def __init__(self, lower_is_better=True, stat_test='wilcoxon', color_mode='local', color_global_min=None, color_global_max=None):
        assert stat_test in ['wilcoxon', 'ttest', None], \
            f"unknown {stat_test=}, valid ones are wilcoxon, ttest, or None"
        assert color_mode in ['local', 'global'], \
            f"unknown {color_mode=}, valid ones are local and global"
        if (color_global_min is not None or color_global_max is not None) and color_mode=='local':
            print('warning: color_global_min and color_global_max are only considered when color_mode==local')
        self.cells = []
        self.lower_is_better = lower_is_better
        self.stat_test = stat_test
        self.color_mode = color_mode
        self.color_global_min = color_global_min
        self.color_global_max = color_global_max

    def register_cell(self, cell: Cell):
        self.cells.append(cell)

    def non_empty_cells(self):
        return [c for c in self.cells if not c.isEmpty()]

    def max(self):
        cells = self.non_empty_cells()
        if len(cells)>0:
            return cells[np.argmax([c.mean() for c in cells])]
        return None

    def min(self):
        cells = self.non_empty_cells()
        if len(cells) > 0:
            return cells[np.argmin([c.mean() for c in cells])]
        return None

    def best(self) -> Cell:
        return self.min() if self.lower_is_better else self.max()

    def worst(self) -> Cell:
        return self.max() if self.lower_is_better else self.min()

    def isEmpty(self):
        return len(self.non_empty_cells())==0

    def compare(self, cell: Cell):
        best = self.best()
        best_n = len(best)
        cell_n = len(cell)
        if best_n > 0 and cell_n > 0:
            if self.stat_test == 'wilcoxon':
                try:
                    _, p_val = wilcoxon(best.values, cell.values)
                except ValueError:
                    p_val = None
                return p_val
            elif self.stat_test == 'ttest':
                best_mean, best_std = best.mean(), best.std()
                cell_mean, cell_std = cell.mean(), cell.std()
                _, p_val = ttest_ind_from_stats(best_mean, best_std, best_n, cell_mean, cell_std, cell_n)
                return p_val
            elif self.stat_test is None:
                return None
            else:
                raise ValueError(f'unknown statistical test {self.stat_test}')
        else:
            return None

    def color(self, cell: Cell):
        cell_mean = cell.mean()

        if self.color_mode == 'local':
            best = self.best()
            worst = self.worst()
            best_mean = best.mean()
            worst_mean = worst.mean()

            if best is None or worst is None or best_mean == worst_mean or cell.isEmpty():
                return ''

            # normalize val in [0,1]
            maxval = max(best_mean, worst_mean)
            minval = min(best_mean, worst_mean)
        else:
            maxval = self.color_global_max
            minval = self.color_global_min

        normval = (cell_mean - minval) / (maxval - minval)

        if self.lower_is_better:
            normval = 1 - normval

        normval = np.clip(normval, 0, 1)

        normval = normval * 2 - 1  # rescale to [-1,1]
        if normval < 0:
            color = 'red'
            tone = cell.format.maxtone * (-normval)
        else:
            color = 'green'
            tone = cell.format.maxtone * normval

        return f'\cellcolor{{{color}!{int(tone)}}}'



class Table:

    def __init__(self,
                 name,
                 benchmarks=None,
                 methods=None,
                 format:CellFormat=None,
                 lower_is_better=True,
                 stat_test='wilcoxon',
                 color_mode='local',
                 with_mean=True
                 ):
        self.name = name
        self.benchmarks = [] if benchmarks is None else benchmarks
        self.methods = [] if methods is None else methods
        self.format = format if format is not None else CellFormat()
        self.lower_is_better = lower_is_better
        self.stat_test = stat_test
        self.color_mode = color_mode
        self.with_mean = with_mean
        self.only_full_mean = True  # if False, compute the mean of partially empty methods also

        if self.color_mode == 'global':
            self.color_global_min = 0
            self.color_global_max = 1
        else:
            self.color_global_min = None
            self.color_global_max = None

        self.T = {}
        self.groups = {}

    def add(self, benchmark, method, v):
        cell = self.get(benchmark, method)
        cell.append(v)

    def get_benchmarks(self):
        return self.benchmarks

    def get_methods(self):
        return self.methods

    def n_benchmarks(self):
        return len(self.benchmarks)

    def n_methods(self):
        return len(self.methods)

    def _new_group(self):
        return CellGroup(self.lower_is_better, self.stat_test, color_mode=self.color_mode,
                  color_global_max=self.color_global_max, color_global_min=self.color_global_min)

    def get(self, benchmark, method) -> Cell:
        if benchmark not in self.benchmarks:
            self.benchmarks.append(benchmark)
        if benchmark not in self.groups:
            self.groups[benchmark] = self._new_group()
        if method not in self.methods:
            self.methods.append(method)
        b_idx = self.benchmarks.index(benchmark)
        m_idx = self.methods.index(method)
        idx = tuple((b_idx, m_idx))
        if idx not in self.T:
            self.T[idx] = Cell(self.format, group=self.groups[benchmark])
        cell = self.T[idx]
        return cell

    def get_value(self, benchmark, method) -> float:
        return self.get(benchmark, method).mean()

    def get_benchmark(self, benchmark):
        cells = [self.get(benchmark, method=m) for m in self.get_methods()]
        cells = [c for c in cells if not c.isEmpty()]
        return cells

    def get_method(self, method):
        cells = [self.get(benchmark=b, method=method) for b in self.get_benchmarks()]
        cells = [c for c in cells if not c.isEmpty()]
        return cells

    def get_method_means(self, method_order):
        mean_group = self._new_group()
        cells = []
        for method in method_order:
            method_mean = Cell(self.format, group=mean_group)
            for bench in self.get_benchmarks():
                mean_value = self.get_value(benchmark=bench, method=method)
                if not np.isnan(mean_value):
                    method_mean.append(mean_value)
            cells.append(method_mean)
        return cells

    def get_benchmark_values(self, benchmark):
        values = np.asarray([c.mean() for c in self.get_benchmark(benchmark)])
        return values

    def get_method_values(self, method):
        values = np.asarray([c.mean() for c in self.get_method(method)])
        return values

    def all_mean(self):
        values = [c.mean() for c in self.T.values() if not c.isEmpty()]
        return np.mean(values)

    def print(self):  # todo: missing method names?
        data_dict = {}
        data_dict['Benchmark'] = [b for b in self.get_benchmarks()]
        for method in self.get_methods():
            data_dict[method] = [self.get(bench, method).print_mean() for bench in self.get_benchmarks()]
        df = pd.DataFrame(data_dict)
        pd.set_option('display.max_columns', None)
        pd.set_option('display.max_rows', None)
        print(df.to_string(index=False))

    def tabular(self, path=None, benchmark_replace=None, method_replace=None, benchmark_order=None, method_order=None, transpose=False):
        if benchmark_replace is None:
            benchmark_replace = {}
        if method_replace is None:
            method_replace = {}
        if benchmark_order is None:
            benchmark_order = self.get_benchmarks()
        if method_order is None:
            method_order = self.get_methods()

        if transpose:
            row_order, row_replace = method_order, method_replace
            col_order, col_replace = benchmark_order, benchmark_replace
        else:
            row_order, row_replace = benchmark_order, benchmark_replace
            col_order, col_replace = method_order, method_replace

        n_cols = len(col_order)
        add_mean_col = self.with_mean and transpose
        add_mean_row = self.with_mean and not transpose
        last_col_idx = n_cols+2 if add_mean_col else n_cols+1

        if self.with_mean:
            mean_cells = self.get_method_means(method_order)

        lines = []
        lines.append('\\begin{tabular}{|c' + '|c' * n_cols + ('||c' if add_mean_col else '') + "|}")

        lines.append(f'\\cline{{2-{last_col_idx}}}')
        l = '\multicolumn{1}{c|}{} & '
        l += ' & '.join([col_replace.get(col, col) for col in col_order])
        if add_mean_col:
            l += ' & Ave.'
        l += ' \\\\\\hline'
        lines.append(l)

        for i, row in enumerate(row_order):
            rowname = row_replace.get(row, row)
            l = rowname + ' & '
            l += ' & '.join([
                self.get(benchmark=col if transpose else row, method=row if transpose else col).print()
                for col in col_order
            ])
            if add_mean_col:
                l+= ' & ' + mean_cells[i].print()
            l += ' \\\\\\hline'
            lines.append(l)

        if add_mean_row:
            lines.append('\hline')
            l = 'Ave. & '
            l+= ' & '.join([mean_cell.print() for mean_cell in mean_cells])
            l += ' \\\\\\hline'
            lines.append(l)

        lines.append('\\end{tabular}')

        tabular_tex = '\n'.join(lines)

        if path is not None:
            parent = Path(path).parent
            if parent:
                os.makedirs(parent, exist_ok=True)
            with open(path, 'wt') as foo:
                foo.write(tabular_tex)

        return tabular_tex

    def table(self, tabular_path, benchmark_replace=None, method_replace=None, resizebox=True, caption=None, label=None, benchmark_order=None, method_order=None, transpose=False):
        if benchmark_replace is None:
            benchmark_replace = {}
        if method_replace is None:
            method_replace = {}

        lines = []
        lines.append('\\begin{table}')
        lines.append('\center')
        if resizebox:
            lines.append('\\resizebox{\\textwidth}{!}{%')

        tabular_str = self.tabular(tabular_path, benchmark_replace, method_replace, benchmark_order, method_order, transpose)
        if tabular_path is None:
            lines.append(tabular_str)
        else:
            lines.append(f'\input{{tables/{Path(tabular_path).name}}}')

        if resizebox:
            lines.append('}%')
        if caption is None:
            caption = tabular_path.replace('_', '\_')
        lines.append(f'\caption{{{caption}}}')
        if label is not None:
            lines.append(f'\label{{{label}}}')
        lines.append('\end{table}')

        table_tex = '\n'.join(lines)

        return table_tex

    def document(self, tex_path, tabular_dir='tables', *args, **kwargs):
        Table.Document(tex_path, tables=[self], tabular_dir=tabular_dir, *args, **kwargs)

    def latexPDF(self, pdf_path, tabular_dir='tables', *args, **kwargs):
        return Table.LatexPDF(pdf_path, tables=[self], tabular_dir=tabular_dir, *args, **kwargs)

    @classmethod
    def Document(self, tex_path, tables:List['Table'], tabular_dir='tables', landscape=True, *args, **kwargs):
        lines = []
        lines.append('\\documentclass[10pt,a4paper]{article}')
        lines.append('\\usepackage[utf8]{inputenc}')
        lines.append('\\usepackage{amsmath}')
        lines.append('\\usepackage{amsfonts}')
        lines.append('\\usepackage{amssymb}')
        lines.append('\\usepackage{graphicx}')
        lines.append('\\usepackage{xcolor}')
        lines.append('\\usepackage{colortbl}')
        if landscape:
            lines.append('\\usepackage[landscape]{geometry}')
        lines.append('')
        lines.append('\\begin{document}')
        for table in tables:
            lines.append('')
            lines.append(table.table(os.path.join(Path(tex_path).parent, tabular_dir, table.name + '_table.tex'), *args, **kwargs))
            lines.append('\n\\newpage\n')
        lines.append('\\end{document}')

        document = '\n'.join(lines)

        parent = Path(tex_path).parent
        if parent:
            os.makedirs(parent, exist_ok=True)
        with open(tex_path, 'wt') as foo:
            foo.write(document)

        return document

    @classmethod
    def LatexPDF(cls, pdf_path: str, tables:List['Table'], tabular_dir: str = 'tables', *args, **kwargs):
        assert pdf_path.endswith('.pdf'), f'{pdf_path=} does not seem a valid name for a pdf file'
        tex_path = pdf_path.replace('.pdf', '.tex')

        cls.Document(tex_path, tables, tabular_dir, *args, **kwargs)

        dir = Path(pdf_path).parent
        pwd = os.getcwd()

        print('currently in', pwd)
        print("[Tables Done] runing latex")
        os.chdir(dir)
        os.system('pdflatex ' + Path(tex_path).name)
        basename = Path(tex_path).name.replace('.tex', '')
        os.system(f'rm {basename}.aux {basename}.log')
        os.chdir(pwd)
        print('[Done]')