Added detailed captioning, increase `max_new_tokens` and fix escape character

#8
by merve HF staff - opened
Files changed (1) hide show
  1. app.py +16 -8
app.py CHANGED
@@ -9,11 +9,13 @@ model_id = "adept/fuyu-8b"
9
  dtype = torch.bfloat16
10
  device = "cuda"
11
 
 
12
  tokenizer = AutoTokenizer.from_pretrained(model_id)
13
  model = FuyuForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=dtype)
14
  processor = FuyuProcessor(image_processor=FuyuImageProcessor(), tokenizer=tokenizer)
15
 
16
- caption_prompt = "Generate a coco-style caption.\\n"
 
17
 
18
  def resize_to_max(image, max_width=1080, max_height=1080):
19
  width, height = image.size
@@ -33,12 +35,16 @@ def predict(image, prompt):
33
  model_inputs = processor(text=prompt, images=[image])
34
  model_inputs = {k: v.to(dtype=dtype if torch.is_floating_point(v) else v.dtype, device=device) for k,v in model_inputs.items()}
35
 
36
- generation_output = model.generate(**model_inputs, max_new_tokens=40)
37
  prompt_len = model_inputs["input_ids"].shape[-1]
38
  return tokenizer.decode(generation_output[0][prompt_len:], skip_special_tokens=True)
39
 
40
- def caption(image):
41
- return predict(image, caption_prompt)
 
 
 
 
42
 
43
  def set_example_image(example: list) -> dict:
44
  return gr.Image.update(value=example[0])
@@ -88,20 +94,22 @@ with gr.Blocks(css=css) as demo:
88
 
89
  with gr.Tab("Image Captioning"):
90
  with gr.Row():
91
- captioning_input = gr.Image(label="Upload your Image", type="pil")
 
 
92
  captioning_output = gr.Textbox(label="Output")
93
  captioning_btn = gr.Button("Generate Caption")
94
 
95
  gr.Examples(
96
- [["assets/captioning_example_1.png"], ["assets/captioning_example_2.png"]],
97
- inputs = [captioning_input],
98
  outputs = [captioning_output],
99
  fn=caption,
100
  cache_examples=True,
101
  label='Click on any Examples below to get captioning results quickly πŸ‘‡'
102
  )
103
 
104
- captioning_btn.click(fn=caption, inputs=captioning_input, outputs=captioning_output)
105
  vqa_btn.click(fn=predict, inputs=[image_input, text_input], outputs=vqa_output)
106
 
107
 
 
9
  dtype = torch.bfloat16
10
  device = "cuda"
11
 
12
+
13
  tokenizer = AutoTokenizer.from_pretrained(model_id)
14
  model = FuyuForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=dtype)
15
  processor = FuyuProcessor(image_processor=FuyuImageProcessor(), tokenizer=tokenizer)
16
 
17
+ CAPTION_PROMPT = "Generate a coco-style caption.\n"
18
+ DETAILED_CAPTION_PROMPT = "What is happening in this image?"
19
 
20
  def resize_to_max(image, max_width=1080, max_height=1080):
21
  width, height = image.size
 
35
  model_inputs = processor(text=prompt, images=[image])
36
  model_inputs = {k: v.to(dtype=dtype if torch.is_floating_point(v) else v.dtype, device=device) for k,v in model_inputs.items()}
37
 
38
+ generation_output = model.generate(**model_inputs, max_new_tokens=50)
39
  prompt_len = model_inputs["input_ids"].shape[-1]
40
  return tokenizer.decode(generation_output[0][prompt_len:], skip_special_tokens=True)
41
 
42
+ def caption(image, detailed_captioning):
43
+ if detailed_captioning:
44
+ caption_prompt = DETAILED_CAPTION_PROMPT
45
+ else:
46
+ caption_prompt = CAPTION_PROMPT
47
+ return predict(image, caption_prompt).lstrip()
48
 
49
  def set_example_image(example: list) -> dict:
50
  return gr.Image.update(value=example[0])
 
94
 
95
  with gr.Tab("Image Captioning"):
96
  with gr.Row():
97
+ with gr.Column():
98
+ captioning_input = gr.Image(label="Upload your Image", type="pil")
99
+ detailed_captioning_checkbox = gr.Checkbox(label="Enable detailed captioning")
100
  captioning_output = gr.Textbox(label="Output")
101
  captioning_btn = gr.Button("Generate Caption")
102
 
103
  gr.Examples(
104
+ [["assets/captioning_example_1.png", False], ["assets/captioning_example_2.png", True]],
105
+ inputs = [captioning_input, detailed_captioning_checkbox],
106
  outputs = [captioning_output],
107
  fn=caption,
108
  cache_examples=True,
109
  label='Click on any Examples below to get captioning results quickly πŸ‘‡'
110
  )
111
 
112
+ captioning_btn.click(fn=caption, inputs=[captioning_input, detailed_captioning_checkbox], outputs=captioning_output)
113
  vqa_btn.click(fn=predict, inputs=[image_input, text_input], outputs=vqa_output)
114
 
115