1
0
Fork 0

fix added for cross_val_predict

This commit is contained in:
Lorenzo Volpi 2023-11-06 01:58:36 +01:00
parent 51c3d54aa5
commit 13fe531e12
1 changed files with 46 additions and 56 deletions

View File

@ -34,8 +34,7 @@ class GridSearchQ(BaseQuantifier):
:param verbose: set to True to get information through the stdout
"""
def __init__(
self,
def __init__(self,
model: BaseQuantifier,
param_grid: dict,
protocol: AbstractProtocol,
@ -43,8 +42,8 @@ class GridSearchQ(BaseQuantifier):
refit=True,
timeout=-1,
n_jobs=None,
verbose=False,
):
verbose=False):
self.model = model
self.param_grid = param_grid
self.protocol = protocol
@ -53,24 +52,22 @@ class GridSearchQ(BaseQuantifier):
self.n_jobs = qp._get_njobs(n_jobs)
self.verbose = verbose
self.__check_error(error)
assert isinstance(protocol, AbstractProtocol), "unknown protocol"
assert isinstance(protocol, AbstractProtocol), 'unknown protocol'
def _sout(self, msg):
if self.verbose:
print(f"[{self.__class__.__name__}:{self.model.__class__.__name__}]: {msg}")
print(f'[{self.__class__.__name__}:{self.model.__class__.__name__}]: {msg}')
def __check_error(self, error):
if error in qp.error.QUANTIFICATION_ERROR:
self.error = error
elif isinstance(error, str):
self.error = qp.error.from_name(error)
elif hasattr(error, "__call__"):
elif hasattr(error, '__call__'):
self.error = error
else:
raise ValueError(
f"unexpected error type; must either be a callable function or a str representing\n"
f"the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}"
)
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
def fit(self, training: LabelledCollection):
""" Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing
@ -89,17 +86,14 @@ class GridSearchQ(BaseQuantifier):
tinit = time()
hyper = [
dict({k: val[i] for i, k in enumerate(params_keys)})
for val in itertools.product(*params_values)
]
self._sout(f"starting model selection with {self.n_jobs =}")
hyper = [dict({k: val[i] for i, k in enumerate(params_keys)}) for val in itertools.product(*params_values)]
self._sout(f'starting model selection with {self.n_jobs =}')
#pass a seed to parallel so it is set in clild processes
scores = qp.util.parallel(
self._delayed_eval,
((params, training) for params in hyper),
seed=qp.environ.get("_R_SEED", None),
n_jobs=self.n_jobs,
seed=qp.environ.get('_R_SEED', None),
n_jobs=self.n_jobs
)
for params, score, model in scores:
@ -110,27 +104,23 @@ class GridSearchQ(BaseQuantifier):
self.best_model_ = model
self.param_scores_[str(params)] = score
else:
self.param_scores_[str(params)] = "timeout"
self.param_scores_[str(params)] = 'timeout'
tend = time()-tinit
if self.best_score_ is None:
raise TimeoutError("no combination of hyperparameters seem to work")
raise TimeoutError('no combination of hyperparameters seem to work')
self._sout(
f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) "
f"[took {tend:.4f}s]"
)
self._sout(f'optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) '
f'[took {tend:.4f}s]')
if self.refit:
if isinstance(protocol, OnLabelledCollectionProtocol):
self._sout(f"refitting on the whole development set")
self._sout(f'refitting on the whole development set')
self.best_model_.fit(training + protocol.get_labelled_collection())
else:
raise RuntimeWarning(
f'"refit" was requested, but the protocol does not '
f"implement the {OnLabelledCollectionProtocol.__name__} interface"
)
raise RuntimeWarning(f'"refit" was requested, but the protocol does not '
f'implement the {OnLabelledCollectionProtocol.__name__} interface')
return self
@ -141,7 +131,6 @@ class GridSearchQ(BaseQuantifier):
error = self.error
if self.timeout > 0:
def handler(signum, frame):
raise TimeoutError()
@ -160,25 +149,24 @@ class GridSearchQ(BaseQuantifier):
score = evaluation.evaluate(model, protocol=protocol, error_metric=error)
ttime = time()-tinit
self._sout(
f"hyperparams={params}\t got {error.__name__} score {score:.5f} [took {ttime:.4f}s]"
)
self._sout(f'hyperparams={params}\t got {error.__name__} score {score:.5f} [took {ttime:.4f}s]')
if self.timeout > 0:
signal.alarm(0)
except TimeoutError:
self._sout(f"timeout ({self.timeout}s) reached for config {params}")
self._sout(f'timeout ({self.timeout}s) reached for config {params}')
score = None
except ValueError as e:
self._sout(f"the combination of hyperparameters {params} is invalid")
self._sout(f'the combination of hyperparameters {params} is invalid')
raise e
except Exception as e:
self._sout(f"something went wrong for config {params}; skipping:")
self._sout(f"\tException: {e}")
self._sout(f'something went wrong for config {params}; skipping:')
self._sout(f'\tException: {e}')
score = None
return params, score, model
def quantify(self, instances):
"""Estimate class prevalence values using the best model found after calling the :meth:`fit` method.
@ -186,7 +174,7 @@ class GridSearchQ(BaseQuantifier):
:return: a ndarray of shape `(n_classes)` with class prevalence estimates as according to the best model found
by the model selection process.
"""
assert hasattr(self, "best_model_"), "quantify called before fit"
assert hasattr(self, 'best_model_'), 'quantify called before fit'
return self.best_model().quantify(instances)
def set_params(self, **parameters):
@ -211,14 +199,14 @@ class GridSearchQ(BaseQuantifier):
:return: a trained quantifier
"""
if hasattr(self, "best_model_"):
if hasattr(self, 'best_model_'):
return self.best_model_
raise ValueError("best_model called before fit")
raise ValueError('best_model called before fit')
def cross_val_predict(
quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0
):
def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0):
"""
Akin to `scikit-learn's cross_val_predict <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_predict.html>`_
but for quantification.
@ -235,7 +223,9 @@ def cross_val_predict(
for train, test in data.kFCV(nfolds=nfolds, random_state=random_state):
quantifier.fit(train)
fold_prev = quantifier.quantify(test.X)
rel_size = 1.0 * len(test) / len(data)
rel_size = 1. * len(test) / len(data)
total_prev += fold_prev*rel_size
return total_prev