gchhablani commited on
Commit
f4963f2
1 Parent(s): d823ba7

Add about and method in app

Browse files
Files changed (3) hide show
  1. app.py +75 -31
  2. session.py +89 -0
  3. utils.py +2 -2
app.py CHANGED
@@ -11,6 +11,10 @@ from mtranslate import translate
11
  from PIL import Image
12
 
13
 
 
 
 
 
14
  @st.cache
15
  def load_model(ckpt):
16
  return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
@@ -24,7 +28,6 @@ with open('answer_reverse_mapping.json') as f:
24
  answer_reverse_mapping = json.load(f)
25
 
26
 
27
-
28
  st.set_page_config(
29
  page_title="Multilingual VQA",
30
  layout="wide",
@@ -34,58 +37,99 @@ st.set_page_config(
34
 
35
  st.title("Multilingual Visual Question Answering")
36
 
 
 
37
  with st.beta_expander("About"):
38
- pass
39
  with st.beta_expander("Method"):
40
- st.image("./misc/Multilingual-VQA.png")
41
- with st.beta_expander("Results"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  pass
43
 
44
  # Init Session State
45
- if 'image_file' not in st.session_state:
46
- st.session_state.image_file = dummy_data.loc[0,'image_file']
47
- st.session_state.question = dummy_data.loc[0,'question']
48
- st.session_state.answer_label = dummy_data.loc[0,'answer_label']
49
- st.session_state.question_lang_id = dummy_data.loc[0, 'lang_id']
50
- st.session_state.answer_lang_id = dummy_data.loc[0, 'lang_id']
51
 
52
- image_path = os.path.join('images',st.session_state.image_file)
53
  image = plt.imread(image_path)
54
- st.session_state.image = image
55
 
56
  col1, col2 = st.beta_columns([5,5])
57
- if col1.button('Get a Random Example'):
 
 
 
 
58
  sample = dummy_data.sample(1).reset_index()
59
- st.session_state.image_file = sample.loc[0,'image_file']
60
- st.session_state.question = sample.loc[0,'question']
61
- st.session_state.answer_label = sample.loc[0,'answer_label']
62
- st.session_state.question_lang_id = sample.loc[0, 'lang_id']
63
- st.session_state.answer_lang_id = sample.loc[0, 'lang_id']
64
 
65
- image_path = os.path.join('images',st.session_state.image_file)
66
  image = plt.imread(image_path)
67
- st.session_state.image = image
68
 
 
69
 
70
  uploaded_file = col2.file_uploader('Upload your image', type=['png','jpg','jpeg'])
71
  if uploaded_file is not None:
72
- st.session_state.image_file = os.path.join('images/val2014',uploaded_file.name)
73
- st.session_state.image = np.array(Image.open(uploaded_file))
74
-
75
 
76
- transformed_image = get_transformed_image(st.session_state.image)
77
 
78
- # Display Image
79
- st.image(st.session_state.image, use_column_width='always')
80
 
81
  # Display Question
82
- question = st.text_input(label="Question", value=st.session_state.question)
83
- st.markdown(f"""**English Translation**: {question if st.session_state.question_lang_id == "en" else translate(question, 'en')}""")
84
  question_inputs = get_text_attributes(question)
85
 
86
  # Select Language
87
  options = ['en', 'de', 'es', 'fr']
88
- st.session_state.answer_lang_id = st.selectbox('Answer Language', index=options.index(st.session_state.answer_lang_id), options=options)
89
  # Display Top-5 Predictions
90
  with st.spinner('Loading model...'):
91
  model = load_model(checkpoints[0])
@@ -94,6 +138,6 @@ with st.spinner('Predicting...'):
94
  logits = np.array(predictions[0][0])
95
  logits = softmax(logits)
96
  labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
97
- translated_labels = translate_labels(labels, st.session_state.answer_lang_id)
98
  fig = plotly_express_horizontal_bar_plot(values, translated_labels)
99
- st.plotly_chart(fig)
11
  from PIL import Image
12
 
13
 
14
+ from session import _get_state
15
+
16
+ state = _get_state()
17
+
18
  @st.cache
19
  def load_model(ckpt):
20
  return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
28
  answer_reverse_mapping = json.load(f)
29
 
30
 
 
31
  st.set_page_config(
32
  page_title="Multilingual VQA",
33
  layout="wide",
37
 
38
  st.title("Multilingual Visual Question Answering")
39
 
40
+
41
+
42
  with st.beta_expander("About"):
43
+ st.write("This project is focused on Mutilingual Visual Question Answering. Most of the existing datasets and models on this task work with English-only image-text pairs. Our intention here is to provide a Proof-of-Concept with our simple ViT+BERT model which can be trained on multilingual text checkpoints with pre-trained image encoders and well enough. Due to lack of good-quality multilingual data, we translate subsets of the Conceptual 12M dataset into English (already in English), French, German and Spanish using the Marian models. We achieved 0.49 accuracy on the multilingual validation set we created. With better captions, and hyperparameter-tuning, we expect to see higher performance.")
44
  with st.beta_expander("Method"):
45
+ col1, col2 = st.beta_columns([5,4])
46
+ col1.image("./misc/Multilingual-VQA.png")
47
+ col2.markdown("""
48
+ ## Pretraining
49
+ We follow an approach similar to [VisualBERT](https://arxiv.org/abs/1908.03557). Instead of using a FasterRCNN to get image features, we use a ViT encoder.
50
+ The task is text-only MLM (Masked Language Modeling). We mask only the text tokens and try to predict the masked tokens. The VisualBERT authors also use a sentence-image matching task where two captions are matched against an image, but we skip this for the sake of simplicity.
51
+ ### Dataset
52
+ The dataset we use for pre-training is a cleaned version of [Conceptual 12M](https://github.com/google-research-datasets/conceptual-12m). The dataset is downloaded and then broken images are removed which gives us about 10M images. Then we use the MBart50 `
53
+ mbart-large-50-one-to-many-mmt` checkpoint to translate the dataset into four different languages - English, French, German, and Spanish, keeping 2.5 million examples of each language.
54
+ """)
55
+
56
+ st.markdown("""
57
+ ### Model
58
+ The model is shown in the image above.We create a custom model in Flax which integerates the ViT model inside BERT embeddings. We also use custom configs and modules in order to accomodate for these changes, and allow loading from BERT and ViT checkpoints. The image is fed to the ViT encoder and the text is fed to the word-embedding layers of BERT model. We use the `bert-base-multilingual-uncased` and `openai/clip-vit-base-patch32` checkpoints for BERT and ViT (actually CLIPVision) models, respectively. All our code is available on [GitHub](https://github.com/gchhablani/multilingual-vqa).
59
+ ## Fine-tuning
60
+
61
+ ### Dataset
62
+ For fine-tuning, we use the [VQA 2.0](https://visualqa.org/) dataset - particularly, the `train` and `validation` sets. We translate all the questions into the four languages specified above using language-specific MarianMT models. This is because MarianMT models return better labels and are faster, hence, are better for fine-tuning. We get 4x the number of examples in each subset.
63
+ ### Model
64
+ We use the `SequenceClassification` model as reference to create our own sequence classification model. 3129 answer labels are chosen, as is the convention for the English VQA task, which can be found [here](https://github.com/gchhablani/multilingual-vqa/blob/main/answer_mapping.json). These are the same labels used in fine-tuning of the VisualBERT models. The outputs shown here have been translated using the [`mtranslate`](https://github.com/mouuff/mtranslate) Google Translate API library. Then we use various pre-trained checkpoints and train the sequence classification model for various steps.
65
+
66
+ Checkpoints:
67
+ - Pre-trained checkpoint: [multilingual-vqa](https://huggingface.co/flax-community/multilingual-vqa)
68
+ - Fine-tuned on 45k pretrained checkpoint: [multilingual-vqa-pt-45k-ft](https://huggingface.co/flax-community/multilingual-vqa-pt-45k-ft)
69
+ - Fine-tuned on 45k pretrained checkpoint with AdaFactor (others use AdamW): [multilingual-vqa-pt-45k-ft-adf](https://huggingface.co/flax-community/multilingual-vqa-pt-45k-ft-adf)
70
+ - Fine-tuned on 60k pretrained checkpoint: [multilingual-vqa-pt-60k-ft](https://huggingface.co/flax-community/multilingual-vqa-pt-60k-ft)
71
+ - Fine-tuned on 70k pretrained checkpoint: [multilingual-vqa-pt-60k-ft](https://huggingface.co/flax-community/multilingual-vqa-pt-70k-ft)
72
+ - From scratch (without pre-training) model: [multilingual-vqa-ft](https://huggingface.co/flax-community/multilingual-vqa-ft)
73
+
74
+ **Caveat**: The best fine-tuned model only achieves 0.49 accuracy on the multilingual validation data that we create. This could be because of not-so-great quality translations, sub-optimal hyperparameters and lack of ample training. In future, we hope to improve this model by addressing such concerns.
75
+ """)
76
+
77
+ with st.beta_expander("Cherry-Picked Results"):
78
+ pass
79
+
80
+ with st.beta_expander("Conclusion"):
81
+ pass
82
+
83
+ with st.beta_expander("Usage"):
84
  pass
85
 
86
  # Init Session State
87
+ if state.image_file is None:
88
+ state.image_file = dummy_data.loc[0,'image_file']
89
+ state.question = dummy_data.loc[0,'question'].strip('- ')
90
+ state.answer_label = dummy_data.loc[0,'answer_label']
91
+ state.question_lang_id = dummy_data.loc[0, 'lang_id']
92
+ state.answer_lang_id = dummy_data.loc[0, 'lang_id']
93
 
94
+ image_path = os.path.join('images',state.image_file)
95
  image = plt.imread(image_path)
96
+ state.image = image
97
 
98
  col1, col2 = st.beta_columns([5,5])
99
+
100
+ # Display Image
101
+ col1.image(state.image, use_column_width='always')
102
+
103
+ if col2.button('Get a random example'):
104
  sample = dummy_data.sample(1).reset_index()
105
+ state.image_file = sample.loc[0,'image_file']
106
+ state.question = sample.loc[0,'question'].strip('- ')
107
+ state.answer_label = sample.loc[0,'answer_label']
108
+ state.question_lang_id = sample.loc[0, 'lang_id']
109
+ state.answer_lang_id = sample.loc[0, 'lang_id']
110
 
111
+ image_path = os.path.join('images',state.image_file)
112
  image = plt.imread(image_path)
113
+ state.image = image
114
 
115
+ st.write("OR")
116
 
117
  uploaded_file = col2.file_uploader('Upload your image', type=['png','jpg','jpeg'])
118
  if uploaded_file is not None:
119
+ state.image_file = os.path.join('images/val2014',uploaded_file.name)
120
+ state.image = np.array(Image.open(uploaded_file))
 
121
 
 
122
 
123
+ transformed_image = get_transformed_image(state.image)
 
124
 
125
  # Display Question
126
+ question = st.text_input(label="Question", value=state.question)
127
+ st.markdown(f"""**English Translation**: {question if state.question_lang_id == "en" else translate(question, 'en')}""")
128
  question_inputs = get_text_attributes(question)
129
 
130
  # Select Language
131
  options = ['en', 'de', 'es', 'fr']
132
+ state.answer_lang_id = st.selectbox('Answer Language', index=options.index(state.answer_lang_id), options=options)
133
  # Display Top-5 Predictions
134
  with st.spinner('Loading model...'):
135
  model = load_model(checkpoints[0])
138
  logits = np.array(predictions[0][0])
139
  logits = softmax(logits)
140
  labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
141
+ translated_labels = translate_labels(labels, state.answer_lang_id)
142
  fig = plotly_express_horizontal_bar_plot(values, translated_labels)
143
+ st.plotly_chart(fig, use_container_width = True)
session.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Code for managing session state, which is needed for multi-input forms
3
+ # See https://github.com/streamlit/streamlit/issues/1557
4
+ #
5
+ # This code is taken from
6
+ # https://gist.github.com/okld/0aba4869ba6fdc8d49132e6974e2e662
7
+ #
8
+ from streamlit.hashing import _CodeHasher
9
+ from streamlit.report_thread import get_report_ctx
10
+ from streamlit.server.server import Server
11
+
12
+
13
+ class _SessionState:
14
+ def __init__(self, session, hash_funcs):
15
+ """Initialize SessionState instance."""
16
+ self.__dict__["_state"] = {
17
+ "data": {},
18
+ "hash": None,
19
+ "hasher": _CodeHasher(hash_funcs),
20
+ "is_rerun": False,
21
+ "session": session,
22
+ }
23
+
24
+ def __call__(self, **kwargs):
25
+ """Initialize state data once."""
26
+ for item, value in kwargs.items():
27
+ if item not in self._state["data"]:
28
+ self._state["data"][item] = value
29
+
30
+ def __getitem__(self, item):
31
+ """Return a saved state value, None if item is undefined."""
32
+ return self._state["data"].get(item, None)
33
+
34
+ def __getattr__(self, item):
35
+ """Return a saved state value, None if item is undefined."""
36
+ return self._state["data"].get(item, None)
37
+
38
+ def __setitem__(self, item, value):
39
+ """Set state value."""
40
+ self._state["data"][item] = value
41
+
42
+ def __setattr__(self, item, value):
43
+ """Set state value."""
44
+ self._state["data"][item] = value
45
+
46
+ def clear(self):
47
+ """Clear session state and request a rerun."""
48
+ self._state["data"].clear()
49
+ self._state["session"].request_rerun()
50
+
51
+ def sync(self):
52
+ """
53
+ Rerun the app with all state values up to date from the beginning to
54
+ fix rollbacks.
55
+ """
56
+ data_to_bytes = self._state["hasher"].to_bytes(self._state["data"], None)
57
+
58
+ # Ensure to rerun only once to avoid infinite loops
59
+ # caused by a constantly changing state value at each run.
60
+ #
61
+ # Example: state.value += 1
62
+ if self._state["is_rerun"]:
63
+ self._state["is_rerun"] = False
64
+
65
+ elif self._state["hash"] is not None:
66
+ if self._state["hash"] != data_to_bytes:
67
+ self._state["is_rerun"] = True
68
+ self._state["session"].request_rerun()
69
+
70
+ self._state["hash"] = data_to_bytes
71
+
72
+
73
+ def _get_session():
74
+ session_id = get_report_ctx().session_id
75
+ session_info = Server.get_current()._get_session_info(session_id)
76
+
77
+ if session_info is None:
78
+ raise RuntimeError("Couldn't get your Streamlit Session object.")
79
+
80
+ return session_info.session
81
+
82
+
83
+ def _get_state(hash_funcs=None):
84
+ session = _get_session()
85
+
86
+ if not hasattr(session, "_custom_session_state"):
87
+ session._custom_session_state = _SessionState(session, hash_funcs)
88
+
89
+ return session._custom_session_state
utils.py CHANGED
@@ -45,7 +45,7 @@ def get_text_attributes(text):
45
 
46
  def get_top_5_predictions(logits, answer_reverse_mapping):
47
  indices = np.argsort(logits)[-5:]
48
- values = np.round(logits[indices], decimals=2)
49
  labels = [answer_reverse_mapping[str(i)] for i in indices]
50
  return labels, values
51
 
@@ -65,5 +65,5 @@ def translate_labels(labels, lang_id):
65
 
66
 
67
  def plotly_express_horizontal_bar_plot(values, labels):
68
- fig = px.bar(x=values, y=labels, text = values, title="Top-5 Predictions", labels={"x": "Scores", "y":"Answers"}, orientation="h")
69
  return fig
45
 
46
  def get_top_5_predictions(logits, answer_reverse_mapping):
47
  indices = np.argsort(logits)[-5:]
48
+ values = logits[indices]
49
  labels = [answer_reverse_mapping[str(i)] for i in indices]
50
  return labels, values
51
 
65
 
66
 
67
  def plotly_express_horizontal_bar_plot(values, labels):
68
+ fig = px.bar(x=values, y=labels, text = [format(value, ".3%") for value in values], title="Top-5 Predictions", labels={"x": "Scores", "y":"Answers"}, orientation="h")
69
  return fig