bhavitvyamalik commited on
Commit
54dc7b4
1 Parent(s): 8ab0ff2
app.py CHANGED
@@ -1,187 +1,200 @@
1
- from io import BytesIO
2
  import streamlit as st
3
- import pandas as pd
4
- import json
5
- import os
6
- import numpy as np
7
- from streamlit import caching
8
- from PIL import Image
9
- from model.flax_clip_vision_mbart.modeling_clip_vision_mbart import (
10
- FlaxCLIPVisionMBartForConditionalGeneration,
11
- )
12
- from transformers import MBart50TokenizerFast
13
- from utils import (
14
- get_transformed_image,
15
- )
16
- import matplotlib.pyplot as plt
17
- from mtranslate import translate
18
-
19
-
20
  from session import _get_state
21
-
22
- state = _get_state()
23
-
24
-
25
- @st.cache
26
- def load_model(ckpt):
27
- return FlaxCLIPVisionMBartForConditionalGeneration.from_pretrained(ckpt)
28
-
29
-
30
- tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50")
31
-
32
- language_mapping = {
33
- "en": "en_XX",
34
- "de": "de_DE",
35
- "fr": "fr_XX",
36
- "es": "es_XX"
37
- }
38
-
39
- code_to_name = {
40
- "en": "English",
41
- "fr": "French",
42
- "de": "German",
43
- "es": "Spanish",
44
- }
45
-
46
- @st.cache
47
- def generate_sequence(pixel_values, lang_code, num_beams, temperature, top_p, do_sample, top_k, max_length):
48
- lang_code = language_mapping[lang_code]
49
- output_ids = state.model.generate(input_ids=pixel_values, forced_bos_token_id=tokenizer.lang_code_to_id[lang_code], max_length=max_length, num_beams=num_beams, temperature=temperature, top_p = top_p, top_k=top_k, do_sample=do_sample)
50
- print(output_ids)
51
- output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=max_length)
52
- return output_sequence
53
-
54
- def read_markdown(path, parent="./sections/"):
55
- with open(os.path.join(parent, path)) as f:
56
- return f.read()
57
-
58
-
59
- checkpoints = ["./ckpt/ckpt-51999"] # TODO: Maybe add more checkpoints?
60
- dummy_data = pd.read_csv("reference.tsv", sep="\t")
61
-
62
- st.set_page_config(
63
- page_title="Multilingual Image Captioning",
64
- layout="wide",
65
- initial_sidebar_state="collapsed",
66
- page_icon="./misc/mic-logo.png",
67
- )
68
-
69
- st.title("Multilingual Image Captioning")
70
- st.write(
71
- "[Bhavitvya Malik](https://huggingface.co/bhavitvyamalik), [Gunjan Chhablani](https://huggingface.co/gchhablani)"
72
- )
73
-
74
- st.sidebar.title("Generation Parameters")
75
- # max_length = st.sidebar.number_input("Max Length", min_value=16, max_value=128, value=64, step=1, help="The maximum length of sequence to be generated.")
76
- max_length = 64
77
- do_sample = st.sidebar.checkbox("Sample", value=False, help="Sample from the model instead of using beam search.")
78
- top_k = st.sidebar.number_input("Top K", min_value=10, max_value=200, value=50, step=1, help="The number of highest probability vocabulary tokens to keep for top-k-filtering.")
79
- num_beams = st.sidebar.number_input(label="Number of Beams", min_value=2, max_value=10, value=4, step=1, help="Number of beams to be used in beam search.")
80
- temperature = st.sidebar.select_slider(label="Temperature", options = list(np.arange(0.0,1.1, step=0.1)), value=1.0, help ="The value used to module the next token probabilities.", format_func=lambda x: f"{x:.2f}")
81
- top_p = st.sidebar.select_slider(label = "Top-P", options = list(np.arange(0.0,1.1, step=0.1)),value=1.0, help="Nucleus Sampling : If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are kept for generation.", format_func=lambda x: f"{x:.2f}")
82
- if st.sidebar.button("Clear All Cache"):
83
- caching.clear_cache()
84
- image_col, intro_col = st.beta_columns([3, 8])
85
- image_col.image("./misc/mic-logo.png", use_column_width="always")
86
- intro_col.write(read_markdown("intro.md"))
87
-
88
- with st.beta_expander("Usage"):
89
- st.markdown(read_markdown("usage.md"))
90
-
91
- with st.beta_expander("Article"):
92
- st.write(read_markdown("abstract.md"))
93
- st.write(read_markdown("caveats.md"))
94
- st.write("## Methodology")
95
- st.image(
96
- "./misc/Multilingual-IC.png"
97
- )
98
- st.markdown(read_markdown("pretraining.md"))
99
- st.write(read_markdown("challenges.md"))
100
- st.write(read_markdown("social_impact.md"))
101
- st.write(read_markdown("bias.md"))
102
-
103
- col1, col2, col3, col4 = st.beta_columns([0.5,2.5,2.5,0.5])
104
- with col2:
105
- st.image("./misc/examples/female_dev_1.jpg", width=350, caption = 'German Caption: <PERSON> arbeitet an einem Computer.', use_column_width='always')
106
- with col3:
107
- st.image("./misc/examples/female_doctor.jpg", width=350, caption = 'English Caption: A portrait of <PERSON>, a doctor who specializes in health care.', use_column_width='always')
108
-
109
- col1, col2, col3, col4 = st.beta_columns([0.5,2.5,2.5,0.5])
110
- with col2:
111
- st.image("./misc/examples/female_doctor_1.jpg", width=350, caption = 'Spanish Caption: El Dr. <PERSON> es un estudiante de posgrado.', use_column_width='always')
112
- with col3:
113
- st.image("./misc/examples/women_cricket.jpg", width=350, caption = 'English Caption: <PERSON> of India bats against <PERSON> of Australia during the first Twenty20 match between India and Australia at Indian Bowl Stadium in New Delhi on Friday. - PTI', use_column_width='always')
114
-
115
- col1, col2, col3, col4 = st.beta_columns([0.5,2.5,2.5,0.5])
116
- with col2:
117
- st.image("./misc/examples/female_dev_2.jpg", width=350, caption = "French Caption: Un écran d'ordinateur avec un écran d'ordinateur ouvert.", use_column_width='always')
118
- with col3:
119
- st.image("./misc/examples/female_biker_resized.jpg", width=350, caption = 'German Caption: <PERSON> auf dem Motorrad von <PERSON>.', use_column_width='always')
120
-
121
- st.write(read_markdown("future_scope.md"))
122
- st.write(read_markdown("references.md"))
123
- # st.write(read_markdown("checkpoints.md"))
124
- st.write(read_markdown("acknowledgements.md"))
125
-
126
- if state.model is None:
127
- with st.spinner("Loading model..."):
128
- state.model = load_model(checkpoints[0])
129
-
130
- first_index = 25
131
- # Init Session State
132
- if state.image_file is None:
133
- state.image_file = dummy_data.loc[first_index, "image_file"]
134
- state.caption = dummy_data.loc[first_index, "caption"].strip("- ")
135
- state.lang_id = dummy_data.loc[first_index, "lang_id"]
136
-
137
- image_path = os.path.join("images", state.image_file)
138
- image = plt.imread(image_path)
139
- state.image = image
140
-
141
- if st.button("Get a random example", help="Get a random example from one of the seeded examples."):
142
- sample = dummy_data.sample(1).reset_index()
143
- state.image_file = sample.loc[0, "image_file"]
144
- state.caption = sample.loc[0, "caption"].strip("- ")
145
- state.lang_id = sample.loc[0, "lang_id"]
146
-
147
- image_path = os.path.join("images", state.image_file)
148
- image = plt.imread(image_path)
149
- state.image = image
150
-
151
- transformed_image = get_transformed_image(state.image)
152
-
153
- new_col1, new_col2 = st.beta_columns([5,5])
154
-
155
- # Display Image
156
- new_col1.image(state.image, use_column_width="always")
157
- # Display Reference Caption
158
- with new_col1.beta_expander("Reference Caption"):
159
- st.write("**Reference Caption**: " + state.caption)
160
- st.markdown(
161
- f"""**English Translation**: {state.caption if state.lang_id == "en" else translate(state.caption, 'en')}"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  )
163
 
164
- # Select Language
165
- options = list(code_to_name.keys())
166
- lang_id = new_col2.selectbox(
167
- "Language",
168
- index=options.index(state.lang_id),
169
- options=options,
170
- format_func=lambda x: code_to_name[x],
171
- help="The language in which caption is to be generated."
172
- )
173
-
174
- sequence = ['']
175
- if new_col2.button("Generate Caption", help="Generate a caption in the specified language."):
176
- with st.spinner("Generating Sequence..."):
177
- sequence = generate_sequence(transformed_image, lang_id, num_beams, temperature, top_p, do_sample, top_k, max_length)
178
- # print(sequence)
179
-
180
- if sequence!=['']:
181
- new_col2.write(
182
- "**Generated Caption**: "+sequence[0]
183
  )
184
 
185
- new_col2.write(
186
- "**English Translation**: "+ sequence[0] if lang_id=="en" else translate(sequence[0])
187
- )
 
 
 
 
 
 
 
 
 
1
+ from apps import article, mic
2
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from session import _get_state
4
+ from multiapp import MultiApp
5
+
6
+ # from io import BytesIO
7
+ # from apps.utils import read_markdown
8
+ # from apps import article
9
+ # import streamlit as st
10
+ # import pandas as pd
11
+ # import os
12
+ # import numpy as np
13
+ # from streamlit import caching
14
+ # from PIL import Image
15
+ # from model.flax_clip_vision_mbart.modeling_clip_vision_mbart import (
16
+ # FlaxCLIPVisionMBartForConditionalGeneration,
17
+ # )
18
+ # import matplotlib.pyplot as plt
19
+ # from mtranslate import translate
20
+
21
+
22
+ # from session import _get_state
23
+
24
+ # state = _get_state()
25
+
26
+
27
+ # @st.cache
28
+ # def load_model(ckpt):
29
+ # return FlaxCLIPVisionMBartForConditionalGeneration.from_pretrained(ckpt)
30
+
31
+
32
+ # tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50")
33
+
34
+ # language_mapping = {
35
+ # "en": "en_XX",
36
+ # "de": "de_DE",
37
+ # "fr": "fr_XX",
38
+ # "es": "es_XX"
39
+ # }
40
+
41
+ # code_to_name = {
42
+ # "en": "English",
43
+ # "fr": "French",
44
+ # "de": "German",
45
+ # "es": "Spanish",
46
+ # }
47
+
48
+ # @st.cache
49
+ # def generate_sequence(pixel_values, lang_code, num_beams, temperature, top_p, do_sample, top_k, max_length):
50
+ # lang_code = language_mapping[lang_code]
51
+ # output_ids = state.model.generate(input_ids=pixel_values, forced_bos_token_id=tokenizer.lang_code_to_id[lang_code], max_length=max_length, num_beams=num_beams, temperature=temperature, top_p = top_p, top_k=top_k, do_sample=do_sample)
52
+ # print(output_ids)
53
+ # output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=max_length)
54
+ # return output_sequence
55
+
56
+
57
+ # checkpoints = ["./ckpt/ckpt-51999"] # TODO: Maybe add more checkpoints?
58
+ # dummy_data = pd.read_csv("reference.tsv", sep="\t")
59
+
60
+
61
+ # st.sidebar.title("Generation Parameters")
62
+ # # max_length = st.sidebar.number_input("Max Length", min_value=16, max_value=128, value=64, step=1, help="The maximum length of sequence to be generated.")
63
+ # max_length = 64
64
+ # do_sample = st.sidebar.checkbox("Sample", value=False, help="Sample from the model instead of using beam search.")
65
+ # top_k = st.sidebar.number_input("Top K", min_value=10, max_value=200, value=50, step=1, help="The number of highest probability vocabulary tokens to keep for top-k-filtering.")
66
+ # num_beams = st.sidebar.number_input(label="Number of Beams", min_value=2, max_value=10, value=4, step=1, help="Number of beams to be used in beam search.")
67
+ # temperature = st.sidebar.select_slider(label="Temperature", options = list(np.arange(0.0,1.1, step=0.1)), value=1.0, help ="The value used to module the next token probabilities.", format_func=lambda x: f"{x:.2f}")
68
+ # top_p = st.sidebar.select_slider(label = "Top-P", options = list(np.arange(0.0,1.1, step=0.1)),value=1.0, help="Nucleus Sampling : If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are kept for generation.", format_func=lambda x: f"{x:.2f}")
69
+ # if st.sidebar.button("Clear All Cache"):
70
+ # caching.clear_cache()
71
+ # image_col, intro_col = st.beta_columns([3, 8])
72
+ # image_col.image("./misc/mic-logo.png", use_column_width="always")
73
+ # intro_col.write(read_markdown("intro.md"))
74
+
75
+ # with st.beta_expander("Usage"):
76
+ # st.markdown(read_markdown("usage.md"))
77
+
78
+ # with st.beta_expander("Article"):
79
+ # st.write(read_markdown("abstract.md"))
80
+ # st.write("## Methodology")
81
+ # st.image(
82
+ # "./misc/Multilingual-IC.png"
83
+ # )
84
+ # st.markdown(read_markdown("pretraining.md"))
85
+ # st.write(read_markdown("challenges.md"))
86
+ # st.write(read_markdown("social_impact.md"))
87
+ # st.write(read_markdown("bias.md"))
88
+
89
+ # col1, col2, col3, col4 = st.beta_columns([0.5,2.5,2.5,0.5])
90
+ # with col2:
91
+ # st.image("./misc/examples/female_dev_1.jpg", width=350, caption = 'German Caption: <PERSON> arbeitet an einem Computer.', use_column_width='always')
92
+ # with col3:
93
+ # st.image("./misc/examples/female_doctor.jpg", width=350, caption = 'English Caption: A portrait of <PERSON>, a doctor who specializes in health care.', use_column_width='always')
94
+
95
+ # col1, col2, col3, col4 = st.beta_columns([0.5,2.5,2.5,0.5])
96
+ # with col2:
97
+ # st.image("./misc/examples/female_doctor_1.jpg", width=350, caption = 'Spanish Caption: El Dr. <PERSON> es un estudiante de posgrado.', use_column_width='always')
98
+ # with col3:
99
+ # st.image("./misc/examples/women_cricket.jpg", width=350, caption = 'English Caption: <PERSON> of India bats against <PERSON> of Australia during the first Twenty20 match between India and Australia at Indian Bowl Stadium in New Delhi on Friday. - PTI', use_column_width='always')
100
+
101
+ # col1, col2, col3, col4 = st.beta_columns([0.5,2.5,2.5,0.5])
102
+ # with col2:
103
+ # st.image("./misc/examples/female_dev_2.jpg", width=350, caption = "French Caption: Un écran d'ordinateur avec un écran d'ordinateur ouvert.", use_column_width='always')
104
+ # with col3:
105
+ # st.image("./misc/examples/female_biker_resized.jpg", width=350, caption = 'German Caption: <PERSON> auf dem Motorrad von <PERSON>.', use_column_width='always')
106
+
107
+ # st.write(read_markdown("future_scope.md"))
108
+ # st.write(read_markdown("references.md"))
109
+ # # st.write(read_markdown("checkpoints.md"))
110
+ # st.write(read_markdown("acknowledgements.md"))
111
+
112
+ # if state.model is None:
113
+ # with st.spinner("Loading model..."):
114
+ # state.model = load_model(checkpoints[0])
115
+
116
+ # first_index = 25
117
+ # # Init Session State
118
+ # if state.image_file is None:
119
+ # state.image_file = dummy_data.loc[first_index, "image_file"]
120
+ # state.caption = dummy_data.loc[first_index, "caption"].strip("- ")
121
+ # state.lang_id = dummy_data.loc[first_index, "lang_id"]
122
+
123
+ # image_path = os.path.join("images", state.image_file)
124
+ # image = plt.imread(image_path)
125
+ # state.image = image
126
+
127
+ # if st.button("Get a random example", help="Get a random example from one of the seeded examples."):
128
+ # sample = dummy_data.sample(1).reset_index()
129
+ # state.image_file = sample.loc[0, "image_file"]
130
+ # state.caption = sample.loc[0, "caption"].strip("- ")
131
+ # state.lang_id = sample.loc[0, "lang_id"]
132
+
133
+ # image_path = os.path.join("images", state.image_file)
134
+ # image = plt.imread(image_path)
135
+ # state.image = image
136
+
137
+ # transformed_image = get_transformed_image(state.image)
138
+
139
+ # new_col1, new_col2 = st.beta_columns([5,5])
140
+
141
+ # # Display Image
142
+ # new_col1.image(state.image, use_column_width="always")
143
+ # # Display Reference Caption
144
+ # with new_col1.beta_expander("Reference Caption"):
145
+ # st.write("**Reference Caption**: " + state.caption)
146
+ # st.markdown(
147
+ # f"""**English Translation**: {state.caption if state.lang_id == "en" else translate(state.caption, 'en')}"""
148
+ # )
149
+
150
+ # # Select Language
151
+ # options = list(code_to_name.keys())
152
+ # lang_id = new_col2.selectbox(
153
+ # "Language",
154
+ # index=options.index(state.lang_id),
155
+ # options=options,
156
+ # format_func=lambda x: code_to_name[x],
157
+ # help="The language in which caption is to be generated."
158
+ # )
159
+
160
+ # sequence = ['']
161
+ # if new_col2.button("Generate Caption", help="Generate a caption in the specified language."):
162
+ # with st.spinner("Generating Sequence..."):
163
+ # sequence = generate_sequence(transformed_image, lang_id, num_beams, temperature, top_p, do_sample, top_k, max_length)
164
+ # # print(sequence)
165
+
166
+ # if sequence!=['']:
167
+ # new_col2.write(
168
+ # "**Generated Caption**: "+sequence[0]
169
+ # )
170
+
171
+ # new_col2.write(
172
+ # "**English Translation**: "+ sequence[0] if lang_id=="en" else translate(sequence[0])
173
+ # )
174
+
175
+ def main():
176
+ state = _get_state()
177
+ st.set_page_config(
178
+ page_title="Multilingual Image Captioning",
179
+ layout="wide",
180
+ initial_sidebar_state="collapsed",
181
+ page_icon="./misc/mic-logo.png",
182
  )
183
 
184
+ st.title("Multilingual Image Captioning")
185
+ st.write(
186
+ "[Bhavitvya Malik](https://huggingface.co/bhavitvyamalik), [Gunjan Chhablani](https://huggingface.co/gchhablani)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  )
188
 
189
+ st.sidebar.title("Multilingual Image Captioning")
190
+ logo = st.sidebar.image("./misc/mic-logo.png")
191
+ st.sidebar.write("Multilingual Image Captioning addresses the challenge of caption generation for an image in a multilingual setting. Here, we fuse CLIP Vision transformer into mBART50 and perform training on translated version of Conceptual-12M dataset. Please use the radio buttons below to navigate.")
192
+ app = MultiApp(state)
193
+ app.add_app("Article", article.app)
194
+ app.add_app("Multilingual Image Captioning", mic.app)
195
+ # app.add_app("Mask Filling", mlm.app)
196
+ app.run()
197
+ state.sync()
198
+
199
+ if __name__ == "__main__":
200
+ main()
apps/article.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from apps.utils import read_markdown
3
+ # from .streamlit_tensorboard import st_tensorboard, kill_tensorboard
4
+ from .utils import Toc
5
+
6
+ def app(state=None):
7
+ #kill_tensorboard()
8
+ toc = Toc()
9
+ st.info("Welcome to our Multilingual Image Captioning demo. Please use the navigation sidebar to move to our demo, or scroll below to read all about our project. 🤗 In case the sidebar isn't properly rendered, please change to a smaller window size and back to full screen.")
10
+
11
+ st.header("Table of contents")
12
+ toc.placeholder()
13
+
14
+ toc.header("Introduction and Motivation")
15
+ st.write(read_markdown("intro/intro.md"))
16
+ toc.subheader("Novel Contributions")
17
+ st.write(read_markdown("intro/contributions.md"))
18
+
19
+ toc.header("Methodology")
20
+
21
+ toc.subheader("Pre-training")
22
+ st.write(read_markdown("pretraining/intro.md"))
23
+ toc.subsubheader("Dataset")
24
+ st.write(read_markdown("pretraining/dataset.md"))
25
+ _, col2, _ = st.beta_columns([1,3,1])
26
+ with col2:
27
+ st.image("./misc/Multilingual-IC.png", use_column_width='always')
28
+ toc.subsubheader("Model")
29
+ st.write(read_markdown("pretraining/model.md"))
30
+ # toc.subsubheader("MLM Training Logs")
31
+ # st.info("In case the TensorBoard logs are not displayed, please visit this link: https://huggingface.co/flax-community/multilingual-vqa-pt-ckpts/tensorboard")
32
+ # st_tensorboard(logdir='./logs/pretrain_logs', port=6006)
33
+ st.write(read_markdown("bias.md"))
34
+
35
+ _, col2, col3, _ = st.beta_columns([0.5,2.5,2.5,0.5])
36
+ with col2:
37
+ st.image("./misc/examples/female_dev_1.jpg", width=350, caption = 'German Caption: <PERSON> arbeitet an einem Computer.', use_column_width='always')
38
+ with col3:
39
+ st.image("./misc/examples/female_doctor.jpg", width=350, caption = 'English Caption: A portrait of <PERSON>, a doctor who specializes in health care.', use_column_width='always')
40
+
41
+ _, col2, col3, _ = st.beta_columns([0.5,2.5,2.5,0.5])
42
+ with col2:
43
+ st.image("./misc/examples/female_doctor_1.jpg", width=350, caption = 'Spanish Caption: El Dr. <PERSON> es un estudiante de posgrado.', use_column_width='always')
44
+ with col3:
45
+ st.image("./misc/examples/women_cricket.jpg", width=350, caption = 'English Caption: <PERSON> of India bats against <PERSON> of Australia during the first Twenty20 match between India and Australia at Indian Bowl Stadium in New Delhi on Friday. - PTI', use_column_width='always')
46
+
47
+ _, col2, col3, _ = st.beta_columns([0.5,2.5,2.5,0.5])
48
+ with col2:
49
+ st.image("./misc/examples/female_dev_2.jpg", width=350, caption = "French Caption: Un écran d'ordinateur avec un écran d'ordinateur ouvert.", use_column_width='always')
50
+ with col3:
51
+ st.image("./misc/examples/female_biker_resized.jpg", width=350, caption = 'German Caption: <PERSON> auf dem Motorrad von <PERSON>.', use_column_width='always')
52
+
53
+ toc.header("Challenges and Technical Difficulties")
54
+ st.write(read_markdown("challenges.md"))
55
+
56
+ toc.header("Limitations")
57
+ st.write(read_markdown("limitations.md"))
58
+
59
+ toc.header("Conclusion, Future Work, and Social Impact")
60
+ toc.subheader("Conclusion")
61
+ st.write(read_markdown("conclusion_future_work/conclusion.md"))
62
+ toc.subheader("Future Work")
63
+ st.write(read_markdown("conclusion_future_work/future_scope.md"))
64
+ toc.subheader("Social Impact")
65
+ st.write(read_markdown("conclusion_future_work/social_impact.md"))
66
+
67
+ toc.header("References")
68
+ st.write(read_markdown("references.md"))
69
+
70
+
71
+ toc.header("Acknowledgements")
72
+ st.write(read_markdown("acknowledgements.md"))
73
+ toc.generate()
apps/mic.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import get_transformed_image
2
+
3
+ import streamlit as st
4
+ import numpy as np
5
+ import pandas as pd
6
+ import os
7
+ import matplotlib.pyplot as plt
8
+ from mtranslate import translate
9
+ from .utils import (
10
+ read_markdown,
11
+ tokenizer,
12
+ language_mapping,
13
+ code_to_name
14
+ )
15
+ import requests
16
+ from PIL import Image
17
+ from .model.flax_clip_vision_mbart.modeling_clip_vision_mbart import (
18
+ FlaxCLIPVisionMBartForConditionalGeneration,
19
+ )
20
+ from streamlit import caching
21
+
22
+
23
+ def app(state):
24
+ mic_state = state
25
+ with st.beta_expander("Usage"):
26
+ st.write(read_markdown("usage.md"))
27
+ st.write("\n")
28
+ st.write(read_markdown("intro.md"))
29
+
30
+ with st.beta_expander("Generation Parameters"):
31
+ do_sample = st.sidebar.checkbox("Sample", value=False, help="Sample from the model instead of using beam search.")
32
+ top_k = st.sidebar.number_input("Top K", min_value=10, max_value=200, value=50, step=1, help="The number of highest probability vocabulary tokens to keep for top-k-filtering.")
33
+ num_beams = st.sidebar.number_input(label="Number of Beams", min_value=2, max_value=10, value=4, step=1, help="Number of beams to be used in beam search.")
34
+ temperature = st.sidebar.select_slider(label="Temperature", options = list(np.arange(0.0,1.1, step=0.1)), value=1.0, help ="The value used to module the next token probabilities.", format_func=lambda x: f"{x:.2f}")
35
+ top_p = st.sidebar.select_slider(label = "Top-P", options = list(np.arange(0.0,1.1, step=0.1)),value=1.0, help="Nucleus Sampling : If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are kept for generation.", format_func=lambda x: f"{x:.2f}")
36
+
37
+ if st.sidebar.button("Clear All Cache"):
38
+ caching.clear_cache()
39
+
40
+ max_length = 64
41
+
42
+ @st.cache
43
+ def load_model(ckpt):
44
+ return FlaxCLIPVisionMBartForConditionalGeneration.from_pretrained(ckpt)
45
+
46
+ @st.cache
47
+ def generate_sequence(pixel_values, lang_code, num_beams, temperature, top_p, do_sample, top_k, max_length):
48
+ lang_code = language_mapping[lang_code]
49
+ output_ids = mic_state.model.generate(input_ids=pixel_values, forced_bos_token_id=tokenizer.lang_code_to_id[lang_code], max_length=max_length, num_beams=num_beams, temperature=temperature, top_p = top_p, top_k=top_k, do_sample=do_sample)
50
+ print(output_ids)
51
+ output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=max_length)
52
+ return output_sequence
53
+
54
+ mic_checkpoints = ["flax-community/clip-vit-base-patch32_mbart-large-50"] # TODO: Maybe add more checkpoints?
55
+ dummy_data = pd.read_csv("reference.tsv", sep="\t")
56
+
57
+ first_index = 25
58
+ # Init Session State
59
+ if mic_state.image_file is None:
60
+ mic_state.image_file = dummy_data.loc[first_index, "image_file"]
61
+ mic_state.caption = dummy_data.loc[first_index, "caption"].strip("- ")
62
+ mic_state.lang_id = dummy_data.loc[first_index, "lang_id"]
63
+
64
+ image_path = os.path.join("images", mic_state.image_file)
65
+ image = plt.imread(image_path)
66
+ mic_state.image = image
67
+
68
+ if mic_state.model is None:
69
+ # Display Top-5 Predictions
70
+ with st.spinner("Loading model..."):
71
+ mic_state.model = load_model(mic_checkpoints[0])
72
+
73
+ query1 = st.text_input(
74
+ "Enter a URL to an image",
75
+ value="http://images.cocodataset.org/val2017/000000039769.jpg",
76
+ )
77
+
78
+ col1, col2, col3 = st.beta_columns([2,1, 2])
79
+ if col1.button(
80
+ "Get a random example",
81
+ help="Get a random example from the 100 `seeded` image-text pairs.",
82
+ ):
83
+ sample = dummy_data.sample(1).reset_index()
84
+ mic_state.image_file = sample.loc[0, "image_file"]
85
+ mic_state.caption = sample.loc[0, "caption"].strip("- ")
86
+ mic_state.lang_id = sample.loc[0, "lang_id"]
87
+
88
+ image_path = os.path.join("images", mic_state.image_file)
89
+ image = plt.imread(image_path)
90
+ mic_state.image = image
91
+
92
+ col2.write("OR")
93
+
94
+ if col3.button("Use above URL"):
95
+ image_data = requests.get(query1, stream=True).raw
96
+ image = np.asarray(Image.open(image_data))
97
+ mic_state.image = image
98
+
99
+ transformed_image = get_transformed_image(mic_state.image)
100
+
101
+ new_col1, new_col2 = st.beta_columns([5,5])
102
+
103
+ # Display Image
104
+ new_col1.image(mic_state.image, use_column_width="always")
105
+ # Display Reference Caption
106
+ with new_col1.beta_expander("Reference Caption"):
107
+ st.write("**Reference Caption**: " + mic_state.caption)
108
+ st.markdown(
109
+ f"""**English Translation**: {mic_state.caption if mic_state.lang_id == "en" else translate(mic_state.caption, 'en')}"""
110
+ )
111
+
112
+ # Select Language
113
+ options = list(code_to_name.keys())
114
+ lang_id = new_col2.selectbox(
115
+ "Language",
116
+ index=options.index(mic_state.lang_id),
117
+ options=options,
118
+ format_func=lambda x: code_to_name[x],
119
+ help="The language in which caption is to be generated."
120
+ )
121
+
122
+ sequence = ['']
123
+ if new_col2.button("Generate Caption", help="Generate a caption in the specified language."):
124
+ with st.spinner("Generating Sequence..."):
125
+ sequence = generate_sequence(transformed_image, lang_id, num_beams, temperature, top_p, do_sample, top_k, max_length)
126
+ # print(sequence)
127
+
128
+ if sequence!=['']:
129
+ new_col2.write(
130
+ "**Generated Caption**: "+sequence[0]
131
+ )
132
+
133
+ new_col2.write(
134
+ "**English Translation**: "+ sequence[0] if lang_id=="en" else translate(sequence[0])
135
+ )
136
+
137
+
138
+
139
+
140
+ # image_col, intro_col = st.beta_columns([3, 8])
141
+ # image_col.image("./misc/mic-logo.png", use_column_width="always")
142
+ # intro_col.write(read_markdown("intro.md"))
{model → apps/model}/__init__.py RENAMED
File without changes
{model → apps/model}/flax_clip_vision_mbart/__init__.py RENAMED
File without changes
{model → apps/model}/flax_clip_vision_mbart/configuration_clip_vision_mbart.py RENAMED
File without changes
{model → apps/model}/flax_clip_vision_mbart/generation_clip_vision_utils.py RENAMED
File without changes
{model → apps/model}/flax_clip_vision_mbart/modeling_clip_vision_mbart.py RENAMED
File without changes
{model → apps/model}/flax_clip_vision_mbart/modeling_clip_vision_utils.py RENAMED
File without changes
utils.py → apps/utils.py RENAMED
@@ -4,7 +4,41 @@ import numpy as np
4
  from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
5
  from torchvision.transforms.functional import InterpolationMode
6
  from PIL import Image
 
 
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  class Transform(torch.nn.Module):
10
  def __init__(self, image_size):
@@ -31,4 +65,22 @@ def get_transformed_image(image):
31
  if isinstance(image, np.ndarray) and image.shape[-1] == 3:
32
  image = image.transpose(2, 0, 1)
33
  image = torch.tensor(image)
34
- return transform(image).unsqueeze(0).permute(0, 2, 3, 1).numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
5
  from torchvision.transforms.functional import InterpolationMode
6
  from PIL import Image
7
+ import os
8
+ import streamlit as st
9
+ from transformers import MBart50TokenizerFast
10
 
11
+ class Toc:
12
+ def __init__(self):
13
+ self._items = []
14
+ self._placeholder = None
15
+
16
+ def title(self, text):
17
+ self._markdown(text, "h1")
18
+
19
+ def header(self, text):
20
+ self._markdown(text, "h2", " " * 2)
21
+
22
+ def subheader(self, text):
23
+ self._markdown(text, "h3", " " * 4)
24
+
25
+ def subsubheader(self, text):
26
+ self._markdown(text, "h4", " " * 8)
27
+
28
+ def placeholder(self, sidebar=False):
29
+ self._placeholder = st.sidebar.empty() if sidebar else st.empty()
30
+
31
+ def generate(self):
32
+ if self._placeholder:
33
+ self._placeholder.markdown("\n".join(self._items), unsafe_allow_html=True)
34
+
35
+ def _markdown(self, text, level, space=""):
36
+ key = "".join(filter(str.isalnum, text)).lower()
37
+
38
+ st.markdown(f"<{level} id='{key}'>{text}</{level}>", unsafe_allow_html=True)
39
+ self._items.append(f"{space}* <a href='#{key}'>{text}</a>")
40
+
41
+ tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50")
42
 
43
  class Transform(torch.nn.Module):
44
  def __init__(self, image_size):
65
  if isinstance(image, np.ndarray) and image.shape[-1] == 3:
66
  image = image.transpose(2, 0, 1)
67
  image = torch.tensor(image)
68
+ return transform(image).unsqueeze(0).permute(0, 2, 3, 1).numpy()
69
+
70
+ def read_markdown(path, parent="./sections/"):
71
+ with open(os.path.join(parent, path)) as f:
72
+ return f.read()
73
+
74
+ language_mapping = {
75
+ "en": "en_XX",
76
+ "de": "de_DE",
77
+ "fr": "fr_XX",
78
+ "es": "es_XX"
79
+ }
80
+
81
+ code_to_name = {
82
+ "en": "English",
83
+ "fr": "French",
84
+ "de": "German",
85
+ "es": "Spanish",
86
+ }
multiapp.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ class MultiApp:
4
+ def __init__(self, state):
5
+ self.apps = []
6
+ self.state = state
7
+
8
+ def add_app(self, title, func):
9
+ self.apps.append({"title": title, "function": func})
10
+
11
+ def run(self):
12
+ st.sidebar.header("Go To:")
13
+ app = st.sidebar.radio(
14
+ "", self.apps, format_func=lambda app: app["title"]
15
+ )
16
+ app["function"](self.state)
sections/acknowledgements.md CHANGED
@@ -1,4 +1,3 @@
1
- ## Acknowledgements
2
  We'd like to thank [Abheesht Sharma](https://huggingface.co/abheesht) for helping in the discussions in the initial phases. [Luke Melas](https://github.com/lukemelas) helped us get the cleaned CC-12M data on our TPU-VMs and we are very grateful to him.
3
 
4
  This project would not be possible without the help of [Patrick](https://huggingface.co/patrickvonplaten) and [Suraj](https://huggingface.co/valhalla) who met with us and helped us review our approach and guided us throughout the project. We especially thank Patrick for going out of the way and allowing us extra TPU time so that we could work on this project.
 
1
  We'd like to thank [Abheesht Sharma](https://huggingface.co/abheesht) for helping in the discussions in the initial phases. [Luke Melas](https://github.com/lukemelas) helped us get the cleaned CC-12M data on our TPU-VMs and we are very grateful to him.
2
 
3
  This project would not be possible without the help of [Patrick](https://huggingface.co/patrickvonplaten) and [Suraj](https://huggingface.co/valhalla) who met with us and helped us review our approach and guided us throughout the project. We especially thank Patrick for going out of the way and allowing us extra TPU time so that we could work on this project.
sections/challenges.md CHANGED
@@ -1,4 +1,3 @@
1
- ## Challenges and Technical Difficulties
2
  Training image captioning that too multilingual was a difficult task and we faced challenges at almost every point of this process.
3
 
4
  - Dataset: Our initial plan was to translate Conceptual 12M using mTranslate or Yandex but they turned out to be too slow even with multiprocessing. Not having proper translation could lead to poor performance of the trained image-caption model. We translated the whole dataset using MBart50 for all languages which took around 3-4 days. Further on, we realised that mBART captions were not that good and model was not converging because of that which lead us to re-translate our captions with [Marian](https://huggingface.co/transformers/model_doc/marian.html)
 
1
  Training image captioning that too multilingual was a difficult task and we faced challenges at almost every point of this process.
2
 
3
  - Dataset: Our initial plan was to translate Conceptual 12M using mTranslate or Yandex but they turned out to be too slow even with multiprocessing. Not having proper translation could lead to poor performance of the trained image-caption model. We translated the whole dataset using MBart50 for all languages which took around 3-4 days. Further on, we realised that mBART captions were not that good and model was not converging because of that which lead us to re-translate our captions with [Marian](https://huggingface.co/transformers/model_doc/marian.html)
sections/{caveats.md → conclusion_future_work/conclusion.md} RENAMED
File without changes
sections/{future_scope.md → conclusion_future_work/future_scope.md} RENAMED
@@ -1,4 +1,3 @@
1
- ## Future scope of work
2
  We hope to improve this project in the future by using:
3
  - Better options for data translation: Translation has a very huge impact on how the end model would perform. Better translators (for e.g. Google Translate API) and language specific seq2seq models for translation are able to generate better data, both for high-resource and low-resource languages.
4
  - Accessibility: Make model deployable on hand-held devices to make it more accessible. Currently, our model is too large to fit on mobile/edge devices because of which not many will be able to access it. However, our final goal is ensure everyone can access it without any computation barriers. We got to know that JAX has an experimental converter `jax2tf`to convert JAX functions to TF. Hopefully we'll be able to support TFLite for our model as well in future.
 
1
  We hope to improve this project in the future by using:
2
  - Better options for data translation: Translation has a very huge impact on how the end model would perform. Better translators (for e.g. Google Translate API) and language specific seq2seq models for translation are able to generate better data, both for high-resource and low-resource languages.
3
  - Accessibility: Make model deployable on hand-held devices to make it more accessible. Currently, our model is too large to fit on mobile/edge devices because of which not many will be able to access it. However, our final goal is ensure everyone can access it without any computation barriers. We got to know that JAX has an experimental converter `jax2tf`to convert JAX functions to TF. Hopefully we'll be able to support TFLite for our model as well in future.
sections/{social_impact.md → conclusion_future_work/social_impact.md} RENAMED
@@ -1,4 +1,3 @@
1
- ## Social Impact
2
  Being able to automatically describe the content of an image using properly formed sentences in any language is a challenging task, but it could have great impact by helping visually impaired people better understand their surroundings.
3
 
4
  Our initial plan was to include 4 high-resource and 4 low-resource languages (Marathi, Bengali, Urdu, Telegu) in our training data. However, the existing translations do not perform as well and we would have received poor labels, not to mention, with a longer training time.
 
1
  Being able to automatically describe the content of an image using properly formed sentences in any language is a challenging task, but it could have great impact by helping visually impaired people better understand their surroundings.
2
 
3
  Our initial plan was to include 4 high-resource and 4 low-resource languages (Marathi, Bengali, Urdu, Telegu) in our training data. However, the existing translations do not perform as well and we would have received poor labels, not to mention, with a longer training time.
sections/intro/contributions.md ADDED
File without changes
sections/intro/intro.md ADDED
File without changes
sections/limitations.md ADDED
File without changes
sections/pretraining.md CHANGED
@@ -1,4 +1,3 @@
1
- ### Pretraining
2
  We follow an encoder-decoder approach for image captioning, where the image encoder is the CLIP Vision model (a ViT transformer). The pre-training task is image-to-text generation. We take the input tokens and shift them using an `<eos>` token towards right in order to create the inputs for our model, while the original input tokens become labels. The model is trained on the dataset. in an end-to-end fashion.
3
 
4
  **Dataset**
 
1
  We follow an encoder-decoder approach for image captioning, where the image encoder is the CLIP Vision model (a ViT transformer). The pre-training task is image-to-text generation. We take the input tokens and shift them using an `<eos>` token towards right in order to create the inputs for our model, while the original input tokens become labels. The model is trained on the dataset. in an end-to-end fashion.
2
 
3
  **Dataset**
sections/pretraining/dataset.md ADDED
@@ -0,0 +1 @@
 
1
+ The dataset we use for pre-training is a cleaned version of Conceptual 12M. The dataset is downloaded and then broken images are removed which gives us about 10M images. To save time, we use 2.5M of these image-text pairs. Then we use the MarianMT `Helsinki-NLP/opus-mt-{src}-{tgt}` checkpoint to translate the dataset into four different languages - English, French, German, and Spanish, keeping approximately 2.5M examples of each language.
sections/pretraining/intro.md ADDED
@@ -0,0 +1 @@
 
1
+ We follow an encoder-decoder approach for image captioning, where the image encoder is the CLIP Vision model (a ViT transformer). The pre-training task is image-to-text generation. We take the input tokens and shift them using an `<eos>` token towards right in order to create the inputs for our model, while the original input tokens become labels. The model is trained on the dataset. in an end-to-end fashion.
sections/pretraining/model.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
1
+ The model is shown in the image above. We create a custom model in Flax which integerates the CLIP Vision model as an encoder inside mBART model. We also use custom configs and modules in order to accomodate for these changes, and allow loading from mBART and CLIP Vision checkpoints. The image is fed to the CLIP Vision encoder and the shifted token ids are fed to the mBART decoder. We use the `facebook/mbart-large-50` and `openai/clip-vit-base-patch32` checkpoints for mBART and CLIP Vision models, respectively. All our code is available on [GitHub](https://github.com/gchhablani/multilingual-image-captioning).
2
+
3
+ Our model reached **eval loss of ~2.6** around ~70K steps. Here are the BLEU scores (out of 1) for different languages:
4
+
5
+ |Language |BLEU-1|BLEU-2|BLEU-3|BLEU-4|
6
+ |--------------------------|------|------|------|------|
7
+ |English | 0.13083| 0.08887| 0.06681 | 0.04899|
8
+ |Spanish | 0.15981| 0.09858| 0.06918| 0.04776|
9
+ |German | 0.14234| 0.09817| 0.07405| 0.0515|
10
+ |French | 0.13021| 0.08862| 0.06598| 0.04647|
sections/references.md CHANGED
@@ -1,5 +1,3 @@
1
- ## References
2
-
3
  ```
4
  @inproceedings{wolf-etal-2020-transformers,
5
  title = "Transformers: State-of-the-Art Natural Language Processing",
 
 
1
  ```
2
  @inproceedings{wolf-etal-2020-transformers,
3
  title = "Transformers: State-of-the-Art Natural Language Processing",