File size: 2,517 Bytes
8d9306e
6f0178d
8d9306e
 
 
c951094
374fa3e
8d9306e
0bb133b
f884ea7
6f0178d
 
 
c951094
 
6f0178d
f884ea7
 
 
d28411b
8d9306e
 
144ec50
8d9306e
6f0178d
144ec50
6f0178d
 
f705683
8d9306e
6f0178d
 
144ec50
6f0178d
144ec50
 
 
 
6f0178d
 
 
9a6a97f
8d9306e
6f0178d
 
 
 
 
 
144ec50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f0178d
8d9306e
 
 
 
686f21e
8d9306e
d28411b
c951094
 
6f0178d
 
8d9306e
686f21e
c951094
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import streamlit as st
import requests


# Designing the interface
st.title("🖼️ Image Captioning Demo 📝")
st.write("[Yih-Dar SHIEH](https://huggingface.co/ydshieh)")

st.sidebar.markdown(
    """
    An image captioning model by combining ViT model with GPT2 model.
    The encoder (ViT) and decoder (GPT2) are combined using Hugging Face transformers' [Vision-To-Text Encoder-Decoder
    framework](https://huggingface.co/transformers/master/model_doc/visionencoderdecoder.html).
    The pretrained weights of both models are loaded, with a set of randomly initialized cross-attention weights.
    The model is trained on the COCO 2017 dataset for about 6900 steps (batch_size=256).
    [Follow-up work of [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/).]\n
    """
)

with st.spinner('Loading and compiling ViT-GPT2 model ...'):
    from model import *

random_image_id = get_random_image_id()

st.sidebar.title("Select a sample image")
sample_image_id = st.sidebar.selectbox(
    "Please choose a sample image",
    sample_image_ids
)

if st.sidebar.button("Random COCO 2017 (val) images"):
    random_image_id = get_random_image_id()
    sample_image_id = "None"

image_id = random_image_id
if sample_image_id != "None":
    assert type(sample_image_id) == int
    image_id = sample_image_id


sample_name = f"COCO_val2017_{str(image_id).zfill(12)}.jpg"
sample_path = os.path.join(sample_dir, sample_name)

if os.path.isfile(sample_path):
    image = Image.open(sample_path)
else:
    url = f"http://images.cocodataset.org/val2017/{str(image_id).zfill(12)}.jpg"
    image = Image.open(requests.get(url, stream=True).raw)

width, height = image.size
resized = image
if height > 384:
    width = int(width / height * 384)
    height = 384
    resized = resized.resize(size=(width, height))
if width > 512:
    width = 512
    height = int(height / width * 512)
    resized = resized.resize(size=(width, height))


st.markdown(f"[{str(image_id).zfill(12)}.jpg](http://images.cocodataset.org/val2017/{str(image_id).zfill(12)}.jpg)")
show = st.image(resized)
show.image(resized, '\n\nSelected Image')
resized.close()

# For newline
st.sidebar.write('\n')

with st.spinner('Generating image caption ...'):

    caption = predict(image)

    caption_en = caption
    st.header(f'Predicted caption:\n\n')
    st.subheader(caption_en)

st.sidebar.header("ViT-GPT2 predicts:")
st.sidebar.write(f"**English**: {caption}")

image.close()