qcdash updated for plotting refactoring
This commit is contained in:
parent
558c3231a3
commit
322f060a13
|
@ -76,57 +76,65 @@ def get_datasets(root: str | Path) -> List[DatasetReport]:
|
||||||
return {str(drp.parent): load_dataset(drp) for drp in dr_paths}
|
return {str(drp.parent): load_dataset(drp) for drp in dr_paths}
|
||||||
|
|
||||||
|
|
||||||
def get_fig(rep: Report, dataset, metric, estimators, view, mode):
|
def get_fig(rep: Report, cls_name, acc_name, dataset_name, estimators, view, mode):
|
||||||
match (view, mode):
|
match (view, mode):
|
||||||
case ("avg", "diagonal"):
|
case ("avg", "diagonal"):
|
||||||
true_accs, estim_accs = rep.diagonal_plot_data(
|
true_accs, estim_accs = rep.diagonal_plot_data(
|
||||||
dataset_name=dataset,
|
dataset_name=dataset_name,
|
||||||
method_names=estimators,
|
method_names=estimators,
|
||||||
acc_name=metric,
|
acc_name=acc_name,
|
||||||
)
|
)
|
||||||
return plot_diagonal(
|
return plot_diagonal(
|
||||||
method_names=estimators,
|
method_names=estimators,
|
||||||
true_accs=true_accs,
|
true_accs=true_accs,
|
||||||
estim_accs=estim_accs,
|
estim_accs=estim_accs,
|
||||||
measure_name=metric,
|
cls_name=cls_name,
|
||||||
|
acc_name=acc_name,
|
||||||
|
dataset_name=dataset_name,
|
||||||
)
|
)
|
||||||
case ("avg", "delta_train"):
|
case ("avg", "delta_train"):
|
||||||
prevs, acc_errs = rep.delta_train_plot_data(
|
prevs, acc_errs = rep.delta_train_plot_data(
|
||||||
dataset_name=dataset,
|
dataset_name=dataset_name,
|
||||||
method_names=estimators,
|
method_names=estimators,
|
||||||
acc_name=metric,
|
acc_name=acc_name,
|
||||||
)
|
)
|
||||||
return plot_delta(
|
return plot_delta(
|
||||||
method_names=estimators,
|
method_names=estimators,
|
||||||
prevs=prevs,
|
prevs=prevs,
|
||||||
acc_errs=acc_errs,
|
acc_errs=acc_errs,
|
||||||
measure_name=metric,
|
cls_name=cls_name,
|
||||||
|
acc_mame=acc_name,
|
||||||
|
dataset_name=dataset_name,
|
||||||
prev_name="Test",
|
prev_name="Test",
|
||||||
)
|
)
|
||||||
case ("avg", "stdev_train"):
|
case ("avg", "stdev_train"):
|
||||||
prevs, acc_errs, stdevs = rep.delta_train_plot_data(
|
prevs, acc_errs, stdevs = rep.delta_train_plot_data(
|
||||||
dataset_name=dataset,
|
dataset_name=dataset_name,
|
||||||
method_names=estimators,
|
method_names=estimators,
|
||||||
acc_name=metric,
|
acc_name=acc_name,
|
||||||
stdev=True,
|
stdev=True,
|
||||||
)
|
)
|
||||||
return plot_delta(
|
return plot_delta(
|
||||||
method_names=estimators,
|
method_names=estimators,
|
||||||
prevs=prevs,
|
prevs=prevs,
|
||||||
acc_errs=acc_errs,
|
acc_errs=acc_errs,
|
||||||
measure_name=metric,
|
cls_name=cls_name,
|
||||||
|
acc_mame=acc_name,
|
||||||
|
dataset_name=dataset_name,
|
||||||
prev_name="Test",
|
prev_name="Test",
|
||||||
stdevs=stdevs,
|
stdevs=stdevs,
|
||||||
)
|
)
|
||||||
case ("avg", "shift"):
|
case ("avg", "shift"):
|
||||||
prevs, acc_errs, counts = rep.shift_plot_data(
|
prevs, acc_errs, counts = rep.shift_plot_data(
|
||||||
dataset_name=dataset, method_names=estimators, acc_name=metric
|
dataset_name=dataset_name, method_names=estimators, acc_name=acc_name
|
||||||
)
|
)
|
||||||
return plot_shift(
|
return plot_shift(
|
||||||
method_names=estimators,
|
method_names=estimators,
|
||||||
prevs=prevs,
|
prevs=prevs,
|
||||||
acc_errs=acc_errs,
|
acc_errs=acc_errs,
|
||||||
measure_name=metric,
|
cls_name=cls_name,
|
||||||
|
acc_name=acc_name,
|
||||||
|
dataset_name=dataset_name,
|
||||||
counts=counts,
|
counts=counts,
|
||||||
)
|
)
|
||||||
case (_, _):
|
case (_, _):
|
||||||
|
@ -543,7 +551,7 @@ def update_content(dataset, metric, estimators, view, mode, root):
|
||||||
case _:
|
case _:
|
||||||
fig = get_fig(
|
fig = get_fig(
|
||||||
dr=dr,
|
dr=dr,
|
||||||
metric=metric,
|
acc_name=metric,
|
||||||
estimators=estimators,
|
estimators=estimators,
|
||||||
view=view,
|
view=view,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
|
|
Loading…
Reference in New Issue