diff --git a/refactor/main.py b/refactor/main.py index 45487b1..bea0067 100644 --- a/refactor/main.py +++ b/refactor/main.py @@ -15,7 +15,7 @@ def main(args): _DATASET = '/home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle' EMBEDDINGS_PATH = '/home/andreapdr/gfun/embeddings' data = MultilingualDataset.load(_DATASET) - data.set_view(languages=['it'], categories=[0, 1]) + # data.set_view(languages=['it'], categories=[0, 1]) lX, ly = data.training() lXte, lyte = data.test() @@ -28,7 +28,7 @@ def main(args): # gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS) # gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS) # gFun = WordClassGen(n_jobs=N_JOBS) - gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=5, nepochs=100, + gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=256, nepochs=50, gpus=args.gpus, n_jobs=N_JOBS) # gFun = BertGen(multilingualIndex, gpus=args.gpus, batch_size=128, n_jobs=N_JOBS) diff --git a/refactor/models/pl_gru.py b/refactor/models/pl_gru.py index 2e3ecf1..1ed8314 100644 --- a/refactor/models/pl_gru.py +++ b/refactor/models/pl_gru.py @@ -29,10 +29,11 @@ class RecurrentModel(pl.LightningModule): self.drop_embedding_range = drop_embedding_range self.drop_embedding_prop = drop_embedding_prop self.loss = torch.nn.BCEWithLogitsLoss() - self.microf1 = F1(num_classes=output_size, multilabel=True, average='micro') - self.macrof1 = F1(num_classes=output_size, multilabel=True, average='macro') + # self.microf1 = F1(num_classes=output_size, multilabel=True, average='micro') + # self.macrof1 = F1(num_classes=output_size, multilabel=True, average='macro') self.accuracy = Accuracy() - self.customMetrics = CustomF1(num_classes=output_size, device=self.gpus) + self.customMicroF1 = CustomF1(num_classes=output_size, average='micro', device=self.gpus) + self.customMacroF1 = CustomF1(num_classes=output_size, average='macro', device=self.gpus) self.lPretrained_embeddings = nn.ModuleDict() self.lLearnable_embeddings = nn.ModuleDict() @@ -103,10 +104,12 @@ class RecurrentModel(pl.LightningModule): # Squashing logits through Sigmoid in order to get confidence score predictions = torch.sigmoid(logits) > 0.5 accuracy = self.accuracy(predictions, ly) - custom = self.customMetrics(predictions, ly) + microF1 = self.customMicroF1(predictions, ly) + macroF1 = self.customMacroF1(predictions, ly) self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True) self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True) - self.log('custom', custom, on_step=True, on_epoch=True, prog_bar=True, logger=True) + self.log('microF1', microF1, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=True, logger=True) return {'loss': loss} def validation_step(self, val_batch, batch_idx): @@ -212,4 +215,17 @@ class CustomF1(Metric): return (num / den).to(self.device) return torch.FloatTensor([1.]).to(self.device) if self.average == 'macro': - raise NotImplementedError + class_specific = [] + for i in range(self.num_classes): + class_tp = self.true_positive[i] + # class_tn = self.true_negative[i] + class_fp = self.false_positive[i] + class_fn = self.false_negative[i] + num = 2.0 * class_tp + den = 2.0 * class_tp + class_fp + class_fn + if den > 0: + class_specific.append(num / den) + else: + class_specific.append(1.) + average = torch.sum(torch.Tensor(class_specific))/self.num_classes + return average.to(self.device)