gchhablani commited on
Commit
69e32d1
1 Parent(s): 7de02ed

Rearrange app

Browse files
app.py CHANGED
@@ -4,6 +4,7 @@ import pandas as pd
4
  import json
5
  import os
6
  import numpy as np
 
7
  from model.flax_clip_vision_bert.modeling_clip_vision_bert import FlaxCLIPVisionBertForSequenceClassification
8
  from utils import get_transformed_image, get_text_attributes, get_top_5_predictions, plotly_express_horizontal_bar_plot, translate_labels
9
  import matplotlib.pyplot as plt
@@ -15,15 +16,30 @@ from session import _get_state
15
 
16
  state = _get_state()
17
 
18
- @st.cache
19
  def load_model(ckpt):
20
  return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
21
 
 
 
 
 
22
  def softmax(logits):
23
  return np.exp(logits)/np.sum(np.exp(logits), axis=0)
24
 
 
 
 
 
25
  checkpoints = ['./ckpt/ckpt-60k-5999'] # TODO: Maybe add more checkpoints?
26
  dummy_data = pd.read_csv('dummy_vqa_multilingual.tsv', sep='\t')
 
 
 
 
 
 
 
27
  with open('answer_reverse_mapping.json') as f:
28
  answer_reverse_mapping = json.load(f)
29
 
@@ -36,69 +52,41 @@ st.set_page_config(
36
  )
37
 
38
  st.title("Multilingual Visual Question Answering")
 
39
 
40
 
 
 
 
 
 
 
41
 
42
- with st.beta_expander("About"):
43
- st.write("This project is focused on Mutilingual Visual Question Answering. Most of the existing datasets and models on this task work with English-only image-text pairs. Our intention here is to provide a Proof-of-Concept with our simple ViT+BERT model which can be trained on multilingual text checkpoints with pre-trained image encoders and well enough. Due to lack of good-quality multilingual data, we translate subsets of the Conceptual 12M dataset into English (already in English), French, German and Spanish using the Marian models. We achieved 0.49 accuracy on the multilingual validation set we created. With better captions, and hyperparameter-tuning, we expect to see higher performance.")
44
- with st.beta_expander("Method"):
45
- col1, col2 = st.beta_columns([5,4])
46
- col1.image("./misc/Multilingual-VQA.png")
47
- col2.markdown("""
48
- ## Pretraining
49
- We follow an approach similar to [VisualBERT](https://arxiv.org/abs/1908.03557). Instead of using a FasterRCNN to get image features, we use a ViT encoder.
50
- The task is text-only MLM (Masked Language Modeling). We mask only the text tokens and try to predict the masked tokens. The VisualBERT authors also use a sentence-image matching task where two captions are matched against an image, but we skip this for the sake of simplicity.
51
- ### Dataset
52
- The dataset we use for pre-training is a cleaned version of [Conceptual 12M](https://github.com/google-research-datasets/conceptual-12m). The dataset is downloaded and then broken images are removed which gives us about 10M images. Then we use the MBart50 `
53
- mbart-large-50-one-to-many-mmt` checkpoint to translate the dataset into four different languages - English, French, German, and Spanish, keeping 2.5 million examples of each language.
54
- """)
55
-
56
- st.markdown("""
57
- ### Model
58
- The model is shown in the image above.We create a custom model in Flax which integerates the ViT model inside BERT embeddings. We also use custom configs and modules in order to accomodate for these changes, and allow loading from BERT and ViT checkpoints. The image is fed to the ViT encoder and the text is fed to the word-embedding layers of BERT model. We use the `bert-base-multilingual-uncased` and `openai/clip-vit-base-patch32` checkpoints for BERT and ViT (actually CLIPVision) models, respectively. All our code is available on [GitHub](https://github.com/gchhablani/multilingual-vqa).
59
- ## Fine-tuning
60
-
61
- ### Dataset
62
- For fine-tuning, we use the [VQA 2.0](https://visualqa.org/) dataset - particularly, the `train` and `validation` sets. We translate all the questions into the four languages specified above using language-specific MarianMT models. This is because MarianMT models return better labels and are faster, hence, are better for fine-tuning. We get 4x the number of examples in each subset.
63
- ### Model
64
- We use the `SequenceClassification` model as reference to create our own sequence classification model. 3129 answer labels are chosen, as is the convention for the English VQA task, which can be found [here](https://github.com/gchhablani/multilingual-vqa/blob/main/answer_mapping.json). These are the same labels used in fine-tuning of the VisualBERT models. The outputs shown here have been translated using the [`mtranslate`](https://github.com/mouuff/mtranslate) Google Translate API library. Then we use various pre-trained checkpoints and train the sequence classification model for various steps.
65
-
66
- Checkpoints:
67
- - Pre-trained checkpoint: [multilingual-vqa](https://huggingface.co/flax-community/multilingual-vqa)
68
- - Fine-tuned on 45k pretrained checkpoint: [multilingual-vqa-pt-45k-ft](https://huggingface.co/flax-community/multilingual-vqa-pt-45k-ft)
69
- - Fine-tuned on 45k pretrained checkpoint with AdaFactor (others use AdamW): [multilingual-vqa-pt-45k-ft-adf](https://huggingface.co/flax-community/multilingual-vqa-pt-45k-ft-adf)
70
- - Fine-tuned on 60k pretrained checkpoint: [multilingual-vqa-pt-60k-ft](https://huggingface.co/flax-community/multilingual-vqa-pt-60k-ft)
71
- - Fine-tuned on 70k pretrained checkpoint: [multilingual-vqa-pt-60k-ft](https://huggingface.co/flax-community/multilingual-vqa-pt-70k-ft)
72
- - From scratch (without pre-training) model: [multilingual-vqa-ft](https://huggingface.co/flax-community/multilingual-vqa-ft)
73
-
74
- **Caveat**: The best fine-tuned model only achieves 0.49 accuracy on the multilingual validation data that we create. This could be because of not-so-great quality translations, sub-optimal hyperparameters and lack of ample training. In future, we hope to improve this model by addressing such concerns.
75
- """)
76
-
77
- with st.beta_expander("Cherry-Picked Results"):
78
- pass
79
-
80
- with st.beta_expander("Conclusion"):
81
- pass
82
 
83
  with st.beta_expander("Usage"):
84
- pass
 
 
 
 
 
 
85
 
 
 
86
  # Init Session State
87
  if state.image_file is None:
88
- state.image_file = dummy_data.loc[0,'image_file']
89
- state.question = dummy_data.loc[0,'question'].strip('- ')
90
- state.answer_label = dummy_data.loc[0,'answer_label']
91
- state.question_lang_id = dummy_data.loc[0, 'lang_id']
92
- state.answer_lang_id = dummy_data.loc[0, 'lang_id']
93
 
94
  image_path = os.path.join('images',state.image_file)
95
  image = plt.imread(image_path)
96
  state.image = image
97
 
98
- col1, col2 = st.beta_columns([5,5])
99
-
100
- # Display Image
101
- col1.image(state.image, use_column_width='always')
102
 
103
  if col2.button('Get a random example'):
104
  sample = dummy_data.sample(1).reset_index()
@@ -112,7 +100,7 @@ if col2.button('Get a random example'):
112
  image = plt.imread(image_path)
113
  state.image = image
114
 
115
- st.write("OR")
116
 
117
  uploaded_file = col2.file_uploader('Upload your image', type=['png','jpg','jpeg'])
118
  if uploaded_file is not None:
@@ -122,20 +110,22 @@ if uploaded_file is not None:
122
 
123
  transformed_image = get_transformed_image(state.image)
124
 
 
 
 
125
  # Display Question
126
- question = st.text_input(label="Question", value=state.question)
127
- st.markdown(f"""**English Translation**: {question if state.question_lang_id == "en" else translate(question, 'en')}""")
128
  question_inputs = get_text_attributes(question)
129
 
130
  # Select Language
131
  options = ['en', 'de', 'es', 'fr']
132
- state.answer_lang_id = st.selectbox('Answer Language', index=options.index(state.answer_lang_id), options=options)
133
  # Display Top-5 Predictions
134
  with st.spinner('Loading model...'):
135
  model = load_model(checkpoints[0])
136
  with st.spinner('Predicting...'):
137
- predictions = model(pixel_values = transformed_image, **question_inputs)
138
- logits = np.array(predictions[0][0])
139
  logits = softmax(logits)
140
  labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
141
  translated_labels = translate_labels(labels, state.answer_lang_id)
 
4
  import json
5
  import os
