Spaces:
Running
on
Zero
Running
on
Zero
wenmengzhou
commited on
Commit
•
703e263
1
Parent(s):
359b5e8
add code and adapt to zero gpus
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- app.py +255 -0
- diffsynth/__init__.py +6 -0
- diffsynth/configs/__init__.py +0 -0
- diffsynth/configs/model_config.py +275 -0
- diffsynth/controlnets/__init__.py +2 -0
- diffsynth/controlnets/controlnet_unit.py +54 -0
- diffsynth/controlnets/processors.py +51 -0
- diffsynth/data/__init__.py +1 -0
- diffsynth/data/simple_text_image.py +35 -0
- diffsynth/data/video.py +148 -0
- diffsynth/extensions/ESRGAN/__init__.py +118 -0
- diffsynth/extensions/FastBlend/__init__.py +63 -0
- diffsynth/extensions/FastBlend/api.py +397 -0
- diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
- diffsynth/extensions/FastBlend/data.py +146 -0
- diffsynth/extensions/FastBlend/patch_match.py +298 -0
- diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
- diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
- diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
- diffsynth/extensions/FastBlend/runners/fast.py +141 -0
- diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
- diffsynth/extensions/RIFE/__init__.py +242 -0
- diffsynth/extensions/__init__.py +0 -0
- diffsynth/models/__init__.py +1 -0
- diffsynth/models/attention.py +89 -0
- diffsynth/models/downloader.py +66 -0
- diffsynth/models/flux_dit.py +575 -0
- diffsynth/models/flux_text_encoder.py +93 -0
- diffsynth/models/flux_vae.py +303 -0
- diffsynth/models/hunyuan_dit.py +451 -0
- diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
- diffsynth/models/kolors_text_encoder.py +1552 -0
- diffsynth/models/lora.py +195 -0
- diffsynth/models/model_manager.py +543 -0
- diffsynth/models/sd3_dit.py +798 -0
- diffsynth/models/sd3_text_encoder.py +0 -0
- diffsynth/models/sd3_vae_decoder.py +81 -0
- diffsynth/models/sd3_vae_encoder.py +95 -0
- diffsynth/models/sd_controlnet.py +589 -0
- diffsynth/models/sd_ipadapter.py +57 -0
- diffsynth/models/sd_motion.py +199 -0
- diffsynth/models/sd_text_encoder.py +321 -0
- diffsynth/models/sd_unet.py +0 -0
- diffsynth/models/sd_vae_decoder.py +336 -0
- diffsynth/models/sd_vae_encoder.py +282 -0
- diffsynth/models/sdxl_controlnet.py +318 -0
- diffsynth/models/sdxl_ipadapter.py +122 -0
- diffsynth/models/sdxl_motion.py +104 -0
- diffsynth/models/sdxl_text_encoder.py +759 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import os
|
3 |
+
os.system("pip install -r requirements.txt")
|
4 |
+
from diffsynth import download_models
|
5 |
+
download_models(["Kolors", "FLUX.1-dev"])
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
from diffsynth import ModelManager, SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
|
9 |
+
import os, torch
|
10 |
+
from PIL import Image
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
|
14 |
+
config = {
|
15 |
+
"model_config": {
|
16 |
+
"Stable Diffusion": {
|
17 |
+
"model_folder": "models/stable_diffusion",
|
18 |
+
"pipeline_class": SDImagePipeline,
|
19 |
+
"default_parameters": {
|
20 |
+
"cfg_scale": 7.0,
|
21 |
+
"height": 512,
|
22 |
+
"width": 512,
|
23 |
+
}
|
24 |
+
},
|
25 |
+
"Stable Diffusion XL": {
|
26 |
+
"model_folder": "models/stable_diffusion_xl",
|
27 |
+
"pipeline_class": SDXLImagePipeline,
|
28 |
+
"default_parameters": {
|
29 |
+
"cfg_scale": 7.0,
|
30 |
+
}
|
31 |
+
},
|
32 |
+
"Stable Diffusion 3": {
|
33 |
+
"model_folder": "models/stable_diffusion_3",
|
34 |
+
"pipeline_class": SD3ImagePipeline,
|
35 |
+
"default_parameters": {
|
36 |
+
"cfg_scale": 7.0,
|
37 |
+
}
|
38 |
+
},
|
39 |
+
"Stable Diffusion XL Turbo": {
|
40 |
+
"model_folder": "models/stable_diffusion_xl_turbo",
|
41 |
+
"pipeline_class": SDXLImagePipeline,
|
42 |
+
"default_parameters": {
|
43 |
+
"negative_prompt": "",
|
44 |
+
"cfg_scale": 1.0,
|
45 |
+
"num_inference_steps": 1,
|
46 |
+
"height": 512,
|
47 |
+
"width": 512,
|
48 |
+
}
|
49 |
+
},
|
50 |
+
"Kolors": {
|
51 |
+
"model_folder": "models/kolors",
|
52 |
+
"pipeline_class": SDXLImagePipeline,
|
53 |
+
"default_parameters": {
|
54 |
+
"cfg_scale": 7.0,
|
55 |
+
}
|
56 |
+
},
|
57 |
+
"HunyuanDiT": {
|
58 |
+
"model_folder": "models/HunyuanDiT",
|
59 |
+
"pipeline_class": HunyuanDiTImagePipeline,
|
60 |
+
"default_parameters": {
|
61 |
+
"cfg_scale": 7.0,
|
62 |
+
}
|
63 |
+
},
|
64 |
+
"FLUX": {
|
65 |
+
"model_folder": "models/FLUX",
|
66 |
+
"pipeline_class": FluxImagePipeline,
|
67 |
+
"default_parameters": {
|
68 |
+
"cfg_scale": 1.0,
|
69 |
+
}
|
70 |
+
}
|
71 |
+
},
|
72 |
+
"max_num_painter_layers": 3,
|
73 |
+
"max_num_model_cache": 2,
|
74 |
+
}
|
75 |
+
|
76 |
+
|
77 |
+
def load_model_list(model_type):
|
78 |
+
if model_type is None:
|
79 |
+
return []
|
80 |
+
folder = config["model_config"][model_type]["model_folder"]
|
81 |
+
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
|
82 |
+
if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
|
83 |
+
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
|
84 |
+
file_list = sorted(file_list)
|
85 |
+
return file_list
|
86 |
+
|
87 |
+
|
88 |
+
def load_model(model_type, model_path):
|
89 |
+
global model_dict
|
90 |
+
model_key = f"{model_type}:{model_path}"
|
91 |
+
if model_key in model_dict:
|
92 |
+
return model_dict[model_key]
|
93 |
+
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
|
94 |
+
model_manager = ModelManager()
|
95 |
+
if model_type == "HunyuanDiT":
|
96 |
+
model_manager.load_models([
|
97 |
+
os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
|
98 |
+
os.path.join(model_path, "mt5/pytorch_model.bin"),
|
99 |
+
os.path.join(model_path, "model/pytorch_model_ema.pt"),
|
100 |
+
os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
|
101 |
+
])
|
102 |
+
elif model_type == "Kolors":
|
103 |
+
model_manager.load_models([
|
104 |
+
os.path.join(model_path, "text_encoder"),
|
105 |
+
os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
|
106 |
+
os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
|
107 |
+
])
|
108 |
+
elif model_type == "FLUX":
|
109 |
+
model_manager.torch_dtype = torch.bfloat16
|
110 |
+
file_list = [
|
111 |
+
os.path.join(model_path, "text_encoder/model.safetensors"),
|
112 |
+
os.path.join(model_path, "text_encoder_2"),
|
113 |
+
]
|
114 |
+
for file_name in os.listdir(model_path):
|
115 |
+
if file_name.endswith(".safetensors"):
|
116 |
+
file_list.append(os.path.join(model_path, file_name))
|
117 |
+
model_manager.load_models(file_list)
|
118 |
+
else:
|
119 |
+
model_manager.load_model(model_path)
|
120 |
+
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
|
121 |
+
while len(model_dict) + 1 > config["max_num_model_cache"]:
|
122 |
+
key = next(iter(model_dict.keys()))
|
123 |
+
model_manager_to_release, _ = model_dict[key]
|
124 |
+
model_manager_to_release.to("cpu")
|
125 |
+
del model_dict[key]
|
126 |
+
torch.cuda.empty_cache()
|
127 |
+
model_dict[model_key] = model_manager, pipe
|
128 |
+
return model_manager, pipe
|
129 |
+
|
130 |
+
|
131 |
+
model_dict = {}
|
132 |
+
|
133 |
+
with gr.Blocks() as app:
|
134 |
+
gr.Markdown("# DiffSynth-Studio Painter")
|
135 |
+
with gr.Row():
|
136 |
+
with gr.Column(scale=382, min_width=100):
|
137 |
+
|
138 |
+
with gr.Accordion(label="Model"):
|
139 |
+
model_type = gr.Dropdown(choices=["Kolors", "FLUX"], label="Model type", value="Kolors")
|
140 |
+
model_path = gr.Dropdown(choices=["Kolors"], interactive=True, label="Model path", value="Kolors")
|
141 |
+
|
142 |
+
@gr.on(inputs=model_type, outputs=model_path, triggers=model_type.change)
|
143 |
+
def model_type_to_model_path(model_type):
|
144 |
+
return gr.Dropdown(choices=load_model_list(model_type))
|
145 |
+
|
146 |
+
with gr.Accordion(label="Prompt"):
|
147 |
+
prompt = gr.Textbox(label="Prompt", lines=3)
|
148 |
+
negative_prompt = gr.Textbox(label="Negative prompt", lines=1)
|
149 |
+
cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, value=7.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
|
150 |
+
embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=0.0, step=0.1, interactive=True, label="Embedded guidance scale (only for FLUX)")
|
151 |
+
|
152 |
+
with gr.Accordion(label="Image"):
|
153 |
+
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, interactive=True, label="Inference steps")
|
154 |
+
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
|
155 |
+
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
|
156 |
+
with gr.Column():
|
157 |
+
use_fixed_seed = gr.Checkbox(value=True, interactive=False, label="Use fixed seed")
|
158 |
+
seed = gr.Number(minimum=0, maximum=10**9, value=0, interactive=True, label="Random seed", show_label=False)
|
159 |
+
|
160 |
+
@gr.on(
|
161 |
+
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
|
162 |
+
outputs=[prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
|
163 |
+
triggers=model_path.change
|
164 |
+
)
|
165 |
+
def model_path_to_default_params(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width):
|
166 |
+
load_model(model_type, model_path)
|
167 |
+
cfg_scale = config["model_config"][model_type]["default_parameters"].get("cfg_scale", cfg_scale)
|
168 |
+
embedded_guidance = config["model_config"][model_type]["default_parameters"].get("embedded_guidance", embedded_guidance)
|
169 |
+
num_inference_steps = config["model_config"][model_type]["default_parameters"].get("num_inference_steps", num_inference_steps)
|
170 |
+
height = config["model_config"][model_type]["default_parameters"].get("height", height)
|
171 |
+
width = config["model_config"][model_type]["default_parameters"].get("width", width)
|
172 |
+
return prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width
|
173 |
+
|
174 |
+
|
175 |
+
with gr.Column(scale=618, min_width=100):
|
176 |
+
with gr.Accordion(label="Painter"):
|
177 |
+
enable_local_prompt_list = []
|
178 |
+
local_prompt_list = []
|
179 |
+
mask_scale_list = []
|
180 |
+
canvas_list = []
|
181 |
+
for painter_layer_id in range(config["max_num_painter_layers"]):
|
182 |
+
with gr.Tab(label=f"Layer {painter_layer_id}"):
|
183 |
+
enable_local_prompt = gr.Checkbox(label="Enable", value=False, key=f"enable_local_prompt_{painter_layer_id}")
|
184 |
+
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
|
185 |
+
mask_scale = gr.Slider(minimum=0.0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Mask scale", key=f"mask_scale_{painter_layer_id}")
|
186 |
+
canvas = gr.ImageEditor(canvas_size=(512, 1), sources=None, layers=False, interactive=True, image_mode="RGBA",
|
187 |
+
brush=gr.Brush(default_size=100, default_color="#000000", colors=["#000000"]),
|
188 |
+
label="Painter", key=f"canvas_{painter_layer_id}")
|
189 |
+
@gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear, enable_local_prompt.change], show_progress="hidden")
|
190 |
+
def resize_canvas(height, width, canvas):
|
191 |
+
h, w = canvas["background"].shape[:2]
|
192 |
+
if h != height or width != w:
|
193 |
+
return np.ones((height, width, 3), dtype=np.uint8) * 255
|
194 |
+
else:
|
195 |
+
return canvas
|
196 |
+
|
197 |
+
enable_local_prompt_list.append(enable_local_prompt)
|
198 |
+
local_prompt_list.append(local_prompt)
|
199 |
+
mask_scale_list.append(mask_scale)
|
200 |
+
canvas_list.append(canvas)
|
201 |
+
with gr.Accordion(label="Results"):
|
202 |
+
run_button = gr.Button(value="Generate", variant="primary")
|
203 |
+
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
|
204 |
+
output_to_painter_button = gr.Button(value="Set as painter's background")
|
205 |
+
painter_background = gr.State(None)
|
206 |
+
input_background = gr.State(None)
|
207 |
+
@gr.on(
|
208 |
+
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list,
|
209 |
+
outputs=[output_image],
|
210 |
+
triggers=run_button.click
|
211 |
+
)
|
212 |
+
@spaces.GPU
|
213 |
+
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed, *args, progress=gr.Progress()):
|
214 |
+
_, pipe = load_model(model_type, model_path)
|
215 |
+
input_params = {
|
216 |
+
"prompt": prompt,
|
217 |
+
"negative_prompt": negative_prompt,
|
218 |
+
"cfg_scale": cfg_scale,
|
219 |
+
"num_inference_steps": num_inference_steps,
|
220 |
+
"height": height,
|
221 |
+
"width": width,
|
222 |
+
"progress_bar_cmd": progress.tqdm,
|
223 |
+
}
|
224 |
+
if isinstance(pipe, FluxImagePipeline):
|
225 |
+
input_params["embedded_guidance"] = embedded_guidance
|
226 |
+
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = (
|
227 |
+
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
|
228 |
+
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
|
229 |
+
args[2 * config["max_num_painter_layers"]: 3 * config["max_num_painter_layers"]],
|
230 |
+
args[3 * config["max_num_painter_layers"]: 4 * config["max_num_painter_layers"]]
|
231 |
+
)
|
232 |
+
local_prompts, masks, mask_scales = [], [], []
|
233 |
+
for enable_local_prompt, local_prompt, mask_scale, canvas in zip(
|
234 |
+
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list
|
235 |
+
):
|
236 |
+
if enable_local_prompt:
|
237 |
+
local_prompts.append(local_prompt)
|
238 |
+
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
|
239 |
+
mask_scales.append(mask_scale)
|
240 |
+
input_params.update({
|
241 |
+
"local_prompts": local_prompts,
|
242 |
+
"masks": masks,
|
243 |
+
"mask_scales": mask_scales,
|
244 |
+
})
|
245 |
+
torch.manual_seed(seed)
|
246 |
+
image = pipe(**input_params)
|
247 |
+
return image
|
248 |
+
|
249 |
+
@gr.on(inputs=[output_image] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
|
250 |
+
def send_output_to_painter_background(output_image, *canvas_list):
|
251 |
+
for canvas in canvas_list:
|
252 |
+
h, w = canvas["background"].shape[:2]
|
253 |
+
canvas["background"] = output_image.resize((w, h))
|
254 |
+
return tuple(canvas_list)
|
255 |
+
app.launch()
|
diffsynth/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .data import *
|
2 |
+
from .models import *
|
3 |
+
from .prompters import *
|
4 |
+
from .schedulers import *
|
5 |
+
from .pipelines import *
|
6 |
+
from .controlnets import *
|
diffsynth/configs/__init__.py
ADDED
File without changes
|
diffsynth/configs/model_config.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing_extensions import Literal, TypeAlias
|
2 |
+
|
3 |
+
from ..models.sd_text_encoder import SDTextEncoder
|
4 |
+
from ..models.sd_unet import SDUNet
|
5 |
+
from ..models.sd_vae_encoder import SDVAEEncoder
|
6 |
+
from ..models.sd_vae_decoder import SDVAEDecoder
|
7 |
+
|
8 |
+
from ..models.sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
9 |
+
from ..models.sdxl_unet import SDXLUNet
|
10 |
+
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
|
11 |
+
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
|
12 |
+
|
13 |
+
from ..models.sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
14 |
+
from ..models.sd3_dit import SD3DiT
|
15 |
+
from ..models.sd3_vae_decoder import SD3VAEDecoder
|
16 |
+
from ..models.sd3_vae_encoder import SD3VAEEncoder
|
17 |
+
|
18 |
+
from ..models.sd_controlnet import SDControlNet
|
19 |
+
from ..models.sdxl_controlnet import SDXLControlNetUnion
|
20 |
+
|
21 |
+
from ..models.sd_motion import SDMotionModel
|
22 |
+
from ..models.sdxl_motion import SDXLMotionModel
|
23 |
+
|
24 |
+
from ..models.svd_image_encoder import SVDImageEncoder
|
25 |
+
from ..models.svd_unet import SVDUNet
|
26 |
+
from ..models.svd_vae_decoder import SVDVAEDecoder
|
27 |
+
from ..models.svd_vae_encoder import SVDVAEEncoder
|
28 |
+
|
29 |
+
from ..models.sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
30 |
+
from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
31 |
+
|
32 |
+
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
33 |
+
from ..models.hunyuan_dit import HunyuanDiT
|
34 |
+
|
35 |
+
|
36 |
+
from ..models.flux_dit import FluxDiT
|
37 |
+
from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
|
38 |
+
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
model_loader_configs = [
|
43 |
+
# These configs are provided for detecting model type automatically.
|
44 |
+
# The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
|
45 |
+
(None, "091b0e30e77c76626b3ba62acdf95343", ["sd_controlnet"], [SDControlNet], "civitai"),
|
46 |
+
(None, "4a6c8306a27d916dea81263c8c88f450", ["hunyuan_dit_clip_text_encoder"], [HunyuanDiTCLIPTextEncoder], "civitai"),
|
47 |
+
(None, "f4aec400fe394297961218c768004521", ["hunyuan_dit"], [HunyuanDiT], "civitai"),
|
48 |
+
(None, "9e6e58043a5a2e332803ed42f6ee7181", ["hunyuan_dit_t5_text_encoder"], [HunyuanDiTT5TextEncoder], "civitai"),
|
49 |
+
(None, "13115dd45a6e1c39860f91ab073b8a78", ["sdxl_vae_encoder", "sdxl_vae_decoder"], [SDXLVAEEncoder, SDXLVAEDecoder], "diffusers"),
|
50 |
+
(None, "d78aa6797382a6d455362358a3295ea9", ["sd_ipadapter_clip_image_encoder"], [IpAdapterCLIPImageEmbedder], "diffusers"),
|
51 |
+
(None, "e291636cc15e803186b47404262ef812", ["sd_ipadapter"], [SDIpAdapter], "civitai"),
|
52 |
+
(None, "399c81f2f8de8d1843d0127a00f3c224", ["sdxl_ipadapter_clip_image_encoder"], [IpAdapterXLCLIPImageEmbedder], "diffusers"),
|
53 |
+
(None, "a64eac9aa0db4b9602213bc0131281c7", ["sdxl_ipadapter"], [SDXLIpAdapter], "civitai"),
|
54 |
+
(None, "52817e4fdd89df154f02749ca6f692ac", ["sdxl_unet"], [SDXLUNet], "diffusers"),
|
55 |
+
(None, "03343c606f16d834d6411d0902b53636", ["sd_text_encoder", "sd_unet", "sd_vae_decoder", "sd_vae_encoder"], [SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder], "civitai"),
|
56 |
+
(None, "d4ba77a7ece070679b4a987f58f201e9", ["sd_text_encoder"], [SDTextEncoder], "civitai"),
|
57 |
+
(None, "d0c89e55c5a57cf3981def0cb1c9e65a", ["sd_vae_decoder", "sd_vae_encoder"], [SDVAEDecoder, SDVAEEncoder], "civitai"),
|
58 |
+
(None, "3926bf373b39a67eeafd7901478a47a7", ["sd_unet"], [SDUNet], "civitai"),
|
59 |
+
(None, "1e0c39ec176b9007c05f76d52b554a4d", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
60 |
+
(None, "d9e0290829ba8d98e28e1a2b1407db4a", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
61 |
+
(None, "5072d0b24e406b49507abe861cf97691", ["sd3_text_encoder_3"], [SD3TextEncoder3], "civitai"),
|
62 |
+
(None, "4cf64a799d04260df438c6f33c9a047e", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"),
|
63 |
+
(None, "d9b008a867c498ab12ad24042eff8e3f", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"), # SDXL-Turbo
|
64 |
+
(None, "025bb7452e531a3853d951d77c63f032", ["sdxl_text_encoder", "sdxl_text_encoder_2"], [SDXLTextEncoder, SDXLTextEncoder2], "civitai"),
|
65 |
+
(None, "298997b403a4245c04102c9f36aac348", ["sdxl_unet"], [SDXLUNet], "civitai"),
|
66 |
+
(None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
|
67 |
+
(None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
|
68 |
+
(None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
|
69 |
+
(None, "31d2d9614fba60511fc9bf2604aa01f7", ["sdxl_controlnet"], [SDXLControlNetUnion], "diffusers"),
|
70 |
+
(None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["flux_text_encoder_1"], [FluxTextEncoder1], "diffusers"),
|
71 |
+
(None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"),
|
72 |
+
(None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"),
|
73 |
+
(None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai")
|
74 |
+
]
|
75 |
+
huggingface_model_loader_configs = [
|
76 |
+
# These configs are provided for detecting model type automatically.
|
77 |
+
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
|
78 |
+
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
|
79 |
+
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
|
80 |
+
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
|
81 |
+
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
|
82 |
+
]
|
83 |
+
patch_model_loader_configs = [
|
84 |
+
# These configs are provided for detecting model type automatically.
|
85 |
+
# The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
|
86 |
+
("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
|
87 |
+
]
|
88 |
+
|
89 |
+
preset_models_on_huggingface = {
|
90 |
+
"HunyuanDiT": [
|
91 |
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
92 |
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
93 |
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
94 |
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
95 |
+
],
|
96 |
+
"stable-video-diffusion-img2vid-xt": [
|
97 |
+
("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
98 |
+
],
|
99 |
+
"ExVideo-SVD-128f-v1": [
|
100 |
+
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
101 |
+
],
|
102 |
+
}
|
103 |
+
preset_models_on_modelscope = {
|
104 |
+
# Hunyuan DiT
|
105 |
+
"HunyuanDiT": [
|
106 |
+
("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
107 |
+
("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
108 |
+
("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
109 |
+
("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
110 |
+
],
|
111 |
+
# Stable Video Diffusion
|
112 |
+
"stable-video-diffusion-img2vid-xt": [
|
113 |
+
("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
114 |
+
],
|
115 |
+
# ExVideo
|
116 |
+
"ExVideo-SVD-128f-v1": [
|
117 |
+
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
118 |
+
],
|
119 |
+
# Stable Diffusion
|
120 |
+
"StableDiffusion_v15": [
|
121 |
+
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
122 |
+
],
|
123 |
+
"DreamShaper_8": [
|
124 |
+
("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
125 |
+
],
|
126 |
+
"AingDiffusion_v12": [
|
127 |
+
("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
|
128 |
+
],
|
129 |
+
"Flat2DAnimerge_v45Sharp": [
|
130 |
+
("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
|
131 |
+
],
|
132 |
+
# Textual Inversion
|
133 |
+
"TextualInversion_VeryBadImageNegative_v1.3": [
|
134 |
+
("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
|
135 |
+
],
|
136 |
+
# Stable Diffusion XL
|
137 |
+
"StableDiffusionXL_v1": [
|
138 |
+
("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
|
139 |
+
],
|
140 |
+
"BluePencilXL_v200": [
|
141 |
+
("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
|
142 |
+
],
|
143 |
+
"StableDiffusionXL_Turbo": [
|
144 |
+
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
145 |
+
],
|
146 |
+
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
|
147 |
+
("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
|
148 |
+
],
|
149 |
+
# Stable Diffusion 3
|
150 |
+
"StableDiffusion3": [
|
151 |
+
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
152 |
+
],
|
153 |
+
"StableDiffusion3_without_T5": [
|
154 |
+
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
|
155 |
+
],
|
156 |
+
# ControlNet
|
157 |
+
"ControlNet_v11f1p_sd15_depth": [
|
158 |
+
("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
|
159 |
+
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
160 |
+
],
|
161 |
+
"ControlNet_v11p_sd15_softedge": [
|
162 |
+
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
|
163 |
+
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
|
164 |
+
],
|
165 |
+
"ControlNet_v11f1e_sd15_tile": [
|
166 |
+
("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
|
167 |
+
],
|
168 |
+
"ControlNet_v11p_sd15_lineart": [
|
169 |
+
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
|
170 |
+
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
171 |
+
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
|
172 |
+
],
|
173 |
+
"ControlNet_union_sdxl_promax": [
|
174 |
+
("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
|
175 |
+
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
176 |
+
],
|
177 |
+
# AnimateDiff
|
178 |
+
"AnimateDiff_v2": [
|
179 |
+
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
180 |
+
],
|
181 |
+
"AnimateDiff_xl_beta": [
|
182 |
+
("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
183 |
+
],
|
184 |
+
# RIFE
|
185 |
+
"RIFE": [
|
186 |
+
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
|
187 |
+
],
|
188 |
+
# Beautiful Prompt
|
189 |
+
"BeautifulPrompt": [
|
190 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
191 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
192 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
193 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
194 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
195 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
196 |
+
],
|
197 |
+
# Translator
|
198 |
+
"opus-mt-zh-en": [
|
199 |
+
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
200 |
+
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
201 |
+
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
202 |
+
("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
203 |
+
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
204 |
+
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
205 |
+
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
206 |
+
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
207 |
+
],
|
208 |
+
# IP-Adapter
|
209 |
+
"IP-Adapter-SD": [
|
210 |
+
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
211 |
+
("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
|
212 |
+
],
|
213 |
+
"IP-Adapter-SDXL": [
|
214 |
+
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
215 |
+
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
216 |
+
],
|
217 |
+
# Kolors
|
218 |
+
"Kolors": [
|
219 |
+
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
220 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
221 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
222 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
223 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
224 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
225 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
226 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
227 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
228 |
+
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
229 |
+
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
230 |
+
],
|
231 |
+
"SDXL-vae-fp16-fix": [
|
232 |
+
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
233 |
+
],
|
234 |
+
# FLUX
|
235 |
+
"FLUX.1-dev": [
|
236 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
237 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
238 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
239 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
240 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
241 |
+
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
242 |
+
("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
243 |
+
]
|
244 |
+
}
|
245 |
+
Preset_model_id: TypeAlias = Literal[
|
246 |
+
"HunyuanDiT",
|
247 |
+
"stable-video-diffusion-img2vid-xt",
|
248 |
+
"ExVideo-SVD-128f-v1",
|
249 |
+
"StableDiffusion_v15",
|
250 |
+
"DreamShaper_8",
|
251 |
+
"AingDiffusion_v12",
|
252 |
+
"Flat2DAnimerge_v45Sharp",
|
253 |
+
"TextualInversion_VeryBadImageNegative_v1.3",
|
254 |
+
"StableDiffusionXL_v1",
|
255 |
+
"BluePencilXL_v200",
|
256 |
+
"StableDiffusionXL_Turbo",
|
257 |
+
"ControlNet_v11f1p_sd15_depth",
|
258 |
+
"ControlNet_v11p_sd15_softedge",
|
259 |
+
"ControlNet_v11f1e_sd15_tile",
|
260 |
+
"ControlNet_v11p_sd15_lineart",
|
261 |
+
"AnimateDiff_v2",
|
262 |
+
"AnimateDiff_xl_beta",
|
263 |
+
"RIFE",
|
264 |
+
"BeautifulPrompt",
|
265 |
+
"opus-mt-zh-en",
|
266 |
+
"IP-Adapter-SD",
|
267 |
+
"IP-Adapter-SDXL",
|
268 |
+
"StableDiffusion3",
|
269 |
+
"StableDiffusion3_without_T5",
|
270 |
+
"Kolors",
|
271 |
+
"SDXL-vae-fp16-fix",
|
272 |
+
"ControlNet_union_sdxl_promax",
|
273 |
+
"FLUX.1-dev",
|
274 |
+
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
275 |
+
]
|
diffsynth/controlnets/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager
|
2 |
+
from .processors import Annotator
|
diffsynth/controlnets/controlnet_unit.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from .processors import Processor_id
|
4 |
+
|
5 |
+
|
6 |
+
class ControlNetConfigUnit:
|
7 |
+
def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
|
8 |
+
self.processor_id = processor_id
|
9 |
+
self.model_path = model_path
|
10 |
+
self.scale = scale
|
11 |
+
|
12 |
+
|
13 |
+
class ControlNetUnit:
|
14 |
+
def __init__(self, processor, model, scale=1.0):
|
15 |
+
self.processor = processor
|
16 |
+
self.model = model
|
17 |
+
self.scale = scale
|
18 |
+
|
19 |
+
|
20 |
+
class MultiControlNetManager:
|
21 |
+
def __init__(self, controlnet_units=[]):
|
22 |
+
self.processors = [unit.processor for unit in controlnet_units]
|
23 |
+
self.models = [unit.model for unit in controlnet_units]
|
24 |
+
self.scales = [unit.scale for unit in controlnet_units]
|
25 |
+
|
26 |
+
def process_image(self, image, processor_id=None):
|
27 |
+
if processor_id is None:
|
28 |
+
processed_image = [processor(image) for processor in self.processors]
|
29 |
+
else:
|
30 |
+
processed_image = [self.processors[processor_id](image)]
|
31 |
+
processed_image = torch.concat([
|
32 |
+
torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
|
33 |
+
for image_ in processed_image
|
34 |
+
], dim=0)
|
35 |
+
return processed_image
|
36 |
+
|
37 |
+
def __call__(
|
38 |
+
self,
|
39 |
+
sample, timestep, encoder_hidden_states, conditionings,
|
40 |
+
tiled=False, tile_size=64, tile_stride=32, **kwargs
|
41 |
+
):
|
42 |
+
res_stack = None
|
43 |
+
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
|
44 |
+
res_stack_ = model(
|
45 |
+
sample, timestep, encoder_hidden_states, conditioning, **kwargs,
|
46 |
+
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
47 |
+
processor_id=processor.processor_id
|
48 |
+
)
|
49 |
+
res_stack_ = [res * scale for res in res_stack_]
|
50 |
+
if res_stack is None:
|
51 |
+
res_stack = res_stack_
|
52 |
+
else:
|
53 |
+
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
54 |
+
return res_stack
|
diffsynth/controlnets/processors.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing_extensions import Literal, TypeAlias
|
2 |
+
import warnings
|
3 |
+
with warnings.catch_warnings():
|
4 |
+
warnings.simplefilter("ignore")
|
5 |
+
from controlnet_aux.processor import (
|
6 |
+
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
Processor_id: TypeAlias = Literal[
|
11 |
+
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"
|
12 |
+
]
|
13 |
+
|
14 |
+
class Annotator:
|
15 |
+
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda'):
|
16 |
+
if processor_id == "canny":
|
17 |
+
self.processor = CannyDetector()
|
18 |
+
elif processor_id == "depth":
|
19 |
+
self.processor = MidasDetector.from_pretrained(model_path).to(device)
|
20 |
+
elif processor_id == "softedge":
|
21 |
+
self.processor = HEDdetector.from_pretrained(model_path).to(device)
|
22 |
+
elif processor_id == "lineart":
|
23 |
+
self.processor = LineartDetector.from_pretrained(model_path).to(device)
|
24 |
+
elif processor_id == "lineart_anime":
|
25 |
+
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
|
26 |
+
elif processor_id == "openpose":
|
27 |
+
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
|
28 |
+
elif processor_id == "tile":
|
29 |
+
self.processor = None
|
30 |
+
else:
|
31 |
+
raise ValueError(f"Unsupported processor_id: {processor_id}")
|
32 |
+
|
33 |
+
self.processor_id = processor_id
|
34 |
+
self.detect_resolution = detect_resolution
|
35 |
+
|
36 |
+
def __call__(self, image):
|
37 |
+
width, height = image.size
|
38 |
+
if self.processor_id == "openpose":
|
39 |
+
kwargs = {
|
40 |
+
"include_body": True,
|
41 |
+
"include_hand": True,
|
42 |
+
"include_face": True
|
43 |
+
}
|
44 |
+
else:
|
45 |
+
kwargs = {}
|
46 |
+
if self.processor is not None:
|
47 |
+
detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
|
48 |
+
image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
|
49 |
+
image = image.resize((width, height))
|
50 |
+
return image
|
51 |
+
|
diffsynth/data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .video import VideoData, save_video, save_frames
|
diffsynth/data/simple_text_image.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, os
|
2 |
+
from torchvision import transforms
|
3 |
+
import pandas as pd
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
class TextImageDataset(torch.utils.data.Dataset):
|
9 |
+
def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
|
10 |
+
self.steps_per_epoch = steps_per_epoch
|
11 |
+
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
|
12 |
+
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
|
13 |
+
self.text = metadata["text"].to_list()
|
14 |
+
self.image_processor = transforms.Compose(
|
15 |
+
[
|
16 |
+
transforms.Resize(max(height, width), interpolation=transforms.InterpolationMode.BILINEAR),
|
17 |
+
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
|
18 |
+
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
|
19 |
+
transforms.ToTensor(),
|
20 |
+
transforms.Normalize([0.5], [0.5]),
|
21 |
+
]
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
def __getitem__(self, index):
|
26 |
+
data_id = torch.randint(0, len(self.path), (1,))[0]
|
27 |
+
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
28 |
+
text = self.text[data_id]
|
29 |
+
image = Image.open(self.path[data_id]).convert("RGB")
|
30 |
+
image = self.image_processor(image)
|
31 |
+
return {"text": text, "image": image}
|
32 |
+
|
33 |
+
|
34 |
+
def __len__(self):
|
35 |
+
return self.steps_per_epoch
|
diffsynth/data/video.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import imageio, os
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
|
7 |
+
class LowMemoryVideo:
|
8 |
+
def __init__(self, file_name):
|
9 |
+
self.reader = imageio.get_reader(file_name)
|
10 |
+
|
11 |
+
def __len__(self):
|
12 |
+
return self.reader.count_frames()
|
13 |
+
|
14 |
+
def __getitem__(self, item):
|
15 |
+
return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
|
16 |
+
|
17 |
+
def __del__(self):
|
18 |
+
self.reader.close()
|
19 |
+
|
20 |
+
|
21 |
+
def split_file_name(file_name):
|
22 |
+
result = []
|
23 |
+
number = -1
|
24 |
+
for i in file_name:
|
25 |
+
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
26 |
+
if number == -1:
|
27 |
+
number = 0
|
28 |
+
number = number*10 + ord(i) - ord("0")
|
29 |
+
else:
|
30 |
+
if number != -1:
|
31 |
+
result.append(number)
|
32 |
+
number = -1
|
33 |
+
result.append(i)
|
34 |
+
if number != -1:
|
35 |
+
result.append(number)
|
36 |
+
result = tuple(result)
|
37 |
+
return result
|
38 |
+
|
39 |
+
|
40 |
+
def search_for_images(folder):
|
41 |
+
file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
|
42 |
+
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
|
43 |
+
file_list = [i[1] for i in sorted(file_list)]
|
44 |
+
file_list = [os.path.join(folder, i) for i in file_list]
|
45 |
+
return file_list
|
46 |
+
|
47 |
+
|
48 |
+
class LowMemoryImageFolder:
|
49 |
+
def __init__(self, folder, file_list=None):
|
50 |
+
if file_list is None:
|
51 |
+
self.file_list = search_for_images(folder)
|
52 |
+
else:
|
53 |
+
self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
|
54 |
+
|
55 |
+
def __len__(self):
|
56 |
+
return len(self.file_list)
|
57 |
+
|
58 |
+
def __getitem__(self, item):
|
59 |
+
return Image.open(self.file_list[item]).convert("RGB")
|
60 |
+
|
61 |
+
def __del__(self):
|
62 |
+
pass
|
63 |
+
|
64 |
+
|
65 |
+
def crop_and_resize(image, height, width):
|
66 |
+
image = np.array(image)
|
67 |
+
image_height, image_width, _ = image.shape
|
68 |
+
if image_height / image_width < height / width:
|
69 |
+
croped_width = int(image_height / height * width)
|
70 |
+
left = (image_width - croped_width) // 2
|
71 |
+
image = image[:, left: left+croped_width]
|
72 |
+
image = Image.fromarray(image).resize((width, height))
|
73 |
+
else:
|
74 |
+
croped_height = int(image_width / width * height)
|
75 |
+
left = (image_height - croped_height) // 2
|
76 |
+
image = image[left: left+croped_height, :]
|
77 |
+
image = Image.fromarray(image).resize((width, height))
|
78 |
+
return image
|
79 |
+
|
80 |
+
|
81 |
+
class VideoData:
|
82 |
+
def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):
|
83 |
+
if video_file is not None:
|
84 |
+
self.data_type = "video"
|
85 |
+
self.data = LowMemoryVideo(video_file, **kwargs)
|
86 |
+
elif image_folder is not None:
|
87 |
+
self.data_type = "images"
|
88 |
+
self.data = LowMemoryImageFolder(image_folder, **kwargs)
|
89 |
+
else:
|
90 |
+
raise ValueError("Cannot open video or image folder")
|
91 |
+
self.length = None
|
92 |
+
self.set_shape(height, width)
|
93 |
+
|
94 |
+
def raw_data(self):
|
95 |
+
frames = []
|
96 |
+
for i in range(self.__len__()):
|
97 |
+
frames.append(self.__getitem__(i))
|
98 |
+
return frames
|
99 |
+
|
100 |
+
def set_length(self, length):
|
101 |
+
self.length = length
|
102 |
+
|
103 |
+
def set_shape(self, height, width):
|
104 |
+
self.height = height
|
105 |
+
self.width = width
|
106 |
+
|
107 |
+
def __len__(self):
|
108 |
+
if self.length is None:
|
109 |
+
return len(self.data)
|
110 |
+
else:
|
111 |
+
return self.length
|
112 |
+
|
113 |
+
def shape(self):
|
114 |
+
if self.height is not None and self.width is not None:
|
115 |
+
return self.height, self.width
|
116 |
+
else:
|
117 |
+
height, width, _ = self.__getitem__(0).shape
|
118 |
+
return height, width
|
119 |
+
|
120 |
+
def __getitem__(self, item):
|
121 |
+
frame = self.data.__getitem__(item)
|
122 |
+
width, height = frame.size
|
123 |
+
if self.height is not None and self.width is not None:
|
124 |
+
if self.height != height or self.width != width:
|
125 |
+
frame = crop_and_resize(frame, self.height, self.width)
|
126 |
+
return frame
|
127 |
+
|
128 |
+
def __del__(self):
|
129 |
+
pass
|
130 |
+
|
131 |
+
def save_images(self, folder):
|
132 |
+
os.makedirs(folder, exist_ok=True)
|
133 |
+
for i in tqdm(range(self.__len__()), desc="Saving images"):
|
134 |
+
frame = self.__getitem__(i)
|
135 |
+
frame.save(os.path.join(folder, f"{i}.png"))
|
136 |
+
|
137 |
+
|
138 |
+
def save_video(frames, save_path, fps, quality=9):
|
139 |
+
writer = imageio.get_writer(save_path, fps=fps, quality=quality)
|
140 |
+
for frame in tqdm(frames, desc="Saving video"):
|
141 |
+
frame = np.array(frame)
|
142 |
+
writer.append_data(frame)
|
143 |
+
writer.close()
|
144 |
+
|
145 |
+
def save_frames(frames, save_path):
|
146 |
+
os.makedirs(save_path, exist_ok=True)
|
147 |
+
for i, frame in enumerate(tqdm(frames, desc="Saving images")):
|
148 |
+
frame.save(os.path.join(save_path, f"{i}.png"))
|
diffsynth/extensions/ESRGAN/__init__.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from einops import repeat
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
class ResidualDenseBlock(torch.nn.Module):
|
8 |
+
|
9 |
+
def __init__(self, num_feat=64, num_grow_ch=32):
|
10 |
+
super(ResidualDenseBlock, self).__init__()
|
11 |
+
self.conv1 = torch.nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
|
12 |
+
self.conv2 = torch.nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
|
13 |
+
self.conv3 = torch.nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
14 |
+
self.conv4 = torch.nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
15 |
+
self.conv5 = torch.nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
|
16 |
+
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
x1 = self.lrelu(self.conv1(x))
|
20 |
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
21 |
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
22 |
+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
23 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
24 |
+
return x5 * 0.2 + x
|
25 |
+
|
26 |
+
|
27 |
+
class RRDB(torch.nn.Module):
|
28 |
+
|
29 |
+
def __init__(self, num_feat, num_grow_ch=32):
|
30 |
+
super(RRDB, self).__init__()
|
31 |
+
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
|
32 |
+
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
|
33 |
+
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
out = self.rdb1(x)
|
37 |
+
out = self.rdb2(out)
|
38 |
+
out = self.rdb3(out)
|
39 |
+
return out * 0.2 + x
|
40 |
+
|
41 |
+
|
42 |
+
class RRDBNet(torch.nn.Module):
|
43 |
+
|
44 |
+
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32):
|
45 |
+
super(RRDBNet, self).__init__()
|
46 |
+
self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
47 |
+
self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
|
48 |
+
self.conv_body = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
49 |
+
# upsample
|
50 |
+
self.conv_up1 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
51 |
+
self.conv_up2 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
52 |
+
self.conv_hr = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
53 |
+
self.conv_last = torch.nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
54 |
+
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
feat = x
|
58 |
+
feat = self.conv_first(feat)
|
59 |
+
body_feat = self.conv_body(self.body(feat))
|
60 |
+
feat = feat + body_feat
|
61 |
+
# upsample
|
62 |
+
feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
|
63 |
+
feat = self.lrelu(self.conv_up1(feat))
|
64 |
+
feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
|
65 |
+
feat = self.lrelu(self.conv_up2(feat))
|
66 |
+
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
67 |
+
return out
|
68 |
+
|
69 |
+
|
70 |
+
class ESRGAN(torch.nn.Module):
|
71 |
+
def __init__(self, model):
|
72 |
+
super().__init__()
|
73 |
+
self.model = model
|
74 |
+
|
75 |
+
@staticmethod
|
76 |
+
def from_pretrained(model_path):
|
77 |
+
model = RRDBNet()
|
78 |
+
state_dict = torch.load(model_path, map_location="cpu")["params_ema"]
|
79 |
+
model.load_state_dict(state_dict)
|
80 |
+
model.eval()
|
81 |
+
return ESRGAN(model)
|
82 |
+
|
83 |
+
def process_image(self, image):
|
84 |
+
image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
|
85 |
+
return image
|
86 |
+
|
87 |
+
def process_images(self, images):
|
88 |
+
images = [self.process_image(image) for image in images]
|
89 |
+
images = torch.stack(images)
|
90 |
+
return images
|
91 |
+
|
92 |
+
def decode_images(self, images):
|
93 |
+
images = (images.permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
|
94 |
+
images = [Image.fromarray(image) for image in images]
|
95 |
+
return images
|
96 |
+
|
97 |
+
@torch.no_grad()
|
98 |
+
def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
|
99 |
+
# Preprocess
|
100 |
+
input_tensor = self.process_images(images)
|
101 |
+
|
102 |
+
# Interpolate
|
103 |
+
output_tensor = []
|
104 |
+
for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
|
105 |
+
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
106 |
+
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
107 |
+
batch_input_tensor = batch_input_tensor.to(
|
108 |
+
device=self.model.conv_first.weight.device,
|
109 |
+
dtype=self.model.conv_first.weight.dtype)
|
110 |
+
batch_output_tensor = self.model(batch_input_tensor)
|
111 |
+
output_tensor.append(batch_output_tensor.cpu())
|
112 |
+
|
113 |
+
# Output
|
114 |
+
output_tensor = torch.concat(output_tensor, dim=0)
|
115 |
+
|
116 |
+
# To images
|
117 |
+
output_images = self.decode_images(output_tensor)
|
118 |
+
return output_images
|
diffsynth/extensions/FastBlend/__init__.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .runners.fast import TableManager, PyramidPatchMatcher
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import cupy as cp
|
5 |
+
|
6 |
+
|
7 |
+
class FastBlendSmoother:
|
8 |
+
def __init__(self):
|
9 |
+
self.batch_size = 8
|
10 |
+
self.window_size = 64
|
11 |
+
self.ebsynth_config = {
|
12 |
+
"minimum_patch_size": 5,
|
13 |
+
"threads_per_block": 8,
|
14 |
+
"num_iter": 5,
|
15 |
+
"gpu_id": 0,
|
16 |
+
"guide_weight": 10.0,
|
17 |
+
"initialize": "identity",
|
18 |
+
"tracking_window_size": 0,
|
19 |
+
}
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def from_model_manager(model_manager):
|
23 |
+
# TODO: fetch GPU ID from model_manager
|
24 |
+
return FastBlendSmoother()
|
25 |
+
|
26 |
+
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config):
|
27 |
+
frames_guide = [np.array(frame) for frame in frames_guide]
|
28 |
+
frames_style = [np.array(frame) for frame in frames_style]
|
29 |
+
table_manager = TableManager()
|
30 |
+
patch_match_engine = PyramidPatchMatcher(
|
31 |
+
image_height=frames_style[0].shape[0],
|
32 |
+
image_width=frames_style[0].shape[1],
|
33 |
+
channel=3,
|
34 |
+
**ebsynth_config
|
35 |
+
)
|
36 |
+
# left part
|
37 |
+
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4")
|
38 |
+
table_l = table_manager.remapping_table_to_blending_table(table_l)
|
39 |
+
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4")
|
40 |
+
# right part
|
41 |
+
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4")
|
42 |
+
table_r = table_manager.remapping_table_to_blending_table(table_r)
|
43 |
+
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1]
|
44 |
+
# merge
|
45 |
+
frames = []
|
46 |
+
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
|
47 |
+
weight_m = -1
|
48 |
+
weight = weight_l + weight_m + weight_r
|
49 |
+
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
|
50 |
+
frames.append(frame)
|
51 |
+
frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames]
|
52 |
+
return frames
|
53 |
+
|
54 |
+
def __call__(self, rendered_frames, original_frames=None, **kwargs):
|
55 |
+
frames = self.run(
|
56 |
+
original_frames, rendered_frames,
|
57 |
+
self.batch_size, self.window_size, self.ebsynth_config
|
58 |
+
)
|
59 |
+
mempool = cp.get_default_memory_pool()
|
60 |
+
pinned_mempool = cp.get_default_pinned_memory_pool()
|
61 |
+
mempool.free_all_blocks()
|
62 |
+
pinned_mempool.free_all_blocks()
|
63 |
+
return frames
|
diffsynth/extensions/FastBlend/api.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .runners import AccurateModeRunner, FastModeRunner, BalancedModeRunner, InterpolationModeRunner, InterpolationModeSingleFrameRunner
|
2 |
+
from .data import VideoData, get_video_fps, save_video, search_for_images
|
3 |
+
import os
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
|
7 |
+
def check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder):
|
8 |
+
frames_guide = VideoData(video_guide, video_guide_folder)
|
9 |
+
frames_style = VideoData(video_style, video_style_folder)
|
10 |
+
message = ""
|
11 |
+
if len(frames_guide) < len(frames_style):
|
12 |
+
message += f"The number of frames mismatches. Only the first {len(frames_guide)} frames of style video will be used.\n"
|
13 |
+
frames_style.set_length(len(frames_guide))
|
14 |
+
elif len(frames_guide) > len(frames_style):
|
15 |
+
message += f"The number of frames mismatches. Only the first {len(frames_style)} frames of guide video will be used.\n"
|
16 |
+
frames_guide.set_length(len(frames_style))
|
17 |
+
height_guide, width_guide = frames_guide.shape()
|
18 |
+
height_style, width_style = frames_style.shape()
|
19 |
+
if height_guide != height_style or width_guide != width_style:
|
20 |
+
message += f"The shape of frames mismatches. The frames in style video will be resized to (height: {height_guide}, width: {width_guide})\n"
|
21 |
+
frames_style.set_shape(height_guide, width_guide)
|
22 |
+
return frames_guide, frames_style, message
|
23 |
+
|
24 |
+
|
25 |
+
def smooth_video(
|
26 |
+
video_guide,
|
27 |
+
video_guide_folder,
|
28 |
+
video_style,
|
29 |
+
video_style_folder,
|
30 |
+
mode,
|
31 |
+
window_size,
|
32 |
+
batch_size,
|
33 |
+
tracking_window_size,
|
34 |
+
output_path,
|
35 |
+
fps,
|
36 |
+
minimum_patch_size,
|
37 |
+
num_iter,
|
38 |
+
guide_weight,
|
39 |
+
initialize,
|
40 |
+
progress = None,
|
41 |
+
):
|
42 |
+
# input
|
43 |
+
frames_guide, frames_style, message = check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder)
|
44 |
+
if len(message) > 0:
|
45 |
+
print(message)
|
46 |
+
# output
|
47 |
+
if output_path == "":
|
48 |
+
if video_style is None:
|
49 |
+
output_path = os.path.join(video_style_folder, "output")
|
50 |
+
else:
|
51 |
+
output_path = os.path.join(os.path.split(video_style)[0], "output")
|
52 |
+
os.makedirs(output_path, exist_ok=True)
|
53 |
+
print("No valid output_path. Your video will be saved here:", output_path)
|
54 |
+
elif not os.path.exists(output_path):
|
55 |
+
os.makedirs(output_path, exist_ok=True)
|
56 |
+
print("Your video will be saved here:", output_path)
|
57 |
+
frames_path = os.path.join(output_path, "frames")
|
58 |
+
video_path = os.path.join(output_path, "video.mp4")
|
59 |
+
os.makedirs(frames_path, exist_ok=True)
|
60 |
+
# process
|
61 |
+
if mode == "Fast" or mode == "Balanced":
|
62 |
+
tracking_window_size = 0
|
63 |
+
ebsynth_config = {
|
64 |
+
"minimum_patch_size": minimum_patch_size,
|
65 |
+
"threads_per_block": 8,
|
66 |
+
"num_iter": num_iter,
|
67 |
+
"gpu_id": 0,
|
68 |
+
"guide_weight": guide_weight,
|
69 |
+
"initialize": initialize,
|
70 |
+
"tracking_window_size": tracking_window_size,
|
71 |
+
}
|
72 |
+
if mode == "Fast":
|
73 |
+
FastModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
74 |
+
elif mode == "Balanced":
|
75 |
+
BalancedModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
76 |
+
elif mode == "Accurate":
|
77 |
+
AccurateModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
78 |
+
# output
|
79 |
+
try:
|
80 |
+
fps = int(fps)
|
81 |
+
except:
|
82 |
+
fps = get_video_fps(video_style) if video_style is not None else 30
|
83 |
+
print("Fps:", fps)
|
84 |
+
print("Saving video...")
|
85 |
+
video_path = save_video(frames_path, video_path, num_frames=len(frames_style), fps=fps)
|
86 |
+
print("Success!")
|
87 |
+
print("Your frames are here:", frames_path)
|
88 |
+
print("Your video is here:", video_path)
|
89 |
+
return output_path, fps, video_path
|
90 |
+
|
91 |
+
|
92 |
+
class KeyFrameMatcher:
|
93 |
+
def __init__(self):
|
94 |
+
pass
|
95 |
+
|
96 |
+
def extract_number_from_filename(self, file_name):
|
97 |
+
result = []
|
98 |
+
number = -1
|
99 |
+
for i in file_name:
|
100 |
+
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
101 |
+
if number == -1:
|
102 |
+
number = 0
|
103 |
+
number = number*10 + ord(i) - ord("0")
|
104 |
+
else:
|
105 |
+
if number != -1:
|
106 |
+
result.append(number)
|
107 |
+
number = -1
|
108 |
+
if number != -1:
|
109 |
+
result.append(number)
|
110 |
+
result = tuple(result)
|
111 |
+
return result
|
112 |
+
|
113 |
+
def extract_number_from_filenames(self, file_names):
|
114 |
+
numbers = [self.extract_number_from_filename(file_name) for file_name in file_names]
|
115 |
+
min_length = min(len(i) for i in numbers)
|
116 |
+
for i in range(min_length-1, -1, -1):
|
117 |
+
if len(set(number[i] for number in numbers))==len(file_names):
|
118 |
+
return [number[i] for number in numbers]
|
119 |
+
return list(range(len(file_names)))
|
120 |
+
|
121 |
+
def match_using_filename(self, file_names_a, file_names_b):
|
122 |
+
file_names_b_set = set(file_names_b)
|
123 |
+
matched_file_name = []
|
124 |
+
for file_name in file_names_a:
|
125 |
+
if file_name not in file_names_b_set:
|
126 |
+
matched_file_name.append(None)
|
127 |
+
else:
|
128 |
+
matched_file_name.append(file_name)
|
129 |
+
return matched_file_name
|
130 |
+
|
131 |
+
def match_using_numbers(self, file_names_a, file_names_b):
|
132 |
+
numbers_a = self.extract_number_from_filenames(file_names_a)
|
133 |
+
numbers_b = self.extract_number_from_filenames(file_names_b)
|
134 |
+
numbers_b_dict = {number: file_name for number, file_name in zip(numbers_b, file_names_b)}
|
135 |
+
matched_file_name = []
|
136 |
+
for number in numbers_a:
|
137 |
+
if number in numbers_b_dict:
|
138 |
+
matched_file_name.append(numbers_b_dict[number])
|
139 |
+
else:
|
140 |
+
matched_file_name.append(None)
|
141 |
+
return matched_file_name
|
142 |
+
|
143 |
+
def match_filenames(self, file_names_a, file_names_b):
|
144 |
+
matched_file_name = self.match_using_filename(file_names_a, file_names_b)
|
145 |
+
if sum([i is not None for i in matched_file_name]) > 0:
|
146 |
+
return matched_file_name
|
147 |
+
matched_file_name = self.match_using_numbers(file_names_a, file_names_b)
|
148 |
+
return matched_file_name
|
149 |
+
|
150 |
+
|
151 |
+
def detect_frames(frames_path, keyframes_path):
|
152 |
+
if not os.path.exists(frames_path) and not os.path.exists(keyframes_path):
|
153 |
+
return "Please input the directory of guide video and rendered frames"
|
154 |
+
elif not os.path.exists(frames_path):
|
155 |
+
return "Please input the directory of guide video"
|
156 |
+
elif not os.path.exists(keyframes_path):
|
157 |
+
return "Please input the directory of rendered frames"
|
158 |
+
frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
|
159 |
+
keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
|
160 |
+
if len(frames)==0:
|
161 |
+
return f"No images detected in {frames_path}"
|
162 |
+
if len(keyframes)==0:
|
163 |
+
return f"No images detected in {keyframes_path}"
|
164 |
+
matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
|
165 |
+
max_filename_length = max([len(i) for i in frames])
|
166 |
+
if sum([i is not None for i in matched_keyframes])==0:
|
167 |
+
message = ""
|
168 |
+
for frame, matched_keyframe in zip(frames, matched_keyframes):
|
169 |
+
message += frame + " " * (max_filename_length - len(frame) + 1)
|
170 |
+
message += "--> No matched keyframes\n"
|
171 |
+
else:
|
172 |
+
message = ""
|
173 |
+
for frame, matched_keyframe in zip(frames, matched_keyframes):
|
174 |
+
message += frame + " " * (max_filename_length - len(frame) + 1)
|
175 |
+
if matched_keyframe is None:
|
176 |
+
message += "--> [to be rendered]\n"
|
177 |
+
else:
|
178 |
+
message += f"--> {matched_keyframe}\n"
|
179 |
+
return message
|
180 |
+
|
181 |
+
|
182 |
+
def check_input_for_interpolating(frames_path, keyframes_path):
|
183 |
+
# search for images
|
184 |
+
frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
|
185 |
+
keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
|
186 |
+
# match frames
|
187 |
+
matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
|
188 |
+
file_list = [file_name for file_name in matched_keyframes if file_name is not None]
|
189 |
+
index_style = [i for i, file_name in enumerate(matched_keyframes) if file_name is not None]
|
190 |
+
frames_guide = VideoData(None, frames_path)
|
191 |
+
frames_style = VideoData(None, keyframes_path, file_list=file_list)
|
192 |
+
# match shape
|
193 |
+
message = ""
|
194 |
+
height_guide, width_guide = frames_guide.shape()
|
195 |
+
height_style, width_style = frames_style.shape()
|
196 |
+
if height_guide != height_style or width_guide != width_style:
|
197 |
+
message += f"The shape of frames mismatches. The rendered keyframes will be resized to (height: {height_guide}, width: {width_guide})\n"
|
198 |
+
frames_style.set_shape(height_guide, width_guide)
|
199 |
+
return frames_guide, frames_style, index_style, message
|
200 |
+
|
201 |
+
|
202 |
+
def interpolate_video(
|
203 |
+
frames_path,
|
204 |
+
keyframes_path,
|
205 |
+
output_path,
|
206 |
+
fps,
|
207 |
+
batch_size,
|
208 |
+
tracking_window_size,
|
209 |
+
minimum_patch_size,
|
210 |
+
num_iter,
|
211 |
+
guide_weight,
|
212 |
+
initialize,
|
213 |
+
progress = None,
|
214 |
+
):
|
215 |
+
# input
|
216 |
+
frames_guide, frames_style, index_style, message = check_input_for_interpolating(frames_path, keyframes_path)
|
217 |
+
if len(message) > 0:
|
218 |
+
print(message)
|
219 |
+
# output
|
220 |
+
if output_path == "":
|
221 |
+
output_path = os.path.join(keyframes_path, "output")
|
222 |
+
os.makedirs(output_path, exist_ok=True)
|
223 |
+
print("No valid output_path. Your video will be saved here:", output_path)
|
224 |
+
elif not os.path.exists(output_path):
|
225 |
+
os.makedirs(output_path, exist_ok=True)
|
226 |
+
print("Your video will be saved here:", output_path)
|
227 |
+
output_frames_path = os.path.join(output_path, "frames")
|
228 |
+
output_video_path = os.path.join(output_path, "video.mp4")
|
229 |
+
os.makedirs(output_frames_path, exist_ok=True)
|
230 |
+
# process
|
231 |
+
ebsynth_config = {
|
232 |
+
"minimum_patch_size": minimum_patch_size,
|
233 |
+
"threads_per_block": 8,
|
234 |
+
"num_iter": num_iter,
|
235 |
+
"gpu_id": 0,
|
236 |
+
"guide_weight": guide_weight,
|
237 |
+
"initialize": initialize,
|
238 |
+
"tracking_window_size": tracking_window_size
|
239 |
+
}
|
240 |
+
if len(index_style)==1:
|
241 |
+
InterpolationModeSingleFrameRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
|
242 |
+
else:
|
243 |
+
InterpolationModeRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
|
244 |
+
try:
|
245 |
+
fps = int(fps)
|
246 |
+
except:
|
247 |
+
fps = 30
|
248 |
+
print("Fps:", fps)
|
249 |
+
print("Saving video...")
|
250 |
+
video_path = save_video(output_frames_path, output_video_path, num_frames=len(frames_guide), fps=fps)
|
251 |
+
print("Success!")
|
252 |
+
print("Your frames are here:", output_frames_path)
|
253 |
+
print("Your video is here:", video_path)
|
254 |
+
return output_path, fps, video_path
|
255 |
+
|
256 |
+
|
257 |
+
def on_ui_tabs():
|
258 |
+
with gr.Blocks(analytics_enabled=False) as ui_component:
|
259 |
+
with gr.Tab("Blend"):
|
260 |
+
gr.Markdown("""
|
261 |
+
# Blend
|
262 |
+
|
263 |
+
Given a guide video and a style video, this algorithm will make the style video fluent according to the motion features of the guide video. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/208d902d-6aba-48d7-b7d5-cd120ebd306d) to see the example. Note that this extension doesn't support long videos. Please use short videos (e.g., several seconds). The algorithm is mainly designed for 512*512 resolution. Please use a larger `Minimum patch size` for higher resolution.
|
264 |
+
""")
|
265 |
+
with gr.Row():
|
266 |
+
with gr.Column():
|
267 |
+
with gr.Tab("Guide video"):
|
268 |
+
video_guide = gr.Video(label="Guide video")
|
269 |
+
with gr.Tab("Guide video (images format)"):
|
270 |
+
video_guide_folder = gr.Textbox(label="Guide video (images format)", value="")
|
271 |
+
with gr.Column():
|
272 |
+
with gr.Tab("Style video"):
|
273 |
+
video_style = gr.Video(label="Style video")
|
274 |
+
with gr.Tab("Style video (images format)"):
|
275 |
+
video_style_folder = gr.Textbox(label="Style video (images format)", value="")
|
276 |
+
with gr.Column():
|
277 |
+
output_path = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of style video")
|
278 |
+
fps = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
|
279 |
+
video_output = gr.Video(label="Output video", interactive=False, show_share_button=True)
|
280 |
+
btn = gr.Button(value="Blend")
|
281 |
+
with gr.Row():
|
282 |
+
with gr.Column():
|
283 |
+
gr.Markdown("# Settings")
|
284 |
+
mode = gr.Radio(["Fast", "Balanced", "Accurate"], label="Inference mode", value="Fast", interactive=True)
|
285 |
+
window_size = gr.Slider(label="Sliding window size", value=15, minimum=1, maximum=1000, step=1, interactive=True)
|
286 |
+
batch_size = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
|
287 |
+
tracking_window_size = gr.Slider(label="Tracking window size (only for accurate mode)", value=0, minimum=0, maximum=10, step=1, interactive=True)
|
288 |
+
gr.Markdown("## Advanced Settings")
|
289 |
+
minimum_patch_size = gr.Slider(label="Minimum patch size (odd number)", value=5, minimum=5, maximum=99, step=2, interactive=True)
|
290 |
+
num_iter = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
|
291 |
+
guide_weight = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
|
292 |
+
initialize = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
|
293 |
+
with gr.Column():
|
294 |
+
gr.Markdown("""
|
295 |
+
# Reference
|
296 |
+
|
297 |
+
* Output directory: the directory to save the video.
|
298 |
+
* Inference mode
|
299 |
+
|
300 |
+
|Mode|Time|Memory|Quality|Frame by frame output|Description|
|
301 |
+
|-|-|-|-|-|-|
|
302 |
+
|Fast|■|■■■|■■|No|Blend the frames using a tree-like data structure, which requires much RAM but is fast.|
|
303 |
+
|Balanced|■■|■|■■|Yes|Blend the frames naively.|
|
304 |
+
|Accurate|■■■|■|■■■|Yes|Blend the frames and align them together for higher video quality. When [batch size] >= [sliding window size] * 2 + 1, the performance is the best.|
|
305 |
+
|
306 |
+
* Sliding window size: our algorithm will blend the frames in a sliding windows. If the size is n, each frame will be blended with the last n frames and the next n frames. A large sliding window can make the video fluent but sometimes smoggy.
|
307 |
+
* Batch size: a larger batch size makes the program faster but requires more VRAM.
|
308 |
+
* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
|
309 |
+
* Advanced settings
|
310 |
+
* Minimum patch size (odd number): the minimum patch size used for patch matching. (Default: 5)
|
311 |
+
* Number of iterations: the number of iterations of patch matching. (Default: 5)
|
312 |
+
* Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
|
313 |
+
* NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
|
314 |
+
""")
|
315 |
+
btn.click(
|
316 |
+
smooth_video,
|
317 |
+
inputs=[
|
318 |
+
video_guide,
|
319 |
+
video_guide_folder,
|
320 |
+
video_style,
|
321 |
+
video_style_folder,
|
322 |
+
mode,
|
323 |
+
window_size,
|
324 |
+
batch_size,
|
325 |
+
tracking_window_size,
|
326 |
+
output_path,
|
327 |
+
fps,
|
328 |
+
minimum_patch_size,
|
329 |
+
num_iter,
|
330 |
+
guide_weight,
|
331 |
+
initialize
|
332 |
+
],
|
333 |
+
outputs=[output_path, fps, video_output]
|
334 |
+
)
|
335 |
+
with gr.Tab("Interpolate"):
|
336 |
+
gr.Markdown("""
|
337 |
+
# Interpolate
|
338 |
+
|
339 |
+
Given a guide video and some rendered keyframes, this algorithm will render the remaining frames. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/3490c5b4-8f67-478f-86de-f9adc2ace16a) to see the example. The algorithm is experimental and is only tested for 512*512 resolution.
|
340 |
+
""")
|
341 |
+
with gr.Row():
|
342 |
+
with gr.Column():
|
343 |
+
with gr.Row():
|
344 |
+
with gr.Column():
|
345 |
+
video_guide_folder_ = gr.Textbox(label="Guide video (images format)", value="")
|
346 |
+
with gr.Column():
|
347 |
+
rendered_keyframes_ = gr.Textbox(label="Rendered keyframes (images format)", value="")
|
348 |
+
with gr.Row():
|
349 |
+
detected_frames = gr.Textbox(label="Detected frames", value="Please input the directory of guide video and rendered frames", lines=9, max_lines=9, interactive=False)
|
350 |
+
video_guide_folder_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
|
351 |
+
rendered_keyframes_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
|
352 |
+
with gr.Column():
|
353 |
+
output_path_ = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of rendered keyframes")
|
354 |
+
fps_ = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
|
355 |
+
video_output_ = gr.Video(label="Output video", interactive=False, show_share_button=True)
|
356 |
+
btn_ = gr.Button(value="Interpolate")
|
357 |
+
with gr.Row():
|
358 |
+
with gr.Column():
|
359 |
+
gr.Markdown("# Settings")
|
360 |
+
batch_size_ = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
|
361 |
+
tracking_window_size_ = gr.Slider(label="Tracking window size", value=0, minimum=0, maximum=10, step=1, interactive=True)
|
362 |
+
gr.Markdown("## Advanced Settings")
|
363 |
+
minimum_patch_size_ = gr.Slider(label="Minimum patch size (odd number, larger is better)", value=15, minimum=5, maximum=99, step=2, interactive=True)
|
364 |
+
num_iter_ = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
|
365 |
+
guide_weight_ = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
|
366 |
+
initialize_ = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
|
367 |
+
with gr.Column():
|
368 |
+
gr.Markdown("""
|
369 |
+
# Reference
|
370 |
+
|
371 |
+
* Output directory: the directory to save the video.
|
372 |
+
* Batch size: a larger batch size makes the program faster but requires more VRAM.
|
373 |
+
* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
|
374 |
+
* Advanced settings
|
375 |
+
* Minimum patch size (odd number): the minimum patch size used for patch matching. **This parameter should be larger than that in blending. (Default: 15)**
|
376 |
+
* Number of iterations: the number of iterations of patch matching. (Default: 5)
|
377 |
+
* Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
|
378 |
+
* NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
|
379 |
+
""")
|
380 |
+
btn_.click(
|
381 |
+
interpolate_video,
|
382 |
+
inputs=[
|
383 |
+
video_guide_folder_,
|
384 |
+
rendered_keyframes_,
|
385 |
+
output_path_,
|
386 |
+
fps_,
|
387 |
+
batch_size_,
|
388 |
+
tracking_window_size_,
|
389 |
+
minimum_patch_size_,
|
390 |
+
num_iter_,
|
391 |
+
guide_weight_,
|
392 |
+
initialize_,
|
393 |
+
],
|
394 |
+
outputs=[output_path_, fps_, video_output_]
|
395 |
+
)
|
396 |
+
|
397 |
+
return [(ui_component, "FastBlend", "FastBlend_ui")]
|
diffsynth/extensions/FastBlend/cupy_kernels.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cupy as cp
|
2 |
+
|
3 |
+
remapping_kernel = cp.RawKernel(r'''
|
4 |
+
extern "C" __global__
|
5 |
+
void remap(
|
6 |
+
const int height,
|
7 |
+
const int width,
|
8 |
+
const int channel,
|
9 |
+
const int patch_size,
|
10 |
+
const int pad_size,
|
11 |
+
const float* source_style,
|
12 |
+
const int* nnf,
|
13 |
+
float* target_style
|
14 |
+
) {
|
15 |
+
const int r = (patch_size - 1) / 2;
|
16 |
+
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
17 |
+
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
18 |
+
if (x >= height or y >= width) return;
|
19 |
+
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
20 |
+
const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size);
|
21 |
+
const int min_px = x < r ? -x : -r;
|
22 |
+
const int max_px = x + r > height - 1 ? height - 1 - x : r;
|
23 |
+
const int min_py = y < r ? -y : -r;
|
24 |
+
const int max_py = y + r > width - 1 ? width - 1 - y : r;
|
25 |
+
int num = 0;
|
26 |
+
for (int px = min_px; px <= max_px; px++){
|
27 |
+
for (int py = min_py; py <= max_py; py++){
|
28 |
+
const int nid = (x + px) * width + y + py;
|
29 |
+
const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px;
|
30 |
+
const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py;
|
31 |
+
if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue;
|
32 |
+
const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size);
|
33 |
+
num++;
|
34 |
+
for (int c = 0; c < channel; c++){
|
35 |
+
target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c];
|
36 |
+
}
|
37 |
+
}
|
38 |
+
}
|
39 |
+
for (int c = 0; c < channel; c++){
|
40 |
+
target_style[z + pid * channel + c] /= num;
|
41 |
+
}
|
42 |
+
}
|
43 |
+
''', 'remap')
|
44 |
+
|
45 |
+
|
46 |
+
patch_error_kernel = cp.RawKernel(r'''
|
47 |
+
extern "C" __global__
|
48 |
+
void patch_error(
|
49 |
+
const int height,
|
50 |
+
const int width,
|
51 |
+
const int channel,
|
52 |
+
const int patch_size,
|
53 |
+
const int pad_size,
|
54 |
+
const float* source,
|
55 |
+
const int* nnf,
|
56 |
+
const float* target,
|
57 |
+
float* error
|
58 |
+
) {
|
59 |
+
const int r = (patch_size - 1) / 2;
|
60 |
+
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
61 |
+
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
62 |
+
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
63 |
+
if (x >= height or y >= width) return;
|
64 |
+
const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0];
|
65 |
+
const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1];
|
66 |
+
float e = 0;
|
67 |
+
for (int px = -r; px <= r; px++){
|
68 |
+
for (int py = -r; py <= r; py++){
|
69 |
+
const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py;
|
70 |
+
const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py;
|
71 |
+
for (int c = 0; c < channel; c++){
|
72 |
+
const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c];
|
73 |
+
e += diff * diff;
|
74 |
+
}
|
75 |
+
}
|
76 |
+
}
|
77 |
+
error[blockIdx.z * height * width + x * width + y] = e;
|
78 |
+
}
|
79 |
+
''', 'patch_error')
|
80 |
+
|
81 |
+
|
82 |
+
pairwise_patch_error_kernel = cp.RawKernel(r'''
|
83 |
+
extern "C" __global__
|
84 |
+
void pairwise_patch_error(
|
85 |
+
const int height,
|
86 |
+
const int width,
|
87 |
+
const int channel,
|
88 |
+
const int patch_size,
|
89 |
+
const int pad_size,
|
90 |
+
const float* source_a,
|
91 |
+
const int* nnf_a,
|
92 |
+
const float* source_b,
|
93 |
+
const int* nnf_b,
|
94 |
+
float* error
|
95 |
+
) {
|
96 |
+
const int r = (patch_size - 1) / 2;
|
97 |
+
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
98 |
+
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
99 |
+
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
100 |
+
if (x >= height or y >= width) return;
|
101 |
+
const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2;
|
102 |
+
const int x_a = nnf_a[z_nnf + 0];
|
103 |
+
const int y_a = nnf_a[z_nnf + 1];
|
104 |
+
const int x_b = nnf_b[z_nnf + 0];
|
105 |
+
const int y_b = nnf_b[z_nnf + 1];
|
106 |
+
float e = 0;
|
107 |
+
for (int px = -r; px <= r; px++){
|
108 |
+
for (int py = -r; py <= r; py++){
|
109 |
+
const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py;
|
110 |
+
const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py;
|
111 |
+
for (int c = 0; c < channel; c++){
|
112 |
+
const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c];
|
113 |
+
e += diff * diff;
|
114 |
+
}
|
115 |
+
}
|
116 |
+
}
|
117 |
+
error[blockIdx.z * height * width + x * width + y] = e;
|
118 |
+
}
|
119 |
+
''', 'pairwise_patch_error')
|
diffsynth/extensions/FastBlend/data.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import imageio, os
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
def read_video(file_name):
|
7 |
+
reader = imageio.get_reader(file_name)
|
8 |
+
video = []
|
9 |
+
for frame in reader:
|
10 |
+
frame = np.array(frame)
|
11 |
+
video.append(frame)
|
12 |
+
reader.close()
|
13 |
+
return video
|
14 |
+
|
15 |
+
|
16 |
+
def get_video_fps(file_name):
|
17 |
+
reader = imageio.get_reader(file_name)
|
18 |
+
fps = reader.get_meta_data()["fps"]
|
19 |
+
reader.close()
|
20 |
+
return fps
|
21 |
+
|
22 |
+
|
23 |
+
def save_video(frames_path, video_path, num_frames, fps):
|
24 |
+
writer = imageio.get_writer(video_path, fps=fps, quality=9)
|
25 |
+
for i in range(num_frames):
|
26 |
+
frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i)))
|
27 |
+
writer.append_data(frame)
|
28 |
+
writer.close()
|
29 |
+
return video_path
|
30 |
+
|
31 |
+
|
32 |
+
class LowMemoryVideo:
|
33 |
+
def __init__(self, file_name):
|
34 |
+
self.reader = imageio.get_reader(file_name)
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return self.reader.count_frames()
|
38 |
+
|
39 |
+
def __getitem__(self, item):
|
40 |
+
return np.array(self.reader.get_data(item))
|
41 |
+
|
42 |
+
def __del__(self):
|
43 |
+
self.reader.close()
|
44 |
+
|
45 |
+
|
46 |
+
def split_file_name(file_name):
|
47 |
+
result = []
|
48 |
+
number = -1
|
49 |
+
for i in file_name:
|
50 |
+
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
51 |
+
if number == -1:
|
52 |
+
number = 0
|
53 |
+
number = number*10 + ord(i) - ord("0")
|
54 |
+
else:
|
55 |
+
if number != -1:
|
56 |
+
result.append(number)
|
57 |
+
number = -1
|
58 |
+
result.append(i)
|
59 |
+
if number != -1:
|
60 |
+
result.append(number)
|
61 |
+
result = tuple(result)
|
62 |
+
return result
|
63 |
+
|
64 |
+
|
65 |
+
def search_for_images(folder):
|
66 |
+
file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
|
67 |
+
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
|
68 |
+
file_list = [i[1] for i in sorted(file_list)]
|
69 |
+
file_list = [os.path.join(folder, i) for i in file_list]
|
70 |
+
return file_list
|
71 |
+
|
72 |
+
|
73 |
+
def read_images(folder):
|
74 |
+
file_list = search_for_images(folder)
|
75 |
+
frames = [np.array(Image.open(i)) for i in file_list]
|
76 |
+
return frames
|
77 |
+
|
78 |
+
|
79 |
+
class LowMemoryImageFolder:
|
80 |
+
def __init__(self, folder, file_list=None):
|
81 |
+
if file_list is None:
|
82 |
+
self.file_list = search_for_images(folder)
|
83 |
+
else:
|
84 |
+
self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
|
85 |
+
|
86 |
+
def __len__(self):
|
87 |
+
return len(self.file_list)
|
88 |
+
|
89 |
+
def __getitem__(self, item):
|
90 |
+
return np.array(Image.open(self.file_list[item]))
|
91 |
+
|
92 |
+
def __del__(self):
|
93 |
+
pass
|
94 |
+
|
95 |
+
|
96 |
+
class VideoData:
|
97 |
+
def __init__(self, video_file, image_folder, **kwargs):
|
98 |
+
if video_file is not None:
|
99 |
+
self.data_type = "video"
|
100 |
+
self.data = LowMemoryVideo(video_file, **kwargs)
|
101 |
+
elif image_folder is not None:
|
102 |
+
self.data_type = "images"
|
103 |
+
self.data = LowMemoryImageFolder(image_folder, **kwargs)
|
104 |
+
else:
|
105 |
+
raise ValueError("Cannot open video or image folder")
|
106 |
+
self.length = None
|
107 |
+
self.height = None
|
108 |
+
self.width = None
|
109 |
+
|
110 |
+
def raw_data(self):
|
111 |
+
frames = []
|
112 |
+
for i in range(self.__len__()):
|
113 |
+
frames.append(self.__getitem__(i))
|
114 |
+
return frames
|
115 |
+
|
116 |
+
def set_length(self, length):
|
117 |
+
self.length = length
|
118 |
+
|
119 |
+
def set_shape(self, height, width):
|
120 |
+
self.height = height
|
121 |
+
self.width = width
|
122 |
+
|
123 |
+
def __len__(self):
|
124 |
+
if self.length is None:
|
125 |
+
return len(self.data)
|
126 |
+
else:
|
127 |
+
return self.length
|
128 |
+
|
129 |
+
def shape(self):
|
130 |
+
if self.height is not None and self.width is not None:
|
131 |
+
return self.height, self.width
|
132 |
+
else:
|
133 |
+
height, width, _ = self.__getitem__(0).shape
|
134 |
+
return height, width
|
135 |
+
|
136 |
+
def __getitem__(self, item):
|
137 |
+
frame = self.data.__getitem__(item)
|
138 |
+
height, width, _ = frame.shape
|
139 |
+
if self.height is not None and self.width is not None:
|
140 |
+
if self.height != height or self.width != width:
|
141 |
+
frame = Image.fromarray(frame).resize((self.width, self.height))
|
142 |
+
frame = np.array(frame)
|
143 |
+
return frame
|
144 |
+
|
145 |
+
def __del__(self):
|
146 |
+
pass
|
diffsynth/extensions/FastBlend/patch_match.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel
|
2 |
+
import numpy as np
|
3 |
+
import cupy as cp
|
4 |
+
import cv2
|
5 |
+
|
6 |
+
|
7 |
+
class PatchMatcher:
|
8 |
+
def __init__(
|
9 |
+
self, height, width, channel, minimum_patch_size,
|
10 |
+
threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
|
11 |
+
random_search_steps=3, random_search_range=4,
|
12 |
+
use_mean_target_style=False, use_pairwise_patch_error=False,
|
13 |
+
tracking_window_size=0
|
14 |
+
):
|
15 |
+
self.height = height
|
16 |
+
self.width = width
|
17 |
+
self.channel = channel
|
18 |
+
self.minimum_patch_size = minimum_patch_size
|
19 |
+
self.threads_per_block = threads_per_block
|
20 |
+
self.num_iter = num_iter
|
21 |
+
self.gpu_id = gpu_id
|
22 |
+
self.guide_weight = guide_weight
|
23 |
+
self.random_search_steps = random_search_steps
|
24 |
+
self.random_search_range = random_search_range
|
25 |
+
self.use_mean_target_style = use_mean_target_style
|
26 |
+
self.use_pairwise_patch_error = use_pairwise_patch_error
|
27 |
+
self.tracking_window_size = tracking_window_size
|
28 |
+
|
29 |
+
self.patch_size_list = [minimum_patch_size + i*2 for i in range(num_iter)][::-1]
|
30 |
+
self.pad_size = self.patch_size_list[0] // 2
|
31 |
+
self.grid = (
|
32 |
+
(height + threads_per_block - 1) // threads_per_block,
|
33 |
+
(width + threads_per_block - 1) // threads_per_block
|
34 |
+
)
|
35 |
+
self.block = (threads_per_block, threads_per_block)
|
36 |
+
|
37 |
+
def pad_image(self, image):
|
38 |
+
return cp.pad(image, ((0, 0), (self.pad_size, self.pad_size), (self.pad_size, self.pad_size), (0, 0)))
|
39 |
+
|
40 |
+
def unpad_image(self, image):
|
41 |
+
return image[:, self.pad_size: -self.pad_size, self.pad_size: -self.pad_size, :]
|
42 |
+
|
43 |
+
def apply_nnf_to_image(self, nnf, source):
|
44 |
+
batch_size = source.shape[0]
|
45 |
+
target = cp.zeros((batch_size, self.height + self.pad_size * 2, self.width + self.pad_size * 2, self.channel), dtype=cp.float32)
|
46 |
+
remapping_kernel(
|
47 |
+
self.grid + (batch_size,),
|
48 |
+
self.block,
|
49 |
+
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target)
|
50 |
+
)
|
51 |
+
return target
|
52 |
+
|
53 |
+
def get_patch_error(self, source, nnf, target):
|
54 |
+
batch_size = source.shape[0]
|
55 |
+
error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
|
56 |
+
patch_error_kernel(
|
57 |
+
self.grid + (batch_size,),
|
58 |
+
self.block,
|
59 |
+
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target, error)
|
60 |
+
)
|
61 |
+
return error
|
62 |
+
|
63 |
+
def get_pairwise_patch_error(self, source, nnf):
|
64 |
+
batch_size = source.shape[0]//2
|
65 |
+
error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
|
66 |
+
source_a, nnf_a = source[0::2].copy(), nnf[0::2].copy()
|
67 |
+
source_b, nnf_b = source[1::2].copy(), nnf[1::2].copy()
|
68 |
+
pairwise_patch_error_kernel(
|
69 |
+
self.grid + (batch_size,),
|
70 |
+
self.block,
|
71 |
+
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source_a, nnf_a, source_b, nnf_b, error)
|
72 |
+
)
|
73 |
+
error = error.repeat(2, axis=0)
|
74 |
+
return error
|
75 |
+
|
76 |
+
def get_error(self, source_guide, target_guide, source_style, target_style, nnf):
|
77 |
+
error_guide = self.get_patch_error(source_guide, nnf, target_guide)
|
78 |
+
if self.use_mean_target_style:
|
79 |
+
target_style = self.apply_nnf_to_image(nnf, source_style)
|
80 |
+
target_style = target_style.mean(axis=0, keepdims=True)
|
81 |
+
target_style = target_style.repeat(source_guide.shape[0], axis=0)
|
82 |
+
if self.use_pairwise_patch_error:
|
83 |
+
error_style = self.get_pairwise_patch_error(source_style, nnf)
|
84 |
+
else:
|
85 |
+
error_style = self.get_patch_error(source_style, nnf, target_style)
|
86 |
+
error = error_guide * self.guide_weight + error_style
|
87 |
+
return error
|
88 |
+
|
89 |
+
def clamp_bound(self, nnf):
|
90 |
+
nnf[:,:,:,0] = cp.clip(nnf[:,:,:,0], 0, self.height-1)
|
91 |
+
nnf[:,:,:,1] = cp.clip(nnf[:,:,:,1], 0, self.width-1)
|
92 |
+
return nnf
|
93 |
+
|
94 |
+
def random_step(self, nnf, r):
|
95 |
+
batch_size = nnf.shape[0]
|
96 |
+
step = cp.random.randint(-r, r+1, size=(batch_size, self.height, self.width, 2), dtype=cp.int32)
|
97 |
+
upd_nnf = self.clamp_bound(nnf + step)
|
98 |
+
return upd_nnf
|
99 |
+
|
100 |
+
def neighboor_step(self, nnf, d):
|
101 |
+
if d==0:
|
102 |
+
upd_nnf = cp.concatenate([nnf[:, :1, :], nnf[:, :-1, :]], axis=1)
|
103 |
+
upd_nnf[:, :, :, 0] += 1
|
104 |
+
elif d==1:
|
105 |
+
upd_nnf = cp.concatenate([nnf[:, :, :1], nnf[:, :, :-1]], axis=2)
|
106 |
+
upd_nnf[:, :, :, 1] += 1
|
107 |
+
elif d==2:
|
108 |
+
upd_nnf = cp.concatenate([nnf[:, 1:, :], nnf[:, -1:, :]], axis=1)
|
109 |
+
upd_nnf[:, :, :, 0] -= 1
|
110 |
+
elif d==3:
|
111 |
+
upd_nnf = cp.concatenate([nnf[:, :, 1:], nnf[:, :, -1:]], axis=2)
|
112 |
+
upd_nnf[:, :, :, 1] -= 1
|
113 |
+
upd_nnf = self.clamp_bound(upd_nnf)
|
114 |
+
return upd_nnf
|
115 |
+
|
116 |
+
def shift_nnf(self, nnf, d):
|
117 |
+
if d>0:
|
118 |
+
d = min(nnf.shape[0], d)
|
119 |
+
upd_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
|
120 |
+
else:
|
121 |
+
d = max(-nnf.shape[0], d)
|
122 |
+
upd_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
|
123 |
+
return upd_nnf
|
124 |
+
|
125 |
+
def track_step(self, nnf, d):
|
126 |
+
if self.use_pairwise_patch_error:
|
127 |
+
upd_nnf = cp.zeros_like(nnf)
|
128 |
+
upd_nnf[0::2] = self.shift_nnf(nnf[0::2], d)
|
129 |
+
upd_nnf[1::2] = self.shift_nnf(nnf[1::2], d)
|
130 |
+
else:
|
131 |
+
upd_nnf = self.shift_nnf(nnf, d)
|
132 |
+
return upd_nnf
|
133 |
+
|
134 |
+
def C(self, n, m):
|
135 |
+
# not used
|
136 |
+
c = 1
|
137 |
+
for i in range(1, n+1):
|
138 |
+
c *= i
|
139 |
+
for i in range(1, m+1):
|
140 |
+
c //= i
|
141 |
+
for i in range(1, n-m+1):
|
142 |
+
c //= i
|
143 |
+
return c
|
144 |
+
|
145 |
+
def bezier_step(self, nnf, r):
|
146 |
+
# not used
|
147 |
+
n = r * 2 - 1
|
148 |
+
upd_nnf = cp.zeros(shape=nnf.shape, dtype=cp.float32)
|
149 |
+
for i, d in enumerate(list(range(-r, 0)) + list(range(1, r+1))):
|
150 |
+
if d>0:
|
151 |
+
ctl_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
|
152 |
+
elif d<0:
|
153 |
+
ctl_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
|
154 |
+
upd_nnf += ctl_nnf * (self.C(n, i) / 2**n)
|
155 |
+
upd_nnf = self.clamp_bound(upd_nnf).astype(nnf.dtype)
|
156 |
+
return upd_nnf
|
157 |
+
|
158 |
+
def update(self, source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf):
|
159 |
+
upd_err = self.get_error(source_guide, target_guide, source_style, target_style, upd_nnf)
|
160 |
+
upd_idx = (upd_err < err)
|
161 |
+
nnf[upd_idx] = upd_nnf[upd_idx]
|
162 |
+
err[upd_idx] = upd_err[upd_idx]
|
163 |
+
return nnf, err
|
164 |
+
|
165 |
+
def propagation(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
166 |
+
for d in cp.random.permutation(4):
|
167 |
+
upd_nnf = self.neighboor_step(nnf, d)
|
168 |
+
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
169 |
+
return nnf, err
|
170 |
+
|
171 |
+
def random_search(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
172 |
+
for i in range(self.random_search_steps):
|
173 |
+
upd_nnf = self.random_step(nnf, self.random_search_range)
|
174 |
+
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
175 |
+
return nnf, err
|
176 |
+
|
177 |
+
def track(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
178 |
+
for d in range(1, self.tracking_window_size + 1):
|
179 |
+
upd_nnf = self.track_step(nnf, d)
|
180 |
+
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
181 |
+
upd_nnf = self.track_step(nnf, -d)
|
182 |
+
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
183 |
+
return nnf, err
|
184 |
+
|
185 |
+
def iteration(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
186 |
+
nnf, err = self.propagation(source_guide, target_guide, source_style, target_style, nnf, err)
|
187 |
+
nnf, err = self.random_search(source_guide, target_guide, source_style, target_style, nnf, err)
|
188 |
+
nnf, err = self.track(source_guide, target_guide, source_style, target_style, nnf, err)
|
189 |
+
return nnf, err
|
190 |
+
|
191 |
+
def estimate_nnf(self, source_guide, target_guide, source_style, nnf):
|
192 |
+
with cp.cuda.Device(self.gpu_id):
|
193 |
+
source_guide = self.pad_image(source_guide)
|
194 |
+
target_guide = self.pad_image(target_guide)
|
195 |
+
source_style = self.pad_image(source_style)
|
196 |
+
for it in range(self.num_iter):
|
197 |
+
self.patch_size = self.patch_size_list[it]
|
198 |
+
target_style = self.apply_nnf_to_image(nnf, source_style)
|
199 |
+
err = self.get_error(source_guide, target_guide, source_style, target_style, nnf)
|
200 |
+
nnf, err = self.iteration(source_guide, target_guide, source_style, target_style, nnf, err)
|
201 |
+
target_style = self.unpad_image(self.apply_nnf_to_image(nnf, source_style))
|
202 |
+
return nnf, target_style
|
203 |
+
|
204 |
+
|
205 |
+
class PyramidPatchMatcher:
|
206 |
+
def __init__(
|
207 |
+
self, image_height, image_width, channel, minimum_patch_size,
|
208 |
+
threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
|
209 |
+
use_mean_target_style=False, use_pairwise_patch_error=False,
|
210 |
+
tracking_window_size=0,
|
211 |
+
initialize="identity"
|
212 |
+
):
|
213 |
+
maximum_patch_size = minimum_patch_size + (num_iter - 1) * 2
|
214 |
+
self.pyramid_level = int(np.log2(min(image_height, image_width) / maximum_patch_size))
|
215 |
+
self.pyramid_heights = []
|
216 |
+
self.pyramid_widths = []
|
217 |
+
self.patch_matchers = []
|
218 |
+
self.minimum_patch_size = minimum_patch_size
|
219 |
+
self.num_iter = num_iter
|
220 |
+
self.gpu_id = gpu_id
|
221 |
+
self.initialize = initialize
|
222 |
+
for level in range(self.pyramid_level):
|
223 |
+
height = image_height//(2**(self.pyramid_level - 1 - level))
|
224 |
+
width = image_width//(2**(self.pyramid_level - 1 - level))
|
225 |
+
self.pyramid_heights.append(height)
|
226 |
+
self.pyramid_widths.append(width)
|
227 |
+
self.patch_matchers.append(PatchMatcher(
|
228 |
+
height, width, channel, minimum_patch_size=minimum_patch_size,
|
229 |
+
threads_per_block=threads_per_block, num_iter=num_iter, gpu_id=gpu_id, guide_weight=guide_weight,
|
230 |
+
use_mean_target_style=use_mean_target_style, use_pairwise_patch_error=use_pairwise_patch_error,
|
231 |
+
tracking_window_size=tracking_window_size
|
232 |
+
))
|
233 |
+
|
234 |
+
def resample_image(self, images, level):
|
235 |
+
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
236 |
+
images = images.get()
|
237 |
+
images_resample = []
|
238 |
+
for image in images:
|
239 |
+
image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
|
240 |
+
images_resample.append(image_resample)
|
241 |
+
images_resample = cp.array(np.stack(images_resample), dtype=cp.float32)
|
242 |
+
return images_resample
|
243 |
+
|
244 |
+
def initialize_nnf(self, batch_size):
|
245 |
+
if self.initialize == "random":
|
246 |
+
height, width = self.pyramid_heights[0], self.pyramid_widths[0]
|
247 |
+
nnf = cp.stack([
|
248 |
+
cp.random.randint(0, height, (batch_size, height, width), dtype=cp.int32),
|
249 |
+
cp.random.randint(0, width, (batch_size, height, width), dtype=cp.int32)
|
250 |
+
], axis=3)
|
251 |
+
elif self.initialize == "identity":
|
252 |
+
height, width = self.pyramid_heights[0], self.pyramid_widths[0]
|
253 |
+
nnf = cp.stack([
|
254 |
+
cp.repeat(cp.arange(height), width).reshape(height, width),
|
255 |
+
cp.tile(cp.arange(width), height).reshape(height, width)
|
256 |
+
], axis=2)
|
257 |
+
nnf = cp.stack([nnf] * batch_size)
|
258 |
+
else:
|
259 |
+
raise NotImplementedError()
|
260 |
+
return nnf
|
261 |
+
|
262 |
+
def update_nnf(self, nnf, level):
|
263 |
+
# upscale
|
264 |
+
nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
|
265 |
+
nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1
|
266 |
+
nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1
|
267 |
+
# check if scale is 2
|
268 |
+
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
269 |
+
if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
|
270 |
+
nnf = nnf.get().astype(np.float32)
|
271 |
+
nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf]
|
272 |
+
nnf = cp.array(np.stack(nnf), dtype=cp.int32)
|
273 |
+
nnf = self.patch_matchers[level].clamp_bound(nnf)
|
274 |
+
return nnf
|
275 |
+
|
276 |
+
def apply_nnf_to_image(self, nnf, image):
|
277 |
+
with cp.cuda.Device(self.gpu_id):
|
278 |
+
image = self.patch_matchers[-1].pad_image(image)
|
279 |
+
image = self.patch_matchers[-1].apply_nnf_to_image(nnf, image)
|
280 |
+
return image
|
281 |
+
|
282 |
+
def estimate_nnf(self, source_guide, target_guide, source_style):
|
283 |
+
with cp.cuda.Device(self.gpu_id):
|
284 |
+
if not isinstance(source_guide, cp.ndarray):
|
285 |
+
source_guide = cp.array(source_guide, dtype=cp.float32)
|
286 |
+
if not isinstance(target_guide, cp.ndarray):
|
287 |
+
target_guide = cp.array(target_guide, dtype=cp.float32)
|
288 |
+
if not isinstance(source_style, cp.ndarray):
|
289 |
+
source_style = cp.array(source_style, dtype=cp.float32)
|
290 |
+
for level in range(self.pyramid_level):
|
291 |
+
nnf = self.initialize_nnf(source_guide.shape[0]) if level==0 else self.update_nnf(nnf, level)
|
292 |
+
source_guide_ = self.resample_image(source_guide, level)
|
293 |
+
target_guide_ = self.resample_image(target_guide, level)
|
294 |
+
source_style_ = self.resample_image(source_style, level)
|
295 |
+
nnf, target_style = self.patch_matchers[level].estimate_nnf(
|
296 |
+
source_guide_, target_guide_, source_style_, nnf
|
297 |
+
)
|
298 |
+
return nnf.get(), target_style.get()
|
diffsynth/extensions/FastBlend/runners/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .accurate import AccurateModeRunner
|
2 |
+
from .fast import FastModeRunner
|
3 |
+
from .balanced import BalancedModeRunner
|
4 |
+
from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner
|
diffsynth/extensions/FastBlend/runners/accurate.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..patch_match import PyramidPatchMatcher
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
class AccurateModeRunner:
|
9 |
+
def __init__(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None):
|
13 |
+
patch_match_engine = PyramidPatchMatcher(
|
14 |
+
image_height=frames_style[0].shape[0],
|
15 |
+
image_width=frames_style[0].shape[1],
|
16 |
+
channel=3,
|
17 |
+
use_mean_target_style=True,
|
18 |
+
**ebsynth_config
|
19 |
+
)
|
20 |
+
# run
|
21 |
+
n = len(frames_style)
|
22 |
+
for target in tqdm(range(n), desc=desc):
|
23 |
+
l, r = max(target - window_size, 0), min(target + window_size + 1, n)
|
24 |
+
remapped_frames = []
|
25 |
+
for i in range(l, r, batch_size):
|
26 |
+
j = min(i + batch_size, r)
|
27 |
+
source_guide = np.stack([frames_guide[source] for source in range(i, j)])
|
28 |
+
target_guide = np.stack([frames_guide[target]] * (j - i))
|
29 |
+
source_style = np.stack([frames_style[source] for source in range(i, j)])
|
30 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
31 |
+
remapped_frames.append(target_style)
|
32 |
+
frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
|
33 |
+
frame = frame.clip(0, 255).astype("uint8")
|
34 |
+
if save_path is not None:
|
35 |
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
diffsynth/extensions/FastBlend/runners/balanced.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..patch_match import PyramidPatchMatcher
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
class BalancedModeRunner:
|
9 |
+
def __init__(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None):
|
13 |
+
patch_match_engine = PyramidPatchMatcher(
|
14 |
+
image_height=frames_style[0].shape[0],
|
15 |
+
image_width=frames_style[0].shape[1],
|
16 |
+
channel=3,
|
17 |
+
**ebsynth_config
|
18 |
+
)
|
19 |
+
# tasks
|
20 |
+
n = len(frames_style)
|
21 |
+
tasks = []
|
22 |
+
for target in range(n):
|
23 |
+
for source in range(target - window_size, target + window_size + 1):
|
24 |
+
if source >= 0 and source < n and source != target:
|
25 |
+
tasks.append((source, target))
|
26 |
+
# run
|
27 |
+
frames = [(None, 1) for i in range(n)]
|
28 |
+
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
29 |
+
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
30 |
+
source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
|
31 |
+
target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
|
32 |
+
source_style = np.stack([frames_style[source] for source, target in tasks_batch])
|
33 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
34 |
+
for (source, target), result in zip(tasks_batch, target_style):
|
35 |
+
frame, weight = frames[target]
|
36 |
+
if frame is None:
|
37 |
+
frame = frames_style[target]
|
38 |
+
frames[target] = (
|
39 |
+
frame * (weight / (weight + 1)) + result / (weight + 1),
|
40 |
+
weight + 1
|
41 |
+
)
|
42 |
+
if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size):
|
43 |
+
frame = frame.clip(0, 255).astype("uint8")
|
44 |
+
if save_path is not None:
|
45 |
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
46 |
+
frames[target] = (None, 1)
|
diffsynth/extensions/FastBlend/runners/fast.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..patch_match import PyramidPatchMatcher
|
2 |
+
import functools, os
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
class TableManager:
|
9 |
+
def __init__(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
def task_list(self, n):
|
13 |
+
tasks = []
|
14 |
+
max_level = 1
|
15 |
+
while (1<<max_level)<=n:
|
16 |
+
max_level += 1
|
17 |
+
for i in range(n):
|
18 |
+
j = i
|
19 |
+
for level in range(max_level):
|
20 |
+
if i&(1<<level):
|
21 |
+
continue
|
22 |
+
j |= 1<<level
|
23 |
+
if j>=n:
|
24 |
+
break
|
25 |
+
meta_data = {
|
26 |
+
"source": i,
|
27 |
+
"target": j,
|
28 |
+
"level": level + 1
|
29 |
+
}
|
30 |
+
tasks.append(meta_data)
|
31 |
+
tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"]))
|
32 |
+
return tasks
|
33 |
+
|
34 |
+
def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""):
|
35 |
+
n = len(frames_guide)
|
36 |
+
tasks = self.task_list(n)
|
37 |
+
remapping_table = [[(frames_style[i], 1)] for i in range(n)]
|
38 |
+
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
39 |
+
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
40 |
+
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
|
41 |
+
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
|
42 |
+
source_style = np.stack([frames_style[task["source"]] for task in tasks_batch])
|
43 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
44 |
+
for task, result in zip(tasks_batch, target_style):
|
45 |
+
target, level = task["target"], task["level"]
|
46 |
+
if len(remapping_table[target])==level:
|
47 |
+
remapping_table[target].append((result, 1))
|
48 |
+
else:
|
49 |
+
frame, weight = remapping_table[target][level]
|
50 |
+
remapping_table[target][level] = (
|
51 |
+
frame * (weight / (weight + 1)) + result / (weight + 1),
|
52 |
+
weight + 1
|
53 |
+
)
|
54 |
+
return remapping_table
|
55 |
+
|
56 |
+
def remapping_table_to_blending_table(self, table):
|
57 |
+
for i in range(len(table)):
|
58 |
+
for j in range(1, len(table[i])):
|
59 |
+
frame_1, weight_1 = table[i][j-1]
|
60 |
+
frame_2, weight_2 = table[i][j]
|
61 |
+
frame = (frame_1 + frame_2) / 2
|
62 |
+
weight = weight_1 + weight_2
|
63 |
+
table[i][j] = (frame, weight)
|
64 |
+
return table
|
65 |
+
|
66 |
+
def tree_query(self, leftbound, rightbound):
|
67 |
+
node_list = []
|
68 |
+
node_index = rightbound
|
69 |
+
while node_index>=leftbound:
|
70 |
+
node_level = 0
|
71 |
+
while (1<<node_level)&node_index and node_index-(1<<node_level+1)+1>=leftbound:
|
72 |
+
node_level += 1
|
73 |
+
node_list.append((node_index, node_level))
|
74 |
+
node_index -= 1<<node_level
|
75 |
+
return node_list
|
76 |
+
|
77 |
+
def process_window_sum(self, frames_guide, blending_table, patch_match_engine, window_size, batch_size, desc=""):
|
78 |
+
n = len(blending_table)
|
79 |
+
tasks = []
|
80 |
+
frames_result = []
|
81 |
+
for target in range(n):
|
82 |
+
node_list = self.tree_query(max(target-window_size, 0), target)
|
83 |
+
for source, level in node_list:
|
84 |
+
if source!=target:
|
85 |
+
meta_data = {
|
86 |
+
"source": source,
|
87 |
+
"target": target,
|
88 |
+
"level": level
|
89 |
+
}
|
90 |
+
tasks.append(meta_data)
|
91 |
+
else:
|
92 |
+
frames_result.append(blending_table[target][level])
|
93 |
+
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
94 |
+
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
95 |
+
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
|
96 |
+
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
|
97 |
+
source_style = np.stack([blending_table[task["source"]][task["level"]][0] for task in tasks_batch])
|
98 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
99 |
+
for task, frame_2 in zip(tasks_batch, target_style):
|
100 |
+
source, target, level = task["source"], task["target"], task["level"]
|
101 |
+
frame_1, weight_1 = frames_result[target]
|
102 |
+
weight_2 = blending_table[source][level][1]
|
103 |
+
weight = weight_1 + weight_2
|
104 |
+
frame = frame_1 * (weight_1 / weight) + frame_2 * (weight_2 / weight)
|
105 |
+
frames_result[target] = (frame, weight)
|
106 |
+
return frames_result
|
107 |
+
|
108 |
+
|
109 |
+
class FastModeRunner:
|
110 |
+
def __init__(self):
|
111 |
+
pass
|
112 |
+
|
113 |
+
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, save_path=None):
|
114 |
+
frames_guide = frames_guide.raw_data()
|
115 |
+
frames_style = frames_style.raw_data()
|
116 |
+
table_manager = TableManager()
|
117 |
+
patch_match_engine = PyramidPatchMatcher(
|
118 |
+
image_height=frames_style[0].shape[0],
|
119 |
+
image_width=frames_style[0].shape[1],
|
120 |
+
channel=3,
|
121 |
+
**ebsynth_config
|
122 |
+
)
|
123 |
+
# left part
|
124 |
+
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="Fast Mode Step 1/4")
|
125 |
+
table_l = table_manager.remapping_table_to_blending_table(table_l)
|
126 |
+
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 2/4")
|
127 |
+
# right part
|
128 |
+
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="Fast Mode Step 3/4")
|
129 |
+
table_r = table_manager.remapping_table_to_blending_table(table_r)
|
130 |
+
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 4/4")[::-1]
|
131 |
+
# merge
|
132 |
+
frames = []
|
133 |
+
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
|
134 |
+
weight_m = -1
|
135 |
+
weight = weight_l + weight_m + weight_r
|
136 |
+
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
|
137 |
+
frames.append(frame)
|
138 |
+
frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
|
139 |
+
if save_path is not None:
|
140 |
+
for target, frame in enumerate(frames):
|
141 |
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
diffsynth/extensions/FastBlend/runners/interpolation.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..patch_match import PyramidPatchMatcher
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
class InterpolationModeRunner:
|
9 |
+
def __init__(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
def get_index_dict(self, index_style):
|
13 |
+
index_dict = {}
|
14 |
+
for i, index in enumerate(index_style):
|
15 |
+
index_dict[index] = i
|
16 |
+
return index_dict
|
17 |
+
|
18 |
+
def get_weight(self, l, m, r):
|
19 |
+
weight_l, weight_r = abs(m - r), abs(m - l)
|
20 |
+
if weight_l + weight_r == 0:
|
21 |
+
weight_l, weight_r = 0.5, 0.5
|
22 |
+
else:
|
23 |
+
weight_l, weight_r = weight_l / (weight_l + weight_r), weight_r / (weight_l + weight_r)
|
24 |
+
return weight_l, weight_r
|
25 |
+
|
26 |
+
def get_task_group(self, index_style, n):
|
27 |
+
task_group = []
|
28 |
+
index_style = sorted(index_style)
|
29 |
+
# first frame
|
30 |
+
if index_style[0]>0:
|
31 |
+
tasks = []
|
32 |
+
for m in range(index_style[0]):
|
33 |
+
tasks.append((index_style[0], m, index_style[0]))
|
34 |
+
task_group.append(tasks)
|
35 |
+
# middle frames
|
36 |
+
for l, r in zip(index_style[:-1], index_style[1:]):
|
37 |
+
tasks = []
|
38 |
+
for m in range(l, r):
|
39 |
+
tasks.append((l, m, r))
|
40 |
+
task_group.append(tasks)
|
41 |
+
# last frame
|
42 |
+
tasks = []
|
43 |
+
for m in range(index_style[-1], n):
|
44 |
+
tasks.append((index_style[-1], m, index_style[-1]))
|
45 |
+
task_group.append(tasks)
|
46 |
+
return task_group
|
47 |
+
|
48 |
+
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
|
49 |
+
patch_match_engine = PyramidPatchMatcher(
|
50 |
+
image_height=frames_style[0].shape[0],
|
51 |
+
image_width=frames_style[0].shape[1],
|
52 |
+
channel=3,
|
53 |
+
use_mean_target_style=False,
|
54 |
+
use_pairwise_patch_error=True,
|
55 |
+
**ebsynth_config
|
56 |
+
)
|
57 |
+
# task
|
58 |
+
index_dict = self.get_index_dict(index_style)
|
59 |
+
task_group = self.get_task_group(index_style, len(frames_guide))
|
60 |
+
# run
|
61 |
+
for tasks in task_group:
|
62 |
+
index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks])
|
63 |
+
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"):
|
64 |
+
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
65 |
+
source_guide, target_guide, source_style = [], [], []
|
66 |
+
for l, m, r in tasks_batch:
|
67 |
+
# l -> m
|
68 |
+
source_guide.append(frames_guide[l])
|
69 |
+
target_guide.append(frames_guide[m])
|
70 |
+
source_style.append(frames_style[index_dict[l]])
|
71 |
+
# r -> m
|
72 |
+
source_guide.append(frames_guide[r])
|
73 |
+
target_guide.append(frames_guide[m])
|
74 |
+
source_style.append(frames_style[index_dict[r]])
|
75 |
+
source_guide = np.stack(source_guide)
|
76 |
+
target_guide = np.stack(target_guide)
|
77 |
+
source_style = np.stack(source_style)
|
78 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
79 |
+
if save_path is not None:
|
80 |
+
for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch):
|
81 |
+
weight_l, weight_r = self.get_weight(l, m, r)
|
82 |
+
frame = frame_l * weight_l + frame_r * weight_r
|
83 |
+
frame = frame.clip(0, 255).astype("uint8")
|
84 |
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m))
|
85 |
+
|
86 |
+
|
87 |
+
class InterpolationModeSingleFrameRunner:
|
88 |
+
def __init__(self):
|
89 |
+
pass
|
90 |
+
|
91 |
+
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
|
92 |
+
# check input
|
93 |
+
tracking_window_size = ebsynth_config["tracking_window_size"]
|
94 |
+
if tracking_window_size * 2 >= batch_size:
|
95 |
+
raise ValueError("batch_size should be larger than track_window_size * 2")
|
96 |
+
frame_style = frames_style[0]
|
97 |
+
frame_guide = frames_guide[index_style[0]]
|
98 |
+
patch_match_engine = PyramidPatchMatcher(
|
99 |
+
image_height=frame_style.shape[0],
|
100 |
+
image_width=frame_style.shape[1],
|
101 |
+
channel=3,
|
102 |
+
**ebsynth_config
|
103 |
+
)
|
104 |
+
# run
|
105 |
+
frame_id, n = 0, len(frames_guide)
|
106 |
+
for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"):
|
107 |
+
if i + batch_size > n:
|
108 |
+
l, r = max(n - batch_size, 0), n
|
109 |
+
else:
|
110 |
+
l, r = i, i + batch_size
|
111 |
+
source_guide = np.stack([frame_guide] * (r-l))
|
112 |
+
target_guide = np.stack([frames_guide[i] for i in range(l, r)])
|
113 |
+
source_style = np.stack([frame_style] * (r-l))
|
114 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
115 |
+
for i, frame in zip(range(l, r), target_style):
|
116 |
+
if i==frame_id:
|
117 |
+
frame = frame.clip(0, 255).astype("uint8")
|
118 |
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id))
|
119 |
+
frame_id += 1
|
120 |
+
if r < n and r-frame_id <= tracking_window_size:
|
121 |
+
break
|
diffsynth/extensions/RIFE/__init__.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
def warp(tenInput, tenFlow, device):
|
9 |
+
backwarp_tenGrid = {}
|
10 |
+
k = (str(tenFlow.device), str(tenFlow.size()))
|
11 |
+
if k not in backwarp_tenGrid:
|
12 |
+
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
|
13 |
+
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
|
14 |
+
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
|
15 |
+
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
|
16 |
+
backwarp_tenGrid[k] = torch.cat(
|
17 |
+
[tenHorizontal, tenVertical], 1).to(device)
|
18 |
+
|
19 |
+
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
|
20 |
+
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
|
21 |
+
|
22 |
+
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
|
23 |
+
return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
|
24 |
+
|
25 |
+
|
26 |
+
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
27 |
+
return nn.Sequential(
|
28 |
+
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
29 |
+
padding=padding, dilation=dilation, bias=True),
|
30 |
+
nn.PReLU(out_planes)
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
class IFBlock(nn.Module):
|
35 |
+
def __init__(self, in_planes, c=64):
|
36 |
+
super(IFBlock, self).__init__()
|
37 |
+
self.conv0 = nn.Sequential(conv(in_planes, c//2, 3, 2, 1), conv(c//2, c, 3, 2, 1),)
|
38 |
+
self.convblock0 = nn.Sequential(conv(c, c), conv(c, c))
|
39 |
+
self.convblock1 = nn.Sequential(conv(c, c), conv(c, c))
|
40 |
+
self.convblock2 = nn.Sequential(conv(c, c), conv(c, c))
|
41 |
+
self.convblock3 = nn.Sequential(conv(c, c), conv(c, c))
|
42 |
+
self.conv1 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 4, 4, 2, 1))
|
43 |
+
self.conv2 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 1, 4, 2, 1))
|
44 |
+
|
45 |
+
def forward(self, x, flow, scale=1):
|
46 |
+
x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
47 |
+
flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale
|
48 |
+
feat = self.conv0(torch.cat((x, flow), 1))
|
49 |
+
feat = self.convblock0(feat) + feat
|
50 |
+
feat = self.convblock1(feat) + feat
|
51 |
+
feat = self.convblock2(feat) + feat
|
52 |
+
feat = self.convblock3(feat) + feat
|
53 |
+
flow = self.conv1(feat)
|
54 |
+
mask = self.conv2(feat)
|
55 |
+
flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale
|
56 |
+
mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
57 |
+
return flow, mask
|
58 |
+
|
59 |
+
|
60 |
+
class IFNet(nn.Module):
|
61 |
+
def __init__(self):
|
62 |
+
super(IFNet, self).__init__()
|
63 |
+
self.block0 = IFBlock(7+4, c=90)
|
64 |
+
self.block1 = IFBlock(7+4, c=90)
|
65 |
+
self.block2 = IFBlock(7+4, c=90)
|
66 |
+
self.block_tea = IFBlock(10+4, c=90)
|
67 |
+
|
68 |
+
def forward(self, x, scale_list=[4, 2, 1], training=False):
|
69 |
+
if training == False:
|
70 |
+
channel = x.shape[1] // 2
|
71 |
+
img0 = x[:, :channel]
|
72 |
+
img1 = x[:, channel:]
|
73 |
+
flow_list = []
|
74 |
+
merged = []
|
75 |
+
mask_list = []
|
76 |
+
warped_img0 = img0
|
77 |
+
warped_img1 = img1
|
78 |
+
flow = (x[:, :4]).detach() * 0
|
79 |
+
mask = (x[:, :1]).detach() * 0
|
80 |
+
block = [self.block0, self.block1, self.block2]
|
81 |
+
for i in range(3):
|
82 |
+
f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
|
83 |
+
f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
|
84 |
+
flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
|
85 |
+
mask = mask + (m0 + (-m1)) / 2
|
86 |
+
mask_list.append(mask)
|
87 |
+
flow_list.append(flow)
|
88 |
+
warped_img0 = warp(img0, flow[:, :2], device=x.device)
|
89 |
+
warped_img1 = warp(img1, flow[:, 2:4], device=x.device)
|
90 |
+
merged.append((warped_img0, warped_img1))
|
91 |
+
'''
|
92 |
+
c0 = self.contextnet(img0, flow[:, :2])
|
93 |
+
c1 = self.contextnet(img1, flow[:, 2:4])
|
94 |
+
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
95 |
+
res = tmp[:, 1:4] * 2 - 1
|
96 |
+
'''
|
97 |
+
for i in range(3):
|
98 |
+
mask_list[i] = torch.sigmoid(mask_list[i])
|
99 |
+
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
100 |
+
return flow_list, mask_list[2], merged
|
101 |
+
|
102 |
+
@staticmethod
|
103 |
+
def state_dict_converter():
|
104 |
+
return IFNetStateDictConverter()
|
105 |
+
|
106 |
+
|
107 |
+
class IFNetStateDictConverter:
|
108 |
+
def __init__(self):
|
109 |
+
pass
|
110 |
+
|
111 |
+
def from_diffusers(self, state_dict):
|
112 |
+
state_dict_ = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
113 |
+
return state_dict_
|
114 |
+
|
115 |
+
def from_civitai(self, state_dict):
|
116 |
+
return self.from_diffusers(state_dict)
|
117 |
+
|
118 |
+
|
119 |
+
class RIFEInterpolater:
|
120 |
+
def __init__(self, model, device="cuda"):
|
121 |
+
self.model = model
|
122 |
+
self.device = device
|
123 |
+
# IFNet only does not support float16
|
124 |
+
self.torch_dtype = torch.float32
|
125 |
+
|
126 |
+
@staticmethod
|
127 |
+
def from_model_manager(model_manager):
|
128 |
+
return RIFEInterpolater(model_manager.RIFE, device=model_manager.device)
|
129 |
+
|
130 |
+
def process_image(self, image):
|
131 |
+
width, height = image.size
|
132 |
+
if width % 32 != 0 or height % 32 != 0:
|
133 |
+
width = (width + 31) // 32
|
134 |
+
height = (height + 31) // 32
|
135 |
+
image = image.resize((width, height))
|
136 |
+
image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1)
|
137 |
+
return image
|
138 |
+
|
139 |
+
def process_images(self, images):
|
140 |
+
images = [self.process_image(image) for image in images]
|
141 |
+
images = torch.stack(images)
|
142 |
+
return images
|
143 |
+
|
144 |
+
def decode_images(self, images):
|
145 |
+
images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
|
146 |
+
images = [Image.fromarray(image) for image in images]
|
147 |
+
return images
|
148 |
+
|
149 |
+
def add_interpolated_images(self, images, interpolated_images):
|
150 |
+
output_images = []
|
151 |
+
for image, interpolated_image in zip(images, interpolated_images):
|
152 |
+
output_images.append(image)
|
153 |
+
output_images.append(interpolated_image)
|
154 |
+
output_images.append(images[-1])
|
155 |
+
return output_images
|
156 |
+
|
157 |
+
|
158 |
+
@torch.no_grad()
|
159 |
+
def interpolate_(self, images, scale=1.0):
|
160 |
+
input_tensor = self.process_images(images)
|
161 |
+
input_tensor = torch.cat((input_tensor[:-1], input_tensor[1:]), dim=1)
|
162 |
+
input_tensor = input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
163 |
+
flow, mask, merged = self.model(input_tensor, [4/scale, 2/scale, 1/scale])
|
164 |
+
output_images = self.decode_images(merged[2].cpu())
|
165 |
+
if output_images[0].size != images[0].size:
|
166 |
+
output_images = [image.resize(images[0].size) for image in output_images]
|
167 |
+
return output_images
|
168 |
+
|
169 |
+
|
170 |
+
@torch.no_grad()
|
171 |
+
def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1, progress_bar=lambda x:x):
|
172 |
+
# Preprocess
|
173 |
+
processed_images = self.process_images(images)
|
174 |
+
|
175 |
+
for iter in range(num_iter):
|
176 |
+
# Input
|
177 |
+
input_tensor = torch.cat((processed_images[:-1], processed_images[1:]), dim=1)
|
178 |
+
|
179 |
+
# Interpolate
|
180 |
+
output_tensor = []
|
181 |
+
for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
|
182 |
+
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
183 |
+
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
184 |
+
batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
185 |
+
flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
|
186 |
+
output_tensor.append(merged[2].cpu())
|
187 |
+
|
188 |
+
# Output
|
189 |
+
output_tensor = torch.concat(output_tensor, dim=0).clip(0, 1)
|
190 |
+
processed_images = self.add_interpolated_images(processed_images, output_tensor)
|
191 |
+
processed_images = torch.stack(processed_images)
|
192 |
+
|
193 |
+
# To images
|
194 |
+
output_images = self.decode_images(processed_images)
|
195 |
+
if output_images[0].size != images[0].size:
|
196 |
+
output_images = [image.resize(images[0].size) for image in output_images]
|
197 |
+
return output_images
|
198 |
+
|
199 |
+
|
200 |
+
class RIFESmoother(RIFEInterpolater):
|
201 |
+
def __init__(self, model, device="cuda"):
|
202 |
+
super(RIFESmoother, self).__init__(model, device=device)
|
203 |
+
|
204 |
+
@staticmethod
|
205 |
+
def from_model_manager(model_manager):
|
206 |
+
return RIFESmoother(model_manager.RIFE, device=model_manager.device)
|
207 |
+
|
208 |
+
def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
|
209 |
+
output_tensor = []
|
210 |
+
for batch_id in range(0, input_tensor.shape[0], batch_size):
|
211 |
+
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
212 |
+
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
213 |
+
batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
214 |
+
flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
|
215 |
+
output_tensor.append(merged[2].cpu())
|
216 |
+
output_tensor = torch.concat(output_tensor, dim=0)
|
217 |
+
return output_tensor
|
218 |
+
|
219 |
+
@torch.no_grad()
|
220 |
+
def __call__(self, rendered_frames, scale=1.0, batch_size=4, num_iter=1, **kwargs):
|
221 |
+
# Preprocess
|
222 |
+
processed_images = self.process_images(rendered_frames)
|
223 |
+
|
224 |
+
for iter in range(num_iter):
|
225 |
+
# Input
|
226 |
+
input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1)
|
227 |
+
|
228 |
+
# Interpolate
|
229 |
+
output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
|
230 |
+
|
231 |
+
# Blend
|
232 |
+
input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1)
|
233 |
+
output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
|
234 |
+
|
235 |
+
# Add to frames
|
236 |
+
processed_images[1:-1] = output_tensor
|
237 |
+
|
238 |
+
# To images
|
239 |
+
output_images = self.decode_images(processed_images)
|
240 |
+
if output_images[0].size != rendered_frames[0].size:
|
241 |
+
output_images = [image.resize(rendered_frames[0].size) for image in output_images]
|
242 |
+
return output_images
|
diffsynth/extensions/__init__.py
ADDED
File without changes
|
diffsynth/models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .model_manager import *
|
diffsynth/models/attention.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from einops import rearrange
|
3 |
+
|
4 |
+
|
5 |
+
def low_version_attention(query, key, value, attn_bias=None):
|
6 |
+
scale = 1 / query.shape[-1] ** 0.5
|
7 |
+
query = query * scale
|
8 |
+
attn = torch.matmul(query, key.transpose(-2, -1))
|
9 |
+
if attn_bias is not None:
|
10 |
+
attn = attn + attn_bias
|
11 |
+
attn = attn.softmax(-1)
|
12 |
+
return attn @ value
|
13 |
+
|
14 |
+
|
15 |
+
class Attention(torch.nn.Module):
|
16 |
+
|
17 |
+
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
18 |
+
super().__init__()
|
19 |
+
dim_inner = head_dim * num_heads
|
20 |
+
kv_dim = kv_dim if kv_dim is not None else q_dim
|
21 |
+
self.num_heads = num_heads
|
22 |
+
self.head_dim = head_dim
|
23 |
+
|
24 |
+
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
25 |
+
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
26 |
+
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
27 |
+
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
28 |
+
|
29 |
+
def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
|
30 |
+
batch_size = q.shape[0]
|
31 |
+
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
32 |
+
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
33 |
+
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
34 |
+
hidden_states = hidden_states + scale * ip_hidden_states
|
35 |
+
return hidden_states
|
36 |
+
|
37 |
+
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
38 |
+
if encoder_hidden_states is None:
|
39 |
+
encoder_hidden_states = hidden_states
|
40 |
+
|
41 |
+
batch_size = encoder_hidden_states.shape[0]
|
42 |
+
|
43 |
+
q = self.to_q(hidden_states)
|
44 |
+
k = self.to_k(encoder_hidden_states)
|
45 |
+
v = self.to_v(encoder_hidden_states)
|
46 |
+
|
47 |
+
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
48 |
+
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
49 |
+
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
50 |
+
|
51 |
+
if qkv_preprocessor is not None:
|
52 |
+
q, k, v = qkv_preprocessor(q, k, v)
|
53 |
+
|
54 |
+
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
55 |
+
if ipadapter_kwargs is not None:
|
56 |
+
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
|
57 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
58 |
+
hidden_states = hidden_states.to(q.dtype)
|
59 |
+
|
60 |
+
hidden_states = self.to_out(hidden_states)
|
61 |
+
|
62 |
+
return hidden_states
|
63 |
+
|
64 |
+
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
65 |
+
if encoder_hidden_states is None:
|
66 |
+
encoder_hidden_states = hidden_states
|
67 |
+
|
68 |
+
q = self.to_q(hidden_states)
|
69 |
+
k = self.to_k(encoder_hidden_states)
|
70 |
+
v = self.to_v(encoder_hidden_states)
|
71 |
+
|
72 |
+
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
|
73 |
+
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
|
74 |
+
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
|
75 |
+
|
76 |
+
if attn_mask is not None:
|
77 |
+
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
|
78 |
+
else:
|
79 |
+
import xformers.ops as xops
|
80 |
+
hidden_states = xops.memory_efficient_attention(q, k, v)
|
81 |
+
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
|
82 |
+
|
83 |
+
hidden_states = hidden_states.to(q.dtype)
|
84 |
+
hidden_states = self.to_out(hidden_states)
|
85 |
+
|
86 |
+
return hidden_states
|
87 |
+
|
88 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
89 |
+
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
|
diffsynth/models/downloader.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import hf_hub_download
|
2 |
+
from modelscope import snapshot_download
|
3 |
+
import os, shutil
|
4 |
+
from typing_extensions import Literal, TypeAlias
|
5 |
+
from typing import List
|
6 |
+
from ..configs.model_config import preset_models_on_huggingface, preset_models_on_modelscope, Preset_model_id
|
7 |
+
|
8 |
+
|
9 |
+
def download_from_modelscope(model_id, origin_file_path, local_dir):
|
10 |
+
os.makedirs(local_dir, exist_ok=True)
|
11 |
+
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
12 |
+
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
|
13 |
+
return
|
14 |
+
else:
|
15 |
+
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
|
16 |
+
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
|
17 |
+
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
18 |
+
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
|
19 |
+
if downloaded_file_path != target_file_path:
|
20 |
+
shutil.move(downloaded_file_path, target_file_path)
|
21 |
+
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
22 |
+
|
23 |
+
|
24 |
+
def download_from_huggingface(model_id, origin_file_path, local_dir):
|
25 |
+
os.makedirs(local_dir, exist_ok=True)
|
26 |
+
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
27 |
+
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
|
28 |
+
return
|
29 |
+
else:
|
30 |
+
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
|
31 |
+
hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
|
32 |
+
|
33 |
+
|
34 |
+
Preset_model_website: TypeAlias = Literal[
|
35 |
+
"HuggingFace",
|
36 |
+
"ModelScope",
|
37 |
+
]
|
38 |
+
website_to_preset_models = {
|
39 |
+
"HuggingFace": preset_models_on_huggingface,
|
40 |
+
"ModelScope": preset_models_on_modelscope,
|
41 |
+
}
|
42 |
+
website_to_download_fn = {
|
43 |
+
"HuggingFace": download_from_huggingface,
|
44 |
+
"ModelScope": download_from_modelscope,
|
45 |
+
}
|
46 |
+
|
47 |
+
|
48 |
+
def download_models(
|
49 |
+
model_id_list: List[Preset_model_id] = [],
|
50 |
+
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
51 |
+
):
|
52 |
+
print(f"Downloading models: {model_id_list}")
|
53 |
+
downloaded_files = []
|
54 |
+
for model_id in model_id_list:
|
55 |
+
for website in downloading_priority:
|
56 |
+
if model_id in website_to_preset_models[website]:
|
57 |
+
for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]:
|
58 |
+
# Check if the file is downloaded.
|
59 |
+
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
60 |
+
if file_to_download in downloaded_files:
|
61 |
+
continue
|
62 |
+
# Download
|
63 |
+
website_to_download_fn[website](model_id, origin_file_path, local_dir)
|
64 |
+
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
65 |
+
downloaded_files.append(file_to_download)
|
66 |
+
return downloaded_files
|
diffsynth/models/flux_dit.py
ADDED
@@ -0,0 +1,575 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm
|
3 |
+
from einops import rearrange
|
4 |
+
from .tiler import TileWorker
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
class RoPEEmbedding(torch.nn.Module):
|
9 |
+
def __init__(self, dim, theta, axes_dim):
|
10 |
+
super().__init__()
|
11 |
+
self.dim = dim
|
12 |
+
self.theta = theta
|
13 |
+
self.axes_dim = axes_dim
|
14 |
+
|
15 |
+
|
16 |
+
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
17 |
+
assert dim % 2 == 0, "The dimension must be even."
|
18 |
+
|
19 |
+
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
20 |
+
omega = 1.0 / (theta**scale)
|
21 |
+
|
22 |
+
batch_size, seq_length = pos.shape
|
23 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
24 |
+
cos_out = torch.cos(out)
|
25 |
+
sin_out = torch.sin(out)
|
26 |
+
|
27 |
+
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
28 |
+
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
|
29 |
+
return out.float()
|
30 |
+
|
31 |
+
|
32 |
+
def forward(self, ids):
|
33 |
+
n_axes = ids.shape[-1]
|
34 |
+
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
|
35 |
+
return emb.unsqueeze(1)
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
class RMSNorm(torch.nn.Module):
|
40 |
+
def __init__(self, dim, eps):
|
41 |
+
super().__init__()
|
42 |
+
self.weight = torch.nn.Parameter(torch.ones((dim,)))
|
43 |
+
self.eps = eps
|
44 |
+
|
45 |
+
def forward(self, hidden_states):
|
46 |
+
input_dtype = hidden_states.dtype
|
47 |
+
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
48 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
49 |
+
hidden_states = hidden_states.to(input_dtype) * self.weight
|
50 |
+
return hidden_states
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
class FluxJointAttention(torch.nn.Module):
|
55 |
+
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
|
56 |
+
super().__init__()
|
57 |
+
self.num_heads = num_heads
|
58 |
+
self.head_dim = head_dim
|
59 |
+
self.only_out_a = only_out_a
|
60 |
+
|
61 |
+
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
62 |
+
self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
|
63 |
+
|
64 |
+
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
65 |
+
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
66 |
+
self.norm_q_b = RMSNorm(head_dim, eps=1e-6)
|
67 |
+
self.norm_k_b = RMSNorm(head_dim, eps=1e-6)
|
68 |
+
|
69 |
+
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
|
70 |
+
if not only_out_a:
|
71 |
+
self.b_to_out = torch.nn.Linear(dim_b, dim_b)
|
72 |
+
|
73 |
+
|
74 |
+
def apply_rope(self, xq, xk, freqs_cis):
|
75 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
76 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
77 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
78 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
79 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
80 |
+
|
81 |
+
|
82 |
+
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb):
|
83 |
+
batch_size = hidden_states_a.shape[0]
|
84 |
+
|
85 |
+
# Part A
|
86 |
+
qkv_a = self.a_to_qkv(hidden_states_a)
|
87 |
+
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
88 |
+
q_a, k_a, v_a = qkv_a.chunk(3, dim=1)
|
89 |
+
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
|
90 |
+
|
91 |
+
# Part B
|
92 |
+
qkv_b = self.b_to_qkv(hidden_states_b)
|
93 |
+
qkv_b = qkv_b.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
94 |
+
q_b, k_b, v_b = qkv_b.chunk(3, dim=1)
|
95 |
+
q_b, k_b = self.norm_q_b(q_b), self.norm_k_b(k_b)
|
96 |
+
|
97 |
+
q = torch.concat([q_b, q_a], dim=2)
|
98 |
+
k = torch.concat([k_b, k_a], dim=2)
|
99 |
+
v = torch.concat([v_b, v_a], dim=2)
|
100 |
+
|
101 |
+
q, k = self.apply_rope(q, k, image_rotary_emb)
|
102 |
+
|
103 |
+
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
104 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
105 |
+
hidden_states = hidden_states.to(q.dtype)
|
106 |
+
hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
|
107 |
+
hidden_states_a = self.a_to_out(hidden_states_a)
|
108 |
+
if self.only_out_a:
|
109 |
+
return hidden_states_a
|
110 |
+
else:
|
111 |
+
hidden_states_b = self.b_to_out(hidden_states_b)
|
112 |
+
return hidden_states_a, hidden_states_b
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
class FluxJointTransformerBlock(torch.nn.Module):
|
117 |
+
def __init__(self, dim, num_attention_heads):
|
118 |
+
super().__init__()
|
119 |
+
self.norm1_a = AdaLayerNorm(dim)
|
120 |
+
self.norm1_b = AdaLayerNorm(dim)
|
121 |
+
|
122 |
+
self.attn = FluxJointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
|
123 |
+
|
124 |
+
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
125 |
+
self.ff_a = torch.nn.Sequential(
|
126 |
+
torch.nn.Linear(dim, dim*4),
|
127 |
+
torch.nn.GELU(approximate="tanh"),
|
128 |
+
torch.nn.Linear(dim*4, dim)
|
129 |
+
)
|
130 |
+
|
131 |
+
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
132 |
+
self.ff_b = torch.nn.Sequential(
|
133 |
+
torch.nn.Linear(dim, dim*4),
|
134 |
+
torch.nn.GELU(approximate="tanh"),
|
135 |
+
torch.nn.Linear(dim*4, dim)
|
136 |
+
)
|
137 |
+
|
138 |
+
|
139 |
+
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
|
140 |
+
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
141 |
+
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
142 |
+
|
143 |
+
# Attention
|
144 |
+
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb)
|
145 |
+
|
146 |
+
# Part A
|
147 |
+
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
148 |
+
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
149 |
+
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
150 |
+
|
151 |
+
# Part B
|
152 |
+
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
|
153 |
+
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
|
154 |
+
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
155 |
+
|
156 |
+
return hidden_states_a, hidden_states_b
|
157 |
+
|
158 |
+
|
159 |
+
|
160 |
+
class FluxSingleAttention(torch.nn.Module):
|
161 |
+
def __init__(self, dim_a, dim_b, num_heads, head_dim):
|
162 |
+
super().__init__()
|
163 |
+
self.num_heads = num_heads
|
164 |
+
self.head_dim = head_dim
|
165 |
+
|
166 |
+
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
167 |
+
|
168 |
+
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
169 |
+
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
170 |
+
|
171 |
+
|
172 |
+
def apply_rope(self, xq, xk, freqs_cis):
|
173 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
174 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
175 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
176 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
177 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
178 |
+
|
179 |
+
|
180 |
+
def forward(self, hidden_states, image_rotary_emb):
|
181 |
+
batch_size = hidden_states.shape[0]
|
182 |
+
|
183 |
+
qkv_a = self.a_to_qkv(hidden_states)
|
184 |
+
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
185 |
+
q_a, k_a, v = qkv_a.chunk(3, dim=1)
|
186 |
+
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
|
187 |
+
|
188 |
+
q, k = self.apply_rope(q_a, k_a, image_rotary_emb)
|
189 |
+
|
190 |
+
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
191 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
192 |
+
hidden_states = hidden_states.to(q.dtype)
|
193 |
+
return hidden_states
|
194 |
+
|
195 |
+
|
196 |
+
|
197 |
+
class AdaLayerNormSingle(torch.nn.Module):
|
198 |
+
def __init__(self, dim):
|
199 |
+
super().__init__()
|
200 |
+
self.silu = torch.nn.SiLU()
|
201 |
+
self.linear = torch.nn.Linear(dim, 3 * dim, bias=True)
|
202 |
+
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
203 |
+
|
204 |
+
|
205 |
+
def forward(self, x, emb):
|
206 |
+
emb = self.linear(self.silu(emb))
|
207 |
+
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
|
208 |
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
209 |
+
return x, gate_msa
|
210 |
+
|
211 |
+
|
212 |
+
|
213 |
+
class FluxSingleTransformerBlock(torch.nn.Module):
|
214 |
+
def __init__(self, dim, num_attention_heads):
|
215 |
+
super().__init__()
|
216 |
+
self.num_heads = num_attention_heads
|
217 |
+
self.head_dim = dim // num_attention_heads
|
218 |
+
self.dim = dim
|
219 |
+
|
220 |
+
self.norm = AdaLayerNormSingle(dim)
|
221 |
+
# self.proj_in = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), torch.nn.GELU(approximate="tanh"))
|
222 |
+
# self.attn = FluxSingleAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
|
223 |
+
self.linear = torch.nn.Linear(dim, dim * (3 + 4))
|
224 |
+
self.norm_q_a = RMSNorm(self.head_dim, eps=1e-6)
|
225 |
+
self.norm_k_a = RMSNorm(self.head_dim, eps=1e-6)
|
226 |
+
|
227 |
+
self.proj_out = torch.nn.Linear(dim * 5, dim)
|
228 |
+
|
229 |
+
|
230 |
+
def apply_rope(self, xq, xk, freqs_cis):
|
231 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
232 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
233 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
234 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
235 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
236 |
+
|
237 |
+
|
238 |
+
def process_attention(self, hidden_states, image_rotary_emb):
|
239 |
+
batch_size = hidden_states.shape[0]
|
240 |
+
|
241 |
+
qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
242 |
+
q, k, v = qkv.chunk(3, dim=1)
|
243 |
+
q, k = self.norm_q_a(q), self.norm_k_a(k)
|
244 |
+
|
245 |
+
q, k = self.apply_rope(q, k, image_rotary_emb)
|
246 |
+
|
247 |
+
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
248 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
249 |
+
hidden_states = hidden_states.to(q.dtype)
|
250 |
+
return hidden_states
|
251 |
+
|
252 |
+
|
253 |
+
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
|
254 |
+
residual = hidden_states_a
|
255 |
+
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
|
256 |
+
hidden_states_a = self.linear(norm_hidden_states)
|
257 |
+
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
|
258 |
+
|
259 |
+
attn_output = self.process_attention(attn_output, image_rotary_emb)
|
260 |
+
mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
|
261 |
+
|
262 |
+
hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
263 |
+
hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a)
|
264 |
+
hidden_states_a = residual + hidden_states_a
|
265 |
+
|
266 |
+
return hidden_states_a, hidden_states_b
|
267 |
+
|
268 |
+
|
269 |
+
|
270 |
+
class AdaLayerNormContinuous(torch.nn.Module):
|
271 |
+
def __init__(self, dim):
|
272 |
+
super().__init__()
|
273 |
+
self.silu = torch.nn.SiLU()
|
274 |
+
self.linear = torch.nn.Linear(dim, dim * 2, bias=True)
|
275 |
+
self.norm = torch.nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
|
276 |
+
|
277 |
+
def forward(self, x, conditioning):
|
278 |
+
emb = self.linear(self.silu(conditioning))
|
279 |
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
280 |
+
x = self.norm(x) * (1 + scale)[:, None] + shift[:, None]
|
281 |
+
return x
|
282 |
+
|
283 |
+
|
284 |
+
|
285 |
+
class FluxDiT(torch.nn.Module):
|
286 |
+
def __init__(self):
|
287 |
+
super().__init__()
|
288 |
+
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
289 |
+
self.time_embedder = TimestepEmbeddings(256, 3072)
|
290 |
+
self.guidance_embedder = TimestepEmbeddings(256, 3072)
|
291 |
+
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
292 |
+
self.context_embedder = torch.nn.Linear(4096, 3072)
|
293 |
+
self.x_embedder = torch.nn.Linear(64, 3072)
|
294 |
+
|
295 |
+
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
|
296 |
+
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
|
297 |
+
|
298 |
+
self.norm_out = AdaLayerNormContinuous(3072)
|
299 |
+
self.proj_out = torch.nn.Linear(3072, 64)
|
300 |
+
|
301 |
+
|
302 |
+
def patchify(self, hidden_states):
|
303 |
+
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
304 |
+
return hidden_states
|
305 |
+
|
306 |
+
|
307 |
+
def unpatchify(self, hidden_states, height, width):
|
308 |
+
hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
309 |
+
return hidden_states
|
310 |
+
|
311 |
+
|
312 |
+
def prepare_image_ids(self, latents):
|
313 |
+
batch_size, _, height, width = latents.shape
|
314 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
315 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
316 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
317 |
+
|
318 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
319 |
+
|
320 |
+
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
321 |
+
latent_image_ids = latent_image_ids.reshape(
|
322 |
+
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
323 |
+
)
|
324 |
+
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
|
325 |
+
|
326 |
+
return latent_image_ids
|
327 |
+
|
328 |
+
|
329 |
+
def tiled_forward(
|
330 |
+
self,
|
331 |
+
hidden_states,
|
332 |
+
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
|
333 |
+
tile_size=128, tile_stride=64,
|
334 |
+
**kwargs
|
335 |
+
):
|
336 |
+
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
|
337 |
+
hidden_states = TileWorker().tiled_forward(
|
338 |
+
lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None),
|
339 |
+
hidden_states,
|
340 |
+
tile_size,
|
341 |
+
tile_stride,
|
342 |
+
tile_device=hidden_states.device,
|
343 |
+
tile_dtype=hidden_states.dtype
|
344 |
+
)
|
345 |
+
return hidden_states
|
346 |
+
|
347 |
+
|
348 |
+
def forward(
|
349 |
+
self,
|
350 |
+
hidden_states,
|
351 |
+
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
|
352 |
+
tiled=False, tile_size=128, tile_stride=64,
|
353 |
+
**kwargs
|
354 |
+
):
|
355 |
+
if tiled:
|
356 |
+
return self.tiled_forward(
|
357 |
+
hidden_states,
|
358 |
+
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
|
359 |
+
tile_size=tile_size, tile_stride=tile_stride,
|
360 |
+
**kwargs
|
361 |
+
)
|
362 |
+
|
363 |
+
if image_ids is None:
|
364 |
+
image_ids = self.prepare_image_ids(hidden_states)
|
365 |
+
|
366 |
+
conditioning = self.time_embedder(timestep, hidden_states.dtype)\
|
367 |
+
+ self.guidance_embedder(guidance, hidden_states.dtype)\
|
368 |
+
+ self.pooled_text_embedder(pooled_prompt_emb)
|
369 |
+
prompt_emb = self.context_embedder(prompt_emb)
|
370 |
+
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
371 |
+
|
372 |
+
height, width = hidden_states.shape[-2:]
|
373 |
+
hidden_states = self.patchify(hidden_states)
|
374 |
+
hidden_states = self.x_embedder(hidden_states)
|
375 |
+
|
376 |
+
for block in self.blocks:
|
377 |
+
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
378 |
+
|
379 |
+
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
380 |
+
for block in self.single_blocks:
|
381 |
+
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
382 |
+
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
383 |
+
|
384 |
+
hidden_states = self.norm_out(hidden_states, conditioning)
|
385 |
+
hidden_states = self.proj_out(hidden_states)
|
386 |
+
hidden_states = self.unpatchify(hidden_states, height, width)
|
387 |
+
|
388 |
+
return hidden_states
|
389 |
+
|
390 |
+
|
391 |
+
@staticmethod
|
392 |
+
def state_dict_converter():
|
393 |
+
return FluxDiTStateDictConverter()
|
394 |
+
|
395 |
+
|
396 |
+
|
397 |
+
class FluxDiTStateDictConverter:
|
398 |
+
def __init__(self):
|
399 |
+
pass
|
400 |
+
|
401 |
+
def from_diffusers(self, state_dict):
|
402 |
+
rename_dict = {
|
403 |
+
"context_embedder": "context_embedder",
|
404 |
+
"x_embedder": "x_embedder",
|
405 |
+
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
406 |
+
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
407 |
+
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
|
408 |
+
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
|
409 |
+
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
410 |
+
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
411 |
+
"norm_out.linear": "norm_out.linear",
|
412 |
+
"proj_out": "proj_out",
|
413 |
+
|
414 |
+
"norm1.linear": "norm1_a.linear",
|
415 |
+
"norm1_context.linear": "norm1_b.linear",
|
416 |
+
"attn.to_q": "attn.a_to_q",
|
417 |
+
"attn.to_k": "attn.a_to_k",
|
418 |
+
"attn.to_v": "attn.a_to_v",
|
419 |
+
"attn.to_out.0": "attn.a_to_out",
|
420 |
+
"attn.add_q_proj": "attn.b_to_q",
|
421 |
+
"attn.add_k_proj": "attn.b_to_k",
|
422 |
+
"attn.add_v_proj": "attn.b_to_v",
|
423 |
+
"attn.to_add_out": "attn.b_to_out",
|
424 |
+
"ff.net.0.proj": "ff_a.0",
|
425 |
+
"ff.net.2": "ff_a.2",
|
426 |
+
"ff_context.net.0.proj": "ff_b.0",
|
427 |
+
"ff_context.net.2": "ff_b.2",
|
428 |
+
"attn.norm_q": "attn.norm_q_a",
|
429 |
+
"attn.norm_k": "attn.norm_k_a",
|
430 |
+
"attn.norm_added_q": "attn.norm_q_b",
|
431 |
+
"attn.norm_added_k": "attn.norm_k_b",
|
432 |
+
}
|
433 |
+
rename_dict_single = {
|
434 |
+
"attn.to_q": "a_to_q",
|
435 |
+
"attn.to_k": "a_to_k",
|
436 |
+
"attn.to_v": "a_to_v",
|
437 |
+
"attn.norm_q": "norm_q_a",
|
438 |
+
"attn.norm_k": "norm_k_a",
|
439 |
+
"norm.linear": "norm.linear",
|
440 |
+
"proj_mlp": "proj_in_besides_attn",
|
441 |
+
"proj_out": "proj_out",
|
442 |
+
}
|
443 |
+
state_dict_ = {}
|
444 |
+
for name, param in state_dict.items():
|
445 |
+
if name in rename_dict:
|
446 |
+
state_dict_[rename_dict[name]] = param
|
447 |
+
elif name.endswith(".weight") or name.endswith(".bias"):
|
448 |
+
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
449 |
+
prefix = name[:-len(suffix)]
|
450 |
+
if prefix in rename_dict:
|
451 |
+
state_dict_[rename_dict[prefix] + suffix] = param
|
452 |
+
elif prefix.startswith("transformer_blocks."):
|
453 |
+
names = prefix.split(".")
|
454 |
+
names[0] = "blocks"
|
455 |
+
middle = ".".join(names[2:])
|
456 |
+
if middle in rename_dict:
|
457 |
+
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
458 |
+
state_dict_[name_] = param
|
459 |
+
elif prefix.startswith("single_transformer_blocks."):
|
460 |
+
names = prefix.split(".")
|
461 |
+
names[0] = "single_blocks"
|
462 |
+
middle = ".".join(names[2:])
|
463 |
+
if middle in rename_dict_single:
|
464 |
+
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
|
465 |
+
state_dict_[name_] = param
|
466 |
+
else:
|
467 |
+
print(name)
|
468 |
+
else:
|
469 |
+
print(name)
|
470 |
+
for name in list(state_dict_.keys()):
|
471 |
+
if ".proj_in_besides_attn." in name:
|
472 |
+
name_ = name.replace(".proj_in_besides_attn.", ".linear.")
|
473 |
+
param = torch.concat([
|
474 |
+
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
|
475 |
+
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
|
476 |
+
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
|
477 |
+
state_dict_[name],
|
478 |
+
], dim=0)
|
479 |
+
state_dict_[name_] = param
|
480 |
+
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
|
481 |
+
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
|
482 |
+
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
|
483 |
+
state_dict_.pop(name)
|
484 |
+
for name in list(state_dict_.keys()):
|
485 |
+
for component in ["a", "b"]:
|
486 |
+
if f".{component}_to_q." in name:
|
487 |
+
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
488 |
+
param = torch.concat([
|
489 |
+
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
490 |
+
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
491 |
+
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
492 |
+
], dim=0)
|
493 |
+
state_dict_[name_] = param
|
494 |
+
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
495 |
+
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
496 |
+
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
497 |
+
return state_dict_
|
498 |
+
|
499 |
+
def from_civitai(self, state_dict):
|
500 |
+
rename_dict = {
|
501 |
+
"time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",
|
502 |
+
"time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight",
|
503 |
+
"time_in.out_layer.bias": "time_embedder.timestep_embedder.2.bias",
|
504 |
+
"time_in.out_layer.weight": "time_embedder.timestep_embedder.2.weight",
|
505 |
+
"txt_in.bias": "context_embedder.bias",
|
506 |
+
"txt_in.weight": "context_embedder.weight",
|
507 |
+
"vector_in.in_layer.bias": "pooled_text_embedder.0.bias",
|
508 |
+
"vector_in.in_layer.weight": "pooled_text_embedder.0.weight",
|
509 |
+
"vector_in.out_layer.bias": "pooled_text_embedder.2.bias",
|
510 |
+
"vector_in.out_layer.weight": "pooled_text_embedder.2.weight",
|
511 |
+
"final_layer.linear.bias": "proj_out.bias",
|
512 |
+
"final_layer.linear.weight": "proj_out.weight",
|
513 |
+
"guidance_in.in_layer.bias": "guidance_embedder.timestep_embedder.0.bias",
|
514 |
+
"guidance_in.in_layer.weight": "guidance_embedder.timestep_embedder.0.weight",
|
515 |
+
"guidance_in.out_layer.bias": "guidance_embedder.timestep_embedder.2.bias",
|
516 |
+
"guidance_in.out_layer.weight": "guidance_embedder.timestep_embedder.2.weight",
|
517 |
+
"img_in.bias": "x_embedder.bias",
|
518 |
+
"img_in.weight": "x_embedder.weight",
|
519 |
+
"final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
|
520 |
+
"final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
|
521 |
+
}
|
522 |
+
suffix_rename_dict = {
|
523 |
+
"img_attn.norm.key_norm.scale": "attn.norm_k_a.weight",
|
524 |
+
"img_attn.norm.query_norm.scale": "attn.norm_q_a.weight",
|
525 |
+
"img_attn.proj.bias": "attn.a_to_out.bias",
|
526 |
+
"img_attn.proj.weight": "attn.a_to_out.weight",
|
527 |
+
"img_attn.qkv.bias": "attn.a_to_qkv.bias",
|
528 |
+
"img_attn.qkv.weight": "attn.a_to_qkv.weight",
|
529 |
+
"img_mlp.0.bias": "ff_a.0.bias",
|
530 |
+
"img_mlp.0.weight": "ff_a.0.weight",
|
531 |
+
"img_mlp.2.bias": "ff_a.2.bias",
|
532 |
+
"img_mlp.2.weight": "ff_a.2.weight",
|
533 |
+
"img_mod.lin.bias": "norm1_a.linear.bias",
|
534 |
+
"img_mod.lin.weight": "norm1_a.linear.weight",
|
535 |
+
"txt_attn.norm.key_norm.scale": "attn.norm_k_b.weight",
|
536 |
+
"txt_attn.norm.query_norm.scale": "attn.norm_q_b.weight",
|
537 |
+
"txt_attn.proj.bias": "attn.b_to_out.bias",
|
538 |
+
"txt_attn.proj.weight": "attn.b_to_out.weight",
|
539 |
+
"txt_attn.qkv.bias": "attn.b_to_qkv.bias",
|
540 |
+
"txt_attn.qkv.weight": "attn.b_to_qkv.weight",
|
541 |
+
"txt_mlp.0.bias": "ff_b.0.bias",
|
542 |
+
"txt_mlp.0.weight": "ff_b.0.weight",
|
543 |
+
"txt_mlp.2.bias": "ff_b.2.bias",
|
544 |
+
"txt_mlp.2.weight": "ff_b.2.weight",
|
545 |
+
"txt_mod.lin.bias": "norm1_b.linear.bias",
|
546 |
+
"txt_mod.lin.weight": "norm1_b.linear.weight",
|
547 |
+
|
548 |
+
"linear1.bias": "linear.bias",
|
549 |
+
"linear1.weight": "linear.weight",
|
550 |
+
"linear2.bias": "proj_out.bias",
|
551 |
+
"linear2.weight": "proj_out.weight",
|
552 |
+
"modulation.lin.bias": "norm.linear.bias",
|
553 |
+
"modulation.lin.weight": "norm.linear.weight",
|
554 |
+
"norm.key_norm.scale": "norm_k_a.weight",
|
555 |
+
"norm.query_norm.scale": "norm_q_a.weight",
|
556 |
+
}
|
557 |
+
state_dict_ = {}
|
558 |
+
for name, param in state_dict.items():
|
559 |
+
names = name.split(".")
|
560 |
+
if name in rename_dict:
|
561 |
+
rename = rename_dict[name]
|
562 |
+
if name.startswith("final_layer.adaLN_modulation.1."):
|
563 |
+
param = torch.concat([param[3072:], param[:3072]], dim=0)
|
564 |
+
state_dict_[rename] = param
|
565 |
+
elif names[0] == "double_blocks":
|
566 |
+
rename = f"blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
|
567 |
+
state_dict_[rename] = param
|
568 |
+
elif names[0] == "single_blocks":
|
569 |
+
if ".".join(names[2:]) in suffix_rename_dict:
|
570 |
+
rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
|
571 |
+
state_dict_[rename] = param
|
572 |
+
else:
|
573 |
+
print(name)
|
574 |
+
return state_dict_
|
575 |
+
|
diffsynth/models/flux_text_encoder.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import T5EncoderModel, T5Config
|
3 |
+
from .sd_text_encoder import SDTextEncoder
|
4 |
+
|
5 |
+
|
6 |
+
class FluxTextEncoder1(SDTextEncoder):
|
7 |
+
def __init__(self, vocab_size=49408):
|
8 |
+
super().__init__(vocab_size=vocab_size)
|
9 |
+
|
10 |
+
def forward(self, input_ids, clip_skip=2):
|
11 |
+
embeds = self.token_embedding(input_ids) + self.position_embeds
|
12 |
+
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
13 |
+
for encoder_id, encoder in enumerate(self.encoders):
|
14 |
+
embeds = encoder(embeds, attn_mask=attn_mask)
|
15 |
+
if encoder_id + clip_skip == len(self.encoders):
|
16 |
+
hidden_states = embeds
|
17 |
+
embeds = self.final_layer_norm(embeds)
|
18 |
+
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
|
19 |
+
return embeds, pooled_embeds
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def state_dict_converter():
|
23 |
+
return FluxTextEncoder1StateDictConverter()
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
class FluxTextEncoder2(T5EncoderModel):
|
28 |
+
def __init__(self, config):
|
29 |
+
super().__init__(config)
|
30 |
+
self.eval()
|
31 |
+
|
32 |
+
def forward(self, input_ids):
|
33 |
+
outputs = super().forward(input_ids=input_ids)
|
34 |
+
prompt_emb = outputs.last_hidden_state
|
35 |
+
return prompt_emb
|
36 |
+
|
37 |
+
@staticmethod
|
38 |
+
def state_dict_converter():
|
39 |
+
return FluxTextEncoder2StateDictConverter()
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
class FluxTextEncoder1StateDictConverter:
|
44 |
+
def __init__(self):
|
45 |
+
pass
|
46 |
+
|
47 |
+
def from_diffusers(self, state_dict):
|
48 |
+
rename_dict = {
|
49 |
+
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
50 |
+
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
51 |
+
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
52 |
+
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
|
53 |
+
}
|
54 |
+
attn_rename_dict = {
|
55 |
+
"self_attn.q_proj": "attn.to_q",
|
56 |
+
"self_attn.k_proj": "attn.to_k",
|
57 |
+
"self_attn.v_proj": "attn.to_v",
|
58 |
+
"self_attn.out_proj": "attn.to_out",
|
59 |
+
"layer_norm1": "layer_norm1",
|
60 |
+
"layer_norm2": "layer_norm2",
|
61 |
+
"mlp.fc1": "fc1",
|
62 |
+
"mlp.fc2": "fc2",
|
63 |
+
}
|
64 |
+
state_dict_ = {}
|
65 |
+
for name in state_dict:
|
66 |
+
if name in rename_dict:
|
67 |
+
param = state_dict[name]
|
68 |
+
if name == "text_model.embeddings.position_embedding.weight":
|
69 |
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
70 |
+
state_dict_[rename_dict[name]] = param
|
71 |
+
elif name.startswith("text_model.encoder.layers."):
|
72 |
+
param = state_dict[name]
|
73 |
+
names = name.split(".")
|
74 |
+
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
75 |
+
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
76 |
+
state_dict_[name_] = param
|
77 |
+
return state_dict_
|
78 |
+
|
79 |
+
def from_civitai(self, state_dict):
|
80 |
+
return self.from_diffusers(state_dict)
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
class FluxTextEncoder2StateDictConverter():
|
85 |
+
def __init__(self):
|
86 |
+
pass
|
87 |
+
|
88 |
+
def from_diffusers(self, state_dict):
|
89 |
+
state_dict_ = state_dict
|
90 |
+
return state_dict_
|
91 |
+
|
92 |
+
def from_civitai(self, state_dict):
|
93 |
+
return self.from_diffusers(state_dict)
|
diffsynth/models/flux_vae.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .sd3_vae_encoder import SD3VAEEncoder, SDVAEEncoderStateDictConverter
|
2 |
+
from .sd3_vae_decoder import SD3VAEDecoder, SDVAEDecoderStateDictConverter
|
3 |
+
|
4 |
+
|
5 |
+
class FluxVAEEncoder(SD3VAEEncoder):
|
6 |
+
def __init__(self):
|
7 |
+
super().__init__()
|
8 |
+
self.scaling_factor = 0.3611
|
9 |
+
self.shift_factor = 0.1159
|
10 |
+
|
11 |
+
@staticmethod
|
12 |
+
def state_dict_converter():
|
13 |
+
return FluxVAEEncoderStateDictConverter()
|
14 |
+
|
15 |
+
|
16 |
+
class FluxVAEDecoder(SD3VAEDecoder):
|
17 |
+
def __init__(self):
|
18 |
+
super().__init__()
|
19 |
+
self.scaling_factor = 0.3611
|
20 |
+
self.shift_factor = 0.1159
|
21 |
+
|
22 |
+
@staticmethod
|
23 |
+
def state_dict_converter():
|
24 |
+
return FluxVAEDecoderStateDictConverter()
|
25 |
+
|
26 |
+
|
27 |
+
class FluxVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
|
28 |
+
def __init__(self):
|
29 |
+
pass
|
30 |
+
|
31 |
+
def from_civitai(self, state_dict):
|
32 |
+
rename_dict = {
|
33 |
+
"encoder.conv_in.bias": "conv_in.bias",
|
34 |
+
"encoder.conv_in.weight": "conv_in.weight",
|
35 |
+
"encoder.conv_out.bias": "conv_out.bias",
|
36 |
+
"encoder.conv_out.weight": "conv_out.weight",
|
37 |
+
"encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
|
38 |
+
"encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
|
39 |
+
"encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
|
40 |
+
"encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
|
41 |
+
"encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
|
42 |
+
"encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
|
43 |
+
"encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
|
44 |
+
"encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
|
45 |
+
"encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
|
46 |
+
"encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
|
47 |
+
"encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
|
48 |
+
"encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
|
49 |
+
"encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
|
50 |
+
"encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
|
51 |
+
"encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
|
52 |
+
"encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
|
53 |
+
"encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
|
54 |
+
"encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
|
55 |
+
"encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
|
56 |
+
"encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
|
57 |
+
"encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
|
58 |
+
"encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
|
59 |
+
"encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
|
60 |
+
"encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
|
61 |
+
"encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
|
62 |
+
"encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
|
63 |
+
"encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
|
64 |
+
"encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
|
65 |
+
"encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
|
66 |
+
"encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
|
67 |
+
"encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
|
68 |
+
"encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
|
69 |
+
"encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
|
70 |
+
"encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
|
71 |
+
"encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
|
72 |
+
"encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
|
73 |
+
"encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
|
74 |
+
"encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
|
75 |
+
"encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
|
76 |
+
"encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
|
77 |
+
"encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
|
78 |
+
"encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
|
79 |
+
"encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
|
80 |
+
"encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
|
81 |
+
"encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
|
82 |
+
"encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
|
83 |
+
"encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
|
84 |
+
"encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
|
85 |
+
"encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
|
86 |
+
"encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
|
87 |
+
"encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
|
88 |
+
"encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
|
89 |
+
"encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
|
90 |
+
"encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
|
91 |
+
"encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
|
92 |
+
"encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
|
93 |
+
"encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
|
94 |
+
"encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
|
95 |
+
"encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
|
96 |
+
"encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
|
97 |
+
"encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
|
98 |
+
"encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
|
99 |
+
"encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
|
100 |
+
"encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
|
101 |
+
"encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
|
102 |
+
"encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
|
103 |
+
"encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
|
104 |
+
"encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
|
105 |
+
"encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
|
106 |
+
"encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
|
107 |
+
"encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
|
108 |
+
"encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
|
109 |
+
"encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
|
110 |
+
"encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
|
111 |
+
"encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
|
112 |
+
"encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
|
113 |
+
"encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
|
114 |
+
"encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
|
115 |
+
"encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
|
116 |
+
"encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
|
117 |
+
"encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
|
118 |
+
"encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
|
119 |
+
"encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
|
120 |
+
"encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
|
121 |
+
"encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
|
122 |
+
"encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
|
123 |
+
"encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
|
124 |
+
"encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
|
125 |
+
"encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
|
126 |
+
"encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
|
127 |
+
"encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
|
128 |
+
"encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
|
129 |
+
"encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
|
130 |
+
"encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
|
131 |
+
"encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
|
132 |
+
"encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
|
133 |
+
"encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
|
134 |
+
"encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
|
135 |
+
"encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
|
136 |
+
"encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
|
137 |
+
"encoder.norm_out.bias": "conv_norm_out.bias",
|
138 |
+
"encoder.norm_out.weight": "conv_norm_out.weight",
|
139 |
+
}
|
140 |
+
state_dict_ = {}
|
141 |
+
for name in state_dict:
|
142 |
+
if name in rename_dict:
|
143 |
+
param = state_dict[name]
|
144 |
+
if "transformer_blocks" in rename_dict[name]:
|
145 |
+
param = param.squeeze()
|
146 |
+
state_dict_[rename_dict[name]] = param
|
147 |
+
return state_dict_
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
class FluxVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
|
152 |
+
def __init__(self):
|
153 |
+
pass
|
154 |
+
|
155 |
+
def from_civitai(self, state_dict):
|
156 |
+
rename_dict = {
|
157 |
+
"decoder.conv_in.bias": "conv_in.bias",
|
158 |
+
"decoder.conv_in.weight": "conv_in.weight",
|
159 |
+
"decoder.conv_out.bias": "conv_out.bias",
|
160 |
+
"decoder.conv_out.weight": "conv_out.weight",
|
161 |
+
"decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
|
162 |
+
"decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
|
163 |
+
"decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
|
164 |
+
"decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
|
165 |
+
"decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
|
166 |
+
"decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
|
167 |
+
"decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
|
168 |
+
"decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
|
169 |
+
"decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
|
170 |
+
"decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
|
171 |
+
"decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
|
172 |
+
"decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
|
173 |
+
"decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
|
174 |
+
"decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
|
175 |
+
"decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
|
176 |
+
"decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
|
177 |
+
"decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
|
178 |
+
"decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
|
179 |
+
"decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
|
180 |
+
"decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
|
181 |
+
"decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
|
182 |
+
"decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
|
183 |
+
"decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
|
184 |
+
"decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
|
185 |
+
"decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
|
186 |
+
"decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
|
187 |
+
"decoder.norm_out.bias": "conv_norm_out.bias",
|
188 |
+
"decoder.norm_out.weight": "conv_norm_out.weight",
|
189 |
+
"decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
|
190 |
+
"decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
|
191 |
+
"decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
|
192 |
+
"decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
|
193 |
+
"decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
|
194 |
+
"decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
|
195 |
+
"decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
|
196 |
+
"decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
|
197 |
+
"decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
|
198 |
+
"decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
|
199 |
+
"decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
|
200 |
+
"decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
|
201 |
+
"decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
|
202 |
+
"decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
|
203 |
+
"decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
|
204 |
+
"decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
|
205 |
+
"decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
|
206 |
+
"decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
|
207 |
+
"decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
|
208 |
+
"decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
|
209 |
+
"decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
|
210 |
+
"decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
|
211 |
+
"decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
|
212 |
+
"decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
|
213 |
+
"decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
|
214 |
+
"decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
|
215 |
+
"decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
|
216 |
+
"decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
|
217 |
+
"decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
|
218 |
+
"decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
|
219 |
+
"decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
|
220 |
+
"decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
|
221 |
+
"decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
|
222 |
+
"decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
|
223 |
+
"decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
|
224 |
+
"decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
|
225 |
+
"decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
|
226 |
+
"decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
|
227 |
+
"decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
|
228 |
+
"decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
|
229 |
+
"decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
|
230 |
+
"decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
|
231 |
+
"decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
|
232 |
+
"decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
|
233 |
+
"decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
|
234 |
+
"decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
|
235 |
+
"decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
|
236 |
+
"decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
|
237 |
+
"decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
|
238 |
+
"decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
|
239 |
+
"decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
|
240 |
+
"decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
|
241 |
+
"decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
|
242 |
+
"decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
|
243 |
+
"decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
|
244 |
+
"decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
|
245 |
+
"decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
|
246 |
+
"decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
|
247 |
+
"decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
|
248 |
+
"decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
|
249 |
+
"decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
|
250 |
+
"decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
|
251 |
+
"decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
|
252 |
+
"decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
|
253 |
+
"decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
|
254 |
+
"decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
|
255 |
+
"decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
|
256 |
+
"decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
|
257 |
+
"decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
|
258 |
+
"decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
|
259 |
+
"decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
|
260 |
+
"decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
|
261 |
+
"decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
|
262 |
+
"decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
|
263 |
+
"decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
|
264 |
+
"decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
|
265 |
+
"decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
|
266 |
+
"decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
|
267 |
+
"decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
|
268 |
+
"decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
|
269 |
+
"decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
|
270 |
+
"decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
|
271 |
+
"decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
|
272 |
+
"decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
|
273 |
+
"decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
|
274 |
+
"decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
|
275 |
+
"decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
|
276 |
+
"decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
|
277 |
+
"decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
|
278 |
+
"decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
|
279 |
+
"decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
|
280 |
+
"decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
|
281 |
+
"decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
|
282 |
+
"decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
|
283 |
+
"decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
|
284 |
+
"decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
|
285 |
+
"decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
|
286 |
+
"decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
|
287 |
+
"decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
|
288 |
+
"decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
|
289 |
+
"decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
|
290 |
+
"decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
|
291 |
+
"decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
|
292 |
+
"decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
|
293 |
+
"decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
|
294 |
+
"decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
|
295 |
+
}
|
296 |
+
state_dict_ = {}
|
297 |
+
for name in state_dict:
|
298 |
+
if name in rename_dict:
|
299 |
+
param = state_dict[name]
|
300 |
+
if "transformer_blocks" in rename_dict[name]:
|
301 |
+
param = param.squeeze()
|
302 |
+
state_dict_[rename_dict[name]] = param
|
303 |
+
return state_dict_
|
diffsynth/models/hunyuan_dit.py
ADDED
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .attention import Attention
|
2 |
+
from einops import repeat, rearrange
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
class HunyuanDiTRotaryEmbedding(torch.nn.Module):
|
8 |
+
|
9 |
+
def __init__(self, q_norm_shape=88, k_norm_shape=88, rotary_emb_on_k=True):
|
10 |
+
super().__init__()
|
11 |
+
self.q_norm = torch.nn.LayerNorm((q_norm_shape,), elementwise_affine=True, eps=1e-06)
|
12 |
+
self.k_norm = torch.nn.LayerNorm((k_norm_shape,), elementwise_affine=True, eps=1e-06)
|
13 |
+
self.rotary_emb_on_k = rotary_emb_on_k
|
14 |
+
self.k_cache, self.v_cache = [], []
|
15 |
+
|
16 |
+
def reshape_for_broadcast(self, freqs_cis, x):
|
17 |
+
ndim = x.ndim
|
18 |
+
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
19 |
+
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
20 |
+
|
21 |
+
def rotate_half(self, x):
|
22 |
+
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
23 |
+
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
24 |
+
|
25 |
+
def apply_rotary_emb(self, xq, xk, freqs_cis):
|
26 |
+
xk_out = None
|
27 |
+
cos, sin = self.reshape_for_broadcast(freqs_cis, xq)
|
28 |
+
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
29 |
+
xq_out = (xq.float() * cos + self.rotate_half(xq.float()) * sin).type_as(xq)
|
30 |
+
if xk is not None:
|
31 |
+
xk_out = (xk.float() * cos + self.rotate_half(xk.float()) * sin).type_as(xk)
|
32 |
+
return xq_out, xk_out
|
33 |
+
|
34 |
+
def forward(self, q, k, v, freqs_cis_img, to_cache=False):
|
35 |
+
# norm
|
36 |
+
q = self.q_norm(q)
|
37 |
+
k = self.k_norm(k)
|
38 |
+
|
39 |
+
# RoPE
|
40 |
+
if self.rotary_emb_on_k:
|
41 |
+
q, k = self.apply_rotary_emb(q, k, freqs_cis_img)
|
42 |
+
else:
|
43 |
+
q, _ = self.apply_rotary_emb(q, None, freqs_cis_img)
|
44 |
+
|
45 |
+
if to_cache:
|
46 |
+
self.k_cache.append(k)
|
47 |
+
self.v_cache.append(v)
|
48 |
+
elif len(self.k_cache) > 0 and len(self.v_cache) > 0:
|
49 |
+
k = torch.concat([k] + self.k_cache, dim=2)
|
50 |
+
v = torch.concat([v] + self.v_cache, dim=2)
|
51 |
+
self.k_cache, self.v_cache = [], []
|
52 |
+
return q, k, v
|
53 |
+
|
54 |
+
|
55 |
+
class FP32_Layernorm(torch.nn.LayerNorm):
|
56 |
+
def forward(self, inputs):
|
57 |
+
origin_dtype = inputs.dtype
|
58 |
+
return torch.nn.functional.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).to(origin_dtype)
|
59 |
+
|
60 |
+
|
61 |
+
class FP32_SiLU(torch.nn.SiLU):
|
62 |
+
def forward(self, inputs):
|
63 |
+
origin_dtype = inputs.dtype
|
64 |
+
return torch.nn.functional.silu(inputs.float(), inplace=False).to(origin_dtype)
|
65 |
+
|
66 |
+
|
67 |
+
class HunyuanDiTFinalLayer(torch.nn.Module):
|
68 |
+
def __init__(self, final_hidden_size=1408, condition_dim=1408, patch_size=2, out_channels=8):
|
69 |
+
super().__init__()
|
70 |
+
self.norm_final = torch.nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
|
71 |
+
self.linear = torch.nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
|
72 |
+
self.adaLN_modulation = torch.nn.Sequential(
|
73 |
+
FP32_SiLU(),
|
74 |
+
torch.nn.Linear(condition_dim, 2 * final_hidden_size, bias=True)
|
75 |
+
)
|
76 |
+
|
77 |
+
def modulate(self, x, shift, scale):
|
78 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
79 |
+
|
80 |
+
def forward(self, hidden_states, condition_emb):
|
81 |
+
shift, scale = self.adaLN_modulation(condition_emb).chunk(2, dim=1)
|
82 |
+
hidden_states = self.modulate(self.norm_final(hidden_states), shift, scale)
|
83 |
+
hidden_states = self.linear(hidden_states)
|
84 |
+
return hidden_states
|
85 |
+
|
86 |
+
|
87 |
+
class HunyuanDiTBlock(torch.nn.Module):
|
88 |
+
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
hidden_dim=1408,
|
92 |
+
condition_dim=1408,
|
93 |
+
num_heads=16,
|
94 |
+
mlp_ratio=4.3637,
|
95 |
+
text_dim=1024,
|
96 |
+
skip_connection=False
|
97 |
+
):
|
98 |
+
super().__init__()
|
99 |
+
self.norm1 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
100 |
+
self.rota1 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads)
|
101 |
+
self.attn1 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
|
102 |
+
self.norm2 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
103 |
+
self.rota2 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads, rotary_emb_on_k=False)
|
104 |
+
self.attn2 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, kv_dim=text_dim, bias_q=True, bias_kv=True, bias_out=True)
|
105 |
+
self.norm3 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
106 |
+
self.modulation = torch.nn.Sequential(FP32_SiLU(), torch.nn.Linear(condition_dim, hidden_dim, bias=True))
|
107 |
+
self.mlp = torch.nn.Sequential(
|
108 |
+
torch.nn.Linear(hidden_dim, int(hidden_dim*mlp_ratio), bias=True),
|
109 |
+
torch.nn.GELU(approximate="tanh"),
|
110 |
+
torch.nn.Linear(int(hidden_dim*mlp_ratio), hidden_dim, bias=True)
|
111 |
+
)
|
112 |
+
if skip_connection:
|
113 |
+
self.skip_norm = FP32_Layernorm((hidden_dim * 2,), eps=1e-6, elementwise_affine=True)
|
114 |
+
self.skip_linear = torch.nn.Linear(hidden_dim * 2, hidden_dim, bias=True)
|
115 |
+
else:
|
116 |
+
self.skip_norm, self.skip_linear = None, None
|
117 |
+
|
118 |
+
def forward(self, hidden_states, condition_emb, text_emb, freq_cis_img, residual=None, to_cache=False):
|
119 |
+
# Long Skip Connection
|
120 |
+
if self.skip_norm is not None and self.skip_linear is not None:
|
121 |
+
hidden_states = torch.cat([hidden_states, residual], dim=-1)
|
122 |
+
hidden_states = self.skip_norm(hidden_states)
|
123 |
+
hidden_states = self.skip_linear(hidden_states)
|
124 |
+
|
125 |
+
# Self-Attention
|
126 |
+
shift_msa = self.modulation(condition_emb).unsqueeze(dim=1)
|
127 |
+
attn_input = self.norm1(hidden_states) + shift_msa
|
128 |
+
hidden_states = hidden_states + self.attn1(attn_input, qkv_preprocessor=lambda q, k, v: self.rota1(q, k, v, freq_cis_img, to_cache=to_cache))
|
129 |
+
|
130 |
+
# Cross-Attention
|
131 |
+
attn_input = self.norm3(hidden_states)
|
132 |
+
hidden_states = hidden_states + self.attn2(attn_input, text_emb, qkv_preprocessor=lambda q, k, v: self.rota2(q, k, v, freq_cis_img))
|
133 |
+
|
134 |
+
# FFN Layer
|
135 |
+
mlp_input = self.norm2(hidden_states)
|
136 |
+
hidden_states = hidden_states + self.mlp(mlp_input)
|
137 |
+
return hidden_states
|
138 |
+
|
139 |
+
|
140 |
+
class AttentionPool(torch.nn.Module):
|
141 |
+
def __init__(self, spacial_dim, embed_dim, num_heads, output_dim = None):
|
142 |
+
super().__init__()
|
143 |
+
self.positional_embedding = torch.nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
|
144 |
+
self.k_proj = torch.nn.Linear(embed_dim, embed_dim)
|
145 |
+
self.q_proj = torch.nn.Linear(embed_dim, embed_dim)
|
146 |
+
self.v_proj = torch.nn.Linear(embed_dim, embed_dim)
|
147 |
+
self.c_proj = torch.nn.Linear(embed_dim, output_dim or embed_dim)
|
148 |
+
self.num_heads = num_heads
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
x = x.permute(1, 0, 2) # NLC -> LNC
|
152 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
153 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
|
154 |
+
x, _ = torch.nn.functional.multi_head_attention_forward(
|
155 |
+
query=x[:1], key=x, value=x,
|
156 |
+
embed_dim_to_check=x.shape[-1],
|
157 |
+
num_heads=self.num_heads,
|
158 |
+
q_proj_weight=self.q_proj.weight,
|
159 |
+
k_proj_weight=self.k_proj.weight,
|
160 |
+
v_proj_weight=self.v_proj.weight,
|
161 |
+
in_proj_weight=None,
|
162 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
163 |
+
bias_k=None,
|
164 |
+
bias_v=None,
|
165 |
+
add_zero_attn=False,
|
166 |
+
dropout_p=0,
|
167 |
+
out_proj_weight=self.c_proj.weight,
|
168 |
+
out_proj_bias=self.c_proj.bias,
|
169 |
+
use_separate_proj_weight=True,
|
170 |
+
training=self.training,
|
171 |
+
need_weights=False
|
172 |
+
)
|
173 |
+
return x.squeeze(0)
|
174 |
+
|
175 |
+
|
176 |
+
class PatchEmbed(torch.nn.Module):
|
177 |
+
def __init__(
|
178 |
+
self,
|
179 |
+
patch_size=(2, 2),
|
180 |
+
in_chans=4,
|
181 |
+
embed_dim=1408,
|
182 |
+
bias=True,
|
183 |
+
):
|
184 |
+
super().__init__()
|
185 |
+
self.proj = torch.nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
186 |
+
|
187 |
+
def forward(self, x):
|
188 |
+
x = self.proj(x)
|
189 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
190 |
+
return x
|
191 |
+
|
192 |
+
|
193 |
+
def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
|
194 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
195 |
+
if not repeat_only:
|
196 |
+
half = dim // 2
|
197 |
+
freqs = torch.exp(
|
198 |
+
-math.log(max_period)
|
199 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
200 |
+
/ half
|
201 |
+
).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线
|
202 |
+
args = t[:, None].float() * freqs[None]
|
203 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
204 |
+
if dim % 2:
|
205 |
+
embedding = torch.cat(
|
206 |
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
207 |
+
)
|
208 |
+
else:
|
209 |
+
embedding = repeat(t, "b -> b d", d=dim)
|
210 |
+
return embedding
|
211 |
+
|
212 |
+
|
213 |
+
class TimestepEmbedder(torch.nn.Module):
|
214 |
+
def __init__(self, hidden_size=1408, frequency_embedding_size=256):
|
215 |
+
super().__init__()
|
216 |
+
self.mlp = torch.nn.Sequential(
|
217 |
+
torch.nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
218 |
+
torch.nn.SiLU(),
|
219 |
+
torch.nn.Linear(hidden_size, hidden_size, bias=True),
|
220 |
+
)
|
221 |
+
self.frequency_embedding_size = frequency_embedding_size
|
222 |
+
|
223 |
+
def forward(self, t):
|
224 |
+
t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
|
225 |
+
t_emb = self.mlp(t_freq)
|
226 |
+
return t_emb
|
227 |
+
|
228 |
+
|
229 |
+
class HunyuanDiT(torch.nn.Module):
|
230 |
+
def __init__(self, num_layers_down=21, num_layers_up=19, in_channels=4, out_channels=8, hidden_dim=1408, text_dim=1024, t5_dim=2048, text_length=77, t5_length=256):
|
231 |
+
super().__init__()
|
232 |
+
|
233 |
+
# Embedders
|
234 |
+
self.text_emb_padding = torch.nn.Parameter(torch.randn(text_length + t5_length, text_dim, dtype=torch.float32))
|
235 |
+
self.t5_embedder = torch.nn.Sequential(
|
236 |
+
torch.nn.Linear(t5_dim, t5_dim * 4, bias=True),
|
237 |
+
FP32_SiLU(),
|
238 |
+
torch.nn.Linear(t5_dim * 4, text_dim, bias=True),
|
239 |
+
)
|
240 |
+
self.t5_pooler = AttentionPool(t5_length, t5_dim, num_heads=8, output_dim=1024)
|
241 |
+
self.style_embedder = torch.nn.Parameter(torch.randn(hidden_dim))
|
242 |
+
self.patch_embedder = PatchEmbed(in_chans=in_channels)
|
243 |
+
self.timestep_embedder = TimestepEmbedder()
|
244 |
+
self.extra_embedder = torch.nn.Sequential(
|
245 |
+
torch.nn.Linear(256 * 6 + 1024 + hidden_dim, hidden_dim * 4),
|
246 |
+
FP32_SiLU(),
|
247 |
+
torch.nn.Linear(hidden_dim * 4, hidden_dim),
|
248 |
+
)
|
249 |
+
|
250 |
+
# Transformer blocks
|
251 |
+
self.num_layers_down = num_layers_down
|
252 |
+
self.num_layers_up = num_layers_up
|
253 |
+
self.blocks = torch.nn.ModuleList(
|
254 |
+
[HunyuanDiTBlock(skip_connection=False) for _ in range(num_layers_down)] + \
|
255 |
+
[HunyuanDiTBlock(skip_connection=True) for _ in range(num_layers_up)]
|
256 |
+
)
|
257 |
+
|
258 |
+
# Output layers
|
259 |
+
self.final_layer = HunyuanDiTFinalLayer()
|
260 |
+
self.out_channels = out_channels
|
261 |
+
|
262 |
+
def prepare_text_emb(self, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5):
|
263 |
+
text_emb_mask = text_emb_mask.bool()
|
264 |
+
text_emb_mask_t5 = text_emb_mask_t5.bool()
|
265 |
+
text_emb_t5 = self.t5_embedder(text_emb_t5)
|
266 |
+
text_emb = torch.cat([text_emb, text_emb_t5], dim=1)
|
267 |
+
text_emb_mask = torch.cat([text_emb_mask, text_emb_mask_t5], dim=-1)
|
268 |
+
text_emb = torch.where(text_emb_mask.unsqueeze(2), text_emb, self.text_emb_padding.to(text_emb))
|
269 |
+
return text_emb
|
270 |
+
|
271 |
+
def prepare_extra_emb(self, text_emb_t5, timestep, size_emb, dtype, batch_size):
|
272 |
+
# Text embedding
|
273 |
+
pooled_text_emb_t5 = self.t5_pooler(text_emb_t5)
|
274 |
+
|
275 |
+
# Timestep embedding
|
276 |
+
timestep_emb = self.timestep_embedder(timestep)
|
277 |
+
|
278 |
+
# Size embedding
|
279 |
+
size_emb = timestep_embedding(size_emb.view(-1), 256).to(dtype)
|
280 |
+
size_emb = size_emb.view(-1, 6 * 256)
|
281 |
+
|
282 |
+
# Style embedding
|
283 |
+
style_emb = repeat(self.style_embedder, "D -> B D", B=batch_size)
|
284 |
+
|
285 |
+
# Concatenate all extra vectors
|
286 |
+
extra_emb = torch.cat([pooled_text_emb_t5, size_emb, style_emb], dim=1)
|
287 |
+
condition_emb = timestep_emb + self.extra_embedder(extra_emb)
|
288 |
+
|
289 |
+
return condition_emb
|
290 |
+
|
291 |
+
def unpatchify(self, x, h, w):
|
292 |
+
return rearrange(x, "B (H W) (P Q C) -> B C (H P) (W Q)", H=h, W=w, P=2, Q=2)
|
293 |
+
|
294 |
+
def build_mask(self, data, is_bound):
|
295 |
+
_, _, H, W = data.shape
|
296 |
+
h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
|
297 |
+
w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
|
298 |
+
border_width = (H + W) // 4
|
299 |
+
pad = torch.ones_like(h) * border_width
|
300 |
+
mask = torch.stack([
|
301 |
+
pad if is_bound[0] else h + 1,
|
302 |
+
pad if is_bound[1] else H - h,
|
303 |
+
pad if is_bound[2] else w + 1,
|
304 |
+
pad if is_bound[3] else W - w
|
305 |
+
]).min(dim=0).values
|
306 |
+
mask = mask.clip(1, border_width)
|
307 |
+
mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
|
308 |
+
mask = rearrange(mask, "H W -> 1 H W")
|
309 |
+
return mask
|
310 |
+
|
311 |
+
def tiled_block_forward(self, block, hidden_states, condition_emb, text_emb, freq_cis_img, residual, torch_dtype, data_device, computation_device, tile_size, tile_stride):
|
312 |
+
B, C, H, W = hidden_states.shape
|
313 |
+
|
314 |
+
weight = torch.zeros((1, 1, H, W), dtype=torch_dtype, device=data_device)
|
315 |
+
values = torch.zeros((B, C, H, W), dtype=torch_dtype, device=data_device)
|
316 |
+
|
317 |
+
# Split tasks
|
318 |
+
tasks = []
|
319 |
+
for h in range(0, H, tile_stride):
|
320 |
+
for w in range(0, W, tile_stride):
|
321 |
+
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
|
322 |
+
continue
|
323 |
+
h_, w_ = h + tile_size, w + tile_size
|
324 |
+
if h_ > H: h, h_ = H - tile_size, H
|
325 |
+
if w_ > W: w, w_ = W - tile_size, W
|
326 |
+
tasks.append((h, h_, w, w_))
|
327 |
+
|
328 |
+
# Run
|
329 |
+
for hl, hr, wl, wr in tasks:
|
330 |
+
hidden_states_batch = hidden_states[:, :, hl:hr, wl:wr].to(computation_device)
|
331 |
+
hidden_states_batch = rearrange(hidden_states_batch, "B C H W -> B (H W) C")
|
332 |
+
if residual is not None:
|
333 |
+
residual_batch = residual[:, :, hl:hr, wl:wr].to(computation_device)
|
334 |
+
residual_batch = rearrange(residual_batch, "B C H W -> B (H W) C")
|
335 |
+
else:
|
336 |
+
residual_batch = None
|
337 |
+
|
338 |
+
# Forward
|
339 |
+
hidden_states_batch = block(hidden_states_batch, condition_emb, text_emb, freq_cis_img, residual_batch).to(data_device)
|
340 |
+
hidden_states_batch = rearrange(hidden_states_batch, "B (H W) C -> B C H W", H=hr-hl)
|
341 |
+
|
342 |
+
mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
|
343 |
+
values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
|
344 |
+
weight[:, :, hl:hr, wl:wr] += mask
|
345 |
+
values /= weight
|
346 |
+
return values
|
347 |
+
|
348 |
+
def forward(
|
349 |
+
self, hidden_states, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5, timestep, size_emb, freq_cis_img,
|
350 |
+
tiled=False, tile_size=64, tile_stride=32,
|
351 |
+
to_cache=False,
|
352 |
+
use_gradient_checkpointing=False,
|
353 |
+
):
|
354 |
+
# Embeddings
|
355 |
+
text_emb = self.prepare_text_emb(text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5)
|
356 |
+
condition_emb = self.prepare_extra_emb(text_emb_t5, timestep, size_emb, hidden_states.dtype, hidden_states.shape[0])
|
357 |
+
|
358 |
+
# Input
|
359 |
+
height, width = hidden_states.shape[-2], hidden_states.shape[-1]
|
360 |
+
hidden_states = self.patch_embedder(hidden_states)
|
361 |
+
|
362 |
+
# Blocks
|
363 |
+
def create_custom_forward(module):
|
364 |
+
def custom_forward(*inputs):
|
365 |
+
return module(*inputs)
|
366 |
+
return custom_forward
|
367 |
+
if tiled:
|
368 |
+
hidden_states = rearrange(hidden_states, "B (H W) C -> B C H W", H=height//2)
|
369 |
+
residuals = []
|
370 |
+
for block_id, block in enumerate(self.blocks):
|
371 |
+
residual = residuals.pop() if block_id >= self.num_layers_down else None
|
372 |
+
hidden_states = self.tiled_block_forward(
|
373 |
+
block, hidden_states, condition_emb, text_emb, freq_cis_img, residual,
|
374 |
+
torch_dtype=hidden_states.dtype, data_device=hidden_states.device, computation_device=hidden_states.device,
|
375 |
+
tile_size=tile_size, tile_stride=tile_stride
|
376 |
+
)
|
377 |
+
if block_id < self.num_layers_down - 2:
|
378 |
+
residuals.append(hidden_states)
|
379 |
+
hidden_states = rearrange(hidden_states, "B C H W -> B (H W) C")
|
380 |
+
else:
|
381 |
+
residuals = []
|
382 |
+
for block_id, block in enumerate(self.blocks):
|
383 |
+
residual = residuals.pop() if block_id >= self.num_layers_down else None
|
384 |
+
if self.training and use_gradient_checkpointing:
|
385 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
386 |
+
create_custom_forward(block),
|
387 |
+
hidden_states, condition_emb, text_emb, freq_cis_img, residual,
|
388 |
+
use_reentrant=False,
|
389 |
+
)
|
390 |
+
else:
|
391 |
+
hidden_states = block(hidden_states, condition_emb, text_emb, freq_cis_img, residual, to_cache=to_cache)
|
392 |
+
if block_id < self.num_layers_down - 2:
|
393 |
+
residuals.append(hidden_states)
|
394 |
+
|
395 |
+
# Output
|
396 |
+
hidden_states = self.final_layer(hidden_states, condition_emb)
|
397 |
+
hidden_states = self.unpatchify(hidden_states, height//2, width//2)
|
398 |
+
hidden_states, _ = hidden_states.chunk(2, dim=1)
|
399 |
+
return hidden_states
|
400 |
+
|
401 |
+
@staticmethod
|
402 |
+
def state_dict_converter():
|
403 |
+
return HunyuanDiTStateDictConverter()
|
404 |
+
|
405 |
+
|
406 |
+
|
407 |
+
class HunyuanDiTStateDictConverter():
|
408 |
+
def __init__(self):
|
409 |
+
pass
|
410 |
+
|
411 |
+
def from_diffusers(self, state_dict):
|
412 |
+
state_dict_ = {}
|
413 |
+
for name, param in state_dict.items():
|
414 |
+
name_ = name
|
415 |
+
name_ = name_.replace(".default_modulation.", ".modulation.")
|
416 |
+
name_ = name_.replace(".mlp.fc1.", ".mlp.0.")
|
417 |
+
name_ = name_.replace(".mlp.fc2.", ".mlp.2.")
|
418 |
+
name_ = name_.replace(".attn1.q_norm.", ".rota1.q_norm.")
|
419 |
+
name_ = name_.replace(".attn2.q_norm.", ".rota2.q_norm.")
|
420 |
+
name_ = name_.replace(".attn1.k_norm.", ".rota1.k_norm.")
|
421 |
+
name_ = name_.replace(".attn2.k_norm.", ".rota2.k_norm.")
|
422 |
+
name_ = name_.replace(".q_proj.", ".to_q.")
|
423 |
+
name_ = name_.replace(".out_proj.", ".to_out.")
|
424 |
+
name_ = name_.replace("text_embedding_padding", "text_emb_padding")
|
425 |
+
name_ = name_.replace("mlp_t5.0.", "t5_embedder.0.")
|
426 |
+
name_ = name_.replace("mlp_t5.2.", "t5_embedder.2.")
|
427 |
+
name_ = name_.replace("pooler.", "t5_pooler.")
|
428 |
+
name_ = name_.replace("x_embedder.", "patch_embedder.")
|
429 |
+
name_ = name_.replace("t_embedder.", "timestep_embedder.")
|
430 |
+
name_ = name_.replace("t5_pooler.to_q.", "t5_pooler.q_proj.")
|
431 |
+
name_ = name_.replace("style_embedder.weight", "style_embedder")
|
432 |
+
if ".kv_proj." in name_:
|
433 |
+
param_k = param[:param.shape[0]//2]
|
434 |
+
param_v = param[param.shape[0]//2:]
|
435 |
+
state_dict_[name_.replace(".kv_proj.", ".to_k.")] = param_k
|
436 |
+
state_dict_[name_.replace(".kv_proj.", ".to_v.")] = param_v
|
437 |
+
elif ".Wqkv." in name_:
|
438 |
+
param_q = param[:param.shape[0]//3]
|
439 |
+
param_k = param[param.shape[0]//3:param.shape[0]//3*2]
|
440 |
+
param_v = param[param.shape[0]//3*2:]
|
441 |
+
state_dict_[name_.replace(".Wqkv.", ".to_q.")] = param_q
|
442 |
+
state_dict_[name_.replace(".Wqkv.", ".to_k.")] = param_k
|
443 |
+
state_dict_[name_.replace(".Wqkv.", ".to_v.")] = param_v
|
444 |
+
elif "style_embedder" in name_:
|
445 |
+
state_dict_[name_] = param.squeeze()
|
446 |
+
else:
|
447 |
+
state_dict_[name_] = param
|
448 |
+
return state_dict_
|
449 |
+
|
450 |
+
def from_civitai(self, state_dict):
|
451 |
+
return self.from_diffusers(state_dict)
|
diffsynth/models/hunyuan_dit_text_encoder.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BertModel, BertConfig, T5EncoderModel, T5Config
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
class HunyuanDiTCLIPTextEncoder(BertModel):
|
7 |
+
def __init__(self):
|
8 |
+
config = BertConfig(
|
9 |
+
_name_or_path = "",
|
10 |
+
architectures = ["BertModel"],
|
11 |
+
attention_probs_dropout_prob = 0.1,
|
12 |
+
bos_token_id = 0,
|
13 |
+
classifier_dropout = None,
|
14 |
+
directionality = "bidi",
|
15 |
+
eos_token_id = 2,
|
16 |
+
hidden_act = "gelu",
|
17 |
+
hidden_dropout_prob = 0.1,
|
18 |
+
hidden_size = 1024,
|
19 |
+
initializer_range = 0.02,
|
20 |
+
intermediate_size = 4096,
|
21 |
+
layer_norm_eps = 1e-12,
|
22 |
+
max_position_embeddings = 512,
|
23 |
+
model_type = "bert",
|
24 |
+
num_attention_heads = 16,
|
25 |
+
num_hidden_layers = 24,
|
26 |
+
output_past = True,
|
27 |
+
pad_token_id = 0,
|
28 |
+
pooler_fc_size = 768,
|
29 |
+
pooler_num_attention_heads = 12,
|
30 |
+
pooler_num_fc_layers = 3,
|
31 |
+
pooler_size_per_head = 128,
|
32 |
+
pooler_type = "first_token_transform",
|
33 |
+
position_embedding_type = "absolute",
|
34 |
+
torch_dtype = "float32",
|
35 |
+
transformers_version = "4.37.2",
|
36 |
+
type_vocab_size = 2,
|
37 |
+
use_cache = True,
|
38 |
+
vocab_size = 47020
|
39 |
+
)
|
40 |
+
super().__init__(config, add_pooling_layer=False)
|
41 |
+
self.eval()
|
42 |
+
|
43 |
+
def forward(self, input_ids, attention_mask, clip_skip=1):
|
44 |
+
input_shape = input_ids.size()
|
45 |
+
|
46 |
+
batch_size, seq_length = input_shape
|
47 |
+
device = input_ids.device
|
48 |
+
|
49 |
+
past_key_values_length = 0
|
50 |
+
|
51 |
+
if attention_mask is None:
|
52 |
+
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
53 |
+
|
54 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
55 |
+
|
56 |
+
embedding_output = self.embeddings(
|
57 |
+
input_ids=input_ids,
|
58 |
+
position_ids=None,
|
59 |
+
token_type_ids=None,
|
60 |
+
inputs_embeds=None,
|
61 |
+
past_key_values_length=0,
|
62 |
+
)
|
63 |
+
encoder_outputs = self.encoder(
|
64 |
+
embedding_output,
|
65 |
+
attention_mask=extended_attention_mask,
|
66 |
+
head_mask=None,
|
67 |
+
encoder_hidden_states=None,
|
68 |
+
encoder_attention_mask=None,
|
69 |
+
past_key_values=None,
|
70 |
+
use_cache=False,
|
71 |
+
output_attentions=False,
|
72 |
+
output_hidden_states=True,
|
73 |
+
return_dict=True,
|
74 |
+
)
|
75 |
+
all_hidden_states = encoder_outputs.hidden_states
|
76 |
+
prompt_emb = all_hidden_states[-clip_skip]
|
77 |
+
if clip_skip > 1:
|
78 |
+
mean, std = all_hidden_states[-1].mean(), all_hidden_states[-1].std()
|
79 |
+
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
80 |
+
return prompt_emb
|
81 |
+
|
82 |
+
@staticmethod
|
83 |
+
def state_dict_converter():
|
84 |
+
return HunyuanDiTCLIPTextEncoderStateDictConverter()
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
class HunyuanDiTT5TextEncoder(T5EncoderModel):
|
89 |
+
def __init__(self):
|
90 |
+
config = T5Config(
|
91 |
+
_name_or_path = "../HunyuanDiT/t2i/mt5",
|
92 |
+
architectures = ["MT5ForConditionalGeneration"],
|
93 |
+
classifier_dropout = 0.0,
|
94 |
+
d_ff = 5120,
|
95 |
+
d_kv = 64,
|
96 |
+
d_model = 2048,
|
97 |
+
decoder_start_token_id = 0,
|
98 |
+
dense_act_fn = "gelu_new",
|
99 |
+
dropout_rate = 0.1,
|
100 |
+
eos_token_id = 1,
|
101 |
+
feed_forward_proj = "gated-gelu",
|
102 |
+
initializer_factor = 1.0,
|
103 |
+
is_encoder_decoder = True,
|
104 |
+
is_gated_act = True,
|
105 |
+
layer_norm_epsilon = 1e-06,
|
106 |
+
model_type = "t5",
|
107 |
+
num_decoder_layers = 24,
|
108 |
+
num_heads = 32,
|
109 |
+
num_layers = 24,
|
110 |
+
output_past = True,
|
111 |
+
pad_token_id = 0,
|
112 |
+
relative_attention_max_distance = 128,
|
113 |
+
relative_attention_num_buckets = 32,
|
114 |
+
tie_word_embeddings = False,
|
115 |
+
tokenizer_class = "T5Tokenizer",
|
116 |
+
transformers_version = "4.37.2",
|
117 |
+
use_cache = True,
|
118 |
+
vocab_size = 250112
|
119 |
+
)
|
120 |
+
super().__init__(config)
|
121 |
+
self.eval()
|
122 |
+
|
123 |
+
def forward(self, input_ids, attention_mask, clip_skip=1):
|
124 |
+
outputs = super().forward(
|
125 |
+
input_ids=input_ids,
|
126 |
+
attention_mask=attention_mask,
|
127 |
+
output_hidden_states=True,
|
128 |
+
)
|
129 |
+
prompt_emb = outputs.hidden_states[-clip_skip]
|
130 |
+
if clip_skip > 1:
|
131 |
+
mean, std = outputs.hidden_states[-1].mean(), outputs.hidden_states[-1].std()
|
132 |
+
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
133 |
+
return prompt_emb
|
134 |
+
|
135 |
+
@staticmethod
|
136 |
+
def state_dict_converter():
|
137 |
+
return HunyuanDiTT5TextEncoderStateDictConverter()
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
class HunyuanDiTCLIPTextEncoderStateDictConverter():
|
142 |
+
def __init__(self):
|
143 |
+
pass
|
144 |
+
|
145 |
+
def from_diffusers(self, state_dict):
|
146 |
+
state_dict_ = {name[5:]: param for name, param in state_dict.items() if name.startswith("bert.")}
|
147 |
+
return state_dict_
|
148 |
+
|
149 |
+
def from_civitai(self, state_dict):
|
150 |
+
return self.from_diffusers(state_dict)
|
151 |
+
|
152 |
+
|
153 |
+
class HunyuanDiTT5TextEncoderStateDictConverter():
|
154 |
+
def __init__(self):
|
155 |
+
pass
|
156 |
+
|
157 |
+
def from_diffusers(self, state_dict):
|
158 |
+
state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("encoder.")}
|
159 |
+
state_dict_["shared.weight"] = state_dict["shared.weight"]
|
160 |
+
return state_dict_
|
161 |
+
|
162 |
+
def from_civitai(self, state_dict):
|
163 |
+
return self.from_diffusers(state_dict)
|
diffsynth/models/kolors_text_encoder.py
ADDED
@@ -0,0 +1,1552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This model is copied from https://github.com/Kwai-Kolors/Kolors/tree/master/kolors/models.
|
3 |
+
We didn't modify this model.
|
4 |
+
The tensor operation is performed in the prompter.
|
5 |
+
"""
|
6 |
+
|
7 |
+
|
8 |
+
""" PyTorch ChatGLM model. """
|
9 |
+
|
10 |
+
import math
|
11 |
+
import copy
|
12 |
+
import warnings
|
13 |
+
import re
|
14 |
+
import sys
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.utils.checkpoint
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from torch import nn
|
20 |
+
from torch.nn import CrossEntropyLoss, LayerNorm
|
21 |
+
from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
|
22 |
+
from torch.nn.utils import skip_init
|
23 |
+
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
24 |
+
from copy import deepcopy
|
25 |
+
|
26 |
+
from transformers.modeling_outputs import (
|
27 |
+
BaseModelOutputWithPast,
|
28 |
+
CausalLMOutputWithPast,
|
29 |
+
SequenceClassifierOutputWithPast,
|
30 |
+
)
|
31 |
+
from transformers.modeling_utils import PreTrainedModel
|
32 |
+
from transformers.utils import logging
|
33 |
+
from transformers.generation.logits_process import LogitsProcessor
|
34 |
+
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
|
35 |
+
from transformers import PretrainedConfig
|
36 |
+
from torch.nn.parameter import Parameter
|
37 |
+
import bz2
|
38 |
+
import torch
|
39 |
+
import base64
|
40 |
+
import ctypes
|
41 |
+
from transformers.utils import logging
|
42 |
+
from typing import List
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
logger = logging.get_logger(__name__)
|
47 |
+
|
48 |
+
try:
|
49 |
+
from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
|
50 |
+
|
51 |
+
|
52 |
+
class Kernel:
|
53 |
+
def __init__(self, code: bytes, function_names: List[str]):
|
54 |
+
self.code = code
|
55 |
+
self._function_names = function_names
|
56 |
+
self._cmodule = LazyKernelCModule(self.code)
|
57 |
+
|
58 |
+
for name in self._function_names:
|
59 |
+
setattr(self, name, KernelFunction(self._cmodule, name))
|
60 |
+
|
61 |
+
|
62 |
+
quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ"
|
63 |
+
|
64 |
+
kernels = Kernel(
|
65 |
+
bz2.decompress(base64.b64decode(quantization_code)),
|
66 |
+
[
|
67 |
+
"int4WeightCompression",
|
68 |
+
"int4WeightExtractionFloat",
|
69 |
+
"int4WeightExtractionHalf",
|
70 |
+
"int8WeightExtractionFloat",
|
71 |
+
"int8WeightExtractionHalf",
|
72 |
+
],
|
73 |
+
)
|
74 |
+
except Exception as exception:
|
75 |
+
kernels = None
|
76 |
+
logger.warning("Failed to load cpm_kernels:" + str(exception))
|
77 |
+
|
78 |
+
|
79 |
+
class W8A16Linear(torch.autograd.Function):
|
80 |
+
@staticmethod
|
81 |
+
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
|
82 |
+
ctx.inp_shape = inp.size()
|
83 |
+
ctx.weight_bit_width = weight_bit_width
|
84 |
+
out_features = quant_w.size(0)
|
85 |
+
inp = inp.contiguous().view(-1, inp.size(-1))
|
86 |
+
weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
|
87 |
+
ctx.weight_shape = weight.size()
|
88 |
+
output = inp.mm(weight.t())
|
89 |
+
ctx.save_for_backward(inp, quant_w, scale_w)
|
90 |
+
return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
|
91 |
+
|
92 |
+
@staticmethod
|
93 |
+
def backward(ctx, grad_output: torch.Tensor):
|
94 |
+
inp, quant_w, scale_w = ctx.saved_tensors
|
95 |
+
weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width)
|
96 |
+
grad_output = grad_output.contiguous().view(-1, weight.size(0))
|
97 |
+
grad_input = grad_output.mm(weight)
|
98 |
+
grad_weight = grad_output.t().mm(inp)
|
99 |
+
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
|
100 |
+
|
101 |
+
|
102 |
+
def compress_int4_weight(weight: torch.Tensor): # (n, m)
|
103 |
+
with torch.cuda.device(weight.device):
|
104 |
+
n, m = weight.size(0), weight.size(1)
|
105 |
+
assert m % 2 == 0
|
106 |
+
m = m // 2
|
107 |
+
out = torch.empty(n, m, dtype=torch.int8, device="cuda")
|
108 |
+
stream = torch.cuda.current_stream()
|
109 |
+
|
110 |
+
gridDim = (n, 1, 1)
|
111 |
+
blockDim = (min(round_up(m, 32), 1024), 1, 1)
|
112 |
+
|
113 |
+
kernels.int4WeightCompression(
|
114 |
+
gridDim,
|
115 |
+
blockDim,
|
116 |
+
0,
|
117 |
+
stream,
|
118 |
+
[ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],
|
119 |
+
)
|
120 |
+
return out
|
121 |
+
|
122 |
+
|
123 |
+
def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):
|
124 |
+
assert scale_list.dtype in [torch.half, torch.bfloat16]
|
125 |
+
assert weight.dtype in [torch.int8]
|
126 |
+
if source_bit_width == 8:
|
127 |
+
return weight.to(scale_list.dtype) * scale_list[:, None]
|
128 |
+
elif source_bit_width == 4:
|
129 |
+
func = (
|
130 |
+
kernels.int4WeightExtractionHalf if scale_list.dtype == torch.half else kernels.int4WeightExtractionBFloat16
|
131 |
+
)
|
132 |
+
else:
|
133 |
+
assert False, "Unsupported bit-width"
|
134 |
+
|
135 |
+
with torch.cuda.device(weight.device):
|
136 |
+
n, m = weight.size(0), weight.size(1)
|
137 |
+
out = torch.empty(n, m * (8 // source_bit_width), dtype=scale_list.dtype, device="cuda")
|
138 |
+
stream = torch.cuda.current_stream()
|
139 |
+
|
140 |
+
gridDim = (n, 1, 1)
|
141 |
+
blockDim = (min(round_up(m, 32), 1024), 1, 1)
|
142 |
+
|
143 |
+
func(
|
144 |
+
gridDim,
|
145 |
+
blockDim,
|
146 |
+
0,
|
147 |
+
stream,
|
148 |
+
[
|
149 |
+
ctypes.c_void_p(weight.data_ptr()),
|
150 |
+
ctypes.c_void_p(scale_list.data_ptr()),
|
151 |
+
ctypes.c_void_p(out.data_ptr()),
|
152 |
+
ctypes.c_int32(n),
|
153 |
+
ctypes.c_int32(m),
|
154 |
+
],
|
155 |
+
)
|
156 |
+
return out
|
157 |
+
|
158 |
+
|
159 |
+
class QuantizedLinear(torch.nn.Module):
|
160 |
+
def __init__(self, weight_bit_width: int, weight, bias=None, device="cuda", dtype=None, empty_init=False):
|
161 |
+
super().__init__()
|
162 |
+
weight = weight.to(device) # ensure the weight is on the cuda device
|
163 |
+
assert str(weight.device).startswith(
|
164 |
+
'cuda'), 'The weights that need to be quantified should be on the CUDA device'
|
165 |
+
self.weight_bit_width = weight_bit_width
|
166 |
+
shape = weight.shape
|
167 |
+
|
168 |
+
if weight is None or empty_init:
|
169 |
+
self.weight = torch.empty(shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=device)
|
170 |
+
self.weight_scale = torch.empty(shape[0], dtype=dtype, device=device)
|
171 |
+
else:
|
172 |
+
self.weight_scale = weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)
|
173 |
+
self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
|
174 |
+
if weight_bit_width == 4:
|
175 |
+
self.weight = compress_int4_weight(self.weight)
|
176 |
+
|
177 |
+
self.weight = Parameter(self.weight.to(device), requires_grad=False)
|
178 |
+
self.weight_scale = Parameter(self.weight_scale.to(device), requires_grad=False)
|
179 |
+
self.bias = Parameter(bias.to(device), requires_grad=False) if bias is not None else None
|
180 |
+
|
181 |
+
def forward(self, input):
|
182 |
+
output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width)
|
183 |
+
if self.bias is not None:
|
184 |
+
output = output + self.bias
|
185 |
+
return output
|
186 |
+
|
187 |
+
|
188 |
+
def quantize(model, weight_bit_width, empty_init=False, device=None):
|
189 |
+
"""Replace fp16 linear with quantized linear"""
|
190 |
+
for layer in model.layers:
|
191 |
+
layer.self_attention.query_key_value = QuantizedLinear(
|
192 |
+
weight_bit_width=weight_bit_width,
|
193 |
+
weight=layer.self_attention.query_key_value.weight,
|
194 |
+
bias=layer.self_attention.query_key_value.bias,
|
195 |
+
dtype=layer.self_attention.query_key_value.weight.dtype,
|
196 |
+
device=layer.self_attention.query_key_value.weight.device if device is None else device,
|
197 |
+
empty_init=empty_init
|
198 |
+
)
|
199 |
+
layer.self_attention.dense = QuantizedLinear(
|
200 |
+
weight_bit_width=weight_bit_width,
|
201 |
+
weight=layer.self_attention.dense.weight,
|
202 |
+
bias=layer.self_attention.dense.bias,
|
203 |
+
dtype=layer.self_attention.dense.weight.dtype,
|
204 |
+
device=layer.self_attention.dense.weight.device if device is None else device,
|
205 |
+
empty_init=empty_init
|
206 |
+
)
|
207 |
+
layer.mlp.dense_h_to_4h = QuantizedLinear(
|
208 |
+
weight_bit_width=weight_bit_width,
|
209 |
+
weight=layer.mlp.dense_h_to_4h.weight,
|
210 |
+
bias=layer.mlp.dense_h_to_4h.bias,
|
211 |
+
dtype=layer.mlp.dense_h_to_4h.weight.dtype,
|
212 |
+
device=layer.mlp.dense_h_to_4h.weight.device if device is None else device,
|
213 |
+
empty_init=empty_init
|
214 |
+
)
|
215 |
+
layer.mlp.dense_4h_to_h = QuantizedLinear(
|
216 |
+
weight_bit_width=weight_bit_width,
|
217 |
+
weight=layer.mlp.dense_4h_to_h.weight,
|
218 |
+
bias=layer.mlp.dense_4h_to_h.bias,
|
219 |
+
dtype=layer.mlp.dense_4h_to_h.weight.dtype,
|
220 |
+
device=layer.mlp.dense_4h_to_h.weight.device if device is None else device,
|
221 |
+
empty_init=empty_init
|
222 |
+
)
|
223 |
+
|
224 |
+
return model
|
225 |
+
|
226 |
+
|
227 |
+
|
228 |
+
class ChatGLMConfig(PretrainedConfig):
|
229 |
+
model_type = "chatglm"
|
230 |
+
def __init__(
|
231 |
+
self,
|
232 |
+
num_layers=28,
|
233 |
+
padded_vocab_size=65024,
|
234 |
+
hidden_size=4096,
|
235 |
+
ffn_hidden_size=13696,
|
236 |
+
kv_channels=128,
|
237 |
+
num_attention_heads=32,
|
238 |
+
seq_length=2048,
|
239 |
+
hidden_dropout=0.0,
|
240 |
+
classifier_dropout=None,
|
241 |
+
attention_dropout=0.0,
|
242 |
+
layernorm_epsilon=1e-5,
|
243 |
+
rmsnorm=True,
|
244 |
+
apply_residual_connection_post_layernorm=False,
|
245 |
+
post_layer_norm=True,
|
246 |
+
add_bias_linear=False,
|
247 |
+
add_qkv_bias=False,
|
248 |
+
bias_dropout_fusion=True,
|
249 |
+
multi_query_attention=False,
|
250 |
+
multi_query_group_num=1,
|
251 |
+
apply_query_key_layer_scaling=True,
|
252 |
+
attention_softmax_in_fp32=True,
|
253 |
+
fp32_residual_connection=False,
|
254 |
+
quantization_bit=0,
|
255 |
+
pre_seq_len=None,
|
256 |
+
prefix_projection=False,
|
257 |
+
**kwargs
|
258 |
+
):
|
259 |
+
self.num_layers = num_layers
|
260 |
+
self.vocab_size = padded_vocab_size
|
261 |
+
self.padded_vocab_size = padded_vocab_size
|
262 |
+
self.hidden_size = hidden_size
|
263 |
+
self.ffn_hidden_size = ffn_hidden_size
|
264 |
+
self.kv_channels = kv_channels
|
265 |
+
self.num_attention_heads = num_attention_heads
|
266 |
+
self.seq_length = seq_length
|
267 |
+
self.hidden_dropout = hidden_dropout
|
268 |
+
self.classifier_dropout = classifier_dropout
|
269 |
+
self.attention_dropout = attention_dropout
|
270 |
+
self.layernorm_epsilon = layernorm_epsilon
|
271 |
+
self.rmsnorm = rmsnorm
|
272 |
+
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
|
273 |
+
self.post_layer_norm = post_layer_norm
|
274 |
+
self.add_bias_linear = add_bias_linear
|
275 |
+
self.add_qkv_bias = add_qkv_bias
|
276 |
+
self.bias_dropout_fusion = bias_dropout_fusion
|
277 |
+
self.multi_query_attention = multi_query_attention
|
278 |
+
self.multi_query_group_num = multi_query_group_num
|
279 |
+
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
|
280 |
+
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
|
281 |
+
self.fp32_residual_connection = fp32_residual_connection
|
282 |
+
self.quantization_bit = quantization_bit
|
283 |
+
self.pre_seq_len = pre_seq_len
|
284 |
+
self.prefix_projection = prefix_projection
|
285 |
+
super().__init__(**kwargs)
|
286 |
+
|
287 |
+
|
288 |
+
|
289 |
+
# flags required to enable jit fusion kernels
|
290 |
+
|
291 |
+
if sys.platform != 'darwin':
|
292 |
+
torch._C._jit_set_profiling_mode(False)
|
293 |
+
torch._C._jit_set_profiling_executor(False)
|
294 |
+
torch._C._jit_override_can_fuse_on_cpu(True)
|
295 |
+
torch._C._jit_override_can_fuse_on_gpu(True)
|
296 |
+
|
297 |
+
logger = logging.get_logger(__name__)
|
298 |
+
|
299 |
+
_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
|
300 |
+
_CONFIG_FOR_DOC = "ChatGLM6BConfig"
|
301 |
+
|
302 |
+
CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
303 |
+
"THUDM/chatglm3-6b-base",
|
304 |
+
# See all ChatGLM models at https://huggingface.co/models?filter=chatglm
|
305 |
+
]
|
306 |
+
|
307 |
+
|
308 |
+
def default_init(cls, *args, **kwargs):
|
309 |
+
return cls(*args, **kwargs)
|
310 |
+
|
311 |
+
|
312 |
+
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
313 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
314 |
+
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
315 |
+
scores.zero_()
|
316 |
+
scores[..., 5] = 5e4
|
317 |
+
return scores
|
318 |
+
|
319 |
+
|
320 |
+
class PrefixEncoder(torch.nn.Module):
|
321 |
+
"""
|
322 |
+
The torch.nn model to encode the prefix
|
323 |
+
Input shape: (batch-size, prefix-length)
|
324 |
+
Output shape: (batch-size, prefix-length, 2*layers*hidden)
|
325 |
+
"""
|
326 |
+
|
327 |
+
def __init__(self, config: ChatGLMConfig):
|
328 |
+
super().__init__()
|
329 |
+
self.prefix_projection = config.prefix_projection
|
330 |
+
if self.prefix_projection:
|
331 |
+
# Use a two-layer MLP to encode the prefix
|
332 |
+
kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
|
333 |
+
self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
|
334 |
+
self.trans = torch.nn.Sequential(
|
335 |
+
torch.nn.Linear(kv_size, config.hidden_size),
|
336 |
+
torch.nn.Tanh(),
|
337 |
+
torch.nn.Linear(config.hidden_size, kv_size)
|
338 |
+
)
|
339 |
+
else:
|
340 |
+
self.embedding = torch.nn.Embedding(config.pre_seq_len,
|
341 |
+
config.num_layers * config.kv_channels * config.multi_query_group_num * 2)
|
342 |
+
|
343 |
+
def forward(self, prefix: torch.Tensor):
|
344 |
+
if self.prefix_projection:
|
345 |
+
prefix_tokens = self.embedding(prefix)
|
346 |
+
past_key_values = self.trans(prefix_tokens)
|
347 |
+
else:
|
348 |
+
past_key_values = self.embedding(prefix)
|
349 |
+
return past_key_values
|
350 |
+
|
351 |
+
|
352 |
+
def split_tensor_along_last_dim(
|
353 |
+
tensor: torch.Tensor,
|
354 |
+
num_partitions: int,
|
355 |
+
contiguous_split_chunks: bool = False,
|
356 |
+
) -> List[torch.Tensor]:
|
357 |
+
"""Split a tensor along its last dimension.
|
358 |
+
|
359 |
+
Arguments:
|
360 |
+
tensor: input tensor.
|
361 |
+
num_partitions: number of partitions to split the tensor
|
362 |
+
contiguous_split_chunks: If True, make each chunk contiguous
|
363 |
+
in memory.
|
364 |
+
|
365 |
+
Returns:
|
366 |
+
A list of Tensors
|
367 |
+
"""
|
368 |
+
# Get the size and dimension.
|
369 |
+
last_dim = tensor.dim() - 1
|
370 |
+
last_dim_size = tensor.size()[last_dim] // num_partitions
|
371 |
+
# Split.
|
372 |
+
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
373 |
+
# Note: torch.split does not create contiguous tensors by default.
|
374 |
+
if contiguous_split_chunks:
|
375 |
+
return tuple(chunk.contiguous() for chunk in tensor_list)
|
376 |
+
|
377 |
+
return tensor_list
|
378 |
+
|
379 |
+
|
380 |
+
class RotaryEmbedding(nn.Module):
|
381 |
+
def __init__(self, dim, original_impl=False, device=None, dtype=None):
|
382 |
+
super().__init__()
|
383 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
|
384 |
+
self.register_buffer("inv_freq", inv_freq)
|
385 |
+
self.dim = dim
|
386 |
+
self.original_impl = original_impl
|
387 |
+
|
388 |
+
def forward_impl(
|
389 |
+
self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
|
390 |
+
):
|
391 |
+
"""Enhanced Transformer with Rotary Position Embedding.
|
392 |
+
|
393 |
+
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
|
394 |
+
transformers/rope/__init__.py. MIT License:
|
395 |
+
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
|
396 |
+
"""
|
397 |
+
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
398 |
+
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
|
399 |
+
|
400 |
+
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
401 |
+
seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
|
402 |
+
|
403 |
+
# Calculate the product of position index and $\theta_i$
|
404 |
+
idx_theta = torch.outer(seq_idx, theta).float()
|
405 |
+
|
406 |
+
cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
|
407 |
+
|
408 |
+
# this is to mimic the behaviour of complex32, else we will get different results
|
409 |
+
if dtype in (torch.float16, torch.bfloat16, torch.int8):
|
410 |
+
cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
|
411 |
+
return cache
|
412 |
+
|
413 |
+
def forward(self, max_seq_len, offset=0):
|
414 |
+
return self.forward_impl(
|
415 |
+
max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
|
416 |
+
)
|
417 |
+
|
418 |
+
|
419 |
+
@torch.jit.script
|
420 |
+
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
421 |
+
# x: [sq, b, np, hn]
|
422 |
+
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
423 |
+
rot_dim = rope_cache.shape[-2] * 2
|
424 |
+
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
|
425 |
+
# truncate to support variable sizes
|
426 |
+
rope_cache = rope_cache[:sq]
|
427 |
+
xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
|
428 |
+
rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
|
429 |
+
x_out2 = torch.stack(
|
430 |
+
[
|
431 |
+
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
|
432 |
+
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
|
433 |
+
],
|
434 |
+
-1,
|
435 |
+
)
|
436 |
+
x_out2 = x_out2.flatten(3)
|
437 |
+
return torch.cat((x_out2, x_pass), dim=-1)
|
438 |
+
|
439 |
+
|
440 |
+
class RMSNorm(torch.nn.Module):
|
441 |
+
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
|
442 |
+
super().__init__()
|
443 |
+
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
|
444 |
+
self.eps = eps
|
445 |
+
|
446 |
+
def forward(self, hidden_states: torch.Tensor):
|
447 |
+
input_dtype = hidden_states.dtype
|
448 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
449 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
450 |
+
|
451 |
+
return (self.weight * hidden_states).to(input_dtype)
|
452 |
+
|
453 |
+
|
454 |
+
class CoreAttention(torch.nn.Module):
|
455 |
+
def __init__(self, config: ChatGLMConfig, layer_number):
|
456 |
+
super(CoreAttention, self).__init__()
|
457 |
+
|
458 |
+
self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
|
459 |
+
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
|
460 |
+
if self.apply_query_key_layer_scaling:
|
461 |
+
self.attention_softmax_in_fp32 = True
|
462 |
+
self.layer_number = max(1, layer_number)
|
463 |
+
|
464 |
+
projection_size = config.kv_channels * config.num_attention_heads
|
465 |
+
|
466 |
+
# Per attention head and per partition values.
|
467 |
+
self.hidden_size_per_partition = projection_size
|
468 |
+
self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
|
469 |
+
self.num_attention_heads_per_partition = config.num_attention_heads
|
470 |
+
|
471 |
+
coeff = None
|
472 |
+
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
473 |
+
if self.apply_query_key_layer_scaling:
|
474 |
+
coeff = self.layer_number
|
475 |
+
self.norm_factor *= coeff
|
476 |
+
self.coeff = coeff
|
477 |
+
|
478 |
+
self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
|
479 |
+
|
480 |
+
def forward(self, query_layer, key_layer, value_layer, attention_mask):
|
481 |
+
pytorch_major_version = int(torch.__version__.split('.')[0])
|
482 |
+
if pytorch_major_version >= 2:
|
483 |
+
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
|
484 |
+
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
485 |
+
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
|
486 |
+
is_causal=True)
|
487 |
+
else:
|
488 |
+
if attention_mask is not None:
|
489 |
+
attention_mask = ~attention_mask
|
490 |
+
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
|
491 |
+
attention_mask)
|
492 |
+
context_layer = context_layer.permute(2, 0, 1, 3)
|
493 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
494 |
+
context_layer = context_layer.reshape(*new_context_layer_shape)
|
495 |
+
else:
|
496 |
+
# Raw attention scores
|
497 |
+
|
498 |
+
# [b, np, sq, sk]
|
499 |
+
output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
|
500 |
+
|
501 |
+
# [sq, b, np, hn] -> [sq, b * np, hn]
|
502 |
+
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
|
503 |
+
# [sk, b, np, hn] -> [sk, b * np, hn]
|
504 |
+
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
|
505 |
+
|
506 |
+
# preallocting input tensor: [b * np, sq, sk]
|
507 |
+
matmul_input_buffer = torch.empty(
|
508 |
+
output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
|
509 |
+
device=query_layer.device
|
510 |
+
)
|
511 |
+
|
512 |
+
# Raw attention scores. [b * np, sq, sk]
|
513 |
+
matmul_result = torch.baddbmm(
|
514 |
+
matmul_input_buffer,
|
515 |
+
query_layer.transpose(0, 1), # [b * np, sq, hn]
|
516 |
+
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
|
517 |
+
beta=0.0,
|
518 |
+
alpha=(1.0 / self.norm_factor),
|
519 |
+
)
|
520 |
+
|
521 |
+
# change view to [b, np, sq, sk]
|
522 |
+
attention_scores = matmul_result.view(*output_size)
|
523 |
+
|
524 |
+
# ===========================
|
525 |
+
# Attention probs and dropout
|
526 |
+
# ===========================
|
527 |
+
|
528 |
+
# attention scores and attention mask [b, np, sq, sk]
|
529 |
+
if self.attention_softmax_in_fp32:
|
530 |
+
attention_scores = attention_scores.float()
|
531 |
+
if self.coeff is not None:
|
532 |
+
attention_scores = attention_scores * self.coeff
|
533 |
+
if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
|
534 |
+
attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
|
535 |
+
device=attention_scores.device, dtype=torch.bool)
|
536 |
+
attention_mask.tril_()
|
537 |
+
attention_mask = ~attention_mask
|
538 |
+
if attention_mask is not None:
|
539 |
+
attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
|
540 |
+
attention_probs = F.softmax(attention_scores, dim=-1)
|
541 |
+
attention_probs = attention_probs.type_as(value_layer)
|
542 |
+
|
543 |
+
# This is actually dropping out entire tokens to attend to, which might
|
544 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
545 |
+
attention_probs = self.attention_dropout(attention_probs)
|
546 |
+
# =========================
|
547 |
+
# Context layer. [sq, b, hp]
|
548 |
+
# =========================
|
549 |
+
|
550 |
+
# value_layer -> context layer.
|
551 |
+
# [sk, b, np, hn] --> [b, np, sq, hn]
|
552 |
+
|
553 |
+
# context layer shape: [b, np, sq, hn]
|
554 |
+
output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
|
555 |
+
# change view [sk, b * np, hn]
|
556 |
+
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
|
557 |
+
# change view [b * np, sq, sk]
|
558 |
+
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
|
559 |
+
# matmul: [b * np, sq, hn]
|
560 |
+
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
|
561 |
+
# change view [b, np, sq, hn]
|
562 |
+
context_layer = context_layer.view(*output_size)
|
563 |
+
# [b, np, sq, hn] --> [sq, b, np, hn]
|
564 |
+
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
565 |
+
# [sq, b, np, hn] --> [sq, b, hp]
|
566 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
567 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
568 |
+
|
569 |
+
return context_layer
|
570 |
+
|
571 |
+
|
572 |
+
class SelfAttention(torch.nn.Module):
|
573 |
+
"""Parallel self-attention layer abstract class.
|
574 |
+
|
575 |
+
Self-attention layer takes input with size [s, b, h]
|
576 |
+
and returns output of the same size.
|
577 |
+
"""
|
578 |
+
|
579 |
+
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
|
580 |
+
super(SelfAttention, self).__init__()
|
581 |
+
self.layer_number = max(1, layer_number)
|
582 |
+
|
583 |
+
self.projection_size = config.kv_channels * config.num_attention_heads
|
584 |
+
|
585 |
+
# Per attention head and per partition values.
|
586 |
+
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
|
587 |
+
self.num_attention_heads_per_partition = config.num_attention_heads
|
588 |
+
|
589 |
+
self.multi_query_attention = config.multi_query_attention
|
590 |
+
self.qkv_hidden_size = 3 * self.projection_size
|
591 |
+
if self.multi_query_attention:
|
592 |
+
self.num_multi_query_groups_per_partition = config.multi_query_group_num
|
593 |
+
self.qkv_hidden_size = (
|
594 |
+
self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
|
595 |
+
)
|
596 |
+
self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
|
597 |
+
bias=config.add_bias_linear or config.add_qkv_bias,
|
598 |
+
device=device, **_config_to_kwargs(config)
|
599 |
+
)
|
600 |
+
|
601 |
+
self.core_attention = CoreAttention(config, self.layer_number)
|
602 |
+
|
603 |
+
# Output.
|
604 |
+
self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
|
605 |
+
device=device, **_config_to_kwargs(config)
|
606 |
+
)
|
607 |
+
|
608 |
+
def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
|
609 |
+
if self.multi_query_attention:
|
610 |
+
num_attention_heads = self.num_multi_query_groups_per_partition
|
611 |
+
else:
|
612 |
+
num_attention_heads = self.num_attention_heads_per_partition
|
613 |
+
return torch.empty(
|
614 |
+
inference_max_sequence_len,
|
615 |
+
batch_size,
|
616 |
+
num_attention_heads,
|
617 |
+
self.hidden_size_per_attention_head,
|
618 |
+
dtype=dtype,
|
619 |
+
device=device,
|
620 |
+
)
|
621 |
+
|
622 |
+
def forward(
|
623 |
+
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
|
624 |
+
):
|
625 |
+
# hidden_states: [sq, b, h]
|
626 |
+
|
627 |
+
# =================================================
|
628 |
+
# Pre-allocate memory for key-values for inference.
|
629 |
+
# =================================================
|
630 |
+
# =====================
|
631 |
+
# Query, Key, and Value
|
632 |
+
# =====================
|
633 |
+
|
634 |
+
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
|
635 |
+
mixed_x_layer = self.query_key_value(hidden_states)
|
636 |
+
|
637 |
+
if self.multi_query_attention:
|
638 |
+
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
|
639 |
+
[
|
640 |
+
self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
|
641 |
+
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
|
642 |
+
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
|
643 |
+
],
|
644 |
+
dim=-1,
|
645 |
+
)
|
646 |
+
query_layer = query_layer.view(
|
647 |
+
query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
|
648 |
+
)
|
649 |
+
key_layer = key_layer.view(
|
650 |
+
key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
|
651 |
+
)
|
652 |
+
value_layer = value_layer.view(
|
653 |
+
value_layer.size()[:-1]
|
654 |
+
+ (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
|
655 |
+
)
|
656 |
+
else:
|
657 |
+
new_tensor_shape = mixed_x_layer.size()[:-1] + \
|
658 |
+
(self.num_attention_heads_per_partition,
|
659 |
+
3 * self.hidden_size_per_attention_head)
|
660 |
+
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
661 |
+
|
662 |
+
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
663 |
+
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
|
664 |
+
|
665 |
+
# apply relative positional encoding (rotary embedding)
|
666 |
+
if rotary_pos_emb is not None:
|
667 |
+
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
|
668 |
+
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
|
669 |
+
|
670 |
+
# adjust key and value for inference
|
671 |
+
if kv_cache is not None:
|
672 |
+
cache_k, cache_v = kv_cache
|
673 |
+
key_layer = torch.cat((cache_k, key_layer), dim=0)
|
674 |
+
value_layer = torch.cat((cache_v, value_layer), dim=0)
|
675 |
+
if use_cache:
|
676 |
+
kv_cache = (key_layer, value_layer)
|
677 |
+
else:
|
678 |
+
kv_cache = None
|
679 |
+
|
680 |
+
if self.multi_query_attention:
|
681 |
+
key_layer = key_layer.unsqueeze(-2)
|
682 |
+
key_layer = key_layer.expand(
|
683 |
+
-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
|
684 |
+
)
|
685 |
+
key_layer = key_layer.contiguous().view(
|
686 |
+
key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
|
687 |
+
)
|
688 |
+
value_layer = value_layer.unsqueeze(-2)
|
689 |
+
value_layer = value_layer.expand(
|
690 |
+
-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
|
691 |
+
)
|
692 |
+
value_layer = value_layer.contiguous().view(
|
693 |
+
value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
|
694 |
+
)
|
695 |
+
|
696 |
+
# ==================================
|
697 |
+
# core attention computation
|
698 |
+
# ==================================
|
699 |
+
|
700 |
+
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
|
701 |
+
|
702 |
+
# =================
|
703 |
+
# Output. [sq, b, h]
|
704 |
+
# =================
|
705 |
+
|
706 |
+
output = self.dense(context_layer)
|
707 |
+
|
708 |
+
return output, kv_cache
|
709 |
+
|
710 |
+
|
711 |
+
def _config_to_kwargs(args):
|
712 |
+
common_kwargs = {
|
713 |
+
"dtype": args.torch_dtype,
|
714 |
+
}
|
715 |
+
return common_kwargs
|
716 |
+
|
717 |
+
|
718 |
+
class MLP(torch.nn.Module):
|
719 |
+
"""MLP.
|
720 |
+
|
721 |
+
MLP will take the input with h hidden state, project it to 4*h
|
722 |
+
hidden dimension, perform nonlinear transformation, and project the
|
723 |
+
state back into h hidden dimension.
|
724 |
+
"""
|
725 |
+
|
726 |
+
def __init__(self, config: ChatGLMConfig, device=None):
|
727 |
+
super(MLP, self).__init__()
|
728 |
+
|
729 |
+
self.add_bias = config.add_bias_linear
|
730 |
+
|
731 |
+
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
|
732 |
+
self.dense_h_to_4h = nn.Linear(
|
733 |
+
config.hidden_size,
|
734 |
+
config.ffn_hidden_size * 2,
|
735 |
+
bias=self.add_bias,
|
736 |
+
device=device,
|
737 |
+
**_config_to_kwargs(config)
|
738 |
+
)
|
739 |
+
|
740 |
+
def swiglu(x):
|
741 |
+
x = torch.chunk(x, 2, dim=-1)
|
742 |
+
return F.silu(x[0]) * x[1]
|
743 |
+
|
744 |
+
self.activation_func = swiglu
|
745 |
+
|
746 |
+
# Project back to h.
|
747 |
+
self.dense_4h_to_h = nn.Linear(
|
748 |
+
config.ffn_hidden_size,
|
749 |
+
config.hidden_size,
|
750 |
+
bias=self.add_bias,
|
751 |
+
device=device,
|
752 |
+
**_config_to_kwargs(config)
|
753 |
+
)
|
754 |
+
|
755 |
+
def forward(self, hidden_states):
|
756 |
+
# [s, b, 4hp]
|
757 |
+
intermediate_parallel = self.dense_h_to_4h(hidden_states)
|
758 |
+
intermediate_parallel = self.activation_func(intermediate_parallel)
|
759 |
+
# [s, b, h]
|
760 |
+
output = self.dense_4h_to_h(intermediate_parallel)
|
761 |
+
return output
|
762 |
+
|
763 |
+
|
764 |
+
class GLMBlock(torch.nn.Module):
|
765 |
+
"""A single transformer layer.
|
766 |
+
|
767 |
+
Transformer layer takes input with size [s, b, h] and returns an
|
768 |
+
output of the same size.
|
769 |
+
"""
|
770 |
+
|
771 |
+
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
|
772 |
+
super(GLMBlock, self).__init__()
|
773 |
+
self.layer_number = layer_number
|
774 |
+
|
775 |
+
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
776 |
+
|
777 |
+
self.fp32_residual_connection = config.fp32_residual_connection
|
778 |
+
|
779 |
+
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
780 |
+
# Layernorm on the input data.
|
781 |
+
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
782 |
+
dtype=config.torch_dtype)
|
783 |
+
|
784 |
+
# Self attention.
|
785 |
+
self.self_attention = SelfAttention(config, layer_number, device=device)
|
786 |
+
self.hidden_dropout = config.hidden_dropout
|
787 |
+
|
788 |
+
# Layernorm on the attention output
|
789 |
+
self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
790 |
+
dtype=config.torch_dtype)
|
791 |
+
|
792 |
+
# MLP
|
793 |
+
self.mlp = MLP(config, device=device)
|
794 |
+
|
795 |
+
def forward(
|
796 |
+
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
|
797 |
+
):
|
798 |
+
# hidden_states: [s, b, h]
|
799 |
+
|
800 |
+
# Layer norm at the beginning of the transformer layer.
|
801 |
+
layernorm_output = self.input_layernorm(hidden_states)
|
802 |
+
# Self attention.
|
803 |
+
attention_output, kv_cache = self.self_attention(
|
804 |
+
layernorm_output,
|
805 |
+
attention_mask,
|
806 |
+
rotary_pos_emb,
|
807 |
+
kv_cache=kv_cache,
|
808 |
+
use_cache=use_cache
|
809 |
+
)
|
810 |
+
|
811 |
+
# Residual connection.
|
812 |
+
if self.apply_residual_connection_post_layernorm:
|
813 |
+
residual = layernorm_output
|
814 |
+
else:
|
815 |
+
residual = hidden_states
|
816 |
+
|
817 |
+
layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
|
818 |
+
layernorm_input = residual + layernorm_input
|
819 |
+
|
820 |
+
# Layer norm post the self attention.
|
821 |
+
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
822 |
+
|
823 |
+
# MLP.
|
824 |
+
mlp_output = self.mlp(layernorm_output)
|
825 |
+
|
826 |
+
# Second residual connection.
|
827 |
+
if self.apply_residual_connection_post_layernorm:
|
828 |
+
residual = layernorm_output
|
829 |
+
else:
|
830 |
+
residual = layernorm_input
|
831 |
+
|
832 |
+
output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
|
833 |
+
output = residual + output
|
834 |
+
|
835 |
+
return output, kv_cache
|
836 |
+
|
837 |
+
|
838 |
+
class GLMTransformer(torch.nn.Module):
|
839 |
+
"""Transformer class."""
|
840 |
+
|
841 |
+
def __init__(self, config: ChatGLMConfig, device=None):
|
842 |
+
super(GLMTransformer, self).__init__()
|
843 |
+
|
844 |
+
self.fp32_residual_connection = config.fp32_residual_connection
|
845 |
+
self.post_layer_norm = config.post_layer_norm
|
846 |
+
|
847 |
+
# Number of layers.
|
848 |
+
self.num_layers = config.num_layers
|
849 |
+
|
850 |
+
# Transformer layers.
|
851 |
+
def build_layer(layer_number):
|
852 |
+
return GLMBlock(config, layer_number, device=device)
|
853 |
+
|
854 |
+
self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
|
855 |
+
|
856 |
+
if self.post_layer_norm:
|
857 |
+
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
858 |
+
# Final layer norm before output.
|
859 |
+
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
860 |
+
dtype=config.torch_dtype)
|
861 |
+
|
862 |
+
self.gradient_checkpointing = False
|
863 |
+
|
864 |
+
def _get_layer(self, layer_number):
|
865 |
+
return self.layers[layer_number]
|
866 |
+
|
867 |
+
def forward(
|
868 |
+
self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
|
869 |
+
use_cache: Optional[bool] = True,
|
870 |
+
output_hidden_states: Optional[bool] = False,
|
871 |
+
):
|
872 |
+
if not kv_caches:
|
873 |
+
kv_caches = [None for _ in range(self.num_layers)]
|
874 |
+
presents = () if use_cache else None
|
875 |
+
if self.gradient_checkpointing and self.training:
|
876 |
+
if use_cache:
|
877 |
+
logger.warning_once(
|
878 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
879 |
+
)
|
880 |
+
use_cache = False
|
881 |
+
|
882 |
+
all_self_attentions = None
|
883 |
+
all_hidden_states = () if output_hidden_states else None
|
884 |
+
for index in range(self.num_layers):
|
885 |
+
if output_hidden_states:
|
886 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
887 |
+
|
888 |
+
layer = self._get_layer(index)
|
889 |
+
if self.gradient_checkpointing and self.training:
|
890 |
+
layer_ret = torch.utils.checkpoint.checkpoint(
|
891 |
+
layer,
|
892 |
+
hidden_states,
|
893 |
+
attention_mask,
|
894 |
+
rotary_pos_emb,
|
895 |
+
kv_caches[index],
|
896 |
+
use_cache
|
897 |
+
)
|
898 |
+
else:
|
899 |
+
layer_ret = layer(
|
900 |
+
hidden_states,
|
901 |
+
attention_mask,
|
902 |
+
rotary_pos_emb,
|
903 |
+
kv_cache=kv_caches[index],
|
904 |
+
use_cache=use_cache
|
905 |
+
)
|
906 |
+
hidden_states, kv_cache = layer_ret
|
907 |
+
if use_cache:
|
908 |
+
presents = presents + (kv_cache,)
|
909 |
+
|
910 |
+
if output_hidden_states:
|
911 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
912 |
+
|
913 |
+
# Final layer norm.
|
914 |
+
if self.post_layer_norm:
|
915 |
+
hidden_states = self.final_layernorm(hidden_states)
|
916 |
+
|
917 |
+
return hidden_states, presents, all_hidden_states, all_self_attentions
|
918 |
+
|
919 |
+
|
920 |
+
class ChatGLMPreTrainedModel(PreTrainedModel):
|
921 |
+
"""
|
922 |
+
An abstract class to handle weights initialization and
|
923 |
+
a simple interface for downloading and loading pretrained models.
|
924 |
+
"""
|
925 |
+
|
926 |
+
is_parallelizable = False
|
927 |
+
supports_gradient_checkpointing = True
|
928 |
+
config_class = ChatGLMConfig
|
929 |
+
base_model_prefix = "transformer"
|
930 |
+
_no_split_modules = ["GLMBlock"]
|
931 |
+
|
932 |
+
def _init_weights(self, module: nn.Module):
|
933 |
+
"""Initialize the weights."""
|
934 |
+
return
|
935 |
+
|
936 |
+
def get_masks(self, input_ids, past_key_values, padding_mask=None):
|
937 |
+
batch_size, seq_length = input_ids.shape
|
938 |
+
full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
|
939 |
+
full_attention_mask.tril_()
|
940 |
+
past_length = 0
|
941 |
+
if past_key_values:
|
942 |
+
past_length = past_key_values[0][0].shape[0]
|
943 |
+
if past_length:
|
944 |
+
full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
|
945 |
+
device=input_ids.device), full_attention_mask), dim=-1)
|
946 |
+
if padding_mask is not None:
|
947 |
+
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
|
948 |
+
if not past_length and padding_mask is not None:
|
949 |
+
full_attention_mask -= padding_mask.unsqueeze(-1) - 1
|
950 |
+
full_attention_mask = (full_attention_mask < 0.5).bool()
|
951 |
+
full_attention_mask.unsqueeze_(1)
|
952 |
+
return full_attention_mask
|
953 |
+
|
954 |
+
def get_position_ids(self, input_ids, device):
|
955 |
+
batch_size, seq_length = input_ids.shape
|
956 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
957 |
+
return position_ids
|
958 |
+
|
959 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
960 |
+
if isinstance(module, GLMTransformer):
|
961 |
+
module.gradient_checkpointing = value
|
962 |
+
|
963 |
+
|
964 |
+
class Embedding(torch.nn.Module):
|
965 |
+
"""Language model embeddings."""
|
966 |
+
|
967 |
+
def __init__(self, config: ChatGLMConfig, device=None):
|
968 |
+
super(Embedding, self).__init__()
|
969 |
+
|
970 |
+
self.hidden_size = config.hidden_size
|
971 |
+
# Word embeddings (parallel).
|
972 |
+
self.word_embeddings = nn.Embedding(
|
973 |
+
config.padded_vocab_size,
|
974 |
+
self.hidden_size,
|
975 |
+
dtype=config.torch_dtype,
|
976 |
+
device=device
|
977 |
+
)
|
978 |
+
self.fp32_residual_connection = config.fp32_residual_connection
|
979 |
+
|
980 |
+
def forward(self, input_ids):
|
981 |
+
# Embeddings.
|
982 |
+
words_embeddings = self.word_embeddings(input_ids)
|
983 |
+
embeddings = words_embeddings
|
984 |
+
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
|
985 |
+
embeddings = embeddings.transpose(0, 1).contiguous()
|
986 |
+
# If the input flag for fp32 residual connection is set, convert for float.
|
987 |
+
if self.fp32_residual_connection:
|
988 |
+
embeddings = embeddings.float()
|
989 |
+
return embeddings
|
990 |
+
|
991 |
+
|
992 |
+
class ChatGLMModel(ChatGLMPreTrainedModel):
|
993 |
+
def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
|
994 |
+
super().__init__(config)
|
995 |
+
if empty_init:
|
996 |
+
init_method = skip_init
|
997 |
+
else:
|
998 |
+
init_method = default_init
|
999 |
+
init_kwargs = {}
|
1000 |
+
if device is not None:
|
1001 |
+
init_kwargs["device"] = device
|
1002 |
+
self.embedding = init_method(Embedding, config, **init_kwargs)
|
1003 |
+
self.num_layers = config.num_layers
|
1004 |
+
self.multi_query_group_num = config.multi_query_group_num
|
1005 |
+
self.kv_channels = config.kv_channels
|
1006 |
+
|
1007 |
+
# Rotary positional embeddings
|
1008 |
+
self.seq_length = config.seq_length
|
1009 |
+
rotary_dim = (
|
1010 |
+
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
1011 |
+
)
|
1012 |
+
|
1013 |
+
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
|
1014 |
+
dtype=config.torch_dtype)
|
1015 |
+
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
1016 |
+
self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
|
1017 |
+
dtype=config.torch_dtype, **init_kwargs)
|
1018 |
+
self.pre_seq_len = config.pre_seq_len
|
1019 |
+
self.prefix_projection = config.prefix_projection
|
1020 |
+
if self.pre_seq_len is not None:
|
1021 |
+
for param in self.parameters():
|
1022 |
+
param.requires_grad = False
|
1023 |
+
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
1024 |
+
self.prefix_encoder = PrefixEncoder(config)
|
1025 |
+
self.dropout = torch.nn.Dropout(0.1)
|
1026 |
+
|
1027 |
+
def get_input_embeddings(self):
|
1028 |
+
return self.embedding.word_embeddings
|
1029 |
+
|
1030 |
+
def get_prompt(self, batch_size, device, dtype=torch.half):
|
1031 |
+
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
|
1032 |
+
past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
|
1033 |
+
past_key_values = past_key_values.view(
|
1034 |
+
batch_size,
|
1035 |
+
self.pre_seq_len,
|
1036 |
+
self.num_layers * 2,
|
1037 |
+
self.multi_query_group_num,
|
1038 |
+
self.kv_channels
|
1039 |
+
)
|
1040 |
+
# seq_len, b, nh, hidden_size
|
1041 |
+
past_key_values = self.dropout(past_key_values)
|
1042 |
+
past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
|
1043 |
+
return past_key_values
|
1044 |
+
|
1045 |
+
def forward(
|
1046 |
+
self,
|
1047 |
+
input_ids,
|
1048 |
+
position_ids: Optional[torch.Tensor] = None,
|
1049 |
+
attention_mask: Optional[torch.BoolTensor] = None,
|
1050 |
+
full_attention_mask: Optional[torch.BoolTensor] = None,
|
1051 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
1052 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1053 |
+
use_cache: Optional[bool] = None,
|
1054 |
+
output_hidden_states: Optional[bool] = None,
|
1055 |
+
return_dict: Optional[bool] = None,
|
1056 |
+
):
|
1057 |
+
output_hidden_states = (
|
1058 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1059 |
+
)
|
1060 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1061 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1062 |
+
|
1063 |
+
batch_size, seq_length = input_ids.shape
|
1064 |
+
|
1065 |
+
if inputs_embeds is None:
|
1066 |
+
inputs_embeds = self.embedding(input_ids)
|
1067 |
+
|
1068 |
+
if self.pre_seq_len is not None:
|
1069 |
+
if past_key_values is None:
|
1070 |
+
past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
|
1071 |
+
dtype=inputs_embeds.dtype)
|
1072 |
+
if attention_mask is not None:
|
1073 |
+
attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),
|
1074 |
+
attention_mask], dim=-1)
|
1075 |
+
|
1076 |
+
if full_attention_mask is None:
|
1077 |
+
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
1078 |
+
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
|
1079 |
+
|
1080 |
+
# Rotary positional embeddings
|
1081 |
+
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
1082 |
+
if position_ids is not None:
|
1083 |
+
rotary_pos_emb = rotary_pos_emb[position_ids]
|
1084 |
+
else:
|
1085 |
+
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
|
1086 |
+
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
1087 |
+
|
1088 |
+
# Run encoder.
|
1089 |
+
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
1090 |
+
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
|
1091 |
+
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
|
1092 |
+
)
|
1093 |
+
|
1094 |
+
if not return_dict:
|
1095 |
+
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
1096 |
+
|
1097 |
+
return BaseModelOutputWithPast(
|
1098 |
+
last_hidden_state=hidden_states,
|
1099 |
+
past_key_values=presents,
|
1100 |
+
hidden_states=all_hidden_states,
|
1101 |
+
attentions=all_self_attentions,
|
1102 |
+
)
|
1103 |
+
|
1104 |
+
def quantize(self, weight_bit_width: int):
|
1105 |
+
# from .quantization import quantize
|
1106 |
+
quantize(self.encoder, weight_bit_width)
|
1107 |
+
return self
|
1108 |
+
|
1109 |
+
|
1110 |
+
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
1111 |
+
def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
|
1112 |
+
super().__init__(config)
|
1113 |
+
|
1114 |
+
self.max_sequence_length = config.max_length
|
1115 |
+
self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
|
1116 |
+
self.config = config
|
1117 |
+
self.quantized = False
|
1118 |
+
|
1119 |
+
if self.config.quantization_bit:
|
1120 |
+
self.quantize(self.config.quantization_bit, empty_init=True)
|
1121 |
+
|
1122 |
+
def _update_model_kwargs_for_generation(
|
1123 |
+
self,
|
1124 |
+
outputs: ModelOutput,
|
1125 |
+
model_kwargs: Dict[str, Any],
|
1126 |
+
is_encoder_decoder: bool = False,
|
1127 |
+
standardize_cache_format: bool = False,
|
1128 |
+
) -> Dict[str, Any]:
|
1129 |
+
# update past_key_values
|
1130 |
+
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
1131 |
+
outputs, standardize_cache_format=standardize_cache_format
|
1132 |
+
)
|
1133 |
+
|
1134 |
+
# update attention mask
|
1135 |
+
if "attention_mask" in model_kwargs:
|
1136 |
+
attention_mask = model_kwargs["attention_mask"]
|
1137 |
+
model_kwargs["attention_mask"] = torch.cat(
|
1138 |
+
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
1139 |
+
)
|
1140 |
+
|
1141 |
+
# update position ids
|
1142 |
+
if "position_ids" in model_kwargs:
|
1143 |
+
position_ids = model_kwargs["position_ids"]
|
1144 |
+
new_position_id = position_ids[..., -1:].clone()
|
1145 |
+
new_position_id += 1
|
1146 |
+
model_kwargs["position_ids"] = torch.cat(
|
1147 |
+
[position_ids, new_position_id], dim=-1
|
1148 |
+
)
|
1149 |
+
|
1150 |
+
model_kwargs["is_first_forward"] = False
|
1151 |
+
return model_kwargs
|
1152 |
+
|
1153 |
+
def prepare_inputs_for_generation(
|
1154 |
+
self,
|
1155 |
+
input_ids: torch.LongTensor,
|
1156 |
+
past_key_values: Optional[torch.Tensor] = None,
|
1157 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1158 |
+
position_ids: Optional[torch.Tensor] = None,
|
1159 |
+
use_cache: Optional[bool] = None,
|
1160 |
+
is_first_forward: bool = True,
|
1161 |
+
**kwargs
|
1162 |
+
) -> dict:
|
1163 |
+
# only last token for input_ids if past is not None
|
1164 |
+
if position_ids is None:
|
1165 |
+
position_ids = self.get_position_ids(input_ids, device=input_ids.device)
|
1166 |
+
if not is_first_forward:
|
1167 |
+
if past_key_values is not None:
|
1168 |
+
position_ids = position_ids[..., -1:]
|
1169 |
+
input_ids = input_ids[:, -1:]
|
1170 |
+
return {
|
1171 |
+
"input_ids": input_ids,
|
1172 |
+
"past_key_values": past_key_values,
|
1173 |
+
"position_ids": position_ids,
|
1174 |
+
"attention_mask": attention_mask,
|
1175 |
+
"return_last_logit": True,
|
1176 |
+
"use_cache": use_cache
|
1177 |
+
}
|
1178 |
+
|
1179 |
+
def forward(
|
1180 |
+
self,
|
1181 |
+
input_ids: Optional[torch.Tensor] = None,
|
1182 |
+
position_ids: Optional[torch.Tensor] = None,
|
1183 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1184 |
+
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
1185 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1186 |
+
labels: Optional[torch.Tensor] = None,
|
1187 |
+
use_cache: Optional[bool] = None,
|
1188 |
+
output_attentions: Optional[bool] = None,
|
1189 |
+
output_hidden_states: Optional[bool] = None,
|
1190 |
+
return_dict: Optional[bool] = None,
|
1191 |
+
return_last_logit: Optional[bool] = False,
|
1192 |
+
):
|
1193 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1194 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1195 |
+
|
1196 |
+
transformer_outputs = self.transformer(
|
1197 |
+
input_ids=input_ids,
|
1198 |
+
position_ids=position_ids,
|
1199 |
+
attention_mask=attention_mask,
|
1200 |
+
past_key_values=past_key_values,
|
1201 |
+
inputs_embeds=inputs_embeds,
|
1202 |
+
use_cache=use_cache,
|
1203 |
+
output_hidden_states=output_hidden_states,
|
1204 |
+
return_dict=return_dict,
|
1205 |
+
)
|
1206 |
+
|
1207 |
+
hidden_states = transformer_outputs[0]
|
1208 |
+
if return_last_logit:
|
1209 |
+
hidden_states = hidden_states[-1:]
|
1210 |
+
lm_logits = self.transformer.output_layer(hidden_states)
|
1211 |
+
lm_logits = lm_logits.transpose(0, 1).contiguous()
|
1212 |
+
|
1213 |
+
loss = None
|
1214 |
+
if labels is not None:
|
1215 |
+
lm_logits = lm_logits.to(torch.float32)
|
1216 |
+
|
1217 |
+
# Shift so that tokens < n predict n
|
1218 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
1219 |
+
shift_labels = labels[..., 1:].contiguous()
|
1220 |
+
# Flatten the tokens
|
1221 |
+
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
1222 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
1223 |
+
|
1224 |
+
lm_logits = lm_logits.to(hidden_states.dtype)
|
1225 |
+
loss = loss.to(hidden_states.dtype)
|
1226 |
+
|
1227 |
+
if not return_dict:
|
1228 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
1229 |
+
return ((loss,) + output) if loss is not None else output
|
1230 |
+
|
1231 |
+
return CausalLMOutputWithPast(
|
1232 |
+
loss=loss,
|
1233 |
+
logits=lm_logits,
|
1234 |
+
past_key_values=transformer_outputs.past_key_values,
|
1235 |
+
hidden_states=transformer_outputs.hidden_states,
|
1236 |
+
attentions=transformer_outputs.attentions,
|
1237 |
+
)
|
1238 |
+
|
1239 |
+
@staticmethod
|
1240 |
+
def _reorder_cache(
|
1241 |
+
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
|
1242 |
+
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
|
1243 |
+
"""
|
1244 |
+
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
1245 |
+
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
1246 |
+
beam_idx at every generation step.
|
1247 |
+
|
1248 |
+
Output shares the same memory storage as `past`.
|
1249 |
+
"""
|
1250 |
+
return tuple(
|
1251 |
+
(
|
1252 |
+
layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
|
1253 |
+
layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
|
1254 |
+
)
|
1255 |
+
for layer_past in past
|
1256 |
+
)
|
1257 |
+
|
1258 |
+
def process_response(self, output, history):
|
1259 |
+
content = ""
|
1260 |
+
history = deepcopy(history)
|
1261 |
+
for response in output.split("<|assistant|>"):
|
1262 |
+
metadata, content = response.split("\n", maxsplit=1)
|
1263 |
+
if not metadata.strip():
|
1264 |
+
content = content.strip()
|
1265 |
+
history.append({"role": "assistant", "metadata": metadata, "content": content})
|
1266 |
+
content = content.replace("[[训练时间]]", "2023年")
|
1267 |
+
else:
|
1268 |
+
history.append({"role": "assistant", "metadata": metadata, "content": content})
|
1269 |
+
if history[0]["role"] == "system" and "tools" in history[0]:
|
1270 |
+
content = "\n".join(content.split("\n")[1:-1])
|
1271 |
+
def tool_call(**kwargs):
|
1272 |
+
return kwargs
|
1273 |
+
parameters = eval(content)
|
1274 |
+
content = {"name": metadata.strip(), "parameters": parameters}
|
1275 |
+
else:
|
1276 |
+
content = {"name": metadata.strip(), "content": content}
|
1277 |
+
return content, history
|
1278 |
+
|
1279 |
+
@torch.inference_mode()
|
1280 |
+
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
|
1281 |
+
max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
|
1282 |
+
**kwargs):
|
1283 |
+
if history is None:
|
1284 |
+
history = []
|
1285 |
+
if logits_processor is None:
|
1286 |
+
logits_processor = LogitsProcessorList()
|
1287 |
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
1288 |
+
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
1289 |
+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1290 |
+
inputs = tokenizer.build_chat_input(query, history=history, role=role)
|
1291 |
+
inputs = inputs.to(self.device)
|
1292 |
+
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
|
1293 |
+
tokenizer.get_command("<|observation|>")]
|
1294 |
+
outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
|
1295 |
+
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
|
1296 |
+
response = tokenizer.decode(outputs)
|
1297 |
+
history.append({"role": role, "content": query})
|
1298 |
+
response, history = self.process_response(response, history)
|
1299 |
+
return response, history
|
1300 |
+
|
1301 |
+
@torch.inference_mode()
|
1302 |
+
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
|
1303 |
+
past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
|
1304 |
+
logits_processor=None, return_past_key_values=False, **kwargs):
|
1305 |
+
if history is None:
|
1306 |
+
history = []
|
1307 |
+
if logits_processor is None:
|
1308 |
+
logits_processor = LogitsProcessorList()
|
1309 |
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
1310 |
+
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
|
1311 |
+
tokenizer.get_command("<|observation|>")]
|
1312 |
+
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
|
1313 |
+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1314 |
+
if past_key_values is None:
|
1315 |
+
inputs = tokenizer.build_chat_input(query, history=history, role=role)
|
1316 |
+
else:
|
1317 |
+
inputs = tokenizer.build_chat_input(query, role=role)
|
1318 |
+
inputs = inputs.to(self.device)
|
1319 |
+
if past_key_values is not None:
|
1320 |
+
past_length = past_key_values[0][0].shape[0]
|
1321 |
+
if self.transformer.pre_seq_len is not None:
|
1322 |
+
past_length -= self.transformer.pre_seq_len
|
1323 |
+
inputs.position_ids += past_length
|
1324 |
+
attention_mask = inputs.attention_mask
|
1325 |
+
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
|
1326 |
+
inputs['attention_mask'] = attention_mask
|
1327 |
+
history.append({"role": role, "content": query})
|
1328 |
+
for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
|
1329 |
+
eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
|
1330 |
+
**gen_kwargs):
|
1331 |
+
if return_past_key_values:
|
1332 |
+
outputs, past_key_values = outputs
|
1333 |
+
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
|
1334 |
+
response = tokenizer.decode(outputs)
|
1335 |
+
if response and response[-1] != "�":
|
1336 |
+
response, new_history = self.process_response(response, history)
|
1337 |
+
if return_past_key_values:
|
1338 |
+
yield response, new_history, past_key_values
|
1339 |
+
else:
|
1340 |
+
yield response, new_history
|
1341 |
+
|
1342 |
+
@torch.inference_mode()
|
1343 |
+
def stream_generate(
|
1344 |
+
self,
|
1345 |
+
input_ids,
|
1346 |
+
generation_config: Optional[GenerationConfig] = None,
|
1347 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
1348 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
1349 |
+
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
1350 |
+
return_past_key_values=False,
|
1351 |
+
**kwargs,
|
1352 |
+
):
|
1353 |
+
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
|
1354 |
+
|
1355 |
+
if generation_config is None:
|
1356 |
+
generation_config = self.generation_config
|
1357 |
+
generation_config = copy.deepcopy(generation_config)
|
1358 |
+
model_kwargs = generation_config.update(**kwargs)
|
1359 |
+
model_kwargs["use_cache"] = generation_config.use_cache
|
1360 |
+
bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
|
1361 |
+
|
1362 |
+
if isinstance(eos_token_id, int):
|
1363 |
+
eos_token_id = [eos_token_id]
|
1364 |
+
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
1365 |
+
|
1366 |
+
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
1367 |
+
if has_default_max_length and generation_config.max_new_tokens is None:
|
1368 |
+
warnings.warn(
|
1369 |
+
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
|
1370 |
+
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
|
1371 |
+
" recommend using `max_new_tokens` to control the maximum length of the generation.",
|
1372 |
+
UserWarning,
|
1373 |
+
)
|
1374 |
+
elif generation_config.max_new_tokens is not None:
|
1375 |
+
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
1376 |
+
if not has_default_max_length:
|
1377 |
+
logger.warn(
|
1378 |
+
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
1379 |
+
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
1380 |
+
"Please refer to the documentation for more information. "
|
1381 |
+
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
|
1382 |
+
UserWarning,
|
1383 |
+
)
|
1384 |
+
|
1385 |
+
if input_ids_seq_length >= generation_config.max_length:
|
1386 |
+
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
1387 |
+
logger.warning(
|
1388 |
+
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
1389 |
+
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
1390 |
+
" increasing `max_new_tokens`."
|
1391 |
+
)
|
1392 |
+
|
1393 |
+
# 2. Set generation parameters if not already defined
|
1394 |
+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
1395 |
+
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
1396 |
+
|
1397 |
+
logits_processor = self._get_logits_processor(
|
1398 |
+
generation_config=generation_config,
|
1399 |
+
input_ids_seq_length=input_ids_seq_length,
|
1400 |
+
encoder_input_ids=input_ids,
|
1401 |
+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
1402 |
+
logits_processor=logits_processor,
|
1403 |
+
)
|
1404 |
+
|
1405 |
+
stopping_criteria = self._get_stopping_criteria(
|
1406 |
+
generation_config=generation_config, stopping_criteria=stopping_criteria
|
1407 |
+
)
|
1408 |
+
logits_warper = self._get_logits_warper(generation_config)
|
1409 |
+
|
1410 |
+
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
1411 |
+
scores = None
|
1412 |
+
while True:
|
1413 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
1414 |
+
# forward pass to get next token
|
1415 |
+
outputs = self(
|
1416 |
+
**model_inputs,
|
1417 |
+
return_dict=True,
|
1418 |
+
output_attentions=False,
|
1419 |
+
output_hidden_states=False,
|
1420 |
+
)
|
1421 |
+
|
1422 |
+
next_token_logits = outputs.logits[:, -1, :]
|
1423 |
+
|
1424 |
+
# pre-process distribution
|
1425 |
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
1426 |
+
next_token_scores = logits_warper(input_ids, next_token_scores)
|
1427 |
+
|
1428 |
+
# sample
|
1429 |
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
1430 |
+
if generation_config.do_sample:
|
1431 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
1432 |
+
else:
|
1433 |
+
next_tokens = torch.argmax(probs, dim=-1)
|
1434 |
+
# update generated ids, model inputs, and length for next step
|
1435 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
1436 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
1437 |
+
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
1438 |
+
)
|
1439 |
+
unfinished_sequences = unfinished_sequences.mul(
|
1440 |
+
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
1441 |
+
)
|
1442 |
+
if return_past_key_values:
|
1443 |
+
yield input_ids, outputs.past_key_values
|
1444 |
+
else:
|
1445 |
+
yield input_ids
|
1446 |
+
# stop when each sentence is finished, or if we exceed the maximum length
|
1447 |
+
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
1448 |
+
break
|
1449 |
+
|
1450 |
+
def quantize(self, bits: int, empty_init=False, device=None, **kwargs):
|
1451 |
+
if bits == 0:
|
1452 |
+
return
|
1453 |
+
|
1454 |
+
# from .quantization import quantize
|
1455 |
+
|
1456 |
+
if self.quantized:
|
1457 |
+
logger.info("Already quantized.")
|
1458 |
+
return self
|
1459 |
+
|
1460 |
+
self.quantized = True
|
1461 |
+
|
1462 |
+
self.config.quantization_bit = bits
|
1463 |
+
|
1464 |
+
self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device,
|
1465 |
+
**kwargs)
|
1466 |
+
return self
|
1467 |
+
|
1468 |
+
|
1469 |
+
class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
1470 |
+
def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
|
1471 |
+
super().__init__(config)
|
1472 |
+
|
1473 |
+
self.num_labels = config.num_labels
|
1474 |
+
self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
|
1475 |
+
|
1476 |
+
self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
|
1477 |
+
if config.classifier_dropout is not None:
|
1478 |
+
self.dropout = nn.Dropout(config.classifier_dropout)
|
1479 |
+
else:
|
1480 |
+
self.dropout = None
|
1481 |
+
self.config = config
|
1482 |
+
|
1483 |
+
if self.config.quantization_bit:
|
1484 |
+
self.quantize(self.config.quantization_bit, empty_init=True)
|
1485 |
+
|
1486 |
+
def forward(
|
1487 |
+
self,
|
1488 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1489 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1490 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1491 |
+
full_attention_mask: Optional[torch.Tensor] = None,
|
1492 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
1493 |
+
inputs_embeds: Optional[torch.LongTensor] = None,
|
1494 |
+
labels: Optional[torch.LongTensor] = None,
|
1495 |
+
use_cache: Optional[bool] = None,
|
1496 |
+
output_hidden_states: Optional[bool] = None,
|
1497 |
+
return_dict: Optional[bool] = None,
|
1498 |
+
) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
|
1499 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1500 |
+
|
1501 |
+
transformer_outputs = self.transformer(
|
1502 |
+
input_ids=input_ids,
|
1503 |
+
position_ids=position_ids,
|
1504 |
+
attention_mask=attention_mask,
|
1505 |
+
full_attention_mask=full_attention_mask,
|
1506 |
+
past_key_values=past_key_values,
|
1507 |
+
inputs_embeds=inputs_embeds,
|
1508 |
+
use_cache=use_cache,
|
1509 |
+
output_hidden_states=output_hidden_states,
|
1510 |
+
return_dict=return_dict,
|
1511 |
+
)
|
1512 |
+
|
1513 |
+
hidden_states = transformer_outputs[0]
|
1514 |
+
pooled_hidden_states = hidden_states[-1]
|
1515 |
+
if self.dropout is not None:
|
1516 |
+
pooled_hidden_states = self.dropout(pooled_hidden_states)
|
1517 |
+
logits = self.classifier_head(pooled_hidden_states)
|
1518 |
+
|
1519 |
+
loss = None
|
1520 |
+
if labels is not None:
|
1521 |
+
if self.config.problem_type is None:
|
1522 |
+
if self.num_labels == 1:
|
1523 |
+
self.config.problem_type = "regression"
|
1524 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
1525 |
+
self.config.problem_type = "single_label_classification"
|
1526 |
+
else:
|
1527 |
+
self.config.problem_type = "multi_label_classification"
|
1528 |
+
|
1529 |
+
if self.config.problem_type == "regression":
|
1530 |
+
loss_fct = MSELoss()
|
1531 |
+
if self.num_labels == 1:
|
1532 |
+
loss = loss_fct(logits.squeeze().float(), labels.squeeze())
|
1533 |
+
else:
|
1534 |
+
loss = loss_fct(logits.float(), labels)
|
1535 |
+
elif self.config.problem_type == "single_label_classification":
|
1536 |
+
loss_fct = CrossEntropyLoss()
|
1537 |
+
loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1))
|
1538 |
+
elif self.config.problem_type == "multi_label_classification":
|
1539 |
+
loss_fct = BCEWithLogitsLoss()
|
1540 |
+
loss = loss_fct(logits.float(), labels.view(-1, self.num_labels))
|
1541 |
+
|
1542 |
+
if not return_dict:
|
1543 |
+
output = (logits,) + transformer_outputs[1:]
|
1544 |
+
return ((loss,) + output) if loss is not None else output
|
1545 |
+
|
1546 |
+
return SequenceClassifierOutputWithPast(
|
1547 |
+
loss=loss,
|
1548 |
+
logits=logits,
|
1549 |
+
past_key_values=transformer_outputs.past_key_values,
|
1550 |
+
hidden_states=transformer_outputs.hidden_states,
|
1551 |
+
attentions=transformer_outputs.attentions,
|
1552 |
+
)
|
diffsynth/models/lora.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .sd_unet import SDUNet
|
3 |
+
from .sdxl_unet import SDXLUNet
|
4 |
+
from .sd_text_encoder import SDTextEncoder
|
5 |
+
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
6 |
+
from .sd3_dit import SD3DiT
|
7 |
+
from .hunyuan_dit import HunyuanDiT
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
class LoRAFromCivitai:
|
12 |
+
def __init__(self):
|
13 |
+
self.supported_model_classes = []
|
14 |
+
self.lora_prefix = []
|
15 |
+
self.renamed_lora_prefix = {}
|
16 |
+
self.special_keys = {}
|
17 |
+
|
18 |
+
|
19 |
+
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
20 |
+
renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
|
21 |
+
state_dict_ = {}
|
22 |
+
for key in state_dict:
|
23 |
+
if ".lora_up" not in key:
|
24 |
+
continue
|
25 |
+
if not key.startswith(lora_prefix):
|
26 |
+
continue
|
27 |
+
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
|
28 |
+
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
|
29 |
+
if len(weight_up.shape) == 4:
|
30 |
+
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
|
31 |
+
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
|
32 |
+
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
33 |
+
else:
|
34 |
+
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
35 |
+
target_name = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight"
|
36 |
+
for special_key in self.special_keys:
|
37 |
+
target_name = target_name.replace(special_key, self.special_keys[special_key])
|
38 |
+
state_dict_[target_name] = lora_weight.cpu()
|
39 |
+
return state_dict_
|
40 |
+
|
41 |
+
|
42 |
+
def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
|
43 |
+
state_dict_model = model.state_dict()
|
44 |
+
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha)
|
45 |
+
if model_resource == "diffusers":
|
46 |
+
state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora)
|
47 |
+
elif model_resource == "civitai":
|
48 |
+
state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora)
|
49 |
+
if len(state_dict_lora) > 0:
|
50 |
+
print(f" {len(state_dict_lora)} tensors are updated.")
|
51 |
+
for name in state_dict_lora:
|
52 |
+
state_dict_model[name] += state_dict_lora[name].to(
|
53 |
+
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
|
54 |
+
model.load_state_dict(state_dict_model)
|
55 |
+
|
56 |
+
|
57 |
+
def match(self, model, state_dict_lora):
|
58 |
+
for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes):
|
59 |
+
if not isinstance(model, model_class):
|
60 |
+
continue
|
61 |
+
state_dict_model = model.state_dict()
|
62 |
+
for model_resource in ["diffusers", "civitai"]:
|
63 |
+
try:
|
64 |
+
state_dict_lora_ = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0)
|
65 |
+
converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \
|
66 |
+
else model.__class__.state_dict_converter().from_civitai
|
67 |
+
state_dict_lora_ = converter_fn(state_dict_lora_)
|
68 |
+
if len(state_dict_lora_) == 0:
|
69 |
+
continue
|
70 |
+
for name in state_dict_lora_:
|
71 |
+
if name not in state_dict_model:
|
72 |
+
break
|
73 |
+
else:
|
74 |
+
return lora_prefix, model_resource
|
75 |
+
except:
|
76 |
+
pass
|
77 |
+
return None
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
class SDLoRAFromCivitai(LoRAFromCivitai):
|
82 |
+
def __init__(self):
|
83 |
+
super().__init__()
|
84 |
+
self.supported_model_classes = [SDUNet, SDTextEncoder]
|
85 |
+
self.lora_prefix = ["lora_unet_", "lora_te_"]
|
86 |
+
self.special_keys = {
|
87 |
+
"down.blocks": "down_blocks",
|
88 |
+
"up.blocks": "up_blocks",
|
89 |
+
"mid.block": "mid_block",
|
90 |
+
"proj.in": "proj_in",
|
91 |
+
"proj.out": "proj_out",
|
92 |
+
"transformer.blocks": "transformer_blocks",
|
93 |
+
"to.q": "to_q",
|
94 |
+
"to.k": "to_k",
|
95 |
+
"to.v": "to_v",
|
96 |
+
"to.out": "to_out",
|
97 |
+
"text.model": "text_model",
|
98 |
+
"self.attn.q.proj": "self_attn.q_proj",
|
99 |
+
"self.attn.k.proj": "self_attn.k_proj",
|
100 |
+
"self.attn.v.proj": "self_attn.v_proj",
|
101 |
+
"self.attn.out.proj": "self_attn.out_proj",
|
102 |
+
"input.blocks": "model.diffusion_model.input_blocks",
|
103 |
+
"middle.block": "model.diffusion_model.middle_block",
|
104 |
+
"output.blocks": "model.diffusion_model.output_blocks",
|
105 |
+
}
|
106 |
+
|
107 |
+
|
108 |
+
class SDXLLoRAFromCivitai(LoRAFromCivitai):
|
109 |
+
def __init__(self):
|
110 |
+
super().__init__()
|
111 |
+
self.supported_model_classes = [SDXLUNet, SDXLTextEncoder, SDXLTextEncoder2]
|
112 |
+
self.lora_prefix = ["lora_unet_", "lora_te1_", "lora_te2_"]
|
113 |
+
self.renamed_lora_prefix = {"lora_te2_": "2"}
|
114 |
+
self.special_keys = {
|
115 |
+
"down.blocks": "down_blocks",
|
116 |
+
"up.blocks": "up_blocks",
|
117 |
+
"mid.block": "mid_block",
|
118 |
+
"proj.in": "proj_in",
|
119 |
+
"proj.out": "proj_out",
|
120 |
+
"transformer.blocks": "transformer_blocks",
|
121 |
+
"to.q": "to_q",
|
122 |
+
"to.k": "to_k",
|
123 |
+
"to.v": "to_v",
|
124 |
+
"to.out": "to_out",
|
125 |
+
"text.model": "conditioner.embedders.0.transformer.text_model",
|
126 |
+
"self.attn.q.proj": "self_attn.q_proj",
|
127 |
+
"self.attn.k.proj": "self_attn.k_proj",
|
128 |
+
"self.attn.v.proj": "self_attn.v_proj",
|
129 |
+
"self.attn.out.proj": "self_attn.out_proj",
|
130 |
+
"input.blocks": "model.diffusion_model.input_blocks",
|
131 |
+
"middle.block": "model.diffusion_model.middle_block",
|
132 |
+
"output.blocks": "model.diffusion_model.output_blocks",
|
133 |
+
"2conditioner.embedders.0.transformer.text_model.encoder.layers": "text_model.encoder.layers"
|
134 |
+
}
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
class GeneralLoRAFromPeft:
|
139 |
+
def __init__(self):
|
140 |
+
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT]
|
141 |
+
|
142 |
+
|
143 |
+
def convert_state_dict(self, state_dict, alpha=1.0, device="cuda", torch_dtype=torch.float16):
|
144 |
+
state_dict_ = {}
|
145 |
+
for key in state_dict:
|
146 |
+
if ".lora_B." not in key:
|
147 |
+
continue
|
148 |
+
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
149 |
+
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
150 |
+
if len(weight_up.shape) == 4:
|
151 |
+
weight_up = weight_up.squeeze(3).squeeze(2)
|
152 |
+
weight_down = weight_down.squeeze(3).squeeze(2)
|
153 |
+
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
154 |
+
else:
|
155 |
+
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
156 |
+
keys = key.split(".")
|
157 |
+
keys.pop(keys.index("lora_B") + 1)
|
158 |
+
keys.pop(keys.index("lora_B"))
|
159 |
+
target_name = ".".join(keys)
|
160 |
+
state_dict_[target_name] = lora_weight.cpu()
|
161 |
+
return state_dict_
|
162 |
+
|
163 |
+
|
164 |
+
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
|
165 |
+
state_dict_model = model.state_dict()
|
166 |
+
for name, param in state_dict_model.items():
|
167 |
+
torch_dtype = param.dtype
|
168 |
+
device = param.device
|
169 |
+
break
|
170 |
+
state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, device=device, torch_dtype=torch_dtype)
|
171 |
+
if len(state_dict_lora) > 0:
|
172 |
+
print(f" {len(state_dict_lora)} tensors are updated.")
|
173 |
+
for name in state_dict_lora:
|
174 |
+
state_dict_model[name] += state_dict_lora[name].to(
|
175 |
+
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
|
176 |
+
model.load_state_dict(state_dict_model)
|
177 |
+
|
178 |
+
|
179 |
+
def match(self, model, state_dict_lora):
|
180 |
+
for model_class in self.supported_model_classes:
|
181 |
+
if not isinstance(model, model_class):
|
182 |
+
continue
|
183 |
+
state_dict_model = model.state_dict()
|
184 |
+
try:
|
185 |
+
state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0)
|
186 |
+
if len(state_dict_lora_) == 0:
|
187 |
+
continue
|
188 |
+
for name in state_dict_lora_:
|
189 |
+
if name not in state_dict_model:
|
190 |
+
break
|
191 |
+
else:
|
192 |
+
return "", ""
|
193 |
+
except:
|
194 |
+
pass
|
195 |
+
return None
|
diffsynth/models/model_manager.py
ADDED
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, torch, hashlib, json, importlib
|
2 |
+
from safetensors import safe_open
|
3 |
+
from torch import Tensor
|
4 |
+
from typing_extensions import Literal, TypeAlias
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
from .downloader import download_models, Preset_model_id, Preset_model_website
|
8 |
+
|
9 |
+
from .sd_text_encoder import SDTextEncoder
|
10 |
+
from .sd_unet import SDUNet
|
11 |
+
from .sd_vae_encoder import SDVAEEncoder
|
12 |
+
from .sd_vae_decoder import SDVAEDecoder
|
13 |
+
from .lora import SDLoRAFromCivitai, SDXLLoRAFromCivitai, GeneralLoRAFromPeft
|
14 |
+
|
15 |
+
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
16 |
+
from .sdxl_unet import SDXLUNet
|
17 |
+
from .sdxl_vae_decoder import SDXLVAEDecoder
|
18 |
+
from .sdxl_vae_encoder import SDXLVAEEncoder
|
19 |
+
|
20 |
+
from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
21 |
+
from .sd3_dit import SD3DiT
|
22 |
+
from .sd3_vae_decoder import SD3VAEDecoder
|
23 |
+
from .sd3_vae_encoder import SD3VAEEncoder
|
24 |
+
|
25 |
+
from .sd_controlnet import SDControlNet
|
26 |
+
from .sdxl_controlnet import SDXLControlNetUnion
|
27 |
+
|
28 |
+
from .sd_motion import SDMotionModel
|
29 |
+
from .sdxl_motion import SDXLMotionModel
|
30 |
+
|
31 |
+
from .svd_image_encoder import SVDImageEncoder
|
32 |
+
from .svd_unet import SVDUNet
|
33 |
+
from .svd_vae_decoder import SVDVAEDecoder
|
34 |
+
from .svd_vae_encoder import SVDVAEEncoder
|
35 |
+
|
36 |
+
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
37 |
+
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
38 |
+
|
39 |
+
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
40 |
+
from .hunyuan_dit import HunyuanDiT
|
41 |
+
|
42 |
+
from .flux_dit import FluxDiT
|
43 |
+
from .flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
|
44 |
+
from .flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
45 |
+
|
46 |
+
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
def load_state_dict(file_path, torch_dtype=None):
|
51 |
+
if file_path.endswith(".safetensors"):
|
52 |
+
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
53 |
+
else:
|
54 |
+
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
55 |
+
|
56 |
+
|
57 |
+
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
58 |
+
state_dict = {}
|
59 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
60 |
+
for k in f.keys():
|
61 |
+
state_dict[k] = f.get_tensor(k)
|
62 |
+
if torch_dtype is not None:
|
63 |
+
state_dict[k] = state_dict[k].to(torch_dtype)
|
64 |
+
return state_dict
|
65 |
+
|
66 |
+
|
67 |
+
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
68 |
+
state_dict = torch.load(file_path, map_location="cpu")
|
69 |
+
if torch_dtype is not None:
|
70 |
+
for i in state_dict:
|
71 |
+
if isinstance(state_dict[i], torch.Tensor):
|
72 |
+
state_dict[i] = state_dict[i].to(torch_dtype)
|
73 |
+
return state_dict
|
74 |
+
|
75 |
+
|
76 |
+
def search_for_embeddings(state_dict):
|
77 |
+
embeddings = []
|
78 |
+
for k in state_dict:
|
79 |
+
if isinstance(state_dict[k], torch.Tensor):
|
80 |
+
embeddings.append(state_dict[k])
|
81 |
+
elif isinstance(state_dict[k], dict):
|
82 |
+
embeddings += search_for_embeddings(state_dict[k])
|
83 |
+
return embeddings
|
84 |
+
|
85 |
+
|
86 |
+
def search_parameter(param, state_dict):
|
87 |
+
for name, param_ in state_dict.items():
|
88 |
+
if param.numel() == param_.numel():
|
89 |
+
if param.shape == param_.shape:
|
90 |
+
if torch.dist(param, param_) < 1e-3:
|
91 |
+
return name
|
92 |
+
else:
|
93 |
+
if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
|
94 |
+
return name
|
95 |
+
return None
|
96 |
+
|
97 |
+
|
98 |
+
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
|
99 |
+
matched_keys = set()
|
100 |
+
with torch.no_grad():
|
101 |
+
for name in source_state_dict:
|
102 |
+
rename = search_parameter(source_state_dict[name], target_state_dict)
|
103 |
+
if rename is not None:
|
104 |
+
print(f'"{name}": "{rename}",')
|
105 |
+
matched_keys.add(rename)
|
106 |
+
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
|
107 |
+
length = source_state_dict[name].shape[0] // 3
|
108 |
+
rename = []
|
109 |
+
for i in range(3):
|
110 |
+
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
|
111 |
+
if None not in rename:
|
112 |
+
print(f'"{name}": {rename},')
|
113 |
+
for rename_ in rename:
|
114 |
+
matched_keys.add(rename_)
|
115 |
+
for name in target_state_dict:
|
116 |
+
if name not in matched_keys:
|
117 |
+
print("Cannot find", name, target_state_dict[name].shape)
|
118 |
+
|
119 |
+
|
120 |
+
def search_for_files(folder, extensions):
|
121 |
+
files = []
|
122 |
+
if os.path.isdir(folder):
|
123 |
+
for file in sorted(os.listdir(folder)):
|
124 |
+
files += search_for_files(os.path.join(folder, file), extensions)
|
125 |
+
elif os.path.isfile(folder):
|
126 |
+
for extension in extensions:
|
127 |
+
if folder.endswith(extension):
|
128 |
+
files.append(folder)
|
129 |
+
break
|
130 |
+
return files
|
131 |
+
|
132 |
+
|
133 |
+
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
134 |
+
keys = []
|
135 |
+
for key, value in state_dict.items():
|
136 |
+
if isinstance(key, str):
|
137 |
+
if isinstance(value, Tensor):
|
138 |
+
if with_shape:
|
139 |
+
shape = "_".join(map(str, list(value.shape)))
|
140 |
+
keys.append(key + ":" + shape)
|
141 |
+
keys.append(key)
|
142 |
+
elif isinstance(value, dict):
|
143 |
+
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
144 |
+
keys.sort()
|
145 |
+
keys_str = ",".join(keys)
|
146 |
+
return keys_str
|
147 |
+
|
148 |
+
|
149 |
+
def split_state_dict_with_prefix(state_dict):
|
150 |
+
keys = sorted([key for key in state_dict if isinstance(key, str)])
|
151 |
+
prefix_dict = {}
|
152 |
+
for key in keys:
|
153 |
+
prefix = key if "." not in key else key.split(".")[0]
|
154 |
+
if prefix not in prefix_dict:
|
155 |
+
prefix_dict[prefix] = []
|
156 |
+
prefix_dict[prefix].append(key)
|
157 |
+
state_dicts = []
|
158 |
+
for prefix, keys in prefix_dict.items():
|
159 |
+
sub_state_dict = {key: state_dict[key] for key in keys}
|
160 |
+
state_dicts.append(sub_state_dict)
|
161 |
+
return state_dicts
|
162 |
+
|
163 |
+
|
164 |
+
def hash_state_dict_keys(state_dict, with_shape=True):
|
165 |
+
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
166 |
+
keys_str = keys_str.encode(encoding="UTF-8")
|
167 |
+
return hashlib.md5(keys_str).hexdigest()
|
168 |
+
|
169 |
+
|
170 |
+
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
|
171 |
+
loaded_model_names, loaded_models = [], []
|
172 |
+
for model_name, model_class in zip(model_names, model_classes):
|
173 |
+
print(f" model_name: {model_name} model_class: {model_class.__name__}")
|
174 |
+
state_dict_converter = model_class.state_dict_converter()
|
175 |
+
if model_resource == "civitai":
|
176 |
+
state_dict_results = state_dict_converter.from_civitai(state_dict)
|
177 |
+
elif model_resource == "diffusers":
|
178 |
+
state_dict_results = state_dict_converter.from_diffusers(state_dict)
|
179 |
+
if isinstance(state_dict_results, tuple):
|
180 |
+
model_state_dict, extra_kwargs = state_dict_results
|
181 |
+
print(f" This model is initialized with extra kwargs: {extra_kwargs}")
|
182 |
+
else:
|
183 |
+
model_state_dict, extra_kwargs = state_dict_results, {}
|
184 |
+
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
|
185 |
+
model = model_class(**extra_kwargs).to(dtype=torch_dtype, device=device)
|
186 |
+
model.load_state_dict(model_state_dict)
|
187 |
+
loaded_model_names.append(model_name)
|
188 |
+
loaded_models.append(model)
|
189 |
+
return loaded_model_names, loaded_models
|
190 |
+
|
191 |
+
|
192 |
+
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
|
193 |
+
loaded_model_names, loaded_models = [], []
|
194 |
+
for model_name, model_class in zip(model_names, model_classes):
|
195 |
+
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
|
196 |
+
if torch_dtype == torch.float16 and hasattr(model, "half"):
|
197 |
+
model = model.half()
|
198 |
+
model = model.to(device=device)
|
199 |
+
loaded_model_names.append(model_name)
|
200 |
+
loaded_models.append(model)
|
201 |
+
return loaded_model_names, loaded_models
|
202 |
+
|
203 |
+
|
204 |
+
def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
|
205 |
+
print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
|
206 |
+
base_state_dict = base_model.state_dict()
|
207 |
+
base_model.to("cpu")
|
208 |
+
del base_model
|
209 |
+
model = model_class(**extra_kwargs)
|
210 |
+
model.load_state_dict(base_state_dict, strict=False)
|
211 |
+
model.load_state_dict(state_dict, strict=False)
|
212 |
+
model.to(dtype=torch_dtype, device=device)
|
213 |
+
return model
|
214 |
+
|
215 |
+
|
216 |
+
def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
|
217 |
+
loaded_model_names, loaded_models = [], []
|
218 |
+
for model_name, model_class in zip(model_names, model_classes):
|
219 |
+
while True:
|
220 |
+
for model_id in range(len(model_manager.model)):
|
221 |
+
base_model_name = model_manager.model_name[model_id]
|
222 |
+
if base_model_name == model_name:
|
223 |
+
base_model_path = model_manager.model_path[model_id]
|
224 |
+
base_model = model_manager.model[model_id]
|
225 |
+
print(f" Adding patch model to {base_model_name} ({base_model_path})")
|
226 |
+
patched_model = load_single_patch_model_from_single_file(
|
227 |
+
state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
|
228 |
+
loaded_model_names.append(base_model_name)
|
229 |
+
loaded_models.append(patched_model)
|
230 |
+
model_manager.model.pop(model_id)
|
231 |
+
model_manager.model_path.pop(model_id)
|
232 |
+
model_manager.model_name.pop(model_id)
|
233 |
+
break
|
234 |
+
else:
|
235 |
+
break
|
236 |
+
return loaded_model_names, loaded_models
|
237 |
+
|
238 |
+
|
239 |
+
|
240 |
+
class ModelDetectorTemplate:
|
241 |
+
def __init__(self):
|
242 |
+
pass
|
243 |
+
|
244 |
+
def match(self, file_path="", state_dict={}):
|
245 |
+
return False
|
246 |
+
|
247 |
+
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
248 |
+
return [], []
|
249 |
+
|
250 |
+
|
251 |
+
|
252 |
+
class ModelDetectorFromSingleFile:
|
253 |
+
def __init__(self, model_loader_configs=[]):
|
254 |
+
self.keys_hash_with_shape_dict = {}
|
255 |
+
self.keys_hash_dict = {}
|
256 |
+
for metadata in model_loader_configs:
|
257 |
+
self.add_model_metadata(*metadata)
|
258 |
+
|
259 |
+
|
260 |
+
def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
|
261 |
+
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
|
262 |
+
if keys_hash is not None:
|
263 |
+
self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
|
264 |
+
|
265 |
+
|
266 |
+
def match(self, file_path="", state_dict={}):
|
267 |
+
if os.path.isdir(file_path):
|
268 |
+
return False
|
269 |
+
if len(state_dict) == 0:
|
270 |
+
state_dict = load_state_dict(file_path)
|
271 |
+
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
272 |
+
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
273 |
+
return True
|
274 |
+
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
275 |
+
if keys_hash in self.keys_hash_dict:
|
276 |
+
return True
|
277 |
+
return False
|
278 |
+
|
279 |
+
|
280 |
+
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
281 |
+
if len(state_dict) == 0:
|
282 |
+
state_dict = load_state_dict(file_path)
|
283 |
+
|
284 |
+
# Load models with strict matching
|
285 |
+
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
286 |
+
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
287 |
+
model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
288 |
+
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
289 |
+
return loaded_model_names, loaded_models
|
290 |
+
|
291 |
+
# Load models without strict matching
|
292 |
+
# (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
|
293 |
+
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
294 |
+
if keys_hash in self.keys_hash_dict:
|
295 |
+
model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
|
296 |
+
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
297 |
+
return loaded_model_names, loaded_models
|
298 |
+
|
299 |
+
return loaded_model_names, loaded_models
|
300 |
+
|
301 |
+
|
302 |
+
|
303 |
+
class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
|
304 |
+
def __init__(self, model_loader_configs=[]):
|
305 |
+
super().__init__(model_loader_configs)
|
306 |
+
|
307 |
+
|
308 |
+
def match(self, file_path="", state_dict={}):
|
309 |
+
if os.path.isdir(file_path):
|
310 |
+
return False
|
311 |
+
if len(state_dict) == 0:
|
312 |
+
state_dict = load_state_dict(file_path)
|
313 |
+
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
314 |
+
for sub_state_dict in splited_state_dict:
|
315 |
+
if super().match(file_path, sub_state_dict):
|
316 |
+
return True
|
317 |
+
return False
|
318 |
+
|
319 |
+
|
320 |
+
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
321 |
+
# Split the state_dict and load from each component
|
322 |
+
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
323 |
+
valid_state_dict = {}
|
324 |
+
for sub_state_dict in splited_state_dict:
|
325 |
+
if super().match(file_path, sub_state_dict):
|
326 |
+
valid_state_dict.update(sub_state_dict)
|
327 |
+
if super().match(file_path, valid_state_dict):
|
328 |
+
loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
|
329 |
+
else:
|
330 |
+
loaded_model_names, loaded_models = [], []
|
331 |
+
for sub_state_dict in splited_state_dict:
|
332 |
+
if super().match(file_path, sub_state_dict):
|
333 |
+
loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
|
334 |
+
loaded_model_names += loaded_model_names_
|
335 |
+
loaded_models += loaded_models_
|
336 |
+
return loaded_model_names, loaded_models
|
337 |
+
|
338 |
+
|
339 |
+
|
340 |
+
class ModelDetectorFromHuggingfaceFolder:
|
341 |
+
def __init__(self, model_loader_configs=[]):
|
342 |
+
self.architecture_dict = {}
|
343 |
+
for metadata in model_loader_configs:
|
344 |
+
self.add_model_metadata(*metadata)
|
345 |
+
|
346 |
+
|
347 |
+
def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
|
348 |
+
self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
|
349 |
+
|
350 |
+
|
351 |
+
def match(self, file_path="", state_dict={}):
|
352 |
+
if os.path.isfile(file_path):
|
353 |
+
return False
|
354 |
+
file_list = os.listdir(file_path)
|
355 |
+
if "config.json" not in file_list:
|
356 |
+
return False
|
357 |
+
with open(os.path.join(file_path, "config.json"), "r") as f:
|
358 |
+
config = json.load(f)
|
359 |
+
if "architectures" not in config:
|
360 |
+
return False
|
361 |
+
return True
|
362 |
+
|
363 |
+
|
364 |
+
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
365 |
+
with open(os.path.join(file_path, "config.json"), "r") as f:
|
366 |
+
config = json.load(f)
|
367 |
+
loaded_model_names, loaded_models = [], []
|
368 |
+
for architecture in config["architectures"]:
|
369 |
+
huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
|
370 |
+
if redirected_architecture is not None:
|
371 |
+
architecture = redirected_architecture
|
372 |
+
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
|
373 |
+
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
|
374 |
+
loaded_model_names += loaded_model_names_
|
375 |
+
loaded_models += loaded_models_
|
376 |
+
return loaded_model_names, loaded_models
|
377 |
+
|
378 |
+
|
379 |
+
|
380 |
+
class ModelDetectorFromPatchedSingleFile:
|
381 |
+
def __init__(self, model_loader_configs=[]):
|
382 |
+
self.keys_hash_with_shape_dict = {}
|
383 |
+
for metadata in model_loader_configs:
|
384 |
+
self.add_model_metadata(*metadata)
|
385 |
+
|
386 |
+
|
387 |
+
def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
|
388 |
+
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
|
389 |
+
|
390 |
+
|
391 |
+
def match(self, file_path="", state_dict={}):
|
392 |
+
if os.path.isdir(file_path):
|
393 |
+
return False
|
394 |
+
if len(state_dict) == 0:
|
395 |
+
state_dict = load_state_dict(file_path)
|
396 |
+
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
397 |
+
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
398 |
+
return True
|
399 |
+
return False
|
400 |
+
|
401 |
+
|
402 |
+
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
|
403 |
+
if len(state_dict) == 0:
|
404 |
+
state_dict = load_state_dict(file_path)
|
405 |
+
|
406 |
+
# Load models with strict matching
|
407 |
+
loaded_model_names, loaded_models = [], []
|
408 |
+
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
409 |
+
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
410 |
+
model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
411 |
+
loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
|
412 |
+
state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
|
413 |
+
loaded_model_names += loaded_model_names_
|
414 |
+
loaded_models += loaded_models_
|
415 |
+
return loaded_model_names, loaded_models
|
416 |
+
|
417 |
+
|
418 |
+
|
419 |
+
class ModelManager:
|
420 |
+
def __init__(
|
421 |
+
self,
|
422 |
+
torch_dtype=torch.float16,
|
423 |
+
device="cuda",
|
424 |
+
model_id_list: List[Preset_model_id] = [],
|
425 |
+
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
426 |
+
file_path_list: List[str] = [],
|
427 |
+
):
|
428 |
+
self.torch_dtype = torch_dtype
|
429 |
+
self.device = device
|
430 |
+
self.model = []
|
431 |
+
self.model_path = []
|
432 |
+
self.model_name = []
|
433 |
+
downloaded_files = download_models(model_id_list, downloading_priority) if len(model_id_list) > 0 else []
|
434 |
+
self.model_detector = [
|
435 |
+
ModelDetectorFromSingleFile(model_loader_configs),
|
436 |
+
ModelDetectorFromSplitedSingleFile(model_loader_configs),
|
437 |
+
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
|
438 |
+
ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
|
439 |
+
]
|
440 |
+
self.load_models(downloaded_files + file_path_list)
|
441 |
+
|
442 |
+
|
443 |
+
def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
|
444 |
+
print(f"Loading models from file: {file_path}")
|
445 |
+
if len(state_dict) == 0:
|
446 |
+
state_dict = load_state_dict(file_path)
|
447 |
+
model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
|
448 |
+
for model_name, model in zip(model_names, models):
|
449 |
+
self.model.append(model)
|
450 |
+
self.model_path.append(file_path)
|
451 |
+
self.model_name.append(model_name)
|
452 |
+
print(f" The following models are loaded: {model_names}.")
|
453 |
+
|
454 |
+
|
455 |
+
def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
|
456 |
+
print(f"Loading models from folder: {file_path}")
|
457 |
+
model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
|
458 |
+
for model_name, model in zip(model_names, models):
|
459 |
+
self.model.append(model)
|
460 |
+
self.model_path.append(file_path)
|
461 |
+
self.model_name.append(model_name)
|
462 |
+
print(f" The following models are loaded: {model_names}.")
|
463 |
+
|
464 |
+
|
465 |
+
def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
|
466 |
+
print(f"Loading patch models from file: {file_path}")
|
467 |
+
model_names, models = load_patch_model_from_single_file(
|
468 |
+
state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
|
469 |
+
for model_name, model in zip(model_names, models):
|
470 |
+
self.model.append(model)
|
471 |
+
self.model_path.append(file_path)
|
472 |
+
self.model_name.append(model_name)
|
473 |
+
print(f" The following patched models are loaded: {model_names}.")
|
474 |
+
|
475 |
+
|
476 |
+
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
|
477 |
+
print(f"Loading LoRA models from file: {file_path}")
|
478 |
+
if len(state_dict) == 0:
|
479 |
+
state_dict = load_state_dict(file_path)
|
480 |
+
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
481 |
+
for lora in [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft()]:
|
482 |
+
match_results = lora.match(model, state_dict)
|
483 |
+
if match_results is not None:
|
484 |
+
print(f" Adding LoRA to {model_name} ({model_path}).")
|
485 |
+
lora_prefix, model_resource = match_results
|
486 |
+
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
487 |
+
break
|
488 |
+
|
489 |
+
|
490 |
+
def load_model(self, file_path, model_names=None):
|
491 |
+
print(f"Loading models from: {file_path}")
|
492 |
+
if os.path.isfile(file_path):
|
493 |
+
state_dict = load_state_dict(file_path)
|
494 |
+
else:
|
495 |
+
state_dict = None
|
496 |
+
for model_detector in self.model_detector:
|
497 |
+
if model_detector.match(file_path, state_dict):
|
498 |
+
model_names, models = model_detector.load(
|
499 |
+
file_path, state_dict,
|
500 |
+
device=self.device, torch_dtype=self.torch_dtype,
|
501 |
+
allowed_model_names=model_names, model_manager=self
|
502 |
+
)
|
503 |
+
for model_name, model in zip(model_names, models):
|
504 |
+
self.model.append(model)
|
505 |
+
self.model_path.append(file_path)
|
506 |
+
self.model_name.append(model_name)
|
507 |
+
print(f" The following models are loaded: {model_names}.")
|
508 |
+
break
|
509 |
+
else:
|
510 |
+
print(f" We cannot detect the model type. No models are loaded.")
|
511 |
+
|
512 |
+
|
513 |
+
def load_models(self, file_path_list, model_names=None):
|
514 |
+
for file_path in file_path_list:
|
515 |
+
self.load_model(file_path, model_names)
|
516 |
+
|
517 |
+
|
518 |
+
def fetch_model(self, model_name, file_path=None, require_model_path=False):
|
519 |
+
fetched_models = []
|
520 |
+
fetched_model_paths = []
|
521 |
+
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
|
522 |
+
if file_path is not None and file_path != model_path:
|
523 |
+
continue
|
524 |
+
if model_name == model_name_:
|
525 |
+
fetched_models.append(model)
|
526 |
+
fetched_model_paths.append(model_path)
|
527 |
+
if len(fetched_models) == 0:
|
528 |
+
print(f"No {model_name} models available.")
|
529 |
+
return None
|
530 |
+
if len(fetched_models) == 1:
|
531 |
+
print(f"Using {model_name} from {fetched_model_paths[0]}.")
|
532 |
+
else:
|
533 |
+
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
|
534 |
+
if require_model_path:
|
535 |
+
return fetched_models[0], fetched_model_paths[0]
|
536 |
+
else:
|
537 |
+
return fetched_models[0]
|
538 |
+
|
539 |
+
|
540 |
+
def to(self, device):
|
541 |
+
for model in self.model:
|
542 |
+
model.to(device)
|
543 |
+
|
diffsynth/models/sd3_dit.py
ADDED
@@ -0,0 +1,798 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from einops import rearrange
|
3 |
+
from .svd_unet import TemporalTimesteps
|
4 |
+
from .tiler import TileWorker
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
class PatchEmbed(torch.nn.Module):
|
9 |
+
def __init__(self, patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192):
|
10 |
+
super().__init__()
|
11 |
+
self.pos_embed_max_size = pos_embed_max_size
|
12 |
+
self.patch_size = patch_size
|
13 |
+
|
14 |
+
self.proj = torch.nn.Conv2d(in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size)
|
15 |
+
self.pos_embed = torch.nn.Parameter(torch.zeros(1, self.pos_embed_max_size, self.pos_embed_max_size, 1536))
|
16 |
+
|
17 |
+
def cropped_pos_embed(self, height, width):
|
18 |
+
height = height // self.patch_size
|
19 |
+
width = width // self.patch_size
|
20 |
+
top = (self.pos_embed_max_size - height) // 2
|
21 |
+
left = (self.pos_embed_max_size - width) // 2
|
22 |
+
spatial_pos_embed = self.pos_embed[:, top : top + height, left : left + width, :].flatten(1, 2)
|
23 |
+
return spatial_pos_embed
|
24 |
+
|
25 |
+
def forward(self, latent):
|
26 |
+
height, width = latent.shape[-2:]
|
27 |
+
latent = self.proj(latent)
|
28 |
+
latent = latent.flatten(2).transpose(1, 2)
|
29 |
+
pos_embed = self.cropped_pos_embed(height, width)
|
30 |
+
return latent + pos_embed
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
class TimestepEmbeddings(torch.nn.Module):
|
35 |
+
def __init__(self, dim_in, dim_out):
|
36 |
+
super().__init__()
|
37 |
+
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
|
38 |
+
self.timestep_embedder = torch.nn.Sequential(
|
39 |
+
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
40 |
+
)
|
41 |
+
|
42 |
+
def forward(self, timestep, dtype):
|
43 |
+
time_emb = self.time_proj(timestep).to(dtype)
|
44 |
+
time_emb = self.timestep_embedder(time_emb)
|
45 |
+
return time_emb
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
class AdaLayerNorm(torch.nn.Module):
|
50 |
+
def __init__(self, dim, single=False):
|
51 |
+
super().__init__()
|
52 |
+
self.single = single
|
53 |
+
self.linear = torch.nn.Linear(dim, dim * (2 if single else 6))
|
54 |
+
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
55 |
+
|
56 |
+
def forward(self, x, emb):
|
57 |
+
emb = self.linear(torch.nn.functional.silu(emb))
|
58 |
+
if self.single:
|
59 |
+
scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
|
60 |
+
x = self.norm(x) * (1 + scale) + shift
|
61 |
+
return x
|
62 |
+
else:
|
63 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
|
64 |
+
x = self.norm(x) * (1 + scale_msa) + shift_msa
|
65 |
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
class JointAttention(torch.nn.Module):
|
70 |
+
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
|
71 |
+
super().__init__()
|
72 |
+
self.num_heads = num_heads
|
73 |
+
self.head_dim = head_dim
|
74 |
+
self.only_out_a = only_out_a
|
75 |
+
|
76 |
+
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
77 |
+
self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
|
78 |
+
|
79 |
+
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
|
80 |
+
if not only_out_a:
|
81 |
+
self.b_to_out = torch.nn.Linear(dim_b, dim_b)
|
82 |
+
|
83 |
+
def forward(self, hidden_states_a, hidden_states_b):
|
84 |
+
batch_size = hidden_states_a.shape[0]
|
85 |
+
|
86 |
+
qkv = torch.concat([self.a_to_qkv(hidden_states_a), self.b_to_qkv(hidden_states_b)], dim=1)
|
87 |
+
qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
88 |
+
q, k, v = qkv.chunk(3, dim=1)
|
89 |
+
|
90 |
+
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
91 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
92 |
+
hidden_states = hidden_states.to(q.dtype)
|
93 |
+
hidden_states_a, hidden_states_b = hidden_states[:, :hidden_states_a.shape[1]], hidden_states[:, hidden_states_a.shape[1]:]
|
94 |
+
hidden_states_a = self.a_to_out(hidden_states_a)
|
95 |
+
if self.only_out_a:
|
96 |
+
return hidden_states_a
|
97 |
+
else:
|
98 |
+
hidden_states_b = self.b_to_out(hidden_states_b)
|
99 |
+
return hidden_states_a, hidden_states_b
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
class JointTransformerBlock(torch.nn.Module):
|
104 |
+
def __init__(self, dim, num_attention_heads):
|
105 |
+
super().__init__()
|
106 |
+
self.norm1_a = AdaLayerNorm(dim)
|
107 |
+
self.norm1_b = AdaLayerNorm(dim)
|
108 |
+
|
109 |
+
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
|
110 |
+
|
111 |
+
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
112 |
+
self.ff_a = torch.nn.Sequential(
|
113 |
+
torch.nn.Linear(dim, dim*4),
|
114 |
+
torch.nn.GELU(approximate="tanh"),
|
115 |
+
torch.nn.Linear(dim*4, dim)
|
116 |
+
)
|
117 |
+
|
118 |
+
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
119 |
+
self.ff_b = torch.nn.Sequential(
|
120 |
+
torch.nn.Linear(dim, dim*4),
|
121 |
+
torch.nn.GELU(approximate="tanh"),
|
122 |
+
torch.nn.Linear(dim*4, dim)
|
123 |
+
)
|
124 |
+
|
125 |
+
|
126 |
+
def forward(self, hidden_states_a, hidden_states_b, temb):
|
127 |
+
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
128 |
+
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
129 |
+
|
130 |
+
# Attention
|
131 |
+
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b)
|
132 |
+
|
133 |
+
# Part A
|
134 |
+
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
135 |
+
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
136 |
+
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
137 |
+
|
138 |
+
# Part B
|
139 |
+
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
|
140 |
+
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
|
141 |
+
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
142 |
+
|
143 |
+
return hidden_states_a, hidden_states_b
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
class JointTransformerFinalBlock(torch.nn.Module):
|
148 |
+
def __init__(self, dim, num_attention_heads):
|
149 |
+
super().__init__()
|
150 |
+
self.norm1_a = AdaLayerNorm(dim)
|
151 |
+
self.norm1_b = AdaLayerNorm(dim, single=True)
|
152 |
+
|
153 |
+
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, only_out_a=True)
|
154 |
+
|
155 |
+
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
156 |
+
self.ff_a = torch.nn.Sequential(
|
157 |
+
torch.nn.Linear(dim, dim*4),
|
158 |
+
torch.nn.GELU(approximate="tanh"),
|
159 |
+
torch.nn.Linear(dim*4, dim)
|
160 |
+
)
|
161 |
+
|
162 |
+
|
163 |
+
def forward(self, hidden_states_a, hidden_states_b, temb):
|
164 |
+
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
165 |
+
norm_hidden_states_b = self.norm1_b(hidden_states_b, emb=temb)
|
166 |
+
|
167 |
+
# Attention
|
168 |
+
attn_output_a = self.attn(norm_hidden_states_a, norm_hidden_states_b)
|
169 |
+
|
170 |
+
# Part A
|
171 |
+
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
172 |
+
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
173 |
+
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
174 |
+
|
175 |
+
return hidden_states_a, hidden_states_b
|
176 |
+
|
177 |
+
|
178 |
+
|
179 |
+
class SD3DiT(torch.nn.Module):
|
180 |
+
def __init__(self):
|
181 |
+
super().__init__()
|
182 |
+
self.pos_embedder = PatchEmbed(patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192)
|
183 |
+
self.time_embedder = TimestepEmbeddings(256, 1536)
|
184 |
+
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(2048, 1536), torch.nn.SiLU(), torch.nn.Linear(1536, 1536))
|
185 |
+
self.context_embedder = torch.nn.Linear(4096, 1536)
|
186 |
+
self.blocks = torch.nn.ModuleList([JointTransformerBlock(1536, 24) for _ in range(23)] + [JointTransformerFinalBlock(1536, 24)])
|
187 |
+
self.norm_out = AdaLayerNorm(1536, single=True)
|
188 |
+
self.proj_out = torch.nn.Linear(1536, 64)
|
189 |
+
|
190 |
+
def tiled_forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size=128, tile_stride=64):
|
191 |
+
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
|
192 |
+
hidden_states = TileWorker().tiled_forward(
|
193 |
+
lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb),
|
194 |
+
hidden_states,
|
195 |
+
tile_size,
|
196 |
+
tile_stride,
|
197 |
+
tile_device=hidden_states.device,
|
198 |
+
tile_dtype=hidden_states.dtype
|
199 |
+
)
|
200 |
+
return hidden_states
|
201 |
+
|
202 |
+
def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tiled=False, tile_size=128, tile_stride=64, use_gradient_checkpointing=False):
|
203 |
+
if tiled:
|
204 |
+
return self.tiled_forward(hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size, tile_stride)
|
205 |
+
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
206 |
+
prompt_emb = self.context_embedder(prompt_emb)
|
207 |
+
|
208 |
+
height, width = hidden_states.shape[-2:]
|
209 |
+
hidden_states = self.pos_embedder(hidden_states)
|
210 |
+
|
211 |
+
def create_custom_forward(module):
|
212 |
+
def custom_forward(*inputs):
|
213 |
+
return module(*inputs)
|
214 |
+
return custom_forward
|
215 |
+
|
216 |
+
for block in self.blocks:
|
217 |
+
if self.training and use_gradient_checkpointing:
|
218 |
+
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
219 |
+
create_custom_forward(block),
|
220 |
+
hidden_states, prompt_emb, conditioning,
|
221 |
+
use_reentrant=False,
|
222 |
+
)
|
223 |
+
else:
|
224 |
+
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning)
|
225 |
+
|
226 |
+
hidden_states = self.norm_out(hidden_states, conditioning)
|
227 |
+
hidden_states = self.proj_out(hidden_states)
|
228 |
+
hidden_states = rearrange(hidden_states, "B (H W) (P Q C) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
229 |
+
return hidden_states
|
230 |
+
|
231 |
+
@staticmethod
|
232 |
+
def state_dict_converter():
|
233 |
+
return SD3DiTStateDictConverter()
|
234 |
+
|
235 |
+
|
236 |
+
|
237 |
+
class SD3DiTStateDictConverter:
|
238 |
+
def __init__(self):
|
239 |
+
pass
|
240 |
+
|
241 |
+
def from_diffusers(self, state_dict):
|
242 |
+
rename_dict = {
|
243 |
+
"context_embedder": "context_embedder",
|
244 |
+
"pos_embed.pos_embed": "pos_embedder.pos_embed",
|
245 |
+
"pos_embed.proj": "pos_embedder.proj",
|
246 |
+
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
247 |
+
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
248 |
+
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
249 |
+
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
250 |
+
"norm_out.linear": "norm_out.linear",
|
251 |
+
"proj_out": "proj_out",
|
252 |
+
|
253 |
+
"norm1.linear": "norm1_a.linear",
|
254 |
+
"norm1_context.linear": "norm1_b.linear",
|
255 |
+
"attn.to_q": "attn.a_to_q",
|
256 |
+
"attn.to_k": "attn.a_to_k",
|
257 |
+
"attn.to_v": "attn.a_to_v",
|
258 |
+
"attn.to_out.0": "attn.a_to_out",
|
259 |
+
"attn.add_q_proj": "attn.b_to_q",
|
260 |
+
"attn.add_k_proj": "attn.b_to_k",
|
261 |
+
"attn.add_v_proj": "attn.b_to_v",
|
262 |
+
"attn.to_add_out": "attn.b_to_out",
|
263 |
+
"ff.net.0.proj": "ff_a.0",
|
264 |
+
"ff.net.2": "ff_a.2",
|
265 |
+
"ff_context.net.0.proj": "ff_b.0",
|
266 |
+
"ff_context.net.2": "ff_b.2",
|
267 |
+
}
|
268 |
+
state_dict_ = {}
|
269 |
+
for name, param in state_dict.items():
|
270 |
+
if name in rename_dict:
|
271 |
+
if name == "pos_embed.pos_embed":
|
272 |
+
param = param.reshape((1, 192, 192, 1536))
|
273 |
+
state_dict_[rename_dict[name]] = param
|
274 |
+
elif name.endswith(".weight") or name.endswith(".bias"):
|
275 |
+
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
276 |
+
prefix = name[:-len(suffix)]
|
277 |
+
if prefix in rename_dict:
|
278 |
+
state_dict_[rename_dict[prefix] + suffix] = param
|
279 |
+
elif prefix.startswith("transformer_blocks."):
|
280 |
+
names = prefix.split(".")
|
281 |
+
names[0] = "blocks"
|
282 |
+
middle = ".".join(names[2:])
|
283 |
+
if middle in rename_dict:
|
284 |
+
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
285 |
+
state_dict_[name_] = param
|
286 |
+
return state_dict_
|
287 |
+
|
288 |
+
def from_civitai(self, state_dict):
|
289 |
+
rename_dict = {
|
290 |
+
"model.diffusion_model.context_embedder.bias": "context_embedder.bias",
|
291 |
+
"model.diffusion_model.context_embedder.weight": "context_embedder.weight",
|
292 |
+
"model.diffusion_model.final_layer.linear.bias": "proj_out.bias",
|
293 |
+
"model.diffusion_model.final_layer.linear.weight": "proj_out.weight",
|
294 |
+
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias": "blocks.0.norm1_b.linear.bias",
|
295 |
+
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.weight": "blocks.0.norm1_b.linear.weight",
|
296 |
+
"model.diffusion_model.joint_blocks.0.context_block.attn.proj.bias": "blocks.0.attn.b_to_out.bias",
|
297 |
+
"model.diffusion_model.joint_blocks.0.context_block.attn.proj.weight": "blocks.0.attn.b_to_out.weight",
|
298 |
+
"model.diffusion_model.joint_blocks.0.context_block.attn.qkv.bias": ['blocks.0.attn.b_to_q.bias', 'blocks.0.attn.b_to_k.bias', 'blocks.0.attn.b_to_v.bias'],
|
299 |
+
"model.diffusion_model.joint_blocks.0.context_block.attn.qkv.weight": ['blocks.0.attn.b_to_q.weight', 'blocks.0.attn.b_to_k.weight', 'blocks.0.attn.b_to_v.weight'],
|
300 |
+
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc1.bias": "blocks.0.ff_b.0.bias",
|
301 |
+
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc1.weight": "blocks.0.ff_b.0.weight",
|
302 |
+
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc2.bias": "blocks.0.ff_b.2.bias",
|
303 |
+
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc2.weight": "blocks.0.ff_b.2.weight",
|
304 |
+
"model.diffusion_model.joint_blocks.0.x_block.adaLN_modulation.1.bias": "blocks.0.norm1_a.linear.bias",
|
305 |
+
"model.diffusion_model.joint_blocks.0.x_block.adaLN_modulation.1.weight": "blocks.0.norm1_a.linear.weight",
|
306 |
+
"model.diffusion_model.joint_blocks.0.x_block.attn.proj.bias": "blocks.0.attn.a_to_out.bias",
|
307 |
+
"model.diffusion_model.joint_blocks.0.x_block.attn.proj.weight": "blocks.0.attn.a_to_out.weight",
|
308 |
+
"model.diffusion_model.joint_blocks.0.x_block.attn.qkv.bias": ['blocks.0.attn.a_to_q.bias', 'blocks.0.attn.a_to_k.bias', 'blocks.0.attn.a_to_v.bias'],
|
309 |
+
"model.diffusion_model.joint_blocks.0.x_block.attn.qkv.weight": ['blocks.0.attn.a_to_q.weight', 'blocks.0.attn.a_to_k.weight', 'blocks.0.attn.a_to_v.weight'],
|
310 |
+
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc1.bias": "blocks.0.ff_a.0.bias",
|
311 |
+
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc1.weight": "blocks.0.ff_a.0.weight",
|
312 |
+
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc2.bias": "blocks.0.ff_a.2.bias",
|
313 |
+
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc2.weight": "blocks.0.ff_a.2.weight",
|
314 |
+
"model.diffusion_model.joint_blocks.1.context_block.adaLN_modulation.1.bias": "blocks.1.norm1_b.linear.bias",
|
315 |
+
"model.diffusion_model.joint_blocks.1.context_block.adaLN_modulation.1.weight": "blocks.1.norm1_b.linear.weight",
|
316 |
+
"model.diffusion_model.joint_blocks.1.context_block.attn.proj.bias": "blocks.1.attn.b_to_out.bias",
|
317 |
+
"model.diffusion_model.joint_blocks.1.context_block.attn.proj.weight": "blocks.1.attn.b_to_out.weight",
|
318 |
+
"model.diffusion_model.joint_blocks.1.context_block.attn.qkv.bias": ['blocks.1.attn.b_to_q.bias', 'blocks.1.attn.b_to_k.bias', 'blocks.1.attn.b_to_v.bias'],
|
319 |
+
"model.diffusion_model.joint_blocks.1.context_block.attn.qkv.weight": ['blocks.1.attn.b_to_q.weight', 'blocks.1.attn.b_to_k.weight', 'blocks.1.attn.b_to_v.weight'],
|
320 |
+
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc1.bias": "blocks.1.ff_b.0.bias",
|
321 |
+
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc1.weight": "blocks.1.ff_b.0.weight",
|
322 |
+
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc2.bias": "blocks.1.ff_b.2.bias",
|
323 |
+
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc2.weight": "blocks.1.ff_b.2.weight",
|
324 |
+
"model.diffusion_model.joint_blocks.1.x_block.adaLN_modulation.1.bias": "blocks.1.norm1_a.linear.bias",
|
325 |
+
"model.diffusion_model.joint_blocks.1.x_block.adaLN_modulation.1.weight": "blocks.1.norm1_a.linear.weight",
|
326 |
+
"model.diffusion_model.joint_blocks.1.x_block.attn.proj.bias": "blocks.1.attn.a_to_out.bias",
|
327 |
+
"model.diffusion_model.joint_blocks.1.x_block.attn.proj.weight": "blocks.1.attn.a_to_out.weight",
|
328 |
+
"model.diffusion_model.joint_blocks.1.x_block.attn.qkv.bias": ['blocks.1.attn.a_to_q.bias', 'blocks.1.attn.a_to_k.bias', 'blocks.1.attn.a_to_v.bias'],
|
329 |
+
"model.diffusion_model.joint_blocks.1.x_block.attn.qkv.weight": ['blocks.1.attn.a_to_q.weight', 'blocks.1.attn.a_to_k.weight', 'blocks.1.attn.a_to_v.weight'],
|
330 |
+
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc1.bias": "blocks.1.ff_a.0.bias",
|
331 |
+
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc1.weight": "blocks.1.ff_a.0.weight",
|
332 |
+
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc2.bias": "blocks.1.ff_a.2.bias",
|
333 |
+
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc2.weight": "blocks.1.ff_a.2.weight",
|
334 |
+
"model.diffusion_model.joint_blocks.10.context_block.adaLN_modulation.1.bias": "blocks.10.norm1_b.linear.bias",
|
335 |
+
"model.diffusion_model.joint_blocks.10.context_block.adaLN_modulation.1.weight": "blocks.10.norm1_b.linear.weight",
|
336 |
+
"model.diffusion_model.joint_blocks.10.context_block.attn.proj.bias": "blocks.10.attn.b_to_out.bias",
|
337 |
+
"model.diffusion_model.joint_blocks.10.context_block.attn.proj.weight": "blocks.10.attn.b_to_out.weight",
|
338 |
+
"model.diffusion_model.joint_blocks.10.context_block.attn.qkv.bias": ['blocks.10.attn.b_to_q.bias', 'blocks.10.attn.b_to_k.bias', 'blocks.10.attn.b_to_v.bias'],
|
339 |
+
"model.diffusion_model.joint_blocks.10.context_block.attn.qkv.weight": ['blocks.10.attn.b_to_q.weight', 'blocks.10.attn.b_to_k.weight', 'blocks.10.attn.b_to_v.weight'],
|
340 |
+
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc1.bias": "blocks.10.ff_b.0.bias",
|
341 |
+
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc1.weight": "blocks.10.ff_b.0.weight",
|
342 |
+
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc2.bias": "blocks.10.ff_b.2.bias",
|
343 |
+
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc2.weight": "blocks.10.ff_b.2.weight",
|
344 |
+
"model.diffusion_model.joint_blocks.10.x_block.adaLN_modulation.1.bias": "blocks.10.norm1_a.linear.bias",
|
345 |
+
"model.diffusion_model.joint_blocks.10.x_block.adaLN_modulation.1.weight": "blocks.10.norm1_a.linear.weight",
|
346 |
+
"model.diffusion_model.joint_blocks.10.x_block.attn.proj.bias": "blocks.10.attn.a_to_out.bias",
|
347 |
+
"model.diffusion_model.joint_blocks.10.x_block.attn.proj.weight": "blocks.10.attn.a_to_out.weight",
|
348 |
+
"model.diffusion_model.joint_blocks.10.x_block.attn.qkv.bias": ['blocks.10.attn.a_to_q.bias', 'blocks.10.attn.a_to_k.bias', 'blocks.10.attn.a_to_v.bias'],
|
349 |
+
"model.diffusion_model.joint_blocks.10.x_block.attn.qkv.weight": ['blocks.10.attn.a_to_q.weight', 'blocks.10.attn.a_to_k.weight', 'blocks.10.attn.a_to_v.weight'],
|
350 |
+
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc1.bias": "blocks.10.ff_a.0.bias",
|
351 |
+
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc1.weight": "blocks.10.ff_a.0.weight",
|
352 |
+
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc2.bias": "blocks.10.ff_a.2.bias",
|
353 |
+
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc2.weight": "blocks.10.ff_a.2.weight",
|
354 |
+
"model.diffusion_model.joint_blocks.11.context_block.adaLN_modulation.1.bias": "blocks.11.norm1_b.linear.bias",
|
355 |
+
"model.diffusion_model.joint_blocks.11.context_block.adaLN_modulation.1.weight": "blocks.11.norm1_b.linear.weight",
|
356 |
+
"model.diffusion_model.joint_blocks.11.context_block.attn.proj.bias": "blocks.11.attn.b_to_out.bias",
|
357 |
+
"model.diffusion_model.joint_blocks.11.context_block.attn.proj.weight": "blocks.11.attn.b_to_out.weight",
|
358 |
+
"model.diffusion_model.joint_blocks.11.context_block.attn.qkv.bias": ['blocks.11.attn.b_to_q.bias', 'blocks.11.attn.b_to_k.bias', 'blocks.11.attn.b_to_v.bias'],
|
359 |
+
"model.diffusion_model.joint_blocks.11.context_block.attn.qkv.weight": ['blocks.11.attn.b_to_q.weight', 'blocks.11.attn.b_to_k.weight', 'blocks.11.attn.b_to_v.weight'],
|
360 |
+
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc1.bias": "blocks.11.ff_b.0.bias",
|
361 |
+
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc1.weight": "blocks.11.ff_b.0.weight",
|
362 |
+
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc2.bias": "blocks.11.ff_b.2.bias",
|
363 |
+
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc2.weight": "blocks.11.ff_b.2.weight",
|
364 |
+
"model.diffusion_model.joint_blocks.11.x_block.adaLN_modulation.1.bias": "blocks.11.norm1_a.linear.bias",
|
365 |
+
"model.diffusion_model.joint_blocks.11.x_block.adaLN_modulation.1.weight": "blocks.11.norm1_a.linear.weight",
|
366 |
+
"model.diffusion_model.joint_blocks.11.x_block.attn.proj.bias": "blocks.11.attn.a_to_out.bias",
|
367 |
+
"model.diffusion_model.joint_blocks.11.x_block.attn.proj.weight": "blocks.11.attn.a_to_out.weight",
|
368 |
+
"model.diffusion_model.joint_blocks.11.x_block.attn.qkv.bias": ['blocks.11.attn.a_to_q.bias', 'blocks.11.attn.a_to_k.bias', 'blocks.11.attn.a_to_v.bias'],
|
369 |
+
"model.diffusion_model.joint_blocks.11.x_block.attn.qkv.weight": ['blocks.11.attn.a_to_q.weight', 'blocks.11.attn.a_to_k.weight', 'blocks.11.attn.a_to_v.weight'],
|
370 |
+
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc1.bias": "blocks.11.ff_a.0.bias",
|
371 |
+
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc1.weight": "blocks.11.ff_a.0.weight",
|
372 |
+
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc2.bias": "blocks.11.ff_a.2.bias",
|
373 |
+
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc2.weight": "blocks.11.ff_a.2.weight",
|
374 |
+
"model.diffusion_model.joint_blocks.12.context_block.adaLN_modulation.1.bias": "blocks.12.norm1_b.linear.bias",
|
375 |
+
"model.diffusion_model.joint_blocks.12.context_block.adaLN_modulation.1.weight": "blocks.12.norm1_b.linear.weight",
|
376 |
+
"model.diffusion_model.joint_blocks.12.context_block.attn.proj.bias": "blocks.12.attn.b_to_out.bias",
|
377 |
+
"model.diffusion_model.joint_blocks.12.context_block.attn.proj.weight": "blocks.12.attn.b_to_out.weight",
|
378 |
+
"model.diffusion_model.joint_blocks.12.context_block.attn.qkv.bias": ['blocks.12.attn.b_to_q.bias', 'blocks.12.attn.b_to_k.bias', 'blocks.12.attn.b_to_v.bias'],
|
379 |
+
"model.diffusion_model.joint_blocks.12.context_block.attn.qkv.weight": ['blocks.12.attn.b_to_q.weight', 'blocks.12.attn.b_to_k.weight', 'blocks.12.attn.b_to_v.weight'],
|
380 |
+
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc1.bias": "blocks.12.ff_b.0.bias",
|
381 |
+
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc1.weight": "blocks.12.ff_b.0.weight",
|
382 |
+
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc2.bias": "blocks.12.ff_b.2.bias",
|
383 |
+
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc2.weight": "blocks.12.ff_b.2.weight",
|
384 |
+
"model.diffusion_model.joint_blocks.12.x_block.adaLN_modulation.1.bias": "blocks.12.norm1_a.linear.bias",
|
385 |
+
"model.diffusion_model.joint_blocks.12.x_block.adaLN_modulation.1.weight": "blocks.12.norm1_a.linear.weight",
|
386 |
+
"model.diffusion_model.joint_blocks.12.x_block.attn.proj.bias": "blocks.12.attn.a_to_out.bias",
|
387 |
+
"model.diffusion_model.joint_blocks.12.x_block.attn.proj.weight": "blocks.12.attn.a_to_out.weight",
|
388 |
+
"model.diffusion_model.joint_blocks.12.x_block.attn.qkv.bias": ['blocks.12.attn.a_to_q.bias', 'blocks.12.attn.a_to_k.bias', 'blocks.12.attn.a_to_v.bias'],
|
389 |
+
"model.diffusion_model.joint_blocks.12.x_block.attn.qkv.weight": ['blocks.12.attn.a_to_q.weight', 'blocks.12.attn.a_to_k.weight', 'blocks.12.attn.a_to_v.weight'],
|
390 |
+
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc1.bias": "blocks.12.ff_a.0.bias",
|
391 |
+
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc1.weight": "blocks.12.ff_a.0.weight",
|
392 |
+
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc2.bias": "blocks.12.ff_a.2.bias",
|
393 |
+
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc2.weight": "blocks.12.ff_a.2.weight",
|
394 |
+
"model.diffusion_model.joint_blocks.13.context_block.adaLN_modulation.1.bias": "blocks.13.norm1_b.linear.bias",
|
395 |
+
"model.diffusion_model.joint_blocks.13.context_block.adaLN_modulation.1.weight": "blocks.13.norm1_b.linear.weight",
|
396 |
+
"model.diffusion_model.joint_blocks.13.context_block.attn.proj.bias": "blocks.13.attn.b_to_out.bias",
|
397 |
+
"model.diffusion_model.joint_blocks.13.context_block.attn.proj.weight": "blocks.13.attn.b_to_out.weight",
|
398 |
+
"model.diffusion_model.joint_blocks.13.context_block.attn.qkv.bias": ['blocks.13.attn.b_to_q.bias', 'blocks.13.attn.b_to_k.bias', 'blocks.13.attn.b_to_v.bias'],
|
399 |
+
"model.diffusion_model.joint_blocks.13.context_block.attn.qkv.weight": ['blocks.13.attn.b_to_q.weight', 'blocks.13.attn.b_to_k.weight', 'blocks.13.attn.b_to_v.weight'],
|
400 |
+
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc1.bias": "blocks.13.ff_b.0.bias",
|
401 |
+
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc1.weight": "blocks.13.ff_b.0.weight",
|
402 |
+
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc2.bias": "blocks.13.ff_b.2.bias",
|
403 |
+
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc2.weight": "blocks.13.ff_b.2.weight",
|
404 |
+
"model.diffusion_model.joint_blocks.13.x_block.adaLN_modulation.1.bias": "blocks.13.norm1_a.linear.bias",
|
405 |
+
"model.diffusion_model.joint_blocks.13.x_block.adaLN_modulation.1.weight": "blocks.13.norm1_a.linear.weight",
|
406 |
+
"model.diffusion_model.joint_blocks.13.x_block.attn.proj.bias": "blocks.13.attn.a_to_out.bias",
|
407 |
+
"model.diffusion_model.joint_blocks.13.x_block.attn.proj.weight": "blocks.13.attn.a_to_out.weight",
|
408 |
+
"model.diffusion_model.joint_blocks.13.x_block.attn.qkv.bias": ['blocks.13.attn.a_to_q.bias', 'blocks.13.attn.a_to_k.bias', 'blocks.13.attn.a_to_v.bias'],
|
409 |
+
"model.diffusion_model.joint_blocks.13.x_block.attn.qkv.weight": ['blocks.13.attn.a_to_q.weight', 'blocks.13.attn.a_to_k.weight', 'blocks.13.attn.a_to_v.weight'],
|
410 |
+
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc1.bias": "blocks.13.ff_a.0.bias",
|
411 |
+
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc1.weight": "blocks.13.ff_a.0.weight",
|
412 |
+
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc2.bias": "blocks.13.ff_a.2.bias",
|
413 |
+
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc2.weight": "blocks.13.ff_a.2.weight",
|
414 |
+
"model.diffusion_model.joint_blocks.14.context_block.adaLN_modulation.1.bias": "blocks.14.norm1_b.linear.bias",
|
415 |
+
"model.diffusion_model.joint_blocks.14.context_block.adaLN_modulation.1.weight": "blocks.14.norm1_b.linear.weight",
|
416 |
+
"model.diffusion_model.joint_blocks.14.context_block.attn.proj.bias": "blocks.14.attn.b_to_out.bias",
|
417 |
+
"model.diffusion_model.joint_blocks.14.context_block.attn.proj.weight": "blocks.14.attn.b_to_out.weight",
|
418 |
+
"model.diffusion_model.joint_blocks.14.context_block.attn.qkv.bias": ['blocks.14.attn.b_to_q.bias', 'blocks.14.attn.b_to_k.bias', 'blocks.14.attn.b_to_v.bias'],
|
419 |
+
"model.diffusion_model.joint_blocks.14.context_block.attn.qkv.weight": ['blocks.14.attn.b_to_q.weight', 'blocks.14.attn.b_to_k.weight', 'blocks.14.attn.b_to_v.weight'],
|
420 |
+
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc1.bias": "blocks.14.ff_b.0.bias",
|
421 |
+
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc1.weight": "blocks.14.ff_b.0.weight",
|
422 |
+
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc2.bias": "blocks.14.ff_b.2.bias",
|
423 |
+
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc2.weight": "blocks.14.ff_b.2.weight",
|
424 |
+
"model.diffusion_model.joint_blocks.14.x_block.adaLN_modulation.1.bias": "blocks.14.norm1_a.linear.bias",
|
425 |
+
"model.diffusion_model.joint_blocks.14.x_block.adaLN_modulation.1.weight": "blocks.14.norm1_a.linear.weight",
|
426 |
+
"model.diffusion_model.joint_blocks.14.x_block.attn.proj.bias": "blocks.14.attn.a_to_out.bias",
|
427 |
+
"model.diffusion_model.joint_blocks.14.x_block.attn.proj.weight": "blocks.14.attn.a_to_out.weight",
|
428 |
+
"model.diffusion_model.joint_blocks.14.x_block.attn.qkv.bias": ['blocks.14.attn.a_to_q.bias', 'blocks.14.attn.a_to_k.bias', 'blocks.14.attn.a_to_v.bias'],
|
429 |
+
"model.diffusion_model.joint_blocks.14.x_block.attn.qkv.weight": ['blocks.14.attn.a_to_q.weight', 'blocks.14.attn.a_to_k.weight', 'blocks.14.attn.a_to_v.weight'],
|
430 |
+
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc1.bias": "blocks.14.ff_a.0.bias",
|
431 |
+
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc1.weight": "blocks.14.ff_a.0.weight",
|
432 |
+
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc2.bias": "blocks.14.ff_a.2.bias",
|
433 |
+
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc2.weight": "blocks.14.ff_a.2.weight",
|
434 |
+
"model.diffusion_model.joint_blocks.15.context_block.adaLN_modulation.1.bias": "blocks.15.norm1_b.linear.bias",
|
435 |
+
"model.diffusion_model.joint_blocks.15.context_block.adaLN_modulation.1.weight": "blocks.15.norm1_b.linear.weight",
|
436 |
+
"model.diffusion_model.joint_blocks.15.context_block.attn.proj.bias": "blocks.15.attn.b_to_out.bias",
|
437 |
+
"model.diffusion_model.joint_blocks.15.context_block.attn.proj.weight": "blocks.15.attn.b_to_out.weight",
|
438 |
+
"model.diffusion_model.joint_blocks.15.context_block.attn.qkv.bias": ['blocks.15.attn.b_to_q.bias', 'blocks.15.attn.b_to_k.bias', 'blocks.15.attn.b_to_v.bias'],
|
439 |
+
"model.diffusion_model.joint_blocks.15.context_block.attn.qkv.weight": ['blocks.15.attn.b_to_q.weight', 'blocks.15.attn.b_to_k.weight', 'blocks.15.attn.b_to_v.weight'],
|
440 |
+
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc1.bias": "blocks.15.ff_b.0.bias",
|
441 |
+
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc1.weight": "blocks.15.ff_b.0.weight",
|
442 |
+
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc2.bias": "blocks.15.ff_b.2.bias",
|
443 |
+
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc2.weight": "blocks.15.ff_b.2.weight",
|
444 |
+
"model.diffusion_model.joint_blocks.15.x_block.adaLN_modulation.1.bias": "blocks.15.norm1_a.linear.bias",
|
445 |
+
"model.diffusion_model.joint_blocks.15.x_block.adaLN_modulation.1.weight": "blocks.15.norm1_a.linear.weight",
|
446 |
+
"model.diffusion_model.joint_blocks.15.x_block.attn.proj.bias": "blocks.15.attn.a_to_out.bias",
|
447 |
+
"model.diffusion_model.joint_blocks.15.x_block.attn.proj.weight": "blocks.15.attn.a_to_out.weight",
|
448 |
+
"model.diffusion_model.joint_blocks.15.x_block.attn.qkv.bias": ['blocks.15.attn.a_to_q.bias', 'blocks.15.attn.a_to_k.bias', 'blocks.15.attn.a_to_v.bias'],
|
449 |
+
"model.diffusion_model.joint_blocks.15.x_block.attn.qkv.weight": ['blocks.15.attn.a_to_q.weight', 'blocks.15.attn.a_to_k.weight', 'blocks.15.attn.a_to_v.weight'],
|
450 |
+
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc1.bias": "blocks.15.ff_a.0.bias",
|
451 |
+
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc1.weight": "blocks.15.ff_a.0.weight",
|
452 |
+
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc2.bias": "blocks.15.ff_a.2.bias",
|
453 |
+
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc2.weight": "blocks.15.ff_a.2.weight",
|
454 |
+
"model.diffusion_model.joint_blocks.16.context_block.adaLN_modulation.1.bias": "blocks.16.norm1_b.linear.bias",
|
455 |
+
"model.diffusion_model.joint_blocks.16.context_block.adaLN_modulation.1.weight": "blocks.16.norm1_b.linear.weight",
|
456 |
+
"model.diffusion_model.joint_blocks.16.context_block.attn.proj.bias": "blocks.16.attn.b_to_out.bias",
|
457 |
+
"model.diffusion_model.joint_blocks.16.context_block.attn.proj.weight": "blocks.16.attn.b_to_out.weight",
|
458 |
+
"model.diffusion_model.joint_blocks.16.context_block.attn.qkv.bias": ['blocks.16.attn.b_to_q.bias', 'blocks.16.attn.b_to_k.bias', 'blocks.16.attn.b_to_v.bias'],
|
459 |
+
"model.diffusion_model.joint_blocks.16.context_block.attn.qkv.weight": ['blocks.16.attn.b_to_q.weight', 'blocks.16.attn.b_to_k.weight', 'blocks.16.attn.b_to_v.weight'],
|
460 |
+
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc1.bias": "blocks.16.ff_b.0.bias",
|
461 |
+
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc1.weight": "blocks.16.ff_b.0.weight",
|
462 |
+
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc2.bias": "blocks.16.ff_b.2.bias",
|
463 |
+
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc2.weight": "blocks.16.ff_b.2.weight",
|
464 |
+
"model.diffusion_model.joint_blocks.16.x_block.adaLN_modulation.1.bias": "blocks.16.norm1_a.linear.bias",
|
465 |
+
"model.diffusion_model.joint_blocks.16.x_block.adaLN_modulation.1.weight": "blocks.16.norm1_a.linear.weight",
|
466 |
+
"model.diffusion_model.joint_blocks.16.x_block.attn.proj.bias": "blocks.16.attn.a_to_out.bias",
|
467 |
+
"model.diffusion_model.joint_blocks.16.x_block.attn.proj.weight": "blocks.16.attn.a_to_out.weight",
|
468 |
+
"model.diffusion_model.joint_blocks.16.x_block.attn.qkv.bias": ['blocks.16.attn.a_to_q.bias', 'blocks.16.attn.a_to_k.bias', 'blocks.16.attn.a_to_v.bias'],
|
469 |
+
"model.diffusion_model.joint_blocks.16.x_block.attn.qkv.weight": ['blocks.16.attn.a_to_q.weight', 'blocks.16.attn.a_to_k.weight', 'blocks.16.attn.a_to_v.weight'],
|
470 |
+
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc1.bias": "blocks.16.ff_a.0.bias",
|
471 |
+
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc1.weight": "blocks.16.ff_a.0.weight",
|
472 |
+
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc2.bias": "blocks.16.ff_a.2.bias",
|
473 |
+
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc2.weight": "blocks.16.ff_a.2.weight",
|
474 |
+
"model.diffusion_model.joint_blocks.17.context_block.adaLN_modulation.1.bias": "blocks.17.norm1_b.linear.bias",
|
475 |
+
"model.diffusion_model.joint_blocks.17.context_block.adaLN_modulation.1.weight": "blocks.17.norm1_b.linear.weight",
|
476 |
+
"model.diffusion_model.joint_blocks.17.context_block.attn.proj.bias": "blocks.17.attn.b_to_out.bias",
|
477 |
+
"model.diffusion_model.joint_blocks.17.context_block.attn.proj.weight": "blocks.17.attn.b_to_out.weight",
|
478 |
+
"model.diffusion_model.joint_blocks.17.context_block.attn.qkv.bias": ['blocks.17.attn.b_to_q.bias', 'blocks.17.attn.b_to_k.bias', 'blocks.17.attn.b_to_v.bias'],
|
479 |
+
"model.diffusion_model.joint_blocks.17.context_block.attn.qkv.weight": ['blocks.17.attn.b_to_q.weight', 'blocks.17.attn.b_to_k.weight', 'blocks.17.attn.b_to_v.weight'],
|
480 |
+
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc1.bias": "blocks.17.ff_b.0.bias",
|
481 |
+
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc1.weight": "blocks.17.ff_b.0.weight",
|
482 |
+
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc2.bias": "blocks.17.ff_b.2.bias",
|
483 |
+
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc2.weight": "blocks.17.ff_b.2.weight",
|
484 |
+
"model.diffusion_model.joint_blocks.17.x_block.adaLN_modulation.1.bias": "blocks.17.norm1_a.linear.bias",
|
485 |
+
"model.diffusion_model.joint_blocks.17.x_block.adaLN_modulation.1.weight": "blocks.17.norm1_a.linear.weight",
|
486 |
+
"model.diffusion_model.joint_blocks.17.x_block.attn.proj.bias": "blocks.17.attn.a_to_out.bias",
|
487 |
+
"model.diffusion_model.joint_blocks.17.x_block.attn.proj.weight": "blocks.17.attn.a_to_out.weight",
|
488 |
+
"model.diffusion_model.joint_blocks.17.x_block.attn.qkv.bias": ['blocks.17.attn.a_to_q.bias', 'blocks.17.attn.a_to_k.bias', 'blocks.17.attn.a_to_v.bias'],
|
489 |
+
"model.diffusion_model.joint_blocks.17.x_block.attn.qkv.weight": ['blocks.17.attn.a_to_q.weight', 'blocks.17.attn.a_to_k.weight', 'blocks.17.attn.a_to_v.weight'],
|
490 |
+
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc1.bias": "blocks.17.ff_a.0.bias",
|
491 |
+
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc1.weight": "blocks.17.ff_a.0.weight",
|
492 |
+
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc2.bias": "blocks.17.ff_a.2.bias",
|
493 |
+
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc2.weight": "blocks.17.ff_a.2.weight",
|
494 |
+
"model.diffusion_model.joint_blocks.18.context_block.adaLN_modulation.1.bias": "blocks.18.norm1_b.linear.bias",
|
495 |
+
"model.diffusion_model.joint_blocks.18.context_block.adaLN_modulation.1.weight": "blocks.18.norm1_b.linear.weight",
|
496 |
+
"model.diffusion_model.joint_blocks.18.context_block.attn.proj.bias": "blocks.18.attn.b_to_out.bias",
|
497 |
+
"model.diffusion_model.joint_blocks.18.context_block.attn.proj.weight": "blocks.18.attn.b_to_out.weight",
|
498 |
+
"model.diffusion_model.joint_blocks.18.context_block.attn.qkv.bias": ['blocks.18.attn.b_to_q.bias', 'blocks.18.attn.b_to_k.bias', 'blocks.18.attn.b_to_v.bias'],
|
499 |
+
"model.diffusion_model.joint_blocks.18.context_block.attn.qkv.weight": ['blocks.18.attn.b_to_q.weight', 'blocks.18.attn.b_to_k.weight', 'blocks.18.attn.b_to_v.weight'],
|
500 |
+
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc1.bias": "blocks.18.ff_b.0.bias",
|
501 |
+
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc1.weight": "blocks.18.ff_b.0.weight",
|
502 |
+
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc2.bias": "blocks.18.ff_b.2.bias",
|
503 |
+
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc2.weight": "blocks.18.ff_b.2.weight",
|
504 |
+
"model.diffusion_model.joint_blocks.18.x_block.adaLN_modulation.1.bias": "blocks.18.norm1_a.linear.bias",
|
505 |
+
"model.diffusion_model.joint_blocks.18.x_block.adaLN_modulation.1.weight": "blocks.18.norm1_a.linear.weight",
|
506 |
+
"model.diffusion_model.joint_blocks.18.x_block.attn.proj.bias": "blocks.18.attn.a_to_out.bias",
|
507 |
+
"model.diffusion_model.joint_blocks.18.x_block.attn.proj.weight": "blocks.18.attn.a_to_out.weight",
|
508 |
+
"model.diffusion_model.joint_blocks.18.x_block.attn.qkv.bias": ['blocks.18.attn.a_to_q.bias', 'blocks.18.attn.a_to_k.bias', 'blocks.18.attn.a_to_v.bias'],
|
509 |
+
"model.diffusion_model.joint_blocks.18.x_block.attn.qkv.weight": ['blocks.18.attn.a_to_q.weight', 'blocks.18.attn.a_to_k.weight', 'blocks.18.attn.a_to_v.weight'],
|
510 |
+
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc1.bias": "blocks.18.ff_a.0.bias",
|
511 |
+
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc1.weight": "blocks.18.ff_a.0.weight",
|
512 |
+
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc2.bias": "blocks.18.ff_a.2.bias",
|
513 |
+
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc2.weight": "blocks.18.ff_a.2.weight",
|
514 |
+
"model.diffusion_model.joint_blocks.19.context_block.adaLN_modulation.1.bias": "blocks.19.norm1_b.linear.bias",
|
515 |
+
"model.diffusion_model.joint_blocks.19.context_block.adaLN_modulation.1.weight": "blocks.19.norm1_b.linear.weight",
|
516 |
+
"model.diffusion_model.joint_blocks.19.context_block.attn.proj.bias": "blocks.19.attn.b_to_out.bias",
|
517 |
+
"model.diffusion_model.joint_blocks.19.context_block.attn.proj.weight": "blocks.19.attn.b_to_out.weight",
|
518 |
+
"model.diffusion_model.joint_blocks.19.context_block.attn.qkv.bias": ['blocks.19.attn.b_to_q.bias', 'blocks.19.attn.b_to_k.bias', 'blocks.19.attn.b_to_v.bias'],
|
519 |
+
"model.diffusion_model.joint_blocks.19.context_block.attn.qkv.weight": ['blocks.19.attn.b_to_q.weight', 'blocks.19.attn.b_to_k.weight', 'blocks.19.attn.b_to_v.weight'],
|
520 |
+
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc1.bias": "blocks.19.ff_b.0.bias",
|
521 |
+
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc1.weight": "blocks.19.ff_b.0.weight",
|
522 |
+
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc2.bias": "blocks.19.ff_b.2.bias",
|
523 |
+
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc2.weight": "blocks.19.ff_b.2.weight",
|
524 |
+
"model.diffusion_model.joint_blocks.19.x_block.adaLN_modulation.1.bias": "blocks.19.norm1_a.linear.bias",
|
525 |
+
"model.diffusion_model.joint_blocks.19.x_block.adaLN_modulation.1.weight": "blocks.19.norm1_a.linear.weight",
|
526 |
+
"model.diffusion_model.joint_blocks.19.x_block.attn.proj.bias": "blocks.19.attn.a_to_out.bias",
|
527 |
+
"model.diffusion_model.joint_blocks.19.x_block.attn.proj.weight": "blocks.19.attn.a_to_out.weight",
|
528 |
+
"model.diffusion_model.joint_blocks.19.x_block.attn.qkv.bias": ['blocks.19.attn.a_to_q.bias', 'blocks.19.attn.a_to_k.bias', 'blocks.19.attn.a_to_v.bias'],
|
529 |
+
"model.diffusion_model.joint_blocks.19.x_block.attn.qkv.weight": ['blocks.19.attn.a_to_q.weight', 'blocks.19.attn.a_to_k.weight', 'blocks.19.attn.a_to_v.weight'],
|
530 |
+
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc1.bias": "blocks.19.ff_a.0.bias",
|
531 |
+
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc1.weight": "blocks.19.ff_a.0.weight",
|
532 |
+
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc2.bias": "blocks.19.ff_a.2.bias",
|
533 |
+
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc2.weight": "blocks.19.ff_a.2.weight",
|
534 |
+
"model.diffusion_model.joint_blocks.2.context_block.adaLN_modulation.1.bias": "blocks.2.norm1_b.linear.bias",
|
535 |
+
"model.diffusion_model.joint_blocks.2.context_block.adaLN_modulation.1.weight": "blocks.2.norm1_b.linear.weight",
|
536 |
+
"model.diffusion_model.joint_blocks.2.context_block.attn.proj.bias": "blocks.2.attn.b_to_out.bias",
|
537 |
+
"model.diffusion_model.joint_blocks.2.context_block.attn.proj.weight": "blocks.2.attn.b_to_out.weight",
|
538 |
+
"model.diffusion_model.joint_blocks.2.context_block.attn.qkv.bias": ['blocks.2.attn.b_to_q.bias', 'blocks.2.attn.b_to_k.bias', 'blocks.2.attn.b_to_v.bias'],
|
539 |
+
"model.diffusion_model.joint_blocks.2.context_block.attn.qkv.weight": ['blocks.2.attn.b_to_q.weight', 'blocks.2.attn.b_to_k.weight', 'blocks.2.attn.b_to_v.weight'],
|
540 |
+
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc1.bias": "blocks.2.ff_b.0.bias",
|
541 |
+
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc1.weight": "blocks.2.ff_b.0.weight",
|
542 |
+
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc2.bias": "blocks.2.ff_b.2.bias",
|
543 |
+
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc2.weight": "blocks.2.ff_b.2.weight",
|
544 |
+
"model.diffusion_model.joint_blocks.2.x_block.adaLN_modulation.1.bias": "blocks.2.norm1_a.linear.bias",
|
545 |
+
"model.diffusion_model.joint_blocks.2.x_block.adaLN_modulation.1.weight": "blocks.2.norm1_a.linear.weight",
|
546 |
+
"model.diffusion_model.joint_blocks.2.x_block.attn.proj.bias": "blocks.2.attn.a_to_out.bias",
|
547 |
+
"model.diffusion_model.joint_blocks.2.x_block.attn.proj.weight": "blocks.2.attn.a_to_out.weight",
|
548 |
+
"model.diffusion_model.joint_blocks.2.x_block.attn.qkv.bias": ['blocks.2.attn.a_to_q.bias', 'blocks.2.attn.a_to_k.bias', 'blocks.2.attn.a_to_v.bias'],
|
549 |
+
"model.diffusion_model.joint_blocks.2.x_block.attn.qkv.weight": ['blocks.2.attn.a_to_q.weight', 'blocks.2.attn.a_to_k.weight', 'blocks.2.attn.a_to_v.weight'],
|
550 |
+
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc1.bias": "blocks.2.ff_a.0.bias",
|
551 |
+
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc1.weight": "blocks.2.ff_a.0.weight",
|
552 |
+
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc2.bias": "blocks.2.ff_a.2.bias",
|
553 |
+
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc2.weight": "blocks.2.ff_a.2.weight",
|
554 |
+
"model.diffusion_model.joint_blocks.20.context_block.adaLN_modulation.1.bias": "blocks.20.norm1_b.linear.bias",
|
555 |
+
"model.diffusion_model.joint_blocks.20.context_block.adaLN_modulation.1.weight": "blocks.20.norm1_b.linear.weight",
|
556 |
+
"model.diffusion_model.joint_blocks.20.context_block.attn.proj.bias": "blocks.20.attn.b_to_out.bias",
|
557 |
+
"model.diffusion_model.joint_blocks.20.context_block.attn.proj.weight": "blocks.20.attn.b_to_out.weight",
|
558 |
+
"model.diffusion_model.joint_blocks.20.context_block.attn.qkv.bias": ['blocks.20.attn.b_to_q.bias', 'blocks.20.attn.b_to_k.bias', 'blocks.20.attn.b_to_v.bias'],
|
559 |
+
"model.diffusion_model.joint_blocks.20.context_block.attn.qkv.weight": ['blocks.20.attn.b_to_q.weight', 'blocks.20.attn.b_to_k.weight', 'blocks.20.attn.b_to_v.weight'],
|
560 |
+
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc1.bias": "blocks.20.ff_b.0.bias",
|
561 |
+
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc1.weight": "blocks.20.ff_b.0.weight",
|
562 |
+
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc2.bias": "blocks.20.ff_b.2.bias",
|
563 |
+
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc2.weight": "blocks.20.ff_b.2.weight",
|
564 |
+
"model.diffusion_model.joint_blocks.20.x_block.adaLN_modulation.1.bias": "blocks.20.norm1_a.linear.bias",
|
565 |
+
"model.diffusion_model.joint_blocks.20.x_block.adaLN_modulation.1.weight": "blocks.20.norm1_a.linear.weight",
|
566 |
+
"model.diffusion_model.joint_blocks.20.x_block.attn.proj.bias": "blocks.20.attn.a_to_out.bias",
|
567 |
+
"model.diffusion_model.joint_blocks.20.x_block.attn.proj.weight": "blocks.20.attn.a_to_out.weight",
|
568 |
+
"model.diffusion_model.joint_blocks.20.x_block.attn.qkv.bias": ['blocks.20.attn.a_to_q.bias', 'blocks.20.attn.a_to_k.bias', 'blocks.20.attn.a_to_v.bias'],
|
569 |
+
"model.diffusion_model.joint_blocks.20.x_block.attn.qkv.weight": ['blocks.20.attn.a_to_q.weight', 'blocks.20.attn.a_to_k.weight', 'blocks.20.attn.a_to_v.weight'],
|
570 |
+
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc1.bias": "blocks.20.ff_a.0.bias",
|
571 |
+
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc1.weight": "blocks.20.ff_a.0.weight",
|
572 |
+
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc2.bias": "blocks.20.ff_a.2.bias",
|
573 |
+
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc2.weight": "blocks.20.ff_a.2.weight",
|
574 |
+
"model.diffusion_model.joint_blocks.21.context_block.adaLN_modulation.1.bias": "blocks.21.norm1_b.linear.bias",
|
575 |
+
"model.diffusion_model.joint_blocks.21.context_block.adaLN_modulation.1.weight": "blocks.21.norm1_b.linear.weight",
|
576 |
+
"model.diffusion_model.joint_blocks.21.context_block.attn.proj.bias": "blocks.21.attn.b_to_out.bias",
|
577 |
+
"model.diffusion_model.joint_blocks.21.context_block.attn.proj.weight": "blocks.21.attn.b_to_out.weight",
|
578 |
+
"model.diffusion_model.joint_blocks.21.context_block.attn.qkv.bias": ['blocks.21.attn.b_to_q.bias', 'blocks.21.attn.b_to_k.bias', 'blocks.21.attn.b_to_v.bias'],
|
579 |
+
"model.diffusion_model.joint_blocks.21.context_block.attn.qkv.weight": ['blocks.21.attn.b_to_q.weight', 'blocks.21.attn.b_to_k.weight', 'blocks.21.attn.b_to_v.weight'],
|
580 |
+
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc1.bias": "blocks.21.ff_b.0.bias",
|
581 |
+
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc1.weight": "blocks.21.ff_b.0.weight",
|
582 |
+
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc2.bias": "blocks.21.ff_b.2.bias",
|
583 |
+
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc2.weight": "blocks.21.ff_b.2.weight",
|
584 |
+
"model.diffusion_model.joint_blocks.21.x_block.adaLN_modulation.1.bias": "blocks.21.norm1_a.linear.bias",
|
585 |
+
"model.diffusion_model.joint_blocks.21.x_block.adaLN_modulation.1.weight": "blocks.21.norm1_a.linear.weight",
|
586 |
+
"model.diffusion_model.joint_blocks.21.x_block.attn.proj.bias": "blocks.21.attn.a_to_out.bias",
|
587 |
+
"model.diffusion_model.joint_blocks.21.x_block.attn.proj.weight": "blocks.21.attn.a_to_out.weight",
|
588 |
+
"model.diffusion_model.joint_blocks.21.x_block.attn.qkv.bias": ['blocks.21.attn.a_to_q.bias', 'blocks.21.attn.a_to_k.bias', 'blocks.21.attn.a_to_v.bias'],
|
589 |
+
"model.diffusion_model.joint_blocks.21.x_block.attn.qkv.weight": ['blocks.21.attn.a_to_q.weight', 'blocks.21.attn.a_to_k.weight', 'blocks.21.attn.a_to_v.weight'],
|
590 |
+
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc1.bias": "blocks.21.ff_a.0.bias",
|
591 |
+
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc1.weight": "blocks.21.ff_a.0.weight",
|
592 |
+
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc2.bias": "blocks.21.ff_a.2.bias",
|
593 |
+
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc2.weight": "blocks.21.ff_a.2.weight",
|
594 |
+
"model.diffusion_model.joint_blocks.22.context_block.adaLN_modulation.1.bias": "blocks.22.norm1_b.linear.bias",
|
595 |
+
"model.diffusion_model.joint_blocks.22.context_block.adaLN_modulation.1.weight": "blocks.22.norm1_b.linear.weight",
|
596 |
+
"model.diffusion_model.joint_blocks.22.context_block.attn.proj.bias": "blocks.22.attn.b_to_out.bias",
|
597 |
+
"model.diffusion_model.joint_blocks.22.context_block.attn.proj.weight": "blocks.22.attn.b_to_out.weight",
|
598 |
+
"model.diffusion_model.joint_blocks.22.context_block.attn.qkv.bias": ['blocks.22.attn.b_to_q.bias', 'blocks.22.attn.b_to_k.bias', 'blocks.22.attn.b_to_v.bias'],
|
599 |
+
"model.diffusion_model.joint_blocks.22.context_block.attn.qkv.weight": ['blocks.22.attn.b_to_q.weight', 'blocks.22.attn.b_to_k.weight', 'blocks.22.attn.b_to_v.weight'],
|
600 |
+
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc1.bias": "blocks.22.ff_b.0.bias",
|
601 |
+
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc1.weight": "blocks.22.ff_b.0.weight",
|
602 |
+
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc2.bias": "blocks.22.ff_b.2.bias",
|
603 |
+
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc2.weight": "blocks.22.ff_b.2.weight",
|
604 |
+
"model.diffusion_model.joint_blocks.22.x_block.adaLN_modulation.1.bias": "blocks.22.norm1_a.linear.bias",
|
605 |
+
"model.diffusion_model.joint_blocks.22.x_block.adaLN_modulation.1.weight": "blocks.22.norm1_a.linear.weight",
|
606 |
+
"model.diffusion_model.joint_blocks.22.x_block.attn.proj.bias": "blocks.22.attn.a_to_out.bias",
|
607 |
+
"model.diffusion_model.joint_blocks.22.x_block.attn.proj.weight": "blocks.22.attn.a_to_out.weight",
|
608 |
+
"model.diffusion_model.joint_blocks.22.x_block.attn.qkv.bias": ['blocks.22.attn.a_to_q.bias', 'blocks.22.attn.a_to_k.bias', 'blocks.22.attn.a_to_v.bias'],
|
609 |
+
"model.diffusion_model.joint_blocks.22.x_block.attn.qkv.weight": ['blocks.22.attn.a_to_q.weight', 'blocks.22.attn.a_to_k.weight', 'blocks.22.attn.a_to_v.weight'],
|
610 |
+
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc1.bias": "blocks.22.ff_a.0.bias",
|
611 |
+
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc1.weight": "blocks.22.ff_a.0.weight",
|
612 |
+
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc2.bias": "blocks.22.ff_a.2.bias",
|
613 |
+
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc2.weight": "blocks.22.ff_a.2.weight",
|
614 |
+
"model.diffusion_model.joint_blocks.23.context_block.attn.qkv.bias": ['blocks.23.attn.b_to_q.bias', 'blocks.23.attn.b_to_k.bias', 'blocks.23.attn.b_to_v.bias'],
|
615 |
+
"model.diffusion_model.joint_blocks.23.context_block.attn.qkv.weight": ['blocks.23.attn.b_to_q.weight', 'blocks.23.attn.b_to_k.weight', 'blocks.23.attn.b_to_v.weight'],
|
616 |
+
"model.diffusion_model.joint_blocks.23.x_block.adaLN_modulation.1.bias": "blocks.23.norm1_a.linear.bias",
|
617 |
+
"model.diffusion_model.joint_blocks.23.x_block.adaLN_modulation.1.weight": "blocks.23.norm1_a.linear.weight",
|
618 |
+
"model.diffusion_model.joint_blocks.23.x_block.attn.proj.bias": "blocks.23.attn.a_to_out.bias",
|
619 |
+
"model.diffusion_model.joint_blocks.23.x_block.attn.proj.weight": "blocks.23.attn.a_to_out.weight",
|
620 |
+
"model.diffusion_model.joint_blocks.23.x_block.attn.qkv.bias": ['blocks.23.attn.a_to_q.bias', 'blocks.23.attn.a_to_k.bias', 'blocks.23.attn.a_to_v.bias'],
|
621 |
+
"model.diffusion_model.joint_blocks.23.x_block.attn.qkv.weight": ['blocks.23.attn.a_to_q.weight', 'blocks.23.attn.a_to_k.weight', 'blocks.23.attn.a_to_v.weight'],
|
622 |
+
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc1.bias": "blocks.23.ff_a.0.bias",
|
623 |
+
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc1.weight": "blocks.23.ff_a.0.weight",
|
624 |
+
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc2.bias": "blocks.23.ff_a.2.bias",
|
625 |
+
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc2.weight": "blocks.23.ff_a.2.weight",
|
626 |
+
"model.diffusion_model.joint_blocks.3.context_block.adaLN_modulation.1.bias": "blocks.3.norm1_b.linear.bias",
|
627 |
+
"model.diffusion_model.joint_blocks.3.context_block.adaLN_modulation.1.weight": "blocks.3.norm1_b.linear.weight",
|
628 |
+
"model.diffusion_model.joint_blocks.3.context_block.attn.proj.bias": "blocks.3.attn.b_to_out.bias",
|
629 |
+
"model.diffusion_model.joint_blocks.3.context_block.attn.proj.weight": "blocks.3.attn.b_to_out.weight",
|
630 |
+
"model.diffusion_model.joint_blocks.3.context_block.attn.qkv.bias": ['blocks.3.attn.b_to_q.bias', 'blocks.3.attn.b_to_k.bias', 'blocks.3.attn.b_to_v.bias'],
|
631 |
+
"model.diffusion_model.joint_blocks.3.context_block.attn.qkv.weight": ['blocks.3.attn.b_to_q.weight', 'blocks.3.attn.b_to_k.weight', 'blocks.3.attn.b_to_v.weight'],
|
632 |
+
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc1.bias": "blocks.3.ff_b.0.bias",
|
633 |
+
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc1.weight": "blocks.3.ff_b.0.weight",
|
634 |
+
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc2.bias": "blocks.3.ff_b.2.bias",
|
635 |
+
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc2.weight": "blocks.3.ff_b.2.weight",
|
636 |
+
"model.diffusion_model.joint_blocks.3.x_block.adaLN_modulation.1.bias": "blocks.3.norm1_a.linear.bias",
|
637 |
+
"model.diffusion_model.joint_blocks.3.x_block.adaLN_modulation.1.weight": "blocks.3.norm1_a.linear.weight",
|
638 |
+
"model.diffusion_model.joint_blocks.3.x_block.attn.proj.bias": "blocks.3.attn.a_to_out.bias",
|
639 |
+
"model.diffusion_model.joint_blocks.3.x_block.attn.proj.weight": "blocks.3.attn.a_to_out.weight",
|
640 |
+
"model.diffusion_model.joint_blocks.3.x_block.attn.qkv.bias": ['blocks.3.attn.a_to_q.bias', 'blocks.3.attn.a_to_k.bias', 'blocks.3.attn.a_to_v.bias'],
|
641 |
+
"model.diffusion_model.joint_blocks.3.x_block.attn.qkv.weight": ['blocks.3.attn.a_to_q.weight', 'blocks.3.attn.a_to_k.weight', 'blocks.3.attn.a_to_v.weight'],
|
642 |
+
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc1.bias": "blocks.3.ff_a.0.bias",
|
643 |
+
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc1.weight": "blocks.3.ff_a.0.weight",
|
644 |
+
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc2.bias": "blocks.3.ff_a.2.bias",
|
645 |
+
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc2.weight": "blocks.3.ff_a.2.weight",
|
646 |
+
"model.diffusion_model.joint_blocks.4.context_block.adaLN_modulation.1.bias": "blocks.4.norm1_b.linear.bias",
|
647 |
+
"model.diffusion_model.joint_blocks.4.context_block.adaLN_modulation.1.weight": "blocks.4.norm1_b.linear.weight",
|
648 |
+
"model.diffusion_model.joint_blocks.4.context_block.attn.proj.bias": "blocks.4.attn.b_to_out.bias",
|
649 |
+
"model.diffusion_model.joint_blocks.4.context_block.attn.proj.weight": "blocks.4.attn.b_to_out.weight",
|
650 |
+
"model.diffusion_model.joint_blocks.4.context_block.attn.qkv.bias": ['blocks.4.attn.b_to_q.bias', 'blocks.4.attn.b_to_k.bias', 'blocks.4.attn.b_to_v.bias'],
|
651 |
+
"model.diffusion_model.joint_blocks.4.context_block.attn.qkv.weight": ['blocks.4.attn.b_to_q.weight', 'blocks.4.attn.b_to_k.weight', 'blocks.4.attn.b_to_v.weight'],
|
652 |
+
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc1.bias": "blocks.4.ff_b.0.bias",
|
653 |
+
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc1.weight": "blocks.4.ff_b.0.weight",
|
654 |
+
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc2.bias": "blocks.4.ff_b.2.bias",
|
655 |
+
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc2.weight": "blocks.4.ff_b.2.weight",
|
656 |
+
"model.diffusion_model.joint_blocks.4.x_block.adaLN_modulation.1.bias": "blocks.4.norm1_a.linear.bias",
|
657 |
+
"model.diffusion_model.joint_blocks.4.x_block.adaLN_modulation.1.weight": "blocks.4.norm1_a.linear.weight",
|
658 |
+
"model.diffusion_model.joint_blocks.4.x_block.attn.proj.bias": "blocks.4.attn.a_to_out.bias",
|
659 |
+
"model.diffusion_model.joint_blocks.4.x_block.attn.proj.weight": "blocks.4.attn.a_to_out.weight",
|
660 |
+
"model.diffusion_model.joint_blocks.4.x_block.attn.qkv.bias": ['blocks.4.attn.a_to_q.bias', 'blocks.4.attn.a_to_k.bias', 'blocks.4.attn.a_to_v.bias'],
|
661 |
+
"model.diffusion_model.joint_blocks.4.x_block.attn.qkv.weight": ['blocks.4.attn.a_to_q.weight', 'blocks.4.attn.a_to_k.weight', 'blocks.4.attn.a_to_v.weight'],
|
662 |
+
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc1.bias": "blocks.4.ff_a.0.bias",
|
663 |
+
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc1.weight": "blocks.4.ff_a.0.weight",
|
664 |
+
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc2.bias": "blocks.4.ff_a.2.bias",
|
665 |
+
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc2.weight": "blocks.4.ff_a.2.weight",
|
666 |
+
"model.diffusion_model.joint_blocks.5.context_block.adaLN_modulation.1.bias": "blocks.5.norm1_b.linear.bias",
|
667 |
+
"model.diffusion_model.joint_blocks.5.context_block.adaLN_modulation.1.weight": "blocks.5.norm1_b.linear.weight",
|
668 |
+
"model.diffusion_model.joint_blocks.5.context_block.attn.proj.bias": "blocks.5.attn.b_to_out.bias",
|
669 |
+
"model.diffusion_model.joint_blocks.5.context_block.attn.proj.weight": "blocks.5.attn.b_to_out.weight",
|
670 |
+
"model.diffusion_model.joint_blocks.5.context_block.attn.qkv.bias": ['blocks.5.attn.b_to_q.bias', 'blocks.5.attn.b_to_k.bias', 'blocks.5.attn.b_to_v.bias'],
|
671 |
+
"model.diffusion_model.joint_blocks.5.context_block.attn.qkv.weight": ['blocks.5.attn.b_to_q.weight', 'blocks.5.attn.b_to_k.weight', 'blocks.5.attn.b_to_v.weight'],
|
672 |
+
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc1.bias": "blocks.5.ff_b.0.bias",
|
673 |
+
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc1.weight": "blocks.5.ff_b.0.weight",
|
674 |
+
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc2.bias": "blocks.5.ff_b.2.bias",
|
675 |
+
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc2.weight": "blocks.5.ff_b.2.weight",
|
676 |
+
"model.diffusion_model.joint_blocks.5.x_block.adaLN_modulation.1.bias": "blocks.5.norm1_a.linear.bias",
|
677 |
+
"model.diffusion_model.joint_blocks.5.x_block.adaLN_modulation.1.weight": "blocks.5.norm1_a.linear.weight",
|
678 |
+
"model.diffusion_model.joint_blocks.5.x_block.attn.proj.bias": "blocks.5.attn.a_to_out.bias",
|
679 |
+
"model.diffusion_model.joint_blocks.5.x_block.attn.proj.weight": "blocks.5.attn.a_to_out.weight",
|
680 |
+
"model.diffusion_model.joint_blocks.5.x_block.attn.qkv.bias": ['blocks.5.attn.a_to_q.bias', 'blocks.5.attn.a_to_k.bias', 'blocks.5.attn.a_to_v.bias'],
|
681 |
+
"model.diffusion_model.joint_blocks.5.x_block.attn.qkv.weight": ['blocks.5.attn.a_to_q.weight', 'blocks.5.attn.a_to_k.weight', 'blocks.5.attn.a_to_v.weight'],
|
682 |
+
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc1.bias": "blocks.5.ff_a.0.bias",
|
683 |
+
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc1.weight": "blocks.5.ff_a.0.weight",
|
684 |
+
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc2.bias": "blocks.5.ff_a.2.bias",
|
685 |
+
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc2.weight": "blocks.5.ff_a.2.weight",
|
686 |
+
"model.diffusion_model.joint_blocks.6.context_block.adaLN_modulation.1.bias": "blocks.6.norm1_b.linear.bias",
|
687 |
+
"model.diffusion_model.joint_blocks.6.context_block.adaLN_modulation.1.weight": "blocks.6.norm1_b.linear.weight",
|
688 |
+
"model.diffusion_model.joint_blocks.6.context_block.attn.proj.bias": "blocks.6.attn.b_to_out.bias",
|
689 |
+
"model.diffusion_model.joint_blocks.6.context_block.attn.proj.weight": "blocks.6.attn.b_to_out.weight",
|
690 |
+
"model.diffusion_model.joint_blocks.6.context_block.attn.qkv.bias": ['blocks.6.attn.b_to_q.bias', 'blocks.6.attn.b_to_k.bias', 'blocks.6.attn.b_to_v.bias'],
|
691 |
+
"model.diffusion_model.joint_blocks.6.context_block.attn.qkv.weight": ['blocks.6.attn.b_to_q.weight', 'blocks.6.attn.b_to_k.weight', 'blocks.6.attn.b_to_v.weight'],
|
692 |
+
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc1.bias": "blocks.6.ff_b.0.bias",
|
693 |
+
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc1.weight": "blocks.6.ff_b.0.weight",
|
694 |
+
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc2.bias": "blocks.6.ff_b.2.bias",
|
695 |
+
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc2.weight": "blocks.6.ff_b.2.weight",
|
696 |
+
"model.diffusion_model.joint_blocks.6.x_block.adaLN_modulation.1.bias": "blocks.6.norm1_a.linear.bias",
|
697 |
+
"model.diffusion_model.joint_blocks.6.x_block.adaLN_modulation.1.weight": "blocks.6.norm1_a.linear.weight",
|
698 |
+
"model.diffusion_model.joint_blocks.6.x_block.attn.proj.bias": "blocks.6.attn.a_to_out.bias",
|
699 |
+
"model.diffusion_model.joint_blocks.6.x_block.attn.proj.weight": "blocks.6.attn.a_to_out.weight",
|
700 |
+
"model.diffusion_model.joint_blocks.6.x_block.attn.qkv.bias": ['blocks.6.attn.a_to_q.bias', 'blocks.6.attn.a_to_k.bias', 'blocks.6.attn.a_to_v.bias'],
|
701 |
+
"model.diffusion_model.joint_blocks.6.x_block.attn.qkv.weight": ['blocks.6.attn.a_to_q.weight', 'blocks.6.attn.a_to_k.weight', 'blocks.6.attn.a_to_v.weight'],
|
702 |
+
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc1.bias": "blocks.6.ff_a.0.bias",
|
703 |
+
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc1.weight": "blocks.6.ff_a.0.weight",
|
704 |
+
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc2.bias": "blocks.6.ff_a.2.bias",
|
705 |
+
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc2.weight": "blocks.6.ff_a.2.weight",
|
706 |
+
"model.diffusion_model.joint_blocks.7.context_block.adaLN_modulation.1.bias": "blocks.7.norm1_b.linear.bias",
|
707 |
+
"model.diffusion_model.joint_blocks.7.context_block.adaLN_modulation.1.weight": "blocks.7.norm1_b.linear.weight",
|
708 |
+
"model.diffusion_model.joint_blocks.7.context_block.attn.proj.bias": "blocks.7.attn.b_to_out.bias",
|
709 |
+
"model.diffusion_model.joint_blocks.7.context_block.attn.proj.weight": "blocks.7.attn.b_to_out.weight",
|
710 |
+
"model.diffusion_model.joint_blocks.7.context_block.attn.qkv.bias": ['blocks.7.attn.b_to_q.bias', 'blocks.7.attn.b_to_k.bias', 'blocks.7.attn.b_to_v.bias'],
|
711 |
+
"model.diffusion_model.joint_blocks.7.context_block.attn.qkv.weight": ['blocks.7.attn.b_to_q.weight', 'blocks.7.attn.b_to_k.weight', 'blocks.7.attn.b_to_v.weight'],
|
712 |
+
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc1.bias": "blocks.7.ff_b.0.bias",
|
713 |
+
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc1.weight": "blocks.7.ff_b.0.weight",
|
714 |
+
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc2.bias": "blocks.7.ff_b.2.bias",
|
715 |
+
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc2.weight": "blocks.7.ff_b.2.weight",
|
716 |
+
"model.diffusion_model.joint_blocks.7.x_block.adaLN_modulation.1.bias": "blocks.7.norm1_a.linear.bias",
|
717 |
+
"model.diffusion_model.joint_blocks.7.x_block.adaLN_modulation.1.weight": "blocks.7.norm1_a.linear.weight",
|
718 |
+
"model.diffusion_model.joint_blocks.7.x_block.attn.proj.bias": "blocks.7.attn.a_to_out.bias",
|
719 |
+
"model.diffusion_model.joint_blocks.7.x_block.attn.proj.weight": "blocks.7.attn.a_to_out.weight",
|
720 |
+
"model.diffusion_model.joint_blocks.7.x_block.attn.qkv.bias": ['blocks.7.attn.a_to_q.bias', 'blocks.7.attn.a_to_k.bias', 'blocks.7.attn.a_to_v.bias'],
|
721 |
+
"model.diffusion_model.joint_blocks.7.x_block.attn.qkv.weight": ['blocks.7.attn.a_to_q.weight', 'blocks.7.attn.a_to_k.weight', 'blocks.7.attn.a_to_v.weight'],
|
722 |
+
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc1.bias": "blocks.7.ff_a.0.bias",
|
723 |
+
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc1.weight": "blocks.7.ff_a.0.weight",
|
724 |
+
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc2.bias": "blocks.7.ff_a.2.bias",
|
725 |
+
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc2.weight": "blocks.7.ff_a.2.weight",
|
726 |
+
"model.diffusion_model.joint_blocks.8.context_block.adaLN_modulation.1.bias": "blocks.8.norm1_b.linear.bias",
|
727 |
+
"model.diffusion_model.joint_blocks.8.context_block.adaLN_modulation.1.weight": "blocks.8.norm1_b.linear.weight",
|
728 |
+
"model.diffusion_model.joint_blocks.8.context_block.attn.proj.bias": "blocks.8.attn.b_to_out.bias",
|
729 |
+
"model.diffusion_model.joint_blocks.8.context_block.attn.proj.weight": "blocks.8.attn.b_to_out.weight",
|
730 |
+
"model.diffusion_model.joint_blocks.8.context_block.attn.qkv.bias": ['blocks.8.attn.b_to_q.bias', 'blocks.8.attn.b_to_k.bias', 'blocks.8.attn.b_to_v.bias'],
|
731 |
+
"model.diffusion_model.joint_blocks.8.context_block.attn.qkv.weight": ['blocks.8.attn.b_to_q.weight', 'blocks.8.attn.b_to_k.weight', 'blocks.8.attn.b_to_v.weight'],
|
732 |
+
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc1.bias": "blocks.8.ff_b.0.bias",
|
733 |
+
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc1.weight": "blocks.8.ff_b.0.weight",
|
734 |
+
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc2.bias": "blocks.8.ff_b.2.bias",
|
735 |
+
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc2.weight": "blocks.8.ff_b.2.weight",
|
736 |
+
"model.diffusion_model.joint_blocks.8.x_block.adaLN_modulation.1.bias": "blocks.8.norm1_a.linear.bias",
|
737 |
+
"model.diffusion_model.joint_blocks.8.x_block.adaLN_modulation.1.weight": "blocks.8.norm1_a.linear.weight",
|
738 |
+
"model.diffusion_model.joint_blocks.8.x_block.attn.proj.bias": "blocks.8.attn.a_to_out.bias",
|
739 |
+
"model.diffusion_model.joint_blocks.8.x_block.attn.proj.weight": "blocks.8.attn.a_to_out.weight",
|
740 |
+
"model.diffusion_model.joint_blocks.8.x_block.attn.qkv.bias": ['blocks.8.attn.a_to_q.bias', 'blocks.8.attn.a_to_k.bias', 'blocks.8.attn.a_to_v.bias'],
|
741 |
+
"model.diffusion_model.joint_blocks.8.x_block.attn.qkv.weight": ['blocks.8.attn.a_to_q.weight', 'blocks.8.attn.a_to_k.weight', 'blocks.8.attn.a_to_v.weight'],
|
742 |
+
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc1.bias": "blocks.8.ff_a.0.bias",
|
743 |
+
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc1.weight": "blocks.8.ff_a.0.weight",
|
744 |
+
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc2.bias": "blocks.8.ff_a.2.bias",
|
745 |
+
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc2.weight": "blocks.8.ff_a.2.weight",
|
746 |
+
"model.diffusion_model.joint_blocks.9.context_block.adaLN_modulation.1.bias": "blocks.9.norm1_b.linear.bias",
|
747 |
+
"model.diffusion_model.joint_blocks.9.context_block.adaLN_modulation.1.weight": "blocks.9.norm1_b.linear.weight",
|
748 |
+
"model.diffusion_model.joint_blocks.9.context_block.attn.proj.bias": "blocks.9.attn.b_to_out.bias",
|
749 |
+
"model.diffusion_model.joint_blocks.9.context_block.attn.proj.weight": "blocks.9.attn.b_to_out.weight",
|
750 |
+
"model.diffusion_model.joint_blocks.9.context_block.attn.qkv.bias": ['blocks.9.attn.b_to_q.bias', 'blocks.9.attn.b_to_k.bias', 'blocks.9.attn.b_to_v.bias'],
|
751 |
+
"model.diffusion_model.joint_blocks.9.context_block.attn.qkv.weight": ['blocks.9.attn.b_to_q.weight', 'blocks.9.attn.b_to_k.weight', 'blocks.9.attn.b_to_v.weight'],
|
752 |
+
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc1.bias": "blocks.9.ff_b.0.bias",
|
753 |
+
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc1.weight": "blocks.9.ff_b.0.weight",
|
754 |
+
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc2.bias": "blocks.9.ff_b.2.bias",
|
755 |
+
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc2.weight": "blocks.9.ff_b.2.weight",
|
756 |
+
"model.diffusion_model.joint_blocks.9.x_block.adaLN_modulation.1.bias": "blocks.9.norm1_a.linear.bias",
|
757 |
+
"model.diffusion_model.joint_blocks.9.x_block.adaLN_modulation.1.weight": "blocks.9.norm1_a.linear.weight",
|
758 |
+
"model.diffusion_model.joint_blocks.9.x_block.attn.proj.bias": "blocks.9.attn.a_to_out.bias",
|
759 |
+
"model.diffusion_model.joint_blocks.9.x_block.attn.proj.weight": "blocks.9.attn.a_to_out.weight",
|
760 |
+
"model.diffusion_model.joint_blocks.9.x_block.attn.qkv.bias": ['blocks.9.attn.a_to_q.bias', 'blocks.9.attn.a_to_k.bias', 'blocks.9.attn.a_to_v.bias'],
|
761 |
+
"model.diffusion_model.joint_blocks.9.x_block.attn.qkv.weight": ['blocks.9.attn.a_to_q.weight', 'blocks.9.attn.a_to_k.weight', 'blocks.9.attn.a_to_v.weight'],
|
762 |
+
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc1.bias": "blocks.9.ff_a.0.bias",
|
763 |
+
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc1.weight": "blocks.9.ff_a.0.weight",
|
764 |
+
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.bias": "blocks.9.ff_a.2.bias",
|
765 |
+
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.weight": "blocks.9.ff_a.2.weight",
|
766 |
+
"model.diffusion_model.pos_embed": "pos_embedder.pos_embed",
|
767 |
+
"model.diffusion_model.t_embedder.mlp.0.bias": "time_embedder.timestep_embedder.0.bias",
|
768 |
+
"model.diffusion_model.t_embedder.mlp.0.weight": "time_embedder.timestep_embedder.0.weight",
|
769 |
+
"model.diffusion_model.t_embedder.mlp.2.bias": "time_embedder.timestep_embedder.2.bias",
|
770 |
+
"model.diffusion_model.t_embedder.mlp.2.weight": "time_embedder.timestep_embedder.2.weight",
|
771 |
+
"model.diffusion_model.x_embedder.proj.bias": "pos_embedder.proj.bias",
|
772 |
+
"model.diffusion_model.x_embedder.proj.weight": "pos_embedder.proj.weight",
|
773 |
+
"model.diffusion_model.y_embedder.mlp.0.bias": "pooled_text_embedder.0.bias",
|
774 |
+
"model.diffusion_model.y_embedder.mlp.0.weight": "pooled_text_embedder.0.weight",
|
775 |
+
"model.diffusion_model.y_embedder.mlp.2.bias": "pooled_text_embedder.2.bias",
|
776 |
+
"model.diffusion_model.y_embedder.mlp.2.weight": "pooled_text_embedder.2.weight",
|
777 |
+
|
778 |
+
"model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.weight": "blocks.23.norm1_b.linear.weight",
|
779 |
+
"model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.bias": "blocks.23.norm1_b.linear.bias",
|
780 |
+
"model.diffusion_model.final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
|
781 |
+
"model.diffusion_model.final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
|
782 |
+
}
|
783 |
+
state_dict_ = {}
|
784 |
+
for name in state_dict:
|
785 |
+
if name in rename_dict:
|
786 |
+
param = state_dict[name]
|
787 |
+
if name.startswith("model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1."):
|
788 |
+
param = torch.concat([param[1536:], param[:1536]], axis=0)
|
789 |
+
elif name.startswith("model.diffusion_model.final_layer.adaLN_modulation.1."):
|
790 |
+
param = torch.concat([param[1536:], param[:1536]], axis=0)
|
791 |
+
elif name == "model.diffusion_model.pos_embed":
|
792 |
+
param = param.reshape((1, 192, 192, 1536))
|
793 |
+
if isinstance(rename_dict[name], str):
|
794 |
+
state_dict_[rename_dict[name]] = param
|
795 |
+
else:
|
796 |
+
name_ = rename_dict[name][0].replace(".a_to_q.", ".a_to_qkv.").replace(".b_to_q.", ".b_to_qkv.")
|
797 |
+
state_dict_[name_] = param
|
798 |
+
return state_dict_
|
diffsynth/models/sd3_text_encoder.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
diffsynth/models/sd3_vae_decoder.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .sd_vae_decoder import VAEAttentionBlock, SDVAEDecoderStateDictConverter
|
3 |
+
from .sd_unet import ResnetBlock, UpSampler
|
4 |
+
from .tiler import TileWorker
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
class SD3VAEDecoder(torch.nn.Module):
|
9 |
+
def __init__(self):
|
10 |
+
super().__init__()
|
11 |
+
self.scaling_factor = 1.5305 # Different from SD 1.x
|
12 |
+
self.shift_factor = 0.0609 # Different from SD 1.x
|
13 |
+
self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x
|
14 |
+
|
15 |
+
self.blocks = torch.nn.ModuleList([
|
16 |
+
# UNetMidBlock2D
|
17 |
+
ResnetBlock(512, 512, eps=1e-6),
|
18 |
+
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
19 |
+
ResnetBlock(512, 512, eps=1e-6),
|
20 |
+
# UpDecoderBlock2D
|
21 |
+
ResnetBlock(512, 512, eps=1e-6),
|
22 |
+
ResnetBlock(512, 512, eps=1e-6),
|
23 |
+
ResnetBlock(512, 512, eps=1e-6),
|
24 |
+
UpSampler(512),
|
25 |
+
# UpDecoderBlock2D
|
26 |
+
ResnetBlock(512, 512, eps=1e-6),
|
27 |
+
ResnetBlock(512, 512, eps=1e-6),
|
28 |
+
ResnetBlock(512, 512, eps=1e-6),
|
29 |
+
UpSampler(512),
|
30 |
+
# UpDecoderBlock2D
|
31 |
+
ResnetBlock(512, 256, eps=1e-6),
|
32 |
+
ResnetBlock(256, 256, eps=1e-6),
|
33 |
+
ResnetBlock(256, 256, eps=1e-6),
|
34 |
+
UpSampler(256),
|
35 |
+
# UpDecoderBlock2D
|
36 |
+
ResnetBlock(256, 128, eps=1e-6),
|
37 |
+
ResnetBlock(128, 128, eps=1e-6),
|
38 |
+
ResnetBlock(128, 128, eps=1e-6),
|
39 |
+
])
|
40 |
+
|
41 |
+
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6)
|
42 |
+
self.conv_act = torch.nn.SiLU()
|
43 |
+
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
44 |
+
|
45 |
+
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
46 |
+
hidden_states = TileWorker().tiled_forward(
|
47 |
+
lambda x: self.forward(x),
|
48 |
+
sample,
|
49 |
+
tile_size,
|
50 |
+
tile_stride,
|
51 |
+
tile_device=sample.device,
|
52 |
+
tile_dtype=sample.dtype
|
53 |
+
)
|
54 |
+
return hidden_states
|
55 |
+
|
56 |
+
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
57 |
+
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
58 |
+
if tiled:
|
59 |
+
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
60 |
+
|
61 |
+
# 1. pre-process
|
62 |
+
hidden_states = sample / self.scaling_factor + self.shift_factor
|
63 |
+
hidden_states = self.conv_in(hidden_states)
|
64 |
+
time_emb = None
|
65 |
+
text_emb = None
|
66 |
+
res_stack = None
|
67 |
+
|
68 |
+
# 2. blocks
|
69 |
+
for i, block in enumerate(self.blocks):
|
70 |
+
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
71 |
+
|
72 |
+
# 3. output
|
73 |
+
hidden_states = self.conv_norm_out(hidden_states)
|
74 |
+
hidden_states = self.conv_act(hidden_states)
|
75 |
+
hidden_states = self.conv_out(hidden_states)
|
76 |
+
|
77 |
+
return hidden_states
|
78 |
+
|
79 |
+
@staticmethod
|
80 |
+
def state_dict_converter():
|
81 |
+
return SDVAEDecoderStateDictConverter()
|
diffsynth/models/sd3_vae_encoder.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .sd_unet import ResnetBlock, DownSampler
|
3 |
+
from .sd_vae_encoder import VAEAttentionBlock, SDVAEEncoderStateDictConverter
|
4 |
+
from .tiler import TileWorker
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
|
8 |
+
class SD3VAEEncoder(torch.nn.Module):
|
9 |
+
def __init__(self):
|
10 |
+
super().__init__()
|
11 |
+
self.scaling_factor = 1.5305 # Different from SD 1.x
|
12 |
+
self.shift_factor = 0.0609 # Different from SD 1.x
|
13 |
+
self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
|
14 |
+
|
15 |
+
self.blocks = torch.nn.ModuleList([
|
16 |
+
# DownEncoderBlock2D
|
17 |
+
ResnetBlock(128, 128, eps=1e-6),
|
18 |
+
ResnetBlock(128, 128, eps=1e-6),
|
19 |
+
DownSampler(128, padding=0, extra_padding=True),
|
20 |
+
# DownEncoderBlock2D
|
21 |
+
ResnetBlock(128, 256, eps=1e-6),
|
22 |
+
ResnetBlock(256, 256, eps=1e-6),
|
23 |
+
DownSampler(256, padding=0, extra_padding=True),
|
24 |
+
# DownEncoderBlock2D
|
25 |
+
ResnetBlock(256, 512, eps=1e-6),
|
26 |
+
ResnetBlock(512, 512, eps=1e-6),
|
27 |
+
DownSampler(512, padding=0, extra_padding=True),
|
28 |
+
# DownEncoderBlock2D
|
29 |
+
ResnetBlock(512, 512, eps=1e-6),
|
30 |
+
ResnetBlock(512, 512, eps=1e-6),
|
31 |
+
# UNetMidBlock2D
|
32 |
+
ResnetBlock(512, 512, eps=1e-6),
|
33 |
+
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
34 |
+
ResnetBlock(512, 512, eps=1e-6),
|
35 |
+
])
|
36 |
+
|
37 |
+
self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
|
38 |
+
self.conv_act = torch.nn.SiLU()
|
39 |
+
self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1)
|
40 |
+
|
41 |
+
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
42 |
+
hidden_states = TileWorker().tiled_forward(
|
43 |
+
lambda x: self.forward(x),
|
44 |
+
sample,
|
45 |
+
tile_size,
|
46 |
+
tile_stride,
|
47 |
+
tile_device=sample.device,
|
48 |
+
tile_dtype=sample.dtype
|
49 |
+
)
|
50 |
+
return hidden_states
|
51 |
+
|
52 |
+
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
53 |
+
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
54 |
+
if tiled:
|
55 |
+
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
56 |
+
|
57 |
+
# 1. pre-process
|
58 |
+
hidden_states = self.conv_in(sample)
|
59 |
+
time_emb = None
|
60 |
+
text_emb = None
|
61 |
+
res_stack = None
|
62 |
+
|
63 |
+
# 2. blocks
|
64 |
+
for i, block in enumerate(self.blocks):
|
65 |
+
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
66 |
+
|
67 |
+
# 3. output
|
68 |
+
hidden_states = self.conv_norm_out(hidden_states)
|
69 |
+
hidden_states = self.conv_act(hidden_states)
|
70 |
+
hidden_states = self.conv_out(hidden_states)
|
71 |
+
hidden_states = hidden_states[:, :16]
|
72 |
+
hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor
|
73 |
+
|
74 |
+
return hidden_states
|
75 |
+
|
76 |
+
def encode_video(self, sample, batch_size=8):
|
77 |
+
B = sample.shape[0]
|
78 |
+
hidden_states = []
|
79 |
+
|
80 |
+
for i in range(0, sample.shape[2], batch_size):
|
81 |
+
|
82 |
+
j = min(i + batch_size, sample.shape[2])
|
83 |
+
sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
|
84 |
+
|
85 |
+
hidden_states_batch = self(sample_batch)
|
86 |
+
hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
|
87 |
+
|
88 |
+
hidden_states.append(hidden_states_batch)
|
89 |
+
|
90 |
+
hidden_states = torch.concat(hidden_states, dim=2)
|
91 |
+
return hidden_states
|
92 |
+
|
93 |
+
@staticmethod
|
94 |
+
def state_dict_converter():
|
95 |
+
return SDVAEEncoderStateDictConverter()
|
diffsynth/models/sd_controlnet.py
ADDED
@@ -0,0 +1,589 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
|
3 |
+
from .tiler import TileWorker
|
4 |
+
|
5 |
+
|
6 |
+
class ControlNetConditioningLayer(torch.nn.Module):
|
7 |
+
def __init__(self, channels = (3, 16, 32, 96, 256, 320)):
|
8 |
+
super().__init__()
|
9 |
+
self.blocks = torch.nn.ModuleList([])
|
10 |
+
self.blocks.append(torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1))
|
11 |
+
self.blocks.append(torch.nn.SiLU())
|
12 |
+
for i in range(1, len(channels) - 2):
|
13 |
+
self.blocks.append(torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1))
|
14 |
+
self.blocks.append(torch.nn.SiLU())
|
15 |
+
self.blocks.append(torch.nn.Conv2d(channels[i], channels[i+1], kernel_size=3, padding=1, stride=2))
|
16 |
+
self.blocks.append(torch.nn.SiLU())
|
17 |
+
self.blocks.append(torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1))
|
18 |
+
|
19 |
+
def forward(self, conditioning):
|
20 |
+
for block in self.blocks:
|
21 |
+
conditioning = block(conditioning)
|
22 |
+
return conditioning
|
23 |
+
|
24 |
+
|
25 |
+
class SDControlNet(torch.nn.Module):
|
26 |
+
def __init__(self, global_pool=False):
|
27 |
+
super().__init__()
|
28 |
+
self.time_proj = Timesteps(320)
|
29 |
+
self.time_embedding = torch.nn.Sequential(
|
30 |
+
torch.nn.Linear(320, 1280),
|
31 |
+
torch.nn.SiLU(),
|
32 |
+
torch.nn.Linear(1280, 1280)
|
33 |
+
)
|
34 |
+
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
|
35 |
+
|
36 |
+
self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
|
37 |
+
|
38 |
+
self.blocks = torch.nn.ModuleList([
|
39 |
+
# CrossAttnDownBlock2D
|
40 |
+
ResnetBlock(320, 320, 1280),
|
41 |
+
AttentionBlock(8, 40, 320, 1, 768),
|
42 |
+
PushBlock(),
|
43 |
+
ResnetBlock(320, 320, 1280),
|
44 |
+
AttentionBlock(8, 40, 320, 1, 768),
|
45 |
+
PushBlock(),
|
46 |
+
DownSampler(320),
|
47 |
+
PushBlock(),
|
48 |
+
# CrossAttnDownBlock2D
|
49 |
+
ResnetBlock(320, 640, 1280),
|
50 |
+
AttentionBlock(8, 80, 640, 1, 768),
|
51 |
+
PushBlock(),
|
52 |
+
ResnetBlock(640, 640, 1280),
|
53 |
+
AttentionBlock(8, 80, 640, 1, 768),
|
54 |
+
PushBlock(),
|
55 |
+
DownSampler(640),
|
56 |
+
PushBlock(),
|
57 |
+
# CrossAttnDownBlock2D
|
58 |
+
ResnetBlock(640, 1280, 1280),
|
59 |
+
AttentionBlock(8, 160, 1280, 1, 768),
|
60 |
+
PushBlock(),
|
61 |
+
ResnetBlock(1280, 1280, 1280),
|
62 |
+
AttentionBlock(8, 160, 1280, 1, 768),
|
63 |
+
PushBlock(),
|
64 |
+
DownSampler(1280),
|
65 |
+
PushBlock(),
|
66 |
+
# DownBlock2D
|
67 |
+
ResnetBlock(1280, 1280, 1280),
|
68 |
+
PushBlock(),
|
69 |
+
ResnetBlock(1280, 1280, 1280),
|
70 |
+
PushBlock(),
|
71 |
+
# UNetMidBlock2DCrossAttn
|
72 |
+
ResnetBlock(1280, 1280, 1280),
|
73 |
+
AttentionBlock(8, 160, 1280, 1, 768),
|
74 |
+
ResnetBlock(1280, 1280, 1280),
|
75 |
+
PushBlock()
|
76 |
+
])
|
77 |
+
|
78 |
+
self.controlnet_blocks = torch.nn.ModuleList([
|
79 |
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
80 |
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
|
81 |
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
|
82 |
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
|
83 |
+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
|
84 |
+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
|
85 |
+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
|
86 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
|
87 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
88 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
89 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
90 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
91 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
92 |
+
])
|
93 |
+
|
94 |
+
self.global_pool = global_pool
|
95 |
+
|
96 |
+
def forward(
|
97 |
+
self,
|
98 |
+
sample, timestep, encoder_hidden_states, conditioning,
|
99 |
+
tiled=False, tile_size=64, tile_stride=32,
|
100 |
+
**kwargs
|
101 |
+
):
|
102 |
+
# 1. time
|
103 |
+
time_emb = self.time_proj(timestep).to(sample.dtype)
|
104 |
+
time_emb = self.time_embedding(time_emb)
|
105 |
+
time_emb = time_emb.repeat(sample.shape[0], 1)
|
106 |
+
|
107 |
+
# 2. pre-process
|
108 |
+
height, width = sample.shape[2], sample.shape[3]
|
109 |
+
hidden_states = self.conv_in(sample) + self.controlnet_conv_in(conditioning)
|
110 |
+
text_emb = encoder_hidden_states
|
111 |
+
res_stack = [hidden_states]
|
112 |
+
|
113 |
+
# 3. blocks
|
114 |
+
for i, block in enumerate(self.blocks):
|
115 |
+
if tiled and not isinstance(block, PushBlock):
|
116 |
+
_, _, inter_height, _ = hidden_states.shape
|
117 |
+
resize_scale = inter_height / height
|
118 |
+
hidden_states = TileWorker().tiled_forward(
|
119 |
+
lambda x: block(x, time_emb, text_emb, res_stack)[0],
|
120 |
+
hidden_states,
|
121 |
+
int(tile_size * resize_scale),
|
122 |
+
int(tile_stride * resize_scale),
|
123 |
+
tile_device=hidden_states.device,
|
124 |
+
tile_dtype=hidden_states.dtype
|
125 |
+
)
|
126 |
+
else:
|
127 |
+
hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
|
128 |
+
|
129 |
+
# 4. ControlNet blocks
|
130 |
+
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
|
131 |
+
|
132 |
+
# pool
|
133 |
+
if self.global_pool:
|
134 |
+
controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack]
|
135 |
+
|
136 |
+
return controlnet_res_stack
|
137 |
+
|
138 |
+
@staticmethod
|
139 |
+
def state_dict_converter():
|
140 |
+
return SDControlNetStateDictConverter()
|
141 |
+
|
142 |
+
|
143 |
+
class SDControlNetStateDictConverter:
|
144 |
+
def __init__(self):
|
145 |
+
pass
|
146 |
+
|
147 |
+
def from_diffusers(self, state_dict):
|
148 |
+
# architecture
|
149 |
+
block_types = [
|
150 |
+
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
|
151 |
+
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
|
152 |
+
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
|
153 |
+
'ResnetBlock', 'PushBlock', 'ResnetBlock', 'PushBlock',
|
154 |
+
'ResnetBlock', 'AttentionBlock', 'ResnetBlock',
|
155 |
+
'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'UpSampler',
|
156 |
+
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
|
157 |
+
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
|
158 |
+
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock'
|
159 |
+
]
|
160 |
+
|
161 |
+
# controlnet_rename_dict
|
162 |
+
controlnet_rename_dict = {
|
163 |
+
"controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight",
|
164 |
+
"controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias",
|
165 |
+
"controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight",
|
166 |
+
"controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias",
|
167 |
+
"controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight",
|
168 |
+
"controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias",
|
169 |
+
"controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight",
|
170 |
+
"controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias",
|
171 |
+
"controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight",
|
172 |
+
"controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias",
|
173 |
+
"controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight",
|
174 |
+
"controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias",
|
175 |
+
"controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
|
176 |
+
"controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
|
177 |
+
"controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
|
178 |
+
"controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
|
179 |
+
}
|
180 |
+
|
181 |
+
# Rename each parameter
|
182 |
+
name_list = sorted([name for name in state_dict])
|
183 |
+
rename_dict = {}
|
184 |
+
block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
|
185 |
+
last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
|
186 |
+
for name in name_list:
|
187 |
+
names = name.split(".")
|
188 |
+
if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
|
189 |
+
pass
|
190 |
+
elif name in controlnet_rename_dict:
|
191 |
+
names = controlnet_rename_dict[name].split(".")
|
192 |
+
elif names[0] == "controlnet_down_blocks":
|
193 |
+
names[0] = "controlnet_blocks"
|
194 |
+
elif names[0] == "controlnet_mid_block":
|
195 |
+
names = ["controlnet_blocks", "12", names[-1]]
|
196 |
+
elif names[0] in ["time_embedding", "add_embedding"]:
|
197 |
+
if names[0] == "add_embedding":
|
198 |
+
names[0] = "add_time_embedding"
|
199 |
+
names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
|
200 |
+
elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
|
201 |
+
if names[0] == "mid_block":
|
202 |
+
names.insert(1, "0")
|
203 |
+
block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
|
204 |
+
block_type_with_id = ".".join(names[:4])
|
205 |
+
if block_type_with_id != last_block_type_with_id[block_type]:
|
206 |
+
block_id[block_type] += 1
|
207 |
+
last_block_type_with_id[block_type] = block_type_with_id
|
208 |
+
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
209 |
+
block_id[block_type] += 1
|
210 |
+
block_type_with_id = ".".join(names[:4])
|
211 |
+
names = ["blocks", str(block_id[block_type])] + names[4:]
|
212 |
+
if "ff" in names:
|
213 |
+
ff_index = names.index("ff")
|
214 |
+
component = ".".join(names[ff_index:ff_index+3])
|
215 |
+
component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
|
216 |
+
names = names[:ff_index] + [component] + names[ff_index+3:]
|
217 |
+
if "to_out" in names:
|
218 |
+
names.pop(names.index("to_out") + 1)
|
219 |
+
else:
|
220 |
+
raise ValueError(f"Unknown parameters: {name}")
|
221 |
+
rename_dict[name] = ".".join(names)
|
222 |
+
|
223 |
+
# Convert state_dict
|
224 |
+
state_dict_ = {}
|
225 |
+
for name, param in state_dict.items():
|
226 |
+
if ".proj_in." in name or ".proj_out." in name:
|
227 |
+
param = param.squeeze()
|
228 |
+
if rename_dict[name] in [
|
229 |
+
"controlnet_blocks.1.bias", "controlnet_blocks.2.bias", "controlnet_blocks.3.bias", "controlnet_blocks.5.bias", "controlnet_blocks.6.bias",
|
230 |
+
"controlnet_blocks.8.bias", "controlnet_blocks.9.bias", "controlnet_blocks.10.bias", "controlnet_blocks.11.bias", "controlnet_blocks.12.bias"
|
231 |
+
]:
|
232 |
+
continue
|
233 |
+
state_dict_[rename_dict[name]] = param
|
234 |
+
return state_dict_
|
235 |
+
|
236 |
+
def from_civitai(self, state_dict):
|
237 |
+
if "mid_block.resnets.1.time_emb_proj.weight" in state_dict:
|
238 |
+
# For controlnets in diffusers format
|
239 |
+
return self.from_diffusers(state_dict)
|
240 |
+
rename_dict = {
|
241 |
+
"control_model.time_embed.0.weight": "time_embedding.0.weight",
|
242 |
+
"control_model.time_embed.0.bias": "time_embedding.0.bias",
|
243 |
+
"control_model.time_embed.2.weight": "time_embedding.2.weight",
|
244 |
+
"control_model.time_embed.2.bias": "time_embedding.2.bias",
|
245 |
+
"control_model.input_blocks.0.0.weight": "conv_in.weight",
|
246 |
+
"control_model.input_blocks.0.0.bias": "conv_in.bias",
|
247 |
+
"control_model.input_blocks.1.0.in_layers.0.weight": "blocks.0.norm1.weight",
|
248 |
+
"control_model.input_blocks.1.0.in_layers.0.bias": "blocks.0.norm1.bias",
|
249 |
+
"control_model.input_blocks.1.0.in_layers.2.weight": "blocks.0.conv1.weight",
|
250 |
+
"control_model.input_blocks.1.0.in_layers.2.bias": "blocks.0.conv1.bias",
|
251 |
+
"control_model.input_blocks.1.0.emb_layers.1.weight": "blocks.0.time_emb_proj.weight",
|
252 |
+
"control_model.input_blocks.1.0.emb_layers.1.bias": "blocks.0.time_emb_proj.bias",
|
253 |
+
"control_model.input_blocks.1.0.out_layers.0.weight": "blocks.0.norm2.weight",
|
254 |
+
"control_model.input_blocks.1.0.out_layers.0.bias": "blocks.0.norm2.bias",
|
255 |
+
"control_model.input_blocks.1.0.out_layers.3.weight": "blocks.0.conv2.weight",
|
256 |
+
"control_model.input_blocks.1.0.out_layers.3.bias": "blocks.0.conv2.bias",
|
257 |
+
"control_model.input_blocks.1.1.norm.weight": "blocks.1.norm.weight",
|
258 |
+
"control_model.input_blocks.1.1.norm.bias": "blocks.1.norm.bias",
|
259 |
+
"control_model.input_blocks.1.1.proj_in.weight": "blocks.1.proj_in.weight",
|
260 |
+
"control_model.input_blocks.1.1.proj_in.bias": "blocks.1.proj_in.bias",
|
261 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "blocks.1.transformer_blocks.0.attn1.to_q.weight",
|
262 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "blocks.1.transformer_blocks.0.attn1.to_k.weight",
|
263 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "blocks.1.transformer_blocks.0.attn1.to_v.weight",
|
264 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.1.transformer_blocks.0.attn1.to_out.weight",
|
265 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.1.transformer_blocks.0.attn1.to_out.bias",
|
266 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.1.transformer_blocks.0.act_fn.proj.weight",
|
267 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.1.transformer_blocks.0.act_fn.proj.bias",
|
268 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "blocks.1.transformer_blocks.0.ff.weight",
|
269 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "blocks.1.transformer_blocks.0.ff.bias",
|
270 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "blocks.1.transformer_blocks.0.attn2.to_q.weight",
|
271 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "blocks.1.transformer_blocks.0.attn2.to_k.weight",
|
272 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "blocks.1.transformer_blocks.0.attn2.to_v.weight",
|
273 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.1.transformer_blocks.0.attn2.to_out.weight",
|
274 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.1.transformer_blocks.0.attn2.to_out.bias",
|
275 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "blocks.1.transformer_blocks.0.norm1.weight",
|
276 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "blocks.1.transformer_blocks.0.norm1.bias",
|
277 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "blocks.1.transformer_blocks.0.norm2.weight",
|
278 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "blocks.1.transformer_blocks.0.norm2.bias",
|
279 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "blocks.1.transformer_blocks.0.norm3.weight",
|
280 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "blocks.1.transformer_blocks.0.norm3.bias",
|
281 |
+
"control_model.input_blocks.1.1.proj_out.weight": "blocks.1.proj_out.weight",
|
282 |
+
"control_model.input_blocks.1.1.proj_out.bias": "blocks.1.proj_out.bias",
|
283 |
+
"control_model.input_blocks.2.0.in_layers.0.weight": "blocks.3.norm1.weight",
|
284 |
+
"control_model.input_blocks.2.0.in_layers.0.bias": "blocks.3.norm1.bias",
|
285 |
+
"control_model.input_blocks.2.0.in_layers.2.weight": "blocks.3.conv1.weight",
|
286 |
+
"control_model.input_blocks.2.0.in_layers.2.bias": "blocks.3.conv1.bias",
|
287 |
+
"control_model.input_blocks.2.0.emb_layers.1.weight": "blocks.3.time_emb_proj.weight",
|
288 |
+
"control_model.input_blocks.2.0.emb_layers.1.bias": "blocks.3.time_emb_proj.bias",
|
289 |
+
"control_model.input_blocks.2.0.out_layers.0.weight": "blocks.3.norm2.weight",
|
290 |
+
"control_model.input_blocks.2.0.out_layers.0.bias": "blocks.3.norm2.bias",
|
291 |
+
"control_model.input_blocks.2.0.out_layers.3.weight": "blocks.3.conv2.weight",
|
292 |
+
"control_model.input_blocks.2.0.out_layers.3.bias": "blocks.3.conv2.bias",
|
293 |
+
"control_model.input_blocks.2.1.norm.weight": "blocks.4.norm.weight",
|
294 |
+
"control_model.input_blocks.2.1.norm.bias": "blocks.4.norm.bias",
|
295 |
+
"control_model.input_blocks.2.1.proj_in.weight": "blocks.4.proj_in.weight",
|
296 |
+
"control_model.input_blocks.2.1.proj_in.bias": "blocks.4.proj_in.bias",
|
297 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "blocks.4.transformer_blocks.0.attn1.to_q.weight",
|
298 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "blocks.4.transformer_blocks.0.attn1.to_k.weight",
|
299 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "blocks.4.transformer_blocks.0.attn1.to_v.weight",
|
300 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.4.transformer_blocks.0.attn1.to_out.weight",
|
301 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.4.transformer_blocks.0.attn1.to_out.bias",
|
302 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.4.transformer_blocks.0.act_fn.proj.weight",
|
303 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.4.transformer_blocks.0.act_fn.proj.bias",
|
304 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "blocks.4.transformer_blocks.0.ff.weight",
|
305 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "blocks.4.transformer_blocks.0.ff.bias",
|
306 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "blocks.4.transformer_blocks.0.attn2.to_q.weight",
|
307 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "blocks.4.transformer_blocks.0.attn2.to_k.weight",
|
308 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "blocks.4.transformer_blocks.0.attn2.to_v.weight",
|
309 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.4.transformer_blocks.0.attn2.to_out.weight",
|
310 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.4.transformer_blocks.0.attn2.to_out.bias",
|
311 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "blocks.4.transformer_blocks.0.norm1.weight",
|
312 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "blocks.4.transformer_blocks.0.norm1.bias",
|
313 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "blocks.4.transformer_blocks.0.norm2.weight",
|
314 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "blocks.4.transformer_blocks.0.norm2.bias",
|
315 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "blocks.4.transformer_blocks.0.norm3.weight",
|
316 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "blocks.4.transformer_blocks.0.norm3.bias",
|
317 |
+
"control_model.input_blocks.2.1.proj_out.weight": "blocks.4.proj_out.weight",
|
318 |
+
"control_model.input_blocks.2.1.proj_out.bias": "blocks.4.proj_out.bias",
|
319 |
+
"control_model.input_blocks.3.0.op.weight": "blocks.6.conv.weight",
|
320 |
+
"control_model.input_blocks.3.0.op.bias": "blocks.6.conv.bias",
|
321 |
+
"control_model.input_blocks.4.0.in_layers.0.weight": "blocks.8.norm1.weight",
|
322 |
+
"control_model.input_blocks.4.0.in_layers.0.bias": "blocks.8.norm1.bias",
|
323 |
+
"control_model.input_blocks.4.0.in_layers.2.weight": "blocks.8.conv1.weight",
|
324 |
+
"control_model.input_blocks.4.0.in_layers.2.bias": "blocks.8.conv1.bias",
|
325 |
+
"control_model.input_blocks.4.0.emb_layers.1.weight": "blocks.8.time_emb_proj.weight",
|
326 |
+
"control_model.input_blocks.4.0.emb_layers.1.bias": "blocks.8.time_emb_proj.bias",
|
327 |
+
"control_model.input_blocks.4.0.out_layers.0.weight": "blocks.8.norm2.weight",
|
328 |
+
"control_model.input_blocks.4.0.out_layers.0.bias": "blocks.8.norm2.bias",
|
329 |
+
"control_model.input_blocks.4.0.out_layers.3.weight": "blocks.8.conv2.weight",
|
330 |
+
"control_model.input_blocks.4.0.out_layers.3.bias": "blocks.8.conv2.bias",
|
331 |
+
"control_model.input_blocks.4.0.skip_connection.weight": "blocks.8.conv_shortcut.weight",
|
332 |
+
"control_model.input_blocks.4.0.skip_connection.bias": "blocks.8.conv_shortcut.bias",
|
333 |
+
"control_model.input_blocks.4.1.norm.weight": "blocks.9.norm.weight",
|
334 |
+
"control_model.input_blocks.4.1.norm.bias": "blocks.9.norm.bias",
|
335 |
+
"control_model.input_blocks.4.1.proj_in.weight": "blocks.9.proj_in.weight",
|
336 |
+
"control_model.input_blocks.4.1.proj_in.bias": "blocks.9.proj_in.bias",
|
337 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.9.transformer_blocks.0.attn1.to_q.weight",
|
338 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.9.transformer_blocks.0.attn1.to_k.weight",
|
339 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.9.transformer_blocks.0.attn1.to_v.weight",
|
340 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.9.transformer_blocks.0.attn1.to_out.weight",
|
341 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.9.transformer_blocks.0.attn1.to_out.bias",
|
342 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.9.transformer_blocks.0.act_fn.proj.weight",
|
343 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.9.transformer_blocks.0.act_fn.proj.bias",
|
344 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.9.transformer_blocks.0.ff.weight",
|
345 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.9.transformer_blocks.0.ff.bias",
|
346 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.9.transformer_blocks.0.attn2.to_q.weight",
|
347 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.9.transformer_blocks.0.attn2.to_k.weight",
|
348 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.9.transformer_blocks.0.attn2.to_v.weight",
|
349 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.9.transformer_blocks.0.attn2.to_out.weight",
|
350 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.9.transformer_blocks.0.attn2.to_out.bias",
|
351 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.9.transformer_blocks.0.norm1.weight",
|
352 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.9.transformer_blocks.0.norm1.bias",
|
353 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.9.transformer_blocks.0.norm2.weight",
|
354 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.9.transformer_blocks.0.norm2.bias",
|
355 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.9.transformer_blocks.0.norm3.weight",
|
356 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.9.transformer_blocks.0.norm3.bias",
|
357 |
+
"control_model.input_blocks.4.1.proj_out.weight": "blocks.9.proj_out.weight",
|
358 |
+
"control_model.input_blocks.4.1.proj_out.bias": "blocks.9.proj_out.bias",
|
359 |
+
"control_model.input_blocks.5.0.in_layers.0.weight": "blocks.11.norm1.weight",
|
360 |
+
"control_model.input_blocks.5.0.in_layers.0.bias": "blocks.11.norm1.bias",
|
361 |
+
"control_model.input_blocks.5.0.in_layers.2.weight": "blocks.11.conv1.weight",
|
362 |
+
"control_model.input_blocks.5.0.in_layers.2.bias": "blocks.11.conv1.bias",
|
363 |
+
"control_model.input_blocks.5.0.emb_layers.1.weight": "blocks.11.time_emb_proj.weight",
|
364 |
+
"control_model.input_blocks.5.0.emb_layers.1.bias": "blocks.11.time_emb_proj.bias",
|
365 |
+
"control_model.input_blocks.5.0.out_layers.0.weight": "blocks.11.norm2.weight",
|
366 |
+
"control_model.input_blocks.5.0.out_layers.0.bias": "blocks.11.norm2.bias",
|
367 |
+
"control_model.input_blocks.5.0.out_layers.3.weight": "blocks.11.conv2.weight",
|
368 |
+
"control_model.input_blocks.5.0.out_layers.3.bias": "blocks.11.conv2.bias",
|
369 |
+
"control_model.input_blocks.5.1.norm.weight": "blocks.12.norm.weight",
|
370 |
+
"control_model.input_blocks.5.1.norm.bias": "blocks.12.norm.bias",
|
371 |
+
"control_model.input_blocks.5.1.proj_in.weight": "blocks.12.proj_in.weight",
|
372 |
+
"control_model.input_blocks.5.1.proj_in.bias": "blocks.12.proj_in.bias",
|
373 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.12.transformer_blocks.0.attn1.to_q.weight",
|
374 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.12.transformer_blocks.0.attn1.to_k.weight",
|
375 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.12.transformer_blocks.0.attn1.to_v.weight",
|
376 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.12.transformer_blocks.0.attn1.to_out.weight",
|
377 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.12.transformer_blocks.0.attn1.to_out.bias",
|
378 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.12.transformer_blocks.0.act_fn.proj.weight",
|
379 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.12.transformer_blocks.0.act_fn.proj.bias",
|
380 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.12.transformer_blocks.0.ff.weight",
|
381 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.12.transformer_blocks.0.ff.bias",
|
382 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.12.transformer_blocks.0.attn2.to_q.weight",
|
383 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.12.transformer_blocks.0.attn2.to_k.weight",
|
384 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.12.transformer_blocks.0.attn2.to_v.weight",
|
385 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.12.transformer_blocks.0.attn2.to_out.weight",
|
386 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.12.transformer_blocks.0.attn2.to_out.bias",
|
387 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.12.transformer_blocks.0.norm1.weight",
|
388 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.12.transformer_blocks.0.norm1.bias",
|
389 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.12.transformer_blocks.0.norm2.weight",
|
390 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.12.transformer_blocks.0.norm2.bias",
|
391 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.12.transformer_blocks.0.norm3.weight",
|
392 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.12.transformer_blocks.0.norm3.bias",
|
393 |
+
"control_model.input_blocks.5.1.proj_out.weight": "blocks.12.proj_out.weight",
|
394 |
+
"control_model.input_blocks.5.1.proj_out.bias": "blocks.12.proj_out.bias",
|
395 |
+
"control_model.input_blocks.6.0.op.weight": "blocks.14.conv.weight",
|
396 |
+
"control_model.input_blocks.6.0.op.bias": "blocks.14.conv.bias",
|
397 |
+
"control_model.input_blocks.7.0.in_layers.0.weight": "blocks.16.norm1.weight",
|
398 |
+
"control_model.input_blocks.7.0.in_layers.0.bias": "blocks.16.norm1.bias",
|
399 |
+
"control_model.input_blocks.7.0.in_layers.2.weight": "blocks.16.conv1.weight",
|
400 |
+
"control_model.input_blocks.7.0.in_layers.2.bias": "blocks.16.conv1.bias",
|
401 |
+
"control_model.input_blocks.7.0.emb_layers.1.weight": "blocks.16.time_emb_proj.weight",
|
402 |
+
"control_model.input_blocks.7.0.emb_layers.1.bias": "blocks.16.time_emb_proj.bias",
|
403 |
+
"control_model.input_blocks.7.0.out_layers.0.weight": "blocks.16.norm2.weight",
|
404 |
+
"control_model.input_blocks.7.0.out_layers.0.bias": "blocks.16.norm2.bias",
|
405 |
+
"control_model.input_blocks.7.0.out_layers.3.weight": "blocks.16.conv2.weight",
|
406 |
+
"control_model.input_blocks.7.0.out_layers.3.bias": "blocks.16.conv2.bias",
|
407 |
+
"control_model.input_blocks.7.0.skip_connection.weight": "blocks.16.conv_shortcut.weight",
|
408 |
+
"control_model.input_blocks.7.0.skip_connection.bias": "blocks.16.conv_shortcut.bias",
|
409 |
+
"control_model.input_blocks.7.1.norm.weight": "blocks.17.norm.weight",
|
410 |
+
"control_model.input_blocks.7.1.norm.bias": "blocks.17.norm.bias",
|
411 |
+
"control_model.input_blocks.7.1.proj_in.weight": "blocks.17.proj_in.weight",
|
412 |
+
"control_model.input_blocks.7.1.proj_in.bias": "blocks.17.proj_in.bias",
|
413 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.17.transformer_blocks.0.attn1.to_q.weight",
|
414 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.17.transformer_blocks.0.attn1.to_k.weight",
|
415 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.17.transformer_blocks.0.attn1.to_v.weight",
|
416 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.17.transformer_blocks.0.attn1.to_out.weight",
|
417 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.17.transformer_blocks.0.attn1.to_out.bias",
|
418 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.17.transformer_blocks.0.act_fn.proj.weight",
|
419 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.17.transformer_blocks.0.act_fn.proj.bias",
|
420 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.17.transformer_blocks.0.ff.weight",
|
421 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.17.transformer_blocks.0.ff.bias",
|
422 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.17.transformer_blocks.0.attn2.to_q.weight",
|
423 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.17.transformer_blocks.0.attn2.to_k.weight",
|
424 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.17.transformer_blocks.0.attn2.to_v.weight",
|
425 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.17.transformer_blocks.0.attn2.to_out.weight",
|
426 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.17.transformer_blocks.0.attn2.to_out.bias",
|
427 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.17.transformer_blocks.0.norm1.weight",
|
428 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.17.transformer_blocks.0.norm1.bias",
|
429 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.17.transformer_blocks.0.norm2.weight",
|
430 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.17.transformer_blocks.0.norm2.bias",
|
431 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.17.transformer_blocks.0.norm3.weight",
|
432 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.17.transformer_blocks.0.norm3.bias",
|
433 |
+
"control_model.input_blocks.7.1.proj_out.weight": "blocks.17.proj_out.weight",
|
434 |
+
"control_model.input_blocks.7.1.proj_out.bias": "blocks.17.proj_out.bias",
|
435 |
+
"control_model.input_blocks.8.0.in_layers.0.weight": "blocks.19.norm1.weight",
|
436 |
+
"control_model.input_blocks.8.0.in_layers.0.bias": "blocks.19.norm1.bias",
|
437 |
+
"control_model.input_blocks.8.0.in_layers.2.weight": "blocks.19.conv1.weight",
|
438 |
+
"control_model.input_blocks.8.0.in_layers.2.bias": "blocks.19.conv1.bias",
|
439 |
+
"control_model.input_blocks.8.0.emb_layers.1.weight": "blocks.19.time_emb_proj.weight",
|
440 |
+
"control_model.input_blocks.8.0.emb_layers.1.bias": "blocks.19.time_emb_proj.bias",
|
441 |
+
"control_model.input_blocks.8.0.out_layers.0.weight": "blocks.19.norm2.weight",
|
442 |
+
"control_model.input_blocks.8.0.out_layers.0.bias": "blocks.19.norm2.bias",
|
443 |
+
"control_model.input_blocks.8.0.out_layers.3.weight": "blocks.19.conv2.weight",
|
444 |
+
"control_model.input_blocks.8.0.out_layers.3.bias": "blocks.19.conv2.bias",
|
445 |
+
"control_model.input_blocks.8.1.norm.weight": "blocks.20.norm.weight",
|
446 |
+
"control_model.input_blocks.8.1.norm.bias": "blocks.20.norm.bias",
|
447 |
+
"control_model.input_blocks.8.1.proj_in.weight": "blocks.20.proj_in.weight",
|
448 |
+
"control_model.input_blocks.8.1.proj_in.bias": "blocks.20.proj_in.bias",
|
449 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.20.transformer_blocks.0.attn1.to_q.weight",
|
450 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.20.transformer_blocks.0.attn1.to_k.weight",
|
451 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.20.transformer_blocks.0.attn1.to_v.weight",
|
452 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.20.transformer_blocks.0.attn1.to_out.weight",
|
453 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.20.transformer_blocks.0.attn1.to_out.bias",
|
454 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.20.transformer_blocks.0.act_fn.proj.weight",
|
455 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.20.transformer_blocks.0.act_fn.proj.bias",
|
456 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.20.transformer_blocks.0.ff.weight",
|
457 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.20.transformer_blocks.0.ff.bias",
|
458 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.20.transformer_blocks.0.attn2.to_q.weight",
|
459 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.20.transformer_blocks.0.attn2.to_k.weight",
|
460 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.20.transformer_blocks.0.attn2.to_v.weight",
|
461 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.20.transformer_blocks.0.attn2.to_out.weight",
|
462 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.20.transformer_blocks.0.attn2.to_out.bias",
|
463 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.20.transformer_blocks.0.norm1.weight",
|
464 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.20.transformer_blocks.0.norm1.bias",
|
465 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.20.transformer_blocks.0.norm2.weight",
|
466 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.20.transformer_blocks.0.norm2.bias",
|
467 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.20.transformer_blocks.0.norm3.weight",
|
468 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.20.transformer_blocks.0.norm3.bias",
|
469 |
+
"control_model.input_blocks.8.1.proj_out.weight": "blocks.20.proj_out.weight",
|
470 |
+
"control_model.input_blocks.8.1.proj_out.bias": "blocks.20.proj_out.bias",
|
471 |
+
"control_model.input_blocks.9.0.op.weight": "blocks.22.conv.weight",
|
472 |
+
"control_model.input_blocks.9.0.op.bias": "blocks.22.conv.bias",
|
473 |
+
"control_model.input_blocks.10.0.in_layers.0.weight": "blocks.24.norm1.weight",
|
474 |
+
"control_model.input_blocks.10.0.in_layers.0.bias": "blocks.24.norm1.bias",
|
475 |
+
"control_model.input_blocks.10.0.in_layers.2.weight": "blocks.24.conv1.weight",
|
476 |
+
"control_model.input_blocks.10.0.in_layers.2.bias": "blocks.24.conv1.bias",
|
477 |
+
"control_model.input_blocks.10.0.emb_layers.1.weight": "blocks.24.time_emb_proj.weight",
|
478 |
+
"control_model.input_blocks.10.0.emb_layers.1.bias": "blocks.24.time_emb_proj.bias",
|
479 |
+
"control_model.input_blocks.10.0.out_layers.0.weight": "blocks.24.norm2.weight",
|
480 |
+
"control_model.input_blocks.10.0.out_layers.0.bias": "blocks.24.norm2.bias",
|
481 |
+
"control_model.input_blocks.10.0.out_layers.3.weight": "blocks.24.conv2.weight",
|
482 |
+
"control_model.input_blocks.10.0.out_layers.3.bias": "blocks.24.conv2.bias",
|
483 |
+
"control_model.input_blocks.11.0.in_layers.0.weight": "blocks.26.norm1.weight",
|
484 |
+
"control_model.input_blocks.11.0.in_layers.0.bias": "blocks.26.norm1.bias",
|
485 |
+
"control_model.input_blocks.11.0.in_layers.2.weight": "blocks.26.conv1.weight",
|
486 |
+
"control_model.input_blocks.11.0.in_layers.2.bias": "blocks.26.conv1.bias",
|
487 |
+
"control_model.input_blocks.11.0.emb_layers.1.weight": "blocks.26.time_emb_proj.weight",
|
488 |
+
"control_model.input_blocks.11.0.emb_layers.1.bias": "blocks.26.time_emb_proj.bias",
|
489 |
+
"control_model.input_blocks.11.0.out_layers.0.weight": "blocks.26.norm2.weight",
|
490 |
+
"control_model.input_blocks.11.0.out_layers.0.bias": "blocks.26.norm2.bias",
|
491 |
+
"control_model.input_blocks.11.0.out_layers.3.weight": "blocks.26.conv2.weight",
|
492 |
+
"control_model.input_blocks.11.0.out_layers.3.bias": "blocks.26.conv2.bias",
|
493 |
+
"control_model.zero_convs.0.0.weight": "controlnet_blocks.0.weight",
|
494 |
+
"control_model.zero_convs.0.0.bias": "controlnet_blocks.0.bias",
|
495 |
+
"control_model.zero_convs.1.0.weight": "controlnet_blocks.1.weight",
|
496 |
+
"control_model.zero_convs.1.0.bias": "controlnet_blocks.0.bias",
|
497 |
+
"control_model.zero_convs.2.0.weight": "controlnet_blocks.2.weight",
|
498 |
+
"control_model.zero_convs.2.0.bias": "controlnet_blocks.0.bias",
|
499 |
+
"control_model.zero_convs.3.0.weight": "controlnet_blocks.3.weight",
|
500 |
+
"control_model.zero_convs.3.0.bias": "controlnet_blocks.0.bias",
|
501 |
+
"control_model.zero_convs.4.0.weight": "controlnet_blocks.4.weight",
|
502 |
+
"control_model.zero_convs.4.0.bias": "controlnet_blocks.4.bias",
|
503 |
+
"control_model.zero_convs.5.0.weight": "controlnet_blocks.5.weight",
|
504 |
+
"control_model.zero_convs.5.0.bias": "controlnet_blocks.4.bias",
|
505 |
+
"control_model.zero_convs.6.0.weight": "controlnet_blocks.6.weight",
|
506 |
+
"control_model.zero_convs.6.0.bias": "controlnet_blocks.4.bias",
|
507 |
+
"control_model.zero_convs.7.0.weight": "controlnet_blocks.7.weight",
|
508 |
+
"control_model.zero_convs.7.0.bias": "controlnet_blocks.7.bias",
|
509 |
+
"control_model.zero_convs.8.0.weight": "controlnet_blocks.8.weight",
|
510 |
+
"control_model.zero_convs.8.0.bias": "controlnet_blocks.7.bias",
|
511 |
+
"control_model.zero_convs.9.0.weight": "controlnet_blocks.9.weight",
|
512 |
+
"control_model.zero_convs.9.0.bias": "controlnet_blocks.7.bias",
|
513 |
+
"control_model.zero_convs.10.0.weight": "controlnet_blocks.10.weight",
|
514 |
+
"control_model.zero_convs.10.0.bias": "controlnet_blocks.7.bias",
|
515 |
+
"control_model.zero_convs.11.0.weight": "controlnet_blocks.11.weight",
|
516 |
+
"control_model.zero_convs.11.0.bias": "controlnet_blocks.7.bias",
|
517 |
+
"control_model.input_hint_block.0.weight": "controlnet_conv_in.blocks.0.weight",
|
518 |
+
"control_model.input_hint_block.0.bias": "controlnet_conv_in.blocks.0.bias",
|
519 |
+
"control_model.input_hint_block.2.weight": "controlnet_conv_in.blocks.2.weight",
|
520 |
+
"control_model.input_hint_block.2.bias": "controlnet_conv_in.blocks.2.bias",
|
521 |
+
"control_model.input_hint_block.4.weight": "controlnet_conv_in.blocks.4.weight",
|
522 |
+
"control_model.input_hint_block.4.bias": "controlnet_conv_in.blocks.4.bias",
|
523 |
+
"control_model.input_hint_block.6.weight": "controlnet_conv_in.blocks.6.weight",
|
524 |
+
"control_model.input_hint_block.6.bias": "controlnet_conv_in.blocks.6.bias",
|
525 |
+
"control_model.input_hint_block.8.weight": "controlnet_conv_in.blocks.8.weight",
|
526 |
+
"control_model.input_hint_block.8.bias": "controlnet_conv_in.blocks.8.bias",
|
527 |
+
"control_model.input_hint_block.10.weight": "controlnet_conv_in.blocks.10.weight",
|
528 |
+
"control_model.input_hint_block.10.bias": "controlnet_conv_in.blocks.10.bias",
|
529 |
+
"control_model.input_hint_block.12.weight": "controlnet_conv_in.blocks.12.weight",
|
530 |
+
"control_model.input_hint_block.12.bias": "controlnet_conv_in.blocks.12.bias",
|
531 |
+
"control_model.input_hint_block.14.weight": "controlnet_conv_in.blocks.14.weight",
|
532 |
+
"control_model.input_hint_block.14.bias": "controlnet_conv_in.blocks.14.bias",
|
533 |
+
"control_model.middle_block.0.in_layers.0.weight": "blocks.28.norm1.weight",
|
534 |
+
"control_model.middle_block.0.in_layers.0.bias": "blocks.28.norm1.bias",
|
535 |
+
"control_model.middle_block.0.in_layers.2.weight": "blocks.28.conv1.weight",
|
536 |
+
"control_model.middle_block.0.in_layers.2.bias": "blocks.28.conv1.bias",
|
537 |
+
"control_model.middle_block.0.emb_layers.1.weight": "blocks.28.time_emb_proj.weight",
|
538 |
+
"control_model.middle_block.0.emb_layers.1.bias": "blocks.28.time_emb_proj.bias",
|
539 |
+
"control_model.middle_block.0.out_layers.0.weight": "blocks.28.norm2.weight",
|
540 |
+
"control_model.middle_block.0.out_layers.0.bias": "blocks.28.norm2.bias",
|
541 |
+
"control_model.middle_block.0.out_layers.3.weight": "blocks.28.conv2.weight",
|
542 |
+
"control_model.middle_block.0.out_layers.3.bias": "blocks.28.conv2.bias",
|
543 |
+
"control_model.middle_block.1.norm.weight": "blocks.29.norm.weight",
|
544 |
+
"control_model.middle_block.1.norm.bias": "blocks.29.norm.bias",
|
545 |
+
"control_model.middle_block.1.proj_in.weight": "blocks.29.proj_in.weight",
|
546 |
+
"control_model.middle_block.1.proj_in.bias": "blocks.29.proj_in.bias",
|
547 |
+
"control_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "blocks.29.transformer_blocks.0.attn1.to_q.weight",
|
548 |
+
"control_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "blocks.29.transformer_blocks.0.attn1.to_k.weight",
|
549 |
+
"control_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "blocks.29.transformer_blocks.0.attn1.to_v.weight",
|
550 |
+
"control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.29.transformer_blocks.0.attn1.to_out.weight",
|
551 |
+
"control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.29.transformer_blocks.0.attn1.to_out.bias",
|
552 |
+
"control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.29.transformer_blocks.0.act_fn.proj.weight",
|
553 |
+
"control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.29.transformer_blocks.0.act_fn.proj.bias",
|
554 |
+
"control_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "blocks.29.transformer_blocks.0.ff.weight",
|
555 |
+
"control_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "blocks.29.transformer_blocks.0.ff.bias",
|
556 |
+
"control_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "blocks.29.transformer_blocks.0.attn2.to_q.weight",
|
557 |
+
"control_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "blocks.29.transformer_blocks.0.attn2.to_k.weight",
|
558 |
+
"control_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "blocks.29.transformer_blocks.0.attn2.to_v.weight",
|
559 |
+
"control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.29.transformer_blocks.0.attn2.to_out.weight",
|
560 |
+
"control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.29.transformer_blocks.0.attn2.to_out.bias",
|
561 |
+
"control_model.middle_block.1.transformer_blocks.0.norm1.weight": "blocks.29.transformer_blocks.0.norm1.weight",
|
562 |
+
"control_model.middle_block.1.transformer_blocks.0.norm1.bias": "blocks.29.transformer_blocks.0.norm1.bias",
|
563 |
+
"control_model.middle_block.1.transformer_blocks.0.norm2.weight": "blocks.29.transformer_blocks.0.norm2.weight",
|
564 |
+
"control_model.middle_block.1.transformer_blocks.0.norm2.bias": "blocks.29.transformer_blocks.0.norm2.bias",
|
565 |
+
"control_model.middle_block.1.transformer_blocks.0.norm3.weight": "blocks.29.transformer_blocks.0.norm3.weight",
|
566 |
+
"control_model.middle_block.1.transformer_blocks.0.norm3.bias": "blocks.29.transformer_blocks.0.norm3.bias",
|
567 |
+
"control_model.middle_block.1.proj_out.weight": "blocks.29.proj_out.weight",
|
568 |
+
"control_model.middle_block.1.proj_out.bias": "blocks.29.proj_out.bias",
|
569 |
+
"control_model.middle_block.2.in_layers.0.weight": "blocks.30.norm1.weight",
|
570 |
+
"control_model.middle_block.2.in_layers.0.bias": "blocks.30.norm1.bias",
|
571 |
+
"control_model.middle_block.2.in_layers.2.weight": "blocks.30.conv1.weight",
|
572 |
+
"control_model.middle_block.2.in_layers.2.bias": "blocks.30.conv1.bias",
|
573 |
+
"control_model.middle_block.2.emb_layers.1.weight": "blocks.30.time_emb_proj.weight",
|
574 |
+
"control_model.middle_block.2.emb_layers.1.bias": "blocks.30.time_emb_proj.bias",
|
575 |
+
"control_model.middle_block.2.out_layers.0.weight": "blocks.30.norm2.weight",
|
576 |
+
"control_model.middle_block.2.out_layers.0.bias": "blocks.30.norm2.bias",
|
577 |
+
"control_model.middle_block.2.out_layers.3.weight": "blocks.30.conv2.weight",
|
578 |
+
"control_model.middle_block.2.out_layers.3.bias": "blocks.30.conv2.bias",
|
579 |
+
"control_model.middle_block_out.0.weight": "controlnet_blocks.12.weight",
|
580 |
+
"control_model.middle_block_out.0.bias": "controlnet_blocks.7.bias",
|
581 |
+
}
|
582 |
+
state_dict_ = {}
|
583 |
+
for name in state_dict:
|
584 |
+
if name in rename_dict:
|
585 |
+
param = state_dict[name]
|
586 |
+
if ".proj_in." in name or ".proj_out." in name:
|
587 |
+
param = param.squeeze()
|
588 |
+
state_dict_[rename_dict[name]] = param
|
589 |
+
return state_dict_
|
diffsynth/models/sd_ipadapter.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .svd_image_encoder import SVDImageEncoder
|
2 |
+
from .sdxl_ipadapter import IpAdapterImageProjModel, IpAdapterModule, SDXLIpAdapterStateDictConverter
|
3 |
+
from transformers import CLIPImageProcessor
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
class IpAdapterCLIPImageEmbedder(SVDImageEncoder):
|
8 |
+
def __init__(self):
|
9 |
+
super().__init__()
|
10 |
+
self.image_processor = CLIPImageProcessor()
|
11 |
+
|
12 |
+
def forward(self, image):
|
13 |
+
pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values
|
14 |
+
pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype)
|
15 |
+
return super().forward(pixel_values)
|
16 |
+
|
17 |
+
|
18 |
+
class SDIpAdapter(torch.nn.Module):
|
19 |
+
def __init__(self):
|
20 |
+
super().__init__()
|
21 |
+
shape_list = [(768, 320)] * 2 + [(768, 640)] * 2 + [(768, 1280)] * 5 + [(768, 640)] * 3 + [(768, 320)] * 3 + [(768, 1280)] * 1
|
22 |
+
self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list])
|
23 |
+
self.image_proj = IpAdapterImageProjModel(cross_attention_dim=768, clip_embeddings_dim=1024, clip_extra_context_tokens=4)
|
24 |
+
self.set_full_adapter()
|
25 |
+
|
26 |
+
def set_full_adapter(self):
|
27 |
+
block_ids = [1, 4, 9, 12, 17, 20, 40, 43, 46, 50, 53, 56, 60, 63, 66, 29]
|
28 |
+
self.call_block_id = {(i, 0): j for j, i in enumerate(block_ids)}
|
29 |
+
|
30 |
+
def set_less_adapter(self):
|
31 |
+
# IP-Adapter for SD v1.5 doesn't support this feature.
|
32 |
+
self.set_full_adapter()
|
33 |
+
|
34 |
+
def forward(self, hidden_states, scale=1.0):
|
35 |
+
hidden_states = self.image_proj(hidden_states)
|
36 |
+
hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
|
37 |
+
ip_kv_dict = {}
|
38 |
+
for (block_id, transformer_id) in self.call_block_id:
|
39 |
+
ipadapter_id = self.call_block_id[(block_id, transformer_id)]
|
40 |
+
ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
|
41 |
+
if block_id not in ip_kv_dict:
|
42 |
+
ip_kv_dict[block_id] = {}
|
43 |
+
ip_kv_dict[block_id][transformer_id] = {
|
44 |
+
"ip_k": ip_k,
|
45 |
+
"ip_v": ip_v,
|
46 |
+
"scale": scale
|
47 |
+
}
|
48 |
+
return ip_kv_dict
|
49 |
+
|
50 |
+
@staticmethod
|
51 |
+
def state_dict_converter():
|
52 |
+
return SDIpAdapterStateDictConverter()
|
53 |
+
|
54 |
+
|
55 |
+
class SDIpAdapterStateDictConverter(SDXLIpAdapterStateDictConverter):
|
56 |
+
def __init__(self):
|
57 |
+
pass
|
diffsynth/models/sd_motion.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .sd_unet import SDUNet, Attention, GEGLU
|
2 |
+
import torch
|
3 |
+
from einops import rearrange, repeat
|
4 |
+
|
5 |
+
|
6 |
+
class TemporalTransformerBlock(torch.nn.Module):
|
7 |
+
|
8 |
+
def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
# 1. Self-Attn
|
12 |
+
self.pe1 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
|
13 |
+
self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
14 |
+
self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
|
15 |
+
|
16 |
+
# 2. Cross-Attn
|
17 |
+
self.pe2 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
|
18 |
+
self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
19 |
+
self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
|
20 |
+
|
21 |
+
# 3. Feed-forward
|
22 |
+
self.norm3 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
23 |
+
self.act_fn = GEGLU(dim, dim * 4)
|
24 |
+
self.ff = torch.nn.Linear(dim * 4, dim)
|
25 |
+
|
26 |
+
|
27 |
+
def forward(self, hidden_states, batch_size=1):
|
28 |
+
|
29 |
+
# 1. Self-Attention
|
30 |
+
norm_hidden_states = self.norm1(hidden_states)
|
31 |
+
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
|
32 |
+
attn_output = self.attn1(norm_hidden_states + self.pe1[:, :norm_hidden_states.shape[1]])
|
33 |
+
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
|
34 |
+
hidden_states = attn_output + hidden_states
|
35 |
+
|
36 |
+
# 2. Cross-Attention
|
37 |
+
norm_hidden_states = self.norm2(hidden_states)
|
38 |
+
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
|
39 |
+
attn_output = self.attn2(norm_hidden_states + self.pe2[:, :norm_hidden_states.shape[1]])
|
40 |
+
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
|
41 |
+
hidden_states = attn_output + hidden_states
|
42 |
+
|
43 |
+
# 3. Feed-forward
|
44 |
+
norm_hidden_states = self.norm3(hidden_states)
|
45 |
+
ff_output = self.act_fn(norm_hidden_states)
|
46 |
+
ff_output = self.ff(ff_output)
|
47 |
+
hidden_states = ff_output + hidden_states
|
48 |
+
|
49 |
+
return hidden_states
|
50 |
+
|
51 |
+
|
52 |
+
class TemporalBlock(torch.nn.Module):
|
53 |
+
|
54 |
+
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
|
55 |
+
super().__init__()
|
56 |
+
inner_dim = num_attention_heads * attention_head_dim
|
57 |
+
|
58 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
59 |
+
self.proj_in = torch.nn.Linear(in_channels, inner_dim)
|
60 |
+
|
61 |
+
self.transformer_blocks = torch.nn.ModuleList([
|
62 |
+
TemporalTransformerBlock(
|
63 |
+
inner_dim,
|
64 |
+
num_attention_heads,
|
65 |
+
attention_head_dim
|
66 |
+
)
|
67 |
+
for d in range(num_layers)
|
68 |
+
])
|
69 |
+
|
70 |
+
self.proj_out = torch.nn.Linear(inner_dim, in_channels)
|
71 |
+
|
72 |
+
def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1):
|
73 |
+
batch, _, height, width = hidden_states.shape
|
74 |
+
residual = hidden_states
|
75 |
+
|
76 |
+
hidden_states = self.norm(hidden_states)
|
77 |
+
inner_dim = hidden_states.shape[1]
|
78 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
79 |
+
hidden_states = self.proj_in(hidden_states)
|
80 |
+
|
81 |
+
for block in self.transformer_blocks:
|
82 |
+
hidden_states = block(
|
83 |
+
hidden_states,
|
84 |
+
batch_size=batch_size
|
85 |
+
)
|
86 |
+
|
87 |
+
hidden_states = self.proj_out(hidden_states)
|
88 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
89 |
+
hidden_states = hidden_states + residual
|
90 |
+
|
91 |
+
return hidden_states, time_emb, text_emb, res_stack
|
92 |
+
|
93 |
+
|
94 |
+
class SDMotionModel(torch.nn.Module):
|
95 |
+
def __init__(self):
|
96 |
+
super().__init__()
|
97 |
+
self.motion_modules = torch.nn.ModuleList([
|
98 |
+
TemporalBlock(8, 40, 320, eps=1e-6),
|
99 |
+
TemporalBlock(8, 40, 320, eps=1e-6),
|
100 |
+
TemporalBlock(8, 80, 640, eps=1e-6),
|
101 |
+
TemporalBlock(8, 80, 640, eps=1e-6),
|
102 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
103 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
104 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
105 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
106 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
107 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
108 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
109 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
110 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
111 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
112 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
113 |
+
TemporalBlock(8, 80, 640, eps=1e-6),
|
114 |
+
TemporalBlock(8, 80, 640, eps=1e-6),
|
115 |
+
TemporalBlock(8, 80, 640, eps=1e-6),
|
116 |
+
TemporalBlock(8, 40, 320, eps=1e-6),
|
117 |
+
TemporalBlock(8, 40, 320, eps=1e-6),
|
118 |
+
TemporalBlock(8, 40, 320, eps=1e-6),
|
119 |
+
])
|
120 |
+
self.call_block_id = {
|
121 |
+
1: 0,
|
122 |
+
4: 1,
|
123 |
+
9: 2,
|
124 |
+
12: 3,
|
125 |
+
17: 4,
|
126 |
+
20: 5,
|
127 |
+
24: 6,
|
128 |
+
26: 7,
|
129 |
+
29: 8,
|
130 |
+
32: 9,
|
131 |
+
34: 10,
|
132 |
+
36: 11,
|
133 |
+
40: 12,
|
134 |
+
43: 13,
|
135 |
+
46: 14,
|
136 |
+
50: 15,
|
137 |
+
53: 16,
|
138 |
+
56: 17,
|
139 |
+
60: 18,
|
140 |
+
63: 19,
|
141 |
+
66: 20
|
142 |
+
}
|
143 |
+
|
144 |
+
def forward(self):
|
145 |
+
pass
|
146 |
+
|
147 |
+
@staticmethod
|
148 |
+
def state_dict_converter():
|
149 |
+
return SDMotionModelStateDictConverter()
|
150 |
+
|
151 |
+
|
152 |
+
class SDMotionModelStateDictConverter:
|
153 |
+
def __init__(self):
|
154 |
+
pass
|
155 |
+
|
156 |
+
def from_diffusers(self, state_dict):
|
157 |
+
rename_dict = {
|
158 |
+
"norm": "norm",
|
159 |
+
"proj_in": "proj_in",
|
160 |
+
"transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q",
|
161 |
+
"transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k",
|
162 |
+
"transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v",
|
163 |
+
"transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out",
|
164 |
+
"transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1",
|
165 |
+
"transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q",
|
166 |
+
"transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k",
|
167 |
+
"transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v",
|
168 |
+
"transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out",
|
169 |
+
"transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2",
|
170 |
+
"transformer_blocks.0.norms.0": "transformer_blocks.0.norm1",
|
171 |
+
"transformer_blocks.0.norms.1": "transformer_blocks.0.norm2",
|
172 |
+
"transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj",
|
173 |
+
"transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff",
|
174 |
+
"transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3",
|
175 |
+
"proj_out": "proj_out",
|
176 |
+
}
|
177 |
+
name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")])
|
178 |
+
name_list += sorted([i for i in state_dict if i.startswith("mid_block.")])
|
179 |
+
name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")])
|
180 |
+
state_dict_ = {}
|
181 |
+
last_prefix, module_id = "", -1
|
182 |
+
for name in name_list:
|
183 |
+
names = name.split(".")
|
184 |
+
prefix_index = names.index("temporal_transformer") + 1
|
185 |
+
prefix = ".".join(names[:prefix_index])
|
186 |
+
if prefix != last_prefix:
|
187 |
+
last_prefix = prefix
|
188 |
+
module_id += 1
|
189 |
+
middle_name = ".".join(names[prefix_index:-1])
|
190 |
+
suffix = names[-1]
|
191 |
+
if "pos_encoder" in names:
|
192 |
+
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]])
|
193 |
+
else:
|
194 |
+
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
|
195 |
+
state_dict_[rename] = state_dict[name]
|
196 |
+
return state_dict_
|
197 |
+
|
198 |
+
def from_civitai(self, state_dict):
|
199 |
+
return self.from_diffusers(state_dict)
|
diffsynth/models/sd_text_encoder.py
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .attention import Attention
|
3 |
+
|
4 |
+
|
5 |
+
class CLIPEncoderLayer(torch.nn.Module):
|
6 |
+
def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
|
7 |
+
super().__init__()
|
8 |
+
self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
|
9 |
+
self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
|
10 |
+
self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
|
11 |
+
self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
|
12 |
+
self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
|
13 |
+
|
14 |
+
self.use_quick_gelu = use_quick_gelu
|
15 |
+
|
16 |
+
def quickGELU(self, x):
|
17 |
+
return x * torch.sigmoid(1.702 * x)
|
18 |
+
|
19 |
+
def forward(self, hidden_states, attn_mask=None):
|
20 |
+
residual = hidden_states
|
21 |
+
|
22 |
+
hidden_states = self.layer_norm1(hidden_states)
|
23 |
+
hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
|
24 |
+
hidden_states = residual + hidden_states
|
25 |
+
|
26 |
+
residual = hidden_states
|
27 |
+
hidden_states = self.layer_norm2(hidden_states)
|
28 |
+
hidden_states = self.fc1(hidden_states)
|
29 |
+
if self.use_quick_gelu:
|
30 |
+
hidden_states = self.quickGELU(hidden_states)
|
31 |
+
else:
|
32 |
+
hidden_states = torch.nn.functional.gelu(hidden_states)
|
33 |
+
hidden_states = self.fc2(hidden_states)
|
34 |
+
hidden_states = residual + hidden_states
|
35 |
+
|
36 |
+
return hidden_states
|
37 |
+
|
38 |
+
|
39 |
+
class SDTextEncoder(torch.nn.Module):
|
40 |
+
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
|
41 |
+
super().__init__()
|
42 |
+
|
43 |
+
# token_embedding
|
44 |
+
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
45 |
+
|
46 |
+
# position_embeds (This is a fixed tensor)
|
47 |
+
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
48 |
+
|
49 |
+
# encoders
|
50 |
+
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
51 |
+
|
52 |
+
# attn_mask
|
53 |
+
self.attn_mask = self.attention_mask(max_position_embeddings)
|
54 |
+
|
55 |
+
# final_layer_norm
|
56 |
+
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
57 |
+
|
58 |
+
def attention_mask(self, length):
|
59 |
+
mask = torch.empty(length, length)
|
60 |
+
mask.fill_(float("-inf"))
|
61 |
+
mask.triu_(1)
|
62 |
+
return mask
|
63 |
+
|
64 |
+
def forward(self, input_ids, clip_skip=1):
|
65 |
+
embeds = self.token_embedding(input_ids) + self.position_embeds
|
66 |
+
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
67 |
+
for encoder_id, encoder in enumerate(self.encoders):
|
68 |
+
embeds = encoder(embeds, attn_mask=attn_mask)
|
69 |
+
if encoder_id + clip_skip == len(self.encoders):
|
70 |
+
break
|
71 |
+
embeds = self.final_layer_norm(embeds)
|
72 |
+
return embeds
|
73 |
+
|
74 |
+
@staticmethod
|
75 |
+
def state_dict_converter():
|
76 |
+
return SDTextEncoderStateDictConverter()
|
77 |
+
|
78 |
+
|
79 |
+
class SDTextEncoderStateDictConverter:
|
80 |
+
def __init__(self):
|
81 |
+
pass
|
82 |
+
|
83 |
+
def from_diffusers(self, state_dict):
|
84 |
+
rename_dict = {
|
85 |
+
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
86 |
+
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
87 |
+
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
88 |
+
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
|
89 |
+
}
|
90 |
+
attn_rename_dict = {
|
91 |
+
"self_attn.q_proj": "attn.to_q",
|
92 |
+
"self_attn.k_proj": "attn.to_k",
|
93 |
+
"self_attn.v_proj": "attn.to_v",
|
94 |
+
"self_attn.out_proj": "attn.to_out",
|
95 |
+
"layer_norm1": "layer_norm1",
|
96 |
+
"layer_norm2": "layer_norm2",
|
97 |
+
"mlp.fc1": "fc1",
|
98 |
+
"mlp.fc2": "fc2",
|
99 |
+
}
|
100 |
+
state_dict_ = {}
|
101 |
+
for name in state_dict:
|
102 |
+
if name in rename_dict:
|
103 |
+
param = state_dict[name]
|
104 |
+
if name == "text_model.embeddings.position_embedding.weight":
|
105 |
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
106 |
+
state_dict_[rename_dict[name]] = param
|
107 |
+
elif name.startswith("text_model.encoder.layers."):
|
108 |
+
param = state_dict[name]
|
109 |
+
names = name.split(".")
|
110 |
+
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
111 |
+
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
112 |
+
state_dict_[name_] = param
|
113 |
+
return state_dict_
|
114 |
+
|
115 |
+
def from_civitai(self, state_dict):
|
116 |
+
rename_dict = {
|
117 |
+
"cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
118 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
|
119 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
|
120 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
|
121 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
|
122 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
|
123 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
|
124 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
|
125 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
|
126 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
|
127 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
|
128 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
129 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
130 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
|
131 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
|
132 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
|
133 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
|
134 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
|
135 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
|
136 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
|
137 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
|
138 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
|
139 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
|
140 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
|
141 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
|
142 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
|
143 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
|
144 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
145 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
146 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
|
147 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
|
148 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
|
149 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
|
150 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
|
151 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
|
152 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
|
153 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
|
154 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
|
155 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
|
156 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
|
157 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
|
158 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
|
159 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
|
160 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
161 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
162 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
|
163 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
|
164 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
|
165 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
|
166 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias",
|
167 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight",
|
168 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias",
|
169 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight",
|
170 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias",
|
171 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight",
|
172 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias",
|
173 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight",
|
174 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias",
|
175 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight",
|
176 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias",
|
177 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight",
|
178 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias",
|
179 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight",
|
180 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias",
|
181 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight",
|
182 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
|
183 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
|
184 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
|
185 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
|
186 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
|
187 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
|
188 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
|
189 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
|
190 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
|
191 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
|
192 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
193 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
194 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
|
195 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
|
196 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
|
197 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
|
198 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
|
199 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
|
200 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
|
201 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
|
202 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
|
203 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
|
204 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
|
205 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
|
206 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
|
207 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
|
208 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
209 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
210 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
|
211 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
|
212 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
|
213 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
|
214 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
|
215 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
|
216 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
|
217 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
|
218 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
|
219 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
|
220 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
|
221 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
|
222 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
|
223 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
|
224 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
225 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
226 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
|
227 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
|
228 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
|
229 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
|
230 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
|
231 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
|
232 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
|
233 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
|
234 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
|
235 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
|
236 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
|
237 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
|
238 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
|
239 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
|
240 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
241 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
242 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
|
243 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
|
244 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
|
245 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
|
246 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
|
247 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
|
248 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
|
249 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
|
250 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
|
251 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
|
252 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
|
253 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
|
254 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
|
255 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
|
256 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
257 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
258 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
|
259 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
|
260 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
|
261 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
|
262 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
|
263 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
|
264 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
|
265 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
|
266 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
|
267 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
|
268 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
|
269 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
|
270 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
|
271 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
|
272 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
273 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
274 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
|
275 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
|
276 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
|
277 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
|
278 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
|
279 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
|
280 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
|
281 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
|
282 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
|
283 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
|
284 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
|
285 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
|
286 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
|
287 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
|
288 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
289 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
290 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
|
291 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
|
292 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
|
293 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
|
294 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
|
295 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
|
296 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
|
297 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
|
298 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
|
299 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
|
300 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
|
301 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
|
302 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
|
303 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
|
304 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
305 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
306 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
|
307 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
|
308 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
|
309 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
|
310 |
+
"cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias",
|
311 |
+
"cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
312 |
+
"cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds"
|
313 |
+
}
|
314 |
+
state_dict_ = {}
|
315 |
+
for name in state_dict:
|
316 |
+
if name in rename_dict:
|
317 |
+
param = state_dict[name]
|
318 |
+
if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight":
|
319 |
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
320 |
+
state_dict_[rename_dict[name]] = param
|
321 |
+
return state_dict_
|
diffsynth/models/sd_unet.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
diffsynth/models/sd_vae_decoder.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .attention import Attention
|
3 |
+
from .sd_unet import ResnetBlock, UpSampler
|
4 |
+
from .tiler import TileWorker
|
5 |
+
|
6 |
+
|
7 |
+
class VAEAttentionBlock(torch.nn.Module):
|
8 |
+
|
9 |
+
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
|
10 |
+
super().__init__()
|
11 |
+
inner_dim = num_attention_heads * attention_head_dim
|
12 |
+
|
13 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
14 |
+
|
15 |
+
self.transformer_blocks = torch.nn.ModuleList([
|
16 |
+
Attention(
|
17 |
+
inner_dim,
|
18 |
+
num_attention_heads,
|
19 |
+
attention_head_dim,
|
20 |
+
bias_q=True,
|
21 |
+
bias_kv=True,
|
22 |
+
bias_out=True
|
23 |
+
)
|
24 |
+
for d in range(num_layers)
|
25 |
+
])
|
26 |
+
|
27 |
+
def forward(self, hidden_states, time_emb, text_emb, res_stack):
|
28 |
+
batch, _, height, width = hidden_states.shape
|
29 |
+
residual = hidden_states
|
30 |
+
|
31 |
+
hidden_states = self.norm(hidden_states)
|
32 |
+
inner_dim = hidden_states.shape[1]
|
33 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
34 |
+
|
35 |
+
for block in self.transformer_blocks:
|
36 |
+
hidden_states = block(hidden_states)
|
37 |
+
|
38 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
39 |
+
hidden_states = hidden_states + residual
|
40 |
+
|
41 |
+
return hidden_states, time_emb, text_emb, res_stack
|
42 |
+
|
43 |
+
|
44 |
+
class SDVAEDecoder(torch.nn.Module):
|
45 |
+
def __init__(self):
|
46 |
+
super().__init__()
|
47 |
+
self.scaling_factor = 0.18215
|
48 |
+
self.post_quant_conv = torch.nn.Conv2d(4, 4, kernel_size=1)
|
49 |
+
self.conv_in = torch.nn.Conv2d(4, 512, kernel_size=3, padding=1)
|
50 |
+
|
51 |
+
self.blocks = torch.nn.ModuleList([
|
52 |
+
# UNetMidBlock2D
|
53 |
+
ResnetBlock(512, 512, eps=1e-6),
|
54 |
+
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
55 |
+
ResnetBlock(512, 512, eps=1e-6),
|
56 |
+
# UpDecoderBlock2D
|
57 |
+
ResnetBlock(512, 512, eps=1e-6),
|
58 |
+
ResnetBlock(512, 512, eps=1e-6),
|
59 |
+
ResnetBlock(512, 512, eps=1e-6),
|
60 |
+
UpSampler(512),
|
61 |
+
# UpDecoderBlock2D
|
62 |
+
ResnetBlock(512, 512, eps=1e-6),
|
63 |
+
ResnetBlock(512, 512, eps=1e-6),
|
64 |
+
ResnetBlock(512, 512, eps=1e-6),
|
65 |
+
UpSampler(512),
|
66 |
+
# UpDecoderBlock2D
|
67 |
+
ResnetBlock(512, 256, eps=1e-6),
|
68 |
+
ResnetBlock(256, 256, eps=1e-6),
|
69 |
+
ResnetBlock(256, 256, eps=1e-6),
|
70 |
+
UpSampler(256),
|
71 |
+
# UpDecoderBlock2D
|
72 |
+
ResnetBlock(256, 128, eps=1e-6),
|
73 |
+
ResnetBlock(128, 128, eps=1e-6),
|
74 |
+
ResnetBlock(128, 128, eps=1e-6),
|
75 |
+
])
|
76 |
+
|
77 |
+
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-5)
|
78 |
+
self.conv_act = torch.nn.SiLU()
|
79 |
+
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
80 |
+
|
81 |
+
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
82 |
+
hidden_states = TileWorker().tiled_forward(
|
83 |
+
lambda x: self.forward(x),
|
84 |
+
sample,
|
85 |
+
tile_size,
|
86 |
+
tile_stride,
|
87 |
+
tile_device=sample.device,
|
88 |
+
tile_dtype=sample.dtype
|
89 |
+
)
|
90 |
+
return hidden_states
|
91 |
+
|
92 |
+
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
93 |
+
original_dtype = sample.dtype
|
94 |
+
sample = sample.to(dtype=next(iter(self.parameters())).dtype)
|
95 |
+
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
96 |
+
if tiled:
|
97 |
+
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
98 |
+
|
99 |
+
# 1. pre-process
|
100 |
+
sample = sample / self.scaling_factor
|
101 |
+
hidden_states = self.post_quant_conv(sample)
|
102 |
+
hidden_states = self.conv_in(hidden_states)
|
103 |
+
time_emb = None
|
104 |
+
text_emb = None
|
105 |
+
res_stack = None
|
106 |
+
|
107 |
+
# 2. blocks
|
108 |
+
for i, block in enumerate(self.blocks):
|
109 |
+
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
110 |
+
|
111 |
+
# 3. output
|
112 |
+
hidden_states = self.conv_norm_out(hidden_states)
|
113 |
+
hidden_states = self.conv_act(hidden_states)
|
114 |
+
hidden_states = self.conv_out(hidden_states)
|
115 |
+
hidden_states = hidden_states.to(original_dtype)
|
116 |
+
|
117 |
+
return hidden_states
|
118 |
+
|
119 |
+
@staticmethod
|
120 |
+
def state_dict_converter():
|
121 |
+
return SDVAEDecoderStateDictConverter()
|
122 |
+
|
123 |
+
|
124 |
+
class SDVAEDecoderStateDictConverter:
|
125 |
+
def __init__(self):
|
126 |
+
pass
|
127 |
+
|
128 |
+
def from_diffusers(self, state_dict):
|
129 |
+
# architecture
|
130 |
+
block_types = [
|
131 |
+
'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock',
|
132 |
+
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
|
133 |
+
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
|
134 |
+
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
|
135 |
+
'ResnetBlock', 'ResnetBlock', 'ResnetBlock'
|
136 |
+
]
|
137 |
+
|
138 |
+
# Rename each parameter
|
139 |
+
local_rename_dict = {
|
140 |
+
"post_quant_conv": "post_quant_conv",
|
141 |
+
"decoder.conv_in": "conv_in",
|
142 |
+
"decoder.mid_block.attentions.0.group_norm": "blocks.1.norm",
|
143 |
+
"decoder.mid_block.attentions.0.to_q": "blocks.1.transformer_blocks.0.to_q",
|
144 |
+
"decoder.mid_block.attentions.0.to_k": "blocks.1.transformer_blocks.0.to_k",
|
145 |
+
"decoder.mid_block.attentions.0.to_v": "blocks.1.transformer_blocks.0.to_v",
|
146 |
+
"decoder.mid_block.attentions.0.to_out.0": "blocks.1.transformer_blocks.0.to_out",
|
147 |
+
"decoder.mid_block.resnets.0.norm1": "blocks.0.norm1",
|
148 |
+
"decoder.mid_block.resnets.0.conv1": "blocks.0.conv1",
|
149 |
+
"decoder.mid_block.resnets.0.norm2": "blocks.0.norm2",
|
150 |
+
"decoder.mid_block.resnets.0.conv2": "blocks.0.conv2",
|
151 |
+
"decoder.mid_block.resnets.1.norm1": "blocks.2.norm1",
|
152 |
+
"decoder.mid_block.resnets.1.conv1": "blocks.2.conv1",
|
153 |
+
"decoder.mid_block.resnets.1.norm2": "blocks.2.norm2",
|
154 |
+
"decoder.mid_block.resnets.1.conv2": "blocks.2.conv2",
|
155 |
+
"decoder.conv_norm_out": "conv_norm_out",
|
156 |
+
"decoder.conv_out": "conv_out",
|
157 |
+
}
|
158 |
+
name_list = sorted([name for name in state_dict])
|
159 |
+
rename_dict = {}
|
160 |
+
block_id = {"ResnetBlock": 2, "DownSampler": 2, "UpSampler": 2}
|
161 |
+
last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
|
162 |
+
for name in name_list:
|
163 |
+
names = name.split(".")
|
164 |
+
name_prefix = ".".join(names[:-1])
|
165 |
+
if name_prefix in local_rename_dict:
|
166 |
+
rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
|
167 |
+
elif name.startswith("decoder.up_blocks"):
|
168 |
+
block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
|
169 |
+
block_type_with_id = ".".join(names[:5])
|
170 |
+
if block_type_with_id != last_block_type_with_id[block_type]:
|
171 |
+
block_id[block_type] += 1
|
172 |
+
last_block_type_with_id[block_type] = block_type_with_id
|
173 |
+
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
174 |
+
block_id[block_type] += 1
|
175 |
+
block_type_with_id = ".".join(names[:5])
|
176 |
+
names = ["blocks", str(block_id[block_type])] + names[5:]
|
177 |
+
rename_dict[name] = ".".join(names)
|
178 |
+
|
179 |
+
# Convert state_dict
|
180 |
+
state_dict_ = {}
|
181 |
+
for name, param in state_dict.items():
|
182 |
+
if name in rename_dict:
|
183 |
+
state_dict_[rename_dict[name]] = param
|
184 |
+
return state_dict_
|
185 |
+
|
186 |
+
def from_civitai(self, state_dict):
|
187 |
+
rename_dict = {
|
188 |
+
"first_stage_model.decoder.conv_in.bias": "conv_in.bias",
|
189 |
+
"first_stage_model.decoder.conv_in.weight": "conv_in.weight",
|
190 |
+
"first_stage_model.decoder.conv_out.bias": "conv_out.bias",
|
191 |
+
"first_stage_model.decoder.conv_out.weight": "conv_out.weight",
|
192 |
+
"first_stage_model.decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
|
193 |
+
"first_stage_model.decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
|
194 |
+
"first_stage_model.decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
|
195 |
+
"first_stage_model.decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
|
196 |
+
"first_stage_model.decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
|
197 |
+
"first_stage_model.decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
|
198 |
+
"first_stage_model.decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
|
199 |
+
"first_stage_model.decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
|
200 |
+
"first_stage_model.decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
|
201 |
+
"first_stage_model.decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
|
202 |
+
"first_stage_model.decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
|
203 |
+
"first_stage_model.decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
|
204 |
+
"first_stage_model.decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
|
205 |
+
"first_stage_model.decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
|
206 |
+
"first_stage_model.decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
|
207 |
+
"first_stage_model.decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
|
208 |
+
"first_stage_model.decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
|
209 |
+
"first_stage_model.decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
|
210 |
+
"first_stage_model.decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
|
211 |
+
"first_stage_model.decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
|
212 |
+
"first_stage_model.decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
|
213 |
+
"first_stage_model.decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
|
214 |
+
"first_stage_model.decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
|
215 |
+
"first_stage_model.decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
|
216 |
+
"first_stage_model.decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
|
217 |
+
"first_stage_model.decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
|
218 |
+
"first_stage_model.decoder.norm_out.bias": "conv_norm_out.bias",
|
219 |
+
"first_stage_model.decoder.norm_out.weight": "conv_norm_out.weight",
|
220 |
+
"first_stage_model.decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
|
221 |
+
"first_stage_model.decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
|
222 |
+
"first_stage_model.decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
|
223 |
+
"first_stage_model.decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
|
224 |
+
"first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
|
225 |
+
"first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
|
226 |
+
"first_stage_model.decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
|
227 |
+
"first_stage_model.decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
|
228 |
+
"first_stage_model.decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
|
229 |
+
"first_stage_model.decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
|
230 |
+
"first_stage_model.decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
|
231 |
+
"first_stage_model.decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
|
232 |
+
"first_stage_model.decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
|
233 |
+
"first_stage_model.decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
|
234 |
+
"first_stage_model.decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
|
235 |
+
"first_stage_model.decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
|
236 |
+
"first_stage_model.decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
|
237 |
+
"first_stage_model.decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
|
238 |
+
"first_stage_model.decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
|
239 |
+
"first_stage_model.decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
|
240 |
+
"first_stage_model.decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
|
241 |
+
"first_stage_model.decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
|
242 |
+
"first_stage_model.decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
|
243 |
+
"first_stage_model.decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
|
244 |
+
"first_stage_model.decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
|
245 |
+
"first_stage_model.decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
|
246 |
+
"first_stage_model.decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
|
247 |
+
"first_stage_model.decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
|
248 |
+
"first_stage_model.decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
|
249 |
+
"first_stage_model.decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
|
250 |
+
"first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
|
251 |
+
"first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
|
252 |
+
"first_stage_model.decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
|
253 |
+
"first_stage_model.decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
|
254 |
+
"first_stage_model.decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
|
255 |
+
"first_stage_model.decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
|
256 |
+
"first_stage_model.decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
|
257 |
+
"first_stage_model.decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
|
258 |
+
"first_stage_model.decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
|
259 |
+
"first_stage_model.decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
|
260 |
+
"first_stage_model.decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
|
261 |
+
"first_stage_model.decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
|
262 |
+
"first_stage_model.decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
|
263 |
+
"first_stage_model.decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
|
264 |
+
"first_stage_model.decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
|
265 |
+
"first_stage_model.decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
|
266 |
+
"first_stage_model.decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
|
267 |
+
"first_stage_model.decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
|
268 |
+
"first_stage_model.decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
|
269 |
+
"first_stage_model.decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
|
270 |
+
"first_stage_model.decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
|
271 |
+
"first_stage_model.decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
|
272 |
+
"first_stage_model.decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
|
273 |
+
"first_stage_model.decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
|
274 |
+
"first_stage_model.decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
|
275 |
+
"first_stage_model.decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
|
276 |
+
"first_stage_model.decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
|
277 |
+
"first_stage_model.decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
|
278 |
+
"first_stage_model.decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
|
279 |
+
"first_stage_model.decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
|
280 |
+
"first_stage_model.decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
|
281 |
+
"first_stage_model.decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
|
282 |
+
"first_stage_model.decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
|
283 |
+
"first_stage_model.decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
|
284 |
+
"first_stage_model.decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
|
285 |
+
"first_stage_model.decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
|
286 |
+
"first_stage_model.decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
|
287 |
+
"first_stage_model.decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
|
288 |
+
"first_stage_model.decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
|
289 |
+
"first_stage_model.decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
|
290 |
+
"first_stage_model.decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
|
291 |
+
"first_stage_model.decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
|
292 |
+
"first_stage_model.decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
|
293 |
+
"first_stage_model.decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
|
294 |
+
"first_stage_model.decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
|
295 |
+
"first_stage_model.decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
|
296 |
+
"first_stage_model.decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
|
297 |
+
"first_stage_model.decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
|
298 |
+
"first_stage_model.decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
|
299 |
+
"first_stage_model.decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
|
300 |
+
"first_stage_model.decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
|
301 |
+
"first_stage_model.decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
|
302 |
+
"first_stage_model.decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
|
303 |
+
"first_stage_model.decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
|
304 |
+
"first_stage_model.decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
|
305 |
+
"first_stage_model.decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
|
306 |
+
"first_stage_model.decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
|
307 |
+
"first_stage_model.decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
|
308 |
+
"first_stage_model.decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
|
309 |
+
"first_stage_model.decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
|
310 |
+
"first_stage_model.decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
|
311 |
+
"first_stage_model.decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
|
312 |
+
"first_stage_model.decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
|
313 |
+
"first_stage_model.decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
|
314 |
+
"first_stage_model.decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
|
315 |
+
"first_stage_model.decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
|
316 |
+
"first_stage_model.decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
|
317 |
+
"first_stage_model.decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
|
318 |
+
"first_stage_model.decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
|
319 |
+
"first_stage_model.decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
|
320 |
+
"first_stage_model.decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
|
321 |
+
"first_stage_model.decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
|
322 |
+
"first_stage_model.decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
|
323 |
+
"first_stage_model.decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
|
324 |
+
"first_stage_model.decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
|
325 |
+
"first_stage_model.decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
|
326 |
+
"first_stage_model.post_quant_conv.bias": "post_quant_conv.bias",
|
327 |
+
"first_stage_model.post_quant_conv.weight": "post_quant_conv.weight",
|
328 |
+
}
|
329 |
+
state_dict_ = {}
|
330 |
+
for name in state_dict:
|
331 |
+
if name in rename_dict:
|
332 |
+
param = state_dict[name]
|
333 |
+
if "transformer_blocks" in rename_dict[name]:
|
334 |
+
param = param.squeeze()
|
335 |
+
state_dict_[rename_dict[name]] = param
|
336 |
+
return state_dict_
|
diffsynth/models/sd_vae_encoder.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .sd_unet import ResnetBlock, DownSampler
|
3 |
+
from .sd_vae_decoder import VAEAttentionBlock
|
4 |
+
from .tiler import TileWorker
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
|
8 |
+
class SDVAEEncoder(torch.nn.Module):
|
9 |
+
def __init__(self):
|
10 |
+
super().__init__()
|
11 |
+
self.scaling_factor = 0.18215
|
12 |
+
self.quant_conv = torch.nn.Conv2d(8, 8, kernel_size=1)
|
13 |
+
self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
|
14 |
+
|
15 |
+
self.blocks = torch.nn.ModuleList([
|
16 |
+
# DownEncoderBlock2D
|
17 |
+
ResnetBlock(128, 128, eps=1e-6),
|
18 |
+
ResnetBlock(128, 128, eps=1e-6),
|
19 |
+
DownSampler(128, padding=0, extra_padding=True),
|
20 |
+
# DownEncoderBlock2D
|
21 |
+
ResnetBlock(128, 256, eps=1e-6),
|
22 |
+
ResnetBlock(256, 256, eps=1e-6),
|
23 |
+
DownSampler(256, padding=0, extra_padding=True),
|
24 |
+
# DownEncoderBlock2D
|
25 |
+
ResnetBlock(256, 512, eps=1e-6),
|
26 |
+
ResnetBlock(512, 512, eps=1e-6),
|
27 |
+
DownSampler(512, padding=0, extra_padding=True),
|
28 |
+
# DownEncoderBlock2D
|
29 |
+
ResnetBlock(512, 512, eps=1e-6),
|
30 |
+
ResnetBlock(512, 512, eps=1e-6),
|
31 |
+
# UNetMidBlock2D
|
32 |
+
ResnetBlock(512, 512, eps=1e-6),
|
33 |
+
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
34 |
+
ResnetBlock(512, 512, eps=1e-6),
|
35 |
+
])
|
36 |
+
|
37 |
+
self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
|
38 |
+
self.conv_act = torch.nn.SiLU()
|
39 |
+
self.conv_out = torch.nn.Conv2d(512, 8, kernel_size=3, padding=1)
|
40 |
+
|
41 |
+
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
42 |
+
hidden_states = TileWorker().tiled_forward(
|
43 |
+
lambda x: self.forward(x),
|
44 |
+
sample,
|
45 |
+
tile_size,
|
46 |
+
tile_stride,
|
47 |
+
tile_device=sample.device,
|
48 |
+
tile_dtype=sample.dtype
|
49 |
+
)
|
50 |
+
return hidden_states
|
51 |
+
|
52 |
+
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
53 |
+
original_dtype = sample.dtype
|
54 |
+
sample = sample.to(dtype=next(iter(self.parameters())).dtype)
|
55 |
+
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
56 |
+
if tiled:
|
57 |
+
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
58 |
+
|
59 |
+
# 1. pre-process
|
60 |
+
hidden_states = self.conv_in(sample)
|
61 |
+
time_emb = None
|
62 |
+
text_emb = None
|
63 |
+
res_stack = None
|
64 |
+
|
65 |
+
# 2. blocks
|
66 |
+
for i, block in enumerate(self.blocks):
|
67 |
+
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
68 |
+
|
69 |
+
# 3. output
|
70 |
+
hidden_states = self.conv_norm_out(hidden_states)
|
71 |
+
hidden_states = self.conv_act(hidden_states)
|
72 |
+
hidden_states = self.conv_out(hidden_states)
|
73 |
+
hidden_states = self.quant_conv(hidden_states)
|
74 |
+
hidden_states = hidden_states[:, :4]
|
75 |
+
hidden_states *= self.scaling_factor
|
76 |
+
hidden_states = hidden_states.to(original_dtype)
|
77 |
+
|
78 |
+
return hidden_states
|
79 |
+
|
80 |
+
def encode_video(self, sample, batch_size=8):
|
81 |
+
B = sample.shape[0]
|
82 |
+
hidden_states = []
|
83 |
+
|
84 |
+
for i in range(0, sample.shape[2], batch_size):
|
85 |
+
|
86 |
+
j = min(i + batch_size, sample.shape[2])
|
87 |
+
sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
|
88 |
+
|
89 |
+
hidden_states_batch = self(sample_batch)
|
90 |
+
hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
|
91 |
+
|
92 |
+
hidden_states.append(hidden_states_batch)
|
93 |
+
|
94 |
+
hidden_states = torch.concat(hidden_states, dim=2)
|
95 |
+
return hidden_states
|
96 |
+
|
97 |
+
@staticmethod
|
98 |
+
def state_dict_converter():
|
99 |
+
return SDVAEEncoderStateDictConverter()
|
100 |
+
|
101 |
+
|
102 |
+
class SDVAEEncoderStateDictConverter:
|
103 |
+
def __init__(self):
|
104 |
+
pass
|
105 |
+
|
106 |
+
def from_diffusers(self, state_dict):
|
107 |
+
# architecture
|
108 |
+
block_types = [
|
109 |
+
'ResnetBlock', 'ResnetBlock', 'DownSampler',
|
110 |
+
'ResnetBlock', 'ResnetBlock', 'DownSampler',
|
111 |
+
'ResnetBlock', 'ResnetBlock', 'DownSampler',
|
112 |
+
'ResnetBlock', 'ResnetBlock',
|
113 |
+
'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock'
|
114 |
+
]
|
115 |
+
|
116 |
+
# Rename each parameter
|
117 |
+
local_rename_dict = {
|
118 |
+
"quant_conv": "quant_conv",
|
119 |
+
"encoder.conv_in": "conv_in",
|
120 |
+
"encoder.mid_block.attentions.0.group_norm": "blocks.12.norm",
|
121 |
+
"encoder.mid_block.attentions.0.to_q": "blocks.12.transformer_blocks.0.to_q",
|
122 |
+
"encoder.mid_block.attentions.0.to_k": "blocks.12.transformer_blocks.0.to_k",
|
123 |
+
"encoder.mid_block.attentions.0.to_v": "blocks.12.transformer_blocks.0.to_v",
|
124 |
+
"encoder.mid_block.attentions.0.to_out.0": "blocks.12.transformer_blocks.0.to_out",
|
125 |
+
"encoder.mid_block.resnets.0.norm1": "blocks.11.norm1",
|
126 |
+
"encoder.mid_block.resnets.0.conv1": "blocks.11.conv1",
|
127 |
+
"encoder.mid_block.resnets.0.norm2": "blocks.11.norm2",
|
128 |
+
"encoder.mid_block.resnets.0.conv2": "blocks.11.conv2",
|
129 |
+
"encoder.mid_block.resnets.1.norm1": "blocks.13.norm1",
|
130 |
+
"encoder.mid_block.resnets.1.conv1": "blocks.13.conv1",
|
131 |
+
"encoder.mid_block.resnets.1.norm2": "blocks.13.norm2",
|
132 |
+
"encoder.mid_block.resnets.1.conv2": "blocks.13.conv2",
|
133 |
+
"encoder.conv_norm_out": "conv_norm_out",
|
134 |
+
"encoder.conv_out": "conv_out",
|
135 |
+
}
|
136 |
+
name_list = sorted([name for name in state_dict])
|
137 |
+
rename_dict = {}
|
138 |
+
block_id = {"ResnetBlock": -1, "DownSampler": -1, "UpSampler": -1}
|
139 |
+
last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
|
140 |
+
for name in name_list:
|
141 |
+
names = name.split(".")
|
142 |
+
name_prefix = ".".join(names[:-1])
|
143 |
+
if name_prefix in local_rename_dict:
|
144 |
+
rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
|
145 |
+
elif name.startswith("encoder.down_blocks"):
|
146 |
+
block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
|
147 |
+
block_type_with_id = ".".join(names[:5])
|
148 |
+
if block_type_with_id != last_block_type_with_id[block_type]:
|
149 |
+
block_id[block_type] += 1
|
150 |
+
last_block_type_with_id[block_type] = block_type_with_id
|
151 |
+
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
152 |
+
block_id[block_type] += 1
|
153 |
+
block_type_with_id = ".".join(names[:5])
|
154 |
+
names = ["blocks", str(block_id[block_type])] + names[5:]
|
155 |
+
rename_dict[name] = ".".join(names)
|
156 |
+
|
157 |
+
# Convert state_dict
|
158 |
+
state_dict_ = {}
|
159 |
+
for name, param in state_dict.items():
|
160 |
+
if name in rename_dict:
|
161 |
+
state_dict_[rename_dict[name]] = param
|
162 |
+
return state_dict_
|
163 |
+
|
164 |
+
def from_civitai(self, state_dict):
|
165 |
+
rename_dict = {
|
166 |
+
"first_stage_model.encoder.conv_in.bias": "conv_in.bias",
|
167 |
+
"first_stage_model.encoder.conv_in.weight": "conv_in.weight",
|
168 |
+
"first_stage_model.encoder.conv_out.bias": "conv_out.bias",
|
169 |
+
"first_stage_model.encoder.conv_out.weight": "conv_out.weight",
|
170 |
+
"first_stage_model.encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
|
171 |
+
"first_stage_model.encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
|
172 |
+
"first_stage_model.encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
|
173 |
+
"first_stage_model.encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
|
174 |
+
"first_stage_model.encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
|
175 |
+
"first_stage_model.encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
|
176 |
+
"first_stage_model.encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
|
177 |
+
"first_stage_model.encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
|
178 |
+
"first_stage_model.encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
|
179 |
+
"first_stage_model.encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
|
180 |
+
"first_stage_model.encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
|
181 |
+
"first_stage_model.encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
|
182 |
+
"first_stage_model.encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
|
183 |
+
"first_stage_model.encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
|
184 |
+
"first_stage_model.encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
|
185 |
+
"first_stage_model.encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
|
186 |
+
"first_stage_model.encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
|
187 |
+
"first_stage_model.encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
|
188 |
+
"first_stage_model.encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
|
189 |
+
"first_stage_model.encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
|
190 |
+
"first_stage_model.encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
|
191 |
+
"first_stage_model.encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
|
192 |
+
"first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
|
193 |
+
"first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
|
194 |
+
"first_stage_model.encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
|
195 |
+
"first_stage_model.encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
|
196 |
+
"first_stage_model.encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
|
197 |
+
"first_stage_model.encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
|
198 |
+
"first_stage_model.encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
|
199 |
+
"first_stage_model.encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
|
200 |
+
"first_stage_model.encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
|
201 |
+
"first_stage_model.encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
|
202 |
+
"first_stage_model.encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
|
203 |
+
"first_stage_model.encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
|
204 |
+
"first_stage_model.encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
|
205 |
+
"first_stage_model.encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
|
206 |
+
"first_stage_model.encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
|
207 |
+
"first_stage_model.encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
|
208 |
+
"first_stage_model.encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
|
209 |
+
"first_stage_model.encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
|
210 |
+
"first_stage_model.encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
|
211 |
+
"first_stage_model.encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
|
212 |
+
"first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
|
213 |
+
"first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
|
214 |
+
"first_stage_model.encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
|
215 |
+
"first_stage_model.encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
|
216 |
+
"first_stage_model.encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
|
217 |
+
"first_stage_model.encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
|
218 |
+
"first_stage_model.encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
|
219 |
+
"first_stage_model.encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
|
220 |
+
"first_stage_model.encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
|
221 |
+
"first_stage_model.encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
|
222 |
+
"first_stage_model.encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
|
223 |
+
"first_stage_model.encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
|
224 |
+
"first_stage_model.encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
|
225 |
+
"first_stage_model.encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
|
226 |
+
"first_stage_model.encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
|
227 |
+
"first_stage_model.encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
|
228 |
+
"first_stage_model.encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
|
229 |
+
"first_stage_model.encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
|
230 |
+
"first_stage_model.encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
|
231 |
+
"first_stage_model.encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
|
232 |
+
"first_stage_model.encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
|
233 |
+
"first_stage_model.encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
|
234 |
+
"first_stage_model.encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
|
235 |
+
"first_stage_model.encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
|
236 |
+
"first_stage_model.encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
|
237 |
+
"first_stage_model.encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
|
238 |
+
"first_stage_model.encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
|
239 |
+
"first_stage_model.encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
|
240 |
+
"first_stage_model.encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
|
241 |
+
"first_stage_model.encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
|
242 |
+
"first_stage_model.encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
|
243 |
+
"first_stage_model.encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
|
244 |
+
"first_stage_model.encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
|
245 |
+
"first_stage_model.encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
|
246 |
+
"first_stage_model.encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
|
247 |
+
"first_stage_model.encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
|
248 |
+
"first_stage_model.encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
|
249 |
+
"first_stage_model.encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
|
250 |
+
"first_stage_model.encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
|
251 |
+
"first_stage_model.encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
|
252 |
+
"first_stage_model.encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
|
253 |
+
"first_stage_model.encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
|
254 |
+
"first_stage_model.encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
|
255 |
+
"first_stage_model.encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
|
256 |
+
"first_stage_model.encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
|
257 |
+
"first_stage_model.encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
|
258 |
+
"first_stage_model.encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
|
259 |
+
"first_stage_model.encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
|
260 |
+
"first_stage_model.encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
|
261 |
+
"first_stage_model.encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
|
262 |
+
"first_stage_model.encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
|
263 |
+
"first_stage_model.encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
|
264 |
+
"first_stage_model.encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
|
265 |
+
"first_stage_model.encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
|
266 |
+
"first_stage_model.encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
|
267 |
+
"first_stage_model.encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
|
268 |
+
"first_stage_model.encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
|
269 |
+
"first_stage_model.encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
|
270 |
+
"first_stage_model.encoder.norm_out.bias": "conv_norm_out.bias",
|
271 |
+
"first_stage_model.encoder.norm_out.weight": "conv_norm_out.weight",
|
272 |
+
"first_stage_model.quant_conv.bias": "quant_conv.bias",
|
273 |
+
"first_stage_model.quant_conv.weight": "quant_conv.weight",
|
274 |
+
}
|
275 |
+
state_dict_ = {}
|
276 |
+
for name in state_dict:
|
277 |
+
if name in rename_dict:
|
278 |
+
param = state_dict[name]
|
279 |
+
if "transformer_blocks" in rename_dict[name]:
|
280 |
+
param = param.squeeze()
|
281 |
+
state_dict_[rename_dict[name]] = param
|
282 |
+
return state_dict_
|
diffsynth/models/sdxl_controlnet.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
|
3 |
+
from .sdxl_unet import SDXLUNet
|
4 |
+
from .tiler import TileWorker
|
5 |
+
from .sd_controlnet import ControlNetConditioningLayer
|
6 |
+
from collections import OrderedDict
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
class QuickGELU(torch.nn.Module):
|
11 |
+
|
12 |
+
def forward(self, x: torch.Tensor):
|
13 |
+
return x * torch.sigmoid(1.702 * x)
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
class ResidualAttentionBlock(torch.nn.Module):
|
18 |
+
|
19 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
self.attn = torch.nn.MultiheadAttention(d_model, n_head)
|
23 |
+
self.ln_1 = torch.nn.LayerNorm(d_model)
|
24 |
+
self.mlp = torch.nn.Sequential(OrderedDict([
|
25 |
+
("c_fc", torch.nn.Linear(d_model, d_model * 4)),
|
26 |
+
("gelu", QuickGELU()),
|
27 |
+
("c_proj", torch.nn.Linear(d_model * 4, d_model))
|
28 |
+
]))
|
29 |
+
self.ln_2 = torch.nn.LayerNorm(d_model)
|
30 |
+
self.attn_mask = attn_mask
|
31 |
+
|
32 |
+
def attention(self, x: torch.Tensor):
|
33 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
34 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
35 |
+
|
36 |
+
def forward(self, x: torch.Tensor):
|
37 |
+
x = x + self.attention(self.ln_1(x))
|
38 |
+
x = x + self.mlp(self.ln_2(x))
|
39 |
+
return x
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
class SDXLControlNetUnion(torch.nn.Module):
|
44 |
+
def __init__(self, global_pool=False):
|
45 |
+
super().__init__()
|
46 |
+
self.time_proj = Timesteps(320)
|
47 |
+
self.time_embedding = torch.nn.Sequential(
|
48 |
+
torch.nn.Linear(320, 1280),
|
49 |
+
torch.nn.SiLU(),
|
50 |
+
torch.nn.Linear(1280, 1280)
|
51 |
+
)
|
52 |
+
self.add_time_proj = Timesteps(256)
|
53 |
+
self.add_time_embedding = torch.nn.Sequential(
|
54 |
+
torch.nn.Linear(2816, 1280),
|
55 |
+
torch.nn.SiLU(),
|
56 |
+
torch.nn.Linear(1280, 1280)
|
57 |
+
)
|
58 |
+
self.control_type_proj = Timesteps(256)
|
59 |
+
self.control_type_embedding = torch.nn.Sequential(
|
60 |
+
torch.nn.Linear(256 * 8, 1280),
|
61 |
+
torch.nn.SiLU(),
|
62 |
+
torch.nn.Linear(1280, 1280)
|
63 |
+
)
|
64 |
+
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
|
65 |
+
|
66 |
+
self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
|
67 |
+
self.controlnet_transformer = ResidualAttentionBlock(320, 8)
|
68 |
+
self.task_embedding = torch.nn.Parameter(torch.randn(8, 320))
|
69 |
+
self.spatial_ch_projs = torch.nn.Linear(320, 320)
|
70 |
+
|
71 |
+
self.blocks = torch.nn.ModuleList([
|
72 |
+
# DownBlock2D
|
73 |
+
ResnetBlock(320, 320, 1280),
|
74 |
+
PushBlock(),
|
75 |
+
ResnetBlock(320, 320, 1280),
|
76 |
+
PushBlock(),
|
77 |
+
DownSampler(320),
|
78 |
+
PushBlock(),
|
79 |
+
# CrossAttnDownBlock2D
|
80 |
+
ResnetBlock(320, 640, 1280),
|
81 |
+
AttentionBlock(10, 64, 640, 2, 2048),
|
82 |
+
PushBlock(),
|
83 |
+
ResnetBlock(640, 640, 1280),
|
84 |
+
AttentionBlock(10, 64, 640, 2, 2048),
|
85 |
+
PushBlock(),
|
86 |
+
DownSampler(640),
|
87 |
+
PushBlock(),
|
88 |
+
# CrossAttnDownBlock2D
|
89 |
+
ResnetBlock(640, 1280, 1280),
|
90 |
+
AttentionBlock(20, 64, 1280, 10, 2048),
|
91 |
+
PushBlock(),
|
92 |
+
ResnetBlock(1280, 1280, 1280),
|
93 |
+
AttentionBlock(20, 64, 1280, 10, 2048),
|
94 |
+
PushBlock(),
|
95 |
+
# UNetMidBlock2DCrossAttn
|
96 |
+
ResnetBlock(1280, 1280, 1280),
|
97 |
+
AttentionBlock(20, 64, 1280, 10, 2048),
|
98 |
+
ResnetBlock(1280, 1280, 1280),
|
99 |
+
PushBlock()
|
100 |
+
])
|
101 |
+
|
102 |
+
self.controlnet_blocks = torch.nn.ModuleList([
|
103 |
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
104 |
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
105 |
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
106 |
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
107 |
+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
|
108 |
+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
|
109 |
+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
|
110 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
|
111 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
|
112 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
|
113 |
+
])
|
114 |
+
|
115 |
+
self.global_pool = global_pool
|
116 |
+
|
117 |
+
# 0 -- openpose
|
118 |
+
# 1 -- depth
|
119 |
+
# 2 -- hed/pidi/scribble/ted
|
120 |
+
# 3 -- canny/lineart/anime_lineart/mlsd
|
121 |
+
# 4 -- normal
|
122 |
+
# 5 -- segment
|
123 |
+
# 6 -- tile
|
124 |
+
# 7 -- repaint
|
125 |
+
self.task_id = {
|
126 |
+
"openpose": 0,
|
127 |
+
"depth": 1,
|
128 |
+
"softedge": 2,
|
129 |
+
"canny": 3,
|
130 |
+
"lineart": 3,
|
131 |
+
"lineart_anime": 3,
|
132 |
+
"tile": 6,
|
133 |
+
"inpaint": 7
|
134 |
+
}
|
135 |
+
|
136 |
+
|
137 |
+
def fuse_condition_to_input(self, hidden_states, task_id, conditioning):
|
138 |
+
controlnet_cond = self.controlnet_conv_in(conditioning)
|
139 |
+
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
|
140 |
+
feat_seq = feat_seq + self.task_embedding[task_id]
|
141 |
+
x = torch.stack([feat_seq, torch.mean(hidden_states, dim=(2, 3))], dim=1)
|
142 |
+
x = self.controlnet_transformer(x)
|
143 |
+
|
144 |
+
alpha = self.spatial_ch_projs(x[:,0]).unsqueeze(-1).unsqueeze(-1)
|
145 |
+
controlnet_cond_fuser = controlnet_cond + alpha
|
146 |
+
|
147 |
+
hidden_states = hidden_states + controlnet_cond_fuser
|
148 |
+
return hidden_states
|
149 |
+
|
150 |
+
|
151 |
+
def forward(
|
152 |
+
self,
|
153 |
+
sample, timestep, encoder_hidden_states,
|
154 |
+
conditioning, processor_id, add_time_id, add_text_embeds,
|
155 |
+
tiled=False, tile_size=64, tile_stride=32,
|
156 |
+
unet:SDXLUNet=None,
|
157 |
+
**kwargs
|
158 |
+
):
|
159 |
+
task_id = self.task_id[processor_id]
|
160 |
+
|
161 |
+
# 1. time
|
162 |
+
t_emb = self.time_proj(timestep).to(sample.dtype)
|
163 |
+
t_emb = self.time_embedding(t_emb)
|
164 |
+
|
165 |
+
time_embeds = self.add_time_proj(add_time_id)
|
166 |
+
time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1))
|
167 |
+
add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1)
|
168 |
+
add_embeds = add_embeds.to(sample.dtype)
|
169 |
+
if unet is not None and unet.is_kolors:
|
170 |
+
add_embeds = unet.add_time_embedding(add_embeds)
|
171 |
+
else:
|
172 |
+
add_embeds = self.add_time_embedding(add_embeds)
|
173 |
+
|
174 |
+
control_type = torch.zeros((sample.shape[0], 8), dtype=sample.dtype, device=sample.device)
|
175 |
+
control_type[:, task_id] = 1
|
176 |
+
control_embeds = self.control_type_proj(control_type.flatten())
|
177 |
+
control_embeds = control_embeds.reshape((sample.shape[0], -1))
|
178 |
+
control_embeds = control_embeds.to(sample.dtype)
|
179 |
+
control_embeds = self.control_type_embedding(control_embeds)
|
180 |
+
time_emb = t_emb + add_embeds + control_embeds
|
181 |
+
|
182 |
+
# 2. pre-process
|
183 |
+
height, width = sample.shape[2], sample.shape[3]
|
184 |
+
hidden_states = self.conv_in(sample)
|
185 |
+
hidden_states = self.fuse_condition_to_input(hidden_states, task_id, conditioning)
|
186 |
+
text_emb = encoder_hidden_states
|
187 |
+
if unet is not None and unet.is_kolors:
|
188 |
+
text_emb = unet.text_intermediate_proj(text_emb)
|
189 |
+
res_stack = [hidden_states]
|
190 |
+
|
191 |
+
# 3. blocks
|
192 |
+
for i, block in enumerate(self.blocks):
|
193 |
+
if tiled and not isinstance(block, PushBlock):
|
194 |
+
_, _, inter_height, _ = hidden_states.shape
|
195 |
+
resize_scale = inter_height / height
|
196 |
+
hidden_states = TileWorker().tiled_forward(
|
197 |
+
lambda x: block(x, time_emb, text_emb, res_stack)[0],
|
198 |
+
hidden_states,
|
199 |
+
int(tile_size * resize_scale),
|
200 |
+
int(tile_stride * resize_scale),
|
201 |
+
tile_device=hidden_states.device,
|
202 |
+
tile_dtype=hidden_states.dtype
|
203 |
+
)
|
204 |
+
else:
|
205 |
+
hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
|
206 |
+
|
207 |
+
# 4. ControlNet blocks
|
208 |
+
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
|
209 |
+
|
210 |
+
# pool
|
211 |
+
if self.global_pool:
|
212 |
+
controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack]
|
213 |
+
|
214 |
+
return controlnet_res_stack
|
215 |
+
|
216 |
+
@staticmethod
|
217 |
+
def state_dict_converter():
|
218 |
+
return SDXLControlNetUnionStateDictConverter()
|
219 |
+
|
220 |
+
|
221 |
+
|
222 |
+
class SDXLControlNetUnionStateDictConverter:
|
223 |
+
def __init__(self):
|
224 |
+
pass
|
225 |
+
|
226 |
+
def from_diffusers(self, state_dict):
|
227 |
+
# architecture
|
228 |
+
block_types = [
|
229 |
+
"ResnetBlock", "PushBlock", "ResnetBlock", "PushBlock", "DownSampler", "PushBlock",
|
230 |
+
"ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock", "DownSampler", "PushBlock",
|
231 |
+
"ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock",
|
232 |
+
"ResnetBlock", "AttentionBlock", "ResnetBlock", "PushBlock"
|
233 |
+
]
|
234 |
+
|
235 |
+
# controlnet_rename_dict
|
236 |
+
controlnet_rename_dict = {
|
237 |
+
"controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight",
|
238 |
+
"controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias",
|
239 |
+
"controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight",
|
240 |
+
"controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias",
|
241 |
+
"controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight",
|
242 |
+
"controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias",
|
243 |
+
"controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight",
|
244 |
+
"controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias",
|
245 |
+
"controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight",
|
246 |
+
"controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias",
|
247 |
+
"controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight",
|
248 |
+
"controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias",
|
249 |
+
"controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
|
250 |
+
"controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
|
251 |
+
"controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
|
252 |
+
"controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
|
253 |
+
"control_add_embedding.linear_1.weight": "control_type_embedding.0.weight",
|
254 |
+
"control_add_embedding.linear_1.bias": "control_type_embedding.0.bias",
|
255 |
+
"control_add_embedding.linear_2.weight": "control_type_embedding.2.weight",
|
256 |
+
"control_add_embedding.linear_2.bias": "control_type_embedding.2.bias",
|
257 |
+
}
|
258 |
+
|
259 |
+
# Rename each parameter
|
260 |
+
name_list = sorted([name for name in state_dict])
|
261 |
+
rename_dict = {}
|
262 |
+
block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
|
263 |
+
last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
|
264 |
+
for name in name_list:
|
265 |
+
names = name.split(".")
|
266 |
+
if names[0] in ["conv_in", "conv_norm_out", "conv_out", "task_embedding", "spatial_ch_projs"]:
|
267 |
+
pass
|
268 |
+
elif name in controlnet_rename_dict:
|
269 |
+
names = controlnet_rename_dict[name].split(".")
|
270 |
+
elif names[0] == "controlnet_down_blocks":
|
271 |
+
names[0] = "controlnet_blocks"
|
272 |
+
elif names[0] == "controlnet_mid_block":
|
273 |
+
names = ["controlnet_blocks", "9", names[-1]]
|
274 |
+
elif names[0] in ["time_embedding", "add_embedding"]:
|
275 |
+
if names[0] == "add_embedding":
|
276 |
+
names[0] = "add_time_embedding"
|
277 |
+
names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
|
278 |
+
elif names[0] == "control_add_embedding":
|
279 |
+
names[0] = "control_type_embedding"
|
280 |
+
elif names[0] == "transformer_layes":
|
281 |
+
names[0] = "controlnet_transformer"
|
282 |
+
names.pop(1)
|
283 |
+
elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
|
284 |
+
if names[0] == "mid_block":
|
285 |
+
names.insert(1, "0")
|
286 |
+
block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
|
287 |
+
block_type_with_id = ".".join(names[:4])
|
288 |
+
if block_type_with_id != last_block_type_with_id[block_type]:
|
289 |
+
block_id[block_type] += 1
|
290 |
+
last_block_type_with_id[block_type] = block_type_with_id
|
291 |
+
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
292 |
+
block_id[block_type] += 1
|
293 |
+
block_type_with_id = ".".join(names[:4])
|
294 |
+
names = ["blocks", str(block_id[block_type])] + names[4:]
|
295 |
+
if "ff" in names:
|
296 |
+
ff_index = names.index("ff")
|
297 |
+
component = ".".join(names[ff_index:ff_index+3])
|
298 |
+
component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
|
299 |
+
names = names[:ff_index] + [component] + names[ff_index+3:]
|
300 |
+
if "to_out" in names:
|
301 |
+
names.pop(names.index("to_out") + 1)
|
302 |
+
else:
|
303 |
+
print(name, state_dict[name].shape)
|
304 |
+
# raise ValueError(f"Unknown parameters: {name}")
|
305 |
+
rename_dict[name] = ".".join(names)
|
306 |
+
|
307 |
+
# Convert state_dict
|
308 |
+
state_dict_ = {}
|
309 |
+
for name, param in state_dict.items():
|
310 |
+
if name not in rename_dict:
|
311 |
+
continue
|
312 |
+
if ".proj_in." in name or ".proj_out." in name:
|
313 |
+
param = param.squeeze()
|
314 |
+
state_dict_[rename_dict[name]] = param
|
315 |
+
return state_dict_
|
316 |
+
|
317 |
+
def from_civitai(self, state_dict):
|
318 |
+
return self.from_diffusers(state_dict)
|
diffsynth/models/sdxl_ipadapter.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .svd_image_encoder import SVDImageEncoder
|
2 |
+
from transformers import CLIPImageProcessor
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class IpAdapterXLCLIPImageEmbedder(SVDImageEncoder):
|
7 |
+
def __init__(self):
|
8 |
+
super().__init__(embed_dim=1664, encoder_intermediate_size=8192, projection_dim=1280, num_encoder_layers=48, num_heads=16, head_dim=104)
|
9 |
+
self.image_processor = CLIPImageProcessor()
|
10 |
+
|
11 |
+
def forward(self, image):
|
12 |
+
pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values
|
13 |
+
pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype)
|
14 |
+
return super().forward(pixel_values)
|
15 |
+
|
16 |
+
|
17 |
+
class IpAdapterImageProjModel(torch.nn.Module):
|
18 |
+
def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280, clip_extra_context_tokens=4):
|
19 |
+
super().__init__()
|
20 |
+
self.cross_attention_dim = cross_attention_dim
|
21 |
+
self.clip_extra_context_tokens = clip_extra_context_tokens
|
22 |
+
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
23 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
24 |
+
|
25 |
+
def forward(self, image_embeds):
|
26 |
+
clip_extra_context_tokens = self.proj(image_embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
|
27 |
+
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
28 |
+
return clip_extra_context_tokens
|
29 |
+
|
30 |
+
|
31 |
+
class IpAdapterModule(torch.nn.Module):
|
32 |
+
def __init__(self, input_dim, output_dim):
|
33 |
+
super().__init__()
|
34 |
+
self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
35 |
+
self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
36 |
+
|
37 |
+
def forward(self, hidden_states):
|
38 |
+
ip_k = self.to_k_ip(hidden_states)
|
39 |
+
ip_v = self.to_v_ip(hidden_states)
|
40 |
+
return ip_k, ip_v
|
41 |
+
|
42 |
+
|
43 |
+
class SDXLIpAdapter(torch.nn.Module):
|
44 |
+
def __init__(self):
|
45 |
+
super().__init__()
|
46 |
+
shape_list = [(2048, 640)] * 4 + [(2048, 1280)] * 50 + [(2048, 640)] * 6 + [(2048, 1280)] * 10
|
47 |
+
self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list])
|
48 |
+
self.image_proj = IpAdapterImageProjModel()
|
49 |
+
self.set_full_adapter()
|
50 |
+
|
51 |
+
def set_full_adapter(self):
|
52 |
+
map_list = sum([
|
53 |
+
[(7, i) for i in range(2)],
|
54 |
+
[(10, i) for i in range(2)],
|
55 |
+
[(15, i) for i in range(10)],
|
56 |
+
[(18, i) for i in range(10)],
|
57 |
+
[(25, i) for i in range(10)],
|
58 |
+
[(28, i) for i in range(10)],
|
59 |
+
[(31, i) for i in range(10)],
|
60 |
+
[(35, i) for i in range(2)],
|
61 |
+
[(38, i) for i in range(2)],
|
62 |
+
[(41, i) for i in range(2)],
|
63 |
+
[(21, i) for i in range(10)],
|
64 |
+
], [])
|
65 |
+
self.call_block_id = {i: j for j, i in enumerate(map_list)}
|
66 |
+
|
67 |
+
def set_less_adapter(self):
|
68 |
+
map_list = sum([
|
69 |
+
[(7, i) for i in range(2)],
|
70 |
+
[(10, i) for i in range(2)],
|
71 |
+
[(15, i) for i in range(10)],
|
72 |
+
[(18, i) for i in range(10)],
|
73 |
+
[(25, i) for i in range(10)],
|
74 |
+
[(28, i) for i in range(10)],
|
75 |
+
[(31, i) for i in range(10)],
|
76 |
+
[(35, i) for i in range(2)],
|
77 |
+
[(38, i) for i in range(2)],
|
78 |
+
[(41, i) for i in range(2)],
|
79 |
+
[(21, i) for i in range(10)],
|
80 |
+
], [])
|
81 |
+
self.call_block_id = {i: j for j, i in enumerate(map_list) if j>=34 and j<44}
|
82 |
+
|
83 |
+
def forward(self, hidden_states, scale=1.0):
|
84 |
+
hidden_states = self.image_proj(hidden_states)
|
85 |
+
hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
|
86 |
+
ip_kv_dict = {}
|
87 |
+
for (block_id, transformer_id) in self.call_block_id:
|
88 |
+
ipadapter_id = self.call_block_id[(block_id, transformer_id)]
|
89 |
+
ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
|
90 |
+
if block_id not in ip_kv_dict:
|
91 |
+
ip_kv_dict[block_id] = {}
|
92 |
+
ip_kv_dict[block_id][transformer_id] = {
|
93 |
+
"ip_k": ip_k,
|
94 |
+
"ip_v": ip_v,
|
95 |
+
"scale": scale
|
96 |
+
}
|
97 |
+
return ip_kv_dict
|
98 |
+
|
99 |
+
@staticmethod
|
100 |
+
def state_dict_converter():
|
101 |
+
return SDXLIpAdapterStateDictConverter()
|
102 |
+
|
103 |
+
|
104 |
+
class SDXLIpAdapterStateDictConverter:
|
105 |
+
def __init__(self):
|
106 |
+
pass
|
107 |
+
|
108 |
+
def from_diffusers(self, state_dict):
|
109 |
+
state_dict_ = {}
|
110 |
+
for name in state_dict["ip_adapter"]:
|
111 |
+
names = name.split(".")
|
112 |
+
layer_id = str(int(names[0]) // 2)
|
113 |
+
name_ = ".".join(["ipadapter_modules"] + [layer_id] + names[1:])
|
114 |
+
state_dict_[name_] = state_dict["ip_adapter"][name]
|
115 |
+
for name in state_dict["image_proj"]:
|
116 |
+
name_ = "image_proj." + name
|
117 |
+
state_dict_[name_] = state_dict["image_proj"][name]
|
118 |
+
return state_dict_
|
119 |
+
|
120 |
+
def from_civitai(self, state_dict):
|
121 |
+
return self.from_diffusers(state_dict)
|
122 |
+
|
diffsynth/models/sdxl_motion.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .sd_motion import TemporalBlock
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
class SDXLMotionModel(torch.nn.Module):
|
7 |
+
def __init__(self):
|
8 |
+
super().__init__()
|
9 |
+
self.motion_modules = torch.nn.ModuleList([
|
10 |
+
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
11 |
+
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
12 |
+
|
13 |
+
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
14 |
+
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
15 |
+
|
16 |
+
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
17 |
+
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
18 |
+
|
19 |
+
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
20 |
+
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
21 |
+
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
22 |
+
|
23 |
+
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
24 |
+
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
25 |
+
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
26 |
+
|
27 |
+
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
28 |
+
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
29 |
+
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
30 |
+
])
|
31 |
+
self.call_block_id = {
|
32 |
+
0: 0,
|
33 |
+
2: 1,
|
34 |
+
7: 2,
|
35 |
+
10: 3,
|
36 |
+
15: 4,
|
37 |
+
18: 5,
|
38 |
+
25: 6,
|
39 |
+
28: 7,
|
40 |
+
31: 8,
|
41 |
+
35: 9,
|
42 |
+
38: 10,
|
43 |
+
41: 11,
|
44 |
+
44: 12,
|
45 |
+
46: 13,
|
46 |
+
48: 14,
|
47 |
+
}
|
48 |
+
|
49 |
+
def forward(self):
|
50 |
+
pass
|
51 |
+
|
52 |
+
@staticmethod
|
53 |
+
def state_dict_converter():
|
54 |
+
return SDMotionModelStateDictConverter()
|
55 |
+
|
56 |
+
|
57 |
+
class SDMotionModelStateDictConverter:
|
58 |
+
def __init__(self):
|
59 |
+
pass
|
60 |
+
|
61 |
+
def from_diffusers(self, state_dict):
|
62 |
+
rename_dict = {
|
63 |
+
"norm": "norm",
|
64 |
+
"proj_in": "proj_in",
|
65 |
+
"transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q",
|
66 |
+
"transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k",
|
67 |
+
"transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v",
|
68 |
+
"transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out",
|
69 |
+
"transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1",
|
70 |
+
"transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q",
|
71 |
+
"transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k",
|
72 |
+
"transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v",
|
73 |
+
"transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out",
|
74 |
+
"transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2",
|
75 |
+
"transformer_blocks.0.norms.0": "transformer_blocks.0.norm1",
|
76 |
+
"transformer_blocks.0.norms.1": "transformer_blocks.0.norm2",
|
77 |
+
"transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj",
|
78 |
+
"transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff",
|
79 |
+
"transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3",
|
80 |
+
"proj_out": "proj_out",
|
81 |
+
}
|
82 |
+
name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")])
|
83 |
+
name_list += sorted([i for i in state_dict if i.startswith("mid_block.")])
|
84 |
+
name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")])
|
85 |
+
state_dict_ = {}
|
86 |
+
last_prefix, module_id = "", -1
|
87 |
+
for name in name_list:
|
88 |
+
names = name.split(".")
|
89 |
+
prefix_index = names.index("temporal_transformer") + 1
|
90 |
+
prefix = ".".join(names[:prefix_index])
|
91 |
+
if prefix != last_prefix:
|
92 |
+
last_prefix = prefix
|
93 |
+
module_id += 1
|
94 |
+
middle_name = ".".join(names[prefix_index:-1])
|
95 |
+
suffix = names[-1]
|
96 |
+
if "pos_encoder" in names:
|
97 |
+
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]])
|
98 |
+
else:
|
99 |
+
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
|
100 |
+
state_dict_[rename] = state_dict[name]
|
101 |
+
return state_dict_
|
102 |
+
|
103 |
+
def from_civitai(self, state_dict):
|
104 |
+
return self.from_diffusers(state_dict)
|
diffsynth/models/sdxl_text_encoder.py
ADDED
@@ -0,0 +1,759 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .sd_text_encoder import CLIPEncoderLayer
|
3 |
+
|
4 |
+
|
5 |
+
class SDXLTextEncoder(torch.nn.Module):
|
6 |
+
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=11, encoder_intermediate_size=3072):
|
7 |
+
super().__init__()
|
8 |
+
|
9 |
+
# token_embedding
|
10 |
+
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
11 |
+
|
12 |
+
# position_embeds (This is a fixed tensor)
|
13 |
+
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
14 |
+
|
15 |
+
# encoders
|
16 |
+
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
17 |
+
|
18 |
+
# attn_mask
|
19 |
+
self.attn_mask = self.attention_mask(max_position_embeddings)
|
20 |
+
|
21 |
+
# The text encoder is different to that in Stable Diffusion 1.x.
|
22 |
+
# It does not include final_layer_norm.
|
23 |
+
|
24 |
+
def attention_mask(self, length):
|
25 |
+
mask = torch.empty(length, length)
|
26 |
+
mask.fill_(float("-inf"))
|
27 |
+
mask.triu_(1)
|
28 |
+
return mask
|
29 |
+
|
30 |
+
def forward(self, input_ids, clip_skip=1):
|
31 |
+
embeds = self.token_embedding(input_ids) + self.position_embeds
|
32 |
+
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
33 |
+
for encoder_id, encoder in enumerate(self.encoders):
|
34 |
+
embeds = encoder(embeds, attn_mask=attn_mask)
|
35 |
+
if encoder_id + clip_skip == len(self.encoders):
|
36 |
+
break
|
37 |
+
return embeds
|
38 |
+
|
39 |
+
@staticmethod
|
40 |
+
def state_dict_converter():
|
41 |
+
return SDXLTextEncoderStateDictConverter()
|
42 |
+
|
43 |
+
|
44 |
+
class SDXLTextEncoder2(torch.nn.Module):
|
45 |
+
def __init__(self, embed_dim=1280, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=32, encoder_intermediate_size=5120):
|
46 |
+
super().__init__()
|
47 |
+
|
48 |
+
# token_embedding
|
49 |
+
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
50 |
+
|
51 |
+
# position_embeds (This is a fixed tensor)
|
52 |
+
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
53 |
+
|
54 |
+
# encoders
|
55 |
+
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=20, head_dim=64, use_quick_gelu=False) for _ in range(num_encoder_layers)])
|
56 |
+
|
57 |
+
# attn_mask
|
58 |
+
self.attn_mask = self.attention_mask(max_position_embeddings)
|
59 |
+
|
60 |
+
# final_layer_norm
|
61 |
+
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
62 |
+
|
63 |
+
# text_projection
|
64 |
+
self.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False)
|
65 |
+
|
66 |
+
def attention_mask(self, length):
|
67 |
+
mask = torch.empty(length, length)
|
68 |
+
mask.fill_(float("-inf"))
|
69 |
+
mask.triu_(1)
|
70 |
+
return mask
|
71 |
+
|
72 |
+
def forward(self, input_ids, clip_skip=2):
|
73 |
+
embeds = self.token_embedding(input_ids) + self.position_embeds
|
74 |
+
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
75 |
+
for encoder_id, encoder in enumerate(self.encoders):
|
76 |
+
embeds = encoder(embeds, attn_mask=attn_mask)
|
77 |
+
if encoder_id + clip_skip == len(self.encoders):
|
78 |
+
hidden_states = embeds
|
79 |
+
embeds = self.final_layer_norm(embeds)
|
80 |
+
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
|
81 |
+
pooled_embeds = self.text_projection(pooled_embeds)
|
82 |
+
return pooled_embeds, hidden_states
|
83 |
+
|
84 |
+
@staticmethod
|
85 |
+
def state_dict_converter():
|
86 |
+
return SDXLTextEncoder2StateDictConverter()
|
87 |
+
|
88 |
+
|
89 |
+
class SDXLTextEncoderStateDictConverter:
|
90 |
+
def __init__(self):
|
91 |
+
pass
|
92 |
+
|
93 |
+
def from_diffusers(self, state_dict):
|
94 |
+
rename_dict = {
|
95 |
+
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
96 |
+
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
97 |
+
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
98 |
+
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
|
99 |
+
}
|
100 |
+
attn_rename_dict = {
|
101 |
+
"self_attn.q_proj": "attn.to_q",
|
102 |
+
"self_attn.k_proj": "attn.to_k",
|
103 |
+
"self_attn.v_proj": "attn.to_v",
|
104 |
+
"self_attn.out_proj": "attn.to_out",
|
105 |
+
"layer_norm1": "layer_norm1",
|
106 |
+
"layer_norm2": "layer_norm2",
|
107 |
+
"mlp.fc1": "fc1",
|
108 |
+
"mlp.fc2": "fc2",
|
109 |
+
}
|
110 |
+
state_dict_ = {}
|
111 |
+
for name in state_dict:
|
112 |
+
if name in rename_dict:
|
113 |
+
param = state_dict[name]
|
114 |
+
if name == "text_model.embeddings.position_embedding.weight":
|
115 |
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
116 |
+
state_dict_[rename_dict[name]] = param
|
117 |
+
elif name.startswith("text_model.encoder.layers."):
|
118 |
+
param = state_dict[name]
|
119 |
+
names = name.split(".")
|
120 |
+
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
121 |
+
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
122 |
+
state_dict_[name_] = param
|
123 |
+
return state_dict_
|
124 |
+
|
125 |
+
def from_civitai(self, state_dict):
|
126 |
+
rename_dict = {
|
127 |
+
"conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight": "position_embeds",
|
128 |
+
"conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
129 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
|
130 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
|
131 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
|
132 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
|
133 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
|
134 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
|
135 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
|
136 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
|
137 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
|
138 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
|
139 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
140 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
141 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
|
142 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
|
143 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
|
144 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
|
145 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
|
146 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
|
147 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
|
148 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
|
149 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
|
150 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
|
151 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
|
152 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
|
153 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
|
154 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
|
155 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
156 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
157 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
|
158 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
|
159 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
|
160 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
|
161 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
|
162 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
|
163 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
|
164 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
|
165 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
|
166 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
|
167 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
|
168 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
|
169 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
|
170 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
|
171 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
172 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
173 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
|
174 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
|
175 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
|
176 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
|
177 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
|
178 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
|
179 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
|
180 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
|
181 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
|
182 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
|
183 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
|
184 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
|
185 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
|
186 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
|
187 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
188 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
189 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
|
190 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
|
191 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
|
192 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
|
193 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
|
194 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
|
195 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
|
196 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
|
197 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
|
198 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
|
199 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
|
200 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
|
201 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
|
202 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
|
203 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
204 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
205 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
|
206 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
|
207 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
|
208 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
|
209 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
|
210 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
|
211 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
|
212 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
|
213 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
|
214 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
|
215 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
|
216 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
|
217 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
|
218 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
|
219 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
220 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
221 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
|
222 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
|
223 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
|
224 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
|
225 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
|
226 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
|
227 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
|
228 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
|
229 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
|
230 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
|
231 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
|
232 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
|
233 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
|
234 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
|
235 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
236 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
237 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
|
238 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
|
239 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
|
240 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
|
241 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
|
242 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
|
243 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
|
244 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
|
245 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
|
246 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
|
247 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
|
248 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
|
249 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
|
250 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
|
251 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
252 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
253 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
|
254 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
|
255 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
|
256 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
|
257 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
|
258 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
|
259 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
|
260 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
|
261 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
|
262 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
|
263 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
|
264 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
|
265 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
|
266 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
|
267 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
268 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
269 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
|
270 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
|
271 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
|
272 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
|
273 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
|
274 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
|
275 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
|
276 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
|
277 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
|
278 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
|
279 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
|
280 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
|
281 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
|
282 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
|
283 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
284 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
285 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
|
286 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
|
287 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
|
288 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
|
289 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
|
290 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
|
291 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
|
292 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
|
293 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
|
294 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
|
295 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
|
296 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
|
297 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
|
298 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
|
299 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
300 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
301 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
|
302 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
|
303 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
|
304 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
|
305 |
+
}
|
306 |
+
state_dict_ = {}
|
307 |
+
for name in state_dict:
|
308 |
+
if name in rename_dict:
|
309 |
+
param = state_dict[name]
|
310 |
+
if name == "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight":
|
311 |
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
312 |
+
state_dict_[rename_dict[name]] = param
|
313 |
+
return state_dict_
|
314 |
+
|
315 |
+
|
316 |
+
class SDXLTextEncoder2StateDictConverter:
|
317 |
+
def __init__(self):
|
318 |
+
pass
|
319 |
+
|
320 |
+
def from_diffusers(self, state_dict):
|
321 |
+
rename_dict = {
|
322 |
+
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
323 |
+
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
324 |
+
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
325 |
+
"text_model.final_layer_norm.bias": "final_layer_norm.bias",
|
326 |
+
"text_projection.weight": "text_projection.weight"
|
327 |
+
}
|
328 |
+
attn_rename_dict = {
|
329 |
+
"self_attn.q_proj": "attn.to_q",
|
330 |
+
"self_attn.k_proj": "attn.to_k",
|
331 |
+
"self_attn.v_proj": "attn.to_v",
|
332 |
+
"self_attn.out_proj": "attn.to_out",
|
333 |
+
"layer_norm1": "layer_norm1",
|
334 |
+
"layer_norm2": "layer_norm2",
|
335 |
+
"mlp.fc1": "fc1",
|
336 |
+
"mlp.fc2": "fc2",
|
337 |
+
}
|
338 |
+
state_dict_ = {}
|
339 |
+
for name in state_dict:
|
340 |
+
if name in rename_dict:
|
341 |
+
param = state_dict[name]
|
342 |
+
if name == "text_model.embeddings.position_embedding.weight":
|
343 |
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
344 |
+
state_dict_[rename_dict[name]] = param
|
345 |
+
elif name.startswith("text_model.encoder.layers."):
|
346 |
+
param = state_dict[name]
|
347 |
+
names = name.split(".")
|
348 |
+
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
349 |
+
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
350 |
+
state_dict_[name_] = param
|
351 |
+
return state_dict_
|
352 |
+
|
353 |
+
def from_civitai(self, state_dict):
|
354 |
+
rename_dict = {
|
355 |
+
"conditioner.embedders.1.model.ln_final.bias": "final_layer_norm.bias",
|
356 |
+
"conditioner.embedders.1.model.ln_final.weight": "final_layer_norm.weight",
|
357 |
+
"conditioner.embedders.1.model.positional_embedding": "position_embeds",
|
358 |
+
"conditioner.embedders.1.model.token_embedding.weight": "token_embedding.weight",
|
359 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias": ['encoders.0.attn.to_q.bias', 'encoders.0.attn.to_k.bias', 'encoders.0.attn.to_v.bias'],
|
360 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight": ['encoders.0.attn.to_q.weight', 'encoders.0.attn.to_k.weight', 'encoders.0.attn.to_v.weight'],
|
361 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
362 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
363 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias": "encoders.0.layer_norm1.bias",
|
364 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.ln_1.weight": "encoders.0.layer_norm1.weight",
|
365 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.ln_2.bias": "encoders.0.layer_norm2.bias",
|
366 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.ln_2.weight": "encoders.0.layer_norm2.weight",
|
367 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.bias": "encoders.0.fc1.bias",
|
368 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.weight": "encoders.0.fc1.weight",
|
369 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.bias": "encoders.0.fc2.bias",
|
370 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.weight": "encoders.0.fc2.weight",
|
371 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias": ['encoders.1.attn.to_q.bias', 'encoders.1.attn.to_k.bias', 'encoders.1.attn.to_v.bias'],
|
372 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight": ['encoders.1.attn.to_q.weight', 'encoders.1.attn.to_k.weight', 'encoders.1.attn.to_v.weight'],
|
373 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
374 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
375 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.ln_1.bias": "encoders.1.layer_norm1.bias",
|
376 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.ln_1.weight": "encoders.1.layer_norm1.weight",
|
377 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.ln_2.bias": "encoders.1.layer_norm2.bias",
|
378 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.ln_2.weight": "encoders.1.layer_norm2.weight",
|
379 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.bias": "encoders.1.fc1.bias",
|
380 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.weight": "encoders.1.fc1.weight",
|
381 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.bias": "encoders.1.fc2.bias",
|
382 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.weight": "encoders.1.fc2.weight",
|
383 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias": ['encoders.10.attn.to_q.bias', 'encoders.10.attn.to_k.bias', 'encoders.10.attn.to_v.bias'],
|
384 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight": ['encoders.10.attn.to_q.weight', 'encoders.10.attn.to_k.weight', 'encoders.10.attn.to_v.weight'],
|
385 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
386 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
387 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.ln_1.bias": "encoders.10.layer_norm1.bias",
|
388 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.ln_1.weight": "encoders.10.layer_norm1.weight",
|
389 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.ln_2.bias": "encoders.10.layer_norm2.bias",
|
390 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.ln_2.weight": "encoders.10.layer_norm2.weight",
|
391 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.bias": "encoders.10.fc1.bias",
|
392 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.weight": "encoders.10.fc1.weight",
|
393 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.bias": "encoders.10.fc2.bias",
|
394 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.weight": "encoders.10.fc2.weight",
|
395 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias": ['encoders.11.attn.to_q.bias', 'encoders.11.attn.to_k.bias', 'encoders.11.attn.to_v.bias'],
|
396 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight": ['encoders.11.attn.to_q.weight', 'encoders.11.attn.to_k.weight', 'encoders.11.attn.to_v.weight'],
|
397 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.bias": "encoders.11.attn.to_out.bias",
|
398 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.weight": "encoders.11.attn.to_out.weight",
|
399 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.ln_1.bias": "encoders.11.layer_norm1.bias",
|
400 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.ln_1.weight": "encoders.11.layer_norm1.weight",
|
401 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.ln_2.bias": "encoders.11.layer_norm2.bias",
|
402 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.ln_2.weight": "encoders.11.layer_norm2.weight",
|
403 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.bias": "encoders.11.fc1.bias",
|
404 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.weight": "encoders.11.fc1.weight",
|
405 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.bias": "encoders.11.fc2.bias",
|
406 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.weight": "encoders.11.fc2.weight",
|
407 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias": ['encoders.12.attn.to_q.bias', 'encoders.12.attn.to_k.bias', 'encoders.12.attn.to_v.bias'],
|
408 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight": ['encoders.12.attn.to_q.weight', 'encoders.12.attn.to_k.weight', 'encoders.12.attn.to_v.weight'],
|
409 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.bias": "encoders.12.attn.to_out.bias",
|
410 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.weight": "encoders.12.attn.to_out.weight",
|
411 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.ln_1.bias": "encoders.12.layer_norm1.bias",
|
412 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.ln_1.weight": "encoders.12.layer_norm1.weight",
|
413 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.ln_2.bias": "encoders.12.layer_norm2.bias",
|
414 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.ln_2.weight": "encoders.12.layer_norm2.weight",
|
415 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.bias": "encoders.12.fc1.bias",
|
416 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.weight": "encoders.12.fc1.weight",
|
417 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.bias": "encoders.12.fc2.bias",
|
418 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.weight": "encoders.12.fc2.weight",
|
419 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias": ['encoders.13.attn.to_q.bias', 'encoders.13.attn.to_k.bias', 'encoders.13.attn.to_v.bias'],
|
420 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight": ['encoders.13.attn.to_q.weight', 'encoders.13.attn.to_k.weight', 'encoders.13.attn.to_v.weight'],
|
421 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.bias": "encoders.13.attn.to_out.bias",
|
422 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.weight": "encoders.13.attn.to_out.weight",
|
423 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.ln_1.bias": "encoders.13.layer_norm1.bias",
|
424 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.ln_1.weight": "encoders.13.layer_norm1.weight",
|
425 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.ln_2.bias": "encoders.13.layer_norm2.bias",
|
426 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.ln_2.weight": "encoders.13.layer_norm2.weight",
|
427 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.bias": "encoders.13.fc1.bias",
|
428 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.weight": "encoders.13.fc1.weight",
|
429 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.bias": "encoders.13.fc2.bias",
|
430 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.weight": "encoders.13.fc2.weight",
|
431 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias": ['encoders.14.attn.to_q.bias', 'encoders.14.attn.to_k.bias', 'encoders.14.attn.to_v.bias'],
|
432 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight": ['encoders.14.attn.to_q.weight', 'encoders.14.attn.to_k.weight', 'encoders.14.attn.to_v.weight'],
|
433 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.bias": "encoders.14.attn.to_out.bias",
|
434 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.weight": "encoders.14.attn.to_out.weight",
|
435 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.ln_1.bias": "encoders.14.layer_norm1.bias",
|
436 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.ln_1.weight": "encoders.14.layer_norm1.weight",
|
437 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.ln_2.bias": "encoders.14.layer_norm2.bias",
|
438 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.ln_2.weight": "encoders.14.layer_norm2.weight",
|
439 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.bias": "encoders.14.fc1.bias",
|
440 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.weight": "encoders.14.fc1.weight",
|
441 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.bias": "encoders.14.fc2.bias",
|
442 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.weight": "encoders.14.fc2.weight",
|
443 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias": ['encoders.15.attn.to_q.bias', 'encoders.15.attn.to_k.bias', 'encoders.15.attn.to_v.bias'],
|
444 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight": ['encoders.15.attn.to_q.weight', 'encoders.15.attn.to_k.weight', 'encoders.15.attn.to_v.weight'],
|
445 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.bias": "encoders.15.attn.to_out.bias",
|
446 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.weight": "encoders.15.attn.to_out.weight",
|
447 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.ln_1.bias": "encoders.15.layer_norm1.bias",
|
448 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.ln_1.weight": "encoders.15.layer_norm1.weight",
|
449 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.ln_2.bias": "encoders.15.layer_norm2.bias",
|
450 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.ln_2.weight": "encoders.15.layer_norm2.weight",
|
451 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.bias": "encoders.15.fc1.bias",
|
452 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.weight": "encoders.15.fc1.weight",
|
453 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.bias": "encoders.15.fc2.bias",
|
454 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.weight": "encoders.15.fc2.weight",
|
455 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias": ['encoders.16.attn.to_q.bias', 'encoders.16.attn.to_k.bias', 'encoders.16.attn.to_v.bias'],
|
456 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight": ['encoders.16.attn.to_q.weight', 'encoders.16.attn.to_k.weight', 'encoders.16.attn.to_v.weight'],
|
457 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.bias": "encoders.16.attn.to_out.bias",
|
458 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.weight": "encoders.16.attn.to_out.weight",
|
459 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.ln_1.bias": "encoders.16.layer_norm1.bias",
|
460 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.ln_1.weight": "encoders.16.layer_norm1.weight",
|
461 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.ln_2.bias": "encoders.16.layer_norm2.bias",
|
462 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.ln_2.weight": "encoders.16.layer_norm2.weight",
|
463 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.bias": "encoders.16.fc1.bias",
|
464 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.weight": "encoders.16.fc1.weight",
|
465 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.bias": "encoders.16.fc2.bias",
|
466 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.weight": "encoders.16.fc2.weight",
|
467 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias": ['encoders.17.attn.to_q.bias', 'encoders.17.attn.to_k.bias', 'encoders.17.attn.to_v.bias'],
|
468 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight": ['encoders.17.attn.to_q.weight', 'encoders.17.attn.to_k.weight', 'encoders.17.attn.to_v.weight'],
|
469 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.bias": "encoders.17.attn.to_out.bias",
|
470 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.weight": "encoders.17.attn.to_out.weight",
|
471 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.ln_1.bias": "encoders.17.layer_norm1.bias",
|
472 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.ln_1.weight": "encoders.17.layer_norm1.weight",
|
473 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.ln_2.bias": "encoders.17.layer_norm2.bias",
|
474 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.ln_2.weight": "encoders.17.layer_norm2.weight",
|
475 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.bias": "encoders.17.fc1.bias",
|
476 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.weight": "encoders.17.fc1.weight",
|
477 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.bias": "encoders.17.fc2.bias",
|
478 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.weight": "encoders.17.fc2.weight",
|
479 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias": ['encoders.18.attn.to_q.bias', 'encoders.18.attn.to_k.bias', 'encoders.18.attn.to_v.bias'],
|
480 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight": ['encoders.18.attn.to_q.weight', 'encoders.18.attn.to_k.weight', 'encoders.18.attn.to_v.weight'],
|
481 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.bias": "encoders.18.attn.to_out.bias",
|
482 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.weight": "encoders.18.attn.to_out.weight",
|
483 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.ln_1.bias": "encoders.18.layer_norm1.bias",
|
484 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.ln_1.weight": "encoders.18.layer_norm1.weight",
|
485 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.ln_2.bias": "encoders.18.layer_norm2.bias",
|
486 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.ln_2.weight": "encoders.18.layer_norm2.weight",
|
487 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.bias": "encoders.18.fc1.bias",
|
488 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.weight": "encoders.18.fc1.weight",
|
489 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.bias": "encoders.18.fc2.bias",
|
490 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.weight": "encoders.18.fc2.weight",
|
491 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias": ['encoders.19.attn.to_q.bias', 'encoders.19.attn.to_k.bias', 'encoders.19.attn.to_v.bias'],
|
492 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight": ['encoders.19.attn.to_q.weight', 'encoders.19.attn.to_k.weight', 'encoders.19.attn.to_v.weight'],
|
493 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.bias": "encoders.19.attn.to_out.bias",
|
494 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.weight": "encoders.19.attn.to_out.weight",
|
495 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.ln_1.bias": "encoders.19.layer_norm1.bias",
|
496 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.ln_1.weight": "encoders.19.layer_norm1.weight",
|
497 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.ln_2.bias": "encoders.19.layer_norm2.bias",
|
498 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.ln_2.weight": "encoders.19.layer_norm2.weight",
|
499 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.bias": "encoders.19.fc1.bias",
|
500 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.weight": "encoders.19.fc1.weight",
|
501 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.bias": "encoders.19.fc2.bias",
|
502 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.weight": "encoders.19.fc2.weight",
|
503 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias": ['encoders.2.attn.to_q.bias', 'encoders.2.attn.to_k.bias', 'encoders.2.attn.to_v.bias'],
|
504 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight": ['encoders.2.attn.to_q.weight', 'encoders.2.attn.to_k.weight', 'encoders.2.attn.to_v.weight'],
|
505 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
506 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
507 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.ln_1.bias": "encoders.2.layer_norm1.bias",
|
508 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.ln_1.weight": "encoders.2.layer_norm1.weight",
|
509 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.ln_2.bias": "encoders.2.layer_norm2.bias",
|
510 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.ln_2.weight": "encoders.2.layer_norm2.weight",
|
511 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.bias": "encoders.2.fc1.bias",
|
512 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.weight": "encoders.2.fc1.weight",
|
513 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.bias": "encoders.2.fc2.bias",
|
514 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.weight": "encoders.2.fc2.weight",
|
515 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias": ['encoders.20.attn.to_q.bias', 'encoders.20.attn.to_k.bias', 'encoders.20.attn.to_v.bias'],
|
516 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight": ['encoders.20.attn.to_q.weight', 'encoders.20.attn.to_k.weight', 'encoders.20.attn.to_v.weight'],
|
517 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.bias": "encoders.20.attn.to_out.bias",
|
518 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.weight": "encoders.20.attn.to_out.weight",
|
519 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.ln_1.bias": "encoders.20.layer_norm1.bias",
|
520 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.ln_1.weight": "encoders.20.layer_norm1.weight",
|
521 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.ln_2.bias": "encoders.20.layer_norm2.bias",
|
522 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.ln_2.weight": "encoders.20.layer_norm2.weight",
|
523 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.bias": "encoders.20.fc1.bias",
|
524 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.weight": "encoders.20.fc1.weight",
|
525 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.bias": "encoders.20.fc2.bias",
|
526 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.weight": "encoders.20.fc2.weight",
|
527 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias": ['encoders.21.attn.to_q.bias', 'encoders.21.attn.to_k.bias', 'encoders.21.attn.to_v.bias'],
|
528 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight": ['encoders.21.attn.to_q.weight', 'encoders.21.attn.to_k.weight', 'encoders.21.attn.to_v.weight'],
|
529 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.bias": "encoders.21.attn.to_out.bias",
|
530 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.weight": "encoders.21.attn.to_out.weight",
|
531 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.ln_1.bias": "encoders.21.layer_norm1.bias",
|
532 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.ln_1.weight": "encoders.21.layer_norm1.weight",
|
533 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.ln_2.bias": "encoders.21.layer_norm2.bias",
|
534 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.ln_2.weight": "encoders.21.layer_norm2.weight",
|
535 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.bias": "encoders.21.fc1.bias",
|
536 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.weight": "encoders.21.fc1.weight",
|
537 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.bias": "encoders.21.fc2.bias",
|
538 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.weight": "encoders.21.fc2.weight",
|
539 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias": ['encoders.22.attn.to_q.bias', 'encoders.22.attn.to_k.bias', 'encoders.22.attn.to_v.bias'],
|
540 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight": ['encoders.22.attn.to_q.weight', 'encoders.22.attn.to_k.weight', 'encoders.22.attn.to_v.weight'],
|
541 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.bias": "encoders.22.attn.to_out.bias",
|
542 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.weight": "encoders.22.attn.to_out.weight",
|
543 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.ln_1.bias": "encoders.22.layer_norm1.bias",
|
544 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.ln_1.weight": "encoders.22.layer_norm1.weight",
|
545 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.ln_2.bias": "encoders.22.layer_norm2.bias",
|
546 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.ln_2.weight": "encoders.22.layer_norm2.weight",
|
547 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.bias": "encoders.22.fc1.bias",
|
548 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.weight": "encoders.22.fc1.weight",
|
549 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.bias": "encoders.22.fc2.bias",
|
550 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.weight": "encoders.22.fc2.weight",
|
551 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias": ['encoders.23.attn.to_q.bias', 'encoders.23.attn.to_k.bias', 'encoders.23.attn.to_v.bias'],
|
552 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight": ['encoders.23.attn.to_q.weight', 'encoders.23.attn.to_k.weight', 'encoders.23.attn.to_v.weight'],
|
553 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.bias": "encoders.23.attn.to_out.bias",
|
554 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.weight": "encoders.23.attn.to_out.weight",
|
555 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.ln_1.bias": "encoders.23.layer_norm1.bias",
|
556 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.ln_1.weight": "encoders.23.layer_norm1.weight",
|
557 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.ln_2.bias": "encoders.23.layer_norm2.bias",
|
558 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.ln_2.weight": "encoders.23.layer_norm2.weight",
|
559 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.bias": "encoders.23.fc1.bias",
|
560 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.weight": "encoders.23.fc1.weight",
|
561 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.bias": "encoders.23.fc2.bias",
|
562 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.weight": "encoders.23.fc2.weight",
|
563 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias": ['encoders.24.attn.to_q.bias', 'encoders.24.attn.to_k.bias', 'encoders.24.attn.to_v.bias'],
|
564 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight": ['encoders.24.attn.to_q.weight', 'encoders.24.attn.to_k.weight', 'encoders.24.attn.to_v.weight'],
|
565 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.bias": "encoders.24.attn.to_out.bias",
|
566 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.weight": "encoders.24.attn.to_out.weight",
|
567 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.ln_1.bias": "encoders.24.layer_norm1.bias",
|
568 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.ln_1.weight": "encoders.24.layer_norm1.weight",
|
569 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.ln_2.bias": "encoders.24.layer_norm2.bias",
|
570 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.ln_2.weight": "encoders.24.layer_norm2.weight",
|
571 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.bias": "encoders.24.fc1.bias",
|
572 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.weight": "encoders.24.fc1.weight",
|
573 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.bias": "encoders.24.fc2.bias",
|
574 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.weight": "encoders.24.fc2.weight",
|
575 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias": ['encoders.25.attn.to_q.bias', 'encoders.25.attn.to_k.bias', 'encoders.25.attn.to_v.bias'],
|
576 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight": ['encoders.25.attn.to_q.weight', 'encoders.25.attn.to_k.weight', 'encoders.25.attn.to_v.weight'],
|
577 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.bias": "encoders.25.attn.to_out.bias",
|
578 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.weight": "encoders.25.attn.to_out.weight",
|
579 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.ln_1.bias": "encoders.25.layer_norm1.bias",
|
580 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.ln_1.weight": "encoders.25.layer_norm1.weight",
|
581 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.ln_2.bias": "encoders.25.layer_norm2.bias",
|
582 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.ln_2.weight": "encoders.25.layer_norm2.weight",
|
583 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.bias": "encoders.25.fc1.bias",
|
584 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.weight": "encoders.25.fc1.weight",
|
585 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.bias": "encoders.25.fc2.bias",
|
586 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.weight": "encoders.25.fc2.weight",
|
587 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias": ['encoders.26.attn.to_q.bias', 'encoders.26.attn.to_k.bias', 'encoders.26.attn.to_v.bias'],
|
588 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight": ['encoders.26.attn.to_q.weight', 'encoders.26.attn.to_k.weight', 'encoders.26.attn.to_v.weight'],
|
589 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.bias": "encoders.26.attn.to_out.bias",
|
590 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.weight": "encoders.26.attn.to_out.weight",
|
591 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.ln_1.bias": "encoders.26.layer_norm1.bias",
|
592 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.ln_1.weight": "encoders.26.layer_norm1.weight",
|
593 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.ln_2.bias": "encoders.26.layer_norm2.bias",
|
594 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.ln_2.weight": "encoders.26.layer_norm2.weight",
|
595 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.bias": "encoders.26.fc1.bias",
|
596 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.weight": "encoders.26.fc1.weight",
|
597 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.bias": "encoders.26.fc2.bias",
|
598 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.weight": "encoders.26.fc2.weight",
|
599 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias": ['encoders.27.attn.to_q.bias', 'encoders.27.attn.to_k.bias', 'encoders.27.attn.to_v.bias'],
|
600 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight": ['encoders.27.attn.to_q.weight', 'encoders.27.attn.to_k.weight', 'encoders.27.attn.to_v.weight'],
|
601 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.bias": "encoders.27.attn.to_out.bias",
|
602 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.weight": "encoders.27.attn.to_out.weight",
|
603 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.ln_1.bias": "encoders.27.layer_norm1.bias",
|
604 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.ln_1.weight": "encoders.27.layer_norm1.weight",
|
605 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.ln_2.bias": "encoders.27.layer_norm2.bias",
|
606 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.ln_2.weight": "encoders.27.layer_norm2.weight",
|
607 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.bias": "encoders.27.fc1.bias",
|
608 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.weight": "encoders.27.fc1.weight",
|
609 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.bias": "encoders.27.fc2.bias",
|
610 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.weight": "encoders.27.fc2.weight",
|
611 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias": ['encoders.28.attn.to_q.bias', 'encoders.28.attn.to_k.bias', 'encoders.28.attn.to_v.bias'],
|
612 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight": ['encoders.28.attn.to_q.weight', 'encoders.28.attn.to_k.weight', 'encoders.28.attn.to_v.weight'],
|
613 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.bias": "encoders.28.attn.to_out.bias",
|
614 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.weight": "encoders.28.attn.to_out.weight",
|
615 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.ln_1.bias": "encoders.28.layer_norm1.bias",
|
616 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.ln_1.weight": "encoders.28.layer_norm1.weight",
|
617 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.ln_2.bias": "encoders.28.layer_norm2.bias",
|
618 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.ln_2.weight": "encoders.28.layer_norm2.weight",
|
619 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.bias": "encoders.28.fc1.bias",
|
620 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.weight": "encoders.28.fc1.weight",
|
621 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.bias": "encoders.28.fc2.bias",
|
622 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.weight": "encoders.28.fc2.weight",
|
623 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias": ['encoders.29.attn.to_q.bias', 'encoders.29.attn.to_k.bias', 'encoders.29.attn.to_v.bias'],
|
624 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight": ['encoders.29.attn.to_q.weight', 'encoders.29.attn.to_k.weight', 'encoders.29.attn.to_v.weight'],
|
625 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.bias": "encoders.29.attn.to_out.bias",
|
626 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.weight": "encoders.29.attn.to_out.weight",
|
627 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.ln_1.bias": "encoders.29.layer_norm1.bias",
|
628 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.ln_1.weight": "encoders.29.layer_norm1.weight",
|
629 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.ln_2.bias": "encoders.29.layer_norm2.bias",
|
630 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.ln_2.weight": "encoders.29.layer_norm2.weight",
|
631 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.bias": "encoders.29.fc1.bias",
|
632 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.weight": "encoders.29.fc1.weight",
|
633 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.bias": "encoders.29.fc2.bias",
|
634 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.weight": "encoders.29.fc2.weight",
|
635 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias": ['encoders.3.attn.to_q.bias', 'encoders.3.attn.to_k.bias', 'encoders.3.attn.to_v.bias'],
|
636 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight": ['encoders.3.attn.to_q.weight', 'encoders.3.attn.to_k.weight', 'encoders.3.attn.to_v.weight'],
|
637 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
638 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
639 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.ln_1.bias": "encoders.3.layer_norm1.bias",
|
640 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.ln_1.weight": "encoders.3.layer_norm1.weight",
|
641 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.ln_2.bias": "encoders.3.layer_norm2.bias",
|
642 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.ln_2.weight": "encoders.3.layer_norm2.weight",
|
643 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.bias": "encoders.3.fc1.bias",
|
644 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.weight": "encoders.3.fc1.weight",
|
645 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.bias": "encoders.3.fc2.bias",
|
646 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.weight": "encoders.3.fc2.weight",
|
647 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias": ['encoders.30.attn.to_q.bias', 'encoders.30.attn.to_k.bias', 'encoders.30.attn.to_v.bias'],
|
648 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight": ['encoders.30.attn.to_q.weight', 'encoders.30.attn.to_k.weight', 'encoders.30.attn.to_v.weight'],
|
649 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.bias": "encoders.30.attn.to_out.bias",
|
650 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.weight": "encoders.30.attn.to_out.weight",
|
651 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.ln_1.bias": "encoders.30.layer_norm1.bias",
|
652 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.ln_1.weight": "encoders.30.layer_norm1.weight",
|
653 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.ln_2.bias": "encoders.30.layer_norm2.bias",
|
654 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.ln_2.weight": "encoders.30.layer_norm2.weight",
|
655 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.bias": "encoders.30.fc1.bias",
|
656 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.weight": "encoders.30.fc1.weight",
|
657 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.bias": "encoders.30.fc2.bias",
|
658 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.weight": "encoders.30.fc2.weight",
|
659 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias": ['encoders.31.attn.to_q.bias', 'encoders.31.attn.to_k.bias', 'encoders.31.attn.to_v.bias'],
|
660 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight": ['encoders.31.attn.to_q.weight', 'encoders.31.attn.to_k.weight', 'encoders.31.attn.to_v.weight'],
|
661 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.bias": "encoders.31.attn.to_out.bias",
|
662 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.weight": "encoders.31.attn.to_out.weight",
|
663 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.ln_1.bias": "encoders.31.layer_norm1.bias",
|
664 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.ln_1.weight": "encoders.31.layer_norm1.weight",
|
665 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.ln_2.bias": "encoders.31.layer_norm2.bias",
|
666 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.ln_2.weight": "encoders.31.layer_norm2.weight",
|
667 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.bias": "encoders.31.fc1.bias",
|
668 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.weight": "encoders.31.fc1.weight",
|
669 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.bias": "encoders.31.fc2.bias",
|
670 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.weight": "encoders.31.fc2.weight",
|
671 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias": ['encoders.4.attn.to_q.bias', 'encoders.4.attn.to_k.bias', 'encoders.4.attn.to_v.bias'],
|
672 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight": ['encoders.4.attn.to_q.weight', 'encoders.4.attn.to_k.weight', 'encoders.4.attn.to_v.weight'],
|
673 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
674 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
675 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.ln_1.bias": "encoders.4.layer_norm1.bias",
|
676 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.ln_1.weight": "encoders.4.layer_norm1.weight",
|
677 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.ln_2.bias": "encoders.4.layer_norm2.bias",
|
678 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.ln_2.weight": "encoders.4.layer_norm2.weight",
|
679 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.bias": "encoders.4.fc1.bias",
|
680 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.weight": "encoders.4.fc1.weight",
|
681 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.bias": "encoders.4.fc2.bias",
|
682 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.weight": "encoders.4.fc2.weight",
|
683 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias": ['encoders.5.attn.to_q.bias', 'encoders.5.attn.to_k.bias', 'encoders.5.attn.to_v.bias'],
|
684 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight": ['encoders.5.attn.to_q.weight', 'encoders.5.attn.to_k.weight', 'encoders.5.attn.to_v.weight'],
|
685 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
686 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
687 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.ln_1.bias": "encoders.5.layer_norm1.bias",
|
688 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.ln_1.weight": "encoders.5.layer_norm1.weight",
|
689 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.ln_2.bias": "encoders.5.layer_norm2.bias",
|
690 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.ln_2.weight": "encoders.5.layer_norm2.weight",
|
691 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.bias": "encoders.5.fc1.bias",
|
692 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.weight": "encoders.5.fc1.weight",
|
693 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.bias": "encoders.5.fc2.bias",
|
694 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.weight": "encoders.5.fc2.weight",
|
695 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias": ['encoders.6.attn.to_q.bias', 'encoders.6.attn.to_k.bias', 'encoders.6.attn.to_v.bias'],
|
696 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight": ['encoders.6.attn.to_q.weight', 'encoders.6.attn.to_k.weight', 'encoders.6.attn.to_v.weight'],
|
697 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
698 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
699 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.ln_1.bias": "encoders.6.layer_norm1.bias",
|
700 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.ln_1.weight": "encoders.6.layer_norm1.weight",
|
701 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.ln_2.bias": "encoders.6.layer_norm2.bias",
|
702 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.ln_2.weight": "encoders.6.layer_norm2.weight",
|
703 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.bias": "encoders.6.fc1.bias",
|
704 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.weight": "encoders.6.fc1.weight",
|
705 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.bias": "encoders.6.fc2.bias",
|
706 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.weight": "encoders.6.fc2.weight",
|
707 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias": ['encoders.7.attn.to_q.bias', 'encoders.7.attn.to_k.bias', 'encoders.7.attn.to_v.bias'],
|
708 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight": ['encoders.7.attn.to_q.weight', 'encoders.7.attn.to_k.weight', 'encoders.7.attn.to_v.weight'],
|
709 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
710 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
711 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.ln_1.bias": "encoders.7.layer_norm1.bias",
|
712 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.ln_1.weight": "encoders.7.layer_norm1.weight",
|
713 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.ln_2.bias": "encoders.7.layer_norm2.bias",
|
714 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.ln_2.weight": "encoders.7.layer_norm2.weight",
|
715 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.bias": "encoders.7.fc1.bias",
|
716 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.weight": "encoders.7.fc1.weight",
|
717 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.bias": "encoders.7.fc2.bias",
|
718 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.weight": "encoders.7.fc2.weight",
|
719 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias": ['encoders.8.attn.to_q.bias', 'encoders.8.attn.to_k.bias', 'encoders.8.attn.to_v.bias'],
|
720 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight": ['encoders.8.attn.to_q.weight', 'encoders.8.attn.to_k.weight', 'encoders.8.attn.to_v.weight'],
|
721 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
722 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
723 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.ln_1.bias": "encoders.8.layer_norm1.bias",
|
724 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.ln_1.weight": "encoders.8.layer_norm1.weight",
|
725 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.ln_2.bias": "encoders.8.layer_norm2.bias",
|
726 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.ln_2.weight": "encoders.8.layer_norm2.weight",
|
727 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.bias": "encoders.8.fc1.bias",
|
728 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.weight": "encoders.8.fc1.weight",
|
729 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.bias": "encoders.8.fc2.bias",
|
730 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.weight": "encoders.8.fc2.weight",
|
731 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias": ['encoders.9.attn.to_q.bias', 'encoders.9.attn.to_k.bias', 'encoders.9.attn.to_v.bias'],
|
732 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight": ['encoders.9.attn.to_q.weight', 'encoders.9.attn.to_k.weight', 'encoders.9.attn.to_v.weight'],
|
733 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
734 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
735 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.ln_1.bias": "encoders.9.layer_norm1.bias",
|
736 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.ln_1.weight": "encoders.9.layer_norm1.weight",
|
737 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.ln_2.bias": "encoders.9.layer_norm2.bias",
|
738 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.ln_2.weight": "encoders.9.layer_norm2.weight",
|
739 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.bias": "encoders.9.fc1.bias",
|
740 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.weight": "encoders.9.fc1.weight",
|
741 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias": "encoders.9.fc2.bias",
|
742 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight": "encoders.9.fc2.weight",
|
743 |
+
"conditioner.embedders.1.model.text_projection": "text_projection.weight",
|
744 |
+
}
|
745 |
+
state_dict_ = {}
|
746 |
+
for name in state_dict:
|
747 |
+
if name in rename_dict:
|
748 |
+
param = state_dict[name]
|
749 |
+
if name == "conditioner.embedders.1.model.positional_embedding":
|
750 |
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
751 |
+
elif name == "conditioner.embedders.1.model.text_projection":
|
752 |
+
param = param.T
|
753 |
+
if isinstance(rename_dict[name], str):
|
754 |
+
state_dict_[rename_dict[name]] = param
|
755 |
+
else:
|
756 |
+
length = param.shape[0] // 3
|
757 |
+
for i, rename in enumerate(rename_dict[name]):
|
758 |
+
state_dict_[rename] = param[i*length: i*length+length]
|
759 |
+
return state_dict_
|