From 9ce0001047193d28e6355fa670787a90d5eff4de Mon Sep 17 00:00:00 2001 From: andreapdr Date: Mon, 12 Jun 2023 12:12:31 +0200 Subject: [PATCH] webis-unprocessed dataset --- dataManager/clsDataset.py | 101 +++++++++++++++++++++++++++++++------ dataManager/gFunDataset.py | 19 ++++++- dataManager/utils.py | 14 ++++- 3 files changed, 116 insertions(+), 18 deletions(-) diff --git a/dataManager/clsDataset.py b/dataManager/clsDataset.py index e81d126..12d0a09 100644 --- a/dataManager/clsDataset.py +++ b/dataManager/clsDataset.py @@ -1,5 +1,6 @@ import sys import os +import xml.etree.ElementTree as ET sys.path.append(os.getcwd()) @@ -8,13 +9,70 @@ import re from dataManager.multilingualDataset import MultilingualDataset CLS_PROCESSED_DATA_DIR = os.path.expanduser("~/datasets/cls-acl10-processed/") -LANGS = ["de", "en", "fr", "jp"] +CLS_UNPROCESSED_DATA_DIR = os.path.expanduser("~/datasets/cls-acl10-unprocessed/") +# LANGS = ["de", "en", "fr", "jp"] +LANGS = ["de", "en", "fr"] DOMAINS = ["books", "dvd", "music"] regex = r":\d+" subst = "" +def load_unprocessed_cls(reduce_target_space=False): + data = {} + for lang in LANGS: + data[lang] = {} + for domain in DOMAINS: + data[lang][domain] = {} + print(f"lang: {lang}, domain: {domain}") + for split in ["train", "test"]: + domain_data = [] + fdir = os.path.join( + CLS_UNPROCESSED_DATA_DIR, lang, domain, f"{split}.review" + ) + tree = ET.parse(fdir) + root = tree.getroot() + for child in root: + if reduce_target_space: + rating = np.zeros(3, dtype=int) + original_rating = int(float(child.find("rating").text)) + if original_rating < 3: + new_rating = 1 + elif original_rating > 3: + new_rating = 3 + else: + new_rating = 2 + rating[new_rating - 1] = 1 + else: + rating = np.zeros(5, dtype=int) + rating[int(float(child.find("rating").text)) - 1] = 1 + domain_data.append( + { + "asin": child.find("asin").text + if child.find("asin") is not None + else None, + "category": child.find("category").text + if child.find("category") is not None + else None, + # "rating": child.find("rating").text + # if child.find("rating") is not None + # else None, + "rating": rating, + "title": child.find("title").text + if child.find("title") is not None + else None, + "text": child.find("text").text + if child.find("text") is not None + else None, + "summary": child.find("summary").text + if child.find("summary") is not None + else None, + } + ) + data[lang][domain].update({split: domain_data}) + return data + + def load_cls(): data = {} for lang in LANGS: @@ -24,7 +82,7 @@ def load_cls(): train = ( open( os.path.join( - CLS_PROCESSED_DATA_DIR, lang, domain, "train.processed" + CLS_UNPROCESSED_DATA_DIR, lang, domain, "train.processed" ), "r", ) @@ -34,7 +92,7 @@ def load_cls(): test = ( open( os.path.join( - CLS_PROCESSED_DATA_DIR, lang, domain, "test.processed" + CLS_UNPROCESSED_DATA_DIR, lang, domain, "test.processed" ), "r", ) @@ -59,18 +117,29 @@ def process_data(line): if __name__ == "__main__": - print(f"datapath: {CLS_PROCESSED_DATA_DIR}") - data = load_cls() - multilingualDataset = MultilingualDataset(dataset_name="cls") - for lang in LANGS: - # TODO: just using book domain atm - Xtr = [text[0] for text in data[lang]["books"]["train"]] - # Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1) - Ytr = np.vstack([text[1] for text in data[lang]["books"]["train"]]) + print(f"datapath: {CLS_UNPROCESSED_DATA_DIR}") + # data = load_cls() + data = load_unprocessed_cls(reduce_target_space=True) + multilingualDataset = MultilingualDataset(dataset_name="webis-cls-unprocessed") - Xte = [text[0] for text in data[lang]["books"]["test"]] - # Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1) - Yte = np.vstack([text[1] for text in data[lang]["books"]["test"]]) + for lang in LANGS: + # Xtr = [text["summary"] for text in data[lang]["books"]["train"]] + Xtr = [text["text"] for text in data[lang]["books"]["train"]] + Ytr = np.vstack([text["rating"] for text in data[lang]["books"]["train"]]) + + # Xte = [text["summary"] for text in data[lang]["books"]["test"]] + Xte = [text["text"] for text in data[lang]["books"]["test"]] + Yte = np.vstack([text["rating"] for text in data[lang]["books"]["test"]]) + + # for lang in LANGS: + # # TODO: just using book domain atm + # Xtr = [text[0] for text in data[lang]["books"]["train"]] + # # Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1) + # Ytr = np.vstack([text[1] for text in data[lang]["books"]["train"]]) + + # Xte = [text[0] for text in data[lang]["books"]["test"]] + # # Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1) + # Yte = np.vstack([text[1] for text in data[lang]["books"]["test"]]) multilingualDataset.add( lang=lang, @@ -82,5 +151,7 @@ if __name__ == "__main__": te_ids=None, ) multilingualDataset.save( - os.path.expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl") + os.path.expanduser( + "~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-book.pkl" + ) ) diff --git a/dataManager/gFunDataset.py b/dataManager/gFunDataset.py index 679c362..243593d 100644 --- a/dataManager/gFunDataset.py +++ b/dataManager/gFunDataset.py @@ -62,14 +62,29 @@ class gFunDataset: ) self.mlb = self.get_label_binarizer(self.labels) - elif "cls" in self.dataset_dir.lower(): - print(f"- Loading CLS dataset from {self.dataset_dir}") + # WEBIS-CLS (processed) + elif ( + "cls" in self.dataset_dir.lower() + and "unprocessed" not in self.dataset_dir.lower() + ): + print(f"- Loading WEBIS-CLS (processed) dataset from {self.dataset_dir}") self.dataset_name = "cls" self.dataset, self.labels, self.data_langs = self._load_multilingual( self.dataset_name, self.dataset_dir, self.nrows ) self.mlb = self.get_label_binarizer(self.labels) + # WEBIS-CLS (unprocessed) + elif ( + "cls" in self.dataset_dir.lower() + and "unprocessed" in self.dataset_dir.lower() + ): + print(f"- Loading WEBIS-CLS (unprocessed) dataset from {self.dataset_dir}") + self.dataset_name = "cls" + self.dataset, self.labels, self.data_langs = self._load_multilingual( + self.dataset_name, self.dataset_dir, self.nrows + ) + self.mlb = self.get_label_binarizer(self.labels) self.show_dimension() return diff --git a/dataManager/utils.py b/dataManager/utils.py index 0f870bd..b5ee50b 100644 --- a/dataManager/utils.py +++ b/dataManager/utils.py @@ -23,6 +23,7 @@ def get_dataset(dataset_name, args): "rcv1-2", "glami", "cls", + "webis", ], "dataset not supported" RCV_DATAPATH = expanduser( @@ -37,7 +38,9 @@ def get_dataset(dataset_name, args): GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset") - WEBIS_CLS = expanduser("~/dataset/cls-acl10-unprocessed") + WEBIS_CLS = expanduser( + "~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-book.pkl" + ) if dataset_name == "multinews": # TODO: convert to gFunDataset @@ -93,6 +96,15 @@ def get_dataset(dataset_name, args): is_multilabel=False, nrows=args.nrows, ) + + elif dataset_name == "webis": + dataset = gFunDataset( + dataset_dir=WEBIS_CLS, + is_textual=True, + is_visual=False, + is_multilabel=False, + nrows=args.nrows, + ) else: raise NotImplementedError return dataset