zhiweili commited on
Commit
e2f86ff
·
1 Parent(s): 9d380a5

add t2i-adapter-sketch

Browse files
Files changed (1) hide show
  1. app_haircolor.py +26 -7
app_haircolor.py CHANGED
@@ -17,9 +17,11 @@ from diffusers import (
17
  from controlnet_aux import (
18
  LineartDetector,
19
  CannyDetector,
 
 
20
  )
21
 
22
- BASE_MODEL = "SG161222/RealVisXL_V5.0_Lightning"
23
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
 
25
  DEFAULT_EDIT_PROMPT = "a woman, blue hair, high detailed"
@@ -30,8 +32,16 @@ DEFAULT_CATEGORY = "hair"
30
  lineart_detector = LineartDetector.from_pretrained("lllyasviel/Annotators")
31
  lineart_detector = lineart_detector.to(DEVICE)
32
 
 
 
 
33
  canndy_detector = CannyDetector()
34
 
 
 
 
 
 
35
  adapters = MultiAdapter(
36
  [
37
  T2IAdapter.from_pretrained(
@@ -44,6 +54,11 @@ adapters = MultiAdapter(
44
  torch_dtype=torch.float16,
45
  varient="fp16",
46
  ),
 
 
 
 
 
47
  ]
48
  )
49
  adapters = adapters.to(torch.float16)
@@ -61,7 +76,7 @@ basepipeline = basepipeline.to(DEVICE)
61
 
62
  basepipeline.enable_model_cpu_offload()
63
 
64
- @spaces.GPU(duration=30)
65
  def image_to_image(
66
  input_image: Image,
67
  edit_prompt: str,
@@ -71,6 +86,7 @@ def image_to_image(
71
  generate_size: int,
72
  lineart_scale: float = 1.0,
73
  canny_scale: float = 0.5,
 
74
  ):
75
  run_task_time = 0
76
  time_cost_str = ''
@@ -79,9 +95,11 @@ def image_to_image(
79
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
80
  canny_image = canndy_detector(input_image, 384, generate_size)
81
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
 
 
82
 
83
- cond_image = [lineart_image, canny_image]
84
- cond_scale = [lineart_scale, canny_scale]
85
 
86
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
87
  generated_image = basepipeline(
@@ -127,8 +145,9 @@ def create_demo() -> gr.Blocks:
127
  mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
128
  with gr.Column():
129
  mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
130
- lineart_scale = gr.Slider(minimum=0, maximum=2, value=0.3, step=0.1, label="Lineart Scale")
131
- canny_scale = gr.Slider(minimum=0, maximum=2, value=0.7, step=0.1, label="Canny Scale")
 
132
  g_btn = gr.Button("Edit Image")
133
 
134
  with gr.Row():
@@ -147,7 +166,7 @@ def create_demo() -> gr.Blocks:
147
  outputs=[origin_area_image, croper],
148
  ).success(
149
  fn=image_to_image,
150
- inputs=[origin_area_image, edit_prompt,seed, num_steps, guidance_scale, generate_size, lineart_scale, canny_scale],
151
  outputs=[generated_image, generated_cost],
152
  ).success(
153
  fn=restore_result,
 
17
  from controlnet_aux import (
18
  LineartDetector,
19
  CannyDetector,
20
+ PidiNetDetector,
21
+ MidasDetector,
22
  )
23
 
24
+ BASE_MODEL = "stabilityai/sdxl-turbo"
25
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
 
27
  DEFAULT_EDIT_PROMPT = "a woman, blue hair, high detailed"
 
32
  lineart_detector = LineartDetector.from_pretrained("lllyasviel/Annotators")
33
  lineart_detector = lineart_detector.to(DEVICE)
34
 
35
+ pidinet_detector = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
36
+ pidinet_detector = pidinet_detector.to(DEVICE)
37
+
38
  canndy_detector = CannyDetector()
39
 
40
+ midas_detector = MidasDetector.from_pretrained(
41
+ "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
42
+ )
43
+ midas_detector = midas_detector.to(DEVICE)
44
+
45
  adapters = MultiAdapter(
46
  [
47
  T2IAdapter.from_pretrained(
 
54
  torch_dtype=torch.float16,
55
  varient="fp16",
56
  ),
57
+ T2IAdapter.from_pretrained(
58
+ "TencentARC/t2i-adapter-sketch-sdxl-1.0",
59
+ torch_dtype=torch.float16,
60
+ varient="fp16",
61
+ ),
62
  ]
63
  )
64
  adapters = adapters.to(torch.float16)
 
76
 
77
  basepipeline.enable_model_cpu_offload()
78
 
79
+ @spaces.GPU(duration=15)
80
  def image_to_image(
81
  input_image: Image,
82
  edit_prompt: str,
 
86
  generate_size: int,
87
  lineart_scale: float = 1.0,
88
  canny_scale: float = 0.5,
89
+ sketch_scale:float = 0.5,
90
  ):
91
  run_task_time = 0
92
  time_cost_str = ''
 
95
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
96
  canny_image = canndy_detector(input_image, 384, generate_size)
97
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
98
+ sketch_image = pidinet_detector(input_image, 512, generate_size)
99
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
100
 
101
+ cond_image = [lineart_image, canny_image, sketch_image]
102
+ cond_scale = [lineart_scale, canny_scale, sketch_scale]
103
 
104
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
105
  generated_image = basepipeline(
 
145
  mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
146
  with gr.Column():
147
  mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
148
+ lineart_scale = gr.Slider(minimum=0, maximum=5, value=1, step=0.1, label="Lineart Scale")
149
+ canny_scale = gr.Slider(minimum=0, maximum=5, value=0.7, step=0.1, label="Canny Scale")
150
+ sketch_scale = gr.Slider(minimum=0, maximum=5, value=1, step=0.1, label="Sketch Scale")
151
  g_btn = gr.Button("Edit Image")
152
 
153
  with gr.Row():
 
166
  outputs=[origin_area_image, croper],
167
  ).success(
168
  fn=image_to_image,
169
+ inputs=[origin_area_image, edit_prompt,seed, num_steps, guidance_scale, generate_size, lineart_scale, canny_scale, sketch_scale],
170
  outputs=[generated_image, generated_cost],
171
  ).success(
172
  fn=restore_result,