ttengwang commited on
Commit
ccb14a3
1 Parent(s): cd2f644

support "segment everything in a paragraph"

Browse files
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import os
2
  import json
3
- import PIL
4
  import gradio as gr
5
  import numpy as np
6
  from gradio import processing_utils
@@ -11,7 +10,7 @@ import functools
11
 
12
  from caption_anything.model import CaptionAnything
13
  from caption_anything.utils.image_editing_utils import create_bubble_frame
14
- from caption_anything.utils.utils import mask_painter, seg_model_map, prepare_segmenter
15
  from caption_anything.utils.parser import parse_augment
16
  from caption_anything.captioner import build_captioner
17
  from caption_anything.text_refiner import build_text_refiner
@@ -23,6 +22,7 @@ from segment_anything import sam_model_registry
23
  args = parse_augment()
24
  args.segmenter = "huge"
25
  args.segmenter_checkpoint = "sam_vit_h_4b8939.pth"
 
26
  if args.segmenter_checkpoint is None:
27
  _, segmenter_checkpoint = prepare_segmenter(args.segmenter)
28
  else:
@@ -53,9 +53,7 @@ class ImageSketcher(gr.Image):
53
  mask = np.zeros((height, width, 4), dtype=np.uint8)
54
  mask[..., -1] = 255
55
  mask = self.postprocess(mask)
56
-
57
  x['mask'] = mask
58
-
59
  return super().preprocess(x)
60
 
61
 
@@ -74,16 +72,19 @@ def init_openai_api_key(api_key=""):
74
  if api_key and len(api_key) > 30:
75
  try:
76
  text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
77
- text_refiner.llm('hi') # test
78
  visual_chatgpt = ConversationBot(shared_chatbot_tools, api_key)
79
  except:
80
  text_refiner = None
81
  visual_chatgpt = None
82
  openai_available = text_refiner is not None
83
- return gr.update(visible=openai_available), gr.update(visible=openai_available), gr.update(
84
- visible=openai_available), gr.update(visible=True), gr.update(visible=True), gr.update(
85
- visible=True), text_refiner, visual_chatgpt
86
-
 
 
 
87
 
88
  def get_click_prompt(chat_input, click_state, click_mode):
89
  inputs = json.loads(chat_input)
@@ -130,18 +131,15 @@ def chat_input_callback(*args):
130
  state = state + [(chat_input, response)]
131
  return state, state
132
 
 
 
133
  def upload_callback(image_input, state, visual_chatgpt=None):
134
 
135
  if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask
136
  image_input, mask = image_input['image'], image_input['mask']
137
 
138
  click_state = [[], [], []]
139
- res = 1024
140
- width, height = image_input.size
141
- ratio = min(1.0 * res / max(width, height), 1.0)
142
- if ratio < 1.0:
143
- image_input = image_input.resize((int(width * ratio), int(height * ratio)))
144
- print('Scaling input image to {}'.format(image_input.size))
145
 
146
  model = build_caption_anything_with_models(
147
  args,
@@ -159,8 +157,8 @@ def upload_callback(image_input, state, visual_chatgpt=None):
159
  new_image_path = get_new_image_name('chat_image', func_name='upload')
160
  image_input.save(new_image_path)
161
  visual_chatgpt.current_image = new_image_path
162
- img_caption, _ = model.captioner.inference_seg(image_input)
163
- Human_prompt = f'\nHuman: provide a new figure with path {new_image_path}. The description is: {img_caption}. This information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
164
  AI_prompt = "Received."
165
  visual_chatgpt.global_prompt = Human_prompt + 'AI: ' + AI_prompt
166
  visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + visual_chatgpt.global_prompt
@@ -201,11 +199,10 @@ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language
201
  model.setup(image_embedding, original_size, input_size, is_image_set=True)
202
 
203
  enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
204
- out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)
205
 
206
  state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
207
  state = state + [(None, "raw_caption: {}".format(out['generated_captions']['raw_caption']))]
208
- wiki = out['generated_captions'].get('wiki', "")
209
  update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
210
  text = out['generated_captions']['raw_caption']
211
  input_mask = np.array(out['mask'].convert('P'))
@@ -221,21 +218,22 @@ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language
221
  point_prompt = f'You should primarly use tools on the selected regional image (description: {text}, path: {new_crop_save_path}), which is a part of the whole image (path: {visual_chatgpt.current_image}). If human mentioned some objects not in the selected region, you can use tools on the whole image.'
222
  visual_chatgpt.point_prompt = point_prompt
223
 
224
- yield state, state, click_state, image_input, wiki
225
  if not args.disable_gpt and model.text_refiner:
226
  refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
227
  enable_wiki=enable_wiki)
228
  # new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
229
  new_cap = refined_caption['caption']
230
- wiki = refined_caption['wiki']
 
231
  state = state + [(None, f"caption: {new_cap}")]
232
  refined_image_input = create_bubble_frame(origin_image_input, new_cap, (click_index[0], click_index[1]),
233
  input_mask,
234
  input_points=input_points, input_labels=input_labels)
235
- yield state, state, click_state, refined_image_input, wiki
236
 
237
 
