diff --git a/quapy/data/datasets.py b/quapy/data/datasets.py index d2060a5..2fe08c4 100644 --- a/quapy/data/datasets.py +++ b/quapy/data/datasets.py @@ -12,6 +12,7 @@ from quapy.data.base import Dataset, LabelledCollection from quapy.data.preprocessing import text2tfidf, reduce_columns from quapy.data.reader import * from quapy.util import download_file_if_not_exists, download_file, get_quapy_home, pickled_resource +from sklearn.preprocessing import StandardScaler REVIEWS_SENTIMENT_DATASETS = ['hp', 'kindle', 'imdb'] @@ -21,11 +22,13 @@ TWITTER_SENTIMENT_DATASETS_TEST = [ 'semeval13', 'semeval14', 'semeval15', 'semeval16', 'sst', 'wa', 'wb', ] + TWITTER_SENTIMENT_DATASETS_TRAIN = [ 'gasp', 'hcr', 'omd', 'sanders', 'semeval', 'semeval16', 'sst', 'wa', 'wb', ] + UCI_BINARY_DATASETS = [ #'acute.a', 'acute.b', 'balance.1', @@ -230,7 +233,7 @@ def fetch_twitter(dataset_name, for_model_selection=False, min_df=None, data_hom return data -def fetch_UCIBinaryDataset(dataset_name, data_home=None, test_split=0.3, verbose=False) -> Dataset: +def fetch_UCIBinaryDataset(dataset_name, data_home=None, test_split=0.3, standardize=True, verbose=False) -> Dataset: """ Loads a UCI dataset as an instance of :class:`quapy.data.base.Dataset`, as used in `Pérez-Gállego, P., Quevedo, J. R., & del Coz, J. J. (2017). @@ -248,14 +251,20 @@ def fetch_UCIBinaryDataset(dataset_name, data_home=None, test_split=0.3, verbose :param data_home: specify the quapy home directory where collections will be dumped (leave empty to use the default ~/quay_data/ directory) :param test_split: proportion of documents to be included in the test set. The rest conforms the training set + :param standardize: indicates whether the covariates should be standardized or not (default is True). If requested, + standardization applies after the LabelledCollection is split, that is, the mean an std are computed only on the + training portion of the data. :param verbose: set to True (default is False) to get information (from the UCI ML repository) about the datasets :return: a :class:`quapy.data.base.Dataset` instance """ data = fetch_UCIBinaryLabelledCollection(dataset_name, data_home, verbose) - return Dataset(*data.split_stratified(1 - test_split, random_state=0), name=dataset_name) + dataset = Dataset(*data.split_stratified(1 - test_split, random_state=0), name=dataset_name) + if standardize: + dataset = qp.data.preprocessing.standardize(dataset) + return dataset -def fetch_UCIBinaryLabelledCollection(dataset_name, data_home=None, verbose=False) -> LabelledCollection: +def fetch_UCIBinaryLabelledCollection(dataset_name, data_home=None, standardize=True, verbose=False) -> LabelledCollection: """ Loads a UCI collection as an instance of :class:`quapy.data.base.LabelledCollection`, as used in `Pérez-Gállego, P., Quevedo, J. R., & del Coz, J. J. (2017). @@ -279,6 +288,7 @@ def fetch_UCIBinaryLabelledCollection(dataset_name, data_home=None, verbose=Fals :param dataset_name: a dataset name :param data_home: specify the quapy home directory where collections will be dumped (leave empty to use the default ~/quay_data/ directory) + :param standardize: indicates whether the covariates should be standardized or not (default is True). :param verbose: set to True (default is False) to get information (from the UCI ML repository) about the datasets :return: a :class:`quapy.data.base.LabelledCollection` instance """ @@ -568,6 +578,10 @@ def fetch_UCIBinaryLabelledCollection(dataset_name, data_home=None, verbose=Fals data = pickled_resource(file, download, identifier, dataset_group) data = binarize_data(dataset_name, data) + if standardize: + stds = StandardScaler() + data.instances = stds.fit_transform(data.instances) + if verbose: data.stats() @@ -580,6 +594,7 @@ def fetch_UCIMulticlassDataset( min_test_split=0.3, max_train_instances=25000, min_class_support=100, + standardize=True, verbose=False) -> Dataset: """ Loads a UCI multiclass dataset as an instance of :class:`quapy.data.base.Dataset`. @@ -610,6 +625,9 @@ def fetch_UCIMulticlassDataset( set to -1 or None to avoid this check :param min_class_support: minimum number of istances per class. Classes with fewer instances are discarded (deafult is 100) + :param standardize: indicates whether the covariates should be standardized or not (default is True). If requested, + standardization applies after the LabelledCollection is split, that is, the mean an std are computed only on the + training portion of the data. :param verbose: set to True (default is False) to get information (stats) about the dataset :return: a :class:`quapy.data.base.Dataset` instance """ @@ -622,10 +640,15 @@ def fetch_UCIMulticlassDataset( if n_train > max_train_instances: train_prop = (max_train_instances / n) - return Dataset(*data.split_stratified(train_prop, random_state=0)) + data = Dataset(*data.split_stratified(train_prop, random_state=0)) + + if standardize: + data = qp.data.preprocessing.standardize(data) + + return data -def fetch_UCIMulticlassLabelledCollection(dataset_name, data_home=None, min_class_support=100, verbose=False) -> LabelledCollection: +def fetch_UCIMulticlassLabelledCollection(dataset_name, data_home=None, min_class_support=100, standardize=True, verbose=False) -> LabelledCollection: """ Loads a UCI multiclass collection as an instance of :class:`quapy.data.base.LabelledCollection`. @@ -649,6 +672,7 @@ def fetch_UCIMulticlassLabelledCollection(dataset_name, data_home=None, min_clas ~/quay_data/ directory) :param min_class_support: minimum number of istances per class. Classes with fewer instances are discarded (deafult is 100) + :param standardize: indicates whether the covariates should be standardized or not (default is True). :param verbose: set to True (default is False) to get information (stats) about the dataset :return: a :class:`quapy.data.base.LabelledCollection` instance """ @@ -755,6 +779,10 @@ def fetch_UCIMulticlassLabelledCollection(dataset_name, data_home=None, min_clas f'is no longer multiclass. Try a reducing this value.' ) + if standardize: + stds = StandardScaler() + data.instances = stds.fit_transform(data.instances) + if verbose: data.stats()