gchhablani commited on
Commit
405f2d4
β€’
1 Parent(s): 61c3dfa

Add MLM task

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. app.py +42 -173
  2. apps/mlm.py +109 -0
  3. {model β†’ apps/model}/__init__.py +0 -0
  4. {model β†’ apps/model}/flax_clip_vision_bert/__init__.py +0 -0
  5. {model β†’ apps/model}/flax_clip_vision_bert/configuration_clip_vision_bert.py +0 -0
  6. {model β†’ apps/model}/flax_clip_vision_bert/modeling_clip_vision_bert.py +0 -0
  7. utils.py β†’ apps/utils.py +6 -5
  8. apps/vqa.py +131 -0
  9. cc12m_data/.DS_Store +0 -0
  10. cc12m_data/images_vqa/.DS_Store +0 -0
  11. cc12m_data/images_vqa/00212055---Wax_cylinder_in_Dictaphone.jpg +0 -0
  12. cc12m_data/images_vqa/00315853---041bdd212f5b5d3d30cbc4ccf523f1a3.jpg +0 -0
  13. cc12m_data/images_vqa/00328633---Metal+chips+fly+in+a+high+speed+turning+operation+performed+on+a+computer+numerical+control+turning+center+%28photo+courtesy+of+Cincinnati+Milacron%29..jpg +0 -0
  14. cc12m_data/images_vqa/00491934---I6FTIDWLJRFPHAK4ZSZH4RQGDA.jpg +0 -0
  15. cc12m_data/images_vqa/00507360---MushroomRisotto1.jpg +0 -0
  16. cc12m_data/images_vqa/00602376---%20essay-example-writing-comparison-compare-contrast-how-to-write-poem-examples-of%20-1024x768.jpg +0 -0
  17. cc12m_data/images_vqa/00606341---dog-coloring-book-detailed-dogs-page2.jpg +0 -0
  18. cc12m_data/images_vqa/00697411---dream-house-swimming-pool-large-133359636.jpg +0 -0
  19. cc12m_data/images_vqa/00923733---white-commercial-van-road-motion-blurred-d-illustration-custom-designed-brandless-87900010.jpg +0 -0
  20. cc12m_data/images_vqa/01023838---fundraising-photo.jpg +0 -0
  21. cc12m_data/images_vqa/01053356---522a16b60d3f226fff652671cdde6011.jpg +0 -0
  22. cc12m_data/images_vqa/01157077---female-fruit-picker-worker-basket-woodcut-illustration-wearing-bandana-holding-viewed-side-set-white-61675986.jpg +0 -0
  23. cc12m_data/images_vqa/01275377---Young-the-Giant.jpg +0 -0
  24. cc12m_data/images_vqa/01327794---40250345161_452dc56b11_z.jpg +0 -0
  25. cc12m_data/images_vqa/01648721---170420062908YDYA.jpg +0 -0
  26. cc12m_data/images_vqa/01760795---The-Size-of-the-buildings-in-Shekou-are-in-direct-relation-to-the-time-it-takes-to-accomplish-tasks.jpg +0 -0
  27. cc12m_data/images_vqa/01761366---fresh-salad-flying-vegetables-ingredients-isolated-white-background-48747892.jpg +0 -0
  28. cc12m_data/images_vqa/01772764---business-woman-winner-standing-first-600w-254762824.jpg +0 -0
  29. cc12m_data/images_vqa/01813337---cd4df5cb43d087533e89b12c9805409e.jpg +0 -0
  30. cc12m_data/images_vqa/02034916---XKC6GGK5NDECNBAD5WAQUWOO5U.jpg +0 -0
  31. cc12m_data/images_vqa/02175876---DL2-4i4.jpg +0 -0
  32. cc12m_data/images_vqa/02217469---mount-macedon-victoria-australia-macedon-regional-park-region-photographed-by-karen-robinson-_march-29-2020_042-1.jpg +0 -0
  33. cc12m_data/images_vqa/02243845---heritage-heritage-matte-stainless-steel-sink-undermount-5_2048x.jpg +0 -0
  34. cc12m_data/images_vqa/02335328---margaret-and-alexander-potters-houses-1948.jpg +0 -0
  35. cc12m_data/images_vqa/02520451---Gower-1.jpg +0 -0
  36. cc12m_data/images_vqa/02912250---a-black-panther-has-been-spotted-in-weald-park-brentwood-essex-britain-shutterstock-editorial-618335e.jpg +0 -0
  37. cc12m_data/images_vqa/03257347---looking-farther-afield-article-size.jpg +0 -0
  38. cc12m_data/images_vqa/03271226---beneath-the-borealis-092517-a-very-bear-y-summer-kennicott-valley-virga.jpg +0 -0
  39. cc12m_data/images_vqa/03307717---tumblr_m9d4xkRM5n1rypkpio1_1280.jpg +0 -0
  40. cc12m_data/images_vqa/03360735---Warm-Bacon-Dip-EasyLowCarb-2.jpg +0 -0
  41. cc12m_data/images_vqa/03394023---m_5e36e15f2169682519441e34.jpg +0 -0
  42. cc12m_data/images_vqa/03401066---160328-capitol-police-mn-1530_bd68b01f1d7f1c3ab99eafa503930569.fit-760w.jpg +0 -0
  43. cc12m_data/images_vqa/03598306---20400805522_fba017bc51_b.jpg +0 -0
  44. cc12m_data/images_vqa/03618296---A+pink+and+grey+woven+baskets+sits+on+top+of+a+clear+side+table.jpg +0 -0
  45. cc12m_data/images_vqa/04331097---108_1504859395_24.jpg +0 -0
  46. cc12m_data/images_vqa/04334412---Pants-All-match-Professional-Harlan-Women-s-Loose-Skinny-High-Waist-New-2019-Suit-Summer-Leisure-Pants-2077.jpg +0 -0
  47. cc12m_data/images_vqa/04358571---41-Travelex.jpg +0 -0
  48. cc12m_data/images_vqa/04361362---square-stone-benches-around-fire-pit-outside-residential-building-sunny-day-pathways-plants-can-also-be-seen-homes-171086572.jpg +0 -0
  49. cc12m_data/images_vqa/04530023---49305383277_29d4a34f37_h.jpg +0 -0
  50. cc12m_data/images_vqa/04749808---thinkstockphotos-1858212351.jpg +0 -0
