forked from moreo/QuaPy
renaming functions to match the app and npp nomenclature; adding npp as an option for GridSearchQ
This commit is contained in:
parent
f28a84242f
commit
be2f54de9c
11
TODO.txt
11
TODO.txt
|
@ -4,6 +4,15 @@ Documentation with sphinx
|
||||||
Document methods with paper references
|
Document methods with paper references
|
||||||
unit-tests
|
unit-tests
|
||||||
|
|
||||||
|
Refactor:
|
||||||
|
==========================================
|
||||||
|
Unify ThresholdOptimization methods, as an extension of PACC (and not ACC), the fit methods are almost identical and
|
||||||
|
use a prob classifier (take into account that PACC uses pcc internally, whereas the threshold methods use cc
|
||||||
|
instead). The fit method of ACC and PACC has a block for estimating the validation estimates that should be unified
|
||||||
|
as well...
|
||||||
|
Rename APP NPP
|
||||||
|
Add NPP as an option for GridSearchQ
|
||||||
|
|
||||||
New features:
|
New features:
|
||||||
==========================================
|
==========================================
|
||||||
Add NAE, NRAE
|
Add NAE, NRAE
|
||||||
|
@ -21,6 +30,7 @@ Add automatic reindex of class labels in LabelledCollection (currently, class in
|
||||||
OVR I believe is currently tied to aggregative methods. We should provide a general interface also for general quantifiers
|
OVR I believe is currently tied to aggregative methods. We should provide a general interface also for general quantifiers
|
||||||
Currently, being "binary" only adds one checker; we should figure out how to impose the check to be automatically performed
|
Currently, being "binary" only adds one checker; we should figure out how to impose the check to be automatically performed
|
||||||
Add random seed management to support replicability (see temp_seed in util.py).
|
Add random seed management to support replicability (see temp_seed in util.py).
|
||||||
|
GridSearchQ is not trully parallelized. It only parallelizes on the predictions.
|
||||||
|
|
||||||
Improvements:
|
Improvements:
|
||||||
==========================================
|
==========================================
|
||||||
|
@ -34,6 +44,7 @@ We might want to think of (improving and) adding the class Tabular (it is define
|
||||||
to generate tables is typically a bad idea, but in this specific case we do have pretty good control of what an
|
to generate tables is typically a bad idea, but in this specific case we do have pretty good control of what an
|
||||||
experiment looks like. (Do we want to abstract experimental results? this could be useful not only for tables but
|
experiment looks like. (Do we want to abstract experimental results? this could be useful not only for tables but
|
||||||
also for plots).
|
also for plots).
|
||||||
|
Add proper logging system. Currently we use print
|
||||||
|
|
||||||
Checks:
|
Checks:
|
||||||
==========================================
|
==========================================
|
||||||
|
|
|
@ -88,12 +88,12 @@ class LabelledCollection:
|
||||||
|
|
||||||
return indexes_sample
|
return indexes_sample
|
||||||
|
|
||||||
# def uniform_sampling_index(self, size):
|
def uniform_sampling_index(self, size):
|
||||||
# return np.random.choice(len(self), size, replace=False)
|
return np.random.choice(len(self), size, replace=False)
|
||||||
|
|
||||||
# def uniform_sampling(self, size):
|
def uniform_sampling(self, size):
|
||||||
# unif_index = self.uniform_sampling_index(size)
|
unif_index = self.uniform_sampling_index(size)
|
||||||
# return self.sampling_from_index(unif_index)
|
return self.sampling_from_index(unif_index)
|
||||||
|
|
||||||
def sampling(self, size, *prevs, shuffle=True):
|
def sampling(self, size, *prevs, shuffle=True):
|
||||||
prev_index = self.sampling_index(size, *prevs, shuffle=shuffle)
|
prev_index = self.sampling_index(size, *prevs, shuffle=shuffle)
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
from typing import Union, Callable, Iterable
|
from typing import Union, Callable, Iterable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from joblib import Parallel, delayed
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import quapy as qp
|
import quapy as qp
|
||||||
|
@ -12,8 +11,7 @@ import quapy.functional as F
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
def artificial_prevalence_prediction(
|
||||||
def artificial_sampling_prediction(
|
|
||||||
model: BaseQuantifier,
|
model: BaseQuantifier,
|
||||||
test: LabelledCollection,
|
test: LabelledCollection,
|
||||||
sample_size,
|
sample_size,
|
||||||
|
@ -51,7 +49,7 @@ def artificial_sampling_prediction(
|
||||||
return _predict_from_indexes(indexes, model, test, n_jobs, verbose)
|
return _predict_from_indexes(indexes, model, test, n_jobs, verbose)
|
||||||
|
|
||||||
|
|
||||||
def natural_sampling_prediction(
|
def natural_prevalence_prediction(
|
||||||
model: BaseQuantifier,
|
model: BaseQuantifier,
|
||||||
test: LabelledCollection,
|
test: LabelledCollection,
|
||||||
sample_size,
|
sample_size,
|
||||||
|
@ -117,7 +115,7 @@ def _predict_from_indexes(
|
||||||
return true_prevalences, estim_prevalences
|
return true_prevalences, estim_prevalences
|
||||||
|
|
||||||
|
|
||||||
def artificial_sampling_report(
|
def artificial_prevalence_report(
|
||||||
model: BaseQuantifier,
|
model: BaseQuantifier,
|
||||||
test: LabelledCollection,
|
test: LabelledCollection,
|
||||||
sample_size,
|
sample_size,
|
||||||
|
@ -129,13 +127,13 @@ def artificial_sampling_report(
|
||||||
error_metrics:Iterable[Union[str,Callable]]='mae',
|
error_metrics:Iterable[Union[str,Callable]]='mae',
|
||||||
verbose=False):
|
verbose=False):
|
||||||
|
|
||||||
true_prevs, estim_prevs = artificial_sampling_prediction(
|
true_prevs, estim_prevs = artificial_prevalence_prediction(
|
||||||
model, test, sample_size, n_prevpoints, n_repetitions, eval_budget, n_jobs, random_seed, verbose
|
model, test, sample_size, n_prevpoints, n_repetitions, eval_budget, n_jobs, random_seed, verbose
|
||||||
)
|
)
|
||||||
return _sampling_report(true_prevs, estim_prevs, error_metrics)
|
return _prevalence_report(true_prevs, estim_prevs, error_metrics)
|
||||||
|
|
||||||
|
|
||||||
def natural_sampling_report(
|
def natural_prevalence_report(
|
||||||
model: BaseQuantifier,
|
model: BaseQuantifier,
|
||||||
test: LabelledCollection,
|
test: LabelledCollection,
|
||||||
sample_size,
|
sample_size,
|
||||||
|
@ -145,13 +143,13 @@ def natural_sampling_report(
|
||||||
error_metrics:Iterable[Union[str,Callable]]='mae',
|
error_metrics:Iterable[Union[str,Callable]]='mae',
|
||||||
verbose=False):
|
verbose=False):
|
||||||
|
|
||||||
true_prevs, estim_prevs = natural_sampling_prediction(
|
true_prevs, estim_prevs = natural_prevalence_prediction(
|
||||||
model, test, sample_size, n_repetitions, n_jobs, random_seed, verbose
|
model, test, sample_size, n_repetitions, n_jobs, random_seed, verbose
|
||||||
)
|
)
|
||||||
return _sampling_report(true_prevs, estim_prevs, error_metrics)
|
return _prevalence_report(true_prevs, estim_prevs, error_metrics)
|
||||||
|
|
||||||
|
|
||||||
def _sampling_report(
|
def _prevalence_report(
|
||||||
true_prevs,
|
true_prevs,
|
||||||
estim_prevs,
|
estim_prevs,
|
||||||
error_metrics: Iterable[Union[str, Callable]] = 'mae'):
|
error_metrics: Iterable[Union[str, Callable]] = 'mae'):
|
||||||
|
@ -173,7 +171,8 @@ def _sampling_report(
|
||||||
|
|
||||||
return df
|
return df
|
||||||
|
|
||||||
def artificial_sampling_eval(
|
|
||||||
|
def artificial_prevalence_protocol(
|
||||||
model: BaseQuantifier,
|
model: BaseQuantifier,
|
||||||
test: LabelledCollection,
|
test: LabelledCollection,
|
||||||
sample_size,
|
sample_size,
|
||||||
|
@ -190,14 +189,14 @@ def artificial_sampling_eval(
|
||||||
|
|
||||||
assert hasattr(error_metric, '__call__'), 'invalid error function'
|
assert hasattr(error_metric, '__call__'), 'invalid error function'
|
||||||
|
|
||||||
true_prevs, estim_prevs = artificial_sampling_prediction(
|
true_prevs, estim_prevs = artificial_prevalence_prediction(
|
||||||
model, test, sample_size, n_prevpoints, n_repetitions, eval_budget, n_jobs, random_seed, verbose
|
model, test, sample_size, n_prevpoints, n_repetitions, eval_budget, n_jobs, random_seed, verbose
|
||||||
)
|
)
|
||||||
|
|
||||||
return error_metric(true_prevs, estim_prevs)
|
return error_metric(true_prevs, estim_prevs)
|
||||||
|
|
||||||
|
|
||||||
def natural_sampling_eval(
|
def natural_prevalence_protocol(
|
||||||
model: BaseQuantifier,
|
model: BaseQuantifier,
|
||||||
test: LabelledCollection,
|
test: LabelledCollection,
|
||||||
sample_size,
|
sample_size,
|
||||||
|
@ -212,7 +211,7 @@ def natural_sampling_eval(
|
||||||
|
|
||||||
assert hasattr(error_metric, '__call__'), 'invalid error function'
|
assert hasattr(error_metric, '__call__'), 'invalid error function'
|
||||||
|
|
||||||
true_prevs, estim_prevs = natural_sampling_prediction(
|
true_prevs, estim_prevs = natural_prevalence_prediction(
|
||||||
model, test, sample_size, n_repetitions, n_jobs, random_seed, verbose
|
model, test, sample_size, n_repetitions, n_jobs, random_seed, verbose
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -513,6 +513,170 @@ class SVMRAE(ELM):
|
||||||
super(SVMRAE, self).__init__(svmperf_base, loss='mrae', **kwargs)
|
super(SVMRAE, self).__init__(svmperf_base, loss='mrae', **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class ThresholdOptimization(AggregativeQuantifier, BinaryQuantifier):
|
||||||
|
|
||||||
|
def __init__(self, learner: BaseEstimator, val_split=0.4):
|
||||||
|
self.learner = learner
|
||||||
|
self.val_split = val_split
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def optimize_threshold(self, y, probabilities):
|
||||||
|
...
|
||||||
|
|
||||||
|
def fit(self, data: LabelledCollection, fit_learner=True, val_split: Union[float, int, LabelledCollection] = None):
|
||||||
|
BinaryQuantifier._check_binary(data, "Threshold Optimization")
|
||||||
|
|
||||||
|
if val_split is None:
|
||||||
|
val_split = self.val_split
|
||||||
|
if isinstance(val_split, int):
|
||||||
|
# kFCV estimation of parameters
|
||||||
|
y, probabilities = [], []
|
||||||
|
kfcv = StratifiedKFold(n_splits=val_split)
|
||||||
|
pbar = tqdm(kfcv.split(*data.Xy), total=val_split)
|
||||||
|
for k, (training_idx, validation_idx) in enumerate(pbar):
|
||||||
|
pbar.set_description(f'{self.__class__.__name__} fitting fold {k}')
|
||||||
|
training = data.sampling_from_index(training_idx)
|
||||||
|
validation = data.sampling_from_index(validation_idx)
|
||||||
|
learner, val_data = training_helper(self.learner, training, fit_learner, val_split=validation)
|
||||||
|
probabilities.append(learner.predict_proba(val_data.instances))
|
||||||
|
y.append(val_data.labels)
|
||||||
|
|
||||||
|
y = np.concatenate(y)
|
||||||
|
probabilities = np.concatenate(probabilities)
|
||||||
|
|
||||||
|
# fit the learner on all data
|
||||||
|
self.learner, _ = training_helper(self.learner, data, fit_learner, val_split=None)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.learner, val_data = training_helper(self.learner, data, fit_learner, val_split=val_split)
|
||||||
|
probabilities = self.learner.predict_proba(val_data.instances)
|
||||||
|
y = val_data.labels
|
||||||
|
|
||||||
|
self.cc = CC(self.learner)
|
||||||
|
|
||||||
|
self.tpr, self.fpr = self.optimize_threshold(y, probabilities)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _condition(self, tpr, fpr) -> float:
|
||||||
|
"""
|
||||||
|
Implements the criterion according to which the threshold should be selected.
|
||||||
|
This function should return a (float) score to be minimized.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def optimize_threshold(self, y, probabilities):
|
||||||
|
best_candidate_threshold_score = None
|
||||||
|
best_tpr = 0
|
||||||
|
best_fpr = 0
|
||||||
|
candidate_thresholds = np.unique(probabilities[:, 1])
|
||||||
|
for candidate_threshold in candidate_thresholds:
|
||||||
|
y_ = [self.classes_[1] if p > candidate_threshold else self.classes_[0] for p in probabilities[:, 1]]
|
||||||
|
TP, FP, FN, TN = self.compute_table(y, y_)
|
||||||
|
tpr = self.compute_tpr(TP, FP)
|
||||||
|
fpr = self.compute_fpr(FP, TN)
|
||||||
|
condition_score = self._condition(tpr, fpr)
|
||||||
|
if best_candidate_threshold_score is None or condition_score < best_candidate_threshold_score:
|
||||||
|
best_candidate_threshold_score = condition_score
|
||||||
|
best_tpr = tpr
|
||||||
|
best_fpr = fpr
|
||||||
|
|
||||||
|
return best_tpr, best_fpr
|
||||||
|
|
||||||
|
def aggregate(self, classif_predictions):
|
||||||
|
prevs_estim = self.cc.aggregate(classif_predictions)
|
||||||
|
if self.tpr - self.fpr == 0:
|
||||||
|
return prevs_estim
|
||||||
|
adjusted_prevs_estim = np.clip((prevs_estim[1] - self.fpr) / (self.tpr - self.fpr), 0, 1)
|
||||||
|
adjusted_prevs_estim = np.array((1 - adjusted_prevs_estim, adjusted_prevs_estim))
|
||||||
|
return adjusted_prevs_estim
|
||||||
|
|
||||||
|
def compute_table(self, y, y_):
|
||||||
|
TP = np.logical_and(y == y_, y == self.classes_[1]).sum()
|
||||||
|
FP = np.logical_and(y != y_, y == self.classes_[0]).sum()
|
||||||
|
FN = np.logical_and(y != y_, y == self.classes_[1]).sum()
|
||||||
|
TN = np.logical_and(y == y_, y == self.classes_[0]).sum()
|
||||||
|
return TP, FP, FN, TN
|
||||||
|
|
||||||
|
def compute_tpr(self, TP, FP):
|
||||||
|
if TP + FP == 0:
|
||||||
|
return 0
|
||||||
|
return TP / (TP + FP)
|
||||||
|
|
||||||
|
def compute_fpr(self, FP, TN):
|
||||||
|
if FP + TN == 0:
|
||||||
|
return 0
|
||||||
|
return FP / (FP + TN)
|
||||||
|
|
||||||
|
|
||||||
|
class T50(ThresholdOptimization):
|
||||||
|
|
||||||
|
def __init__(self, learner: BaseEstimator, val_split=0.4):
|
||||||
|
super().__init__(learner, val_split)
|
||||||
|
|
||||||
|
def _condition(self, tpr, fpr) -> float:
|
||||||
|
return abs(tpr - 0.5)
|
||||||
|
|
||||||
|
|
||||||
|
class MAX(ThresholdOptimization):
|
||||||
|
|
||||||
|
def __init__(self, learner: BaseEstimator, val_split=0.4):
|
||||||
|
super().__init__(learner, val_split)
|
||||||
|
|
||||||
|
def _condition(self, tpr, fpr) -> float:
|
||||||
|
# MAX strives to maximize (tpr - fpr), which is equivalent to minimize (fpr - tpr)
|
||||||
|
return (fpr - tpr)
|
||||||
|
|
||||||
|
|
||||||
|
class X(ThresholdOptimization):
|
||||||
|
|
||||||
|
def __init__(self, learner: BaseEstimator, val_split=0.4):
|
||||||
|
super().__init__(learner, val_split)
|
||||||
|
|
||||||
|
def _condition(self, tpr, fpr) -> float:
|
||||||
|
return abs(1 - (tpr + fpr))
|
||||||
|
|
||||||
|
|
||||||
|
class MS(ThresholdOptimization):
|
||||||
|
|
||||||
|
def __init__(self, learner: BaseEstimator, val_split=0.4):
|
||||||
|
super().__init__(learner, val_split)
|
||||||
|
|
||||||
|
def optimize_threshold(self, y, probabilities):
|
||||||
|
tprs = []
|
||||||
|
fprs = []
|
||||||
|
candidate_thresholds = np.unique(probabilities[:, 1])
|
||||||
|
for candidate_threshold in candidate_thresholds:
|
||||||
|
y_ = [self.classes_[1] if p > candidate_threshold else self.classes_[0] for p in probabilities[:, 1]]
|
||||||
|
TP, FP, FN, TN = self.compute_table(y, y_)
|
||||||
|
tpr = self.compute_tpr(TP, FP)
|
||||||
|
fpr = self.compute_fpr(FP, TN)
|
||||||
|
tprs.append(tpr)
|
||||||
|
fprs.append(fpr)
|
||||||
|
return np.median(tprs), np.median(fprs)
|
||||||
|
|
||||||
|
|
||||||
|
class MS2(MS):
|
||||||
|
|
||||||
|
def __init__(self, learner: BaseEstimator, val_split=0.4):
|
||||||
|
super().__init__(learner, val_split)
|
||||||
|
|
||||||
|
def optimize_threshold(self, y, probabilities):
|
||||||
|
tprs = [0, 1]
|
||||||
|
fprs = [0, 1]
|
||||||
|
candidate_thresholds = np.unique(probabilities[:, 1])
|
||||||
|
for candidate_threshold in candidate_thresholds:
|
||||||
|
y_ = [self.classes_[1] if p > candidate_threshold else self.classes_[0] for p in probabilities[:, 1]]
|
||||||
|
TP, FP, FN, TN = self.compute_table(y, y_)
|
||||||
|
tpr = self.compute_tpr(TP, FP)
|
||||||
|
fpr = self.compute_fpr(FP, TN)
|
||||||
|
if (tpr - fpr) > 0.25:
|
||||||
|
tprs.append(tpr)
|
||||||
|
fprs.append(fpr)
|
||||||
|
return np.median(tprs), np.median(fprs)
|
||||||
|
|
||||||
|
|
||||||
ClassifyAndCount = CC
|
ClassifyAndCount = CC
|
||||||
AdjustedClassifyAndCount = ACC
|
AdjustedClassifyAndCount = ACC
|
||||||
ProbabilisticClassifyAndCount = PCC
|
ProbabilisticClassifyAndCount = PCC
|
||||||
|
@ -520,6 +684,8 @@ ProbabilisticAdjustedClassifyAndCount = PACC
|
||||||
ExpectationMaximizationQuantifier = EMQ
|
ExpectationMaximizationQuantifier = EMQ
|
||||||
HellingerDistanceY = HDy
|
HellingerDistanceY = HDy
|
||||||
ExplicitLossMinimisation = ELM
|
ExplicitLossMinimisation = ELM
|
||||||
|
MedianSweep = MS
|
||||||
|
MedianSweep2 = MS2
|
||||||
|
|
||||||
|
|
||||||
class OneVsAll(AggregativeQuantifier):
|
class OneVsAll(AggregativeQuantifier):
|
||||||
|
|
|
@ -5,7 +5,7 @@ from typing import Union, Callable
|
||||||
|
|
||||||
import quapy as qp
|
import quapy as qp
|
||||||
from quapy.data.base import LabelledCollection
|
from quapy.data.base import LabelledCollection
|
||||||
from quapy.evaluation import artificial_sampling_prediction
|
from quapy.evaluation import artificial_prevalence_prediction, natural_prevalence_prediction
|
||||||
from quapy.method.aggregative import BaseQuantifier
|
from quapy.method.aggregative import BaseQuantifier
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ class GridSearchQ(BaseQuantifier):
|
||||||
model: BaseQuantifier,
|
model: BaseQuantifier,
|
||||||
param_grid: dict,
|
param_grid: dict,
|
||||||
sample_size: int,
|
sample_size: int,
|
||||||
|
protocol='app',
|
||||||
n_prevpoints: int = None,
|
n_prevpoints: int = None,
|
||||||
n_repetitions: int = 1,
|
n_repetitions: int = 1,
|
||||||
eval_budget: int = None,
|
eval_budget: int = None,
|
||||||
|
@ -29,15 +30,15 @@ class GridSearchQ(BaseQuantifier):
|
||||||
Optimizes the hyperparameters of a quantification method, based on an evaluation method and on an evaluation
|
Optimizes the hyperparameters of a quantification method, based on an evaluation method and on an evaluation
|
||||||
protocol for quantification.
|
protocol for quantification.
|
||||||
:param model: the quantifier to optimize
|
:param model: the quantifier to optimize
|
||||||
:param training: the training set on which to optimize the hyperparameters
|
|
||||||
:param validation: either a LabelledCollection on which to test the performance of the different settings, or
|
|
||||||
a float in [0,1] indicating the proportion of labelled data to extract from the training set
|
|
||||||
:param param_grid: a dictionary with keys the parameter names and values the list of values to explore for
|
:param param_grid: a dictionary with keys the parameter names and values the list of values to explore for
|
||||||
that particular parameter
|
|
||||||
:param sample_size: the size of the samples to extract from the validation set
|
:param sample_size: the size of the samples to extract from the validation set
|
||||||
|
that particular parameter
|
||||||
|
:param protocol: either 'app' for the artificial prevalence protocol, or 'npp' for the natural prevalence
|
||||||
|
protocol
|
||||||
:param n_prevpoints: if specified, indicates the number of equally distant point to extract from the interval
|
:param n_prevpoints: if specified, indicates the number of equally distant point to extract from the interval
|
||||||
[0,1] in order to define the prevalences of the samples; e.g., if n_prevpoints=5, then the prevalences for
|
[0,1] in order to define the prevalences of the samples; e.g., if n_prevpoints=5, then the prevalences for
|
||||||
each class will be explored in [0.00, 0.25, 0.50, 0.75, 1.00]. If not specified, then eval_budget is requested
|
each class will be explored in [0.00, 0.25, 0.50, 0.75, 1.00]. If not specified, then eval_budget is requested.
|
||||||
|
Ignored if protocol='npp'.
|
||||||
:param n_repetitions: the number of repetitions for each combination of prevalences. This parameter is ignored
|
:param n_repetitions: the number of repetitions for each combination of prevalences. This parameter is ignored
|
||||||
if eval_budget is set and is lower than the number of combinations that would be generated using the value
|
if eval_budget is set and is lower than the number of combinations that would be generated using the value
|
||||||
assigned to n_prevpoints (for the current number of classes and n_repetitions)
|
assigned to n_prevpoints (for the current number of classes and n_repetitions)
|
||||||
|
@ -45,10 +46,13 @@ class GridSearchQ(BaseQuantifier):
|
||||||
combination. For example, if there are 3 classes, n_repetitions=1 and eval_budget=20, then n_prevpoints will be
|
combination. For example, if there are 3 classes, n_repetitions=1 and eval_budget=20, then n_prevpoints will be
|
||||||
set to 5, since this will generate 15 different prevalences:
|
set to 5, since this will generate 15 different prevalences:
|
||||||
[0, 0, 1], [0, 0.25, 0.75], [0, 0.5, 0.5] ... [1, 0, 0]
|
[0, 0, 1], [0, 0.25, 0.75], [0, 0.5, 0.5] ... [1, 0, 0]
|
||||||
|
Ignored if protocol='npp'.
|
||||||
:param error: an error function (callable) or a string indicating the name of an error function (valid ones
|
:param error: an error function (callable) or a string indicating the name of an error function (valid ones
|
||||||
are those in qp.error.QUANTIFICATION_ERROR
|
are those in qp.error.QUANTIFICATION_ERROR
|
||||||
:param refit: whether or not to refit the model on the whole labelled collection (training+validation) with
|
:param refit: whether or not to refit the model on the whole labelled collection (training+validation) with
|
||||||
the best chosen hyperparameter combination
|
the best chosen hyperparameter combination
|
||||||
|
:param val_split: either a LabelledCollection on which to test the performance of the different settings, or
|
||||||
|
a float in [0,1] indicating the proportion of labelled data to extract from the training set
|
||||||
:param n_jobs: number of parallel jobs
|
:param n_jobs: number of parallel jobs
|
||||||
:param random_seed: set the seed of the random generator to replicate experiments
|
:param random_seed: set the seed of the random generator to replicate experiments
|
||||||
:param timeout: establishes a timer (in seconds) for each of the hyperparameters configurations being tested.
|
:param timeout: establishes a timer (in seconds) for each of the hyperparameters configurations being tested.
|
||||||
|
@ -59,6 +63,7 @@ class GridSearchQ(BaseQuantifier):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.param_grid = param_grid
|
self.param_grid = param_grid
|
||||||
self.sample_size = sample_size
|
self.sample_size = sample_size
|
||||||
|
self.protocol = protocol.lower()
|
||||||
self.n_prevpoints = n_prevpoints
|
self.n_prevpoints = n_prevpoints
|
||||||
self.n_repetitions = n_repetitions
|
self.n_repetitions = n_repetitions
|
||||||
self.eval_budget = eval_budget
|
self.eval_budget = eval_budget
|
||||||
|
@ -69,6 +74,19 @@ class GridSearchQ(BaseQuantifier):
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.__check_error(error)
|
self.__check_error(error)
|
||||||
|
assert self.protocol in {'app', 'npp'}, \
|
||||||
|
'unknown protocol; valid ones are "app" or "npp" for the "artificial" or the "natural" prevalence protocols'
|
||||||
|
if self.protocol == 'npp':
|
||||||
|
if self.n_repetitions is None or self.n_repetitions == 1:
|
||||||
|
if self.eval_budget is not None:
|
||||||
|
print(f'[warning] when protocol=="npp" the parameter n_repetitions should be indicated '
|
||||||
|
f'(and not eval_budget). Setting n_repetitions={self.eval_budget}...')
|
||||||
|
self.n_repetitions = self.eval_budget
|
||||||
|
else:
|
||||||
|
raise ValueError(f'when protocol=="npp" the parameter n_repetitions should be indicated '
|
||||||
|
f'(and should be >1).')
|
||||||
|
if self.n_prevpoints is not None:
|
||||||
|
print('[warning] n_prevpoints has been set along with the npp protocol, and will be ignored')
|
||||||
|
|
||||||
def sout(self, msg):
|
def sout(self, msg):
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
@ -83,7 +101,7 @@ class GridSearchQ(BaseQuantifier):
|
||||||
return training, validation
|
return training, validation
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'"validation" must either be a LabelledCollection or a float in (0,1) indicating the'
|
raise ValueError(f'"validation" must either be a LabelledCollection or a float in (0,1) indicating the'
|
||||||
f'proportion of training documents to extract (found) {type(validation)}')
|
f'proportion of training documents to extract (type found: {type(validation)})')
|
||||||
|
|
||||||
def __check_error(self, error):
|
def __check_error(self, error):
|
||||||
if error in qp.error.QUANTIFICATION_ERROR:
|
if error in qp.error.QUANTIFICATION_ERROR:
|
||||||
|
@ -96,6 +114,27 @@ class GridSearchQ(BaseQuantifier):
|
||||||
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
|
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
|
||||||
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
|
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
|
||||||
|
|
||||||
|
def __generate_predictions(self, model, val_split):
|
||||||
|
commons = {
|
||||||
|
'n_repetitions': self.n_repetitions,
|
||||||
|
'n_jobs': self.n_jobs,
|
||||||
|
'random_seed': self.random_seed,
|
||||||
|
'verbose': False
|
||||||
|
}
|
||||||
|
if self.protocol == 'app':
|
||||||
|
return artificial_prevalence_prediction(
|
||||||
|
model, val_split, self.sample_size,
|
||||||
|
n_prevpoints=self.n_prevpoints,
|
||||||
|
eval_budget=self.eval_budget,
|
||||||
|
**commons
|
||||||
|
)
|
||||||
|
elif self.protocol == 'npp':
|
||||||
|
return natural_prevalence_prediction(
|
||||||
|
model, val_split, self.sample_size,
|
||||||
|
**commons)
|
||||||
|
else:
|
||||||
|
raise ValueError('unknown protocol')
|
||||||
|
|
||||||
def fit(self, training: LabelledCollection, val_split: Union[LabelledCollection, float] = None):
|
def fit(self, training: LabelledCollection, val_split: Union[LabelledCollection, float] = None):
|
||||||
"""
|
"""
|
||||||
:param training: the training set on which to optimize the hyperparameters
|
:param training: the training set on which to optimize the hyperparameters
|
||||||
|
@ -134,16 +173,7 @@ class GridSearchQ(BaseQuantifier):
|
||||||
# overrides default parameters with the parameters being explored at this iteration
|
# overrides default parameters with the parameters being explored at this iteration
|
||||||
model.set_params(**params)
|
model.set_params(**params)
|
||||||
model.fit(training)
|
model.fit(training)
|
||||||
true_prevalences, estim_prevalences = artificial_sampling_prediction(
|
true_prevalences, estim_prevalences = self.__generate_predictions(model, val_split)
|
||||||
model, val_split, self.sample_size,
|
|
||||||
n_prevpoints=self.n_prevpoints,
|
|
||||||
n_repetitions=self.n_repetitions,
|
|
||||||
eval_budget=self.eval_budget,
|
|
||||||
n_jobs=n_jobs,
|
|
||||||
random_seed=self.random_seed,
|
|
||||||
verbose=False
|
|
||||||
)
|
|
||||||
|
|
||||||
score = self.error(true_prevalences, estim_prevalences)
|
score = self.error(true_prevalences, estim_prevalences)
|
||||||
self.sout(f'checking hyperparams={params} got {self.error.__name__} score {score:.5f}')
|
self.sout(f'checking hyperparams={params} got {self.error.__name__} score {score:.5f}')
|
||||||
if self.best_score_ is None or score < self.best_score_:
|
if self.best_score_ is None or score < self.best_score_:
|
||||||
|
@ -173,6 +203,7 @@ class GridSearchQ(BaseQuantifier):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def quantify(self, instances):
|
def quantify(self, instances):
|
||||||
|
assert hasattr(self, 'best_model_'), 'quantify called before fit'
|
||||||
return self.best_model_.quantify(instances)
|
return self.best_model_.quantify(instances)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
Loading…
Reference in New Issue