going back to mse
This commit is contained in:
parent
5b216fb958
commit
b1160c5336
|
@ -8,6 +8,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SupConLoss(nn.Module):
|
||||
"""Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
|
||||
It also supports the unsupervised contrastive loss in SimCLR"""
|
||||
|
@ -128,22 +129,58 @@ class SupConLoss1View(nn.Module):
|
|||
upper_diag = torch.triu_indices(batch_size,batch_size,+1)
|
||||
cross_upper = cross[upper_diag[0], upper_diag[1]]
|
||||
mask_upper = mask[upper_diag[0], upper_diag[1]]
|
||||
npos = int(mask_upper.sum().item())
|
||||
# npos = int(mask_upper.sum().item())
|
||||
# weight = torch.from_numpy(np.asarray([1-pos, pos], dtype=float)).to(device)
|
||||
#return torch.nn.functional.binary_cross_entropy_with_logits(cross_upper, mask_upper)
|
||||
#print('mask min-max:', mask.min(), mask.max())
|
||||
#print('cross min-max:', cross.min(), cross.max())
|
||||
#return torch.norm(cross-mask, p='fro') # <-- diagonal signal (trivial) should be too strong
|
||||
pos_loss = mse(cross_upper, mask_upper, label=1, k=-1)
|
||||
neg_loss = mse(cross_upper, mask_upper, label=0, k=npos)
|
||||
pos_loss = mse(cross_upper, mask_upper, label=1)
|
||||
neg_loss = mse(cross_upper, mask_upper, label=0)
|
||||
# return frobenius_loss, neg_loss, pos_loss
|
||||
#return neg_loss, pos_loss
|
||||
# balanced_loss = pos_loss + neg_loss
|
||||
# return balanced_loss, neg_loss, pos_loss
|
||||
# loss = torch.nn.functional.binary_cross_entropy(cross_upper, mask_upper)
|
||||
# return loss, neg_loss, pos_loss
|
||||
n=len(mask_upper)
|
||||
return (1/n)*torch.sqrt(torch.sum((cross_upper-mask_upper)**2)), neg_loss, pos_loss
|
||||
# n=len(mask_upper)
|
||||
# return (1/n)*torch.sqrt(torch.sum((cross_upper-mask_upper)**2)), neg_loss, pos_loss
|
||||
return mse(cross_upper, mask_upper), neg_loss, pos_loss
|
||||
|
||||
|
||||
class SupConLoss1ViewCrossEntropy(nn.Module):
|
||||
"""Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
|
||||
It also supports the unsupervised contrastive loss in SimCLR"""
|
||||
def __init__(self):
|
||||
super(SupConLoss1ViewCrossEntropy, self).__init__()
|
||||
|
||||
def forward(self, features, labels):
|
||||
device = (torch.device('cuda')
|
||||
if features.is_cuda
|
||||
else torch.device('cpu'))
|
||||
|
||||
if len(features.shape) != 2:
|
||||
raise ValueError('`features` needs to be [bsz, ndim]')
|
||||
|
||||
batch_size = features.shape[0]
|
||||
labels = labels.contiguous().view(-1, 1)
|
||||
if labels.shape[0] != batch_size:
|
||||
raise ValueError('Num of labels does not match num of features')
|
||||
|
||||
mask = torch.eq(labels, labels.T).float().to(device)
|
||||
cross = torch.matmul(features, features.T)
|
||||
|
||||
upper_diag = torch.triu_indices(batch_size,batch_size,+1)
|
||||
cross_upper = cross[upper_diag[0], upper_diag[1]]
|
||||
mask_upper = mask[upper_diag[0], upper_diag[1]]
|
||||
|
||||
npos = int(mask_upper.sum().item())
|
||||
|
||||
pos_loss = bin_cross_entropy(cross_upper, mask_upper, label=1)
|
||||
neg_loss = bin_cross_entropy(cross_upper, mask_upper, label=0)
|
||||
|
||||
loss = bin_cross_entropy(cross_upper, mask_upper)
|
||||
return loss, neg_loss, pos_loss
|
||||
|
||||
|
||||
def choice(tensor, k):
|
||||
|
@ -152,14 +189,27 @@ def choice(tensor, k):
|
|||
return tensor[idx]
|
||||
|
||||
|
||||
def mse(input, target, label, k=-1):
|
||||
input = input[target==label]
|
||||
def mse(input, target, label=None, k=-1):
|
||||
if label is not None:
|
||||
index = target==label
|
||||
input = input[index]
|
||||
target = target[index]
|
||||
if k>-1:
|
||||
input = choice(input, k)
|
||||
target = choice(target, k)
|
||||
|
||||
if label==0:
|
||||
return torch.mean(input**2)
|
||||
else:
|
||||
return torch.mean((1-input)**2)
|
||||
return torch.mean((input - target) ** 2)
|
||||
|
||||
# if label==0:
|
||||
# return torch.mean(input**2)
|
||||
# else:
|
||||
# return torch.mean((1-input)**2)
|
||||
# index = target==label
|
||||
# return torch.mean((input[index] - target[index]) ** 2)
|
||||
|
||||
|
||||
def bin_cross_entropy(input, target, label=None):
|
||||
if label is None:
|
||||
return torch.nn.functional.binary_cross_entropy_with_logits(input, target)
|
||||
index = target == label
|
||||
return torch.nn.functional.binary_cross_entropy_with_logits(input[index], target[index])
|
||||
|
|
138
src/main.py
138
src/main.py
|
@ -49,6 +49,26 @@ def load_dataset(opt):
|
|||
return dataset_name, dataset
|
||||
|
||||
|
||||
def instantiate_model(A, index, pad_index, device):
|
||||
phi = Phi(
|
||||
cnn=CNNProjection(
|
||||
vocabulary_size=index.vocabulary_size(),
|
||||
embedding_dim=opt.hidden,
|
||||
channels_out=opt.chout,
|
||||
kernel_sizes=opt.kernelsizes),
|
||||
ff=FFProjection(input_size=len(opt.kernelsizes) * opt.chout,
|
||||
hidden_sizes=[],
|
||||
output_size=opt.repr,
|
||||
activation=nn.functional.relu,
|
||||
dropout=0.5,
|
||||
activate_last=True),
|
||||
).to(device)
|
||||
cls = AuthorshipAttributionClassifier(
|
||||
phi, num_authors=A.size, pad_index=pad_index, pad_length=opt.pad, device=device
|
||||
)
|
||||
cls.xavier_uniform()
|
||||
print(cls)
|
||||
return cls, phi
|
||||
|
||||
def main(opt):
|
||||
|
||||
|
@ -73,25 +93,8 @@ def main(opt):
|
|||
|
||||
# attribution
|
||||
print('Attribution')
|
||||
phi = Phi(
|
||||
cnn=CNNProjection(
|
||||
vocabulary_size=index.vocabulary_size(),
|
||||
embedding_dim=opt.hidden,
|
||||
channels_out=opt.chout,
|
||||
kernel_sizes=opt.kernelsizes),
|
||||
ff=FFProjection(input_size=len(opt.kernelsizes) * opt.chout,
|
||||
hidden_sizes=[1024],
|
||||
output_size=opt.repr,
|
||||
activation=nn.functional.relu,
|
||||
dropout=0.5,
|
||||
activate_last=True),
|
||||
).to(device)
|
||||
|
||||
cls = AuthorshipAttributionClassifier(
|
||||
phi, num_authors=A.size, pad_index=pad_index, pad_length=opt.pad, device=device
|
||||
)
|
||||
# cls.xavier_uniform()
|
||||
print(cls)
|
||||
cls, phi = instantiate_model(A, index, pad_index, device)
|
||||
|
||||
if opt.name == 'auto':
|
||||
method = f'{phi.__class__.__name__}_alpha{opt.alpha}'
|
||||
|
@ -99,52 +102,52 @@ def main(opt):
|
|||
method = opt.name
|
||||
|
||||
with open(f'results_feb_{opt.mode}.txt', 'wt') as foo:
|
||||
if opt.mode=='savlin':
|
||||
Xtr_, Xval_, ytr_, yval_ = train_test_split(Xtr, ytr, test_size=0.1, stratify=ytr)
|
||||
cls.supervised_contrastive_learning(Xtr_, ytr_, Xval_, yval_,
|
||||
batch_size=opt.batchsize, epochs=opt.epochs, lr=opt.lr,
|
||||
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||
checkpointpath=opt.checkpoint)
|
||||
val_microf1 = cls.train_linear_classifier(Xtr_, ytr_, Xval_, yval_,
|
||||
batch_size=opt.batchsize, epochs=opt.epochs, lr=opt.lr,
|
||||
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||
checkpointpath=opt.checkpoint)
|
||||
# test
|
||||
yte_ = cls.predict(Xte)
|
||||
print('sav(fix)-lin(trained) network prediction')
|
||||
acc, macrof1, microf1 = evaluation(yte, yte_)
|
||||
foo.write(f'sav(fix)-lin(trained) network prediction: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}\n')
|
||||
Xtr_, Xval_, ytr_, yval_ = train_test_split(Xtr, ytr, test_size=0.1, stratify=ytr)
|
||||
|
||||
val_microf1 = cls.fit(Xtr, ytr,
|
||||
batch_size=opt.batchsize, epochs=opt.epochs, alpha=opt.alpha, lr=opt.lr,
|
||||
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||
checkpointpath=opt.checkpoint
|
||||
)
|
||||
# test
|
||||
yte_ = cls.predict(Xte)
|
||||
print('end-to-end-finetuning network prediction')
|
||||
acc, macrof1, microf1 = evaluation(yte, yte_)
|
||||
foo.write(
|
||||
f'end-to-end-finetuning network prediction: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}\n')
|
||||
cls.supervised_contrastive_learning(Xtr_, ytr_, Xval_, yval_,
|
||||
batch_size=opt.batchsize, epochs=opt.epochs, lr=opt.lr,
|
||||
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||
checkpointpath=opt.checkpoint)
|
||||
|
||||
svm = GridSearchCV(LinearSVC(), param_grid={'C':np.logspace(-2,3,6), 'class_weight':['balanced',None]}, n_jobs=-1)
|
||||
svm.fit(cls.project(Xtr), ytr)
|
||||
yte_ = svm.predict(cls.project(Xte))
|
||||
acc, macrof1, microf1 = evaluation(yte, yte_)
|
||||
print(f'svm: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}')
|
||||
foo.write(
|
||||
f'svm network prediction: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}\n')
|
||||
elif opt.mode=='attr':
|
||||
# train
|
||||
val_microf1 = cls.fit(Xtr, ytr,
|
||||
batch_size=opt.batchsize, epochs=opt.epochs, alpha=opt.alpha, lr=opt.lr,
|
||||
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||
checkpointpath=opt.checkpoint
|
||||
)
|
||||
# test
|
||||
yte_ = cls.predict(Xte)
|
||||
print('end-to-end network prediction')
|
||||
acc, macrof1, microf1 = evaluation(yte, yte_)
|
||||
# svm_experiment(cls.project(Xtr), ytr, cls.project(Xte), yte, foo, 'svm-pre')
|
||||
svm_experiment(cls.project_kernel(Xtr), ytr, cls.project_kernel(Xte), yte, foo, 'svm-kernel')
|
||||
|
||||
|
||||
val_microf1 = cls.train_linear_classifier(Xtr_, ytr_, Xval_, yval_,
|
||||
batch_size=opt.batchsize, epochs=opt.epochs, lr=opt.lr,
|
||||
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||
checkpointpath=opt.checkpoint)
|
||||
# test
|
||||
yte_ = cls.predict(Xte)
|
||||
print('sav(fix)-lin(trained) network prediction')
|
||||
|
||||
acc, macrof1, microf1 = evaluation(yte, yte_)
|
||||
foo.write(f'sav(fix)-lin(trained) network prediction: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}\n')
|
||||
|
||||
val_microf1 = cls.fit(Xtr, ytr,
|
||||
batch_size=opt.batchsize, epochs=opt.epochs, alpha=opt.alpha, lr=opt.lr,
|
||||
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||
checkpointpath=opt.checkpoint
|
||||
)
|
||||
# test
|
||||
yte_ = cls.predict(Xte)
|
||||
print('end-to-end-finetuning network prediction')
|
||||
acc, macrof1, microf1 = evaluation(yte, yte_)
|
||||
foo.write(
|
||||
f'end-to-end-finetuning network prediction: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}\n')
|
||||
|
||||
print('training end-to-end without self-supervision init')
|
||||
cls, phi = instantiate_model(A, index, pad_index, device)
|
||||
# train
|
||||
val_microf1 = cls.fit(Xtr, ytr,
|
||||
batch_size=opt.batchsize, epochs=opt.epochs, alpha=opt.alpha, lr=opt.lr,
|
||||
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||
checkpointpath=opt.checkpoint
|
||||
)
|
||||
# test
|
||||
yte_ = cls.predict(Xte)
|
||||
print('end-to-end (w/o self-supervised initialization) network prediction')
|
||||
acc, macrof1, microf1 = evaluation(yte, yte_)
|
||||
|
||||
# results = Results(opt.output)
|
||||
# results.add(dataset_name, method, acc, macrof1, microf1, val_microf1)
|
||||
|
@ -183,13 +186,22 @@ class Results:
|
|||
def close(self):
|
||||
self.foo.close()
|
||||
|
||||
def svm_experiment(Xtr, ytr, Xte, yte, foo, name):
|
||||
svm = GridSearchCV(
|
||||
LinearSVC(), param_grid={'C': np.logspace(-2, 3, 6), 'class_weight': ['balanced', None]}, n_jobs=-1
|
||||
)
|
||||
svm.fit(Xtr, ytr)
|
||||
yte_ = svm.predict(Xte)
|
||||
acc, macrof1, microf1 = evaluation(yte, yte_)
|
||||
print(f'{name}: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}')
|
||||
foo.write(f'{name} network prediction: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='CNN with KTA regularization')
|
||||
parser.add_argument('-H', '--hidden', help='Hidden/embedding size', type=int, default=32)
|
||||
parser.add_argument('-c', '--chout', help='Channels output size', type=int, default=128)
|
||||
parser.add_argument('-r', '--repr', help='Projection size (phi)', type=int, default=1024)
|
||||
parser.add_argument('-r', '--repr', help='Projection size (phi)', type=int, default=2048)
|
||||
parser.add_argument('-k', '--kernelsizes', help='Size of the convolutional kernels', nargs='+', default=[6,7,8])
|
||||
parser.add_argument('-p', '--pad', help='Pad length', type=int, default=3000)
|
||||
parser.add_argument('-b', '--batchsize', help='Batch size', type=int, default=250)
|
||||
|
|
|
@ -7,7 +7,7 @@ from tqdm import tqdm
|
|||
import math
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
from losses import SupConLoss1View
|
||||
from losses import SupConLoss1View, SupConLoss1ViewCrossEntropy
|
||||
from model.early_stop import EarlyStop
|
||||
from model.layers import FFProjection
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -20,6 +20,7 @@ class AuthorshipAttributionClassifier(nn.Module):
|
|||
self.ff = FFProjection(input_size=projector.output_size,
|
||||
hidden_sizes=[],
|
||||
output_size=num_authors).to(device)
|
||||
self.linear_proj = nn.Linear(projector.output_size, 128).to(device) # to train the kernel alignment
|
||||
self.pad_index = pad_index
|
||||
self.pad_length = pad_length
|
||||
self.device = device
|
||||
|
@ -132,10 +133,11 @@ class AuthorshipAttributionClassifier(nn.Module):
|
|||
self.load_state_dict(torch.load(checkpointpath))
|
||||
return early_stop.best_score
|
||||
|
||||
def supervised_contrastive_learning(self, X, y, Xval, yval, batch_size, epochs, patience=50, lr=0.001, log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
||||
def supervised_contrastive_learning(self, X, y, Xval, yval, batch_size, epochs, patience=200, lr=0.001, log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
||||
early_stop = EarlyStop(patience, lower_is_better=True)
|
||||
|
||||
criterion = SupConLoss1View().to(self.device)
|
||||
# criterion = SupConLoss1ViewCrossEntropy().to(self.device)
|
||||
optim = torch.optim.Adam(self.parameters(), lr=lr)
|
||||
|
||||
tr_data = IndexedDataset(X, y, self.pad_length, self.pad_index, self.device)
|
||||
|
@ -154,6 +156,8 @@ class AuthorshipAttributionClassifier(nn.Module):
|
|||
#while True:
|
||||
optim.zero_grad()
|
||||
phi = self.projector(xi)
|
||||
phi = self.linear_proj(phi)
|
||||
phi = F.normalize(phi, p=2, dim=-1)
|
||||
#contrastive_loss = criterion(phi, torch.as_tensor(yi).to(self.device))
|
||||
contrastive_loss, neg_loss, pos_loss = criterion(phi, torch.as_tensor(yi).to(self.device))
|
||||
#contrastive_loss = neg_loss+pos_loss
|
||||
|
@ -191,7 +195,7 @@ class AuthorshipAttributionClassifier(nn.Module):
|
|||
self.load_state_dict(torch.load(checkpointpath))
|
||||
return early_stop.best_score
|
||||
|
||||
def train_linear_classifier(self, X, y, Xval, yval, batch_size, epochs, patience=10, lr=0.001, log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
||||
def train_linear_classifier(self, X, y, Xval, yval, batch_size, epochs, patience=25, lr=0.001, log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
||||
early_stop = EarlyStop(patience)
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss().to(self.device)
|
||||
|
@ -272,6 +276,19 @@ class AuthorshipAttributionClassifier(nn.Module):
|
|||
predictions.append(phi)
|
||||
return np.concatenate(predictions)
|
||||
|
||||
def project_kernel(self, x, batch_size=100):
|
||||
self.eval()
|
||||
te_data = IndexedDataset(x, None, self.pad_length, self.pad_index, self.device)
|
||||
predictions = []
|
||||
with torch.no_grad():
|
||||
for xi in te_data.asDataLoader(batch_size, shuffle=False):
|
||||
phi = self.projector(xi)
|
||||
phi = self.linear_proj(phi)
|
||||
phi = F.normalize(phi, p=2, dim=-1)
|
||||
phi = tensor2numpy(phi)
|
||||
predictions.append(phi)
|
||||
return np.concatenate(predictions)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
phi = self.projector(x)
|
||||
|
|
Loading…
Reference in New Issue