6
  import numpy as np
7
+ from streamlit.elements import markdown
8
  from model.flax_clip_vision_bert.modeling_clip_vision_bert import FlaxCLIPVisionBertForSequenceClassification
9
  from utils import get_transformed_image, get_text_attributes, get_top_5_predictions, plotly_express_horizontal_bar_plot, translate_labels
10
  import matplotlib.pyplot as plt
 
16
 
17
  state = _get_state()
18
 
19
+ @st.cache(persist=True)
20
  def load_model(ckpt):
21
  return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
22
 
23
+ @st.cache(persist=True)
24
+ def predict(model, transformed_image, question_inputs):
25
+ return np.array(model(pixel_values = transformed_image, **question_inputs)[0][0])
26
+
27
  def softmax(logits):
28
  return np.exp(logits)/np.sum(np.exp(logits), axis=0)
29
 
30
+ def read_markdown(path, parent="./sections/"):
31
+ with open(os.path.join(parent,path)) as f:
32
+ return f.read()
33
+
34
  checkpoints = ['./ckpt/ckpt-60k-5999'] # TODO: Maybe add more checkpoints?
35
  dummy_data = pd.read_csv('dummy_vqa_multilingual.tsv', sep='\t')
36
+ code_to_name = {
37
+ "en": "English",
38
+ "fr": "French",
39
+ "de": "German",
40
+ "es": "Spanish",
41
+ }
42
+
43
  with open('answer_reverse_mapping.json') as f:
44
  answer_reverse_mapping = json.load(f)
45
 
 
52
  )
53
 
54
  st.title("Multilingual Visual Question Answering")
55
+ st.write("[Gunjan Chhablani](https://huggingface.co/gchhablani), [Bhavitvya Malik](https://huggingface.co/bhavitvyamalik)")
56
 
57
 
58
+ st.sidebar.write(read_markdown("about.md"))
59
+ st.sidebar.write(read_markdown("caveats.md"))
60
+ st.sidebar.write(read_markdown("challenges.md"))
61
+ st.sidebar.write(read_markdown("social_impact.md"))
62
+ st.sidebar.write(read_markdown("checkpoints.md"))
63
+ st.sidebar.write(read_markdown("acknowledgements.md"))
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  with st.beta_expander("Usage"):
67
+ st.markdown(read_markdown("usage.md"))
68
+
69
+ with st.beta_expander("Method"):
70
+ st.image("./misc/Multilingual-VQA.png")
71
+ st.markdown(read_markdown("pretraining.md"))
72
+ st.markdown(read_markdown("finetuning.md"))
73
+
74
 
75
+
76
+ first_index = 20
77
  # Init Session State
78
  if state.image_file is None:
79
+ state.image_file = dummy_data.loc[first_index,'image_file']
80
+ state.question = dummy_data.loc[first_index,'question'].strip('- ')
81
+ state.answer_label = dummy_data.loc[first_index,'answer_label']
82
+ state.question_lang_id = dummy_data.loc[first_index, 'lang_id']
83
+ state.answer_lang_id = dummy_data.loc[first_index, 'lang_id']
84
 
85
  image_path = os.path.join('images',state.image_file)
86
  image = plt.imread(image_path)
87
  state.image = image
88
 
89
+ col1, col2 = st.beta_columns([6,4])
 
 
 
90
 
91
  if col2.button('Get a random example'):
92
  sample = dummy_data.sample(1).reset_index()
 
100
  image = plt.imread(image_path)
101
  state.image = image
102
 
103
+ col2.write("OR")
104
 
105
  uploaded_file = col2.file_uploader('Upload your image', type=['png','jpg','jpeg'])
106
  if uploaded_file is not None:
 
110
 
111
  transformed_image = get_transformed_image(state.image)
112
 
113
+ # Display Image
114
+ col1.image(state.image, use_column_width='always')
115
+
116
  # Display Question
117
+ question = col2.text_input(label="Question", value=state.question)
118
+ col2.markdown(f"""**English Translation**: {question if state.question_lang_id == "en" else translate(question, 'en')}""")
119
  question_inputs = get_text_attributes(question)
120
 
121
  # Select Language
122
  options = ['en', 'de', 'es', 'fr']
123
+ state.answer_lang_id = col2.selectbox('Answer Language', index=options.index(state.answer_lang_id), options=options, format_func = lambda x: code_to_name[x])
124
  # Display Top-5 Predictions
