ziheng1234 commited on
Commit
3e8fe6c
·
verified ·
1 Parent(s): e5cb7a0

Upload 39 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,18 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ test_imgs/2.png filter=lfs diff=lfs merge=lfs -text
37
+ test_imgs/3.png filter=lfs diff=lfs merge=lfs -text
38
+ test_imgs/generated_1_bbox.png filter=lfs diff=lfs merge=lfs -text
39
+ test_imgs/generated_1.png filter=lfs diff=lfs merge=lfs -text
40
+ test_imgs/generated_2_bbox.png filter=lfs diff=lfs merge=lfs -text
41
+ test_imgs/generated_2.png filter=lfs diff=lfs merge=lfs -text
42
+ test_imgs/generated_3_bbox_1.png filter=lfs diff=lfs merge=lfs -text
43
+ test_imgs/generated_3_bbox.png filter=lfs diff=lfs merge=lfs -text
44
+ test_imgs/generated_3.png filter=lfs diff=lfs merge=lfs -text
45
+ test_imgs/product_1_bbox.png filter=lfs diff=lfs merge=lfs -text
46
+ test_imgs/product_2_bbox.png filter=lfs diff=lfs merge=lfs -text
47
+ test_imgs/product_2.png filter=lfs diff=lfs merge=lfs -text
48
+ test_imgs/product_3_bbox_1.png filter=lfs diff=lfs merge=lfs -text
49
+ test_imgs/product_3_bbox.png filter=lfs diff=lfs merge=lfs -text
50
+ test_imgs/product_3.png filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,720 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import os
2
+ # os.system("pip uninstall -y gradio")
3
+ # os.system("pip install gradio==5.49.1")
4
+ # os.system("pip uninstall -y gradio_image_annotation")
5
+ # os.system("pip install gradio_image_annotation==0.4.1")
6
+ # os.system("pip uninstall -y huggingface-hub")
7
+ # os.system("pip install huggingface-hub==0.35.3")
8
+
9
+
10
+ import torch
11
+ from PIL import Image
12
+ import gradio as gr
13
+ from gradio_image_annotation import image_annotator
14
+ import numpy as np
15
+ import random
16
+
17
+ from diffusers import FluxTransformer2DModel, FluxKontextPipeline
18
+ from safetensors.torch import load_file
19
+ from huggingface_hub import hf_hub_download
20
+ from src.lora_helper import set_single_lora
21
+ from src.detail_encoder import DetailEncoder
22
+ from src.kontext_custom_pipeline import FluxKontextPipelineWithPhotoEncoderAddTokens
23
+ # import spaces
24
+ from uno.flux.pipeline import UNOPipeline
25
+
26
+ hf_hub_download(
27
+ repo_id="ziheng1234/ImageCritic",
28
+ filename="detail_encoder.safetensors",
29
+ local_dir="models" # 下载到本地 models/ 目录
30
+ )
31
+ hf_hub_download(
32
+ repo_id="ziheng1234/ImageCritic",
33
+ filename="lora.safetensors",
34
+ local_dir="models"
35
+ )
36
+
37
+ from huggingface_hub import snapshot_download
38
+ repo_id = "ziheng1234/kontext"
39
+ local_dir = "./kontext"
40
+ snapshot_download(
41
+ repo_id=repo_id,
42
+ local_dir=local_dir,
43
+ repo_type="model",
44
+ resume_download=True,
45
+ max_workers=8
46
+ )
47
+ base_path = "./models"
48
+ detail_encoder_path = f"{base_path}/detail_encoder.safetensors"
49
+ kontext_lora_path = f"{base_path}/lora.safetensors"
50
+
51
+
52
+ def pick_kontext_resolution(w: int, h: int) -> tuple[int, int]:
53
+ PREFERRED_KONTEXT_RESOLUTIONS = [
54
+ (672, 1568), (688, 1504), (720, 1456), (752, 1392),
55
+ (800, 1328), (832, 1248), (880, 1184), (944, 1104),
56
+ (1024, 1024), (1104, 944), (1184, 880), (1248, 832),
57
+ (1328, 800), (1392, 752), (1456, 720), (1504, 688), (1568, 672),
58
+ ]
59
+ target_ratio = w / h
60
+ return min(
61
+ PREFERRED_KONTEXT_RESOLUTIONS,
62
+ key=lambda wh: abs((wh[0] / wh[1]) - target_ratio),
63
+ )
64
+
65
+
66
+ MAX_SEED = np.iinfo(np.int32).max
67
+
68
+ device = None
69
+ pipeline = None
70
+ transformer = None
71
+ detail_encoder = None
72
+ stage1_pipeline = None
73
+
74
+ @spaces.GPU(duration=200)
75
+ def load_stage1_model():
76
+ global stage1_pipeline, device
77
+
78
+ if stage1_pipeline is not None:
79
+ return
80
+
81
+ print("加载 Stage 1 UNO Pipeline...")
82
+ if device is None:
83
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
84
+
85
+ model_type = "flux-dev"
86
+ stage1_pipeline = UNOPipeline(model_type, device, offload=False, only_lora=True, lora_rank=512)
87
+ print("Stage 1 模型加载完成!")
88
+
89
+ @spaces.GPU(duration=200)
90
+ def load_models():
91
+ global device, pipeline, transformer, detail_encoder
92
+
93
+ if pipeline is not None and transformer is not None and detail_encoder is not None:
94
+ return
95
+
96
+ print("CUDA 可用:", torch.cuda.is_available())
97
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
98
+ print("使用设备:", device)
99
+
100
+ dtype = torch.bfloat16 if "cuda" in device else torch.float32
101
+
102
+ print("加载 FluxKontextPipelineWithPhotoEncoderAddTokens...")
103
+ pipeline_local = FluxKontextPipelineWithPhotoEncoderAddTokens.from_pretrained(
104
+ "./kontext",
105
+ torch_dtype=dtype,
106
+ )
107
+ pipeline_local.to(device)
108
+
109
+ print("加载 FluxTransformer2DModel...")
110
+ transformer_local = FluxTransformer2DModel.from_pretrained(
111
+ "./kontext",
112
+ subfolder="transformer",
113
+ torch_dtype=dtype,
114
+ )
115
+ transformer_local.to(device)
116
+
117
+ print("加载 detail_encoder 权重...")
118
+ state_dict = load_file(detail_encoder_path)
119
+ detail_encoder_local = DetailEncoder().to(dtype=transformer_local.dtype, device=device)
120
+ detail_encoder_local.to(device)
121
+
122
+ with torch.no_grad():
123
+ for name, param in detail_encoder_local.named_parameters():
124
+ if name in state_dict:
125
+ added = state_dict[name].to(param.device)
126
+ param.add_(added)
127
+
128
+ pipeline_local.transformer = transformer_local
129
+ pipeline_local.detail_encoder = detail_encoder_local
130
+
131
+ print("加载 LoRA...")
132
+ set_single_lora(pipeline_local.transformer, kontext_lora_path, lora_weights=[1.0])
133
+
134
+ print("模型加载完成!")
135
+
136
+ # 写回全局变量
137
+ pipeline = pipeline_local
138
+ transformer = transformer_local
139
+ detail_encoder = detail_encoder_local
140
+
141
+ @spaces.GPU(duration=200)
142
+ def generate_image_method1(input_image, prompt, width, height, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28):
143
+ """
144
+ Stage 1 - Method 1: UNO image generation
145
+ """
146
+ load_stage1_model()
147
+ global stage1_pipeline
148
+
149
+ if randomize_seed:
150
+ seed = -1
151
+
152
+ try:
153
+ # UNO pipeline uses gradio_generate interface
154
+ output_image, output_file = stage1_pipeline.gradio_generate(
155
+ prompt=prompt,
156
+ width=int(width),
157
+ height=int(height),
158
+ guidance=guidance_scale,
159
+ num_steps=steps,
160
+ seed=seed,
161
+ image_prompt1=input_image,
162
+ image_prompt2=None,
163
+ image_prompt3=None,
164
+ image_prompt4=None,
165
+ )
166
+ used_seed = seed if seed != -1 else random.randint(0, MAX_SEED)
167
+ return output_image, used_seed
168
+ except Exception as e:
169
+ print(f"Stage 1 生成图像时发生错误: {e}")
170
+ raise gr.Error(f"生成失败:{str(e)}")
171
+
172
+ def extract_first_box(annotations: dict):
173
+ """
174
+ 从 gradio_image_annotation 的返回中拿第一个 bbox 和对应的 PIL 图像及 patch
175
+
176
+ 如果没有 bbox,则自动使用整张图作为 bbox。
177
+ """
178
+ if not annotations:
179
+ raise gr.Error("Missing annotation data. Please check if an image is uploaded.")
180
+
181
+ img_array = annotations.get("image", None)
182
+ boxes = annotations.get("boxes", [])
183
+
184
+ if img_array is None:
185
+ raise gr.Error("No 'image' field found in annotation.")
186
+
187
+ img = Image.fromarray(img_array)
188
+
189
+ # ✅
190
+ if not boxes:
191
+ w, h = img.size
192
+ xmin, ymin, xmax, ymax = 0, 0, w, h
193
+ else:
194
+ box = boxes[0]
195
+ xmin = int(box["xmin"])
196
+ ymin = int(box["ymin"])
197
+ xmax = int(box["xmax"])
198
+ ymax = int(box["ymax"])
199
+
200
+ if xmax <= xmin or ymax <= ymin:
201
+ raise gr.Error("Invalid bbox, please draw the box again.")
202
+
203
+ patch = img.crop((xmin, ymin, xmax, ymax))
204
+ return img, patch, (xmin, ymin, xmax, ymax)
205
+
206
+ @spaces.GPU(duration=200)
207
+ def run_with_two_bboxes(
208
+ annotations_A: dict | None, #
209
+ annotations_B: dict | None, #
210
+ object_name: str,
211
+ base_seed: int = 0,
212
+ ): # noqa: C901
213
+ """
214
+ """
215
+
216
+ load_models()
217
+ global pipeline, device
218
+ if annotations_A is None:
219
+ raise gr.Error("please upload reference image and draw a bbox")
220
+ if annotations_B is None:
221
+ raise gr.Error("please upload input image to be corrected and draw a bbox")
222
+
223
+ # 1.
224
+ img1_full, patch_A, bbox_A = extract_first_box(annotations_A)
225
+ img2_full, patch_B, bbox_B = extract_first_box(annotations_B)
226
+
227
+ xmin_B, ymin_B, xmax_B, ymax_B = bbox_B
228
+ patch_w = xmax_B - xmin_B
229
+ patch_h = ymax_B - ymin_B
230
+
231
+ if not object_name:
232
+ object_name = "object"
233
+
234
+ # 2.
235
+ orig_w, orig_h = patch_B.size
236
+ target_w, target_h = pick_kontext_resolution(orig_w, orig_h)
237
+ width_for_model, height_for_model = target_w, target_h
238
+
239
+ # 3.
240
+ cond_A_image = patch_A.resize((width_for_model, height_for_model), Image.Resampling.LANCZOS)
241
+ cond_B_image = patch_B.resize((width_for_model, height_for_model), Image.Resampling.LANCZOS)
242
+
243
+ prompt = f"use the {object_name} in IMG1 as a reference to refine, replace, enhance the {object_name} in IMG2"
244
+ print("prompt:", prompt)
245
+
246
+ seed = int(base_seed)
247
+ gen_device = device.split(":")[0] if "cuda" in device else device
248
+ generator = torch.Generator(gen_device).manual_seed(seed)
249
+
250
+ try:
251
+ out = pipeline(
252
+ image_A=cond_A_image,
253
+ image_B=cond_B_image,
254
+ prompt=prompt,
255
+ height=height_for_model,
256
+ width=width_for_model,
257
+ guidance_scale=3.5,
258
+ generator=generator,
259
+ )
260
+
261
+ gen_patch_model = out.images[0]
262
+
263
+ #
264
+ gen_patch = gen_patch_model.resize((patch_w, patch_h), Image.Resampling.LANCZOS)
265
+
266
+ #
267
+ composed = img2_full.copy()
268
+ composed.paste(gen_patch, (xmin_B, ymin_B))
269
+ patch_A_resized = patch_A.resize((patch_w, patch_h), Image.Resampling.LANCZOS)
270
+ patch_B_resized = patch_B.resize((patch_w, patch_h), Image.Resampling.LANCZOS)
271
+ SPACING = 10
272
+ collage_w = patch_w * 3 + SPACING * 2
273
+ collage_h = patch_h
274
+
275
+ collage = Image.new("RGB", (collage_w, collage_h), (255, 255, 255))
276
+
277
+ x0 = 0
278
+ x1 = patch_w + SPACING
279
+ x2 = patch_w * 2 + SPACING * 2
280
+
281
+ collage.paste(patch_A_resized, (x0, 0))
282
+ collage.paste(patch_B_resized, (x1, 0))
283
+ collage.paste(gen_patch, (x2, 0))
284
+
285
+ return collage, composed
286
+
287
+ except Exception as e:
288
+ print(f"生成图像时发生错误: {e}")
289
+ raise gr.Error(f"生成失败:{str(e)}")
290
+
291
+
292
+ import gradio as gr
293
+
294
+ with gr.Blocks(
295
+ theme=gr.themes.Soft(primary_hue="blue", neutral_hue="slate"),
296
+ css="""
297
+ /* Global Clean Font */
298
+
299
+
300
+ /* Center container */
301
+ .app-container {
302
+ width: 100% !important;
303
+ max-width: 100% !important;
304
+ margin: 0 auto;
305
+ }
306
+
307
+ /* Title block */
308
+ .title-block h1 {
309
+ text-align: center;
310
+ font-size: 3rem;
311
+ font-weight: 1100;
312
+
313
+ /* 蓝紫渐变 */
314
+ background: linear-gradient(90deg, #5b8dff, #b57aff);
315
+ -webkit-background-clip: text;
316
+ color: transparent;
317
+ }
318
+
319
+ .title-block h2 {
320
+ text-align: center;
321
+ font-size: 1.6rem;
322
+ font-weight: 700;
323
+ margin-top: 0.4rem;
324
+
325
+ /* 稍弱一点的渐变 */
326
+ background: linear-gradient(90deg, #6da0ff, #c28aff);
327
+ -webkit-background-clip: text;
328
+ color: transparent;
329
+ }
330
+
331
+ /* Title block
332
+
333
+ .title-block h1 {
334
+ text-align: center; font-size: 2.4rem; font-weight: 800; color: #1f2937;
335
+ }
336
+ .title-block h2 {
337
+ text-align: center; font-size: 1.2rem; font-weight: 500; color: #303030; margin-top: 0.4rem;
338
+ }
339
+ */
340
+
341
+ /* Simple card */
342
+ .clean-card {
343
+ background: #ffffff;
344
+ border: 1px solid #e5e7eb;
345
+ border-radius: 12px;
346
+ padding: 14px 16px;
347
+ margin-bottom: 10px;
348
+ }
349
+
350
+ /* Card title */
351
+ .clean-card-title {
352
+ font-size: 1.3rem;
353
+ font-weight: 600;
354
+ color: #404040;
355
+ margin-bottom: 6px;
356
+ }
357
+
358
+ /* Subtitle */
359
+ .clean-card-subtitle {
360
+ font-size: 1.1rem;
361
+ color: #404040;
362
+ margin-bottom: 8px;
363
+ }
364
+
365
+ /* Output card */
366
+ .output-card {
367
+ background: #ffffff;
368
+ border: 1px solid #d1d5db;
369
+ border-radius: 12px;
370
+ padding: 14px 16px;
371
+ }
372
+ .output-card1 {
373
+ background: #ffffff;
374
+ border: none !important;
375
+ box-shadow: none !important;
376
+ border-radius: 12px;
377
+ padding: 14px 16px;
378
+ }
379
+
380
+ /* 渐变主按钮:同时兼容 button 自己是 .color-btn,或者外层是 .color-btn 的情况 */
381
+ button.color-btn,
382
+ .color-btn button {
383
+ width: 100%;
384
+ background: linear-gradient(90deg, #3b82f6 0%, #6366f1 100%) !important;
385
+ color: #ffffff !important;
386
+ font-size: 1.05rem !important;
387
+ font-weight: 700 !important;
388
+ padding: 14px !important;
389
+ border-radius: 12px !important;
390
+
391
+ border: none !important;
392
+ box-shadow: 0 4px 12px rgba(99, 102, 241, 0.25) !important;
393
+ transition: 0.2s ease !important;
394
+ cursor: pointer;
395
+ }
396
+
397
+ /* Hover 效果 */
398
+ button.color-btn:hover,
399
+ .color-btn button:hover {
400
+ opacity: 0.92 !important;
401
+ transform: translateY(-1px) !important;
402
+ }
403
+
404
+ /* 按下反馈 */
405
+ button.color-btn:active,
406
+ .color-btn button:active {
407
+ transform: scale(0.98) !important;
408
+ }
409
+
410
+ /* 如果外面还有 wrapper,就把它搞透明一下(防止再套一层白条) */
411
+ .color-btn > div {
412
+ background: transparent !important;
413
+ box-shadow: none !important;
414
+ border: none !important;
415
+ }
416
+
417
+ .example-image img {
418
+ height: 400px !important;
419
+ object-fit: contain;
420
+
421
+ """
422
+ ) as demo:
423
+ gen_patch_out = None
424
+ composed_out = None
425
+ # -------------------------------------------------------
426
+ # Title
427
+ # -------------------------------------------------------
428
+ gr.Markdown(
429
+ """
430
+ <div class="title-block">
431
+ <h1>The Consistency Critic:</h1>
432
+ <h2>Correcting Inconsistencies in Generated Images via Reference-Guided Attentive Alignment</h2>
433
+ </div>
434
+ """
435
+ )
436
+
437
+ # ========================================================
438
+ # 两个 Stage 并排显示
439
+ # ========================================================
440
+ with gr.Row(elem_classes="app-container"):
441
+ # ========================================================
442
+ # STAGE 1: Image Generation (左侧)
443
+ # ========================================================
444
+ with gr.Column(scale=1):
445
+ gr.Markdown(
446
+ """
447
+ <div class="clean-card">
448
+ <div class="clean-card-title">🎨 Stage 1: Customized Image Generation</div>
449
+ <div class="clean-card-subtitle">Generate images from prompts and reference image using UNO method. The output can be used as input for Stage 2.</div>
450
+ </div>
451
+ """
452
+ )
453
+
454
+ # Stage 1 Input
455
+ gr.Markdown("### Input")
456
+ stage1_input_image = gr.Image(label="Input Image (Optional)", type="pil")
457
+ stage1_prompt = gr.Textbox(
458
+ label="Prompt",
459
+ placeholder="Enter your prompt for image generation",
460
+ lines=3
461
+ )
462
+
463
+ with gr.Row():
464
+ with gr.Column():
465
+ stage1_width = gr.Slider(512, 2048, 1024, step=16, label="Generation Width")
466
+ stage1_height = gr.Slider(512, 2048, 1024, step=16, label="Generation Height")
467
+ with gr.Accordion("Advanced Settings", open=False):
468
+ stage1_seed = gr.Slider(
469
+ label="Seed",
470
+ minimum=0,
471
+ maximum=MAX_SEED,
472
+ step=1,
473
+ value=42,
474
+ )
475
+ stage1_randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
476
+ stage1_guidance_scale = gr.Slider(
477
+ label="Guidance Scale",
478
+ minimum=1,
479
+ maximum=10,
480
+ step=0.1,
481
+ value=2.5,
482
+ )
483
+ stage1_steps = gr.Slider(
484
+ label="Steps",
485
+ minimum=1,
486
+ maximum=30,
487
+ value=28,
488
+ step=1
489
+ )
490
+
491
+ stage1_method1_btn = gr.Button("✨ Generate Image", elem_classes="color-btn")
492
+
493
+ # Stage 1 Output
494
+ gr.Markdown("### Output")
495
+ stage1_output_image = gr.Image(label="Generated Image", interactive=False)
496
+ stage1_used_seed = gr.Number(label="Used Seed", interactive=False)
497
+
498
+ # -------------------------------------------------------
499
+ # Stage 1 Examples
500
+ # -------------------------------------------------------
501
+ gr.Markdown(
502
+ """
503
+ <div style="
504
+ font-size: 1.3rem;
505
+ font-weight: 600;
506
+ color: #404040;
507
+ margin-top: 16px;
508
+ margin-bottom: 6px;
509
+ ">
510
+ 📚 Stage 1 Example Images & Prompts
511
+ </div>
512
+ """,
513
+ )
514
+ gr.Markdown(
515
+ """
516
+ <div style="
517
+ font-size: 1.1rem;
518
+ color: #404040;
519
+ margin-bottom: 8px;
520
+ ">
521
+ Click on any example below to load the image and prompt into Stage 1 inputs.
522
+ </div>
523
+ """,
524
+ )
525
+
526
+ gr.Examples(
527
+ examples=[
528
+ ["./test_imgs/product_3.png", "In a softly lit nursery, a baby sleeps peacefully as a parent gently applies the product to a washcloth. The scene is calm and warm, with natural light highlighting the product’s label. The camera captures a close-up, centered view, emphasizing the product’s presence and its gentle interaction with the environment."],
529
+ ["./test_imgs/3.png", "Create an engaging lifestyle e-commerce scene where a person delicately picks up the product from a slightly shifted angle to add depth and realism, placing it within a creative photography workspace filled with soft natural light, scattered camera gear, open photo books, and warm wooden textures."],
530
+ ["./test_imgs/2.png", "Create a stylish e-commerce scene featuring the product displayed on a modern clothing rack in a bright boutique environment, surrounded by soft natural lighting, minimalistic decor, and complementary fashion accessories"]
531
+ ],
532
+ inputs=[stage1_input_image, stage1_prompt],
533
+ label="Click to Load Examples"
534
+ )
535
+
536
+ # ========================================================
537
+ # STAGE 2: Image Correction (右侧)
538
+ # ========================================================
539
+ with gr.Column(scale=1):
540
+ gr.Markdown(
541
+ """
542
+ <div class="clean-card">
543
+ <div class="clean-card-title">🔧 Stage 2: Image Consistency Correction</div>
544
+ <div class="clean-card-subtitle">Refine and correct generated images using ImageCritic.</div>
545
+ </div>
546
+ """
547
+ )
548
+
549
+ # -------------------------------------------------------
550
+ # Tips for Stage 2
551
+ # -------------------------------------------------------
552
+ gr.Markdown(
553
+ """
554
+ <div class="clean-card">
555
+ <div class="clean-card-title">💡 Stage 2 Tips</div>
556
+ <div class="clean-card-subtitle">
557
+ • Crop both the bbox that needs to be corrected and the reference bbox, preferably covering the smallest repeating unit, to achieve better results.<br>
558
+ • The bbox area should ideally cover the region to be corrected and the reference region as completely as possible.<br>
559
+ • The aspect ratio of the bboxes should also be kept consistent to avoid errors caused by incorrect scaling.<br>
560
+ • If model fails to correct the image, it may be because the generated image is too similar to the reference image, causing the model to skip the repair. You can manually<b> paint that area black on a drawing board before sending to the model, or try cropping only the local region and performing multiple rounds correcting to progressively enhance the whole generated image.</b>
561
+ </div>
562
+ """
563
+ )
564
+
565
+ # -------------------------------------------------------
566
+ # Image annotation area
567
+ # -------------------------------------------------------
568
+ with gr.Row():
569
+ # Left: Reference Image
570
+ with gr.Column():
571
+ gr.Markdown(
572
+ """
573
+ <div class="clean-card">
574
+ <div class="clean-card-title">📌 Reference Image</div>
575
+ <div class="clean-card-subtitle">Draw a bounding box around the area for reference.</div>
576
+ </div>
577
+ """
578
+ )
579
+
580
+ annotator_A = image_annotator(
581
+ value=None,
582
+ label="reference image",
583
+ label_list=["bbox for reference"],
584
+ label_colors = [(168, 160, 194)],
585
+ single_box=True,
586
+ image_type="numpy",
587
+ sources=["upload", "clipboard"],
588
+ height=300,
589
+ )
590
+
591
+ # Right: Image to be corrected
592
+ with gr.Column():
593
+ gr.Markdown(
594
+ """
595
+ <div class="clean-card">
596
+ <div class="clean-card-title">🖼️ Input Image To Be Corrected</div>
597
+ <div class="clean-card-subtitle">Use the mouse wheel to zoom and draw a bounding box around the area to be corrected.</div>
598
+ </div>
599
+ """
600
+ )
601
+
602
+ annotator_B = image_annotator(
603
+ value=None,
604
+ label="input image to be corrected",
605
+ label_list=["bbox for correction"],
606
+ label_colors = [(168, 160, 194)],
607
+ single_box=True,
608
+ image_type="numpy",
609
+ sources=["upload", "clipboard"],
610
+ height=300,
611
+ )
612
+
613
+ # -------------------------------------------------------
614
+ # Controls
615
+ # -------------------------------------------------------
616
+ with gr.Row():
617
+ object_name = gr.Textbox(
618
+ label="Caption for object (optional; using 'product' also works)",
619
+ value="product",
620
+ placeholder="e.g. product, shoes, bag, face ..."
621
+ )
622
+
623
+ base_seed = gr.Number(
624
+ label="Seed",
625
+ value=0,
626
+ precision=0,
627
+ )
628
+
629
+ # -------------------------------------------------------
630
+ # Run Button
631
+ # -------------------------------------------------------
632
+ run_btn = gr.Button("✨ Generate ", elem_classes="color-btn")
633
+
634
+ # ===================== 输出区 =====================
635
+ gr.Markdown("### Output")
636
+ with gr.Column(elem_classes="output-card1"):
637
+ gen_patch_out = gr.Image(
638
+ label="concatenated input-output",
639
+ interactive=False
640
+ )
641
+
642
+ with gr.Column(elem_classes="output-card1"):
643
+ composed_out = gr.Image(
644
+ label="corrected image",
645
+ interactive=False
646
+ )
647
+
648
+ # -------------------------------------------------------
649
+ # Stage 2 Example 区域整体放进一个白色卡片
650
+ # -------------------------------------------------------
651
+ with gr.Column(elem_classes="clean-card"):
652
+
653
+ gr.Markdown(
654
+ """
655
+ <div style="
656
+ font-size: 1.3rem;
657
+ font-weight: 600;
658
+ color: #404040;
659
+ margin-bottom: 6px;
660
+ ">
661
+ 📚 Example Images
662
+ </div>
663
+ """,
664
+ )
665
+
666
+ gr.Markdown(
667
+ """
668
+ <div style="
669
+ font-size: 1.1rem;
670
+ color: #404040;
671
+ margin-bottom: 8px;
672
+ ">
673
+ Below are some example pairs showing how bounding boxes should be drawn.
674
+ You can click and drag the image below into the upper area for generation.<br>
675
+ <b> Full-image input is also supported, but it is recommended to use the smallest possible bounding box that covers the region to be corrected and reference bbox. For example, the bbox approach used in the first row generally produces better results than the one used in the second row.</b>
676
+ </div>
677
+ """,
678
+ )
679
+ with gr.Row():
680
+ gr.Image("./test_imgs/product_3.png",label="reference example", elem_classes="example-image")
681
+ gr.Image("./test_imgs/product_3_bbox_1.png",label="reference example with bbox",elem_classes="example-image")
682
+ gr.Image("./test_imgs/generated_3.png",label="input example", elem_classes="example-image")
683
+ gr.Image("./test_imgs/generated_3_bbox_1.png",label="input example with bbox", elem_classes="example-image")
684
+
685
+
686
+ with gr.Row():
687
+ gr.Image("./test_imgs/product_3.png",label="reference example", elem_classes="example-image")
688
+ gr.Image("./test_imgs/product_3_bbox.png",label="reference example with bbox",elem_classes="example-image")
689
+ gr.Image("./test_imgs/generated_3.png",label="input example", elem_classes="example-image")
690
+ gr.Image("./test_imgs/generated_3_bbox.png",label="input example with bbox", elem_classes="example-image")
691
+
692
+ with gr.Row():
693
+ gr.Image("./test_imgs/product_1.jpg", label="reference example", elem_classes="example-image")
694
+ gr.Image("./test_imgs/product_1_bbox.png", label="reference example with bbox", elem_classes="example-image")
695
+ gr.Image("./test_imgs/generated_1.png", label="input example", elem_classes="example-image")
696
+ gr.Image("./test_imgs/generated_1_bbox.png", label="input example with bbox", elem_classes="example-image")
697
+
698
+ with gr.Row():
699
+ gr.Image("./test_imgs/product_2.png",label="reference example", elem_classes="example-image")
700
+ gr.Image("./test_imgs/product_2_bbox.png",label="reference example with bbox",elem_classes="example-image")
701
+ gr.Image("./test_imgs/generated_2.png", label="input example", elem_classes="example-image")
702
+ gr.Image("./test_imgs/generated_2_bbox.png", label="input example with bbox", elem_classes="example-image")
703
+
704
+ # ========= 所有组件都定义完,再绑定按钮点击 =========
705
+ # Stage 1: Image Generation
706
+ stage1_method1_btn.click(
707
+ fn=generate_image_method1,
708
+ inputs=[stage1_input_image, stage1_prompt, stage1_width, stage1_height, stage1_seed, stage1_randomize_seed, stage1_guidance_scale, stage1_steps],
709
+ outputs=[stage1_output_image, stage1_used_seed],
710
+ )
711
+
712
+ # Stage 2: Image Correction
713
+ run_btn.click(
714
+ fn=run_with_two_bboxes,
715
+ inputs=[annotator_A, annotator_B, object_name, base_seed],
716
+ outputs=[gen_patch_out, composed_out],
717
+ )
718
+
719
+ if __name__ == "__main__":
720
+ demo.launch(server_name="0.0.0.0", server_port=7779)
requirements.txt ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu124
2
+ torch
3
+ torchvision
4
+ accelerate==1.10.0
5
+ clip @ git+https://github.com/openai/CLIP.git@dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
6
+ contourpy==1.3.2
7
+ cycler==0.12.1
8
+ datasets==4.0.0
9
+ decord==0.6.0
10
+ diffusers @ git+https://github.com/huggingface/diffusers.git@345864eb852b528fd1f4b6ad087fa06e0470006b
11
+ gradio==5.49.1
12
+ gradio_client==1.13.3
13
+ gradio_image_annotation==0.4.1
14
+ huggingface-hub==0.35.3
15
+ ipykernel==7.0.1
16
+ ipython==8.37.0
17
+ Jinja2==3.1.6
18
+ multiprocess==0.70.16
19
+ ninja==1.13.0
20
+ numpy==2.2.6
21
+ open_clip_torch==3.2.0
22
+ openai==1.107.2
23
+ opencv-python==4.12.0.88
24
+ opencv-python-headless==4.12.0.88
25
+ qwen-vl-utils==0.0.11
26
+ requests==2.32.5
27
+ safetensors==0.6.2
28
+ scikit-learn==1.7.2
29
+ tornado==6.5.2
30
+ tqdm==4.67.1
31
+ transformers==4.51.3
32
+ wandb==0.21.1
33
+ einops
34
+ sentencepiece
src/__init__.py ADDED
File without changes
src/attention_processor.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from typing import Optional, Tuple, Dict, Any
4
+ import os
5
+ import numpy as np
6
+ from PIL import Image
7
+ import matplotlib.pyplot as plt
8
+ from diffusers.models.attention_processor import FluxAttnProcessor2_0
9
+
10
+ class VisualFluxAttnProcessor2_0(FluxAttnProcessor2_0):
11
+ """
12
+ 自定义的Flux注意力处理器,用于保存注意力图进行可视化
13
+ """
14
+
15
+ def __init__(self, save_attention=True, save_dir="attention_maps"):
16
+ super().__init__()
17
+ self.save_attention = save_attention
18
+ self.save_dir = save_dir
19
+ self.step_counter = 0
20
+
21
+ # 创建保存目录
22
+ if self.save_attention:
23
+ os.makedirs(self.save_dir, exist_ok=True)
24
+
25
+ def save_attention_map(self, attn_weights, layer_name="", step=None):
26
+ """保存注意力图"""
27
+ if not self.save_attention:
28
+ return
29
+
30
+ if step is None:
31
+ step = self.step_counter
32
+
33
+ # 取第一个batch和第一个head的注意力权重
34
+ attn_map = attn_weights[0, 0].detach().cpu().numpy() # [seq_len, seq_len]
35
+
36
+ # 创建热力图
37
+ plt.figure(figsize=(12, 10))
38
+ plt.imshow(attn_map, cmap='hot', interpolation='nearest')
39
+ plt.colorbar()
40
+ plt.title(f'Attention Map - {layer_name} - Step {step}')
41
+ plt.xlabel('Key Position')
42
+ plt.ylabel('Query Position')
43
+
44
+ # 保存图片
45
+ save_path = os.path.join(self.save_dir, f"attention_{layer_name}_step_{step}.png")
46
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
47
+ plt.close()
48
+
49
+ print(f"Attention map saved to: {save_path}")
50
+
51
+ def __call__(
52
+ self,
53
+ attn,
54
+ hidden_states: torch.Tensor,
55
+ encoder_hidden_states: Optional[torch.Tensor] = None,
56
+ attention_mask: Optional[torch.Tensor] = None,
57
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
58
+ use_cond: bool = False,
59
+ ) -> torch.Tensor:
60
+ batch_size, sequence_length, _ = (
61
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
62
+ )
63
+
64
+ if attention_mask is not None:
65
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
66
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
67
+
68
+ if attn.group_norm is not None:
69
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
70
+
71
+ query = attn.to_q(hidden_states)
72
+
73
+ if encoder_hidden_states is None:
74
+ encoder_hidden_states = hidden_states
75
+ elif attn.norm_cross:
76
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
77
+
78
+ key = attn.to_k(encoder_hidden_states)
79
+ value = attn.to_v(encoder_hidden_states)
80
+
81
+ inner_dim = key.shape[-1]
82
+ head_dim = inner_dim // attn.heads
83
+
84
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
85
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
86
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
87
+
88
+ if attn.norm_q is not None:
89
+ query = attn.norm_q(query)
90
+ if attn.norm_k is not None:
91
+ key = attn.norm_k(key)
92
+
93
+ # 应用旋转位置编码
94
+ if image_rotary_emb is not None:
95
+ query = attn.rotary_emb(query, image_rotary_emb)
96
+ if not attn.is_cross_attention:
97
+ key = attn.rotary_emb(key, image_rotary_emb)
98
+
99
+ # 计算注意力权重
100
+ attention_scores = torch.matmul(query, key.transpose(-2, -1)) / (head_dim ** 0.5)
101
+
102
+ if attention_mask is not None:
103
+ attention_scores = attention_scores + attention_mask
104
+
105
+ attention_probs = F.softmax(attention_scores, dim=-1)
106
+
107
+ # 保存注意力图
108
+ if self.save_attention and self.step_counter % 10 == 0: # 每10步保存一次
109
+ layer_name = f"layer_{self.step_counter // 10}"
110
+ self.save_attention_map(attention_probs, layer_name, self.step_counter)
111
+
112
+ # 应用dropout
113
+ attention_probs = F.dropout(attention_probs, p=attn.dropout, training=attn.training)
114
+
115
+ # 计算输出
116
+ hidden_states = torch.matmul(attention_probs, value)
117
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
118
+ hidden_states = hidden_states.to(query.dtype)
119
+
120
+ if use_cond:
121
+ # 处理条件分支的情况
122
+ seq_len = hidden_states.shape[1]
123
+ if seq_len % 2 == 0:
124
+ # 假设前半部分是原始hidden_states,后半部分是条件hidden_states
125
+ mid_point = seq_len // 2
126
+ original_hidden_states = hidden_states[:, :mid_point, :]
127
+ cond_hidden_states = hidden_states[:, mid_point:, :]
128
+
129
+ # 分别处理
130
+ original_output = attn.to_out[0](original_hidden_states)
131
+ cond_output = attn.to_out[0](cond_hidden_states)
132
+
133
+ if len(attn.to_out) > 1:
134
+ original_output = attn.to_out[1](original_output)
135
+ cond_output = attn.to_out[1](cond_output)
136
+
137
+ self.step_counter += 1
138
+ return original_output, cond_output
139
+
140
+ # 标准输出处理
141
+ hidden_states = attn.to_out[0](hidden_states)
142
+ if len(attn.to_out) > 1:
143
+ hidden_states = attn.to_out[1](hidden_states)
144
+
145
+ self.step_counter += 1
146
+ return hidden_states
src/detail_encoder.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Merge image encoder and fuse module to create an ID Encoder
2
+ # send multiple ID images, we can directly obtain the updated text encoder containing a stacked ID embedding
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from transformers.models.clip.modeling_clip import CLIPVisionModelWithProjection
8
+ from transformers.models.clip.configuration_clip import CLIPVisionConfig
9
+ from transformers import PretrainedConfig
10
+
11
+ VISION_CONFIG_DICT = {
12
+ "hidden_size": 1024,
13
+ "intermediate_size": 4096,
14
+ "num_attention_heads": 16,
15
+ "num_hidden_layers": 24,
16
+ "patch_size": 14,
17
+ "projection_dim": 768
18
+ }
19
+
20
+ class MLP(nn.Module):
21
+ def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True):
22
+ super().__init__()
23
+ if use_residual:
24
+ assert in_dim == out_dim
25
+ self.layernorm = nn.LayerNorm(in_dim)
26
+ self.fc1 = nn.Linear(in_dim, hidden_dim)
27
+ self.fc2 = nn.Linear(hidden_dim, out_dim)
28
+ self.use_residual = use_residual
29
+ self.act_fn = nn.GELU()
30
+
31
+ def forward(self, x):
32
+ residual = x
33
+ x = self.layernorm(x)
34
+ x = self.fc1(x)
35
+ x = self.act_fn(x)
36
+ x = self.fc2(x)
37
+ if self.use_residual:
38
+ x = x + residual
39
+ return x
40
+
41
+
42
+ class FuseModule(nn.Module):
43
+ def __init__(self, prompt_embed_dim, id_embed_dim):
44
+ super().__init__()
45
+ self.mlp1 = MLP(prompt_embed_dim + id_embed_dim, prompt_embed_dim, prompt_embed_dim, use_residual=False)
46
+ self.mlp2 = MLP(prompt_embed_dim, prompt_embed_dim, prompt_embed_dim, use_residual=True)
47
+ self.layer_norm = nn.LayerNorm(prompt_embed_dim)
48
+
49
+ def fuse_fn(self, prompt_embeds, id_embeds):
50
+ stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1)
51
+ stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds
52
+ stacked_id_embeds = self.mlp2(stacked_id_embeds)
53
+ stacked_id_embeds = self.layer_norm(stacked_id_embeds)
54
+ return stacked_id_embeds
55
+
56
+ def forward(
57
+ self,
58
+ prompt_embeds,
59
+ id_embeds,
60
+ class_tokens_mask,
61
+ ) -> torch.Tensor:
62
+ device = prompt_embeds.device
63
+ class_tokens_mask = class_tokens_mask.to(device)
64
+ id_embeds = id_embeds.to(prompt_embeds.dtype)
65
+ num_inputs = class_tokens_mask.sum().unsqueeze(0).to(id_embeds.device)
66
+ batch_size, max_num_inputs = id_embeds.shape[:2]
67
+ seq_length = prompt_embeds.shape[1]
68
+ flat_id_embeds = id_embeds.view(
69
+ -1, id_embeds.shape[-2], id_embeds.shape[-1]
70
+ )
71
+ valid_id_mask = (
72
+ torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :]
73
+ < num_inputs[:, None]
74
+ )
75
+ valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()]
76
+
77
+ prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1])
78
+ class_tokens_mask = class_tokens_mask.view(-1)
79
+ valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1])
80
+ image_token_embeds = prompt_embeds[class_tokens_mask]
81
+ stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds)
82
+ stacked_id_embeds = stacked_id_embeds.to(device=device, dtype=prompt_embeds.dtype)
83
+ assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}"
84
+ prompt_embeds = prompt_embeds.masked_scatter(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype))
85
+ updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1)
86
+ return updated_prompt_embeds
87
+
88
+ class DetailEncoder(CLIPVisionModelWithProjection):
89
+ def __init__(self):
90
+
91
+ super().__init__(CLIPVisionConfig(**VISION_CONFIG_DICT))
92
+ self.visual_projection_2 = nn.Linear(1024, 1280, bias=False)
93
+ self.fuse_module = FuseModule(4096, 2048)
94
+
95
+ def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask):
96
+ dtype = next(self.parameters()).dtype
97
+ device = next(self.parameters()).device
98
+ b, num_inputs, c, h, w = id_pixel_values.shape
99
+ # device setting
100
+ id_pixel_values = id_pixel_values.to(device=device, dtype=dtype)
101
+ prompt_embeds = prompt_embeds.to(device=device, dtype=dtype)
102
+ class_tokens_mask = class_tokens_mask.to(device=device)
103
+
104
+ id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w)
105
+
106
+ id_pixel_values = F.interpolate(id_pixel_values, size=(224, 224), mode="bilinear", align_corners=False)
107
+ # id embeds <--> input image
108
+ shared_id_embeds = self.vision_model(id_pixel_values)[1]
109
+ id_embeds = self.visual_projection(shared_id_embeds)
110
+ id_embeds_2 = self.visual_projection_2(shared_id_embeds)
111
+
112
+ id_embeds = id_embeds.view(b, num_inputs, 1, -1)
113
+ id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
114
+
115
+ id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
116
+ updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask)
117
+ return updated_prompt_embeds
118
+
src/jsonl_datasets.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from datasets import load_dataset
3
+ from torchvision import transforms
4
+ import random
5
+ import torch
6
+
7
+ Image.MAX_IMAGE_PIXELS = None
8
+
9
+ def multiple_16(num: float):
10
+ return int(round(num / 16) * 16)
11
+
12
+ def get_random_resolution(min_size=512, max_size=1280, multiple=16):
13
+ resolution = random.randint(min_size // multiple, max_size // multiple) * multiple
14
+ return resolution
15
+
16
+ def load_image_safely(image_path, size):
17
+ try:
18
+ image = Image.open(image_path).convert("RGB")
19
+ return image
20
+ except Exception as e:
21
+ print("file error: "+image_path)
22
+ with open("failed_images.txt", "a") as f:
23
+ f.write(f"{image_path}\n")
24
+ return Image.new("RGB", (size, size), (255, 255, 255))
25
+
26
+ def make_train_dataset(args, tokenizer, accelerator=None):
27
+ if args.train_data_dir is not None:
28
+ print("load_data")
29
+ dataset = load_dataset('json', data_files=args.train_data_dir)
30
+
31
+ column_names = dataset["train"].column_names
32
+
33
+ # 6. Get the column names for input/target.
34
+ caption_column = args.caption_column
35
+ target_column = args.target_column
36
+ if args.subject_column is not None:
37
+ subject_columns = args.subject_column.split(",")
38
+ if args.spatial_column is not None:
39
+ spatial_columns= args.spatial_column.split(",")
40
+
41
+ size = args.cond_size
42
+ noise_size = get_random_resolution(max_size=args.noise_size) # maybe 768 or higher
43
+ subject_cond_train_transforms = transforms.Compose(
44
+ [
45
+ transforms.Lambda(lambda img: img.resize((
46
+ multiple_16(size * img.size[0] / max(img.size)),
47
+ multiple_16(size * img.size[1] / max(img.size))
48
+ ), resample=Image.BILINEAR)),
49
+ transforms.RandomHorizontalFlip(p=0.7),
50
+ transforms.RandomRotation(degrees=20),
51
+ transforms.Lambda(lambda img: transforms.Pad(
52
+ padding=(
53
+ int((size - img.size[0]) / 2),
54
+ int((size - img.size[1]) / 2),
55
+ int((size - img.size[0]) / 2),
56
+ int((size - img.size[1]) / 2)
57
+ ),
58
+ fill=0
59
+ )(img)),
60
+ transforms.ToTensor(),
61
+ transforms.Normalize([0.5], [0.5]),
62
+ ]
63
+ )
64
+ cond_train_transforms = transforms.Compose(
65
+ [
66
+ transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BILINEAR),
67
+ transforms.CenterCrop((size, size)),
68
+ transforms.ToTensor(),
69
+ transforms.Normalize([0.5], [0.5]),
70
+ ]
71
+ )
72
+
73
+ def train_transforms(image, noise_size):
74
+ train_transforms_ = transforms.Compose(
75
+ [
76
+ transforms.Lambda(lambda img: img.resize((
77
+ multiple_16(noise_size * img.size[0] / max(img.size)),
78
+ multiple_16(noise_size * img.size[1] / max(img.size))
79
+ ), resample=Image.BILINEAR)),
80
+ transforms.ToTensor(),
81
+ transforms.Normalize([0.5], [0.5]),
82
+ ]
83
+ )
84
+ transformed_image = train_transforms_(image)
85
+ return transformed_image
86
+
87
+ def load_and_transform_cond_images(images):
88
+ transformed_images = [cond_train_transforms(image) for image in images]
89
+ concatenated_image = torch.cat(transformed_images, dim=1)
90
+ return concatenated_image
91
+
92
+ def load_and_transform_subject_images(images):
93
+ transformed_images = [subject_cond_train_transforms(image) for image in images]
94
+ concatenated_image = torch.cat(transformed_images, dim=1)
95
+ return concatenated_image
96
+
97
+ tokenizer_clip = tokenizer[0]
98
+ tokenizer_t5 = tokenizer[1]
99
+
100
+ def tokenize_prompt_clip_t5(examples):
101
+ captions = []
102
+ for caption in examples[caption_column]:
103
+ if isinstance(caption, str):
104
+ if random.random() < 0.1:
105
+ captions.append(" ") # 将文本设为空
106
+ else:
107
+ captions.append(caption)
108
+ elif isinstance(caption, list):
109
+ # take a random caption if there are multiple
110
+ if random.random() < 0.1:
111
+ captions.append(" ")
112
+ else:
113
+ captions.append(random.choice(caption))
114
+ else:
115
+ raise ValueError(
116
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
117
+ )
118
+ text_inputs = tokenizer_clip(
119
+ captions,
120
+ padding="max_length",
121
+ max_length=77,
122
+ truncation=True,
123
+ return_length=False,
124
+ return_overflowing_tokens=False,
125
+ return_tensors="pt",
126
+ )
127
+ text_input_ids_1 = text_inputs.input_ids
128
+
129
+ text_inputs = tokenizer_t5(
130
+ captions,
131
+ padding="max_length",
132
+ max_length=512,
133
+ truncation=True,
134
+ return_length=False,
135
+ return_overflowing_tokens=False,
136
+ return_tensors="pt",
137
+ )
138
+ text_input_ids_2 = text_inputs.input_ids
139
+ return text_input_ids_1, text_input_ids_2
140
+
141
+ def preprocess_train(examples):
142
+ _examples = {}
143
+ if args.subject_column is not None:
144
+ subject_images = [[load_image_safely(examples[column][i], args.cond_size) for column in subject_columns] for i in range(len(examples[target_column]))]
145
+ _examples["subject_pixel_values"] = [load_and_transform_subject_images(subject) for subject in subject_images]
146
+ if args.spatial_column is not None:
147
+ spatial_images = [[load_image_safely(examples[column][i], args.cond_size) for column in spatial_columns] for i in range(len(examples[target_column]))]
148
+ _examples["cond_pixel_values"] = [load_and_transform_cond_images(spatial) for spatial in spatial_images]
149
+ target_images = [load_image_safely(image_path, args.cond_size) for image_path in examples[target_column]]
150
+ _examples["pixel_values"] = [train_transforms(image, noise_size) for image in target_images]
151
+ _examples["token_ids_clip"], _examples["token_ids_t5"] = tokenize_prompt_clip_t5(examples)
152
+ return _examples
153
+
154
+ if accelerator is not None:
155
+ with accelerator.main_process_first():
156
+ train_dataset = dataset["train"].with_transform(preprocess_train)
157
+ else:
158
+ train_dataset = dataset["train"].with_transform(preprocess_train)
159
+
160
+ return train_dataset
161
+
162
+
163
+ def collate_fn(examples):
164
+ if examples[0].get("cond_pixel_values") is not None:
165
+ cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples])
166
+ cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
167
+ else:
168
+ cond_pixel_values = None
169
+ if examples[0].get("subject_pixel_values") is not None:
170
+ subject_pixel_values = torch.stack([example["subject_pixel_values"] for example in examples])
171
+ subject_pixel_values = subject_pixel_values.to(memory_format=torch.contiguous_format).float()
172
+ else:
173
+ subject_pixel_values = None
174
+
175
+ target_pixel_values = torch.stack([example["pixel_values"] for example in examples])
176
+ target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float()
177
+ token_ids_clip = torch.stack([torch.tensor(example["token_ids_clip"]) for example in examples])
178
+ token_ids_t5 = torch.stack([torch.tensor(example["token_ids_t5"]) for example in examples])
179
+
180
+ return {
181
+ "cond_pixel_values": cond_pixel_values,
182
+ "subject_pixel_values": subject_pixel_values,
183
+ "pixel_values": target_pixel_values,
184
+ "text_ids_1": token_ids_clip,
185
+ "text_ids_2": token_ids_t5,
186
+ }
src/kontext_custom_pipeline.py ADDED
The diff for this file is too large to render. See raw diff
 
