jiuface commited on
Commit
4c32826
1 Parent(s): 2b27106

return_rectangles

Browse files
Files changed (1) hide show
  1. app.py +32 -23
app.py CHANGED
@@ -30,7 +30,7 @@ SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
30
  @spaces.GPU(duration=20)
31
  @torch.inference_mode()
32
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
33
- def process_image(image_input, image_url, task_prompt, text_prompt=None, dilate=0, merge_masks=False) -> Optional[Image.Image]:
34
  if not image_input:
35
  gr.Info("Please upload an image.")
36
  return None
@@ -58,28 +58,37 @@ def process_image(image_input, image_url, task_prompt, text_prompt=None, dilate=
58
  result=result,
59
  resolution_wh=image_input.size
60
  )
61
- detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
62
- if len(detections) == 0:
63
- gr.Info("No objects detected.")
64
- return None
65
  images = []
66
- print("mask generated:", len(detections.mask))
67
- kernel_size = dilate
68
- kernel = np.ones((kernel_size, kernel_size), np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- for i in range(len(detections.mask)):
71
- mask = detections.mask[i].astype(np.uint8) * 255
72
- if dilate > 0:
73
- mask = cv2.dilate(mask, kernel, iterations=1)
74
- images.append(mask)
75
-
76
- if merge_masks:
77
- final_images = []
78
- merged_mask = np.zeros_like(images[0], dtype=np.uint8)
79
- for mask in images:
80
- merged_mask = cv2.bitwise_or(merged_mask, mask)
81
- final_images = [merged_mask]
82
- return final_images
83
  return images
84
 
85
 
@@ -93,7 +102,7 @@ with gr.Blocks() as demo:
93
  )
94
  dilate = gr.Slider(label="dilate mask", minimum=0, maximum=50, value=10, step=1)
95
  merge_masks = gr.Checkbox(label="Merge masks", value=False)
96
-
97
  text_prompt = gr.Textbox(label='Text prompt', placeholder='Enter text prompts')
98
  submit_button = gr.Button(value='Submit', variant='primary')
99
  with gr.Column():
@@ -101,7 +110,7 @@ with gr.Blocks() as demo:
101
  print(image, image_url, task_prompt, text_prompt, image_gallery)
102
  submit_button.click(
103
  fn = process_image,
104
- inputs = [image, image_url, task_prompt, text_prompt, dilate, merge_masks],
105
  outputs = [image_gallery,],
106
  show_api=False
107
  )
 
30
  @spaces.GPU(duration=20)
31
  @torch.inference_mode()
32
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
33
+ def process_image(image_input, image_url, task_prompt, text_prompt=None, dilate=0, merge_masks=False, return_rectangles=False) -> Optional[Image.Image]:
34
  if not image_input:
35
  gr.Info("Please upload an image.")
36
  return None
 
58
  result=result,
59
  resolution_wh=image_input.size
60
  )
 
 
 
 
61
  images = []
62
+ if return_rectangles:
63
+ # 创建黑色背景的图片
64
+ mask_image = np.zeros((image_input.size.height, image_input.size.width), dtype=np.uint8)
65
+ bboxes = detections.get('bboxes', [])
66
+ for bbox in bboxes:
67
+ x1, y1, x2, y2 = map(int, bbox)
68
+ # 在 mask_image 上绘制白色的矩形
69
+ cv2.rectangle(mask_image, (x1, y1), (x2, y2), 255, thickness=cv2.FILLED)
70
+ images = [mask_image]
71
+ else:
72
+ # sam
73
+ detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
74
+ if len(detections) == 0:
75
+ gr.Info("No objects detected.")
76
+ return None
77
+ kernel_size = dilate
78
+ print("mask generated:", len(detections.mask))
79
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
80
+ for i in range(len(detections.mask)):
81
+ mask = detections.mask[i].astype(np.uint8) * 255
82
+ if dilate > 0:
83
+ mask = cv2.dilate(mask, kernel, iterations=1)
84
+ images.append(mask)
85
+ if merge_masks:
86
+
87
+ merged_mask = np.zeros_like(images[0], dtype=np.uint8)
88
+ for mask in images:
89
+ merged_mask = cv2.bitwise_or(merged_mask, mask)
90
+ images = [merged_mask] + images
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  return images
93
 
94
 
 
102
  )
103
  dilate = gr.Slider(label="dilate mask", minimum=0, maximum=50, value=10, step=1)
104
  merge_masks = gr.Checkbox(label="Merge masks", value=False)
105
+ return_rectangles = gr.Checkbox(label="Return rectangle masks", value=False)
106
  text_prompt = gr.Textbox(label='Text prompt', placeholder='Enter text prompts')
107
  submit_button = gr.Button(value='Submit', variant='primary')
108
  with gr.Column():
 
110
  print(image, image_url, task_prompt, text_prompt, image_gallery)
111
  submit_button.click(
112
  fn = process_image,
113
+ inputs = [image, image_url, task_prompt, text_prompt, dilate, merge_masks, return_rectangles],
114
  outputs = [image_gallery,],
115
  show_api=False
116
  )