ResearcherXman commited on
Commit
0ba2339
1 Parent(s): 52ae519
app.py CHANGED
@@ -1,16 +1,60 @@
1
  import os
2
  import cv2
3
  import math
 
4
  import random
5
  import numpy as np
 
 
6
  from PIL import Image
7
 
 
8
  from diffusers.utils import load_image
 
 
 
 
9
 
 
 
 
 
10
  import gradio as gr
11
 
12
  # global variable
13
  MAX_SEED = np.iinfo(np.int32).max
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
16
  if randomize_seed:
@@ -29,14 +73,174 @@ def remove_back_to_files():
29
  def remove_tips():
30
  return gr.update(visible=False)
31
 
32
- def generate_image(face_image, pose_image, prompt, negative_prompt, num_steps, identitynet_strength_ratio, adapter_strength_ratio, num_outputs, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  if face_image is None:
35
  raise gr.Error(f"Cannot find any input face image! Please upload the face image")
36
-
 
 
 
 
 
 
37
  face_image = load_image(face_image[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- return [face_image], gr.update(visible=True)
40
 
41
  ### Description
42
  title = r"""
@@ -47,9 +251,9 @@ description = r"""
47
  <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/InstantID/InstantID' target='_blank'><b>InstantID: Zero-shot Identity-Preserving Generation in Seconds</b></a>.<br>
48
 
49
  How to use:<br>
50
- 1. Upload a person image or cropped face image. For multiple person images, we will only detect the biggest face. Make sure face is in good condition and not significantly blocked or blurred.
51
- 2. (Optionally) upload another person image as reference pose. If not uploaded, we will use the first person image to extract landmarks.
52
- 3. Enter a text prompt as normal text-to-image model.
53
  4. Click the <b>Submit</b> button to start customizing.
54
  5. Share your customizd photo with your friends, enjoy😊!
55
  """
@@ -67,7 +271,6 @@ If our work is helpful for your research or applications, please cite us via:
67
  year={2024}
68
  }
69
  ```
70
-
71
  📧 **Contact**
72
  <br>
73
  If you have any questions, please feel free to open an issue or directly reach us out at <b>haofanwang.ai@gmail.com</b>.
@@ -75,9 +278,10 @@ If you have any questions, please feel free to open an issue or directly reach u
75
 
76
  tips = r"""
77
  ### Usage tips of InstantID
78
- 1. If you're not satisfied with the similarity, scroll down to "Advanced Options" and increase the weight of "IdentityNet Strength" and "Adapter Strength".
79
- 2. If you feel that the saturation is too high, first decrease the Adapter strength. If it is still too high, then decrease the IdentityNet strength.
80
- 3. If you find that text control is not as expected, decrease Adapter strength.
 
81
  """
82
 
83
  css = '''
@@ -113,14 +317,34 @@ with gr.Blocks(css=css) as demo:
113
  # prompt
114
  prompt = gr.Textbox(label="Prompt",
115
  info="Give simple prompt is enough to achieve good face fedility",
116
- placeholder="A photo of a man/woman")
117
- submit = gr.Button("Submit")
118
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  with gr.Accordion(open=False, label="Advanced Options"):
120
  negative_prompt = gr.Textbox(
121
  label="Negative Prompt",
122
  placeholder="low quality",
123
- value="nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry",
124
  )
125
  num_steps = gr.Slider(
126
  label="Number of sample steps",
@@ -129,27 +353,6 @@ with gr.Blocks(css=css) as demo:
129
  step=1,
130
  value=30,
131
  )
132
- identitynet_strength_ratio = gr.Slider(
133
- label="IdentityNet strength",
134
- minimum=0,
135
- maximum=1.5,
136
- step=0.05,
137
- value=0.65,
138
- )
139
- adapter_strength_ratio = gr.Slider(
140
- label="Image adapter strength",
141
- minimum=0,
142
- maximum=1,
143
- step=0.05,
144
- value=0.30,
145
- )
146
- num_outputs = gr.Slider(
147
- label="Number of output images",
148
- minimum=1,
149
- maximum=4,
150
- step=1,
151
- value=2,
152
- )
153
  guidance_scale = gr.Slider(
154
  label="Guidance scale",
155
  minimum=0.1,
@@ -165,6 +368,7 @@ with gr.Blocks(css=css) as demo:
165
  value=42,
166
  )
167
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
168
 
169
  with gr.Column():
170
  gallery = gr.Gallery(label="Generated Images")
@@ -187,10 +391,18 @@ with gr.Blocks(css=css) as demo:
187
  api_name=False,
188
  ).then(
189
  fn=generate_image,
190
- inputs=[face_files, pose_files, prompt, negative_prompt, num_steps, identitynet_strength_ratio, adapter_strength_ratio, num_outputs, guidance_scale, seed],
191
  outputs=[gallery, usage_tips]
192
  )
193
-
194
- gr.Markdown(article)
195
 
 
 
 
 
 
 
 
 
 
 
196
  demo.launch()
 
1
  import os
2
  import cv2
3
  import math
4
+ import torch
5
  import random
6
  import numpy as np
7
+
8
+ import PIL
9
  from PIL import Image
10
 
11
+ import diffusers
12
  from diffusers.utils import load_image
13
+ from diffusers.models import ControlNetModel
14
+
15
+ import insightface
16
+ from insightface.app import FaceAnalysis
17
 
18
+ from style_template import styles
19
+ from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline
20
+
21
+ import spaces
22
  import gradio as gr
23
 
24
  # global variable
25
  MAX_SEED = np.iinfo(np.int32).max
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ STYLE_NAMES = list(styles.keys())
28
+ DEFAULT_STYLE_NAME = "Watercolor"
29
+
30
+ # download checkpoints
31
+ from huggingface_hub import hf_hub_download
32
+ hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir="./checkpoints")
33
+ hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir="./checkpoints")
34
+ hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints")
35
+
36
+ # Load face encoder
37
+ app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
38
+ app.prepare(ctx_id=0, det_size=(640, 640))
39
+
40
+ # Path to InstantID models
41
+ face_adapter = f'./checkpoints/ip-adapter.bin'
42
+ controlnet_path = f'./checkpoints/ControlNetModel'
43
+
44
+ # Load pipeline
45
+ controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
46
+
47
+ base_model_path = 'GHArt/Unstable_Diffusers_YamerMIX_V9_xl_fp16'
48
+
49
+ pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
50
+ base_model_path,
51
+ controlnet=controlnet,
52
+ torch_dtype=torch.float16,
53
+ safety_checker=None,
54
+ feature_extractor=None,
55
+ )
56
+ pipe.cuda()
57
+ pipe.load_ip_adapter_instantid(face_adapter)
58
 
59
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
60
  if randomize_seed:
 
73
  def remove_tips():
74
  return gr.update(visible=False)
75
 
76
+ def get_example():
77
+ case = [
78
+ [
79
+ ['./examples/yann-lecun_resize.jpg'],
80
+ "a man",
81
+ "Snow",
82
+ "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
83
+ ],
84
+ [
85
+ ['./examples/musk_resize.jpeg'],
86
+ "a man",
87
+ "Mars",
88
+ "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
89
+ ],
90
+ [
91
+ ['./examples/sam_resize.png'],
92
+ "a man",
93
+ "Jungle",
94
+ "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, gree",
95
+ ],
96
+ [
97
+ ['./examples/schmidhuber_resize.png'],
98
+ "a man",
99
+ "Neon",
100
+ "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
101
+ ],
102
+ [
103
+ ['./examples/kaifu_resize.png'],
104
+ "a man",
105
+ "Vibrant Color",
106
+ "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
107
+ ],
108
+ ]
109
+ return case
110
+
111
+ def convert_from_cv2_to_image(img: np.ndarray) -> Image:
112
+ return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
113
+
114
+ def convert_from_image_to_cv2(img: Image) -> np.ndarray:
115
+ return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
116
+
117
+ def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]):
118
+ stickwidth = 4
119
+ limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
120
+ kps = np.array(kps)
121
+
122
+ w, h = image_pil.size
123
+ out_img = np.zeros([h, w, 3])
124
+
125
+ for i in range(len(limbSeq)):
126
+ index = limbSeq[i]
127
+ color = color_list[index[0]]
128
+
129
+ x = kps[index][:, 0]
130
+ y = kps[index][:, 1]
131
+ length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
132
+ angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
133
+ polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
134
+ out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
135
+ out_img = (out_img * 0.6).astype(np.uint8)
136
+
137
+ for idx_kp, kp in enumerate(kps):
138
+ color = color_list[idx_kp]
139
+ x, y = kp
140
+ out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
141
+
142
+ out_img_pil = Image.fromarray(out_img.astype(np.uint8))
143
+ return out_img_pil
144
+
145
+ def resize_img(input_image, max_side=1280, min_side=1024, size=None,
146
+ pad_to_max_side=False, mode=PIL.Image.BILINEAR, base_pixel_number=64):
147
+
148
+ w, h = input_image.size
149
+ if size is not None:
150
+ w_resize_new, h_resize_new = size
151
+ else:
152
+ ratio = min_side / min(h, w)
153
+ w, h = round(ratio*w), round(ratio*h)
154
+ ratio = max_side / max(h, w)
155
+ input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode)
156
+ w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
157
+ h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
158
+ input_image = input_image.resize([w_resize_new, h_resize_new], mode)
159
+
160
+ if pad_to_max_side:
161
+ res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
162
+ offset_x = (max_side - w_resize_new) // 2
163
+ offset_y = (max_side - h_resize_new) // 2
164
+ res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image)
165
+ input_image = Image.fromarray(res)
166
+ return input_image
167
+
168
+ def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
169
+ p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
170
+ return p.replace("{prompt}", positive), n + ' ' + negative
171
+
172
+ @spaces.GPU
173
+ def generate_image(face_image, pose_image, prompt, negative_prompt, style_name, enhance_face_region, num_steps, identitynet_strength_ratio, adapter_strength_ratio, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
174
 
175
  if face_image is None:
176
  raise gr.Error(f"Cannot find any input face image! Please upload the face image")
177
+
178
+ if prompt is None:
179
+ prompt = "a person"
180
+
181
+ # apply the style template
182
+ prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
183
+
184
  face_image = load_image(face_image[0])
185
+ face_image = resize_img(face_image)
186
+ face_image_cv2 = convert_from_image_to_cv2(face_image)
187
+ height, width, _ = face_image_cv2.shape
188
+
189
+ # Extract face features
190
+ face_info = app.get(face_image_cv2)
191
+
192
+ if len(face_info) == 0:
193
+ raise gr.Error(f"Cannot find any face in the image! Please upload another person image")
194
+
195
+ face_info = face_info[-1]
196
+ face_emb = face_info['embedding']
197
+ face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info['kps'])
198
+
199
+ if pose_image is not None:
200
+ pose_image = load_image(pose_image[0])
201
+ pose_image = resize_img(pose_image)
202
+ pose_image_cv2 = convert_from_image_to_cv2(pose_image)
203
+
204
+ face_info = app.get(pose_image_cv2)
205
+
206
+ if len(face_info) == 0:
207
+ raise gr.Error(f"Cannot find any face in the reference image! Please upload another person image")
208
+
209
+ face_info = face_info[-1]
210
+ face_kps = draw_kps(pose_image, face_info['kps'])
211
+
212
+ width, height = face_kps.size
213
+
214
+ if enhance_face_region:
215
+ control_mask = np.zeros([height, width, 3])
216
+ x1, y1, x2, y2 = face_info['bbox']
217
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
218
+ control_mask[y1:y2, x1:x2] = 255
219
+ control_mask = Image.fromarray(control_mask.astype(np.uint8))
220
+ else:
221
+ control_mask = None
222
+
223
+ generator = torch.Generator(device=device).manual_seed(seed)
224
+
225
+ print("Start inference...")
226
+ print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}")
227
+
228
+ pipe.set_ip_adapter_scale(adapter_strength_ratio)
229
+ images = pipe(
230
+ prompt=prompt,
231
+ negative_prompt=negative_prompt,
232
+ image_embeds=face_emb,
233
+ image=face_kps,
234
+ control_mask=control_mask,
235
+ controlnet_conditioning_scale=float(identitynet_strength_ratio),
236
+ num_inference_steps=num_steps,
237
+ guidance_scale=guidance_scale,
238
+ height=height,
239
+ width=width,
240
+ generator=generator
241
+ ).images
242
 
243
+ return images, gr.update(visible=True)
244
 
245
  ### Description
246
  title = r"""
 
251
  <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/InstantID/InstantID' target='_blank'><b>InstantID: Zero-shot Identity-Preserving Generation in Seconds</b></a>.<br>
252
 
253
  How to use:<br>
254
+ 1. Upload a person image. For multiple person images, we will only detect the biggest face. Make sure face is not too small and not significantly blocked or blurred.
255
+ 2. (Optionally) upload another person image as reference pose. If not uploaded, we will use the first person image to extract landmarks. If you use a cropped face at step1, it is recommeneded to upload it to extract a new pose.
256
+ 3. Enter a text prompt as done in normal text-to-image models.
257
  4. Click the <b>Submit</b> button to start customizing.
258
  5. Share your customizd photo with your friends, enjoy😊!
259
  """
 
271
  year={2024}
272
  }
273
  ```
 
274
  📧 **Contact**
275
  <br>
276
  If you have any questions, please feel free to open an issue or directly reach us out at <b>haofanwang.ai@gmail.com</b>.
 
278
 
279
  tips = r"""
280
  ### Usage tips of InstantID
281
+ 1. If you're unsatisfied with the similarity, increase the weight of controlnet_conditioning_scale (IdentityNet) and ip_adapter_scale (Adapter).
282
+ 2. If the generated image is over-saturated, decrease the ip_adapter_scale. If not work, decrease controlnet_conditioning_scale.
283
+ 3. If text control is not as expected, decrease ip_adapter_scale.
284
+ 4. Find a good base model always makes a difference.
285
  """
286
 
287
  css = '''
 
317
  # prompt
318
  prompt = gr.Textbox(label="Prompt",
319
  info="Give simple prompt is enough to achieve good face fedility",
320
+ placeholder="A photo of a person",
321
+ value="")
322
+
323
+ submit = gr.Button("Submit", variant="primary")
324
+
325
+ style = gr.Dropdown(label="Style template", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
326
+
327
+ # strength
328
+ identitynet_strength_ratio = gr.Slider(
329
+ label="IdentityNet strength (for fedility)",
330
+ minimum=0,
331
+ maximum=1.5,
332
+ step=0.05,
333
+ value=0.80,
334
+ )
335
+ adapter_strength_ratio = gr.Slider(
336
+ label="Image adapter strength (for detail)",
337
+ minimum=0,
338
+ maximum=1.5,
339
+ step=0.05,
340
+ value=0.80,
341
+ )
342
+
343
  with gr.Accordion(open=False, label="Advanced Options"):
344
  negative_prompt = gr.Textbox(
345
  label="Negative Prompt",
346
  placeholder="low quality",
347
+ value="(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
348
  )
349
  num_steps = gr.Slider(
350
  label="Number of sample steps",
 
353
  step=1,
354
  value=30,
355
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  guidance_scale = gr.Slider(
357
  label="Guidance scale",
358
  minimum=0.1,
 
368
  value=42,
369
  )
370
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
371
+ enhance_face_region = gr.Checkbox(label="Enhance non-face region", value=True)
372
 
373
  with gr.Column():
374
  gallery = gr.Gallery(label="Generated Images")
 
391
  api_name=False,
392
  ).then(
393
  fn=generate_image,
394
+ inputs=[face_files, pose_files, prompt, negative_prompt, style, enhance_face_region, num_steps, identitynet_strength_ratio, adapter_strength_ratio, guidance_scale, seed],
395
  outputs=[gallery, usage_tips]
396
  )
 
 
397
 
398
+ gr.Examples(
399
+ examples=get_example(),
400
+ inputs=[face_files, prompt, style, negative_prompt],
401
+ run_on_click=True,
402
+ fn=upload_example_to_gallery,
403
+ outputs=[uploaded_faces, clear_button_face, face_files],
404
+ )
405
+
406
+ gr.Markdown(article)
407
+
408
  demo.launch()
ip_adapter/attention_processor.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ try:
7
+ import xformers
8
+ import xformers.ops
9
+ xformers_available = True
10
+ except Exception as e:
11
+ xformers_available = False
12
+
13
+
14
+
15
+ class RegionControler(object):
16
+ def __init__(self) -> None:
17
+ self.prompt_image_conditioning = []
18
+ region_control = RegionControler()
19
+
20
+
21
+ class AttnProcessor(nn.Module):
22
+ r"""
23
+ Default processor for performing attention-related computations.
24
+ """
25
+ def __init__(
26
+ self,
27
+ hidden_size=None,
28
+ cross_attention_dim=None,
29
+ ):
30
+ super().__init__()
31
+
32
+ def __call__(
33
+ self,
34
+ attn,
35
+ hidden_states,
36
+ encoder_hidden_states=None,
37
+ attention_mask=None,
38
+ temb=None,
39
+ ):
40
+ residual = hidden_states
41
+
42
+ if attn.spatial_norm is not None:
43
+ hidden_states = attn.spatial_norm(hidden_states, temb)
44
+
45
+ input_ndim = hidden_states.ndim
46
+
47
+ if input_ndim == 4:
48
+ batch_size, channel, height, width = hidden_states.shape
49
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
50
+
51
+ batch_size, sequence_length, _ = (
52
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
53
+ )
54
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
55
+
56
+ if attn.group_norm is not None:
57
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
58
+
59
+ query = attn.to_q(hidden_states)
60
+
61
+ if encoder_hidden_states is None:
62
+ encoder_hidden_states = hidden_states
63
+ elif attn.norm_cross:
64
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
65
+
66
+ key = attn.to_k(encoder_hidden_states)
67
+ value = attn.to_v(encoder_hidden_states)
68
+
69
+ query = attn.head_to_batch_dim(query)
70
+ key = attn.head_to_batch_dim(key)
71
+ value = attn.head_to_batch_dim(value)
72
+
73
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
74
+ hidden_states = torch.bmm(attention_probs, value)
75
+ hidden_states = attn.batch_to_head_dim(hidden_states)
76
+
77
+ # linear proj
78
+ hidden_states = attn.to_out[0](hidden_states)
79
+ # dropout
80
+ hidden_states = attn.to_out[1](hidden_states)
81
+
82
+ if input_ndim == 4:
83
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
84
+
85
+ if attn.residual_connection:
86
+ hidden_states = hidden_states + residual
87
+
88
+ hidden_states = hidden_states / attn.rescale_output_factor
89
+
90
+ return hidden_states
91
+
92
+
93
+ class IPAttnProcessor(nn.Module):
94
+ r"""
95
+ Attention processor for IP-Adapater.
96
+ Args:
97
+ hidden_size (`int`):
98
+ The hidden size of the attention layer.
99
+ cross_attention_dim (`int`):
100
+ The number of channels in the `encoder_hidden_states`.
101
+ scale (`float`, defaults to 1.0):
102
+ the weight scale of image prompt.
103
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
104
+ The context length of the image features.
105
+ """
106
+
107
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
108
+ super().__init__()
109
+
110
+ self.hidden_size = hidden_size
111
+ self.cross_attention_dim = cross_attention_dim
112
+ self.scale = scale
113
+ self.num_tokens = num_tokens
114
+
115
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
116
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
117
+
118
+ def __call__(
119
+ self,
120
+ attn,
121
+ hidden_states,
122
+ encoder_hidden_states=None,
123
+ attention_mask=None,
124
+ temb=None,
125
+ ):
126
+ residual = hidden_states
127
+
128
+ if attn.spatial_norm is not None:
129
+ hidden_states = attn.spatial_norm(hidden_states, temb)
130
+
131
+ input_ndim = hidden_states.ndim
132
+
133
+ if input_ndim == 4:
134
+ batch_size, channel, height, width = hidden_states.shape
135
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
136
+
137
+ batch_size, sequence_length, _ = (
138
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
139
+ )
140
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
141
+
142
+ if attn.group_norm is not None:
143
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
144
+
145
+ query = attn.to_q(hidden_states)
146
+
147
+ if encoder_hidden_states is None:
148
+ encoder_hidden_states = hidden_states
149
+ else:
150
+ # get encoder_hidden_states, ip_hidden_states
151
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
152
+ encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :]
153
+ if attn.norm_cross:
154
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
155
+
156
+ key = attn.to_k(encoder_hidden_states)
157
+ value = attn.to_v(encoder_hidden_states)
158
+
159
+ query = attn.head_to_batch_dim(query)
160
+ key = attn.head_to_batch_dim(key)
161
+ value = attn.head_to_batch_dim(value)
162
+
163
+ if xformers_available:
164
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
165
+ else:
166
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
167
+ hidden_states = torch.bmm(attention_probs, value)
168
+ hidden_states = attn.batch_to_head_dim(hidden_states)
169
+
170
+ # for ip-adapter
171
+ ip_key = self.to_k_ip(ip_hidden_states)
172
+ ip_value = self.to_v_ip(ip_hidden_states)
173
+
174
+ ip_key = attn.head_to_batch_dim(ip_key)
175
+ ip_value = attn.head_to_batch_dim(ip_value)
176
+
177
+ if xformers_available:
178
+ ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
179
+ else:
180
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
181
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
182
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
183
+
184
+ # region control
185
+ if len(region_control.prompt_image_conditioning) == 1:
186
+ region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
187
+ if region_mask is not None:
188
+ h, w = region_mask.shape[:2]
189
+ ratio = (h * w / query.shape[1]) ** 0.5
190
+ mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
191
+ else:
192
+ mask = torch.ones_like(ip_hidden_states)
193
+ ip_hidden_states = ip_hidden_states * mask
194
+
195
+ hidden_states = hidden_states + self.scale * ip_hidden_states
196
+
197
+ # linear proj
198
+ hidden_states = attn.to_out[0](hidden_states)
199
+ # dropout
200
+ hidden_states = attn.to_out[1](hidden_states)
201
+
202
+ if input_ndim == 4:
203
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
204
+
205
+ if attn.residual_connection:
206
+ hidden_states = hidden_states + residual
207
+
208
+ hidden_states = hidden_states / attn.rescale_output_factor
209
+
210
+ return hidden_states
211
+
212
+
213
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
214
+ # TODO attention_mask
215
+ query = query.contiguous()
216
+ key = key.contiguous()
217
+ value = value.contiguous()
218
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
219
+ # hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
220
+ return hidden_states
221
+
222
+
223
+ class AttnProcessor2_0(torch.nn.Module):
224
+ r"""
225
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
226
+ """
227
+ def __init__(
228
+ self,
229
+ hidden_size=None,
230
+ cross_attention_dim=None,
231
+ ):
232
+ super().__init__()
233
+ if not hasattr(F, "scaled_dot_product_attention"):
234
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
235
+
236
+ def __call__(
237
+ self,
238
+ attn,
239
+ hidden_states,
240
+ encoder_hidden_states=None,
241
+ attention_mask=None,
242
+ temb=None,
243
+ ):
244
+ residual = hidden_states
245
+
246
+ if attn.spatial_norm is not None:
247
+ hidden_states = attn.spatial_norm(hidden_states, temb)
248
+
249
+ input_ndim = hidden_states.ndim
250
+
251
+ if input_ndim == 4:
252
+ batch_size, channel, height, width = hidden_states.shape
253
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
254
+
255
+ batch_size, sequence_length, _ = (
256
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
257
+ )
258
+
259
+ if attention_mask is not None:
260
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
261
+ # scaled_dot_product_attention expects attention_mask shape to be
262
+ # (batch, heads, source_length, target_length)
263
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
264
+
265
+ if attn.group_norm is not None:
266
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
267
+
268
+ query = attn.to_q(hidden_states)
269
+
270
+ if encoder_hidden_states is None:
271
+ encoder_hidden_states = hidden_states
272
+ elif attn.norm_cross:
273
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
274
+
275
+ key = attn.to_k(encoder_hidden_states)
276
+ value = attn.to_v(encoder_hidden_states)
277
+
278
+ inner_dim = key.shape[-1]
279
+ head_dim = inner_dim // attn.heads
280
+
281
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
282
+
283
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
284
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
285
+
286
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
287
+ # TODO: add support for attn.scale when we move to Torch 2.1
288
+ hidden_states = F.scaled_dot_product_attention(
289
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
290
+ )
291
+
292
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
293
+ hidden_states = hidden_states.to(query.dtype)
294
+
295
+ # linear proj
296
+ hidden_states = attn.to_out[0](hidden_states)
297
+ # dropout
298
+ hidden_states = attn.to_out[1](hidden_states)
299
+
300
+ if input_ndim == 4:
301
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
302
+
303
+ if attn.residual_connection:
304
+ hidden_states = hidden_states + residual
305
+
306
+ hidden_states = hidden_states / attn.rescale_output_factor
307
+
308
+ return hidden_states
ip_adapter/resampler.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ # FFN
9
+ def FeedForward(dim, mult=4):
10
+ inner_dim = int(dim * mult)
11
+ return nn.Sequential(
12
+ nn.LayerNorm(dim),
13
+ nn.Linear(dim, inner_dim, bias=False),
14
+ nn.GELU(),
15
+ nn.Linear(inner_dim, dim, bias=False),
16
+ )
17
+
18
+
19
+ def reshape_tensor(x, heads):
20
+ bs, length, width = x.shape
21
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
22
+ x = x.view(bs, length, heads, -1)
23
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
24
+ x = x.transpose(1, 2)
25
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
26
+ x = x.reshape(bs, heads, length, -1)
27
+ return x
28
+
29
+
30
+ class PerceiverAttention(nn.Module):
31
+ def __init__(self, *, dim, dim_head=64, heads=8):
32
+ super().__init__()
33
+ self.scale = dim_head**-0.5
34
+ self.dim_head = dim_head
35
+ self.heads = heads
36
+ inner_dim = dim_head * heads
37
+
38
+ self.norm1 = nn.LayerNorm(dim)
39
+ self.norm2 = nn.LayerNorm(dim)
40
+
41
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
42
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
43
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
44
+
45
+
46
+ def forward(self, x, latents):
47
+ """
48
+ Args:
49
+ x (torch.Tensor): image features
50
+ shape (b, n1, D)
51
+ latent (torch.Tensor): latent features
52
+ shape (b, n2, D)
53
+ """
54
+ x = self.norm1(x)
55
+ latents = self.norm2(latents)
56
+
57
+ b, l, _ = latents.shape
58
+
59
+ q = self.to_q(latents)
60
+ kv_input = torch.cat((x, latents), dim=-2)
61
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
62
+
63
+ q = reshape_tensor(q, self.heads)
64
+ k = reshape_tensor(k, self.heads)
65
+ v = reshape_tensor(v, self.heads)
66
+
67
+ # attention
68
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
69
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
70
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
71
+ out = weight @ v
72
+
73
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
74
+
75
+ return self.to_out(out)
76
+
77
+
78
+ class Resampler(nn.Module):
79
+ def __init__(
80
+ self,
81
+ dim=1024,
82
+ depth=8,
83
+ dim_head=64,
84
+ heads=16,
85
+ num_queries=8,
86
+ embedding_dim=768,
87
+ output_dim=1024,
88
+ ff_mult=4,
89
+ ):
90
+ super().__init__()
91
+
92
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
93
+
94
+ self.proj_in = nn.Linear(embedding_dim, dim)
95
+
96
+ self.proj_out = nn.Linear(dim, output_dim)
97
+ self.norm_out = nn.LayerNorm(output_dim)
98
+
99
+ self.layers = nn.ModuleList([])
100
+ for _ in range(depth):
101
+ self.layers.append(
102
+ nn.ModuleList(
103
+ [
104
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
105
+ FeedForward(dim=dim, mult=ff_mult),
106
+ ]
107
+ )
108
+ )
109
+
110
+ def forward(self, x):
111
+
112
+ latents = self.latents.repeat(x.size(0), 1, 1)
113
+
114
+ x = self.proj_in(x)
115
+
116
+ for attn, ff in self.layers:
117
+ latents = attn(x, latents) + latents
118
+ latents = ff(latents) + latents
119
+
120
+ latents = self.proj_out(latents)
121
+ return self.norm_out(latents)
ip_adapter/utils.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+
3
+
4
+ def is_torch2_available():
5
+ return hasattr(F, "scaled_dot_product_attention")
models/antelopev2/1k3d68.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc
3
+ size 143607619
models/antelopev2/2d106det.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf
3
+ size 5030888
models/antelopev2/genderage.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb
3
+ size 1322532
models/antelopev2/glintr100.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ab1d6435d639628a6f3e5008dd4f929edf4c4124b1a7169e1048f9fef534cdf
3
+ size 260665334
models/antelopev2/scrfd_10g_bnkps.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91
3
+ size 16923827
pipeline_stable_diffusion_xl_instantid.py ADDED
@@ -0,0 +1,1134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import cv2
19
+ import math
20
+
21
+ import numpy as np
22
+ import PIL.Image
23
+ import torch
24
+ import torch.nn.functional as F
25
+
26
+ from diffusers.image_processor import PipelineImageInput
27
+
28
+ from diffusers.models import ControlNetModel
29
+
30
+ from diffusers.utils import (
31
+ deprecate,
32
+ logging,
33
+ replace_example_docstring,
34
+ )
35
+ from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
36
+ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
37
+
38
+ from diffusers import StableDiffusionXLControlNetPipeline
39
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
40
+ from diffusers.utils.import_utils import is_xformers_available
41
+
42
+ from ip_adapter.resampler import Resampler
43
+ from ip_adapter.utils import is_torch2_available
44
+
45
+ if is_torch2_available():
46
+ from ip_adapter.attention_processor import (
47
+ AttnProcessor2_0 as AttnProcessor,
48
+ )
49
+ from ip_adapter.attention_processor import (
50
+ IPAttnProcessor2_0 as IPAttnProcessor,
51
+ )
52
+ else:
53
+ from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
54
+ from ip_adapter.attention_processor import region_control
55
+
56
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
57
+
58
+
59
+ EXAMPLE_DOC_STRING = """
60
+ Examples:
61
+ ```py
62
+ >>> # !pip install opencv-python transformers accelerate insightface
63
+ >>> import diffusers
64
+ >>> from diffusers.utils import load_image
65
+ >>> from diffusers.models import ControlNetModel
66
+
67
+ >>> import cv2
68
+ >>> import torch
69
+ >>> import numpy as np
70
+ >>> from PIL import Image
71
+
72
+ >>> from insightface.app import FaceAnalysis
73
+ >>> from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
74
+
75
+ >>> # download 'antelopev2' under ./models
76
+ >>> app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
77
+ >>> app.prepare(ctx_id=0, det_size=(640, 640))
78
+
79
+ >>> # download models under ./checkpoints
80
+ >>> face_adapter = f'./checkpoints/ip-adapter.bin'
81
+ >>> controlnet_path = f'./checkpoints/ControlNetModel'
82
+
83
+ >>> # load IdentityNet
84
+ >>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
85
+
86
+ >>> pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
87
+ ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
88
+ ... )
89
+ >>> pipe.cuda()
90
+
91
+ >>> # load adapter
92
+ >>> pipe.load_ip_adapter_instantid(face_adapter)
93
+
94
+ >>> prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality"
95
+ >>> negative_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured"
96
+
97
+ >>> # load an image
98
+ >>> image = load_image("your-example.jpg")
99
+
100
+ >>> face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))[-1]
101
+ >>> face_emb = face_info['embedding']
102
+ >>> face_kps = draw_kps(face_image, face_info['kps'])
103
+
104
+ >>> pipe.set_ip_adapter_scale(0.8)
105
+
106
+ >>> # generate image
107
+ >>> image = pipe(
108
+ ... prompt, image_embeds=face_emb, image=face_kps, controlnet_conditioning_scale=0.8
109
+ ... ).images[0]
110
+ ```
111
+ """
112
+
113
+
114
+ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
115
+ class LongPromptWeight(object):
116
+
117
+ """
118
+ Copied from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion_xl.py
119
+ """
120
+
121
+ def __init__(self) -> None:
122
+ pass
123
+
124
+ def parse_prompt_attention(self, text):
125
+ """
126
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
127
+ Accepted tokens are:
128
+ (abc) - increases attention to abc by a multiplier of 1.1
129
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
130
+ [abc] - decreases attention to abc by a multiplier of 1.1
131
+ \( - literal character '('
132
+ \[ - literal character '['
133
+ \) - literal character ')'
134
+ \] - literal character ']'
135
+ \\ - literal character '\'
136
+ anything else - just text
137
+
138
+ >>> parse_prompt_attention('normal text')
139
+ [['normal text', 1.0]]
140
+ >>> parse_prompt_attention('an (important) word')
141
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
142
+ >>> parse_prompt_attention('(unbalanced')
143
+ [['unbalanced', 1.1]]
144
+ >>> parse_prompt_attention('\(literal\]')
145
+ [['(literal]', 1.0]]
146
+ >>> parse_prompt_attention('(unnecessary)(parens)')
147
+ [['unnecessaryparens', 1.1]]
148
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
149
+ [['a ', 1.0],
150
+ ['house', 1.5730000000000004],
151
+ [' ', 1.1],
152
+ ['on', 1.0],
153
+ [' a ', 1.1],
154
+ ['hill', 0.55],
155
+ [', sun, ', 1.1],
156
+ ['sky', 1.4641000000000006],
157
+ ['.', 1.1]]
158
+ """
159
+ import re
160
+
161
+ re_attention = re.compile(
162
+ r"""
163
+ \\\(|\\\)|\\\[|\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|
164
+ \)|]|[^\\()\[\]:]+|:
165
+ """,
166
+ re.X,
167
+ )
168
+
169
+ re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
170
+
171
+ res = []
172
+ round_brackets = []
173
+ square_brackets = []
174
+
175
+ round_bracket_multiplier = 1.1
176
+ square_bracket_multiplier = 1 / 1.1
177
+
178
+ def multiply_range(start_position, multiplier):
179
+ for p in range(start_position, len(res)):
180
+ res[p][1] *= multiplier
181
+
182
+ for m in re_attention.finditer(text):
183
+ text = m.group(0)
184
+ weight = m.group(1)
185
+
186
+ if text.startswith("\\"):
187
+ res.append([text[1:], 1.0])
188
+ elif text == "(":
189
+ round_brackets.append(len(res))
190
+ elif text == "[":
191
+ square_brackets.append(len(res))
192
+ elif weight is not None and len(round_brackets) > 0:
193
+ multiply_range(round_brackets.pop(), float(weight))
194
+ elif text == ")" and len(round_brackets) > 0:
195
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
196
+ elif text == "]" and len(square_brackets) > 0:
197
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
198
+ else:
199
+ parts = re.split(re_break, text)
200
+ for i, part in enumerate(parts):
201
+ if i > 0:
202
+ res.append(["BREAK", -1])
203
+ res.append([part, 1.0])
204
+
205
+ for pos in round_brackets:
206
+ multiply_range(pos, round_bracket_multiplier)
207
+
208
+ for pos in square_brackets:
209
+ multiply_range(pos, square_bracket_multiplier)
210
+
211
+ if len(res) == 0:
212
+ res = [["", 1.0]]
213
+
214
+ # merge runs of identical weights
215
+ i = 0
216
+ while i + 1 < len(res):
217
+ if res[i][1] == res[i + 1][1]:
218
+ res[i][0] += res[i + 1][0]
219
+ res.pop(i + 1)
220
+ else:
221
+ i += 1
222
+
223
+ return res
224
+
225
+ def get_prompts_tokens_with_weights(self, clip_tokenizer: CLIPTokenizer, prompt: str):
226
+ """
227
+ Get prompt token ids and weights, this function works for both prompt and negative prompt
228
+
229
+ Args:
230
+ pipe (CLIPTokenizer)
231
+ A CLIPTokenizer
232
+ prompt (str)
233
+ A prompt string with weights
234
+
235
+ Returns:
236
+ text_tokens (list)
237
+ A list contains token ids
238
+ text_weight (list)
239
+ A list contains the correspodent weight of token ids
240
+
241
+ Example:
242
+ import torch
243
+ from transformers import CLIPTokenizer
244
+
245
+ clip_tokenizer = CLIPTokenizer.from_pretrained(
246
+ "stablediffusionapi/deliberate-v2"
247
+ , subfolder = "tokenizer"
248
+ , dtype = torch.float16
249
+ )
250
+
251
+ token_id_list, token_weight_list = get_prompts_tokens_with_weights(
252
+ clip_tokenizer = clip_tokenizer
253
+ ,prompt = "a (red:1.5) cat"*70
254
+ )
255
+ """
256
+ texts_and_weights = self.parse_prompt_attention(prompt)
257
+ text_tokens, text_weights = [], []
258
+ for word, weight in texts_and_weights:
259
+ # tokenize and discard the starting and the ending token
260
+ token = clip_tokenizer(word, truncation=False).input_ids[1:-1] # so that tokenize whatever length prompt
261
+ # the returned token is a 1d list: [320, 1125, 539, 320]
262
+
263
+ # merge the new tokens to the all tokens holder: text_tokens
264
+ text_tokens = [*text_tokens, *token]
265
+
266
+ # each token chunk will come with one weight, like ['red cat', 2.0]
267
+ # need to expand weight for each token.
268
+ chunk_weights = [weight] * len(token)
269
+
270
+ # append the weight back to the weight holder: text_weights
271
+ text_weights = [*text_weights, *chunk_weights]
272
+ return text_tokens, text_weights
273
+
274
+ def group_tokens_and_weights(self, token_ids: list, weights: list, pad_last_block=False):
275
+ """
276
+ Produce tokens and weights in groups and pad the missing tokens
277
+
278
+ Args:
279
+ token_ids (list)
280
+ The token ids from tokenizer
281
+ weights (list)
282
+ The weights list from function get_prompts_tokens_with_weights
283
+ pad_last_block (bool)
284
+ Control if fill the last token list to 75 tokens with eos
285
+ Returns:
286
+ new_token_ids (2d list)
287
+ new_weights (2d list)
288
+
289
+ Example:
290
+ token_groups,weight_groups = group_tokens_and_weights(
291
+ token_ids = token_id_list
292
+ , weights = token_weight_list
293
+ )
294
+ """
295
+ bos, eos = 49406, 49407
296
+
297
+ # this will be a 2d list
298
+ new_token_ids = []
299
+ new_weights = []
300
+ while len(token_ids) >= 75:
301
+ # get the first 75 tokens
302
+ head_75_tokens = [token_ids.pop(0) for _ in range(75)]
303
+ head_75_weights = [weights.pop(0) for _ in range(75)]
304
+
305
+ # extract token ids and weights
306
+ temp_77_token_ids = [bos] + head_75_tokens + [eos]
307
+ temp_77_weights = [1.0] + head_75_weights + [1.0]
308
+
309
+ # add 77 token and weights chunk to the holder list
310
+ new_token_ids.append(temp_77_token_ids)
311
+ new_weights.append(temp_77_weights)
312
+
313
+ # padding the left
314
+ if len(token_ids) >= 0:
315
+ padding_len = 75 - len(token_ids) if pad_last_block else 0
316
+
317
+ temp_77_token_ids = [bos] + token_ids + [eos] * padding_len + [eos]
318
+ new_token_ids.append(temp_77_token_ids)
319
+
320
+ temp_77_weights = [1.0] + weights + [1.0] * padding_len + [1.0]
321
+ new_weights.append(temp_77_weights)
322
+
323
+ return new_token_ids, new_weights
324
+
325
+ def get_weighted_text_embeddings_sdxl(
326
+ self,
327
+ pipe: StableDiffusionXLPipeline,
328
+ prompt: str = "",
329
+ prompt_2: str = None,
330
+ neg_prompt: str = "",
331
+ neg_prompt_2: str = None,
332
+ prompt_embeds=None,
333
+ negative_prompt_embeds=None,
334
+ pooled_prompt_embeds=None,
335
+ negative_pooled_prompt_embeds=None,
336
+ extra_emb=None,
337
+ extra_emb_alpha=0.6,
338
+ ):
339
+ """
340
+ This function can process long prompt with weights, no length limitation
341
+ for Stable Diffusion XL
342
+
343
+ Args:
344
+ pipe (StableDiffusionPipeline)
345
+ prompt (str)
346
+ prompt_2 (str)
347
+ neg_prompt (str)
348
+ neg_prompt_2 (str)
349
+ Returns:
350
+ prompt_embeds (torch.Tensor)
351
+ neg_prompt_embeds (torch.Tensor)
352
+ """
353
+ #
354
+ if prompt_embeds is not None and \
355
+ negative_prompt_embeds is not None and \
356
+ pooled_prompt_embeds is not None and \
357
+ negative_pooled_prompt_embeds is not None:
358
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
359
+
360
+ if prompt_2:
361
+ prompt = f"{prompt} {prompt_2}"
362
+
363
+ if neg_prompt_2:
364
+ neg_prompt = f"{neg_prompt} {neg_prompt_2}"
365
+
366
+ eos = pipe.tokenizer.eos_token_id
367
+
368
+ # tokenizer 1
369
+ prompt_tokens, prompt_weights = self.get_prompts_tokens_with_weights(pipe.tokenizer, prompt)
370
+ neg_prompt_tokens, neg_prompt_weights = self.get_prompts_tokens_with_weights(pipe.tokenizer, neg_prompt)
371
+
372
+ # tokenizer 2
373
+ # prompt_tokens_2, prompt_weights_2 = self.get_prompts_tokens_with_weights(pipe.tokenizer_2, prompt)
374
+ # neg_prompt_tokens_2, neg_prompt_weights_2 = self.get_prompts_tokens_with_weights(pipe.tokenizer_2, neg_prompt)
375
+ # tokenizer 2 遇到 !! !!!! 等多感叹号和tokenizer 1的效果不一致
376
+ prompt_tokens_2, prompt_weights_2 = self.get_prompts_tokens_with_weights(pipe.tokenizer, prompt)
377
+ neg_prompt_tokens_2, neg_prompt_weights_2 = self.get_prompts_tokens_with_weights(pipe.tokenizer, neg_prompt)
378
+
379
+ # padding the shorter one for prompt set 1
380
+ prompt_token_len = len(prompt_tokens)
381
+ neg_prompt_token_len = len(neg_prompt_tokens)
382
+
383
+ if prompt_token_len > neg_prompt_token_len:
384
+ # padding the neg_prompt with eos token
385
+ neg_prompt_tokens = neg_prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len)
386
+ neg_prompt_weights = neg_prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len)
387
+ else:
388
+ # padding the prompt
389
+ prompt_tokens = prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len)
390
+ prompt_weights = prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len)
391
+
392
+ # padding the shorter one for token set 2
393
+ prompt_token_len_2 = len(prompt_tokens_2)
394
+ neg_prompt_token_len_2 = len(neg_prompt_tokens_2)
395
+
396
+ if prompt_token_len_2 > neg_prompt_token_len_2:
397
+ # padding the neg_prompt with eos token
398
+ neg_prompt_tokens_2 = neg_prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
399
+ neg_prompt_weights_2 = neg_prompt_weights_2 + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
400
+ else:
401
+ # padding the prompt
402
+ prompt_tokens_2 = prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
403
+ prompt_weights_2 = prompt_weights + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
404
+
405
+ embeds = []
406
+ neg_embeds = []
407
+
408
+ prompt_token_groups, prompt_weight_groups = self.group_tokens_and_weights(prompt_tokens.copy(), prompt_weights.copy())
409
+
410
+ neg_prompt_token_groups, neg_prompt_weight_groups = self.group_tokens_and_weights(
411
+ neg_prompt_tokens.copy(), neg_prompt_weights.copy()
412
+ )
413
+
414
+ prompt_token_groups_2, prompt_weight_groups_2 = self.group_tokens_and_weights(
415
+ prompt_tokens_2.copy(), prompt_weights_2.copy()
416
+ )
417
+
418
+ neg_prompt_token_groups_2, neg_prompt_weight_groups_2 = self.group_tokens_and_weights(
419
+ neg_prompt_tokens_2.copy(), neg_prompt_weights_2.copy()
420
+ )
421
+
422
+ # get prompt embeddings one by one is not working.
423
+ for i in range(len(prompt_token_groups)):
424
+ # get positive prompt embeddings with weights
425
+ token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=pipe.device)
426
+ weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=pipe.device)
427
+
428
+ token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device)
429
+
430
+ # use first text encoder
431
+ prompt_embeds_1 = pipe.text_encoder(token_tensor.to(pipe.device), output_hidden_states=True)
432
+ prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
433
+
434
+ # use second text encoder
435
+ prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(pipe.device), output_hidden_states=True)
436
+ prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
437
+ pooled_prompt_embeds = prompt_embeds_2[0]
438
+
439
+ prompt_embeds_list = [prompt_embeds_1_hidden_states, prompt_embeds_2_hidden_states]
440
+ token_embedding = torch.concat(prompt_embeds_list, dim=-1).squeeze(0)
441
+
442
+ for j in range(len(weight_tensor)):
443
+ if weight_tensor[j] != 1.0:
444
+ token_embedding[j] = (
445
+ token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight_tensor[j]
446
+ )
447
+
448
+ token_embedding = token_embedding.unsqueeze(0)
449
+ embeds.append(token_embedding)
450
+
451
+ # get negative prompt embeddings with weights
452
+ neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=pipe.device)
453
+ neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device)
454
+ neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=pipe.device)
455
+
456
+ # use first text encoder
457
+ neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(pipe.device), output_hidden_states=True)
458
+ neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2]
459
+
460
+ # use second text encoder
461
+ neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(pipe.device), output_hidden_states=True)
462
+ neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2]
463
+ negative_pooled_prompt_embeds = neg_prompt_embeds_2[0]
464
+
465
+ neg_prompt_embeds_list = [neg_prompt_embeds_1_hidden_states, neg_prompt_embeds_2_hidden_states]
466
+ neg_token_embedding = torch.concat(neg_prompt_embeds_list, dim=-1).squeeze(0)
467
+
468
+ for z in range(len(neg_weight_tensor)):
469
+ if neg_weight_tensor[z] != 1.0:
470
+ neg_token_embedding[z] = (
471
+ neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight_tensor[z]
472
+ )
473
+
474
+ neg_token_embedding = neg_token_embedding.unsqueeze(0)
475
+ neg_embeds.append(neg_token_embedding)
476
+
477
+ prompt_embeds = torch.cat(embeds, dim=1)
478
+ negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
479
+
480
+ if extra_emb is not None:
481
+ extra_emb = extra_emb.to(prompt_embeds.device, dtype=prompt_embeds.dtype) * extra_emb_alpha
482
+ prompt_embeds = torch.cat([prompt_embeds, extra_emb], 1)
483
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, torch.zeros_like(extra_emb)], 1)
484
+ print(f'fix prompt_embeds, extra_emb_alpha={extra_emb_alpha}')
485
+
486
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
487
+
488
+ def get_prompt_embeds(self, *args, **kwargs):
489
+ prompt_embeds, negative_prompt_embeds, _, _ = self.get_weighted_text_embeddings_sdxl(*args, **kwargs)
490
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
491
+ return prompt_embeds
492
+
493
+
494
+ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
495
+
496
+ def cuda(self, dtype=torch.float16, use_xformers=False):
497
+ self.to('cuda', dtype)
498
+
499
+ if hasattr(self, 'image_proj_model'):
500
+ self.image_proj_model.to(self.unet.device).to(self.unet.dtype)
501
+
502
+ if use_xformers:
503
+ if is_xformers_available():
504
+ import xformers
505
+ from packaging import version
506
+
507
+ xformers_version = version.parse(xformers.__version__)
508
+ if xformers_version == version.parse("0.0.16"):
509
+ logger.warn(
510
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
511
+ )
512
+ self.enable_xformers_memory_efficient_attention()
513
+ else:
514
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
515
+
516
+ def load_ip_adapter_instantid(self, model_ckpt, image_emb_dim=512, num_tokens=16, scale=0.5):
517
+ self.set_image_proj_model(model_ckpt, image_emb_dim, num_tokens)
518
+ self.set_ip_adapter(model_ckpt, num_tokens, scale)
519
+
520
+ def set_image_proj_model(self, model_ckpt, image_emb_dim=512, num_tokens=16):
521
+
522
+ image_proj_model = Resampler(
523
+ dim=1280,
524
+ depth=4,
525
+ dim_head=64,
526
+ heads=20,
527
+ num_queries=num_tokens,
528
+ embedding_dim=image_emb_dim,
529
+ output_dim=self.unet.config.cross_attention_dim,
530
+ ff_mult=4,
531
+ )
532
+
533
+ image_proj_model.eval()
534
+
535
+ self.image_proj_model = image_proj_model.to(self.device, dtype=self.dtype)
536
+ state_dict = torch.load(model_ckpt, map_location="cpu")
537
+ if 'image_proj' in state_dict:
538
+ state_dict = state_dict["image_proj"]
539
+ self.image_proj_model.load_state_dict(state_dict)
540
+
541
+ self.image_proj_model_in_features = image_emb_dim
542
+
543
+ def set_ip_adapter(self, model_ckpt, num_tokens, scale):
544
+
545
+ unet = self.unet
546
+ attn_procs = {}
547
+ for name in unet.attn_processors.keys():
548
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
549
+ if name.startswith("mid_block"):
550
+ hidden_size = unet.config.block_out_channels[-1]
551
+ elif name.startswith("up_blocks"):
552
+ block_id = int(name[len("up_blocks.")])
553
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
554
+ elif name.startswith("down_blocks"):
555
+ block_id = int(name[len("down_blocks.")])
556
+ hidden_size = unet.config.block_out_channels[block_id]
557
+ if cross_attention_dim is None:
558
+ attn_procs[name] = AttnProcessor().to(unet.device, dtype=unet.dtype)
559
+ else:
560
+ attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size,
561
+ cross_attention_dim=cross_attention_dim,
562
+ scale=scale,
563
+ num_tokens=num_tokens).to(unet.device, dtype=unet.dtype)
564
+ unet.set_attn_processor(attn_procs)
565
+
566
+ state_dict = torch.load(model_ckpt, map_location="cpu")
567
+ ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values())
568
+ if 'ip_adapter' in state_dict:
569
+ state_dict = state_dict['ip_adapter']
570
+ ip_layers.load_state_dict(state_dict)
571
+
572
+ def set_ip_adapter_scale(self, scale):
573
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
574
+ for attn_processor in unet.attn_processors.values():
575
+ if isinstance(attn_processor, IPAttnProcessor):
576
+ attn_processor.scale = scale
577
+
578
+ def _encode_prompt_image_emb(self, prompt_image_emb, device, dtype, do_classifier_free_guidance):
579
+
580
+ if isinstance(prompt_image_emb, torch.Tensor):
581
+ prompt_image_emb = prompt_image_emb.clone().detach()
582
+ else:
583
+ prompt_image_emb = torch.tensor(prompt_image_emb)
584
+
585
+ prompt_image_emb = prompt_image_emb.to(device=device, dtype=dtype)
586
+ prompt_image_emb = prompt_image_emb.reshape([1, -1, self.image_proj_model_in_features])
587
+
588
+ if do_classifier_free_guidance:
589
+ prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0)
590
+ else:
591
+ prompt_image_emb = torch.cat([prompt_image_emb], dim=0)
592
+
593
+ prompt_image_emb = self.image_proj_model(prompt_image_emb)
594
+ return prompt_image_emb
595
+
596
+ @torch.no_grad()
597
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
598
+ def __call__(
599
+ self,
600
+ prompt: Union[str, List[str]] = None,
601
+ prompt_2: Optional[Union[str, List[str]]] = None,
602
+ image: PipelineImageInput = None,
603
+ height: Optional[int] = None,
604
+ width: Optional[int] = None,
605
+ num_inference_steps: int = 50,
606
+ guidance_scale: float = 5.0,
607
+ negative_prompt: Optional[Union[str, List[str]]] = None,
608
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
609
+ num_images_per_prompt: Optional[int] = 1,
610
+ eta: float = 0.0,
611
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
612
+ latents: Optional[torch.FloatTensor] = None,
613
+ prompt_embeds: Optional[torch.FloatTensor] = None,
614
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
615
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
616
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
617
+ image_embeds: Optional[torch.FloatTensor] = None,
618
+ output_type: Optional[str] = "pil",
619
+ return_dict: bool = True,
620
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
621
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
622
+ guess_mode: bool = False,
623
+ control_guidance_start: Union[float, List[float]] = 0.0,
624
+ control_guidance_end: Union[float, List[float]] = 1.0,
625
+ original_size: Tuple[int, int] = None,
626
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
627
+ target_size: Tuple[int, int] = None,
628
+ negative_original_size: Optional[Tuple[int, int]] = None,
629
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
630
+ negative_target_size: Optional[Tuple[int, int]] = None,
631
+ clip_skip: Optional[int] = None,
632
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
633
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
634
+ control_mask = None,
635
+ **kwargs,
636
+ ):
637
+ r"""
638
+ The call function to the pipeline for generation.
639
+
640
+ Args:
641
+ prompt (`str` or `List[str]`, *optional*):
642
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
643
+ prompt_2 (`str` or `List[str]`, *optional*):
644
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
645
+ used in both text-encoders.
646
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
647
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
648
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
649
+ specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
650
+ accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
651
+ and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
652
+ `init`, images must be passed as a list such that each element of the list can be correctly batched for
653
+ input to a single ControlNet.
654
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
655
+ The height in pixels of the generated image. Anything below 512 pixels won't work well for
656
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
657
+ and checkpoints that are not specifically fine-tuned on low resolutions.
658
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
659
+ The width in pixels of the generated image. Anything below 512 pixels won't work well for
660
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
661
+ and checkpoints that are not specifically fine-tuned on low resolutions.
662
+ num_inference_steps (`int`, *optional*, defaults to 50):
663
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
664
+ expense of slower inference.
665
+ guidance_scale (`float`, *optional*, defaults to 5.0):
666
+ A higher guidance scale value encourages the model to generate images closely linked to the text
667
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
668
+ negative_prompt (`str` or `List[str]`, *optional*):
669
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
670
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
671
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
672
+ The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`
673
+ and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
674
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
675
+ The number of images to generate per prompt.
676
+ eta (`float`, *optional*, defaults to 0.0):
677
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
678
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
679
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
680
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
681
+ generation deterministic.
682
+ latents (`torch.FloatTensor`, *optional*):
683
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
684
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
685
+ tensor is generated by sampling using the supplied random `generator`.
686
+ prompt_embeds (`torch.FloatTensor`, *optional*):
687
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
688
+ provided, text embeddings are generated from the `prompt` input argument.
689
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
690
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
691
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
692
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
693
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
694
+ not provided, pooled text embeddings are generated from `prompt` input argument.
695
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
696
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
697
+ weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
698
+ argument.
699
+ image_embeds (`torch.FloatTensor`, *optional*):
700
+ Pre-generated image embeddings.
701
+ output_type (`str`, *optional*, defaults to `"pil"`):
702
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
703
+ return_dict (`bool`, *optional*, defaults to `True`):
704
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
705
+ plain tuple.
706
+ cross_attention_kwargs (`dict`, *optional*):
707
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
708
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
709
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
710
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
711
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
712
+ the corresponding scale as a list.
713
+ guess_mode (`bool`, *optional*, defaults to `False`):
714
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
715
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
716
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
717
+ The percentage of total steps at which the ControlNet starts applying.
718
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
719
+ The percentage of total steps at which the ControlNet stops applying.
720
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
721
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
722
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
723
+ explained in section 2.2 of
724
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
725
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
726
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
727
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
728
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
729
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
730
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
731
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
732
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
733
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
734
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
735
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
736
+ micro-conditioning as explained in section 2.2 of
737
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
738
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
739
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
740
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
741
+ micro-conditioning as explained in section 2.2 of
742
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
743
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
744
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
745
+ To negatively condition the generation process based on a target image resolution. It should be as same
746
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
747
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
748
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
749
+ clip_skip (`int`, *optional*):
750
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
751
+ the output of the pre-final layer will be used for computing the prompt embeddings.
752
+ callback_on_step_end (`Callable`, *optional*):
753
+ A function that calls at the end of each denoising steps during the inference. The function is called
754
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
755
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
756
+ `callback_on_step_end_tensor_inputs`.
757
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
758
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
759
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
760
+ `._callback_tensor_inputs` attribute of your pipeine class.
761
+
762
+ Examples:
763
+
764
+ Returns:
765
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
766
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
767
+ otherwise a `tuple` is returned containing the output images.
768
+ """
769
+ lpw = LongPromptWeight()
770
+
771
+ callback = kwargs.pop("callback", None)
772
+ callback_steps = kwargs.pop("callback_steps", None)
773
+
774
+ if callback is not None:
775
+ deprecate(
776
+ "callback",
777
+ "1.0.0",
778
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
779
+ )
780
+ if callback_steps is not None:
781
+ deprecate(
782
+ "callback_steps",
783
+ "1.0.0",
784
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
785
+ )
786
+
787
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
788
+
789
+ # align format for control guidance
790
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
791
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
792
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
793
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
794
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
795
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
796
+ control_guidance_start, control_guidance_end = (
797
+ mult * [control_guidance_start],
798
+ mult * [control_guidance_end],
799
+ )
800
+
801
+ # 1. Check inputs. Raise error if not correct
802
+ self.check_inputs(
803
+ prompt,
804
+ prompt_2,
805
+ image,
806
+ callback_steps,
807
+ negative_prompt,
808
+ negative_prompt_2,
809
+ prompt_embeds,
810
+ negative_prompt_embeds,
811
+ pooled_prompt_embeds,
812
+ negative_pooled_prompt_embeds,
813
+ controlnet_conditioning_scale,
814
+ control_guidance_start,
815
+ control_guidance_end,
816
+ callback_on_step_end_tensor_inputs,
817
+ )
818
+
819
+ self._guidance_scale = guidance_scale
820
+ self._clip_skip = clip_skip
821
+ self._cross_attention_kwargs = cross_attention_kwargs
822
+
823
+ # 2. Define call parameters
824
+ if prompt is not None and isinstance(prompt, str):
825
+ batch_size = 1
826
+ elif prompt is not None and isinstance(prompt, list):
827
+ batch_size = len(prompt)
828
+ else:
829
+ batch_size = prompt_embeds.shape[0]
830
+
831
+ device = self._execution_device
832
+
833
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
834
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
835
+
836
+ global_pool_conditions = (
837
+ controlnet.config.global_pool_conditions
838
+ if isinstance(controlnet, ControlNetModel)
839
+ else controlnet.nets[0].config.global_pool_conditions
840
+ )
841
+ guess_mode = guess_mode or global_pool_conditions
842
+
843
+ # 3.1 Encode input prompt
844
+ (
845
+ prompt_embeds,
846
+ negative_prompt_embeds,
847
+ pooled_prompt_embeds,
848
+ negative_pooled_prompt_embeds,
849
+ ) = lpw.get_weighted_text_embeddings_sdxl(
850
+ pipe=self,
851
+ prompt=prompt,
852
+ neg_prompt=negative_prompt,
853
+ prompt_embeds=prompt_embeds,
854
+ negative_prompt_embeds=negative_prompt_embeds,
855
+ pooled_prompt_embeds=pooled_prompt_embeds,
856
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
857
+ )
858
+
859
+ # 3.2 Encode image prompt
860
+ prompt_image_emb = self._encode_prompt_image_emb(image_embeds,
861
+ device,
862
+ self.unet.dtype,
863
+ self.do_classifier_free_guidance)
864
+
865
+ # 4. Prepare image
866
+ if isinstance(controlnet, ControlNetModel):
867
+ image = self.prepare_image(
868
+ image=image,
869
+ width=width,
870
+ height=height,
871
+ batch_size=batch_size * num_images_per_prompt,
872
+ num_images_per_prompt=num_images_per_prompt,
873
+ device=device,
874
+ dtype=controlnet.dtype,
875
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
876
+ guess_mode=guess_mode,
877
+ )
878
+ height, width = image.shape[-2:]
879
+ elif isinstance(controlnet, MultiControlNetModel):
880
+ images = []
881
+
882
+ for image_ in image:
883
+ image_ = self.prepare_image(
884
+ image=image_,
885
+ width=width,
886
+ height=height,
887
+ batch_size=batch_size * num_images_per_prompt,
888
+ num_images_per_prompt=num_images_per_prompt,
889
+ device=device,
890
+ dtype=controlnet.dtype,
891
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
892
+ guess_mode=guess_mode,
893
+ )
894
+
895
+ images.append(image_)
896
+
897
+ image = images
898
+ height, width = image[0].shape[-2:]
899
+ else:
900
+ assert False
901
+
902
+ # 4.1 Region control
903
+ if control_mask is not None:
904
+ mask_weight_image = control_mask
905
+ mask_weight_image = np.array(mask_weight_image)
906
+ mask_weight_image_tensor = torch.from_numpy(mask_weight_image).to(device=device, dtype=prompt_embeds.dtype)
907
+ mask_weight_image_tensor = mask_weight_image_tensor[:, :, 0] / 255.
908
+ mask_weight_image_tensor = mask_weight_image_tensor[None, None]
909
+ h, w = mask_weight_image_tensor.shape[-2:]
910
+ control_mask_wight_image_list = []
911
+ for scale in [8, 8, 8, 16, 16, 16, 32, 32, 32]:
912
+ scale_mask_weight_image_tensor = F.interpolate(
913
+ mask_weight_image_tensor,(h // scale, w // scale), mode='bilinear')
914
+ control_mask_wight_image_list.append(scale_mask_weight_image_tensor)
915
+ region_mask = torch.from_numpy(np.array(control_mask)[:, :, 0]).to(self.unet.device, dtype=self.unet.dtype) / 255.
916
+ region_control.prompt_image_conditioning = [dict(region_mask=region_mask)]
917
+ else:
918
+ control_mask_wight_image_list = None
919
+ region_control.prompt_image_conditioning = [dict(region_mask=None)]
920
+
921
+ # 5. Prepare timesteps
922
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
923
+ timesteps = self.scheduler.timesteps
924
+ self._num_timesteps = len(timesteps)
925
+
926
+ # 6. Prepare latent variables
927
+ num_channels_latents = self.unet.config.in_channels
928
+ latents = self.prepare_latents(
929
+ batch_size * num_images_per_prompt,
930
+ num_channels_latents,
931
+ height,
932
+ width,
933
+ prompt_embeds.dtype,
934
+ device,
935
+ generator,
936
+ latents,
937
+ )
938
+
939
+ # 6.5 Optionally get Guidance Scale Embedding
940
+ timestep_cond = None
941
+ if self.unet.config.time_cond_proj_dim is not None:
942
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
943
+ timestep_cond = self.get_guidance_scale_embedding(
944
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
945
+ ).to(device=device, dtype=latents.dtype)
946
+
947
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
948
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
949
+
950
+ # 7.1 Create tensor stating which controlnets to keep
951
+ controlnet_keep = []
952
+ for i in range(len(timesteps)):
953
+ keeps = [
954
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
955
+ for s, e in zip(control_guidance_start, control_guidance_end)
956
+ ]
957
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
958
+
959
+ # 7.2 Prepare added time ids & embeddings
960
+ if isinstance(image, list):
961
+ original_size = original_size or image[0].shape[-2:]
962
+ else:
963
+ original_size = original_size or image.shape[-2:]
964
+ target_size = target_size or (height, width)
965
+
966
+ add_text_embeds = pooled_prompt_embeds
967
+ if self.text_encoder_2 is None:
968
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
969
+ else:
970
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
971
+
972
+ add_time_ids = self._get_add_time_ids(
973
+ original_size,
974
+ crops_coords_top_left,
975
+ target_size,
976
+ dtype=prompt_embeds.dtype,
977
+ text_encoder_projection_dim=text_encoder_projection_dim,
978
+ )
979
+
980
+ if negative_original_size is not None and negative_target_size is not None:
981
+ negative_add_time_ids = self._get_add_time_ids(
982
+ negative_original_size,
983
+ negative_crops_coords_top_left,
984
+ negative_target_size,
985
+ dtype=prompt_embeds.dtype,
986
+ text_encoder_projection_dim=text_encoder_projection_dim,
987
+ )
988
+ else:
989
+ negative_add_time_ids = add_time_ids
990
+
991
+ if self.do_classifier_free_guidance:
992
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
993
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
994
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
995
+
996
+ prompt_embeds = prompt_embeds.to(device)
997
+ add_text_embeds = add_text_embeds.to(device)
998
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
999
+ encoder_hidden_states = torch.cat([prompt_embeds, prompt_image_emb], dim=1)
1000
+
1001
+ # 8. Denoising loop
1002
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1003
+ is_unet_compiled = is_compiled_module(self.unet)
1004
+ is_controlnet_compiled = is_compiled_module(self.controlnet)
1005
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
1006
+
1007
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1008
+ for i, t in enumerate(timesteps):
1009
+ # Relevant thread:
1010
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
1011
+ if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
1012
+ torch._inductor.cudagraph_mark_step_begin()
1013
+ # expand the latents if we are doing classifier free guidance
1014
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1015
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1016
+
1017
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1018
+
1019
+ # controlnet(s) inference
1020
+ if guess_mode and self.do_classifier_free_guidance:
1021
+ # Infer ControlNet only for the conditional batch.
1022
+ control_model_input = latents
1023
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1024
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1025
+ controlnet_added_cond_kwargs = {
1026
+ "text_embeds": add_text_embeds.chunk(2)[1],
1027
+ "time_ids": add_time_ids.chunk(2)[1],
1028
+ }
1029
+ else:
1030
+ control_model_input = latent_model_input
1031
+ controlnet_prompt_embeds = prompt_embeds
1032
+ controlnet_added_cond_kwargs = added_cond_kwargs
1033
+
1034
+ if isinstance(controlnet_keep[i], list):
1035
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1036
+ else:
1037
+ controlnet_cond_scale = controlnet_conditioning_scale
1038
+ if isinstance(controlnet_cond_scale, list):
1039
+ controlnet_cond_scale = controlnet_cond_scale[0]
1040
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
1041
+
1042
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1043
+ control_model_input,
1044
+ t,
1045
+ encoder_hidden_states=prompt_image_emb,
1046
+ controlnet_cond=image,
1047
+ conditioning_scale=cond_scale,
1048
+ guess_mode=guess_mode,
1049
+ added_cond_kwargs=controlnet_added_cond_kwargs,
1050
+ return_dict=False,
1051
+ )
1052
+
1053
+ # controlnet mask
1054
+ if control_mask_wight_image_list is not None:
1055
+ down_block_res_samples = [
1056
+ down_block_res_sample * mask_weight
1057
+ for down_block_res_sample, mask_weight in zip(down_block_res_samples, control_mask_wight_image_list)
1058
+ ]
1059
+ mid_block_res_sample *= control_mask_wight_image_list[-1]
1060
+
1061
+ if guess_mode and self.do_classifier_free_guidance:
1062
+ # Infered ControlNet only for the conditional batch.
1063
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1064
+ # add 0 to the unconditional batch to keep it unchanged.
1065
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1066
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1067
+
1068
+ # predict the noise residual
1069
+ noise_pred = self.unet(
1070
+ latent_model_input,
1071
+ t,
1072
+ encoder_hidden_states=encoder_hidden_states,
1073
+ timestep_cond=timestep_cond,
1074
+ cross_attention_kwargs=self.cross_attention_kwargs,
1075
+ down_block_additional_residuals=down_block_res_samples,
1076
+ mid_block_additional_residual=mid_block_res_sample,
1077
+ added_cond_kwargs=added_cond_kwargs,
1078
+ return_dict=False,
1079
+ )[0]
1080
+
1081
+ # perform guidance
1082
+ if self.do_classifier_free_guidance:
1083
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1084
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1085
+
1086
+ # compute the previous noisy sample x_t -> x_t-1
1087
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1088
+
1089
+ if callback_on_step_end is not None:
1090
+ callback_kwargs = {}
1091
+ for k in callback_on_step_end_tensor_inputs:
1092
+ callback_kwargs[k] = locals()[k]
1093
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1094
+
1095
+ latents = callback_outputs.pop("latents", latents)
1096
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1097
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1098
+
1099
+ # call the callback, if provided
1100
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1101
+ progress_bar.update()
1102
+ if callback is not None and i % callback_steps == 0:
1103
+ step_idx = i // getattr(self.scheduler, "order", 1)
1104
+ callback(step_idx, t, latents)
1105
+
1106
+ if not output_type == "latent":
1107
+ # make sure the VAE is in float32 mode, as it overflows in float16
1108
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1109
+ if needs_upcasting:
1110
+ self.upcast_vae()
1111
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1112
+
1113
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1114
+
1115
+ # cast back to fp16 if needed
1116
+ if needs_upcasting:
1117
+ self.vae.to(dtype=torch.float16)
1118
+ else:
1119
+ image = latents
1120
+
1121
+ if not output_type == "latent":
1122
+ # apply watermark if available
1123
+ if self.watermark is not None:
1124
+ image = self.watermark.apply_watermark(image)
1125
+
1126
+ image = self.image_processor.postprocess(image, output_type=output_type)
1127
+
1128
+ # Offload all models
1129
+ self.maybe_free_model_hooks()
1130
+
1131
+ if not return_dict:
1132
+ return (image,)
1133
+
1134
+ return StableDiffusionXLPipelineOutput(images=image)
style_template.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ style_list = [
2
+ {
3
+ "name": "(No style)",
4
+ "prompt": "{prompt}",
5
+ "negative_prompt": "",
6
+ },
7
+ {
8
+ "name": "Watercolor",
9
+ "prompt": "watercolor painting, {prompt}. vibrant, beautiful, painterly, detailed, textural, artistic",
10
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, photorealistic, 35mm film, deformed, glitch, low contrast, noisy",
11
+ },
12
+ {
13
+ "name": "Film Noir",
14
+ "prompt": "film noir style, ink sketch|vector, {prompt} highly detailed, sharp focus, ultra sharpness, monochrome, high contrast, dramatic shadows, 1940s style, mysterious, cinematic",
15
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
16
+ },
17
+ {
18
+ "name": "Neon",
19
+ "prompt": "masterpiece painting, buildings in the backdrop, kaleidoscope, lilac orange blue cream fuchsia bright vivid gradient colors, the scene is cinematic, {prompt}, emotional realism, double exposure, watercolor ink pencil, graded wash, color layering, magic realism, figurative painting, intricate motifs, organic tracery, polished",
20
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
21
+ },
22
+ {
23
+ "name": "Jungle",
24
+ "prompt": 'waist-up "{prompt} in a Jungle" by Syd Mead, tangerine cold color palette, muted colors, detailed, 8k,photo r3al,dripping paint,3d toon style,3d style,Movie Still',
25
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
26
+ },
27
+ {
28
+ "name": "Mars",
29
+ "prompt": "{prompt}, Post-apocalyptic. Mars Colony, Scavengers roam the wastelands searching for valuable resources, rovers, bright morning sunlight shining, (detailed) (intricate) (8k) (HDR) (cinematic lighting) (sharp focus)",
30
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
31
+ },
32
+ {
33
+ "name": "Vibrant Color",
34
+ "prompt": "vibrant colorful, ink sketch|vector|2d colors, at nightfall, sharp focus, {prompt}, highly detailed, sharp focus, the clouds,colorful,ultra sharpness",
35
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
36
+ },
37
+ {
38
+ "name": "Snow",
39
+ "prompt": "cinema 4d render, {prompt}, high contrast, vibrant and saturated, sico style, surrounded by magical glow,floating ice shards, snow crystals, cold, windy background, frozen natural landscape in background cinematic atmosphere,highly detailed, sharp focus, intricate design, 3d, unreal engine, octane render, CG best quality, highres, photorealistic, dramatic lighting, artstation, concept art, cinematic, epic Steven Spielberg movie still, sharp focus, smoke, sparks, art by pascal blanche and greg rutkowski and repin, trending on artstation, hyperrealism painting, matte painting, 4k resolution",
40
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
41
+ },
42
+ {
43
+ "name": "Line art",
44
+ "prompt": "line art drawing {prompt} . professional, sleek, modern, minimalist, graphic, line art, vector graphics",
45
+ "negative_prompt": "anime, photorealistic, 35mm film, deformed, glitch, blurry, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, mutated, realism, realistic, impressionism, expressionism, oil, acrylic",
46
+ },
47
+ ]
48
+
49
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}