import pickle import tensorflow as tf import pandas as pd import numpy as np # CONTANTS MAX_LENGTH = 40 # VOCABULARY_SIZE = 10000 BATCH_SIZE = 32 BUFFER_SIZE = 1000 EMBEDDING_DIM = 512 UNITS = 512 # LOADING DATA vocab = pickle.load(open('saved_vocabulary/vocab_1.file', 'rb')) tokenizer = tf.keras.layers.TextVectorization( # max_tokens=VOCABULARY_SIZE, standardize=None, output_sequence_length=MAX_LENGTH, vocabulary=vocab ) idx2word = tf.keras.layers.StringLookup( mask_token="", vocabulary=tokenizer.get_vocabulary(), invert=True) # MODEL def CNN_Encoder(): inception_v3 = tf.keras.applications.InceptionV3( include_top=False, weights='imagenet' ) inception_v3.trainable = False output = inception_v3.output output = tf.keras.layers.Reshape( (-1, output.shape[-1]))(output) cnn_model = tf.keras.models.Model(inception_v3.input, output) return cnn_model class TransformerEncoderLayer(tf.keras.layers.Layer): def __init__(self, embed_dim, num_heads): super().__init__() self.layer_norm_1 = tf.keras.layers.LayerNormalization() self.layer_norm_2 = tf.keras.layers.LayerNormalization() self.attention = tf.keras.layers.MultiHeadAttention( num_heads=num_heads, key_dim=embed_dim) self.dense = tf.keras.layers.Dense(embed_dim, activation="relu") def call(self, x, training): x = self.layer_norm_1(x) x = self.dense(x) attn_output = self.attention( query=x, value=x, key=x, attention_mask=None, training=training ) x = self.layer_norm_2(x + attn_output) return x class Embeddings(tf.keras.layers.Layer): def __init__(self, vocab_size, embed_dim, max_len): super().__init__() self.token_embeddings = tf.keras.layers.Embedding( vocab_size, embed_dim) self.position_embeddings = tf.keras.layers.Embedding( max_len, embed_dim, input_shape=(None, max_len)) def call(self, input_ids): length = tf.shape(input_ids)[-1] position_ids = tf.range(start=0, limit=length, delta=1) position_ids = tf.expand_dims(position_ids, axis=0) token_embeddings = self.token_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) return token_embeddings + position_embeddings class TransformerDecoderLayer(tf.keras.layers.Layer): def __init__(self, embed_dim, units, num_heads): super().__init__() self.embedding = Embeddings( tokenizer.vocabulary_size(), embed_dim, MAX_LENGTH) self.attention_1 = tf.keras.layers.MultiHeadAttention( num_heads=num_heads, key_dim=embed_dim, dropout=0.1 ) self.attention_2 = tf.keras.layers.MultiHeadAttention( num_heads=num_heads, key_dim=embed_dim, dropout=0.1 ) self.layernorm_1 = tf.keras.layers.LayerNormalization() self.layernorm_2 = tf.keras.layers.LayerNormalization() self.layernorm_3 = tf.keras.layers.LayerNormalization() self.ffn_layer_1 = tf.keras.layers.Dense(units, activation="relu") self.ffn_layer_2 = tf.keras.layers.Dense(embed_dim) self.out = tf.keras.layers.Dense(tokenizer.vocabulary_size(), activation="softmax") self.dropout_1 = tf.keras.layers.Dropout(0.3) self.dropout_2 = tf.keras.layers.Dropout(0.5) def call(self, input_ids, encoder_output, training, mask=None): embeddings = self.embedding(input_ids) combined_mask = None padding_mask = None if mask is not None: causal_mask = self.get_causal_attention_mask(embeddings) padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32) combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32) combined_mask = tf.minimum(combined_mask, causal_mask) attn_output_1 = self.attention_1( query=embeddings, value=embeddings, key=embeddings, attention_mask=combined_mask, training=training ) out_1 = self.layernorm_1(embeddings + attn_output_1) attn_output_2 = self.attention_2( query=out_1, value=encoder_output, key=encoder_output, attention_mask=padding_mask, training=training ) out_2 = self.layernorm_2(out_1 + attn_output_2) ffn_out = self.ffn_layer_1(out_2) ffn_out = self.dropout_1(ffn_out, training=training) ffn_out = self.ffn_layer_2(ffn_out) ffn_out = self.layernorm_3(ffn_out + out_2) ffn_out = self.dropout_2(ffn_out, training=training) preds = self.out(ffn_out) return preds def get_causal_attention_mask(self, inputs): input_shape = tf.shape(inputs) batch_size, sequence_length = input_shape[0], input_shape[1] i = tf.range(sequence_length)[:, tf.newaxis] j = tf.range(sequence_length) mask = tf.cast(i >= j, dtype="int32") mask = tf.reshape(mask, (1, input_shape[1], input_shape[1])) mult = tf.concat( [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], axis=0 ) return tf.tile(mask, mult) class ImageCaptioningModel(tf.keras.Model): def __init__(self, cnn_model, encoder, decoder, image_aug=None): super().__init__() self.cnn_model = cnn_model self.encoder = encoder self.decoder = decoder self.image_aug = image_aug self.loss_tracker = tf.keras.metrics.Mean(name="loss") self.acc_tracker = tf.keras.metrics.Mean(name="accuracy") def calculate_loss(self, y_true, y_pred, mask): loss = self.loss(y_true, y_pred) mask = tf.cast(mask, dtype=loss.dtype) loss *= mask return tf.reduce_sum(loss) / tf.reduce_sum(mask) def calculate_accuracy(self, y_true, y_pred, mask): accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=2)) accuracy = tf.math.logical_and(mask, accuracy) accuracy = tf.cast(accuracy, dtype=tf.float32) mask = tf.cast(mask, dtype=tf.float32) return tf.reduce_sum(accuracy) / tf.reduce_sum(mask) def compute_loss_and_acc(self, img_embed, captions, training=True): encoder_output = self.encoder(img_embed, training=True) y_input = captions[:, :-1] y_true = captions[:, 1:] mask = (y_true != 0) y_pred = self.decoder( y_input, encoder_output, training=True, mask=mask ) loss = self.calculate_loss(y_true, y_pred, mask) acc = self.calculate_accuracy(y_true, y_pred, mask) return loss, acc def train_step(self, batch): imgs, captions = batch if self.image_aug: imgs = self.image_aug(imgs) img_embed = self.cnn_model(imgs) with tf.GradientTape() as tape: loss, acc = self.compute_loss_and_acc( img_embed, captions ) train_vars = ( self.encoder.trainable_variables + self.decoder.trainable_variables ) grads = tape.gradient(loss, train_vars) self.optimizer.apply_gradients(zip(grads, train_vars)) self.loss_tracker.update_state(loss) self.acc_tracker.update_state(acc) return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()} def test_step(self, batch): imgs, captions = batch img_embed = self.cnn_model(imgs) loss, acc = self.compute_loss_and_acc( img_embed, captions, training=False ) self.loss_tracker.update_state(loss) self.acc_tracker.update_state(acc) return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()} @property def metrics(self): return [self.loss_tracker, self.acc_tracker] def load_image_from_path(img_path): img = tf.io.read_file(img_path) img = tf.io.decode_jpeg(img, channels=3) img = tf.keras.layers.Resizing(299, 299)(img) img = img / 255. return img def generate_caption(img, caption_model, add_noise=False): if isinstance(img, str): img = load_image_from_path(img) if add_noise == True: noise = tf.random.normal(img.shape)*0.1 img = (img + noise) img = (img - tf.reduce_min(img))/(tf.reduce_max(img) - tf.reduce_min(img)) img = tf.expand_dims(img, axis=0) img_embed = caption_model.cnn_model(img) img_encoded = caption_model.encoder(img_embed, training=False) y_inp = '[start]' for i in range(MAX_LENGTH-1): tokenized = tokenizer([y_inp])[:, :-1] mask = tf.cast(tokenized != 0, tf.int32) pred = caption_model.decoder( tokenized, img_encoded, training=False, mask=mask) pred_idx = np.argmax(pred[0, i, :]) pred_word = idx2word(pred_idx).numpy().decode('utf-8') if pred_word == '[end]': break y_inp += ' ' + pred_word y_inp = y_inp.replace('[start] ', '') return y_inp def get_caption_model(): encoder = TransformerEncoderLayer(EMBEDDING_DIM, 1) decoder = TransformerDecoderLayer(EMBEDDING_DIM, UNITS, 8) cnn_model = CNN_Encoder() caption_model = ImageCaptioningModel( cnn_model=cnn_model, encoder=encoder, decoder=decoder, image_aug=None, ) def call_fn(batch, training): return batch caption_model.call = call_fn sample_x, sample_y = tf.random.normal((1, 299, 299, 3)), tf.zeros((1, 40)) caption_model((sample_x, sample_y)) sample_img_embed = caption_model.cnn_model(sample_x) sample_enc_out = caption_model.encoder(sample_img_embed, training=False) caption_model.decoder(sample_y, sample_enc_out, training=False) try: caption_model.load_weights('saved_models/image_captioning_transformer_weights_2.h5') except FileNotFoundError: caption_model.load_weights('Image-Captioning/saved_models/image_captioning_transformer_weights_2.h5') return caption_model