liuyizhang commited on
Commit
a71406a
1 Parent(s): 9403943

update files

Browse files
automatic_label_demo.py CHANGED
@@ -43,20 +43,23 @@ def load_image(image_path):
43
  return image_pil, image
44
 
45
 
46
- def generate_caption(raw_image):
47
  # unconditional image captioning
48
- inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
 
 
 
49
  out = blip_model.generate(**inputs)
50
  caption = processor.decode(out[0], skip_special_tokens=True)
51
  return caption
52
 
53
 
54
- def generate_tags(caption, max_tokens=100, model="gpt-3.5-turbo"):
55
  prompt = [
56
  {
57
  'role': 'system',
58
- 'content': 'Extrat the unique nouns in the caption. Remove all the adjectives. ' + \
59
- 'List the nouns in singular form. Split them by ".". ' + \
60
  f'Caption: {caption}.'
61
  }
62
  ]
@@ -197,6 +200,7 @@ if __name__ == "__main__":
197
  "--sam_checkpoint", type=str, required=True, help="path to checkpoint file"
198
  )
199
  parser.add_argument("--input_image", type=str, required=True, help="path to image file")
 
200
  parser.add_argument("--openai_key", type=str, required=True, help="key for chatgpt")
201
  parser.add_argument("--openai_proxy", default=None, type=str, help="proxy for chatgpt")
202
  parser.add_argument(
@@ -215,6 +219,7 @@ if __name__ == "__main__":
215
  grounded_checkpoint = args.grounded_checkpoint # change the path of the model
216
  sam_checkpoint = args.sam_checkpoint
217
  image_path = args.input_image
 
218
  openai_key = args.openai_key
219
  openai_proxy = args.openai_proxy
220
  output_dir = args.output_dir
@@ -242,9 +247,14 @@ if __name__ == "__main__":
242
  # https://huggingface.co/spaces/xinyu1205/Tag2Text
243
  # but there are some bugs...
244
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
245
- blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
246
- caption = generate_caption(image_pil)
247
- text_prompt = generate_tags(caption)
 
 
 
 
 
248
  print(f"Caption: {caption}")
249
  print(f"Tags: {text_prompt}")
250
 
 
43
  return image_pil, image
44
 
45
 
46
+ def generate_caption(raw_image, device):
47
  # unconditional image captioning
48
+ if device == "cuda":
49
+ inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
50
+ else:
51
+ inputs = processor(raw_image, return_tensors="pt")
52
  out = blip_model.generate(**inputs)
53
  caption = processor.decode(out[0], skip_special_tokens=True)
54
  return caption
55
 
56
 
57
+ def generate_tags(caption, split=',', max_tokens=100, model="gpt-3.5-turbo"):
58
  prompt = [
59
  {
60
  'role': 'system',
61
+ 'content': 'Extract the unique nouns in the caption. Remove all the adjectives. ' + \
62
+ f'List the nouns in singular form. Split them by "{split} ". ' + \
63
  f'Caption: {caption}.'
64
  }
65
  ]
 
200
  "--sam_checkpoint", type=str, required=True, help="path to checkpoint file"
201
  )
202
  parser.add_argument("--input_image", type=str, required=True, help="path to image file")
203
+ parser.add_argument("--split", default=",", type=str, help="split for text prompt")
204
  parser.add_argument("--openai_key", type=str, required=True, help="key for chatgpt")
205
  parser.add_argument("--openai_proxy", default=None, type=str, help="proxy for chatgpt")
