Spaces:
Runtime error
Runtime error
File size: 3,644 Bytes
5281471 8d0e872 56c51f4 49c0315 7cc986f 8d0e872 e1cdab4 8d0e872 5d3b8a6 fd27ee1 8f84007 fd27ee1 a888578 710977a 8d0e872 c8a1a5f 5d3b8a6 8d0e872 1cabc83 15921a8 b58ad35 7332d54 15921a8 b58ad35 7332d54 b58ad35 7332d54 65193db b58ad35 b108d42 7332d54 15921a8 7332d54 b58ad35 8d0e872 b58ad35 7332d54 b58ad35 8d0e872 b58ad35 7332d54 8f84007 8d0e872 15921a8 8d0e872 15921a8 5650fb4 7332d54 5650fb4 7332d54 5650fb4 8d0e872 5650fb4 7332d54 5650fb4 65193db 8d0e872 7332d54 65193db 5650fb4 eb66921 5650fb4 7332d54 eb66921 8f84007 eb66921 710977a 15921a8 5650fb4 15921a8 7332d54 8b7cae6 5650fb4 2fa0369 5650fb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import io
import streamlit as st
from model import *
# # TODO:
# - Reformat the model introduction
# - Make the iterative text generation
def gen_show_caption(sub_prompt=None, cap_prompt=""):
with st.spinner("Generating Caption"):
subreddit, caption = virtexModel.predict(
image_dict, sub_prompt=sub_prompt, prompt=cap_prompt
)
st.markdown(
f"""
<style>
red {{ color:#c62828; font-size: 1.5rem }}
blue {{ color:#2a72d5; font-size: 1.5rem }}
remaining {{ color: black; font-size: 1.5rem }}
</style>
<red>r/{subreddit}</red>: <blue> {cap_prompt} </blue><remaining> {caption} </remaining>
""",
unsafe_allow_html=True,
)
with st.spinner("Loading Model"):
virtexModel, imageLoader, sample_images, valid_subs = create_objects()
# ----------------------------------------------------------------------------
# Populate sidebar.
# ----------------------------------------------------------------------------
select_idx = None
st.sidebar.title("Select or upload an image")
if st.sidebar.button("Random Sample Image"):
select_idx = get_rand_idx(sample_images)
sample_image = sample_images[0 if select_idx is None else select_idx]
uploaded_image = None
uploaded_file = st.sidebar.file_uploader("Choose a file")
if uploaded_file is not None:
uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
select_idx = None # Set this to help rewrite the cache
st.sidebar.title("Select a Subreddit")
sub = st.sidebar.selectbox(
"Type below to condition on a subreddit. Select None for a predicted subreddit",
valid_subs,
)
st.sidebar.title("Write a Custom Prompt")
cap_prompt = st.sidebar.text_input("Write the start of your caption below", value="")
_ = st.sidebar.button("Regenerate Caption")
st.sidebar.title("Advanced Options")
num_captions = st.sidebar.select_slider(
"Number of Captions to Predict", options=[1, 2, 3, 4, 5], value=1
)
nuc_size = st.sidebar.slider(
"Nucleus Size:\nLarger values lead to more diverse captions",
min_value=0.0,
max_value=1.0,
value=0.8,
step=0.05,
)
# ----------------------------------------------------------------------------
virtexModel.model.decoder.nucleus_size = nuc_size
image_file = sample_image
# LOAD AND CACHE THE IMAGE
if uploaded_image is not None:
image = uploaded_image
elif select_idx is None and "image" in st.session_state:
image = st.session_state["image"]
else:
image = Image.open(image_file)
image = image.convert("RGB")
st.session_state["image"] = image
image_dict = imageLoader.transform(image)
show_image = imageLoader.show_resize(image)
st.title("Image Captioning with VirTex model trained on RedCaps")
st.markdown("""
Caption your own images or try out some of our sample images.
You can also generate captions as if they are from specific subreddits,
as if they start with a particular prompt, or even both.
Tweet your results with `#redcaps`!
**Note:** This model was not trained on images of people,
hence may not generate accurate captions describing humans.
For more details, visit [redcaps.xyz](https://redcaps.xyz) check out
our [NeurIPS 2021 paper](https://openreview.net/forum?id=VjJxBi1p9zh).
""")
_, center, _ = st.columns([1, 10, 1])
with center:
st.image(show_image)
if sub is None and imageLoader.text_transform(cap_prompt) != "":
st.write("Without a specified subreddit we default to /r/pics")
for i in range(num_captions):
gen_show_caption(sub, imageLoader.text_transform(cap_prompt))
|