10% test 10% val
This commit is contained in:
parent
efe9d90f89
commit
727dda6167
|
@ -7,7 +7,7 @@ from data.AuthorshipDataset import AuthorshipDataset, LabelledCorpus
|
||||||
|
|
||||||
class Imdb62(AuthorshipDataset):
|
class Imdb62(AuthorshipDataset):
|
||||||
|
|
||||||
TEST_SIZE = 0.30
|
TEST_SIZE = 0.10
|
||||||
NUM_AUTHORS = 62
|
NUM_AUTHORS = 62
|
||||||
NUM_DOCS_BY_AUTHOR = int(1000-(1000*TEST_SIZE))
|
NUM_DOCS_BY_AUTHOR = int(1000-(1000*TEST_SIZE))
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ class AuthorshipAttributionClassifier(nn.Module):
|
||||||
self.padder = Padding(pad_index=pad_index, max_length=pad_length, dynamic=True, pad_at_end=False, device=device)
|
self.padder = Padding(pad_index=pad_index, max_length=pad_length, dynamic=True, pad_at_end=False, device=device)
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def fit(self, X, y, batch_size, epochs, lr=0.001, val_prop=0.2, log='../log/tmp.csv'):
|
def fit(self, X, y, batch_size, epochs, lr=0.001, val_prop=0.1, log='../log/tmp.csv'):
|
||||||
batcher = Batch(batch_size=batch_size, n_epochs=epochs)
|
batcher = Batch(batch_size=batch_size, n_epochs=epochs)
|
||||||
criterion = torch.nn.CrossEntropyLoss().to(self.device)
|
criterion = torch.nn.CrossEntropyLoss().to(self.device)
|
||||||
optim = torch.optim.Adam(self.parameters(), lr=lr)
|
optim = torch.optim.Adam(self.parameters(), lr=lr)
|
||||||
|
|
Loading…
Reference in New Issue