diff --git a/src/model/classifiers.py b/src/model/classifiers.py index 96a7165..de31070 100644 --- a/src/model/classifiers.py +++ b/src/model/classifiers.py @@ -18,11 +18,11 @@ class AuthorshipAttributionClassifier(nn.Module): self.device = device def fit(self, X, y, batch_size, epochs, lr=0.001, val_prop=0.1, log='../log/tmp.csv'): - batcher = Batch(batch_size=batch_size, n_epochs=epochs) + #batcher = Batch(batch_size=batch_size, n_epochs=epochs) + batcher = TwoClassBatch(batch_size=batch_size, n_epochs=epochs, steps_per_epoch=X.shape[0]//batch_size) batcher_val = Batch(batch_size=batch_size, n_epochs=epochs, shuffle=False) criterion = torch.nn.CrossEntropyLoss().to(self.device) optim = torch.optim.Adam(self.parameters(), lr=lr) - #optim = torch.optim.Adadelta(self.parameters(), lr=lr) X, Xval, y, yval = train_test_split(X, y, test_size=val_prop, stratify=y)