forked from moreo/QuaPy
476 lines
16 KiB
Python
476 lines
16 KiB
Python
import numpy as np
|
|
from typing import Union, List
|
|
from collections.abc import Iterable
|
|
from dataclasses import dataclass
|
|
from scipy.stats import wilcoxon, ttest_ind_from_stats
|
|
import pandas as pd
|
|
import os
|
|
from pathlib import Path
|
|
|
|
|
|
@dataclass
|
|
class CellFormat:
|
|
mean_prec: int = 3
|
|
std_prec: int = 3
|
|
show_std: bool = True
|
|
remove_zero: bool = False
|
|
color: bool = True
|
|
maxtone: int = 50
|
|
|
|
|
|
class Cell:
|
|
|
|
def __init__(self, format: CellFormat, group: 'CellGroup'):
|
|
self.values = []
|
|
self.format = format
|
|
self.touch()
|
|
self.group = group
|
|
self.group.register_cell(self)
|
|
|
|
def __len__(self):
|
|
return len(self.values)
|
|
|
|
def mean(self):
|
|
if self.mean_ is None:
|
|
self.mean_ = np.mean(self.values)
|
|
return self.mean_
|
|
|
|
def std(self):
|
|
if self.std_ is None:
|
|
self.std_ = np.std(self.values)
|
|
return self.std_
|
|
|
|
def touch(self):
|
|
self.mean_ = None
|
|
self.std_ = None
|
|
|
|
def append(self, v: Union[float,Iterable]):
|
|
if isinstance(v, Iterable):
|
|
self.values.extend(v)
|
|
self.values.append(v)
|
|
self.touch()
|
|
|
|
def isEmpty(self):
|
|
return len(self)==0
|
|
|
|
def isBest(self):
|
|
best = self.group.best()
|
|
if best is not None:
|
|
return (best == self) or (np.isclose(best.mean(), self.mean()))
|
|
return False
|
|
|
|
def print_mean(self):
|
|
if self.isEmpty():
|
|
return ''
|
|
else:
|
|
return f'{self.mean():.{self.format.mean_prec}f}'
|
|
|
|
def print(self):
|
|
if self.isEmpty():
|
|
return ''
|
|
|
|
# mean
|
|
# ---------------------------------------------------
|
|
mean = self.print_mean()
|
|
if self.format.remove_zero:
|
|
mean = mean.replace('0.', '.')
|
|
|
|
# std ?
|
|
# ---------------------------------------------------
|
|
if self.format.show_std:
|
|
std = f' $\pm$ {self.std():.{self.format.std_prec}f}'
|
|
else:
|
|
std = ''
|
|
|
|
# bold or statistical test
|
|
# ---------------------------------------------------
|
|
if self.isBest():
|
|
str_cell = f'\\textbf{{{mean}{std}}}'
|
|
else:
|
|
comp_symbol = ''
|
|
pval = self.group.compare(self)
|
|
if pval is not None:
|
|
if 0.005 > pval:
|
|
comp_symbol = ''
|
|
elif 0.05 > pval >= 0.005:
|
|
comp_symbol = '$^{\dag}$'
|
|
elif pval >= 0.05:
|
|
comp_symbol = '${\ddag}$'
|
|
str_cell = f'{mean}{comp_symbol}{std}'
|
|
|
|
# color ?
|
|
# ---------------------------------------------------
|
|
if self.format.color:
|
|
str_cell += ' ' + self.group.color(self)
|
|
|
|
return str_cell
|
|
|
|
|
|
class CellGroup:
|
|
|
|
def __init__(self, lower_is_better=True, stat_test='wilcoxon', color_mode='local', color_global_min=None, color_global_max=None):
|
|
assert stat_test in ['wilcoxon', 'ttest', None], \
|
|
f"unknown {stat_test=}, valid ones are wilcoxon, ttest, or None"
|
|
assert color_mode in ['local', 'global'], \
|
|
f"unknown {color_mode=}, valid ones are local and global"
|
|
if (color_global_min is not None or color_global_max is not None) and color_mode=='local':
|
|
print('warning: color_global_min and color_global_max are only considered when color_mode==local')
|
|
self.cells = []
|
|
self.lower_is_better = lower_is_better
|
|
self.stat_test = stat_test
|
|
self.color_mode = color_mode
|
|
self.color_global_min = color_global_min
|
|
self.color_global_max = color_global_max
|
|
|
|
def register_cell(self, cell: Cell):
|
|
self.cells.append(cell)
|
|
|
|
def non_empty_cells(self):
|
|
return [c for c in self.cells if not c.isEmpty()]
|
|
|
|
def max(self):
|
|
cells = self.non_empty_cells()
|
|
if len(cells)>0:
|
|
return cells[np.argmax([c.mean() for c in cells])]
|
|
return None
|
|
|
|
def min(self):
|
|
cells = self.non_empty_cells()
|
|
if len(cells) > 0:
|
|
return cells[np.argmin([c.mean() for c in cells])]
|
|
return None
|
|
|
|
def best(self) -> Cell:
|
|
return self.min() if self.lower_is_better else self.max()
|
|
|
|
def worst(self) -> Cell:
|
|
return self.max() if self.lower_is_better else self.min()
|
|
|
|
def isEmpty(self):
|
|
return len(self.non_empty_cells())==0
|
|
|
|
def compare(self, cell: Cell):
|
|
best = self.best()
|
|
best_n = len(best)
|
|
cell_n = len(cell)
|
|
if best_n > 0 and cell_n > 0:
|
|
if self.stat_test == 'wilcoxon':
|
|
try:
|
|
_, p_val = wilcoxon(best.values, cell.values)
|
|
except ValueError:
|
|
p_val = None
|
|
return p_val
|
|
elif self.stat_test == 'ttest':
|
|
best_mean, best_std = best.mean(), best.std()
|
|
cell_mean, cell_std = cell.mean(), cell.std()
|
|
_, p_val = ttest_ind_from_stats(best_mean, best_std, best_n, cell_mean, cell_std, cell_n)
|
|
return p_val
|
|
elif self.stat_test is None:
|
|
return None
|
|
else:
|
|
raise ValueError(f'unknown statistical test {self.stat_test}')
|
|
else:
|
|
return None
|
|
|
|
def color(self, cell: Cell):
|
|
cell_mean = cell.mean()
|
|
|
|
if self.color_mode == 'local':
|
|
best = self.best()
|
|
worst = self.worst()
|
|
best_mean = best.mean()
|
|
worst_mean = worst.mean()
|
|
|
|
if best is None or worst is None or best_mean == worst_mean or cell.isEmpty():
|
|
return ''
|
|
|
|
# normalize val in [0,1]
|
|
maxval = max(best_mean, worst_mean)
|
|
minval = min(best_mean, worst_mean)
|
|
else:
|
|
maxval = self.color_global_max
|
|
minval = self.color_global_min
|
|
|
|
normval = (cell_mean - minval) / (maxval - minval)
|
|
|
|
if self.lower_is_better:
|
|
normval = 1 - normval
|
|
|
|
normval = np.clip(normval, 0, 1)
|
|
|
|
normval = normval * 2 - 1 # rescale to [-1,1]
|
|
if normval < 0:
|
|
color = 'red'
|
|
tone = cell.format.maxtone * (-normval)
|
|
else:
|
|
color = 'green'
|
|
tone = cell.format.maxtone * normval
|
|
|
|
return f'\cellcolor{{{color}!{int(tone)}}}'
|
|
|
|
|
|
|
|
class Table:
|
|
|
|
def __init__(self,
|
|
name,
|
|
benchmarks=None,
|
|
methods=None,
|
|
format:CellFormat=None,
|
|
lower_is_better=True,
|
|
stat_test='wilcoxon',
|
|
color_mode='local',
|
|
with_mean=True
|
|
):
|
|
self.name = name
|
|
self.benchmarks = [] if benchmarks is None else benchmarks
|
|
self.methods = [] if methods is None else methods
|
|
self.format = format if format is not None else CellFormat()
|
|
self.lower_is_better = lower_is_better
|
|
self.stat_test = stat_test
|
|
self.color_mode = color_mode
|
|
self.with_mean = with_mean
|
|
self.only_full_mean = True # if False, compute the mean of partially empty methods also
|
|
|
|
if self.color_mode == 'global':
|
|
self.color_global_min = 0
|
|
self.color_global_max = 1
|
|
else:
|
|
self.color_global_min = None
|
|
self.color_global_max = None
|
|
|
|
self.T = {}
|
|
self.groups = {}
|
|
|
|
def add(self, benchmark, method, v):
|
|
cell = self.get(benchmark, method)
|
|
cell.append(v)
|
|
|
|
def get_benchmarks(self):
|
|
return self.benchmarks
|
|
|
|
def get_methods(self):
|
|
return self.methods
|
|
|
|
def n_benchmarks(self):
|
|
return len(self.benchmarks)
|
|
|
|
def n_methods(self):
|
|
return len(self.methods)
|
|
|
|
def _new_group(self):
|
|
return CellGroup(self.lower_is_better, self.stat_test, color_mode=self.color_mode,
|
|
color_global_max=self.color_global_max, color_global_min=self.color_global_min)
|
|
|
|
def get(self, benchmark, method) -> Cell:
|
|
if benchmark not in self.benchmarks:
|
|
self.benchmarks.append(benchmark)
|
|
if benchmark not in self.groups:
|
|
self.groups[benchmark] = self._new_group()
|
|
if method not in self.methods:
|
|
self.methods.append(method)
|
|
b_idx = self.benchmarks.index(benchmark)
|
|
m_idx = self.methods.index(method)
|
|
idx = tuple((b_idx, m_idx))
|
|
if idx not in self.T:
|
|
self.T[idx] = Cell(self.format, group=self.groups[benchmark])
|
|
cell = self.T[idx]
|
|
return cell
|
|
|
|
def get_value(self, benchmark, method) -> float:
|
|
return self.get(benchmark, method).mean()
|
|
|
|
def get_benchmark(self, benchmark):
|
|
cells = [self.get(benchmark, method=m) for m in self.get_methods()]
|
|
cells = [c for c in cells if not c.isEmpty()]
|
|
return cells
|
|
|
|
def get_method(self, method):
|
|
cells = [self.get(benchmark=b, method=method) for b in self.get_benchmarks()]
|
|
cells = [c for c in cells if not c.isEmpty()]
|
|
return cells
|
|
|
|
def get_method_means(self, method_order):
|
|
mean_group = self._new_group()
|
|
cells = []
|
|
for method in method_order:
|
|
method_mean = Cell(self.format, group=mean_group)
|
|
for bench in self.get_benchmarks():
|
|
mean_value = self.get_value(benchmark=bench, method=method)
|
|
if not np.isnan(mean_value):
|
|
method_mean.append(mean_value)
|
|
cells.append(method_mean)
|
|
return cells
|
|
|
|
def get_benchmark_values(self, benchmark):
|
|
values = np.asarray([c.mean() for c in self.get_benchmark(benchmark)])
|
|
return values
|
|
|
|
def get_method_values(self, method):
|
|
values = np.asarray([c.mean() for c in self.get_method(method)])
|
|
return values
|
|
|
|
def all_mean(self):
|
|
values = [c.mean() for c in self.T.values() if not c.isEmpty()]
|
|
return np.mean(values)
|
|
|
|
def print(self): # todo: missing method names?
|
|
data_dict = {}
|
|
data_dict['Benchmark'] = [b for b in self.get_benchmarks()]
|
|
for method in self.get_methods():
|
|
data_dict[method] = [self.get(bench, method).print_mean() for bench in self.get_benchmarks()]
|
|
df = pd.DataFrame(data_dict)
|
|
pd.set_option('display.max_columns', None)
|
|
pd.set_option('display.max_rows', None)
|
|
print(df.to_string(index=False))
|
|
|
|
def tabular(self, path=None, benchmark_replace=None, method_replace=None, benchmark_order=None, method_order=None, transpose=False):
|
|
if benchmark_replace is None:
|
|
benchmark_replace = {}
|
|
if method_replace is None:
|
|
method_replace = {}
|
|
if benchmark_order is None:
|
|
benchmark_order = self.get_benchmarks()
|
|
if method_order is None:
|
|
method_order = self.get_methods()
|
|
|
|
if transpose:
|
|
row_order, row_replace = method_order, method_replace
|
|
col_order, col_replace = benchmark_order, benchmark_replace
|
|
else:
|
|
row_order, row_replace = benchmark_order, benchmark_replace
|
|
col_order, col_replace = method_order, method_replace
|
|
|
|
n_cols = len(col_order)
|
|
add_mean_col = self.with_mean and transpose
|
|
add_mean_row = self.with_mean and not transpose
|
|
last_col_idx = n_cols+2 if add_mean_col else n_cols+1
|
|
|
|
if self.with_mean:
|
|
mean_cells = self.get_method_means(method_order)
|
|
|
|
lines = []
|
|
lines.append('\\begin{tabular}{|c' + '|c' * n_cols + ('||c' if add_mean_col else '') + "|}")
|
|
|
|
lines.append(f'\\cline{{2-{last_col_idx}}}')
|
|
l = '\multicolumn{1}{c|}{} & '
|
|
l += ' & '.join([col_replace.get(col, col) for col in col_order])
|
|
if add_mean_col:
|
|
l += ' & Ave.'
|
|
l += ' \\\\\\hline'
|
|
lines.append(l)
|
|
|
|
for i, row in enumerate(row_order):
|
|
rowname = row_replace.get(row, row)
|
|
l = rowname + ' & '
|
|
l += ' & '.join([
|
|
self.get(benchmark=col if transpose else row, method=row if transpose else col).print()
|
|
for col in col_order
|
|
])
|
|
if add_mean_col:
|
|
l+= ' & ' + mean_cells[i].print()
|
|
l += ' \\\\\\hline'
|
|
lines.append(l)
|
|
|
|
if add_mean_row:
|
|
lines.append('\hline')
|
|
l = 'Ave. & '
|
|
l+= ' & '.join([mean_cell.print() for mean_cell in mean_cells])
|
|
l += ' \\\\\\hline'
|
|
lines.append(l)
|
|
|
|
lines.append('\\end{tabular}')
|
|
|
|
tabular_tex = '\n'.join(lines)
|
|
|
|
if path is not None:
|
|
parent = Path(path).parent
|
|
if parent:
|
|
os.makedirs(parent, exist_ok=True)
|
|
with open(path, 'wt') as foo:
|
|
foo.write(tabular_tex)
|
|
|
|
return tabular_tex
|
|
|
|
def table(self, tabular_path, benchmark_replace=None, method_replace=None, resizebox=True, caption=None, label=None, benchmark_order=None, method_order=None, transpose=False):
|
|
if benchmark_replace is None:
|
|
benchmark_replace = {}
|
|
if method_replace is None:
|
|
method_replace = {}
|
|
|
|
lines = []
|
|
lines.append('\\begin{table}')
|
|
lines.append('\center')
|
|
if resizebox:
|
|
lines.append('\\resizebox{\\textwidth}{!}{%')
|
|
|
|
tabular_str = self.tabular(tabular_path, benchmark_replace, method_replace, benchmark_order, method_order, transpose)
|
|
if tabular_path is None:
|
|
lines.append(tabular_str)
|
|
else:
|
|
lines.append(f'\input{{tables/{Path(tabular_path).name}}}')
|
|
|
|
if resizebox:
|
|
lines.append('}%')
|
|
if caption is None:
|
|
caption = tabular_path.replace('_', '\_')
|
|
lines.append(f'\caption{{{caption}}}')
|
|
if label is not None:
|
|
lines.append(f'\label{{{label}}}')
|
|
lines.append('\end{table}')
|
|
|
|
table_tex = '\n'.join(lines)
|
|
|
|
return table_tex
|
|
|
|
def document(self, tex_path, tabular_dir='tables', *args, **kwargs):
|
|
Table.Document(tex_path, tables=[self], tabular_dir=tabular_dir, *args, **kwargs)
|
|
|
|
def latexPDF(self, pdf_path, tabular_dir='tables', *args, **kwargs):
|
|
return Table.LatexPDF(pdf_path, tables=[self], tabular_dir=tabular_dir, *args, **kwargs)
|
|
|
|
@classmethod
|
|
def Document(self, tex_path, tables:List['Table'], tabular_dir='tables', *args, **kwargs):
|
|
lines = []
|
|
lines.append('\\documentclass[10pt,a4paper]{article}')
|
|
lines.append('\\usepackage[utf8]{inputenc}')
|
|
lines.append('\\usepackage{amsmath}')
|
|
lines.append('\\usepackage{amsfonts}')
|
|
lines.append('\\usepackage{amssymb}')
|
|
lines.append('\\usepackage{graphicx}')
|
|
lines.append('\\usepackage{xcolor}')
|
|
lines.append('\\usepackage{colortbl}')
|
|
lines.append('')
|
|
lines.append('\\begin{document}')
|
|
for table in tables:
|
|
lines.append('')
|
|
lines.append(table.table(os.path.join(Path(tex_path).parent, tabular_dir, table.name + '_table.tex'), *args, **kwargs))
|
|
lines.append('\\end{document}')
|
|
|
|
document = '\n'.join(lines)
|
|
|
|
parent = Path(tex_path).parent
|
|
if parent:
|
|
os.makedirs(parent, exist_ok=True)
|
|
with open(tex_path, 'wt') as foo:
|
|
foo.write(document)
|
|
|
|
return document
|
|
|
|
@classmethod
|
|
def LatexPDF(cls, pdf_path: str, tables:List['Table'], tabular_dir: str = 'tables', *args, **kwargs):
|
|
assert pdf_path.endswith('.pdf'), f'{pdf_path=} does not seem a valid name for a pdf file'
|
|
tex_path = pdf_path.replace('.pdf', '.tex')
|
|
|
|
cls.Document(tex_path, tables, tabular_dir, *args, **kwargs)
|
|
|
|
dir = Path(pdf_path).parent
|
|
pwd = os.getcwd()
|
|
|
|
print('currently in', pwd)
|
|
print("[Tables Done] runing latex")
|
|
os.chdir(dir)
|
|
os.system('pdflatex ' + Path(tex_path).name)
|
|
basename = Path(tex_path).name.replace('.tex', '')
|
|
os.system(f'rm {basename}.aux {basename}.bbl {basename}.blg {basename}.log {basename}.out {basename}.dvi')
|
|
os.chdir(pwd) |