From 1385d2ac95b303685a4f3fd8979c7cf011ed3038 Mon Sep 17 00:00:00 2001 From: Alex Moreo Date: Thu, 4 Feb 2021 18:40:45 +0100 Subject: [PATCH] increasing patience --- src/model/classifiers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/model/classifiers.py b/src/model/classifiers.py index 58a8426..f23d826 100644 --- a/src/model/classifiers.py +++ b/src/model/classifiers.py @@ -29,7 +29,7 @@ class AuthorshipAttributionClassifier(nn.Module): if p.dim() > 1 and p.requires_grad: nn.init.xavier_uniform_(p) - def fit(self, X, y, batch_size, epochs, patience=10, lr=0.001, val_prop=0.1, alpha=1., log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'): + def fit(self, X, y, batch_size, epochs, patience=50, lr=0.001, val_prop=0.1, alpha=1., log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'): assert 0 <= alpha <= 1, 'wrong range, alpha must be in [0,1]' early_stop = EarlyStop(patience) @@ -132,11 +132,11 @@ class AuthorshipAttributionClassifier(nn.Module): self.load_state_dict(torch.load(checkpointpath)) return early_stop.best_score - def supervised_contrastive_learning(self, X, y, Xval, yval, batch_size, epochs, patience=10, lr=0.001, log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'): + def supervised_contrastive_learning(self, X, y, Xval, yval, batch_size, epochs, patience=50, lr=0.001, log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'): early_stop = EarlyStop(patience) criterion = SupConLoss1View().to(self.device) - optim = torch.optim.Adam(self.projector.parameters(), lr=lr) + optim = torch.optim.Adam(self.parameters(), lr=lr) tr_data = IndexedDataset(X, y, self.pad_length, self.pad_index, self.device) val_data = IndexedDataset(Xval, yval, self.pad_length, self.pad_index, self.device) @@ -191,11 +191,11 @@ class AuthorshipAttributionClassifier(nn.Module): self.load_state_dict(torch.load(checkpointpath)) return early_stop.best_score - def train_linear_classifier(self, X, y, Xval, yval, batch_size, epochs, patience=10, lr=0.001, log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'): + def train_linear_classifier(self, X, y, Xval, yval, batch_size, epochs, patience=50, lr=0.001, log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'): early_stop = EarlyStop(patience) criterion = torch.nn.CrossEntropyLoss().to(self.device) - optim = torch.optim.Adam(self.ff.parameters(), lr=lr) + optim = torch.optim.Adam(self.parameters(), lr=lr) tr_data = IndexedDataset(X, y, self.pad_length, self.pad_index, self.device) val_data = IndexedDataset(Xval, yval, self.pad_length, self.pad_index, self.device) @@ -221,7 +221,7 @@ class AuthorshipAttributionClassifier(nn.Module): tr_loss = np.mean(losses) pbar.set_description(f'training epoch={epoch} ' f'loss={tr_loss:.5f} ' - f'val_loss={val_loss:.5f} val_acc={acc:.4f} macrof1={macrof1:.4f} microf1={microf1:.4f}' + f'val_loss={val_loss:.5f} val_acc={acc:.4f} macrof1={macrof1:.4f} microf1={microf1:.4f} ' f'patience={early_stop.patience}/{early_stop.patience_limit}') # validation