bhavitvyamalik commited on
Commit
6088947
2 Parent(s): f82fbe0 c1274fe

Merge branch 'main' of https://huggingface.co/spaces/flax-community/multilingual-image-captioning into main

Browse files
.gitignore CHANGED
@@ -1,3 +1,2 @@
1
- *mic_env/*
2
- **__pycache__**
3
- *.pyc
 
1
+ mic_env/*
2
+ *.pyc
 
app.py CHANGED
@@ -44,9 +44,9 @@ code_to_name = {
44
  }
45
 
46
  @st.cache(persist=True)
47
- def generate_sequence(pixel_values, lang_code, num_beams):
48
  lang_code = language_mapping[lang_code]
49
- output_ids = model.generate(input_ids=pixel_values, forced_bos_token_id=tokenizer.lang_code_to_id[lang_code], max_length=64, num_beams=num_beams)
50
  print(output_ids)
51
  output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=64)
52
  return output_sequence
@@ -56,13 +56,14 @@ def read_markdown(path, parent="./sections/"):
56
  return f.read()
57
 
58
 
59
- checkpoints = ["./ckpt/ckpt-22499"] # TODO: Maybe add more checkpoints?
60
  dummy_data = pd.read_csv("reference.tsv", sep="\t")
61
 
62
  st.set_page_config(
63
  page_title="Multilingual Image Captioning",
64
  layout="wide",
65
  initial_sidebar_state="collapsed",
 
66
  )
67
 
68
  st.title("Multilingual Image Captioning")
@@ -70,12 +71,33 @@ st.write(
70
  "[Bhavitvya Malik](https://huggingface.co/bhavitvyamalik), [Gunjan Chhablani](https://huggingface.co/gchhablani)"
71
  )
72
 
73
- st.sidebar.title("Settings")
74
  num_beams = st.sidebar.number_input(label="Number of Beams", min_value=2, max_value=10, value=4, step=1, help="Number of beams to be used in beam search.")
 
 
 
 
 
 
75
 
76
  with st.beta_expander("Usage"):
77
  st.markdown(read_markdown("usage.md"))
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  first_index = 20
80
  # Init Session State
81
  if state.image_file is None:
@@ -87,9 +109,9 @@ if state.image_file is None:
87
  image = plt.imread(image_path)
88
  state.image = image
89
 
90
- col1, col2 = st.beta_columns([6, 4])
91
 
92
- if col2.button("Get a random example"):
93
  sample = dummy_data.sample(1).reset_index()
94
  state.image_file = sample.loc[0, "image_file"]
95
  state.caption = sample.loc[0, "caption"].strip("- ")
@@ -99,40 +121,42 @@ if col2.button("Get a random example"):
99
  image = plt.imread(image_path)
100
  state.image = image
101
 
102
- col2.write("OR")
103
 
104
- uploaded_file = col2.file_uploader("Upload your image", type=["png", "jpg", "jpeg"])
105
- if uploaded_file is not None:
106
- state.image_file = os.path.join("images", uploaded_file.name)
107
- state.image = np.array(Image.open(uploaded_file))
108
 
109
  transformed_image = get_transformed_image(state.image)
110
 
 
111
  # Display Image
112
- col1.image(state.image, use_column_width="auto")
 
113
 
114
  # Display Reference Caption
115
- col2.write("**Reference Caption**: " + state.caption)
116
- col2.markdown(
117
  f"""**English Translation**: {state.caption if state.lang_id == "en" else translate(state.caption, 'en')}"""
118
  )
119
 
120
  # Select Language
121
  options = list(code_to_name.keys())
122
- lang_id = col2.selectbox(
123
  "Language",
124
  index=options.index(state.lang_id),
125
  options=options,
126
  format_func=lambda x: code_to_name[x],
 
127
  )
128
- # Display Top-5 Predictions
129
  with st.spinner("Loading model..."):
130
  model = load_model(checkpoints[0])
131
-
132
  sequence = ['']
133
- if col2.button("Generate Caption"):
134
  with st.spinner("Generating Sequence..."):
135
- sequence = generate_sequence(transformed_image, lang_id, num_beams)
136
  # print(sequence)
137
 
138
  if sequence!=['']:
@@ -143,15 +167,3 @@ if sequence!=['']:
143
  st.write(
144
  "**English Translation**: "+ sequence[0] if lang_id=="en" else translate(sequence[0])
145
  )
146
- st.write(read_markdown("abstract.md"))
147
- st.write(read_markdown("caveats.md"))
148
- # st.write("# Methodology")
149
- # st.image(
150
- # "./misc/Multilingual-IC.png", caption="Seq2Seq model for Image-text Captioning."
151
- # )
152
- st.markdown(read_markdown("pretraining.md"))
153
- st.write(read_markdown("challenges.md"))
154
- st.write(read_markdown("social_impact.md"))
155
- st.write(read_markdown("references.md"))
156
- # st.write(read_markdown("checkpoints.md"))
157
- st.write(read_markdown("acknowledgements.md"))
 
44
  }
45
 
46
  @st.cache(persist=True)
47
+ def generate_sequence(pixel_values, lang_code, num_beams, temperature, top_p):
48
  lang_code = language_mapping[lang_code]
49
+ output_ids = model.generate(input_ids=pixel_values, forced_bos_token_id=tokenizer.lang_code_to_id[lang_code], max_length=64, num_beams=num_beams, temperature=temperature, top_p = top_p)
50
  print(output_ids)
51
  output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=64)
52
  return output_sequence
 
56
  return f.read()
57
 
58
 
59
+ checkpoints = ["./ckpt/ckpt-17499"] # TODO: Maybe add more checkpoints?
60
  dummy_data = pd.read_csv("reference.tsv", sep="\t")
61
 
62
  st.set_page_config(
63
  page_title="Multilingual Image Captioning",
64
  layout="wide",
65
  initial_sidebar_state="collapsed",
66
+ page_icon="./misc/mic-logo.png",
67
  )
68
 
69
  st.title("Multilingual Image Captioning")
 
71
  "[Bhavitvya Malik](https://huggingface.co/bhavitvyamalik), [Gunjan Chhablani](https://huggingface.co/gchhablani)"
72
  )
73
 
74
+ st.sidebar.title("Generation Parameters")
75
  num_beams = st.sidebar.number_input(label="Number of Beams", min_value=2, max_value=10, value=4, step=1, help="Number of beams to be used in beam search.")
76
+ temperature = st.sidebar.select_slider(label="Temperature", options = list(np.arange(0.0,1.1, step=0.1)), value=1.0, help ="The value used to module the next token probabilities.", format_func=lambda x: f"{x:.2f}")
77
+ top_p = st.sidebar.select_slider(label = "Top-P", options = list(np.arange(0.0,1.1, step=0.1)),value=1.0, help="Nucleus Sampling : If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are kept for generation.", format_func=lambda x: f"{x:.2f}")
78
+
79
+ image_col, intro_col = st.beta_columns([3, 8])
80
+ image_col.image("./misc/mic-logo.png", use_column_width="always")
81
+ intro_col.write(read_markdown("intro.md"))
82
 
83
  with st.beta_expander("Usage"):
84
  st.markdown(read_markdown("usage.md"))
85
 
86
+ with st.beta_expander("Article"):
87
+ st.write(read_markdown("abstract.md"))
88
+ st.write(read_markdown("caveats.md"))
89
+ # st.write("# Methodology")
90
+ # st.image(
91
+ # "./misc/Multilingual-IC.png", caption="Seq2Seq model for Image-text Captioning."
92
+ # )
93
+ st.markdown(read_markdown("pretraining.md"))
94
+ st.write(read_markdown("challenges.md"))
95
+ st.write(read_markdown("social_impact.md"))
96
+ st.write(read_markdown("references.md"))
97
+ # st.write(read_markdown("checkpoints.md"))
98
+ st.write(read_markdown("acknowledgements.md"))
99
+
100
+
101
  first_index = 20
102
  # Init Session State
103
  if state.image_file is None:
 
109
  image = plt.imread(image_path)
110
  state.image = image
111
 
112
+ # col1, col2 = st.beta_columns([6, 4])
113
 
114
+ if st.button("Get a random example", help="Get a random example from one of the seeded examples."):
115
  sample = dummy_data.sample(1).reset_index()
116
  state.image_file = sample.loc[0, "image_file"]
117
  state.caption = sample.loc[0, "caption"].strip("- ")
 
121
  image = plt.imread(image_path)
122
  state.image = image
123
 
124
+ # col2.write("OR")
125
 
126
+ # uploaded_file = col2.file_uploader("Upload your image", type=["png", "jpg", "jpeg"])
127
+ # if uploaded_file is not None:
128
+ # state.image_file = os.path.join("images", uploaded_file.name)
129
+ # state.image = np.array(Image.open(uploaded_file))
130
 
131
  transformed_image = get_transformed_image(state.image)
132
 
133
+ new_col1, new_col2 = st.beta_columns([5,5])
134
  # Display Image
135
+ new_col1.image(state.image, use_column_width="always")
136
+
137
 
138
  # Display Reference Caption
139
+ new_col2.write("**Reference Caption**: " + state.caption)
140
+ new_col2.markdown(
141
  f"""**English Translation**: {state.caption if state.lang_id == "en" else translate(state.caption, 'en')}"""
142
  )
143
 
144
  # Select Language
145
  options = list(code_to_name.keys())
146
+ lang_id = new_col2.selectbox(
147
  "Language",
148
  index=options.index(state.lang_id),
149
  options=options,
150
  format_func=lambda x: code_to_name[x],
151
+ help="The language in which caption is to be generated."
152
  )
153
+
154
  with st.spinner("Loading model..."):
155
  model = load_model(checkpoints[0])
 
156
  sequence = ['']
157
+ if new_col2.button("Generate Caption", help="Generate a caption in the specified language."):
158
  with st.spinner("Generating Sequence..."):
159
+ sequence = generate_sequence(transformed_image, lang_id, num_beams, temperature, top_p)
160
  # print(sequence)
161
 
162
  if sequence!=['']:
 
167
  st.write(
168
  "**English Translation**: "+ sequence[0] if lang_id=="en" else translate(sequence[0])
169
  )
 
 
 
 
 
 
 
 
 
 
 
 
misc/Multilingual IC.svg ADDED
misc/mic-logo.png ADDED
model/flax_clip_vision_mbart/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (184 Bytes)
 
model/flax_clip_vision_mbart/__pycache__/configuration_clip_vision_mbart.cpython-38.pyc DELETED
Binary file (1.7 kB)
 
model/flax_clip_vision_mbart/__pycache__/generation_clip_vision_utils.cpython-38.pyc DELETED
Binary file (21.8 kB)
 
model/flax_clip_vision_mbart/__pycache__/modeling_clip_vision_mbart.cpython-38.pyc DELETED
Binary file (15.5 kB)
 
model/flax_clip_vision_mbart/__pycache__/modeling_clip_vision_utils.cpython-38.pyc DELETED
Binary file (16.6 kB)