238
- def get_sketch_prompt(mask: PIL.Image.Image):
239
  """
240
  Get the prompt for the sketcher.
241
  TODO: This is a temporary solution. We should cluster the sketch and get the bounding box of each cluster.
@@ -282,12 +280,11 @@ def inference_traject(sketcher_image, enable_wiki, language, sentiment, factuali
282
  model.setup(image_embedding, original_size, input_size, is_image_set=True)
283
 
284
  enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
285
- out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)
286
 
287
  # Update components and states
288
  state.append((f'Box: {boxes}', None))
289
  state.append((None, f'raw_caption: {out["generated_captions"]["raw_caption"]}'))
290
- wiki = out['generated_captions'].get('wiki', "")
291
  text = out['generated_captions']['raw_caption']
292
  input_mask = np.array(out['mask'].convert('P'))
293
  image_input = mask_painter(np.array(image_input), input_mask)
@@ -297,18 +294,19 @@ def inference_traject(sketcher_image, enable_wiki, language, sentiment, factuali
297
  fake_click_index = (int((boxes[0][0] + boxes[0][2]) / 2), int((boxes[0][1] + boxes[0][3]) / 2))
298
  image_input = create_bubble_frame(image_input, text, fake_click_index, input_mask)
299
 
300
- yield state, state, image_input, wiki
301
 
302
  if not args.disable_gpt and model.text_refiner:
303
  refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
304
  enable_wiki=enable_wiki)
305
 
306
  new_cap = refined_caption['caption']
307
- wiki = refined_caption['wiki']
 
308
  state = state + [(None, f"caption: {new_cap}")]
309
  refined_image_input = create_bubble_frame(origin_image_input, new_cap, fake_click_index, input_mask)
310
 
311
- yield state, state, refined_image_input, wiki
312
 
313
  def clear_chat_memory(visual_chatgpt, keep_global=False):
314
  if visual_chatgpt is not None:
@@ -319,7 +317,26 @@ def clear_chat_memory(visual_chatgpt, keep_global=False):
319
  else:
320
  visual_chatgpt.current_image = None
321
  visual_chatgpt.global_prompt = ""
322
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  def get_style():
324
  current_version = version.parse(gr.__version__)
325
  if current_version <= version.parse('3.24.1'):
@@ -400,7 +417,7 @@ def create_ui():
400
  with gr.Row():
401
  submit_button_sketcher = gr.Button(value="Submit", interactive=True)
402
 
403
- with gr.Column(visible=False) as modules_need_gpt:
404
  with gr.Row(scale=1.0):
405
  language = gr.Dropdown(
406
  ['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"],
@@ -431,26 +448,31 @@ def create_ui():
431
  value="No",
432
  label="Enable Wiki",
433
  interactive=True)
434
- with gr.Column(visible=True) as modules_not_need_gpt3:
435
- gr.Examples(
436
- examples=examples,
437
- inputs=[example_image],
438
- )
439
  with gr.Column(scale=0.5):
440
- openai_api_key = gr.Textbox(
441
- placeholder="Input openAI API key",
442
- show_label=False,
443
- label="OpenAI API Key",
444
- lines=1,
445
- type="password")
446
- with gr.Row(scale=0.5):
447
- enable_chatGPT_button = gr.Button(value="Run with ChatGPT", interactive=True, variant='primary')
448
- disable_chatGPT_button = gr.Button(value="Run without ChatGPT (Faster)", interactive=True,
449
- variant='primary')
450
- with gr.Column(visible=False) as modules_need_gpt2:
451
- wiki_output = gr.Textbox(lines=5, label="Wiki", max_lines=5)
452
- with gr.Column(visible=False) as modules_not_need_gpt2:
453
- chatbot = gr.Chatbot(label="Chat about Selected Object", ).style(height=550, scale=0.5)
 
 
 
 
 
454
  with gr.Column(visible=False) as modules_need_gpt3:
455
  chat_input = gr.Textbox(show_label=False, placeholder="Enter text and press Enter").style(
456
  container=False)
@@ -459,36 +481,38 @@ def create_ui():
459
  submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
460
 
461
  openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key],
462
- outputs=[modules_need_gpt, modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt,
463
- modules_not_need_gpt2, modules_not_need_gpt3, text_refiner, visual_chatgpt])
464
  enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key],
465
- outputs=[modules_need_gpt, modules_need_gpt2, modules_need_gpt3,
466
  modules_not_need_gpt,
467
- modules_not_need_gpt2, modules_not_need_gpt3, text_refiner, visual_chatgpt])
468
- disable_chatGPT_button.click(init_openai_api_key,
469
- outputs=[modules_need_gpt, modules_need_gpt2, modules_need_gpt3,
470
  modules_not_need_gpt,
471
- modules_not_need_gpt2, modules_not_need_gpt3, text_refiner, visual_chatgpt])
472
-
473
  enable_chatGPT_button.click(
474
  lambda: (None, [], [], [[], [], []], "", "", ""),
475
  [],
476
- [image_input, chatbot, state, click_state, wiki_output, origin_image],
477
  queue=False,
478
  show_progress=False
479
  )
480
  openai_api_key.submit(
481
  lambda: (None, [], [], [[], [], []], "", "", ""),
482
  [],
483
- [image_input, chatbot, state, click_state, wiki_output, origin_image],
484
  queue=False,
485
  show_progress=False
486
  )
 
 
487
 
488
  clear_button_click.click(
489
- lambda x: ([[], [], []], x, ""),
490
  [origin_image],
491
- [click_state, image_input, wiki_output],
492
  queue=False,
493
  show_progress=False
494
  )
@@ -496,7 +520,7 @@ def create_ui():
496
  clear_button_image.click(
497
  lambda: (None, [], [], [[], [], []], "", "", ""),
498
  [],
499
- [image_input, chatbot, state, click_state, wiki_output, origin_image],
500
  queue=False,
501
  show_progress=False
502
  )
@@ -513,7 +537,7 @@ def create_ui():
513
  image_input.clear(
514
  lambda: (None, [], [], [[], [], []], "", "", ""),
515
  [],
516
- [image_input, chatbot, state, click_state, wiki_output, origin_image],
517
  queue=False,
518
  show_progress=False
519
  )
@@ -544,7 +568,7 @@ def create_ui():
544
  origin_image, point_prompt, click_mode, enable_wiki, language, sentiment, factuality, length,
545
  image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt
546
  ],
547
- outputs=[chatbot, state, click_state, image_input, wiki_output],
548
  show_progress=False, queue=True
549
  )
550
 
@@ -554,7 +578,7 @@ def create_ui():
554
  sketcher_input, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
555
  original_size, input_size, text_refiner
556
  ],
557
- outputs=[chatbot, state, sketcher_input, wiki_output],
558
  show_progress=False, queue=True
559
  )
560
 
 
1
  import os
2
  import json
 
3
  import gradio as gr
4
  import numpy as np
5
  from gradio import processing_utils
 
10
 
11
  from caption_anything.model import CaptionAnything
12
  from caption_anything.utils.image_editing_utils import create_bubble_frame
13
+ from caption_anything.utils.utils import mask_painter, seg_model_map, prepare_segmenter, image_resize
14
  from caption_anything.utils.parser import parse_augment
15
  from caption_anything.captioner import build_captioner
16
  from caption_anything.text_refiner import build_text_refiner
 
22
  args = parse_augment()
23
  args.segmenter = "huge"
24
  args.segmenter_checkpoint = "sam_vit_h_4b8939.pth"
25
+
26
  if args.segmenter_checkpoint is None:
27
  _, segmenter_checkpoint = prepare_segmenter(args.segmenter)
28
  else:
 
53
  mask = np.zeros((height, width, 4), dtype=np.uint8)
54
  mask[..., -1] = 255
55
  mask = self.postprocess(mask)
 
56
  x['mask'] = mask
 
57
  return super().preprocess(x)
58
 
59
 
 
72
  if api_key and len(api_key) > 30:
73
  try:
74
  text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
75
+ assert len(text_refiner.llm('hi')) > 0 # test
76
  visual_chatgpt = ConversationBot(shared_chatbot_tools, api_key)
77
  except:
78
  text_refiner = None
79
  visual_chatgpt = None
80
  openai_available = text_refiner is not None
81
+ if openai_available:
82
+ return [gr.update(visible=True)]*6 + [gr.update(visible=False)]*2 + [text_refiner, visual_chatgpt, None]
83
+ else:
84
+ return [gr.update(visible=False)]*6 + [gr.update(visible=True)]*2 + [text_refiner, visual_chatgpt, 'Your OpenAI API Key is not available']
85
+
86
+ def init_wo_openai_api_key():
87
+ return [gr.update(visible=False)]*4 + [gr.update(visible=True)]*2 + [gr.update(visible=False)]*2 + [None, None, None]
88
 
89
  def get_click_prompt(chat_input, click_state, click_mode):
90
  inputs = json.loads(chat_input)
 
131
  state = state + [(chat_input, response)]
132
  return state, state
133
 
134
+
135
+
136
  def upload_callback(image_input, state, visual_chatgpt=None):
137
 
138
  if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask
139
  image_input, mask = image_input['image'], image_input['mask']
140
 
141
  click_state = [[], [], []]
142
+ image_input = image_resize(image_input, res=1024)
 
 
 
 
 
143
 
144
  model = build_caption_anything_with_models(
145
  args,
 
157
  new_image_path = get_new_image_name('chat_image', func_name='upload')
158
  image_input.save(new_image_path)
159
  visual_chatgpt.current_image = new_image_path
160
+ img_caption = model.captioner.inference(image_input, filter=False, args={'text_prompt':''})['caption']
161
+ Human_prompt = f'\nHuman: The description of the image with path {new_image_path} is: {img_caption}. This information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
162
  AI_prompt = "Received."
163
  visual_chatgpt.global_prompt = Human_prompt + 'AI: ' + AI_prompt
164
  visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + visual_chatgpt.global_prompt
 
199
  model.setup(image_embedding, original_size, input_size, is_image_set=True)
200
 
201
  enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
202
+ out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki, verbose=True, args={'clip_filter': False})[0]
203
 
204
  state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
205
  state = state + [(None, "raw_caption: {}".format(out['generated_captions']['raw_caption']))]
 
206
  update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
207
  text = out['generated_captions']['raw_caption']
208
  input_mask = np.array(out['mask'].convert('P'))
 
218
  point_prompt = f'You should primarly use tools on the selected regional image (description: {text}, path: {new_crop_save_path}), which is a part of the whole image (path: {visual_chatgpt.current_image}). If human mentioned some objects not in the selected region, you can use tools on the whole image.'
219
  visual_chatgpt.point_prompt = point_prompt
220
 
221
+ yield state, state, click_state, image_input
222
  if not args.disable_gpt and model.text_refiner:
223
  refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
224
  enable_wiki=enable_wiki)
225
  # new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
226
  new_cap = refined_caption['caption']
227
+ if refined_caption['wiki']:
228
+ state = state + [(None, "Wiki: {}".format(refined_caption['wiki']))]
229
  state = state + [(None, f"caption: {new_cap}")]
230
  refined_image_input = create_bubble_frame(origin_image_input, new_cap, (click_index[0], click_index[1]),
231
  input_mask,
232
  input_points=input_points, input_labels=input_labels)
233
+ yield state, state, click_state, refined_image_input
234
 
235
 
236
+ def get_sketch_prompt(mask: Image.Image):
237
  """
238
  Get the prompt for the sketcher.
239
  TODO: This is a temporary solution. We should cluster the sketch and get the bounding box of each cluster.
 
280
  model.setup(image_embedding, original_size, input_size, is_image_set=True)
281
 
282
  enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
283
+ out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)[0]
284
 
285
  # Update components and states
286
  state.append((f'Box: {boxes}', None))
287
  state.append((None, f'raw_caption: {out["generated_captions"]["raw_caption"]}'))
 
288
  text = out['generated_captions']['raw_caption']
289
  input_mask = np.array(out['mask'].convert('P'))
290
  image_input = mask_painter(np.array(image_input), input_mask)
 
294
  fake_click_index = (int((boxes[0][0] + boxes[0][2]) / 2), int((boxes[0][1] + boxes[0][3]) / 2))
295
  image_input = create_bubble_frame(image_input, text, fake_click_index, input_mask)
296
 
297
+ yield state, state, image_input
298
 
299
  if not args.disable_gpt and model.text_refiner:
300
  refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
301
  enable_wiki=enable_wiki)
302
 
303
  new_cap = refined_caption['caption']
304
+ if refined_caption['wiki']:
305
+ state = state + [(None, "Wiki: {}".format(refined_caption['wiki']))]
306
  state = state + [(None, f"caption: {new_cap}")]
