last training swipe on eval set is now performed on batch size equal to the training set batch size
This commit is contained in:
parent
ee2a9481de
commit
41647f974a
|
@ -246,7 +246,6 @@ class Trainer:
|
||||||
# swapping model on gpu
|
# swapping model on gpu
|
||||||
del self.model
|
del self.model
|
||||||
self.model = restored_model.to(self.device)
|
self.model = restored_model.to(self.device)
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
if self.scheduler is not None:
|
if self.scheduler is not None:
|
||||||
|
@ -262,7 +261,14 @@ class Trainer:
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"- last swipe on eval set")
|
print(f"- last swipe on eval set")
|
||||||
self.train_epoch(eval_dataloader, epoch=-1)
|
self.train_epoch(
|
||||||
|
DataLoader(
|
||||||
|
eval_dataloader.dataset,
|
||||||
|
batch_size=train_dataloader.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
),
|
||||||
|
epoch=-1,
|
||||||
|
)
|
||||||
self.earlystopping.save_model(self.model)
|
self.earlystopping.save_model(self.model)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
@ -341,6 +347,7 @@ class EarlyStopping:
|
||||||
|
|
||||||
def __call__(self, validation, model, epoch):
|
def __call__(self, validation, model, epoch):
|
||||||
if validation >= self.best_score:
|
if validation >= self.best_score:
|
||||||
|
wandb.log({"patience": self.patience - self.counter})
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(
|
print(
|
||||||
f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}"
|
f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}"
|
||||||
|
@ -352,6 +359,7 @@ class EarlyStopping:
|
||||||
self.save_model(model)
|
self.save_model(model)
|
||||||
elif validation < (self.best_score + self.min_delta):
|
elif validation < (self.best_score + self.min_delta):
|
||||||
self.counter += 1
|
self.counter += 1
|
||||||
|
wandb.log({"patience": self.patience - self.counter})
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(
|
print(
|
||||||
f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}"
|
f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}"
|
||||||
|
|
Loading…
Reference in New Issue