plot report for multiclass fixed
This commit is contained in:
parent
06efc90257
commit
50f4a833dc
|
@ -14,6 +14,7 @@ tests/__pycache__/*
|
||||||
tests/*/__pycache__/*
|
tests/*/__pycache__/*
|
||||||
tests/*/*/__pycache__/*
|
tests/*/*/__pycache__/*
|
||||||
htmlcov/*
|
htmlcov/*
|
||||||
|
test*.py
|
||||||
|
|
||||||
*.coverage
|
*.coverage
|
||||||
.coverage
|
.coverage
|
||||||
|
@ -22,4 +23,4 @@ scp_sync.py
|
||||||
|
|
||||||
out/*
|
out/*
|
||||||
output/*
|
output/*
|
||||||
!output/main/
|
!output/main/
|
||||||
|
|
22
conf.yaml
22
conf.yaml
|
@ -415,6 +415,26 @@ kde_lr_gs_conf: &kde_lr_gs_conf
|
||||||
- m3w_kde_lr_gs
|
- m3w_kde_lr_gs
|
||||||
N_JOBS: -2
|
N_JOBS: -2
|
||||||
|
|
||||||
|
confs:
|
||||||
|
- DATASET_NAME: twitter_gasp
|
||||||
|
|
||||||
|
|
||||||
|
multiclass_conf: &multiclass_conf
|
||||||
|
global:
|
||||||
|
METRICS:
|
||||||
|
- acc
|
||||||
|
- f1
|
||||||
|
OUT_DIR_NAME: output/multiclass
|
||||||
|
DATASET_N_PREVS: 5
|
||||||
|
COMP_ESTIMATORS:
|
||||||
|
- bin_sld_lr_gs
|
||||||
|
- mul_sld_lr_gs
|
||||||
|
- bin_kde_lr_gs
|
||||||
|
- mul_kde_lr_gs
|
||||||
|
- atc_mc
|
||||||
|
- doc
|
||||||
|
N_JOBS: -2
|
||||||
|
|
||||||
confs: *main_confs
|
confs: *main_confs
|
||||||
|
|
||||||
timing_conf: &timing_conf
|
timing_conf: &timing_conf
|
||||||
|
@ -461,4 +481,4 @@ timing_gs_conf: &timing_gs_conf
|
||||||
|
|
||||||
confs: *main_confs
|
confs: *main_confs
|
||||||
|
|
||||||
exec: *debug_conf
|
exec: *multiclass_conf
|
||||||
|
|
|
@ -436,6 +436,16 @@ def _cr_data(cr: CompReport, metric=None, estimators=None):
|
||||||
return cr.data(metric, estimators)
|
return cr.data(metric, estimators)
|
||||||
|
|
||||||
|
|
||||||
|
def _key_reverse_delta_train(idx):
|
||||||
|
idx = idx.to_numpy()
|
||||||
|
sorted_idx = np.array(
|
||||||
|
sorted(list(idx), key=lambda x: x[-1]), dtype=("float," * len(idx[0]))[:-1]
|
||||||
|
)
|
||||||
|
# get sorting index
|
||||||
|
nparr = np.nonzero(idx[:, None] == sorted_idx)[1]
|
||||||
|
return nparr
|
||||||
|
|
||||||
|
|
||||||
class DatasetReport:
|
class DatasetReport:
|
||||||
_default_dr_modes = [
|
_default_dr_modes = [
|
||||||
"delta_train",
|
"delta_train",
|
||||||
|
@ -457,6 +467,16 @@ class DatasetReport:
|
||||||
self.name = name
|
self.name = name
|
||||||
self.crs: List[CompReport] = [] if crs is None else crs
|
self.crs: List[CompReport] = [] if crs is None else crs
|
||||||
|
|
||||||
|
def sort_delta_train_index(self, data):
|
||||||
|
# data_ = data.sort_index(axis=0, level=0, ascending=True, sort_remaining=False)
|
||||||
|
data_ = data.sort_index(
|
||||||
|
axis=0,
|
||||||
|
level=0,
|
||||||
|
key=_key_reverse_delta_train,
|
||||||
|
)
|
||||||
|
print(data_.index)
|
||||||
|
return data_
|
||||||
|
|
||||||
def join(self, other, estimators=None):
|
def join(self, other, estimators=None):
|
||||||
_crs = [
|
_crs = [
|
||||||
s_cr.join(o_cr, estimators=estimators)
|
s_cr.join(o_cr, estimators=estimators)
|
||||||
|
@ -542,6 +562,7 @@ class DatasetReport:
|
||||||
_data.index = _idx
|
_data.index = _idx
|
||||||
|
|
||||||
_data = _data.sort_index(axis=0, level=0, ascending=False, sort_remaining=False)
|
_data = _data.sort_index(axis=0, level=0, ascending=False, sort_remaining=False)
|
||||||
|
|
||||||
return _data
|
return _data
|
||||||
|
|
||||||
def shift_data(
|
def shift_data(
|
||||||
|
@ -633,6 +654,8 @@ class DatasetReport:
|
||||||
avg_on_train = _data.groupby(level=1, sort=False).mean()
|
avg_on_train = _data.groupby(level=1, sort=False).mean()
|
||||||
if avg_on_train.empty:
|
if avg_on_train.empty:
|
||||||
return None
|
return None
|
||||||
|
# sort index in reverse order
|
||||||
|
avg_on_train = self.sort_delta_train_index(avg_on_train)
|
||||||
prevs_on_train = avg_on_train.index.unique(0)
|
prevs_on_train = avg_on_train.index.unique(0)
|
||||||
return plot.plot_delta(
|
return plot.plot_delta(
|
||||||
# base_prevs=np.around(
|
# base_prevs=np.around(
|
||||||
|
|
Loading…
Reference in New Issue