majinyu commited on
Commit
b683f14
1 Parent(s): 0da4c5f

print info to track progress

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -125,6 +125,7 @@ def draw_box(box, draw, label):
125
 
126
 
127
  def inference(raw_image, specified_tags, tagging_model_type, tagging_model, grounding_dino_model, sam_model):
 
128
  raw_image = raw_image.convert("RGB")
129
 
130
  # run tagging model
@@ -165,6 +166,7 @@ def inference(raw_image, specified_tags, tagging_model_type, tagging_model, grou
165
  boxes_filt, scores, pred_phrases = get_grounding_output(
166
  grounding_dino_model, image, tags, box_threshold, text_threshold, device=device
167
  )
 
168
 
169
  # run SAM
170
  image = np.asarray(raw_image)
@@ -179,13 +181,13 @@ def inference(raw_image, specified_tags, tagging_model_type, tagging_model, grou
179
 
180
  boxes_filt = boxes_filt.cpu()
181
  # use NMS to handle overlapped boxes
182
- nms_idx = torchvision.ops.nms(
183
- boxes_filt, scores, iou_threshold).numpy().tolist()
184
  boxes_filt = boxes_filt[nms_idx]
185
  pred_phrases = [pred_phrases[idx] for idx in nms_idx]
 
186
 
187
- transformed_boxes = sam_model.transform.apply_boxes_torch(
188
- boxes_filt, image.shape[:2]).to(device)
189
 
190
  masks, _, _ = sam_model.predict_torch(
191
  point_coords=None,
@@ -193,6 +195,7 @@ def inference(raw_image, specified_tags, tagging_model_type, tagging_model, grou
193
  boxes=transformed_boxes.to(device),
194
  multimask_output=False,
195
  )
 
196
 
197
  # draw output image
198
  mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
 
125
 
126
 
127
  def inference(raw_image, specified_tags, tagging_model_type, tagging_model, grounding_dino_model, sam_model):
128
+ print(f"Start processing, image size {raw_image.size}")
129
  raw_image = raw_image.convert("RGB")
130
 
131
  # run tagging model
 
166
  boxes_filt, scores, pred_phrases = get_grounding_output(
167
  grounding_dino_model, image, tags, box_threshold, text_threshold, device=device
168
  )
169
+ print("GroundingDINO finished")
170
 
171
  # run SAM
172
  image = np.asarray(raw_image)
 
181
 
182
  boxes_filt = boxes_filt.cpu()
183
  # use NMS to handle overlapped boxes
184
+ print(f"Before NMS: {boxes_filt.shape[0]} boxes")
185
+ nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
186
  boxes_filt = boxes_filt[nms_idx]
187
  pred_phrases = [pred_phrases[idx] for idx in nms_idx]
188
+ print(f"After NMS: {boxes_filt.shape[0]} boxes")
189
 
190
+ transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
 
191
 
192
  masks, _, _ = sam_model.predict_torch(
193
  point_coords=None,
 
195
  boxes=transformed_boxes.to(device),
196
  multimask_output=False,
197
  )
198
+ print("SAM finished")
199
 
200
  # draw output image
201
  mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))