app.py CHANGED
@@ -1,183 +1,52 @@
1
- import json
2
  import os
3
- from io import BytesIO
4
-
5
- import matplotlib.pyplot as plt
6
- import numpy as np
7
- import pandas as pd
8
  import streamlit as st
9
- from mtranslate import translate
10
- from PIL import Image
11
- from streamlit.elements import markdown
12
-
13
- from model.flax_clip_vision_bert.modeling_clip_vision_bert import (
14
- FlaxCLIPVisionBertForSequenceClassification,
15
- )
16
- from session import _get_state
17
- from utils import (
18
- get_text_attributes,
19
- get_top_5_predictions,
20
- get_transformed_image,
21
- plotly_express_horizontal_bar_plot,
22
- translate_labels,
23
- )
24
-
25
- state = _get_state()
26
-
27
-
28
- @st.cache(persist=True)
29
- def load_model(ckpt):
30
- return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
31
-
32
-
33
- @st.cache(persist=True)
34
- def predict(transformed_image, question_inputs):
35
- return np.array(model(pixel_values=transformed_image, **question_inputs)[0][0])
36
-
37
-
38
- def softmax(logits):
39
- return np.exp(logits) / np.sum(np.exp(logits), axis=0)
40
-
41
 
42
  def read_markdown(path, parent="./sections/"):
43
  with open(os.path.join(parent, path)) as f:
44
  return f.read()
45
 