206
  parser.add_argument(
 
219
  grounded_checkpoint = args.grounded_checkpoint # change the path of the model
220
  sam_checkpoint = args.sam_checkpoint
221
  image_path = args.input_image
222
+ split = args.split
223
  openai_key = args.openai_key
224
  openai_proxy = args.openai_proxy
225
  output_dir = args.output_dir
 
247
  # https://huggingface.co/spaces/xinyu1205/Tag2Text
248
  # but there are some bugs...
249
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
250
+ if device == "cuda":
251
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
252
+ else:
253
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
254
+ caption = generate_caption(image_pil, device=device)
255
+ # Currently ", " is better for detecting single tags
256
+ # while ". " is a little worse in some case
257
+ text_prompt = generate_tags(caption, split=split)
258
  print(f"Caption: {caption}")
259
  print(f"Tags: {text_prompt}")
260
 
gradio_app.py CHANGED
@@ -1,11 +1,13 @@
 
 
1
  import gradio as gr
2
 
3
  import argparse
4
- import os
5
  import copy
6
 
7
  import numpy as np
8
  import torch
 
9
  from PIL import Image, ImageDraw, ImageFont
10
 
11
  # Grounding DINO
@@ -30,6 +32,10 @@ from io import BytesIO
30
  from diffusers import StableDiffusionInpaintPipeline
31
  from huggingface_hub import hf_hub_download
32
 
 
 
 
 
33
  def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
34
  args = SLConfig.fromfile(model_config_path)
35
  model = build_model(args)
@@ -42,6 +48,13 @@ def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
42
  _ = model.eval()
43
  return model
44
 
 
 
 
 
 
 
 
45
  def plot_boxes_to_image(image_pil, tgt):
46
  H, W = tgt["size"]
47
  boxes = tgt["boxes"]
@@ -135,14 +148,16 @@ def get_grounding_output(model, image, caption, box_threshold, text_threshold, w
135
  tokenized = tokenlizer(caption)
136
  # build pred
137
  pred_phrases = []
 
138
  for logit, box in zip(logits_filt, boxes_filt):
139
  pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
140
  if with_logits:
141
  pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
142
  else:
143
  pred_phrases.append(pred_phrase)
 
144
 
145
- return boxes_filt, pred_phrases
146
 
147
  def show_mask(mask, ax, random_color=False):
148
  if random_color:
@@ -164,12 +179,11 @@ def show_box(box, ax, label):
164
  config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
165
  ckpt_repo_id = "ShilongLiu/GroundingDINO"
166
  ckpt_filenmae = "groundingdino_swint_ogc.pth"
167
- sam_checkpoint='/home/ecs-user/download/sam_vit_h_4b8939.pth'
168
  output_dir="outputs"
169
  device="cuda"
170
 
171
- def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold):
172
- assert text_prompt, 'text_prompt is not found!'
173
 
174
  # make dir
175
  os.makedirs(output_dir, exist_ok=True)
@@ -177,18 +191,29 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr
177
  image_pil, image = load_image(image_path.convert("RGB"))
178
  # load model
179
  model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
 
180
 
181
  # visualize raw image
182
  image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
183
 
 
 
 
 
 
 
 
 
 
 
184
  # run grounding dino model
185
- boxes_filt, pred_phrases = get_grounding_output(
186
  model, image, text_prompt, box_threshold, text_threshold, device=device
187
  )
188
 
189
  size = image_pil.size
190
 
191
- if task_type == 'seg' or task_type == 'inpainting':
192
  # initialize SAM
193
  predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
194
  image = np.array(image_path)
@@ -201,6 +226,16 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr
201
  boxes_filt[i][2:] += boxes_filt[i][:2]
202
 
203
  boxes_filt = boxes_filt.cpu()
 
 
 
 
 
 
 
 
 
 
204
  transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
205
 
206
  masks, _, _ = predictor.predict_torch(
@@ -224,7 +259,7 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr
224
  image_with_box.save(image_path)
225
  image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
226
  return image_result
227
- elif task_type == 'seg':
228
  assert sam_checkpoint, 'sam_checkpoint is not found!'
229
 
230
  # draw output image
@@ -234,6 +269,8 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr
234
  show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
235
  for box, label in zip(boxes_filt, pred_phrases):
236
  show_box(box.numpy(), plt.gca(), label)
 
 
237
  plt.axis('off')
238
  image_path = os.path.join(output_dir, "grounding_dino_output.jpg")
239
  plt.savefig(image_path, bbox_inches="tight")
@@ -242,16 +279,24 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr
242
  elif task_type == 'inpainting':
243
  assert inpaint_prompt, 'inpaint_prompt is not found!'
244
  # inpainting pipeline
245
- mask = masks[0][0].cpu().numpy() # simply choose the first mask, which will be refine in the future release
 
 
 
 
246
  mask_pil = Image.fromarray(mask)
247
- image_pil = Image.fromarray(image)
248
 
249
  pipe = StableDiffusionInpaintPipeline.from_pretrained(
250
  "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
251
  )
252
  pipe = pipe.to("cuda")
253
 
 
 
 
254
  image = pipe(prompt=inpaint_prompt, image=image_pil, mask_image=mask_pil).images[0]
 
 
255
  image_path = os.path.join(output_dir, "grounded_sam_inpainting_output.jpg")
256
  image.save(image_path)
257
  image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
@@ -264,15 +309,16 @@ if __name__ == "__main__":
264
  parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
265
  parser.add_argument("--debug", action="store_true", help="using debug mode")
266
  parser.add_argument("--share", action="store_true", help="share the app")
 
267
  args = parser.parse_args()
268
 
269
  block = gr.Blocks().queue()
270
  with block:
271
  with gr.Row():
272
  with gr.Column():
273
- input_image = gr.Image(source='upload', type="pil")
274
- text_prompt = gr.Textbox(label="Detection Prompt")
275
- task_type = gr.Textbox(label="task type: det/seg/inpainting")
276
  inpaint_prompt = gr.Textbox(label="Inpaint Prompt")
277
  run_button = gr.Button(label="Run")
278
  with gr.Accordion("Advanced options", open=False):
@@ -282,6 +328,10 @@ if __name__ == "__main__":
282
  text_threshold = gr.Slider(
283
  label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
284
  )
 
 
 
 
285
 
286
  with gr.Column():
287
  gallery = gr.outputs.Image(
@@ -289,7 +339,7 @@ if __name__ == "__main__":
289
  ).style(full_width=True, full_height=True)
290
 
291
  run_button.click(fn=run_grounded_sam, inputs=[
292
- input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold], outputs=[gallery])
293
 
294
 
295
- block.launch(server_name='0.0.0.0', server_port=7589, debug=args.debug, share=args.share)
 
1
+ import os
2
+ # os.system('pip install v0.1.0-alpha2.tar.gz')
3
  import gradio as gr
4
 
5
  import argparse
 
6
  import copy
7
 
8
  import numpy as np
9
  import torch
10
+ import torchvision
11
  from PIL import Image, ImageDraw, ImageFont
12
 
13
  # Grounding DINO
 
32
  from diffusers import StableDiffusionInpaintPipeline
33
  from huggingface_hub import hf_hub_download
34
 
35
+ # BLIP
36
+ from transformers import BlipProcessor, BlipForConditionalGeneration
37
+
38
+
39
  def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
40
  args = SLConfig.fromfile(model_config_path)
41
  model = build_model(args)
 
48
  _ = model.eval()
49
  return model
50
 
51
+ def generate_caption(processor, blip_model, raw_image):
52
+ # unconditional image captioning
53
+ inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
54
+ out = blip_model.generate(**inputs)
55
+ caption = processor.decode(out[0], skip_special_tokens=True)
56
+ return caption
57
+
58
  def plot_boxes_to_image(image_pil, tgt):
59
  H, W = tgt["size"]
60
  boxes = tgt["boxes"]
 
148
  tokenized = tokenlizer(caption)
149
  # build pred
150
  pred_phrases = []
151
+ scores = []
152
  for logit, box in zip(logits_filt, boxes_filt):
153
  pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
154
  if with_logits:
155
  pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
156
  else:
157
  pred_phrases.append(pred_phrase)
158
+ scores.append(logit.max().item())
159
 
160
+ return boxes_filt, torch.Tensor(scores), pred_phrases
161
 
162
  def show_mask(mask, ax, random_color=False):
163
  if random_color:
 
179
  config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
180
  ckpt_repo_id = "ShilongLiu/GroundingDINO"
181
  ckpt_filenmae = "groundingdino_swint_ogc.pth"
182
+ sam_checkpoint='sam_vit_h_4b8939.pth'
183
  output_dir="outputs"
184
  device="cuda"
185
 
186
+ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode):
 
187
 
188
  # make dir
189
  os.makedirs(output_dir, exist_ok=True)
 
191
  image_pil, image = load_image(image_path.convert("RGB"))
192
  # load model
193
  model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
194
+ # model = load_model(config_file, ckpt_filenmae, device=device)
195
 
196
  # visualize raw image
197
  image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
198
 
199
+ if task_type == 'automatic':
200
+ # generate caption and tags
201
+ # use Tag2Text can generate better captions
202
+ # https://huggingface.co/spaces/xinyu1205/Tag2Text
203
+ # but there are some bugs...
204
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
205
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
206
+ text_prompt = generate_caption(processor, blip_model, image_pil)
207
+ print(f"Caption: {text_prompt}")
208
+
209
  # run grounding dino model
210
+ boxes_filt, scores, pred_phrases = get_grounding_output(
211
  model, image, text_prompt, box_threshold, text_threshold, device=device
212
  )
213
 
214
  size = image_pil.size
215
 
216
+ if task_type == 'seg' or task_type == 'inpainting' or task_type == 'automatic':
217
  # initialize SAM
218
  predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
219
  image = np.array(image_path)
 
226
  boxes_filt[i][2:] += boxes_filt[i][:2]
227
 
228
  boxes_filt = boxes_filt.cpu()
229
+
230
+ if task_type == 'automatic':
231
+ # use NMS to handle overlapped boxes
232
+ print(f"Before NMS: {boxes_filt.shape[0]} boxes")
233
+ nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
234
+ boxes_filt = boxes_filt[nms_idx]
235
+ pred_phrases = [pred_phrases[idx] for idx in nms_idx]
236
+ print(f"After NMS: {boxes_filt.shape[0]} boxes")
237
+ print(f"Revise caption with number: {text_prompt}")
238
+
239
  transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
240
 
241
  masks, _, _ = predictor.predict_torch(
 
259
  image_with_box.save(image_path)
260
  image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
261
  return image_result
262
+ elif task_type == 'seg' or task_type == 'automatic':
263
  assert sam_checkpoint, 'sam_checkpoint is not found!'
264
 
265
  # draw output image
 
269
  show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
270
  for box, label in zip(boxes_filt, pred_phrases):
271
  show_box(box.numpy(), plt.gca(), label)
272
+ if task_type == 'automatic':
273
+ plt.title(text_prompt)
274
  plt.axis('off')
275
  image_path = os.path.join(output_dir, "grounding_dino_output.jpg")
276
  plt.savefig(image_path, bbox_inches="tight")
 
279
  elif task_type == 'inpainting':
280
  assert inpaint_prompt, 'inpaint_prompt is not found!'
281
  # inpainting pipeline
282
+ if inpaint_mode == 'merge':
283
+ masks = torch.sum(masks, dim=0).unsqueeze(0)
284
+ masks = torch.where(masks > 0, True, False)
285
+ else:
286
+ mask = masks[0][0].cpu().numpy() # simply choose the first mask, which will be refine in the future release
287
  mask_pil = Image.fromarray(mask)
 
288
 
289
  pipe = StableDiffusionInpaintPipeline.from_pretrained(
290
  "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
291
  )
292
  pipe = pipe.to("cuda")
293
 
294
+ image_pil = image_pil.resize((512, 512))
295
+ mask_pil = mask_pil.resize((512, 512))
296
+
297
  image = pipe(prompt=inpaint_prompt, image=image_pil, mask_image=mask_pil).images[0]
298
+ image = image.resize(size)
299
+
300
  image_path = os.path.join(output_dir, "grounded_sam_inpainting_output.jpg")
301
  image.save(image_path)
302
  image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
 
309
  parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
310
  parser.add_argument("--debug", action="store_true", help="using debug mode")
311
  parser.add_argument("--share", action="store_true", help="share the app")
312
+ parser.add_argument('--port', type=int, default=7589, help='port to run the server')
313
  args = parser.parse_args()
314
 
315
  block = gr.Blocks().queue()
316
  with block:
317
  with gr.Row():
318
  with gr.Column():
319
+ input_image = gr.Image(source='upload', type="pil", value="assets/demo1.jpg")
320
+ task_type = gr.Dropdown(["det", "seg", "inpainting", "automatic"], value="automatic", label="task_type")
321
+ text_prompt = gr.Textbox(label="Text Prompt")
322
  inpaint_prompt = gr.Textbox(label="Inpaint Prompt")
323
  run_button = gr.Button(label="Run")
324
  with gr.Accordion("Advanced options", open=False):
 
328
  text_threshold = gr.Slider(
329
  label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
330
  )
331
+ iou_threshold = gr.Slider(
332
+ label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.001
333
+ )
334
+ inpaint_mode = gr.Dropdown(["merge", "first"], value="merge", label="inpaint_mode")
335
 
336
  with gr.Column():
337
  gallery = gr.outputs.Image(
 
339
  ).style(full_width=True, full_height=True)
340
 
341
  run_button.click(fn=run_grounded_sam, inputs=[
342
+ input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode], outputs=[gallery])
343
 
344
 
345
+ block.launch(server_name='0.0.0.0', server_port=args.port, debug=args.debug, share=args.share)
gradio_auto_label.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import argparse
4
+ import os
5
+ import copy
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torchvision
10
+ from PIL import Image, ImageDraw, ImageFont
11
+ import openai
12
+ # Grounding DINO
13
+ import GroundingDINO.groundingdino.datasets.transforms as T
14
+ from GroundingDINO.groundingdino.models import build_model
15
+ from GroundingDINO.groundingdino.util import box_ops
16
+ from GroundingDINO.groundingdino.util.slconfig import SLConfig
17
+ from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
18
+ from transformers import BlipProcessor, BlipForConditionalGeneration
19
+ # segment anything
20
+ from segment_anything import build_sam, SamPredictor
21
+ from segment_anything.utils.amg import remove_small_regions
22
+ import cv2
23
+ import numpy as np
24
+ import matplotlib.pyplot as plt
25
+
26
+
27
+ # diffusers
28
+ import PIL
29
+ import requests
30
+ import torch
31
+ from io import BytesIO
32
+ from huggingface_hub import hf_hub_download
33
+ from sys import platform
34
+
35
+ #macos
36
+ if platform == 'darwin':
37
+ import matplotlib
38
+ matplotlib.use('agg')
39
+
40
+ def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
41
+ args = SLConfig.fromfile(model_config_path)
42
+ model = build_model(args)
43
+ args.device = device
44
+
45
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
46
+ checkpoint = torch.load(cache_file, map_location='cpu')
47
+ log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
48
+ print("Model loaded from {} \n => {}".format(cache_file, log))
49
+ _ = model.eval()
50
+ return model
51
+
52
+ def plot_boxes_to_image(image_pil, tgt):
53
+ H, W = tgt["size"]
54
+ boxes = tgt["boxes"]
55
+ labels = tgt["labels"]
56
+ assert len(boxes) == len(labels), "boxes and labels must have same length"
57
+
58
+ draw = ImageDraw.Draw(image_pil)
59
+ mask = Image.new("L", image_pil.size, 0)
60
+ mask_draw = ImageDraw.Draw(mask)
61
+
62
+ # draw boxes and masks
63
+ for box, label in zip(boxes, labels):
64
+ # from 0..1 to 0..W, 0..H
65
+ box = box * torch.Tensor([W, H, W, H])
66
+ # from xywh to xyxy
67
+ box[:2] -= box[2:] / 2
68
+ box[2:] += box[:2]
69
+ # random color
70
+ color = tuple(np.random.randint(0, 255, size=3).tolist())
71
+ # draw
72
+ x0, y0, x1, y1 = box
73
+ x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
74
+
75
+ draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
76
+ # draw.text((x0, y0), str(label), fill=color)
77
+
78
+ font = ImageFont.load_default()
79
+ if hasattr(font, "getbbox"):
80
+ bbox = draw.textbbox((x0, y0), str(label), font)
81
+ else:
82
+ w, h = draw.textsize(str(label), font)
83
+ bbox = (x0, y0, w + x0, y0 + h)
84
+ # bbox = draw.textbbox((x0, y0), str(label))
85
+ draw.rectangle(bbox, fill=color)
86
+ draw.text((x0, y0), str(label), fill="white")
87
+
88
+ mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6)
89
+
90
+ return image_pil, mask
91
+
92
+ def load_image(image_path):
93
+ # # load image
94
+ # image_pil = Image.open(image_path).convert("RGB") # load image
95
+ image_pil = image_path
96
+
97
+ transform = T.Compose(
98
+ [
99
+ T.RandomResize([800], max_size=1333),
100
+ T.ToTensor(),
101
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
102
+ ]
103
+ )
104
+ image, _ = transform(image_pil, None) # 3, h, w
105
+ return image_pil, image
106
+
107
+
108
+ def load_model(model_config_path, model_checkpoint_path, device):
109
+ args = SLConfig.fromfile(model_config_path)
110
+ args.device = device
111
+ model = build_model(args)
112
+ checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
113
+ load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
114
+ _ = model.eval()
115
+ return model
116
+
117
+
118
+ def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
119
+ caption = caption.lower()
120
+ caption = caption.strip()
121
+ if not caption.endswith("."):
122
+ caption = caption + "."
123
+ model = model.to(device)
124
+ image = image.to(device)
125
+ with torch.no_grad():
126
+ outputs = model(image[None], captions=[caption])
127
+ logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
128
+ boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
129
+ logits.shape[0]
130
+
131
+ # filter output
132
+ logits_filt = logits.clone()
133
+ boxes_filt = boxes.clone()
134
+ filt_mask = logits_filt.max(dim=1)[0] > box_threshold
135
+ logits_filt = logits_filt[filt_mask] # num_filt, 256
136
+ boxes_filt = boxes_filt[filt_mask] # num_filt, 4
137
+ logits_filt.shape[0]
138
+
139
+ # get phrase
140
+ tokenlizer = model.tokenizer
141
+ tokenized = tokenlizer(caption)
142
+ # build pred
143
+ pred_phrases = []
144
+ scores = []
145
+ for logit, box in zip(logits_filt, boxes_filt):
146
+ pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
147
+ if with_logits:
148
+ pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
149
+ else:
150
+ pred_phrases.append(pred_phrase)
151
+ scores.append(logit.max().item())
152
+
153
+ return boxes_filt, torch.Tensor(scores), pred_phrases
154
+
155
+ def show_mask(mask, ax, random_color=False):
156
+ if random_color:
157
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
158
+ else:
159
+ color = np.array([30/255, 144/255, 255/255, 0.6])
160
+ h, w = mask.shape[-2:]
161
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
162
+ ax.imshow(mask_image)
163
+
164
+ def save_mask_data(output_dir, mask_list, box_list, label_list):
165
+ value = 0 # 0 for background
166
+
167
+ mask_img = torch.zeros(mask_list.shape[-2:])
168
+ for idx, mask in enumerate(mask_list):
169
+ mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1
170
+ plt.figure(figsize=(10, 10))
171
+ plt.imshow(mask_img.numpy())
172
+ plt.axis('off')
173
+ mask_img_path = os.path.join(output_dir, 'mask.jpg')
174
+ plt.savefig(mask_img_path, bbox_inches="tight", dpi=300, pad_inches=0.0)
175
+
176
+ json_data = [{
177
+ 'value': value,
178
+ 'label': 'background'
179
+ }]
180
+ for label, box in zip(label_list, box_list):
181
+ value += 1
182
+ name, logit = label.split('(')
183
+ logit = logit[:-1] # the last is ')'
184
+ json_data.append({
185
+ 'value': value,
186
+ 'label': name,
187
+ 'logit': float(logit),
188
+ 'box': box.numpy().tolist(),
189
+ })
190
+
191
+ mask_json_path = os.path.join(output_dir, 'mask.json')
192
+ with open(mask_json_path, 'w') as f:
193
+ json.dump(json_data, f)
194
+
195
+ return mask_img_path, mask_json_path
196
+
197
+ def show_box(box, ax, label):
198
+ x0, y0 = box[0], box[1]
199
+ w, h = box[2] - box[0], box[3] - box[1]
200
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
201
+ ax.text(x0, y0, label)
202
+
203
+ config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
204
+ ckpt_repo_id = "ShilongLiu/GroundingDINO"
205
+ ckpt_filenmae = "groundingdino_swint_ogc.pth"
206
+ sam_checkpoint='sam_vit_h_4b8939.pth'
207
+ output_dir="outputs"
208
+ device="cpu"
209
+
210
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
211
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
212
+
213
+ def generate_caption(raw_image):
214
+ # unconditional image captioning
215
+ inputs = processor(raw_image, return_tensors="pt")
216
+ out = blip_model.generate(**inputs)
217
+ caption = processor.decode(out[0], skip_special_tokens=True)
218
+ return caption
219
+
220
+
221
+ def generate_tags(caption, split=',', max_tokens=100, model="gpt-3.5-turbo", openai_key=''):
222
+ openai.api_key = openai_key
223
+ prompt = [
224
+ {
225
+ 'role': 'system',
226
+ 'content': 'Extract the unique nouns in the caption. Remove all the adjectives. ' + \
227
+ f'List the nouns in singular form. Split them by "{split} ". ' + \
228
+ f'Caption: {caption}.'
229
+ }
230
+ ]
231
+ response = openai.ChatCompletion.create(model=model, messages=prompt, temperature=0.6, max_tokens=max_tokens)
232
+ reply = response['choices'][0]['message']['content']
233
+ # sometimes return with "noun: xxx, xxx, xxx"
234
+ tags = reply.split(':')[-1].strip()
235
+ return tags
236
+
237
+ def check_caption(caption, pred_phrases, max_tokens=100, model="gpt-3.5-turbo"):
238
+ object_list = [obj.split('(')[0] for obj in pred_phrases]
239
+ object_num = []
240
+ for obj in set(object_list):
241
+ object_num.append(f'{object_list.count(obj)} {obj}')
242
+ object_num = ', '.join(object_num)
243
+ print(f"Correct object number: {object_num}")
244
+
245
+ prompt = [
246
+ {
247
+ 'role': 'system',
248
+ 'content': 'Revise the number in the caption if it is wrong. ' + \
249
+ f'Caption: {caption}. ' + \
250
+ f'True object number: {object_num}. ' + \
251
+ 'Only give the revised caption: '
252
+ }
253
+ ]
254
+ response = openai.ChatCompletion.create(model=model, messages=prompt, temperature=0.6, max_tokens=max_tokens)
255
+ reply = response['choices'][0]['message']['content']
256
+ # sometimes return with "Caption: xxx, xxx, xxx"
257
+ caption = reply.split(':')[-1].strip()
258
+ return caption
259
+
260
+ def run_grounded_sam(image_path, openai_key, box_threshold, text_threshold, iou_threshold, area_threshold):
261
+ assert openai_key, 'Openai key is not found!'
262
+
263
+ # make dir
264
+ os.makedirs(output_dir, exist_ok=True)
265
+ # load image
266
+ image_pil, image = load_image(image_path.convert("RGB"))
267
+ # load model
268
+ model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
269
+
270
+ # visualize raw image
271
+ image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
272
+
273
+ caption = generate_caption(image_pil)
274
+ # Currently ", " is better for detecting single tags
275
+ # while ". " is a little worse in some case
276
+ split = ','
277
+ tags = generate_tags(caption, split=split, openai_key=openai_key)
278
+
279
+ # run grounding dino model
280
+ boxes_filt, scores, pred_phrases = get_grounding_output(
281
+ model, image, tags, box_threshold, text_threshold, device=device
282
+ )
283
+
284
+ size = image_pil.size
285
+
286
+ # initialize SAM
287
+ predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
288
+ image = np.array(image_path)
289
+ predictor.set_image(image)
290
+
291
+ H, W = size[1], size[0]
292
+ for i in range(boxes_filt.size(0)):
293
+ boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
294
+ boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
295
+ boxes_filt[i][2:] += boxes_filt[i][:2]
296
+
297
+ boxes_filt = boxes_filt.cpu()
298
+ # use NMS to handle overlapped boxes
299
+ print(f"Before NMS: {boxes_filt.shape[0]} boxes")
300
+ nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
301
+ boxes_filt = boxes_filt[nms_idx]
302
+ pred_phrases = [pred_phrases[idx] for idx in nms_idx]
303
+ print(f"After NMS: {boxes_filt.shape[0]} boxes")
304
+ caption = check_caption(caption, pred_phrases)
305
+ print(f"Revise caption with number: {caption}")
306
+
307
+ transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
308
+
309
+ masks, _, _ = predictor.predict_torch(
310
+ point_coords = None,
311
+ point_labels = None,
312
+ boxes = transformed_boxes,
313
+ multimask_output = False,
314
+ )
315
+ # area threshold: remove the mask when area < area_thresh (in pixels)
316
+ new_masks = []
317
+ for mask in masks:
318
+ # reshape to be used in remove_small_regions()
319
+ mask = mask.cpu().numpy().squeeze()
320
+ mask, _ = remove_small_regions(mask, area_threshold, mode="holes")
321
+ mask, _ = remove_small_regions(mask, area_threshold, mode="islands")
322
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
323
+
324
+ masks = torch.stack(new_masks, dim=0)
325
+ # masks: [1, 1, 512, 512]
326
+ assert sam_checkpoint, 'sam_checkpoint is not found!'
327
+
328
+ # draw output image
329
+ plt.figure(figsize=(10, 10))
330
+ plt.imshow(image)
331
+ for mask in masks:
332
+ show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
333
+ for box, label in zip(boxes_filt, pred_phrases):
334
+ show_box(box.numpy(), plt.gca(), label)
335
+ plt.axis('off')
336
+ image_path = os.path.join(output_dir, "grounding_dino_output.jpg")
337
+ plt.savefig(image_path, bbox_inches="tight")
338
+ image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
339
+
340
+ mask_img_path, _ = save_mask_data('./outputs', masks, boxes_filt, pred_phrases)
341
+
342
+ mask_img = cv2.cvtColor(cv2.imread(mask_img_path), cv2.COLOR_BGR2RGB)
343
+
344
+ return image_result, mask_img, caption, tags
345
+
346
+ if __name__ == "__main__":
347
+
348
+ parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
349
+ parser.add_argument("--debug", action="store_true", help="using debug mode")
350
+ parser.add_argument("--share", action="store_true", help="share the app")
351
+ args = parser.parse_args()
352
+
353
+ block = gr.Blocks().queue()
354
+ with block:
355
+ with gr.Row():
356
+ with gr.Column():
357
+ input_image = gr.Image(source='upload', type="pil")
358
+ openai_key = gr.Textbox(label="OpenAI key")
359
+
360
+ run_button = gr.Button(label="Run")
361
+ with gr.Accordion("Advanced options", open=False):
362
+ box_threshold = gr.Slider(
363
+ label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
364
+ )
365
+ text_threshold = gr.Slider(
366
+ label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
367
+ )
368
+ iou_threshold = gr.Slider(
369
+ label="IoU Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.001
370
+ )
371
+ area_threshold = gr.Slider(
372
+ label="Area Threshold", minimum=0.0, maximum=2500, value=100, step=10
373
+ )
374
+
375
+ with gr.Column():
376
+ image_caption = gr.Textbox(label="Image Caption")
377
+ identified_labels = gr.Textbox(label="Key objects extracted by ChatGPT")
378
+ gallery = gr.outputs.Image(
379
+ type="pil",
380
+ ).style(full_width=True, full_height=True)
381
+
382
+ mask_gallary = gr.outputs.Image(
383
+ type="pil",
384
+ ).style(full_width=True, full_height=True)
385
+
386
+
387
+ run_button.click(fn=run_grounded_sam, inputs=[
388
+ input_image, openai_key, box_threshold, text_threshold, iou_threshold, area_threshold],
389
+ outputs=[gallery, mask_gallary, image_caption, identified_labels])
390
+
391
+
392
+ block.launch(server_name='0.0.0.0', server_port=7589, debug=args.debug, share=args.share)
grounded_sam.ipynb CHANGED
@@ -53,12 +53,21 @@
53
  },
