testing twoclassbatch
This commit is contained in:
parent
d6f2f16de1
commit
cc49ffd152
|
@ -18,11 +18,11 @@ class AuthorshipAttributionClassifier(nn.Module):
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def fit(self, X, y, batch_size, epochs, lr=0.001, val_prop=0.1, log='../log/tmp.csv'):
|
def fit(self, X, y, batch_size, epochs, lr=0.001, val_prop=0.1, log='../log/tmp.csv'):
|
||||||
batcher = Batch(batch_size=batch_size, n_epochs=epochs)
|
#batcher = Batch(batch_size=batch_size, n_epochs=epochs)
|
||||||
|
batcher = TwoClassBatch(batch_size=batch_size, n_epochs=epochs, steps_per_epoch=X.shape[0]//batch_size)
|
||||||
batcher_val = Batch(batch_size=batch_size, n_epochs=epochs, shuffle=False)
|
batcher_val = Batch(batch_size=batch_size, n_epochs=epochs, shuffle=False)
|
||||||
criterion = torch.nn.CrossEntropyLoss().to(self.device)
|
criterion = torch.nn.CrossEntropyLoss().to(self.device)
|
||||||
optim = torch.optim.Adam(self.parameters(), lr=lr)
|
optim = torch.optim.Adam(self.parameters(), lr=lr)
|
||||||
#optim = torch.optim.Adadelta(self.parameters(), lr=lr)
|
|
||||||
|
|
||||||
X, Xval, y, yval = train_test_split(X, y, test_size=val_prop, stratify=y)
|
X, Xval, y, yval = train_test_split(X, y, test_size=val_prop, stratify=y)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue