diff --git a/quacc/dataset.py b/quacc/dataset.py index fe91e92..ce97aec 100644 --- a/quacc/dataset.py +++ b/quacc/dataset.py @@ -1,14 +1,99 @@ import math +import os +import pickle +import tarfile from typing import List import numpy as np import quapy as qp from quapy.data.base import LabelledCollection from sklearn.conftest import fetch_rcv1 +from sklearn.utils import Bunch + +from quacc import utils TRAIN_VAL_PROP = 0.5 +def fetch_cifar10() -> Bunch: + URL = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" + data_home = utils.get_quacc_home() + unzipped_path = data_home / "cifar-10-batches-py" + if not unzipped_path.exists(): + downloaded_path = data_home / URL.split("/")[-1] + utils.download_file(URL, downloaded_path) + with tarfile.open(downloaded_path) as f: + f.extractall(data_home) + os.remove(downloaded_path) + + datas = [] + data_names = sorted([f for f in os.listdir(unzipped_path) if f.startswith("data")]) + for f in data_names: + with open(unzipped_path / f, "rb") as file: + datas.append(pickle.load(file, encoding="bytes")) + + tests = [] + test_names = sorted([f for f in os.listdir(unzipped_path) if f.startswith("test")]) + for f in test_names: + with open(unzipped_path / f, "rb") as file: + tests.append(pickle.load(file, encoding="bytes")) + + with open(unzipped_path / "batches.meta", "rb") as file: + meta = pickle.load(file, encoding="bytes") + + return Bunch( + train=Bunch( + data=np.concatenate([d[b"data"] for d in datas], axis=0), + labels=np.concatenate([d[b"labels"] for d in datas]), + ), + test=Bunch( + data=np.concatenate([d[b"data"] for d in tests], axis=0), + labels=np.concatenate([d[b"labels"] for d in tests]), + ), + label_names=[cs.decode("utf-8") for cs in meta[b"label_names"]], + ) + + +def fetch_cifar100(): + URL = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" + data_home = utils.get_quacc_home() + unzipped_path = data_home / "cifar-100-python" + if not unzipped_path.exists(): + downloaded_path = data_home / URL.split("/")[-1] + utils.download_file(URL, downloaded_path) + with tarfile.open(downloaded_path) as f: + f.extractall(data_home) + os.remove(downloaded_path) + + with open(unzipped_path / "train", "rb") as file: + train_d = pickle.load(file, encoding="bytes") + + with open(unzipped_path / "test", "rb") as file: + test_d = pickle.load(file, encoding="bytes") + + with open(unzipped_path / "meta", "rb") as file: + meta_d = pickle.load(file, encoding="bytes") + + train_bunch = Bunch( + data=train_d[b"data"], + fine_labels=np.array(train_d[b"fine_labels"]), + coarse_labels=np.array(train_d[b"coarse_labels"]), + ) + + test_bunch = Bunch( + data=test_d[b"data"], + fine_labels=np.array(test_d[b"fine_labels"]), + coarse_labels=np.array(test_d[b"coarse_labels"]), + ) + + return Bunch( + train=train_bunch, + test=test_bunch, + fine_label_names=meta_d[b"fine_label_names"], + coarse_label_names=meta_d[b"coarse_label_names"], + ) + + class DatasetSample: def __init__( self, @@ -71,13 +156,54 @@ class Dataset: return all_train, test - def get_raw(self) -> DatasetSample: + def __cifar10(self): + dataset = fetch_cifar10() + available_targets: list = dataset.label_names + + if self._target is None or self._target not in available_targets: + raise ValueError(f"Invalid target {self._target}") + + target_index = available_targets.index(self._target) + all_train_d = dataset.train.data + all_train_l = (dataset.train.labels == target_index).astype(int) + test_d = dataset.test.data + test_l = (dataset.test.labels == target_index).astype(int) + all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1]) + test = LabelledCollection(test_d, test_l, classes=[0, 1]) + + return all_train, test + + def __cifar100(self): + dataset = fetch_cifar100() + available_targets: list = dataset.coarse_label_names + + if self._target is None or self._target not in available_targets: + raise ValueError(f"Invalid target {self._target}") + + target_index = available_targets.index(self._target) + all_train_d = dataset.train.data + all_train_l = (dataset.train.coarse_labels == target_index).astype(int) + test_d = dataset.test.data + test_l = (dataset.test.coarse_labels == target_index).astype(int) + all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1]) + test = LabelledCollection(test_d, test_l, classes=[0, 1]) + + return all_train, test + + def __train_test(self): all_train, test = { "spambase": self.__spambase, "imdb": self.__imdb, "rcv1": self.__rcv1, + "cifar10": self.__cifar10, + "cifar100": self.__cifar100, }[self._name]() + return all_train, test + + def get_raw(self) -> DatasetSample: + all_train, test = self.__train_test() + train, val = all_train.split_stratified( train_prop=TRAIN_VAL_PROP, random_state=0 ) @@ -85,11 +211,7 @@ class Dataset: return DatasetSample(train, val, test) def get(self) -> List[DatasetSample]: - (all_train, test) = { - "spambase": self.__spambase, - "imdb": self.__imdb, - "rcv1": self.__rcv1, - }[self._name]() + all_train, test = self.__train_test() # resample all_train set to have (0.5, 0.5) prevalence at_positives = np.sum(all_train.y) @@ -119,11 +241,15 @@ class Dataset: @property def name(self): - return ( - f"{self._name}_{self._target}_{self.n_prevs}prevs" - if self._name == "rcv1" - else f"{self._name}_{self.n_prevs}prevs" - ) + match (self._name, self.n_prevs): + case (("rcv1" | "cifar10" | "cifar100"), 9): + return f"{self._name}_{self._target}" + case (("rcv1" | "cifar10" | "cifar100"), _): + return f"{self._name}_{self._target}_{self.n_prevs}prevs" + case (_, 9): + return f"{self._name}" + case (_, _): + return f"{self._name}_{self.n_prevs}prevs" # >>> fetch_rcv1().target_names @@ -168,4 +294,4 @@ def rcv1_info(): if __name__ == "__main__": - rcv1_info() + fetch_cifar100()