jackyccl commited on
Commit
16b1d75
1 Parent(s): 99c8a45

Add mask extension function to inpaint mode.

Browse files
Files changed (1) hide show
  1. app.py +38 -30
app.py CHANGED
@@ -250,6 +250,38 @@ def xywh_to_xyxy(box, sizeW, sizeH):
250
  box = box.numpy()
251
  return box
252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
254
  iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend):
255
 
@@ -372,40 +404,16 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
372
  if task_type == 'inpainting':
373
  # inpainting pipeline
374
  image_source_for_inpaint = image_pil.resize((512, 512))
 
 
 
375
  image_mask_for_inpaint = mask_pil.resize((512, 512))
376
  image_inpainting = sd_pipe(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
377
  else:
378
  # remove from mask
379
  if mask_source_radio == mask_source_segment:
380
- mask_imgs = []
381
- masks_shape = masks_ori.shape
382
- boxes_filt_ori_array = boxes_filt_ori.numpy()
383
- if inpaint_mode == 'merge':
384
- extend_shape_0 = masks_shape[0]
385
- extend_shape_1 = masks_shape[1]
386
- else:
387
- extend_shape_0 = 1
388
- extend_shape_1 = 1
389
- for i in range(extend_shape_0):
390
- for j in range(extend_shape_1):
391
- mask = masks_ori[i][j].cpu().numpy()
392
- mask_pil = Image.fromarray(mask)
393
-
394
- if remove_mode == 'segment':
395
- useRectangle = False
396
- else:
397
- useRectangle = True
398
-
399
- try:
400
- remove_mask_extend = int(remove_mask_extend)
401
- except:
402
- remove_mask_extend = 10
403
- mask_pil_exp = mask_extend(copy.deepcopy(mask_pil).convert("RGB"),
404
- # box_convert(torch.tensor(boxes_filt_ori_array[i]), in_fmt="cxcywh", out_fmt="xyxy").numpy(),
405
- xywh_to_xyxy(torch.tensor(boxes_filt_ori_array[i]), size[0], size[1]),
406
- extend_pixels=remove_mask_extend, useRectangle=useRectangle)
407
- mask_imgs.append(mask_pil_exp)
408
- mask_pil = mix_masks(mask_imgs)
409
  output_images.append(mask_pil.convert("RGB"))
410
  image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")))
411
 
@@ -495,7 +503,7 @@ if __name__ == "__main__":
495
  with gr.Column(scale=1):
496
  remove_mode = gr.Radio(["segment", "rectangle"], value="segment", label='remove mode')
497
  with gr.Column(scale=1):
498
- remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10')
499
 
500
  with gr.Column():
501
  gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", visible=True
 
250
  box = box.numpy()
251
  return box
252
 
253
+ def to_extend_mask(segment_mask, boxes_filt, size, remove_mask_extend, remove_mode):
254
+ # remove from mask
255
+ mask_imgs = []
256
+ masks_shape = segment_mask.shape
257
+ boxes_filt_ori_array = boxes_filt.numpy()
258
+ if inpaint_mode == 'merge':
259
+ extend_shape_0 = masks_shape[0]
260
+ extend_shape_1 = masks_shape[1]
261
+ else:
262
+ extend_shape_0 = 1
263
+ extend_shape_1 = 1
264
+ for i in range(extend_shape_0):
265
+ for j in range(extend_shape_1):
266
+ mask = segment_mask[i][j].cpu().numpy()
267
+ mask_pil = Image.fromarray(mask)
268
+
269
+ if remove_mode == 'segment':
270
+ useRectangle = False
271
+ else:
272
+ useRectangle = True
273
+
274
+ try:
275
+ remove_mask_extend = int(remove_mask_extend)
276
+ except:
277
+ remove_mask_extend = 10
278
+ mask_pil_exp = mask_extend(copy.deepcopy(mask_pil).convert("RGB"),
279
+ xywh_to_xyxy(torch.tensor(boxes_filt_ori_array[i]), size[0], size[1]),
280
+ extend_pixels=remove_mask_extend, useRectangle=useRectangle)
281
+ mask_imgs.append(mask_pil_exp)
282
+ mask_pil = mix_masks(mask_imgs)
283
+ return mask_pil
284
+
285
  def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
286
  iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend):
287
 
 
404
  if task_type == 'inpainting':
405
  # inpainting pipeline
406
  image_source_for_inpaint = image_pil.resize((512, 512))
407
+ if remove_mask_extend:
408
+ mask_pil = to_extend_mask(masks_ori, boxes_filt_ori, size, remove_mask_extend, remove_mode)
409
+ output_images.append(mask_pil.convert("RGB"))
410
  image_mask_for_inpaint = mask_pil.resize((512, 512))
411
  image_inpainting = sd_pipe(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
412
  else:
413
  # remove from mask
414
  if mask_source_radio == mask_source_segment:
415
+ if remove_mask_extend:
416
+ mask_pil = to_extend_mask(masks_ori, boxes_filt_ori, size, remove_mask_extend, remove_mode)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
  output_images.append(mask_pil.convert("RGB"))
418
  image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")))
419
 
 
503
  with gr.Column(scale=1):
504
  remove_mode = gr.Radio(["segment", "rectangle"], value="segment", label='remove mode')
505
  with gr.Column(scale=1):
506
+ remove_mask_extend = gr.Textbox(label="Enlarge Mask (Empty: no mask extension, default: 10)")
507
 
508
  with gr.Column():
509
  gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", visible=True