gchhablani commited on
Commit
611eaf4
1 Parent(s): 888419c

Fix image display issues and add caching

Browse files
Files changed (1) hide show
  1. app.py +10 -18
app.py CHANGED
@@ -46,7 +46,7 @@ code_to_name = {
46
  @st.cache
47
  def generate_sequence(pixel_values, lang_code, num_beams, temperature, top_p):
48
  lang_code = language_mapping[lang_code]
49
- output_ids = model.generate(input_ids=pixel_values, forced_bos_token_id=tokenizer.lang_code_to_id[lang_code], max_length=64, num_beams=num_beams, temperature=temperature, top_p = top_p)
50
  print(output_ids)
51
  output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=64)
52
  return output_sequence
@@ -75,7 +75,8 @@ st.sidebar.title("Generation Parameters")
75
  num_beams = st.sidebar.number_input(label="Number of Beams", min_value=2, max_value=10, value=4, step=1, help="Number of beams to be used in beam search.")
76
  temperature = st.sidebar.select_slider(label="Temperature", options = list(np.arange(0.0,1.1, step=0.1)), value=1.0, help ="The value used to module the next token probabilities.", format_func=lambda x: f"{x:.2f}")
77
  top_p = st.sidebar.select_slider(label = "Top-P", options = list(np.arange(0.0,1.1, step=0.1)),value=1.0, help="Nucleus Sampling : If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are kept for generation.", format_func=lambda x: f"{x:.2f}")
78
-
 
79
  image_col, intro_col = st.beta_columns([3, 8])
80
  image_col.image("./misc/mic-logo.png", use_column_width="always")
81
  intro_col.write(read_markdown("intro.md"))
@@ -98,8 +99,11 @@ with st.beta_expander("Article"):
98
  # st.write(read_markdown("checkpoints.md"))
99
  st.write(read_markdown("acknowledgements.md"))
100
 
 
 
 
101
 
102
- first_index = 20
103
  # Init Session State
104
  if state.image_file is None:
105
  state.image_file = dummy_data.loc[first_index, "image_file"]
@@ -110,9 +114,10 @@ if state.image_file is None:
110
  image = plt.imread(image_path)
111
  state.image = image
112
 
113
- col1, col2 = st.beta_columns([5, 5])
114
 
115
- if col1.button("Get a random example", help="Get a random example from one of the seeded examples."):
 
116
  sample = dummy_data.sample(1).reset_index()
117
  state.image_file = sample.loc[0, "image_file"]
118
  state.caption = sample.loc[0, "caption"].strip("- ")
@@ -122,21 +127,10 @@ if col1.button("Get a random example", help="Get a random example from one of th
122
  image = plt.imread(image_path)
123
  state.image = image
124
 
125
- if col2.sidebar.button("Clear All Cache"):
126
- caching.clear_cache()
127
-
128
- # uploaded_file = col2.file_uploader("Upload your image", type=["png", "jpg", "jpeg"])
129
- # if uploaded_file is not None:
130
- # state.image_file = os.path.join("images", uploaded_file.name)
131
- # state.image = np.array(Image.open(uploaded_file))
132
-
133
  transformed_image = get_transformed_image(state.image)
134
 
135
- new_col1, new_col2 = st.beta_columns([5,5])
136
  # Display Image
137
  new_col1.image(state.image, use_column_width="always")
138
-
139
-
140
  # Display Reference Caption
141
  new_col2.write("**Reference Caption**: " + state.caption)
142
  new_col2.markdown(
@@ -153,8 +147,6 @@ lang_id = new_col2.selectbox(
153
  help="The language in which caption is to be generated."
154
  )
155
 
156
- with st.spinner("Loading model..."):
157
- model = load_model(checkpoints[0])
158
  sequence = ['']
159
  if new_col2.button("Generate Caption", help="Generate a caption in the specified language."):
160
  with st.spinner("Generating Sequence..."):
 
46
  @st.cache
47
  def generate_sequence(pixel_values, lang_code, num_beams, temperature, top_p):
48
  lang_code = language_mapping[lang_code]
49
+ output_ids = state.model.generate(input_ids=pixel_values, forced_bos_token_id=tokenizer.lang_code_to_id[lang_code], max_length=64, num_beams=num_beams, temperature=temperature, top_p = top_p)
50
  print(output_ids)
51
  output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=64)
52
  return output_sequence
 
75
  num_beams = st.sidebar.number_input(label="Number of Beams", min_value=2, max_value=10, value=4, step=1, help="Number of beams to be used in beam search.")
76
  temperature = st.sidebar.select_slider(label="Temperature", options = list(np.arange(0.0,1.1, step=0.1)), value=1.0, help ="The value used to module the next token probabilities.", format_func=lambda x: f"{x:.2f}")
77
  top_p = st.sidebar.select_slider(label = "Top-P", options = list(np.arange(0.0,1.1, step=0.1)),value=1.0, help="Nucleus Sampling : If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are kept for generation.", format_func=lambda x: f"{x:.2f}")
78
+ if st.sidebar.button("Clear All Cache"):
79
+ caching.clear_cache()
80
  image_col, intro_col = st.beta_columns([3, 8])
81
  image_col.image("./misc/mic-logo.png", use_column_width="always")
82
  intro_col.write(read_markdown("intro.md"))
 
99
  # st.write(read_markdown("checkpoints.md"))
100
  st.write(read_markdown("acknowledgements.md"))
101
 
102
+ if state.model is None:
103
+ with st.spinner("Loading model..."):
104
+ state.model = load_model(checkpoints[0])
105
 
106
+ first_index = 25
107
  # Init Session State
108
  if state.image_file is None:
109
  state.image_file = dummy_data.loc[first_index, "image_file"]
 
114
  image = plt.imread(image_path)
115
  state.image = image
116
 
117
+ new_col1, new_col2 = st.beta_columns([5,5])
118
 
119
+
120
+ if new_col2.button("Get a random example", help="Get a random example from one of the seeded examples."):
121
  sample = dummy_data.sample(1).reset_index()
122
  state.image_file = sample.loc[0, "image_file"]
123
  state.caption = sample.loc[0, "caption"].strip("- ")
 
127
  image = plt.imread(image_path)
128
  state.image = image
129
 
 
 
 
 
 
 
 
 
130
  transformed_image = get_transformed_image(state.image)
131
 
 
132
  # Display Image
133
  new_col1.image(state.image, use_column_width="always")
 
 
134
  # Display Reference Caption
135
  new_col2.write("**Reference Caption**: " + state.caption)
136
  new_col2.markdown(
 
147
  help="The language in which caption is to be generated."
148
  )
149
 
 
 
150
  sequence = ['']
151
  if new_col2.button("Generate Caption", help="Generate a caption in the specified language."):
152
  with st.spinner("Generating Sequence..."):