307
  refined_image_input = create_bubble_frame(origin_image_input, new_cap, fake_click_index, input_mask)
308
 
309
+ yield state, state, refined_image_input
310
 
311
  def clear_chat_memory(visual_chatgpt, keep_global=False):
312
  if visual_chatgpt is not None:
 
317
  else:
318
  visual_chatgpt.current_image = None
319
  visual_chatgpt.global_prompt = ""
320
+
321
+ def cap_everything(image_input, visual_chatgpt, text_refiner):
322
+
323
+ model = build_caption_anything_with_models(
324
+ args,
325
+ api_key="",
326
+ captioner=shared_captioner,
327
+ sam_model=shared_sam_model,
328
+ text_refiner=text_refiner,
329
+ session_id=iface.app_id
330
+ )
331
+ paragraph = model.inference_cap_everything(image_input, verbose=True)
332
+ # state = state + [(None, f"Caption Everything: {paragraph}")]
333
+ Human_prompt = f'\nThe description of the image with path {visual_chatgpt.current_image} is:\n{paragraph}\nThis information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
334
+ AI_prompt = "Received."
335
+ visual_chatgpt.global_prompt = Human_prompt + 'AI: ' + AI_prompt
336
+ visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + visual_chatgpt.global_prompt
337
+ return paragraph
338
+
339
+
340
  def get_style():
341
  current_version = version.parse(gr.__version__)
342
  if current_version <= version.parse('3.24.1'):
 
417
  with gr.Row():
418
  submit_button_sketcher = gr.Button(value="Submit", interactive=True)
419
 
420
+ with gr.Column(visible=False) as modules_need_gpt1:
421
  with gr.Row(scale=1.0):
422
  language = gr.Dropdown(
423
  ['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"],
 
448
  value="No",
449
  label="Enable Wiki",
450
  interactive=True)
451
+ # with gr.Column(visible=True) as modules_not_need_gpt3:
452
+ gr.Examples(
453
+ examples=examples,
454
+ inputs=[example_image],
455
+ )
456
  with gr.Column(scale=0.5):
457
+ with gr.Column(visible=True) as module_key_input:
458
+ openai_api_key = gr.Textbox(
459
+ placeholder="Input openAI API key",
460
+ show_label=False,
461
+ label="OpenAI API Key",
462
+ lines=1,
463
+ type="password")
464
+ with gr.Row(scale=0.5):
465
+ enable_chatGPT_button = gr.Button(value="Run with ChatGPT", interactive=True, variant='primary')
466
+ disable_chatGPT_button = gr.Button(value="Run without ChatGPT (Faster)", interactive=True,
467
+ variant='primary')
468
+ with gr.Column(visible=False) as module_notification_box:
469
+ notification_box = gr.Textbox(lines=1, label="Notification", max_lines=5, show_label=False)
470
+ with gr.Column(visible=False) as modules_need_gpt2:
471
+ paragraph_output = gr.Textbox(lines=7, label="Describe Everything", max_lines=7)
472
+ with gr.Column(visible=False) as modules_need_gpt0:
473
+ cap_everything_button = gr.Button(value="Caption Everything in a Paragraph", interactive=True)
474
+ with gr.Column(visible=False) as modules_not_need_gpt2:
475
+ chatbot = gr.Chatbot(label="Chatbox", ).style(height=550, scale=0.5)
476
  with gr.Column(visible=False) as modules_need_gpt3:
477
  chat_input = gr.Textbox(show_label=False, placeholder="Enter text and press Enter").style(
478
  container=False)
 
481
  submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
482
 
483
  openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key],
484
+ outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt,
485
+ modules_not_need_gpt2, module_key_input, module_notification_box, text_refiner, visual_chatgpt, notification_box])
486
  enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key],
487
+ outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3,
488
  modules_not_need_gpt,
489
+ modules_not_need_gpt2, module_key_input, module_notification_box, text_refiner, visual_chatgpt, notification_box])
490
+ disable_chatGPT_button.click(init_wo_openai_api_key,
491
+ outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3,
492
  modules_not_need_gpt,
493
+ modules_not_need_gpt2, module_key_input, module_notification_box, text_refiner, visual_chatgpt, notification_box])
494
+
495
  enable_chatGPT_button.click(
496
  lambda: (None, [], [], [[], [], []], "", "", ""),
497
  [],
498
+ [image_input, chatbot, state, click_state, paragraph_output, origin_image],
499
  queue=False,
500
  show_progress=False
501
  )
502
  openai_api_key.submit(
503
  lambda: (None, [], [], [[], [], []], "", "", ""),
504
  [],
505
+ [image_input, chatbot, state, click_state, paragraph_output, origin_image],
506
  queue=False,
507
  show_progress=False
508
  )
509
+
510
+ cap_everything_button.click(cap_everything, [origin_image, visual_chatgpt, text_refiner], [paragraph_output])
511
 
512
  clear_button_click.click(
513
+ lambda x: ([[], [], []], x),
514
  [origin_image],
515
+ [click_state, image_input],
516
  queue=False,
517
  show_progress=False
518
  )
 
520
  clear_button_image.click(
521
  lambda: (None, [], [], [[], [], []], "", "", ""),
522
  [],
523
+ [image_input, chatbot, state, click_state, paragraph_output, origin_image],
524
  queue=False,
525
  show_progress=False
526
  )
 
537
  image_input.clear(
538
  lambda: (None, [], [], [[], [], []], "", "", ""),
539
  [],
540
+ [image_input, chatbot, state, click_state, paragraph_output, origin_image],
541
  queue=False,
542
  show_progress=False
543
  )
 
568
  origin_image, point_prompt, click_mode, enable_wiki, language, sentiment, factuality, length,
569
  image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt
570
  ],
571
+ outputs=[chatbot, state, click_state, image_input],
572
  show_progress=False, queue=True
573
  )
574
 
 
578
  sketcher_input, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
579
  original_size, input_size, text_refiner
580
  ],
581
+ outputs=[chatbot, state, sketcher_input],
582
  show_progress=False, queue=True
583
  )
584
 
caption_anything/captioner/base_captioner.py CHANGED
@@ -5,7 +5,7 @@ import json
5
  import pdb
6
  import cv2
7
  import numpy as np
8
- from typing import Union
9
  import time
10
  import clip
11
 
@@ -16,13 +16,10 @@ def boundary(inputs):
16
  col = inputs.shape[1]
17
  inputs = inputs.reshape(-1)
18
  lens = len(inputs)
19
-
20
  start = np.argmax(inputs)
21
  end = lens - 1 - np.argmax(np.flip(inputs))
22
-
23
  top = start // col
24
  bottom = end // col
25
-
26
  return top, bottom
27
 
28
 
@@ -84,27 +81,27 @@ class BaseCaptioner:
84
  self.enable_filter = enable_filter
85
  if enable_filter:
86
  self.filter, self.preprocess = clip.load('ViT-B/32', device)
87
- self.threshold = 0.2
88
 
89
  @torch.no_grad()
90
- def filter_caption(self, image: Union[np.ndarray, Image.Image, str], caption: str):
91
-
92
  image = load_image(image, return_type='pil')
93
-
94
  image = self.preprocess(image).unsqueeze(0).to(self.device) # (1, 3, 224, 224)
95
- text = clip.tokenize(caption).to(self.device) # (1, 77)
 
 
 
96
  image_features = self.filter.encode_image(image) # (1, 512)
97
- text_features = self.filter.encode_text(text) # (1, 512)
98
  image_features /= image_features.norm(dim=-1, keepdim=True)
99
  text_features /= text_features.norm(dim=-1, keepdim=True)
100
- similarity = torch.matmul(image_features, text_features.transpose(1, 0)).item()
101
- if similarity < self.threshold:
102
- print('There seems to be nothing where you clicked.')
103
- out = ""
104
  else:
105
- out = caption
106
  print(f'Clip score of the caption is {similarity}')
107
- return out
108
 
109
  def inference(self, image: Union[np.ndarray, Image.Image, str], filter: bool = False):
110
  raise NotImplementedError()
@@ -112,7 +109,7 @@ class BaseCaptioner:
112
  def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, filter: bool = False):
113
  raise NotImplementedError()
114
 
115
- def inference_box(self, image: Union[np.ndarray, Image.Image, str], box: Union[list, np.ndarray], filter=False):
116
  image = load_image(image, return_type="pil")
117
 
118
  if np.array(box).size == 4:
@@ -123,23 +120,31 @@ class BaseCaptioner:
123
  elif np.array(box).size == 8: # four corners of an irregular rectangle
124
  image_crop = cut_box(np.array(image), box)
125
 
126
- crop_save_path = f'result/crop_{time.time()}.png'
127
- Image.fromarray(image_crop).save(crop_save_path)
128
- print(f'croped image saved in {crop_save_path}')
129
- caption = self.inference(image_crop, filter)
130
- return caption, crop_save_path
131
-
132
- def inference_seg(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str] = None,
133
- crop_mode="w_bg", filter=False, disable_regular_box=False):
 
 
 
 
 
 
 
 
 
