GOKULSINGHSHAH123 commited on
Commit
d4c25b8
1 Parent(s): 9997505

Upload 3 files

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +53 -0
  3. model.py +357 -0
  4. vocab_coco.file +3 -0
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  Image_caption/vocab_coco.file filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  Image_caption/vocab_coco.file filter=lfs diff=lfs merge=lfs -text
37
+ vocab_coco.file filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import streamlit as st
4
+ import requests
5
+ from PIL import Image
6
+ from model import get_caption_model, generate_caption
7
+
8
+
9
+ @st.cache(allow_output_mutation=True)
10
+ def get_model():
11
+ return get_caption_model()
12
+
13
+ caption_model = get_model()
14
+
15
+
16
+ def predict():
17
+ captions = []
18
+ pred_caption = generate_caption('tmp.jpg', caption_model)
19
+
20
+ st.markdown('#### Predicted Captions:')
21
+ captions.append(pred_caption)
22
+
23
+ for _ in range(4):
24
+ pred_caption = generate_caption('tmp.jpg', caption_model, add_noise=True)
25
+ if pred_caption not in captions:
26
+ captions.append(pred_caption)
27
+
28
+ for c in captions:
29
+ st.write(c)
30
+
31
+ st.title('Image Captioner')
32
+ img_url = st.text_input(label='Enter Image URL')
33
+
34
+ if (img_url != "") and (img_url != None):
35
+ img = Image.open(requests.get(img_url, stream=True).raw)
36
+ img = img.convert('RGB')
37
+ st.image(img)
38
+ img.save('tmp.jpg')
39
+ predict()
40
+ os.remove('tmp.jpg')
41
+
42
+
43
+ st.markdown('<center style="opacity: 70%">OR</center>', unsafe_allow_html=True)
44
+ img_upload = st.file_uploader(label='Upload Image', type=['jpg', 'png', 'jpeg'])
45
+
46
+ if img_upload != None:
47
+ img = img_upload.read()
48
+ img = Image.open(io.BytesIO(img))
49
+ img = img.convert('RGB')
50
+ img.save('tmp.jpg')
51
+ st.image(img)
52
+ predict()
53
+ os.remove('tmp.jpg')
model.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from PIL import Image
3
+ import tensorflow as tf
4
+ import requests
5
+ import numpy as np
6
+
7
+ vocab = pickle.load(open('vocab_coco.file', 'rb'))
8
+
9
+ word = "cat"
10
+ MAX_LENGTH = 40
11
+ VOCABULARY_SIZE = 15000
12
+ BATCH_SIZE = 64
13
+ BUFFER_SIZE = 1000
14
+ EMBEDDING_DIM = 512
15
+ UNITS = 512
16
+
17
+ # Tokenize the word using the adapted TextVectorization layer
18
+ tokenizer = tf.keras.layers.TextVectorization(
19
+ standardize=None,
20
+ output_sequence_length=40,
21
+ vocabulary=vocab)
22
+
23
+ # Convert the tokenized word to a numpy array
24
+ tokenized_word = tokenizer([word])
25
+ tokenized_word = tokenized_word.numpy()
26
+
27
+ # Print the tokenized word
28
+ print("Tokenized word:", tokenized_word)
29
+
30
+ idx2word = tf.keras.layers.StringLookup(
31
+ mask_token="",
32
+ vocabulary=tokenizer.get_vocabulary(),
33
+ invert=True)
34
+
35
+ def load_image_from_path(img_path):
36
+ img = tf.io.read_file(img_path)
37
+ img = tf.io.decode_jpeg(img, channels=3)
38
+ img = tf.keras.layers.Resizing(299, 299)(img)
39
+ img = tf.cast(img, tf.float32) # Convert to tf.float32
40
+ img = tf.keras.applications.inception_v3.preprocess_input(img)
41
+ return img
42
+
43
+ image_augmentation = tf.keras.Sequential(
44
+ [
45
+ tf.keras.layers.RandomFlip("horizontal"),
46
+ tf.keras.layers.RandomRotation(0.2),
47
+ tf.keras.layers.RandomContrast(0.3),
48
+ ]
49
+ )
50
+
51
+ def CNN_Encoder():
52
+ inception_v3 = tf.keras.applications.InceptionV3(
53
+ include_top=False,
54
+ weights='imagenet'
55
+ )
56
+
57
+ output = inception_v3.output
58
+ output = tf.keras.layers.Reshape(
59
+ (-1, output.shape[-1]))(output)
60
+
61
+ cnn_model = tf.keras.models.Model(inception_v3.input, output)
62
+ return cnn_model
63
+
64
+
65
+ class TransformerEncoderLayer(tf.keras.layers.Layer):
66
+
67
+ def __init__(self, embed_dim, num_heads):
68
+ super().__init__()
69
+ self.layer_norm_1 = tf.keras.layers.LayerNormalization()
70
+ self.layer_norm_2 = tf.keras.layers.LayerNormalization()
71
+ self.attention = tf.keras.layers.MultiHeadAttention(
72
+ num_heads=num_heads, key_dim=embed_dim)
73
+ self.dense = tf.keras.layers.Dense(embed_dim, activation="relu")
74
+
75
+
76
+ def call(self, x, training):
77
+ x = self.layer_norm_1(x)
78
+ x = self.dense(x)
79
+
80
+ attn_output = self.attention(
81
+ query=x,
82
+ value=x,
83
+ key=x,
84
+ attention_mask=None,
85
+ training=training
86
+ )
87
+
88
+ x = self.layer_norm_2(x + attn_output)
89
+ return x
90
+
91
+
92
+ class Embeddings(tf.keras.layers.Layer):
93
+
94
+ def __init__(self, vocab_size, embed_dim, max_len):
95
+ super().__init__()
96
+ self.token_embeddings = tf.keras.layers.Embedding(
97
+ vocab_size, embed_dim)
98
+ self.position_embeddings = tf.keras.layers.Embedding(
99
+ max_len, embed_dim, input_shape=(None, max_len))
100
+
101
+
102
+ def call(self, input_ids):
103
+ length = tf.shape(input_ids)[-1]
104
+ position_ids = tf.range(start=0, limit=length, delta=1)
105
+ position_ids = tf.expand_dims(position_ids, axis=0)
106
+
107
+ token_embeddings = self.token_embeddings(input_ids)
108
+ position_embeddings = self.position_embeddings(position_ids)
109
+
110
+ return token_embeddings + position_embeddings
111
+
112
+ class TransformerDecoderLayer(tf.keras.layers.Layer):
113
+
114
+ def __init__(self, embed_dim, units, num_heads):
115
+ super().__init__()
116
+ self.embedding = Embeddings(
117
+ tokenizer.vocabulary_size(), embed_dim, MAX_LENGTH)
118
+
119
+ self.attention_1 = tf.keras.layers.MultiHeadAttention(
120
+ num_heads=num_heads, key_dim=embed_dim, dropout=0.1
121
+ )
122
+ self.attention_2 = tf.keras.layers.MultiHeadAttention(
123
+ num_heads=num_heads, key_dim=embed_dim, dropout=0.1
124
+ )
125
+
126
+ self.layernorm_1 = tf.keras.layers.LayerNormalization()
127
+ self.layernorm_2 = tf.keras.layers.LayerNormalization()
128
+ self.layernorm_3 = tf.keras.layers.LayerNormalization()
129
+
130
+ self.ffn_layer_1 = tf.keras.layers.Dense(units, activation="relu")
131
+ self.ffn_layer_2 = tf.keras.layers.Dense(embed_dim)
132
+
133
+ self.out = tf.keras.layers.Dense(tokenizer.vocabulary_size(), activation="softmax")
134
+
135
+ self.dropout_1 = tf.keras.layers.Dropout(0.3)
136
+ self.dropout_2 = tf.keras.layers.Dropout(0.5)
137
+
138
+
139
+ def call(self, input_ids, encoder_output, training, mask=None):
140
+ embeddings = self.embedding(input_ids)
141
+
142
+ combined_mask = None
143
+ padding_mask = None
144
+
145
+ if mask is not None:
146
+ causal_mask = self.get_causal_attention_mask(embeddings)
147
+ padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32)
148
+ combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32)
149
+ combined_mask = tf.minimum(combined_mask, causal_mask)
150
+
151
+ attn_output_1 = self.attention_1(
152
+ query=embeddings,
153
+ value=embeddings,
154
+ key=embeddings,
155
+ attention_mask=combined_mask,
156
+ training=training
157
+ )
158
+
159
+ out_1 = self.layernorm_1(embeddings + attn_output_1)
160
+
161
+ attn_output_2 = self.attention_2(
162
+ query=out_1,
163
+ value=encoder_output,
164
+ key=encoder_output,
165
+ attention_mask=padding_mask,
166
+ training=training
167
+ )
168
+
169
+ out_2 = self.layernorm_2(out_1 + attn_output_2)
170
+
171
+ ffn_out = self.ffn_layer_1(out_2)
172
+ ffn_out = self.dropout_1(ffn_out, training=training)
173
+ ffn_out = self.ffn_layer_2(ffn_out)
174
+
175
+ ffn_out = self.layernorm_3(ffn_out + out_2)
176
+ ffn_out = self.dropout_2(ffn_out, training=training)
177
+ preds = self.out(ffn_out)
178
+ return preds
179
+
180
+
181
+ def get_causal_attention_mask(self, inputs):
182
+ input_shape = tf.shape(inputs)
183
+ batch_size, sequence_length = input_shape[0], input_shape[1]
184
+ i = tf.range(sequence_length)[:, tf.newaxis]
185
+ j = tf.range(sequence_length)
186
+ mask = tf.cast(i >= j, dtype="int32")
187
+ mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
188
+ mult = tf.concat(
189
+ [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],
190
+ axis=0
191
+ )
192
+ return tf.tile(mask, mult)
193
+
194
+ class ImageCaptioningModel(tf.keras.Model):
195
+
196
+ def __init__(self, cnn_model, encoder, decoder, image_aug=None):
197
+ super().__init__()
198
+ self.cnn_model = cnn_model
199
+ self.encoder = encoder
200
+ self.decoder = decoder
201
+ self.image_aug = image_aug
202
+ self.loss_tracker = tf.keras.metrics.Mean(name="loss")
203
+ self.acc_tracker = tf.keras.metrics.Mean(name="accuracy")
204
+
205
+
206
+ def calculate_loss(self, y_true, y_pred, mask):
207
+ loss = self.loss(y_true, y_pred)
208
+ mask = tf.cast(mask, dtype=loss.dtype)
209
+ loss *= mask
210
+ return tf.reduce_sum(loss) / tf.reduce_sum(mask)
211
+
212
+
213
+ def calculate_accuracy(self, y_true, y_pred, mask):
214
+ accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=2))
215
+ accuracy = tf.math.logical_and(mask, accuracy)
216
+ accuracy = tf.cast(accuracy, dtype=tf.float32)
217
+ mask = tf.cast(mask, dtype=tf.float32)
218
+ return tf.reduce_sum(accuracy) / tf.reduce_sum(mask)
219
+
220
+
221
+ def compute_loss_and_acc(self, img_embed, captions, training=True):
222
+ encoder_output = self.encoder(img_embed, training=True)
223
+ y_input = captions[:, :-1]
224
+ y_true = captions[:, 1:]
225
+ mask = (y_true != 0)
226
+ y_pred = self.decoder(
227
+ y_input, encoder_output, training=True, mask=mask
228
+ )
229
+ loss = self.calculate_loss(y_true, y_pred, mask)
230
+ acc = self.calculate_accuracy(y_true, y_pred, mask)
231
+ return loss, acc
232
+
233
+
234
+ def train_step(self, batch):
235
+ imgs, captions = batch
236
+
237
+ if self.image_aug:
238
+ imgs = self.image_aug(imgs)
239
+
240
+ img_embed = self.cnn_model(imgs)
241
+
242
+ with tf.GradientTape() as tape:
243
+ loss, acc = self.compute_loss_and_acc(
244
+ img_embed, captions
245
+ )
246
+
247
+ train_vars = (
248
+ self.encoder.trainable_variables + self.decoder.trainable_variables
249
+ )
250
+ grads = tape.gradient(loss, train_vars)
251
+ self.optimizer.apply_gradients(zip(grads, train_vars))
252
+ self.loss_tracker.update_state(loss)
253
+ self.acc_tracker.update_state(acc)
254
+
255
+ return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}
256
+
257
+
258
+ def test_step(self, batch):
259
+ imgs, captions = batch
260
+
261
+ img_embed = self.cnn_model(imgs)
262
+
263
+ loss, acc = self.compute_loss_and_acc(
264
+ img_embed, captions, training=False
265
+ )
266
+
267
+ self.loss_tracker.update_state(loss)
268
+ self.acc_tracker.update_state(acc)
269
+
270
+ return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}
271
+
272
+ @property
273
+ def metrics(self):
274
+ return [self.loss_tracker, self.acc_tracker]
275
+
276
+ encoder = TransformerEncoderLayer(EMBEDDING_DIM, 1)
277
+ decoder = TransformerDecoderLayer(EMBEDDING_DIM, UNITS, 8)
278
+
279
+ cnn_model = CNN_Encoder()
280
+ caption_model = ImageCaptioningModel(
281
+ cnn_model=cnn_model, encoder=encoder, decoder=decoder, image_aug=image_augmentation,
282
+ )
283
+
284
+ def get_caption_model():
285
+ encoder = TransformerEncoderLayer(EMBEDDING_DIM, 1)
286
+ decoder = TransformerDecoderLayer(EMBEDDING_DIM, UNITS, 8)
287
+
288
+ cnn_model = CNN_Encoder()
289
+
290
+ caption_model = ImageCaptioningModel(
291
+ cnn_model=cnn_model, encoder=encoder, decoder=decoder, image_aug=None,
292
+ )
293
+
294
+ def call_fn(batch, training=False):
295
+ return batch
296
+
297
+ caption_model.call = call_fn
298
+ sample_x, sample_y = tf.random.normal((1, 299, 299, 3)), tf.zeros((1, 40))
299
+
300
+ caption_model((sample_x, sample_y))
301
+
302
+ sample_img_embed = caption_model.cnn_model(sample_x)
303
+ sample_enc_out = caption_model.encoder(sample_img_embed, training=False)
304
+ caption_model.decoder(sample_y, sample_enc_out, training=False)
305
+
306
+ try:
307
+ caption_model.load_weights('model.h5')
308
+ except FileNotFoundError:
309
+ caption_model.load_weights('model.h5')
310
+
311
+ return caption_model
312
+
313
+
314
+ def get_model():
315
+ return get_caption_model()
316
+
317
+ caption_model = get_model()
318
+
319
+
320
+ def load_image_from_path(img_path):
321
+ img = tf.io.read_file(img_path)
322
+ img = tf.io.decode_jpeg(img, channels=3)
323
+ img = tf.keras.layers.Resizing(299, 299)(img)
324
+ img = tf.cast(img, tf.float32) # Convert to tf.float32
325
+ img = tf.keras.applications.inception_v3.preprocess_input(img)
326
+ return img
327
+
328
+ def generate_caption(img_path,caption_model, add_noise=False):
329
+ img = load_image_from_path(img_path)
330
+
331
+ if add_noise:
332
+ noise = tf.random.normal(img.shape)*0.1
333
+ img = img + noise
334
+ img = (img - tf.reduce_min(img))/(tf.reduce_max(img) - tf.reduce_min(img))
335
+
336
+ img = tf.expand_dims(img, axis=0)
337
+ img_embed = caption_model.cnn_model(img)
338
+ img_encoded = caption_model.encoder(img_embed, training=False)
339
+
340
+ y_inp = '[start]'
341
+ for i in range(MAX_LENGTH-1):
342
+ tokenized = tokenizer([y_inp])[:, :-1]
343
+ mask = tf.cast(tokenized != 0, tf.int32)
344
+ pred = caption_model.decoder(
345
+ tokenized, img_encoded, training=False, mask=mask)
346
+
347
+ pred_idx = np.argmax(pred[0, i, :])
348
+ pred_idx = tf.convert_to_tensor(pred_idx)
349
+ pred_word = idx2word(pred_idx).numpy().decode('utf-8')
350
+ if pred_word == '[end]':
351
+ break
352
+
353
+ y_inp += ' ' + pred_word
354
+
355
+ y_inp = y_inp.replace('[start] ', '')
356
+ return y_inp
357
+
vocab_coco.file ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db3679ac5eae9c774916e24b87704dad600dcd230808a5935a09b7abf189495b
3
+ size 1350141