54
  {
55
  "cell_type": "code",
56
- "execution_count": 187,
57
  "metadata": {},
58
  "outputs": [],
59
  "source": [
60
- "import os\n",
61
  "\n",
 
 
 
 
 
 
 
 
 
62
  "# If you have multiple GPUs, you can set the GPU to use here.\n",
63
  "# The default is to use the first GPU, which is usually GPU 0.\n",
64
  "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
@@ -85,7 +94,7 @@
85
  "from GroundingDINO.groundingdino.util import box_ops\n",
86
  "from GroundingDINO.groundingdino.util.slconfig import SLConfig\n",
87
  "from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap\n",
88
- "from groundingdino.util.inference import annotate, load_image, predict\n",
89
  "\n",
90
  "import supervision as sv\n",
91
  "\n",
 
53
  },
54
  {
55
  "cell_type": "code",
56
+ "execution_count": null,
57
  "metadata": {},
58
  "outputs": [],
59
  "source": [
60
+ "import os, sys\n",
61
  "\n",
62
+ "sys.path.append(os.path.join(os.getcwd(), \"GroundingDINO\"))"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": 187,
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": [
71
  "# If you have multiple GPUs, you can set the GPU to use here.\n",
72
  "# The default is to use the first GPU, which is usually GPU 0.\n",
73
  "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
 
94
  "from GroundingDINO.groundingdino.util import box_ops\n",
95
  "from GroundingDINO.groundingdino.util.slconfig import SLConfig\n",
96
  "from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap\n",
97
+ "from GroundingDINO.groundingdino.util.inference import annotate, load_image, predict\n",
98
  "\n",
99
  "import supervision as sv\n",
100
  "\n",
grounded_sam_inpainting_demo.py CHANGED
@@ -125,6 +125,7 @@ if __name__ == "__main__":
125
 
126
  parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
127
  parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
 
128
  parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False")
129
  args = parser.parse_args()
130
 
@@ -138,6 +139,7 @@ if __name__ == "__main__":
138
  output_dir = args.output_dir
139
  box_threshold = args.box_threshold
140
  text_threshold = args.box_threshold
 
141
  device = args.device
142
 
143
  # make dir
@@ -181,7 +183,11 @@ if __name__ == "__main__":
181
  # masks: [1, 1, 512, 512]
182
 
183
  # inpainting pipeline
184
- mask = masks[0][0].cpu().numpy() # simply choose the first mask, which will be refine in the future release
 
 
 
 
185
  mask_pil = Image.fromarray(mask)
186
  image_pil = Image.fromarray(image)
187
 
@@ -190,8 +196,11 @@ if __name__ == "__main__":
190
  )
191
  pipe = pipe.to("cuda")
192
 
 
 
193
  # prompt = "A sofa, high quality, detailed"
194
  image = pipe(prompt=inpaint_prompt, image=image_pil, mask_image=mask_pil).images[0]
 
195
  image.save(os.path.join(output_dir, "grounded_sam_inpainting_output.jpg"))
196
 
197
  # draw output image
 
125
 
126
  parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
127
  parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
128
+ parser.add_argument("--inpaint_mode", type=str, default="first", help="inpaint mode")
129
  parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False")
130
  args = parser.parse_args()
131
 
 
139
  output_dir = args.output_dir
140
  box_threshold = args.box_threshold
141
  text_threshold = args.box_threshold
142
+ inpaint_mode = args.inpaint_mode
143
  device = args.device
144
 
145
  # make dir
 
183
  # masks: [1, 1, 512, 512]
184
 
185
  # inpainting pipeline
186
+ if inpaint_mode == 'merge':
187
+ masks = torch.sum(masks, dim=0).unsqueeze(0)
188
+ masks = torch.where(masks > 0, True, False)
189
+ else:
190
+ mask = masks[0][0].cpu().numpy() # simply choose the first mask, which will be refine in the future release
191
  mask_pil = Image.fromarray(mask)
192
  image_pil = Image.fromarray(image)
193
 
 
196
  )
197
  pipe = pipe.to("cuda")
198
 
199
+ image_pil = image_pil.resize((512, 512))
200
+ mask_pil = mask_pil.resize((512, 512))
201
  # prompt = "A sofa, high quality, detailed"
202
  image = pipe(prompt=inpaint_prompt, image=image_pil, mask_image=mask_pil).images[0]
203
+ image = image.resize(size)
204
  image.save(os.path.join(output_dir, "grounded_sam_inpainting_output.jpg"))
205
 
206
  # draw output image
grounded_sam_whisper_demo.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import copy
4
+
5
+ import numpy as np
6
+ import json
7
+ import torch
8
+ import torchvision
9
+ from PIL import Image, ImageDraw, ImageFont
10
+
11
+ # Grounding DINO
12
+ import GroundingDINO.groundingdino.datasets.transforms as T
13
+ from GroundingDINO.groundingdino.models import build_model
14
+ from GroundingDINO.groundingdino.util import box_ops
15
+ from GroundingDINO.groundingdino.util.slconfig import SLConfig
16
+ from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
17
+
18
+ # segment anything
19
+ from segment_anything import build_sam, SamPredictor
20
+ import cv2
21
+ import numpy as np
22
+ import matplotlib.pyplot as plt
23
+
24
+ # whisper
25
+ import whisper
26
+
27
+
28
+ def load_image(image_path):
29
+ # load image
30
+ image_pil = Image.open(image_path).convert("RGB") # load image
31
+
32
+ transform = T.Compose(
33
+ [
34
+ T.RandomResize([800], max_size=1333),
35
+ T.ToTensor(),
36
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
37
+ ]
38
+ )
39
+ image, _ = transform(image_pil, None) # 3, h, w
40
+ return image_pil, image
41
+
42
+
43
+ def load_model(model_config_path, model_checkpoint_path, device):
44
+ args = SLConfig.fromfile(model_config_path)
45
+ args.device = device
46
+ model = build_model(args)
47
+ checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
48
+ load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
49
+ print(load_res)
50
+ _ = model.eval()
51
+ return model
52
+
53
+
54
+ def get_grounding_output(model, image, caption, box_threshold, text_threshold,device="cpu"):
55
+ caption = caption.lower()
56
+ caption = caption.strip()
57
+ if not caption.endswith("."):
58
+ caption = caption + "."
59
+ model = model.to(device)
60
+ image = image.to(device)
61
+ with torch.no_grad():
62
+ outputs = model(image[None], captions=[caption])
63
+ logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
64
+ boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
65
+ logits.shape[0]
66
+
67
+ # filter output
68
+ logits_filt = logits.clone()
69
+ boxes_filt = boxes.clone()
70
+ filt_mask = logits_filt.max(dim=1)[0] > box_threshold
71
+ logits_filt = logits_filt[filt_mask] # num_filt, 256
72
+ boxes_filt = boxes_filt[filt_mask] # num_filt, 4
73
+ logits_filt.shape[0]
74
+
75
+ # get phrase
76
+ tokenlizer = model.tokenizer
77
+ tokenized = tokenlizer(caption)
78
+ # build pred
79
+ pred_phrases = []
80
+ scores = []
81
+ for logit, box in zip(logits_filt, boxes_filt):
82
+ pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
83
+ pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
84
+ scores.append(logit.max().item())
85
+
86
+ return boxes_filt, torch.Tensor(scores), pred_phrases
87
+
88
+ def show_mask(mask, ax, random_color=False):
89
+ if random_color:
90
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
91
+ else:
92
+ color = np.array([30/255, 144/255, 255/255, 0.6])
93
+ h, w = mask.shape[-2:]
94
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
95
+ ax.imshow(mask_image)
96
+
97
+
98
+ def show_box(box, ax, label):
99
+ x0, y0 = box[0], box[1]
100
+ w, h = box[2] - box[0], box[3] - box[1]
101
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
102
+ ax.text(x0, y0, label)
103
+
104
+
105
+ def save_mask_data(output_dir, mask_list, box_list, label_list):
106
+ value = 0 # 0 for background
107
+
108
+ mask_img = torch.zeros(mask_list.shape[-2:])
109
+ for idx, mask in enumerate(mask_list):
110
+ mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1
111
+ plt.figure(figsize=(10, 10))
112
+ plt.imshow(mask_img.numpy())
113
+ plt.axis('off')
114
+ plt.savefig(os.path.join(output_dir, 'mask.jpg'), bbox_inches="tight", dpi=300, pad_inches=0.0)
115
+
116
+ json_data = [{
117
+ 'value': value,
118
+ 'label': 'background'
119
+ }]
120
+ for label, box in zip(label_list, box_list):
121
+ value += 1
122
+ name, logit = label.split('(')
123
+ logit = logit[:-1] # the last is ')'
124
+ json_data.append({
125
+ 'value': value,
126
+ 'label': name,
127
+ 'logit': float(logit),
128
+ 'box': box.numpy().tolist(),
129
+ })
130
+ with open(os.path.join(output_dir, 'mask.json'), 'w') as f:
131
+ json.dump(json_data, f)
132
+
133
+
134
+ def speech_recognition(speech_file, model):
135
+ # whisper
136
+ # load audio and pad/trim it to fit 30 seconds
137
+ audio = whisper.load_audio(speech_file)
138
+ audio = whisper.pad_or_trim(audio)
139
+
140
+ # make log-Mel spectrogram and move to the same device as the model
141
+ mel = whisper.log_mel_spectrogram(audio).to(model.device)
142
+
143
+ # detect the spoken language
144
+ _, probs = model.detect_language(mel)
145
+ speech_language = max(probs, key=probs.get)
146
+
147
+ # decode the audio
148
+ options = whisper.DecodingOptions()
149
+ result = whisper.decode(model, mel, options)
150
+
151
+ # print the recognized text
152
+ speech_text = result.text
153
+ return speech_text, speech_language
154
+
155
+ if __name__ == "__main__":
156
+
157
+ parser = argparse.ArgumentParser("Grounded-Segment-Anything Demo", add_help=True)
158
+ parser.add_argument("--config", type=str, required=True, help="path to config file")
159
+ parser.add_argument(
160
+ "--grounded_checkpoint", type=str, required=True, help="path to checkpoint file"
161
+ )
162
+ parser.add_argument(
163
+ "--sam_checkpoint", type=str, required=True, help="path to checkpoint file"
164
+ )
165
+ parser.add_argument("--input_image", type=str, required=True, help="path to image file")
166
+ parser.add_argument("--speech_file", type=str, required=True, help="speech file")
167
+ parser.add_argument(
168
+ "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
169
+ )
170
+
171
+ parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
172
+ parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
173
+ parser.add_argument("--iou_threshold", type=float, default=0.5, help="iou threshold")
174
+
175
+ parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False")
176
+ args = parser.parse_args()
177
+
178
+ # cfg
179
+ config_file = args.config # change the path of the model config file
180
+ grounded_checkpoint = args.grounded_checkpoint # change the path of the model
181
+ sam_checkpoint = args.sam_checkpoint
182
+ image_path = args.input_image
183
+ output_dir = args.output_dir
184
+ box_threshold = args.box_threshold
185
+ text_threshold = args.text_threshold
186
+ iou_threshold = args.iou_threshold
187
+ device = args.device
188
+
189
+ # load speech
190
+ whisper_model = whisper.load_model("base")
191
+ speech_text, speech_language = speech_recognition(args.speech_file, whisper_model)
192
+ print(f"speech_text: {speech_text}")
193
+ print(f"speech_language: {speech_language}")
194
+
195
+ # make dir
196
+ os.makedirs(output_dir, exist_ok=True)
197
+ # load image
198
+ image_pil, image = load_image(image_path)
199
+ # load model
200
+ model = load_model(config_file, grounded_checkpoint, device=device)
201
+
202
+ # visualize raw image
203
+ image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
204
+
205
+ # run grounding dino model
206
+ text_prompt = speech_text
207
+ boxes_filt, scores, pred_phrases = get_grounding_output(
208
+ model, image, text_prompt, box_threshold, text_threshold, device=device
209
+ )
210
+
211
+ # initialize SAM
212
+ predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(args.device))
213
+ image = cv2.imread(image_path)
214
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
215
+ predictor.set_image(image)
216
+
217
+ size = image_pil.size
218
+ H, W = size[1], size[0]
219
+ for i in range(boxes_filt.size(0)):
220
+ boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
221
+ boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
222
+ boxes_filt[i][2:] += boxes_filt[i][:2]
223
+
224
+ boxes_filt = boxes_filt.cpu()
225
+ # use NMS to handle overlapped boxes
226
+ print(f"Before NMS: {boxes_filt.shape[0]} boxes")
227
+ nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
228
+ boxes_filt = boxes_filt[nms_idx]
229
+ pred_phrases = [pred_phrases[idx] for idx in nms_idx]
230
+ print(f"After NMS: {boxes_filt.shape[0]} boxes")
231
+
232
+ transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
233
+
234
+ masks, _, _ = predictor.predict_torch(
235
+ point_coords = None,
236
+ point_labels = None,
237
+ boxes = transformed_boxes.to(args.device),
238
+ multimask_output = False,
239
+ )
240
+
241
+ # draw output image
242
+ plt.figure(figsize=(10, 10))
243
+ plt.imshow(image)
244
+ for mask in masks:
245
+ show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
246
+ for box, label in zip(boxes_filt, pred_phrases):
247
+ show_box(box.numpy(), plt.gca(), label)
248
+
249
+ plt.title(speech_text)
250
+ plt.axis('off')
251
+ plt.savefig(
252
+ os.path.join(output_dir, "grounded_sam_whisper_output.jpg"),
253
+ bbox_inches="tight", dpi=300, pad_inches=0.0
254
+ )
255
+
256
+
257
+ save_mask_data(output_dir, masks, boxes_filt, pred_phrases)
258
+
grounded_dino_sam_inpainting_demo.py → grounded_sam_whisper_inpainting_demo.py RENAMED
@@ -1,6 +1,6 @@
1
  import argparse
