File size: 2,258 Bytes
6cae53c
5281471
7a0b2ad
41b10db
7a0b2ad
49c0315
c8a1a5f
 
 
 
 
 
4dab50d
c8a1a5f
41b10db
 
c8a1a5f
41b10db
5a0da41
41b10db
c8a1a5f
41b10db
ce8d2b3
c8a1a5f
 
41b10db
4dab50d
c8a1a5f
4dab50d
c8a1a5f
 
 
 
 
4dab50d
c8a1a5f
 
 
49c0315
c8a1a5f
 
 
 
 
712349e
4dab50d
c8a1a5f
 
4dab50d
c8a1a5f
 
5281471
a4c3b59
5281471
c8a1a5f
5281471
f307fe5
a4c3b59
 
 
5281471
c8a1a5f
 
 
 
 
5281471
c8a1a5f
4dab50d
c8a1a5f
 
 
 
6cae53c
c8a1a5f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import streamlit as st
import io
import sys
import time
sys.path.append("./virtex/")

st.title("Image Captioning Demo from Redcaps")
st.sidebar.markdown(
    """
    Image Captioning Model from VirTex trained on Redcaps
    """
)

with st.spinner("Loading Model"):
    st.write("DEBUG PRINTING ==========")
    start = time.time()
    from model import *
    st.write(f"Import TIME: {time.time()-start}")
    sample_images = get_samples()
    start = time.time()
    download_files()
    st.write(f"download TIME: {time.time()-start}")
    start = time.time()
    virtexModel = VirTexModel()
    imageLoader = ImageLoader()
    st.write(f"model TIME: {time.time()-start}")

random_image = get_rand_img(sample_images)

st.sidebar.title("Select a sample image")
sample_image = st.sidebar.selectbox(
    "",
    sample_images
)

if st.sidebar.button("Random Sample Image"):
    random_image = get_rand_img(sample_images)
    sample_image = None
    
uploaded_image = None
with st.sidebar.form("file-uploader-form", clear_on_submit=True):
    uploaded_file = st.file_uploader("Choose a file")
    submitted = st.form_submit_button("Submit")
    if uploaded_file is not None and submitted:
        uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))

if uploaded_image is None and submitted:
    st.write("Please select a file to upload")

else:
    image_file = sample_image if sample_image is not None else random_image
    
    image = uploaded_image if uploaded_image is not None else Image.open(image_file)
    
    image_dict = imageLoader.transform(image)
    
#     image = imageLoader.to_image(image_dict["image"].squeeze(0))
    
    show = st.image(image)
    show.image(image, "Your Image")
    
    with st.spinner("Generating Caption"):
        subreddit, caption = virtexModel.predict(image_dict)
        st.header("Predicted Caption:\n\n")
        st.subheader(f"Subreddit: {subreddit}\n")
        st.subheader(f"Caption: {caption}\n")
        
    image.close()

# from model import *
# download_files()
# sample_images = get_samples()
# v, il = VirTexModel(), ImageLoader()

# for s in sample_images:
#     subreddit, caption = v.predict(il.load(s))
#     print("=====================")
#     print(subreddit)
#     print(caption)