tiny_clip / src /loss.py
sachin's picture
Initial training code
6d1b6c6
raw
history blame
No virus
3.58 kB
import torch
from torch import nn
import torch.nn.functional as F
def get_similarity_matrix(
image_features: torch.Tensor, text_features: torch.Tensor
) -> torch.Tensor:
return image_features @ text_features.T
def contrastive_loss(logits, dim):
neg_ce = torch.diag(F.log_softmax(logits, dim=dim))
return -neg_ce.mean()
def contrastive_sigmoid_loss(logits):
return F.binary_cross_entropy_with_logits(logits, torch.eye(len(logits)), reduction="mean")
class CLIPLoss(nn.Module):
def __init__(self, logit_temperature: float = -1.0):
super().__init__()
self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature))
def forward(self, similarity_matrix: torch.Tensor, *args):
temperature = self.logit_temperature.sigmoid()
caption_loss = contrastive_loss(similarity_matrix / temperature, dim=0)
image_loss = contrastive_loss(similarity_matrix / temperature, dim=1)
return 0.5 * (caption_loss + image_loss)
class CyCLIPLoss(nn.Module):
def __init__(self, logit_temperature: float = -1.0):
super().__init__()
self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature))
self.lambda_1: float = 1.0
self.lambda_2: float = 1.0
def forward(
self,
similarity_matrix: torch.Tensor,
image_features: torch.Tensor,
text_features: torch.Tensor,
):
temperature = self.logit_temperature.sigmoid()
caption_loss = contrastive_loss(similarity_matrix / temperature, dim=0)
image_loss = contrastive_loss(similarity_matrix / temperature, dim=1)
symmetry_loss = F.mse_loss(similarity_matrix, similarity_matrix.T)
modality_difference_loss = F.mse_loss(
image_features @ image_features.T, text_features @ text_features.T
)
return (
0.5 * (caption_loss + image_loss)
+ self.lambda_1 * symmetry_loss
+ self.lambda_2 * modality_difference_loss
)
class SigLIPLoss(nn.Module):
def __init__(self, logit_temperature: float = -1.0):
super().__init__()
self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature))
def forward(self, similarity_matrix: torch.Tensor, *args):
temperature = self.logit_temperature.sigmoid()
return contrastive_sigmoid_loss(similarity_matrix / temperature)
class CySigLIPLoss(nn.Module):
def __init__(self, logit_temperature: float = -1.0):
super().__init__()
self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature))
self.lambda_1: float = 1.0
self.lambda_2: float = 1.0
def forward(
self,
similarity_matrix: torch.Tensor,
image_features: torch.Tensor,
text_features: torch.Tensor,
):
temperature = self.logit_temperature.sigmoid()
loss = contrastive_sigmoid_loss(similarity_matrix / temperature)
symmetry_loss = F.mse_loss(similarity_matrix, similarity_matrix.T)
modality_difference_loss = F.mse_loss(
image_features @ image_features.T, text_features @ text_features.T
)
return loss + self.lambda_1 * symmetry_loss + self.lambda_2 * modality_difference_loss
def get_loss(loss_type: str):
loss_functions = {
"clip": CLIPLoss(),
"cyclip": CyCLIPLoss(),
"sigmoid": SigLIPLoss(),
"cyclic_sigmoid": CySigLIPLoss(),
}
if loss_type in loss_functions:
return loss_functions[loss_type]
else:
raise ValueError("Invalid loss type")