2
  import os
3
- import copy
4
 
5
  import numpy as np
6
  import torch
@@ -27,45 +27,12 @@ import torch
27
  from io import BytesIO
28
  from diffusers import StableDiffusionInpaintPipeline
29
 
30
- def plot_boxes_to_image(image_pil, tgt):
31
- H, W = tgt["size"]
32
- boxes = tgt["boxes"]
33
- labels = tgt["labels"]
34
- assert len(boxes) == len(labels), "boxes and labels must have same length"
35
-
36
- draw = ImageDraw.Draw(image_pil)
37
- mask = Image.new("L", image_pil.size, 0)
38
- mask_draw = ImageDraw.Draw(mask)
39
-
40
- # draw boxes and masks
41
- for box, label in zip(boxes, labels):
42
- # from 0..1 to 0..W, 0..H
43
- box = box * torch.Tensor([W, H, W, H])
44
- # from xywh to xyxy
45
- box[:2] -= box[2:] / 2
46
- box[2:] += box[:2]
47
- # random color
48
- color = tuple(np.random.randint(0, 255, size=3).tolist())
49
- # draw
50
- x0, y0, x1, y1 = box
51
- x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
52
-
53
- draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
54
- # draw.text((x0, y0), str(label), fill=color)
55
-
56
- font = ImageFont.load_default()
57
- if hasattr(font, "getbbox"):
58
- bbox = draw.textbbox((x0, y0), str(label), font)
59
- else:
60
- w, h = draw.textsize(str(label), font)
61
- bbox = (x0, y0, w + x0, y0 + h)
62
- # bbox = draw.textbbox((x0, y0), str(label))
63
- draw.rectangle(bbox, fill=color)
64
- draw.text((x0, y0), str(label), fill="white")
65
 
