ttengwang commited on
Commit
f1a2810
1 Parent(s): 89e01b9
app.py CHANGED
@@ -7,6 +7,7 @@ from gradio import processing_utils
7
 
8
  from packaging import version
9
  from PIL import Image, ImageDraw
 
10
 
11
  from caption_anything.model import CaptionAnything
12
  from caption_anything.utils.image_editing_utils import create_bubble_frame
@@ -22,7 +23,6 @@ from segment_anything import sam_model_registry
22
  args = parse_augment()
23
  args.segmenter = "huge"
24
  args.segmenter_checkpoint = "sam_vit_h_4b8939.pth"
25
-
26
  if args.segmenter_checkpoint is None:
27
  _, segmenter_checkpoint = prepare_segmenter(args.segmenter)
28
  else:
@@ -131,7 +131,7 @@ def chat_input_callback(*args):
131
  return state, state
132
 
133
  def upload_callback(image_input, state, visual_chatgpt=None):
134
-
135
  if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask
136
  image_input, mask = image_input['image'], image_input['mask']
137
 
@@ -162,7 +162,8 @@ def upload_callback(image_input, state, visual_chatgpt=None):
162
  img_caption, _ = model.captioner.inference_seg(image_input)
163
  Human_prompt = f'\nHuman: provide a new figure with path {new_image_path}. The description is: {img_caption}. This information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
164
  AI_prompt = "Received."
165
- visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
 
166
  state = [(None, 'Received new image, resize it to width {} and height {}: '.format(image_input.size[0], image_input.size[1]))]
167
 
168
  return state, state, image_input, click_state, image_input, image_input, image_embedding, \
@@ -309,12 +310,16 @@ def inference_traject(sketcher_image, enable_wiki, language, sentiment, factuali
309
 
310
  yield state, state, refined_image_input, wiki
311
 
312
- def clear_chat_memory(visual_chatgpt):
313
  if visual_chatgpt is not None:
314
  visual_chatgpt.memory.clear()
315
- visual_chatgpt.current_image = None
316
  visual_chatgpt.point_prompt = ""
317
-
 
 
 
 
 
318
  def get_style():
319
  current_version = version.parse(gr.__version__)
320
  if current_version <= version.parse('3.24.1'):
@@ -465,6 +470,21 @@ def create_ui():
465
  modules_not_need_gpt,
466
  modules_not_need_gpt2, modules_not_need_gpt3, text_refiner, visual_chatgpt])
467
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
  clear_button_click.click(
469
  lambda x: ([[], [], []], x, ""),
470
  [origin_image],
@@ -472,6 +492,7 @@ def create_ui():
472
  queue=False,
473
  show_progress=False
474
  )
 
475
  clear_button_image.click(
476
  lambda: (None, [], [], [[], [], []], "", "", ""),
477
  [],
 
7
 
8
  from packaging import version
9
  from PIL import Image, ImageDraw
10
+ import functools
11
 
12
  from caption_anything.model import CaptionAnything
13
  from caption_anything.utils.image_editing_utils import create_bubble_frame
 
23
  args = parse_augment()
24
  args.segmenter = "huge"
25
  args.segmenter_checkpoint = "sam_vit_h_4b8939.pth"
 
26
  if args.segmenter_checkpoint is None:
27
  _, segmenter_checkpoint = prepare_segmenter(args.segmenter)
28
  else:
 
131
  return state, state
132
 
133
  def upload_callback(image_input, state, visual_chatgpt=None):
134
+
135
  if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask
136
  image_input, mask = image_input['image'], image_input['mask']
137
 
 
162
  img_caption, _ = model.captioner.inference_seg(image_input)
163
  Human_prompt = f'\nHuman: provide a new figure with path {new_image_path}. The description is: {img_caption}. This information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
164
  AI_prompt = "Received."
165
+ visual_chatgpt.global_prompt = Human_prompt + 'AI: ' + AI_prompt
166
+ visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + visual_chatgpt.global_prompt
167
  state = [(None, 'Received new image, resize it to width {} and height {}: '.format(image_input.size[0], image_input.size[1]))]
168
 
169
  return state, state, image_input, click_state, image_input, image_input, image_embedding, \
 
310
 
311
  yield state, state, refined_image_input, wiki
312
 
313
+ def clear_chat_memory(visual_chatgpt, keep_global=False):
314
  if visual_chatgpt is not None:
315
  visual_chatgpt.memory.clear()
 
316
  visual_chatgpt.point_prompt = ""
317
+ if keep_global:
318
+ visual_chatgpt.agent.memory.buffer = visual_chatgpt.global_prompt
319
+ else:
320
+ visual_chatgpt.current_image = None
321
+ visual_chatgpt.global_prompt = ""
322
+
323
  def get_style():
324
  current_version = version.parse(gr.__version__)
325
  if current_version <= version.parse('3.24.1'):
 
470
  modules_not_need_gpt,
471
  modules_not_need_gpt2, modules_not_need_gpt3, text_refiner, visual_chatgpt])
472
 
473
+ enable_chatGPT_button.click(
474
+ lambda: (None, [], [], [[], [], []], "", "", ""),
475
+ [],
476
+ [image_input, chatbot, state, click_state, wiki_output, origin_image],
477
+ queue=False,
478
+ show_progress=False
479
+ )
480
+ openai_api_key.submit(
481
+ lambda: (None, [], [], [[], [], []], "", "", ""),
482
+ [],
483
+ [image_input, chatbot, state, click_state, wiki_output, origin_image],
484
+ queue=False,
485
+ show_progress=False
486
+ )
487
+
488
  clear_button_click.click(
489
  lambda x: ([[], [], []], x, ""),
490
  [origin_image],
 
492
  queue=False,
493
  show_progress=False
494
  )
