Upload folder using huggingface_hub
Browse files- model/best.pt +3 -0
- model/clip_model.py +102 -0
model/best.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bb622fc997c8e7a6c2881662d79a40b619cb5e83e563f8a4f5b5cece7fb73d1c
|
3 |
+
size 363250197
|
model/clip_model.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from transformers import DistilBertTokenizer, DistilBertModel,DistilBertConfig
|
3 |
+
import timm
|
4 |
+
|
5 |
+
def cross_entropy(preds,targets,reduction='none'):
|
6 |
+
log_softmax = nn.LogSoftmax(dim=-1)
|
7 |
+
loss = (-targets * log_softmax(preds)).sum(1)
|
8 |
+
if reduction == 'none':
|
9 |
+
return loss
|
10 |
+
else:
|
11 |
+
return loss.mean()
|
12 |
+
|
13 |
+
|
14 |
+
class TextEncoder(nn.Module):
|
15 |
+
def __init__(self,model_name='distilbert-base-uncased',pretrained=True,trainable=True):
|
16 |
+
super().__init__()
|
17 |
+
if pretrained:
|
18 |
+
self.model = DistilBertModel.from_pretrained(model_name)
|
19 |
+
else:
|
20 |
+
self.model = DistilBertModel(config=DistilBertConfig())
|
21 |
+
|
22 |
+
for p in self.model.parameters():
|
23 |
+
p.requires_grad = trainable
|
24 |
+
|
25 |
+
self.target_token_idx =0
|
26 |
+
|
27 |
+
def forward(self,input_ids,attention_mask):
|
28 |
+
output = self.model(input_ids=input_ids,attention_mask=attention_mask)
|
29 |
+
last_hidden_state = output.last_hidden_state
|
30 |
+
return last_hidden_state[:,self.target_token_idx,:]
|
31 |
+
|
32 |
+
|
33 |
+
class ImageEncoder(nn.Module):
|
34 |
+
def __init__(self,model_name='resnet50',pretrained=True,trainable=True):
|
35 |
+
super().__init__()
|
36 |
+
self.model = timm.create_model(model_name,pretrained,num_classes=0,global_pool="avg")
|
37 |
+
|
38 |
+
|
39 |
+
for p in self.model.parameters():
|
40 |
+
p.requires_grad = trainable
|
41 |
+
|
42 |
+
def forward(self,x):
|
43 |
+
return self.model(x)
|
44 |
+
|
45 |
+
|
46 |
+
class ProjectionHead(nn.Module):
|
47 |
+
def __init__(self,embedding_dim,projection_dim=256,dropout=0.1):
|
48 |
+
super().__init__()
|
49 |
+
self.projection = nn.Linear(embedding_dim,projection_dim)
|
50 |
+
self.gelu = nn.GELU()
|
51 |
+
self.fc = nn.Linear(projection_dim,projection_dim)
|
52 |
+
self.dropout = nn.Dropout(p=dropout)
|
53 |
+
self.layer_norm = nn.LayerNorm(projection_dim)
|
54 |
+
|
55 |
+
def forward(self,x):
|
56 |
+
# print("shape of input",x.shape)
|
57 |
+
projected = self.projection(x)
|
58 |
+
# print("shape after projection-1",projected.shape)
|
59 |
+
x = self.gelu(projected)
|
60 |
+
# print("shape after Gelu Layer",x.shape)
|
61 |
+
x = self.fc(x)
|
62 |
+
# print("shape after projection-2",x.shape)
|
63 |
+
x = self.dropout(x)
|
64 |
+
# print("shape after dropout 1",x.shape)
|
65 |
+
x = self.layer_norm(x+projected)
|
66 |
+
# print("shape output / Normalization",x.shape)
|
67 |
+
return x
|
68 |
+
|
69 |
+
|
70 |
+
class CLIPModel(nn.Module):
|
71 |
+
def __init__(self,temperature=1.0,image_embedding=2048,text_embedding=768):
|
72 |
+
super().__init__()
|
73 |
+
self.image_encoder = ImageEncoder()
|
74 |
+
self.text_encoder = TextEncoder()
|
75 |
+
self.image_projection = ProjectionHead(embedding_dim=image_embedding)
|
76 |
+
self.text_projection = ProjectionHead(embedding_dim=text_embedding)
|
77 |
+
self.temperature = temperature
|
78 |
+
|
79 |
+
def forward(self,batch):
|
80 |
+
image_features = self.image_encoder(batch['image'])
|
81 |
+
text_features = self.text_encoder(input_ids = batch['input_ids'], attention_mask=batch['attention_mask'])
|
82 |
+
image_embeddings = self.image_projection(image_features)
|
83 |
+
# print('image_embedding_shape',image_embeddings.shape)
|
84 |
+
text_embeddings = self.text_projection(text_features)
|
85 |
+
# print('text_embedding_shape',text_embeddings.shape)
|
86 |
+
|
87 |
+
|
88 |
+
#calculating the loss
|
89 |
+
logits = (text_embeddings @ image_embeddings.T) / self.temperature
|
90 |
+
# print("logits size() : ",logits.shape)
|
91 |
+
image_similarity = image_embeddings @ image_embeddings.T
|
92 |
+
# print("image_similarity() : ",image_similarity.shape)
|
93 |
+
text_similarity = text_embeddings @ text_embeddings.T
|
94 |
+
# print("text_similarity() : ",text_similarity.shape)
|
95 |
+
targets = F.softmax((image_similarity + text_similarity )/ 2*self.temperature, dim=-1)
|
96 |
+
# print("targets shape: ",text_similarity.shape)
|
97 |
+
texts_loss = cross_entropy(logits,targets,reduction='none')
|
98 |
+
# print("texts_loss shape",texts_loss)
|
99 |
+
image_loss = cross_entropy(logits.T ,targets.T,reduction='none')
|
100 |
+
# print("image_loss shape",image_loss)
|
101 |
+
loss = (image_loss + texts_loss) / 2.0
|
102 |
+
return loss.mean()
|