diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py index 10c2333..7ab0c2b 100644 --- a/evaluation/evaluate.py +++ b/evaluation/evaluate.py @@ -1,46 +1,96 @@ from joblib import Parallel, delayed +from collections import defaultdict from evaluation.metrics import * +from sklearn.metrics import accuracy_score, top_k_accuracy_score, f1_score -def evaluation_metrics(y, y_): - if len(y.shape) == len(y_.shape) == 1 and len(np.unique(y)) > 2: # single-label - raise NotImplementedError() - else: +def evaluation_metrics(y, y_, clf_type): + if clf_type == "singlelabel": + return ( + accuracy_score(y, y_), + # TODO: we need the logits to compute this top_k_accuracy_score(y, y_, k=5), + # TODO: we need logits top_k_accuracy_score(y, y_, k=10), + f1_score(y, y_, average="macro", zero_division=1), + f1_score(y, y_, average="micro"), + ) + elif clf_type == "multilabel": return ( macroF1(y, y_), microF1(y, y_), macroK(y, y_), microK(y, y_), - # macroAcc(y, y_), ) + else: + raise ValueError("clf_type must be either 'singlelabel' or 'multilabel'") -def evaluate(ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1): +def evaluate( + ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1, clf_type="multilabel" +): if n_jobs == 1: - return {lang: metrics(ly_true[lang], ly_pred[lang]) for lang in ly_true.keys()} + return { + lang: metrics(ly_true[lang], ly_pred[lang], clf_type) + for lang in ly_true.keys() + } else: langs = list(ly_true.keys()) evals = Parallel(n_jobs=n_jobs)( - delayed(metrics)(ly_true[lang], ly_pred[lang]) for lang in langs + delayed(metrics)(ly_true[lang], ly_pred[lang], clf_type) for lang in langs ) return {lang: evals[i] for i, lang in enumerate(langs)} -def log_eval(l_eval, phase="training", verbose=True): +def log_eval(l_eval, phase="training", clf_type="multilabel", verbose=True): if verbose: print(f"\n[Results {phase}]") metrics = [] - for lang in l_eval.keys(): - macrof1, microf1, macrok, microk = l_eval[lang] - metrics.append([macrof1, microf1, macrok, microk]) - if phase != "validation": - print(f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}") - averages = np.mean(np.array(metrics), axis=0) - if verbose: - print( - "Averages: MF1, mF1, MK, mK", - np.round(averages, 3), - "\n", - ) - return averages + + if clf_type == "multilabel": + for lang in l_eval.keys(): + macrof1, microf1, macrok, microk = l_eval[lang] + metrics.append([macrof1, microf1, macrok, microk]) + if phase != "validation": + print(f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}") + averages = np.mean(np.array(metrics), axis=0) + if verbose: + print( + "Averages: MF1, mF1, MK, mK", + np.round(averages, 3), + "\n", + ) + return averages # TODO: return a dict avg and lang specific + + elif clf_type == "singlelabel": + lang_metrics = defaultdict(dict) + _metrics = [ + "accuracy", + # "acc5", # "accuracy-at-5", + # "acc10", # "accuracy-at-10", + "MF1", # "macro-F1", + "mF1", # "micro-F1", + ] + for lang in l_eval.keys(): + # acc, top5, top10, macrof1, microf1 = l_eval[lang] + acc, macrof1, microf1 = l_eval[lang] + # metrics.append([acc, top5, top10, macrof1, microf1]) + metrics.append([acc, macrof1, microf1]) + + for m, v in zip(_metrics, l_eval[lang]): + lang_metrics[m][lang] = v + + if phase != "validation": + print( + # f"Lang {lang}: acc = {acc:.3f} acc-top5 = {top5:.3f} acc-top10 = {top10:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}" + f"Lang {lang}: acc = {acc:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}" + ) + averages = np.mean(np.array(metrics), axis=0) + if verbose: + print( + # "Averages: Acc, Acc-5, Acc-10, MF1, mF1", + "Averages: Acc, MF1, mF1", + np.round(averages, 3), + "\n", + ) + avg_metrics = dict(zip(_metrics, averages)) + return avg_metrics, lang_metrics diff --git a/gfun/generalizedFunnelling.py b/gfun/generalizedFunnelling.py index 3899107..a969596 100644 --- a/gfun/generalizedFunnelling.py +++ b/gfun/generalizedFunnelling.py @@ -1,17 +1,14 @@ import os -import sys - -# sys.path.append(os.path.join(os.getcwd(), "gfun")) - import pickle import numpy as np -from gfun.vgfs.commons import TfidfVectorizerMultilingual, AttentionAggregator + +from gfun.vgfs.commons import AttentionAggregator, TfidfVectorizerMultilingual, predict from gfun.vgfs.learners.svms import MetaClassifier, get_learner from gfun.vgfs.multilingualGen import MultilingualGen from gfun.vgfs.textualTransformerGen import TextualTransformerGen -from gfun.vgfs.visualTransformerGen import VisualTransformerGen from gfun.vgfs.vanillaFun import VanillaFunGen +from gfun.vgfs.visualTransformerGen import VisualTransformerGen from gfun.vgfs.wceGen import WceGen @@ -25,12 +22,14 @@ class GeneralizedFunnelling: visual_transformer, langs, num_labels, + classification_type, embed_dir, n_jobs, batch_size, eval_batch_size, max_length, - lr, + textual_lr, + visual_lr, epochs, patience, evaluate_step, @@ -52,6 +51,7 @@ class GeneralizedFunnelling: self.visual_trf_vgf = visual_transformer self.probabilistic = probabilistic self.num_labels = num_labels + self.clf_type = classification_type # ------------------------ self.langs = langs self.embed_dir = embed_dir @@ -59,7 +59,8 @@ class GeneralizedFunnelling: # Textual Transformer VGF params ---------- self.textual_trf_name = textual_transformer_name self.epochs = epochs - self.lr_transformer = lr + self.txt_trf_lr = textual_lr + self.vis_trf_lr = visual_lr self.batch_size_trf = batch_size self.eval_batch_size_trf = eval_batch_size self.max_length = max_length @@ -114,7 +115,7 @@ class GeneralizedFunnelling: self.attn_aggregator = AttentionAggregator( embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking), out_dim=self.num_labels, - lr=self.lr_transformer, + lr=self.txt_trf_lr, patience=self.patience, num_heads=1, device=self.device, @@ -148,7 +149,7 @@ class GeneralizedFunnelling: transformer_vgf = TextualTransformerGen( dataset_name=self.dataset_name, model_name=self.textual_trf_name, - lr=self.lr_transformer, + lr=self.txt_trf_lr, epochs=self.epochs, batch_size=self.batch_size_trf, batch_size_eval=self.eval_batch_size_trf, @@ -159,6 +160,7 @@ class GeneralizedFunnelling: verbose=True, patience=self.patience, device=self.device, + classification_type=self.clf_type, ) self.first_tier_learners.append(transformer_vgf) @@ -166,7 +168,7 @@ class GeneralizedFunnelling: visual_trasformer_vgf = VisualTransformerGen( dataset_name=self.dataset_name, model_name="vit", - lr=self.lr_transformer, + lr=self.vis_trf_lr, epochs=self.epochs, batch_size=self.batch_size_trf, batch_size_eval=self.eval_batch_size_trf, @@ -174,6 +176,7 @@ class GeneralizedFunnelling: evaluate_step=self.evaluate_step, patience=self.patience, device=self.device, + classification_type=self.clf_type, ) self.first_tier_learners.append(visual_trasformer_vgf) @@ -182,7 +185,7 @@ class GeneralizedFunnelling: self.attn_aggregator = AttentionAggregator( embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking), out_dim=self.num_labels, - lr=self.lr_transformer, + lr=self.txt_trf_lr, patience=self.patience, num_heads=1, device=self.device, @@ -255,10 +258,9 @@ class GeneralizedFunnelling: projections.append(l_posteriors) agg = self.aggregate(projections) l_out = self.metaclassifier.predict_proba(agg) - # converting to binary predictions - # if self.dataset_name in ["cls"]: # TODO: better way to do this - # for lang, preds in l_out.items(): - # l_out[lang] = np.expand_dims(np.argmax(preds, axis=1), axis=1) + if self.clf_type == "singlelabel": + for lang, preds in l_out.items(): + l_out[lang] = predict(preds, clf_type=self.clf_type) return l_out def fit_transform(self, lX, lY): diff --git a/gfun/vgfs/commons.py b/gfun/vgfs/commons.py index effbf9d..6399d85 100644 --- a/gfun/vgfs/commons.py +++ b/gfun/vgfs/commons.py @@ -9,6 +9,7 @@ from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.model_selection import train_test_split from sklearn.preprocessing import normalize from torch.optim import AdamW +from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.utils.data import DataLoader, Dataset from transformers.modeling_outputs import ModelOutput @@ -22,6 +23,21 @@ def _normalize(lX, l2=True): return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX +def verbosity_eval(epoch, print_eval): + if (epoch + 1) % print_eval == 0 and epoch != 0: + return True + else: + return False + + +def format_langkey_wandb(lang_dict): + log_dict = {} + for metric, l_dict in lang_dict.items(): + for lang, value in l_dict.items(): + log_dict[f"language metric/{metric}/{lang}"] = value + return log_dict + + def XdotM(X, M, sif): E = X.dot(M) if sif: @@ -58,18 +74,23 @@ def compute_pc(X, npc=1): return svd.components_ -def predict(logits, classification_type="multilabel"): +def predict(logits, clf_type="multilabel"): """ Converts soft precictions to hard predictions [0,1] """ - if classification_type == "multilabel": + if clf_type == "multilabel": prediction = torch.sigmoid(logits) > 0.5 - elif classification_type == "singlelabel": - prediction = torch.argmax(logits, dim=1).view(-1, 1) + return prediction.detach().cpu().numpy() + elif clf_type == "singlelabel": + if type(logits) != torch.Tensor: + logits = torch.tensor(logits) + prediction = torch.softmax(logits, dim=1) + prediction = prediction.detach().cpu().numpy() + _argmaxs = prediction.argmax(axis=1) + prediction = np.eye(prediction.shape[1])[_argmaxs] + return prediction else: - print("unknown classification type") - - return prediction.detach().cpu().numpy() + raise NotImplementedError() class TfidfVectorizerMultilingual: @@ -115,36 +136,54 @@ class Trainer: patience, experiment_name, checkpoint_path, + classification_type, vgf_name, + n_jobs, + scheduler_name=None, ): self.device = device self.model = model.to(device) - self.optimizer = self.init_optimizer(optimizer_name, lr) + self.optimizer, self.scheduler = self.init_optimizer( + optimizer_name, lr, scheduler_name + ) self.evaluate_steps = evaluate_step self.loss_fn = loss_fn.to(device) self.print_steps = print_steps self.experiment_name = experiment_name self.patience = patience - self.print_eval = evaluate_step + self.print_eval = 10 self.earlystopping = EarlyStopping( patience=patience, checkpoint_path=checkpoint_path, verbose=False, experiment_name=experiment_name, ) + self.clf_type = classification_type self.vgf_name = vgf_name + self.n_jobs = n_jobs + self.monitored_metric = ( + "macro-F1" if self.clf_type == "multilabel" else "accuracy" + ) # TODO: make this configurable - def init_optimizer(self, optimizer_name, lr): + def init_optimizer(self, optimizer_name, lr, scheduler_name): if optimizer_name.lower() == "adamw": - return AdamW(self.model.parameters(), lr=lr) + optim = AdamW(self.model.parameters(), lr=lr) else: raise ValueError(f"Optimizer {optimizer_name} not supported") + if scheduler_name is None: + scheduler = None + elif scheduler_name == "ReduceLROnPlateau": + scheduler = ReduceLROnPlateau(optim, "max", factor=0.5, min_lr=1e-5) + else: + raise ValueError(f"Scheduler {scheduler_name} not supported") + return optim, scheduler def get_config(self, train_dataloader, eval_dataloader, epochs): return { "model name": self.model.name_or_path, "epochs": epochs, "learning rate": self.optimizer.defaults["lr"], + "scheduler": "TODO", # TODO: add scheduler name "train batch size": train_dataloader.batch_size, "eval batch size": eval_dataloader.batch_size, "max len": train_dataloader.dataset.X.shape[-1], @@ -152,6 +191,7 @@ class Trainer: "evaluate every": self.evaluate_steps, "print eval every": self.print_eval, "print train steps": self.print_steps, + "classification type": self.clf_type, } def train(self, train_dataloader, eval_dataloader, epochs=10): @@ -168,23 +208,23 @@ class Trainer: for epoch in range(epochs): train_loss = self.train_epoch(train_dataloader, epoch) - wandb_logger.log({f"{self.vgf_name}_train_loss": train_loss}) - - if (epoch + 1) % self.evaluate_steps == 0: - print_eval = (epoch + 1) % self.print_eval == 0 + if (epoch + 1) % self.evaluate_steps == 0 or (epoch + 1) == 1: + print_eval = verbosity_eval(epoch, self.print_eval) with torch.no_grad(): - eval_loss, metric_watcher = self.evaluate( - eval_dataloader, epoch, print_eval=print_eval + eval_loss, avg_metrics, lang_metrics = self.evaluate( + eval_dataloader, + print_eval=print_eval, + n_jobs=self.n_jobs, ) wandb_logger.log( - { - f"{self.vgf_name}_eval_loss": eval_loss, - f"{self.vgf_name}_eval_metric": metric_watcher, - } + {"loss/val": eval_loss, **format_langkey_wandb(lang_metrics)}, + commit=False, ) - stop = self.earlystopping(metric_watcher, self.model, epoch + 1) + stop = self.earlystopping( + avg_metrics[self.monitored_metric], self.model, epoch + 1 + ) if stop: print( f"- restoring best model from epoch {self.earlystopping.best_epoch} with best metric: {self.earlystopping.best_score:3f}" @@ -194,6 +234,16 @@ class Trainer: ) break + if self.scheduler is not None: + self.scheduler.step(avg_metrics[self.monitored_metric]) + + wandb_logger.log( + { + "loss/train": train_loss, + "learning rate": self.optimizer.param_groups[0]["lr"], + } + ) + print(f"- last swipe on eval set") self.train_epoch(eval_dataloader, epoch=-1) self.earlystopping.save_model(self.model) @@ -201,6 +251,7 @@ class Trainer: def train_epoch(self, dataloader, epoch): self.model.train() + epoch_losses = [] for b_idx, (x, y, lang) in enumerate(dataloader): self.optimizer.zero_grad() y_hat = self.model(x.to(self.device)) @@ -210,38 +261,47 @@ class Trainer: loss = self.loss_fn(y_hat, y.to(self.device)) loss.backward() self.optimizer.step() + epoch_losses.append(loss.item()) if (epoch + 1) % PRINT_ON_EPOCH == 0: if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0: - print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}") - return loss.item() + print( + f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {np.mean(epoch_losses):.4f}" + ) + return np.mean(epoch_losses) - def evaluate(self, dataloader, epoch, print_eval=True): + def evaluate(self, dataloader, print_eval=True, n_jobs=-1): self.model.eval() + eval_losses = [] - lY = defaultdict(list) - lY_hat = defaultdict(list) + lY_true = defaultdict(list) + lY_pred = defaultdict(list) for b_idx, (x, y, lang) in enumerate(dataloader): - y_hat = self.model(x.to(self.device)) - if isinstance(y_hat, ModelOutput): - loss = self.loss_fn(y_hat.logits, y.to(self.device)) - predictions = predict(y_hat.logits, classification_type="multilabel") + y_pred = self.model(x.to(self.device)) + if isinstance(y_pred, ModelOutput): + loss = self.loss_fn(y_pred.logits, y.to(self.device)) + predictions = predict(y_pred.logits, clf_type=self.clf_type) else: - loss = self.loss_fn(y_hat, y.to(self.device)) - predictions = predict(y_hat, classification_type="multilabel") + loss = self.loss_fn(y_pred, y.to(self.device)) + predictions = predict(y_pred, clf_type=self.clf_type) + + eval_losses.append(loss.item()) for l, _true, _pred in zip(lang, y, predictions): - lY[l].append(_true.detach().cpu().numpy()) - lY_hat[l].append(_pred) + lY_true[l].append(_true.detach().cpu().numpy()) + lY_pred[l].append(_pred) - for lang in lY: - lY[lang] = np.vstack(lY[lang]) - lY_hat[lang] = np.vstack(lY_hat[lang]) + for lang in lY_true: + lY_true[lang] = np.vstack(lY_true[lang]) + lY_pred[lang] = np.vstack(lY_pred[lang]) - l_eval = evaluate(lY, lY_hat) - average_metrics = log_eval(l_eval, phase="validation", verbose=print_eval) + l_eval = evaluate(lY_true, lY_pred, clf_type=self.clf_type, n_jobs=n_jobs) - return loss.item(), average_metrics[0] # macro-F1 + avg_metrics, lang_metrics = log_eval( + l_eval, phase="validation", clf_type=self.clf_type, verbose=print_eval + ) + + return np.mean(eval_losses), avg_metrics, lang_metrics class EarlyStopping: @@ -279,7 +339,7 @@ class EarlyStopping: print( f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}" ) - if self.counter >= self.patience: + if self.counter >= self.patience and self.patience != -1: print(f"- earlystopping: Early stopping at epoch {epoch}") return True diff --git a/gfun/vgfs/textualTransformerGen.py b/gfun/vgfs/textualTransformerGen.py index 1c91b93..8f8d661 100644 --- a/gfun/vgfs/textualTransformerGen.py +++ b/gfun/vgfs/textualTransformerGen.py @@ -7,11 +7,9 @@ from collections import defaultdict import numpy as np import torch import transformers - -# from sklearn.model_selection import train_test_split -# from torch.optim import AdamW from torch.utils.data import Dataset from transformers import AutoModelForSequenceClassification, AutoTokenizer + from gfun.vgfs.commons import Trainer from gfun.vgfs.transformerGen import TransformerGen from gfun.vgfs.viewGen import ViewGen @@ -19,9 +17,6 @@ from gfun.vgfs.viewGen import ViewGen transformers.logging.set_verbosity_error() -# TODO: add support to loggers - - class TextualTransformerGen(ViewGen, TransformerGen): def __init__( self, @@ -39,6 +34,7 @@ class TextualTransformerGen(ViewGen, TransformerGen): evaluate_step=10, verbose=False, patience=5, + classification_type="multilabel", ): super().__init__( self._validate_model_name(model_name), @@ -56,6 +52,7 @@ class TextualTransformerGen(ViewGen, TransformerGen): verbose, patience, ) + self.clf_type = classification_type self.fitted = False print( f"- init Textual TransformerModel model_name: {self.model_name}, device: {self.device}]" @@ -143,6 +140,9 @@ class TextualTransformerGen(ViewGen, TransformerGen): experiment_name=experiment_name, checkpoint_path="models/vgfs/transformer", vgf_name="textual_trf", + classification_type=self.clf_type, + n_jobs=self.n_jobs, + scheduler_name="ReduceLROnPlateau", ) trainer.train( train_dataloader=tra_dataloader, diff --git a/gfun/vgfs/visualTransformerGen.py b/gfun/vgfs/visualTransformerGen.py index f8b6e6e..6c4d3b1 100644 --- a/gfun/vgfs/visualTransformerGen.py +++ b/gfun/vgfs/visualTransformerGen.py @@ -27,6 +27,7 @@ class VisualTransformerGen(ViewGen, TransformerGen): device="cpu", probabilistic=False, patience=5, + classification_type="multilabel", ): super().__init__( model_name, @@ -40,6 +41,7 @@ class VisualTransformerGen(ViewGen, TransformerGen): patience=patience, probabilistic=probabilistic, ) + self.clf_type = classification_type self.fitted = False print( f"- init Visual TransformerModel model_name: {self.model_name}, device: {self.device}]" @@ -113,6 +115,8 @@ class VisualTransformerGen(ViewGen, TransformerGen): experiment_name=experiment_name, checkpoint_path="models/vgfs/transformer", vgf_name="visual_trf", + classification_type=self.clf_type, + n_jobs=self.n_jobs, ) trainer.train( diff --git a/main.py b/main.py index 254b068..9efce33 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,7 @@ +import os + +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + from argparse import ArgumentParser from time import time @@ -7,6 +11,7 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling """ TODO: + - [!] LR scheduler - [!] CLS dataset is loading only "books" domain data - [!] documents should be trimmed to the same length (?) - [!] overall gfun results logger @@ -42,6 +47,7 @@ def main(args): dataset_name=args.dataset, langs=dataset.langs(), num_labels=dataset.num_labels(), + classification_type=args.clf_type, # Posterior VGF params ---------------- posterior=args.posteriors, # Multilingual VGF params ------------- @@ -55,7 +61,8 @@ def main(args): batch_size=args.batch_size, eval_batch_size=args.eval_batch_size, epochs=args.epochs, - lr=args.lr, + textual_lr=args.textual_lr, + visual_lr=args.visual_lr, max_length=args.max_length, patience=args.patience, evaluate_step=args.evaluate_step, @@ -93,8 +100,8 @@ def main(args): print(f"- training completed in {timetr - tinit:.2f} seconds") gfun_preds = gfun.transform(lX_te) - test_eval = evaluate(lY_te, gfun_preds) - log_eval(test_eval, phase="test") + test_eval = evaluate(lY_te, gfun_preds, clf_type=args.clf_type, n_jobs=args.n_jobs) + log_eval(test_eval, phase="test", clf_type=args.clf_type) timeval = time() print(f"- testing completed in {timeval - timetr:.2f} seconds") @@ -112,6 +119,7 @@ if __name__ == "__main__": parser.add_argument("--nrows", type=int, default=None) parser.add_argument("--min_count", type=int, default=10) parser.add_argument("--max_labels", type=int, default=50) + parser.add_argument("--clf_type", type=str, default="multilabel") # gFUN parameters ---------------------- parser.add_argument("-p", "--posteriors", action="store_true") parser.add_argument("-m", "--multilingual", action="store_true") @@ -127,7 +135,8 @@ if __name__ == "__main__": parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--eval_batch_size", type=int, default=128) parser.add_argument("--epochs", type=int, default=100) - parser.add_argument("--lr", type=float, default=1e-5) + parser.add_argument("--textual_lr", type=float, default=1e-5) + parser.add_argument("--visual_lr", type=float, default=1e-5) parser.add_argument("--max_length", type=int, default=128) parser.add_argument("--patience", type=int, default=5) parser.add_argument("--evaluate_step", type=int, default=10)