From 9d2a6dcdec46f81a8a21540e2616ac7c7b61918a Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Thu, 4 Apr 2024 17:06:51 +0200 Subject: [PATCH] datasets added and adapted to refactoring --- quacc/dataset.py | 178 +++++++++++++++++++++++++++++++---------------- 1 file changed, 117 insertions(+), 61 deletions(-) diff --git a/quacc/dataset.py b/quacc/dataset.py index ea0a058..45df048 100644 --- a/quacc/dataset.py +++ b/quacc/dataset.py @@ -1,29 +1,31 @@ -import itertools import math import os import pickle import tarfile -from typing import List, Tuple +from typing import List import numpy as np import quapy as qp from quapy.data.base import LabelledCollection -from sklearn.conftest import fetch_rcv1 +from quapy.data.datasets import fetch_lequa2022, fetch_UCIMulticlassLabelledCollection +from sklearn.datasets import fetch_20newsgroups, fetch_rcv1 +from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.utils import Bunch -from quacc import utils -from quacc.environment import env +from quacc.legacy.environment import env +from quacc.utils import commons +from quacc.utils.commons import save_json_file TRAIN_VAL_PROP = 0.5 def fetch_cifar10() -> Bunch: URL = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" - data_home = utils.get_quacc_home() + data_home = commons.get_quacc_home() unzipped_path = data_home / "cifar-10-batches-py" if not unzipped_path.exists(): downloaded_path = data_home / URL.split("/")[-1] - utils.download_file(URL, downloaded_path) + commons.download_file(URL, downloaded_path) with tarfile.open(downloaded_path) as f: f.extractall(data_home) os.remove(downloaded_path) @@ -58,11 +60,11 @@ def fetch_cifar10() -> Bunch: def fetch_cifar100(): URL = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" - data_home = utils.get_quacc_home() + data_home = commons.get_quacc_home() unzipped_path = data_home / "cifar-100-python" if not unzipped_path.exists(): downloaded_path = data_home / URL.split("/")[-1] - utils.download_file(URL, downloaded_path) + commons.download_file(URL, downloaded_path) with tarfile.open(downloaded_path) as f: f.extractall(data_home) os.remove(downloaded_path) @@ -96,6 +98,23 @@ def fetch_cifar100(): ) +def save_dataset_stats(path, test_prot, L, V): + test_prevs = [Ui.prevalence() for Ui in test_prot()] + shifts = [qp.error.ae(L.prevalence(), Ui_prev) for Ui_prev in test_prevs] + info = { + "n_classes": L.n_classes, + "n_train": len(L), + "n_val": len(V), + "train_prev": L.prevalence().tolist(), + "val_prev": V.prevalence().tolist(), + "test_prevs": [x.tolist() for x in test_prevs], + "shifts": [x.tolist() for x in shifts], + "sample_size": test_prot.sample_size, + "num_samples": test_prot.total(), + } + save_json_file(path, info) + + class DatasetSample: def __init__( self, @@ -121,49 +140,69 @@ class DatasetSample: class DatasetProvider: - def __spambase(self, **kwargs): - return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test + @classmethod + def _split_train(cls, train: LabelledCollection): + return train.split_stratified(0.5, random_state=0) - # provare min_df=5 - def __imdb(self, **kwargs): - return qp.datasets.fetch_reviews("imdb", tfidf=True, min_df=3).train_test + @classmethod + def _split_whole(cls, dataset: LabelledCollection): + train, U = dataset.split_stratified(train_prop=0.66, random_state=0) + T, V = train.split_stratified(train_prop=0.5, random_state=0) + return T, V, U + + @classmethod + def spambase(cls): + train, U = qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test + T, V = cls._split_train(train) + return T, V, U + + @classmethod + def imdb(cls): + train, U = qp.datasets.fetch_reviews( + "imdb", tfidf=True, min_df=10, pickle=True + ).train_test + T, V = cls._split_train(train) + return T, V, U + + @classmethod + def rcv1(cls, target): + training = fetch_rcv1(subset="train") + test = fetch_rcv1(subset="test") - def __rcv1(self, target, **kwargs): - n_train = 23149 available_targets = ["CCAT", "GCAT", "MCAT"] + if target is None or target not in available_targets: + raise ValueError(f"Invalid target {target}") + + class_names = training.target_names.tolist() + class_idx = class_names.index(target) + tr_labels = training.target[:, class_idx].toarray().flatten() + te_labels = test.target[:, class_idx].toarray().flatten() + tr = LabelledCollection(training.data, tr_labels) + U = LabelledCollection(test.data, te_labels) + T, V = cls._split_train(tr) + return T, V, U + + @classmethod + def cifar10(cls, target): + dataset = fetch_cifar10() + available_targets: list = dataset.label_names if target is None or target not in available_targets: raise ValueError(f"Invalid target {target}") - dataset = fetch_rcv1() - target_index = np.where(dataset.target_names == target)[0] - all_train_d = dataset.data[:n_train, :] - test_d = dataset.data[n_train:, :] - labels = dataset.target[:, target_index].toarray().flatten() - all_train_l, test_l = labels[:n_train], labels[n_train:] - all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1]) - test = LabelledCollection(test_d, test_l, classes=[0, 1]) - - return all_train, test - - def __cifar10(self, target, **kwargs): - dataset = fetch_cifar10() - available_targets: list = dataset.label_names - - if target is None or self._target not in available_targets: - raise ValueError(f"Invalid target {target}") - - target_index = available_targets.index(target) - all_train_d = dataset.train.data - all_train_l = (dataset.train.labels == target_index).astype(int) + target_idx = available_targets.index(target) + train_d = dataset.train.data + train_l = (dataset.train.labels == target_idx).astype(int) test_d = dataset.test.data - test_l = (dataset.test.labels == target_index).astype(int) - all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1]) - test = LabelledCollection(test_d, test_l, classes=[0, 1]) + test_l = (dataset.test.labels == target_idx).astype(int) + train = LabelledCollection(train_d, train_l, classes=[0, 1]) + U = LabelledCollection(test_d, test_l, classes=[0, 1]) + T, V = cls._split_train(train) - return all_train, test + return T, V, U - def __cifar100(self, target, **kwargs): + @classmethod + def cifar100(cls, target): dataset = fetch_cifar100() available_targets: list = dataset.coarse_label_names @@ -171,31 +210,48 @@ class DatasetProvider: raise ValueError(f"Invalid target {target}") target_index = available_targets.index(target) - all_train_d = dataset.train.data - all_train_l = (dataset.train.coarse_labels == target_index).astype(int) + train_d = dataset.train.data + train_l = (dataset.train.coarse_labels == target_index).astype(int) test_d = dataset.test.data test_l = (dataset.test.coarse_labels == target_index).astype(int) - all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1]) - test = LabelledCollection(test_d, test_l, classes=[0, 1]) + train = LabelledCollection(train_d, train_l, classes=[0, 1]) + U = LabelledCollection(test_d, test_l, classes=[0, 1]) + T, V = cls._split_train(train) - return all_train, test + return T, V, U - def __twitter_gasp(self, **kwargs): - return qp.datasets.fetch_twitter("gasp", min_df=3).train_test + @classmethod + def twitter(cls, dataset_name): + data = qp.datasets.fetch_twitter(dataset_name, min_df=3, pickle=True) + T, V = cls._split_train(data.training) + U = data.test + return T, V, U - def alltrain_test( - self, name: str, target: str | None - ) -> Tuple[LabelledCollection, LabelledCollection]: - all_train, test = { - "spambase": self.__spambase, - "imdb": self.__imdb, - "rcv1": self.__rcv1, - "cifar10": self.__cifar10, - "cifar100": self.__cifar100, - "twitter_gasp": self.__twitter_gasp, - }[name](target=target) + @classmethod + def uci_multiclass(cls, dataset_name): + dataset = fetch_UCIMulticlassLabelledCollection(dataset_name) + return cls._split_whole(dataset) - return all_train, test + @classmethod + def news20(cls): + train = fetch_20newsgroups( + subset="train", remove=("headers", "footers", "quotes") + ) + test = fetch_20newsgroups( + subset="test", remove=("headers", "footers", "quotes") + ) + tfidf = TfidfVectorizer(min_df=5, sublinear_tf=True) + Xtr = tfidf.fit_transform(train.data) + Xte = tfidf.transform((test.data)) + train = LabelledCollection(instances=Xtr, labels=train.target) + U = LabelledCollection(instances=Xte, labels=test.target) + T, V = cls._split_train(train) + return T, V, U + + @classmethod + def t1b_lequa2022(cls): + dataset, _, _ = fetch_lequa2022(task="T1B") + return cls._split_whole(dataset) class Dataset(DatasetProvider):