labels on plot shift added
This commit is contained in:
parent
29c871367e
commit
327cbdaf9e
18
TODO.html
18
TODO.html
|
@ -103,15 +103,6 @@ verbose=True).fit(V_tr)</li>
|
|||
<p><input class="task-list-item-checkbox" checked=""type="checkbox"> import baselines</p>
|
||||
</li>
|
||||
<li class="task-list-item enabled">
|
||||
<p><input class="task-list-item-checkbox" checked=""type="checkbox"> plot avg con train prevalence sull'asse x e media su test prevalecne</p>
|
||||
</li>
|
||||
<li class="task-list-item enabled">
|
||||
<p><input class="task-list-item-checkbox" checked=""type="checkbox"> realizzare grid search per task specifico partendo da GridSearchQ</p>
|
||||
</li>
|
||||
<li class="task-list-item enabled">
|
||||
<p><input class="task-list-item-checkbox" checked=""type="checkbox"> provare PACC come quantificatore</p>
|
||||
</li>
|
||||
<li class="task-list-item enabled">
|
||||
<p><input class="task-list-item-checkbox"type="checkbox"> importare mandoline</p>
|
||||
<ul>
|
||||
<li>mandoline può essere importato, ma richiedere uno slicing delle features a priori che devere essere realizzato ad hoc</li>
|
||||
|
@ -124,6 +115,15 @@ verbose=True).fit(V_tr)</li>
|
|||
</ul>
|
||||
</li>
|
||||
<li class="task-list-item enabled">
|
||||
<p><input class="task-list-item-checkbox" checked=""type="checkbox"> plot avg con train prevalence sull'asse x e media su test prevalecne</p>
|
||||
</li>
|
||||
<li class="task-list-item enabled">
|
||||
<p><input class="task-list-item-checkbox" checked=""type="checkbox"> realizzare grid search per task specifico partendo da GridSearchQ</p>
|
||||
</li>
|
||||
<li class="task-list-item enabled">
|
||||
<p><input class="task-list-item-checkbox" checked=""type="checkbox"> provare PACC come quantificatore</p>
|
||||
</li>
|
||||
<li class="task-list-item enabled">
|
||||
<p><input class="task-list-item-checkbox"type="checkbox"> aggiungere etichette in shift plot</p>
|
||||
</li>
|
||||
<li class="task-list-item enabled">
|
||||
|
|
6
TODO.md
6
TODO.md
|
@ -30,13 +30,13 @@
|
|||
- nel caso di bin fare media dei due best score
|
||||
- [x] import baselines
|
||||
|
||||
- [x] plot avg con train prevalence sull'asse x e media su test prevalecne
|
||||
- [x] realizzare grid search per task specifico partendo da GridSearchQ
|
||||
- [x] provare PACC come quantificatore
|
||||
- [ ] importare mandoline
|
||||
- mandoline può essere importato, ma richiedere uno slicing delle features a priori che devere essere realizzato ad hoc
|
||||
- [ ] sistemare vecchie iw baselines
|
||||
- non possono essere fixate perché dipendono da numpy
|
||||
- [x] plot avg con train prevalence sull'asse x e media su test prevalecne
|
||||
- [x] realizzare grid search per task specifico partendo da GridSearchQ
|
||||
- [x] provare PACC come quantificatore
|
||||
- [ ] aggiungere etichette in shift plot
|
||||
- [ ] sistemare exact_train quapy
|
||||
- [ ] testare anche su imbd
|
|
@ -151,4 +151,4 @@ main_conf: &main_conf
|
|||
- atc_ne
|
||||
- doc_feat
|
||||
|
||||
exec: *mc_conf
|
||||
exec: *debug_conf
|
151
quacc.log
151
quacc.log
|
@ -2850,3 +2850,154 @@
|
|||
05/11/23 14:16:15| INFO atc_mc finished [took 49.6779s]
|
||||
05/11/23 14:16:19| INFO mulmc_sld finished [took 61.0610s]
|
||||
05/11/23 14:16:22| INFO mulne_sld finished [took 62.2089s]
|
||||
05/11/23 14:19:02| INFO binmc_sld finished [took 225.5737s]
|
||||
05/11/23 14:19:03| INFO binne_sld finished [took 223.9017s]
|
||||
05/11/23 14:28:50| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 100.0, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': 'vs', 'confidence': 'entropy'} (score=0.00756) [took 806.7930s]
|
||||
05/11/23 14:29:32| INFO mul_sld_gs finished [took 848.7630s]
|
||||
05/11/23 14:36:02| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 100.0, 'quantifier__classifier__class_weight': None, 'quantifier__recalib': 'vs', 'confidence': 'entropy'} (score=0.00781) [took 1240.9138s]
|
||||
05/11/23 14:39:04| INFO bin_sld_gs finished [took 1422.5520s]
|
||||
05/11/23 14:39:04| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs finished [took 1428.8824s]
|
||||
05/11/23 14:39:04| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs started
|
||||
05/11/23 14:39:58| INFO ref finished [took 45.7514s]
|
||||
05/11/23 14:40:02| INFO atc_mc finished [took 48.3888s]
|
||||
05/11/23 14:40:05| INFO mulmc_sld finished [took 59.0537s]
|
||||
05/11/23 14:40:09| INFO mulne_sld finished [took 60.9189s]
|
||||
05/11/23 14:42:42| INFO binne_sld finished [took 214.5464s]
|
||||
05/11/23 14:42:44| INFO binmc_sld finished [took 218.8429s]
|
||||
05/11/23 14:52:23| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 1000.0, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': 'vs', 'confidence': 'entropy'} (score=0.00984) [took 792.5474s]
|
||||
05/11/23 14:53:05| INFO mul_sld_gs finished [took 834.1824s]
|
||||
05/11/23 14:59:56| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 1.0, 'quantifier__classifier__class_weight': None, 'quantifier__recalib': None, 'confidence': 'max_conf'} (score=0.01112) [took 1247.0092s]
|
||||
05/11/23 15:02:57| INFO bin_sld_gs finished [took 1427.5051s]
|
||||
05/11/23 15:02:57| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs finished [took 1432.9172s]
|
||||
05/11/23 15:02:57| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs started
|
||||
05/11/23 15:03:49| INFO ref finished [took 44.4148s]
|
||||
05/11/23 15:03:54| INFO atc_mc finished [took 47.7566s]
|
||||
05/11/23 15:04:00| INFO mulmc_sld finished [took 60.5480s]
|
||||
05/11/23 15:04:03| INFO mulne_sld finished [took 61.2226s]
|
||||
05/11/23 15:06:30| INFO binmc_sld finished [took 211.9647s]
|
||||
05/11/23 15:06:32| INFO binne_sld finished [took 211.4312s]
|
||||
05/11/23 15:16:00| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 10.0, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': 'vs', 'confidence': 'max_conf'} (score=0.00571) [took 776.6085s]
|
||||
05/11/23 15:16:42| INFO mul_sld_gs finished [took 817.9358s]
|
||||
05/11/23 15:23:24| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 1.0, 'quantifier__classifier__class_weight': None, 'quantifier__recalib': 'vs', 'confidence': 'entropy'} (score=0.00653) [took 1221.6531s]
|
||||
05/11/23 15:26:23| INFO bin_sld_gs finished [took 1400.9688s]
|
||||
05/11/23 15:26:23| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs finished [took 1406.4620s]
|
||||
05/11/23 15:26:23| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs started
|
||||
05/11/23 15:27:16| INFO ref finished [took 44.3988s]
|
||||
05/11/23 15:27:21| INFO atc_mc finished [took 48.5589s]
|
||||
05/11/23 15:27:27| INFO mulmc_sld finished [took 61.4269s]
|
||||
05/11/23 15:27:29| INFO mulne_sld finished [took 61.8292s]
|
||||
05/11/23 15:29:55| INFO binmc_sld finished [took 210.1585s]
|
||||
05/11/23 15:29:59| INFO binne_sld finished [took 212.0930s]
|
||||
05/11/23 15:39:22| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 100.0, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': 'vs', 'confidence': 'max_conf'} (score=0.00616) [took 771.6071s]
|
||||
05/11/23 15:40:03| INFO mul_sld_gs finished [took 813.2905s]
|
||||
05/11/23 15:47:04| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 100.0, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': None, 'confidence': None} (score=0.00544) [took 1234.9832s]
|
||||
05/11/23 15:50:10| INFO bin_sld_gs finished [took 1421.7775s]
|
||||
05/11/23 15:50:10| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs finished [took 1427.0062s]
|
||||
05/11/23 15:50:10| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs started
|
||||
05/11/23 15:51:11| INFO ref finished [took 49.7682s]
|
||||
05/11/23 15:51:19| INFO atc_mc finished [took 54.2855s]
|
||||
05/11/23 15:51:22| INFO mulmc_sld finished [took 68.7688s]
|
||||
05/11/23 15:51:26| INFO mulne_sld finished [took 69.3711s]
|
||||
05/11/23 15:54:07| INFO binmc_sld finished [took 234.7962s]
|
||||
05/11/23 15:54:09| INFO binne_sld finished [took 234.6444s]
|
||||
05/11/23 16:03:51| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 0.1, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': 'bcts', 'confidence': 'entropy'} (score=0.00765) [took 811.6704s]
|
||||
05/11/23 16:04:34| INFO mul_sld_gs finished [took 854.8196s]
|
||||
05/11/23 16:11:10| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 0.1, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': 'vs', 'confidence': 'max_conf'} (score=0.01234) [took 1252.4784s]
|
||||
05/11/23 16:14:10| INFO bin_sld_gs finished [took 1431.7446s]
|
||||
05/11/23 16:14:10| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs finished [took 1439.1145s]
|
||||
05/11/23 16:14:10| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs started
|
||||
05/11/23 16:15:02| INFO ref finished [took 44.0970s]
|
||||
05/11/23 16:15:07| INFO atc_mc finished [took 48.2871s]
|
||||
05/11/23 16:15:13| INFO mulmc_sld finished [took 61.0461s]
|
||||
05/11/23 16:15:15| INFO mulne_sld finished [took 60.6375s]
|
||||
05/11/23 16:17:46| INFO binmc_sld finished [took 215.1734s]
|
||||
05/11/23 16:17:49| INFO binne_sld finished [took 215.7846s]
|
||||
05/11/23 16:27:15| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 1.0, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': 'vs', 'confidence': None} (score=0.00822) [took 778.5688s]
|
||||
05/11/23 16:27:56| INFO mul_sld_gs finished [took 819.2615s]
|
||||
05/11/23 16:34:16| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 1.0, 'quantifier__classifier__class_weight': None, 'quantifier__recalib': None, 'confidence': 'entropy'} (score=0.00894) [took 1200.6639s]
|
||||
05/11/23 16:37:21| INFO bin_sld_gs finished [took 1385.9035s]
|
||||
05/11/23 16:37:21| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs finished [took 1391.5055s]
|
||||
05/11/23 16:37:21| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs started
|
||||
05/11/23 16:38:13| INFO ref finished [took 44.7046s]
|
||||
05/11/23 16:38:18| INFO atc_mc finished [took 48.7802s]
|
||||
05/11/23 16:38:21| INFO mulmc_sld finished [took 57.4163s]
|
||||
05/11/23 16:38:24| INFO mulne_sld finished [took 58.9847s]
|
||||
05/11/23 16:40:59| INFO binmc_sld finished [took 216.7311s]
|
||||
05/11/23 16:41:01| INFO binne_sld finished [took 216.5312s]
|
||||
05/11/23 16:50:06| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 100.0, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': None, 'confidence': 'max_conf'} (score=0.00808) [took 758.6896s]
|
||||
05/11/23 16:50:46| INFO mul_sld_gs finished [took 798.8038s]
|
||||
05/11/23 16:56:41| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 10.0, 'quantifier__classifier__class_weight': None, 'quantifier__recalib': None, 'confidence': 'entropy'} (score=0.00604) [took 1154.7043s]
|
||||
05/11/23 16:59:39| INFO bin_sld_gs finished [took 1332.5521s]
|
||||
05/11/23 16:59:39| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs finished [took 1337.7947s]
|
||||
----------------------------------------------------------------------------------------------------
|
||||
05/11/23 20:08:46| ERROR estimate comparison failed. Exceprion: 'environ' object has no attribute 'OUT_PATH'
|
||||
----------------------------------------------------------------------------------------------------
|
||||
05/11/23 20:09:08| ERROR estimate comparison failed. Exceprion: 'environ' object has no attribute 'OUT_PATH'
|
||||
----------------------------------------------------------------------------------------------------
|
||||
05/11/23 20:09:27| INFO dataset imdb_3prevs
|
||||
05/11/23 20:09:34| INFO Dataset sample 0.20 of dataset imdb_3prevs started
|
||||
05/11/23 20:09:44| INFO ref finished [took 8.9550s]
|
||||
05/11/23 20:09:47| INFO atc_mc finished [took 11.8923s]
|
||||
05/11/23 20:09:56| INFO mulmc_sld finished [took 21.3196s]
|
||||
05/11/23 20:09:56| INFO Dataset sample 0.20 of dataset imdb_3prevs finished [took 21.7709s]
|
||||
05/11/23 20:09:56| INFO Dataset sample 0.50 of dataset imdb_3prevs started
|
||||
05/11/23 20:10:05| INFO ref finished [took 8.6116s]
|
||||
05/11/23 20:10:08| INFO atc_mc finished [took 11.6880s]
|
||||
05/11/23 20:10:16| INFO mulmc_sld finished [took 19.7793s]
|
||||
05/11/23 20:10:16| INFO Dataset sample 0.50 of dataset imdb_3prevs finished [took 20.3246s]
|
||||
05/11/23 20:10:16| INFO Dataset sample 0.80 of dataset imdb_3prevs started
|
||||
05/11/23 20:10:26| INFO ref finished [took 8.6654s]
|
||||
05/11/23 20:10:29| INFO atc_mc finished [took 11.6975s]
|
||||
05/11/23 20:10:35| INFO mulmc_sld finished [took 18.1478s]
|
||||
05/11/23 20:10:35| INFO Dataset sample 0.80 of dataset imdb_3prevs finished [took 18.7200s]
|
||||
----------------------------------------------------------------------------------------------------
|
||||
05/11/23 20:11:42| INFO dataset imdb_3prevs
|
||||
05/11/23 20:11:49| INFO Dataset sample 0.20 of dataset imdb_3prevs started
|
||||
05/11/23 20:11:58| INFO ref finished [took 8.7146s]
|
||||
05/11/23 20:12:02| INFO atc_mc finished [took 11.9672s]
|
||||
05/11/23 20:12:10| INFO mulmc_sld finished [took 20.7824s]
|
||||
05/11/23 20:12:10| INFO Dataset sample 0.20 of dataset imdb_3prevs finished [took 21.2293s]
|
||||
05/11/23 20:12:10| INFO Dataset sample 0.50 of dataset imdb_3prevs started
|
||||
05/11/23 20:12:19| INFO ref finished [took 8.5867s]
|
||||
05/11/23 20:12:23| INFO atc_mc finished [took 11.6542s]
|
||||
05/11/23 20:12:30| INFO mulmc_sld finished [took 19.6709s]
|
||||
05/11/23 20:12:30| INFO Dataset sample 0.50 of dataset imdb_3prevs finished [took 20.1802s]
|
||||
05/11/23 20:12:30| INFO Dataset sample 0.80 of dataset imdb_3prevs started
|
||||
05/11/23 20:12:40| INFO ref finished [took 8.7231s]
|
||||
05/11/23 20:12:43| INFO atc_mc finished [took 11.8244s]
|
||||
05/11/23 20:12:49| INFO mulmc_sld finished [took 18.0420s]
|
||||
05/11/23 20:12:49| INFO Dataset sample 0.80 of dataset imdb_3prevs finished [took 18.6102s]
|
||||
----------------------------------------------------------------------------------------------------
|
||||
05/11/23 20:14:32| INFO dataset imdb_3prevs
|
||||
05/11/23 20:14:39| INFO Dataset sample 0.20 of dataset imdb_3prevs started
|
||||
05/11/23 20:14:48| INFO ref finished [took 8.6247s]
|
||||
05/11/23 20:14:51| INFO atc_mc finished [took 11.6363s]
|
||||
05/11/23 20:15:00| INFO mulmc_sld finished [took 20.4634s]
|
||||
05/11/23 20:15:00| INFO Dataset sample 0.20 of dataset imdb_3prevs finished [took 20.9026s]
|
||||
05/11/23 20:15:00| INFO Dataset sample 0.50 of dataset imdb_3prevs started
|
||||
05/11/23 20:15:09| INFO ref finished [took 8.5219s]
|
||||
05/11/23 20:15:12| INFO atc_mc finished [took 11.6739s]
|
||||
05/11/23 20:15:20| INFO mulmc_sld finished [took 19.8454s]
|
||||
05/11/23 20:15:20| INFO Dataset sample 0.50 of dataset imdb_3prevs finished [took 20.3705s]
|
||||
05/11/23 20:15:20| INFO Dataset sample 0.80 of dataset imdb_3prevs started
|
||||
05/11/23 20:15:29| INFO ref finished [took 8.5948s]
|
||||
05/11/23 20:15:32| INFO atc_mc finished [took 11.7465s]
|
||||
05/11/23 20:15:39| INFO mulmc_sld finished [took 17.9276s]
|
||||
05/11/23 20:15:39| INFO Dataset sample 0.80 of dataset imdb_3prevs finished [took 18.4893s]
|
||||
----------------------------------------------------------------------------------------------------
|
||||
05/11/23 20:16:10| INFO dataset imdb_3prevs
|
||||
05/11/23 20:16:17| INFO Dataset sample 0.20 of dataset imdb_3prevs started
|
||||
05/11/23 20:16:26| INFO ref finished [took 8.3736s]
|
||||
05/11/23 20:16:29| INFO atc_mc finished [took 11.3995s]
|
||||
05/11/23 20:16:38| INFO mulmc_sld finished [took 20.4916s]
|
||||
05/11/23 20:16:38| INFO Dataset sample 0.20 of dataset imdb_3prevs finished [took 20.9187s]
|
||||
05/11/23 20:16:38| INFO Dataset sample 0.50 of dataset imdb_3prevs started
|
||||
05/11/23 20:16:47| INFO ref finished [took 8.4368s]
|
||||
05/11/23 20:16:50| INFO atc_mc finished [took 11.4889s]
|
||||
05/11/23 20:16:58| INFO mulmc_sld finished [took 19.6803s]
|
||||
05/11/23 20:16:58| INFO Dataset sample 0.50 of dataset imdb_3prevs finished [took 20.2091s]
|
||||
05/11/23 20:16:58| INFO Dataset sample 0.80 of dataset imdb_3prevs started
|
||||
05/11/23 20:17:08| INFO ref finished [took 8.9281s]
|
||||
05/11/23 20:17:11| INFO atc_mc finished [took 11.9333s]
|
||||
05/11/23 20:17:17| INFO mulmc_sld finished [took 18.2367s]
|
||||
05/11/23 20:17:17| INFO Dataset sample 0.80 of dataset imdb_3prevs finished [took 18.8309s]
|
||||
|
|
|
@ -182,22 +182,21 @@ class CompReport:
|
|||
train_prev=self.train_prev,
|
||||
)
|
||||
elif mode == "shift":
|
||||
shift_data = (
|
||||
self.shift_data(metric=metric, estimators=estimators)
|
||||
.groupby(level=0)
|
||||
.mean()
|
||||
)
|
||||
_shift_data = self.shift_data(metric=metric, estimators=estimators)
|
||||
shift_avg = _shift_data.groupby(level=0).mean()
|
||||
shift_counts = _shift_data.groupby(level=0).count()
|
||||
shift_prevs = np.around(
|
||||
[(1.0 - p, p) for p in np.sort(shift_data.index.unique(0))],
|
||||
[(1.0 - p, p) for p in np.sort(shift_avg.index.unique(0))],
|
||||
decimals=2,
|
||||
)
|
||||
return plot.plot_shift(
|
||||
shift_prevs=shift_prevs,
|
||||
columns=shift_data.columns.to_numpy(),
|
||||
data=shift_data.T.to_numpy(),
|
||||
columns=shift_avg.columns.to_numpy(),
|
||||
data=shift_avg.T.to_numpy(),
|
||||
metric=metric,
|
||||
name=conf,
|
||||
train_prev=self.train_prev,
|
||||
counts=shift_counts.T.to_numpy(),
|
||||
)
|
||||
|
||||
def to_md(self, conf="default", metric="acc", estimators=None, stdev=False) -> str:
|
||||
|
@ -374,6 +373,7 @@ class DatasetReport:
|
|||
res += "### avg dataset shift\n"
|
||||
|
||||
avg_shift = _shift_data.groupby(level=0).mean()
|
||||
count_shift = _shift_data.groupby(level=0).count()
|
||||
prevs_shift = np.sort(avg_shift.index.unique(0))
|
||||
|
||||
shift_op = plot.plot_shift(
|
||||
|
@ -383,6 +383,7 @@ class DatasetReport:
|
|||
metric=metric,
|
||||
name=conf,
|
||||
train_prev=None,
|
||||
counts=count_shift.T.to_numpy(),
|
||||
)
|
||||
res += f".as_posix()})\n"
|
||||
|
||||
|
|
|
@ -27,7 +27,6 @@ def plot_delta(
|
|||
metric="acc",
|
||||
name="default",
|
||||
train_prev=None,
|
||||
fit_scores=None,
|
||||
legend=True,
|
||||
avg=None,
|
||||
) -> Path:
|
||||
|
@ -75,14 +74,6 @@ def plot_delta(
|
|||
color=_cy["color"],
|
||||
alpha=0.25,
|
||||
)
|
||||
if fit_scores is not None and method in fit_scores:
|
||||
ax.plot(
|
||||
base_prevs,
|
||||
np.repeat(fit_scores[method], base_prevs.shape[0]),
|
||||
color=_cy["color"],
|
||||
linestyle="--",
|
||||
markersize=0,
|
||||
)
|
||||
|
||||
x_label = "test" if avg is None or avg == "train" else "train"
|
||||
ax.set(
|
||||
|
@ -188,11 +179,11 @@ def plot_shift(
|
|||
columns,
|
||||
data,
|
||||
*,
|
||||
counts=None,
|
||||
pos_class=1,
|
||||
metric="acc",
|
||||
name="default",
|
||||
train_prev=None,
|
||||
fit_scores=None,
|
||||
legend=True,
|
||||
) -> Path:
|
||||
if train_prev is not None:
|
||||
|
@ -223,15 +214,20 @@ def plot_shift(
|
|||
markersize=3,
|
||||
zorder=2,
|
||||
)
|
||||
|
||||
if fit_scores is not None and method in fit_scores:
|
||||
ax.plot(
|
||||
shift_prevs,
|
||||
np.repeat(fit_scores[method], shift_prevs.shape[0]),
|
||||
color=_cy["color"],
|
||||
linestyle="--",
|
||||
markersize=0,
|
||||
)
|
||||
if counts is not None:
|
||||
_col_idx = np.where(columns == method)[0]
|
||||
count = counts[_col_idx].flatten()
|
||||
for prev, shift, cnt in zip(shift_prevs, shifts, count):
|
||||
label = f"{cnt}"
|
||||
plt.annotate(
|
||||
label,
|
||||
(prev, shift),
|
||||
textcoords="offset points",
|
||||
xytext=(0, 10),
|
||||
ha="center",
|
||||
color=_cy["color"],
|
||||
fontsize=12.0,
|
||||
)
|
||||
|
||||
ax.set(xlabel="dataset shift", ylabel=metric, title=title)
|
||||
|
||||
|
|
Loading…
Reference in New Issue