update
This commit is contained in:
parent
38867371fa
commit
a8848ceda2
|
@ -132,30 +132,46 @@ class SupConLoss1View(nn.Module):
|
|||
|
||||
cross = torch.matmul(features, features.T)
|
||||
|
||||
# frobenius_loss = torch.norm(mask-cross)
|
||||
|
||||
|
||||
upper_diag = torch.triu_indices(batch_size,batch_size,+1)
|
||||
cross_upper = cross[upper_diag[0], upper_diag[1]]
|
||||
mask_upper = mask[upper_diag[0], upper_diag[1]]
|
||||
#pos = mask_upper.sum()
|
||||
npos = int(mask_upper.sum().item())
|
||||
# weight = torch.from_numpy(np.asarray([1-pos, pos], dtype=float)).to(device)
|
||||
#return torch.nn.functional.binary_cross_entropy_with_logits(cross_upper, mask_upper)
|
||||
#print('mask min-max:', mask.min(), mask.max())
|
||||
#print('cross min-max:', cross.min(), cross.max())
|
||||
#return torch.norm(cross-mask, p='fro') # <-- diagonal signal (trivial) should be too strong
|
||||
pos_loss = mse(cross_upper, mask_upper, label=1)
|
||||
neg_loss = mse(cross_upper, mask_upper, label=0)
|
||||
pos_loss = mse(cross_upper, mask_upper, label=1, k=-1)
|
||||
neg_loss = mse(cross_upper, mask_upper, label=0, k=npos)
|
||||
# return frobenius_loss, neg_loss, pos_loss
|
||||
#return neg_loss, pos_loss
|
||||
#balanced_loss = pos_loss + neg_loss
|
||||
#return balanced_loss
|
||||
# balanced_loss = pos_loss + neg_loss
|
||||
# return balanced_loss, neg_loss, pos_loss
|
||||
# loss = torch.nn.functional.binary_cross_entropy(cross_upper, mask_upper)
|
||||
# return loss, neg_loss, pos_loss
|
||||
return torch.mean((cross_upper-mask_upper)**2), neg_loss, pos_loss
|
||||
|
||||
|
||||
def mse(input, target, label):
|
||||
def choice(tensor, k):
|
||||
perm = torch.randperm(tensor.size(0))
|
||||
idx = perm[:k]
|
||||
return tensor[idx]
|
||||
|
||||
|
||||
def mse(input, target, label, k=-1):
|
||||
input = input[target==label]
|
||||
if k>-1:
|
||||
input = choice(input, k)
|
||||
|
||||
if label==0:
|
||||
return torch.mean(input**2)
|
||||
else:
|
||||
return torch.mean((1-input)**2)
|
||||
#return torch.mean((input[index] - target[index]) ** 2)
|
||||
# index = target==label
|
||||
# return torch.mean((input[index] - target[index]) ** 2)
|
||||
|
||||
|
||||
|
||||
|
|
96
src/main.py
96
src/main.py
|
@ -90,7 +90,7 @@ def main(opt):
|
|||
cls = AuthorshipAttributionClassifier(
|
||||
phi, num_authors=A.size, pad_index=pad_index, pad_length=opt.pad, device=device
|
||||
)
|
||||
cls.xavier_uniform()
|
||||
# cls.xavier_uniform()
|
||||
print(cls)
|
||||
|
||||
if opt.name == 'auto':
|
||||
|
@ -98,37 +98,56 @@ def main(opt):
|
|||
else:
|
||||
method = opt.name
|
||||
|
||||
if opt.mode=='savlin':
|
||||
Xtr_, Xval_, ytr_, yval_ = train_test_split(Xtr, ytr, test_size=0.1, stratify=ytr)
|
||||
cls.supervised_contrastive_learning(Xtr_, ytr_, Xval_, yval_,
|
||||
batch_size=opt.batchsize, epochs=opt.epochs, lr=opt.lr,
|
||||
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||
checkpointpath=opt.checkpoint)
|
||||
val_microf1 = cls.train_linear_classifier(Xtr_, ytr_, Xval_, yval_,
|
||||
batch_size=opt.batchsize, epochs=opt.epochs, lr=opt.lr,
|
||||
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||
checkpointpath=opt.checkpoint)
|
||||
svm = GridSearchCV(LinearSVC(), param_grid={'C':np.logspace(-2,3,6), 'class_weight':['balanced',None]}, n_jobs=-1)
|
||||
svm.fit(cls.project(Xtr), ytr)
|
||||
yte_ = svm.predict(cls.project(Xte))
|
||||
acc, macrof1, microf1 = evaluation(yte, yte_)
|
||||
print(f'svm: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}')
|
||||
elif opt.mode=='attr':
|
||||
# train
|
||||
val_microf1 = cls.fit(Xtr, ytr,
|
||||
batch_size=opt.batchsize, epochs=opt.epochs, alpha=opt.alpha, lr=opt.lr,
|
||||
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||
checkpointpath=opt.checkpoint
|
||||
)
|
||||
with open(f'results_feb_{opt.mode}.txt', 'wt') as foo:
|
||||
if opt.mode=='savlin':
|
||||
Xtr_, Xval_, ytr_, yval_ = train_test_split(Xtr, ytr, test_size=0.1, stratify=ytr)
|
||||
cls.supervised_contrastive_learning(Xtr_, ytr_, Xval_, yval_,
|
||||
batch_size=opt.batchsize, epochs=opt.epochs, lr=opt.lr,
|
||||
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||
checkpointpath=opt.checkpoint)
|
||||
val_microf1 = cls.train_linear_classifier(Xtr_, ytr_, Xval_, yval_,
|
||||
batch_size=opt.batchsize, epochs=opt.epochs, lr=opt.lr,
|
||||
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||
checkpointpath=opt.checkpoint)
|
||||
# test
|
||||
yte_ = cls.predict(Xte)
|
||||
print('sav(fix)-lin(trained) network prediction')
|
||||
acc, macrof1, microf1 = evaluation(yte, yte_)
|
||||
foo.write(f'sav(fix)-lin(trained) network prediction: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}\n')
|
||||
|
||||
val_microf1 = cls.fit(Xtr, ytr,
|
||||
batch_size=opt.batchsize, epochs=opt.epochs, alpha=opt.alpha, lr=opt.lr,
|
||||
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||
checkpointpath=opt.checkpoint
|
||||
)
|
||||
# test
|
||||
yte_ = cls.predict(Xte)
|
||||
print('end-to-end-finetuning network prediction')
|
||||
acc, macrof1, microf1 = evaluation(yte, yte_)
|
||||
foo.write(
|
||||
f'end-to-end-finetuning network prediction: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}\n')
|
||||
|
||||
# test
|
||||
yte_ = cls.predict(Xte)
|
||||
print('network prediction')
|
||||
acc, macrof1, microf1 = evaluation(yte, yte_)
|
||||
svm = GridSearchCV(LinearSVC(), param_grid={'C':np.logspace(-2,3,6), 'class_weight':['balanced',None]}, n_jobs=-1)
|
||||
svm.fit(cls.project(Xtr), ytr)
|
||||
yte_ = svm.predict(cls.project(Xte))
|
||||
acc, macrof1, microf1 = evaluation(yte, yte_)
|
||||
print(f'svm: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}')
|
||||
foo.write(
|
||||
f'svm network prediction: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}\n')
|
||||
elif opt.mode=='attr':
|
||||
# train
|
||||
val_microf1 = cls.fit(Xtr, ytr,
|
||||
batch_size=opt.batchsize, epochs=opt.epochs, alpha=opt.alpha, lr=opt.lr,
|
||||
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||
checkpointpath=opt.checkpoint
|
||||
)
|
||||
# test
|
||||
yte_ = cls.predict(Xte)
|
||||
print('end-to-end network prediction')
|
||||
acc, macrof1, microf1 = evaluation(yte, yte_)
|
||||
|
||||
results = Results(opt.output)
|
||||
results.add(dataset_name, method, acc, macrof1, microf1, val_microf1)
|
||||
# results = Results(opt.output)
|
||||
# results.add(dataset_name, method, acc, macrof1, microf1, val_microf1)
|
||||
|
||||
|
||||
# verification
|
||||
|
@ -173,7 +192,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('-r', '--repr', help='Projection size (phi)', type=int, default=1024)
|
||||
parser.add_argument('-k', '--kernelsizes', help='Size of the convolutional kernels', nargs='+', default=[6,7,8])
|
||||
parser.add_argument('-p', '--pad', help='Pad length', type=int, default=3000)
|
||||
parser.add_argument('-b', '--batchsize', help='Batch size', type=int, default=50)
|
||||
parser.add_argument('-b', '--batchsize', help='Batch size', type=int, default=250)
|
||||
parser.add_argument('-e', '--epochs', help='Max number of epochs', type=int, default=250)
|
||||
parser.add_argument('-A', '--authors', help='Number of authors (-1 to select all)', type=int, default=-1)
|
||||
parser.add_argument('-D', '--documents', help='Number of documents per author (-1 to select all)', type=int, default=-1)
|
||||
|
@ -184,7 +203,7 @@ if __name__ == '__main__':
|
|||
'This parameter indicates a directory, the name of the pickle is '
|
||||
'derived automatically.', default='../pickles')
|
||||
parser.add_argument('-a', '--alpha', help='Controls the loss as attr-loss(alpha) + sav-loss(1-alpha)', type=float, default=1.)
|
||||
parser.add_argument('--lr', help='Learning rate', type=float, default=0.001)
|
||||
parser.add_argument('--lr', help='Learning rate', type=float, default=0.01)
|
||||
parser.add_argument('--checkpoint', help='Path where to dump model parameters', default='../checkpoint/model.dat')
|
||||
parser.add_argument('-n', '--name', help='Name of the model', default='auto')
|
||||
requiredNamed = parser.add_argument_group('required named arguments')
|
||||
|
@ -201,3 +220,18 @@ if __name__ == '__main__':
|
|||
create_path_if_not_exists(opt.pickle)
|
||||
|
||||
main(opt)
|
||||
|
||||
"""
|
||||
python3 main.py -d imdb62 -m savlin -A 10 -e 100 -H 32 -c 64 --lr 0.01 -b 250
|
||||
svm: acc=0.915 macrof1=0.915 microf1=0.915
|
||||
network prediction (savlin)
|
||||
acc=90.9000%
|
||||
macro-f1=0.9086
|
||||
micro-f1=0.9090
|
||||
|
||||
end-to-end network prediction (attr)
|
||||
acc=94.6000%
|
||||
macro-f1=0.9458
|
||||
micro-f1=0.9460
|
||||
|
||||
"""
|
|
@ -133,7 +133,7 @@ class AuthorshipAttributionClassifier(nn.Module):
|
|||
return early_stop.best_score
|
||||
|
||||
def supervised_contrastive_learning(self, X, y, Xval, yval, batch_size, epochs, patience=50, lr=0.001, log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
||||
early_stop = EarlyStop(patience)
|
||||
early_stop = EarlyStop(patience, lower_is_better=True)
|
||||
|
||||
criterion = SupConLoss1View().to(self.device)
|
||||
optim = torch.optim.Adam(self.parameters(), lr=lr)
|
||||
|
@ -191,7 +191,7 @@ class AuthorshipAttributionClassifier(nn.Module):
|
|||
self.load_state_dict(torch.load(checkpointpath))
|
||||
return early_stop.best_score
|
||||
|
||||
def train_linear_classifier(self, X, y, Xval, yval, batch_size, epochs, patience=50, lr=0.001, log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
||||
def train_linear_classifier(self, X, y, Xval, yval, batch_size, epochs, patience=10, lr=0.001, log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
||||
early_stop = EarlyStop(patience)
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss().to(self.device)
|
||||
|
|
|
@ -4,16 +4,17 @@ from time import time
|
|||
|
||||
class EarlyStop:
|
||||
|
||||
def __init__(self, patience=20):
|
||||
def __init__(self, patience=20, lower_is_better=False):
|
||||
self.patience_limit = patience
|
||||
self.patience = patience
|
||||
self.best_score = None
|
||||
self.best_epoch = None
|
||||
self.better_than = lambda a,b: a<b if lower_is_better else a>b
|
||||
self.STOP = False
|
||||
self.IMPROVED = False
|
||||
|
||||
def __call__(self, watch_score, epoch):
|
||||
if self.best_score is None or watch_score >= self.best_score:
|
||||
if self.best_score is None or self.better_than(watch_score, self.best_score):
|
||||
self.IMPROVED = True
|
||||
self.best_score = watch_score
|
||||
self.best_epoch = epoch
|
||||
|
|
Loading…
Reference in New Issue