File size: 5,089 Bytes
405f2d4
 
 
 
 
 
 
 
 
 
 
 
571a3f6
 
405f2d4
 
 
 
fb3c77c
405f2d4
 
 
 
 
2c8f495
405f2d4
 
 
 
2c8f495
 
405f2d4
fb3c77c
 
 
 
2c8f495
 
 
4b29c6a
2c8f495
405f2d4
2c8f495
405f2d4
 
 
2c8f495
 
 
405f2d4
 
 
 
 
 
 
 
 
 
 
 
2c8f495
 
 
 
 
 
 
 
 
405f2d4
2c8f495
405f2d4
4b29c6a
0cb8576
4b29c6a
2c8f495
 
89ea6a7
 
 
 
 
95fff6e
405f2d4
 
 
 
2c8f495
 
 
 
 
405f2d4
2c8f495
405f2d4
2c8f495
405f2d4
89ea6a7
571a3f6
89ea6a7
571a3f6
 
 
 
2c8f495
405f2d4
 
 
 
f384719
405f2d4
 
 
 
2c8f495
405f2d4
 
 
2c8f495
405f2d4
 
 
 
 
 
2c8f495
405f2d4
2c8f495
405f2d4
 
 
 
8187482
 
 
 
 
 
 
 
 
 
405f2d4
 
 
 
 
2c8f495
405f2d4
2c8f495
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from .utils import (
    get_text_attributes,
    get_top_5_predictions,
    get_transformed_image,
    plotly_express_horizontal_bar_plot,
    translate_labels,
)

import streamlit as st
import numpy as np
import pandas as pd
import os
import requests
from PIL import Image
import matplotlib.pyplot as plt
import json

from mtranslate import translate
from .utils import read_markdown

from .model.flax_clip_vision_bert.modeling_clip_vision_bert import (
    FlaxCLIPVisionBertForSequenceClassification,
)


def softmax(logits):
    return np.exp(logits) / np.sum(np.exp(logits), axis=0)


def app(state):
    vqa_state = state

    with st.beta_expander("Usage"):
        st.write(read_markdown("vqa_usage.md"))
    st.write(read_markdown("vqa_intro.md"))

    # @st.cache(persist=False)
    def predict(transformed_image, question_inputs):
        return np.array(
            vqa_state.vqa_model(pixel_values=transformed_image, **question_inputs)[0][0]
        )

    # @st.cache(persist=False)
    def load_model(ckpt):
        return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)

    vqa_checkpoints = [
        "flax-community/clip-vision-bert-vqa-ft-6k"
    ]  # TODO: Maybe add more checkpoints?
    dummy_data = pd.read_csv("dummy_vqa_multilingual.tsv", sep="\t")
    code_to_name = {
        "en": "English",
        "fr": "French",
        "de": "German",
        "es": "Spanish",
    }

    with open("answer_reverse_mapping.json") as f:
        answer_reverse_mapping = json.load(f)

    first_index = 20
    # Init Session vqa_state
    if vqa_state.vqa_image_file is None:
        vqa_state.vqa_image_file = dummy_data.loc[first_index, "image_file"]
        vqa_state.question = dummy_data.loc[first_index, "question"].strip("- ")
        vqa_state.answer_label = dummy_data.loc[first_index, "answer_label"]
        vqa_state.question_lang_id = dummy_data.loc[first_index, "lang_id"]
        vqa_state.answer_lang_id = dummy_data.loc[first_index, "lang_id"]

        image_path = os.path.join("resized_images", vqa_state.vqa_image_file)
        image = plt.imread(image_path)
        vqa_state.vqa_image = image

    if vqa_state.vqa_model is None:
        with st.spinner("Loading model..."):
            vqa_state.vqa_model = load_model(vqa_checkpoints[0])
    
    # Display Top-5 Predictions
    query1 = st.text_input(
        "Enter a URL to an image",
        value="http://images.cocodataset.org/val2017/000000039769.jpg",
    )
    col1, col2, col3 = st.beta_columns([2,1, 2])
    if col1.button(
        "Get a random example",
        help="Get a random example from the 100 `seeded` image-text pairs.",
    ):
        sample = dummy_data.sample(1).reset_index()
        vqa_state.vqa_image_file = sample.loc[0, "image_file"]
        vqa_state.question = sample.loc[0, "question"].strip("- ")
        vqa_state.answer_label = sample.loc[0, "answer_label"]
        vqa_state.question_lang_id = sample.loc[0, "lang_id"]
        vqa_state.answer_lang_id = sample.loc[0, "lang_id"]

        image_path = os.path.join("resized_images", vqa_state.vqa_image_file)
        image = plt.imread(image_path)
        vqa_state.vqa_image = image

    col2.write("OR")
    
    if col3.button("Use above URL"):
        image_data = requests.get(query1, stream=True).raw
        image = np.asarray(Image.open(image_data))
        vqa_state.mlm_image = image

    transformed_image = get_transformed_image(vqa_state.vqa_image)

    new_col1, new_col2 = st.beta_columns([5, 5])

    # Display Image
    new_col1.image(vqa_state.vqa_image, use_column_width="auto")

    # Display Question
    question = new_col2.text_input(
        label="Question",
        value=vqa_state.question,
        help="Type your question regarding the image above in one of the four languages.",
    )
    new_col2.markdown(
        f"""**English Translation**: {question if vqa_state.question_lang_id == "en" else translate(question, 'en')}"""
    )

    question_inputs = get_text_attributes(question)

    # Select Language
    options = ["en", "de", "es", "fr"]
    vqa_state.answer_lang_id = new_col2.selectbox(
        "Answer Language",
        index=options.index(vqa_state.answer_lang_id),
        options=options,
        format_func=lambda x: code_to_name[x],
        help="The language to be used to show the top-5 labels.",
    )
    if question == vqa_state.question:

        actual_answer = answer_reverse_mapping[str(vqa_state.answer_label)]
        new_col2.markdown(
            "**Actual Answer**: "
            + translate_labels([actual_answer], vqa_state.answer_lang_id)[0]
            + " ("
            + actual_answer
            + ")"
        )

    with st.spinner("Predicting..."):
        logits = predict(transformed_image, dict(question_inputs))
    logits = softmax(logits)
    labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
    translated_labels = translate_labels(labels, vqa_state.answer_lang_id)
    fig = plotly_express_horizontal_bar_plot(values, translated_labels)
    st.plotly_chart(fig, use_container_width=True)