tracking best val micro F1

This commit is contained in:
Alejandro Moreo Fernandez 2020-05-08 16:51:22 +02:00
parent 8447a6e185
commit f0b08278e4
2 changed files with 6 additions and 5 deletions

View File

@ -83,7 +83,7 @@ def main(opt):
method = opt.name method = opt.name
# train # train
cls.fit(Xtr, ytr, val_microf1 = cls.fit(Xtr, ytr,
batch_size=opt.batchsize, epochs=opt.epochs, alpha=opt.alpha, lr=opt.lr, batch_size=opt.batchsize, epochs=opt.epochs, alpha=opt.alpha, lr=opt.lr,
log=f'{opt.log}/{method}-{dataset_name}.csv', log=f'{opt.log}/{method}-{dataset_name}.csv',
checkpointpath=opt.checkpoint checkpointpath=opt.checkpoint
@ -94,7 +94,7 @@ def main(opt):
acc, macrof1, microf1 = evaluation(yte, yte_) acc, macrof1, microf1 = evaluation(yte, yte_)
results = Results(opt.output) results = Results(opt.output)
results.add(dataset_name, method, acc, macrof1, microf1) results.add(dataset_name, method, acc, macrof1, microf1, val_microf1)
# verification # verification
@ -121,10 +121,10 @@ class Results:
addheader = not os.path.exists(path) addheader = not os.path.exists(path)
self.foo = open(path, 'at') self.foo = open(path, 'at')
if addheader: if addheader:
self.add('Dataset', 'Method', 'Accuracy', 'MacroF1', 'microF1') self.add('Dataset', 'Method', 'Accuracy', 'MacroF1', 'microF1', 'val_microF1')
def add(self, dataset, method, acc, macrof1, microf1): def add(self, dataset, method, acc, macrof1, microf1, val_microF1):
self.foo.write(f'{dataset}\t{method}\t{acc}\t{macrof1}\t{microf1}\n') self.foo.write(f'{dataset}\t{method}\t{acc}\t{macrof1}\t{microf1}\t{val_microF1}\n')
self.foo.flush() self.foo.flush()
def close(self): def close(self):

View File

@ -100,6 +100,7 @@ class AuthorshipAttributionClassifier(nn.Module):
break break
print(f'training ended; loading best model parameters in {checkpointpath} for epoch {early_stop.best_epoch}') print(f'training ended; loading best model parameters in {checkpointpath} for epoch {early_stop.best_epoch}')
self.load_state_dict(torch.load(checkpointpath)) self.load_state_dict(torch.load(checkpointpath))
return early_stop.best_score
def predict(self, x, batch_size=100): def predict(self, x, batch_size=100):
self.eval() self.eval()