save and load datasets as pkl

This commit is contained in:
Andrea Pedrotti 2023-03-10 12:40:26 +01:00
parent 7d0d6ba1f6
commit 57918ec523
4 changed files with 45 additions and 12 deletions

View File

@ -1,3 +1,5 @@
import os
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
from dataManager.glamiDataset import get_dataframe
from dataManager.multilingualDataset import MultilingualDataset
@ -22,7 +24,7 @@ class gFunDataset:
self.labels = labels
self.nrows = nrows
self.dataset = {}
self.load_dataset()
self._load_dataset()
def get_label_binarizer(self, labels):
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
@ -35,7 +37,7 @@ class gFunDataset:
mlb.fit(labels)
return mlb
def load_dataset(self):
def _load_dataset(self):
if "glami" in self.dataset_dir.lower():
print(f"- Loading GLAMI dataset from {self.dataset_dir}")
self.dataset_name = "glami"
@ -205,6 +207,14 @@ class gFunDataset:
else:
return self.labels
def save_as_pickle(self, path):
import pickle
filepath = os.path.join(path, f"{self.dataset_name}_{self.nrows}.pkl")
with open(filepath, "wb") as f:
print(f"- saving dataset in {filepath}")
pickle.dump(self, f)
if __name__ == "__main__":
import os

View File

@ -1,9 +1,21 @@
from os.path import expanduser
from os.path import expanduser, join
from dataManager.gFunDataset import gFunDataset
from dataManager.multiNewsDataset import MultiNewsDataset
from dataManager.amazonDataset import AmazonDataset
def load_from_pickle(path, dataset_name, nrows):
import pickle
filepath = join(path, f"{dataset_name}_{nrows}.pkl")
with open(filepath, "rb") as f:
loaded = pickle.load(f)
print(f"- Loaded dataset from {filepath}")
loaded.show_dimension()
return loaded
def get_dataset(dataset_name, args):
assert dataset_name in [
"multinews",
@ -58,6 +70,9 @@ def get_dataset(dataset_name, args):
nrows=args.nrows,
)
elif dataset_name == "glami":
if args.save_dataset is False:
dataset = load_from_pickle(GLAMI_DATAPATH, dataset_name, args.nrows)
else:
dataset = gFunDataset(
dataset_dir=GLAMI_DATAPATH,
is_textual=True,
@ -65,6 +80,9 @@ def get_dataset(dataset_name, args):
is_multilabel=False,
nrows=args.nrows,
)
dataset.save_as_pickle(GLAMI_DATAPATH)
elif dataset_name == "cls":
dataset = gFunDataset(
dataset_dir=CLS_DATAPATH,

View File

@ -167,6 +167,7 @@ class Trainer:
)
self.clf_type = classification_type
self.vgf_name = vgf_name
self.scheduler_name = scheduler_name
self.n_jobs = n_jobs
self.monitored_metric = (
"macro-F1" if self.clf_type == "multilabel" else "accuracy"
@ -190,7 +191,9 @@ class Trainer:
"model name": self.model.name_or_path,
"epochs": epochs,
"learning rate": self.optimizer.defaults["lr"],
"scheduler": "TODO", # TODO: add scheduler name
"scheduler": self.scheduler_name, # TODO: add scheduler params
"train size": len(train_dataloader.dataset),
"eval size": len(eval_dataloader.dataset),
"train batch size": train_dataloader.batch_size,
"eval batch size": eval_dataloader.batch_size,
"max len": train_dataloader.dataset.X.shape[-1],

View File

@ -1,6 +1,6 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from argparse import ArgumentParser
from time import time
@ -12,6 +12,7 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
"""
TODO:
- [!] add support for mT5
- [!] log on wandb also the other VGF results + final results
- [!] CLS dataset is loading only "books" domain data
- [!] documents should be trimmed to the same length (?)
- [!] overall gfun results logger
@ -120,6 +121,7 @@ if __name__ == "__main__":
parser.add_argument("--min_count", type=int, default=10)
parser.add_argument("--max_labels", type=int, default=50)
parser.add_argument("--clf_type", type=str, default="multilabel")
parser.add_argument("--save_dataset", action="store_true")
# gFUN parameters ----------------------
parser.add_argument("-p", "--posteriors", action="store_true")
parser.add_argument("-m", "--multilingual", action="store_true")