Soumyapro commited on
Commit
2aea542
1 Parent(s): 27a180e

upload_model

Browse files
Files changed (2) hide show
  1. model/best.pt +3 -0
  2. model/clip_model.py +102 -0
model/best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb622fc997c8e7a6c2881662d79a40b619cb5e83e563f8a4f5b5cece7fb73d1c
3
+ size 363250197
model/clip_model.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from transformers import DistilBertTokenizer, DistilBertModel,DistilBertConfig
3
+ import timm
4
+
5
+ def cross_entropy(preds,targets,reduction='none'):
6
+ log_softmax = nn.LogSoftmax(dim=-1)
7
+ loss = (-targets * log_softmax(preds)).sum(1)
8
+ if reduction == 'none':
9
+ return loss
10
+ else:
11
+ return loss.mean()
12
+
13
+
14
+ class TextEncoder(nn.Module):
15
+ def __init__(self,model_name='distilbert-base-uncased',pretrained=True,trainable=True):
16
+ super().__init__()
17
+ if pretrained:
18
+ self.model = DistilBertModel.from_pretrained(model_name)
19
+ else:
20
+ self.model = DistilBertModel(config=DistilBertConfig())
21
+
22
+ for p in self.model.parameters():
23
+ p.requires_grad = trainable
24
+
25
+ self.target_token_idx =0
26
+
27
+ def forward(self,input_ids,attention_mask):
28
+ output = self.model(input_ids=input_ids,attention_mask=attention_mask)
29
+ last_hidden_state = output.last_hidden_state
30
+ return last_hidden_state[:,self.target_token_idx,:]
31
+
32
+
33
+ class ImageEncoder(nn.Module):
34
+ def __init__(self,model_name='resnet50',pretrained=True,trainable=True):
35
+ super().__init__()
36
+ self.model = timm.create_model(model_name,pretrained,num_classes=0,global_pool="avg")
37
+
38
+
39
+ for p in self.model.parameters():
40
+ p.requires_grad = trainable
41
+
42
+ def forward(self,x):
43
+ return self.model(x)
44
+
45
+
46
+ class ProjectionHead(nn.Module):
47
+ def __init__(self,embedding_dim,projection_dim=256,dropout=0.1):
48
+ super().__init__()
49
+ self.projection = nn.Linear(embedding_dim,projection_dim)
50
+ self.gelu = nn.GELU()
51
+ self.fc = nn.Linear(projection_dim,projection_dim)
52
+ self.dropout = nn.Dropout(p=dropout)
53
+ self.layer_norm = nn.LayerNorm(projection_dim)
54
+
55
+ def forward(self,x):
56
+ # print("shape of input",x.shape)
57
+ projected = self.projection(x)
58
+ # print("shape after projection-1",projected.shape)
59
+ x = self.gelu(projected)
60
+ # print("shape after Gelu Layer",x.shape)
61
+ x = self.fc(x)
62
+ # print("shape after projection-2",x.shape)
63
+ x = self.dropout(x)
64
+ # print("shape after dropout 1",x.shape)
65
+ x = self.layer_norm(x+projected)
66
+ # print("shape output / Normalization",x.shape)
67
+ return x
68
+
69
+
70
+ class CLIPModel(nn.Module):
71
+ def __init__(self,temperature=1.0,image_embedding=2048,text_embedding=768):
72
+ super().__init__()
73
+ self.image_encoder = ImageEncoder()
74
+ self.text_encoder = TextEncoder()
75
+ self.image_projection = ProjectionHead(embedding_dim=image_embedding)
76
+ self.text_projection = ProjectionHead(embedding_dim=text_embedding)
77
+ self.temperature = temperature
78
+
79
+ def forward(self,batch):
80
+ image_features = self.image_encoder(batch['image'])
81
+ text_features = self.text_encoder(input_ids = batch['input_ids'], attention_mask=batch['attention_mask'])
82
+ image_embeddings = self.image_projection(image_features)
83
+ # print('image_embedding_shape',image_embeddings.shape)
84
+ text_embeddings = self.text_projection(text_features)
85
+ # print('text_embedding_shape',text_embeddings.shape)
86
+
87
+
88
+ #calculating the loss
89
+ logits = (text_embeddings @ image_embeddings.T) / self.temperature
90
+ # print("logits size() : ",logits.shape)
91
+ image_similarity = image_embeddings @ image_embeddings.T
92
+ # print("image_similarity() : ",image_similarity.shape)
93
+ text_similarity = text_embeddings @ text_embeddings.T
94
+ # print("text_similarity() : ",text_similarity.shape)
95
+ targets = F.softmax((image_similarity + text_similarity )/ 2*self.temperature, dim=-1)
96
+ # print("targets shape: ",text_similarity.shape)
97
+ texts_loss = cross_entropy(logits,targets,reduction='none')
98
+ # print("texts_loss shape",texts_loss)
99
+ image_loss = cross_entropy(logits.T ,targets.T,reduction='none')
100
+ # print("image_loss shape",image_loss)
101
+ loss = (image_loss + texts_loss) / 2.0
102
+ return loss.mean()