66
- mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6)
 
67
 
68
- return image_pil, mask
69
 
70
  def load_image(image_path):
71
  # load image
@@ -143,6 +110,48 @@ def show_box(box, ax, label):
143
  w, h = box[2] - box[0], box[3] - box[1]
144
  ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
145
  ax.text(x0, y0, label)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
 
148
  if __name__ == "__main__":
@@ -153,36 +162,38 @@ if __name__ == "__main__":
153
  "--grounded_checkpoint", type=str, required=True, help="path to checkpoint file"
154
  )
155
  parser.add_argument(
156
- "--sam_checkpoint", type=str, required=False, help="path to checkpoint file"
157
  )
158
- parser.add_argument("--task_type", type=str, required=True, help="select task")
159
  parser.add_argument("--input_image", type=str, required=True, help="path to image file")
160
- parser.add_argument("--text_prompt", type=str, required=True, help="text prompt")
161
- parser.add_argument("--inpaint_prompt", type=str, required=False, help="inpaint prompt")
162
  parser.add_argument(
163
  "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
164
  )
165
-
 
 
 
 
 
 
166
  parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
167
  parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
 
168
  parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False")
 
169
  args = parser.parse_args()
170
 
171
  # cfg
172
  config_file = args.config # change the path of the model config file
173
  grounded_checkpoint = args.grounded_checkpoint # change the path of the model
