File size: 1,915 Bytes
6cae53c
5281471
7a0b2ad
 
49c0315
c8a1a5f
 
 
 
 
 
4dab50d
c8a1a5f
 
 
 
 
 
4dab50d
c8a1a5f
4dab50d
c8a1a5f
 
 
 
 
4dab50d
c8a1a5f
 
 
49c0315
c8a1a5f
 
 
 
 
 
4dab50d
c8a1a5f
 
4dab50d
c8a1a5f
 
5281471
c8a1a5f
5281471
c8a1a5f
5281471
c8a1a5f
5281471
c8a1a5f
 
 
 
 
5281471
c8a1a5f
4dab50d
c8a1a5f
 
 
 
6cae53c
c8a1a5f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import streamlit as st
import io
import sys
sys.path.append("./virtex/")

st.title("Image Captioning Demo from Redcaps")
st.sidebar.markdown(
    """
    Image Captioning Model from VirTex trained on Redcaps
    """
)

with st.spinner("Loading Model"):
    from model import *
    sample_images = glob.glob("./samples/*.jpg")
    download_files()
    virtexModel = VirTexModel()
    imageLoader = ImageLoader()

random_image = get_rand_img(sample_images)

st.sidebar.title("Select a sample image")
sample_image = st.sidebar.selectbox(
    "",
    sample_images
)

if st.sidebar.button("Random Sample Image"):
    random_image = get_rand_img(sample_images)
    sample_image = None
    
uploaded_image = None
with st.sidebar.form("file-uploader-form", clear_on_submit=True):
    uploaded_file = st.file_uploader("Choose a file")
    submitted = st.form_submit_button("Submit")
    if uploaded_file is not None and submitted:
        uploaded_image = Image.open(io.BytesIO(uploaded_file.get_values()))

if uploaded_image is None and submitted:
    st.write("Please select a file to upload")

else:
    image_file = sample_image if sample_image is not None else random_image
    
    image = uploaded_image if uploaded_image is not None else Image.open()
    
    image_dict = imageLoader.transform(image)
    
    show.image(st.image(image_dict["image"]), "Target Image")
    
    with st.spinner("Generating Caption"):
        subreddit, caption = virtexModel.predict(image_dict)
        st.header("Predicted Caption:\n\n")
        st.subheader(f"Subreddit: {subreddit}\n")
        st.subheader(f"Caption: {caption}\n")
        
    image.close()

# from model import *
# download_files()
# sample_images = get_samples()
# v, il = VirTexModel(), ImageLoader()

# for s in sample_images:
#     subreddit, caption = v.predict(il.load(s))
#     print("=====================")
#     print(subreddit)
#     print(caption)