495
+ clear_button_click.click(functools.partial(clear_chat_memory, keep_global=True), inputs=[visual_chatgpt])
496
  clear_button_image.click(
497
  lambda: (None, [], [], [[], [], []], "", "", ""),
498
  [],
app_wo_langchain.py DELETED
@@ -1,588 +0,0 @@
1
- import os
2
- import json
3
- from typing import List
4
-
5
- import PIL
6
- import gradio as gr
7
- import numpy as np
8
- from gradio import processing_utils
9
-
10
- from packaging import version
11
- from PIL import Image, ImageDraw
12
-
13
- from caption_anything.model import CaptionAnything
14
- from caption_anything.utils.image_editing_utils import create_bubble_frame
15
- from caption_anything.utils.utils import mask_painter, seg_model_map, prepare_segmenter
16
- from caption_anything.utils.parser import parse_augment
17
- from caption_anything.captioner import build_captioner
18
- from caption_anything.text_refiner import build_text_refiner
19
- from caption_anything.segmenter import build_segmenter
20
- from caption_anything.utils.chatbot import ConversationBot, build_chatbot_tools, get_new_image_name
21
- from segment_anything import sam_model_registry
22
-
23
-
24
- args = parse_augment()
25
-
26
- args = parse_augment()
27
- if args.segmenter_checkpoint is None:
28
- _, segmenter_checkpoint = prepare_segmenter(args.segmenter)
29
- else:
30
- segmenter_checkpoint = args.segmenter_checkpoint
31
-
32
- shared_captioner = build_captioner(args.captioner, args.device, args)
33
- shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device)
34
-
35
-
36
- class ImageSketcher(gr.Image):
37
- """
38
- Fix the bug of gradio.Image that cannot upload with tool == 'sketch'.
39
- """
40
-
41
- is_template = True # Magic to make this work with gradio.Block, don't remove unless you know what you're doing.
42
-
43
- def __init__(self, **kwargs):
44
- super().__init__(tool="sketch", **kwargs)
45
-
46
- def preprocess(self, x):
47
- if self.tool == 'sketch' and self.source in ["upload", "webcam"]:
48
- assert isinstance(x, dict)
49
- if x['mask'] is None:
50
- decode_image = processing_utils.decode_base64_to_image(x['image'])
51
- width, height = decode_image.size
52
- mask = np.zeros((height, width, 4), dtype=np.uint8)
53
- mask[..., -1] = 255
54
- mask = self.postprocess(mask)
55
-
56
- x['mask'] = mask
57
-
58
- return super().preprocess(x)
59
-
60
-
61
- def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, text_refiner=None,
62
- session_id=None):
63
- segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
64
- captioner = captioner
65
- if session_id is not None:
66
- print('Init caption anything for session {}'.format(session_id))
67
- return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, text_refiner=text_refiner)
68
-
69
-
70
- def init_openai_api_key(api_key=""):
71
- text_refiner = None
72
- if api_key and len(api_key) > 30:
73
- try:
74
- text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
75
- text_refiner.llm('hi') # test
76
- except:
77
- text_refiner = None
78
- openai_available = text_refiner is not None
79
- return gr.update(visible=openai_available), gr.update(visible=openai_available), gr.update(
80
- visible=openai_available), gr.update(visible=True), gr.update(visible=True), gr.update(
81
- visible=True), text_refiner
82
-
83
-
84
- def get_click_prompt(chat_input, click_state, click_mode):
85
- inputs = json.loads(chat_input)
86
- if click_mode == 'Continuous':
87
- points = click_state[0]
88
- labels = click_state[1]
89
- for input in inputs:
90
- points.append(input[:2])
91
- labels.append(input[2])
92
- elif click_mode == 'Single':
93
- points = []
94
- labels = []
95
- for input in inputs:
96
- points.append(input[:2])
97
- labels.append(input[2])
98
- click_state[0] = points
99
- click_state[1] = labels
100
- else:
101
- raise NotImplementedError
102
-
103
- prompt = {
104
- "prompt_type": ["click"],
105
- "input_point": click_state[0],
106
- "input_label": click_state[1],
107
- "multimask_output": "True",
108
- }
109
- return prompt
110
-
111
-
112
- def update_click_state(click_state, caption, click_mode):
113
- if click_mode == 'Continuous':
114
- click_state[2].append(caption)
115
- elif click_mode == 'Single':
116
- click_state[2] = [caption]
117
- else:
118
- raise NotImplementedError
119
-
120
-
121
- def chat_with_points(chat_input, click_state, chat_state, state, text_refiner, img_caption):
122
- if text_refiner is None:
123
- response = "Text refiner is not initilzed, please input openai api key."
124
- state = state + [(chat_input, response)]
125
- return state, state, chat_state
126
-
127
- points, labels, captions = click_state
128
- # point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting!"
129
- suffix = '\nHuman: {chat_input}\nAI: '
130
- qa_template = '\nHuman: {q}\nAI: {a}'
131
- # # "The image is of width {width} and height {height}."
132
- point_chat_prompt = "I am an AI trained to chat with you about an image. I am greate at what is going on in any image based on the image information your provide. The overall image description is \"{img_caption}\". You will also provide me objects in the image in details, i.e., their location and visual descriptions. Here are the locations and descriptions of events that happen in the image: {points_with_caps} \nYou are required to use language instead of number to describe these positions. Now, let's chat!"
133
- prev_visual_context = ""
134
- pos_points = []
135
- pos_captions = []
136
-
137
- for i in range(len(points)):
138
- if labels[i] == 1:
139
- pos_points.append(f"(X:{points[i][0]}, Y:{points[i][1]})")
140
- pos_captions.append(captions[i])
141
- prev_visual_context = prev_visual_context + '\n' + 'There is an event described as \"{}\" locating at {}'.format(
142
- pos_captions[-1], ', '.join(pos_points))
143
-
144
- context_length_thres = 500
145
- prev_history = ""
146
- for i in range(len(chat_state)):
147
- q, a = chat_state[i]
148
- if len(prev_history) < context_length_thres:
149
- prev_history = prev_history + qa_template.format(**{"q": q, "a": a})
150
- else:
151
- break
152
- chat_prompt = point_chat_prompt.format(
153
- **{"img_caption": img_caption, "points_with_caps": prev_visual_context}) + prev_history + suffix.format(
154
- **{"chat_input": chat_input})
155
- print('\nchat_prompt: ', chat_prompt)
156
- response = text_refiner.llm(chat_prompt)
157
- state = state + [(chat_input, response)]
158
- chat_state = chat_state + [(chat_input, response)]
159
- return state, state, chat_state
160
-
161
-
162
- def upload_callback(image_input, state):
163
- if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask
164
- image_input, mask = image_input['image'], image_input['mask']
165
-
166
- chat_state = []
167
- click_state = [[], [], []]
168
- res = 1024
169
- width, height = image_input.size
170
- ratio = min(1.0 * res / max(width, height), 1.0)
171
- if ratio < 1.0:
172
- image_input = image_input.resize((int(width * ratio), int(height * ratio)))
173
- print('Scaling input image to {}'.format(image_input.size))
174
- state = [] + [(None, 'Image size: ' + str(image_input.size))]
175
- model = build_caption_anything_with_models(
176
- args,
177
- api_key="",
178
- captioner=shared_captioner,
179
- sam_model=shared_sam_model,
180
- session_id=iface.app_id
181
- )
182
- model.segmenter.set_image(image_input)
183
- image_embedding = model.image_embedding
184
- original_size = model.original_size
185
- input_size = model.input_size
186
- img_caption, _ = model.captioner.inference_seg(image_input)
187
-
188
- return state, state, chat_state, image_input, click_state, image_input, image_input, image_embedding, \
189
- original_size, input_size, img_caption
190
-
191
-
192
- def inference_click(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
193
- length, image_embedding, state, click_state, original_size, input_size, text_refiner,
194
- evt: gr.SelectData):
195
- click_index = evt.index
196
-
197
- if point_prompt == 'Positive':
198
- coordinate = "[[{}, {}, 1]]".format(str(click_index[0]), str(click_index[1]))
199
- else:
200
- coordinate = "[[{}, {}, 0]]".format(str(click_index[0]), str(click_index[1]))
201
-
202
- prompt = get_click_prompt(coordinate, click_state, click_mode)
203
- input_points = prompt['input_point']
204
- input_labels = prompt['input_label']
205
-
206
- controls = {'length': length,
207
- 'sentiment': sentiment,
208
- 'factuality': factuality,
209
- 'language': language}
210
-
211
- model = build_caption_anything_with_models(
212
- args,
213
- api_key="",
214
- captioner=shared_captioner,
215
- sam_model=shared_sam_model,
216
- text_refiner=text_refiner,
217
- session_id=iface.app_id
218
- )
219
-
220
- model.setup(image_embedding, original_size, input_size, is_image_set=True)
221
-
222
- enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
223
- out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)
224
-
225
- state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
226
- state = state + [(None, "raw_caption: {}".format(out['generated_captions']['raw_caption']))]
227
- wiki = out['generated_captions'].get('wiki', "")
228
- update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
229
- text = out['generated_captions']['raw_caption']
230
- input_mask = np.array(out['mask'].convert('P'))
231
- image_input = mask_painter(np.array(image_input), input_mask)
232
- origin_image_input = image_input
233
- image_input = create_bubble_frame(image_input, text, (click_index[0], click_index[1]), input_mask,
234
- input_points=input_points, input_labels=input_labels)
235
- yield state, state, click_state, image_input, wiki
236
- if not args.disable_gpt and model.text_refiner:
237
- refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
238
- enable_wiki=enable_wiki)
239
- # new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
240
- new_cap = refined_caption['caption']
241
- wiki = refined_caption['wiki']
242
- state = state + [(None, f"caption: {new_cap}")]
243
- refined_image_input = create_bubble_frame(origin_image_input, new_cap, (click_index[0], click_index[1]),
244
- input_mask,
245
- input_points=input_points, input_labels=input_labels)
246
- yield state, state, click_state, refined_image_input, wiki
247
-
248
-
249
- def get_sketch_prompt(mask: PIL.Image.Image, multi_mask=True):
250
- """
251
- Get the prompt for the sketcher.
252
- TODO: This is a temporary solution. We should cluster the sketch and get the bounding box of each cluster.
253
- """
254
-
255
- mask = np.array(np.asarray(mask)[..., 0])
256
- mask[mask > 0] = 1 # Refine the mask, let all nonzero values be 1
257
-
258
- if not multi_mask:
259
- y, x = np.where(mask == 1)
260
- x1, y1 = np.min(x), np.min(y)
261
- x2, y2 = np.max(x), np.max(y)
262
-
263
- prompt = {
264
- 'prompt_type': ['box'],
265
- 'input_boxes': [
266
- [x1, y1, x2, y2]
267
- ]
268
- }
269
-
270
- return prompt
271
-
272
- traversed = np.zeros_like(mask)
273
- groups = np.zeros_like(mask)
274
- max_group_id = 1
275
-
276
- # Iterate over all pixels
277
- for x in range(mask.shape[0]):
278
- for y in range(mask.shape[1]):
279
- if traversed[x, y] == 1:
280
- continue
281
-
282
- if mask[x, y] == 0:
283
- traversed[x, y] = 1
284
- else:
285
- # If pixel is part of mask
286
- groups[x, y] = max_group_id
287
- stack = [(x, y)]
288
- while stack:
289
- i, j = stack.pop()
290
- if traversed[i, j] == 1:
291
- continue
292
- traversed[i, j] = 1
293
- if mask[i, j] == 1:
294
- groups[i, j] = max_group_id
295
- for di, dj in [(1, 0), (-1, 0), (0, 1), (0, -1), (1, 1), (1, -1), (-1, 1), (-1, -1)]:
296
- ni, nj = i + di, j + dj
297
- traversed[i, j] = 1
298
- if 0 <= nj < mask.shape[1] and mask.shape[0] > ni >= 0 == traversed[ni, nj]:
299
- stack.append((i + di, j + dj))
300
- max_group_id += 1
301
-
302
- # get the bounding box of each group
303
- boxes = []
304
- for group in range(1, max_group_id):
305
- y, x = np.where(groups == group)
306
- x1, y1 = np.min(x), np.min(y)
307
- x2, y2 = np.max(x), np.max(y)
308
- boxes.append([x1, y1, x2, y2])
309
-
310
- prompt = {
311
- 'prompt_type': ['box'],
312
- 'input_boxes': boxes
313
- }
314
-
315
- return prompt
316
-
317
-
318
- def inference_traject(sketcher_image, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
319
- original_size, input_size, text_refiner):
320
- image_input, mask = sketcher_image['image'], sketcher_image['mask']
321
-
322
- prompt = get_sketch_prompt(mask, multi_mask=False)
323
- boxes = prompt['input_boxes']
324
-
325
- controls = {'length': length,
326
- 'sentiment': sentiment,
327
- 'factuality': factuality,
328
- 'language': language}
329
-
330
- model = build_caption_anything_with_models(
331
- args,
332
- api_key="",
333
- captioner=shared_captioner,
334
- sam_model=shared_sam_model,
335
- text_refiner=text_refiner,
336
- session_id=iface.app_id
337
- )
338
-
339
- model.setup(image_embedding, original_size, input_size, is_image_set=True)
340
-
341
- enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
342
- out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)
343
-
344
- # Update components and states
345
- state.append((f'Box: {boxes}', None))
346
- state.append((None, f'raw_caption: {out["generated_captions"]["raw_caption"]}'))
347
- wiki = out['generated_captions'].get('wiki', "")
348
- text = out['generated_captions']['raw_caption']
349
- input_mask = np.array(out['mask'].convert('P'))
350
- image_input = mask_painter(np.array(image_input), input_mask)
351
-
352
- origin_image_input = image_input
353
-
354
- fake_click_index = (int((boxes[0][0] + boxes[0][2]) / 2), int((boxes[0][1] + boxes[0][3]) / 2))
355
- image_input = create_bubble_frame(image_input, text, fake_click_index, input_mask)
356
-
357
- yield state, state, image_input, wiki
358
-
359
- if not args.disable_gpt and model.text_refiner:
360
- refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
361
- enable_wiki=enable_wiki)
362
-
363
- new_cap = refined_caption['caption']
364
- wiki = refined_caption['wiki']
365
- state = state + [(None, f"caption: {new_cap}")]
366
- refined_image_input = create_bubble_frame(origin_image_input, new_cap, fake_click_index, input_mask)
367
-
368
- yield state, state, refined_image_input, wiki
369
-
370
-
371
- def get_style():
372
- current_version = version.parse(gr.__version__)
373
- if current_version <= version.parse('3.24.1'):
374
- style = '''
375
- #image_sketcher{min-height:500px}
376
- #image_sketcher [data-testid="image"], #image_sketcher [data-testid="image"] > div{min-height: 500px}
377
- #image_upload{min-height:500px}
378
- #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 500px}
379
- '''
380
- elif current_version <= version.parse('3.27'):
381
- style = '''
382
- #image_sketcher{min-height:500px}
383
- #image_upload{min-height:500px}
384
- '''
385
- else:
386
- style = None
387
-
388
- return style
389
-
390
-
391
- def create_ui():
392
- title = """<p><h1 align="center">Caption-Anything</h1></p>
393
- """
394
- description = """<p>Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. Code: <a href="https://github.com/ttengwang/Caption-Anything">https://github.com/ttengwang/Caption-Anything</a> <a href="https://huggingface.co/spaces/TencentARC/Caption-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>"""
395
-
396
- examples = [
397
- ["test_images/img35.webp"],
398
- ["test_images/img2.jpg"],
399
- ["test_images/img5.jpg"],
400
- ["test_images/img12.jpg"],
401
- ["test_images/img14.jpg"],
402
- ["test_images/qingming3.jpeg"],
403
- ["test_images/img1.jpg"],
404
- ]
405
-
406
- with gr.Blocks(
407
- css=get_style()
408
- ) as iface:
409
- state = gr.State([])
410
- click_state = gr.State([[], [], []])
411
- chat_state = gr.State([])
412
- origin_image = gr.State(None)
413
- image_embedding = gr.State(None)
414
- text_refiner = gr.State(None)
415
- original_size = gr.State(None)
416
- input_size = gr.State(None)
417
- img_caption = gr.State(None)
418
-
419
- gr.Markdown(title)
420
- gr.Markdown(description)
421
-
422
- with gr.Row():
423
- with gr.Column(scale=1.0):
424
- with gr.Column(visible=False) as modules_not_need_gpt:
425
- with gr.Tab("Click"):
426
- image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
427
- example_image = gr.Image(type="pil", interactive=False, visible=False)
428
- with gr.Row(scale=1.0):
429
- with gr.Row(scale=0.4):
430
- point_prompt = gr.Radio(
431
- choices=["Positive", "Negative"],
432
- value="Positive",
433
- label="Point Prompt",
434
- interactive=True)
435
- click_mode = gr.Radio(
436
- choices=["Continuous", "Single"],
437
- value="Continuous",
438
- label="Clicking Mode",
439
- interactive=True)
440
- with gr.Row(scale=0.4):
441
- clear_button_click = gr.Button(value="Clear Clicks", interactive=True)
442
- clear_button_image = gr.Button(value="Clear Image", interactive=True)
443
- with gr.Tab("Trajectory (Beta)"):
444
- sketcher_input = ImageSketcher(type="pil", interactive=True, brush_radius=20,
445
- elem_id="image_sketcher")
446
- with gr.Row():
447
- submit_button_sketcher = gr.Button(value="Submit", interactive=True)
448
-
449
- with gr.Column(visible=False) as modules_need_gpt:
450
- with gr.Row(scale=1.0):
451
- language = gr.Dropdown(
452
- ['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"],
453
- value="English", label="Language", interactive=True)
454
- sentiment = gr.Radio(
455
- choices=["Positive", "Natural", "Negative"],
456
- value="Natural",
457
- label="Sentiment",
458
- interactive=True,
459
- )
460
- with gr.Row(scale=1.0):
461
- factuality = gr.Radio(
462
- choices=["Factual", "Imagination"],
463
- value="Factual",
464
- label="Factuality",
465
- interactive=True,
466
- )
467
- length = gr.Slider(
468
- minimum=10,
469
- maximum=80,
470
- value=10,
471
- step=1,
472
- interactive=True,
473
- label="Generated Caption Length",
474
- )
475
- enable_wiki = gr.Radio(
476
- choices=["Yes", "No"],
477
- value="No",
478
- label="Enable Wiki",
479
- interactive=True)
480
- with gr.Column(visible=True) as modules_not_need_gpt3:
481
- gr.Examples(
482
- examples=examples,
483
- inputs=[example_image],
484
- )
485
- with gr.Column(scale=0.5):
486
- openai_api_key = gr.Textbox(
487
- placeholder="Input openAI API key",
488
- show_label=False,
489
- label="OpenAI API Key",
490
- lines=1,
491
- type="password")
492
- with gr.Row(scale=0.5):
493
- enable_chatGPT_button = gr.Button(value="Run with ChatGPT", interactive=True, variant='primary')
494
- disable_chatGPT_button = gr.Button(value="Run without ChatGPT (Faster)", interactive=True,
495
- variant='primary')
496
- with gr.Column(visible=False) as modules_need_gpt2:
497
- wiki_output = gr.Textbox(lines=5, label="Wiki", max_lines=5)
498
- with gr.Column(visible=False) as modules_not_need_gpt2:
499
- chatbot = gr.Chatbot(label="Chat about Selected Object", ).style(height=550, scale=0.5)
500
- with gr.Column(visible=False) as modules_need_gpt3:
501
- chat_input = gr.Textbox(show_label=False, placeholder="Enter text and press Enter").style(
502
- container=False)
503
- with gr.Row():
504
- clear_button_text = gr.Button(value="Clear Text", interactive=True)
505
- submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
506
-
507
- openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key],
508
- outputs=[modules_need_gpt, modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt,
509
- modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
510
- enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key],
511
- outputs=[modules_need_gpt, modules_need_gpt2, modules_need_gpt3,
512
- modules_not_need_gpt,
513
- modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
514
- disable_chatGPT_button.click(init_openai_api_key,
515
- outputs=[modules_need_gpt, modules_need_gpt2, modules_need_gpt3,
516
- modules_not_need_gpt,
517
- modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
518
-
519
- clear_button_click.click(
520
- lambda x: ([[], [], []], x, ""),
521
- [origin_image],
522
- [click_state, image_input, wiki_output],
523
- queue=False,
524
- show_progress=False
525
- )
526
- clear_button_image.click(
527
- lambda: (None, [], [], [], [[], [], []], "", "", ""),
528
- [],
529
- [image_input, chatbot, state, chat_state, click_state, wiki_output, origin_image, img_caption],
530
- queue=False,
531
- show_progress=False
532
- )
533
- clear_button_text.click(
534
- lambda: ([], [], [[], [], [], []], []),
535
- [],
536
- [chatbot, state, click_state, chat_state],
537
- queue=False,
538
- show_progress=False
539
- )
540
- image_input.clear(
541
- lambda: (None, [], [], [], [[], [], []], "", "", ""),
542
- [],
543
- [image_input, chatbot, state, chat_state, click_state, wiki_output, origin_image, img_caption],
544
- queue=False,
545
- show_progress=False
546
- )
547
-
548
- image_input.upload(upload_callback, [image_input, state],
549
- [chatbot, state, chat_state, origin_image, click_state, image_input, sketcher_input,
550
- image_embedding, original_size, input_size, img_caption])
551
- sketcher_input.upload(upload_callback, [sketcher_input, state],
552
- [chatbot, state, chat_state, origin_image, click_state, image_input, sketcher_input,
553
- image_embedding, original_size, input_size, img_caption])
554
- chat_input.submit(chat_with_points, [chat_input, click_state, chat_state, state, text_refiner, img_caption],
555
- [chatbot, state, chat_state])
556
- chat_input.submit(lambda: "", None, chat_input)
557
- example_image.change(upload_callback, [example_image, state],
558
- [chatbot, state, chat_state, origin_image, click_state, image_input, sketcher_input,
559
- image_embedding, original_size, input_size, img_caption])
560
-
561
- # select coordinate
562
- image_input.select(
563
- inference_click,
564
- inputs=[
565
- origin_image, point_prompt, click_mode, enable_wiki, language, sentiment, factuality, length,
566
- image_embedding, state, click_state, original_size, input_size, text_refiner
567
- ],
568
- outputs=[chatbot, state, click_state, image_input, wiki_output],
569
- show_progress=False, queue=True
570
- )
571
-
572
- submit_button_sketcher.click(
573
- inference_traject,
574
- inputs=[
575
- sketcher_input, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
576
- original_size, input_size, text_refiner
577
- ],
578
- outputs=[chatbot, state, sketcher_input, wiki_output],
579
- show_progress=False, queue=True
580
- )
581
-
582
- return iface
583
-
584
-
585
- if __name__ == '__main__':
586
- iface = create_ui()
587
- iface.queue(concurrency_count=5, api_open=False, max_size=10)
588
- iface.launch(server_name="0.0.0.0", enable_queue=True, server_port=args.port, share=args.gradio_share)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
caption_anything/captioner/base_captioner.py CHANGED
@@ -9,8 +9,10 @@ from typing import Union
9
  import time
10
  import clip
11
 
 
 
 
12
  def boundary(inputs):
13
-
14
  col = inputs.shape[1]
15
  inputs = inputs.reshape(-1)
16
  lens = len(inputs)
@@ -20,11 +22,11 @@ def boundary(inputs):
20
 
21
  top = start // col
22
  bottom = end // col
23
-
24
  return top, bottom
25
 
 
26
  def new_seg_to_box(seg_mask: Union[np.ndarray, Image.Image, str]):
27
-
28
  if type(seg_mask) == str:
29
  seg_mask = Image.open(seg_mask)
30
  elif type(seg_mask) == np.ndarray:
@@ -35,12 +37,13 @@ def new_seg_to_box(seg_mask: Union[np.ndarray, Image.Image, str]):
35
  left, right = boundary(seg_mask.T)
36
  return [left / size, top / size, right / size, bottom / size]
37
 
 
38
  def seg_to_box(seg_mask: Union[np.ndarray, Image.Image, str]):
39
  if type(seg_mask) == str:
40
  seg_mask = cv2.imread(seg_mask, cv2.IMREAD_GRAYSCALE)
41
  _, seg_mask = cv2.threshold(seg_mask, 127, 255, 0)
42
  elif type(seg_mask) == np.ndarray:
43
- assert seg_mask.ndim == 2 # only support single-channel segmentation mask
44
  seg_mask = seg_mask.astype('uint8')
45
  if seg_mask.dtype == 'bool':
46
  seg_mask = seg_mask * 255
@@ -49,25 +52,28 @@ def seg_to_box(seg_mask: Union[np.ndarray, Image.Image, str]):
49
  rect = cv2.minAreaRect(contours)
50
  box = cv2.boxPoints(rect)
51
  if rect[-1] >= 45:
52
- newstart = box.argmin(axis=0)[1] # leftmost
53
  else:
54
- newstart = box.argmax(axis=0)[0] # topmost
55
  box = np.concatenate([box[newstart:], box[:newstart]], axis=0)
56
  box = np.int0(box)
57
  return box
58
 
 
59
  def get_w_h(rect_points):
60
  w = np.linalg.norm(rect_points[0] - rect_points[1], ord=2).astype('int')
61
  h = np.linalg.norm(rect_points[0] - rect_points[3], ord=2).astype('int')
62
  return w, h
63
-
 
64
  def cut_box(img, rect_points):
65
  w, h = get_w_h(rect_points)
66
- dst_pts = np.array([[h, 0], [h, w], [0, w], [0, 0],], dtype="float32")
67
  transform = cv2.getPerspectiveTransform(rect_points.astype("float32"), dst_pts)
68
  cropped_img = cv2.warpPerspective(img, transform, (h, w))
69
  return cropped_img
70
-
 
71
  class BaseCaptioner:
72
  def __init__(self, device, enable_filter=False):
73
  print(f"Initializing ImageCaptioning to {device}")
@@ -82,18 +88,15 @@ class BaseCaptioner:
82
 
83
  @torch.no_grad()
84
  def filter_caption(self, image: Union[np.ndarray, Image.Image, str], caption: str):
85
-
86
- if type(image) == str: # input path
87
- image = Image.open(image)
88
- elif type(image) == np.ndarray:
89
- image = Image.fromarray(image)
90
-
91
- image = self.preprocess(image).unsqueeze(0).to(self.device) # (1, 3, 224, 224)
92
- text = clip.tokenize(caption).to(self.device) # (1, 77)
93
- image_features = self.filter.encode_image(image) # (1, 512)
94
- text_features = self.filter.encode_text(text) # (1, 512)
95
- image_features /= image_features.norm(dim = -1, keepdim = True)
96
- text_features /= text_features.norm(dim = -1, keepdim = True)
97
  similarity = torch.matmul(image_features, text_features.transpose(1, 0)).item()
98
  if similarity < self.threshold:
99
  print('There seems to be nothing where you clicked.')
@@ -103,24 +106,21 @@ class BaseCaptioner:
103
  print(f'Clip score of the caption is {similarity}')
104
  return out
105
 
106
-
107
- def inference(self, image: Union[np.ndarray, Image.Image, str], filter: bool=False):
108
  raise NotImplementedError()
109
-
110
- def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, filter: bool=False):
111
  raise NotImplementedError()
112
-
113
  def inference_box(self, image: Union[np.ndarray, Image.Image, str], box: Union[list, np.ndarray], filter=False):
114
- if type(image) == str: # input path
115
- image = Image.open(image)
116
- elif type(image) == np.ndarray:
117
- image = Image.fromarray(image)
118
 
119
- if np.array(box).size == 4: # [x0, y0, x1, y1], where (x0, y0), (x1, y1) represent top-left and bottom-right corners
 
120
  size = max(image.width, image.height)
121
  x1, y1, x2, y2 = box
122
- image_crop = np.array(image.crop((x1 * size, y1 * size, x2 * size, y2 * size)))
123
- elif np.array(box).size == 8: # four corners of an irregular rectangle
124
  image_crop = cut_box(np.array(image), box)
125
 
126
  crop_save_path = f'result/crop_{time.time()}.png'
@@ -128,24 +128,20 @@ class BaseCaptioner:
128
  print(f'croped image saved in {crop_save_path}')
129
  caption = self.inference(image_crop, filter)
130
  return caption, crop_save_path
131
-
132
 
133
- def inference_seg(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str]=None, crop_mode="w_bg", filter=False, disable_regular_box = False):
 
134
  if seg_mask is None:
135
  seg_mask = np.ones(image.size).astype(bool)
136
-
137
- if type(image) == str:
138
- image = Image.open(image)
139
- if type(seg_mask) == str:
140
- seg_mask = Image.open(seg_mask)
141
- elif type(seg_mask) == np.ndarray:
142
- seg_mask = Image.fromarray(seg_mask)
143
 
144
  seg_mask = seg_mask.resize(image.size)
145
  seg_mask = np.array(seg_mask) > 0
146
-
147
- if crop_mode=="wo_bg":
148
- image = np.array(image) * seg_mask[:,:,np.newaxis] + (1 - seg_mask[:,:,np.newaxis]) * 255
149
  image = np.uint8(image)
150
  else:
151
  image = np.array(image)
@@ -155,20 +151,17 @@ class BaseCaptioner:
155
  else:
156
  min_area_box = new_seg_to_box(seg_mask)
157
  return self.inference_box(image, min_area_box, filter)
158
-
159
-
160
- def generate_seg_cropped_image(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str], crop_mode="w_bg", disable_regular_box = False):
161
- if type(image) == str:
162
- image = Image.open(image)
163
- if type(seg_mask) == str:
164
- seg_mask = Image.open(seg_mask)
165
- elif type(seg_mask) == np.ndarray:
166
- seg_mask = Image.fromarray(seg_mask)
167
  seg_mask = seg_mask.resize(image.size)
