datasets added and adapted to refactoring

This commit is contained in:
Lorenzo Volpi 2024-04-04 17:06:51 +02:00
parent d4b0212a92
commit 9d2a6dcdec
1 changed files with 117 additions and 61 deletions

View File

@ -1,29 +1,31 @@
import itertools
import math import math
import os import os
import pickle import pickle
import tarfile import tarfile
from typing import List, Tuple from typing import List
import numpy as np import numpy as np
import quapy as qp import quapy as qp
from quapy.data.base import LabelledCollection from quapy.data.base import LabelledCollection
from sklearn.conftest import fetch_rcv1 from quapy.data.datasets import fetch_lequa2022, fetch_UCIMulticlassLabelledCollection
from sklearn.datasets import fetch_20newsgroups, fetch_rcv1
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.utils import Bunch from sklearn.utils import Bunch
from quacc import utils from quacc.legacy.environment import env
from quacc.environment import env from quacc.utils import commons
from quacc.utils.commons import save_json_file
TRAIN_VAL_PROP = 0.5 TRAIN_VAL_PROP = 0.5
def fetch_cifar10() -> Bunch: def fetch_cifar10() -> Bunch:
URL = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" URL = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
data_home = utils.get_quacc_home() data_home = commons.get_quacc_home()
unzipped_path = data_home / "cifar-10-batches-py" unzipped_path = data_home / "cifar-10-batches-py"
if not unzipped_path.exists(): if not unzipped_path.exists():
downloaded_path = data_home / URL.split("/")[-1] downloaded_path = data_home / URL.split("/")[-1]
utils.download_file(URL, downloaded_path) commons.download_file(URL, downloaded_path)
with tarfile.open(downloaded_path) as f: with tarfile.open(downloaded_path) as f:
f.extractall(data_home) f.extractall(data_home)
os.remove(downloaded_path) os.remove(downloaded_path)
@ -58,11 +60,11 @@ def fetch_cifar10() -> Bunch:
def fetch_cifar100(): def fetch_cifar100():
URL = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" URL = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
data_home = utils.get_quacc_home() data_home = commons.get_quacc_home()
unzipped_path = data_home / "cifar-100-python" unzipped_path = data_home / "cifar-100-python"
if not unzipped_path.exists(): if not unzipped_path.exists():
downloaded_path = data_home / URL.split("/")[-1] downloaded_path = data_home / URL.split("/")[-1]
utils.download_file(URL, downloaded_path) commons.download_file(URL, downloaded_path)
with tarfile.open(downloaded_path) as f: with tarfile.open(downloaded_path) as f:
f.extractall(data_home) f.extractall(data_home)
os.remove(downloaded_path) os.remove(downloaded_path)
@ -96,6 +98,23 @@ def fetch_cifar100():
) )
def save_dataset_stats(path, test_prot, L, V):
test_prevs = [Ui.prevalence() for Ui in test_prot()]
shifts = [qp.error.ae(L.prevalence(), Ui_prev) for Ui_prev in test_prevs]
info = {
"n_classes": L.n_classes,
"n_train": len(L),
"n_val": len(V),
"train_prev": L.prevalence().tolist(),
"val_prev": V.prevalence().tolist(),
"test_prevs": [x.tolist() for x in test_prevs],
"shifts": [x.tolist() for x in shifts],
"sample_size": test_prot.sample_size,
"num_samples": test_prot.total(),
}
save_json_file(path, info)
class DatasetSample: class DatasetSample:
def __init__( def __init__(
self, self,
@ -121,49 +140,69 @@ class DatasetSample:
class DatasetProvider: class DatasetProvider:
def __spambase(self, **kwargs): @classmethod
return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test def _split_train(cls, train: LabelledCollection):
return train.split_stratified(0.5, random_state=0)
# provare min_df=5 @classmethod
def __imdb(self, **kwargs): def _split_whole(cls, dataset: LabelledCollection):
return qp.datasets.fetch_reviews("imdb", tfidf=True, min_df=3).train_test train, U = dataset.split_stratified(train_prop=0.66, random_state=0)
T, V = train.split_stratified(train_prop=0.5, random_state=0)
return T, V, U
@classmethod
def spambase(cls):
train, U = qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test
T, V = cls._split_train(train)
return T, V, U
@classmethod
def imdb(cls):
train, U = qp.datasets.fetch_reviews(
"imdb", tfidf=True, min_df=10, pickle=True
).train_test
T, V = cls._split_train(train)
return T, V, U
@classmethod
def rcv1(cls, target):
training = fetch_rcv1(subset="train")
test = fetch_rcv1(subset="test")
def __rcv1(self, target, **kwargs):
n_train = 23149
available_targets = ["CCAT", "GCAT", "MCAT"] available_targets = ["CCAT", "GCAT", "MCAT"]
if target is None or target not in available_targets:
raise ValueError(f"Invalid target {target}")
class_names = training.target_names.tolist()
class_idx = class_names.index(target)
tr_labels = training.target[:, class_idx].toarray().flatten()
te_labels = test.target[:, class_idx].toarray().flatten()
tr = LabelledCollection(training.data, tr_labels)
U = LabelledCollection(test.data, te_labels)
T, V = cls._split_train(tr)
return T, V, U
@classmethod
def cifar10(cls, target):
dataset = fetch_cifar10()
available_targets: list = dataset.label_names
if target is None or target not in available_targets: if target is None or target not in available_targets:
raise ValueError(f"Invalid target {target}") raise ValueError(f"Invalid target {target}")
dataset = fetch_rcv1() target_idx = available_targets.index(target)
target_index = np.where(dataset.target_names == target)[0] train_d = dataset.train.data
all_train_d = dataset.data[:n_train, :] train_l = (dataset.train.labels == target_idx).astype(int)
test_d = dataset.data[n_train:, :]
labels = dataset.target[:, target_index].toarray().flatten()
all_train_l, test_l = labels[:n_train], labels[n_train:]
all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1])
test = LabelledCollection(test_d, test_l, classes=[0, 1])
return all_train, test
def __cifar10(self, target, **kwargs):
dataset = fetch_cifar10()
available_targets: list = dataset.label_names
if target is None or self._target not in available_targets:
raise ValueError(f"Invalid target {target}")
target_index = available_targets.index(target)
all_train_d = dataset.train.data
all_train_l = (dataset.train.labels == target_index).astype(int)
test_d = dataset.test.data test_d = dataset.test.data
test_l = (dataset.test.labels == target_index).astype(int) test_l = (dataset.test.labels == target_idx).astype(int)
all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1]) train = LabelledCollection(train_d, train_l, classes=[0, 1])
test = LabelledCollection(test_d, test_l, classes=[0, 1]) U = LabelledCollection(test_d, test_l, classes=[0, 1])
T, V = cls._split_train(train)
return all_train, test return T, V, U
def __cifar100(self, target, **kwargs): @classmethod
def cifar100(cls, target):
dataset = fetch_cifar100() dataset = fetch_cifar100()
available_targets: list = dataset.coarse_label_names available_targets: list = dataset.coarse_label_names
@ -171,31 +210,48 @@ class DatasetProvider:
raise ValueError(f"Invalid target {target}") raise ValueError(f"Invalid target {target}")
target_index = available_targets.index(target) target_index = available_targets.index(target)
all_train_d = dataset.train.data train_d = dataset.train.data
all_train_l = (dataset.train.coarse_labels == target_index).astype(int) train_l = (dataset.train.coarse_labels == target_index).astype(int)
test_d = dataset.test.data test_d = dataset.test.data
test_l = (dataset.test.coarse_labels == target_index).astype(int) test_l = (dataset.test.coarse_labels == target_index).astype(int)
all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1]) train = LabelledCollection(train_d, train_l, classes=[0, 1])
test = LabelledCollection(test_d, test_l, classes=[0, 1]) U = LabelledCollection(test_d, test_l, classes=[0, 1])
T, V = cls._split_train(train)
return all_train, test return T, V, U
def __twitter_gasp(self, **kwargs): @classmethod
return qp.datasets.fetch_twitter("gasp", min_df=3).train_test def twitter(cls, dataset_name):
data = qp.datasets.fetch_twitter(dataset_name, min_df=3, pickle=True)
T, V = cls._split_train(data.training)
U = data.test
return T, V, U
def alltrain_test( @classmethod
self, name: str, target: str | None def uci_multiclass(cls, dataset_name):
) -> Tuple[LabelledCollection, LabelledCollection]: dataset = fetch_UCIMulticlassLabelledCollection(dataset_name)
all_train, test = { return cls._split_whole(dataset)
"spambase": self.__spambase,
"imdb": self.__imdb,
"rcv1": self.__rcv1,
"cifar10": self.__cifar10,
"cifar100": self.__cifar100,
"twitter_gasp": self.__twitter_gasp,
}[name](target=target)
return all_train, test @classmethod
def news20(cls):
train = fetch_20newsgroups(
subset="train", remove=("headers", "footers", "quotes")
)
test = fetch_20newsgroups(
subset="test", remove=("headers", "footers", "quotes")
)
tfidf = TfidfVectorizer(min_df=5, sublinear_tf=True)
Xtr = tfidf.fit_transform(train.data)
Xte = tfidf.transform((test.data))
train = LabelledCollection(instances=Xtr, labels=train.target)
U = LabelledCollection(instances=Xte, labels=test.target)
T, V = cls._split_train(train)
return T, V, U
@classmethod
def t1b_lequa2022(cls):
dataset, _, _ = fetch_lequa2022(task="T1B")
return cls._split_whole(dataset)
class Dataset(DatasetProvider): class Dataset(DatasetProvider):