plot report for multiclass fixed

This commit is contained in:
Lorenzo Volpi 2024-03-24 19:46:01 +01:00
parent 06efc90257
commit 50f4a833dc
3 changed files with 46 additions and 2 deletions

3
.gitignore vendored
View File

@ -14,6 +14,7 @@ tests/__pycache__/*
tests/*/__pycache__/*
tests/*/*/__pycache__/*
htmlcov/*
test*.py
*.coverage
.coverage
@ -22,4 +23,4 @@ scp_sync.py
out/*
output/*
!output/main/
!output/main/

View File

@ -415,6 +415,26 @@ kde_lr_gs_conf: &kde_lr_gs_conf
- m3w_kde_lr_gs
N_JOBS: -2
confs:
- DATASET_NAME: twitter_gasp
multiclass_conf: &multiclass_conf
global:
METRICS:
- acc
- f1
OUT_DIR_NAME: output/multiclass
DATASET_N_PREVS: 5
COMP_ESTIMATORS:
- bin_sld_lr_gs
- mul_sld_lr_gs
- bin_kde_lr_gs
- mul_kde_lr_gs
- atc_mc
- doc
N_JOBS: -2
confs: *main_confs
timing_conf: &timing_conf
@ -461,4 +481,4 @@ timing_gs_conf: &timing_gs_conf
confs: *main_confs
exec: *debug_conf
exec: *multiclass_conf

View File

@ -436,6 +436,16 @@ def _cr_data(cr: CompReport, metric=None, estimators=None):
return cr.data(metric, estimators)
def _key_reverse_delta_train(idx):
idx = idx.to_numpy()
sorted_idx = np.array(
sorted(list(idx), key=lambda x: x[-1]), dtype=("float," * len(idx[0]))[:-1]
)
# get sorting index
nparr = np.nonzero(idx[:, None] == sorted_idx)[1]
return nparr
class DatasetReport:
_default_dr_modes = [
"delta_train",
@ -457,6 +467,16 @@ class DatasetReport:
self.name = name
self.crs: List[CompReport] = [] if crs is None else crs
def sort_delta_train_index(self, data):
# data_ = data.sort_index(axis=0, level=0, ascending=True, sort_remaining=False)
data_ = data.sort_index(
axis=0,
level=0,
key=_key_reverse_delta_train,
)
print(data_.index)
return data_
def join(self, other, estimators=None):
_crs = [
s_cr.join(o_cr, estimators=estimators)
@ -542,6 +562,7 @@ class DatasetReport:
_data.index = _idx
_data = _data.sort_index(axis=0, level=0, ascending=False, sort_remaining=False)
return _data
def shift_data(
@ -633,6 +654,8 @@ class DatasetReport:
avg_on_train = _data.groupby(level=1, sort=False).mean()
if avg_on_train.empty:
return None
# sort index in reverse order
avg_on_train = self.sort_delta_train_index(avg_on_train)
prevs_on_train = avg_on_train.index.unique(0)
return plot.plot_delta(
# base_prevs=np.around(