168
  seg_mask = np.array(seg_mask) > 0
169
 
170
- if crop_mode=="wo_bg":
171
- image = np.array(image) * seg_mask[:,:,np.newaxis] + (1- seg_mask[:,:,np.newaxis]) * 255
172
  else:
173
  image = np.array(image)
174
 
@@ -176,24 +169,24 @@ class BaseCaptioner:
176
  box = seg_to_box(seg_mask)
177
  else:
178
  box = new_seg_to_box(seg_mask)
179
-
180
- if np.array(box).size == 4: # [x0, y0, x1, y1], where (x0, y0), (x1, y1) represent top-left and bottom-right corners
 
181
  size = max(image.shape[0], image.shape[1])
182
  x1, y1, x2, y2 = box
183
- image_crop = np.array(image.crop((x1 * size, y1 * size, x2 * size, y2 * size)))
184
- elif np.array(box).size == 8: # four corners of an irregular rectangle
185
  image_crop = cut_box(np.array(image), box)
186
  crop_save_path = f'result/crop_{time.time()}.png'
187
  Image.fromarray(image_crop).save(crop_save_path)
188
  print(f'croped image saved in {crop_save_path}')
189
  return crop_save_path
190
 
191
-
192
  if __name__ == '__main__':
193
  model = BaseCaptioner(device='cuda:0')
