import torch import torch.nn as nn import torch.nn.functional as F import config class CL_loss(nn.Module): """Supervised contrastive loss without weighting.""" def __init__(self): super(CL_loss, self).__init__() self.temperature = config.temperature def forward(self, feature_vectors, labels): normalized_features = F.normalize( feature_vectors, p=2, dim=1 ) # normalize by row, each row euc is approximately 1 # Identify indices for each label anchor_indices = (labels == 0).nonzero().squeeze(dim=1) positive_indices = (labels == 1).nonzero().squeeze(dim=1) negative_indices = (labels == 2).nonzero().squeeze(dim=1) # Extract tensors based on labels anchor = normalized_features[anchor_indices] positives = normalized_features[positive_indices] negatives = normalized_features[negative_indices] pos_and_neg = torch.cat([positives, negatives]) pos_cardinal = positives.shape[0] denominator = torch.sum( torch.exp( torch.div( F.cosine_similarity(anchor, pos_and_neg, dim=1), self.temperature, ) ) ) # if not torch.isfinite(denominator): # print("Denominator is Inf!") # if not torch.isfinite( # torch.exp( # torch.div(F.cosine_similarity(anchor, pos_and_neg, dim=1)), # self.temperature, # ) # ).all(): # print("Exp is Inf!") # print( # torch.exp( # torch.div(F.cosine_similarity(anchor, pos_and_neg, dim=1)), # self.temperature, # ) # ) sum_log_ent = torch.sum( torch.log( torch.div( torch.exp( torch.div( F.cosine_similarity(anchor, positives, dim=1), self.temperature, ) ), denominator, ) ) ) scale = -1 / pos_cardinal out = scale * sum_log_ent return out