datasets added and adapted to refactoring

This commit is contained in:
Lorenzo Volpi 2024-04-04 17:06:51 +02:00
parent d4b0212a92
commit 9d2a6dcdec
1 changed files with 117 additions and 61 deletions

View File

@ -1,29 +1,31 @@
import itertools
import math
import os
import pickle
import tarfile
from typing import List, Tuple
from typing import List
import numpy as np
import quapy as qp
from quapy.data.base import LabelledCollection
from sklearn.conftest import fetch_rcv1
from quapy.data.datasets import fetch_lequa2022, fetch_UCIMulticlassLabelledCollection
from sklearn.datasets import fetch_20newsgroups, fetch_rcv1
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.utils import Bunch
from quacc import utils
from quacc.environment import env
from quacc.legacy.environment import env
from quacc.utils import commons
from quacc.utils.commons import save_json_file
TRAIN_VAL_PROP = 0.5
def fetch_cifar10() -> Bunch:
URL = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
data_home = utils.get_quacc_home()
data_home = commons.get_quacc_home()
unzipped_path = data_home / "cifar-10-batches-py"
if not unzipped_path.exists():
downloaded_path = data_home / URL.split("/")[-1]
utils.download_file(URL, downloaded_path)
commons.download_file(URL, downloaded_path)
with tarfile.open(downloaded_path) as f:
f.extractall(data_home)
os.remove(downloaded_path)
@ -58,11 +60,11 @@ def fetch_cifar10() -> Bunch:
def fetch_cifar100():
URL = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
data_home = utils.get_quacc_home()
data_home = commons.get_quacc_home()
unzipped_path = data_home / "cifar-100-python"
if not unzipped_path.exists():
downloaded_path = data_home / URL.split("/")[-1]
utils.download_file(URL, downloaded_path)
commons.download_file(URL, downloaded_path)
with tarfile.open(downloaded_path) as f:
f.extractall(data_home)
os.remove(downloaded_path)
@ -96,6 +98,23 @@ def fetch_cifar100():
)
def save_dataset_stats(path, test_prot, L, V):
test_prevs = [Ui.prevalence() for Ui in test_prot()]
shifts = [qp.error.ae(L.prevalence(), Ui_prev) for Ui_prev in test_prevs]
info = {
"n_classes": L.n_classes,
"n_train": len(L),
"n_val": len(V),
"train_prev": L.prevalence().tolist(),
"val_prev": V.prevalence().tolist(),
"test_prevs": [x.tolist() for x in test_prevs],
"shifts": [x.tolist() for x in shifts],
"sample_size": test_prot.sample_size,
"num_samples": test_prot.total(),
}
save_json_file(path, info)
class DatasetSample:
def __init__(
self,
@ -121,49 +140,69 @@ class DatasetSample:
class DatasetProvider:
def __spambase(self, **kwargs):
return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test
@classmethod
def _split_train(cls, train: LabelledCollection):
return train.split_stratified(0.5, random_state=0)
# provare min_df=5
def __imdb(self, **kwargs):
return qp.datasets.fetch_reviews("imdb", tfidf=True, min_df=3).train_test
@classmethod
def _split_whole(cls, dataset: LabelledCollection):
train, U = dataset.split_stratified(train_prop=0.66, random_state=0)
T, V = train.split_stratified(train_prop=0.5, random_state=0)
return T, V, U
@classmethod
def spambase(cls):
train, U = qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test
T, V = cls._split_train(train)
return T, V, U
@classmethod
def imdb(cls):
train, U = qp.datasets.fetch_reviews(
"imdb", tfidf=True, min_df=10, pickle=True
).train_test
T, V = cls._split_train(train)
return T, V, U
@classmethod
def rcv1(cls, target):
training = fetch_rcv1(subset="train")
test = fetch_rcv1(subset="test")
def __rcv1(self, target, **kwargs):
n_train = 23149
available_targets = ["CCAT", "GCAT", "MCAT"]
if target is None or target not in available_targets:
raise ValueError(f"Invalid target {target}")
class_names = training.target_names.tolist()
class_idx = class_names.index(target)
tr_labels = training.target[:, class_idx].toarray().flatten()
te_labels = test.target[:, class_idx].toarray().flatten()
tr = LabelledCollection(training.data, tr_labels)
U = LabelledCollection(test.data, te_labels)
T, V = cls._split_train(tr)
return T, V, U
@classmethod
def cifar10(cls, target):
dataset = fetch_cifar10()
available_targets: list = dataset.label_names
if target is None or target not in available_targets:
raise ValueError(f"Invalid target {target}")
dataset = fetch_rcv1()
target_index = np.where(dataset.target_names == target)[0]
all_train_d = dataset.data[:n_train, :]
test_d = dataset.data[n_train:, :]
labels = dataset.target[:, target_index].toarray().flatten()
all_train_l, test_l = labels[:n_train], labels[n_train:]
all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1])
test = LabelledCollection(test_d, test_l, classes=[0, 1])
return all_train, test
def __cifar10(self, target, **kwargs):
dataset = fetch_cifar10()
available_targets: list = dataset.label_names
if target is None or self._target not in available_targets:
raise ValueError(f"Invalid target {target}")
target_index = available_targets.index(target)
all_train_d = dataset.train.data
all_train_l = (dataset.train.labels == target_index).astype(int)
target_idx = available_targets.index(target)
train_d = dataset.train.data
train_l = (dataset.train.labels == target_idx).astype(int)
test_d = dataset.test.data
test_l = (dataset.test.labels == target_index).astype(int)
all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1])
test = LabelledCollection(test_d, test_l, classes=[0, 1])
test_l = (dataset.test.labels == target_idx).astype(int)
train = LabelledCollection(train_d, train_l, classes=[0, 1])
U = LabelledCollection(test_d, test_l, classes=[0, 1])
T, V = cls._split_train(train)
return all_train, test
return T, V, U
def __cifar100(self, target, **kwargs):
@classmethod
def cifar100(cls, target):
dataset = fetch_cifar100()
available_targets: list = dataset.coarse_label_names
@ -171,31 +210,48 @@ class DatasetProvider:
raise ValueError(f"Invalid target {target}")
target_index = available_targets.index(target)
all_train_d = dataset.train.data
all_train_l = (dataset.train.coarse_labels == target_index).astype(int)
train_d = dataset.train.data
train_l = (dataset.train.coarse_labels == target_index).astype(int)
test_d = dataset.test.data
test_l = (dataset.test.coarse_labels == target_index).astype(int)
all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1])
test = LabelledCollection(test_d, test_l, classes=[0, 1])
train = LabelledCollection(train_d, train_l, classes=[0, 1])
U = LabelledCollection(test_d, test_l, classes=[0, 1])
T, V = cls._split_train(train)
return all_train, test
return T, V, U
def __twitter_gasp(self, **kwargs):
return qp.datasets.fetch_twitter("gasp", min_df=3).train_test
@classmethod
def twitter(cls, dataset_name):
data = qp.datasets.fetch_twitter(dataset_name, min_df=3, pickle=True)
T, V = cls._split_train(data.training)
U = data.test
return T, V, U
def alltrain_test(
self, name: str, target: str | None
) -> Tuple[LabelledCollection, LabelledCollection]:
all_train, test = {
"spambase": self.__spambase,
"imdb": self.__imdb,
"rcv1": self.__rcv1,
"cifar10": self.__cifar10,
"cifar100": self.__cifar100,
"twitter_gasp": self.__twitter_gasp,
}[name](target=target)
@classmethod
def uci_multiclass(cls, dataset_name):
dataset = fetch_UCIMulticlassLabelledCollection(dataset_name)
return cls._split_whole(dataset)
return all_train, test
@classmethod
def news20(cls):
train = fetch_20newsgroups(
subset="train", remove=("headers", "footers", "quotes")
)
test = fetch_20newsgroups(
subset="test", remove=("headers", "footers", "quotes")
)
tfidf = TfidfVectorizer(min_df=5, sublinear_tf=True)
Xtr = tfidf.fit_transform(train.data)
Xte = tfidf.transform((test.data))
train = LabelledCollection(instances=Xtr, labels=train.target)
U = LabelledCollection(instances=Xte, labels=test.target)
T, V = cls._split_train(train)
return T, V, U
@classmethod
def t1b_lequa2022(cls):
dataset, _, _ = fetch_lequa2022(task="T1B")
return cls._split_whole(dataset)
class Dataset(DatasetProvider):