Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import pickle | |
import re | |
import tensorflow as tf | |
from tensorflow import keras | |
from tensorflow.keras import layers | |
from tensorflow.keras.applications import efficientnet | |
from tensorflow.keras.layers import TextVectorization | |
#warning ignorer | |
import warnings | |
warnings.filterwarnings("ignore") | |
# Desired image dimensions | |
IMAGE_SIZE = (299, 299) | |
# Vocabulary size | |
VOCAB_SIZE = 10000 | |
# Fixed length allowed for any sequence | |
SEQ_LENGTH = 25 | |
# Dimension for the image embeddings and token embeddings | |
EMBED_DIM = 512 | |
# Per-layer units in the feed-forward network | |
FF_DIM = 512 | |
# load the text data | |
open_file = open('text.pkl', "rb") | |
text_data = pickle.load(open_file) | |
open_file.close() | |
# text preprocessing | |
def custom_standardization(input_string): | |
lowercase = tf.strings.lower(input_string) | |
return tf.strings.regex_replace(lowercase, "[%s]" % re.escape(strip_chars), "") | |
strip_chars = "!\"#$%&'()*+,-./:;<=>?@[\]^_`{|}~" | |
strip_chars = strip_chars.replace("<", "") | |
strip_chars = strip_chars.replace(">", "") | |
vectorization = TextVectorization( | |
max_tokens=VOCAB_SIZE, | |
output_mode="int", | |
output_sequence_length=SEQ_LENGTH, | |
standardize=custom_standardization, | |
) | |
vectorization.adapt(text_data) | |
# image preprocessing | |
def decode_and_resize(img_path): | |
img = tf.io.read_file(img_path) | |
img = tf.image.decode_jpeg(img, channels=3) | |
img = tf.image.resize(img, IMAGE_SIZE) | |
img = tf.image.convert_image_dtype(img, tf.float32) | |
return img | |
# Data augmentation for image data | |
image_augmentation = keras.Sequential( | |
[ | |
layers.RandomFlip("horizontal"), | |
layers.RandomRotation(0.2), | |
layers.RandomContrast(0.3), | |
] | |
) | |
# model building | |
def get_cnn_model(): | |
base_model = efficientnet.EfficientNetB0( | |
input_shape=(*IMAGE_SIZE, 3), include_top=False, weights="imagenet", | |
) | |
# We freeze our feature extractor | |
base_model.trainable = False | |
base_model_out = base_model.output | |
base_model_out = layers.Reshape((-1, base_model_out.shape[-1]))(base_model_out) | |
cnn_model = keras.models.Model(base_model.input, base_model_out) | |
return cnn_model | |
class TransformerEncoderBlock(layers.Layer): | |
def __init__(self, embed_dim, dense_dim, num_heads, **kwargs): | |
super().__init__(**kwargs) | |
self.embed_dim = embed_dim | |
self.dense_dim = dense_dim | |
self.num_heads = num_heads | |
self.attention_1 = layers.MultiHeadAttention( | |
num_heads=num_heads, key_dim=embed_dim, dropout=0.0 | |
) | |
self.layernorm_1 = layers.LayerNormalization() | |
self.layernorm_2 = layers.LayerNormalization() | |
self.dense_1 = layers.Dense(embed_dim, activation="relu") | |
def call(self, inputs, training, mask=None): | |
inputs = self.layernorm_1(inputs) | |
inputs = self.dense_1(inputs) | |
attention_output_1 = self.attention_1( | |
query=inputs, | |
value=inputs, | |
key=inputs, | |
attention_mask=None, | |
training=training, | |
) | |
out_1 = self.layernorm_2(inputs + attention_output_1) | |
return out_1 | |
class PositionalEmbedding(layers.Layer): | |
def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs): | |
super().__init__(**kwargs) | |
self.token_embeddings = layers.Embedding( | |
input_dim=vocab_size, output_dim=embed_dim | |
) | |
self.position_embeddings = layers.Embedding( | |
input_dim=sequence_length, output_dim=embed_dim | |
) | |
self.sequence_length = sequence_length | |
self.vocab_size = vocab_size | |
self.embed_dim = embed_dim | |
self.embed_scale = tf.math.sqrt(tf.cast(embed_dim, tf.float32)) | |
def call(self, inputs): | |
length = tf.shape(inputs)[-1] | |
positions = tf.range(start=0, limit=length, delta=1) | |
embedded_tokens = self.token_embeddings(inputs) | |
embedded_tokens = embedded_tokens * self.embed_scale | |
embedded_positions = self.position_embeddings(positions) | |
return embedded_tokens + embedded_positions | |
def compute_mask(self, inputs, mask=None): | |
return tf.math.not_equal(inputs, 0) | |
class TransformerDecoderBlock(layers.Layer): | |
def __init__(self, embed_dim, ff_dim, num_heads, **kwargs): | |
super().__init__(**kwargs) | |
self.embed_dim = embed_dim | |
self.ff_dim = ff_dim | |
self.num_heads = num_heads | |
self.attention_1 = layers.MultiHeadAttention( | |
num_heads=num_heads, key_dim=embed_dim, dropout=0.1 | |
) | |
self.attention_2 = layers.MultiHeadAttention( | |
num_heads=num_heads, key_dim=embed_dim, dropout=0.1 | |
) | |
self.ffn_layer_1 = layers.Dense(ff_dim, activation="relu") | |
self.ffn_layer_2 = layers.Dense(embed_dim) | |
self.layernorm_1 = layers.LayerNormalization() | |
self.layernorm_2 = layers.LayerNormalization() | |
self.layernorm_3 = layers.LayerNormalization() | |
self.embedding = PositionalEmbedding( | |
embed_dim=EMBED_DIM, sequence_length=SEQ_LENGTH, vocab_size=VOCAB_SIZE | |
) | |
self.out = layers.Dense(VOCAB_SIZE, activation="softmax") | |
self.dropout_1 = layers.Dropout(0.3) | |
self.dropout_2 = layers.Dropout(0.5) | |
self.supports_masking = True | |
def call(self, inputs, encoder_outputs, training, mask=None): | |
inputs = self.embedding(inputs) | |
causal_mask = self.get_causal_attention_mask(inputs) | |
if mask is not None: | |
padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32) | |
combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32) | |
combined_mask = tf.minimum(combined_mask, causal_mask) | |
attention_output_1 = self.attention_1( | |
query=inputs, | |
value=inputs, | |
key=inputs, | |
attention_mask=combined_mask, | |
training=training, | |
) | |
out_1 = self.layernorm_1(inputs + attention_output_1) | |
attention_output_2 = self.attention_2( | |
query=out_1, | |
value=encoder_outputs, | |
key=encoder_outputs, | |
attention_mask=padding_mask, | |
training=training, | |
) | |
out_2 = self.layernorm_2(out_1 + attention_output_2) | |
ffn_out = self.ffn_layer_1(out_2) | |
ffn_out = self.dropout_1(ffn_out, training=training) | |
ffn_out = self.ffn_layer_2(ffn_out) | |
ffn_out = self.layernorm_3(ffn_out + out_2, training=training) | |
ffn_out = self.dropout_2(ffn_out, training=training) | |
preds = self.out(ffn_out) | |
return preds | |
def get_causal_attention_mask(self, inputs): | |
input_shape = tf.shape(inputs) | |
batch_size, sequence_length = input_shape[0], input_shape[1] | |
i = tf.range(sequence_length)[:, tf.newaxis] | |
j = tf.range(sequence_length) | |
mask = tf.cast(i >= j, dtype="int32") | |
mask = tf.reshape(mask, (1, input_shape[1], input_shape[1])) | |
mult = tf.concat( | |
[tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], | |
axis=0, | |
) | |
return tf.tile(mask, mult) | |
class ImageCaptioningModel(keras.Model): | |
def __init__( | |
self, cnn_model, encoder, decoder, num_captions_per_image=5, image_aug=None, | |
): | |
super().__init__() | |
self.cnn_model = cnn_model | |
self.encoder = encoder | |
self.decoder = decoder | |
self.loss_tracker = keras.metrics.Mean(name="loss") | |
self.acc_tracker = keras.metrics.Mean(name="accuracy") | |
self.num_captions_per_image = num_captions_per_image | |
self.image_aug = image_aug | |
def calculate_loss(self, y_true, y_pred, mask): | |
loss = self.loss(y_true, y_pred) | |
mask = tf.cast(mask, dtype=loss.dtype) | |
loss *= mask | |
return tf.reduce_sum(loss) / tf.reduce_sum(mask) | |
def calculate_accuracy(self, y_true, y_pred, mask): | |
accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=2)) | |
accuracy = tf.math.logical_and(mask, accuracy) | |
accuracy = tf.cast(accuracy, dtype=tf.float32) | |
mask = tf.cast(mask, dtype=tf.float32) | |
return tf.reduce_sum(accuracy) / tf.reduce_sum(mask) | |
def _compute_caption_loss_and_acc(self, img_embed, batch_seq, training=True): | |
encoder_out = self.encoder(img_embed, training=training) | |
batch_seq_inp = batch_seq[:, :-1] | |
batch_seq_true = batch_seq[:, 1:] | |
mask = tf.math.not_equal(batch_seq_true, 0) | |
batch_seq_pred = self.decoder( | |
batch_seq_inp, encoder_out, training=training, mask=mask | |
) | |
loss = self.calculate_loss(batch_seq_true, batch_seq_pred, mask) | |
acc = self.calculate_accuracy(batch_seq_true, batch_seq_pred, mask) | |
return loss, acc | |
def train_step(self, batch_data): | |
batch_img, batch_seq = batch_data | |
batch_loss = 0 | |
batch_acc = 0 | |
if self.image_aug: | |
batch_img = self.image_aug(batch_img) | |
# 1. Get image embeddings | |
img_embed = self.cnn_model(batch_img) | |
# 2. Pass each of the five captions one by one to the decoder | |
# along with the encoder outputs and compute the loss as well as accuracy | |
# for each caption. | |
for i in range(self.num_captions_per_image): | |
with tf.GradientTape() as tape: | |
loss, acc = self._compute_caption_loss_and_acc( | |
img_embed, batch_seq[:, i, :], training=True | |
) | |
# 3. Update loss and accuracy | |
batch_loss += loss | |
batch_acc += acc | |
# 4. Get the list of all the trainable weights | |
train_vars = ( | |
self.encoder.trainable_variables + self.decoder.trainable_variables | |
) | |
# 5. Get the gradients | |
grads = tape.gradient(loss, train_vars) | |
# 6. Update the trainable weights | |
self.optimizer.apply_gradients(zip(grads, train_vars)) | |
# 7. Update the trackers | |
batch_acc /= float(self.num_captions_per_image) | |
self.loss_tracker.update_state(batch_loss) | |
self.acc_tracker.update_state(batch_acc) | |
# 8. Return the loss and accuracy values | |
return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()} | |
def test_step(self, batch_data): | |
batch_img, batch_seq = batch_data | |
batch_loss = 0 | |
batch_acc = 0 | |
# 1. Get image embeddings | |
img_embed = self.cnn_model(batch_img) | |
# 2. Pass each of the five captions one by one to the decoder | |
# along with the encoder outputs and compute the loss as well as accuracy | |
# for each caption. | |
for i in range(self.num_captions_per_image): | |
loss, acc = self._compute_caption_loss_and_acc( | |
img_embed, batch_seq[:, i, :], training=False | |
) | |
# 3. Update batch loss and batch accuracy | |
batch_loss += loss | |
batch_acc += acc | |
batch_acc /= float(self.num_captions_per_image) | |
# 4. Update the trackers | |
self.loss_tracker.update_state(batch_loss) | |
self.acc_tracker.update_state(batch_acc) | |
# 5. Return the loss and accuracy values | |
return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()} | |
def metrics(self): | |
# We need to list our metrics here so the `reset_states()` can be | |
# called automatically. | |
return [self.loss_tracker, self.acc_tracker] | |
# wrapping models | |
cnn_model = get_cnn_model() | |
encoder = TransformerEncoderBlock(embed_dim=EMBED_DIM, dense_dim=FF_DIM, num_heads=1) | |
decoder = TransformerDecoderBlock(embed_dim=EMBED_DIM, ff_dim=FF_DIM, num_heads=2) | |
caption_model = ImageCaptioningModel( | |
cnn_model=cnn_model, encoder=encoder, decoder=decoder, image_aug=image_augmentation, | |
) | |
loaded_model = ImageCaptioningModel( | |
cnn_model=cnn_model, encoder=encoder, decoder=decoder, image_aug=image_augmentation, | |
) | |
# load weights | |
loaded_model.built = True | |
loaded_model.load_weights('cap_model') | |
vocab = vectorization.get_vocabulary() | |
index_lookup = dict(zip(range(len(vocab)), vocab)) | |
max_decoded_sentence_length = SEQ_LENGTH - 1 | |
#valid_images = list(valid_data.keys()) | |
def generate_caption(image): | |
sample_img = image | |
# Read the image from the disk | |
sample_img = decode_and_resize(sample_img) | |
img = sample_img.numpy().clip(0, 255).astype(np.uint8) | |
#plt.imshow(img) | |
#plt.show() | |
# Pass the image to the CNN | |
img = tf.expand_dims(sample_img, 0) | |
img = loaded_model.cnn_model(img) | |
# Pass the image features to the Transformer encoder | |
encoded_img = loaded_model.encoder(img, training=False) | |
# Generate the caption using the Transformer decoder | |
decoded_caption = "<start> " | |
for i in range(max_decoded_sentence_length): | |
tokenized_caption = vectorization([decoded_caption])[:, :-1] | |
mask = tf.math.not_equal(tokenized_caption, 0) | |
predictions = loaded_model.decoder( | |
tokenized_caption, encoded_img, training=False, mask=mask | |
) | |
sampled_token_index = np.argmax(predictions[0, i, :]) | |
sampled_token = index_lookup[sampled_token_index] | |
if sampled_token == " <end>": | |
break | |
decoded_caption += " " + sampled_token | |
decoded_caption = decoded_caption.replace("<start> ", "") | |
decoded_caption = decoded_caption.replace(" <end>", "").strip() | |
print(decoded_caption) | |
inputs = [ | |
gr.inputs.Image( label="Original Image") | |
] | |
outputs = [ | |
gr.outputs.Textbox(label = 'Caption') | |
] | |
title = "Image Captioning using CNN and a transformer + " | |
description = "Implementing an image captioning model using a pretrained CNN model of Efficient Net and transformer to generate Image Caption for the uploaded image. Flickr8K Dataset was used for training." | |
article = " " | |
gr.Interface( | |
generate_caption, | |
inputs, | |
outputs, | |
title=title, | |
description=description, | |
article=article, | |
examples=[["pic 1.jpg"], ["pic 2.jpg"], ["pic 3.jpg"], ["pic 4.jpg"]], | |
).launch() |