random_state updated

This commit is contained in:
Lorenzo Volpi 2023-11-26 16:42:35 +01:00
parent 1806243d53
commit 8f903d96e2
3 changed files with 10 additions and 7 deletions

View File

@ -17,6 +17,7 @@ import baselines.impweight as iw
import baselines.mandoline as mandolib
import baselines.rca as rcalib
from baselines.utils import clone_fit
from quacc.environment import env
from .report import EvaluationReport
@ -169,7 +170,7 @@ def doc(
predict_method="predict_proba",
):
c_model_predict = getattr(c_model, predict_method)
val1, val2 = validation.split_stratified(train_prop=0.5, random_state=0)
val1, val2 = validation.split_stratified(train_prop=0.5, random_state=env._R_SEED)
val1_probs = c_model_predict(val1.X)
val1_mc = np.max(val1_probs, axis=-1)
val1_preds = np.argmax(val1_probs, axis=-1)
@ -281,7 +282,7 @@ def rca_star(
"""elsahar19"""
c_model_predict = getattr(c_model, predict_method)
validation1, validation2 = validation.split_stratified(
train_prop=0.5, random_state=0
train_prop=0.5, random_state=env._R_SEED
)
val1_pred = c_model_predict(validation1.X)
c_model1 = clone_fit(c_model, validation1.X, val1_pred)
@ -318,7 +319,7 @@ def gde(
predict_method="predict",
) -> EvaluationReport:
c_model_predict = getattr(c_model, predict_method)
val1, val2 = validation.split_stratified(train_prop=0.5, random_state=0)
val1, val2 = validation.split_stratified(train_prop=0.5, random_state=env._R_SEED)
c_model1 = clone_fit(c_model, val1.X, val1.y)
c_model1_predict = getattr(c_model1, predict_method)
c_model2 = clone_fit(c_model, val2.X, val2.y)

View File

@ -8,6 +8,7 @@ from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
import quacc as qc
from quacc.environment import env
from quacc.evaluation.report import EvaluationReport
from quacc.method.base import BQAE, MCAE, BaseAccuracyEstimator
from quacc.method.model_selection import GridSearchAE
@ -97,7 +98,7 @@ class EvaluationMethodGridSearch(EvaluationMethod):
pg: str = "sld"
def __call__(self, c_model, validation, protocol) -> EvaluationReport:
v_train, v_val = validation.split_stratified(0.6, random_state=0)
v_train, v_val = validation.split_stratified(0.6, random_state=env._R_SEED)
__grid = _param_grid.get(self.pg, {})
est = GridSearchAE(
model=self.get_est(c_model),
@ -122,7 +123,7 @@ def __sld_lr():
def __kde_lr():
return KDEy(LogisticRegression())
return KDEy(LogisticRegression(), random_state=env._R_SEED)
def __sld_lsvc():

View File

@ -13,6 +13,7 @@ from sklearn.base import BaseEstimator
import quacc as qc
import quacc.error
from quacc.data import ExtendedCollection, ExtendedData
from quacc.environment import env
from quacc.evaluation import evaluate
from quacc.logger import SubLogger
from quacc.method.base import (
@ -251,7 +252,7 @@ class MCAEgsq(MultiClassAccuracyEstimator):
def fit(self, train: LabelledCollection):
self.e_train = self.extend(train)
t_train, t_val = self.e_train.split_stratified(0.6, random_state=0)
t_train, t_val = self.e_train.split_stratified(0.6, random_state=env._R_SEED)
self.quantifier = GridSearchQ(
deepcopy(self.quantifier),
param_grid=self.param_grid,
@ -304,7 +305,7 @@ class BQAEgsq(BinaryQuantifierAccuracyEstimator):
self.quantifiers = []
for e_train in self.e_trains:
t_train, t_val = e_train.split_stratified(0.6, random_state=0)
t_train, t_val = e_train.split_stratified(0.6, random_state=env._R_SEED)
quantifier = GridSearchQ(
model=deepcopy(self.quantifier),
param_grid=self.param_grid,