qyoo commited on
Commit
5d216fa
·
1 Parent(s): 0a7a656
Files changed (36) hide show
  1. README.md +5 -7
  2. app.py +361 -123
  3. demo/book.jpg +0 -0
  4. demo/horse.jpg +0 -0
  5. demo/statue.jpg +0 -0
  6. demo/t-shirt.jpg +0 -0
  7. ip_adapter/__init__.py +23 -0
  8. ip_adapter/__pycache__/__init__.cpython-310.pyc +0 -0
  9. ip_adapter/__pycache__/attention_processor.cpython-310.pyc +0 -0
  10. ip_adapter/__pycache__/custom_pipelines.cpython-310.pyc +0 -0
  11. ip_adapter/__pycache__/ip_adapter.cpython-310.pyc +0 -0
  12. ip_adapter/__pycache__/resampler.cpython-310.pyc +0 -0
  13. ip_adapter/__pycache__/utils.cpython-310.pyc +0 -0
  14. ip_adapter/attention_processor.py +948 -0
  15. ip_adapter/custom_pipelines.py +805 -0
  16. ip_adapter/ip_adapter.py +1043 -0
  17. ip_adapter/resampler.py +247 -0
  18. ip_adapter/utils.py +140 -0
  19. omini_control/__init__.py +0 -0
  20. omini_control/__pycache__/__init__.cpython-310.pyc +0 -0
  21. omini_control/__pycache__/block.cpython-310.pyc +0 -0
  22. omini_control/__pycache__/concept_alignment.cpython-310.pyc +0 -0
  23. omini_control/__pycache__/conceptrol.cpython-310.pyc +0 -0
  24. omini_control/__pycache__/condition.cpython-310.pyc +0 -0
  25. omini_control/__pycache__/flux_conceptrol_pipeline.cpython-310.pyc +0 -0
  26. omini_control/__pycache__/lora_controller.cpython-310.pyc +0 -0
  27. omini_control/__pycache__/transformer.cpython-310.pyc +0 -0
  28. omini_control/block.py +354 -0
  29. omini_control/conceptrol.py +208 -0
  30. omini_control/condition.py +124 -0
  31. omini_control/flux_conceptrol_pipeline.py +368 -0
  32. omini_control/lora_controller.py +75 -0
  33. omini_control/transformer.py +273 -0
  34. requirements.txt +8 -5
  35. style.css +95 -0
  36. utils.py +212 -0
README.md CHANGED
@@ -1,14 +1,12 @@
1
  ---
2
- title: Conceptrol
3
- emoji: 🖼
4
- colorFrom: purple
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 5.0.1
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
- short_description: A free lunch eliciting personalized ability of adapters
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: PAID
3
+ emoji: 🏢
4
+ colorFrom: pink
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.22.0
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,154 +1,392 @@
 
 
1
  import gradio as gr
2
  import numpy as np
3
- import random
4
-
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
 
 
 
 
 
 
 
 
11
 
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
 
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
  }
 
 
 
 
