refatored random_state

This commit is contained in:
Lorenzo Volpi 2023-11-26 16:33:11 +01:00
parent 4dbabacb0d
commit e05dfd4a16
1 changed files with 9 additions and 6 deletions

View File

@ -2,7 +2,7 @@ import math
import os import os
import pickle import pickle
import tarfile import tarfile
from typing import List from typing import List, Tuple
import numpy as np import numpy as np
import quapy as qp import quapy as qp
@ -11,6 +11,7 @@ from sklearn.conftest import fetch_rcv1
from sklearn.utils import Bunch from sklearn.utils import Bunch
from quacc import utils from quacc import utils
from quacc.environment import env
TRAIN_VAL_PROP = 0.5 TRAIN_VAL_PROP = 0.5
@ -190,7 +191,7 @@ class Dataset:
return all_train, test return all_train, test
def __train_test(self): def __train_test(self) -> Tuple[LabelledCollection, LabelledCollection]:
all_train, test = { all_train, test = {
"spambase": self.__spambase, "spambase": self.__spambase,
"imdb": self.__imdb, "imdb": self.__imdb,
@ -205,7 +206,7 @@ class Dataset:
all_train, test = self.__train_test() all_train, test = self.__train_test()
train, val = all_train.split_stratified( train, val = all_train.split_stratified(
train_prop=TRAIN_VAL_PROP, random_state=0 train_prop=TRAIN_VAL_PROP, random_state=env._R_SEED
) )
return DatasetSample(train, val, test) return DatasetSample(train, val, test)
@ -216,7 +217,9 @@ class Dataset:
# resample all_train set to have (0.5, 0.5) prevalence # resample all_train set to have (0.5, 0.5) prevalence
at_positives = np.sum(all_train.y) at_positives = np.sum(all_train.y)
all_train = all_train.sampling( all_train = all_train.sampling(
min(at_positives, len(all_train) - at_positives) * 2, 0.5, random_state=0 min(at_positives, len(all_train) - at_positives) * 2,
0.5,
random_state=env._R_SEED,
) )
# sample prevalences # sample prevalences
@ -228,9 +231,9 @@ class Dataset:
at_size = min(math.floor(len(all_train) * 0.5 / p) for p in prevs) at_size = min(math.floor(len(all_train) * 0.5 / p) for p in prevs)
datasets = [] datasets = []
for p in 1.0 - prevs: for p in 1.0 - prevs:
all_train_sampled = all_train.sampling(at_size, p, random_state=0) all_train_sampled = all_train.sampling(at_size, p, random_state=env._R_SEED)
train, validation = all_train_sampled.split_stratified( train, validation = all_train_sampled.split_stratified(
train_prop=TRAIN_VAL_PROP, random_state=0 train_prop=TRAIN_VAL_PROP, random_state=env._R_SEED
) )
datasets.append(DatasetSample(train, validation, test)) datasets.append(DatasetSample(train, validation, test))