save and load datasets as pkl

This commit is contained in:
Andrea Pedrotti 2023-03-10 12:40:26 +01:00
parent 7d0d6ba1f6
commit 57918ec523
4 changed files with 45 additions and 12 deletions

View File

@ -1,3 +1,5 @@
import os
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
from dataManager.glamiDataset import get_dataframe from dataManager.glamiDataset import get_dataframe
from dataManager.multilingualDataset import MultilingualDataset from dataManager.multilingualDataset import MultilingualDataset
@ -22,7 +24,7 @@ class gFunDataset:
self.labels = labels self.labels = labels
self.nrows = nrows self.nrows = nrows
self.dataset = {} self.dataset = {}
self.load_dataset() self._load_dataset()
def get_label_binarizer(self, labels): def get_label_binarizer(self, labels):
if self.dataset_name in ["rcv1-2", "jrc", "cls"]: if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
@ -35,7 +37,7 @@ class gFunDataset:
mlb.fit(labels) mlb.fit(labels)
return mlb return mlb
def load_dataset(self): def _load_dataset(self):
if "glami" in self.dataset_dir.lower(): if "glami" in self.dataset_dir.lower():
print(f"- Loading GLAMI dataset from {self.dataset_dir}") print(f"- Loading GLAMI dataset from {self.dataset_dir}")
self.dataset_name = "glami" self.dataset_name = "glami"
@ -205,6 +207,14 @@ class gFunDataset:
else: else:
return self.labels return self.labels
def save_as_pickle(self, path):
import pickle
filepath = os.path.join(path, f"{self.dataset_name}_{self.nrows}.pkl")
with open(filepath, "wb") as f:
print(f"- saving dataset in {filepath}")
pickle.dump(self, f)
if __name__ == "__main__": if __name__ == "__main__":
import os import os

View File

@ -1,9 +1,21 @@
from os.path import expanduser from os.path import expanduser, join
from dataManager.gFunDataset import gFunDataset from dataManager.gFunDataset import gFunDataset
from dataManager.multiNewsDataset import MultiNewsDataset from dataManager.multiNewsDataset import MultiNewsDataset
from dataManager.amazonDataset import AmazonDataset from dataManager.amazonDataset import AmazonDataset
def load_from_pickle(path, dataset_name, nrows):
import pickle
filepath = join(path, f"{dataset_name}_{nrows}.pkl")
with open(filepath, "rb") as f:
loaded = pickle.load(f)
print(f"- Loaded dataset from {filepath}")
loaded.show_dimension()
return loaded
def get_dataset(dataset_name, args): def get_dataset(dataset_name, args):
assert dataset_name in [ assert dataset_name in [
"multinews", "multinews",
@ -58,13 +70,19 @@ def get_dataset(dataset_name, args):
nrows=args.nrows, nrows=args.nrows,
) )
elif dataset_name == "glami": elif dataset_name == "glami":
dataset = gFunDataset( if args.save_dataset is False:
dataset_dir=GLAMI_DATAPATH, dataset = load_from_pickle(GLAMI_DATAPATH, dataset_name, args.nrows)
is_textual=True, else:
is_visual=True, dataset = gFunDataset(
is_multilabel=False, dataset_dir=GLAMI_DATAPATH,
nrows=args.nrows, is_textual=True,
) is_visual=True,
is_multilabel=False,
nrows=args.nrows,
)
dataset.save_as_pickle(GLAMI_DATAPATH)
elif dataset_name == "cls": elif dataset_name == "cls":
dataset = gFunDataset( dataset = gFunDataset(
dataset_dir=CLS_DATAPATH, dataset_dir=CLS_DATAPATH,

View File

@ -167,6 +167,7 @@ class Trainer:
) )
self.clf_type = classification_type self.clf_type = classification_type
self.vgf_name = vgf_name self.vgf_name = vgf_name
self.scheduler_name = scheduler_name
self.n_jobs = n_jobs self.n_jobs = n_jobs
self.monitored_metric = ( self.monitored_metric = (
"macro-F1" if self.clf_type == "multilabel" else "accuracy" "macro-F1" if self.clf_type == "multilabel" else "accuracy"
@ -190,7 +191,9 @@ class Trainer:
"model name": self.model.name_or_path, "model name": self.model.name_or_path,
"epochs": epochs, "epochs": epochs,
"learning rate": self.optimizer.defaults["lr"], "learning rate": self.optimizer.defaults["lr"],
"scheduler": "TODO", # TODO: add scheduler name "scheduler": self.scheduler_name, # TODO: add scheduler params
"train size": len(train_dataloader.dataset),
"eval size": len(eval_dataloader.dataset),
"train batch size": train_dataloader.batch_size, "train batch size": train_dataloader.batch_size,
"eval batch size": eval_dataloader.batch_size, "eval batch size": eval_dataloader.batch_size,
"max len": train_dataloader.dataset.X.shape[-1], "max len": train_dataloader.dataset.X.shape[-1],

View File

@ -1,6 +1,6 @@
import os import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3" os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from argparse import ArgumentParser from argparse import ArgumentParser
from time import time from time import time
@ -12,6 +12,7 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
""" """
TODO: TODO:
- [!] add support for mT5 - [!] add support for mT5
- [!] log on wandb also the other VGF results + final results
- [!] CLS dataset is loading only "books" domain data - [!] CLS dataset is loading only "books" domain data
- [!] documents should be trimmed to the same length (?) - [!] documents should be trimmed to the same length (?)
- [!] overall gfun results logger - [!] overall gfun results logger
@ -120,6 +121,7 @@ if __name__ == "__main__":
parser.add_argument("--min_count", type=int, default=10) parser.add_argument("--min_count", type=int, default=10)
parser.add_argument("--max_labels", type=int, default=50) parser.add_argument("--max_labels", type=int, default=50)
parser.add_argument("--clf_type", type=str, default="multilabel") parser.add_argument("--clf_type", type=str, default="multilabel")
parser.add_argument("--save_dataset", action="store_true")
# gFUN parameters ---------------------- # gFUN parameters ----------------------
parser.add_argument("-p", "--posteriors", action="store_true") parser.add_argument("-p", "--posteriors", action="store_true")
parser.add_argument("-m", "--multilingual", action="store_true") parser.add_argument("-m", "--multilingual", action="store_true")