# TODO: this should be a instance of an abstract MultilingualDataset from abc import ABC, abstractmethod from scipy.sparse import issparse from os.path import join, expanduser import pickle import re import numpy as np from tqdm import tqdm class NewMultilingualDataset(ABC): @abstractmethod def get_training(self): pass @abstractmethod def get_validation(self): pass @abstractmethod def get_test(self): pass @abstractmethod def mask_numbers(self): pass @abstractmethod def save(self): pass @abstractmethod def load(self): pass # class RcvMultilingualDataset(MultilingualDataset): class RcvMultilingualDataset: def __init__( self, run="0", ): self.dataset_name = "rcv1-2" self.dataset_path = expanduser( f"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run{run}.pickle" ) def load(self): import pickle data = pickle.load(open(self.dataset_path, "rb")) return self class MultilingualDataset: """ A multilingual dataset is a dictionary of training and test documents indexed by language code. Train and test sets are represented as tuples of the type (X,Y,ids), where X is a matrix representation of the documents (e.g., a document-by-term sparse csr_matrix), Y is a document-by-label binary np.array indicating the labels of each document, and ids is a list of document-identifiers from the original collection. """ def __init__(self, dataset_name): self.dataset_name = dataset_name self.multiling_dataset = {} print(f"[Init Multilingual Dataset: {self.dataset_name}]") def add(self, lang, Xtr, Ytr, Xte, Yte, tr_ids=None, te_ids=None): self.multiling_dataset[lang] = ((Xtr, Ytr, tr_ids), (Xte, Yte, te_ids)) def save(self, file): self.sort_indexes() pickle.dump(self, open(file, "wb"), pickle.HIGHEST_PROTOCOL) return self def __getitem__(self, item): if item in self.langs(): return self.multiling_dataset[item] return None @classmethod def load(cls, file): data = pickle.load(open(file, "rb")) data.sort_indexes() return data @classmethod def load_ids(cls, file): data = pickle.load(open(file, "rb")) tr_ids = { lang: tr_ids for (lang, ((_, _, tr_ids), (_, _, _))) in data.multiling_dataset.items() } te_ids = { lang: te_ids for (lang, ((_, _, _), (_, _, te_ids))) in data.multiling_dataset.items() } return tr_ids, te_ids def sort_indexes(self): for lang, ((Xtr, _, _), (Xte, _, _)) in self.multiling_dataset.items(): if issparse(Xtr): Xtr.sort_indices() if issparse(Xte): Xte.sort_indices() def set_view(self, categories=None, languages=None): if categories is not None: if isinstance(categories, int): categories = np.array([categories]) elif isinstance(categories, list): categories = np.array(categories) self.categories_view = categories if languages is not None: self.languages_view = languages def training(self, mask_numbers=False, target_as_csr=False): return self.lXtr(mask_numbers), self.lYtr(as_csr=target_as_csr) def test(self, mask_numbers=False, target_as_csr=False): return self.lXte(mask_numbers), self.lYte(as_csr=target_as_csr) def lXtr(self, mask_numbers=False): proc = lambda x: _mask_numbers(x) if mask_numbers else x # return {lang: Xtr for (lang, ((Xtr, _, _), _)) in self.multiling_dataset.items() if lang in self.langs()} return { lang: proc(Xtr) for (lang, ((Xtr, _, _), _)) in self.multiling_dataset.items() if lang in self.langs() } def lXte(self, mask_numbers=False): proc = lambda x: _mask_numbers(x) if mask_numbers else x # return {lang: Xte for (lang, (_, (Xte, _, _))) in self.multiling_dataset.items() if lang in self.langs()} return { lang: proc(Xte) for (lang, (_, (Xte, _, _))) in self.multiling_dataset.items() if lang in self.langs() } def lYtr(self, as_csr=False): lY = { lang: self.cat_view(Ytr) for (lang, ((_, Ytr, _), _)) in self.multiling_dataset.items() if lang in self.langs() } if as_csr: lY = {l: csr_matrix(Y) for l, Y in lY.items()} return lY def lYte(self, as_csr=False): lY = { lang: self.cat_view(Yte) for (lang, (_, (_, Yte, _))) in self.multiling_dataset.items() if lang in self.langs() } if as_csr: lY = {l: csr_matrix(Y) for l, Y in lY.items()} return lY def cat_view(self, Y): if hasattr(self, "categories_view"): return Y[:, self.categories_view] else: return Y def langs(self): if hasattr(self, "languages_view"): langs = self.languages_view else: langs = sorted(self.multiling_dataset.keys()) return langs def num_categories(self): return self.lYtr()[self.langs()[0]].shape[1] def show_dimensions(self): def shape(X): return X.shape if hasattr(X, "shape") else len(X) for lang, ( (Xtr, Ytr, IDtr), (Xte, Yte, IDte), ) in self.multiling_dataset.items(): if lang not in self.langs(): continue print( "Lang {}, Xtr={}, ytr={}, Xte={}, yte={}".format( lang, shape(Xtr), self.cat_view(Ytr).shape, shape(Xte), self.cat_view(Yte).shape, ) ) def show_category_prevalences(self): nC = self.num_categories() accum_tr = np.zeros(nC, dtype=np.int) accum_te = np.zeros(nC, dtype=np.int) in_langs = np.zeros( nC, dtype=np.int ) # count languages with at least one positive example (per category) for lang, ( (Xtr, Ytr, IDtr), (Xte, Yte, IDte), ) in self.multiling_dataset.items(): if lang not in self.langs(): continue prev_train = np.sum(self.cat_view(Ytr), axis=0) prev_test = np.sum(self.cat_view(Yte), axis=0) accum_tr += prev_train accum_te += prev_test in_langs += (prev_train > 0) * 1 print(lang + "-train", prev_train) print(lang + "-test", prev_test) print("all-train", accum_tr) print("all-test", accum_te) return accum_tr, accum_te, in_langs def set_labels(self, labels): self.labels = labels def reduce_data(self, langs=["it", "en"], maxn=50): print(f"- Reducing data: {langs} with max {maxn} documents...\n") self.set_view(languages=langs) data = { lang: self._reduce(data, maxn) for lang, data in self.multiling_dataset.items() if lang in langs } self.multiling_dataset = data return self def _reduce(self, multilingual_dataset, maxn): new_data = [] for split in multilingual_dataset: docs, labels, ids = split new_data.append((docs[:maxn], labels[:maxn], ids[:maxn])) return new_data def _mask_numbers(data): mask_moredigit = re.compile(r"\s[\+-]?\d{5,}([\.,]\d*)*\b") mask_4digit = re.compile(r"\s[\+-]?\d{4}([\.,]\d*)*\b") mask_3digit = re.compile(r"\s[\+-]?\d{3}([\.,]\d*)*\b") mask_2digit = re.compile(r"\s[\+-]?\d{2}([\.,]\d*)*\b") mask_1digit = re.compile(r"\s[\+-]?\d{1}([\.,]\d*)*\b") masked = [] for text in tqdm(data, desc="masking numbers"): text = " " + text text = mask_moredigit.sub(" MoreDigitMask", text) text = mask_4digit.sub(" FourDigitMask", text) text = mask_3digit.sub(" ThreeDigitMask", text) text = mask_2digit.sub(" TwoDigitMask", text) text = mask_1digit.sub(" OneDigitMask", text) masked.append(text.replace(".", "").replace(",", "").strip()) return masked if __name__ == "__main__": DATAPATH = expanduser( "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle" ) print(DATAPATH) dataset = MultilingualDataset().load(DATAPATH) print(dataset.show_dimensions())