increasing patience
This commit is contained in:
parent
829d2150ec
commit
1385d2ac95
|
@ -29,7 +29,7 @@ class AuthorshipAttributionClassifier(nn.Module):
|
||||||
if p.dim() > 1 and p.requires_grad:
|
if p.dim() > 1 and p.requires_grad:
|
||||||
nn.init.xavier_uniform_(p)
|
nn.init.xavier_uniform_(p)
|
||||||
|
|
||||||
def fit(self, X, y, batch_size, epochs, patience=10, lr=0.001, val_prop=0.1, alpha=1., log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
def fit(self, X, y, batch_size, epochs, patience=50, lr=0.001, val_prop=0.1, alpha=1., log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
||||||
assert 0 <= alpha <= 1, 'wrong range, alpha must be in [0,1]'
|
assert 0 <= alpha <= 1, 'wrong range, alpha must be in [0,1]'
|
||||||
early_stop = EarlyStop(patience)
|
early_stop = EarlyStop(patience)
|
||||||
|
|
||||||
|
@ -132,11 +132,11 @@ class AuthorshipAttributionClassifier(nn.Module):
|
||||||
self.load_state_dict(torch.load(checkpointpath))
|
self.load_state_dict(torch.load(checkpointpath))
|
||||||
return early_stop.best_score
|
return early_stop.best_score
|
||||||
|
|
||||||
def supervised_contrastive_learning(self, X, y, Xval, yval, batch_size, epochs, patience=10, lr=0.001, log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
def supervised_contrastive_learning(self, X, y, Xval, yval, batch_size, epochs, patience=50, lr=0.001, log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
||||||
early_stop = EarlyStop(patience)
|
early_stop = EarlyStop(patience)
|
||||||
|
|
||||||
criterion = SupConLoss1View().to(self.device)
|
criterion = SupConLoss1View().to(self.device)
|
||||||
optim = torch.optim.Adam(self.projector.parameters(), lr=lr)
|
optim = torch.optim.Adam(self.parameters(), lr=lr)
|
||||||
|
|
||||||
tr_data = IndexedDataset(X, y, self.pad_length, self.pad_index, self.device)
|
tr_data = IndexedDataset(X, y, self.pad_length, self.pad_index, self.device)
|
||||||
val_data = IndexedDataset(Xval, yval, self.pad_length, self.pad_index, self.device)
|
val_data = IndexedDataset(Xval, yval, self.pad_length, self.pad_index, self.device)
|
||||||
|
@ -191,11 +191,11 @@ class AuthorshipAttributionClassifier(nn.Module):
|
||||||
self.load_state_dict(torch.load(checkpointpath))
|
self.load_state_dict(torch.load(checkpointpath))
|
||||||
return early_stop.best_score
|
return early_stop.best_score
|
||||||
|
|
||||||
def train_linear_classifier(self, X, y, Xval, yval, batch_size, epochs, patience=10, lr=0.001, log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
def train_linear_classifier(self, X, y, Xval, yval, batch_size, epochs, patience=50, lr=0.001, log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
||||||
early_stop = EarlyStop(patience)
|
early_stop = EarlyStop(patience)
|
||||||
|
|
||||||
criterion = torch.nn.CrossEntropyLoss().to(self.device)
|
criterion = torch.nn.CrossEntropyLoss().to(self.device)
|
||||||
optim = torch.optim.Adam(self.ff.parameters(), lr=lr)
|
optim = torch.optim.Adam(self.parameters(), lr=lr)
|
||||||
|
|
||||||
tr_data = IndexedDataset(X, y, self.pad_length, self.pad_index, self.device)
|
tr_data = IndexedDataset(X, y, self.pad_length, self.pad_index, self.device)
|
||||||
val_data = IndexedDataset(Xval, yval, self.pad_length, self.pad_index, self.device)
|
val_data = IndexedDataset(Xval, yval, self.pad_length, self.pad_index, self.device)
|
||||||
|
|
Loading…
Reference in New Issue