jackyccl commited on
Commit
e98c1f9
1 Parent(s): 300b427

Add remove

Browse files
Files changed (1) hide show
  1. app.py +104 -6
app.py CHANGED
@@ -11,6 +11,8 @@ import subprocess
11
  import copy
12
  import time
13
  import warnings
 
 
14
 
15
  import torch
16
  from torchvision.ops import box_convert
@@ -26,13 +28,18 @@ import groundingdino.datasets.transforms as T
26
  # segment anything
27
  from segment_anything import build_sam, SamPredictor
28
 
 
 
 
 
 
29
  #stable diffusion
30
  from diffusers import StableDiffusionInpaintPipeline
31
 
32
  from huggingface_hub import hf_hub_download
33
 
34
- if not os.path.exists('./demo1.jpg'):
35
- os.system("wget https://github.com/IDEA-Research/Grounded-Segment-Anything/raw/main/assets/demo1.jpg")
36
 
37
  if not os.path.exists('./sam_vit_h_4b8939.pth'):
38
  logger.info(f"get sam_vit_h_4b8939.pth...")
@@ -177,6 +184,63 @@ def mix_masks(imgs):
177
  re_img = 1 - re_img
178
  return Image.fromarray(np.uint8(255*re_img))
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
181
  iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend):
182
 
@@ -199,6 +263,8 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
199
  # load image
200
  image_pil, image_tensor = load_image_and_transform(input_image['image'])
201
 
 
 
202
  # RUN GROUNDINGDINO: we skip DINO if we draw mask on the image
203
  if (task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw:
204
  pass
@@ -218,7 +284,6 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
218
  }
219
 
220
  # store and save DINO output
221
- output_images = []
222
  image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
223
  image_path = os.path.join(output_dir, f"grounding_dino_output_{file_temp}.jpg")
224
  image_with_box.save(image_path)
@@ -300,7 +365,39 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
300
  image_source_for_inpaint = image_pil.resize((512, 512))
301
  image_mask_for_inpaint = mask_pil.resize((512, 512))
302
  image_inpainting = sd_pipe(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
303
- # else: add remove option here!!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
  image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
306
  output_images.append(image_inpainting)
@@ -330,6 +427,7 @@ def change_radio_display(task_type, mask_source_radio):
330
  # model initialization
331
  groundingDino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filename, groundingdino_device)
332
  sam_predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
 
333
 
334
  # initialize stable-diffusion-inpainting
335
  logger.info(f"initialize stable-diffusion-inpainting...")
@@ -359,7 +457,7 @@ if __name__ == "__main__":
359
  with gr.Row():
360
  with gr.Column():
361
  input_image = gr.Image(
362
- source="upload", elem_id="image_upload", type="pil", tool="sketch", value="demo1.jpg", label="Upload")
363
  task_type = gr.Radio(["segment", "inpainting", "remove"], value="segment",
364
  label='Task type', visible=True)
365
 
@@ -368,7 +466,7 @@ if __name__ == "__main__":
368
  visible=False)
369
 
370
  text_prompt = gr.Textbox(label="Detection Prompt, seperating each name with dot '.', i.e.: bear.cat.dog.chair ]", \
371
- value='bear', placeholder="Cannot be empty")
372
  inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
373
 
374
  run_button = gr.Button(label="Run")
 
11
  import copy
12
  import time
13
  import warnings
14
+ import io
15
+ import random
16
 
17
  import torch
18
  from torchvision.ops import box_convert
 
28
  # segment anything
29
  from segment_anything import build_sam, SamPredictor
30
 
31
+ # lama-cleaner
32
+ from lama_cleaner.model_manager import ModelManager
33
+ from lama_cleaner.schema import Config as lama_Config
34
+ from lama_cleaner.helper import load_img, numpy_to_bytes, resize_max_size
35
+
36
  #stable diffusion
37
  from diffusers import StableDiffusionInpaintPipeline
38
 
39
  from huggingface_hub import hf_hub_download
40
 
41
+ if not os.path.exists('./demo2.jpg'):
42
+ os.system("wget https://github.com/IDEA-Research/Grounded-Segment-Anything/raw/main/assets/demo2.jpg")
43
 
44
  if not os.path.exists('./sam_vit_h_4b8939.pth'):
45
  logger.info(f"get sam_vit_h_4b8939.pth...")
 
184
  re_img = 1 - re_img
185
  return Image.fromarray(np.uint8(255*re_img))
186
 
