From a4584b79dbf30517f86effccc6208ed28c36d396 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pablo=20Gonz=C3=A1lez?= Date: Mon, 11 Jul 2022 16:27:02 +0200 Subject: [PATCH] changing gridsearchQ to ensure reproducibility --- quapy/model_selection.py | 3 ++- quapy/util.py | 19 ++++++++++++++----- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/quapy/model_selection.py b/quapy/model_selection.py index d627649..41a7a19 100644 --- a/quapy/model_selection.py +++ b/quapy/model_selection.py @@ -83,7 +83,8 @@ class GridSearchQ(BaseQuantifier): tinit = time() hyper = [dict({k: values[i] for i, k in enumerate(params_keys)}) for values in itertools.product(*params_values)] - scores = qp.util.parallel(self._delayed_eval, ((params, training) for params in hyper), n_jobs=self.n_jobs) + #pass a seed to parallel so it is set in clild processes + scores = qp.util.parallel(self._delayed_eval, ((params, training) for params in hyper), seed=qp.environ.get('_R_SEED', None), n_jobs=self.n_jobs) for params, score, model in scores: if score is not None: diff --git a/quapy/util.py b/quapy/util.py index 2ccf06d..94187e6 100644 --- a/quapy/util.py +++ b/quapy/util.py @@ -5,6 +5,7 @@ import os import pickle import urllib from pathlib import Path +from contextlib import ExitStack import quapy as qp import numpy as np @@ -36,7 +37,7 @@ def map_parallel(func, args, n_jobs): return list(itertools.chain.from_iterable(results)) -def parallel(func, args, n_jobs): +def parallel(func, args, n_jobs, seed = None): """ A wrapper of multiprocessing: @@ -44,14 +45,20 @@ def parallel(func, args, n_jobs): >>> delayed(func)(args_i) for args_i in args >>> ) - that takes the `quapy.environ` variable as input silently + that takes the `quapy.environ` variable as input silently. + Seeds the child processes to ensure reproducibility when n_jobs>1 """ - def func_dec(environ, *args): + def func_dec(environ, seed, *args): qp.environ = environ.copy() qp.environ['N_JOBS'] = 1 - return func(*args) + #set a context with a temporal seed to ensure results are reproducibles in parallel + with ExitStack() as stack: + if seed is not None: + stack.enter_context(qp.util.temp_seed(seed)) + return func(*args) + return Parallel(n_jobs=n_jobs)( - delayed(func_dec)(qp.environ, args_i) for args_i in args + delayed(func_dec)(qp.environ, None if seed is None else seed+i, args_i) for i, args_i in enumerate(args) ) @@ -66,6 +73,8 @@ def temp_seed(random_state): :param random_state: the seed to set within the "with" context """ state = np.random.get_state() + #save the seed just in case is needed (for instance for setting the seed to child processes) + qp.environ['_R_SEED'] = random_state np.random.seed(random_state) try: yield