tracking best val micro F1

This commit is contained in:
Alejandro Moreo Fernandez 2020-05-08 16:51:22 +02:00
parent 8447a6e185
commit f0b08278e4
2 changed files with 6 additions and 5 deletions

View File

@ -83,7 +83,7 @@ def main(opt):
method = opt.name
# train
cls.fit(Xtr, ytr,
val_microf1 = cls.fit(Xtr, ytr,
batch_size=opt.batchsize, epochs=opt.epochs, alpha=opt.alpha, lr=opt.lr,
log=f'{opt.log}/{method}-{dataset_name}.csv',
checkpointpath=opt.checkpoint
@ -94,7 +94,7 @@ def main(opt):
acc, macrof1, microf1 = evaluation(yte, yte_)
results = Results(opt.output)
results.add(dataset_name, method, acc, macrof1, microf1)
results.add(dataset_name, method, acc, macrof1, microf1, val_microf1)
# verification
@ -121,10 +121,10 @@ class Results:
addheader = not os.path.exists(path)
self.foo = open(path, 'at')
if addheader:
self.add('Dataset', 'Method', 'Accuracy', 'MacroF1', 'microF1')
self.add('Dataset', 'Method', 'Accuracy', 'MacroF1', 'microF1', 'val_microF1')
def add(self, dataset, method, acc, macrof1, microf1):
self.foo.write(f'{dataset}\t{method}\t{acc}\t{macrof1}\t{microf1}\n')
def add(self, dataset, method, acc, macrof1, microf1, val_microF1):
self.foo.write(f'{dataset}\t{method}\t{acc}\t{macrof1}\t{microf1}\t{val_microF1}\n')
self.foo.flush()
def close(self):

View File

@ -100,6 +100,7 @@ class AuthorshipAttributionClassifier(nn.Module):
break
print(f'training ended; loading best model parameters in {checkpointpath} for epoch {early_stop.best_epoch}')
self.load_state_dict(torch.load(checkpointpath))
return early_stop.best_score
def predict(self, x, batch_size=100):
self.eval()