QuaPy/MultiLabel/util/early_stop.py

55 lines
1.8 KiB
Python
Executable File

#adapted from https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py
import torch
from time import time
from util.file import create_if_not_exist
class EarlyStopping:
def __init__(self, model, patience=20, verbose=True, checkpoint='./checkpoint.pt'):
# set patience to 0 or -1 to avoid stopping, but still keeping track of the best value and model parameters
self.patience_limit = patience
self.patience = patience
self.verbose = verbose
self.best_score = None
self.best_epoch = None
self.stop_time = None
self.checkpoint = checkpoint
self.model = model
self.STOP = False
def __call__(self, watch_score, epoch):
if self.STOP:
return #done
if self.best_score is None or watch_score >= self.best_score:
self.best_score = watch_score
self.best_epoch = epoch
self.stop_time = time()
if self.checkpoint:
self.print(f'[early-stop] improved, saving model in {self.checkpoint}')
torch.save(self.model, self.checkpoint)
else:
self.print(f'[early-stop] improved')
self.patience = self.patience_limit
else:
self.patience -= 1
if self.patience == 0:
self.STOP = True
self.print(f'[early-stop] patience exhausted')
else:
if self.patience>0: # if negative, then early-stop is ignored
self.print(f'[early-stop] patience={self.patience}')
def reinit_counter(self):
self.STOP = False
self.patience=self.patience_limit
def restore_checkpoint(self):
return torch.load(self.checkpoint)
def print(self, msg):
if self.verbose:
print(msg)