Spaces:
Running
on
Zero
Running
on
Zero
return_rectangles
Browse files
app.py
CHANGED
@@ -30,7 +30,7 @@ SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
|
|
30 |
@spaces.GPU(duration=20)
|
31 |
@torch.inference_mode()
|
32 |
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
33 |
-
def process_image(image_input, image_url, task_prompt, text_prompt=None, dilate=0, merge_masks=False) -> Optional[Image.Image]:
|
34 |
if not image_input:
|
35 |
gr.Info("Please upload an image.")
|
36 |
return None
|
@@ -58,28 +58,37 @@ def process_image(image_input, image_url, task_prompt, text_prompt=None, dilate=
|
|
58 |
result=result,
|
59 |
resolution_wh=image_input.size
|
60 |
)
|
61 |
-
detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
|
62 |
-
if len(detections) == 0:
|
63 |
-
gr.Info("No objects detected.")
|
64 |
-
return None
|
65 |
images = []
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
-
for i in range(len(detections.mask)):
|
71 |
-
mask = detections.mask[i].astype(np.uint8) * 255
|
72 |
-
if dilate > 0:
|
73 |
-
mask = cv2.dilate(mask, kernel, iterations=1)
|
74 |
-
images.append(mask)
|
75 |
-
|
76 |
-
if merge_masks:
|
77 |
-
final_images = []
|
78 |
-
merged_mask = np.zeros_like(images[0], dtype=np.uint8)
|
79 |
-
for mask in images:
|
80 |
-
merged_mask = cv2.bitwise_or(merged_mask, mask)
|
81 |
-
final_images = [merged_mask]
|
82 |
-
return final_images
|
83 |
return images
|
84 |
|
85 |
|
@@ -93,7 +102,7 @@ with gr.Blocks() as demo:
|
|
93 |
)
|
94 |
dilate = gr.Slider(label="dilate mask", minimum=0, maximum=50, value=10, step=1)
|
95 |
merge_masks = gr.Checkbox(label="Merge masks", value=False)
|
96 |
-
|
97 |
text_prompt = gr.Textbox(label='Text prompt', placeholder='Enter text prompts')
|
98 |
submit_button = gr.Button(value='Submit', variant='primary')
|
99 |
with gr.Column():
|
@@ -101,7 +110,7 @@ with gr.Blocks() as demo:
|
|
101 |
print(image, image_url, task_prompt, text_prompt, image_gallery)
|
102 |
submit_button.click(
|
103 |
fn = process_image,
|
104 |
-
inputs = [image, image_url, task_prompt, text_prompt, dilate, merge_masks],
|
105 |
outputs = [image_gallery,],
|
106 |
show_api=False
|
107 |
)
|
|
|
30 |
@spaces.GPU(duration=20)
|
31 |
@torch.inference_mode()
|
32 |
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
33 |
+
def process_image(image_input, image_url, task_prompt, text_prompt=None, dilate=0, merge_masks=False, return_rectangles=False) -> Optional[Image.Image]:
|
34 |
if not image_input:
|
35 |
gr.Info("Please upload an image.")
|
36 |
return None
|
|
|
58 |
result=result,
|
59 |
resolution_wh=image_input.size
|
60 |
)
|
|
|
|
|
|
|
|
|
61 |
images = []
|
62 |
+
if return_rectangles:
|
63 |
+
# 创建黑色背景的图片
|
64 |
+
mask_image = np.zeros((image_input.size.height, image_input.size.width), dtype=np.uint8)
|
65 |
+
bboxes = detections.get('bboxes', [])
|
66 |
+
for bbox in bboxes:
|
67 |
+
x1, y1, x2, y2 = map(int, bbox)
|
68 |
+
# 在 mask_image 上绘制白色的矩形
|
69 |
+
cv2.rectangle(mask_image, (x1, y1), (x2, y2), 255, thickness=cv2.FILLED)
|
70 |
+
images = [mask_image]
|
71 |
+
else:
|
72 |
+
# sam
|
73 |
+
detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
|
74 |
+
if len(detections) == 0:
|
75 |
+
gr.Info("No objects detected.")
|
76 |
+
return None
|
77 |
+
kernel_size = dilate
|
78 |
+
print("mask generated:", len(detections.mask))
|
79 |
+
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
80 |
+
for i in range(len(detections.mask)):
|
81 |
+
mask = detections.mask[i].astype(np.uint8) * 255
|
82 |
+
if dilate > 0:
|
83 |
+
mask = cv2.dilate(mask, kernel, iterations=1)
|
84 |
+
images.append(mask)
|
85 |
+
if merge_masks:
|
86 |
+
|
87 |
+
merged_mask = np.zeros_like(images[0], dtype=np.uint8)
|
88 |
+
for mask in images:
|
89 |
+
merged_mask = cv2.bitwise_or(merged_mask, mask)
|
90 |
+
images = [merged_mask] + images
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
return images
|
93 |
|
94 |
|
|
|
102 |
)
|
103 |
dilate = gr.Slider(label="dilate mask", minimum=0, maximum=50, value=10, step=1)
|
104 |
merge_masks = gr.Checkbox(label="Merge masks", value=False)
|
105 |
+
return_rectangles = gr.Checkbox(label="Return rectangle masks", value=False)
|
106 |
text_prompt = gr.Textbox(label='Text prompt', placeholder='Enter text prompts')
|
107 |
submit_button = gr.Button(value='Submit', variant='primary')
|
108 |
with gr.Column():
|
|
|
110 |
print(image, image_url, task_prompt, text_prompt, image_gallery)
|
111 |
submit_button.click(
|
112 |
fn = process_image,
|
113 |
+
inputs = [image, image_url, task_prompt, text_prompt, dilate, merge_masks, return_rectangles],
|
114 |
outputs = [image_gallery,],
|
115 |
show_api=False
|
116 |
)
|