testing twoclassbatch

This commit is contained in:
Alejandro Moreo Fernandez 2020-05-03 04:00:34 +02:00
parent d6f2f16de1
commit cc49ffd152
1 changed files with 2 additions and 2 deletions

View File

@ -18,11 +18,11 @@ class AuthorshipAttributionClassifier(nn.Module):
self.device = device
def fit(self, X, y, batch_size, epochs, lr=0.001, val_prop=0.1, log='../log/tmp.csv'):
batcher = Batch(batch_size=batch_size, n_epochs=epochs)
#batcher = Batch(batch_size=batch_size, n_epochs=epochs)
batcher = TwoClassBatch(batch_size=batch_size, n_epochs=epochs, steps_per_epoch=X.shape[0]//batch_size)
batcher_val = Batch(batch_size=batch_size, n_epochs=epochs, shuffle=False)
criterion = torch.nn.CrossEntropyLoss().to(self.device)
optim = torch.optim.Adam(self.parameters(), lr=lr)
#optim = torch.optim.Adadelta(self.parameters(), lr=lr)
X, Xval, y, yval = train_test_split(X, y, test_size=val_prop, stratify=y)