virtex-redcaps / app.py
zamborg's picture
added caching to loading
5d3b8a6
raw history blame
No virus
2.69 kB
import streamlit as st
import io
import sys
import time
import json
from model import *
sys.path.append("./virtex/")
def gen_show_caption(sub_prompt=None, cap_prompt = ""):
with st.spinner("Generating Caption"):
if sub_prompt is None and cap_prompt is not "":
st.write("Without a specified subreddit, caption prompts will skip subreddit prediction")
subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt = cap_prompt)
st.header("Predicted Caption:\n\n")
# st.subheader(f"r/{subreddit}:\t{caption}\n")
st.markdown(
f"""
r/{subreddit}
\t
**{caption}**
"""
)
st.title("Image Captioning Demo from Redcaps")
st.sidebar.markdown(
"""
Image Captioning Model from VirTex trained on Redcaps
"""
)
with st.spinner("Loading Model"):
virtexModel, imageLoader, sample_images, valid_subs = create_objects()
random_image = get_rand_img(sample_images)
rand_idx = 0
st.sidebar.title("Select a sample image")
if st.sidebar.button("Random Sample Image"):
rand_idx, random_image = get_rand_img(sample_images)
sample_image = None
sample_image = st.sidebar.selectbox(
"",
sample_images,
index=rand_idx
)
st.sidebar.title("Select a Subreddit")
sub = st.sidebar.selectbox(
"Select None for a Predicted Subreddit",
valid_subs
)
st.sidebar.title("Write a Custom Prompt")
cap_prompt = st.sidebar.text_input(
"Leave this blank for an unbiased caption",
value=""
)
uploaded_image = None
with st.sidebar.form("file-uploader-form", clear_on_submit=True):
uploaded_file = st.file_uploader("Choose a file")
submitted = st.form_submit_button("Submit")
if uploaded_file is not None and submitted:
uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
if uploaded_image is None and submitted:
st.write("Please select a file to upload")
else:
image_file = sample_image if sample_image is not None else random_image
image = uploaded_image if uploaded_image is not None else Image.open(image_file)
image_dict = imageLoader.transform(image)
show_image = imageLoader.show_resize(image)
show = st.image(show_image)
show.image(show_image, "Your Image")
gen_show_caption(sub, imageLoader.text_transform(cap_prompt))
image.close()
# from model import *
# sample_images = get_samples()
# v, il = VirTexModel(), ImageLoader()
# for s in sample_images:
# subreddit, caption = v.predict(il.load(s))
# print("=====================")
# print(subreddit)
# print(caption)