46
- checkpoints = ["./ckpt/vqa/ckpt-60k-5999"] # TODO: Maybe add more checkpoints?
47
- dummy_data = pd.read_csv("dummy_vqa_multilingual.tsv", sep="\t")
48
- code_to_name = {
49
- "en": "English",
50
- "fr": "French",
51
- "de": "German",
52
- "es": "Spanish",
53
- }
54
-
55
- with open("answer_reverse_mapping.json") as f:
56
- answer_reverse_mapping = json.load(f)
57
-
58
-
59
- st.set_page_config(
60
- page_title="Multilingual VQA",
61
- layout="wide",
62
- initial_sidebar_state="collapsed",
63
- page_icon="./misc/mvqa-logo-3-white.png",
64
- )
65
-
66
- st.title("Multilingual Visual Question Answering")
67
- st.write(
68
- "[Gunjan Chhablani](https://huggingface.co/gchhablani), [Bhavitvya Malik](https://huggingface.co/bhavitvyamalik)"
69
- )
70
-
71
- image_col, intro_col = st.beta_columns([3, 8])
72
- image_col.image("./misc/mvqa-logo-3-white.png", use_column_width="always")
73
- intro_col.write(read_markdown("intro.md"))
74
- with st.beta_expander("Usage"):
75
- st.write(read_markdown("usage.md"))
76
-
77
- with st.beta_expander("Article"):
78
- st.write(read_markdown("abstract.md"))
79
- st.write(read_markdown("caveats.md"))
80
- st.write("## Methodology")
81
- col1, col2 = st.beta_columns([1,1])
82
- col1.image(
83
- "./misc/article/resized/Multilingual-VQA.png",
84
- caption="Masked LM model for Image-text Pretraining.",
85
  )
86
- col2.markdown(read_markdown("pretraining.md"))
87
- st.markdown(read_markdown("finetuning.md"))
88
- st.write(read_markdown("challenges.md"))
89
- st.write(read_markdown("social_impact.md"))
90
- st.write(read_markdown("references.md"))
91
- st.write(read_markdown("checkpoints.md"))
92
- st.write(read_markdown("acknowledgements.md"))
93
 
94
- first_index = 20
95
- # Init Session State
96
- if state.image_file is None:
97
- state.image_file = dummy_data.loc[first_index, "image_file"]
98
- state.question = dummy_data.loc[first_index, "question"].strip("- ")
99
- state.answer_label = dummy_data.loc[first_index, "answer_label"]
100
- state.question_lang_id = dummy_data.loc[first_index, "lang_id"]
101
- state.answer_lang_id = dummy_data.loc[first_index, "lang_id"]
102
-
103
- image_path = os.path.join("resized_images", state.image_file)
104
- image = plt.imread(image_path)
105
- state.image = image
106
-
107
- # col1, col2, col3 = st.beta_columns([3,3,3])
108
-
109
- if st.button(
110
- "Get a random example",
111
- help="Get a random example from the 100 `seeded` image-text pairs.",
112
- ):
113
- sample = dummy_data.sample(1).reset_index()
114
- state.image_file = sample.loc[0, "image_file"]
115
- state.question = sample.loc[0, "question"].strip("- ")
116
- state.answer_label = sample.loc[0, "answer_label"]
117
- state.question_lang_id = sample.loc[0, "lang_id"]
118
- state.answer_lang_id = sample.loc[0, "lang_id"]
119
-
120
- image_path = os.path.join("resized_images", state.image_file)
121
- image = plt.imread(image_path)
122
- state.image = image
123
-
124
- # col2.write("OR")
125
-
126
- # uploaded_file = col2.file_uploader(
127
- # "Upload your image",
128
- # type=["png", "jpg", "jpeg"],
129
- # help="Upload a file of your choosing.",
130
- # )
131
- # if uploaded_file is not None:
132
- # state.image_file = os.path.join("images/val2014", uploaded_file.name)
133
- # state.image = np.array(Image.open(uploaded_file))
134
-
135
- transformed_image = get_transformed_image(state.image)
136
-
137
- new_col1, new_col2 = st.beta_columns([5, 5])
138
-
139
- # Display Image
140
- new_col1.image(state.image, use_column_width="always")
141
-
142
-
143
- # Display Question
144
- question = new_col2.text_input(
145
- label="Question",
146
- value=state.question,
147
- help="Type your question regarding the image above in one of the four languages.",
148
- )
149
- new_col2.markdown(
150
- f"""**English Translation**: {question if state.question_lang_id == "en" else translate(question, 'en')}"""
151
- )
152
-
153
- question_inputs = get_text_attributes(question)
154
-
155
- # Select Language
156
- options = ["en", "de", "es", "fr"]
157
- state.answer_lang_id = new_col2.selectbox(
158
- "Answer Language",
159
- index=options.index(state.answer_lang_id),
160
- options=options,
161
- format_func=lambda x: code_to_name[x],
162
- help="The language to be used to show the top-5 labels.",
163
- )
164
-
165
- actual_answer = answer_reverse_mapping[str(state.answer_label)]
166
- new_col2.markdown(
167
- "**Actual Answer**: "
168
- + translate_labels([actual_answer], state.answer_lang_id)[0]
169
- + " ("
170
- + actual_answer
171
- + ")"
172
- )
173
 