174
  sam_checkpoint = args.sam_checkpoint
175
- task_type = args.task_type
176
  image_path = args.input_image
177
- text_prompt = args.text_prompt
178
- inpaint_prompt = args.inpaint_prompt
179
  output_dir = args.output_dir
180
  box_threshold = args.box_threshold
181
  text_threshold = args.box_threshold
 
182
  device = args.device
183
 
184
- assert text_prompt, 'text_prompt is not found!'
185
-
186
  # make dir
187
  os.makedirs(output_dir, exist_ok=True)
188
  # load image
@@ -192,87 +203,79 @@ if __name__ == "__main__":
192
 
193
  # visualize raw image
194
  image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
195
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  # run grounding dino model
197
  boxes_filt, pred_phrases = get_grounding_output(
198
- model, image, text_prompt, box_threshold, text_threshold, device=device
199
  )
200
 
 
 
 
 
 
 
201
  size = image_pil.size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
- if task_type == 'seg' or task_type == 'inpainting':
204
- # initialize SAM
205
- predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
206
- image = cv2.imread(image_path)
207
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
208
- predictor.set_image(image)
209
-
210
- H, W = size[1], size[0]
211
- for i in range(boxes_filt.size(0)):
212
- boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
213
- boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
214
- boxes_filt[i][2:] += boxes_filt[i][:2]
215
-
216
- boxes_filt = boxes_filt.cpu()
217
- transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
218
-
219
- masks, _, _ = predictor.predict_torch(
220
- point_coords = None,
221
- point_labels = None,
222
- boxes = transformed_boxes,
223
- multimask_output = False,
224
- )
225
-
226
- # masks: [1, 1, 512, 512]
227
-
228
- if task_type == 'det':
229
- assert grounded_checkpoint, 'grounded_checkpoint is not found!'
230
- pred_dict = {
231
- "boxes": boxes_filt,
232
- "size": [size[1], size[0]], # H,W
233
- "labels": pred_phrases,
234
- }
235
- # import ipdb; ipdb.set_trace()
236
- image_with_box = plot_boxes_to_image(image_pil, pred_dict)[0]
237
- image_with_box.save(os.path.join(output_dir, "grounding_dino_output.jpg"))
238
- elif task_type == 'seg':
239
- assert sam_checkpoint, 'sam_checkpoint is not found!'
240
-
241
- # draw output image
242
- plt.figure(figsize=(10, 10))
243
- plt.imshow(image)
244
- for mask in masks:
245
- show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
246
- for box, label in zip(boxes_filt, pred_phrases):
247
- show_box(box.numpy(), plt.gca(), label)
248
- plt.axis('off')
249
- plt.savefig(os.path.join(output_dir, "grounded_sam_output.jpg"), bbox_inches="tight")
250
-
251
- elif task_type == 'inpainting':
252
- assert inpaint_prompt, 'inpaint_prompt is not found!'
253
- # inpainting pipeline
254
- mask = masks[0][0].cpu().numpy() # simply choose the first mask, which will be refine in the future release
255
- mask_pil = Image.fromarray(mask)
256
- image_pil = Image.fromarray(image)
257
-
258
- pipe = StableDiffusionInpaintPipeline.from_pretrained(
259
- "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
260
- )
261
- pipe = pipe.to("cuda")
262
-
263
- # prompt = "A sofa, high quality, detailed"
264
- image = pipe(prompt=inpaint_prompt, image=image_pil, mask_image=mask_pil).images[0]
265
- image.save(os.path.join(output_dir, "grounded_sam_inpainting_output.jpg"))
266
-
267
- # draw output image
268
- # plt.figure(figsize=(10, 10))
269
- # plt.imshow(image)
270
- # for mask in masks:
271
- # show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
272
- # for box, label in zip(boxes_filt, pred_phrases):
273
- # show_box(box.numpy(), plt.gca(), label)
274
- # plt.axis('off')
275
- # plt.savefig(os.path.join(output_dir, "grounded_sam_output.jpg"), bbox_inches="tight")
276
  else:
