implemented save/load for MT5ForSequenceClassification. Moved torch Datasets to datamanager module
This commit is contained in:
parent
56faaf2615
commit
9d43ebb23b
|
@ -1,2 +1,66 @@
|
|||
class TorchMultiNewsDataset:
|
||||
pass
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class MultilingualDatasetTorch(Dataset):
|
||||
def __init__(self, lX, lY, split="train"):
|
||||
self.lX = lX
|
||||
self.lY = lY
|
||||
self.split = split
|
||||
self.langs = []
|
||||
self.init()
|
||||
|
||||
def init(self):
|
||||
self.X = torch.vstack([data.input_ids for data in self.lX.values()])
|
||||
if self.split != "whole":
|
||||
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
|
||||
self.langs = sum(
|
||||
[
|
||||
v
|
||||
for v in {
|
||||
lang: [lang] * len(data.input_ids) for lang, data in self.lX.items()
|
||||
}.values()
|
||||
],
|
||||
[],
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
return len(self.X)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.split == "whole":
|
||||
return self.X[index], self.langs[index]
|
||||
return self.X[index], self.Y[index], self.langs[index]
|
||||
|
||||
|
||||
class MultimodalDatasetTorch(Dataset):
|
||||
def __init__(self, lX, lY, split="train"):
|
||||
self.lX = lX
|
||||
self.lY = lY
|
||||
self.split = split
|
||||
self.langs = []
|
||||
self.init()
|
||||
|
||||
def init(self):
|
||||
self.X = torch.vstack([imgs for imgs in self.lX.values()])
|
||||
if self.split != "whole":
|
||||
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
|
||||
self.langs = sum(
|
||||
[
|
||||
v
|
||||
for v in {
|
||||
lang: [lang] * len(data) for lang, data in self.lX.items()
|
||||
}.values()
|
||||
],
|
||||
[],
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.X)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.split == "whole":
|
||||
return self.X[index], self.langs[index]
|
||||
return self.X[index], self.Y[index], self.langs[index]
|
||||
|
|
|
@ -278,7 +278,7 @@ class Trainer:
|
|||
loss = self.loss_fn(y_hat, y.to(self.device))
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
batch_losses.append(loss.item()) # TODO: is this still on gpu?
|
||||
batch_losses.append(loss.item())
|
||||
if (epoch + 1) % PRINT_ON_EPOCH == 0:
|
||||
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
|
||||
print(
|
||||
|
|
|
@ -9,13 +9,13 @@ import torch
|
|||
import torch.nn as nn
|
||||
import transformers
|
||||
from transformers import MT5EncoderModel
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from transformers.modeling_outputs import ModelOutput
|
||||
|
||||
from gfun.vgfs.commons import Trainer
|
||||
from gfun.vgfs.transformerGen import TransformerGen
|
||||
from gfun.vgfs.viewGen import ViewGen
|
||||
from dataManager.torchDataset import MultilingualDatasetTorch
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
|
@ -44,11 +44,12 @@ class MT5ForSequenceClassification(nn.Module):
|
|||
return ModelOutput(logits=logits)
|
||||
|
||||
def save_pretrained(self, checkpoint_dir):
|
||||
pass # TODO: implement
|
||||
torch.save(self.state_dict(), checkpoint_dir + ".pt")
|
||||
return
|
||||
|
||||
def from_pretrained(self, checkpoint_dir):
|
||||
# TODO: implement
|
||||
return self
|
||||
checkpoint_dir += ".pt"
|
||||
return self.load_state_dict(torch.load(checkpoint_dir))
|
||||
|
||||
|
||||
class TextualTransformerGen(ViewGen, TransformerGen):
|
||||
|
@ -165,9 +166,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
shuffle=False,
|
||||
)
|
||||
|
||||
experiment_name = (
|
||||
f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
|
||||
)
|
||||
experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
|
||||
|
||||
trainer = Trainer(
|
||||
model=self.model,
|
||||
|
@ -179,12 +178,17 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
evaluate_step=self.evaluate_step,
|
||||
patience=self.patience,
|
||||
experiment_name=experiment_name,
|
||||
checkpoint_path="models/vgfs/transformer",
|
||||
checkpoint_path=os.path.join(
|
||||
"models",
|
||||
"vgfs",
|
||||
"transformer",
|
||||
self._format_model_name(self.model_name),
|
||||
),
|
||||
vgf_name="textual_trf",
|
||||
classification_type=self.clf_type,
|
||||
n_jobs=self.n_jobs,
|
||||
# scheduler_name="ReduceLROnPlateau",
|
||||
scheduler_name=None,
|
||||
scheduler_name="ReduceLROnPlateau",
|
||||
# scheduler_name=None,
|
||||
)
|
||||
trainer.train(
|
||||
train_dataloader=tra_dataloader,
|
||||
|
@ -259,39 +263,17 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _format_model_name(self, model_name):
|
||||
if "mt5" in model_name:
|
||||
return "google-mt5"
|
||||
elif "bert" in model_name:
|
||||
if "multilingual" in model_name:
|
||||
return "mbert"
|
||||
elif "xlm" in model_name:
|
||||
return "xlm"
|
||||
else:
|
||||
return model_name
|
||||
|
||||
def __str__(self):
|
||||
str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
|
||||
return str
|
||||
|
||||
|
||||
class MultilingualDatasetTorch(Dataset):
|
||||
def __init__(self, lX, lY, split="train"):
|
||||
self.lX = lX
|
||||
self.lY = lY
|
||||
self.split = split
|
||||
self.langs = []
|
||||
self.init()
|
||||
|
||||
def init(self):
|
||||
self.X = torch.vstack([data.input_ids for data in self.lX.values()])
|
||||
if self.split != "whole":
|
||||
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
|
||||
self.langs = sum(
|
||||
[
|
||||
v
|
||||
for v in {
|
||||
lang: [lang] * len(data.input_ids) for lang, data in self.lX.items()
|
||||
}.values()
|
||||
],
|
||||
[],
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
return len(self.X)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.split == "whole":
|
||||
return self.X[index], self.langs[index]
|
||||
return self.X[index], self.Y[index], self.langs[index]
|
||||
|
|
|
@ -4,12 +4,12 @@ import numpy as np
|
|||
import torch
|
||||
import transformers
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
||||
|
||||
from gfun.vgfs.commons import Trainer
|
||||
from gfun.vgfs.transformerGen import TransformerGen
|
||||
from gfun.vgfs.viewGen import ViewGen
|
||||
from dataManager.torchDataset import MultilingualDatasetTorch
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
|
@ -186,63 +186,3 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
|||
def __str__(self):
|
||||
str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
|
||||
return str
|
||||
|
||||
|
||||
class MultimodalDatasetTorch(Dataset):
|
||||
def __init__(self, lX, lY, split="train"):
|
||||
self.lX = lX
|
||||
self.lY = lY
|
||||
self.split = split
|
||||
self.langs = []
|
||||
self.init()
|
||||
|
||||
def init(self):
|
||||
self.X = torch.vstack([imgs for imgs in self.lX.values()])
|
||||
if self.split != "whole":
|
||||
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
|
||||
self.langs = sum(
|
||||
[
|
||||
v
|
||||
for v in {
|
||||
lang: [lang] * len(data) for lang, data in self.lX.items()
|
||||
}.values()
|
||||
],
|
||||
[],
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.X)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.split == "whole":
|
||||
return self.X[index], self.langs[index]
|
||||
return self.X[index], self.Y[index], self.langs[index]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from os.path import expanduser
|
||||
|
||||
from dataManager.gFunDataset import gFunDataset
|
||||
|
||||
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
|
||||
dataset = gFunDataset(
|
||||
dataset_dir=GLAMI_DATAPATH,
|
||||
is_textual=True,
|
||||
is_visual=True,
|
||||
is_multilabel=False,
|
||||
nrows=50,
|
||||
)
|
||||
|
||||
vg = VisualTransformerGen(
|
||||
dataset_name=dataset.dataset_name,
|
||||
model_name="vit",
|
||||
device="cuda",
|
||||
epochs=5,
|
||||
evaluate_step=10,
|
||||
patience=10,
|
||||
probabilistic=True,
|
||||
)
|
||||
lX, lY = dataset.training()
|
||||
vg.fit(lX, lY)
|
||||
out = vg.transform(lX)
|
||||
exit(0)
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
beautifulsoup4==4.11.2
|
||||
joblib==1.2.0
|
||||
matplotlib==3.7.1
|
||||
numpy==1.24.2
|
||||
matplotlib==3.6.3
|
||||
numpy==1.24.1
|
||||
pandas==1.5.3
|
||||
Pillow==9.4.0
|
||||
requests==2.28.2
|
||||
scikit_learn==1.2.1
|
||||
scikit_learn==1.2.2
|
||||
scipy==1.10.1
|
||||
torch==1.13.1
|
||||
torchtext==0.14.1
|
||||
tqdm==4.65.0
|
||||
transformers==4.26.1
|
||||
tqdm==4.64.1
|
||||
transformers==4.26.0
|
||||
|
|
Loading…
Reference in New Issue