From f0b08278e417c9a1924aa2e21ec220c1791b7d97 Mon Sep 17 00:00:00 2001 From: Alex Moreo Date: Fri, 8 May 2020 16:51:22 +0200 Subject: [PATCH] tracking best val micro F1 --- src/main.py | 10 +++++----- src/model/classifiers.py | 1 + 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/main.py b/src/main.py index d7b2e98..5db4d23 100644 --- a/src/main.py +++ b/src/main.py @@ -83,7 +83,7 @@ def main(opt): method = opt.name # train - cls.fit(Xtr, ytr, + val_microf1 = cls.fit(Xtr, ytr, batch_size=opt.batchsize, epochs=opt.epochs, alpha=opt.alpha, lr=opt.lr, log=f'{opt.log}/{method}-{dataset_name}.csv', checkpointpath=opt.checkpoint @@ -94,7 +94,7 @@ def main(opt): acc, macrof1, microf1 = evaluation(yte, yte_) results = Results(opt.output) - results.add(dataset_name, method, acc, macrof1, microf1) + results.add(dataset_name, method, acc, macrof1, microf1, val_microf1) # verification @@ -121,10 +121,10 @@ class Results: addheader = not os.path.exists(path) self.foo = open(path, 'at') if addheader: - self.add('Dataset', 'Method', 'Accuracy', 'MacroF1', 'microF1') + self.add('Dataset', 'Method', 'Accuracy', 'MacroF1', 'microF1', 'val_microF1') - def add(self, dataset, method, acc, macrof1, microf1): - self.foo.write(f'{dataset}\t{method}\t{acc}\t{macrof1}\t{microf1}\n') + def add(self, dataset, method, acc, macrof1, microf1, val_microF1): + self.foo.write(f'{dataset}\t{method}\t{acc}\t{macrof1}\t{microf1}\t{val_microF1}\n') self.foo.flush() def close(self): diff --git a/src/model/classifiers.py b/src/model/classifiers.py index 6f277d5..ebdb57f 100644 --- a/src/model/classifiers.py +++ b/src/model/classifiers.py @@ -100,6 +100,7 @@ class AuthorshipAttributionClassifier(nn.Module): break print(f'training ended; loading best model parameters in {checkpointpath} for epoch {early_stop.best_epoch}') self.load_state_dict(torch.load(checkpointpath)) + return early_stop.best_score def predict(self, x, batch_size=100): self.eval()