277
- print("task_type:{} error!".format(task_type))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
 
1
  import argparse
2
  import os
3
+ from warnings import warn
4
 
5
  import numpy as np
6
  import torch
 
27
  from io import BytesIO
28
  from diffusers import StableDiffusionInpaintPipeline
29
 
30
+ # whisper
31
+ import whisper
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ # ChatGPT
34
+ import openai
35
 
 
36
 
37
  def load_image(image_path):
38
  # load image
 
110
  w, h = box[2] - box[0], box[3] - box[1]
111
  ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
112
  ax.text(x0, y0, label)
113
+
114
+
115
+ def speech_recognition(speech_file, model):
116
+ # whisper
117
+ # load audio and pad/trim it to fit 30 seconds
118
+ audio = whisper.load_audio(speech_file)
119
+ audio = whisper.pad_or_trim(audio)
120
+
121
+ # make log-Mel spectrogram and move to the same device as the model
122
+ mel = whisper.log_mel_spectrogram(audio).to(model.device)
123
+
124
+ # detect the spoken language
125
+ _, probs = model.detect_language(mel)
126
+ speech_language = max(probs, key=probs.get)
127
+
128
+ # decode the audio
129
+ options = whisper.DecodingOptions()
130
+ result = whisper.decode(model, mel, options)
131
+
132
+ # print the recognized text
133
+ speech_text = result.text
134
+ return speech_text, speech_language
135
+
136
+
137
+ def filter_prompts_with_chatgpt(caption, max_tokens=100, model="gpt-3.5-turbo"):
138
+ prompt = [
139
+ {
140
+ 'role': 'system',
141
+ 'content': f"Extract the main object to be replaced and marked it as 'main_object', " + \
142
+ f"Extract the remaining part as 'other prompt' " + \
143
+ f"Return (main_object, other prompt)" + \
144
+ f'Given caption: {caption}.'
145
+ }
146
+ ]
147
+ response = openai.ChatCompletion.create(model=model, messages=prompt, temperature=0.6, max_tokens=max_tokens)
148
+ reply = response['choices'][0]['message']['content']
149
+ try:
150
+ det_prompt, inpaint_prompt = reply.split('\n')[0].split(':')[-1].strip(), reply.split('\n')[1].split(':')[-1].strip()
151
+ except:
152
+ warn(f"Failed to extract tags from caption") # use caption as det_prompt, inpaint_prompt
153
+ det_prompt, inpaint_prompt = caption, caption
154
+ return det_prompt, inpaint_prompt
155
 
