import torch.nn as nn from transformers import DistilBertTokenizer, DistilBertModel,DistilBertConfig import timm def cross_entropy(preds,targets,reduction='none'): log_softmax = nn.LogSoftmax(dim=-1) loss = (-targets * log_softmax(preds)).sum(1) if reduction == 'none': return loss else: return loss.mean() class TextEncoder(nn.Module): def __init__(self,model_name='distilbert-base-uncased',pretrained=True,trainable=True): super().__init__() if pretrained: self.model = DistilBertModel.from_pretrained(model_name) else: self.model = DistilBertModel(config=DistilBertConfig()) for p in self.model.parameters(): p.requires_grad = trainable self.target_token_idx =0 def forward(self,input_ids,attention_mask): output = self.model(input_ids=input_ids,attention_mask=attention_mask) last_hidden_state = output.last_hidden_state return last_hidden_state[:,self.target_token_idx,:] class ImageEncoder(nn.Module): def __init__(self,model_name='resnet50',pretrained=True,trainable=True): super().__init__() self.model = timm.create_model(model_name,pretrained,num_classes=0,global_pool="avg") for p in self.model.parameters(): p.requires_grad = trainable def forward(self,x): return self.model(x) class ProjectionHead(nn.Module): def __init__(self,embedding_dim,projection_dim=256,dropout=0.1): super().__init__() self.projection = nn.Linear(embedding_dim,projection_dim) self.gelu = nn.GELU() self.fc = nn.Linear(projection_dim,projection_dim) self.dropout = nn.Dropout(p=dropout) self.layer_norm = nn.LayerNorm(projection_dim) def forward(self,x): # print("shape of input",x.shape) projected = self.projection(x) # print("shape after projection-1",projected.shape) x = self.gelu(projected) # print("shape after Gelu Layer",x.shape) x = self.fc(x) # print("shape after projection-2",x.shape) x = self.dropout(x) # print("shape after dropout 1",x.shape) x = self.layer_norm(x+projected) # print("shape output / Normalization",x.shape) return x class CLIPModel(nn.Module): def __init__(self,temperature=1.0,image_embedding=2048,text_embedding=768): super().__init__() self.image_encoder = ImageEncoder() self.text_encoder = TextEncoder() self.image_projection = ProjectionHead(embedding_dim=image_embedding) self.text_projection = ProjectionHead(embedding_dim=text_embedding) self.temperature = temperature def forward(self,batch): image_features = self.image_encoder(batch['image']) text_features = self.text_encoder(input_ids = batch['input_ids'], attention_mask=batch['attention_mask']) image_embeddings = self.image_projection(image_features) # print('image_embedding_shape',image_embeddings.shape) text_embeddings = self.text_projection(text_features) # print('text_embedding_shape',text_embeddings.shape) #calculating the loss logits = (text_embeddings @ image_embeddings.T) / self.temperature # print("logits size() : ",logits.shape) image_similarity = image_embeddings @ image_embeddings.T # print("image_similarity() : ",image_similarity.shape) text_similarity = text_embeddings @ text_embeddings.T # print("text_similarity() : ",text_similarity.shape) targets = F.softmax((image_similarity + text_similarity )/ 2*self.temperature, dim=-1) # print("targets shape: ",text_similarity.shape) texts_loss = cross_entropy(logits,targets,reduction='none') # print("texts_loss shape",texts_loss) image_loss = cross_entropy(logits.T ,targets.T,reduction='none') # print("image_loss shape",image_loss) loss = (image_loss + texts_loss) / 2.0 return loss.mean()