194
  image_path = 'test_images/img2.jpg'
195
- seg_mask = np.zeros((15,15))
196
  seg_mask[5:10, 5:10] = 1
197
  seg_mask = 'image/SAM/img10.jpg.raw_mask.png'
198
  print(model.inference_seg(image_path, seg_mask))
199
-
 
9
  import time
10
  import clip
11
 
12
+ from caption_anything.utils.utils import load_image
13
+
14
+
15
  def boundary(inputs):
 
16
  col = inputs.shape[1]
17
  inputs = inputs.reshape(-1)
18
  lens = len(inputs)
 
22
 
23
  top = start // col
24
  bottom = end // col
25
+
26
  return top, bottom
27
 
28
+
29
  def new_seg_to_box(seg_mask: Union[np.ndarray, Image.Image, str]):
 
30
  if type(seg_mask) == str:
31
  seg_mask = Image.open(seg_mask)
32
  elif type(seg_mask) == np.ndarray:
 
37
  left, right = boundary(seg_mask.T)
38
  return [left / size, top / size, right / size, bottom / size]
39
 
40
+
41
  def seg_to_box(seg_mask: Union[np.ndarray, Image.Image, str]):
42
  if type(seg_mask) == str:
43
  seg_mask = cv2.imread(seg_mask, cv2.IMREAD_GRAYSCALE)
44
  _, seg_mask = cv2.threshold(seg_mask, 127, 255, 0)
45
  elif type(seg_mask) == np.ndarray:
46
+ assert seg_mask.ndim == 2 # only support single-channel segmentation mask
47
  seg_mask = seg_mask.astype('uint8')
48
  if seg_mask.dtype == 'bool':
49
  seg_mask = seg_mask * 255
 
52
  rect = cv2.minAreaRect(contours)
53
  box = cv2.boxPoints(rect)
54
  if rect[-1] >= 45:
55
+ newstart = box.argmin(axis=0)[1] # leftmost
56
  else:
57
+ newstart = box.argmax(axis=0)[0] # topmost
58
  box = np.concatenate([box[newstart:], box[:newstart]], axis=0)
59
  box = np.int0(box)
60
  return box
61
 
62
+
63
  def get_w_h(rect_points):
64
  w = np.linalg.norm(rect_points[0] - rect_points[1], ord=2).astype('int')
65
  h = np.linalg.norm(rect_points[0] - rect_points[3], ord=2).astype('int')
66
  return w, h
67
+
68
+
69
  def cut_box(img, rect_points):
70
  w, h = get_w_h(rect_points)
71
+ dst_pts = np.array([[h, 0], [h, w], [0, w], [0, 0], ], dtype="float32")
72
  transform = cv2.getPerspectiveTransform(rect_points.astype("float32"), dst_pts)
73
  cropped_img = cv2.warpPerspective(img, transform, (h, w))
74
  return cropped_img
75
+
76
+
77
  class BaseCaptioner:
78
  def __init__(self, device, enable_filter=False):
79
  print(f"Initializing ImageCaptioning to {device}")
 
88
 
89
  @torch.no_grad()
90
  def filter_caption(self, image: Union[np.ndarray, Image.Image, str], caption: str):
91
+
92
+ image = load_image(image, return_type='pil')
93
+
94
+ image = self.preprocess(image).unsqueeze(0).to(self.device) # (1, 3, 224, 224)
95
+ text = clip.tokenize(caption).to(self.device) # (1, 77)
96
+ image_features = self.filter.encode_image(image) # (1, 512)
97
+ text_features = self.filter.encode_text(text) # (1, 512)
98
+ image_features /= image_features.norm(dim=-1, keepdim=True)
99
+ text_features /= text_features.norm(dim=-1, keepdim=True)
 
 
 
100
  similarity = torch.matmul(image_features, text_features.transpose(1, 0)).item()
101
  if similarity < self.threshold:
102
  print('There seems to be nothing where you clicked.')
 
106
  print(f'Clip score of the caption is {similarity}')
107
  return out
108
 
109
+ def inference(self, image: Union[np.ndarray, Image.Image, str], filter: bool = False):
 
110
  raise NotImplementedError()
111
+
112
+ def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, filter: bool = False):
113
  raise NotImplementedError()
114
+
115
  def inference_box(self, image: Union[np.ndarray, Image.Image, str], box: Union[list, np.ndarray], filter=False):
116
+ image = load_image(image, return_type="pil")
 
 
 
117
 
118
+ if np.array(box).size == 4:
119
+ # [x0, y0, x1, y1], where (x0, y0), (x1, y1) represent top-left and bottom-right corners
120
  size = max(image.width, image.height)
121
  x1, y1, x2, y2 = box
122
+ image_crop = np.array(image.crop((x1 * size, y1 * size, x2 * size, y2 * size)))
123
+ elif np.array(box).size == 8: # four corners of an irregular rectangle
124
  image_crop = cut_box(np.array(image), box)
125
 
126
  crop_save_path = f'result/crop_{time.time()}.png'
 
128
  print(f'croped image saved in {crop_save_path}')
129
  caption = self.inference(image_crop, filter)
130
  return caption, crop_save_path
 
131
 
132
+ def inference_seg(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str] = None,
133
+ crop_mode="w_bg", filter=False, disable_regular_box=False):
134
  if seg_mask is None:
135
  seg_mask = np.ones(image.size).astype(bool)
136
+
137
+ image = load_image(image, return_type="pil")
138
+ seg_mask = load_image(seg_mask, return_type="pil")
 
 
 
 
139
 
140
  seg_mask = seg_mask.resize(image.size)
141
  seg_mask = np.array(seg_mask) > 0
142
+
143
+ if crop_mode == "wo_bg":
144
+ image = np.array(image) * seg_mask[:, :, np.newaxis] + (1 - seg_mask[:, :, np.newaxis]) * 255
145
  image = np.uint8(image)
146
  else:
147
  image = np.array(image)
 
151
  else:
152
  min_area_box = new_seg_to_box(seg_mask)
153
  return self.inference_box(image, min_area_box, filter)
154
+
155
+ def generate_seg_cropped_image(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str],
156
+ crop_mode="w_bg", disable_regular_box=False):
157
+ image = load_image(image, return_type="pil")
158
+ seg_mask = load_image(seg_mask, return_type="pil")
159
+
 
 
 
160
  seg_mask = seg_mask.resize(image.size)
161
  seg_mask = np.array(seg_mask) > 0
162
 
163
+ if crop_mode == "wo_bg":
164
+ image = np.array(image) * seg_mask[:, :, np.newaxis] + (1 - seg_mask[:, :, np.newaxis]) * 255
165
  else:
166
  image = np.array(image)
167
 
 
169
  box = seg_to_box(seg_mask)
170
  else:
171
  box = new_seg_to_box(seg_mask)
172
+
173
+ if np.array(box).size == 4:
174
+ # [x0, y0, x1, y1], where (x0, y0), (x1, y1) represent top-left and bottom-right corners
175
  size = max(image.shape[0], image.shape[1])
176
  x1, y1, x2, y2 = box
177
+ image_crop = np.array(image.crop((x1 * size, y1 * size, x2 * size, y2 * size)))
178
+ elif np.array(box).size == 8: # four corners of an irregular rectangle
179
  image_crop = cut_box(np.array(image), box)
180
  crop_save_path = f'result/crop_{time.time()}.png'