156
 
157
  if __name__ == "__main__":
 
162
  "--grounded_checkpoint", type=str, required=True, help="path to checkpoint file"
163
  )
164
  parser.add_argument(
165
+ "--sam_checkpoint", type=str, required=True, help="path to checkpoint file"
166
  )
 
167
  parser.add_argument("--input_image", type=str, required=True, help="path to image file")
 
 
168
  parser.add_argument(
169
  "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
170
  )
171
+ parser.add_argument("--det_speech_file", type=str, help="grounding speech file")
172
+ parser.add_argument("--inpaint_speech_file", type=str, help="inpaint speech file")
173
+ parser.add_argument("--prompt_speech_file", type=str, help="prompt speech file, no need to provide det_speech_file")
174
+ parser.add_argument("--enable_chatgpt", action="store_true", help="enable chatgpt")
175
+ parser.add_argument("--openai_key", type=str, help="key for chatgpt")
176
+ parser.add_argument("--openai_proxy", default=None, type=str, help="proxy for chatgpt")
177
+ parser.add_argument("--whisper_model", type=str, default="small", help="whisper model version: tiny, base, small, medium, large")
178
  parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
179
  parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
180
+ parser.add_argument("--inpaint_mode", type=str, default="first", help="inpaint mode")
181
  parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False")
182
+ parser.add_argument("--prompt_extra", type=str, default=" high resolution, real scene", help="extra prompt for inpaint")
183
  args = parser.parse_args()
184
 
185
  # cfg
186
  config_file = args.config # change the path of the model config file
187
  grounded_checkpoint = args.grounded_checkpoint # change the path of the model
188
  sam_checkpoint = args.sam_checkpoint
 
189
  image_path = args.input_image
190
+
 
191
  output_dir = args.output_dir
192
  box_threshold = args.box_threshold
193
  text_threshold = args.box_threshold
194
+ inpaint_mode = args.inpaint_mode
195
  device = args.device
196
 
 
 
197
  # make dir
198
  os.makedirs(output_dir, exist_ok=True)
199
  # load image
 
203
 
204
  # visualize raw image
205
  image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
206
+
207
+ # recognize speech
208
+ whisper_model = whisper.load_model(args.whisper_model)
209
+
210
+ if args.enable_chatgpt:
211
+ openai.api_key = args.openai_key
212
+ if args.openai_proxy:
213
+ openai.proxy = {"http": args.openai_proxy, "https": args.openai_proxy}
214
+ speech_text, _ = speech_recognition(args.prompt_speech_file, whisper_model)
215
+ det_prompt, inpaint_prompt = filter_prompts_with_chatgpt(speech_text)
216
+ inpaint_prompt += args.prompt_extra
217
+ print(f"det_prompt: {det_prompt}, inpaint_prompt: {inpaint_prompt}")
218
+ else:
219
+ det_prompt, det_speech_language = speech_recognition(args.det_speech_file, whisper_model)
220
+ inpaint_prompt, inpaint_speech_language = speech_recognition(args.inpaint_speech_file, whisper_model)
221
+ print(f"det_prompt: {det_prompt}, using language: {det_speech_language}")
222
+ print(f"inpaint_prompt: {inpaint_prompt}, using language: {inpaint_speech_language}")
223
+
224
  # run grounding dino model
225
  boxes_filt, pred_phrases = get_grounding_output(
226
+ model, image, det_prompt, box_threshold, text_threshold, device=device
227
  )
228
 
229
+ # initialize SAM
230
+ predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
231
+ image = cv2.imread(image_path)
232
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
233
+ predictor.set_image(image)
234
+
235
  size = image_pil.size
236
+ H, W = size[1], size[0]
237
+ for i in range(boxes_filt.size(0)):
238
+ boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
239
+ boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
240
+ boxes_filt[i][2:] += boxes_filt[i][:2]
241
+
242
+ boxes_filt = boxes_filt.cpu()
243
+ transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
244
+
245
+ masks, _, _ = predictor.predict_torch(
246
+ point_coords = None,
247
+ point_labels = None,
248
+ boxes = transformed_boxes,
249
+ multimask_output = False,
250
+ )
251
 
252
+ # masks: [1, 1, 512, 512]
253
+
254
+ # inpainting pipeline
255
+ if inpaint_mode == 'merge':
256
+ masks = torch.sum(masks, dim=0).unsqueeze(0)
257
+ masks = torch.where(masks > 0, True, False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  else:
259
+ mask = masks[0][0].cpu().numpy() # simply choose the first mask, which will be refine in the future release
260
+ mask_pil = Image.fromarray(mask)
261
+ image_pil = Image.fromarray(image)
262
+
263
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
264
+ "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
265
+ )
266
+ pipe = pipe.to("cuda")
267
+
268
+ # prompt = "A sofa, high quality, detailed"
269
+ image = pipe(prompt=inpaint_prompt, image=image_pil, mask_image=mask_pil).images[0]
270
+ image.save(os.path.join(output_dir, "grounded_sam_inpainting_output.jpg"))
271
+
272
+ # draw output image
273
+ # plt.figure(figsize=(10, 10))
274
+ # plt.imshow(image)
275
+ # for mask in masks:
276
+ # show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
277
+ # for box, label in zip(boxes_filt, pred_phrases):
278
+ # show_box(box.numpy(), plt.gca(), label)
279
+ # plt.axis('off')
280
+ # plt.savefig(os.path.join(output_dir, "grounded_sam_output.jpg"), bbox_inches="tight")
281