Tobias Cornille commited on
Commit
27a9b54
1 Parent(s): d334f4b

Fix device

Browse files
Files changed (1) hide show
  1. app.py +15 -17
app.py CHANGED
@@ -3,9 +3,6 @@ import subprocess, os, sys
3
  result = subprocess.run(["pip", "install", "-e", "GroundingDINO"], check=True)
4
  print(f"pip install GroundingDINO = {result}")
5
 
6
- result = subprocess.run(["pip", "list"], check=True)
7
- print(f"pip list = {result}")
8
-
9
  sys.path.insert(0, "./GroundingDINO")
10
 
11
  if not os.path.exists("./sam_vit_h_4b8939.pth"):
@@ -42,9 +39,8 @@ from GroundingDINO.groundingdino.util import box_ops
42
  from GroundingDINO.groundingdino.util.slconfig import SLConfig
43
  from GroundingDINO.groundingdino.util.utils import (
44
  clean_state_dict,
45
- get_phrases_from_posmap,
46
  )
47
- from GroundingDINO.groundingdino.util.inference import annotate, load_image, predict
48
 
49
  # segment anything
50
  from segment_anything import build_sam, SamPredictor
@@ -63,6 +59,7 @@ def load_model_hf(model_config_path, repo_id, filename, device):
63
  log = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
64
  print("Model loaded from {} \n => {}".format(cache_file, log))
65
  _ = model.eval()
 
66
  return model
67
 
68
 
@@ -99,6 +96,7 @@ def dino_detection(
99
  caption=detection_prompt,
100
  box_threshold=box_threshold,
101
  text_threshold=text_threshold,
 
102
  )
103
  category_ids = [category_name_to_id[phrase] for phrase in phrases]
104
 
@@ -113,7 +111,7 @@ def dino_detection(
113
  return boxes, category_ids
114
 
115
 
116
- def sam_masks_from_dino_boxes(predictor, image_array, boxes):
117
  # box: normalized box xywh -> unnormalized xyxy
118
  H, W, _ = image_array.shape
119
  boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
@@ -284,14 +282,6 @@ def generate_panoptic_mask(
284
  image = image.convert("RGB")
285
  image_array = np.asarray(image)
286
 
287
- if device != "cpu":
288
- try:
289
- from GroundingDINO.groundingdino import _C
290
- except:
291
- warnings.warn(
292
- "Failed to load custom C++ ops. Running on CPU mode Only in groundingdino!"
293
- )
294
-
295
  # detect boxes for "thing" categories using Grounding DINO
296
  thing_boxes, _ = dino_detection(
297
  dino_model,
@@ -306,7 +296,9 @@ def generate_panoptic_mask(
306
  # compute SAM image embedding
307
  sam_predictor.set_image(image_array)
308
  # get segmentation masks for the thing boxes
309
- thing_masks = sam_masks_from_dino_boxes(sam_predictor, image_array, thing_boxes)
 
 
310
  # get rough segmentation masks for "stuff" categories using CLIPSeg
311
  clipseg_preds, clipseg_semantic_inds = clipseg_segmentation(
312
  clipseg_processor,
@@ -366,9 +358,16 @@ sam_checkpoint = "./sam_vit_h_4b8939.pth"
366
  device = "cuda" if torch.cuda.is_available() else "cpu"
367
  print("Using device:", device)
368
 
 
 
 
 
 
 
 
 
369
  # initialize groundingdino model
370
  dino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filename, device)
371
- dino_model = dino_model.to(device)
372
 
373
  # initialize SAM
374
  sam = build_sam(checkpoint=sam_checkpoint)
@@ -458,6 +457,5 @@ if __name__ == "__main__":
458
  ],
459
  outputs=[plot],
460
  )
461
- # task_type.change(fn=change_task_type, inputs=[task_type], outputs=[inpaint_prompt])
462
 
463
  block.launch(server_name="0.0.0.0", debug=args.debug, share=args.share)
 
3
  result = subprocess.run(["pip", "install", "-e", "GroundingDINO"], check=True)
4
  print(f"pip install GroundingDINO = {result}")
5
 
 
 
 
6
  sys.path.insert(0, "./GroundingDINO")
7
 
8
  if not os.path.exists("./sam_vit_h_4b8939.pth"):
 
39
  from GroundingDINO.groundingdino.util.slconfig import SLConfig
40
  from GroundingDINO.groundingdino.util.utils import (
41
  clean_state_dict,
 
42
  )
43
+ from GroundingDINO.groundingdino.util.inference import annotate, predict
44
 
45
  # segment anything
46
  from segment_anything import build_sam, SamPredictor
 
59
  log = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
60
  print("Model loaded from {} \n => {}".format(cache_file, log))
61
  _ = model.eval()
62
+ model = model.to(device)
63
  return model
64
 
65
 
 
96
  caption=detection_prompt,
97
  box_threshold=box_threshold,
98
  text_threshold=text_threshold,
99
+ device=device,
100
  )
101
  category_ids = [category_name_to_id[phrase] for phrase in phrases]
102
 
 
111
  return boxes, category_ids
112
 
113
 
114
+ def sam_masks_from_dino_boxes(predictor, image_array, boxes, device):
115
  # box: normalized box xywh -> unnormalized xyxy
116
  H, W, _ = image_array.shape
117
  boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
 
282
  image = image.convert("RGB")
283
  image_array = np.asarray(image)
284
 
 
 
 
 
 
 
 
 
285
  # detect boxes for "thing" categories using Grounding DINO
286
  thing_boxes, _ = dino_detection(
287
  dino_model,
 
296
  # compute SAM image embedding
297
  sam_predictor.set_image(image_array)
298
  # get segmentation masks for the thing boxes
299
+ thing_masks = sam_masks_from_dino_boxes(
300
+ sam_predictor, image_array, thing_boxes, device
301
+ )
302
  # get rough segmentation masks for "stuff" categories using CLIPSeg
303
  clipseg_preds, clipseg_semantic_inds = clipseg_segmentation(
304
  clipseg_processor,
 
358
  device = "cuda" if torch.cuda.is_available() else "cpu"
359
  print("Using device:", device)
360
 
361
+ if device != "cpu":
362
+ try:
363
+ from GroundingDINO.groundingdino import _C
364
+ except:
365
+ warnings.warn(
366
+ "Failed to load custom C++ ops. Running on CPU mode Only in groundingdino!"
367
+ )
368
+
369
  # initialize groundingdino model
370
  dino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filename, device)
 
371
 
372
  # initialize SAM
373
  sam = build_sam(checkpoint=sam_checkpoint)
 
457
  ],
458
  outputs=[plot],
459
  )
 
460
 
461
  block.launch(server_name="0.0.0.0", debug=args.debug, share=args.share)