File size: 5,769 Bytes
405f2d4
 
 
 
 
2c8f495
405f2d4
 
 
 
 
 
 
2c8f495
fb3c77c
3f280c5
 
405f2d4
 
 
 
2c8f495
405f2d4
 
 
2c8f495
 
ffe19d9
405f2d4
fb3c77c
 
37c757a
fb3c77c
2c8f495
405f2d4
4b29c6a
 
2c8f495
 
 
 
 
405f2d4
 
 
7fdcddd
 
405f2d4
 
f15eef4
2c8f495
 
 
405f2d4
f384719
2c8f495
f384719
875bee5
f384719
2c8f495
 
405f2d4
a4ce24c
405f2d4
2c8f495
405f2d4
4b29c6a
0cb8576
 
4b29c6a
405f2d4
89ea6a7
 
 
 
 
 
95fff6e
405f2d4
 
 
 
2c8f495
405f2d4
f384719
2c8f495
f384719
875bee5
f384719
2c8f495
 
405f2d4
a4ce24c
405f2d4
2c8f495
405f2d4
89ea6a7
 
 
3f280c5
 
 
 
 
 
2c8f495
405f2d4
 
 
 
f384719
405f2d4
 
 
3b5bd24
405f2d4
2c8f495
405f2d4
 
 
62c196c
 
 
3b5bd24
f384719
 
 
 
 
3b5bd24
f384719
3b5bd24
405f2d4
 
 
2c8f495
 
 
0ad09b6
def21a2
4bb9586
2c8f495
405f2d4
2c8f495
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from .utils import (
    get_text_attributes,
    get_top_5_predictions,
    get_transformed_image,
    plotly_express_horizontal_bar_plot,
    bert_tokenizer,
)

import streamlit as st
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from mtranslate import translate
from .utils import read_markdown
import requests
from PIL import Image
from .model.flax_clip_vision_bert.modeling_clip_vision_bert import (
    FlaxCLIPVisionBertForMaskedLM,
)


def softmax(logits):
    return np.exp(logits) / np.sum(np.exp(logits), axis=0)

def app(state):
    mlm_state = state
    st.header("Visuo-linguistic Mask Filling Demo")

    with st.beta_expander("Usage"):
        st.write(read_markdown("mlm_usage.md"))
    st.info(read_markdown("mlm_intro.md"))

    # @st.cache(persist=False) # TODO: Make this work with mlm_state. Currently not supported.
    def predict(transformed_image, caption_inputs):
        outputs = mlm_state.mlm_model(pixel_values=transformed_image, **caption_inputs)
        indices = np.where(caption_inputs["input_ids"] == bert_tokenizer.mask_token_id)[1][0]
        preds = outputs.logits[0][indices]
        scores = np.array(preds)
        return scores

    # @st.cache(persist=False)
    def load_model(ckpt):
        return FlaxCLIPVisionBertForMaskedLM.from_pretrained(ckpt)

    mlm_checkpoints = ["flax-community/clip-vision-bert-cc12m-70k"]
    #mlm_checkpoints = ["./ckpt/mlm/ckpt-60k"]
    dummy_data = pd.read_csv("cc12m_data/vqa_val.tsv", sep="\t")

    first_index = 15
    # Init Session mlm_state
    if mlm_state.mlm_image_file is None:
        mlm_state.mlm_image_file = dummy_data.loc[first_index, "image_file"]
        caption = dummy_data.loc[first_index, "caption"].strip("- ")
        mlm_state.unmasked_caption = caption
        ids = bert_tokenizer.encode(caption)
        mask_index = np.random.randint(1, len(ids) - 1)
        mlm_state.currently_masked_token = bert_tokenizer.convert_ids_to_tokens([ids[mask_index]])[0]
        ids[mask_index] = bert_tokenizer.mask_token_id
        mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
        mlm_state.caption_lang_id = dummy_data.loc[first_index, "lang_id"]

        image_path = os.path.join("cc12m_data/resized_images_vqa", mlm_state.mlm_image_file)
        image = plt.imread(image_path)
        mlm_state.mlm_image = image

    if mlm_state.mlm_model is None:
        # Display Top-5 Predictions
        with st.spinner("Loading model..."):
            mlm_state.mlm_model = load_model(mlm_checkpoints[0])

    query1 = st.text_input(
        "Enter a URL to an image",
        value="http://images.cocodataset.org/val2017/000000039769.jpg",
    )

    col1, col2, col3 = st.beta_columns([2,1, 2])
    if col1.button(
        "Get a random example",
        help="Get a random example from the 100 `seeded` image-text pairs.",
    ):
        sample = dummy_data.sample(1).reset_index()
        mlm_state.mlm_image_file = sample.loc[0, "image_file"]
        caption = sample.loc[0, "caption"].strip("- ")
        mlm_state.unmasked_caption = caption
        ids = bert_tokenizer.encode(caption)
        mask_index = np.random.randint(1, len(ids) - 1)
        mlm_state.currently_masked_token = bert_tokenizer.convert_ids_to_tokens([ids[mask_index]])[0]
        ids[mask_index] = bert_tokenizer.mask_token_id
        mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
        mlm_state.caption_lang_id = sample.loc[0, "lang_id"]

        image_path = os.path.join("cc12m_data/resized_images_vqa", mlm_state.mlm_image_file)
        image = plt.imread(image_path)
        mlm_state.mlm_image = image

    col2.write("OR")

    if col3.button("Use above URL"):
        image_data = requests.get(query1, stream=True).raw
        image = np.asarray(Image.open(image_data))
        mlm_state.mlm_image = image



    transformed_image = get_transformed_image(mlm_state.mlm_image)

    new_col1, new_col2 = st.beta_columns([5, 5])

    # Display Image
    new_col1.image(mlm_state.mlm_image, use_column_width="auto")

    # Display caption
    new_col2.write("Write your text with exactly one [MASK] token.")
    mlm_state.caption = new_col2.text_input(
        label="Text",
        value=mlm_state.caption,
        help="Type your masked caption regarding the image above in one of the four languages.",
    )

    print(mlm_state.currently_maskd_token)
    print(mlm_state.unmasked_caption)
    print(mlm_state.caption)
    if mlm_state.unmasked_caption == mlm_state.caption.replace("[MASK]", mlm_state.currently_masked_token):
        new_col2.markdown("**Masked Token**: "+mlm_state.currently_masked_token)
        new_col2.markdown("**English Translation: " + mlm_state.unmasked_caption if mlm_state.caption_lang_id == "en" else translate(mlm_state.unmasked_caption, 'en'))

    else:
        new_col2.markdown(
            f"""**English Translation**: {mlm_state.caption if mlm_state.caption_lang_id == "en" else translate(mlm_state.caption, 'en')}"""
        )
    caption_inputs = get_text_attributes(mlm_state.caption)

    # Display Top-5 Predictions
    with st.spinner("Predicting..."):
        scores = predict(transformed_image, dict(caption_inputs))
    scores = softmax(scores)
    labels, values = get_top_5_predictions(scores)
    filled_sentence = mlm_state.caption.replace("[MASK]", labels[-1])
    st.write("**Filled Sentence**: " + filled_sentence)
    st.write( f"""**English Translation**: {translate(filled_sentence, 'en')}""")
    # newer_col1, newer_col2 = st.beta_columns([6,4])
    fig = plotly_express_horizontal_bar_plot(values, labels)
    st.dataframe(pd.DataFrame({"Tokens":labels, "English Translation": list(map(lambda x: translate(x),labels))}).T)
    st.plotly_chart(fig, use_container_width=True)