import numpy as np import streamlit as st from tensorflow.keras.preprocessing.sequence import pad_sequences from tensorflow.keras.models import load_model, Model from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input from tensorflow.keras.preprocessing.text import tokenizer_from_json from tensorflow.keras.preprocessing.image import load_img, img_to_array from PIL import Image @st.cache_resource def init_lstm_model(): return load_model("./best_model.h5") @st.cache_resource def init_vgg16_model(): vgg_model = VGG16() return Model(inputs = vgg_model.inputs , outputs = vgg_model.layers[-2].output) @st.cache_resource def init_lstm_tokenizer(): with open("./tokenizer.txt") as rf: return tokenizer_from_json(rf.read()) vgg16_model = init_vgg16_model() lstm_model = init_lstm_model() lstm_tokenizer = init_lstm_tokenizer() max_length = 34 def idx_to_word(integer): for word, index in lstm_tokenizer.word_index.items(): if index == integer: return word return None def predict_caption(image, max_length): # add start tag for generation process in_text = 'startseq' # iterate over the max length of sequence for _ in range(max_length): # encode input sequence sequence = lstm_tokenizer.texts_to_sequences([in_text])[0] # pad the sequence sequence = pad_sequences([sequence], max_length) # predict next word yhat = lstm_model.predict([image, sequence], verbose=0) # get index with high probability yhat = np.argmax(yhat) # convert index to word word = idx_to_word(yhat) # stop if word not found if word is None: break # append word as input for generating next word in_text += " " + word # stop if we reach end tag if word == 'endseq': break return in_text def generate_caption(image_name): # load the image image = load_img(image_name, target_size=(224, 224)) # convert image pixels to numpy array image = img_to_array(image) # reshape data for model image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2])) # preprocess image for vgg image = preprocess_input(image) feature = vgg16_model.predict(image) # predict the caption y_pred = predict_caption(feature, max_length) return y_pred.replace("startseq", "").replace("endseq", "").strip() st.title(""" Image Captioner. This app generates a caption for the input image. The results will be predicted from the basic cnn-rnn to advanced transformer based encoder-decoder models.""") file_name = st.file_uploader("Upload an image to generate caption...") if file_name is not None: col1, col2 = st.columns(2) image = Image.open(file_name) col1.image(image, use_column_width=True) prediction = generate_caption(file_name) col2.header("Predictions") col2.subheader(f"VGG16-LSTM : ") col2.text(prediction)