src/layers.py ADDED
@@ -0,0 +1,673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from typing import Callable, List, Optional, Tuple, Union
4
+ from einops import rearrange
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from torch import Tensor
9
+ from diffusers.models.attention_processor import Attention
10
+
11
+ # Global variables for attention visualization
12
+ step = 0
13
+ global_timestep = 0
14
+ global_timestep2 = 0
15
+
16
+ def scaled_dot_product_average_attention_map(query, key, attn_mask=None, is_causal=False, scale=None) -> torch.Tensor:
17
+ # Efficient implementation equivalent to the following:
18
+ L, S = query.size(-2), key.size(-2)
19
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
20
+ attn_bias = torch.zeros(L, S, dtype=query.dtype)
21
+ if is_causal:
22
+ assert attn_mask is None
23
+ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
24
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
25
+ attn_bias.to(query.dtype)
26
+
27
+ if attn_mask is not None:
28
+ if attn_mask.dtype == torch.bool:
29
+ attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
30
+ else:
31
+ attn_bias += attn_mask
32
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
33
+ attn_weight += attn_bias.to(attn_weight.device)
34
+ attn_weight = attn_weight.mean(dim=(1, 2))
35
+ return attn_weight
36
+
37
+ class LoRALinearLayer(nn.Module):
38
+ def __init__(
39
+ self,
40
+ in_features: int,
41
+ out_features: int,
42
+ rank: int = 4,
43
+ network_alpha: Optional[float] = None,
44
+ device: Optional[Union[torch.device, str]] = None,
45
+ dtype: Optional[torch.dtype] = None,
46
+ ):
47
+ super().__init__()
48
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
49
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
50
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
51
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
52
+ self.network_alpha = network_alpha
53
+ self.rank = rank
54
+ self.out_features = out_features
55
+ self.in_features = in_features
56
+
57
+ nn.init.normal_(self.down.weight, std=1 / rank)
58
+ nn.init.zeros_(self.up.weight)
59
+
60
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
61
+ orig_dtype = hidden_states.dtype
62
+ dtype = self.down.weight.dtype
63
+
64
+ down_hidden_states = self.down(hidden_states.to(dtype))
65
+ up_hidden_states = self.up(down_hidden_states)
66
+
67
+ if self.network_alpha is not None:
68
+ up_hidden_states *= self.network_alpha / self.rank
69
+
70
+ return up_hidden_states.to(orig_dtype)
71
+
72
+
73
+ class MultiSingleStreamBlockLoraProcessor(nn.Module):
74
+ def __init__(self, in_features: int, out_features: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, n_loras=1):
75
+ super().__init__()
76
+ # Initialize a list to store the LoRA layers
77
+ self.n_loras = n_loras
78
+ self.q_loras = nn.ModuleList([
79
+ LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
80
+ for i in range(n_loras)
81
+ ])
82
+ self.k_loras = nn.ModuleList([
83
+ LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
84
+ for i in range(n_loras)
85
+ ])
86
+ self.v_loras = nn.ModuleList([
87
+ LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
88
+ for i in range(n_loras)
89
+ ])
90
+ self.lora_weights = lora_weights
91
+
92
+
93
+ def __call__(self,
94
+ attn: Attention,
95
+ hidden_states: torch.FloatTensor,
96
+ encoder_hidden_states: torch.FloatTensor = None,
97
+ attention_mask: Optional[torch.FloatTensor] = None,
98
+ image_rotary_emb: Optional[torch.Tensor] = None,
99
+ use_cond = False,
100
+ ) -> torch.FloatTensor:
101
+
102
+ batch_size, seq_len, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
103
+ query = attn.to_q(hidden_states)
104
+ key = attn.to_k(hidden_states)
105
+ value = attn.to_v(hidden_states)
106
+
107
+ for i in range(self.n_loras):
108
+ query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
109
+ key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
110
+ value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
111
+
112
+ inner_dim = key.shape[-1]
113
+ head_dim = inner_dim // attn.heads
114
+
115
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
116
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
117
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
118
+
119
+ if attn.norm_q is not None:
120
+ query = attn.norm_q(query)
121
+ if attn.norm_k is not None:
122
+ key = attn.norm_k(key)
123
+
124
+ if image_rotary_emb is not None:
125
+ from diffusers.models.embeddings import apply_rotary_emb
126
+ query = apply_rotary_emb(query, image_rotary_emb)
127
+ key = apply_rotary_emb(key, image_rotary_emb)
128
+
129
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
130
+
131
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
132
+ hidden_states = hidden_states.to(query.dtype)
133
+
134
+ return hidden_states
135
+
136
+
137
+ class MultiDoubleStreamBlockLoraProcessor(nn.Module):
138
+ def __init__(self, in_features: int, out_features: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, n_loras=1):
139
+ super().__init__()
140
+
141
+ # Initialize a list to store the LoRA layers
142
+ self.n_loras = n_loras
143
+ self.q_loras = nn.ModuleList([
144
+ LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
145
+ for i in range(n_loras)
146
+ ])
147
+ self.k_loras = nn.ModuleList([
148
+ LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
149
+ for i in range(n_loras)
150
+ ])
151
+ self.v_loras = nn.ModuleList([
152
+ LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
153
+ for i in range(n_loras)
154
+ ])
155
+ self.proj_loras = nn.ModuleList([
156
+ LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
157
+ for i in range(n_loras)
158
+ ])
159
+ self.lora_weights = lora_weights
160
+
161
+
162
+ def __call__(self,
163
+ attn: Attention,
164
+ hidden_states: torch.FloatTensor,
165
+ encoder_hidden_states: torch.FloatTensor = None,
166
+ attention_mask: Optional[torch.FloatTensor] = None,
167
+ image_rotary_emb: Optional[torch.Tensor] = None,
168
+ use_cond=False,
169
+ ) -> torch.FloatTensor:
170
+
171
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
172
+
173
+ # `context` projections.
174
+ inner_dim = 3072
175
+ head_dim = inner_dim // attn.heads
176
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
177
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
178
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
179
+
180
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
181
+ batch_size, -1, attn.heads, head_dim
182
+ ).transpose(1, 2)
183
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
184
+ batch_size, -1, attn.heads, head_dim
185
+ ).transpose(1, 2)
186
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
187
+ batch_size, -1, attn.heads, head_dim
188
+ ).transpose(1, 2)
189
+
190
+ if attn.norm_added_q is not None:
191
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
192
+ if attn.norm_added_k is not None:
193
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
194
+
195
+ query = attn.to_q(hidden_states)
196
+ key = attn.to_k(hidden_states)
197
+ value = attn.to_v(hidden_states)
198
+ for i in range(self.n_loras):
199
+ query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
200
+ key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
201
+ value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
202
+
203
+ inner_dim = key.shape[-1]
204
+ head_dim = inner_dim // attn.heads
205
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
206
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
207
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
208
+
209
+ if attn.norm_q is not None:
210
+ query = attn.norm_q(query)
211
+ if attn.norm_k is not None:
212
+ key = attn.norm_k(key)
213
+
214
+ # attention
215
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
216
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
217
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
218
+
219
+ if image_rotary_emb is not None:
220
+ from diffusers.models.embeddings import apply_rotary_emb
221
+ query = apply_rotary_emb(query, image_rotary_emb)
222
+ key = apply_rotary_emb(key, image_rotary_emb)
223
+
224
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
225
+
226
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
227
+ hidden_states = hidden_states.to(query.dtype)
228
+
229
+ encoder_hidden_states, hidden_states = (
230
+ hidden_states[:, : encoder_hidden_states.shape[1]],
231
+ hidden_states[:, encoder_hidden_states.shape[1] :],
232
+ )
233
+
234
+ # Linear projection (with LoRA weight applied to each proj layer)
235
+ hidden_states = attn.to_out[0](hidden_states)
236
+ for i in range(self.n_loras):
237
+ hidden_states = hidden_states + self.lora_weights[i] * self.proj_loras[i](hidden_states)
238
+ # dropout
239
+ hidden_states = attn.to_out[1](hidden_states)
240
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
241
+
242
+ return (hidden_states, encoder_hidden_states)
243
+
244
+
245
+ class MultiSingleStreamBlockLoraProcessorWithLoss(nn.Module):
246
+ def __init__(self, in_features: int, out_features: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, n_loras=1):
247
+ super().__init__()
248
+ # Initialize a list to store the LoRA layers
249
+ self.n_loras = n_loras
250
+ self.q_loras = nn.ModuleList([
251
+ LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
252
+ for i in range(n_loras)
253
+ ])
254
+ self.k_loras = nn.ModuleList([
255
+ LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
256
+ for i in range(n_loras)
257
+ ])
258
+ self.v_loras = nn.ModuleList([
259
+ LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
260
+ for i in range(n_loras)
261
+ ])
262
+ self.lora_weights = lora_weights
263
+
264
+
265
+ def __call__(self,
266
+ attn: Attention,
267
+ hidden_states: torch.FloatTensor,
268
+ encoder_hidden_states: torch.FloatTensor = None,
269
+ attention_mask: Optional[torch.FloatTensor] = None,
270
+ image_rotary_emb: Optional[torch.Tensor] = None,
271
+ use_cond = False,
272
+ ) -> torch.FloatTensor:
273
+
274
+ batch_size, seq_len, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
275
+ query = attn.to_q(hidden_states)
276
+ key = attn.to_k(hidden_states)
277
+ value = attn.to_v(hidden_states)
278
+ encoder_hidden_length = 512
279
+
280
+ length = (hidden_states.shape[-2] - encoder_hidden_length) // 3
281
+
282
+
283
+ for i in range(self.n_loras):
284
+ query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
285
+ key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
286
+ value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
287
+
288
+ inner_dim = key.shape[-1]
289
+ head_dim = inner_dim // attn.heads
290
+
291
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
292
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
293
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
294
+
295
+ if attn.norm_q is not None:
296
+ query = attn.norm_q(query)
297
+ if attn.norm_k is not None:
298
+ key = attn.norm_k(key)
299
+
300
+ if image_rotary_emb is not None:
301
+ from diffusers.models.embeddings import apply_rotary_emb
302
+ query = apply_rotary_emb(query, image_rotary_emb)
303
+ key = apply_rotary_emb(key, image_rotary_emb)
304
+
305
+ # query_cond_a = query[:, :, encoder_hidden_length+length : encoder_hidden_length+2*length, :]
306
+ # query_cond_b = query[:, :, encoder_hidden_length+2*length : encoder_hidden_length+3*length, :]
307
+
308
+ # key_noise = key[:, :, encoder_hidden_length:encoder_hidden_length+length, :]
309
+
310
+
311
+ # attention_probs_query_a_key_noise = scaled_dot_product_average_attention_map(query_cond_a, key_noise, attn_mask=attention_mask, is_causal=False)
312
+ # attention_probs_query_b_key_noise = scaled_dot_product_average_attention_map(query_cond_b, key_noise, attn_mask=attention_mask, is_causal=False)
313
+
314
+ # attn.attention_probs_query_a_key_noise = attention_probs_query_a_key_noise
315
+ # attn.attention_probs_query_b_key_noise = attention_probs_query_b_key_noise
316
+
317
+
318
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
319
+
320
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
321
+ hidden_states = hidden_states.to(query.dtype)
322
+
323
+ return hidden_states
324
+
325
+
326
+ class MultiDoubleStreamBlockLoraProcessorWithLoss(nn.Module):
327
+ def __init__(self, in_features: int, out_features: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, n_loras=1):
328
+ super().__init__()
329
+
330
+ # Initialize a list to store the LoRA layers
331
+ self.n_loras = n_loras
332
+ self.q_loras = nn.ModuleList([
333
+ LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
334
+ for i in range(n_loras)
335
+ ])
336
+ self.k_loras = nn.ModuleList([
337
+ LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
338
+ for i in range(n_loras)
339
+ ])
340
+ self.v_loras = nn.ModuleList([
341
+ LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
342
+ for i in range(n_loras)
343
+ ])
344
+ self.proj_loras = nn.ModuleList([
345
+ LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
346
+ for i in range(n_loras)
347
+ ])
348
+ self.lora_weights = lora_weights
349
+
350
+
351
+ def __call__(self,
352
+ attn: Attention,
353
+ hidden_states: torch.FloatTensor,
354
+ encoder_hidden_states: torch.FloatTensor = None,
355
+ attention_mask: Optional[torch.FloatTensor] = None,
356
+ image_rotary_emb: Optional[torch.Tensor] = None,
357
+ use_cond=False,
358
+ ) -> torch.FloatTensor:
359
+
360
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
361
+
362
+ # `context` projections.
363
+ inner_dim = 3072
364
+ head_dim = inner_dim // attn.heads
365
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
366
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
367
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
368
+
369
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
370
+ batch_size, -1, attn.heads, head_dim
371
+ ).transpose(1, 2)
372
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
373
+ batch_size, -1, attn.heads, head_dim
374
+ ).transpose(1, 2)
375
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
376
+ batch_size, -1, attn.heads, head_dim
377
+ ).transpose(1, 2)
378
+
379
+ if attn.norm_added_q is not None:
380
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
381
+ if attn.norm_added_k is not None:
382
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
383
+
384
+ query = attn.to_q(hidden_states)
385
+ key = attn.to_k(hidden_states)
386
+ value = attn.to_v(hidden_states)
387
+ length = hidden_states.shape[-2] // 3
388
+
389
+ for i in range(self.n_loras):
390
+ query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
391
+ key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
392
+ value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
393
+
394
+ inner_dim = key.shape[-1]
395
+ head_dim = inner_dim // attn.heads
396
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
397
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
398
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
399
+
400
+ if attn.norm_q is not None:
401
+ query = attn.norm_q(query)
402
+ if attn.norm_k is not None:
403
+ key = attn.norm_k(key)
404
+
405
+ # attention
406
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
407
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
408
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
409
+
410
+ if image_rotary_emb is not None:
411
+ from diffusers.models.embeddings import apply_rotary_emb
412
+ query = apply_rotary_emb(query, image_rotary_emb)
413
+ key = apply_rotary_emb(key, image_rotary_emb)
414
+ encoder_hidden_length = 512
415
+
416
+ query_cond_a = query[:, :, encoder_hidden_length+length : encoder_hidden_length+2*length, :]
417
+ query_cond_b = query[:, :, encoder_hidden_length+2*length : encoder_hidden_length+3*length, :]
418
+
419
+ key_noise = key[:, :, encoder_hidden_length:encoder_hidden_length+length, :]
420
+
421
+ attention_probs_query_a_key_noise = scaled_dot_product_average_attention_map(query_cond_a, key_noise, attn_mask=attention_mask, is_causal=False)
422
+ attention_probs_query_b_key_noise = scaled_dot_product_average_attention_map(query_cond_b, key_noise, attn_mask=attention_mask, is_causal=False)
423
+
424
+ attn.attention_probs_query_a_key_noise = attention_probs_query_a_key_noise
425
+ attn.attention_probs_query_b_key_noise = attention_probs_query_b_key_noise
426
+
427
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
428
+
429
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
430
+ hidden_states = hidden_states.to(query.dtype)
431
+
432
+ encoder_hidden_states, hidden_states = (
433
+ hidden_states[:, : encoder_hidden_states.shape[1]],
434
+ hidden_states[:, encoder_hidden_states.shape[1] :],
435
+ )
436
+
437
+ # Linear projection (with LoRA weight applied to each proj layer)
438
+ hidden_states = attn.to_out[0](hidden_states)
439
+ for i in range(self.n_loras):
440
+ hidden_states = hidden_states + self.lora_weights[i] * self.proj_loras[i](hidden_states)
441
+ # dropout
442
+ hidden_states = attn.to_out[1](hidden_states)
443
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
444
+
445
+ return (hidden_states, encoder_hidden_states)
446
+
447
+
448
+
449
+ class MultiDoubleStreamBlockLoraProcessor_visual(nn.Module):
450
+ def __init__(self, in_features: int, out_features: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, n_loras=1):
451
+ super().__init__()
452
+
453
+ # Initialize a list to store the LoRA layers
454
+ self.n_loras = n_loras
455
+ self.q_loras = nn.ModuleList([
456
+ LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
457
+ for i in range(n_loras)
458
+ ])
459
+ self.k_loras = nn.ModuleList([
460
+ LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
461
+ for i in range(n_loras)
462
+ ])
463
+ self.v_loras = nn.ModuleList([
464
+ LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
465
+ for i in range(n_loras)
466
+ ])
467
+ self.proj_loras = nn.ModuleList([
468
+ LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
469
+ for i in range(n_loras)
470
+ ])
471
+ self.lora_weights = lora_weights
472
+
473
+
474
+ def __call__(self,
475
+ attn: Attention,
476
+ hidden_states: torch.FloatTensor,
477
+ encoder_hidden_states: torch.FloatTensor = None,
478
+ attention_mask: Optional[torch.FloatTensor] = None,
479
+ image_rotary_emb: Optional[torch.Tensor] = None,
480
+ use_cond=False,
481
+ ) -> torch.FloatTensor:
482
+
483
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
484
+
485
+ # `context` projections.
486
+ inner_dim = 3072
487
+ head_dim = inner_dim // attn.heads
488
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
489
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
490
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
491
+
492
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
493
+ batch_size, -1, attn.heads, head_dim
494
+ ).transpose(1, 2)
495
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
496
+ batch_size, -1, attn.heads, head_dim
497
+ ).transpose(1, 2)
498
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
499
+ batch_size, -1, attn.heads, head_dim
500
+ ).transpose(1, 2)
501
+
502
+ if attn.norm_added_q is not None:
503
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
504
+ if attn.norm_added_k is not None:
505
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
506
+
507
+ query = attn.to_q(hidden_states)
508
+ key = attn.to_k(hidden_states)
509
+ value = attn.to_v(hidden_states)
510
+ length = hidden_states.shape[-2] // 3
511
+
512
+ for i in range(self.n_loras):
513
+ query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
514
+ key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
515
+ value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
516
+
517
+ inner_dim = key.shape[-1]
518
+ head_dim = inner_dim // attn.heads
519
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
520
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
521
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
522
+
523
+ if attn.norm_q is not None:
524
+ query = attn.norm_q(query)
525
+ if attn.norm_k is not None:
526
+ key = attn.norm_k(key)
527
+
528
+ # attention
529
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
530
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
531
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
532
+
533
+ if image_rotary_emb is not None:
534
+ from diffusers.models.embeddings import apply_rotary_emb
535
+ query = apply_rotary_emb(query, image_rotary_emb)
536
+ key = apply_rotary_emb(key, image_rotary_emb)
537
+ encoder_hidden_length = 512
538
+
539
+ query_cond_a = query[:, :, encoder_hidden_length+length : encoder_hidden_length+2*length, :]
540
+ query_cond_b = query[:, :, encoder_hidden_length+2*length : encoder_hidden_length+3*length, :]
541
+
542
+ key_noise = key[:, :, encoder_hidden_length:encoder_hidden_length+length, :]
543
+
544
+ attention_probs_query_a_key_noise = scaled_dot_product_average_attention_map(query_cond_a, key_noise, attn_mask=attention_mask, is_causal=False)
545
+ attention_probs_query_b_key_noise = scaled_dot_product_average_attention_map(query_cond_b, key_noise, attn_mask=attention_mask, is_causal=False)
546
+
547
+ if not hasattr(attn, 'attention_probs_query_a_key_noise'):
548
+ attn.attention_probs_query_a_key_noise = []
549
+ if not hasattr(attn, 'attention_probs_query_b_key_noise'):
550
+ attn.attention_probs_query_b_key_noise = []
551
+
552
+ global global_timestep
553
+
554
+ attn.attention_probs_query_a_key_noise.append((global_timestep//19, attention_probs_query_a_key_noise))
555
+ attn.attention_probs_query_b_key_noise.append((global_timestep//19, attention_probs_query_b_key_noise))
556
+
557
+ print(f"Global Timestep: {global_timestep//19}")
558
+
559
+ global_timestep += 1
560
+
561
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
562
+
563
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
564
+ hidden_states = hidden_states.to(query.dtype)
565
+
566
+ encoder_hidden_states, hidden_states = (
567
+ hidden_states[:, : encoder_hidden_states.shape[1]],
568
+ hidden_states[:, encoder_hidden_states.shape[1] :],
569
+ )
570
+
571
+ # Linear projection (with LoRA weight applied to each proj layer)
572
+ hidden_states = attn.to_out[0](hidden_states)
573
+ for i in range(self.n_loras):
574
+ hidden_states = hidden_states + self.lora_weights[i] * self.proj_loras[i](hidden_states)
575
+ # dropout
576
+ hidden_states = attn.to_out[1](hidden_states)
577
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
578
+
579
+ return (hidden_states, encoder_hidden_states)
580
+
581
+
582
+
583
+ class MultiSingleStreamBlockLoraProcessor_visual(nn.Module):
584
+ def __init__(self, in_features: int, out_features: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, n_loras=1):
585
+ super().__init__()
586
+ # Initialize a list to store the LoRA layers
587
+ self.n_loras = n_loras
588
+ self.q_loras = nn.ModuleList([
589
+ LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
590
+ for i in range(n_loras)
591
+ ])
592
+ self.k_loras = nn.ModuleList([
593
+ LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
594
+ for i in range(n_loras)
595
+ ])
596
+ self.v_loras = nn.ModuleList([
597
+ LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
598
+ for i in range(n_loras)
599
+ ])
600
+ self.lora_weights = lora_weights
601
+
602
+
603
+ def __call__(self,
604
+ attn: Attention,
605
+ hidden_states: torch.FloatTensor,
606
+ encoder_hidden_states: torch.FloatTensor = None,
607
+ attention_mask: Optional[torch.FloatTensor] = None,
608
+ image_rotary_emb: Optional[torch.Tensor] = None,
609
+ use_cond = False,
610
+ ) -> torch.FloatTensor:
611
+
612
+ batch_size, seq_len, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
613
+ query = attn.to_q(hidden_states)
614
+ key = attn.to_k(hidden_states)
615
+ value = attn.to_v(hidden_states)
616
+ encoder_hidden_length = 512
617
+
618
+ length = (hidden_states.shape[-2] - encoder_hidden_length) // 3
619
+
620
+
621
+ for i in range(self.n_loras):
622
+ query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
623
+ key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
624
+ value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
625
+
626
+ inner_dim = key.shape[-1]
627
+ head_dim = inner_dim // attn.heads
628
+
629
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
630
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
631
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
632
+
633
+ if attn.norm_q is not None:
634
+ query = attn.norm_q(query)
635
+ if attn.norm_k is not None:
636
+ key = attn.norm_k(key)
637
+
638
+ if image_rotary_emb is not None:
639
+ from diffusers.models.embeddings import apply_rotary_emb
640
+ query = apply_rotary_emb(query, image_rotary_emb)
641
+ key = apply_rotary_emb(key, image_rotary_emb)
642
+
643
+ if not hasattr(attn, 'attention_probs_query_a_key_noise2'):
644
+ attn.attention_probs_query_a_key_noise2 = []
645
+ if not hasattr(attn, 'attention_probs_query_b_key_noise2'):
646
+ attn.attention_probs_query_b_key_noise2 = []
647
+
648
+ query_cond_a = query[:, :, encoder_hidden_length+length : encoder_hidden_length+2*length, :]
649
+ query_cond_b = query[:, :, encoder_hidden_length+2*length : encoder_hidden_length+3*length, :]
650
+
651
+ key_noise = key[:, :, encoder_hidden_length:encoder_hidden_length+length, :]
652
+
653
+ attention_probs_query_a_key_noise2 = scaled_dot_product_average_attention_map(query_cond_a, key_noise, attn_mask=attention_mask, is_causal=False)
654
+ attention_probs_query_b_key_noise2 = scaled_dot_product_average_attention_map(query_cond_b, key_noise, attn_mask=attention_mask, is_causal=False)
655
+
656
+
657
+ global global_timestep2
658
+
659
+ attn.attention_probs_query_a_key_noise2.append((global_timestep//38, attention_probs_query_a_key_noise2))
660
+ attn.attention_probs_query_b_key_noise2.append((global_timestep//38, attention_probs_query_b_key_noise2))
661
+
662
+ print(f"Global Timestep2: {global_timestep2//38}")
663
+
664
+ global_timestep2 += 1
665
+
666
+
667
+
668
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
669
+
670
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
671
+ hidden_states = hidden_states.to(query.dtype)
672
+
673
+ return hidden_states
src/lora_helper.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers.models.attention_processor import FluxAttnProcessor2_0
2
+ from safetensors import safe_open
3
+ import re
4
+ import torch
5
+ from .layers import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor, MultiDoubleStreamBlockLoraProcessor_visual, MultiDoubleStreamBlockLoraProcessorWithLoss, MultiSingleStreamBlockLoraProcessor_visual
6
+
7
+
8
+
9
+ device = "cuda:0"
10
+
11
+ def load_safetensors(path):
12
+ tensors = {}
13
+ with safe_open(path, framework="pt", device="cpu") as f:
14
+ for key in f.keys():
15
+ tensors[key] = f.get_tensor(key)
16
+ return tensors
17
+
18
+ def get_lora_rank(checkpoint):
19
+ for k in checkpoint.keys():
20
+ if k.endswith(".down.weight"):
21
+ return checkpoint[k].shape[0]
22
+
23
+ def load_checkpoint(local_path):
24
+ if local_path is not None:
25
+ if '.safetensors' in local_path:
26
+ print(f"Loading .safetensors checkpoint from {local_path}")
27
+ checkpoint = load_safetensors(local_path)
28
+ else:
29
+ print(f"Loading checkpoint from {local_path}")
30
+ checkpoint = torch.load(local_path, map_location='cpu')
31
+ return checkpoint
32
+
33
+ def update_model_with_lora(checkpoint, lora_weights, transformer):
34
+ number = len(lora_weights)
35
+ ranks = [get_lora_rank(checkpoint) for _ in range(number)]
36
+ lora_attn_procs = {}
37
+ double_blocks_idx = list(range(19))
38
+ single_blocks_idx = list(range(38))
39
+ for name, attn_processor in transformer.attn_processors.items():
40
+ match = re.search(r'\.(\d+)\.', name)
41
+ if match:
42
+ layer_index = int(match.group(1))
43
+
44
+ if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
45
+
46
+ lora_state_dicts = {}
47
+ for key, value in checkpoint.items():
48
+ # Match based on the layer index in the key (assuming the key contains layer index)
49
+ if re.search(r'\.(\d+)\.', key):
50
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
51
+ if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
52
+ lora_state_dicts[key] = value
53
+
54
+ lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
55
+ in_features=3072, out_features=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, n_loras=number
56
+ )
57
+
58
+ # Load the weights from the checkpoint dictionary into the corresponding layers
59
+ for n in range(number):
60
+ lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
61
+ lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
62
+ lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
63
+ lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
64
+ lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
65
+ lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
66
+ lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.down.weight', None)
67
+ lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.up.weight', None)
68
+ lora_attn_procs[name].to(device)
69
+
70
+ elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
71
+
72
+ lora_state_dicts = {}
73
+ for key, value in checkpoint.items():
74
+ # Match based on the layer index in the key (assuming the key contains layer index)
75
+ if re.search(r'\.(\d+)\.', key):
76
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
77
+ if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
78
+ lora_state_dicts[key] = value
79
+
80
+ lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
81
+ in_features=3072, out_features=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, n_loras=number
82
+ )
83
+ # Load the weights from the checkpoint dictionary into the corresponding layers
84
+ for n in range(number):
85
+ lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
86
+ lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
87
+ lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
88
+ lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
89
+ lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
90
+ lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
91
+ lora_attn_procs[name].to(device)
92
+ else:
93
+ lora_attn_procs[name] = FluxAttnProcessor2_0()
94
+
95
+ transformer.set_attn_processor(lora_attn_procs)
96
+
97
+
98
+ def update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size):
99
+ ck_number = len(checkpoints)
100
+ cond_lora_number = [len(ls) for ls in lora_weights]
101
+ cond_number = sum(cond_lora_number)
102
+ ranks = [get_lora_rank(checkpoint) for checkpoint in checkpoints]
103
+ multi_lora_weight = []
104
+ for ls in lora_weights:
105
+ for n in ls:
106
+ multi_lora_weight.append(n)
107
+
108
+ lora_attn_procs = {}
109
+ double_blocks_idx = list(range(19))
110
+ single_blocks_idx = list(range(38))
111
+ for name, attn_processor in transformer.attn_processors.items():
112
+ match = re.search(r'\.(\d+)\.', name)
113
+ if match:
114
+ layer_index = int(match.group(1))
115
+
116
+ if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
117
+ lora_state_dicts = [{} for _ in range(ck_number)]
118
+ for idx, checkpoint in enumerate(checkpoints):
119
+ for key, value in checkpoint.items():
120
+ # Match based on the layer index in the key (assuming the key contains layer index)
121
+ if re.search(r'\.(\d+)\.', key):
122
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
123
+ if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
124
+ lora_state_dicts[idx][key] = value
125
+
126
+ lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
127
+ dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number
128
+ )
129
+
130
+ # Load the weights from the checkpoint dictionary into the corresponding layers
131
+ num = 0
132
+ for idx in range(ck_number):
133
+ for n in range(cond_lora_number[idx]):
134
+ lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None)
135
+ lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None)
136
+ lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None)
137
+ lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None)
138
+ lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None)
139
+ lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None)
140
+ lora_attn_procs[name].proj_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.down.weight', None)
141
+ lora_attn_procs[name].proj_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.up.weight', None)
142
+ lora_attn_procs[name].to(device)
143
+ num += 1
144
+
145
+ elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
146
+
147
+ lora_state_dicts = [{} for _ in range(ck_number)]
148
+ for idx, checkpoint in enumerate(checkpoints):
149
+ for key, value in checkpoint.items():
150
+ # Match based on the layer index in the key (assuming the key contains layer index)
151
+ if re.search(r'\.(\d+)\.', key):
152
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
153
+ if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
154
+ lora_state_dicts[idx][key] = value
155
+
156
+ lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
157
+ dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number
158
+ )
159
+ # Load the weights from the checkpoint dictionary into the corresponding layers
160
+ num = 0
161
+ for idx in range(ck_number):
162
+ for n in range(cond_lora_number[idx]):
163
+ lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None)
164
+ lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None)
165
+ lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None)
166
+ lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None)
167
+ lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None)
168
+ lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None)
169
+ lora_attn_procs[name].to(device)
170
+ num += 1
171
+
172
+ else:
173
+ lora_attn_procs[name] = FluxAttnProcessor2_0()
174
+
175
+ transformer.set_attn_processor(lora_attn_procs)
176
+
177
+
178
+ def set_single_lora(transformer, local_path, lora_weights=[]):
179
+ checkpoint = load_checkpoint(local_path)
180
+ update_model_with_lora(checkpoint, lora_weights, transformer)
181
+
182
+ def set_single_lora_visual(transformer, local_path, lora_weights=[]):
183
+ checkpoint = load_checkpoint(local_path)
184
+ update_model_with_lora_with_visual(checkpoint, lora_weights, transformer)
185
+
186
+ def set_multi_lora(transformer, local_paths, lora_weights=[[]], cond_size=512):
187
+ checkpoints = [load_checkpoint(local_path) for local_path in local_paths]
188
+ update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size)
189
+
190
+ def unset_lora(transformer):
191
+ lora_attn_procs = {}
192
+ for name, attn_processor in transformer.attn_processors.items():
193
+ lora_attn_procs[name] = FluxAttnProcessor2_0()
194
+ transformer.set_attn_processor(lora_attn_procs)
195
+
196
+ def update_model_with_lora_with_visual(checkpoint, lora_weights, transformer):
197
+ number = len(lora_weights)
198
+ ranks = [get_lora_rank(checkpoint) for _ in range(number)]
199
+ lora_attn_procs = {}
200
+ double_blocks_idx = list(range(19))
201
+ single_blocks_idx = list(range(38))
202
+ for name, attn_processor in transformer.attn_processors.items():
203
+ match = re.search(r'\.(\d+)\.', name)
204
+ if match:
205
+ layer_index = int(match.group(1))
206
+
207
+ if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
208
+
209
+ lora_state_dicts = {}
210
+ for key, value in checkpoint.items():
211
+ # Match based on the layer index in the key (assuming the key contains layer index)
212
+ if re.search(r'\.(\d+)\.', key):
213
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
214
+ if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
215
+ lora_state_dicts[key] = value
216
+
217
+ lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor_visual(
218
+ in_features=3072, out_features=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, n_loras=number
219
+ )
220
+
221
+ # Load the weights from the checkpoint dictionary into the corresponding layers
222
+ # for n in range(number):
223
+ # lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
224
+ # lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
225
+ # lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
226
+ # lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
227
+ # lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
228
+ # lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
229
+ # lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.down.weight', None)
230
+ # lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.up.weight', None)
231
+ # lora_attn_procs[name].to(device)
232
+
233
+ elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
234
+
235
+ lora_state_dicts = {}
236
+ for key, value in checkpoint.items():
237
+ # Match based on the layer index in the key (assuming the key contains layer index)
238
+ if re.search(r'\.(\d+)\.', key):
239
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
240
+ if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
241
+ lora_state_dicts[key] = value
242
+
243
+ lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor_visual(
244
+ in_features=3072, out_features=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, n_loras=number
245
+ )
246
+ # Load the weights from the checkpoint dictionary into the corresponding layers
247
+ # for n in range(number):
248
+ # lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
249
+ # lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
250
+ # lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
251
+ # lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
252
+ # lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
253
+ # lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
254
+ # lora_attn_procs[name].to(device)
255
+ else:
256
+ lora_attn_procs[name] = FluxAttnProcessor2_0()
257
+
258
+ transformer.set_attn_processor(lora_attn_procs)
259
+
260
+
261
+
262
+ '''
263
+ unset_lora(pipe.transformer)
264
+ lora_path = "./lora.safetensors"
265
+ lora_weights = [1, 1]
266
+ set_lora(pipe.transformer, local_path=lora_path, lora_weights=lora_weights, cond_size=512)
267
+ '''
src/pipeline.py ADDED
@@ -0,0 +1,805 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
7
+
8
+ from diffusers.image_processor import (VaeImageProcessor)
9
+ from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin
10
+ from diffusers.models.autoencoders import AutoencoderKL
11
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
12
+ from diffusers.utils import (
13
+ USE_PEFT_BACKEND,
14
+ is_torch_xla_available,
15
+ logging,
16
+ scale_lora_layers,
17
+ unscale_lora_layers,
18
+ )
19
+ from diffusers.utils.torch_utils import randn_tensor
20
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
21
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
22
+ from torchvision.transforms.functional import pad
23
+ from .transformer_flux import FluxTransformer2DModel
24
+
25
+ if is_torch_xla_available():
26
+ import torch_xla.core.xla_model as xm
27
+
28
+ XLA_AVAILABLE = True
29
+ else:
30
+ XLA_AVAILABLE = False
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+ def calculate_shift(
35
+ image_seq_len,
36
+ base_seq_len: int = 256,
37
+ max_seq_len: int = 4096,
38
+ base_shift: float = 0.5,
39
+ max_shift: float = 1.16,
40
+ ):
41
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
42
+ b = base_shift - m * base_seq_len
43
+ mu = image_seq_len * m + b
44
+ return mu
45
+
46
+ def prepare_latent_image_ids_2(height, width, device, dtype):
47
+ latent_image_ids = torch.zeros(height//2, width//2, 3, device=device, dtype=dtype)
48
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height//2, device=device)[:, None] # y坐标
49
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width//2, device=device)[None, :] # x坐标
50
+ return latent_image_ids
51
+
52
+ def prepare_latent_subject_ids(height, width, device, dtype):
53
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3, device=device, dtype=dtype)
54
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2, device=device)[:, None]
55
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2, device=device)[None, :]
56
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
57
+ latent_image_ids = latent_image_ids.reshape(
58
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
59
+ )
60
+ return latent_image_ids.to(device=device, dtype=dtype)
61
+
62
+ def resize_position_encoding(batch_size, original_height, original_width, target_height, target_width, device, dtype):
63
+ latent_image_ids = prepare_latent_image_ids_2(original_height, original_width, device, dtype)
64
+ scale_h = original_height / target_height
65
+ scale_w = original_width / target_width
66
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
67
+ latent_image_ids = latent_image_ids.reshape(
68
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
69
+ )
70
+ #spatial进行PE插值
71
+ latent_image_ids_resized = torch.zeros(target_height//2, target_width//2, 3, device=device, dtype=dtype)
72
+ for i in range(target_height//2):
73
+ for j in range(target_width//2):
74
+ latent_image_ids_resized[i, j, 1] = i*scale_h
75
+ latent_image_ids_resized[i, j, 2] = j*scale_w
76
+ cond_latent_image_id_height, cond_latent_image_id_width, cond_latent_image_id_channels = latent_image_ids_resized.shape
77
+ cond_latent_image_ids = latent_image_ids_resized.reshape(
78
+ cond_latent_image_id_height * cond_latent_image_id_width, cond_latent_image_id_channels
79
+ )
80
+ # latent_image_ids_ = torch.concat([latent_image_ids, cond_latent_image_ids], dim=0)
81
+ return latent_image_ids, cond_latent_image_ids #, latent_image_ids_
82
+
83
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
84
+ def retrieve_latents(
85
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
86
+ ):
87
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
88
+ return encoder_output.latent_dist.sample(generator)
89
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
90
+ return encoder_output.latent_dist.mode()
91
+ elif hasattr(encoder_output, "latents"):
92
+ return encoder_output.latents
93
+ else:
94
+ raise AttributeError("Could not access latents of provided encoder_output")
95
+
96
+
97
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
98
+ def retrieve_timesteps(
99
+ scheduler,
100
+ num_inference_steps: Optional[int] = None,
101
+ device: Optional[Union[str, torch.device]] = None,
102
+ timesteps: Optional[List[int]] = None,
103
+ sigmas: Optional[List[float]] = None,
104
+ **kwargs,
105
+ ):
106
+ """
107
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
108
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
109
+
110
+ Args:
111
+ scheduler (`SchedulerMixin`):
112
+ The scheduler to get timesteps from.
113
+ num_inference_steps (`int`):
114
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
115
+ must be `None`.
116
+ device (`str` or `torch.device`, *optional*):
117
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
118
+ timesteps (`List[int]`, *optional*):
119
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
120
+ `num_inference_steps` and `sigmas` must be `None`.
121
+ sigmas (`List[float]`, *optional*):
122
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
123
+ `num_inference_steps` and `timesteps` must be `None`.
124
+
125
+ Returns:
126
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
127
+ second element is the number of inference steps.
128
+ """
129
+ if timesteps is not None and sigmas is not None:
130
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
131
+ if timesteps is not None:
132
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
133
+ if not accepts_timesteps:
134
+ raise ValueError(
135
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
136
+ f" timestep schedules. Please check whether you are using the correct scheduler."
137
+ )
138
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
139
+ timesteps = scheduler.timesteps
140
+ num_inference_steps = len(timesteps)
141
+ elif sigmas is not None:
142
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
143
+ if not accept_sigmas:
144
+ raise ValueError(
145
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
146
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
147
+ )
148
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
149
+ timesteps = scheduler.timesteps
150
+ num_inference_steps = len(timesteps)
151
+ else:
152
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
153
+ timesteps = scheduler.timesteps
154
+ return timesteps, num_inference_steps
155
+
156
+
157
+ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
158
+ r"""
159
+ The Flux pipeline for text-to-image generation.
160
+
161
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
162
+
163
+ Args:
164
+ transformer ([`FluxTransformer2DModel`]):
165
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
166
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
167
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
168
+ vae ([`AutoencoderKL`]):
169
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
170
+ text_encoder ([`CLIPTextModel`]):
171
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
172
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
173
+ text_encoder_2 ([`T5EncoderModel`]):
174
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
175
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
176
+ tokenizer (`CLIPTokenizer`):
177
+ Tokenizer of class
178
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
179
+ tokenizer_2 (`T5TokenizerFast`):
180
+ Second Tokenizer of class
181
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
182
+ """
183
+
184
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
185
+ _optional_components = []
186
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
187
+
188
+ def __init__(
189
+ self,
190
+ scheduler: FlowMatchEulerDiscreteScheduler,
191
+ vae: AutoencoderKL,
192
+ text_encoder: CLIPTextModel,
193
+ tokenizer: CLIPTokenizer,
194
+ text_encoder_2: T5EncoderModel,
195
+ tokenizer_2: T5TokenizerFast,
196
+ transformer: FluxTransformer2DModel,
197
+ ):
198
+ super().__init__()
199
+
200
+ self.register_modules(
201
+ vae=vae,
202
+ text_encoder=text_encoder,
203
+ text_encoder_2=text_encoder_2,
204
+ tokenizer=tokenizer,
205
+ tokenizer_2=tokenizer_2,
206
+ transformer=transformer,
207
+ scheduler=scheduler,
208
+ )
209
+ self.vae_scale_factor = (
210
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
211
+ )
212
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
213
+ self.tokenizer_max_length = (
214
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
215
+ )
216
+ self.default_sample_size = 64
217
+
218
+ def _get_t5_prompt_embeds(
219
+ self,
220
+ prompt: Union[str, List[str]] = None,
221
+ num_images_per_prompt: int = 1,
222
+ max_sequence_length: int = 512,
223
+ device: Optional[torch.device] = None,
224
+ dtype: Optional[torch.dtype] = None,
225
+ ):
226
+ device = device or self._execution_device
227
+ dtype = dtype or self.text_encoder.dtype
228
+
229
+ prompt = [prompt] if isinstance(prompt, str) else prompt
230
+ batch_size = len(prompt)
231
+
232
+ text_inputs = self.tokenizer_2(
233
+ prompt,
234
+ padding="max_length",
235
+ max_length=max_sequence_length,
236
+ truncation=True,
237
+ return_length=False,
238
+ return_overflowing_tokens=False,
239
+ return_tensors="pt",
240
+ )
241
+ text_input_ids = text_inputs.input_ids
242
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
243
+
244
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
245
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1: -1])
246
+ logger.warning(
247
+ "The following part of your input was truncated because `max_sequence_length` is set to "
248
+ f" {max_sequence_length} tokens: {removed_text}"
249
+ )
250
+
251
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
252
+
253
+ dtype = self.text_encoder_2.dtype
254
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
255
+
256
+ _, seq_len, _ = prompt_embeds.shape
257
+
258
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
259
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
260
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
261
+
262
+ return prompt_embeds
263
+
264
+ def _get_clip_prompt_embeds(
265
+ self,
266
+ prompt: Union[str, List[str]],
267
+ num_images_per_prompt: int = 1,
268
+ device: Optional[torch.device] = None,
269
+ ):
270
+ device = device or self._execution_device
271
+
272
+ prompt = [prompt] if isinstance(prompt, str) else prompt
273
+ batch_size = len(prompt)
274
+
275
+ text_inputs = self.tokenizer(
276
+ prompt,
277
+ padding="max_length",
278
+ max_length=self.tokenizer_max_length,
279
+ truncation=True,
280
+ return_overflowing_tokens=False,
281
+ return_length=False,
282
+ return_tensors="pt",
283
+ )
284
+
285
+ text_input_ids = text_inputs.input_ids
286
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
287
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
288
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1: -1])
289
+ logger.warning(
290
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
291
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
292
+ )
293
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
294
+
295
+ # Use pooled output of CLIPTextModel
296
+ prompt_embeds = prompt_embeds.pooler_output
297
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
298
+
299
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
300
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
301
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
302
+
303
+ return prompt_embeds
304
+
305
+ def encode_prompt(
306
+ self,
307
+ prompt: Union[str, List[str]],
308
+ prompt_2: Union[str, List[str]],
309
+ device: Optional[torch.device] = None,
310
+ num_images_per_prompt: int = 1,
311
+ prompt_embeds: Optional[torch.FloatTensor] = None,
312
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
313
+ max_sequence_length: int = 512,
314
+ lora_scale: Optional[float] = None,
315
+ ):
316
+ r"""
317
+
318
+ Args:
319
+ prompt (`str` or `List[str]`, *optional*):
320
+ prompt to be encoded
321
+ prompt_2 (`str` or `List[str]`, *optional*):
322
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
323
+ used in all text-encoders
324
+ device: (`torch.device`):
325
+ torch device
326
+ num_images_per_prompt (`int`):
327
+ number of images that should be generated per prompt
328
+ prompt_embeds (`torch.FloatTensor`, *optional*):
329
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
330
+ provided, text embeddings will be generated from `prompt` input argument.
331
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
332
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
333
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
334
+ lora_scale (`float`, *optional*):
335
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
336
+ """
337
+ device = device or self._execution_device
338
+
339
+ # set lora scale so that monkey patched LoRA
340
+ # function of text encoder can correctly access it
341
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
342
+ self._lora_scale = lora_scale
343
+
344
+ # dynamically adjust the LoRA scale
345
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
346
+ scale_lora_layers(self.text_encoder, lora_scale)
347
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
348
+ scale_lora_layers(self.text_encoder_2, lora_scale)
349
+
350
+ prompt = [prompt] if isinstance(prompt, str) else prompt
351
+
352
+ if prompt_embeds is None:
353
+ prompt_2 = prompt_2 or prompt
354
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
355
+
356
+ # We only use the pooled prompt output from the CLIPTextModel
357
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
358
+ prompt=prompt,
359
+ device=device,
360
+ num_images_per_prompt=num_images_per_prompt,
361
+ )
362
+ prompt_embeds = self._get_t5_prompt_embeds(
363
+ prompt=prompt_2,
364
+ num_images_per_prompt=num_images_per_prompt,
365
+ max_sequence_length=max_sequence_length,
366
+ device=device,
367
+ )
368
+
369
+ if self.text_encoder is not None:
370
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
371
+ # Retrieve the original scale by scaling back the LoRA layers
372
+ unscale_lora_layers(self.text_encoder, lora_scale)
373
+
374
+ if self.text_encoder_2 is not None:
375
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
376
+ # Retrieve the original scale by scaling back the LoRA layers
377
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
378
+
379
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
380
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
381
+
382
+ return prompt_embeds, pooled_prompt_embeds, text_ids
383
+
384
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
385
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
386
+ if isinstance(generator, list):
387
+ image_latents = [
388
+ retrieve_latents(self.vae.encode(image[i: i + 1]), generator=generator[i])
389
+ for i in range(image.shape[0])
390
+ ]
391
+ image_latents = torch.cat(image_latents, dim=0)
392
+ else:
393
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
394
+
395
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
396
+
397
+ return image_latents
398
+
399
+ def check_inputs(
400
+ self,
401
+ prompt,
402
+ prompt_2,
403
+ height,
404
+ width,
405
+ prompt_embeds=None,
406
+ pooled_prompt_embeds=None,
407
+ callback_on_step_end_tensor_inputs=None,
408
+ max_sequence_length=None,
409
+ ):
410
+ if height % 8 != 0 or width % 8 != 0:
411
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
412
+
413
+ if callback_on_step_end_tensor_inputs is not None and not all(
414
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
415
+ ):
416
+ raise ValueError(
417
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
418
+ )
419
+
420
+ if prompt is not None and prompt_embeds is not None:
421
+ raise ValueError(
422
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
423
+ " only forward one of the two."
424
+ )
425
+ elif prompt_2 is not None and prompt_embeds is not None:
426
+ raise ValueError(
427
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
428
+ " only forward one of the two."
429
+ )
430
+ elif prompt is None and prompt_embeds is None:
431
+ raise ValueError(
432
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
433
+ )
434
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
435
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
436
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
437
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
438
+
439
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
440
+ raise ValueError(
441
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
442
+ )
443
+
444
+ if max_sequence_length is not None and max_sequence_length > 512:
445
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
446
+
447
+ @staticmethod
448
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
449
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
450
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
451
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
452
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
453
+ latent_image_ids = latent_image_ids.reshape(
454
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
455
+ )
456
+ return latent_image_ids.to(device=device, dtype=dtype)
457
+
458
+ @staticmethod
459
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
460
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
461
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
462
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
463
+ return latents
464
+
465
+ @staticmethod
466
+ def _unpack_latents(latents, height, width, vae_scale_factor):
467
+ batch_size, num_patches, channels = latents.shape
468
+
469
+ height = height // vae_scale_factor
470
+ width = width // vae_scale_factor
471
+
472
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
473
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
474
+
475
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
476
+
477
+ return latents
478
+
479
+ def enable_vae_slicing(self):
480
+ r"""
481
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
482
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
483
+ """
484
+ self.vae.enable_slicing()
485
+
486
+ def disable_vae_slicing(self):
487
+ r"""
488
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
489
+ computing decoding in one step.
490
+ """
491
+ self.vae.disable_slicing()
492
+
493
+ def enable_vae_tiling(self):
494
+ r"""
495
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
496
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
497
+ processing larger images.
498
+ """
499
+ self.vae.enable_tiling()
500
+
501
+ def disable_vae_tiling(self):
502
+ r"""
503
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
504
+ computing decoding in one step.
505
+ """
506
+ self.vae.disable_tiling()
507
+
508
+ def prepare_latents(
509
+ self,
510
+ batch_size,
511
+ num_channels_latents,
512
+ height,
513
+ width,
514
+ dtype,
515
+ device,
516
+ generator,
517
+ subject_image,
518
+ condition_image,
519
+ latents=None,
520
+ cond_number=1,
521
+ sub_number=1
522
+ ):
523
+ height_cond = 2 * (self.cond_size // self.vae_scale_factor)
524
+ width_cond = 2 * (self.cond_size // self.vae_scale_factor)
525
+ height = 2 * (int(height) // self.vae_scale_factor)
526
+ width = 2 * (int(width) // self.vae_scale_factor)
527
+
528
+ shape = (batch_size, num_channels_latents, height, width) # 1 16 106 80
529
+ noise_latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
530
+ noise_latents = self._pack_latents(noise_latents, batch_size, num_channels_latents, height, width)
531
+ noise_latent_image_ids, cond_latent_image_ids = resize_position_encoding(
532
+ batch_size,
533
+ height,
534
+ width,
535
+ height_cond,
536
+ width_cond,
537
+ device,
538
+ dtype,
539
+ )
540
+
541
+ latents_to_concat = [] # 不包含 latents
542
+ latents_ids_to_concat = [noise_latent_image_ids]
543
+
544
+ # subject
545
+ if subject_image is not None:
546
+ shape_subject = (batch_size, num_channels_latents, height_cond*sub_number, width_cond)
547
+ subject_image = subject_image.to(device=device, dtype=dtype)
548
+ subject_image_latents = self._encode_vae_image(image=subject_image, generator=generator)
549
+ subject_latents = self._pack_latents(subject_image_latents, batch_size, num_channels_latents, height_cond*sub_number, width_cond)
550
+ mask2 = torch.zeros(shape_subject, device=device, dtype=dtype)
551
+ mask2 = self._pack_latents(mask2, batch_size, num_channels_latents, height_cond*sub_number, width_cond)
552
+ latent_subject_ids = prepare_latent_subject_ids(height_cond, width_cond, device, dtype)
553
+ latent_subject_ids[:, 1] += 64 # fixed offset
554
+ subject_latent_image_ids = torch.concat([latent_subject_ids for _ in range(sub_number)], dim=-2)
555
+ latents_to_concat.append(subject_latents)
556
+ latents_ids_to_concat.append(subject_latent_image_ids)
557
+
558
+ # spatial
559
+ if condition_image is not None:
560
+ shape_cond = (batch_size, num_channels_latents, height_cond*cond_number, width_cond)
561
+ condition_image = condition_image.to(device=device, dtype=dtype)
562
+ image_latents = self._encode_vae_image(image=condition_image, generator=generator)
563
+ cond_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height_cond*cond_number, width_cond)
564
+ mask3 = torch.zeros(shape_cond, device=device, dtype=dtype)
565
+ mask3 = self._pack_latents(mask3, batch_size, num_channels_latents, height_cond*cond_number, width_cond)
566
+ cond_latent_image_ids = cond_latent_image_ids
567
+ cond_latent_image_ids = torch.concat([cond_latent_image_ids for _ in range(cond_number)], dim=-2)
568
+ latents_ids_to_concat.append(cond_latent_image_ids)
569
+ latents_to_concat.append(cond_latents)
570
+
571
+ cond_latents = torch.concat(latents_to_concat, dim=-2)
572
+ latent_image_ids = torch.concat(latents_ids_to_concat, dim=-2)
573
+ return cond_latents, latent_image_ids, noise_latents
574
+
575
+ @property
576
+ def guidance_scale(self):
577
+ return self._guidance_scale
578
+
579
+ @property
580
+ def joint_attention_kwargs(self):
581
+ return self._joint_attention_kwargs
582
+
583
+ @property
584
+ def num_timesteps(self):
585
+ return self._num_timesteps
586
+
587
+ @property
588
+ def interrupt(self):
589
+ return self._interrupt
590
+
591
+ @torch.no_grad()
592
+ def __call__(
593
+ self,
594
+ prompt: Union[str, List[str]] = None,
595
+ prompt_2: Optional[Union[str, List[str]]] = None,
596
+ height: Optional[int] = None,
597
+ width: Optional[int] = None,
598
+ num_inference_steps: int = 28,
599
+ timesteps: List[int] = None,
600
+ guidance_scale: float = 3.5,
601
+ num_images_per_prompt: Optional[int] = 1,
602
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
603
+ latents: Optional[torch.FloatTensor] = None,
604
+ prompt_embeds: Optional[torch.FloatTensor] = None,
605
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
606
+ output_type: Optional[str] = "pil",
607
+ return_dict: bool = True,
608
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
609
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
610
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
611
+ max_sequence_length: int = 512,
612
+ spatial_images=None,
613
+ subject_images=None,
614
+ cond_size=512,
615
+ ):
616
+
617
+ height = height or self.default_sample_size * self.vae_scale_factor
618
+ width = width or self.default_sample_size * self.vae_scale_factor
619
+ self.cond_size = cond_size
620
+
621
+ # 1. Check inputs. Raise error if not correct
622
+ self.check_inputs(
623
+ prompt,
624
+ prompt_2,
625
+ height,
626
+ width,
627
+ prompt_embeds=prompt_embeds,
628
+ pooled_prompt_embeds=pooled_prompt_embeds,
629
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
630
+ max_sequence_length=max_sequence_length,
631
+ )
632
+
633
+ self._guidance_scale = guidance_scale
634
+ self._joint_attention_kwargs = joint_attention_kwargs
635
+ self._interrupt = False
636
+
637
+ cond_number = len(spatial_images)
638
+ sub_number = len(subject_images)
639
+
640
+ if sub_number > 0:
641
+ subject_image_ls = []
642
+ for subject_image in subject_images:
643
+ w, h = subject_image.size[:2]
644
+ scale = self.cond_size / max(h, w)
645
+ new_h, new_w = int(h * scale), int(w * scale)
646
+ subject_image = self.image_processor.preprocess(subject_image, height=new_h, width=new_w)
647
+ subject_image = subject_image.to(dtype=torch.float32)
648
+ pad_h = cond_size - subject_image.shape[-2]
649
+ pad_w = cond_size - subject_image.shape[-1]
650
+ subject_image = pad(
651
+ subject_image,
652
+ padding=(int(pad_w / 2), int(pad_h / 2), int(pad_w / 2), int(pad_h / 2)),
653
+ fill=0
654
+ )
655
+ subject_image_ls.append(subject_image)
656
+ subject_image = torch.concat(subject_image_ls, dim=-2)
657
+ else:
658
+ subject_image = None
659
+
660
+ if cond_number > 0:
661
+ condition_image_ls = []
662
+ for img in spatial_images:
663
+ condition_image = self.image_processor.preprocess(img, height=self.cond_size, width=self.cond_size)
664
+ condition_image = condition_image.to(dtype=torch.float32)
665
+ condition_image_ls.append(condition_image)
666
+ condition_image = torch.concat(condition_image_ls, dim=-2)
667
+ else:
668
+ condition_image = None
669
+
670
+ # 2. Define call parameters
671
+ if prompt is not None and isinstance(prompt, str):
672
+ batch_size = 1
673
+ elif prompt is not None and isinstance(prompt, list):
674
+ batch_size = len(prompt)
675
+ else:
676
+ batch_size = prompt_embeds.shape[0]
677
+
678
+ device = self._execution_device
679
+
680
+ lora_scale = (
681
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
682
+ )
683
+ (
684
+ prompt_embeds,
685
+ pooled_prompt_embeds,
686
+ text_ids,
687
+ ) = self.encode_prompt(
688
+ prompt=prompt,
689
+ prompt_2=prompt_2,
690
+ prompt_embeds=prompt_embeds,
691
+ pooled_prompt_embeds=pooled_prompt_embeds,
692
+ device=device,
693
+ num_images_per_prompt=num_images_per_prompt,
694
+ max_sequence_length=max_sequence_length,
695
+ lora_scale=lora_scale,
696
+ )
697
+
698
+ # 4. Prepare latent variables
699
+ num_channels_latents = self.transformer.config.in_channels // 4 # 16
700
+ cond_latents, latent_image_ids, noise_latents = self.prepare_latents(
701
+ batch_size * num_images_per_prompt,
702
+ num_channels_latents,
703
+ height,
704
+ width,
705
+ prompt_embeds.dtype,
706
+ device,
707
+ generator,
708
+ subject_image,
709
+ condition_image,
710
+ latents,
711
+ cond_number,
712
+ sub_number
713
+ )
714
+ latents = noise_latents
715
+ # 5. Prepare timesteps
716
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
717
+ image_seq_len = latents.shape[1]
718
+ mu = calculate_shift(
719
+ image_seq_len,
720
+ self.scheduler.config.base_image_seq_len,
721
+ self.scheduler.config.max_image_seq_len,
722
+ self.scheduler.config.base_shift,
723
+ self.scheduler.config.max_shift,
724
+ )
725
+ timesteps, num_inference_steps = retrieve_timesteps(
726
+ self.scheduler,
727
+ num_inference_steps,
728
+ device,
729
+ timesteps,
730
+ sigmas,
731
+ mu=mu,
732
+ )
733
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
734
+ self._num_timesteps = len(timesteps)
735
+
736
+ # handle guidance
737
+ if self.transformer.config.guidance_embeds:
738
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
739
+ guidance = guidance.expand(latents.shape[0])
740
+ else:
741
+ guidance = None
742
+
743
+ # 6. Denoising loop
744
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
745
+ for i, t in enumerate(timesteps):
746
+ if self.interrupt:
747
+ continue
748
+
749
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
750
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
751
+ noise_pred = self.transformer(
752
+ hidden_states=latents, # 1 4096 64
753
+ cond_hidden_states=cond_latents,
754
+ timestep=timestep / 1000,
755
+ guidance=guidance,
756
+ pooled_projections=pooled_prompt_embeds,
757
+ encoder_hidden_states=prompt_embeds,
758
+ txt_ids=text_ids,
759
+ img_ids=latent_image_ids,
760
+ joint_attention_kwargs=self.joint_attention_kwargs,
761
+ return_dict=False,
762
+ )[0]
763
+
764
+ # compute the previous noisy sample x_t -> x_t-1
765
+ latents_dtype = latents.dtype
766
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
767
+ latents = latents
768
+
769
+ if latents.dtype != latents_dtype:
770
+ if torch.backends.mps.is_available():
771
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
772
+ latents = latents.to(latents_dtype)
773
+
774
+ if callback_on_step_end is not None:
775
+ callback_kwargs = {}
776
+ for k in callback_on_step_end_tensor_inputs:
777
+ callback_kwargs[k] = locals()[k]
778
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
779
+
780
+ latents = callback_outputs.pop("latents", latents)
781
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
782
+
783
+ # call the callback, if provided
784
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
785
+ progress_bar.update()
786
+
787
+ if XLA_AVAILABLE:
788
+ xm.mark_step()
789
+
790
+ if output_type == "latent":
791
+ image = latents
792
+
793
+ else:
794
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
795
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
796
+ image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
797
+ image = self.image_processor.postprocess(image, output_type=output_type)
798
+
799
+ # Offload all models
800
+ self.maybe_free_model_hooks()
801
+
802
+ if not return_dict:
803
+ return (image,)
804
+
805
+ return FluxPipelineOutput(images=image)
src/prompt_helper.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def load_text_encoders(args, class_one, class_two):
5
+ text_encoder_one = class_one.from_pretrained(
6
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
7
+ )
8
+ text_encoder_two = class_two.from_pretrained(
9
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
10
+ )
11
+ return text_encoder_one, text_encoder_two
12
+
13
+
14
+ def tokenize_prompt(tokenizer, prompt, max_sequence_length):
15
+ text_inputs = tokenizer(
16
+ prompt,
17
+ padding="max_length",
18
+ max_length=max_sequence_length,
19
+ truncation=True,
20
+ return_length=False,
21
+ return_overflowing_tokens=False,
22
+ return_tensors="pt",
23
+ )
24
+ text_input_ids = text_inputs.input_ids
25
+ return text_input_ids
26
+
27
+
28
+ def tokenize_prompt_clip(tokenizer, prompt):
29
+ text_inputs = tokenizer(
30
+ prompt,
31
+ padding="max_length",
32
+ max_length=77,
33
+ truncation=True,
34
+ return_length=False,
35
+ return_overflowing_tokens=False,
36
+ return_tensors="pt",
37
+ )
38
+ text_input_ids = text_inputs.input_ids
39
+ return text_input_ids
40
+
41
+
42
+ def tokenize_prompt_t5(tokenizer, prompt):
43
+ text_inputs = tokenizer(
44
+ prompt,
45
+ padding="max_length",
46
+ max_length=512,
47
+ truncation=True,
48
+ return_length=False,
49
+ return_overflowing_tokens=False,
50
+ return_tensors="pt",
51
+ )
52
+ text_input_ids = text_inputs.input_ids
53
+ return text_input_ids
54
+
55
+
56
+ def _encode_prompt_with_t5(
57
+ text_encoder,
58
+ tokenizer,
59
+ max_sequence_length=512,
60
+ prompt=None,
61
+ num_images_per_prompt=1,
62
+ device=None,
63
+ text_input_ids=None,
64
+ ):
65
+ prompt = [prompt] if isinstance(prompt, str) else prompt
66
+ batch_size = len(prompt)
67
+
68
+ if tokenizer is not None:
69
+ text_inputs = tokenizer(
70
+ prompt,
71
+ padding="max_length",
72
+ max_length=max_sequence_length,
73
+ truncation=True,
74
+ return_length=False,
75
+ return_overflowing_tokens=False,
76
+ return_tensors="pt",
77
+ )
78
+ text_input_ids = text_inputs.input_ids
79
+ else:
80
+ if text_input_ids is None:
81
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
82
+
83
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
84
+
85
+ dtype = text_encoder.dtype
86
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
87
+
88
+ _, seq_len, _ = prompt_embeds.shape
89
+
90
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
91
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
92
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
93
+
94
+ return prompt_embeds
95
+
96
+
97
+ def _encode_prompt_with_clip(
98
+ text_encoder,
99
+ tokenizer,
100
+ prompt: str,
101
+ device=None,
102
+ text_input_ids=None,
103
+ num_images_per_prompt: int = 1,
104
+ ):
105
+ prompt = [prompt] if isinstance(prompt, str) else prompt
106
+ batch_size = len(prompt)
107
+
108
+ if tokenizer is not None:
109
+ text_inputs = tokenizer(
110
+ prompt,
111
+ padding="max_length",
112
+ max_length=77,
113
+ truncation=True,
114
+ return_overflowing_tokens=False,
115
+ return_length=False,
116
+ return_tensors="pt",
117
+ )
118
+
119
+ text_input_ids = text_inputs.input_ids
120
+ else:
121
+ if text_input_ids is None:
122
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
123
+
124
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
125
+
126
+ # Use pooled output of CLIPTextModel
127
+ prompt_embeds = prompt_embeds.pooler_output
128
+ prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
129
+
130
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
131
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
132
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
133
+
134
+ return prompt_embeds
135
+
136
+
137
+ def encode_prompt(
138
+ text_encoders,
139
+ tokenizers,
140
+ prompt: str,
141
+ max_sequence_length,
142
+ device=None,
143
+ num_images_per_prompt: int = 1,
144
+ text_input_ids_list=None,
145
+ ):
146
+ prompt = [prompt] if isinstance(prompt, str) else prompt
147
+ dtype = text_encoders[0].dtype
148
+
149
+ pooled_prompt_embeds = _encode_prompt_with_clip(
150
+ text_encoder=text_encoders[0],
151
+ tokenizer=tokenizers[0],
152
+ prompt=prompt,
153
+ device=device if device is not None else text_encoders[0].device,
154
+ num_images_per_prompt=num_images_per_prompt,
155
+ text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
156
+ )
157
+
158
+ prompt_embeds = _encode_prompt_with_t5(
159
+ text_encoder=text_encoders[1],
160
+ tokenizer=tokenizers[1],
161
+ max_sequence_length=max_sequence_length,
162
+ prompt=prompt,
163
+ num_images_per_prompt=num_images_per_prompt,
164
+ device=device if device is not None else text_encoders[1].device,
165
+ text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
166
+ )
167
+
168
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
169
+
170
+ return prompt_embeds, pooled_prompt_embeds, text_ids
171
+
172
+
173
+ def encode_token_ids(text_encoders, tokens, accelerator, num_images_per_prompt=1, device=None):
174
+ text_encoder_clip = text_encoders[0]
175
+ text_encoder_t5 = text_encoders[1]
176
+ tokens_clip, tokens_t5 = tokens[0], tokens[1]
177
+ batch_size = tokens_clip.shape[0]
178
+
179
+ if device == "cpu":
180
+ device = "cpu"
181
+ else:
182
+ device = accelerator.device
183
+
184
+ # clip
185
+ prompt_embeds = text_encoder_clip(tokens_clip.to(device), output_hidden_states=False)
186
+ # Use pooled output of CLIPTextModel
187
+ prompt_embeds = prompt_embeds.pooler_output
188
+ prompt_embeds = prompt_embeds.to(dtype=text_encoder_clip.dtype, device=accelerator.device)
189
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
190
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
191
+ pooled_prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
192
+ pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=text_encoder_clip.dtype, device=accelerator.device)
193
+
194
+ # t5
195
+ prompt_embeds = text_encoder_t5(tokens_t5.to(device))[0]
196
+ dtype = text_encoder_t5.dtype
197
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=accelerator.device)
198
+ _, seq_len, _ = prompt_embeds.shape
199
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
200
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
201
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
202
+
203
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=accelerator.device, dtype=dtype)
204
+
205
+ return prompt_embeds, pooled_prompt_embeds, text_ids
src/transformer_flux.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
10
+ from diffusers.models.attention import FeedForward
11
+ from diffusers.models.attention_processor import (
12
+ Attention,
13
+ AttentionProcessor,
14
+ FluxAttnProcessor2_0,
15
+ FluxAttnProcessor2_0_NPU,
16
+ FusedFluxAttnProcessor2_0,
17
+ )
18
+ from diffusers.models.modeling_utils import ModelMixin
19
+ from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
20
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
21
+ from diffusers.utils.import_utils import is_torch_npu_available
22
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
23
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
24
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
25
+
26
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27
+
28
+ @maybe_allow_in_graph
29
+ class FluxSingleTransformerBlock(nn.Module):
30
+
31
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
32
+ super().__init__()
33
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
34
+
35
+ self.norm = AdaLayerNormZeroSingle(dim)
36
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
37
+ self.act_mlp = nn.GELU(approximate="tanh")
38
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
39
+
40
+ if is_torch_npu_available():
41
+ processor = FluxAttnProcessor2_0_NPU()
42
+ else:
43
+ processor = FluxAttnProcessor2_0()
44
+ self.attn = Attention(
45
+ query_dim=dim,
46
+ cross_attention_dim=None,
47
+ dim_head=attention_head_dim,
48
+ heads=num_attention_heads,
49
+ out_dim=dim,
50
+ bias=True,
51
+ processor=processor,
52
+ qk_norm="rms_norm",
53
+ eps=1e-6,
54
+ pre_only=True,
55
+ )
56
+
57
+ def forward(
58
+ self,
59
+ hidden_states: torch.Tensor,
60
+ cond_hidden_states: torch.Tensor,
61
+ temb: torch.Tensor,
62
+ cond_temb: torch.Tensor,
63
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
64
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
65
+ ) -> torch.Tensor:
66
+ use_cond = cond_hidden_states is not None
67
+
68
+ residual = hidden_states
69
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
70
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
71
+
72
+ if use_cond:
73
+ residual_cond = cond_hidden_states
74
+ norm_cond_hidden_states, cond_gate = self.norm(cond_hidden_states, emb=cond_temb)
75
+ mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_cond_hidden_states))
76
+
77
+ norm_hidden_states_concat = torch.concat([norm_hidden_states, norm_cond_hidden_states], dim=-2)
78
+
79
+ joint_attention_kwargs = joint_attention_kwargs or {}
80
+ attn_output = self.attn(
81
+ hidden_states=norm_hidden_states_concat,
82
+ image_rotary_emb=image_rotary_emb,
83
+ use_cond=use_cond,
84
+ **joint_attention_kwargs,
85
+ )
86
+ if use_cond:
87
+ attn_output, cond_attn_output = attn_output
88
+
89
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
90
+ gate = gate.unsqueeze(1)
91
+ hidden_states = gate * self.proj_out(hidden_states)
92
+ hidden_states = residual + hidden_states
93
+
94
+ if use_cond:
95
+ condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
96
+ cond_gate = cond_gate.unsqueeze(1)
97
+ condition_latents = cond_gate * self.proj_out(condition_latents)
98
+ condition_latents = residual_cond + condition_latents
99
+
100
+ if hidden_states.dtype == torch.float16:
101
+ hidden_states = hidden_states.clip(-65504, 65504)
102
+
103
+ return hidden_states, condition_latents if use_cond else None
104
+
105
+
106
+ @maybe_allow_in_graph
107
+ class FluxTransformerBlock(nn.Module):
108
+ def __init__(
109
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
110
+ ):
111
+ super().__init__()
112
+
113
+ self.norm1 = AdaLayerNormZero(dim)
114
+
115
+ self.norm1_context = AdaLayerNormZero(dim)
116
+
117
+ if hasattr(F, "scaled_dot_product_attention"):
118
+ processor = FluxAttnProcessor2_0()
119
+ else:
120
+ raise ValueError(
121
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
122
+ )
123
+ self.attn = Attention(
124
+ query_dim=dim,
125
+ cross_attention_dim=None,
126
+ added_kv_proj_dim=dim,
127
+ dim_head=attention_head_dim,
128
+ heads=num_attention_heads,
129
+ out_dim=dim,
130
+ context_pre_only=False,
131
+ bias=True,
132
+ processor=processor,
133
+ qk_norm=qk_norm,
134
+ eps=eps,
135
+ )
136
+
137
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
138
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
139
+
140
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
141
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
142
+
143
+ # let chunk size default to None
144
+ self._chunk_size = None
145
+ self._chunk_dim = 0
146
+
147
+ def forward(
148
+ self,
149
+ hidden_states: torch.Tensor,
150
+ cond_hidden_states: torch.Tensor,
151
+ encoder_hidden_states: torch.Tensor,
152
+ temb: torch.Tensor,
153
+ cond_temb: torch.Tensor,
154
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
155
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
156
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
157
+ use_cond = cond_hidden_states is not None
158
+
159
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
160
+ if use_cond:
161
+ (
162
+ norm_cond_hidden_states,
163
+ cond_gate_msa,
164
+ cond_shift_mlp,
165
+ cond_scale_mlp,
166
+ cond_gate_mlp,
167
+ ) = self.norm1(cond_hidden_states, emb=cond_temb)
168
+
169
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
170
+ encoder_hidden_states, emb=temb
171
+ )
172
+
173
+ norm_hidden_states = torch.concat([norm_hidden_states, norm_cond_hidden_states], dim=-2)
174
+
175
+ joint_attention_kwargs = joint_attention_kwargs or {}
176
+ # Attention.
177
+ attention_outputs = self.attn(
178
+ hidden_states=norm_hidden_states,
179
+ encoder_hidden_states=norm_encoder_hidden_states,
180
+ image_rotary_emb=image_rotary_emb,
181
+ use_cond=use_cond,
182
+ **joint_attention_kwargs,
183
+ )
184
+
185
+ attn_output, context_attn_output = attention_outputs[:2]
186
+ cond_attn_output = attention_outputs[2] if use_cond else None
187
+
188
+ # Process attention outputs for the `hidden_states`.
189
+ attn_output = gate_msa.unsqueeze(1) * attn_output
190
+ hidden_states = hidden_states + attn_output
191
+
192
+ if use_cond:
193
+ cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
194
+ cond_hidden_states = cond_hidden_states + cond_attn_output
195
+
196
+ norm_hidden_states = self.norm2(hidden_states)
197
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
198
+
199
+ if use_cond:
200
+ norm_cond_hidden_states = self.norm2(cond_hidden_states)
201
+ norm_cond_hidden_states = (
202
+ norm_cond_hidden_states * (1 + cond_scale_mlp[:, None])
203
+ + cond_shift_mlp[:, None]
204
+ )
205
+
206
+ ff_output = self.ff(norm_hidden_states)
207
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
208
+ hidden_states = hidden_states + ff_output
209
+
210
+ if use_cond:
211
+ cond_ff_output = self.ff(norm_cond_hidden_states)
212
+ cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
213
+ cond_hidden_states = cond_hidden_states + cond_ff_output
214
+
215
+ # Process attention outputs for the `encoder_hidden_states`.
216
+
217
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
218
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
219
+
220
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
221
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
222
+
223
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
224
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
225
+ if encoder_hidden_states.dtype == torch.float16:
226
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
227
+
228
+ return encoder_hidden_states, hidden_states, cond_hidden_states if use_cond else None
229
+
230
+
231
+ class FluxTransformer2DModel(
232
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin
233
+ ):
234
+ _supports_gradient_checkpointing = True
235
+ _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
236
+
237
+ @register_to_config
238
+ def __init__(
239
+ self,
240
+ patch_size: int = 1,
241
+ in_channels: int = 64,
242
+ out_channels: Optional[int] = None,
243
+ num_layers: int = 19,
244
+ num_single_layers: int = 38,
245
+ attention_head_dim: int = 128,
246
+ num_attention_heads: int = 24,
247
+ joint_attention_dim: int = 4096,
248
+ pooled_projection_dim: int = 768,
249
+ guidance_embeds: bool = False,
250
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
251
+ ):
252
+ super().__init__()
253
+ self.out_channels = out_channels or in_channels
254
+ self.inner_dim = num_attention_heads * attention_head_dim
255
+
256
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
257
+
258
+ text_time_guidance_cls = (
259
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
260
+ )
261
+ self.time_text_embed = text_time_guidance_cls(
262
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
263
+ )
264
+
265
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
266
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim)
267
+
268
+ self.transformer_blocks = nn.ModuleList(
269
+ [
270
+ FluxTransformerBlock(
271
+ dim=self.inner_dim,
272
+ num_attention_heads=num_attention_heads,
273
+ attention_head_dim=attention_head_dim,
274
+ )
275
+ for _ in range(num_layers)
276
+ ]
277
+ )
278
+
279
+ self.single_transformer_blocks = nn.ModuleList(
280
+ [
281
+ FluxSingleTransformerBlock(
282
+ dim=self.inner_dim,
283
+ num_attention_heads=num_attention_heads,
284
+ attention_head_dim=attention_head_dim,
285
+ )
286
+ for _ in range(num_single_layers)
287
+ ]
288
+ )
289
+
290
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
291
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
292
+
293
+ self.gradient_checkpointing = False
294
+
295
+ @property
296
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
297
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
298
+ r"""
299
+ Returns:
300
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
301
+ indexed by its weight name.
302
+ """
303
+ # set recursively
304
+ processors = {}
305
+
306
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
307
+ if hasattr(module, "get_processor"):
308
+ processors[f"{name}.processor"] = module.get_processor()
309
+
310
+ for sub_name, child in module.named_children():
311
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
312
+
313
+ return processors
314
+
315
+ for name, module in self.named_children():
316
+ fn_recursive_add_processors(name, module, processors)
317
+
318
+ return processors
319
+
320
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
321
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
322
+ r"""
323
+ Sets the attention processor to use to compute attention.
324
+
325
+ Parameters:
326
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
327
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
328
+ for **all** `Attention` layers.
329
+
330
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
331
+ processor. This is strongly recommended when setting trainable attention processors.
332
+
333
+ """
334
+ count = len(self.attn_processors.keys())
335
+
336
+ if isinstance(processor, dict) and len(processor) != count:
337
+ raise ValueError(
338
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
339
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
340
+ )
341
+
342
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
343
+ if hasattr(module, "set_processor"):
344
+ if not isinstance(processor, dict):
345
+ module.set_processor(processor)
346
+ else:
347
+ module.set_processor(processor.pop(f"{name}.processor"))
348
+
349
+ for sub_name, child in module.named_children():
350
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
351
+
352
+ for name, module in self.named_children():
353
+ fn_recursive_attn_processor(name, module, processor)
354
+
355
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
356
+ def fuse_qkv_projections(self):
357
+ """
358
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
359
+ are fused. For cross-attention modules, key and value projection matrices are fused.
360
+
361
+ <Tip warning={true}>
362
+
363
+ This API is 🧪 experimental.
364
+
365
+ </Tip>
366
+ """
367
+ self.original_attn_processors = None
368
+
369
+ for _, attn_processor in self.attn_processors.items():
370
+ if "Added" in str(attn_processor.__class__.__name__):
371
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
372
+
373
+ self.original_attn_processors = self.attn_processors
374
+
375
+ for module in self.modules():
376
+ if isinstance(module, Attention):
377
+ module.fuse_projections(fuse=True)
378
+
379
+ self.set_attn_processor(FusedFluxAttnProcessor2_0())
380
+
381
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
382
+ def unfuse_qkv_projections(self):
383
+ """Disables the fused QKV projection if enabled.
384
+
385
+ <Tip warning={true}>
386
+
387
+ This API is 🧪 experimental.
388
+
389
+ </Tip>
390
+
391
+ """
392
+ if self.original_attn_processors is not None:
393
+ self.set_attn_processor(self.original_attn_processors)
394
+
395
+ def _set_gradient_checkpointing(self, module, value=False):
396
+ if hasattr(module, "gradient_checkpointing"):
397
+ module.gradient_checkpointing = value
398
+
399
+ def forward(
400
+ self,
401
+ hidden_states: torch.Tensor,
402
+ cond_hidden_states: torch.Tensor = None,
403
+ encoder_hidden_states: torch.Tensor = None,
404
+ pooled_projections: torch.Tensor = None,
405
+ timestep: torch.LongTensor = None,
406
+ img_ids: torch.Tensor = None,
407
+ txt_ids: torch.Tensor = None,
408
+ guidance: torch.Tensor = None,
409
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
410
+ controlnet_block_samples=None,
411
+ controlnet_single_block_samples=None,
412
+ return_dict: bool = True,
413
+ controlnet_blocks_repeat: bool = False,
414
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
415
+ if cond_hidden_states is not None:
416
+ use_condition = True
417
+ else:
418
+ use_condition = False
419
+
420
+ if joint_attention_kwargs is not None:
421
+ joint_attention_kwargs = joint_attention_kwargs.copy()
422
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
423
+ else:
424
+ lora_scale = 1.0
425
+
426
+ if USE_PEFT_BACKEND:
427
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
428
+ scale_lora_layers(self, lora_scale)
429
+ else:
430
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
431
+ logger.warning(
432
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
433
+ )
434
+
435
+ hidden_states = self.x_embedder(hidden_states)
436
+ cond_hidden_states = self.x_embedder(cond_hidden_states)
437
+
438
+ timestep = timestep.to(hidden_states.dtype) * 1000
439
+ if guidance is not None:
440
+ guidance = guidance.to(hidden_states.dtype) * 1000
441
+ else:
442
+ guidance = None
443
+
444
+ temb = (
445
+ self.time_text_embed(timestep, pooled_projections)
446
+ if guidance is None
447
+ else self.time_text_embed(timestep, guidance, pooled_projections)
448
+ )
449
+
450
+ cond_temb = (
451
+ self.time_text_embed(torch.ones_like(timestep) * 0, pooled_projections)
452
+ if guidance is None
453
+ else self.time_text_embed(
454
+ torch.ones_like(timestep) * 0, guidance, pooled_projections
455
+ )
456
+ )
457
+
458
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
459
+
460
+ if txt_ids.ndim == 3:
461
+ logger.warning(
462
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
463
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
464
+ )
465
+ txt_ids = txt_ids[0]
466
+ if img_ids.ndim == 3:
467
+ logger.warning(
468
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
469
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
470
+ )
471
+ img_ids = img_ids[0]
472
+
473
+ ids = torch.cat((txt_ids, img_ids), dim=0)
474
+ image_rotary_emb = self.pos_embed(ids)
475
+
476
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
477
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
478
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
479
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
480
+
481
+ for index_block, block in enumerate(self.transformer_blocks):
482
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
483
+
484
+ def create_custom_forward(module, return_dict=None):
485
+ def custom_forward(*inputs):
486
+ if return_dict is not None:
487
+ return module(*inputs, return_dict=return_dict)
488
+ else:
489
+ return module(*inputs)
490
+
491
+ return custom_forward
492
+
493
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
494
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
495
+ create_custom_forward(block),
496
+ hidden_states,
497
+ encoder_hidden_states,
498
+ temb,
499
+ image_rotary_emb,
500
+ cond_temb=cond_temb if use_condition else None,
501
+ cond_hidden_states=cond_hidden_states if use_condition else None,
502
+ **ckpt_kwargs,
503
+ )
504
+
505
+ else:
506
+ encoder_hidden_states, hidden_states, cond_hidden_states = block(
507
+ hidden_states=hidden_states,
508
+ encoder_hidden_states=encoder_hidden_states,
509
+ cond_hidden_states=cond_hidden_states if use_condition else None,
510
+ temb=temb,
511
+ cond_temb=cond_temb if use_condition else None,
512
+ image_rotary_emb=image_rotary_emb,
513
+ joint_attention_kwargs=joint_attention_kwargs,
514
+ )
515
+
516
+ # controlnet residual
517
+ if controlnet_block_samples is not None:
518
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
519
+ interval_control = int(np.ceil(interval_control))
520
+ # For Xlabs ControlNet.
521
+ if controlnet_blocks_repeat:
522
+ hidden_states = (
523
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
524
+ )
525
+ else:
526
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
527
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
528
+
529
+ for index_block, block in enumerate(self.single_transformer_blocks):
530
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
531
+
532
+ def create_custom_forward(module, return_dict=None):
533
+ def custom_forward(*inputs):
534
+ if return_dict is not None:
535
+ return module(*inputs, return_dict=return_dict)
536
+ else:
537
+ return module(*inputs)
538
+
539
+ return custom_forward
540
+
541
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
542
+ hidden_states, cond_hidden_states = torch.utils.checkpoint.checkpoint(
543
+ create_custom_forward(block),
544
+ hidden_states,
545
+ temb,
546
+ image_rotary_emb,
547
+ cond_temb=cond_temb if use_condition else None,
548
+ cond_hidden_states=cond_hidden_states if use_condition else None,
549
+ **ckpt_kwargs,
550
+ )
551
+
552
+ else:
553
+ hidden_states, cond_hidden_states = block(
554
+ hidden_states=hidden_states,
555
+ cond_hidden_states=cond_hidden_states if use_condition else None,
556
+ temb=temb,
557
+ cond_temb=cond_temb if use_condition else None,
558
+ image_rotary_emb=image_rotary_emb,
559
+ joint_attention_kwargs=joint_attention_kwargs,
560
+ )
561
+
562
+ # controlnet residual
563
+ if controlnet_single_block_samples is not None:
564
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
565
+ interval_control = int(np.ceil(interval_control))
566
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
567
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
568
+ + controlnet_single_block_samples[index_block // interval_control]
569
+ )
570
+
571
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
572
+
573
+ hidden_states = self.norm_out(hidden_states, temb)
574
+ output = self.proj_out(hidden_states)
575
+
576
+ if USE_PEFT_BACKEND:
577
+ # remove `lora_scale` from each PEFT layer
578
+ unscale_lora_layers(self, lora_scale)
579
+
580
+ if not return_dict:
581
+ return (output,)
582
+
583
+ return Transformer2DModelOutput(sample=output)
src/transformer_with_loss.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
10
+ from diffusers.models.attention import FeedForward
11
+ from diffusers.models.attention_processor import (
12
+ Attention,
13
+ AttentionProcessor,
14
+ FluxAttnProcessor2_0,
15
+ FluxAttnProcessor2_0_NPU,
16
+ FusedFluxAttnProcessor2_0,
17
+ )
18
+ from diffusers.models.modeling_utils import ModelMixin
19
+ from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
20
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers, deprecate
21
+ from diffusers.utils.import_utils import is_torch_npu_available
22
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
23
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
24
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
25
+ from diffusers import CacheMixin
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ @maybe_allow_in_graph
31
+ class FluxSingleTransformerBlock(nn.Module):
32
+ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
33
+ super().__init__()
34
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
35
+
36
+ self.norm = AdaLayerNormZeroSingle(dim)
37
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
38
+ self.act_mlp = nn.GELU(approximate="tanh")
39
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
40
+
41
+ if is_torch_npu_available():
42
+ deprecation_message = (
43
+ "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
44
+ "should be set explicitly using the `set_attn_processor` method."
45
+ )
46
+ deprecate("npu_processor", "0.34.0", deprecation_message)
47
+ processor = FluxAttnProcessor2_0_NPU()
48
+ else:
49
+ processor = FluxAttnProcessor2_0()
50
+
51
+ self.attn = Attention(
52
+ query_dim=dim,
53
+ cross_attention_dim=None,
54
+ dim_head=attention_head_dim,
55
+ heads=num_attention_heads,
56
+ out_dim=dim,
57
+ bias=True,
58
+ processor=processor,
59
+ qk_norm="rms_norm",
60
+ eps=1e-6,
61
+ pre_only=True,
62
+ )
63
+
64
+ def forward(
65
+ self,
66
+ hidden_states: torch.Tensor,
67
+ temb: torch.Tensor,
68
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
69
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
70
+ ) -> torch.Tensor:
71
+ residual = hidden_states
72
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
73
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
74
+ joint_attention_kwargs = joint_attention_kwargs or {}
75
+ attn_output = self.attn(
76
+ hidden_states=norm_hidden_states,
77
+ image_rotary_emb=image_rotary_emb,
78
+ **joint_attention_kwargs,
79
+ )
80
+
81
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
82
+ gate = gate.unsqueeze(1)
83
+ hidden_states = gate * self.proj_out(hidden_states)
84
+ hidden_states = residual + hidden_states
85
+ if hidden_states.dtype == torch.float16:
86
+ hidden_states = hidden_states.clip(-65504, 65504)
87
+
88
+ return hidden_states
89
+
90
+
91
+ @maybe_allow_in_graph
92
+ class FluxTransformerBlock(nn.Module):
93
+ def __init__(
94
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
95
+ ):
96
+ super().__init__()
97
+
98
+ self.norm1 = AdaLayerNormZero(dim)
99
+ self.norm1_context = AdaLayerNormZero(dim)
100
+
101
+ self.attn = Attention(
102
+ query_dim=dim,
103
+ cross_attention_dim=None,
104
+ added_kv_proj_dim=dim,
105
+ dim_head=attention_head_dim,
106
+ heads=num_attention_heads,
107
+ out_dim=dim,
108
+ context_pre_only=False,
109
+ bias=True,
110
+ processor=FluxAttnProcessor2_0(),
111
+ qk_norm=qk_norm,
112
+ eps=eps,
113
+ )
114
+
115
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
116
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
117
+
118
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
119
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
120
+
121
+ def forward(
122
+ self,
123
+ hidden_states: torch.Tensor,
124
+ encoder_hidden_states: torch.Tensor,
125
+ temb: torch.Tensor,
126
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
127
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
128
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
129
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
130
+
131
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
132
+ encoder_hidden_states, emb=temb
133
+ )
134
+ joint_attention_kwargs = joint_attention_kwargs or {}
135
+ # Attention.
136
+ attention_outputs = self.attn(
137
+ hidden_states=norm_hidden_states,
138
+ encoder_hidden_states=norm_encoder_hidden_states,
139
+ image_rotary_emb=image_rotary_emb,
140
+ **joint_attention_kwargs,
141
+ )
142
+
143
+ if len(attention_outputs) == 2:
144
+ attn_output, context_attn_output = attention_outputs
145
+ elif len(attention_outputs) == 3:
146
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
147
+
148
+ # Process attention outputs for the `hidden_states`.
149
+ attn_output = gate_msa.unsqueeze(1) * attn_output
150
+ hidden_states = hidden_states + attn_output
151
+
152
+ norm_hidden_states = self.norm2(hidden_states)
153
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
154
+
155
+ ff_output = self.ff(norm_hidden_states)
156
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
157
+
158
+ hidden_states = hidden_states + ff_output
159
+ if len(attention_outputs) == 3:
160
+ hidden_states = hidden_states + ip_attn_output
161
+
162
+ # Process attention outputs for the `encoder_hidden_states`.
163
+
164
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
165
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
166
+
167
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
168
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
169
+
170
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
171
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
172
+ if encoder_hidden_states.dtype == torch.float16:
173
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
174
+
175
+ return encoder_hidden_states, hidden_states
176
+
177
+
178
+ class FluxTransformer2DModelWithLoss(
179
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
180
+ ):
181
+ _supports_gradient_checkpointing = True
182
+ _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
183
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
184
+ _repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
185
+
186
+ @register_to_config
187
+ def __init__(
188
+ self,
189
+ patch_size: int = 1,
190
+ in_channels: int = 64,
191
+ out_channels: Optional[int] = None,
192
+ num_layers: int = 19,
193
+ num_single_layers: int = 38,
194
+ attention_head_dim: int = 128,
195
+ num_attention_heads: int = 24,
196
+ joint_attention_dim: int = 4096,
197
+ pooled_projection_dim: int = 768,
198
+ guidance_embeds: bool = False,
199
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
200
+ ):
201
+ super().__init__()
202
+ self.out_channels = out_channels or in_channels
203
+ self.inner_dim = num_attention_heads * attention_head_dim
204
+
205
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
206
+
207
+ text_time_guidance_cls = (
208
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
209
+ )
210
+ self.time_text_embed = text_time_guidance_cls(
211
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
212
+ )
213
+
214
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
215
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim)
216
+
217
+ self.transformer_blocks = nn.ModuleList(
218
+ [
219
+ FluxTransformerBlock(
220
+ dim=self.inner_dim,
221
+ num_attention_heads=num_attention_heads,
222
+ attention_head_dim=attention_head_dim,
223
+ )
224
+ for _ in range(num_layers)
225
+ ]
226
+ )
227
+
228
+ self.single_transformer_blocks = nn.ModuleList(
229
+ [
230
+ FluxSingleTransformerBlock(
231
+ dim=self.inner_dim,
232
+ num_attention_heads=num_attention_heads,
233
+ attention_head_dim=attention_head_dim,
234
+ )
235
+ for _ in range(num_single_layers)
236
+ ]
237
+ )
238
+
239
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
240
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
241
+
242
+ self.gradient_checkpointing = False
243
+
244
+ @property
245
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
246
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
247
+ r"""
248
+ Returns:
249
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
250
+ indexed by its weight name.
251
+ """
252
+ # set recursively
253
+ processors = {}
254
+
255
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
256
+ if hasattr(module, "get_processor"):
257
+ processors[f"{name}.processor"] = module.get_processor()
258
+
259
+ for sub_name, child in module.named_children():
260
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
261
+
262
+ return processors
263
+
264
+ for name, module in self.named_children():
265
+ fn_recursive_add_processors(name, module, processors)
266
+
267
+ return processors
268
+
269
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
270
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
271
+ r"""
272
+ Sets the attention processor to use to compute attention.
273
+
274
+ Parameters:
275
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
276
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
277
+ for **all** `Attention` layers.
278
+
279
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
280
+ processor. This is strongly recommended when setting trainable attention processors.
281
+
282
+ """
283
+ count = len(self.attn_processors.keys())
284
+
285
+ if isinstance(processor, dict) and len(processor) != count:
286
+ raise ValueError(
287
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
288
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
289
+ )
290
+
291
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
292
+ if hasattr(module, "set_processor"):
293
+ if not isinstance(processor, dict):
294
+ module.set_processor(processor)
295
+ else:
296
+ module.set_processor(processor.pop(f"{name}.processor"))
297
+
298
+ for sub_name, child in module.named_children():
299
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
300
+
301
+ for name, module in self.named_children():
302
+ fn_recursive_attn_processor(name, module, processor)
303
+
304
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
305
+ def fuse_qkv_projections(self):
306
+ """
307
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
308
+ are fused. For cross-attention modules, key and value projection matrices are fused.
309
+
310
+ <Tip warning={true}>
311
+
312
+ This API is 🧪 experimental.
313
+
314
+ </Tip>
315
+ """
316
+ self.original_attn_processors = None
317
+
318
+ for _, attn_processor in self.attn_processors.items():
319
+ if "Added" in str(attn_processor.__class__.__name__):
320
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
321
+
322
+ self.original_attn_processors = self.attn_processors
323
+
324
+ for module in self.modules():
325
+ if isinstance(module, Attention):
326
+ module.fuse_projections(fuse=True)
327
+
328
+ self.set_attn_processor(FusedFluxAttnProcessor2_0())
329
+
330
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
331
+ def unfuse_qkv_projections(self):
332
+ """Disables the fused QKV projection if enabled.
333
+
334
+ <Tip warning={true}>
335
+
336
+ This API is 🧪 experimental.
337
+
338
+ </Tip>
339
+
340
+ """
341
+ if self.original_attn_processors is not None:
342
+ self.set_attn_processor(self.original_attn_processors)
343
+
344
+ def forward(
345
+ self,
346
+ hidden_states: torch.Tensor,
347
+ encoder_hidden_states: torch.Tensor = None,
348
+ pooled_projections: torch.Tensor = None,
349
+ timestep: torch.LongTensor = None,
350
+ img_ids: torch.Tensor = None,
351
+ txt_ids: torch.Tensor = None,
352
+ guidance: torch.Tensor = None,
353
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
354
+ controlnet_block_samples=None,
355
+ controlnet_single_block_samples=None,
356
+ return_dict: bool = True,
357
+ controlnet_blocks_repeat: bool = False,
358
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
359
+ """
360
+ The [`FluxTransformer2DModel`] forward method.
361
+
362
+ Args:
363
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
364
+ Input `hidden_states`.
365
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
366
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
367
+ pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
368
+ from the embeddings of input conditions.
369
+ timestep ( `torch.LongTensor`):
370
+ Used to indicate denoising step.
371
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
372
+ A list of tensors that if specified are added to the residuals of transformer blocks.
373
+ joint_attention_kwargs (`dict`, *optional*):
374
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
375
+ `self.processor` in
376
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
377
+ return_dict (`bool`, *optional*, defaults to `True`):
378
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
379
+ tuple.
380
+
381
+ Returns:
382
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
383
+ `tuple` where the first element is the sample tensor.
384
+ """
385
+ if joint_attention_kwargs is not None:
386
+ joint_attention_kwargs = joint_attention_kwargs.copy()
387
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
388
+ else:
389
+ lora_scale = 1.0
390
+
391
+ if USE_PEFT_BACKEND:
392
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
393
+ scale_lora_layers(self, lora_scale)
394
+ else:
395
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
396
+ logger.warning(
397
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
398
+ )
399
+
400
+ hidden_states = self.x_embedder(hidden_states)
401
+
402
+ timestep = timestep.to(hidden_states.dtype) * 1000
403
+ if guidance is not None:
404
+ guidance = guidance.to(hidden_states.dtype) * 1000
405
+
406
+ temb = (
407
+ self.time_text_embed(timestep, pooled_projections)
408
+ if guidance is None
409
+ else self.time_text_embed(timestep, guidance, pooled_projections)
410
+ )
411
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
412
+
413
+ if txt_ids.ndim == 3:
414
+ logger.warning(
415
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
416
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
417
+ )
418
+ txt_ids = txt_ids[0]
419
+ if img_ids.ndim == 3:
420
+ logger.warning(
421
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
422
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
423
+ )
424
+ img_ids = img_ids[0]
425
+
426
+ ids = torch.cat((txt_ids, img_ids), dim=0)
427
+ image_rotary_emb = self.pos_embed(ids)
428
+
429
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
430
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
431
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
432
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
433
+
434
+ for index_block, block in enumerate(self.transformer_blocks):
435
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
436
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
437
+ block,
438
+ hidden_states,
439
+ encoder_hidden_states,
440
+ temb,
441
+ image_rotary_emb,
442
+ )
443
+
444
+ else:
445
+ encoder_hidden_states, hidden_states = block(
446
+ hidden_states=hidden_states,
447
+ encoder_hidden_states=encoder_hidden_states,
448
+ temb=temb,
449
+ image_rotary_emb=image_rotary_emb,
450
+ joint_attention_kwargs=joint_attention_kwargs,
451
+ )
452
+
453
+ # controlnet residual
454
+ if controlnet_block_samples is not None:
455
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
456
+ interval_control = int(np.ceil(interval_control))
457
+ # For Xlabs ControlNet.
458
+ if controlnet_blocks_repeat:
459
+ hidden_states = (
460
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
461
+ )
462
+ else:
463
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
464
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
465
+
466
+ for index_block, block in enumerate(self.single_transformer_blocks):
467
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
468
+ hidden_states = self._gradient_checkpointing_func(
469
+ block,
470
+ hidden_states,
471
+ temb,
472
+ image_rotary_emb,
473
+ )
474
+
475
+ else:
476
+ hidden_states = block(
477
+ hidden_states=hidden_states,
478
+ temb=temb,
479
+ image_rotary_emb=image_rotary_emb,
480
+ joint_attention_kwargs=joint_attention_kwargs,
481
+ )
482
+
483
+ # controlnet residual
484
+ if controlnet_single_block_samples is not None:
485
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
486
+ interval_control = int(np.ceil(interval_control))
487
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
488
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
489
+ + controlnet_single_block_samples[index_block // interval_control]
490
+ )
491
+
492
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
493
+
494
+ hidden_states = self.norm_out(hidden_states, temb)
495
+ output = self.proj_out(hidden_states)
496
+
497
+ if USE_PEFT_BACKEND:
498
+ # remove `lora_scale` from each PEFT layer
499
+ unscale_lora_layers(self, lora_scale)
500
+
501
+ if not return_dict:
502
+ return (output,)
503
+
504
+ return Transformer2DModelOutput(sample=output)
test_imgs/2.png ADDED

