sonamsherpa commited on
Commit
27afece
1 Parent(s): dbebdf9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +685 -2
app.py CHANGED
@@ -1,3 +1,686 @@
1
- import pandas as pd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- print('Hello World!')
 
1
+ # -*- coding: utf-8 -*-
2
+ """Copy of Copy of Imagecaption_generator_AIML.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1Thp1MpIDt-AnhXifbSu-AeGQRI8iR3-E
8
+ """
9
+
10
+ !pip install wget
11
+
12
+ import os
13
+ import re
14
+ import numpy as np
15
+ import matplotlib.pyplot as plt
16
+ import requests
17
+ import tensorflow as tf
18
+ from tensorflow import keras
19
+ from tensorflow.keras import layers
20
+ import shutil
21
+ from tensorflow.keras.applications import efficientnet
22
+ import wget
23
+ from tensorflow.keras.layers import TextVectorization
24
+
25
+
26
+ seed = 111
27
+ np.random.seed(seed)
28
+ tf.random.set_seed(seed)
29
+
30
+ !wget -q https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip
31
+ !wget -q https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip
32
+ !unzip -qq Flickr8k_Dataset.zip
33
+ !unzip -qq Flickr8k_text.zip
34
+ !rm Flickr8k_Dataset.zip Flickr8k_text.zip
35
+
36
+ # Desired image dimensions
37
+ image_size = (299, 299)
38
+
39
+ # Vocabulary size
40
+ vocabulary_size = 10000
41
+
42
+ # Fixed length allowed for any sequence
43
+ sequence_length = 25
44
+
45
+ # Dimension for the image embeddings and token embeddings
46
+ # Per-layer units in the feed-forward network
47
+ embedded_dimension = feed_forward_dimension = EMBED_DIM = 512
48
+
49
+ # Other training parameters
50
+ batch_size = 64
51
+ epochs = 30
52
+ autotune = tf.data.AUTOTUNE
53
+
54
+ def map_image_caption(filename):
55
+ '''
56
+ Load caption and maps each caption to respecitve image
57
+ Returns: Dictionay of image name and its captions and list contatining all the captions
58
+ '''
59
+
60
+ with open(filename) as caption_file:
61
+ caption_data = caption_file.readlines()
62
+ mapped_captions = {}
63
+ text_data = []
64
+ skip_these_images = set()
65
+
66
+ for c_data in caption_data:
67
+ # Image's name and caption is seperated by tab so split them into separate variable
68
+ image_name, caption = c_data.strip("\n").split("\t")
69
+ caption = caption.strip()
70
+
71
+ # There are 5 captions for each images and each images name has suffix '#(caption_number)' so remove everything after # and strip for any whitespaces
72
+ image_name = os.path.join('Flicker8k_Dataset', image_name.split("#")[0].strip())
73
+
74
+ # We will remove caption that are either too short to too long
75
+ tokens = caption.strip().split()
76
+
77
+ if len(tokens) < 5 or len(tokens) > sequence_length:
78
+ skip_these_images.add(image_name)
79
+ continue
80
+
81
+ if image_name.endswith("jpg") and image_name not in skip_these_images:
82
+ # Add start and end tags to identify the begining and ending of captions
83
+ text_data.append("<start> " + caption + " <end>")
84
+
85
+ if image_name in mapped_captions:
86
+ mapped_captions[image_name].append(caption)
87
+ else:
88
+ mapped_captions[image_name] = [caption]
89
+
90
+ for image_name in skip_these_images:
91
+ if image_name in mapped_captions:
92
+ del mapped_captions[image_name]
93
+
94
+ return mapped_captions, text_data
95
+
96
+ def train_val_split(caption_data):
97
+ '''
98
+ Split train and test data for training and testing
99
+ '''
100
+ train_size = 0.8
101
+
102
+ # Get list of image names and convert to list
103
+ list_of_images = list(caption_data.keys())
104
+
105
+ # Shuffle for randomness
106
+ np.random.shuffle(list_of_images)
107
+
108
+ # Split data into training and testing
109
+ train_size = int(len(caption_data) * train_size)
110
+
111
+ train_data = {
112
+ name: caption_data[name] for name in list_of_images[:train_size]
113
+ }
114
+ test_data = {
115
+ name: caption_data[name] for name in list_of_images[train_size:]
116
+ }
117
+
118
+ return train_data, test_data
119
+
120
+ # Load the dataset
121
+ captions_mapping, text_data = map_image_caption("Flickr8k.token.txt")
122
+
123
+ # Split the dataset into training and validation sets
124
+ training_data, validation_data = train_val_split(captions_mapping)
125
+ print("Number of training samples here: ", len(training_data))
126
+ print("Number of validation samples here: ", len(validation_data))
127
+
128
+ def standardize(input_string):
129
+ strip_chars = "!\"#$%&'()*+,-./:;<=>?@[\]^_`{|}~".replace("<", "").replace(">", "")
130
+ return tf.strings.regex_replace(tf.strings.lower(input_string), "[%s]" % re.escape(strip_chars), "")
131
+
132
+ vectorization = TextVectorization(
133
+ max_tokens=vocabulary_size,
134
+ output_mode="int",
135
+ output_sequence_length=sequence_length,
136
+ standardize=standardize,
137
+ )
138
+ vectorization.adapt(text_data)
139
+
140
+ # Data augmentation for image data
141
+ image_augmentation = keras.Sequential(
142
+ [
143
+ layers.RandomFlip("horizontal"),
144
+ layers.RandomRotation(0.2),
145
+ layers.RandomContrast(0.3),
146
+ ]
147
+ )
148
+
149
+ def decoder_to_resizer(img_path):
150
+ '''
151
+ Decodes jpg and resize and converts images to float for processing
152
+ '''
153
+ image = tf.io.read_file(img_path)
154
+ decoded_image = tf.image.decode_jpeg(image, channels=3)
155
+ resized_image = tf.image.resize(decoded_image, image_size)
156
+ return tf.image.convert_image_dtype(resized_image, tf.float32)
157
+
158
+ def process_input(img_path, captions):
159
+ '''
160
+ Returns decoded jpg in float after resizing to standard size, returns vectorized caption detail
161
+ '''
162
+ return decoder_to_resizer(img_path), vectorization(captions)
163
+
164
+ def prepare_dataset(images, captions):
165
+ dataset = tf.data.Dataset.from_tensor_slices((images, captions))
166
+ return dataset.shuffle(batch_size * 8).map(process_input, num_parallel_calls=autotune).batch(batch_size).prefetch(autotune)
167
+
168
+
169
+
170
+ training_dataset = prepare_dataset(list(training_data.keys()), list(training_data.values()))
171
+ validation_dataset = prepare_dataset(list(validation_data.keys()), list(validation_data.values()))
172
+
173
+ training_dataset
174
+
175
+ validation_dataset
176
+
177
+ def prepare_cnn_model():
178
+ base_model = efficientnet.EfficientNetB0(
179
+ input_shape=(*image_size, 3), include_top=False, weights="imagenet",
180
+ )
181
+ # We freeze our feature extractor
182
+ base_model.trainable = False
183
+ base_model_out = base_model.output
184
+ base_model_out = layers.Reshape((-1, base_model_out.shape[-1]))(base_model_out)
185
+ cnn_model = keras.models.Model(base_model.input, base_model_out)
186
+ return cnn_model
187
+
188
+ class EncoderClass(layers.Layer):
189
+ ''' Encoder block that inherits layer and uses layer for neural network model
190
+
191
+ '''
192
+ def __init__(self, embedded_dimension, dense_dimension, number_of_heads, **kwargs):
193
+ super().__init__(**kwargs)
194
+ self.embedded_dimension = embedded_dimension
195
+ self.dense_dimension = dense_dimension
196
+ self.number_of_heads = number_of_heads
197
+
198
+ # A multi headed self attention layer with no dropout
199
+ self.mh_attention_layer = layers.MultiHeadAttention(
200
+ num_heads=number_of_heads,
201
+ key_dim=embedded_dimension,
202
+ dropout=0.0
203
+ )
204
+
205
+ # Normalization layers
206
+ # There layers noramlizes the input we can compare it to Standard Scaler in traditional machine learning algorithm
207
+ self.normalization_layer_1 = layers.LayerNormalization()
208
+ self.normalization_layer_2 = layers.LayerNormalization()
209
+
210
+ # Dense layer with relu activation
211
+ self.dense_layer = layers.Dense(embedded_dimension, activation="relu")
212
+
213
+ def call(self, inputs, training):
214
+ # Here the inputs for multiheaded attention layers are passed with combination of normalization layer and dense layer
215
+ inputs = self.dense_layer(self.normalization_layer_1(inputs))
216
+
217
+ attention_output_1 = self.mh_attention_layer(
218
+ query=inputs,
219
+ value=inputs,
220
+ key=inputs,
221
+ attention_mask=None,
222
+ training=training,
223
+ )
224
+
225
+ # Here after applying attention mechanism in original input, it is passed from another normalization layer
226
+ return self.normalization_layer_2(inputs + attention_output_1)
227
+
228
+ class EmbedTokenAndPostionClass(layers.Layer):
229
+ ''' This call will embed token and its position together giving both semantic and contextual meaning to each token
230
+ '''
231
+ def __init__(self, sequence_length, vocabulary_size, embedded_dimension, **kwargs):
232
+ super().__init__(**kwargs)
233
+
234
+ # Initialize Embedding layer to embed tokens, here inputs is the vocabulary size and output dimension is the embedded dimension
235
+ # This layer captures the semantic meaning of token in the inputs. This helps to understand the meaning of words and their relationship
236
+ self.token_embeddings = layers.Embedding(
237
+ input_dim=vocabulary_size,
238
+ output_dim=embedded_dimension
239
+ )
240
+
241
+ # Initialize Embedding layer that embebs positions, here inputs is the sequence length and output dimension is the embedded dimension
242
+ # This simply helps to capture the position of the input or order or where a particular token is
243
+ self.position_embeddings = layers.Embedding(
244
+ input_dim=sequence_length,
245
+ output_dim=embedded_dimension
246
+ )
247
+ self.sequence_length = sequence_length
248
+ self.vocabulary_size = vocabulary_size
249
+ self.embedded_dimension = embedded_dimension
250
+
251
+ # Calculate the square root of embedded dimension and convert to float 32
252
+ # This is done to prevent magnitude/value of embedded dimension from becoming too high
253
+ self.embedded_scale = tf.math.sqrt(tf.cast(embedded_dimension, tf.float32))
254
+
255
+ def call(self, inputs):
256
+
257
+ # Get all the positions
258
+ positions = tf.range(start=0, limit=tf.shape(inputs)[-1], delta=1)
259
+
260
+ # Pass input through token embedding
261
+ # This will generate continous vector for each token
262
+ embedded_tokens = self.token_embeddings(inputs) * self.embedded_scale
263
+ embedded_positions = self.position_embeddings(positions)
264
+
265
+ # Combine vector and their position, capturing both sematic meaning of the words and its contextual meaning
266
+ return embedded_tokens + embedded_positions
267
+
268
+ def compute_mask(self, inputs, mask=None):
269
+ return tf.math.not_equal(inputs, 0)
270
+
271
+ class DecoderClass(layers.Layer):
272
+ '''This is the decoder component of our model. This will decode the vector space that has been encoded and embedded with its postions.
273
+ It uses self attention and cross attention mechanism along with feed forward NN layer to give output sequences.
274
+ '''
275
+
276
+ def __init__(self, embedded_dimension, feed_forward_dimension, number_of_heads, **kwargs):
277
+ super().__init__(**kwargs)
278
+ self.embed_dim = embedded_dimension
279
+ self.feed_forward_dimension = feed_forward_dimension
280
+ self.number_of_heads = number_of_heads
281
+
282
+ self.first_attention_layer = layers.MultiHeadAttention(
283
+ num_heads=number_of_heads,
284
+ key_dim=embedded_dimension,
285
+ dropout=0.1
286
+ )
287
+
288
+ self.second_attention_layer = layers.MultiHeadAttention(
289
+ num_heads=number_of_heads,
290
+ key_dim=embedded_dimension,
291
+ dropout=0.1
292
+ )
293
+
294
+ self.first_feed_forward_layer = layers.Dense(feed_forward_dimension, activation="relu")
295
+ self.second_feed_forward_layer = layers.Dense(embedded_dimension)
296
+
297
+ self.first_normalization_layer = layers.LayerNormalization()
298
+ self.second_normalization_layer = layers.LayerNormalization()
299
+ self.third_normalization_layer = layers.LayerNormalization()
300
+
301
+ self.embedding = EmbedTokenAndPostionClass(
302
+ embedded_dimension=embedded_dimension,
303
+ sequence_length=sequence_length,
304
+ vocabulary_size=vocabulary_size
305
+ )
306
+
307
+ self.out = layers.Dense(vocabulary_size, activation="softmax")
308
+
309
+ self.first_dropout_layer = layers.Dropout(0.3)
310
+ self.second_dropout_layer = layers.Dropout(0.5)
311
+ self.supports_masking = True
312
+
313
+ def call(self, inputs, encoder_outputs, training, mask=None):
314
+ inputs = self.embedding(inputs)
315
+ causal_mask = self.get_causal_attention_mask(inputs)
316
+
317
+ if mask is not None:
318
+ padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32)
319
+ combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32)
320
+ combined_mask = tf.minimum(combined_mask, causal_mask)
321
+
322
+ first_attention_output = self.first_attention_layer(
323
+ query=inputs,
324
+ value=inputs,
325
+ key=inputs,
326
+ attention_mask=combined_mask,
327
+ training=training,
328
+ )
329
+ first_normalization_output = self.first_normalization_layer(inputs + first_attention_output)
330
+
331
+ second_attention_output = self.second_attention_layer(
332
+ query=first_normalization_output,
333
+ value=encoder_outputs,
334
+ key=encoder_outputs,
335
+ attention_mask=padding_mask,
336
+ training=training,
337
+ )
338
+ second_normalization_output = self.second_normalization_layer(first_normalization_output + second_attention_output)
339
+
340
+ output = self.first_feed_forward_layer(second_normalization_output)
341
+ output = self.first_dropout_layer(output, training=training)
342
+ output = self.second_feed_forward_layer(output)
343
+
344
+ output = self.third_normalization_layer(output + second_normalization_output, training=training)
345
+ output = self.second_dropout_layer(output, training=training)
346
+ return self.out(output)
347
+
348
+ def get_causal_attention_mask(self, inputs):
349
+ input_shape = tf.shape(inputs)
350
+ batch_size, sequence_length = input_shape[0], input_shape[1]
351
+ i = tf.range(sequence_length)[:, tf.newaxis]
352
+ j = tf.range(sequence_length)
353
+ mask = tf.cast(i >= j, dtype="int32")
354
+ mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
355
+ mult = tf.concat(
356
+ [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],
357
+ axis=0,
358
+ )
359
+ return tf.tile(mask, mult)
360
+
361
+ class ImageCaptionClass(keras.Model):
362
+ def __init__(
363
+ self, efficient_net_model, encoder_class, decoder_class, image_augmentation=None,
364
+ ):
365
+ super().__init__()
366
+ self.efficient_net_model = efficient_net_model
367
+ self.encoder_class = encoder_class
368
+ self.decoder_class = decoder_class
369
+ self.loss_tracker = keras.metrics.Mean(name="loss")
370
+ self.acc_tracker = keras.metrics.Mean(name="accuracy")
371
+ self.caption_to_image_ration = 5
372
+ self.image_augmentation = image_augmentation
373
+
374
+ def calculate_loss(self, y_actual_value, y_predicted_vaue, mask):
375
+ loss = self.loss(y_actual_value, y_predicted_vaue)
376
+ mask = tf.cast(mask, dtype=loss.dtype)
377
+ loss *= mask
378
+ return tf.reduce_sum(loss) / tf.reduce_sum(mask)
379
+
380
+ def calculate_accuracy(self, y_actual_value, y_predicted_vaue, mask):
381
+ accuracy = tf.equal(y_actual_value, tf.argmax(y_predicted_vaue, axis=2))
382
+ accuracy = tf.math.logical_and(mask, accuracy)
383
+ accuracy = tf.cast(accuracy, dtype=tf.float32)
384
+ mask = tf.cast(mask, dtype=tf.float32)
385
+ return tf.reduce_sum(accuracy) / tf.reduce_sum(mask)
386
+
387
+ def get_caption_loss_and_accuracy(self, image_embedded, batch_sequence, calculate_for_train=True):
388
+ encoder_class_out = self.encoder_class(image_embedded, training=calculate_for_train)
389
+ batch_sequence_input = batch_sequence[:, :-1]
390
+ batch_sequence_actual = batch_sequence[:, 1:]
391
+ mask = tf.math.not_equal(batch_sequence_actual, 0)
392
+ batch_sequence_predicted = self.decoder_class(
393
+ batch_sequence_input, encoder_class_out, training=calculate_for_train, mask=mask
394
+ )
395
+ loss = self.calculate_loss(batch_sequence_actual, batch_sequence_predicted, mask)
396
+ acc = self.calculate_accuracy(batch_sequence_actual, batch_sequence_predicted, mask)
397
+ return loss, acc
398
+
399
+ def train_step(self, data):
400
+ batch_image, batch_sequence = data
401
+ batch_loss = 0
402
+ batch_accuracy = 0
403
+
404
+ if self.image_augmentation:
405
+ batch_image = self.image_augmentation(batch_image)
406
+
407
+ # 1. Get image embeddings
408
+ image_embedded = self.efficient_net_model(batch_image)
409
+
410
+ # 2. Pass each of the five captions one by one to the decoder_class
411
+ # along with the encoder_class outputs and compute the loss as well as accuracy
412
+ # for each caption.
413
+ for i in range(self.caption_to_image_ration):
414
+ with tf.GradientTape() as gradient_tape:
415
+ loss, acc = self.get_caption_loss_and_accuracy(
416
+ image_embedded, batch_sequence[:, i, :], calculate_for_train=True
417
+ )
418
+
419
+ # 3. Update loss and accuracy
420
+ batch_loss += loss
421
+ batch_accuracy += acc
422
+
423
+ # 4. Get the list of all the trainable weights
424
+ training_weights = (
425
+ self.encoder_class.trainable_variables + self.decoder_class.trainable_variables
426
+ )
427
+
428
+ # 5. Get the gradients
429
+ gradient_lists = gradient_tape.gradient(loss, training_weights)
430
+
431
+ # 6. Update the trainable weights
432
+ self.optimizer.apply_gradients(zip(gradient_lists, training_weights))
433
+
434
+ # 7. Update the trackers
435
+ batch_accuracy /= float(self.caption_to_image_ration)
436
+ self.loss_tracker.update_state(batch_loss)
437
+ self.acc_tracker.update_state(batch_accuracy)
438
+
439
+ # 8. Return the loss and accuracy values
440
+ return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}
441
+
442
+ def test_step(self, data):
443
+ batch_image, batch_sequence = data
444
+ batch_loss = 0
445
+ batch_accuracy = 0
446
+
447
+ # 1. Get image embeddings
448
+ image_embedded = self.efficient_net_model(batch_image)
449
+
450
+ # 2. Pass each of the five captions one by one to the decoder_class
451
+ # along with the encoder_class outputs and compute the loss as well as accuracy
452
+ # for each caption.
453
+ for i in range(self.caption_to_image_ration):
454
+ loss, acc = self.get_caption_loss_and_accuracy(
455
+ image_embedded, batch_sequence[:, i, :], calculate_for_train=False
456
+ )
457
+
458
+ # 3. Update batch loss and batch accuracy
459
+ batch_loss += loss
460
+ batch_accuracy += acc
461
+
462
+ batch_accuracy /= float(self.caption_to_image_ration)
463
+
464
+ # 4. Update the trackers
465
+ self.loss_tracker.update_state(batch_loss)
466
+ self.acc_tracker.update_state(batch_accuracy)
467
+
468
+ # 5. Return the loss and accuracy values
469
+ return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}
470
+
471
+ @property
472
+ def metrics(self):
473
+ # We need to list our metrics here so the `reset_states()` can be
474
+ # called automatically.
475
+ return [self.loss_tracker, self.acc_tracker]
476
+ def get_config(self):
477
+ # Return a dictionary containing the configuration of your model.
478
+ config = {
479
+ "efficient_net_model": self.efficient_net_model,
480
+ "encoder_class": self.encoder_class,
481
+ "decoder_class": self.decoder_class,
482
+ "caption_to_image_ration": self.caption_to_image_ration,
483
+ "image_augmentation": self.image_augmentation,
484
+ }
485
+ return config
486
+
487
+
488
+ def call(self, data):
489
+ batch_image, batch_sequence = data
490
+ batch_loss = 0
491
+ batch_accuracy = 0
492
+
493
+ if self.image_augmentation:
494
+ batch_image = self.image_augmentation(batch_image)
495
+
496
+ # 1. Get image embeddings
497
+ image_embedded = self.efficient_net_model(batch_image)
498
+
499
+ # 2. Pass each of the five captions one by one to the decoder_class
500
+ # along with the encoder_class outputs and compute the loss as well as accuracy
501
+ # for each caption.
502
+ for i in range(self.caption_to_image_ration):
503
+ loss, acc = self.get_caption_loss_and_accuracy(
504
+ image_embedded, batch_sequence[:, i, :], calculate_for_train=True
505
+ )
506
+
507
+ # 3. Update batch loss and batch accuracy
508
+ batch_loss += loss
509
+ batch_accuracy += acc
510
+
511
+ batch_accuracy /= float(self.caption_to_image_ration)
512
+
513
+ # 4. Update the trackers
514
+ self.loss_tracker.update_state(batch_loss)
515
+ self.acc_tracker.update_state(batch_accuracy)
516
+
517
+ # 5. Return the loss and accuracy values
518
+ return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}
519
+
520
+
521
+
522
+
523
+
524
+ cnn_model = prepare_cnn_model()
525
+ encoder = EncoderClass(embedded_dimension=embedded_dimension, dense_dimension=feed_forward_dimension, number_of_heads=1)
526
+ decoder = DecoderClass(embedded_dimension=embedded_dimension, feed_forward_dimension=feed_forward_dimension, number_of_heads=2)
527
+ caption_model = ImageCaptionClass(
528
+ efficient_net_model=cnn_model, encoder_class=encoder, decoder_class=decoder, image_augmentation=image_augmentation,
529
+ )
530
+ caption_model
531
+
532
+ cross_entropy_loss_f = keras.losses.SparseCategoricalCrossentropy(reduction="none")
533
+
534
+ early_stopping_function = keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True)
535
+
536
+ # Create a class that inherits Learning Rate Schedule class from Keras,
537
+ # this class will determines how slow or fast the model adjusts its parameter according to the loss function
538
+
539
+ class LRSClass(keras.optimizers.schedules.LearningRateSchedule):
540
+ def __init__(self, learning_rate_post_warmup, steps):
541
+ super().__init__()
542
+ self.learning_rate_post_warmup = learning_rate_post_warmup
543
+ self.steps = steps
544
+
545
+ def __call__(self, step):
546
+ global_step = tf.cast(step, tf.float32)
547
+ steps = tf.cast(self.steps, tf.float32)
548
+ progress = global_step / steps
549
+ learning_rate = self.learning_rate_post_warmup * progress
550
+ return tf.cond(
551
+ global_step < steps,
552
+ lambda: learning_rate,
553
+ lambda: self.learning_rate_post_warmup,
554
+ )
555
+
556
+ # Number of optimization steps required
557
+ num_train_steps = len(training_dataset) * epochs
558
+
559
+ # No. of steps where learning rate is gradually increased.
560
+ warmup_steps = num_train_steps // 15
561
+
562
+ lr_schedule = LRSClass(learning_rate_post_warmup=1e-4, steps=warmup_steps)
563
+
564
+ # Compile the model
565
+ caption_model.compile(optimizer=keras.optimizers.Adam(lr_schedule), loss=cross_entropy_loss_f)
566
+
567
+ # Fit the model
568
+ caption_model.fit(
569
+ training_dataset,
570
+ epochs=epochs,
571
+ validation_data=validation_dataset,
572
+ callbacks=[early_stopping_function],
573
+ )
574
+
575
+
576
+
577
+ caption_model.summary()
578
+
579
+
580
+
581
+ test_vocabulary = vectorization.get_vocabulary()
582
+ index_lookup = dict(zip(range(len(test_vocabulary)), test_vocabulary))
583
+ max_decoded_sentence_length = sequence_length - 1
584
+ valid_images = list(validation_data.keys())
585
+
586
+
587
+ def generate_caption():
588
+ # Select a random image from the validation dataset
589
+ validate_image = np.random.choice(valid_images)
590
+
591
+ # Get sample image and decode/ resize
592
+ validate_image = decoder_to_resizer(validate_image)
593
+ image = validate_image.numpy().clip(0, 255).astype(np.uint8)
594
+ plt.imshow(image)
595
+ plt.show()
596
+
597
+ # Prepare image and send it the efficient net model
598
+ image = tf.expand_dims(validate_image, 0)
599
+ image = caption_model.efficient_net_model(image)
600
+
601
+ # Pass the image features to the Transformer encoder
602
+ encoded_img = caption_model.encoder_class(image, training=False)
603
+
604
+ # Generate the caption using the Transformer decoder
605
+ decoded_caption = "<start> "
606
+ for i in range(max_decoded_sentence_length):
607
+ tokenized_caption = vectorization([decoded_caption])[:, :-1]
608
+ mask = tf.math.not_equal(tokenized_caption, 0)
609
+ predictions = caption_model.decoder_class(
610
+ tokenized_caption, encoded_img, training=False, mask=mask
611
+ )
612
+ sampled_token_index = np.argmax(predictions[0, i, :])
613
+ sampled_token = index_lookup[sampled_token_index]
614
+ if sampled_token == "<end>":
615
+ break
616
+ decoded_caption += " " + sampled_token
617
+
618
+ decoded_caption = decoded_caption.replace("<start> ", "")
619
+ decoded_caption = decoded_caption.replace(" <end>", "").strip()
620
+ print("Predicted Caption: ", decoded_caption)
621
+
622
+ generate_caption()
623
+
624
+ generate_caption()
625
+
626
+ generate_caption()
627
+
628
+ def generate_caption_custom(img_path):
629
+ # Select a random image from the validation dataset
630
+
631
+
632
+ validate_image = img_path
633
+ print(validate_image)
634
+ # Get sample image and decode/ resize
635
+ validate_image = decoder_to_resizer(validate_image)
636
+ image = validate_image.numpy().clip(0, 255).astype(np.uint8)
637
+ plt.imshow(image)
638
+ plt.show()
639
+
640
+ # Prepare image and send it the efficient net model
641
+ image = tf.expand_dims(validate_image, 0)
642
+ image = caption_model.efficient_net_model(image)
643
+
644
+ # Pass the image features to the Transformer encoder
645
+ encoded_img = caption_model.encoder_class(image, training=False)
646
+
647
+ # Generate the caption using the Transformer decoder
648
+ decoded_caption = "<start> "
649
+ for i in range(max_decoded_sentence_length):
650
+ tokenized_caption = vectorization([decoded_caption])[:, :-1]
651
+ mask = tf.math.not_equal(tokenized_caption, 0)
652
+ predictions = caption_model.decoder_class(
653
+ tokenized_caption, encoded_img, training=False, mask=mask
654
+ )
655
+ sampled_token_index = np.argmax(predictions[0, i, :])
656
+ sampled_token = index_lookup[sampled_token_index]
657
+ if sampled_token == "<end>":
658
+ break
659
+ decoded_caption += " " + sampled_token
660
+
661
+ decoded_caption = decoded_caption.replace("<start> ", "")
662
+ decoded_caption = decoded_caption.replace(" <end>", "").strip()
663
+ print("Predicted Caption: ", decoded_caption)
664
+
665
+ # generate_caption_custom("./image2.jpg")
666
+
667
+ # generate_caption_custom("./Document.jpeg")
668
+
669
+ def generate_with_link(url):
670
+ file_name = wget.download(url)
671
+ generate_caption_custom(file_name)
672
+
673
+ link = 'https://media.istockphoto.com/id/1346503960/photo/school-children-with-a-parachute.jpg?s=1024x1024&w=is&k=20&c=HNOFWi02yU4NB_98iIWKHbzpGlWPYcfQagnPthD2eOo='
674
+ generate_with_link(link)
675
+
676
+ caption_model.save('path/to/location', save_format='tf')
677
+
678
+ image_shape = (*image_size, 3) # Assuming RGB images
679
+ caption_shape = (5, sequence_length) # For 5 captions with max sequence length
680
+
681
+ caption_model.build(input_shape=[(None, *image_shape), (None, *caption_shape)])
682
+
683
+ # Save the model
684
+ path_to_save = 'path_to_save'
685
+ caption_model.save(path_to_save)
686