import numpy as np from typing import Union, List from collections.abc import Iterable from dataclasses import dataclass from scipy.stats import wilcoxon, ttest_ind_from_stats import pandas as pd import os from pathlib import Path @dataclass class CellFormat: mean_prec: int = 3 std_prec: int = 3 show_std: bool = True remove_zero: bool = False color: bool = True maxtone: int = 50 class Cell: def __init__(self, format: CellFormat, group: 'CellGroup'): self.values = [] self.format = format self.touch() self.group = group self.group.register_cell(self) def __len__(self): return len(self.values) def mean(self): if self.mean_ is None: self.mean_ = np.mean(self.values) return self.mean_ def std(self): if self.std_ is None: self.std_ = np.std(self.values) return self.std_ def touch(self): self.mean_ = None self.std_ = None def append(self, v: Union[float,Iterable]): if isinstance(v, Iterable): self.values.extend(v) self.values.append(v) self.touch() def isEmpty(self): return len(self)==0 def isBest(self): best = self.group.best() if best is not None: return (best == self) or (np.isclose(best.mean(), self.mean())) return False def print_mean(self): if self.isEmpty(): return '' else: return f'{self.mean():.{self.format.mean_prec}f}' def print(self): if self.isEmpty(): return '' # mean # --------------------------------------------------- mean = self.print_mean() if self.format.remove_zero: mean = mean.replace('0.', '.') # std ? # --------------------------------------------------- if self.format.show_std: std = f' $\pm$ {self.std():.{self.format.std_prec}f}' else: std = '' # bold or statistical test # --------------------------------------------------- if self.isBest(): str_cell = f'\\textbf{{{mean}{std}}}' else: comp_symbol = '' pval = self.group.compare(self) if pval is not None: if 0.005 > pval: comp_symbol = '' elif 0.05 > pval >= 0.005: comp_symbol = '$^{\dag}$' elif pval >= 0.05: comp_symbol = '${\ddag}$' str_cell = f'{mean}{comp_symbol}{std}' # color ? # --------------------------------------------------- if self.format.color: str_cell += ' ' + self.group.color(self) return str_cell class CellGroup: def __init__(self, lower_is_better=True, stat_test='wilcoxon', color_mode='local', color_global_min=None, color_global_max=None): assert stat_test in ['wilcoxon', 'ttest', None], \ f"unknown {stat_test=}, valid ones are wilcoxon, ttest, or None" assert color_mode in ['local', 'global'], \ f"unknown {color_mode=}, valid ones are local and global" if (color_global_min is not None or color_global_max is not None) and color_mode=='local': print('warning: color_global_min and color_global_max are only considered when color_mode==local') self.cells = [] self.lower_is_better = lower_is_better self.stat_test = stat_test self.color_mode = color_mode self.color_global_min = color_global_min self.color_global_max = color_global_max def register_cell(self, cell: Cell): self.cells.append(cell) def non_empty_cells(self): return [c for c in self.cells if not c.isEmpty()] def max(self): cells = self.non_empty_cells() if len(cells)>0: return cells[np.argmax([c.mean() for c in cells])] return None def min(self): cells = self.non_empty_cells() if len(cells) > 0: return cells[np.argmin([c.mean() for c in cells])] return None def best(self) -> Cell: return self.min() if self.lower_is_better else self.max() def worst(self) -> Cell: return self.max() if self.lower_is_better else self.min() def isEmpty(self): return len(self.non_empty_cells())==0 def compare(self, cell: Cell): best = self.best() best_n = len(best) cell_n = len(cell) if best_n > 0 and cell_n > 0: if self.stat_test == 'wilcoxon': try: _, p_val = wilcoxon(best.values, cell.values) except ValueError: p_val = None return p_val elif self.stat_test == 'ttest': best_mean, best_std = best.mean(), best.std() cell_mean, cell_std = cell.mean(), cell.std() _, p_val = ttest_ind_from_stats(best_mean, best_std, best_n, cell_mean, cell_std, cell_n) return p_val elif self.stat_test is None: return None else: raise ValueError(f'unknown statistical test {self.stat_test}') else: return None def color(self, cell: Cell): cell_mean = cell.mean() if self.color_mode == 'local': best = self.best() worst = self.worst() best_mean = best.mean() worst_mean = worst.mean() if best is None or worst is None or best_mean == worst_mean or cell.isEmpty(): return '' # normalize val in [0,1] maxval = max(best_mean, worst_mean) minval = min(best_mean, worst_mean) else: maxval = self.color_global_max minval = self.color_global_min normval = (cell_mean - minval) / (maxval - minval) if self.lower_is_better: normval = 1 - normval normval = np.clip(normval, 0, 1) normval = normval * 2 - 1 # rescale to [-1,1] if normval < 0: color = 'red' tone = cell.format.maxtone * (-normval) else: color = 'green' tone = cell.format.maxtone * normval return f'\cellcolor{{{color}!{int(tone)}}}' class Table: def __init__(self, name, benchmarks=None, methods=None, format:CellFormat=None, lower_is_better=True, stat_test='wilcoxon', color_mode='local', with_mean=True ): self.name = name self.benchmarks = [] if benchmarks is None else benchmarks self.methods = [] if methods is None else methods self.format = format if format is not None else CellFormat() self.lower_is_better = lower_is_better self.stat_test = stat_test self.color_mode = color_mode self.with_mean = with_mean self.only_full_mean = True # if False, compute the mean of partially empty methods also if self.color_mode == 'global': self.color_global_min = 0 self.color_global_max = 1 else: self.color_global_min = None self.color_global_max = None self.T = {} self.groups = {} def add(self, benchmark, method, v): cell = self.get(benchmark, method) cell.append(v) def get_benchmarks(self): return self.benchmarks def get_methods(self): return self.methods def n_benchmarks(self): return len(self.benchmarks) def n_methods(self): return len(self.methods) def _new_group(self): return CellGroup(self.lower_is_better, self.stat_test, color_mode=self.color_mode, color_global_max=self.color_global_max, color_global_min=self.color_global_min) def get(self, benchmark, method) -> Cell: if benchmark not in self.benchmarks: self.benchmarks.append(benchmark) if benchmark not in self.groups: self.groups[benchmark] = self._new_group() if method not in self.methods: self.methods.append(method) b_idx = self.benchmarks.index(benchmark) m_idx = self.methods.index(method) idx = tuple((b_idx, m_idx)) if idx not in self.T: self.T[idx] = Cell(self.format, group=self.groups[benchmark]) cell = self.T[idx] return cell def get_value(self, benchmark, method) -> float: return self.get(benchmark, method).mean() def get_benchmark(self, benchmark): cells = [self.get(benchmark, method=m) for m in self.get_methods()] cells = [c for c in cells if not c.isEmpty()] return cells def get_method(self, method): cells = [self.get(benchmark=b, method=method) for b in self.get_benchmarks()] cells = [c for c in cells if not c.isEmpty()] return cells def get_method_means(self, method_order): mean_group = self._new_group() cells = [] for method in method_order: method_mean = Cell(self.format, group=mean_group) for bench in self.get_benchmarks(): mean_value = self.get_value(benchmark=bench, method=method) if not np.isnan(mean_value): method_mean.append(mean_value) cells.append(method_mean) return cells def get_benchmark_values(self, benchmark): values = np.asarray([c.mean() for c in self.get_benchmark(benchmark)]) return values def get_method_values(self, method): values = np.asarray([c.mean() for c in self.get_method(method)]) return values def all_mean(self): values = [c.mean() for c in self.T.values() if not c.isEmpty()] return np.mean(values) def print(self): # todo: missing method names? data_dict = {} data_dict['Benchmark'] = [b for b in self.get_benchmarks()] for method in self.get_methods(): data_dict[method] = [self.get(bench, method).print_mean() for bench in self.get_benchmarks()] df = pd.DataFrame(data_dict) pd.set_option('display.max_columns', None) pd.set_option('display.max_rows', None) print(df.to_string(index=False)) def tabular(self, path=None, benchmark_replace=None, method_replace=None, benchmark_order=None, method_order=None, transpose=False): if benchmark_replace is None: benchmark_replace = {} if method_replace is None: method_replace = {} if benchmark_order is None: benchmark_order = self.get_benchmarks() if method_order is None: method_order = self.get_methods() if transpose: row_order, row_replace = method_order, method_replace col_order, col_replace = benchmark_order, benchmark_replace else: row_order, row_replace = benchmark_order, benchmark_replace col_order, col_replace = method_order, method_replace n_cols = len(col_order) add_mean_col = self.with_mean and transpose add_mean_row = self.with_mean and not transpose last_col_idx = n_cols+2 if add_mean_col else n_cols+1 if self.with_mean: mean_cells = self.get_method_means(method_order) lines = [] lines.append('\\begin{tabular}{|c' + '|c' * n_cols + ('||c' if add_mean_col else '') + "|}") lines.append(f'\\cline{{2-{last_col_idx}}}') l = '\multicolumn{1}{c|}{} & ' l += ' & '.join([col_replace.get(col, col) for col in col_order]) if add_mean_col: l += ' & Ave.' l += ' \\\\\\hline' lines.append(l) for i, row in enumerate(row_order): rowname = row_replace.get(row, row) l = rowname + ' & ' l += ' & '.join([ self.get(benchmark=col if transpose else row, method=row if transpose else col).print() for col in col_order ]) if add_mean_col: l+= ' & ' + mean_cells[i].print() l += ' \\\\\\hline' lines.append(l) if add_mean_row: lines.append('\hline') l = 'Ave. & ' l+= ' & '.join([mean_cell.print() for mean_cell in mean_cells]) l += ' \\\\\\hline' lines.append(l) lines.append('\\end{tabular}') tabular_tex = '\n'.join(lines) if path is not None: parent = Path(path).parent if parent: os.makedirs(parent, exist_ok=True) with open(path, 'wt') as foo: foo.write(tabular_tex) return tabular_tex def table(self, tabular_path, benchmark_replace=None, method_replace=None, resizebox=True, caption=None, label=None, benchmark_order=None, method_order=None, transpose=False): if benchmark_replace is None: benchmark_replace = {} if method_replace is None: method_replace = {} lines = [] lines.append('\\begin{table}') lines.append('\center') if resizebox: lines.append('\\resizebox{\\textwidth}{!}{%') tabular_str = self.tabular(tabular_path, benchmark_replace, method_replace, benchmark_order, method_order, transpose) if tabular_path is None: lines.append(tabular_str) else: lines.append(f'\input{{tables/{Path(tabular_path).name}}}') if resizebox: lines.append('}%') if caption is None: caption = tabular_path.replace('_', '\_') lines.append(f'\caption{{{caption}}}') if label is not None: lines.append(f'\label{{{label}}}') lines.append('\end{table}') table_tex = '\n'.join(lines) return table_tex def document(self, tex_path, tabular_dir='tables', *args, **kwargs): Table.Document(tex_path, tables=[self], tabular_dir=tabular_dir, *args, **kwargs) def latexPDF(self, pdf_path, tabular_dir='tables', *args, **kwargs): return Table.LatexPDF(pdf_path, tables=[self], tabular_dir=tabular_dir, *args, **kwargs) @classmethod def Document(self, tex_path, tables:List['Table'], tabular_dir='tables', landscape=True, *args, **kwargs): lines = [] lines.append('\\documentclass[10pt,a4paper]{article}') lines.append('\\usepackage[utf8]{inputenc}') lines.append('\\usepackage{amsmath}') lines.append('\\usepackage{amsfonts}') lines.append('\\usepackage{amssymb}') lines.append('\\usepackage{graphicx}') lines.append('\\usepackage{xcolor}') lines.append('\\usepackage{colortbl}') if landscape: lines.append('\\usepackage[landscape]{geometry}') lines.append('') lines.append('\\begin{document}') for table in tables: lines.append('') lines.append(table.table(os.path.join(Path(tex_path).parent, tabular_dir, table.name + '_table.tex'), *args, **kwargs)) lines.append('\n\\newpage\n') lines.append('\\end{document}') document = '\n'.join(lines) parent = Path(tex_path).parent if parent: os.makedirs(parent, exist_ok=True) with open(tex_path, 'wt') as foo: foo.write(document) return document @classmethod def LatexPDF(cls, pdf_path: str, tables:List['Table'], tabular_dir: str = 'tables', *args, **kwargs): assert pdf_path.endswith('.pdf'), f'{pdf_path=} does not seem a valid name for a pdf file' tex_path = pdf_path.replace('.pdf', '.tex') cls.Document(tex_path, tables, tabular_dir, *args, **kwargs) dir = Path(pdf_path).parent pwd = os.getcwd() print('currently in', pwd) print("[Tables Done] runing latex") os.chdir(dir) os.system('pdflatex ' + Path(tex_path).name) basename = Path(tex_path).name.replace('.tex', '') os.system(f'rm {basename}.aux {basename}.log') os.chdir(pwd) print('[Done]')