dragonSwing commited on
Commit
2396fdf
1 Parent(s): 53c5524

Enable smoothing mask and expanding mask

Browse files
Files changed (1) hide show
  1. app.py +58 -15
app.py CHANGED
@@ -1,9 +1,10 @@
 
1
  import json
2
  import os
3
- import subprocess
4
  import sys
5
  import tempfile
6
 
 
7
  import gradio as gr
8
  import numpy as np
9
  import supervision as sv
@@ -86,7 +87,16 @@ grounding_dino_model = DinoModel(
86
  )
87
 
88
 
89
- def process(image_path, task, prompt, box_threshold, text_threshold, iou_threshold):
 
 
 
 
 
 
 
 
 
90
  global tag2text_model, sam_predictor, sam_automask_generator, grounding_dino_model, device
91
  output_gallery = []
92
  detections = None
@@ -97,6 +107,7 @@ def process(image_path, task, prompt, box_threshold, text_threshold, iou_thresho
97
  image = Image.open(image_path)
98
  image_pil = image.convert("RGB")
99
  image = np.array(image_pil)
 
100
 
101
  # Extract image metadata
102
  filename = os.path.basename(image_path)
@@ -106,7 +117,7 @@ def process(image_path, task, prompt, box_threshold, text_threshold, iou_thresho
106
  metadata["image"]["height"] = h
107
 
108
  # Generate tags
109
- if task in ["auto", "detect"] and prompt == "":
110
  tags, caption = generate_tags(tag2text_model, image_pil, "None", device)
111
  prompt = " . ".join(tags)
112
  print(f"Caption: {caption}")
@@ -146,20 +157,38 @@ def process(image_path, task, prompt, box_threshold, text_threshold, iou_thresho
146
 
147
  # Segmentation
148
  if task in ["auto", "segment"]:
 
 
 
149
  if detections:
150
  masks, scores = segment(
151
- sam_predictor, image=image, boxes=detections.xyxy
152
  )
 
 
 
 
 
 
 
 
 
153
  detections.mask = masks
 
 
 
154
  else:
155
- masks = sam_automask_generator.generate(image)
156
  sorted_generated_masks = sorted(
157
  masks, key=lambda x: x["area"], reverse=True
158
  )
159
 
160
  xywh = np.array([mask["bbox"] for mask in sorted_generated_masks])
161
  mask = np.array(
162
- [mask["segmentation"] for mask in sorted_generated_masks]
 
 
 
163
  )
164
  scores = np.array(
165
  [mask["predicted_iou"] for mask in sorted_generated_masks]
@@ -167,9 +196,7 @@ def process(image_path, task, prompt, box_threshold, text_threshold, iou_thresho
167
  detections = sv.Detections(
168
  xyxy=xywh_to_xyxy(boxes_xywh=xywh), mask=mask
169
  )
170
- # opacity = 0.4
171
- # mask_image, _ = show_anns_sam(masks)
172
- # annotated_image = np.uint8(mask_image * opacity + image * (1 - opacity))
173
 
174
  mask_annotator = sv.MaskAnnotator()
175
  mask_image = np.zeros_like(image, dtype=np.uint8)
@@ -177,7 +204,13 @@ def process(image_path, task, prompt, box_threshold, text_threshold, iou_thresho
177
  mask_image, detections=detections, opacity=1
178
  )
179
  annotated_image = mask_annotator.annotate(image, detections=detections)
 
180
  output_gallery.append(mask_image)
 
 
 
 
 
181
  output_gallery.append(annotated_image)
182
 
183
  # ToDo: Extract metadata
@@ -203,7 +236,7 @@ def process(image_path, task, prompt, box_threshold, text_threshold, iou_thresho
203
 
204
  meta_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json")
205
  meta_file_path = meta_file.name
206
- with open(meta_file_path, "w") as fp:
207
  json.dump(metadata, fp)
208
 
209
  return output_gallery, meta_file_path
@@ -231,7 +264,6 @@ with gr.Blocks(css="style.css", title=title) as demo:
231
  value=0.3,
232
  step=0.05,
233
  label="Box threshold",
234
- info="Hash size to use for image hashing",
235
  )
236
  text_threshold = gr.Slider(
237
  minimum=0,
@@ -239,7 +271,6 @@ with gr.Blocks(css="style.css", title=title) as demo:
239
  value=0.25,
240
  step=0.05,
241
  label="Text threshold",
242
- info="Number of history images used to find out duplicate image",
243
  )
244
  iou_threshold = gr.Slider(
245
  minimum=0,
@@ -247,7 +278,18 @@ with gr.Blocks(css="style.css", title=title) as demo:
247
  value=0.5,
248
  step=0.05,
249
  label="IOU threshold",
250
- info="Minimum similarity threshold (in percent) to consider 2 images to be similar",
 
 
 
 
 
 
 
 
 
 
 
251
  )
252
  run_button = gr.Button(label="Run")
253
 
@@ -256,12 +298,11 @@ with gr.Blocks(css="style.css", title=title) as demo:
256
  label="Generated images", show_label=False, elem_id="gallery"
257
  ).style(preview=True, grid=2, object_fit="scale-down")
258
  meta_file = gr.File(label="Metadata file")
259
-
260
  with gr.Row(elem_classes=["container"]):
261
  gr.Examples(
262
  [
263
  ["examples/dog.png", "auto", ""],
264
- ["examples/eiffel.png", "auto", ""],
265
  ["examples/eiffel.png", "segment", ""],
266
  ["examples/girl.png", "auto", "girl . face"],
267
  ["examples/horse.png", "detect", "horse"],
@@ -279,6 +320,8 @@ with gr.Blocks(css="style.css", title=title) as demo:
279
  box_threshold,
280
  text_threshold,
281
  iou_threshold,
 
 
282
  ],
283
  outputs=[gallery, meta_file],
284
  )
 
1
+ import functools
2
  import json
3
  import os
 
4
  import sys
5
  import tempfile
6
 
7
+ import cv2
8
  import gradio as gr
9
  import numpy as np
10
  import supervision as sv
 
87
  )
88
 
89
 
90
+ def process(
91
+ image_path,
92
+ task,
93
+ prompt,
94
+ box_threshold,
95
+ text_threshold,
96
+ iou_threshold,
97
+ kernel_size,
98
+ expand_mask,
99
+ ):
100
  global tag2text_model, sam_predictor, sam_automask_generator, grounding_dino_model, device
101
  output_gallery = []
102
  detections = None
 
107
  image = Image.open(image_path)
108
  image_pil = image.convert("RGB")
109
  image = np.array(image_pil)
110
+ orig_image = image.copy()
111
 
112
  # Extract image metadata
113
  filename = os.path.basename(image_path)
 
117
  metadata["image"]["height"] = h
118
 
119
  # Generate tags
120
+ if task in ["auto", "detection"] and prompt == "":
121
  tags, caption = generate_tags(tag2text_model, image_pil, "None", device)
122
  prompt = " . ".join(tags)
123
  print(f"Caption: {caption}")
 
157
 
158
  # Segmentation
159
  if task in ["auto", "segment"]:
160
+ kernel = cv2.getStructuringElement(
161
+ cv2.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1)
162
+ )
163
  if detections:
164
  masks, scores = segment(
165
+ sam_predictor, image=orig_image, boxes=detections.xyxy
166
  )
167
+ if expand_mask:
168
+ masks = [
169
+ cv2.dilate(mask.astype(np.uint8), kernel) for mask in masks
170
+ ]
171
+ else:
172
+ masks = [
173
+ cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
174
+ for mask in masks
175
+ ]
176
  detections.mask = masks
177
+ binary_mask = functools.reduce(
178
+ lambda x, y: x + y, detections.mask
179
+ ).astype(np.bool)
180
  else:
181
+ masks = sam_automask_generator.generate(orig_image)
182
  sorted_generated_masks = sorted(
183
  masks, key=lambda x: x["area"], reverse=True
184
  )
185
 
186
  xywh = np.array([mask["bbox"] for mask in sorted_generated_masks])
187
  mask = np.array(
188
+ [
189
+ cv2.dilate(mask["segmentation"].astype(np.uint8), kernel)
190
+ for mask in sorted_generated_masks
191
+ ]
192
  )
193
  scores = np.array(
194
  [mask["predicted_iou"] for mask in sorted_generated_masks]
 
196
  detections = sv.Detections(
197
  xyxy=xywh_to_xyxy(boxes_xywh=xywh), mask=mask
198
  )
199
+ binary_mask = None
 
 
200
 
201
  mask_annotator = sv.MaskAnnotator()
202
  mask_image = np.zeros_like(image, dtype=np.uint8)
 
204
  mask_image, detections=detections, opacity=1
205
  )
206
  annotated_image = mask_annotator.annotate(image, detections=detections)
207
+
208
  output_gallery.append(mask_image)
209
+ if binary_mask is not None:
210
+ binary_mask_image = binary_mask * 255
211
+ cutout_image = np.expand_dims(binary_mask, axis=-1) * orig_image
212
+ output_gallery.append(binary_mask_image)
213
+ output_gallery.append(cutout_image)
214
  output_gallery.append(annotated_image)
215
 
216
  # ToDo: Extract metadata
 
236
 
237
  meta_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json")
238
  meta_file_path = meta_file.name
239
+ with open(meta_file_path, "w", encoding="utf-8") as fp:
240
  json.dump(metadata, fp)
241
 
242
  return output_gallery, meta_file_path
 
264
  value=0.3,
265
  step=0.05,
266
  label="Box threshold",
 
267
  )
268
  text_threshold = gr.Slider(
269
  minimum=0,
 
271
  value=0.25,
272
  step=0.05,
273
  label="Text threshold",
 
274
  )
275
  iou_threshold = gr.Slider(
276
  minimum=0,
 
278
  value=0.5,
279
  step=0.05,
280
  label="IOU threshold",
281
+ info="Intersection over Union threshold",
282
+ )
283
+ kernel_size = gr.Slider(
284
+ minimum=1,
285
+ maximum=5,
286
+ value=2,
287
+ step=1,
288
+ label="Kernel size",
289
+ info="Use to smooth segment masks",
290
+ )
291
+ expand_mask = gr.Checkbox(
292
+ label="Expand mask",
293
  )
294
  run_button = gr.Button(label="Run")
295
 
 
298
  label="Generated images", show_label=False, elem_id="gallery"
299
  ).style(preview=True, grid=2, object_fit="scale-down")
300
  meta_file = gr.File(label="Metadata file")
 
301
  with gr.Row(elem_classes=["container"]):
302
  gr.Examples(
303
  [
304
  ["examples/dog.png", "auto", ""],
305
+ ["examples/eiffel.jpg", "auto", "tower . lake . grass . sky"],
306
  ["examples/eiffel.png", "segment", ""],
307
  ["examples/girl.png", "auto", "girl . face"],
308
  ["examples/horse.png", "detect", "horse"],
 
320
  box_threshold,
321
  text_threshold,
322
  iou_threshold,
323
+ kernel_size,
324
+ expand_mask,
325
  ],
326
  outputs=[gallery, meta_file],
327
  )