CHOPT / loss.py
sxtforreal's picture
Create loss.py
7f4f2d3 verified
raw
history blame
No virus
7.09 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import config
class ContrastiveLoss_simcse(nn.Module):
"""SimCSE loss"""
def __init__(self):
super(ContrastiveLoss_simcse, self).__init__()
self.temperature = config.temperature
def forward(self, feature_vectors, labels):
normalized_features = F.normalize(
feature_vectors, p=2, dim=0
) # normalize along columns
# Identify indices for each label
anchor_indices = (labels == 0).nonzero().squeeze(dim=1)
positive_indices = (labels == 1).nonzero().squeeze(dim=1)
negative_indices = (labels == 2).nonzero().squeeze(dim=1)
# Extract tensors based on labels
anchor = normalized_features[anchor_indices]
positives = normalized_features[positive_indices]
negatives = normalized_features[negative_indices]
pos_and_neg = torch.cat([positives, negatives])
denominator = torch.sum(
torch.exp(
torch.div(
torch.matmul(anchor, torch.transpose(pos_and_neg, 0, 1)),
self.temperature,
)
)
)
numerator = torch.exp(
torch.div(
torch.matmul(anchor, torch.transpose(positives, 0, 1)),
self.temperature,
)
)
loss = -torch.log(
torch.div(
numerator,
denominator,
)
)
return loss
class ContrastiveLoss_simcse_w(nn.Module):
"""SimCSE loss with weighting."""
def __init__(self):
super(ContrastiveLoss_simcse_w, self).__init__()
self.temperature = config.temperature
def forward(self, feature_vectors, labels, scores):
normalized_features = F.normalize(
feature_vectors, p=2, dim=0
) # normalize along columns
# Identify indices for each label
anchor_indices = (labels == 0).nonzero().squeeze(dim=1)
positive_indices = (labels == 1).nonzero().squeeze(dim=1)
negative_indices = (labels == 2).nonzero().squeeze(dim=1)
pos_scores = scores[positive_indices].float()
normalized_neg_scores = F.normalize(
scores[negative_indices].float(), p=2, dim=0
) # l2-norm
normalized_neg_scores += 1
scores = torch.cat([pos_scores, normalized_neg_scores])
# Extract tensors based on labels
anchor = normalized_features[anchor_indices]
positives = normalized_features[positive_indices]
negatives = normalized_features[negative_indices]
pos_and_neg = torch.cat([positives, negatives])
denominator = torch.sum(
torch.exp(
scores
* torch.div(
torch.matmul(anchor, torch.transpose(pos_and_neg, 0, 1)),
self.temperature,
)
)
)
numerator = torch.exp(
torch.div(
torch.matmul(anchor, torch.transpose(positives, 0, 1)),
self.temperature,
)
)
loss = -torch.log(
torch.div(
numerator,
denominator,
)
)
return loss
class ContrastiveLoss_samp(nn.Module):
"""Supervised contrastive loss without weighting."""
def __init__(self):
super(ContrastiveLoss_samp, self).__init__()
self.temperature = config.temperature
def forward(self, feature_vectors, labels):
# Normalize feature vectors
normalized_features = F.normalize(
feature_vectors, p=2, dim=0
) # normalize along columns
# Identify indices for each label
anchor_indices = (labels == 0).nonzero().squeeze(dim=1)
positive_indices = (labels == 1).nonzero().squeeze(dim=1)
negative_indices = (labels == 2).nonzero().squeeze(dim=1)
# Extract tensors based on labels
anchor = normalized_features[anchor_indices]
positives = normalized_features[positive_indices]
negatives = normalized_features[negative_indices]
pos_and_neg = torch.cat([positives, negatives])
pos_cardinal = positives.shape[0]
denominator = torch.sum(
torch.exp(
torch.div(
torch.matmul(anchor, torch.transpose(pos_and_neg, 0, 1)),
self.temperature,
)
)
)
sum_log_ent = torch.sum(
torch.log(
torch.div(
torch.exp(
torch.div(
torch.matmul(anchor, torch.transpose(positives, 0, 1)),
self.temperature,
)
),
denominator,
)
)
)
scale = -1 / pos_cardinal
return scale * sum_log_ent
class ContrastiveLoss_samp_w(nn.Module):
"""Supervised contrastive loss with weighting."""
def __init__(self):
super(ContrastiveLoss_samp_w, self).__init__()
self.temperature = config.temperature
def forward(self, feature_vectors, labels, scores):
# Normalize feature vectors
normalized_features = F.normalize(
feature_vectors, p=2, dim=0
) # normalize along columns
# Identify indices for each label
anchor_indices = (labels == 0).nonzero().squeeze(dim=1)
positive_indices = (labels == 1).nonzero().squeeze(dim=1)
negative_indices = (labels == 2).nonzero().squeeze(dim=1)
# Normalize score vector
num_skip = len(positive_indices) + 1
pos_scores = scores[: (num_skip - 1)].float() # exclude anchor
normalized_neg_scores = F.normalize(
scores[num_skip:].float(), p=2, dim=0
) # l2-norm
normalized_neg_scores += 1
scores = torch.cat([pos_scores, normalized_neg_scores])
# Extract tensors based on labels
anchor = normalized_features[anchor_indices]
positives = normalized_features[positive_indices]
negatives = normalized_features[negative_indices]
pos_and_neg = torch.cat([positives, negatives])
pos_cardinal = positives.shape[0]
denominator = torch.sum(
torch.exp(
scores
* torch.div(
torch.matmul(anchor, torch.transpose(pos_and_neg, 0, 1)),
self.temperature,
)
)
)
sum_log_ent = torch.sum(
torch.log(
torch.div(
torch.exp(
torch.div(
torch.matmul(anchor, torch.transpose(positives, 0, 1)),
self.temperature,
)
),
denominator,
)
)
)
scale = -1 / pos_cardinal
return scale * sum_log_ent