update in the frobenius norm, adding and squrt
This commit is contained in:
parent
38867371fa
commit
9f3567c4f8
|
@ -107,16 +107,6 @@ class SupConLoss1View(nn.Module):
|
|||
self.base_temperature = base_temperature
|
||||
|
||||
def forward(self, features, labels):
|
||||
"""Compute loss for model. If both `labels` and `mask` are None,
|
||||
it degenerates to SimCLR unsupervised loss:
|
||||
https://arxiv.org/pdf/2002.05709.pdf
|
||||
|
||||
Args:
|
||||
features: hidden vector of shape [bsz, ndim].
|
||||
labels: ground truth of shape [bsz].
|
||||
Returns:
|
||||
A loss scalar.
|
||||
"""
|
||||
device = (torch.device('cuda')
|
||||
if features.is_cuda
|
||||
else torch.device('cpu'))
|
||||
|
@ -146,7 +136,8 @@ class SupConLoss1View(nn.Module):
|
|||
#return neg_loss, pos_loss
|
||||
#balanced_loss = pos_loss + neg_loss
|
||||
#return balanced_loss
|
||||
return torch.mean((cross_upper-mask_upper)**2), neg_loss, pos_loss
|
||||
n=len(mask_upper)
|
||||
return (1/n)*torch.sqrt(torch.sum((cross_upper-mask_upper)**2)), neg_loss, pos_loss
|
||||
|
||||
|
||||
def mse(input, target, label):
|
||||
|
|
Loading…
Reference in New Issue