testing twoclassbatch
This commit is contained in:
parent
d6f2f16de1
commit
cc49ffd152
|
@ -18,11 +18,11 @@ class AuthorshipAttributionClassifier(nn.Module):
|
|||
self.device = device
|
||||
|
||||
def fit(self, X, y, batch_size, epochs, lr=0.001, val_prop=0.1, log='../log/tmp.csv'):
|
||||
batcher = Batch(batch_size=batch_size, n_epochs=epochs)
|
||||
#batcher = Batch(batch_size=batch_size, n_epochs=epochs)
|
||||
batcher = TwoClassBatch(batch_size=batch_size, n_epochs=epochs, steps_per_epoch=X.shape[0]//batch_size)
|
||||
batcher_val = Batch(batch_size=batch_size, n_epochs=epochs, shuffle=False)
|
||||
criterion = torch.nn.CrossEntropyLoss().to(self.device)
|
||||
optim = torch.optim.Adam(self.parameters(), lr=lr)
|
||||
#optim = torch.optim.Adadelta(self.parameters(), lr=lr)
|
||||
|
||||
X, Xval, y, yval = train_test_split(X, y, test_size=val_prop, stratify=y)
|
||||
|
||||
|
|
Loading…
Reference in New Issue