plot report for multiclass fixed
This commit is contained in:
parent
06efc90257
commit
50f4a833dc
|
@ -14,6 +14,7 @@ tests/__pycache__/*
|
|||
tests/*/__pycache__/*
|
||||
tests/*/*/__pycache__/*
|
||||
htmlcov/*
|
||||
test*.py
|
||||
|
||||
*.coverage
|
||||
.coverage
|
||||
|
@ -22,4 +23,4 @@ scp_sync.py
|
|||
|
||||
out/*
|
||||
output/*
|
||||
!output/main/
|
||||
!output/main/
|
||||
|
|
22
conf.yaml
22
conf.yaml
|
@ -415,6 +415,26 @@ kde_lr_gs_conf: &kde_lr_gs_conf
|
|||
- m3w_kde_lr_gs
|
||||
N_JOBS: -2
|
||||
|
||||
confs:
|
||||
- DATASET_NAME: twitter_gasp
|
||||
|
||||
|
||||
multiclass_conf: &multiclass_conf
|
||||
global:
|
||||
METRICS:
|
||||
- acc
|
||||
- f1
|
||||
OUT_DIR_NAME: output/multiclass
|
||||
DATASET_N_PREVS: 5
|
||||
COMP_ESTIMATORS:
|
||||
- bin_sld_lr_gs
|
||||
- mul_sld_lr_gs
|
||||
- bin_kde_lr_gs
|
||||
- mul_kde_lr_gs
|
||||
- atc_mc
|
||||
- doc
|
||||
N_JOBS: -2
|
||||
|
||||
confs: *main_confs
|
||||
|
||||
timing_conf: &timing_conf
|
||||
|
@ -461,4 +481,4 @@ timing_gs_conf: &timing_gs_conf
|
|||
|
||||
confs: *main_confs
|
||||
|
||||
exec: *debug_conf
|
||||
exec: *multiclass_conf
|
||||
|
|
|
@ -436,6 +436,16 @@ def _cr_data(cr: CompReport, metric=None, estimators=None):
|
|||
return cr.data(metric, estimators)
|
||||
|
||||
|
||||
def _key_reverse_delta_train(idx):
|
||||
idx = idx.to_numpy()
|
||||
sorted_idx = np.array(
|
||||
sorted(list(idx), key=lambda x: x[-1]), dtype=("float," * len(idx[0]))[:-1]
|
||||
)
|
||||
# get sorting index
|
||||
nparr = np.nonzero(idx[:, None] == sorted_idx)[1]
|
||||
return nparr
|
||||
|
||||
|
||||
class DatasetReport:
|
||||
_default_dr_modes = [
|
||||
"delta_train",
|
||||
|
@ -457,6 +467,16 @@ class DatasetReport:
|
|||
self.name = name
|
||||
self.crs: List[CompReport] = [] if crs is None else crs
|
||||
|
||||
def sort_delta_train_index(self, data):
|
||||
# data_ = data.sort_index(axis=0, level=0, ascending=True, sort_remaining=False)
|
||||
data_ = data.sort_index(
|
||||
axis=0,
|
||||
level=0,
|
||||
key=_key_reverse_delta_train,
|
||||
)
|
||||
print(data_.index)
|
||||
return data_
|
||||
|
||||
def join(self, other, estimators=None):
|
||||
_crs = [
|
||||
s_cr.join(o_cr, estimators=estimators)
|
||||
|
@ -542,6 +562,7 @@ class DatasetReport:
|
|||
_data.index = _idx
|
||||
|
||||
_data = _data.sort_index(axis=0, level=0, ascending=False, sort_remaining=False)
|
||||
|
||||
return _data
|
||||
|
||||
def shift_data(
|
||||
|
@ -633,6 +654,8 @@ class DatasetReport:
|
|||
avg_on_train = _data.groupby(level=1, sort=False).mean()
|
||||
if avg_on_train.empty:
|
||||
return None
|
||||
# sort index in reverse order
|
||||
avg_on_train = self.sort_delta_train_index(avg_on_train)
|
||||
prevs_on_train = avg_on_train.index.unique(0)
|
||||
return plot.plot_delta(
|
||||
# base_prevs=np.around(
|
||||
|
|
Loading…
Reference in New Issue