import io import streamlit as st from model import * # # TODO: # - Reformat the model introduction # - Make the iterative text generation def gen_show_caption(sub_prompt=None, cap_prompt=""): with st.spinner("Generating Caption"): subreddit, caption = virtexModel.predict( image_dict, sub_prompt=sub_prompt, prompt=cap_prompt ) st.markdown( f""" ### r/{subreddit} {cap_prompt} {caption} """, unsafe_allow_html=True, ) _, center, _ = st.columns([1, 8, 1]) with center: st.title("Image Captioning Demo from RedCaps") st.sidebar.markdown( """ ### Image Captioning Model from VirTex trained on RedCaps Use this page to caption your own images or try out some of our samples. You can also generate captions as if they are from specific subreddits, as if they start with a particular prompt, or even both. Share your results on twitter with #redcaps or with a friend*. """ ) # st.markdown(footer,unsafe_allow_html=True) with st.spinner("Loading Model"): virtexModel, imageLoader, sample_images, valid_subs = create_objects() select_idx = None st.sidebar.title("Select a sample image") if st.sidebar.button("Random Sample Image"): select_idx = get_rand_idx(sample_images) sample_image = sample_images[0 if select_idx is None else select_idx] uploaded_image = None # with st.sidebar.form("file-uploader-form", clear_on_submit=True): uploaded_file = st.sidebar.file_uploader("Choose a file") # submitted = st.form_submit_button("Submit") if uploaded_file is not None: # and submitted: uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue())) select_idx = None # set this to help rewrite the cache # class OnChange(): # def __init__(self, idx): # self.idx = idx # def __call__(self): # st.write(f"the idx is: {self.idx}") # st.write(f"the sample_image is {sample_image}") # sample_image = st.sidebar.selectbox( # "", # sample_images, # index = 0 if select_idx is None else select_idx, # on_change=OnChange(0 if select_idx is None else select_idx) # ) st.sidebar.title("Select a Subreddit") sub = st.sidebar.selectbox( "Type below to condition on a subreddit. Select None for a predicted subreddit", valid_subs, ) st.sidebar.title("Write a Custom Prompt") cap_prompt = st.sidebar.text_input("Write the start of your caption below", value="") _ = st.sidebar.button("Regenerate Caption") st.sidebar.write("Advanced Options:") num_captions = st.sidebar.select_slider( "Number of Captions to Predict", options=[1, 2, 3, 4, 5], value=1 ) nuc_size = st.sidebar.slider( "Nucelus Size:\nLarger values lead to more diverse captions", min_value=0.0, max_value=1.0, value=0.8, step=0.05, ) virtexModel.model.decoder.nucleus_size = nuc_size image_file = sample_image # LOAD AND CACHE THE IMAGE if uploaded_image is not None: image = uploaded_image elif select_idx is None and "image" in st.session_state: image = st.session_state["image"] else: image = Image.open(image_file) image = image.convert("RGB") st.session_state["image"] = image image_dict = imageLoader.transform(image) show_image = imageLoader.show_resize(image) with center: show = st.image(show_image) show.image(show_image) if sub is None and imageLoader.text_transform(cap_prompt) is not "": st.write("Without a specified subreddit we default to /r/pics") for i in range(num_captions): gen_show_caption(sub, imageLoader.text_transform(cap_prompt)) st.sidebar.markdown( """ *Please note that this model was explicitly not trained on images of people, and as a result is not designed to caption images with humans. This demo accompanies our paper RedCaps. Created by Karan Desai, Gaurav Kaul, Zubin Aysola, Justin Johnson """ )