zabir-nabil commited on
Commit
c913f51
1 Parent(s): 18970fe

Create CLIP_model.py

Browse files
Files changed (1) hide show
  1. CLIP_model.py +71 -0
CLIP_model.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from torchvision import models, transforms
5
+ from transformers import AutoTokenizer, AutoModel
6
+ import config as CFG
7
+ import cv2
8
+
9
+ class CLIPModel(nn.Module):
10
+ """CLIP model for Bangla"""
11
+ def __init__(self):
12
+ super(CLIPModel, self).__init__()
13
+ self.image_encoder = models.efficientnet_b2(weights = "EfficientNet_B2_Weights.DEFAULT")
14
+ self.image_encoder.fc = nn.Identity()
15
+
16
+ self.image_out = nn.Sequential(
17
+ nn.Linear(CFG.image_embedding, 256), nn.ReLU(), nn.Linear(256, 256)
18
+ )
19
+
20
+ self.text_encoder = AutoModel.from_pretrained(CFG.text_encoder_model)
21
+ self.target_token_idx = 0
22
+
23
+
24
+ self.text_out = nn.Sequential(
25
+ nn.Linear(768, 256), nn.ReLU(), nn.Linear(256, 256)
26
+ )
27
+
28
+
29
+ def forward(self, image, text, mask):
30
+ image_vec = self.image_encoder(image)
31
+ image_vec = self.image_out(image_vec)
32
+
33
+ text_out = self.text_encoder(text, mask)
34
+ last_hidden_states = text_out.last_hidden_state
35
+
36
+ last_hidden_states = last_hidden_states[:,self.target_token_idx,:]
37
+ text_vec = self.text_out(last_hidden_states.view(-1,768))
38
+
39
+ return image_vec, text_vec
40
+
41
+ def get_image_embeddings(self, image):
42
+ image_vec = self.image_encoder(image)
43
+ image_vec = self.image_out(image_vec)
44
+
45
+ return image_vec
46
+
47
+ def get_text_embeddings(self, text, mask):
48
+ text_out = self.text_encoder(text, mask)
49
+ last_hidden_states = text_out.last_hidden_state
50
+
51
+ last_hidden_states = last_hidden_states[:,self.target_token_idx,:]
52
+ text_vec = self.text_out(last_hidden_states.view(-1,768))
53
+
54
+ return text_vec
55
+
56
+
57
+
58
+ if __name__ == '__main__':
59
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
+
61
+ images = torch.randn(40, 3, 224, 224).to(device)
62
+ input_ids = torch.randint(5, 300, size=(40, 200)).to(device)
63
+ attention_mask = torch.ones(40, 200).to(device)
64
+
65
+ print("Building CLIP")
66
+ clip_model = CLIPModel().to(device)
67
+ print(clip_model)
68
+
69
+ img_vec, text_vec = clip_model(images, input_ids, attention_mask)
70
+ print(img_vec.shape)
71
+ print(text_vec.shape)