import torch import torch.nn as nn import torch.nn.functional as F import config class ContrastiveLoss_simcse(nn.Module): """SimCSE loss""" def __init__(self): super(ContrastiveLoss_simcse, self).__init__() self.temperature = config.temperature def forward(self, feature_vectors, labels): normalized_features = F.normalize( feature_vectors, p=2, dim=0 ) # normalize along columns # Identify indices for each label anchor_indices = (labels == 0).nonzero().squeeze(dim=1) positive_indices = (labels == 1).nonzero().squeeze(dim=1) negative_indices = (labels == 2).nonzero().squeeze(dim=1) # Extract tensors based on labels anchor = normalized_features[anchor_indices] positives = normalized_features[positive_indices] negatives = normalized_features[negative_indices] pos_and_neg = torch.cat([positives, negatives]) denominator = torch.sum( torch.exp( torch.div( torch.matmul(anchor, torch.transpose(pos_and_neg, 0, 1)), self.temperature, ) ) ) numerator = torch.exp( torch.div( torch.matmul(anchor, torch.transpose(positives, 0, 1)), self.temperature, ) ) loss = -torch.log( torch.div( numerator, denominator, ) ) return loss class ContrastiveLoss_simcse_w(nn.Module): """SimCSE loss with weighting.""" def __init__(self): super(ContrastiveLoss_simcse_w, self).__init__() self.temperature = config.temperature def forward(self, feature_vectors, labels, scores): normalized_features = F.normalize( feature_vectors, p=2, dim=0 ) # normalize along columns # Identify indices for each label anchor_indices = (labels == 0).nonzero().squeeze(dim=1) positive_indices = (labels == 1).nonzero().squeeze(dim=1) negative_indices = (labels == 2).nonzero().squeeze(dim=1) pos_scores = scores[positive_indices].float() normalized_neg_scores = F.normalize( scores[negative_indices].float(), p=2, dim=0 ) # l2-norm normalized_neg_scores += 1 scores = torch.cat([pos_scores, normalized_neg_scores]) # Extract tensors based on labels anchor = normalized_features[anchor_indices] positives = normalized_features[positive_indices] negatives = normalized_features[negative_indices] pos_and_neg = torch.cat([positives, negatives]) denominator = torch.sum( torch.exp( scores * torch.div( torch.matmul(anchor, torch.transpose(pos_and_neg, 0, 1)), self.temperature, ) ) ) numerator = torch.exp( torch.div( torch.matmul(anchor, torch.transpose(positives, 0, 1)), self.temperature, ) ) loss = -torch.log( torch.div( numerator, denominator, ) ) return loss class ContrastiveLoss_samp(nn.Module): """Supervised contrastive loss without weighting.""" def __init__(self): super(ContrastiveLoss_samp, self).__init__() self.temperature = config.temperature def forward(self, feature_vectors, labels): # Normalize feature vectors normalized_features = F.normalize( feature_vectors, p=2, dim=0 ) # normalize along columns # Identify indices for each label anchor_indices = (labels == 0).nonzero().squeeze(dim=1) positive_indices = (labels == 1).nonzero().squeeze(dim=1) negative_indices = (labels == 2).nonzero().squeeze(dim=1) # Extract tensors based on labels anchor = normalized_features[anchor_indices] positives = normalized_features[positive_indices] negatives = normalized_features[negative_indices] pos_and_neg = torch.cat([positives, negatives]) pos_cardinal = positives.shape[0] denominator = torch.sum( torch.exp( torch.div( torch.matmul(anchor, torch.transpose(pos_and_neg, 0, 1)), self.temperature, ) ) ) sum_log_ent = torch.sum( torch.log( torch.div( torch.exp( torch.div( torch.matmul(anchor, torch.transpose(positives, 0, 1)), self.temperature, ) ), denominator, ) ) ) scale = -1 / pos_cardinal return scale * sum_log_ent class ContrastiveLoss_samp_w(nn.Module): """Supervised contrastive loss with weighting.""" def __init__(self): super(ContrastiveLoss_samp_w, self).__init__() self.temperature = config.temperature def forward(self, feature_vectors, labels, scores): # Normalize feature vectors normalized_features = F.normalize( feature_vectors, p=2, dim=0 ) # normalize along columns # Identify indices for each label anchor_indices = (labels == 0).nonzero().squeeze(dim=1) positive_indices = (labels == 1).nonzero().squeeze(dim=1) negative_indices = (labels == 2).nonzero().squeeze(dim=1) # Normalize score vector num_skip = len(positive_indices) + 1 pos_scores = scores[: (num_skip - 1)].float() # exclude anchor normalized_neg_scores = F.normalize( scores[num_skip:].float(), p=2, dim=0 ) # l2-norm normalized_neg_scores += 1 scores = torch.cat([pos_scores, normalized_neg_scores]) # Extract tensors based on labels anchor = normalized_features[anchor_indices] positives = normalized_features[positive_indices] negatives = normalized_features[negative_indices] pos_and_neg = torch.cat([positives, negatives]) pos_cardinal = positives.shape[0] denominator = torch.sum( torch.exp( scores * torch.div( torch.matmul(anchor, torch.transpose(pos_and_neg, 0, 1)), self.temperature, ) ) ) sum_log_ent = torch.sum( torch.log( torch.div( torch.exp( torch.div( torch.matmul(anchor, torch.transpose(positives, 0, 1)), self.temperature, ) ), denominator, ) ) ) scale = -1 / pos_cardinal return scale * sum_log_ent