import streamlit as st import io import sys import time import json sys.path.append("./virtex/") from model import * def gen_show_caption(sub_prompt=None, cap_prompt = ""): with st.spinner("Generating Caption"): if sub_prompt is None and cap_prompt is not "": st.write("Without a specified subreddit, caption prompts will skip subreddit prediction") subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt = cap_prompt) st.header("Predicted Caption:\n\n") # st.subheader(f"r/{subreddit}:\t{caption}\n") st.markdown( f""" ### r/{subreddit} #### {caption} """ ) st.title("Image Captioning Demo from RedCaps") st.sidebar.markdown( """ ### Image Captioning Model from VirTex trained on RedCaps Use this page to caption your own images or try out some of our samples. You can also generate captions as if they are from specific subreddits, as if they start with a particular prompt, or even both. Feel free to share your results on twitter with #redcaps or with a friend. """ ) with st.spinner("Loading Model"): virtexModel, imageLoader, sample_images, valid_subs = create_objects() staggered = st.sidebar.checkbox("Iteratively Generate Captions") if staggered: pass else: select_idx = None st.sidebar.title("Select a sample image") if st.sidebar.button("Random Sample Image"): select_idx = get_rand_idx(sample_images) sample_image = sample_images[0 if select_idx is None else select_idx] uploaded_image = None with st.sidebar.form("file-uploader-form", clear_on_submit=True): uploaded_file = st.file_uploader("Choose a file") submitted = st.form_submit_button("Submit") if uploaded_file is not None and submitted: uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue())) select_idx = None # set this to help rewrite the cache # class OnChange(): # def __init__(self, idx): # self.idx = idx # def __call__(self): # st.write(f"the idx is: {self.idx}") # st.write(f"the sample_image is {sample_image}") # sample_image = st.sidebar.selectbox( # "", # sample_images, # index = 0 if select_idx is None else select_idx, # on_change=OnChange(0 if select_idx is None else select_idx) # ) st.sidebar.title("Select a Subreddit") sub = st.sidebar.selectbox( "Select None for a Predicted Subreddit", valid_subs ) st.sidebar.title("Write a Custom Prompt") cap_prompt = st.sidebar.text_input( "Leave this blank for an unbiased caption", value="" ) _ = st.sidebar.button("Regenerate Caption") # advanced = st.sidebar.checkbox("Advanced Options") # if advanced: # nuc_size = st.sidebar.slider("") if uploaded_image is None and submitted: st.write("Please select a file to upload") else: image_file = sample_image # LOAD AND CACHE THE IMAGE if uploaded_image is not None: image = uploaded_image elif select_idx is None and 'image' in st.session_state: image = st.session_state['image'] else: image = Image.open(image_file) image = image.convert("RGB") st.session_state['image'] = image image_dict = imageLoader.transform(image) show_image = imageLoader.show_resize(image) show = st.image(show_image) show.image(show_image, "Your Image") gen_show_caption(sub, imageLoader.text_transform(cap_prompt)) # from model import * # sample_images = get_samples() # v, il = VirTexModel(), ImageLoader() # for s in sample_images: # subreddit, caption = v.predict(il.load(s)) # print("=====================") # print(subreddit) # print(caption)