diff --git a/src/losses.py b/src/losses.py index 2c40f22..f067f07 100644 --- a/src/losses.py +++ b/src/losses.py @@ -107,16 +107,6 @@ class SupConLoss1View(nn.Module): self.base_temperature = base_temperature def forward(self, features, labels): - """Compute loss for model. If both `labels` and `mask` are None, - it degenerates to SimCLR unsupervised loss: - https://arxiv.org/pdf/2002.05709.pdf - - Args: - features: hidden vector of shape [bsz, ndim]. - labels: ground truth of shape [bsz]. - Returns: - A loss scalar. - """ device = (torch.device('cuda') if features.is_cuda else torch.device('cpu')) @@ -146,7 +136,8 @@ class SupConLoss1View(nn.Module): #return neg_loss, pos_loss #balanced_loss = pos_loss + neg_loss #return balanced_loss - return torch.mean((cross_upper-mask_upper)**2), neg_loss, pos_loss + n=len(mask_upper) + return (1/n)*torch.sqrt(torch.sum((cross_upper-mask_upper)**2)), neg_loss, pos_loss def mse(input, target, label):