From 57918ec5237e0038b42a596bfac08d4dc1a66bde Mon Sep 17 00:00:00 2001 From: andreapdr Date: Fri, 10 Mar 2023 12:40:26 +0100 Subject: [PATCH] save and load datasets as pkl --- dataManager/gFunDataset.py | 14 ++++++++++++-- dataManager/utils.py | 34 ++++++++++++++++++++++++++-------- gfun/vgfs/commons.py | 5 ++++- main.py | 4 +++- 4 files changed, 45 insertions(+), 12 deletions(-) diff --git a/dataManager/gFunDataset.py b/dataManager/gFunDataset.py index a0040ec..10fcbb3 100644 --- a/dataManager/gFunDataset.py +++ b/dataManager/gFunDataset.py @@ -1,3 +1,5 @@ +import os + from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer from dataManager.glamiDataset import get_dataframe from dataManager.multilingualDataset import MultilingualDataset @@ -22,7 +24,7 @@ class gFunDataset: self.labels = labels self.nrows = nrows self.dataset = {} - self.load_dataset() + self._load_dataset() def get_label_binarizer(self, labels): if self.dataset_name in ["rcv1-2", "jrc", "cls"]: @@ -35,7 +37,7 @@ class gFunDataset: mlb.fit(labels) return mlb - def load_dataset(self): + def _load_dataset(self): if "glami" in self.dataset_dir.lower(): print(f"- Loading GLAMI dataset from {self.dataset_dir}") self.dataset_name = "glami" @@ -205,6 +207,14 @@ class gFunDataset: else: return self.labels + def save_as_pickle(self, path): + import pickle + + filepath = os.path.join(path, f"{self.dataset_name}_{self.nrows}.pkl") + with open(filepath, "wb") as f: + print(f"- saving dataset in {filepath}") + pickle.dump(self, f) + if __name__ == "__main__": import os diff --git a/dataManager/utils.py b/dataManager/utils.py index 4dfa953..d20bbbb 100644 --- a/dataManager/utils.py +++ b/dataManager/utils.py @@ -1,9 +1,21 @@ -from os.path import expanduser +from os.path import expanduser, join from dataManager.gFunDataset import gFunDataset from dataManager.multiNewsDataset import MultiNewsDataset from dataManager.amazonDataset import AmazonDataset +def load_from_pickle(path, dataset_name, nrows): + import pickle + + filepath = join(path, f"{dataset_name}_{nrows}.pkl") + + with open(filepath, "rb") as f: + loaded = pickle.load(f) + print(f"- Loaded dataset from {filepath}") + loaded.show_dimension() + return loaded + + def get_dataset(dataset_name, args): assert dataset_name in [ "multinews", @@ -58,13 +70,19 @@ def get_dataset(dataset_name, args): nrows=args.nrows, ) elif dataset_name == "glami": - dataset = gFunDataset( - dataset_dir=GLAMI_DATAPATH, - is_textual=True, - is_visual=True, - is_multilabel=False, - nrows=args.nrows, - ) + if args.save_dataset is False: + dataset = load_from_pickle(GLAMI_DATAPATH, dataset_name, args.nrows) + else: + dataset = gFunDataset( + dataset_dir=GLAMI_DATAPATH, + is_textual=True, + is_visual=True, + is_multilabel=False, + nrows=args.nrows, + ) + + dataset.save_as_pickle(GLAMI_DATAPATH) + elif dataset_name == "cls": dataset = gFunDataset( dataset_dir=CLS_DATAPATH, diff --git a/gfun/vgfs/commons.py b/gfun/vgfs/commons.py index 6c4d59c..b1f01cd 100644 --- a/gfun/vgfs/commons.py +++ b/gfun/vgfs/commons.py @@ -167,6 +167,7 @@ class Trainer: ) self.clf_type = classification_type self.vgf_name = vgf_name + self.scheduler_name = scheduler_name self.n_jobs = n_jobs self.monitored_metric = ( "macro-F1" if self.clf_type == "multilabel" else "accuracy" @@ -190,7 +191,9 @@ class Trainer: "model name": self.model.name_or_path, "epochs": epochs, "learning rate": self.optimizer.defaults["lr"], - "scheduler": "TODO", # TODO: add scheduler name + "scheduler": self.scheduler_name, # TODO: add scheduler params + "train size": len(train_dataloader.dataset), + "eval size": len(eval_dataloader.dataset), "train batch size": train_dataloader.batch_size, "eval batch size": eval_dataloader.batch_size, "max len": train_dataloader.dataset.X.shape[-1], diff --git a/main.py b/main.py index 05ef6ee..f376fa2 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,6 @@ import os -os.environ["CUDA_VISIBLE_DEVICES"] = "3" +os.environ["CUDA_VISIBLE_DEVICES"] = "0" from argparse import ArgumentParser from time import time @@ -12,6 +12,7 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling """ TODO: - [!] add support for mT5 + - [!] log on wandb also the other VGF results + final results - [!] CLS dataset is loading only "books" domain data - [!] documents should be trimmed to the same length (?) - [!] overall gfun results logger @@ -120,6 +121,7 @@ if __name__ == "__main__": parser.add_argument("--min_count", type=int, default=10) parser.add_argument("--max_labels", type=int, default=50) parser.add_argument("--clf_type", type=str, default="multilabel") + parser.add_argument("--save_dataset", action="store_true") # gFUN parameters ---------------------- parser.add_argument("-p", "--posteriors", action="store_true") parser.add_argument("-m", "--multilingual", action="store_true")