File size: 3,644 Bytes
5281471
8d0e872
 
56c51f4
49c0315
7cc986f
 
 
 
8d0e872
 
e1cdab4
8d0e872
 
 
5d3b8a6
 
fd27ee1
8f84007
 
 
fd27ee1
a888578
710977a
8d0e872
 
 
 
c8a1a5f
5d3b8a6
8d0e872
1cabc83
15921a8
 
 
b58ad35
7332d54
15921a8
b58ad35
 
7332d54
b58ad35
7332d54
65193db
b58ad35
b108d42
7332d54
15921a8
 
 
7332d54
 
b58ad35
 
 
8d0e872
b58ad35
7332d54
b58ad35
8d0e872
b58ad35
 
7332d54
 
8f84007
8d0e872
 
 
 
15921a8
8d0e872
 
 
 
 
15921a8
 
5650fb4
7332d54
5650fb4
7332d54
5650fb4
 
 
8d0e872
 
5650fb4
 
7332d54
5650fb4
65193db
8d0e872
7332d54
65193db
5650fb4
eb66921
5650fb4
7332d54
eb66921
 
 
 
 
 
8f84007
 
 
 
 
eb66921
 
710977a
15921a8
5650fb4
15921a8
7332d54
8b7cae6
5650fb4
2fa0369
5650fb4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import io

import streamlit as st
from model import *

# # TODO:
# - Reformat the model introduction
# - Make the iterative text generation


def gen_show_caption(sub_prompt=None, cap_prompt=""):
    with st.spinner("Generating Caption"):
        subreddit, caption = virtexModel.predict(
            image_dict, sub_prompt=sub_prompt, prompt=cap_prompt
        )
        st.markdown(
            f"""
            <style>
                red {{ color:#c62828; font-size: 1.5rem }}
                blue {{ color:#2a72d5; font-size: 1.5rem }}
                remaining {{ color: black; font-size: 1.5rem }}
            </style>

            <red>r/{subreddit}</red>: <blue> {cap_prompt} </blue><remaining> {caption} </remaining>
            """,
            unsafe_allow_html=True,
        )

with st.spinner("Loading Model"):
    virtexModel, imageLoader, sample_images, valid_subs = create_objects()


# ----------------------------------------------------------------------------
# Populate sidebar.
# ----------------------------------------------------------------------------
select_idx = None

st.sidebar.title("Select or upload an image")
if st.sidebar.button("Random Sample Image"):
    select_idx = get_rand_idx(sample_images)

sample_image = sample_images[0 if select_idx is None else select_idx]


uploaded_image = None
uploaded_file = st.sidebar.file_uploader("Choose a file")

if uploaded_file is not None:
    uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
    select_idx = None  # Set this to help rewrite the cache


st.sidebar.title("Select a Subreddit")
sub = st.sidebar.selectbox(
    "Type below to condition on a subreddit. Select None for a predicted subreddit",
    valid_subs,
)

st.sidebar.title("Write a Custom Prompt")
cap_prompt = st.sidebar.text_input("Write the start of your caption below", value="")

_ = st.sidebar.button("Regenerate Caption")


st.sidebar.title("Advanced Options")
num_captions = st.sidebar.select_slider(
    "Number of Captions to Predict", options=[1, 2, 3, 4, 5], value=1
)
nuc_size = st.sidebar.slider(
    "Nucleus Size:\nLarger values lead to more diverse captions",
    min_value=0.0,
    max_value=1.0,
    value=0.8,
    step=0.05,
)
# ----------------------------------------------------------------------------

virtexModel.model.decoder.nucleus_size = nuc_size

image_file = sample_image

# LOAD AND CACHE THE IMAGE
if uploaded_image is not None:
    image = uploaded_image
elif select_idx is None and "image" in st.session_state:
    image = st.session_state["image"]
else:
    image = Image.open(image_file)

image = image.convert("RGB")

st.session_state["image"] = image


image_dict = imageLoader.transform(image)

show_image = imageLoader.show_resize(image)

st.title("Image Captioning with VirTex model trained on RedCaps")
st.markdown("""
Caption your own images or try out some of our sample images.
You can also generate captions as if they are from specific subreddits,
as if they start with a particular prompt, or even both.
Tweet your results with `#redcaps`!

**Note:** This model was not trained on images of people,
hence may not generate accurate captions describing humans.
For more details, visit [redcaps.xyz](https://redcaps.xyz) check out
our [NeurIPS 2021 paper](https://openreview.net/forum?id=VjJxBi1p9zh).
""")

_, center, _ = st.columns([1, 10, 1])

with center:
    st.image(show_image)

    if sub is None and imageLoader.text_transform(cap_prompt) != "":
        st.write("Without a specified subreddit we default to /r/pics")
    for i in range(num_captions):
        gen_show_caption(sub, imageLoader.text_transform(cap_prompt))