File size: 2,666 Bytes
6cae53c
5281471
7a0b2ad
41b10db
e1cdab4
7a0b2ad
49c0315
c838395
e1cdab4
ed768de
 
c838395
e1cdab4
 
 
 
c8a1a5f
 
 
 
 
 
4dab50d
c8a1a5f
 
5a0da41
c8a1a5f
 
e1cdab4
 
4dab50d
c8a1a5f
0674c7e
4dab50d
c8a1a5f
0674c7e
 
 
 
 
c8a1a5f
 
0674c7e
 
c8a1a5f
4dab50d
1cabc83
 
 
 
 
 
 
de79882
1cabc83
c838395
1cabc83
49c0315
54fa80d
c8a1a5f
 
 
 
 
712349e
e1cdab4
 
4dab50d
c8a1a5f
 
4dab50d
c8a1a5f
 
5281471
a4c3b59
5281471
c8a1a5f
5281471
f307fe5
a4c3b59
 
 
5281471
1cabc83
5281471
c8a1a5f
4dab50d
c8a1a5f
 
 
 
6cae53c
c8a1a5f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import streamlit as st
import io
import sys
import time
import json
sys.path.append("./virtex/")

def gen_show_caption(sub_prompt=None, cap_prompt = ""):
    with st.spinner("Generating Caption"):
        if sub_prompt is None and cap_prompt is not "":
            st.write("Without a specified subreddit, caption prompts will skip subreddit prediction")
        subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt = cap_prompt)
        st.header("Predicted Caption:\n\n")
        st.subheader(f"r/{subreddit}:\t{caption}\n")
    

st.title("Image Captioning Demo from Redcaps")
st.sidebar.markdown(
    """
    Image Captioning Model from VirTex trained on Redcaps
    """
)

with st.spinner("Loading Model"):
    from model import *
    sample_images = get_samples()
    virtexModel = VirTexModel()
    imageLoader = ImageLoader()
    valid_subs = json.load(open(VALID_SUBREDDITS_PATH))
    valid_subs.insert(0, None)

random_image = get_rand_img(sample_images)
rand_idx = 0

st.sidebar.title("Select a sample image")

if st.sidebar.button("Random Sample Image"):
    rand_idx, random_image = get_rand_img(sample_images)
    sample_image = None
    
sample_image = st.sidebar.selectbox(
    "",
    sample_images,
    index=rand_idx
)

st.sidebar.title("Select a Subreddit")
sub = st.sidebar.selectbox(
    "Select None for a Predicted Subreddit",
    valid_subs
)

st.sidebar.title("Write a Custom Prompt")
cap_prompt = st.sidebar.text_input(
    "Leave this blank for an unbiased caption", 
    value=""
)
    
    
uploaded_image = None
with st.sidebar.form("file-uploader-form", clear_on_submit=True):
    uploaded_file = st.file_uploader("Choose a file")
    submitted = st.form_submit_button("Submit")
    if uploaded_file is not None and submitted:
        uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
        


if uploaded_image is None and submitted:
    st.write("Please select a file to upload")

else:
    image_file = sample_image if sample_image is not None else random_image
    
    image = uploaded_image if uploaded_image is not None else Image.open(image_file)
    
    image_dict = imageLoader.transform(image)
    
#     image = imageLoader.to_image(image_dict["image"].squeeze(0))
    
    show = st.image(image)
    show.image(image, "Your Image")
    
    gen_show_caption(sub, cap_prompt)
        
    image.close()

# from model import *
# download_files()
# sample_images = get_samples()
# v, il = VirTexModel(), ImageLoader()

# for s in sample_images:
#     subreddit, caption = v.predict(il.load(s))
#     print("=====================")
#     print(subreddit)
#     print(caption)