134
  if seg_mask is None:
135
  seg_mask = np.ones(image.size).astype(bool)
136
-
137
  image = load_image(image, return_type="pil")
138
  seg_mask = load_image(seg_mask, return_type="pil")
139
 
140
  seg_mask = seg_mask.resize(image.size)
141
  seg_mask = np.array(seg_mask) > 0
142
-
143
  if crop_mode == "wo_bg":
144
  image = np.array(image) * seg_mask[:, :, np.newaxis] + (1 - seg_mask[:, :, np.newaxis]) * 255
145
  image = np.uint8(image)
@@ -150,10 +155,13 @@ class BaseCaptioner:
150
  min_area_box = seg_to_box(seg_mask)
151
  else:
152
  min_area_box = new_seg_to_box(seg_mask)
153
- return self.inference_box(image, min_area_box, filter)
154
 
155
- def generate_seg_cropped_image(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str],
156
- crop_mode="w_bg", disable_regular_box=False):
 
 
 
157
  image = load_image(image, return_type="pil")
158
  seg_mask = load_image(seg_mask, return_type="pil")
159
 
 
5
  import pdb
6
  import cv2
7
  import numpy as np
8
+ from typing import Any, Union, List
9
  import time
10
  import clip
11
 
 
16
  col = inputs.shape[1]
17
  inputs = inputs.reshape(-1)
18
  lens = len(inputs)
 
19
  start = np.argmax(inputs)
20
  end = lens - 1 - np.argmax(np.flip(inputs))
 
21
  top = start // col
22
  bottom = end // col
 
23
  return top, bottom
24
 
25
 
 
81
  self.enable_filter = enable_filter
82
  if enable_filter:
83
  self.filter, self.preprocess = clip.load('ViT-B/32', device)
 
84
 
85
  @torch.no_grad()
86
+ def filter_caption(self, image: Union[np.ndarray, Image.Image, str], caption: str, reference_caption: List[str]=[]):
 
87
  image = load_image(image, return_type='pil')
 
88
  image = self.preprocess(image).unsqueeze(0).to(self.device) # (1, 3, 224, 224)
89
+ captions = [caption]
90
+ if len(reference_caption):
91
+ captions.extend(reference_caption)
92
+ text = clip.tokenize(captions).to(self.device) # (>1, 77)
93
  image_features = self.filter.encode_image(image) # (1, 512)
94
+ text_features = self.filter.encode_text(text) # # (>1, 512)
95
  image_features /= image_features.norm(dim=-1, keepdim=True)
96
  text_features /= text_features.norm(dim=-1, keepdim=True)
97
+
98
+ if len(reference_caption):
99
+ similarity = torch.matmul(image_features, text_features.transpose(1, 0)) / 0.07
100
+ similarity = similarity.softmax(dim=1)[0, 0].item()
101
  else:
102
+ similarity = torch.matmul(image_features, text_features.transpose(1, 0)).item()
103
  print(f'Clip score of the caption is {similarity}')
104
+ return similarity
105
 
106
  def inference(self, image: Union[np.ndarray, Image.Image, str], filter: bool = False):
107
  raise NotImplementedError()
 
109
  def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, filter: bool = False):
110
  raise NotImplementedError()
111
 
112
+ def inference_box(self, image: Union[np.ndarray, Image.Image, str], box: Union[list, np.ndarray], filter=False, verbose=False, caption_args={}):
113
  image = load_image(image, return_type="pil")
114
 
115
  if np.array(box).size == 4:
 
120
  elif np.array(box).size == 8: # four corners of an irregular rectangle
121
  image_crop = cut_box(np.array(image), box)
122
 
123
+ crop_save_path = None
124
+ if verbose:
125
+ crop_save_path = f'result/crop_{time.time()}.png'
126
+ Image.fromarray(image_crop).save(crop_save_path)
127
+ print(f'croped image saved in {crop_save_path}')
128
+ caption = self.inference(image_crop, filter, caption_args)
129
+ caption.update({'crop_save_path': crop_save_path})
130
+ return caption
131
+
132
+ def inference_seg(self,
133
+ image: Union[np.ndarray, str],
134
+ seg_mask: Union[np.ndarray, Image.Image, str] = None,
135
+ crop_mode="w_bg",
136
+ filter=False,
137
+ disable_regular_box=False,
138
+ verbose=False,
139
+ caption_args={}):
140
  if seg_mask is None:
141
  seg_mask = np.ones(image.size).astype(bool)
142
+
143
  image = load_image(image, return_type="pil")
144
  seg_mask = load_image(seg_mask, return_type="pil")
145
 
146
  seg_mask = seg_mask.resize(image.size)
147
  seg_mask = np.array(seg_mask) > 0
 
148
  if crop_mode == "wo_bg":
149
  image = np.array(image) * seg_mask[:, :, np.newaxis] + (1 - seg_mask[:, :, np.newaxis]) * 255
150
  image = np.uint8(image)
 
155
  min_area_box = seg_to_box(seg_mask)
156
  else:
157
  min_area_box = new_seg_to_box(seg_mask)
158
+ return self.inference_box(image, min_area_box, filter, verbose, caption_args)
159
 
160
+ def generate_seg_cropped_image(self,
161
+ image: Union[np.ndarray, str],
162
+ seg_mask: Union[np.ndarray, Image.Image, str],
163
+ crop_mode="w_bg",
164
+ disable_regular_box=False):
165
  image = load_image(image, return_type="pil")
166
  seg_mask = load_image(seg_mask, return_type="pil")
167
 
caption_anything/captioner/blip.py CHANGED
@@ -20,19 +20,24 @@ class BLIPCaptioner(BaseCaptioner):
20
  torch_dtype=self.torch_dtype).to(self.device)
21
 
22
  @torch.no_grad()
23
- def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
24
  image = load_image(image, return_type="pil")
25
  inputs = self.processor(image, return_tensors="pt").to(self.device, self.torch_dtype)
26
  out = self.model.generate(**inputs, max_new_tokens=50)
27
  captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
 
 
28
  if self.enable_filter and filter:
29
- captions = self.filter_caption(image, captions)
 
 
30
  print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
31
- return captions
32
 
33
  @torch.no_grad()
34
  def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg",
35
  filter=False, disable_regular_box=False):
 
36
  crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode,
37
  disable_regular_box=disable_regular_box)
38
  image = load_image(image, return_type="pil")
@@ -47,9 +52,11 @@ class BLIPCaptioner(BaseCaptioner):
47
  out = self.model.generate(pixel_values=pixel_values, pixel_masks=pixel_masks, max_new_tokens=50)
48
  captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
49
  if self.enable_filter and filter:
50
- captions = self.filter_caption(image, captions)
 
 
51
  print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
52
- return captions, crop_save_path
53
 
54
 
55
  if __name__ == '__main__':
 
20
  torch_dtype=self.torch_dtype).to(self.device)
21
 
22
  @torch.no_grad()
23
+ def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False, args={}):
24
  image = load_image(image, return_type="pil")
25
  inputs = self.processor(image, return_tensors="pt").to(self.device, self.torch_dtype)
26
  out = self.model.generate(**inputs, max_new_tokens=50)
27
  captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
28
+
29
+ result = {}
30
  if self.enable_filter and filter:
31
+ clip_score = self.filter_caption(image, captions)
32
+ result['clip_score'] = clip_score
33
+ result.update({'caption':captions})
34
  print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
35
+ return {'caption': captions}
36
 
37
  @torch.no_grad()
38
  def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg",
39
  filter=False, disable_regular_box=False):
40
+ result = {}
41
  crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode,
42
  disable_regular_box=disable_regular_box)
43
  image = load_image(image, return_type="pil")
 
52
  out = self.model.generate(pixel_values=pixel_values, pixel_masks=pixel_masks, max_new_tokens=50)
53
  captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
54
  if self.enable_filter and filter:
55
+ clip_score = self.filter_caption(image, captions)
56
+ result['clip_score'] = clip_score
57
+ result.update({'caption':captions, 'crop_save_path':crop_save_path})
58
  print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
59
+ return result
60
 
61
 
62
  if __name__ == '__main__':
caption_anything/captioner/blip2.py CHANGED
@@ -20,18 +20,31 @@ class BLIP2Captioner(BaseCaptioner):
20
  self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map='sequential', load_in_8bit=True)
21
 
22
  @torch.no_grad()
23
- def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
 
 
 
 
 
 
 
24
  image = load_image(image, return_type="pil")
25
-
26
  if not self.dialogue:
27
- text_prompt = 'The image shows'
28
- inputs = self.processor(image, text = text_prompt, return_tensors="pt").to(self.device, self.torch_dtype)
29
- out = self.model.generate(**inputs, max_new_tokens=50)
30
- captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
31
  if self.enable_filter and filter:
