diff --git a/.gitignore b/.gitignore index 406234c..6d4bcb5 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ tests/__pycache__/* tests/*/__pycache__/* tests/*/*/__pycache__/* htmlcov/* +test*.py *.coverage .coverage @@ -22,4 +23,4 @@ scp_sync.py out/* output/* -!output/main/ \ No newline at end of file +!output/main/ diff --git a/conf.yaml b/conf.yaml index dfa4762..17f9600 100644 --- a/conf.yaml +++ b/conf.yaml @@ -415,6 +415,26 @@ kde_lr_gs_conf: &kde_lr_gs_conf - m3w_kde_lr_gs N_JOBS: -2 + confs: + - DATASET_NAME: twitter_gasp + + +multiclass_conf: &multiclass_conf + global: + METRICS: + - acc + - f1 + OUT_DIR_NAME: output/multiclass + DATASET_N_PREVS: 5 + COMP_ESTIMATORS: + - bin_sld_lr_gs + - mul_sld_lr_gs + - bin_kde_lr_gs + - mul_kde_lr_gs + - atc_mc + - doc + N_JOBS: -2 + confs: *main_confs timing_conf: &timing_conf @@ -461,4 +481,4 @@ timing_gs_conf: &timing_gs_conf confs: *main_confs -exec: *debug_conf +exec: *multiclass_conf diff --git a/quacc/evaluation/report.py b/quacc/evaluation/report.py index a6ba71b..13a18ba 100644 --- a/quacc/evaluation/report.py +++ b/quacc/evaluation/report.py @@ -436,6 +436,16 @@ def _cr_data(cr: CompReport, metric=None, estimators=None): return cr.data(metric, estimators) +def _key_reverse_delta_train(idx): + idx = idx.to_numpy() + sorted_idx = np.array( + sorted(list(idx), key=lambda x: x[-1]), dtype=("float," * len(idx[0]))[:-1] + ) + # get sorting index + nparr = np.nonzero(idx[:, None] == sorted_idx)[1] + return nparr + + class DatasetReport: _default_dr_modes = [ "delta_train", @@ -457,6 +467,16 @@ class DatasetReport: self.name = name self.crs: List[CompReport] = [] if crs is None else crs + def sort_delta_train_index(self, data): + # data_ = data.sort_index(axis=0, level=0, ascending=True, sort_remaining=False) + data_ = data.sort_index( + axis=0, + level=0, + key=_key_reverse_delta_train, + ) + print(data_.index) + return data_ + def join(self, other, estimators=None): _crs = [ s_cr.join(o_cr, estimators=estimators) @@ -542,6 +562,7 @@ class DatasetReport: _data.index = _idx _data = _data.sort_index(axis=0, level=0, ascending=False, sort_remaining=False) + return _data def shift_data( @@ -633,6 +654,8 @@ class DatasetReport: avg_on_train = _data.groupby(level=1, sort=False).mean() if avg_on_train.empty: return None + # sort index in reverse order + avg_on_train = self.sort_delta_train_index(avg_on_train) prevs_on_train = avg_on_train.index.unique(0) return plot.plot_delta( # base_prevs=np.around(