From 5847c217ed36d659611c3e10ae36bfd4691d09f3 Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Thu, 30 Nov 2023 03:10:06 +0100 Subject: [PATCH] added check for empty table on plot generation --- quacc/evaluation/report.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/quacc/evaluation/report.py b/quacc/evaluation/report.py index 70a43d6..67cab60 100644 --- a/quacc/evaluation/report.py +++ b/quacc/evaluation/report.py @@ -436,6 +436,8 @@ class DatasetReport: if mode == "delta_train": _data = self.data(metric, estimators) if data is None else data avg_on_train = _data.groupby(level=1).mean() + if avg_on_train.empty: + return None prevs_on_train = np.sort(avg_on_train.index.unique(0)) return plot.plot_delta( base_prevs=np.around( @@ -454,6 +456,8 @@ class DatasetReport: elif mode == "stdev_train": _data = self.data(metric, estimators) if data is None else data avg_on_train = _data.groupby(level=1).mean() + if avg_on_train.empty: + return None prevs_on_train = np.sort(avg_on_train.index.unique(0)) stdev_on_train = _data.groupby(level=1).std() return plot.plot_delta( @@ -474,6 +478,8 @@ class DatasetReport: elif mode == "delta_test": _data = self.data(metric, estimators) if data is None else data avg_on_test = _data.groupby(level=0).mean() + if avg_on_test.empty: + return None prevs_on_test = np.sort(avg_on_test.index.unique(0)) return plot.plot_delta( base_prevs=np.around([(1.0 - p, p) for p in prevs_on_test], decimals=2), @@ -490,6 +496,8 @@ class DatasetReport: elif mode == "stdev_test": _data = self.data(metric, estimators) if data is None else data avg_on_test = _data.groupby(level=0).mean() + if avg_on_test.empty: + return None prevs_on_test = np.sort(avg_on_test.index.unique(0)) stdev_on_test = _data.groupby(level=0).std() return plot.plot_delta( @@ -508,6 +516,8 @@ class DatasetReport: elif mode == "shift": _shift_data = self.shift_data(metric, estimators) if data is None else data avg_shift = _shift_data.groupby(level=0).mean() + if avg_shift.empty: + return None count_shift = _shift_data.groupby(level=0).count() prevs_shift = np.sort(avg_shift.index.unique(0)) return plot.plot_shift(