random_state updated
This commit is contained in:
parent
1806243d53
commit
8f903d96e2
|
@ -17,6 +17,7 @@ import baselines.impweight as iw
|
|||
import baselines.mandoline as mandolib
|
||||
import baselines.rca as rcalib
|
||||
from baselines.utils import clone_fit
|
||||
from quacc.environment import env
|
||||
|
||||
from .report import EvaluationReport
|
||||
|
||||
|
@ -169,7 +170,7 @@ def doc(
|
|||
predict_method="predict_proba",
|
||||
):
|
||||
c_model_predict = getattr(c_model, predict_method)
|
||||
val1, val2 = validation.split_stratified(train_prop=0.5, random_state=0)
|
||||
val1, val2 = validation.split_stratified(train_prop=0.5, random_state=env._R_SEED)
|
||||
val1_probs = c_model_predict(val1.X)
|
||||
val1_mc = np.max(val1_probs, axis=-1)
|
||||
val1_preds = np.argmax(val1_probs, axis=-1)
|
||||
|
@ -281,7 +282,7 @@ def rca_star(
|
|||
"""elsahar19"""
|
||||
c_model_predict = getattr(c_model, predict_method)
|
||||
validation1, validation2 = validation.split_stratified(
|
||||
train_prop=0.5, random_state=0
|
||||
train_prop=0.5, random_state=env._R_SEED
|
||||
)
|
||||
val1_pred = c_model_predict(validation1.X)
|
||||
c_model1 = clone_fit(c_model, validation1.X, val1_pred)
|
||||
|
@ -318,7 +319,7 @@ def gde(
|
|||
predict_method="predict",
|
||||
) -> EvaluationReport:
|
||||
c_model_predict = getattr(c_model, predict_method)
|
||||
val1, val2 = validation.split_stratified(train_prop=0.5, random_state=0)
|
||||
val1, val2 = validation.split_stratified(train_prop=0.5, random_state=env._R_SEED)
|
||||
c_model1 = clone_fit(c_model, val1.X, val1.y)
|
||||
c_model1_predict = getattr(c_model1, predict_method)
|
||||
c_model2 = clone_fit(c_model, val2.X, val2.y)
|
||||
|
|
|
@ -8,6 +8,7 @@ from sklearn.linear_model import LogisticRegression
|
|||
from sklearn.svm import LinearSVC
|
||||
|
||||
import quacc as qc
|
||||
from quacc.environment import env
|
||||
from quacc.evaluation.report import EvaluationReport
|
||||
from quacc.method.base import BQAE, MCAE, BaseAccuracyEstimator
|
||||
from quacc.method.model_selection import GridSearchAE
|
||||
|
@ -97,7 +98,7 @@ class EvaluationMethodGridSearch(EvaluationMethod):
|
|||
pg: str = "sld"
|
||||
|
||||
def __call__(self, c_model, validation, protocol) -> EvaluationReport:
|
||||
v_train, v_val = validation.split_stratified(0.6, random_state=0)
|
||||
v_train, v_val = validation.split_stratified(0.6, random_state=env._R_SEED)
|
||||
__grid = _param_grid.get(self.pg, {})
|
||||
est = GridSearchAE(
|
||||
model=self.get_est(c_model),
|
||||
|
@ -122,7 +123,7 @@ def __sld_lr():
|
|||
|
||||
|
||||
def __kde_lr():
|
||||
return KDEy(LogisticRegression())
|
||||
return KDEy(LogisticRegression(), random_state=env._R_SEED)
|
||||
|
||||
|
||||
def __sld_lsvc():
|
||||
|
|
|
@ -13,6 +13,7 @@ from sklearn.base import BaseEstimator
|
|||
import quacc as qc
|
||||
import quacc.error
|
||||
from quacc.data import ExtendedCollection, ExtendedData
|
||||
from quacc.environment import env
|
||||
from quacc.evaluation import evaluate
|
||||
from quacc.logger import SubLogger
|
||||
from quacc.method.base import (
|
||||
|
@ -251,7 +252,7 @@ class MCAEgsq(MultiClassAccuracyEstimator):
|
|||
|
||||
def fit(self, train: LabelledCollection):
|
||||
self.e_train = self.extend(train)
|
||||
t_train, t_val = self.e_train.split_stratified(0.6, random_state=0)
|
||||
t_train, t_val = self.e_train.split_stratified(0.6, random_state=env._R_SEED)
|
||||
self.quantifier = GridSearchQ(
|
||||
deepcopy(self.quantifier),
|
||||
param_grid=self.param_grid,
|
||||
|
@ -304,7 +305,7 @@ class BQAEgsq(BinaryQuantifierAccuracyEstimator):
|
|||
|
||||
self.quantifiers = []
|
||||
for e_train in self.e_trains:
|
||||
t_train, t_val = e_train.split_stratified(0.6, random_state=0)
|
||||
t_train, t_val = e_train.split_stratified(0.6, random_state=env._R_SEED)
|
||||
quantifier = GridSearchQ(
|
||||
model=deepcopy(self.quantifier),
|
||||
param_grid=self.param_grid,
|
||||
|
|
Loading…
Reference in New Issue