32
- captions = self.filter_caption(image, captions)
33
- print(f"\nProcessed ImageCaptioning by BLIP2Captioner, Output Text: {captions}")
34
- return captions
 
 
 
 
 
 
35
  else:
36
  context = []
37
  template = "Question: {} Answer: {}."
@@ -44,8 +57,8 @@ class BLIP2Captioner(BaseCaptioner):
44
  out = self.model.generate(**inputs, max_new_tokens=50)
45
  captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
46
  context.append((input_texts, captions))
47
-
48
- return captions
49
 
50
  if __name__ == '__main__':
51
 
 
20
  self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map='sequential', load_in_8bit=True)
21
 
22
  @torch.no_grad()
23
+ def inference(self,
24
+ image: Union[np.ndarray, Image.Image, str],
25
+ filter=False,
26
+ args={}):
27
+ args['return_ppl'] = args.get('return_ppl', False)
28
+ args['text_prompt'] = args.get('text_prompt', 'Question: what does the image show? Answer:')
29
+ args['reference_caption'] = args.get('reference_caption', [])
30
+
31
  image = load_image(image, return_type="pil")
32
+ result = {}
33
  if not self.dialogue:
34
+ inputs = self.processor(image, text = args['text_prompt'], return_tensors="pt").to(self.device, self.torch_dtype)
35
+ out = self.model.generate(**inputs, return_dict_in_generate=True, output_scores=True, max_new_tokens=50)
36
+ captions = self.processor.batch_decode(out.sequences, skip_special_tokens=True)
37
+ caption = [caption.strip() for caption in captions][0]
38
  if self.enable_filter and filter:
39
+ print('reference caption: {}, caption: {}'.format(args['reference_caption'], caption))
40
+ clip_score = self.filter_caption(image, caption, args['reference_caption'])
41
+ result['clip_score'] = clip_score
42
+ if args['return_ppl']:
43
+ ppl_score = torch.stack(out.scores, dim=1).softmax(dim=2).log().max(dim=2)[0].sum(dim=1)[0]
44
+ result['ppl_score'] = ppl_score.item()
45
+ print(f"\nProcessed ImageCaptioning by BLIP2Captioner, Output Text: {caption}")
46
+ result['caption'] = caption
47
+ return result
48
  else:
49
  context = []
50
  template = "Question: {} Answer: {}."
 
57
  out = self.model.generate(**inputs, max_new_tokens=50)
58
  captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
59
  context.append((input_texts, captions))
60
+ result['caption'] = captions
61
+ return result
62
 
63
  if __name__ == '__main__':
64
 
caption_anything/captioner/git.py CHANGED
@@ -19,19 +19,24 @@ class GITCaptioner(BaseCaptioner):
19
  self.model = GitForCausalLM.from_pretrained("microsoft/git-large", torch_dtype=self.torch_dtype).to(self.device)
20
 
21
  @torch.no_grad()
22
- def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
23
  image = load_image(image, return_type="pil")
24
  pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device, self.torch_dtype)
25
  generated_ids = self.model.generate(pixel_values=pixel_values, max_new_tokens=50)
26
- generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
 
 
27
  if self.enable_filter and filter:
28
- captions = self.filter_caption(image, captions)
29
- print(f"\nProcessed ImageCaptioning by GITCaptioner, Output Text: {generated_caption}")
30
- return generated_caption
 
 
31
 
32
  @torch.no_grad()
33
  def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg",
34
  filter=False, disable_regular_box=False):
 
35
  crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode,
36
  disable_regular_box=disable_regular_box)
37
  image = load_image(image, return_type="pil")
@@ -46,9 +51,11 @@ class GITCaptioner(BaseCaptioner):
46
  out = self.model.generate(pixel_values=pixel_values, pixel_masks=pixel_masks, max_new_tokens=50)
47
  captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
48
  if self.enable_filter and filter:
49
- captions = self.filter_caption(image, captions)
 
50
  print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
51
- return captions, crop_save_path
 
52
 
53
 
54
  if __name__ == '__main__':
 
19
  self.model = GitForCausalLM.from_pretrained("microsoft/git-large", torch_dtype=self.torch_dtype).to(self.device)
20
 
21
  @torch.no_grad()
22
+ def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False, args={}):
23
  image = load_image(image, return_type="pil")
24
  pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device, self.torch_dtype)
25
  generated_ids = self.model.generate(pixel_values=pixel_values, max_new_tokens=50)
26
+ captions = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
27
+
28
+ result = {}
29
  if self.enable_filter and filter:
30
+ clip_score = self.filter_caption(image, captions)
31
+ result['clip_score'] = clip_score
32
+ result.update({'caption':captions})
33
+ print(f"\nProcessed ImageCaptioning by GITCaptioner, Output Text: {captions}")
34
+ return {'caption': captions}
35
 
36
  @torch.no_grad()
37
  def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg",
38
  filter=False, disable_regular_box=False):
39
+ result = {}
40
  crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode,
41
  disable_regular_box=disable_regular_box)
42
  image = load_image(image, return_type="pil")
 
51
  out = self.model.generate(pixel_values=pixel_values, pixel_masks=pixel_masks, max_new_tokens=50)
52
  captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
53
  if self.enable_filter and filter:
54
+ clip_score = self.filter_caption(image, captions)
55
+ result['clip_score'] = clip_score
56
  print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
57
+ result.update({'caption':captions, 'crop_save_path':crop_save_path})
58
+ return result
59
 
60
 
61
  if __name__ == '__main__':
caption_anything/model.py CHANGED
@@ -5,24 +5,33 @@ import time
5
  from PIL import Image
6
  import cv2
7
  import numpy as np
 
 
 
8
  from caption_anything.captioner import build_captioner, BaseCaptioner
9
- from caption_anything.segmenter import build_segmenter
10
  from caption_anything.text_refiner import build_text_refiner
11
-
12
-
 
 
13
  class CaptionAnything:
14
  def __init__(self, args, api_key="", captioner=None, segmenter=None, text_refiner=None):
15
  self.args = args
16
  self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner
17
  self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter
18
-
 
 
 
19
  self.text_refiner = None
20
  if not args.disable_gpt:
21
  if text_refiner is not None:
22
  self.text_refiner = text_refiner
23
- else:
24
  self.init_refiner(api_key)
25
-
 
26
  @property
27
  def image_embedding(self):
28
  return self.segmenter.image_embedding
@@ -61,65 +70,195 @@ class CaptionAnything:
61
  self.text_refiner = None
62
  print('OpenAI GPT is not available')
63
 
64
- def inference(self, image, prompt, controls, disable_gpt=False, enable_wiki=False):
65
- # TODO: Add support to multiple seg masks.
66
-
67
  # segment with prompt
68
  print("CA prompt: ", prompt, "CA controls", controls)
69
- seg_mask = self.segmenter.inference(image, prompt)[0, ...]
70
-
71
- if self.args.enable_morphologyex:
72
- seg_mask = 255 * seg_mask.astype(np.uint8)
73
- seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis=-1)
74
- seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_OPEN, kernel=np.ones((6, 6), np.uint8))
75
- seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_CLOSE, kernel=np.ones((6, 6), np.uint8))
76
- seg_mask = seg_mask[:, :, 0] > 0
77
- mask_save_path = f'result/mask_{time.time()}.png'
78
- if not os.path.exists(os.path.dirname(mask_save_path)):
79
- os.makedirs(os.path.dirname(mask_save_path))
80
- seg_mask_img = Image.fromarray(seg_mask.astype('int') * 255.)
81
- if seg_mask_img.mode != 'RGB':
82
- seg_mask_img = seg_mask_img.convert('RGB')
83
- seg_mask_img.save(mask_save_path)
84
- print('seg_mask path: ', mask_save_path)
85
- print("seg_mask.shape: ", seg_mask.shape)
86
-
87
- # captioning with mask
88
- if self.args.enable_reduce_tokens:
89
- caption, crop_save_path = self.captioner. \
90
- inference_with_reduced_tokens(image, seg_mask,
91
- crop_mode=self.args.seg_crop_mode,
92
- filter=self.args.clip_filter,
93
- disable_regular_box=self.args.disable_regular_box)
94
- else:
95
- caption, crop_save_path = self.captioner. \
96
- inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode,
97
- filter=self.args.clip_filter,
98
- disable_regular_box=self.args.disable_regular_box)
99
-
100
- # refining with TextRefiner
101
- context_captions = []
102
- if self.args.context_captions:
103
- context_captions.append(self.captioner.inference(image))
104
- if not disable_gpt and self.text_refiner is not None:
105
- refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions,
106
- enable_wiki=enable_wiki)
107
  else:
108
- refined_caption = {'raw_caption': caption}
109
- out = {'generated_captions': refined_caption,
110
- 'crop_save_path': crop_save_path,
111
- 'mask_save_path': mask_save_path,
112
- 'mask': seg_mask_img,
113
- 'context_captions': context_captions}
114
- return out
 
 
 
 
115
 
 
 
 
 
 
 
 
 
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  if __name__ == "__main__":
118
  from caption_anything.utils.parser import parse_augment
