File size: 4,226 Bytes
7173b22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch.nn as nn
from transformers import DistilBertTokenizer, DistilBertModel,DistilBertConfig
import timm

def cross_entropy(preds,targets,reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == 'none':
        return loss
    else:
        return loss.mean()
    

class TextEncoder(nn.Module):
    def __init__(self,model_name='distilbert-base-uncased',pretrained=True,trainable=True):
        super().__init__()
        if pretrained:
            self.model = DistilBertModel.from_pretrained(model_name)
        else:
            self.model = DistilBertModel(config=DistilBertConfig())
        
        for p in self.model.parameters():
            p.requires_grad = trainable
        
        self.target_token_idx =0
        
    def forward(self,input_ids,attention_mask):
        output = self.model(input_ids=input_ids,attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state[:,self.target_token_idx,:]
    

class ImageEncoder(nn.Module):
    def __init__(self,model_name='resnet50',pretrained=True,trainable=True):
        super().__init__()
        self.model = timm.create_model(model_name,pretrained,num_classes=0,global_pool="avg")
        
        
        for p in self.model.parameters():
            p.requires_grad = trainable
            
    def forward(self,x):
        return self.model(x)
    

class ProjectionHead(nn.Module):
    def __init__(self,embedding_dim,projection_dim=256,dropout=0.1):
        super().__init__()
        self.projection = nn.Linear(embedding_dim,projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim,projection_dim)
        self.dropout = nn.Dropout(p=dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)
    
    def forward(self,x):
#         print("shape of input",x.shape)
        projected = self.projection(x)
#         print("shape after projection-1",projected.shape)
        x = self.gelu(projected)
#         print("shape after Gelu Layer",x.shape)
        x = self.fc(x)
#         print("shape after projection-2",x.shape)
        x = self.dropout(x)
#         print("shape after dropout 1",x.shape)
        x = self.layer_norm(x+projected)
#         print("shape output / Normalization",x.shape)
        return x


class CLIPModel(nn.Module):
    def __init__(self,temperature=1.0,image_embedding=2048,text_embedding=768):
        super().__init__()
        self.image_encoder = ImageEncoder()
        self.text_encoder = TextEncoder()
        self.image_projection = ProjectionHead(embedding_dim=image_embedding)
        self.text_projection = ProjectionHead(embedding_dim=text_embedding)
        self.temperature = temperature
    
    def forward(self,batch):
        image_features = self.image_encoder(batch['image'])
        text_features = self.text_encoder(input_ids = batch['input_ids'], attention_mask=batch['attention_mask'])
        image_embeddings = self.image_projection(image_features)
#         print('image_embedding_shape',image_embeddings.shape)
        text_embeddings = self.text_projection(text_features)
#         print('text_embedding_shape',text_embeddings.shape)
        
        
        #calculating the loss
        logits = (text_embeddings @ image_embeddings.T) / self.temperature
#         print("logits size() : ",logits.shape)
        image_similarity = image_embeddings @ image_embeddings.T
#         print("image_similarity() : ",image_similarity.shape)
        text_similarity = text_embeddings @ text_embeddings.T
#         print("text_similarity() : ",text_similarity.shape)
        targets = F.softmax((image_similarity + text_similarity )/ 2*self.temperature, dim=-1)
#         print("targets shape: ",text_similarity.shape)
        texts_loss = cross_entropy(logits,targets,reduction='none')
#         print("texts_loss shape",texts_loss)
        image_loss = cross_entropy(logits.T ,targets.T,reduction='none')
#         print("image_loss shape",image_loss)
        loss = (image_loss + texts_loss) / 2.0
        return loss.mean()