181
  Image.fromarray(image_crop).save(crop_save_path)
182
  print(f'croped image saved in {crop_save_path}')
183
  return crop_save_path
184
 
185
+
186
  if __name__ == '__main__':
187
  model = BaseCaptioner(device='cuda:0')
188
  image_path = 'test_images/img2.jpg'
189
+ seg_mask = np.zeros((15, 15))
190
  seg_mask[5:10, 5:10] = 1
191
  seg_mask = 'image/SAM/img10.jpg.raw_mask.png'
192
  print(model.inference_seg(image_path, seg_mask))
 
caption_anything/captioner/blip.py CHANGED
@@ -1,14 +1,13 @@
1
  import torch
2
- from PIL import Image, ImageDraw, ImageOps
3
  from transformers import BlipProcessor
 
 
4
  from .modeling_blip import BlipForConditionalGeneration
5
- import json
6
- import pdb
7
- import cv2
8
  import numpy as np
9
  from typing import Union
10
  from .base_captioner import BaseCaptioner
11
- import torchvision.transforms.functional as F
12
 
13
 
14
  class BLIPCaptioner(BaseCaptioner):
@@ -17,12 +16,12 @@ class BLIPCaptioner(BaseCaptioner):
17
  self.device = device
18
  self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
19
  self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
20
- self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=self.torch_dtype).to(self.device)
21
-
 
22
  @torch.no_grad()
23
  def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
24
- if type(image) == str: # input path
25
- image = Image.open(image)
26
  inputs = self.processor(image, return_tensors="pt").to(self.device, self.torch_dtype)
27
  out = self.model.generate(**inputs, max_new_tokens=50)
28
  captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
@@ -30,12 +29,13 @@ class BLIPCaptioner(BaseCaptioner):
30
  captions = self.filter_caption(image, captions)
31
  print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
32
  return captions
33
-
34
  @torch.no_grad()
35
- def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg", filter=False, disable_regular_box = False):
36
- crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode, disable_regular_box=disable_regular_box)
37
- if type(image) == str: # input path
38
- image = Image.open(image)
 
39
  inputs = self.processor(image, return_tensors="pt")
40
  pixel_values = inputs.pixel_values.to(self.device, self.torch_dtype)
41
  _, _, H, W = pixel_values.shape
@@ -56,11 +56,10 @@ if __name__ == '__main__':
56
  model = BLIPCaptioner(device='cuda:0')
57
  # image_path = 'test_images/img2.jpg'
58
  image_path = 'image/SAM/img10.jpg'
59
- seg_mask = np.zeros((15,15))
60
  seg_mask[5:10, 5:10] = 1
61
  seg_mask = 'test_images/img10.jpg.raw_mask.png'
62
  image_path = 'test_images/img2.jpg'
63
  seg_mask = 'test_images/img2.jpg.raw_mask.png'
64
  print(f'process image {image_path}')
65
  print(model.inference_with_reduced_tokens(image_path, seg_mask))
66
-
 
1
  import torch
2
+ from PIL import Image
3
  from transformers import BlipProcessor
4
+
5
+ from caption_anything.utils.utils import load_image
6
  from .modeling_blip import BlipForConditionalGeneration
 
 
 
7
  import numpy as np
8
  from typing import Union
9
  from .base_captioner import BaseCaptioner
10
+ import torchvision.transforms.functional as F
11
 
12
 
13
  class BLIPCaptioner(BaseCaptioner):
 
16
  self.device = device
17
  self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
18
  self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
19
+ self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large",
20
+ torch_dtype=self.torch_dtype).to(self.device)
21
+
22
  @torch.no_grad()
23
  def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
24
+ image = load_image(image, return_type="pil")
 
25
  inputs = self.processor(image, return_tensors="pt").to(self.device, self.torch_dtype)
26
  out = self.model.generate(**inputs, max_new_tokens=50)
27
  captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
 
29
  captions = self.filter_caption(image, captions)
30
  print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
31
  return captions
32
+
33
  @torch.no_grad()
34
+ def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg",
35
+ filter=False, disable_regular_box=False):
36
+ crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode,
37
+ disable_regular_box=disable_regular_box)
38
+ image = load_image(image, return_type="pil")
39
  inputs = self.processor(image, return_tensors="pt")
40
  pixel_values = inputs.pixel_values.to(self.device, self.torch_dtype)
41
  _, _, H, W = pixel_values.shape
 
56
  model = BLIPCaptioner(device='cuda:0')
57
  # image_path = 'test_images/img2.jpg'
58
  image_path = 'image/SAM/img10.jpg'
59
+ seg_mask = np.zeros((15, 15))
60
  seg_mask[5:10, 5:10] = 1
61
  seg_mask = 'test_images/img10.jpg.raw_mask.png'
62
  image_path = 'test_images/img2.jpg'
63
  seg_mask = 'test_images/img2.jpg.raw_mask.png'
64
  print(f'process image {image_path}')
65
  print(model.inference_with_reduced_tokens(image_path, seg_mask))
 
caption_anything/captioner/blip2.py CHANGED
@@ -4,7 +4,7 @@ import numpy as np
4
  from typing import Union
5
  from transformers import AutoProcessor, Blip2ForConditionalGeneration
6
 
7
- from caption_anything.utils.utils import is_platform_win
8
  from .base_captioner import BaseCaptioner
9
 
10
  class BLIP2Captioner(BaseCaptioner):
@@ -21,11 +21,10 @@ class BLIP2Captioner(BaseCaptioner):
21
 
22
  @torch.no_grad()
23
  def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
24
- if type(image) == str: # input path
25
- image = Image.open(image)
26
 
27
  if not self.dialogue:
28
- text_prompt = 'Question: what does the image show? Answer:'
29
  inputs = self.processor(image, text = text_prompt, return_tensors="pt").to(self.device, self.torch_dtype)
30
  out = self.model.generate(**inputs, max_new_tokens=50)
31
  captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
 
4
  from typing import Union
5
  from transformers import AutoProcessor, Blip2ForConditionalGeneration
6
 
7
+ from caption_anything.utils.utils import is_platform_win, load_image
8
  from .base_captioner import BaseCaptioner
9
 
10
  class BLIP2Captioner(BaseCaptioner):
 
21
 
22
  @torch.no_grad()
23
  def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
24
+ image = load_image(image, return_type="pil")
 
25
 
26
  if not self.dialogue:
27
+ text_prompt = 'The image shows'
28
  inputs = self.processor(image, text = text_prompt, return_tensors="pt").to(self.device, self.torch_dtype)
29
  out = self.model.generate(**inputs, max_new_tokens=50)
30
  captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
caption_anything/captioner/git.py CHANGED
@@ -1,4 +1,6 @@
1
  from transformers import GitProcessor, AutoProcessor
 
 
2
  from .modeling_git import GitForCausalLM
3
  from PIL import Image
4
  import torch
@@ -15,11 +17,10 @@ class GITCaptioner(BaseCaptioner):
15
  self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
16
  self.processor = AutoProcessor.from_pretrained("microsoft/git-large")
17
  self.model = GitForCausalLM.from_pretrained("microsoft/git-large", torch_dtype=self.torch_dtype).to(self.device)
18
-
19
  @torch.no_grad()
20
  def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
21
- if type(image) == str: # input path
22
- image = Image.open(image)
23
  pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device, self.torch_dtype)
24
  generated_ids = self.model.generate(pixel_values=pixel_values, max_new_tokens=50)
25
  generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
@@ -29,10 +30,11 @@ class GITCaptioner(BaseCaptioner):
29
  return generated_caption
30
 
31
  @torch.no_grad()
32
- def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg", filter=False, disable_regular_box = False):
33
- crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode, disable_regular_box=disable_regular_box)
34
- if type(image) == str: # input path
35
- image = Image.open(image)
 
36
  inputs = self.processor(images=image, return_tensors="pt")
37
  pixel_values = inputs.pixel_values.to(self.device, self.torch_dtype)
38
  _, _, H, W = pixel_values.shape
@@ -48,10 +50,11 @@ class GITCaptioner(BaseCaptioner):
48
  print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
49
  return captions, crop_save_path
50
 
 
51
  if __name__ == '__main__':
52
  model = GITCaptioner(device='cuda:2', enable_filter=False)
53
  image_path = 'test_images/img2.jpg'
54
- seg_mask = np.zeros((224,224))
55
  seg_mask[50:200, 50:200] = 1
56
  print(f'process image {image_path}')
57
- print(model.inference_with_reduced_tokens(image_path, seg_mask))
 
1
  from transformers import GitProcessor, AutoProcessor
2
+
3
+ from caption_anything.utils.utils import load_image
4
  from .modeling_git import GitForCausalLM
5
  from PIL import Image
6
  import torch
 
17
  self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
18
  self.processor = AutoProcessor.from_pretrained("microsoft/git-large")
19
  self.model = GitForCausalLM.from_pretrained("microsoft/git-large", torch_dtype=self.torch_dtype).to(self.device)
20
+
21
  @torch.no_grad()
22
  def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
23
+ image = load_image(image, return_type="pil")
 
24
  pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device, self.torch_dtype)
25
  generated_ids = self.model.generate(pixel_values=pixel_values, max_new_tokens=50)
26
  generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
 
30
  return generated_caption
31
 
32
  @torch.no_grad()
33
+ def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg",
34
+ filter=False, disable_regular_box=False):
35
+ crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode,
36
+ disable_regular_box=disable_regular_box)
37
+ image = load_image(image, return_type="pil")
38
  inputs = self.processor(images=image, return_tensors="pt")
39
  pixel_values = inputs.pixel_values.to(self.device, self.torch_dtype)
40
  _, _, H, W = pixel_values.shape
 
50
  print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
51
  return captions, crop_save_path
52
 
53
+
54
  if __name__ == '__main__':
55
  model = GITCaptioner(device='cuda:2', enable_filter=False)
56
  image_path = 'test_images/img2.jpg'
57
+ seg_mask = np.zeros((224, 224))
58
  seg_mask[50:200, 50:200] = 1
59
  print(f'process image {image_path}')
60
+ print(model.inference_with_reduced_tokens(image_path, seg_mask))
caption_anything/model.py CHANGED
@@ -62,9 +62,12 @@ class CaptionAnything:
62
  print('OpenAI GPT is not available')
63
 
64
  def inference(self, image, prompt, controls, disable_gpt=False, enable_wiki=False):
 
 
65
  # segment with prompt
66
  print("CA prompt: ", prompt, "CA controls", controls)
67
  seg_mask = self.segmenter.inference(image, prompt)[0, ...]
 
68
  if self.args.enable_morphologyex:
69
  seg_mask = 255 * seg_mask.astype(np.uint8)
70
  seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis=-1)
@@ -80,6 +83,7 @@ class CaptionAnything:
80
  seg_mask_img.save(mask_save_path)
81
  print('seg_mask path: ', mask_save_path)
82
  print("seg_mask.shape: ", seg_mask.shape)
 
83
  # captioning with mask
84
  if self.args.enable_reduce_tokens:
85
  caption, crop_save_path = self.captioner. \