119
-
120
  args = parse_augment()
121
- # image_path = 'test_images/img3.jpg'
122
- image_path = 'test_images/img1.jpg'
123
  prompts = [
124
  {
125
  "prompt_type": ["click"],
@@ -127,12 +266,12 @@ if __name__ == "__main__":
127
  "input_label": [1, 0],
128
  "multimask_output": "True",
129
  },
130
- {
131
- "prompt_type": ["click"],
132
- "input_point": [[300, 800]],
133
- "input_label": [1],
134
- "multimask_output": "True",
135
- }
136
  ]
137
  controls = {
138
  "length": "30",
@@ -143,11 +282,11 @@ if __name__ == "__main__":
143
  }
144
 
145
  model = CaptionAnything(args, os.environ['OPENAI_API_KEY'])
146
- for prompt in prompts:
147
- print('*' * 30)
148
- print('Image path: ', image_path)
149
- image = Image.open(image_path)
150
- print(image)
151
- print('Visual controls (SAM prompt):\n', prompt)
152
- print('Language controls:\n', controls)
153
- out = model.inference(image_path, prompt, controls)
 
5
  from PIL import Image
6
  import cv2
7
  import numpy as np
8
+ from PIL import Image
9
+ import easyocr
10
+ import copy
11
  from caption_anything.captioner import build_captioner, BaseCaptioner
12
+ from caption_anything.segmenter import build_segmenter, build_segmenter_densecap
13
  from caption_anything.text_refiner import build_text_refiner
14
+ from caption_anything.utils.utils import prepare_segmenter, seg_model_map, load_image, get_image_shape
15
+ from caption_anything.utils.utils import mask_painter_foreground_all, mask_painter, xywh_to_x1y1x2y2, image_resize
16
+ from caption_anything.utils.densecap_painter import draw_bbox
17
+
18
  class CaptionAnything:
19
  def __init__(self, args, api_key="", captioner=None, segmenter=None, text_refiner=None):
20
  self.args = args
21
  self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner
22
  self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter
23
+ self.segmenter_densecap = build_segmenter_densecap(args.segmenter, args.device, args, model=self.segmenter.model)
24
+
25
+ self.lang = ["ch_tra", "en"]
26
+ self.reader = easyocr.Reader(self.lang)
27
  self.text_refiner = None
28
  if not args.disable_gpt:
29
  if text_refiner is not None:
30
  self.text_refiner = text_refiner
31
+ elif api_key != "":
32
  self.init_refiner(api_key)
33
+ self.require_caption_prompt = args.captioner == 'blip2'
34
+
35
  @property
36
  def image_embedding(self):
37
  return self.segmenter.image_embedding
 
70
  self.text_refiner = None
71
  print('OpenAI GPT is not available')
72
 
73
+ def inference(self, image, prompt, controls, disable_gpt=False, enable_wiki=False, verbose=False, is_densecap=False, args={}):
 
 
74
  # segment with prompt
75
  print("CA prompt: ", prompt, "CA controls", controls)
76
+ is_seg_everything = 'everything' in prompt['prompt_type']
77
+
78
+ args['seg_crop_mode'] = args.get('seg_crop_mode', self.args.seg_crop_mode)
79
+ args['clip_filter'] = args.get('clip_filter', self.args.clip_filter)
80
+ args['disable_regular_box'] = args.get('disable_regular_box', self.args.disable_regular_box)
81
+ args['context_captions'] = args.get('context_captions', self.args.context_captions)
82
+ args['enable_reduce_tokens'] = args.get('enable_reduce_tokens', self.args.enable_reduce_tokens)
83
+ args['enable_morphologyex'] = args.get('enable_morphologyex', self.args.enable_morphologyex)
84
+ args['topN'] = args.get('topN', 10) if is_seg_everything else 1
85
+ args['min_mask_area'] = args.get('min_mask_area', 0)
86
+
87
+ if not is_densecap:
88
+ seg_results = self.segmenter.inference(image, prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  else:
90
+ seg_results = self.segmenter_densecap.inference(image, prompt)
91
+
92
+ seg_masks, seg_bbox, seg_area = seg_results if is_seg_everything else (seg_results, None, None)
93
+
94
+ if args['topN'] > 1: # sort by area
95
+ samples = list(zip(*[seg_masks, seg_bbox, seg_area]))
96
+ # top_samples = sorted(samples, key=lambda x: x[2], reverse=True)
97
+ # seg_masks, seg_bbox, seg_area = list(zip(*top_samples))
98
+ samples = list(filter(lambda x: x[2] > args['min_mask_area'], samples))
99
+ samples = samples[:args['topN']]
100
+ seg_masks, seg_bbox, seg_area = list(zip(*samples))
101
 
102
+ out_list = []
103
+ for i, seg_mask in enumerate(seg_masks):
104
+ if args['enable_morphologyex']:
105
+ seg_mask = 255 * seg_mask.astype(np.uint8)
106
+ seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis=-1)
107
+ seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_OPEN, kernel=np.ones((6, 6), np.uint8))
108
+ seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_CLOSE, kernel=np.ones((6, 6), np.uint8))
109
+ seg_mask = seg_mask[:, :, 0] > 0
110
 
