update in the frobenius norm, adding and squrt
This commit is contained in:
parent
38867371fa
commit
9f3567c4f8
|
@ -107,16 +107,6 @@ class SupConLoss1View(nn.Module):
|
||||||
self.base_temperature = base_temperature
|
self.base_temperature = base_temperature
|
||||||
|
|
||||||
def forward(self, features, labels):
|
def forward(self, features, labels):
|
||||||
"""Compute loss for model. If both `labels` and `mask` are None,
|
|
||||||
it degenerates to SimCLR unsupervised loss:
|
|
||||||
https://arxiv.org/pdf/2002.05709.pdf
|
|
||||||
|
|
||||||
Args:
|
|
||||||
features: hidden vector of shape [bsz, ndim].
|
|
||||||
labels: ground truth of shape [bsz].
|
|
||||||
Returns:
|
|
||||||
A loss scalar.
|
|
||||||
"""
|
|
||||||
device = (torch.device('cuda')
|
device = (torch.device('cuda')
|
||||||
if features.is_cuda
|
if features.is_cuda
|
||||||
else torch.device('cpu'))
|
else torch.device('cpu'))
|
||||||
|
@ -146,7 +136,8 @@ class SupConLoss1View(nn.Module):
|
||||||
#return neg_loss, pos_loss
|
#return neg_loss, pos_loss
|
||||||
#balanced_loss = pos_loss + neg_loss
|
#balanced_loss = pos_loss + neg_loss
|
||||||
#return balanced_loss
|
#return balanced_loss
|
||||||
return torch.mean((cross_upper-mask_upper)**2), neg_loss, pos_loss
|
n=len(mask_upper)
|
||||||
|
return (1/n)*torch.sqrt(torch.sum((cross_upper-mask_upper)**2)), neg_loss, pos_loss
|
||||||
|
|
||||||
|
|
||||||
def mse(input, target, label):
|
def mse(input, target, label):
|
||||||
|
|
Loading…
Reference in New Issue