@@ -92,6 +96,7 @@ class CaptionAnything:
92
  inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode,
93
  filter=self.args.clip_filter,
94
  disable_regular_box=self.args.disable_regular_box)
 
95
  # refining with TextRefiner
96
  context_captions = []
97
  if self.args.context_captions:
@@ -111,6 +116,7 @@ class CaptionAnything:
111
 
112
  if __name__ == "__main__":
113
  from caption_anything.utils.parser import parse_augment
 
114
  args = parse_augment()
115
  # image_path = 'test_images/img3.jpg'
116
  image_path = 'test_images/img1.jpg'
 
62
  print('OpenAI GPT is not available')
63
 
64
  def inference(self, image, prompt, controls, disable_gpt=False, enable_wiki=False):
65
+ # TODO: Add support to multiple seg masks.
66
+
67
  # segment with prompt
68
  print("CA prompt: ", prompt, "CA controls", controls)
69
  seg_mask = self.segmenter.inference(image, prompt)[0, ...]
70
+
71
  if self.args.enable_morphologyex:
72
  seg_mask = 255 * seg_mask.astype(np.uint8)
73
  seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis=-1)
 
83
  seg_mask_img.save(mask_save_path)
84
  print('seg_mask path: ', mask_save_path)
85
  print("seg_mask.shape: ", seg_mask.shape)
86
+
87
  # captioning with mask
88
  if self.args.enable_reduce_tokens:
89
  caption, crop_save_path = self.captioner. \
 
96
  inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode,
97
  filter=self.args.clip_filter,
98
  disable_regular_box=self.args.disable_regular_box)
99
+
100
  # refining with TextRefiner
101
  context_captions = []
102
  if self.args.context_captions:
 
116
 
117
  if __name__ == "__main__":
118
  from caption_anything.utils.parser import parse_augment
119
+
120
  args = parse_augment()
121
  # image_path = 'test_images/img3.jpg'
122
  image_path = 'test_images/img1.jpg'
caption_anything/segmenter/base_segmenter.py CHANGED
@@ -5,7 +5,7 @@ from PIL import Image, ImageDraw, ImageOps
5
  import numpy as np
6
  from typing import Union
7
  from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
8
- from caption_anything.utils.utils import prepare_segmenter, seg_model_map
9
  import matplotlib.pyplot as plt
10
  import PIL
11
 
@@ -30,21 +30,9 @@ class BaseSegmenter:
30
  self.image_embedding = None
31
  self.image = None
32
 
33
- def read_image(self, image: Union[np.ndarray, Image.Image, str]):
34
- if type(image) == str: # input path
35
- image = Image.open(image)
36
- image = np.array(image)
37
- elif type(image) == Image.Image:
38
- image = np.array(image)
39
- elif type(image) == np.ndarray:
40
- image = image
41
- else:
42
- raise TypeError
43
- return image
44
-
45
  @torch.no_grad()
46
  def set_image(self, image: Union[np.ndarray, Image.Image, str]):
47
- image = self.read_image(image)
48
  self.image = image
49
  if self.reuse_feature:
50
  self.predictor.set_image(image)
@@ -57,7 +45,7 @@ class BaseSegmenter:
57
  SAM inference of image according to control.
58
  Args:
59
  image: str or PIL.Image or np.ndarray
60
- control:
61
  prompt_type:
62
  1. {control['prompt_type'] = ['everything']} to segment everything in the image.
63
  2. {control['prompt_type'] = ['click', 'box']} to segment according to click and box.
@@ -77,7 +65,7 @@ class BaseSegmenter:
77
  masks: np.ndarray of shape [num_masks, height, width]
78
 
79
  """
80
- image = self.read_image(image) # Turn image into np.ndarray
81
  if 'everything' in control['prompt_type']:
82
  masks = self.mask_generator.generate(image)
83
  new_masks = np.concatenate([mask["segmentation"][np.newaxis, :] for mask in masks])
 
5
  import numpy as np
6
  from typing import Union
7
  from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
8
+ from caption_anything.utils.utils import prepare_segmenter, seg_model_map, load_image
9
  import matplotlib.pyplot as plt
10
  import PIL
11
 
 
30
  self.image_embedding = None
31
  self.image = None
32
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  @torch.no_grad()
34
  def set_image(self, image: Union[np.ndarray, Image.Image, str]):
35
+ image = load_image(image, return_type='numpy')
36
  self.image = image
37
  if self.reuse_feature:
38
  self.predictor.set_image(image)
 
45
  SAM inference of image according to control.
46
  Args:
47
  image: str or PIL.Image or np.ndarray
48
+ control: dict to control SAM.
49
  prompt_type:
50
  1. {control['prompt_type'] = ['everything']} to segment everything in the image.
51
  2. {control['prompt_type'] = ['click', 'box']} to segment according to click and box.
 
65
  masks: np.ndarray of shape [num_masks, height, width]
66
 
67
  """
68
+ image = load_image(image, return_type='numpy')
69
  if 'everything' in control['prompt_type']:
70
  masks = self.mask_generator.generate(image)
71
  new_masks = np.concatenate([mask["segmentation"][np.newaxis, :] for mask in masks])
caption_anything/utils/chatbot.py CHANGED
@@ -19,22 +19,11 @@ from PIL import Image, ImageDraw, ImageOps
19
  from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
20
 
21
  VISUAL_CHATGPT_PREFIX = """
22
- Caption Anything Chatbox (short as CATchat) is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. CATchat is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
23
 
24
- As a language model, CATchat can not directly read images, but it has a list of tools to finish different visual tasks. CATchat can invoke different tools to indirectly understand pictures.
25
 
26
- Visual ChatGPT has access to the following tools:"""
27
-
28
-
29
- # VISUAL_CHATGPT_PREFIX = """Visual ChatGPT is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. Visual ChatGPT is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
30
-
31
- # Visual ChatGPT is able to process and understand large amounts of text and images. As a language model, Visual ChatGPT can not directly read images, but it has a list of tools to finish different visual tasks. Each image will have a file name formed as "chat_image/xxx.png", and Visual ChatGPT can invoke different tools to indirectly understand pictures. When talking about images, Visual ChatGPT is very strict to the file name and will never fabricate nonexistent files. Visual ChatGPT is able to use tools in a sequence, and is loyal to the tool observation outputs rather than faking the image content and image file name.
32
-
33
- # Visual ChatGPT is aware of the coordinate of an object in the image, which is represented as a point (X, Y) on the object. Note that (0, 0) represents the bottom-left corner of the image.
34
-
35
- # Human may provide new figures to Visual ChatGPT with a description. The description helps Visual ChatGPT to understand this image, but Visual ChatGPT should use tools to finish following tasks, rather than directly imagine from the description.
36
-
37
- # Overall, Visual ChatGPT is a powerful visual dialogue assistant tool that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics.
38
 
39
 
40
  # TOOLS:
@@ -63,8 +52,7 @@ Previous conversation history:
63
  {chat_history}
64
 
65
  New input: {input}
66
- Since CATchat is a text language model, CATchat must use tools iteratively to observe images rather than imagination.
67
- The thoughts and observations are only visible for CATchat, CATchat should remember to repeat important information in the final response for Human.
68
 
69
  Thought: Do I need to use a tool? {agent_scratchpad} (You are strictly to use the aforementioned "Thought/Action/Action Input/Observation" format as the answer.)"""
70
 
@@ -111,9 +99,9 @@ class VisualQuestionAnswering:
111
  # "Salesforce/blip-vqa-capfilt-large", torch_dtype=self.torch_dtype).to(self.device)
112
 
113
  @prompts(name="Answer Question About The Image",
114
- description="useful when you need an answer for a question based on an image. "
115
- "like: what is the background color of the last image, how many cats in this figure, what is in this figure. "
116
- "The input to this tool should be a comma separated string of two, representing the image_path and the question")
117
  def inference(self, inputs):
118
  image_path, question = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
119
  raw_image = Image.open(image_path).convert('RGB')
@@ -151,12 +139,13 @@ def build_chatbot_tools(load_dict):
151
  class ConversationBot:
152
  def __init__(self, tools, api_key=""):
153
  # load_dict = {'VisualQuestionAnswering':'cuda:0', 'ImageCaptioning':'cuda:1',...}
154
- llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0, openai_api_key=api_key)
155
  self.llm = llm
156
  self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
157
  self.tools = tools
158
  self.current_image = None
159
  self.point_prompt = ""
 
160
  self.agent = initialize_agent(
161
  self.tools,
162
  self.llm,
@@ -212,7 +201,7 @@ if __name__ == '__main__':
212
  bot = ConversationBot(tools)
213
  with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo:
214
  with gr.Row():
215
- chatbot = gr.Chatbot(elem_id="chatbot", label="Visual ChatGPT").style(height=1000,scale=0.5)
216
  auxwindow = gr.Chatbot(elem_id="chatbot", label="Aux Window").style(height=1000,scale=0.5)
217
  state = gr.State([])
218
  aux_state = gr.State([])
 
19
  from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
20
 
21
  VISUAL_CHATGPT_PREFIX = """
22
+ I want you act as Caption Anything Chatbox (short as CATchat), which is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. You are able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
23
 
24
+ As a language model, you can not directly read images, but can invoke VQA tool to indirectly understand pictures, by repeatly asking questions about the objects and scene of the image. You should carefully asking informative questions to maximize your information about this image content. Each image will have a file name formed as "chat_image/xxx.png", you are very strict to the file name and will never fabricate nonexistent files.
25
 
26
+ You have access to the following tools:"""
 
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  # TOOLS:
 
52
  {chat_history}
53
 
54
  New input: {input}
55
+ As a language model, you must repeatly to use VQA tools to observe images. You response should be consistent with the outputs of the VQA tool instead of imagination. Do not repeat asking the same question.
 
56
 
57
  Thought: Do I need to use a tool? {agent_scratchpad} (You are strictly to use the aforementioned "Thought/Action/Action Input/Observation" format as the answer.)"""
58
 
 
99
  # "Salesforce/blip-vqa-capfilt-large", torch_dtype=self.torch_dtype).to(self.device)
100
 
101
  @prompts(name="Answer Question About The Image",
102
+ description="VQA tool is useful when you need an answer for a question based on an image. "
103
+ "like: what is the color of an object, how many cats in this figure, where is the child sitting, what does the cat doing, why is he laughing."
104
+ "The input to this tool should be a comma separated string of two, representing the image path and the question.")
105
  def inference(self, inputs):
106
  image_path, question = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
107
  raw_image = Image.open(image_path).convert('RGB')
 
139
  class ConversationBot:
140
  def __init__(self, tools, api_key=""):
141
  # load_dict = {'VisualQuestionAnswering':'cuda:0', 'ImageCaptioning':'cuda:1',...}
142
+ llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0.7, openai_api_key=api_key)
143
  self.llm = llm
144
  self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
145
  self.tools = tools
146
  self.current_image = None
147
  self.point_prompt = ""