125
  with st.spinner('Loading model...'):
126
  model = load_model(checkpoints[0])
127
  with st.spinner('Predicting...'):
128
+ logits = predict(model, transformed_image, dict(question_inputs))
 
129
  logits = softmax(logits)
130
  labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
131
  translated_labels = translate_labels(labels, state.answer_lang_id)
sections/about.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # About
2
+ This project is focused on Mutilingual Visual Question Answering. Most of the existing datasets and models on this task work with English-only image-text pairs. Our intention here is to provide a Proof-of-Concept with our simple ViT+BERT model which can be trained on multilingual text checkpoints with pre-trained image encoders and made to perform well enough. Due to lack of good-quality multilingual data, we translate subsets of the Conceptual 12M dataset into English (already in English), French, German and Spanish using the Marian models. We achieved 0.49 accuracy on the multilingual validation set we created. With better captions, and hyperparameter-tuning, we expect to see higher performance.
sections/acknowledgements.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Acknowledgements
2
+ We thank [Nilakshan Kunananthaseelan](https://huggingface.co/knilakshan20) for helping us whenever he could get a chance. We also thank [Abheesht Sharma](https://huggingface.co/abheesht) for helping in the discussions in the initial phases. Lastly, [Luke Melas](https://github.com/lukemelas) helped us get the CC-12M data on our TPU-VMs and we are very grateful to him.
sections/caveats.md ADDED
@@ -0,0 +1 @@
 
 
1
+ **Caveat**: The best fine-tuned model only achieves 0.49 accuracy on the multilingual validation data that we create. This could be because of not-so-great quality translations, sub-optimal hyperparameters and lack of ample training. In future, we hope to improve this model by addressing such concerns.
sections/challenges.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Challenges and Technical Difficulties
2
+ We faced challenges at every step of the way, despite having some example scripts and models ready by the 🤗 team in Flax.
3
+
4
+ - The dataset we used - Conceptual 12M took 2-3 days to translate using MBart (since we didn't have Marian at the time). The major bottleneck was implementing the translation efficiently. We tried using `mtranslate` first but it turned out to be too slow, even with multiprocessing.
5
+
6
+ - The translations with deep learning models aren't as "perfect" as translation APIs like Google and Yandex. This could lead to poor performance.
7
+
8
+ - We prepared the model and config classes for our model from scratch, basing it on `ViT` and `BERT` implementations in Flax. The ViT embeddings should be used inside the BERT embeddings class, which was the major challenge here.
9
+
10
+ - We prepared a training script for image-text text-only MLM and sequence classification, which we based on hybrid clip, masked LM and the text classification examples.
11
+
12
+ - We were only able to get around 1.5 days of training time on TPUs due to above mentioned challenges. We were unable to perform hyperparameter tuning. Our loss curves on the pre-training model show that the training hasn't converged, and we could see further improvement in the MLM accuracy.
13
+
14
+ - The VQA dataset, despite having many examples, and after translating into 4x the number of examples, is small and the model overfits. In order to address this, we need more multilingual data, and lighter models, which are both a major challenge right now.
sections/checkpoints.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ **Checkpoints**:
2
+ - Pre-trained checkpoint: [multilingual-vqa](https://huggingface.co/flax-community/multilingual-vqa)
3
+ - Fine-tuned on 45k pretrained checkpoint: [multilingual-vqa-pt-45k-ft](https://huggingface.co/flax-community/multilingual-vqa-pt-45k-ft)
4
+ - Fine-tuned on 45k pretrained checkpoint with AdaFactor (others use AdamW): [multilingual-vqa-pt-45k-ft-adf](https://huggingface.co/flax-community/multilingual-vqa-pt-45k-ft-adf)
5
+ - Fine-tuned on 60k pretrained checkpoint: [multilingual-vqa-pt-60k-ft](https://huggingface.co/flax-community/multilingual-vqa-pt-60k-ft)
6
+ - Fine-tuned on 70k pretrained checkpoint: [multilingual-vqa-pt-60k-ft](https://huggingface.co/flax-community/multilingual-vqa-pt-70k-ft)
7
+ - From scratch (without pre-training) model: [multilingual-vqa-ft](https://huggingface.co/flax-community/multilingual-vqa-ft)
sections/finetuning.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ ## Fine-tuning
2
+ ### Dataset
3
+ For fine-tuning, we use the [VQA 2.0](https://visualqa.org/) dataset - particularly, the `train` and `validation` sets. We translate all the questions into the four languages specified above using language-specific MarianMT models. This is because MarianMT models return better labels and are faster, hence, are better for fine-tuning. We get 4x the number of examples in each subset.
4
+ ### Model
5
+ We use the `SequenceClassification` model as reference to create our own sequence classification model. In this, a classification layer is attached on top of the pre-trained BERT model in order to performance multi-class classification. 3129 answer labels are chosen, as is the convention for the English VQA task, which can be found [here](https://github.com/gchhablani/multilingual-vqa/blob/main/answer_mapping.json). These are the same labels used in fine-tuning of the VisualBERT models. The outputs shown here have been translated using the [`mtranslate`](https://github.com/mouuff/mtranslate) Google Translate API library. Then we use various pre-trained checkpoints and train the sequence classification model for various steps.
sections/pretraining.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ## Pretraining
2
+ We follow an approach similar to [VisualBERT](https://arxiv.org/abs/1908.03557). Instead of using a FasterRCNN to get image features, we use a ViT encoder. The pre-training task is text-only MLM (Masked Language Modeling). We mask only the text tokens and try to predict the masked tokens. The VisualBERT authors also use a sentence-image matching task where two captions are matched against an image, but we skip this for the sake of simplicity.
3
+ ### Dataset
4
+ The dataset we use for pre-training is a cleaned version of [Conceptual 12M](https://github.com/google-research-datasets/conceptual-12m). The dataset is downloaded and then broken images are removed which gives us about 10M images. Then we use the MBart50 `mbart-large-50-one-to-many-mmt` checkpoint to translate the dataset into four different languages - English, French, German, and Spanish, keeping 2.5 million examples of each language.
5
+ ### Model
6
+ The model is shown in the image above. The `Dummy MLM Head` is actually combined with the MLM head but it never contributes to the MLM loss, hence the name (the predictions on these tokens are ignored). We create a custom model in Flax which integerates the ViT model inside BERT embeddings. We also use custom configs and modules in order to accomodate for these changes, and allow loading from BERT and ViT checkpoints. The image is fed to the ViT encoder and the text is fed to the word-embedding layers of BERT model. We use the `bert-base-multilingual-uncased` and `openai/clip-vit-base-patch32` checkpoints for BERT and ViT (actually CLIPVision) models, respectively. All our code is available on [GitHub](https://github.com/gchhablani/multilingual-vqa).
sections/references.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # References
sections/social_impact.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Social Impact
2
+ Multilingual Visual Question Answering has not received a lot of attention. There are very few multilingual VQA datasets, and that is what we wanted to address here. Our initial plan was to include 4 high-resource and 4 low-resource languages in our training data. However, the existing translations do not perform as well and we would have received poor labels, along with longer training time needed. We hope to improve this in the future by using better translators (for e.g. Google Translate API) to get more multilingual data, especially in low-resource languages. Regardless, our aim with this project was to provide with a pipeline approach to deal with Multilingual visuo-linguistic pretraining and perform Multilingual Visual Question Answering.
sections/usage.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - This demo loads the `FlaxCLIPVisionBertForSequenceClassificationModel` present in the `model` directory of this repository. The checkpoint is loaded from `ckpt/ckpt-60k-5999` which is pre-trained checkpoint with 60k steps and 5999 fine-tuning steps. 100 random examples are present in the `dummy_vqa_multilingual.tsv` which respective images in the `images/val2014` directory.
2
+
3
+ - You can also upload your image using the `Upload your image` file uplaoder and type in a question of your choosing.
4
+
5
+ - We provide `English Translation` of the question for users who are not acquainted with the other languages. This is done using `mtranslate` to keep things flexible enough and needs internet connection as it uses the Google Translate API.
6
+
7
+ - The model predicts the answers from a list of 3129 answers which have their labels present in `answer_reverse_mapping.json`.
8
+
9
+ - Lastly, once can choose the `Answer Language` which is also a saved dictionary created using `mtranslate` library for the 3129 answer options.
10
+
11
+ - The top-5 predictions are displayed below and their respective confidence scores are shown in form of a bar plot.
12
+
13
+