Git LFS Details

  • SHA256: 913839ca3ca83963f6b52309394a0ea3ca7b701a7290d7db41d7e4f1879e7467
  • Pointer size: 132 Bytes
  • Size of remote file: 1.04 MB
test_imgs/3.png ADDED

Git LFS Details

  • SHA256: bc366a2bbdeaa883141155214f2af4236c354457b8ca7b608ead64b21b053a2f
  • Pointer size: 131 Bytes
  • Size of remote file: 261 kB
test_imgs/generated_1.png ADDED

Git LFS Details

  • SHA256: 5e95bb2e854b04958fbcd1013a01ed929a36c5c1b43887c3fad0d646b0768703
  • Pointer size: 131 Bytes
  • Size of remote file: 323 kB
test_imgs/generated_1_bbox.png ADDED

Git LFS Details

  • SHA256: a3a5fdaf6a6998a5b8559aca5c9571cde99da1945cc8418ea8f679cd1ce6b4bf
  • Pointer size: 131 Bytes
  • Size of remote file: 383 kB
test_imgs/generated_2.png ADDED

Git LFS Details

  • SHA256: 193dca8f7fab34f802fc6ba7623346a5542871d6f190ba6a412280298fc2c6a4
  • Pointer size: 131 Bytes
  • Size of remote file: 522 kB
