Spaces:
Runtime error
Runtime error
sonamsherpa
commited on
Commit
•
27afece
1
Parent(s):
dbebdf9
Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,686 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
print('Hello World!')
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""Copy of Copy of Imagecaption_generator_AIML.ipynb
|
3 |
+
|
4 |
+
Automatically generated by Colaboratory.
|
5 |
+
|
6 |
+
Original file is located at
|
7 |
+
https://colab.research.google.com/drive/1Thp1MpIDt-AnhXifbSu-AeGQRI8iR3-E
|
8 |
+
"""
|
9 |
+
|
10 |
+
!pip install wget
|
11 |
+
|
12 |
+
import os
|
13 |
+
import re
|
14 |
+
import numpy as np
|
15 |
+
import matplotlib.pyplot as plt
|
16 |
+
import requests
|
17 |
+
import tensorflow as tf
|
18 |
+
from tensorflow import keras
|
19 |
+
from tensorflow.keras import layers
|
20 |
+
import shutil
|
21 |
+
from tensorflow.keras.applications import efficientnet
|
22 |
+
import wget
|
23 |
+
from tensorflow.keras.layers import TextVectorization
|
24 |
+
|
25 |
+
|
26 |
+
seed = 111
|
27 |
+
np.random.seed(seed)
|
28 |
+
tf.random.set_seed(seed)
|
29 |
+
|
30 |
+
!wget -q https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip
|
31 |
+
!wget -q https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip
|
32 |
+
!unzip -qq Flickr8k_Dataset.zip
|
33 |
+
!unzip -qq Flickr8k_text.zip
|
34 |
+
!rm Flickr8k_Dataset.zip Flickr8k_text.zip
|
35 |
+
|
36 |
+
# Desired image dimensions
|
37 |
+
image_size = (299, 299)
|
38 |
+
|
39 |
+
# Vocabulary size
|
40 |
+
vocabulary_size = 10000
|
41 |
+
|
42 |
+
# Fixed length allowed for any sequence
|
43 |
+
sequence_length = 25
|
44 |
+
|
45 |
+
# Dimension for the image embeddings and token embeddings
|
46 |
+
# Per-layer units in the feed-forward network
|
47 |
+
embedded_dimension = feed_forward_dimension = EMBED_DIM = 512
|
48 |
+
|
49 |
+
# Other training parameters
|
50 |
+
batch_size = 64
|
51 |
+
epochs = 30
|
52 |
+
autotune = tf.data.AUTOTUNE
|
53 |
+
|
54 |
+
def map_image_caption(filename):
|
55 |
+
'''
|
56 |
+
Load caption and maps each caption to respecitve image
|
57 |
+
Returns: Dictionay of image name and its captions and list contatining all the captions
|
58 |
+
'''
|
59 |
+
|
60 |
+
with open(filename) as caption_file:
|
61 |
+
caption_data = caption_file.readlines()
|
62 |
+
mapped_captions = {}
|
63 |
+
text_data = []
|
64 |
+
skip_these_images = set()
|
65 |
+
|
66 |
+
for c_data in caption_data:
|
67 |
+
# Image's name and caption is seperated by tab so split them into separate variable
|
68 |
+
image_name, caption = c_data.strip("\n").split("\t")
|
69 |
+
caption = caption.strip()
|
70 |
+
|
71 |
+
# There are 5 captions for each images and each images name has suffix '#(caption_number)' so remove everything after # and strip for any whitespaces
|
72 |
+
image_name = os.path.join('Flicker8k_Dataset', image_name.split("#")[0].strip())
|
73 |
+
|
74 |
+
# We will remove caption that are either too short to too long
|
75 |
+
tokens = caption.strip().split()
|
76 |
+
|
77 |
+
if len(tokens) < 5 or len(tokens) > sequence_length:
|
78 |
+
skip_these_images.add(image_name)
|
79 |
+
continue
|
80 |
+
|
81 |
+
if image_name.endswith("jpg") and image_name not in skip_these_images:
|
82 |
+
# Add start and end tags to identify the begining and ending of captions
|
83 |
+
text_data.append("<start> " + caption + " <end>")
|
84 |
+
|
85 |
+
if image_name in mapped_captions:
|
86 |
+
mapped_captions[image_name].append(caption)
|
87 |
+
else:
|
88 |
+
mapped_captions[image_name] = [caption]
|
89 |
+
|
90 |
+
for image_name in skip_these_images:
|
91 |
+
if image_name in mapped_captions:
|
92 |
+
del mapped_captions[image_name]
|
93 |
+
|
94 |
+
return mapped_captions, text_data
|
95 |
+
|
96 |
+
def train_val_split(caption_data):
|
97 |
+
'''
|
98 |
+
Split train and test data for training and testing
|
99 |
+
'''
|
100 |
+
train_size = 0.8
|
101 |
+
|
102 |
+
# Get list of image names and convert to list
|
103 |
+
list_of_images = list(caption_data.keys())
|
104 |
+
|
105 |
+
# Shuffle for randomness
|
106 |
+
np.random.shuffle(list_of_images)
|
107 |
+
|
108 |
+
# Split data into training and testing
|
109 |
+
train_size = int(len(caption_data) * train_size)
|
110 |
+
|
111 |
+
train_data = {
|
112 |
+
name: caption_data[name] for name in list_of_images[:train_size]
|
113 |
+
}
|
114 |
+
test_data = {
|
115 |
+
name: caption_data[name] for name in list_of_images[train_size:]
|
116 |
+
}
|
117 |
+
|
118 |
+
return train_data, test_data
|
119 |
+
|
120 |
+
# Load the dataset
|
121 |
+
captions_mapping, text_data = map_image_caption("Flickr8k.token.txt")
|
122 |
+
|
123 |
+
# Split the dataset into training and validation sets
|
124 |
+
training_data, validation_data = train_val_split(captions_mapping)
|
125 |
+
print("Number of training samples here: ", len(training_data))
|
126 |
+
print("Number of validation samples here: ", len(validation_data))
|
127 |
+
|
128 |
+
def standardize(input_string):
|
129 |
+
strip_chars = "!\"#$%&'()*+,-./:;<=>?@[\]^_`{|}~".replace("<", "").replace(">", "")
|
130 |
+
return tf.strings.regex_replace(tf.strings.lower(input_string), "[%s]" % re.escape(strip_chars), "")
|
131 |
+
|
132 |
+
vectorization = TextVectorization(
|
133 |
+
max_tokens=vocabulary_size,
|
134 |
+
output_mode="int",
|
135 |
+
output_sequence_length=sequence_length,
|
136 |
+
standardize=standardize,
|
137 |
+
)
|
138 |
+
vectorization.adapt(text_data)
|
139 |
+
|
140 |
+
# Data augmentation for image data
|
141 |
+
image_augmentation = keras.Sequential(
|
142 |
+
[
|
143 |
+
layers.RandomFlip("horizontal"),
|
144 |
+
layers.RandomRotation(0.2),
|
145 |
+
layers.RandomContrast(0.3),
|
146 |
+
]
|
147 |
+
)
|
148 |
+
|
149 |
+
def decoder_to_resizer(img_path):
|
150 |
+
'''
|
151 |
+
Decodes jpg and resize and converts images to float for processing
|
152 |
+
'''
|
153 |
+
image = tf.io.read_file(img_path)
|
154 |
+
decoded_image = tf.image.decode_jpeg(image, channels=3)
|
155 |
+
resized_image = tf.image.resize(decoded_image, image_size)
|
156 |
+
return tf.image.convert_image_dtype(resized_image, tf.float32)
|
157 |
+
|
158 |
+
def process_input(img_path, captions):
|
159 |
+
'''
|
160 |
+
Returns decoded jpg in float after resizing to standard size, returns vectorized caption detail
|
161 |
+
'''
|
162 |
+
return decoder_to_resizer(img_path), vectorization(captions)
|
163 |
+
|
164 |
+
def prepare_dataset(images, captions):
|
165 |
+
dataset = tf.data.Dataset.from_tensor_slices((images, captions))
|
166 |
+
return dataset.shuffle(batch_size * 8).map(process_input, num_parallel_calls=autotune).batch(batch_size).prefetch(autotune)
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
training_dataset = prepare_dataset(list(training_data.keys()), list(training_data.values()))
|
171 |
+
validation_dataset = prepare_dataset(list(validation_data.keys()), list(validation_data.values()))
|
172 |
+
|
173 |
+
training_dataset
|
174 |
+
|
175 |
+
validation_dataset
|
176 |
+
|
177 |
+
def prepare_cnn_model():
|
178 |
+
base_model = efficientnet.EfficientNetB0(
|
179 |
+
input_shape=(*image_size, 3), include_top=False, weights="imagenet",
|
180 |
+
)
|
181 |
+
# We freeze our feature extractor
|
182 |
+
base_model.trainable = False
|
183 |
+
base_model_out = base_model.output
|
184 |
+
base_model_out = layers.Reshape((-1, base_model_out.shape[-1]))(base_model_out)
|
185 |
+
cnn_model = keras.models.Model(base_model.input, base_model_out)
|
186 |
+
return cnn_model
|
187 |
+
|
188 |
+
class EncoderClass(layers.Layer):
|
189 |
+
''' Encoder block that inherits layer and uses layer for neural network model
|
190 |
+
|
191 |
+
'''
|
192 |
+
def __init__(self, embedded_dimension, dense_dimension, number_of_heads, **kwargs):
|
193 |
+
super().__init__(**kwargs)
|
194 |
+
self.embedded_dimension = embedded_dimension
|
195 |
+
self.dense_dimension = dense_dimension
|
196 |
+
self.number_of_heads = number_of_heads
|
197 |
+
|
198 |
+
# A multi headed self attention layer with no dropout
|
199 |
+
self.mh_attention_layer = layers.MultiHeadAttention(
|
200 |
+
num_heads=number_of_heads,
|
201 |
+
key_dim=embedded_dimension,
|
202 |
+
dropout=0.0
|
203 |
+
)
|
204 |
+
|
205 |
+
# Normalization layers
|
206 |
+
# There layers noramlizes the input we can compare it to Standard Scaler in traditional machine learning algorithm
|
207 |
+
self.normalization_layer_1 = layers.LayerNormalization()
|
208 |
+
self.normalization_layer_2 = layers.LayerNormalization()
|
209 |
+
|
210 |
+
# Dense layer with relu activation
|
211 |
+
self.dense_layer = layers.Dense(embedded_dimension, activation="relu")
|
212 |
+
|
213 |
+
def call(self, inputs, training):
|
214 |
+
# Here the inputs for multiheaded attention layers are passed with combination of normalization layer and dense layer
|
215 |
+
inputs = self.dense_layer(self.normalization_layer_1(inputs))
|
216 |
+
|
217 |
+
attention_output_1 = self.mh_attention_layer(
|
218 |
+
query=inputs,
|
219 |
+
value=inputs,
|
220 |
+
key=inputs,
|
221 |
+
attention_mask=None,
|
222 |
+
training=training,
|
223 |
+
)
|
224 |
+
|
225 |
+
# Here after applying attention mechanism in original input, it is passed from another normalization layer
|
226 |
+
return self.normalization_layer_2(inputs + attention_output_1)
|
227 |
+
|
228 |
+
class EmbedTokenAndPostionClass(layers.Layer):
|
229 |
+
''' This call will embed token and its position together giving both semantic and contextual meaning to each token
|
230 |
+
'''
|
231 |
+
def __init__(self, sequence_length, vocabulary_size, embedded_dimension, **kwargs):
|
232 |
+
super().__init__(**kwargs)
|
233 |
+
|
234 |
+
# Initialize Embedding layer to embed tokens, here inputs is the vocabulary size and output dimension is the embedded dimension
|
235 |
+
# This layer captures the semantic meaning of token in the inputs. This helps to understand the meaning of words and their relationship
|
236 |
+
self.token_embeddings = layers.Embedding(
|
237 |
+
input_dim=vocabulary_size,
|
238 |
+
output_dim=embedded_dimension
|
239 |
+
)
|
240 |
+
|
241 |
+
# Initialize Embedding layer that embebs positions, here inputs is the sequence length and output dimension is the embedded dimension
|
242 |
+
# This simply helps to capture the position of the input or order or where a particular token is
|
243 |
+
self.position_embeddings = layers.Embedding(
|
244 |
+
input_dim=sequence_length,
|
245 |
+
output_dim=embedded_dimension
|
246 |
+
)
|
247 |
+
self.sequence_length = sequence_length
|
248 |
+
self.vocabulary_size = vocabulary_size
|
249 |
+
self.embedded_dimension = embedded_dimension
|
250 |
+
|
251 |
+
# Calculate the square root of embedded dimension and convert to float 32
|
252 |
+
# This is done to prevent magnitude/value of embedded dimension from becoming too high
|
253 |
+
self.embedded_scale = tf.math.sqrt(tf.cast(embedded_dimension, tf.float32))
|
254 |
+
|
255 |
+
def call(self, inputs):
|
256 |
+
|
257 |
+
# Get all the positions
|
258 |
+
positions = tf.range(start=0, limit=tf.shape(inputs)[-1], delta=1)
|
259 |
+
|
260 |
+
# Pass input through token embedding
|
261 |
+
# This will generate continous vector for each token
|
262 |
+
embedded_tokens = self.token_embeddings(inputs) * self.embedded_scale
|
263 |
+
embedded_positions = self.position_embeddings(positions)
|
264 |
+
|
265 |
+
# Combine vector and their position, capturing both sematic meaning of the words and its contextual meaning
|
266 |
+
return embedded_tokens + embedded_positions
|
267 |
+
|
268 |
+
def compute_mask(self, inputs, mask=None):
|
269 |
+
return tf.math.not_equal(inputs, 0)
|
270 |
+
|
271 |
+
class DecoderClass(layers.Layer):
|
272 |
+
'''This is the decoder component of our model. This will decode the vector space that has been encoded and embedded with its postions.
|
273 |
+
It uses self attention and cross attention mechanism along with feed forward NN layer to give output sequences.
|
274 |
+
'''
|
275 |
+
|
276 |
+
def __init__(self, embedded_dimension, feed_forward_dimension, number_of_heads, **kwargs):
|
277 |
+
super().__init__(**kwargs)
|
278 |
+
self.embed_dim = embedded_dimension
|
279 |
+
self.feed_forward_dimension = feed_forward_dimension
|
280 |
+
self.number_of_heads = number_of_heads
|
281 |
+
|
282 |
+
self.first_attention_layer = layers.MultiHeadAttention(
|
283 |
+
num_heads=number_of_heads,
|
284 |
+
key_dim=embedded_dimension,
|
285 |
+
dropout=0.1
|
286 |
+
)
|
287 |
+
|
288 |
+
self.second_attention_layer = layers.MultiHeadAttention(
|
289 |
+
num_heads=number_of_heads,
|
290 |
+
key_dim=embedded_dimension,
|
291 |
+
dropout=0.1
|
292 |
+
)
|
293 |
+
|
294 |
+
self.first_feed_forward_layer = layers.Dense(feed_forward_dimension, activation="relu")
|
295 |
+
self.second_feed_forward_layer = layers.Dense(embedded_dimension)
|
296 |
+
|
297 |
+
self.first_normalization_layer = layers.LayerNormalization()
|
298 |
+
self.second_normalization_layer = layers.LayerNormalization()
|
299 |
+
self.third_normalization_layer = layers.LayerNormalization()
|
300 |
+
|
301 |
+
self.embedding = EmbedTokenAndPostionClass(
|
302 |
+
embedded_dimension=embedded_dimension,
|
303 |
+
sequence_length=sequence_length,
|
304 |
+
vocabulary_size=vocabulary_size
|
305 |
+
)
|
306 |
+
|
307 |
+
self.out = layers.Dense(vocabulary_size, activation="softmax")
|
308 |
+
|
309 |
+
self.first_dropout_layer = layers.Dropout(0.3)
|
310 |
+
self.second_dropout_layer = layers.Dropout(0.5)
|
311 |
+
self.supports_masking = True
|
312 |
+
|
313 |
+
def call(self, inputs, encoder_outputs, training, mask=None):
|
314 |
+
inputs = self.embedding(inputs)
|
315 |
+
causal_mask = self.get_causal_attention_mask(inputs)
|
316 |
+
|
317 |
+
if mask is not None:
|
318 |
+
padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32)
|
319 |
+
combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32)
|
320 |
+
combined_mask = tf.minimum(combined_mask, causal_mask)
|
321 |
+
|
322 |
+
first_attention_output = self.first_attention_layer(
|
323 |
+
query=inputs,
|
324 |
+
value=inputs,
|
325 |
+
key=inputs,
|
326 |
+
attention_mask=combined_mask,
|
327 |
+
training=training,
|
328 |
+
)
|
329 |
+
first_normalization_output = self.first_normalization_layer(inputs + first_attention_output)
|
330 |
+
|
331 |
+
second_attention_output = self.second_attention_layer(
|
332 |
+
query=first_normalization_output,
|
333 |
+
value=encoder_outputs,
|
334 |
+
key=encoder_outputs,
|
335 |
+
attention_mask=padding_mask,
|
336 |
+
training=training,
|
337 |
+
)
|
338 |
+
second_normalization_output = self.second_normalization_layer(first_normalization_output + second_attention_output)
|
339 |
+
|
340 |
+
output = self.first_feed_forward_layer(second_normalization_output)
|
341 |
+
output = self.first_dropout_layer(output, training=training)
|
342 |
+
output = self.second_feed_forward_layer(output)
|
343 |
+
|
344 |
+
output = self.third_normalization_layer(output + second_normalization_output, training=training)
|
345 |
+
output = self.second_dropout_layer(output, training=training)
|
346 |
+
return self.out(output)
|
347 |
+
|
348 |
+
def get_causal_attention_mask(self, inputs):
|
349 |
+
input_shape = tf.shape(inputs)
|
350 |
+
batch_size, sequence_length = input_shape[0], input_shape[1]
|
351 |
+
i = tf.range(sequence_length)[:, tf.newaxis]
|
352 |
+
j = tf.range(sequence_length)
|
353 |
+
mask = tf.cast(i >= j, dtype="int32")
|
354 |
+
mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
|
355 |
+
mult = tf.concat(
|
356 |
+
[tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],
|
357 |
+
axis=0,
|
358 |
+
)
|
359 |
+
return tf.tile(mask, mult)
|
360 |
+
|
361 |
+
class ImageCaptionClass(keras.Model):
|
362 |
+
def __init__(
|
363 |
+
self, efficient_net_model, encoder_class, decoder_class, image_augmentation=None,
|
364 |
+
):
|
365 |
+
super().__init__()
|
366 |
+
self.efficient_net_model = efficient_net_model
|
367 |
+
self.encoder_class = encoder_class
|
368 |
+
self.decoder_class = decoder_class
|
369 |
+
self.loss_tracker = keras.metrics.Mean(name="loss")
|
370 |
+
self.acc_tracker = keras.metrics.Mean(name="accuracy")
|
371 |
+
self.caption_to_image_ration = 5
|
372 |
+
self.image_augmentation = image_augmentation
|
373 |
+
|
374 |
+
def calculate_loss(self, y_actual_value, y_predicted_vaue, mask):
|
375 |
+
loss = self.loss(y_actual_value, y_predicted_vaue)
|
376 |
+
mask = tf.cast(mask, dtype=loss.dtype)
|
377 |
+
loss *= mask
|
378 |
+
return tf.reduce_sum(loss) / tf.reduce_sum(mask)
|
379 |
+
|
380 |
+
def calculate_accuracy(self, y_actual_value, y_predicted_vaue, mask):
|
381 |
+
accuracy = tf.equal(y_actual_value, tf.argmax(y_predicted_vaue, axis=2))
|
382 |
+
accuracy = tf.math.logical_and(mask, accuracy)
|
383 |
+
accuracy = tf.cast(accuracy, dtype=tf.float32)
|
384 |
+
mask = tf.cast(mask, dtype=tf.float32)
|
385 |
+
return tf.reduce_sum(accuracy) / tf.reduce_sum(mask)
|
386 |
+
|
387 |
+
def get_caption_loss_and_accuracy(self, image_embedded, batch_sequence, calculate_for_train=True):
|
388 |
+
encoder_class_out = self.encoder_class(image_embedded, training=calculate_for_train)
|
389 |
+
batch_sequence_input = batch_sequence[:, :-1]
|
390 |
+
batch_sequence_actual = batch_sequence[:, 1:]
|
391 |
+
mask = tf.math.not_equal(batch_sequence_actual, 0)
|
392 |
+
batch_sequence_predicted = self.decoder_class(
|
393 |
+
batch_sequence_input, encoder_class_out, training=calculate_for_train, mask=mask
|
394 |
+
)
|
395 |
+
loss = self.calculate_loss(batch_sequence_actual, batch_sequence_predicted, mask)
|
396 |
+
acc = self.calculate_accuracy(batch_sequence_actual, batch_sequence_predicted, mask)
|
397 |
+
return loss, acc
|
398 |
+
|
399 |
+
def train_step(self, data):
|
400 |
+
batch_image, batch_sequence = data
|
401 |
+
batch_loss = 0
|
402 |
+
batch_accuracy = 0
|
403 |
+
|
404 |
+
if self.image_augmentation:
|
405 |
+
batch_image = self.image_augmentation(batch_image)
|
406 |
+
|
407 |
+
# 1. Get image embeddings
|
408 |
+
image_embedded = self.efficient_net_model(batch_image)
|
409 |
+
|
410 |
+
# 2. Pass each of the five captions one by one to the decoder_class
|
411 |
+
# along with the encoder_class outputs and compute the loss as well as accuracy
|
412 |
+
# for each caption.
|
413 |
+
for i in range(self.caption_to_image_ration):
|
414 |
+
with tf.GradientTape() as gradient_tape:
|
415 |
+
loss, acc = self.get_caption_loss_and_accuracy(
|
416 |
+
image_embedded, batch_sequence[:, i, :], calculate_for_train=True
|
417 |
+
)
|
418 |
+
|
419 |
+
# 3. Update loss and accuracy
|
420 |
+
batch_loss += loss
|
421 |
+
batch_accuracy += acc
|
422 |
+
|
423 |
+
# 4. Get the list of all the trainable weights
|
424 |
+
training_weights = (
|
425 |
+
self.encoder_class.trainable_variables + self.decoder_class.trainable_variables
|
426 |
+
)
|
427 |
+
|
428 |
+
# 5. Get the gradients
|
429 |
+
gradient_lists = gradient_tape.gradient(loss, training_weights)
|
430 |
+
|
431 |
+
# 6. Update the trainable weights
|
432 |
+
self.optimizer.apply_gradients(zip(gradient_lists, training_weights))
|
433 |
+
|
434 |
+
# 7. Update the trackers
|
435 |
+
batch_accuracy /= float(self.caption_to_image_ration)
|
436 |
+
self.loss_tracker.update_state(batch_loss)
|
437 |
+
self.acc_tracker.update_state(batch_accuracy)
|
438 |
+
|
439 |
+
# 8. Return the loss and accuracy values
|
440 |
+
return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}
|
441 |
+
|
442 |
+
def test_step(self, data):
|
443 |
+
batch_image, batch_sequence = data
|
444 |
+
batch_loss = 0
|
445 |
+
batch_accuracy = 0
|
446 |
+
|
447 |
+
# 1. Get image embeddings
|
448 |
+
image_embedded = self.efficient_net_model(batch_image)
|
449 |
+
|
450 |
+
# 2. Pass each of the five captions one by one to the decoder_class
|
451 |
+
# along with the encoder_class outputs and compute the loss as well as accuracy
|
452 |
+
# for each caption.
|
453 |
+
for i in range(self.caption_to_image_ration):
|
454 |
+
loss, acc = self.get_caption_loss_and_accuracy(
|
455 |
+
image_embedded, batch_sequence[:, i, :], calculate_for_train=False
|
456 |
+
)
|
457 |
+
|
458 |
+
# 3. Update batch loss and batch accuracy
|
459 |
+
batch_loss += loss
|
460 |
+
batch_accuracy += acc
|
461 |
+
|
462 |
+
batch_accuracy /= float(self.caption_to_image_ration)
|
463 |
+
|
464 |
+
# 4. Update the trackers
|
465 |
+
self.loss_tracker.update_state(batch_loss)
|
466 |
+
self.acc_tracker.update_state(batch_accuracy)
|
467 |
+
|
468 |
+
# 5. Return the loss and accuracy values
|
469 |
+
return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}
|
470 |
+
|
471 |
+
@property
|
472 |
+
def metrics(self):
|
473 |
+
# We need to list our metrics here so the `reset_states()` can be
|
474 |
+
# called automatically.
|
475 |
+
return [self.loss_tracker, self.acc_tracker]
|
476 |
+
def get_config(self):
|
477 |
+
# Return a dictionary containing the configuration of your model.
|
478 |
+
config = {
|
479 |
+
"efficient_net_model": self.efficient_net_model,
|
480 |
+
"encoder_class": self.encoder_class,
|
481 |
+
"decoder_class": self.decoder_class,
|
482 |
+
"caption_to_image_ration": self.caption_to_image_ration,
|
483 |
+
"image_augmentation": self.image_augmentation,
|
484 |
+
}
|
485 |
+
return config
|
486 |
+
|
487 |
+
|
488 |
+
def call(self, data):
|
489 |
+
batch_image, batch_sequence = data
|
490 |
+
batch_loss = 0
|
491 |
+
batch_accuracy = 0
|
492 |
+
|
493 |
+
if self.image_augmentation:
|
494 |
+
batch_image = self.image_augmentation(batch_image)
|
495 |
+
|
496 |
+
# 1. Get image embeddings
|
497 |
+
image_embedded = self.efficient_net_model(batch_image)
|
498 |
+
|
499 |
+
# 2. Pass each of the five captions one by one to the decoder_class
|
500 |
+
# along with the encoder_class outputs and compute the loss as well as accuracy
|
501 |
+
# for each caption.
|
502 |
+
for i in range(self.caption_to_image_ration):
|
503 |
+
loss, acc = self.get_caption_loss_and_accuracy(
|
504 |
+
image_embedded, batch_sequence[:, i, :], calculate_for_train=True
|
505 |
+
)
|
506 |
+
|
507 |
+
# 3. Update batch loss and batch accuracy
|
508 |
+
batch_loss += loss
|
509 |
+
batch_accuracy += acc
|
510 |
+
|
511 |
+
batch_accuracy /= float(self.caption_to_image_ration)
|
512 |
+
|
513 |
+
# 4. Update the trackers
|
514 |
+
self.loss_tracker.update_state(batch_loss)
|
515 |
+
self.acc_tracker.update_state(batch_accuracy)
|
516 |
+
|
517 |
+
# 5. Return the loss and accuracy values
|
518 |
+
return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}
|
519 |
+
|
520 |
+
|
521 |
+
|
522 |
+
|
523 |
+
|
524 |
+
cnn_model = prepare_cnn_model()
|
525 |
+
encoder = EncoderClass(embedded_dimension=embedded_dimension, dense_dimension=feed_forward_dimension, number_of_heads=1)
|
526 |
+
decoder = DecoderClass(embedded_dimension=embedded_dimension, feed_forward_dimension=feed_forward_dimension, number_of_heads=2)
|
527 |
+
caption_model = ImageCaptionClass(
|
528 |
+
efficient_net_model=cnn_model, encoder_class=encoder, decoder_class=decoder, image_augmentation=image_augmentation,
|
529 |
+
)
|
530 |
+
caption_model
|
531 |
+
|
532 |
+
cross_entropy_loss_f = keras.losses.SparseCategoricalCrossentropy(reduction="none")
|
533 |
+
|
534 |
+
early_stopping_function = keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True)
|
535 |
+
|
536 |
+
# Create a class that inherits Learning Rate Schedule class from Keras,
|
537 |
+
# this class will determines how slow or fast the model adjusts its parameter according to the loss function
|
538 |
+
|
539 |
+
class LRSClass(keras.optimizers.schedules.LearningRateSchedule):
|
540 |
+
def __init__(self, learning_rate_post_warmup, steps):
|
541 |
+
super().__init__()
|
542 |
+
self.learning_rate_post_warmup = learning_rate_post_warmup
|
543 |
+
self.steps = steps
|
544 |
+
|
545 |
+
def __call__(self, step):
|
546 |
+
global_step = tf.cast(step, tf.float32)
|
547 |
+
steps = tf.cast(self.steps, tf.float32)
|
548 |
+
progress = global_step / steps
|
549 |
+
learning_rate = self.learning_rate_post_warmup * progress
|
550 |
+
return tf.cond(
|
551 |
+
global_step < steps,
|
552 |
+
lambda: learning_rate,
|
553 |
+
lambda: self.learning_rate_post_warmup,
|
554 |
+
)
|
555 |
+
|
556 |
+
# Number of optimization steps required
|
557 |
+
num_train_steps = len(training_dataset) * epochs
|
558 |
+
|
559 |
+
# No. of steps where learning rate is gradually increased.
|
560 |
+
warmup_steps = num_train_steps // 15
|
561 |
+
|
562 |
+
lr_schedule = LRSClass(learning_rate_post_warmup=1e-4, steps=warmup_steps)
|
563 |
+
|
564 |
+
# Compile the model
|
565 |
+
caption_model.compile(optimizer=keras.optimizers.Adam(lr_schedule), loss=cross_entropy_loss_f)
|
566 |
+
|
567 |
+
# Fit the model
|
568 |
+
caption_model.fit(
|
569 |
+
training_dataset,
|
570 |
+
epochs=epochs,
|
571 |
+
validation_data=validation_dataset,
|
572 |
+
callbacks=[early_stopping_function],
|
573 |
+
)
|
574 |
+
|
575 |
+
|
576 |
+
|
577 |
+
caption_model.summary()
|
578 |
+
|
579 |
+
|
580 |
+
|
581 |
+
test_vocabulary = vectorization.get_vocabulary()
|
582 |
+
index_lookup = dict(zip(range(len(test_vocabulary)), test_vocabulary))
|
583 |
+
max_decoded_sentence_length = sequence_length - 1
|
584 |
+
valid_images = list(validation_data.keys())
|
585 |
+
|
586 |
+
|
587 |
+
def generate_caption():
|
588 |
+
# Select a random image from the validation dataset
|
589 |
+
validate_image = np.random.choice(valid_images)
|
590 |
+
|
591 |
+
# Get sample image and decode/ resize
|
592 |
+
validate_image = decoder_to_resizer(validate_image)
|
593 |
+
image = validate_image.numpy().clip(0, 255).astype(np.uint8)
|
594 |
+
plt.imshow(image)
|
595 |
+
plt.show()
|
596 |
+
|
597 |
+
# Prepare image and send it the efficient net model
|
598 |
+
image = tf.expand_dims(validate_image, 0)
|
599 |
+
image = caption_model.efficient_net_model(image)
|
600 |
+
|
601 |
+
# Pass the image features to the Transformer encoder
|
602 |
+
encoded_img = caption_model.encoder_class(image, training=False)
|
603 |
+
|
604 |
+
# Generate the caption using the Transformer decoder
|
605 |
+
decoded_caption = "<start> "
|
606 |
+
for i in range(max_decoded_sentence_length):
|
607 |
+
tokenized_caption = vectorization([decoded_caption])[:, :-1]
|
608 |
+
mask = tf.math.not_equal(tokenized_caption, 0)
|
609 |
+
predictions = caption_model.decoder_class(
|
610 |
+
tokenized_caption, encoded_img, training=False, mask=mask
|
611 |
+
)
|
612 |
+
sampled_token_index = np.argmax(predictions[0, i, :])
|
613 |
+
sampled_token = index_lookup[sampled_token_index]
|
614 |
+
if sampled_token == "<end>":
|
615 |
+
break
|
616 |
+
decoded_caption += " " + sampled_token
|
617 |
+
|
618 |
+
decoded_caption = decoded_caption.replace("<start> ", "")
|
619 |
+
decoded_caption = decoded_caption.replace(" <end>", "").strip()
|
620 |
+
print("Predicted Caption: ", decoded_caption)
|
621 |
+
|
622 |
+
generate_caption()
|
623 |
+
|
624 |
+
generate_caption()
|
625 |
+
|
626 |
+
generate_caption()
|
627 |
+
|
628 |
+
def generate_caption_custom(img_path):
|
629 |
+
# Select a random image from the validation dataset
|
630 |
+
|
631 |
+
|
632 |
+
validate_image = img_path
|
633 |
+
print(validate_image)
|
634 |
+
# Get sample image and decode/ resize
|
635 |
+
validate_image = decoder_to_resizer(validate_image)
|
636 |
+
image = validate_image.numpy().clip(0, 255).astype(np.uint8)
|
637 |
+
plt.imshow(image)
|
638 |
+
plt.show()
|
639 |
+
|
640 |
+
# Prepare image and send it the efficient net model
|
641 |
+
image = tf.expand_dims(validate_image, 0)
|
642 |
+
image = caption_model.efficient_net_model(image)
|
643 |
+
|
644 |
+
# Pass the image features to the Transformer encoder
|
645 |
+
encoded_img = caption_model.encoder_class(image, training=False)
|
646 |
+
|
647 |
+
# Generate the caption using the Transformer decoder
|
648 |
+
decoded_caption = "<start> "
|
649 |
+
for i in range(max_decoded_sentence_length):
|
650 |
+
tokenized_caption = vectorization([decoded_caption])[:, :-1]
|
651 |
+
mask = tf.math.not_equal(tokenized_caption, 0)
|
652 |
+
predictions = caption_model.decoder_class(
|
653 |
+
tokenized_caption, encoded_img, training=False, mask=mask
|
654 |
+
)
|
655 |
+
sampled_token_index = np.argmax(predictions[0, i, :])
|
656 |
+
sampled_token = index_lookup[sampled_token_index]
|
657 |
+
if sampled_token == "<end>":
|
658 |
+
break
|
659 |
+
decoded_caption += " " + sampled_token
|
660 |
+
|
661 |
+
decoded_caption = decoded_caption.replace("<start> ", "")
|
662 |
+
decoded_caption = decoded_caption.replace(" <end>", "").strip()
|
663 |
+
print("Predicted Caption: ", decoded_caption)
|
664 |
+
|
665 |
+
# generate_caption_custom("./image2.jpg")
|
666 |
+
|
667 |
+
# generate_caption_custom("./Document.jpeg")
|
668 |
+
|
669 |
+
def generate_with_link(url):
|
670 |
+
file_name = wget.download(url)
|
671 |
+
generate_caption_custom(file_name)
|
672 |
+
|
673 |
+
link = 'https://media.istockphoto.com/id/1346503960/photo/school-children-with-a-parachute.jpg?s=1024x1024&w=is&k=20&c=HNOFWi02yU4NB_98iIWKHbzpGlWPYcfQagnPthD2eOo='
|
674 |
+
generate_with_link(link)
|
675 |
+
|
676 |
+
caption_model.save('path/to/location', save_format='tf')
|
677 |
+
|
678 |
+
image_shape = (*image_size, 3) # Assuming RGB images
|
679 |
+
caption_shape = (5, sequence_length) # For 5 captions with max sequence length
|
680 |
+
|
681 |
+
caption_model.build(input_shape=[(None, *image_shape), (None, *caption_shape)])
|
682 |
+
|
683 |
+
# Save the model
|
684 |
+
path_to_save = 'path_to_save'
|
685 |
+
caption_model.save(path_to_save)
|
686 |
|
|