#adapted from https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py import torch from time import time from util.file import create_if_not_exist class EarlyStopping: def __init__(self, model, patience=20, verbose=True, checkpoint='./checkpoint.pt'): # set patience to 0 or -1 to avoid stopping, but still keeping track of the best value and model parameters self.patience_limit = patience self.patience = patience self.verbose = verbose self.best_score = None self.best_epoch = None self.stop_time = None self.checkpoint = checkpoint self.model = model self.STOP = False def __call__(self, watch_score, epoch): if self.STOP: return #done if self.best_score is None or watch_score >= self.best_score: self.best_score = watch_score self.best_epoch = epoch self.stop_time = time() if self.checkpoint: self.print(f'[early-stop] improved, saving model in {self.checkpoint}') torch.save(self.model, self.checkpoint) else: self.print(f'[early-stop] improved') self.patience = self.patience_limit else: self.patience -= 1 if self.patience == 0: self.STOP = True self.print(f'[early-stop] patience exhausted') else: if self.patience>0: # if negative, then early-stop is ignored self.print(f'[early-stop] patience={self.patience}') def reinit_counter(self): self.STOP = False self.patience=self.patience_limit def restore_checkpoint(self): return torch.load(self.checkpoint) def print(self, msg): if self.verbose: print(msg)