test_imgs/generated_2_bbox.png ADDED

Git LFS Details

  • SHA256: 6376fd6ebb4966d7a26b973663304516ee5878c0a81db93c71493c60a339050d
  • Pointer size: 131 Bytes
  • Size of remote file: 192 kB
test_imgs/generated_3.png ADDED

Git LFS Details

  • SHA256: 2087fe809477c88c3a46ade8d58d55277fa05ef922f68030e5aebc13991f9043
  • Pointer size: 132 Bytes
  • Size of remote file: 1.42 MB
test_imgs/generated_3_bbox.png ADDED

Git LFS Details

  • SHA256: fdaa47d38698863c7dc4ca34ed841b7b70370bebd772b1b2fa55eb9504bb2bd8
  • Pointer size: 131 Bytes
  • Size of remote file: 637 kB
test_imgs/generated_3_bbox_1.png ADDED

Git LFS Details

  • SHA256: 5b3ab07a05abc6c2a8808211ef642b8ec498d69ab4e5c86bc46b14c882153fb5
  • Pointer size: 131 Bytes
  • Size of remote file: 434 kB
test_imgs/product_1.jpg ADDED
test_imgs/product_1_bbox.png ADDED

Git LFS Details

  • SHA256: 17d8dd3de98689c8db96b80d6de5f3561fc4fc73f8a8bee5b892d3da235c65dc
  • Pointer size: 131 Bytes
  • Size of remote file: 352 kB
