Manikandan-Alagu's picture
Update app.py
1025657 verified
raw
history blame
13.3 kB
import io
import os
import re
import pickle
import base64
import tensorflow as tf
import pandas as pd
import numpy as np
import streamlit as st
import requests
import nltk
from PIL import Image
from poetpy import get_poetry
from nltk.corpus import stopwords
nltk.download('stopwords')
# CONTANTS
MAX_LENGTH = 40
# VOCABULARY_SIZE = 10000
BATCH_SIZE = 32
BUFFER_SIZE = 1000
EMBEDDING_DIM = 512
UNITS = 512
# LOADING DATA
vocab = pickle.load(open('saved_vocabulary/vocab_coco.file', 'rb'))
tokenizer = tf.keras.layers.TextVectorization(
# max_tokens=VOCABULARY_SIZE,
standardize=None,
output_sequence_length=MAX_LENGTH,
vocabulary=vocab
)
idx2word = tf.keras.layers.StringLookup(
mask_token="",
vocabulary=tokenizer.get_vocabulary(),
invert=True
)
# MODEL
def CNN_Encoder():
inception_v3 = tf.keras.applications.InceptionV3(
include_top=False,
weights='imagenet'
)
output = inception_v3.output
output = tf.keras.layers.Reshape(
(-1, output.shape[-1]))(output)
cnn_model = tf.keras.models.Model(inception_v3.input, output)
return cnn_model
class TransformerEncoderLayer(tf.keras.layers.Layer):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.layer_norm_1 = tf.keras.layers.LayerNormalization()
self.layer_norm_2 = tf.keras.layers.LayerNormalization()
self.attention = tf.keras.layers.MultiHeadAttention(
num_heads=num_heads, key_dim=embed_dim)
self.dense = tf.keras.layers.Dense(embed_dim, activation="relu")
def call(self, x, training):
x = self.layer_norm_1(x)
x = self.dense(x)
attn_output = self.attention(
query=x,
value=x,
key=x,
attention_mask=None,
training=training
)
x = self.layer_norm_2(x + attn_output)
return x
class Embeddings(tf.keras.layers.Layer):
def __init__(self, vocab_size, embed_dim, max_len):
super().__init__()
self.token_embeddings = tf.keras.layers.Embedding(
vocab_size, embed_dim)
self.position_embeddings = tf.keras.layers.Embedding(
max_len, embed_dim, input_shape=(None, max_len))
def call(self, input_ids):
length = tf.shape(input_ids)[-1]
position_ids = tf.range(start=0, limit=length, delta=1)
position_ids = tf.expand_dims(position_ids, axis=0)
token_embeddings = self.token_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
return token_embeddings + position_embeddings
class TransformerDecoderLayer(tf.keras.layers.Layer):
def __init__(self, embed_dim, units, num_heads):
super().__init__()
self.embedding = Embeddings(
tokenizer.vocabulary_size(), embed_dim, MAX_LENGTH)
self.attention_1 = tf.keras.layers.MultiHeadAttention(
num_heads=num_heads, key_dim=embed_dim, dropout=0.1
)
self.attention_2 = tf.keras.layers.MultiHeadAttention(
num_heads=num_heads, key_dim=embed_dim, dropout=0.1
)
self.layernorm_1 = tf.keras.layers.LayerNormalization()
self.layernorm_2 = tf.keras.layers.LayerNormalization()
self.layernorm_3 = tf.keras.layers.LayerNormalization()
self.ffn_layer_1 = tf.keras.layers.Dense(units, activation="relu")
self.ffn_layer_2 = tf.keras.layers.Dense(embed_dim)
self.out = tf.keras.layers.Dense(tokenizer.vocabulary_size(), activation="softmax")
self.dropout_1 = tf.keras.layers.Dropout(0.3)
self.dropout_2 = tf.keras.layers.Dropout(0.5)
def call(self, input_ids, encoder_output, training, mask=None):
embeddings = self.embedding(input_ids)
combined_mask = None
padding_mask = None
if mask is not None:
causal_mask = self.get_causal_attention_mask(embeddings)
padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32)
combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32)
combined_mask = tf.minimum(combined_mask, causal_mask)
attn_output_1 = self.attention_1(
query=embeddings,
value=embeddings,
key=embeddings,
attention_mask=combined_mask,
training=training
)
out_1 = self.layernorm_1(embeddings + attn_output_1)
attn_output_2 = self.attention_2(
query=out_1,
value=encoder_output,
key=encoder_output,
attention_mask=padding_mask,
training=training
)
out_2 = self.layernorm_2(out_1 + attn_output_2)
ffn_out = self.ffn_layer_1(out_2)
ffn_out = self.dropout_1(ffn_out, training=training)
ffn_out = self.ffn_layer_2(ffn_out)
ffn_out = self.layernorm_3(ffn_out + out_2)
ffn_out = self.dropout_2(ffn_out, training=training)
preds = self.out(ffn_out)
return preds
def get_causal_attention_mask(self, inputs):
input_shape = tf.shape(inputs)
batch_size, sequence_length = input_shape[0], input_shape[1]
i = tf.range(sequence_length)[:, tf.newaxis]
j = tf.range(sequence_length)
mask = tf.cast(i >= j, dtype="int32")
mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
mult = tf.concat(
[tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],
axis=0
)
return tf.tile(mask, mult)
class ImageCaptioningModel(tf.keras.Model):
def __init__(self, cnn_model, encoder, decoder, image_aug=None):
super().__init__()
self.cnn_model = cnn_model
self.encoder = encoder
self.decoder = decoder
self.image_aug = image_aug
self.loss_tracker = tf.keras.metrics.Mean(name="loss")
self.acc_tracker = tf.keras.metrics.Mean(name="accuracy")
def calculate_loss(self, y_true, y_pred, mask):
loss = self.loss(y_true, y_pred)
mask = tf.cast(mask, dtype=loss.dtype)
loss *= mask
return tf.reduce_sum(loss) / tf.reduce_sum(mask)
def calculate_accuracy(self, y_true, y_pred, mask):
accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=2))
accuracy = tf.math.logical_and(mask, accuracy)
accuracy = tf.cast(accuracy, dtype=tf.float32)
mask = tf.cast(mask, dtype=tf.float32)
return tf.reduce_sum(accuracy) / tf.reduce_sum(mask)
def compute_loss_and_acc(self, img_embed, captions, training=True):
encoder_output = self.encoder(img_embed, training=True)
y_input = captions[:, :-1]
y_true = captions[:, 1:]
mask = (y_true != 0)
y_pred = self.decoder(
y_input, encoder_output, training=True, mask=mask
)
loss = self.calculate_loss(y_true, y_pred, mask)
acc = self.calculate_accuracy(y_true, y_pred, mask)
return loss, acc
def train_step(self, batch):
imgs, captions = batch
if self.image_aug:
imgs = self.image_aug(imgs)
img_embed = self.cnn_model(imgs)
with tf.GradientTape() as tape:
loss, acc = self.compute_loss_and_acc(
img_embed, captions
)
train_vars = (
self.encoder.trainable_variables + self.decoder.trainable_variables
)
grads = tape.gradient(loss, train_vars)
self.optimizer.apply_gradients(zip(grads, train_vars))
self.loss_tracker.update_state(loss)
self.acc_tracker.update_state(acc)
return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}
def test_step(self, batch):
imgs, captions = batch
img_embed = self.cnn_model(imgs)
loss, acc = self.compute_loss_and_acc(
img_embed, captions, training=False
)
self.loss_tracker.update_state(loss)
self.acc_tracker.update_state(acc)
return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}
@property
def metrics(self):
return [self.loss_tracker, self.acc_tracker]
def load_image_from_path(img_path):
img = tf.io.read_file(img_path)
img = tf.io.decode_jpeg(img, channels=3)
img = tf.cast(img, tf.float32) # Convert to float32
img = tf.keras.layers.Resizing(299, 299)(img)
img = tf.keras.applications.inception_v3.preprocess_input(img)
return img
def generate_caption(img, caption_model, add_noise=False):
if isinstance(img, str):
img = load_image_from_path(img)
if add_noise == True:
noise = tf.random.normal(img.shape)*0.1
img = (img + noise)
img = (img - tf.reduce_min(img))/(tf.reduce_max(img) - tf.reduce_min(img))
img = tf.expand_dims(img, axis=0)
img_embed = caption_model.cnn_model(img)
img_encoded = caption_model.encoder(img_embed, training=False)
y_inp = '[start]'
for i in range(MAX_LENGTH-1):
tokenized = tokenizer([y_inp])[:, :-1]
mask = tf.cast(tokenized != 0, tf.int32)
pred = caption_model.decoder(
tokenized, img_encoded, training=False, mask=mask)
pred_idx = np.argmax(pred[0, i, :])
pred_word = idx2word(pred_idx).numpy().decode('utf-8')
if pred_word == '[end]':
break
y_inp += ' ' + pred_word
y_inp = y_inp.replace('[start] ', '')
return y_inp
def get_caption_model():
encoder = TransformerEncoderLayer(EMBEDDING_DIM, 1)
decoder = TransformerDecoderLayer(EMBEDDING_DIM, UNITS, 8)
cnn_model = CNN_Encoder()
caption_model = ImageCaptioningModel(
cnn_model=cnn_model, encoder=encoder, decoder=decoder, image_aug=None,
)
def call_fn(batch, training=True):
return batch
caption_model.call = call_fn
sample_x, sample_y = tf.random.normal((1, 299, 299, 3)), tf.zeros((1, 40))
caption_model((sample_x, sample_y))
sample_img_embed = caption_model.cnn_model(sample_x)
sample_enc_out = caption_model.encoder(sample_img_embed, training=False)
caption_model.decoder(sample_y, sample_enc_out, training=False)
try:
caption_model.load_weights('saved_models/image_captioning_coco_weights.h5')
except FileNotFoundError:
caption_model.load_weights('Image-Captioning/saved_models/image_captioning_coco_weights.h5')
return caption_model
#part-2
@st.cache_resource
def get_model():
return get_caption_model()
caption_model = get_model()
@st.cache_data
def extract_important_term(caption):
# Remove stopwords
stop_words = set(stopwords.words('english'))
words = caption.lower().split()
filtered_words = [word for word in words if word not in stop_words]
# Find the longest word
important_term = max(filtered_words, key=len)
return important_term
def generate_poem(word, num_lines):
# Retrieve poetry lines containing the given word
poetry_lines = get_poetry('lines', word)
# Filter out the lines that don't contain the word
selected_lines = []
for poem in poetry_lines:
lines = poem['lines']
for line in lines:
if word.lower() in line.lower():
selected_lines.append(line)
# Select num_lines lines from the retrieved poetry lines
selected_lines = selected_lines[:num_lines]
return selected_lines
def predict(term_col, poem_col):
pred_caption = generate_caption('tmp.jpg', caption_model)
# Extract the important term
important_term = extract_important_term(pred_caption)
# Generate poem using poetpy
poem_lines = generate_poem(important_term, num_lines=10)
# Display the poem
poem_col.markdown('#### Generated Poem:')
poem_col.markdown('<div class="poem-container">', unsafe_allow_html=True)
for line in poem_lines:
poem_col.markdown(f'<div class="poem-line" style="color: black; background-color: light grey; padding: 5px; margin-bottom: 5px; font-family: \'Palatino Linotype\', \'Book Antiqua\', Palatino, serif;">{line}</div>', unsafe_allow_html=True)
poem_col.markdown('</div>', unsafe_allow_html=True)
st.markdown('<h1 style="text-align:center; font-family:Comic Sans MS; width:fit-content; font-size:3em; color:green; text-shadow: 2px 2px 4px #000000;">AUTO POEM GENERATOR</h1>', unsafe_allow_html=True)
col1, col2 = st.columns(2)
# Image URL input
img_url = st.text_input(label='Enter Image URL')
# Image upload input
img_upload = st.file_uploader(label='Upload Image', type=['jpg', 'png', 'jpeg'])
# Process image and generate poem
if img_url:
img = Image.open(requests.get(img_url, stream=True).raw)
img = img.convert('RGB')
col1.image(img, caption="Input Image", use_column_width=True)
img.save('tmp.jpg')
predict(col1, col2)
st.markdown('<center style="opacity: 70%">OR</center>', unsafe_allow_html=True)
elif img_upload:
img = img_upload.read()
img = Image.open(io.BytesIO(img))
img = img.convert('RGB')
col1.image(img, caption="Input Image", use_column_width=True)
img.save('tmp.jpg')
predict(col1, col2)
# Remove temporary image file
if img_url or img_upload:
os.remove('tmp.jpg')