virtex-redcaps / app.py
zamborg's picture
Did y'all know that python is an interpreted language? and that sometimes you need to run import statements in order
56c51f4
raw history blame
No virus
2.69 kB
import streamlit as st
import io
import sys
import time
import json
sys.path.append("./virtex/")
from model import *
def gen_show_caption(sub_prompt=None, cap_prompt = ""):
with st.spinner("Generating Caption"):
if sub_prompt is None and cap_prompt is not "":
st.write("Without a specified subreddit, caption prompts will skip subreddit prediction")
subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt = cap_prompt)
st.header("Predicted Caption:\n\n")
# st.subheader(f"r/{subreddit}:\t{caption}\n")
st.markdown(
f"""
r/{subreddit}
\t
**{caption}**
"""
)
st.title("Image Captioning Demo from Redcaps")
st.sidebar.markdown(
"""
Image Captioning Model from VirTex trained on Redcaps
"""
)
with st.spinner("Loading Model"):
virtexModel, imageLoader, sample_images, valid_subs = create_objects()
random_image = get_rand_img(sample_images)
rand_idx = 0
st.sidebar.title("Select a sample image")
if st.sidebar.button("Random Sample Image"):
rand_idx, random_image = get_rand_img(sample_images)
sample_image = None
sample_image = st.sidebar.selectbox(
"",
sample_images,
index=rand_idx
)
st.sidebar.title("Select a Subreddit")
sub = st.sidebar.selectbox(
"Select None for a Predicted Subreddit",
valid_subs
)
st.sidebar.title("Write a Custom Prompt")
cap_prompt = st.sidebar.text_input(
"Leave this blank for an unbiased caption",
value=""
)
uploaded_image = None
with st.sidebar.form("file-uploader-form", clear_on_submit=True):
uploaded_file = st.file_uploader("Choose a file")
submitted = st.form_submit_button("Submit")
if uploaded_file is not None and submitted:
uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
if uploaded_image is None and submitted:
st.write("Please select a file to upload")
else:
image_file = sample_image if sample_image is not None else random_image
image = uploaded_image if uploaded_image is not None else Image.open(image_file)
image_dict = imageLoader.transform(image)
show_image = imageLoader.show_resize(image)
show = st.image(show_image)
show.image(show_image, "Your Image")
gen_show_caption(sub, imageLoader.text_transform(cap_prompt))
image.close()
# from model import *
# sample_images = get_samples()
# v, il = VirTexModel(), ImageLoader()
# for s in sample_images:
# subreddit, caption = v.predict(il.load(s))
# print("=====================")
# print(subreddit)
# print(caption)