yizhangliu commited on
Commit
47dfe4c
1 Parent(s): 34fde06

update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -7
app.py CHANGED
@@ -45,7 +45,7 @@ plt = matplotlib.pyplot
45
 
46
  groundingdino_enable = True
47
  sam_enable = True
48
- inpainting_enable = False #True
49
  ram_enable = False
50
 
51
  lama_cleaner_enable = True
@@ -79,11 +79,13 @@ from io import BytesIO
79
  from diffusers import StableDiffusionInpaintPipeline
80
  from huggingface_hub import hf_hub_download
81
 
82
- from huggingface_hub import snapshot_download
83
- from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_inpainting import StableDiffusionXLInpaintPipeline
84
- from kolors.models.modeling_chatglm import ChatGLMModel
85
- from kolors.models.tokenization_chatglm import ChatGLMTokenizer
86
- from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
 
 
87
 
88
  from util_computer import computer_info
89
 
@@ -329,6 +331,7 @@ def load_sd_model(device):
329
  global sd_model
330
  logger.info(f"initialize stable-diffusion-inpainting...")
331
  sd_model = None
 
332
  if os.environ.get('IS_MY_DEBUG') is None:
333
  # sd_model = StableDiffusionInpaintPipeline.from_pretrained(
334
  # "runwayml/stable-diffusion-inpainting",
@@ -355,6 +358,7 @@ def load_sd_model(device):
355
 
356
  sd_model.to(device)
357
  sd_model.enable_attention_slicing()
 
358
 
359
  def load_lama_cleaner_model(device):
360
  # initialize lama_cleaner
@@ -613,6 +617,29 @@ def get_time_cost(run_task_time, time_cost_str):
613
  run_task_time = now_time
614
  return run_task_time, time_cost_str
615
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
616
  def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
617
  iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input, cleaner_size_limit=1080):
618
 
@@ -624,6 +651,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
624
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
625
 
626
  # logger.info(f"input_image==={input_image}")
 
627
  if 'background' in input_image.keys():
628
  input_image['image'] = input_image['background'].convert("RGB")
629
  if len(input_image['layers']) > 0:
@@ -794,7 +822,9 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
794
  image_mask_for_inpaint = Image.fromarray(255*img_arr.astype('uint8'))
795
  output_images.append(image_mask_for_inpaint.convert("RGB"))
796
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
797
- image_inpainting = sd_model(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
 
 
798
  else:
799
  # remove from mask
800
  if mask_source_radio == mask_source_segment:
 
45
 
46
  groundingdino_enable = True
47
  sam_enable = True
48
+ inpainting_enable = True
49
  ram_enable = False
50
 
51
  lama_cleaner_enable = True
 
79
  from diffusers import StableDiffusionInpaintPipeline
80
  from huggingface_hub import hf_hub_download
81
 
82
+ from gradio_client import Client, handle_file
83
+
84
+ # from huggingface_hub import snapshot_download
85
+ # from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_inpainting import StableDiffusionXLInpaintPipeline
86
+ # from kolors.models.modeling_chatglm import ChatGLMModel
87
+ # from kolors.models.tokenization_chatglm import ChatGLMTokenizer
88
+ # from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
89
 
90
  from util_computer import computer_info
91
 
 
331
  global sd_model
332
  logger.info(f"initialize stable-diffusion-inpainting...")
333
  sd_model = None
334
+ '''
335
  if os.environ.get('IS_MY_DEBUG') is None:
336
  # sd_model = StableDiffusionInpaintPipeline.from_pretrained(
337
  # "runwayml/stable-diffusion-inpainting",
 
358
 
359
  sd_model.to(device)
360
  sd_model.enable_attention_slicing()
361
+ '''
362
 
363
  def load_lama_cleaner_model(device):
364
  # initialize lama_cleaner
 
617
  run_task_time = now_time
618
  return run_task_time, time_cost_str
619
 
620
+ def load_kolors_inpainting(inpaint_prompt, image, mask_image):
621
+ # sd_model(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
622
+
623
+ client = Client("Kwai-Kolors/Kolors-Inpainting")
624
+ result = client.predict(
625
+ prompt=inpaint_prompt,
626
+ image=image,
627
+ mask_image = mask_image,
628
+ negative_prompt="broken fingers, deformed fingers, deformed hands, stumps, blurriness, low quality",
629
+ seed=0,
630
+ randomize_seed=True,
631
+ guidance_scale=6,
632
+ num_inference_steps=25,
633
+ api_name="/infer"
634
+ )
635
+ logger.info(f'load_kolors_inpainting_result={result}')
636
+ im = Image.open(result)
637
+ if im.mode == "RGBA":
638
+ im.load() # required for png.split()
639
+ background = Image.new("RGB", im.size, (255, 255, 255))
640
+ background.paste(im, mask=im.split()[3])
641
+ return result
642
+
643
  def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
644
  iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input, cleaner_size_limit=1080):
645
 
 
651
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
652
 
653
  # logger.info(f"input_image==={input_image}")
654
+ ori_input_image = input_image
655
  if 'background' in input_image.keys():
656
  input_image['image'] = input_image['background'].convert("RGB")
657
  if len(input_image['layers']) > 0:
 
822
  image_mask_for_inpaint = Image.fromarray(255*img_arr.astype('uint8'))
823
  output_images.append(image_mask_for_inpaint.convert("RGB"))
824
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
825
+
826
+ # image_inpainting = sd_model(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
827
+ image_inpainting = load_kolors_inpainting(ori_input_image, image_source_for_inpaint, image_mask_for_inpaint).images[0])
828
  else:
829
  # remove from mask
830
  if mask_source_radio == mask_source_segment: