import os os.environ["TOKENIZERS_PARALLELISM"] = "true" from collections import defaultdict import numpy as np import torch import transformers from sklearn.model_selection import train_test_split from torch.optim import AdamW from torch.utils.data import DataLoader, Dataset from transformers import AutoModelForSequenceClassification, AutoTokenizer from vgfs.learners.svms import FeatureSet2Posteriors from evaluation.evaluate import evaluate, log_eval transformers.logging.set_verbosity_error() # TODO: add support to loggers class TransformerGen: def __init__( self, model_name, epochs=10, lr=1e-5, batch_size=4, batch_size_eval=32, max_length=512, print_steps=50, device="cpu", probabilistic=False, n_jobs=-1, evaluate_step=10, verbose=False, patience=5, ): self.model_name = model_name self.device = device self.model = None self.lr = lr self.epochs = epochs self.tokenizer = None self.max_length = max_length self.batch_size = batch_size self.batch_size_eval = batch_size_eval self.print_steps = print_steps self.probabilistic = probabilistic self.n_jobs = n_jobs self.fitted = False self.datasets = {} self.evaluate_step = evaluate_step self.verbose = verbose self.patience = patience self._init() def _init(self): if self.probabilistic: self.feature2posterior_projector = FeatureSet2Posteriors( n_jobs=self.n_jobs, verbose=False ) self.model_name = self._get_model_name(self.model_name) print( f"- init TransformerModel model_name: {self.model_name}, device: {self.device}]" ) def _get_model_name(self, name): if "bert" == name: name_model = "bert-base-uncased" elif "mbert" == name: name_model = "bert-base-multilingual-uncased" elif "xlm" == name: name_model = "xlm-roberta-base" else: raise NotImplementedError return name_model def load_pretrained_model(self, model_name, num_labels): return AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=num_labels, output_hidden_states=True ) def load_tokenizer(self, model_name): return AutoTokenizer.from_pretrained(model_name) def init_model(self, model_name, num_labels): return self.load_pretrained_model(model_name, num_labels), self.load_tokenizer( model_name ) def get_train_val_data(self, lX, lY, split=0.2, seed=42): tr_lX, tr_lY, val_lX, val_lY = {}, {}, {}, {} for lang in lX.keys(): tr_X, val_X, tr_Y, val_Y = train_test_split( lX[lang], lY[lang], test_size=split, random_state=seed, shuffle=False ) tr_lX[lang] = tr_X tr_lY[lang] = tr_Y val_lX[lang] = val_X val_lY[lang] = val_Y return tr_lX, tr_lY, val_lX, val_lY def build_dataloader(self, lX, lY, batch_size, split="train", shuffle=True): l_tokenized = {lang: self._tokenize(data) for lang, data in lX.items()} self.datasets[split] = MultilingualDatasetTorch(l_tokenized, lY, split=split) return DataLoader(self.datasets[split], batch_size=batch_size, shuffle=shuffle) def _tokenize(self, X): return self.tokenizer( X, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_length, ) def fit(self, lX, lY): if self.fitted: return self print("- fitting Transformer View Generating Function") _l = list(lX.keys())[0] self.num_labels = lY[_l].shape[-1] self.model, self.tokenizer = self.init_model( self.model_name, num_labels=self.num_labels ) tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data( lX, lY, split=0.2, seed=42 ) tra_dataloader = self.build_dataloader( tr_lX, tr_lY, self.batch_size, split="train", shuffle=True ) val_dataloader = self.build_dataloader( val_lX, val_lY, self.batch_size_eval, split="val", shuffle=False ) experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}" trainer = Trainer( model=self.model, optimizer_name="adamW", lr=self.lr, device=self.device, loss_fn=torch.nn.CrossEntropyLoss(), print_steps=self.print_steps, evaluate_step=self.evaluate_step, patience=self.patience, experiment_name=experiment_name, ) trainer.train( train_dataloader=tra_dataloader, eval_dataloader=val_dataloader, epochs=self.epochs, ) if self.probabilistic: self.feature2posterior_projector.fit(self.transform(lX), lY) self.fitted = True return self def transform(self, lX): _embeds = [] l_embeds = defaultdict(list) dataloader = self.build_dataloader( lX, lY=None, batch_size=self.batch_size_eval, split="whole", shuffle=False ) self.model.eval() with torch.no_grad(): for input_ids, lang in dataloader: input_ids = input_ids.to(self.device) out = self.model(input_ids).hidden_states[-1] batch_embeddings = out[:, 0, :].cpu().numpy() _embeds.append((batch_embeddings, lang)) for embed, lang in _embeds: for sample_embed, sample_lang in zip(embed, lang): l_embeds[sample_lang].append(sample_embed) if self.probabilistic and self.fitted: l_embeds = self.feature2posterior_projector.transform(l_embeds) return l_embeds def fit_transform(self, lX, lY): return self.fit(lX, lY).transform(lX) def save_vgf(self, model_id): import pickle from os.path import join from os import makedirs vgf_name = "transformerGen" _basedir = join("models", "vgfs", "transformer") makedirs(_basedir, exist_ok=True) _path = join(_basedir, f"{vgf_name}_{model_id}.pkl") with open(_path, "wb") as f: pickle.dump(self, f) return self def __str__(self): str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n" return str class MultilingualDatasetTorch(Dataset): def __init__(self, lX, lY, split="train"): self.lX = lX self.lY = lY self.split = split self.langs = [] self.init() def init(self): self.X = torch.vstack([data.input_ids for data in self.lX.values()]) if self.split != "whole": self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()]) self.langs = sum( [ v for v in { lang: [lang] * len(data.input_ids) for lang, data in self.lX.items() }.values() ], [], ) return self def __len__(self): return len(self.X) def __getitem__(self, index): if self.split == "whole": return self.X[index], self.langs[index] return self.X[index], self.Y[index], self.langs[index] class Trainer: def __init__( self, model, optimizer_name, device, loss_fn, lr, print_steps, evaluate_step, patience, experiment_name, ): self.device = device self.model = model.to(device) self.optimizer = self.init_optimizer(optimizer_name, lr) self.evaluate_steps = evaluate_step self.loss_fn = loss_fn.to(device) self.print_steps = print_steps self.earlystopping = EarlyStopping( patience=patience, checkpoint_path="models/vgfs/transformers/", verbose=True, experiment_name=experiment_name, ) def init_optimizer(self, optimizer_name, lr): if optimizer_name.lower() == "adamw": return AdamW(self.model.parameters(), lr=lr) else: raise ValueError(f"Optimizer {optimizer_name} not supported") def train(self, train_dataloader, eval_dataloader, epochs=10): print( f"""- Training params: - epochs: {epochs} - learning rate: {self.optimizer.defaults['lr']} - train batch size: {train_dataloader.batch_size} - eval batch size: {eval_dataloader.batch_size} - max len: {train_dataloader.dataset.X.shape[-1]}\n""", ) for epoch in range(epochs): self.train_epoch(train_dataloader, epoch) if (epoch + 1) % self.evaluate_steps == 0: metric_watcher = self.evaluate(eval_dataloader) stop = self.earlystopping(metric_watcher, self.model, epoch + 1) if stop: break return self.model def train_epoch(self, dataloader, epoch): self.model.train() for b_idx, (x, y, lang) in enumerate(dataloader): self.optimizer.zero_grad() y_hat = self.model(x.to(self.device)) loss = self.loss_fn(y_hat.logits, y.to(self.device)) loss.backward() self.optimizer.step() if b_idx % self.print_steps == 0: print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}") return self def evaluate(self, dataloader): self.model.eval() lY = defaultdict(list) lY_hat = defaultdict(list) for b_idx, (x, y, lang) in enumerate(dataloader): y_hat = self.model(x.to(self.device)) loss = self.loss_fn(y_hat.logits, y.to(self.device)) predictions = predict(y_hat.logits, classification_type="multilabel") for l, _true, _pred in zip(lang, y, predictions): lY[l].append(_true.detach().cpu().numpy()) lY_hat[l].append(_pred) for lang in lY: lY[lang] = np.vstack(lY[lang]) lY_hat[lang] = np.vstack(lY_hat[lang]) l_eval = evaluate(lY, lY_hat) average_metrics = log_eval(l_eval, phase="validation") return average_metrics[0] # macro-F1 class EarlyStopping: def __init__( self, patience=5, min_delta=0, verbose=True, checkpoint_path="checkpoint.pt", experiment_name="experiment", ): self.patience = patience self.min_delta = min_delta self.counter = 0 self.best_score = 0 self.best_epoch = None self.verbose = verbose self.checkpoint_path = checkpoint_path self.experiment_name = experiment_name def __call__(self, validation, model, epoch): if validation > self.best_score: print( f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}" ) self.best_score = validation self.counter = 0 # self.save_model(model) elif validation < (self.best_score + self.min_delta): self.counter += 1 print( f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}" ) if self.counter >= self.patience: if self.verbose: print(f"- earlystopping: Early stopping at epoch {epoch}") return True def save_model(self, model): _checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name) print(f"- saving model to {_checkpoint_dir}") os.makedirs(_checkpoint_dir, exist_ok=True) model.save_pretrained(_checkpoint_dir) def predict(logits, classification_type="multilabel"): """ Converts soft precictions to hard predictions [0,1] """ if classification_type == "multilabel": prediction = torch.sigmoid(logits) > 0.5 elif classification_type == "singlelabel": prediction = torch.argmax(logits, dim=1).view(-1, 1) else: print("unknown classification type") return prediction.detach().cpu().numpy()