qcdash updated for plotting refactoring

This commit is contained in:
Lorenzo Volpi 2024-04-05 15:49:57 +02:00
parent 558c3231a3
commit 322f060a13
1 changed files with 21 additions and 13 deletions

View File

@ -76,57 +76,65 @@ def get_datasets(root: str | Path) -> List[DatasetReport]:
return {str(drp.parent): load_dataset(drp) for drp in dr_paths} return {str(drp.parent): load_dataset(drp) for drp in dr_paths}
def get_fig(rep: Report, dataset, metric, estimators, view, mode): def get_fig(rep: Report, cls_name, acc_name, dataset_name, estimators, view, mode):
match (view, mode): match (view, mode):
case ("avg", "diagonal"): case ("avg", "diagonal"):
true_accs, estim_accs = rep.diagonal_plot_data( true_accs, estim_accs = rep.diagonal_plot_data(
dataset_name=dataset, dataset_name=dataset_name,
method_names=estimators, method_names=estimators,
acc_name=metric, acc_name=acc_name,
) )
return plot_diagonal( return plot_diagonal(
method_names=estimators, method_names=estimators,
true_accs=true_accs, true_accs=true_accs,
estim_accs=estim_accs, estim_accs=estim_accs,
measure_name=metric, cls_name=cls_name,
acc_name=acc_name,
dataset_name=dataset_name,
) )
case ("avg", "delta_train"): case ("avg", "delta_train"):
prevs, acc_errs = rep.delta_train_plot_data( prevs, acc_errs = rep.delta_train_plot_data(
dataset_name=dataset, dataset_name=dataset_name,
method_names=estimators, method_names=estimators,
acc_name=metric, acc_name=acc_name,
) )
return plot_delta( return plot_delta(
method_names=estimators, method_names=estimators,
prevs=prevs, prevs=prevs,
acc_errs=acc_errs, acc_errs=acc_errs,
measure_name=metric, cls_name=cls_name,
acc_mame=acc_name,
dataset_name=dataset_name,
prev_name="Test", prev_name="Test",
) )
case ("avg", "stdev_train"): case ("avg", "stdev_train"):
prevs, acc_errs, stdevs = rep.delta_train_plot_data( prevs, acc_errs, stdevs = rep.delta_train_plot_data(
dataset_name=dataset, dataset_name=dataset_name,
method_names=estimators, method_names=estimators,
acc_name=metric, acc_name=acc_name,
stdev=True, stdev=True,
) )
return plot_delta( return plot_delta(
method_names=estimators, method_names=estimators,
prevs=prevs, prevs=prevs,
acc_errs=acc_errs, acc_errs=acc_errs,
measure_name=metric, cls_name=cls_name,
acc_mame=acc_name,
dataset_name=dataset_name,
prev_name="Test", prev_name="Test",
stdevs=stdevs, stdevs=stdevs,
) )
case ("avg", "shift"): case ("avg", "shift"):
prevs, acc_errs, counts = rep.shift_plot_data( prevs, acc_errs, counts = rep.shift_plot_data(
dataset_name=dataset, method_names=estimators, acc_name=metric dataset_name=dataset_name, method_names=estimators, acc_name=acc_name
) )
return plot_shift( return plot_shift(
method_names=estimators, method_names=estimators,
prevs=prevs, prevs=prevs,
acc_errs=acc_errs, acc_errs=acc_errs,
measure_name=metric, cls_name=cls_name,
acc_name=acc_name,
dataset_name=dataset_name,
counts=counts, counts=counts,
) )
case (_, _): case (_, _):
@ -543,7 +551,7 @@ def update_content(dataset, metric, estimators, view, mode, root):
case _: case _:
fig = get_fig( fig = get_fig(
dr=dr, dr=dr,
metric=metric, acc_name=metric,
estimators=estimators, estimators=estimators,
view=view, view=view,
mode=mode, mode=mode,