cleaning baselines

This commit is contained in:
Alejandro Moreo Fernandez 2024-05-02 16:53:44 +02:00
parent 07a86746c3
commit 8dfb109b41
1 changed files with 4 additions and 7 deletions
LeQua2024

View File

@ -12,7 +12,7 @@ from quapy.method.non_aggregative import MaximumLikelihoodPrevalenceEstimation a
import quapy.functional as F
# LeQua official baselines
# LeQua official baselines (under development!)
# =================================================================================
BINARY_TASKS = ['T1', 'T4']
@ -40,12 +40,6 @@ def baselines():
yield ACC(new_cls()), "ACC", q_params
yield PCC(new_cls()), "PCC", q_params
yield PACC(new_cls()), "PACC", q_params
# yield EMQ(CalibratedClassifierCV(new_cls())), "SLD-Platt", wrap_params(wrap_params(lr_params, 'estimator'), 'classifier')
# yield EMQ(new_cls()), "SLD", q_params
# yield EMQ(new_cls()), "SLD-BCTS", {**q_params, 'recalib': ['bcts'], 'val_split': [5]}
# yield MLPE(), "MLPE", None
# if args.task in BINARY_TASKS:
# yield MS2(new_cls()), "MedianSweep2", q_params
def main(args):
@ -60,6 +54,9 @@ def main(args):
train, gen_val, gen_test = fetch_lequa2024(task=args.task, data_home=args.datadir, merge_T3=True)
# gen_test is None, since the true prevalence vectors for the test samples will be released
# only after the competition ends
print(f'number of classes: {len(train.classes_)}')
print(f'number of training documents: {len(train)}')
print(f'training prevalence: {F.strprev(train.prevalence())}')