Juliofc commited on
Commit
8c04c2d
·
1 Parent(s): 3781482

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +210 -0
app.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from torch.utils.data import DataLoader
3
+ from torchvision import transforms
4
+ from tqdm.notebook import tqdm
5
+ import torch
6
+ from torch.autograd import Variable
7
+ import torchvision
8
+ import pickle
9
+ from PIL import Image
10
+ import torch.nn as nn
11
+ import math
12
+ import random
13
+ import gradio as gr
14
+ device = "cpu"
15
+ max_seq_len=67
16
+ with open('index_to_word.pkl', 'rb') as handle:
17
+ index_to_word = pickle.load(handle)
18
+ with open('word_to_index.pkl', 'rb') as handle:
19
+ word_to_index = pickle.load(handle)
20
+
21
+ resnet18 = torchvision.models.resnet18(pretrained=True).to(device)
22
+ resnet18.eval()
23
+ resNet18Layer4 = resnet18._modules.get('layer4').to(device)
24
+
25
+ def create_df(img):
26
+ df = pd.DataFrame({"image": [img]})
27
+ return df
28
+
29
+ def get_vector(t_img):
30
+
31
+ t_img = Variable(t_img)
32
+ my_embedding = torch.zeros(1, 512, 7, 7)
33
+ def copy_data(m, i, o):
34
+ my_embedding.copy_(o.data)
35
+
36
+ h = resNet18Layer4.register_forward_hook(copy_data)
37
+ resnet18(t_img)
38
+
39
+ h.remove()
40
+ return my_embedding
41
+
42
+ class extractImageFeatureResNetDataSet():
43
+ from PIL import Image
44
+ def __init__(self, data):
45
+ self.data = data
46
+ self.scaler = transforms.Resize([224, 224])
47
+ self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
48
+ std=[0.229, 0.224, 0.225])
49
+ self.to_tensor = transforms.ToTensor()
50
+ def __len__(self):
51
+ return len(self.data)
52
+
53
+ def __getitem__(self, idx):
54
+
55
+ image_name = self.data.iloc[idx]['image']
56
+ img_loc = str(image_name) #os.getcwd()+'/imput_img/'+str(image_name)
57
+ img = Image.open(img_loc)
58
+ t_img = self.normalize(self.to_tensor(self.scaler(img)))
59
+
60
+ return image_name, t_img
61
+
62
+ def feature_exctractor(df):
63
+ extract_imgFtr_ResNet_input = {}
64
+ input_ImageDataset_ResNet = extractImageFeatureResNetDataSet(df[['image']])
65
+ input_ImageDataloader_ResNet = DataLoader(input_ImageDataset_ResNet, batch_size = 1, shuffle=False)
66
+ for image_name, t_img in tqdm(input_ImageDataloader_ResNet):
67
+ t_img = t_img.to("cpu")
68
+ embdg = get_vector(t_img)
69
+ extract_imgFtr_ResNet_input[image_name[0]] = embdg
70
+ return extract_imgFtr_ResNet_input
71
+
72
+ class PositionalEncoding(nn.Module):
73
+
74
+ def __init__(self, d_model, dropout=0.1, max_len=max_seq_len):
75
+ super(PositionalEncoding, self).__init__()
76
+ self.dropout = nn.Dropout(p=dropout)
77
+
78
+ pe = torch.zeros(max_len, d_model)
79
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
80
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
81
+ pe[:, 0::2] = torch.sin(position * div_term)
82
+ pe[:, 1::2] = torch.cos(position * div_term)
83
+ pe = pe.unsqueeze(0)
84
+ self.register_buffer('pe', pe)
85
+
86
+
87
+ def forward(self, x):
88
+ if self.pe.size(0) < x.size(0):
89
+ self.pe = self.pe.repeat(x.size(0), 1, 1).to(device)
90
+ self.pe = self.pe[:x.size(0), : , : ]
91
+
92
+ x = x + self.pe
93
+ return self.dropout(x)
94
+
95
+ class ImageCaptionModel(nn.Module):
96
+ def __init__(self, n_head, n_decoder_layer, vocab_size, embedding_size):
97
+ super(ImageCaptionModel, self).__init__()
98
+ self.pos_encoder = PositionalEncoding(embedding_size, 0.1)
99
+ self.TransformerDecoderLayer = nn.TransformerDecoderLayer(d_model = embedding_size, nhead = n_head)
100
+ self.TransformerDecoder = nn.TransformerDecoder(decoder_layer = self.TransformerDecoderLayer, num_layers = n_decoder_layer)
101
+ self.embedding_size = embedding_size
102
+ self.embedding = nn.Embedding(vocab_size , embedding_size)
103
+ self.last_linear_layer = nn.Linear(embedding_size, vocab_size)
104
+ self.init_weights()
105
+
106
+ def init_weights(self):
107
+ initrange = 0.1
108
+ self.embedding.weight.data.uniform_(-initrange, initrange)
109
+ self.last_linear_layer.bias.data.zero_()
110
+ self.last_linear_layer.weight.data.uniform_(-initrange, initrange)
111
+
112
+ def generate_Mask(self, size, decoder_inp):
113
+ decoder_input_mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
114
+ decoder_input_mask = decoder_input_mask.float().masked_fill(decoder_input_mask == 0, float('-inf')).masked_fill(decoder_input_mask == 1, float(0.0))
115
+
116
+ decoder_input_pad_mask = decoder_inp.float().masked_fill(decoder_inp == 0, float(0.0)).masked_fill(decoder_inp > 0, float(1.0))
117
+ decoder_input_pad_mask_bool = decoder_inp == 0
118
+
119
+ return decoder_input_mask, decoder_input_pad_mask, decoder_input_pad_mask_bool
120
+
121
+ def forward(self, encoded_image, decoder_inp):
122
+ encoded_image = encoded_image.permute(1,0,2)
123
+
124
+
125
+ decoder_inp_embed = self.embedding(decoder_inp)* math.sqrt(self.embedding_size)
126
+
127
+ decoder_inp_embed = self.pos_encoder(decoder_inp_embed)
128
+ decoder_inp_embed = decoder_inp_embed.permute(1,0,2)
129
+
130
+
131
+ decoder_input_mask, decoder_input_pad_mask, decoder_input_pad_mask_bool = self.generate_Mask(decoder_inp.size(1), decoder_inp)
132
+ decoder_input_mask = decoder_input_mask.to(device)
133
+ decoder_input_pad_mask = decoder_input_pad_mask.to(device)
134
+ decoder_input_pad_mask_bool = decoder_input_pad_mask_bool.to(device)
135
+
136
+
137
+ decoder_output = self.TransformerDecoder(tgt = decoder_inp_embed, memory = encoded_image, tgt_mask = decoder_input_mask, tgt_key_padding_mask = decoder_input_pad_mask_bool)
138
+
139
+ final_output = self.last_linear_layer(decoder_output)
140
+
141
+ return final_output, decoder_input_pad_mask
142
+
143
+
144
+ def generate_caption(K, img_nm, extract_imgFtr_ResNet_input):
145
+ from PIL import Image
146
+ img_loc = str(img_nm)#os.getcwd()+'/imput_img/'+
147
+ image = Image.open(img_loc).convert("RGB")
148
+ #plt.imshow(image)
149
+
150
+ model.eval()
151
+ img_embed = extract_imgFtr_ResNet_input[img_nm].to(device)
152
+
153
+
154
+ img_embed = img_embed.permute(0,2,3,1)
155
+ img_embed = img_embed.view(img_embed.size(0), -1, img_embed.size(3))
156
+
157
+
158
+ input_seq = [pad_token]*max_seq_len
159
+ input_seq[0] = start_token
160
+
161
+ input_seq = torch.tensor(input_seq).unsqueeze(0).to(device)
162
+ predicted_sentence = []
163
+ with torch.no_grad():
164
+ for eval_iter in range(0, max_seq_len):
165
+ output, padding_mask = model.forward(img_embed, input_seq)
166
+
167
+ output = output[eval_iter, 0, :]
168
+
169
+ values = torch.topk(output, K).values.tolist()
170
+ indices = torch.topk(output, K).indices.tolist()
171
+
172
+ next_word_index = random.choices(indices, values, k = 1)[0]
173
+
174
+ next_word = index_to_word[next_word_index]
175
+
176
+ input_seq[:, eval_iter+1] = next_word_index
177
+
178
+
179
+ if next_word == '<end>' :
180
+ break
181
+
182
+ predicted_sentence.append(next_word)
183
+ return " ".join(predicted_sentence + ["."])
184
+
185
+ device = torch.device('cpu')
186
+ model = torch.load('./BestModel_20000_Datos', map_location=device)
187
+ start_token = word_to_index['<start>']
188
+ end_token = word_to_index['<end>']
189
+ pad_token = word_to_index['<pad>']
190
+ max_seq_len = 67
191
+
192
+ def predict(inp):
193
+ device = "cpu"
194
+ max_seq_len=67
195
+ with open('index_to_word.pkl', 'rb') as handle:
196
+ index_to_word = pickle.load(handle)
197
+ with open('word_to_index.pkl', 'rb') as handle:
198
+ word_to_index = pickle.load(handle)
199
+
200
+ resnet18 = torchvision.models.resnet18(pretrained=True).to(device)
201
+ resnet18.eval()
202
+ resNet18Layer4 = resnet18._modules.get('layer4').to(device)
203
+ df = create_df(inp)
204
+ extract_imgFtr_ResNet_input = feature_exctractor(df)
205
+ prediction = generate_caption(1, inp, extract_imgFtr_ResNet_input)
206
+ return prediction
207
+
208
+ gr.Interface(fn=predict,
209
+ inputs=gr.Image(type="filepath"),
210
+ outputs=gr.Text()).launch(debug=True)