|
from torch import nn |
|
import torch.nn.functional as F |
|
|
|
from image_encoder import ImageEncoder |
|
from text_encoder import TextEncoder |
|
from projection_head import ProjectionHead |
|
from configuration import CFG |
|
|
|
|
|
class CLIPModel(nn.Module): |
|
def __init__( |
|
self, |
|
temperature=CFG.temperature, |
|
image_embedding=CFG.image_embedding, |
|
text_embedding=CFG.text_embedding, |
|
): |
|
super().__init__() |
|
self.image_encoder = ImageEncoder() |
|
self.text_encoder = TextEncoder() |
|
self.image_projection = ProjectionHead(embedding_dim=image_embedding) |
|
self.text_projection = ProjectionHead(embedding_dim=text_embedding) |
|
self.temperature = temperature |
|
|
|
def forward(self, batch): |
|
|
|
image_features = self.image_encoder(batch["image"]) |
|
text_features = self.text_encoder( |
|
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"] |
|
) |
|
|
|
image_embeddings = self.image_projection(image_features) |
|
text_embeddings = self.text_projection(text_features) |
|
|
|
|
|
logits = (text_embeddings @ image_embeddings.T) / self.temperature |
|
images_similarity = image_embeddings @ image_embeddings.T |
|
texts_similarity = text_embeddings @ text_embeddings.T |
|
targets = F.softmax( |
|
(images_similarity + texts_similarity) / 2 * self.temperature, dim=-1 |
|
) |
|
texts_loss = cross_entropy(logits, targets, reduction='none') |
|
images_loss = cross_entropy(logits.T, targets.T, reduction='none') |
|
loss = (images_loss + texts_loss) / 2.0 |
|
return loss.mean() |
|
|
|
|
|
def cross_entropy(preds, targets, reduction='none'): |
|
log_softmax = nn.LogSoftmax(dim=-1) |
|
loss = (-targets * log_softmax(preds)).sum(1) |
|
if reduction == "none": |
|
return loss |
|
elif reduction == "mean": |
|
return loss.mean() |