datasets added and adapted to refactoring
This commit is contained in:
parent
d4b0212a92
commit
9d2a6dcdec
178
quacc/dataset.py
178
quacc/dataset.py
|
@ -1,29 +1,31 @@
|
||||||
import itertools
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import tarfile
|
import tarfile
|
||||||
from typing import List, Tuple
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import quapy as qp
|
import quapy as qp
|
||||||
from quapy.data.base import LabelledCollection
|
from quapy.data.base import LabelledCollection
|
||||||
from sklearn.conftest import fetch_rcv1
|
from quapy.data.datasets import fetch_lequa2022, fetch_UCIMulticlassLabelledCollection
|
||||||
|
from sklearn.datasets import fetch_20newsgroups, fetch_rcv1
|
||||||
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
from sklearn.utils import Bunch
|
from sklearn.utils import Bunch
|
||||||
|
|
||||||
from quacc import utils
|
from quacc.legacy.environment import env
|
||||||
from quacc.environment import env
|
from quacc.utils import commons
|
||||||
|
from quacc.utils.commons import save_json_file
|
||||||
|
|
||||||
TRAIN_VAL_PROP = 0.5
|
TRAIN_VAL_PROP = 0.5
|
||||||
|
|
||||||
|
|
||||||
def fetch_cifar10() -> Bunch:
|
def fetch_cifar10() -> Bunch:
|
||||||
URL = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
|
URL = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
|
||||||
data_home = utils.get_quacc_home()
|
data_home = commons.get_quacc_home()
|
||||||
unzipped_path = data_home / "cifar-10-batches-py"
|
unzipped_path = data_home / "cifar-10-batches-py"
|
||||||
if not unzipped_path.exists():
|
if not unzipped_path.exists():
|
||||||
downloaded_path = data_home / URL.split("/")[-1]
|
downloaded_path = data_home / URL.split("/")[-1]
|
||||||
utils.download_file(URL, downloaded_path)
|
commons.download_file(URL, downloaded_path)
|
||||||
with tarfile.open(downloaded_path) as f:
|
with tarfile.open(downloaded_path) as f:
|
||||||
f.extractall(data_home)
|
f.extractall(data_home)
|
||||||
os.remove(downloaded_path)
|
os.remove(downloaded_path)
|
||||||
|
@ -58,11 +60,11 @@ def fetch_cifar10() -> Bunch:
|
||||||
|
|
||||||
def fetch_cifar100():
|
def fetch_cifar100():
|
||||||
URL = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
|
URL = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
|
||||||
data_home = utils.get_quacc_home()
|
data_home = commons.get_quacc_home()
|
||||||
unzipped_path = data_home / "cifar-100-python"
|
unzipped_path = data_home / "cifar-100-python"
|
||||||
if not unzipped_path.exists():
|
if not unzipped_path.exists():
|
||||||
downloaded_path = data_home / URL.split("/")[-1]
|
downloaded_path = data_home / URL.split("/")[-1]
|
||||||
utils.download_file(URL, downloaded_path)
|
commons.download_file(URL, downloaded_path)
|
||||||
with tarfile.open(downloaded_path) as f:
|
with tarfile.open(downloaded_path) as f:
|
||||||
f.extractall(data_home)
|
f.extractall(data_home)
|
||||||
os.remove(downloaded_path)
|
os.remove(downloaded_path)
|
||||||
|
@ -96,6 +98,23 @@ def fetch_cifar100():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def save_dataset_stats(path, test_prot, L, V):
|
||||||
|
test_prevs = [Ui.prevalence() for Ui in test_prot()]
|
||||||
|
shifts = [qp.error.ae(L.prevalence(), Ui_prev) for Ui_prev in test_prevs]
|
||||||
|
info = {
|
||||||
|
"n_classes": L.n_classes,
|
||||||
|
"n_train": len(L),
|
||||||
|
"n_val": len(V),
|
||||||
|
"train_prev": L.prevalence().tolist(),
|
||||||
|
"val_prev": V.prevalence().tolist(),
|
||||||
|
"test_prevs": [x.tolist() for x in test_prevs],
|
||||||
|
"shifts": [x.tolist() for x in shifts],
|
||||||
|
"sample_size": test_prot.sample_size,
|
||||||
|
"num_samples": test_prot.total(),
|
||||||
|
}
|
||||||
|
save_json_file(path, info)
|
||||||
|
|
||||||
|
|
||||||
class DatasetSample:
|
class DatasetSample:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -121,49 +140,69 @@ class DatasetSample:
|
||||||
|
|
||||||
|
|
||||||
class DatasetProvider:
|
class DatasetProvider:
|
||||||
def __spambase(self, **kwargs):
|
@classmethod
|
||||||
return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test
|
def _split_train(cls, train: LabelledCollection):
|
||||||
|
return train.split_stratified(0.5, random_state=0)
|
||||||
|
|
||||||
# provare min_df=5
|
@classmethod
|
||||||
def __imdb(self, **kwargs):
|
def _split_whole(cls, dataset: LabelledCollection):
|
||||||
return qp.datasets.fetch_reviews("imdb", tfidf=True, min_df=3).train_test
|
train, U = dataset.split_stratified(train_prop=0.66, random_state=0)
|
||||||
|
T, V = train.split_stratified(train_prop=0.5, random_state=0)
|
||||||
|
return T, V, U
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def spambase(cls):
|
||||||
|
train, U = qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test
|
||||||
|
T, V = cls._split_train(train)
|
||||||
|
return T, V, U
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def imdb(cls):
|
||||||
|
train, U = qp.datasets.fetch_reviews(
|
||||||
|
"imdb", tfidf=True, min_df=10, pickle=True
|
||||||
|
).train_test
|
||||||
|
T, V = cls._split_train(train)
|
||||||
|
return T, V, U
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def rcv1(cls, target):
|
||||||
|
training = fetch_rcv1(subset="train")
|
||||||
|
test = fetch_rcv1(subset="test")
|
||||||
|
|
||||||
def __rcv1(self, target, **kwargs):
|
|
||||||
n_train = 23149
|
|
||||||
available_targets = ["CCAT", "GCAT", "MCAT"]
|
available_targets = ["CCAT", "GCAT", "MCAT"]
|
||||||
|
if target is None or target not in available_targets:
|
||||||
|
raise ValueError(f"Invalid target {target}")
|
||||||
|
|
||||||
|
class_names = training.target_names.tolist()
|
||||||
|
class_idx = class_names.index(target)
|
||||||
|
tr_labels = training.target[:, class_idx].toarray().flatten()
|
||||||
|
te_labels = test.target[:, class_idx].toarray().flatten()
|
||||||
|
tr = LabelledCollection(training.data, tr_labels)
|
||||||
|
U = LabelledCollection(test.data, te_labels)
|
||||||
|
T, V = cls._split_train(tr)
|
||||||
|
return T, V, U
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def cifar10(cls, target):
|
||||||
|
dataset = fetch_cifar10()
|
||||||
|
available_targets: list = dataset.label_names
|
||||||
|
|
||||||
if target is None or target not in available_targets:
|
if target is None or target not in available_targets:
|
||||||
raise ValueError(f"Invalid target {target}")
|
raise ValueError(f"Invalid target {target}")
|
||||||
|
|
||||||
dataset = fetch_rcv1()
|
target_idx = available_targets.index(target)
|
||||||
target_index = np.where(dataset.target_names == target)[0]
|
train_d = dataset.train.data
|
||||||
all_train_d = dataset.data[:n_train, :]
|
train_l = (dataset.train.labels == target_idx).astype(int)
|
||||||
test_d = dataset.data[n_train:, :]
|
|
||||||
labels = dataset.target[:, target_index].toarray().flatten()
|
|
||||||
all_train_l, test_l = labels[:n_train], labels[n_train:]
|
|
||||||
all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1])
|
|
||||||
test = LabelledCollection(test_d, test_l, classes=[0, 1])
|
|
||||||
|
|
||||||
return all_train, test
|
|
||||||
|
|
||||||
def __cifar10(self, target, **kwargs):
|
|
||||||
dataset = fetch_cifar10()
|
|
||||||
available_targets: list = dataset.label_names
|
|
||||||
|
|
||||||
if target is None or self._target not in available_targets:
|
|
||||||
raise ValueError(f"Invalid target {target}")
|
|
||||||
|
|
||||||
target_index = available_targets.index(target)
|
|
||||||
all_train_d = dataset.train.data
|
|
||||||
all_train_l = (dataset.train.labels == target_index).astype(int)
|
|
||||||
test_d = dataset.test.data
|
test_d = dataset.test.data
|
||||||
test_l = (dataset.test.labels == target_index).astype(int)
|
test_l = (dataset.test.labels == target_idx).astype(int)
|
||||||
all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1])
|
train = LabelledCollection(train_d, train_l, classes=[0, 1])
|
||||||
test = LabelledCollection(test_d, test_l, classes=[0, 1])
|
U = LabelledCollection(test_d, test_l, classes=[0, 1])
|
||||||
|
T, V = cls._split_train(train)
|
||||||
|
|
||||||
return all_train, test
|
return T, V, U
|
||||||
|
|
||||||
def __cifar100(self, target, **kwargs):
|
@classmethod
|
||||||
|
def cifar100(cls, target):
|
||||||
dataset = fetch_cifar100()
|
dataset = fetch_cifar100()
|
||||||
available_targets: list = dataset.coarse_label_names
|
available_targets: list = dataset.coarse_label_names
|
||||||
|
|
||||||
|
@ -171,31 +210,48 @@ class DatasetProvider:
|
||||||
raise ValueError(f"Invalid target {target}")
|
raise ValueError(f"Invalid target {target}")
|
||||||
|
|
||||||
target_index = available_targets.index(target)
|
target_index = available_targets.index(target)
|
||||||
all_train_d = dataset.train.data
|
train_d = dataset.train.data
|
||||||
all_train_l = (dataset.train.coarse_labels == target_index).astype(int)
|
train_l = (dataset.train.coarse_labels == target_index).astype(int)
|
||||||
test_d = dataset.test.data
|
test_d = dataset.test.data
|
||||||
test_l = (dataset.test.coarse_labels == target_index).astype(int)
|
test_l = (dataset.test.coarse_labels == target_index).astype(int)
|
||||||
all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1])
|
train = LabelledCollection(train_d, train_l, classes=[0, 1])
|
||||||
test = LabelledCollection(test_d, test_l, classes=[0, 1])
|
U = LabelledCollection(test_d, test_l, classes=[0, 1])
|
||||||
|
T, V = cls._split_train(train)
|
||||||
|
|
||||||
return all_train, test
|
return T, V, U
|
||||||
|
|
||||||
def __twitter_gasp(self, **kwargs):
|
@classmethod
|
||||||
return qp.datasets.fetch_twitter("gasp", min_df=3).train_test
|
def twitter(cls, dataset_name):
|
||||||
|
data = qp.datasets.fetch_twitter(dataset_name, min_df=3, pickle=True)
|
||||||
|
T, V = cls._split_train(data.training)
|
||||||
|
U = data.test
|
||||||
|
return T, V, U
|
||||||
|
|
||||||
def alltrain_test(
|
@classmethod
|
||||||
self, name: str, target: str | None
|
def uci_multiclass(cls, dataset_name):
|
||||||
) -> Tuple[LabelledCollection, LabelledCollection]:
|
dataset = fetch_UCIMulticlassLabelledCollection(dataset_name)
|
||||||
all_train, test = {
|
return cls._split_whole(dataset)
|
||||||
"spambase": self.__spambase,
|
|
||||||
"imdb": self.__imdb,
|
|
||||||
"rcv1": self.__rcv1,
|
|
||||||
"cifar10": self.__cifar10,
|
|
||||||
"cifar100": self.__cifar100,
|
|
||||||
"twitter_gasp": self.__twitter_gasp,
|
|
||||||
}[name](target=target)
|
|
||||||
|
|
||||||
return all_train, test
|
@classmethod
|
||||||
|
def news20(cls):
|
||||||
|
train = fetch_20newsgroups(
|
||||||
|
subset="train", remove=("headers", "footers", "quotes")
|
||||||
|
)
|
||||||
|
test = fetch_20newsgroups(
|
||||||
|
subset="test", remove=("headers", "footers", "quotes")
|
||||||
|
)
|
||||||
|
tfidf = TfidfVectorizer(min_df=5, sublinear_tf=True)
|
||||||
|
Xtr = tfidf.fit_transform(train.data)
|
||||||
|
Xte = tfidf.transform((test.data))
|
||||||
|
train = LabelledCollection(instances=Xtr, labels=train.target)
|
||||||
|
U = LabelledCollection(instances=Xte, labels=test.target)
|
||||||
|
T, V = cls._split_train(train)
|
||||||
|
return T, V, U
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def t1b_lequa2022(cls):
|
||||||
|
dataset, _, _ = fetch_lequa2022(task="T1B")
|
||||||
|
return cls._split_whole(dataset)
|
||||||
|
|
||||||
|
|
||||||
class Dataset(DatasetProvider):
|
class Dataset(DatasetProvider):
|
||||||
|
|
Loading…
Reference in New Issue