going back to mse
This commit is contained in:
parent
5b216fb958
commit
b1160c5336
|
@ -8,6 +8,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class SupConLoss(nn.Module):
|
class SupConLoss(nn.Module):
|
||||||
"""Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
|
"""Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
|
||||||
It also supports the unsupervised contrastive loss in SimCLR"""
|
It also supports the unsupervised contrastive loss in SimCLR"""
|
||||||
|
@ -128,22 +129,58 @@ class SupConLoss1View(nn.Module):
|
||||||
upper_diag = torch.triu_indices(batch_size,batch_size,+1)
|
upper_diag = torch.triu_indices(batch_size,batch_size,+1)
|
||||||
cross_upper = cross[upper_diag[0], upper_diag[1]]
|
cross_upper = cross[upper_diag[0], upper_diag[1]]
|
||||||
mask_upper = mask[upper_diag[0], upper_diag[1]]
|
mask_upper = mask[upper_diag[0], upper_diag[1]]
|
||||||
npos = int(mask_upper.sum().item())
|
# npos = int(mask_upper.sum().item())
|
||||||
# weight = torch.from_numpy(np.asarray([1-pos, pos], dtype=float)).to(device)
|
# weight = torch.from_numpy(np.asarray([1-pos, pos], dtype=float)).to(device)
|
||||||
#return torch.nn.functional.binary_cross_entropy_with_logits(cross_upper, mask_upper)
|
#return torch.nn.functional.binary_cross_entropy_with_logits(cross_upper, mask_upper)
|
||||||
#print('mask min-max:', mask.min(), mask.max())
|
#print('mask min-max:', mask.min(), mask.max())
|
||||||
#print('cross min-max:', cross.min(), cross.max())
|
#print('cross min-max:', cross.min(), cross.max())
|
||||||
#return torch.norm(cross-mask, p='fro') # <-- diagonal signal (trivial) should be too strong
|
#return torch.norm(cross-mask, p='fro') # <-- diagonal signal (trivial) should be too strong
|
||||||
pos_loss = mse(cross_upper, mask_upper, label=1, k=-1)
|
pos_loss = mse(cross_upper, mask_upper, label=1)
|
||||||
neg_loss = mse(cross_upper, mask_upper, label=0, k=npos)
|
neg_loss = mse(cross_upper, mask_upper, label=0)
|
||||||
# return frobenius_loss, neg_loss, pos_loss
|
# return frobenius_loss, neg_loss, pos_loss
|
||||||
#return neg_loss, pos_loss
|
#return neg_loss, pos_loss
|
||||||
# balanced_loss = pos_loss + neg_loss
|
# balanced_loss = pos_loss + neg_loss
|
||||||
# return balanced_loss, neg_loss, pos_loss
|
# return balanced_loss, neg_loss, pos_loss
|
||||||
# loss = torch.nn.functional.binary_cross_entropy(cross_upper, mask_upper)
|
# loss = torch.nn.functional.binary_cross_entropy(cross_upper, mask_upper)
|
||||||
# return loss, neg_loss, pos_loss
|
# return loss, neg_loss, pos_loss
|
||||||
n=len(mask_upper)
|
# n=len(mask_upper)
|
||||||
return (1/n)*torch.sqrt(torch.sum((cross_upper-mask_upper)**2)), neg_loss, pos_loss
|
# return (1/n)*torch.sqrt(torch.sum((cross_upper-mask_upper)**2)), neg_loss, pos_loss
|
||||||
|
return mse(cross_upper, mask_upper), neg_loss, pos_loss
|
||||||
|
|
||||||
|
|
||||||
|
class SupConLoss1ViewCrossEntropy(nn.Module):
|
||||||
|
"""Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
|
||||||
|
It also supports the unsupervised contrastive loss in SimCLR"""
|
||||||
|
def __init__(self):
|
||||||
|
super(SupConLoss1ViewCrossEntropy, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, features, labels):
|
||||||
|
device = (torch.device('cuda')
|
||||||
|
if features.is_cuda
|
||||||
|
else torch.device('cpu'))
|
||||||
|
|
||||||
|
if len(features.shape) != 2:
|
||||||
|
raise ValueError('`features` needs to be [bsz, ndim]')
|
||||||
|
|
||||||
|
batch_size = features.shape[0]
|
||||||
|
labels = labels.contiguous().view(-1, 1)
|
||||||
|
if labels.shape[0] != batch_size:
|
||||||
|
raise ValueError('Num of labels does not match num of features')
|
||||||
|
|
||||||
|
mask = torch.eq(labels, labels.T).float().to(device)
|
||||||
|
cross = torch.matmul(features, features.T)
|
||||||
|
|
||||||
|
upper_diag = torch.triu_indices(batch_size,batch_size,+1)
|
||||||
|
cross_upper = cross[upper_diag[0], upper_diag[1]]
|
||||||
|
mask_upper = mask[upper_diag[0], upper_diag[1]]
|
||||||
|
|
||||||
|
npos = int(mask_upper.sum().item())
|
||||||
|
|
||||||
|
pos_loss = bin_cross_entropy(cross_upper, mask_upper, label=1)
|
||||||
|
neg_loss = bin_cross_entropy(cross_upper, mask_upper, label=0)
|
||||||
|
|
||||||
|
loss = bin_cross_entropy(cross_upper, mask_upper)
|
||||||
|
return loss, neg_loss, pos_loss
|
||||||
|
|
||||||
|
|
||||||
def choice(tensor, k):
|
def choice(tensor, k):
|
||||||
|
@ -152,14 +189,27 @@ def choice(tensor, k):
|
||||||
return tensor[idx]
|
return tensor[idx]
|
||||||
|
|
||||||
|
|
||||||
def mse(input, target, label, k=-1):
|
def mse(input, target, label=None, k=-1):
|
||||||
input = input[target==label]
|
if label is not None:
|
||||||
|
index = target==label
|
||||||
|
input = input[index]
|
||||||
|
target = target[index]
|
||||||
if k>-1:
|
if k>-1:
|
||||||
input = choice(input, k)
|
input = choice(input, k)
|
||||||
|
target = choice(target, k)
|
||||||
|
|
||||||
if label==0:
|
return torch.mean((input - target) ** 2)
|
||||||
return torch.mean(input**2)
|
|
||||||
else:
|
# if label==0:
|
||||||
return torch.mean((1-input)**2)
|
# return torch.mean(input**2)
|
||||||
|
# else:
|
||||||
|
# return torch.mean((1-input)**2)
|
||||||
# index = target==label
|
# index = target==label
|
||||||
# return torch.mean((input[index] - target[index]) ** 2)
|
# return torch.mean((input[index] - target[index]) ** 2)
|
||||||
|
|
||||||
|
|
||||||
|
def bin_cross_entropy(input, target, label=None):
|
||||||
|
if label is None:
|
||||||
|
return torch.nn.functional.binary_cross_entropy_with_logits(input, target)
|
||||||
|
index = target == label
|
||||||
|
return torch.nn.functional.binary_cross_entropy_with_logits(input[index], target[index])
|
||||||
|
|
138
src/main.py
138
src/main.py
|
@ -49,6 +49,26 @@ def load_dataset(opt):
|
||||||
return dataset_name, dataset
|
return dataset_name, dataset
|
||||||
|
|
||||||
|
|
||||||
|
def instantiate_model(A, index, pad_index, device):
|
||||||
|
phi = Phi(
|
||||||
|
cnn=CNNProjection(
|
||||||
|
vocabulary_size=index.vocabulary_size(),
|
||||||
|
embedding_dim=opt.hidden,
|
||||||
|
channels_out=opt.chout,
|
||||||
|
kernel_sizes=opt.kernelsizes),
|
||||||
|
ff=FFProjection(input_size=len(opt.kernelsizes) * opt.chout,
|
||||||
|
hidden_sizes=[],
|
||||||
|
output_size=opt.repr,
|
||||||
|
activation=nn.functional.relu,
|
||||||
|
dropout=0.5,
|
||||||
|
activate_last=True),
|
||||||
|
).to(device)
|
||||||
|
cls = AuthorshipAttributionClassifier(
|
||||||
|
phi, num_authors=A.size, pad_index=pad_index, pad_length=opt.pad, device=device
|
||||||
|
)
|
||||||
|
cls.xavier_uniform()
|
||||||
|
print(cls)
|
||||||
|
return cls, phi
|
||||||
|
|
||||||
def main(opt):
|
def main(opt):
|
||||||
|
|
||||||
|
@ -73,25 +93,8 @@ def main(opt):
|
||||||
|
|
||||||
# attribution
|
# attribution
|
||||||
print('Attribution')
|
print('Attribution')
|
||||||
phi = Phi(
|
|
||||||
cnn=CNNProjection(
|
|
||||||
vocabulary_size=index.vocabulary_size(),
|
|
||||||
embedding_dim=opt.hidden,
|
|
||||||
channels_out=opt.chout,
|
|
||||||
kernel_sizes=opt.kernelsizes),
|
|
||||||
ff=FFProjection(input_size=len(opt.kernelsizes) * opt.chout,
|
|
||||||
hidden_sizes=[1024],
|
|
||||||
output_size=opt.repr,
|
|
||||||
activation=nn.functional.relu,
|
|
||||||
dropout=0.5,
|
|
||||||
activate_last=True),
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
cls = AuthorshipAttributionClassifier(
|
cls, phi = instantiate_model(A, index, pad_index, device)
|
||||||
phi, num_authors=A.size, pad_index=pad_index, pad_length=opt.pad, device=device
|
|
||||||
)
|
|
||||||
# cls.xavier_uniform()
|
|
||||||
print(cls)
|
|
||||||
|
|
||||||
if opt.name == 'auto':
|
if opt.name == 'auto':
|
||||||
method = f'{phi.__class__.__name__}_alpha{opt.alpha}'
|
method = f'{phi.__class__.__name__}_alpha{opt.alpha}'
|
||||||
|
@ -99,52 +102,52 @@ def main(opt):
|
||||||
method = opt.name
|
method = opt.name
|
||||||
|
|
||||||
with open(f'results_feb_{opt.mode}.txt', 'wt') as foo:
|
with open(f'results_feb_{opt.mode}.txt', 'wt') as foo:
|
||||||
if opt.mode=='savlin':
|
Xtr_, Xval_, ytr_, yval_ = train_test_split(Xtr, ytr, test_size=0.1, stratify=ytr)
|
||||||
Xtr_, Xval_, ytr_, yval_ = train_test_split(Xtr, ytr, test_size=0.1, stratify=ytr)
|
|
||||||
cls.supervised_contrastive_learning(Xtr_, ytr_, Xval_, yval_,
|
|
||||||
batch_size=opt.batchsize, epochs=opt.epochs, lr=opt.lr,
|
|
||||||
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
|
||||||
checkpointpath=opt.checkpoint)
|
|
||||||
val_microf1 = cls.train_linear_classifier(Xtr_, ytr_, Xval_, yval_,
|
|
||||||
batch_size=opt.batchsize, epochs=opt.epochs, lr=opt.lr,
|
|
||||||
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
|
||||||
checkpointpath=opt.checkpoint)
|
|
||||||
# test
|
|
||||||
yte_ = cls.predict(Xte)
|
|
||||||
print('sav(fix)-lin(trained) network prediction')
|
|
||||||
acc, macrof1, microf1 = evaluation(yte, yte_)
|
|
||||||
foo.write(f'sav(fix)-lin(trained) network prediction: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}\n')
|
|
||||||
|
|
||||||
val_microf1 = cls.fit(Xtr, ytr,
|
cls.supervised_contrastive_learning(Xtr_, ytr_, Xval_, yval_,
|
||||||
batch_size=opt.batchsize, epochs=opt.epochs, alpha=opt.alpha, lr=opt.lr,
|
batch_size=opt.batchsize, epochs=opt.epochs, lr=opt.lr,
|
||||||
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||||
checkpointpath=opt.checkpoint
|
checkpointpath=opt.checkpoint)
|
||||||
)
|
|
||||||
# test
|
|
||||||
yte_ = cls.predict(Xte)
|
|
||||||
print('end-to-end-finetuning network prediction')
|
|
||||||
acc, macrof1, microf1 = evaluation(yte, yte_)
|
|
||||||
foo.write(
|
|
||||||
f'end-to-end-finetuning network prediction: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}\n')
|
|
||||||
|
|
||||||
svm = GridSearchCV(LinearSVC(), param_grid={'C':np.logspace(-2,3,6), 'class_weight':['balanced',None]}, n_jobs=-1)
|
# svm_experiment(cls.project(Xtr), ytr, cls.project(Xte), yte, foo, 'svm-pre')
|
||||||
svm.fit(cls.project(Xtr), ytr)
|
svm_experiment(cls.project_kernel(Xtr), ytr, cls.project_kernel(Xte), yte, foo, 'svm-kernel')
|
||||||
yte_ = svm.predict(cls.project(Xte))
|
|
||||||
acc, macrof1, microf1 = evaluation(yte, yte_)
|
|
||||||
print(f'svm: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}')
|
val_microf1 = cls.train_linear_classifier(Xtr_, ytr_, Xval_, yval_,
|
||||||
foo.write(
|
batch_size=opt.batchsize, epochs=opt.epochs, lr=opt.lr,
|
||||||
f'svm network prediction: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}\n')
|
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||||
elif opt.mode=='attr':
|
checkpointpath=opt.checkpoint)
|
||||||
# train
|
# test
|
||||||
val_microf1 = cls.fit(Xtr, ytr,
|
yte_ = cls.predict(Xte)
|
||||||
batch_size=opt.batchsize, epochs=opt.epochs, alpha=opt.alpha, lr=opt.lr,
|
print('sav(fix)-lin(trained) network prediction')
|
||||||
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
|
||||||
checkpointpath=opt.checkpoint
|
acc, macrof1, microf1 = evaluation(yte, yte_)
|
||||||
)
|
foo.write(f'sav(fix)-lin(trained) network prediction: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}\n')
|
||||||
# test
|
|
||||||
yte_ = cls.predict(Xte)
|
val_microf1 = cls.fit(Xtr, ytr,
|
||||||
print('end-to-end network prediction')
|
batch_size=opt.batchsize, epochs=opt.epochs, alpha=opt.alpha, lr=opt.lr,
|
||||||
acc, macrof1, microf1 = evaluation(yte, yte_)
|
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||||
|
checkpointpath=opt.checkpoint
|
||||||
|
)
|
||||||
|
# test
|
||||||
|
yte_ = cls.predict(Xte)
|
||||||
|
print('end-to-end-finetuning network prediction')
|
||||||
|
acc, macrof1, microf1 = evaluation(yte, yte_)
|
||||||
|
foo.write(
|
||||||
|
f'end-to-end-finetuning network prediction: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}\n')
|
||||||
|
|
||||||
|
print('training end-to-end without self-supervision init')
|
||||||
|
cls, phi = instantiate_model(A, index, pad_index, device)
|
||||||
|
# train
|
||||||
|
val_microf1 = cls.fit(Xtr, ytr,
|
||||||
|
batch_size=opt.batchsize, epochs=opt.epochs, alpha=opt.alpha, lr=opt.lr,
|
||||||
|
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||||
|
checkpointpath=opt.checkpoint
|
||||||
|
)
|
||||||
|
# test
|
||||||
|
yte_ = cls.predict(Xte)
|
||||||
|
print('end-to-end (w/o self-supervised initialization) network prediction')
|
||||||
|
acc, macrof1, microf1 = evaluation(yte, yte_)
|
||||||
|
|
||||||
# results = Results(opt.output)
|
# results = Results(opt.output)
|
||||||
# results.add(dataset_name, method, acc, macrof1, microf1, val_microf1)
|
# results.add(dataset_name, method, acc, macrof1, microf1, val_microf1)
|
||||||
|
@ -183,13 +186,22 @@ class Results:
|
||||||
def close(self):
|
def close(self):
|
||||||
self.foo.close()
|
self.foo.close()
|
||||||
|
|
||||||
|
def svm_experiment(Xtr, ytr, Xte, yte, foo, name):
|
||||||
|
svm = GridSearchCV(
|
||||||
|
LinearSVC(), param_grid={'C': np.logspace(-2, 3, 6), 'class_weight': ['balanced', None]}, n_jobs=-1
|
||||||
|
)
|
||||||
|
svm.fit(Xtr, ytr)
|
||||||
|
yte_ = svm.predict(Xte)
|
||||||
|
acc, macrof1, microf1 = evaluation(yte, yte_)
|
||||||
|
print(f'{name}: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}')
|
||||||
|
foo.write(f'{name} network prediction: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}\n')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description='CNN with KTA regularization')
|
parser = argparse.ArgumentParser(description='CNN with KTA regularization')
|
||||||
parser.add_argument('-H', '--hidden', help='Hidden/embedding size', type=int, default=32)
|
parser.add_argument('-H', '--hidden', help='Hidden/embedding size', type=int, default=32)
|
||||||
parser.add_argument('-c', '--chout', help='Channels output size', type=int, default=128)
|
parser.add_argument('-c', '--chout', help='Channels output size', type=int, default=128)
|
||||||
parser.add_argument('-r', '--repr', help='Projection size (phi)', type=int, default=1024)
|
parser.add_argument('-r', '--repr', help='Projection size (phi)', type=int, default=2048)
|
||||||
parser.add_argument('-k', '--kernelsizes', help='Size of the convolutional kernels', nargs='+', default=[6,7,8])
|
parser.add_argument('-k', '--kernelsizes', help='Size of the convolutional kernels', nargs='+', default=[6,7,8])
|
||||||
parser.add_argument('-p', '--pad', help='Pad length', type=int, default=3000)
|
parser.add_argument('-p', '--pad', help='Pad length', type=int, default=3000)
|
||||||
parser.add_argument('-b', '--batchsize', help='Batch size', type=int, default=250)
|
parser.add_argument('-b', '--batchsize', help='Batch size', type=int, default=250)
|
||||||
|
|
|
@ -7,7 +7,7 @@ from tqdm import tqdm
|
||||||
import math
|
import math
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
from losses import SupConLoss1View
|
from losses import SupConLoss1View, SupConLoss1ViewCrossEntropy
|
||||||
from model.early_stop import EarlyStop
|
from model.early_stop import EarlyStop
|
||||||
from model.layers import FFProjection
|
from model.layers import FFProjection
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
@ -20,6 +20,7 @@ class AuthorshipAttributionClassifier(nn.Module):
|
||||||
self.ff = FFProjection(input_size=projector.output_size,
|
self.ff = FFProjection(input_size=projector.output_size,
|
||||||
hidden_sizes=[],
|
hidden_sizes=[],
|
||||||
output_size=num_authors).to(device)
|
output_size=num_authors).to(device)
|
||||||
|
self.linear_proj = nn.Linear(projector.output_size, 128).to(device) # to train the kernel alignment
|
||||||
self.pad_index = pad_index
|
self.pad_index = pad_index
|
||||||
self.pad_length = pad_length
|
self.pad_length = pad_length
|
||||||
self.device = device
|
self.device = device
|
||||||
|
@ -132,10 +133,11 @@ class AuthorshipAttributionClassifier(nn.Module):
|
||||||
self.load_state_dict(torch.load(checkpointpath))
|
self.load_state_dict(torch.load(checkpointpath))
|
||||||
return early_stop.best_score
|
return early_stop.best_score
|
||||||
|
|
||||||
def supervised_contrastive_learning(self, X, y, Xval, yval, batch_size, epochs, patience=50, lr=0.001, log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
def supervised_contrastive_learning(self, X, y, Xval, yval, batch_size, epochs, patience=200, lr=0.001, log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
||||||
early_stop = EarlyStop(patience, lower_is_better=True)
|
early_stop = EarlyStop(patience, lower_is_better=True)
|
||||||
|
|
||||||
criterion = SupConLoss1View().to(self.device)
|
criterion = SupConLoss1View().to(self.device)
|
||||||
|
# criterion = SupConLoss1ViewCrossEntropy().to(self.device)
|
||||||
optim = torch.optim.Adam(self.parameters(), lr=lr)
|
optim = torch.optim.Adam(self.parameters(), lr=lr)
|
||||||
|
|
||||||
tr_data = IndexedDataset(X, y, self.pad_length, self.pad_index, self.device)
|
tr_data = IndexedDataset(X, y, self.pad_length, self.pad_index, self.device)
|
||||||
|
@ -154,6 +156,8 @@ class AuthorshipAttributionClassifier(nn.Module):
|
||||||
#while True:
|
#while True:
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
phi = self.projector(xi)
|
phi = self.projector(xi)
|
||||||
|
phi = self.linear_proj(phi)
|
||||||
|
phi = F.normalize(phi, p=2, dim=-1)
|
||||||
#contrastive_loss = criterion(phi, torch.as_tensor(yi).to(self.device))
|
#contrastive_loss = criterion(phi, torch.as_tensor(yi).to(self.device))
|
||||||
contrastive_loss, neg_loss, pos_loss = criterion(phi, torch.as_tensor(yi).to(self.device))
|
contrastive_loss, neg_loss, pos_loss = criterion(phi, torch.as_tensor(yi).to(self.device))
|
||||||
#contrastive_loss = neg_loss+pos_loss
|
#contrastive_loss = neg_loss+pos_loss
|
||||||
|
@ -191,7 +195,7 @@ class AuthorshipAttributionClassifier(nn.Module):
|
||||||
self.load_state_dict(torch.load(checkpointpath))
|
self.load_state_dict(torch.load(checkpointpath))
|
||||||
return early_stop.best_score
|
return early_stop.best_score
|
||||||
|
|
||||||
def train_linear_classifier(self, X, y, Xval, yval, batch_size, epochs, patience=10, lr=0.001, log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
def train_linear_classifier(self, X, y, Xval, yval, batch_size, epochs, patience=25, lr=0.001, log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
||||||
early_stop = EarlyStop(patience)
|
early_stop = EarlyStop(patience)
|
||||||
|
|
||||||
criterion = torch.nn.CrossEntropyLoss().to(self.device)
|
criterion = torch.nn.CrossEntropyLoss().to(self.device)
|
||||||
|
@ -272,6 +276,19 @@ class AuthorshipAttributionClassifier(nn.Module):
|
||||||
predictions.append(phi)
|
predictions.append(phi)
|
||||||
return np.concatenate(predictions)
|
return np.concatenate(predictions)
|
||||||
|
|
||||||
|
def project_kernel(self, x, batch_size=100):
|
||||||
|
self.eval()
|
||||||
|
te_data = IndexedDataset(x, None, self.pad_length, self.pad_index, self.device)
|
||||||
|
predictions = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for xi in te_data.asDataLoader(batch_size, shuffle=False):
|
||||||
|
phi = self.projector(xi)
|
||||||
|
phi = self.linear_proj(phi)
|
||||||
|
phi = F.normalize(phi, p=2, dim=-1)
|
||||||
|
phi = tensor2numpy(phi)
|
||||||
|
predictions.append(phi)
|
||||||
|
return np.concatenate(predictions)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
phi = self.projector(x)
|
phi = self.projector(x)
|
||||||
|
|
Loading…
Reference in New Issue