gchhablani commited on
Commit
690384a
1 Parent(s): 0e71038
Files changed (4) hide show
  1. app.py +63 -35
  2. requirements.txt +2 -1
  3. translate_answer_mapping.py +9 -3
  4. utils.py +16 -5
app.py CHANGED
@@ -5,8 +5,17 @@ import json
5
  import os
6
  import numpy as np
7
  from streamlit.elements import markdown
8
- from model.flax_clip_vision_bert.modeling_clip_vision_bert import FlaxCLIPVisionBertForSequenceClassification
9
- from utils import get_transformed_image, get_text_attributes, get_top_5_predictions, plotly_express_horizontal_bar_plot, translate_labels
 
 
 
 
 
 
 
 
 
10
  import matplotlib.pyplot as plt
11
  from mtranslate import translate
12
  from PIL import Image
@@ -16,23 +25,30 @@ from session import _get_state
16
 
17
  state = _get_state()
18
 
 
19
  @st.cache(persist=True)
20
  def load_model(ckpt):
21
  return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
22
 
 
23
  @st.cache(persist=True)
24
  def predict(transformed_image, question_inputs):
25
- return np.array(model(pixel_values = transformed_image, **question_inputs)[0][0])
26
-
 
27
  def softmax(logits):
28
- return np.exp(logits)/np.sum(np.exp(logits), axis=0)
 
29
 
30
  def read_markdown(path, parent="./sections/"):
31
- with open(os.path.join(parent,path)) as f:
32
  return f.read()
33
 
34
- checkpoints = ['./ckpt/ckpt-60k-5999'] # TODO: Maybe add more checkpoints?
35
- dummy_data = pd.read_csv('dummy_vqa_multilingual.tsv', sep='\t')
 
 
 