111
+ seg_mask_img = Image.fromarray(seg_mask.astype('int') * 255.)
112
+ mask_save_path = None
113
+
114
+ if verbose:
115
+ mask_save_path = f'result/mask_{time.time()}.png'
116
+ if not os.path.exists(os.path.dirname(mask_save_path)):
117
+ os.makedirs(os.path.dirname(mask_save_path))
118
+
119
+ if seg_mask_img.mode != 'RGB':
120
+ seg_mask_img = seg_mask_img.convert('RGB')
121
+ seg_mask_img.save(mask_save_path)
122
+ print('seg_mask path: ', mask_save_path)
123
+ print("seg_mask.shape: ", seg_mask.shape)
124
+
125
+
126
+ # captioning with mask
127
+ if args['enable_reduce_tokens']:
128
+ result = self.captioner.inference_with_reduced_tokens(image, seg_mask,
129
+ crop_mode=args['seg_crop_mode'],
130
+ filter=args['clip_filter'],
131
+ disable_regular_box=args['disable_regular_box'],
132
+ verbose=verbose,
133
+ caption_args=args)
134
+ else:
135
+ result = self.captioner.inference_seg(image, seg_mask,
136
+ crop_mode=args['seg_crop_mode'],
137
+ filter=args['clip_filter'],
138
+ disable_regular_box=args['disable_regular_box'],
139
+ verbose=verbose,
140
+ caption_args=args)
141
+ caption = result.get('caption', None)
142
+ crop_save_path = result.get('crop_save_path', None)
143
+
144
+ # refining with TextRefiner
145
+ context_captions = []
146
+ if args['context_captions']:
147
+ context_captions.append(self.captioner.inference(image)['caption'])
148
+ if not disable_gpt and self.text_refiner is not None:
149
+ refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions,
150
+ enable_wiki=enable_wiki)
151
+ else:
152
+ refined_caption = {'raw_caption': caption}
153
+ out = {'generated_captions': refined_caption,
154
+ 'crop_save_path': crop_save_path,
155
+ 'mask_save_path': mask_save_path,
156
+ 'mask': seg_mask_img,
157
+ 'bbox': seg_bbox[i] if seg_bbox is not None else None,
158
+ 'area': seg_area[i] if seg_area is not None else None,
159
+ 'context_captions': context_captions,
160
+ 'ppl_score': result.get('ppl_score', -100.),
161
+ 'clip_score': result.get('clip_score', 0.)
162
+ }
163
+ out_list.append(out)
164
+ return out_list
165
+
166
+ def parse_dense_caption(self, image, topN=10, reference_caption=[], verbose=False):
167
+ width, height = get_image_shape(image)
168
+ prompt = {'prompt_type': ['everything']}
169
+ densecap_args = {
170
+ 'return_ppl': True,
171
+ 'clip_filter': True,
172
+ 'reference_caption': reference_caption,
173
+ 'text_prompt': "", # 'Question: what does the image show? Answer:'
174
+ 'seg_crop_mode': 'w_bg',
175
+ # 'text_prompt': "",
176
+ # 'seg_crop_mode': 'wo_bg',
177
+ 'disable_regular_box': False,
178
+ 'topN': topN,
179
+ 'min_ppl_score': -1.8,
180
+ 'min_clip_score': 0.30,
181
+ 'min_mask_area': 2500,
182
+ }
183
+
184
+ dense_captions = self.inference(image, prompt,
185
+ controls=None,
186
+ disable_gpt=True,
187
+ verbose=verbose,
188
+ is_densecap=True,
189
+ args=densecap_args)
190
+ print('Process Dense Captioning: \n', dense_captions)
191
+ dense_captions = list(filter(lambda x: x['ppl_score'] / (1+len(x['generated_captions']['raw_caption'].split())) >= densecap_args['min_ppl_score'], dense_captions))
192
+ dense_captions = list(filter(lambda x: x['clip_score'] >= densecap_args['min_clip_score'], dense_captions))
193
+ dense_cap_prompt = []
194
+ for cap in dense_captions:
195
+ x, y, w, h = cap['bbox']
196
+ cx, cy = x + w/2, (y + h/2)
197
+ dense_cap_prompt.append("({}: X:{:.0f}, Y:{:.0f}, Width:{:.0f}, Height:{:.0f})".format(cap['generated_captions']['raw_caption'], cx, cy, w, h))
198
+
199
+ if verbose:
200
+ all_masks = [np.array(item['mask'].convert('P')) for item in dense_captions]
201
+ new_image = mask_painter_foreground_all(np.array(image), all_masks, background_alpha=0.4)
202
+ save_path = 'result/dense_caption_mask.png'
203
+ Image.fromarray(new_image).save(save_path)
204
+ print(f'Dense captioning mask saved in {save_path}')
205
+
206
+ vis_path = 'result/dense_caption_vis_{}.png'.format(time.time())
207
+ dense_cap_painter_input = [{'bbox': xywh_to_x1y1x2y2(cap['bbox']),
208
+ 'caption': cap['generated_captions']['raw_caption']} for cap in dense_captions]
209
+ draw_bbox(load_image(image, return_type='numpy'), vis_path, dense_cap_painter_input, show_caption=True)
210
+ print(f'Dense Captioning visualization saved in {vis_path}')
211
+ return ','.join(dense_cap_prompt)
212
+
213
+ def parse_ocr(self, image, thres=0.2):
214
+ width, height = get_image_shape(image)
215
+ image = load_image(image, return_type='numpy')
216
+ bounds = self.reader.readtext(image)
217
+ bounds = [bound for bound in bounds if bound[2] > thres]
218
+ print('Process OCR Text:\n', bounds)
219
+
220
+ ocr_prompt = []
221
+ for box, text, conf in bounds:
222
+ p0, p1, p2, p3 = box
223
+ ocr_prompt.append('(\"{}\": X:{:.0f}, Y:{:.0f})'.format(text, (p0[0]+p1[0]+p2[0]+p3[0])/4, (p0[1]+p1[1]+p2[1]+p3[1])/4))
224
+ ocr_prompt = '\n'.join(ocr_prompt)
225
+
226
+ # ocr_prompt = self.text_refiner.llm(f'The image have some scene texts with their locations: {ocr_prompt}. Please group these individual words into one or several phrase based on their relative positions (only give me your answer, do not show explanination)').strip()
227
+
228
+ # ocr_prefix1 = f'The image have some scene texts with their locations: {ocr_prompt}. Please group these individual words into one or several phrase based on their relative positions (only give me your answer, do not show explanination)'
229
+ # ocr_prefix2 = f'Please group these individual words into 1-3 phrases, given scene texts with their locations: {ocr_prompt}. You return is one or several strings and infer their locations. (only give me your answer like (“man working”, X: value, Y: value), do not show explanination)'
230
+ # ocr_prefix4 = f'summarize the individual scene text words detected by OCR tools into a fluent sentence based on their positions and distances. You should strictly describe all of the given scene text words. Do not miss any given word. Do not create non-exist words. Do not appear numeric positions. The individual words are given:\n{ocr_prompt}\n'
231
+ # ocr_prefix3 = f'combine the individual scene text words detected by OCR tools into one/several fluent phrases/sentences based on their positions and distances. You should strictly copy or correct all of the given scene text words. Do not miss any given word. Do not create non-exist words. The response is several strings seperate with their location (X, Y), each of which represents a phrase. The individual words are given:\n{ocr_prompt}\n'
232
+ # response = self.text_refiner.llm(ocr_prefix3).strip() if len(ocr_prompt) else ""
233
+ return ocr_prompt
234
+
235
+ def inference_cap_everything(self, image, verbose=False):
236
+ image = load_image(image, return_type='pil')
237
+ image = image_resize(image, res=1024)
238
+ width, height = get_image_shape(image)
239
+ other_args = {'text_prompt': ""} if self.require_caption_prompt else {}
240
+ img_caption = self.captioner.inference(image, filter=False, args=other_args)['caption']
241
+ dense_caption_prompt = self.parse_dense_caption(image, topN=10, verbose=verbose, reference_caption=[])
242
+ scene_text_prompt = self.parse_ocr(image, thres=0.2)
243
+ # scene_text_prompt = "N/A"
244
+
245
+ # the summarize_prompt is modified from https://github.com/JialianW/GRiT and https://github.com/showlab/Image2Paragraph
246
+ summarize_prompt = "Imagine you are a blind but intelligent image captioner. You should generate a descriptive, coherent and human-like paragraph based on the given information (a,b,c,d) instead of imagination:\na) Image Resolution: {image_size}\nb) Image Caption:{image_caption}\nc) Dense Caption: {dense_caption}\nd) Scene Text: {scene_text}\nThere are some rules for your response: Show objects with their attributes (e.g. position, color, size, shape, texture).\nPrimarily describe common objects with large size.\nProvide context of the image.\nShow relative position between objects.\nLess than 6 sentences.\nDo not appear number.\nDo not describe any individual letter.\nDo not show the image resolution.\nIngore the white background."
247
+ prompt = summarize_prompt.format(**{
248
+ "image_size": "width {} height {}".format(width, height),
249
+ "image_caption":img_caption,
250
+ "dense_caption": dense_caption_prompt,
251
+ "scene_text": scene_text_prompt})
252
+ print(f'caption everything prompt: {prompt}')
253
+ response = self.text_refiner.llm(prompt).strip()
254
+ # chinese_response = self.text_refiner.llm('Translate it into Chinese: {}'.format(response)).strip()
255
+ return response
256
+
257
  if __name__ == "__main__":
258
  from caption_anything.utils.parser import parse_augment
 
259
  args = parse_augment()
260
+ image_path = 'image/ocr/Untitled.png'
261
+ image = Image.open(image_path)
262
  prompts = [
263
  {
264
  "prompt_type": ["click"],
 
266
  "input_label": [1, 0],
267
  "multimask_output": "True",
268
  },
269
+ # {
270
+ # "prompt_type": ["click"],
271
+ # "input_point": [[300, 800]],
272
+ # "input_label": [1],
273
+ # "multimask_output": "True",
274
+ # }
275
  ]
276
  controls = {
277
  "length": "30",
 
282
  }
283
 
284
  model = CaptionAnything(args, os.environ['OPENAI_API_KEY'])
285
+ img_dir = 'test_images/memes'
286
+ for image_file in os.listdir(img_dir):
287
+ image_path = os.path.join(img_dir, image_file)
288
+ print('image_path:', image_path)
289
+ paragraph = model.inference_cap_everything(image_path, verbose=True)
290
+ print('Caption Everything:\n', paragraph)
291
+ ocr = model.parse_ocr(image_path)
292
+ print('OCR', ocr)
caption_anything/segmenter/__init__.py CHANGED
@@ -1,5 +1,14 @@
1
  from .base_segmenter import BaseSegmenter
2
  from caption_anything.utils.utils import seg_model_map
 
3
 
4
- def build_segmenter(model_name, device, args=None, model=None):
5
- return BaseSegmenter(device, args.segmenter_checkpoint, model_name, reuse_feature=not args.disable_reuse_features, model=model)
 
 
 
 
 
 
 
 
 
1
  from .base_segmenter import BaseSegmenter
2
  from caption_anything.utils.utils import seg_model_map
3
+ import copy
4
 
5
+ def build_segmenter(model_name, device, args, model=None):
6
+ return BaseSegmenter(device, args.segmenter_checkpoint, model_name, reuse_feature=not args.disable_reuse_features, model=model, args=args)
7
+
8
+ def build_segmenter_densecap(model_name, device, args, model=None):
9
+ args_for_densecap = copy.deepcopy(args)
10
+ args_for_densecap.pred_iou_thresh = 0.88
11
+ args_for_densecap.min_mask_region_area = 400
12
+ args_for_densecap.stability_score_thresh = 0.95
13
+ args_for_densecap.box_nms_thresh = 0.3
14
+ return BaseSegmenter(device, args.segmenter_checkpoint, model_name, reuse_feature=not args.disable_reuse_features, model=model, args=args)
caption_anything/segmenter/base_segmenter.py CHANGED
@@ -11,7 +11,7 @@ import PIL
11
 
12
 
13
  class BaseSegmenter:
14
- def __init__(self, device, checkpoint, model_name='huge', reuse_feature=True, model=None):
15
  print(f"Initializing BaseSegmenter to {device}")
16
  self.device = device
17
  self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
@@ -26,7 +26,10 @@ class BaseSegmenter:
26
  self.model = model
27
  self.reuse_feature = reuse_feature
28
  self.predictor = SamPredictor(self.model)