148
+ self.global_prompt = ""
149
  self.agent = initialize_agent(
150
  self.tools,
151
  self.llm,
 
201
  bot = ConversationBot(tools)
202
  with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo:
203
  with gr.Row():
204
+ chatbot = gr.Chatbot(elem_id="chatbot", label="CATchat").style(height=1000,scale=0.5)
205
  auxwindow = gr.Chatbot(elem_id="chatbot", label="Aux Window").style(height=1000,scale=0.5)
206
  state = gr.State([])
207
  aux_state = gr.State([])
caption_anything/utils/utils.py CHANGED
@@ -1,13 +1,41 @@
1
  import os
 
 
 
2
  import cv2
 
3
  import requests
4
  import numpy as np
 
 
 
5
  from PIL import Image
6
- import time
7
- import sys
8
- import urllib
9
  from tqdm import tqdm
10
- import hashlib
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def is_platform_win():
13
  return sys.platform == "win32"
@@ -114,7 +142,7 @@ def vis_add_mask(image, mask, color, alpha, kernel_size):
114
  mask = mask.astype('float').copy()
115
  mask = (cv2.GaussianBlur(mask, (kernel_size, kernel_size), kernel_size) / 255.) * (alpha)
116
  for i in range(3):
117
- image[:, :, i] = image[:, :, i] * (1-alpha+mask) + color[i] * (alpha-mask)
118
  return image
119
 
120
 
@@ -122,11 +150,12 @@ def vis_add_mask_wo_blur(image, mask, color, alpha):
122
  color = np.array(color)
123
  mask = mask.astype('float').copy()
124
  for i in range(3):
125
- image[:, :, i] = image[:, :, i] * (1-alpha+mask) + color[i] * (alpha-mask)
126
  return image
127
 
128
 
129
- def vis_add_mask_wo_gaussian(image, background_mask, contour_mask, background_color, contour_color, background_alpha, contour_alpha):
 
130
  background_color = np.array(background_color)
131
  contour_color = np.array(contour_color)
132
 
@@ -134,16 +163,17 @@ def vis_add_mask_wo_gaussian(image, background_mask, contour_mask, background_co
134
  # contour_mask = 1 - contour_mask
135
 
136
  for i in range(3):
137
- image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \
138
- + background_color[i] * (background_alpha-background_mask*background_alpha)
139
 
140
- image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \
141
- + contour_color[i] * (contour_alpha-contour_mask*contour_alpha)
142
 
143
  return image.astype('uint8')
144
 
145
 
146
- def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, background_color=0, paint_foreground=False):
 
147
  """
148
  add color mask to the background/foreground area
149
  input_image: numpy array (w, h, C)
@@ -163,23 +193,27 @@ def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_
163
  assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
164
 
165
  # 0: background, 1: foreground
166
- input_mask[input_mask>0] = 255
167
  if paint_foreground:
168
- painted_image = vis_add_mask(input_image, 255 - input_mask, color_list[background_color], background_alpha, background_blur_radius) # black for background
 
169
  else:
170
- # mask background
171
- painted_image = vis_add_mask(input_image, input_mask, color_list[background_color], background_alpha, background_blur_radius) # black for background
 
172
  # mask contour
173
  contour_mask = input_mask.copy()
174
- contour_mask = cv2.Canny(contour_mask, 100, 200) # contour extraction
175
  # widden contour
176
  kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (contour_width, contour_width))
177
  contour_mask = cv2.dilate(contour_mask, kernel)
178
- painted_image = vis_add_mask(painted_image, 255-contour_mask, color_list[contour_color], contour_alpha, contour_width)
 
179
  return painted_image
180
 
181
 
182
- def mask_painter_foreground_all(input_image, input_masks, background_alpha=0.7, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1):
 
183
  """
184
  paint color mask on the all foreground area
185
  input_image: numpy array with shape (w, h, C)
@@ -194,22 +228,24 @@ def mask_painter_foreground_all(input_image, input_masks, background_alpha=0.7,
194
  Output:
195
  painted_image: numpy array
196
  """
197
-
198
  for i, input_mask in enumerate(input_masks):
199
- input_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, background_color=i + 2, paint_foreground=True)
 
200
  return input_image
201
 
 
202
  def mask_generator_00(mask, background_radius, contour_radius):
203
  # no background width when '00'
204
  # distance map
205
  dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
206
- dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
207
  dist_map = dist_transform_fore - dist_transform_back
208
  # ...:::!!!:::...
209
  contour_radius += 2
210
  contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
211
  contour_mask = contour_mask / np.max(contour_mask)
212
- contour_mask[contour_mask>0.5] = 1.
213
 
214
  return mask, contour_mask
215
 
@@ -218,7 +254,7 @@ def mask_generator_01(mask, background_radius, contour_radius):
218
  # no background width when '00'
219
  # distance map
220
  dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
221
- dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
222
  dist_map = dist_transform_fore - dist_transform_back
223
  # ...:::!!!:::...
224
  contour_radius += 2
@@ -230,7 +266,7 @@ def mask_generator_01(mask, background_radius, contour_radius):
230
  def mask_generator_10(mask, background_radius, contour_radius):
231
  # distance map
232
  dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
233
- dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
234
  dist_map = dist_transform_fore - dist_transform_back
235
  # .....:::::!!!!!
236
  background_mask = np.clip(dist_map, -background_radius, background_radius)
@@ -240,14 +276,14 @@ def mask_generator_10(mask, background_radius, contour_radius):
240
  contour_radius += 2
241
  contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
242
  contour_mask = contour_mask / np.max(contour_mask)
243
- contour_mask[contour_mask>0.5] = 1.
244
  return background_mask, contour_mask
245
 
246
 
247
  def mask_generator_11(mask, background_radius, contour_radius):
248
  # distance map
249
  dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
250
- dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
251
  dist_map = dist_transform_fore - dist_transform_back
252
  # .....:::::!!!!!
253
  background_mask = np.clip(dist_map, -background_radius, background_radius)
@@ -260,7 +296,8 @@ def mask_generator_11(mask, background_radius, contour_radius):
260
  return background_mask, contour_mask
261
 
262
 
263
- def mask_painter_wo_gaussian(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, mode='11'):
 
264
  """
265
  Input:
266
  input_image: numpy array
@@ -283,8 +320,8 @@ def mask_painter_wo_gaussian(input_image, input_mask, background_alpha=0.5, back
283
  width, height = input_image.shape[0], input_image.shape[1]
284
  res = 1024
285
  ratio = min(1.0 * res / max(width, height), 1.0)
286
- input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
287
- input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
288
 
289
  # 0: background, 1: foreground
290
  msk = np.clip(input_mask, 0, 1)
@@ -292,23 +329,78 @@ def mask_painter_wo_gaussian(input_image, input_mask, background_alpha=0.5, back
292
  # generate masks for background and contour pixels
293
  background_radius = (background_blur_radius - 1) // 2
294
  contour_radius = (contour_width - 1) // 2
295
- generator_dict = {'00':mask_generator_00, '01':mask_generator_01, '10':mask_generator_10, '11':mask_generator_11}
 
296
  background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
297
 
298
  # paint
299
  painted_image = vis_add_mask_wo_gaussian \
300
- (input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background
 
301
 
302
  return painted_image
303
 
304
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  if __name__ == '__main__':
306
 
307
- background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
308
- background_blur_radius = 31 # radius of background blur, must be odd number
309
- contour_width = 11 # contour width, must be odd number
310
- contour_color = 3 # id in color map, 0: black, 1: white, >1: others
311
- contour_alpha = 1 # transparency of background, 0: no contour highlighted
312
 
313
  # load input image and mask
314
  input_image = np.array(Image.open('./test_images/painter_input_image.jpg').convert('RGB'))
@@ -323,23 +415,28 @@ if __name__ == '__main__':
323
 
324
  for i in range(50):
325
  t2 = time.time()
326
- painted_image_00 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00')
 
327
  e2 = time.time()
328
 
329
  t3 = time.time()
330
- painted_image_10 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='10')
 
331
  e3 = time.time()
332
 
333
  t1 = time.time()
334
- painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha)
 
335
  e1 = time.time()
336
 
337
  t4 = time.time()
338
- painted_image_01 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='01')
 
339
  e4 = time.time()
340
 
341
  t5 = time.time()
342
- painted_image_11 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='11')
 
343
  e5 = time.time()
344
 
345
  overall_time_1 += (e1 - t1)
@@ -348,11 +445,11 @@ if __name__ == '__main__':
348
  overall_time_4 += (e4 - t4)
349
  overall_time_5 += (e5 - t5)
350
 
351
- print(f'average time w gaussian: {overall_time_1/50}')
352
- print(f'average time w/o gaussian00: {overall_time_2/50}')
353
- print(f'average time w/o gaussian10: {overall_time_3/50}')
354
- print(f'average time w/o gaussian01: {overall_time_4/50}')
355
- print(f'average time w/o gaussian11: {overall_time_5/50}')
356
 
357
  # save
358
  painted_image_00 = Image.fromarray(painted_image_00)
@@ -366,54 +463,3 @@ if __name__ == '__main__':
366
 
367
  painted_image_11 = Image.fromarray(painted_image_11)
368
  painted_image_11.save('./test_images/painter_output_image_11.png')
369
-
370
-
371
- seg_model_map = {
372
- 'base': 'vit_b',
373
- 'large': 'vit_l',
374
- 'huge': 'vit_h'
375
- }
376
- ckpt_url_map = {
377
- 'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
378
- 'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
379
- 'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
380
- }
381
- expected_sha256_map = {
382
- 'vit_b': 'ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912',
383
- 'vit_l': '3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622',
384
- 'vit_h': 'a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e'
385
- }
386
- def prepare_segmenter(segmenter = "huge", download_root: str = None):
387
- """
388
- Prepare segmenter model and download checkpoint if necessary.
389
-
390
- Returns: segmenter model name from 'vit_b', 'vit_l', 'vit_h'.
391
-
392
- """
393
-
394
- os.makedirs('result', exist_ok=True)
395
- seg_model_name = seg_model_map[segmenter]
396
- checkpoint_url = ckpt_url_map[seg_model_name]
397
- folder = download_root or os.path.expanduser("~/.cache/SAM")
398
- filename = os.path.basename(checkpoint_url)
399
- segmenter_checkpoint = download_checkpoint(checkpoint_url, folder, filename, expected_sha256_map[seg_model_name])
400
-
401
- return seg_model_name, segmenter_checkpoint
402
-
403
-
404
- def download_checkpoint(url, folder, filename, expected_sha256):
405
- os.makedirs(folder, exist_ok=True)
406
- download_target = os.path.join(folder, filename)
407
- if os.path.isfile(download_target):
408
- if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
409
- return download_target
410
-
411
- print(f'Download SAM checkpoint {url}, saving to {download_target} ...')
412
- with requests.get(url, stream=True) as response, open(download_target, "wb") as output:
413
- progress = tqdm(total=int(response.headers.get('content-length', 0)), unit='B', unit_scale=True)
414
- for data in response.iter_content(chunk_size=1024):
415
- size = output.write(data)
416
- progress.update(size)
417
- if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
418
- raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
419
- return download_target
 
1
  import os
2
+ import time
3
+ import sys
4
+
5
  import cv2
6
+ import hashlib
7
  import requests
8
  import numpy as np
9
+
10
+ from typing import Union
11
+
12
  from PIL import Image
 
 
 
13
  from tqdm import tqdm
14
+
15
+
16
+ def load_image(image: Union[np.ndarray, Image.Image, str], return_type='numpy'):
17
+ """
18
+ Load image from path or PIL.Image or numpy.ndarray to required format.
19
+ """
20
+
21
+ # Check if image is already in return_type
22
+ if isinstance(image, Image.Image) and return_type == 'pil' or \
23
+ isinstance(image, np.ndarray) and return_type == 'numpy':
24
+ return image
25
+
26
+ # PIL.Image as intermediate format
27
+ if isinstance(image, str):
28
+ image = Image.open(image)
29
+ elif isinstance(image, np.ndarray):
30
+ image = Image.fromarray(image)
31
+
32
+ if return_type == 'pil':
33
+ return image
34
+ elif return_type == 'numpy':
35
+ return np.asarray(image)
36
+ else:
37
+ raise NotImplementedError()
38
+
39
 
40
  def is_platform_win():
41
  return sys.platform == "win32"
 
142
  mask = mask.astype('float').copy()
143
  mask = (cv2.GaussianBlur(mask, (kernel_size, kernel_size), kernel_size) / 255.) * (alpha)
144
  for i in range(3):
145
+ image[:, :, i] = image[:, :, i] * (1 - alpha + mask) + color[i] * (alpha - mask)
146
  return image
147
 
148
 
 
150
  color = np.array(color)
151
  mask = mask.astype('float').copy()
152
  for i in range(3):
153
+ image[:, :, i] = image[:, :, i] * (1 - alpha + mask) + color[i] * (alpha - mask)
154
  return image
155
 
156
 
157
+ def vis_add_mask_wo_gaussian(image, background_mask, contour_mask, background_color, contour_color, background_alpha,
158
+ contour_alpha):
159
  background_color = np.array(background_color)
160
  contour_color = np.array(contour_color)
161
 
 
163
  # contour_mask = 1 - contour_mask
164
 
165
  for i in range(3):
166
+ image[:, :, i] = image[:, :, i] * (1 - background_alpha + background_mask * background_alpha) \
167
+ + background_color[i] * (background_alpha - background_mask * background_alpha)
168
 
169
+ image[:, :, i] = image[:, :, i] * (1 - contour_alpha + contour_mask * contour_alpha) \
170
+ + contour_color[i] * (contour_alpha - contour_mask * contour_alpha)
171
 
172
  return image.astype('uint8')
173
 
174
 
175
+ def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_radius=7, contour_width=3,
176
+ contour_color=3, contour_alpha=1, background_color=0, paint_foreground=False):
177
  """
178
  add color mask to the background/foreground area
179
  input_image: numpy array (w, h, C)
 
193
  assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
194
 
195
  # 0: background, 1: foreground
196
+ input_mask[input_mask > 0] = 255
197
  if paint_foreground:
198
+ painted_image = vis_add_mask(input_image, 255 - input_mask, color_list[background_color], background_alpha,
199
+ background_blur_radius) # black for background
200
  else:
201
+ # mask background
202
+ painted_image = vis_add_mask(input_image, input_mask, color_list[background_color], background_alpha,
203
+ background_blur_radius) # black for background
204
  # mask contour
205
  contour_mask = input_mask.copy()
206
+ contour_mask = cv2.Canny(contour_mask, 100, 200) # contour extraction
207
  # widden contour
208
  kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (contour_width, contour_width))
209
  contour_mask = cv2.dilate(contour_mask, kernel)
210
+ painted_image = vis_add_mask(painted_image, 255 - contour_mask, color_list[contour_color], contour_alpha,
211
+ contour_width)
212
  return painted_image
213
 
214
 
215
+ def mask_painter_foreground_all(input_image, input_masks, background_alpha=0.7, background_blur_radius=7,
216
+ contour_width=3, contour_color=3, contour_alpha=1):
217
  """
