cwkuo commited on
Commit
6b2ffd3
β€’
1 Parent(s): ef2dc13

tune some default params

Browse files
Files changed (2) hide show
  1. app.py +33 -41
  2. examples/diamond_head.jpg +0 -3
app.py CHANGED
@@ -159,7 +159,7 @@ def retrieve_knowledge(image):
159
 
160
 
161
  @torch.inference_mode()
162
- def generate(state: Conversation, temperature, top_p, max_new_tokens, add_knwl, do_sampling):
163
  if state.skip_next: # This generate call is skipped due to invalid inputs
164
  yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 3 + knwl_unchange
165
  return
@@ -172,37 +172,33 @@ def generate(state: Conversation, temperature, top_p, max_new_tokens, add_knwl,
172
 
173
  # retrieve and visualize knowledge
174
  image = state.get_images(return_pil=True)[0]
175
- if bool(add_knwl):
176
- knwl_embd, knwl = retrieve_knowledge(image)
177
- knwl_img, knwl_txt, idx = [None, ] * 15, ["", ] * 15, 0
178
- for query_type, knwl_pos in (("whole", 1), ("five", 5), ("nine", 9)):
179
- if query_type == "whole":
180
- images = [image, ]
181
- elif query_type == "five":
182
- images = five_crop(image)
183
- elif query_type == "nine":
184
- images = nine_crop(image)
185
-
186
- for pos in range(knwl_pos):
187
- try:
188
- txt = ""
189
- for k, v in knwl[query_type][pos].items():
190
- v = ", ".join([vi.replace("_", " ") for vi in v])
191
- txt += f"**[{k.upper()}]:** {v}\n\n"
192
- knwl_txt[idx] += txt
193
-
194
- img = images[pos]
195
- img = query_trans.transforms[0](img)
196
- img = query_trans.transforms[1](img)
197
- img = query_trans.transforms[2](img)
198
- knwl_img[idx] = img
199
- except KeyError:
200
- pass
201
- idx += 1
202
- knwl_vis = tuple(knwl_img + knwl_txt)
203
- else:
204
- knwl_embd = None
205
- knwl_vis = knwl_none
206
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 3 + knwl_vis
207
 
208
  # generate output
@@ -217,7 +213,7 @@ def generate(state: Conversation, temperature, top_p, max_new_tokens, add_knwl,
217
  target=gptk_model.generate,
218
  kwargs=dict(
219
  samples=samples,
220
- use_nucleus_sampling=bool(do_sampling),
221
  max_length=min(int(max_new_tokens), 1024),
222
  top_p=float(top_p),
223
  temperature=float(temperature),
@@ -270,7 +266,6 @@ def build_demo():
270
  gr.Examples(examples=[
271
  ["examples/mona_lisa.jpg", "Discuss the historical impact and the significance of this painting in the art world."],
272
  ["examples/mona_lisa_dog.jpg", "Describe this photo in detail."],
273
- ["examples/diamond_head.jpg", "What is the name of this famous sight in the photo?"],
274
  ["examples/horseshoe_bend.jpg", "What are the possible reasons of the formation of this sight?"],
275
  ], inputs=[imagebox, textbox])
276
 
@@ -286,10 +281,7 @@ def build_demo():
286
  clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=False, scale=1)
287
 
288
  with gr.Accordion("Parameters", open=True):
289
- with gr.Row():
290
- add_knwl = gr.Checkbox(value=True, interactive=True, label="Knowledge")
291
- do_sampling = gr.Checkbox(value=False, interactive=True, label="Sampling")
292
- temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, interactive=True, label="Temperature",)
293
  top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
294
  max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
295
 
@@ -318,7 +310,7 @@ def build_demo():
318
  regenerate, [state], [state, chatbot, textbox, imagebox] + btn_list
319
  ).then(
320
  generate,
321
- [state, temperature, top_p, max_output_tokens, add_knwl, do_sampling],
322
  [state, chatbot] + btn_list + knwl_vis
323
  )
324
 
@@ -330,7 +322,7 @@ def build_demo():
330
  add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
331
  ).then(
332
  generate,
333
- [state, temperature, top_p, max_output_tokens, add_knwl, do_sampling],
334
  [state, chatbot] + btn_list + knwl_vis
335
  )
336
 
@@ -338,7 +330,7 @@ def build_demo():
338
  add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
339
  ).then(
340
  generate,
341
- [state, temperature, top_p, max_output_tokens, add_knwl, do_sampling],
342
  [state, chatbot] + btn_list + knwl_vis
343
  )
344
 
 
159
 
160
 
161
  @torch.inference_mode()
162
+ def generate(state: Conversation, temperature, top_p, max_new_tokens):
163
  if state.skip_next: # This generate call is skipped due to invalid inputs
164
  yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 3 + knwl_unchange
165
  return
 
172
 
173
  # retrieve and visualize knowledge
174
  image = state.get_images(return_pil=True)[0]
175
+ knwl_embd, knwl = retrieve_knowledge(image)
176
+ knwl_img, knwl_txt, idx = [None, ] * 15, ["", ] * 15, 0
177
+ for query_type, knwl_pos in (("whole", 1), ("five", 5), ("nine", 9)):
178
+ if query_type == "whole":
179
+ images = [image, ]
180
+ elif query_type == "five":
181
+ images = five_crop(image)
182
+ elif query_type == "nine":
183
+ images = nine_crop(image)
184
+
185
+ for pos in range(knwl_pos):
186
+ try:
187
+ txt = ""
188
+ for k, v in knwl[query_type][pos].items():
189
+ v = ", ".join([vi.replace("_", " ") for vi in v])
190
+ txt += f"**[{k.upper()}]:** {v}\n\n"
191
+ knwl_txt[idx] += txt
192
+
193
+ img = images[pos]
194
+ img = query_trans.transforms[0](img)
195
+ img = query_trans.transforms[1](img)
196
+ img = query_trans.transforms[2](img)
197
+ knwl_img[idx] = img
198
+ except KeyError:
199
+ pass
200
+ idx += 1
201
+ knwl_vis = tuple(knwl_img + knwl_txt)
 
 
 
 
202
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 3 + knwl_vis
203
 
204
  # generate output
 
213
  target=gptk_model.generate,
214
  kwargs=dict(
215
  samples=samples,
216
+ use_nucleus_sampling=(temperature > 0.001),
217
  max_length=min(int(max_new_tokens), 1024),
218
  top_p=float(top_p),
219
  temperature=float(temperature),
 
266
  gr.Examples(examples=[
267
  ["examples/mona_lisa.jpg", "Discuss the historical impact and the significance of this painting in the art world."],
268
  ["examples/mona_lisa_dog.jpg", "Describe this photo in detail."],
 
269
  ["examples/horseshoe_bend.jpg", "What are the possible reasons of the formation of this sight?"],
270
  ], inputs=[imagebox, textbox])
271
 
 
281
  clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=False, scale=1)
282
 
283
  with gr.Accordion("Parameters", open=True):
284
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True, label="Temperature",)
 
 
 
285
  top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
286
  max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
287
 
 
310
  regenerate, [state], [state, chatbot, textbox, imagebox] + btn_list
311
  ).then(
312
  generate,
313
+ [state, temperature, top_p, max_output_tokens],
314
  [state, chatbot] + btn_list + knwl_vis
315
  )
316
 
 
322
  add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
323
  ).then(
324
  generate,
325
+ [state, temperature, top_p, max_output_tokens],
326
  [state, chatbot] + btn_list + knwl_vis
327
  )
328
 
 
330
  add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
331
  ).then(
332
  generate,
333
+ [state, temperature, top_p, max_output_tokens],
334
  [state, chatbot] + btn_list + knwl_vis
335
  )
336
 
examples/diamond_head.jpg DELETED

Git LFS Details

  • SHA256: 33d2f8ebdcde47a8a3cef6af8baa13cbbfc148a25dc869c081f0c4bc4d5522b1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.13 MB