all examples but 15 (qunfold) properly working

This commit is contained in:
Alejandro Moreo Fernandez 2025-10-01 17:41:36 +02:00
parent edbc8bc201
commit 24ab704661
18 changed files with 168 additions and 78 deletions

View File

@ -1,5 +1,7 @@
Adapt examples; remaining: example 4-onwards Adapt examples; remaining: example 4-onwards
not working: 4, 4b, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 not working: 15 (qunfold)
Solve the warnings issue; right now there is a warning ignore in method/__init__.py:
Add 'platt' to calib options in EMQ? Add 'platt' to calib options in EMQ?

View File

@ -402,6 +402,10 @@ train, test_gen = qp.datasets.fetch_IFCB(for_model_selection=False, single_sampl
# ... train and evaluation # ... train and evaluation
``` ```
See also [Automatic plankton quantification using deep features
P González, A Castaño, EE Peacock, J Díez, JJ Del Coz, HM Sosik
Journal of Plankton Research 41 (4), 449-463](https://par.nsf.gov/servlets/purl/10172325).
## Adding Custom Datasets ## Adding Custom Datasets

View File

@ -9,6 +9,11 @@ import numpy as np
""" """
In this example, we will create a quantifier for tweet sentiment analysis considering three classes: negative, neutral, In this example, we will create a quantifier for tweet sentiment analysis considering three classes: negative, neutral,
and positive. We will use a one-vs-all approach using a binary quantifier for demonstration purposes. and positive. We will use a one-vs-all approach using a binary quantifier for demonstration purposes.
Caveat: the one-vs-all approach is deemed inadequate under prior probability shift conditions. The reasons
are discussed in:
Donyavi, Z., Serapio, A., & Batista, G. (2023). MC-SQ: A highly accurate ensemble for multi-class quantifi-
cation. In: Proceedings of the 2023 SIAM International Conference on Data Mining (SDM), SIAM, pp. 622630
""" """
qp.environ['SAMPLE_SIZE'] = 100 qp.environ['SAMPLE_SIZE'] = 100
@ -40,11 +45,11 @@ param_grid = {
} }
print('starting model selection') print('starting model selection')
model_selection = GridSearchQ(quantifier, param_grid, protocol=UPP(val), verbose=True, refit=False) model_selection = GridSearchQ(quantifier, param_grid, protocol=UPP(val), verbose=True, refit=False)
quantifier = model_selection.fit(train_modsel).best_model() quantifier = model_selection.fit(*train_modsel.Xy).best_model()
print('training on the whole training set') print('training on the whole training set')
train, test = qp.datasets.fetch_twitter('hcr', for_model_selection=False, pickle=True).train_test train, test = qp.datasets.fetch_twitter('hcr', for_model_selection=False, pickle=True).train_test
quantifier.fit(train) quantifier.fit(*train.Xy)
# evaluation # evaluation
mae = qp.evaluation.evaluate(quantifier, protocol=UPP(test), error_metric='mae') mae = qp.evaluation.evaluate(quantifier, protocol=UPP(test), error_metric='mae')

View File

@ -23,8 +23,9 @@ qp.environ['SAMPLE_SIZE']=100
df = pd.DataFrame(columns=['method', 'dataset', 'MAE', 'MRAE', 'tr-time', 'te-time']) df = pd.DataFrame(columns=['method', 'dataset', 'MAE', 'MRAE', 'tr-time', 'te-time'])
datasets = qp.datasets.UCI_BINARY_DATASETS
for dataset_name in tqdm(qp.datasets.UCI_BINARY_DATASETS, total=len(qp.datasets.UCI_BINARY_DATASETS)): for dataset_name in tqdm(datasets, total=len(datasets), desc='datasets processed'):
if dataset_name in ['acute.a', 'acute.b', 'balance.2', 'iris.1']: if dataset_name in ['acute.a', 'acute.b', 'balance.2', 'iris.1']:
# these datasets tend to produce either too good or too bad results... # these datasets tend to produce either too good or too bad results...
continue continue
@ -32,23 +33,25 @@ for dataset_name in tqdm(qp.datasets.UCI_BINARY_DATASETS, total=len(qp.datasets.
collection = qp.datasets.fetch_UCIBinaryLabelledCollection(dataset_name, verbose=False) collection = qp.datasets.fetch_UCIBinaryLabelledCollection(dataset_name, verbose=False)
train, test = collection.split_stratified() train, test = collection.split_stratified()
Xtr, ytr = train.Xy
# HDy............................................ # HDy............................................
tinit = time() tinit = time()
hdy = HDy(LogisticRegression()).fit(train) hdy = HDy(LogisticRegression()).fit(Xtr, ytr)
t_hdy_train = time()-tinit t_hdy_train = time()-tinit
tinit = time() tinit = time()
hdy_report = qp.evaluation.evaluation_report(hdy, APP(test), error_metrics=['mae', 'mrae']).mean() hdy_report = qp.evaluation.evaluation_report(hdy, APP(test), error_metrics=['mae', 'mrae']).mean(numeric_only=True)
t_hdy_test = time() - tinit t_hdy_test = time() - tinit
df.loc[len(df)] = ['HDy', dataset_name, hdy_report['mae'], hdy_report['mrae'], t_hdy_train, t_hdy_test] df.loc[len(df)] = ['HDy', dataset_name, hdy_report['mae'], hdy_report['mrae'], t_hdy_train, t_hdy_test]
# HDx............................................ # HDx............................................
tinit = time() tinit = time()
hdx = DMx.HDx(n_jobs=-1).fit(train) hdx = DMx.HDx(n_jobs=-1).fit(Xtr, ytr)
t_hdx_train = time() - tinit t_hdx_train = time() - tinit
tinit = time() tinit = time()
hdx_report = qp.evaluation.evaluation_report(hdx, APP(test), error_metrics=['mae', 'mrae']).mean() hdx_report = qp.evaluation.evaluation_report(hdx, APP(test), error_metrics=['mae', 'mrae']).mean(numeric_only=True)
t_hdx_test = time() - tinit t_hdx_test = time() - tinit
df.loc[len(df)] = ['HDx', dataset_name, hdx_report['mae'], hdx_report['mrae'], t_hdx_train, t_hdx_test] df.loc[len(df)] = ['HDx', dataset_name, hdx_report['mae'], hdx_report['mrae'], t_hdx_train, t_hdx_test]

View File

@ -3,14 +3,13 @@ from sklearn.linear_model import LogisticRegression
import quapy as qp import quapy as qp
from quapy.method.aggregative import PACC from quapy.method.aggregative import PACC
from quapy.data import LabelledCollection
from quapy.protocol import AbstractStochasticSeededProtocol from quapy.protocol import AbstractStochasticSeededProtocol
import quapy.functional as F import quapy.functional as F
""" """
In this example, we create a custom protocol. In this example, we create a custom protocol.
The protocol generates samples of a Gaussian mixture model with random mixture parameter (the sample prevalence). The protocol generates synthetic samples of a Gaussian mixture model with random mixture parameter
Datapoints are univariate and we consider 2 classes only. (the sample prevalence). Datapoints are univariate and we consider 2 classes only for simplicity.
""" """
class GaussianMixProtocol(AbstractStochasticSeededProtocol): class GaussianMixProtocol(AbstractStochasticSeededProtocol):
# We need to extend AbstractStochasticSeededProtocol if we want the samples to be replicable # We need to extend AbstractStochasticSeededProtocol if we want the samples to be replicable
@ -81,10 +80,9 @@ with qp.util.temp_seed(0):
Xpos = np.random.normal(loc=mu_2, scale=std_2, size=100) Xpos = np.random.normal(loc=mu_2, scale=std_2, size=100)
X = np.concatenate([Xneg, Xpos]).reshape(-1,1) X = np.concatenate([Xneg, Xpos]).reshape(-1,1)
y = [0]*100 + [1]*100 y = [0]*100 + [1]*100
training = LabelledCollection(X, y)
pacc = PACC(LogisticRegression()) pacc = PACC(LogisticRegression())
pacc.fit(training) pacc.fit(X, y)
mae = qp.evaluation.evaluate(pacc, protocol=gm, error_metric='mae', verbose=True) mae = qp.evaluation.evaluate(pacc, protocol=gm, error_metric='mae', verbose=True)

View File

@ -122,7 +122,7 @@ def get_random_forest() -> RandomForestClassifier:
def _get_estimate(estimator_class, training: LabelledCollection, test: np.ndarray) -> None: def _get_estimate(estimator_class, training: LabelledCollection, test: np.ndarray) -> None:
"""Auxiliary method for running ACC and PACC.""" """Auxiliary method for running ACC and PACC."""
estimator = estimator_class(get_random_forest()) estimator = estimator_class(get_random_forest())
estimator.fit(training) estimator.fit(*training.Xy)
return estimator.predict(test) return estimator.predict(test)
@ -130,7 +130,7 @@ def train_and_plot_bayesian_quantification(ax: plt.Axes, training: LabelledColle
"""Fits Bayesian quantification and plots posterior mean as well as individual samples""" """Fits Bayesian quantification and plots posterior mean as well as individual samples"""
print('training model Bayesian CC...', end='') print('training model Bayesian CC...', end='')
quantifier = BayesianCC(classifier=get_random_forest()) quantifier = BayesianCC(classifier=get_random_forest())
quantifier.fit(training) quantifier.fit(*training.Xy)
# Obtain mean prediction # Obtain mean prediction
mean_prediction = quantifier.predict(test.X) mean_prediction = quantifier.predict(test.X)

View File

@ -21,6 +21,7 @@ Let see one example:
# load some data # load some data
data = qp.datasets.fetch_UCIMulticlassDataset('molecular') data = qp.datasets.fetch_UCIMulticlassDataset('molecular')
train, test = data.train_test train, test = data.train_test
Xtr, ytr = train.Xy
# by simply wrapping an aggregative quantifier within the AggregativeBootstrap class, we can obtain confidence # by simply wrapping an aggregative quantifier within the AggregativeBootstrap class, we can obtain confidence
# intervals around the point estimate, in this case, at 95% of confidence # intervals around the point estimate, in this case, at 95% of confidence
@ -29,7 +30,7 @@ pacc = AggregativeBootstrap(PACC(), n_test_samples=500, confidence_level=0.95)
with qp.util.temp_seed(0): with qp.util.temp_seed(0):
# we train the quantifier the usual way # we train the quantifier the usual way
pacc.fit(train) pacc.fit(Xtr, ytr)
# let us simulate some shift in the test data # let us simulate some shift in the test data
random_prevalence = F.uniform_prevalence_sampling(n_classes=test.n_classes) random_prevalence = F.uniform_prevalence_sampling(n_classes=test.n_classes)
@ -53,7 +54,7 @@ with qp.util.temp_seed(0):
print(f'point-estimate: {F.strprev(pred_prev)}') print(f'point-estimate: {F.strprev(pred_prev)}')
print(f'absolute error: {error:.3f}') print(f'absolute error: {error:.3f}')
print(f'Is the true value in the confidence region?: {conf_intervals.coverage(true_prev)==1}') print(f'Is the true value in the confidence region?: {conf_intervals.coverage(true_prev)==1}')
print(f'Proportion of simplex covered at {pacc.confidence_level*100:.1f}%: {conf_intervals.simplex_portion()*100:.2f}%') print(f'Proportion of simplex covered at confidence level {pacc.confidence_level*100:.1f}%: {conf_intervals.simplex_portion()*100:.2f}%')
""" """
Final remarks: Final remarks:

View File

@ -31,13 +31,13 @@ training, val_generator, test_generator = fetch_lequa2022(task=task)
Xtr, ytr = training.Xy Xtr, ytr = training.Xy
# define the quantifier # define the quantifier
quantifier = EMQ(classifier=LogisticRegression()) quantifier = EMQ(classifier=LogisticRegression(), val_split=5)
# model selection # model selection
param_grid = { param_grid = {
'classifier__C': np.logspace(-3, 3, 7), # classifier-dependent: inverse of regularization strength 'classifier__C': np.logspace(-3, 3, 7), # classifier-dependent: inverse of regularization strength
'classifier__class_weight': ['balanced', None], # classifier-dependent: weights of each class 'classifier__class_weight': ['balanced', None], # classifier-dependent: weights of each class
# 'calib': ['bcts', None] # quantifier-dependent: recalibration method (new in v0.1.7) 'calib': ['bcts', None] # quantifier-dependent: recalibration method (new in v0.1.7)
} }
model_selection = GridSearchQ(quantifier, param_grid, protocol=val_generator, error='mrae', refit=False, verbose=True) model_selection = GridSearchQ(quantifier, param_grid, protocol=val_generator, error='mrae', refit=False, verbose=True)
quantifier = model_selection.fit(Xtr, ytr) quantifier = model_selection.fit(Xtr, ytr)

View File

@ -1,6 +1,6 @@
import quapy as qp
import numpy as np import numpy as np
from sklearn.linear_model import LogisticRegression from sklearn.linear_model import LogisticRegression
import quapy as qp
import quapy.functional as F import quapy.functional as F
from quapy.data.datasets import LEQUA2024_SAMPLE_SIZE, fetch_lequa2024 from quapy.data.datasets import LEQUA2024_SAMPLE_SIZE, fetch_lequa2024
from quapy.evaluation import evaluation_report from quapy.evaluation import evaluation_report
@ -14,6 +14,7 @@ LeQua competition itself, check:
https://lequa2024.github.io/index (the site of the competition) https://lequa2024.github.io/index (the site of the competition)
""" """
# there are 4 tasks: T1 (binary), T2 (multiclass), T3 (ordinal), T4 (binary - covariate & prior shift) # there are 4 tasks: T1 (binary), T2 (multiclass), T3 (ordinal), T4 (binary - covariate & prior shift)
task = 'T2' task = 'T2'
@ -38,6 +39,7 @@ param_grid = {
'classifier__class_weight': ['balanced', None], # classifier-dependent: weights of each class 'classifier__class_weight': ['balanced', None], # classifier-dependent: weights of each class
'bandwidth': np.linspace(0.01, 0.2, 20) # quantifier-dependent: bandwidth of the kernel 'bandwidth': np.linspace(0.01, 0.2, 20) # quantifier-dependent: bandwidth of the kernel
} }
model_selection = GridSearchQ(quantifier, param_grid, protocol=val_generator, error='mrae', refit=False, verbose=True) model_selection = GridSearchQ(quantifier, param_grid, protocol=val_generator, error='mrae', refit=False, verbose=True)
quantifier = model_selection.fit(Xtr, ytr) quantifier = model_selection.fit(Xtr, ytr)

View File

@ -1,4 +1,7 @@
from copy import deepcopy from copy import deepcopy
from pathlib import Path
import pandas as pd
import quapy as qp import quapy as qp
from sklearn.calibration import CalibratedClassifierCV from sklearn.calibration import CalibratedClassifierCV
@ -15,6 +18,18 @@ import itertools
import argparse import argparse
import torch import torch
import shutil import shutil
from glob import glob
"""
This example shows how to generate experiments for the UCI ML repository binary datasets following the protocol
proposed in "Pérez-Gállego , P., Quevedo , J. R., and del Coz, J. J. Using ensembles for problems with characteriz-
able changes in data distribution: A case study on quantification. Information Fusion 34 (2017), 87100."
This example covers most important steps in the experimentation pipeline, namely, the training and optimization
of the hyperparameters of different quantifiers, and the evaluation of these quantifiers based on standard
prevalence sampling protocols aimed at simulating different levels of prior probability shift.
"""
N_JOBS = -1 N_JOBS = -1
@ -28,10 +43,6 @@ def newLR():
return LogisticRegression(max_iter=1000, solver='lbfgs', n_jobs=-1) return LogisticRegression(max_iter=1000, solver='lbfgs', n_jobs=-1)
def calibratedLR():
return CalibratedClassifierCV(newLR())
__C_range = np.logspace(-3, 3, 7) __C_range = np.logspace(-3, 3, 7)
lr_params = { lr_params = {
'classifier__C': __C_range, 'classifier__C': __C_range,
@ -74,6 +85,13 @@ def result_path(path, dataset_name, model_name, run, optim_loss):
return os.path.join(path, f'{dataset_name}-{model_name}-run{run}-{optim_loss}.pkl') return os.path.join(path, f'{dataset_name}-{model_name}-run{run}-{optim_loss}.pkl')
def parse_result_path(path):
*dataset, method, run, metric = Path(path).name.split('-')
dataset = '-'.join(dataset)
run = int(run.replace('run',''))
return dataset, method, run, metric
def is_already_computed(dataset_name, model_name, run, optim_loss): def is_already_computed(dataset_name, model_name, run, optim_loss):
return os.path.exists(result_path(args.results, dataset_name, model_name, run, optim_loss)) return os.path.exists(result_path(args.results, dataset_name, model_name, run, optim_loss))
@ -130,10 +148,28 @@ def run(experiment):
best_params) best_params)
def show_results(result_folder):
result_data = []
for file in glob(os.path.join(result_folder,'*.pkl')):
true_prevalences, estim_prevalences, *_ = pickle.load(open(file, 'rb'))
dataset, method, run, metric = parse_result_path(file)
mae = qp.error.mae(true_prevalences, estim_prevalences)
result_data.append({
'dataset': dataset,
'method': method,
'run': run,
metric: mae
})
df = pd.DataFrame(result_data)
pd.set_option("display.max_columns", None)
pd.set_option("display.expand_frame_repr", False)
print(df.pivot_table(index='dataset', columns='method', values=metric))
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Run experiments for Tweeter Sentiment Quantification') parser = argparse.ArgumentParser(description='Run experiments for Tweeter Sentiment Quantification')
parser.add_argument('--results', metavar='RESULT_PATH', type=str, parser.add_argument('--results', metavar='RESULT_PATH', type=str,
help='path to the directory where to store the results', default='./uci_results') help='path to the directory where to store the results', default='./results/uci_binary')
parser.add_argument('--svmperfpath', metavar='SVMPERF_PATH', type=str, default='../svm_perf_quantification', parser.add_argument('--svmperfpath', metavar='SVMPERF_PATH', type=str, default='../svm_perf_quantification',
help='path to the directory with svmperf') help='path to the directory with svmperf')
parser.add_argument('--checkpointdir', metavar='PATH', type=str, default='./checkpoint', parser.add_argument('--checkpointdir', metavar='PATH', type=str, default='./checkpoint',
@ -155,3 +191,5 @@ if __name__ == '__main__':
qp.util.parallel(run, itertools.product(optim_losses, datasets, models), n_jobs=CUDA_N_JOBS) qp.util.parallel(run, itertools.product(optim_losses, datasets, models), n_jobs=CUDA_N_JOBS)
shutil.rmtree(args.checkpointdir, ignore_errors=True) shutil.rmtree(args.checkpointdir, ignore_errors=True)
show_results(args.results)

View File

@ -1,4 +1,3 @@
import pickle
import os import os
from time import time from time import time
from collections import defaultdict from collections import defaultdict
@ -7,11 +6,16 @@ import numpy as np
from sklearn.linear_model import LogisticRegression from sklearn.linear_model import LogisticRegression
import quapy as qp import quapy as qp
from quapy.method.aggregative import PACC, EMQ from quapy.method.aggregative import PACC, EMQ, KDEyML
from quapy.model_selection import GridSearchQ from quapy.model_selection import GridSearchQ
from quapy.protocol import UPP from quapy.protocol import UPP
from pathlib import Path from pathlib import Path
"""
This example is the analogous counterpart of example 7 but involving multiclass quantification problems
using datasets from the UCI ML repository.
"""
SEED = 1 SEED = 1
@ -31,7 +35,7 @@ def wrap_hyper(classifier_hyper_grid:dict):
METHODS = [ METHODS = [
('PACC', PACC(newLR()), wrap_hyper(logreg_grid)), ('PACC', PACC(newLR()), wrap_hyper(logreg_grid)),
('EMQ', EMQ(newLR()), wrap_hyper(logreg_grid)), ('EMQ', EMQ(newLR()), wrap_hyper(logreg_grid)),
# ('KDEy-ML', KDEyML(newLR()), {**wrap_hyper(logreg_grid), **{'bandwidth': np.linspace(0.01, 0.2, 20)}}), ('KDEy-ML', KDEyML(newLR()), {**wrap_hyper(logreg_grid), **{'bandwidth': np.linspace(0.01, 0.2, 20)}}),
] ]
@ -43,6 +47,7 @@ def show_results(result_path):
pv = df.pivot_table(index='Dataset', columns="Method", values=["MAE", "MRAE", "t_train"], margins=True) pv = df.pivot_table(index='Dataset', columns="Method", values=["MAE", "MRAE", "t_train"], margins=True)
print(pv) print(pv)
def load_timings(result_path): def load_timings(result_path):
import pandas as pd import pandas as pd
timings = defaultdict(lambda: {}) timings = defaultdict(lambda: {})
@ -59,7 +64,7 @@ if __name__ == '__main__':
qp.environ['N_JOBS'] = -1 qp.environ['N_JOBS'] = -1
n_bags_val = 250 n_bags_val = 250
n_bags_test = 1000 n_bags_test = 1000
result_dir = f'results/ucimulti' result_dir = f'results/uci_multiclass'
os.makedirs(result_dir, exist_ok=True) os.makedirs(result_dir, exist_ok=True)
@ -100,7 +105,7 @@ if __name__ == '__main__':
t_init = time() t_init = time()
try: try:
modsel.fit(train) modsel.fit(*train.Xy)
print(f'best params {modsel.best_params_}') print(f'best params {modsel.best_params_}')
print(f'best score {modsel.best_score_}') print(f'best score {modsel.best_score_}')
@ -108,7 +113,8 @@ if __name__ == '__main__':
quantifier = modsel.best_model() quantifier = modsel.best_model()
except: except:
print('something went wrong... trying to fit the default model') print('something went wrong... trying to fit the default model')
quantifier.fit(train) quantifier.fit(*train.Xy)
timings[method_name][dataset] = time() - t_init timings[method_name][dataset] = time() - t_init

View File

@ -6,6 +6,18 @@ from sklearn.linear_model import LogisticRegression
from quapy.model_selection import GridSearchQ from quapy.model_selection import GridSearchQ
from quapy.evaluation import evaluation_report from quapy.evaluation import evaluation_report
"""
This example shows a complete experiment using the IFCB Plankton dataset;
see https://hlt-isti.github.io/QuaPy/manuals/datasets.html#ifcb-plankton-dataset
Note that this dataset can be downloaded in two modes: for model selection or for evaluation.
See also:
Automatic plankton quantification using deep features
P González, A Castaño, EE Peacock, J Díez, JJ Del Coz, HM Sosik
Journal of Plankton Research 41 (4), 449-463
"""
print('Quantifying the IFCB dataset with PACC\n') print('Quantifying the IFCB dataset with PACC\n')
@ -30,7 +42,7 @@ mod_sel = GridSearchQ(
n_jobs=-1, n_jobs=-1,
verbose=True, verbose=True,
raise_errors=True raise_errors=True
).fit(train) ).fit(*train.Xy)
print(f'model selection chose hyperparameters: {mod_sel.best_params_}') print(f'model selection chose hyperparameters: {mod_sel.best_params_}')
quantifier = mod_sel.best_model_ quantifier = mod_sel.best_model_
@ -42,7 +54,7 @@ print(f'\ttraining size={len(train)}, features={train.X.shape[1]}, classes={trai
print(f'\ttest samples={test_gen.total()}') print(f'\ttest samples={test_gen.total()}')
print('training on the whole dataset before test') print('training on the whole dataset before test')
quantifier.fit(train) quantifier.fit(*train.Xy)
print('testing...') print('testing...')
report = evaluation_report(quantifier, protocol=test_gen, error_metrics=['mae'], verbose=True) report = evaluation_report(quantifier, protocol=test_gen, error_metrics=['mae'], verbose=True)

View File

@ -11,13 +11,5 @@ rm $FILE
patch -s -p0 < svm-perf-quantification-ext.patch patch -s -p0 < svm-perf-quantification-ext.patch
mv svm_perf svm_perf_quantification mv svm_perf svm_perf_quantification
cd svm_perf_quantification cd svm_perf_quantification
make make CFLAGS="-O3 -Wall -Wno-unused-result -fcommon"

View File

@ -1,5 +1,4 @@
"""QuaPy module for quantification""" """QuaPy module for quantification"""
from sklearn.linear_model import LogisticRegression
from quapy.data import datasets from quapy.data import datasets
from . import error from . import error
@ -16,6 +15,12 @@ import os
__version__ = '0.2.0' __version__ = '0.2.0'
def _default_cls():
from sklearn.linear_model import LogisticRegression
return LogisticRegression()
environ = { environ = {
'SAMPLE_SIZE': None, 'SAMPLE_SIZE': None,
'UNK_TOKEN': '[UNK]', 'UNK_TOKEN': '[UNK]',
@ -24,7 +29,7 @@ environ = {
'PAD_INDEX': 1, 'PAD_INDEX': 1,
'SVMPERF_HOME': './svm_perf_quantification', 'SVMPERF_HOME': './svm_perf_quantification',
'N_JOBS': int(os.getenv('N_JOBS', 1)), 'N_JOBS': int(os.getenv('N_JOBS', 1)),
'DEFAULT_CLS': LogisticRegression() 'DEFAULT_CLS': _default_cls()
} }
@ -68,3 +73,5 @@ def _get_classifier(classifier):
if classifier is None: if classifier is None:
raise ValueError('neither classifier nor qp.environ["DEFAULT_CLS"] have been specified') raise ValueError('neither classifier nor qp.environ["DEFAULT_CLS"] have been specified')
return classifier return classifier

View File

@ -33,27 +33,16 @@ class SVMperf(BaseEstimator, ClassifierMixin):
valid_losses = {'01':0, 'f1':1, 'kld':12, 'nkld':13, 'q':22, 'qacc':23, 'qf1':24, 'qgm':25, 'mae':26, 'mrae':27} valid_losses = {'01':0, 'f1':1, 'kld':12, 'nkld':13, 'q':22, 'qacc':23, 'qf1':24, 'qgm':25, 'mae':26, 'mrae':27}
def __init__(self, svmperf_base, C=0.01, verbose=False, loss='01', host_folder=None): def __init__(self, svmperf_base, C=0.01, verbose=False, loss='01', host_folder=None):
assert exists(svmperf_base), f'path {svmperf_base} does not seem to point to a valid path' assert exists(svmperf_base), \
(f'path {svmperf_base} does not seem to point to a valid path;'
f'did you install svm-perf? '
f'see instructions in https://hlt-isti.github.io/QuaPy/manuals/explicit-loss-minimization.html')
self.svmperf_base = svmperf_base self.svmperf_base = svmperf_base
self.C = C self.C = C
self.verbose = verbose self.verbose = verbose
self.loss = loss self.loss = loss
self.host_folder = host_folder self.host_folder = host_folder
# def set_params(self, **parameters):
# """
# Set the hyper-parameters for svm-perf. Currently, only the `C` and `loss` parameters are supported
#
# :param parameters: a `**kwargs` dictionary `{'C': <float>}`
# """
# assert sorted(list(parameters.keys())) == ['C', 'loss'], \
# 'currently, only the C and loss parameters are supported'
# self.C = parameters.get('C', self.C)
# self.loss = parameters.get('loss', self.loss)
#
# def get_params(self, deep=True):
# return {'C': self.C, 'loss': self.loss}
def fit(self, X, y): def fit(self, X, y):
""" """
Trains the SVM for the multivariate performance loss Trains the SVM for the multivariate performance loss

View File

@ -1,3 +1,7 @@
import warnings
from sklearn.exceptions import ConvergenceWarning
warnings.simplefilter("ignore", ConvergenceWarning)
from . import confidence from . import confidence
from . import base from . import base
from . import aggregative from . import aggregative
@ -63,3 +67,5 @@ QUANTIFICATION_METHODS = AGGREGATIVE_METHODS | NON_AGGREGATIVE_METHODS | META_ME

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from argparse import ArgumentError
from copy import deepcopy from copy import deepcopy
from typing import Callable, Literal, Union from typing import Callable, Literal, Union
import numpy as np import numpy as np
@ -19,6 +20,10 @@ from quapy.data import LabelledCollection
from quapy.method.base import BaseQuantifier, BinaryQuantifier, OneVsAllGeneric from quapy.method.base import BaseQuantifier, BinaryQuantifier, OneVsAllGeneric
from quapy.method import _bayesian from quapy.method import _bayesian
# import warnings
# from sklearn.exceptions import ConvergenceWarning
# warnings.filterwarnings("ignore", category=ConvergenceWarning)
# Abstract classes # Abstract classes
# ------------------------------------ # ------------------------------------
@ -51,7 +56,11 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
the training data be wasted. the training data be wasted.
""" """
def __init__(self, classifier: Union[None,BaseEstimator], fit_classifier:bool=True, val_split:Union[int,float,tuple,None]=5): def __init__(self,
classifier: Union[None,BaseEstimator],
fit_classifier:bool=True,
val_split:Union[int,float,tuple,None]=5):
self.classifier = qp._get_classifier(classifier) self.classifier = qp._get_classifier(classifier)
self.fit_classifier = fit_classifier self.fit_classifier = fit_classifier
self.val_split = val_split self.val_split = val_split
@ -63,6 +72,7 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
assert isinstance(fit_classifier, bool), \ assert isinstance(fit_classifier, bool), \
f'unexpected type for {fit_classifier=}; must be True or False' f'unexpected type for {fit_classifier=}; must be True or False'
# val_split is indicated as a number of folds for cross-validation
if isinstance(val_split, int): if isinstance(val_split, int):
assert val_split > 1, \ assert val_split > 1, \
(f'when {val_split=} is indicated as an integer, it represents the number of folds in a kFCV ' (f'when {val_split=} is indicated as an integer, it represents the number of folds in a kFCV '
@ -75,12 +85,14 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
if val_split!=5: if val_split!=5:
assert fit_classifier, (f'Parameter {val_split=} has been modified, but {fit_classifier=} ' assert fit_classifier, (f'Parameter {val_split=} has been modified, but {fit_classifier=} '
f'indicates the classifier should not be retrained.') f'indicates the classifier should not be retrained.')
# val_split is indicated as a fraction of validation instances
elif isinstance(val_split, float): elif isinstance(val_split, float):
assert 0 < val_split < 1, \ assert 0 < val_split < 1, \
(f'when {val_split=} is indicated as a float, it represents the fraction of training instances ' (f'when {val_split=} is indicated as a float, it represents the fraction of training instances '
f'to be used for validation, and must thus be in the range (0,1)') f'to be used for validation, and must thus be in the range (0,1)')
assert fit_classifier, (f'when {val_split=} is indicated as a float (the fraction of training instances ' assert fit_classifier, (f'when {val_split=} is indicated as a float (the fraction of training instances '
f'to be used for validation), the parameter {fit_classifier=} must be True') f'to be used for validation), the parameter {fit_classifier=} must be True')
# val_split is indicated as a validation collection (X,y)
elif isinstance(val_split, tuple): elif isinstance(val_split, tuple):
assert len(val_split) == 2, \ assert len(val_split) == 2, \
(f'when {val_split=} is indicated as a tuple, it represents the collection (X,y) on which the ' (f'when {val_split=} is indicated as a tuple, it represents the collection (X,y) on which the '
@ -674,26 +686,26 @@ class EMQ(AggregativeSoftQuantifier):
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be :param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
the one indicated in `qp.environ['DEFAULT_CLS']` the one indicated in `qp.environ['DEFAULT_CLS']`
:param fit_classifier: whether to train the learner (default is True). Set to False if the :param fit_classifier: whether to train the classifier (default is True). Set to False if the
learner has been trained outside the quantifier. given classifier has already been trained.
:param val_split: specifies the data used for generating classifier predictions. This specification :param val_split: specifies the data used for generating the classifier predictions on which the
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to aggregation function is to be trained. This specification can be made as float in (0, 1) indicating
be extracted from the training set; or as an integer (default 5), indicating that the predictions the proportion of stratified held-out validation set to be extracted from the training set; or as
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value an integer (default 5), indicating that the predictions are to be generated in a `k`-fold
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation. cross-validation manner (with this integer indicating the value for `k`); or as a tuple (X,y) defining
This hyperparameter is only meant to be used when the heuristics are to be applied, i.e., if a the specific set of data to use for validation. This hyperparameter is only meant to be used when
calibration is required. The default value is None (meaning the calibration is not required). In the heuristics are to be applied, i.e., if a calibration is required. The default value is None
case this hyperparameter is set to a value other than None, but the calibration is not required (meaning the calibration is not required). In case this hyperparameter is set to a value other than
(calib=None), a warning message will be raised. None, but the calibration is not required (calib=None), a warning message will be raised.
:param exact_train_prev: set to True (default) for using the true training prevalence as the initial observation; :param exact_train_prev: set to True (default) for using the true training prevalence as the initial
set to False for computing the training prevalence as an estimate of it, i.e., as the expected observation; set to False for computing the training prevalence as an estimate of it, i.e., as the
value of the posterior probabilities of the training instances. expected value of the posterior probabilities of the training instances.
:param calib: a string indicating the method of calibration. :param calib: a string indicating the method of calibration.
Available choices include "nbvs" (No-Bias Vector Scaling), "bcts" (Bias-Corrected Temperature Scaling, Available choices include "nbvs" (No-Bias Vector Scaling), "bcts" (Bias-Corrected Temperature Scaling),
default), "ts" (Temperature Scaling), and "vs" (Vector Scaling). Default is None (no calibration). "ts" (Temperature Scaling), and "vs" (Vector Scaling). Default is None (no calibration).
:param on_calib_error: a string indicating the policy to follow in case the calibrator fails at runtime. :param on_calib_error: a string indicating the policy to follow in case the calibrator fails at runtime.
Options include "raise" (default), in which case a RuntimeException is raised; and "backup", in which Options include "raise" (default), in which case a RuntimeException is raised; and "backup", in which
@ -823,6 +835,19 @@ class EMQ(AggregativeSoftQuantifier):
""" """
P = classif_predictions P = classif_predictions
y = labels y = labels
requires_predictions = (self.calib is not None) or (not self.exact_train_prev)
if P is None and requires_predictions:
# classifier predictions were not generated because val_split=None
raise ArgumentError(self.val_split, self.__class__.__name__ +
": Classifier predictions for the aggregative fit were not generated because "
"val_split=None. This usually happens when you enable calibrations or heuristics "
"during model selection but left val_split set to its default value (None). "
"Please provide one of the following values for val_split: (i) an integer >1 "
"(e.g. val_split=5) for k-fold cross-validation; (ii) a float in (0,1) (e.g. "
"val_split=0.3) for a proportion split; or (iii) a tuple (X, y) with explicit "
"validation data")
if self.calib is not None: if self.calib is not None:
calibrator = { calibrator = {
'nbvs': NoBiasVectorScaling(), 'nbvs': NoBiasVectorScaling(),

View File

@ -86,14 +86,14 @@ class GridSearchQ(BaseQuantifier):
self.n_jobs = qp._get_njobs(n_jobs) self.n_jobs = qp._get_njobs(n_jobs)
self.raise_errors = raise_errors self.raise_errors = raise_errors
self.verbose = verbose self.verbose = verbose
self.__check_error(error) self.__check_error_measure(error)
assert isinstance(protocol, AbstractProtocol), 'unknown protocol' assert isinstance(protocol, AbstractProtocol), 'unknown protocol'
def _sout(self, msg): def _sout(self, msg):
if self.verbose: if self.verbose:
print(f'[{self.__class__.__name__}:{self.model.__class__.__name__}]: {msg}') print(f'[{self.__class__.__name__}:{self.model.__class__.__name__}]: {msg}')
def __check_error(self, error): def __check_error_measure(self, error):
if error in qp.error.QUANTIFICATION_ERROR: if error in qp.error.QUANTIFICATION_ERROR:
self.error = error self.error = error
elif isinstance(error, str): elif isinstance(error, str):