File size: 4,223 Bytes
6cae53c
5281471
7a0b2ad
41b10db
e1cdab4
7a0b2ad
56c51f4
49c0315
7cc986f
 
 
 
c838395
e1cdab4
5650fb4
5d3b8a6
 
fd27ee1
 
c366431
fd27ee1
5650fb4
 
 
fd27ee1
ec3e6af
fd27ee1
 
a888578
5650fb4
fd27ee1
4a4709d
e1cdab4
5650fb4
e1cdab4
5650fb4
 
c8a1a5f
 
65193db
 
 
 
 
 
5650fb4
c8a1a5f
 
5650fb4
4dab50d
c8a1a5f
5d3b8a6
defbed4
1cabc83
b58ad35
7332d54
b58ad35
7332d54
b58ad35
 
7332d54
b58ad35
7332d54
65193db
b58ad35
22f4deb
b108d42
239ba66
b108d42
22f4deb
 
65193db
b58ad35
 
 
7332d54
b58ad35
 
 
7332d54
b58ad35
 
 
 
 
 
7332d54
b58ad35
 
 
 
 
7332d54
b58ad35
 
 
 
 
 
 
7332d54
 
5650fb4
 
 
 
7332d54
5650fb4
7332d54
5650fb4
 
 
 
 
 
 
7332d54
5650fb4
65193db
5650fb4
7332d54
65193db
5650fb4
7332d54
5650fb4
7332d54
5650fb4
b58ad35
5650fb4
7332d54
5650fb4
 
2fa0369
5650fb4
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import streamlit as st
import io
import sys
import time
import json
sys.path.append("./virtex/")
from model import *

# # TODO:
# - Reformat the model introduction
# - Make the iterative text generation

def gen_show_caption(sub_prompt=None, cap_prompt = ""):
    with st.spinner("Generating Caption"):
        subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt=cap_prompt)
        st.markdown(
            f"""
            <style>
                red{{
                    color:#c62828
                }}
                blue{{
                    color:#2a72d5
                }}
                mono{{
                    font-family: "Inconsolata";
                }}
            </style>

            ### <red> r/{subreddit} </red> <blue> {cap_prompt} </blue> {caption}
            """, 
            unsafe_allow_html=True)
    
_, center, _ = st.columns([1,8,1])

with center:
    st.title("Image Captioning Demo from RedCaps")
st.sidebar.markdown(
    """
    ### Image Captioning Model from VirTex trained on RedCaps
    
    Use this page to caption your own images or try out some of our samples.
    You can also generate captions as if they are from specific subreddits,
    as if they start with a particular prompt, or even both.
    
    Share your results on twitter with #redcaps or with a friend*.
    """
)
# st.markdown(footer,unsafe_allow_html=True)

with st.spinner("Loading Model"):
    virtexModel, imageLoader, sample_images, valid_subs = create_objects()
    

select_idx = None

st.sidebar.title("Select a sample image")

if st.sidebar.button("Random Sample Image"):
    select_idx = get_rand_idx(sample_images)

sample_image = sample_images[0 if select_idx is None else select_idx]


uploaded_image = None
# with st.sidebar.form("file-uploader-form", clear_on_submit=True):
uploaded_file = st.sidebar.file_uploader("Choose a file")
# submitted = st.form_submit_button("Submit")
if uploaded_file is not None:# and submitted:
    uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
    select_idx = None # set this to help rewrite the cache

# class OnChange():
#     def __init__(self, idx):
#         self.idx = idx

#     def __call__(self):
#         st.write(f"the idx is: {self.idx}")
#         st.write(f"the sample_image is {sample_image}")

# sample_image = st.sidebar.selectbox(
#     "",
#     sample_images,
#     index = 0 if select_idx is None else select_idx,
#     on_change=OnChange(0 if select_idx is None else select_idx)
# )

st.sidebar.title("Select a Subreddit")
sub = st.sidebar.selectbox(
    "Type below to condition on a subreddit. Select None for a predicted subreddit",
    valid_subs
)

st.sidebar.title("Write a Custom Prompt")
cap_prompt = st.sidebar.text_input(
    "Write the start of your caption below", 
    value=""
)

_ = st.sidebar.button("Regenerate Caption")


st.sidebar.write("Advanced Options:")
num_captions = st.sidebar.select_slider("Number of Captions to Predict", options=[1,2,3,4,5], value=1)
nuc_size = st.sidebar.slider("Nucelus Size:\nLarger values lead to more diverse captions", min_value=0.0, max_value=1.0, value=0.8, step=0.05)
virtexModel.model.decoder.nucleus_size = nuc_size

image_file = sample_image

# LOAD AND CACHE THE IMAGE
if uploaded_image is not None:
    image = uploaded_image
elif select_idx is None and 'image' in st.session_state:
    image = st.session_state['image']
else:
    image = Image.open(image_file)

image = image.convert("RGB")

st.session_state['image'] = image


image_dict = imageLoader.transform(image)

show_image = imageLoader.show_resize(image)

with center:
    show = st.image(show_image)
    show.image(show_image)

    if sub is None and imageLoader.text_transform(cap_prompt) is not "":
        st.write("Without a specified subreddit we default to /r/pics")
    for i in range(num_captions):
        gen_show_caption(sub, imageLoader.text_transform(cap_prompt))

st.sidebar.markdown(
    """
*Please note that this model was explicitly not trained on images of people, and as a result is not designed to caption images with humans.

This demo accompanies our paper RedCaps.

Created by Karan Desai, Gaurav Kaul, Zubin Aysola, Justin Johnson
    """
)