Spaces:
Runtime error
Runtime error
File size: 3,544 Bytes
6cae53c 5281471 7a0b2ad 41b10db e1cdab4 7a0b2ad 56c51f4 49c0315 c838395 e1cdab4 ed768de c838395 e1cdab4 ee12c5f 5d3b8a6 3f08c7f 5d3b8a6 e1cdab4 c8a1a5f 4dab50d c8a1a5f 5d3b8a6 defbed4 1cabc83 7332d54 4dab50d 7332d54 c8a1a5f 5281471 7332d54 defbed4 7332d54 4dab50d 7332d54 6cae53c 7332d54 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import streamlit as st
import io
import sys
import time
import json
sys.path.append("./virtex/")
from model import *
def gen_show_caption(sub_prompt=None, cap_prompt = ""):
with st.spinner("Generating Caption"):
if sub_prompt is None and cap_prompt is not "":
st.write("Without a specified subreddit, caption prompts will skip subreddit prediction")
subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt = cap_prompt)
st.header("Predicted Caption:\n\n")
# st.subheader(f"r/{subreddit}:\t{caption}\n")
st.markdown(
f"""
### r/{subreddit}
#### {caption}
"""
)
st.title("Image Captioning Demo from Redcaps")
st.sidebar.markdown(
"""
Image Captioning Model from VirTex trained on Redcaps
"""
)
with st.spinner("Loading Model"):
virtexModel, imageLoader, sample_images, valid_subs = create_objects()
staggered = st.sidebar.checkbox("Iteratively Generate Captions")
if staggered:
pass
else:
select_idx = None
st.sidebar.title("Select a sample image")
if st.sidebar.button("Random Sample Image"):
select_idx = get_rand_idx(sample_images)
sample_image = sample_images[0 if select_idx is None else select_idx]
# class OnChange():
# def __init__(self, idx):
# self.idx = idx
# def __call__(self):
# st.write(f"the idx is: {self.idx}")
# st.write(f"the sample_image is {sample_image}")
# sample_image = st.sidebar.selectbox(
# "",
# sample_images,
# index = 0 if select_idx is None else select_idx,
# on_change=OnChange(0 if select_idx is None else select_idx)
# )
st.sidebar.title("Select a Subreddit")
sub = st.sidebar.selectbox(
"Select None for a Predicted Subreddit",
valid_subs
)
st.sidebar.title("Write a Custom Prompt")
cap_prompt = st.sidebar.text_input(
"Leave this blank for an unbiased caption",
value=""
)
uploaded_image = None
with st.sidebar.form("file-uploader-form", clear_on_submit=True):
uploaded_file = st.file_uploader("Choose a file")
submitted = st.form_submit_button("Submit")
if uploaded_file is not None and submitted:
uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
select_idx = None # set this to help rewrite the cache
_ = st.sidebar.button("Regenerate Caption")
if uploaded_image is None and submitted:
st.write("Please select a file to upload")
else:
image_file = sample_image
# LOAD AND CACHE THE IMAGE
if uploaded_image is not None:
image = uploaded_image
elif select_idx is None and 'image' in st.session_state:
image = st.session_state['image']
else:
image = Image.open(image_file)
st.session_state['image'] = image
image_dict = imageLoader.transform(image)
show_image = imageLoader.show_resize(image)
show = st.image(show_image)
show.image(show_image, "Your Image")
gen_show_caption(sub, imageLoader.text_transform(cap_prompt))
# from model import *
# sample_images = get_samples()
# v, il = VirTexModel(), ImageLoader()
# for s in sample_images:
# subreddit, caption = v.predict(il.load(s))
# print("=====================")
# print(subreddit)
# print(caption)
|