65
  """
66
 
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
 
 
70
 
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
 
 
79
 
80
- run_button = gr.Button("Run", scale=0, variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- result = gr.Image(label="Result", show_label=False)
83
 
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
 
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
99
 
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
 
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  )
 
118
 
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
 
126
  )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
  inputs=[
141
  prompt,
 
142
  negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
  num_inference_steps,
 
 
 
 
 
149
  ],
150
- outputs=[result, seed],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  )
 
152
 
153
- if __name__ == "__main__":
154
- demo.launch()
 
1
+ import os
2
+
3
  import gradio as gr
4
  import numpy as np
 
 
 
 
5
  import torch
6
+ from PIL import Image
7
 
8
+ from ip_adapter import (
9
+ ConceptrolIPAdapterPlus,
10
+ ConceptrolIPAdapterPlusXL,
11
+ )
12
+ from ip_adapter.custom_pipelines import (
13
+ StableDiffusionCustomPipeline,
14
+ StableDiffusionXLCustomPipeline,
15
+ )
16
+ from omini_control.conceptrol import Conceptrol
17
+ from omini_control.flux_conceptrol_pipeline import FluxConceptrolPipeline
18
 
 
 
 
 
19
 
20
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
21
 
22
+ title = r"""
23
+ <h1 align="center">Conceptrol: Concept Control of Zero-shot Personalized Image Generation</h1>
24
+ """
25
+
26
+ description = r"""
27
+ <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/QY-H00/Conceptrol/tree/public' target='_blank'><b>Conceptrol: Concept Control of Zero-shot Personalized Image Generation</b></a>.<br>
28
+ How to use:<br>
29
+ 1. Input text prompt, visual specification and the textual concept of the personalized target.
30
+ 2. Choose your preferrd base model, the first time for switching might take 30 minutes to download the model.
31
+ 3. For each inference, SD-series takes about 10s, SDXL-series takes about 30s, FLUX takes about 50s.
32
+ 4. Click the <b>Generate</b> button to enjoy! 😊
33
+ """
34
+
35
+ article = r"""
36
+ ---
37
+ ✒️ **Citation**
38
+ <br>
39
+ If you found this demo/our paper useful, please consider citing:
40
+ ```bibtex
41
+ @article{he2025conceptrol,
42
+ title={Conceptrol: Concept Control of Zero-shot Personalized Image Generation},
43
+ author={He, Qiyuan and Yao, Angela},
44
+ journal={arXiv preprint arXiv:2403.17924},
45
+ year={2024}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  }
47
+ ```
48
+ 📧 **Contact**
49
+ <br>
50
+ If you have any questions, please feel free to open an issue in our <a href='https://github.com/QY-H00/Conceptrol/tree/public' target='_blank'><b>Github Repo</b></a> or directly reach us out at <b>qhe@u.nus.edu.sg</b>.
51
  """
52
 
53
+ MAX_SEED = np.iinfo(np.int32).max
54
+ CACHE_EXAMPLES = False
55
+ USE_TORCH_COMPILE = False
56
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
57
+ PREVIEW_IMAGES = False
58
 
59
+ # Default settings
60
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
+ adapter_name = "h94/IP-Adapter/models/ip-adapter-plus_sd15.bin"
62
+ pipe = StableDiffusionCustomPipeline.from_pretrained(
63
+ "SG161222/Realistic_Vision_V5.1_noVAE",
64
+ torch_dtype=torch.float16,
65
+ feature_extractor=None,
66
+ safety_checker=None
67
+ )
68
+ pipeline = ConceptrolIPAdapterPlus(pipe, "", adapter_name, device, num_tokens=16)
69
 
70
+ def change_model_fn(model_name: str) -> None:
71
+ global device, pipeline
72
+
73
+ # Clear GPU memory
74
+ if torch.cuda.is_available():
75
+ if pipeline is not None:
76
+ del pipeline
77
+ torch.cuda.empty_cache()
78
+
79
+ name_mapping = {
80
+ "SD1.5-512": "stable-diffusion-v1-5/stable-diffusion-v1-5",
81
+ "AOM3 (SD-based)": "hogiahien/aom3",
82
+ "RealVis-v5.1 (SD-based)": "SG161222/Realistic_Vision_V5.1_noVAE",
83
+ "SDXL-1024": "stabilityai/stable-diffusion-xl-base-1.0",
84
+ "RealVisXL-v5.0 (SDXL-based)": "SG161222/RealVisXL_V5.0",
85
+ "Playground-XL-v2 (SDXL-based)": "playgroundai/playground-v2.5-1024px-aesthetic",
86
+ "Animagine-XL-v4.0 (SDXL-based)": "cagliostrolab/animagine-xl-4.0",
87
+ "FLUX-schnell": "black-forest-labs/FLUX.1-schnell"
88
+ }
89
+ if "XL" in model_name:
90
+ adapter_name = "h94/IP-Adapter/sdxl_models/ip-adapter-plus_sdxl_vit-h.safetensors"
91
+ pipe = StableDiffusionXLCustomPipeline.from_pretrained(
92
+ name_mapping[model_name],
93
+ # variant="fp16",
94
+ torch_dtype=torch.float16,
95
+ feature_extractor=None
96
+ )
97
+ pipeline = ConceptrolIPAdapterPlusXL(pipe, "", adapter_name, device, num_tokens=16)
98
+ globals()["pipeline"] = pipeline
99
+
100
+ elif "FLUX" in model_name:
101
+ adapter_name = "Yuanshi/OminiControl"
102
+ pipeline = FluxConceptrolPipeline.from_pretrained(
103
+ name_mapping[model_name], torch_dtype=torch.bfloat16
104
+ ).to(device)
105
+ pipeline.load_lora_weights(
106
+ adapter_name,
107
+ weight_name="omini/subject_512.safetensors",
108
+ adapter_name="subject",
109
+ )
110
+ config = {"name": "conceptrol"}
111
+ conceptrol = Conceptrol(config)
112
+ pipeline.load_conceptrol(conceptrol)
113
+ globals()["pipeline"] = pipeline
114
+ globals()["pipeline"].to(device, dtype=torch.bfloat16)
115
+
116
+ elif "XL" not in model_name and "FLUX" not in model_name:
117
+ adapter_name = "h94/IP-Adapter/models/ip-adapter-plus_sd15.bin"
118
+ pipe = StableDiffusionCustomPipeline.from_pretrained(
119
+ name_mapping[model_name],
120
+ torch_dtype=torch.float16,
121
+ feature_extractor=None,
122
+ safety_checker=None
123
+ )
124
+ pipeline = ConceptrolIPAdapterPlus(pipe, "", adapter_name, device, num_tokens=16)
125
+ globals()["pipeline"] = pipeline
126
+ else:
127
+ raise KeyError("Not supported model name!")
128
 
 
129
 
130
+ def save_image(img, index):
131
+ unique_name = f"{index}.png"
132
+ img = Image.fromarray(img)
133
+ img.save(unique_name)
134
+ return unique_name
 
 
135
 
 
 
 
 
 
 
 
136
 
137
+ def get_example() -> list[list[str | float | int]]:
138
+ case = [
139
+ [
140
+ "A statue is reading the book in the cafe, best quality, high quality",
141
+ "statue",
142
+ "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
143
+ Image.open("demo/statue.jpg"),
144
+ 50,
145
+ 6.0,
146
+ 1.0,
147
+ 0.2,
148
+ 42,
149
+ "RealVis-v5.1 (SD-based)"
150
+ ],
151
+ [
152
+ "A hyper-realistic, high-resolution photograph of an astronaut in a meticulously detailed space suit riding a majestic horse across an otherworldly landscape. The image features dynamic lighting, rich textures, and a cinematic atmosphere, capturing every intricate detail in stunning clarity.",
153
+ "horse",
154
+ "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
155
+ Image.open("demo/horse.jpg"),
156
+ 50,
157
+ 6.0,
158
+ 1.0,
159
+ 0.2,
160
+ 42,
161
+ "RealVisXL-v5.0 (SDXL-based)"
162
+ ],
163
+ [
164
+ "A man wearing a T-shirt walking on the street",
165
+ "T-shirt",
166
+ "",
167
+ Image.open("demo/t-shirt.jpg"),
168
+ 20,
169
+ 3.5,
170
+ 1.0,
171
+ 0.0,
172
+ 42,
173
+ "FLUX-schnell"
174
+ ]
175
+ ]
176
+ return case
177
+
178
+
179
+ def change_generate_button_fn(enable: int) -> gr.Button:
180
+ if enable == 0:
181
+ return gr.Button(interactive=False, value="Switching Model...")
182
+ else:
183
+ return gr.Button(interactive=True, value="Generate")
184
+
185
+
186
+ def dynamic_gallery_fn():
187
+ return gr.Image(label="Result", show_label=False)
188
 
 
 
 
 
 
 
 
 
189
 
190
+ @torch.no_grad()
191
+ def generate(
192
+ prompt="a statue is reading the book in the cafe",
193
+ subject="cat",
194
+ negative_prompt="",
195
+ image=None,
196
+ num_inference_steps=20,
197
+ guidance_scale=3.5,
198
+ condition_scale=1.0,
199
+ control_guidance_start=0.0,
200
+ seed=0,
201
+ model_name="RealVis-v5.1 (SD-based)"
202
+ ) -> np.ndarray:
203
+ global pipeline
204
+ change_model_fn(model_name)
205
+ if isinstance(pipeline, FluxConceptrolPipeline):
206
+ images = pipeline(
207
+ prompt=prompt,
208
+ image=image,
209
+ subject=subject,
210
+ num_inference_steps=num_inference_steps,
211
+ guidance_scale=guidance_scale,
212
+ condition_scale=condition_scale,
213
+ control_guidance_start=control_guidance_start,
214
+ height=512,
215
+ width=512,
216
+ seed=seed,
217
+ ).images[0]
218
+ elif isinstance(pipeline, ConceptrolIPAdapterPlus) or isinstance(pipeline, ConceptrolIPAdapterPlusXL):
219
+ images = pipeline.generate(
220
+ prompt=prompt,
221
+ pil_images=[image],
222
+ subjects=[subject],
223
+ num_samples=1,
224
+ num_inference_steps=50,
225
+ scale=condition_scale,
226
+ negative_prompt=negative_prompt,
227
+ control_guidance_start=control_guidance_start,
228
+ seed=seed,
229
+ )[0]
230
+ else:
231
+ raise TypeError("Unsupported Pipeline")
232
+
233
+ return images
234
+
235
+ with gr.Blocks(css="style.css") as demo:
236
+ gr.Markdown(title)
237
+ gr.Markdown(description)
238
+ with gr.Row(elem_classes="grid-container"):
239
+ with gr.Group():
240
+ with gr.Row(elem_classes="flex-grow"):
241
+ with gr.Column(elem_classes="grid-item"): # 左侧列
242
+ prompt = gr.Text(
243
+ label="Prompt",
244
+ max_lines=3,
245
+ placeholder="Enter the Descriptive Prompt",
246
+ interactive=True,
247
+ value="A statue is reading the book in the cafe, best quality, high quality",
248
+ )
249
+ textual_concept = gr.Text(
250
+ label="Textual Concept",
251
+ max_lines=3,
252
+ placeholder="Enter the Textual Concept required customization",
253
+ interactive=True,
254
+ value="statue",
255
+ )
256
+ negative_prompt = gr.Text(
257
+ label="Negative prompt",
258
+ max_lines=3,
259
+ placeholder="Enter a Negative Prompt",
260
+ interactive=True,
261
+ value="deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality"
262
+ )
263
+
264
+ with gr.Row(elem_classes="flex-grow"):
265
+ image_prompt = gr.Image(
266
+ label="Reference Image for customization",
267
+ interactive=True,
268
+ height=280
269
  )
270
+
271
 
272
+ with gr.Group():
273
+ with gr.Column(elem_classes="grid-item"): # 右侧列
274
+ with gr.Row(elem_classes="flex-grow"):
275
+
276
+ with gr.Group():
277
+ # result = gr.Gallery(label="Result", show_label=False, rows=1, columns=1)
278
+ result = gr.Image(label="Result", show_label=False, height=238, width=256)
279
+ generate_button = gr.Button(value="Generate", variant="primary")
280
+
281
+ with gr.Accordion("Advanced options", open=True):
282
+ with gr.Row():
283
+ with gr.Column():
284
+ # with gr.Row(elem_classes="flex-grow"):
285
+ model_choice = gr.Dropdown(
286
+ [
287
+ "AOM3 (SD-based)",
288
+ "SD1.5-512",
289
+ "RealVis-v5.1 (SD-based)",
290
+ "SDXL-1024",
291
+ "RealVisXL-v5.0 (SDXL-based)",
292
+ "Animagine-XL-v4.0 (SDXL-based)",
293
+ "FLUX-schnell"
294
+ ],
295
+ label="Model",
296
+ value="RealVis-v5.1 (SD-based)",
297
+ interactive=True,
298
+ info="XL-Series takes longer time and FLUX takes even more",
299
+ )
300
+ condition_scale = gr.Slider(
301
+ label="Condition Scale of Reference Image",
302
+ minimum=0.4,
303
+ maximum=1.5,
304
+ step=0.05,
305
+ value=1.0,
306
+ interactive=True,
307
+ )
308
+ warmup_ratio = gr.Slider(
309
+ label="Warmup Ratio",
310
  minimum=0.0,
311
+ maximum=1,
312
+ step=0.05,
313
+ value=0.2,
314
+ interactive=True,
315
  )
316
+ guidance_scale = gr.Slider(
317
+ label="Guidance Scale",
318
+ minimum=0,
319
+ maximum=10,
320
+ step=0.1,
321
+ value=5.0,
322
+ interactive=True,
323
  )
324
+ num_inference_steps = gr.Slider(
325
+ label="Inference Steps",
326
+ minimum=10,
327
+ maximum=50,
328
+ step=1,
329
+ value=50,
330
+ interactive=True,
331
+ )
332
+ with gr.Column():
333
+ seed = gr.Slider(
334
+ label="Seed",
335
+ minimum=0,
336
+ maximum=MAX_SEED,
337
+ step=1,
338
+ value=0,
339
+ )
340
 
341
+ gr.Examples(
342
+ examples=get_example(),
 
 
343
  inputs=[
344
  prompt,
345
+ textual_concept,
346
  negative_prompt,
347
+ image_prompt,
 
 
 
 
348
  num_inference_steps,
349
+ guidance_scale,
350
+ condition_scale,
351
+ warmup_ratio,
352
+ seed,
353
+ model_choice
354
  ],
355
+ cache_examples=CACHE_EXAMPLES,
356
+ )
357
+
358
+ # model_choice.change(
359
+ # fn=change_generate_button_fn,
360
+ # inputs=gr.Number(0, visible=False),
361
+ # outputs=generate_button,
362
+ # )
363
+
364
+ # .then(fn=change_model_fn, inputs=model_choice).then(
365
+ # fn=change_generate_button_fn,
366
+ # inputs=gr.Number(1, visible=False),
367
+ # outputs=generate_button,
368
+ # )
369
+
370
+ inputs = [
371
+ prompt,
372
+ textual_concept,
373
+ negative_prompt,
374
+ image_prompt,
375
+ num_inference_steps,
376
+ guidance_scale,
377
+ condition_scale,
378
+ warmup_ratio,
379
+ seed,
380
+ model_choice
381
+ ]
382
+ generate_button.click(
383
+ fn=dynamic_gallery_fn,
384
+ outputs=result,
385
+ ).then(
386
+ fn=generate,
387
+ inputs=inputs,
388
+ outputs=result,
389
  )
390
+ gr.Markdown(article)
391
 
392
+ demo.launch()
 
demo/book.jpg ADDED
demo/horse.jpg ADDED
demo/statue.jpg ADDED
demo/t-shirt.jpg ADDED
ip_adapter/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .ip_adapter import (
2
+ IPAdapter,
3
+ IPAdapterFull,
4
+ IPAdapterPlus,
5
+ IPAdapterPlusXL,
6
+ IPAdapterXL,
7
+ ConceptrolIPAdapter,
8
+ ConceptrolIPAdapterPlus,
9
+ ConceptrolIPAdapterPlusXL,
10
+ ConceptrolIPAdapterXL,
11
+ )
12
+
13
+ __all__ = [
14
+ "IPAdapter",
15
+ "IPAdapterPlus",
16
+ "IPAdapterPlusXL",
17
+ "IPAdapterXL",
18
+ "IPAdapterFull",
19
+ "ConceptrolIPAdapter",
20
+ "ConceptrolIPAdapterPlus",
21
+ "ConceptrolIPAdapterXL",
22
+ "ConceptrolIPAdapterPlusXL",
23
+ ]
ip_adapter/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (488 Bytes). View file
 
ip_adapter/__pycache__/attention_processor.cpython-310.pyc ADDED
Binary file (14.2 kB). View file
 
ip_adapter/__pycache__/custom_pipelines.cpython-310.pyc ADDED
Binary file (28.9 kB). View file
 
ip_adapter/__pycache__/ip_adapter.cpython-310.pyc ADDED
Binary file (17.9 kB). View file
 
ip_adapter/__pycache__/resampler.cpython-310.pyc ADDED
Binary file (5.62 kB). View file
 
ip_adapter/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.67 kB). View file
 
ip_adapter/attention_processor.py ADDED
@@ -0,0 +1,948 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ # Global Variable
7
+ global_concept_mask = []
8
+ attn_mask_logs = {}
9
+ text_attn_map_logs = {}
10
+ image_attn_map_logs = {}
11
+
12
+
13
+ class AttnProcessor(nn.Module):
14
+ r"""
15
+ Default processor for performing attention-related computations.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ hidden_size=None,
21
+ cross_attention_dim=None,
22
+ ):
23
+ super().__init__()
24
+
25
+ def __call__(
26
+ self,
27
+ attn,
28
+ hidden_states,
29
+ encoder_hidden_states=None,
30
+ attention_mask=None,
31
+ temb=None,
32
+ *args,
33
+ **kwargs,
34
+ ):
35
+ residual = hidden_states
36
+
37
+ if attn.spatial_norm is not None:
38
+ hidden_states = attn.spatial_norm(hidden_states, temb)
39
+
40
+ input_ndim = hidden_states.ndim
41
+
42
+ if input_ndim == 4:
43
+ batch_size, channel, height, width = hidden_states.shape
44
+ hidden_states = hidden_states.view(
45
+ batch_size, channel, height * width
46
+ ).transpose(1, 2)
47
+
48
+ batch_size, sequence_length, _ = (
49
+ hidden_states.shape
50
+ if encoder_hidden_states is None
51
+ else encoder_hidden_states.shape
52
+ )
53
+ attention_mask = attn.prepare_attention_mask(
54
+ attention_mask, sequence_length, batch_size
55
+ )
56
+
57
+ if attn.group_norm is not None:
58
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
59
+ 1, 2
60
+ )
61
+
62
+ query = attn.to_q(hidden_states)
63
+
64
+ if encoder_hidden_states is None:
65
+ encoder_hidden_states = hidden_states
66
+ elif attn.norm_cross:
67
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
68
+ encoder_hidden_states
69
+ )
70
+
71
+ key = attn.to_k(encoder_hidden_states)
72
+ value = attn.to_v(encoder_hidden_states)
73
+
74
+ query = attn.head_to_batch_dim(query)
75
+ key = attn.head_to_batch_dim(key)
76
+ value = attn.head_to_batch_dim(value)
77
+
78
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
79
+ hidden_states = torch.bmm(attention_probs, value)
80
+ hidden_states = attn.batch_to_head_dim(hidden_states)
81
+
82
+ # linear proj
83
+ hidden_states = attn.to_out[0](hidden_states)
84
+ # dropout
85
+ hidden_states = attn.to_out[1](hidden_states)
86
+
87
+ if input_ndim == 4:
88
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
89
+ batch_size, channel, height, width
90
+ )
91
+
92
+ if attn.residual_connection:
93
+ hidden_states = hidden_states + residual
94
+
95
+ hidden_states = hidden_states / attn.rescale_output_factor
96
+
97
+ return hidden_states
98
+
99
+
100
+ class IPAttnProcessor(nn.Module):
101
+ r"""
102
+ Attention processor for IP-Adapater.
103
+ Args:
104
+ hidden_size (`int`):
105
+ The hidden size of the attention layer.
106
+ cross_attention_dim (`int`):
107
+ The number of channels in the `encoder_hidden_states`.
108
+ scale (`float`, defaults to 1.0):
109
+ the weight scale of image prompt.
110
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
111
+ The context length of the image features.
112
+ """
113
+
114
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
115
+ super().__init__()
116
+
117
+ self.hidden_size = hidden_size
118
+ self.cross_attention_dim = cross_attention_dim
119
+ self.scale = scale
120
+ self.num_tokens = num_tokens
121
+
122
+ self.to_k_ip = nn.Linear(
123
+ cross_attention_dim or hidden_size, hidden_size, bias=False
124
+ )
125
+ self.to_v_ip = nn.Linear(
126
+ cross_attention_dim or hidden_size, hidden_size, bias=False
127
+ )
128
+
129
+ def __call__(
130
+ self,
131
+ attn,
132
+ hidden_states,
133
+ encoder_hidden_states=None,
134
+ attention_mask=None,
135
+ temb=None,
136
+ *args,
137
+ **kwargs,
138
+ ):
139
+ global global_concept_mask
140
+ global attn_mask_logs
141
+ global text_attn_map_logs
142
+ global image_attn_map_logs
143
+ residual = hidden_states
144
+
145
+ if attn.spatial_norm is not None:
146
+ hidden_states = attn.spatial_norm(hidden_states, temb)
147
+
148
+ input_ndim = hidden_states.ndim
149
+
150
+ if input_ndim == 4:
151
+ batch_size, channel, height, width = hidden_states.shape
152
+ hidden_states = hidden_states.view(
153
+ batch_size, channel, height * width
154
+ ).transpose(1, 2)
155
+
156
+ batch_size, sequence_length, _ = (
157
+ hidden_states.shape
158
+ if encoder_hidden_states is None
159
+ else encoder_hidden_states.shape
160
+ )
161
+ attention_mask = attn.prepare_attention_mask(
162
+ attention_mask, sequence_length, batch_size
163
+ )
164
+
165
+ if attn.group_norm is not None:
166
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
167
+ 1, 2
168
+ )
169
+
170
+ query = attn.to_q(hidden_states)
171
+
172
+ if encoder_hidden_states is None:
173
+ encoder_hidden_states = hidden_states
174
+ else:
175
+ # get encoder_hidden_states, ip_hidden_states
176
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
177
+ encoder_hidden_states, ip_hidden_states = (
178
+ encoder_hidden_states[:, :end_pos, :],
179
+ encoder_hidden_states[:, end_pos:, :],
180
+ )
181
+ if attn.norm_cross:
182
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
183
+ encoder_hidden_states
184
+ )
185
+
186
+ key = attn.to_k(encoder_hidden_states)
187
+ value = attn.to_v(encoder_hidden_states)
188
+
189
+ query = attn.head_to_batch_dim(query)
190
+ key = attn.head_to_batch_dim(key)
191
+ value = attn.head_to_batch_dim(value)
192
+
193
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
194
+ hidden_states = torch.bmm(attention_probs, value)
195
+ hidden_states = attn.batch_to_head_dim(hidden_states)
196
+
197
+ # for ip-adapter
198
+ ip_key = self.to_k_ip(ip_hidden_states)
199
+ ip_value = self.to_v_ip(ip_hidden_states)
200
+
201
+ ip_key = attn.head_to_batch_dim(ip_key)
202
+ ip_value = attn.head_to_batch_dim(ip_value)
203
+
204
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
205
+ self.attn_map = ip_attention_probs
206
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
207
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
208
+
209
+ hidden_states = hidden_states + self.scale * ip_hidden_states
210
+
211
+ # linear proj
212
+ hidden_states = attn.to_out[0](hidden_states)
213
+ # dropout
214
+ hidden_states = attn.to_out[1](hidden_states)
215
+
216
+ if input_ndim == 4:
217
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
218
+ batch_size, channel, height, width
219
+ )
220
+
221
+ if attn.residual_connection:
222
+ hidden_states = hidden_states + residual
223
+
224
+ hidden_states = hidden_states / attn.rescale_output_factor
225
+
226
+ return hidden_states
227
+
228
+
229
+ class ConceptrolAttnProcessor(nn.Module):
230
+ r"""
231
+ Attention processor for IP-Adapater.
232
+ Args:
233
+ hidden_size (`int`):
234
+ The hidden size of the attention layer.
235
+ cross_attention_dim (`int`):
236
+ The number of channels in the `encoder_hidden_states`.
237
+ scale (`float`, defaults to 1.0):
238
+ the weight scale of image prompt.
239
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
240
+ The context length of the image features.
241
+ """
242
+
243
+ def __init__(
244
+ self,
245
+ hidden_size,
246
+ cross_attention_dim=None,
247
+ scale=1.0,
248
+ num_tokens=4,
249
+ textual_concept_idxs=None,
250
+ name=None,
251
+ global_masking=False,
252
+ adaptive_scale_mask=False,
253
+ concept_mask_layer=None,
254
+ ):
255
+ super().__init__()
256
+
257
+ self.hidden_size = hidden_size
258
+ self.cross_attention_dim = cross_attention_dim
259
+ self.scale = scale
260
+ self.num_tokens = num_tokens
261
+
262
+ self.textual_concept_idxs = textual_concept_idxs
263
+ self.name = name
264
+
265
+ self.to_k_ip = nn.Linear(
266
+ cross_attention_dim or hidden_size, hidden_size, bias=False
267
+ )
268
+ self.to_v_ip = nn.Linear(
269
+ cross_attention_dim or hidden_size, hidden_size, bias=False
270
+ )
271
+
272
+ self.global_masking = global_masking
273
+ self.adaptive_scale_mask = adaptive_scale_mask
274
+
275
+ if concept_mask_layer is None:
276
+ concept_mask_layer = [
277
+ "mid_block.attentions.0.transformer_blocks.0.attn2.processor"
278
+ ] # For SD
279
+ print("Warning: Using default concept mask layer for SD. For SDXL, use 'up_blocks.0.attentions.1.transformer_blocks.5.attn2.processor'")
280
+ # concept_mask_layer = ['up_blocks.0.attentions.1.transformer_blocks.1.attn2.processor'] # For SDXL
281
+ self.concept_mask_layer = concept_mask_layer
282
+
283
+ def set_global_view(self, attn_procs):
284
+ self.attn_procs = attn_procs
285
+ # print(self.name, self.attn_procs.keys())
286
+
287
+ def __call__(
288
+ self,
289
+ attn,
290
+ hidden_states,
291
+ encoder_hidden_states=None,
292
+ attention_mask=None,
293
+ temb=None,
294
+ *args,
295
+ **kwargs,
296
+ ):
297
+ global global_concept_mask
298
+ global attn_mask_logs
299
+
300
+ if self.textual_concept_idxs is None:
301
+ raise ValueError(
302
+ "textual_concept_idxs should be provided for ConceptrolAttnProcessor"
303
+ )
304
+ residual = hidden_states
305
+
306
+ if attn.spatial_norm is not None:
307
+ hidden_states = attn.spatial_norm(hidden_states, temb)
308
+
309
+ input_ndim = hidden_states.ndim
310
+
311
+ if input_ndim == 4:
312
+ batch_size, channel, height, width = hidden_states.shape
313
+ hidden_states = hidden_states.view(
314
+ batch_size, channel, height * width
315
+ ).transpose(1, 2)
316
+
317
+ batch_size, sequence_length, _ = (
318
+ hidden_states.shape
319
+ if encoder_hidden_states is None
320
+ else encoder_hidden_states.shape
321
+ )
322
+ attention_mask = attn.prepare_attention_mask(
323
+ attention_mask, sequence_length, batch_size
324
+ )
325
+
326
+ if attn.group_norm is not None:
327
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
328
+ 1, 2
329
+ )
330
+
331
+ query = attn.to_q(hidden_states)
332
+
333
+ if encoder_hidden_states is None:
334
+ encoder_hidden_states = hidden_states
335
+ else:
336
+ # get encoder_hidden_states, ip_hidden_states
337
+ end_pos = 77 # Both SD and SDXL use 77 as length of text tokens
338
+ encoder_hidden_states, ip_hidden_states_cat = (
339
+ encoder_hidden_states[:, :end_pos, :],
340
+ encoder_hidden_states[:, end_pos:, :],
341
+ )
342
+ num_concepts = ip_hidden_states_cat.shape[1] // self.num_tokens
343
+ ip_hidden_states_list = torch.chunk(
344
+ ip_hidden_states_cat, num_concepts, dim=1
345
+ )
346
+ assert len(ip_hidden_states_list) == len(
347
+ self.textual_concept_idxs
348
+ ), f"register_idxs should have the same length as the number of concepts, but got {len(ip_hidden_states_list)} and {len(self.textual_concept_idxs)}"
349
+ if attn.norm_cross:
350
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
351
+ encoder_hidden_states
352
+ )
353
+
354
+ key = attn.to_k(encoder_hidden_states)
355
+ value = attn.to_v(encoder_hidden_states)
356
+
357
+ query = attn.head_to_batch_dim(query) # [16, 4096, 40]
358
+ key = attn.head_to_batch_dim(key) # [16, 77, 40]
359
+ value = attn.head_to_batch_dim(value) # [16, 77, 40]
360
+
361
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
362
+ hidden_states = torch.bmm(attention_probs, value)
363
+ hidden_states = attn.batch_to_head_dim(hidden_states)
364
+
365
+ concept_mask_layer = self.concept_mask_layer
366
+ if len(global_concept_mask) == 0:
367
+ global_concept_mask = [None for _ in range(len(self.textual_concept_idxs))]
368
+ for i in range(len(self.textual_concept_idxs)):
369
+ ip_hidden_states = ip_hidden_states_list[i]
370
+ textual_concept_start_idx, textual_concept_end_idx = (
371
+ self.textual_concept_idxs[i]
372
+ )
373
+ ip_key = self.to_k_ip(ip_hidden_states)
374
+ ip_value = self.to_v_ip(ip_hidden_states)
375
+
376
+ ip_key = attn.head_to_batch_dim(ip_key) # [16, 4, 40]
377
+ ip_value = attn.head_to_batch_dim(ip_value) # [16, 4, 40]
378
+
379
+ # attention_probs: [20/40, 4096, 77]
380
+
381
+ ip_attention_mask = attention_probs[
382
+ :, :, textual_concept_start_idx:textual_concept_end_idx
383
+ ] # [16, 4096, T]
384
+ ip_attention_mask = torch.mean(
385
+ ip_attention_mask, dim=-1, keepdim=True
386
+ ) # [16, 4096, 1]
387
+ ip_attention_mask = attn.batch_to_head_dim(
388
+ ip_attention_mask
389
+ ) # [2, 4096, 8]
390
+ ip_attention_mask = torch.mean(
391
+ ip_attention_mask, dim=-1, keepdim=True
392
+ ) # [2, 4096, 1]
393
+
394
+ ip_attention_mask = ip_attention_mask / (
395
+ torch.amax(ip_attention_mask, dim=-2, keepdim=True) + 1e-6
396
+ )
397
+
398
+ ip_attention_mask = ip_attention_mask[1:2] # (use the classifier one)
399
+
400
+ # Visualization
401
+ if self.name not in attn_mask_logs:
402
+ attn_mask_logs[self.name] = []
403
+ text_attn_map_logs[self.name] = []
404
+ image_attn_map_logs[self.name] = []
405
+ attn_mask_logs[self.name].append(
406
+ ip_attention_mask.detach().cpu().numpy()[0, :, 0]
407
+ )
408
+ text_attn_map_logs[self.name].append(
409
+ ip_attention_mask.detach().cpu().numpy()[0, :, 0]
410
+ )
411
+
412
+ if self.global_masking and (
413
+ self.name == concept_mask_layer[0]
414
+ ):
415
+ global_concept_mask[i] = ip_attention_mask
416
+
417
+ if (
418
+ self.global_masking
419
+ and self.name != concept_mask_layer[0]
420
+ and global_concept_mask[i] is not None
421
+ ):
422
+ original_dim = int(global_concept_mask[i].shape[1] ** 0.5)
423
+ target_dim = int(hidden_states.shape[1] ** 0.5)
424
+ global_concept_mask_2d = global_concept_mask[i].view(
425
+ global_concept_mask[i].shape[0], 1, original_dim, original_dim
426
+ )
427
+ resized_global_concept_mask_2d = F.interpolate(
428
+ global_concept_mask_2d,
429
+ size=(target_dim, target_dim),
430
+ mode="nearest",
431
+ )
432
+ resized_global_concept_mask = resized_global_concept_mask_2d.view(
433
+ global_concept_mask[i].shape[0], -1, 1
434
+ )
435
+ ip_attention_mask = resized_global_concept_mask
436
+
437
+ ip_attention_probs = attn.get_attention_scores(
438
+ query, ip_key, None
439
+ ) # [16, 4096, 4]
440
+
441
+ # Visualization
442
+ ip_attention_map = attention_probs[:, :, 15:16] # [16, 4096, T]
443
+ ip_attention_map = torch.mean(
444
+ ip_attention_map, dim=-1, keepdim=True
445
+ ) # [16, 4096, 1]
446
+ ip_attention_map = torch.mean(
447
+ ip_attention_map, dim=-1, keepdim=True
448
+ ) # [16, 4096, 1]
449
+ ip_attention_map = attn.batch_to_head_dim(ip_attention_map) # [2, 4096, 8]
450
+ ip_attention_map = torch.mean(
451
+ ip_attention_map, dim=-1, keepdim=True
452
+ ) # [2, 4096, 1]
453
+ ip_attention_map = ip_attention_map / (
454
+ torch.amax(ip_attention_map, dim=-2, keepdim=True) + 1e-6
455
+ )
456
+ ip_attention_map = ip_attention_map[1:2] # (use the classifier one)
457
+ image_attn_map_logs[self.name].append(
458
+ ip_attention_map.detach().cpu().numpy()[0, :, 0]
459
+ )
460
+
461
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) # [16, 4096, 40]
462
+ ip_hidden_states = attn.batch_to_head_dim(
463
+ ip_hidden_states
464
+ ) # [2, 4096, 320]
465
+ ip_hidden_states = ip_hidden_states * ip_attention_mask
466
+
467
+ if self.adaptive_scale_mask:
468
+ raise ValueError("adaptive_scale_mask is deprecated already")
469
+
470
+ hidden_states += self.scale * ip_hidden_states
471
+
472
+ # linear proj
473
+ hidden_states = attn.to_out[0](hidden_states)
474
+ # dropout
475
+ hidden_states = attn.to_out[1](hidden_states)
476
+
477
+ if input_ndim == 4:
478
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
479
+ batch_size, channel, height, width
480
+ )
481
+
482
+ if attn.residual_connection:
483
+ hidden_states = hidden_states + residual
484
+
485
+ hidden_states = hidden_states / attn.rescale_output_factor
486
+
487
+ return hidden_states
488
+
489
+
490
+ class AttnProcessor2_0(torch.nn.Module):
491
+ r"""
492
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
493
+ """
494
+
495
+ def __init__(
496
+ self,
497
+ hidden_size=None,
498
+ cross_attention_dim=None,
499
+ ):
500
+ super().__init__()
501
+ if not hasattr(F, "scaled_dot_product_attention"):
502
+ raise ImportError(
503
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
504
+ )
505
+
506
+ def __call__(
507
+ self,
508
+ attn,
509
+ hidden_states,
510
+ encoder_hidden_states=None,
511
+ attention_mask=None,
512
+ temb=None,
513
+ *args,
514
+ **kwargs,
515
+ ):
516
+ residual = hidden_states
517
+
518
+ if attn.spatial_norm is not None:
519
+ hidden_states = attn.spatial_norm(hidden_states, temb)
520
+
521
+ input_ndim = hidden_states.ndim
522
+
523
+ if input_ndim == 4:
524
+ batch_size, channel, height, width = hidden_states.shape
525
+ hidden_states = hidden_states.view(
526
+ batch_size, channel, height * width
527
+ ).transpose(1, 2)
528
+
529
+ batch_size, sequence_length, _ = (
530
+ hidden_states.shape
531
+ if encoder_hidden_states is None
532
+ else encoder_hidden_states.shape
533
+ )
534
+
535
+ if attention_mask is not None:
536
+ attention_mask = attn.prepare_attention_mask(
537
+ attention_mask, sequence_length, batch_size
538
+ )
539
+ # scaled_dot_product_attention expects attention_mask shape to be
540
+ # (batch, heads, source_length, target_length)
541
+ attention_mask = attention_mask.view(
542
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
543
+ )
544
+
545
+ if attn.group_norm is not None:
546
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
547
+ 1, 2
548
+ )
549
+
550
+ query = attn.to_q(hidden_states)
551
+
552
+ if encoder_hidden_states is None:
553
+ encoder_hidden_states = hidden_states
554
+ elif attn.norm_cross:
555
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
556
+ encoder_hidden_states
557
+ )
558
+
559
+ key = attn.to_k(encoder_hidden_states)
560
+ value = attn.to_v(encoder_hidden_states)
561
+
562
+ inner_dim = key.shape[-1]
563
+ head_dim = inner_dim // attn.heads
564
+
565
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
566
+
567
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
568
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
569
+
570
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
571
+ # TODO: add support for attn.scale when we move to Torch 2.1
572
+ hidden_states = F.scaled_dot_product_attention(
573
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
574
+ )
575
+
576
+ hidden_states = hidden_states.transpose(1, 2).reshape(
577
+ batch_size, -1, attn.heads * head_dim
578
+ )
579
+ hidden_states = hidden_states.to(query.dtype)
580
+
581
+ # linear proj
582
+ hidden_states = attn.to_out[0](hidden_states)
583
+ # dropout
584
+ hidden_states = attn.to_out[1](hidden_states)
585
+
586
+ if input_ndim == 4:
587
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
588
+ batch_size, channel, height, width
589
+ )
590
+
591
+ if attn.residual_connection:
592
+ hidden_states = hidden_states + residual
593
+
594
+ hidden_states = hidden_states / attn.rescale_output_factor
595
+
596
+ return hidden_states
597
+
598
+
599
+ class IPAttnProcessor2_0(torch.nn.Module):
600
+ r"""
601
+ Attention processor for IP-Adapater for PyTorch 2.0.
602
+ Args:
603
+ hidden_size (`int`):
604
+ The hidden size of the attention layer.
605
+ cross_attention_dim (`int`):
606
+ The number of channels in the `encoder_hidden_states`.
607
+ scale (`float`, defaults to 1.0):
608
+ the weight scale of image prompt.
609
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
610
+ The context length of the image features.
611
+ """
612
+
613
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
614
+ super().__init__()
615
+
616
+ if not hasattr(F, "scaled_dot_product_attention"):
617
+ raise ImportError(
618
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
619
+ )
620
+
621
+ self.hidden_size = hidden_size
622
+ self.cross_attention_dim = cross_attention_dim
623
+ self.scale = scale
624
+ self.num_tokens = num_tokens
625
+
626
+ self.to_k_ip = nn.Linear(
627
+ cross_attention_dim or hidden_size, hidden_size, bias=False
628
+ )
629
+ self.to_v_ip = nn.Linear(
630
+ cross_attention_dim or hidden_size, hidden_size, bias=False
631
+ )
632
+
633
+ def __call__(
634
+ self,
635
+ attn,
636
+ hidden_states,
637
+ encoder_hidden_states=None,
638
+ attention_mask=None,
639
+ temb=None,
640
+ *args,
641
+ **kwargs,
642
+ ):
643
+ residual = hidden_states
644
+
645
+ if attn.spatial_norm is not None:
646
+ hidden_states = attn.spatial_norm(hidden_states, temb)
647
+
648
+ input_ndim = hidden_states.ndim
649
+
650
+ if input_ndim == 4:
651
+ batch_size, channel, height, width = hidden_states.shape
652
+ hidden_states = hidden_states.view(
653
+ batch_size, channel, height * width
654
+ ).transpose(1, 2)
655
+
656
+ batch_size, sequence_length, _ = (
657
+ hidden_states.shape
658
+ if encoder_hidden_states is None
659
+ else encoder_hidden_states.shape
660
+ )
661
+
662
+ if attention_mask is not None:
663
+ attention_mask = attn.prepare_attention_mask(
664
+ attention_mask, sequence_length, batch_size
665
+ )
666
+ # scaled_dot_product_attention expects attention_mask shape to be
667
+ # (batch, heads, source_length, target_length)
668
+ attention_mask = attention_mask.view(
669
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
670
+ )
671
+
672
+ if attn.group_norm is not None:
673
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
674
+ 1, 2
675
+ )
676
+
677
+ query = attn.to_q(hidden_states)
678
+
679
+ if encoder_hidden_states is None:
680
+ encoder_hidden_states = hidden_states
681
+ else:
682
+ # get encoder_hidden_states, ip_hidden_states
683
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
684
+ encoder_hidden_states, ip_hidden_states = (
685
+ encoder_hidden_states[:, :end_pos, :],
686
+ encoder_hidden_states[:, end_pos:, :],
687
+ )
688
+ if attn.norm_cross:
689
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
690
+ encoder_hidden_states
691
+ )
692
+
693
+ key = attn.to_k(encoder_hidden_states)
694
+ value = attn.to_v(encoder_hidden_states)
695
+
696
+ inner_dim = key.shape[-1]
697
+ head_dim = inner_dim // attn.heads
698
+
699
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
700
+
701
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
702
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
703
+
704
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
705
+ # TODO: add support for attn.scale when we move to Torch 2.1
706
+ hidden_states = F.scaled_dot_product_attention(
707
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
708
+ )
709
+
710
+ hidden_states = hidden_states.transpose(1, 2).reshape(
711
+ batch_size, -1, attn.heads * head_dim
712
+ )
713
+ hidden_states = hidden_states.to(query.dtype)
714
+
715
+ # for ip-adapter
716
+ ip_key = self.to_k_ip(ip_hidden_states)
717
+ ip_value = self.to_v_ip(ip_hidden_states)
718
+
719
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
720
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
721
+
722
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
723
+ # TODO: add support for attn.scale when we move to Torch 2.1
724
+ ip_hidden_states = F.scaled_dot_product_attention(
725
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
726
+ )
727
+ with torch.no_grad():
728
+ self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
729
+ # print(self.attn_map.shape)
730
+
731
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
732
+ batch_size, -1, attn.heads * head_dim
733
+ )
734
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
735
+
736
+ hidden_states = hidden_states + self.scale * ip_hidden_states
737
+
738
+ # linear proj
739
+ hidden_states = attn.to_out[0](hidden_states)
740
+ # dropout
741
+ hidden_states = attn.to_out[1](hidden_states)
742
+
743
+ if input_ndim == 4:
744
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
745
+ batch_size, channel, height, width
746
+ )
747
+
748
+ if attn.residual_connection:
749
+ hidden_states = hidden_states + residual
750
+
751
+ hidden_states = hidden_states / attn.rescale_output_factor
752
+
753
+ return hidden_states
754
+
755
+
756
+ ## for controlnet
757
+ class CNAttnProcessor:
758
+ r"""
759
+ Default processor for performing attention-related computations.
760
+ """
761
+
762
+ def __init__(self, num_tokens=4):
763
+ self.num_tokens = num_tokens
764
+
765
+ def __call__(
766
+ self,
767
+ attn,
768
+ hidden_states,
769
+ encoder_hidden_states=None,
770
+ attention_mask=None,
771
+ temb=None,
772
+ *args,
773
+ **kwargs,
774
+ ):
775
+ residual = hidden_states
776
+
777
+ if attn.spatial_norm is not None:
778
+ hidden_states = attn.spatial_norm(hidden_states, temb)
779
+
780
+ input_ndim = hidden_states.ndim
781
+
782
+ if input_ndim == 4:
783
+ batch_size, channel, height, width = hidden_states.shape
784
+ hidden_states = hidden_states.view(
785
+ batch_size, channel, height * width
786
+ ).transpose(1, 2)
787
+
788
+ batch_size, sequence_length, _ = (
789
+ hidden_states.shape
790
+ if encoder_hidden_states is None
791
+ else encoder_hidden_states.shape
792
+ )
793
+ attention_mask = attn.prepare_attention_mask(
794
+ attention_mask, sequence_length, batch_size
795
+ )
796
+
797
+ if attn.group_norm is not None:
798
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
799
+ 1, 2
800
+ )
801
+
802
+ query = attn.to_q(hidden_states)
803
+
804
+ if encoder_hidden_states is None:
805
+ encoder_hidden_states = hidden_states
806
+ else:
807
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
808
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
809
+ if attn.norm_cross:
810
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
811
+ encoder_hidden_states
812
+ )
813
+
814
+ key = attn.to_k(encoder_hidden_states)
815
+ value = attn.to_v(encoder_hidden_states)
816
+
817
+ query = attn.head_to_batch_dim(query)
818
+ key = attn.head_to_batch_dim(key)
819
+ value = attn.head_to_batch_dim(value)
820
+
821
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
822
+ hidden_states = torch.bmm(attention_probs, value)
823
+ hidden_states = attn.batch_to_head_dim(hidden_states)
824
+
825
+ # linear proj
826
+ hidden_states = attn.to_out[0](hidden_states)
827
+ # dropout
828
+ hidden_states = attn.to_out[1](hidden_states)
829
+
830
+ if input_ndim == 4:
831
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
832
+ batch_size, channel, height, width
833
+ )
834
+
835
+ if attn.residual_connection:
836
+ hidden_states = hidden_states + residual
837
+
838
+ hidden_states = hidden_states / attn.rescale_output_factor
839
+
840
+ return hidden_states
841
+
842
+
843
+ class CNAttnProcessor2_0:
844
+ r"""
845
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
846
+ """
847
+
848
+ def __init__(self, num_tokens=4):
849
+ if not hasattr(F, "scaled_dot_product_attention"):
850
+ raise ImportError(
851
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
852
+ )
853
+ self.num_tokens = num_tokens
854
+
855
+ def __call__(
856
+ self,
857
+ attn,
858
+ hidden_states,
859
+ encoder_hidden_states=None,
860
+ attention_mask=None,
861
+ temb=None,
862
+ *args,
863
+ **kwargs,
864
+ ):
865
+ residual = hidden_states
866
+
867
+ if attn.spatial_norm is not None:
868
+ hidden_states = attn.spatial_norm(hidden_states, temb)
869
+
870
+ input_ndim = hidden_states.ndim
871
+
872
+ if input_ndim == 4:
873
+ batch_size, channel, height, width = hidden_states.shape
874
+ hidden_states = hidden_states.view(
875
+ batch_size, channel, height * width
876
+ ).transpose(1, 2)
877
+
878
+ batch_size, sequence_length, _ = (
879
+ hidden_states.shape
880
+ if encoder_hidden_states is None
881
+ else encoder_hidden_states.shape
882
+ )
883
+
884
+ if attention_mask is not None:
885
+ attention_mask = attn.prepare_attention_mask(
886
+ attention_mask, sequence_length, batch_size
887
+ )
888
+ # scaled_dot_product_attention expects attention_mask shape to be
889
+ # (batch, heads, source_length, target_length)
890
+ attention_mask = attention_mask.view(
891
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
892
+ )
893
+
894
+ if attn.group_norm is not None:
895
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
896
+ 1, 2
897
+ )
898
+
899
+ query = attn.to_q(hidden_states)
900
+
901
+ if encoder_hidden_states is None:
902
+ encoder_hidden_states = hidden_states
903
+ else:
904
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
905
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
906
+ if attn.norm_cross:
907
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
908
+ encoder_hidden_states
909
+ )
910
+
911
+ key = attn.to_k(encoder_hidden_states)
912
+ value = attn.to_v(encoder_hidden_states)
913
+
914
+ inner_dim = key.shape[-1]
915
+ head_dim = inner_dim // attn.heads
916
+
917
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
918
+
919
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
920
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
921
+
922
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
923
+ # TODO: add support for attn.scale when we move to Torch 2.1
924
+ hidden_states = F.scaled_dot_product_attention(
925
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
926
+ )
927
+
928
+ hidden_states = hidden_states.transpose(1, 2).reshape(
929
+ batch_size, -1, attn.heads * head_dim
930
+ )
931
+ hidden_states = hidden_states.to(query.dtype)
932
+
933
+ # linear proj
934
+ hidden_states = attn.to_out[0](hidden_states)
935
+ # dropout
936
+ hidden_states = attn.to_out[1](hidden_states)
937
+
938
+ if input_ndim == 4:
939
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
940
+ batch_size, channel, height, width
941
+ )
942
+
943
+ if attn.residual_connection:
944
+ hidden_states = hidden_states + residual
945
+
946
+ hidden_states = hidden_states / attn.rescale_output_factor
947
+
948
+ return hidden_states
ip_adapter/custom_pipelines.py ADDED
@@ -0,0 +1,805 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
5
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
6
+ from diffusers.image_processor import PipelineImageInput
7
+ from diffusers.pipelines.stable_diffusion.pipeline_output import (
8
+ StableDiffusionPipelineOutput,
9
+ )
10
+ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
11
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import (
12
+ rescale_noise_cfg,
13
+ )
14
+
15
+ from .attention_processor import IPAttnProcessor, ConceptrolAttnProcessor
16
+ from . import attention_processor
17
+
18
+
19
+ class StableDiffusionXLCustomPipeline(StableDiffusionXLPipeline):
20
+ def set_scale(self, scale):
21
+ for attn_processor in self.unet.attn_processors.values():
22
+ if isinstance(attn_processor, (IPAttnProcessor, ConceptrolAttnProcessor)):
23
+ attn_processor.scale = scale
24
+
25
+ @torch.no_grad()
26
+ def __call__( # noqa: C901
27
+ self,
28
+ prompt: Optional[Union[str, List[str]]] = None,
29
+ prompt_2: Optional[Union[str, List[str]]] = None,
30
+ height: Optional[int] = None,
31
+ width: Optional[int] = None,
32
+ num_inference_steps: int = 50,
33
+ denoising_end: Optional[float] = None,
34
+ guidance_scale: float = 6.0,
35
+ negative_prompt: Optional[Union[str, List[str]]] = None,
36
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
37
+ num_images_per_prompt: Optional[int] = 1,
38
+ eta: float = 0.0,
39
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
40
+ latents: Optional[torch.FloatTensor] = None,
41
+ prompt_embeds: Optional[torch.FloatTensor] = None,
42
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
43
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
44
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
45
+ output_type: Optional[str] = "pil",
46
+ return_dict: bool = True,
47
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
48
+ callback_steps: int = 1,
49
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
50
+ guidance_rescale: float = 0.0,
51
+ original_size: Optional[Tuple[int, int]] = None,
52
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
53
+ target_size: Optional[Tuple[int, int]] = None,
54
+ negative_original_size: Optional[Tuple[int, int]] = None,
55
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
56
+ negative_target_size: Optional[Tuple[int, int]] = None,
57
+ control_guidance_start: float = 0.0,
58
+ control_guidance_end: float = 1.0,
59
+ ):
60
+ r"""
61
+ Function invoked when calling the pipeline for generation.
62
+
63
+ Args:
64
+ prompt (`str` or `List[str]`, *optional*):
65
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
66
+ instead.
67
+ prompt_2 (`str` or `List[str]`, *optional*):
68
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
69
+ used in both text-encoders
70
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
71
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
72
+ Anything below 512 pixels won't work well for
73
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
74
+ and checkpoints that are not specifically fine-tuned on low resolutions.
75
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
76
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
77
+ Anything below 512 pixels won't work well for
78
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
79
+ and checkpoints that are not specifically fine-tuned on low resolutions.
80
+ num_inference_steps (`int`, *optional*, defaults to 50):
81
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
82
+ expense of slower inference.
83
+ denoising_end (`float`, *optional*):
84
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
85
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
86
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
87
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
88
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
89
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
90
+ guidance_scale (`float`, *optional*, defaults to 5.0):
91
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
92
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
93
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
94
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
95
+ usually at the expense of lower image quality.
96
+ negative_prompt (`str` or `List[str]`, *optional*):
97
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
98
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
99
+ less than `1`).
100
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
101
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
102
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
103
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
104
+ The number of images to generate per prompt.
105
+ eta (`float`, *optional*, defaults to 0.0):
106
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
107
+ [`schedulers.DDIMScheduler`], will be ignored for others.
108
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
109
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
110
+ to make generation deterministic.
111
+ latents (`torch.FloatTensor`, *optional*):
112
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
113
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
114
+ tensor will ge generated by sampling using the supplied random `generator`.
115
+ prompt_embeds (`torch.FloatTensor`, *optional*):
116
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
117
+ provided, text embeddings will be generated from `prompt` input argument.
118
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
119
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
120
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
121
+ argument.
122
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
123
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
124
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
125
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
126
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
127
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
128
+ input argument.
129
+ output_type (`str`, *optional*, defaults to `"pil"`):
130
+ The output format of the generate image. Choose between
131
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
132
+ return_dict (`bool`, *optional*, defaults to `True`):
133
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
134
+ of a plain tuple.
135
+ callback (`Callable`, *optional*):
136
+ A function that will be called every `callback_steps` steps during inference. The function will be
137
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
138
+ callback_steps (`int`, *optional*, defaults to 1):
139
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
140
+ called at every step.
141
+ cross_attention_kwargs (`dict`, *optional*):
142
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
143
+ `self.processor` in
144
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
145
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
146
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
147
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
148
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
149
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
150
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
151
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
152
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
153
+ explained in section 2.2 of
154
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
155
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
156
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
157
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
158
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
159
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
160
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
161
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
162
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
163
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
164
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
165
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
166
+ micro-conditioning as explained in section 2.2 of
167
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
168
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
169
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
170
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
171
+ micro-conditioning as explained in section 2.2 of
172
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
173
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
174
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
175
+ To negatively condition the generation process based on a target image resolution. It should be as same
176
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
177
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
178
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
179
+ control_guidance_start (`float`, *optional*, defaults to 0.0):
180
+ The percentage of total steps at which the ControlNet starts applying.
181
+ control_guidance_end (`float`, *optional*, defaults to 1.0):
182
+ The percentage of total steps at which the ControlNet stops applying.
183
+
184
+ Examples:
185
+
186
+ Returns:
187
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
188
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
189
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
190
+ """
191
+
192
+ attention_processor.attn_mask_logs = {}
193
+ attention_processor.image_attn_map_logs = {}
194
+ attention_processor.text_attn_map_logs = {}
195
+ # Clear the global concept mask
196
+ attention_processor.global_concept_mask = []
197
+
198
+ # 0. Default height and width to unet
199
+ height = height or self.default_sample_size * self.vae_scale_factor
200
+ width = width or self.default_sample_size * self.vae_scale_factor
201
+
202
+ original_size = original_size or (height, width)
203
+ target_size = target_size or (height, width)
204
+
205
+ # 1. Check inputs. Raise error if not correct
206
+ self.check_inputs(
207
+ prompt,
208
+ prompt_2,
209
+ height,
210
+ width,
211
+ callback_steps,
212
+ negative_prompt,
213
+ negative_prompt_2,
214
+ prompt_embeds,
215
+ negative_prompt_embeds,
216
+ pooled_prompt_embeds,
217
+ negative_pooled_prompt_embeds,
218
+ )
219
+
220
+ # 2. Define call parameters
221
+ if prompt is not None and isinstance(prompt, str):
222
+ batch_size = 1
223
+ elif prompt is not None and isinstance(prompt, list):
224
+ batch_size = len(prompt)
225
+ else:
226
+ batch_size = prompt_embeds.shape[0]
227
+
228
+ device = self._execution_device
229
+
230
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
231
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
232
+ # corresponds to doing no classifier free guidance.
233
+ do_classifier_free_guidance = guidance_scale > 1.0
234
+
235
+ # 3. Encode input prompt
236
+ text_encoder_lora_scale = (
237
+ cross_attention_kwargs.get("scale", None)
238
+ if cross_attention_kwargs is not None
239
+ else None
240
+ )
241
+ (
242
+ prompt_embeds,
243
+ negative_prompt_embeds,
244
+ pooled_prompt_embeds,
245
+ negative_pooled_prompt_embeds,
246
+ ) = self.encode_prompt(
247
+ prompt=prompt,
248
+ prompt_2=prompt_2,
249
+ device=device,
250
+ num_images_per_prompt=num_images_per_prompt,
251
+ do_classifier_free_guidance=do_classifier_free_guidance,
252
+ negative_prompt=negative_prompt,
253
+ negative_prompt_2=negative_prompt_2,
254
+ prompt_embeds=prompt_embeds,
255
+ negative_prompt_embeds=negative_prompt_embeds,
256
+ pooled_prompt_embeds=pooled_prompt_embeds,
257
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
258
+ lora_scale=text_encoder_lora_scale,
259
+ )
260
+
261
+ # 4. Prepare timesteps
262
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
263
+
264
+ timesteps = self.scheduler.timesteps
265
+
266
+ # 5. Prepare latent variables
267
+ num_channels_latents = self.unet.config.in_channels
268
+ latents = self.prepare_latents(
269
+ batch_size * num_images_per_prompt,
270
+ num_channels_latents,
271
+ height,
272
+ width,
273
+ prompt_embeds.dtype,
274
+ device,
275
+ generator,
276
+ latents,
277
+ )
278
+
279
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
280
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
281
+
282
+ # 7. Prepare added time ids & embeddings
283
+ add_text_embeds = pooled_prompt_embeds
284
+ if self.text_encoder_2 is None:
285
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
286
+ else:
287
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
288
+
289
+ add_time_ids = self._get_add_time_ids(
290
+ original_size,
291
+ crops_coords_top_left,
292
+ target_size,
293
+ dtype=prompt_embeds.dtype,
294
+ text_encoder_projection_dim=text_encoder_projection_dim,
295
+ )
296
+ if negative_original_size is not None and negative_target_size is not None:
297
+ negative_add_time_ids = self._get_add_time_ids(
298
+ negative_original_size,
299
+ negative_crops_coords_top_left,
300
+ negative_target_size,
301
+ dtype=prompt_embeds.dtype,
302
+ text_encoder_projection_dim=text_encoder_projection_dim,
303
+ )
304
+ else:
305
+ negative_add_time_ids = add_time_ids
306
+
307
+ if do_classifier_free_guidance:
308
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
309
+ add_text_embeds = torch.cat(
310
+ [negative_pooled_prompt_embeds, add_text_embeds], dim=0
311
+ )
312
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
313
+
314
+ prompt_embeds = prompt_embeds.to(device)
315
+ add_text_embeds = add_text_embeds.to(device)
316
+ add_time_ids = add_time_ids.to(device).repeat(
317
+ batch_size * num_images_per_prompt, 1
318
+ )
319
+
320
+ # 8. Denoising loop
321
+ num_warmup_steps = max(
322
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
323
+ )
324
+
325
+ # 7.1 Apply denoising_end
326
+ if (
327
+ denoising_end is not None
328
+ and isinstance(denoising_end, float)
329
+ and denoising_end > 0
330
+ and denoising_end < 1
331
+ ):
332
+ discrete_timestep_cutoff = int(
333
+ round(
334
+ self.scheduler.config.num_train_timesteps
335
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
336
+ )
337
+ )
338
+ num_inference_steps = len(
339
+ list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))
340
+ )
341
+ timesteps = timesteps[:num_inference_steps]
342
+
343
+ # get init conditioning scale
344
+ for attn_processor in self.unet.attn_processors.values():
345
+ if isinstance(attn_processor, (IPAttnProcessor, ConceptrolAttnProcessor)):
346
+ conditioning_scale = attn_processor.scale
347
+ break
348
+
349
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
350
+ for i, t in enumerate(timesteps):
351
+ if (i / len(timesteps) < control_guidance_start) or (
352
+ (i + 1) / len(timesteps) > control_guidance_end
353
+ ):
354
+ self.set_scale(0.0)
355
+ else:
356
+ self.set_scale(conditioning_scale)
357
+
358
+ # expand the latents if we are doing classifier free guidance
359
+ latent_model_input = (
360
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
361
+ )
362
+
363
+ latent_model_input = self.scheduler.scale_model_input(
364
+ latent_model_input, t
365
+ )
366
+
367
+ # predict the noise residual
368
+ added_cond_kwargs = {
369
+ "text_embeds": add_text_embeds,
370
+ "time_ids": add_time_ids,
371
+ }
372
+
373
+ noise_pred = self.unet(
374
+ latent_model_input,
375
+ t,
376
+ encoder_hidden_states=prompt_embeds,
377
+ cross_attention_kwargs=cross_attention_kwargs,
378
+ added_cond_kwargs=added_cond_kwargs,
379
+ return_dict=False,
380
+ )[0]
381
+
382
+ # perform guidance
383
+ if do_classifier_free_guidance:
384
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
385
+ noise_pred = noise_pred_uncond + guidance_scale * (
386
+ noise_pred_text - noise_pred_uncond
387
+ )
388
+
389
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
390
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
391
+ noise_pred = rescale_noise_cfg(
392
+ noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
393
+ )
394
+
395
+ # compute the previous noisy sample x_t -> x_t-1
396
+ latents = self.scheduler.step(
397
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
398
+ )[0]
399
+
400
+ # call the callback, if provided
401
+ if i == len(timesteps) - 1 or (
402
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
403
+ ):
404
+ progress_bar.update()
405
+ if callback is not None and i % callback_steps == 0:
406
+ callback(i, t, latents)
407
+
408
+ if not output_type == "latent":
409
+ # make sure the VAE is in float32 mode, as it overflows in float16
410
+ needs_upcasting = (
411
+ self.vae.dtype == torch.float16 and self.vae.config.force_upcast
412
+ )
413
+
414
+ if needs_upcasting:
415
+ self.upcast_vae()
416
+ latents = latents.to(
417
+ next(iter(self.vae.post_quant_conv.parameters())).dtype
418
+ )
419
+
420
+ image = self.vae.decode(
421
+ latents / self.vae.config.scaling_factor, return_dict=False
422
+ )[0]
423
+
424
+ # cast back to fp16 if needed
425
+ if needs_upcasting:
426
+ self.vae.to(dtype=torch.float16)
427
+ else:
428
+ image = latents
429
+
430
+ if output_type != "latent":
431
+ # apply watermark if available
432
+ if self.watermark is not None:
433
+ image = self.watermark.apply_watermark(image)
434
+
435
+ image = self.image_processor.postprocess(image, output_type=output_type)
436
+
437
+ # Offload all models
438
+ self.maybe_free_model_hooks()
439
+
440
+ if not return_dict:
441
+ return (image,)
442
+
443
+ return StableDiffusionXLPipelineOutput(images=image)
444
+
445
+
446
+ class StableDiffusionCustomPipeline(StableDiffusionPipeline):
447
+ def set_scale(self, scale):
448
+ for attn_processor in self.unet.attn_processors.values():
449
+ if isinstance(attn_processor, (IPAttnProcessor, ConceptrolAttnProcessor)):
450
+ attn_processor.scale = scale
451
+
452
+ @torch.no_grad()
453
+ def __call__(
454
+ self,
455
+ prompt: Union[str, List[str]] = None,
456
+ height: Optional[int] = None,
457
+ width: Optional[int] = None,
458
+ num_inference_steps: int = 50,
459
+ timesteps: List[int] = None,
460
+ guidance_scale: float = 7.5,
461
+ negative_prompt: Optional[Union[str, List[str]]] = None,
462
+ num_images_per_prompt: Optional[int] = 1,
463
+ eta: float = 0.0,
464
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
465
+ latents: Optional[torch.Tensor] = None,
466
+ prompt_embeds: Optional[torch.Tensor] = None,
467
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
468
+ ip_adapter_image: Optional[PipelineImageInput] = None,
469
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
470
+ output_type: Optional[str] = "pil",
471
+ return_dict: bool = True,
472
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
473
+ guidance_rescale: float = 0.0,
474
+ clip_skip: Optional[int] = None,
475
+ callback_on_step_end: Optional[
476
+ Union[
477
+ Callable[[int, int, Dict], None],
478
+ PipelineCallback,
479
+ MultiPipelineCallbacks,
480
+ ]
481
+ ] = None,
482
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
483
+ control_guidance_start: float = 0.0,
484
+ control_guidance_end: float = 1.0,
485
+ **kwargs,
486
+ ):
487
+ r"""
488
+ The call function to the pipeline for generation.
489
+
490
+ Args:
491
+ prompt (`str` or `List[str]`, *optional*):
492
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
493
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
494
+ The height in pixels of the generated image.
495
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
496
+ The width in pixels of the generated image.
497
+ num_inference_steps (`int`, *optional*, defaults to 50):
498
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
499
+ expense of slower inference.
500
+ timesteps (`List[int]`, *optional*):
501
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
502
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
503
+ passed will be used. Must be in descending order.
504
+ sigmas (`List[float]`, *optional*):
505
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
506
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
507
+ will be used.
508
+ guidance_scale (`float`, *optional*, defaults to 7.5):
509
+ A higher guidance scale value encourages the model to generate images closely linked to the text
510
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
511
+ negative_prompt (`str` or `List[str]`, *optional*):
512
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
513
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
514
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
515
+ The number of images to generate per prompt.
516
+ eta (`float`, *optional*, defaults to 0.0):
517
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
518
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
519
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
520
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
521
+ generation deterministic.
522
+ latents (`torch.Tensor`, *optional*):
523
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
524
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
525
+ tensor is generated by sampling using the supplied random `generator`.
526
+ prompt_embeds (`torch.Tensor`, *optional*):
527
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
528
+ provided, text embeddings are generated from the `prompt` input argument.
529
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
530
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
531
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
532
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
533
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
534
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
535
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
536
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
537
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
538
+ output_type (`str`, *optional*, defaults to `"pil"`):
539
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
540
+ return_dict (`bool`, *optional*, defaults to `True`):
541
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
542
+ plain tuple.
543
+ cross_attention_kwargs (`dict`, *optional*):
544
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
545
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
546
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
547
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
548
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
549
+ using zero terminal SNR.
550
+ clip_skip (`int`, *optional*):
551
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
552
+ the output of the pre-final layer will be used for computing the prompt embeddings.
553
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
554
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
555
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
556
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
557
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
558
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
559
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
560
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
561
+ `._callback_tensor_inputs` attribute of your pipeline class.
562
+
563
+ Examples:
564
+
565
+ Returns:
566
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
567
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
568
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
569
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
570
+ "not-safe-for-work" (nsfw) content.
571
+ """
572
+
573
+ attention_processor.attn_mask_logs = {}
574
+ attention_processor.image_attn_map_logs = {}
575
+ attention_processor.text_attn_map_logs = {}
576
+ attention_processor.global_concept_mask = []
577
+
578
+ callback = kwargs.pop("callback", None)
579
+ callback_steps = kwargs.pop("callback_steps", None)
580
+
581
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
582
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
583
+
584
+ # 0. Default height and width to unet
585
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
586
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
587
+ # to deal with lora scaling and other possible forward hooks
588
+
589
+ # 1. Check inputs. Raise error if not correct
590
+ self.check_inputs(
591
+ prompt,
592
+ height,
593
+ width,
594
+ callback_steps,
595
+ negative_prompt,
596
+ prompt_embeds,
597
+ negative_prompt_embeds,
598
+ ip_adapter_image,
599
+ ip_adapter_image_embeds,
600
+ callback_on_step_end_tensor_inputs,
601
+ )
602
+
603
+ self._guidance_scale = guidance_scale
604
+ self._guidance_rescale = guidance_rescale
605
+ self._clip_skip = clip_skip
606
+ self._cross_attention_kwargs = cross_attention_kwargs
607
+ self._interrupt = False
608
+
609
+ # 2. Define call parameters
610
+ if prompt is not None and isinstance(prompt, str):
611
+ batch_size = 1
612
+ elif prompt is not None and isinstance(prompt, list):
613
+ batch_size = len(prompt)
614
+ else:
615
+ batch_size = prompt_embeds.shape[0]
616
+
617
+ device = self._execution_device
618
+
619
+ # 3. Encode input prompt
620
+ lora_scale = (
621
+ self.cross_attention_kwargs.get("scale", None)
622
+ if self.cross_attention_kwargs is not None
623
+ else None
624
+ )
625
+
626
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
627
+ prompt,
628
+ device,
629
+ num_images_per_prompt,
630
+ self.do_classifier_free_guidance,
631
+ negative_prompt,
632
+ prompt_embeds=prompt_embeds,
633
+ negative_prompt_embeds=negative_prompt_embeds,
634
+ lora_scale=lora_scale,
635
+ clip_skip=self.clip_skip,
636
+ )
637
+
638
+ # For classifier free guidance, we need to do two forward passes.
639
+ # Here we concatenate the unconditional and text embeddings into a single batch
640
+ # to avoid doing two forward passes
641
+ if self.do_classifier_free_guidance:
642
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
643
+
644
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
645
+ image_embeds = self.prepare_ip_adapter_image_embeds(
646
+ ip_adapter_image,
647
+ ip_adapter_image_embeds,
648
+ device,
649
+ batch_size * num_images_per_prompt,
650
+ self.do_classifier_free_guidance,
651
+ )
652
+
653
+ # 4. Prepare timesteps
654
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
655
+ timesteps = self.scheduler.timesteps
656
+
657
+ # 5. Prepare latent variables
658
+ num_channels_latents = self.unet.config.in_channels
659
+ latents = self.prepare_latents(
660
+ batch_size * num_images_per_prompt,
661
+ num_channels_latents,
662
+ height,
663
+ width,
664
+ prompt_embeds.dtype,
665
+ device,
666
+ generator,
667
+ latents,
668
+ )
669
+
670
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
671
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
672
+
673
+ # 6.1 Add image embeds for IP-Adapter
674
+ added_cond_kwargs = (
675
+ {"image_embeds": image_embeds}
676
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
677
+ else None
678
+ )
679
+
680
+ # 6.2 Optionally get Guidance Scale Embedding
681
+ timestep_cond = None
682
+ if self.unet.config.time_cond_proj_dim is not None:
683
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
684
+ batch_size * num_images_per_prompt
685
+ )
686
+ timestep_cond = self.get_guidance_scale_embedding(
687
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
688
+ ).to(device=device, dtype=latents.dtype)
689
+
690
+ # 7. Denoising loop
691
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
692
+ self._num_timesteps = len(timesteps)
693
+
694
+ # get init conditioning scale
695
+ for attn_processor in self.unet.attn_processors.values():
696
+ if isinstance(attn_processor, (ConceptrolAttnProcessor, IPAttnProcessor)):
697
+ conditioning_scale = attn_processor.scale
698
+ break
699
+
700
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
701
+ for i, t in enumerate(timesteps):
702
+ if (i / len(timesteps) < control_guidance_start) or (
703
+ (i + 1) / len(timesteps) > control_guidance_end
704
+ ):
705
+ self.set_scale(0.0)
706
+ else:
707
+ self.set_scale(conditioning_scale)
708
+
709
+ if self.interrupt:
710
+ continue
711
+
712
+ # expand the latents if we are doing classifier free guidance
713
+ latent_model_input = (
714
+ torch.cat([latents] * 2)
715
+ if self.do_classifier_free_guidance
716
+ else latents
717
+ )
718
+ latent_model_input = self.scheduler.scale_model_input(
719
+ latent_model_input, t
720
+ )
721
+
722
+ # predict the noise residual
723
+ noise_pred = self.unet(
724
+ latent_model_input,
725
+ t,
726
+ encoder_hidden_states=prompt_embeds,
727
+ timestep_cond=timestep_cond,
728
+ cross_attention_kwargs=self.cross_attention_kwargs,
729
+ added_cond_kwargs=added_cond_kwargs,
730
+ return_dict=False,
731
+ )[0]
732
+
733
+ # perform guidance
734
+ if self.do_classifier_free_guidance:
735
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
736
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
737
+ noise_pred_text - noise_pred_uncond
738
+ )
739
+
740
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
741
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
742
+ noise_pred = rescale_noise_cfg(
743
+ noise_pred,
744
+ noise_pred_text,
745
+ guidance_rescale=self.guidance_rescale,
746
+ )
747
+
748
+ # compute the previous noisy sample x_t -> x_t-1
749
+ results = self.scheduler.step(
750
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
751
+ )
752
+ latents = results[0]
753
+ # pred_original = results[1]
754
+
755
+ if callback_on_step_end is not None:
756
+ callback_kwargs = {}
757
+ for k in callback_on_step_end_tensor_inputs:
758
+ callback_kwargs[k] = locals()[k]
759
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
760
+
761
+ latents = callback_outputs.pop("latents", latents)
762
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
763
+ negative_prompt_embeds = callback_outputs.pop(
764
+ "negative_prompt_embeds", negative_prompt_embeds
765
+ )
766
+
767
+ # call the callback, if provided
768
+ if i == len(timesteps) - 1 or (
769
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
770
+ ):
771
+ progress_bar.update()
772
+ if callback is not None and i % callback_steps == 0:
773
+ step_idx = i // getattr(self.scheduler, "order", 1)
774
+ callback(step_idx, t, latents)
775
+
776
+ if output_type != "latent":
777
+ image = self.vae.decode(
778
+ latents / self.vae.config.scaling_factor,
779
+ return_dict=False,
780
+ generator=generator,
781
+ )[0]
782
+ image, has_nsfw_concept = self.run_safety_checker(
783
+ image, device, prompt_embeds.dtype
784
+ )
785
+ else:
786
+ image = latents
787
+ has_nsfw_concept = None
788
+
789
+ if has_nsfw_concept is None:
790
+ do_denormalize = [True] * image.shape[0]
791
+ else:
792
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
793
+ image = self.image_processor.postprocess(
794
+ image, output_type=output_type, do_denormalize=do_denormalize
795
+ )
796
+
797
+ # Offload all models
798
+ self.maybe_free_model_hooks()
799
+
800
+ if not return_dict:
801
+ return (image, has_nsfw_concept)
802
+
803
+ return StableDiffusionPipelineOutput(
804
+ images=image, nsfw_content_detected=has_nsfw_concept
805
+ )
ip_adapter/ip_adapter.py ADDED
@@ -0,0 +1,1043 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import torch
5
+ from diffusers.pipelines.controlnet import MultiControlNetModel
6
+ from PIL import Image
7
+ from safetensors import safe_open
8
+ from transformers import (
9
+ CLIPImageProcessor,
10
+ CLIPVisionModelWithProjection,
11
+ CLIPTokenizer,
12
+ )
13
+
14
+ from .attention_processor import (
15
+ AttnProcessor,
16
+ CNAttnProcessor,
17
+ IPAttnProcessor,
18
+ ConceptrolAttnProcessor,
19
+ )
20
+ from .resampler import Resampler
21
+ from .utils import get_generator
22
+ from huggingface_hub import hf_hub_download
23
+
24
+ SD_CONCEPT_LAYER = ["up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor"]
25
+ SDXL_CONCEPT_LAYER = ["up_blocks.0.attentions.1.transformer_blocks.3.attn2.processor"]
26
+
27
+
28
+ class ImageProjModel(torch.nn.Module):
29
+ """Projection Model"""
30
+
31
+ def __init__(
32
+ self,
33
+ cross_attention_dim=1024,
34
+ clip_embeddings_dim=1024,
35
+ clip_extra_context_tokens=4,
36
+ ):
37
+ super().__init__()
38
+
39
+ self.generator = None
40
+ self.cross_attention_dim = cross_attention_dim
41
+ self.clip_extra_context_tokens = clip_extra_context_tokens
42
+ self.proj = torch.nn.Linear(
43
+ clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim
44
+ )
45
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
46
+
47
+ def forward(self, image_embeds):
48
+ embeds = image_embeds
49
+ clip_extra_context_tokens = self.proj(embeds).reshape(
50
+ -1, self.clip_extra_context_tokens, self.cross_attention_dim
51
+ )
52
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
53
+ return clip_extra_context_tokens
54
+
55
+
56
+ class MLPProjModel(torch.nn.Module):
57
+ """SD model with image prompt"""
58
+
59
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
60
+ super().__init__()
61
+
62
+ self.proj = torch.nn.Sequential(
63
+ torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
64
+ torch.nn.GELU(),
65
+ torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
66
+ torch.nn.LayerNorm(cross_attention_dim),
67
+ )
68
+
69
+ def forward(self, image_embeds):
70
+ clip_extra_context_tokens = self.proj(image_embeds)
71
+ return clip_extra_context_tokens
72
+
73
+
74
+ class IPAdapter:
75
+ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4):
76
+ self.device = device
77
+ self.image_encoder_path = image_encoder_path
78
+ self.ip_ckpt = ip_ckpt
79
+ self.num_tokens = num_tokens
80
+
81
+ self.pipe = sd_pipe.to(self.device)
82
+ self.set_ip_adapter()
83
+
84
+ # load image encoder
85
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
86
+ "h94/IP-Adapter",
87
+ subfolder="models/image_encoder",
88
+ torch_dtype=torch.float16,
89
+ ).to(self.device, dtype=torch.float16)
90
+ self.clip_image_processor = CLIPImageProcessor()
91
+ # image proj model
92
+ self.image_proj_model = self.init_proj()
93
+
94
+ self.load_ip_adapter()
95
+
96
+ def init_proj(self):
97
+ image_proj_model = ImageProjModel(
98
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
99
+ clip_embeddings_dim=self.image_encoder.config.projection_dim,
100
+ clip_extra_context_tokens=self.num_tokens,
101
+ ).to(self.device, dtype=torch.float16)
102
+ return image_proj_model
103
+
104
+ def set_ip_adapter(self):
105
+ unet = self.pipe.unet
106
+ attn_procs = {}
107
+ for name in unet.attn_processors.keys(): # noqa: SIM118
108
+ cross_attention_dim = (
109
+ None
110
+ if name.endswith("attn1.processor")
111
+ else unet.config.cross_attention_dim
112
+ )
113
+ if name.startswith("mid_block"):
114
+ hidden_size = unet.config.block_out_channels[-1]
115
+ elif name.startswith("up_blocks"):
116
+ block_id = int(name[len("up_blocks.")])
117
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
118
+ elif name.startswith("down_blocks"):
119
+ block_id = int(name[len("down_blocks.")])
120
+ hidden_size = unet.config.block_out_channels[block_id]
121
+ if cross_attention_dim is None:
122
+ attn_procs[name] = AttnProcessor()
123
+ else:
124
+ attn_procs[name] = IPAttnProcessor(
125
+ hidden_size=hidden_size,
126
+ cross_attention_dim=cross_attention_dim,
127
+ scale=1.0,
128
+ num_tokens=self.num_tokens,
129
+ ).to(self.device, dtype=torch.float16)
130
+ unet.set_attn_processor(attn_procs)
131
+ if hasattr(self.pipe, "controlnet"):
132
+ if isinstance(self.pipe.controlnet, MultiControlNetModel):
133
+ for controlnet in self.pipe.controlnet.nets:
134
+ controlnet.set_attn_processor(
135
+ CNAttnProcessor(num_tokens=self.num_tokens)
136
+ )
137
+ else:
138
+ self.pipe.controlnet.set_attn_processor(
139
+ CNAttnProcessor(num_tokens=self.num_tokens)
140
+ )
141
+
142
+ def load_ip_adapter(self):
143
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
144
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
145
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
146
+ for key in f.keys(): # noqa: SIM118
147
+ if key.startswith("image_proj."):
148
+ state_dict["image_proj"][key.replace("image_proj.", "")] = (
149
+ f.get_tensor(key)
150
+ )
151
+ elif key.startswith("ip_adapter."):
152
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = (
153
+ f.get_tensor(key)
154
+ )
155
+ else:
156
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
157
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
158
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
159
+ ip_layers.load_state_dict(state_dict["ip_adapter"])
160
+
161
+ @torch.inference_mode()
162
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
163
+ if pil_image is not None:
164
+ if isinstance(pil_image, Image.Image):
165
+ pil_image = [pil_image]
166
+ clip_image = self.clip_image_processor(
167
+ images=pil_image, return_tensors="pt"
168
+ ).pixel_values
169
+ clip_image_embeds = self.image_encoder(
170
+ clip_image.to(self.device, dtype=torch.float16)
171
+ ).image_embeds
172
+ else:
173
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
174
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
175
+ uncond_image_prompt_embeds = self.image_proj_model(
176
+ torch.zeros_like(clip_image_embeds)
177
+ )
178
+ return image_prompt_embeds, uncond_image_prompt_embeds
179
+
180
+ def set_scale(self, scale):
181
+ for attn_processor in self.pipe.unet.attn_processors.values():
182
+ if isinstance(attn_processor, IPAttnProcessor):
183
+ attn_processor.scale = scale
184
+
185
+ def generate(
186
+ self,
187
+ pil_images=None,
188
+ clip_image_embeds=None,
189
+ prompt=None,
190
+ negative_prompt=None,
191
+ scale=1.0,
192
+ num_samples=1,
193
+ guidance_scale=7.5,
194
+ num_inference_steps=30,
195
+ **kwargs,
196
+ ):
197
+ self.set_scale(scale)
198
+
199
+ num_prompts = 1 if pil_images is not None else clip_image_embeds.size(0)
200
+
201
+ if prompt is None:
202
+ prompt = "best quality, high quality"
203
+ if negative_prompt is None:
204
+ negative_prompt = (
205
+ "monochrome, lowres, bad anatomy, worst quality, low quality"
206
+ )
207
+
208
+ if not isinstance(prompt, List):
209
+ prompt = [prompt] * num_prompts
210
+ if not isinstance(negative_prompt, List):
211
+ negative_prompt = [negative_prompt] * num_prompts
212
+
213
+ image_prompt_embeds_list = []
214
+ uncond_image_prompt_embeds_list = []
215
+ for pil_image in pil_images:
216
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
217
+ pil_image=pil_image, clip_image_embeds=clip_image_embeds
218
+ )
219
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
220
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
221
+ image_prompt_embeds = image_prompt_embeds.view(
222
+ bs_embed * num_samples, seq_len, -1
223
+ )
224
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(
225
+ 1, num_samples, 1
226
+ )
227
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(
228
+ bs_embed * num_samples, seq_len, -1
229
+ )
230
+ image_prompt_embeds_list.append(image_prompt_embeds)
231
+ uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds)
232
+
233
+ with torch.inference_mode():
234
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
235
+ prompt,
236
+ device=self.device,
237
+ num_images_per_prompt=num_samples,
238
+ do_classifier_free_guidance=True,
239
+ negative_prompt=negative_prompt,
240
+ )
241
+ prompt_embeds = torch.cat(
242
+ [prompt_embeds_, *image_prompt_embeds_list], dim=1
243
+ )
244
+ negative_prompt_embeds = torch.cat(
245
+ [negative_prompt_embeds_, *uncond_image_prompt_embeds_list], dim=1
246
+ )
247
+
248
+ # generator = get_generator(seed, self.device)
249
+
250
+ images = self.pipe(
251
+ prompt_embeds=prompt_embeds,
252
+ negative_prompt_embeds=negative_prompt_embeds,
253
+ guidance_scale=guidance_scale,
254
+ num_inference_steps=num_inference_steps,
255
+ # generator=generator,
256
+ **kwargs,
257
+ ).images
258
+
259
+ return images
260
+
261
+
262
+ class ConceptrolIPAdapter:
263
+ def __init__(
264
+ self,
265
+ sd_pipe,
266
+ image_encoder_path,
267
+ ip_ckpt,
268
+ device,
269
+ num_tokens=4,
270
+ global_masking=False,
271
+ adaptive_scale_mask=False,
272
+ ):
273
+ self.device = device
274
+ self.image_encoder_path = image_encoder_path
275
+ self.ip_ckpt = ip_ckpt
276
+ self.num_tokens = num_tokens
277
+
278
+ self.pipe = sd_pipe.to(self.device)
279
+ self.set_ip_adapter(global_masking, adaptive_scale_mask)
280
+
281
+ # load image encoder
282
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
283
+ "h94/IP-Adapter",
284
+ subfolder="models/image_encoder",
285
+ torch_dtype=torch.float16,
286
+ ).to(self.device, dtype=torch.float16)
287
+ self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
288
+ self.clip_image_processor = CLIPImageProcessor()
289
+ # image proj model
290
+ self.image_proj_model = self.init_proj()
291
+
292
+ self.load_ip_adapter()
293
+
294
+ def init_proj(self):
295
+ image_proj_model = ImageProjModel(
296
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
297
+ clip_embeddings_dim=self.image_encoder.config.projection_dim,
298
+ clip_extra_context_tokens=self.num_tokens,
299
+ ).to(self.device, dtype=torch.float16)
300
+ return image_proj_model
301
+
302
+ def set_ip_adapter(self, global_masking, adaptive_scale_mask):
303
+ unet = self.pipe.unet
304
+ attn_procs = {}
305
+ for name in unet.attn_processors.keys(): # noqa: SIM118
306
+ cross_attention_dim = (
307
+ None
308
+ if name.endswith("attn1.processor")
309
+ else unet.config.cross_attention_dim
310
+ )
311
+ if name.startswith("mid_block"):
312
+ hidden_size = unet.config.block_out_channels[-1]
313
+ elif name.startswith("up_blocks"):
314
+ block_id = int(name[len("up_blocks.")])
315
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
316
+ elif name.startswith("down_blocks"):
317
+ block_id = int(name[len("down_blocks.")])
318
+ hidden_size = unet.config.block_out_channels[block_id]
319
+ if cross_attention_dim is None:
320
+ attn_procs[name] = AttnProcessor()
321
+ else:
322
+ attn_procs[name] = ConceptrolAttnProcessor(
323
+ hidden_size=hidden_size,
324
+ cross_attention_dim=cross_attention_dim,
325
+ scale=1.0,
326
+ num_tokens=self.num_tokens,
327
+ name=name,
328
+ global_masking=global_masking,
329
+ adaptive_scale_mask=adaptive_scale_mask,
330
+ concept_mask_layer=SD_CONCEPT_LAYER,
331
+ ).to(self.device, dtype=torch.float16)
332
+ unet.set_attn_processor(attn_procs)
333
+ for name in unet.attn_processors.keys(): # noqa: SIM118
334
+ cross_attention_dim = (
335
+ None
336
+ if name.endswith("attn1.processor")
337
+ else unet.config.cross_attention_dim
338
+ )
339
+ if cross_attention_dim is not None:
340
+ unet.attn_processors[name].set_global_view(unet.attn_processors)
341
+ if hasattr(self.pipe, "controlnet"):
342
+ if isinstance(self.pipe.controlnet, MultiControlNetModel):
343
+ for controlnet in self.pipe.controlnet.nets:
344
+ controlnet.set_attn_processor(
345
+ CNAttnProcessor(num_tokens=self.num_tokens)
346
+ )
347
+ else:
348
+ self.pipe.controlnet.set_attn_processor(
349
+ CNAttnProcessor(num_tokens=self.num_tokens)
350
+ )
351
+
352
+ def load_ip_adapter(self):
353
+ ckpt_path = self.ip_ckpt
354
+ # If the checkpoint path is not an existing file and is not a full URL,
355
+ # assume it's a Huggingface repository specification.
356
+ if not os.path.exists(self.ip_ckpt) and not self.ip_ckpt.startswith("http"):
357
+ # If a colon is present, use it to split repo_id and filename.
358
+ if ":" in self.ip_ckpt:
359
+ repo_id, filename = self.ip_ckpt.split(":", 1)
360
+ else:
361
+ parts = self.ip_ckpt.split('/')
362
+ if len(parts) > 2:
363
+ # For example, "h94/IP-Adapter/models/ip-adapter-plus_sd15.bin"
364
+ # repo_id becomes "h94/IP-Adapter" and filename "models/ip-adapter-plus_sd15.bin".
365
+ repo_id = '/'.join(parts[:2])
366
+ filename = '/'.join(parts[2:])
367
+ else:
368
+ repo_id = self.ip_ckpt
369
+ filename = "models/ip-adapter-plus_sd15.bin" # default filename if not specified
370
+ ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
371
+
372
+ # Load the state dictionary from the checkpoint file.
373
+ if os.path.splitext(ckpt_path)[-1] == ".safetensors":
374
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
375
+ with safe_open(ckpt_path, framework="pt", device="cpu") as f:
376
+ for key in f.keys():
377
+ if key.startswith("image_proj."):
378
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
379
+ elif key.startswith("ip_adapter."):
380
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
381
+ else:
382
+ state_dict = torch.load(ckpt_path, map_location="cpu")
383
+
384
+ # Load the state dictionaries into the corresponding models.
385
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
386
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
387
+ ip_layers.load_state_dict(state_dict["ip_adapter"])
388
+
389
+ @torch.inference_mode()
390
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
391
+ if pil_image is not None:
392
+ if isinstance(pil_image, Image.Image):
393
+ pil_image = [pil_image]
394
+ clip_image = self.clip_image_processor(
395
+ images=pil_image, return_tensors="pt"
396
+ ).pixel_values
397
+ clip_image_embeds = self.image_encoder(
398
+ clip_image.to(self.device, dtype=torch.float16)
399
+ ).image_embeds
400
+ else:
401
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
402
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
403
+ uncond_image_prompt_embeds = self.image_proj_model(
404
+ torch.zeros_like(clip_image_embeds)
405
+ )
406
+ return image_prompt_embeds, uncond_image_prompt_embeds
407
+
408
+ def set_scale(self, scale):
409
+ for attn_processor in self.pipe.unet.attn_processors.values():
410
+ if isinstance(attn_processor, ConceptrolAttnProcessor):
411
+ attn_processor.scale = scale
412
+
413
+ def load_textual_concept(self, prompt, subjects):
414
+ tokens = self.tokenizer.tokenize(prompt)
415
+ textual_concept_idxs = []
416
+ offset = 1 # TODO: change back to 1 if not true
417
+
418
+ for subject in subjects:
419
+ subject_tokens = self.tokenizer.tokenize(subject)
420
+ start_idx = tokens.index(subject_tokens[0]) + offset
421
+ end_idx = tokens.index(subject_tokens[-1]) + offset
422
+ textual_concept_idxs.append((start_idx, end_idx + 1))
423
+ print("Locate:", subject, start_idx, end_idx + 1)
424
+
425
+ for attn_processor in self.pipe.unet.attn_processors.values():
426
+ if isinstance(attn_processor, ConceptrolAttnProcessor):
427
+ attn_processor.textual_concept_idxs = textual_concept_idxs
428
+
429
+ def generate(
430
+ self,
431
+ pil_images=None,
432
+ clip_image_embeds=None,
433
+ prompt=None,
434
+ negative_prompt=None,
435
+ scale=1.0,
436
+ num_samples=1,
437
+ seed=42,
438
+ subjects=None,
439
+ guidance_scale=7.5,
440
+ num_inference_steps=30,
441
+ **kwargs,
442
+ ):
443
+ self.set_scale(scale)
444
+
445
+ num_prompts = 1 # not support multiple prompts
446
+
447
+ if prompt is None:
448
+ prompt = "best quality, high quality"
449
+ if negative_prompt is None:
450
+ negative_prompt = (
451
+ "monochrome, lowres, bad anatomy, worst quality, low quality"
452
+ )
453
+
454
+ if subjects:
455
+ self.load_textual_concept(prompt, subjects)
456
+ else:
457
+ raise ValueError("Subjects must be provided")
458
+
459
+ if not isinstance(prompt, List):
460
+ prompt = [prompt] * num_prompts
461
+ if not isinstance(negative_prompt, List):
462
+ negative_prompt = [negative_prompt] * num_prompts
463
+
464
+ image_prompt_embeds_list = []
465
+ uncond_image_prompt_embeds_list = []
466
+ for pil_image in pil_images:
467
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
468
+ pil_image=pil_image, clip_image_embeds=clip_image_embeds
469
+ )
470
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
471
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
472
+ image_prompt_embeds = image_prompt_embeds.view(
473
+ bs_embed * num_samples, seq_len, -1
474
+ )
475
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(
476
+ 1, num_samples, 1
477
+ )
478
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(
479
+ bs_embed * num_samples, seq_len, -1
480
+ )
481
+ image_prompt_embeds_list.append(image_prompt_embeds)
482
+ uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds)
483
+
484
+ with torch.inference_mode():
485
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
486
+ prompt,
487
+ device=self.device,
488
+ num_images_per_prompt=num_samples,
489
+ do_classifier_free_guidance=True,
490
+ negative_prompt=negative_prompt,
491
+ )
492
+ prompt_embeds = torch.cat(
493
+ [prompt_embeds_, *image_prompt_embeds_list], dim=1
494
+ )
495
+ negative_prompt_embeds = torch.cat(
496
+ [negative_prompt_embeds_, *uncond_image_prompt_embeds_list], dim=1
497
+ )
498
+
499
+ generator = get_generator(seed, self.device)
500
+
501
+ images = self.pipe(
502
+ prompt_embeds=prompt_embeds,
503
+ negative_prompt_embeds=negative_prompt_embeds,
504
+ guidance_scale=guidance_scale,
505
+ num_inference_steps=num_inference_steps,
506
+ generator=generator,
507
+ **kwargs,
508
+ ).images
509
+
510
+ return images
511
+
512
+
513
+ class IPAdapterXL(IPAdapter):
514
+ """SDXL"""
515
+
516
+ def generate(
517
+ self,
518
+ pil_images,
519
+ prompt=None,
520
+ negative_prompt=None,
521
+ scale=1.0,
522
+ num_samples=1,
523
+ seed=None,
524
+ num_inference_steps=30,
525
+ **kwargs,
526
+ ):
527
+ self.set_scale(scale)
528
+
529
+ num_prompts = 1 # not support multiple prompts
530
+
531
+ if prompt is None:
532
+ prompt = "best quality, high quality"
533
+ if negative_prompt is None:
534
+ negative_prompt = (
535
+ "monochrome, lowres, bad anatomy, worst quality, low quality"
536
+ )
537
+
538
+ if not isinstance(prompt, List):
539
+ prompt = [prompt] * num_prompts
540
+ if not isinstance(negative_prompt, List):
541
+ negative_prompt = [negative_prompt] * num_prompts
542
+
543
+ image_prompt_embeds_list = []
544
+ uncond_image_prompt_embeds_list = []
545
+ for pil_image in pil_images:
546
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
547
+ pil_image=pil_image
548
+ )
549
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
550
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
551
+ image_prompt_embeds = image_prompt_embeds.view(
552
+ bs_embed * num_samples, seq_len, -1
553
+ )
554
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(
555
+ 1, num_samples, 1
556
+ )
557
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(
558
+ bs_embed * num_samples, seq_len, -1
559
+ )
560
+ image_prompt_embeds_list.append(image_prompt_embeds)
561
+ uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds)
562
+
563
+ with torch.inference_mode():
564
+ (
565
+ prompt_embeds,
566
+ negative_prompt_embeds,
567
+ pooled_prompt_embeds,
568
+ negative_pooled_prompt_embeds,
569
+ ) = self.pipe.encode_prompt(
570
+ prompt,
571
+ num_images_per_prompt=num_samples,
572
+ do_classifier_free_guidance=True,
573
+ negative_prompt=negative_prompt,
574
+ )
575
+ prompt_embeds = torch.cat([prompt_embeds, *image_prompt_embeds_list], dim=1)
576
+ negative_prompt_embeds = torch.cat(
577
+ [negative_prompt_embeds, *uncond_image_prompt_embeds_list], dim=1
578
+ )
579
+
580
+ generator = get_generator(seed, self.device)
581
+
582
+ images = self.pipe(
583
+ prompt_embeds=prompt_embeds,
584
+ negative_prompt_embeds=negative_prompt_embeds,
585
+ pooled_prompt_embeds=pooled_prompt_embeds,
586
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
587
+ num_inference_steps=num_inference_steps,
588
+ generator=generator,
589
+ **kwargs,
590
+ ).images
591
+
592
+ return images
593
+
594
+
595
+ class ConceptrolIPAdapterXL(ConceptrolIPAdapter):
596
+ """SDXL"""
597
+
598
+ def set_ip_adapter(self, global_masking, adaptive_scale_mask):
599
+ unet = self.pipe.unet
600
+ attn_procs = {}
601
+ for name in unet.attn_processors.keys(): # noqa: SIM118
602
+ cross_attention_dim = (
603
+ None
604
+ if name.endswith("attn1.processor")
605
+ else unet.config.cross_attention_dim
606
+ )
607
+ if name.startswith("mid_block"):
608
+ hidden_size = unet.config.block_out_channels[-1]
609
+ elif name.startswith("up_blocks"):
610
+ block_id = int(name[len("up_blocks.")])
611
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
612
+ elif name.startswith("down_blocks"):
613
+ block_id = int(name[len("down_blocks.")])
614
+ hidden_size = unet.config.block_out_channels[block_id]
615
+ if cross_attention_dim is None:
616
+ attn_procs[name] = AttnProcessor()
617
+ else:
618
+ attn_procs[name] = ConceptrolAttnProcessor(
619
+ hidden_size=hidden_size,
620
+ cross_attention_dim=cross_attention_dim,
621
+ scale=1.0,
622
+ num_tokens=self.num_tokens,
623
+ name=name,
624
+ global_masking=global_masking,
625
+ adaptive_scale_mask=adaptive_scale_mask,
626
+ concept_mask_layer=SDXL_CONCEPT_LAYER,
627
+ ).to(self.device, dtype=torch.float16)
628
+ unet.set_attn_processor(attn_procs)
629
+ for name in unet.attn_processors.keys(): # noqa: SIM118
630
+ cross_attention_dim = (
631
+ None
632
+ if name.endswith("attn1.processor")
633
+ else unet.config.cross_attention_dim
634
+ )
635
+ if cross_attention_dim is not None:
636
+ unet.attn_processors[name].set_global_view(unet.attn_processors)
637
+ if hasattr(self.pipe, "controlnet"):
638
+ if isinstance(self.pipe.controlnet, MultiControlNetModel):
639
+ for controlnet in self.pipe.controlnet.nets:
640
+ controlnet.set_attn_processor(
641
+ CNAttnProcessor(num_tokens=self.num_tokens)
642
+ )
643
+ else:
644
+ self.pipe.controlnet.set_attn_processor(
645
+ CNAttnProcessor(num_tokens=self.num_tokens)
646
+ )
647
+
648
+ def generate(
649
+ self,
650
+ pil_images=None,
651
+ prompt=None,
652
+ negative_prompt=None,
653
+ subjects=None,
654
+ scale=1.0,
655
+ num_samples=1,
656
+ num_inference_steps=30,
657
+ seed=None,
658
+ **kwargs,
659
+ ):
660
+ self.set_scale(scale)
661
+
662
+ num_prompts = 1 # not support multiple prompts
663
+
664
+ if prompt is None:
665
+ prompt = "best quality, high quality"
666
+ if negative_prompt is None:
667
+ negative_prompt = (
668
+ "monochrome, lowres, bad anatomy, worst quality, low quality"
669
+ )
670
+
671
+ if subjects:
672
+ self.load_textual_concept(prompt, subjects)
673
+ else:
674
+ raise ValueError("Subjects must be provided")
675
+
676
+ if not isinstance(prompt, List):
677
+ prompt = [prompt] * num_prompts
678
+ if not isinstance(negative_prompt, List):
679
+ negative_prompt = [negative_prompt] * num_prompts
680
+
681
+ image_prompt_embeds_list = []
682
+ uncond_image_prompt_embeds_list = []
683
+ for pil_image in pil_images:
684
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
685
+ pil_image=pil_image
686
+ )
687
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
688
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
689
+ image_prompt_embeds = image_prompt_embeds.view(
690
+ bs_embed * num_samples, seq_len, -1
691
+ )
692
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(
693
+ 1, num_samples, 1
694
+ )
695
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(
696
+ bs_embed * num_samples, seq_len, -1
697
+ )
698
+ image_prompt_embeds_list.append(image_prompt_embeds)
699
+ uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds)
700
+
701
+ with torch.inference_mode():
702
+ (
703
+ prompt_embeds,
704
+ negative_prompt_embeds,
705
+ pooled_prompt_embeds,
706
+ negative_pooled_prompt_embeds,
707
+ ) = self.pipe.encode_prompt(
708
+ prompt,
709
+ num_images_per_prompt=num_samples,
710
+ do_classifier_free_guidance=True,
711
+ negative_prompt=negative_prompt,
712
+ )
713
+ prompt_embeds = torch.cat([prompt_embeds, *image_prompt_embeds_list], dim=1)
714
+ negative_prompt_embeds = torch.cat(
715
+ [negative_prompt_embeds, *uncond_image_prompt_embeds_list], dim=1
716
+ )
717
+
718
+ generator = get_generator(seed, self.device)
719
+
720
+ images = self.pipe(
721
+ prompt_embeds=prompt_embeds,
722
+ negative_prompt_embeds=negative_prompt_embeds,
723
+ pooled_prompt_embeds=pooled_prompt_embeds,
724
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
725
+ num_inference_steps=num_inference_steps,
726
+ generator=generator,
727
+ **kwargs,
728
+ ).images
729
+
730
+ return images
731
+
732
+
733
+ class IPAdapterPlus(IPAdapter):
734
+ """IP-Adapter with fine-grained features"""
735
+
736
+ def init_proj(self):
737
+ image_proj_model = Resampler(
738
+ dim=self.pipe.unet.config.cross_attention_dim,
739
+ depth=4,
740
+ dim_head=64,
741
+ heads=12,
742
+ num_queries=self.num_tokens,
743
+ embedding_dim=self.image_encoder.config.hidden_size,
744
+ output_dim=self.pipe.unet.config.cross_attention_dim,
745
+ ff_mult=4,
746
+ ).to(self.device, dtype=torch.float16)
747
+ return image_proj_model
748
+
749
+ @torch.inference_mode()
750
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
751
+ if isinstance(pil_image, Image.Image):
752
+ pil_image = [pil_image]
753
+ clip_image = self.clip_image_processor(
754
+ images=pil_image, return_tensors="pt"
755
+ ).pixel_values
756
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
757
+ clip_image_embeds = self.image_encoder(
758
+ clip_image, output_hidden_states=True
759
+ ).hidden_states[-2]
760
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
761
+ uncond_clip_image_embeds = self.image_encoder(
762
+ torch.zeros_like(clip_image), output_hidden_states=True
763
+ ).hidden_states[-2]
764
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
765
+ return image_prompt_embeds, uncond_image_prompt_embeds
766
+
767
+
768
+ class ConceptrolIPAdapterPlus(ConceptrolIPAdapter):
769
+ """IP-Adapter with fine-grained features"""
770
+
771
+ def init_proj(self):
772
+ image_proj_model = Resampler(
773
+ dim=self.pipe.unet.config.cross_attention_dim,
774
+ depth=4,
775
+ dim_head=64,
776
+ heads=12,
777
+ num_queries=self.num_tokens,
778
+ embedding_dim=self.image_encoder.config.hidden_size,
779
+ output_dim=self.pipe.unet.config.cross_attention_dim,
780
+ ff_mult=4,
781
+ ).to(self.device, dtype=torch.float16)
782
+ return image_proj_model
783
+
784
+ @torch.inference_mode()
785
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
786
+ if isinstance(pil_image, Image.Image):
787
+ pil_image = [pil_image]
788
+ clip_image = self.clip_image_processor(
789
+ images=pil_image, return_tensors="pt"
790
+ ).pixel_values
791
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
792
+ clip_image_embeds = self.image_encoder(
793
+ clip_image, output_hidden_states=True
794
+ ).hidden_states[-2]
795
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
796
+ uncond_clip_image_embeds = self.image_encoder(
797
+ torch.zeros_like(clip_image), output_hidden_states=True
798
+ ).hidden_states[-2]
799
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
800
+ return image_prompt_embeds, uncond_image_prompt_embeds
801
+
802
+
803
+ class IPAdapterFull(IPAdapterPlus):
804
+ """IP-Adapter with full features"""
805
+
806
+ def init_proj(self):
807
+ image_proj_model = MLPProjModel(
808
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
809
+ clip_embeddings_dim=self.image_encoder.config.hidden_size,
810
+ ).to(self.device, dtype=torch.float16)
811
+ return image_proj_model
812
+
813
+
814
+ class IPAdapterPlusXL(IPAdapter):
815
+ """SDXL"""
816
+
817
+ def init_proj(self):
818
+ image_proj_model = Resampler(
819
+ dim=1280,
820
+ depth=4,
821
+ dim_head=64,
822
+ heads=20,
823
+ num_queries=self.num_tokens,
824
+ embedding_dim=self.image_encoder.config.hidden_size,
825
+ output_dim=self.pipe.unet.config.cross_attention_dim,
826
+ ff_mult=4,
827
+ ).to(self.device, dtype=torch.float16)
828
+ return image_proj_model
829
+
830
+ @torch.inference_mode()
831
+ def get_image_embeds(self, pil_image):
832
+ if isinstance(pil_image, Image.Image):
833
+ pil_image = [pil_image]
834
+ clip_image = self.clip_image_processor(
835
+ images=pil_image, return_tensors="pt"
836
+ ).pixel_values
837
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
838
+ clip_image_embeds = self.image_encoder(
839
+ clip_image, output_hidden_states=True
840
+ ).hidden_states[-2]
841
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
842
+ uncond_clip_image_embeds = self.image_encoder(
843
+ torch.zeros_like(clip_image), output_hidden_states=True
844
+ ).hidden_states[-2]
845
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
846
+ return image_prompt_embeds, uncond_image_prompt_embeds
847
+
848
+ def generate(
849
+ self,
850
+ pil_images=None,
851
+ prompt=None,
852
+ negative_prompt=None,
853
+ scale=1.0,
854
+ num_samples=1,
855
+ seed=42,
856
+ num_inference_steps=30,
857
+ **kwargs,
858
+ ):
859
+ self.set_scale(scale)
860
+
861
+ num_prompts = 1 # not support multiple prompts
862
+
863
+ if prompt is None:
864
+ prompt = "best quality, high quality"
865
+ if negative_prompt is None:
866
+ negative_prompt = (
867
+ "monochrome, lowres, bad anatomy, worst quality, low quality"
868
+ )
869
+
870
+ if not isinstance(prompt, List):
871
+ prompt = [prompt] * num_prompts
872
+ if not isinstance(negative_prompt, List):
873
+ negative_prompt = [negative_prompt] * num_prompts
874
+
875
+ image_prompt_embeds_list = []
876
+ uncond_image_prompt_embeds_list = []
877
+ for pil_image in pil_images:
878
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
879
+ pil_image=pil_image
880
+ )
881
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
882
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
883
+ image_prompt_embeds = image_prompt_embeds.view(
884
+ bs_embed * num_samples, seq_len, -1
885
+ )
886
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(
887
+ 1, num_samples, 1
888
+ )
889
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(
890
+ bs_embed * num_samples, seq_len, -1
891
+ )
892
+ image_prompt_embeds_list.append(image_prompt_embeds)
893
+ uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds)
894
+
895
+ with torch.inference_mode():
896
+ (
897
+ prompt_embeds,
898
+ negative_prompt_embeds,
899
+ pooled_prompt_embeds,
900
+ negative_pooled_prompt_embeds,
901
+ ) = self.pipe.encode_prompt(
902
+ prompt,
903
+ num_images_per_prompt=num_samples,
904
+ do_classifier_free_guidance=True,
905
+ negative_prompt=negative_prompt,
906
+ )
907
+ prompt_embeds = torch.cat([prompt_embeds, *image_prompt_embeds_list], dim=1)
908
+ negative_prompt_embeds = torch.cat(
909
+ [negative_prompt_embeds, *uncond_image_prompt_embeds_list], dim=1
910
+ )
911
+
912
+ generator = get_generator(seed, self.device)
913
+
914
+ images = self.pipe(
915
+ prompt_embeds=prompt_embeds,
916
+ negative_prompt_embeds=negative_prompt_embeds,
917
+ pooled_prompt_embeds=pooled_prompt_embeds,
918
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
919
+ num_inference_steps=num_inference_steps,
920
+ generator=generator,
921
+ **kwargs,
922
+ ).images
923
+
924
+ return images
925
+
926
+
927
+ class ConceptrolIPAdapterPlusXL(ConceptrolIPAdapterXL):
928
+ """SDXL"""
929
+
930
+ def init_proj(self):
931
+ image_proj_model = Resampler(
932
+ dim=1280,
933
+ depth=4,
934
+ dim_head=64,
935
+ heads=20,
936
+ num_queries=self.num_tokens,
937
+ embedding_dim=self.image_encoder.config.hidden_size,
938
+ output_dim=self.pipe.unet.config.cross_attention_dim,
939
+ ff_mult=4,
940
+ ).to(self.device, dtype=torch.float16)
941
+ return image_proj_model
942
+
943
+ @torch.inference_mode()
944
+ def get_image_embeds(self, pil_image):
945
+ if isinstance(pil_image, Image.Image):
946
+ pil_image = [pil_image]
947
+ clip_image = self.clip_image_processor(
948
+ images=pil_image, return_tensors="pt"
949
+ ).pixel_values
950
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
951
+ clip_image_embeds = self.image_encoder(
952
+ clip_image, output_hidden_states=True
953
+ ).hidden_states[-2]
954
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
955
+ uncond_clip_image_embeds = self.image_encoder(
956
+ torch.zeros_like(clip_image), output_hidden_states=True
957
+ ).hidden_states[-2]
958
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
959
+ return image_prompt_embeds, uncond_image_prompt_embeds
960
+
961
+ def generate(
962
+ self,
963
+ pil_images=None,
964
+ prompt=None,
965
+ negative_prompt=None,
966
+ scale=1.0,
967
+ subjects=None,
968
+ num_samples=1,
969
+ seed=42,
970
+ num_inference_steps=30,
971
+ **kwargs,
972
+ ):
973
+ self.set_scale(scale)
974
+
975
+ num_prompts = 1 # not support multiple prompts
976
+
977
+ if prompt is None:
978
+ prompt = "best quality, high quality"
979
+ if negative_prompt is None:
980
+ negative_prompt = (
981
+ "monochrome, lowres, bad anatomy, worst quality, low quality"
982
+ )
983
+
984
+ if subjects:
985
+ self.load_textual_concept(prompt, subjects)
986
+ else:
987
+ raise ValueError("Subjects must be provided")
988
+
989
+ if not isinstance(prompt, List):
990
+ prompt = [prompt] * num_prompts
991
+ if not isinstance(negative_prompt, List):
992
+ negative_prompt = [negative_prompt] * num_prompts
993
+
994
+ image_prompt_embeds_list = []
995
+ uncond_image_prompt_embeds_list = []
996
+ for pil_image in pil_images:
997
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
998
+ pil_image=pil_image
999
+ )
1000
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
1001
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
1002
+ image_prompt_embeds = image_prompt_embeds.view(
1003
+ bs_embed * num_samples, seq_len, -1
1004
+ )
1005
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(
1006
+ 1, num_samples, 1
1007
+ )
1008
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(
1009
+ bs_embed * num_samples, seq_len, -1
1010
+ )
1011
+ image_prompt_embeds_list.append(image_prompt_embeds)
1012
+ uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds)
1013
+
1014
+ with torch.inference_mode():
1015
+ (
1016
+ prompt_embeds,
1017
+ negative_prompt_embeds,
1018
+ pooled_prompt_embeds,
1019
+ negative_pooled_prompt_embeds,
1020
+ ) = self.pipe.encode_prompt(
1021
+ prompt,
1022
+ num_images_per_prompt=num_samples,
1023
+ do_classifier_free_guidance=True,
1024
+ negative_prompt=negative_prompt,
1025
+ )
1026
+ prompt_embeds = torch.cat([prompt_embeds, *image_prompt_embeds_list], dim=1)
1027
+ negative_prompt_embeds = torch.cat(
1028
+ [negative_prompt_embeds, *uncond_image_prompt_embeds_list], dim=1
1029
+ )
1030
+
1031
+ generator = get_generator(seed, self.device)
1032
+
1033
+ images = self.pipe(
1034
+ prompt_embeds=prompt_embeds,
1035
+ negative_prompt_embeds=negative_prompt_embeds,
1036
+ pooled_prompt_embeds=pooled_prompt_embeds,
1037
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1038
+ num_inference_steps=num_inference_steps,
1039
+ generator=generator,
1040
+ **kwargs,
1041
+ ).images
1042
+
1043
+ return images
ip_adapter/resampler.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from einops.layers.torch import Rearrange
10
+
11
+
12
+ # FFN
13
+ def FeedForward(dim, mult=4):
14
+ inner_dim = int(dim * mult)
15
+ return nn.Sequential(
16
+ nn.LayerNorm(dim),
17
+ nn.Linear(dim, inner_dim, bias=False),
18
+ nn.GELU(),
19
+ nn.Linear(inner_dim, dim, bias=False),
20
+ )
21
+
22
+
23
+ def reshape_tensor(x, heads):
24
+ bs, length, width = x.shape
25
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
26
+ x = x.view(bs, length, heads, -1)
27
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
28
+ x = x.transpose(1, 2)
29
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
30
+ x = x.reshape(bs, heads, length, -1)
31
+ return x
32
+
33
+
34
+ class PerceiverAttention(nn.Module):
35
+ def __init__(self, *, dim, dim_head=64, heads=8):
36
+ super().__init__()
37
+ self.scale = dim_head**-0.5
38
+ self.dim_head = dim_head
39
+ self.heads = heads
40
+ inner_dim = dim_head * heads
41
+
42
+ self.norm1 = nn.LayerNorm(dim)
43
+ self.norm2 = nn.LayerNorm(dim)
44
+
45
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
46
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
47
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
48
+
49
+ def forward(self, x, latents):
50
+ """
51
+ Args:
52
+ x (torch.Tensor): image features
53
+ shape (b, n1, D)
54
+ latent (torch.Tensor): latent features
55
+ shape (b, n2, D)
56
+ """
57
+ x = self.norm1(x)
58
+ latents = self.norm2(latents)
59
+
60
+ b, l, _ = latents.shape
61
+
62
+ q = self.to_q(latents)
63
+ kv_input = torch.cat((x, latents), dim=-2)
64
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
65
+
66
+ q = reshape_tensor(q, self.heads)
67
+ k = reshape_tensor(k, self.heads)
68
+ v = reshape_tensor(v, self.heads)
69
+
70
+ # attention
71
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
72
+ weight = (q * scale) @ (k * scale).transpose(
73
+ -2, -1
74
+ ) # More stable with f16 than dividing afterwards
75
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
76
+ out = weight @ v
77
+
78
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
79
+
80
+ return self.to_out(out)
81
+
82
+
83
+ class Resampler(nn.Module):
84
+ def __init__(
85
+ self,
86
+ dim=1024,
87
+ depth=8,
88
+ dim_head=64,
89
+ heads=16,
90
+ num_queries=8,
91
+ embedding_dim=768,
92
+ output_dim=1024,
93
+ ff_mult=4,
94
+ max_seq_len: int = 257, # CLIP tokens + CLS token
95
+ apply_pos_emb: bool = False,
96
+ num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
97
+ ):
98
+ super().__init__()
99
+ self.pos_emb = (
100
+ nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
101
+ )
102
+
103
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
104
+
105
+ self.proj_in = nn.Linear(embedding_dim, dim)
106
+
107
+ self.proj_out = nn.Linear(dim, output_dim)
108
+ self.norm_out = nn.LayerNorm(output_dim)
109
+
110
+ self.to_latents_from_mean_pooled_seq = (
111
+ nn.Sequential(
112
+ nn.LayerNorm(dim),
113
+ nn.Linear(dim, dim * num_latents_mean_pooled),
114
+ Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
115
+ )
116
+ if num_latents_mean_pooled > 0
117
+ else None
118
+ )
119
+
120
+ self.layers = nn.ModuleList([])
121
+ for _ in range(depth):
122
+ self.layers.append(
123
+ nn.ModuleList(
124
+ [
125
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
126
+ FeedForward(dim=dim, mult=ff_mult),
127
+ ]
128
+ )
129
+ )
130
+
131
+ def forward(self, x):
132
+ if self.pos_emb is not None:
133
+ n, device = x.shape[1], x.device
134
+ pos_emb = self.pos_emb(torch.arange(n, device=device))
135
+ x = x + pos_emb
136
+
137
+ latents = self.latents.repeat(x.size(0), 1, 1)
138
+
139
+ x = self.proj_in(x)
140
+
141
+ if self.to_latents_from_mean_pooled_seq:
142
+ meanpooled_seq = masked_mean(
143
+ x,
144
+ dim=1,
145
+ mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool),
146
+ )
147
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
148
+ latents = torch.cat((meanpooled_latents, latents), dim=-2)
149
+
150
+ for attn, ff in self.layers:
151
+ latents = attn(x, latents) + latents
152
+ latents = ff(latents) + latents
153
+
154
+ latents = self.proj_out(latents)
155
+ return self.norm_out(latents)
156
+
157
+
158
+ class ResamplerZeroInOut(nn.Module):
159
+ def __init__(
160
+ self,
161
+ dim=1024,
162
+ depth=8,
163
+ dim_head=64,
164
+ heads=16,
165
+ num_queries=8,
166
+ embedding_dim=768,
167
+ output_dim=1024,
168
+ ff_mult=4,
169
+ max_seq_len: int = 257, # CLIP tokens + CLS token
170
+ apply_pos_emb: bool = False,
171
+ num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
172
+ ):
173
+ super().__init__()
174
+ self.pos_emb = (
175
+ nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
176
+ )
177
+
178
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
179
+
180
+ self.proj_in_zero = nn.Linear(embedding_dim, embedding_dim, bias=False)
181
+ self.proj_in = nn.Linear(embedding_dim, dim)
182
+ self.proj_out = nn.Linear(dim, output_dim)
183
+ self.proj_out_zero = nn.Linear(output_dim, output_dim, bias=False)
184
+ self.norm_out = nn.LayerNorm(output_dim)
185
+
186
+ nn.init.zeros_(self.proj_in_zero.weight)
187
+ nn.init.zeros_(self.proj_out_zero.weight)
188
+
189
+ self.to_latents_from_mean_pooled_seq = (
190
+ nn.Sequential(
191
+ nn.LayerNorm(dim),
192
+ nn.Linear(dim, dim * num_latents_mean_pooled),
193
+ Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
194
+ )
195
+ if num_latents_mean_pooled > 0
196
+ else None
197
+ )
198
+
199
+ self.layers = nn.ModuleList([])
200
+ for _ in range(depth):
201
+ self.layers.append(
202
+ nn.ModuleList(
203
+ [
204
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
205
+ FeedForward(dim=dim, mult=ff_mult),
206
+ ]
207
+ )
208
+ )
209
+
210
+ def forward(self, x):
211
+ if self.pos_emb is not None:
212
+ n, device = x.shape[1], x.device
213
+ pos_emb = self.pos_emb(torch.arange(n, device=device))
214
+ x = x + pos_emb
215
+
216
+ latents = self.latents.repeat(x.size(0), 1, 1)
217
+
218
+ x = self.proj_in_zero(x)
219
+ x = self.proj_in(x)
220
+
221
+ if self.to_latents_from_mean_pooled_seq:
222
+ meanpooled_seq = masked_mean(
223
+ x,
224
+ dim=1,
225
+ mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool),
226
+ )
227
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
228
+ latents = torch.cat((meanpooled_latents, latents), dim=-2)
229
+
230
+ for attn, ff in self.layers:
231
+ latents = attn(x, latents) + latents
232
+ latents = ff(latents) + latents
233
+
234
+ latents = self.proj_out(latents)
235
+ latents = self.proj_out_zero(latents)
236
+ return self.norm_out(latents)
237
+
238
+
239
+ def masked_mean(t, *, dim, mask=None):
240
+ if mask is None:
241
+ return t.mean(dim=dim)
242
+
243
+ denom = mask.sum(dim=dim, keepdim=True)
244
+ mask = rearrange(mask, "b n -> b n 1")
245
+ masked_t = t.masked_fill(~mask, 0.0)
246
+
247
+ return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
ip_adapter/utils.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from PIL import Image
5
+
6
+ # global variable
7
+ raw_attn_maps = {}
8
+ raw_ip_attn_maps = {}
9
+ attn_maps = {}
10
+ ip_attn_maps = {}
11
+
12
+
13
+ def hook_fn(name):
14
+ def forward_hook(module, input, output):
15
+ if hasattr(module.processor, "attn_map"):
16
+ if name not in raw_attn_maps:
17
+ raw_attn_maps[name] = []
18
+ if name not in raw_ip_attn_maps:
19
+ raw_ip_attn_maps[name] = []
20
+ raw_attn_maps[name].append(module.processor.attn_map)
21
+ raw_ip_attn_maps[name].append(module.processor.ip_attn_map)
22
+ del module.processor.attn_map
23
+ del module.processor.ip_attn_map
24
+
25
+ return forward_hook
26
+
27
+
28
+ def post_process_attn_maps():
29
+ global raw_attn_maps, raw_ip_attn_maps, attn_maps, ip_attn_maps
30
+ attn_maps = [
31
+ dict(zip(raw_attn_maps.keys(), values))
32
+ for values in zip(*raw_attn_maps.values())
33
+ ]
34
+ ip_attn_maps = [
35
+ dict(zip(raw_ip_attn_maps.keys(), values))
36
+ for values in zip(*raw_ip_attn_maps.values())
37
+ ]
38
+
39
+ return attn_maps, ip_attn_maps
40
+
41
+
42
+ def register_cross_attention_hook(unet):
43
+ for name, module in unet.named_modules():
44
+ if name.split(".")[-1].startswith("attn2"):
45
+ module.register_forward_hook(hook_fn(name))
46
+
47
+ return unet
48
+
49
+
50
+ def upscale(attn_map, target_size):
51
+ attn_map = torch.mean(attn_map, dim=0)
52
+ attn_map = attn_map.permute(1, 0)
53
+ temp_size = None
54
+
55
+ for i in range(0, 5):
56
+ scale = 2**i
57
+ if (target_size[0] // scale) * (target_size[1] // scale) == attn_map.shape[
58
+ 1
59
+ ] * 64:
60
+ temp_size = (target_size[0] // (scale * 8), target_size[1] // (scale * 8))
61
+ break
62
+
63
+ assert temp_size is not None, "temp_size cannot is None"
64
+
65
+ attn_map = attn_map.view(attn_map.shape[0], *temp_size)
66
+
67
+ attn_map = F.interpolate(
68
+ attn_map.unsqueeze(0).to(dtype=torch.float32),
69
+ size=target_size,
70
+ mode="bilinear",
71
+ align_corners=False,
72
+ )[0]
73
+
74
+ attn_map = torch.softmax(attn_map, dim=0)
75
+ return attn_map
76
+
77
+
78
+ def get_net_attn_map(
79
+ image_size, batch_size=2, instance_or_negative=False, detach=True, step=-1
80
+ ):
81
+
82
+ idx = 0 if instance_or_negative else 1
83
+ net_attn_maps = []
84
+ net_ip_attn_maps = []
85
+
86
+ for _, attn_map in attn_maps[step].items():
87
+ attn_map = attn_map.cpu() if detach else attn_map
88
+ attn_map = torch.chunk(attn_map, batch_size)[
89
+ idx
90
+ ].squeeze() # get the attention map of text
91
+ attn_map = upscale(attn_map, image_size)
92
+ net_attn_maps.append(attn_map)
93
+
94
+ net_attn_maps = torch.mean(torch.stack(net_attn_maps, dim=0), dim=0)
95
+
96
+ for _, attn_map in ip_attn_maps[step].items():
97
+ attn_map = attn_map.cpu() if detach else attn_map
98
+ attn_map = torch.chunk(attn_map, batch_size)[
99
+ idx
100
+ ].squeeze() # get the attention map of text
101
+ attn_map = upscale(attn_map, image_size)
102
+ net_ip_attn_maps.append(attn_map)
103
+
104
+ net_ip_attn_maps = torch.mean(torch.stack(net_ip_attn_maps, dim=0), dim=0)
105
+
106
+ return net_attn_maps, net_ip_attn_maps
107
+
108
+
109
+ def attnmaps2images(net_attn_maps):
110
+ images = []
111
+
112
+ for attn_map in net_attn_maps:
113
+ attn_map = attn_map.cpu().numpy()
114
+ normalized_attn_map = (
115
+ (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
116
+ )
117
+ normalized_attn_map = normalized_attn_map.astype(np.uint8)
118
+ image = Image.fromarray(normalized_attn_map)
119
+ images.append(image)
120
+
121
+ return images
122
+
123
+
124
+ def is_torch2_available():
125
+ return hasattr(F, "scaled_dot_product_attention")
126
+
127
+
128
+ def get_generator(seed, device):
129
+
130
+ if seed is not None:
131
+ if isinstance(seed, list):
132
+ generator = [
133
+ torch.Generator(device).manual_seed(seed_item) for seed_item in seed
134
+ ]
135
+ else:
136
+ generator = torch.Generator(device).manual_seed(seed)
137
+ else:
138
+ generator = None
139
+
140
+ return generator
omini_control/__init__.py ADDED
File without changes
omini_control/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (153 Bytes). View file
 
omini_control/__pycache__/block.cpython-310.pyc ADDED
Binary file (6.15 kB). View file
 
omini_control/__pycache__/concept_alignment.cpython-310.pyc ADDED
Binary file (14.1 kB). View file
 
omini_control/__pycache__/conceptrol.cpython-310.pyc ADDED
Binary file (4.1 kB). View file
 
omini_control/__pycache__/condition.cpython-310.pyc ADDED
Binary file (3.34 kB). View file
 
omini_control/__pycache__/flux_conceptrol_pipeline.cpython-310.pyc ADDED
Binary file (7.59 kB). View file
 
omini_control/__pycache__/lora_controller.cpython-310.pyc ADDED
Binary file (3 kB). View file
 
omini_control/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (5.12 kB). View file
 
omini_control/block.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional, Dict, Any
3
+ from diffusers.models.attention_processor import Attention, F
4
+ from .lora_controller import enable_lora
5
+ from .conceptrol import Conceptrol
6
+
7
+
8
+ def attn_forward(
9
+ attn: Attention,
10
+ hidden_states: torch.FloatTensor,
11
+ encoder_hidden_states: torch.FloatTensor = None,
12
+ condition_latents: torch.FloatTensor = None,
13
+ attention_mask: Optional[torch.FloatTensor] = None,
14
+ image_rotary_emb: Optional[torch.Tensor] = None,
15
+ cond_rotary_emb: Optional[torch.Tensor] = None,
16
+ model_config: Optional[Dict[str, Any]] = {},
17
+ conceptrol: Conceptrol = None,
18
+ ) -> torch.FloatTensor:
19
+ global attn_maps
20
+ batch_size, _, _ = (
21
+ hidden_states.shape
22
+ if encoder_hidden_states is None
23
+ else encoder_hidden_states.shape
24
+ )
25
+
26
+ with enable_lora(
27
+ (attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False)
28
+ ):
29
+ # `sample` projections.
30
+ query = attn.to_q(hidden_states)
31
+ key = attn.to_k(hidden_states)
32
+ value = attn.to_v(hidden_states)
33
+
34
+ inner_dim = key.shape[-1]
35
+ head_dim = inner_dim // attn.heads
36
+
37
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
38
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
39
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
40
+
41
+ if attn.norm_q is not None:
42
+ query = attn.norm_q(query)
43
+ if attn.norm_k is not None:
44
+ key = attn.norm_k(key)
45
+
46
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
47
+ if encoder_hidden_states is not None:
48
+ # `context` projections.
49
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
50
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
51
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
52
+
53
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
54
+ batch_size, -1, attn.heads, head_dim
55
+ ).transpose(1, 2)
56
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
57
+ batch_size, -1, attn.heads, head_dim
58
+ ).transpose(1, 2)
59
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
60
+ batch_size, -1, attn.heads, head_dim
61
+ ).transpose(1, 2)
62
+
63
+ if attn.norm_added_q is not None:
64
+ encoder_hidden_states_query_proj = attn.norm_added_q(
65
+ encoder_hidden_states_query_proj
66
+ )
67
+ if attn.norm_added_k is not None:
68
+ encoder_hidden_states_key_proj = attn.norm_added_k(
69
+ encoder_hidden_states_key_proj
70
+ )
71
+
72
+ # attention
73
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
74
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
75
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
76
+
77
+ if image_rotary_emb is not None:
78
+ from diffusers.models.embeddings import apply_rotary_emb
79
+
80
+ query = apply_rotary_emb(query, image_rotary_emb)
81
+ key = apply_rotary_emb(key, image_rotary_emb)
82
+
83
+ if condition_latents is not None:
84
+ cond_query = attn.to_q(condition_latents)
85
+ cond_key = attn.to_k(condition_latents)
86
+ cond_value = attn.to_v(condition_latents)
87
+
88
+ cond_query = cond_query.view(batch_size, -1, attn.heads, head_dim).transpose(
89
+ 1, 2
90
+ )
91
+ cond_key = cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
92
+ cond_value = cond_value.view(batch_size, -1, attn.heads, head_dim).transpose(
93
+ 1, 2
94
+ )
95
+ if attn.norm_q is not None:
96
+ cond_query = attn.norm_q(cond_query)
97
+ if attn.norm_k is not None:
98
+ cond_key = attn.norm_k(cond_key)
99
+
100
+ if cond_rotary_emb is not None:
101
+ cond_query = apply_rotary_emb(cond_query, cond_rotary_emb)
102
+ cond_key = apply_rotary_emb(cond_key, cond_rotary_emb)
103
+
104
+ if condition_latents is not None:
105
+ query = torch.cat([query, cond_query], dim=2)
106
+ key = torch.cat([key, cond_key], dim=2)
107
+ value = torch.cat([value, cond_value], dim=2)
108
+
109
+ if not model_config.get("union_cond_attn", True):
110
+ # If we don't want to use the union condition attention, we need to mask the attention
111
+ # between the hidden states and the condition latents
112
+ attention_mask = torch.ones(
113
+ query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
114
+ )
115
+ condition_n = cond_query.shape[2]
116
+ attention_mask[-condition_n:, :-condition_n] = False
117
+ attention_mask[:-condition_n, -condition_n:] = False
118
+ if hasattr(attn, "c_factor"):
119
+ attention_mask = torch.zeros(
120
+ query.shape[2], key.shape[2], device=query.device, dtype=query.dtype
121
+ )
122
+ condition_n = cond_query.shape[2]
123
+ bias = torch.log(attn.c_factor[0])
124
+ attention_mask[-condition_n:, :-condition_n] = bias
125
+ attention_mask[:-condition_n, -condition_n:] = bias
126
+
127
+ if conceptrol is None:
128
+ print(
129
+ "Conceptrol using this stuff indicates that the implementation is problematic"
130
+ )
131
+ hidden_states = F.scaled_dot_product_attention(
132
+ query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask
133
+ )
134
+ hidden_states = hidden_states.transpose(1, 2).reshape(
135
+ batch_size, -1, attn.heads * head_dim
136
+ )
137
+ hidden_states = hidden_states.to(query.dtype)
138
+ else:
139
+ conceptrolled_attention_probs = conceptrol(
140
+ query, key, attention_mask, c_factor=attn.c_factor
141
+ )
142
+ hidden_states = conceptrolled_attention_probs @ value
143
+ hidden_states = hidden_states.transpose(1, 2).reshape(
144
+ batch_size, -1, attn.heads * head_dim
145
+ )
146
+ hidden_states = hidden_states.to(query.dtype)
147
+
148
+ if encoder_hidden_states is not None:
149
+ if condition_latents is not None:
150
+ encoder_hidden_states, hidden_states, condition_latents = (
151
+ hidden_states[:, : encoder_hidden_states.shape[1]],
152
+ hidden_states[
153
+ :, encoder_hidden_states.shape[1] : -condition_latents.shape[1]
154
+ ],
155
+ hidden_states[:, -condition_latents.shape[1] :],
156
+ )
157
+ else:
158
+ encoder_hidden_states, hidden_states = (
159
+ hidden_states[:, : encoder_hidden_states.shape[1]],
160
+ hidden_states[:, encoder_hidden_states.shape[1] :],
161
+ )
162
+
163
+ with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)):
164
+ # linear proj
165
+ hidden_states = attn.to_out[0](hidden_states)
166
+ # dropout
167
+ hidden_states = attn.to_out[1](hidden_states)
168
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
169
+
170
+ if condition_latents is not None:
171
+ condition_latents = attn.to_out[0](condition_latents)
172
+ condition_latents = attn.to_out[1](condition_latents)
173
+
174
+ return (
175
+ (hidden_states, encoder_hidden_states, condition_latents)
176
+ if condition_latents is not None
177
+ else (hidden_states, encoder_hidden_states)
178
+ )
179
+ elif condition_latents is not None:
180
+ # if there are condition_latents, we need to separate the hidden_states and the condition_latents
181
+ hidden_states, condition_latents = (
182
+ hidden_states[:, : -condition_latents.shape[1]],
183
+ hidden_states[:, -condition_latents.shape[1] :],
184
+ )
185
+ return hidden_states, condition_latents
186
+ else:
187
+ return hidden_states
188
+
189
+
190
+ def block_forward(
191
+ self,
192
+ hidden_states: torch.FloatTensor,
193
+ encoder_hidden_states: torch.FloatTensor,
194
+ condition_latents: torch.FloatTensor,
195
+ temb: torch.FloatTensor,
196
+ cond_temb: torch.FloatTensor,
197
+ cond_rotary_emb=None,
198
+ image_rotary_emb=None,
199
+ model_config: Optional[Dict[str, Any]] = {},
200
+ conceptrol: Conceptrol = None,
201
+ ):
202
+ use_cond = condition_latents is not None
203
+ with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)):
204
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
205
+ hidden_states, emb=temb
206
+ )
207
+
208
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
209
+ self.norm1_context(encoder_hidden_states, emb=temb)
210
+ )
211
+
212
+ if use_cond:
213
+ (
214
+ norm_condition_latents,
215
+ cond_gate_msa,
216
+ cond_shift_mlp,
217
+ cond_scale_mlp,
218
+ cond_gate_mlp,
219
+ ) = self.norm1(condition_latents, emb=cond_temb)
220
+
221
+ # Attention.
222
+ result = attn_forward(
223
+ self.attn,
224
+ model_config=model_config,
225
+ hidden_states=norm_hidden_states,
226
+ encoder_hidden_states=norm_encoder_hidden_states,
227
+ condition_latents=norm_condition_latents if use_cond else None,
228
+ image_rotary_emb=image_rotary_emb,
229
+ cond_rotary_emb=cond_rotary_emb if use_cond else None,
230
+ conceptrol=conceptrol if use_cond else None,
231
+ )
232
+ attn_output, context_attn_output = result[:2]
233
+ cond_attn_output = result[2] if use_cond else None
234
+
235
+ # Process attention outputs for the `hidden_states`.
236
+ # 1. hidden_states
237
+ attn_output = gate_msa.unsqueeze(1) * attn_output
238
+ hidden_states = hidden_states + attn_output
239
+ # 2. encoder_hidden_states
240
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
241
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
242
+ # 3. condition_latents
243
+ if use_cond:
244
+ cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
245
+ condition_latents = condition_latents + cond_attn_output
246
+ if model_config.get("add_cond_attn", False):
247
+ hidden_states += cond_attn_output
248
+
249
+ # LayerNorm + MLP.
250
+ # 1. hidden_states
251
+ norm_hidden_states = self.norm2(hidden_states)
252
+ norm_hidden_states = (
253
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
254
+ )
255
+ # 2. encoder_hidden_states
256
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
257
+ norm_encoder_hidden_states = (
258
+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
259
+ )
260
+ # 3. condition_latents
261
+ if use_cond:
262
+ norm_condition_latents = self.norm2(condition_latents)
263
+ norm_condition_latents = (
264
+ norm_condition_latents * (1 + cond_scale_mlp[:, None])
265
+ + cond_shift_mlp[:, None]
266
+ )
267
+
268
+ # Feed-forward.
269
+ with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)):
270
+ # 1. hidden_states
271
+ ff_output = self.ff(norm_hidden_states)
272
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
273
+ # 2. encoder_hidden_states
274
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
275
+ context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output
276
+ # 3. condition_latents
277
+ if use_cond:
278
+ cond_ff_output = self.ff(norm_condition_latents)
279
+ cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
280
+
281
+ # Process feed-forward outputs.
282
+ hidden_states = hidden_states + ff_output
283
+ encoder_hidden_states = encoder_hidden_states + context_ff_output
284
+ if use_cond:
285
+ condition_latents = condition_latents + cond_ff_output
286
+
287
+ # Clip to avoid overflow.
288
+ if encoder_hidden_states.dtype == torch.float16:
289
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
290
+
291
+ return encoder_hidden_states, hidden_states, condition_latents if use_cond else None
292
+
293
+
294
+ def single_block_forward(
295
+ self,
296
+ hidden_states: torch.FloatTensor,
297
+ temb: torch.FloatTensor,
298
+ image_rotary_emb=None,
299
+ condition_latents: torch.FloatTensor = None,
300
+ cond_temb: torch.FloatTensor = None,
301
+ cond_rotary_emb=None,
302
+ model_config: Optional[Dict[str, Any]] = {},
303
+ conceptrol: Conceptrol = None,
304
+ ):
305
+
306
+ using_cond = condition_latents is not None
307
+ residual = hidden_states
308
+ with enable_lora(
309
+ (
310
+ self.norm.linear,
311
+ self.proj_mlp,
312
+ ),
313
+ model_config.get("latent_lora", False),
314
+ ):
315
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
316
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
317
+ if using_cond:
318
+ residual_cond = condition_latents
319
+ norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb)
320
+ mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents))
321
+
322
+ attn_output = attn_forward(
323
+ self.attn,
324
+ model_config=model_config,
325
+ hidden_states=norm_hidden_states,
326
+ image_rotary_emb=image_rotary_emb,
327
+ **(
328
+ {
329
+ "condition_latents": norm_condition_latents,
330
+ "cond_rotary_emb": cond_rotary_emb if using_cond else None,
331
+ "conceptrol": conceptrol if using_cond else None,
332
+ }
333
+ if using_cond
334
+ else {}
335
+ ),
336
+ )
337
+ if using_cond:
338
+ attn_output, cond_attn_output = attn_output
339
+
340
+ with enable_lora((self.proj_out,), model_config.get("latent_lora", False)):
341
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
342
+ gate = gate.unsqueeze(1)
343
+ hidden_states = gate * self.proj_out(hidden_states)
344
+ hidden_states = residual + hidden_states
345
+ if using_cond:
346
+ condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
347
+ cond_gate = cond_gate.unsqueeze(1)
348
+ condition_latents = cond_gate * self.proj_out(condition_latents)
349
+ condition_latents = residual_cond + condition_latents
350
+
351
+ if hidden_states.dtype == torch.float16:
352
+ hidden_states = hidden_states.clip(-65504, 65504)
353
+
354
+ return hidden_states if not using_cond else (hidden_states, condition_latents)
omini_control/conceptrol.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import torch
6
+
7
+
8
+ class Conceptrol:
9
+ def __init__(self, config):
10
+ if "name" not in config:
11
+ raise KeyError("name has to be provided as 'conceptrol' or 'ominicontrol'")
12
+
13
+ name = config["name"]
14
+ if name not in ["conceptrol", "ominicontrol"]:
15
+ raise ValueError(
16
+ f"Name must be one of ['conceptrol', 'ominicontrol'], got {name}"
17
+ )
18
+
19
+ try:
20
+ log_attn_map = config["log_attn_map"]
21
+ except KeyError:
22
+ log_attn_map = False
23
+
24
+ # static
25
+ self.NUM_BLOCKS = 19 # this is fixed for FLUX
26
+ self.M = 512 # num of text tokens, fixed for FLUX
27
+ self.N = 1024 # num of latent / image condtion tokens, fixed for FLUX
28
+ self.EP = -10e6
29
+ self.CONCEPT_BLOCK_IDX = 18
30
+
31
+ # fixed during one generation
32
+ self.name = name
33
+
34
+ # variable during one generation
35
+ self.textual_concept_mask = None
36
+ self.forward_count = 0
37
+
38
+ # log out for visualization
39
+ if log_attn_map:
40
+ self.attn_maps = {"latent_to_concept": [], "latent_to_image": []}
41
+
42
+ def __call__(
43
+ self,
44
+ query: torch.FloatTensor,
45
+ key: torch.FloatTensor,
46
+ attention_mask: torch.Tensor,
47
+ c_factor: float = 1.0,
48
+ ) -> torch.Tensor:
49
+
50
+ if not hasattr(self, "textual_concept_idx"):
51
+ raise AttributeError(
52
+ "textual_concept_idx must be registered before calling Conceptrol"
53
+ )
54
+
55
+ # Skip computation for ominicontrol
56
+ if self.name == "ominicontrol":
57
+ scale_factor = 1 / math.sqrt(query.size(-1))
58
+ attention_weight = (
59
+ query @ key.transpose(-2, -1) * scale_factor + attention_mask
60
+ )
61
+ attention_probs = torch.softmax(
62
+ attention_weight, dim=-1
63
+ ) # [B, H, M+2N, M+2N]
64
+ return attention_probs
65
+
66
+ if not self.textual_concept_idx[0] < self.textual_concept_idx[1]:
67
+ raise ValueError(
68
+ f"register_idx[0] must be less than register_idx[1], "
69
+ f"got {self.textual_concept_idx[0]} >= {self.textual_concept_idx[1]}"
70
+ )
71
+
72
+ ### Reset attention mask predefined in ominicontrol
73
+ attention_mask = torch.zeros_like(attention_mask)
74
+ bias = torch.log(c_factor[0])
75
+ # attention of image condition to latent
76
+ attention_mask[-self.N :, self.M : -self.N] = bias
77
+ # attention of latent to image condition
78
+ attention_mask[self.M : -self.N, -self.N :] = bias
79
+
80
+ # attention of textual concept to image condition
81
+ attention_mask[
82
+ self.textual_concept_idx[0] : self.textual_concept_idx[1], -self.N :
83
+ ] = bias
84
+ # attention of other words to image condition (set as negative inf)
85
+ attention_mask[: self.textual_concept_idx[0], -self.N :] = self.EP
86
+ attention_mask[self.textual_concept_idx[1] : self.M, -self.N :] = self.EP
87
+
88
+ # If there is no textual_concept_mask, it means currently in layers previous to the first concept-specific block
89
+ if self.textual_concept_mask is None:
90
+ self.textual_concept_mask = (
91
+ torch.zeros_like(attention_mask).unsqueeze(0).unsqueeze(0)
92
+ )
93
+
94
+ ### Compute attention
95
+ scale_factor = 1 / math.sqrt(query.size(-1))
96
+ attention_weight = (
97
+ query @ key.transpose(-2, -1) * scale_factor
98
+ + attention_mask
99
+ + self.textual_concept_mask
100
+ )
101
+ # [B, H, M+2N, M+2N]
102
+ attention_probs = torch.softmax(attention_weight, dim=-1)
103
+
104
+ ### Extract textual concept mask if it's concept-specific block
105
+ is_concept_block = (
106
+ self.forward_count % self.NUM_BLOCKS == self.CONCEPT_BLOCK_IDX
107
+ )
108
+ if is_concept_block:
109
+ # Shape: [B, H, N, S], where S is the token numbers of the subject
110
+ textual_concept_mask_local = attention_probs[
111
+ :,
112
+ :,
113
+ self.M : -self.N,
114
+ self.textual_concept_idx[0] : self.textual_concept_idx[1],
115
+ ]
116
+ # Consider the ratio within context of text
117
+ textual_concept_mask_local = textual_concept_mask_local / torch.sum(
118
+ attention_probs[:, :, self.M : -self.N, : self.M], dim=-1, keepdim=True
119
+ )
120
+ # Average over words and head, Shape: [B, 1, N, 1]
121
+ textual_concept_mask_local = torch.mean(
122
+ textual_concept_mask_local, dim=(-1, 1), keepdim=True
123
+ )
124
+ # Normalize to average as 1
125
+ textual_concept_mask_local = textual_concept_mask_local / torch.mean(
126
+ textual_concept_mask_local, dim=-2, keepdim=True
127
+ )
128
+
129
+ self.textual_concept_mask = (
130
+ torch.zeros_like(attention_mask).unsqueeze(0).unsqueeze(0)
131
+ )
132
+ # log(A) in the paper
133
+ self.textual_concept_mask[:, :, self.M : -self.N, -self.N :] = torch.log(
134
+ textual_concept_mask_local
135
+ )
136
+
137
+ self.forward_count += 1
138
+
139
+ return attention_probs
140
+
141
+ def register(self, textual_concept_idx):
142
+ self.textual_concept_idx = textual_concept_idx
143
+
144
+ def visualize_attn_map(self, config_name: str, subject: str):
145
+ global global_concept_mask
146
+ global forward_count
147
+
148
+ save_dir = f"attn_maps/{config_name}/{subject}"
149
+ if not os.path.exists(save_dir):
150
+ os.makedirs(save_dir)
151
+ for attn_map_name, attn_maps in self.attn_maps.items():
152
+ if "token_to_token" in attn_map_name:
153
+ continue
154
+ plt.figure()
155
+
156
+ rows, cols = 8, 19
157
+ fig, axes = plt.subplots(
158
+ rows, cols, figsize=(64 * cols / 100, 64 * rows / 100)
159
+ )
160
+ fig.subplots_adjust(
161
+ wspace=0.1, hspace=0.1
162
+ ) # Adjust spacing between subplots
163
+
164
+ # Plot each array in the list on the grid
165
+ for i, ax in enumerate(axes.flatten()):
166
+ if i < len(attn_maps): # Only plot existing arrays
167
+ attn_map = attn_maps[i] / np.amax(attn_maps[i])
168
+ ax.imshow(attn_map, cmap="viridis")
169
+ ax.axis("off") # Turn off axes for clarity
170
+ else:
171
+ ax.axis("off") # Turn off unused subplots
172
+
173
+ fig.set_size_inches(64 * cols / 100, 64 * rows / 100)
174
+ save_path = os.path.join(save_dir, f"{attn_map_name}.jpg")
175
+ plt.savefig(save_path)
176
+ plt.close()
177
+
178
+ for attn_map_name, attn_maps in self.attn_maps.items():
179
+ if "token_to_token" not in attn_map_name:
180
+ continue
181
+ plt.figure()
182
+
183
+ rows, cols = 8, 19
184
+ fig, axes = plt.subplots(
185
+ rows, cols, figsize=(2560 * cols / 100, 2560 * rows / 100)
186
+ )
187
+ fig.subplots_adjust(
188
+ wspace=0.1, hspace=0.1
189
+ ) # Adjust spacing between subplots
190
+
191
+ # Plot each array in the list on the grid
192
+ for i, ax in enumerate(axes.flatten()):
193
+ if i < len(attn_maps): # Only plot existing arrays
194
+ attn_map = attn_maps[i] / np.amax(attn_maps[i])
195
+ ax.imshow(attn_map, cmap="viridis")
196
+ ax.axis("off") # Turn off axes for clarity
197
+ else:
198
+ ax.axis("off") # Turn off unused subplots
199
+
200
+ fig.set_size_inches(64 * cols / 100, 64 * rows / 100)
201
+ save_path = os.path.join(save_dir, f"{attn_map_name}.jpg")
202
+ plt.savefig(save_path)
203
+ plt.close()
204
+
205
+ for attn_map_name in self.attn_maps.keys():
206
+ self.attn_maps[attn_map_name] = []
207
+ global_concept_mask = None
208
+ forward_count = 0
omini_control/condition.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Union, Tuple
3
+ from diffusers.pipelines import FluxPipeline
4
+ from PIL import Image, ImageFilter
5
+ import numpy as np
6
+ import cv2
7
+
8
+ condition_dict = {
9
+ "depth": 0,
10
+ "canny": 1,
11
+ "subject": 4,
12
+ "coloring": 6,
13
+ "deblurring": 7,
14
+ "fill": 9,
15
+ }
16
+
17
+
18
+ class Condition(object):
19
+ def __init__(
20
+ self,
21
+ condition_type: str,
22
+ raw_img: Union[Image.Image, torch.Tensor] = None,
23
+ condition: Union[Image.Image, torch.Tensor] = None,
24
+ mask=None,
25
+ ) -> None:
26
+ self.condition_type = condition_type
27
+ assert raw_img is not None or condition is not None
28
+ if raw_img is not None:
29
+ self.condition = self.get_condition(condition_type, raw_img)
30
+ else:
31
+ self.condition = condition
32
+ # TODO: Add mask support
33
+ assert mask is None, "Mask not supported yet"
34
+
35
+ def get_condition(
36
+ self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor]
37
+ ) -> Union[Image.Image, torch.Tensor]:
38
+ """
39
+ Returns the condition image.
40
+ """
41
+ if condition_type == "depth":
42
+ from transformers import pipeline
43
+
44
+ depth_pipe = pipeline(
45
+ task="depth-estimation",
46
+ model="LiheYoung/depth-anything-small-hf",
47
+ device="cuda",
48
+ )
49
+ source_image = raw_img.convert("RGB")
50
+ condition_img = depth_pipe(source_image)["depth"].convert("RGB")
51
+ return condition_img
52
+ elif condition_type == "canny":
53
+ img = np.array(raw_img)
54
+ edges = cv2.Canny(img, 100, 200)
55
+ edges = Image.fromarray(edges).convert("RGB")
56
+ return edges
57
+ elif condition_type == "subject":
58
+ return raw_img
59
+ elif condition_type == "coloring":
60
+ return raw_img.convert("L").convert("RGB")
61
+ elif condition_type == "deblurring":
62
+ condition_image = (
63
+ raw_img.convert("RGB")
64
+ .filter(ImageFilter.GaussianBlur(10))
65
+ .convert("RGB")
66
+ )
67
+ return condition_image
68
+ elif condition_type == "fill":
69
+ return raw_img.convert("RGB")
70
+ return self.condition
71
+
72
+ @property
73
+ def type_id(self) -> int:
74
+ """
75
+ Returns the type id of the condition.
76
+ """
77
+ return condition_dict[self.condition_type]
78
+
79
+ @classmethod
80
+ def get_type_id(cls, condition_type: str) -> int:
81
+ """
82
+ Returns the type id of the condition.
83
+ """
84
+ return condition_dict[condition_type]
85
+
86
+ def _encode_image(self, pipe: FluxPipeline, cond_img: Image.Image) -> torch.Tensor:
87
+ """
88
+ Encodes an image condition into tokens using the pipeline.
89
+ """
90
+ cond_img = pipe.image_processor.preprocess(cond_img)
91
+ cond_img = cond_img.to(pipe.device).to(pipe.dtype)
92
+ cond_img = pipe.vae.encode(cond_img).latent_dist.sample()
93
+ cond_img = (
94
+ cond_img - pipe.vae.config.shift_factor
95
+ ) * pipe.vae.config.scaling_factor
96
+ cond_tokens = pipe._pack_latents(cond_img, *cond_img.shape)
97
+ cond_ids = pipe._prepare_latent_image_ids(
98
+ cond_img.shape[0],
99
+ cond_img.shape[2],
100
+ cond_img.shape[3],
101
+ pipe.device,
102
+ pipe.dtype,
103
+ )
104
+ return cond_tokens, cond_ids
105
+
106
+ def encode(self, pipe: FluxPipeline) -> Tuple[torch.Tensor, torch.Tensor, int]:
107
+ """
108
+ Encodes the condition into tokens, ids and type_id.
109
+ """
110
+ if self.condition_type in [
111
+ "depth",
112
+ "canny",
113
+ "subject",
114
+ "coloring",
115
+ "deblurring",
116
+ "fill",
117
+ ]:
118
+ tokens, ids = self._encode_image(pipe, self.condition)
119
+ else:
120
+ raise NotImplementedError(
121
+ f"Condition type {self.condition_type} not implemented"
122
+ )
123
+ type_id = torch.ones_like(ids[:, :1]) * self.type_id
124
+ return tokens, ids, type_id
omini_control/flux_conceptrol_pipeline.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.pipelines import FluxPipeline
3
+ from typing import List, Union, Optional, Dict, Any, Callable
4
+ from .transformer import tranformer_forward
5
+ from .condition import Condition
6
+ from .conceptrol import Conceptrol
7
+
8
+ from diffusers.pipelines.flux.pipeline_flux import (
9
+ FluxPipelineOutput,
10
+ calculate_shift,
11
+ retrieve_timesteps,
12
+ np,
13
+ )
14
+
15
+ denoising_images = []
16
+
17
+
18
+ def prepare_params(
19
+ prompt: Union[str, List[str]] = None,
20
+ prompt_2: Optional[Union[str, List[str]]] = None,
21
+ height: Optional[int] = 512,
22
+ width: Optional[int] = 512,
23
+ num_inference_steps: int = 28,
24
+ timesteps: List[int] = None,
25
+ guidance_scale: float = 3.5,
26
+ num_images_per_prompt: Optional[int] = 1,
27
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
28
+ latents: Optional[torch.FloatTensor] = None,
29
+ prompt_embeds: Optional[torch.FloatTensor] = None,
30
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
31
+ output_type: Optional[str] = "pil",
32
+ return_dict: bool = True,
33
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
34
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
35
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
36
+ max_sequence_length: int = 512,
37
+ **kwargs: dict,
38
+ ):
39
+ return (
40
+ prompt,
41
+ prompt_2,
42
+ height,
43
+ width,
44
+ num_inference_steps,
45
+ timesteps,
46
+ guidance_scale,
47
+ num_images_per_prompt,
48
+ generator,
49
+ latents,
50
+ prompt_embeds,
51
+ pooled_prompt_embeds,
52
+ output_type,
53
+ return_dict,
54
+ joint_attention_kwargs,
55
+ callback_on_step_end,
56
+ callback_on_step_end_tensor_inputs,
57
+ max_sequence_length,
58
+ )
59
+
60
+
61
+ def seed_everything(seed: int = 42):
62
+ torch.backends.cudnn.deterministic = True
63
+ torch.manual_seed(seed)
64
+ np.random.seed(seed)
65
+
66
+
67
+ def set_scale(pipe, condition_scale):
68
+ for name, module in pipe.transformer.named_modules():
69
+ if not name.endswith(".attn"):
70
+ continue
71
+ module.c_factor = torch.ones(1, 1) * condition_scale
72
+
73
+
74
+ class FluxConceptrolPipeline(FluxPipeline):
75
+
76
+ def find_subsequence(self, text, sub):
77
+ sub_len = len(sub)
78
+ for i in range(len(text) - sub_len + 1):
79
+ if text[i : i + sub_len] == sub:
80
+ return i, i + sub_len # Return start and end indices
81
+ return None
82
+
83
+ def locate_subject(self, prompt, subject, max_length=512):
84
+ text_inputs = self.tokenizer_2.tokenize(
85
+ prompt,
86
+ padding="max_length",
87
+ max_length=max_length,
88
+ truncation=True,
89
+ return_length=False,
90
+ return_overflowing_tokens=False,
91
+ return_tensors="pt",
92
+ )
93
+ subject_inputs = self.tokenizer_2.tokenize(
94
+ subject,
95
+ truncation=True,
96
+ return_length=False,
97
+ return_overflowing_tokens=False,
98
+ return_tensors="pt",
99
+ )
100
+ print("Text Inputs:", text_inputs)
101
+ print("Sbject Inputs:", subject_inputs)
102
+ print(self.find_subsequence(text_inputs, subject_inputs))
103
+ return self.find_subsequence(text_inputs, subject_inputs)
104
+
105
+ text_input_ids = text_inputs
106
+ return (
107
+ text_input_ids.index(subject_inputs[0]),
108
+ text_input_ids.index(subject_inputs[-1]) + 1,
109
+ )
110
+
111
+ def load_conceptrol(self, conceptrol):
112
+ self.conceptrol = conceptrol
113
+
114
+ @torch.no_grad()
115
+ def __call__(
116
+ self,
117
+ image=None,
118
+ model_config: Optional[Dict[str, Any]] = {},
119
+ condition_scale: float = 1.0,
120
+ subject: Optional[str] = None,
121
+ control_guidance_start: float = 0.0,
122
+ control_guidance_end: float = 1.0,
123
+ conceptrol: Conceptrol = None,
124
+ seed: int = 42,
125
+ **params: dict,
126
+ ):
127
+ seed_everything(seed)
128
+
129
+ if conceptrol is None:
130
+ if not hasattr(self, "conceptrol"):
131
+ raise ValueError("Default conceptrol not loaded. Please call load_conceptrol() first.")
132
+ conceptrol = self.conceptrol
133
+
134
+ conditions = [Condition("subject", image.convert("RGB").resize((512, 512)))]
135
+ if condition_scale != 1:
136
+ for name, module in self.transformer.named_modules():
137
+ if not name.endswith(".attn"):
138
+ continue
139
+ module.c_factor = torch.ones(1, 1) * condition_scale
140
+
141
+ (
142
+ prompt,
143
+ prompt_2,
144
+ height,
145
+ width,
146
+ num_inference_steps,
147
+ timesteps,
148
+ guidance_scale,
149
+ num_images_per_prompt,
150
+ generator,
151
+ latents,
152
+ prompt_embeds,
153
+ pooled_prompt_embeds,
154
+ output_type,
155
+ return_dict,
156
+ joint_attention_kwargs,
157
+ callback_on_step_end,
158
+ callback_on_step_end_tensor_inputs,
159
+ max_sequence_length,
160
+ ) = prepare_params(**params)
161
+
162
+ if subject is not None:
163
+ textual_concept_idx = self.locate_subject(params["prompt"], subject)
164
+ else:
165
+ raise ValueError("Subject has to be provided")
166
+
167
+ if textual_concept_idx is None:
168
+ raise ValueError("Textual concept idx has to be provided")
169
+
170
+ conceptrol.register(textual_concept_idx)
171
+
172
+ height = height or self.default_sample_size * self.vae_scale_factor
173
+ width = width or self.default_sample_size * self.vae_scale_factor
174
+
175
+ # 1. Check inputs. Raise error if not correct
176
+ self.check_inputs(
177
+ prompt,
178
+ prompt_2,
179
+ height,
180
+ width,
181
+ prompt_embeds=prompt_embeds,
182
+ pooled_prompt_embeds=pooled_prompt_embeds,
183
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
184
+ max_sequence_length=max_sequence_length,
185
+ )
186
+
187
+ self._guidance_scale = guidance_scale
188
+ self._joint_attention_kwargs = joint_attention_kwargs
189
+ self._interrupt = False
190
+
191
+ # 2. Define call parameters
192
+ if prompt is not None and isinstance(prompt, str):
193
+ batch_size = 1
194
+ elif prompt is not None and isinstance(prompt, list):
195
+ batch_size = len(prompt)
196
+ else:
197
+ batch_size = prompt_embeds.shape[0]
198
+
199
+ device = self._execution_device
200
+
201
+ lora_scale = (
202
+ self.joint_attention_kwargs.get("scale", None)
203
+ if self.joint_attention_kwargs is not None
204
+ else None
205
+ )
206
+ (
207
+ prompt_embeds,
208
+ pooled_prompt_embeds,
209
+ text_ids,
210
+ ) = self.encode_prompt(
211
+ prompt=prompt,
212
+ prompt_2=prompt_2,
213
+ prompt_embeds=prompt_embeds,
214
+ pooled_prompt_embeds=pooled_prompt_embeds,
215
+ device=device,
216
+ num_images_per_prompt=num_images_per_prompt,
217
+ max_sequence_length=max_sequence_length,
218
+ lora_scale=lora_scale,
219
+ )
220
+
221
+ # 4. Prepare latent variables
222
+ num_channels_latents = self.transformer.config.in_channels // 4
223
+ latents, latent_image_ids = self.prepare_latents(
224
+ batch_size * num_images_per_prompt,
225
+ num_channels_latents,
226
+ height,
227
+ width,
228
+ prompt_embeds.dtype,
229
+ device,
230
+ generator,
231
+ latents,
232
+ )
233
+
234
+ # 4.1. Prepare conditions
235
+ condition_latents, condition_ids, condition_type_ids = ([] for _ in range(3))
236
+ use_condition = conditions is not None or []
237
+ if use_condition:
238
+ assert len(conditions) <= 1, "Only one condition is supported for now."
239
+ self.set_adapters(conditions[0].condition_type)
240
+ for condition in conditions:
241
+ tokens, ids, type_id = condition.encode(self)
242
+ condition_latents.append(tokens) # [batch_size, token_n, token_dim]
243
+ condition_ids.append(ids) # [token_n, id_dim(3)]
244
+ condition_type_ids.append(type_id) # [token_n, 1]
245
+ condition_latents = torch.cat(condition_latents, dim=1)
246
+ condition_ids = torch.cat(condition_ids, dim=0)
247
+ if condition.condition_type == "subject":
248
+ condition_ids[:, 2] += width // 16
249
+ condition_type_ids = torch.cat(condition_type_ids, dim=0)
250
+
251
+ # 5. Prepare timesteps
252
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
253
+ image_seq_len = latents.shape[1]
254
+ mu = calculate_shift(
255
+ image_seq_len,
256
+ self.scheduler.config.base_image_seq_len,
257
+ self.scheduler.config.max_image_seq_len,
258
+ self.scheduler.config.base_shift,
259
+ self.scheduler.config.max_shift,
260
+ )
261
+ timesteps, num_inference_steps = retrieve_timesteps(
262
+ self.scheduler,
263
+ num_inference_steps,
264
+ device,
265
+ timesteps,
266
+ sigmas,
267
+ mu=mu,
268
+ )
269
+ num_warmup_steps = max(
270
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
271
+ )
272
+ self._num_timesteps = len(timesteps)
273
+
274
+ # 6. Denoising loop
275
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
276
+ for i, t in enumerate(timesteps):
277
+ if (i / len(timesteps) < control_guidance_start) or (
278
+ (i + 1) / len(timesteps) > control_guidance_end
279
+ ):
280
+ set_scale(self, 0.5) # Warmup required for the first few steps
281
+ else:
282
+ set_scale(self, condition_scale)
283
+ if self.interrupt:
284
+ continue
285
+
286
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
287
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
288
+
289
+ # handle guidance
290
+ if self.transformer.config.guidance_embeds:
291
+ guidance = torch.tensor([guidance_scale], device=device)
292
+ guidance = guidance.expand(latents.shape[0])
293
+ else:
294
+ guidance = None
295
+ noise_pred = tranformer_forward(
296
+ self.transformer,
297
+ model_config=model_config,
298
+ conceptrol=conceptrol,
299
+ # Inputs of the condition (new feature)
300
+ condition_latents=condition_latents if use_condition else None,
301
+ condition_ids=condition_ids if use_condition else None,
302
+ condition_type_ids=condition_type_ids if use_condition else None,
303
+ # Inputs to the original transformer
304
+ hidden_states=latents,
305
+ timestep=timestep / 1000,
306
+ guidance=guidance,
307
+ pooled_projections=pooled_prompt_embeds,
308
+ encoder_hidden_states=prompt_embeds,
309
+ txt_ids=text_ids,
310
+ img_ids=latent_image_ids,
311
+ joint_attention_kwargs=self.joint_attention_kwargs,
312
+ return_dict=False,
313
+ )[0]
314
+
315
+ # compute the previous noisy sample x_t -> x_t-1
316
+ latents_dtype = latents.dtype
317
+ latents = self.scheduler.step(
318
+ noise_pred, t, latents, return_dict=False
319
+ )[0]
320
+
321
+ if latents.dtype != latents_dtype:
322
+ if torch.backends.mps.is_available():
323
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
324
+ latents = latents.to(latents_dtype)
325
+
326
+ if callback_on_step_end is not None:
327
+ callback_kwargs = {}
328
+ for k in callback_on_step_end_tensor_inputs:
329
+ callback_kwargs[k] = locals()[k]
330
+ callback_outputs = callback_on_step_end(
331
+ self, latents, callback_kwargs
332
+ )
333
+
334
+ global denoising_images
335
+ denoising_images.append(callback_outputs)
336
+
337
+ # call the callback, if provided
338
+ if i == len(timesteps) - 1 or (
339
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
340
+ ):
341
+ progress_bar.update()
342
+
343
+ if output_type == "latent":
344
+ image = latents
345
+
346
+ else:
347
+ latents = self._unpack_latents(
348
+ latents, height, width, self.vae_scale_factor
349
+ )
350
+ latents = (
351
+ latents / self.vae.config.scaling_factor
352
+ ) + self.vae.config.shift_factor
353
+ image = self.vae.decode(latents, return_dict=False)[0]
354
+ image = self.image_processor.postprocess(image, output_type=output_type)
355
+
356
+ # Offload all models
357
+ self.maybe_free_model_hooks()
358
+
359
+ if condition_scale != 1:
360
+ for name, module in self.transformer.named_modules():
361
+ if not name.endswith(".attn"):
362
+ continue
363
+ del module.c_factor
364
+
365
+ if not return_dict:
366
+ return (image,)
367
+
368
+ return FluxPipelineOutput(images=image)
omini_control/lora_controller.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from peft.tuners.tuners_utils import BaseTunerLayer
2
+ from typing import List, Any, Optional, Type
3
+
4
+
5
+ class enable_lora:
6
+ def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None:
7
+ self.activated: bool = activated
8
+ if activated:
9
+ return
10
+ self.lora_modules: List[BaseTunerLayer] = [
11
+ each for each in lora_modules if isinstance(each, BaseTunerLayer)
12
+ ]
13
+ self.scales = [
14
+ {
15
+ active_adapter: lora_module.scaling[active_adapter]
16
+ for active_adapter in lora_module.active_adapters
17
+ }
18
+ for lora_module in self.lora_modules
19
+ ]
20
+
21
+ def __enter__(self) -> None:
22
+ if self.activated:
23
+ return
24
+
25
+ for lora_module in self.lora_modules:
26
+ if not isinstance(lora_module, BaseTunerLayer):
27
+ continue
28
+ lora_module.scale_layer(0)
29
+
30
+ def __exit__(
31
+ self,
32
+ exc_type: Optional[Type[BaseException]],
33
+ exc_val: Optional[BaseException],
34
+ exc_tb: Optional[Any],
35
+ ) -> None:
36
+ if self.activated:
37
+ return
38
+ for i, lora_module in enumerate(self.lora_modules):
39
+ if not isinstance(lora_module, BaseTunerLayer):
40
+ continue
41
+ for active_adapter in lora_module.active_adapters:
42
+ lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
43
+
44
+
45
+ class set_lora_scale:
46
+ def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None:
47
+ self.lora_modules: List[BaseTunerLayer] = [
48
+ each for each in lora_modules if isinstance(each, BaseTunerLayer)
49
+ ]
50
+ self.scales = [
51
+ {
52
+ active_adapter: lora_module.scaling[active_adapter]
53
+ for active_adapter in lora_module.active_adapters
54
+ }
55
+ for lora_module in self.lora_modules
56
+ ]
57
+ self.scale = scale
58
+
59
+ def __enter__(self) -> None:
60
+ for lora_module in self.lora_modules:
61
+ if not isinstance(lora_module, BaseTunerLayer):
62
+ continue
63
+ lora_module.scale_layer(self.scale)
64
+
65
+ def __exit__(
66
+ self,
67
+ exc_type: Optional[Type[BaseException]],
68
+ exc_val: Optional[BaseException],
69
+ exc_tb: Optional[Any],
70
+ ) -> None:
71
+ for i, lora_module in enumerate(self.lora_modules):
72
+ if not isinstance(lora_module, BaseTunerLayer):
73
+ continue
74
+ for active_adapter in lora_module.active_adapters:
75
+ lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
omini_control/transformer.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional, Dict, Any
3
+ from .block import block_forward, single_block_forward
4
+ from .lora_controller import enable_lora
5
+ from .conceptrol import Conceptrol
6
+ from diffusers.models.transformers.transformer_flux import (
7
+ FluxTransformer2DModel,
8
+ Transformer2DModelOutput,
9
+ USE_PEFT_BACKEND,
10
+ is_torch_version,
11
+ scale_lora_layers,
12
+ unscale_lora_layers,
13
+ logger,
14
+ )
15
+ import numpy as np
16
+
17
+
18
+ def prepare_params(
19
+ hidden_states: torch.Tensor,
20
+ encoder_hidden_states: torch.Tensor = None,
21
+ pooled_projections: torch.Tensor = None,
22
+ timestep: torch.LongTensor = None,
23
+ img_ids: torch.Tensor = None,
24
+ txt_ids: torch.Tensor = None,
25
+ guidance: torch.Tensor = None,
26
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
27
+ controlnet_block_samples=None,
28
+ controlnet_single_block_samples=None,
29
+ return_dict: bool = True,
30
+ **kwargs: dict,
31
+ ):
32
+ return (
33
+ hidden_states,
34
+ encoder_hidden_states,
35
+ pooled_projections,
36
+ timestep,
37
+ img_ids,
38
+ txt_ids,
39
+ guidance,
40
+ joint_attention_kwargs,
41
+ controlnet_block_samples,
42
+ controlnet_single_block_samples,
43
+ return_dict,
44
+ )
45
+
46
+
47
+ def tranformer_forward(
48
+ transformer: FluxTransformer2DModel,
49
+ condition_latents: torch.Tensor,
50
+ condition_ids: torch.Tensor,
51
+ condition_type_ids: torch.Tensor,
52
+ model_config: Optional[Dict[str, Any]] = {},
53
+ return_conditional_latents: bool = False,
54
+ c_t=0,
55
+ conceptrol: Conceptrol = None,
56
+ **params: dict,
57
+ ):
58
+ self = transformer
59
+ use_condition = condition_latents is not None
60
+ use_condition_in_single_blocks = model_config.get(
61
+ "use_condition_in_single_blocks", True
62
+ )
63
+ # if return_conditional_latents is True, use_condition and use_condition_in_single_blocks must be True
64
+ assert not return_conditional_latents or (
65
+ use_condition and use_condition_in_single_blocks
66
+ ), "`return_conditional_latents` is True, `use_condition` and `use_condition_in_single_blocks` must be True"
67
+
68
+ (
69
+ hidden_states,
70
+ encoder_hidden_states,
71
+ pooled_projections,
72
+ timestep,
73
+ img_ids,
74
+ txt_ids,
75
+ guidance,
76
+ joint_attention_kwargs,
77
+ controlnet_block_samples,
78
+ controlnet_single_block_samples,
79
+ return_dict,
80
+ ) = prepare_params(**params)
81
+
82
+ if joint_attention_kwargs is not None:
83
+ joint_attention_kwargs = joint_attention_kwargs.copy()
84
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
85
+ else:
86
+ lora_scale = 1.0
87
+
88
+ if USE_PEFT_BACKEND:
89
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
90
+ scale_lora_layers(self, lora_scale)
91
+ else:
92
+ if (
93
+ joint_attention_kwargs is not None
94
+ and joint_attention_kwargs.get("scale", None) is not None
95
+ ):
96
+ logger.warning(
97
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
98
+ )
99
+ with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
100
+ hidden_states = self.x_embedder(hidden_states)
101
+ condition_latents = self.x_embedder(condition_latents) if use_condition else None
102
+
103
+ timestep = timestep.to(hidden_states.dtype) * 1000
104
+ if guidance is not None:
105
+ guidance = guidance.to(hidden_states.dtype) * 1000
106
+ else:
107
+ guidance = None
108
+ temb = (
109
+ self.time_text_embed(timestep, pooled_projections)
110
+ if guidance is None
111
+ else self.time_text_embed(timestep, guidance, pooled_projections)
112
+ )
113
+ cond_temb = (
114
+ self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
115
+ if guidance is None
116
+ else self.time_text_embed(
117
+ torch.ones_like(timestep) * c_t * 1000, guidance, pooled_projections
118
+ )
119
+ )
120
+ if hasattr(self, "cond_type_embed") and condition_type_ids is not None:
121
+ cond_type_proj = self.time_text_embed.time_proj(condition_type_ids[0])
122
+ cond_type_emb = self.cond_type_embed(cond_type_proj.to(dtype=cond_temb.dtype))
123
+ cond_temb = cond_temb + cond_type_emb
124
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
125
+
126
+ if txt_ids.ndim == 3:
127
+ logger.warning(
128
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
129
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
130
+ )
131
+ txt_ids = txt_ids[0]
132
+ if img_ids.ndim == 3:
133
+ logger.warning(
134
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
135
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
136
+ )
137
+ img_ids = img_ids[0]
138
+
139
+ ids = torch.cat((txt_ids, img_ids), dim=0)
140
+ image_rotary_emb = self.pos_embed(ids)
141
+ if use_condition:
142
+ cond_ids = condition_ids
143
+ cond_rotary_emb = self.pos_embed(cond_ids)
144
+
145
+ # hidden_states = torch.cat([hidden_states, condition_latents], dim=1)
146
+
147
+ for index_block, block in enumerate(self.transformer_blocks):
148
+ if self.training and self.gradient_checkpointing:
149
+
150
+ def create_custom_forward(module, return_dict=None):
151
+ def custom_forward(*inputs):
152
+ if return_dict is not None:
153
+ return module(*inputs, return_dict=return_dict)
154
+ else:
155
+ return module(*inputs)
156
+
157
+ return custom_forward
158
+
159
+ ckpt_kwargs: Dict[str, Any] = (
160
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
161
+ )
162
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
163
+ create_custom_forward(block),
164
+ hidden_states,
165
+ encoder_hidden_states,
166
+ temb,
167
+ image_rotary_emb,
168
+ **ckpt_kwargs,
169
+ )
170
+
171
+ else:
172
+ encoder_hidden_states, hidden_states, condition_latents = block_forward(
173
+ block,
174
+ model_config=model_config,
175
+ hidden_states=hidden_states,
176
+ encoder_hidden_states=encoder_hidden_states,
177
+ condition_latents=condition_latents if use_condition else None,
178
+ temb=temb,
179
+ cond_temb=cond_temb if use_condition else None,
180
+ cond_rotary_emb=cond_rotary_emb if use_condition else None,
181
+ image_rotary_emb=image_rotary_emb,
182
+ conceptrol=conceptrol,
183
+ )
184
+
185
+ # controlnet residual
186
+ if controlnet_block_samples is not None:
187
+ interval_control = len(self.transformer_blocks) / len(
188
+ controlnet_block_samples
189
+ )
190
+ interval_control = int(np.ceil(interval_control))
191
+ hidden_states = (
192
+ hidden_states
193
+ + controlnet_block_samples[index_block // interval_control]
194
+ )
195
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
196
+
197
+ for index_block, block in enumerate(self.single_transformer_blocks):
198
+ if self.training and self.gradient_checkpointing:
199
+
200
+ def create_custom_forward(module, return_dict=None):
201
+ def custom_forward(*inputs):
202
+ if return_dict is not None:
203
+ return module(*inputs, return_dict=return_dict)
204
+ else:
205
+ return module(*inputs)
206
+
207
+ return custom_forward
208
+
209
+ ckpt_kwargs: Dict[str, Any] = (
210
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
211
+ )
212
+ hidden_states = torch.utils.checkpoint.checkpoint(
213
+ create_custom_forward(block),
214
+ hidden_states,
215
+ temb,
216
+ image_rotary_emb,
217
+ **ckpt_kwargs,
218
+ )
219
+
220
+ else:
221
+ result = single_block_forward(
222
+ block,
223
+ model_config=model_config,
224
+ hidden_states=hidden_states,
225
+ temb=temb,
226
+ image_rotary_emb=image_rotary_emb,
227
+ **(
228
+ {
229
+ "condition_latents": condition_latents,
230
+ "cond_temb": cond_temb,
231
+ "cond_rotary_emb": cond_rotary_emb,
232
+ "conceptrol": conceptrol,
233
+ }
234
+ if use_condition_in_single_blocks and use_condition
235
+ else {}
236
+ ),
237
+ )
238
+ if use_condition_in_single_blocks and use_condition:
239
+ hidden_states, condition_latents = result
240
+ else:
241
+ hidden_states = result
242
+
243
+ # controlnet residual
244
+ if controlnet_single_block_samples is not None:
245
+ interval_control = len(self.single_transformer_blocks) / len(
246
+ controlnet_single_block_samples
247
+ )
248
+ interval_control = int(np.ceil(interval_control))
249
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
250
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
251
+ + controlnet_single_block_samples[index_block // interval_control]
252
+ )
253
+
254
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
255
+
256
+ hidden_states = self.norm_out(hidden_states, temb)
257
+ output = self.proj_out(hidden_states)
258
+ if return_conditional_latents:
259
+ condition_latents = (
260
+ self.norm_out(condition_latents, cond_temb) if use_condition else None
261
+ )
262
+ condition_output = self.proj_out(condition_latents) if use_condition else None
263
+
264
+ if USE_PEFT_BACKEND:
265
+ # remove `lora_scale` from each PEFT layer
266
+ unscale_lora_layers(self, lora_scale)
267
+
268
+ if not return_dict:
269
+ return (
270
+ (output,) if not return_conditional_latents else (output, condition_output)
271
+ )
272
+
273
+ return Transformer2DModelOutput(sample=output)
requirements.txt CHANGED
@@ -1,6 +1,9 @@
1
- accelerate
2
- diffusers
3
- invisible_watermark
4
- torch
5
  transformers
6
- xformers
 
 
 
 
 
 
 
 
 
 
 
 
1
  transformers
2
+ diffusers
3
+ peft
4
+ opencv-python
5
+ protobuf
6
+ sentencepiece
7
+ gradio
8
+ jupyter
9
+ torchao
style.css ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ justify-content: center;
4
+ }
5
+
6
+ [role="tabpanel"] {
7
+ border: 0
8
+ }
9
+
10
+ #duplicate-button {
11
+ margin: auto;
12
+ color: #fff;
13
+ background: #1565c0;
14
+ border-radius: 100vh;
15
+ }
16
+
17
+ .gradio-container {
18
+ max-width: 690px ! important;
19
+ }
20
+
21
+ .equal-height {
22
+ display: flex;
23
+ flex: 1;
24
+ }
25
+
26
+ .grid-container {
27
+ display: grid;
28
+ grid-template-columns: 1fr 1fr; /* 两列宽度相等 */
29
+ gap: 20px;
30
+ height: 100%; /* 确保容器高度为100% */
31
+ }
32
+
33
+ .grid-item {
34
+ display: flex;
35
+ flex-direction: column;
36
+ height: 100%;
37
+ }
38
+
39
+ .flex-grow {
40
+ flex-grow: 1; /* 使该元素占据剩余的高度 */
41
+ display: flex;
42
+ flex-direction: column;
43
+ }
44
+
45
+ #share-btn-container {
46
+ padding-left: 0.5rem !important;
47
+ padding-right: 0.5rem !important;
48
+ background-color: #000000;
49
+ justify-content: center;
50
+ align-items: center;
51
+ border-radius: 9999px !important;
52
+ max-width: 13rem;
53
+ margin-left: auto;
54
+ margin-top: 0.35em;
55
+ }
56
+
57
+ div#share-btn-container>div {
58
+ flex-direction: row;
59
+ background: black;
60
+ align-items: center
61
+ }
62
+
63
+ #share-btn-container:hover {
64
+ background-color: #060606
65
+ }
66
+
67
+ #share-btn {
68
+ all: initial;
69
+ color: #ffffff;
70
+ font-weight: 600;
71
+ cursor: pointer;
72
+ font-family: 'IBM Plex Sans', sans-serif;
73
+ margin-left: 0.5rem !important;
74
+ padding-top: 0.5rem !important;
75
+ padding-bottom: 0.5rem !important;
76
+ right: 0;
77
+ font-size: 15px;
78
+ }
79
+
80
+ #share-btn * {
81
+ all: unset
82
+ }
83
+
84
+ #share-btn-container div:nth-child(-n+2) {
85
+ width: auto !important;
86
+ min-height: 0px !important;
87
+ }
88
+
89
+ #share-btn-container .wrap {
90
+ display: none !important
91
+ }
92
+
93
+ #share-btn-container.hidden {
94
+ display: none !important
95
+ }
utils.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import torch
7
+ from lpips import LPIPS
8
+ from PIL import Image
9
+ from torchvision.transforms import Normalize
10
+
11
+
12
+ def show_images_horizontally(
13
+ list_of_files: np.array, output_file: Optional[str] = None, interact: bool = False
14
+ ) -> None:
15
+ """
16
+ Visualize the list of images horizontally and save the figure as PNG.
17
+
18
+ Args:
19
+ list_of_files: The list of images as numpy array with shape (N, H, W, C).
20
+ output_file: The output file path to save the figure as PNG.
21
+ interact: Whether to show the figure interactively in Jupyter Notebook or not in Python.
22
+ """
23
+ number_of_files = len(list_of_files)
24
+
25
+ heights = [a[0].shape[0] for a in list_of_files]
26
+ widths = [a.shape[1] for a in list_of_files[0]]
27
+
28
+ fig_width = 8.0 # inches
29
+ fig_height = fig_width * sum(heights) / sum(widths)
30
+
31
+ # Create a figure with subplots
32
+ _, axs = plt.subplots(
33
+ 1, number_of_files, figsize=(fig_width * number_of_files, fig_height)
34
+ )
35
+ plt.tight_layout()
36
+ for i in range(number_of_files):
37
+ _image = list_of_files[i]
38
+ axs[i].imshow(_image)
39
+ axs[i].axis("off")
40
+
41
+ # Save the figure as PNG
42
+ if interact:
43
+ plt.show()
44
+ else:
45
+ plt.savefig(output_file, bbox_inches="tight", pad_inches=0.25)
46
+
47
+
48
+ def image_grids(images, rows=None, cols=None):
49
+ if not images:
50
+ raise ValueError("The image list is empty.")
51
+
52
+ n_images = len(images)
53
+ if cols is None:
54
+ cols = int(n_images**0.5)
55
+ if rows is None:
56
+ rows = (n_images + cols - 1) // cols
57
+
58
+ width, height = images[0].size
59
+ grid_width = cols * width
60
+ grid_height = rows * height
61
+
62
+ grid_image = Image.new("RGB", (grid_width, grid_height))
63
+
64
+ for i, image in enumerate(images):
65
+ row, col = divmod(i, cols)
66
+ grid_image.paste(image, (col * width, row * height))
67
+
68
+ return grid_image
69
+
70
+
71
+ def save_image(image: np.array, file_name: str) -> None:
72
+ """
73
+ Save the image as JPG.
74
+
75
+ Args:
76
+ image: The input image as numpy array with shape (H, W, C).
77
+ file_name: The file name to save the image.
78
+ """
79
+ image = Image.fromarray(image)
80
+ image.save(file_name)
81
+
82
+
83
+ def load_and_process_images(load_dir: str) -> np.array:
84
+ """
85
+ Load and process the images into numpy array from the directory.
86
+
87
+ Args:
88
+ load_dir: The directory to load the images.
89
+
90
+ Returns:
91
+ images: The images as numpy array with shape (N, H, W, C).
92
+ """
93
+ images = []
94
+ print(load_dir)
95
+ filenames = sorted(
96
+ os.listdir(load_dir), key=lambda x: int(x.split(".")[0])
97
+ ) # Ensure the files are sorted numerically
98
+ for filename in filenames:
99
+ if filename.endswith(".jpg"):
100
+ img = Image.open(os.path.join(load_dir, filename))
101
+ img_array = (
102
+ np.asarray(img) / 255.0
103
+ ) # Convert to numpy array and scale pixel values to [0, 1]
104
+ images.append(img_array)
105
+ return images
106
+
107
+
108
+ def compute_lpips(images: np.array, lpips_model: LPIPS) -> np.array:
109
+ """
110
+ Compute the LPIPS of the input images.
111
+
112
+ Args:
113
+ images: The input images as numpy array with shape (N, H, W, C).
114
+ lpips_model: The LPIPS model used to compute perceptual distances.
115
+
116
+ Returns:
117
+ distances: The LPIPS of the input images.
118
+ """
119
+ # Get device of lpips_model
120
+ device = next(lpips_model.parameters()).device
121
+ device = str(device)
122
+
123
+ # Change the input images into tensor
124
+ images = torch.tensor(images).to(device).float()
125
+ images = torch.permute(images, (0, 3, 1, 2))
126
+ normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
127
+ images = normalize(images)
128
+
129
+ # Compute the LPIPS between each adjacent input images
130
+ distances = []
131
+ for i in range(images.shape[0]):
132
+ if i == images.shape[0] - 1:
133
+ break
134
+ img1 = images[i].unsqueeze(0)
135
+ img2 = images[i + 1].unsqueeze(0)
136
+ loss = lpips_model(img1, img2)
137
+ distances.append(loss.item())
138
+ distances = np.array(distances)
139
+ return distances
140
+
141
+
142
+ def compute_gini(distances: np.array) -> float:
143
+ """
144
+ Compute the Gini index of the input distances.
145
+
146
+ Args:
147
+ distances: The input distances as numpy array.
148
+
149
+ Returns:
150
+ gini: The Gini index of the input distances.
151
+ """
152
+ if len(distances) < 2:
153
+ return 0.0 # Gini index is 0 for less than two elements
154
+
155
+ # Sort the list of distances
156
+ sorted_distances = sorted(distances)
157
+ n = len(sorted_distances)
158
+ mean_distance = sum(sorted_distances) / n
159
+
160
+ # Compute the sum of absolute differences
161
+ sum_of_differences = 0
162
+ for di in sorted_distances:
163
+ for dj in sorted_distances:
164
+ sum_of_differences += abs(di - dj)
165
+
166
+ # Normalize the sum of differences by the mean and the number of elements
167
+ gini = sum_of_differences / (2 * n * n * mean_distance)
168
+ return gini
169
+
170
+
171
+ def compute_smoothness_and_consistency(images: np.array, lpips_model: LPIPS) -> tuple:
172
+ """
173
+ Compute the smoothness and efficiency of the input images.
174
+
175
+ Args:
176
+ images: The input images as numpy array with shape (N, H, W, C).
177
+ lpips_model: The LPIPS model used to compute perceptual distances.
178
+
179
+ Returns:
180
+ smoothness: One minus gini index of LPIPS of consecutive images.
181
+ consistency: The mean LPIPS of consecutive images.
182
+ max_inception_distance: The maximum LPIPS of consecutive images.
183
+ """
184
+ distances = compute_lpips(images, lpips_model)
185
+ smoothness = 1 - compute_gini(distances)
186
+ consistency = np.mean(distances)
187
+ max_inception_distance = np.max(distances)
188
+ return smoothness, consistency, max_inception_distance
189
+
190
+
191
+ def separate_source_and_interpolated_images(images: np.array) -> tuple:
192
+ """
193
+ Separate the input images into source and interpolated images.
194
+ The input source is the start and end of the images, while the interpolated images are the rest.
195
+
196
+ Args:
197
+ images: The input images as numpy array with shape (N, H, W, C).
198
+
199
+ Returns:
200
+ source: The source images as numpy array with shape (2, H, W, C).
201
+ interpolation: The interpolated images as numpy array with shape (N-2, H, W, C).
202
+ """
203
+ # Check if the array has at least two elements
204
+ if len(images) < 2:
205
+ raise ValueError("The input array should have at least two elements.")
206
+
207
+ # Separate the array into two parts
208
+ # First part takes the first and last element
209
+ source = np.array([images[0], images[-1]])
210
+ # Second part takes the rest of the elements
211
+ interpolation = images[1:-1]
212
+ return source, interpolation