fix added for len of a LabelledCollection
This commit is contained in:
parent
34c60e0870
commit
51c3d54aa5
quapy
|
@ -1,17 +1,17 @@
|
||||||
import itertools
|
import itertools
|
||||||
import signal
|
import signal
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Union, Callable
|
from time import time
|
||||||
|
from typing import Callable, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn import clone
|
from sklearn import clone
|
||||||
|
|
||||||
import quapy as qp
|
import quapy as qp
|
||||||
from quapy import evaluation
|
from quapy import evaluation
|
||||||
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
|
|
||||||
from quapy.data.base import LabelledCollection
|
from quapy.data.base import LabelledCollection
|
||||||
from quapy.method.aggregative import BaseQuantifier
|
from quapy.method.aggregative import BaseQuantifier
|
||||||
from time import time
|
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
|
||||||
|
|
||||||
|
|
||||||
class GridSearchQ(BaseQuantifier):
|
class GridSearchQ(BaseQuantifier):
|
||||||
|
@ -34,7 +34,8 @@ class GridSearchQ(BaseQuantifier):
|
||||||
:param verbose: set to True to get information through the stdout
|
:param verbose: set to True to get information through the stdout
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
model: BaseQuantifier,
|
model: BaseQuantifier,
|
||||||
param_grid: dict,
|
param_grid: dict,
|
||||||
protocol: AbstractProtocol,
|
protocol: AbstractProtocol,
|
||||||
|
@ -42,8 +43,8 @@ class GridSearchQ(BaseQuantifier):
|
||||||
refit=True,
|
refit=True,
|
||||||
timeout=-1,
|
timeout=-1,
|
||||||
n_jobs=None,
|
n_jobs=None,
|
||||||
verbose=False):
|
verbose=False,
|
||||||
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.param_grid = param_grid
|
self.param_grid = param_grid
|
||||||
self.protocol = protocol
|
self.protocol = protocol
|
||||||
|
@ -52,25 +53,27 @@ class GridSearchQ(BaseQuantifier):
|
||||||
self.n_jobs = qp._get_njobs(n_jobs)
|
self.n_jobs = qp._get_njobs(n_jobs)
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.__check_error(error)
|
self.__check_error(error)
|
||||||
assert isinstance(protocol, AbstractProtocol), 'unknown protocol'
|
assert isinstance(protocol, AbstractProtocol), "unknown protocol"
|
||||||
|
|
||||||
def _sout(self, msg):
|
def _sout(self, msg):
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f'[{self.__class__.__name__}:{self.model.__class__.__name__}]: {msg}')
|
print(f"[{self.__class__.__name__}:{self.model.__class__.__name__}]: {msg}")
|
||||||
|
|
||||||
def __check_error(self, error):
|
def __check_error(self, error):
|
||||||
if error in qp.error.QUANTIFICATION_ERROR:
|
if error in qp.error.QUANTIFICATION_ERROR:
|
||||||
self.error = error
|
self.error = error
|
||||||
elif isinstance(error, str):
|
elif isinstance(error, str):
|
||||||
self.error = qp.error.from_name(error)
|
self.error = qp.error.from_name(error)
|
||||||
elif hasattr(error, '__call__'):
|
elif hasattr(error, "__call__"):
|
||||||
self.error = error
|
self.error = error
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
|
raise ValueError(
|
||||||
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
|
f"unexpected error type; must either be a callable function or a str representing\n"
|
||||||
|
f"the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}"
|
||||||
|
)
|
||||||
|
|
||||||
def fit(self, training: LabelledCollection):
|
def fit(self, training: LabelledCollection):
|
||||||
""" Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing
|
"""Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing
|
||||||
the error metric.
|
the error metric.
|
||||||
|
|
||||||
:param training: the training set on which to optimize the hyperparameters
|
:param training: the training set on which to optimize the hyperparameters
|
||||||
|
@ -86,14 +89,17 @@ class GridSearchQ(BaseQuantifier):
|
||||||
|
|
||||||
tinit = time()
|
tinit = time()
|
||||||
|
|
||||||
hyper = [dict({k: val[i] for i, k in enumerate(params_keys)}) for val in itertools.product(*params_values)]
|
hyper = [
|
||||||
self._sout(f'starting model selection with {self.n_jobs =}')
|
dict({k: val[i] for i, k in enumerate(params_keys)})
|
||||||
#pass a seed to parallel so it is set in clild processes
|
for val in itertools.product(*params_values)
|
||||||
|
]
|
||||||
|
self._sout(f"starting model selection with {self.n_jobs =}")
|
||||||
|
# pass a seed to parallel so it is set in clild processes
|
||||||
scores = qp.util.parallel(
|
scores = qp.util.parallel(
|
||||||
self._delayed_eval,
|
self._delayed_eval,
|
||||||
((params, training) for params in hyper),
|
((params, training) for params in hyper),
|
||||||
seed=qp.environ.get('_R_SEED', None),
|
seed=qp.environ.get("_R_SEED", None),
|
||||||
n_jobs=self.n_jobs
|
n_jobs=self.n_jobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
for params, score, model in scores:
|
for params, score, model in scores:
|
||||||
|
@ -104,23 +110,27 @@ class GridSearchQ(BaseQuantifier):
|
||||||
self.best_model_ = model
|
self.best_model_ = model
|
||||||
self.param_scores_[str(params)] = score
|
self.param_scores_[str(params)] = score
|
||||||
else:
|
else:
|
||||||
self.param_scores_[str(params)] = 'timeout'
|
self.param_scores_[str(params)] = "timeout"
|
||||||
|
|
||||||
tend = time()-tinit
|
tend = time() - tinit
|
||||||
|
|
||||||
if self.best_score_ is None:
|
if self.best_score_ is None:
|
||||||
raise TimeoutError('no combination of hyperparameters seem to work')
|
raise TimeoutError("no combination of hyperparameters seem to work")
|
||||||
|
|
||||||
self._sout(f'optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) '
|
self._sout(
|
||||||
f'[took {tend:.4f}s]')
|
f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) "
|
||||||
|
f"[took {tend:.4f}s]"
|
||||||
|
)
|
||||||
|
|
||||||
if self.refit:
|
if self.refit:
|
||||||
if isinstance(protocol, OnLabelledCollectionProtocol):
|
if isinstance(protocol, OnLabelledCollectionProtocol):
|
||||||
self._sout(f'refitting on the whole development set')
|
self._sout(f"refitting on the whole development set")
|
||||||
self.best_model_.fit(training + protocol.get_labelled_collection())
|
self.best_model_.fit(training + protocol.get_labelled_collection())
|
||||||
else:
|
else:
|
||||||
raise RuntimeWarning(f'"refit" was requested, but the protocol does not '
|
raise RuntimeWarning(
|
||||||
f'implement the {OnLabelledCollectionProtocol.__name__} interface')
|
f'"refit" was requested, but the protocol does not '
|
||||||
|
f"implement the {OnLabelledCollectionProtocol.__name__} interface"
|
||||||
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -131,6 +141,7 @@ class GridSearchQ(BaseQuantifier):
|
||||||
error = self.error
|
error = self.error
|
||||||
|
|
||||||
if self.timeout > 0:
|
if self.timeout > 0:
|
||||||
|
|
||||||
def handler(signum, frame):
|
def handler(signum, frame):
|
||||||
raise TimeoutError()
|
raise TimeoutError()
|
||||||
|
|
||||||
|
@ -148,25 +159,26 @@ class GridSearchQ(BaseQuantifier):
|
||||||
model.fit(training)
|
model.fit(training)
|
||||||
score = evaluation.evaluate(model, protocol=protocol, error_metric=error)
|
score = evaluation.evaluate(model, protocol=protocol, error_metric=error)
|
||||||
|
|
||||||
ttime = time()-tinit
|
ttime = time() - tinit
|
||||||
self._sout(f'hyperparams={params}\t got {error.__name__} score {score:.5f} [took {ttime:.4f}s]')
|
self._sout(
|
||||||
|
f"hyperparams={params}\t got {error.__name__} score {score:.5f} [took {ttime:.4f}s]"
|
||||||
|
)
|
||||||
|
|
||||||
if self.timeout > 0:
|
if self.timeout > 0:
|
||||||
signal.alarm(0)
|
signal.alarm(0)
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
self._sout(f'timeout ({self.timeout}s) reached for config {params}')
|
self._sout(f"timeout ({self.timeout}s) reached for config {params}")
|
||||||
score = None
|
score = None
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
self._sout(f'the combination of hyperparameters {params} is invalid')
|
self._sout(f"the combination of hyperparameters {params} is invalid")
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._sout(f'something went wrong for config {params}; skipping:')
|
self._sout(f"something went wrong for config {params}; skipping:")
|
||||||
self._sout(f'\tException: {e}')
|
self._sout(f"\tException: {e}")
|
||||||
score = None
|
score = None
|
||||||
|
|
||||||
return params, score, model
|
return params, score, model
|
||||||
|
|
||||||
|
|
||||||
def quantify(self, instances):
|
def quantify(self, instances):
|
||||||
"""Estimate class prevalence values using the best model found after calling the :meth:`fit` method.
|
"""Estimate class prevalence values using the best model found after calling the :meth:`fit` method.
|
||||||
|
|
||||||
|
@ -174,7 +186,7 @@ class GridSearchQ(BaseQuantifier):
|
||||||
:return: a ndarray of shape `(n_classes)` with class prevalence estimates as according to the best model found
|
:return: a ndarray of shape `(n_classes)` with class prevalence estimates as according to the best model found
|
||||||
by the model selection process.
|
by the model selection process.
|
||||||
"""
|
"""
|
||||||
assert hasattr(self, 'best_model_'), 'quantify called before fit'
|
assert hasattr(self, "best_model_"), "quantify called before fit"
|
||||||
return self.best_model().quantify(instances)
|
return self.best_model().quantify(instances)
|
||||||
|
|
||||||
def set_params(self, **parameters):
|
def set_params(self, **parameters):
|
||||||
|
@ -199,14 +211,14 @@ class GridSearchQ(BaseQuantifier):
|
||||||
|
|
||||||
:return: a trained quantifier
|
:return: a trained quantifier
|
||||||
"""
|
"""
|
||||||
if hasattr(self, 'best_model_'):
|
if hasattr(self, "best_model_"):
|
||||||
return self.best_model_
|
return self.best_model_
|
||||||
raise ValueError('best_model called before fit')
|
raise ValueError("best_model called before fit")
|
||||||
|
|
||||||
|
|
||||||
|
def cross_val_predict(
|
||||||
|
quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0
|
||||||
def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0):
|
):
|
||||||
"""
|
"""
|
||||||
Akin to `scikit-learn's cross_val_predict <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_predict.html>`_
|
Akin to `scikit-learn's cross_val_predict <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_predict.html>`_
|
||||||
but for quantification.
|
but for quantification.
|
||||||
|
@ -223,9 +235,7 @@ def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfol
|
||||||
for train, test in data.kFCV(nfolds=nfolds, random_state=random_state):
|
for train, test in data.kFCV(nfolds=nfolds, random_state=random_state):
|
||||||
quantifier.fit(train)
|
quantifier.fit(train)
|
||||||
fold_prev = quantifier.quantify(test.X)
|
fold_prev = quantifier.quantify(test.X)
|
||||||
rel_size = len(test.X)/len(data)
|
rel_size = 1.0 * len(test) / len(data)
|
||||||
total_prev += fold_prev*rel_size
|
total_prev += fold_prev * rel_size
|
||||||
|
|
||||||
return total_prev
|
return total_prev
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue