File size: 2,843 Bytes
6cae53c
5281471
7a0b2ad
41b10db
e1cdab4
7a0b2ad
56c51f4
49c0315
c838395
e1cdab4
ed768de
 
c838395
e1cdab4
ee12c5f
5d3b8a6
 
3f08c7f
 
 
5d3b8a6
 
e1cdab4
 
c8a1a5f
 
 
 
 
 
4dab50d
c8a1a5f
5d3b8a6
defbed4
 
 
4dab50d
 
c8a1a5f
0674c7e
 
defbed4
0674c7e
c8a1a5f
 
0674c7e
defbed4
c8a1a5f
4dab50d
1cabc83
 
 
 
 
 
 
de79882
1cabc83
c838395
1cabc83
49c0315
54fa80d
c8a1a5f
 
 
 
 
712349e
defbed4
e1cdab4
 
4dab50d
c8a1a5f
 
4dab50d
c8a1a5f
70e1d6c
5281471
defbed4
 
 
4871a34
 
defbed4
 
5281471
4871a34
70e1d6c
c8a1a5f
5281471
79c7b01
a4c3b59
79c7b01
 
5281471
ab30850
5281471
c8a1a5f
4dab50d
c8a1a5f
 
 
6cae53c
c8a1a5f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import streamlit as st
import io
import sys
import time
import json
sys.path.append("./virtex/")
from model import *

def gen_show_caption(sub_prompt=None, cap_prompt = ""):
    with st.spinner("Generating Caption"):
        if sub_prompt is None and cap_prompt is not "":
            st.write("Without a specified subreddit, caption prompts will skip subreddit prediction")
        subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt = cap_prompt)
        st.header("Predicted Caption:\n\n")
#         st.subheader(f"r/{subreddit}:\t{caption}\n")
        st.markdown(
            f"""
            ### r/{subreddit}
            
            #### {caption}
            """
        )
    

st.title("Image Captioning Demo from Redcaps")
st.sidebar.markdown(
    """
    Image Captioning Model from VirTex trained on Redcaps
    """
)

with st.spinner("Loading Model"):
    virtexModel, imageLoader, sample_images, valid_subs = create_objects()
    
    
select_idx = None


st.sidebar.title("Select a sample image")

if st.sidebar.button("Random Sample Image"):
    select_idx = get_rand_idx(sample_images)
    
sample_image = st.sidebar.selectbox(
    "",
    sample_images,
    index = 0 if select_idx is None else select_idx
)

st.sidebar.title("Select a Subreddit")
sub = st.sidebar.selectbox(
    "Select None for a Predicted Subreddit",
    valid_subs
)

st.sidebar.title("Write a Custom Prompt")
cap_prompt = st.sidebar.text_input(
    "Leave this blank for an unbiased caption", 
    value=""
)
    
    
uploaded_image = None
with st.sidebar.form("file-uploader-form", clear_on_submit=True):
    uploaded_file = st.file_uploader("Choose a file")
    submitted = st.form_submit_button("Submit")
    if uploaded_file is not None and submitted:
        uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
        select_idx = None # set this to help rewrite the cache
        


if uploaded_image is None and submitted:
    st.write("Please select a file to upload")

else:
    image_file = sample_image
    
    # LOAD AND CACHE THE IMAGE
    if uploaded_image is not None:
        image = uploaded_image
    elif select_idx is None:
        image = st.session_state.image
    else:
        image = Image.open(image_file)
    
    st.session_state.image = image
    
    image_dict = imageLoader.transform(image)
    
    show_image = imageLoader.show_resize(image)
    
    show = st.image(show_image)
    show.image(show_image, "Your Image")
    
    gen_show_caption(sub, imageLoader.text_transform(cap_prompt))
        
    image.close()

# from model import *
# sample_images = get_samples()
# v, il = VirTexModel(), ImageLoader()

# for s in sample_images:
#     subreddit, caption = v.predict(il.load(s))
#     print("=====================")
#     print(subreddit)
#     print(caption)