from scipy.stats import wilcoxon, ttest_ind_from_stats
import numpy as np



class ResultSet:
    VALID_TESTS = [None, "wilcoxon", "ttest_ind_from_stats"]
    TTEST_DIFF = 'different'
    TTEST_SIM  = 'similar'
    TTEST_SAME = 'same'

    def __init__(self, name, addfunc, compare='mean', lower_is_better=True, show_std=True, test="wilcoxon",
                 remove_mean='', prec_mean=3, remove_std='', prec_std=3, maxtone=50, minval=None, maxval=None):
        """

        :param name: name of the result set (e.g., a Dataset)
        :param addfunc: a function which is called to process the result input in the "add" method. This function should
        return a dictionary containing any key-value (e.g., 'mean':0.89) of interest
        :param compare: the key (as generated by addfunc) that is to be compared in order to rank results
        :param lower_is_better: if True, lower values of the "compare" key will result in higher ranks
        :param show_std: whether or not to show the 'std' value (if True, the addfunc is expected to generate it)
        :param test: which test of statistical significance to use. If "wilcoxon" then scipy.stats.wilcoxon(x,y) will
        be computed where x,y are the values of the key "values" as computed by addfunc. If "ttest_ind_from_stats", then
        scipy.stats.ttest_ind_from_stats will be called on "mean", "std", "nobs" values (as computed by addfunc) for
        both samples being compared.
        :param remove_mean: if specified, removes the string from the mean (e.g., useful to remove the '0.')
        :param remove_std: if specified, removes the string from the std (e.g., useful to remove the '0.')
        """
        self.name = name
        self.addfunc = addfunc
        self.compare = compare
        self.lower_is_better = lower_is_better
        self.show_std = show_std
        assert test in self.VALID_TESTS, f'unknown test, valid are {self.VALID_TESTS}'
        self.test = test
        self.remove_mean = remove_mean
        self.prec_mean = prec_mean
        self.remove_std = remove_std
        self.prec_std = prec_std
        self.maxtone = maxtone
        self.minval = minval
        self.maxval = maxval

        self.r = dict()
        self.computed = False

    def add(self, key, *args):
        result = self.addfunc(*args)
        if result is None:
            return
        assert 'values' in result, f'the add function {self.addfunc.__name__} does not fill the "values" attribute'
        self.r[key] = result
        vals = self.r[key]['values']
        if isinstance(vals, np.ndarray):
            self.r[key]['mean'] = vals.mean()
            self.r[key]['std'] = vals.std()
            self.r[key]['nobs'] = len(vals)
        self.computed = False

    def update(self):
        if not self.computed:
            self.compute()

    def compute(self):
        keylist = np.asarray(list(self.r.keys()))
        vallist = [self.r[key][self.compare] for key in keylist]
        keylist = keylist[np.argsort(vallist)]

        print(vallist)
        self.range_minval = min(vallist) if self.minval is None else self.minval
        self.range_maxval = max(vallist) if self.maxval is None else self.maxval
        if not self.lower_is_better:
            keylist = keylist[::-1]

        # keep track of statistical significance tests; if all are different, then the "phantom dags" will not be shown
        self.some_similar = False

        for i, key in enumerate(keylist):
            rank = i + 1
            isbest = rank == 1
            if isbest:
                best = self.r[key]
            self.r[key]['best'] = isbest
            self.r[key]['rank'] = rank

            #color
            val = self.r[key][self.compare]
            self.r[key]['color'] = self.get_value_color(val, minval=self.range_minval, maxval=self.range_maxval)

            if self.test is not None:
                if isbest:
                    p_val = 0
                elif self.test == 'wilcoxon':
                    _, p_val = wilcoxon(best['values'], self.r[key]['values'])
                elif self.test == 'ttest_ind_from_stats':
                    mean1, std1, nobs1 = best['mean'], best['std'], best['nobs']
                    mean2, std2, nobs2 = self.r[key]['mean'], self.r[key]['std'], self.r[key]['nobs']
                    _, p_val = ttest_ind_from_stats(mean1, std1, nobs1, mean2, std2, nobs2)

                if 0.005 >= p_val:
                    self.r[key]['test'] = ResultSet.TTEST_DIFF
                elif 0.05 >= p_val > 0.005:
                    self.r[key]['test'] = ResultSet.TTEST_SIM
                    self.some_similar = True
                elif p_val > 0.05:
                    self.r[key]['test'] = ResultSet.TTEST_SAME
                    self.some_similar = True

        self.computed = True

    def latex(self, key, missing='--', color=True):

        if key not in self.r:
            return missing

        self.update()

        rd = self.r[key]
        s = f"{rd['mean']:.{self.prec_mean}f}"
        if self.remove_mean:
            s = s.replace(self.remove_mean, '.')
        if rd['best']:
            s = "\\textbf{"+s+"}"
        else:
            if self.test is not None and self.some_similar:
                if rd['test'] == ResultSet.TTEST_SIM:
                    s += '^{\dag\phantom{\dag}}'
                elif rd['test'] == ResultSet.TTEST_SAME:
                    s += '^{\ddag}'
                elif rd['test'] == ResultSet.TTEST_DIFF:
                    s += '^{\phantom{\ddag}}'

        if self.show_std:
            std = f"{rd['std']:.{self.prec_std}f}"
            if self.remove_std:
                std = std.replace(self.remove_std, '.')
            s += f" \pm {std}"

        s = f'$ {s} $'
        if color:
            s += ' ' + self.r[key]['color']

        return s

    def mean(self, attr='mean', required:int=None, missing=np.nan):
        """
        returns the mean value for the "attr" attribute
        :param attr: the attribute to average across results
        :param required: if specified, indicates the number of values that should be part of the mean; if this number
        is different, then the mean is not computed
        :param missing: the value to return in case the required condition is not satisfied
        :return: the mean of the "key" attribute
        """
        keylist = list(self.r.keys())
        vallist = [self.r[key].get(attr, None) for key in keylist]
        if None in vallist:
            return missing
        if required is not None:
            if len(vallist) != required:
                return missing
        return np.mean(vallist)

    def get(self, key, attr, missing='--'):
        if key in self.r:
            self.update()
            if attr in self.r[key]:
                return self.r[key][attr]
        return missing

    def get_color(self, key):
        if key not in self.r:
            return ''
        self.update()
        return self.r[key]['color']

    def get_value_color(self, val, minval=None, maxval=None):
        if minval is None or maxval is None:
            self.update()
            minval=self.range_minval
            maxval=self.range_maxval
        val = (val - minval) / (maxval - minval)
        if self.lower_is_better:
            val = 1 - val
        return color_red2green_01(val, self.maxtone)

    def change_compare(self, attr):
        self.compare = attr
        self.computed = False



def color_red2green_01(val, maxtone=100):
    assert 0 <= val <= 1, f'val {val} out of range [0,1]'

    # rescale to [-1,1]
    val = val * 2 - 1
    if val < 0:
        color = 'red'
        tone = maxtone * (-val)
    else:
        color = 'green'
        tone = maxtone * val
    return '\cellcolor{' + color + f'!{int(tone)}' + '}'