tracking best val micro F1
This commit is contained in:
parent
8447a6e185
commit
f0b08278e4
10
src/main.py
10
src/main.py
|
@ -83,7 +83,7 @@ def main(opt):
|
|||
method = opt.name
|
||||
|
||||
# train
|
||||
cls.fit(Xtr, ytr,
|
||||
val_microf1 = cls.fit(Xtr, ytr,
|
||||
batch_size=opt.batchsize, epochs=opt.epochs, alpha=opt.alpha, lr=opt.lr,
|
||||
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||
checkpointpath=opt.checkpoint
|
||||
|
@ -94,7 +94,7 @@ def main(opt):
|
|||
acc, macrof1, microf1 = evaluation(yte, yte_)
|
||||
|
||||
results = Results(opt.output)
|
||||
results.add(dataset_name, method, acc, macrof1, microf1)
|
||||
results.add(dataset_name, method, acc, macrof1, microf1, val_microf1)
|
||||
|
||||
|
||||
# verification
|
||||
|
@ -121,10 +121,10 @@ class Results:
|
|||
addheader = not os.path.exists(path)
|
||||
self.foo = open(path, 'at')
|
||||
if addheader:
|
||||
self.add('Dataset', 'Method', 'Accuracy', 'MacroF1', 'microF1')
|
||||
self.add('Dataset', 'Method', 'Accuracy', 'MacroF1', 'microF1', 'val_microF1')
|
||||
|
||||
def add(self, dataset, method, acc, macrof1, microf1):
|
||||
self.foo.write(f'{dataset}\t{method}\t{acc}\t{macrof1}\t{microf1}\n')
|
||||
def add(self, dataset, method, acc, macrof1, microf1, val_microF1):
|
||||
self.foo.write(f'{dataset}\t{method}\t{acc}\t{macrof1}\t{microf1}\t{val_microF1}\n')
|
||||
self.foo.flush()
|
||||
|
||||
def close(self):
|
||||
|
|
|
@ -100,6 +100,7 @@ class AuthorshipAttributionClassifier(nn.Module):
|
|||
break
|
||||
print(f'training ended; loading best model parameters in {checkpointpath} for epoch {early_stop.best_epoch}')
|
||||
self.load_state_dict(torch.load(checkpointpath))
|
||||
return early_stop.best_score
|
||||
|
||||
def predict(self, x, batch_size=100):
|
||||
self.eval()
|
||||
|
|
Loading…
Reference in New Issue