refatored random_state
This commit is contained in:
parent
4dbabacb0d
commit
e05dfd4a16
|
@ -2,7 +2,7 @@ import math
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import tarfile
|
import tarfile
|
||||||
from typing import List
|
from typing import List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import quapy as qp
|
import quapy as qp
|
||||||
|
@ -11,6 +11,7 @@ from sklearn.conftest import fetch_rcv1
|
||||||
from sklearn.utils import Bunch
|
from sklearn.utils import Bunch
|
||||||
|
|
||||||
from quacc import utils
|
from quacc import utils
|
||||||
|
from quacc.environment import env
|
||||||
|
|
||||||
TRAIN_VAL_PROP = 0.5
|
TRAIN_VAL_PROP = 0.5
|
||||||
|
|
||||||
|
@ -190,7 +191,7 @@ class Dataset:
|
||||||
|
|
||||||
return all_train, test
|
return all_train, test
|
||||||
|
|
||||||
def __train_test(self):
|
def __train_test(self) -> Tuple[LabelledCollection, LabelledCollection]:
|
||||||
all_train, test = {
|
all_train, test = {
|
||||||
"spambase": self.__spambase,
|
"spambase": self.__spambase,
|
||||||
"imdb": self.__imdb,
|
"imdb": self.__imdb,
|
||||||
|
@ -205,7 +206,7 @@ class Dataset:
|
||||||
all_train, test = self.__train_test()
|
all_train, test = self.__train_test()
|
||||||
|
|
||||||
train, val = all_train.split_stratified(
|
train, val = all_train.split_stratified(
|
||||||
train_prop=TRAIN_VAL_PROP, random_state=0
|
train_prop=TRAIN_VAL_PROP, random_state=env._R_SEED
|
||||||
)
|
)
|
||||||
|
|
||||||
return DatasetSample(train, val, test)
|
return DatasetSample(train, val, test)
|
||||||
|
@ -216,7 +217,9 @@ class Dataset:
|
||||||
# resample all_train set to have (0.5, 0.5) prevalence
|
# resample all_train set to have (0.5, 0.5) prevalence
|
||||||
at_positives = np.sum(all_train.y)
|
at_positives = np.sum(all_train.y)
|
||||||
all_train = all_train.sampling(
|
all_train = all_train.sampling(
|
||||||
min(at_positives, len(all_train) - at_positives) * 2, 0.5, random_state=0
|
min(at_positives, len(all_train) - at_positives) * 2,
|
||||||
|
0.5,
|
||||||
|
random_state=env._R_SEED,
|
||||||
)
|
)
|
||||||
|
|
||||||
# sample prevalences
|
# sample prevalences
|
||||||
|
@ -228,9 +231,9 @@ class Dataset:
|
||||||
at_size = min(math.floor(len(all_train) * 0.5 / p) for p in prevs)
|
at_size = min(math.floor(len(all_train) * 0.5 / p) for p in prevs)
|
||||||
datasets = []
|
datasets = []
|
||||||
for p in 1.0 - prevs:
|
for p in 1.0 - prevs:
|
||||||
all_train_sampled = all_train.sampling(at_size, p, random_state=0)
|
all_train_sampled = all_train.sampling(at_size, p, random_state=env._R_SEED)
|
||||||
train, validation = all_train_sampled.split_stratified(
|
train, validation = all_train_sampled.split_stratified(
|
||||||
train_prop=TRAIN_VAL_PROP, random_state=0
|
train_prop=TRAIN_VAL_PROP, random_state=env._R_SEED
|
||||||
)
|
)
|
||||||
datasets.append(DatasetSample(train, validation, test))
|
datasets.append(DatasetSample(train, validation, test))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue