Spaces:
Runtime error
Runtime error
ResearcherXman
commited on
Commit
·
0ba2339
1
Parent(s):
52ae519
update
Browse files- app.py +250 -38
- ip_adapter/attention_processor.py +308 -0
- ip_adapter/resampler.py +121 -0
- ip_adapter/utils.py +5 -0
- models/antelopev2/1k3d68.onnx +3 -0
- models/antelopev2/2d106det.onnx +3 -0
- models/antelopev2/genderage.onnx +3 -0
- models/antelopev2/glintr100.onnx +3 -0
- models/antelopev2/scrfd_10g_bnkps.onnx +3 -0
- pipeline_stable_diffusion_xl_instantid.py +1134 -0
- style_template.py +49 -0
app.py
CHANGED
@@ -1,16 +1,60 @@
|
|
1 |
import os
|
2 |
import cv2
|
3 |
import math
|
|
|
4 |
import random
|
5 |
import numpy as np
|
|
|
|
|
6 |
from PIL import Image
|
7 |
|
|
|
8 |
from diffusers.utils import load_image
|
|
|
|
|
|
|
|
|
9 |
|
|
|
|
|
|
|
|
|
10 |
import gradio as gr
|
11 |
|
12 |
# global variable
|
13 |
MAX_SEED = np.iinfo(np.int32).max
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
16 |
if randomize_seed:
|
@@ -29,14 +73,174 @@ def remove_back_to_files():
|
|
29 |
def remove_tips():
|
30 |
return gr.update(visible=False)
|
31 |
|
32 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
if face_image is None:
|
35 |
raise gr.Error(f"Cannot find any input face image! Please upload the face image")
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
face_image = load_image(face_image[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
-
return
|
40 |
|
41 |
### Description
|
42 |
title = r"""
|
@@ -47,9 +251,9 @@ description = r"""
|
|
47 |
<b>Official 🤗 Gradio demo</b> for <a href='https://github.com/InstantID/InstantID' target='_blank'><b>InstantID: Zero-shot Identity-Preserving Generation in Seconds</b></a>.<br>
|
48 |
|
49 |
How to use:<br>
|
50 |
-
1. Upload a person image
|
51 |
-
2. (Optionally) upload another person image as reference pose. If not uploaded, we will use the first person image to extract landmarks.
|
52 |
-
3. Enter a text prompt as normal text-to-image
|
53 |
4. Click the <b>Submit</b> button to start customizing.
|
54 |
5. Share your customizd photo with your friends, enjoy😊!
|
55 |
"""
|
@@ -67,7 +271,6 @@ If our work is helpful for your research or applications, please cite us via:
|
|
67 |
year={2024}
|
68 |
}
|
69 |
```
|
70 |
-
|
71 |
📧 **Contact**
|
72 |
<br>
|
73 |
If you have any questions, please feel free to open an issue or directly reach us out at <b>haofanwang.ai@gmail.com</b>.
|
@@ -75,9 +278,10 @@ If you have any questions, please feel free to open an issue or directly reach u
|
|
75 |
|
76 |
tips = r"""
|
77 |
### Usage tips of InstantID
|
78 |
-
1. If you're
|
79 |
-
2. If
|
80 |
-
3. If
|
|
|
81 |
"""
|
82 |
|
83 |
css = '''
|
@@ -113,14 +317,34 @@ with gr.Blocks(css=css) as demo:
|
|
113 |
# prompt
|
114 |
prompt = gr.Textbox(label="Prompt",
|
115 |
info="Give simple prompt is enough to achieve good face fedility",
|
116 |
-
placeholder="A photo of a
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
with gr.Accordion(open=False, label="Advanced Options"):
|
120 |
negative_prompt = gr.Textbox(
|
121 |
label="Negative Prompt",
|
122 |
placeholder="low quality",
|
123 |
-
value="
|
124 |
)
|
125 |
num_steps = gr.Slider(
|
126 |
label="Number of sample steps",
|
@@ -129,27 +353,6 @@ with gr.Blocks(css=css) as demo:
|
|
129 |
step=1,
|
130 |
value=30,
|
131 |
)
|
132 |
-
identitynet_strength_ratio = gr.Slider(
|
133 |
-
label="IdentityNet strength",
|
134 |
-
minimum=0,
|
135 |
-
maximum=1.5,
|
136 |
-
step=0.05,
|
137 |
-
value=0.65,
|
138 |
-
)
|
139 |
-
adapter_strength_ratio = gr.Slider(
|
140 |
-
label="Image adapter strength",
|
141 |
-
minimum=0,
|
142 |
-
maximum=1,
|
143 |
-
step=0.05,
|
144 |
-
value=0.30,
|
145 |
-
)
|
146 |
-
num_outputs = gr.Slider(
|
147 |
-
label="Number of output images",
|
148 |
-
minimum=1,
|
149 |
-
maximum=4,
|
150 |
-
step=1,
|
151 |
-
value=2,
|
152 |
-
)
|
153 |
guidance_scale = gr.Slider(
|
154 |
label="Guidance scale",
|
155 |
minimum=0.1,
|
@@ -165,6 +368,7 @@ with gr.Blocks(css=css) as demo:
|
|
165 |
value=42,
|
166 |
)
|
167 |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
|
|
168 |
|
169 |
with gr.Column():
|
170 |
gallery = gr.Gallery(label="Generated Images")
|
@@ -187,10 +391,18 @@ with gr.Blocks(css=css) as demo:
|
|
187 |
api_name=False,
|
188 |
).then(
|
189 |
fn=generate_image,
|
190 |
-
inputs=[face_files, pose_files, prompt, negative_prompt, num_steps, identitynet_strength_ratio, adapter_strength_ratio,
|
191 |
outputs=[gallery, usage_tips]
|
192 |
)
|
193 |
-
|
194 |
-
gr.Markdown(article)
|
195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
demo.launch()
|
|
|
1 |
import os
|
2 |
import cv2
|
3 |
import math
|
4 |
+
import torch
|
5 |
import random
|
6 |
import numpy as np
|
7 |
+
|
8 |
+
import PIL
|
9 |
from PIL import Image
|
10 |
|
11 |
+
import diffusers
|
12 |
from diffusers.utils import load_image
|
13 |
+
from diffusers.models import ControlNetModel
|
14 |
+
|
15 |
+
import insightface
|
16 |
+
from insightface.app import FaceAnalysis
|
17 |
|
18 |
+
from style_template import styles
|
19 |
+
from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline
|
20 |
+
|
21 |
+
import spaces
|
22 |
import gradio as gr
|
23 |
|
24 |
# global variable
|
25 |
MAX_SEED = np.iinfo(np.int32).max
|
26 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
27 |
+
STYLE_NAMES = list(styles.keys())
|
28 |
+
DEFAULT_STYLE_NAME = "Watercolor"
|
29 |
+
|
30 |
+
# download checkpoints
|
31 |
+
from huggingface_hub import hf_hub_download
|
32 |
+
hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir="./checkpoints")
|
33 |
+
hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir="./checkpoints")
|
34 |
+
hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints")
|
35 |
+
|
36 |
+
# Load face encoder
|
37 |
+
app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
38 |
+
app.prepare(ctx_id=0, det_size=(640, 640))
|
39 |
+
|
40 |
+
# Path to InstantID models
|
41 |
+
face_adapter = f'./checkpoints/ip-adapter.bin'
|
42 |
+
controlnet_path = f'./checkpoints/ControlNetModel'
|
43 |
+
|
44 |
+
# Load pipeline
|
45 |
+
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
|
46 |
+
|
47 |
+
base_model_path = 'GHArt/Unstable_Diffusers_YamerMIX_V9_xl_fp16'
|
48 |
+
|
49 |
+
pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
|
50 |
+
base_model_path,
|
51 |
+
controlnet=controlnet,
|
52 |
+
torch_dtype=torch.float16,
|
53 |
+
safety_checker=None,
|
54 |
+
feature_extractor=None,
|
55 |
+
)
|
56 |
+
pipe.cuda()
|
57 |
+
pipe.load_ip_adapter_instantid(face_adapter)
|
58 |
|
59 |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
60 |
if randomize_seed:
|
|
|
73 |
def remove_tips():
|
74 |
return gr.update(visible=False)
|
75 |
|
76 |
+
def get_example():
|
77 |
+
case = [
|
78 |
+
[
|
79 |
+
['./examples/yann-lecun_resize.jpg'],
|
80 |
+
"a man",
|
81 |
+
"Snow",
|
82 |
+
"(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
83 |
+
],
|
84 |
+
[
|
85 |
+
['./examples/musk_resize.jpeg'],
|
86 |
+
"a man",
|
87 |
+
"Mars",
|
88 |
+
"(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
89 |
+
],
|
90 |
+
[
|
91 |
+
['./examples/sam_resize.png'],
|
92 |
+
"a man",
|
93 |
+
"Jungle",
|
94 |
+
"(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, gree",
|
95 |
+
],
|
96 |
+
[
|
97 |
+
['./examples/schmidhuber_resize.png'],
|
98 |
+
"a man",
|
99 |
+
"Neon",
|
100 |
+
"(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
101 |
+
],
|
102 |
+
[
|
103 |
+
['./examples/kaifu_resize.png'],
|
104 |
+
"a man",
|
105 |
+
"Vibrant Color",
|
106 |
+
"(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
107 |
+
],
|
108 |
+
]
|
109 |
+
return case
|
110 |
+
|
111 |
+
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
|
112 |
+
return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
113 |
+
|
114 |
+
def convert_from_image_to_cv2(img: Image) -> np.ndarray:
|
115 |
+
return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
116 |
+
|
117 |
+
def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]):
|
118 |
+
stickwidth = 4
|
119 |
+
limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
|
120 |
+
kps = np.array(kps)
|
121 |
+
|
122 |
+
w, h = image_pil.size
|
123 |
+
out_img = np.zeros([h, w, 3])
|
124 |
+
|
125 |
+
for i in range(len(limbSeq)):
|
126 |
+
index = limbSeq[i]
|
127 |
+
color = color_list[index[0]]
|
128 |
+
|
129 |
+
x = kps[index][:, 0]
|
130 |
+
y = kps[index][:, 1]
|
131 |
+
length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
|
132 |
+
angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
|
133 |
+
polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
134 |
+
out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
|
135 |
+
out_img = (out_img * 0.6).astype(np.uint8)
|
136 |
+
|
137 |
+
for idx_kp, kp in enumerate(kps):
|
138 |
+
color = color_list[idx_kp]
|
139 |
+
x, y = kp
|
140 |
+
out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
|
141 |
+
|
142 |
+
out_img_pil = Image.fromarray(out_img.astype(np.uint8))
|
143 |
+
return out_img_pil
|
144 |
+
|
145 |
+
def resize_img(input_image, max_side=1280, min_side=1024, size=None,
|
146 |
+
pad_to_max_side=False, mode=PIL.Image.BILINEAR, base_pixel_number=64):
|
147 |
+
|
148 |
+
w, h = input_image.size
|
149 |
+
if size is not None:
|
150 |
+
w_resize_new, h_resize_new = size
|
151 |
+
else:
|
152 |
+
ratio = min_side / min(h, w)
|
153 |
+
w, h = round(ratio*w), round(ratio*h)
|
154 |
+
ratio = max_side / max(h, w)
|
155 |
+
input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode)
|
156 |
+
w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
|
157 |
+
h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
|
158 |
+
input_image = input_image.resize([w_resize_new, h_resize_new], mode)
|
159 |
+
|
160 |
+
if pad_to_max_side:
|
161 |
+
res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
|
162 |
+
offset_x = (max_side - w_resize_new) // 2
|
163 |
+
offset_y = (max_side - h_resize_new) // 2
|
164 |
+
res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image)
|
165 |
+
input_image = Image.fromarray(res)
|
166 |
+
return input_image
|
167 |
+
|
168 |
+
def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
|
169 |
+
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
170 |
+
return p.replace("{prompt}", positive), n + ' ' + negative
|
171 |
+
|
172 |
+
@spaces.GPU
|
173 |
+
def generate_image(face_image, pose_image, prompt, negative_prompt, style_name, enhance_face_region, num_steps, identitynet_strength_ratio, adapter_strength_ratio, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
|
174 |
|
175 |
if face_image is None:
|
176 |
raise gr.Error(f"Cannot find any input face image! Please upload the face image")
|
177 |
+
|
178 |
+
if prompt is None:
|
179 |
+
prompt = "a person"
|
180 |
+
|
181 |
+
# apply the style template
|
182 |
+
prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
|
183 |
+
|
184 |
face_image = load_image(face_image[0])
|
185 |
+
face_image = resize_img(face_image)
|
186 |
+
face_image_cv2 = convert_from_image_to_cv2(face_image)
|
187 |
+
height, width, _ = face_image_cv2.shape
|
188 |
+
|
189 |
+
# Extract face features
|
190 |
+
face_info = app.get(face_image_cv2)
|
191 |
+
|
192 |
+
if len(face_info) == 0:
|
193 |
+
raise gr.Error(f"Cannot find any face in the image! Please upload another person image")
|
194 |
+
|
195 |
+
face_info = face_info[-1]
|
196 |
+
face_emb = face_info['embedding']
|
197 |
+
face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info['kps'])
|
198 |
+
|
199 |
+
if pose_image is not None:
|
200 |
+
pose_image = load_image(pose_image[0])
|
201 |
+
pose_image = resize_img(pose_image)
|
202 |
+
pose_image_cv2 = convert_from_image_to_cv2(pose_image)
|
203 |
+
|
204 |
+
face_info = app.get(pose_image_cv2)
|
205 |
+
|
206 |
+
if len(face_info) == 0:
|
207 |
+
raise gr.Error(f"Cannot find any face in the reference image! Please upload another person image")
|
208 |
+
|
209 |
+
face_info = face_info[-1]
|
210 |
+
face_kps = draw_kps(pose_image, face_info['kps'])
|
211 |
+
|
212 |
+
width, height = face_kps.size
|
213 |
+
|
214 |
+
if enhance_face_region:
|
215 |
+
control_mask = np.zeros([height, width, 3])
|
216 |
+
x1, y1, x2, y2 = face_info['bbox']
|
217 |
+
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
218 |
+
control_mask[y1:y2, x1:x2] = 255
|
219 |
+
control_mask = Image.fromarray(control_mask.astype(np.uint8))
|
220 |
+
else:
|
221 |
+
control_mask = None
|
222 |
+
|
223 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
224 |
+
|
225 |
+
print("Start inference...")
|
226 |
+
print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}")
|
227 |
+
|
228 |
+
pipe.set_ip_adapter_scale(adapter_strength_ratio)
|
229 |
+
images = pipe(
|
230 |
+
prompt=prompt,
|
231 |
+
negative_prompt=negative_prompt,
|
232 |
+
image_embeds=face_emb,
|
233 |
+
image=face_kps,
|
234 |
+
control_mask=control_mask,
|
235 |
+
controlnet_conditioning_scale=float(identitynet_strength_ratio),
|
236 |
+
num_inference_steps=num_steps,
|
237 |
+
guidance_scale=guidance_scale,
|
238 |
+
height=height,
|
239 |
+
width=width,
|
240 |
+
generator=generator
|
241 |
+
).images
|
242 |
|
243 |
+
return images, gr.update(visible=True)
|
244 |
|
245 |
### Description
|
246 |
title = r"""
|
|
|
251 |
<b>Official 🤗 Gradio demo</b> for <a href='https://github.com/InstantID/InstantID' target='_blank'><b>InstantID: Zero-shot Identity-Preserving Generation in Seconds</b></a>.<br>
|
252 |
|
253 |
How to use:<br>
|
254 |
+
1. Upload a person image. For multiple person images, we will only detect the biggest face. Make sure face is not too small and not significantly blocked or blurred.
|
255 |
+
2. (Optionally) upload another person image as reference pose. If not uploaded, we will use the first person image to extract landmarks. If you use a cropped face at step1, it is recommeneded to upload it to extract a new pose.
|
256 |
+
3. Enter a text prompt as done in normal text-to-image models.
|
257 |
4. Click the <b>Submit</b> button to start customizing.
|
258 |
5. Share your customizd photo with your friends, enjoy😊!
|
259 |
"""
|
|
|
271 |
year={2024}
|
272 |
}
|
273 |
```
|
|
|
274 |
📧 **Contact**
|
275 |
<br>
|
276 |
If you have any questions, please feel free to open an issue or directly reach us out at <b>haofanwang.ai@gmail.com</b>.
|
|
|
278 |
|
279 |
tips = r"""
|
280 |
### Usage tips of InstantID
|
281 |
+
1. If you're unsatisfied with the similarity, increase the weight of controlnet_conditioning_scale (IdentityNet) and ip_adapter_scale (Adapter).
|
282 |
+
2. If the generated image is over-saturated, decrease the ip_adapter_scale. If not work, decrease controlnet_conditioning_scale.
|
283 |
+
3. If text control is not as expected, decrease ip_adapter_scale.
|
284 |
+
4. Find a good base model always makes a difference.
|
285 |
"""
|
286 |
|
287 |
css = '''
|
|
|
317 |
# prompt
|
318 |
prompt = gr.Textbox(label="Prompt",
|
319 |
info="Give simple prompt is enough to achieve good face fedility",
|
320 |
+
placeholder="A photo of a person",
|
321 |
+
value="")
|
322 |
+
|
323 |
+
submit = gr.Button("Submit", variant="primary")
|
324 |
+
|
325 |
+
style = gr.Dropdown(label="Style template", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
|
326 |
+
|
327 |
+
# strength
|
328 |
+
identitynet_strength_ratio = gr.Slider(
|
329 |
+
label="IdentityNet strength (for fedility)",
|
330 |
+
minimum=0,
|
331 |
+
maximum=1.5,
|
332 |
+
step=0.05,
|
333 |
+
value=0.80,
|
334 |
+
)
|
335 |
+
adapter_strength_ratio = gr.Slider(
|
336 |
+
label="Image adapter strength (for detail)",
|
337 |
+
minimum=0,
|
338 |
+
maximum=1.5,
|
339 |
+
step=0.05,
|
340 |
+
value=0.80,
|
341 |
+
)
|
342 |
+
|
343 |
with gr.Accordion(open=False, label="Advanced Options"):
|
344 |
negative_prompt = gr.Textbox(
|
345 |
label="Negative Prompt",
|
346 |
placeholder="low quality",
|
347 |
+
value="(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
348 |
)
|
349 |
num_steps = gr.Slider(
|
350 |
label="Number of sample steps",
|
|
|
353 |
step=1,
|
354 |
value=30,
|
355 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
guidance_scale = gr.Slider(
|
357 |
label="Guidance scale",
|
358 |
minimum=0.1,
|
|
|
368 |
value=42,
|
369 |
)
|
370 |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
371 |
+
enhance_face_region = gr.Checkbox(label="Enhance non-face region", value=True)
|
372 |
|
373 |
with gr.Column():
|
374 |
gallery = gr.Gallery(label="Generated Images")
|
|
|
391 |
api_name=False,
|
392 |
).then(
|
393 |
fn=generate_image,
|
394 |
+
inputs=[face_files, pose_files, prompt, negative_prompt, style, enhance_face_region, num_steps, identitynet_strength_ratio, adapter_strength_ratio, guidance_scale, seed],
|
395 |
outputs=[gallery, usage_tips]
|
396 |
)
|
|
|
|
|
397 |
|
398 |
+
gr.Examples(
|
399 |
+
examples=get_example(),
|
400 |
+
inputs=[face_files, prompt, style, negative_prompt],
|
401 |
+
run_on_click=True,
|
402 |
+
fn=upload_example_to_gallery,
|
403 |
+
outputs=[uploaded_faces, clear_button_face, face_files],
|
404 |
+
)
|
405 |
+
|
406 |
+
gr.Markdown(article)
|
407 |
+
|
408 |
demo.launch()
|
ip_adapter/attention_processor.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
try:
|
7 |
+
import xformers
|
8 |
+
import xformers.ops
|
9 |
+
xformers_available = True
|
10 |
+
except Exception as e:
|
11 |
+
xformers_available = False
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
class RegionControler(object):
|
16 |
+
def __init__(self) -> None:
|
17 |
+
self.prompt_image_conditioning = []
|
18 |
+
region_control = RegionControler()
|
19 |
+
|
20 |
+
|
21 |
+
class AttnProcessor(nn.Module):
|
22 |
+
r"""
|
23 |
+
Default processor for performing attention-related computations.
|
24 |
+
"""
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
hidden_size=None,
|
28 |
+
cross_attention_dim=None,
|
29 |
+
):
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
def __call__(
|
33 |
+
self,
|
34 |
+
attn,
|
35 |
+
hidden_states,
|
36 |
+
encoder_hidden_states=None,
|
37 |
+
attention_mask=None,
|
38 |
+
temb=None,
|
39 |
+
):
|
40 |
+
residual = hidden_states
|
41 |
+
|
42 |
+
if attn.spatial_norm is not None:
|
43 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
44 |
+
|
45 |
+
input_ndim = hidden_states.ndim
|
46 |
+
|
47 |
+
if input_ndim == 4:
|
48 |
+
batch_size, channel, height, width = hidden_states.shape
|
49 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
50 |
+
|
51 |
+
batch_size, sequence_length, _ = (
|
52 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
53 |
+
)
|
54 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
55 |
+
|
56 |
+
if attn.group_norm is not None:
|
57 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
58 |
+
|
59 |
+
query = attn.to_q(hidden_states)
|
60 |
+
|
61 |
+
if encoder_hidden_states is None:
|
62 |
+
encoder_hidden_states = hidden_states
|
63 |
+
elif attn.norm_cross:
|
64 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
65 |
+
|
66 |
+
key = attn.to_k(encoder_hidden_states)
|
67 |
+
value = attn.to_v(encoder_hidden_states)
|
68 |
+
|
69 |
+
query = attn.head_to_batch_dim(query)
|
70 |
+
key = attn.head_to_batch_dim(key)
|
71 |
+
value = attn.head_to_batch_dim(value)
|
72 |
+
|
73 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
74 |
+
hidden_states = torch.bmm(attention_probs, value)
|
75 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
76 |
+
|
77 |
+
# linear proj
|
78 |
+
hidden_states = attn.to_out[0](hidden_states)
|
79 |
+
# dropout
|
80 |
+
hidden_states = attn.to_out[1](hidden_states)
|
81 |
+
|
82 |
+
if input_ndim == 4:
|
83 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
84 |
+
|
85 |
+
if attn.residual_connection:
|
86 |
+
hidden_states = hidden_states + residual
|
87 |
+
|
88 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
89 |
+
|
90 |
+
return hidden_states
|
91 |
+
|
92 |
+
|
93 |
+
class IPAttnProcessor(nn.Module):
|
94 |
+
r"""
|
95 |
+
Attention processor for IP-Adapater.
|
96 |
+
Args:
|
97 |
+
hidden_size (`int`):
|
98 |
+
The hidden size of the attention layer.
|
99 |
+
cross_attention_dim (`int`):
|
100 |
+
The number of channels in the `encoder_hidden_states`.
|
101 |
+
scale (`float`, defaults to 1.0):
|
102 |
+
the weight scale of image prompt.
|
103 |
+
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
104 |
+
The context length of the image features.
|
105 |
+
"""
|
106 |
+
|
107 |
+
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
|
108 |
+
super().__init__()
|
109 |
+
|
110 |
+
self.hidden_size = hidden_size
|
111 |
+
self.cross_attention_dim = cross_attention_dim
|
112 |
+
self.scale = scale
|
113 |
+
self.num_tokens = num_tokens
|
114 |
+
|
115 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
116 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
117 |
+
|
118 |
+
def __call__(
|
119 |
+
self,
|
120 |
+
attn,
|
121 |
+
hidden_states,
|
122 |
+
encoder_hidden_states=None,
|
123 |
+
attention_mask=None,
|
124 |
+
temb=None,
|
125 |
+
):
|
126 |
+
residual = hidden_states
|
127 |
+
|
128 |
+
if attn.spatial_norm is not None:
|
129 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
130 |
+
|
131 |
+
input_ndim = hidden_states.ndim
|
132 |
+
|
133 |
+
if input_ndim == 4:
|
134 |
+
batch_size, channel, height, width = hidden_states.shape
|
135 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
136 |
+
|
137 |
+
batch_size, sequence_length, _ = (
|
138 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
139 |
+
)
|
140 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
141 |
+
|
142 |
+
if attn.group_norm is not None:
|
143 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
144 |
+
|
145 |
+
query = attn.to_q(hidden_states)
|
146 |
+
|
147 |
+
if encoder_hidden_states is None:
|
148 |
+
encoder_hidden_states = hidden_states
|
149 |
+
else:
|
150 |
+
# get encoder_hidden_states, ip_hidden_states
|
151 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
152 |
+
encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :]
|
153 |
+
if attn.norm_cross:
|
154 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
155 |
+
|
156 |
+
key = attn.to_k(encoder_hidden_states)
|
157 |
+
value = attn.to_v(encoder_hidden_states)
|
158 |
+
|
159 |
+
query = attn.head_to_batch_dim(query)
|
160 |
+
key = attn.head_to_batch_dim(key)
|
161 |
+
value = attn.head_to_batch_dim(value)
|
162 |
+
|
163 |
+
if xformers_available:
|
164 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
165 |
+
else:
|
166 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
167 |
+
hidden_states = torch.bmm(attention_probs, value)
|
168 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
169 |
+
|
170 |
+
# for ip-adapter
|
171 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
172 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
173 |
+
|
174 |
+
ip_key = attn.head_to_batch_dim(ip_key)
|
175 |
+
ip_value = attn.head_to_batch_dim(ip_value)
|
176 |
+
|
177 |
+
if xformers_available:
|
178 |
+
ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
|
179 |
+
else:
|
180 |
+
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
181 |
+
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
182 |
+
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
183 |
+
|
184 |
+
# region control
|
185 |
+
if len(region_control.prompt_image_conditioning) == 1:
|
186 |
+
region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
|
187 |
+
if region_mask is not None:
|
188 |
+
h, w = region_mask.shape[:2]
|
189 |
+
ratio = (h * w / query.shape[1]) ** 0.5
|
190 |
+
mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
|
191 |
+
else:
|
192 |
+
mask = torch.ones_like(ip_hidden_states)
|
193 |
+
ip_hidden_states = ip_hidden_states * mask
|
194 |
+
|
195 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
196 |
+
|
197 |
+
# linear proj
|
198 |
+
hidden_states = attn.to_out[0](hidden_states)
|
199 |
+
# dropout
|
200 |
+
hidden_states = attn.to_out[1](hidden_states)
|
201 |
+
|
202 |
+
if input_ndim == 4:
|
203 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
204 |
+
|
205 |
+
if attn.residual_connection:
|
206 |
+
hidden_states = hidden_states + residual
|
207 |
+
|
208 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
209 |
+
|
210 |
+
return hidden_states
|
211 |
+
|
212 |
+
|
213 |
+
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
|
214 |
+
# TODO attention_mask
|
215 |
+
query = query.contiguous()
|
216 |
+
key = key.contiguous()
|
217 |
+
value = value.contiguous()
|
218 |
+
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
|
219 |
+
# hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
220 |
+
return hidden_states
|
221 |
+
|
222 |
+
|
223 |
+
class AttnProcessor2_0(torch.nn.Module):
|
224 |
+
r"""
|
225 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
226 |
+
"""
|
227 |
+
def __init__(
|
228 |
+
self,
|
229 |
+
hidden_size=None,
|
230 |
+
cross_attention_dim=None,
|
231 |
+
):
|
232 |
+
super().__init__()
|
233 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
234 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
235 |
+
|
236 |
+
def __call__(
|
237 |
+
self,
|
238 |
+
attn,
|
239 |
+
hidden_states,
|
240 |
+
encoder_hidden_states=None,
|
241 |
+
attention_mask=None,
|
242 |
+
temb=None,
|
243 |
+
):
|
244 |
+
residual = hidden_states
|
245 |
+
|
246 |
+
if attn.spatial_norm is not None:
|
247 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
248 |
+
|
249 |
+
input_ndim = hidden_states.ndim
|
250 |
+
|
251 |
+
if input_ndim == 4:
|
252 |
+
batch_size, channel, height, width = hidden_states.shape
|
253 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
254 |
+
|
255 |
+
batch_size, sequence_length, _ = (
|
256 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
257 |
+
)
|
258 |
+
|
259 |
+
if attention_mask is not None:
|
260 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
261 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
262 |
+
# (batch, heads, source_length, target_length)
|
263 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
264 |
+
|
265 |
+
if attn.group_norm is not None:
|
266 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
267 |
+
|
268 |
+
query = attn.to_q(hidden_states)
|
269 |
+
|
270 |
+
if encoder_hidden_states is None:
|
271 |
+
encoder_hidden_states = hidden_states
|
272 |
+
elif attn.norm_cross:
|
273 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
274 |
+
|
275 |
+
key = attn.to_k(encoder_hidden_states)
|
276 |
+
value = attn.to_v(encoder_hidden_states)
|
277 |
+
|
278 |
+
inner_dim = key.shape[-1]
|
279 |
+
head_dim = inner_dim // attn.heads
|
280 |
+
|
281 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
282 |
+
|
283 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
284 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
285 |
+
|
286 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
287 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
288 |
+
hidden_states = F.scaled_dot_product_attention(
|
289 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
290 |
+
)
|
291 |
+
|
292 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
293 |
+
hidden_states = hidden_states.to(query.dtype)
|
294 |
+
|
295 |
+
# linear proj
|
296 |
+
hidden_states = attn.to_out[0](hidden_states)
|
297 |
+
# dropout
|
298 |
+
hidden_states = attn.to_out[1](hidden_states)
|
299 |
+
|
300 |
+
if input_ndim == 4:
|
301 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
302 |
+
|
303 |
+
if attn.residual_connection:
|
304 |
+
hidden_states = hidden_states + residual
|
305 |
+
|
306 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
307 |
+
|
308 |
+
return hidden_states
|
ip_adapter/resampler.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
|
8 |
+
# FFN
|
9 |
+
def FeedForward(dim, mult=4):
|
10 |
+
inner_dim = int(dim * mult)
|
11 |
+
return nn.Sequential(
|
12 |
+
nn.LayerNorm(dim),
|
13 |
+
nn.Linear(dim, inner_dim, bias=False),
|
14 |
+
nn.GELU(),
|
15 |
+
nn.Linear(inner_dim, dim, bias=False),
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
def reshape_tensor(x, heads):
|
20 |
+
bs, length, width = x.shape
|
21 |
+
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
22 |
+
x = x.view(bs, length, heads, -1)
|
23 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
24 |
+
x = x.transpose(1, 2)
|
25 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
26 |
+
x = x.reshape(bs, heads, length, -1)
|
27 |
+
return x
|
28 |
+
|
29 |
+
|
30 |
+
class PerceiverAttention(nn.Module):
|
31 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
32 |
+
super().__init__()
|
33 |
+
self.scale = dim_head**-0.5
|
34 |
+
self.dim_head = dim_head
|
35 |
+
self.heads = heads
|
36 |
+
inner_dim = dim_head * heads
|
37 |
+
|
38 |
+
self.norm1 = nn.LayerNorm(dim)
|
39 |
+
self.norm2 = nn.LayerNorm(dim)
|
40 |
+
|
41 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
42 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
43 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
44 |
+
|
45 |
+
|
46 |
+
def forward(self, x, latents):
|
47 |
+
"""
|
48 |
+
Args:
|
49 |
+
x (torch.Tensor): image features
|
50 |
+
shape (b, n1, D)
|
51 |
+
latent (torch.Tensor): latent features
|
52 |
+
shape (b, n2, D)
|
53 |
+
"""
|
54 |
+
x = self.norm1(x)
|
55 |
+
latents = self.norm2(latents)
|
56 |
+
|
57 |
+
b, l, _ = latents.shape
|
58 |
+
|
59 |
+
q = self.to_q(latents)
|
60 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
61 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
62 |
+
|
63 |
+
q = reshape_tensor(q, self.heads)
|
64 |
+
k = reshape_tensor(k, self.heads)
|
65 |
+
v = reshape_tensor(v, self.heads)
|
66 |
+
|
67 |
+
# attention
|
68 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
69 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
70 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
71 |
+
out = weight @ v
|
72 |
+
|
73 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
74 |
+
|
75 |
+
return self.to_out(out)
|
76 |
+
|
77 |
+
|
78 |
+
class Resampler(nn.Module):
|
79 |
+
def __init__(
|
80 |
+
self,
|
81 |
+
dim=1024,
|
82 |
+
depth=8,
|
83 |
+
dim_head=64,
|
84 |
+
heads=16,
|
85 |
+
num_queries=8,
|
86 |
+
embedding_dim=768,
|
87 |
+
output_dim=1024,
|
88 |
+
ff_mult=4,
|
89 |
+
):
|
90 |
+
super().__init__()
|
91 |
+
|
92 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
93 |
+
|
94 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
95 |
+
|
96 |
+
self.proj_out = nn.Linear(dim, output_dim)
|
97 |
+
self.norm_out = nn.LayerNorm(output_dim)
|
98 |
+
|
99 |
+
self.layers = nn.ModuleList([])
|
100 |
+
for _ in range(depth):
|
101 |
+
self.layers.append(
|
102 |
+
nn.ModuleList(
|
103 |
+
[
|
104 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
105 |
+
FeedForward(dim=dim, mult=ff_mult),
|
106 |
+
]
|
107 |
+
)
|
108 |
+
)
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
|
112 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
113 |
+
|
114 |
+
x = self.proj_in(x)
|
115 |
+
|
116 |
+
for attn, ff in self.layers:
|
117 |
+
latents = attn(x, latents) + latents
|
118 |
+
latents = ff(latents) + latents
|
119 |
+
|
120 |
+
latents = self.proj_out(latents)
|
121 |
+
return self.norm_out(latents)
|
ip_adapter/utils.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn.functional as F
|
2 |
+
|
3 |
+
|
4 |
+
def is_torch2_available():
|
5 |
+
return hasattr(F, "scaled_dot_product_attention")
|
models/antelopev2/1k3d68.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc
|
3 |
+
size 143607619
|
models/antelopev2/2d106det.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf
|
3 |
+
size 5030888
|
models/antelopev2/genderage.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb
|
3 |
+
size 1322532
|
models/antelopev2/glintr100.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4ab1d6435d639628a6f3e5008dd4f929edf4c4124b1a7169e1048f9fef534cdf
|
3 |
+
size 260665334
|
models/antelopev2/scrfd_10g_bnkps.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91
|
3 |
+
size 16923827
|
pipeline_stable_diffusion_xl_instantid.py
ADDED
@@ -0,0 +1,1134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The InstantX Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import cv2
|
19 |
+
import math
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import PIL.Image
|
23 |
+
import torch
|
24 |
+
import torch.nn.functional as F
|
25 |
+
|
26 |
+
from diffusers.image_processor import PipelineImageInput
|
27 |
+
|
28 |
+
from diffusers.models import ControlNetModel
|
29 |
+
|
30 |
+
from diffusers.utils import (
|
31 |
+
deprecate,
|
32 |
+
logging,
|
33 |
+
replace_example_docstring,
|
34 |
+
)
|
35 |
+
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
|
36 |
+
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
37 |
+
|
38 |
+
from diffusers import StableDiffusionXLControlNetPipeline
|
39 |
+
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
40 |
+
from diffusers.utils.import_utils import is_xformers_available
|
41 |
+
|
42 |
+
from ip_adapter.resampler import Resampler
|
43 |
+
from ip_adapter.utils import is_torch2_available
|
44 |
+
|
45 |
+
if is_torch2_available():
|
46 |
+
from ip_adapter.attention_processor import (
|
47 |
+
AttnProcessor2_0 as AttnProcessor,
|
48 |
+
)
|
49 |
+
from ip_adapter.attention_processor import (
|
50 |
+
IPAttnProcessor2_0 as IPAttnProcessor,
|
51 |
+
)
|
52 |
+
else:
|
53 |
+
from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
|
54 |
+
from ip_adapter.attention_processor import region_control
|
55 |
+
|
56 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
57 |
+
|
58 |
+
|
59 |
+
EXAMPLE_DOC_STRING = """
|
60 |
+
Examples:
|
61 |
+
```py
|
62 |
+
>>> # !pip install opencv-python transformers accelerate insightface
|
63 |
+
>>> import diffusers
|
64 |
+
>>> from diffusers.utils import load_image
|
65 |
+
>>> from diffusers.models import ControlNetModel
|
66 |
+
|
67 |
+
>>> import cv2
|
68 |
+
>>> import torch
|
69 |
+
>>> import numpy as np
|
70 |
+
>>> from PIL import Image
|
71 |
+
|
72 |
+
>>> from insightface.app import FaceAnalysis
|
73 |
+
>>> from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
|
74 |
+
|
75 |
+
>>> # download 'antelopev2' under ./models
|
76 |
+
>>> app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
77 |
+
>>> app.prepare(ctx_id=0, det_size=(640, 640))
|
78 |
+
|
79 |
+
>>> # download models under ./checkpoints
|
80 |
+
>>> face_adapter = f'./checkpoints/ip-adapter.bin'
|
81 |
+
>>> controlnet_path = f'./checkpoints/ControlNetModel'
|
82 |
+
|
83 |
+
>>> # load IdentityNet
|
84 |
+
>>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
|
85 |
+
|
86 |
+
>>> pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
|
87 |
+
... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
|
88 |
+
... )
|
89 |
+
>>> pipe.cuda()
|
90 |
+
|
91 |
+
>>> # load adapter
|
92 |
+
>>> pipe.load_ip_adapter_instantid(face_adapter)
|
93 |
+
|
94 |
+
>>> prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality"
|
95 |
+
>>> negative_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured"
|
96 |
+
|
97 |
+
>>> # load an image
|
98 |
+
>>> image = load_image("your-example.jpg")
|
99 |
+
|
100 |
+
>>> face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))[-1]
|
101 |
+
>>> face_emb = face_info['embedding']
|
102 |
+
>>> face_kps = draw_kps(face_image, face_info['kps'])
|
103 |
+
|
104 |
+
>>> pipe.set_ip_adapter_scale(0.8)
|
105 |
+
|
106 |
+
>>> # generate image
|
107 |
+
>>> image = pipe(
|
108 |
+
... prompt, image_embeds=face_emb, image=face_kps, controlnet_conditioning_scale=0.8
|
109 |
+
... ).images[0]
|
110 |
+
```
|
111 |
+
"""
|
112 |
+
|
113 |
+
|
114 |
+
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
|
115 |
+
class LongPromptWeight(object):
|
116 |
+
|
117 |
+
"""
|
118 |
+
Copied from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion_xl.py
|
119 |
+
"""
|
120 |
+
|
121 |
+
def __init__(self) -> None:
|
122 |
+
pass
|
123 |
+
|
124 |
+
def parse_prompt_attention(self, text):
|
125 |
+
"""
|
126 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
127 |
+
Accepted tokens are:
|
128 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
129 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
130 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
131 |
+
\( - literal character '('
|
132 |
+
\[ - literal character '['
|
133 |
+
\) - literal character ')'
|
134 |
+
\] - literal character ']'
|
135 |
+
\\ - literal character '\'
|
136 |
+
anything else - just text
|
137 |
+
|
138 |
+
>>> parse_prompt_attention('normal text')
|
139 |
+
[['normal text', 1.0]]
|
140 |
+
>>> parse_prompt_attention('an (important) word')
|
141 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
142 |
+
>>> parse_prompt_attention('(unbalanced')
|
143 |
+
[['unbalanced', 1.1]]
|
144 |
+
>>> parse_prompt_attention('\(literal\]')
|
145 |
+
[['(literal]', 1.0]]
|
146 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
147 |
+
[['unnecessaryparens', 1.1]]
|
148 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
149 |
+
[['a ', 1.0],
|
150 |
+
['house', 1.5730000000000004],
|
151 |
+
[' ', 1.1],
|
152 |
+
['on', 1.0],
|
153 |
+
[' a ', 1.1],
|
154 |
+
['hill', 0.55],
|
155 |
+
[', sun, ', 1.1],
|
156 |
+
['sky', 1.4641000000000006],
|
157 |
+
['.', 1.1]]
|
158 |
+
"""
|
159 |
+
import re
|
160 |
+
|
161 |
+
re_attention = re.compile(
|
162 |
+
r"""
|
163 |
+
\\\(|\\\)|\\\[|\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|
|
164 |
+
\)|]|[^\\()\[\]:]+|:
|
165 |
+
""",
|
166 |
+
re.X,
|
167 |
+
)
|
168 |
+
|
169 |
+
re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
|
170 |
+
|
171 |
+
res = []
|
172 |
+
round_brackets = []
|
173 |
+
square_brackets = []
|
174 |
+
|
175 |
+
round_bracket_multiplier = 1.1
|
176 |
+
square_bracket_multiplier = 1 / 1.1
|
177 |
+
|
178 |
+
def multiply_range(start_position, multiplier):
|
179 |
+
for p in range(start_position, len(res)):
|
180 |
+
res[p][1] *= multiplier
|
181 |
+
|
182 |
+
for m in re_attention.finditer(text):
|
183 |
+
text = m.group(0)
|
184 |
+
weight = m.group(1)
|
185 |
+
|
186 |
+
if text.startswith("\\"):
|
187 |
+
res.append([text[1:], 1.0])
|
188 |
+
elif text == "(":
|
189 |
+
round_brackets.append(len(res))
|
190 |
+
elif text == "[":
|
191 |
+
square_brackets.append(len(res))
|
192 |
+
elif weight is not None and len(round_brackets) > 0:
|
193 |
+
multiply_range(round_brackets.pop(), float(weight))
|
194 |
+
elif text == ")" and len(round_brackets) > 0:
|
195 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
196 |
+
elif text == "]" and len(square_brackets) > 0:
|
197 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
198 |
+
else:
|
199 |
+
parts = re.split(re_break, text)
|
200 |
+
for i, part in enumerate(parts):
|
201 |
+
if i > 0:
|
202 |
+
res.append(["BREAK", -1])
|
203 |
+
res.append([part, 1.0])
|
204 |
+
|
205 |
+
for pos in round_brackets:
|
206 |
+
multiply_range(pos, round_bracket_multiplier)
|
207 |
+
|
208 |
+
for pos in square_brackets:
|
209 |
+
multiply_range(pos, square_bracket_multiplier)
|
210 |
+
|
211 |
+
if len(res) == 0:
|
212 |
+
res = [["", 1.0]]
|
213 |
+
|
214 |
+
# merge runs of identical weights
|
215 |
+
i = 0
|
216 |
+
while i + 1 < len(res):
|
217 |
+
if res[i][1] == res[i + 1][1]:
|
218 |
+
res[i][0] += res[i + 1][0]
|
219 |
+
res.pop(i + 1)
|
220 |
+
else:
|
221 |
+
i += 1
|
222 |
+
|
223 |
+
return res
|
224 |
+
|
225 |
+
def get_prompts_tokens_with_weights(self, clip_tokenizer: CLIPTokenizer, prompt: str):
|
226 |
+
"""
|
227 |
+
Get prompt token ids and weights, this function works for both prompt and negative prompt
|
228 |
+
|
229 |
+
Args:
|
230 |
+
pipe (CLIPTokenizer)
|
231 |
+
A CLIPTokenizer
|
232 |
+
prompt (str)
|
233 |
+
A prompt string with weights
|
234 |
+
|
235 |
+
Returns:
|
236 |
+
text_tokens (list)
|
237 |
+
A list contains token ids
|
238 |
+
text_weight (list)
|
239 |
+
A list contains the correspodent weight of token ids
|
240 |
+
|
241 |
+
Example:
|
242 |
+
import torch
|
243 |
+
from transformers import CLIPTokenizer
|
244 |
+
|
245 |
+
clip_tokenizer = CLIPTokenizer.from_pretrained(
|
246 |
+
"stablediffusionapi/deliberate-v2"
|
247 |
+
, subfolder = "tokenizer"
|
248 |
+
, dtype = torch.float16
|
249 |
+
)
|
250 |
+
|
251 |
+
token_id_list, token_weight_list = get_prompts_tokens_with_weights(
|
252 |
+
clip_tokenizer = clip_tokenizer
|
253 |
+
,prompt = "a (red:1.5) cat"*70
|
254 |
+
)
|
255 |
+
"""
|
256 |
+
texts_and_weights = self.parse_prompt_attention(prompt)
|
257 |
+
text_tokens, text_weights = [], []
|
258 |
+
for word, weight in texts_and_weights:
|
259 |
+
# tokenize and discard the starting and the ending token
|
260 |
+
token = clip_tokenizer(word, truncation=False).input_ids[1:-1] # so that tokenize whatever length prompt
|
261 |
+
# the returned token is a 1d list: [320, 1125, 539, 320]
|
262 |
+
|
263 |
+
# merge the new tokens to the all tokens holder: text_tokens
|
264 |
+
text_tokens = [*text_tokens, *token]
|
265 |
+
|
266 |
+
# each token chunk will come with one weight, like ['red cat', 2.0]
|
267 |
+
# need to expand weight for each token.
|
268 |
+
chunk_weights = [weight] * len(token)
|
269 |
+
|
270 |
+
# append the weight back to the weight holder: text_weights
|
271 |
+
text_weights = [*text_weights, *chunk_weights]
|
272 |
+
return text_tokens, text_weights
|
273 |
+
|
274 |
+
def group_tokens_and_weights(self, token_ids: list, weights: list, pad_last_block=False):
|
275 |
+
"""
|
276 |
+
Produce tokens and weights in groups and pad the missing tokens
|
277 |
+
|
278 |
+
Args:
|
279 |
+
token_ids (list)
|
280 |
+
The token ids from tokenizer
|
281 |
+
weights (list)
|
282 |
+
The weights list from function get_prompts_tokens_with_weights
|
283 |
+
pad_last_block (bool)
|
284 |
+
Control if fill the last token list to 75 tokens with eos
|
285 |
+
Returns:
|
286 |
+
new_token_ids (2d list)
|
287 |
+
new_weights (2d list)
|
288 |
+
|
289 |
+
Example:
|
290 |
+
token_groups,weight_groups = group_tokens_and_weights(
|
291 |
+
token_ids = token_id_list
|
292 |
+
, weights = token_weight_list
|
293 |
+
)
|
294 |
+
"""
|
295 |
+
bos, eos = 49406, 49407
|
296 |
+
|
297 |
+
# this will be a 2d list
|
298 |
+
new_token_ids = []
|
299 |
+
new_weights = []
|
300 |
+
while len(token_ids) >= 75:
|
301 |
+
# get the first 75 tokens
|
302 |
+
head_75_tokens = [token_ids.pop(0) for _ in range(75)]
|
303 |
+
head_75_weights = [weights.pop(0) for _ in range(75)]
|
304 |
+
|
305 |
+
# extract token ids and weights
|
306 |
+
temp_77_token_ids = [bos] + head_75_tokens + [eos]
|
307 |
+
temp_77_weights = [1.0] + head_75_weights + [1.0]
|
308 |
+
|
309 |
+
# add 77 token and weights chunk to the holder list
|
310 |
+
new_token_ids.append(temp_77_token_ids)
|
311 |
+
new_weights.append(temp_77_weights)
|
312 |
+
|
313 |
+
# padding the left
|
314 |
+
if len(token_ids) >= 0:
|
315 |
+
padding_len = 75 - len(token_ids) if pad_last_block else 0
|
316 |
+
|
317 |
+
temp_77_token_ids = [bos] + token_ids + [eos] * padding_len + [eos]
|
318 |
+
new_token_ids.append(temp_77_token_ids)
|
319 |
+
|
320 |
+
temp_77_weights = [1.0] + weights + [1.0] * padding_len + [1.0]
|
321 |
+
new_weights.append(temp_77_weights)
|
322 |
+
|
323 |
+
return new_token_ids, new_weights
|
324 |
+
|
325 |
+
def get_weighted_text_embeddings_sdxl(
|
326 |
+
self,
|
327 |
+
pipe: StableDiffusionXLPipeline,
|
328 |
+
prompt: str = "",
|
329 |
+
prompt_2: str = None,
|
330 |
+
neg_prompt: str = "",
|
331 |
+
neg_prompt_2: str = None,
|
332 |
+
prompt_embeds=None,
|
333 |
+
negative_prompt_embeds=None,
|
334 |
+
pooled_prompt_embeds=None,
|
335 |
+
negative_pooled_prompt_embeds=None,
|
336 |
+
extra_emb=None,
|
337 |
+
extra_emb_alpha=0.6,
|
338 |
+
):
|
339 |
+
"""
|
340 |
+
This function can process long prompt with weights, no length limitation
|
341 |
+
for Stable Diffusion XL
|
342 |
+
|
343 |
+
Args:
|
344 |
+
pipe (StableDiffusionPipeline)
|
345 |
+
prompt (str)
|
346 |
+
prompt_2 (str)
|
347 |
+
neg_prompt (str)
|
348 |
+
neg_prompt_2 (str)
|
349 |
+
Returns:
|
350 |
+
prompt_embeds (torch.Tensor)
|
351 |
+
neg_prompt_embeds (torch.Tensor)
|
352 |
+
"""
|
353 |
+
#
|
354 |
+
if prompt_embeds is not None and \
|
355 |
+
negative_prompt_embeds is not None and \
|
356 |
+
pooled_prompt_embeds is not None and \
|
357 |
+
negative_pooled_prompt_embeds is not None:
|
358 |
+
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
359 |
+
|
360 |
+
if prompt_2:
|
361 |
+
prompt = f"{prompt} {prompt_2}"
|
362 |
+
|
363 |
+
if neg_prompt_2:
|
364 |
+
neg_prompt = f"{neg_prompt} {neg_prompt_2}"
|
365 |
+
|
366 |
+
eos = pipe.tokenizer.eos_token_id
|
367 |
+
|
368 |
+
# tokenizer 1
|
369 |
+
prompt_tokens, prompt_weights = self.get_prompts_tokens_with_weights(pipe.tokenizer, prompt)
|
370 |
+
neg_prompt_tokens, neg_prompt_weights = self.get_prompts_tokens_with_weights(pipe.tokenizer, neg_prompt)
|
371 |
+
|
372 |
+
# tokenizer 2
|
373 |
+
# prompt_tokens_2, prompt_weights_2 = self.get_prompts_tokens_with_weights(pipe.tokenizer_2, prompt)
|
374 |
+
# neg_prompt_tokens_2, neg_prompt_weights_2 = self.get_prompts_tokens_with_weights(pipe.tokenizer_2, neg_prompt)
|
375 |
+
# tokenizer 2 遇到 !! !!!! 等多感叹号和tokenizer 1的效果不一致
|
376 |
+
prompt_tokens_2, prompt_weights_2 = self.get_prompts_tokens_with_weights(pipe.tokenizer, prompt)
|
377 |
+
neg_prompt_tokens_2, neg_prompt_weights_2 = self.get_prompts_tokens_with_weights(pipe.tokenizer, neg_prompt)
|
378 |
+
|
379 |
+
# padding the shorter one for prompt set 1
|
380 |
+
prompt_token_len = len(prompt_tokens)
|
381 |
+
neg_prompt_token_len = len(neg_prompt_tokens)
|
382 |
+
|
383 |
+
if prompt_token_len > neg_prompt_token_len:
|
384 |
+
# padding the neg_prompt with eos token
|
385 |
+
neg_prompt_tokens = neg_prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len)
|
386 |
+
neg_prompt_weights = neg_prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len)
|
387 |
+
else:
|
388 |
+
# padding the prompt
|
389 |
+
prompt_tokens = prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len)
|
390 |
+
prompt_weights = prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len)
|
391 |
+
|
392 |
+
# padding the shorter one for token set 2
|
393 |
+
prompt_token_len_2 = len(prompt_tokens_2)
|
394 |
+
neg_prompt_token_len_2 = len(neg_prompt_tokens_2)
|
395 |
+
|
396 |
+
if prompt_token_len_2 > neg_prompt_token_len_2:
|
397 |
+
# padding the neg_prompt with eos token
|
398 |
+
neg_prompt_tokens_2 = neg_prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
399 |
+
neg_prompt_weights_2 = neg_prompt_weights_2 + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
400 |
+
else:
|
401 |
+
# padding the prompt
|
402 |
+
prompt_tokens_2 = prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
403 |
+
prompt_weights_2 = prompt_weights + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
404 |
+
|
405 |
+
embeds = []
|
406 |
+
neg_embeds = []
|
407 |
+
|
408 |
+
prompt_token_groups, prompt_weight_groups = self.group_tokens_and_weights(prompt_tokens.copy(), prompt_weights.copy())
|
409 |
+
|
410 |
+
neg_prompt_token_groups, neg_prompt_weight_groups = self.group_tokens_and_weights(
|
411 |
+
neg_prompt_tokens.copy(), neg_prompt_weights.copy()
|
412 |
+
)
|
413 |
+
|
414 |
+
prompt_token_groups_2, prompt_weight_groups_2 = self.group_tokens_and_weights(
|
415 |
+
prompt_tokens_2.copy(), prompt_weights_2.copy()
|
416 |
+
)
|
417 |
+
|
418 |
+
neg_prompt_token_groups_2, neg_prompt_weight_groups_2 = self.group_tokens_and_weights(
|
419 |
+
neg_prompt_tokens_2.copy(), neg_prompt_weights_2.copy()
|
420 |
+
)
|
421 |
+
|
422 |
+
# get prompt embeddings one by one is not working.
|
423 |
+
for i in range(len(prompt_token_groups)):
|
424 |
+
# get positive prompt embeddings with weights
|
425 |
+
token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=pipe.device)
|
426 |
+
weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=pipe.device)
|
427 |
+
|
428 |
+
token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device)
|
429 |
+
|
430 |
+
# use first text encoder
|
431 |
+
prompt_embeds_1 = pipe.text_encoder(token_tensor.to(pipe.device), output_hidden_states=True)
|
432 |
+
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
|
433 |
+
|
434 |
+
# use second text encoder
|
435 |
+
prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(pipe.device), output_hidden_states=True)
|
436 |
+
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
|
437 |
+
pooled_prompt_embeds = prompt_embeds_2[0]
|
438 |
+
|
439 |
+
prompt_embeds_list = [prompt_embeds_1_hidden_states, prompt_embeds_2_hidden_states]
|
440 |
+
token_embedding = torch.concat(prompt_embeds_list, dim=-1).squeeze(0)
|
441 |
+
|
442 |
+
for j in range(len(weight_tensor)):
|
443 |
+
if weight_tensor[j] != 1.0:
|
444 |
+
token_embedding[j] = (
|
445 |
+
token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight_tensor[j]
|
446 |
+
)
|
447 |
+
|
448 |
+
token_embedding = token_embedding.unsqueeze(0)
|
449 |
+
embeds.append(token_embedding)
|
450 |
+
|
451 |
+
# get negative prompt embeddings with weights
|
452 |
+
neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=pipe.device)
|
453 |
+
neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device)
|
454 |
+
neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=pipe.device)
|
455 |
+
|
456 |
+
# use first text encoder
|
457 |
+
neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(pipe.device), output_hidden_states=True)
|
458 |
+
neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2]
|
459 |
+
|
460 |
+
# use second text encoder
|
461 |
+
neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(pipe.device), output_hidden_states=True)
|
462 |
+
neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2]
|
463 |
+
negative_pooled_prompt_embeds = neg_prompt_embeds_2[0]
|
464 |
+
|
465 |
+
neg_prompt_embeds_list = [neg_prompt_embeds_1_hidden_states, neg_prompt_embeds_2_hidden_states]
|
466 |
+
neg_token_embedding = torch.concat(neg_prompt_embeds_list, dim=-1).squeeze(0)
|
467 |
+
|
468 |
+
for z in range(len(neg_weight_tensor)):
|
469 |
+
if neg_weight_tensor[z] != 1.0:
|
470 |
+
neg_token_embedding[z] = (
|
471 |
+
neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight_tensor[z]
|
472 |
+
)
|
473 |
+
|
474 |
+
neg_token_embedding = neg_token_embedding.unsqueeze(0)
|
475 |
+
neg_embeds.append(neg_token_embedding)
|
476 |
+
|
477 |
+
prompt_embeds = torch.cat(embeds, dim=1)
|
478 |
+
negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
|
479 |
+
|
480 |
+
if extra_emb is not None:
|
481 |
+
extra_emb = extra_emb.to(prompt_embeds.device, dtype=prompt_embeds.dtype) * extra_emb_alpha
|
482 |
+
prompt_embeds = torch.cat([prompt_embeds, extra_emb], 1)
|
483 |
+
negative_prompt_embeds = torch.cat([negative_prompt_embeds, torch.zeros_like(extra_emb)], 1)
|
484 |
+
print(f'fix prompt_embeds, extra_emb_alpha={extra_emb_alpha}')
|
485 |
+
|
486 |
+
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
487 |
+
|
488 |
+
def get_prompt_embeds(self, *args, **kwargs):
|
489 |
+
prompt_embeds, negative_prompt_embeds, _, _ = self.get_weighted_text_embeddings_sdxl(*args, **kwargs)
|
490 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
491 |
+
return prompt_embeds
|
492 |
+
|
493 |
+
|
494 |
+
class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
|
495 |
+
|
496 |
+
def cuda(self, dtype=torch.float16, use_xformers=False):
|
497 |
+
self.to('cuda', dtype)
|
498 |
+
|
499 |
+
if hasattr(self, 'image_proj_model'):
|
500 |
+
self.image_proj_model.to(self.unet.device).to(self.unet.dtype)
|
501 |
+
|
502 |
+
if use_xformers:
|
503 |
+
if is_xformers_available():
|
504 |
+
import xformers
|
505 |
+
from packaging import version
|
506 |
+
|
507 |
+
xformers_version = version.parse(xformers.__version__)
|
508 |
+
if xformers_version == version.parse("0.0.16"):
|
509 |
+
logger.warn(
|
510 |
+
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
511 |
+
)
|
512 |
+
self.enable_xformers_memory_efficient_attention()
|
513 |
+
else:
|
514 |
+
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
515 |
+
|
516 |
+
def load_ip_adapter_instantid(self, model_ckpt, image_emb_dim=512, num_tokens=16, scale=0.5):
|
517 |
+
self.set_image_proj_model(model_ckpt, image_emb_dim, num_tokens)
|
518 |
+
self.set_ip_adapter(model_ckpt, num_tokens, scale)
|
519 |
+
|
520 |
+
def set_image_proj_model(self, model_ckpt, image_emb_dim=512, num_tokens=16):
|
521 |
+
|
522 |
+
image_proj_model = Resampler(
|
523 |
+
dim=1280,
|
524 |
+
depth=4,
|
525 |
+
dim_head=64,
|
526 |
+
heads=20,
|
527 |
+
num_queries=num_tokens,
|
528 |
+
embedding_dim=image_emb_dim,
|
529 |
+
output_dim=self.unet.config.cross_attention_dim,
|
530 |
+
ff_mult=4,
|
531 |
+
)
|
532 |
+
|
533 |
+
image_proj_model.eval()
|
534 |
+
|
535 |
+
self.image_proj_model = image_proj_model.to(self.device, dtype=self.dtype)
|
536 |
+
state_dict = torch.load(model_ckpt, map_location="cpu")
|
537 |
+
if 'image_proj' in state_dict:
|
538 |
+
state_dict = state_dict["image_proj"]
|
539 |
+
self.image_proj_model.load_state_dict(state_dict)
|
540 |
+
|
541 |
+
self.image_proj_model_in_features = image_emb_dim
|
542 |
+
|
543 |
+
def set_ip_adapter(self, model_ckpt, num_tokens, scale):
|
544 |
+
|
545 |
+
unet = self.unet
|
546 |
+
attn_procs = {}
|
547 |
+
for name in unet.attn_processors.keys():
|
548 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
549 |
+
if name.startswith("mid_block"):
|
550 |
+
hidden_size = unet.config.block_out_channels[-1]
|
551 |
+
elif name.startswith("up_blocks"):
|
552 |
+
block_id = int(name[len("up_blocks.")])
|
553 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
554 |
+
elif name.startswith("down_blocks"):
|
555 |
+
block_id = int(name[len("down_blocks.")])
|
556 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
557 |
+
if cross_attention_dim is None:
|
558 |
+
attn_procs[name] = AttnProcessor().to(unet.device, dtype=unet.dtype)
|
559 |
+
else:
|
560 |
+
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size,
|
561 |
+
cross_attention_dim=cross_attention_dim,
|
562 |
+
scale=scale,
|
563 |
+
num_tokens=num_tokens).to(unet.device, dtype=unet.dtype)
|
564 |
+
unet.set_attn_processor(attn_procs)
|
565 |
+
|
566 |
+
state_dict = torch.load(model_ckpt, map_location="cpu")
|
567 |
+
ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values())
|
568 |
+
if 'ip_adapter' in state_dict:
|
569 |
+
state_dict = state_dict['ip_adapter']
|
570 |
+
ip_layers.load_state_dict(state_dict)
|
571 |
+
|
572 |
+
def set_ip_adapter_scale(self, scale):
|
573 |
+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
574 |
+
for attn_processor in unet.attn_processors.values():
|
575 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
576 |
+
attn_processor.scale = scale
|
577 |
+
|
578 |
+
def _encode_prompt_image_emb(self, prompt_image_emb, device, dtype, do_classifier_free_guidance):
|
579 |
+
|
580 |
+
if isinstance(prompt_image_emb, torch.Tensor):
|
581 |
+
prompt_image_emb = prompt_image_emb.clone().detach()
|
582 |
+
else:
|
583 |
+
prompt_image_emb = torch.tensor(prompt_image_emb)
|
584 |
+
|
585 |
+
prompt_image_emb = prompt_image_emb.to(device=device, dtype=dtype)
|
586 |
+
prompt_image_emb = prompt_image_emb.reshape([1, -1, self.image_proj_model_in_features])
|
587 |
+
|
588 |
+
if do_classifier_free_guidance:
|
589 |
+
prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0)
|
590 |
+
else:
|
591 |
+
prompt_image_emb = torch.cat([prompt_image_emb], dim=0)
|
592 |
+
|
593 |
+
prompt_image_emb = self.image_proj_model(prompt_image_emb)
|
594 |
+
return prompt_image_emb
|
595 |
+
|
596 |
+
@torch.no_grad()
|
597 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
598 |
+
def __call__(
|
599 |
+
self,
|
600 |
+
prompt: Union[str, List[str]] = None,
|
601 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
602 |
+
image: PipelineImageInput = None,
|
603 |
+
height: Optional[int] = None,
|
604 |
+
width: Optional[int] = None,
|
605 |
+
num_inference_steps: int = 50,
|
606 |
+
guidance_scale: float = 5.0,
|
607 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
608 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
609 |
+
num_images_per_prompt: Optional[int] = 1,
|
610 |
+
eta: float = 0.0,
|
611 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
612 |
+
latents: Optional[torch.FloatTensor] = None,
|
613 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
614 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
615 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
616 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
617 |
+
image_embeds: Optional[torch.FloatTensor] = None,
|
618 |
+
output_type: Optional[str] = "pil",
|
619 |
+
return_dict: bool = True,
|
620 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
621 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
622 |
+
guess_mode: bool = False,
|
623 |
+
control_guidance_start: Union[float, List[float]] = 0.0,
|
624 |
+
control_guidance_end: Union[float, List[float]] = 1.0,
|
625 |
+
original_size: Tuple[int, int] = None,
|
626 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
627 |
+
target_size: Tuple[int, int] = None,
|
628 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
629 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
630 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
631 |
+
clip_skip: Optional[int] = None,
|
632 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
633 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
634 |
+
control_mask = None,
|
635 |
+
**kwargs,
|
636 |
+
):
|
637 |
+
r"""
|
638 |
+
The call function to the pipeline for generation.
|
639 |
+
|
640 |
+
Args:
|
641 |
+
prompt (`str` or `List[str]`, *optional*):
|
642 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
643 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
644 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
645 |
+
used in both text-encoders.
|
646 |
+
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
647 |
+
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
648 |
+
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
649 |
+
specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
|
650 |
+
accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
|
651 |
+
and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
|
652 |
+
`init`, images must be passed as a list such that each element of the list can be correctly batched for
|
653 |
+
input to a single ControlNet.
|
654 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
655 |
+
The height in pixels of the generated image. Anything below 512 pixels won't work well for
|
656 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
657 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
658 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
659 |
+
The width in pixels of the generated image. Anything below 512 pixels won't work well for
|
660 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
661 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
662 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
663 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
664 |
+
expense of slower inference.
|
665 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
666 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
667 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
668 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
669 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
670 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
671 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
672 |
+
The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`
|
673 |
+
and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
|
674 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
675 |
+
The number of images to generate per prompt.
|
676 |
+
eta (`float`, *optional*, defaults to 0.0):
|
677 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
678 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
679 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
680 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
681 |
+
generation deterministic.
|
682 |
+
latents (`torch.FloatTensor`, *optional*):
|
683 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
684 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
685 |
+
tensor is generated by sampling using the supplied random `generator`.
|
686 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
687 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
688 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
689 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
690 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
691 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
692 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
693 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
694 |
+
not provided, pooled text embeddings are generated from `prompt` input argument.
|
695 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
696 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
|
697 |
+
weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
|
698 |
+
argument.
|
699 |
+
image_embeds (`torch.FloatTensor`, *optional*):
|
700 |
+
Pre-generated image embeddings.
|
701 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
702 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
703 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
704 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
705 |
+
plain tuple.
|
706 |
+
cross_attention_kwargs (`dict`, *optional*):
|
707 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
708 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
709 |
+
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
710 |
+
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
711 |
+
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
|
712 |
+
the corresponding scale as a list.
|
713 |
+
guess_mode (`bool`, *optional*, defaults to `False`):
|
714 |
+
The ControlNet encoder tries to recognize the content of the input image even if you remove all
|
715 |
+
prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
|
716 |
+
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
717 |
+
The percentage of total steps at which the ControlNet starts applying.
|
718 |
+
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
719 |
+
The percentage of total steps at which the ControlNet stops applying.
|
720 |
+
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
721 |
+
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
722 |
+
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
723 |
+
explained in section 2.2 of
|
724 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
725 |
+
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
726 |
+
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
727 |
+
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
728 |
+
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
729 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
730 |
+
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
731 |
+
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
732 |
+
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
|
733 |
+
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
734 |
+
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
735 |
+
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
736 |
+
micro-conditioning as explained in section 2.2 of
|
737 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
738 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
739 |
+
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
740 |
+
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
741 |
+
micro-conditioning as explained in section 2.2 of
|
742 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
743 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
744 |
+
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
745 |
+
To negatively condition the generation process based on a target image resolution. It should be as same
|
746 |
+
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
747 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
748 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
749 |
+
clip_skip (`int`, *optional*):
|
750 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
751 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
752 |
+
callback_on_step_end (`Callable`, *optional*):
|
753 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
754 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
755 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
756 |
+
`callback_on_step_end_tensor_inputs`.
|
757 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
758 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
759 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
760 |
+
`._callback_tensor_inputs` attribute of your pipeine class.
|
761 |
+
|
762 |
+
Examples:
|
763 |
+
|
764 |
+
Returns:
|
765 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
766 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
767 |
+
otherwise a `tuple` is returned containing the output images.
|
768 |
+
"""
|
769 |
+
lpw = LongPromptWeight()
|
770 |
+
|
771 |
+
callback = kwargs.pop("callback", None)
|
772 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
773 |
+
|
774 |
+
if callback is not None:
|
775 |
+
deprecate(
|
776 |
+
"callback",
|
777 |
+
"1.0.0",
|
778 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
779 |
+
)
|
780 |
+
if callback_steps is not None:
|
781 |
+
deprecate(
|
782 |
+
"callback_steps",
|
783 |
+
"1.0.0",
|
784 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
785 |
+
)
|
786 |
+
|
787 |
+
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
788 |
+
|
789 |
+
# align format for control guidance
|
790 |
+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
791 |
+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
792 |
+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
793 |
+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
794 |
+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
795 |
+
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
|
796 |
+
control_guidance_start, control_guidance_end = (
|
797 |
+
mult * [control_guidance_start],
|
798 |
+
mult * [control_guidance_end],
|
799 |
+
)
|
800 |
+
|
801 |
+
# 1. Check inputs. Raise error if not correct
|
802 |
+
self.check_inputs(
|
803 |
+
prompt,
|
804 |
+
prompt_2,
|
805 |
+
image,
|
806 |
+
callback_steps,
|
807 |
+
negative_prompt,
|
808 |
+
negative_prompt_2,
|
809 |
+
prompt_embeds,
|
810 |
+
negative_prompt_embeds,
|
811 |
+
pooled_prompt_embeds,
|
812 |
+
negative_pooled_prompt_embeds,
|
813 |
+
controlnet_conditioning_scale,
|
814 |
+
control_guidance_start,
|
815 |
+
control_guidance_end,
|
816 |
+
callback_on_step_end_tensor_inputs,
|
817 |
+
)
|
818 |
+
|
819 |
+
self._guidance_scale = guidance_scale
|
820 |
+
self._clip_skip = clip_skip
|
821 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
822 |
+
|
823 |
+
# 2. Define call parameters
|
824 |
+
if prompt is not None and isinstance(prompt, str):
|
825 |
+
batch_size = 1
|
826 |
+
elif prompt is not None and isinstance(prompt, list):
|
827 |
+
batch_size = len(prompt)
|
828 |
+
else:
|
829 |
+
batch_size = prompt_embeds.shape[0]
|
830 |
+
|
831 |
+
device = self._execution_device
|
832 |
+
|
833 |
+
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
834 |
+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
835 |
+
|
836 |
+
global_pool_conditions = (
|
837 |
+
controlnet.config.global_pool_conditions
|
838 |
+
if isinstance(controlnet, ControlNetModel)
|
839 |
+
else controlnet.nets[0].config.global_pool_conditions
|
840 |
+
)
|
841 |
+
guess_mode = guess_mode or global_pool_conditions
|
842 |
+
|
843 |
+
# 3.1 Encode input prompt
|
844 |
+
(
|
845 |
+
prompt_embeds,
|
846 |
+
negative_prompt_embeds,
|
847 |
+
pooled_prompt_embeds,
|
848 |
+
negative_pooled_prompt_embeds,
|
849 |
+
) = lpw.get_weighted_text_embeddings_sdxl(
|
850 |
+
pipe=self,
|
851 |
+
prompt=prompt,
|
852 |
+
neg_prompt=negative_prompt,
|
853 |
+
prompt_embeds=prompt_embeds,
|
854 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
855 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
856 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
857 |
+
)
|
858 |
+
|
859 |
+
# 3.2 Encode image prompt
|
860 |
+
prompt_image_emb = self._encode_prompt_image_emb(image_embeds,
|
861 |
+
device,
|
862 |
+
self.unet.dtype,
|
863 |
+
self.do_classifier_free_guidance)
|
864 |
+
|
865 |
+
# 4. Prepare image
|
866 |
+
if isinstance(controlnet, ControlNetModel):
|
867 |
+
image = self.prepare_image(
|
868 |
+
image=image,
|
869 |
+
width=width,
|
870 |
+
height=height,
|
871 |
+
batch_size=batch_size * num_images_per_prompt,
|
872 |
+
num_images_per_prompt=num_images_per_prompt,
|
873 |
+
device=device,
|
874 |
+
dtype=controlnet.dtype,
|
875 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
876 |
+
guess_mode=guess_mode,
|
877 |
+
)
|
878 |
+
height, width = image.shape[-2:]
|
879 |
+
elif isinstance(controlnet, MultiControlNetModel):
|
880 |
+
images = []
|
881 |
+
|
882 |
+
for image_ in image:
|
883 |
+
image_ = self.prepare_image(
|
884 |
+
image=image_,
|
885 |
+
width=width,
|
886 |
+
height=height,
|
887 |
+
batch_size=batch_size * num_images_per_prompt,
|
888 |
+
num_images_per_prompt=num_images_per_prompt,
|
889 |
+
device=device,
|
890 |
+
dtype=controlnet.dtype,
|
891 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
892 |
+
guess_mode=guess_mode,
|
893 |
+
)
|
894 |
+
|
895 |
+
images.append(image_)
|
896 |
+
|
897 |
+
image = images
|
898 |
+
height, width = image[0].shape[-2:]
|
899 |
+
else:
|
900 |
+
assert False
|
901 |
+
|
902 |
+
# 4.1 Region control
|
903 |
+
if control_mask is not None:
|
904 |
+
mask_weight_image = control_mask
|
905 |
+
mask_weight_image = np.array(mask_weight_image)
|
906 |
+
mask_weight_image_tensor = torch.from_numpy(mask_weight_image).to(device=device, dtype=prompt_embeds.dtype)
|
907 |
+
mask_weight_image_tensor = mask_weight_image_tensor[:, :, 0] / 255.
|
908 |
+
mask_weight_image_tensor = mask_weight_image_tensor[None, None]
|
909 |
+
h, w = mask_weight_image_tensor.shape[-2:]
|
910 |
+
control_mask_wight_image_list = []
|
911 |
+
for scale in [8, 8, 8, 16, 16, 16, 32, 32, 32]:
|
912 |
+
scale_mask_weight_image_tensor = F.interpolate(
|
913 |
+
mask_weight_image_tensor,(h // scale, w // scale), mode='bilinear')
|
914 |
+
control_mask_wight_image_list.append(scale_mask_weight_image_tensor)
|
915 |
+
region_mask = torch.from_numpy(np.array(control_mask)[:, :, 0]).to(self.unet.device, dtype=self.unet.dtype) / 255.
|
916 |
+
region_control.prompt_image_conditioning = [dict(region_mask=region_mask)]
|
917 |
+
else:
|
918 |
+
control_mask_wight_image_list = None
|
919 |
+
region_control.prompt_image_conditioning = [dict(region_mask=None)]
|
920 |
+
|
921 |
+
# 5. Prepare timesteps
|
922 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
923 |
+
timesteps = self.scheduler.timesteps
|
924 |
+
self._num_timesteps = len(timesteps)
|
925 |
+
|
926 |
+
# 6. Prepare latent variables
|
927 |
+
num_channels_latents = self.unet.config.in_channels
|
928 |
+
latents = self.prepare_latents(
|
929 |
+
batch_size * num_images_per_prompt,
|
930 |
+
num_channels_latents,
|
931 |
+
height,
|
932 |
+
width,
|
933 |
+
prompt_embeds.dtype,
|
934 |
+
device,
|
935 |
+
generator,
|
936 |
+
latents,
|
937 |
+
)
|
938 |
+
|
939 |
+
# 6.5 Optionally get Guidance Scale Embedding
|
940 |
+
timestep_cond = None
|
941 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
942 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
943 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
944 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
945 |
+
).to(device=device, dtype=latents.dtype)
|
946 |
+
|
947 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
948 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
949 |
+
|
950 |
+
# 7.1 Create tensor stating which controlnets to keep
|
951 |
+
controlnet_keep = []
|
952 |
+
for i in range(len(timesteps)):
|
953 |
+
keeps = [
|
954 |
+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
955 |
+
for s, e in zip(control_guidance_start, control_guidance_end)
|
956 |
+
]
|
957 |
+
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
958 |
+
|
959 |
+
# 7.2 Prepare added time ids & embeddings
|
960 |
+
if isinstance(image, list):
|
961 |
+
original_size = original_size or image[0].shape[-2:]
|
962 |
+
else:
|
963 |
+
original_size = original_size or image.shape[-2:]
|
964 |
+
target_size = target_size or (height, width)
|
965 |
+
|
966 |
+
add_text_embeds = pooled_prompt_embeds
|
967 |
+
if self.text_encoder_2 is None:
|
968 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
969 |
+
else:
|
970 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
971 |
+
|
972 |
+
add_time_ids = self._get_add_time_ids(
|
973 |
+
original_size,
|
974 |
+
crops_coords_top_left,
|
975 |
+
target_size,
|
976 |
+
dtype=prompt_embeds.dtype,
|
977 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
978 |
+
)
|
979 |
+
|
980 |
+
if negative_original_size is not None and negative_target_size is not None:
|
981 |
+
negative_add_time_ids = self._get_add_time_ids(
|
982 |
+
negative_original_size,
|
983 |
+
negative_crops_coords_top_left,
|
984 |
+
negative_target_size,
|
985 |
+
dtype=prompt_embeds.dtype,
|
986 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
987 |
+
)
|
988 |
+
else:
|
989 |
+
negative_add_time_ids = add_time_ids
|
990 |
+
|
991 |
+
if self.do_classifier_free_guidance:
|
992 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
993 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
994 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
995 |
+
|
996 |
+
prompt_embeds = prompt_embeds.to(device)
|
997 |
+
add_text_embeds = add_text_embeds.to(device)
|
998 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
999 |
+
encoder_hidden_states = torch.cat([prompt_embeds, prompt_image_emb], dim=1)
|
1000 |
+
|
1001 |
+
# 8. Denoising loop
|
1002 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
1003 |
+
is_unet_compiled = is_compiled_module(self.unet)
|
1004 |
+
is_controlnet_compiled = is_compiled_module(self.controlnet)
|
1005 |
+
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
|
1006 |
+
|
1007 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1008 |
+
for i, t in enumerate(timesteps):
|
1009 |
+
# Relevant thread:
|
1010 |
+
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
|
1011 |
+
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
|
1012 |
+
torch._inductor.cudagraph_mark_step_begin()
|
1013 |
+
# expand the latents if we are doing classifier free guidance
|
1014 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
1015 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1016 |
+
|
1017 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
1018 |
+
|
1019 |
+
# controlnet(s) inference
|
1020 |
+
if guess_mode and self.do_classifier_free_guidance:
|
1021 |
+
# Infer ControlNet only for the conditional batch.
|
1022 |
+
control_model_input = latents
|
1023 |
+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
1024 |
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
1025 |
+
controlnet_added_cond_kwargs = {
|
1026 |
+
"text_embeds": add_text_embeds.chunk(2)[1],
|
1027 |
+
"time_ids": add_time_ids.chunk(2)[1],
|
1028 |
+
}
|
1029 |
+
else:
|
1030 |
+
control_model_input = latent_model_input
|
1031 |
+
controlnet_prompt_embeds = prompt_embeds
|
1032 |
+
controlnet_added_cond_kwargs = added_cond_kwargs
|
1033 |
+
|
1034 |
+
if isinstance(controlnet_keep[i], list):
|
1035 |
+
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
1036 |
+
else:
|
1037 |
+
controlnet_cond_scale = controlnet_conditioning_scale
|
1038 |
+
if isinstance(controlnet_cond_scale, list):
|
1039 |
+
controlnet_cond_scale = controlnet_cond_scale[0]
|
1040 |
+
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
1041 |
+
|
1042 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
1043 |
+
control_model_input,
|
1044 |
+
t,
|
1045 |
+
encoder_hidden_states=prompt_image_emb,
|
1046 |
+
controlnet_cond=image,
|
1047 |
+
conditioning_scale=cond_scale,
|
1048 |
+
guess_mode=guess_mode,
|
1049 |
+
added_cond_kwargs=controlnet_added_cond_kwargs,
|
1050 |
+
return_dict=False,
|
1051 |
+
)
|
1052 |
+
|
1053 |
+
# controlnet mask
|
1054 |
+
if control_mask_wight_image_list is not None:
|
1055 |
+
down_block_res_samples = [
|
1056 |
+
down_block_res_sample * mask_weight
|
1057 |
+
for down_block_res_sample, mask_weight in zip(down_block_res_samples, control_mask_wight_image_list)
|
1058 |
+
]
|
1059 |
+
mid_block_res_sample *= control_mask_wight_image_list[-1]
|
1060 |
+
|
1061 |
+
if guess_mode and self.do_classifier_free_guidance:
|
1062 |
+
# Infered ControlNet only for the conditional batch.
|
1063 |
+
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
1064 |
+
# add 0 to the unconditional batch to keep it unchanged.
|
1065 |
+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
1066 |
+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
1067 |
+
|
1068 |
+
# predict the noise residual
|
1069 |
+
noise_pred = self.unet(
|
1070 |
+
latent_model_input,
|
1071 |
+
t,
|
1072 |
+
encoder_hidden_states=encoder_hidden_states,
|
1073 |
+
timestep_cond=timestep_cond,
|
1074 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
1075 |
+
down_block_additional_residuals=down_block_res_samples,
|
1076 |
+
mid_block_additional_residual=mid_block_res_sample,
|
1077 |
+
added_cond_kwargs=added_cond_kwargs,
|
1078 |
+
return_dict=False,
|
1079 |
+
)[0]
|
1080 |
+
|
1081 |
+
# perform guidance
|
1082 |
+
if self.do_classifier_free_guidance:
|
1083 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1084 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1085 |
+
|
1086 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1087 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
1088 |
+
|
1089 |
+
if callback_on_step_end is not None:
|
1090 |
+
callback_kwargs = {}
|
1091 |
+
for k in callback_on_step_end_tensor_inputs:
|
1092 |
+
callback_kwargs[k] = locals()[k]
|
1093 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1094 |
+
|
1095 |
+
latents = callback_outputs.pop("latents", latents)
|
1096 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1097 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
1098 |
+
|
1099 |
+
# call the callback, if provided
|
1100 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1101 |
+
progress_bar.update()
|
1102 |
+
if callback is not None and i % callback_steps == 0:
|
1103 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
1104 |
+
callback(step_idx, t, latents)
|
1105 |
+
|
1106 |
+
if not output_type == "latent":
|
1107 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
1108 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
1109 |
+
if needs_upcasting:
|
1110 |
+
self.upcast_vae()
|
1111 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
1112 |
+
|
1113 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
1114 |
+
|
1115 |
+
# cast back to fp16 if needed
|
1116 |
+
if needs_upcasting:
|
1117 |
+
self.vae.to(dtype=torch.float16)
|
1118 |
+
else:
|
1119 |
+
image = latents
|
1120 |
+
|
1121 |
+
if not output_type == "latent":
|
1122 |
+
# apply watermark if available
|
1123 |
+
if self.watermark is not None:
|
1124 |
+
image = self.watermark.apply_watermark(image)
|
1125 |
+
|
1126 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
1127 |
+
|
1128 |
+
# Offload all models
|
1129 |
+
self.maybe_free_model_hooks()
|
1130 |
+
|
1131 |
+
if not return_dict:
|
1132 |
+
return (image,)
|
1133 |
+
|
1134 |
+
return StableDiffusionXLPipelineOutput(images=image)
|
style_template.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
style_list = [
|
2 |
+
{
|
3 |
+
"name": "(No style)",
|
4 |
+
"prompt": "{prompt}",
|
5 |
+
"negative_prompt": "",
|
6 |
+
},
|
7 |
+
{
|
8 |
+
"name": "Watercolor",
|
9 |
+
"prompt": "watercolor painting, {prompt}. vibrant, beautiful, painterly, detailed, textural, artistic",
|
10 |
+
"negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, photorealistic, 35mm film, deformed, glitch, low contrast, noisy",
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"name": "Film Noir",
|
14 |
+
"prompt": "film noir style, ink sketch|vector, {prompt} highly detailed, sharp focus, ultra sharpness, monochrome, high contrast, dramatic shadows, 1940s style, mysterious, cinematic",
|
15 |
+
"negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"name": "Neon",
|
19 |
+
"prompt": "masterpiece painting, buildings in the backdrop, kaleidoscope, lilac orange blue cream fuchsia bright vivid gradient colors, the scene is cinematic, {prompt}, emotional realism, double exposure, watercolor ink pencil, graded wash, color layering, magic realism, figurative painting, intricate motifs, organic tracery, polished",
|
20 |
+
"negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"name": "Jungle",
|
24 |
+
"prompt": 'waist-up "{prompt} in a Jungle" by Syd Mead, tangerine cold color palette, muted colors, detailed, 8k,photo r3al,dripping paint,3d toon style,3d style,Movie Still',
|
25 |
+
"negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"name": "Mars",
|
29 |
+
"prompt": "{prompt}, Post-apocalyptic. Mars Colony, Scavengers roam the wastelands searching for valuable resources, rovers, bright morning sunlight shining, (detailed) (intricate) (8k) (HDR) (cinematic lighting) (sharp focus)",
|
30 |
+
"negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"name": "Vibrant Color",
|
34 |
+
"prompt": "vibrant colorful, ink sketch|vector|2d colors, at nightfall, sharp focus, {prompt}, highly detailed, sharp focus, the clouds,colorful,ultra sharpness",
|
35 |
+
"negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"name": "Snow",
|
39 |
+
"prompt": "cinema 4d render, {prompt}, high contrast, vibrant and saturated, sico style, surrounded by magical glow,floating ice shards, snow crystals, cold, windy background, frozen natural landscape in background cinematic atmosphere,highly detailed, sharp focus, intricate design, 3d, unreal engine, octane render, CG best quality, highres, photorealistic, dramatic lighting, artstation, concept art, cinematic, epic Steven Spielberg movie still, sharp focus, smoke, sparks, art by pascal blanche and greg rutkowski and repin, trending on artstation, hyperrealism painting, matte painting, 4k resolution",
|
40 |
+
"negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"name": "Line art",
|
44 |
+
"prompt": "line art drawing {prompt} . professional, sleek, modern, minimalist, graphic, line art, vector graphics",
|
45 |
+
"negative_prompt": "anime, photorealistic, 35mm film, deformed, glitch, blurry, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, mutated, realism, realistic, impressionism, expressionism, oil, acrylic",
|
46 |
+
},
|
47 |
+
]
|
48 |
+
|
49 |
+
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
|