File size: 3,145 Bytes
ec5e5fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import streamlit as st
import requests
import io


# Designing the interface
st.title("🖼️ Image Captioning Demo 📝")

st.sidebar.markdown(
    """
    An image captioning model by combining ViT model with GPT2 model.
    The encoder (ViT) and decoder (GPT2) are combined using Hugging Face transformers' [Vision-To-Text Encoder-Decoder
    framework](https://huggingface.co/transformers/master/model_doc/visionencoderdecoder.html).
    The pretrained weights of both models are loaded, with a set of randomly initialized cross-attention weights.
    The model is trained on the COCO 2017 dataset for about 6900 steps (batch_size=256).
    [Follow-up work of [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/).]\n
    """
)

with st.spinner('Loading and compiling ViT-GPT2 model ...'):
    from model import *

random_image_id = get_random_image_id()

st.sidebar.title("Select a sample image")
sample_image_id = st.sidebar.selectbox(
    "Please choose a sample image",
    sample_image_ids
)

if st.sidebar.button("Random COCO 2017 (val) images"):
    random_image_id = get_random_image_id()
    sample_image_id = "None"

bytes_data = None
with st.sidebar.form("file-uploader-form", clear_on_submit=True):
    uploaded_file = st.file_uploader("Choose a file")
    submitted = st.form_submit_button("Upload")
    if submitted and uploaded_file is not None:
        bytes_data = io.BytesIO(uploaded_file.getvalue())

if (bytes_data is None) and submitted:

    st.write("No file is selected to upload")

else:

    image_id = random_image_id
    if sample_image_id != "None":
        assert type(sample_image_id) == int
        image_id = sample_image_id

    sample_name = f"COCO_val2017_{str(image_id).zfill(12)}.jpg"
    sample_path = os.path.join(sample_dir, sample_name)

    if bytes_data is not None:
        image = Image.open(bytes_data)
    elif os.path.isfile(sample_path):
        image = Image.open(sample_path)
    else:
        url = f"http://images.cocodataset.org/val2017/{str(image_id).zfill(12)}.jpg"
        image = Image.open(requests.get(url, stream=True).raw)

    width, height = image.size
    resized = image.resize(size=(width, height))
    if height > 384:
        width = int(width / height * 384)
        height = 384
        resized = resized.resize(size=(width, height))
    width, height = resized.size
    if width > 512:
        width = 512
        height = int(height / width * 512)
        resized = resized.resize(size=(width, height))

    if bytes_data is None:
        st.markdown(f"[{str(image_id).zfill(12)}.jpg](http://images.cocodataset.org/val2017/{str(image_id).zfill(12)}.jpg)")
    show = st.image(resized)
    show.image(resized, '\n\nSelected Image')
    resized.close()

    # For newline
    st.sidebar.write('\n')

    with st.spinner('Generating image caption ...'):

        caption = predict(image)

        caption_en = caption
        st.header(f'Predicted caption:\n\n')
        st.subheader(caption_en)

    st.sidebar.header("ViT-GPT2 predicts: ")
    st.sidebar.write(f"{caption}")

    image.close()