Soumyapro's picture
Upload folder using huggingface_hub
7173b22 verified
import torch.nn as nn
from transformers import DistilBertTokenizer, DistilBertModel,DistilBertConfig
import timm
def cross_entropy(preds,targets,reduction='none'):
log_softmax = nn.LogSoftmax(dim=-1)
loss = (-targets * log_softmax(preds)).sum(1)
if reduction == 'none':
return loss
else:
return loss.mean()
class TextEncoder(nn.Module):
def __init__(self,model_name='distilbert-base-uncased',pretrained=True,trainable=True):
super().__init__()
if pretrained:
self.model = DistilBertModel.from_pretrained(model_name)
else:
self.model = DistilBertModel(config=DistilBertConfig())
for p in self.model.parameters():
p.requires_grad = trainable
self.target_token_idx =0
def forward(self,input_ids,attention_mask):
output = self.model(input_ids=input_ids,attention_mask=attention_mask)
last_hidden_state = output.last_hidden_state
return last_hidden_state[:,self.target_token_idx,:]
class ImageEncoder(nn.Module):
def __init__(self,model_name='resnet50',pretrained=True,trainable=True):
super().__init__()
self.model = timm.create_model(model_name,pretrained,num_classes=0,global_pool="avg")
for p in self.model.parameters():
p.requires_grad = trainable
def forward(self,x):
return self.model(x)
class ProjectionHead(nn.Module):
def __init__(self,embedding_dim,projection_dim=256,dropout=0.1):
super().__init__()
self.projection = nn.Linear(embedding_dim,projection_dim)
self.gelu = nn.GELU()
self.fc = nn.Linear(projection_dim,projection_dim)
self.dropout = nn.Dropout(p=dropout)
self.layer_norm = nn.LayerNorm(projection_dim)
def forward(self,x):
# print("shape of input",x.shape)
projected = self.projection(x)
# print("shape after projection-1",projected.shape)
x = self.gelu(projected)
# print("shape after Gelu Layer",x.shape)
x = self.fc(x)
# print("shape after projection-2",x.shape)
x = self.dropout(x)
# print("shape after dropout 1",x.shape)
x = self.layer_norm(x+projected)
# print("shape output / Normalization",x.shape)
return x
class CLIPModel(nn.Module):
def __init__(self,temperature=1.0,image_embedding=2048,text_embedding=768):
super().__init__()
self.image_encoder = ImageEncoder()
self.text_encoder = TextEncoder()
self.image_projection = ProjectionHead(embedding_dim=image_embedding)
self.text_projection = ProjectionHead(embedding_dim=text_embedding)
self.temperature = temperature
def forward(self,batch):
image_features = self.image_encoder(batch['image'])
text_features = self.text_encoder(input_ids = batch['input_ids'], attention_mask=batch['attention_mask'])
image_embeddings = self.image_projection(image_features)
# print('image_embedding_shape',image_embeddings.shape)
text_embeddings = self.text_projection(text_features)
# print('text_embedding_shape',text_embeddings.shape)
#calculating the loss
logits = (text_embeddings @ image_embeddings.T) / self.temperature
# print("logits size() : ",logits.shape)
image_similarity = image_embeddings @ image_embeddings.T
# print("image_similarity() : ",image_similarity.shape)
text_similarity = text_embeddings @ text_embeddings.T
# print("text_similarity() : ",text_similarity.shape)
targets = F.softmax((image_similarity + text_similarity )/ 2*self.temperature, dim=-1)
# print("targets shape: ",text_similarity.shape)
texts_loss = cross_entropy(logits,targets,reduction='none')
# print("texts_loss shape",texts_loss)
image_loss = cross_entropy(logits.T ,targets.T,reduction='none')
# print("image_loss shape",image_loss)
loss = (image_loss + texts_loss) / 2.0
return loss.mean()