gchhablani commited on
Commit
e289356
1 Parent(s): 7a89f67
Files changed (3) hide show
  1. app.py +41 -21
  2. translate_answer_mapping.py +4 -3
  3. utils.py +6 -5
app.py CHANGED
@@ -1,26 +1,26 @@
1
- from io import BytesIO
2
- import streamlit as st
3
- import pandas as pd
4
  import json
5
  import os
 
 
 
6
  import numpy as np
7
- from streamlit.elements import markdown
 
 
8
  from PIL import Image
 
 
9
  from model.flax_clip_vision_bert.modeling_clip_vision_bert import (
10
  FlaxCLIPVisionBertForSequenceClassification,
11
  )
 
12
  from utils import (
13
- get_transformed_image,
14
  get_text_attributes,
15
  get_top_5_predictions,
 
16
  plotly_express_horizontal_bar_plot,
17
  translate_labels,
18
  )
19
- import matplotlib.pyplot as plt
20
- from mtranslate import translate
21
-
22
-
23
- from session import _get_state
24
 
25
  state = _get_state()
26
 
@@ -74,9 +74,9 @@ st.write(
74
  "[Gunjan Chhablani](https://huggingface.co/gchhablani), [Bhavitvya Malik](https://huggingface.co/bhavitvyamalik)"
75
  )
76
 
77
- image_col, intro_col = st.beta_columns([3,8])
78
- image_col.image("./misc/mvqa-logo-white.png", use_column_width='always')
79
- intro_col.write(read_markdown('intro.md'))
80
  with st.beta_expander("Usage"):
81
  st.write(read_markdown("usage.md"))
82
 
@@ -85,7 +85,8 @@ with st.beta_expander("Article"):
85
  st.write(read_markdown("caveats.md"))
86
  st.write("## Methodology")
87
  st.image(
88
- "./misc/Multilingual-VQA.png", caption="Masked LM model for Image-text Pretraining."
 
89
  )
90
  st.markdown(read_markdown("pretraining.md"))
91
  st.markdown(read_markdown("finetuning.md"))
@@ -110,7 +111,10 @@ if state.image_file is None:
110
 
111
  col1, col2 = st.beta_columns([6, 4])
112
 
113
- if col2.button("Get a random example", help="Get a random example from the 100 `seeded` image-text pairs."):
 
 
 
114
  sample = dummy_data.sample(1).reset_index()
115
  state.image_file = sample.loc[0, "image_file"]
116
  state.question = sample.loc[0, "question"].strip("- ")
@@ -124,9 +128,15 @@ if col2.button("Get a random example", help="Get a random example from the 100 `
124
 
125
  col2.write("OR")
126
 
127
- uploaded_file = col2.file_uploader("Upload your image", type=["png", "jpg", "jpeg"], help="Upload a file of your choosing.")
 
 
 
 
128
  if uploaded_file is not None:
129
- st.error("Uploading files does not work on HuggingFace spaces. This app only supports random examples for now.")
 
 
130
  # state.image_file = os.path.join("images/val2014", uploaded_file.name)
131
  # state.image = np.array(Image.open(uploaded_file))
132
 
@@ -135,9 +145,13 @@ transformed_image = get_transformed_image(state.image)
135
  # Display Image
136
  col1.image(state.image, use_column_width="auto")
137
 
138
- new_col1, new_col2 = st.beta_columns([5,5])
139
  # Display Question
140
- question = new_col1.text_input(label="Question", value=state.question, help="Type your question regarding the image above in one of the four languages.")
 
 
 
 
141
  new_col1.markdown(
142
  f"""**English Translation**: {question if state.question_lang_id == "en" else translate(question, 'en')}"""
143
  )
@@ -151,11 +165,17 @@ state.answer_lang_id = new_col2.selectbox(
151
  index=options.index(state.answer_lang_id),
152
  options=options,
153
  format_func=lambda x: code_to_name[x],
154
- help="The language to be used to show the top-5 labels."
155
  )
156
 
157
  actual_answer = answer_reverse_mapping[str(state.answer_label)]
158
- new_col2.markdown("**Actual Answer**: " + translate_labels([actual_answer], state.answer_lang_id)[0]+" ("+actual_answer+")")
 
 
 
 
 
 
159
 
160
  # Display Top-5 Predictions
161
  with st.spinner("Loading model..."):
 
 
 
1
  import json
2
  import os
3
+ from io import BytesIO
4
+
5
+ import matplotlib.pyplot as plt
6
  import numpy as np
7
+ import pandas as pd
8
+ import streamlit as st
9
+ from mtranslate import translate
10
  from PIL import Image
11
+ from streamlit.elements import markdown
12
+
13
  from model.flax_clip_vision_bert.modeling_clip_vision_bert import (
14
  FlaxCLIPVisionBertForSequenceClassification,
15
  )
16
+ from session import _get_state
17
  from utils import (
 
18
  get_text_attributes,
19
  get_top_5_predictions,
20
+ get_transformed_image,
21
  plotly_express_horizontal_bar_plot,
22
  translate_labels,
23
  )
 
 
 
 
 
24
 
25
  state = _get_state()
26
 
74
  "[Gunjan Chhablani](https://huggingface.co/gchhablani), [Bhavitvya Malik](https://huggingface.co/bhavitvyamalik)"
75
  )
