CHOPT / loss.py
sxtforreal's picture
Upload 3 files
d09e211 verified
raw
history blame
No virus
2.28 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import config
class CL_loss(nn.Module):
"""Supervised contrastive loss without weighting."""
def __init__(self):
super(CL_loss, self).__init__()
self.temperature = config.temperature
def forward(self, feature_vectors, labels):
normalized_features = F.normalize(
feature_vectors, p=2, dim=1
) # normalize by row, each row euc is approximately 1
# Identify indices for each label
anchor_indices = (labels == 0).nonzero().squeeze(dim=1)
positive_indices = (labels == 1).nonzero().squeeze(dim=1)
negative_indices = (labels == 2).nonzero().squeeze(dim=1)
# Extract tensors based on labels
anchor = normalized_features[anchor_indices]
positives = normalized_features[positive_indices]
negatives = normalized_features[negative_indices]
pos_and_neg = torch.cat([positives, negatives])
pos_cardinal = positives.shape[0]
denominator = torch.sum(
torch.exp(
torch.div(
F.cosine_similarity(anchor, pos_and_neg, dim=1),
self.temperature,
)
)
)
# if not torch.isfinite(denominator):
# print("Denominator is Inf!")
# if not torch.isfinite(
# torch.exp(
# torch.div(F.cosine_similarity(anchor, pos_and_neg, dim=1)),
# self.temperature,
# )
# ).all():
# print("Exp is Inf!")
# print(
# torch.exp(
# torch.div(F.cosine_similarity(anchor, pos_and_neg, dim=1)),
# self.temperature,
# )
# )
sum_log_ent = torch.sum(
torch.log(
torch.div(
torch.exp(
torch.div(
F.cosine_similarity(anchor, positives, dim=1),
self.temperature,
)
),
denominator,
)
)
)
scale = -1 / pos_cardinal
out = scale * sum_log_ent
return out