File size: 2,488 Bytes
6cae53c
5281471
7a0b2ad
41b10db
e1cdab4
7a0b2ad
49c0315
1cabc83
e1cdab4
b80df5c
e1cdab4
 
 
 
c8a1a5f
 
 
 
 
 
4dab50d
c8a1a5f
 
5a0da41
c8a1a5f
 
e1cdab4
 
4dab50d
c8a1a5f
4dab50d
c8a1a5f
 
 
 
 
4dab50d
1cabc83
 
 
 
 
 
 
de79882
1cabc83
de79882
1cabc83
 
54fa80d
 
 
c8a1a5f
 
 
49c0315
54fa80d
c8a1a5f
 
 
 
 
712349e
e1cdab4
 
4dab50d
c8a1a5f
 
4dab50d
c8a1a5f
 
5281471
a4c3b59
5281471
c8a1a5f
5281471
f307fe5
a4c3b59
 
 
5281471
1cabc83
5281471
c8a1a5f
4dab50d
c8a1a5f
 
 
 
6cae53c
c8a1a5f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import streamlit as st
import io
import sys
import time
import json
sys.path.append("./virtex/")

def gen_show_caption(sub_prompt=None, cap_prompt = None):
    with st.spinner("Generating Caption"):
        subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt)
        st.header("Predicted Caption:\n\n")
        st.subheader(f"r/{subreddit}:\t{caption}\n")
    

st.title("Image Captioning Demo from Redcaps")
st.sidebar.markdown(
    """
    Image Captioning Model from VirTex trained on Redcaps
    """
)

with st.spinner("Loading Model"):
    from model import *
    sample_images = get_samples()
    virtexModel = VirTexModel()
    imageLoader = ImageLoader()
    valid_subs = json.load(open(VALID_SUBREDDITS_PATH))
    valid_subs.insert(0, None)

random_image = get_rand_img(sample_images)

st.sidebar.title("Select a sample image")
sample_image = st.sidebar.selectbox(
    "",
    sample_images
)

st.sidebar.title("Select a Subreddit")
sub = st.sidebar.selectbox(
    "Select None for a Predicted Subreddit",
    valid_subs
)

st.sidebar.title("Write a Custom Prompt")
cap_prompt = st.sidebar.text_input(
    "Leave this blank for an unbiased caption", 
    value=None
)

if cap_prompt is None:
    st.write("HAHA")

if st.sidebar.button("Random Sample Image"):
    random_image = get_rand_img(sample_images)
    sample_image = None
    
    
uploaded_image = None
with st.sidebar.form("file-uploader-form", clear_on_submit=True):
    uploaded_file = st.file_uploader("Choose a file")
    submitted = st.form_submit_button("Submit")
    if uploaded_file is not None and submitted:
        uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
        


if uploaded_image is None and submitted:
    st.write("Please select a file to upload")

else:
    image_file = sample_image if sample_image is not None else random_image
    
    image = uploaded_image if uploaded_image is not None else Image.open(image_file)
    
    image_dict = imageLoader.transform(image)
    
#     image = imageLoader.to_image(image_dict["image"].squeeze(0))
    
    show = st.image(image)
    show.image(image, "Your Image")
    
    gen_show_caption(sub, cap_prompt)
        
    image.close()

# from model import *
# download_files()
# sample_images = get_samples()
# v, il = VirTexModel(), ImageLoader()

# for s in sample_images:
#     subreddit, caption = v.predict(il.load(s))
#     print("=====================")
#     print(subreddit)
#     print(caption)