36
  code_to_name = {
37
  "en": "English",
38
  "fr": "French",
@@ -40,7 +56,7 @@ code_to_name = {
40
  "es": "Spanish",
41
  }
42
 
43
- with open('answer_reverse_mapping.json') as f:
44
  answer_reverse_mapping = json.load(f)
45
 
46
 
@@ -52,7 +68,9 @@ st.set_page_config(
52
  )
53
 
54
  st.title("Multilingual Visual Question Answering")
55
- st.write("[Gunjan Chhablani](https://huggingface.co/gchhablani), [Bhavitvya Malik](https://huggingface.co/bhavitvyamalik)")
 
 
56
 
57
  with st.beta_expander("Usage"):
58
  st.markdown(read_markdown("usage.md"))
@@ -60,67 +78,77 @@ with st.beta_expander("Usage"):
60
  first_index = 20
61
  # Init Session State
62
  if state.image_file is None:
63
- state.image_file = dummy_data.loc[first_index,'image_file']
64
- state.question = dummy_data.loc[first_index,'question'].strip('- ')
65
- state.answer_label = dummy_data.loc[first_index,'answer_label']
66
- state.question_lang_id = dummy_data.loc[first_index, 'lang_id']
67
- state.answer_lang_id = dummy_data.loc[first_index, 'lang_id']
68
-
69
- image_path = os.path.join('images',state.image_file)
70
  image = plt.imread(image_path)
71
  state.image = image
72
 
73
- col1, col2 = st.beta_columns([6,4])
74
 
75
- if col2.button('Get a random example'):
76
  sample = dummy_data.sample(1).reset_index()
77
- state.image_file = sample.loc[0,'image_file']
78
- state.question = sample.loc[0,'question'].strip('- ')
79
- state.answer_label = sample.loc[0,'answer_label']
80
- state.question_lang_id = sample.loc[0, 'lang_id']
81
- state.answer_lang_id = sample.loc[0, 'lang_id']
82
 
83
- image_path = os.path.join('images',state.image_file)
84
  image = plt.imread(image_path)
85
  state.image = image
86
 
87
  col2.write("OR")
88
 
89
- uploaded_file = col2.file_uploader('Upload your image', type=['png','jpg','jpeg'])
90
  if uploaded_file is not None:
91
- state.image_file = os.path.join('images/val2014',uploaded_file.name)
92
  state.image = np.array(Image.open(uploaded_file))
93
 
94
 
 
95
  transformed_image = get_transformed_image(state.image)
96
 
97
  # Display Image
98
- col1.image(state.image, use_column_width='always')
99
 
100
  # Display Question
101
  question = col2.text_input(label="Question", value=state.question)
102
- col2.markdown(f"""**English Translation**: {question if state.question_lang_id == "en" else translate(question, 'en')}""")
 
 
103
  question_inputs = get_text_attributes(question)
104
 
105
  # Select Language
106
- options = ['en', 'de', 'es', 'fr']
107
- state.answer_lang_id = col2.selectbox('Answer Language', index=options.index(state.answer_lang_id), options=options, format_func = lambda x: code_to_name[x])
 
 
 
 
 
108
  # Display Top-5 Predictions
109
- with st.spinner('Loading model...'):
110
  model = load_model(checkpoints[0])
111
- with st.spinner('Predicting...'):
112
  logits = predict(transformed_image, dict(question_inputs))
113
  logits = softmax(logits)
114
  labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
115
  translated_labels = translate_labels(labels, state.answer_lang_id)
116
  fig = plotly_express_horizontal_bar_plot(values, translated_labels)
117
- st.plotly_chart(fig, use_container_width = True)
118
 
119
 
120
  st.write(read_markdown("abstract.md"))
121
  st.write(read_markdown("caveats.md"))
122
  st.write("# Methodology")
123
- st.image("./misc/Multilingual-VQA.png", caption="Masked LM model for Image-text Pretraining.")
 
 
124
  st.markdown(read_markdown("pretraining.md"))
125
  st.markdown(read_markdown("finetuning.md"))
126
  st.write(read_markdown("challenges.md"))
 
5
  import os
6
  import numpy as np
7
  from streamlit.elements import markdown
8
+ import cv2
9
+ from model.flax_clip_vision_bert.modeling_clip_vision_bert import (
10
+ FlaxCLIPVisionBertForSequenceClassification,
11
+ )
12
+ from utils import (
13
+ get_transformed_image,
14
+ get_text_attributes,
15
+ get_top_5_predictions,
16
+ plotly_express_horizontal_bar_plot,
17
+ translate_labels,
18
+ )
19
  import matplotlib.pyplot as plt
20
  from mtranslate import translate
21
  from PIL import Image
 
25
 
26
  state = _get_state()
27
 
28
+
29
  @st.cache(persist=True)
30
  def load_model(ckpt):
31
  return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
32
 
33
+
34
  @st.cache(persist=True)
35
  def predict(transformed_image, question_inputs):
36
+ return np.array(model(pixel_values=transformed_image, **question_inputs)[0][0])
37
+
38
+
39
  def softmax(logits):
40
+ return np.exp(logits) / np.sum(np.exp(logits), axis=0)
41
+
42
 
43
  def read_markdown(path, parent="./sections/"):
44
+ with open(os.path.join(parent, path)) as f:
45
  return f.read()
46
 
47
+ def resize_height(image, new_height):
48
+ h, w, c = image.shape
49
+
50
+ checkpoints = ["./ckpt/ckpt-60k-5999"] # TODO: Maybe add more checkpoints?
51
+ dummy_data = pd.read_csv("dummy_vqa_multilingual.tsv", sep="\t")
52
  code_to_name = {
53
  "en": "English",
54
  "fr": "French",
 
56
  "es": "Spanish",
57
  }
58
 
59
+ with open("answer_reverse_mapping.json") as f:
60
  answer_reverse_mapping = json.load(f)
61
 
62
 
 
68
  )
69
 
70
  st.title("Multilingual Visual Question Answering")
71
+ st.write(
72
+ "[Gunjan Chhablani](https://huggingface.co/gchhablani), [Bhavitvya Malik](https://huggingface.co/bhavitvyamalik)"
73
+ )
74
 
75
  with st.beta_expander("Usage"):
76
  st.markdown(read_markdown("usage.md"))
 
78
  first_index = 20
79
  # Init Session State
80
  if state.image_file is None:
81
+ state.image_file = dummy_data.loc[first_index, "image_file"]
82
+ state.question = dummy_data.loc[first_index, "question"].strip("- ")
83
+ state.answer_label = dummy_data.loc[first_index, "answer_label"]
84
+ state.question_lang_id = dummy_data.loc[first_index, "lang_id"]
85
+ state.answer_lang_id = dummy_data.loc[first_index, "lang_id"]
86
+
87
+ image_path = os.path.join("images", state.image_file)
88
  image = plt.imread(image_path)
89
  state.image = image
90
 
91
+ col1, col2 = st.beta_columns([6, 4])
92
 
93
+ if col2.button("Get a random example"):
94
  sample = dummy_data.sample(1).reset_index()
