merge
This commit is contained in:
commit
5b216fb958
|
@ -122,77 +122,44 @@ class SupConLoss1View(nn.Module):
|
||||||
|
|
||||||
cross = torch.matmul(features, features.T)
|
cross = torch.matmul(features, features.T)
|
||||||
|
|
||||||
|
# frobenius_loss = torch.norm(mask-cross)
|
||||||
|
|
||||||
|
|
||||||
upper_diag = torch.triu_indices(batch_size,batch_size,+1)
|
upper_diag = torch.triu_indices(batch_size,batch_size,+1)
|
||||||
cross_upper = cross[upper_diag[0], upper_diag[1]]
|
cross_upper = cross[upper_diag[0], upper_diag[1]]
|
||||||
mask_upper = mask[upper_diag[0], upper_diag[1]]
|
mask_upper = mask[upper_diag[0], upper_diag[1]]
|
||||||
#pos = mask_upper.sum()
|
npos = int(mask_upper.sum().item())
|
||||||
# weight = torch.from_numpy(np.asarray([1-pos, pos], dtype=float)).to(device)
|
# weight = torch.from_numpy(np.asarray([1-pos, pos], dtype=float)).to(device)
|
||||||
#return torch.nn.functional.binary_cross_entropy_with_logits(cross_upper, mask_upper)
|
#return torch.nn.functional.binary_cross_entropy_with_logits(cross_upper, mask_upper)
|
||||||
#print('mask min-max:', mask.min(), mask.max())
|
#print('mask min-max:', mask.min(), mask.max())
|
||||||
#print('cross min-max:', cross.min(), cross.max())
|
#print('cross min-max:', cross.min(), cross.max())
|
||||||
#return torch.norm(cross-mask, p='fro') # <-- diagonal signal (trivial) should be too strong
|
#return torch.norm(cross-mask, p='fro') # <-- diagonal signal (trivial) should be too strong
|
||||||
pos_loss = mse(cross_upper, mask_upper, label=1)
|
pos_loss = mse(cross_upper, mask_upper, label=1, k=-1)
|
||||||
neg_loss = mse(cross_upper, mask_upper, label=0)
|
neg_loss = mse(cross_upper, mask_upper, label=0, k=npos)
|
||||||
|
# return frobenius_loss, neg_loss, pos_loss
|
||||||
#return neg_loss, pos_loss
|
#return neg_loss, pos_loss
|
||||||
#balanced_loss = pos_loss + neg_loss
|
# balanced_loss = pos_loss + neg_loss
|
||||||
#return balanced_loss
|
# return balanced_loss, neg_loss, pos_loss
|
||||||
|
# loss = torch.nn.functional.binary_cross_entropy(cross_upper, mask_upper)
|
||||||
|
# return loss, neg_loss, pos_loss
|
||||||
n=len(mask_upper)
|
n=len(mask_upper)
|
||||||
return (1/n)*torch.sqrt(torch.sum((cross_upper-mask_upper)**2)), neg_loss, pos_loss
|
return (1/n)*torch.sqrt(torch.sum((cross_upper-mask_upper)**2)), neg_loss, pos_loss
|
||||||
|
|
||||||
|
|
||||||
def mse(input, target, label):
|
def choice(tensor, k):
|
||||||
|
perm = torch.randperm(tensor.size(0))
|
||||||
|
idx = perm[:k]
|
||||||
|
return tensor[idx]
|
||||||
|
|
||||||
|
|
||||||
|
def mse(input, target, label, k=-1):
|
||||||
input = input[target==label]
|
input = input[target==label]
|
||||||
|
if k>-1:
|
||||||
|
input = choice(input, k)
|
||||||
|
|
||||||
if label==0:
|
if label==0:
|
||||||
return torch.mean(input**2)
|
return torch.mean(input**2)
|
||||||
else:
|
else:
|
||||||
return torch.mean((1-input)**2)
|
return torch.mean((1-input)**2)
|
||||||
#return torch.mean((input[index] - target[index]) ** 2)
|
# index = target==label
|
||||||
|
# return torch.mean((input[index] - target[index]) ** 2)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# # compute logits
|
|
||||||
# anchor_dot_contrast = torch.div(torch.matmul(features, features.T),self.temperature)
|
|
||||||
# # for numerical stability
|
|
||||||
# # logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
|
|
||||||
# # logits = anchor_dot_contrast - logits_max.detach()
|
|
||||||
# logits = anchor_dot_contrast
|
|
||||||
#
|
|
||||||
# # mask-out self-contrast cases
|
|
||||||
# # logits_mask = torch.scatter(
|
|
||||||
# # torch.ones_like(mask),
|
|
||||||
# # 1,
|
|
||||||
# # torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
|
|
||||||
# # 0
|
|
||||||
# # )
|
|
||||||
# # mask = mask * logits_mask
|
|
||||||
# logits_mask = torch.ones_like(mask)
|
|
||||||
# logits_mask.fill_diagonal_(0)
|
|
||||||
# mask.fill_diagonal_(0)
|
|
||||||
#
|
|
||||||
# # compute log_prob
|
|
||||||
# exp_logits = torch.exp(logits) * logits_mask
|
|
||||||
# log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
|
|
||||||
#
|
|
||||||
# # compute mean of log-likelihood over positive
|
|
||||||
# div = mask.sum(1)
|
|
||||||
# div=torch.clamp(div, min=1)
|
|
||||||
# mean_log_prob_pos = (mask * log_prob).sum(1) / div
|
|
||||||
#
|
|
||||||
# # loss
|
|
||||||
# loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
|
|
||||||
# # loss = loss.view(anchor_count, batch_size).mean()
|
|
||||||
# loss = loss.view(-1, batch_size).mean()
|
|
||||||
#
|
|
||||||
# return loss
|
|
||||||
|
|
50
src/main.py
50
src/main.py
|
@ -90,7 +90,7 @@ def main(opt):
|
||||||
cls = AuthorshipAttributionClassifier(
|
cls = AuthorshipAttributionClassifier(
|
||||||
phi, num_authors=A.size, pad_index=pad_index, pad_length=opt.pad, device=device
|
phi, num_authors=A.size, pad_index=pad_index, pad_length=opt.pad, device=device
|
||||||
)
|
)
|
||||||
cls.xavier_uniform()
|
# cls.xavier_uniform()
|
||||||
print(cls)
|
print(cls)
|
||||||
|
|
||||||
if opt.name == 'auto':
|
if opt.name == 'auto':
|
||||||
|
@ -98,6 +98,7 @@ def main(opt):
|
||||||
else:
|
else:
|
||||||
method = opt.name
|
method = opt.name
|
||||||
|
|
||||||
|
with open(f'results_feb_{opt.mode}.txt', 'wt') as foo:
|
||||||
if opt.mode=='savlin':
|
if opt.mode=='savlin':
|
||||||
Xtr_, Xval_, ytr_, yval_ = train_test_split(Xtr, ytr, test_size=0.1, stratify=ytr)
|
Xtr_, Xval_, ytr_, yval_ = train_test_split(Xtr, ytr, test_size=0.1, stratify=ytr)
|
||||||
cls.supervised_contrastive_learning(Xtr_, ytr_, Xval_, yval_,
|
cls.supervised_contrastive_learning(Xtr_, ytr_, Xval_, yval_,
|
||||||
|
@ -108,11 +109,31 @@ def main(opt):
|
||||||
batch_size=opt.batchsize, epochs=opt.epochs, lr=opt.lr,
|
batch_size=opt.batchsize, epochs=opt.epochs, lr=opt.lr,
|
||||||
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||||
checkpointpath=opt.checkpoint)
|
checkpointpath=opt.checkpoint)
|
||||||
|
# test
|
||||||
|
yte_ = cls.predict(Xte)
|
||||||
|
print('sav(fix)-lin(trained) network prediction')
|
||||||
|
acc, macrof1, microf1 = evaluation(yte, yte_)
|
||||||
|
foo.write(f'sav(fix)-lin(trained) network prediction: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}\n')
|
||||||
|
|
||||||
|
val_microf1 = cls.fit(Xtr, ytr,
|
||||||
|
batch_size=opt.batchsize, epochs=opt.epochs, alpha=opt.alpha, lr=opt.lr,
|
||||||
|
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||||
|
checkpointpath=opt.checkpoint
|
||||||
|
)
|
||||||
|
# test
|
||||||
|
yte_ = cls.predict(Xte)
|
||||||
|
print('end-to-end-finetuning network prediction')
|
||||||
|
acc, macrof1, microf1 = evaluation(yte, yte_)
|
||||||
|
foo.write(
|
||||||
|
f'end-to-end-finetuning network prediction: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}\n')
|
||||||
|
|
||||||
svm = GridSearchCV(LinearSVC(), param_grid={'C':np.logspace(-2,3,6), 'class_weight':['balanced',None]}, n_jobs=-1)
|
svm = GridSearchCV(LinearSVC(), param_grid={'C':np.logspace(-2,3,6), 'class_weight':['balanced',None]}, n_jobs=-1)
|
||||||
svm.fit(cls.project(Xtr), ytr)
|
svm.fit(cls.project(Xtr), ytr)
|
||||||
yte_ = svm.predict(cls.project(Xte))
|
yte_ = svm.predict(cls.project(Xte))
|
||||||
acc, macrof1, microf1 = evaluation(yte, yte_)
|
acc, macrof1, microf1 = evaluation(yte, yte_)
|
||||||
print(f'svm: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}')
|
print(f'svm: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}')
|
||||||
|
foo.write(
|
||||||
|
f'svm network prediction: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}\n')
|
||||||
elif opt.mode=='attr':
|
elif opt.mode=='attr':
|
||||||
# train
|
# train
|
||||||
val_microf1 = cls.fit(Xtr, ytr,
|
val_microf1 = cls.fit(Xtr, ytr,
|
||||||
|
@ -120,15 +141,13 @@ def main(opt):
|
||||||
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||||
checkpointpath=opt.checkpoint
|
checkpointpath=opt.checkpoint
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# test
|
# test
|
||||||
yte_ = cls.predict(Xte)
|
yte_ = cls.predict(Xte)
|
||||||
print('network prediction')
|
print('end-to-end network prediction')
|
||||||
acc, macrof1, microf1 = evaluation(yte, yte_)
|
acc, macrof1, microf1 = evaluation(yte, yte_)
|
||||||
|
|
||||||
results = Results(opt.output)
|
# results = Results(opt.output)
|
||||||
results.add(dataset_name, method, acc, macrof1, microf1, val_microf1)
|
# results.add(dataset_name, method, acc, macrof1, microf1, val_microf1)
|
||||||
|
|
||||||
|
|
||||||
# verification
|
# verification
|
||||||
|
@ -173,7 +192,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('-r', '--repr', help='Projection size (phi)', type=int, default=1024)
|
parser.add_argument('-r', '--repr', help='Projection size (phi)', type=int, default=1024)
|
||||||
parser.add_argument('-k', '--kernelsizes', help='Size of the convolutional kernels', nargs='+', default=[6,7,8])
|
parser.add_argument('-k', '--kernelsizes', help='Size of the convolutional kernels', nargs='+', default=[6,7,8])
|
||||||
parser.add_argument('-p', '--pad', help='Pad length', type=int, default=3000)
|
parser.add_argument('-p', '--pad', help='Pad length', type=int, default=3000)
|
||||||
parser.add_argument('-b', '--batchsize', help='Batch size', type=int, default=50)
|
parser.add_argument('-b', '--batchsize', help='Batch size', type=int, default=250)
|
||||||
parser.add_argument('-e', '--epochs', help='Max number of epochs', type=int, default=250)
|
parser.add_argument('-e', '--epochs', help='Max number of epochs', type=int, default=250)
|
||||||
parser.add_argument('-A', '--authors', help='Number of authors (-1 to select all)', type=int, default=-1)
|
parser.add_argument('-A', '--authors', help='Number of authors (-1 to select all)', type=int, default=-1)
|
||||||
parser.add_argument('-D', '--documents', help='Number of documents per author (-1 to select all)', type=int, default=-1)
|
parser.add_argument('-D', '--documents', help='Number of documents per author (-1 to select all)', type=int, default=-1)
|
||||||
|
@ -184,7 +203,7 @@ if __name__ == '__main__':
|
||||||
'This parameter indicates a directory, the name of the pickle is '
|
'This parameter indicates a directory, the name of the pickle is '
|
||||||
'derived automatically.', default='../pickles')
|
'derived automatically.', default='../pickles')
|
||||||
parser.add_argument('-a', '--alpha', help='Controls the loss as attr-loss(alpha) + sav-loss(1-alpha)', type=float, default=1.)
|
parser.add_argument('-a', '--alpha', help='Controls the loss as attr-loss(alpha) + sav-loss(1-alpha)', type=float, default=1.)
|
||||||
parser.add_argument('--lr', help='Learning rate', type=float, default=0.001)
|
parser.add_argument('--lr', help='Learning rate', type=float, default=0.01)
|
||||||
parser.add_argument('--checkpoint', help='Path where to dump model parameters', default='../checkpoint/model.dat')
|
parser.add_argument('--checkpoint', help='Path where to dump model parameters', default='../checkpoint/model.dat')
|
||||||
parser.add_argument('-n', '--name', help='Name of the model', default='auto')
|
parser.add_argument('-n', '--name', help='Name of the model', default='auto')
|
||||||
requiredNamed = parser.add_argument_group('required named arguments')
|
requiredNamed = parser.add_argument_group('required named arguments')
|
||||||
|
@ -201,3 +220,18 @@ if __name__ == '__main__':
|
||||||
create_path_if_not_exists(opt.pickle)
|
create_path_if_not_exists(opt.pickle)
|
||||||
|
|
||||||
main(opt)
|
main(opt)
|
||||||
|
|
||||||
|
"""
|
||||||
|
python3 main.py -d imdb62 -m savlin -A 10 -e 100 -H 32 -c 64 --lr 0.01 -b 250
|
||||||
|
svm: acc=0.915 macrof1=0.915 microf1=0.915
|
||||||
|
network prediction (savlin)
|
||||||
|
acc=90.9000%
|
||||||
|
macro-f1=0.9086
|
||||||
|
micro-f1=0.9090
|
||||||
|
|
||||||
|
end-to-end network prediction (attr)
|
||||||
|
acc=94.6000%
|
||||||
|
macro-f1=0.9458
|
||||||
|
micro-f1=0.9460
|
||||||
|
|
||||||
|
"""
|
|
@ -133,7 +133,7 @@ class AuthorshipAttributionClassifier(nn.Module):
|
||||||
return early_stop.best_score
|
return early_stop.best_score
|
||||||
|
|
||||||
def supervised_contrastive_learning(self, X, y, Xval, yval, batch_size, epochs, patience=50, lr=0.001, log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
def supervised_contrastive_learning(self, X, y, Xval, yval, batch_size, epochs, patience=50, lr=0.001, log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
||||||
early_stop = EarlyStop(patience)
|
early_stop = EarlyStop(patience, lower_is_better=True)
|
||||||
|
|
||||||
criterion = SupConLoss1View().to(self.device)
|
criterion = SupConLoss1View().to(self.device)
|
||||||
optim = torch.optim.Adam(self.parameters(), lr=lr)
|
optim = torch.optim.Adam(self.parameters(), lr=lr)
|
||||||
|
@ -191,7 +191,7 @@ class AuthorshipAttributionClassifier(nn.Module):
|
||||||
self.load_state_dict(torch.load(checkpointpath))
|
self.load_state_dict(torch.load(checkpointpath))
|
||||||
return early_stop.best_score
|
return early_stop.best_score
|
||||||
|
|
||||||
def train_linear_classifier(self, X, y, Xval, yval, batch_size, epochs, patience=50, lr=0.001, log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
def train_linear_classifier(self, X, y, Xval, yval, batch_size, epochs, patience=10, lr=0.001, log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
||||||
early_stop = EarlyStop(patience)
|
early_stop = EarlyStop(patience)
|
||||||
|
|
||||||
criterion = torch.nn.CrossEntropyLoss().to(self.device)
|
criterion = torch.nn.CrossEntropyLoss().to(self.device)
|
||||||
|
|
|
@ -4,16 +4,17 @@ from time import time
|
||||||
|
|
||||||
class EarlyStop:
|
class EarlyStop:
|
||||||
|
|
||||||
def __init__(self, patience=20):
|
def __init__(self, patience=20, lower_is_better=False):
|
||||||
self.patience_limit = patience
|
self.patience_limit = patience
|
||||||
self.patience = patience
|
self.patience = patience
|
||||||
self.best_score = None
|
self.best_score = None
|
||||||
self.best_epoch = None
|
self.best_epoch = None
|
||||||
|
self.better_than = lambda a,b: a<b if lower_is_better else a>b
|
||||||
self.STOP = False
|
self.STOP = False
|
||||||
self.IMPROVED = False
|
self.IMPROVED = False
|
||||||
|
|
||||||
def __call__(self, watch_score, epoch):
|
def __call__(self, watch_score, epoch):
|
||||||
if self.best_score is None or watch_score >= self.best_score:
|
if self.best_score is None or self.better_than(watch_score, self.best_score):
|
||||||
self.IMPROVED = True
|
self.IMPROVED = True
|
||||||
self.best_score = watch_score
|
self.best_score = watch_score
|
||||||
self.best_epoch = epoch
|
self.best_epoch = epoch
|
||||||
|
|
Loading…
Reference in New Issue