forked from moreo/QuaPy
205 lines
7.6 KiB
Python
205 lines
7.6 KiB
Python
from scipy.stats import wilcoxon, ttest_ind_from_stats
|
|
import numpy as np
|
|
|
|
|
|
|
|
class ResultSet:
|
|
VALID_TESTS = [None, "wilcoxon", "ttest_ind_from_stats"]
|
|
TTEST_DIFF = 'different'
|
|
TTEST_SIM = 'similar'
|
|
TTEST_SAME = 'same'
|
|
|
|
def __init__(self, name, addfunc, compare='mean', lower_is_better=True, show_std=True, test="wilcoxon",
|
|
remove_mean='', prec_mean=3, remove_std='', prec_std=3, maxtone=50, minval=None, maxval=None):
|
|
"""
|
|
|
|
:param name: name of the result set (e.g., a Dataset)
|
|
:param addfunc: a function which is called to process the result input in the "add" method. This function should
|
|
return a dictionary containing any key-value (e.g., 'mean':0.89) of interest
|
|
:param compare: the key (as generated by addfunc) that is to be compared in order to rank results
|
|
:param lower_is_better: if True, lower values of the "compare" key will result in higher ranks
|
|
:param show_std: whether or not to show the 'std' value (if True, the addfunc is expected to generate it)
|
|
:param test: which test of statistical significance to use. If "wilcoxon" then scipy.stats.wilcoxon(x,y) will
|
|
be computed where x,y are the values of the key "values" as computed by addfunc. If "ttest_ind_from_stats", then
|
|
scipy.stats.ttest_ind_from_stats will be called on "mean", "std", "nobs" values (as computed by addfunc) for
|
|
both samples being compared.
|
|
:param remove_mean: if specified, removes the string from the mean (e.g., useful to remove the '0.')
|
|
:param remove_std: if specified, removes the string from the std (e.g., useful to remove the '0.')
|
|
"""
|
|
self.name = name
|
|
self.addfunc = addfunc
|
|
self.compare = compare
|
|
self.lower_is_better = lower_is_better
|
|
self.show_std = show_std
|
|
assert test in self.VALID_TESTS, f'unknown test, valid are {self.VALID_TESTS}'
|
|
self.test = test
|
|
self.remove_mean = remove_mean
|
|
self.prec_mean = prec_mean
|
|
self.remove_std = remove_std
|
|
self.prec_std = prec_std
|
|
self.maxtone = maxtone
|
|
self.minval = minval
|
|
self.maxval = maxval
|
|
|
|
self.r = dict()
|
|
self.computed = False
|
|
|
|
def add(self, key, *args):
|
|
result = self.addfunc(*args)
|
|
if result is None:
|
|
return
|
|
assert 'values' in result, f'the add function {self.addfunc.__name__} does not fill the "values" attribute'
|
|
self.r[key] = result
|
|
vals = self.r[key]['values']
|
|
if isinstance(vals, np.ndarray):
|
|
self.r[key]['mean'] = vals.mean()
|
|
self.r[key]['std'] = vals.std()
|
|
self.r[key]['nobs'] = len(vals)
|
|
self.computed = False
|
|
|
|
def update(self):
|
|
if not self.computed:
|
|
self.compute()
|
|
|
|
def compute(self):
|
|
keylist = np.asarray(list(self.r.keys()))
|
|
vallist = [self.r[key][self.compare] for key in keylist]
|
|
keylist = keylist[np.argsort(vallist)]
|
|
|
|
print(vallist)
|
|
self.range_minval = min(vallist) if self.minval is None else self.minval
|
|
self.range_maxval = max(vallist) if self.maxval is None else self.maxval
|
|
if not self.lower_is_better:
|
|
keylist = keylist[::-1]
|
|
|
|
# keep track of statistical significance tests; if all are different, then the "phantom dags" will not be shown
|
|
self.some_similar = False
|
|
|
|
for i, key in enumerate(keylist):
|
|
rank = i + 1
|
|
isbest = rank == 1
|
|
if isbest:
|
|
best = self.r[key]
|
|
self.r[key]['best'] = isbest
|
|
self.r[key]['rank'] = rank
|
|
|
|
#color
|
|
val = self.r[key][self.compare]
|
|
self.r[key]['color'] = self.get_value_color(val, minval=self.range_minval, maxval=self.range_maxval)
|
|
|
|
if self.test is not None:
|
|
if isbest:
|
|
p_val = 0
|
|
elif self.test == 'wilcoxon':
|
|
_, p_val = wilcoxon(best['values'], self.r[key]['values'])
|
|
elif self.test == 'ttest_ind_from_stats':
|
|
mean1, std1, nobs1 = best['mean'], best['std'], best['nobs']
|
|
mean2, std2, nobs2 = self.r[key]['mean'], self.r[key]['std'], self.r[key]['nobs']
|
|
_, p_val = ttest_ind_from_stats(mean1, std1, nobs1, mean2, std2, nobs2)
|
|
|
|
if 0.005 >= p_val:
|
|
self.r[key]['test'] = ResultSet.TTEST_DIFF
|
|
elif 0.05 >= p_val > 0.005:
|
|
self.r[key]['test'] = ResultSet.TTEST_SIM
|
|
self.some_similar = True
|
|
elif p_val > 0.05:
|
|
self.r[key]['test'] = ResultSet.TTEST_SAME
|
|
self.some_similar = True
|
|
|
|
self.computed = True
|
|
|
|
def latex(self, key, missing='--', color=True):
|
|
|
|
if key not in self.r:
|
|
return missing
|
|
|
|
self.update()
|
|
|
|
rd = self.r[key]
|
|
s = f"{rd['mean']:.{self.prec_mean}f}"
|
|
if self.remove_mean:
|
|
s = s.replace(self.remove_mean, '.')
|
|
if rd['best']:
|
|
s = "\\textbf{"+s+"}"
|
|
else:
|
|
if self.test is not None and self.some_similar:
|
|
if rd['test'] == ResultSet.TTEST_SIM:
|
|
s += '^{\dag\phantom{\dag}}'
|
|
elif rd['test'] == ResultSet.TTEST_SAME:
|
|
s += '^{\ddag}'
|
|
elif rd['test'] == ResultSet.TTEST_DIFF:
|
|
s += '^{\phantom{\ddag}}'
|
|
|
|
if self.show_std:
|
|
std = f"{rd['std']:.{self.prec_std}f}"
|
|
if self.remove_std:
|
|
std = std.replace(self.remove_std, '.')
|
|
s += f" \pm {std}"
|
|
|
|
s = f'$ {s} $'
|
|
if color:
|
|
s += ' ' + self.r[key]['color']
|
|
|
|
return s
|
|
|
|
def mean(self, attr='mean', required:int=None, missing=np.nan):
|
|
"""
|
|
returns the mean value for the "attr" attribute
|
|
:param attr: the attribute to average across results
|
|
:param required: if specified, indicates the number of values that should be part of the mean; if this number
|
|
is different, then the mean is not computed
|
|
:param missing: the value to return in case the required condition is not satisfied
|
|
:return: the mean of the "key" attribute
|
|
"""
|
|
keylist = list(self.r.keys())
|
|
vallist = [self.r[key].get(attr, None) for key in keylist]
|
|
if None in vallist:
|
|
return missing
|
|
if required is not None:
|
|
if len(vallist) != required:
|
|
return missing
|
|
return np.mean(vallist)
|
|
|
|
def get(self, key, attr, missing='--'):
|
|
if key in self.r:
|
|
self.update()
|
|
if attr in self.r[key]:
|
|
return self.r[key][attr]
|
|
return missing
|
|
|
|
def get_color(self, key):
|
|
if key not in self.r:
|
|
return ''
|
|
self.update()
|
|
return self.r[key]['color']
|
|
|
|
def get_value_color(self, val, minval=None, maxval=None):
|
|
if minval is None or maxval is None:
|
|
self.update()
|
|
minval=self.range_minval
|
|
maxval=self.range_maxval
|
|
val = (val - minval) / (maxval - minval)
|
|
if self.lower_is_better:
|
|
val = 1 - val
|
|
return color_red2green_01(val, self.maxtone)
|
|
|
|
def change_compare(self, attr):
|
|
self.compare = attr
|
|
self.computed = False
|
|
|
|
|
|
|
|
def color_red2green_01(val, maxtone=100):
|
|
assert 0 <= val <= 1, f'val {val} out of range [0,1]'
|
|
|
|
# rescale to [-1,1]
|
|
val = val * 2 - 1
|
|
if val < 0:
|
|
color = 'red'
|
|
tone = maxtone * (-val)
|
|
else:
|
|
color = 'green'
|
|
tone = maxtone * val
|
|
return '\cellcolor{' + color + f'!{int(tone)}' + '}'
|
|
|