forked from moreo/QuaPy
55 lines
1.8 KiB
Python
Executable File
55 lines
1.8 KiB
Python
Executable File
#adapted from https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py
|
|
import torch
|
|
from time import time
|
|
from util.file import create_if_not_exist
|
|
|
|
|
|
class EarlyStopping:
|
|
|
|
def __init__(self, model, patience=20, verbose=True, checkpoint='./checkpoint.pt'):
|
|
# set patience to 0 or -1 to avoid stopping, but still keeping track of the best value and model parameters
|
|
self.patience_limit = patience
|
|
self.patience = patience
|
|
self.verbose = verbose
|
|
self.best_score = None
|
|
self.best_epoch = None
|
|
self.stop_time = None
|
|
self.checkpoint = checkpoint
|
|
self.model = model
|
|
self.STOP = False
|
|
|
|
def __call__(self, watch_score, epoch):
|
|
|
|
if self.STOP:
|
|
return #done
|
|
|
|
if self.best_score is None or watch_score >= self.best_score:
|
|
self.best_score = watch_score
|
|
self.best_epoch = epoch
|
|
self.stop_time = time()
|
|
if self.checkpoint:
|
|
self.print(f'[early-stop] improved, saving model in {self.checkpoint}')
|
|
torch.save(self.model, self.checkpoint)
|
|
else:
|
|
self.print(f'[early-stop] improved')
|
|
self.patience = self.patience_limit
|
|
else:
|
|
self.patience -= 1
|
|
if self.patience == 0:
|
|
self.STOP = True
|
|
self.print(f'[early-stop] patience exhausted')
|
|
else:
|
|
if self.patience>0: # if negative, then early-stop is ignored
|
|
self.print(f'[early-stop] patience={self.patience}')
|
|
|
|
def reinit_counter(self):
|
|
self.STOP = False
|
|
self.patience=self.patience_limit
|
|
|
|
def restore_checkpoint(self):
|
|
return torch.load(self.checkpoint)
|
|
|
|
def print(self, msg):
|
|
if self.verbose:
|
|
print(msg)
|