ResearcherXman commited on
Commit
ec7fc1c
1 Parent(s): f4fab1d

support lcm and multi-controlnets

Browse files
app.py CHANGED
@@ -1,23 +1,34 @@
1
- import math
2
- import random
3
-
4
  import cv2
5
- import gradio as gr
 
6
  import numpy as np
 
7
  import PIL
8
- import spaces
9
- import torch
10
- from diffusers.models import ControlNetModel
11
  from diffusers.utils import load_image
 
 
 
 
 
12
  from insightface.app import FaceAnalysis
13
- from PIL import Image
14
 
15
- from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline
16
  from style_template import styles
 
 
 
 
 
 
 
17
 
18
  # global variable
19
  MAX_SEED = np.iinfo(np.int32).max
20
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
21
  STYLE_NAMES = list(styles.keys())
22
  DEFAULT_STYLE_NAME = "Watercolor"
23
 
@@ -33,69 +44,120 @@ hf_hub_download(
33
  hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints")
34
 
35
  # Load face encoder
36
- app = FaceAnalysis(name="antelopev2", root="./", providers=["CPUExecutionProvider"])
 
 
 
 
37
  app.prepare(ctx_id=0, det_size=(640, 640))
38
 
39
  # Path to InstantID models
40
- face_adapter = "./checkpoints/ip-adapter.bin"
41
- controlnet_path = "./checkpoints/ControlNetModel"
42
 
43
- # Load pipeline
44
- controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
 
 
45
 
46
- base_model_path = "wangqixun/YamerMIX_v8"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
49
- base_model_path,
50
- controlnet=controlnet,
51
- torch_dtype=torch.float16,
52
  safety_checker=None,
53
  feature_extractor=None,
 
 
 
 
54
  )
55
- pipe.cuda()
56
- pipe.load_ip_adapter_instantid(face_adapter)
57
- pipe.image_proj_model.to("cuda")
58
- pipe.unet.to("cuda")
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
62
  if randomize_seed:
63
  seed = random.randint(0, MAX_SEED)
64
  return seed
65
 
66
-
67
  def remove_tips():
68
  return gr.update(visible=False)
69
 
70
-
71
  def get_example():
