save and load datasets as pkl
This commit is contained in:
parent
7d0d6ba1f6
commit
57918ec523
|
@ -1,3 +1,5 @@
|
||||||
|
import os
|
||||||
|
|
||||||
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
||||||
from dataManager.glamiDataset import get_dataframe
|
from dataManager.glamiDataset import get_dataframe
|
||||||
from dataManager.multilingualDataset import MultilingualDataset
|
from dataManager.multilingualDataset import MultilingualDataset
|
||||||
|
@ -22,7 +24,7 @@ class gFunDataset:
|
||||||
self.labels = labels
|
self.labels = labels
|
||||||
self.nrows = nrows
|
self.nrows = nrows
|
||||||
self.dataset = {}
|
self.dataset = {}
|
||||||
self.load_dataset()
|
self._load_dataset()
|
||||||
|
|
||||||
def get_label_binarizer(self, labels):
|
def get_label_binarizer(self, labels):
|
||||||
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
|
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
|
||||||
|
@ -35,7 +37,7 @@ class gFunDataset:
|
||||||
mlb.fit(labels)
|
mlb.fit(labels)
|
||||||
return mlb
|
return mlb
|
||||||
|
|
||||||
def load_dataset(self):
|
def _load_dataset(self):
|
||||||
if "glami" in self.dataset_dir.lower():
|
if "glami" in self.dataset_dir.lower():
|
||||||
print(f"- Loading GLAMI dataset from {self.dataset_dir}")
|
print(f"- Loading GLAMI dataset from {self.dataset_dir}")
|
||||||
self.dataset_name = "glami"
|
self.dataset_name = "glami"
|
||||||
|
@ -205,6 +207,14 @@ class gFunDataset:
|
||||||
else:
|
else:
|
||||||
return self.labels
|
return self.labels
|
||||||
|
|
||||||
|
def save_as_pickle(self, path):
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
filepath = os.path.join(path, f"{self.dataset_name}_{self.nrows}.pkl")
|
||||||
|
with open(filepath, "wb") as f:
|
||||||
|
print(f"- saving dataset in {filepath}")
|
||||||
|
pickle.dump(self, f)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import os
|
import os
|
||||||
|
|
|
@ -1,9 +1,21 @@
|
||||||
from os.path import expanduser
|
from os.path import expanduser, join
|
||||||
from dataManager.gFunDataset import gFunDataset
|
from dataManager.gFunDataset import gFunDataset
|
||||||
from dataManager.multiNewsDataset import MultiNewsDataset
|
from dataManager.multiNewsDataset import MultiNewsDataset
|
||||||
from dataManager.amazonDataset import AmazonDataset
|
from dataManager.amazonDataset import AmazonDataset
|
||||||
|
|
||||||
|
|
||||||
|
def load_from_pickle(path, dataset_name, nrows):
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
filepath = join(path, f"{dataset_name}_{nrows}.pkl")
|
||||||
|
|
||||||
|
with open(filepath, "rb") as f:
|
||||||
|
loaded = pickle.load(f)
|
||||||
|
print(f"- Loaded dataset from {filepath}")
|
||||||
|
loaded.show_dimension()
|
||||||
|
return loaded
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(dataset_name, args):
|
def get_dataset(dataset_name, args):
|
||||||
assert dataset_name in [
|
assert dataset_name in [
|
||||||
"multinews",
|
"multinews",
|
||||||
|
@ -58,13 +70,19 @@ def get_dataset(dataset_name, args):
|
||||||
nrows=args.nrows,
|
nrows=args.nrows,
|
||||||
)
|
)
|
||||||
elif dataset_name == "glami":
|
elif dataset_name == "glami":
|
||||||
dataset = gFunDataset(
|
if args.save_dataset is False:
|
||||||
dataset_dir=GLAMI_DATAPATH,
|
dataset = load_from_pickle(GLAMI_DATAPATH, dataset_name, args.nrows)
|
||||||
is_textual=True,
|
else:
|
||||||
is_visual=True,
|
dataset = gFunDataset(
|
||||||
is_multilabel=False,
|
dataset_dir=GLAMI_DATAPATH,
|
||||||
nrows=args.nrows,
|
is_textual=True,
|
||||||
)
|
is_visual=True,
|
||||||
|
is_multilabel=False,
|
||||||
|
nrows=args.nrows,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset.save_as_pickle(GLAMI_DATAPATH)
|
||||||
|
|
||||||
elif dataset_name == "cls":
|
elif dataset_name == "cls":
|
||||||
dataset = gFunDataset(
|
dataset = gFunDataset(
|
||||||
dataset_dir=CLS_DATAPATH,
|
dataset_dir=CLS_DATAPATH,
|
||||||
|
|
|
@ -167,6 +167,7 @@ class Trainer:
|
||||||
)
|
)
|
||||||
self.clf_type = classification_type
|
self.clf_type = classification_type
|
||||||
self.vgf_name = vgf_name
|
self.vgf_name = vgf_name
|
||||||
|
self.scheduler_name = scheduler_name
|
||||||
self.n_jobs = n_jobs
|
self.n_jobs = n_jobs
|
||||||
self.monitored_metric = (
|
self.monitored_metric = (
|
||||||
"macro-F1" if self.clf_type == "multilabel" else "accuracy"
|
"macro-F1" if self.clf_type == "multilabel" else "accuracy"
|
||||||
|
@ -190,7 +191,9 @@ class Trainer:
|
||||||
"model name": self.model.name_or_path,
|
"model name": self.model.name_or_path,
|
||||||
"epochs": epochs,
|
"epochs": epochs,
|
||||||
"learning rate": self.optimizer.defaults["lr"],
|
"learning rate": self.optimizer.defaults["lr"],
|
||||||
"scheduler": "TODO", # TODO: add scheduler name
|
"scheduler": self.scheduler_name, # TODO: add scheduler params
|
||||||
|
"train size": len(train_dataloader.dataset),
|
||||||
|
"eval size": len(eval_dataloader.dataset),
|
||||||
"train batch size": train_dataloader.batch_size,
|
"train batch size": train_dataloader.batch_size,
|
||||||
"eval batch size": eval_dataloader.batch_size,
|
"eval batch size": eval_dataloader.batch_size,
|
||||||
"max len": train_dataloader.dataset.X.shape[-1],
|
"max len": train_dataloader.dataset.X.shape[-1],
|
||||||
|
|
4
main.py
4
main.py
|
@ -1,6 +1,6 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from time import time
|
from time import time
|
||||||
|
@ -12,6 +12,7 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||||
"""
|
"""
|
||||||
TODO:
|
TODO:
|
||||||
- [!] add support for mT5
|
- [!] add support for mT5
|
||||||
|
- [!] log on wandb also the other VGF results + final results
|
||||||
- [!] CLS dataset is loading only "books" domain data
|
- [!] CLS dataset is loading only "books" domain data
|
||||||
- [!] documents should be trimmed to the same length (?)
|
- [!] documents should be trimmed to the same length (?)
|
||||||
- [!] overall gfun results logger
|
- [!] overall gfun results logger
|
||||||
|
@ -120,6 +121,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--min_count", type=int, default=10)
|
parser.add_argument("--min_count", type=int, default=10)
|
||||||
parser.add_argument("--max_labels", type=int, default=50)
|
parser.add_argument("--max_labels", type=int, default=50)
|
||||||
parser.add_argument("--clf_type", type=str, default="multilabel")
|
parser.add_argument("--clf_type", type=str, default="multilabel")
|
||||||
|
parser.add_argument("--save_dataset", action="store_true")
|
||||||
# gFUN parameters ----------------------
|
# gFUN parameters ----------------------
|
||||||
parser.add_argument("-p", "--posteriors", action="store_true")
|
parser.add_argument("-p", "--posteriors", action="store_true")
|
||||||
parser.add_argument("-m", "--multilingual", action="store_true")
|
parser.add_argument("-m", "--multilingual", action="store_true")
|
||||||
|
|
Loading…
Reference in New Issue