218
  paint color mask on the all foreground area
219
  input_image: numpy array with shape (w, h, C)
 
228
  Output:
229
  painted_image: numpy array
230
  """
231
+
232
  for i, input_mask in enumerate(input_masks):
233
+ input_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width,
234
+ contour_color, contour_alpha, background_color=i + 2, paint_foreground=True)
235
  return input_image
236
 
237
+
238
  def mask_generator_00(mask, background_radius, contour_radius):
239
  # no background width when '00'
240
  # distance map
241
  dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
242
+ dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3)
243
  dist_map = dist_transform_fore - dist_transform_back
244
  # ...:::!!!:::...
245
  contour_radius += 2
246
  contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
247
  contour_mask = contour_mask / np.max(contour_mask)
248
+ contour_mask[contour_mask > 0.5] = 1.
249
 
250
  return mask, contour_mask
251
 
 
254
  # no background width when '00'
255
  # distance map
256
  dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
257
+ dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3)
258
  dist_map = dist_transform_fore - dist_transform_back
259
  # ...:::!!!:::...
260
  contour_radius += 2
 
266
  def mask_generator_10(mask, background_radius, contour_radius):
267
  # distance map
268
  dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
269
+ dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3)
270
  dist_map = dist_transform_fore - dist_transform_back
271
  # .....:::::!!!!!
272
  background_mask = np.clip(dist_map, -background_radius, background_radius)
 
276
  contour_radius += 2
277
  contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
278
  contour_mask = contour_mask / np.max(contour_mask)
279
+ contour_mask[contour_mask > 0.5] = 1.
280
  return background_mask, contour_mask
281
 
282
 
283
  def mask_generator_11(mask, background_radius, contour_radius):
284
  # distance map
285
  dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
286
+ dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3)
287
  dist_map = dist_transform_fore - dist_transform_back
288
  # .....:::::!!!!!
289
  background_mask = np.clip(dist_map, -background_radius, background_radius)
 
296
  return background_mask, contour_mask
297
 
298
 
299
+ def mask_painter_wo_gaussian(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3,
300
+ contour_color=3, contour_alpha=1, mode='11'):
301
  """
302
  Input:
303
  input_image: numpy array
 
320
  width, height = input_image.shape[0], input_image.shape[1]
321
  res = 1024
322
  ratio = min(1.0 * res / max(width, height), 1.0)
323
+ input_image = cv2.resize(input_image, (int(height * ratio), int(width * ratio)))
324
+ input_mask = cv2.resize(input_mask, (int(height * ratio), int(width * ratio)))
325
 
326
  # 0: background, 1: foreground
327
  msk = np.clip(input_mask, 0, 1)
 
329
  # generate masks for background and contour pixels
330
  background_radius = (background_blur_radius - 1) // 2
331
  contour_radius = (contour_width - 1) // 2
332
+ generator_dict = {'00': mask_generator_00, '01': mask_generator_01, '10': mask_generator_10,
333
+ '11': mask_generator_11}
334
  background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
335
 
336
  # paint
337
  painted_image = vis_add_mask_wo_gaussian \
338
+ (input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha,
339
+ contour_alpha) # black for background
340
 
341
  return painted_image
342
 
343
 
344
+ seg_model_map = {
345
+ 'base': 'vit_b',
346
+ 'large': 'vit_l',
347
+ 'huge': 'vit_h'
348
+ }
349
+ ckpt_url_map = {
350
+ 'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
351
+ 'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
352
+ 'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
353
+ }
354
+ expected_sha256_map = {
355
+ 'vit_b': 'ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912',
356
+ 'vit_l': '3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622',
357
+ 'vit_h': 'a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e'
358
+ }
359
+
360
+
361
+ def prepare_segmenter(segmenter="huge", download_root: str = None):
362
+ """
363
+ Prepare segmenter model and download checkpoint if necessary.
364
+
365
+ Returns: segmenter model name from 'vit_b', 'vit_l', 'vit_h'.
366
+
367
+ """
368
+
369
+ os.makedirs('result', exist_ok=True)
370
+ seg_model_name = seg_model_map[segmenter]
371
+ checkpoint_url = ckpt_url_map[seg_model_name]
372
+ folder = download_root or os.path.expanduser("~/.cache/SAM")
373
+ filename = os.path.basename(checkpoint_url)
374
+ segmenter_checkpoint = download_checkpoint(checkpoint_url, folder, filename, expected_sha256_map[seg_model_name])
375
+
376
+ return seg_model_name, segmenter_checkpoint
377
+
378
+
379
+ def download_checkpoint(url, folder, filename, expected_sha256):
380
+ os.makedirs(folder, exist_ok=True)
381
+ download_target = os.path.join(folder, filename)
382
+ if os.path.isfile(download_target):
383
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
384
+ return download_target
385
+
386
+ print(f'Download SAM checkpoint {url}, saving to {download_target} ...')
387
+ with requests.get(url, stream=True) as response, open(download_target, "wb") as output:
388
+ progress = tqdm(total=int(response.headers.get('content-length', 0)), unit='B', unit_scale=True)
389
+ for data in response.iter_content(chunk_size=1024):
390
+ size = output.write(data)
391
+ progress.update(size)
392
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
393
+ raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
394
+ return download_target
395
+
396
+
397
  if __name__ == '__main__':
398
 
399
+ background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
400
+ background_blur_radius = 31 # radius of background blur, must be odd number
401
+ contour_width = 11 # contour width, must be odd number
402
+ contour_color = 3 # id in color map, 0: black, 1: white, >1: others
403
+ contour_alpha = 1 # transparency of background, 0: no contour highlighted
404
 
405
  # load input image and mask
406
  input_image = np.array(Image.open('./test_images/painter_input_image.jpg').convert('RGB'))
 
415
 
416
  for i in range(50):
417
  t2 = time.time()
418
+ painted_image_00 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius,
419
+ contour_width, contour_color, contour_alpha, mode='00')
420
  e2 = time.time()
421
 
422
  t3 = time.time()
423
+ painted_image_10 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius,
424
+ contour_width, contour_color, contour_alpha, mode='10')
425
  e3 = time.time()
426
 
427
  t1 = time.time()
428
+ painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width,
429
+ contour_color, contour_alpha)
430
  e1 = time.time()
431
 
432
  t4 = time.time()
433
+ painted_image_01 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius,
434
+ contour_width, contour_color, contour_alpha, mode='01')
435
  e4 = time.time()
436
 
437
  t5 = time.time()
438
+ painted_image_11 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius,
439
+ contour_width, contour_color, contour_alpha, mode='11')
440
  e5 = time.time()
441
 
442
  overall_time_1 += (e1 - t1)
 
445
  overall_time_4 += (e4 - t4)
446
  overall_time_5 += (e5 - t5)
447
 
448
+ print(f'average time w gaussian: {overall_time_1 / 50}')
449
+ print(f'average time w/o gaussian00: {overall_time_2 / 50}')
450
+ print(f'average time w/o gaussian10: {overall_time_3 / 50}')
451
+ print(f'average time w/o gaussian01: {overall_time_4 / 50}')
452
+ print(f'average time w/o gaussian11: {overall_time_5 / 50}')
453
 
454
  # save
455
  painted_image_00 = Image.fromarray(painted_image_00)
 
463
 
464
  painted_image_11 = Image.fromarray(painted_image_11)
465
  painted_image_11.save('./test_images/painter_output_image_11.png')