saichandrapandraju's picture
Update app.py
914cf7d
import numpy as np
import streamlit as st
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import load_model, Model
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
from tensorflow.keras.preprocessing.text import tokenizer_from_json
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from PIL import Image
@st.cache_resource
def init_lstm_model():
return load_model("./best_model.h5")
@st.cache_resource
def init_vgg16_model():
vgg_model = VGG16()
return Model(inputs = vgg_model.inputs , outputs = vgg_model.layers[-2].output)
@st.cache_resource
def init_lstm_tokenizer():
with open("./tokenizer.txt") as rf:
return tokenizer_from_json(rf.read())
vgg16_model = init_vgg16_model()
lstm_model = init_lstm_model()
lstm_tokenizer = init_lstm_tokenizer()
max_length = 34
def idx_to_word(integer):
for word, index in lstm_tokenizer.word_index.items():
if index == integer:
return word
return None
def predict_caption(image, max_length):
# add start tag for generation process
in_text = 'startseq'
# iterate over the max length of sequence
for _ in range(max_length):
# encode input sequence
sequence = lstm_tokenizer.texts_to_sequences([in_text])[0]
# pad the sequence
sequence = pad_sequences([sequence], max_length)
# predict next word
yhat = lstm_model.predict([image, sequence], verbose=0)
# get index with high probability
yhat = np.argmax(yhat)
# convert index to word
word = idx_to_word(yhat)
# stop if word not found
if word is None:
break
# append word as input for generating next word
in_text += " " + word
# stop if we reach end tag
if word == 'endseq':
break
return in_text
def generate_caption(image_name):
# load the image
image = load_img(image_name, target_size=(224, 224))
# convert image pixels to numpy array
image = img_to_array(image)
# reshape data for model
image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
# preprocess image for vgg
image = preprocess_input(image)
feature = vgg16_model.predict(image)
# predict the caption
y_pred = predict_caption(feature, max_length)
return y_pred.replace("startseq", "").replace("endseq", "").strip()
st.title("""
Image Captioner.
This app generates a caption for the input image. The results will be predicted from the basic cnn-rnn to advanced transformer based encoder-decoder models.""")
file_name = st.file_uploader("Upload an image to generate caption...")
if file_name is not None:
col1, col2 = st.columns(2)
image = Image.open(file_name)
col1.image(image, use_column_width=True)
prediction = generate_caption(file_name)
col2.header("Predictions")
col2.subheader(f"VGG16-LSTM : ")
col2.text(prediction)