liuyizhang commited on
Commit
0bfb07f
1 Parent(s): eeaf820

update app.py

Browse files
Files changed (2) hide show
  1. app.py +13 -8
  2. app_cli.py +10 -2
app.py CHANGED
@@ -4,6 +4,7 @@ warnings.filterwarnings('ignore')
4
 
5
  import subprocess, io, os, sys, time
6
  os.system("pip install gradio==3.36.1")
 
7
 
8
  from loguru import logger
9
 
@@ -21,8 +22,6 @@ if os.environ.get('IS_MY_DEBUG') is None:
21
 
22
  sys.path.insert(0, './GroundingDINO')
23
 
24
- import gradio as gr
25
-
26
  import argparse
27
  import copy
28
 
@@ -301,7 +300,7 @@ def load_lama_cleaner_model():
301
  device='cpu', # device,
302
  )
303
 
304
- def lama_cleaner_process(image, mask):
305
  ori_image = image
306
  if mask.shape[0] == image.shape[1] and mask.shape[1] == image.shape[0] and mask.shape[0] != mask.shape[1]:
307
  # rotate image
@@ -311,8 +310,8 @@ def lama_cleaner_process(image, mask):
311
  original_shape = ori_image.shape
312
  interpolation = cv2.INTER_CUBIC
313
 
314
- size_limit = 1080
315
- if size_limit == "Original":
316
  size_limit = max(image.shape)
317
  else:
318
  size_limit = int(size_limit)
@@ -517,7 +516,7 @@ mask_source_draw = "draw a mask on input image"
517
  mask_source_segment = "type what to detect below"
518
 
519
  def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
520
- iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation):
521
  if (task_type == 'relate anything'):
522
  output_images = relate_anything(input_image['image'], num_relation)
523
  return output_images, gr.Gallery.update(label='relate images')
@@ -644,6 +643,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
644
  image_inpainting = sd_pipe(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
645
  else:
646
  # remove from mask
 
647
  if mask_source_radio == mask_source_segment:
648
  mask_imgs = []
649
  masks_shape = masks_ori.shape
@@ -673,11 +673,16 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
673
  extend_pixels=remove_mask_extend, useRectangle=useRectangle)
674
  mask_imgs.append(mask_pil_exp)
675
  mask_pil = mix_masks(mask_imgs)
676
- output_images.append(mask_pil.convert("RGB"))
677
- image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")))
678
 
 
 
 
 
 
679
  image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
680
  output_images.append(image_inpainting)
 
681
  return output_images, gr.Gallery.update(label='result images')
682
  else:
683
  logger.info(f"task_type:{task_type} error!")
4
 
5
  import subprocess, io, os, sys, time
6
  os.system("pip install gradio==3.36.1")
7
+ import gradio as gr
8
 
9
  from loguru import logger
10
 
22
 
23
  sys.path.insert(0, './GroundingDINO')
24
 
 
 
25
  import argparse
26
  import copy
27
 
300
  device='cpu', # device,
301
  )
302
 
303
+ def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
304
  ori_image = image
305
  if mask.shape[0] == image.shape[1] and mask.shape[1] == image.shape[0] and mask.shape[0] != mask.shape[1]:
306
  # rotate image
310
  original_shape = ori_image.shape
311
  interpolation = cv2.INTER_CUBIC
312
 
313
+ size_limit = cleaner_size_limit
314
+ if size_limit == -1:
315
  size_limit = max(image.shape)
316
  else:
317
  size_limit = int(size_limit)
516
  mask_source_segment = "type what to detect below"
517
 
518
  def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
519
+ iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, cleaner_size_limit=1080):
520
  if (task_type == 'relate anything'):
521
  output_images = relate_anything(input_image['image'], num_relation)
522
  return output_images, gr.Gallery.update(label='relate images')
643
  image_inpainting = sd_pipe(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
644
  else:
645
  # remove from mask
646
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_5_')
647
  if mask_source_radio == mask_source_segment:
648
  mask_imgs = []
649
  masks_shape = masks_ori.shape
673
  extend_pixels=remove_mask_extend, useRectangle=useRectangle)
674
  mask_imgs.append(mask_pil_exp)
675
  mask_pil = mix_masks(mask_imgs)
676
+ output_images.append(mask_pil.convert("RGB"))
 
677
 
678
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_')
679
+ image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")), cleaner_size_limit)
680
+ output_images.append(image_inpainting)
681
+
682
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_7_')
683
  image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
684
  output_images.append(image_inpainting)
685
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
686
  return output_images, gr.Gallery.update(label='result images')
687
  else:
688
  logger.info(f"task_type:{task_type} error!")
app_cli.py CHANGED
@@ -115,7 +115,15 @@ if __name__ == '__main__':
115
  mask_source_radio = "type what to detect below",
116
  remove_mode = "rectangle", # ["segment", "rectangle"]
117
  remove_mask_extend = "10",
118
- num_relation = 5)
 
 
119
  if len(output_images) > 0:
120
- logger.info(f'save result to {args.output_image} ... ')
121
  output_images[-1].save(args.output_image)
 
 
 
 
 
 
115
  mask_source_radio = "type what to detect below",
116
  remove_mode = "rectangle", # ["segment", "rectangle"]
117
  remove_mask_extend = "10",
118
+ num_relation = 5,
119
+ cleaner_size_limit = -1,
120
+ )
121
  if len(output_images) > 0:
122
+ logger.info(f'save result to {args.output_image} ... ')
123
  output_images[-1].save(args.output_image)
124
+ # count = 0
125
+ # for output_image in output_images:
126
+ # count += 1
127
+ # if isinstance(output_image, np.ndarray):
128
+ # output_image = PIL.Image.fromarray(output_image.astype(np.uint8))
129
+ # output_image.save(args.output_image.replace(".", f"_{count}."))