95
+ state.image_file = sample.loc[0, "image_file"]
96
+ state.question = sample.loc[0, "question"].strip("- ")
97
+ state.answer_label = sample.loc[0, "answer_label"]
98
+ state.question_lang_id = sample.loc[0, "lang_id"]
99
+ state.answer_lang_id = sample.loc[0, "lang_id"]
100
 
101
+ image_path = os.path.join("images", state.image_file)
102
  image = plt.imread(image_path)
103
  state.image = image
104
 
105
  col2.write("OR")
106
 
107
+ uploaded_file = col2.file_uploader("Upload your image", type=["png", "jpg", "jpeg"])
108
  if uploaded_file is not None:
109
+ state.image_file = os.path.join("images/val2014", uploaded_file.name)
110
  state.image = np.array(Image.open(uploaded_file))
111
 
112
 
113
+ state.image =
114
  transformed_image = get_transformed_image(state.image)
115
 
116
  # Display Image
117
+ col1.image(state.image, use_column_width="always")
118
 
119
  # Display Question
120
  question = col2.text_input(label="Question", value=state.question)
121
+ col2.markdown(
122
+ f"""**English Translation**: {question if state.question_lang_id == "en" else translate(question, 'en')}"""
123
+ )
124
  question_inputs = get_text_attributes(question)
125
 
126
  # Select Language
127
+ options = ["en", "de", "es", "fr"]
128
+ state.answer_lang_id = col2.selectbox(
129
+ "Answer Language",
130
+ index=options.index(state.answer_lang_id),
131
+ options=options,
132
+ format_func=lambda x: code_to_name[x],
133
+ )
134
  # Display Top-5 Predictions
135
+ with st.spinner("Loading model..."):
136
  model = load_model(checkpoints[0])
137
+ with st.spinner("Predicting..."):
138
  logits = predict(transformed_image, dict(question_inputs))
139
  logits = softmax(logits)
140
  labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
141
  translated_labels = translate_labels(labels, state.answer_lang_id)
142
  fig = plotly_express_horizontal_bar_plot(values, translated_labels)
143
+ st.plotly_chart(fig, use_container_width=True)
144
 
145
 
146
  st.write(read_markdown("abstract.md"))
147
  st.write(read_markdown("caveats.md"))
148
  st.write("# Methodology")
149
+ st.image(
150
+ "./misc/Multilingual-VQA.png", caption="Masked LM model for Image-text Pretraining."
151
+ )
152
  st.markdown(read_markdown("pretraining.md"))
153
  st.markdown(read_markdown("finetuning.md"))
154
  st.write(read_markdown("challenges.md"))
requirements.txt CHANGED
@@ -4,4 +4,5 @@ git+https://github.com/huggingface/transformers.git
4
  torchvision==0.10.0
5
  mtranslate==1.8
6
  black==21.7b0
7
- flax==0.3.4
 
 
4
  torchvision==0.10.0
5
  mtranslate==1.8
6
  black==21.7b0
7
+ flax==0.3.4
8
+ opencv-python==4.5.3
translate_answer_mapping.py CHANGED
@@ -4,6 +4,7 @@ from tqdm import tqdm
4
  import ray
5
  from asyncio import Event
6
  from ray.actor import ActorHandle
 
7
  ray.init()
8
  from typing import Tuple
9
 
@@ -48,6 +49,7 @@ class ProgressBarActor:
48
  """
49
  return self.counter
50
 
 
51
  class ProgressBar:
52
  progress_actor: ActorHandle
53
  total: int
@@ -89,14 +91,16 @@ class ProgressBar:
89
  with open("answer_reverse_mapping.json") as f:
90
  answer_reverse_mapping = json.load(f)
91
 
 
92
  @ray.remote
93
  def translate_answer(value, pba):
94
  temp = {}
95
  for lang in ["fr", "es", "de"]:
96
- temp.update({lang: translate(value, lang, 'en')})
97
  pba.update.remote(1)
98
  return temp
99
 
 
100
  translation_dicts = []
101
  pb = ProgressBar(len(answer_reverse_mapping.values()))
102
  actor = pb.actor
@@ -104,8 +108,10 @@ for value in answer_reverse_mapping.values():
104
  translation_dicts.append(translate_answer.remote(value, actor))
105
 
106
  pb.print_until_done()
107
- translation_dict = dict(zip(answer_reverse_mapping.values(),ray.get(translation_dicts)))
 
 
108
 
109
 
110
  with open("translation_dict.json", "w") as f:
111
- json.dump(translation_dict, f)
 
4
  import ray
5
  from asyncio import Event
6
  from ray.actor import ActorHandle
7
+
8
  ray.init()
9
  from typing import Tuple
10
 
 
49
  """
