save and load datasets as pkl
This commit is contained in:
parent
7d0d6ba1f6
commit
57918ec523
|
@ -1,3 +1,5 @@
|
|||
import os
|
||||
|
||||
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
||||
from dataManager.glamiDataset import get_dataframe
|
||||
from dataManager.multilingualDataset import MultilingualDataset
|
||||
|
@ -22,7 +24,7 @@ class gFunDataset:
|
|||
self.labels = labels
|
||||
self.nrows = nrows
|
||||
self.dataset = {}
|
||||
self.load_dataset()
|
||||
self._load_dataset()
|
||||
|
||||
def get_label_binarizer(self, labels):
|
||||
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
|
||||
|
@ -35,7 +37,7 @@ class gFunDataset:
|
|||
mlb.fit(labels)
|
||||
return mlb
|
||||
|
||||
def load_dataset(self):
|
||||
def _load_dataset(self):
|
||||
if "glami" in self.dataset_dir.lower():
|
||||
print(f"- Loading GLAMI dataset from {self.dataset_dir}")
|
||||
self.dataset_name = "glami"
|
||||
|
@ -205,6 +207,14 @@ class gFunDataset:
|
|||
else:
|
||||
return self.labels
|
||||
|
||||
def save_as_pickle(self, path):
|
||||
import pickle
|
||||
|
||||
filepath = os.path.join(path, f"{self.dataset_name}_{self.nrows}.pkl")
|
||||
with open(filepath, "wb") as f:
|
||||
print(f"- saving dataset in {filepath}")
|
||||
pickle.dump(self, f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
|
|
@ -1,9 +1,21 @@
|
|||
from os.path import expanduser
|
||||
from os.path import expanduser, join
|
||||
from dataManager.gFunDataset import gFunDataset
|
||||
from dataManager.multiNewsDataset import MultiNewsDataset
|
||||
from dataManager.amazonDataset import AmazonDataset
|
||||
|
||||
|
||||
def load_from_pickle(path, dataset_name, nrows):
|
||||
import pickle
|
||||
|
||||
filepath = join(path, f"{dataset_name}_{nrows}.pkl")
|
||||
|
||||
with open(filepath, "rb") as f:
|
||||
loaded = pickle.load(f)
|
||||
print(f"- Loaded dataset from {filepath}")
|
||||
loaded.show_dimension()
|
||||
return loaded
|
||||
|
||||
|
||||
def get_dataset(dataset_name, args):
|
||||
assert dataset_name in [
|
||||
"multinews",
|
||||
|
@ -58,6 +70,9 @@ def get_dataset(dataset_name, args):
|
|||
nrows=args.nrows,
|
||||
)
|
||||
elif dataset_name == "glami":
|
||||
if args.save_dataset is False:
|
||||
dataset = load_from_pickle(GLAMI_DATAPATH, dataset_name, args.nrows)
|
||||
else:
|
||||
dataset = gFunDataset(
|
||||
dataset_dir=GLAMI_DATAPATH,
|
||||
is_textual=True,
|
||||
|
@ -65,6 +80,9 @@ def get_dataset(dataset_name, args):
|
|||
is_multilabel=False,
|
||||
nrows=args.nrows,
|
||||
)
|
||||
|
||||
dataset.save_as_pickle(GLAMI_DATAPATH)
|
||||
|
||||
elif dataset_name == "cls":
|
||||
dataset = gFunDataset(
|
||||
dataset_dir=CLS_DATAPATH,
|
||||
|
|
|
@ -167,6 +167,7 @@ class Trainer:
|
|||
)
|
||||
self.clf_type = classification_type
|
||||
self.vgf_name = vgf_name
|
||||
self.scheduler_name = scheduler_name
|
||||
self.n_jobs = n_jobs
|
||||
self.monitored_metric = (
|
||||
"macro-F1" if self.clf_type == "multilabel" else "accuracy"
|
||||
|
@ -190,7 +191,9 @@ class Trainer:
|
|||
"model name": self.model.name_or_path,
|
||||
"epochs": epochs,
|
||||
"learning rate": self.optimizer.defaults["lr"],
|
||||
"scheduler": "TODO", # TODO: add scheduler name
|
||||
"scheduler": self.scheduler_name, # TODO: add scheduler params
|
||||
"train size": len(train_dataloader.dataset),
|
||||
"eval size": len(eval_dataloader.dataset),
|
||||
"train batch size": train_dataloader.batch_size,
|
||||
"eval batch size": eval_dataloader.batch_size,
|
||||
"max len": train_dataloader.dataset.X.shape[-1],
|
||||
|
|
4
main.py
4
main.py
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
|
||||
from argparse import ArgumentParser
|
||||
from time import time
|
||||
|
@ -12,6 +12,7 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
|
|||
"""
|
||||
TODO:
|
||||
- [!] add support for mT5
|
||||
- [!] log on wandb also the other VGF results + final results
|
||||
- [!] CLS dataset is loading only "books" domain data
|
||||
- [!] documents should be trimmed to the same length (?)
|
||||
- [!] overall gfun results logger
|
||||
|
@ -120,6 +121,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--min_count", type=int, default=10)
|
||||
parser.add_argument("--max_labels", type=int, default=50)
|
||||
parser.add_argument("--clf_type", type=str, default="multilabel")
|
||||
parser.add_argument("--save_dataset", action="store_true")
|
||||
# gFUN parameters ----------------------
|
||||
parser.add_argument("-p", "--posteriors", action="store_true")
|
||||
parser.add_argument("-m", "--multilingual", action="store_true")
|
||||
|
|
Loading…
Reference in New Issue