VictorGearhead commited on
Commit
d7d0d10
1 Parent(s): e78d325

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +251 -0
app.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from sklearn.model_selection import train_test_split
4
+ import streamlit as st
5
+
6
+ import torchvision
7
+ from torchvision import transforms
8
+ import cv2
9
+ import math
10
+
11
+ from collections import Counter
12
+ from PIL import Image
13
+ import PIL
14
+
15
+ import zipfile
16
+ import io
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.autograd import Variable
21
+ from torch.utils.data import Dataset, DataLoader
22
+ import torch.nn.functional as fun
23
+
24
+ from tqdm.notebook import tqdm
25
+ import matplotlib.pyplot as plt
26
+
27
+ # import nltk
28
+ # import ssl
29
+
30
+ # try:
31
+ # _create_unverified_https_context = ssl._create_unverified_context
32
+ # except AttributeError:
33
+ # pass
34
+ # else:
35
+ # ssl._create_default_https_context = _create_unverified_https_context
36
+
37
+ # nltk.download()
38
+ from nltk.tokenize import word_tokenize
39
+
40
+ import match
41
+ import pickle
42
+ import gc
43
+ import random
44
+
45
+ data = pd.read_csv("captions.txt", sep=',')
46
+
47
+ #Removes Single Char
48
+ def remove_single_char(caption_list):
49
+ list = []
50
+ for word in caption_list:
51
+ if len(word)>1:
52
+ list.append(word)
53
+ return list
54
+
55
+ #Make an array of words out of caption and then remove useless single char words
56
+
57
+ data['caption'] = data['caption'].apply(lambda caption :word_tokenize(caption))
58
+
59
+ data['caption'] = data['caption'].apply(lambda word : remove_single_char(word))
60
+
61
+ #We need to make sure size of all the captions arrays is same so we add <cell> to cover up
62
+ lengths = []
63
+ lengths = data['caption'].apply(lambda caption : len(caption))
64
+
65
+ max_length = lengths.max()
66
+
67
+ data['caption'] = data['caption'].apply(lambda caption : ['<start>'] + caption + ['<cell>']*(max_length-len(caption)) + ['<end>'])
68
+
69
+ #For non truncated dataframe to appear
70
+ pd.set_option('display.max_colwidth', None)
71
+
72
+ #Extracting words
73
+ words = data['caption'].apply(lambda word : " ".join(word)).str.cat(sep = ' ').split(' ')
74
+
75
+ #Arranging the words in order of their frequency
76
+ word_dict = sorted(Counter(words), key=Counter(words).get, reverse=True)
77
+
78
+ dict_size = len(word_dict)
79
+ vocab_threshold = 5
80
+
81
+ #Encoding the words with index in dictionary made above
82
+ data['sequence'] = data['caption'].apply(lambda caption : [word_dict.index(word) for word in caption])
83
+ data = data.sort_values(by = 'image')
84
+
85
+ class PositionalEncoding(nn.Module):
86
+
87
+ def __init__(self, d_model, dropout=0.1, max_len=(max_length+2)):
88
+ super(PositionalEncoding, self).__init__()
89
+ self.dropout = nn.Dropout(p=dropout)
90
+
91
+ pe = torch.zeros(max_len, d_model)
92
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
93
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
94
+ pe[:, 0::2] = torch.sin(position * div_term)
95
+ pe[:, 1::2] = torch.cos(position * div_term)
96
+ pe = pe.unsqueeze(0)
97
+ self.register_buffer('pe', pe)
98
+
99
+
100
+ def forward(self, x):
101
+ if self.pe.size(0) < x.size(0):
102
+ self.pe = self.pe.repeat(x.size(0), 1, 1)
103
+ self.pe = self.pe[:x.size(0), : , : ]
104
+
105
+ x = x + self.pe
106
+ return self.dropout(x)
107
+
108
+ class ImageCaptionModel(nn.Module):
109
+ def __init__(self, n_head, n_decoder_layer, vocab_size, embedding_size):
110
+ super(ImageCaptionModel, self).__init__()
111
+ self.pos_encoder = PositionalEncoding(embedding_size, 0.1)
112
+ self.TransformerDecoderLayer = nn.TransformerDecoderLayer(d_model = embedding_size, nhead = n_head)
113
+ self.TransformerDecoder = nn.TransformerDecoder(decoder_layer = self.TransformerDecoderLayer, num_layers = n_decoder_layer)
114
+ self.embedding_size = embedding_size
115
+ self.embedding = nn.Embedding(vocab_size , embedding_size)
116
+ self.last_linear_layer = nn.Linear(embedding_size, vocab_size)
117
+ self.init_weights()
118
+ self.n_head = n_head
119
+
120
+ def init_weights(self):
121
+ initrange = 0.1
122
+ self.embedding.weight.data.uniform_(-initrange, initrange)
123
+ self.last_linear_layer.bias.data.zero_()
124
+ self.last_linear_layer.weight.data.uniform_(-initrange, initrange)
125
+
126
+ def generate_Mask(self, size, decoder_inp):
127
+ decoder_input_mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
128
+ decoder_input_mask = decoder_input_mask.float().masked_fill(decoder_input_mask == 0, float('-inf')).masked_fill(decoder_input_mask == 1, float(0.0))
129
+
130
+ decoder_input_pad_mask = decoder_inp.float().masked_fill(decoder_inp == 0, float(0.0)).masked_fill(decoder_inp > 0, float(1.0))
131
+ decoder_input_pad_mask_bool = decoder_inp == 0
132
+
133
+ return decoder_input_mask, decoder_input_pad_mask, decoder_input_pad_mask_bool
134
+
135
+ def forward(self, encoded_image, decoder_inp):
136
+ # display(decoder_inp)
137
+ encoded_image = encoded_image.permute(1,0,2)
138
+
139
+ decoder_inp = torch.clamp(decoder_inp, 0, self.embedding.num_embeddings - 1)
140
+
141
+
142
+ decoder_inp_embed = self.embedding(decoder_inp)* math.sqrt(self.embedding_size)
143
+ decoder_inp_embed = self.embedding(decoder_inp)
144
+
145
+ decoder_inp_embed = self.pos_encoder(decoder_inp_embed)
146
+ decoder_inp_embed = decoder_inp_embed.permute(1,0,2)
147
+
148
+
149
+ decoder_input_mask, decoder_input_pad_mask, decoder_input_pad_mask_bool = self.generate_Mask(decoder_inp.size(1), decoder_inp)
150
+ decoder_input_mask = decoder_input_mask
151
+ decoder_input_pad_mask = decoder_input_pad_mask
152
+ decoder_input_pad_mask_bool = decoder_input_pad_mask_bool
153
+
154
+ decoder_output = self.TransformerDecoder(tgt = decoder_inp_embed, memory = encoded_image, tgt_mask = decoder_input_mask, tgt_key_padding_mask = decoder_input_pad_mask_bool)
155
+
156
+ final_output = self.last_linear_layer(decoder_output)
157
+
158
+ return final_output, decoder_input_pad_mask
159
+
160
+
161
+ model = pd.read_pickle('ImageCaptioning_Model.pkl')
162
+ model.eval()
163
+ start_token = 2
164
+ end_token = 3
165
+ cell_token = 1
166
+ max_seq_len = 34
167
+
168
+ validation = pd.read_pickle('Image_Features_Embed_ResNet_Valid.pkl')
169
+
170
+ def process_image_from_zip(zip_path, image_name):
171
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
172
+ with zip_ref.open(image_name) as file:
173
+ # Use BytesIO to keep the file-like object open
174
+ image_data = io.BytesIO(file.read())
175
+ image = Image.open(image_data)
176
+ image = image.convert("RGB") # Convert to RGB if needed
177
+ return image
178
+
179
+ def generate_caption(K, image_path):
180
+
181
+ model.eval()
182
+ image__path = 'Images/' + image_path
183
+ # image = Image.open('Images/' + image_path).convert("RGB")
184
+ image = process_image_from_zip('Images.zip', image__path)
185
+ plt.imshow(image)
186
+
187
+ valid_img_df = validation[validation['image']==image_path]
188
+ print("Actual Caption : ")
189
+ actual_caption_list = valid_img_df['caption'].tolist()
190
+ filtered_caption_list = [word for word in actual_caption_list[0] if word not in ['<start>', '<end>', '<cell>']]
191
+ actual_caption = " ".join(filtered_caption_list)
192
+ print(actual_caption)
193
+
194
+ valid_img_embed = validation[validation['image'] == image_path]
195
+
196
+ img_embed = valid_img_embed['embedded'].tolist()
197
+
198
+ img_embed = torch.tensor(img_embed)
199
+
200
+
201
+ input_seq = [cell_token]*max_seq_len
202
+ input_seq[0] = start_token
203
+
204
+ input_seq = torch.tensor(input_seq).unsqueeze(0)
205
+ predicted_sentence = []
206
+
207
+ with torch.no_grad():
208
+ for eval_iter in range(0, max_seq_len):
209
+ img_embed_dense = img_embed.to_dense()
210
+
211
+ output, padding_mask = model.forward(img_embed, input_seq)
212
+
213
+ output = output[eval_iter, 0, :]
214
+
215
+ values = torch.topk(output, K).values.tolist()
216
+ indices = torch.topk(output, K).indices.tolist()
217
+
218
+ next_word_index = random.choices(indices, values, k = 1)[0]
219
+
220
+ index_to_word = {index: word for index, word in enumerate(word_dict)}
221
+ next_word = index_to_word[next_word_index]
222
+
223
+ if eval_iter + 1 < max_seq_len:
224
+ input_seq[:, eval_iter + 1] = next_word_index
225
+
226
+ if next_word == '<end>' :
227
+ break
228
+
229
+ predicted_sentence.append(next_word)
230
+
231
+ print("\n")
232
+ print("Predicted caption : ")
233
+
234
+ filtered_caption_list = [word for word in predicted_sentence if word not in ['<start>', '<end>', '<cell>']]
235
+ print(" ".join(filtered_caption_list))
236
+
237
+
238
+ st.title('Image Captioning')
239
+ st.write('Generate Caption for Random Image')
240
+ generate_caption_button = st.button('Generate Caption')
241
+
242
+ if generate_caption_button:
243
+ try:
244
+ random_row = validation.sample()
245
+ random_image = random_row.iloc[0]['image']
246
+
247
+ generate_caption(1, random_image)
248
+
249
+ except RuntimeError as e:
250
+ print("TRY AGAIN")
251
+