dataset fixed
This commit is contained in:
parent
5d82419ce8
commit
46c24d9fd8
|
@ -126,9 +126,7 @@ class DatasetProvider:
|
||||||
|
|
||||||
# provare min_df=5
|
# provare min_df=5
|
||||||
def __imdb(self, **kwargs):
|
def __imdb(self, **kwargs):
|
||||||
return qp.datasets.fetch_reviews(
|
return qp.datasets.fetch_reviews("imdb", tfidf=True, min_df=3).train_test
|
||||||
"imdb", data_home="./quapy_data", tfidf=True, min_df=3
|
|
||||||
).train_test
|
|
||||||
|
|
||||||
def __rcv1(self, target, **kwargs):
|
def __rcv1(self, target, **kwargs):
|
||||||
n_train = 23149
|
n_train = 23149
|
||||||
|
@ -137,7 +135,7 @@ class DatasetProvider:
|
||||||
if target is None or target not in available_targets:
|
if target is None or target not in available_targets:
|
||||||
raise ValueError(f"Invalid target {target}")
|
raise ValueError(f"Invalid target {target}")
|
||||||
|
|
||||||
dataset = fetch_rcv1(data_home="./scikit_learn_data")
|
dataset = fetch_rcv1()
|
||||||
target_index = np.where(dataset.target_names == target)[0]
|
target_index = np.where(dataset.target_names == target)[0]
|
||||||
all_train_d = dataset.data[:n_train, :]
|
all_train_d = dataset.data[:n_train, :]
|
||||||
test_d = dataset.data[n_train:, :]
|
test_d = dataset.data[n_train:, :]
|
||||||
|
|
Loading…
Reference in New Issue