From 4c6a0342ffa83bb783b71c3a11f3d563823ea31c Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Thu, 7 Mar 2024 19:33:32 +0100 Subject: [PATCH] tesi updated --- TODO.md | 3 +++ quacc/evaluation/estimators.py | 9 ++++--- quacc/evaluation/method.py | 2 +- quacc/evaluation/report.py | 20 ++++++++++++++ quacc/plot/base.py | 1 + quacc/plot/plot.py | 2 ++ quacc/plot/plotly.py | 48 +++++++++++++++++++++++++--------- 7 files changed, 68 insertions(+), 17 deletions(-) diff --git a/TODO.md b/TODO.md index 154ff00..d559f56 100644 --- a/TODO.md +++ b/TODO.md @@ -61,3 +61,6 @@ multiclass: - [x] aggiungere supporto a multiclass in dataset.py - [x] aggiungere group_false in ExtensionPolicy - [ ] modificare BQAE in modo che i quantifier si adattino alla casistica(binary/multi in base a group_false) + +fix: +- [ ] make quantifiers predict 0 prevalence for classes for which we have 0 samples diff --git a/quacc/evaluation/estimators.py b/quacc/evaluation/estimators.py index d6dd81e..476584b 100644 --- a/quacc/evaluation/estimators.py +++ b/quacc/evaluation/estimators.py @@ -88,7 +88,8 @@ _renames = { "d_bin_sld_rbf": "(2x2)_SLD_RBF", "d_mul_sld_rbf": "(1x4)_SLD_RBF", "d_m3w_sld_rbf": "(1x3)_SLD_RBF", - "sld_lr_gs": "MS_SLD_LR", + # "sld_lr_gs": "MS_SLD_LR", + "sld_lr_gs": "QuAcc(SLD)", "bin_kde_lr": "(2x2)_KDEy_LR", "mul_kde_lr": "(1x4)_KDEy_LR", "m3w_kde_lr": "(1x3)_KDEy_LR", @@ -98,8 +99,10 @@ _renames = { "bin_cc_lr": "(2x2)_CC_LR", "mul_cc_lr": "(1x4)_CC_LR", "m3w_cc_lr": "(1x3)_CC_LR", - "kde_lr_gs": "MS_KDEy_LR", - "cc_lr_gs": "MS_CC_LR", + # "kde_lr_gs": "MS_KDEy_LR", + "kde_lr_gs": "QuAcc(KDEy)", + # "cc_lr_gs": "MS_CC_LR", + "cc_lr_gs": "QuAcc(CC)", "atc_mc": "ATC", "doc": "DoC", "mandoline": "Mandoline", diff --git a/quacc/evaluation/method.py b/quacc/evaluation/method.py index 79326be..3b0af41 100644 --- a/quacc/evaluation/method.py +++ b/quacc/evaluation/method.py @@ -493,9 +493,9 @@ __cc_lr_set = [ ] __ms_set = [ + E("cc_lr_gs"), E("sld_lr_gs"), E("kde_lr_gs"), - E("cc_lr_gs"), E("QuAcc"), ] diff --git a/quacc/evaluation/report.py b/quacc/evaluation/report.py index f078296..a6ba71b 100644 --- a/quacc/evaluation/report.py +++ b/quacc/evaluation/report.py @@ -447,6 +447,7 @@ class DatasetReport: "delta_test", "stdev_test", "test_table", + "diagonal", "stats_table", "fit_scores", ] @@ -745,6 +746,25 @@ class DatasetReport: base_path=base_path, backend=backend, ) + elif mode == "diagonal": + f_data = self.data(metric=metric + "_score", estimators=estimators) + if f_data.empty: + return None + + ref: pd.Series = f_data.loc[:, "ref"] + f_data.drop(columns=["ref"], inplace=True) + return plot.plot_diagonal( + reference=ref.to_numpy(), + columns=f_data.columns.to_numpy(), + data=f_data.T.to_numpy(), + metric=metric, + name=conf, + # train_prev=self.train_prev, + fixed_lim=True, + save_fig=save_fig, + base_path=base_path, + backend=backend, + ) def to_md( self, diff --git a/quacc/plot/base.py b/quacc/plot/base.py index 04e9e8c..a44b219 100644 --- a/quacc/plot/base.py +++ b/quacc/plot/base.py @@ -17,6 +17,7 @@ class BasePlot: title="default", x_label="true", y_label="estim.", + fixed_lim=False, legend=True, ): ... diff --git a/quacc/plot/plot.py b/quacc/plot/plot.py index fa7c082..3eebf1e 100644 --- a/quacc/plot/plot.py +++ b/quacc/plot/plot.py @@ -77,6 +77,7 @@ def plot_diagonal( metric="acc", name="default", train_prev=None, + fixed_lim=False, legend=True, save_fig=False, base_path=None, @@ -103,6 +104,7 @@ def plot_diagonal( title=title, x_label=x_label, y_label=y_label, + fixed_lim=fixed_lim, legend=legend, ) diff --git a/quacc/plot/plotly.py b/quacc/plot/plotly.py index 900be13..52a514d 100644 --- a/quacc/plot/plotly.py +++ b/quacc/plot/plotly.py @@ -184,14 +184,21 @@ class PlotlyPlot(BasePlot): title="default", x_label="true", y_label="estim.", + fixed_lim=False, legend=True, ) -> go.Figure: fig = go.Figure() x = reference line_colors = self.get_colors(len(columns)) - _edges = (np.min([np.min(x), np.min(data)]), np.max([np.max(x), np.max(data)])) - _lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]]) + if fixed_lim: + _lims = np.array([[0.0, 1.0], [0.0, 1.0]]) + else: + _edges = ( + np.min([np.min(x), np.min(data)]), + np.max([np.max(x), np.max(data)]), + ) + _lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]]) named_data = {c: d for c, d in zip(columns, data)} r_columns = {c: r for c, r in zip(columns, self.rename_plots(columns))} @@ -201,7 +208,7 @@ class PlotlyPlot(BasePlot): r_name = r_columns[name] color = next(line_colors) slope, interc = np.polyfit(x, val, 1) - y_lr = np.array([slope * _x + interc for _x in _lims[0]]) + # y_lr = np.array([slope * _x + interc for _x in _lims[0]]) fig.add_traces( [ go.Scatter( @@ -210,17 +217,25 @@ class PlotlyPlot(BasePlot): customdata=np.stack((val - x,), axis=-1), mode="markers", name=r_name, - line=dict(color=self.hex_to_rgb(color, t=0.5)), + marker=dict(color=self.hex_to_rgb(color, t=0.5)), hovertemplate="true acc: %{x:,.4f}
estim. acc: %{y:,.4f}
acc err.: %{customdata[0]:,.4f}", + # showlegend=False, ), - go.Scatter( - x=_lims[0], - y=y_lr, - mode="lines", - name=name, - line=dict(color=self.hex_to_rgb(color), width=3), - showlegend=False, - ), + # go.Scatter( + # x=[x[-1]], + # y=[val[-1]], + # mode="markers", + # marker=dict(color=self.hex_to_rgb(color), size=8), + # name=r_name, + # ), + # go.Scatter( + # x=_lims[0], + # y=y_lr, + # mode="lines", + # name=name, + # line=dict(color=self.hex_to_rgb(color), width=3), + # showlegend=False, + # ), ] ) fig.add_trace( @@ -235,7 +250,14 @@ class PlotlyPlot(BasePlot): ) self.update_layout(fig, title, x_label, y_label) - fig.update_layout(yaxis_scaleanchor="x", yaxis_scaleratio=1.0) + fig.update_layout( + autosize=False, + width=1300, + height=1000, + yaxis_scaleanchor="x", + yaxis_scaleratio=1.0, + yaxis_range=[-0.1, 1.1], + ) return fig def plot_shift(