test_imgs/product_2.png ADDED

Git LFS Details

  • SHA256: 2570c1fc8e4da310f78981e9dc050bf9049f2038343fe03669ff35fb2c9c00f8
  • Pointer size: 131 Bytes
  • Size of remote file: 496 kB
test_imgs/product_2_bbox.png ADDED

Git LFS Details

  • SHA256: 3c042db4ecc731c06f4d8040d46597e8d1c9459aa06e4431f4372c9b67efd9d4
  • Pointer size: 131 Bytes
  • Size of remote file: 475 kB
test_imgs/product_3.png ADDED

Git LFS Details

  • SHA256: 3217819ed69d19c2e309897f673a47c6f7910e74e239e8cbcb182e6df72a567d
  • Pointer size: 131 Bytes
  • Size of remote file: 367 kB
test_imgs/product_3_bbox.png ADDED

Git LFS Details

  • SHA256: 61d09d4f7e03c4447cde438befe4be7bc697be5e28aaf447f23315f9dc38de41
  • Pointer size: 131 Bytes
  • Size of remote file: 138 kB
test_imgs/product_3_bbox_1.png ADDED

Git LFS Details

  • SHA256: 70dd0fff6ed5164d9ddff1c2184aa725321a70ea5598985393431d758291343d
  • Pointer size: 131 Bytes
  • Size of remote file: 147 kB
