Spaces:
Runtime error
Runtime error
print info to track progress
Browse files
app.py
CHANGED
@@ -125,6 +125,7 @@ def draw_box(box, draw, label):
|
|
125 |
|
126 |
|
127 |
def inference(raw_image, specified_tags, tagging_model_type, tagging_model, grounding_dino_model, sam_model):
|
|
|
128 |
raw_image = raw_image.convert("RGB")
|
129 |
|
130 |
# run tagging model
|
@@ -165,6 +166,7 @@ def inference(raw_image, specified_tags, tagging_model_type, tagging_model, grou
|
|
165 |
boxes_filt, scores, pred_phrases = get_grounding_output(
|
166 |
grounding_dino_model, image, tags, box_threshold, text_threshold, device=device
|
167 |
)
|
|
|
168 |
|
169 |
# run SAM
|
170 |
image = np.asarray(raw_image)
|
@@ -179,13 +181,13 @@ def inference(raw_image, specified_tags, tagging_model_type, tagging_model, grou
|
|
179 |
|
180 |
boxes_filt = boxes_filt.cpu()
|
181 |
# use NMS to handle overlapped boxes
|
182 |
-
|
183 |
-
|
184 |
boxes_filt = boxes_filt[nms_idx]
|
185 |
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
|
|
|
186 |
|
187 |
-
transformed_boxes = sam_model.transform.apply_boxes_torch(
|
188 |
-
boxes_filt, image.shape[:2]).to(device)
|
189 |
|
190 |
masks, _, _ = sam_model.predict_torch(
|
191 |
point_coords=None,
|
@@ -193,6 +195,7 @@ def inference(raw_image, specified_tags, tagging_model_type, tagging_model, grou
|
|
193 |
boxes=transformed_boxes.to(device),
|
194 |
multimask_output=False,
|
195 |
)
|
|
|
196 |
|
197 |
# draw output image
|
198 |
mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
|
|
|
125 |
|
126 |
|
127 |
def inference(raw_image, specified_tags, tagging_model_type, tagging_model, grounding_dino_model, sam_model):
|
128 |
+
print(f"Start processing, image size {raw_image.size}")
|
129 |
raw_image = raw_image.convert("RGB")
|
130 |
|
131 |
# run tagging model
|
|
|
166 |
boxes_filt, scores, pred_phrases = get_grounding_output(
|
167 |
grounding_dino_model, image, tags, box_threshold, text_threshold, device=device
|
168 |
)
|
169 |
+
print("GroundingDINO finished")
|
170 |
|
171 |
# run SAM
|
172 |
image = np.asarray(raw_image)
|
|
|
181 |
|
182 |
boxes_filt = boxes_filt.cpu()
|
183 |
# use NMS to handle overlapped boxes
|
184 |
+
print(f"Before NMS: {boxes_filt.shape[0]} boxes")
|
185 |
+
nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
|
186 |
boxes_filt = boxes_filt[nms_idx]
|
187 |
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
|
188 |
+
print(f"After NMS: {boxes_filt.shape[0]} boxes")
|
189 |
|
190 |
+
transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
|
|
|
191 |
|
192 |
masks, _, _ = sam_model.predict_torch(
|
193 |
point_coords=None,
|
|
|
195 |
boxes=transformed_boxes.to(device),
|
196 |
multimask_output=False,
|
197 |
)
|
198 |
+
print("SAM finished")
|
199 |
|
200 |
# draw output image
|
201 |
mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
|