72
  case = [
73
  [
74
  "./examples/yann-lecun_resize.jpg",
 
75
  "a man",
76
  "Snow",
77
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
78
  ],
79
  [
80
  "./examples/musk_resize.jpeg",
81
- "a man",
 
82
  "Mars",
83
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
84
  ],
85
  [
86
  "./examples/sam_resize.png",
87
- "a man",
 
88
  "Jungle",
89
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, gree",
90
  ],
91
  [
92
  "./examples/schmidhuber_resize.png",
93
- "a man",
 
94
  "Neon",
95
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
96
  ],
97
  [
98
  "./examples/kaifu_resize.png",
 
99
  "a man",
100
  "Vibrant Color",
101
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
@@ -103,50 +165,33 @@ def get_example():
103
  ]
104
  return case
105
 
106
-
107
- def run_for_examples(face_file, prompt, style, negative_prompt):
108
- return generate_image(face_file, None, prompt, negative_prompt, style, True, 30, 0.8, 0.8, 5, 42)
109
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  def convert_from_cv2_to_image(img: np.ndarray) -> Image:
112
  return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
113
 
114
-
115
  def convert_from_image_to_cv2(img: Image) -> np.ndarray:
116
  return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
117
 
118
-
119
- def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
120
- stickwidth = 4
121
- limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
122
- kps = np.array(kps)
123
-
124
- w, h = image_pil.size
125
- out_img = np.zeros([h, w, 3])
126
-
127
- for i in range(len(limbSeq)):
128
- index = limbSeq[i]
129
- color = color_list[index[0]]
130
-
131
- x = kps[index][:, 0]
132
- y = kps[index][:, 1]
133
- length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
134
- angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
135
- polygon = cv2.ellipse2Poly(
136
- (int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1
137
- )
138
- out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
139
- out_img = (out_img * 0.6).astype(np.uint8)
140
-
141
- for idx_kp, kp in enumerate(kps):
142
- color = color_list[idx_kp]
143
- x, y = kp
144
- out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
145
-
146
- out_img_pil = Image.fromarray(out_img.astype(np.uint8))
147
- return out_img_pil
148
-
149
-
150
  def resize_img(
151
  input_image,
152
  max_side=1280,
@@ -172,21 +217,18 @@ def resize_img(
172
  res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
173
  offset_x = (max_side - w_resize_new) // 2
174
  offset_y = (max_side - h_resize_new) // 2
175
- res[offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new] = np.array(input_image)
 
 
176
  input_image = Image.fromarray(res)
177
  return input_image
178
 
179
-
180
- def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
 
181
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
182
  return p.replace("{prompt}", positive), n + " " + negative
183
 
184
-
185
- def check_input_image(face_image):
186
- if face_image is None:
187
- raise gr.Error("Cannot find any input face image! Please upload the face image")
188
-
189
-
190
  @spaces.GPU
191
  def generate_image(
192
  face_image_path,
@@ -194,14 +236,41 @@ def generate_image(
194
  prompt,
195
  negative_prompt,
196
  style_name,
197
- enhance_face_region,
198
  num_steps,
199
  identitynet_strength_ratio,
200
  adapter_strength_ratio,
 
 
 
 
201
  guidance_scale,
202
  seed,
 
 
 
203
  progress=gr.Progress(track_tqdm=True),
204
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  if prompt is None:
206
  prompt = "a person"
207
 
@@ -209,7 +278,7 @@ def generate_image(
209
  prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
210
 
211
  face_image = load_image(face_image_path)
212
- face_image = resize_img(face_image)
213
  face_image_cv2 = convert_from_image_to_cv2(face_image)
214
  height, width, _ = face_image_cv2.shape
215
 
@@ -217,23 +286,31 @@ def generate_image(
217
  face_info = app.get(face_image_cv2)
218
 
219
  if len(face_info) == 0:
220
- raise gr.Error("Cannot find any face in the image! Please upload another person image")
 
 
221
 
222
- face_info = sorted(face_info, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * x["bbox"][3] - x["bbox"][1])[
 
 
 
223
  -1
224
  ] # only use the maximum face
225
  face_emb = face_info["embedding"]
226
  face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info["kps"])
227
-
228
  if pose_image_path is not None:
229
  pose_image = load_image(pose_image_path)
230
- pose_image = resize_img(pose_image)
 
231
  pose_image_cv2 = convert_from_image_to_cv2(pose_image)
232
 
233
  face_info = app.get(pose_image_cv2)
234
 
235
  if len(face_info) == 0:
236
- raise gr.Error("Cannot find any face in the reference image! Please upload another person image")
 
 
237
 
238
  face_info = face_info[-1]
239
  face_kps = draw_kps(pose_image, face_info["kps"])
@@ -249,6 +326,28 @@ def generate_image(
249
  else:
250
  control_mask = None
251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  generator = torch.Generator(device=device).manual_seed(seed)
253
 
254
  print("Start inference...")
@@ -259,9 +358,9 @@ def generate_image(
259
  prompt=prompt,
260
  negative_prompt=negative_prompt,
261
  image_embeds=face_emb,
262
- image=face_kps,
263
  control_mask=control_mask,
264
- controlnet_conditioning_scale=float(identitynet_strength_ratio),
265
  num_inference_steps=num_steps,
266
  guidance_scale=guidance_scale,
267
  height=height,
@@ -271,8 +370,7 @@ def generate_image(
271
 
272
  return images[0], gr.update(visible=True)
273
 
274
-
275
- ### Description
276
  title = r"""
277
  <h1 align="center">InstantID: Zero-shot Identity-Preserving Generation in Seconds</h1>
278
  """
@@ -281,12 +379,12 @@ description = r"""
281
  <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/InstantID/InstantID' target='_blank'><b>InstantID: Zero-shot Identity-Preserving Generation in Seconds</b></a>.<br>
282
 
283
  How to use:<br>
284
- 1. Upload a person image. For multiple person images, we will only detect the biggest face. Make sure face is not too small and not significantly blocked or blurred.
285
- 2. (Optionally) upload another person image as reference pose. If not uploaded, we will use the first person image to extract landmarks. If you use a cropped face at step1, it is recommeneded to upload it to extract a new pose.
286
- 3. Enter a text prompt as done in normal text-to-image models.
287
- 4. Click the <b>Submit</b> button to start customizing.
288
- 5. Share your customizd photo with your friends, enjoy😊!
289
- """
290
 
291
  article = r"""
292
  ---
@@ -295,10 +393,10 @@ article = r"""
295
  If our work is helpful for your research or applications, please cite us via:
296
  ```bibtex
297
  @article{wang2024instantid,
298
- title={InstantID: Zero-shot Identity-Preserving Generation in Seconds},
299
- author={Wang, Qixun and Bai, Xu and Wang, Haofan and Qin, Zekui and Chen, Anthony},
300
- journal={arXiv preprint arXiv:2401.07519},
301
- year={2024}
302
  }
303
  ```
304
  📧 **Contact**
@@ -308,10 +406,10 @@ If you have any questions, please feel free to open an issue or directly reach u
308
 
309
  tips = r"""
310
  ### Usage tips of InstantID
311
- 1. If you're unsatisfied with the similarity, increase the weight of controlnet_conditioning_scale (IdentityNet) and ip_adapter_scale (Adapter).
312
- 2. If the generated image is over-saturated, decrease the ip_adapter_scale. If not work, decrease controlnet_conditioning_scale.
313
- 3. If text control is not as expected, decrease ip_adapter_scale.
314
- 4. Find a good base model always makes a difference.
315
  """
316
 
317
  css = """
@@ -324,27 +422,39 @@ with gr.Blocks(css=css) as demo:
324
 
325
  with gr.Row():
326
  with gr.Column():
327
- # upload face image
328
- face_file = gr.Image(label="Upload a photo of your face", type="filepath")
329
-
330
- # optional: upload a reference pose image
331
- pose_file = gr.Image(label="Upload a reference pose image (optional)", type="filepath")
 
 
 
 
 
332
 
333
  # prompt
334
  prompt = gr.Textbox(
335
  label="Prompt",
336
- info="Give simple prompt is enough to achieve good face fedility",
337
  placeholder="A photo of a person",
338
  value="",
339
  )
340
 
341
  submit = gr.Button("Submit", variant="primary")
342
-
343
- style = gr.Dropdown(label="Style template", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
 
 
 
 
 
 
 
344
 
345
  # strength
346
  identitynet_strength_ratio = gr.Slider(
347
- label="IdentityNet strength (for fedility)",
348
  minimum=0,
349
  maximum=1.5,
350
  step=0.05,
@@ -357,26 +467,51 @@ with gr.Blocks(css=css) as demo:
357
  step=0.05,
358
  value=0.80,
359
  )
360
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  with gr.Accordion(open=False, label="Advanced Options"):
362
  negative_prompt = gr.Textbox(
363
  label="Negative Prompt",
364
  placeholder="low quality",
365
- value="(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, nudity,naked, bikini, skimpy, scanty, bare skin, lingerie, swimsuit, exposed, see-through, photo, anthropomorphic cat, monochrome, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
366
  )
367
  num_steps = gr.Slider(
368
  label="Number of sample steps",
369
- minimum=20,
370
  maximum=100,
371
  step=1,
372
- value=30,
373
  )
374
  guidance_scale = gr.Slider(
375
  label="Guidance scale",
376
  minimum=0.1,
377
- maximum=10.0,
378
  step=0.1,
379
- value=5,
380
  )
381
  seed = gr.Slider(
382
  label="Seed",
@@ -385,18 +520,31 @@ with gr.Blocks(css=css) as demo:
385
  step=1,
386
  value=42,
387
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
389
  enhance_face_region = gr.Checkbox(label="Enhance non-face region", value=True)
390
 
391
- with gr.Column():
392
- output_image = gr.Image(label="Generated Image")
393
- usage_tips = gr.Markdown(label="Usage tips of InstantID", value=tips, visible=False)
 
 
394
 
395
  submit.click(
396
  fn=remove_tips,
397
  outputs=usage_tips,
398
- queue=False,
399
- api_name=False,
400
  ).then(
401
  fn=randomize_seed_fn,
402
  inputs=[seed, randomize_seed],
@@ -404,11 +552,6 @@ with gr.Blocks(css=css) as demo:
404
  queue=False,
405
  api_name=False,
406
  ).then(
407
- fn=check_input_image,
408
- inputs=face_file,
409
- queue=False,
410
- api_name=False,
411
- ).success(
412
  fn=generate_image,
413
  inputs=[
414
  face_file,
@@ -416,21 +559,34 @@ with gr.Blocks(css=css) as demo:
416
  prompt,
417
  negative_prompt,
418
  style,
419
- enhance_face_region,
420
  num_steps,
421
  identitynet_strength_ratio,
422
  adapter_strength_ratio,
 
 
 
 
423
  guidance_scale,
424
  seed,
 
 
 
425
  ],
426
- outputs=[output_image, usage_tips],
 
 
 
 
 
 
 
427
  )
428
 
429
  gr.Examples(
430
  examples=get_example(),
431
- inputs=[face_file, prompt, style, negative_prompt],
432
- outputs=[output_image, usage_tips],
433
  fn=run_for_examples,
 
434
  cache_examples=True,
435
  )
436
 
 
1
+ import os
 
 
2
  import cv2
3
+ import torch
4
+ import random
5
  import numpy as np
6
+
7
  import PIL
8
+ from PIL import Image
9
+
10
+ import diffusers
11
  from diffusers.utils import load_image
12
+ from diffusers.models import ControlNetModel
13
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
14
+
15
+ from huggingface_hub import hf_hub_download
16
+
17
  from insightface.app import FaceAnalysis
 
18
 
 
19
  from style_template import styles
20
+ from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline, draw_kps
21
+ from model_util import load_models_xl, get_torch_device
22
+ from controlnet_util import openpose, get_depth_map, get_canny_image
23
+
24
+ import gradio as gr
25
+
26
+ import spaces
27
 
28
  # global variable
29
  MAX_SEED = np.iinfo(np.int32).max
30
+ device = get_torch_device()
31
+ dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
32
  STYLE_NAMES = list(styles.keys())
33
  DEFAULT_STYLE_NAME = "Watercolor"
34
 
 
44
  hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints")
45
 
46
  # Load face encoder
47
+ app = FaceAnalysis(
48
+ name="antelopev2",
49
+ root="./",
50
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
51
+ )
52
  app.prepare(ctx_id=0, det_size=(640, 640))
53
 
54
  # Path to InstantID models
55
+ face_adapter = f"./checkpoints/ip-adapter.bin"
56
+ controlnet_path = f"./checkpoints/ControlNetModel"
57
 
58
+ # Load pipeline face ControlNetModel
59
+ controlnet_identitynet = ControlNetModel.from_pretrained(
60
+ controlnet_path, torch_dtype=dtype
61
+ )
62
 
63
+ # controlnet-pose
64
+ controlnet_pose_model = "thibaud/controlnet-openpose-sdxl-1.0"
65
+ controlnet_canny_model = "diffusers/controlnet-canny-sdxl-1.0"
66
+ controlnet_depth_model = "diffusers/controlnet-depth-sdxl-1.0-small"
67
+
68
+ controlnet_pose = ControlNetModel.from_pretrained(
69
+ controlnet_pose_model, torch_dtype=dtype
70
+ ).to(device)
71
+ controlnet_canny = ControlNetModel.from_pretrained(
72
+ controlnet_canny_model, torch_dtype=dtype
73
+ ).to(device)
74
+ controlnet_depth = ControlNetModel.from_pretrained(
75
+ controlnet_depth_model, torch_dtype=dtype
76
+ ).to(device)
77
+
78
+ controlnet_map = {
79
+ "pose": controlnet_pose,
80
+ "canny": controlnet_canny,
81
+ "depth": controlnet_depth,
82
+ }
83
+ controlnet_map_fn = {
84
+ "pose": openpose,
85
+ "canny": get_canny_image,
86
+ "depth": get_depth_map,
87
+ }
88
+
89
+ pretrained_model_name_or_path = "wangqixun/YamerMIX_v8"
90
 
91
  pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
92
+ pretrained_model_name_or_path,
93
+ controlnet=[controlnet_identitynet],
94
+ torch_dtype=dtype,
95
  safety_checker=None,
96
  feature_extractor=None,
97
+ ).to(device)
98
+
99
+ pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(
100
+ pipe.scheduler.config
101
  )
 
 
 
 
102
 
103
+ pipe.load_ip_adapter_instantid(face_adapter)
104
+ # load and disable LCM
105
+ pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl")
106
+ pipe.disable_lora()
107
+
108
+ def toggle_lcm_ui(value):
109
+ if value:
110
+ return (
111
+ gr.update(minimum=0, maximum=100, step=1, value=5),
112
+ gr.update(minimum=0.1, maximum=20.0, step=0.1, value=1.5),
113
+ )
114
+ else:
115
+ return (
116
+ gr.update(minimum=5, maximum=100, step=1, value=30),
117
+ gr.update(minimum=0.1, maximum=20.0, step=0.1, value=5),
118
+ )
119
 
120
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
121
  if randomize_seed:
122
  seed = random.randint(0, MAX_SEED)
123
  return seed
124
 
 
125
  def remove_tips():
126
  return gr.update(visible=False)
127
 
 
128
  def get_example():
129
  case = [
130
  [
131
  "./examples/yann-lecun_resize.jpg",
132
+ None,
133
  "a man",
134
  "Snow",
135
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
136
  ],
137
  [
138
  "./examples/musk_resize.jpeg",
139
+ "./examples/poses/pose2.jpg",
140
+ "a man flying in the sky in Mars",
141
  "Mars",
142
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
143
  ],
144
  [
145
  "./examples/sam_resize.png",
146
+ "./examples/poses/pose4.jpg",
147
+ "a man doing a silly pose wearing a suite",
148
  "Jungle",
149
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, gree",
150
  ],
151
  [
152
  "./examples/schmidhuber_resize.png",
153
+ "./examples/poses/pose3.jpg",
154
+ "a man sit on a chair",
155
  "Neon",
156
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
157
  ],
158
  [
159
  "./examples/kaifu_resize.png",
160
+ "./examples/poses/pose.jpg",
161
  "a man",
162
  "Vibrant Color",
163
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
 
165
  ]
166
  return case
167
 
168
+ def run_for_examples(face_file, pose_file, prompt, style, negative_prompt):
169
+ return generate_image(
170
+ face_file,
171
+ pose_file,
172
+ prompt,
173
+ negative_prompt,
174
+ style,
175
+ 20, # num_steps
176
+ 0.8, # identitynet_strength_ratio
177
+ 0.8, # adapter_strength_ratio
178
+ 0.4, # pose_strength
179
+ 0.3, # canny_strength
180
+ 0.5, # depth_strength
181
+ ["pose", "canny"], # controlnet_selection
182
+ 5.0, # guidance_scale
183
+ 42, # seed
184
+ "EulerDiscreteScheduler", # scheduler
185
+ False, # enable_LCM
186
+ True, # enable_Face_Region
187
+ )
188
 
189
  def convert_from_cv2_to_image(img: np.ndarray) -> Image:
190
  return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
191
 
 
192
  def convert_from_image_to_cv2(img: Image) -> np.ndarray:
193
  return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  def resize_img(
196
  input_image,
197
  max_side=1280,
 
217
  res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
218
  offset_x = (max_side - w_resize_new) // 2
219
  offset_y = (max_side - h_resize_new) // 2
220
+ res[
221
+ offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new
222
+ ] = np.array(input_image)
223
  input_image = Image.fromarray(res)
224
  return input_image
225
 
226
+ def apply_style(
227
+ style_name: str, positive: str, negative: str = ""
228
+ ) -> tuple[str, str]:
229
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
230
  return p.replace("{prompt}", positive), n + " " + negative
231
 
 
 
 
 
 
 
232
  @spaces.GPU
233
  def generate_image(
234
  face_image_path,
 
236
  prompt,
237
  negative_prompt,
238
  style_name,
 
239
  num_steps,
240
  identitynet_strength_ratio,
241
  adapter_strength_ratio,
242
+ pose_strength,
243
+ canny_strength,
244
+ depth_strength,
245
+ controlnet_selection,
246
  guidance_scale,
247
  seed,
248
+ scheduler,
249
+ enable_LCM,
250
+ enhance_face_region,
251
  progress=gr.Progress(track_tqdm=True),
252
  ):
253
+
254
+ if enable_LCM:
255
+ pipe.scheduler = diffusers.LCMScheduler.from_config(pipe.scheduler.config)
256
+ pipe.enable_lora()
257
+ else:
258
+ pipe.disable_lora()
259
+ scheduler_class_name = scheduler.split("-")[0]
260
+
261
+ add_kwargs = {}
262
+ if len(scheduler.split("-")) > 1:
263
+ add_kwargs["use_karras_sigmas"] = True
264
+ if len(scheduler.split("-")) > 2:
265
+ add_kwargs["algorithm_type"] = "sde-dpmsolver++"
266
+ scheduler = getattr(diffusers, scheduler_class_name)
267
+ pipe.scheduler = scheduler.from_config(pipe.scheduler.config, **add_kwargs)
268
+
269
+ if face_image_path is None:
270
+ raise gr.Error(
271
+ f"Cannot find any input face image! Please upload the face image"
272
+ )
273
+
274
  if prompt is None:
275
  prompt = "a person"
276
 
 
278
  prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
279
 
280
  face_image = load_image(face_image_path)
281
+ face_image = resize_img(face_image, max_side=1024)
282
  face_image_cv2 = convert_from_image_to_cv2(face_image)
283
  height, width, _ = face_image_cv2.shape
284
 
 
286
  face_info = app.get(face_image_cv2)
287
 
288
  if len(face_info) == 0:
289
+ raise gr.Error(
290
+ f"Unable to detect a face in the image. Please upload a different photo with a clear face."
291
+ )
292
 
293
+ face_info = sorted(
294
+ face_info,
295
+ key=lambda x: (x["bbox"][2] - x["bbox"][0]) * x["bbox"][3] - x["bbox"][1],
296
+ )[
297
  -1
298
  ] # only use the maximum face
299
  face_emb = face_info["embedding"]
300
  face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info["kps"])
301
+ img_controlnet = face_image
302
  if pose_image_path is not None:
303
  pose_image = load_image(pose_image_path)
304
+ pose_image = resize_img(pose_image, max_side=1024)
305
+ img_controlnet = pose_image
306
  pose_image_cv2 = convert_from_image_to_cv2(pose_image)
307
 
308
  face_info = app.get(pose_image_cv2)
309
 
310
  if len(face_info) == 0:
311
+ raise gr.Error(
312
+ f"Cannot find any face in the reference image! Please upload another person image"
313
+ )
314
 
315
  face_info = face_info[-1]
316
  face_kps = draw_kps(pose_image, face_info["kps"])
 
326
  else:
327
  control_mask = None
328
 
329
+ if len(controlnet_selection) > 0:
330
+ controlnet_scales = {
331
+ "pose": pose_strength,
332
+ "canny": canny_strength,
333
+ "depth": depth_strength,
334
+ }
335
+ pipe.controlnet = MultiControlNetModel(
336
+ [controlnet_identitynet]
337
+ + [controlnet_map[s] for s in controlnet_selection]
338
+ )
339
+ control_scales = [float(identitynet_strength_ratio)] + [
340
+ controlnet_scales[s] for s in controlnet_selection
341
+ ]
342
+ control_images = [face_kps] + [
343
+ controlnet_map_fn[s](img_controlnet).resize((width, height))
344
+ for s in controlnet_selection
345
+ ]
346
+ else:
347
+ pipe.controlnet = controlnet_identitynet
348
+ control_scales = float(identitynet_strength_ratio)
349
+ control_images = face_kps
350
+
351
  generator = torch.Generator(device=device).manual_seed(seed)
352
 
353
  print("Start inference...")
 
358
  prompt=prompt,
359
  negative_prompt=negative_prompt,
360
  image_embeds=face_emb,
361
+ image=control_images,
362
  control_mask=control_mask,
363
+ controlnet_conditioning_scale=control_scales,
364
  num_inference_steps=num_steps,
365
  guidance_scale=guidance_scale,
366
  height=height,
 
370
 
371
  return images[0], gr.update(visible=True)
372
 
373
+ # Description
 
374
  title = r"""
375
  <h1 align="center">InstantID: Zero-shot Identity-Preserving Generation in Seconds</h1>
376
  """
 
379
  <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/InstantID/InstantID' target='_blank'><b>InstantID: Zero-shot Identity-Preserving Generation in Seconds</b></a>.<br>
380
 
381
  How to use:<br>
382
+ 1. Upload an image with a face. For images with multiple faces, we will only detect the largest face. Ensure the face is not too small and is clearly visible without significant obstructions or blurring.
383
+ 2. (Optional) You can upload another image as a reference for the face pose. If you don't, we will use the first detected face image to extract facial landmarks. If you use a cropped face at step 1, it is recommended to upload it to define a new face pose.
384
+ 3. (Optional) You can select multiple ControlNet models to control the generation process. The default is to use the IdentityNet only. The ControlNet models include pose skeleton, canny, and depth. You can adjust the strength of each ControlNet model to control the generation process.
385
+ 4. Enter a text prompt, as done in normal text-to-image models.
386
+ 5. Click the <b>Submit</b> button to begin customization.
387
+ 6. Share your customized photo with your friends and enjoy! 😊"""
388
 
389
  article = r"""
390
  ---
 
393
  If our work is helpful for your research or applications, please cite us via:
394
  ```bibtex
395
  @article{wang2024instantid,
396
+ title={InstantID: Zero-shot Identity-Preserving Generation in Seconds},
397
+ author={Wang, Qixun and Bai, Xu and Wang, Haofan and Qin, Zekui and Chen, Anthony},
398
+ journal={arXiv preprint arXiv:2401.07519},
399
+ year={2024}
400
  }
401
  ```
402
  📧 **Contact**
 
406
 
407
  tips = r"""
408
  ### Usage tips of InstantID
409
+ 1. If you're not satisfied with the similarity, try increasing the weight of "IdentityNet Strength" and "Adapter Strength."
410
+ 2. If you feel that the saturation is too high, first decrease the Adapter strength. If it remains too high, then decrease the IdentityNet strength.
411
+ 3. If you find that text control is not as expected, decrease Adapter strength.
412
+ 4. If you find that realistic style is not good enough, go for our Github repo and use a more realistic base model.
413
  """
414
 
415
  css = """
 
422
 
423
  with gr.Row():
424
  with gr.Column():
425
+ with gr.Row(equal_height=True):
426
+ # upload face image
427
+ face_file = gr.Image(
428
+ label="Upload a photo of your face", type="filepath"
429
+ )
430
+ # optional: upload a reference pose image
431
+ pose_file = gr.Image(
432
+ label="Upload a reference pose image (Optional)",
433
+ type="filepath",
434
+ )
435
 
436
  # prompt
437
  prompt = gr.Textbox(
438
  label="Prompt",
439
+ info="Give simple prompt is enough to achieve good face fidelity",
440
  placeholder="A photo of a person",
441
  value="",
442
  )
443
 
444
  submit = gr.Button("Submit", variant="primary")
445
+ enable_LCM = gr.Checkbox(
446
+ label="Enable Fast Inference with LCM", value=enable_lcm_arg,
447
+ info="LCM speeds up the inference step, the trade-off is the quality of the generated image. It performs better with portrait face images rather than distant faces",
448
+ )
449
+ style = gr.Dropdown(
450
+ label="Style template",
451
+ choices=STYLE_NAMES,
452
+ value=DEFAULT_STYLE_NAME,
453
+ )
454
 
455
  # strength
456
  identitynet_strength_ratio = gr.Slider(
457
+ label="IdentityNet strength (for fidelity)",
458
  minimum=0,
459
  maximum=1.5,
460
  step=0.05,
 
467
  step=0.05,
468
  value=0.80,
469
  )
470
+ with gr.Accordion("Controlnet"):
471
+ controlnet_selection = gr.CheckboxGroup(
472
+ ["pose", "canny", "depth"], label="Controlnet", value=["pose"],
473
+ info="Use pose for skeleton inference, canny for edge detection, and depth for depth map estimation. You can try all three to control the generation process"
474
+ )
475
+ pose_strength = gr.Slider(
476
+ label="Pose strength",
477
+ minimum=0,
478
+ maximum=1.5,
479
+ step=0.05,
480
+ value=0.40,
481
+ )
482
+ canny_strength = gr.Slider(
483
+ label="Canny strength",
484
+ minimum=0,
485
+ maximum=1.5,
486
+ step=0.05,
487
+ value=0.40,
488
+ )
489
+ depth_strength = gr.Slider(
490
+ label="Depth strength",
491
+ minimum=0,
492
+ maximum=1.5,
493
+ step=0.05,
494
+ value=0.40,
495
+ )
496
  with gr.Accordion(open=False, label="Advanced Options"):
497
  negative_prompt = gr.Textbox(
498
  label="Negative Prompt",
499
  placeholder="low quality",
500
+ value="(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
501
  )
502
  num_steps = gr.Slider(
503
  label="Number of sample steps",
504
+ minimum=1,
505
  maximum=100,
506
  step=1,
507
+ value=5 if enable_lcm_arg else 30,
508
  )
509
  guidance_scale = gr.Slider(
510
  label="Guidance scale",
511
  minimum=0.1,
512
+ maximum=20.0,
513
  step=0.1,
514
+ value=0.0 if enable_lcm_arg else 5.0,
515
  )
516
  seed = gr.Slider(
517
  label="Seed",
 
520
  step=1,
521
  value=42,
522
  )
523
+ schedulers = [
524
+ "DEISMultistepScheduler",
525
+ "HeunDiscreteScheduler",
526
+ "EulerDiscreteScheduler",
527
+ "DPMSolverMultistepScheduler",
528
+ "DPMSolverMultistepScheduler-Karras",
529
+ "DPMSolverMultistepScheduler-Karras-SDE",
530
+ ]
531
+ scheduler = gr.Dropdown(
532
+ label="Schedulers",
533
+ choices=schedulers,
534
+ value="EulerDiscreteScheduler",
535
+ )
536
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
537
  enhance_face_region = gr.Checkbox(label="Enhance non-face region", value=True)
538
 
539
+ with gr.Column(scale=1):
540
+ gallery = gr.Image(label="Generated Images")
541
+ usage_tips = gr.Markdown(
542
+ label="InstantID Usage Tips", value=tips, visible=False
543
+ )
544
 
545
  submit.click(
546
  fn=remove_tips,
547
  outputs=usage_tips,
 
 
548
  ).then(
549
  fn=randomize_seed_fn,
550
  inputs=[seed, randomize_seed],
 
552
  queue=False,
553
  api_name=False,
554
  ).then(
 
 
 
 
 
555
  fn=generate_image,
556
  inputs=[
557
  face_file,
 
559
  prompt,
560
  negative_prompt,
561
  style,
 
562
  num_steps,
563
  identitynet_strength_ratio,
564
  adapter_strength_ratio,
565
+ pose_strength,
566
+ canny_strength,
567
+ depth_strength,
568
+ controlnet_selection,
569
  guidance_scale,
570
  seed,
571
+ scheduler,
572
+ enable_LCM,
573
+ enhance_face_region,
574
  ],
575
+ outputs=[gallery, usage_tips],
576
+ )
577
+
578
+ enable_LCM.input(
579
+ fn=toggle_lcm_ui,
580
+ inputs=[enable_LCM],
581
+ outputs=[num_steps, guidance_scale],
582
+ queue=False,
583
  )
584
 
585
  gr.Examples(
586
  examples=get_example(),
587
+ inputs=[face_file, pose_file, prompt, style, negative_prompt],
 
588
  fn=run_for_examples,
589
+ outputs=[gallery, usage_tips],
590
  cache_examples=True,
591
  )
592
 
controlnet_util.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from controlnet_aux import OpenposeDetector
5
+ from model_util import get_torch_device
6
+ import cv2
7
+
8
+
9
+ from transformers import DPTImageProcessor, DPTForDepthEstimation
10
+
11
+ device = get_torch_device()
12
+ depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
13
+ feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
14
+ openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
15
+
16
+ def get_depth_map(image):
17
+ image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
18
+ with torch.no_grad(), torch.autocast("cuda"):
19
+ depth_map = depth_estimator(image).predicted_depth
20
+
21
+ depth_map = torch.nn.functional.interpolate(
22
+ depth_map.unsqueeze(1),
23
+ size=(1024, 1024),
24
+ mode="bicubic",
25
+ align_corners=False,
26
+ )
27
+ depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
28
+ depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
29
+ depth_map = (depth_map - depth_min) / (depth_max - depth_min)
30
+ image = torch.cat([depth_map] * 3, dim=1)
31
+
32
+ image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
33
+ image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
34
+ return image
35
+
36
+ def get_canny_image(image, t1=100, t2=200):
37
+ image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
38
+ edges = cv2.Canny(image, t1, t2)
39
+ return Image.fromarray(edges, "L")
gradio_cached_examples/25/Generated Image/2880a3d19ef9b42e3ed2/image.png DELETED

Git LFS Details

  • SHA256: 573444e88e4bf4ab7bf4a693cf53cea3988366f4ca35b41523bf7c027802d0a6
  • Pointer size: 132 Bytes
  • Size of remote file: 2.57 MB
gradio_cached_examples/25/Generated Image/38dde1388c41109c5d39/image.png DELETED

Git LFS Details

  • SHA256: c4e80ada96212f1acd058324b25602dc611de920de9e710c26286e751a5f1a9a
  • Pointer size: 132 Bytes
  • Size of remote file: 2.31 MB
gradio_cached_examples/25/Generated Image/6cb5a3af223906666bfd/image.png DELETED

Git LFS Details

  • SHA256: f4b7543b3b1fd8ae301ee77c8479a7b3170e2ee063cd501a04e5e8ecc38417f6
  • Pointer size: 132 Bytes
  • Size of remote file: 2.28 MB
gradio_cached_examples/25/Generated Image/7e6ea85f77dd925d842c/image.png DELETED

Git LFS Details

  • SHA256: abde3e97dc47d0b95f9909b2575834d6792db49816e5b40b0c8474a48fe467b2
  • Pointer size: 132 Bytes
  • Size of remote file: 2.38 MB
gradio_cached_examples/25/Generated Image/e56b029833685ff77e6a/image.png DELETED

Git LFS Details

  • SHA256: 2f0992f67fe4839bffb56cf3c527e08c652406c55729257374ee8d630ac21501
  • Pointer size: 132 Bytes
  • Size of remote file: 2.69 MB
gradio_cached_examples/25/log.csv DELETED
@@ -1,6 +0,0 @@
1
- Generated Image,Usage tips of InstantID,flag,username,timestamp
2
- "{""path"":""gradio_cached_examples/25/Generated Image/7e6ea85f77dd925d842c/image.png"",""url"":null,""size"":null,""orig_name"":""image.png"",""mime_type"":null}","{'visible': True, '__type__': 'update'}",,,2024-01-24 08:55:38.846769
3
- "{""path"":""gradio_cached_examples/25/Generated Image/2880a3d19ef9b42e3ed2/image.png"",""url"":null,""size"":null,""orig_name"":""image.png"",""mime_type"":null}","{'visible': True, '__type__': 'update'}",,,2024-01-24 08:56:11.432078
4
- "{""path"":""gradio_cached_examples/25/Generated Image/38dde1388c41109c5d39/image.png"",""url"":null,""size"":null,""orig_name"":""image.png"",""mime_type"":null}","{'visible': True, '__type__': 'update'}",,,2024-01-24 08:56:45.563918
5
- "{""path"":""gradio_cached_examples/25/Generated Image/e56b029833685ff77e6a/image.png"",""url"":null,""size"":null,""orig_name"":""image.png"",""mime_type"":null}","{'visible': True, '__type__': 'update'}",,,2024-01-24 08:57:20.321876
6
- "{""path"":""gradio_cached_examples/25/Generated Image/6cb5a3af223906666bfd/image.png"",""url"":null,""size"":null,""orig_name"":""image.png"",""mime_type"":null}","{'visible': True, '__type__': 'update'}",,,2024-01-24 08:57:53.871716
 
 
 
 
 
 
 
ip_adapter/attention_processor.py CHANGED
@@ -10,14 +10,11 @@ try:
10
  except Exception as e:
11
  xformers_available = False
12
 
13
-
14
-
15
  class RegionControler(object):
16
  def __init__(self) -> None:
17
  self.prompt_image_conditioning = []
18
  region_control = RegionControler()
19
 
20
-
21
  class AttnProcessor(nn.Module):
22
  r"""
23
  Default processor for performing attention-related computations.
@@ -29,7 +26,7 @@ class AttnProcessor(nn.Module):
29
  ):
30
  super().__init__()
31
 
32
- def __call__(
33
  self,
34
  attn,
35
  hidden_states,
@@ -115,7 +112,7 @@ class IPAttnProcessor(nn.Module):
115
  self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
116
  self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
117
 
118
- def __call__(
119
  self,
120
  attn,
121
  hidden_states,
@@ -180,7 +177,7 @@ class IPAttnProcessor(nn.Module):
180
  ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
181
  ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
182
  ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
183
-
184
  # region control
185
  if len(region_control.prompt_image_conditioning) == 1:
186
  region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
@@ -190,7 +187,7 @@ class IPAttnProcessor(nn.Module):
190
  mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
191
  else:
192
  mask = torch.ones_like(ip_hidden_states)
193
- ip_hidden_states = ip_hidden_states * mask
194
 
195
  hidden_states = hidden_states + self.scale * ip_hidden_states
196
 
@@ -233,7 +230,7 @@ class AttnProcessor2_0(torch.nn.Module):
233
  if not hasattr(F, "scaled_dot_product_attention"):
234
  raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
235
 
236
- def __call__(
237
  self,
238
  attn,
239
  hidden_states,
@@ -305,4 +302,145 @@ class AttnProcessor2_0(torch.nn.Module):
305
 
306
  hidden_states = hidden_states / attn.rescale_output_factor
307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  return hidden_states
 
10
  except Exception as e:
11
  xformers_available = False
12
 
 
 
13
  class RegionControler(object):
14
  def __init__(self) -> None:
15
  self.prompt_image_conditioning = []
16
  region_control = RegionControler()
17
 
 
18
  class AttnProcessor(nn.Module):
19
  r"""
20
  Default processor for performing attention-related computations.
 
26
  ):
27
  super().__init__()
28
 
29
+ def forward(
30
  self,
31
  attn,
32
  hidden_states,
 
112
  self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
113
  self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
114
 
115
+ def forward(
116
  self,
117
  attn,
118
  hidden_states,
 
177
  ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
178
  ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
179
  ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
180
+
181
  # region control
182
  if len(region_control.prompt_image_conditioning) == 1:
183
  region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
 
187
  mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
188
  else:
189
  mask = torch.ones_like(ip_hidden_states)
190
+ ip_hidden_states = ip_hidden_states * mask
191
 
192
  hidden_states = hidden_states + self.scale * ip_hidden_states
193
 
 
230
  if not hasattr(F, "scaled_dot_product_attention"):
231
  raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
232
 
233
+ def forward(
234
  self,
235
  attn,
236
  hidden_states,
 
302
 
303
  hidden_states = hidden_states / attn.rescale_output_factor
304
 
305
+ return hidden_states
306
+
307
+ class IPAttnProcessor2_0(torch.nn.Module):
308
+ r"""
309
+ Attention processor for IP-Adapater for PyTorch 2.0.
310
+ Args:
311
+ hidden_size (`int`):
312
+ The hidden size of the attention layer.
313
+ cross_attention_dim (`int`):
314
+ The number of channels in the `encoder_hidden_states`.
315
+ scale (`float`, defaults to 1.0):
316
+ the weight scale of image prompt.
317
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
318
+ The context length of the image features.
319
+ """
320
+
321
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
322
+ super().__init__()
323
+
324
+ if not hasattr(F, "scaled_dot_product_attention"):
325
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
326
+
327
+ self.hidden_size = hidden_size
328
+ self.cross_attention_dim = cross_attention_dim
329
+ self.scale = scale
330
+ self.num_tokens = num_tokens
331
+
332
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
333
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
334
+
335
+ def forward(
336
+ self,
337
+ attn,
338
+ hidden_states,
339
+ encoder_hidden_states=None,
340
+ attention_mask=None,
341
+ temb=None,
342
+ ):
343
+ residual = hidden_states
344
+
345
+ if attn.spatial_norm is not None:
346
+ hidden_states = attn.spatial_norm(hidden_states, temb)
347
+
348
+ input_ndim = hidden_states.ndim
349
+
350
+ if input_ndim == 4:
351
+ batch_size, channel, height, width = hidden_states.shape
352
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
353
+
354
+ batch_size, sequence_length, _ = (
355
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
356
+ )
357
+
358
+ if attention_mask is not None:
359
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
360
+ # scaled_dot_product_attention expects attention_mask shape to be
361
+ # (batch, heads, source_length, target_length)
362
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
363
+
364
+ if attn.group_norm is not None:
365
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
366
+
367
+ query = attn.to_q(hidden_states)
368
+
369
+ if encoder_hidden_states is None:
370
+ encoder_hidden_states = hidden_states
371
+ else:
372
+ # get encoder_hidden_states, ip_hidden_states
373
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
374
+ encoder_hidden_states, ip_hidden_states = (
375
+ encoder_hidden_states[:, :end_pos, :],
376
+ encoder_hidden_states[:, end_pos:, :],
377
+ )
378
+ if attn.norm_cross:
379
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
380
+
381
+ key = attn.to_k(encoder_hidden_states)
382
+ value = attn.to_v(encoder_hidden_states)
383
+
384
+ inner_dim = key.shape[-1]
385
+ head_dim = inner_dim // attn.heads
386
+
387
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
388
+
389
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
390
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
391
+
392
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
393
+ # TODO: add support for attn.scale when we move to Torch 2.1
394
+ hidden_states = F.scaled_dot_product_attention(
395
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
396
+ )
397
+
398
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
399
+ hidden_states = hidden_states.to(query.dtype)
400
+
401
+ # for ip-adapter
402
+ ip_key = self.to_k_ip(ip_hidden_states)
403
+ ip_value = self.to_v_ip(ip_hidden_states)
404
+
405
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
406
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
407
+
408
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
409
+ # TODO: add support for attn.scale when we move to Torch 2.1
410
+ ip_hidden_states = F.scaled_dot_product_attention(
411
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
412
+ )
413
+ with torch.no_grad():
414
+ self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
415
+ #print(self.attn_map.shape)
416
+
417
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
418
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
419
+
420
+ # region control
421
+ if len(region_control.prompt_image_conditioning) == 1:
422
+ region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
423
+ if region_mask is not None:
424
+ h, w = region_mask.shape[:2]
425
+ ratio = (h * w / query.shape[1]) ** 0.5
426
+ mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
427
+ else:
428
+ mask = torch.ones_like(ip_hidden_states)
429
+ ip_hidden_states = ip_hidden_states * mask
430
+
431
+ hidden_states = hidden_states + self.scale * ip_hidden_states
432
+
433
+ # linear proj
434
+ hidden_states = attn.to_out[0](hidden_states)
435
+ # dropout
436
+ hidden_states = attn.to_out[1](hidden_states)
437
+
438
+ if input_ndim == 4:
439
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
440
+
441
+ if attn.residual_connection:
442
+ hidden_states = hidden_states + residual
443
+
444
+ hidden_states = hidden_states / attn.rescale_output_factor
445
+
446
  return hidden_states
model_util.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Union, Optional, Tuple, List
2
+
3
+ import torch
4
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
5
+ from diffusers import (
6
+ UNet2DConditionModel,
7
+ SchedulerMixin,
8
+ StableDiffusionPipeline,
9
+ StableDiffusionXLPipeline,
10
+ AutoencoderKL,
11
+ )
12
+ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
13
+ convert_ldm_unet_checkpoint,
14
+ )
15
+ from safetensors.torch import load_file
16
+ from diffusers.schedulers import (
17
+ DDIMScheduler,
18
+ DDPMScheduler,
19
+ LMSDiscreteScheduler,
20
+ EulerDiscreteScheduler,
21
+ EulerAncestralDiscreteScheduler,
22
+ UniPCMultistepScheduler,
23
+ )
24
+
25
+ from omegaconf import OmegaConf
26
+
27
+ # DiffUsers版StableDiffusionのモデルパラメータ
28
+ NUM_TRAIN_TIMESTEPS = 1000
29
+ BETA_START = 0.00085
30
+ BETA_END = 0.0120
31
+
32
+ UNET_PARAMS_MODEL_CHANNELS = 320
33
+ UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
34
+ UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
35
+ UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
36
+ UNET_PARAMS_IN_CHANNELS = 4
37
+ UNET_PARAMS_OUT_CHANNELS = 4
38
+ UNET_PARAMS_NUM_RES_BLOCKS = 2
39
+ UNET_PARAMS_CONTEXT_DIM = 768
40
+ UNET_PARAMS_NUM_HEADS = 8
41
+ # UNET_PARAMS_USE_LINEAR_PROJECTION = False
42
+
43
+ VAE_PARAMS_Z_CHANNELS = 4
44
+ VAE_PARAMS_RESOLUTION = 256
45
+ VAE_PARAMS_IN_CHANNELS = 3
46
+ VAE_PARAMS_OUT_CH = 3
47
+ VAE_PARAMS_CH = 128
48
+ VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
49
+ VAE_PARAMS_NUM_RES_BLOCKS = 2
50
+
51
+ # V2
52
+ V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
53
+ V2_UNET_PARAMS_CONTEXT_DIM = 1024
54
+ # V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True
55
+
56
+ TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4"
57
+ TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1"
58
+
59
+ AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a", "euler", "uniPC"]
60
+
61
+ SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection]
62
+
63
+ DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this
64
+
65
+
66
+ def load_checkpoint_with_text_encoder_conversion(ckpt_path: str, device="cpu"):
67
+ # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
68
+ TEXT_ENCODER_KEY_REPLACEMENTS = [
69
+ (
70
+ "cond_stage_model.transformer.embeddings.",
71
+ "cond_stage_model.transformer.text_model.embeddings.",
72
+ ),
73
+ (
74
+ "cond_stage_model.transformer.encoder.",
75
+ "cond_stage_model.transformer.text_model.encoder.",
76
+ ),
77
+ (
78
+ "cond_stage_model.transformer.final_layer_norm.",
79
+ "cond_stage_model.transformer.text_model.final_layer_norm.",
80
+ ),
81
+ ]
82
+
83
+ if ckpt_path.endswith(".safetensors"):
84
+ checkpoint = None
85
+ state_dict = load_file(ckpt_path) # , device) # may causes error
86
+ else:
87
+ checkpoint = torch.load(ckpt_path, map_location=device)
88
+ if "state_dict" in checkpoint:
89
+ state_dict = checkpoint["state_dict"]
90
+ else:
91
+ state_dict = checkpoint
92
+ checkpoint = None
93
+
94
+ key_reps = []
95
+ for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
96
+ for key in state_dict.keys():
97
+ if key.startswith(rep_from):
98
+ new_key = rep_to + key[len(rep_from) :]
99
+ key_reps.append((key, new_key))
100
+
101
+ for key, new_key in key_reps:
102
+ state_dict[new_key] = state_dict[key]
103
+ del state_dict[key]
104
+
105
+ return checkpoint, state_dict
106
+
107
+
108
+ def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False):
109
+ """
110
+ Creates a config for the diffusers based on the config of the LDM model.
111
+ """
112
+ # unet_params = original_config.model.params.unet_config.params
113
+
114
+ block_out_channels = [
115
+ UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT
116
+ ]
117
+
118
+ down_block_types = []
119
+ resolution = 1
120
+ for i in range(len(block_out_channels)):
121
+ block_type = (
122
+ "CrossAttnDownBlock2D"
123
+ if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS
124
+ else "DownBlock2D"
125
+ )
126
+ down_block_types.append(block_type)
127
+ if i != len(block_out_channels) - 1:
128
+ resolution *= 2
129
+
130
+ up_block_types = []
131
+ for i in range(len(block_out_channels)):
132
+ block_type = (
133
+ "CrossAttnUpBlock2D"
134
+ if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS
135
+ else "UpBlock2D"
136
+ )
137
+ up_block_types.append(block_type)
138
+ resolution //= 2
139
+
140
+ config = dict(
141
+ sample_size=UNET_PARAMS_IMAGE_SIZE,
142
+ in_channels=UNET_PARAMS_IN_CHANNELS,
143
+ out_channels=UNET_PARAMS_OUT_CHANNELS,
144
+ down_block_types=tuple(down_block_types),
145
+ up_block_types=tuple(up_block_types),
146
+ block_out_channels=tuple(block_out_channels),
147
+ layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
148
+ cross_attention_dim=UNET_PARAMS_CONTEXT_DIM
149
+ if not v2
150
+ else V2_UNET_PARAMS_CONTEXT_DIM,
151
+ attention_head_dim=UNET_PARAMS_NUM_HEADS
152
+ if not v2
153
+ else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
154
+ # use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION,
155
+ )
156
+ if v2 and use_linear_projection_in_v2:
157
+ config["use_linear_projection"] = True
158
+
159
+ return config
160
+
161
+
162
+ def load_diffusers_model(
163
+ pretrained_model_name_or_path: str,
164
+ v2: bool = False,
165
+ clip_skip: Optional[int] = None,
166
+ weight_dtype: torch.dtype = torch.float32,
167
+ ) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
168
+ if v2:
169
+ tokenizer = CLIPTokenizer.from_pretrained(
170
+ TOKENIZER_V2_MODEL_NAME,
171
+ subfolder="tokenizer",
172
+ torch_dtype=weight_dtype,
173
+ cache_dir=DIFFUSERS_CACHE_DIR,
174
+ )
175
+ text_encoder = CLIPTextModel.from_pretrained(
176
+ pretrained_model_name_or_path,
177
+ subfolder="text_encoder",
178
+ # default is clip skip 2
179
+ num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23,
180
+ torch_dtype=weight_dtype,
181
+ cache_dir=DIFFUSERS_CACHE_DIR,
182
+ )
183
+ else:
184
+ tokenizer = CLIPTokenizer.from_pretrained(
185
+ TOKENIZER_V1_MODEL_NAME,
186
+ subfolder="tokenizer",
187
+ torch_dtype=weight_dtype,
188
+ cache_dir=DIFFUSERS_CACHE_DIR,
189
+ )
190
+ text_encoder = CLIPTextModel.from_pretrained(
191
+ pretrained_model_name_or_path,
192
+ subfolder="text_encoder",
193
+ num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12,
194
+ torch_dtype=weight_dtype,
195
+ cache_dir=DIFFUSERS_CACHE_DIR,
196
+ )
197
+
198
+ unet = UNet2DConditionModel.from_pretrained(
199
+ pretrained_model_name_or_path,
200
+ subfolder="unet",
201
+ torch_dtype=weight_dtype,
202
+ cache_dir=DIFFUSERS_CACHE_DIR,
203
+ )
204
+
205
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
206
+
207
+ return tokenizer, text_encoder, unet, vae
208
+
209
+
210
+ def load_checkpoint_model(
211
+ checkpoint_path: str,
212
+ v2: bool = False,
213
+ clip_skip: Optional[int] = None,
214
+ weight_dtype: torch.dtype = torch.float32,
215
+ ) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
216
+ pipe = StableDiffusionPipeline.from_single_file(
217
+ checkpoint_path,
218
+ upcast_attention=True if v2 else False,
219
+ torch_dtype=weight_dtype,
220
+ cache_dir=DIFFUSERS_CACHE_DIR,
221
+ )
222
+
223
+ _, state_dict = load_checkpoint_with_text_encoder_conversion(checkpoint_path)
224
+ unet_config = create_unet_diffusers_config(v2, use_linear_projection_in_v2=v2)
225
+ unet_config["class_embed_type"] = None
226
+ unet_config["addition_embed_type"] = None
227
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(state_dict, unet_config)
228
+ unet = UNet2DConditionModel(**unet_config)
229
+ unet.load_state_dict(converted_unet_checkpoint)
230
+
231
+ tokenizer = pipe.tokenizer
232
+ text_encoder = pipe.text_encoder
233
+ vae = pipe.vae
234
+ if clip_skip is not None:
235
+ if v2:
236
+ text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1)
237
+ else:
238
+ text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1)
239
+
240
+ del pipe
241
+
242
+ return tokenizer, text_encoder, unet, vae
243
+
244
+
245
+ def load_models(
246
+ pretrained_model_name_or_path: str,
247
+ scheduler_name: str,
248
+ v2: bool = False,
249
+ v_pred: bool = False,
250
+ weight_dtype: torch.dtype = torch.float32,
251
+ ) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]:
252
+ if pretrained_model_name_or_path.endswith(
253
+ ".ckpt"
254
+ ) or pretrained_model_name_or_path.endswith(".safetensors"):
255
+ tokenizer, text_encoder, unet, vae = load_checkpoint_model(
256
+ pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
257
+ )
258
+ else: # diffusers
259
+ tokenizer, text_encoder, unet, vae = load_diffusers_model(
260
+ pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
261
+ )
262
+
263
+ if scheduler_name:
264
+ scheduler = create_noise_scheduler(
265
+ scheduler_name,
266
+ prediction_type="v_prediction" if v_pred else "epsilon",
267
+ )
268
+ else:
269
+ scheduler = None
270
+
271
+ return tokenizer, text_encoder, unet, scheduler, vae
272
+
273
+
274
+ def load_diffusers_model_xl(
275
+ pretrained_model_name_or_path: str,
276
+ weight_dtype: torch.dtype = torch.float32,
277
+ ) -> Tuple[List[CLIPTokenizer], List[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
278
+ # returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet
279
+
280
+ tokenizers = [
281
+ CLIPTokenizer.from_pretrained(
282
+ pretrained_model_name_or_path,
283
+ subfolder="tokenizer",
284
+ torch_dtype=weight_dtype,
285
+ cache_dir=DIFFUSERS_CACHE_DIR,
286
+ ),
287
+ CLIPTokenizer.from_pretrained(
288
+ pretrained_model_name_or_path,
289
+ subfolder="tokenizer_2",
290
+ torch_dtype=weight_dtype,
291
+ cache_dir=DIFFUSERS_CACHE_DIR,
292
+ pad_token_id=0, # same as open clip
293
+ ),
294
+ ]
295
+
296
+ text_encoders = [
297
+ CLIPTextModel.from_pretrained(
298
+ pretrained_model_name_or_path,
299
+ subfolder="text_encoder",
300
+ torch_dtype=weight_dtype,
301
+ cache_dir=DIFFUSERS_CACHE_DIR,
302
+ ),
303
+ CLIPTextModelWithProjection.from_pretrained(
304
+ pretrained_model_name_or_path,
305
+ subfolder="text_encoder_2",
306
+ torch_dtype=weight_dtype,
307
+ cache_dir=DIFFUSERS_CACHE_DIR,
308
+ ),
309
+ ]
310
+
311
+ unet = UNet2DConditionModel.from_pretrained(
312
+ pretrained_model_name_or_path,
313
+ subfolder="unet",
314
+ torch_dtype=weight_dtype,
315
+ cache_dir=DIFFUSERS_CACHE_DIR,
316
+ )
317
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
318
+ return tokenizers, text_encoders, unet, vae
319
+
320
+
321
+ def load_checkpoint_model_xl(
322
+ checkpoint_path: str,
323
+ weight_dtype: torch.dtype = torch.float32,
324
+ ) -> Tuple[List[CLIPTokenizer], List[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
325
+ pipe = StableDiffusionXLPipeline.from_single_file(
326
+ checkpoint_path,
327
+ torch_dtype=weight_dtype,
328
+ cache_dir=DIFFUSERS_CACHE_DIR,
329
+ )
330
+
331
+ unet = pipe.unet
332
+ vae = pipe.vae
333
+ tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
334
+ text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
335
+ if len(text_encoders) == 2:
336
+ text_encoders[1].pad_token_id = 0
337
+
338
+ del pipe
339
+
340
+ return tokenizers, text_encoders, unet, vae
341
+
342
+
343
+ def load_models_xl(
344
+ pretrained_model_name_or_path: str,
345
+ scheduler_name: str,
346
+ weight_dtype: torch.dtype = torch.float32,
347
+ noise_scheduler_kwargs=None,
348
+ ) -> Tuple[
349
+ List[CLIPTokenizer],
350
+ List[SDXL_TEXT_ENCODER_TYPE],
351
+ UNet2DConditionModel,
352
+ SchedulerMixin,
353
+ ]:
354
+ if pretrained_model_name_or_path.endswith(
355
+ ".ckpt"
356
+ ) or pretrained_model_name_or_path.endswith(".safetensors"):
357
+ (tokenizers, text_encoders, unet, vae) = load_checkpoint_model_xl(
358
+ pretrained_model_name_or_path, weight_dtype
359
+ )
360
+ else: # diffusers
361
+ (tokenizers, text_encoders, unet, vae) = load_diffusers_model_xl(
362
+ pretrained_model_name_or_path, weight_dtype
363
+ )
364
+ if scheduler_name:
365
+ scheduler = create_noise_scheduler(scheduler_name, noise_scheduler_kwargs)
366
+ else:
367
+ scheduler = None
368
+
369
+ return tokenizers, text_encoders, unet, scheduler, vae
370
+
371
+ def create_noise_scheduler(
372
+ scheduler_name: AVAILABLE_SCHEDULERS = "ddpm",
373
+ noise_scheduler_kwargs=None,
374
+ prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
375
+ ) -> SchedulerMixin:
376
+ name = scheduler_name.lower().replace(" ", "_")
377
+ if name.lower() == "ddim":
378
+ # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim
379
+ scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs))
380
+ elif name.lower() == "ddpm":
381
+ # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm
382
+ scheduler = DDPMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs))
383
+ elif name.lower() == "lms":
384
+ # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete
385
+ scheduler = LMSDiscreteScheduler(
386
+ **OmegaConf.to_container(noise_scheduler_kwargs)
387
+ )
388
+ elif name.lower() == "euler_a":
389
+ # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral
390
+ scheduler = EulerAncestralDiscreteScheduler(
391
+ **OmegaConf.to_container(noise_scheduler_kwargs)
392
+ )
393
+ elif name.lower() == "euler":
394
+ # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral
395
+ scheduler = EulerDiscreteScheduler(
396
+ **OmegaConf.to_container(noise_scheduler_kwargs)
397
+ )
398
+ elif name.lower() == "unipc":
399
+ # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/unipc
400
+ scheduler = UniPCMultistepScheduler(
401
+ **OmegaConf.to_container(noise_scheduler_kwargs)
402
+ )
403
+ else:
404
+ raise ValueError(f"Unknown scheduler name: {name}")
405
+
406
+ return scheduler
407
+
408
+
409
+ def torch_gc():
410
+ import gc
411
+
412
+ gc.collect()
413
+ if torch.cuda.is_available():
414
+ with torch.cuda.device("cuda"):
415
+ torch.cuda.empty_cache()
416
+ torch.cuda.ipc_collect()
417
+
418
+
419
+ from enum import Enum
420
+
421
+
422
+ class CPUState(Enum):
423
+ GPU = 0
424
+ CPU = 1
425
+ MPS = 2
426
+
427
+
428
+ cpu_state = CPUState.GPU
429
+ xpu_available = False
430
+ directml_enabled = False
431
+
432
+
433
+ def is_intel_xpu():
434
+ global cpu_state
435
+ global xpu_available
436
+ if cpu_state == CPUState.GPU:
437
+ if xpu_available:
438
+ return True
439
+ return False
440
+
441
+
442
+ try:
443
+ import intel_extension_for_pytorch as ipex
444
+
445
+ if torch.xpu.is_available():
446
+ xpu_available = True
447
+ except:
448
+ pass
449
+
450
+ try:
451
+ if torch.backends.mps.is_available():
452
+ cpu_state = CPUState.MPS
453
+ import torch.mps
454
+ except:
455
+ pass
456
+
457
+
458
+ def get_torch_device():
459
+ global directml_enabled
460
+ global cpu_state
461
+ if directml_enabled:
462
+ global directml_device
463
+ return directml_device
464
+ if cpu_state == CPUState.MPS:
465
+ return torch.device("mps")
466
+ if cpu_state == CPUState.CPU:
467
+ return torch.device("cpu")
468
+ else:
469
+ if is_intel_xpu():
470
+ return torch.device("xpu")
471
+ else:
472
+ return torch.device(torch.cuda.current_device())
pipeline_stable_diffusion_xl_instantid.py → pipeline_stable_diffusion_xl_instantid_full.py RENAMED
@@ -22,7 +22,6 @@ import numpy as np
22
  import PIL.Image
23
  import torch
24
  import torch.nn.functional as F
25
- from transformers import CLIPTokenizer
26
 
27
  from diffusers.image_processor import PipelineImageInput
28
 
@@ -41,8 +40,12 @@ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
41
  from diffusers.utils.import_utils import is_xformers_available
42
 
43
  from ip_adapter.resampler import Resampler
 
44
 
45
- from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
 
 
 
46
  from ip_adapter.attention_processor import region_control
47
 
48
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -102,7 +105,7 @@ EXAMPLE_DOC_STRING = """
102
  ```
103
  """
104
 
105
-
106
  from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
107
  class LongPromptWeight(object):
108
 
@@ -482,6 +485,34 @@ class LongPromptWeight(object):
482
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
483
  return prompt_embeds
484
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
 
486
  class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
487
 
@@ -567,7 +598,7 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
567
  if isinstance(attn_processor, IPAttnProcessor):
568
  attn_processor.scale = scale
569
 
570
- def _encode_prompt_image_emb(self, prompt_image_emb, device, dtype, do_classifier_free_guidance):
571
 
572
  if isinstance(prompt_image_emb, torch.Tensor):
573
  prompt_image_emb = prompt_image_emb.clone().detach()
@@ -583,6 +614,11 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
583
  prompt_image_emb = torch.cat([prompt_image_emb], dim=0)
584
 
585
  prompt_image_emb = self.image_proj_model(prompt_image_emb)
 
 
 
 
 
586
  return prompt_image_emb
587
 
588
  @torch.no_grad()
@@ -623,7 +659,13 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
623
  clip_skip: Optional[int] = None,
624
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
625
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
 
 
 
 
 
626
  control_mask = None,
 
627
  **kwargs,
628
  ):
629
  r"""
@@ -758,6 +800,7 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
758
  If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
759
  otherwise a `tuple` is returned containing the output images.
760
  """
 
761
  lpw = LongPromptWeight()
762
 
763
  callback = kwargs.pop("callback", None)
@@ -789,6 +832,10 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
789
  mult * [control_guidance_start],
790
  mult * [control_guidance_end],
791
  )
 
 
 
 
792
 
793
  # 1. Check inputs. Raise error if not correct
794
  self.check_inputs(
@@ -851,6 +898,7 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
851
  # 3.2 Encode image prompt
852
  prompt_image_emb = self._encode_prompt_image_emb(image_embeds,
853
  device,
 
854
  self.unet.dtype,
855
  self.do_classifier_free_guidance)
856
 
@@ -1031,24 +1079,57 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
1031
  controlnet_cond_scale = controlnet_cond_scale[0]
1032
  cond_scale = controlnet_cond_scale * controlnet_keep[i]
1033
 
1034
- down_block_res_samples, mid_block_res_sample = self.controlnet(
1035
- control_model_input,
1036
- t,
1037
- encoder_hidden_states=prompt_image_emb,
1038
- controlnet_cond=image,
1039
- conditioning_scale=cond_scale,
1040
- guess_mode=guess_mode,
1041
- added_cond_kwargs=controlnet_added_cond_kwargs,
1042
- return_dict=False,
1043
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1044
 
1045
- # controlnet mask
1046
- if control_mask_wight_image_list is not None:
1047
- down_block_res_samples = [
1048
- down_block_res_sample * mask_weight
1049
- for down_block_res_sample, mask_weight in zip(down_block_res_samples, control_mask_wight_image_list)
1050
- ]
1051
- mid_block_res_sample *= control_mask_wight_image_list[-1]
1052
 
1053
  if guess_mode and self.do_classifier_free_guidance:
1054
  # Infered ControlNet only for the conditional batch.
 
22
  import PIL.Image
23
  import torch
24
  import torch.nn.functional as F
 
25
 
26
  from diffusers.image_processor import PipelineImageInput
27
 
 
40
  from diffusers.utils.import_utils import is_xformers_available
41
 
42
  from ip_adapter.resampler import Resampler
43
+ from ip_adapter.utils import is_torch2_available
44
 
45
+ if is_torch2_available():
46
+ from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
47
+ else:
48
+ from ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor
49
  from ip_adapter.attention_processor import region_control
50
 
51
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
 
105
  ```
106
  """
107
 
108
+ from transformers import CLIPTokenizer
109
  from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
110
  class LongPromptWeight(object):
111
 
 
485
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
486
  return prompt_embeds
487
 
488
+ def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]):
489
+
490
+ stickwidth = 4
491
+ limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
492
+ kps = np.array(kps)
493
+
494
+ w, h = image_pil.size
495
+ out_img = np.zeros([h, w, 3])
496
+
497
+ for i in range(len(limbSeq)):
498
+ index = limbSeq[i]
499
+ color = color_list[index[0]]
500
+
501
+ x = kps[index][:, 0]
502
+ y = kps[index][:, 1]
503
+ length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
504
+ angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
505
+ polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
506
+ out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
507
+ out_img = (out_img * 0.6).astype(np.uint8)
508
+
509
+ for idx_kp, kp in enumerate(kps):
510
+ color = color_list[idx_kp]
511
+ x, y = kp
512
+ out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
513
+
514
+ out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
515
+ return out_img_pil
516
 
517
  class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
518
 
 
598
  if isinstance(attn_processor, IPAttnProcessor):
599
  attn_processor.scale = scale
600
 
601
+ def _encode_prompt_image_emb(self, prompt_image_emb, device, num_images_per_prompt, dtype, do_classifier_free_guidance):
602
 
603
  if isinstance(prompt_image_emb, torch.Tensor):
604
  prompt_image_emb = prompt_image_emb.clone().detach()
 
614
  prompt_image_emb = torch.cat([prompt_image_emb], dim=0)
615
 
616
  prompt_image_emb = self.image_proj_model(prompt_image_emb)
617
+
618
+ bs_embed, seq_len, _ = prompt_image_emb.shape
619
+ prompt_image_emb = prompt_image_emb.repeat(1, num_images_per_prompt, 1)
620
+ prompt_image_emb = prompt_image_emb.view(bs_embed * num_images_per_prompt, seq_len, -1)
621
+
622
  return prompt_image_emb
623
 
624
  @torch.no_grad()
 
659
  clip_skip: Optional[int] = None,
660
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
661
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
662
+
663
+ # IP adapter
664
+ ip_adapter_scale=None,
665
+
666
+ # Enhance Face Region
667
  control_mask = None,
668
+
669
  **kwargs,
670
  ):
671
  r"""
 
800
  If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
801
  otherwise a `tuple` is returned containing the output images.
802
  """
803
+
804
  lpw = LongPromptWeight()
805
 
806
  callback = kwargs.pop("callback", None)
 
832
  mult * [control_guidance_start],
833
  mult * [control_guidance_end],
834
  )
835
+
836
+ # 0. set ip_adapter_scale
837
+ if ip_adapter_scale is not None:
838
+ self.set_ip_adapter_scale(ip_adapter_scale)
839
 
840
  # 1. Check inputs. Raise error if not correct
841
  self.check_inputs(
 
898
  # 3.2 Encode image prompt
899
  prompt_image_emb = self._encode_prompt_image_emb(image_embeds,
900
  device,
901
+ num_images_per_prompt,
902
  self.unet.dtype,
903
  self.do_classifier_free_guidance)
904
 
 
1079
  controlnet_cond_scale = controlnet_cond_scale[0]
1080
  cond_scale = controlnet_cond_scale * controlnet_keep[i]
1081
 
1082
+ if isinstance(self.controlnet, MultiControlNetModel):
1083
+ down_block_res_samples_list, mid_block_res_sample_list = [], []
1084
+ for control_index in range(len(self.controlnet.nets)):
1085
+ controlnet = self.controlnet.nets[control_index]
1086
+ if control_index == 0:
1087
+ # assume fhe first controlnet is IdentityNet
1088
+ controlnet_prompt_embeds = prompt_image_emb
1089
+ else:
1090
+ controlnet_prompt_embeds = prompt_embeds
1091
+ down_block_res_samples, mid_block_res_sample = controlnet(control_model_input,
1092
+ t,
1093
+ encoder_hidden_states=controlnet_prompt_embeds,
1094
+ controlnet_cond=image[control_index],
1095
+ conditioning_scale=cond_scale[control_index],
1096
+ guess_mode=guess_mode,
1097
+ added_cond_kwargs=controlnet_added_cond_kwargs,
1098
+ return_dict=False)
1099
+
1100
+ # controlnet mask
1101
+ if control_index == 0 and control_mask_wight_image_list is not None:
1102
+ down_block_res_samples = [
1103
+ down_block_res_sample * mask_weight
1104
+ for down_block_res_sample, mask_weight in zip(down_block_res_samples, control_mask_wight_image_list)
1105
+ ]
1106
+ mid_block_res_sample *= control_mask_wight_image_list[-1]
1107
+
1108
+ down_block_res_samples_list.append(down_block_res_samples)
1109
+ mid_block_res_sample_list.append(mid_block_res_sample)
1110
+
1111
+ mid_block_res_sample = torch.stack(mid_block_res_sample_list).sum(dim=0)
1112
+ down_block_res_samples = [torch.stack(down_block_res_samples).sum(dim=0) for down_block_res_samples in
1113
+ zip(*down_block_res_samples_list)]
1114
+ else:
1115
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1116
+ control_model_input,
1117
+ t,
1118
+ encoder_hidden_states=prompt_image_emb,
1119
+ controlnet_cond=image,
1120
+ conditioning_scale=cond_scale,
1121
+ guess_mode=guess_mode,
1122
+ added_cond_kwargs=controlnet_added_cond_kwargs,
1123
+ return_dict=False,
1124
+ )
1125
 
1126
+ # controlnet mask
1127
+ if control_mask_wight_image_list is not None:
1128
+ down_block_res_samples = [
1129
+ down_block_res_sample * mask_weight
1130
+ for down_block_res_sample, mask_weight in zip(down_block_res_samples, control_mask_wight_image_list)
1131
+ ]
1132
+ mid_block_res_sample *= control_mask_wight_image_list[-1]
1133
 
1134
  if guess_mode and self.do_classifier_free_guidance:
1135
  # Infered ControlNet only for the conditional batch.
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
- diffusers==0.25.0
2
  torch==2.0.0
3
  torchvision==0.15.1
4
- transformers==4.36.2
5
  accelerate
6
  safetensors
7
  einops
@@ -11,4 +11,8 @@ omegaconf
11
  peft
12
  huggingface-hub==0.20.2
13
  opencv-python
14
- insightface
 
 
 
 
 
1
+ diffusers==0.25.1
2
  torch==2.0.0
3
  torchvision==0.15.1
4
+ transformers==4.37.1
5
  accelerate
6
  safetensors
7
  einops
 
11
  peft
12
  huggingface-hub==0.20.2
13
  opencv-python
14
+ insightface
15
+ gradio
16
+ controlnet_aux
17
+ gdown
18
+ peft