187
+ def lama_cleaner_process(image, mask):
188
+ ori_image = image
189
+ if mask.shape[0] == image.shape[1] and mask.shape[1] == image.shape[0] and mask.shape[0] != mask.shape[1]:
190
+ # rotate image
191
+ ori_image = np.transpose(image[::-1, ...][:, ::-1], axes=(1, 0, 2))[::-1, ...]
192
+ image = ori_image
193
+
194
+ original_shape = ori_image.shape
195
+ interpolation = cv2.INTER_CUBIC
196
+
197
+ size_limit = 1080
198
+ if size_limit == "Original":
199
+ size_limit = max(image.shape)
200
+ else:
201
+ size_limit = int(size_limit)
202
+
203
+ config = lama_Config(
204
+ ldm_steps=25,
205
+ ldm_sampler='plms',
206
+ zits_wireframe=True,
207
+ hd_strategy='Original',
208
+ hd_strategy_crop_margin=196,
209
+ hd_strategy_crop_trigger_size=1280,
210
+ hd_strategy_resize_limit=2048,
211
+ prompt='',
212
+ use_croper=False,
213
+ croper_x=0,
214
+ croper_y=0,
215
+ croper_height=512,
216
+ croper_width=512,
217
+ sd_mask_blur=5,
218
+ sd_strength=0.75,
219
+ sd_steps=50,
220
+ sd_guidance_scale=7.5,
221
+ sd_sampler='ddim',
222
+ sd_seed=42,
223
+ cv2_flag='INPAINT_NS',
224
+ cv2_radius=5,
225
+ )
226
+
227
+ if config.sd_seed == -1:
228
+ config.sd_seed = random.randint(1, 999999999)
229
+
230
+ # logger.info(f"Origin image shape_0_: {original_shape} / {size_limit}")
231
+ image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
232
+ # logger.info(f"Resized image shape_1_: {image.shape}")
233
+
234
+ # logger.info(f"mask image shape_0_: {mask.shape} / {type(mask)}")
235
+ mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
236
+ # logger.info(f"mask image shape_1_: {mask.shape} / {type(mask)}")
237
+
238
+ res_np_img = lama_cleaner_model(image, mask, config)
239
+ torch.cuda.empty_cache()
240
+
241
+ image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img, 'png')))
242
+ return image
243
+
244
  def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
245
  iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend):
246
 
 
263
  # load image
264
  image_pil, image_tensor = load_image_and_transform(input_image['image'])
265
 
266
+ output_images = []
267
+ output_images.append(input_image['image'])
268
  # RUN GROUNDINGDINO: we skip DINO if we draw mask on the image
269
  if (task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw:
270
  pass
 
284
  }
285
 
286
  # store and save DINO output
 
287
  image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
288
  image_path = os.path.join(output_dir, f"grounding_dino_output_{file_temp}.jpg")
289
  image_with_box.save(image_path)
 
365
  image_source_for_inpaint = image_pil.resize((512, 512))
366
  image_mask_for_inpaint = mask_pil.resize((512, 512))
367
  image_inpainting = sd_pipe(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
368
+ else:
369
+ # remove from mask
370
+ if mask_source_radio == mask_source_segment:
371
+ mask_imgs = []
372
+ masks_shape = masks_ori.shape
373
+ boxes_filt_ori_array = boxes_filt_ori.numpy()
374
+ if inpaint_mode == 'merge':
375
+ extend_shape_0 = masks_shape[0]
376
+ extend_shape_1 = masks_shape[1]
377
+ else:
378
+ extend_shape_0 = 1
379
+ extend_shape_1 = 1
380
+ for i in range(extend_shape_0):
381
+ for j in range(extend_shape_1):
382
+ mask = masks_ori[i][j].cpu().numpy()
383
+ mask_pil = Image.fromarray(mask)
384
+
385
+ if remove_mode == 'segment':
386
+ useRectangle = False
387
+ else:
388
+ useRectangle = True
389
+
390
+ try:
391
+ remove_mask_extend = int(remove_mask_extend)
392
+ except:
393
+ remove_mask_extend = 10
394
+ mask_pil_exp = mask_extend(copy.deepcopy(mask_pil).convert("RGB"),
395
+ box_convert(torch.tensor(boxes_filt_ori_array[i]), in_fmt="cxcywh", out_fmt="xyxy").numpy(),
396
+ extend_pixels=remove_mask_extend, useRectangle=useRectangle)
397
+ mask_imgs.append(mask_pil_exp)
398
+ mask_pil = mix_masks(mask_imgs)
399
+ output_images.append(mask_pil.convert("RGB"))
400
+ image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")))
401
 
402
  image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
403
  output_images.append(image_inpainting)
 
427
  # model initialization
428
  groundingDino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filename, groundingdino_device)
429
  sam_predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
430
+ lama_cleaner_model = ModelManager(name='lama',device='cpu')
431
 
432
  # initialize stable-diffusion-inpainting
433
  logger.info(f"initialize stable-diffusion-inpainting...")
 
457
  with gr.Row():
458
  with gr.Column():
459
  input_image = gr.Image(
460
+ source="upload", elem_id="image_upload", type="pil", tool="sketch", value="demo2.jpg", label="Upload")
461
  task_type = gr.Radio(["segment", "inpainting", "remove"], value="segment",
462
  label='Task type', visible=True)
463
 
 
466
  visible=False)
467
 
468
  text_prompt = gr.Textbox(label="Detection Prompt, seperating each name with dot '.', i.e.: bear.cat.dog.chair ]", \
469
+ value='dog', placeholder="Cannot be empty")
470
  inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
471
 
472
  run_button = gr.Button(label="Run")