last training swipe on eval set is now performed on batch size equal to the training set batch size
This commit is contained in:
parent
ee2a9481de
commit
41647f974a
|
@ -246,7 +246,6 @@ class Trainer:
|
|||
# swapping model on gpu
|
||||
del self.model
|
||||
self.model = restored_model.to(self.device)
|
||||
|
||||
break
|
||||
|
||||
if self.scheduler is not None:
|
||||
|
@ -262,7 +261,14 @@ class Trainer:
|
|||
)
|
||||
|
||||
print(f"- last swipe on eval set")
|
||||
self.train_epoch(eval_dataloader, epoch=-1)
|
||||
self.train_epoch(
|
||||
DataLoader(
|
||||
eval_dataloader.dataset,
|
||||
batch_size=train_dataloader.batch_size,
|
||||
shuffle=True,
|
||||
),
|
||||
epoch=-1,
|
||||
)
|
||||
self.earlystopping.save_model(self.model)
|
||||
return self.model
|
||||
|
||||
|
@ -341,6 +347,7 @@ class EarlyStopping:
|
|||
|
||||
def __call__(self, validation, model, epoch):
|
||||
if validation >= self.best_score:
|
||||
wandb.log({"patience": self.patience - self.counter})
|
||||
if self.verbose:
|
||||
print(
|
||||
f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}"
|
||||
|
@ -352,6 +359,7 @@ class EarlyStopping:
|
|||
self.save_model(model)
|
||||
elif validation < (self.best_score + self.min_delta):
|
||||
self.counter += 1
|
||||
wandb.log({"patience": self.patience - self.counter})
|
||||
if self.verbose:
|
||||
print(
|
||||
f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}"
|
||||
|
|
Loading…
Reference in New Issue