SeemG commited on
Commit
befabd8
1 Parent(s): 2147186

Upload 3 files

Browse files
Files changed (3) hide show
  1. best.pt +3 -0
  2. clip_model.py +330 -0
  3. flickr8k.zip +3 -0
best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82190c1a92ab132cd94422395dc2b671cedccefad5598c1da753431e7cab9575
3
+ size 363250624
clip_model.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import gc
4
+ import numpy as np
5
+ import pandas as pd
6
+ import itertools
7
+ from tqdm.autonotebook import tqdm
8
+ import albumentations as A
9
+ import matplotlib.pyplot as plt
10
+
11
+ import torch
12
+ from torch import nn
13
+ import torch.nn.functional as F
14
+ import timm
15
+ from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
16
+
17
+
18
+
19
+ class CFG:
20
+ debug = False
21
+ image_path = "/content/flickr30k_images/flickr30k_images"
22
+ captions_path = "/content"
23
+ batch_size = 32
24
+ num_workers = 2
25
+ head_lr = 1e-3
26
+ image_encoder_lr = 1e-4
27
+ text_encoder_lr = 1e-5
28
+ weight_decay = 1e-3
29
+ patience = 1
30
+ factor = 0.8
31
+ epochs = 4
32
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+
34
+ model_name = 'resnet50'
35
+ image_embedding = 2048
36
+ text_encoder_model = "distilbert-base-uncased"
37
+ text_embedding = 768
38
+ text_tokenizer = "distilbert-base-uncased"
39
+ max_length = 200
40
+
41
+ pretrained = True # for both image encoder and text encoder
42
+ trainable = True # for both image encoder and text encoder
43
+ temperature = 1.0
44
+
45
+ # image size
46
+ size = 224
47
+
48
+ # for projection head; used for both image and text encoders
49
+ num_projection_layers = 1
50
+ projection_dim = 256
51
+ dropout = 0.1
52
+
53
+
54
+
55
+ class AvgMeter:
56
+ def __init__(self, name="Metric"):
57
+ self.name = name
58
+ self.reset()
59
+
60
+ def reset(self):
61
+ self.avg, self.sum, self.count = [0] * 3
62
+
63
+ def update(self, val, count=1):
64
+ self.count += count
65
+ self.sum += val * count
66
+ self.avg = self.sum / self.count
67
+
68
+ def __repr__(self):
69
+ text = f"{self.name}: {self.avg:.4f}"
70
+ return text
71
+
72
+ def get_lr(optimizer):
73
+ for param_group in optimizer.param_groups:
74
+ return param_group["lr"]
75
+
76
+
77
+
78
+
79
+ class CLIPDataset(torch.utils.data.Dataset):
80
+ def __init__(self, image_filenames, captions, tokenizer, transforms):
81
+ """
82
+ image_filenames and cpations must have the same length; so, if there are
83
+ multiple captions for each image, the image_filenames must have repetitive
84
+ file names
85
+ """
86
+
87
+ self.image_filenames = image_filenames
88
+ self.captions = list(captions)
89
+ self.encoded_captions = tokenizer(
90
+ list(captions), padding=True, truncation=True, max_length=CFG.max_length
91
+ )
92
+ self.transforms = transforms
93
+
94
+ def __getitem__(self, idx):
95
+ item = {
96
+ key: torch.tensor(values[idx])
97
+ for key, values in self.encoded_captions.items()
98
+ }
99
+
100
+ image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}")
101
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
102
+ image = self.transforms(image=image)['image']
103
+ item['image'] = torch.tensor(image).permute(2, 0, 1).float()
104
+ item['caption'] = self.captions[idx]
105
+
106
+ return item
107
+
108
+
109
+ def __len__(self):
110
+ return len(self.captions)
111
+
112
+
113
+
114
+ def get_transforms(mode="train"):
115
+ if mode == "train":
116
+ return A.Compose(
117
+ [
118
+ A.Resize(CFG.size, CFG.size, always_apply=True),
119
+ A.Normalize(max_pixel_value=255.0, always_apply=True),
120
+ ]
121
+ )
122
+ else:
123
+ return A.Compose(
124
+ [
125
+ A.Resize(CFG.size, CFG.size, always_apply=True),
126
+ A.Normalize(max_pixel_value=255.0, always_apply=True),
127
+ ]
128
+ )
129
+
130
+
131
+ class ImageEncoder(nn.Module):
132
+ """
133
+ Encode images to a fixed size vector
134
+ """
135
+
136
+ def __init__(
137
+ self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable
138
+ ):
139
+ super().__init__()
140
+ self.model = timm.create_model(
141
+ model_name, pretrained, num_classes=0, global_pool="avg"
142
+ )
143
+ for p in self.model.parameters():
144
+ p.requires_grad = trainable
145
+
146
+ def forward(self, x):
147
+ return self.model(x)
148
+
149
+ class TextEncoder(nn.Module):
150
+ def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
151
+ super().__init__()
152
+ if pretrained:
153
+ self.model = DistilBertModel.from_pretrained(model_name)
154
+ else:
155
+ self.model = DistilBertModel(config=DistilBertConfig())
156
+
157
+ for p in self.model.parameters():
158
+ p.requires_grad = trainable
159
+
160
+ # we are using the CLS token hidden representation as the sentence's embedding
161
+ self.target_token_idx = 0
162
+
163
+ def forward(self, input_ids, attention_mask):
164
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)
165
+ last_hidden_state = output.last_hidden_state
166
+ return last_hidden_state[:, self.target_token_idx, :]
167
+
168
+
169
+
170
+ class ProjectionHead(nn.Module):
171
+ def __init__(
172
+ self,
173
+ embedding_dim,
174
+ projection_dim=CFG.projection_dim,
175
+ dropout=CFG.dropout
176
+ ):
177
+ super().__init__()
178
+ self.projection = nn.Linear(embedding_dim, projection_dim)
179
+ self.gelu = nn.GELU()
180
+ self.fc = nn.Linear(projection_dim, projection_dim)
181
+ self.dropout = nn.Dropout(dropout)
182
+ self.layer_norm = nn.LayerNorm(projection_dim)
183
+
184
+ def forward(self, x):
185
+ projected = self.projection(x)
186
+ x = self.gelu(projected)
187
+ x = self.fc(x)
188
+ x = self.dropout(x)
189
+ x = x + projected
190
+ x = self.layer_norm(x)
191
+ return x
192
+
193
+ class CLIPModel(nn.Module):
194
+ def __init__(
195
+ self,
196
+ temperature=CFG.temperature,
197
+ image_embedding=CFG.image_embedding,
198
+ text_embedding=CFG.text_embedding,
199
+ ):
200
+ super().__init__()
201
+ self.image_encoder = ImageEncoder()
202
+ self.text_encoder = TextEncoder()
203
+ self.image_projection = ProjectionHead(embedding_dim=image_embedding)
204
+ self.text_projection = ProjectionHead(embedding_dim=text_embedding)
205
+ self.temperature = temperature
206
+
207
+ def forward(self, batch):
208
+ # Getting Image and Text Features
209
+ image_features = self.image_encoder(batch["image"])
210
+ text_features = self.text_encoder(
211
+ input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
212
+ )
213
+ # Getting Image and Text Embeddings (with same dimension)
214
+ image_embeddings = self.image_projection(image_features)
215
+ text_embeddings = self.text_projection(text_features)
216
+
217
+ # Calculating the Loss
218
+ logits = (text_embeddings @ image_embeddings.T) / self.temperature
219
+ images_similarity = image_embeddings @ image_embeddings.T
220
+ texts_similarity = text_embeddings @ text_embeddings.T
221
+ targets = F.softmax(
222
+ (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
223
+ )
224
+ texts_loss = cross_entropy(logits, targets, reduction='none')
225
+ images_loss = cross_entropy(logits.T, targets.T, reduction='none')
226
+ loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size)
227
+ return loss.mean()
228
+
229
+
230
+ def cross_entropy(preds, targets, reduction='none'):
231
+ log_softmax = nn.LogSoftmax(dim=-1)
232
+ loss = (-targets * log_softmax(preds)).sum(1)
233
+ if reduction == "none":
234
+ return loss
235
+ elif reduction == "mean":
236
+ return loss.mean()
237
+
238
+ def make_train_valid_dfs():
239
+ dataframe = pd.read_csv(f"{CFG.captions_path}/captions.csv")
240
+ max_id = dataframe["id"].max() + 1 if not CFG.debug else 100
241
+ image_ids = np.arange(0, max_id)
242
+ np.random.seed(42)
243
+ valid_ids = np.random.choice(
244
+ image_ids, size=int(0.2 * len(image_ids)), replace=False
245
+ )
246
+ train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
247
+ train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True)
248
+ valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True)
249
+ return train_dataframe, valid_dataframe
250
+
251
+
252
+ def build_loaders(dataframe, tokenizer, mode):
253
+ transforms = get_transforms(mode=mode)
254
+ dataset = CLIPDataset(
255
+ dataframe["image"].values,
256
+ dataframe["caption"].values,
257
+ tokenizer=tokenizer,
258
+ transforms=transforms,
259
+ )
260
+ dataloader = torch.utils.data.DataLoader(
261
+ dataset,
262
+ batch_size=CFG.batch_size,
263
+ num_workers=CFG.num_workers,
264
+ shuffle=True if mode == "train" else False,
265
+ )
266
+ return dataloader
267
+
268
+
269
+
270
+
271
+ def get_image_embeddings(valid_df, model_path):
272
+ tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
273
+ valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
274
+
275
+ model = CLIPModel().to(CFG.device)
276
+ model.load_state_dict(torch.load(model_path, map_location=CFG.device))
277
+ model.eval()
278
+
279
+ valid_image_embeddings = []
280
+ with torch.no_grad():
281
+ for batch in tqdm(valid_loader):
282
+ image_features = model.image_encoder(batch["image"].to(CFG.device))
283
+ image_embeddings = model.image_projection(image_features)
284
+ valid_image_embeddings.append(image_embeddings)
285
+ return model, torch.cat(valid_image_embeddings)
286
+
287
+
288
+
289
+ def find_matches(model, image_embeddings, query, image_filenames, n=9):
290
+ tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
291
+ encoded_query = tokenizer([query])
292
+ batch = {
293
+ key: torch.tensor(values).to(CFG.device)
294
+ for key, values in encoded_query.items()
295
+ }
296
+ with torch.no_grad():
297
+ text_features = model.text_encoder(
298
+ input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
299
+ )
300
+ text_embeddings = model.text_projection(text_features)
301
+
302
+ image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
303
+ text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
304
+ dot_similarity = text_embeddings_n @ image_embeddings_n.T
305
+
306
+ values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
307
+ matches = [image_filenames[idx] for idx in indices[::5]]
308
+
309
+ _, axes = plt.subplots(3, 3, figsize=(10, 10))
310
+
311
+ results = []
312
+ for match, ax in zip(matches, axes.flatten()):
313
+ image = cv2.imread(f"{CFG.image_path}/{match}")
314
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
315
+ # ax.imshow(image)
316
+ # ax.axis("off")
317
+ results.append(image)
318
+ return results
319
+
320
+
321
+
322
+
323
+
324
+
325
+
326
+
327
+
328
+
329
+
330
+
flickr8k.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0677f83ccb736b08e73ddd1219ba1ba12bc72a8742a71df4edaaaa4abc64d42b
3
+ size 1112971163