|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
|
|
|
|
def get_similarity_matrix( |
|
image_features: torch.Tensor, text_features: torch.Tensor |
|
) -> torch.Tensor: |
|
return image_features @ text_features.T |
|
|
|
|
|
def contrastive_loss(logits, dim): |
|
neg_ce = torch.diag(F.log_softmax(logits, dim=dim)) |
|
return -neg_ce.mean() |
|
|
|
|
|
def contrastive_sigmoid_loss(logits): |
|
return F.binary_cross_entropy_with_logits(logits, torch.eye(len(logits)), reduction="mean") |
|
|
|
|
|
class CLIPLoss(nn.Module): |
|
def __init__(self, logit_temperature: float = -1.0): |
|
super().__init__() |
|
self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature)) |
|
|
|
def forward(self, similarity_matrix: torch.Tensor, *args): |
|
temperature = self.logit_temperature.sigmoid() |
|
|
|
caption_loss = contrastive_loss(similarity_matrix / temperature, dim=0) |
|
image_loss = contrastive_loss(similarity_matrix / temperature, dim=1) |
|
|
|
return 0.5 * (caption_loss + image_loss) |
|
|
|
|
|
class CyCLIPLoss(nn.Module): |
|
def __init__(self, logit_temperature: float = -1.0): |
|
super().__init__() |
|
self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature)) |
|
self.lambda_1: float = 1.0 |
|
self.lambda_2: float = 1.0 |
|
|
|
def forward( |
|
self, |
|
similarity_matrix: torch.Tensor, |
|
image_features: torch.Tensor, |
|
text_features: torch.Tensor, |
|
): |
|
temperature = self.logit_temperature.sigmoid() |
|
caption_loss = contrastive_loss(similarity_matrix / temperature, dim=0) |
|
image_loss = contrastive_loss(similarity_matrix / temperature, dim=1) |
|
|
|
symmetry_loss = F.mse_loss(similarity_matrix, similarity_matrix.T) |
|
modality_difference_loss = F.mse_loss( |
|
image_features @ image_features.T, text_features @ text_features.T |
|
) |
|
|
|
return ( |
|
0.5 * (caption_loss + image_loss) |
|
+ self.lambda_1 * symmetry_loss |
|
+ self.lambda_2 * modality_difference_loss |
|
) |
|
|
|
|
|
class SigLIPLoss(nn.Module): |
|
def __init__(self, logit_temperature: float = -1.0): |
|
super().__init__() |
|
self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature)) |
|
|
|
def forward(self, similarity_matrix: torch.Tensor, *args): |
|
temperature = self.logit_temperature.sigmoid() |
|
return contrastive_sigmoid_loss(similarity_matrix / temperature) |
|
|
|
|
|
class CySigLIPLoss(nn.Module): |
|
def __init__(self, logit_temperature: float = -1.0): |
|
super().__init__() |
|
self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature)) |
|
self.lambda_1: float = 1.0 |
|
self.lambda_2: float = 1.0 |
|
|
|
def forward( |
|
self, |
|
similarity_matrix: torch.Tensor, |
|
image_features: torch.Tensor, |
|
text_features: torch.Tensor, |
|
): |
|
temperature = self.logit_temperature.sigmoid() |
|
loss = contrastive_sigmoid_loss(similarity_matrix / temperature) |
|
|
|
symmetry_loss = F.mse_loss(similarity_matrix, similarity_matrix.T) |
|
modality_difference_loss = F.mse_loss( |
|
image_features @ image_features.T, text_features @ text_features.T |
|
) |
|
|
|
return loss + self.lambda_1 * symmetry_loss + self.lambda_2 * modality_difference_loss |
|
|
|
|
|
def get_loss(loss_type: str): |
|
loss_functions = { |
|
"clip": CLIPLoss(), |
|
"cyclip": CyCLIPLoss(), |
|
"sigmoid": SigLIPLoss(), |
|
"cyclic_sigmoid": CySigLIPLoss(), |
|
} |
|
if loss_type in loss_functions: |
|
return loss_functions[loss_type] |
|
else: |
|
raise ValueError("Invalid loss type") |
|
|