zhiweili commited on
Commit
9ae974c
·
1 Parent(s): 3db6256

p2p add canny

Browse files
Files changed (1) hide show
  1. app_haircolor_img2img.py +18 -6
app_haircolor_img2img.py CHANGED
@@ -44,10 +44,16 @@ pidiNet_detector = pidiNet_detector.to(DEVICE)
44
  hed_detector = HEDdetector.from_pretrained('lllyasviel/Annotators')
45
  hed_detector = hed_detector.to(DEVICE)
46
 
47
- controlnet = ControlNetModel.from_pretrained(
48
- "lllyasviel/control_v11e_sd15_ip2p",
49
- torch_dtype=torch.float16,
50
- )
 
 
 
 
 
 
51
 
52
  basepipeline = StableDiffusionControlNetPipeline.from_pretrained(
53
  BASE_MODEL,
@@ -70,12 +76,15 @@ def image_to_image(
70
  num_steps: int,
71
  guidance_scale: float,
72
  generate_size: int,
 
 
73
  ):
74
  run_task_time = 0
75
  time_cost_str = ''
76
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
 
77
 
78
- cond_image = input_image
79
 
80
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
81
  generated_image = basepipeline(
@@ -87,6 +96,7 @@ def image_to_image(
87
  width=generate_size,
88
  guidance_scale=guidance_scale,
89
  num_inference_steps=num_steps,
 
90
  ).images[0]
91
 
92
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
@@ -130,6 +140,8 @@ def create_demo() -> gr.Blocks:
130
  mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
131
  seed = gr.Number(label="Seed", value=8)
132
  category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
 
 
133
  g_btn = gr.Button("Edit Image")
134
 
135
  with gr.Row():
@@ -148,7 +160,7 @@ def create_demo() -> gr.Blocks:
148
  outputs=[origin_area_image, croper],
149
  ).success(
150
  fn=image_to_image,
151
- inputs=[origin_area_image, edit_prompt,seed, num_steps, guidance_scale, generate_size],
152
  outputs=[generated_image, generated_cost],
153
  ).success(
154
  fn=restore_result,
 
44
  hed_detector = HEDdetector.from_pretrained('lllyasviel/Annotators')
45
  hed_detector = hed_detector.to(DEVICE)
46
 
47
+ controlnet = [
48
+ ControlNetModel.from_pretrained(
49
+ "lllyasviel/control_v11e_sd15_ip2p",
50
+ torch_dtype=torch.float16,
51
+ ),
52
+ ControlNetModel.from_pretrained(
53
+ "lllyasviel/control_v11p_sd15_canny",
54
+ torch_dtype=torch.float16,
55
+ ),
56
+ ]
57
 
58
  basepipeline = StableDiffusionControlNetPipeline.from_pretrained(
59
  BASE_MODEL,
 
76
  num_steps: int,
77
  guidance_scale: float,
78
  generate_size: int,
79
+ cond_scale1: float = 1.2,
80
+ cond_scale2: float = 1.2,
81
  ):
82
  run_task_time = 0
83
  time_cost_str = ''
84
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
85
+ canny_image = canny_detector(input_image)
86
 
87
+ cond_image = [input_image, canny_image]
88
 
89
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
90
  generated_image = basepipeline(
 
96
  width=generate_size,
97
  guidance_scale=guidance_scale,
98
  num_inference_steps=num_steps,
99
+ controlnet_conditioning_scale=[cond_scale1, cond_scale2],
100
  ).images[0]
101
 
102
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
 
140
  mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
141
  seed = gr.Number(label="Seed", value=8)
142
  category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
143
+ cond_scale1 = gr.Slider(minimum=0, maximum=3, value=1.2, step=0.1, label="Cond_scale1")
144
+ cond_scale2 = gr.Slider(minimum=0, maximum=3, value=1.2, step=0.1, label="Cond_scale2")
145
  g_btn = gr.Button("Edit Image")
146
 
147
  with gr.Row():
 
160
  outputs=[origin_area_image, croper],
161
  ).success(
162
  fn=image_to_image,
163
+ inputs=[origin_area_image, edit_prompt,seed, num_steps, guidance_scale, generate_size, cond_scale1, cond_scale2],
164
  outputs=[generated_image, generated_cost],
165
  ).success(
166
  fn=restore_result,