diff --git a/TODO.md b/TODO.md
index 154ff00..d559f56 100644
--- a/TODO.md
+++ b/TODO.md
@@ -61,3 +61,6 @@ multiclass:
- [x] aggiungere supporto a multiclass in dataset.py
- [x] aggiungere group_false in ExtensionPolicy
- [ ] modificare BQAE in modo che i quantifier si adattino alla casistica(binary/multi in base a group_false)
+
+fix:
+- [ ] make quantifiers predict 0 prevalence for classes for which we have 0 samples
diff --git a/quacc/evaluation/estimators.py b/quacc/evaluation/estimators.py
index d6dd81e..476584b 100644
--- a/quacc/evaluation/estimators.py
+++ b/quacc/evaluation/estimators.py
@@ -88,7 +88,8 @@ _renames = {
"d_bin_sld_rbf": "(2x2)_SLD_RBF",
"d_mul_sld_rbf": "(1x4)_SLD_RBF",
"d_m3w_sld_rbf": "(1x3)_SLD_RBF",
- "sld_lr_gs": "MS_SLD_LR",
+ # "sld_lr_gs": "MS_SLD_LR",
+ "sld_lr_gs": "QuAcc(SLD)",
"bin_kde_lr": "(2x2)_KDEy_LR",
"mul_kde_lr": "(1x4)_KDEy_LR",
"m3w_kde_lr": "(1x3)_KDEy_LR",
@@ -98,8 +99,10 @@ _renames = {
"bin_cc_lr": "(2x2)_CC_LR",
"mul_cc_lr": "(1x4)_CC_LR",
"m3w_cc_lr": "(1x3)_CC_LR",
- "kde_lr_gs": "MS_KDEy_LR",
- "cc_lr_gs": "MS_CC_LR",
+ # "kde_lr_gs": "MS_KDEy_LR",
+ "kde_lr_gs": "QuAcc(KDEy)",
+ # "cc_lr_gs": "MS_CC_LR",
+ "cc_lr_gs": "QuAcc(CC)",
"atc_mc": "ATC",
"doc": "DoC",
"mandoline": "Mandoline",
diff --git a/quacc/evaluation/method.py b/quacc/evaluation/method.py
index 79326be..3b0af41 100644
--- a/quacc/evaluation/method.py
+++ b/quacc/evaluation/method.py
@@ -493,9 +493,9 @@ __cc_lr_set = [
]
__ms_set = [
+ E("cc_lr_gs"),
E("sld_lr_gs"),
E("kde_lr_gs"),
- E("cc_lr_gs"),
E("QuAcc"),
]
diff --git a/quacc/evaluation/report.py b/quacc/evaluation/report.py
index f078296..a6ba71b 100644
--- a/quacc/evaluation/report.py
+++ b/quacc/evaluation/report.py
@@ -447,6 +447,7 @@ class DatasetReport:
"delta_test",
"stdev_test",
"test_table",
+ "diagonal",
"stats_table",
"fit_scores",
]
@@ -745,6 +746,25 @@ class DatasetReport:
base_path=base_path,
backend=backend,
)
+ elif mode == "diagonal":
+ f_data = self.data(metric=metric + "_score", estimators=estimators)
+ if f_data.empty:
+ return None
+
+ ref: pd.Series = f_data.loc[:, "ref"]
+ f_data.drop(columns=["ref"], inplace=True)
+ return plot.plot_diagonal(
+ reference=ref.to_numpy(),
+ columns=f_data.columns.to_numpy(),
+ data=f_data.T.to_numpy(),
+ metric=metric,
+ name=conf,
+ # train_prev=self.train_prev,
+ fixed_lim=True,
+ save_fig=save_fig,
+ base_path=base_path,
+ backend=backend,
+ )
def to_md(
self,
diff --git a/quacc/plot/base.py b/quacc/plot/base.py
index 04e9e8c..a44b219 100644
--- a/quacc/plot/base.py
+++ b/quacc/plot/base.py
@@ -17,6 +17,7 @@ class BasePlot:
title="default",
x_label="true",
y_label="estim.",
+ fixed_lim=False,
legend=True,
):
...
diff --git a/quacc/plot/plot.py b/quacc/plot/plot.py
index fa7c082..3eebf1e 100644
--- a/quacc/plot/plot.py
+++ b/quacc/plot/plot.py
@@ -77,6 +77,7 @@ def plot_diagonal(
metric="acc",
name="default",
train_prev=None,
+ fixed_lim=False,
legend=True,
save_fig=False,
base_path=None,
@@ -103,6 +104,7 @@ def plot_diagonal(
title=title,
x_label=x_label,
y_label=y_label,
+ fixed_lim=fixed_lim,
legend=legend,
)
diff --git a/quacc/plot/plotly.py b/quacc/plot/plotly.py
index 900be13..52a514d 100644
--- a/quacc/plot/plotly.py
+++ b/quacc/plot/plotly.py
@@ -184,14 +184,21 @@ class PlotlyPlot(BasePlot):
title="default",
x_label="true",
y_label="estim.",
+ fixed_lim=False,
legend=True,
) -> go.Figure:
fig = go.Figure()
x = reference
line_colors = self.get_colors(len(columns))
- _edges = (np.min([np.min(x), np.min(data)]), np.max([np.max(x), np.max(data)]))
- _lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]])
+ if fixed_lim:
+ _lims = np.array([[0.0, 1.0], [0.0, 1.0]])
+ else:
+ _edges = (
+ np.min([np.min(x), np.min(data)]),
+ np.max([np.max(x), np.max(data)]),
+ )
+ _lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]])
named_data = {c: d for c, d in zip(columns, data)}
r_columns = {c: r for c, r in zip(columns, self.rename_plots(columns))}
@@ -201,7 +208,7 @@ class PlotlyPlot(BasePlot):
r_name = r_columns[name]
color = next(line_colors)
slope, interc = np.polyfit(x, val, 1)
- y_lr = np.array([slope * _x + interc for _x in _lims[0]])
+ # y_lr = np.array([slope * _x + interc for _x in _lims[0]])
fig.add_traces(
[
go.Scatter(
@@ -210,17 +217,25 @@ class PlotlyPlot(BasePlot):
customdata=np.stack((val - x,), axis=-1),
mode="markers",
name=r_name,
- line=dict(color=self.hex_to_rgb(color, t=0.5)),
+ marker=dict(color=self.hex_to_rgb(color, t=0.5)),
hovertemplate="true acc: %{x:,.4f}
estim. acc: %{y:,.4f}
acc err.: %{customdata[0]:,.4f}",
+ # showlegend=False,
),
- go.Scatter(
- x=_lims[0],
- y=y_lr,
- mode="lines",
- name=name,
- line=dict(color=self.hex_to_rgb(color), width=3),
- showlegend=False,
- ),
+ # go.Scatter(
+ # x=[x[-1]],
+ # y=[val[-1]],
+ # mode="markers",
+ # marker=dict(color=self.hex_to_rgb(color), size=8),
+ # name=r_name,
+ # ),
+ # go.Scatter(
+ # x=_lims[0],
+ # y=y_lr,
+ # mode="lines",
+ # name=name,
+ # line=dict(color=self.hex_to_rgb(color), width=3),
+ # showlegend=False,
+ # ),
]
)
fig.add_trace(
@@ -235,7 +250,14 @@ class PlotlyPlot(BasePlot):
)
self.update_layout(fig, title, x_label, y_label)
- fig.update_layout(yaxis_scaleanchor="x", yaxis_scaleratio=1.0)
+ fig.update_layout(
+ autosize=False,
+ width=1300,
+ height=1000,
+ yaxis_scaleanchor="x",
+ yaxis_scaleratio=1.0,
+ yaxis_range=[-0.1, 1.1],
+ )
return fig
def plot_shift(