Spaces:
Runtime error
Runtime error
gchhablani
commited on
Commit
·
f384719
1
Parent(s):
571a3f6
Add auto scaling image
Browse files- apps/mlm.py +18 -6
- apps/vqa.py +1 -1
apps/mlm.py
CHANGED
@@ -50,8 +50,11 @@ def app(state):
|
|
50 |
if mlm_state.mlm_image_file is None:
|
51 |
mlm_state.mlm_image_file = dummy_data.loc[first_index, "image_file"]
|
52 |
caption = dummy_data.loc[first_index, "caption"].strip("- ")
|
|
|
53 |
ids = bert_tokenizer.encode(caption)
|
54 |
-
|
|
|
|
|
55 |
mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
|
56 |
mlm_state.caption_lang_id = dummy_data.loc[first_index, "lang_id"]
|
57 |
|
@@ -72,8 +75,11 @@ def app(state):
|
|
72 |
sample = dummy_data.sample(1).reset_index()
|
73 |
mlm_state.mlm_image_file = sample.loc[0, "image_file"]
|
74 |
caption = sample.loc[0, "caption"].strip("- ")
|
|
|
75 |
ids = bert_tokenizer.encode(caption)
|
76 |
-
|
|
|
|
|
77 |
mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
|
78 |
mlm_state.caption_lang_id = sample.loc[0, "lang_id"]
|
79 |
|
@@ -99,7 +105,7 @@ def app(state):
|
|
99 |
new_col1, new_col2 = st.beta_columns([5, 5])
|
100 |
|
101 |
# Display Image
|
102 |
-
new_col1.image(mlm_state.mlm_image, use_column_width="
|
103 |
|
104 |
# Display caption
|
105 |
new_col2.write("Write your text with exactly one [MASK] token.")
|
@@ -109,9 +115,14 @@ def app(state):
|
|
109 |
help="Type your masked caption regarding the image above in one of the four languages.",
|
110 |
)
|
111 |
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
115 |
caption_inputs = get_text_attributes(caption)
|
116 |
|
117 |
# Display Top-5 Predictions
|
@@ -119,6 +130,7 @@ def app(state):
|
|
119 |
scores = predict(transformed_image, dict(caption_inputs))
|
120 |
scores = softmax(scores)
|
121 |
labels, values = get_top_5_predictions(scores)
|
|
|
122 |
# newer_col1, newer_col2 = st.beta_columns([6,4])
|
123 |
fig = plotly_express_horizontal_bar_plot(values, labels)
|
124 |
st.dataframe(pd.DataFrame({"Tokens":labels, "English Translation": list(map(lambda x: translate(x),labels))}).T)
|
|
|
50 |
if mlm_state.mlm_image_file is None:
|
51 |
mlm_state.mlm_image_file = dummy_data.loc[first_index, "image_file"]
|
52 |
caption = dummy_data.loc[first_index, "caption"].strip("- ")
|
53 |
+
mlm_state.unmasked_caption = caption
|
54 |
ids = bert_tokenizer.encode(caption)
|
55 |
+
mask_index = np.random.randint(1, len(ids) - 1)
|
56 |
+
mlm_state.currently_masked_token = ids[mask_index]
|
57 |
+
ids[mask_index] = bert_tokenizer.mask_token_id
|
58 |
mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
|
59 |
mlm_state.caption_lang_id = dummy_data.loc[first_index, "lang_id"]
|
60 |
|
|
|
75 |
sample = dummy_data.sample(1).reset_index()
|
76 |
mlm_state.mlm_image_file = sample.loc[0, "image_file"]
|
77 |
caption = sample.loc[0, "caption"].strip("- ")
|
78 |
+
mlm_state.unmasked_caption = caption
|
79 |
ids = bert_tokenizer.encode(caption)
|
80 |
+
mask_index = np.random.randint(1, len(ids) - 1)
|
81 |
+
mlm_state.currently_masked_token = ids[mask_index]
|
82 |
+
ids[mask_index] = bert_tokenizer.mask_token_id
|
83 |
mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
|
84 |
mlm_state.caption_lang_id = sample.loc[0, "lang_id"]
|
85 |
|
|
|
105 |
new_col1, new_col2 = st.beta_columns([5, 5])
|
106 |
|
107 |
# Display Image
|
108 |
+
new_col1.image(mlm_state.mlm_image, use_column_width="auto")
|
109 |
|
110 |
# Display caption
|
111 |
new_col2.write("Write your text with exactly one [MASK] token.")
|
|
|
115 |
help="Type your masked caption regarding the image above in one of the four languages.",
|
116 |
)
|
117 |
|
118 |
+
if caption == mlm_state.caption:
|
119 |
+
new_col2.markdown("**Masked Token**: "+mlm_state.currently_masked_token)
|
120 |
+
new_col2.markdown("**English Translation: " + mlm_state.unmasked_caption if mlm_state.caption_lang_id == "en" else translate(mlm_state.unmasked_caption, 'en'))
|
121 |
+
|
122 |
+
else:
|
123 |
+
new_col2.markdown(
|
124 |
+
f"""**English Translation**: {caption if mlm_state.caption_lang_id == "en" else translate(caption, 'en')}"""
|
125 |
+
)
|
126 |
caption_inputs = get_text_attributes(caption)
|
127 |
|
128 |
# Display Top-5 Predictions
|
|
|
130 |
scores = predict(transformed_image, dict(caption_inputs))
|
131 |
scores = softmax(scores)
|
132 |
labels, values = get_top_5_predictions(scores)
|
133 |
+
print(labels)
|
134 |
# newer_col1, newer_col2 = st.beta_columns([6,4])
|
135 |
fig = plotly_express_horizontal_bar_plot(values, labels)
|
136 |
st.dataframe(pd.DataFrame({"Tokens":labels, "English Translation": list(map(lambda x: translate(x),labels))}).T)
|
apps/vqa.py
CHANGED
@@ -109,7 +109,7 @@ def app(state):
|
|
109 |
new_col1, new_col2 = st.beta_columns([5, 5])
|
110 |
|
111 |
# Display Image
|
112 |
-
new_col1.image(vqa_state.vqa_image, use_column_width="
|
113 |
|
114 |
# Display Question
|
115 |
question = new_col2.text_input(
|
|
|
109 |
new_col1, new_col2 = st.beta_columns([5, 5])
|
110 |
|
111 |
# Display Image
|
112 |
+
new_col1.image(vqa_state.vqa_image, use_column_width="auto")
|
113 |
|
114 |
# Display Question
|
115 |
question = new_col2.text_input(
|