76
 
77
+ image_col, intro_col = st.beta_columns([3, 8])
78
+ image_col.image("./misc/mvqa-logo-white.png", use_column_width="always")
79
+ intro_col.write(read_markdown("intro.md"))
80
  with st.beta_expander("Usage"):
81
  st.write(read_markdown("usage.md"))
82
 
85
  st.write(read_markdown("caveats.md"))
86
  st.write("## Methodology")
87
  st.image(
88
+ "./misc/Multilingual-VQA.png",
89
+ caption="Masked LM model for Image-text Pretraining.",
90
  )
91
  st.markdown(read_markdown("pretraining.md"))
92
  st.markdown(read_markdown("finetuning.md"))
111
 
112
  col1, col2 = st.beta_columns([6, 4])
113
 
114
+ if col2.button(
115
+ "Get a random example",
116
+ help="Get a random example from the 100 `seeded` image-text pairs.",
117
+ ):
118
  sample = dummy_data.sample(1).reset_index()
119
  state.image_file = sample.loc[0, "image_file"]
120
  state.question = sample.loc[0, "question"].strip("- ")
128
 
129
  col2.write("OR")
130
 
131
+ uploaded_file = col2.file_uploader(
132
+ "Upload your image",
133
+ type=["png", "jpg", "jpeg"],
134
+ help="Upload a file of your choosing.",
135
+ )
136
  if uploaded_file is not None:
137
+ st.error(
138
+ "Uploading files does not work on HuggingFace spaces. This app only supports random examples for now."
139
+ )
140
  # state.image_file = os.path.join("images/val2014", uploaded_file.name)
141
  # state.image = np.array(Image.open(uploaded_file))
142
 
145
  # Display Image
146
  col1.image(state.image, use_column_width="auto")
147
 
148
+ new_col1, new_col2 = st.beta_columns([5, 5])
149
  # Display Question
150
+ question = new_col1.text_input(
151
+ label="Question",
152
+ value=state.question,
153
+ help="Type your question regarding the image above in one of the four languages.",
154
+ )
155
  new_col1.markdown(
156
  f"""**English Translation**: {question if state.question_lang_id == "en" else translate(question, 'en')}"""
157
  )
165
  index=options.index(state.answer_lang_id),
166
  options=options,
167
  format_func=lambda x: code_to_name[x],
168
+ help="The language to be used to show the top-5 labels.",
169
  )
170
 
171
  actual_answer = answer_reverse_mapping[str(state.answer_label)]
172
+ new_col2.markdown(
173
+ "**Actual Answer**: "
174
+ + translate_labels([actual_answer], state.answer_lang_id)[0]
175
+ + " ("
176
+ + actual_answer
177
+ + ")"
178
+ )
179
 
180
  # Display Top-5 Predictions
181
  with st.spinner("Loading model..."):
translate_answer_mapping.py CHANGED
@@ -1,9 +1,10 @@
1
- from mtranslate.core import translate
2
  import json
3
- from tqdm import tqdm
4
- import ray
5
  from asyncio import Event
 
 
 
6
  from ray.actor import ActorHandle
 
7
 
8
  ray.init()
9
  from typing import Tuple
 
1
  import json
 
 
2
  from asyncio import Event
3
+
4
+ import ray
5
+ from mtranslate.core import translate
6
  from ray.actor import ActorHandle
7
+ from tqdm import tqdm
8
 
9
  ray.init()
10
  from typing import Tuple
utils.py CHANGED
@@ -1,12 +1,13 @@
1
- from torchvision.io import read_image, ImageReadMode
2
- import torch
3
  import numpy as np
 
 
 
 
4
  from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
5
  from torchvision.transforms.functional import InterpolationMode
6
  from transformers import BertTokenizerFast
7
- import plotly.express as px
8
- import json
9
- from PIL import Image
10
 
11
 
12
  class Transform(torch.nn.Module):
1
+ import json
2
+
3
  import numpy as np
4
+ import plotly.express as px
5
+ import torch
6
+ from PIL import Image
7
+ from torchvision.io import ImageReadMode, read_image
8
  from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
9
  from torchvision.transforms.functional import InterpolationMode
10
  from transformers import BertTokenizerFast
 
 
 
11
 
12
 
13
  class Transform(torch.nn.Module):