wilcoxon statistic added

This commit is contained in:
Lorenzo Volpi 2023-11-28 09:09:17 +01:00
parent dddf8746e2
commit 2eaa5debd1
3 changed files with 30 additions and 22 deletions

View File

@ -6,7 +6,7 @@ import panel as pn
from quacc.evaluation.estimators import CE from quacc.evaluation.estimators import CE
from quacc.evaluation.report import CompReport, DatasetReport from quacc.evaluation.report import CompReport, DatasetReport
from quacc.evaluation.stats import ttest_rel from quacc.evaluation.stats import wilcoxon
_plot_sizing_mode = "stretch_both" _plot_sizing_mode = "stretch_both"
valid_plot_modes = defaultdict(lambda: CompReport._default_modes) valid_plot_modes = defaultdict(lambda: CompReport._default_modes)
@ -44,7 +44,7 @@ def create_plots(
) )
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
case ("avg", "stats_table"): case ("avg", "stats_table"):
_data = ttest_rel(dr, metric=metric, estimators=estimators) _data = wilcoxon(dr, metric=metric, estimators=estimators)
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
case ("avg", _ as plot_mode): case ("avg", _ as plot_mode):
_plot = dr.get_plots( _plot = dr.get_plots(
@ -80,6 +80,10 @@ def create_plots(
.mean() .mean()
) )
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
case (_, "stats_table"):
cr = dr.crs[_prevs.index(int(plot_view))]
_data = wilcoxon(cr, metric=metric, estimators=estimators)
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
case (_, _ as plot_mode): case (_, _ as plot_mode):
cr = dr.crs[_prevs.index(int(plot_view))] cr = dr.crs[_prevs.index(int(plot_view))]
_plot = cr.get_plots( _plot = cr.get_plots(

View File

@ -66,6 +66,7 @@ class CompReport:
"shift", "shift",
"shift_table", "shift_table",
"diagonal", "diagonal",
"stats_table",
] ]
def __init__( def __init__(

View File

@ -4,34 +4,37 @@ import numpy as np
import pandas as pd import pandas as pd
from scipy import stats as sp_stats from scipy import stats as sp_stats
from quacc.evaluation.estimators import CE # from quacc.evaluation.estimators import CE
from quacc.evaluation.report import DatasetReport from quacc.evaluation.report import CompReport, DatasetReport
def ttest_rel( def shapiro(
dr: DatasetReport, metric: str = None, estimators: List[str] = None r: DatasetReport | CompReport, metric: str = None, estimators: List[str] = None
) -> pd.DataFrame: ) -> pd.DataFrame:
_data = dr.data(metric, estimators) _data = r.data(metric, estimators)
shapiro_data = np.array( shapiro_data = np.array(
[sp_stats.shapiro(_data.loc[:, e]) for e in _data.columns.unique(0)] [sp_stats.shapiro(_data.loc[:, e]) for e in _data.columns.unique(0)]
).T ).T
dr_index = ["shapiro_W", "shapiro_p"]
dr_columns = _data.columns.unique(0)
return pd.DataFrame(shapiro_data, columns=dr_columns, index=dr_index)
_ttest_rel = {}
for bs in np.intersect1d(CE.name.baselines, _data.columns.unique(0)): def wilcoxon(
_ttest_rel[f"ttr_{bs}"] = [ r: DatasetReport | CompReport, metric: str = None, estimators: List[str] = None
sp_stats.ttest_rel(_data.loc[:, bs], _data.loc[:, e]).statistic ) -> pd.DataFrame:
if e not in CE.name.baselines _data = r.data(metric, estimators)
else np.nan
_wilcoxon = {}
for est in _data.columns.unique(0):
_wilcoxon[est] = [
sp_stats.wilcoxon(_data.loc[:, est], _data.loc[:, e]).pvalue
if e != est
else 1.0
for e in _data.columns.unique(0) for e in _data.columns.unique(0)
] ]
ttr_data = np.array(list(_ttest_rel.values())) wilcoxon_data = np.array(list(_wilcoxon.values()))
dr_index = ["shapiro_W", "shapiro_p"] + list(_ttest_rel.keys()) dr_index = list(_wilcoxon.keys())
dr_columns = _data.columns.unique(0) dr_columns = _data.columns.unique(0)
dr_data = ( return pd.DataFrame(wilcoxon_data, columns=dr_columns, index=dr_index)
np.concatenate([shapiro_data, ttr_data], axis=0)
if ttr_data.shape[0] > 0
else shapiro_data
)
return pd.DataFrame(dr_data, columns=dr_columns, index=dr_index)