Spaces:
Running
on
Zero
Running
on
Zero
Upload 39 files
Browse files- .gitattributes +15 -0
- app.py +720 -0
- requirements.txt +34 -0
- src/__init__.py +0 -0
- src/attention_processor.py +146 -0
- src/detail_encoder.py +118 -0
- src/jsonl_datasets.py +186 -0
- src/kontext_custom_pipeline.py +0 -0
- src/layers.py +673 -0
- src/lora_helper.py +267 -0
- src/pipeline.py +805 -0
- src/prompt_helper.py +205 -0
- src/transformer_flux.py +583 -0
- src/transformer_with_loss.py +504 -0
- test_imgs/2.png +3 -0
- test_imgs/3.png +3 -0
- test_imgs/generated_1.png +3 -0
- test_imgs/generated_1_bbox.png +3 -0
- test_imgs/generated_2.png +3 -0
- test_imgs/generated_2_bbox.png +3 -0
- test_imgs/generated_3.png +3 -0
- test_imgs/generated_3_bbox.png +3 -0
- test_imgs/generated_3_bbox_1.png +3 -0
- test_imgs/product_1.jpg +0 -0
- test_imgs/product_1_bbox.png +3 -0
- test_imgs/product_2.png +3 -0
- test_imgs/product_2_bbox.png +3 -0
- test_imgs/product_3.png +3 -0
- test_imgs/product_3_bbox.png +3 -0
- test_imgs/product_3_bbox_1.png +3 -0
- uno/dataset/uno.py +132 -0
- uno/flux/math.py +45 -0
- uno/flux/model.py +222 -0
- uno/flux/modules/autoencoder.py +327 -0
- uno/flux/modules/conditioner.py +53 -0
- uno/flux/modules/layers.py +435 -0
- uno/flux/pipeline.py +304 -0
- uno/flux/sampling.py +252 -0
- uno/flux/util.py +396 -0
- uno/utils/convert_yaml_to_args_file.py +34 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,18 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
test_imgs/2.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
test_imgs/3.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
test_imgs/generated_1_bbox.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
test_imgs/generated_1.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
test_imgs/generated_2_bbox.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
test_imgs/generated_2.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
test_imgs/generated_3_bbox_1.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
test_imgs/generated_3_bbox.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
test_imgs/generated_3.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
test_imgs/product_1_bbox.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
test_imgs/product_2_bbox.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
test_imgs/product_2.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
test_imgs/product_3_bbox_1.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
test_imgs/product_3_bbox.png filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
test_imgs/product_3.png filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
|
@@ -0,0 +1,720 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import os
|
| 2 |
+
# os.system("pip uninstall -y gradio")
|
| 3 |
+
# os.system("pip install gradio==5.49.1")
|
| 4 |
+
# os.system("pip uninstall -y gradio_image_annotation")
|
| 5 |
+
# os.system("pip install gradio_image_annotation==0.4.1")
|
| 6 |
+
# os.system("pip uninstall -y huggingface-hub")
|
| 7 |
+
# os.system("pip install huggingface-hub==0.35.3")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import gradio as gr
|
| 13 |
+
from gradio_image_annotation import image_annotator
|
| 14 |
+
import numpy as np
|
| 15 |
+
import random
|
| 16 |
+
|
| 17 |
+
from diffusers import FluxTransformer2DModel, FluxKontextPipeline
|
| 18 |
+
from safetensors.torch import load_file
|
| 19 |
+
from huggingface_hub import hf_hub_download
|
| 20 |
+
from src.lora_helper import set_single_lora
|
| 21 |
+
from src.detail_encoder import DetailEncoder
|
| 22 |
+
from src.kontext_custom_pipeline import FluxKontextPipelineWithPhotoEncoderAddTokens
|
| 23 |
+
# import spaces
|
| 24 |
+
from uno.flux.pipeline import UNOPipeline
|
| 25 |
+
|
| 26 |
+
hf_hub_download(
|
| 27 |
+
repo_id="ziheng1234/ImageCritic",
|
| 28 |
+
filename="detail_encoder.safetensors",
|
| 29 |
+
local_dir="models" # 下载到本地 models/ 目录
|
| 30 |
+
)
|
| 31 |
+
hf_hub_download(
|
| 32 |
+
repo_id="ziheng1234/ImageCritic",
|
| 33 |
+
filename="lora.safetensors",
|
| 34 |
+
local_dir="models"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
from huggingface_hub import snapshot_download
|
| 38 |
+
repo_id = "ziheng1234/kontext"
|
| 39 |
+
local_dir = "./kontext"
|
| 40 |
+
snapshot_download(
|
| 41 |
+
repo_id=repo_id,
|
| 42 |
+
local_dir=local_dir,
|
| 43 |
+
repo_type="model",
|
| 44 |
+
resume_download=True,
|
| 45 |
+
max_workers=8
|
| 46 |
+
)
|
| 47 |
+
base_path = "./models"
|
| 48 |
+
detail_encoder_path = f"{base_path}/detail_encoder.safetensors"
|
| 49 |
+
kontext_lora_path = f"{base_path}/lora.safetensors"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def pick_kontext_resolution(w: int, h: int) -> tuple[int, int]:
|
| 53 |
+
PREFERRED_KONTEXT_RESOLUTIONS = [
|
| 54 |
+
(672, 1568), (688, 1504), (720, 1456), (752, 1392),
|
| 55 |
+
(800, 1328), (832, 1248), (880, 1184), (944, 1104),
|
| 56 |
+
(1024, 1024), (1104, 944), (1184, 880), (1248, 832),
|
| 57 |
+
(1328, 800), (1392, 752), (1456, 720), (1504, 688), (1568, 672),
|
| 58 |
+
]
|
| 59 |
+
target_ratio = w / h
|
| 60 |
+
return min(
|
| 61 |
+
PREFERRED_KONTEXT_RESOLUTIONS,
|
| 62 |
+
key=lambda wh: abs((wh[0] / wh[1]) - target_ratio),
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
MAX_SEED = np.iinfo(np.int32).max
|
| 67 |
+
|
| 68 |
+
device = None
|
| 69 |
+
pipeline = None
|
| 70 |
+
transformer = None
|
| 71 |
+
detail_encoder = None
|
| 72 |
+
stage1_pipeline = None
|
| 73 |
+
|
| 74 |
+
@spaces.GPU(duration=200)
|
| 75 |
+
def load_stage1_model():
|
| 76 |
+
global stage1_pipeline, device
|
| 77 |
+
|
| 78 |
+
if stage1_pipeline is not None:
|
| 79 |
+
return
|
| 80 |
+
|
| 81 |
+
print("加载 Stage 1 UNO Pipeline...")
|
| 82 |
+
if device is None:
|
| 83 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 84 |
+
|
| 85 |
+
model_type = "flux-dev"
|
| 86 |
+
stage1_pipeline = UNOPipeline(model_type, device, offload=False, only_lora=True, lora_rank=512)
|
| 87 |
+
print("Stage 1 模型加载完成!")
|
| 88 |
+
|
| 89 |
+
@spaces.GPU(duration=200)
|
| 90 |
+
def load_models():
|
| 91 |
+
global device, pipeline, transformer, detail_encoder
|
| 92 |
+
|
| 93 |
+
if pipeline is not None and transformer is not None and detail_encoder is not None:
|
| 94 |
+
return
|
| 95 |
+
|
| 96 |
+
print("CUDA 可用:", torch.cuda.is_available())
|
| 97 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 98 |
+
print("使用设备:", device)
|
| 99 |
+
|
| 100 |
+
dtype = torch.bfloat16 if "cuda" in device else torch.float32
|
| 101 |
+
|
| 102 |
+
print("加载 FluxKontextPipelineWithPhotoEncoderAddTokens...")
|
| 103 |
+
pipeline_local = FluxKontextPipelineWithPhotoEncoderAddTokens.from_pretrained(
|
| 104 |
+
"./kontext",
|
| 105 |
+
torch_dtype=dtype,
|
| 106 |
+
)
|
| 107 |
+
pipeline_local.to(device)
|
| 108 |
+
|
| 109 |
+
print("加载 FluxTransformer2DModel...")
|
| 110 |
+
transformer_local = FluxTransformer2DModel.from_pretrained(
|
| 111 |
+
"./kontext",
|
| 112 |
+
subfolder="transformer",
|
| 113 |
+
torch_dtype=dtype,
|
| 114 |
+
)
|
| 115 |
+
transformer_local.to(device)
|
| 116 |
+
|
| 117 |
+
print("加载 detail_encoder 权重...")
|
| 118 |
+
state_dict = load_file(detail_encoder_path)
|
| 119 |
+
detail_encoder_local = DetailEncoder().to(dtype=transformer_local.dtype, device=device)
|
| 120 |
+
detail_encoder_local.to(device)
|
| 121 |
+
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
for name, param in detail_encoder_local.named_parameters():
|
| 124 |
+
if name in state_dict:
|
| 125 |
+
added = state_dict[name].to(param.device)
|
| 126 |
+
param.add_(added)
|
| 127 |
+
|
| 128 |
+
pipeline_local.transformer = transformer_local
|
| 129 |
+
pipeline_local.detail_encoder = detail_encoder_local
|
| 130 |
+
|
| 131 |
+
print("加载 LoRA...")
|
| 132 |
+
set_single_lora(pipeline_local.transformer, kontext_lora_path, lora_weights=[1.0])
|
| 133 |
+
|
| 134 |
+
print("模型加载完成!")
|
| 135 |
+
|
| 136 |
+
# 写回全局变量
|
| 137 |
+
pipeline = pipeline_local
|
| 138 |
+
transformer = transformer_local
|
| 139 |
+
detail_encoder = detail_encoder_local
|
| 140 |
+
|
| 141 |
+
@spaces.GPU(duration=200)
|
| 142 |
+
def generate_image_method1(input_image, prompt, width, height, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28):
|
| 143 |
+
"""
|
| 144 |
+
Stage 1 - Method 1: UNO image generation
|
| 145 |
+
"""
|
| 146 |
+
load_stage1_model()
|
| 147 |
+
global stage1_pipeline
|
| 148 |
+
|
| 149 |
+
if randomize_seed:
|
| 150 |
+
seed = -1
|
| 151 |
+
|
| 152 |
+
try:
|
| 153 |
+
# UNO pipeline uses gradio_generate interface
|
| 154 |
+
output_image, output_file = stage1_pipeline.gradio_generate(
|
| 155 |
+
prompt=prompt,
|
| 156 |
+
width=int(width),
|
| 157 |
+
height=int(height),
|
| 158 |
+
guidance=guidance_scale,
|
| 159 |
+
num_steps=steps,
|
| 160 |
+
seed=seed,
|
| 161 |
+
image_prompt1=input_image,
|
| 162 |
+
image_prompt2=None,
|
| 163 |
+
image_prompt3=None,
|
| 164 |
+
image_prompt4=None,
|
| 165 |
+
)
|
| 166 |
+
used_seed = seed if seed != -1 else random.randint(0, MAX_SEED)
|
| 167 |
+
return output_image, used_seed
|
| 168 |
+
except Exception as e:
|
| 169 |
+
print(f"Stage 1 生成图像时发生错误: {e}")
|
| 170 |
+
raise gr.Error(f"生成失败:{str(e)}")
|
| 171 |
+
|
| 172 |
+
def extract_first_box(annotations: dict):
|
| 173 |
+
"""
|
| 174 |
+
从 gradio_image_annotation 的返回中拿第一个 bbox 和对应的 PIL 图像及 patch
|
| 175 |
+
|
| 176 |
+
如果没有 bbox,则自动使用整张图作为 bbox。
|
| 177 |
+
"""
|
| 178 |
+
if not annotations:
|
| 179 |
+
raise gr.Error("Missing annotation data. Please check if an image is uploaded.")
|
| 180 |
+
|
| 181 |
+
img_array = annotations.get("image", None)
|
| 182 |
+
boxes = annotations.get("boxes", [])
|
| 183 |
+
|
| 184 |
+
if img_array is None:
|
| 185 |
+
raise gr.Error("No 'image' field found in annotation.")
|
| 186 |
+
|
| 187 |
+
img = Image.fromarray(img_array)
|
| 188 |
+
|
| 189 |
+
# ✅
|
| 190 |
+
if not boxes:
|
| 191 |
+
w, h = img.size
|
| 192 |
+
xmin, ymin, xmax, ymax = 0, 0, w, h
|
| 193 |
+
else:
|
| 194 |
+
box = boxes[0]
|
| 195 |
+
xmin = int(box["xmin"])
|
| 196 |
+
ymin = int(box["ymin"])
|
| 197 |
+
xmax = int(box["xmax"])
|
| 198 |
+
ymax = int(box["ymax"])
|
| 199 |
+
|
| 200 |
+
if xmax <= xmin or ymax <= ymin:
|
| 201 |
+
raise gr.Error("Invalid bbox, please draw the box again.")
|
| 202 |
+
|
| 203 |
+
patch = img.crop((xmin, ymin, xmax, ymax))
|
| 204 |
+
return img, patch, (xmin, ymin, xmax, ymax)
|
| 205 |
+
|
| 206 |
+
@spaces.GPU(duration=200)
|
| 207 |
+
def run_with_two_bboxes(
|
| 208 |
+
annotations_A: dict | None, #
|
| 209 |
+
annotations_B: dict | None, #
|
| 210 |
+
object_name: str,
|
| 211 |
+
base_seed: int = 0,
|
| 212 |
+
): # noqa: C901
|
| 213 |
+
"""
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
load_models()
|
| 217 |
+
global pipeline, device
|
| 218 |
+
if annotations_A is None:
|
| 219 |
+
raise gr.Error("please upload reference image and draw a bbox")
|
| 220 |
+
if annotations_B is None:
|
| 221 |
+
raise gr.Error("please upload input image to be corrected and draw a bbox")
|
| 222 |
+
|
| 223 |
+
# 1.
|
| 224 |
+
img1_full, patch_A, bbox_A = extract_first_box(annotations_A)
|
| 225 |
+
img2_full, patch_B, bbox_B = extract_first_box(annotations_B)
|
| 226 |
+
|
| 227 |
+
xmin_B, ymin_B, xmax_B, ymax_B = bbox_B
|
| 228 |
+
patch_w = xmax_B - xmin_B
|
| 229 |
+
patch_h = ymax_B - ymin_B
|
| 230 |
+
|
| 231 |
+
if not object_name:
|
| 232 |
+
object_name = "object"
|
| 233 |
+
|
| 234 |
+
# 2.
|
| 235 |
+
orig_w, orig_h = patch_B.size
|
| 236 |
+
target_w, target_h = pick_kontext_resolution(orig_w, orig_h)
|
| 237 |
+
width_for_model, height_for_model = target_w, target_h
|
| 238 |
+
|
| 239 |
+
# 3.
|
| 240 |
+
cond_A_image = patch_A.resize((width_for_model, height_for_model), Image.Resampling.LANCZOS)
|
| 241 |
+
cond_B_image = patch_B.resize((width_for_model, height_for_model), Image.Resampling.LANCZOS)
|
| 242 |
+
|
| 243 |
+
prompt = f"use the {object_name} in IMG1 as a reference to refine, replace, enhance the {object_name} in IMG2"
|
| 244 |
+
print("prompt:", prompt)
|
| 245 |
+
|
| 246 |
+
seed = int(base_seed)
|
| 247 |
+
gen_device = device.split(":")[0] if "cuda" in device else device
|
| 248 |
+
generator = torch.Generator(gen_device).manual_seed(seed)
|
| 249 |
+
|
| 250 |
+
try:
|
| 251 |
+
out = pipeline(
|
| 252 |
+
image_A=cond_A_image,
|
| 253 |
+
image_B=cond_B_image,
|
| 254 |
+
prompt=prompt,
|
| 255 |
+
height=height_for_model,
|
| 256 |
+
width=width_for_model,
|
| 257 |
+
guidance_scale=3.5,
|
| 258 |
+
generator=generator,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
gen_patch_model = out.images[0]
|
| 262 |
+
|
| 263 |
+
#
|
| 264 |
+
gen_patch = gen_patch_model.resize((patch_w, patch_h), Image.Resampling.LANCZOS)
|
| 265 |
+
|
| 266 |
+
#
|
| 267 |
+
composed = img2_full.copy()
|
| 268 |
+
composed.paste(gen_patch, (xmin_B, ymin_B))
|
| 269 |
+
patch_A_resized = patch_A.resize((patch_w, patch_h), Image.Resampling.LANCZOS)
|
| 270 |
+
patch_B_resized = patch_B.resize((patch_w, patch_h), Image.Resampling.LANCZOS)
|
| 271 |
+
SPACING = 10
|
| 272 |
+
collage_w = patch_w * 3 + SPACING * 2
|
| 273 |
+
collage_h = patch_h
|
| 274 |
+
|
| 275 |
+
collage = Image.new("RGB", (collage_w, collage_h), (255, 255, 255))
|
| 276 |
+
|
| 277 |
+
x0 = 0
|
| 278 |
+
x1 = patch_w + SPACING
|
| 279 |
+
x2 = patch_w * 2 + SPACING * 2
|
| 280 |
+
|
| 281 |
+
collage.paste(patch_A_resized, (x0, 0))
|
| 282 |
+
collage.paste(patch_B_resized, (x1, 0))
|
| 283 |
+
collage.paste(gen_patch, (x2, 0))
|
| 284 |
+
|
| 285 |
+
return collage, composed
|
| 286 |
+
|
| 287 |
+
except Exception as e:
|
| 288 |
+
print(f"生成图像时发生错误: {e}")
|
| 289 |
+
raise gr.Error(f"生成失败:{str(e)}")
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
import gradio as gr
|
| 293 |
+
|
| 294 |
+
with gr.Blocks(
|
| 295 |
+
theme=gr.themes.Soft(primary_hue="blue", neutral_hue="slate"),
|
| 296 |
+
css="""
|
| 297 |
+
/* Global Clean Font */
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
/* Center container */
|
| 301 |
+
.app-container {
|
| 302 |
+
width: 100% !important;
|
| 303 |
+
max-width: 100% !important;
|
| 304 |
+
margin: 0 auto;
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
/* Title block */
|
| 308 |
+
.title-block h1 {
|
| 309 |
+
text-align: center;
|
| 310 |
+
font-size: 3rem;
|
| 311 |
+
font-weight: 1100;
|
| 312 |
+
|
| 313 |
+
/* 蓝紫渐变 */
|
| 314 |
+
background: linear-gradient(90deg, #5b8dff, #b57aff);
|
| 315 |
+
-webkit-background-clip: text;
|
| 316 |
+
color: transparent;
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
.title-block h2 {
|
| 320 |
+
text-align: center;
|
| 321 |
+
font-size: 1.6rem;
|
| 322 |
+
font-weight: 700;
|
| 323 |
+
margin-top: 0.4rem;
|
| 324 |
+
|
| 325 |
+
/* 稍弱一点的渐变 */
|
| 326 |
+
background: linear-gradient(90deg, #6da0ff, #c28aff);
|
| 327 |
+
-webkit-background-clip: text;
|
| 328 |
+
color: transparent;
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
/* Title block
|
| 332 |
+
|
| 333 |
+
.title-block h1 {
|
| 334 |
+
text-align: center; font-size: 2.4rem; font-weight: 800; color: #1f2937;
|
| 335 |
+
}
|
| 336 |
+
.title-block h2 {
|
| 337 |
+
text-align: center; font-size: 1.2rem; font-weight: 500; color: #303030; margin-top: 0.4rem;
|
| 338 |
+
}
|
| 339 |
+
*/
|
| 340 |
+
|
| 341 |
+
/* Simple card */
|
| 342 |
+
.clean-card {
|
| 343 |
+
background: #ffffff;
|
| 344 |
+
border: 1px solid #e5e7eb;
|
| 345 |
+
border-radius: 12px;
|
| 346 |
+
padding: 14px 16px;
|
| 347 |
+
margin-bottom: 10px;
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
/* Card title */
|
| 351 |
+
.clean-card-title {
|
| 352 |
+
font-size: 1.3rem;
|
| 353 |
+
font-weight: 600;
|
| 354 |
+
color: #404040;
|
| 355 |
+
margin-bottom: 6px;
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
/* Subtitle */
|
| 359 |
+
.clean-card-subtitle {
|
| 360 |
+
font-size: 1.1rem;
|
| 361 |
+
color: #404040;
|
| 362 |
+
margin-bottom: 8px;
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
/* Output card */
|
| 366 |
+
.output-card {
|
| 367 |
+
background: #ffffff;
|
| 368 |
+
border: 1px solid #d1d5db;
|
| 369 |
+
border-radius: 12px;
|
| 370 |
+
padding: 14px 16px;
|
| 371 |
+
}
|
| 372 |
+
.output-card1 {
|
| 373 |
+
background: #ffffff;
|
| 374 |
+
border: none !important;
|
| 375 |
+
box-shadow: none !important;
|
| 376 |
+
border-radius: 12px;
|
| 377 |
+
padding: 14px 16px;
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
/* 渐变主按钮:同时兼容 button 自己是 .color-btn,或者外层是 .color-btn 的情况 */
|
| 381 |
+
button.color-btn,
|
| 382 |
+
.color-btn button {
|
| 383 |
+
width: 100%;
|
| 384 |
+
background: linear-gradient(90deg, #3b82f6 0%, #6366f1 100%) !important;
|
| 385 |
+
color: #ffffff !important;
|
| 386 |
+
font-size: 1.05rem !important;
|
| 387 |
+
font-weight: 700 !important;
|
| 388 |
+
padding: 14px !important;
|
| 389 |
+
border-radius: 12px !important;
|
| 390 |
+
|
| 391 |
+
border: none !important;
|
| 392 |
+
box-shadow: 0 4px 12px rgba(99, 102, 241, 0.25) !important;
|
| 393 |
+
transition: 0.2s ease !important;
|
| 394 |
+
cursor: pointer;
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
/* Hover 效果 */
|
| 398 |
+
button.color-btn:hover,
|
| 399 |
+
.color-btn button:hover {
|
| 400 |
+
opacity: 0.92 !important;
|
| 401 |
+
transform: translateY(-1px) !important;
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
/* 按下反馈 */
|
| 405 |
+
button.color-btn:active,
|
| 406 |
+
.color-btn button:active {
|
| 407 |
+
transform: scale(0.98) !important;
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
/* 如果外面还有 wrapper,就把它搞透明一下(防止再套一层白条) */
|
| 411 |
+
.color-btn > div {
|
| 412 |
+
background: transparent !important;
|
| 413 |
+
box-shadow: none !important;
|
| 414 |
+
border: none !important;
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
.example-image img {
|
| 418 |
+
height: 400px !important;
|
| 419 |
+
object-fit: contain;
|
| 420 |
+
|
| 421 |
+
"""
|
| 422 |
+
) as demo:
|
| 423 |
+
gen_patch_out = None
|
| 424 |
+
composed_out = None
|
| 425 |
+
# -------------------------------------------------------
|
| 426 |
+
# Title
|
| 427 |
+
# -------------------------------------------------------
|
| 428 |
+
gr.Markdown(
|
| 429 |
+
"""
|
| 430 |
+
<div class="title-block">
|
| 431 |
+
<h1>The Consistency Critic:</h1>
|
| 432 |
+
<h2>Correcting Inconsistencies in Generated Images via Reference-Guided Attentive Alignment</h2>
|
| 433 |
+
</div>
|
| 434 |
+
"""
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# ========================================================
|
| 438 |
+
# 两个 Stage 并排显示
|
| 439 |
+
# ========================================================
|
| 440 |
+
with gr.Row(elem_classes="app-container"):
|
| 441 |
+
# ========================================================
|
| 442 |
+
# STAGE 1: Image Generation (左侧)
|
| 443 |
+
# ========================================================
|
| 444 |
+
with gr.Column(scale=1):
|
| 445 |
+
gr.Markdown(
|
| 446 |
+
"""
|
| 447 |
+
<div class="clean-card">
|
| 448 |
+
<div class="clean-card-title">🎨 Stage 1: Customized Image Generation</div>
|
| 449 |
+
<div class="clean-card-subtitle">Generate images from prompts and reference image using UNO method. The output can be used as input for Stage 2.</div>
|
| 450 |
+
</div>
|
| 451 |
+
"""
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
# Stage 1 Input
|
| 455 |
+
gr.Markdown("### Input")
|
| 456 |
+
stage1_input_image = gr.Image(label="Input Image (Optional)", type="pil")
|
| 457 |
+
stage1_prompt = gr.Textbox(
|
| 458 |
+
label="Prompt",
|
| 459 |
+
placeholder="Enter your prompt for image generation",
|
| 460 |
+
lines=3
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
with gr.Row():
|
| 464 |
+
with gr.Column():
|
| 465 |
+
stage1_width = gr.Slider(512, 2048, 1024, step=16, label="Generation Width")
|
| 466 |
+
stage1_height = gr.Slider(512, 2048, 1024, step=16, label="Generation Height")
|
| 467 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 468 |
+
stage1_seed = gr.Slider(
|
| 469 |
+
label="Seed",
|
| 470 |
+
minimum=0,
|
| 471 |
+
maximum=MAX_SEED,
|
| 472 |
+
step=1,
|
| 473 |
+
value=42,
|
| 474 |
+
)
|
| 475 |
+
stage1_randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
| 476 |
+
stage1_guidance_scale = gr.Slider(
|
| 477 |
+
label="Guidance Scale",
|
| 478 |
+
minimum=1,
|
| 479 |
+
maximum=10,
|
| 480 |
+
step=0.1,
|
| 481 |
+
value=2.5,
|
| 482 |
+
)
|
| 483 |
+
stage1_steps = gr.Slider(
|
| 484 |
+
label="Steps",
|
| 485 |
+
minimum=1,
|
| 486 |
+
maximum=30,
|
| 487 |
+
value=28,
|
| 488 |
+
step=1
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
stage1_method1_btn = gr.Button("✨ Generate Image", elem_classes="color-btn")
|
| 492 |
+
|
| 493 |
+
# Stage 1 Output
|
| 494 |
+
gr.Markdown("### Output")
|
| 495 |
+
stage1_output_image = gr.Image(label="Generated Image", interactive=False)
|
| 496 |
+
stage1_used_seed = gr.Number(label="Used Seed", interactive=False)
|
| 497 |
+
|
| 498 |
+
# -------------------------------------------------------
|
| 499 |
+
# Stage 1 Examples
|
| 500 |
+
# -------------------------------------------------------
|
| 501 |
+
gr.Markdown(
|
| 502 |
+
"""
|
| 503 |
+
<div style="
|
| 504 |
+
font-size: 1.3rem;
|
| 505 |
+
font-weight: 600;
|
| 506 |
+
color: #404040;
|
| 507 |
+
margin-top: 16px;
|
| 508 |
+
margin-bottom: 6px;
|
| 509 |
+
">
|
| 510 |
+
📚 Stage 1 Example Images & Prompts
|
| 511 |
+
</div>
|
| 512 |
+
""",
|
| 513 |
+
)
|
| 514 |
+
gr.Markdown(
|
| 515 |
+
"""
|
| 516 |
+
<div style="
|
| 517 |
+
font-size: 1.1rem;
|
| 518 |
+
color: #404040;
|
| 519 |
+
margin-bottom: 8px;
|
| 520 |
+
">
|
| 521 |
+
Click on any example below to load the image and prompt into Stage 1 inputs.
|
| 522 |
+
</div>
|
| 523 |
+
""",
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
gr.Examples(
|
| 527 |
+
examples=[
|
| 528 |
+
["./test_imgs/product_3.png", "In a softly lit nursery, a baby sleeps peacefully as a parent gently applies the product to a washcloth. The scene is calm and warm, with natural light highlighting the product’s label. The camera captures a close-up, centered view, emphasizing the product’s presence and its gentle interaction with the environment."],
|
| 529 |
+
["./test_imgs/3.png", "Create an engaging lifestyle e-commerce scene where a person delicately picks up the product from a slightly shifted angle to add depth and realism, placing it within a creative photography workspace filled with soft natural light, scattered camera gear, open photo books, and warm wooden textures."],
|
| 530 |
+
["./test_imgs/2.png", "Create a stylish e-commerce scene featuring the product displayed on a modern clothing rack in a bright boutique environment, surrounded by soft natural lighting, minimalistic decor, and complementary fashion accessories"]
|
| 531 |
+
],
|
| 532 |
+
inputs=[stage1_input_image, stage1_prompt],
|
| 533 |
+
label="Click to Load Examples"
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
# ========================================================
|
| 537 |
+
# STAGE 2: Image Correction (右侧)
|
| 538 |
+
# ========================================================
|
| 539 |
+
with gr.Column(scale=1):
|
| 540 |
+
gr.Markdown(
|
| 541 |
+
"""
|
| 542 |
+
<div class="clean-card">
|
| 543 |
+
<div class="clean-card-title">🔧 Stage 2: Image Consistency Correction</div>
|
| 544 |
+
<div class="clean-card-subtitle">Refine and correct generated images using ImageCritic.</div>
|
| 545 |
+
</div>
|
| 546 |
+
"""
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
# -------------------------------------------------------
|
| 550 |
+
# Tips for Stage 2
|
| 551 |
+
# -------------------------------------------------------
|
| 552 |
+
gr.Markdown(
|
| 553 |
+
"""
|
| 554 |
+
<div class="clean-card">
|
| 555 |
+
<div class="clean-card-title">💡 Stage 2 Tips</div>
|
| 556 |
+
<div class="clean-card-subtitle">
|
| 557 |
+
• Crop both the bbox that needs to be corrected and the reference bbox, preferably covering the smallest repeating unit, to achieve better results.<br>
|
| 558 |
+
• The bbox area should ideally cover the region to be corrected and the reference region as completely as possible.<br>
|
| 559 |
+
• The aspect ratio of the bboxes should also be kept consistent to avoid errors caused by incorrect scaling.<br>
|
| 560 |
+
• If model fails to correct the image, it may be because the generated image is too similar to the reference image, causing the model to skip the repair. You can manually<b> paint that area black on a drawing board before sending to the model, or try cropping only the local region and performing multiple rounds correcting to progressively enhance the whole generated image.</b>
|
| 561 |
+
</div>
|
| 562 |
+
"""
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
# -------------------------------------------------------
|
| 566 |
+
# Image annotation area
|
| 567 |
+
# -------------------------------------------------------
|
| 568 |
+
with gr.Row():
|
| 569 |
+
# Left: Reference Image
|
| 570 |
+
with gr.Column():
|
| 571 |
+
gr.Markdown(
|
| 572 |
+
"""
|
| 573 |
+
<div class="clean-card">
|
| 574 |
+
<div class="clean-card-title">📌 Reference Image</div>
|
| 575 |
+
<div class="clean-card-subtitle">Draw a bounding box around the area for reference.</div>
|
| 576 |
+
</div>
|
| 577 |
+
"""
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
annotator_A = image_annotator(
|
| 581 |
+
value=None,
|
| 582 |
+
label="reference image",
|
| 583 |
+
label_list=["bbox for reference"],
|
| 584 |
+
label_colors = [(168, 160, 194)],
|
| 585 |
+
single_box=True,
|
| 586 |
+
image_type="numpy",
|
| 587 |
+
sources=["upload", "clipboard"],
|
| 588 |
+
height=300,
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
# Right: Image to be corrected
|
| 592 |
+
with gr.Column():
|
| 593 |
+
gr.Markdown(
|
| 594 |
+
"""
|
| 595 |
+
<div class="clean-card">
|
| 596 |
+
<div class="clean-card-title">🖼️ Input Image To Be Corrected</div>
|
| 597 |
+
<div class="clean-card-subtitle">Use the mouse wheel to zoom and draw a bounding box around the area to be corrected.</div>
|
| 598 |
+
</div>
|
| 599 |
+
"""
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
annotator_B = image_annotator(
|
| 603 |
+
value=None,
|
| 604 |
+
label="input image to be corrected",
|
| 605 |
+
label_list=["bbox for correction"],
|
| 606 |
+
label_colors = [(168, 160, 194)],
|
| 607 |
+
single_box=True,
|
| 608 |
+
image_type="numpy",
|
| 609 |
+
sources=["upload", "clipboard"],
|
| 610 |
+
height=300,
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
# -------------------------------------------------------
|
| 614 |
+
# Controls
|
| 615 |
+
# -------------------------------------------------------
|
| 616 |
+
with gr.Row():
|
| 617 |
+
object_name = gr.Textbox(
|
| 618 |
+
label="Caption for object (optional; using 'product' also works)",
|
| 619 |
+
value="product",
|
| 620 |
+
placeholder="e.g. product, shoes, bag, face ..."
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
base_seed = gr.Number(
|
| 624 |
+
label="Seed",
|
| 625 |
+
value=0,
|
| 626 |
+
precision=0,
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
# -------------------------------------------------------
|
| 630 |
+
# Run Button
|
| 631 |
+
# -------------------------------------------------------
|
| 632 |
+
run_btn = gr.Button("✨ Generate ", elem_classes="color-btn")
|
| 633 |
+
|
| 634 |
+
# ===================== 输出区 =====================
|
| 635 |
+
gr.Markdown("### Output")
|
| 636 |
+
with gr.Column(elem_classes="output-card1"):
|
| 637 |
+
gen_patch_out = gr.Image(
|
| 638 |
+
label="concatenated input-output",
|
| 639 |
+
interactive=False
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
with gr.Column(elem_classes="output-card1"):
|
| 643 |
+
composed_out = gr.Image(
|
| 644 |
+
label="corrected image",
|
| 645 |
+
interactive=False
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
# -------------------------------------------------------
|
| 649 |
+
# Stage 2 Example 区域整体放进一个白色卡片
|
| 650 |
+
# -------------------------------------------------------
|
| 651 |
+
with gr.Column(elem_classes="clean-card"):
|
| 652 |
+
|
| 653 |
+
gr.Markdown(
|
| 654 |
+
"""
|
| 655 |
+
<div style="
|
| 656 |
+
font-size: 1.3rem;
|
| 657 |
+
font-weight: 600;
|
| 658 |
+
color: #404040;
|
| 659 |
+
margin-bottom: 6px;
|
| 660 |
+
">
|
| 661 |
+
📚 Example Images
|
| 662 |
+
</div>
|
| 663 |
+
""",
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
gr.Markdown(
|
| 667 |
+
"""
|
| 668 |
+
<div style="
|
| 669 |
+
font-size: 1.1rem;
|
| 670 |
+
color: #404040;
|
| 671 |
+
margin-bottom: 8px;
|
| 672 |
+
">
|
| 673 |
+
Below are some example pairs showing how bounding boxes should be drawn.
|
| 674 |
+
You can click and drag the image below into the upper area for generation.<br>
|
| 675 |
+
<b> Full-image input is also supported, but it is recommended to use the smallest possible bounding box that covers the region to be corrected and reference bbox. For example, the bbox approach used in the first row generally produces better results than the one used in the second row.</b>
|
| 676 |
+
</div>
|
| 677 |
+
""",
|
| 678 |
+
)
|
| 679 |
+
with gr.Row():
|
| 680 |
+
gr.Image("./test_imgs/product_3.png",label="reference example", elem_classes="example-image")
|
| 681 |
+
gr.Image("./test_imgs/product_3_bbox_1.png",label="reference example with bbox",elem_classes="example-image")
|
| 682 |
+
gr.Image("./test_imgs/generated_3.png",label="input example", elem_classes="example-image")
|
| 683 |
+
gr.Image("./test_imgs/generated_3_bbox_1.png",label="input example with bbox", elem_classes="example-image")
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
with gr.Row():
|
| 687 |
+
gr.Image("./test_imgs/product_3.png",label="reference example", elem_classes="example-image")
|
| 688 |
+
gr.Image("./test_imgs/product_3_bbox.png",label="reference example with bbox",elem_classes="example-image")
|
| 689 |
+
gr.Image("./test_imgs/generated_3.png",label="input example", elem_classes="example-image")
|
| 690 |
+
gr.Image("./test_imgs/generated_3_bbox.png",label="input example with bbox", elem_classes="example-image")
|
| 691 |
+
|
| 692 |
+
with gr.Row():
|
| 693 |
+
gr.Image("./test_imgs/product_1.jpg", label="reference example", elem_classes="example-image")
|
| 694 |
+
gr.Image("./test_imgs/product_1_bbox.png", label="reference example with bbox", elem_classes="example-image")
|
| 695 |
+
gr.Image("./test_imgs/generated_1.png", label="input example", elem_classes="example-image")
|
| 696 |
+
gr.Image("./test_imgs/generated_1_bbox.png", label="input example with bbox", elem_classes="example-image")
|
| 697 |
+
|
| 698 |
+
with gr.Row():
|
| 699 |
+
gr.Image("./test_imgs/product_2.png",label="reference example", elem_classes="example-image")
|
| 700 |
+
gr.Image("./test_imgs/product_2_bbox.png",label="reference example with bbox",elem_classes="example-image")
|
| 701 |
+
gr.Image("./test_imgs/generated_2.png", label="input example", elem_classes="example-image")
|
| 702 |
+
gr.Image("./test_imgs/generated_2_bbox.png", label="input example with bbox", elem_classes="example-image")
|
| 703 |
+
|
| 704 |
+
# ========= 所有组件都定义完,再绑定按钮点击 =========
|
| 705 |
+
# Stage 1: Image Generation
|
| 706 |
+
stage1_method1_btn.click(
|
| 707 |
+
fn=generate_image_method1,
|
| 708 |
+
inputs=[stage1_input_image, stage1_prompt, stage1_width, stage1_height, stage1_seed, stage1_randomize_seed, stage1_guidance_scale, stage1_steps],
|
| 709 |
+
outputs=[stage1_output_image, stage1_used_seed],
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
# Stage 2: Image Correction
|
| 713 |
+
run_btn.click(
|
| 714 |
+
fn=run_with_two_bboxes,
|
| 715 |
+
inputs=[annotator_A, annotator_B, object_name, base_seed],
|
| 716 |
+
outputs=[gen_patch_out, composed_out],
|
| 717 |
+
)
|
| 718 |
+
|
| 719 |
+
if __name__ == "__main__":
|
| 720 |
+
demo.launch(server_name="0.0.0.0", server_port=7779)
|
requirements.txt
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cu124
|
| 2 |
+
torch
|
| 3 |
+
torchvision
|
| 4 |
+
accelerate==1.10.0
|
| 5 |
+
clip @ git+https://github.com/openai/CLIP.git@dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
|
| 6 |
+
contourpy==1.3.2
|
| 7 |
+
cycler==0.12.1
|
| 8 |
+
datasets==4.0.0
|
| 9 |
+
decord==0.6.0
|
| 10 |
+
diffusers @ git+https://github.com/huggingface/diffusers.git@345864eb852b528fd1f4b6ad087fa06e0470006b
|
| 11 |
+
gradio==5.49.1
|
| 12 |
+
gradio_client==1.13.3
|
| 13 |
+
gradio_image_annotation==0.4.1
|
| 14 |
+
huggingface-hub==0.35.3
|
| 15 |
+
ipykernel==7.0.1
|
| 16 |
+
ipython==8.37.0
|
| 17 |
+
Jinja2==3.1.6
|
| 18 |
+
multiprocess==0.70.16
|
| 19 |
+
ninja==1.13.0
|
| 20 |
+
numpy==2.2.6
|
| 21 |
+
open_clip_torch==3.2.0
|
| 22 |
+
openai==1.107.2
|
| 23 |
+
opencv-python==4.12.0.88
|
| 24 |
+
opencv-python-headless==4.12.0.88
|
| 25 |
+
qwen-vl-utils==0.0.11
|
| 26 |
+
requests==2.32.5
|
| 27 |
+
safetensors==0.6.2
|
| 28 |
+
scikit-learn==1.7.2
|
| 29 |
+
tornado==6.5.2
|
| 30 |
+
tqdm==4.67.1
|
| 31 |
+
transformers==4.51.3
|
| 32 |
+
wandb==0.21.1
|
| 33 |
+
einops
|
| 34 |
+
sentencepiece
|
src/__init__.py
ADDED
|
File without changes
|
src/attention_processor.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from typing import Optional, Tuple, Dict, Any
|
| 4 |
+
import os
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
from diffusers.models.attention_processor import FluxAttnProcessor2_0
|
| 9 |
+
|
| 10 |
+
class VisualFluxAttnProcessor2_0(FluxAttnProcessor2_0):
|
| 11 |
+
"""
|
| 12 |
+
自定义的Flux注意力处理器,用于保存注意力图进行可视化
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, save_attention=True, save_dir="attention_maps"):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.save_attention = save_attention
|
| 18 |
+
self.save_dir = save_dir
|
| 19 |
+
self.step_counter = 0
|
| 20 |
+
|
| 21 |
+
# 创建保存目录
|
| 22 |
+
if self.save_attention:
|
| 23 |
+
os.makedirs(self.save_dir, exist_ok=True)
|
| 24 |
+
|
| 25 |
+
def save_attention_map(self, attn_weights, layer_name="", step=None):
|
| 26 |
+
"""保存注意力图"""
|
| 27 |
+
if not self.save_attention:
|
| 28 |
+
return
|
| 29 |
+
|
| 30 |
+
if step is None:
|
| 31 |
+
step = self.step_counter
|
| 32 |
+
|
| 33 |
+
# 取第一个batch和第一个head的注意力权重
|
| 34 |
+
attn_map = attn_weights[0, 0].detach().cpu().numpy() # [seq_len, seq_len]
|
| 35 |
+
|
| 36 |
+
# 创建热力图
|
| 37 |
+
plt.figure(figsize=(12, 10))
|
| 38 |
+
plt.imshow(attn_map, cmap='hot', interpolation='nearest')
|
| 39 |
+
plt.colorbar()
|
| 40 |
+
plt.title(f'Attention Map - {layer_name} - Step {step}')
|
| 41 |
+
plt.xlabel('Key Position')
|
| 42 |
+
plt.ylabel('Query Position')
|
| 43 |
+
|
| 44 |
+
# 保存图片
|
| 45 |
+
save_path = os.path.join(self.save_dir, f"attention_{layer_name}_step_{step}.png")
|
| 46 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 47 |
+
plt.close()
|
| 48 |
+
|
| 49 |
+
print(f"Attention map saved to: {save_path}")
|
| 50 |
+
|
| 51 |
+
def __call__(
|
| 52 |
+
self,
|
| 53 |
+
attn,
|
| 54 |
+
hidden_states: torch.Tensor,
|
| 55 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 56 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 57 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 58 |
+
use_cond: bool = False,
|
| 59 |
+
) -> torch.Tensor:
|
| 60 |
+
batch_size, sequence_length, _ = (
|
| 61 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
if attention_mask is not None:
|
| 65 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 66 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
| 67 |
+
|
| 68 |
+
if attn.group_norm is not None:
|
| 69 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 70 |
+
|
| 71 |
+
query = attn.to_q(hidden_states)
|
| 72 |
+
|
| 73 |
+
if encoder_hidden_states is None:
|
| 74 |
+
encoder_hidden_states = hidden_states
|
| 75 |
+
elif attn.norm_cross:
|
| 76 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 77 |
+
|
| 78 |
+
key = attn.to_k(encoder_hidden_states)
|
| 79 |
+
value = attn.to_v(encoder_hidden_states)
|
| 80 |
+
|
| 81 |
+
inner_dim = key.shape[-1]
|
| 82 |
+
head_dim = inner_dim // attn.heads
|
| 83 |
+
|
| 84 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 85 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 86 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 87 |
+
|
| 88 |
+
if attn.norm_q is not None:
|
| 89 |
+
query = attn.norm_q(query)
|
| 90 |
+
if attn.norm_k is not None:
|
| 91 |
+
key = attn.norm_k(key)
|
| 92 |
+
|
| 93 |
+
# 应用旋转位置编码
|
| 94 |
+
if image_rotary_emb is not None:
|
| 95 |
+
query = attn.rotary_emb(query, image_rotary_emb)
|
| 96 |
+
if not attn.is_cross_attention:
|
| 97 |
+
key = attn.rotary_emb(key, image_rotary_emb)
|
| 98 |
+
|
| 99 |
+
# 计算注意力权重
|
| 100 |
+
attention_scores = torch.matmul(query, key.transpose(-2, -1)) / (head_dim ** 0.5)
|
| 101 |
+
|
| 102 |
+
if attention_mask is not None:
|
| 103 |
+
attention_scores = attention_scores + attention_mask
|
| 104 |
+
|
| 105 |
+
attention_probs = F.softmax(attention_scores, dim=-1)
|
| 106 |
+
|
| 107 |
+
# 保存注意力图
|
| 108 |
+
if self.save_attention and self.step_counter % 10 == 0: # 每10步保存一次
|
| 109 |
+
layer_name = f"layer_{self.step_counter // 10}"
|
| 110 |
+
self.save_attention_map(attention_probs, layer_name, self.step_counter)
|
| 111 |
+
|
| 112 |
+
# 应用dropout
|
| 113 |
+
attention_probs = F.dropout(attention_probs, p=attn.dropout, training=attn.training)
|
| 114 |
+
|
| 115 |
+
# 计算输出
|
| 116 |
+
hidden_states = torch.matmul(attention_probs, value)
|
| 117 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 118 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 119 |
+
|
| 120 |
+
if use_cond:
|
| 121 |
+
# 处理条件分支的情况
|
| 122 |
+
seq_len = hidden_states.shape[1]
|
| 123 |
+
if seq_len % 2 == 0:
|
| 124 |
+
# 假设前半部分是原始hidden_states,后半部分是条件hidden_states
|
| 125 |
+
mid_point = seq_len // 2
|
| 126 |
+
original_hidden_states = hidden_states[:, :mid_point, :]
|
| 127 |
+
cond_hidden_states = hidden_states[:, mid_point:, :]
|
| 128 |
+
|
| 129 |
+
# 分别处理
|
| 130 |
+
original_output = attn.to_out[0](original_hidden_states)
|
| 131 |
+
cond_output = attn.to_out[0](cond_hidden_states)
|
| 132 |
+
|
| 133 |
+
if len(attn.to_out) > 1:
|
| 134 |
+
original_output = attn.to_out[1](original_output)
|
| 135 |
+
cond_output = attn.to_out[1](cond_output)
|
| 136 |
+
|
| 137 |
+
self.step_counter += 1
|
| 138 |
+
return original_output, cond_output
|
| 139 |
+
|
| 140 |
+
# 标准输出处理
|
| 141 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 142 |
+
if len(attn.to_out) > 1:
|
| 143 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 144 |
+
|
| 145 |
+
self.step_counter += 1
|
| 146 |
+
return hidden_states
|
src/detail_encoder.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Merge image encoder and fuse module to create an ID Encoder
|
| 2 |
+
# send multiple ID images, we can directly obtain the updated text encoder containing a stacked ID embedding
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from transformers.models.clip.modeling_clip import CLIPVisionModelWithProjection
|
| 8 |
+
from transformers.models.clip.configuration_clip import CLIPVisionConfig
|
| 9 |
+
from transformers import PretrainedConfig
|
| 10 |
+
|
| 11 |
+
VISION_CONFIG_DICT = {
|
| 12 |
+
"hidden_size": 1024,
|
| 13 |
+
"intermediate_size": 4096,
|
| 14 |
+
"num_attention_heads": 16,
|
| 15 |
+
"num_hidden_layers": 24,
|
| 16 |
+
"patch_size": 14,
|
| 17 |
+
"projection_dim": 768
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
class MLP(nn.Module):
|
| 21 |
+
def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True):
|
| 22 |
+
super().__init__()
|
| 23 |
+
if use_residual:
|
| 24 |
+
assert in_dim == out_dim
|
| 25 |
+
self.layernorm = nn.LayerNorm(in_dim)
|
| 26 |
+
self.fc1 = nn.Linear(in_dim, hidden_dim)
|
| 27 |
+
self.fc2 = nn.Linear(hidden_dim, out_dim)
|
| 28 |
+
self.use_residual = use_residual
|
| 29 |
+
self.act_fn = nn.GELU()
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
residual = x
|
| 33 |
+
x = self.layernorm(x)
|
| 34 |
+
x = self.fc1(x)
|
| 35 |
+
x = self.act_fn(x)
|
| 36 |
+
x = self.fc2(x)
|
| 37 |
+
if self.use_residual:
|
| 38 |
+
x = x + residual
|
| 39 |
+
return x
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class FuseModule(nn.Module):
|
| 43 |
+
def __init__(self, prompt_embed_dim, id_embed_dim):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.mlp1 = MLP(prompt_embed_dim + id_embed_dim, prompt_embed_dim, prompt_embed_dim, use_residual=False)
|
| 46 |
+
self.mlp2 = MLP(prompt_embed_dim, prompt_embed_dim, prompt_embed_dim, use_residual=True)
|
| 47 |
+
self.layer_norm = nn.LayerNorm(prompt_embed_dim)
|
| 48 |
+
|
| 49 |
+
def fuse_fn(self, prompt_embeds, id_embeds):
|
| 50 |
+
stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1)
|
| 51 |
+
stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds
|
| 52 |
+
stacked_id_embeds = self.mlp2(stacked_id_embeds)
|
| 53 |
+
stacked_id_embeds = self.layer_norm(stacked_id_embeds)
|
| 54 |
+
return stacked_id_embeds
|
| 55 |
+
|
| 56 |
+
def forward(
|
| 57 |
+
self,
|
| 58 |
+
prompt_embeds,
|
| 59 |
+
id_embeds,
|
| 60 |
+
class_tokens_mask,
|
| 61 |
+
) -> torch.Tensor:
|
| 62 |
+
device = prompt_embeds.device
|
| 63 |
+
class_tokens_mask = class_tokens_mask.to(device)
|
| 64 |
+
id_embeds = id_embeds.to(prompt_embeds.dtype)
|
| 65 |
+
num_inputs = class_tokens_mask.sum().unsqueeze(0).to(id_embeds.device)
|
| 66 |
+
batch_size, max_num_inputs = id_embeds.shape[:2]
|
| 67 |
+
seq_length = prompt_embeds.shape[1]
|
| 68 |
+
flat_id_embeds = id_embeds.view(
|
| 69 |
+
-1, id_embeds.shape[-2], id_embeds.shape[-1]
|
| 70 |
+
)
|
| 71 |
+
valid_id_mask = (
|
| 72 |
+
torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :]
|
| 73 |
+
< num_inputs[:, None]
|
| 74 |
+
)
|
| 75 |
+
valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()]
|
| 76 |
+
|
| 77 |
+
prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1])
|
| 78 |
+
class_tokens_mask = class_tokens_mask.view(-1)
|
| 79 |
+
valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1])
|
| 80 |
+
image_token_embeds = prompt_embeds[class_tokens_mask]
|
| 81 |
+
stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds)
|
| 82 |
+
stacked_id_embeds = stacked_id_embeds.to(device=device, dtype=prompt_embeds.dtype)
|
| 83 |
+
assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}"
|
| 84 |
+
prompt_embeds = prompt_embeds.masked_scatter(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype))
|
| 85 |
+
updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1)
|
| 86 |
+
return updated_prompt_embeds
|
| 87 |
+
|
| 88 |
+
class DetailEncoder(CLIPVisionModelWithProjection):
|
| 89 |
+
def __init__(self):
|
| 90 |
+
|
| 91 |
+
super().__init__(CLIPVisionConfig(**VISION_CONFIG_DICT))
|
| 92 |
+
self.visual_projection_2 = nn.Linear(1024, 1280, bias=False)
|
| 93 |
+
self.fuse_module = FuseModule(4096, 2048)
|
| 94 |
+
|
| 95 |
+
def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask):
|
| 96 |
+
dtype = next(self.parameters()).dtype
|
| 97 |
+
device = next(self.parameters()).device
|
| 98 |
+
b, num_inputs, c, h, w = id_pixel_values.shape
|
| 99 |
+
# device setting
|
| 100 |
+
id_pixel_values = id_pixel_values.to(device=device, dtype=dtype)
|
| 101 |
+
prompt_embeds = prompt_embeds.to(device=device, dtype=dtype)
|
| 102 |
+
class_tokens_mask = class_tokens_mask.to(device=device)
|
| 103 |
+
|
| 104 |
+
id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w)
|
| 105 |
+
|
| 106 |
+
id_pixel_values = F.interpolate(id_pixel_values, size=(224, 224), mode="bilinear", align_corners=False)
|
| 107 |
+
# id embeds <--> input image
|
| 108 |
+
shared_id_embeds = self.vision_model(id_pixel_values)[1]
|
| 109 |
+
id_embeds = self.visual_projection(shared_id_embeds)
|
| 110 |
+
id_embeds_2 = self.visual_projection_2(shared_id_embeds)
|
| 111 |
+
|
| 112 |
+
id_embeds = id_embeds.view(b, num_inputs, 1, -1)
|
| 113 |
+
id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
|
| 114 |
+
|
| 115 |
+
id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
|
| 116 |
+
updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask)
|
| 117 |
+
return updated_prompt_embeds
|
| 118 |
+
|
src/jsonl_datasets.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
from datasets import load_dataset
|
| 3 |
+
from torchvision import transforms
|
| 4 |
+
import random
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
Image.MAX_IMAGE_PIXELS = None
|
| 8 |
+
|
| 9 |
+
def multiple_16(num: float):
|
| 10 |
+
return int(round(num / 16) * 16)
|
| 11 |
+
|
| 12 |
+
def get_random_resolution(min_size=512, max_size=1280, multiple=16):
|
| 13 |
+
resolution = random.randint(min_size // multiple, max_size // multiple) * multiple
|
| 14 |
+
return resolution
|
| 15 |
+
|
| 16 |
+
def load_image_safely(image_path, size):
|
| 17 |
+
try:
|
| 18 |
+
image = Image.open(image_path).convert("RGB")
|
| 19 |
+
return image
|
| 20 |
+
except Exception as e:
|
| 21 |
+
print("file error: "+image_path)
|
| 22 |
+
with open("failed_images.txt", "a") as f:
|
| 23 |
+
f.write(f"{image_path}\n")
|
| 24 |
+
return Image.new("RGB", (size, size), (255, 255, 255))
|
| 25 |
+
|
| 26 |
+
def make_train_dataset(args, tokenizer, accelerator=None):
|
| 27 |
+
if args.train_data_dir is not None:
|
| 28 |
+
print("load_data")
|
| 29 |
+
dataset = load_dataset('json', data_files=args.train_data_dir)
|
| 30 |
+
|
| 31 |
+
column_names = dataset["train"].column_names
|
| 32 |
+
|
| 33 |
+
# 6. Get the column names for input/target.
|
| 34 |
+
caption_column = args.caption_column
|
| 35 |
+
target_column = args.target_column
|
| 36 |
+
if args.subject_column is not None:
|
| 37 |
+
subject_columns = args.subject_column.split(",")
|
| 38 |
+
if args.spatial_column is not None:
|
| 39 |
+
spatial_columns= args.spatial_column.split(",")
|
| 40 |
+
|
| 41 |
+
size = args.cond_size
|
| 42 |
+
noise_size = get_random_resolution(max_size=args.noise_size) # maybe 768 or higher
|
| 43 |
+
subject_cond_train_transforms = transforms.Compose(
|
| 44 |
+
[
|
| 45 |
+
transforms.Lambda(lambda img: img.resize((
|
| 46 |
+
multiple_16(size * img.size[0] / max(img.size)),
|
| 47 |
+
multiple_16(size * img.size[1] / max(img.size))
|
| 48 |
+
), resample=Image.BILINEAR)),
|
| 49 |
+
transforms.RandomHorizontalFlip(p=0.7),
|
| 50 |
+
transforms.RandomRotation(degrees=20),
|
| 51 |
+
transforms.Lambda(lambda img: transforms.Pad(
|
| 52 |
+
padding=(
|
| 53 |
+
int((size - img.size[0]) / 2),
|
| 54 |
+
int((size - img.size[1]) / 2),
|
| 55 |
+
int((size - img.size[0]) / 2),
|
| 56 |
+
int((size - img.size[1]) / 2)
|
| 57 |
+
),
|
| 58 |
+
fill=0
|
| 59 |
+
)(img)),
|
| 60 |
+
transforms.ToTensor(),
|
| 61 |
+
transforms.Normalize([0.5], [0.5]),
|
| 62 |
+
]
|
| 63 |
+
)
|
| 64 |
+
cond_train_transforms = transforms.Compose(
|
| 65 |
+
[
|
| 66 |
+
transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BILINEAR),
|
| 67 |
+
transforms.CenterCrop((size, size)),
|
| 68 |
+
transforms.ToTensor(),
|
| 69 |
+
transforms.Normalize([0.5], [0.5]),
|
| 70 |
+
]
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def train_transforms(image, noise_size):
|
| 74 |
+
train_transforms_ = transforms.Compose(
|
| 75 |
+
[
|
| 76 |
+
transforms.Lambda(lambda img: img.resize((
|
| 77 |
+
multiple_16(noise_size * img.size[0] / max(img.size)),
|
| 78 |
+
multiple_16(noise_size * img.size[1] / max(img.size))
|
| 79 |
+
), resample=Image.BILINEAR)),
|
| 80 |
+
transforms.ToTensor(),
|
| 81 |
+
transforms.Normalize([0.5], [0.5]),
|
| 82 |
+
]
|
| 83 |
+
)
|
| 84 |
+
transformed_image = train_transforms_(image)
|
| 85 |
+
return transformed_image
|
| 86 |
+
|
| 87 |
+
def load_and_transform_cond_images(images):
|
| 88 |
+
transformed_images = [cond_train_transforms(image) for image in images]
|
| 89 |
+
concatenated_image = torch.cat(transformed_images, dim=1)
|
| 90 |
+
return concatenated_image
|
| 91 |
+
|
| 92 |
+
def load_and_transform_subject_images(images):
|
| 93 |
+
transformed_images = [subject_cond_train_transforms(image) for image in images]
|
| 94 |
+
concatenated_image = torch.cat(transformed_images, dim=1)
|
| 95 |
+
return concatenated_image
|
| 96 |
+
|
| 97 |
+
tokenizer_clip = tokenizer[0]
|
| 98 |
+
tokenizer_t5 = tokenizer[1]
|
| 99 |
+
|
| 100 |
+
def tokenize_prompt_clip_t5(examples):
|
| 101 |
+
captions = []
|
| 102 |
+
for caption in examples[caption_column]:
|
| 103 |
+
if isinstance(caption, str):
|
| 104 |
+
if random.random() < 0.1:
|
| 105 |
+
captions.append(" ") # 将文本设为空
|
| 106 |
+
else:
|
| 107 |
+
captions.append(caption)
|
| 108 |
+
elif isinstance(caption, list):
|
| 109 |
+
# take a random caption if there are multiple
|
| 110 |
+
if random.random() < 0.1:
|
| 111 |
+
captions.append(" ")
|
| 112 |
+
else:
|
| 113 |
+
captions.append(random.choice(caption))
|
| 114 |
+
else:
|
| 115 |
+
raise ValueError(
|
| 116 |
+
f"Caption column `{caption_column}` should contain either strings or lists of strings."
|
| 117 |
+
)
|
| 118 |
+
text_inputs = tokenizer_clip(
|
| 119 |
+
captions,
|
| 120 |
+
padding="max_length",
|
| 121 |
+
max_length=77,
|
| 122 |
+
truncation=True,
|
| 123 |
+
return_length=False,
|
| 124 |
+
return_overflowing_tokens=False,
|
| 125 |
+
return_tensors="pt",
|
| 126 |
+
)
|
| 127 |
+
text_input_ids_1 = text_inputs.input_ids
|
| 128 |
+
|
| 129 |
+
text_inputs = tokenizer_t5(
|
| 130 |
+
captions,
|
| 131 |
+
padding="max_length",
|
| 132 |
+
max_length=512,
|
| 133 |
+
truncation=True,
|
| 134 |
+
return_length=False,
|
| 135 |
+
return_overflowing_tokens=False,
|
| 136 |
+
return_tensors="pt",
|
| 137 |
+
)
|
| 138 |
+
text_input_ids_2 = text_inputs.input_ids
|
| 139 |
+
return text_input_ids_1, text_input_ids_2
|
| 140 |
+
|
| 141 |
+
def preprocess_train(examples):
|
| 142 |
+
_examples = {}
|
| 143 |
+
if args.subject_column is not None:
|
| 144 |
+
subject_images = [[load_image_safely(examples[column][i], args.cond_size) for column in subject_columns] for i in range(len(examples[target_column]))]
|
| 145 |
+
_examples["subject_pixel_values"] = [load_and_transform_subject_images(subject) for subject in subject_images]
|
| 146 |
+
if args.spatial_column is not None:
|
| 147 |
+
spatial_images = [[load_image_safely(examples[column][i], args.cond_size) for column in spatial_columns] for i in range(len(examples[target_column]))]
|
| 148 |
+
_examples["cond_pixel_values"] = [load_and_transform_cond_images(spatial) for spatial in spatial_images]
|
| 149 |
+
target_images = [load_image_safely(image_path, args.cond_size) for image_path in examples[target_column]]
|
| 150 |
+
_examples["pixel_values"] = [train_transforms(image, noise_size) for image in target_images]
|
| 151 |
+
_examples["token_ids_clip"], _examples["token_ids_t5"] = tokenize_prompt_clip_t5(examples)
|
| 152 |
+
return _examples
|
| 153 |
+
|
| 154 |
+
if accelerator is not None:
|
| 155 |
+
with accelerator.main_process_first():
|
| 156 |
+
train_dataset = dataset["train"].with_transform(preprocess_train)
|
| 157 |
+
else:
|
| 158 |
+
train_dataset = dataset["train"].with_transform(preprocess_train)
|
| 159 |
+
|
| 160 |
+
return train_dataset
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def collate_fn(examples):
|
| 164 |
+
if examples[0].get("cond_pixel_values") is not None:
|
| 165 |
+
cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples])
|
| 166 |
+
cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 167 |
+
else:
|
| 168 |
+
cond_pixel_values = None
|
| 169 |
+
if examples[0].get("subject_pixel_values") is not None:
|
| 170 |
+
subject_pixel_values = torch.stack([example["subject_pixel_values"] for example in examples])
|
| 171 |
+
subject_pixel_values = subject_pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 172 |
+
else:
|
| 173 |
+
subject_pixel_values = None
|
| 174 |
+
|
| 175 |
+
target_pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
| 176 |
+
target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 177 |
+
token_ids_clip = torch.stack([torch.tensor(example["token_ids_clip"]) for example in examples])
|
| 178 |
+
token_ids_t5 = torch.stack([torch.tensor(example["token_ids_t5"]) for example in examples])
|
| 179 |
+
|
| 180 |
+
return {
|
| 181 |
+
"cond_pixel_values": cond_pixel_values,
|
| 182 |
+
"subject_pixel_values": subject_pixel_values,
|
| 183 |
+
"pixel_values": target_pixel_values,
|
| 184 |
+
"text_ids_1": token_ids_clip,
|
| 185 |
+
"text_ids_2": token_ids_t5,
|
| 186 |
+
}
|
src/kontext_custom_pipeline.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/layers.py
ADDED
|
@@ -0,0 +1,673 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import math
|
| 3 |
+
from typing import Callable, List, Optional, Tuple, Union
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from diffusers.models.attention_processor import Attention
|
| 10 |
+
|
| 11 |
+
# Global variables for attention visualization
|
| 12 |
+
step = 0
|
| 13 |
+
global_timestep = 0
|
| 14 |
+
global_timestep2 = 0
|
| 15 |
+
|
| 16 |
+
def scaled_dot_product_average_attention_map(query, key, attn_mask=None, is_causal=False, scale=None) -> torch.Tensor:
|
| 17 |
+
# Efficient implementation equivalent to the following:
|
| 18 |
+
L, S = query.size(-2), key.size(-2)
|
| 19 |
+
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
| 20 |
+
attn_bias = torch.zeros(L, S, dtype=query.dtype)
|
| 21 |
+
if is_causal:
|
| 22 |
+
assert attn_mask is None
|
| 23 |
+
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
|
| 24 |
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
| 25 |
+
attn_bias.to(query.dtype)
|
| 26 |
+
|
| 27 |
+
if attn_mask is not None:
|
| 28 |
+
if attn_mask.dtype == torch.bool:
|
| 29 |
+
attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
| 30 |
+
else:
|
| 31 |
+
attn_bias += attn_mask
|
| 32 |
+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
| 33 |
+
attn_weight += attn_bias.to(attn_weight.device)
|
| 34 |
+
attn_weight = attn_weight.mean(dim=(1, 2))
|
| 35 |
+
return attn_weight
|
| 36 |
+
|
| 37 |
+
class LoRALinearLayer(nn.Module):
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
in_features: int,
|
| 41 |
+
out_features: int,
|
| 42 |
+
rank: int = 4,
|
| 43 |
+
network_alpha: Optional[float] = None,
|
| 44 |
+
device: Optional[Union[torch.device, str]] = None,
|
| 45 |
+
dtype: Optional[torch.dtype] = None,
|
| 46 |
+
):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
|
| 49 |
+
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
|
| 50 |
+
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
| 51 |
+
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
| 52 |
+
self.network_alpha = network_alpha
|
| 53 |
+
self.rank = rank
|
| 54 |
+
self.out_features = out_features
|
| 55 |
+
self.in_features = in_features
|
| 56 |
+
|
| 57 |
+
nn.init.normal_(self.down.weight, std=1 / rank)
|
| 58 |
+
nn.init.zeros_(self.up.weight)
|
| 59 |
+
|
| 60 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 61 |
+
orig_dtype = hidden_states.dtype
|
| 62 |
+
dtype = self.down.weight.dtype
|
| 63 |
+
|
| 64 |
+
down_hidden_states = self.down(hidden_states.to(dtype))
|
| 65 |
+
up_hidden_states = self.up(down_hidden_states)
|
| 66 |
+
|
| 67 |
+
if self.network_alpha is not None:
|
| 68 |
+
up_hidden_states *= self.network_alpha / self.rank
|
| 69 |
+
|
| 70 |
+
return up_hidden_states.to(orig_dtype)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class MultiSingleStreamBlockLoraProcessor(nn.Module):
|
| 74 |
+
def __init__(self, in_features: int, out_features: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, n_loras=1):
|
| 75 |
+
super().__init__()
|
| 76 |
+
# Initialize a list to store the LoRA layers
|
| 77 |
+
self.n_loras = n_loras
|
| 78 |
+
self.q_loras = nn.ModuleList([
|
| 79 |
+
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
|
| 80 |
+
for i in range(n_loras)
|
| 81 |
+
])
|
| 82 |
+
self.k_loras = nn.ModuleList([
|
| 83 |
+
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
|
| 84 |
+
for i in range(n_loras)
|
| 85 |
+
])
|
| 86 |
+
self.v_loras = nn.ModuleList([
|
| 87 |
+
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
|
| 88 |
+
for i in range(n_loras)
|
| 89 |
+
])
|
| 90 |
+
self.lora_weights = lora_weights
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def __call__(self,
|
| 94 |
+
attn: Attention,
|
| 95 |
+
hidden_states: torch.FloatTensor,
|
| 96 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
| 97 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 98 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 99 |
+
use_cond = False,
|
| 100 |
+
) -> torch.FloatTensor:
|
| 101 |
+
|
| 102 |
+
batch_size, seq_len, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 103 |
+
query = attn.to_q(hidden_states)
|
| 104 |
+
key = attn.to_k(hidden_states)
|
| 105 |
+
value = attn.to_v(hidden_states)
|
| 106 |
+
|
| 107 |
+
for i in range(self.n_loras):
|
| 108 |
+
query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
|
| 109 |
+
key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
|
| 110 |
+
value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
|
| 111 |
+
|
| 112 |
+
inner_dim = key.shape[-1]
|
| 113 |
+
head_dim = inner_dim // attn.heads
|
| 114 |
+
|
| 115 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 116 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 117 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 118 |
+
|
| 119 |
+
if attn.norm_q is not None:
|
| 120 |
+
query = attn.norm_q(query)
|
| 121 |
+
if attn.norm_k is not None:
|
| 122 |
+
key = attn.norm_k(key)
|
| 123 |
+
|
| 124 |
+
if image_rotary_emb is not None:
|
| 125 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 126 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
| 127 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
| 128 |
+
|
| 129 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
|
| 130 |
+
|
| 131 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 132 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 133 |
+
|
| 134 |
+
return hidden_states
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class MultiDoubleStreamBlockLoraProcessor(nn.Module):
|
| 138 |
+
def __init__(self, in_features: int, out_features: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, n_loras=1):
|
| 139 |
+
super().__init__()
|
| 140 |
+
|
| 141 |
+
# Initialize a list to store the LoRA layers
|
| 142 |
+
self.n_loras = n_loras
|
| 143 |
+
self.q_loras = nn.ModuleList([
|
| 144 |
+
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
|
| 145 |
+
for i in range(n_loras)
|
| 146 |
+
])
|
| 147 |
+
self.k_loras = nn.ModuleList([
|
| 148 |
+
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
|
| 149 |
+
for i in range(n_loras)
|
| 150 |
+
])
|
| 151 |
+
self.v_loras = nn.ModuleList([
|
| 152 |
+
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
|
| 153 |
+
for i in range(n_loras)
|
| 154 |
+
])
|
| 155 |
+
self.proj_loras = nn.ModuleList([
|
| 156 |
+
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
|
| 157 |
+
for i in range(n_loras)
|
| 158 |
+
])
|
| 159 |
+
self.lora_weights = lora_weights
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def __call__(self,
|
| 163 |
+
attn: Attention,
|
| 164 |
+
hidden_states: torch.FloatTensor,
|
| 165 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
| 166 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 167 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 168 |
+
use_cond=False,
|
| 169 |
+
) -> torch.FloatTensor:
|
| 170 |
+
|
| 171 |
+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 172 |
+
|
| 173 |
+
# `context` projections.
|
| 174 |
+
inner_dim = 3072
|
| 175 |
+
head_dim = inner_dim // attn.heads
|
| 176 |
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
| 177 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
| 178 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
| 179 |
+
|
| 180 |
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
| 181 |
+
batch_size, -1, attn.heads, head_dim
|
| 182 |
+
).transpose(1, 2)
|
| 183 |
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
| 184 |
+
batch_size, -1, attn.heads, head_dim
|
| 185 |
+
).transpose(1, 2)
|
| 186 |
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
| 187 |
+
batch_size, -1, attn.heads, head_dim
|
| 188 |
+
).transpose(1, 2)
|
| 189 |
+
|
| 190 |
+
if attn.norm_added_q is not None:
|
| 191 |
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
| 192 |
+
if attn.norm_added_k is not None:
|
| 193 |
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
| 194 |
+
|
| 195 |
+
query = attn.to_q(hidden_states)
|
| 196 |
+
key = attn.to_k(hidden_states)
|
| 197 |
+
value = attn.to_v(hidden_states)
|
| 198 |
+
for i in range(self.n_loras):
|
| 199 |
+
query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
|
| 200 |
+
key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
|
| 201 |
+
value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
|
| 202 |
+
|
| 203 |
+
inner_dim = key.shape[-1]
|
| 204 |
+
head_dim = inner_dim // attn.heads
|
| 205 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 206 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 207 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 208 |
+
|
| 209 |
+
if attn.norm_q is not None:
|
| 210 |
+
query = attn.norm_q(query)
|
| 211 |
+
if attn.norm_k is not None:
|
| 212 |
+
key = attn.norm_k(key)
|
| 213 |
+
|
| 214 |
+
# attention
|
| 215 |
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
| 216 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
| 217 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
| 218 |
+
|
| 219 |
+
if image_rotary_emb is not None:
|
| 220 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 221 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
| 222 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
| 223 |
+
|
| 224 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
|
| 225 |
+
|
| 226 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 227 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 228 |
+
|
| 229 |
+
encoder_hidden_states, hidden_states = (
|
| 230 |
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
| 231 |
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# Linear projection (with LoRA weight applied to each proj layer)
|
| 235 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 236 |
+
for i in range(self.n_loras):
|
| 237 |
+
hidden_states = hidden_states + self.lora_weights[i] * self.proj_loras[i](hidden_states)
|
| 238 |
+
# dropout
|
| 239 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 240 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 241 |
+
|
| 242 |
+
return (hidden_states, encoder_hidden_states)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class MultiSingleStreamBlockLoraProcessorWithLoss(nn.Module):
|
| 246 |
+
def __init__(self, in_features: int, out_features: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, n_loras=1):
|
| 247 |
+
super().__init__()
|
| 248 |
+
# Initialize a list to store the LoRA layers
|
| 249 |
+
self.n_loras = n_loras
|
| 250 |
+
self.q_loras = nn.ModuleList([
|
| 251 |
+
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
|
| 252 |
+
for i in range(n_loras)
|
| 253 |
+
])
|
| 254 |
+
self.k_loras = nn.ModuleList([
|
| 255 |
+
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
|
| 256 |
+
for i in range(n_loras)
|
| 257 |
+
])
|
| 258 |
+
self.v_loras = nn.ModuleList([
|
| 259 |
+
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
|
| 260 |
+
for i in range(n_loras)
|
| 261 |
+
])
|
| 262 |
+
self.lora_weights = lora_weights
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def __call__(self,
|
| 266 |
+
attn: Attention,
|
| 267 |
+
hidden_states: torch.FloatTensor,
|
| 268 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
| 269 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 270 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 271 |
+
use_cond = False,
|
| 272 |
+
) -> torch.FloatTensor:
|
| 273 |
+
|
| 274 |
+
batch_size, seq_len, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 275 |
+
query = attn.to_q(hidden_states)
|
| 276 |
+
key = attn.to_k(hidden_states)
|
| 277 |
+
value = attn.to_v(hidden_states)
|
| 278 |
+
encoder_hidden_length = 512
|
| 279 |
+
|
| 280 |
+
length = (hidden_states.shape[-2] - encoder_hidden_length) // 3
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
for i in range(self.n_loras):
|
| 284 |
+
query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
|
| 285 |
+
key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
|
| 286 |
+
value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
|
| 287 |
+
|
| 288 |
+
inner_dim = key.shape[-1]
|
| 289 |
+
head_dim = inner_dim // attn.heads
|
| 290 |
+
|
| 291 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 292 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 293 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 294 |
+
|
| 295 |
+
if attn.norm_q is not None:
|
| 296 |
+
query = attn.norm_q(query)
|
| 297 |
+
if attn.norm_k is not None:
|
| 298 |
+
key = attn.norm_k(key)
|
| 299 |
+
|
| 300 |
+
if image_rotary_emb is not None:
|
| 301 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 302 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
| 303 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
| 304 |
+
|
| 305 |
+
# query_cond_a = query[:, :, encoder_hidden_length+length : encoder_hidden_length+2*length, :]
|
| 306 |
+
# query_cond_b = query[:, :, encoder_hidden_length+2*length : encoder_hidden_length+3*length, :]
|
| 307 |
+
|
| 308 |
+
# key_noise = key[:, :, encoder_hidden_length:encoder_hidden_length+length, :]
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
# attention_probs_query_a_key_noise = scaled_dot_product_average_attention_map(query_cond_a, key_noise, attn_mask=attention_mask, is_causal=False)
|
| 312 |
+
# attention_probs_query_b_key_noise = scaled_dot_product_average_attention_map(query_cond_b, key_noise, attn_mask=attention_mask, is_causal=False)
|
| 313 |
+
|
| 314 |
+
# attn.attention_probs_query_a_key_noise = attention_probs_query_a_key_noise
|
| 315 |
+
# attn.attention_probs_query_b_key_noise = attention_probs_query_b_key_noise
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
|
| 319 |
+
|
| 320 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 321 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 322 |
+
|
| 323 |
+
return hidden_states
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
class MultiDoubleStreamBlockLoraProcessorWithLoss(nn.Module):
|
| 327 |
+
def __init__(self, in_features: int, out_features: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, n_loras=1):
|
| 328 |
+
super().__init__()
|
| 329 |
+
|
| 330 |
+
# Initialize a list to store the LoRA layers
|
| 331 |
+
self.n_loras = n_loras
|
| 332 |
+
self.q_loras = nn.ModuleList([
|
| 333 |
+
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
|
| 334 |
+
for i in range(n_loras)
|
| 335 |
+
])
|
| 336 |
+
self.k_loras = nn.ModuleList([
|
| 337 |
+
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
|
| 338 |
+
for i in range(n_loras)
|
| 339 |
+
])
|
| 340 |
+
self.v_loras = nn.ModuleList([
|
| 341 |
+
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
|
| 342 |
+
for i in range(n_loras)
|
| 343 |
+
])
|
| 344 |
+
self.proj_loras = nn.ModuleList([
|
| 345 |
+
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
|
| 346 |
+
for i in range(n_loras)
|
| 347 |
+
])
|
| 348 |
+
self.lora_weights = lora_weights
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def __call__(self,
|
| 352 |
+
attn: Attention,
|
| 353 |
+
hidden_states: torch.FloatTensor,
|
| 354 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
| 355 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 356 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 357 |
+
use_cond=False,
|
| 358 |
+
) -> torch.FloatTensor:
|
| 359 |
+
|
| 360 |
+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 361 |
+
|
| 362 |
+
# `context` projections.
|
| 363 |
+
inner_dim = 3072
|
| 364 |
+
head_dim = inner_dim // attn.heads
|
| 365 |
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
| 366 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
| 367 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
| 368 |
+
|
| 369 |
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
| 370 |
+
batch_size, -1, attn.heads, head_dim
|
| 371 |
+
).transpose(1, 2)
|
| 372 |
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
| 373 |
+
batch_size, -1, attn.heads, head_dim
|
| 374 |
+
).transpose(1, 2)
|
| 375 |
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
| 376 |
+
batch_size, -1, attn.heads, head_dim
|
| 377 |
+
).transpose(1, 2)
|
| 378 |
+
|
| 379 |
+
if attn.norm_added_q is not None:
|
| 380 |
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
| 381 |
+
if attn.norm_added_k is not None:
|
| 382 |
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
| 383 |
+
|
| 384 |
+
query = attn.to_q(hidden_states)
|
| 385 |
+
key = attn.to_k(hidden_states)
|
| 386 |
+
value = attn.to_v(hidden_states)
|
| 387 |
+
length = hidden_states.shape[-2] // 3
|
| 388 |
+
|
| 389 |
+
for i in range(self.n_loras):
|
| 390 |
+
query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
|
| 391 |
+
key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
|
| 392 |
+
value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
|
| 393 |
+
|
| 394 |
+
inner_dim = key.shape[-1]
|
| 395 |
+
head_dim = inner_dim // attn.heads
|
| 396 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 397 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 398 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 399 |
+
|
| 400 |
+
if attn.norm_q is not None:
|
| 401 |
+
query = attn.norm_q(query)
|
| 402 |
+
if attn.norm_k is not None:
|
| 403 |
+
key = attn.norm_k(key)
|
| 404 |
+
|
| 405 |
+
# attention
|
| 406 |
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
| 407 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
| 408 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
| 409 |
+
|
| 410 |
+
if image_rotary_emb is not None:
|
| 411 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 412 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
| 413 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
| 414 |
+
encoder_hidden_length = 512
|
| 415 |
+
|
| 416 |
+
query_cond_a = query[:, :, encoder_hidden_length+length : encoder_hidden_length+2*length, :]
|
| 417 |
+
query_cond_b = query[:, :, encoder_hidden_length+2*length : encoder_hidden_length+3*length, :]
|
| 418 |
+
|
| 419 |
+
key_noise = key[:, :, encoder_hidden_length:encoder_hidden_length+length, :]
|
| 420 |
+
|
| 421 |
+
attention_probs_query_a_key_noise = scaled_dot_product_average_attention_map(query_cond_a, key_noise, attn_mask=attention_mask, is_causal=False)
|
| 422 |
+
attention_probs_query_b_key_noise = scaled_dot_product_average_attention_map(query_cond_b, key_noise, attn_mask=attention_mask, is_causal=False)
|
| 423 |
+
|
| 424 |
+
attn.attention_probs_query_a_key_noise = attention_probs_query_a_key_noise
|
| 425 |
+
attn.attention_probs_query_b_key_noise = attention_probs_query_b_key_noise
|
| 426 |
+
|
| 427 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
|
| 428 |
+
|
| 429 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 430 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 431 |
+
|
| 432 |
+
encoder_hidden_states, hidden_states = (
|
| 433 |
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
| 434 |
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# Linear projection (with LoRA weight applied to each proj layer)
|
| 438 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 439 |
+
for i in range(self.n_loras):
|
| 440 |
+
hidden_states = hidden_states + self.lora_weights[i] * self.proj_loras[i](hidden_states)
|
| 441 |
+
# dropout
|
| 442 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 443 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 444 |
+
|
| 445 |
+
return (hidden_states, encoder_hidden_states)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
class MultiDoubleStreamBlockLoraProcessor_visual(nn.Module):
|
| 450 |
+
def __init__(self, in_features: int, out_features: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, n_loras=1):
|
| 451 |
+
super().__init__()
|
| 452 |
+
|
| 453 |
+
# Initialize a list to store the LoRA layers
|
| 454 |
+
self.n_loras = n_loras
|
| 455 |
+
self.q_loras = nn.ModuleList([
|
| 456 |
+
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
|
| 457 |
+
for i in range(n_loras)
|
| 458 |
+
])
|
| 459 |
+
self.k_loras = nn.ModuleList([
|
| 460 |
+
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
|
| 461 |
+
for i in range(n_loras)
|
| 462 |
+
])
|
| 463 |
+
self.v_loras = nn.ModuleList([
|
| 464 |
+
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
|
| 465 |
+
for i in range(n_loras)
|
| 466 |
+
])
|
| 467 |
+
self.proj_loras = nn.ModuleList([
|
| 468 |
+
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
|
| 469 |
+
for i in range(n_loras)
|
| 470 |
+
])
|
| 471 |
+
self.lora_weights = lora_weights
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def __call__(self,
|
| 475 |
+
attn: Attention,
|
| 476 |
+
hidden_states: torch.FloatTensor,
|
| 477 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
| 478 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 479 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 480 |
+
use_cond=False,
|
| 481 |
+
) -> torch.FloatTensor:
|
| 482 |
+
|
| 483 |
+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 484 |
+
|
| 485 |
+
# `context` projections.
|
| 486 |
+
inner_dim = 3072
|
| 487 |
+
head_dim = inner_dim // attn.heads
|
| 488 |
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
| 489 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
| 490 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
| 491 |
+
|
| 492 |
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
| 493 |
+
batch_size, -1, attn.heads, head_dim
|
| 494 |
+
).transpose(1, 2)
|
| 495 |
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
| 496 |
+
batch_size, -1, attn.heads, head_dim
|
| 497 |
+
).transpose(1, 2)
|
| 498 |
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
| 499 |
+
batch_size, -1, attn.heads, head_dim
|
| 500 |
+
).transpose(1, 2)
|
| 501 |
+
|
| 502 |
+
if attn.norm_added_q is not None:
|
| 503 |
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
| 504 |
+
if attn.norm_added_k is not None:
|
| 505 |
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
| 506 |
+
|
| 507 |
+
query = attn.to_q(hidden_states)
|
| 508 |
+
key = attn.to_k(hidden_states)
|
| 509 |
+
value = attn.to_v(hidden_states)
|
| 510 |
+
length = hidden_states.shape[-2] // 3
|
| 511 |
+
|
| 512 |
+
for i in range(self.n_loras):
|
| 513 |
+
query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
|
| 514 |
+
key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
|
| 515 |
+
value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
|
| 516 |
+
|
| 517 |
+
inner_dim = key.shape[-1]
|
| 518 |
+
head_dim = inner_dim // attn.heads
|
| 519 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 520 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 521 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 522 |
+
|
| 523 |
+
if attn.norm_q is not None:
|
| 524 |
+
query = attn.norm_q(query)
|
| 525 |
+
if attn.norm_k is not None:
|
| 526 |
+
key = attn.norm_k(key)
|
| 527 |
+
|
| 528 |
+
# attention
|
| 529 |
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
| 530 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
| 531 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
| 532 |
+
|
| 533 |
+
if image_rotary_emb is not None:
|
| 534 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 535 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
| 536 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
| 537 |
+
encoder_hidden_length = 512
|
| 538 |
+
|
| 539 |
+
query_cond_a = query[:, :, encoder_hidden_length+length : encoder_hidden_length+2*length, :]
|
| 540 |
+
query_cond_b = query[:, :, encoder_hidden_length+2*length : encoder_hidden_length+3*length, :]
|
| 541 |
+
|
| 542 |
+
key_noise = key[:, :, encoder_hidden_length:encoder_hidden_length+length, :]
|
| 543 |
+
|
| 544 |
+
attention_probs_query_a_key_noise = scaled_dot_product_average_attention_map(query_cond_a, key_noise, attn_mask=attention_mask, is_causal=False)
|
| 545 |
+
attention_probs_query_b_key_noise = scaled_dot_product_average_attention_map(query_cond_b, key_noise, attn_mask=attention_mask, is_causal=False)
|
| 546 |
+
|
| 547 |
+
if not hasattr(attn, 'attention_probs_query_a_key_noise'):
|
| 548 |
+
attn.attention_probs_query_a_key_noise = []
|
| 549 |
+
if not hasattr(attn, 'attention_probs_query_b_key_noise'):
|
| 550 |
+
attn.attention_probs_query_b_key_noise = []
|
| 551 |
+
|
| 552 |
+
global global_timestep
|
| 553 |
+
|
| 554 |
+
attn.attention_probs_query_a_key_noise.append((global_timestep//19, attention_probs_query_a_key_noise))
|
| 555 |
+
attn.attention_probs_query_b_key_noise.append((global_timestep//19, attention_probs_query_b_key_noise))
|
| 556 |
+
|
| 557 |
+
print(f"Global Timestep: {global_timestep//19}")
|
| 558 |
+
|
| 559 |
+
global_timestep += 1
|
| 560 |
+
|
| 561 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
|
| 562 |
+
|
| 563 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 564 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 565 |
+
|
| 566 |
+
encoder_hidden_states, hidden_states = (
|
| 567 |
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
| 568 |
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
# Linear projection (with LoRA weight applied to each proj layer)
|
| 572 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 573 |
+
for i in range(self.n_loras):
|
| 574 |
+
hidden_states = hidden_states + self.lora_weights[i] * self.proj_loras[i](hidden_states)
|
| 575 |
+
# dropout
|
| 576 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 577 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 578 |
+
|
| 579 |
+
return (hidden_states, encoder_hidden_states)
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
class MultiSingleStreamBlockLoraProcessor_visual(nn.Module):
|
| 584 |
+
def __init__(self, in_features: int, out_features: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, n_loras=1):
|
| 585 |
+
super().__init__()
|
| 586 |
+
# Initialize a list to store the LoRA layers
|
| 587 |
+
self.n_loras = n_loras
|
| 588 |
+
self.q_loras = nn.ModuleList([
|
| 589 |
+
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
|
| 590 |
+
for i in range(n_loras)
|
| 591 |
+
])
|
| 592 |
+
self.k_loras = nn.ModuleList([
|
| 593 |
+
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
|
| 594 |
+
for i in range(n_loras)
|
| 595 |
+
])
|
| 596 |
+
self.v_loras = nn.ModuleList([
|
| 597 |
+
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
|
| 598 |
+
for i in range(n_loras)
|
| 599 |
+
])
|
| 600 |
+
self.lora_weights = lora_weights
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
def __call__(self,
|
| 604 |
+
attn: Attention,
|
| 605 |
+
hidden_states: torch.FloatTensor,
|
| 606 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
| 607 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 608 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 609 |
+
use_cond = False,
|
| 610 |
+
) -> torch.FloatTensor:
|
| 611 |
+
|
| 612 |
+
batch_size, seq_len, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 613 |
+
query = attn.to_q(hidden_states)
|
| 614 |
+
key = attn.to_k(hidden_states)
|
| 615 |
+
value = attn.to_v(hidden_states)
|
| 616 |
+
encoder_hidden_length = 512
|
| 617 |
+
|
| 618 |
+
length = (hidden_states.shape[-2] - encoder_hidden_length) // 3
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
for i in range(self.n_loras):
|
| 622 |
+
query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
|
| 623 |
+
key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
|
| 624 |
+
value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
|
| 625 |
+
|
| 626 |
+
inner_dim = key.shape[-1]
|
| 627 |
+
head_dim = inner_dim // attn.heads
|
| 628 |
+
|
| 629 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 630 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 631 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 632 |
+
|
| 633 |
+
if attn.norm_q is not None:
|
| 634 |
+
query = attn.norm_q(query)
|
| 635 |
+
if attn.norm_k is not None:
|
| 636 |
+
key = attn.norm_k(key)
|
| 637 |
+
|
| 638 |
+
if image_rotary_emb is not None:
|
| 639 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 640 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
| 641 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
| 642 |
+
|
| 643 |
+
if not hasattr(attn, 'attention_probs_query_a_key_noise2'):
|
| 644 |
+
attn.attention_probs_query_a_key_noise2 = []
|
| 645 |
+
if not hasattr(attn, 'attention_probs_query_b_key_noise2'):
|
| 646 |
+
attn.attention_probs_query_b_key_noise2 = []
|
| 647 |
+
|
| 648 |
+
query_cond_a = query[:, :, encoder_hidden_length+length : encoder_hidden_length+2*length, :]
|
| 649 |
+
query_cond_b = query[:, :, encoder_hidden_length+2*length : encoder_hidden_length+3*length, :]
|
| 650 |
+
|
| 651 |
+
key_noise = key[:, :, encoder_hidden_length:encoder_hidden_length+length, :]
|
| 652 |
+
|
| 653 |
+
attention_probs_query_a_key_noise2 = scaled_dot_product_average_attention_map(query_cond_a, key_noise, attn_mask=attention_mask, is_causal=False)
|
| 654 |
+
attention_probs_query_b_key_noise2 = scaled_dot_product_average_attention_map(query_cond_b, key_noise, attn_mask=attention_mask, is_causal=False)
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
global global_timestep2
|
| 658 |
+
|
| 659 |
+
attn.attention_probs_query_a_key_noise2.append((global_timestep//38, attention_probs_query_a_key_noise2))
|
| 660 |
+
attn.attention_probs_query_b_key_noise2.append((global_timestep//38, attention_probs_query_b_key_noise2))
|
| 661 |
+
|
| 662 |
+
print(f"Global Timestep2: {global_timestep2//38}")
|
| 663 |
+
|
| 664 |
+
global_timestep2 += 1
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
|
| 669 |
+
|
| 670 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 671 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 672 |
+
|
| 673 |
+
return hidden_states
|
src/lora_helper.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from diffusers.models.attention_processor import FluxAttnProcessor2_0
|
| 2 |
+
from safetensors import safe_open
|
| 3 |
+
import re
|
| 4 |
+
import torch
|
| 5 |
+
from .layers import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor, MultiDoubleStreamBlockLoraProcessor_visual, MultiDoubleStreamBlockLoraProcessorWithLoss, MultiSingleStreamBlockLoraProcessor_visual
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
device = "cuda:0"
|
| 10 |
+
|
| 11 |
+
def load_safetensors(path):
|
| 12 |
+
tensors = {}
|
| 13 |
+
with safe_open(path, framework="pt", device="cpu") as f:
|
| 14 |
+
for key in f.keys():
|
| 15 |
+
tensors[key] = f.get_tensor(key)
|
| 16 |
+
return tensors
|
| 17 |
+
|
| 18 |
+
def get_lora_rank(checkpoint):
|
| 19 |
+
for k in checkpoint.keys():
|
| 20 |
+
if k.endswith(".down.weight"):
|
| 21 |
+
return checkpoint[k].shape[0]
|
| 22 |
+
|
| 23 |
+
def load_checkpoint(local_path):
|
| 24 |
+
if local_path is not None:
|
| 25 |
+
if '.safetensors' in local_path:
|
| 26 |
+
print(f"Loading .safetensors checkpoint from {local_path}")
|
| 27 |
+
checkpoint = load_safetensors(local_path)
|
| 28 |
+
else:
|
| 29 |
+
print(f"Loading checkpoint from {local_path}")
|
| 30 |
+
checkpoint = torch.load(local_path, map_location='cpu')
|
| 31 |
+
return checkpoint
|
| 32 |
+
|
| 33 |
+
def update_model_with_lora(checkpoint, lora_weights, transformer):
|
| 34 |
+
number = len(lora_weights)
|
| 35 |
+
ranks = [get_lora_rank(checkpoint) for _ in range(number)]
|
| 36 |
+
lora_attn_procs = {}
|
| 37 |
+
double_blocks_idx = list(range(19))
|
| 38 |
+
single_blocks_idx = list(range(38))
|
| 39 |
+
for name, attn_processor in transformer.attn_processors.items():
|
| 40 |
+
match = re.search(r'\.(\d+)\.', name)
|
| 41 |
+
if match:
|
| 42 |
+
layer_index = int(match.group(1))
|
| 43 |
+
|
| 44 |
+
if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
|
| 45 |
+
|
| 46 |
+
lora_state_dicts = {}
|
| 47 |
+
for key, value in checkpoint.items():
|
| 48 |
+
# Match based on the layer index in the key (assuming the key contains layer index)
|
| 49 |
+
if re.search(r'\.(\d+)\.', key):
|
| 50 |
+
checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
|
| 51 |
+
if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
|
| 52 |
+
lora_state_dicts[key] = value
|
| 53 |
+
|
| 54 |
+
lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
|
| 55 |
+
in_features=3072, out_features=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, n_loras=number
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Load the weights from the checkpoint dictionary into the corresponding layers
|
| 59 |
+
for n in range(number):
|
| 60 |
+
lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
|
| 61 |
+
lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
|
| 62 |
+
lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
|
| 63 |
+
lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
|
| 64 |
+
lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
|
| 65 |
+
lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
|
| 66 |
+
lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.down.weight', None)
|
| 67 |
+
lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.up.weight', None)
|
| 68 |
+
lora_attn_procs[name].to(device)
|
| 69 |
+
|
| 70 |
+
elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
|
| 71 |
+
|
| 72 |
+
lora_state_dicts = {}
|
| 73 |
+
for key, value in checkpoint.items():
|
| 74 |
+
# Match based on the layer index in the key (assuming the key contains layer index)
|
| 75 |
+
if re.search(r'\.(\d+)\.', key):
|
| 76 |
+
checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
|
| 77 |
+
if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
|
| 78 |
+
lora_state_dicts[key] = value
|
| 79 |
+
|
| 80 |
+
lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
|
| 81 |
+
in_features=3072, out_features=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, n_loras=number
|
| 82 |
+
)
|
| 83 |
+
# Load the weights from the checkpoint dictionary into the corresponding layers
|
| 84 |
+
for n in range(number):
|
| 85 |
+
lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
|
| 86 |
+
lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
|
| 87 |
+
lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
|
| 88 |
+
lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
|
| 89 |
+
lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
|
| 90 |
+
lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
|
| 91 |
+
lora_attn_procs[name].to(device)
|
| 92 |
+
else:
|
| 93 |
+
lora_attn_procs[name] = FluxAttnProcessor2_0()
|
| 94 |
+
|
| 95 |
+
transformer.set_attn_processor(lora_attn_procs)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size):
|
| 99 |
+
ck_number = len(checkpoints)
|
| 100 |
+
cond_lora_number = [len(ls) for ls in lora_weights]
|
| 101 |
+
cond_number = sum(cond_lora_number)
|
| 102 |
+
ranks = [get_lora_rank(checkpoint) for checkpoint in checkpoints]
|
| 103 |
+
multi_lora_weight = []
|
| 104 |
+
for ls in lora_weights:
|
| 105 |
+
for n in ls:
|
| 106 |
+
multi_lora_weight.append(n)
|
| 107 |
+
|
| 108 |
+
lora_attn_procs = {}
|
| 109 |
+
double_blocks_idx = list(range(19))
|
| 110 |
+
single_blocks_idx = list(range(38))
|
| 111 |
+
for name, attn_processor in transformer.attn_processors.items():
|
| 112 |
+
match = re.search(r'\.(\d+)\.', name)
|
| 113 |
+
if match:
|
| 114 |
+
layer_index = int(match.group(1))
|
| 115 |
+
|
| 116 |
+
if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
|
| 117 |
+
lora_state_dicts = [{} for _ in range(ck_number)]
|
| 118 |
+
for idx, checkpoint in enumerate(checkpoints):
|
| 119 |
+
for key, value in checkpoint.items():
|
| 120 |
+
# Match based on the layer index in the key (assuming the key contains layer index)
|
| 121 |
+
if re.search(r'\.(\d+)\.', key):
|
| 122 |
+
checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
|
| 123 |
+
if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
|
| 124 |
+
lora_state_dicts[idx][key] = value
|
| 125 |
+
|
| 126 |
+
lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
|
| 127 |
+
dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Load the weights from the checkpoint dictionary into the corresponding layers
|
| 131 |
+
num = 0
|
| 132 |
+
for idx in range(ck_number):
|
| 133 |
+
for n in range(cond_lora_number[idx]):
|
| 134 |
+
lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None)
|
| 135 |
+
lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None)
|
| 136 |
+
lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None)
|
| 137 |
+
lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None)
|
| 138 |
+
lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None)
|
| 139 |
+
lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None)
|
| 140 |
+
lora_attn_procs[name].proj_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.down.weight', None)
|
| 141 |
+
lora_attn_procs[name].proj_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.up.weight', None)
|
| 142 |
+
lora_attn_procs[name].to(device)
|
| 143 |
+
num += 1
|
| 144 |
+
|
| 145 |
+
elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
|
| 146 |
+
|
| 147 |
+
lora_state_dicts = [{} for _ in range(ck_number)]
|
| 148 |
+
for idx, checkpoint in enumerate(checkpoints):
|
| 149 |
+
for key, value in checkpoint.items():
|
| 150 |
+
# Match based on the layer index in the key (assuming the key contains layer index)
|
| 151 |
+
if re.search(r'\.(\d+)\.', key):
|
| 152 |
+
checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
|
| 153 |
+
if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
|
| 154 |
+
lora_state_dicts[idx][key] = value
|
| 155 |
+
|
| 156 |
+
lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
|
| 157 |
+
dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number
|
| 158 |
+
)
|
| 159 |
+
# Load the weights from the checkpoint dictionary into the corresponding layers
|
| 160 |
+
num = 0
|
| 161 |
+
for idx in range(ck_number):
|
| 162 |
+
for n in range(cond_lora_number[idx]):
|
| 163 |
+
lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None)
|
| 164 |
+
lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None)
|
| 165 |
+
lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None)
|
| 166 |
+
lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None)
|
| 167 |
+
lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None)
|
| 168 |
+
lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None)
|
| 169 |
+
lora_attn_procs[name].to(device)
|
| 170 |
+
num += 1
|
| 171 |
+
|
| 172 |
+
else:
|
| 173 |
+
lora_attn_procs[name] = FluxAttnProcessor2_0()
|
| 174 |
+
|
| 175 |
+
transformer.set_attn_processor(lora_attn_procs)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def set_single_lora(transformer, local_path, lora_weights=[]):
|
| 179 |
+
checkpoint = load_checkpoint(local_path)
|
| 180 |
+
update_model_with_lora(checkpoint, lora_weights, transformer)
|
| 181 |
+
|
| 182 |
+
def set_single_lora_visual(transformer, local_path, lora_weights=[]):
|
| 183 |
+
checkpoint = load_checkpoint(local_path)
|
| 184 |
+
update_model_with_lora_with_visual(checkpoint, lora_weights, transformer)
|
| 185 |
+
|
| 186 |
+
def set_multi_lora(transformer, local_paths, lora_weights=[[]], cond_size=512):
|
| 187 |
+
checkpoints = [load_checkpoint(local_path) for local_path in local_paths]
|
| 188 |
+
update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size)
|
| 189 |
+
|
| 190 |
+
def unset_lora(transformer):
|
| 191 |
+
lora_attn_procs = {}
|
| 192 |
+
for name, attn_processor in transformer.attn_processors.items():
|
| 193 |
+
lora_attn_procs[name] = FluxAttnProcessor2_0()
|
| 194 |
+
transformer.set_attn_processor(lora_attn_procs)
|
| 195 |
+
|
| 196 |
+
def update_model_with_lora_with_visual(checkpoint, lora_weights, transformer):
|
| 197 |
+
number = len(lora_weights)
|
| 198 |
+
ranks = [get_lora_rank(checkpoint) for _ in range(number)]
|
| 199 |
+
lora_attn_procs = {}
|
| 200 |
+
double_blocks_idx = list(range(19))
|
| 201 |
+
single_blocks_idx = list(range(38))
|
| 202 |
+
for name, attn_processor in transformer.attn_processors.items():
|
| 203 |
+
match = re.search(r'\.(\d+)\.', name)
|
| 204 |
+
if match:
|
| 205 |
+
layer_index = int(match.group(1))
|
| 206 |
+
|
| 207 |
+
if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
|
| 208 |
+
|
| 209 |
+
lora_state_dicts = {}
|
| 210 |
+
for key, value in checkpoint.items():
|
| 211 |
+
# Match based on the layer index in the key (assuming the key contains layer index)
|
| 212 |
+
if re.search(r'\.(\d+)\.', key):
|
| 213 |
+
checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
|
| 214 |
+
if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
|
| 215 |
+
lora_state_dicts[key] = value
|
| 216 |
+
|
| 217 |
+
lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor_visual(
|
| 218 |
+
in_features=3072, out_features=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, n_loras=number
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# Load the weights from the checkpoint dictionary into the corresponding layers
|
| 222 |
+
# for n in range(number):
|
| 223 |
+
# lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
|
| 224 |
+
# lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
|
| 225 |
+
# lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
|
| 226 |
+
# lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
|
| 227 |
+
# lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
|
| 228 |
+
# lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
|
| 229 |
+
# lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.down.weight', None)
|
| 230 |
+
# lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.up.weight', None)
|
| 231 |
+
# lora_attn_procs[name].to(device)
|
| 232 |
+
|
| 233 |
+
elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
|
| 234 |
+
|
| 235 |
+
lora_state_dicts = {}
|
| 236 |
+
for key, value in checkpoint.items():
|
| 237 |
+
# Match based on the layer index in the key (assuming the key contains layer index)
|
| 238 |
+
if re.search(r'\.(\d+)\.', key):
|
| 239 |
+
checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
|
| 240 |
+
if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
|
| 241 |
+
lora_state_dicts[key] = value
|
| 242 |
+
|
| 243 |
+
lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor_visual(
|
| 244 |
+
in_features=3072, out_features=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, n_loras=number
|
| 245 |
+
)
|
| 246 |
+
# Load the weights from the checkpoint dictionary into the corresponding layers
|
| 247 |
+
# for n in range(number):
|
| 248 |
+
# lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
|
| 249 |
+
# lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
|
| 250 |
+
# lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
|
| 251 |
+
# lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
|
| 252 |
+
# lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
|
| 253 |
+
# lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
|
| 254 |
+
# lora_attn_procs[name].to(device)
|
| 255 |
+
else:
|
| 256 |
+
lora_attn_procs[name] = FluxAttnProcessor2_0()
|
| 257 |
+
|
| 258 |
+
transformer.set_attn_processor(lora_attn_procs)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
'''
|
| 263 |
+
unset_lora(pipe.transformer)
|
| 264 |
+
lora_path = "./lora.safetensors"
|
| 265 |
+
lora_weights = [1, 1]
|
| 266 |
+
set_lora(pipe.transformer, local_path=lora_path, lora_weights=lora_weights, cond_size=512)
|
| 267 |
+
'''
|
src/pipeline.py
ADDED
|
@@ -0,0 +1,805 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
| 7 |
+
|
| 8 |
+
from diffusers.image_processor import (VaeImageProcessor)
|
| 9 |
+
from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin
|
| 10 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
| 11 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 12 |
+
from diffusers.utils import (
|
| 13 |
+
USE_PEFT_BACKEND,
|
| 14 |
+
is_torch_xla_available,
|
| 15 |
+
logging,
|
| 16 |
+
scale_lora_layers,
|
| 17 |
+
unscale_lora_layers,
|
| 18 |
+
)
|
| 19 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 20 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 21 |
+
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
| 22 |
+
from torchvision.transforms.functional import pad
|
| 23 |
+
from .transformer_flux import FluxTransformer2DModel
|
| 24 |
+
|
| 25 |
+
if is_torch_xla_available():
|
| 26 |
+
import torch_xla.core.xla_model as xm
|
| 27 |
+
|
| 28 |
+
XLA_AVAILABLE = True
|
| 29 |
+
else:
|
| 30 |
+
XLA_AVAILABLE = False
|
| 31 |
+
|
| 32 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 33 |
+
|
| 34 |
+
def calculate_shift(
|
| 35 |
+
image_seq_len,
|
| 36 |
+
base_seq_len: int = 256,
|
| 37 |
+
max_seq_len: int = 4096,
|
| 38 |
+
base_shift: float = 0.5,
|
| 39 |
+
max_shift: float = 1.16,
|
| 40 |
+
):
|
| 41 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 42 |
+
b = base_shift - m * base_seq_len
|
| 43 |
+
mu = image_seq_len * m + b
|
| 44 |
+
return mu
|
| 45 |
+
|
| 46 |
+
def prepare_latent_image_ids_2(height, width, device, dtype):
|
| 47 |
+
latent_image_ids = torch.zeros(height//2, width//2, 3, device=device, dtype=dtype)
|
| 48 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height//2, device=device)[:, None] # y坐标
|
| 49 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width//2, device=device)[None, :] # x坐标
|
| 50 |
+
return latent_image_ids
|
| 51 |
+
|
| 52 |
+
def prepare_latent_subject_ids(height, width, device, dtype):
|
| 53 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3, device=device, dtype=dtype)
|
| 54 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2, device=device)[:, None]
|
| 55 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2, device=device)[None, :]
|
| 56 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
| 57 |
+
latent_image_ids = latent_image_ids.reshape(
|
| 58 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
| 59 |
+
)
|
| 60 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
| 61 |
+
|
| 62 |
+
def resize_position_encoding(batch_size, original_height, original_width, target_height, target_width, device, dtype):
|
| 63 |
+
latent_image_ids = prepare_latent_image_ids_2(original_height, original_width, device, dtype)
|
| 64 |
+
scale_h = original_height / target_height
|
| 65 |
+
scale_w = original_width / target_width
|
| 66 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
| 67 |
+
latent_image_ids = latent_image_ids.reshape(
|
| 68 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
| 69 |
+
)
|
| 70 |
+
#spatial进行PE插值
|
| 71 |
+
latent_image_ids_resized = torch.zeros(target_height//2, target_width//2, 3, device=device, dtype=dtype)
|
| 72 |
+
for i in range(target_height//2):
|
| 73 |
+
for j in range(target_width//2):
|
| 74 |
+
latent_image_ids_resized[i, j, 1] = i*scale_h
|
| 75 |
+
latent_image_ids_resized[i, j, 2] = j*scale_w
|
| 76 |
+
cond_latent_image_id_height, cond_latent_image_id_width, cond_latent_image_id_channels = latent_image_ids_resized.shape
|
| 77 |
+
cond_latent_image_ids = latent_image_ids_resized.reshape(
|
| 78 |
+
cond_latent_image_id_height * cond_latent_image_id_width, cond_latent_image_id_channels
|
| 79 |
+
)
|
| 80 |
+
# latent_image_ids_ = torch.concat([latent_image_ids, cond_latent_image_ids], dim=0)
|
| 81 |
+
return latent_image_ids, cond_latent_image_ids #, latent_image_ids_
|
| 82 |
+
|
| 83 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 84 |
+
def retrieve_latents(
|
| 85 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 86 |
+
):
|
| 87 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 88 |
+
return encoder_output.latent_dist.sample(generator)
|
| 89 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 90 |
+
return encoder_output.latent_dist.mode()
|
| 91 |
+
elif hasattr(encoder_output, "latents"):
|
| 92 |
+
return encoder_output.latents
|
| 93 |
+
else:
|
| 94 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 98 |
+
def retrieve_timesteps(
|
| 99 |
+
scheduler,
|
| 100 |
+
num_inference_steps: Optional[int] = None,
|
| 101 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 102 |
+
timesteps: Optional[List[int]] = None,
|
| 103 |
+
sigmas: Optional[List[float]] = None,
|
| 104 |
+
**kwargs,
|
| 105 |
+
):
|
| 106 |
+
"""
|
| 107 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 108 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
scheduler (`SchedulerMixin`):
|
| 112 |
+
The scheduler to get timesteps from.
|
| 113 |
+
num_inference_steps (`int`):
|
| 114 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 115 |
+
must be `None`.
|
| 116 |
+
device (`str` or `torch.device`, *optional*):
|
| 117 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 118 |
+
timesteps (`List[int]`, *optional*):
|
| 119 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 120 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 121 |
+
sigmas (`List[float]`, *optional*):
|
| 122 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 123 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 127 |
+
second element is the number of inference steps.
|
| 128 |
+
"""
|
| 129 |
+
if timesteps is not None and sigmas is not None:
|
| 130 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 131 |
+
if timesteps is not None:
|
| 132 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 133 |
+
if not accepts_timesteps:
|
| 134 |
+
raise ValueError(
|
| 135 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 136 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 137 |
+
)
|
| 138 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 139 |
+
timesteps = scheduler.timesteps
|
| 140 |
+
num_inference_steps = len(timesteps)
|
| 141 |
+
elif sigmas is not None:
|
| 142 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 143 |
+
if not accept_sigmas:
|
| 144 |
+
raise ValueError(
|
| 145 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 146 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 147 |
+
)
|
| 148 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 149 |
+
timesteps = scheduler.timesteps
|
| 150 |
+
num_inference_steps = len(timesteps)
|
| 151 |
+
else:
|
| 152 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 153 |
+
timesteps = scheduler.timesteps
|
| 154 |
+
return timesteps, num_inference_steps
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
| 158 |
+
r"""
|
| 159 |
+
The Flux pipeline for text-to-image generation.
|
| 160 |
+
|
| 161 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
transformer ([`FluxTransformer2DModel`]):
|
| 165 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
| 166 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 167 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 168 |
+
vae ([`AutoencoderKL`]):
|
| 169 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 170 |
+
text_encoder ([`CLIPTextModel`]):
|
| 171 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
| 172 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
| 173 |
+
text_encoder_2 ([`T5EncoderModel`]):
|
| 174 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
| 175 |
+
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
| 176 |
+
tokenizer (`CLIPTokenizer`):
|
| 177 |
+
Tokenizer of class
|
| 178 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 179 |
+
tokenizer_2 (`T5TokenizerFast`):
|
| 180 |
+
Second Tokenizer of class
|
| 181 |
+
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
| 182 |
+
"""
|
| 183 |
+
|
| 184 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
| 185 |
+
_optional_components = []
|
| 186 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 187 |
+
|
| 188 |
+
def __init__(
|
| 189 |
+
self,
|
| 190 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 191 |
+
vae: AutoencoderKL,
|
| 192 |
+
text_encoder: CLIPTextModel,
|
| 193 |
+
tokenizer: CLIPTokenizer,
|
| 194 |
+
text_encoder_2: T5EncoderModel,
|
| 195 |
+
tokenizer_2: T5TokenizerFast,
|
| 196 |
+
transformer: FluxTransformer2DModel,
|
| 197 |
+
):
|
| 198 |
+
super().__init__()
|
| 199 |
+
|
| 200 |
+
self.register_modules(
|
| 201 |
+
vae=vae,
|
| 202 |
+
text_encoder=text_encoder,
|
| 203 |
+
text_encoder_2=text_encoder_2,
|
| 204 |
+
tokenizer=tokenizer,
|
| 205 |
+
tokenizer_2=tokenizer_2,
|
| 206 |
+
transformer=transformer,
|
| 207 |
+
scheduler=scheduler,
|
| 208 |
+
)
|
| 209 |
+
self.vae_scale_factor = (
|
| 210 |
+
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
|
| 211 |
+
)
|
| 212 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 213 |
+
self.tokenizer_max_length = (
|
| 214 |
+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
| 215 |
+
)
|
| 216 |
+
self.default_sample_size = 64
|
| 217 |
+
|
| 218 |
+
def _get_t5_prompt_embeds(
|
| 219 |
+
self,
|
| 220 |
+
prompt: Union[str, List[str]] = None,
|
| 221 |
+
num_images_per_prompt: int = 1,
|
| 222 |
+
max_sequence_length: int = 512,
|
| 223 |
+
device: Optional[torch.device] = None,
|
| 224 |
+
dtype: Optional[torch.dtype] = None,
|
| 225 |
+
):
|
| 226 |
+
device = device or self._execution_device
|
| 227 |
+
dtype = dtype or self.text_encoder.dtype
|
| 228 |
+
|
| 229 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 230 |
+
batch_size = len(prompt)
|
| 231 |
+
|
| 232 |
+
text_inputs = self.tokenizer_2(
|
| 233 |
+
prompt,
|
| 234 |
+
padding="max_length",
|
| 235 |
+
max_length=max_sequence_length,
|
| 236 |
+
truncation=True,
|
| 237 |
+
return_length=False,
|
| 238 |
+
return_overflowing_tokens=False,
|
| 239 |
+
return_tensors="pt",
|
| 240 |
+
)
|
| 241 |
+
text_input_ids = text_inputs.input_ids
|
| 242 |
+
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
| 243 |
+
|
| 244 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 245 |
+
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1: -1])
|
| 246 |
+
logger.warning(
|
| 247 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 248 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
|
| 252 |
+
|
| 253 |
+
dtype = self.text_encoder_2.dtype
|
| 254 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 255 |
+
|
| 256 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 257 |
+
|
| 258 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 259 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 260 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 261 |
+
|
| 262 |
+
return prompt_embeds
|
| 263 |
+
|
| 264 |
+
def _get_clip_prompt_embeds(
|
| 265 |
+
self,
|
| 266 |
+
prompt: Union[str, List[str]],
|
| 267 |
+
num_images_per_prompt: int = 1,
|
| 268 |
+
device: Optional[torch.device] = None,
|
| 269 |
+
):
|
| 270 |
+
device = device or self._execution_device
|
| 271 |
+
|
| 272 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 273 |
+
batch_size = len(prompt)
|
| 274 |
+
|
| 275 |
+
text_inputs = self.tokenizer(
|
| 276 |
+
prompt,
|
| 277 |
+
padding="max_length",
|
| 278 |
+
max_length=self.tokenizer_max_length,
|
| 279 |
+
truncation=True,
|
| 280 |
+
return_overflowing_tokens=False,
|
| 281 |
+
return_length=False,
|
| 282 |
+
return_tensors="pt",
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
text_input_ids = text_inputs.input_ids
|
| 286 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 287 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 288 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1: -1])
|
| 289 |
+
logger.warning(
|
| 290 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 291 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
| 292 |
+
)
|
| 293 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
| 294 |
+
|
| 295 |
+
# Use pooled output of CLIPTextModel
|
| 296 |
+
prompt_embeds = prompt_embeds.pooler_output
|
| 297 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
| 298 |
+
|
| 299 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 300 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
| 301 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
| 302 |
+
|
| 303 |
+
return prompt_embeds
|
| 304 |
+
|
| 305 |
+
def encode_prompt(
|
| 306 |
+
self,
|
| 307 |
+
prompt: Union[str, List[str]],
|
| 308 |
+
prompt_2: Union[str, List[str]],
|
| 309 |
+
device: Optional[torch.device] = None,
|
| 310 |
+
num_images_per_prompt: int = 1,
|
| 311 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 312 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 313 |
+
max_sequence_length: int = 512,
|
| 314 |
+
lora_scale: Optional[float] = None,
|
| 315 |
+
):
|
| 316 |
+
r"""
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 320 |
+
prompt to be encoded
|
| 321 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 322 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 323 |
+
used in all text-encoders
|
| 324 |
+
device: (`torch.device`):
|
| 325 |
+
torch device
|
| 326 |
+
num_images_per_prompt (`int`):
|
| 327 |
+
number of images that should be generated per prompt
|
| 328 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 329 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 330 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 331 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 332 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 333 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 334 |
+
lora_scale (`float`, *optional*):
|
| 335 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 336 |
+
"""
|
| 337 |
+
device = device or self._execution_device
|
| 338 |
+
|
| 339 |
+
# set lora scale so that monkey patched LoRA
|
| 340 |
+
# function of text encoder can correctly access it
|
| 341 |
+
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
| 342 |
+
self._lora_scale = lora_scale
|
| 343 |
+
|
| 344 |
+
# dynamically adjust the LoRA scale
|
| 345 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
| 346 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
| 347 |
+
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
| 348 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
| 349 |
+
|
| 350 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 351 |
+
|
| 352 |
+
if prompt_embeds is None:
|
| 353 |
+
prompt_2 = prompt_2 or prompt
|
| 354 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
| 355 |
+
|
| 356 |
+
# We only use the pooled prompt output from the CLIPTextModel
|
| 357 |
+
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
| 358 |
+
prompt=prompt,
|
| 359 |
+
device=device,
|
| 360 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 361 |
+
)
|
| 362 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 363 |
+
prompt=prompt_2,
|
| 364 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 365 |
+
max_sequence_length=max_sequence_length,
|
| 366 |
+
device=device,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
if self.text_encoder is not None:
|
| 370 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 371 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 372 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
| 373 |
+
|
| 374 |
+
if self.text_encoder_2 is not None:
|
| 375 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 376 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 377 |
+
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
| 378 |
+
|
| 379 |
+
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
| 380 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
| 381 |
+
|
| 382 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
| 383 |
+
|
| 384 |
+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
|
| 385 |
+
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
| 386 |
+
if isinstance(generator, list):
|
| 387 |
+
image_latents = [
|
| 388 |
+
retrieve_latents(self.vae.encode(image[i: i + 1]), generator=generator[i])
|
| 389 |
+
for i in range(image.shape[0])
|
| 390 |
+
]
|
| 391 |
+
image_latents = torch.cat(image_latents, dim=0)
|
| 392 |
+
else:
|
| 393 |
+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
| 394 |
+
|
| 395 |
+
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
| 396 |
+
|
| 397 |
+
return image_latents
|
| 398 |
+
|
| 399 |
+
def check_inputs(
|
| 400 |
+
self,
|
| 401 |
+
prompt,
|
| 402 |
+
prompt_2,
|
| 403 |
+
height,
|
| 404 |
+
width,
|
| 405 |
+
prompt_embeds=None,
|
| 406 |
+
pooled_prompt_embeds=None,
|
| 407 |
+
callback_on_step_end_tensor_inputs=None,
|
| 408 |
+
max_sequence_length=None,
|
| 409 |
+
):
|
| 410 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 411 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 412 |
+
|
| 413 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 414 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 415 |
+
):
|
| 416 |
+
raise ValueError(
|
| 417 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
if prompt is not None and prompt_embeds is not None:
|
| 421 |
+
raise ValueError(
|
| 422 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 423 |
+
" only forward one of the two."
|
| 424 |
+
)
|
| 425 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
| 426 |
+
raise ValueError(
|
| 427 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 428 |
+
" only forward one of the two."
|
| 429 |
+
)
|
| 430 |
+
elif prompt is None and prompt_embeds is None:
|
| 431 |
+
raise ValueError(
|
| 432 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 433 |
+
)
|
| 434 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 435 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 436 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
| 437 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
| 438 |
+
|
| 439 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
| 440 |
+
raise ValueError(
|
| 441 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
if max_sequence_length is not None and max_sequence_length > 512:
|
| 445 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
| 446 |
+
|
| 447 |
+
@staticmethod
|
| 448 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
| 449 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
| 450 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
| 451 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
| 452 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
| 453 |
+
latent_image_ids = latent_image_ids.reshape(
|
| 454 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
| 455 |
+
)
|
| 456 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
| 457 |
+
|
| 458 |
+
@staticmethod
|
| 459 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
| 460 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
| 461 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
| 462 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
| 463 |
+
return latents
|
| 464 |
+
|
| 465 |
+
@staticmethod
|
| 466 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
| 467 |
+
batch_size, num_patches, channels = latents.shape
|
| 468 |
+
|
| 469 |
+
height = height // vae_scale_factor
|
| 470 |
+
width = width // vae_scale_factor
|
| 471 |
+
|
| 472 |
+
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
|
| 473 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
| 474 |
+
|
| 475 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
|
| 476 |
+
|
| 477 |
+
return latents
|
| 478 |
+
|
| 479 |
+
def enable_vae_slicing(self):
|
| 480 |
+
r"""
|
| 481 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 482 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 483 |
+
"""
|
| 484 |
+
self.vae.enable_slicing()
|
| 485 |
+
|
| 486 |
+
def disable_vae_slicing(self):
|
| 487 |
+
r"""
|
| 488 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
| 489 |
+
computing decoding in one step.
|
| 490 |
+
"""
|
| 491 |
+
self.vae.disable_slicing()
|
| 492 |
+
|
| 493 |
+
def enable_vae_tiling(self):
|
| 494 |
+
r"""
|
| 495 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
| 496 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
| 497 |
+
processing larger images.
|
| 498 |
+
"""
|
| 499 |
+
self.vae.enable_tiling()
|
| 500 |
+
|
| 501 |
+
def disable_vae_tiling(self):
|
| 502 |
+
r"""
|
| 503 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
| 504 |
+
computing decoding in one step.
|
| 505 |
+
"""
|
| 506 |
+
self.vae.disable_tiling()
|
| 507 |
+
|
| 508 |
+
def prepare_latents(
|
| 509 |
+
self,
|
| 510 |
+
batch_size,
|
| 511 |
+
num_channels_latents,
|
| 512 |
+
height,
|
| 513 |
+
width,
|
| 514 |
+
dtype,
|
| 515 |
+
device,
|
| 516 |
+
generator,
|
| 517 |
+
subject_image,
|
| 518 |
+
condition_image,
|
| 519 |
+
latents=None,
|
| 520 |
+
cond_number=1,
|
| 521 |
+
sub_number=1
|
| 522 |
+
):
|
| 523 |
+
height_cond = 2 * (self.cond_size // self.vae_scale_factor)
|
| 524 |
+
width_cond = 2 * (self.cond_size // self.vae_scale_factor)
|
| 525 |
+
height = 2 * (int(height) // self.vae_scale_factor)
|
| 526 |
+
width = 2 * (int(width) // self.vae_scale_factor)
|
| 527 |
+
|
| 528 |
+
shape = (batch_size, num_channels_latents, height, width) # 1 16 106 80
|
| 529 |
+
noise_latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 530 |
+
noise_latents = self._pack_latents(noise_latents, batch_size, num_channels_latents, height, width)
|
| 531 |
+
noise_latent_image_ids, cond_latent_image_ids = resize_position_encoding(
|
| 532 |
+
batch_size,
|
| 533 |
+
height,
|
| 534 |
+
width,
|
| 535 |
+
height_cond,
|
| 536 |
+
width_cond,
|
| 537 |
+
device,
|
| 538 |
+
dtype,
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
latents_to_concat = [] # 不包含 latents
|
| 542 |
+
latents_ids_to_concat = [noise_latent_image_ids]
|
| 543 |
+
|
| 544 |
+
# subject
|
| 545 |
+
if subject_image is not None:
|
| 546 |
+
shape_subject = (batch_size, num_channels_latents, height_cond*sub_number, width_cond)
|
| 547 |
+
subject_image = subject_image.to(device=device, dtype=dtype)
|
| 548 |
+
subject_image_latents = self._encode_vae_image(image=subject_image, generator=generator)
|
| 549 |
+
subject_latents = self._pack_latents(subject_image_latents, batch_size, num_channels_latents, height_cond*sub_number, width_cond)
|
| 550 |
+
mask2 = torch.zeros(shape_subject, device=device, dtype=dtype)
|
| 551 |
+
mask2 = self._pack_latents(mask2, batch_size, num_channels_latents, height_cond*sub_number, width_cond)
|
| 552 |
+
latent_subject_ids = prepare_latent_subject_ids(height_cond, width_cond, device, dtype)
|
| 553 |
+
latent_subject_ids[:, 1] += 64 # fixed offset
|
| 554 |
+
subject_latent_image_ids = torch.concat([latent_subject_ids for _ in range(sub_number)], dim=-2)
|
| 555 |
+
latents_to_concat.append(subject_latents)
|
| 556 |
+
latents_ids_to_concat.append(subject_latent_image_ids)
|
| 557 |
+
|
| 558 |
+
# spatial
|
| 559 |
+
if condition_image is not None:
|
| 560 |
+
shape_cond = (batch_size, num_channels_latents, height_cond*cond_number, width_cond)
|
| 561 |
+
condition_image = condition_image.to(device=device, dtype=dtype)
|
| 562 |
+
image_latents = self._encode_vae_image(image=condition_image, generator=generator)
|
| 563 |
+
cond_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height_cond*cond_number, width_cond)
|
| 564 |
+
mask3 = torch.zeros(shape_cond, device=device, dtype=dtype)
|
| 565 |
+
mask3 = self._pack_latents(mask3, batch_size, num_channels_latents, height_cond*cond_number, width_cond)
|
| 566 |
+
cond_latent_image_ids = cond_latent_image_ids
|
| 567 |
+
cond_latent_image_ids = torch.concat([cond_latent_image_ids for _ in range(cond_number)], dim=-2)
|
| 568 |
+
latents_ids_to_concat.append(cond_latent_image_ids)
|
| 569 |
+
latents_to_concat.append(cond_latents)
|
| 570 |
+
|
| 571 |
+
cond_latents = torch.concat(latents_to_concat, dim=-2)
|
| 572 |
+
latent_image_ids = torch.concat(latents_ids_to_concat, dim=-2)
|
| 573 |
+
return cond_latents, latent_image_ids, noise_latents
|
| 574 |
+
|
| 575 |
+
@property
|
| 576 |
+
def guidance_scale(self):
|
| 577 |
+
return self._guidance_scale
|
| 578 |
+
|
| 579 |
+
@property
|
| 580 |
+
def joint_attention_kwargs(self):
|
| 581 |
+
return self._joint_attention_kwargs
|
| 582 |
+
|
| 583 |
+
@property
|
| 584 |
+
def num_timesteps(self):
|
| 585 |
+
return self._num_timesteps
|
| 586 |
+
|
| 587 |
+
@property
|
| 588 |
+
def interrupt(self):
|
| 589 |
+
return self._interrupt
|
| 590 |
+
|
| 591 |
+
@torch.no_grad()
|
| 592 |
+
def __call__(
|
| 593 |
+
self,
|
| 594 |
+
prompt: Union[str, List[str]] = None,
|
| 595 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 596 |
+
height: Optional[int] = None,
|
| 597 |
+
width: Optional[int] = None,
|
| 598 |
+
num_inference_steps: int = 28,
|
| 599 |
+
timesteps: List[int] = None,
|
| 600 |
+
guidance_scale: float = 3.5,
|
| 601 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 602 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 603 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 604 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 605 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 606 |
+
output_type: Optional[str] = "pil",
|
| 607 |
+
return_dict: bool = True,
|
| 608 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 609 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 610 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 611 |
+
max_sequence_length: int = 512,
|
| 612 |
+
spatial_images=None,
|
| 613 |
+
subject_images=None,
|
| 614 |
+
cond_size=512,
|
| 615 |
+
):
|
| 616 |
+
|
| 617 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 618 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 619 |
+
self.cond_size = cond_size
|
| 620 |
+
|
| 621 |
+
# 1. Check inputs. Raise error if not correct
|
| 622 |
+
self.check_inputs(
|
| 623 |
+
prompt,
|
| 624 |
+
prompt_2,
|
| 625 |
+
height,
|
| 626 |
+
width,
|
| 627 |
+
prompt_embeds=prompt_embeds,
|
| 628 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 629 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 630 |
+
max_sequence_length=max_sequence_length,
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
self._guidance_scale = guidance_scale
|
| 634 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 635 |
+
self._interrupt = False
|
| 636 |
+
|
| 637 |
+
cond_number = len(spatial_images)
|
| 638 |
+
sub_number = len(subject_images)
|
| 639 |
+
|
| 640 |
+
if sub_number > 0:
|
| 641 |
+
subject_image_ls = []
|
| 642 |
+
for subject_image in subject_images:
|
| 643 |
+
w, h = subject_image.size[:2]
|
| 644 |
+
scale = self.cond_size / max(h, w)
|
| 645 |
+
new_h, new_w = int(h * scale), int(w * scale)
|
| 646 |
+
subject_image = self.image_processor.preprocess(subject_image, height=new_h, width=new_w)
|
| 647 |
+
subject_image = subject_image.to(dtype=torch.float32)
|
| 648 |
+
pad_h = cond_size - subject_image.shape[-2]
|
| 649 |
+
pad_w = cond_size - subject_image.shape[-1]
|
| 650 |
+
subject_image = pad(
|
| 651 |
+
subject_image,
|
| 652 |
+
padding=(int(pad_w / 2), int(pad_h / 2), int(pad_w / 2), int(pad_h / 2)),
|
| 653 |
+
fill=0
|
| 654 |
+
)
|
| 655 |
+
subject_image_ls.append(subject_image)
|
| 656 |
+
subject_image = torch.concat(subject_image_ls, dim=-2)
|
| 657 |
+
else:
|
| 658 |
+
subject_image = None
|
| 659 |
+
|
| 660 |
+
if cond_number > 0:
|
| 661 |
+
condition_image_ls = []
|
| 662 |
+
for img in spatial_images:
|
| 663 |
+
condition_image = self.image_processor.preprocess(img, height=self.cond_size, width=self.cond_size)
|
| 664 |
+
condition_image = condition_image.to(dtype=torch.float32)
|
| 665 |
+
condition_image_ls.append(condition_image)
|
| 666 |
+
condition_image = torch.concat(condition_image_ls, dim=-2)
|
| 667 |
+
else:
|
| 668 |
+
condition_image = None
|
| 669 |
+
|
| 670 |
+
# 2. Define call parameters
|
| 671 |
+
if prompt is not None and isinstance(prompt, str):
|
| 672 |
+
batch_size = 1
|
| 673 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 674 |
+
batch_size = len(prompt)
|
| 675 |
+
else:
|
| 676 |
+
batch_size = prompt_embeds.shape[0]
|
| 677 |
+
|
| 678 |
+
device = self._execution_device
|
| 679 |
+
|
| 680 |
+
lora_scale = (
|
| 681 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 682 |
+
)
|
| 683 |
+
(
|
| 684 |
+
prompt_embeds,
|
| 685 |
+
pooled_prompt_embeds,
|
| 686 |
+
text_ids,
|
| 687 |
+
) = self.encode_prompt(
|
| 688 |
+
prompt=prompt,
|
| 689 |
+
prompt_2=prompt_2,
|
| 690 |
+
prompt_embeds=prompt_embeds,
|
| 691 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 692 |
+
device=device,
|
| 693 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 694 |
+
max_sequence_length=max_sequence_length,
|
| 695 |
+
lora_scale=lora_scale,
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
# 4. Prepare latent variables
|
| 699 |
+
num_channels_latents = self.transformer.config.in_channels // 4 # 16
|
| 700 |
+
cond_latents, latent_image_ids, noise_latents = self.prepare_latents(
|
| 701 |
+
batch_size * num_images_per_prompt,
|
| 702 |
+
num_channels_latents,
|
| 703 |
+
height,
|
| 704 |
+
width,
|
| 705 |
+
prompt_embeds.dtype,
|
| 706 |
+
device,
|
| 707 |
+
generator,
|
| 708 |
+
subject_image,
|
| 709 |
+
condition_image,
|
| 710 |
+
latents,
|
| 711 |
+
cond_number,
|
| 712 |
+
sub_number
|
| 713 |
+
)
|
| 714 |
+
latents = noise_latents
|
| 715 |
+
# 5. Prepare timesteps
|
| 716 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
| 717 |
+
image_seq_len = latents.shape[1]
|
| 718 |
+
mu = calculate_shift(
|
| 719 |
+
image_seq_len,
|
| 720 |
+
self.scheduler.config.base_image_seq_len,
|
| 721 |
+
self.scheduler.config.max_image_seq_len,
|
| 722 |
+
self.scheduler.config.base_shift,
|
| 723 |
+
self.scheduler.config.max_shift,
|
| 724 |
+
)
|
| 725 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 726 |
+
self.scheduler,
|
| 727 |
+
num_inference_steps,
|
| 728 |
+
device,
|
| 729 |
+
timesteps,
|
| 730 |
+
sigmas,
|
| 731 |
+
mu=mu,
|
| 732 |
+
)
|
| 733 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 734 |
+
self._num_timesteps = len(timesteps)
|
| 735 |
+
|
| 736 |
+
# handle guidance
|
| 737 |
+
if self.transformer.config.guidance_embeds:
|
| 738 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 739 |
+
guidance = guidance.expand(latents.shape[0])
|
| 740 |
+
else:
|
| 741 |
+
guidance = None
|
| 742 |
+
|
| 743 |
+
# 6. Denoising loop
|
| 744 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 745 |
+
for i, t in enumerate(timesteps):
|
| 746 |
+
if self.interrupt:
|
| 747 |
+
continue
|
| 748 |
+
|
| 749 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 750 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 751 |
+
noise_pred = self.transformer(
|
| 752 |
+
hidden_states=latents, # 1 4096 64
|
| 753 |
+
cond_hidden_states=cond_latents,
|
| 754 |
+
timestep=timestep / 1000,
|
| 755 |
+
guidance=guidance,
|
| 756 |
+
pooled_projections=pooled_prompt_embeds,
|
| 757 |
+
encoder_hidden_states=prompt_embeds,
|
| 758 |
+
txt_ids=text_ids,
|
| 759 |
+
img_ids=latent_image_ids,
|
| 760 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 761 |
+
return_dict=False,
|
| 762 |
+
)[0]
|
| 763 |
+
|
| 764 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 765 |
+
latents_dtype = latents.dtype
|
| 766 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 767 |
+
latents = latents
|
| 768 |
+
|
| 769 |
+
if latents.dtype != latents_dtype:
|
| 770 |
+
if torch.backends.mps.is_available():
|
| 771 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 772 |
+
latents = latents.to(latents_dtype)
|
| 773 |
+
|
| 774 |
+
if callback_on_step_end is not None:
|
| 775 |
+
callback_kwargs = {}
|
| 776 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 777 |
+
callback_kwargs[k] = locals()[k]
|
| 778 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 779 |
+
|
| 780 |
+
latents = callback_outputs.pop("latents", latents)
|
| 781 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 782 |
+
|
| 783 |
+
# call the callback, if provided
|
| 784 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 785 |
+
progress_bar.update()
|
| 786 |
+
|
| 787 |
+
if XLA_AVAILABLE:
|
| 788 |
+
xm.mark_step()
|
| 789 |
+
|
| 790 |
+
if output_type == "latent":
|
| 791 |
+
image = latents
|
| 792 |
+
|
| 793 |
+
else:
|
| 794 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 795 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 796 |
+
image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
|
| 797 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 798 |
+
|
| 799 |
+
# Offload all models
|
| 800 |
+
self.maybe_free_model_hooks()
|
| 801 |
+
|
| 802 |
+
if not return_dict:
|
| 803 |
+
return (image,)
|
| 804 |
+
|
| 805 |
+
return FluxPipelineOutput(images=image)
|
src/prompt_helper.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def load_text_encoders(args, class_one, class_two):
|
| 5 |
+
text_encoder_one = class_one.from_pretrained(
|
| 6 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
| 7 |
+
)
|
| 8 |
+
text_encoder_two = class_two.from_pretrained(
|
| 9 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
|
| 10 |
+
)
|
| 11 |
+
return text_encoder_one, text_encoder_two
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def tokenize_prompt(tokenizer, prompt, max_sequence_length):
|
| 15 |
+
text_inputs = tokenizer(
|
| 16 |
+
prompt,
|
| 17 |
+
padding="max_length",
|
| 18 |
+
max_length=max_sequence_length,
|
| 19 |
+
truncation=True,
|
| 20 |
+
return_length=False,
|
| 21 |
+
return_overflowing_tokens=False,
|
| 22 |
+
return_tensors="pt",
|
| 23 |
+
)
|
| 24 |
+
text_input_ids = text_inputs.input_ids
|
| 25 |
+
return text_input_ids
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def tokenize_prompt_clip(tokenizer, prompt):
|
| 29 |
+
text_inputs = tokenizer(
|
| 30 |
+
prompt,
|
| 31 |
+
padding="max_length",
|
| 32 |
+
max_length=77,
|
| 33 |
+
truncation=True,
|
| 34 |
+
return_length=False,
|
| 35 |
+
return_overflowing_tokens=False,
|
| 36 |
+
return_tensors="pt",
|
| 37 |
+
)
|
| 38 |
+
text_input_ids = text_inputs.input_ids
|
| 39 |
+
return text_input_ids
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def tokenize_prompt_t5(tokenizer, prompt):
|
| 43 |
+
text_inputs = tokenizer(
|
| 44 |
+
prompt,
|
| 45 |
+
padding="max_length",
|
| 46 |
+
max_length=512,
|
| 47 |
+
truncation=True,
|
| 48 |
+
return_length=False,
|
| 49 |
+
return_overflowing_tokens=False,
|
| 50 |
+
return_tensors="pt",
|
| 51 |
+
)
|
| 52 |
+
text_input_ids = text_inputs.input_ids
|
| 53 |
+
return text_input_ids
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _encode_prompt_with_t5(
|
| 57 |
+
text_encoder,
|
| 58 |
+
tokenizer,
|
| 59 |
+
max_sequence_length=512,
|
| 60 |
+
prompt=None,
|
| 61 |
+
num_images_per_prompt=1,
|
| 62 |
+
device=None,
|
| 63 |
+
text_input_ids=None,
|
| 64 |
+
):
|
| 65 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 66 |
+
batch_size = len(prompt)
|
| 67 |
+
|
| 68 |
+
if tokenizer is not None:
|
| 69 |
+
text_inputs = tokenizer(
|
| 70 |
+
prompt,
|
| 71 |
+
padding="max_length",
|
| 72 |
+
max_length=max_sequence_length,
|
| 73 |
+
truncation=True,
|
| 74 |
+
return_length=False,
|
| 75 |
+
return_overflowing_tokens=False,
|
| 76 |
+
return_tensors="pt",
|
| 77 |
+
)
|
| 78 |
+
text_input_ids = text_inputs.input_ids
|
| 79 |
+
else:
|
| 80 |
+
if text_input_ids is None:
|
| 81 |
+
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
|
| 82 |
+
|
| 83 |
+
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
| 84 |
+
|
| 85 |
+
dtype = text_encoder.dtype
|
| 86 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 87 |
+
|
| 88 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 89 |
+
|
| 90 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 91 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 92 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 93 |
+
|
| 94 |
+
return prompt_embeds
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _encode_prompt_with_clip(
|
| 98 |
+
text_encoder,
|
| 99 |
+
tokenizer,
|
| 100 |
+
prompt: str,
|
| 101 |
+
device=None,
|
| 102 |
+
text_input_ids=None,
|
| 103 |
+
num_images_per_prompt: int = 1,
|
| 104 |
+
):
|
| 105 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 106 |
+
batch_size = len(prompt)
|
| 107 |
+
|
| 108 |
+
if tokenizer is not None:
|
| 109 |
+
text_inputs = tokenizer(
|
| 110 |
+
prompt,
|
| 111 |
+
padding="max_length",
|
| 112 |
+
max_length=77,
|
| 113 |
+
truncation=True,
|
| 114 |
+
return_overflowing_tokens=False,
|
| 115 |
+
return_length=False,
|
| 116 |
+
return_tensors="pt",
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
text_input_ids = text_inputs.input_ids
|
| 120 |
+
else:
|
| 121 |
+
if text_input_ids is None:
|
| 122 |
+
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
|
| 123 |
+
|
| 124 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
| 125 |
+
|
| 126 |
+
# Use pooled output of CLIPTextModel
|
| 127 |
+
prompt_embeds = prompt_embeds.pooler_output
|
| 128 |
+
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
|
| 129 |
+
|
| 130 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 131 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 132 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
| 133 |
+
|
| 134 |
+
return prompt_embeds
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def encode_prompt(
|
| 138 |
+
text_encoders,
|
| 139 |
+
tokenizers,
|
| 140 |
+
prompt: str,
|
| 141 |
+
max_sequence_length,
|
| 142 |
+
device=None,
|
| 143 |
+
num_images_per_prompt: int = 1,
|
| 144 |
+
text_input_ids_list=None,
|
| 145 |
+
):
|
| 146 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 147 |
+
dtype = text_encoders[0].dtype
|
| 148 |
+
|
| 149 |
+
pooled_prompt_embeds = _encode_prompt_with_clip(
|
| 150 |
+
text_encoder=text_encoders[0],
|
| 151 |
+
tokenizer=tokenizers[0],
|
| 152 |
+
prompt=prompt,
|
| 153 |
+
device=device if device is not None else text_encoders[0].device,
|
| 154 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 155 |
+
text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
prompt_embeds = _encode_prompt_with_t5(
|
| 159 |
+
text_encoder=text_encoders[1],
|
| 160 |
+
tokenizer=tokenizers[1],
|
| 161 |
+
max_sequence_length=max_sequence_length,
|
| 162 |
+
prompt=prompt,
|
| 163 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 164 |
+
device=device if device is not None else text_encoders[1].device,
|
| 165 |
+
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
| 169 |
+
|
| 170 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def encode_token_ids(text_encoders, tokens, accelerator, num_images_per_prompt=1, device=None):
|
| 174 |
+
text_encoder_clip = text_encoders[0]
|
| 175 |
+
text_encoder_t5 = text_encoders[1]
|
| 176 |
+
tokens_clip, tokens_t5 = tokens[0], tokens[1]
|
| 177 |
+
batch_size = tokens_clip.shape[0]
|
| 178 |
+
|
| 179 |
+
if device == "cpu":
|
| 180 |
+
device = "cpu"
|
| 181 |
+
else:
|
| 182 |
+
device = accelerator.device
|
| 183 |
+
|
| 184 |
+
# clip
|
| 185 |
+
prompt_embeds = text_encoder_clip(tokens_clip.to(device), output_hidden_states=False)
|
| 186 |
+
# Use pooled output of CLIPTextModel
|
| 187 |
+
prompt_embeds = prompt_embeds.pooler_output
|
| 188 |
+
prompt_embeds = prompt_embeds.to(dtype=text_encoder_clip.dtype, device=accelerator.device)
|
| 189 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 190 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 191 |
+
pooled_prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
| 192 |
+
pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=text_encoder_clip.dtype, device=accelerator.device)
|
| 193 |
+
|
| 194 |
+
# t5
|
| 195 |
+
prompt_embeds = text_encoder_t5(tokens_t5.to(device))[0]
|
| 196 |
+
dtype = text_encoder_t5.dtype
|
| 197 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=accelerator.device)
|
| 198 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 199 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 200 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 201 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 202 |
+
|
| 203 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=accelerator.device, dtype=dtype)
|
| 204 |
+
|
| 205 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
src/transformer_flux.py
ADDED
|
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 9 |
+
from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
| 10 |
+
from diffusers.models.attention import FeedForward
|
| 11 |
+
from diffusers.models.attention_processor import (
|
| 12 |
+
Attention,
|
| 13 |
+
AttentionProcessor,
|
| 14 |
+
FluxAttnProcessor2_0,
|
| 15 |
+
FluxAttnProcessor2_0_NPU,
|
| 16 |
+
FusedFluxAttnProcessor2_0,
|
| 17 |
+
)
|
| 18 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 19 |
+
from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
| 20 |
+
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
| 21 |
+
from diffusers.utils.import_utils import is_torch_npu_available
|
| 22 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 23 |
+
from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
| 24 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 25 |
+
|
| 26 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 27 |
+
|
| 28 |
+
@maybe_allow_in_graph
|
| 29 |
+
class FluxSingleTransformerBlock(nn.Module):
|
| 30 |
+
|
| 31 |
+
def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
| 34 |
+
|
| 35 |
+
self.norm = AdaLayerNormZeroSingle(dim)
|
| 36 |
+
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
| 37 |
+
self.act_mlp = nn.GELU(approximate="tanh")
|
| 38 |
+
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
| 39 |
+
|
| 40 |
+
if is_torch_npu_available():
|
| 41 |
+
processor = FluxAttnProcessor2_0_NPU()
|
| 42 |
+
else:
|
| 43 |
+
processor = FluxAttnProcessor2_0()
|
| 44 |
+
self.attn = Attention(
|
| 45 |
+
query_dim=dim,
|
| 46 |
+
cross_attention_dim=None,
|
| 47 |
+
dim_head=attention_head_dim,
|
| 48 |
+
heads=num_attention_heads,
|
| 49 |
+
out_dim=dim,
|
| 50 |
+
bias=True,
|
| 51 |
+
processor=processor,
|
| 52 |
+
qk_norm="rms_norm",
|
| 53 |
+
eps=1e-6,
|
| 54 |
+
pre_only=True,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def forward(
|
| 58 |
+
self,
|
| 59 |
+
hidden_states: torch.Tensor,
|
| 60 |
+
cond_hidden_states: torch.Tensor,
|
| 61 |
+
temb: torch.Tensor,
|
| 62 |
+
cond_temb: torch.Tensor,
|
| 63 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 64 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 65 |
+
) -> torch.Tensor:
|
| 66 |
+
use_cond = cond_hidden_states is not None
|
| 67 |
+
|
| 68 |
+
residual = hidden_states
|
| 69 |
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
| 70 |
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
| 71 |
+
|
| 72 |
+
if use_cond:
|
| 73 |
+
residual_cond = cond_hidden_states
|
| 74 |
+
norm_cond_hidden_states, cond_gate = self.norm(cond_hidden_states, emb=cond_temb)
|
| 75 |
+
mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_cond_hidden_states))
|
| 76 |
+
|
| 77 |
+
norm_hidden_states_concat = torch.concat([norm_hidden_states, norm_cond_hidden_states], dim=-2)
|
| 78 |
+
|
| 79 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 80 |
+
attn_output = self.attn(
|
| 81 |
+
hidden_states=norm_hidden_states_concat,
|
| 82 |
+
image_rotary_emb=image_rotary_emb,
|
| 83 |
+
use_cond=use_cond,
|
| 84 |
+
**joint_attention_kwargs,
|
| 85 |
+
)
|
| 86 |
+
if use_cond:
|
| 87 |
+
attn_output, cond_attn_output = attn_output
|
| 88 |
+
|
| 89 |
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
| 90 |
+
gate = gate.unsqueeze(1)
|
| 91 |
+
hidden_states = gate * self.proj_out(hidden_states)
|
| 92 |
+
hidden_states = residual + hidden_states
|
| 93 |
+
|
| 94 |
+
if use_cond:
|
| 95 |
+
condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
|
| 96 |
+
cond_gate = cond_gate.unsqueeze(1)
|
| 97 |
+
condition_latents = cond_gate * self.proj_out(condition_latents)
|
| 98 |
+
condition_latents = residual_cond + condition_latents
|
| 99 |
+
|
| 100 |
+
if hidden_states.dtype == torch.float16:
|
| 101 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
| 102 |
+
|
| 103 |
+
return hidden_states, condition_latents if use_cond else None
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@maybe_allow_in_graph
|
| 107 |
+
class FluxTransformerBlock(nn.Module):
|
| 108 |
+
def __init__(
|
| 109 |
+
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
|
| 110 |
+
):
|
| 111 |
+
super().__init__()
|
| 112 |
+
|
| 113 |
+
self.norm1 = AdaLayerNormZero(dim)
|
| 114 |
+
|
| 115 |
+
self.norm1_context = AdaLayerNormZero(dim)
|
| 116 |
+
|
| 117 |
+
if hasattr(F, "scaled_dot_product_attention"):
|
| 118 |
+
processor = FluxAttnProcessor2_0()
|
| 119 |
+
else:
|
| 120 |
+
raise ValueError(
|
| 121 |
+
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
| 122 |
+
)
|
| 123 |
+
self.attn = Attention(
|
| 124 |
+
query_dim=dim,
|
| 125 |
+
cross_attention_dim=None,
|
| 126 |
+
added_kv_proj_dim=dim,
|
| 127 |
+
dim_head=attention_head_dim,
|
| 128 |
+
heads=num_attention_heads,
|
| 129 |
+
out_dim=dim,
|
| 130 |
+
context_pre_only=False,
|
| 131 |
+
bias=True,
|
| 132 |
+
processor=processor,
|
| 133 |
+
qk_norm=qk_norm,
|
| 134 |
+
eps=eps,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 138 |
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
| 139 |
+
|
| 140 |
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 141 |
+
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
| 142 |
+
|
| 143 |
+
# let chunk size default to None
|
| 144 |
+
self._chunk_size = None
|
| 145 |
+
self._chunk_dim = 0
|
| 146 |
+
|
| 147 |
+
def forward(
|
| 148 |
+
self,
|
| 149 |
+
hidden_states: torch.Tensor,
|
| 150 |
+
cond_hidden_states: torch.Tensor,
|
| 151 |
+
encoder_hidden_states: torch.Tensor,
|
| 152 |
+
temb: torch.Tensor,
|
| 153 |
+
cond_temb: torch.Tensor,
|
| 154 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 155 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 156 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 157 |
+
use_cond = cond_hidden_states is not None
|
| 158 |
+
|
| 159 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
| 160 |
+
if use_cond:
|
| 161 |
+
(
|
| 162 |
+
norm_cond_hidden_states,
|
| 163 |
+
cond_gate_msa,
|
| 164 |
+
cond_shift_mlp,
|
| 165 |
+
cond_scale_mlp,
|
| 166 |
+
cond_gate_mlp,
|
| 167 |
+
) = self.norm1(cond_hidden_states, emb=cond_temb)
|
| 168 |
+
|
| 169 |
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
| 170 |
+
encoder_hidden_states, emb=temb
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
norm_hidden_states = torch.concat([norm_hidden_states, norm_cond_hidden_states], dim=-2)
|
| 174 |
+
|
| 175 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 176 |
+
# Attention.
|
| 177 |
+
attention_outputs = self.attn(
|
| 178 |
+
hidden_states=norm_hidden_states,
|
| 179 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 180 |
+
image_rotary_emb=image_rotary_emb,
|
| 181 |
+
use_cond=use_cond,
|
| 182 |
+
**joint_attention_kwargs,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
attn_output, context_attn_output = attention_outputs[:2]
|
| 186 |
+
cond_attn_output = attention_outputs[2] if use_cond else None
|
| 187 |
+
|
| 188 |
+
# Process attention outputs for the `hidden_states`.
|
| 189 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
| 190 |
+
hidden_states = hidden_states + attn_output
|
| 191 |
+
|
| 192 |
+
if use_cond:
|
| 193 |
+
cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
|
| 194 |
+
cond_hidden_states = cond_hidden_states + cond_attn_output
|
| 195 |
+
|
| 196 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 197 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 198 |
+
|
| 199 |
+
if use_cond:
|
| 200 |
+
norm_cond_hidden_states = self.norm2(cond_hidden_states)
|
| 201 |
+
norm_cond_hidden_states = (
|
| 202 |
+
norm_cond_hidden_states * (1 + cond_scale_mlp[:, None])
|
| 203 |
+
+ cond_shift_mlp[:, None]
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
ff_output = self.ff(norm_hidden_states)
|
| 207 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
| 208 |
+
hidden_states = hidden_states + ff_output
|
| 209 |
+
|
| 210 |
+
if use_cond:
|
| 211 |
+
cond_ff_output = self.ff(norm_cond_hidden_states)
|
| 212 |
+
cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
|
| 213 |
+
cond_hidden_states = cond_hidden_states + cond_ff_output
|
| 214 |
+
|
| 215 |
+
# Process attention outputs for the `encoder_hidden_states`.
|
| 216 |
+
|
| 217 |
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
| 218 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
| 219 |
+
|
| 220 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
| 221 |
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
| 222 |
+
|
| 223 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
| 224 |
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
| 225 |
+
if encoder_hidden_states.dtype == torch.float16:
|
| 226 |
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
| 227 |
+
|
| 228 |
+
return encoder_hidden_states, hidden_states, cond_hidden_states if use_cond else None
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class FluxTransformer2DModel(
|
| 232 |
+
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin
|
| 233 |
+
):
|
| 234 |
+
_supports_gradient_checkpointing = True
|
| 235 |
+
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
| 236 |
+
|
| 237 |
+
@register_to_config
|
| 238 |
+
def __init__(
|
| 239 |
+
self,
|
| 240 |
+
patch_size: int = 1,
|
| 241 |
+
in_channels: int = 64,
|
| 242 |
+
out_channels: Optional[int] = None,
|
| 243 |
+
num_layers: int = 19,
|
| 244 |
+
num_single_layers: int = 38,
|
| 245 |
+
attention_head_dim: int = 128,
|
| 246 |
+
num_attention_heads: int = 24,
|
| 247 |
+
joint_attention_dim: int = 4096,
|
| 248 |
+
pooled_projection_dim: int = 768,
|
| 249 |
+
guidance_embeds: bool = False,
|
| 250 |
+
axes_dims_rope: Tuple[int] = (16, 56, 56),
|
| 251 |
+
):
|
| 252 |
+
super().__init__()
|
| 253 |
+
self.out_channels = out_channels or in_channels
|
| 254 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
| 255 |
+
|
| 256 |
+
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
| 257 |
+
|
| 258 |
+
text_time_guidance_cls = (
|
| 259 |
+
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
|
| 260 |
+
)
|
| 261 |
+
self.time_text_embed = text_time_guidance_cls(
|
| 262 |
+
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
|
| 266 |
+
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
|
| 267 |
+
|
| 268 |
+
self.transformer_blocks = nn.ModuleList(
|
| 269 |
+
[
|
| 270 |
+
FluxTransformerBlock(
|
| 271 |
+
dim=self.inner_dim,
|
| 272 |
+
num_attention_heads=num_attention_heads,
|
| 273 |
+
attention_head_dim=attention_head_dim,
|
| 274 |
+
)
|
| 275 |
+
for _ in range(num_layers)
|
| 276 |
+
]
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
self.single_transformer_blocks = nn.ModuleList(
|
| 280 |
+
[
|
| 281 |
+
FluxSingleTransformerBlock(
|
| 282 |
+
dim=self.inner_dim,
|
| 283 |
+
num_attention_heads=num_attention_heads,
|
| 284 |
+
attention_head_dim=attention_head_dim,
|
| 285 |
+
)
|
| 286 |
+
for _ in range(num_single_layers)
|
| 287 |
+
]
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
| 291 |
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
| 292 |
+
|
| 293 |
+
self.gradient_checkpointing = False
|
| 294 |
+
|
| 295 |
+
@property
|
| 296 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 297 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 298 |
+
r"""
|
| 299 |
+
Returns:
|
| 300 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 301 |
+
indexed by its weight name.
|
| 302 |
+
"""
|
| 303 |
+
# set recursively
|
| 304 |
+
processors = {}
|
| 305 |
+
|
| 306 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 307 |
+
if hasattr(module, "get_processor"):
|
| 308 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 309 |
+
|
| 310 |
+
for sub_name, child in module.named_children():
|
| 311 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 312 |
+
|
| 313 |
+
return processors
|
| 314 |
+
|
| 315 |
+
for name, module in self.named_children():
|
| 316 |
+
fn_recursive_add_processors(name, module, processors)
|
| 317 |
+
|
| 318 |
+
return processors
|
| 319 |
+
|
| 320 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 321 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 322 |
+
r"""
|
| 323 |
+
Sets the attention processor to use to compute attention.
|
| 324 |
+
|
| 325 |
+
Parameters:
|
| 326 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 327 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 328 |
+
for **all** `Attention` layers.
|
| 329 |
+
|
| 330 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 331 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 332 |
+
|
| 333 |
+
"""
|
| 334 |
+
count = len(self.attn_processors.keys())
|
| 335 |
+
|
| 336 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 337 |
+
raise ValueError(
|
| 338 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 339 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 343 |
+
if hasattr(module, "set_processor"):
|
| 344 |
+
if not isinstance(processor, dict):
|
| 345 |
+
module.set_processor(processor)
|
| 346 |
+
else:
|
| 347 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 348 |
+
|
| 349 |
+
for sub_name, child in module.named_children():
|
| 350 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 351 |
+
|
| 352 |
+
for name, module in self.named_children():
|
| 353 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 354 |
+
|
| 355 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
|
| 356 |
+
def fuse_qkv_projections(self):
|
| 357 |
+
"""
|
| 358 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
| 359 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
| 360 |
+
|
| 361 |
+
<Tip warning={true}>
|
| 362 |
+
|
| 363 |
+
This API is 🧪 experimental.
|
| 364 |
+
|
| 365 |
+
</Tip>
|
| 366 |
+
"""
|
| 367 |
+
self.original_attn_processors = None
|
| 368 |
+
|
| 369 |
+
for _, attn_processor in self.attn_processors.items():
|
| 370 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
| 371 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
| 372 |
+
|
| 373 |
+
self.original_attn_processors = self.attn_processors
|
| 374 |
+
|
| 375 |
+
for module in self.modules():
|
| 376 |
+
if isinstance(module, Attention):
|
| 377 |
+
module.fuse_projections(fuse=True)
|
| 378 |
+
|
| 379 |
+
self.set_attn_processor(FusedFluxAttnProcessor2_0())
|
| 380 |
+
|
| 381 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
| 382 |
+
def unfuse_qkv_projections(self):
|
| 383 |
+
"""Disables the fused QKV projection if enabled.
|
| 384 |
+
|
| 385 |
+
<Tip warning={true}>
|
| 386 |
+
|
| 387 |
+
This API is 🧪 experimental.
|
| 388 |
+
|
| 389 |
+
</Tip>
|
| 390 |
+
|
| 391 |
+
"""
|
| 392 |
+
if self.original_attn_processors is not None:
|
| 393 |
+
self.set_attn_processor(self.original_attn_processors)
|
| 394 |
+
|
| 395 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 396 |
+
if hasattr(module, "gradient_checkpointing"):
|
| 397 |
+
module.gradient_checkpointing = value
|
| 398 |
+
|
| 399 |
+
def forward(
|
| 400 |
+
self,
|
| 401 |
+
hidden_states: torch.Tensor,
|
| 402 |
+
cond_hidden_states: torch.Tensor = None,
|
| 403 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 404 |
+
pooled_projections: torch.Tensor = None,
|
| 405 |
+
timestep: torch.LongTensor = None,
|
| 406 |
+
img_ids: torch.Tensor = None,
|
| 407 |
+
txt_ids: torch.Tensor = None,
|
| 408 |
+
guidance: torch.Tensor = None,
|
| 409 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 410 |
+
controlnet_block_samples=None,
|
| 411 |
+
controlnet_single_block_samples=None,
|
| 412 |
+
return_dict: bool = True,
|
| 413 |
+
controlnet_blocks_repeat: bool = False,
|
| 414 |
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
| 415 |
+
if cond_hidden_states is not None:
|
| 416 |
+
use_condition = True
|
| 417 |
+
else:
|
| 418 |
+
use_condition = False
|
| 419 |
+
|
| 420 |
+
if joint_attention_kwargs is not None:
|
| 421 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
| 422 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
| 423 |
+
else:
|
| 424 |
+
lora_scale = 1.0
|
| 425 |
+
|
| 426 |
+
if USE_PEFT_BACKEND:
|
| 427 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 428 |
+
scale_lora_layers(self, lora_scale)
|
| 429 |
+
else:
|
| 430 |
+
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
| 431 |
+
logger.warning(
|
| 432 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 436 |
+
cond_hidden_states = self.x_embedder(cond_hidden_states)
|
| 437 |
+
|
| 438 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
| 439 |
+
if guidance is not None:
|
| 440 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
| 441 |
+
else:
|
| 442 |
+
guidance = None
|
| 443 |
+
|
| 444 |
+
temb = (
|
| 445 |
+
self.time_text_embed(timestep, pooled_projections)
|
| 446 |
+
if guidance is None
|
| 447 |
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
cond_temb = (
|
| 451 |
+
self.time_text_embed(torch.ones_like(timestep) * 0, pooled_projections)
|
| 452 |
+
if guidance is None
|
| 453 |
+
else self.time_text_embed(
|
| 454 |
+
torch.ones_like(timestep) * 0, guidance, pooled_projections
|
| 455 |
+
)
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
| 459 |
+
|
| 460 |
+
if txt_ids.ndim == 3:
|
| 461 |
+
logger.warning(
|
| 462 |
+
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
| 463 |
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
| 464 |
+
)
|
| 465 |
+
txt_ids = txt_ids[0]
|
| 466 |
+
if img_ids.ndim == 3:
|
| 467 |
+
logger.warning(
|
| 468 |
+
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
| 469 |
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
| 470 |
+
)
|
| 471 |
+
img_ids = img_ids[0]
|
| 472 |
+
|
| 473 |
+
ids = torch.cat((txt_ids, img_ids), dim=0)
|
| 474 |
+
image_rotary_emb = self.pos_embed(ids)
|
| 475 |
+
|
| 476 |
+
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
| 477 |
+
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
| 478 |
+
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
|
| 479 |
+
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
|
| 480 |
+
|
| 481 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
| 482 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 483 |
+
|
| 484 |
+
def create_custom_forward(module, return_dict=None):
|
| 485 |
+
def custom_forward(*inputs):
|
| 486 |
+
if return_dict is not None:
|
| 487 |
+
return module(*inputs, return_dict=return_dict)
|
| 488 |
+
else:
|
| 489 |
+
return module(*inputs)
|
| 490 |
+
|
| 491 |
+
return custom_forward
|
| 492 |
+
|
| 493 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 494 |
+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
| 495 |
+
create_custom_forward(block),
|
| 496 |
+
hidden_states,
|
| 497 |
+
encoder_hidden_states,
|
| 498 |
+
temb,
|
| 499 |
+
image_rotary_emb,
|
| 500 |
+
cond_temb=cond_temb if use_condition else None,
|
| 501 |
+
cond_hidden_states=cond_hidden_states if use_condition else None,
|
| 502 |
+
**ckpt_kwargs,
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
else:
|
| 506 |
+
encoder_hidden_states, hidden_states, cond_hidden_states = block(
|
| 507 |
+
hidden_states=hidden_states,
|
| 508 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 509 |
+
cond_hidden_states=cond_hidden_states if use_condition else None,
|
| 510 |
+
temb=temb,
|
| 511 |
+
cond_temb=cond_temb if use_condition else None,
|
| 512 |
+
image_rotary_emb=image_rotary_emb,
|
| 513 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
# controlnet residual
|
| 517 |
+
if controlnet_block_samples is not None:
|
| 518 |
+
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
| 519 |
+
interval_control = int(np.ceil(interval_control))
|
| 520 |
+
# For Xlabs ControlNet.
|
| 521 |
+
if controlnet_blocks_repeat:
|
| 522 |
+
hidden_states = (
|
| 523 |
+
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
|
| 524 |
+
)
|
| 525 |
+
else:
|
| 526 |
+
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
| 527 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 528 |
+
|
| 529 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
| 530 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 531 |
+
|
| 532 |
+
def create_custom_forward(module, return_dict=None):
|
| 533 |
+
def custom_forward(*inputs):
|
| 534 |
+
if return_dict is not None:
|
| 535 |
+
return module(*inputs, return_dict=return_dict)
|
| 536 |
+
else:
|
| 537 |
+
return module(*inputs)
|
| 538 |
+
|
| 539 |
+
return custom_forward
|
| 540 |
+
|
| 541 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 542 |
+
hidden_states, cond_hidden_states = torch.utils.checkpoint.checkpoint(
|
| 543 |
+
create_custom_forward(block),
|
| 544 |
+
hidden_states,
|
| 545 |
+
temb,
|
| 546 |
+
image_rotary_emb,
|
| 547 |
+
cond_temb=cond_temb if use_condition else None,
|
| 548 |
+
cond_hidden_states=cond_hidden_states if use_condition else None,
|
| 549 |
+
**ckpt_kwargs,
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
else:
|
| 553 |
+
hidden_states, cond_hidden_states = block(
|
| 554 |
+
hidden_states=hidden_states,
|
| 555 |
+
cond_hidden_states=cond_hidden_states if use_condition else None,
|
| 556 |
+
temb=temb,
|
| 557 |
+
cond_temb=cond_temb if use_condition else None,
|
| 558 |
+
image_rotary_emb=image_rotary_emb,
|
| 559 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
# controlnet residual
|
| 563 |
+
if controlnet_single_block_samples is not None:
|
| 564 |
+
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
| 565 |
+
interval_control = int(np.ceil(interval_control))
|
| 566 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
| 567 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
| 568 |
+
+ controlnet_single_block_samples[index_block // interval_control]
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
| 572 |
+
|
| 573 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 574 |
+
output = self.proj_out(hidden_states)
|
| 575 |
+
|
| 576 |
+
if USE_PEFT_BACKEND:
|
| 577 |
+
# remove `lora_scale` from each PEFT layer
|
| 578 |
+
unscale_lora_layers(self, lora_scale)
|
| 579 |
+
|
| 580 |
+
if not return_dict:
|
| 581 |
+
return (output,)
|
| 582 |
+
|
| 583 |
+
return Transformer2DModelOutput(sample=output)
|
src/transformer_with_loss.py
ADDED
|
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 9 |
+
from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
| 10 |
+
from diffusers.models.attention import FeedForward
|
| 11 |
+
from diffusers.models.attention_processor import (
|
| 12 |
+
Attention,
|
| 13 |
+
AttentionProcessor,
|
| 14 |
+
FluxAttnProcessor2_0,
|
| 15 |
+
FluxAttnProcessor2_0_NPU,
|
| 16 |
+
FusedFluxAttnProcessor2_0,
|
| 17 |
+
)
|
| 18 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 19 |
+
from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
| 20 |
+
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers, deprecate
|
| 21 |
+
from diffusers.utils.import_utils import is_torch_npu_available
|
| 22 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 23 |
+
from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
| 24 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 25 |
+
from diffusers import CacheMixin
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@maybe_allow_in_graph
|
| 31 |
+
class FluxSingleTransformerBlock(nn.Module):
|
| 32 |
+
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
| 35 |
+
|
| 36 |
+
self.norm = AdaLayerNormZeroSingle(dim)
|
| 37 |
+
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
| 38 |
+
self.act_mlp = nn.GELU(approximate="tanh")
|
| 39 |
+
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
| 40 |
+
|
| 41 |
+
if is_torch_npu_available():
|
| 42 |
+
deprecation_message = (
|
| 43 |
+
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
|
| 44 |
+
"should be set explicitly using the `set_attn_processor` method."
|
| 45 |
+
)
|
| 46 |
+
deprecate("npu_processor", "0.34.0", deprecation_message)
|
| 47 |
+
processor = FluxAttnProcessor2_0_NPU()
|
| 48 |
+
else:
|
| 49 |
+
processor = FluxAttnProcessor2_0()
|
| 50 |
+
|
| 51 |
+
self.attn = Attention(
|
| 52 |
+
query_dim=dim,
|
| 53 |
+
cross_attention_dim=None,
|
| 54 |
+
dim_head=attention_head_dim,
|
| 55 |
+
heads=num_attention_heads,
|
| 56 |
+
out_dim=dim,
|
| 57 |
+
bias=True,
|
| 58 |
+
processor=processor,
|
| 59 |
+
qk_norm="rms_norm",
|
| 60 |
+
eps=1e-6,
|
| 61 |
+
pre_only=True,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
def forward(
|
| 65 |
+
self,
|
| 66 |
+
hidden_states: torch.Tensor,
|
| 67 |
+
temb: torch.Tensor,
|
| 68 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 69 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 70 |
+
) -> torch.Tensor:
|
| 71 |
+
residual = hidden_states
|
| 72 |
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
| 73 |
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
| 74 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 75 |
+
attn_output = self.attn(
|
| 76 |
+
hidden_states=norm_hidden_states,
|
| 77 |
+
image_rotary_emb=image_rotary_emb,
|
| 78 |
+
**joint_attention_kwargs,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
| 82 |
+
gate = gate.unsqueeze(1)
|
| 83 |
+
hidden_states = gate * self.proj_out(hidden_states)
|
| 84 |
+
hidden_states = residual + hidden_states
|
| 85 |
+
if hidden_states.dtype == torch.float16:
|
| 86 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
| 87 |
+
|
| 88 |
+
return hidden_states
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@maybe_allow_in_graph
|
| 92 |
+
class FluxTransformerBlock(nn.Module):
|
| 93 |
+
def __init__(
|
| 94 |
+
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
|
| 95 |
+
):
|
| 96 |
+
super().__init__()
|
| 97 |
+
|
| 98 |
+
self.norm1 = AdaLayerNormZero(dim)
|
| 99 |
+
self.norm1_context = AdaLayerNormZero(dim)
|
| 100 |
+
|
| 101 |
+
self.attn = Attention(
|
| 102 |
+
query_dim=dim,
|
| 103 |
+
cross_attention_dim=None,
|
| 104 |
+
added_kv_proj_dim=dim,
|
| 105 |
+
dim_head=attention_head_dim,
|
| 106 |
+
heads=num_attention_heads,
|
| 107 |
+
out_dim=dim,
|
| 108 |
+
context_pre_only=False,
|
| 109 |
+
bias=True,
|
| 110 |
+
processor=FluxAttnProcessor2_0(),
|
| 111 |
+
qk_norm=qk_norm,
|
| 112 |
+
eps=eps,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 116 |
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
| 117 |
+
|
| 118 |
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 119 |
+
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
| 120 |
+
|
| 121 |
+
def forward(
|
| 122 |
+
self,
|
| 123 |
+
hidden_states: torch.Tensor,
|
| 124 |
+
encoder_hidden_states: torch.Tensor,
|
| 125 |
+
temb: torch.Tensor,
|
| 126 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 127 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 128 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 129 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
| 130 |
+
|
| 131 |
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
| 132 |
+
encoder_hidden_states, emb=temb
|
| 133 |
+
)
|
| 134 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 135 |
+
# Attention.
|
| 136 |
+
attention_outputs = self.attn(
|
| 137 |
+
hidden_states=norm_hidden_states,
|
| 138 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 139 |
+
image_rotary_emb=image_rotary_emb,
|
| 140 |
+
**joint_attention_kwargs,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
if len(attention_outputs) == 2:
|
| 144 |
+
attn_output, context_attn_output = attention_outputs
|
| 145 |
+
elif len(attention_outputs) == 3:
|
| 146 |
+
attn_output, context_attn_output, ip_attn_output = attention_outputs
|
| 147 |
+
|
| 148 |
+
# Process attention outputs for the `hidden_states`.
|
| 149 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
| 150 |
+
hidden_states = hidden_states + attn_output
|
| 151 |
+
|
| 152 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 153 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 154 |
+
|
| 155 |
+
ff_output = self.ff(norm_hidden_states)
|
| 156 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
| 157 |
+
|
| 158 |
+
hidden_states = hidden_states + ff_output
|
| 159 |
+
if len(attention_outputs) == 3:
|
| 160 |
+
hidden_states = hidden_states + ip_attn_output
|
| 161 |
+
|
| 162 |
+
# Process attention outputs for the `encoder_hidden_states`.
|
| 163 |
+
|
| 164 |
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
| 165 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
| 166 |
+
|
| 167 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
| 168 |
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
| 169 |
+
|
| 170 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
| 171 |
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
| 172 |
+
if encoder_hidden_states.dtype == torch.float16:
|
| 173 |
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
| 174 |
+
|
| 175 |
+
return encoder_hidden_states, hidden_states
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class FluxTransformer2DModelWithLoss(
|
| 179 |
+
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
|
| 180 |
+
):
|
| 181 |
+
_supports_gradient_checkpointing = True
|
| 182 |
+
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
| 183 |
+
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
| 184 |
+
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
| 185 |
+
|
| 186 |
+
@register_to_config
|
| 187 |
+
def __init__(
|
| 188 |
+
self,
|
| 189 |
+
patch_size: int = 1,
|
| 190 |
+
in_channels: int = 64,
|
| 191 |
+
out_channels: Optional[int] = None,
|
| 192 |
+
num_layers: int = 19,
|
| 193 |
+
num_single_layers: int = 38,
|
| 194 |
+
attention_head_dim: int = 128,
|
| 195 |
+
num_attention_heads: int = 24,
|
| 196 |
+
joint_attention_dim: int = 4096,
|
| 197 |
+
pooled_projection_dim: int = 768,
|
| 198 |
+
guidance_embeds: bool = False,
|
| 199 |
+
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
| 200 |
+
):
|
| 201 |
+
super().__init__()
|
| 202 |
+
self.out_channels = out_channels or in_channels
|
| 203 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
| 204 |
+
|
| 205 |
+
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
| 206 |
+
|
| 207 |
+
text_time_guidance_cls = (
|
| 208 |
+
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
|
| 209 |
+
)
|
| 210 |
+
self.time_text_embed = text_time_guidance_cls(
|
| 211 |
+
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
|
| 215 |
+
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
|
| 216 |
+
|
| 217 |
+
self.transformer_blocks = nn.ModuleList(
|
| 218 |
+
[
|
| 219 |
+
FluxTransformerBlock(
|
| 220 |
+
dim=self.inner_dim,
|
| 221 |
+
num_attention_heads=num_attention_heads,
|
| 222 |
+
attention_head_dim=attention_head_dim,
|
| 223 |
+
)
|
| 224 |
+
for _ in range(num_layers)
|
| 225 |
+
]
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
self.single_transformer_blocks = nn.ModuleList(
|
| 229 |
+
[
|
| 230 |
+
FluxSingleTransformerBlock(
|
| 231 |
+
dim=self.inner_dim,
|
| 232 |
+
num_attention_heads=num_attention_heads,
|
| 233 |
+
attention_head_dim=attention_head_dim,
|
| 234 |
+
)
|
| 235 |
+
for _ in range(num_single_layers)
|
| 236 |
+
]
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
| 240 |
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
| 241 |
+
|
| 242 |
+
self.gradient_checkpointing = False
|
| 243 |
+
|
| 244 |
+
@property
|
| 245 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 246 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 247 |
+
r"""
|
| 248 |
+
Returns:
|
| 249 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 250 |
+
indexed by its weight name.
|
| 251 |
+
"""
|
| 252 |
+
# set recursively
|
| 253 |
+
processors = {}
|
| 254 |
+
|
| 255 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 256 |
+
if hasattr(module, "get_processor"):
|
| 257 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 258 |
+
|
| 259 |
+
for sub_name, child in module.named_children():
|
| 260 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 261 |
+
|
| 262 |
+
return processors
|
| 263 |
+
|
| 264 |
+
for name, module in self.named_children():
|
| 265 |
+
fn_recursive_add_processors(name, module, processors)
|
| 266 |
+
|
| 267 |
+
return processors
|
| 268 |
+
|
| 269 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 270 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 271 |
+
r"""
|
| 272 |
+
Sets the attention processor to use to compute attention.
|
| 273 |
+
|
| 274 |
+
Parameters:
|
| 275 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 276 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 277 |
+
for **all** `Attention` layers.
|
| 278 |
+
|
| 279 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 280 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 281 |
+
|
| 282 |
+
"""
|
| 283 |
+
count = len(self.attn_processors.keys())
|
| 284 |
+
|
| 285 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 286 |
+
raise ValueError(
|
| 287 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 288 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 292 |
+
if hasattr(module, "set_processor"):
|
| 293 |
+
if not isinstance(processor, dict):
|
| 294 |
+
module.set_processor(processor)
|
| 295 |
+
else:
|
| 296 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 297 |
+
|
| 298 |
+
for sub_name, child in module.named_children():
|
| 299 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 300 |
+
|
| 301 |
+
for name, module in self.named_children():
|
| 302 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 303 |
+
|
| 304 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
|
| 305 |
+
def fuse_qkv_projections(self):
|
| 306 |
+
"""
|
| 307 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
| 308 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
| 309 |
+
|
| 310 |
+
<Tip warning={true}>
|
| 311 |
+
|
| 312 |
+
This API is 🧪 experimental.
|
| 313 |
+
|
| 314 |
+
</Tip>
|
| 315 |
+
"""
|
| 316 |
+
self.original_attn_processors = None
|
| 317 |
+
|
| 318 |
+
for _, attn_processor in self.attn_processors.items():
|
| 319 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
| 320 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
| 321 |
+
|
| 322 |
+
self.original_attn_processors = self.attn_processors
|
| 323 |
+
|
| 324 |
+
for module in self.modules():
|
| 325 |
+
if isinstance(module, Attention):
|
| 326 |
+
module.fuse_projections(fuse=True)
|
| 327 |
+
|
| 328 |
+
self.set_attn_processor(FusedFluxAttnProcessor2_0())
|
| 329 |
+
|
| 330 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
| 331 |
+
def unfuse_qkv_projections(self):
|
| 332 |
+
"""Disables the fused QKV projection if enabled.
|
| 333 |
+
|
| 334 |
+
<Tip warning={true}>
|
| 335 |
+
|
| 336 |
+
This API is 🧪 experimental.
|
| 337 |
+
|
| 338 |
+
</Tip>
|
| 339 |
+
|
| 340 |
+
"""
|
| 341 |
+
if self.original_attn_processors is not None:
|
| 342 |
+
self.set_attn_processor(self.original_attn_processors)
|
| 343 |
+
|
| 344 |
+
def forward(
|
| 345 |
+
self,
|
| 346 |
+
hidden_states: torch.Tensor,
|
| 347 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 348 |
+
pooled_projections: torch.Tensor = None,
|
| 349 |
+
timestep: torch.LongTensor = None,
|
| 350 |
+
img_ids: torch.Tensor = None,
|
| 351 |
+
txt_ids: torch.Tensor = None,
|
| 352 |
+
guidance: torch.Tensor = None,
|
| 353 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 354 |
+
controlnet_block_samples=None,
|
| 355 |
+
controlnet_single_block_samples=None,
|
| 356 |
+
return_dict: bool = True,
|
| 357 |
+
controlnet_blocks_repeat: bool = False,
|
| 358 |
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
| 359 |
+
"""
|
| 360 |
+
The [`FluxTransformer2DModel`] forward method.
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
| 364 |
+
Input `hidden_states`.
|
| 365 |
+
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
| 366 |
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
| 367 |
+
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
| 368 |
+
from the embeddings of input conditions.
|
| 369 |
+
timestep ( `torch.LongTensor`):
|
| 370 |
+
Used to indicate denoising step.
|
| 371 |
+
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
| 372 |
+
A list of tensors that if specified are added to the residuals of transformer blocks.
|
| 373 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 374 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 375 |
+
`self.processor` in
|
| 376 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 377 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 378 |
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
| 379 |
+
tuple.
|
| 380 |
+
|
| 381 |
+
Returns:
|
| 382 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
| 383 |
+
`tuple` where the first element is the sample tensor.
|
| 384 |
+
"""
|
| 385 |
+
if joint_attention_kwargs is not None:
|
| 386 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
| 387 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
| 388 |
+
else:
|
| 389 |
+
lora_scale = 1.0
|
| 390 |
+
|
| 391 |
+
if USE_PEFT_BACKEND:
|
| 392 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 393 |
+
scale_lora_layers(self, lora_scale)
|
| 394 |
+
else:
|
| 395 |
+
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
| 396 |
+
logger.warning(
|
| 397 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 401 |
+
|
| 402 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
| 403 |
+
if guidance is not None:
|
| 404 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
| 405 |
+
|
| 406 |
+
temb = (
|
| 407 |
+
self.time_text_embed(timestep, pooled_projections)
|
| 408 |
+
if guidance is None
|
| 409 |
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
| 410 |
+
)
|
| 411 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
| 412 |
+
|
| 413 |
+
if txt_ids.ndim == 3:
|
| 414 |
+
logger.warning(
|
| 415 |
+
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
| 416 |
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
| 417 |
+
)
|
| 418 |
+
txt_ids = txt_ids[0]
|
| 419 |
+
if img_ids.ndim == 3:
|
| 420 |
+
logger.warning(
|
| 421 |
+
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
| 422 |
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
| 423 |
+
)
|
| 424 |
+
img_ids = img_ids[0]
|
| 425 |
+
|
| 426 |
+
ids = torch.cat((txt_ids, img_ids), dim=0)
|
| 427 |
+
image_rotary_emb = self.pos_embed(ids)
|
| 428 |
+
|
| 429 |
+
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
| 430 |
+
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
| 431 |
+
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
|
| 432 |
+
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
|
| 433 |
+
|
| 434 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
| 435 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 436 |
+
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
| 437 |
+
block,
|
| 438 |
+
hidden_states,
|
| 439 |
+
encoder_hidden_states,
|
| 440 |
+
temb,
|
| 441 |
+
image_rotary_emb,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
else:
|
| 445 |
+
encoder_hidden_states, hidden_states = block(
|
| 446 |
+
hidden_states=hidden_states,
|
| 447 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 448 |
+
temb=temb,
|
| 449 |
+
image_rotary_emb=image_rotary_emb,
|
| 450 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
# controlnet residual
|
| 454 |
+
if controlnet_block_samples is not None:
|
| 455 |
+
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
| 456 |
+
interval_control = int(np.ceil(interval_control))
|
| 457 |
+
# For Xlabs ControlNet.
|
| 458 |
+
if controlnet_blocks_repeat:
|
| 459 |
+
hidden_states = (
|
| 460 |
+
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
|
| 461 |
+
)
|
| 462 |
+
else:
|
| 463 |
+
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
| 464 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 465 |
+
|
| 466 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
| 467 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 468 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 469 |
+
block,
|
| 470 |
+
hidden_states,
|
| 471 |
+
temb,
|
| 472 |
+
image_rotary_emb,
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
else:
|
| 476 |
+
hidden_states = block(
|
| 477 |
+
hidden_states=hidden_states,
|
| 478 |
+
temb=temb,
|
| 479 |
+
image_rotary_emb=image_rotary_emb,
|
| 480 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
# controlnet residual
|
| 484 |
+
if controlnet_single_block_samples is not None:
|
| 485 |
+
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
| 486 |
+
interval_control = int(np.ceil(interval_control))
|
| 487 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
| 488 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
| 489 |
+
+ controlnet_single_block_samples[index_block // interval_control]
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
| 493 |
+
|
| 494 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 495 |
+
output = self.proj_out(hidden_states)
|
| 496 |
+
|
| 497 |
+
if USE_PEFT_BACKEND:
|
| 498 |
+
# remove `lora_scale` from each PEFT layer
|
| 499 |
+
unscale_lora_layers(self, lora_scale)
|
| 500 |
+
|
| 501 |
+
if not return_dict:
|
| 502 |
+
return (output,)
|
| 503 |
+
|
| 504 |
+
return Transformer2DModelOutput(sample=output)
|
test_imgs/2.png
ADDED
|
Git LFS Details
|
test_imgs/3.png
ADDED
|
Git LFS Details
|
test_imgs/generated_1.png
ADDED
|
Git LFS Details
|
test_imgs/generated_1_bbox.png
ADDED
|
Git LFS Details
|
test_imgs/generated_2.png
ADDED
|
Git LFS Details
|
test_imgs/generated_2_bbox.png
ADDED
|
Git LFS Details
|
test_imgs/generated_3.png
ADDED
|
Git LFS Details
|
test_imgs/generated_3_bbox.png
ADDED
|
Git LFS Details
|
test_imgs/generated_3_bbox_1.png
ADDED
|
Git LFS Details
|
test_imgs/product_1.jpg
ADDED
|
test_imgs/product_1_bbox.png
ADDED
|
Git LFS Details
|
test_imgs/product_2.png
ADDED
|
Git LFS Details
|
test_imgs/product_2_bbox.png
ADDED
|
Git LFS Details
|
test_imgs/product_3.png
ADDED
|
Git LFS Details
|
test_imgs/product_3_bbox.png
ADDED
|
Git LFS Details
|
test_imgs/product_3_bbox_1.png
ADDED
|
Git LFS Details
|
uno/dataset/uno.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
|
| 2 |
+
# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
|
| 3 |
+
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
import torchvision.transforms.functional as TVF
|
| 22 |
+
from torch.utils.data import DataLoader, Dataset
|
| 23 |
+
from torchvision.transforms import Compose, Normalize, ToTensor
|
| 24 |
+
|
| 25 |
+
def bucket_images(images: list[torch.Tensor], resolution: int = 512):
|
| 26 |
+
bucket_override=[
|
| 27 |
+
# h w
|
| 28 |
+
(256, 768),
|
| 29 |
+
(320, 768),
|
| 30 |
+
(320, 704),
|
| 31 |
+
(384, 640),
|
| 32 |
+
(448, 576),
|
| 33 |
+
(512, 512),
|
| 34 |
+
(576, 448),
|
| 35 |
+
(640, 384),
|
| 36 |
+
(704, 320),
|
| 37 |
+
(768, 320),
|
| 38 |
+
(768, 256)
|
| 39 |
+
]
|
| 40 |
+
bucket_override = [(int(h / 512 * resolution), int(w / 512 * resolution)) for h, w in bucket_override]
|
| 41 |
+
bucket_override = [(h // 16 * 16, w // 16 * 16) for h, w in bucket_override]
|
| 42 |
+
|
| 43 |
+
aspect_ratios = [image.shape[-2] / image.shape[-1] for image in images]
|
| 44 |
+
mean_aspect_ratio = np.mean(aspect_ratios)
|
| 45 |
+
|
| 46 |
+
new_h, new_w = bucket_override[0]
|
| 47 |
+
min_aspect_diff = np.abs(new_h / new_w - mean_aspect_ratio)
|
| 48 |
+
for h, w in bucket_override:
|
| 49 |
+
aspect_diff = np.abs(h / w - mean_aspect_ratio)
|
| 50 |
+
if aspect_diff < min_aspect_diff:
|
| 51 |
+
min_aspect_diff = aspect_diff
|
| 52 |
+
new_h, new_w = h, w
|
| 53 |
+
|
| 54 |
+
images = [TVF.resize(image, (new_h, new_w)) for image in images]
|
| 55 |
+
images = torch.stack(images, dim=0)
|
| 56 |
+
return images
|
| 57 |
+
|
| 58 |
+
class FluxPairedDatasetV2(Dataset):
|
| 59 |
+
def __init__(self, json_file: str, resolution: int, resolution_ref: int | None = None):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.json_file = json_file
|
| 62 |
+
self.resolution = resolution
|
| 63 |
+
self.resolution_ref = resolution_ref if resolution_ref is not None else resolution
|
| 64 |
+
self.image_root = os.path.dirname(json_file)
|
| 65 |
+
|
| 66 |
+
with open(self.json_file, "rt") as f:
|
| 67 |
+
self.data_dicts = json.load(f)
|
| 68 |
+
|
| 69 |
+
self.transform = Compose([
|
| 70 |
+
ToTensor(),
|
| 71 |
+
Normalize([0.5], [0.5]),
|
| 72 |
+
])
|
| 73 |
+
|
| 74 |
+
def __getitem__(self, idx):
|
| 75 |
+
data_dict = self.data_dicts[idx]
|
| 76 |
+
image_paths = [data_dict["image_path"]] if "image_path" in data_dict else data_dict["image_paths"]
|
| 77 |
+
txt = data_dict["prompt"]
|
| 78 |
+
image_tgt_path = data_dict.get("image_tgt_path", None)
|
| 79 |
+
ref_imgs = [
|
| 80 |
+
Image.open(os.path.join(self.image_root, path)).convert("RGB")
|
| 81 |
+
for path in image_paths
|
| 82 |
+
]
|
| 83 |
+
ref_imgs = [self.transform(img) for img in ref_imgs]
|
| 84 |
+
img = None
|
| 85 |
+
if image_tgt_path is not None:
|
| 86 |
+
img = Image.open(os.path.join(self.image_root, image_tgt_path)).convert("RGB")
|
| 87 |
+
img = self.transform(img)
|
| 88 |
+
|
| 89 |
+
return {
|
| 90 |
+
"img": img,
|
| 91 |
+
"txt": txt,
|
| 92 |
+
"ref_imgs": ref_imgs,
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
def __len__(self):
|
| 96 |
+
return len(self.data_dicts)
|
| 97 |
+
|
| 98 |
+
def collate_fn(self, batch):
|
| 99 |
+
img = [data["img"] for data in batch]
|
| 100 |
+
txt = [data["txt"] for data in batch]
|
| 101 |
+
ref_imgs = [data["ref_imgs"] for data in batch]
|
| 102 |
+
assert all([len(ref_imgs[0]) == len(ref_imgs[i]) for i in range(len(ref_imgs))])
|
| 103 |
+
|
| 104 |
+
n_ref = len(ref_imgs[0])
|
| 105 |
+
|
| 106 |
+
img = bucket_images(img, self.resolution)
|
| 107 |
+
ref_imgs_new = []
|
| 108 |
+
for i in range(n_ref):
|
| 109 |
+
ref_imgs_i = [refs[i] for refs in ref_imgs]
|
| 110 |
+
ref_imgs_i = bucket_images(ref_imgs_i, self.resolution_ref)
|
| 111 |
+
ref_imgs_new.append(ref_imgs_i)
|
| 112 |
+
|
| 113 |
+
return {
|
| 114 |
+
"txt": txt,
|
| 115 |
+
"img": img,
|
| 116 |
+
"ref_imgs": ref_imgs_new,
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
if __name__ == '__main__':
|
| 120 |
+
import argparse
|
| 121 |
+
from pprint import pprint
|
| 122 |
+
parser = argparse.ArgumentParser()
|
| 123 |
+
# parser.add_argument("--json_file", type=str, required=True)
|
| 124 |
+
parser.add_argument("--json_file", type=str, default="datasets/fake_train_data.json")
|
| 125 |
+
args = parser.parse_args()
|
| 126 |
+
dataset = FluxPairedDatasetV2(args.json_file, 512)
|
| 127 |
+
dataloder = DataLoader(dataset, batch_size=4, collate_fn=dataset.collate_fn)
|
| 128 |
+
|
| 129 |
+
for i, data_dict in enumerate(dataloder):
|
| 130 |
+
pprint(i)
|
| 131 |
+
pprint(data_dict)
|
| 132 |
+
breakpoint()
|
uno/flux/math.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
|
| 2 |
+
# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
|
| 3 |
+
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from einops import rearrange
|
| 18 |
+
from torch import Tensor
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
| 22 |
+
q, k = apply_rope(q, k, pe)
|
| 23 |
+
|
| 24 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
| 25 |
+
x = rearrange(x, "B H L D -> B L (H D)")
|
| 26 |
+
|
| 27 |
+
return x
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
| 31 |
+
assert dim % 2 == 0
|
| 32 |
+
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
| 33 |
+
omega = 1.0 / (theta**scale)
|
| 34 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
| 35 |
+
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
| 36 |
+
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
| 37 |
+
return out.float()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
| 41 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
| 42 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
| 43 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
| 44 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
| 45 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
uno/flux/model.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
|
| 2 |
+
# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
|
| 3 |
+
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from torch import Tensor, nn
|
| 20 |
+
|
| 21 |
+
from .modules.layers import DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock, timestep_embedding
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class FluxParams:
|
| 26 |
+
in_channels: int
|
| 27 |
+
vec_in_dim: int
|
| 28 |
+
context_in_dim: int
|
| 29 |
+
hidden_size: int
|
| 30 |
+
mlp_ratio: float
|
| 31 |
+
num_heads: int
|
| 32 |
+
depth: int
|
| 33 |
+
depth_single_blocks: int
|
| 34 |
+
axes_dim: list[int]
|
| 35 |
+
theta: int
|
| 36 |
+
qkv_bias: bool
|
| 37 |
+
guidance_embed: bool
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class Flux(nn.Module):
|
| 41 |
+
"""
|
| 42 |
+
Transformer model for flow matching on sequences.
|
| 43 |
+
"""
|
| 44 |
+
_supports_gradient_checkpointing = True
|
| 45 |
+
|
| 46 |
+
def __init__(self, params: FluxParams):
|
| 47 |
+
super().__init__()
|
| 48 |
+
|
| 49 |
+
self.params = params
|
| 50 |
+
self.in_channels = params.in_channels
|
| 51 |
+
self.out_channels = self.in_channels
|
| 52 |
+
if params.hidden_size % params.num_heads != 0:
|
| 53 |
+
raise ValueError(
|
| 54 |
+
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
| 55 |
+
)
|
| 56 |
+
pe_dim = params.hidden_size // params.num_heads
|
| 57 |
+
if sum(params.axes_dim) != pe_dim:
|
| 58 |
+
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
| 59 |
+
self.hidden_size = params.hidden_size
|
| 60 |
+
self.num_heads = params.num_heads
|
| 61 |
+
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
| 62 |
+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
| 63 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
| 64 |
+
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
| 65 |
+
self.guidance_in = (
|
| 66 |
+
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
| 67 |
+
)
|
| 68 |
+
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
| 69 |
+
|
| 70 |
+
self.double_blocks = nn.ModuleList(
|
| 71 |
+
[
|
| 72 |
+
DoubleStreamBlock(
|
| 73 |
+
self.hidden_size,
|
| 74 |
+
self.num_heads,
|
| 75 |
+
mlp_ratio=params.mlp_ratio,
|
| 76 |
+
qkv_bias=params.qkv_bias,
|
| 77 |
+
)
|
| 78 |
+
for _ in range(params.depth)
|
| 79 |
+
]
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
self.single_blocks = nn.ModuleList(
|
| 83 |
+
[
|
| 84 |
+
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
| 85 |
+
for _ in range(params.depth_single_blocks)
|
| 86 |
+
]
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
| 90 |
+
self.gradient_checkpointing = False
|
| 91 |
+
|
| 92 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 93 |
+
if hasattr(module, "gradient_checkpointing"):
|
| 94 |
+
module.gradient_checkpointing = value
|
| 95 |
+
|
| 96 |
+
@property
|
| 97 |
+
def attn_processors(self):
|
| 98 |
+
# set recursively
|
| 99 |
+
processors = {} # type: dict[str, nn.Module]
|
| 100 |
+
|
| 101 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
|
| 102 |
+
if hasattr(module, "set_processor"):
|
| 103 |
+
processors[f"{name}.processor"] = module.processor
|
| 104 |
+
|
| 105 |
+
for sub_name, child in module.named_children():
|
| 106 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 107 |
+
|
| 108 |
+
return processors
|
| 109 |
+
|
| 110 |
+
for name, module in self.named_children():
|
| 111 |
+
fn_recursive_add_processors(name, module, processors)
|
| 112 |
+
|
| 113 |
+
return processors
|
| 114 |
+
|
| 115 |
+
def set_attn_processor(self, processor):
|
| 116 |
+
r"""
|
| 117 |
+
Sets the attention processor to use to compute attention.
|
| 118 |
+
|
| 119 |
+
Parameters:
|
| 120 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 121 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 122 |
+
for **all** `Attention` layers.
|
| 123 |
+
|
| 124 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 125 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 126 |
+
|
| 127 |
+
"""
|
| 128 |
+
count = len(self.attn_processors.keys())
|
| 129 |
+
|
| 130 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 131 |
+
raise ValueError(
|
| 132 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 133 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 137 |
+
if hasattr(module, "set_processor"):
|
| 138 |
+
if not isinstance(processor, dict):
|
| 139 |
+
module.set_processor(processor)
|
| 140 |
+
else:
|
| 141 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 142 |
+
|
| 143 |
+
for sub_name, child in module.named_children():
|
| 144 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 145 |
+
|
| 146 |
+
for name, module in self.named_children():
|
| 147 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 148 |
+
|
| 149 |
+
def forward(
|
| 150 |
+
self,
|
| 151 |
+
img: Tensor,
|
| 152 |
+
img_ids: Tensor,
|
| 153 |
+
txt: Tensor,
|
| 154 |
+
txt_ids: Tensor,
|
| 155 |
+
timesteps: Tensor,
|
| 156 |
+
y: Tensor,
|
| 157 |
+
guidance: Tensor | None = None,
|
| 158 |
+
ref_img: Tensor | None = None,
|
| 159 |
+
ref_img_ids: Tensor | None = None,
|
| 160 |
+
) -> Tensor:
|
| 161 |
+
if img.ndim != 3 or txt.ndim != 3:
|
| 162 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
| 163 |
+
|
| 164 |
+
# running on sequences img
|
| 165 |
+
img = self.img_in(img)
|
| 166 |
+
vec = self.time_in(timestep_embedding(timesteps, 256))
|
| 167 |
+
if self.params.guidance_embed:
|
| 168 |
+
if guidance is None:
|
| 169 |
+
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
| 170 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
| 171 |
+
vec = vec + self.vector_in(y)
|
| 172 |
+
txt = self.txt_in(txt)
|
| 173 |
+
|
| 174 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
| 175 |
+
|
| 176 |
+
# concat ref_img/img
|
| 177 |
+
img_end = img.shape[1]
|
| 178 |
+
if ref_img is not None:
|
| 179 |
+
if isinstance(ref_img, tuple) or isinstance(ref_img, list):
|
| 180 |
+
img_in = [img] + [self.img_in(ref) for ref in ref_img]
|
| 181 |
+
img_ids = [ids] + [ref_ids for ref_ids in ref_img_ids]
|
| 182 |
+
img = torch.cat(img_in, dim=1)
|
| 183 |
+
ids = torch.cat(img_ids, dim=1)
|
| 184 |
+
else:
|
| 185 |
+
img = torch.cat((img, self.img_in(ref_img)), dim=1)
|
| 186 |
+
ids = torch.cat((ids, ref_img_ids), dim=1)
|
| 187 |
+
pe = self.pe_embedder(ids)
|
| 188 |
+
|
| 189 |
+
for index_block, block in enumerate(self.double_blocks):
|
| 190 |
+
if self.training and self.gradient_checkpointing:
|
| 191 |
+
img, txt = torch.utils.checkpoint.checkpoint(
|
| 192 |
+
block,
|
| 193 |
+
img=img,
|
| 194 |
+
txt=txt,
|
| 195 |
+
vec=vec,
|
| 196 |
+
pe=pe,
|
| 197 |
+
use_reentrant=False,
|
| 198 |
+
)
|
| 199 |
+
else:
|
| 200 |
+
img, txt = block(
|
| 201 |
+
img=img,
|
| 202 |
+
txt=txt,
|
| 203 |
+
vec=vec,
|
| 204 |
+
pe=pe
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
img = torch.cat((txt, img), 1)
|
| 208 |
+
for block in self.single_blocks:
|
| 209 |
+
if self.training and self.gradient_checkpointing:
|
| 210 |
+
img = torch.utils.checkpoint.checkpoint(
|
| 211 |
+
block,
|
| 212 |
+
img, vec=vec, pe=pe,
|
| 213 |
+
use_reentrant=False
|
| 214 |
+
)
|
| 215 |
+
else:
|
| 216 |
+
img = block(img, vec=vec, pe=pe)
|
| 217 |
+
img = img[:, txt.shape[1] :, ...]
|
| 218 |
+
# index img
|
| 219 |
+
img = img[:, :img_end, ...]
|
| 220 |
+
|
| 221 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
| 222 |
+
return img
|
uno/flux/modules/autoencoder.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
|
| 2 |
+
# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
|
| 3 |
+
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from einops import rearrange
|
| 20 |
+
from torch import Tensor, nn
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class AutoEncoderParams:
|
| 25 |
+
resolution: int
|
| 26 |
+
in_channels: int
|
| 27 |
+
ch: int
|
| 28 |
+
out_ch: int
|
| 29 |
+
ch_mult: list[int]
|
| 30 |
+
num_res_blocks: int
|
| 31 |
+
z_channels: int
|
| 32 |
+
scale_factor: float
|
| 33 |
+
shift_factor: float
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def swish(x: Tensor) -> Tensor:
|
| 37 |
+
return x * torch.sigmoid(x)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class AttnBlock(nn.Module):
|
| 41 |
+
def __init__(self, in_channels: int):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.in_channels = in_channels
|
| 44 |
+
|
| 45 |
+
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 46 |
+
|
| 47 |
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 48 |
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 49 |
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 50 |
+
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 51 |
+
|
| 52 |
+
def attention(self, h_: Tensor) -> Tensor:
|
| 53 |
+
h_ = self.norm(h_)
|
| 54 |
+
q = self.q(h_)
|
| 55 |
+
k = self.k(h_)
|
| 56 |
+
v = self.v(h_)
|
| 57 |
+
|
| 58 |
+
b, c, h, w = q.shape
|
| 59 |
+
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
| 60 |
+
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
| 61 |
+
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
| 62 |
+
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
| 63 |
+
|
| 64 |
+
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
| 65 |
+
|
| 66 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 67 |
+
return x + self.proj_out(self.attention(x))
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class ResnetBlock(nn.Module):
|
| 71 |
+
def __init__(self, in_channels: int, out_channels: int):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.in_channels = in_channels
|
| 74 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 75 |
+
self.out_channels = out_channels
|
| 76 |
+
|
| 77 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 78 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 79 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
| 80 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 81 |
+
if self.in_channels != self.out_channels:
|
| 82 |
+
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 83 |
+
|
| 84 |
+
def forward(self, x):
|
| 85 |
+
h = x
|
| 86 |
+
h = self.norm1(h)
|
| 87 |
+
h = swish(h)
|
| 88 |
+
h = self.conv1(h)
|
| 89 |
+
|
| 90 |
+
h = self.norm2(h)
|
| 91 |
+
h = swish(h)
|
| 92 |
+
h = self.conv2(h)
|
| 93 |
+
|
| 94 |
+
if self.in_channels != self.out_channels:
|
| 95 |
+
x = self.nin_shortcut(x)
|
| 96 |
+
|
| 97 |
+
return x + h
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class Downsample(nn.Module):
|
| 101 |
+
def __init__(self, in_channels: int):
|
| 102 |
+
super().__init__()
|
| 103 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 104 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
| 105 |
+
|
| 106 |
+
def forward(self, x: Tensor):
|
| 107 |
+
pad = (0, 1, 0, 1)
|
| 108 |
+
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
| 109 |
+
x = self.conv(x)
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class Upsample(nn.Module):
|
| 114 |
+
def __init__(self, in_channels: int):
|
| 115 |
+
super().__init__()
|
| 116 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
| 117 |
+
|
| 118 |
+
def forward(self, x: Tensor):
|
| 119 |
+
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 120 |
+
x = self.conv(x)
|
| 121 |
+
return x
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class Encoder(nn.Module):
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
resolution: int,
|
| 128 |
+
in_channels: int,
|
| 129 |
+
ch: int,
|
| 130 |
+
ch_mult: list[int],
|
| 131 |
+
num_res_blocks: int,
|
| 132 |
+
z_channels: int,
|
| 133 |
+
):
|
| 134 |
+
super().__init__()
|
| 135 |
+
self.ch = ch
|
| 136 |
+
self.num_resolutions = len(ch_mult)
|
| 137 |
+
self.num_res_blocks = num_res_blocks
|
| 138 |
+
self.resolution = resolution
|
| 139 |
+
self.in_channels = in_channels
|
| 140 |
+
# downsampling
|
| 141 |
+
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
| 142 |
+
|
| 143 |
+
curr_res = resolution
|
| 144 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
| 145 |
+
self.in_ch_mult = in_ch_mult
|
| 146 |
+
self.down = nn.ModuleList()
|
| 147 |
+
block_in = self.ch
|
| 148 |
+
for i_level in range(self.num_resolutions):
|
| 149 |
+
block = nn.ModuleList()
|
| 150 |
+
attn = nn.ModuleList()
|
| 151 |
+
block_in = ch * in_ch_mult[i_level]
|
| 152 |
+
block_out = ch * ch_mult[i_level]
|
| 153 |
+
for _ in range(self.num_res_blocks):
|
| 154 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
| 155 |
+
block_in = block_out
|
| 156 |
+
down = nn.Module()
|
| 157 |
+
down.block = block
|
| 158 |
+
down.attn = attn
|
| 159 |
+
if i_level != self.num_resolutions - 1:
|
| 160 |
+
down.downsample = Downsample(block_in)
|
| 161 |
+
curr_res = curr_res // 2
|
| 162 |
+
self.down.append(down)
|
| 163 |
+
|
| 164 |
+
# middle
|
| 165 |
+
self.mid = nn.Module()
|
| 166 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 167 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 168 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 169 |
+
|
| 170 |
+
# end
|
| 171 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
| 172 |
+
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
| 173 |
+
|
| 174 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 175 |
+
# downsampling
|
| 176 |
+
hs = [self.conv_in(x)]
|
| 177 |
+
for i_level in range(self.num_resolutions):
|
| 178 |
+
for i_block in range(self.num_res_blocks):
|
| 179 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
| 180 |
+
if len(self.down[i_level].attn) > 0:
|
| 181 |
+
h = self.down[i_level].attn[i_block](h)
|
| 182 |
+
hs.append(h)
|
| 183 |
+
if i_level != self.num_resolutions - 1:
|
| 184 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
| 185 |
+
|
| 186 |
+
# middle
|
| 187 |
+
h = hs[-1]
|
| 188 |
+
h = self.mid.block_1(h)
|
| 189 |
+
h = self.mid.attn_1(h)
|
| 190 |
+
h = self.mid.block_2(h)
|
| 191 |
+
# end
|
| 192 |
+
h = self.norm_out(h)
|
| 193 |
+
h = swish(h)
|
| 194 |
+
h = self.conv_out(h)
|
| 195 |
+
return h
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class Decoder(nn.Module):
|
| 199 |
+
def __init__(
|
| 200 |
+
self,
|
| 201 |
+
ch: int,
|
| 202 |
+
out_ch: int,
|
| 203 |
+
ch_mult: list[int],
|
| 204 |
+
num_res_blocks: int,
|
| 205 |
+
in_channels: int,
|
| 206 |
+
resolution: int,
|
| 207 |
+
z_channels: int,
|
| 208 |
+
):
|
| 209 |
+
super().__init__()
|
| 210 |
+
self.ch = ch
|
| 211 |
+
self.num_resolutions = len(ch_mult)
|
| 212 |
+
self.num_res_blocks = num_res_blocks
|
| 213 |
+
self.resolution = resolution
|
| 214 |
+
self.in_channels = in_channels
|
| 215 |
+
self.ffactor = 2 ** (self.num_resolutions - 1)
|
| 216 |
+
|
| 217 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
| 218 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
| 219 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
| 220 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
| 221 |
+
|
| 222 |
+
# z to block_in
|
| 223 |
+
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
| 224 |
+
|
| 225 |
+
# middle
|
| 226 |
+
self.mid = nn.Module()
|
| 227 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 228 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 229 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 230 |
+
|
| 231 |
+
# upsampling
|
| 232 |
+
self.up = nn.ModuleList()
|
| 233 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 234 |
+
block = nn.ModuleList()
|
| 235 |
+
attn = nn.ModuleList()
|
| 236 |
+
block_out = ch * ch_mult[i_level]
|
| 237 |
+
for _ in range(self.num_res_blocks + 1):
|
| 238 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
| 239 |
+
block_in = block_out
|
| 240 |
+
up = nn.Module()
|
| 241 |
+
up.block = block
|
| 242 |
+
up.attn = attn
|
| 243 |
+
if i_level != 0:
|
| 244 |
+
up.upsample = Upsample(block_in)
|
| 245 |
+
curr_res = curr_res * 2
|
| 246 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 247 |
+
|
| 248 |
+
# end
|
| 249 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
| 250 |
+
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
| 251 |
+
|
| 252 |
+
def forward(self, z: Tensor) -> Tensor:
|
| 253 |
+
# z to block_in
|
| 254 |
+
h = self.conv_in(z)
|
| 255 |
+
|
| 256 |
+
# middle
|
| 257 |
+
h = self.mid.block_1(h)
|
| 258 |
+
h = self.mid.attn_1(h)
|
| 259 |
+
h = self.mid.block_2(h)
|
| 260 |
+
|
| 261 |
+
# upsampling
|
| 262 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 263 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 264 |
+
h = self.up[i_level].block[i_block](h)
|
| 265 |
+
if len(self.up[i_level].attn) > 0:
|
| 266 |
+
h = self.up[i_level].attn[i_block](h)
|
| 267 |
+
if i_level != 0:
|
| 268 |
+
h = self.up[i_level].upsample(h)
|
| 269 |
+
|
| 270 |
+
# end
|
| 271 |
+
h = self.norm_out(h)
|
| 272 |
+
h = swish(h)
|
| 273 |
+
h = self.conv_out(h)
|
| 274 |
+
return h
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class DiagonalGaussian(nn.Module):
|
| 278 |
+
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
| 279 |
+
super().__init__()
|
| 280 |
+
self.sample = sample
|
| 281 |
+
self.chunk_dim = chunk_dim
|
| 282 |
+
|
| 283 |
+
def forward(self, z: Tensor) -> Tensor:
|
| 284 |
+
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
| 285 |
+
if self.sample:
|
| 286 |
+
std = torch.exp(0.5 * logvar)
|
| 287 |
+
return mean + std * torch.randn_like(mean)
|
| 288 |
+
else:
|
| 289 |
+
return mean
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class AutoEncoder(nn.Module):
|
| 293 |
+
def __init__(self, params: AutoEncoderParams):
|
| 294 |
+
super().__init__()
|
| 295 |
+
self.encoder = Encoder(
|
| 296 |
+
resolution=params.resolution,
|
| 297 |
+
in_channels=params.in_channels,
|
| 298 |
+
ch=params.ch,
|
| 299 |
+
ch_mult=params.ch_mult,
|
| 300 |
+
num_res_blocks=params.num_res_blocks,
|
| 301 |
+
z_channels=params.z_channels,
|
| 302 |
+
)
|
| 303 |
+
self.decoder = Decoder(
|
| 304 |
+
resolution=params.resolution,
|
| 305 |
+
in_channels=params.in_channels,
|
| 306 |
+
ch=params.ch,
|
| 307 |
+
out_ch=params.out_ch,
|
| 308 |
+
ch_mult=params.ch_mult,
|
| 309 |
+
num_res_blocks=params.num_res_blocks,
|
| 310 |
+
z_channels=params.z_channels,
|
| 311 |
+
)
|
| 312 |
+
self.reg = DiagonalGaussian()
|
| 313 |
+
|
| 314 |
+
self.scale_factor = params.scale_factor
|
| 315 |
+
self.shift_factor = params.shift_factor
|
| 316 |
+
|
| 317 |
+
def encode(self, x: Tensor) -> Tensor:
|
| 318 |
+
z = self.reg(self.encoder(x))
|
| 319 |
+
z = self.scale_factor * (z - self.shift_factor)
|
| 320 |
+
return z
|
| 321 |
+
|
| 322 |
+
def decode(self, z: Tensor) -> Tensor:
|
| 323 |
+
z = z / self.scale_factor + self.shift_factor
|
| 324 |
+
return self.decoder(z)
|
| 325 |
+
|
| 326 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 327 |
+
return self.decode(self.encode(x))
|
uno/flux/modules/conditioner.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
|
| 2 |
+
# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
|
| 3 |
+
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from torch import Tensor, nn
|
| 17 |
+
from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
|
| 18 |
+
T5Tokenizer)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class HFEmbedder(nn.Module):
|
| 22 |
+
def __init__(self, version: str, max_length: int, **hf_kwargs):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.is_clip = version.startswith("openai")
|
| 25 |
+
self.max_length = max_length
|
| 26 |
+
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
| 27 |
+
|
| 28 |
+
if self.is_clip:
|
| 29 |
+
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
|
| 30 |
+
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
|
| 31 |
+
else:
|
| 32 |
+
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
|
| 33 |
+
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
|
| 34 |
+
|
| 35 |
+
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
| 36 |
+
|
| 37 |
+
def forward(self, text: list[str]) -> Tensor:
|
| 38 |
+
batch_encoding = self.tokenizer(
|
| 39 |
+
text,
|
| 40 |
+
truncation=True,
|
| 41 |
+
max_length=self.max_length,
|
| 42 |
+
return_length=False,
|
| 43 |
+
return_overflowing_tokens=False,
|
| 44 |
+
padding="max_length",
|
| 45 |
+
return_tensors="pt",
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
outputs = self.hf_module(
|
| 49 |
+
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
| 50 |
+
attention_mask=None,
|
| 51 |
+
output_hidden_states=False,
|
| 52 |
+
)
|
| 53 |
+
return outputs[self.output_key]
|
uno/flux/modules/layers.py
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
|
| 2 |
+
# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
|
| 3 |
+
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import math
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from einops import rearrange
|
| 21 |
+
from torch import Tensor, nn
|
| 22 |
+
|
| 23 |
+
from ..math import attention, rope
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
|
| 26 |
+
class EmbedND(nn.Module):
|
| 27 |
+
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.dim = dim
|
| 30 |
+
self.theta = theta
|
| 31 |
+
self.axes_dim = axes_dim
|
| 32 |
+
|
| 33 |
+
def forward(self, ids: Tensor) -> Tensor:
|
| 34 |
+
n_axes = ids.shape[-1]
|
| 35 |
+
emb = torch.cat(
|
| 36 |
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
| 37 |
+
dim=-3,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
return emb.unsqueeze(1)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
| 44 |
+
"""
|
| 45 |
+
Create sinusoidal timestep embeddings.
|
| 46 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 47 |
+
These may be fractional.
|
| 48 |
+
:param dim: the dimension of the output.
|
| 49 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 50 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 51 |
+
"""
|
| 52 |
+
t = time_factor * t
|
| 53 |
+
half = dim // 2
|
| 54 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
| 55 |
+
t.device
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
args = t[:, None].float() * freqs[None]
|
| 59 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 60 |
+
if dim % 2:
|
| 61 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 62 |
+
if torch.is_floating_point(t):
|
| 63 |
+
embedding = embedding.to(t)
|
| 64 |
+
return embedding
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class MLPEmbedder(nn.Module):
|
| 68 |
+
def __init__(self, in_dim: int, hidden_dim: int):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
| 71 |
+
self.silu = nn.SiLU()
|
| 72 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 73 |
+
|
| 74 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 75 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class RMSNorm(torch.nn.Module):
|
| 79 |
+
def __init__(self, dim: int):
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.scale = nn.Parameter(torch.ones(dim))
|
| 82 |
+
|
| 83 |
+
def forward(self, x: Tensor):
|
| 84 |
+
x_dtype = x.dtype
|
| 85 |
+
x = x.float()
|
| 86 |
+
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
| 87 |
+
return (x * rrms).to(dtype=x_dtype) * self.scale
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class QKNorm(torch.nn.Module):
|
| 91 |
+
def __init__(self, dim: int):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.query_norm = RMSNorm(dim)
|
| 94 |
+
self.key_norm = RMSNorm(dim)
|
| 95 |
+
|
| 96 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
| 97 |
+
q = self.query_norm(q)
|
| 98 |
+
k = self.key_norm(k)
|
| 99 |
+
return q.to(v), k.to(v)
|
| 100 |
+
|
| 101 |
+
class LoRALinearLayer(nn.Module):
|
| 102 |
+
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
|
| 103 |
+
super().__init__()
|
| 104 |
+
|
| 105 |
+
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
|
| 106 |
+
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
|
| 107 |
+
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
| 108 |
+
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
| 109 |
+
self.network_alpha = network_alpha
|
| 110 |
+
self.rank = rank
|
| 111 |
+
|
| 112 |
+
nn.init.normal_(self.down.weight, std=1 / rank)
|
| 113 |
+
nn.init.zeros_(self.up.weight)
|
| 114 |
+
|
| 115 |
+
def forward(self, hidden_states):
|
| 116 |
+
orig_dtype = hidden_states.dtype
|
| 117 |
+
dtype = self.down.weight.dtype
|
| 118 |
+
|
| 119 |
+
down_hidden_states = self.down(hidden_states.to(dtype))
|
| 120 |
+
up_hidden_states = self.up(down_hidden_states)
|
| 121 |
+
|
| 122 |
+
if self.network_alpha is not None:
|
| 123 |
+
up_hidden_states *= self.network_alpha / self.rank
|
| 124 |
+
|
| 125 |
+
return up_hidden_states.to(orig_dtype)
|
| 126 |
+
|
| 127 |
+
class FLuxSelfAttnProcessor:
|
| 128 |
+
def __call__(self, attn, x, pe, **attention_kwargs):
|
| 129 |
+
qkv = attn.qkv(x)
|
| 130 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
| 131 |
+
q, k = attn.norm(q, k, v)
|
| 132 |
+
x = attention(q, k, v, pe=pe)
|
| 133 |
+
x = attn.proj(x)
|
| 134 |
+
return x
|
| 135 |
+
|
| 136 |
+
class LoraFluxAttnProcessor(nn.Module):
|
| 137 |
+
|
| 138 |
+
def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
|
| 139 |
+
super().__init__()
|
| 140 |
+
self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
|
| 141 |
+
self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha)
|
| 142 |
+
self.lora_weight = lora_weight
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def __call__(self, attn, x, pe, **attention_kwargs):
|
| 146 |
+
qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight
|
| 147 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
| 148 |
+
q, k = attn.norm(q, k, v)
|
| 149 |
+
x = attention(q, k, v, pe=pe)
|
| 150 |
+
x = attn.proj(x) + self.proj_lora(x) * self.lora_weight
|
| 151 |
+
return x
|
| 152 |
+
|
| 153 |
+
class SelfAttention(nn.Module):
|
| 154 |
+
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
| 155 |
+
super().__init__()
|
| 156 |
+
self.num_heads = num_heads
|
| 157 |
+
head_dim = dim // num_heads
|
| 158 |
+
|
| 159 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 160 |
+
self.norm = QKNorm(head_dim)
|
| 161 |
+
self.proj = nn.Linear(dim, dim)
|
| 162 |
+
def forward():
|
| 163 |
+
pass
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
@dataclass
|
| 167 |
+
class ModulationOut:
|
| 168 |
+
shift: Tensor
|
| 169 |
+
scale: Tensor
|
| 170 |
+
gate: Tensor
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class Modulation(nn.Module):
|
| 174 |
+
def __init__(self, dim: int, double: bool):
|
| 175 |
+
super().__init__()
|
| 176 |
+
self.is_double = double
|
| 177 |
+
self.multiplier = 6 if double else 3
|
| 178 |
+
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
| 179 |
+
|
| 180 |
+
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
| 181 |
+
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
| 182 |
+
|
| 183 |
+
return (
|
| 184 |
+
ModulationOut(*out[:3]),
|
| 185 |
+
ModulationOut(*out[3:]) if self.is_double else None,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
class DoubleStreamBlockLoraProcessor(nn.Module):
|
| 189 |
+
def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
|
| 190 |
+
super().__init__()
|
| 191 |
+
self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
|
| 192 |
+
self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha)
|
| 193 |
+
self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
|
| 194 |
+
self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha)
|
| 195 |
+
self.lora_weight = lora_weight
|
| 196 |
+
|
| 197 |
+
def forward(self, attn, img, txt, vec, pe, **attention_kwargs):
|
| 198 |
+
img_mod1, img_mod2 = attn.img_mod(vec)
|
| 199 |
+
txt_mod1, txt_mod2 = attn.txt_mod(vec)
|
| 200 |
+
|
| 201 |
+
# prepare image for attention
|
| 202 |
+
img_modulated = attn.img_norm1(img)
|
| 203 |
+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
| 204 |
+
img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight
|
| 205 |
+
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
|
| 206 |
+
img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
|
| 207 |
+
|
| 208 |
+
# prepare txt for attention
|
| 209 |
+
txt_modulated = attn.txt_norm1(txt)
|
| 210 |
+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
| 211 |
+
txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight
|
| 212 |
+
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
|
| 213 |
+
txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
|
| 214 |
+
|
| 215 |
+
# run actual attention
|
| 216 |
+
q = torch.cat((txt_q, img_q), dim=2)
|
| 217 |
+
k = torch.cat((txt_k, img_k), dim=2)
|
| 218 |
+
v = torch.cat((txt_v, img_v), dim=2)
|
| 219 |
+
|
| 220 |
+
attn1 = attention(q, k, v, pe=pe)
|
| 221 |
+
txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
|
| 222 |
+
|
| 223 |
+
# calculate the img bloks
|
| 224 |
+
img = img + img_mod1.gate * (attn.img_attn.proj(img_attn) + self.proj_lora1(img_attn) * self.lora_weight)
|
| 225 |
+
img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
|
| 226 |
+
|
| 227 |
+
# calculate the txt bloks
|
| 228 |
+
txt = txt + txt_mod1.gate * (attn.txt_attn.proj(txt_attn) + self.proj_lora2(txt_attn) * self.lora_weight)
|
| 229 |
+
txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
|
| 230 |
+
return img, txt
|
| 231 |
+
|
| 232 |
+
class DoubleStreamBlockProcessor:
|
| 233 |
+
def __call__(self, attn, img, txt, vec, pe, **attention_kwargs):
|
| 234 |
+
img_mod1, img_mod2 = attn.img_mod(vec)
|
| 235 |
+
txt_mod1, txt_mod2 = attn.txt_mod(vec)
|
| 236 |
+
|
| 237 |
+
# prepare image for attention
|
| 238 |
+
img_modulated = attn.img_norm1(img)
|
| 239 |
+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
| 240 |
+
img_qkv = attn.img_attn.qkv(img_modulated)
|
| 241 |
+
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
|
| 242 |
+
img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
|
| 243 |
+
|
| 244 |
+
# prepare txt for attention
|
| 245 |
+
txt_modulated = attn.txt_norm1(txt)
|
| 246 |
+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
| 247 |
+
txt_qkv = attn.txt_attn.qkv(txt_modulated)
|
| 248 |
+
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
|
| 249 |
+
txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
|
| 250 |
+
|
| 251 |
+
# run actual attention
|
| 252 |
+
q = torch.cat((txt_q, img_q), dim=2)
|
| 253 |
+
k = torch.cat((txt_k, img_k), dim=2)
|
| 254 |
+
v = torch.cat((txt_v, img_v), dim=2)
|
| 255 |
+
|
| 256 |
+
attn1 = attention(q, k, v, pe=pe)
|
| 257 |
+
txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
|
| 258 |
+
|
| 259 |
+
# calculate the img bloks
|
| 260 |
+
img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
|
| 261 |
+
img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
|
| 262 |
+
|
| 263 |
+
# calculate the txt bloks
|
| 264 |
+
txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
|
| 265 |
+
txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
|
| 266 |
+
return img, txt
|
| 267 |
+
|
| 268 |
+
class DoubleStreamBlock(nn.Module):
|
| 269 |
+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
|
| 270 |
+
super().__init__()
|
| 271 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 272 |
+
self.num_heads = num_heads
|
| 273 |
+
self.hidden_size = hidden_size
|
| 274 |
+
self.head_dim = hidden_size // num_heads
|
| 275 |
+
|
| 276 |
+
self.img_mod = Modulation(hidden_size, double=True)
|
| 277 |
+
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 278 |
+
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
| 279 |
+
|
| 280 |
+
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 281 |
+
self.img_mlp = nn.Sequential(
|
| 282 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
| 283 |
+
nn.GELU(approximate="tanh"),
|
| 284 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
self.txt_mod = Modulation(hidden_size, double=True)
|
| 288 |
+
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 289 |
+
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
| 290 |
+
|
| 291 |
+
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 292 |
+
self.txt_mlp = nn.Sequential(
|
| 293 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
| 294 |
+
nn.GELU(approximate="tanh"),
|
| 295 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
| 296 |
+
)
|
| 297 |
+
processor = DoubleStreamBlockProcessor()
|
| 298 |
+
self.set_processor(processor)
|
| 299 |
+
|
| 300 |
+
def set_processor(self, processor) -> None:
|
| 301 |
+
self.processor = processor
|
| 302 |
+
|
| 303 |
+
def get_processor(self):
|
| 304 |
+
return self.processor
|
| 305 |
+
|
| 306 |
+
def forward(
|
| 307 |
+
self,
|
| 308 |
+
img: Tensor,
|
| 309 |
+
txt: Tensor,
|
| 310 |
+
vec: Tensor,
|
| 311 |
+
pe: Tensor,
|
| 312 |
+
image_proj: Tensor = None,
|
| 313 |
+
ip_scale: float =1.0,
|
| 314 |
+
) -> tuple[Tensor, Tensor]:
|
| 315 |
+
if image_proj is None:
|
| 316 |
+
return self.processor(self, img, txt, vec, pe)
|
| 317 |
+
else:
|
| 318 |
+
return self.processor(self, img, txt, vec, pe, image_proj, ip_scale)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class SingleStreamBlockLoraProcessor(nn.Module):
|
| 322 |
+
def __init__(self, dim: int, rank: int = 4, network_alpha = None, lora_weight: float = 1):
|
| 323 |
+
super().__init__()
|
| 324 |
+
self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
|
| 325 |
+
self.proj_lora = LoRALinearLayer(15360, dim, rank, network_alpha)
|
| 326 |
+
self.lora_weight = lora_weight
|
| 327 |
+
|
| 328 |
+
def forward(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
| 329 |
+
|
| 330 |
+
mod, _ = attn.modulation(vec)
|
| 331 |
+
x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
|
| 332 |
+
qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
|
| 333 |
+
qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight
|
| 334 |
+
|
| 335 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
|
| 336 |
+
q, k = attn.norm(q, k, v)
|
| 337 |
+
|
| 338 |
+
# compute attention
|
| 339 |
+
attn_1 = attention(q, k, v, pe=pe)
|
| 340 |
+
|
| 341 |
+
# compute activation in mlp stream, cat again and run second linear layer
|
| 342 |
+
output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
|
| 343 |
+
output = output + self.proj_lora(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) * self.lora_weight
|
| 344 |
+
output = x + mod.gate * output
|
| 345 |
+
return output
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
class SingleStreamBlockProcessor:
|
| 349 |
+
def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor, **attention_kwargs) -> Tensor:
|
| 350 |
+
|
| 351 |
+
mod, _ = attn.modulation(vec)
|
| 352 |
+
x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
|
| 353 |
+
qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
|
| 354 |
+
|
| 355 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
|
| 356 |
+
q, k = attn.norm(q, k, v)
|
| 357 |
+
|
| 358 |
+
# compute attention
|
| 359 |
+
attn_1 = attention(q, k, v, pe=pe)
|
| 360 |
+
|
| 361 |
+
# compute activation in mlp stream, cat again and run second linear layer
|
| 362 |
+
output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
|
| 363 |
+
output = x + mod.gate * output
|
| 364 |
+
return output
|
| 365 |
+
|
| 366 |
+
class SingleStreamBlock(nn.Module):
|
| 367 |
+
"""
|
| 368 |
+
A DiT block with parallel linear layers as described in
|
| 369 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
| 370 |
+
"""
|
| 371 |
+
|
| 372 |
+
def __init__(
|
| 373 |
+
self,
|
| 374 |
+
hidden_size: int,
|
| 375 |
+
num_heads: int,
|
| 376 |
+
mlp_ratio: float = 4.0,
|
| 377 |
+
qk_scale: float | None = None,
|
| 378 |
+
):
|
| 379 |
+
super().__init__()
|
| 380 |
+
self.hidden_dim = hidden_size
|
| 381 |
+
self.num_heads = num_heads
|
| 382 |
+
self.head_dim = hidden_size // num_heads
|
| 383 |
+
self.scale = qk_scale or self.head_dim**-0.5
|
| 384 |
+
|
| 385 |
+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 386 |
+
# qkv and mlp_in
|
| 387 |
+
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
| 388 |
+
# proj and mlp_out
|
| 389 |
+
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
| 390 |
+
|
| 391 |
+
self.norm = QKNorm(self.head_dim)
|
| 392 |
+
|
| 393 |
+
self.hidden_size = hidden_size
|
| 394 |
+
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 395 |
+
|
| 396 |
+
self.mlp_act = nn.GELU(approximate="tanh")
|
| 397 |
+
self.modulation = Modulation(hidden_size, double=False)
|
| 398 |
+
|
| 399 |
+
processor = SingleStreamBlockProcessor()
|
| 400 |
+
self.set_processor(processor)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def set_processor(self, processor) -> None:
|
| 404 |
+
self.processor = processor
|
| 405 |
+
|
| 406 |
+
def get_processor(self):
|
| 407 |
+
return self.processor
|
| 408 |
+
|
| 409 |
+
def forward(
|
| 410 |
+
self,
|
| 411 |
+
x: Tensor,
|
| 412 |
+
vec: Tensor,
|
| 413 |
+
pe: Tensor,
|
| 414 |
+
image_proj: Tensor | None = None,
|
| 415 |
+
ip_scale: float = 1.0,
|
| 416 |
+
) -> Tensor:
|
| 417 |
+
if image_proj is None:
|
| 418 |
+
return self.processor(self, x, vec, pe)
|
| 419 |
+
else:
|
| 420 |
+
return self.processor(self, x, vec, pe, image_proj, ip_scale)
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
class LastLayer(nn.Module):
|
| 425 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
| 426 |
+
super().__init__()
|
| 427 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 428 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
| 429 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
| 430 |
+
|
| 431 |
+
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
| 432 |
+
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
| 433 |
+
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
| 434 |
+
x = self.linear(x)
|
| 435 |
+
return x
|
uno/flux/pipeline.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
|
| 2 |
+
# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
|
| 3 |
+
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
from typing import Literal
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from einops import rearrange
|
| 21 |
+
from PIL import ExifTags, Image
|
| 22 |
+
import torchvision.transforms.functional as TVF
|
| 23 |
+
|
| 24 |
+
from uno.flux.modules.layers import (
|
| 25 |
+
DoubleStreamBlockLoraProcessor,
|
| 26 |
+
DoubleStreamBlockProcessor,
|
| 27 |
+
SingleStreamBlockLoraProcessor,
|
| 28 |
+
SingleStreamBlockProcessor,
|
| 29 |
+
)
|
| 30 |
+
from uno.flux.sampling import denoise, get_noise, get_schedule, prepare_multi_ip, unpack
|
| 31 |
+
from uno.flux.util import (
|
| 32 |
+
get_lora_rank,
|
| 33 |
+
load_ae,
|
| 34 |
+
load_checkpoint,
|
| 35 |
+
load_clip,
|
| 36 |
+
load_flow_model,
|
| 37 |
+
load_flow_model_only_lora,
|
| 38 |
+
load_flow_model_quintized,
|
| 39 |
+
load_t5,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def find_nearest_scale(image_h, image_w, predefined_scales):
|
| 44 |
+
"""
|
| 45 |
+
根据图片的高度和宽度,找到最近的预定义尺度。
|
| 46 |
+
|
| 47 |
+
:param image_h: 图片的高度
|
| 48 |
+
:param image_w: 图片的宽度
|
| 49 |
+
:param predefined_scales: 预定义尺度列表 [(h1, w1), (h2, w2), ...]
|
| 50 |
+
:return: 最近的预定义尺度 (h, w)
|
| 51 |
+
"""
|
| 52 |
+
# 计算输入图片的长宽比
|
| 53 |
+
image_ratio = image_h / image_w
|
| 54 |
+
|
| 55 |
+
# 初始化变量以存储最小差异和最近的尺度
|
| 56 |
+
min_diff = float('inf')
|
| 57 |
+
nearest_scale = None
|
| 58 |
+
|
| 59 |
+
# 遍历所有预定义尺度,找到与输入图片长宽比最接近的尺度
|
| 60 |
+
for scale_h, scale_w in predefined_scales:
|
| 61 |
+
predefined_ratio = scale_h / scale_w
|
| 62 |
+
diff = abs(predefined_ratio - image_ratio)
|
| 63 |
+
|
| 64 |
+
if diff < min_diff:
|
| 65 |
+
min_diff = diff
|
| 66 |
+
nearest_scale = (scale_h, scale_w)
|
| 67 |
+
|
| 68 |
+
return nearest_scale
|
| 69 |
+
|
| 70 |
+
def preprocess_ref(raw_image: Image.Image, long_size: int = 512):
|
| 71 |
+
# 获取原始图像的宽度和高度
|
| 72 |
+
image_w, image_h = raw_image.size
|
| 73 |
+
|
| 74 |
+
# 计算长边和短边
|
| 75 |
+
if image_w >= image_h:
|
| 76 |
+
new_w = long_size
|
| 77 |
+
new_h = int((long_size / image_w) * image_h)
|
| 78 |
+
else:
|
| 79 |
+
new_h = long_size
|
| 80 |
+
new_w = int((long_size / image_h) * image_w)
|
| 81 |
+
|
| 82 |
+
# 按新的宽高进行等比例缩放
|
| 83 |
+
raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS)
|
| 84 |
+
target_w = new_w // 16 * 16
|
| 85 |
+
target_h = new_h // 16 * 16
|
| 86 |
+
|
| 87 |
+
# 计算裁剪的起始坐标以实现中心裁剪
|
| 88 |
+
left = (new_w - target_w) // 2
|
| 89 |
+
top = (new_h - target_h) // 2
|
| 90 |
+
right = left + target_w
|
| 91 |
+
bottom = top + target_h
|
| 92 |
+
|
| 93 |
+
# 进行中心裁剪
|
| 94 |
+
raw_image = raw_image.crop((left, top, right, bottom))
|
| 95 |
+
|
| 96 |
+
# 转换为 RGB 模式
|
| 97 |
+
raw_image = raw_image.convert("RGB")
|
| 98 |
+
return raw_image
|
| 99 |
+
|
| 100 |
+
class UNOPipeline:
|
| 101 |
+
def __init__(
|
| 102 |
+
self,
|
| 103 |
+
model_type: str,
|
| 104 |
+
device: torch.device,
|
| 105 |
+
offload: bool = False,
|
| 106 |
+
only_lora: bool = False,
|
| 107 |
+
lora_rank: int = 16
|
| 108 |
+
):
|
| 109 |
+
self.device = device
|
| 110 |
+
self.offload = offload
|
| 111 |
+
self.model_type = model_type
|
| 112 |
+
|
| 113 |
+
self.clip = load_clip(self.device)
|
| 114 |
+
self.t5 = load_t5(self.device, max_length=512)
|
| 115 |
+
self.ae = load_ae(model_type, device="cpu" if offload else self.device)
|
| 116 |
+
if "fp8" in model_type:
|
| 117 |
+
self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.device)
|
| 118 |
+
elif only_lora:
|
| 119 |
+
self.model = load_flow_model_only_lora(
|
| 120 |
+
model_type, device="cpu" if offload else self.device, lora_rank=lora_rank
|
| 121 |
+
)
|
| 122 |
+
else:
|
| 123 |
+
self.model = load_flow_model(model_type, device="cpu" if offload else self.device)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def load_ckpt(self, ckpt_path):
|
| 127 |
+
if ckpt_path is not None:
|
| 128 |
+
from safetensors.torch import load_file as load_sft
|
| 129 |
+
print("Loading checkpoint to replace old keys")
|
| 130 |
+
# load_sft doesn't support torch.device
|
| 131 |
+
if ckpt_path.endswith('safetensors'):
|
| 132 |
+
sd = load_sft(ckpt_path, device='cpu')
|
| 133 |
+
missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True)
|
| 134 |
+
else:
|
| 135 |
+
dit_state = torch.load(ckpt_path, map_location='cpu')
|
| 136 |
+
sd = {}
|
| 137 |
+
for k in dit_state.keys():
|
| 138 |
+
sd[k.replace('module.','')] = dit_state[k]
|
| 139 |
+
missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True)
|
| 140 |
+
self.model.to(str(self.device))
|
| 141 |
+
print(f"missing keys: {missing}\n\n\n\n\nunexpected keys: {unexpected}")
|
| 142 |
+
|
| 143 |
+
def set_lora(self, local_path: str = None, repo_id: str = None,
|
| 144 |
+
name: str = None, lora_weight: int = 0.7):
|
| 145 |
+
checkpoint = load_checkpoint(local_path, repo_id, name)
|
| 146 |
+
self.update_model_with_lora(checkpoint, lora_weight)
|
| 147 |
+
|
| 148 |
+
def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7):
|
| 149 |
+
checkpoint = load_checkpoint(
|
| 150 |
+
None, self.hf_lora_collection, self.lora_types_to_names[lora_type]
|
| 151 |
+
)
|
| 152 |
+
self.update_model_with_lora(checkpoint, lora_weight)
|
| 153 |
+
|
| 154 |
+
def update_model_with_lora(self, checkpoint, lora_weight):
|
| 155 |
+
rank = get_lora_rank(checkpoint)
|
| 156 |
+
lora_attn_procs = {}
|
| 157 |
+
|
| 158 |
+
for name, _ in self.model.attn_processors.items():
|
| 159 |
+
lora_state_dict = {}
|
| 160 |
+
for k in checkpoint.keys():
|
| 161 |
+
if name in k:
|
| 162 |
+
lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight
|
| 163 |
+
|
| 164 |
+
if len(lora_state_dict):
|
| 165 |
+
if name.startswith("single_blocks"):
|
| 166 |
+
lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=3072, rank=rank)
|
| 167 |
+
else:
|
| 168 |
+
lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank)
|
| 169 |
+
lora_attn_procs[name].load_state_dict(lora_state_dict)
|
| 170 |
+
lora_attn_procs[name].to(self.device)
|
| 171 |
+
else:
|
| 172 |
+
if name.startswith("single_blocks"):
|
| 173 |
+
lora_attn_procs[name] = SingleStreamBlockProcessor()
|
| 174 |
+
else:
|
| 175 |
+
lora_attn_procs[name] = DoubleStreamBlockProcessor()
|
| 176 |
+
|
| 177 |
+
self.model.set_attn_processor(lora_attn_procs)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def __call__(
|
| 181 |
+
self,
|
| 182 |
+
prompt: str,
|
| 183 |
+
width: int = 512,
|
| 184 |
+
height: int = 512,
|
| 185 |
+
guidance: float = 4,
|
| 186 |
+
num_steps: int = 50,
|
| 187 |
+
seed: int = 123456789,
|
| 188 |
+
**kwargs
|
| 189 |
+
):
|
| 190 |
+
width = 16 * (width // 16)
|
| 191 |
+
height = 16 * (height // 16)
|
| 192 |
+
|
| 193 |
+
return self.forward(
|
| 194 |
+
prompt,
|
| 195 |
+
width,
|
| 196 |
+
height,
|
| 197 |
+
guidance,
|
| 198 |
+
num_steps,
|
| 199 |
+
seed,
|
| 200 |
+
**kwargs
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
@torch.inference_mode()
|
| 204 |
+
def gradio_generate(
|
| 205 |
+
self,
|
| 206 |
+
prompt: str,
|
| 207 |
+
width: int,
|
| 208 |
+
height: int,
|
| 209 |
+
guidance: float,
|
| 210 |
+
num_steps: int,
|
| 211 |
+
seed: int,
|
| 212 |
+
image_prompt1: Image.Image,
|
| 213 |
+
image_prompt2: Image.Image,
|
| 214 |
+
image_prompt3: Image.Image,
|
| 215 |
+
image_prompt4: Image.Image,
|
| 216 |
+
):
|
| 217 |
+
ref_imgs = [image_prompt1, image_prompt2, image_prompt3, image_prompt4]
|
| 218 |
+
ref_imgs = [img for img in ref_imgs if isinstance(img, Image.Image)]
|
| 219 |
+
ref_long_side = 512 if len(ref_imgs) <= 1 else 320
|
| 220 |
+
ref_imgs = [preprocess_ref(img, ref_long_side) for img in ref_imgs]
|
| 221 |
+
|
| 222 |
+
seed = seed if seed != -1 else torch.randint(0, 10 ** 8, (1,)).item()
|
| 223 |
+
|
| 224 |
+
img = self(prompt=prompt, width=width, height=height, guidance=guidance,
|
| 225 |
+
num_steps=num_steps, seed=seed, ref_imgs=ref_imgs)
|
| 226 |
+
|
| 227 |
+
filename = f"output/gradio/{seed}_{prompt[:20]}.png"
|
| 228 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
| 229 |
+
exif_data = Image.Exif()
|
| 230 |
+
exif_data[ExifTags.Base.Make] = "UNO"
|
| 231 |
+
exif_data[ExifTags.Base.Model] = self.model_type
|
| 232 |
+
info = f"{prompt=}, {seed=}, {width=}, {height=}, {guidance=}, {num_steps=}"
|
| 233 |
+
exif_data[ExifTags.Base.ImageDescription] = info
|
| 234 |
+
img.save(filename, format="png", exif=exif_data)
|
| 235 |
+
return img, filename
|
| 236 |
+
|
| 237 |
+
@torch.inference_mode
|
| 238 |
+
def forward(
|
| 239 |
+
self,
|
| 240 |
+
prompt: str,
|
| 241 |
+
width: int,
|
| 242 |
+
height: int,
|
| 243 |
+
guidance: float,
|
| 244 |
+
num_steps: int,
|
| 245 |
+
seed: int,
|
| 246 |
+
ref_imgs: list[Image.Image] | None = None,
|
| 247 |
+
pe: Literal['d', 'h', 'w', 'o'] = 'd',
|
| 248 |
+
):
|
| 249 |
+
x = get_noise(
|
| 250 |
+
1, height, width, device=self.device,
|
| 251 |
+
dtype=torch.bfloat16, seed=seed
|
| 252 |
+
)
|
| 253 |
+
timesteps = get_schedule(
|
| 254 |
+
num_steps,
|
| 255 |
+
(width // 8) * (height // 8) // (16 * 16),
|
| 256 |
+
shift=True,
|
| 257 |
+
)
|
| 258 |
+
if self.offload:
|
| 259 |
+
self.ae.encoder = self.ae.encoder.to(self.device)
|
| 260 |
+
x_1_refs = [
|
| 261 |
+
self.ae.encode(
|
| 262 |
+
(TVF.to_tensor(ref_img) * 2.0 - 1.0)
|
| 263 |
+
.unsqueeze(0).to(self.device, torch.float32)
|
| 264 |
+
).to(torch.bfloat16)
|
| 265 |
+
for ref_img in ref_imgs
|
| 266 |
+
]
|
| 267 |
+
|
| 268 |
+
if self.offload:
|
| 269 |
+
self.ae.encoder = self.offload_model_to_cpu(self.ae.encoder)
|
| 270 |
+
self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
|
| 271 |
+
inp_cond = prepare_multi_ip(
|
| 272 |
+
t5=self.t5, clip=self.clip,
|
| 273 |
+
img=x,
|
| 274 |
+
prompt=prompt, ref_imgs=x_1_refs, pe=pe
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
if self.offload:
|
| 278 |
+
self.offload_model_to_cpu(self.t5, self.clip)
|
| 279 |
+
self.model = self.model.to(self.device)
|
| 280 |
+
|
| 281 |
+
x = denoise(
|
| 282 |
+
self.model,
|
| 283 |
+
**inp_cond,
|
| 284 |
+
timesteps=timesteps,
|
| 285 |
+
guidance=guidance,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if self.offload:
|
| 289 |
+
self.offload_model_to_cpu(self.model)
|
| 290 |
+
self.ae.decoder.to(x.device)
|
| 291 |
+
x = unpack(x.float(), height, width)
|
| 292 |
+
x = self.ae.decode(x)
|
| 293 |
+
self.offload_model_to_cpu(self.ae.decoder)
|
| 294 |
+
|
| 295 |
+
x1 = x.clamp(-1, 1)
|
| 296 |
+
x1 = rearrange(x1[-1], "c h w -> h w c")
|
| 297 |
+
output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy())
|
| 298 |
+
return output_img
|
| 299 |
+
|
| 300 |
+
def offload_model_to_cpu(self, *models):
|
| 301 |
+
if not self.offload: return
|
| 302 |
+
for model in models:
|
| 303 |
+
model.cpu()
|
| 304 |
+
torch.cuda.empty_cache()
|
uno/flux/sampling.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
|
| 2 |
+
# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
|
| 3 |
+
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import math
|
| 17 |
+
from typing import Literal
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from einops import rearrange, repeat
|
| 21 |
+
from torch import Tensor
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
|
| 24 |
+
from .model import Flux
|
| 25 |
+
from .modules.conditioner import HFEmbedder
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_noise(
|
| 29 |
+
num_samples: int,
|
| 30 |
+
height: int,
|
| 31 |
+
width: int,
|
| 32 |
+
device: torch.device,
|
| 33 |
+
dtype: torch.dtype,
|
| 34 |
+
seed: int,
|
| 35 |
+
):
|
| 36 |
+
return torch.randn(
|
| 37 |
+
num_samples,
|
| 38 |
+
16,
|
| 39 |
+
# allow for packing
|
| 40 |
+
2 * math.ceil(height / 16),
|
| 41 |
+
2 * math.ceil(width / 16),
|
| 42 |
+
device=device,
|
| 43 |
+
dtype=dtype,
|
| 44 |
+
generator=torch.Generator(device=device).manual_seed(seed),
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def prepare(
|
| 49 |
+
t5: HFEmbedder,
|
| 50 |
+
clip: HFEmbedder,
|
| 51 |
+
img: Tensor,
|
| 52 |
+
prompt: str | list[str],
|
| 53 |
+
ref_img: None | Tensor=None,
|
| 54 |
+
pe: Literal['d', 'h', 'w', 'o'] ='d'
|
| 55 |
+
) -> dict[str, Tensor]:
|
| 56 |
+
assert pe in ['d', 'h', 'w', 'o']
|
| 57 |
+
bs, c, h, w = img.shape
|
| 58 |
+
if bs == 1 and not isinstance(prompt, str):
|
| 59 |
+
bs = len(prompt)
|
| 60 |
+
|
| 61 |
+
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
| 62 |
+
if img.shape[0] == 1 and bs > 1:
|
| 63 |
+
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
| 64 |
+
|
| 65 |
+
img_ids = torch.zeros(h // 2, w // 2, 3)
|
| 66 |
+
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
| 67 |
+
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
| 68 |
+
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
| 69 |
+
|
| 70 |
+
if ref_img is not None:
|
| 71 |
+
_, _, ref_h, ref_w = ref_img.shape
|
| 72 |
+
ref_img = rearrange(ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
| 73 |
+
if ref_img.shape[0] == 1 and bs > 1:
|
| 74 |
+
ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
|
| 75 |
+
ref_img_ids = torch.zeros(ref_h // 2, ref_w // 2, 3)
|
| 76 |
+
# img id分别在宽高偏移各自最大值
|
| 77 |
+
h_offset = h // 2 if pe in {'d', 'h'} else 0
|
| 78 |
+
w_offset = w // 2 if pe in {'d', 'w'} else 0
|
| 79 |
+
ref_img_ids[..., 1] = ref_img_ids[..., 1] + torch.arange(ref_h // 2)[:, None] + h_offset
|
| 80 |
+
ref_img_ids[..., 2] = ref_img_ids[..., 2] + torch.arange(ref_w // 2)[None, :] + w_offset
|
| 81 |
+
ref_img_ids = repeat(ref_img_ids, "h w c -> b (h w) c", b=bs)
|
| 82 |
+
|
| 83 |
+
if isinstance(prompt, str):
|
| 84 |
+
prompt = [prompt]
|
| 85 |
+
txt = t5(prompt)
|
| 86 |
+
if txt.shape[0] == 1 and bs > 1:
|
| 87 |
+
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
| 88 |
+
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
| 89 |
+
|
| 90 |
+
vec = clip(prompt)
|
| 91 |
+
if vec.shape[0] == 1 and bs > 1:
|
| 92 |
+
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
| 93 |
+
|
| 94 |
+
if ref_img is not None:
|
| 95 |
+
return {
|
| 96 |
+
"img": img,
|
| 97 |
+
"img_ids": img_ids.to(img.device),
|
| 98 |
+
"ref_img": ref_img,
|
| 99 |
+
"ref_img_ids": ref_img_ids.to(img.device),
|
| 100 |
+
"txt": txt.to(img.device),
|
| 101 |
+
"txt_ids": txt_ids.to(img.device),
|
| 102 |
+
"vec": vec.to(img.device),
|
| 103 |
+
}
|
| 104 |
+
else:
|
| 105 |
+
return {
|
| 106 |
+
"img": img,
|
| 107 |
+
"img_ids": img_ids.to(img.device),
|
| 108 |
+
"txt": txt.to(img.device),
|
| 109 |
+
"txt_ids": txt_ids.to(img.device),
|
| 110 |
+
"vec": vec.to(img.device),
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
def prepare_multi_ip(
|
| 114 |
+
t5: HFEmbedder,
|
| 115 |
+
clip: HFEmbedder,
|
| 116 |
+
img: Tensor,
|
| 117 |
+
prompt: str | list[str],
|
| 118 |
+
ref_imgs: list[Tensor] | None = None,
|
| 119 |
+
pe: Literal['d', 'h', 'w', 'o'] = 'd'
|
| 120 |
+
) -> dict[str, Tensor]:
|
| 121 |
+
assert pe in ['d', 'h', 'w', 'o']
|
| 122 |
+
bs, c, h, w = img.shape
|
| 123 |
+
if bs == 1 and not isinstance(prompt, str):
|
| 124 |
+
bs = len(prompt)
|
| 125 |
+
|
| 126 |
+
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
| 127 |
+
if img.shape[0] == 1 and bs > 1:
|
| 128 |
+
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
| 129 |
+
|
| 130 |
+
img_ids = torch.zeros(h // 2, w // 2, 3)
|
| 131 |
+
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
| 132 |
+
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
| 133 |
+
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
| 134 |
+
|
| 135 |
+
ref_img_ids = []
|
| 136 |
+
ref_imgs_list = []
|
| 137 |
+
pe_shift_w, pe_shift_h = w // 2, h // 2
|
| 138 |
+
for ref_img in ref_imgs:
|
| 139 |
+
_, _, ref_h1, ref_w1 = ref_img.shape
|
| 140 |
+
ref_img = rearrange(ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
| 141 |
+
if ref_img.shape[0] == 1 and bs > 1:
|
| 142 |
+
ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
|
| 143 |
+
ref_img_ids1 = torch.zeros(ref_h1 // 2, ref_w1 // 2, 3)
|
| 144 |
+
# img id分别���宽高偏移各自最大值
|
| 145 |
+
h_offset = pe_shift_h if pe in {'d', 'h'} else 0
|
| 146 |
+
w_offset = pe_shift_w if pe in {'d', 'w'} else 0
|
| 147 |
+
ref_img_ids1[..., 1] = ref_img_ids1[..., 1] + torch.arange(ref_h1 // 2)[:, None] + h_offset
|
| 148 |
+
ref_img_ids1[..., 2] = ref_img_ids1[..., 2] + torch.arange(ref_w1 // 2)[None, :] + w_offset
|
| 149 |
+
ref_img_ids1 = repeat(ref_img_ids1, "h w c -> b (h w) c", b=bs)
|
| 150 |
+
ref_img_ids.append(ref_img_ids1)
|
| 151 |
+
ref_imgs_list.append(ref_img)
|
| 152 |
+
|
| 153 |
+
# 更新pe shift
|
| 154 |
+
pe_shift_h += ref_h1 // 2
|
| 155 |
+
pe_shift_w += ref_w1 // 2
|
| 156 |
+
|
| 157 |
+
if isinstance(prompt, str):
|
| 158 |
+
prompt = [prompt]
|
| 159 |
+
txt = t5(prompt)
|
| 160 |
+
if txt.shape[0] == 1 and bs > 1:
|
| 161 |
+
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
| 162 |
+
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
| 163 |
+
|
| 164 |
+
vec = clip(prompt)
|
| 165 |
+
if vec.shape[0] == 1 and bs > 1:
|
| 166 |
+
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
| 167 |
+
|
| 168 |
+
return {
|
| 169 |
+
"img": img,
|
| 170 |
+
"img_ids": img_ids.to(img.device),
|
| 171 |
+
"ref_img": tuple(ref_imgs_list),
|
| 172 |
+
"ref_img_ids": [ref_img_id.to(img.device) for ref_img_id in ref_img_ids],
|
| 173 |
+
"txt": txt.to(img.device),
|
| 174 |
+
"txt_ids": txt_ids.to(img.device),
|
| 175 |
+
"vec": vec.to(img.device),
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def time_shift(mu: float, sigma: float, t: Tensor):
|
| 180 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def get_lin_function(
|
| 184 |
+
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
|
| 185 |
+
):
|
| 186 |
+
m = (y2 - y1) / (x2 - x1)
|
| 187 |
+
b = y1 - m * x1
|
| 188 |
+
return lambda x: m * x + b
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def get_schedule(
|
| 192 |
+
num_steps: int,
|
| 193 |
+
image_seq_len: int,
|
| 194 |
+
base_shift: float = 0.5,
|
| 195 |
+
max_shift: float = 1.15,
|
| 196 |
+
shift: bool = True,
|
| 197 |
+
) -> list[float]:
|
| 198 |
+
# extra step for zero
|
| 199 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
| 200 |
+
|
| 201 |
+
# shifting the schedule to favor high timesteps for higher signal images
|
| 202 |
+
if shift:
|
| 203 |
+
# eastimate mu based on linear estimation between two points
|
| 204 |
+
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
| 205 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
| 206 |
+
|
| 207 |
+
return timesteps.tolist()
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def denoise(
|
| 211 |
+
model: Flux,
|
| 212 |
+
# model input
|
| 213 |
+
img: Tensor,
|
| 214 |
+
img_ids: Tensor,
|
| 215 |
+
txt: Tensor,
|
| 216 |
+
txt_ids: Tensor,
|
| 217 |
+
vec: Tensor,
|
| 218 |
+
# sampling parameters
|
| 219 |
+
timesteps: list[float],
|
| 220 |
+
guidance: float = 4.0,
|
| 221 |
+
ref_img: Tensor=None,
|
| 222 |
+
ref_img_ids: Tensor=None,
|
| 223 |
+
):
|
| 224 |
+
i = 0
|
| 225 |
+
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
| 226 |
+
for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1):
|
| 227 |
+
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
| 228 |
+
pred = model(
|
| 229 |
+
img=img,
|
| 230 |
+
img_ids=img_ids,
|
| 231 |
+
ref_img=ref_img,
|
| 232 |
+
ref_img_ids=ref_img_ids,
|
| 233 |
+
txt=txt,
|
| 234 |
+
txt_ids=txt_ids,
|
| 235 |
+
y=vec,
|
| 236 |
+
timesteps=t_vec,
|
| 237 |
+
guidance=guidance_vec
|
| 238 |
+
)
|
| 239 |
+
img = img + (t_prev - t_curr) * pred
|
| 240 |
+
i += 1
|
| 241 |
+
return img
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
| 245 |
+
return rearrange(
|
| 246 |
+
x,
|
| 247 |
+
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
| 248 |
+
h=math.ceil(height / 16),
|
| 249 |
+
w=math.ceil(width / 16),
|
| 250 |
+
ph=2,
|
| 251 |
+
pw=2,
|
| 252 |
+
)
|
uno/flux/util.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
|
| 2 |
+
# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
|
| 3 |
+
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import json
|
| 21 |
+
import numpy as np
|
| 22 |
+
from huggingface_hub import hf_hub_download
|
| 23 |
+
from safetensors import safe_open
|
| 24 |
+
from safetensors.torch import load_file as load_sft
|
| 25 |
+
|
| 26 |
+
from .model import Flux, FluxParams
|
| 27 |
+
from .modules.autoencoder import AutoEncoder, AutoEncoderParams
|
| 28 |
+
from .modules.conditioner import HFEmbedder
|
| 29 |
+
|
| 30 |
+
import re
|
| 31 |
+
from uno.flux.modules.layers import DoubleStreamBlockLoraProcessor, SingleStreamBlockLoraProcessor
|
| 32 |
+
def load_model(ckpt, device='cpu'):
|
| 33 |
+
if ckpt.endswith('safetensors'):
|
| 34 |
+
from safetensors import safe_open
|
| 35 |
+
pl_sd = {}
|
| 36 |
+
with safe_open(ckpt, framework="pt", device=device) as f:
|
| 37 |
+
for k in f.keys():
|
| 38 |
+
pl_sd[k] = f.get_tensor(k)
|
| 39 |
+
else:
|
| 40 |
+
pl_sd = torch.load(ckpt, map_location=device)
|
| 41 |
+
return pl_sd
|
| 42 |
+
|
| 43 |
+
def load_safetensors(path):
|
| 44 |
+
tensors = {}
|
| 45 |
+
with safe_open(path, framework="pt", device="cpu") as f:
|
| 46 |
+
for key in f.keys():
|
| 47 |
+
tensors[key] = f.get_tensor(key)
|
| 48 |
+
return tensors
|
| 49 |
+
|
| 50 |
+
def get_lora_rank(checkpoint):
|
| 51 |
+
for k in checkpoint.keys():
|
| 52 |
+
if k.endswith(".down.weight"):
|
| 53 |
+
return checkpoint[k].shape[0]
|
| 54 |
+
|
| 55 |
+
def load_checkpoint(local_path, repo_id, name):
|
| 56 |
+
if local_path is not None:
|
| 57 |
+
if '.safetensors' in local_path:
|
| 58 |
+
print(f"Loading .safetensors checkpoint from {local_path}")
|
| 59 |
+
checkpoint = load_safetensors(local_path)
|
| 60 |
+
else:
|
| 61 |
+
print(f"Loading checkpoint from {local_path}")
|
| 62 |
+
checkpoint = torch.load(local_path, map_location='cpu')
|
| 63 |
+
elif repo_id is not None and name is not None:
|
| 64 |
+
print(f"Loading checkpoint {name} from repo id {repo_id}")
|
| 65 |
+
checkpoint = load_from_repo_id(repo_id, name)
|
| 66 |
+
else:
|
| 67 |
+
raise ValueError(
|
| 68 |
+
"LOADING ERROR: you must specify local_path or repo_id with name in HF to download"
|
| 69 |
+
)
|
| 70 |
+
return checkpoint
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def c_crop(image):
|
| 74 |
+
width, height = image.size
|
| 75 |
+
new_size = min(width, height)
|
| 76 |
+
left = (width - new_size) / 2
|
| 77 |
+
top = (height - new_size) / 2
|
| 78 |
+
right = (width + new_size) / 2
|
| 79 |
+
bottom = (height + new_size) / 2
|
| 80 |
+
return image.crop((left, top, right, bottom))
|
| 81 |
+
|
| 82 |
+
def pad64(x):
|
| 83 |
+
return int(np.ceil(float(x) / 64.0) * 64 - x)
|
| 84 |
+
|
| 85 |
+
def HWC3(x):
|
| 86 |
+
assert x.dtype == np.uint8
|
| 87 |
+
if x.ndim == 2:
|
| 88 |
+
x = x[:, :, None]
|
| 89 |
+
assert x.ndim == 3
|
| 90 |
+
H, W, C = x.shape
|
| 91 |
+
assert C == 1 or C == 3 or C == 4
|
| 92 |
+
if C == 3:
|
| 93 |
+
return x
|
| 94 |
+
if C == 1:
|
| 95 |
+
return np.concatenate([x, x, x], axis=2)
|
| 96 |
+
if C == 4:
|
| 97 |
+
color = x[:, :, 0:3].astype(np.float32)
|
| 98 |
+
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
| 99 |
+
y = color * alpha + 255.0 * (1.0 - alpha)
|
| 100 |
+
y = y.clip(0, 255).astype(np.uint8)
|
| 101 |
+
return y
|
| 102 |
+
|
| 103 |
+
@dataclass
|
| 104 |
+
class ModelSpec:
|
| 105 |
+
params: FluxParams
|
| 106 |
+
ae_params: AutoEncoderParams
|
| 107 |
+
ckpt_path: str | None
|
| 108 |
+
ae_path: str | None
|
| 109 |
+
repo_id: str | None
|
| 110 |
+
repo_flow: str | None
|
| 111 |
+
repo_ae: str | None
|
| 112 |
+
repo_id_ae: str | None
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
configs = {
|
| 116 |
+
"flux-dev": ModelSpec(
|
| 117 |
+
repo_id="black-forest-labs/FLUX.1-dev",
|
| 118 |
+
repo_id_ae="black-forest-labs/FLUX.1-dev",
|
| 119 |
+
repo_flow="flux1-dev.safetensors",
|
| 120 |
+
repo_ae="ae.safetensors",
|
| 121 |
+
ckpt_path=os.getenv("FLUX_DEV"),
|
| 122 |
+
params=FluxParams(
|
| 123 |
+
in_channels=64,
|
| 124 |
+
vec_in_dim=768,
|
| 125 |
+
context_in_dim=4096,
|
| 126 |
+
hidden_size=3072,
|
| 127 |
+
mlp_ratio=4.0,
|
| 128 |
+
num_heads=24,
|
| 129 |
+
depth=19,
|
| 130 |
+
depth_single_blocks=38,
|
| 131 |
+
axes_dim=[16, 56, 56],
|
| 132 |
+
theta=10_000,
|
| 133 |
+
qkv_bias=True,
|
| 134 |
+
guidance_embed=True,
|
| 135 |
+
),
|
| 136 |
+
ae_path=os.getenv("AE"),
|
| 137 |
+
ae_params=AutoEncoderParams(
|
| 138 |
+
resolution=256,
|
| 139 |
+
in_channels=3,
|
| 140 |
+
ch=128,
|
| 141 |
+
out_ch=3,
|
| 142 |
+
ch_mult=[1, 2, 4, 4],
|
| 143 |
+
num_res_blocks=2,
|
| 144 |
+
z_channels=16,
|
| 145 |
+
scale_factor=0.3611,
|
| 146 |
+
shift_factor=0.1159,
|
| 147 |
+
),
|
| 148 |
+
),
|
| 149 |
+
"flux-dev-fp8": ModelSpec(
|
| 150 |
+
repo_id="XLabs-AI/flux-dev-fp8",
|
| 151 |
+
repo_id_ae="black-forest-labs/FLUX.1-dev",
|
| 152 |
+
repo_flow="flux-dev-fp8.safetensors",
|
| 153 |
+
repo_ae="ae.safetensors",
|
| 154 |
+
ckpt_path=os.getenv("FLUX_DEV_FP8"),
|
| 155 |
+
params=FluxParams(
|
| 156 |
+
in_channels=64,
|
| 157 |
+
vec_in_dim=768,
|
| 158 |
+
context_in_dim=4096,
|
| 159 |
+
hidden_size=3072,
|
| 160 |
+
mlp_ratio=4.0,
|
| 161 |
+
num_heads=24,
|
| 162 |
+
depth=19,
|
| 163 |
+
depth_single_blocks=38,
|
| 164 |
+
axes_dim=[16, 56, 56],
|
| 165 |
+
theta=10_000,
|
| 166 |
+
qkv_bias=True,
|
| 167 |
+
guidance_embed=True,
|
| 168 |
+
),
|
| 169 |
+
ae_path=os.getenv("AE"),
|
| 170 |
+
ae_params=AutoEncoderParams(
|
| 171 |
+
resolution=256,
|
| 172 |
+
in_channels=3,
|
| 173 |
+
ch=128,
|
| 174 |
+
out_ch=3,
|
| 175 |
+
ch_mult=[1, 2, 4, 4],
|
| 176 |
+
num_res_blocks=2,
|
| 177 |
+
z_channels=16,
|
| 178 |
+
scale_factor=0.3611,
|
| 179 |
+
shift_factor=0.1159,
|
| 180 |
+
),
|
| 181 |
+
),
|
| 182 |
+
"flux-schnell": ModelSpec(
|
| 183 |
+
repo_id="black-forest-labs/FLUX.1-schnell",
|
| 184 |
+
repo_id_ae="black-forest-labs/FLUX.1-dev",
|
| 185 |
+
repo_flow="flux1-schnell.safetensors",
|
| 186 |
+
repo_ae="ae.safetensors",
|
| 187 |
+
ckpt_path=os.getenv("FLUX_SCHNELL"),
|
| 188 |
+
params=FluxParams(
|
| 189 |
+
in_channels=64,
|
| 190 |
+
vec_in_dim=768,
|
| 191 |
+
context_in_dim=4096,
|
| 192 |
+
hidden_size=3072,
|
| 193 |
+
mlp_ratio=4.0,
|
| 194 |
+
num_heads=24,
|
| 195 |
+
depth=19,
|
| 196 |
+
depth_single_blocks=38,
|
| 197 |
+
axes_dim=[16, 56, 56],
|
| 198 |
+
theta=10_000,
|
| 199 |
+
qkv_bias=True,
|
| 200 |
+
guidance_embed=False,
|
| 201 |
+
),
|
| 202 |
+
ae_path=os.getenv("AE"),
|
| 203 |
+
ae_params=AutoEncoderParams(
|
| 204 |
+
resolution=256,
|
| 205 |
+
in_channels=3,
|
| 206 |
+
ch=128,
|
| 207 |
+
out_ch=3,
|
| 208 |
+
ch_mult=[1, 2, 4, 4],
|
| 209 |
+
num_res_blocks=2,
|
| 210 |
+
z_channels=16,
|
| 211 |
+
scale_factor=0.3611,
|
| 212 |
+
shift_factor=0.1159,
|
| 213 |
+
),
|
| 214 |
+
),
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
|
| 219 |
+
if len(missing) > 0 and len(unexpected) > 0:
|
| 220 |
+
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
| 221 |
+
print("\n" + "-" * 79 + "\n")
|
| 222 |
+
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
| 223 |
+
elif len(missing) > 0:
|
| 224 |
+
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
| 225 |
+
elif len(unexpected) > 0:
|
| 226 |
+
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
| 227 |
+
|
| 228 |
+
def load_from_repo_id(repo_id, checkpoint_name):
|
| 229 |
+
ckpt_path = hf_hub_download(repo_id, checkpoint_name)
|
| 230 |
+
sd = load_sft(ckpt_path, device='cpu')
|
| 231 |
+
return sd
|
| 232 |
+
|
| 233 |
+
def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
|
| 234 |
+
# Loading Flux
|
| 235 |
+
print("Init model")
|
| 236 |
+
ckpt_path = configs[name].ckpt_path
|
| 237 |
+
if (
|
| 238 |
+
ckpt_path is None
|
| 239 |
+
and configs[name].repo_id is not None
|
| 240 |
+
and configs[name].repo_flow is not None
|
| 241 |
+
and hf_download
|
| 242 |
+
):
|
| 243 |
+
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
|
| 244 |
+
|
| 245 |
+
with torch.device("meta" if ckpt_path is not None else device):
|
| 246 |
+
model = Flux(configs[name].params).to(torch.bfloat16)
|
| 247 |
+
|
| 248 |
+
if ckpt_path is not None:
|
| 249 |
+
print("Loading checkpoint")
|
| 250 |
+
# load_sft doesn't support torch.device
|
| 251 |
+
sd = load_model(ckpt_path, device=str(device))
|
| 252 |
+
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
| 253 |
+
print_load_warning(missing, unexpected)
|
| 254 |
+
return model
|
| 255 |
+
|
| 256 |
+
def load_flow_model_only_lora(
|
| 257 |
+
name: str,
|
| 258 |
+
device: str | torch.device = "cuda",
|
| 259 |
+
hf_download: bool = True,
|
| 260 |
+
lora_rank: int = 16
|
| 261 |
+
):
|
| 262 |
+
# Loading Flux
|
| 263 |
+
print("Init model")
|
| 264 |
+
ckpt_path = configs[name].ckpt_path
|
| 265 |
+
if (
|
| 266 |
+
ckpt_path is None
|
| 267 |
+
and configs[name].repo_id is not None
|
| 268 |
+
and configs[name].repo_flow is not None
|
| 269 |
+
and hf_download
|
| 270 |
+
):
|
| 271 |
+
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors"))
|
| 272 |
+
|
| 273 |
+
if hf_download:
|
| 274 |
+
# lora_ckpt_path = hf_hub_download("bytedance-research/UNO", "dit_lora.safetensors")
|
| 275 |
+
try:
|
| 276 |
+
lora_ckpt_path = hf_hub_download("bytedance-research/UNO", "dit_lora.safetensors")
|
| 277 |
+
except:
|
| 278 |
+
lora_ckpt_path = os.environ.get("LORA", None)
|
| 279 |
+
else:
|
| 280 |
+
lora_ckpt_path = os.environ.get("LORA", None)
|
| 281 |
+
|
| 282 |
+
with torch.device("meta" if ckpt_path is not None else device):
|
| 283 |
+
model = Flux(configs[name].params)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
model = set_lora(model, lora_rank, device="meta" if lora_ckpt_path is not None else device)
|
| 287 |
+
|
| 288 |
+
if ckpt_path is not None:
|
| 289 |
+
print("Loading lora")
|
| 290 |
+
lora_sd = load_sft(lora_ckpt_path, device=str(device)) if lora_ckpt_path.endswith("safetensors")\
|
| 291 |
+
else torch.load(lora_ckpt_path, map_location='cpu')
|
| 292 |
+
|
| 293 |
+
print("Loading main checkpoint")
|
| 294 |
+
# load_sft doesn't support torch.device
|
| 295 |
+
|
| 296 |
+
if ckpt_path.endswith('safetensors'):
|
| 297 |
+
sd = load_sft(ckpt_path, device=str(device))
|
| 298 |
+
sd.update(lora_sd)
|
| 299 |
+
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
| 300 |
+
else:
|
| 301 |
+
dit_state = torch.load(ckpt_path, map_location='cpu')
|
| 302 |
+
sd = {}
|
| 303 |
+
for k in dit_state.keys():
|
| 304 |
+
sd[k.replace('module.','')] = dit_state[k]
|
| 305 |
+
sd.update(lora_sd)
|
| 306 |
+
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
| 307 |
+
model.to(str(device))
|
| 308 |
+
print_load_warning(missing, unexpected)
|
| 309 |
+
return model
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def set_lora(
|
| 313 |
+
model: Flux,
|
| 314 |
+
lora_rank: int,
|
| 315 |
+
double_blocks_indices: list[int] | None = None,
|
| 316 |
+
single_blocks_indices: list[int] | None = None,
|
| 317 |
+
device: str | torch.device = "cpu",
|
| 318 |
+
) -> Flux:
|
| 319 |
+
double_blocks_indices = list(range(model.params.depth)) if double_blocks_indices is None else double_blocks_indices
|
| 320 |
+
single_blocks_indices = list(range(model.params.depth_single_blocks)) if single_blocks_indices is None \
|
| 321 |
+
else single_blocks_indices
|
| 322 |
+
|
| 323 |
+
lora_attn_procs = {}
|
| 324 |
+
with torch.device(device):
|
| 325 |
+
for name, attn_processor in model.attn_processors.items():
|
| 326 |
+
match = re.search(r'\.(\d+)\.', name)
|
| 327 |
+
if match:
|
| 328 |
+
layer_index = int(match.group(1))
|
| 329 |
+
|
| 330 |
+
if name.startswith("double_blocks") and layer_index in double_blocks_indices:
|
| 331 |
+
lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank)
|
| 332 |
+
elif name.startswith("single_blocks") and layer_index in single_blocks_indices:
|
| 333 |
+
lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank)
|
| 334 |
+
else:
|
| 335 |
+
lora_attn_procs[name] = attn_processor
|
| 336 |
+
model.set_attn_processor(lora_attn_procs)
|
| 337 |
+
return model
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
|
| 341 |
+
# Loading Flux
|
| 342 |
+
from optimum.quanto import requantize
|
| 343 |
+
print("Init model")
|
| 344 |
+
ckpt_path = configs[name].ckpt_path
|
| 345 |
+
if (
|
| 346 |
+
ckpt_path is None
|
| 347 |
+
and configs[name].repo_id is not None
|
| 348 |
+
and configs[name].repo_flow is not None
|
| 349 |
+
and hf_download
|
| 350 |
+
):
|
| 351 |
+
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
|
| 352 |
+
json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json')
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
model = Flux(configs[name].params).to(torch.bfloat16)
|
| 356 |
+
|
| 357 |
+
print("Loading checkpoint")
|
| 358 |
+
# load_sft doesn't support torch.device
|
| 359 |
+
sd = load_sft(ckpt_path, device='cpu')
|
| 360 |
+
with open(json_path, "r") as f:
|
| 361 |
+
quantization_map = json.load(f)
|
| 362 |
+
print("Start a quantization process...")
|
| 363 |
+
requantize(model, sd, quantization_map, device=device)
|
| 364 |
+
print("Model is quantized!")
|
| 365 |
+
return model
|
| 366 |
+
|
| 367 |
+
def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
|
| 368 |
+
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
| 369 |
+
version = os.environ.get("T5", "xlabs-ai/xflux_text_encoders")
|
| 370 |
+
return HFEmbedder(version, max_length=max_length, torch_dtype=torch.bfloat16).to(device)
|
| 371 |
+
|
| 372 |
+
def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
|
| 373 |
+
version = os.environ.get("CLIP", "openai/clip-vit-large-patch14")
|
| 374 |
+
return HFEmbedder(version, max_length=77, torch_dtype=torch.bfloat16).to(device)
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
|
| 378 |
+
ckpt_path = configs[name].ae_path
|
| 379 |
+
if (
|
| 380 |
+
ckpt_path is None
|
| 381 |
+
and configs[name].repo_id is not None
|
| 382 |
+
and configs[name].repo_ae is not None
|
| 383 |
+
and hf_download
|
| 384 |
+
):
|
| 385 |
+
ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae)
|
| 386 |
+
|
| 387 |
+
# Loading the autoencoder
|
| 388 |
+
print("Init AE")
|
| 389 |
+
with torch.device("meta" if ckpt_path is not None else device):
|
| 390 |
+
ae = AutoEncoder(configs[name].ae_params)
|
| 391 |
+
|
| 392 |
+
if ckpt_path is not None:
|
| 393 |
+
sd = load_sft(ckpt_path, device=str(device))
|
| 394 |
+
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
|
| 395 |
+
print_load_warning(missing, unexpected)
|
| 396 |
+
return ae
|
uno/utils/convert_yaml_to_args_file.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
|
| 2 |
+
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import yaml
|
| 17 |
+
|
| 18 |
+
parser = argparse.ArgumentParser()
|
| 19 |
+
parser.add_argument("--yaml", type=str, required=True)
|
| 20 |
+
parser.add_argument("--arg", type=str, required=True)
|
| 21 |
+
args = parser.parse_args()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
with open(args.yaml, "r") as f:
|
| 25 |
+
data = yaml.safe_load(f)
|
| 26 |
+
|
| 27 |
+
with open(args.arg, "w") as f:
|
| 28 |
+
for k, v in data.items():
|
| 29 |
+
if isinstance(v, list):
|
| 30 |
+
v = list(map(str, v))
|
| 31 |
+
v = " ".join(v)
|
| 32 |
+
if v is None:
|
| 33 |
+
continue
|
| 34 |
+
print(f"--{k} {v}", end=" ", file=f)
|