bug in log sofmax
This commit is contained in:
parent
0fbbd64b05
commit
d6f2f16de1
|
@ -53,7 +53,8 @@ class AuthorshipAttributionClassifier(nn.Module):
|
||||||
logits = self.forward(xi)
|
logits = self.forward(xi)
|
||||||
loss = criterion(logits, torch.as_tensor(yi).to(self.device))
|
loss = criterion(logits, torch.as_tensor(yi).to(self.device))
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
prediction = tensor2numpy(torch.argmax(nn.functional.log_softmax(logits), dim=1).view(-1))
|
logits = nn.functional.log_softmax(logits, dim=1)
|
||||||
|
prediction = tensor2numpy(torch.argmax(logits, dim=1).view(-1))
|
||||||
predictions.append(prediction)
|
predictions.append(prediction)
|
||||||
val_loss = np.mean(losses)
|
val_loss = np.mean(losses)
|
||||||
predictions = np.concatenate(predictions)
|
predictions = np.concatenate(predictions)
|
||||||
|
@ -71,7 +72,8 @@ class AuthorshipAttributionClassifier(nn.Module):
|
||||||
for xi in tqdm(batcher.epoch(x), desc='test'):
|
for xi in tqdm(batcher.epoch(x), desc='test'):
|
||||||
xi = self.padder.transform(xi)
|
xi = self.padder.transform(xi)
|
||||||
logits = self.forward(xi)
|
logits = self.forward(xi)
|
||||||
prediction = tensor2numpy(nn.functional.log_softmax(torch.argmax(logits, dim=1).view(-1)))
|
logits = nn.functional.log_softmax(logits, dim=1)
|
||||||
|
prediction = tensor2numpy(torch.argmax(logits, dim=1).view(-1))
|
||||||
predictions.append(prediction)
|
predictions.append(prediction)
|
||||||
return np.concatenate(predictions)
|
return np.concatenate(predictions)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue