Spaces:
Runtime error
Runtime error
cwkuo
commited on
Commit
β’
6b2ffd3
1
Parent(s):
ef2dc13
tune some default params
Browse files- app.py +33 -41
- examples/diamond_head.jpg +0 -3
app.py
CHANGED
@@ -159,7 +159,7 @@ def retrieve_knowledge(image):
|
|
159 |
|
160 |
|
161 |
@torch.inference_mode()
|
162 |
-
def generate(state: Conversation, temperature, top_p, max_new_tokens
|
163 |
if state.skip_next: # This generate call is skipped due to invalid inputs
|
164 |
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 3 + knwl_unchange
|
165 |
return
|
@@ -172,37 +172,33 @@ def generate(state: Conversation, temperature, top_p, max_new_tokens, add_knwl,
|
|
172 |
|
173 |
# retrieve and visualize knowledge
|
174 |
image = state.get_images(return_pil=True)[0]
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
knwl_vis = tuple(knwl_img + knwl_txt)
|
203 |
-
else:
|
204 |
-
knwl_embd = None
|
205 |
-
knwl_vis = knwl_none
|
206 |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 3 + knwl_vis
|
207 |
|
208 |
# generate output
|
@@ -217,7 +213,7 @@ def generate(state: Conversation, temperature, top_p, max_new_tokens, add_knwl,
|
|
217 |
target=gptk_model.generate,
|
218 |
kwargs=dict(
|
219 |
samples=samples,
|
220 |
-
use_nucleus_sampling=
|
221 |
max_length=min(int(max_new_tokens), 1024),
|
222 |
top_p=float(top_p),
|
223 |
temperature=float(temperature),
|
@@ -270,7 +266,6 @@ def build_demo():
|
|
270 |
gr.Examples(examples=[
|
271 |
["examples/mona_lisa.jpg", "Discuss the historical impact and the significance of this painting in the art world."],
|
272 |
["examples/mona_lisa_dog.jpg", "Describe this photo in detail."],
|
273 |
-
["examples/diamond_head.jpg", "What is the name of this famous sight in the photo?"],
|
274 |
["examples/horseshoe_bend.jpg", "What are the possible reasons of the formation of this sight?"],
|
275 |
], inputs=[imagebox, textbox])
|
276 |
|
@@ -286,10 +281,7 @@ def build_demo():
|
|
286 |
clear_btn = gr.Button(value="ποΈ Clear", interactive=False, scale=1)
|
287 |
|
288 |
with gr.Accordion("Parameters", open=True):
|
289 |
-
|
290 |
-
add_knwl = gr.Checkbox(value=True, interactive=True, label="Knowledge")
|
291 |
-
do_sampling = gr.Checkbox(value=False, interactive=True, label="Sampling")
|
292 |
-
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, interactive=True, label="Temperature",)
|
293 |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
|
294 |
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
|
295 |
|
@@ -318,7 +310,7 @@ def build_demo():
|
|
318 |
regenerate, [state], [state, chatbot, textbox, imagebox] + btn_list
|
319 |
).then(
|
320 |
generate,
|
321 |
-
[state, temperature, top_p, max_output_tokens
|
322 |
[state, chatbot] + btn_list + knwl_vis
|
323 |
)
|
324 |
|
@@ -330,7 +322,7 @@ def build_demo():
|
|
330 |
add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
|
331 |
).then(
|
332 |
generate,
|
333 |
-
[state, temperature, top_p, max_output_tokens
|
334 |
[state, chatbot] + btn_list + knwl_vis
|
335 |
)
|
336 |
|
@@ -338,7 +330,7 @@ def build_demo():
|
|
338 |
add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
|
339 |
).then(
|
340 |
generate,
|
341 |
-
[state, temperature, top_p, max_output_tokens
|
342 |
[state, chatbot] + btn_list + knwl_vis
|
343 |
)
|
344 |
|
|
|
159 |
|
160 |
|
161 |
@torch.inference_mode()
|
162 |
+
def generate(state: Conversation, temperature, top_p, max_new_tokens):
|
163 |
if state.skip_next: # This generate call is skipped due to invalid inputs
|
164 |
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 3 + knwl_unchange
|
165 |
return
|
|
|
172 |
|
173 |
# retrieve and visualize knowledge
|
174 |
image = state.get_images(return_pil=True)[0]
|
175 |
+
knwl_embd, knwl = retrieve_knowledge(image)
|
176 |
+
knwl_img, knwl_txt, idx = [None, ] * 15, ["", ] * 15, 0
|
177 |
+
for query_type, knwl_pos in (("whole", 1), ("five", 5), ("nine", 9)):
|
178 |
+
if query_type == "whole":
|
179 |
+
images = [image, ]
|
180 |
+
elif query_type == "five":
|
181 |
+
images = five_crop(image)
|
182 |
+
elif query_type == "nine":
|
183 |
+
images = nine_crop(image)
|
184 |
+
|
185 |
+
for pos in range(knwl_pos):
|
186 |
+
try:
|
187 |
+
txt = ""
|
188 |
+
for k, v in knwl[query_type][pos].items():
|
189 |
+
v = ", ".join([vi.replace("_", " ") for vi in v])
|
190 |
+
txt += f"**[{k.upper()}]:** {v}\n\n"
|
191 |
+
knwl_txt[idx] += txt
|
192 |
+
|
193 |
+
img = images[pos]
|
194 |
+
img = query_trans.transforms[0](img)
|
195 |
+
img = query_trans.transforms[1](img)
|
196 |
+
img = query_trans.transforms[2](img)
|
197 |
+
knwl_img[idx] = img
|
198 |
+
except KeyError:
|
199 |
+
pass
|
200 |
+
idx += 1
|
201 |
+
knwl_vis = tuple(knwl_img + knwl_txt)
|
|
|
|
|
|
|
|
|
202 |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 3 + knwl_vis
|
203 |
|
204 |
# generate output
|
|
|
213 |
target=gptk_model.generate,
|
214 |
kwargs=dict(
|
215 |
samples=samples,
|
216 |
+
use_nucleus_sampling=(temperature > 0.001),
|
217 |
max_length=min(int(max_new_tokens), 1024),
|
218 |
top_p=float(top_p),
|
219 |
temperature=float(temperature),
|
|
|
266 |
gr.Examples(examples=[
|
267 |
["examples/mona_lisa.jpg", "Discuss the historical impact and the significance of this painting in the art world."],
|
268 |
["examples/mona_lisa_dog.jpg", "Describe this photo in detail."],
|
|
|
269 |
["examples/horseshoe_bend.jpg", "What are the possible reasons of the formation of this sight?"],
|
270 |
], inputs=[imagebox, textbox])
|
271 |
|
|
|
281 |
clear_btn = gr.Button(value="ποΈ Clear", interactive=False, scale=1)
|
282 |
|
283 |
with gr.Accordion("Parameters", open=True):
|
284 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True, label="Temperature",)
|
|
|
|
|
|
|
285 |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
|
286 |
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
|
287 |
|
|
|
310 |
regenerate, [state], [state, chatbot, textbox, imagebox] + btn_list
|
311 |
).then(
|
312 |
generate,
|
313 |
+
[state, temperature, top_p, max_output_tokens],
|
314 |
[state, chatbot] + btn_list + knwl_vis
|
315 |
)
|
316 |
|
|
|
322 |
add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
|
323 |
).then(
|
324 |
generate,
|
325 |
+
[state, temperature, top_p, max_output_tokens],
|
326 |
[state, chatbot] + btn_list + knwl_vis
|
327 |
)
|
328 |
|
|
|
330 |
add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
|
331 |
).then(
|
332 |
generate,
|
333 |
+
[state, temperature, top_p, max_output_tokens],
|
334 |
[state, chatbot] + btn_list + knwl_vis
|
335 |
)
|
336 |
|
examples/diamond_head.jpg
DELETED
Git LFS Details
|