import streamlit as st import io import sys import time import json sys.path.append("./virtex/") from model import * def gen_show_caption(sub_prompt=None, cap_prompt = ""): with st.spinner("Generating Caption"): if sub_prompt is None and cap_prompt is not "": st.write("Without a specified subreddit, caption prompts will skip subreddit prediction") subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt = cap_prompt) st.header("Predicted Caption:\n\n") # st.subheader(f"r/{subreddit}:\t{caption}\n") st.markdown( f""" ### r/{subreddit} #### {caption} """ ) st.title("Image Captioning Demo from Redcaps") st.sidebar.markdown( """ Image Captioning Model from VirTex trained on Redcaps """ ) with st.spinner("Loading Model"): virtexModel, imageLoader, sample_images, valid_subs = create_objects() select_idx = None st.sidebar.title("Select a sample image") if st.sidebar.button("Random Sample Image"): select_idx = get_rand_idx(sample_images) sample_image = sample_images[0 if select_idx is None else select_idx] # class OnChange(): # def __init__(self, idx): # self.idx = idx # def __call__(self): # st.write(f"the idx is: {self.idx}") # st.write(f"the sample_image is {sample_image}") # sample_image = st.sidebar.selectbox( # "", # sample_images, # index = 0 if select_idx is None else select_idx, # on_change=OnChange(0 if select_idx is None else select_idx) # ) st.sidebar.title("Select a Subreddit") sub = st.sidebar.selectbox( "Select None for a Predicted Subreddit", valid_subs ) st.sidebar.title("Write a Custom Prompt") cap_prompt = st.sidebar.text_input( "Leave this blank for an unbiased caption", value="" ) uploaded_image = None with st.sidebar.form("file-uploader-form", clear_on_submit=True): uploaded_file = st.file_uploader("Choose a file") submitted = st.form_submit_button("Submit") if uploaded_file is not None and submitted: uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue())) select_idx = None # set this to help rewrite the cache _ = st.sidebar.button("Regenerate Caption") if uploaded_image is None and submitted: st.write("Please select a file to upload") else: image_file = sample_image # LOAD AND CACHE THE IMAGE if uploaded_image is not None: image = uploaded_image elif select_idx is None and 'image' in st.session_state: image = st.session_state['image'] else: image = Image.open(image_file) st.session_state['image'] = image image_dict = imageLoader.transform(image) show_image = imageLoader.show_resize(image) show = st.image(show_image) show.image(show_image, "Your Image") gen_show_caption(sub, imageLoader.text_transform(cap_prompt)) # from model import * # sample_images = get_samples() # v, il = VirTexModel(), ImageLoader() # for s in sample_images: # subreddit, caption = v.predict(il.load(s)) # print("=====================") # print(subreddit) # print(caption)