|
import numpy as np |
|
import streamlit as st |
|
from tensorflow.keras.preprocessing.sequence import pad_sequences |
|
from tensorflow.keras.models import load_model, Model |
|
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input |
|
from tensorflow.keras.preprocessing.text import tokenizer_from_json |
|
from tensorflow.keras.preprocessing.image import load_img, img_to_array |
|
from PIL import Image |
|
|
|
|
|
@st.cache_resource |
|
def init_lstm_model(): |
|
return load_model("./best_model.h5") |
|
|
|
@st.cache_resource |
|
def init_vgg16_model(): |
|
vgg_model = VGG16() |
|
return Model(inputs = vgg_model.inputs , outputs = vgg_model.layers[-2].output) |
|
|
|
@st.cache_resource |
|
def init_lstm_tokenizer(): |
|
with open("./tokenizer.txt") as rf: |
|
return tokenizer_from_json(rf.read()) |
|
|
|
|
|
|
|
vgg16_model = init_vgg16_model() |
|
lstm_model = init_lstm_model() |
|
lstm_tokenizer = init_lstm_tokenizer() |
|
max_length = 34 |
|
|
|
def idx_to_word(integer): |
|
for word, index in lstm_tokenizer.word_index.items(): |
|
if index == integer: |
|
return word |
|
return None |
|
|
|
|
|
def predict_caption(image, max_length): |
|
|
|
in_text = 'startseq' |
|
|
|
for _ in range(max_length): |
|
|
|
sequence = lstm_tokenizer.texts_to_sequences([in_text])[0] |
|
|
|
sequence = pad_sequences([sequence], max_length) |
|
|
|
yhat = lstm_model.predict([image, sequence], verbose=0) |
|
|
|
yhat = np.argmax(yhat) |
|
|
|
word = idx_to_word(yhat) |
|
|
|
if word is None: |
|
break |
|
|
|
in_text += " " + word |
|
|
|
if word == 'endseq': |
|
break |
|
return in_text |
|
|
|
|
|
|
|
def generate_caption(image_name): |
|
|
|
image = load_img(image_name, target_size=(224, 224)) |
|
|
|
image = img_to_array(image) |
|
|
|
image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2])) |
|
|
|
image = preprocess_input(image) |
|
feature = vgg16_model.predict(image) |
|
|
|
y_pred = predict_caption(feature, max_length) |
|
return y_pred.replace("startseq", "").replace("endseq", "").strip() |
|
|
|
|
|
st.title(""" |
|
Image Captioner. |
|
|
|
This app generates a caption for the input image. The results will be predicted from the basic cnn-rnn to advanced transformer based encoder-decoder models.""") |
|
|
|
|
|
file_name = st.file_uploader("Upload an image to generate caption...") |
|
|
|
if file_name is not None: |
|
col1, col2 = st.columns(2) |
|
|
|
image = Image.open(file_name) |
|
col1.image(image, use_column_width=True) |
|
prediction = generate_caption(file_name) |
|
|
|
col2.header("Predictions") |
|
col2.subheader(f"VGG16-LSTM : ") |
|
col2.text(prediction) |