From d6f2f16de13990ca04885171f2a66e40e575e013 Mon Sep 17 00:00:00 2001 From: Alex Moreo Date: Sat, 2 May 2020 23:29:59 +0200 Subject: [PATCH] bug in log sofmax --- src/model/classifiers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/model/classifiers.py b/src/model/classifiers.py index 69ab6e4..96a7165 100644 --- a/src/model/classifiers.py +++ b/src/model/classifiers.py @@ -53,7 +53,8 @@ class AuthorshipAttributionClassifier(nn.Module): logits = self.forward(xi) loss = criterion(logits, torch.as_tensor(yi).to(self.device)) losses.append(loss.item()) - prediction = tensor2numpy(torch.argmax(nn.functional.log_softmax(logits), dim=1).view(-1)) + logits = nn.functional.log_softmax(logits, dim=1) + prediction = tensor2numpy(torch.argmax(logits, dim=1).view(-1)) predictions.append(prediction) val_loss = np.mean(losses) predictions = np.concatenate(predictions) @@ -71,7 +72,8 @@ class AuthorshipAttributionClassifier(nn.Module): for xi in tqdm(batcher.epoch(x), desc='test'): xi = self.padder.transform(xi) logits = self.forward(xi) - prediction = tensor2numpy(nn.functional.log_softmax(torch.argmax(logits, dim=1).view(-1))) + logits = nn.functional.log_softmax(logits, dim=1) + prediction = tensor2numpy(torch.argmax(logits, dim=1).view(-1)) predictions.append(prediction) return np.concatenate(predictions)