datasets added and adapted to refactoring
This commit is contained in:
parent
d4b0212a92
commit
9d2a6dcdec
178
quacc/dataset.py
178
quacc/dataset.py
|
@ -1,29 +1,31 @@
|
|||
import itertools
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
import tarfile
|
||||
from typing import List, Tuple
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import quapy as qp
|
||||
from quapy.data.base import LabelledCollection
|
||||
from sklearn.conftest import fetch_rcv1
|
||||
from quapy.data.datasets import fetch_lequa2022, fetch_UCIMulticlassLabelledCollection
|
||||
from sklearn.datasets import fetch_20newsgroups, fetch_rcv1
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.utils import Bunch
|
||||
|
||||
from quacc import utils
|
||||
from quacc.environment import env
|
||||
from quacc.legacy.environment import env
|
||||
from quacc.utils import commons
|
||||
from quacc.utils.commons import save_json_file
|
||||
|
||||
TRAIN_VAL_PROP = 0.5
|
||||
|
||||
|
||||
def fetch_cifar10() -> Bunch:
|
||||
URL = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
|
||||
data_home = utils.get_quacc_home()
|
||||
data_home = commons.get_quacc_home()
|
||||
unzipped_path = data_home / "cifar-10-batches-py"
|
||||
if not unzipped_path.exists():
|
||||
downloaded_path = data_home / URL.split("/")[-1]
|
||||
utils.download_file(URL, downloaded_path)
|
||||
commons.download_file(URL, downloaded_path)
|
||||
with tarfile.open(downloaded_path) as f:
|
||||
f.extractall(data_home)
|
||||
os.remove(downloaded_path)
|
||||
|
@ -58,11 +60,11 @@ def fetch_cifar10() -> Bunch:
|
|||
|
||||
def fetch_cifar100():
|
||||
URL = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
|
||||
data_home = utils.get_quacc_home()
|
||||
data_home = commons.get_quacc_home()
|
||||
unzipped_path = data_home / "cifar-100-python"
|
||||
if not unzipped_path.exists():
|
||||
downloaded_path = data_home / URL.split("/")[-1]
|
||||
utils.download_file(URL, downloaded_path)
|
||||
commons.download_file(URL, downloaded_path)
|
||||
with tarfile.open(downloaded_path) as f:
|
||||
f.extractall(data_home)
|
||||
os.remove(downloaded_path)
|
||||
|
@ -96,6 +98,23 @@ def fetch_cifar100():
|
|||
)
|
||||
|
||||
|
||||
def save_dataset_stats(path, test_prot, L, V):
|
||||
test_prevs = [Ui.prevalence() for Ui in test_prot()]
|
||||
shifts = [qp.error.ae(L.prevalence(), Ui_prev) for Ui_prev in test_prevs]
|
||||
info = {
|
||||
"n_classes": L.n_classes,
|
||||
"n_train": len(L),
|
||||
"n_val": len(V),
|
||||
"train_prev": L.prevalence().tolist(),
|
||||
"val_prev": V.prevalence().tolist(),
|
||||
"test_prevs": [x.tolist() for x in test_prevs],
|
||||
"shifts": [x.tolist() for x in shifts],
|
||||
"sample_size": test_prot.sample_size,
|
||||
"num_samples": test_prot.total(),
|
||||
}
|
||||
save_json_file(path, info)
|
||||
|
||||
|
||||
class DatasetSample:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -121,49 +140,69 @@ class DatasetSample:
|
|||
|
||||
|
||||
class DatasetProvider:
|
||||
def __spambase(self, **kwargs):
|
||||
return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test
|
||||
@classmethod
|
||||
def _split_train(cls, train: LabelledCollection):
|
||||
return train.split_stratified(0.5, random_state=0)
|
||||
|
||||
# provare min_df=5
|
||||
def __imdb(self, **kwargs):
|
||||
return qp.datasets.fetch_reviews("imdb", tfidf=True, min_df=3).train_test
|
||||
@classmethod
|
||||
def _split_whole(cls, dataset: LabelledCollection):
|
||||
train, U = dataset.split_stratified(train_prop=0.66, random_state=0)
|
||||
T, V = train.split_stratified(train_prop=0.5, random_state=0)
|
||||
return T, V, U
|
||||
|
||||
@classmethod
|
||||
def spambase(cls):
|
||||
train, U = qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test
|
||||
T, V = cls._split_train(train)
|
||||
return T, V, U
|
||||
|
||||
@classmethod
|
||||
def imdb(cls):
|
||||
train, U = qp.datasets.fetch_reviews(
|
||||
"imdb", tfidf=True, min_df=10, pickle=True
|
||||
).train_test
|
||||
T, V = cls._split_train(train)
|
||||
return T, V, U
|
||||
|
||||
@classmethod
|
||||
def rcv1(cls, target):
|
||||
training = fetch_rcv1(subset="train")
|
||||
test = fetch_rcv1(subset="test")
|
||||
|
||||
def __rcv1(self, target, **kwargs):
|
||||
n_train = 23149
|
||||
available_targets = ["CCAT", "GCAT", "MCAT"]
|
||||
if target is None or target not in available_targets:
|
||||
raise ValueError(f"Invalid target {target}")
|
||||
|
||||
class_names = training.target_names.tolist()
|
||||
class_idx = class_names.index(target)
|
||||
tr_labels = training.target[:, class_idx].toarray().flatten()
|
||||
te_labels = test.target[:, class_idx].toarray().flatten()
|
||||
tr = LabelledCollection(training.data, tr_labels)
|
||||
U = LabelledCollection(test.data, te_labels)
|
||||
T, V = cls._split_train(tr)
|
||||
return T, V, U
|
||||
|
||||
@classmethod
|
||||
def cifar10(cls, target):
|
||||
dataset = fetch_cifar10()
|
||||
available_targets: list = dataset.label_names
|
||||
|
||||
if target is None or target not in available_targets:
|
||||
raise ValueError(f"Invalid target {target}")
|
||||
|
||||
dataset = fetch_rcv1()
|
||||
target_index = np.where(dataset.target_names == target)[0]
|
||||
all_train_d = dataset.data[:n_train, :]
|
||||
test_d = dataset.data[n_train:, :]
|
||||
labels = dataset.target[:, target_index].toarray().flatten()
|
||||
all_train_l, test_l = labels[:n_train], labels[n_train:]
|
||||
all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1])
|
||||
test = LabelledCollection(test_d, test_l, classes=[0, 1])
|
||||
|
||||
return all_train, test
|
||||
|
||||
def __cifar10(self, target, **kwargs):
|
||||
dataset = fetch_cifar10()
|
||||
available_targets: list = dataset.label_names
|
||||
|
||||
if target is None or self._target not in available_targets:
|
||||
raise ValueError(f"Invalid target {target}")
|
||||
|
||||
target_index = available_targets.index(target)
|
||||
all_train_d = dataset.train.data
|
||||
all_train_l = (dataset.train.labels == target_index).astype(int)
|
||||
target_idx = available_targets.index(target)
|
||||
train_d = dataset.train.data
|
||||
train_l = (dataset.train.labels == target_idx).astype(int)
|
||||
test_d = dataset.test.data
|
||||
test_l = (dataset.test.labels == target_index).astype(int)
|
||||
all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1])
|
||||
test = LabelledCollection(test_d, test_l, classes=[0, 1])
|
||||
test_l = (dataset.test.labels == target_idx).astype(int)
|
||||
train = LabelledCollection(train_d, train_l, classes=[0, 1])
|
||||
U = LabelledCollection(test_d, test_l, classes=[0, 1])
|
||||
T, V = cls._split_train(train)
|
||||
|
||||
return all_train, test
|
||||
return T, V, U
|
||||
|
||||
def __cifar100(self, target, **kwargs):
|
||||
@classmethod
|
||||
def cifar100(cls, target):
|
||||
dataset = fetch_cifar100()
|
||||
available_targets: list = dataset.coarse_label_names
|
||||
|
||||
|
@ -171,31 +210,48 @@ class DatasetProvider:
|
|||
raise ValueError(f"Invalid target {target}")
|
||||
|
||||
target_index = available_targets.index(target)
|
||||
all_train_d = dataset.train.data
|
||||
all_train_l = (dataset.train.coarse_labels == target_index).astype(int)
|
||||
train_d = dataset.train.data
|
||||
train_l = (dataset.train.coarse_labels == target_index).astype(int)
|
||||
test_d = dataset.test.data
|
||||
test_l = (dataset.test.coarse_labels == target_index).astype(int)
|
||||
all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1])
|
||||
test = LabelledCollection(test_d, test_l, classes=[0, 1])
|
||||
train = LabelledCollection(train_d, train_l, classes=[0, 1])
|
||||
U = LabelledCollection(test_d, test_l, classes=[0, 1])
|
||||
T, V = cls._split_train(train)
|
||||
|
||||
return all_train, test
|
||||
return T, V, U
|
||||
|
||||
def __twitter_gasp(self, **kwargs):
|
||||
return qp.datasets.fetch_twitter("gasp", min_df=3).train_test
|
||||
@classmethod
|
||||
def twitter(cls, dataset_name):
|
||||
data = qp.datasets.fetch_twitter(dataset_name, min_df=3, pickle=True)
|
||||
T, V = cls._split_train(data.training)
|
||||
U = data.test
|
||||
return T, V, U
|
||||
|
||||
def alltrain_test(
|
||||
self, name: str, target: str | None
|
||||
) -> Tuple[LabelledCollection, LabelledCollection]:
|
||||
all_train, test = {
|
||||
"spambase": self.__spambase,
|
||||
"imdb": self.__imdb,
|
||||
"rcv1": self.__rcv1,
|
||||
"cifar10": self.__cifar10,
|
||||
"cifar100": self.__cifar100,
|
||||
"twitter_gasp": self.__twitter_gasp,
|
||||
}[name](target=target)
|
||||
@classmethod
|
||||
def uci_multiclass(cls, dataset_name):
|
||||
dataset = fetch_UCIMulticlassLabelledCollection(dataset_name)
|
||||
return cls._split_whole(dataset)
|
||||
|
||||
return all_train, test
|
||||
@classmethod
|
||||
def news20(cls):
|
||||
train = fetch_20newsgroups(
|
||||
subset="train", remove=("headers", "footers", "quotes")
|
||||
)
|
||||
test = fetch_20newsgroups(
|
||||
subset="test", remove=("headers", "footers", "quotes")
|
||||
)
|
||||
tfidf = TfidfVectorizer(min_df=5, sublinear_tf=True)
|
||||
Xtr = tfidf.fit_transform(train.data)
|
||||
Xte = tfidf.transform((test.data))
|
||||
train = LabelledCollection(instances=Xtr, labels=train.target)
|
||||
U = LabelledCollection(instances=Xte, labels=test.target)
|
||||
T, V = cls._split_train(train)
|
||||
return T, V, U
|
||||
|
||||
@classmethod
|
||||
def t1b_lequa2022(cls):
|
||||
dataset, _, _ = fetch_lequa2022(task="T1B")
|
||||
return cls._split_whole(dataset)
|
||||
|
||||
|
||||
class Dataset(DatasetProvider):
|
||||
|
|
Loading…
Reference in New Issue