diff --git a/dataManager/torchDataset.py b/dataManager/torchDataset.py index f5c60d8..76ec388 100644 --- a/dataManager/torchDataset.py +++ b/dataManager/torchDataset.py @@ -1,2 +1,66 @@ -class TorchMultiNewsDataset: - pass +import torch +from torch.utils.data import Dataset + + +class MultilingualDatasetTorch(Dataset): + def __init__(self, lX, lY, split="train"): + self.lX = lX + self.lY = lY + self.split = split + self.langs = [] + self.init() + + def init(self): + self.X = torch.vstack([data.input_ids for data in self.lX.values()]) + if self.split != "whole": + self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()]) + self.langs = sum( + [ + v + for v in { + lang: [lang] * len(data.input_ids) for lang, data in self.lX.items() + }.values() + ], + [], + ) + + return self + + def __len__(self): + return len(self.X) + + def __getitem__(self, index): + if self.split == "whole": + return self.X[index], self.langs[index] + return self.X[index], self.Y[index], self.langs[index] + + +class MultimodalDatasetTorch(Dataset): + def __init__(self, lX, lY, split="train"): + self.lX = lX + self.lY = lY + self.split = split + self.langs = [] + self.init() + + def init(self): + self.X = torch.vstack([imgs for imgs in self.lX.values()]) + if self.split != "whole": + self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()]) + self.langs = sum( + [ + v + for v in { + lang: [lang] * len(data) for lang, data in self.lX.items() + }.values() + ], + [], + ) + + def __len__(self): + return len(self.X) + + def __getitem__(self, index): + if self.split == "whole": + return self.X[index], self.langs[index] + return self.X[index], self.Y[index], self.langs[index] diff --git a/gfun/vgfs/commons.py b/gfun/vgfs/commons.py index 617b2fd..979704f 100644 --- a/gfun/vgfs/commons.py +++ b/gfun/vgfs/commons.py @@ -278,7 +278,7 @@ class Trainer: loss = self.loss_fn(y_hat, y.to(self.device)) loss.backward() self.optimizer.step() - batch_losses.append(loss.item()) # TODO: is this still on gpu? + batch_losses.append(loss.item()) if (epoch + 1) % PRINT_ON_EPOCH == 0: if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0: print( diff --git a/gfun/vgfs/textualTransformerGen.py b/gfun/vgfs/textualTransformerGen.py index ed5a92c..5bfb5c1 100644 --- a/gfun/vgfs/textualTransformerGen.py +++ b/gfun/vgfs/textualTransformerGen.py @@ -9,13 +9,13 @@ import torch import torch.nn as nn import transformers from transformers import MT5EncoderModel -from torch.utils.data import Dataset from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers.modeling_outputs import ModelOutput from gfun.vgfs.commons import Trainer from gfun.vgfs.transformerGen import TransformerGen from gfun.vgfs.viewGen import ViewGen +from dataManager.torchDataset import MultilingualDatasetTorch transformers.logging.set_verbosity_error() @@ -44,11 +44,12 @@ class MT5ForSequenceClassification(nn.Module): return ModelOutput(logits=logits) def save_pretrained(self, checkpoint_dir): - pass # TODO: implement + torch.save(self.state_dict(), checkpoint_dir + ".pt") + return def from_pretrained(self, checkpoint_dir): - # TODO: implement - return self + checkpoint_dir += ".pt" + return self.load_state_dict(torch.load(checkpoint_dir)) class TextualTransformerGen(ViewGen, TransformerGen): @@ -165,9 +166,7 @@ class TextualTransformerGen(ViewGen, TransformerGen): shuffle=False, ) - experiment_name = ( - f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}" - ) + experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}" trainer = Trainer( model=self.model, @@ -179,12 +178,17 @@ class TextualTransformerGen(ViewGen, TransformerGen): evaluate_step=self.evaluate_step, patience=self.patience, experiment_name=experiment_name, - checkpoint_path="models/vgfs/transformer", + checkpoint_path=os.path.join( + "models", + "vgfs", + "transformer", + self._format_model_name(self.model_name), + ), vgf_name="textual_trf", classification_type=self.clf_type, n_jobs=self.n_jobs, - # scheduler_name="ReduceLROnPlateau", - scheduler_name=None, + scheduler_name="ReduceLROnPlateau", + # scheduler_name=None, ) trainer.train( train_dataloader=tra_dataloader, @@ -259,39 +263,17 @@ class TextualTransformerGen(ViewGen, TransformerGen): for param in self.model.parameters(): param.requires_grad = False + def _format_model_name(self, model_name): + if "mt5" in model_name: + return "google-mt5" + elif "bert" in model_name: + if "multilingual" in model_name: + return "mbert" + elif "xlm" in model_name: + return "xlm" + else: + return model_name + def __str__(self): str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n" return str - - -class MultilingualDatasetTorch(Dataset): - def __init__(self, lX, lY, split="train"): - self.lX = lX - self.lY = lY - self.split = split - self.langs = [] - self.init() - - def init(self): - self.X = torch.vstack([data.input_ids for data in self.lX.values()]) - if self.split != "whole": - self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()]) - self.langs = sum( - [ - v - for v in { - lang: [lang] * len(data.input_ids) for lang, data in self.lX.items() - }.values() - ], - [], - ) - - return self - - def __len__(self): - return len(self.X) - - def __getitem__(self, index): - if self.split == "whole": - return self.X[index], self.langs[index] - return self.X[index], self.Y[index], self.langs[index] diff --git a/gfun/vgfs/visualTransformerGen.py b/gfun/vgfs/visualTransformerGen.py index 6c4d3b1..ae8b914 100644 --- a/gfun/vgfs/visualTransformerGen.py +++ b/gfun/vgfs/visualTransformerGen.py @@ -4,12 +4,12 @@ import numpy as np import torch import transformers from PIL import Image -from torch.utils.data import Dataset from transformers import AutoImageProcessor, AutoModelForImageClassification from gfun.vgfs.commons import Trainer from gfun.vgfs.transformerGen import TransformerGen from gfun.vgfs.viewGen import ViewGen +from dataManager.torchDataset import MultilingualDatasetTorch transformers.logging.set_verbosity_error() @@ -186,63 +186,3 @@ class VisualTransformerGen(ViewGen, TransformerGen): def __str__(self): str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n" return str - - -class MultimodalDatasetTorch(Dataset): - def __init__(self, lX, lY, split="train"): - self.lX = lX - self.lY = lY - self.split = split - self.langs = [] - self.init() - - def init(self): - self.X = torch.vstack([imgs for imgs in self.lX.values()]) - if self.split != "whole": - self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()]) - self.langs = sum( - [ - v - for v in { - lang: [lang] * len(data) for lang, data in self.lX.items() - }.values() - ], - [], - ) - - def __len__(self): - return len(self.X) - - def __getitem__(self, index): - if self.split == "whole": - return self.X[index], self.langs[index] - return self.X[index], self.Y[index], self.langs[index] - - -if __name__ == "__main__": - from os.path import expanduser - - from dataManager.gFunDataset import gFunDataset - - GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset") - dataset = gFunDataset( - dataset_dir=GLAMI_DATAPATH, - is_textual=True, - is_visual=True, - is_multilabel=False, - nrows=50, - ) - - vg = VisualTransformerGen( - dataset_name=dataset.dataset_name, - model_name="vit", - device="cuda", - epochs=5, - evaluate_step=10, - patience=10, - probabilistic=True, - ) - lX, lY = dataset.training() - vg.fit(lX, lY) - out = vg.transform(lX) - exit(0) diff --git a/requirements.txt b/requirements.txt index 3f39887..3bdb99d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,13 @@ beautifulsoup4==4.11.2 joblib==1.2.0 -matplotlib==3.7.1 -numpy==1.24.2 +matplotlib==3.6.3 +numpy==1.24.1 pandas==1.5.3 Pillow==9.4.0 requests==2.28.2 -scikit_learn==1.2.1 +scikit_learn==1.2.2 scipy==1.10.1 torch==1.13.1 torchtext==0.14.1 -tqdm==4.65.0 -transformers==4.26.1 +tqdm==4.64.1 +transformers==4.26.0