liuyizhang commited on
Commit
5247a47
1 Parent(s): 68cab41

update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -242,7 +242,8 @@ groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
242
 
243
  # initialize SAM
244
  logger.info(f"initialize SAM model...")
245
- sam_model = build_sam(checkpoint=sam_checkpoint).to(device)
 
246
  sam_predictor = SamPredictor(sam_model)
247
  sam_mask_generator = SamAutomaticMaskGenerator(sam_model)
248
 
@@ -558,7 +559,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
558
  boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
559
  boxes_filt[i][2:] += boxes_filt[i][:2]
560
 
561
- boxes_filt = boxes_filt.cpu()
562
  transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
563
 
564
  masks, _, _, _ = sam_predictor.predict_torch(
 
242
 
243
  # initialize SAM
244
  logger.info(f"initialize SAM model...")
245
+ sam_device = device
246
+ sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
247
  sam_predictor = SamPredictor(sam_model)
248
  sam_mask_generator = SamAutomaticMaskGenerator(sam_model)
249
 
 
559
  boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
560
  boxes_filt[i][2:] += boxes_filt[i][:2]
561
 
562
+ boxes_filt = boxes_filt.to(sam_device)
563
  transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
564
 
565
  masks, _, _, _ = sam_predictor.predict_torch(