uno/dataset/uno.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import json
17
+ import os
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torchvision.transforms.functional as TVF
22
+ from torch.utils.data import DataLoader, Dataset
23
+ from torchvision.transforms import Compose, Normalize, ToTensor
24
+
25
+ def bucket_images(images: list[torch.Tensor], resolution: int = 512):
26
+ bucket_override=[
27
+ # h w
28
+ (256, 768),
29
+ (320, 768),
30
+ (320, 704),
31
+ (384, 640),
32
+ (448, 576),
33
+ (512, 512),
34
+ (576, 448),
35
+ (640, 384),
36
+ (704, 320),
37
+ (768, 320),
38
+ (768, 256)
39
+ ]
40
+ bucket_override = [(int(h / 512 * resolution), int(w / 512 * resolution)) for h, w in bucket_override]
41
+ bucket_override = [(h // 16 * 16, w // 16 * 16) for h, w in bucket_override]
42
+
43
+ aspect_ratios = [image.shape[-2] / image.shape[-1] for image in images]
44
+ mean_aspect_ratio = np.mean(aspect_ratios)
45
+
46
+ new_h, new_w = bucket_override[0]
47
+ min_aspect_diff = np.abs(new_h / new_w - mean_aspect_ratio)
48
+ for h, w in bucket_override:
49
+ aspect_diff = np.abs(h / w - mean_aspect_ratio)
50
+ if aspect_diff < min_aspect_diff:
51
+ min_aspect_diff = aspect_diff
52
+ new_h, new_w = h, w
53
+
54
+ images = [TVF.resize(image, (new_h, new_w)) for image in images]
55
+ images = torch.stack(images, dim=0)
56
+ return images
57
+
58
+ class FluxPairedDatasetV2(Dataset):
59
+ def __init__(self, json_file: str, resolution: int, resolution_ref: int | None = None):
60
+ super().__init__()
61
+ self.json_file = json_file
62
+ self.resolution = resolution
63
+ self.resolution_ref = resolution_ref if resolution_ref is not None else resolution
64
+ self.image_root = os.path.dirname(json_file)
65
+
66
+ with open(self.json_file, "rt") as f:
67
+ self.data_dicts = json.load(f)
68
+
69
+ self.transform = Compose([
70
+ ToTensor(),
71
+ Normalize([0.5], [0.5]),
72
+ ])
73
+
74
+ def __getitem__(self, idx):
75
+ data_dict = self.data_dicts[idx]
76
+ image_paths = [data_dict["image_path"]] if "image_path" in data_dict else data_dict["image_paths"]
77
+ txt = data_dict["prompt"]
78
+ image_tgt_path = data_dict.get("image_tgt_path", None)
79
+ ref_imgs = [
80
+ Image.open(os.path.join(self.image_root, path)).convert("RGB")
81
+ for path in image_paths
82
+ ]
83
+ ref_imgs = [self.transform(img) for img in ref_imgs]
84
+ img = None
85
+ if image_tgt_path is not None:
86
+ img = Image.open(os.path.join(self.image_root, image_tgt_path)).convert("RGB")
87
+ img = self.transform(img)
88
+
89
+ return {
90
+ "img": img,
91
+ "txt": txt,
92
+ "ref_imgs": ref_imgs,
93
+ }
94
+
95
+ def __len__(self):
96
+ return len(self.data_dicts)
97
+
98
+ def collate_fn(self, batch):
99
+ img = [data["img"] for data in batch]
100
+ txt = [data["txt"] for data in batch]
101
+ ref_imgs = [data["ref_imgs"] for data in batch]
102
+ assert all([len(ref_imgs[0]) == len(ref_imgs[i]) for i in range(len(ref_imgs))])
103
+
104
+ n_ref = len(ref_imgs[0])
105
+
106
+ img = bucket_images(img, self.resolution)
107
+ ref_imgs_new = []
108
+ for i in range(n_ref):
109
+ ref_imgs_i = [refs[i] for refs in ref_imgs]
110
+ ref_imgs_i = bucket_images(ref_imgs_i, self.resolution_ref)
111
+ ref_imgs_new.append(ref_imgs_i)
112
+
113
+ return {
114
+ "txt": txt,
115
+ "img": img,
116
+ "ref_imgs": ref_imgs_new,
117
+ }
118
+
119
+ if __name__ == '__main__':
120
+ import argparse
121
+ from pprint import pprint
122
+ parser = argparse.ArgumentParser()
123
+ # parser.add_argument("--json_file", type=str, required=True)
124
+ parser.add_argument("--json_file", type=str, default="datasets/fake_train_data.json")
125
+ args = parser.parse_args()
126
+ dataset = FluxPairedDatasetV2(args.json_file, 512)
127
+ dataloder = DataLoader(dataset, batch_size=4, collate_fn=dataset.collate_fn)
128
+
129
+ for i, data_dict in enumerate(dataloder):
130
+ pprint(i)
131
+ pprint(data_dict)
132
+ breakpoint()
uno/flux/math.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from einops import rearrange
18
+ from torch import Tensor
19
+
20
+
21
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
22
+ q, k = apply_rope(q, k, pe)
23
+
24
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
25
+ x = rearrange(x, "B H L D -> B L (H D)")
26
+
27
+ return x
28
+
29
+
30
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
31
+ assert dim % 2 == 0
32
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
33
+ omega = 1.0 / (theta**scale)
34
+ out = torch.einsum("...n,d->...nd", pos, omega)
35
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
36
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
37
+ return out.float()
38
+
39
+
40
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
41
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
42
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
43
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
44
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
45
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
uno/flux/model.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+
18
+ import torch
19
+ from torch import Tensor, nn
20
+
21
+ from .modules.layers import DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock, timestep_embedding
22
+
23
+
24
+ @dataclass
25
+ class FluxParams:
26
+ in_channels: int
27
+ vec_in_dim: int
28
+ context_in_dim: int
29
+ hidden_size: int
30
+ mlp_ratio: float
31
+ num_heads: int
32
+ depth: int
33
+ depth_single_blocks: int
34
+ axes_dim: list[int]
35
+ theta: int
36
+ qkv_bias: bool
37
+ guidance_embed: bool
38
+
39
+
40
+ class Flux(nn.Module):
41
+ """
42
+ Transformer model for flow matching on sequences.
43
+ """
44
+ _supports_gradient_checkpointing = True
45
+
46
+ def __init__(self, params: FluxParams):
47
+ super().__init__()
48
+
49
+ self.params = params
50
+ self.in_channels = params.in_channels
51
+ self.out_channels = self.in_channels
52
+ if params.hidden_size % params.num_heads != 0:
53
+ raise ValueError(
54
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
55
+ )
56
+ pe_dim = params.hidden_size // params.num_heads
57
+ if sum(params.axes_dim) != pe_dim:
58
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
59
+ self.hidden_size = params.hidden_size
60
+ self.num_heads = params.num_heads
61
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
62
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
63
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
64
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
65
+ self.guidance_in = (
66
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
67
+ )
68
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
69
+
70
+ self.double_blocks = nn.ModuleList(
71
+ [
72
+ DoubleStreamBlock(
73
+ self.hidden_size,
74
+ self.num_heads,
75
+ mlp_ratio=params.mlp_ratio,
76
+ qkv_bias=params.qkv_bias,
77
+ )
78
+ for _ in range(params.depth)
79
+ ]
80
+ )
81
+
82
+ self.single_blocks = nn.ModuleList(
83
+ [
84
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
85
+ for _ in range(params.depth_single_blocks)
86
+ ]
87
+ )
88
+
89
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
90
+ self.gradient_checkpointing = False
91
+
92
+ def _set_gradient_checkpointing(self, module, value=False):
93
+ if hasattr(module, "gradient_checkpointing"):
94
+ module.gradient_checkpointing = value
95
+
96
+ @property
97
+ def attn_processors(self):
98
+ # set recursively
99
+ processors = {} # type: dict[str, nn.Module]
100
+
101
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
102
+ if hasattr(module, "set_processor"):
103
+ processors[f"{name}.processor"] = module.processor
104
+
105
+ for sub_name, child in module.named_children():
106
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
107
+
108
+ return processors
109
+
110
+ for name, module in self.named_children():
111
+ fn_recursive_add_processors(name, module, processors)
112
+
113
+ return processors
114
+
115
+ def set_attn_processor(self, processor):
116
+ r"""
117
+ Sets the attention processor to use to compute attention.
118
+
119
+ Parameters:
120
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
121
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
122
+ for **all** `Attention` layers.
123
+
124
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
125
+ processor. This is strongly recommended when setting trainable attention processors.
126
+
127
+ """
128
+ count = len(self.attn_processors.keys())
129
+
130
+ if isinstance(processor, dict) and len(processor) != count:
131
+ raise ValueError(
132
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
133
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
134
+ )
135
+
136
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
137
+ if hasattr(module, "set_processor"):
138
+ if not isinstance(processor, dict):
139
+ module.set_processor(processor)
140
+ else:
141
+ module.set_processor(processor.pop(f"{name}.processor"))
142
+
143
+ for sub_name, child in module.named_children():
144
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
145
+
146
+ for name, module in self.named_children():
147
+ fn_recursive_attn_processor(name, module, processor)
148
+
149
+ def forward(
150
+ self,
151
+ img: Tensor,
152
+ img_ids: Tensor,
153
+ txt: Tensor,
154
+ txt_ids: Tensor,
155
+ timesteps: Tensor,
156
+ y: Tensor,
157
+ guidance: Tensor | None = None,
158
+ ref_img: Tensor | None = None,
159
+ ref_img_ids: Tensor | None = None,
160
+ ) -> Tensor:
161
+ if img.ndim != 3 or txt.ndim != 3:
162
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
163
+
164
+ # running on sequences img
165
+ img = self.img_in(img)
166
+ vec = self.time_in(timestep_embedding(timesteps, 256))
167
+ if self.params.guidance_embed:
168
+ if guidance is None:
169
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
170
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
171
+ vec = vec + self.vector_in(y)
172
+ txt = self.txt_in(txt)
173
+
174
+ ids = torch.cat((txt_ids, img_ids), dim=1)
175
+
176
+ # concat ref_img/img
177
+ img_end = img.shape[1]
178
+ if ref_img is not None:
179
+ if isinstance(ref_img, tuple) or isinstance(ref_img, list):
180
+ img_in = [img] + [self.img_in(ref) for ref in ref_img]
181
+ img_ids = [ids] + [ref_ids for ref_ids in ref_img_ids]
182
+ img = torch.cat(img_in, dim=1)
183
+ ids = torch.cat(img_ids, dim=1)
184
+ else:
185
+ img = torch.cat((img, self.img_in(ref_img)), dim=1)
186
+ ids = torch.cat((ids, ref_img_ids), dim=1)
187
+ pe = self.pe_embedder(ids)
188
+
189
+ for index_block, block in enumerate(self.double_blocks):
190
+ if self.training and self.gradient_checkpointing:
191
+ img, txt = torch.utils.checkpoint.checkpoint(
192
+ block,
193
+ img=img,
194
+ txt=txt,
195
+ vec=vec,
196
+ pe=pe,
197
+ use_reentrant=False,
198
+ )
199
+ else:
200
+ img, txt = block(
201
+ img=img,
202
+ txt=txt,
203
+ vec=vec,
204
+ pe=pe
205
+ )
206
+
207
+ img = torch.cat((txt, img), 1)
208
+ for block in self.single_blocks:
209
+ if self.training and self.gradient_checkpointing:
210
+ img = torch.utils.checkpoint.checkpoint(
211
+ block,
212
+ img, vec=vec, pe=pe,
213
+ use_reentrant=False
214
+ )
215
+ else:
216
+ img = block(img, vec=vec, pe=pe)
217
+ img = img[:, txt.shape[1] :, ...]
218
+ # index img
219
+ img = img[:, :img_end, ...]
220
+
221
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
222
+ return img
uno/flux/modules/autoencoder.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+
18
+ import torch
19
+ from einops import rearrange
20
+ from torch import Tensor, nn
21
+
22
+
23
+ @dataclass
24
+ class AutoEncoderParams:
25
+ resolution: int
26
+ in_channels: int
27
+ ch: int
28
+ out_ch: int
29
+ ch_mult: list[int]
30
+ num_res_blocks: int
31
+ z_channels: int
32
+ scale_factor: float
33
+ shift_factor: float
34
+
35
+
36
+ def swish(x: Tensor) -> Tensor:
37
+ return x * torch.sigmoid(x)
38
+
39
+
40
+ class AttnBlock(nn.Module):
41
+ def __init__(self, in_channels: int):
42
+ super().__init__()
43
+ self.in_channels = in_channels
44
+
45
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
46
+
47
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
48
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
49
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
50
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
51
+
52
+ def attention(self, h_: Tensor) -> Tensor:
53
+ h_ = self.norm(h_)
54
+ q = self.q(h_)
55
+ k = self.k(h_)
56
+ v = self.v(h_)
57
+
58
+ b, c, h, w = q.shape
59
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
60
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
61
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
62
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
63
+
64
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
65
+
66
+ def forward(self, x: Tensor) -> Tensor:
67
+ return x + self.proj_out(self.attention(x))
68
+
69
+
70
+ class ResnetBlock(nn.Module):
71
+ def __init__(self, in_channels: int, out_channels: int):
72
+ super().__init__()
73
+ self.in_channels = in_channels
74
+ out_channels = in_channels if out_channels is None else out_channels
75
+ self.out_channels = out_channels
76
+
77
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
79
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
80
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
81
+ if self.in_channels != self.out_channels:
82
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
83
+
84
+ def forward(self, x):
85
+ h = x
86
+ h = self.norm1(h)
87
+ h = swish(h)
88
+ h = self.conv1(h)
89
+
90
+ h = self.norm2(h)
91
+ h = swish(h)
92
+ h = self.conv2(h)
93
+
94
+ if self.in_channels != self.out_channels:
95
+ x = self.nin_shortcut(x)
96
+
97
+ return x + h
98
+
99
+
100
+ class Downsample(nn.Module):
101
+ def __init__(self, in_channels: int):
102
+ super().__init__()
103
+ # no asymmetric padding in torch conv, must do it ourselves
104
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
105
+
106
+ def forward(self, x: Tensor):
107
+ pad = (0, 1, 0, 1)
108
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
109
+ x = self.conv(x)
110
+ return x
111
+
112
+
113
+ class Upsample(nn.Module):
114
+ def __init__(self, in_channels: int):
115
+ super().__init__()
116
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
117
+
118
+ def forward(self, x: Tensor):
119
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
120
+ x = self.conv(x)
121
+ return x
122
+
123
+
124
+ class Encoder(nn.Module):
125
+ def __init__(
126
+ self,
127
+ resolution: int,
128
+ in_channels: int,
129
+ ch: int,
130
+ ch_mult: list[int],
131
+ num_res_blocks: int,
132
+ z_channels: int,
133
+ ):
134
+ super().__init__()
135
+ self.ch = ch
136
+ self.num_resolutions = len(ch_mult)
137
+ self.num_res_blocks = num_res_blocks
138
+ self.resolution = resolution
139
+ self.in_channels = in_channels
140
+ # downsampling
141
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
142
+
143
+ curr_res = resolution
144
+ in_ch_mult = (1,) + tuple(ch_mult)
145
+ self.in_ch_mult = in_ch_mult
146
+ self.down = nn.ModuleList()
147
+ block_in = self.ch
148
+ for i_level in range(self.num_resolutions):
149
+ block = nn.ModuleList()
150
+ attn = nn.ModuleList()
151
+ block_in = ch * in_ch_mult[i_level]
152
+ block_out = ch * ch_mult[i_level]
153
+ for _ in range(self.num_res_blocks):
154
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
155
+ block_in = block_out
156
+ down = nn.Module()
157
+ down.block = block
158
+ down.attn = attn
159
+ if i_level != self.num_resolutions - 1:
160
+ down.downsample = Downsample(block_in)
161
+ curr_res = curr_res // 2
162
+ self.down.append(down)
163
+
164
+ # middle
165
+ self.mid = nn.Module()
166
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
167
+ self.mid.attn_1 = AttnBlock(block_in)
168
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
169
+
170
+ # end
171
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
172
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
173
+
174
+ def forward(self, x: Tensor) -> Tensor:
175
+ # downsampling
176
+ hs = [self.conv_in(x)]
177
+ for i_level in range(self.num_resolutions):
178
+ for i_block in range(self.num_res_blocks):
179
+ h = self.down[i_level].block[i_block](hs[-1])
180
+ if len(self.down[i_level].attn) > 0:
181
+ h = self.down[i_level].attn[i_block](h)
182
+ hs.append(h)
183
+ if i_level != self.num_resolutions - 1:
184
+ hs.append(self.down[i_level].downsample(hs[-1]))
185
+
186
+ # middle
187
+ h = hs[-1]
188
+ h = self.mid.block_1(h)
189
+ h = self.mid.attn_1(h)
190
+ h = self.mid.block_2(h)
191
+ # end
192
+ h = self.norm_out(h)
193
+ h = swish(h)
194
+ h = self.conv_out(h)
195
+ return h
196
+
197
+
198
+ class Decoder(nn.Module):
199
+ def __init__(
200
+ self,
201
+ ch: int,
202
+ out_ch: int,
203
+ ch_mult: list[int],
204
+ num_res_blocks: int,
205
+ in_channels: int,
206
+ resolution: int,
207
+ z_channels: int,
208
+ ):
209
+ super().__init__()
210
+ self.ch = ch
211
+ self.num_resolutions = len(ch_mult)
212
+ self.num_res_blocks = num_res_blocks
213
+ self.resolution = resolution
214
+ self.in_channels = in_channels
215
+ self.ffactor = 2 ** (self.num_resolutions - 1)
216
+
217
+ # compute in_ch_mult, block_in and curr_res at lowest res
218
+ block_in = ch * ch_mult[self.num_resolutions - 1]
219
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
220
+ self.z_shape = (1, z_channels, curr_res, curr_res)
221
+
222
+ # z to block_in
223
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
224
+
225
+ # middle
226
+ self.mid = nn.Module()
227
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
228
+ self.mid.attn_1 = AttnBlock(block_in)
229
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
230
+
231
+ # upsampling
232
+ self.up = nn.ModuleList()
233
+ for i_level in reversed(range(self.num_resolutions)):
234
+ block = nn.ModuleList()
235
+ attn = nn.ModuleList()
236
+ block_out = ch * ch_mult[i_level]
237
+ for _ in range(self.num_res_blocks + 1):
238
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
239
+ block_in = block_out
240
+ up = nn.Module()
241
+ up.block = block
242
+ up.attn = attn
243
+ if i_level != 0:
244
+ up.upsample = Upsample(block_in)
245
+ curr_res = curr_res * 2
246
+ self.up.insert(0, up) # prepend to get consistent order
247
+
248
+ # end
249
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
250
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
251
+
252
+ def forward(self, z: Tensor) -> Tensor:
253
+ # z to block_in
254
+ h = self.conv_in(z)
255
+
256
+ # middle
257
+ h = self.mid.block_1(h)
258
+ h = self.mid.attn_1(h)
259
+ h = self.mid.block_2(h)
260
+
261
+ # upsampling
262
+ for i_level in reversed(range(self.num_resolutions)):
263
+ for i_block in range(self.num_res_blocks + 1):
264
+ h = self.up[i_level].block[i_block](h)
265
+ if len(self.up[i_level].attn) > 0:
266
+ h = self.up[i_level].attn[i_block](h)
267
+ if i_level != 0:
268
+ h = self.up[i_level].upsample(h)
269
+
270
+ # end
271
+ h = self.norm_out(h)
272
+ h = swish(h)
273
+ h = self.conv_out(h)
274
+ return h
275
+
276
+
277
+ class DiagonalGaussian(nn.Module):
278
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
279
+ super().__init__()
280
+ self.sample = sample
281
+ self.chunk_dim = chunk_dim
282
+
283
+ def forward(self, z: Tensor) -> Tensor:
284
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
285
+ if self.sample:
286
+ std = torch.exp(0.5 * logvar)
287
+ return mean + std * torch.randn_like(mean)
288
+ else:
289
+ return mean
290
+
291
+
292
+ class AutoEncoder(nn.Module):
293
+ def __init__(self, params: AutoEncoderParams):
294
+ super().__init__()
295
+ self.encoder = Encoder(
296
+ resolution=params.resolution,
297
+ in_channels=params.in_channels,
298
+ ch=params.ch,
299
+ ch_mult=params.ch_mult,
300
+ num_res_blocks=params.num_res_blocks,
301
+ z_channels=params.z_channels,
302
+ )
303
+ self.decoder = Decoder(
304
+ resolution=params.resolution,
305
+ in_channels=params.in_channels,
306
+ ch=params.ch,
307
+ out_ch=params.out_ch,
308
+ ch_mult=params.ch_mult,
309
+ num_res_blocks=params.num_res_blocks,
310
+ z_channels=params.z_channels,
311
+ )
312
+ self.reg = DiagonalGaussian()
313
+
314
+ self.scale_factor = params.scale_factor
315
+ self.shift_factor = params.shift_factor
316
+
317
+ def encode(self, x: Tensor) -> Tensor:
318
+ z = self.reg(self.encoder(x))
319
+ z = self.scale_factor * (z - self.shift_factor)
320
+ return z
321
+
322
+ def decode(self, z: Tensor) -> Tensor:
323
+ z = z / self.scale_factor + self.shift_factor
324
+ return self.decoder(z)
325
+
326
+ def forward(self, x: Tensor) -> Tensor:
327
+ return self.decode(self.encode(x))
uno/flux/modules/conditioner.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from torch import Tensor, nn
17
+ from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
18
+ T5Tokenizer)
19
+
20
+
21
+ class HFEmbedder(nn.Module):
22
+ def __init__(self, version: str, max_length: int, **hf_kwargs):
23
+ super().__init__()
24
+ self.is_clip = version.startswith("openai")
25
+ self.max_length = max_length
26
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
27
+
28
+ if self.is_clip:
29
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
30
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
31
+ else:
32
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
33
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
34
+
35
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
36
+
37
+ def forward(self, text: list[str]) -> Tensor:
38
+ batch_encoding = self.tokenizer(
39
+ text,
40
+ truncation=True,
41
+ max_length=self.max_length,
42
+ return_length=False,
43
+ return_overflowing_tokens=False,
44
+ padding="max_length",
45
+ return_tensors="pt",
46
+ )
47
+
48
+ outputs = self.hf_module(
49
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
50
+ attention_mask=None,
51
+ output_hidden_states=False,
52
+ )
53
+ return outputs[self.output_key]
uno/flux/modules/layers.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from dataclasses import dataclass
18
+
19
+ import torch
20
+ from einops import rearrange
21
+ from torch import Tensor, nn
22
+
23
+ from ..math import attention, rope
24
+ import torch.nn.functional as F
25
+
26
+ class EmbedND(nn.Module):
27
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
28
+ super().__init__()
29
+ self.dim = dim
30
+ self.theta = theta
31
+ self.axes_dim = axes_dim
32
+
33
+ def forward(self, ids: Tensor) -> Tensor:
34
+ n_axes = ids.shape[-1]
35
+ emb = torch.cat(
36
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
37
+ dim=-3,
38
+ )
39
+
40
+ return emb.unsqueeze(1)
41
+
42
+
43
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
44
+ """
45
+ Create sinusoidal timestep embeddings.
46
+ :param t: a 1-D Tensor of N indices, one per batch element.
47
+ These may be fractional.
48
+ :param dim: the dimension of the output.
49
+ :param max_period: controls the minimum frequency of the embeddings.
50
+ :return: an (N, D) Tensor of positional embeddings.
51
+ """
52
+ t = time_factor * t
53
+ half = dim // 2
54
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
55
+ t.device
56
+ )
57
+
58
+ args = t[:, None].float() * freqs[None]
59
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
60
+ if dim % 2:
61
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
62
+ if torch.is_floating_point(t):
63
+ embedding = embedding.to(t)
64
+ return embedding
65
+
66
+
67
+ class MLPEmbedder(nn.Module):
68
+ def __init__(self, in_dim: int, hidden_dim: int):
69
+ super().__init__()
70
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
71
+ self.silu = nn.SiLU()
72
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
73
+
74
+ def forward(self, x: Tensor) -> Tensor:
75
+ return self.out_layer(self.silu(self.in_layer(x)))
76
+
77
+
78
+ class RMSNorm(torch.nn.Module):
79
+ def __init__(self, dim: int):
80
+ super().__init__()
81
+ self.scale = nn.Parameter(torch.ones(dim))
82
+
83
+ def forward(self, x: Tensor):
84
+ x_dtype = x.dtype
85
+ x = x.float()
86
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
87
+ return (x * rrms).to(dtype=x_dtype) * self.scale
88
+
89
+
90
+ class QKNorm(torch.nn.Module):
91
+ def __init__(self, dim: int):
92
+ super().__init__()
93
+ self.query_norm = RMSNorm(dim)
94
+ self.key_norm = RMSNorm(dim)
95
+
96
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
97
+ q = self.query_norm(q)
98
+ k = self.key_norm(k)
99
+ return q.to(v), k.to(v)
100
+
101
+ class LoRALinearLayer(nn.Module):
102
+ def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
103
+ super().__init__()
104
+
105
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
106
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
107
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
108
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
109
+ self.network_alpha = network_alpha
110
+ self.rank = rank
111
+
112
+ nn.init.normal_(self.down.weight, std=1 / rank)
113
+ nn.init.zeros_(self.up.weight)
114
+
115
+ def forward(self, hidden_states):
116
+ orig_dtype = hidden_states.dtype
117
+ dtype = self.down.weight.dtype
118
+
119
+ down_hidden_states = self.down(hidden_states.to(dtype))
120
+ up_hidden_states = self.up(down_hidden_states)
121
+
122
+ if self.network_alpha is not None:
123
+ up_hidden_states *= self.network_alpha / self.rank
124
+
125
+ return up_hidden_states.to(orig_dtype)
126
+
127
+ class FLuxSelfAttnProcessor:
128
+ def __call__(self, attn, x, pe, **attention_kwargs):
129
+ qkv = attn.qkv(x)
130
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
131
+ q, k = attn.norm(q, k, v)
132
+ x = attention(q, k, v, pe=pe)
133
+ x = attn.proj(x)
134
+ return x
135
+
136
+ class LoraFluxAttnProcessor(nn.Module):
137
+
138
+ def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
139
+ super().__init__()
140
+ self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
141
+ self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha)
142
+ self.lora_weight = lora_weight
143
+
144
+
145
+ def __call__(self, attn, x, pe, **attention_kwargs):
146
+ qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight
147
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
148
+ q, k = attn.norm(q, k, v)
149
+ x = attention(q, k, v, pe=pe)
150
+ x = attn.proj(x) + self.proj_lora(x) * self.lora_weight
151
+ return x
152
+
153
+ class SelfAttention(nn.Module):
154
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
155
+ super().__init__()
156
+ self.num_heads = num_heads
157
+ head_dim = dim // num_heads
158
+
159
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
160
+ self.norm = QKNorm(head_dim)
161
+ self.proj = nn.Linear(dim, dim)
162
+ def forward():
163
+ pass
164
+
165
+
166
+ @dataclass
167
+ class ModulationOut:
168
+ shift: Tensor
169
+ scale: Tensor
170
+ gate: Tensor
171
+
172
+
173
+ class Modulation(nn.Module):
174
+ def __init__(self, dim: int, double: bool):
175
+ super().__init__()
176
+ self.is_double = double
177
+ self.multiplier = 6 if double else 3
178
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
179
+
180
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
181
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
182
+
183
+ return (
184
+ ModulationOut(*out[:3]),
185
+ ModulationOut(*out[3:]) if self.is_double else None,
186
+ )
187
+
188
+ class DoubleStreamBlockLoraProcessor(nn.Module):
189
+ def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
190
+ super().__init__()
191
+ self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
192
+ self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha)
193
+ self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
194
+ self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha)
195
+ self.lora_weight = lora_weight
196
+
197
+ def forward(self, attn, img, txt, vec, pe, **attention_kwargs):
198
+ img_mod1, img_mod2 = attn.img_mod(vec)
199
+ txt_mod1, txt_mod2 = attn.txt_mod(vec)
200
+
201
+ # prepare image for attention
202
+ img_modulated = attn.img_norm1(img)
203
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
204
+ img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight
205
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
206
+ img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
207
+
208
+ # prepare txt for attention
209
+ txt_modulated = attn.txt_norm1(txt)
210
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
211
+ txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight
212
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
213
+ txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
214
+
215
+ # run actual attention
216
+ q = torch.cat((txt_q, img_q), dim=2)
217
+ k = torch.cat((txt_k, img_k), dim=2)
218
+ v = torch.cat((txt_v, img_v), dim=2)
219
+
220
+ attn1 = attention(q, k, v, pe=pe)
221
+ txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
222
+
223
+ # calculate the img bloks
224
+ img = img + img_mod1.gate * (attn.img_attn.proj(img_attn) + self.proj_lora1(img_attn) * self.lora_weight)
225
+ img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
226
+
227
+ # calculate the txt bloks
228
+ txt = txt + txt_mod1.gate * (attn.txt_attn.proj(txt_attn) + self.proj_lora2(txt_attn) * self.lora_weight)
229
+ txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
230
+ return img, txt
231
+
232
+ class DoubleStreamBlockProcessor:
233
+ def __call__(self, attn, img, txt, vec, pe, **attention_kwargs):
234
+ img_mod1, img_mod2 = attn.img_mod(vec)
235
+ txt_mod1, txt_mod2 = attn.txt_mod(vec)
236
+
237
+ # prepare image for attention
238
+ img_modulated = attn.img_norm1(img)
239
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
240
+ img_qkv = attn.img_attn.qkv(img_modulated)
241
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
242
+ img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
243
+
244
+ # prepare txt for attention
245
+ txt_modulated = attn.txt_norm1(txt)
246
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
247
+ txt_qkv = attn.txt_attn.qkv(txt_modulated)
248
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
249
+ txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
250
+
251
+ # run actual attention
252
+ q = torch.cat((txt_q, img_q), dim=2)
253
+ k = torch.cat((txt_k, img_k), dim=2)
254
+ v = torch.cat((txt_v, img_v), dim=2)
255
+
256
+ attn1 = attention(q, k, v, pe=pe)
257
+ txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
258
+
259
+ # calculate the img bloks
260
+ img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
261
+ img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
262
+
263
+ # calculate the txt bloks
264
+ txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
265
+ txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
266
+ return img, txt
267
+
268
+ class DoubleStreamBlock(nn.Module):
269
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
270
+ super().__init__()
271
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
272
+ self.num_heads = num_heads
273
+ self.hidden_size = hidden_size
274
+ self.head_dim = hidden_size // num_heads
275
+
276
+ self.img_mod = Modulation(hidden_size, double=True)
277
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
278
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
279
+
280
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
281
+ self.img_mlp = nn.Sequential(
282
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
283
+ nn.GELU(approximate="tanh"),
284
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
285
+ )
286
+
287
+ self.txt_mod = Modulation(hidden_size, double=True)
288
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
289
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
290
+
291
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
292
+ self.txt_mlp = nn.Sequential(
293
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
294
+ nn.GELU(approximate="tanh"),
295
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
296
+ )
297
+ processor = DoubleStreamBlockProcessor()
298
+ self.set_processor(processor)
299
+
300
+ def set_processor(self, processor) -> None:
301
+ self.processor = processor
302
+
303
+ def get_processor(self):
304
+ return self.processor
305
+
306
+ def forward(
307
+ self,
308
+ img: Tensor,
309
+ txt: Tensor,
310
+ vec: Tensor,
311
+ pe: Tensor,
312
+ image_proj: Tensor = None,
313
+ ip_scale: float =1.0,
314
+ ) -> tuple[Tensor, Tensor]:
315
+ if image_proj is None:
316
+ return self.processor(self, img, txt, vec, pe)
317
+ else:
318
+ return self.processor(self, img, txt, vec, pe, image_proj, ip_scale)
319
+
320
+
321
+ class SingleStreamBlockLoraProcessor(nn.Module):
322
+ def __init__(self, dim: int, rank: int = 4, network_alpha = None, lora_weight: float = 1):
323
+ super().__init__()
324
+ self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
325
+ self.proj_lora = LoRALinearLayer(15360, dim, rank, network_alpha)
326
+ self.lora_weight = lora_weight
327
+
328
+ def forward(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
329
+
330
+ mod, _ = attn.modulation(vec)
331
+ x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
332
+ qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
333
+ qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight
334
+
335
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
336
+ q, k = attn.norm(q, k, v)
337
+
338
+ # compute attention
339
+ attn_1 = attention(q, k, v, pe=pe)
340
+
341
+ # compute activation in mlp stream, cat again and run second linear layer
342
+ output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
343
+ output = output + self.proj_lora(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) * self.lora_weight
344
+ output = x + mod.gate * output
345
+ return output
346
+
347
+
348
+ class SingleStreamBlockProcessor:
349
+ def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor, **attention_kwargs) -> Tensor:
350
+
351
+ mod, _ = attn.modulation(vec)
352
+ x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
353
+ qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
354
+
355
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
356
+ q, k = attn.norm(q, k, v)
357
+
358
+ # compute attention
359
+ attn_1 = attention(q, k, v, pe=pe)
360
+
361
+ # compute activation in mlp stream, cat again and run second linear layer
362
+ output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
363
+ output = x + mod.gate * output
364
+ return output
365
+
366
+ class SingleStreamBlock(nn.Module):
367
+ """
368
+ A DiT block with parallel linear layers as described in
369
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
370
+ """
371
+
372
+ def __init__(
373
+ self,
374
+ hidden_size: int,
375
+ num_heads: int,
376
+ mlp_ratio: float = 4.0,
377
+ qk_scale: float | None = None,
378
+ ):
379
+ super().__init__()
380
+ self.hidden_dim = hidden_size
381
+ self.num_heads = num_heads
382
+ self.head_dim = hidden_size // num_heads
383
+ self.scale = qk_scale or self.head_dim**-0.5
384
+
385
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
386
+ # qkv and mlp_in
387
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
388
+ # proj and mlp_out
389
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
390
+
391
+ self.norm = QKNorm(self.head_dim)
392
+
393
+ self.hidden_size = hidden_size
394
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
395
+
396
+ self.mlp_act = nn.GELU(approximate="tanh")
397
+ self.modulation = Modulation(hidden_size, double=False)
398
+
399
+ processor = SingleStreamBlockProcessor()
400
+ self.set_processor(processor)
401
+
402
+
403
+ def set_processor(self, processor) -> None:
404
+ self.processor = processor
405
+
406
+ def get_processor(self):
407
+ return self.processor
408
+
409
+ def forward(
410
+ self,
411
+ x: Tensor,
412
+ vec: Tensor,
413
+ pe: Tensor,
414
+ image_proj: Tensor | None = None,
415
+ ip_scale: float = 1.0,
416
+ ) -> Tensor:
417
+ if image_proj is None:
418
+ return self.processor(self, x, vec, pe)
419
+ else:
420
+ return self.processor(self, x, vec, pe, image_proj, ip_scale)
421
+
422
+
423
+
424
+ class LastLayer(nn.Module):
425
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
426
+ super().__init__()
427
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
428
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
429
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
430
+
431
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
432
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
433
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
434
+ x = self.linear(x)
435
+ return x
uno/flux/pipeline.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from typing import Literal
18
+
19
+ import torch
20
+ from einops import rearrange
21
+ from PIL import ExifTags, Image
22
+ import torchvision.transforms.functional as TVF
23
+
24
+ from uno.flux.modules.layers import (
25
+ DoubleStreamBlockLoraProcessor,
26
+ DoubleStreamBlockProcessor,
27
+ SingleStreamBlockLoraProcessor,
28
+ SingleStreamBlockProcessor,
29
+ )
30
+ from uno.flux.sampling import denoise, get_noise, get_schedule, prepare_multi_ip, unpack
31
+ from uno.flux.util import (
32
+ get_lora_rank,
33
+ load_ae,
34
+ load_checkpoint,
35
+ load_clip,
36
+ load_flow_model,
37
+ load_flow_model_only_lora,
38
+ load_flow_model_quintized,
39
+ load_t5,
40
+ )
41
+
42
+
43
+ def find_nearest_scale(image_h, image_w, predefined_scales):
44
+ """
45
+ 根据图片的高度和宽度,找到最近的预定义尺度。
46
+
47
+ :param image_h: 图片的高度
48
+ :param image_w: 图片的宽度
49
+ :param predefined_scales: 预定义尺度列表 [(h1, w1), (h2, w2), ...]
50
+ :return: 最近的预定义尺度 (h, w)
51
+ """
52
+ # 计算输入图片的长宽比
53
+ image_ratio = image_h / image_w
54
+
55
+ # 初始化变量以存储最小差异和最近的尺度
56
+ min_diff = float('inf')
57
+ nearest_scale = None
58
+
59
+ # 遍历所有预定义尺度,找到与输入图片长宽比最接近的尺度
60
+ for scale_h, scale_w in predefined_scales:
61
+ predefined_ratio = scale_h / scale_w
62
+ diff = abs(predefined_ratio - image_ratio)
63
+
64
+ if diff < min_diff:
65
+ min_diff = diff
66
+ nearest_scale = (scale_h, scale_w)
67
+
68
+ return nearest_scale
69
+
70
+ def preprocess_ref(raw_image: Image.Image, long_size: int = 512):
71
+ # 获取原始图像的宽度和高度
72
+ image_w, image_h = raw_image.size
73
+
74
+ # 计算长边和短边
75
+ if image_w >= image_h:
76
+ new_w = long_size
77
+ new_h = int((long_size / image_w) * image_h)
78
+ else:
79
+ new_h = long_size
80
+ new_w = int((long_size / image_h) * image_w)
81
+
82
+ # 按新的宽高进行等比例缩放
83
+ raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS)
84
+ target_w = new_w // 16 * 16
85
+ target_h = new_h // 16 * 16
86
+
87
+ # 计算裁剪的起始坐标以实现中心裁剪
88
+ left = (new_w - target_w) // 2
89
+ top = (new_h - target_h) // 2
90
+ right = left + target_w
91
+ bottom = top + target_h
92
+
93
+ # 进行中心裁剪
94
+ raw_image = raw_image.crop((left, top, right, bottom))
95
+
96
+ # 转换为 RGB 模式
97
+ raw_image = raw_image.convert("RGB")
98
+ return raw_image
99
+
100
+ class UNOPipeline:
101
+ def __init__(
102
+ self,
103
+ model_type: str,
104
+ device: torch.device,
105
+ offload: bool = False,
106
+ only_lora: bool = False,
107
+ lora_rank: int = 16
108
+ ):
109
+ self.device = device
110
+ self.offload = offload
111
+ self.model_type = model_type
112
+
113
+ self.clip = load_clip(self.device)
114
+ self.t5 = load_t5(self.device, max_length=512)
115
+ self.ae = load_ae(model_type, device="cpu" if offload else self.device)
116
+ if "fp8" in model_type:
117
+ self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.device)
118
+ elif only_lora:
119
+ self.model = load_flow_model_only_lora(
120
+ model_type, device="cpu" if offload else self.device, lora_rank=lora_rank
121
+ )
122
+ else:
123
+ self.model = load_flow_model(model_type, device="cpu" if offload else self.device)
124
+
125
+
126
+ def load_ckpt(self, ckpt_path):
127
+ if ckpt_path is not None:
128
+ from safetensors.torch import load_file as load_sft
129
+ print("Loading checkpoint to replace old keys")
130
+ # load_sft doesn't support torch.device
131
+ if ckpt_path.endswith('safetensors'):
132
+ sd = load_sft(ckpt_path, device='cpu')
133
+ missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True)
134
+ else:
135
+ dit_state = torch.load(ckpt_path, map_location='cpu')
136
+ sd = {}
137
+ for k in dit_state.keys():
138
+ sd[k.replace('module.','')] = dit_state[k]
139
+ missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True)
140
+ self.model.to(str(self.device))
141
+ print(f"missing keys: {missing}\n\n\n\n\nunexpected keys: {unexpected}")
142
+
143
+ def set_lora(self, local_path: str = None, repo_id: str = None,
144
+ name: str = None, lora_weight: int = 0.7):
145
+ checkpoint = load_checkpoint(local_path, repo_id, name)
146
+ self.update_model_with_lora(checkpoint, lora_weight)
147
+
148
+ def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7):
149
+ checkpoint = load_checkpoint(
150
+ None, self.hf_lora_collection, self.lora_types_to_names[lora_type]
151
+ )
152
+ self.update_model_with_lora(checkpoint, lora_weight)
153
+
154
+ def update_model_with_lora(self, checkpoint, lora_weight):
155
+ rank = get_lora_rank(checkpoint)
156
+ lora_attn_procs = {}
157
+
158
+ for name, _ in self.model.attn_processors.items():
159
+ lora_state_dict = {}
160
+ for k in checkpoint.keys():
161
+ if name in k:
162
+ lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight
163
+
164
+ if len(lora_state_dict):
165
+ if name.startswith("single_blocks"):
166
+ lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=3072, rank=rank)
167
+ else:
168
+ lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank)
169
+ lora_attn_procs[name].load_state_dict(lora_state_dict)
170
+ lora_attn_procs[name].to(self.device)
171
+ else:
172
+ if name.startswith("single_blocks"):
173
+ lora_attn_procs[name] = SingleStreamBlockProcessor()
174
+ else:
175
+ lora_attn_procs[name] = DoubleStreamBlockProcessor()
176
+
177
+ self.model.set_attn_processor(lora_attn_procs)
178
+
179
+
180
+ def __call__(
181
+ self,
182
+ prompt: str,
183
+ width: int = 512,
184
+ height: int = 512,
185
+ guidance: float = 4,
186
+ num_steps: int = 50,
187
+ seed: int = 123456789,
188
+ **kwargs
189
+ ):
190
+ width = 16 * (width // 16)
191
+ height = 16 * (height // 16)
192
+
193
+ return self.forward(
194
+ prompt,
195
+ width,
196
+ height,
197
+ guidance,
198
+ num_steps,
199
+ seed,
200
+ **kwargs
201
+ )
202
+
203
+ @torch.inference_mode()
204
+ def gradio_generate(
205
+ self,
206
+ prompt: str,
207
+ width: int,
208
+ height: int,
209
+ guidance: float,
210
+ num_steps: int,
211
+ seed: int,
212
+ image_prompt1: Image.Image,
213
+ image_prompt2: Image.Image,
214
+ image_prompt3: Image.Image,
215
+ image_prompt4: Image.Image,
216
+ ):
217
+ ref_imgs = [image_prompt1, image_prompt2, image_prompt3, image_prompt4]
218
+ ref_imgs = [img for img in ref_imgs if isinstance(img, Image.Image)]
219
+ ref_long_side = 512 if len(ref_imgs) <= 1 else 320
220
+ ref_imgs = [preprocess_ref(img, ref_long_side) for img in ref_imgs]
221
+
222
+ seed = seed if seed != -1 else torch.randint(0, 10 ** 8, (1,)).item()
223
+
224
+ img = self(prompt=prompt, width=width, height=height, guidance=guidance,
225
+ num_steps=num_steps, seed=seed, ref_imgs=ref_imgs)
226
+
227
+ filename = f"output/gradio/{seed}_{prompt[:20]}.png"
228
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
229
+ exif_data = Image.Exif()
230
+ exif_data[ExifTags.Base.Make] = "UNO"
231
+ exif_data[ExifTags.Base.Model] = self.model_type
232
+ info = f"{prompt=}, {seed=}, {width=}, {height=}, {guidance=}, {num_steps=}"
233
+ exif_data[ExifTags.Base.ImageDescription] = info
234
+ img.save(filename, format="png", exif=exif_data)
235
+ return img, filename
236
+
237
+ @torch.inference_mode
238
+ def forward(
239
+ self,
240
+ prompt: str,
241
+ width: int,
242
+ height: int,
243
+ guidance: float,
244
+ num_steps: int,
245
+ seed: int,
246
+ ref_imgs: list[Image.Image] | None = None,
247
+ pe: Literal['d', 'h', 'w', 'o'] = 'd',
248
+ ):
249
+ x = get_noise(
250
+ 1, height, width, device=self.device,
251
+ dtype=torch.bfloat16, seed=seed
252
+ )
253
+ timesteps = get_schedule(
254
+ num_steps,
255
+ (width // 8) * (height // 8) // (16 * 16),
256
+ shift=True,
257
+ )
258
+ if self.offload:
259
+ self.ae.encoder = self.ae.encoder.to(self.device)
260
+ x_1_refs = [
261
+ self.ae.encode(
262
+ (TVF.to_tensor(ref_img) * 2.0 - 1.0)
263
+ .unsqueeze(0).to(self.device, torch.float32)
264
+ ).to(torch.bfloat16)
265
+ for ref_img in ref_imgs
266
+ ]
267
+
268
+ if self.offload:
269
+ self.ae.encoder = self.offload_model_to_cpu(self.ae.encoder)
270
+ self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
271
+ inp_cond = prepare_multi_ip(
272
+ t5=self.t5, clip=self.clip,
273
+ img=x,
274
+ prompt=prompt, ref_imgs=x_1_refs, pe=pe
275
+ )
276
+
277
+ if self.offload:
278
+ self.offload_model_to_cpu(self.t5, self.clip)
279
+ self.model = self.model.to(self.device)
280
+
281
+ x = denoise(
282
+ self.model,
283
+ **inp_cond,
284
+ timesteps=timesteps,
285
+ guidance=guidance,
286
+ )
287
+
288
+ if self.offload:
289
+ self.offload_model_to_cpu(self.model)
290
+ self.ae.decoder.to(x.device)
291
+ x = unpack(x.float(), height, width)
292
+ x = self.ae.decode(x)
293
+ self.offload_model_to_cpu(self.ae.decoder)
294
+
295
+ x1 = x.clamp(-1, 1)
296
+ x1 = rearrange(x1[-1], "c h w -> h w c")
297
+ output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy())
298
+ return output_img
299
+
300
+ def offload_model_to_cpu(self, *models):
301
+ if not self.offload: return
302
+ for model in models:
303
+ model.cpu()
304
+ torch.cuda.empty_cache()
uno/flux/sampling.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from typing import Literal
18
+
19
+ import torch
20
+ from einops import rearrange, repeat
21
+ from torch import Tensor
22
+ from tqdm import tqdm
23
+
24
+ from .model import Flux
25
+ from .modules.conditioner import HFEmbedder
26
+
27
+
28
+ def get_noise(
29
+ num_samples: int,
30
+ height: int,
31
+ width: int,
32
+ device: torch.device,
33
+ dtype: torch.dtype,
34
+ seed: int,
35
+ ):
36
+ return torch.randn(
37
+ num_samples,
38
+ 16,
39
+ # allow for packing
40
+ 2 * math.ceil(height / 16),
41
+ 2 * math.ceil(width / 16),
42
+ device=device,
43
+ dtype=dtype,
44
+ generator=torch.Generator(device=device).manual_seed(seed),
45
+ )
46
+
47
+
48
+ def prepare(
49
+ t5: HFEmbedder,
50
+ clip: HFEmbedder,
51
+ img: Tensor,
52
+ prompt: str | list[str],
53
+ ref_img: None | Tensor=None,
54
+ pe: Literal['d', 'h', 'w', 'o'] ='d'
55
+ ) -> dict[str, Tensor]:
56
+ assert pe in ['d', 'h', 'w', 'o']
57
+ bs, c, h, w = img.shape
58
+ if bs == 1 and not isinstance(prompt, str):
59
+ bs = len(prompt)
60
+
61
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
62
+ if img.shape[0] == 1 and bs > 1:
63
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
64
+
65
+ img_ids = torch.zeros(h // 2, w // 2, 3)
66
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
67
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
68
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
69
+
70
+ if ref_img is not None:
71
+ _, _, ref_h, ref_w = ref_img.shape
72
+ ref_img = rearrange(ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
73
+ if ref_img.shape[0] == 1 and bs > 1:
74
+ ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
75
+ ref_img_ids = torch.zeros(ref_h // 2, ref_w // 2, 3)
76
+ # img id分别在宽高偏移各自最大值
77
+ h_offset = h // 2 if pe in {'d', 'h'} else 0
78
+ w_offset = w // 2 if pe in {'d', 'w'} else 0
79
+ ref_img_ids[..., 1] = ref_img_ids[..., 1] + torch.arange(ref_h // 2)[:, None] + h_offset
80
+ ref_img_ids[..., 2] = ref_img_ids[..., 2] + torch.arange(ref_w // 2)[None, :] + w_offset
81
+ ref_img_ids = repeat(ref_img_ids, "h w c -> b (h w) c", b=bs)
82
+
83
+ if isinstance(prompt, str):
84
+ prompt = [prompt]
85
+ txt = t5(prompt)
86
+ if txt.shape[0] == 1 and bs > 1:
87
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
88
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
89
+
90
+ vec = clip(prompt)
91
+ if vec.shape[0] == 1 and bs > 1:
92
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
93
+
94
+ if ref_img is not None:
95
+ return {
96
+ "img": img,
97
+ "img_ids": img_ids.to(img.device),
98
+ "ref_img": ref_img,
99
+ "ref_img_ids": ref_img_ids.to(img.device),
100
+ "txt": txt.to(img.device),
101
+ "txt_ids": txt_ids.to(img.device),
102
+ "vec": vec.to(img.device),
103
+ }
104
+ else:
105
+ return {
106
+ "img": img,
107
+ "img_ids": img_ids.to(img.device),
108
+ "txt": txt.to(img.device),
109
+ "txt_ids": txt_ids.to(img.device),
110
+ "vec": vec.to(img.device),
111
+ }
112
+
113
+ def prepare_multi_ip(
114
+ t5: HFEmbedder,
115
+ clip: HFEmbedder,
116
+ img: Tensor,
117
+ prompt: str | list[str],
118
+ ref_imgs: list[Tensor] | None = None,
119
+ pe: Literal['d', 'h', 'w', 'o'] = 'd'
120
+ ) -> dict[str, Tensor]:
121
+ assert pe in ['d', 'h', 'w', 'o']
122
+ bs, c, h, w = img.shape
123
+ if bs == 1 and not isinstance(prompt, str):
124
+ bs = len(prompt)
125
+
126
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
127
+ if img.shape[0] == 1 and bs > 1:
128
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
129
+
130
+ img_ids = torch.zeros(h // 2, w // 2, 3)
131
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
132
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
133
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
134
+
135
+ ref_img_ids = []
136
+ ref_imgs_list = []
137
+ pe_shift_w, pe_shift_h = w // 2, h // 2
138
+ for ref_img in ref_imgs:
139
+ _, _, ref_h1, ref_w1 = ref_img.shape
140
+ ref_img = rearrange(ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
141
+ if ref_img.shape[0] == 1 and bs > 1:
142
+ ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
143
+ ref_img_ids1 = torch.zeros(ref_h1 // 2, ref_w1 // 2, 3)
144
+ # img id分别���宽高偏移各自最大值
145
+ h_offset = pe_shift_h if pe in {'d', 'h'} else 0
146
+ w_offset = pe_shift_w if pe in {'d', 'w'} else 0
147
+ ref_img_ids1[..., 1] = ref_img_ids1[..., 1] + torch.arange(ref_h1 // 2)[:, None] + h_offset
148
+ ref_img_ids1[..., 2] = ref_img_ids1[..., 2] + torch.arange(ref_w1 // 2)[None, :] + w_offset
149
+ ref_img_ids1 = repeat(ref_img_ids1, "h w c -> b (h w) c", b=bs)
150
+ ref_img_ids.append(ref_img_ids1)
151
+ ref_imgs_list.append(ref_img)
152
+
153
+ # 更新pe shift
154
+ pe_shift_h += ref_h1 // 2
155
+ pe_shift_w += ref_w1 // 2
156
+
157
+ if isinstance(prompt, str):
158
+ prompt = [prompt]
159
+ txt = t5(prompt)
160
+ if txt.shape[0] == 1 and bs > 1:
161
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
162
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
163
+
164
+ vec = clip(prompt)
165
+ if vec.shape[0] == 1 and bs > 1:
166
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
167
+
168
+ return {
169
+ "img": img,
170
+ "img_ids": img_ids.to(img.device),
171
+ "ref_img": tuple(ref_imgs_list),
172
+ "ref_img_ids": [ref_img_id.to(img.device) for ref_img_id in ref_img_ids],
173
+ "txt": txt.to(img.device),
174
+ "txt_ids": txt_ids.to(img.device),
175
+ "vec": vec.to(img.device),
176
+ }
177
+
178
+
179
+ def time_shift(mu: float, sigma: float, t: Tensor):
180
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
181
+
182
+
183
+ def get_lin_function(
184
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
185
+ ):
186
+ m = (y2 - y1) / (x2 - x1)
187
+ b = y1 - m * x1
188
+ return lambda x: m * x + b
189
+
190
+
191
+ def get_schedule(
192
+ num_steps: int,
193
+ image_seq_len: int,
194
+ base_shift: float = 0.5,
195
+ max_shift: float = 1.15,
196
+ shift: bool = True,
197
+ ) -> list[float]:
198
+ # extra step for zero
199
+ timesteps = torch.linspace(1, 0, num_steps + 1)
200
+
201
+ # shifting the schedule to favor high timesteps for higher signal images
202
+ if shift:
203
+ # eastimate mu based on linear estimation between two points
204
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
205
+ timesteps = time_shift(mu, 1.0, timesteps)
206
+
207
+ return timesteps.tolist()
208
+
209
+
210
+ def denoise(
211
+ model: Flux,
212
+ # model input
213
+ img: Tensor,
214
+ img_ids: Tensor,
215
+ txt: Tensor,
216
+ txt_ids: Tensor,
217
+ vec: Tensor,
218
+ # sampling parameters
219
+ timesteps: list[float],
220
+ guidance: float = 4.0,
221
+ ref_img: Tensor=None,
222
+ ref_img_ids: Tensor=None,
223
+ ):
224
+ i = 0
225
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
226
+ for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1):
227
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
228
+ pred = model(
229
+ img=img,
230
+ img_ids=img_ids,
231
+ ref_img=ref_img,
232
+ ref_img_ids=ref_img_ids,
233
+ txt=txt,
234
+ txt_ids=txt_ids,
235
+ y=vec,
236
+ timesteps=t_vec,
237
+ guidance=guidance_vec
238
+ )
239
+ img = img + (t_prev - t_curr) * pred
240
+ i += 1
241
+ return img
242
+
243
+
244
+ def unpack(x: Tensor, height: int, width: int) -> Tensor:
245
+ return rearrange(
246
+ x,
247
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
248
+ h=math.ceil(height / 16),
249
+ w=math.ceil(width / 16),
250
+ ph=2,
251
+ pw=2,
252
+ )
uno/flux/util.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from dataclasses import dataclass
18
+
19
+ import torch
20
+ import json
21
+ import numpy as np
22
+ from huggingface_hub import hf_hub_download
23
+ from safetensors import safe_open
24
+ from safetensors.torch import load_file as load_sft
25
+
26
+ from .model import Flux, FluxParams
27
+ from .modules.autoencoder import AutoEncoder, AutoEncoderParams
28
+ from .modules.conditioner import HFEmbedder
29
+
30
+ import re
31
+ from uno.flux.modules.layers import DoubleStreamBlockLoraProcessor, SingleStreamBlockLoraProcessor
32
+ def load_model(ckpt, device='cpu'):
33
+ if ckpt.endswith('safetensors'):
34
+ from safetensors import safe_open
35
+ pl_sd = {}
36
+ with safe_open(ckpt, framework="pt", device=device) as f:
37
+ for k in f.keys():
38
+ pl_sd[k] = f.get_tensor(k)
39
+ else:
40
+ pl_sd = torch.load(ckpt, map_location=device)
41
+ return pl_sd
42
+
43
+ def load_safetensors(path):
44
+ tensors = {}
45
+ with safe_open(path, framework="pt", device="cpu") as f:
46
+ for key in f.keys():
47
+ tensors[key] = f.get_tensor(key)
48
+ return tensors
49
+
50
+ def get_lora_rank(checkpoint):
51
+ for k in checkpoint.keys():
52
+ if k.endswith(".down.weight"):
53
+ return checkpoint[k].shape[0]
54
+
55
+ def load_checkpoint(local_path, repo_id, name):
56
+ if local_path is not None:
57
+ if '.safetensors' in local_path:
58
+ print(f"Loading .safetensors checkpoint from {local_path}")
59
+ checkpoint = load_safetensors(local_path)
60
+ else:
61
+ print(f"Loading checkpoint from {local_path}")
62
+ checkpoint = torch.load(local_path, map_location='cpu')
63
+ elif repo_id is not None and name is not None:
64
+ print(f"Loading checkpoint {name} from repo id {repo_id}")
65
+ checkpoint = load_from_repo_id(repo_id, name)
66
+ else:
67
+ raise ValueError(
68
+ "LOADING ERROR: you must specify local_path or repo_id with name in HF to download"
69
+ )
70
+ return checkpoint
71
+
72
+
73
+ def c_crop(image):
74
+ width, height = image.size
75
+ new_size = min(width, height)
76
+ left = (width - new_size) / 2
77
+ top = (height - new_size) / 2
78
+ right = (width + new_size) / 2
79
+ bottom = (height + new_size) / 2
80
+ return image.crop((left, top, right, bottom))
81
+
82
+ def pad64(x):
83
+ return int(np.ceil(float(x) / 64.0) * 64 - x)
84
+
85
+ def HWC3(x):
86
+ assert x.dtype == np.uint8
87
+ if x.ndim == 2:
88
+ x = x[:, :, None]
89
+ assert x.ndim == 3
90
+ H, W, C = x.shape
91
+ assert C == 1 or C == 3 or C == 4
92
+ if C == 3:
93
+ return x
94
+ if C == 1:
95
+ return np.concatenate([x, x, x], axis=2)
96
+ if C == 4:
97
+ color = x[:, :, 0:3].astype(np.float32)
98
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
99
+ y = color * alpha + 255.0 * (1.0 - alpha)
100
+ y = y.clip(0, 255).astype(np.uint8)
101
+ return y
102
+
103
+ @dataclass
104
+ class ModelSpec:
105
+ params: FluxParams
106
+ ae_params: AutoEncoderParams
107
+ ckpt_path: str | None
108
+ ae_path: str | None
109
+ repo_id: str | None
110
+ repo_flow: str | None
111
+ repo_ae: str | None
112
+ repo_id_ae: str | None
113
+
114
+
115
+ configs = {
116
+ "flux-dev": ModelSpec(
117
+ repo_id="black-forest-labs/FLUX.1-dev",
118
+ repo_id_ae="black-forest-labs/FLUX.1-dev",
119
+ repo_flow="flux1-dev.safetensors",
120
+ repo_ae="ae.safetensors",
121
+ ckpt_path=os.getenv("FLUX_DEV"),
122
+ params=FluxParams(
123
+ in_channels=64,
124
+ vec_in_dim=768,
125
+ context_in_dim=4096,
126
+ hidden_size=3072,
127
+ mlp_ratio=4.0,
128
+ num_heads=24,
129
+ depth=19,
130
+ depth_single_blocks=38,
131
+ axes_dim=[16, 56, 56],
132
+ theta=10_000,
133
+ qkv_bias=True,
134
+ guidance_embed=True,
135
+ ),
136
+ ae_path=os.getenv("AE"),
137
+ ae_params=AutoEncoderParams(
138
+ resolution=256,
139
+ in_channels=3,
140
+ ch=128,
141
+ out_ch=3,
142
+ ch_mult=[1, 2, 4, 4],
143
+ num_res_blocks=2,
144
+ z_channels=16,
145
+ scale_factor=0.3611,
146
+ shift_factor=0.1159,
147
+ ),
148
+ ),
149
+ "flux-dev-fp8": ModelSpec(
150
+ repo_id="XLabs-AI/flux-dev-fp8",
151
+ repo_id_ae="black-forest-labs/FLUX.1-dev",
152
+ repo_flow="flux-dev-fp8.safetensors",
153
+ repo_ae="ae.safetensors",
154
+ ckpt_path=os.getenv("FLUX_DEV_FP8"),
155
+ params=FluxParams(
156
+ in_channels=64,
157
+ vec_in_dim=768,
158
+ context_in_dim=4096,
159
+ hidden_size=3072,
160
+ mlp_ratio=4.0,
161
+ num_heads=24,
162
+ depth=19,
163
+ depth_single_blocks=38,
164
+ axes_dim=[16, 56, 56],
165
+ theta=10_000,
166
+ qkv_bias=True,
167
+ guidance_embed=True,
168
+ ),
169
+ ae_path=os.getenv("AE"),
170
+ ae_params=AutoEncoderParams(
171
+ resolution=256,
172
+ in_channels=3,
173
+ ch=128,
174
+ out_ch=3,
175
+ ch_mult=[1, 2, 4, 4],
176
+ num_res_blocks=2,
177
+ z_channels=16,
178
+ scale_factor=0.3611,
179
+ shift_factor=0.1159,
180
+ ),
181
+ ),
182
+ "flux-schnell": ModelSpec(
183
+ repo_id="black-forest-labs/FLUX.1-schnell",
184
+ repo_id_ae="black-forest-labs/FLUX.1-dev",
185
+ repo_flow="flux1-schnell.safetensors",
186
+ repo_ae="ae.safetensors",
187
+ ckpt_path=os.getenv("FLUX_SCHNELL"),
188
+ params=FluxParams(
189
+ in_channels=64,
190
+ vec_in_dim=768,
191
+ context_in_dim=4096,
192
+ hidden_size=3072,
193
+ mlp_ratio=4.0,
194
+ num_heads=24,
195
+ depth=19,
196
+ depth_single_blocks=38,
197
+ axes_dim=[16, 56, 56],
198
+ theta=10_000,
199
+ qkv_bias=True,
200
+ guidance_embed=False,
201
+ ),
202
+ ae_path=os.getenv("AE"),
203
+ ae_params=AutoEncoderParams(
204
+ resolution=256,
205
+ in_channels=3,
206
+ ch=128,
207
+ out_ch=3,
208
+ ch_mult=[1, 2, 4, 4],
209
+ num_res_blocks=2,
210
+ z_channels=16,
211
+ scale_factor=0.3611,
212
+ shift_factor=0.1159,
213
+ ),
214
+ ),
215
+ }
216
+
217
+
218
+ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
219
+ if len(missing) > 0 and len(unexpected) > 0:
220
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
221
+ print("\n" + "-" * 79 + "\n")
222
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
223
+ elif len(missing) > 0:
224
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
225
+ elif len(unexpected) > 0:
226
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
227
+
228
+ def load_from_repo_id(repo_id, checkpoint_name):
229
+ ckpt_path = hf_hub_download(repo_id, checkpoint_name)
230
+ sd = load_sft(ckpt_path, device='cpu')
231
+ return sd
232
+
233
+ def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
234
+ # Loading Flux
235
+ print("Init model")
236
+ ckpt_path = configs[name].ckpt_path
237
+ if (
238
+ ckpt_path is None
239
+ and configs[name].repo_id is not None
240
+ and configs[name].repo_flow is not None
241
+ and hf_download
242
+ ):
243
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
244
+
245
+ with torch.device("meta" if ckpt_path is not None else device):
246
+ model = Flux(configs[name].params).to(torch.bfloat16)
247
+
248
+ if ckpt_path is not None:
249
+ print("Loading checkpoint")
250
+ # load_sft doesn't support torch.device
251
+ sd = load_model(ckpt_path, device=str(device))
252
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
253
+ print_load_warning(missing, unexpected)
254
+ return model
255
+
256
+ def load_flow_model_only_lora(
257
+ name: str,
258
+ device: str | torch.device = "cuda",
259
+ hf_download: bool = True,
260
+ lora_rank: int = 16
261
+ ):
262
+ # Loading Flux
263
+ print("Init model")
264
+ ckpt_path = configs[name].ckpt_path
265
+ if (
266
+ ckpt_path is None
267
+ and configs[name].repo_id is not None
268
+ and configs[name].repo_flow is not None
269
+ and hf_download
270
+ ):
271
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors"))
272
+
273
+ if hf_download:
274
+ # lora_ckpt_path = hf_hub_download("bytedance-research/UNO", "dit_lora.safetensors")
275
+ try:
276
+ lora_ckpt_path = hf_hub_download("bytedance-research/UNO", "dit_lora.safetensors")
277
+ except:
278
+ lora_ckpt_path = os.environ.get("LORA", None)
279
+ else:
280
+ lora_ckpt_path = os.environ.get("LORA", None)
281
+
282
+ with torch.device("meta" if ckpt_path is not None else device):
283
+ model = Flux(configs[name].params)
284
+
285
+
286
+ model = set_lora(model, lora_rank, device="meta" if lora_ckpt_path is not None else device)
287
+
288
+ if ckpt_path is not None:
289
+ print("Loading lora")
290
+ lora_sd = load_sft(lora_ckpt_path, device=str(device)) if lora_ckpt_path.endswith("safetensors")\
291
+ else torch.load(lora_ckpt_path, map_location='cpu')
292
+
293
+ print("Loading main checkpoint")
294
+ # load_sft doesn't support torch.device
295
+
296
+ if ckpt_path.endswith('safetensors'):
297
+ sd = load_sft(ckpt_path, device=str(device))
298
+ sd.update(lora_sd)
299
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
300
+ else:
301
+ dit_state = torch.load(ckpt_path, map_location='cpu')
302
+ sd = {}
303
+ for k in dit_state.keys():
304
+ sd[k.replace('module.','')] = dit_state[k]
305
+ sd.update(lora_sd)
306
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
307
+ model.to(str(device))
308
+ print_load_warning(missing, unexpected)
309
+ return model
310
+
311
+
312
+ def set_lora(
313
+ model: Flux,
314
+ lora_rank: int,
315
+ double_blocks_indices: list[int] | None = None,
316
+ single_blocks_indices: list[int] | None = None,
317
+ device: str | torch.device = "cpu",
318
+ ) -> Flux:
319
+ double_blocks_indices = list(range(model.params.depth)) if double_blocks_indices is None else double_blocks_indices
320
+ single_blocks_indices = list(range(model.params.depth_single_blocks)) if single_blocks_indices is None \
321
+ else single_blocks_indices
322
+
323
+ lora_attn_procs = {}
324
+ with torch.device(device):
325
+ for name, attn_processor in model.attn_processors.items():
326
+ match = re.search(r'\.(\d+)\.', name)
327
+ if match:
328
+ layer_index = int(match.group(1))
329
+
330
+ if name.startswith("double_blocks") and layer_index in double_blocks_indices:
331
+ lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank)
332
+ elif name.startswith("single_blocks") and layer_index in single_blocks_indices:
333
+ lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank)
334
+ else:
335
+ lora_attn_procs[name] = attn_processor
336
+ model.set_attn_processor(lora_attn_procs)
337
+ return model
338
+
339
+
340
+ def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
341
+ # Loading Flux
342
+ from optimum.quanto import requantize
343
+ print("Init model")
344
+ ckpt_path = configs[name].ckpt_path
345
+ if (
346
+ ckpt_path is None
347
+ and configs[name].repo_id is not None
348
+ and configs[name].repo_flow is not None
349
+ and hf_download
350
+ ):
351
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
352
+ json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json')
353
+
354
+
355
+ model = Flux(configs[name].params).to(torch.bfloat16)
356
+
357
+ print("Loading checkpoint")
358
+ # load_sft doesn't support torch.device
359
+ sd = load_sft(ckpt_path, device='cpu')
360
+ with open(json_path, "r") as f:
361
+ quantization_map = json.load(f)
362
+ print("Start a quantization process...")
363
+ requantize(model, sd, quantization_map, device=device)
364
+ print("Model is quantized!")
365
+ return model
366
+
367
+ def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
368
+ # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
369
+ version = os.environ.get("T5", "xlabs-ai/xflux_text_encoders")
370
+ return HFEmbedder(version, max_length=max_length, torch_dtype=torch.bfloat16).to(device)
371
+
372
+ def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
373
+ version = os.environ.get("CLIP", "openai/clip-vit-large-patch14")
374
+ return HFEmbedder(version, max_length=77, torch_dtype=torch.bfloat16).to(device)
375
+
376
+
377
+ def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
378
+ ckpt_path = configs[name].ae_path
379
+ if (
380
+ ckpt_path is None
381
+ and configs[name].repo_id is not None
382
+ and configs[name].repo_ae is not None
383
+ and hf_download
384
+ ):
385
+ ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae)
386
+
387
+ # Loading the autoencoder
388
+ print("Init AE")
389
+ with torch.device("meta" if ckpt_path is not None else device):
390
+ ae = AutoEncoder(configs[name].ae_params)
391
+
392
+ if ckpt_path is not None:
393
+ sd = load_sft(ckpt_path, device=str(device))
394
+ missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
395
+ print_load_warning(missing, unexpected)
396
+ return ae
uno/utils/convert_yaml_to_args_file.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import yaml
17
+
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument("--yaml", type=str, required=True)
20
+ parser.add_argument("--arg", type=str, required=True)
21
+ args = parser.parse_args()
22
+
23
+
24
+ with open(args.yaml, "r") as f:
25
+ data = yaml.safe_load(f)
26
+
27
+ with open(args.arg, "w") as f:
28
+ for k, v in data.items():
29
+ if isinstance(v, list):
30
+ v = list(map(str, v))
31
+ v = " ".join(v)
32
+ if v is None:
33
+ continue
34
+ print(f"--{k} {v}", end=" ", file=f)