29
- self.mask_generator = SamAutomaticMaskGenerator(self.model)
 
 
 
30
  self.image_embedding = None
31
  self.image = None
32
 
@@ -69,7 +72,9 @@ class BaseSegmenter:
69
  if 'everything' in control['prompt_type']:
70
  masks = self.mask_generator.generate(image)
71
  new_masks = np.concatenate([mask["segmentation"][np.newaxis, :] for mask in masks])
72
- return new_masks
 
 
73
  else:
74
  if not self.reuse_feature or self.image_embedding is None:
75
  self.set_image(image)
 
11
 
12
 
13
  class BaseSegmenter:
14
+ def __init__(self, device, checkpoint, model_name='huge', reuse_feature=True, model=None, args=None):
15
  print(f"Initializing BaseSegmenter to {device}")
16
  self.device = device
17
  self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
 
26
  self.model = model
27
  self.reuse_feature = reuse_feature
28
  self.predictor = SamPredictor(self.model)
29
+
30
+ sam_generator_keys = ['pred_iou_thresh', 'min_mask_region_area', 'stability_score_thresh', 'box_nms_thresh']
31
+ generator_args = {k:v for k,v in vars(args).items() if k in sam_generator_keys}
32
+ self.mask_generator = SamAutomaticMaskGenerator(model=self.model, **generator_args)
33
  self.image_embedding = None
34
  self.image = None
35
 
 
72
  if 'everything' in control['prompt_type']:
73
  masks = self.mask_generator.generate(image)
74
  new_masks = np.concatenate([mask["segmentation"][np.newaxis, :] for mask in masks])
75
+ bbox = np.array([mask["bbox"] for mask in masks])
76
+ area = np.array([mask["area"] for mask in masks])
77
+ return new_masks, bbox, area
78
  else:
79
  if not self.reuse_feature or self.image_embedding is None:
80
  self.set_image(image)
caption_anything/utils/densecap_painter.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import json
3
+ import numpy as np
4
+ from typing import List
5
+ import random
6
+ from typing import Union
7
+
8
+ def draw_bbox(img: Union[np.ndarray, str], save_name: str, bbox: List[dict], show_caption: bool = False):
9
+ """
10
+ bbox: [{'image_id': str, 'bbox': [x1, y1, x2, y2], 'caption': str}, ...]
11
+ """
12
+ if isinstance(img, str):
13
+ img = cv2.imread(img)
14
+
15
+ RGB = [0, 50, 100, 150, 200, 250]
16
+ for box in bbox:
17
+ box['bbox'] = [int(_) for _ in box['bbox']]
18
+ x1, y1, x2, y2 = box['bbox']
19
+ caption = box['caption']
20
+ box_color = random.choices(RGB, k = 3)
21
+ (text_width, text_height), _ = cv2.getTextSize(caption, cv2.FONT_HERSHEY_SIMPLEX, fontScale = 0.5, thickness = 2)
22
+ cv2.rectangle(img, (x1, y1), (x2, y2), color = box_color, thickness = 2)
23
+ if show_caption:
24
+ cv2.putText(img, caption, (x1, y1 + text_height), cv2.FONT_HERSHEY_SIMPLEX, fontScale = 0.5, color = box_color, thickness = 2)
25
+
26
+ cv2.imwrite(save_name, img)
27
+ # cv2.imshow('visualise', img)
28
+ # cv2.waitKey(0)
29
+
30
+ def parse_bbox(anno, image_id: int = None):
31
+
32
+ with open(anno, 'r') as f:
33
+ predictions = json.load(f)
34
+
35
+ if image_id is None:
36
+ image_id = next(iter(predictions))
37
+
38
+ return predictions[image_id]
39
+
40
+ def gt_bbox(anno, img_name: int = None):
41
+
42
+ with open(anno, 'r') as f:
43
+ annotations = json.load(f)
44
+ annotations = annotations['annotations']
45
+
46
+ gt = []
47
+ img_name = int(img_name[:-4])
48
+ for annotation in annotations:
49
+ if annotation['image_id'] == 63:
50
+ x1, y1, w, h = annotation['bbox']
51
+ gt.append({'bbox': [x1, y1, x1 + w, y1 + h], 'caption': annotation['caption']})
52
+ return gt
53
+
54
+ if __name__ == '__main__':
55
+
56
+ img_name = '63.jpg'
57
+ show_caption = True
58
+ anno = 'vg_dense_captioning_blip2_top48_0.88_1000_0.96_debugTrue_predictions_shard_all.json'
59
+
60
+ img = cv2.imread(img_name)
61
+ examp_bbox = parse_bbox(anno)
62
+ ground_truth_bbox = gt_bbox('test.json', img_name)
63
+ draw_bbox(img, 'GT.jpg', ground_truth_bbox, show_caption)
64
+ draw_bbox(img, 'Pred.jpg', examp_bbox, show_caption)
caption_anything/utils/parser.py CHANGED
@@ -22,6 +22,12 @@ def parse_augment():
22
  parser.add_argument('--disable_reuse_features', action="store_true", default=False)
23
  parser.add_argument('--enable_morphologyex', action="store_true", default=False)
24
  parser.add_argument('--chat_tools_dict', type=str, default='VisualQuestionAnswering_cuda:0', help='Visual ChatGPT tools, only useful when running gradio applications')
 
 
 
 
 
 
25
  args = parser.parse_args()
26
 
27
  if args.debug:
 
22
  parser.add_argument('--disable_reuse_features', action="store_true", default=False)
23
  parser.add_argument('--enable_morphologyex', action="store_true", default=False)
24
  parser.add_argument('--chat_tools_dict', type=str, default='VisualQuestionAnswering_cuda:0', help='Visual ChatGPT tools, only useful when running gradio applications')
25
+
26
+ parser.add_argument('--pred_iou_thresh', type=float, default=0.88, help="sam post-precessing")
27
+ parser.add_argument('--min_mask_region_area', type=int, default=0, help="sam post-precessing")
28
+ parser.add_argument('--stability_score_thresh', type=float, default=0.95, help='sam post-processing')
29
+ parser.add_argument('--box_nms_thresh', type=float, default=0.7, help='sam post-processing')
30
+
31
  args = parser.parse_args()
32
 
33
  if args.debug:
caption_anything/utils/utils.py CHANGED
@@ -29,6 +29,9 @@ def load_image(image: Union[np.ndarray, Image.Image, str], return_type='numpy'):
29
  elif isinstance(image, np.ndarray):
30
  image = Image.fromarray(image)
31
 
 
 
 
32
  if return_type == 'pil':
33
  return image
34
  elif return_type == 'numpy':
@@ -37,6 +40,34 @@ def load_image(image: Union[np.ndarray, Image.Image, str], return_type='numpy'):
37
  raise NotImplementedError()
38
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def is_platform_win():
41
  return sys.platform == "win32"
42
 
 
29
  elif isinstance(image, np.ndarray):
30
  image = Image.fromarray(image)
31
 
32
+ if image.mode == "RGBA":
33
+ image = image.convert("RGB")
34
+
35
  if return_type == 'pil':
36
  return image
37
  elif return_type == 'numpy':
 
40
  raise NotImplementedError()
41
 
42
 
43
+ def image_resize(image: Image.Image, res=1024):
44
+ width, height = org_size = image.size
45
+ ratio = min(1.0 * res / max(width, height), 1.0)
46
+ if ratio < 1.0:
47
+ image = image.resize((int(width * ratio), int(height * ratio)))
48
+ print('Scaling image from {} to {}'.format(org_size, image.size))
49
+ return image
50
+
51
+ def xywh_to_x1y1x2y2(bbox):
52
+ x, y, w, h = bbox
53
+ return x,y,x+w,y+h
54
+
55
+
56
+ def x1y1x2y2_to_xywh(bbox):
57
+ x1, y1, x2, y2 = bbox
58
+ return x1,y1,x2-x1,y2-y1
59
+
60
+
61
+ def get_image_shape(image):
62
+ if isinstance(image, str):
63
+ return Image.open(image).size
64
+ elif isinstance(image, np.ndarray):
65
+ return image.shape
66
+ elif isinstance(image, Image.Image):
67
+ return image.size
68
+ else:
69
+ raise NotImplementedError
70
+
71
  def is_platform_win():
72
  return sys.platform == "win32"
73
 
requirements.txt CHANGED
@@ -17,4 +17,7 @@ onnxruntime
17
  onnx
18
  https://gradio-builds.s3.amazonaws.com/3e68e5e882a6790ac5b457bd33f4edf9b695af90/gradio-3.24.1-py3-none-any.whl
19
  accelerate
20
- bitsandbytes
 
 
 
 
17
  onnx
18
  https://gradio-builds.s3.amazonaws.com/3e68e5e882a6790ac5b457bd33f4edf9b695af90/gradio-3.24.1-py3-none-any.whl
19
  accelerate
20
+ bitsandbytes
21
+ packaging~=23.1
22
+ easyocr
23
+ tensorboardX