174
- # Display Top-5 Predictions
175
- with st.spinner("Loading model..."):
176
- model = load_model(checkpoints[0])
177
- with st.spinner("Predicting..."):
178
- logits = predict(transformed_image, dict(question_inputs))
179
- logits = softmax(logits)
180
- labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
181
- translated_labels = translate_labels(labels, state.answer_lang_id)
182
- fig = plotly_express_horizontal_bar_plot(values, translated_labels)
183
- st.plotly_chart(fig, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from apps import mlm, vqa
2
  import os
 
 
 
 
 
3
  import streamlit as st
4
+ from multiapp import MultiApp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  def read_markdown(path, parent="./sections/"):
7
  with open(os.path.join(parent, path)) as f:
8
  return f.read()
9
 
10
+ def main():
11
+ st.set_page_config(
12
+ page_title="Multilingual VQA",
13
+ layout="wide",
14
+ initial_sidebar_state="collapsed",
15
+ page_icon="./misc/mvqa-logo-3-white.png",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  )
 
 
 
 
 
 
 
17
 
18
+ st.title("Multilingual Visual Question Answering")
19
+ st.write(
20
+ "[Gunjan Chhablani](https://huggingface.co/gchhablani), [Bhavitvya Malik](https://huggingface.co/bhavitvyamalik)"
21
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ image_col, intro_col = st.beta_columns([3, 8])
24
+ image_col.image("./misc/mvqa-logo-3-white.png", use_column_width="always")
25
+ intro_col.write(read_markdown("intro.md"))
26
+ with st.beta_expander("Usage"):
27
+ st.write(read_markdown("usage.md"))
28
+
29
+ with st.beta_expander("Article"):
30
+ st.write(read_markdown("abstract.md"))
31
+ st.write(read_markdown("caveats.md"))
32
+ st.write("## Methodology")
33
+ col1, col2 = st.beta_columns([1,1])
34
+ col1.image(
35
+ "./misc/article/Multilingual-VQA.png",
36
+ caption="Masked LM model for Image-text Pretraining.",
37
+ )
38
+ col2.markdown(read_markdown("pretraining.md"))
39
+ st.markdown(read_markdown("finetuning.md"))
40
+ st.write(read_markdown("challenges.md"))
41
+ st.write(read_markdown("social_impact.md"))
42
+ st.write(read_markdown("references.md"))
43
+ st.write(read_markdown("checkpoints.md"))
44
+ st.write(read_markdown("acknowledgements.md"))
45
+
46
+ app = MultiApp()
47
+ app.add_app("Visual Question Answering", vqa.app)
48
+ app.add_app("Mask Filling", mlm.app)
49
+ app.run()
50
+
51
+ if __name__ == "__main__":
52
+ main()
apps/mlm.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from .utils import (
3
+ get_text_attributes,
4
+ get_top_5_predictions,
5
+ get_transformed_image,
6
+ plotly_express_horizontal_bar_plot,
7
+ translate_labels,
8
+ bert_tokenizer
9
+ )
10
+
11
+ import streamlit as st
12
+ import numpy as np
13
+ import pandas as pd
14
+ import os
15
+ import matplotlib.pyplot as plt
16
+
17
+ from session import _get_state
18
+
19
+
20
+ from .model.flax_clip_vision_bert.modeling_clip_vision_bert import (
21
+ FlaxCLIPVisionBertForMaskedLM,
22
+ )
23
+
24
+ def softmax(logits):
25
+ return np.exp(logits) / np.sum(np.exp(logits), axis=0)
26
+
27
+ def app():
28
+ state = _get_state()
29
+
30
+ @st.cache(persist=False)
31
+ def predict(transformed_image, caption_inputs):
32
+ outputs = state.model(pixel_values=transformed_image, **caption_inputs)
33
+ indices = np.where(caption_inputs['input_ids']==bert_tokenizer.mask_token_id)
34
+ preds = outputs.logits[indices][0]
35
+ sorted_indices = np.argsort(preds)[::-1] # Get reverse sorted scores
36
+ top_5_indices = sorted_indices[:5]
37
+ top_5_tokens = bert_tokenizer.convert_ids_to_tokens(top_5_indices)
38
+ top_5_scores = np.array(preds[top_5_indices])
39
+ return top_5_tokens, top_5_scores
40
+
41
+
42
+ @st.cache(persist=False)
43
+ def load_model(ckpt):
44
+ return FlaxCLIPVisionBertForMaskedLM.from_pretrained(ckpt)
45
+
46
+ mlm_checkpoints = ['flax-community/clip-vision-bert-cc12m-70k']
47
+ dummy_data = pd.read_csv("cc12m_data/vqa_val.tsv", sep="\t")
48
+
49
+ first_index = 20
50
+ # Init Session State
51
+ if state.image_file is None:
52
+ state.image_file = dummy_data.loc[first_index, "image_file"]
53
+ caption = dummy_data.loc[first_index, "caption"].strip("- ")
54
+ ids = bert_tokenizer(caption)
55
+ ids[np.random.randint(0, len(ids))] = bert_tokenizer.mask_token_id
56
+ state.caption = bert_tokenizer.decode(ids)
57
+ state.caption_lang_id = dummy_data.loc[first_index, "lang_id"]
58
+
59
+ image_path = os.path.join("cc12m_data/images_vqa", state.image_file)
60
+ image = plt.imread(image_path)
61
+ state.image = image
62
+
63
+ if state.model is None:
64
+ # Display Top-5 Predictions
65
+ with st.spinner("Loading model..."):
66
+ state.model = load_model(mlm_checkpoints[0])
67
+
68
+ if st.button(
69
+ "Get a random example",
70
+ help="Get a random example from the 100 `seeded` image-text pairs.",
71
+ ):
72
+ sample = dummy_data.sample(1).reset_index()
73
+ state.image_file = sample.loc[0, "image_file"]
74
+ caption = sample.loc[0, "caption"].strip("- ")
75
+ ids = bert_tokenizer(caption)
76
+ ids[np.random.randint(0, len(ids))] = bert_tokenizer.mask_token_id
77
+ state.caption = bert_tokenizer.decode(ids)
78
+ state.caption_lang_id = sample.loc[0, "lang_id"]
79
+
80
+ image_path = os.path.join("cc12m_data/images_vqa", state.image_file)
81
+ image = plt.imread(image_path)
82
+ state.image = image
83
+
84
+ transformed_image = get_transformed_image(state.image)
85
+
86
+ new_col1, new_col2 = st.beta_columns([5, 5])
87
+
88
+ # Display Image
89
+ new_col1.image(state.image, use_column_width="always")
90
+
91
+
92
+ # Display caption
93
+ new_col2.write("Write your text with exactly one [MASK] token.")
94
+ caption = new_col2.text_input(
95
+ label="Text",
96
+ value=state.caption,
97
+ help="Type your masked caption regarding the image above in one of the four languages.",
98
+ )
99
+
100
+ caption_inputs = get_text_attributes(caption)
101
+
102
+ # Display Top-5 Predictions
103
+
104
+ with st.spinner("Predicting..."):
105
+ logits = predict(transformed_image, dict(caption_inputs))
106
+ logits = softmax(logits)
107
+ labels, values = get_top_5_predictions(logits)
108
+ fig = plotly_express_horizontal_bar_plot(values, labels)
109
+ st.plotly_chart(fig, use_container_width=True)
{model β†’ apps/model}/__init__.py RENAMED
File without changes
{model β†’ apps/model}/flax_clip_vision_bert/__init__.py RENAMED
File without changes
{model β†’ apps/model}/flax_clip_vision_bert/configuration_clip_vision_bert.py RENAMED
File without changes
{model β†’ apps/model}/flax_clip_vision_bert/modeling_clip_vision_bert.py RENAMED
File without changes
utils.py β†’ apps/utils.py RENAMED
@@ -3,8 +3,7 @@ import json
3
  import numpy as np
4
  import plotly.express as px
5
  import torch
6
- from PIL import Image
7
- from torchvision.io import ImageReadMode, read_image
8
  from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
9
  from torchvision.transforms.functional import InterpolationMode
10
  from transformers import BertTokenizerFast
@@ -41,15 +40,17 @@ def get_transformed_image(image):
41
 
42
  bert_tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-uncased")
43
 
44
-
45
  def get_text_attributes(text):
46
  return bert_tokenizer([text], return_token_type_ids=True, return_tensors="np")
47
 
48
 
49
- def get_top_5_predictions(logits, answer_reverse_mapping):
50
  indices = np.argsort(logits)[-5:]
51
  values = logits[indices]
52
- labels = [answer_reverse_mapping[str(i)] for i in indices]
 
 
 
53
  return labels, values
54
 
55
 
 
3
  import numpy as np
4
  import plotly.express as px
5
  import torch
6
+ from torchvision.io import read_image
 
7
  from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
8
  from torchvision.transforms.functional import InterpolationMode
9
  from transformers import BertTokenizerFast
 
40
 
41
  bert_tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-uncased")
42
 
 
43
  def get_text_attributes(text):
44
  return bert_tokenizer([text], return_token_type_ids=True, return_tensors="np")
45
 
46
 
47
+ def get_top_5_predictions(logits, answer_reverse_mapping=None):
48
  indices = np.argsort(logits)[-5:]
49
  values = logits[indices]
50
+ if answer_reverse_mapping is not None:
51
+ labels = [answer_reverse_mapping[str(i)] for i in indices]
52
+ else:
53
+ labels = bert_tokenizer.convert_ids_to_tokens(indices)
54
  return labels, values
55
 
56
 
apps/vqa.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from .utils import (
3
+ get_text_attributes,
4
+ get_top_5_predictions,
5
+ get_transformed_image,
6
+ plotly_express_horizontal_bar_plot,
7
+ translate_labels,
8
+ )
9
+
10
+ import streamlit as st
11
+ import numpy as np
12
+ import pandas as pd
13
+ import os
14
+ import matplotlib.pyplot as plt
15
+ import json
16
+
17
+ from mtranslate import translate
18
+ from session import _get_state
19
+
20
+
21
+ from .model.flax_clip_vision_bert.modeling_clip_vision_bert import (
22
+ FlaxCLIPVisionBertForSequenceClassification,
23
+ )
24
+
25
+ def softmax(logits):
26
+ return np.exp(logits) / np.sum(np.exp(logits), axis=0)
27
+
28
+ def app():
29
+ state = _get_state()
30
+
31
+ @st.cache(persist=True)
32
+ def predict(transformed_image, question_inputs):
33
+ return np.array(state.model(pixel_values=transformed_image, **question_inputs)[0][0])
34
+
35
+
36
+ @st.cache(persist=True)
37
+ def load_model(ckpt):
38
+ return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
39
+
40
+ vqa_checkpoints = ["flax-community/clip-vision-bert-vqa-ft-6k"] # TODO: Maybe add more checkpoints?
41
+ dummy_data = pd.read_csv("dummy_vqa_multilingual.tsv", sep="\t")
42
+ code_to_name = {
43
+ "en": "English",
44
+ "fr": "French",
45
+ "de": "German",
46
+ "es": "Spanish",
47
+ }
48
+
49
+
50
+ with open("answer_reverse_mapping.json") as f:
51
+ answer_reverse_mapping = json.load(f)
52
+
53
+ first_index = 20
54
+ # Init Session State
55
+ if state.image_file is None:
56
+ state.image_file = dummy_data.loc[first_index, "image_file"]
57
+ state.question = dummy_data.loc[first_index, "question"].strip("- ")
58
+ state.answer_label = dummy_data.loc[first_index, "answer_label"]
59
+ state.question_lang_id = dummy_data.loc[first_index, "lang_id"]
60
+ state.answer_lang_id = dummy_data.loc[first_index, "lang_id"]
61
+
62
+ image_path = os.path.join("resized_images", state.image_file)
63
+ image = plt.imread(image_path)
64
+ state.image = image
65
+
66
+ if state.model is None:
67
+ # Display Top-5 Predictions
68
+ with st.spinner("Loading model..."):
69
+ state.model = load_model(vqa_checkpoints[0])
70
+
71
+ if st.button(
72
+ "Get a random example",
73
+ help="Get a random example from the 100 `seeded` image-text pairs.",
74
+ ):
75
+ sample = dummy_data.sample(1).reset_index()
76
+ state.image_file = sample.loc[0, "image_file"]
77
+ state.question = sample.loc[0, "question"].strip("- ")
78
+ state.answer_label = sample.loc[0, "answer_label"]
79
+ state.question_lang_id = sample.loc[0, "lang_id"]
80
+ state.answer_lang_id = sample.loc[0, "lang_id"]
81
+
82
+ image_path = os.path.join("resized_images", state.image_file)
83
+ image = plt.imread(image_path)
84
+ state.image = image
85
+
86
+ transformed_image = get_transformed_image(state.image)
87
+
88
+ new_col1, new_col2 = st.beta_columns([5, 5])
89
+
90
+ # Display Image
91
+ new_col1.image(state.image, use_column_width="always")
92
+
93
+
94
+ # Display Question
95
+ question = new_col2.text_input(
96
+ label="Question",
97
+ value=state.question,
98
+ help="Type your question regarding the image above in one of the four languages.",
99
+ )
100
+ new_col2.markdown(
101
+ f"""**English Translation**: {question if state.question_lang_id == "en" else translate(question, 'en')}"""
102
+ )
103
+
104
+ question_inputs = get_text_attributes(question)
105
+
106
+ # Select Language
107
+ options = ["en", "de", "es", "fr"]
108
+ state.answer_lang_id = new_col2.selectbox(
109
+ "Answer Language",
110
+ index=options.index(state.answer_lang_id),
111
+ options=options,
112
+ format_func=lambda x: code_to_name[x],
113
+ help="The language to be used to show the top-5 labels.",
114
+ )
115
+
116
+ actual_answer = answer_reverse_mapping[str(state.answer_label)]
117
+ new_col2.markdown(
118
+ "**Actual Answer**: "
119
+ + translate_labels([actual_answer], state.answer_lang_id)[0]
120
+ + " ("
121
+ + actual_answer
122
+ + ")"
123
+ )
124
+
125
+ with st.spinner("Predicting..."):
126
+ logits = predict(transformed_image, dict(question_inputs))
127
+ logits = softmax(logits)
128
+ labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
129
+ translated_labels = translate_labels(labels, state.answer_lang_id)
130
+ fig = plotly_express_horizontal_bar_plot(values, translated_labels)
131
+ st.plotly_chart(fig, use_container_width=True)
cc12m_data/.DS_Store ADDED
Binary file (6.15 kB). View file
 
cc12m_data/images_vqa/.DS_Store ADDED
Binary file (6.15 kB). View file
 
cc12m_data/images_vqa/00212055---Wax_cylinder_in_Dictaphone.jpg ADDED
cc12m_data/images_vqa/00315853---041bdd212f5b5d3d30cbc4ccf523f1a3.jpg ADDED
cc12m_data/images_vqa/00328633---Metal+chips+fly+in+a+high+speed+turning+operation+performed+on+a+computer+numerical+control+turning+center+%28photo+courtesy+of+Cincinnati+Milacron%29..jpg ADDED
cc12m_data/images_vqa/00491934---I6FTIDWLJRFPHAK4ZSZH4RQGDA.jpg ADDED
cc12m_data/images_vqa/00507360---MushroomRisotto1.jpg ADDED
cc12m_data/images_vqa/00602376---%20essay-example-writing-comparison-compare-contrast-how-to-write-poem-examples-of%20-1024x768.jpg ADDED
cc12m_data/images_vqa/00606341---dog-coloring-book-detailed-dogs-page2.jpg ADDED
cc12m_data/images_vqa/00697411---dream-house-swimming-pool-large-133359636.jpg ADDED
cc12m_data/images_vqa/00923733---white-commercial-van-road-motion-blurred-d-illustration-custom-designed-brandless-87900010.jpg ADDED
cc12m_data/images_vqa/01023838---fundraising-photo.jpg ADDED
cc12m_data/images_vqa/01053356---522a16b60d3f226fff652671cdde6011.jpg ADDED
cc12m_data/images_vqa/01157077---female-fruit-picker-worker-basket-woodcut-illustration-wearing-bandana-holding-viewed-side-set-white-61675986.jpg ADDED
cc12m_data/images_vqa/01275377---Young-the-Giant.jpg ADDED
cc12m_data/images_vqa/01327794---40250345161_452dc56b11_z.jpg ADDED
cc12m_data/images_vqa/01648721---170420062908YDYA.jpg ADDED
cc12m_data/images_vqa/01760795---The-Size-of-the-buildings-in-Shekou-are-in-direct-relation-to-the-time-it-takes-to-accomplish-tasks.jpg ADDED
cc12m_data/images_vqa/01761366---fresh-salad-flying-vegetables-ingredients-isolated-white-background-48747892.jpg ADDED
cc12m_data/images_vqa/01772764---business-woman-winner-standing-first-600w-254762824.jpg ADDED
cc12m_data/images_vqa/01813337---cd4df5cb43d087533e89b12c9805409e.jpg ADDED
cc12m_data/images_vqa/02034916---XKC6GGK5NDECNBAD5WAQUWOO5U.jpg ADDED
cc12m_data/images_vqa/02175876---DL2-4i4.jpg ADDED
cc12m_data/images_vqa/02217469---mount-macedon-victoria-australia-macedon-regional-park-region-photographed-by-karen-robinson-_march-29-2020_042-1.jpg ADDED
cc12m_data/images_vqa/02243845---heritage-heritage-matte-stainless-steel-sink-undermount-5_2048x.jpg ADDED
cc12m_data/images_vqa/02335328---margaret-and-alexander-potters-houses-1948.jpg ADDED
cc12m_data/images_vqa/02520451---Gower-1.jpg ADDED
cc12m_data/images_vqa/02912250---a-black-panther-has-been-spotted-in-weald-park-brentwood-essex-britain-shutterstock-editorial-618335e.jpg ADDED
cc12m_data/images_vqa/03257347---looking-farther-afield-article-size.jpg ADDED
cc12m_data/images_vqa/03271226---beneath-the-borealis-092517-a-very-bear-y-summer-kennicott-valley-virga.jpg ADDED
cc12m_data/images_vqa/03307717---tumblr_m9d4xkRM5n1rypkpio1_1280.jpg ADDED
cc12m_data/images_vqa/03360735---Warm-Bacon-Dip-EasyLowCarb-2.jpg ADDED
cc12m_data/images_vqa/03394023---m_5e36e15f2169682519441e34.jpg ADDED
cc12m_data/images_vqa/03401066---160328-capitol-police-mn-1530_bd68b01f1d7f1c3ab99eafa503930569.fit-760w.jpg ADDED
cc12m_data/images_vqa/03598306---20400805522_fba017bc51_b.jpg ADDED
cc12m_data/images_vqa/03618296---A+pink+and+grey+woven+baskets+sits+on+top+of+a+clear+side+table.jpg ADDED
cc12m_data/images_vqa/04331097---108_1504859395_24.jpg ADDED
cc12m_data/images_vqa/04334412---Pants-All-match-Professional-Harlan-Women-s-Loose-Skinny-High-Waist-New-2019-Suit-Summer-Leisure-Pants-2077.jpg ADDED
cc12m_data/images_vqa/04358571---41-Travelex.jpg ADDED
cc12m_data/images_vqa/04361362---square-stone-benches-around-fire-pit-outside-residential-building-sunny-day-pathways-plants-can-also-be-seen-homes-171086572.jpg ADDED
cc12m_data/images_vqa/04530023---49305383277_29d4a34f37_h.jpg ADDED
cc12m_data/images_vqa/04749808---thinkstockphotos-1858212351.jpg ADDED