50
  return self.counter
51
 
52
+
53
  class ProgressBar:
54
  progress_actor: ActorHandle
55
  total: int
 
91
  with open("answer_reverse_mapping.json") as f:
92
  answer_reverse_mapping = json.load(f)
93
 
94
+
95
  @ray.remote
96
  def translate_answer(value, pba):
97
  temp = {}
98
  for lang in ["fr", "es", "de"]:
99
+ temp.update({lang: translate(value, lang, "en")})
100
  pba.update.remote(1)
101
  return temp
102
 
103
+
104
  translation_dicts = []
105
  pb = ProgressBar(len(answer_reverse_mapping.values()))
106
  actor = pb.actor
 
108
  translation_dicts.append(translate_answer.remote(value, actor))
109
 
110
  pb.print_until_done()
111
+ translation_dict = dict(
112
+ zip(answer_reverse_mapping.values(), ray.get(translation_dicts))
113
+ )
114
 
115
 
116
  with open("translation_dict.json", "w") as f:
117
+ json.dump(translation_dict, f)
utils.py CHANGED
@@ -7,6 +7,8 @@ from transformers import BertTokenizerFast
7
  import plotly.express as px
8
  import json
9
  from PIL import Image
 
 
10
  class Transform(torch.nn.Module):
11
  def __init__(self, image_size):
12
  super().__init__()
@@ -31,7 +33,7 @@ transform = Transform(224)
31
 
32
  def get_transformed_image(image):
33
  if image.shape[-1] == 3 and isinstance(image, np.ndarray):
34
- image = image.transpose(2,0,1)
35
  image = torch.tensor(image)
36
  return transform(image).unsqueeze(0).permute(0, 2, 3, 1).numpy()
37
 
@@ -49,13 +51,15 @@ def get_top_5_predictions(logits, answer_reverse_mapping):
49
  labels = [answer_reverse_mapping[str(i)] for i in indices]
50
  return labels, values
51
 
52
- with open('translation_dict.json') as f:
 
53
  translate_dict = json.load(f)
54
 
 
55
  def translate_labels(labels, lang_id):
56
  translated_labels = []
57
  for label in labels:
58
- if label=="<unk>":
59
  translated_labels.append("<unk>")
60
  elif lang_id == "en":
61
  translated_labels.append(label)
@@ -65,5 +69,12 @@ def translate_labels(labels, lang_id):
65
 
66
 
67
  def plotly_express_horizontal_bar_plot(values, labels):
68
- fig = px.bar(x=values, y=labels, text = [format(value, ".3%") for value in values], title="Top-5 Predictions", labels={"x": "Scores", "y":"Answers"}, orientation="h")
69
- return fig
 
 
 
 
 
 
 
 
7
  import plotly.express as px
8
  import json
9
  from PIL import Image
10
+
11
+
12
  class Transform(torch.nn.Module):
13
  def __init__(self, image_size):
14
  super().__init__()
 
33
 
34
  def get_transformed_image(image):
35
  if image.shape[-1] == 3 and isinstance(image, np.ndarray):
36
+ image = image.transpose(2, 0, 1)
37
  image = torch.tensor(image)
38
  return transform(image).unsqueeze(0).permute(0, 2, 3, 1).numpy()
39
 
 
51
  labels = [answer_reverse_mapping[str(i)] for i in indices]
52
  return labels, values
53
 
54
+
55
+ with open("translation_dict.json") as f:
56
  translate_dict = json.load(f)
57
 
58
+
59
  def translate_labels(labels, lang_id):
60
  translated_labels = []
61
  for label in labels:
62
+ if label == "<unk>":
63
  translated_labels.append("<unk>")
64
  elif lang_id == "en":
65
  translated_labels.append(label)
 
69
 
70
 
71
  def plotly_express_horizontal_bar_plot(values, labels):
72
+ fig = px.bar(
73
+ x=values,
74
+ y=labels,
75
+ text=[format(value, ".3%") for value in values],
76
+ title="Top-5 Predictions",
77
+ labels={"x": "Scores", "y": "Answers"},
78
+ orientation="h",
79
+ )
80
+ return fig