From 8dbe48ff7a454aed9c28757757089ed6634a1418 Mon Sep 17 00:00:00 2001 From: andrea Date: Wed, 20 Jan 2021 11:47:51 +0100 Subject: [PATCH] Implemented custom micro F1 in pl (cpu and gpu) --- refactor/data/datamodule.py | 13 +++- refactor/main.py | 5 +- refactor/models/pl_gru.py | 150 ++++++++++++------------------------ refactor/util/common.py | 9 +++ refactor/view_generators.py | 14 ++-- 5 files changed, 83 insertions(+), 108 deletions(-) diff --git a/refactor/data/datamodule.py b/refactor/data/datamodule.py index 67a83d6..29020dc 100644 --- a/refactor/data/datamodule.py +++ b/refactor/data/datamodule.py @@ -88,7 +88,7 @@ class RecurrentDataset(Dataset): return index_list -class GfunDataModule(pl.LightningDataModule): +class RecurrentDataModule(pl.LightningDataModule): def __init__(self, multilingualIndex, batchsize=64): """ Pytorch-lightning DataModule: https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html @@ -105,9 +105,18 @@ class GfunDataModule(pl.LightningDataModule): def setup(self, stage=None): if stage == 'fit' or stage is None: l_train_index, l_train_target = self.multilingualIndex.l_train() + + # l_train_index = {l: train[:50] for l, train in l_train_index.items()} + # l_train_target = {l: target[:50] for l, target in l_train_target.items()} + self.training_dataset = RecurrentDataset(l_train_index, l_train_target, lPad_index=self.multilingualIndex.l_pad()) + l_val_index, l_val_target = self.multilingualIndex.l_val() + + # l_val_index = {l: train[:50] for l, train in l_val_index.items()} + # l_val_target = {l: target[:50] for l, target in l_val_target.items()} + self.val_dataset = RecurrentDataset(l_val_index, l_val_target, lPad_index=self.multilingualIndex.l_pad()) if stage == 'test' or stage is None: @@ -128,7 +137,7 @@ class GfunDataModule(pl.LightningDataModule): collate_fn=self.test_dataset.collate_fn) -class BertDataModule(GfunDataModule): +class BertDataModule(RecurrentDataModule): def __init__(self, multilingualIndex, batchsize=64, max_len=512): super().__init__(multilingualIndex, batchsize) self.max_len = max_len diff --git a/refactor/main.py b/refactor/main.py index 42ef9c9..45487b1 100644 --- a/refactor/main.py +++ b/refactor/main.py @@ -15,7 +15,7 @@ def main(args): _DATASET = '/home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle' EMBEDDINGS_PATH = '/home/andreapdr/gfun/embeddings' data = MultilingualDataset.load(_DATASET) - data.set_view(languages=['it'], categories=[0,1]) + data.set_view(languages=['it'], categories=[0, 1]) lX, ly = data.training() lXte, lyte = data.test() @@ -28,7 +28,8 @@ def main(args): # gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS) # gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS) # gFun = WordClassGen(n_jobs=N_JOBS) - gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=512, gpus=args.gpus, n_jobs=N_JOBS) + gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=5, nepochs=100, + gpus=args.gpus, n_jobs=N_JOBS) # gFun = BertGen(multilingualIndex, gpus=args.gpus, batch_size=128, n_jobs=N_JOBS) gFun.fit(lX, ly) diff --git a/refactor/models/pl_gru.py b/refactor/models/pl_gru.py index 268a694..2e3ecf1 100644 --- a/refactor/models/pl_gru.py +++ b/refactor/models/pl_gru.py @@ -5,12 +5,10 @@ from transformers import AdamW import torch.nn.functional as F from torch.autograd import Variable import pytorch_lightning as pl -from pytorch_lightning.metrics import F1, Accuracy, Metric +from pytorch_lightning.metrics import Metric, F1, Accuracy from torch.optim.lr_scheduler import StepLR -from typing import Any, Optional, Tuple -from pytorch_lightning.metrics.utils import _input_format_classification_one_hot, class_reduce from models.helpers import init_embeddings -import numpy as np +from util.common import is_true, is_false from util.evaluation import evaluate @@ -20,8 +18,9 @@ class RecurrentModel(pl.LightningModule): """ def __init__(self, lPretrained, langs, output_size, hidden_size, lVocab_size, learnable_length, - drop_embedding_range, drop_embedding_prop, lMuse_debug=None, multilingual_index_debug=None): + drop_embedding_range, drop_embedding_prop, gpus=None): super().__init__() + self.gpus = gpus self.langs = langs self.lVocab_size = lVocab_size self.learnable_length = learnable_length @@ -33,7 +32,7 @@ class RecurrentModel(pl.LightningModule): self.microf1 = F1(num_classes=output_size, multilabel=True, average='micro') self.macrof1 = F1(num_classes=output_size, multilabel=True, average='macro') self.accuracy = Accuracy() - self.customMetrics = CustomMetrics(num_classes=output_size, multilabel=True, average='micro') + self.customMetrics = CustomF1(num_classes=output_size, device=self.gpus) self.lPretrained_embeddings = nn.ModuleDict() self.lLearnable_embeddings = nn.ModuleDict() @@ -42,10 +41,6 @@ class RecurrentModel(pl.LightningModule): self.n_directions = 1 self.dropout = nn.Dropout(0.6) - # TODO: debug setting - self.lMuse = lMuse_debug - self.multilingual_index_debug = multilingual_index_debug - lstm_out = 256 ff1 = 512 ff2 = 256 @@ -111,7 +106,7 @@ class RecurrentModel(pl.LightningModule): custom = self.customMetrics(predictions, ly) self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True) self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True) - self.log('custom', custom, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('custom', custom, on_step=True, on_epoch=True, prog_bar=True, logger=True) return {'loss': loss} def validation_step(self, val_batch, batch_idx): @@ -139,7 +134,6 @@ class RecurrentModel(pl.LightningModule): accuracy = self.accuracy(predictions, ly) self.log('test-accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True) return - # return {'pred': predictions, 'target': ly} def embed(self, X, lang): input_list = [] @@ -166,98 +160,56 @@ class RecurrentModel(pl.LightningModule): return [optimizer], [scheduler] -class CustomMetrics(Metric): - def __init__( - self, - num_classes: int, - beta: float = 1.0, - threshold: float = 0.5, - average: str = "micro", - multilabel: bool = False, - compute_on_step: bool = True, - dist_sync_on_step: bool = False, - process_group: Optional[Any] = None, - ): - super().__init__( - compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, - ) - +class CustomF1(Metric): + def __init__(self, num_classes, device, average='micro'): + """ + Custom F1 metric. + Scikit learn provides a full set of evaluation metrics, but they treat special cases differently. + I.e., when the number of true positives, false positives, and false negatives amount to 0, all + affected metrics (precision, recall, and thus f1) output 0 in Scikit learn. + We adhere to the common practice of outputting 1 in this case since the classifier has correctly + classified all examples as negatives. + :param num_classes: + :param device: + :param average: + """ + super().__init__() self.num_classes = num_classes - self.beta = beta - self.threshold = threshold self.average = average - self.multilabel = multilabel + self.device = 'cuda' if device else 'cpu' + self.add_state('true_positive', default=torch.zeros(self.num_classes)) + self.add_state('true_negative', default=torch.zeros(self.num_classes)) + self.add_state('false_positive', default=torch.zeros(self.num_classes)) + self.add_state('false_negative', default=torch.zeros(self.num_classes)) - allowed_average = ("micro", "macro", "weighted", None) - if self.average not in allowed_average: - raise ValueError('Argument `average` expected to be one of the following:' - f' {allowed_average} but got {self.average}') + def update(self, preds, target): + true_positive, true_negative, false_positive, false_negative = self._update(preds, target) - self.add_state("true_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") - self.add_state("predicted_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") - self.add_state("actual_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") + self.true_positive += true_positive + self.true_negative += true_negative + self.false_positive += false_positive + self.false_negative += false_negative - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. + def _update(self, pred, target): + assert pred.shape == target.shape + # preparing preds and targets for count + true_pred = is_true(pred, self.device) + false_pred = is_false(pred, self.device) + true_target = is_true(target, self.device) + false_target = is_false(target, self.device) - Args: - preds: Predictions from model - target: Ground truth values - """ - true_positives, predicted_positives, actual_positives = _fbeta_update( - preds, target, self.num_classes, self.threshold, self.multilabel - ) - - self.true_positives += true_positives - self.predicted_positives += predicted_positives - self.actual_positives += actual_positives + tp = torch.sum(true_pred * true_target, dim=0) + tn = torch.sum(false_pred * false_target, dim=0) + fp = torch.sum(true_pred * false_target, dim=0) + fn = torch.sum(false_pred * target, dim=0) + return tp, tn, fp, fn def compute(self): - """ - Computes metrics over state. - """ - return _fbeta_compute(self.true_positives, self.predicted_positives, - self.actual_positives, self.beta, self.average) - - -def _fbeta_update( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: int, - threshold: float = 0.5, - multilabel: bool = False -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - preds, target = _input_format_classification_one_hot( - num_classes, preds, target, threshold, multilabel - ) - true_positives = torch.sum(preds * target, dim=1) - predicted_positives = torch.sum(preds, dim=1) - actual_positives = torch.sum(target, dim=1) - return true_positives, predicted_positives, actual_positives - - -def _fbeta_compute( - true_positives: torch.Tensor, - predicted_positives: torch.Tensor, - actual_positives: torch.Tensor, - beta: float = 1.0, - average: str = "micro" -) -> torch.Tensor: - if average == "micro": - precision = true_positives.sum().float() / predicted_positives.sum() - recall = true_positives.sum().float() / actual_positives.sum() - else: - precision = true_positives.float() / predicted_positives - recall = true_positives.float() / actual_positives - - num = (1 + beta ** 2) * precision * recall - denom = beta ** 2 * precision + recall - new_num = 2 * true_positives - new_fp = predicted_positives - true_positives - new_fn = actual_positives - true_positives - new_den = 2 * true_positives + new_fp + new_fn - if new_den.sum() == 0: - # whats is the correct return type ? TODO - return class_reduce(new_num, new_den, weights=actual_positives, class_reduction=average) - return class_reduce(num, denom, weights=actual_positives, class_reduction=average) + if self.average == 'micro': + num = 2.0 * self.true_positive.sum() + den = 2.0 * self.true_positive.sum() + self.false_positive.sum() + self.false_negative.sum() + if den > 0: + return (num / den).to(self.device) + return torch.FloatTensor([1.]).to(self.device) + if self.average == 'macro': + raise NotImplementedError diff --git a/refactor/util/common.py b/refactor/util/common.py index 4bd0c20..c6f6610 100644 --- a/refactor/util/common.py +++ b/refactor/util/common.py @@ -327,3 +327,12 @@ def index(data, vocab, known_words, analyzer, unk_index, out_of_vocabulary): # pbar.set_description(f'[unk = {unk_count}/{knw_count}={(100.*unk_count/knw_count):.2f}%]' # f'[out = {out_count}/{knw_count}={(100.*out_count/knw_count):.2f}%]') return indexes + + +def is_true(tensor, device): + return torch.where(tensor == 1, torch.Tensor([1]).to(device), torch.Tensor([0]).to(device)) + + +def is_false(tensor, device): + return torch.where(tensor == 0, torch.Tensor([1]).to(device), torch.Tensor([0]).to(device)) + diff --git a/refactor/view_generators.py b/refactor/view_generators.py index abe2442..9ea91fa 100644 --- a/refactor/view_generators.py +++ b/refactor/view_generators.py @@ -22,7 +22,7 @@ from models.pl_gru import RecurrentModel from models.pl_bert import BertModel from models.lstm_class import RNNMultilingualClassifier from pytorch_lightning import Trainer -from data.datamodule import GfunDataModule, BertDataModule +from data.datamodule import RecurrentDataModule, BertDataModule from pytorch_lightning.loggers import TensorBoardLogger import torch @@ -144,7 +144,8 @@ class RecurrentGen(ViewGen): # TODO: save model https://forums.pytorchlightning.ai/t/how-to-save-hparams-when-not-provided-as-argument-apparently-assigning-to-hparams-is-not-recomended/339/5 # Problem: we are passing lPretrained to init the RecurrentModel -> incredible slow at saving (checkpoint). # if we do not save it is impossible to init RecurrentModel by calling RecurrentModel.load_from_checkpoint() - def __init__(self, multilingualIndex, pretrained_embeddings, wce, batch_size=512, gpus=0, n_jobs=-1, stored_path=None): + def __init__(self, multilingualIndex, pretrained_embeddings, wce, batch_size=512, nepochs=50, + gpus=0, n_jobs=-1, stored_path=None): """ generates document embedding by means of a Gated Recurrent Units. The model can be initialized with different (multilingual/aligned) word representations (e.g., MUSE, WCE, ecc.,). @@ -162,6 +163,7 @@ class RecurrentGen(ViewGen): self.gpus = gpus self.n_jobs = n_jobs self.stored_path = stored_path + self.nepochs = nepochs # EMBEDDINGS to be deployed self.pretrained = pretrained_embeddings @@ -193,7 +195,8 @@ class RecurrentGen(ViewGen): lVocab_size=lvocab_size, learnable_length=learnable_length, drop_embedding_range=self.multilingualIndex.sup_range, - drop_embedding_prop=0.5 + drop_embedding_prop=0.5, + gpus=self.gpus ) def fit(self, lX, ly): @@ -204,8 +207,9 @@ class RecurrentGen(ViewGen): :param ly: :return: """ - recurrentDataModule = GfunDataModule(self.multilingualIndex, batchsize=self.batch_size) - trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=50, checkpoint_callback=False) + recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size) + trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs, + checkpoint_callback=False) # vanilla_torch_model = torch.load( # '/home/andreapdr/funneling_pdr/checkpoint/gru_viewgen_-jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle')