Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pickle | |
import tensorflow as tf | |
import cv2 | |
import numpy as np | |
from PIL import Image, ImageOps | |
import imageio.v3 as iio | |
import time | |
from textwrap import wrap | |
import matplotlib.pylab as plt | |
from tensorflow.keras import Input | |
from tensorflow.keras.layers import ( | |
GRU, | |
Add, | |
AdditiveAttention, | |
Attention, | |
Concatenate, | |
Dense, | |
Embedding, | |
LayerNormalization, | |
Reshape, | |
StringLookup, | |
TextVectorization, | |
) | |
MAX_CAPTION_LEN = 64 | |
MINIMUM_SENTENCE_LENGTH = 5 | |
IMG_HEIGHT = 299 | |
IMG_WIDTH = 299 | |
IMG_CHANNELS = 3 | |
ATTENTION_DIM = 512 # size of dense layer in Attention | |
VOCAB_SIZE = 20000 | |
FEATURES_SHAPE = (8, 8, 1536) | |
def load_image_model(): | |
image_model=tf.keras.models.load_model('./image_caption_model.h5') | |
return image_model | |
def load_decoder_model(): | |
decoder_model=tf.keras.models.load_model('./decoder_pred_model.h5') | |
return decoder_model | |
def load_encoder_model(): | |
encoder=tf.keras.models.load_model('./encoder_model.h5') | |
return encoder | |
st.title(":blue[Nishant Guvvada's] :red[AI Journey] Image Caption Generation") | |
st.write(""" | |
# Multi-Modal Machine Learning | |
""" | |
) | |
file = st.file_uploader("Upload an image to generate captions!", type= ['png', 'jpg']) | |
# We will override the default standardization of TextVectorization to preserve | |
# "<>" characters, so we preserve the tokens for the <start> and <end>. | |
def standardize(inputs): | |
inputs = tf.strings.lower(inputs) | |
return tf.strings.regex_replace( | |
inputs, r"[!\"#$%&\(\)\*\+.,-/:;=?@\[\\\]^_`{|}~]?", "" | |
) | |
# Choose the most frequent words from the vocabulary & remove punctuation etc. | |
vocab = open('./tokenizer_vocab.txt', 'rb') | |
tokenizer = pickle.load(vocab) | |
# Lookup table: Word -> Index | |
word_to_index = StringLookup( | |
mask_token="", vocabulary=tokenizer | |
) | |
# Lookup table: Index -> Word | |
index_to_word = StringLookup( | |
mask_token="", vocabulary=tokenizer, invert=True | |
) | |
## Probabilistic prediction using the trained model | |
def predict_caption(file): | |
filename = Image.open(file) | |
image = filename.convert('RGB') | |
image = np.array(image) | |
gru_state = tf.zeros((1, ATTENTION_DIM)) | |
resize = tf.image.resize(image, (IMG_HEIGHT, IMG_WIDTH)) | |
img = resize/255 | |
encoder = load_encoder_model() | |
features = encoder(tf.expand_dims(img, axis=0)) | |
dec_input = tf.expand_dims([word_to_index("<start>")], 1) | |
result = [] | |
decoder_pred_model = load_decoder_model() | |
for i in range(MAX_CAPTION_LEN): | |
predictions, gru_state = decoder_pred_model( | |
[dec_input, gru_state, features] | |
) | |
# draws from log distribution given by predictions | |
top_probs, top_idxs = tf.math.top_k( | |
input=predictions[0][0], k=10, sorted=False | |
) | |
chosen_id = tf.random.categorical([top_probs], 1)[0].numpy() | |
predicted_id = top_idxs.numpy()[chosen_id][0] | |
result.append(tokenizer[predicted_id]) | |
if predicted_id == word_to_index("<end>"): | |
return img, result | |
dec_input = tf.expand_dims([predicted_id], 1) | |
return img, result | |
def on_click(): | |
if file is None: | |
st.text("Please upload an image file") | |
else: | |
image = Image.open(file) | |
st.image(image, use_column_width=True) | |
for i in range(5): | |
image, caption = predict_caption(file) | |
#print(" ".join(caption[:-1]) + ".") | |
st.write(" ".join(caption[:-1]) + ".") | |
st.button('Generate', on_click=on_click) |