|
import torch.nn as nn
|
|
from transformers import DistilBertTokenizer, DistilBertModel,DistilBertConfig
|
|
import timm
|
|
|
|
def cross_entropy(preds,targets,reduction='none'):
|
|
log_softmax = nn.LogSoftmax(dim=-1)
|
|
loss = (-targets * log_softmax(preds)).sum(1)
|
|
if reduction == 'none':
|
|
return loss
|
|
else:
|
|
return loss.mean()
|
|
|
|
|
|
class TextEncoder(nn.Module):
|
|
def __init__(self,model_name='distilbert-base-uncased',pretrained=True,trainable=True):
|
|
super().__init__()
|
|
if pretrained:
|
|
self.model = DistilBertModel.from_pretrained(model_name)
|
|
else:
|
|
self.model = DistilBertModel(config=DistilBertConfig())
|
|
|
|
for p in self.model.parameters():
|
|
p.requires_grad = trainable
|
|
|
|
self.target_token_idx =0
|
|
|
|
def forward(self,input_ids,attention_mask):
|
|
output = self.model(input_ids=input_ids,attention_mask=attention_mask)
|
|
last_hidden_state = output.last_hidden_state
|
|
return last_hidden_state[:,self.target_token_idx,:]
|
|
|
|
|
|
class ImageEncoder(nn.Module):
|
|
def __init__(self,model_name='resnet50',pretrained=True,trainable=True):
|
|
super().__init__()
|
|
self.model = timm.create_model(model_name,pretrained,num_classes=0,global_pool="avg")
|
|
|
|
|
|
for p in self.model.parameters():
|
|
p.requires_grad = trainable
|
|
|
|
def forward(self,x):
|
|
return self.model(x)
|
|
|
|
|
|
class ProjectionHead(nn.Module):
|
|
def __init__(self,embedding_dim,projection_dim=256,dropout=0.1):
|
|
super().__init__()
|
|
self.projection = nn.Linear(embedding_dim,projection_dim)
|
|
self.gelu = nn.GELU()
|
|
self.fc = nn.Linear(projection_dim,projection_dim)
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
self.layer_norm = nn.LayerNorm(projection_dim)
|
|
|
|
def forward(self,x):
|
|
|
|
projected = self.projection(x)
|
|
|
|
x = self.gelu(projected)
|
|
|
|
x = self.fc(x)
|
|
|
|
x = self.dropout(x)
|
|
|
|
x = self.layer_norm(x+projected)
|
|
|
|
return x
|
|
|
|
|
|
class CLIPModel(nn.Module):
|
|
def __init__(self,temperature=1.0,image_embedding=2048,text_embedding=768):
|
|
super().__init__()
|
|
self.image_encoder = ImageEncoder()
|
|
self.text_encoder = TextEncoder()
|
|
self.image_projection = ProjectionHead(embedding_dim=image_embedding)
|
|
self.text_projection = ProjectionHead(embedding_dim=text_embedding)
|
|
self.temperature = temperature
|
|
|
|
def forward(self,batch):
|
|
image_features = self.image_encoder(batch['image'])
|
|
text_features = self.text_encoder(input_ids = batch['input_ids'], attention_mask=batch['attention_mask'])
|
|
image_embeddings = self.image_projection(image_features)
|
|
|
|
text_embeddings = self.text_projection(text_features)
|
|
|
|
|
|
|
|
|
|
logits = (text_embeddings @ image_embeddings.T) / self.temperature
|
|
|
|
image_similarity = image_embeddings @ image_embeddings.T
|
|
|
|
text_similarity = text_embeddings @ text_embeddings.T
|
|
|
|
targets = F.softmax((image_similarity + text_similarity )/ 2*self.temperature, dim=-1)
|
|
|
|
texts_loss = cross_entropy(logits,targets,reduction='none')
|
|
|
|
image_loss = cross_entropy(logits.T ,targets.T,reduction='none')
|
|
|
|
loss = (image_loss + texts_loss) / 2.0
|
|
return loss.mean() |