diff --git a/qcpanel/util.py b/qcpanel/util.py index 3820d71..6d1473d 100644 --- a/qcpanel/util.py +++ b/qcpanel/util.py @@ -6,7 +6,7 @@ import panel as pn from quacc.evaluation.estimators import CE from quacc.evaluation.report import CompReport, DatasetReport -from quacc.evaluation.stats import ttest_rel +from quacc.evaluation.stats import wilcoxon _plot_sizing_mode = "stretch_both" valid_plot_modes = defaultdict(lambda: CompReport._default_modes) @@ -44,7 +44,7 @@ def create_plots( ) return pn.pane.DataFrame(_data, align="center") if not _data.empty else None case ("avg", "stats_table"): - _data = ttest_rel(dr, metric=metric, estimators=estimators) + _data = wilcoxon(dr, metric=metric, estimators=estimators) return pn.pane.DataFrame(_data, align="center") if not _data.empty else None case ("avg", _ as plot_mode): _plot = dr.get_plots( @@ -80,6 +80,10 @@ def create_plots( .mean() ) return pn.pane.DataFrame(_data, align="center") if not _data.empty else None + case (_, "stats_table"): + cr = dr.crs[_prevs.index(int(plot_view))] + _data = wilcoxon(cr, metric=metric, estimators=estimators) + return pn.pane.DataFrame(_data, align="center") if not _data.empty else None case (_, _ as plot_mode): cr = dr.crs[_prevs.index(int(plot_view))] _plot = cr.get_plots( diff --git a/quacc/evaluation/report.py b/quacc/evaluation/report.py index e636980..e61e709 100644 --- a/quacc/evaluation/report.py +++ b/quacc/evaluation/report.py @@ -66,6 +66,7 @@ class CompReport: "shift", "shift_table", "diagonal", + "stats_table", ] def __init__( diff --git a/quacc/evaluation/stats.py b/quacc/evaluation/stats.py index d0fc3a3..1bf12d6 100644 --- a/quacc/evaluation/stats.py +++ b/quacc/evaluation/stats.py @@ -4,34 +4,37 @@ import numpy as np import pandas as pd from scipy import stats as sp_stats -from quacc.evaluation.estimators import CE -from quacc.evaluation.report import DatasetReport +# from quacc.evaluation.estimators import CE +from quacc.evaluation.report import CompReport, DatasetReport -def ttest_rel( - dr: DatasetReport, metric: str = None, estimators: List[str] = None +def shapiro( + r: DatasetReport | CompReport, metric: str = None, estimators: List[str] = None ) -> pd.DataFrame: - _data = dr.data(metric, estimators) - + _data = r.data(metric, estimators) shapiro_data = np.array( [sp_stats.shapiro(_data.loc[:, e]) for e in _data.columns.unique(0)] ).T + dr_index = ["shapiro_W", "shapiro_p"] + dr_columns = _data.columns.unique(0) + return pd.DataFrame(shapiro_data, columns=dr_columns, index=dr_index) - _ttest_rel = {} - for bs in np.intersect1d(CE.name.baselines, _data.columns.unique(0)): - _ttest_rel[f"ttr_{bs}"] = [ - sp_stats.ttest_rel(_data.loc[:, bs], _data.loc[:, e]).statistic - if e not in CE.name.baselines - else np.nan + +def wilcoxon( + r: DatasetReport | CompReport, metric: str = None, estimators: List[str] = None +) -> pd.DataFrame: + _data = r.data(metric, estimators) + + _wilcoxon = {} + for est in _data.columns.unique(0): + _wilcoxon[est] = [ + sp_stats.wilcoxon(_data.loc[:, est], _data.loc[:, e]).pvalue + if e != est + else 1.0 for e in _data.columns.unique(0) ] - ttr_data = np.array(list(_ttest_rel.values())) + wilcoxon_data = np.array(list(_wilcoxon.values())) - dr_index = ["shapiro_W", "shapiro_p"] + list(_ttest_rel.keys()) + dr_index = list(_wilcoxon.keys()) dr_columns = _data.columns.unique(0) - dr_data = ( - np.concatenate([shapiro_data, ttr_data], axis=0) - if ttr_data.shape[0] > 0 - else shapiro_data - ) - return pd.DataFrame(dr_data, columns=dr_columns, index=dr_index) + return pd.DataFrame(wilcoxon_data, columns=dr_columns, index=dr_index)