wenmengzhou commited on
Commit
703e263
1 Parent(s): 359b5e8

add code and adapt to zero gpus

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. app.py +255 -0
  3. diffsynth/__init__.py +6 -0
  4. diffsynth/configs/__init__.py +0 -0
  5. diffsynth/configs/model_config.py +275 -0
  6. diffsynth/controlnets/__init__.py +2 -0
  7. diffsynth/controlnets/controlnet_unit.py +54 -0
  8. diffsynth/controlnets/processors.py +51 -0
  9. diffsynth/data/__init__.py +1 -0
  10. diffsynth/data/simple_text_image.py +35 -0
  11. diffsynth/data/video.py +148 -0
  12. diffsynth/extensions/ESRGAN/__init__.py +118 -0
  13. diffsynth/extensions/FastBlend/__init__.py +63 -0
  14. diffsynth/extensions/FastBlend/api.py +397 -0
  15. diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
  16. diffsynth/extensions/FastBlend/data.py +146 -0
  17. diffsynth/extensions/FastBlend/patch_match.py +298 -0
  18. diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
  19. diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
  20. diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
  21. diffsynth/extensions/FastBlend/runners/fast.py +141 -0
  22. diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
  23. diffsynth/extensions/RIFE/__init__.py +242 -0
  24. diffsynth/extensions/__init__.py +0 -0
  25. diffsynth/models/__init__.py +1 -0
  26. diffsynth/models/attention.py +89 -0
  27. diffsynth/models/downloader.py +66 -0
  28. diffsynth/models/flux_dit.py +575 -0
  29. diffsynth/models/flux_text_encoder.py +93 -0
  30. diffsynth/models/flux_vae.py +303 -0
  31. diffsynth/models/hunyuan_dit.py +451 -0
  32. diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
  33. diffsynth/models/kolors_text_encoder.py +1552 -0
  34. diffsynth/models/lora.py +195 -0
  35. diffsynth/models/model_manager.py +543 -0
  36. diffsynth/models/sd3_dit.py +798 -0
  37. diffsynth/models/sd3_text_encoder.py +0 -0
  38. diffsynth/models/sd3_vae_decoder.py +81 -0
  39. diffsynth/models/sd3_vae_encoder.py +95 -0
  40. diffsynth/models/sd_controlnet.py +589 -0
  41. diffsynth/models/sd_ipadapter.py +57 -0
  42. diffsynth/models/sd_motion.py +199 -0
  43. diffsynth/models/sd_text_encoder.py +321 -0
  44. diffsynth/models/sd_unet.py +0 -0
  45. diffsynth/models/sd_vae_decoder.py +336 -0
  46. diffsynth/models/sd_vae_encoder.py +282 -0
  47. diffsynth/models/sdxl_controlnet.py +318 -0
  48. diffsynth/models/sdxl_ipadapter.py +122 -0
  49. diffsynth/models/sdxl_motion.py +104 -0
  50. diffsynth/models/sdxl_text_encoder.py +759 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+ os.system("pip install -r requirements.txt")
4
+ from diffsynth import download_models
5
+ download_models(["Kolors", "FLUX.1-dev"])
6
+
7
+ import gradio as gr
8
+ from diffsynth import ModelManager, SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
9
+ import os, torch
10
+ from PIL import Image
11
+ import numpy as np
12
+
13
+
14
+ config = {
15
+ "model_config": {
16
+ "Stable Diffusion": {
17
+ "model_folder": "models/stable_diffusion",
18
+ "pipeline_class": SDImagePipeline,
19
+ "default_parameters": {
20
+ "cfg_scale": 7.0,
21
+ "height": 512,
22
+ "width": 512,
23
+ }
24
+ },
25
+ "Stable Diffusion XL": {
26
+ "model_folder": "models/stable_diffusion_xl",
27
+ "pipeline_class": SDXLImagePipeline,
28
+ "default_parameters": {
29
+ "cfg_scale": 7.0,
30
+ }
31
+ },
32
+ "Stable Diffusion 3": {
33
+ "model_folder": "models/stable_diffusion_3",
34
+ "pipeline_class": SD3ImagePipeline,
35
+ "default_parameters": {
36
+ "cfg_scale": 7.0,
37
+ }
38
+ },
39
+ "Stable Diffusion XL Turbo": {
40
+ "model_folder": "models/stable_diffusion_xl_turbo",
41
+ "pipeline_class": SDXLImagePipeline,
42
+ "default_parameters": {
43
+ "negative_prompt": "",
44
+ "cfg_scale": 1.0,
45
+ "num_inference_steps": 1,
46
+ "height": 512,
47
+ "width": 512,
48
+ }
49
+ },
50
+ "Kolors": {
51
+ "model_folder": "models/kolors",
52
+ "pipeline_class": SDXLImagePipeline,
53
+ "default_parameters": {
54
+ "cfg_scale": 7.0,
55
+ }
56
+ },
57
+ "HunyuanDiT": {
58
+ "model_folder": "models/HunyuanDiT",
59
+ "pipeline_class": HunyuanDiTImagePipeline,
60
+ "default_parameters": {
61
+ "cfg_scale": 7.0,
62
+ }
63
+ },
64
+ "FLUX": {
65
+ "model_folder": "models/FLUX",
66
+ "pipeline_class": FluxImagePipeline,
67
+ "default_parameters": {
68
+ "cfg_scale": 1.0,
69
+ }
70
+ }
71
+ },
72
+ "max_num_painter_layers": 3,
73
+ "max_num_model_cache": 2,
74
+ }
75
+
76
+
77
+ def load_model_list(model_type):
78
+ if model_type is None:
79
+ return []
80
+ folder = config["model_config"][model_type]["model_folder"]
81
+ file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
82
+ if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
83
+ file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
84
+ file_list = sorted(file_list)
85
+ return file_list
86
+
87
+
88
+ def load_model(model_type, model_path):
89
+ global model_dict
90
+ model_key = f"{model_type}:{model_path}"
91
+ if model_key in model_dict:
92
+ return model_dict[model_key]
93
+ model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
94
+ model_manager = ModelManager()
95
+ if model_type == "HunyuanDiT":
96
+ model_manager.load_models([
97
+ os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
98
+ os.path.join(model_path, "mt5/pytorch_model.bin"),
99
+ os.path.join(model_path, "model/pytorch_model_ema.pt"),
100
+ os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
101
+ ])
102
+ elif model_type == "Kolors":
103
+ model_manager.load_models([
104
+ os.path.join(model_path, "text_encoder"),
105
+ os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
106
+ os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
107
+ ])
108
+ elif model_type == "FLUX":
109
+ model_manager.torch_dtype = torch.bfloat16
110
+ file_list = [
111
+ os.path.join(model_path, "text_encoder/model.safetensors"),
112
+ os.path.join(model_path, "text_encoder_2"),
113
+ ]
114
+ for file_name in os.listdir(model_path):
115
+ if file_name.endswith(".safetensors"):
116
+ file_list.append(os.path.join(model_path, file_name))
117
+ model_manager.load_models(file_list)
118
+ else:
119
+ model_manager.load_model(model_path)
120
+ pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
121
+ while len(model_dict) + 1 > config["max_num_model_cache"]:
122
+ key = next(iter(model_dict.keys()))
123
+ model_manager_to_release, _ = model_dict[key]
124
+ model_manager_to_release.to("cpu")
125
+ del model_dict[key]
126
+ torch.cuda.empty_cache()
127
+ model_dict[model_key] = model_manager, pipe
128
+ return model_manager, pipe
129
+
130
+
131
+ model_dict = {}
132
+
133
+ with gr.Blocks() as app:
134
+ gr.Markdown("# DiffSynth-Studio Painter")
135
+ with gr.Row():
136
+ with gr.Column(scale=382, min_width=100):
137
+
138
+ with gr.Accordion(label="Model"):
139
+ model_type = gr.Dropdown(choices=["Kolors", "FLUX"], label="Model type", value="Kolors")
140
+ model_path = gr.Dropdown(choices=["Kolors"], interactive=True, label="Model path", value="Kolors")
141
+
142
+ @gr.on(inputs=model_type, outputs=model_path, triggers=model_type.change)
143
+ def model_type_to_model_path(model_type):
144
+ return gr.Dropdown(choices=load_model_list(model_type))
145
+
146
+ with gr.Accordion(label="Prompt"):
147
+ prompt = gr.Textbox(label="Prompt", lines=3)
148
+ negative_prompt = gr.Textbox(label="Negative prompt", lines=1)
149
+ cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, value=7.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
150
+ embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=0.0, step=0.1, interactive=True, label="Embedded guidance scale (only for FLUX)")
151
+
152
+ with gr.Accordion(label="Image"):
153
+ num_inference_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, interactive=True, label="Inference steps")
154
+ height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
155
+ width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
156
+ with gr.Column():
157
+ use_fixed_seed = gr.Checkbox(value=True, interactive=False, label="Use fixed seed")
158
+ seed = gr.Number(minimum=0, maximum=10**9, value=0, interactive=True, label="Random seed", show_label=False)
159
+
160
+ @gr.on(
161
+ inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
162
+ outputs=[prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
163
+ triggers=model_path.change
164
+ )
165
+ def model_path_to_default_params(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width):
166
+ load_model(model_type, model_path)
167
+ cfg_scale = config["model_config"][model_type]["default_parameters"].get("cfg_scale", cfg_scale)
168
+ embedded_guidance = config["model_config"][model_type]["default_parameters"].get("embedded_guidance", embedded_guidance)
169
+ num_inference_steps = config["model_config"][model_type]["default_parameters"].get("num_inference_steps", num_inference_steps)
170
+ height = config["model_config"][model_type]["default_parameters"].get("height", height)
171
+ width = config["model_config"][model_type]["default_parameters"].get("width", width)
172
+ return prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width
173
+
174
+
175
+ with gr.Column(scale=618, min_width=100):
176
+ with gr.Accordion(label="Painter"):
177
+ enable_local_prompt_list = []
178
+ local_prompt_list = []
179
+ mask_scale_list = []
180
+ canvas_list = []
181
+ for painter_layer_id in range(config["max_num_painter_layers"]):
182
+ with gr.Tab(label=f"Layer {painter_layer_id}"):
183
+ enable_local_prompt = gr.Checkbox(label="Enable", value=False, key=f"enable_local_prompt_{painter_layer_id}")
184
+ local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
185
+ mask_scale = gr.Slider(minimum=0.0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Mask scale", key=f"mask_scale_{painter_layer_id}")
186
+ canvas = gr.ImageEditor(canvas_size=(512, 1), sources=None, layers=False, interactive=True, image_mode="RGBA",
187
+ brush=gr.Brush(default_size=100, default_color="#000000", colors=["#000000"]),
188
+ label="Painter", key=f"canvas_{painter_layer_id}")
189
+ @gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear, enable_local_prompt.change], show_progress="hidden")
190
+ def resize_canvas(height, width, canvas):
191
+ h, w = canvas["background"].shape[:2]
192
+ if h != height or width != w:
193
+ return np.ones((height, width, 3), dtype=np.uint8) * 255
194
+ else:
195
+ return canvas
196
+
197
+ enable_local_prompt_list.append(enable_local_prompt)
198
+ local_prompt_list.append(local_prompt)
199
+ mask_scale_list.append(mask_scale)
200
+ canvas_list.append(canvas)
201
+ with gr.Accordion(label="Results"):
202
+ run_button = gr.Button(value="Generate", variant="primary")
203
+ output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
204
+ output_to_painter_button = gr.Button(value="Set as painter's background")
205
+ painter_background = gr.State(None)
206
+ input_background = gr.State(None)
207
+ @gr.on(
208
+ inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list,
209
+ outputs=[output_image],
210
+ triggers=run_button.click
211
+ )
212
+ @spaces.GPU
213
+ def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed, *args, progress=gr.Progress()):
214
+ _, pipe = load_model(model_type, model_path)
215
+ input_params = {
216
+ "prompt": prompt,
217
+ "negative_prompt": negative_prompt,
218
+ "cfg_scale": cfg_scale,
219
+ "num_inference_steps": num_inference_steps,
220
+ "height": height,
221
+ "width": width,
222
+ "progress_bar_cmd": progress.tqdm,
223
+ }
224
+ if isinstance(pipe, FluxImagePipeline):
225
+ input_params["embedded_guidance"] = embedded_guidance
226
+ enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = (
227
+ args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
228
+ args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
229
+ args[2 * config["max_num_painter_layers"]: 3 * config["max_num_painter_layers"]],
230
+ args[3 * config["max_num_painter_layers"]: 4 * config["max_num_painter_layers"]]
231
+ )
232
+ local_prompts, masks, mask_scales = [], [], []
233
+ for enable_local_prompt, local_prompt, mask_scale, canvas in zip(
234
+ enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list
235
+ ):
236
+ if enable_local_prompt:
237
+ local_prompts.append(local_prompt)
238
+ masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
239
+ mask_scales.append(mask_scale)
240
+ input_params.update({
241
+ "local_prompts": local_prompts,
242
+ "masks": masks,
243
+ "mask_scales": mask_scales,
244
+ })
245
+ torch.manual_seed(seed)
246
+ image = pipe(**input_params)
247
+ return image
248
+
249
+ @gr.on(inputs=[output_image] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
250
+ def send_output_to_painter_background(output_image, *canvas_list):
251
+ for canvas in canvas_list:
252
+ h, w = canvas["background"].shape[:2]
253
+ canvas["background"] = output_image.resize((w, h))
254
+ return tuple(canvas_list)
255
+ app.launch()
diffsynth/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .data import *
2
+ from .models import *
3
+ from .prompters import *
4
+ from .schedulers import *
5
+ from .pipelines import *
6
+ from .controlnets import *
diffsynth/configs/__init__.py ADDED
File without changes
diffsynth/configs/model_config.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import Literal, TypeAlias
2
+
3
+ from ..models.sd_text_encoder import SDTextEncoder
4
+ from ..models.sd_unet import SDUNet
5
+ from ..models.sd_vae_encoder import SDVAEEncoder
6
+ from ..models.sd_vae_decoder import SDVAEDecoder
7
+
8
+ from ..models.sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
9
+ from ..models.sdxl_unet import SDXLUNet
10
+ from ..models.sdxl_vae_decoder import SDXLVAEDecoder
11
+ from ..models.sdxl_vae_encoder import SDXLVAEEncoder
12
+
13
+ from ..models.sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
14
+ from ..models.sd3_dit import SD3DiT
15
+ from ..models.sd3_vae_decoder import SD3VAEDecoder
16
+ from ..models.sd3_vae_encoder import SD3VAEEncoder
17
+
18
+ from ..models.sd_controlnet import SDControlNet
19
+ from ..models.sdxl_controlnet import SDXLControlNetUnion
20
+
21
+ from ..models.sd_motion import SDMotionModel
22
+ from ..models.sdxl_motion import SDXLMotionModel
23
+
24
+ from ..models.svd_image_encoder import SVDImageEncoder
25
+ from ..models.svd_unet import SVDUNet
26
+ from ..models.svd_vae_decoder import SVDVAEDecoder
27
+ from ..models.svd_vae_encoder import SVDVAEEncoder
28
+
29
+ from ..models.sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
30
+ from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
31
+
32
+ from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
33
+ from ..models.hunyuan_dit import HunyuanDiT
34
+
35
+
36
+ from ..models.flux_dit import FluxDiT
37
+ from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
38
+ from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
39
+
40
+
41
+
42
+ model_loader_configs = [
43
+ # These configs are provided for detecting model type automatically.
44
+ # The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
45
+ (None, "091b0e30e77c76626b3ba62acdf95343", ["sd_controlnet"], [SDControlNet], "civitai"),
46
+ (None, "4a6c8306a27d916dea81263c8c88f450", ["hunyuan_dit_clip_text_encoder"], [HunyuanDiTCLIPTextEncoder], "civitai"),
47
+ (None, "f4aec400fe394297961218c768004521", ["hunyuan_dit"], [HunyuanDiT], "civitai"),
48
+ (None, "9e6e58043a5a2e332803ed42f6ee7181", ["hunyuan_dit_t5_text_encoder"], [HunyuanDiTT5TextEncoder], "civitai"),
49
+ (None, "13115dd45a6e1c39860f91ab073b8a78", ["sdxl_vae_encoder", "sdxl_vae_decoder"], [SDXLVAEEncoder, SDXLVAEDecoder], "diffusers"),
50
+ (None, "d78aa6797382a6d455362358a3295ea9", ["sd_ipadapter_clip_image_encoder"], [IpAdapterCLIPImageEmbedder], "diffusers"),
51
+ (None, "e291636cc15e803186b47404262ef812", ["sd_ipadapter"], [SDIpAdapter], "civitai"),
52
+ (None, "399c81f2f8de8d1843d0127a00f3c224", ["sdxl_ipadapter_clip_image_encoder"], [IpAdapterXLCLIPImageEmbedder], "diffusers"),
53
+ (None, "a64eac9aa0db4b9602213bc0131281c7", ["sdxl_ipadapter"], [SDXLIpAdapter], "civitai"),
54
+ (None, "52817e4fdd89df154f02749ca6f692ac", ["sdxl_unet"], [SDXLUNet], "diffusers"),
55
+ (None, "03343c606f16d834d6411d0902b53636", ["sd_text_encoder", "sd_unet", "sd_vae_decoder", "sd_vae_encoder"], [SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder], "civitai"),
56
+ (None, "d4ba77a7ece070679b4a987f58f201e9", ["sd_text_encoder"], [SDTextEncoder], "civitai"),
57
+ (None, "d0c89e55c5a57cf3981def0cb1c9e65a", ["sd_vae_decoder", "sd_vae_encoder"], [SDVAEDecoder, SDVAEEncoder], "civitai"),
58
+ (None, "3926bf373b39a67eeafd7901478a47a7", ["sd_unet"], [SDUNet], "civitai"),
59
+ (None, "1e0c39ec176b9007c05f76d52b554a4d", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
60
+ (None, "d9e0290829ba8d98e28e1a2b1407db4a", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
61
+ (None, "5072d0b24e406b49507abe861cf97691", ["sd3_text_encoder_3"], [SD3TextEncoder3], "civitai"),
62
+ (None, "4cf64a799d04260df438c6f33c9a047e", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"),
63
+ (None, "d9b008a867c498ab12ad24042eff8e3f", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"), # SDXL-Turbo
64
+ (None, "025bb7452e531a3853d951d77c63f032", ["sdxl_text_encoder", "sdxl_text_encoder_2"], [SDXLTextEncoder, SDXLTextEncoder2], "civitai"),
65
+ (None, "298997b403a4245c04102c9f36aac348", ["sdxl_unet"], [SDXLUNet], "civitai"),
66
+ (None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
67
+ (None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
68
+ (None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
69
+ (None, "31d2d9614fba60511fc9bf2604aa01f7", ["sdxl_controlnet"], [SDXLControlNetUnion], "diffusers"),
70
+ (None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["flux_text_encoder_1"], [FluxTextEncoder1], "diffusers"),
71
+ (None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"),
72
+ (None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"),
73
+ (None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai")
74
+ ]
75
+ huggingface_model_loader_configs = [
76
+ # These configs are provided for detecting model type automatically.
77
+ # The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
78
+ ("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
79
+ ("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
80
+ ("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
81
+ ("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
82
+ ]
83
+ patch_model_loader_configs = [
84
+ # These configs are provided for detecting model type automatically.
85
+ # The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
86
+ ("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
87
+ ]
88
+
89
+ preset_models_on_huggingface = {
90
+ "HunyuanDiT": [
91
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
92
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
93
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
94
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
95
+ ],
96
+ "stable-video-diffusion-img2vid-xt": [
97
+ ("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
98
+ ],
99
+ "ExVideo-SVD-128f-v1": [
100
+ ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
101
+ ],
102
+ }
103
+ preset_models_on_modelscope = {
104
+ # Hunyuan DiT
105
+ "HunyuanDiT": [
106
+ ("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
107
+ ("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
108
+ ("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
109
+ ("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
110
+ ],
111
+ # Stable Video Diffusion
112
+ "stable-video-diffusion-img2vid-xt": [
113
+ ("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
114
+ ],
115
+ # ExVideo
116
+ "ExVideo-SVD-128f-v1": [
117
+ ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
118
+ ],
119
+ # Stable Diffusion
120
+ "StableDiffusion_v15": [
121
+ ("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
122
+ ],
123
+ "DreamShaper_8": [
124
+ ("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
125
+ ],
126
+ "AingDiffusion_v12": [
127
+ ("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
128
+ ],
129
+ "Flat2DAnimerge_v45Sharp": [
130
+ ("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
131
+ ],
132
+ # Textual Inversion
133
+ "TextualInversion_VeryBadImageNegative_v1.3": [
134
+ ("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
135
+ ],
136
+ # Stable Diffusion XL
137
+ "StableDiffusionXL_v1": [
138
+ ("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
139
+ ],
140
+ "BluePencilXL_v200": [
141
+ ("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
142
+ ],
143
+ "StableDiffusionXL_Turbo": [
144
+ ("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
145
+ ],
146
+ "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
147
+ ("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
148
+ ],
149
+ # Stable Diffusion 3
150
+ "StableDiffusion3": [
151
+ ("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
152
+ ],
153
+ "StableDiffusion3_without_T5": [
154
+ ("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
155
+ ],
156
+ # ControlNet
157
+ "ControlNet_v11f1p_sd15_depth": [
158
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
159
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
160
+ ],
161
+ "ControlNet_v11p_sd15_softedge": [
162
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
163
+ ("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
164
+ ],
165
+ "ControlNet_v11f1e_sd15_tile": [
166
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
167
+ ],
168
+ "ControlNet_v11p_sd15_lineart": [
169
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
170
+ ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
171
+ ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
172
+ ],
173
+ "ControlNet_union_sdxl_promax": [
174
+ ("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
175
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
176
+ ],
177
+ # AnimateDiff
178
+ "AnimateDiff_v2": [
179
+ ("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
180
+ ],
181
+ "AnimateDiff_xl_beta": [
182
+ ("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
183
+ ],
184
+ # RIFE
185
+ "RIFE": [
186
+ ("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
187
+ ],
188
+ # Beautiful Prompt
189
+ "BeautifulPrompt": [
190
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
191
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
192
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
193
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
194
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
195
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
196
+ ],
197
+ # Translator
198
+ "opus-mt-zh-en": [
199
+ ("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
200
+ ("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
201
+ ("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
202
+ ("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
203
+ ("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
204
+ ("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
205
+ ("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
206
+ ("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
207
+ ],
208
+ # IP-Adapter
209
+ "IP-Adapter-SD": [
210
+ ("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
211
+ ("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
212
+ ],
213
+ "IP-Adapter-SDXL": [
214
+ ("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
215
+ ("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
216
+ ],
217
+ # Kolors
218
+ "Kolors": [
219
+ ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
220
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
221
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
222
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
223
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
224
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
225
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
226
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
227
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
228
+ ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
229
+ ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
230
+ ],
231
+ "SDXL-vae-fp16-fix": [
232
+ ("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
233
+ ],
234
+ # FLUX
235
+ "FLUX.1-dev": [
236
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
237
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
238
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
239
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
240
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
241
+ ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
242
+ ("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
243
+ ]
244
+ }
245
+ Preset_model_id: TypeAlias = Literal[
246
+ "HunyuanDiT",
247
+ "stable-video-diffusion-img2vid-xt",
248
+ "ExVideo-SVD-128f-v1",
249
+ "StableDiffusion_v15",
250
+ "DreamShaper_8",
251
+ "AingDiffusion_v12",
252
+ "Flat2DAnimerge_v45Sharp",
253
+ "TextualInversion_VeryBadImageNegative_v1.3",
254
+ "StableDiffusionXL_v1",
255
+ "BluePencilXL_v200",
256
+ "StableDiffusionXL_Turbo",
257
+ "ControlNet_v11f1p_sd15_depth",
258
+ "ControlNet_v11p_sd15_softedge",
259
+ "ControlNet_v11f1e_sd15_tile",
260
+ "ControlNet_v11p_sd15_lineart",
261
+ "AnimateDiff_v2",
262
+ "AnimateDiff_xl_beta",
263
+ "RIFE",
264
+ "BeautifulPrompt",
265
+ "opus-mt-zh-en",
266
+ "IP-Adapter-SD",
267
+ "IP-Adapter-SDXL",
268
+ "StableDiffusion3",
269
+ "StableDiffusion3_without_T5",
270
+ "Kolors",
271
+ "SDXL-vae-fp16-fix",
272
+ "ControlNet_union_sdxl_promax",
273
+ "FLUX.1-dev",
274
+ "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
275
+ ]
diffsynth/controlnets/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager
2
+ from .processors import Annotator
diffsynth/controlnets/controlnet_unit.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from .processors import Processor_id
4
+
5
+
6
+ class ControlNetConfigUnit:
7
+ def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
8
+ self.processor_id = processor_id
9
+ self.model_path = model_path
10
+ self.scale = scale
11
+
12
+
13
+ class ControlNetUnit:
14
+ def __init__(self, processor, model, scale=1.0):
15
+ self.processor = processor
16
+ self.model = model
17
+ self.scale = scale
18
+
19
+
20
+ class MultiControlNetManager:
21
+ def __init__(self, controlnet_units=[]):
22
+ self.processors = [unit.processor for unit in controlnet_units]
23
+ self.models = [unit.model for unit in controlnet_units]
24
+ self.scales = [unit.scale for unit in controlnet_units]
25
+
26
+ def process_image(self, image, processor_id=None):
27
+ if processor_id is None:
28
+ processed_image = [processor(image) for processor in self.processors]
29
+ else:
30
+ processed_image = [self.processors[processor_id](image)]
31
+ processed_image = torch.concat([
32
+ torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
33
+ for image_ in processed_image
34
+ ], dim=0)
35
+ return processed_image
36
+
37
+ def __call__(
38
+ self,
39
+ sample, timestep, encoder_hidden_states, conditionings,
40
+ tiled=False, tile_size=64, tile_stride=32, **kwargs
41
+ ):
42
+ res_stack = None
43
+ for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
44
+ res_stack_ = model(
45
+ sample, timestep, encoder_hidden_states, conditioning, **kwargs,
46
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
47
+ processor_id=processor.processor_id
48
+ )
49
+ res_stack_ = [res * scale for res in res_stack_]
50
+ if res_stack is None:
51
+ res_stack = res_stack_
52
+ else:
53
+ res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
54
+ return res_stack
diffsynth/controlnets/processors.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import Literal, TypeAlias
2
+ import warnings
3
+ with warnings.catch_warnings():
4
+ warnings.simplefilter("ignore")
5
+ from controlnet_aux.processor import (
6
+ CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector
7
+ )
8
+
9
+
10
+ Processor_id: TypeAlias = Literal[
11
+ "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"
12
+ ]
13
+
14
+ class Annotator:
15
+ def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda'):
16
+ if processor_id == "canny":
17
+ self.processor = CannyDetector()
18
+ elif processor_id == "depth":
19
+ self.processor = MidasDetector.from_pretrained(model_path).to(device)
20
+ elif processor_id == "softedge":
21
+ self.processor = HEDdetector.from_pretrained(model_path).to(device)
22
+ elif processor_id == "lineart":
23
+ self.processor = LineartDetector.from_pretrained(model_path).to(device)
24
+ elif processor_id == "lineart_anime":
25
+ self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
26
+ elif processor_id == "openpose":
27
+ self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
28
+ elif processor_id == "tile":
29
+ self.processor = None
30
+ else:
31
+ raise ValueError(f"Unsupported processor_id: {processor_id}")
32
+
33
+ self.processor_id = processor_id
34
+ self.detect_resolution = detect_resolution
35
+
36
+ def __call__(self, image):
37
+ width, height = image.size
38
+ if self.processor_id == "openpose":
39
+ kwargs = {
40
+ "include_body": True,
41
+ "include_hand": True,
42
+ "include_face": True
43
+ }
44
+ else:
45
+ kwargs = {}
46
+ if self.processor is not None:
47
+ detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
48
+ image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
49
+ image = image.resize((width, height))
50
+ return image
51
+
diffsynth/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .video import VideoData, save_video, save_frames
diffsynth/data/simple_text_image.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os
2
+ from torchvision import transforms
3
+ import pandas as pd
4
+ from PIL import Image
5
+
6
+
7
+
8
+ class TextImageDataset(torch.utils.data.Dataset):
9
+ def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
10
+ self.steps_per_epoch = steps_per_epoch
11
+ metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
12
+ self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
13
+ self.text = metadata["text"].to_list()
14
+ self.image_processor = transforms.Compose(
15
+ [
16
+ transforms.Resize(max(height, width), interpolation=transforms.InterpolationMode.BILINEAR),
17
+ transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
18
+ transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize([0.5], [0.5]),
21
+ ]
22
+ )
23
+
24
+
25
+ def __getitem__(self, index):
26
+ data_id = torch.randint(0, len(self.path), (1,))[0]
27
+ data_id = (data_id + index) % len(self.path) # For fixed seed.
28
+ text = self.text[data_id]
29
+ image = Image.open(self.path[data_id]).convert("RGB")
30
+ image = self.image_processor(image)
31
+ return {"text": text, "image": image}
32
+
33
+
34
+ def __len__(self):
35
+ return self.steps_per_epoch
diffsynth/data/video.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imageio, os
2
+ import numpy as np
3
+ from PIL import Image
4
+ from tqdm import tqdm
5
+
6
+
7
+ class LowMemoryVideo:
8
+ def __init__(self, file_name):
9
+ self.reader = imageio.get_reader(file_name)
10
+
11
+ def __len__(self):
12
+ return self.reader.count_frames()
13
+
14
+ def __getitem__(self, item):
15
+ return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
16
+
17
+ def __del__(self):
18
+ self.reader.close()
19
+
20
+
21
+ def split_file_name(file_name):
22
+ result = []
23
+ number = -1
24
+ for i in file_name:
25
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
26
+ if number == -1:
27
+ number = 0
28
+ number = number*10 + ord(i) - ord("0")
29
+ else:
30
+ if number != -1:
31
+ result.append(number)
32
+ number = -1
33
+ result.append(i)
34
+ if number != -1:
35
+ result.append(number)
36
+ result = tuple(result)
37
+ return result
38
+
39
+
40
+ def search_for_images(folder):
41
+ file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
42
+ file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
43
+ file_list = [i[1] for i in sorted(file_list)]
44
+ file_list = [os.path.join(folder, i) for i in file_list]
45
+ return file_list
46
+
47
+
48
+ class LowMemoryImageFolder:
49
+ def __init__(self, folder, file_list=None):
50
+ if file_list is None:
51
+ self.file_list = search_for_images(folder)
52
+ else:
53
+ self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
54
+
55
+ def __len__(self):
56
+ return len(self.file_list)
57
+
58
+ def __getitem__(self, item):
59
+ return Image.open(self.file_list[item]).convert("RGB")
60
+
61
+ def __del__(self):
62
+ pass
63
+
64
+
65
+ def crop_and_resize(image, height, width):
66
+ image = np.array(image)
67
+ image_height, image_width, _ = image.shape
68
+ if image_height / image_width < height / width:
69
+ croped_width = int(image_height / height * width)
70
+ left = (image_width - croped_width) // 2
71
+ image = image[:, left: left+croped_width]
72
+ image = Image.fromarray(image).resize((width, height))
73
+ else:
74
+ croped_height = int(image_width / width * height)
75
+ left = (image_height - croped_height) // 2
76
+ image = image[left: left+croped_height, :]
77
+ image = Image.fromarray(image).resize((width, height))
78
+ return image
79
+
80
+
81
+ class VideoData:
82
+ def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):
83
+ if video_file is not None:
84
+ self.data_type = "video"
85
+ self.data = LowMemoryVideo(video_file, **kwargs)
86
+ elif image_folder is not None:
87
+ self.data_type = "images"
88
+ self.data = LowMemoryImageFolder(image_folder, **kwargs)
89
+ else:
90
+ raise ValueError("Cannot open video or image folder")
91
+ self.length = None
92
+ self.set_shape(height, width)
93
+
94
+ def raw_data(self):
95
+ frames = []
96
+ for i in range(self.__len__()):
97
+ frames.append(self.__getitem__(i))
98
+ return frames
99
+
100
+ def set_length(self, length):
101
+ self.length = length
102
+
103
+ def set_shape(self, height, width):
104
+ self.height = height
105
+ self.width = width
106
+
107
+ def __len__(self):
108
+ if self.length is None:
109
+ return len(self.data)
110
+ else:
111
+ return self.length
112
+
113
+ def shape(self):
114
+ if self.height is not None and self.width is not None:
115
+ return self.height, self.width
116
+ else:
117
+ height, width, _ = self.__getitem__(0).shape
118
+ return height, width
119
+
120
+ def __getitem__(self, item):
121
+ frame = self.data.__getitem__(item)
122
+ width, height = frame.size
123
+ if self.height is not None and self.width is not None:
124
+ if self.height != height or self.width != width:
125
+ frame = crop_and_resize(frame, self.height, self.width)
126
+ return frame
127
+
128
+ def __del__(self):
129
+ pass
130
+
131
+ def save_images(self, folder):
132
+ os.makedirs(folder, exist_ok=True)
133
+ for i in tqdm(range(self.__len__()), desc="Saving images"):
134
+ frame = self.__getitem__(i)
135
+ frame.save(os.path.join(folder, f"{i}.png"))
136
+
137
+
138
+ def save_video(frames, save_path, fps, quality=9):
139
+ writer = imageio.get_writer(save_path, fps=fps, quality=quality)
140
+ for frame in tqdm(frames, desc="Saving video"):
141
+ frame = np.array(frame)
142
+ writer.append_data(frame)
143
+ writer.close()
144
+
145
+ def save_frames(frames, save_path):
146
+ os.makedirs(save_path, exist_ok=True)
147
+ for i, frame in enumerate(tqdm(frames, desc="Saving images")):
148
+ frame.save(os.path.join(save_path, f"{i}.png"))
diffsynth/extensions/ESRGAN/__init__.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import repeat
3
+ from PIL import Image
4
+ import numpy as np
5
+
6
+
7
+ class ResidualDenseBlock(torch.nn.Module):
8
+
9
+ def __init__(self, num_feat=64, num_grow_ch=32):
10
+ super(ResidualDenseBlock, self).__init__()
11
+ self.conv1 = torch.nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
12
+ self.conv2 = torch.nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
13
+ self.conv3 = torch.nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
14
+ self.conv4 = torch.nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
15
+ self.conv5 = torch.nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
16
+ self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
17
+
18
+ def forward(self, x):
19
+ x1 = self.lrelu(self.conv1(x))
20
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
21
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
22
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
23
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
24
+ return x5 * 0.2 + x
25
+
26
+
27
+ class RRDB(torch.nn.Module):
28
+
29
+ def __init__(self, num_feat, num_grow_ch=32):
30
+ super(RRDB, self).__init__()
31
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
32
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
33
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
34
+
35
+ def forward(self, x):
36
+ out = self.rdb1(x)
37
+ out = self.rdb2(out)
38
+ out = self.rdb3(out)
39
+ return out * 0.2 + x
40
+
41
+
42
+ class RRDBNet(torch.nn.Module):
43
+
44
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32):
45
+ super(RRDBNet, self).__init__()
46
+ self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
47
+ self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
48
+ self.conv_body = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
49
+ # upsample
50
+ self.conv_up1 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
51
+ self.conv_up2 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
52
+ self.conv_hr = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
53
+ self.conv_last = torch.nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
54
+ self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
55
+
56
+ def forward(self, x):
57
+ feat = x
58
+ feat = self.conv_first(feat)
59
+ body_feat = self.conv_body(self.body(feat))
60
+ feat = feat + body_feat
61
+ # upsample
62
+ feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
63
+ feat = self.lrelu(self.conv_up1(feat))
64
+ feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
65
+ feat = self.lrelu(self.conv_up2(feat))
66
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
67
+ return out
68
+
69
+
70
+ class ESRGAN(torch.nn.Module):
71
+ def __init__(self, model):
72
+ super().__init__()
73
+ self.model = model
74
+
75
+ @staticmethod
76
+ def from_pretrained(model_path):
77
+ model = RRDBNet()
78
+ state_dict = torch.load(model_path, map_location="cpu")["params_ema"]
79
+ model.load_state_dict(state_dict)
80
+ model.eval()
81
+ return ESRGAN(model)
82
+
83
+ def process_image(self, image):
84
+ image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
85
+ return image
86
+
87
+ def process_images(self, images):
88
+ images = [self.process_image(image) for image in images]
89
+ images = torch.stack(images)
90
+ return images
91
+
92
+ def decode_images(self, images):
93
+ images = (images.permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
94
+ images = [Image.fromarray(image) for image in images]
95
+ return images
96
+
97
+ @torch.no_grad()
98
+ def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
99
+ # Preprocess
100
+ input_tensor = self.process_images(images)
101
+
102
+ # Interpolate
103
+ output_tensor = []
104
+ for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
105
+ batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
106
+ batch_input_tensor = input_tensor[batch_id: batch_id_]
107
+ batch_input_tensor = batch_input_tensor.to(
108
+ device=self.model.conv_first.weight.device,
109
+ dtype=self.model.conv_first.weight.dtype)
110
+ batch_output_tensor = self.model(batch_input_tensor)
111
+ output_tensor.append(batch_output_tensor.cpu())
112
+
113
+ # Output
114
+ output_tensor = torch.concat(output_tensor, dim=0)
115
+
116
+ # To images
117
+ output_images = self.decode_images(output_tensor)
118
+ return output_images
diffsynth/extensions/FastBlend/__init__.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .runners.fast import TableManager, PyramidPatchMatcher
2
+ from PIL import Image
3
+ import numpy as np
4
+ import cupy as cp
5
+
6
+
7
+ class FastBlendSmoother:
8
+ def __init__(self):
9
+ self.batch_size = 8
10
+ self.window_size = 64
11
+ self.ebsynth_config = {
12
+ "minimum_patch_size": 5,
13
+ "threads_per_block": 8,
14
+ "num_iter": 5,
15
+ "gpu_id": 0,
16
+ "guide_weight": 10.0,
17
+ "initialize": "identity",
18
+ "tracking_window_size": 0,
19
+ }
20
+
21
+ @staticmethod
22
+ def from_model_manager(model_manager):
23
+ # TODO: fetch GPU ID from model_manager
24
+ return FastBlendSmoother()
25
+
26
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config):
27
+ frames_guide = [np.array(frame) for frame in frames_guide]
28
+ frames_style = [np.array(frame) for frame in frames_style]
29
+ table_manager = TableManager()
30
+ patch_match_engine = PyramidPatchMatcher(
31
+ image_height=frames_style[0].shape[0],
32
+ image_width=frames_style[0].shape[1],
33
+ channel=3,
34
+ **ebsynth_config
35
+ )
36
+ # left part
37
+ table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4")
38
+ table_l = table_manager.remapping_table_to_blending_table(table_l)
39
+ table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4")
40
+ # right part
41
+ table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4")
42
+ table_r = table_manager.remapping_table_to_blending_table(table_r)
43
+ table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1]
44
+ # merge
45
+ frames = []
46
+ for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
47
+ weight_m = -1
48
+ weight = weight_l + weight_m + weight_r
49
+ frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
50
+ frames.append(frame)
51
+ frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames]
52
+ return frames
53
+
54
+ def __call__(self, rendered_frames, original_frames=None, **kwargs):
55
+ frames = self.run(
56
+ original_frames, rendered_frames,
57
+ self.batch_size, self.window_size, self.ebsynth_config
58
+ )
59
+ mempool = cp.get_default_memory_pool()
60
+ pinned_mempool = cp.get_default_pinned_memory_pool()
61
+ mempool.free_all_blocks()
62
+ pinned_mempool.free_all_blocks()
63
+ return frames
diffsynth/extensions/FastBlend/api.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .runners import AccurateModeRunner, FastModeRunner, BalancedModeRunner, InterpolationModeRunner, InterpolationModeSingleFrameRunner
2
+ from .data import VideoData, get_video_fps, save_video, search_for_images
3
+ import os
4
+ import gradio as gr
5
+
6
+
7
+ def check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder):
8
+ frames_guide = VideoData(video_guide, video_guide_folder)
9
+ frames_style = VideoData(video_style, video_style_folder)
10
+ message = ""
11
+ if len(frames_guide) < len(frames_style):
12
+ message += f"The number of frames mismatches. Only the first {len(frames_guide)} frames of style video will be used.\n"
13
+ frames_style.set_length(len(frames_guide))
14
+ elif len(frames_guide) > len(frames_style):
15
+ message += f"The number of frames mismatches. Only the first {len(frames_style)} frames of guide video will be used.\n"
16
+ frames_guide.set_length(len(frames_style))
17
+ height_guide, width_guide = frames_guide.shape()
18
+ height_style, width_style = frames_style.shape()
19
+ if height_guide != height_style or width_guide != width_style:
20
+ message += f"The shape of frames mismatches. The frames in style video will be resized to (height: {height_guide}, width: {width_guide})\n"
21
+ frames_style.set_shape(height_guide, width_guide)
22
+ return frames_guide, frames_style, message
23
+
24
+
25
+ def smooth_video(
26
+ video_guide,
27
+ video_guide_folder,
28
+ video_style,
29
+ video_style_folder,
30
+ mode,
31
+ window_size,
32
+ batch_size,
33
+ tracking_window_size,
34
+ output_path,
35
+ fps,
36
+ minimum_patch_size,
37
+ num_iter,
38
+ guide_weight,
39
+ initialize,
40
+ progress = None,
41
+ ):
42
+ # input
43
+ frames_guide, frames_style, message = check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder)
44
+ if len(message) > 0:
45
+ print(message)
46
+ # output
47
+ if output_path == "":
48
+ if video_style is None:
49
+ output_path = os.path.join(video_style_folder, "output")
50
+ else:
51
+ output_path = os.path.join(os.path.split(video_style)[0], "output")
52
+ os.makedirs(output_path, exist_ok=True)
53
+ print("No valid output_path. Your video will be saved here:", output_path)
54
+ elif not os.path.exists(output_path):
55
+ os.makedirs(output_path, exist_ok=True)
56
+ print("Your video will be saved here:", output_path)
57
+ frames_path = os.path.join(output_path, "frames")
58
+ video_path = os.path.join(output_path, "video.mp4")
59
+ os.makedirs(frames_path, exist_ok=True)
60
+ # process
61
+ if mode == "Fast" or mode == "Balanced":
62
+ tracking_window_size = 0
63
+ ebsynth_config = {
64
+ "minimum_patch_size": minimum_patch_size,
65
+ "threads_per_block": 8,
66
+ "num_iter": num_iter,
67
+ "gpu_id": 0,
68
+ "guide_weight": guide_weight,
69
+ "initialize": initialize,
70
+ "tracking_window_size": tracking_window_size,
71
+ }
72
+ if mode == "Fast":
73
+ FastModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
74
+ elif mode == "Balanced":
75
+ BalancedModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
76
+ elif mode == "Accurate":
77
+ AccurateModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
78
+ # output
79
+ try:
80
+ fps = int(fps)
81
+ except:
82
+ fps = get_video_fps(video_style) if video_style is not None else 30
83
+ print("Fps:", fps)
84
+ print("Saving video...")
85
+ video_path = save_video(frames_path, video_path, num_frames=len(frames_style), fps=fps)
86
+ print("Success!")
87
+ print("Your frames are here:", frames_path)
88
+ print("Your video is here:", video_path)
89
+ return output_path, fps, video_path
90
+
91
+
92
+ class KeyFrameMatcher:
93
+ def __init__(self):
94
+ pass
95
+
96
+ def extract_number_from_filename(self, file_name):
97
+ result = []
98
+ number = -1
99
+ for i in file_name:
100
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
101
+ if number == -1:
102
+ number = 0
103
+ number = number*10 + ord(i) - ord("0")
104
+ else:
105
+ if number != -1:
106
+ result.append(number)
107
+ number = -1
108
+ if number != -1:
109
+ result.append(number)
110
+ result = tuple(result)
111
+ return result
112
+
113
+ def extract_number_from_filenames(self, file_names):
114
+ numbers = [self.extract_number_from_filename(file_name) for file_name in file_names]
115
+ min_length = min(len(i) for i in numbers)
116
+ for i in range(min_length-1, -1, -1):
117
+ if len(set(number[i] for number in numbers))==len(file_names):
118
+ return [number[i] for number in numbers]
119
+ return list(range(len(file_names)))
120
+
121
+ def match_using_filename(self, file_names_a, file_names_b):
122
+ file_names_b_set = set(file_names_b)
123
+ matched_file_name = []
124
+ for file_name in file_names_a:
125
+ if file_name not in file_names_b_set:
126
+ matched_file_name.append(None)
127
+ else:
128
+ matched_file_name.append(file_name)
129
+ return matched_file_name
130
+
131
+ def match_using_numbers(self, file_names_a, file_names_b):
132
+ numbers_a = self.extract_number_from_filenames(file_names_a)
133
+ numbers_b = self.extract_number_from_filenames(file_names_b)
134
+ numbers_b_dict = {number: file_name for number, file_name in zip(numbers_b, file_names_b)}
135
+ matched_file_name = []
136
+ for number in numbers_a:
137
+ if number in numbers_b_dict:
138
+ matched_file_name.append(numbers_b_dict[number])
139
+ else:
140
+ matched_file_name.append(None)
141
+ return matched_file_name
142
+
143
+ def match_filenames(self, file_names_a, file_names_b):
144
+ matched_file_name = self.match_using_filename(file_names_a, file_names_b)
145
+ if sum([i is not None for i in matched_file_name]) > 0:
146
+ return matched_file_name
147
+ matched_file_name = self.match_using_numbers(file_names_a, file_names_b)
148
+ return matched_file_name
149
+
150
+
151
+ def detect_frames(frames_path, keyframes_path):
152
+ if not os.path.exists(frames_path) and not os.path.exists(keyframes_path):
153
+ return "Please input the directory of guide video and rendered frames"
154
+ elif not os.path.exists(frames_path):
155
+ return "Please input the directory of guide video"
156
+ elif not os.path.exists(keyframes_path):
157
+ return "Please input the directory of rendered frames"
158
+ frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
159
+ keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
160
+ if len(frames)==0:
161
+ return f"No images detected in {frames_path}"
162
+ if len(keyframes)==0:
163
+ return f"No images detected in {keyframes_path}"
164
+ matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
165
+ max_filename_length = max([len(i) for i in frames])
166
+ if sum([i is not None for i in matched_keyframes])==0:
167
+ message = ""
168
+ for frame, matched_keyframe in zip(frames, matched_keyframes):
169
+ message += frame + " " * (max_filename_length - len(frame) + 1)
170
+ message += "--> No matched keyframes\n"
171
+ else:
172
+ message = ""
173
+ for frame, matched_keyframe in zip(frames, matched_keyframes):
174
+ message += frame + " " * (max_filename_length - len(frame) + 1)
175
+ if matched_keyframe is None:
176
+ message += "--> [to be rendered]\n"
177
+ else:
178
+ message += f"--> {matched_keyframe}\n"
179
+ return message
180
+
181
+
182
+ def check_input_for_interpolating(frames_path, keyframes_path):
183
+ # search for images
184
+ frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
185
+ keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
186
+ # match frames
187
+ matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
188
+ file_list = [file_name for file_name in matched_keyframes if file_name is not None]
189
+ index_style = [i for i, file_name in enumerate(matched_keyframes) if file_name is not None]
190
+ frames_guide = VideoData(None, frames_path)
191
+ frames_style = VideoData(None, keyframes_path, file_list=file_list)
192
+ # match shape
193
+ message = ""
194
+ height_guide, width_guide = frames_guide.shape()
195
+ height_style, width_style = frames_style.shape()
196
+ if height_guide != height_style or width_guide != width_style:
197
+ message += f"The shape of frames mismatches. The rendered keyframes will be resized to (height: {height_guide}, width: {width_guide})\n"
198
+ frames_style.set_shape(height_guide, width_guide)
199
+ return frames_guide, frames_style, index_style, message
200
+
201
+
202
+ def interpolate_video(
203
+ frames_path,
204
+ keyframes_path,
205
+ output_path,
206
+ fps,
207
+ batch_size,
208
+ tracking_window_size,
209
+ minimum_patch_size,
210
+ num_iter,
211
+ guide_weight,
212
+ initialize,
213
+ progress = None,
214
+ ):
215
+ # input
216
+ frames_guide, frames_style, index_style, message = check_input_for_interpolating(frames_path, keyframes_path)
217
+ if len(message) > 0:
218
+ print(message)
219
+ # output
220
+ if output_path == "":
221
+ output_path = os.path.join(keyframes_path, "output")
222
+ os.makedirs(output_path, exist_ok=True)
223
+ print("No valid output_path. Your video will be saved here:", output_path)
224
+ elif not os.path.exists(output_path):
225
+ os.makedirs(output_path, exist_ok=True)
226
+ print("Your video will be saved here:", output_path)
227
+ output_frames_path = os.path.join(output_path, "frames")
228
+ output_video_path = os.path.join(output_path, "video.mp4")
229
+ os.makedirs(output_frames_path, exist_ok=True)
230
+ # process
231
+ ebsynth_config = {
232
+ "minimum_patch_size": minimum_patch_size,
233
+ "threads_per_block": 8,
234
+ "num_iter": num_iter,
235
+ "gpu_id": 0,
236
+ "guide_weight": guide_weight,
237
+ "initialize": initialize,
238
+ "tracking_window_size": tracking_window_size
239
+ }
240
+ if len(index_style)==1:
241
+ InterpolationModeSingleFrameRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
242
+ else:
243
+ InterpolationModeRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
244
+ try:
245
+ fps = int(fps)
246
+ except:
247
+ fps = 30
248
+ print("Fps:", fps)
249
+ print("Saving video...")
250
+ video_path = save_video(output_frames_path, output_video_path, num_frames=len(frames_guide), fps=fps)
251
+ print("Success!")
252
+ print("Your frames are here:", output_frames_path)
253
+ print("Your video is here:", video_path)
254
+ return output_path, fps, video_path
255
+
256
+
257
+ def on_ui_tabs():
258
+ with gr.Blocks(analytics_enabled=False) as ui_component:
259
+ with gr.Tab("Blend"):
260
+ gr.Markdown("""
261
+ # Blend
262
+
263
+ Given a guide video and a style video, this algorithm will make the style video fluent according to the motion features of the guide video. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/208d902d-6aba-48d7-b7d5-cd120ebd306d) to see the example. Note that this extension doesn't support long videos. Please use short videos (e.g., several seconds). The algorithm is mainly designed for 512*512 resolution. Please use a larger `Minimum patch size` for higher resolution.
264
+ """)
265
+ with gr.Row():
266
+ with gr.Column():
267
+ with gr.Tab("Guide video"):
268
+ video_guide = gr.Video(label="Guide video")
269
+ with gr.Tab("Guide video (images format)"):
270
+ video_guide_folder = gr.Textbox(label="Guide video (images format)", value="")
271
+ with gr.Column():
272
+ with gr.Tab("Style video"):
273
+ video_style = gr.Video(label="Style video")
274
+ with gr.Tab("Style video (images format)"):
275
+ video_style_folder = gr.Textbox(label="Style video (images format)", value="")
276
+ with gr.Column():
277
+ output_path = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of style video")
278
+ fps = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
279
+ video_output = gr.Video(label="Output video", interactive=False, show_share_button=True)
280
+ btn = gr.Button(value="Blend")
281
+ with gr.Row():
282
+ with gr.Column():
283
+ gr.Markdown("# Settings")
284
+ mode = gr.Radio(["Fast", "Balanced", "Accurate"], label="Inference mode", value="Fast", interactive=True)
285
+ window_size = gr.Slider(label="Sliding window size", value=15, minimum=1, maximum=1000, step=1, interactive=True)
286
+ batch_size = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
287
+ tracking_window_size = gr.Slider(label="Tracking window size (only for accurate mode)", value=0, minimum=0, maximum=10, step=1, interactive=True)
288
+ gr.Markdown("## Advanced Settings")
289
+ minimum_patch_size = gr.Slider(label="Minimum patch size (odd number)", value=5, minimum=5, maximum=99, step=2, interactive=True)
290
+ num_iter = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
291
+ guide_weight = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
292
+ initialize = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
293
+ with gr.Column():
294
+ gr.Markdown("""
295
+ # Reference
296
+
297
+ * Output directory: the directory to save the video.
298
+ * Inference mode
299
+
300
+ |Mode|Time|Memory|Quality|Frame by frame output|Description|
301
+ |-|-|-|-|-|-|
302
+ |Fast|■|■■■|■■|No|Blend the frames using a tree-like data structure, which requires much RAM but is fast.|
303
+ |Balanced|■■|■|■■|Yes|Blend the frames naively.|
304
+ |Accurate|■■■|■|■■■|Yes|Blend the frames and align them together for higher video quality. When [batch size] >= [sliding window size] * 2 + 1, the performance is the best.|
305
+
306
+ * Sliding window size: our algorithm will blend the frames in a sliding windows. If the size is n, each frame will be blended with the last n frames and the next n frames. A large sliding window can make the video fluent but sometimes smoggy.
307
+ * Batch size: a larger batch size makes the program faster but requires more VRAM.
308
+ * Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
309
+ * Advanced settings
310
+ * Minimum patch size (odd number): the minimum patch size used for patch matching. (Default: 5)
311
+ * Number of iterations: the number of iterations of patch matching. (Default: 5)
312
+ * Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
313
+ * NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
314
+ """)
315
+ btn.click(
316
+ smooth_video,
317
+ inputs=[
318
+ video_guide,
319
+ video_guide_folder,
320
+ video_style,
321
+ video_style_folder,
322
+ mode,
323
+ window_size,
324
+ batch_size,
325
+ tracking_window_size,
326
+ output_path,
327
+ fps,
328
+ minimum_patch_size,
329
+ num_iter,
330
+ guide_weight,
331
+ initialize
332
+ ],
333
+ outputs=[output_path, fps, video_output]
334
+ )
335
+ with gr.Tab("Interpolate"):
336
+ gr.Markdown("""
337
+ # Interpolate
338
+
339
+ Given a guide video and some rendered keyframes, this algorithm will render the remaining frames. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/3490c5b4-8f67-478f-86de-f9adc2ace16a) to see the example. The algorithm is experimental and is only tested for 512*512 resolution.
340
+ """)
341
+ with gr.Row():
342
+ with gr.Column():
343
+ with gr.Row():
344
+ with gr.Column():
345
+ video_guide_folder_ = gr.Textbox(label="Guide video (images format)", value="")
346
+ with gr.Column():
347
+ rendered_keyframes_ = gr.Textbox(label="Rendered keyframes (images format)", value="")
348
+ with gr.Row():
349
+ detected_frames = gr.Textbox(label="Detected frames", value="Please input the directory of guide video and rendered frames", lines=9, max_lines=9, interactive=False)
350
+ video_guide_folder_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
351
+ rendered_keyframes_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
352
+ with gr.Column():
353
+ output_path_ = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of rendered keyframes")
354
+ fps_ = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
355
+ video_output_ = gr.Video(label="Output video", interactive=False, show_share_button=True)
356
+ btn_ = gr.Button(value="Interpolate")
357
+ with gr.Row():
358
+ with gr.Column():
359
+ gr.Markdown("# Settings")
360
+ batch_size_ = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
361
+ tracking_window_size_ = gr.Slider(label="Tracking window size", value=0, minimum=0, maximum=10, step=1, interactive=True)
362
+ gr.Markdown("## Advanced Settings")
363
+ minimum_patch_size_ = gr.Slider(label="Minimum patch size (odd number, larger is better)", value=15, minimum=5, maximum=99, step=2, interactive=True)
364
+ num_iter_ = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
365
+ guide_weight_ = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
366
+ initialize_ = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
367
+ with gr.Column():
368
+ gr.Markdown("""
369
+ # Reference
370
+
371
+ * Output directory: the directory to save the video.
372
+ * Batch size: a larger batch size makes the program faster but requires more VRAM.
373
+ * Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
374
+ * Advanced settings
375
+ * Minimum patch size (odd number): the minimum patch size used for patch matching. **This parameter should be larger than that in blending. (Default: 15)**
376
+ * Number of iterations: the number of iterations of patch matching. (Default: 5)
377
+ * Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
378
+ * NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
379
+ """)
380
+ btn_.click(
381
+ interpolate_video,
382
+ inputs=[
383
+ video_guide_folder_,
384
+ rendered_keyframes_,
385
+ output_path_,
386
+ fps_,
387
+ batch_size_,
388
+ tracking_window_size_,
389
+ minimum_patch_size_,
390
+ num_iter_,
391
+ guide_weight_,
392
+ initialize_,
393
+ ],
394
+ outputs=[output_path_, fps_, video_output_]
395
+ )
396
+
397
+ return [(ui_component, "FastBlend", "FastBlend_ui")]
diffsynth/extensions/FastBlend/cupy_kernels.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cupy as cp
2
+
3
+ remapping_kernel = cp.RawKernel(r'''
4
+ extern "C" __global__
5
+ void remap(
6
+ const int height,
7
+ const int width,
8
+ const int channel,
9
+ const int patch_size,
10
+ const int pad_size,
11
+ const float* source_style,
12
+ const int* nnf,
13
+ float* target_style
14
+ ) {
15
+ const int r = (patch_size - 1) / 2;
16
+ const int x = blockDim.x * blockIdx.x + threadIdx.x;
17
+ const int y = blockDim.y * blockIdx.y + threadIdx.y;
18
+ if (x >= height or y >= width) return;
19
+ const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
20
+ const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size);
21
+ const int min_px = x < r ? -x : -r;
22
+ const int max_px = x + r > height - 1 ? height - 1 - x : r;
23
+ const int min_py = y < r ? -y : -r;
24
+ const int max_py = y + r > width - 1 ? width - 1 - y : r;
25
+ int num = 0;
26
+ for (int px = min_px; px <= max_px; px++){
27
+ for (int py = min_py; py <= max_py; py++){
28
+ const int nid = (x + px) * width + y + py;
29
+ const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px;
30
+ const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py;
31
+ if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue;
32
+ const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size);
33
+ num++;
34
+ for (int c = 0; c < channel; c++){
35
+ target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c];
36
+ }
37
+ }
38
+ }
39
+ for (int c = 0; c < channel; c++){
40
+ target_style[z + pid * channel + c] /= num;
41
+ }
42
+ }
43
+ ''', 'remap')
44
+
45
+
46
+ patch_error_kernel = cp.RawKernel(r'''
47
+ extern "C" __global__
48
+ void patch_error(
49
+ const int height,
50
+ const int width,
51
+ const int channel,
52
+ const int patch_size,
53
+ const int pad_size,
54
+ const float* source,
55
+ const int* nnf,
56
+ const float* target,
57
+ float* error
58
+ ) {
59
+ const int r = (patch_size - 1) / 2;
60
+ const int x = blockDim.x * blockIdx.x + threadIdx.x;
61
+ const int y = blockDim.y * blockIdx.y + threadIdx.y;
62
+ const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
63
+ if (x >= height or y >= width) return;
64
+ const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0];
65
+ const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1];
66
+ float e = 0;
67
+ for (int px = -r; px <= r; px++){
68
+ for (int py = -r; py <= r; py++){
69
+ const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py;
70
+ const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py;
71
+ for (int c = 0; c < channel; c++){
72
+ const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c];
73
+ e += diff * diff;
74
+ }
75
+ }
76
+ }
77
+ error[blockIdx.z * height * width + x * width + y] = e;
78
+ }
79
+ ''', 'patch_error')
80
+
81
+
82
+ pairwise_patch_error_kernel = cp.RawKernel(r'''
83
+ extern "C" __global__
84
+ void pairwise_patch_error(
85
+ const int height,
86
+ const int width,
87
+ const int channel,
88
+ const int patch_size,
89
+ const int pad_size,
90
+ const float* source_a,
91
+ const int* nnf_a,
92
+ const float* source_b,
93
+ const int* nnf_b,
94
+ float* error
95
+ ) {
96
+ const int r = (patch_size - 1) / 2;
97
+ const int x = blockDim.x * blockIdx.x + threadIdx.x;
98
+ const int y = blockDim.y * blockIdx.y + threadIdx.y;
99
+ const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
100
+ if (x >= height or y >= width) return;
101
+ const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2;
102
+ const int x_a = nnf_a[z_nnf + 0];
103
+ const int y_a = nnf_a[z_nnf + 1];
104
+ const int x_b = nnf_b[z_nnf + 0];
105
+ const int y_b = nnf_b[z_nnf + 1];
106
+ float e = 0;
107
+ for (int px = -r; px <= r; px++){
108
+ for (int py = -r; py <= r; py++){
109
+ const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py;
110
+ const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py;
111
+ for (int c = 0; c < channel; c++){
112
+ const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c];
113
+ e += diff * diff;
114
+ }
115
+ }
116
+ }
117
+ error[blockIdx.z * height * width + x * width + y] = e;
118
+ }
119
+ ''', 'pairwise_patch_error')
diffsynth/extensions/FastBlend/data.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imageio, os
2
+ import numpy as np
3
+ from PIL import Image
4
+
5
+
6
+ def read_video(file_name):
7
+ reader = imageio.get_reader(file_name)
8
+ video = []
9
+ for frame in reader:
10
+ frame = np.array(frame)
11
+ video.append(frame)
12
+ reader.close()
13
+ return video
14
+
15
+
16
+ def get_video_fps(file_name):
17
+ reader = imageio.get_reader(file_name)
18
+ fps = reader.get_meta_data()["fps"]
19
+ reader.close()
20
+ return fps
21
+
22
+
23
+ def save_video(frames_path, video_path, num_frames, fps):
24
+ writer = imageio.get_writer(video_path, fps=fps, quality=9)
25
+ for i in range(num_frames):
26
+ frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i)))
27
+ writer.append_data(frame)
28
+ writer.close()
29
+ return video_path
30
+
31
+
32
+ class LowMemoryVideo:
33
+ def __init__(self, file_name):
34
+ self.reader = imageio.get_reader(file_name)
35
+
36
+ def __len__(self):
37
+ return self.reader.count_frames()
38
+
39
+ def __getitem__(self, item):
40
+ return np.array(self.reader.get_data(item))
41
+
42
+ def __del__(self):
43
+ self.reader.close()
44
+
45
+
46
+ def split_file_name(file_name):
47
+ result = []
48
+ number = -1
49
+ for i in file_name:
50
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
51
+ if number == -1:
52
+ number = 0
53
+ number = number*10 + ord(i) - ord("0")
54
+ else:
55
+ if number != -1:
56
+ result.append(number)
57
+ number = -1
58
+ result.append(i)
59
+ if number != -1:
60
+ result.append(number)
61
+ result = tuple(result)
62
+ return result
63
+
64
+
65
+ def search_for_images(folder):
66
+ file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
67
+ file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
68
+ file_list = [i[1] for i in sorted(file_list)]
69
+ file_list = [os.path.join(folder, i) for i in file_list]
70
+ return file_list
71
+
72
+
73
+ def read_images(folder):
74
+ file_list = search_for_images(folder)
75
+ frames = [np.array(Image.open(i)) for i in file_list]
76
+ return frames
77
+
78
+
79
+ class LowMemoryImageFolder:
80
+ def __init__(self, folder, file_list=None):
81
+ if file_list is None:
82
+ self.file_list = search_for_images(folder)
83
+ else:
84
+ self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
85
+
86
+ def __len__(self):
87
+ return len(self.file_list)
88
+
89
+ def __getitem__(self, item):
90
+ return np.array(Image.open(self.file_list[item]))
91
+
92
+ def __del__(self):
93
+ pass
94
+
95
+
96
+ class VideoData:
97
+ def __init__(self, video_file, image_folder, **kwargs):
98
+ if video_file is not None:
99
+ self.data_type = "video"
100
+ self.data = LowMemoryVideo(video_file, **kwargs)
101
+ elif image_folder is not None:
102
+ self.data_type = "images"
103
+ self.data = LowMemoryImageFolder(image_folder, **kwargs)
104
+ else:
105
+ raise ValueError("Cannot open video or image folder")
106
+ self.length = None
107
+ self.height = None
108
+ self.width = None
109
+
110
+ def raw_data(self):
111
+ frames = []
112
+ for i in range(self.__len__()):
113
+ frames.append(self.__getitem__(i))
114
+ return frames
115
+
116
+ def set_length(self, length):
117
+ self.length = length
118
+
119
+ def set_shape(self, height, width):
120
+ self.height = height
121
+ self.width = width
122
+
123
+ def __len__(self):
124
+ if self.length is None:
125
+ return len(self.data)
126
+ else:
127
+ return self.length
128
+
129
+ def shape(self):
130
+ if self.height is not None and self.width is not None:
131
+ return self.height, self.width
132
+ else:
133
+ height, width, _ = self.__getitem__(0).shape
134
+ return height, width
135
+
136
+ def __getitem__(self, item):
137
+ frame = self.data.__getitem__(item)
138
+ height, width, _ = frame.shape
139
+ if self.height is not None and self.width is not None:
140
+ if self.height != height or self.width != width:
141
+ frame = Image.fromarray(frame).resize((self.width, self.height))
142
+ frame = np.array(frame)
143
+ return frame
144
+
145
+ def __del__(self):
146
+ pass
diffsynth/extensions/FastBlend/patch_match.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel
2
+ import numpy as np
3
+ import cupy as cp
4
+ import cv2
5
+
6
+
7
+ class PatchMatcher:
8
+ def __init__(
9
+ self, height, width, channel, minimum_patch_size,
10
+ threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
11
+ random_search_steps=3, random_search_range=4,
12
+ use_mean_target_style=False, use_pairwise_patch_error=False,
13
+ tracking_window_size=0
14
+ ):
15
+ self.height = height
16
+ self.width = width
17
+ self.channel = channel
18
+ self.minimum_patch_size = minimum_patch_size
19
+ self.threads_per_block = threads_per_block
20
+ self.num_iter = num_iter
21
+ self.gpu_id = gpu_id
22
+ self.guide_weight = guide_weight
23
+ self.random_search_steps = random_search_steps
24
+ self.random_search_range = random_search_range
25
+ self.use_mean_target_style = use_mean_target_style
26
+ self.use_pairwise_patch_error = use_pairwise_patch_error
27
+ self.tracking_window_size = tracking_window_size
28
+
29
+ self.patch_size_list = [minimum_patch_size + i*2 for i in range(num_iter)][::-1]
30
+ self.pad_size = self.patch_size_list[0] // 2
31
+ self.grid = (
32
+ (height + threads_per_block - 1) // threads_per_block,
33
+ (width + threads_per_block - 1) // threads_per_block
34
+ )
35
+ self.block = (threads_per_block, threads_per_block)
36
+
37
+ def pad_image(self, image):
38
+ return cp.pad(image, ((0, 0), (self.pad_size, self.pad_size), (self.pad_size, self.pad_size), (0, 0)))
39
+
40
+ def unpad_image(self, image):
41
+ return image[:, self.pad_size: -self.pad_size, self.pad_size: -self.pad_size, :]
42
+
43
+ def apply_nnf_to_image(self, nnf, source):
44
+ batch_size = source.shape[0]
45
+ target = cp.zeros((batch_size, self.height + self.pad_size * 2, self.width + self.pad_size * 2, self.channel), dtype=cp.float32)
46
+ remapping_kernel(
47
+ self.grid + (batch_size,),
48
+ self.block,
49
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target)
50
+ )
51
+ return target
52
+
53
+ def get_patch_error(self, source, nnf, target):
54
+ batch_size = source.shape[0]
55
+ error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
56
+ patch_error_kernel(
57
+ self.grid + (batch_size,),
58
+ self.block,
59
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target, error)
60
+ )
61
+ return error
62
+
63
+ def get_pairwise_patch_error(self, source, nnf):
64
+ batch_size = source.shape[0]//2
65
+ error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
66
+ source_a, nnf_a = source[0::2].copy(), nnf[0::2].copy()
67
+ source_b, nnf_b = source[1::2].copy(), nnf[1::2].copy()
68
+ pairwise_patch_error_kernel(
69
+ self.grid + (batch_size,),
70
+ self.block,
71
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source_a, nnf_a, source_b, nnf_b, error)
72
+ )
73
+ error = error.repeat(2, axis=0)
74
+ return error
75
+
76
+ def get_error(self, source_guide, target_guide, source_style, target_style, nnf):
77
+ error_guide = self.get_patch_error(source_guide, nnf, target_guide)
78
+ if self.use_mean_target_style:
79
+ target_style = self.apply_nnf_to_image(nnf, source_style)
80
+ target_style = target_style.mean(axis=0, keepdims=True)
81
+ target_style = target_style.repeat(source_guide.shape[0], axis=0)
82
+ if self.use_pairwise_patch_error:
83
+ error_style = self.get_pairwise_patch_error(source_style, nnf)
84
+ else:
85
+ error_style = self.get_patch_error(source_style, nnf, target_style)
86
+ error = error_guide * self.guide_weight + error_style
87
+ return error
88
+
89
+ def clamp_bound(self, nnf):
90
+ nnf[:,:,:,0] = cp.clip(nnf[:,:,:,0], 0, self.height-1)
91
+ nnf[:,:,:,1] = cp.clip(nnf[:,:,:,1], 0, self.width-1)
92
+ return nnf
93
+
94
+ def random_step(self, nnf, r):
95
+ batch_size = nnf.shape[0]
96
+ step = cp.random.randint(-r, r+1, size=(batch_size, self.height, self.width, 2), dtype=cp.int32)
97
+ upd_nnf = self.clamp_bound(nnf + step)
98
+ return upd_nnf
99
+
100
+ def neighboor_step(self, nnf, d):
101
+ if d==0:
102
+ upd_nnf = cp.concatenate([nnf[:, :1, :], nnf[:, :-1, :]], axis=1)
103
+ upd_nnf[:, :, :, 0] += 1
104
+ elif d==1:
105
+ upd_nnf = cp.concatenate([nnf[:, :, :1], nnf[:, :, :-1]], axis=2)
106
+ upd_nnf[:, :, :, 1] += 1
107
+ elif d==2:
108
+ upd_nnf = cp.concatenate([nnf[:, 1:, :], nnf[:, -1:, :]], axis=1)
109
+ upd_nnf[:, :, :, 0] -= 1
110
+ elif d==3:
111
+ upd_nnf = cp.concatenate([nnf[:, :, 1:], nnf[:, :, -1:]], axis=2)
112
+ upd_nnf[:, :, :, 1] -= 1
113
+ upd_nnf = self.clamp_bound(upd_nnf)
114
+ return upd_nnf
115
+
116
+ def shift_nnf(self, nnf, d):
117
+ if d>0:
118
+ d = min(nnf.shape[0], d)
119
+ upd_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
120
+ else:
121
+ d = max(-nnf.shape[0], d)
122
+ upd_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
123
+ return upd_nnf
124
+
125
+ def track_step(self, nnf, d):
126
+ if self.use_pairwise_patch_error:
127
+ upd_nnf = cp.zeros_like(nnf)
128
+ upd_nnf[0::2] = self.shift_nnf(nnf[0::2], d)
129
+ upd_nnf[1::2] = self.shift_nnf(nnf[1::2], d)
130
+ else:
131
+ upd_nnf = self.shift_nnf(nnf, d)
132
+ return upd_nnf
133
+
134
+ def C(self, n, m):
135
+ # not used
136
+ c = 1
137
+ for i in range(1, n+1):
138
+ c *= i
139
+ for i in range(1, m+1):
140
+ c //= i
141
+ for i in range(1, n-m+1):
142
+ c //= i
143
+ return c
144
+
145
+ def bezier_step(self, nnf, r):
146
+ # not used
147
+ n = r * 2 - 1
148
+ upd_nnf = cp.zeros(shape=nnf.shape, dtype=cp.float32)
149
+ for i, d in enumerate(list(range(-r, 0)) + list(range(1, r+1))):
150
+ if d>0:
151
+ ctl_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
152
+ elif d<0:
153
+ ctl_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
154
+ upd_nnf += ctl_nnf * (self.C(n, i) / 2**n)
155
+ upd_nnf = self.clamp_bound(upd_nnf).astype(nnf.dtype)
156
+ return upd_nnf
157
+
158
+ def update(self, source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf):
159
+ upd_err = self.get_error(source_guide, target_guide, source_style, target_style, upd_nnf)
160
+ upd_idx = (upd_err < err)
161
+ nnf[upd_idx] = upd_nnf[upd_idx]
162
+ err[upd_idx] = upd_err[upd_idx]
163
+ return nnf, err
164
+
165
+ def propagation(self, source_guide, target_guide, source_style, target_style, nnf, err):
166
+ for d in cp.random.permutation(4):
167
+ upd_nnf = self.neighboor_step(nnf, d)
168
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
169
+ return nnf, err
170
+
171
+ def random_search(self, source_guide, target_guide, source_style, target_style, nnf, err):
172
+ for i in range(self.random_search_steps):
173
+ upd_nnf = self.random_step(nnf, self.random_search_range)
174
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
175
+ return nnf, err
176
+
177
+ def track(self, source_guide, target_guide, source_style, target_style, nnf, err):
178
+ for d in range(1, self.tracking_window_size + 1):
179
+ upd_nnf = self.track_step(nnf, d)
180
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
181
+ upd_nnf = self.track_step(nnf, -d)
182
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
183
+ return nnf, err
184
+
185
+ def iteration(self, source_guide, target_guide, source_style, target_style, nnf, err):
186
+ nnf, err = self.propagation(source_guide, target_guide, source_style, target_style, nnf, err)
187
+ nnf, err = self.random_search(source_guide, target_guide, source_style, target_style, nnf, err)
188
+ nnf, err = self.track(source_guide, target_guide, source_style, target_style, nnf, err)
189
+ return nnf, err
190
+
191
+ def estimate_nnf(self, source_guide, target_guide, source_style, nnf):
192
+ with cp.cuda.Device(self.gpu_id):
193
+ source_guide = self.pad_image(source_guide)
194
+ target_guide = self.pad_image(target_guide)
195
+ source_style = self.pad_image(source_style)
196
+ for it in range(self.num_iter):
197
+ self.patch_size = self.patch_size_list[it]
198
+ target_style = self.apply_nnf_to_image(nnf, source_style)
199
+ err = self.get_error(source_guide, target_guide, source_style, target_style, nnf)
200
+ nnf, err = self.iteration(source_guide, target_guide, source_style, target_style, nnf, err)
201
+ target_style = self.unpad_image(self.apply_nnf_to_image(nnf, source_style))
202
+ return nnf, target_style
203
+
204
+
205
+ class PyramidPatchMatcher:
206
+ def __init__(
207
+ self, image_height, image_width, channel, minimum_patch_size,
208
+ threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
209
+ use_mean_target_style=False, use_pairwise_patch_error=False,
210
+ tracking_window_size=0,
211
+ initialize="identity"
212
+ ):
213
+ maximum_patch_size = minimum_patch_size + (num_iter - 1) * 2
214
+ self.pyramid_level = int(np.log2(min(image_height, image_width) / maximum_patch_size))
215
+ self.pyramid_heights = []
216
+ self.pyramid_widths = []
217
+ self.patch_matchers = []
218
+ self.minimum_patch_size = minimum_patch_size
219
+ self.num_iter = num_iter
220
+ self.gpu_id = gpu_id
221
+ self.initialize = initialize
222
+ for level in range(self.pyramid_level):
223
+ height = image_height//(2**(self.pyramid_level - 1 - level))
224
+ width = image_width//(2**(self.pyramid_level - 1 - level))
225
+ self.pyramid_heights.append(height)
226
+ self.pyramid_widths.append(width)
227
+ self.patch_matchers.append(PatchMatcher(
228
+ height, width, channel, minimum_patch_size=minimum_patch_size,
229
+ threads_per_block=threads_per_block, num_iter=num_iter, gpu_id=gpu_id, guide_weight=guide_weight,
230
+ use_mean_target_style=use_mean_target_style, use_pairwise_patch_error=use_pairwise_patch_error,
231
+ tracking_window_size=tracking_window_size
232
+ ))
233
+
234
+ def resample_image(self, images, level):
235
+ height, width = self.pyramid_heights[level], self.pyramid_widths[level]
236
+ images = images.get()
237
+ images_resample = []
238
+ for image in images:
239
+ image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
240
+ images_resample.append(image_resample)
241
+ images_resample = cp.array(np.stack(images_resample), dtype=cp.float32)
242
+ return images_resample
243
+
244
+ def initialize_nnf(self, batch_size):
245
+ if self.initialize == "random":
246
+ height, width = self.pyramid_heights[0], self.pyramid_widths[0]
247
+ nnf = cp.stack([
248
+ cp.random.randint(0, height, (batch_size, height, width), dtype=cp.int32),
249
+ cp.random.randint(0, width, (batch_size, height, width), dtype=cp.int32)
250
+ ], axis=3)
251
+ elif self.initialize == "identity":
252
+ height, width = self.pyramid_heights[0], self.pyramid_widths[0]
253
+ nnf = cp.stack([
254
+ cp.repeat(cp.arange(height), width).reshape(height, width),
255
+ cp.tile(cp.arange(width), height).reshape(height, width)
256
+ ], axis=2)
257
+ nnf = cp.stack([nnf] * batch_size)
258
+ else:
259
+ raise NotImplementedError()
260
+ return nnf
261
+
262
+ def update_nnf(self, nnf, level):
263
+ # upscale
264
+ nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
265
+ nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1
266
+ nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1
267
+ # check if scale is 2
268
+ height, width = self.pyramid_heights[level], self.pyramid_widths[level]
269
+ if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
270
+ nnf = nnf.get().astype(np.float32)
271
+ nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf]
272
+ nnf = cp.array(np.stack(nnf), dtype=cp.int32)
273
+ nnf = self.patch_matchers[level].clamp_bound(nnf)
274
+ return nnf
275
+
276
+ def apply_nnf_to_image(self, nnf, image):
277
+ with cp.cuda.Device(self.gpu_id):
278
+ image = self.patch_matchers[-1].pad_image(image)
279
+ image = self.patch_matchers[-1].apply_nnf_to_image(nnf, image)
280
+ return image
281
+
282
+ def estimate_nnf(self, source_guide, target_guide, source_style):
283
+ with cp.cuda.Device(self.gpu_id):
284
+ if not isinstance(source_guide, cp.ndarray):
285
+ source_guide = cp.array(source_guide, dtype=cp.float32)
286
+ if not isinstance(target_guide, cp.ndarray):
287
+ target_guide = cp.array(target_guide, dtype=cp.float32)
288
+ if not isinstance(source_style, cp.ndarray):
289
+ source_style = cp.array(source_style, dtype=cp.float32)
290
+ for level in range(self.pyramid_level):
291
+ nnf = self.initialize_nnf(source_guide.shape[0]) if level==0 else self.update_nnf(nnf, level)
292
+ source_guide_ = self.resample_image(source_guide, level)
293
+ target_guide_ = self.resample_image(target_guide, level)
294
+ source_style_ = self.resample_image(source_style, level)
295
+ nnf, target_style = self.patch_matchers[level].estimate_nnf(
296
+ source_guide_, target_guide_, source_style_, nnf
297
+ )
298
+ return nnf.get(), target_style.get()
diffsynth/extensions/FastBlend/runners/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .accurate import AccurateModeRunner
2
+ from .fast import FastModeRunner
3
+ from .balanced import BalancedModeRunner
4
+ from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner
diffsynth/extensions/FastBlend/runners/accurate.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class AccurateModeRunner:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None):
13
+ patch_match_engine = PyramidPatchMatcher(
14
+ image_height=frames_style[0].shape[0],
15
+ image_width=frames_style[0].shape[1],
16
+ channel=3,
17
+ use_mean_target_style=True,
18
+ **ebsynth_config
19
+ )
20
+ # run
21
+ n = len(frames_style)
22
+ for target in tqdm(range(n), desc=desc):
23
+ l, r = max(target - window_size, 0), min(target + window_size + 1, n)
24
+ remapped_frames = []
25
+ for i in range(l, r, batch_size):
26
+ j = min(i + batch_size, r)
27
+ source_guide = np.stack([frames_guide[source] for source in range(i, j)])
28
+ target_guide = np.stack([frames_guide[target]] * (j - i))
29
+ source_style = np.stack([frames_style[source] for source in range(i, j)])
30
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
31
+ remapped_frames.append(target_style)
32
+ frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
33
+ frame = frame.clip(0, 255).astype("uint8")
34
+ if save_path is not None:
35
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
diffsynth/extensions/FastBlend/runners/balanced.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class BalancedModeRunner:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None):
13
+ patch_match_engine = PyramidPatchMatcher(
14
+ image_height=frames_style[0].shape[0],
15
+ image_width=frames_style[0].shape[1],
16
+ channel=3,
17
+ **ebsynth_config
18
+ )
19
+ # tasks
20
+ n = len(frames_style)
21
+ tasks = []
22
+ for target in range(n):
23
+ for source in range(target - window_size, target + window_size + 1):
24
+ if source >= 0 and source < n and source != target:
25
+ tasks.append((source, target))
26
+ # run
27
+ frames = [(None, 1) for i in range(n)]
28
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
29
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
30
+ source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
31
+ target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
32
+ source_style = np.stack([frames_style[source] for source, target in tasks_batch])
33
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
34
+ for (source, target), result in zip(tasks_batch, target_style):
35
+ frame, weight = frames[target]
36
+ if frame is None:
37
+ frame = frames_style[target]
38
+ frames[target] = (
39
+ frame * (weight / (weight + 1)) + result / (weight + 1),
40
+ weight + 1
41
+ )
42
+ if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size):
43
+ frame = frame.clip(0, 255).astype("uint8")
44
+ if save_path is not None:
45
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
46
+ frames[target] = (None, 1)
diffsynth/extensions/FastBlend/runners/fast.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import functools, os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class TableManager:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def task_list(self, n):
13
+ tasks = []
14
+ max_level = 1
15
+ while (1<<max_level)<=n:
16
+ max_level += 1
17
+ for i in range(n):
18
+ j = i
19
+ for level in range(max_level):
20
+ if i&(1<<level):
21
+ continue
22
+ j |= 1<<level
23
+ if j>=n:
24
+ break
25
+ meta_data = {
26
+ "source": i,
27
+ "target": j,
28
+ "level": level + 1
29
+ }
30
+ tasks.append(meta_data)
31
+ tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"]))
32
+ return tasks
33
+
34
+ def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""):
35
+ n = len(frames_guide)
36
+ tasks = self.task_list(n)
37
+ remapping_table = [[(frames_style[i], 1)] for i in range(n)]
38
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
39
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
40
+ source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
41
+ target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
42
+ source_style = np.stack([frames_style[task["source"]] for task in tasks_batch])
43
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
44
+ for task, result in zip(tasks_batch, target_style):
45
+ target, level = task["target"], task["level"]
46
+ if len(remapping_table[target])==level:
47
+ remapping_table[target].append((result, 1))
48
+ else:
49
+ frame, weight = remapping_table[target][level]
50
+ remapping_table[target][level] = (
51
+ frame * (weight / (weight + 1)) + result / (weight + 1),
52
+ weight + 1
53
+ )
54
+ return remapping_table
55
+
56
+ def remapping_table_to_blending_table(self, table):
57
+ for i in range(len(table)):
58
+ for j in range(1, len(table[i])):
59
+ frame_1, weight_1 = table[i][j-1]
60
+ frame_2, weight_2 = table[i][j]
61
+ frame = (frame_1 + frame_2) / 2
62
+ weight = weight_1 + weight_2
63
+ table[i][j] = (frame, weight)
64
+ return table
65
+
66
+ def tree_query(self, leftbound, rightbound):
67
+ node_list = []
68
+ node_index = rightbound
69
+ while node_index>=leftbound:
70
+ node_level = 0
71
+ while (1<<node_level)&node_index and node_index-(1<<node_level+1)+1>=leftbound:
72
+ node_level += 1
73
+ node_list.append((node_index, node_level))
74
+ node_index -= 1<<node_level
75
+ return node_list
76
+
77
+ def process_window_sum(self, frames_guide, blending_table, patch_match_engine, window_size, batch_size, desc=""):
78
+ n = len(blending_table)
79
+ tasks = []
80
+ frames_result = []
81
+ for target in range(n):
82
+ node_list = self.tree_query(max(target-window_size, 0), target)
83
+ for source, level in node_list:
84
+ if source!=target:
85
+ meta_data = {
86
+ "source": source,
87
+ "target": target,
88
+ "level": level
89
+ }
90
+ tasks.append(meta_data)
91
+ else:
92
+ frames_result.append(blending_table[target][level])
93
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
94
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
95
+ source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
96
+ target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
97
+ source_style = np.stack([blending_table[task["source"]][task["level"]][0] for task in tasks_batch])
98
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
99
+ for task, frame_2 in zip(tasks_batch, target_style):
100
+ source, target, level = task["source"], task["target"], task["level"]
101
+ frame_1, weight_1 = frames_result[target]
102
+ weight_2 = blending_table[source][level][1]
103
+ weight = weight_1 + weight_2
104
+ frame = frame_1 * (weight_1 / weight) + frame_2 * (weight_2 / weight)
105
+ frames_result[target] = (frame, weight)
106
+ return frames_result
107
+
108
+
109
+ class FastModeRunner:
110
+ def __init__(self):
111
+ pass
112
+
113
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, save_path=None):
114
+ frames_guide = frames_guide.raw_data()
115
+ frames_style = frames_style.raw_data()
116
+ table_manager = TableManager()
117
+ patch_match_engine = PyramidPatchMatcher(
118
+ image_height=frames_style[0].shape[0],
119
+ image_width=frames_style[0].shape[1],
120
+ channel=3,
121
+ **ebsynth_config
122
+ )
123
+ # left part
124
+ table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="Fast Mode Step 1/4")
125
+ table_l = table_manager.remapping_table_to_blending_table(table_l)
126
+ table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 2/4")
127
+ # right part
128
+ table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="Fast Mode Step 3/4")
129
+ table_r = table_manager.remapping_table_to_blending_table(table_r)
130
+ table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 4/4")[::-1]
131
+ # merge
132
+ frames = []
133
+ for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
134
+ weight_m = -1
135
+ weight = weight_l + weight_m + weight_r
136
+ frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
137
+ frames.append(frame)
138
+ frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
139
+ if save_path is not None:
140
+ for target, frame in enumerate(frames):
141
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
diffsynth/extensions/FastBlend/runners/interpolation.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class InterpolationModeRunner:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def get_index_dict(self, index_style):
13
+ index_dict = {}
14
+ for i, index in enumerate(index_style):
15
+ index_dict[index] = i
16
+ return index_dict
17
+
18
+ def get_weight(self, l, m, r):
19
+ weight_l, weight_r = abs(m - r), abs(m - l)
20
+ if weight_l + weight_r == 0:
21
+ weight_l, weight_r = 0.5, 0.5
22
+ else:
23
+ weight_l, weight_r = weight_l / (weight_l + weight_r), weight_r / (weight_l + weight_r)
24
+ return weight_l, weight_r
25
+
26
+ def get_task_group(self, index_style, n):
27
+ task_group = []
28
+ index_style = sorted(index_style)
29
+ # first frame
30
+ if index_style[0]>0:
31
+ tasks = []
32
+ for m in range(index_style[0]):
33
+ tasks.append((index_style[0], m, index_style[0]))
34
+ task_group.append(tasks)
35
+ # middle frames
36
+ for l, r in zip(index_style[:-1], index_style[1:]):
37
+ tasks = []
38
+ for m in range(l, r):
39
+ tasks.append((l, m, r))
40
+ task_group.append(tasks)
41
+ # last frame
42
+ tasks = []
43
+ for m in range(index_style[-1], n):
44
+ tasks.append((index_style[-1], m, index_style[-1]))
45
+ task_group.append(tasks)
46
+ return task_group
47
+
48
+ def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
49
+ patch_match_engine = PyramidPatchMatcher(
50
+ image_height=frames_style[0].shape[0],
51
+ image_width=frames_style[0].shape[1],
52
+ channel=3,
53
+ use_mean_target_style=False,
54
+ use_pairwise_patch_error=True,
55
+ **ebsynth_config
56
+ )
57
+ # task
58
+ index_dict = self.get_index_dict(index_style)
59
+ task_group = self.get_task_group(index_style, len(frames_guide))
60
+ # run
61
+ for tasks in task_group:
62
+ index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks])
63
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"):
64
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
65
+ source_guide, target_guide, source_style = [], [], []
66
+ for l, m, r in tasks_batch:
67
+ # l -> m
68
+ source_guide.append(frames_guide[l])
69
+ target_guide.append(frames_guide[m])
70
+ source_style.append(frames_style[index_dict[l]])
71
+ # r -> m
72
+ source_guide.append(frames_guide[r])
73
+ target_guide.append(frames_guide[m])
74
+ source_style.append(frames_style[index_dict[r]])
75
+ source_guide = np.stack(source_guide)
76
+ target_guide = np.stack(target_guide)
77
+ source_style = np.stack(source_style)
78
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
79
+ if save_path is not None:
80
+ for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch):
81
+ weight_l, weight_r = self.get_weight(l, m, r)
82
+ frame = frame_l * weight_l + frame_r * weight_r
83
+ frame = frame.clip(0, 255).astype("uint8")
84
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m))
85
+
86
+
87
+ class InterpolationModeSingleFrameRunner:
88
+ def __init__(self):
89
+ pass
90
+
91
+ def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
92
+ # check input
93
+ tracking_window_size = ebsynth_config["tracking_window_size"]
94
+ if tracking_window_size * 2 >= batch_size:
95
+ raise ValueError("batch_size should be larger than track_window_size * 2")
96
+ frame_style = frames_style[0]
97
+ frame_guide = frames_guide[index_style[0]]
98
+ patch_match_engine = PyramidPatchMatcher(
99
+ image_height=frame_style.shape[0],
100
+ image_width=frame_style.shape[1],
101
+ channel=3,
102
+ **ebsynth_config
103
+ )
104
+ # run
105
+ frame_id, n = 0, len(frames_guide)
106
+ for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"):
107
+ if i + batch_size > n:
108
+ l, r = max(n - batch_size, 0), n
109
+ else:
110
+ l, r = i, i + batch_size
111
+ source_guide = np.stack([frame_guide] * (r-l))
112
+ target_guide = np.stack([frames_guide[i] for i in range(l, r)])
113
+ source_style = np.stack([frame_style] * (r-l))
114
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
115
+ for i, frame in zip(range(l, r), target_style):
116
+ if i==frame_id:
117
+ frame = frame.clip(0, 255).astype("uint8")
118
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id))
119
+ frame_id += 1
120
+ if r < n and r-frame_id <= tracking_window_size:
121
+ break
diffsynth/extensions/RIFE/__init__.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+
8
+ def warp(tenInput, tenFlow, device):
9
+ backwarp_tenGrid = {}
10
+ k = (str(tenFlow.device), str(tenFlow.size()))
11
+ if k not in backwarp_tenGrid:
12
+ tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
13
+ 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
14
+ tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
15
+ 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
16
+ backwarp_tenGrid[k] = torch.cat(
17
+ [tenHorizontal, tenVertical], 1).to(device)
18
+
19
+ tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
20
+ tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
21
+
22
+ g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
23
+ return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
24
+
25
+
26
+ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
27
+ return nn.Sequential(
28
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
29
+ padding=padding, dilation=dilation, bias=True),
30
+ nn.PReLU(out_planes)
31
+ )
32
+
33
+
34
+ class IFBlock(nn.Module):
35
+ def __init__(self, in_planes, c=64):
36
+ super(IFBlock, self).__init__()
37
+ self.conv0 = nn.Sequential(conv(in_planes, c//2, 3, 2, 1), conv(c//2, c, 3, 2, 1),)
38
+ self.convblock0 = nn.Sequential(conv(c, c), conv(c, c))
39
+ self.convblock1 = nn.Sequential(conv(c, c), conv(c, c))
40
+ self.convblock2 = nn.Sequential(conv(c, c), conv(c, c))
41
+ self.convblock3 = nn.Sequential(conv(c, c), conv(c, c))
42
+ self.conv1 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 4, 4, 2, 1))
43
+ self.conv2 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 1, 4, 2, 1))
44
+
45
+ def forward(self, x, flow, scale=1):
46
+ x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
47
+ flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale
48
+ feat = self.conv0(torch.cat((x, flow), 1))
49
+ feat = self.convblock0(feat) + feat
50
+ feat = self.convblock1(feat) + feat
51
+ feat = self.convblock2(feat) + feat
52
+ feat = self.convblock3(feat) + feat
53
+ flow = self.conv1(feat)
54
+ mask = self.conv2(feat)
55
+ flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale
56
+ mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
57
+ return flow, mask
58
+
59
+
60
+ class IFNet(nn.Module):
61
+ def __init__(self):
62
+ super(IFNet, self).__init__()
63
+ self.block0 = IFBlock(7+4, c=90)
64
+ self.block1 = IFBlock(7+4, c=90)
65
+ self.block2 = IFBlock(7+4, c=90)
66
+ self.block_tea = IFBlock(10+4, c=90)
67
+
68
+ def forward(self, x, scale_list=[4, 2, 1], training=False):
69
+ if training == False:
70
+ channel = x.shape[1] // 2
71
+ img0 = x[:, :channel]
72
+ img1 = x[:, channel:]
73
+ flow_list = []
74
+ merged = []
75
+ mask_list = []
76
+ warped_img0 = img0
77
+ warped_img1 = img1
78
+ flow = (x[:, :4]).detach() * 0
79
+ mask = (x[:, :1]).detach() * 0
80
+ block = [self.block0, self.block1, self.block2]
81
+ for i in range(3):
82
+ f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
83
+ f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
84
+ flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
85
+ mask = mask + (m0 + (-m1)) / 2
86
+ mask_list.append(mask)
87
+ flow_list.append(flow)
88
+ warped_img0 = warp(img0, flow[:, :2], device=x.device)
89
+ warped_img1 = warp(img1, flow[:, 2:4], device=x.device)
90
+ merged.append((warped_img0, warped_img1))
91
+ '''
92
+ c0 = self.contextnet(img0, flow[:, :2])
93
+ c1 = self.contextnet(img1, flow[:, 2:4])
94
+ tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
95
+ res = tmp[:, 1:4] * 2 - 1
96
+ '''
97
+ for i in range(3):
98
+ mask_list[i] = torch.sigmoid(mask_list[i])
99
+ merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
100
+ return flow_list, mask_list[2], merged
101
+
102
+ @staticmethod
103
+ def state_dict_converter():
104
+ return IFNetStateDictConverter()
105
+
106
+
107
+ class IFNetStateDictConverter:
108
+ def __init__(self):
109
+ pass
110
+
111
+ def from_diffusers(self, state_dict):
112
+ state_dict_ = {k.replace("module.", ""): v for k, v in state_dict.items()}
113
+ return state_dict_
114
+
115
+ def from_civitai(self, state_dict):
116
+ return self.from_diffusers(state_dict)
117
+
118
+
119
+ class RIFEInterpolater:
120
+ def __init__(self, model, device="cuda"):
121
+ self.model = model
122
+ self.device = device
123
+ # IFNet only does not support float16
124
+ self.torch_dtype = torch.float32
125
+
126
+ @staticmethod
127
+ def from_model_manager(model_manager):
128
+ return RIFEInterpolater(model_manager.RIFE, device=model_manager.device)
129
+
130
+ def process_image(self, image):
131
+ width, height = image.size
132
+ if width % 32 != 0 or height % 32 != 0:
133
+ width = (width + 31) // 32
134
+ height = (height + 31) // 32
135
+ image = image.resize((width, height))
136
+ image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1)
137
+ return image
138
+
139
+ def process_images(self, images):
140
+ images = [self.process_image(image) for image in images]
141
+ images = torch.stack(images)
142
+ return images
143
+
144
+ def decode_images(self, images):
145
+ images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
146
+ images = [Image.fromarray(image) for image in images]
147
+ return images
148
+
149
+ def add_interpolated_images(self, images, interpolated_images):
150
+ output_images = []
151
+ for image, interpolated_image in zip(images, interpolated_images):
152
+ output_images.append(image)
153
+ output_images.append(interpolated_image)
154
+ output_images.append(images[-1])
155
+ return output_images
156
+
157
+
158
+ @torch.no_grad()
159
+ def interpolate_(self, images, scale=1.0):
160
+ input_tensor = self.process_images(images)
161
+ input_tensor = torch.cat((input_tensor[:-1], input_tensor[1:]), dim=1)
162
+ input_tensor = input_tensor.to(device=self.device, dtype=self.torch_dtype)
163
+ flow, mask, merged = self.model(input_tensor, [4/scale, 2/scale, 1/scale])
164
+ output_images = self.decode_images(merged[2].cpu())
165
+ if output_images[0].size != images[0].size:
166
+ output_images = [image.resize(images[0].size) for image in output_images]
167
+ return output_images
168
+
169
+
170
+ @torch.no_grad()
171
+ def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1, progress_bar=lambda x:x):
172
+ # Preprocess
173
+ processed_images = self.process_images(images)
174
+
175
+ for iter in range(num_iter):
176
+ # Input
177
+ input_tensor = torch.cat((processed_images[:-1], processed_images[1:]), dim=1)
178
+
179
+ # Interpolate
180
+ output_tensor = []
181
+ for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
182
+ batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
183
+ batch_input_tensor = input_tensor[batch_id: batch_id_]
184
+ batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
185
+ flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
186
+ output_tensor.append(merged[2].cpu())
187
+
188
+ # Output
189
+ output_tensor = torch.concat(output_tensor, dim=0).clip(0, 1)
190
+ processed_images = self.add_interpolated_images(processed_images, output_tensor)
191
+ processed_images = torch.stack(processed_images)
192
+
193
+ # To images
194
+ output_images = self.decode_images(processed_images)
195
+ if output_images[0].size != images[0].size:
196
+ output_images = [image.resize(images[0].size) for image in output_images]
197
+ return output_images
198
+
199
+
200
+ class RIFESmoother(RIFEInterpolater):
201
+ def __init__(self, model, device="cuda"):
202
+ super(RIFESmoother, self).__init__(model, device=device)
203
+
204
+ @staticmethod
205
+ def from_model_manager(model_manager):
206
+ return RIFESmoother(model_manager.RIFE, device=model_manager.device)
207
+
208
+ def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
209
+ output_tensor = []
210
+ for batch_id in range(0, input_tensor.shape[0], batch_size):
211
+ batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
212
+ batch_input_tensor = input_tensor[batch_id: batch_id_]
213
+ batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
214
+ flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
215
+ output_tensor.append(merged[2].cpu())
216
+ output_tensor = torch.concat(output_tensor, dim=0)
217
+ return output_tensor
218
+
219
+ @torch.no_grad()
220
+ def __call__(self, rendered_frames, scale=1.0, batch_size=4, num_iter=1, **kwargs):
221
+ # Preprocess
222
+ processed_images = self.process_images(rendered_frames)
223
+
224
+ for iter in range(num_iter):
225
+ # Input
226
+ input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1)
227
+
228
+ # Interpolate
229
+ output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
230
+
231
+ # Blend
232
+ input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1)
233
+ output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
234
+
235
+ # Add to frames
236
+ processed_images[1:-1] = output_tensor
237
+
238
+ # To images
239
+ output_images = self.decode_images(processed_images)
240
+ if output_images[0].size != rendered_frames[0].size:
241
+ output_images = [image.resize(rendered_frames[0].size) for image in output_images]
242
+ return output_images
diffsynth/extensions/__init__.py ADDED
File without changes
diffsynth/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model_manager import *
diffsynth/models/attention.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+
4
+
5
+ def low_version_attention(query, key, value, attn_bias=None):
6
+ scale = 1 / query.shape[-1] ** 0.5
7
+ query = query * scale
8
+ attn = torch.matmul(query, key.transpose(-2, -1))
9
+ if attn_bias is not None:
10
+ attn = attn + attn_bias
11
+ attn = attn.softmax(-1)
12
+ return attn @ value
13
+
14
+
15
+ class Attention(torch.nn.Module):
16
+
17
+ def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
18
+ super().__init__()
19
+ dim_inner = head_dim * num_heads
20
+ kv_dim = kv_dim if kv_dim is not None else q_dim
21
+ self.num_heads = num_heads
22
+ self.head_dim = head_dim
23
+
24
+ self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
25
+ self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
26
+ self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
27
+ self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
28
+
29
+ def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
30
+ batch_size = q.shape[0]
31
+ ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
32
+ ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
33
+ ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
34
+ hidden_states = hidden_states + scale * ip_hidden_states
35
+ return hidden_states
36
+
37
+ def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
38
+ if encoder_hidden_states is None:
39
+ encoder_hidden_states = hidden_states
40
+
41
+ batch_size = encoder_hidden_states.shape[0]
42
+
43
+ q = self.to_q(hidden_states)
44
+ k = self.to_k(encoder_hidden_states)
45
+ v = self.to_v(encoder_hidden_states)
46
+
47
+ q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
48
+ k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
49
+ v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
50
+
51
+ if qkv_preprocessor is not None:
52
+ q, k, v = qkv_preprocessor(q, k, v)
53
+
54
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
55
+ if ipadapter_kwargs is not None:
56
+ hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
57
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
58
+ hidden_states = hidden_states.to(q.dtype)
59
+
60
+ hidden_states = self.to_out(hidden_states)
61
+
62
+ return hidden_states
63
+
64
+ def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
65
+ if encoder_hidden_states is None:
66
+ encoder_hidden_states = hidden_states
67
+
68
+ q = self.to_q(hidden_states)
69
+ k = self.to_k(encoder_hidden_states)
70
+ v = self.to_v(encoder_hidden_states)
71
+
72
+ q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
73
+ k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
74
+ v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
75
+
76
+ if attn_mask is not None:
77
+ hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
78
+ else:
79
+ import xformers.ops as xops
80
+ hidden_states = xops.memory_efficient_attention(q, k, v)
81
+ hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
82
+
83
+ hidden_states = hidden_states.to(q.dtype)
84
+ hidden_states = self.to_out(hidden_states)
85
+
86
+ return hidden_states
87
+
88
+ def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
89
+ return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
diffsynth/models/downloader.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download
2
+ from modelscope import snapshot_download
3
+ import os, shutil
4
+ from typing_extensions import Literal, TypeAlias
5
+ from typing import List
6
+ from ..configs.model_config import preset_models_on_huggingface, preset_models_on_modelscope, Preset_model_id
7
+
8
+
9
+ def download_from_modelscope(model_id, origin_file_path, local_dir):
10
+ os.makedirs(local_dir, exist_ok=True)
11
+ if os.path.basename(origin_file_path) in os.listdir(local_dir):
12
+ print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
13
+ return
14
+ else:
15
+ print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
16
+ snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
17
+ downloaded_file_path = os.path.join(local_dir, origin_file_path)
18
+ target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
19
+ if downloaded_file_path != target_file_path:
20
+ shutil.move(downloaded_file_path, target_file_path)
21
+ shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
22
+
23
+
24
+ def download_from_huggingface(model_id, origin_file_path, local_dir):
25
+ os.makedirs(local_dir, exist_ok=True)
26
+ if os.path.basename(origin_file_path) in os.listdir(local_dir):
27
+ print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
28
+ return
29
+ else:
30
+ print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
31
+ hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
32
+
33
+
34
+ Preset_model_website: TypeAlias = Literal[
35
+ "HuggingFace",
36
+ "ModelScope",
37
+ ]
38
+ website_to_preset_models = {
39
+ "HuggingFace": preset_models_on_huggingface,
40
+ "ModelScope": preset_models_on_modelscope,
41
+ }
42
+ website_to_download_fn = {
43
+ "HuggingFace": download_from_huggingface,
44
+ "ModelScope": download_from_modelscope,
45
+ }
46
+
47
+
48
+ def download_models(
49
+ model_id_list: List[Preset_model_id] = [],
50
+ downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
51
+ ):
52
+ print(f"Downloading models: {model_id_list}")
53
+ downloaded_files = []
54
+ for model_id in model_id_list:
55
+ for website in downloading_priority:
56
+ if model_id in website_to_preset_models[website]:
57
+ for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]:
58
+ # Check if the file is downloaded.
59
+ file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
60
+ if file_to_download in downloaded_files:
61
+ continue
62
+ # Download
63
+ website_to_download_fn[website](model_id, origin_file_path, local_dir)
64
+ if os.path.basename(origin_file_path) in os.listdir(local_dir):
65
+ downloaded_files.append(file_to_download)
66
+ return downloaded_files
diffsynth/models/flux_dit.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .sd3_dit import TimestepEmbeddings, AdaLayerNorm
3
+ from einops import rearrange
4
+ from .tiler import TileWorker
5
+
6
+
7
+
8
+ class RoPEEmbedding(torch.nn.Module):
9
+ def __init__(self, dim, theta, axes_dim):
10
+ super().__init__()
11
+ self.dim = dim
12
+ self.theta = theta
13
+ self.axes_dim = axes_dim
14
+
15
+
16
+ def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
17
+ assert dim % 2 == 0, "The dimension must be even."
18
+
19
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
20
+ omega = 1.0 / (theta**scale)
21
+
22
+ batch_size, seq_length = pos.shape
23
+ out = torch.einsum("...n,d->...nd", pos, omega)
24
+ cos_out = torch.cos(out)
25
+ sin_out = torch.sin(out)
26
+
27
+ stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
28
+ out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
29
+ return out.float()
30
+
31
+
32
+ def forward(self, ids):
33
+ n_axes = ids.shape[-1]
34
+ emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
35
+ return emb.unsqueeze(1)
36
+
37
+
38
+
39
+ class RMSNorm(torch.nn.Module):
40
+ def __init__(self, dim, eps):
41
+ super().__init__()
42
+ self.weight = torch.nn.Parameter(torch.ones((dim,)))
43
+ self.eps = eps
44
+
45
+ def forward(self, hidden_states):
46
+ input_dtype = hidden_states.dtype
47
+ variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
48
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
49
+ hidden_states = hidden_states.to(input_dtype) * self.weight
50
+ return hidden_states
51
+
52
+
53
+
54
+ class FluxJointAttention(torch.nn.Module):
55
+ def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
56
+ super().__init__()
57
+ self.num_heads = num_heads
58
+ self.head_dim = head_dim
59
+ self.only_out_a = only_out_a
60
+
61
+ self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
62
+ self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
63
+
64
+ self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
65
+ self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
66
+ self.norm_q_b = RMSNorm(head_dim, eps=1e-6)
67
+ self.norm_k_b = RMSNorm(head_dim, eps=1e-6)
68
+
69
+ self.a_to_out = torch.nn.Linear(dim_a, dim_a)
70
+ if not only_out_a:
71
+ self.b_to_out = torch.nn.Linear(dim_b, dim_b)
72
+
73
+
74
+ def apply_rope(self, xq, xk, freqs_cis):
75
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
76
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
77
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
78
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
79
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
80
+
81
+
82
+ def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb):
83
+ batch_size = hidden_states_a.shape[0]
84
+
85
+ # Part A
86
+ qkv_a = self.a_to_qkv(hidden_states_a)
87
+ qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
88
+ q_a, k_a, v_a = qkv_a.chunk(3, dim=1)
89
+ q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
90
+
91
+ # Part B
92
+ qkv_b = self.b_to_qkv(hidden_states_b)
93
+ qkv_b = qkv_b.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
94
+ q_b, k_b, v_b = qkv_b.chunk(3, dim=1)
95
+ q_b, k_b = self.norm_q_b(q_b), self.norm_k_b(k_b)
96
+
97
+ q = torch.concat([q_b, q_a], dim=2)
98
+ k = torch.concat([k_b, k_a], dim=2)
99
+ v = torch.concat([v_b, v_a], dim=2)
100
+
101
+ q, k = self.apply_rope(q, k, image_rotary_emb)
102
+
103
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
104
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
105
+ hidden_states = hidden_states.to(q.dtype)
106
+ hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
107
+ hidden_states_a = self.a_to_out(hidden_states_a)
108
+ if self.only_out_a:
109
+ return hidden_states_a
110
+ else:
111
+ hidden_states_b = self.b_to_out(hidden_states_b)
112
+ return hidden_states_a, hidden_states_b
113
+
114
+
115
+
116
+ class FluxJointTransformerBlock(torch.nn.Module):
117
+ def __init__(self, dim, num_attention_heads):
118
+ super().__init__()
119
+ self.norm1_a = AdaLayerNorm(dim)
120
+ self.norm1_b = AdaLayerNorm(dim)
121
+
122
+ self.attn = FluxJointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
123
+
124
+ self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
125
+ self.ff_a = torch.nn.Sequential(
126
+ torch.nn.Linear(dim, dim*4),
127
+ torch.nn.GELU(approximate="tanh"),
128
+ torch.nn.Linear(dim*4, dim)
129
+ )
130
+
131
+ self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
132
+ self.ff_b = torch.nn.Sequential(
133
+ torch.nn.Linear(dim, dim*4),
134
+ torch.nn.GELU(approximate="tanh"),
135
+ torch.nn.Linear(dim*4, dim)
136
+ )
137
+
138
+
139
+ def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
140
+ norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
141
+ norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
142
+
143
+ # Attention
144
+ attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb)
145
+
146
+ # Part A
147
+ hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
148
+ norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
149
+ hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
150
+
151
+ # Part B
152
+ hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
153
+ norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
154
+ hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
155
+
156
+ return hidden_states_a, hidden_states_b
157
+
158
+
159
+
160
+ class FluxSingleAttention(torch.nn.Module):
161
+ def __init__(self, dim_a, dim_b, num_heads, head_dim):
162
+ super().__init__()
163
+ self.num_heads = num_heads
164
+ self.head_dim = head_dim
165
+
166
+ self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
167
+
168
+ self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
169
+ self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
170
+
171
+
172
+ def apply_rope(self, xq, xk, freqs_cis):
173
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
174
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
175
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
176
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
177
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
178
+
179
+
180
+ def forward(self, hidden_states, image_rotary_emb):
181
+ batch_size = hidden_states.shape[0]
182
+
183
+ qkv_a = self.a_to_qkv(hidden_states)
184
+ qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
185
+ q_a, k_a, v = qkv_a.chunk(3, dim=1)
186
+ q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
187
+
188
+ q, k = self.apply_rope(q_a, k_a, image_rotary_emb)
189
+
190
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
191
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
192
+ hidden_states = hidden_states.to(q.dtype)
193
+ return hidden_states
194
+
195
+
196
+
197
+ class AdaLayerNormSingle(torch.nn.Module):
198
+ def __init__(self, dim):
199
+ super().__init__()
200
+ self.silu = torch.nn.SiLU()
201
+ self.linear = torch.nn.Linear(dim, 3 * dim, bias=True)
202
+ self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
203
+
204
+
205
+ def forward(self, x, emb):
206
+ emb = self.linear(self.silu(emb))
207
+ shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
208
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
209
+ return x, gate_msa
210
+
211
+
212
+
213
+ class FluxSingleTransformerBlock(torch.nn.Module):
214
+ def __init__(self, dim, num_attention_heads):
215
+ super().__init__()
216
+ self.num_heads = num_attention_heads
217
+ self.head_dim = dim // num_attention_heads
218
+ self.dim = dim
219
+
220
+ self.norm = AdaLayerNormSingle(dim)
221
+ # self.proj_in = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), torch.nn.GELU(approximate="tanh"))
222
+ # self.attn = FluxSingleAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
223
+ self.linear = torch.nn.Linear(dim, dim * (3 + 4))
224
+ self.norm_q_a = RMSNorm(self.head_dim, eps=1e-6)
225
+ self.norm_k_a = RMSNorm(self.head_dim, eps=1e-6)
226
+
227
+ self.proj_out = torch.nn.Linear(dim * 5, dim)
228
+
229
+
230
+ def apply_rope(self, xq, xk, freqs_cis):
231
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
232
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
233
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
234
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
235
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
236
+
237
+
238
+ def process_attention(self, hidden_states, image_rotary_emb):
239
+ batch_size = hidden_states.shape[0]
240
+
241
+ qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
242
+ q, k, v = qkv.chunk(3, dim=1)
243
+ q, k = self.norm_q_a(q), self.norm_k_a(k)
244
+
245
+ q, k = self.apply_rope(q, k, image_rotary_emb)
246
+
247
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
248
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
249
+ hidden_states = hidden_states.to(q.dtype)
250
+ return hidden_states
251
+
252
+
253
+ def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
254
+ residual = hidden_states_a
255
+ norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
256
+ hidden_states_a = self.linear(norm_hidden_states)
257
+ attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
258
+
259
+ attn_output = self.process_attention(attn_output, image_rotary_emb)
260
+ mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
261
+
262
+ hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
263
+ hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a)
264
+ hidden_states_a = residual + hidden_states_a
265
+
266
+ return hidden_states_a, hidden_states_b
267
+
268
+
269
+
270
+ class AdaLayerNormContinuous(torch.nn.Module):
271
+ def __init__(self, dim):
272
+ super().__init__()
273
+ self.silu = torch.nn.SiLU()
274
+ self.linear = torch.nn.Linear(dim, dim * 2, bias=True)
275
+ self.norm = torch.nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
276
+
277
+ def forward(self, x, conditioning):
278
+ emb = self.linear(self.silu(conditioning))
279
+ scale, shift = torch.chunk(emb, 2, dim=1)
280
+ x = self.norm(x) * (1 + scale)[:, None] + shift[:, None]
281
+ return x
282
+
283
+
284
+
285
+ class FluxDiT(torch.nn.Module):
286
+ def __init__(self):
287
+ super().__init__()
288
+ self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
289
+ self.time_embedder = TimestepEmbeddings(256, 3072)
290
+ self.guidance_embedder = TimestepEmbeddings(256, 3072)
291
+ self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
292
+ self.context_embedder = torch.nn.Linear(4096, 3072)
293
+ self.x_embedder = torch.nn.Linear(64, 3072)
294
+
295
+ self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
296
+ self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
297
+
298
+ self.norm_out = AdaLayerNormContinuous(3072)
299
+ self.proj_out = torch.nn.Linear(3072, 64)
300
+
301
+
302
+ def patchify(self, hidden_states):
303
+ hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
304
+ return hidden_states
305
+
306
+
307
+ def unpatchify(self, hidden_states, height, width):
308
+ hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
309
+ return hidden_states
310
+
311
+
312
+ def prepare_image_ids(self, latents):
313
+ batch_size, _, height, width = latents.shape
314
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
315
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
316
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
317
+
318
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
319
+
320
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
321
+ latent_image_ids = latent_image_ids.reshape(
322
+ batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
323
+ )
324
+ latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
325
+
326
+ return latent_image_ids
327
+
328
+
329
+ def tiled_forward(
330
+ self,
331
+ hidden_states,
332
+ timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
333
+ tile_size=128, tile_stride=64,
334
+ **kwargs
335
+ ):
336
+ # Due to the global positional embedding, we cannot implement layer-wise tiled forward.
337
+ hidden_states = TileWorker().tiled_forward(
338
+ lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None),
339
+ hidden_states,
340
+ tile_size,
341
+ tile_stride,
342
+ tile_device=hidden_states.device,
343
+ tile_dtype=hidden_states.dtype
344
+ )
345
+ return hidden_states
346
+
347
+
348
+ def forward(
349
+ self,
350
+ hidden_states,
351
+ timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
352
+ tiled=False, tile_size=128, tile_stride=64,
353
+ **kwargs
354
+ ):
355
+ if tiled:
356
+ return self.tiled_forward(
357
+ hidden_states,
358
+ timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
359
+ tile_size=tile_size, tile_stride=tile_stride,
360
+ **kwargs
361
+ )
362
+
363
+ if image_ids is None:
364
+ image_ids = self.prepare_image_ids(hidden_states)
365
+
366
+ conditioning = self.time_embedder(timestep, hidden_states.dtype)\
367
+ + self.guidance_embedder(guidance, hidden_states.dtype)\
368
+ + self.pooled_text_embedder(pooled_prompt_emb)
369
+ prompt_emb = self.context_embedder(prompt_emb)
370
+ image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
371
+
372
+ height, width = hidden_states.shape[-2:]
373
+ hidden_states = self.patchify(hidden_states)
374
+ hidden_states = self.x_embedder(hidden_states)
375
+
376
+ for block in self.blocks:
377
+ hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
378
+
379
+ hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
380
+ for block in self.single_blocks:
381
+ hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
382
+ hidden_states = hidden_states[:, prompt_emb.shape[1]:]
383
+
384
+ hidden_states = self.norm_out(hidden_states, conditioning)
385
+ hidden_states = self.proj_out(hidden_states)
386
+ hidden_states = self.unpatchify(hidden_states, height, width)
387
+
388
+ return hidden_states
389
+
390
+
391
+ @staticmethod
392
+ def state_dict_converter():
393
+ return FluxDiTStateDictConverter()
394
+
395
+
396
+
397
+ class FluxDiTStateDictConverter:
398
+ def __init__(self):
399
+ pass
400
+
401
+ def from_diffusers(self, state_dict):
402
+ rename_dict = {
403
+ "context_embedder": "context_embedder",
404
+ "x_embedder": "x_embedder",
405
+ "time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
406
+ "time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
407
+ "time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
408
+ "time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
409
+ "time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
410
+ "time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
411
+ "norm_out.linear": "norm_out.linear",
412
+ "proj_out": "proj_out",
413
+
414
+ "norm1.linear": "norm1_a.linear",
415
+ "norm1_context.linear": "norm1_b.linear",
416
+ "attn.to_q": "attn.a_to_q",
417
+ "attn.to_k": "attn.a_to_k",
418
+ "attn.to_v": "attn.a_to_v",
419
+ "attn.to_out.0": "attn.a_to_out",
420
+ "attn.add_q_proj": "attn.b_to_q",
421
+ "attn.add_k_proj": "attn.b_to_k",
422
+ "attn.add_v_proj": "attn.b_to_v",
423
+ "attn.to_add_out": "attn.b_to_out",
424
+ "ff.net.0.proj": "ff_a.0",
425
+ "ff.net.2": "ff_a.2",
426
+ "ff_context.net.0.proj": "ff_b.0",
427
+ "ff_context.net.2": "ff_b.2",
428
+ "attn.norm_q": "attn.norm_q_a",
429
+ "attn.norm_k": "attn.norm_k_a",
430
+ "attn.norm_added_q": "attn.norm_q_b",
431
+ "attn.norm_added_k": "attn.norm_k_b",
432
+ }
433
+ rename_dict_single = {
434
+ "attn.to_q": "a_to_q",
435
+ "attn.to_k": "a_to_k",
436
+ "attn.to_v": "a_to_v",
437
+ "attn.norm_q": "norm_q_a",
438
+ "attn.norm_k": "norm_k_a",
439
+ "norm.linear": "norm.linear",
440
+ "proj_mlp": "proj_in_besides_attn",
441
+ "proj_out": "proj_out",
442
+ }
443
+ state_dict_ = {}
444
+ for name, param in state_dict.items():
445
+ if name in rename_dict:
446
+ state_dict_[rename_dict[name]] = param
447
+ elif name.endswith(".weight") or name.endswith(".bias"):
448
+ suffix = ".weight" if name.endswith(".weight") else ".bias"
449
+ prefix = name[:-len(suffix)]
450
+ if prefix in rename_dict:
451
+ state_dict_[rename_dict[prefix] + suffix] = param
452
+ elif prefix.startswith("transformer_blocks."):
453
+ names = prefix.split(".")
454
+ names[0] = "blocks"
455
+ middle = ".".join(names[2:])
456
+ if middle in rename_dict:
457
+ name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
458
+ state_dict_[name_] = param
459
+ elif prefix.startswith("single_transformer_blocks."):
460
+ names = prefix.split(".")
461
+ names[0] = "single_blocks"
462
+ middle = ".".join(names[2:])
463
+ if middle in rename_dict_single:
464
+ name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
465
+ state_dict_[name_] = param
466
+ else:
467
+ print(name)
468
+ else:
469
+ print(name)
470
+ for name in list(state_dict_.keys()):
471
+ if ".proj_in_besides_attn." in name:
472
+ name_ = name.replace(".proj_in_besides_attn.", ".linear.")
473
+ param = torch.concat([
474
+ state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
475
+ state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
476
+ state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
477
+ state_dict_[name],
478
+ ], dim=0)
479
+ state_dict_[name_] = param
480
+ state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
481
+ state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
482
+ state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
483
+ state_dict_.pop(name)
484
+ for name in list(state_dict_.keys()):
485
+ for component in ["a", "b"]:
486
+ if f".{component}_to_q." in name:
487
+ name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
488
+ param = torch.concat([
489
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
490
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
491
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
492
+ ], dim=0)
493
+ state_dict_[name_] = param
494
+ state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
495
+ state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
496
+ state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
497
+ return state_dict_
498
+
499
+ def from_civitai(self, state_dict):
500
+ rename_dict = {
501
+ "time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",
502
+ "time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight",
503
+ "time_in.out_layer.bias": "time_embedder.timestep_embedder.2.bias",
504
+ "time_in.out_layer.weight": "time_embedder.timestep_embedder.2.weight",
505
+ "txt_in.bias": "context_embedder.bias",
506
+ "txt_in.weight": "context_embedder.weight",
507
+ "vector_in.in_layer.bias": "pooled_text_embedder.0.bias",
508
+ "vector_in.in_layer.weight": "pooled_text_embedder.0.weight",
509
+ "vector_in.out_layer.bias": "pooled_text_embedder.2.bias",
510
+ "vector_in.out_layer.weight": "pooled_text_embedder.2.weight",
511
+ "final_layer.linear.bias": "proj_out.bias",
512
+ "final_layer.linear.weight": "proj_out.weight",
513
+ "guidance_in.in_layer.bias": "guidance_embedder.timestep_embedder.0.bias",
514
+ "guidance_in.in_layer.weight": "guidance_embedder.timestep_embedder.0.weight",
515
+ "guidance_in.out_layer.bias": "guidance_embedder.timestep_embedder.2.bias",
516
+ "guidance_in.out_layer.weight": "guidance_embedder.timestep_embedder.2.weight",
517
+ "img_in.bias": "x_embedder.bias",
518
+ "img_in.weight": "x_embedder.weight",
519
+ "final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
520
+ "final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
521
+ }
522
+ suffix_rename_dict = {
523
+ "img_attn.norm.key_norm.scale": "attn.norm_k_a.weight",
524
+ "img_attn.norm.query_norm.scale": "attn.norm_q_a.weight",
525
+ "img_attn.proj.bias": "attn.a_to_out.bias",
526
+ "img_attn.proj.weight": "attn.a_to_out.weight",
527
+ "img_attn.qkv.bias": "attn.a_to_qkv.bias",
528
+ "img_attn.qkv.weight": "attn.a_to_qkv.weight",
529
+ "img_mlp.0.bias": "ff_a.0.bias",
530
+ "img_mlp.0.weight": "ff_a.0.weight",
531
+ "img_mlp.2.bias": "ff_a.2.bias",
532
+ "img_mlp.2.weight": "ff_a.2.weight",
533
+ "img_mod.lin.bias": "norm1_a.linear.bias",
534
+ "img_mod.lin.weight": "norm1_a.linear.weight",
535
+ "txt_attn.norm.key_norm.scale": "attn.norm_k_b.weight",
536
+ "txt_attn.norm.query_norm.scale": "attn.norm_q_b.weight",
537
+ "txt_attn.proj.bias": "attn.b_to_out.bias",
538
+ "txt_attn.proj.weight": "attn.b_to_out.weight",
539
+ "txt_attn.qkv.bias": "attn.b_to_qkv.bias",
540
+ "txt_attn.qkv.weight": "attn.b_to_qkv.weight",
541
+ "txt_mlp.0.bias": "ff_b.0.bias",
542
+ "txt_mlp.0.weight": "ff_b.0.weight",
543
+ "txt_mlp.2.bias": "ff_b.2.bias",
544
+ "txt_mlp.2.weight": "ff_b.2.weight",
545
+ "txt_mod.lin.bias": "norm1_b.linear.bias",
546
+ "txt_mod.lin.weight": "norm1_b.linear.weight",
547
+
548
+ "linear1.bias": "linear.bias",
549
+ "linear1.weight": "linear.weight",
550
+ "linear2.bias": "proj_out.bias",
551
+ "linear2.weight": "proj_out.weight",
552
+ "modulation.lin.bias": "norm.linear.bias",
553
+ "modulation.lin.weight": "norm.linear.weight",
554
+ "norm.key_norm.scale": "norm_k_a.weight",
555
+ "norm.query_norm.scale": "norm_q_a.weight",
556
+ }
557
+ state_dict_ = {}
558
+ for name, param in state_dict.items():
559
+ names = name.split(".")
560
+ if name in rename_dict:
561
+ rename = rename_dict[name]
562
+ if name.startswith("final_layer.adaLN_modulation.1."):
563
+ param = torch.concat([param[3072:], param[:3072]], dim=0)
564
+ state_dict_[rename] = param
565
+ elif names[0] == "double_blocks":
566
+ rename = f"blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
567
+ state_dict_[rename] = param
568
+ elif names[0] == "single_blocks":
569
+ if ".".join(names[2:]) in suffix_rename_dict:
570
+ rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
571
+ state_dict_[rename] = param
572
+ else:
573
+ print(name)
574
+ return state_dict_
575
+
diffsynth/models/flux_text_encoder.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import T5EncoderModel, T5Config
3
+ from .sd_text_encoder import SDTextEncoder
4
+
5
+
6
+ class FluxTextEncoder1(SDTextEncoder):
7
+ def __init__(self, vocab_size=49408):
8
+ super().__init__(vocab_size=vocab_size)
9
+
10
+ def forward(self, input_ids, clip_skip=2):
11
+ embeds = self.token_embedding(input_ids) + self.position_embeds
12
+ attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
13
+ for encoder_id, encoder in enumerate(self.encoders):
14
+ embeds = encoder(embeds, attn_mask=attn_mask)
15
+ if encoder_id + clip_skip == len(self.encoders):
16
+ hidden_states = embeds
17
+ embeds = self.final_layer_norm(embeds)
18
+ pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
19
+ return embeds, pooled_embeds
20
+
21
+ @staticmethod
22
+ def state_dict_converter():
23
+ return FluxTextEncoder1StateDictConverter()
24
+
25
+
26
+
27
+ class FluxTextEncoder2(T5EncoderModel):
28
+ def __init__(self, config):
29
+ super().__init__(config)
30
+ self.eval()
31
+
32
+ def forward(self, input_ids):
33
+ outputs = super().forward(input_ids=input_ids)
34
+ prompt_emb = outputs.last_hidden_state
35
+ return prompt_emb
36
+
37
+ @staticmethod
38
+ def state_dict_converter():
39
+ return FluxTextEncoder2StateDictConverter()
40
+
41
+
42
+
43
+ class FluxTextEncoder1StateDictConverter:
44
+ def __init__(self):
45
+ pass
46
+
47
+ def from_diffusers(self, state_dict):
48
+ rename_dict = {
49
+ "text_model.embeddings.token_embedding.weight": "token_embedding.weight",
50
+ "text_model.embeddings.position_embedding.weight": "position_embeds",
51
+ "text_model.final_layer_norm.weight": "final_layer_norm.weight",
52
+ "text_model.final_layer_norm.bias": "final_layer_norm.bias"
53
+ }
54
+ attn_rename_dict = {
55
+ "self_attn.q_proj": "attn.to_q",
56
+ "self_attn.k_proj": "attn.to_k",
57
+ "self_attn.v_proj": "attn.to_v",
58
+ "self_attn.out_proj": "attn.to_out",
59
+ "layer_norm1": "layer_norm1",
60
+ "layer_norm2": "layer_norm2",
61
+ "mlp.fc1": "fc1",
62
+ "mlp.fc2": "fc2",
63
+ }
64
+ state_dict_ = {}
65
+ for name in state_dict:
66
+ if name in rename_dict:
67
+ param = state_dict[name]
68
+ if name == "text_model.embeddings.position_embedding.weight":
69
+ param = param.reshape((1, param.shape[0], param.shape[1]))
70
+ state_dict_[rename_dict[name]] = param
71
+ elif name.startswith("text_model.encoder.layers."):
72
+ param = state_dict[name]
73
+ names = name.split(".")
74
+ layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
75
+ name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
76
+ state_dict_[name_] = param
77
+ return state_dict_
78
+
79
+ def from_civitai(self, state_dict):
80
+ return self.from_diffusers(state_dict)
81
+
82
+
83
+
84
+ class FluxTextEncoder2StateDictConverter():
85
+ def __init__(self):
86
+ pass
87
+
88
+ def from_diffusers(self, state_dict):
89
+ state_dict_ = state_dict
90
+ return state_dict_
91
+
92
+ def from_civitai(self, state_dict):
93
+ return self.from_diffusers(state_dict)
diffsynth/models/flux_vae.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .sd3_vae_encoder import SD3VAEEncoder, SDVAEEncoderStateDictConverter
2
+ from .sd3_vae_decoder import SD3VAEDecoder, SDVAEDecoderStateDictConverter
3
+
4
+
5
+ class FluxVAEEncoder(SD3VAEEncoder):
6
+ def __init__(self):
7
+ super().__init__()
8
+ self.scaling_factor = 0.3611
9
+ self.shift_factor = 0.1159
10
+
11
+ @staticmethod
12
+ def state_dict_converter():
13
+ return FluxVAEEncoderStateDictConverter()
14
+
15
+
16
+ class FluxVAEDecoder(SD3VAEDecoder):
17
+ def __init__(self):
18
+ super().__init__()
19
+ self.scaling_factor = 0.3611
20
+ self.shift_factor = 0.1159
21
+
22
+ @staticmethod
23
+ def state_dict_converter():
24
+ return FluxVAEDecoderStateDictConverter()
25
+
26
+
27
+ class FluxVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
28
+ def __init__(self):
29
+ pass
30
+
31
+ def from_civitai(self, state_dict):
32
+ rename_dict = {
33
+ "encoder.conv_in.bias": "conv_in.bias",
34
+ "encoder.conv_in.weight": "conv_in.weight",
35
+ "encoder.conv_out.bias": "conv_out.bias",
36
+ "encoder.conv_out.weight": "conv_out.weight",
37
+ "encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
38
+ "encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
39
+ "encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
40
+ "encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
41
+ "encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
42
+ "encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
43
+ "encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
44
+ "encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
45
+ "encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
46
+ "encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
47
+ "encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
48
+ "encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
49
+ "encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
50
+ "encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
51
+ "encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
52
+ "encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
53
+ "encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
54
+ "encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
55
+ "encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
56
+ "encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
57
+ "encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
58
+ "encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
59
+ "encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
60
+ "encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
61
+ "encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
62
+ "encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
63
+ "encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
64
+ "encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
65
+ "encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
66
+ "encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
67
+ "encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
68
+ "encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
69
+ "encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
70
+ "encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
71
+ "encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
72
+ "encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
73
+ "encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
74
+ "encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
75
+ "encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
76
+ "encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
77
+ "encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
78
+ "encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
79
+ "encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
80
+ "encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
81
+ "encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
82
+ "encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
83
+ "encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
84
+ "encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
85
+ "encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
86
+ "encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
87
+ "encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
88
+ "encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
89
+ "encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
90
+ "encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
91
+ "encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
92
+ "encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
93
+ "encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
94
+ "encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
95
+ "encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
96
+ "encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
97
+ "encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
98
+ "encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
99
+ "encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
100
+ "encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
101
+ "encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
102
+ "encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
103
+ "encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
104
+ "encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
105
+ "encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
106
+ "encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
107
+ "encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
108
+ "encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
109
+ "encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
110
+ "encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
111
+ "encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
112
+ "encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
113
+ "encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
114
+ "encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
115
+ "encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
116
+ "encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
117
+ "encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
118
+ "encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
119
+ "encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
120
+ "encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
121
+ "encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
122
+ "encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
123
+ "encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
124
+ "encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
125
+ "encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
126
+ "encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
127
+ "encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
128
+ "encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
129
+ "encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
130
+ "encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
131
+ "encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
132
+ "encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
133
+ "encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
134
+ "encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
135
+ "encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
136
+ "encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
137
+ "encoder.norm_out.bias": "conv_norm_out.bias",
138
+ "encoder.norm_out.weight": "conv_norm_out.weight",
139
+ }
140
+ state_dict_ = {}
141
+ for name in state_dict:
142
+ if name in rename_dict:
143
+ param = state_dict[name]
144
+ if "transformer_blocks" in rename_dict[name]:
145
+ param = param.squeeze()
146
+ state_dict_[rename_dict[name]] = param
147
+ return state_dict_
148
+
149
+
150
+
151
+ class FluxVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
152
+ def __init__(self):
153
+ pass
154
+
155
+ def from_civitai(self, state_dict):
156
+ rename_dict = {
157
+ "decoder.conv_in.bias": "conv_in.bias",
158
+ "decoder.conv_in.weight": "conv_in.weight",
159
+ "decoder.conv_out.bias": "conv_out.bias",
160
+ "decoder.conv_out.weight": "conv_out.weight",
161
+ "decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
162
+ "decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
163
+ "decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
164
+ "decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
165
+ "decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
166
+ "decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
167
+ "decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
168
+ "decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
169
+ "decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
170
+ "decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
171
+ "decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
172
+ "decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
173
+ "decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
174
+ "decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
175
+ "decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
176
+ "decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
177
+ "decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
178
+ "decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
179
+ "decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
180
+ "decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
181
+ "decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
182
+ "decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
183
+ "decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
184
+ "decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
185
+ "decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
186
+ "decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
187
+ "decoder.norm_out.bias": "conv_norm_out.bias",
188
+ "decoder.norm_out.weight": "conv_norm_out.weight",
189
+ "decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
190
+ "decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
191
+ "decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
192
+ "decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
193
+ "decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
194
+ "decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
195
+ "decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
196
+ "decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
197
+ "decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
198
+ "decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
199
+ "decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
200
+ "decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
201
+ "decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
202
+ "decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
203
+ "decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
204
+ "decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
205
+ "decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
206
+ "decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
207
+ "decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
208
+ "decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
209
+ "decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
210
+ "decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
211
+ "decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
212
+ "decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
213
+ "decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
214
+ "decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
215
+ "decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
216
+ "decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
217
+ "decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
218
+ "decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
219
+ "decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
220
+ "decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
221
+ "decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
222
+ "decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
223
+ "decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
224
+ "decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
225
+ "decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
226
+ "decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
227
+ "decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
228
+ "decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
229
+ "decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
230
+ "decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
231
+ "decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
232
+ "decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
233
+ "decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
234
+ "decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
235
+ "decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
236
+ "decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
237
+ "decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
238
+ "decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
239
+ "decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
240
+ "decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
241
+ "decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
242
+ "decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
243
+ "decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
244
+ "decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
245
+ "decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
246
+ "decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
247
+ "decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
248
+ "decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
249
+ "decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
250
+ "decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
251
+ "decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
252
+ "decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
253
+ "decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
254
+ "decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
255
+ "decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
256
+ "decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
257
+ "decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
258
+ "decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
259
+ "decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
260
+ "decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
261
+ "decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
262
+ "decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
263
+ "decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
264
+ "decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
265
+ "decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
266
+ "decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
267
+ "decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
268
+ "decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
269
+ "decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
270
+ "decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
271
+ "decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
272
+ "decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
273
+ "decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
274
+ "decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
275
+ "decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
276
+ "decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
277
+ "decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
278
+ "decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
279
+ "decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
280
+ "decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
281
+ "decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
282
+ "decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
283
+ "decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
284
+ "decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
285
+ "decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
286
+ "decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
287
+ "decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
288
+ "decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
289
+ "decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
290
+ "decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
291
+ "decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
292
+ "decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
293
+ "decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
294
+ "decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
295
+ }
296
+ state_dict_ = {}
297
+ for name in state_dict:
298
+ if name in rename_dict:
299
+ param = state_dict[name]
300
+ if "transformer_blocks" in rename_dict[name]:
301
+ param = param.squeeze()
302
+ state_dict_[rename_dict[name]] = param
303
+ return state_dict_
diffsynth/models/hunyuan_dit.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .attention import Attention
2
+ from einops import repeat, rearrange
3
+ import math
4
+ import torch
5
+
6
+
7
+ class HunyuanDiTRotaryEmbedding(torch.nn.Module):
8
+
9
+ def __init__(self, q_norm_shape=88, k_norm_shape=88, rotary_emb_on_k=True):
10
+ super().__init__()
11
+ self.q_norm = torch.nn.LayerNorm((q_norm_shape,), elementwise_affine=True, eps=1e-06)
12
+ self.k_norm = torch.nn.LayerNorm((k_norm_shape,), elementwise_affine=True, eps=1e-06)
13
+ self.rotary_emb_on_k = rotary_emb_on_k
14
+ self.k_cache, self.v_cache = [], []
15
+
16
+ def reshape_for_broadcast(self, freqs_cis, x):
17
+ ndim = x.ndim
18
+ shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
19
+ return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
20
+
21
+ def rotate_half(self, x):
22
+ x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
23
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
24
+
25
+ def apply_rotary_emb(self, xq, xk, freqs_cis):
26
+ xk_out = None
27
+ cos, sin = self.reshape_for_broadcast(freqs_cis, xq)
28
+ cos, sin = cos.to(xq.device), sin.to(xq.device)
29
+ xq_out = (xq.float() * cos + self.rotate_half(xq.float()) * sin).type_as(xq)
30
+ if xk is not None:
31
+ xk_out = (xk.float() * cos + self.rotate_half(xk.float()) * sin).type_as(xk)
32
+ return xq_out, xk_out
33
+
34
+ def forward(self, q, k, v, freqs_cis_img, to_cache=False):
35
+ # norm
36
+ q = self.q_norm(q)
37
+ k = self.k_norm(k)
38
+
39
+ # RoPE
40
+ if self.rotary_emb_on_k:
41
+ q, k = self.apply_rotary_emb(q, k, freqs_cis_img)
42
+ else:
43
+ q, _ = self.apply_rotary_emb(q, None, freqs_cis_img)
44
+
45
+ if to_cache:
46
+ self.k_cache.append(k)
47
+ self.v_cache.append(v)
48
+ elif len(self.k_cache) > 0 and len(self.v_cache) > 0:
49
+ k = torch.concat([k] + self.k_cache, dim=2)
50
+ v = torch.concat([v] + self.v_cache, dim=2)
51
+ self.k_cache, self.v_cache = [], []
52
+ return q, k, v
53
+
54
+
55
+ class FP32_Layernorm(torch.nn.LayerNorm):
56
+ def forward(self, inputs):
57
+ origin_dtype = inputs.dtype
58
+ return torch.nn.functional.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).to(origin_dtype)
59
+
60
+
61
+ class FP32_SiLU(torch.nn.SiLU):
62
+ def forward(self, inputs):
63
+ origin_dtype = inputs.dtype
64
+ return torch.nn.functional.silu(inputs.float(), inplace=False).to(origin_dtype)
65
+
66
+
67
+ class HunyuanDiTFinalLayer(torch.nn.Module):
68
+ def __init__(self, final_hidden_size=1408, condition_dim=1408, patch_size=2, out_channels=8):
69
+ super().__init__()
70
+ self.norm_final = torch.nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
71
+ self.linear = torch.nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
72
+ self.adaLN_modulation = torch.nn.Sequential(
73
+ FP32_SiLU(),
74
+ torch.nn.Linear(condition_dim, 2 * final_hidden_size, bias=True)
75
+ )
76
+
77
+ def modulate(self, x, shift, scale):
78
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
79
+
80
+ def forward(self, hidden_states, condition_emb):
81
+ shift, scale = self.adaLN_modulation(condition_emb).chunk(2, dim=1)
82
+ hidden_states = self.modulate(self.norm_final(hidden_states), shift, scale)
83
+ hidden_states = self.linear(hidden_states)
84
+ return hidden_states
85
+
86
+
87
+ class HunyuanDiTBlock(torch.nn.Module):
88
+
89
+ def __init__(
90
+ self,
91
+ hidden_dim=1408,
92
+ condition_dim=1408,
93
+ num_heads=16,
94
+ mlp_ratio=4.3637,
95
+ text_dim=1024,
96
+ skip_connection=False
97
+ ):
98
+ super().__init__()
99
+ self.norm1 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
100
+ self.rota1 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads)
101
+ self.attn1 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
102
+ self.norm2 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
103
+ self.rota2 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads, rotary_emb_on_k=False)
104
+ self.attn2 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, kv_dim=text_dim, bias_q=True, bias_kv=True, bias_out=True)
105
+ self.norm3 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
106
+ self.modulation = torch.nn.Sequential(FP32_SiLU(), torch.nn.Linear(condition_dim, hidden_dim, bias=True))
107
+ self.mlp = torch.nn.Sequential(
108
+ torch.nn.Linear(hidden_dim, int(hidden_dim*mlp_ratio), bias=True),
109
+ torch.nn.GELU(approximate="tanh"),
110
+ torch.nn.Linear(int(hidden_dim*mlp_ratio), hidden_dim, bias=True)
111
+ )
112
+ if skip_connection:
113
+ self.skip_norm = FP32_Layernorm((hidden_dim * 2,), eps=1e-6, elementwise_affine=True)
114
+ self.skip_linear = torch.nn.Linear(hidden_dim * 2, hidden_dim, bias=True)
115
+ else:
116
+ self.skip_norm, self.skip_linear = None, None
117
+
118
+ def forward(self, hidden_states, condition_emb, text_emb, freq_cis_img, residual=None, to_cache=False):
119
+ # Long Skip Connection
120
+ if self.skip_norm is not None and self.skip_linear is not None:
121
+ hidden_states = torch.cat([hidden_states, residual], dim=-1)
122
+ hidden_states = self.skip_norm(hidden_states)
123
+ hidden_states = self.skip_linear(hidden_states)
124
+
125
+ # Self-Attention
126
+ shift_msa = self.modulation(condition_emb).unsqueeze(dim=1)
127
+ attn_input = self.norm1(hidden_states) + shift_msa
128
+ hidden_states = hidden_states + self.attn1(attn_input, qkv_preprocessor=lambda q, k, v: self.rota1(q, k, v, freq_cis_img, to_cache=to_cache))
129
+
130
+ # Cross-Attention
131
+ attn_input = self.norm3(hidden_states)
132
+ hidden_states = hidden_states + self.attn2(attn_input, text_emb, qkv_preprocessor=lambda q, k, v: self.rota2(q, k, v, freq_cis_img))
133
+
134
+ # FFN Layer
135
+ mlp_input = self.norm2(hidden_states)
136
+ hidden_states = hidden_states + self.mlp(mlp_input)
137
+ return hidden_states
138
+
139
+
140
+ class AttentionPool(torch.nn.Module):
141
+ def __init__(self, spacial_dim, embed_dim, num_heads, output_dim = None):
142
+ super().__init__()
143
+ self.positional_embedding = torch.nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
144
+ self.k_proj = torch.nn.Linear(embed_dim, embed_dim)
145
+ self.q_proj = torch.nn.Linear(embed_dim, embed_dim)
146
+ self.v_proj = torch.nn.Linear(embed_dim, embed_dim)
147
+ self.c_proj = torch.nn.Linear(embed_dim, output_dim or embed_dim)
148
+ self.num_heads = num_heads
149
+
150
+ def forward(self, x):
151
+ x = x.permute(1, 0, 2) # NLC -> LNC
152
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
153
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
154
+ x, _ = torch.nn.functional.multi_head_attention_forward(
155
+ query=x[:1], key=x, value=x,
156
+ embed_dim_to_check=x.shape[-1],
157
+ num_heads=self.num_heads,
158
+ q_proj_weight=self.q_proj.weight,
159
+ k_proj_weight=self.k_proj.weight,
160
+ v_proj_weight=self.v_proj.weight,
161
+ in_proj_weight=None,
162
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
163
+ bias_k=None,
164
+ bias_v=None,
165
+ add_zero_attn=False,
166
+ dropout_p=0,
167
+ out_proj_weight=self.c_proj.weight,
168
+ out_proj_bias=self.c_proj.bias,
169
+ use_separate_proj_weight=True,
170
+ training=self.training,
171
+ need_weights=False
172
+ )
173
+ return x.squeeze(0)
174
+
175
+
176
+ class PatchEmbed(torch.nn.Module):
177
+ def __init__(
178
+ self,
179
+ patch_size=(2, 2),
180
+ in_chans=4,
181
+ embed_dim=1408,
182
+ bias=True,
183
+ ):
184
+ super().__init__()
185
+ self.proj = torch.nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
186
+
187
+ def forward(self, x):
188
+ x = self.proj(x)
189
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
190
+ return x
191
+
192
+
193
+ def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
194
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
195
+ if not repeat_only:
196
+ half = dim // 2
197
+ freqs = torch.exp(
198
+ -math.log(max_period)
199
+ * torch.arange(start=0, end=half, dtype=torch.float32)
200
+ / half
201
+ ).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线
202
+ args = t[:, None].float() * freqs[None]
203
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
204
+ if dim % 2:
205
+ embedding = torch.cat(
206
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
207
+ )
208
+ else:
209
+ embedding = repeat(t, "b -> b d", d=dim)
210
+ return embedding
211
+
212
+
213
+ class TimestepEmbedder(torch.nn.Module):
214
+ def __init__(self, hidden_size=1408, frequency_embedding_size=256):
215
+ super().__init__()
216
+ self.mlp = torch.nn.Sequential(
217
+ torch.nn.Linear(frequency_embedding_size, hidden_size, bias=True),
218
+ torch.nn.SiLU(),
219
+ torch.nn.Linear(hidden_size, hidden_size, bias=True),
220
+ )
221
+ self.frequency_embedding_size = frequency_embedding_size
222
+
223
+ def forward(self, t):
224
+ t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
225
+ t_emb = self.mlp(t_freq)
226
+ return t_emb
227
+
228
+
229
+ class HunyuanDiT(torch.nn.Module):
230
+ def __init__(self, num_layers_down=21, num_layers_up=19, in_channels=4, out_channels=8, hidden_dim=1408, text_dim=1024, t5_dim=2048, text_length=77, t5_length=256):
231
+ super().__init__()
232
+
233
+ # Embedders
234
+ self.text_emb_padding = torch.nn.Parameter(torch.randn(text_length + t5_length, text_dim, dtype=torch.float32))
235
+ self.t5_embedder = torch.nn.Sequential(
236
+ torch.nn.Linear(t5_dim, t5_dim * 4, bias=True),
237
+ FP32_SiLU(),
238
+ torch.nn.Linear(t5_dim * 4, text_dim, bias=True),
239
+ )
240
+ self.t5_pooler = AttentionPool(t5_length, t5_dim, num_heads=8, output_dim=1024)
241
+ self.style_embedder = torch.nn.Parameter(torch.randn(hidden_dim))
242
+ self.patch_embedder = PatchEmbed(in_chans=in_channels)
243
+ self.timestep_embedder = TimestepEmbedder()
244
+ self.extra_embedder = torch.nn.Sequential(
245
+ torch.nn.Linear(256 * 6 + 1024 + hidden_dim, hidden_dim * 4),
246
+ FP32_SiLU(),
247
+ torch.nn.Linear(hidden_dim * 4, hidden_dim),
248
+ )
249
+
250
+ # Transformer blocks
251
+ self.num_layers_down = num_layers_down
252
+ self.num_layers_up = num_layers_up
253
+ self.blocks = torch.nn.ModuleList(
254
+ [HunyuanDiTBlock(skip_connection=False) for _ in range(num_layers_down)] + \
255
+ [HunyuanDiTBlock(skip_connection=True) for _ in range(num_layers_up)]
256
+ )
257
+
258
+ # Output layers
259
+ self.final_layer = HunyuanDiTFinalLayer()
260
+ self.out_channels = out_channels
261
+
262
+ def prepare_text_emb(self, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5):
263
+ text_emb_mask = text_emb_mask.bool()
264
+ text_emb_mask_t5 = text_emb_mask_t5.bool()
265
+ text_emb_t5 = self.t5_embedder(text_emb_t5)
266
+ text_emb = torch.cat([text_emb, text_emb_t5], dim=1)
267
+ text_emb_mask = torch.cat([text_emb_mask, text_emb_mask_t5], dim=-1)
268
+ text_emb = torch.where(text_emb_mask.unsqueeze(2), text_emb, self.text_emb_padding.to(text_emb))
269
+ return text_emb
270
+
271
+ def prepare_extra_emb(self, text_emb_t5, timestep, size_emb, dtype, batch_size):
272
+ # Text embedding
273
+ pooled_text_emb_t5 = self.t5_pooler(text_emb_t5)
274
+
275
+ # Timestep embedding
276
+ timestep_emb = self.timestep_embedder(timestep)
277
+
278
+ # Size embedding
279
+ size_emb = timestep_embedding(size_emb.view(-1), 256).to(dtype)
280
+ size_emb = size_emb.view(-1, 6 * 256)
281
+
282
+ # Style embedding
283
+ style_emb = repeat(self.style_embedder, "D -> B D", B=batch_size)
284
+
285
+ # Concatenate all extra vectors
286
+ extra_emb = torch.cat([pooled_text_emb_t5, size_emb, style_emb], dim=1)
287
+ condition_emb = timestep_emb + self.extra_embedder(extra_emb)
288
+
289
+ return condition_emb
290
+
291
+ def unpatchify(self, x, h, w):
292
+ return rearrange(x, "B (H W) (P Q C) -> B C (H P) (W Q)", H=h, W=w, P=2, Q=2)
293
+
294
+ def build_mask(self, data, is_bound):
295
+ _, _, H, W = data.shape
296
+ h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
297
+ w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
298
+ border_width = (H + W) // 4
299
+ pad = torch.ones_like(h) * border_width
300
+ mask = torch.stack([
301
+ pad if is_bound[0] else h + 1,
302
+ pad if is_bound[1] else H - h,
303
+ pad if is_bound[2] else w + 1,
304
+ pad if is_bound[3] else W - w
305
+ ]).min(dim=0).values
306
+ mask = mask.clip(1, border_width)
307
+ mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
308
+ mask = rearrange(mask, "H W -> 1 H W")
309
+ return mask
310
+
311
+ def tiled_block_forward(self, block, hidden_states, condition_emb, text_emb, freq_cis_img, residual, torch_dtype, data_device, computation_device, tile_size, tile_stride):
312
+ B, C, H, W = hidden_states.shape
313
+
314
+ weight = torch.zeros((1, 1, H, W), dtype=torch_dtype, device=data_device)
315
+ values = torch.zeros((B, C, H, W), dtype=torch_dtype, device=data_device)
316
+
317
+ # Split tasks
318
+ tasks = []
319
+ for h in range(0, H, tile_stride):
320
+ for w in range(0, W, tile_stride):
321
+ if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
322
+ continue
323
+ h_, w_ = h + tile_size, w + tile_size
324
+ if h_ > H: h, h_ = H - tile_size, H
325
+ if w_ > W: w, w_ = W - tile_size, W
326
+ tasks.append((h, h_, w, w_))
327
+
328
+ # Run
329
+ for hl, hr, wl, wr in tasks:
330
+ hidden_states_batch = hidden_states[:, :, hl:hr, wl:wr].to(computation_device)
331
+ hidden_states_batch = rearrange(hidden_states_batch, "B C H W -> B (H W) C")
332
+ if residual is not None:
333
+ residual_batch = residual[:, :, hl:hr, wl:wr].to(computation_device)
334
+ residual_batch = rearrange(residual_batch, "B C H W -> B (H W) C")
335
+ else:
336
+ residual_batch = None
337
+
338
+ # Forward
339
+ hidden_states_batch = block(hidden_states_batch, condition_emb, text_emb, freq_cis_img, residual_batch).to(data_device)
340
+ hidden_states_batch = rearrange(hidden_states_batch, "B (H W) C -> B C H W", H=hr-hl)
341
+
342
+ mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
343
+ values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
344
+ weight[:, :, hl:hr, wl:wr] += mask
345
+ values /= weight
346
+ return values
347
+
348
+ def forward(
349
+ self, hidden_states, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5, timestep, size_emb, freq_cis_img,
350
+ tiled=False, tile_size=64, tile_stride=32,
351
+ to_cache=False,
352
+ use_gradient_checkpointing=False,
353
+ ):
354
+ # Embeddings
355
+ text_emb = self.prepare_text_emb(text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5)
356
+ condition_emb = self.prepare_extra_emb(text_emb_t5, timestep, size_emb, hidden_states.dtype, hidden_states.shape[0])
357
+
358
+ # Input
359
+ height, width = hidden_states.shape[-2], hidden_states.shape[-1]
360
+ hidden_states = self.patch_embedder(hidden_states)
361
+
362
+ # Blocks
363
+ def create_custom_forward(module):
364
+ def custom_forward(*inputs):
365
+ return module(*inputs)
366
+ return custom_forward
367
+ if tiled:
368
+ hidden_states = rearrange(hidden_states, "B (H W) C -> B C H W", H=height//2)
369
+ residuals = []
370
+ for block_id, block in enumerate(self.blocks):
371
+ residual = residuals.pop() if block_id >= self.num_layers_down else None
372
+ hidden_states = self.tiled_block_forward(
373
+ block, hidden_states, condition_emb, text_emb, freq_cis_img, residual,
374
+ torch_dtype=hidden_states.dtype, data_device=hidden_states.device, computation_device=hidden_states.device,
375
+ tile_size=tile_size, tile_stride=tile_stride
376
+ )
377
+ if block_id < self.num_layers_down - 2:
378
+ residuals.append(hidden_states)
379
+ hidden_states = rearrange(hidden_states, "B C H W -> B (H W) C")
380
+ else:
381
+ residuals = []
382
+ for block_id, block in enumerate(self.blocks):
383
+ residual = residuals.pop() if block_id >= self.num_layers_down else None
384
+ if self.training and use_gradient_checkpointing:
385
+ hidden_states = torch.utils.checkpoint.checkpoint(
386
+ create_custom_forward(block),
387
+ hidden_states, condition_emb, text_emb, freq_cis_img, residual,
388
+ use_reentrant=False,
389
+ )
390
+ else:
391
+ hidden_states = block(hidden_states, condition_emb, text_emb, freq_cis_img, residual, to_cache=to_cache)
392
+ if block_id < self.num_layers_down - 2:
393
+ residuals.append(hidden_states)
394
+
395
+ # Output
396
+ hidden_states = self.final_layer(hidden_states, condition_emb)
397
+ hidden_states = self.unpatchify(hidden_states, height//2, width//2)
398
+ hidden_states, _ = hidden_states.chunk(2, dim=1)
399
+ return hidden_states
400
+
401
+ @staticmethod
402
+ def state_dict_converter():
403
+ return HunyuanDiTStateDictConverter()
404
+
405
+
406
+
407
+ class HunyuanDiTStateDictConverter():
408
+ def __init__(self):
409
+ pass
410
+
411
+ def from_diffusers(self, state_dict):
412
+ state_dict_ = {}
413
+ for name, param in state_dict.items():
414
+ name_ = name
415
+ name_ = name_.replace(".default_modulation.", ".modulation.")
416
+ name_ = name_.replace(".mlp.fc1.", ".mlp.0.")
417
+ name_ = name_.replace(".mlp.fc2.", ".mlp.2.")
418
+ name_ = name_.replace(".attn1.q_norm.", ".rota1.q_norm.")
419
+ name_ = name_.replace(".attn2.q_norm.", ".rota2.q_norm.")
420
+ name_ = name_.replace(".attn1.k_norm.", ".rota1.k_norm.")
421
+ name_ = name_.replace(".attn2.k_norm.", ".rota2.k_norm.")
422
+ name_ = name_.replace(".q_proj.", ".to_q.")
423
+ name_ = name_.replace(".out_proj.", ".to_out.")
424
+ name_ = name_.replace("text_embedding_padding", "text_emb_padding")
425
+ name_ = name_.replace("mlp_t5.0.", "t5_embedder.0.")
426
+ name_ = name_.replace("mlp_t5.2.", "t5_embedder.2.")
427
+ name_ = name_.replace("pooler.", "t5_pooler.")
428
+ name_ = name_.replace("x_embedder.", "patch_embedder.")
429
+ name_ = name_.replace("t_embedder.", "timestep_embedder.")
430
+ name_ = name_.replace("t5_pooler.to_q.", "t5_pooler.q_proj.")
431
+ name_ = name_.replace("style_embedder.weight", "style_embedder")
432
+ if ".kv_proj." in name_:
433
+ param_k = param[:param.shape[0]//2]
434
+ param_v = param[param.shape[0]//2:]
435
+ state_dict_[name_.replace(".kv_proj.", ".to_k.")] = param_k
436
+ state_dict_[name_.replace(".kv_proj.", ".to_v.")] = param_v
437
+ elif ".Wqkv." in name_:
438
+ param_q = param[:param.shape[0]//3]
439
+ param_k = param[param.shape[0]//3:param.shape[0]//3*2]
440
+ param_v = param[param.shape[0]//3*2:]
441
+ state_dict_[name_.replace(".Wqkv.", ".to_q.")] = param_q
442
+ state_dict_[name_.replace(".Wqkv.", ".to_k.")] = param_k
443
+ state_dict_[name_.replace(".Wqkv.", ".to_v.")] = param_v
444
+ elif "style_embedder" in name_:
445
+ state_dict_[name_] = param.squeeze()
446
+ else:
447
+ state_dict_[name_] = param
448
+ return state_dict_
449
+
450
+ def from_civitai(self, state_dict):
451
+ return self.from_diffusers(state_dict)
diffsynth/models/hunyuan_dit_text_encoder.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertModel, BertConfig, T5EncoderModel, T5Config
2
+ import torch
3
+
4
+
5
+
6
+ class HunyuanDiTCLIPTextEncoder(BertModel):
7
+ def __init__(self):
8
+ config = BertConfig(
9
+ _name_or_path = "",
10
+ architectures = ["BertModel"],
11
+ attention_probs_dropout_prob = 0.1,
12
+ bos_token_id = 0,
13
+ classifier_dropout = None,
14
+ directionality = "bidi",
15
+ eos_token_id = 2,
16
+ hidden_act = "gelu",
17
+ hidden_dropout_prob = 0.1,
18
+ hidden_size = 1024,
19
+ initializer_range = 0.02,
20
+ intermediate_size = 4096,
21
+ layer_norm_eps = 1e-12,
22
+ max_position_embeddings = 512,
23
+ model_type = "bert",
24
+ num_attention_heads = 16,
25
+ num_hidden_layers = 24,
26
+ output_past = True,
27
+ pad_token_id = 0,
28
+ pooler_fc_size = 768,
29
+ pooler_num_attention_heads = 12,
30
+ pooler_num_fc_layers = 3,
31
+ pooler_size_per_head = 128,
32
+ pooler_type = "first_token_transform",
33
+ position_embedding_type = "absolute",
34
+ torch_dtype = "float32",
35
+ transformers_version = "4.37.2",
36
+ type_vocab_size = 2,
37
+ use_cache = True,
38
+ vocab_size = 47020
39
+ )
40
+ super().__init__(config, add_pooling_layer=False)
41
+ self.eval()
42
+
43
+ def forward(self, input_ids, attention_mask, clip_skip=1):
44
+ input_shape = input_ids.size()
45
+
46
+ batch_size, seq_length = input_shape
47
+ device = input_ids.device
48
+
49
+ past_key_values_length = 0
50
+
51
+ if attention_mask is None:
52
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
53
+
54
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
55
+
56
+ embedding_output = self.embeddings(
57
+ input_ids=input_ids,
58
+ position_ids=None,
59
+ token_type_ids=None,
60
+ inputs_embeds=None,
61
+ past_key_values_length=0,
62
+ )
63
+ encoder_outputs = self.encoder(
64
+ embedding_output,
65
+ attention_mask=extended_attention_mask,
66
+ head_mask=None,
67
+ encoder_hidden_states=None,
68
+ encoder_attention_mask=None,
69
+ past_key_values=None,
70
+ use_cache=False,
71
+ output_attentions=False,
72
+ output_hidden_states=True,
73
+ return_dict=True,
74
+ )
75
+ all_hidden_states = encoder_outputs.hidden_states
76
+ prompt_emb = all_hidden_states[-clip_skip]
77
+ if clip_skip > 1:
78
+ mean, std = all_hidden_states[-1].mean(), all_hidden_states[-1].std()
79
+ prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
80
+ return prompt_emb
81
+
82
+ @staticmethod
83
+ def state_dict_converter():
84
+ return HunyuanDiTCLIPTextEncoderStateDictConverter()
85
+
86
+
87
+
88
+ class HunyuanDiTT5TextEncoder(T5EncoderModel):
89
+ def __init__(self):
90
+ config = T5Config(
91
+ _name_or_path = "../HunyuanDiT/t2i/mt5",
92
+ architectures = ["MT5ForConditionalGeneration"],
93
+ classifier_dropout = 0.0,
94
+ d_ff = 5120,
95
+ d_kv = 64,
96
+ d_model = 2048,
97
+ decoder_start_token_id = 0,
98
+ dense_act_fn = "gelu_new",
99
+ dropout_rate = 0.1,
100
+ eos_token_id = 1,
101
+ feed_forward_proj = "gated-gelu",
102
+ initializer_factor = 1.0,
103
+ is_encoder_decoder = True,
104
+ is_gated_act = True,
105
+ layer_norm_epsilon = 1e-06,
106
+ model_type = "t5",
107
+ num_decoder_layers = 24,
108
+ num_heads = 32,
109
+ num_layers = 24,
110
+ output_past = True,
111
+ pad_token_id = 0,
112
+ relative_attention_max_distance = 128,
113
+ relative_attention_num_buckets = 32,
114
+ tie_word_embeddings = False,
115
+ tokenizer_class = "T5Tokenizer",
116
+ transformers_version = "4.37.2",
117
+ use_cache = True,
118
+ vocab_size = 250112
119
+ )
120
+ super().__init__(config)
121
+ self.eval()
122
+
123
+ def forward(self, input_ids, attention_mask, clip_skip=1):
124
+ outputs = super().forward(
125
+ input_ids=input_ids,
126
+ attention_mask=attention_mask,
127
+ output_hidden_states=True,
128
+ )
129
+ prompt_emb = outputs.hidden_states[-clip_skip]
130
+ if clip_skip > 1:
131
+ mean, std = outputs.hidden_states[-1].mean(), outputs.hidden_states[-1].std()
132
+ prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
133
+ return prompt_emb
134
+
135
+ @staticmethod
136
+ def state_dict_converter():
137
+ return HunyuanDiTT5TextEncoderStateDictConverter()
138
+
139
+
140
+
141
+ class HunyuanDiTCLIPTextEncoderStateDictConverter():
142
+ def __init__(self):
143
+ pass
144
+
145
+ def from_diffusers(self, state_dict):
146
+ state_dict_ = {name[5:]: param for name, param in state_dict.items() if name.startswith("bert.")}
147
+ return state_dict_
148
+
149
+ def from_civitai(self, state_dict):
150
+ return self.from_diffusers(state_dict)
151
+
152
+
153
+ class HunyuanDiTT5TextEncoderStateDictConverter():
154
+ def __init__(self):
155
+ pass
156
+
157
+ def from_diffusers(self, state_dict):
158
+ state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("encoder.")}
159
+ state_dict_["shared.weight"] = state_dict["shared.weight"]
160
+ return state_dict_
161
+
162
+ def from_civitai(self, state_dict):
163
+ return self.from_diffusers(state_dict)
diffsynth/models/kolors_text_encoder.py ADDED
@@ -0,0 +1,1552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This model is copied from https://github.com/Kwai-Kolors/Kolors/tree/master/kolors/models.
3
+ We didn't modify this model.
4
+ The tensor operation is performed in the prompter.
5
+ """
6
+
7
+
8
+ """ PyTorch ChatGLM model. """
9
+
10
+ import math
11
+ import copy
12
+ import warnings
13
+ import re
14
+ import sys
15
+
16
+ import torch
17
+ import torch.utils.checkpoint
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+ from torch.nn import CrossEntropyLoss, LayerNorm
21
+ from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
22
+ from torch.nn.utils import skip_init
23
+ from typing import Optional, Tuple, Union, List, Callable, Dict, Any
24
+ from copy import deepcopy
25
+
26
+ from transformers.modeling_outputs import (
27
+ BaseModelOutputWithPast,
28
+ CausalLMOutputWithPast,
29
+ SequenceClassifierOutputWithPast,
30
+ )
31
+ from transformers.modeling_utils import PreTrainedModel
32
+ from transformers.utils import logging
33
+ from transformers.generation.logits_process import LogitsProcessor
34
+ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
35
+ from transformers import PretrainedConfig
36
+ from torch.nn.parameter import Parameter
37
+ import bz2
38
+ import torch
39
+ import base64
40
+ import ctypes
41
+ from transformers.utils import logging
42
+ from typing import List
43
+
44
+
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+ try:
49
+ from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
50
+
51
+
52
+ class Kernel:
53
+ def __init__(self, code: bytes, function_names: List[str]):
54
+ self.code = code
55
+ self._function_names = function_names
56
+ self._cmodule = LazyKernelCModule(self.code)
57
+
58
+ for name in self._function_names:
59
+ setattr(self, name, KernelFunction(self._cmodule, name))
60
+
61
+
62
+ quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ"
63
+
64
+ kernels = Kernel(
65
+ bz2.decompress(base64.b64decode(quantization_code)),
66
+ [
67
+ "int4WeightCompression",
68
+ "int4WeightExtractionFloat",
69
+ "int4WeightExtractionHalf",
70
+ "int8WeightExtractionFloat",
71
+ "int8WeightExtractionHalf",
72
+ ],
73
+ )
74
+ except Exception as exception:
75
+ kernels = None
76
+ logger.warning("Failed to load cpm_kernels:" + str(exception))
77
+
78
+
79
+ class W8A16Linear(torch.autograd.Function):
80
+ @staticmethod
81
+ def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
82
+ ctx.inp_shape = inp.size()
83
+ ctx.weight_bit_width = weight_bit_width
84
+ out_features = quant_w.size(0)
85
+ inp = inp.contiguous().view(-1, inp.size(-1))
86
+ weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
87
+ ctx.weight_shape = weight.size()
88
+ output = inp.mm(weight.t())
89
+ ctx.save_for_backward(inp, quant_w, scale_w)
90
+ return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
91
+
92
+ @staticmethod
93
+ def backward(ctx, grad_output: torch.Tensor):
94
+ inp, quant_w, scale_w = ctx.saved_tensors
95
+ weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width)
96
+ grad_output = grad_output.contiguous().view(-1, weight.size(0))
97
+ grad_input = grad_output.mm(weight)
98
+ grad_weight = grad_output.t().mm(inp)
99
+ return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
100
+
101
+
102
+ def compress_int4_weight(weight: torch.Tensor): # (n, m)
103
+ with torch.cuda.device(weight.device):
104
+ n, m = weight.size(0), weight.size(1)
105
+ assert m % 2 == 0
106
+ m = m // 2
107
+ out = torch.empty(n, m, dtype=torch.int8, device="cuda")
108
+ stream = torch.cuda.current_stream()
109
+
110
+ gridDim = (n, 1, 1)
111
+ blockDim = (min(round_up(m, 32), 1024), 1, 1)
112
+
113
+ kernels.int4WeightCompression(
114
+ gridDim,
115
+ blockDim,
116
+ 0,
117
+ stream,
118
+ [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],
119
+ )
120
+ return out
121
+
122
+
123
+ def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):
124
+ assert scale_list.dtype in [torch.half, torch.bfloat16]
125
+ assert weight.dtype in [torch.int8]
126
+ if source_bit_width == 8:
127
+ return weight.to(scale_list.dtype) * scale_list[:, None]
128
+ elif source_bit_width == 4:
129
+ func = (
130
+ kernels.int4WeightExtractionHalf if scale_list.dtype == torch.half else kernels.int4WeightExtractionBFloat16
131
+ )
132
+ else:
133
+ assert False, "Unsupported bit-width"
134
+
135
+ with torch.cuda.device(weight.device):
136
+ n, m = weight.size(0), weight.size(1)
137
+ out = torch.empty(n, m * (8 // source_bit_width), dtype=scale_list.dtype, device="cuda")
138
+ stream = torch.cuda.current_stream()
139
+
140
+ gridDim = (n, 1, 1)
141
+ blockDim = (min(round_up(m, 32), 1024), 1, 1)
142
+
143
+ func(
144
+ gridDim,
145
+ blockDim,
146
+ 0,
147
+ stream,
148
+ [
149
+ ctypes.c_void_p(weight.data_ptr()),
150
+ ctypes.c_void_p(scale_list.data_ptr()),
151
+ ctypes.c_void_p(out.data_ptr()),
152
+ ctypes.c_int32(n),
153
+ ctypes.c_int32(m),
154
+ ],
155
+ )
156
+ return out
157
+
158
+
159
+ class QuantizedLinear(torch.nn.Module):
160
+ def __init__(self, weight_bit_width: int, weight, bias=None, device="cuda", dtype=None, empty_init=False):
161
+ super().__init__()
162
+ weight = weight.to(device) # ensure the weight is on the cuda device
163
+ assert str(weight.device).startswith(
164
+ 'cuda'), 'The weights that need to be quantified should be on the CUDA device'
165
+ self.weight_bit_width = weight_bit_width
166
+ shape = weight.shape
167
+
168
+ if weight is None or empty_init:
169
+ self.weight = torch.empty(shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=device)
170
+ self.weight_scale = torch.empty(shape[0], dtype=dtype, device=device)
171
+ else:
172
+ self.weight_scale = weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)
173
+ self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
174
+ if weight_bit_width == 4:
175
+ self.weight = compress_int4_weight(self.weight)
176
+
177
+ self.weight = Parameter(self.weight.to(device), requires_grad=False)
178
+ self.weight_scale = Parameter(self.weight_scale.to(device), requires_grad=False)
179
+ self.bias = Parameter(bias.to(device), requires_grad=False) if bias is not None else None
180
+
181
+ def forward(self, input):
182
+ output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width)
183
+ if self.bias is not None:
184
+ output = output + self.bias
185
+ return output
186
+
187
+
188
+ def quantize(model, weight_bit_width, empty_init=False, device=None):
189
+ """Replace fp16 linear with quantized linear"""
190
+ for layer in model.layers:
191
+ layer.self_attention.query_key_value = QuantizedLinear(
192
+ weight_bit_width=weight_bit_width,
193
+ weight=layer.self_attention.query_key_value.weight,
194
+ bias=layer.self_attention.query_key_value.bias,
195
+ dtype=layer.self_attention.query_key_value.weight.dtype,
196
+ device=layer.self_attention.query_key_value.weight.device if device is None else device,
197
+ empty_init=empty_init
198
+ )
199
+ layer.self_attention.dense = QuantizedLinear(
200
+ weight_bit_width=weight_bit_width,
201
+ weight=layer.self_attention.dense.weight,
202
+ bias=layer.self_attention.dense.bias,
203
+ dtype=layer.self_attention.dense.weight.dtype,
204
+ device=layer.self_attention.dense.weight.device if device is None else device,
205
+ empty_init=empty_init
206
+ )
207
+ layer.mlp.dense_h_to_4h = QuantizedLinear(
208
+ weight_bit_width=weight_bit_width,
209
+ weight=layer.mlp.dense_h_to_4h.weight,
210
+ bias=layer.mlp.dense_h_to_4h.bias,
211
+ dtype=layer.mlp.dense_h_to_4h.weight.dtype,
212
+ device=layer.mlp.dense_h_to_4h.weight.device if device is None else device,
213
+ empty_init=empty_init
214
+ )
215
+ layer.mlp.dense_4h_to_h = QuantizedLinear(
216
+ weight_bit_width=weight_bit_width,
217
+ weight=layer.mlp.dense_4h_to_h.weight,
218
+ bias=layer.mlp.dense_4h_to_h.bias,
219
+ dtype=layer.mlp.dense_4h_to_h.weight.dtype,
220
+ device=layer.mlp.dense_4h_to_h.weight.device if device is None else device,
221
+ empty_init=empty_init
222
+ )
223
+
224
+ return model
225
+
226
+
227
+
228
+ class ChatGLMConfig(PretrainedConfig):
229
+ model_type = "chatglm"
230
+ def __init__(
231
+ self,
232
+ num_layers=28,
233
+ padded_vocab_size=65024,
234
+ hidden_size=4096,
235
+ ffn_hidden_size=13696,
236
+ kv_channels=128,
237
+ num_attention_heads=32,
238
+ seq_length=2048,
239
+ hidden_dropout=0.0,
240
+ classifier_dropout=None,
241
+ attention_dropout=0.0,
242
+ layernorm_epsilon=1e-5,
243
+ rmsnorm=True,
244
+ apply_residual_connection_post_layernorm=False,
245
+ post_layer_norm=True,
246
+ add_bias_linear=False,
247
+ add_qkv_bias=False,
248
+ bias_dropout_fusion=True,
249
+ multi_query_attention=False,
250
+ multi_query_group_num=1,
251
+ apply_query_key_layer_scaling=True,
252
+ attention_softmax_in_fp32=True,
253
+ fp32_residual_connection=False,
254
+ quantization_bit=0,
255
+ pre_seq_len=None,
256
+ prefix_projection=False,
257
+ **kwargs
258
+ ):
259
+ self.num_layers = num_layers
260
+ self.vocab_size = padded_vocab_size
261
+ self.padded_vocab_size = padded_vocab_size
262
+ self.hidden_size = hidden_size
263
+ self.ffn_hidden_size = ffn_hidden_size
264
+ self.kv_channels = kv_channels
265
+ self.num_attention_heads = num_attention_heads
266
+ self.seq_length = seq_length
267
+ self.hidden_dropout = hidden_dropout
268
+ self.classifier_dropout = classifier_dropout
269
+ self.attention_dropout = attention_dropout
270
+ self.layernorm_epsilon = layernorm_epsilon
271
+ self.rmsnorm = rmsnorm
272
+ self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
273
+ self.post_layer_norm = post_layer_norm
274
+ self.add_bias_linear = add_bias_linear
275
+ self.add_qkv_bias = add_qkv_bias
276
+ self.bias_dropout_fusion = bias_dropout_fusion
277
+ self.multi_query_attention = multi_query_attention
278
+ self.multi_query_group_num = multi_query_group_num
279
+ self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
280
+ self.attention_softmax_in_fp32 = attention_softmax_in_fp32
281
+ self.fp32_residual_connection = fp32_residual_connection
282
+ self.quantization_bit = quantization_bit
283
+ self.pre_seq_len = pre_seq_len
284
+ self.prefix_projection = prefix_projection
285
+ super().__init__(**kwargs)
286
+
287
+
288
+
289
+ # flags required to enable jit fusion kernels
290
+
291
+ if sys.platform != 'darwin':
292
+ torch._C._jit_set_profiling_mode(False)
293
+ torch._C._jit_set_profiling_executor(False)
294
+ torch._C._jit_override_can_fuse_on_cpu(True)
295
+ torch._C._jit_override_can_fuse_on_gpu(True)
296
+
297
+ logger = logging.get_logger(__name__)
298
+
299
+ _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
300
+ _CONFIG_FOR_DOC = "ChatGLM6BConfig"
301
+
302
+ CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
303
+ "THUDM/chatglm3-6b-base",
304
+ # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
305
+ ]
306
+
307
+
308
+ def default_init(cls, *args, **kwargs):
309
+ return cls(*args, **kwargs)
310
+
311
+
312
+ class InvalidScoreLogitsProcessor(LogitsProcessor):
313
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
314
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
315
+ scores.zero_()
316
+ scores[..., 5] = 5e4
317
+ return scores
318
+
319
+
320
+ class PrefixEncoder(torch.nn.Module):
321
+ """
322
+ The torch.nn model to encode the prefix
323
+ Input shape: (batch-size, prefix-length)
324
+ Output shape: (batch-size, prefix-length, 2*layers*hidden)
325
+ """
326
+
327
+ def __init__(self, config: ChatGLMConfig):
328
+ super().__init__()
329
+ self.prefix_projection = config.prefix_projection
330
+ if self.prefix_projection:
331
+ # Use a two-layer MLP to encode the prefix
332
+ kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
333
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
334
+ self.trans = torch.nn.Sequential(
335
+ torch.nn.Linear(kv_size, config.hidden_size),
336
+ torch.nn.Tanh(),
337
+ torch.nn.Linear(config.hidden_size, kv_size)
338
+ )
339
+ else:
340
+ self.embedding = torch.nn.Embedding(config.pre_seq_len,
341
+ config.num_layers * config.kv_channels * config.multi_query_group_num * 2)
342
+
343
+ def forward(self, prefix: torch.Tensor):
344
+ if self.prefix_projection:
345
+ prefix_tokens = self.embedding(prefix)
346
+ past_key_values = self.trans(prefix_tokens)
347
+ else:
348
+ past_key_values = self.embedding(prefix)
349
+ return past_key_values
350
+
351
+
352
+ def split_tensor_along_last_dim(
353
+ tensor: torch.Tensor,
354
+ num_partitions: int,
355
+ contiguous_split_chunks: bool = False,
356
+ ) -> List[torch.Tensor]:
357
+ """Split a tensor along its last dimension.
358
+
359
+ Arguments:
360
+ tensor: input tensor.
361
+ num_partitions: number of partitions to split the tensor
362
+ contiguous_split_chunks: If True, make each chunk contiguous
363
+ in memory.
364
+
365
+ Returns:
366
+ A list of Tensors
367
+ """
368
+ # Get the size and dimension.
369
+ last_dim = tensor.dim() - 1
370
+ last_dim_size = tensor.size()[last_dim] // num_partitions
371
+ # Split.
372
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
373
+ # Note: torch.split does not create contiguous tensors by default.
374
+ if contiguous_split_chunks:
375
+ return tuple(chunk.contiguous() for chunk in tensor_list)
376
+
377
+ return tensor_list
378
+
379
+
380
+ class RotaryEmbedding(nn.Module):
381
+ def __init__(self, dim, original_impl=False, device=None, dtype=None):
382
+ super().__init__()
383
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
384
+ self.register_buffer("inv_freq", inv_freq)
385
+ self.dim = dim
386
+ self.original_impl = original_impl
387
+
388
+ def forward_impl(
389
+ self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
390
+ ):
391
+ """Enhanced Transformer with Rotary Position Embedding.
392
+
393
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
394
+ transformers/rope/__init__.py. MIT License:
395
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
396
+ """
397
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
398
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
399
+
400
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
401
+ seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
402
+
403
+ # Calculate the product of position index and $\theta_i$
404
+ idx_theta = torch.outer(seq_idx, theta).float()
405
+
406
+ cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
407
+
408
+ # this is to mimic the behaviour of complex32, else we will get different results
409
+ if dtype in (torch.float16, torch.bfloat16, torch.int8):
410
+ cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
411
+ return cache
412
+
413
+ def forward(self, max_seq_len, offset=0):
414
+ return self.forward_impl(
415
+ max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
416
+ )
417
+
418
+
419
+ @torch.jit.script
420
+ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
421
+ # x: [sq, b, np, hn]
422
+ sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
423
+ rot_dim = rope_cache.shape[-2] * 2
424
+ x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
425
+ # truncate to support variable sizes
426
+ rope_cache = rope_cache[:sq]
427
+ xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
428
+ rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
429
+ x_out2 = torch.stack(
430
+ [
431
+ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
432
+ xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
433
+ ],
434
+ -1,
435
+ )
436
+ x_out2 = x_out2.flatten(3)
437
+ return torch.cat((x_out2, x_pass), dim=-1)
438
+
439
+
440
+ class RMSNorm(torch.nn.Module):
441
+ def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
442
+ super().__init__()
443
+ self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
444
+ self.eps = eps
445
+
446
+ def forward(self, hidden_states: torch.Tensor):
447
+ input_dtype = hidden_states.dtype
448
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
449
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
450
+
451
+ return (self.weight * hidden_states).to(input_dtype)
452
+
453
+
454
+ class CoreAttention(torch.nn.Module):
455
+ def __init__(self, config: ChatGLMConfig, layer_number):
456
+ super(CoreAttention, self).__init__()
457
+
458
+ self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
459
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
460
+ if self.apply_query_key_layer_scaling:
461
+ self.attention_softmax_in_fp32 = True
462
+ self.layer_number = max(1, layer_number)
463
+
464
+ projection_size = config.kv_channels * config.num_attention_heads
465
+
466
+ # Per attention head and per partition values.
467
+ self.hidden_size_per_partition = projection_size
468
+ self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
469
+ self.num_attention_heads_per_partition = config.num_attention_heads
470
+
471
+ coeff = None
472
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
473
+ if self.apply_query_key_layer_scaling:
474
+ coeff = self.layer_number
475
+ self.norm_factor *= coeff
476
+ self.coeff = coeff
477
+
478
+ self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
479
+
480
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
481
+ pytorch_major_version = int(torch.__version__.split('.')[0])
482
+ if pytorch_major_version >= 2:
483
+ query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
484
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
485
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
486
+ is_causal=True)
487
+ else:
488
+ if attention_mask is not None:
489
+ attention_mask = ~attention_mask
490
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
491
+ attention_mask)
492
+ context_layer = context_layer.permute(2, 0, 1, 3)
493
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
494
+ context_layer = context_layer.reshape(*new_context_layer_shape)
495
+ else:
496
+ # Raw attention scores
497
+
498
+ # [b, np, sq, sk]
499
+ output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
500
+
501
+ # [sq, b, np, hn] -> [sq, b * np, hn]
502
+ query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
503
+ # [sk, b, np, hn] -> [sk, b * np, hn]
504
+ key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
505
+
506
+ # preallocting input tensor: [b * np, sq, sk]
507
+ matmul_input_buffer = torch.empty(
508
+ output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
509
+ device=query_layer.device
510
+ )
511
+
512
+ # Raw attention scores. [b * np, sq, sk]
513
+ matmul_result = torch.baddbmm(
514
+ matmul_input_buffer,
515
+ query_layer.transpose(0, 1), # [b * np, sq, hn]
516
+ key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
517
+ beta=0.0,
518
+ alpha=(1.0 / self.norm_factor),
519
+ )
520
+
521
+ # change view to [b, np, sq, sk]
522
+ attention_scores = matmul_result.view(*output_size)
523
+
524
+ # ===========================
525
+ # Attention probs and dropout
526
+ # ===========================
527
+
528
+ # attention scores and attention mask [b, np, sq, sk]
529
+ if self.attention_softmax_in_fp32:
530
+ attention_scores = attention_scores.float()
531
+ if self.coeff is not None:
532
+ attention_scores = attention_scores * self.coeff
533
+ if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
534
+ attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
535
+ device=attention_scores.device, dtype=torch.bool)
536
+ attention_mask.tril_()
537
+ attention_mask = ~attention_mask
538
+ if attention_mask is not None:
539
+ attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
540
+ attention_probs = F.softmax(attention_scores, dim=-1)
541
+ attention_probs = attention_probs.type_as(value_layer)
542
+
543
+ # This is actually dropping out entire tokens to attend to, which might
544
+ # seem a bit unusual, but is taken from the original Transformer paper.
545
+ attention_probs = self.attention_dropout(attention_probs)
546
+ # =========================
547
+ # Context layer. [sq, b, hp]
548
+ # =========================
549
+
550
+ # value_layer -> context layer.
551
+ # [sk, b, np, hn] --> [b, np, sq, hn]
552
+
553
+ # context layer shape: [b, np, sq, hn]
554
+ output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
555
+ # change view [sk, b * np, hn]
556
+ value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
557
+ # change view [b * np, sq, sk]
558
+ attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
559
+ # matmul: [b * np, sq, hn]
560
+ context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
561
+ # change view [b, np, sq, hn]
562
+ context_layer = context_layer.view(*output_size)
563
+ # [b, np, sq, hn] --> [sq, b, np, hn]
564
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
565
+ # [sq, b, np, hn] --> [sq, b, hp]
566
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
567
+ context_layer = context_layer.view(*new_context_layer_shape)
568
+
569
+ return context_layer
570
+
571
+
572
+ class SelfAttention(torch.nn.Module):
573
+ """Parallel self-attention layer abstract class.
574
+
575
+ Self-attention layer takes input with size [s, b, h]
576
+ and returns output of the same size.
577
+ """
578
+
579
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
580
+ super(SelfAttention, self).__init__()
581
+ self.layer_number = max(1, layer_number)
582
+
583
+ self.projection_size = config.kv_channels * config.num_attention_heads
584
+
585
+ # Per attention head and per partition values.
586
+ self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
587
+ self.num_attention_heads_per_partition = config.num_attention_heads
588
+
589
+ self.multi_query_attention = config.multi_query_attention
590
+ self.qkv_hidden_size = 3 * self.projection_size
591
+ if self.multi_query_attention:
592
+ self.num_multi_query_groups_per_partition = config.multi_query_group_num
593
+ self.qkv_hidden_size = (
594
+ self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
595
+ )
596
+ self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
597
+ bias=config.add_bias_linear or config.add_qkv_bias,
598
+ device=device, **_config_to_kwargs(config)
599
+ )
600
+
601
+ self.core_attention = CoreAttention(config, self.layer_number)
602
+
603
+ # Output.
604
+ self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
605
+ device=device, **_config_to_kwargs(config)
606
+ )
607
+
608
+ def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
609
+ if self.multi_query_attention:
610
+ num_attention_heads = self.num_multi_query_groups_per_partition
611
+ else:
612
+ num_attention_heads = self.num_attention_heads_per_partition
613
+ return torch.empty(
614
+ inference_max_sequence_len,
615
+ batch_size,
616
+ num_attention_heads,
617
+ self.hidden_size_per_attention_head,
618
+ dtype=dtype,
619
+ device=device,
620
+ )
621
+
622
+ def forward(
623
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
624
+ ):
625
+ # hidden_states: [sq, b, h]
626
+
627
+ # =================================================
628
+ # Pre-allocate memory for key-values for inference.
629
+ # =================================================
630
+ # =====================
631
+ # Query, Key, and Value
632
+ # =====================
633
+
634
+ # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
635
+ mixed_x_layer = self.query_key_value(hidden_states)
636
+
637
+ if self.multi_query_attention:
638
+ (query_layer, key_layer, value_layer) = mixed_x_layer.split(
639
+ [
640
+ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
641
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
642
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
643
+ ],
644
+ dim=-1,
645
+ )
646
+ query_layer = query_layer.view(
647
+ query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
648
+ )
649
+ key_layer = key_layer.view(
650
+ key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
651
+ )
652
+ value_layer = value_layer.view(
653
+ value_layer.size()[:-1]
654
+ + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
655
+ )
656
+ else:
657
+ new_tensor_shape = mixed_x_layer.size()[:-1] + \
658
+ (self.num_attention_heads_per_partition,
659
+ 3 * self.hidden_size_per_attention_head)
660
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
661
+
662
+ # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
663
+ (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
664
+
665
+ # apply relative positional encoding (rotary embedding)
666
+ if rotary_pos_emb is not None:
667
+ query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
668
+ key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
669
+
670
+ # adjust key and value for inference
671
+ if kv_cache is not None:
672
+ cache_k, cache_v = kv_cache
673
+ key_layer = torch.cat((cache_k, key_layer), dim=0)
674
+ value_layer = torch.cat((cache_v, value_layer), dim=0)
675
+ if use_cache:
676
+ kv_cache = (key_layer, value_layer)
677
+ else:
678
+ kv_cache = None
679
+
680
+ if self.multi_query_attention:
681
+ key_layer = key_layer.unsqueeze(-2)
682
+ key_layer = key_layer.expand(
683
+ -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
684
+ )
685
+ key_layer = key_layer.contiguous().view(
686
+ key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
687
+ )
688
+ value_layer = value_layer.unsqueeze(-2)
689
+ value_layer = value_layer.expand(
690
+ -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
691
+ )
692
+ value_layer = value_layer.contiguous().view(
693
+ value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
694
+ )
695
+
696
+ # ==================================
697
+ # core attention computation
698
+ # ==================================
699
+
700
+ context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
701
+
702
+ # =================
703
+ # Output. [sq, b, h]
704
+ # =================
705
+
706
+ output = self.dense(context_layer)
707
+
708
+ return output, kv_cache
709
+
710
+
711
+ def _config_to_kwargs(args):
712
+ common_kwargs = {
713
+ "dtype": args.torch_dtype,
714
+ }
715
+ return common_kwargs
716
+
717
+
718
+ class MLP(torch.nn.Module):
719
+ """MLP.
720
+
721
+ MLP will take the input with h hidden state, project it to 4*h
722
+ hidden dimension, perform nonlinear transformation, and project the
723
+ state back into h hidden dimension.
724
+ """
725
+
726
+ def __init__(self, config: ChatGLMConfig, device=None):
727
+ super(MLP, self).__init__()
728
+
729
+ self.add_bias = config.add_bias_linear
730
+
731
+ # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
732
+ self.dense_h_to_4h = nn.Linear(
733
+ config.hidden_size,
734
+ config.ffn_hidden_size * 2,
735
+ bias=self.add_bias,
736
+ device=device,
737
+ **_config_to_kwargs(config)
738
+ )
739
+
740
+ def swiglu(x):
741
+ x = torch.chunk(x, 2, dim=-1)
742
+ return F.silu(x[0]) * x[1]
743
+
744
+ self.activation_func = swiglu
745
+
746
+ # Project back to h.
747
+ self.dense_4h_to_h = nn.Linear(
748
+ config.ffn_hidden_size,
749
+ config.hidden_size,
750
+ bias=self.add_bias,
751
+ device=device,
752
+ **_config_to_kwargs(config)
753
+ )
754
+
755
+ def forward(self, hidden_states):
756
+ # [s, b, 4hp]
757
+ intermediate_parallel = self.dense_h_to_4h(hidden_states)
758
+ intermediate_parallel = self.activation_func(intermediate_parallel)
759
+ # [s, b, h]
760
+ output = self.dense_4h_to_h(intermediate_parallel)
761
+ return output
762
+
763
+
764
+ class GLMBlock(torch.nn.Module):
765
+ """A single transformer layer.
766
+
767
+ Transformer layer takes input with size [s, b, h] and returns an
768
+ output of the same size.
769
+ """
770
+
771
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
772
+ super(GLMBlock, self).__init__()
773
+ self.layer_number = layer_number
774
+
775
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
776
+
777
+ self.fp32_residual_connection = config.fp32_residual_connection
778
+
779
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
780
+ # Layernorm on the input data.
781
+ self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
782
+ dtype=config.torch_dtype)
783
+
784
+ # Self attention.
785
+ self.self_attention = SelfAttention(config, layer_number, device=device)
786
+ self.hidden_dropout = config.hidden_dropout
787
+
788
+ # Layernorm on the attention output
789
+ self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
790
+ dtype=config.torch_dtype)
791
+
792
+ # MLP
793
+ self.mlp = MLP(config, device=device)
794
+
795
+ def forward(
796
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
797
+ ):
798
+ # hidden_states: [s, b, h]
799
+
800
+ # Layer norm at the beginning of the transformer layer.
801
+ layernorm_output = self.input_layernorm(hidden_states)
802
+ # Self attention.
803
+ attention_output, kv_cache = self.self_attention(
804
+ layernorm_output,
805
+ attention_mask,
806
+ rotary_pos_emb,
807
+ kv_cache=kv_cache,
808
+ use_cache=use_cache
809
+ )
810
+
811
+ # Residual connection.
812
+ if self.apply_residual_connection_post_layernorm:
813
+ residual = layernorm_output
814
+ else:
815
+ residual = hidden_states
816
+
817
+ layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
818
+ layernorm_input = residual + layernorm_input
819
+
820
+ # Layer norm post the self attention.
821
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
822
+
823
+ # MLP.
824
+ mlp_output = self.mlp(layernorm_output)
825
+
826
+ # Second residual connection.
827
+ if self.apply_residual_connection_post_layernorm:
828
+ residual = layernorm_output
829
+ else:
830
+ residual = layernorm_input
831
+
832
+ output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
833
+ output = residual + output
834
+
835
+ return output, kv_cache
836
+
837
+
838
+ class GLMTransformer(torch.nn.Module):
839
+ """Transformer class."""
840
+
841
+ def __init__(self, config: ChatGLMConfig, device=None):
842
+ super(GLMTransformer, self).__init__()
843
+
844
+ self.fp32_residual_connection = config.fp32_residual_connection
845
+ self.post_layer_norm = config.post_layer_norm
846
+
847
+ # Number of layers.
848
+ self.num_layers = config.num_layers
849
+
850
+ # Transformer layers.
851
+ def build_layer(layer_number):
852
+ return GLMBlock(config, layer_number, device=device)
853
+
854
+ self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
855
+
856
+ if self.post_layer_norm:
857
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
858
+ # Final layer norm before output.
859
+ self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
860
+ dtype=config.torch_dtype)
861
+
862
+ self.gradient_checkpointing = False
863
+
864
+ def _get_layer(self, layer_number):
865
+ return self.layers[layer_number]
866
+
867
+ def forward(
868
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
869
+ use_cache: Optional[bool] = True,
870
+ output_hidden_states: Optional[bool] = False,
871
+ ):
872
+ if not kv_caches:
873
+ kv_caches = [None for _ in range(self.num_layers)]
874
+ presents = () if use_cache else None
875
+ if self.gradient_checkpointing and self.training:
876
+ if use_cache:
877
+ logger.warning_once(
878
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
879
+ )
880
+ use_cache = False
881
+
882
+ all_self_attentions = None
883
+ all_hidden_states = () if output_hidden_states else None
884
+ for index in range(self.num_layers):
885
+ if output_hidden_states:
886
+ all_hidden_states = all_hidden_states + (hidden_states,)
887
+
888
+ layer = self._get_layer(index)
889
+ if self.gradient_checkpointing and self.training:
890
+ layer_ret = torch.utils.checkpoint.checkpoint(
891
+ layer,
892
+ hidden_states,
893
+ attention_mask,
894
+ rotary_pos_emb,
895
+ kv_caches[index],
896
+ use_cache
897
+ )
898
+ else:
899
+ layer_ret = layer(
900
+ hidden_states,
901
+ attention_mask,
902
+ rotary_pos_emb,
903
+ kv_cache=kv_caches[index],
904
+ use_cache=use_cache
905
+ )
906
+ hidden_states, kv_cache = layer_ret
907
+ if use_cache:
908
+ presents = presents + (kv_cache,)
909
+
910
+ if output_hidden_states:
911
+ all_hidden_states = all_hidden_states + (hidden_states,)
912
+
913
+ # Final layer norm.
914
+ if self.post_layer_norm:
915
+ hidden_states = self.final_layernorm(hidden_states)
916
+
917
+ return hidden_states, presents, all_hidden_states, all_self_attentions
918
+
919
+
920
+ class ChatGLMPreTrainedModel(PreTrainedModel):
921
+ """
922
+ An abstract class to handle weights initialization and
923
+ a simple interface for downloading and loading pretrained models.
924
+ """
925
+
926
+ is_parallelizable = False
927
+ supports_gradient_checkpointing = True
928
+ config_class = ChatGLMConfig
929
+ base_model_prefix = "transformer"
930
+ _no_split_modules = ["GLMBlock"]
931
+
932
+ def _init_weights(self, module: nn.Module):
933
+ """Initialize the weights."""
934
+ return
935
+
936
+ def get_masks(self, input_ids, past_key_values, padding_mask=None):
937
+ batch_size, seq_length = input_ids.shape
938
+ full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
939
+ full_attention_mask.tril_()
940
+ past_length = 0
941
+ if past_key_values:
942
+ past_length = past_key_values[0][0].shape[0]
943
+ if past_length:
944
+ full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
945
+ device=input_ids.device), full_attention_mask), dim=-1)
946
+ if padding_mask is not None:
947
+ full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
948
+ if not past_length and padding_mask is not None:
949
+ full_attention_mask -= padding_mask.unsqueeze(-1) - 1
950
+ full_attention_mask = (full_attention_mask < 0.5).bool()
951
+ full_attention_mask.unsqueeze_(1)
952
+ return full_attention_mask
953
+
954
+ def get_position_ids(self, input_ids, device):
955
+ batch_size, seq_length = input_ids.shape
956
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
957
+ return position_ids
958
+
959
+ def _set_gradient_checkpointing(self, module, value=False):
960
+ if isinstance(module, GLMTransformer):
961
+ module.gradient_checkpointing = value
962
+
963
+
964
+ class Embedding(torch.nn.Module):
965
+ """Language model embeddings."""
966
+
967
+ def __init__(self, config: ChatGLMConfig, device=None):
968
+ super(Embedding, self).__init__()
969
+
970
+ self.hidden_size = config.hidden_size
971
+ # Word embeddings (parallel).
972
+ self.word_embeddings = nn.Embedding(
973
+ config.padded_vocab_size,
974
+ self.hidden_size,
975
+ dtype=config.torch_dtype,
976
+ device=device
977
+ )
978
+ self.fp32_residual_connection = config.fp32_residual_connection
979
+
980
+ def forward(self, input_ids):
981
+ # Embeddings.
982
+ words_embeddings = self.word_embeddings(input_ids)
983
+ embeddings = words_embeddings
984
+ # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
985
+ embeddings = embeddings.transpose(0, 1).contiguous()
986
+ # If the input flag for fp32 residual connection is set, convert for float.
987
+ if self.fp32_residual_connection:
988
+ embeddings = embeddings.float()
989
+ return embeddings
990
+
991
+
992
+ class ChatGLMModel(ChatGLMPreTrainedModel):
993
+ def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
994
+ super().__init__(config)
995
+ if empty_init:
996
+ init_method = skip_init
997
+ else:
998
+ init_method = default_init
999
+ init_kwargs = {}
1000
+ if device is not None:
1001
+ init_kwargs["device"] = device
1002
+ self.embedding = init_method(Embedding, config, **init_kwargs)
1003
+ self.num_layers = config.num_layers
1004
+ self.multi_query_group_num = config.multi_query_group_num
1005
+ self.kv_channels = config.kv_channels
1006
+
1007
+ # Rotary positional embeddings
1008
+ self.seq_length = config.seq_length
1009
+ rotary_dim = (
1010
+ config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
1011
+ )
1012
+
1013
+ self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
1014
+ dtype=config.torch_dtype)
1015
+ self.encoder = init_method(GLMTransformer, config, **init_kwargs)
1016
+ self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
1017
+ dtype=config.torch_dtype, **init_kwargs)
1018
+ self.pre_seq_len = config.pre_seq_len
1019
+ self.prefix_projection = config.prefix_projection
1020
+ if self.pre_seq_len is not None:
1021
+ for param in self.parameters():
1022
+ param.requires_grad = False
1023
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
1024
+ self.prefix_encoder = PrefixEncoder(config)
1025
+ self.dropout = torch.nn.Dropout(0.1)
1026
+
1027
+ def get_input_embeddings(self):
1028
+ return self.embedding.word_embeddings
1029
+
1030
+ def get_prompt(self, batch_size, device, dtype=torch.half):
1031
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
1032
+ past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
1033
+ past_key_values = past_key_values.view(
1034
+ batch_size,
1035
+ self.pre_seq_len,
1036
+ self.num_layers * 2,
1037
+ self.multi_query_group_num,
1038
+ self.kv_channels
1039
+ )
1040
+ # seq_len, b, nh, hidden_size
1041
+ past_key_values = self.dropout(past_key_values)
1042
+ past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
1043
+ return past_key_values
1044
+
1045
+ def forward(
1046
+ self,
1047
+ input_ids,
1048
+ position_ids: Optional[torch.Tensor] = None,
1049
+ attention_mask: Optional[torch.BoolTensor] = None,
1050
+ full_attention_mask: Optional[torch.BoolTensor] = None,
1051
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1052
+ inputs_embeds: Optional[torch.Tensor] = None,
1053
+ use_cache: Optional[bool] = None,
1054
+ output_hidden_states: Optional[bool] = None,
1055
+ return_dict: Optional[bool] = None,
1056
+ ):
1057
+ output_hidden_states = (
1058
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1059
+ )
1060
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1061
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1062
+
1063
+ batch_size, seq_length = input_ids.shape
1064
+
1065
+ if inputs_embeds is None:
1066
+ inputs_embeds = self.embedding(input_ids)
1067
+
1068
+ if self.pre_seq_len is not None:
1069
+ if past_key_values is None:
1070
+ past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
1071
+ dtype=inputs_embeds.dtype)
1072
+ if attention_mask is not None:
1073
+ attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),
1074
+ attention_mask], dim=-1)
1075
+
1076
+ if full_attention_mask is None:
1077
+ if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
1078
+ full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
1079
+
1080
+ # Rotary positional embeddings
1081
+ rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
1082
+ if position_ids is not None:
1083
+ rotary_pos_emb = rotary_pos_emb[position_ids]
1084
+ else:
1085
+ rotary_pos_emb = rotary_pos_emb[None, :seq_length]
1086
+ rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
1087
+
1088
+ # Run encoder.
1089
+ hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
1090
+ inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
1091
+ kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
1092
+ )
1093
+
1094
+ if not return_dict:
1095
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
1096
+
1097
+ return BaseModelOutputWithPast(
1098
+ last_hidden_state=hidden_states,
1099
+ past_key_values=presents,
1100
+ hidden_states=all_hidden_states,
1101
+ attentions=all_self_attentions,
1102
+ )
1103
+
1104
+ def quantize(self, weight_bit_width: int):
1105
+ # from .quantization import quantize
1106
+ quantize(self.encoder, weight_bit_width)
1107
+ return self
1108
+
1109
+
1110
+ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1111
+ def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
1112
+ super().__init__(config)
1113
+
1114
+ self.max_sequence_length = config.max_length
1115
+ self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
1116
+ self.config = config
1117
+ self.quantized = False
1118
+
1119
+ if self.config.quantization_bit:
1120
+ self.quantize(self.config.quantization_bit, empty_init=True)
1121
+
1122
+ def _update_model_kwargs_for_generation(
1123
+ self,
1124
+ outputs: ModelOutput,
1125
+ model_kwargs: Dict[str, Any],
1126
+ is_encoder_decoder: bool = False,
1127
+ standardize_cache_format: bool = False,
1128
+ ) -> Dict[str, Any]:
1129
+ # update past_key_values
1130
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
1131
+ outputs, standardize_cache_format=standardize_cache_format
1132
+ )
1133
+
1134
+ # update attention mask
1135
+ if "attention_mask" in model_kwargs:
1136
+ attention_mask = model_kwargs["attention_mask"]
1137
+ model_kwargs["attention_mask"] = torch.cat(
1138
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
1139
+ )
1140
+
1141
+ # update position ids
1142
+ if "position_ids" in model_kwargs:
1143
+ position_ids = model_kwargs["position_ids"]
1144
+ new_position_id = position_ids[..., -1:].clone()
1145
+ new_position_id += 1
1146
+ model_kwargs["position_ids"] = torch.cat(
1147
+ [position_ids, new_position_id], dim=-1
1148
+ )
1149
+
1150
+ model_kwargs["is_first_forward"] = False
1151
+ return model_kwargs
1152
+
1153
+ def prepare_inputs_for_generation(
1154
+ self,
1155
+ input_ids: torch.LongTensor,
1156
+ past_key_values: Optional[torch.Tensor] = None,
1157
+ attention_mask: Optional[torch.Tensor] = None,
1158
+ position_ids: Optional[torch.Tensor] = None,
1159
+ use_cache: Optional[bool] = None,
1160
+ is_first_forward: bool = True,
1161
+ **kwargs
1162
+ ) -> dict:
1163
+ # only last token for input_ids if past is not None
1164
+ if position_ids is None:
1165
+ position_ids = self.get_position_ids(input_ids, device=input_ids.device)
1166
+ if not is_first_forward:
1167
+ if past_key_values is not None:
1168
+ position_ids = position_ids[..., -1:]
1169
+ input_ids = input_ids[:, -1:]
1170
+ return {
1171
+ "input_ids": input_ids,
1172
+ "past_key_values": past_key_values,
1173
+ "position_ids": position_ids,
1174
+ "attention_mask": attention_mask,
1175
+ "return_last_logit": True,
1176
+ "use_cache": use_cache
1177
+ }
1178
+
1179
+ def forward(
1180
+ self,
1181
+ input_ids: Optional[torch.Tensor] = None,
1182
+ position_ids: Optional[torch.Tensor] = None,
1183
+ attention_mask: Optional[torch.Tensor] = None,
1184
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
1185
+ inputs_embeds: Optional[torch.Tensor] = None,
1186
+ labels: Optional[torch.Tensor] = None,
1187
+ use_cache: Optional[bool] = None,
1188
+ output_attentions: Optional[bool] = None,
1189
+ output_hidden_states: Optional[bool] = None,
1190
+ return_dict: Optional[bool] = None,
1191
+ return_last_logit: Optional[bool] = False,
1192
+ ):
1193
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1194
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1195
+
1196
+ transformer_outputs = self.transformer(
1197
+ input_ids=input_ids,
1198
+ position_ids=position_ids,
1199
+ attention_mask=attention_mask,
1200
+ past_key_values=past_key_values,
1201
+ inputs_embeds=inputs_embeds,
1202
+ use_cache=use_cache,
1203
+ output_hidden_states=output_hidden_states,
1204
+ return_dict=return_dict,
1205
+ )
1206
+
1207
+ hidden_states = transformer_outputs[0]
1208
+ if return_last_logit:
1209
+ hidden_states = hidden_states[-1:]
1210
+ lm_logits = self.transformer.output_layer(hidden_states)
1211
+ lm_logits = lm_logits.transpose(0, 1).contiguous()
1212
+
1213
+ loss = None
1214
+ if labels is not None:
1215
+ lm_logits = lm_logits.to(torch.float32)
1216
+
1217
+ # Shift so that tokens < n predict n
1218
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1219
+ shift_labels = labels[..., 1:].contiguous()
1220
+ # Flatten the tokens
1221
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
1222
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1223
+
1224
+ lm_logits = lm_logits.to(hidden_states.dtype)
1225
+ loss = loss.to(hidden_states.dtype)
1226
+
1227
+ if not return_dict:
1228
+ output = (lm_logits,) + transformer_outputs[1:]
1229
+ return ((loss,) + output) if loss is not None else output
1230
+
1231
+ return CausalLMOutputWithPast(
1232
+ loss=loss,
1233
+ logits=lm_logits,
1234
+ past_key_values=transformer_outputs.past_key_values,
1235
+ hidden_states=transformer_outputs.hidden_states,
1236
+ attentions=transformer_outputs.attentions,
1237
+ )
1238
+
1239
+ @staticmethod
1240
+ def _reorder_cache(
1241
+ past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
1242
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
1243
+ """
1244
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1245
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1246
+ beam_idx at every generation step.
1247
+
1248
+ Output shares the same memory storage as `past`.
1249
+ """
1250
+ return tuple(
1251
+ (
1252
+ layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
1253
+ layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
1254
+ )
1255
+ for layer_past in past
1256
+ )
1257
+
1258
+ def process_response(self, output, history):
1259
+ content = ""
1260
+ history = deepcopy(history)
1261
+ for response in output.split("<|assistant|>"):
1262
+ metadata, content = response.split("\n", maxsplit=1)
1263
+ if not metadata.strip():
1264
+ content = content.strip()
1265
+ history.append({"role": "assistant", "metadata": metadata, "content": content})
1266
+ content = content.replace("[[训练时间]]", "2023年")
1267
+ else:
1268
+ history.append({"role": "assistant", "metadata": metadata, "content": content})
1269
+ if history[0]["role"] == "system" and "tools" in history[0]:
1270
+ content = "\n".join(content.split("\n")[1:-1])
1271
+ def tool_call(**kwargs):
1272
+ return kwargs
1273
+ parameters = eval(content)
1274
+ content = {"name": metadata.strip(), "parameters": parameters}
1275
+ else:
1276
+ content = {"name": metadata.strip(), "content": content}
1277
+ return content, history
1278
+
1279
+ @torch.inference_mode()
1280
+ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
1281
+ max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
1282
+ **kwargs):
1283
+ if history is None:
1284
+ history = []
1285
+ if logits_processor is None:
1286
+ logits_processor = LogitsProcessorList()
1287
+ logits_processor.append(InvalidScoreLogitsProcessor())
1288
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1289
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1290
+ inputs = tokenizer.build_chat_input(query, history=history, role=role)
1291
+ inputs = inputs.to(self.device)
1292
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
1293
+ tokenizer.get_command("<|observation|>")]
1294
+ outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
1295
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1296
+ response = tokenizer.decode(outputs)
1297
+ history.append({"role": role, "content": query})
1298
+ response, history = self.process_response(response, history)
1299
+ return response, history
1300
+
1301
+ @torch.inference_mode()
1302
+ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
1303
+ past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
1304
+ logits_processor=None, return_past_key_values=False, **kwargs):
1305
+ if history is None:
1306
+ history = []
1307
+ if logits_processor is None:
1308
+ logits_processor = LogitsProcessorList()
1309
+ logits_processor.append(InvalidScoreLogitsProcessor())
1310
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
1311
+ tokenizer.get_command("<|observation|>")]
1312
+ gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
1313
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1314
+ if past_key_values is None:
1315
+ inputs = tokenizer.build_chat_input(query, history=history, role=role)
1316
+ else:
1317
+ inputs = tokenizer.build_chat_input(query, role=role)
1318
+ inputs = inputs.to(self.device)
1319
+ if past_key_values is not None:
1320
+ past_length = past_key_values[0][0].shape[0]
1321
+ if self.transformer.pre_seq_len is not None:
1322
+ past_length -= self.transformer.pre_seq_len
1323
+ inputs.position_ids += past_length
1324
+ attention_mask = inputs.attention_mask
1325
+ attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
1326
+ inputs['attention_mask'] = attention_mask
1327
+ history.append({"role": role, "content": query})
1328
+ for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
1329
+ eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
1330
+ **gen_kwargs):
1331
+ if return_past_key_values:
1332
+ outputs, past_key_values = outputs
1333
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1334
+ response = tokenizer.decode(outputs)
1335
+ if response and response[-1] != "�":
1336
+ response, new_history = self.process_response(response, history)
1337
+ if return_past_key_values:
1338
+ yield response, new_history, past_key_values
1339
+ else:
1340
+ yield response, new_history
1341
+
1342
+ @torch.inference_mode()
1343
+ def stream_generate(
1344
+ self,
1345
+ input_ids,
1346
+ generation_config: Optional[GenerationConfig] = None,
1347
+ logits_processor: Optional[LogitsProcessorList] = None,
1348
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1349
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
1350
+ return_past_key_values=False,
1351
+ **kwargs,
1352
+ ):
1353
+ batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
1354
+
1355
+ if generation_config is None:
1356
+ generation_config = self.generation_config
1357
+ generation_config = copy.deepcopy(generation_config)
1358
+ model_kwargs = generation_config.update(**kwargs)
1359
+ model_kwargs["use_cache"] = generation_config.use_cache
1360
+ bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
1361
+
1362
+ if isinstance(eos_token_id, int):
1363
+ eos_token_id = [eos_token_id]
1364
+ eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
1365
+
1366
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
1367
+ if has_default_max_length and generation_config.max_new_tokens is None:
1368
+ warnings.warn(
1369
+ f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
1370
+ "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
1371
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
1372
+ UserWarning,
1373
+ )
1374
+ elif generation_config.max_new_tokens is not None:
1375
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
1376
+ if not has_default_max_length:
1377
+ logger.warn(
1378
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
1379
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
1380
+ "Please refer to the documentation for more information. "
1381
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
1382
+ UserWarning,
1383
+ )
1384
+
1385
+ if input_ids_seq_length >= generation_config.max_length:
1386
+ input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
1387
+ logger.warning(
1388
+ f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
1389
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
1390
+ " increasing `max_new_tokens`."
1391
+ )
1392
+
1393
+ # 2. Set generation parameters if not already defined
1394
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1395
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1396
+
1397
+ logits_processor = self._get_logits_processor(
1398
+ generation_config=generation_config,
1399
+ input_ids_seq_length=input_ids_seq_length,
1400
+ encoder_input_ids=input_ids,
1401
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1402
+ logits_processor=logits_processor,
1403
+ )
1404
+
1405
+ stopping_criteria = self._get_stopping_criteria(
1406
+ generation_config=generation_config, stopping_criteria=stopping_criteria
1407
+ )
1408
+ logits_warper = self._get_logits_warper(generation_config)
1409
+
1410
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
1411
+ scores = None
1412
+ while True:
1413
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1414
+ # forward pass to get next token
1415
+ outputs = self(
1416
+ **model_inputs,
1417
+ return_dict=True,
1418
+ output_attentions=False,
1419
+ output_hidden_states=False,
1420
+ )
1421
+
1422
+ next_token_logits = outputs.logits[:, -1, :]
1423
+
1424
+ # pre-process distribution
1425
+ next_token_scores = logits_processor(input_ids, next_token_logits)
1426
+ next_token_scores = logits_warper(input_ids, next_token_scores)
1427
+
1428
+ # sample
1429
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
1430
+ if generation_config.do_sample:
1431
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1432
+ else:
1433
+ next_tokens = torch.argmax(probs, dim=-1)
1434
+ # update generated ids, model inputs, and length for next step
1435
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
1436
+ model_kwargs = self._update_model_kwargs_for_generation(
1437
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1438
+ )
1439
+ unfinished_sequences = unfinished_sequences.mul(
1440
+ next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
1441
+ )
1442
+ if return_past_key_values:
1443
+ yield input_ids, outputs.past_key_values
1444
+ else:
1445
+ yield input_ids
1446
+ # stop when each sentence is finished, or if we exceed the maximum length
1447
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
1448
+ break
1449
+
1450
+ def quantize(self, bits: int, empty_init=False, device=None, **kwargs):
1451
+ if bits == 0:
1452
+ return
1453
+
1454
+ # from .quantization import quantize
1455
+
1456
+ if self.quantized:
1457
+ logger.info("Already quantized.")
1458
+ return self
1459
+
1460
+ self.quantized = True
1461
+
1462
+ self.config.quantization_bit = bits
1463
+
1464
+ self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device,
1465
+ **kwargs)
1466
+ return self
1467
+
1468
+
1469
+ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1470
+ def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
1471
+ super().__init__(config)
1472
+
1473
+ self.num_labels = config.num_labels
1474
+ self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
1475
+
1476
+ self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
1477
+ if config.classifier_dropout is not None:
1478
+ self.dropout = nn.Dropout(config.classifier_dropout)
1479
+ else:
1480
+ self.dropout = None
1481
+ self.config = config
1482
+
1483
+ if self.config.quantization_bit:
1484
+ self.quantize(self.config.quantization_bit, empty_init=True)
1485
+
1486
+ def forward(
1487
+ self,
1488
+ input_ids: Optional[torch.LongTensor] = None,
1489
+ position_ids: Optional[torch.LongTensor] = None,
1490
+ attention_mask: Optional[torch.Tensor] = None,
1491
+ full_attention_mask: Optional[torch.Tensor] = None,
1492
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1493
+ inputs_embeds: Optional[torch.LongTensor] = None,
1494
+ labels: Optional[torch.LongTensor] = None,
1495
+ use_cache: Optional[bool] = None,
1496
+ output_hidden_states: Optional[bool] = None,
1497
+ return_dict: Optional[bool] = None,
1498
+ ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
1499
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1500
+
1501
+ transformer_outputs = self.transformer(
1502
+ input_ids=input_ids,
1503
+ position_ids=position_ids,
1504
+ attention_mask=attention_mask,
1505
+ full_attention_mask=full_attention_mask,
1506
+ past_key_values=past_key_values,
1507
+ inputs_embeds=inputs_embeds,
1508
+ use_cache=use_cache,
1509
+ output_hidden_states=output_hidden_states,
1510
+ return_dict=return_dict,
1511
+ )
1512
+
1513
+ hidden_states = transformer_outputs[0]
1514
+ pooled_hidden_states = hidden_states[-1]
1515
+ if self.dropout is not None:
1516
+ pooled_hidden_states = self.dropout(pooled_hidden_states)
1517
+ logits = self.classifier_head(pooled_hidden_states)
1518
+
1519
+ loss = None
1520
+ if labels is not None:
1521
+ if self.config.problem_type is None:
1522
+ if self.num_labels == 1:
1523
+ self.config.problem_type = "regression"
1524
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1525
+ self.config.problem_type = "single_label_classification"
1526
+ else:
1527
+ self.config.problem_type = "multi_label_classification"
1528
+
1529
+ if self.config.problem_type == "regression":
1530
+ loss_fct = MSELoss()
1531
+ if self.num_labels == 1:
1532
+ loss = loss_fct(logits.squeeze().float(), labels.squeeze())
1533
+ else:
1534
+ loss = loss_fct(logits.float(), labels)
1535
+ elif self.config.problem_type == "single_label_classification":
1536
+ loss_fct = CrossEntropyLoss()
1537
+ loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1))
1538
+ elif self.config.problem_type == "multi_label_classification":
1539
+ loss_fct = BCEWithLogitsLoss()
1540
+ loss = loss_fct(logits.float(), labels.view(-1, self.num_labels))
1541
+
1542
+ if not return_dict:
1543
+ output = (logits,) + transformer_outputs[1:]
1544
+ return ((loss,) + output) if loss is not None else output
1545
+
1546
+ return SequenceClassifierOutputWithPast(
1547
+ loss=loss,
1548
+ logits=logits,
1549
+ past_key_values=transformer_outputs.past_key_values,
1550
+ hidden_states=transformer_outputs.hidden_states,
1551
+ attentions=transformer_outputs.attentions,
1552
+ )
diffsynth/models/lora.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .sd_unet import SDUNet
3
+ from .sdxl_unet import SDXLUNet
4
+ from .sd_text_encoder import SDTextEncoder
5
+ from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
6
+ from .sd3_dit import SD3DiT
7
+ from .hunyuan_dit import HunyuanDiT
8
+
9
+
10
+
11
+ class LoRAFromCivitai:
12
+ def __init__(self):
13
+ self.supported_model_classes = []
14
+ self.lora_prefix = []
15
+ self.renamed_lora_prefix = {}
16
+ self.special_keys = {}
17
+
18
+
19
+ def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
20
+ renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
21
+ state_dict_ = {}
22
+ for key in state_dict:
23
+ if ".lora_up" not in key:
24
+ continue
25
+ if not key.startswith(lora_prefix):
26
+ continue
27
+ weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
28
+ weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
29
+ if len(weight_up.shape) == 4:
30
+ weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
31
+ weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
32
+ lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
33
+ else:
34
+ lora_weight = alpha * torch.mm(weight_up, weight_down)
35
+ target_name = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight"
36
+ for special_key in self.special_keys:
37
+ target_name = target_name.replace(special_key, self.special_keys[special_key])
38
+ state_dict_[target_name] = lora_weight.cpu()
39
+ return state_dict_
40
+
41
+
42
+ def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
43
+ state_dict_model = model.state_dict()
44
+ state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha)
45
+ if model_resource == "diffusers":
46
+ state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora)
47
+ elif model_resource == "civitai":
48
+ state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora)
49
+ if len(state_dict_lora) > 0:
50
+ print(f" {len(state_dict_lora)} tensors are updated.")
51
+ for name in state_dict_lora:
52
+ state_dict_model[name] += state_dict_lora[name].to(
53
+ dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
54
+ model.load_state_dict(state_dict_model)
55
+
56
+
57
+ def match(self, model, state_dict_lora):
58
+ for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes):
59
+ if not isinstance(model, model_class):
60
+ continue
61
+ state_dict_model = model.state_dict()
62
+ for model_resource in ["diffusers", "civitai"]:
63
+ try:
64
+ state_dict_lora_ = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0)
65
+ converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \
66
+ else model.__class__.state_dict_converter().from_civitai
67
+ state_dict_lora_ = converter_fn(state_dict_lora_)
68
+ if len(state_dict_lora_) == 0:
69
+ continue
70
+ for name in state_dict_lora_:
71
+ if name not in state_dict_model:
72
+ break
73
+ else:
74
+ return lora_prefix, model_resource
75
+ except:
76
+ pass
77
+ return None
78
+
79
+
80
+
81
+ class SDLoRAFromCivitai(LoRAFromCivitai):
82
+ def __init__(self):
83
+ super().__init__()
84
+ self.supported_model_classes = [SDUNet, SDTextEncoder]
85
+ self.lora_prefix = ["lora_unet_", "lora_te_"]
86
+ self.special_keys = {
87
+ "down.blocks": "down_blocks",
88
+ "up.blocks": "up_blocks",
89
+ "mid.block": "mid_block",
90
+ "proj.in": "proj_in",
91
+ "proj.out": "proj_out",
92
+ "transformer.blocks": "transformer_blocks",
93
+ "to.q": "to_q",
94
+ "to.k": "to_k",
95
+ "to.v": "to_v",
96
+ "to.out": "to_out",
97
+ "text.model": "text_model",
98
+ "self.attn.q.proj": "self_attn.q_proj",
99
+ "self.attn.k.proj": "self_attn.k_proj",
100
+ "self.attn.v.proj": "self_attn.v_proj",
101
+ "self.attn.out.proj": "self_attn.out_proj",
102
+ "input.blocks": "model.diffusion_model.input_blocks",
103
+ "middle.block": "model.diffusion_model.middle_block",
104
+ "output.blocks": "model.diffusion_model.output_blocks",
105
+ }
106
+
107
+
108
+ class SDXLLoRAFromCivitai(LoRAFromCivitai):
109
+ def __init__(self):
110
+ super().__init__()
111
+ self.supported_model_classes = [SDXLUNet, SDXLTextEncoder, SDXLTextEncoder2]
112
+ self.lora_prefix = ["lora_unet_", "lora_te1_", "lora_te2_"]
113
+ self.renamed_lora_prefix = {"lora_te2_": "2"}
114
+ self.special_keys = {
115
+ "down.blocks": "down_blocks",
116
+ "up.blocks": "up_blocks",
117
+ "mid.block": "mid_block",
118
+ "proj.in": "proj_in",
119
+ "proj.out": "proj_out",
120
+ "transformer.blocks": "transformer_blocks",
121
+ "to.q": "to_q",
122
+ "to.k": "to_k",
123
+ "to.v": "to_v",
124
+ "to.out": "to_out",
125
+ "text.model": "conditioner.embedders.0.transformer.text_model",
126
+ "self.attn.q.proj": "self_attn.q_proj",
127
+ "self.attn.k.proj": "self_attn.k_proj",
128
+ "self.attn.v.proj": "self_attn.v_proj",
129
+ "self.attn.out.proj": "self_attn.out_proj",
130
+ "input.blocks": "model.diffusion_model.input_blocks",
131
+ "middle.block": "model.diffusion_model.middle_block",
132
+ "output.blocks": "model.diffusion_model.output_blocks",
133
+ "2conditioner.embedders.0.transformer.text_model.encoder.layers": "text_model.encoder.layers"
134
+ }
135
+
136
+
137
+
138
+ class GeneralLoRAFromPeft:
139
+ def __init__(self):
140
+ self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT]
141
+
142
+
143
+ def convert_state_dict(self, state_dict, alpha=1.0, device="cuda", torch_dtype=torch.float16):
144
+ state_dict_ = {}
145
+ for key in state_dict:
146
+ if ".lora_B." not in key:
147
+ continue
148
+ weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
149
+ weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
150
+ if len(weight_up.shape) == 4:
151
+ weight_up = weight_up.squeeze(3).squeeze(2)
152
+ weight_down = weight_down.squeeze(3).squeeze(2)
153
+ lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
154
+ else:
155
+ lora_weight = alpha * torch.mm(weight_up, weight_down)
156
+ keys = key.split(".")
157
+ keys.pop(keys.index("lora_B") + 1)
158
+ keys.pop(keys.index("lora_B"))
159
+ target_name = ".".join(keys)
160
+ state_dict_[target_name] = lora_weight.cpu()
161
+ return state_dict_
162
+
163
+
164
+ def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
165
+ state_dict_model = model.state_dict()
166
+ for name, param in state_dict_model.items():
167
+ torch_dtype = param.dtype
168
+ device = param.device
169
+ break
170
+ state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, device=device, torch_dtype=torch_dtype)
171
+ if len(state_dict_lora) > 0:
172
+ print(f" {len(state_dict_lora)} tensors are updated.")
173
+ for name in state_dict_lora:
174
+ state_dict_model[name] += state_dict_lora[name].to(
175
+ dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
176
+ model.load_state_dict(state_dict_model)
177
+
178
+
179
+ def match(self, model, state_dict_lora):
180
+ for model_class in self.supported_model_classes:
181
+ if not isinstance(model, model_class):
182
+ continue
183
+ state_dict_model = model.state_dict()
184
+ try:
185
+ state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0)
186
+ if len(state_dict_lora_) == 0:
187
+ continue
188
+ for name in state_dict_lora_:
189
+ if name not in state_dict_model:
190
+ break
191
+ else:
192
+ return "", ""
193
+ except:
194
+ pass
195
+ return None
diffsynth/models/model_manager.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch, hashlib, json, importlib
2
+ from safetensors import safe_open
3
+ from torch import Tensor
4
+ from typing_extensions import Literal, TypeAlias
5
+ from typing import List
6
+
7
+ from .downloader import download_models, Preset_model_id, Preset_model_website
8
+
9
+ from .sd_text_encoder import SDTextEncoder
10
+ from .sd_unet import SDUNet
11
+ from .sd_vae_encoder import SDVAEEncoder
12
+ from .sd_vae_decoder import SDVAEDecoder
13
+ from .lora import SDLoRAFromCivitai, SDXLLoRAFromCivitai, GeneralLoRAFromPeft
14
+
15
+ from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
16
+ from .sdxl_unet import SDXLUNet
17
+ from .sdxl_vae_decoder import SDXLVAEDecoder
18
+ from .sdxl_vae_encoder import SDXLVAEEncoder
19
+
20
+ from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
21
+ from .sd3_dit import SD3DiT
22
+ from .sd3_vae_decoder import SD3VAEDecoder
23
+ from .sd3_vae_encoder import SD3VAEEncoder
24
+
25
+ from .sd_controlnet import SDControlNet
26
+ from .sdxl_controlnet import SDXLControlNetUnion
27
+
28
+ from .sd_motion import SDMotionModel
29
+ from .sdxl_motion import SDXLMotionModel
30
+
31
+ from .svd_image_encoder import SVDImageEncoder
32
+ from .svd_unet import SVDUNet
33
+ from .svd_vae_decoder import SVDVAEDecoder
34
+ from .svd_vae_encoder import SVDVAEEncoder
35
+
36
+ from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
37
+ from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
38
+
39
+ from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
40
+ from .hunyuan_dit import HunyuanDiT
41
+
42
+ from .flux_dit import FluxDiT
43
+ from .flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
44
+ from .flux_vae import FluxVAEEncoder, FluxVAEDecoder
45
+
46
+ from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
47
+
48
+
49
+
50
+ def load_state_dict(file_path, torch_dtype=None):
51
+ if file_path.endswith(".safetensors"):
52
+ return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
53
+ else:
54
+ return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
55
+
56
+
57
+ def load_state_dict_from_safetensors(file_path, torch_dtype=None):
58
+ state_dict = {}
59
+ with safe_open(file_path, framework="pt", device="cpu") as f:
60
+ for k in f.keys():
61
+ state_dict[k] = f.get_tensor(k)
62
+ if torch_dtype is not None:
63
+ state_dict[k] = state_dict[k].to(torch_dtype)
64
+ return state_dict
65
+
66
+
67
+ def load_state_dict_from_bin(file_path, torch_dtype=None):
68
+ state_dict = torch.load(file_path, map_location="cpu")
69
+ if torch_dtype is not None:
70
+ for i in state_dict:
71
+ if isinstance(state_dict[i], torch.Tensor):
72
+ state_dict[i] = state_dict[i].to(torch_dtype)
73
+ return state_dict
74
+
75
+
76
+ def search_for_embeddings(state_dict):
77
+ embeddings = []
78
+ for k in state_dict:
79
+ if isinstance(state_dict[k], torch.Tensor):
80
+ embeddings.append(state_dict[k])
81
+ elif isinstance(state_dict[k], dict):
82
+ embeddings += search_for_embeddings(state_dict[k])
83
+ return embeddings
84
+
85
+
86
+ def search_parameter(param, state_dict):
87
+ for name, param_ in state_dict.items():
88
+ if param.numel() == param_.numel():
89
+ if param.shape == param_.shape:
90
+ if torch.dist(param, param_) < 1e-3:
91
+ return name
92
+ else:
93
+ if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
94
+ return name
95
+ return None
96
+
97
+
98
+ def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
99
+ matched_keys = set()
100
+ with torch.no_grad():
101
+ for name in source_state_dict:
102
+ rename = search_parameter(source_state_dict[name], target_state_dict)
103
+ if rename is not None:
104
+ print(f'"{name}": "{rename}",')
105
+ matched_keys.add(rename)
106
+ elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
107
+ length = source_state_dict[name].shape[0] // 3
108
+ rename = []
109
+ for i in range(3):
110
+ rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
111
+ if None not in rename:
112
+ print(f'"{name}": {rename},')
113
+ for rename_ in rename:
114
+ matched_keys.add(rename_)
115
+ for name in target_state_dict:
116
+ if name not in matched_keys:
117
+ print("Cannot find", name, target_state_dict[name].shape)
118
+
119
+
120
+ def search_for_files(folder, extensions):
121
+ files = []
122
+ if os.path.isdir(folder):
123
+ for file in sorted(os.listdir(folder)):
124
+ files += search_for_files(os.path.join(folder, file), extensions)
125
+ elif os.path.isfile(folder):
126
+ for extension in extensions:
127
+ if folder.endswith(extension):
128
+ files.append(folder)
129
+ break
130
+ return files
131
+
132
+
133
+ def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
134
+ keys = []
135
+ for key, value in state_dict.items():
136
+ if isinstance(key, str):
137
+ if isinstance(value, Tensor):
138
+ if with_shape:
139
+ shape = "_".join(map(str, list(value.shape)))
140
+ keys.append(key + ":" + shape)
141
+ keys.append(key)
142
+ elif isinstance(value, dict):
143
+ keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
144
+ keys.sort()
145
+ keys_str = ",".join(keys)
146
+ return keys_str
147
+
148
+
149
+ def split_state_dict_with_prefix(state_dict):
150
+ keys = sorted([key for key in state_dict if isinstance(key, str)])
151
+ prefix_dict = {}
152
+ for key in keys:
153
+ prefix = key if "." not in key else key.split(".")[0]
154
+ if prefix not in prefix_dict:
155
+ prefix_dict[prefix] = []
156
+ prefix_dict[prefix].append(key)
157
+ state_dicts = []
158
+ for prefix, keys in prefix_dict.items():
159
+ sub_state_dict = {key: state_dict[key] for key in keys}
160
+ state_dicts.append(sub_state_dict)
161
+ return state_dicts
162
+
163
+
164
+ def hash_state_dict_keys(state_dict, with_shape=True):
165
+ keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
166
+ keys_str = keys_str.encode(encoding="UTF-8")
167
+ return hashlib.md5(keys_str).hexdigest()
168
+
169
+
170
+ def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
171
+ loaded_model_names, loaded_models = [], []
172
+ for model_name, model_class in zip(model_names, model_classes):
173
+ print(f" model_name: {model_name} model_class: {model_class.__name__}")
174
+ state_dict_converter = model_class.state_dict_converter()
175
+ if model_resource == "civitai":
176
+ state_dict_results = state_dict_converter.from_civitai(state_dict)
177
+ elif model_resource == "diffusers":
178
+ state_dict_results = state_dict_converter.from_diffusers(state_dict)
179
+ if isinstance(state_dict_results, tuple):
180
+ model_state_dict, extra_kwargs = state_dict_results
181
+ print(f" This model is initialized with extra kwargs: {extra_kwargs}")
182
+ else:
183
+ model_state_dict, extra_kwargs = state_dict_results, {}
184
+ torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
185
+ model = model_class(**extra_kwargs).to(dtype=torch_dtype, device=device)
186
+ model.load_state_dict(model_state_dict)
187
+ loaded_model_names.append(model_name)
188
+ loaded_models.append(model)
189
+ return loaded_model_names, loaded_models
190
+
191
+
192
+ def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
193
+ loaded_model_names, loaded_models = [], []
194
+ for model_name, model_class in zip(model_names, model_classes):
195
+ model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
196
+ if torch_dtype == torch.float16 and hasattr(model, "half"):
197
+ model = model.half()
198
+ model = model.to(device=device)
199
+ loaded_model_names.append(model_name)
200
+ loaded_models.append(model)
201
+ return loaded_model_names, loaded_models
202
+
203
+
204
+ def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
205
+ print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
206
+ base_state_dict = base_model.state_dict()
207
+ base_model.to("cpu")
208
+ del base_model
209
+ model = model_class(**extra_kwargs)
210
+ model.load_state_dict(base_state_dict, strict=False)
211
+ model.load_state_dict(state_dict, strict=False)
212
+ model.to(dtype=torch_dtype, device=device)
213
+ return model
214
+
215
+
216
+ def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
217
+ loaded_model_names, loaded_models = [], []
218
+ for model_name, model_class in zip(model_names, model_classes):
219
+ while True:
220
+ for model_id in range(len(model_manager.model)):
221
+ base_model_name = model_manager.model_name[model_id]
222
+ if base_model_name == model_name:
223
+ base_model_path = model_manager.model_path[model_id]
224
+ base_model = model_manager.model[model_id]
225
+ print(f" Adding patch model to {base_model_name} ({base_model_path})")
226
+ patched_model = load_single_patch_model_from_single_file(
227
+ state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
228
+ loaded_model_names.append(base_model_name)
229
+ loaded_models.append(patched_model)
230
+ model_manager.model.pop(model_id)
231
+ model_manager.model_path.pop(model_id)
232
+ model_manager.model_name.pop(model_id)
233
+ break
234
+ else:
235
+ break
236
+ return loaded_model_names, loaded_models
237
+
238
+
239
+
240
+ class ModelDetectorTemplate:
241
+ def __init__(self):
242
+ pass
243
+
244
+ def match(self, file_path="", state_dict={}):
245
+ return False
246
+
247
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
248
+ return [], []
249
+
250
+
251
+
252
+ class ModelDetectorFromSingleFile:
253
+ def __init__(self, model_loader_configs=[]):
254
+ self.keys_hash_with_shape_dict = {}
255
+ self.keys_hash_dict = {}
256
+ for metadata in model_loader_configs:
257
+ self.add_model_metadata(*metadata)
258
+
259
+
260
+ def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
261
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
262
+ if keys_hash is not None:
263
+ self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
264
+
265
+
266
+ def match(self, file_path="", state_dict={}):
267
+ if os.path.isdir(file_path):
268
+ return False
269
+ if len(state_dict) == 0:
270
+ state_dict = load_state_dict(file_path)
271
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
272
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
273
+ return True
274
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
275
+ if keys_hash in self.keys_hash_dict:
276
+ return True
277
+ return False
278
+
279
+
280
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
281
+ if len(state_dict) == 0:
282
+ state_dict = load_state_dict(file_path)
283
+
284
+ # Load models with strict matching
285
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
286
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
287
+ model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
288
+ loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
289
+ return loaded_model_names, loaded_models
290
+
291
+ # Load models without strict matching
292
+ # (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
293
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
294
+ if keys_hash in self.keys_hash_dict:
295
+ model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
296
+ loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
297
+ return loaded_model_names, loaded_models
298
+
299
+ return loaded_model_names, loaded_models
300
+
301
+
302
+
303
+ class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
304
+ def __init__(self, model_loader_configs=[]):
305
+ super().__init__(model_loader_configs)
306
+
307
+
308
+ def match(self, file_path="", state_dict={}):
309
+ if os.path.isdir(file_path):
310
+ return False
311
+ if len(state_dict) == 0:
312
+ state_dict = load_state_dict(file_path)
313
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
314
+ for sub_state_dict in splited_state_dict:
315
+ if super().match(file_path, sub_state_dict):
316
+ return True
317
+ return False
318
+
319
+
320
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
321
+ # Split the state_dict and load from each component
322
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
323
+ valid_state_dict = {}
324
+ for sub_state_dict in splited_state_dict:
325
+ if super().match(file_path, sub_state_dict):
326
+ valid_state_dict.update(sub_state_dict)
327
+ if super().match(file_path, valid_state_dict):
328
+ loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
329
+ else:
330
+ loaded_model_names, loaded_models = [], []
331
+ for sub_state_dict in splited_state_dict:
332
+ if super().match(file_path, sub_state_dict):
333
+ loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
334
+ loaded_model_names += loaded_model_names_
335
+ loaded_models += loaded_models_
336
+ return loaded_model_names, loaded_models
337
+
338
+
339
+
340
+ class ModelDetectorFromHuggingfaceFolder:
341
+ def __init__(self, model_loader_configs=[]):
342
+ self.architecture_dict = {}
343
+ for metadata in model_loader_configs:
344
+ self.add_model_metadata(*metadata)
345
+
346
+
347
+ def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
348
+ self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
349
+
350
+
351
+ def match(self, file_path="", state_dict={}):
352
+ if os.path.isfile(file_path):
353
+ return False
354
+ file_list = os.listdir(file_path)
355
+ if "config.json" not in file_list:
356
+ return False
357
+ with open(os.path.join(file_path, "config.json"), "r") as f:
358
+ config = json.load(f)
359
+ if "architectures" not in config:
360
+ return False
361
+ return True
362
+
363
+
364
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
365
+ with open(os.path.join(file_path, "config.json"), "r") as f:
366
+ config = json.load(f)
367
+ loaded_model_names, loaded_models = [], []
368
+ for architecture in config["architectures"]:
369
+ huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
370
+ if redirected_architecture is not None:
371
+ architecture = redirected_architecture
372
+ model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
373
+ loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
374
+ loaded_model_names += loaded_model_names_
375
+ loaded_models += loaded_models_
376
+ return loaded_model_names, loaded_models
377
+
378
+
379
+
380
+ class ModelDetectorFromPatchedSingleFile:
381
+ def __init__(self, model_loader_configs=[]):
382
+ self.keys_hash_with_shape_dict = {}
383
+ for metadata in model_loader_configs:
384
+ self.add_model_metadata(*metadata)
385
+
386
+
387
+ def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
388
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
389
+
390
+
391
+ def match(self, file_path="", state_dict={}):
392
+ if os.path.isdir(file_path):
393
+ return False
394
+ if len(state_dict) == 0:
395
+ state_dict = load_state_dict(file_path)
396
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
397
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
398
+ return True
399
+ return False
400
+
401
+
402
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
403
+ if len(state_dict) == 0:
404
+ state_dict = load_state_dict(file_path)
405
+
406
+ # Load models with strict matching
407
+ loaded_model_names, loaded_models = [], []
408
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
409
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
410
+ model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
411
+ loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
412
+ state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
413
+ loaded_model_names += loaded_model_names_
414
+ loaded_models += loaded_models_
415
+ return loaded_model_names, loaded_models
416
+
417
+
418
+
419
+ class ModelManager:
420
+ def __init__(
421
+ self,
422
+ torch_dtype=torch.float16,
423
+ device="cuda",
424
+ model_id_list: List[Preset_model_id] = [],
425
+ downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
426
+ file_path_list: List[str] = [],
427
+ ):
428
+ self.torch_dtype = torch_dtype
429
+ self.device = device
430
+ self.model = []
431
+ self.model_path = []
432
+ self.model_name = []
433
+ downloaded_files = download_models(model_id_list, downloading_priority) if len(model_id_list) > 0 else []
434
+ self.model_detector = [
435
+ ModelDetectorFromSingleFile(model_loader_configs),
436
+ ModelDetectorFromSplitedSingleFile(model_loader_configs),
437
+ ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
438
+ ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
439
+ ]
440
+ self.load_models(downloaded_files + file_path_list)
441
+
442
+
443
+ def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
444
+ print(f"Loading models from file: {file_path}")
445
+ if len(state_dict) == 0:
446
+ state_dict = load_state_dict(file_path)
447
+ model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
448
+ for model_name, model in zip(model_names, models):
449
+ self.model.append(model)
450
+ self.model_path.append(file_path)
451
+ self.model_name.append(model_name)
452
+ print(f" The following models are loaded: {model_names}.")
453
+
454
+
455
+ def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
456
+ print(f"Loading models from folder: {file_path}")
457
+ model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
458
+ for model_name, model in zip(model_names, models):
459
+ self.model.append(model)
460
+ self.model_path.append(file_path)
461
+ self.model_name.append(model_name)
462
+ print(f" The following models are loaded: {model_names}.")
463
+
464
+
465
+ def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
466
+ print(f"Loading patch models from file: {file_path}")
467
+ model_names, models = load_patch_model_from_single_file(
468
+ state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
469
+ for model_name, model in zip(model_names, models):
470
+ self.model.append(model)
471
+ self.model_path.append(file_path)
472
+ self.model_name.append(model_name)
473
+ print(f" The following patched models are loaded: {model_names}.")
474
+
475
+
476
+ def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
477
+ print(f"Loading LoRA models from file: {file_path}")
478
+ if len(state_dict) == 0:
479
+ state_dict = load_state_dict(file_path)
480
+ for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
481
+ for lora in [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft()]:
482
+ match_results = lora.match(model, state_dict)
483
+ if match_results is not None:
484
+ print(f" Adding LoRA to {model_name} ({model_path}).")
485
+ lora_prefix, model_resource = match_results
486
+ lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
487
+ break
488
+
489
+
490
+ def load_model(self, file_path, model_names=None):
491
+ print(f"Loading models from: {file_path}")
492
+ if os.path.isfile(file_path):
493
+ state_dict = load_state_dict(file_path)
494
+ else:
495
+ state_dict = None
496
+ for model_detector in self.model_detector:
497
+ if model_detector.match(file_path, state_dict):
498
+ model_names, models = model_detector.load(
499
+ file_path, state_dict,
500
+ device=self.device, torch_dtype=self.torch_dtype,
501
+ allowed_model_names=model_names, model_manager=self
502
+ )
503
+ for model_name, model in zip(model_names, models):
504
+ self.model.append(model)
505
+ self.model_path.append(file_path)
506
+ self.model_name.append(model_name)
507
+ print(f" The following models are loaded: {model_names}.")
508
+ break
509
+ else:
510
+ print(f" We cannot detect the model type. No models are loaded.")
511
+
512
+
513
+ def load_models(self, file_path_list, model_names=None):
514
+ for file_path in file_path_list:
515
+ self.load_model(file_path, model_names)
516
+
517
+
518
+ def fetch_model(self, model_name, file_path=None, require_model_path=False):
519
+ fetched_models = []
520
+ fetched_model_paths = []
521
+ for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
522
+ if file_path is not None and file_path != model_path:
523
+ continue
524
+ if model_name == model_name_:
525
+ fetched_models.append(model)
526
+ fetched_model_paths.append(model_path)
527
+ if len(fetched_models) == 0:
528
+ print(f"No {model_name} models available.")
529
+ return None
530
+ if len(fetched_models) == 1:
531
+ print(f"Using {model_name} from {fetched_model_paths[0]}.")
532
+ else:
533
+ print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
534
+ if require_model_path:
535
+ return fetched_models[0], fetched_model_paths[0]
536
+ else:
537
+ return fetched_models[0]
538
+
539
+
540
+ def to(self, device):
541
+ for model in self.model:
542
+ model.to(device)
543
+
diffsynth/models/sd3_dit.py ADDED
@@ -0,0 +1,798 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from .svd_unet import TemporalTimesteps
4
+ from .tiler import TileWorker
5
+
6
+
7
+
8
+ class PatchEmbed(torch.nn.Module):
9
+ def __init__(self, patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192):
10
+ super().__init__()
11
+ self.pos_embed_max_size = pos_embed_max_size
12
+ self.patch_size = patch_size
13
+
14
+ self.proj = torch.nn.Conv2d(in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size)
15
+ self.pos_embed = torch.nn.Parameter(torch.zeros(1, self.pos_embed_max_size, self.pos_embed_max_size, 1536))
16
+
17
+ def cropped_pos_embed(self, height, width):
18
+ height = height // self.patch_size
19
+ width = width // self.patch_size
20
+ top = (self.pos_embed_max_size - height) // 2
21
+ left = (self.pos_embed_max_size - width) // 2
22
+ spatial_pos_embed = self.pos_embed[:, top : top + height, left : left + width, :].flatten(1, 2)
23
+ return spatial_pos_embed
24
+
25
+ def forward(self, latent):
26
+ height, width = latent.shape[-2:]
27
+ latent = self.proj(latent)
28
+ latent = latent.flatten(2).transpose(1, 2)
29
+ pos_embed = self.cropped_pos_embed(height, width)
30
+ return latent + pos_embed
31
+
32
+
33
+
34
+ class TimestepEmbeddings(torch.nn.Module):
35
+ def __init__(self, dim_in, dim_out):
36
+ super().__init__()
37
+ self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
38
+ self.timestep_embedder = torch.nn.Sequential(
39
+ torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
40
+ )
41
+
42
+ def forward(self, timestep, dtype):
43
+ time_emb = self.time_proj(timestep).to(dtype)
44
+ time_emb = self.timestep_embedder(time_emb)
45
+ return time_emb
46
+
47
+
48
+
49
+ class AdaLayerNorm(torch.nn.Module):
50
+ def __init__(self, dim, single=False):
51
+ super().__init__()
52
+ self.single = single
53
+ self.linear = torch.nn.Linear(dim, dim * (2 if single else 6))
54
+ self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
55
+
56
+ def forward(self, x, emb):
57
+ emb = self.linear(torch.nn.functional.silu(emb))
58
+ if self.single:
59
+ scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
60
+ x = self.norm(x) * (1 + scale) + shift
61
+ return x
62
+ else:
63
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
64
+ x = self.norm(x) * (1 + scale_msa) + shift_msa
65
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
66
+
67
+
68
+
69
+ class JointAttention(torch.nn.Module):
70
+ def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
71
+ super().__init__()
72
+ self.num_heads = num_heads
73
+ self.head_dim = head_dim
74
+ self.only_out_a = only_out_a
75
+
76
+ self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
77
+ self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
78
+
79
+ self.a_to_out = torch.nn.Linear(dim_a, dim_a)
80
+ if not only_out_a:
81
+ self.b_to_out = torch.nn.Linear(dim_b, dim_b)
82
+
83
+ def forward(self, hidden_states_a, hidden_states_b):
84
+ batch_size = hidden_states_a.shape[0]
85
+
86
+ qkv = torch.concat([self.a_to_qkv(hidden_states_a), self.b_to_qkv(hidden_states_b)], dim=1)
87
+ qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
88
+ q, k, v = qkv.chunk(3, dim=1)
89
+
90
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
91
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
92
+ hidden_states = hidden_states.to(q.dtype)
93
+ hidden_states_a, hidden_states_b = hidden_states[:, :hidden_states_a.shape[1]], hidden_states[:, hidden_states_a.shape[1]:]
94
+ hidden_states_a = self.a_to_out(hidden_states_a)
95
+ if self.only_out_a:
96
+ return hidden_states_a
97
+ else:
98
+ hidden_states_b = self.b_to_out(hidden_states_b)
99
+ return hidden_states_a, hidden_states_b
100
+
101
+
102
+
103
+ class JointTransformerBlock(torch.nn.Module):
104
+ def __init__(self, dim, num_attention_heads):
105
+ super().__init__()
106
+ self.norm1_a = AdaLayerNorm(dim)
107
+ self.norm1_b = AdaLayerNorm(dim)
108
+
109
+ self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
110
+
111
+ self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
112
+ self.ff_a = torch.nn.Sequential(
113
+ torch.nn.Linear(dim, dim*4),
114
+ torch.nn.GELU(approximate="tanh"),
115
+ torch.nn.Linear(dim*4, dim)
116
+ )
117
+
118
+ self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
119
+ self.ff_b = torch.nn.Sequential(
120
+ torch.nn.Linear(dim, dim*4),
121
+ torch.nn.GELU(approximate="tanh"),
122
+ torch.nn.Linear(dim*4, dim)
123
+ )
124
+
125
+
126
+ def forward(self, hidden_states_a, hidden_states_b, temb):
127
+ norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
128
+ norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
129
+
130
+ # Attention
131
+ attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b)
132
+
133
+ # Part A
134
+ hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
135
+ norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
136
+ hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
137
+
138
+ # Part B
139
+ hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
140
+ norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
141
+ hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
142
+
143
+ return hidden_states_a, hidden_states_b
144
+
145
+
146
+
147
+ class JointTransformerFinalBlock(torch.nn.Module):
148
+ def __init__(self, dim, num_attention_heads):
149
+ super().__init__()
150
+ self.norm1_a = AdaLayerNorm(dim)
151
+ self.norm1_b = AdaLayerNorm(dim, single=True)
152
+
153
+ self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, only_out_a=True)
154
+
155
+ self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
156
+ self.ff_a = torch.nn.Sequential(
157
+ torch.nn.Linear(dim, dim*4),
158
+ torch.nn.GELU(approximate="tanh"),
159
+ torch.nn.Linear(dim*4, dim)
160
+ )
161
+
162
+
163
+ def forward(self, hidden_states_a, hidden_states_b, temb):
164
+ norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
165
+ norm_hidden_states_b = self.norm1_b(hidden_states_b, emb=temb)
166
+
167
+ # Attention
168
+ attn_output_a = self.attn(norm_hidden_states_a, norm_hidden_states_b)
169
+
170
+ # Part A
171
+ hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
172
+ norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
173
+ hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
174
+
175
+ return hidden_states_a, hidden_states_b
176
+
177
+
178
+
179
+ class SD3DiT(torch.nn.Module):
180
+ def __init__(self):
181
+ super().__init__()
182
+ self.pos_embedder = PatchEmbed(patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192)
183
+ self.time_embedder = TimestepEmbeddings(256, 1536)
184
+ self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(2048, 1536), torch.nn.SiLU(), torch.nn.Linear(1536, 1536))
185
+ self.context_embedder = torch.nn.Linear(4096, 1536)
186
+ self.blocks = torch.nn.ModuleList([JointTransformerBlock(1536, 24) for _ in range(23)] + [JointTransformerFinalBlock(1536, 24)])
187
+ self.norm_out = AdaLayerNorm(1536, single=True)
188
+ self.proj_out = torch.nn.Linear(1536, 64)
189
+
190
+ def tiled_forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size=128, tile_stride=64):
191
+ # Due to the global positional embedding, we cannot implement layer-wise tiled forward.
192
+ hidden_states = TileWorker().tiled_forward(
193
+ lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb),
194
+ hidden_states,
195
+ tile_size,
196
+ tile_stride,
197
+ tile_device=hidden_states.device,
198
+ tile_dtype=hidden_states.dtype
199
+ )
200
+ return hidden_states
201
+
202
+ def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tiled=False, tile_size=128, tile_stride=64, use_gradient_checkpointing=False):
203
+ if tiled:
204
+ return self.tiled_forward(hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size, tile_stride)
205
+ conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
206
+ prompt_emb = self.context_embedder(prompt_emb)
207
+
208
+ height, width = hidden_states.shape[-2:]
209
+ hidden_states = self.pos_embedder(hidden_states)
210
+
211
+ def create_custom_forward(module):
212
+ def custom_forward(*inputs):
213
+ return module(*inputs)
214
+ return custom_forward
215
+
216
+ for block in self.blocks:
217
+ if self.training and use_gradient_checkpointing:
218
+ hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
219
+ create_custom_forward(block),
220
+ hidden_states, prompt_emb, conditioning,
221
+ use_reentrant=False,
222
+ )
223
+ else:
224
+ hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning)
225
+
226
+ hidden_states = self.norm_out(hidden_states, conditioning)
227
+ hidden_states = self.proj_out(hidden_states)
228
+ hidden_states = rearrange(hidden_states, "B (H W) (P Q C) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
229
+ return hidden_states
230
+
231
+ @staticmethod
232
+ def state_dict_converter():
233
+ return SD3DiTStateDictConverter()
234
+
235
+
236
+
237
+ class SD3DiTStateDictConverter:
238
+ def __init__(self):
239
+ pass
240
+
241
+ def from_diffusers(self, state_dict):
242
+ rename_dict = {
243
+ "context_embedder": "context_embedder",
244
+ "pos_embed.pos_embed": "pos_embedder.pos_embed",
245
+ "pos_embed.proj": "pos_embedder.proj",
246
+ "time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
247
+ "time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
248
+ "time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
249
+ "time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
250
+ "norm_out.linear": "norm_out.linear",
251
+ "proj_out": "proj_out",
252
+
253
+ "norm1.linear": "norm1_a.linear",
254
+ "norm1_context.linear": "norm1_b.linear",
255
+ "attn.to_q": "attn.a_to_q",
256
+ "attn.to_k": "attn.a_to_k",
257
+ "attn.to_v": "attn.a_to_v",
258
+ "attn.to_out.0": "attn.a_to_out",
259
+ "attn.add_q_proj": "attn.b_to_q",
260
+ "attn.add_k_proj": "attn.b_to_k",
261
+ "attn.add_v_proj": "attn.b_to_v",
262
+ "attn.to_add_out": "attn.b_to_out",
263
+ "ff.net.0.proj": "ff_a.0",
264
+ "ff.net.2": "ff_a.2",
265
+ "ff_context.net.0.proj": "ff_b.0",
266
+ "ff_context.net.2": "ff_b.2",
267
+ }
268
+ state_dict_ = {}
269
+ for name, param in state_dict.items():
270
+ if name in rename_dict:
271
+ if name == "pos_embed.pos_embed":
272
+ param = param.reshape((1, 192, 192, 1536))
273
+ state_dict_[rename_dict[name]] = param
274
+ elif name.endswith(".weight") or name.endswith(".bias"):
275
+ suffix = ".weight" if name.endswith(".weight") else ".bias"
276
+ prefix = name[:-len(suffix)]
277
+ if prefix in rename_dict:
278
+ state_dict_[rename_dict[prefix] + suffix] = param
279
+ elif prefix.startswith("transformer_blocks."):
280
+ names = prefix.split(".")
281
+ names[0] = "blocks"
282
+ middle = ".".join(names[2:])
283
+ if middle in rename_dict:
284
+ name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
285
+ state_dict_[name_] = param
286
+ return state_dict_
287
+
288
+ def from_civitai(self, state_dict):
289
+ rename_dict = {
290
+ "model.diffusion_model.context_embedder.bias": "context_embedder.bias",
291
+ "model.diffusion_model.context_embedder.weight": "context_embedder.weight",
292
+ "model.diffusion_model.final_layer.linear.bias": "proj_out.bias",
293
+ "model.diffusion_model.final_layer.linear.weight": "proj_out.weight",
294
+ "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias": "blocks.0.norm1_b.linear.bias",
295
+ "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.weight": "blocks.0.norm1_b.linear.weight",
296
+ "model.diffusion_model.joint_blocks.0.context_block.attn.proj.bias": "blocks.0.attn.b_to_out.bias",
297
+ "model.diffusion_model.joint_blocks.0.context_block.attn.proj.weight": "blocks.0.attn.b_to_out.weight",
298
+ "model.diffusion_model.joint_blocks.0.context_block.attn.qkv.bias": ['blocks.0.attn.b_to_q.bias', 'blocks.0.attn.b_to_k.bias', 'blocks.0.attn.b_to_v.bias'],
299
+ "model.diffusion_model.joint_blocks.0.context_block.attn.qkv.weight": ['blocks.0.attn.b_to_q.weight', 'blocks.0.attn.b_to_k.weight', 'blocks.0.attn.b_to_v.weight'],
300
+ "model.diffusion_model.joint_blocks.0.context_block.mlp.fc1.bias": "blocks.0.ff_b.0.bias",
301
+ "model.diffusion_model.joint_blocks.0.context_block.mlp.fc1.weight": "blocks.0.ff_b.0.weight",
302
+ "model.diffusion_model.joint_blocks.0.context_block.mlp.fc2.bias": "blocks.0.ff_b.2.bias",
303
+ "model.diffusion_model.joint_blocks.0.context_block.mlp.fc2.weight": "blocks.0.ff_b.2.weight",
304
+ "model.diffusion_model.joint_blocks.0.x_block.adaLN_modulation.1.bias": "blocks.0.norm1_a.linear.bias",
305
+ "model.diffusion_model.joint_blocks.0.x_block.adaLN_modulation.1.weight": "blocks.0.norm1_a.linear.weight",
306
+ "model.diffusion_model.joint_blocks.0.x_block.attn.proj.bias": "blocks.0.attn.a_to_out.bias",
307
+ "model.diffusion_model.joint_blocks.0.x_block.attn.proj.weight": "blocks.0.attn.a_to_out.weight",
308
+ "model.diffusion_model.joint_blocks.0.x_block.attn.qkv.bias": ['blocks.0.attn.a_to_q.bias', 'blocks.0.attn.a_to_k.bias', 'blocks.0.attn.a_to_v.bias'],
309
+ "model.diffusion_model.joint_blocks.0.x_block.attn.qkv.weight": ['blocks.0.attn.a_to_q.weight', 'blocks.0.attn.a_to_k.weight', 'blocks.0.attn.a_to_v.weight'],
310
+ "model.diffusion_model.joint_blocks.0.x_block.mlp.fc1.bias": "blocks.0.ff_a.0.bias",
311
+ "model.diffusion_model.joint_blocks.0.x_block.mlp.fc1.weight": "blocks.0.ff_a.0.weight",
312
+ "model.diffusion_model.joint_blocks.0.x_block.mlp.fc2.bias": "blocks.0.ff_a.2.bias",
313
+ "model.diffusion_model.joint_blocks.0.x_block.mlp.fc2.weight": "blocks.0.ff_a.2.weight",
314
+ "model.diffusion_model.joint_blocks.1.context_block.adaLN_modulation.1.bias": "blocks.1.norm1_b.linear.bias",
315
+ "model.diffusion_model.joint_blocks.1.context_block.adaLN_modulation.1.weight": "blocks.1.norm1_b.linear.weight",
316
+ "model.diffusion_model.joint_blocks.1.context_block.attn.proj.bias": "blocks.1.attn.b_to_out.bias",
317
+ "model.diffusion_model.joint_blocks.1.context_block.attn.proj.weight": "blocks.1.attn.b_to_out.weight",
318
+ "model.diffusion_model.joint_blocks.1.context_block.attn.qkv.bias": ['blocks.1.attn.b_to_q.bias', 'blocks.1.attn.b_to_k.bias', 'blocks.1.attn.b_to_v.bias'],
319
+ "model.diffusion_model.joint_blocks.1.context_block.attn.qkv.weight": ['blocks.1.attn.b_to_q.weight', 'blocks.1.attn.b_to_k.weight', 'blocks.1.attn.b_to_v.weight'],
320
+ "model.diffusion_model.joint_blocks.1.context_block.mlp.fc1.bias": "blocks.1.ff_b.0.bias",
321
+ "model.diffusion_model.joint_blocks.1.context_block.mlp.fc1.weight": "blocks.1.ff_b.0.weight",
322
+ "model.diffusion_model.joint_blocks.1.context_block.mlp.fc2.bias": "blocks.1.ff_b.2.bias",
323
+ "model.diffusion_model.joint_blocks.1.context_block.mlp.fc2.weight": "blocks.1.ff_b.2.weight",
324
+ "model.diffusion_model.joint_blocks.1.x_block.adaLN_modulation.1.bias": "blocks.1.norm1_a.linear.bias",
325
+ "model.diffusion_model.joint_blocks.1.x_block.adaLN_modulation.1.weight": "blocks.1.norm1_a.linear.weight",
326
+ "model.diffusion_model.joint_blocks.1.x_block.attn.proj.bias": "blocks.1.attn.a_to_out.bias",
327
+ "model.diffusion_model.joint_blocks.1.x_block.attn.proj.weight": "blocks.1.attn.a_to_out.weight",
328
+ "model.diffusion_model.joint_blocks.1.x_block.attn.qkv.bias": ['blocks.1.attn.a_to_q.bias', 'blocks.1.attn.a_to_k.bias', 'blocks.1.attn.a_to_v.bias'],
329
+ "model.diffusion_model.joint_blocks.1.x_block.attn.qkv.weight": ['blocks.1.attn.a_to_q.weight', 'blocks.1.attn.a_to_k.weight', 'blocks.1.attn.a_to_v.weight'],
330
+ "model.diffusion_model.joint_blocks.1.x_block.mlp.fc1.bias": "blocks.1.ff_a.0.bias",
331
+ "model.diffusion_model.joint_blocks.1.x_block.mlp.fc1.weight": "blocks.1.ff_a.0.weight",
332
+ "model.diffusion_model.joint_blocks.1.x_block.mlp.fc2.bias": "blocks.1.ff_a.2.bias",
333
+ "model.diffusion_model.joint_blocks.1.x_block.mlp.fc2.weight": "blocks.1.ff_a.2.weight",
334
+ "model.diffusion_model.joint_blocks.10.context_block.adaLN_modulation.1.bias": "blocks.10.norm1_b.linear.bias",
335
+ "model.diffusion_model.joint_blocks.10.context_block.adaLN_modulation.1.weight": "blocks.10.norm1_b.linear.weight",
336
+ "model.diffusion_model.joint_blocks.10.context_block.attn.proj.bias": "blocks.10.attn.b_to_out.bias",
337
+ "model.diffusion_model.joint_blocks.10.context_block.attn.proj.weight": "blocks.10.attn.b_to_out.weight",
338
+ "model.diffusion_model.joint_blocks.10.context_block.attn.qkv.bias": ['blocks.10.attn.b_to_q.bias', 'blocks.10.attn.b_to_k.bias', 'blocks.10.attn.b_to_v.bias'],
339
+ "model.diffusion_model.joint_blocks.10.context_block.attn.qkv.weight": ['blocks.10.attn.b_to_q.weight', 'blocks.10.attn.b_to_k.weight', 'blocks.10.attn.b_to_v.weight'],
340
+ "model.diffusion_model.joint_blocks.10.context_block.mlp.fc1.bias": "blocks.10.ff_b.0.bias",
341
+ "model.diffusion_model.joint_blocks.10.context_block.mlp.fc1.weight": "blocks.10.ff_b.0.weight",
342
+ "model.diffusion_model.joint_blocks.10.context_block.mlp.fc2.bias": "blocks.10.ff_b.2.bias",
343
+ "model.diffusion_model.joint_blocks.10.context_block.mlp.fc2.weight": "blocks.10.ff_b.2.weight",
344
+ "model.diffusion_model.joint_blocks.10.x_block.adaLN_modulation.1.bias": "blocks.10.norm1_a.linear.bias",
345
+ "model.diffusion_model.joint_blocks.10.x_block.adaLN_modulation.1.weight": "blocks.10.norm1_a.linear.weight",
346
+ "model.diffusion_model.joint_blocks.10.x_block.attn.proj.bias": "blocks.10.attn.a_to_out.bias",
347
+ "model.diffusion_model.joint_blocks.10.x_block.attn.proj.weight": "blocks.10.attn.a_to_out.weight",
348
+ "model.diffusion_model.joint_blocks.10.x_block.attn.qkv.bias": ['blocks.10.attn.a_to_q.bias', 'blocks.10.attn.a_to_k.bias', 'blocks.10.attn.a_to_v.bias'],
349
+ "model.diffusion_model.joint_blocks.10.x_block.attn.qkv.weight": ['blocks.10.attn.a_to_q.weight', 'blocks.10.attn.a_to_k.weight', 'blocks.10.attn.a_to_v.weight'],
350
+ "model.diffusion_model.joint_blocks.10.x_block.mlp.fc1.bias": "blocks.10.ff_a.0.bias",
351
+ "model.diffusion_model.joint_blocks.10.x_block.mlp.fc1.weight": "blocks.10.ff_a.0.weight",
352
+ "model.diffusion_model.joint_blocks.10.x_block.mlp.fc2.bias": "blocks.10.ff_a.2.bias",
353
+ "model.diffusion_model.joint_blocks.10.x_block.mlp.fc2.weight": "blocks.10.ff_a.2.weight",
354
+ "model.diffusion_model.joint_blocks.11.context_block.adaLN_modulation.1.bias": "blocks.11.norm1_b.linear.bias",
355
+ "model.diffusion_model.joint_blocks.11.context_block.adaLN_modulation.1.weight": "blocks.11.norm1_b.linear.weight",
356
+ "model.diffusion_model.joint_blocks.11.context_block.attn.proj.bias": "blocks.11.attn.b_to_out.bias",
357
+ "model.diffusion_model.joint_blocks.11.context_block.attn.proj.weight": "blocks.11.attn.b_to_out.weight",
358
+ "model.diffusion_model.joint_blocks.11.context_block.attn.qkv.bias": ['blocks.11.attn.b_to_q.bias', 'blocks.11.attn.b_to_k.bias', 'blocks.11.attn.b_to_v.bias'],
359
+ "model.diffusion_model.joint_blocks.11.context_block.attn.qkv.weight": ['blocks.11.attn.b_to_q.weight', 'blocks.11.attn.b_to_k.weight', 'blocks.11.attn.b_to_v.weight'],
360
+ "model.diffusion_model.joint_blocks.11.context_block.mlp.fc1.bias": "blocks.11.ff_b.0.bias",
361
+ "model.diffusion_model.joint_blocks.11.context_block.mlp.fc1.weight": "blocks.11.ff_b.0.weight",
362
+ "model.diffusion_model.joint_blocks.11.context_block.mlp.fc2.bias": "blocks.11.ff_b.2.bias",
363
+ "model.diffusion_model.joint_blocks.11.context_block.mlp.fc2.weight": "blocks.11.ff_b.2.weight",
364
+ "model.diffusion_model.joint_blocks.11.x_block.adaLN_modulation.1.bias": "blocks.11.norm1_a.linear.bias",
365
+ "model.diffusion_model.joint_blocks.11.x_block.adaLN_modulation.1.weight": "blocks.11.norm1_a.linear.weight",
366
+ "model.diffusion_model.joint_blocks.11.x_block.attn.proj.bias": "blocks.11.attn.a_to_out.bias",
367
+ "model.diffusion_model.joint_blocks.11.x_block.attn.proj.weight": "blocks.11.attn.a_to_out.weight",
368
+ "model.diffusion_model.joint_blocks.11.x_block.attn.qkv.bias": ['blocks.11.attn.a_to_q.bias', 'blocks.11.attn.a_to_k.bias', 'blocks.11.attn.a_to_v.bias'],
369
+ "model.diffusion_model.joint_blocks.11.x_block.attn.qkv.weight": ['blocks.11.attn.a_to_q.weight', 'blocks.11.attn.a_to_k.weight', 'blocks.11.attn.a_to_v.weight'],
370
+ "model.diffusion_model.joint_blocks.11.x_block.mlp.fc1.bias": "blocks.11.ff_a.0.bias",
371
+ "model.diffusion_model.joint_blocks.11.x_block.mlp.fc1.weight": "blocks.11.ff_a.0.weight",
372
+ "model.diffusion_model.joint_blocks.11.x_block.mlp.fc2.bias": "blocks.11.ff_a.2.bias",
373
+ "model.diffusion_model.joint_blocks.11.x_block.mlp.fc2.weight": "blocks.11.ff_a.2.weight",
374
+ "model.diffusion_model.joint_blocks.12.context_block.adaLN_modulation.1.bias": "blocks.12.norm1_b.linear.bias",
375
+ "model.diffusion_model.joint_blocks.12.context_block.adaLN_modulation.1.weight": "blocks.12.norm1_b.linear.weight",
376
+ "model.diffusion_model.joint_blocks.12.context_block.attn.proj.bias": "blocks.12.attn.b_to_out.bias",
377
+ "model.diffusion_model.joint_blocks.12.context_block.attn.proj.weight": "blocks.12.attn.b_to_out.weight",
378
+ "model.diffusion_model.joint_blocks.12.context_block.attn.qkv.bias": ['blocks.12.attn.b_to_q.bias', 'blocks.12.attn.b_to_k.bias', 'blocks.12.attn.b_to_v.bias'],
379
+ "model.diffusion_model.joint_blocks.12.context_block.attn.qkv.weight": ['blocks.12.attn.b_to_q.weight', 'blocks.12.attn.b_to_k.weight', 'blocks.12.attn.b_to_v.weight'],
380
+ "model.diffusion_model.joint_blocks.12.context_block.mlp.fc1.bias": "blocks.12.ff_b.0.bias",
381
+ "model.diffusion_model.joint_blocks.12.context_block.mlp.fc1.weight": "blocks.12.ff_b.0.weight",
382
+ "model.diffusion_model.joint_blocks.12.context_block.mlp.fc2.bias": "blocks.12.ff_b.2.bias",
383
+ "model.diffusion_model.joint_blocks.12.context_block.mlp.fc2.weight": "blocks.12.ff_b.2.weight",
384
+ "model.diffusion_model.joint_blocks.12.x_block.adaLN_modulation.1.bias": "blocks.12.norm1_a.linear.bias",
385
+ "model.diffusion_model.joint_blocks.12.x_block.adaLN_modulation.1.weight": "blocks.12.norm1_a.linear.weight",
386
+ "model.diffusion_model.joint_blocks.12.x_block.attn.proj.bias": "blocks.12.attn.a_to_out.bias",
387
+ "model.diffusion_model.joint_blocks.12.x_block.attn.proj.weight": "blocks.12.attn.a_to_out.weight",
388
+ "model.diffusion_model.joint_blocks.12.x_block.attn.qkv.bias": ['blocks.12.attn.a_to_q.bias', 'blocks.12.attn.a_to_k.bias', 'blocks.12.attn.a_to_v.bias'],
389
+ "model.diffusion_model.joint_blocks.12.x_block.attn.qkv.weight": ['blocks.12.attn.a_to_q.weight', 'blocks.12.attn.a_to_k.weight', 'blocks.12.attn.a_to_v.weight'],
390
+ "model.diffusion_model.joint_blocks.12.x_block.mlp.fc1.bias": "blocks.12.ff_a.0.bias",
391
+ "model.diffusion_model.joint_blocks.12.x_block.mlp.fc1.weight": "blocks.12.ff_a.0.weight",
392
+ "model.diffusion_model.joint_blocks.12.x_block.mlp.fc2.bias": "blocks.12.ff_a.2.bias",
393
+ "model.diffusion_model.joint_blocks.12.x_block.mlp.fc2.weight": "blocks.12.ff_a.2.weight",
394
+ "model.diffusion_model.joint_blocks.13.context_block.adaLN_modulation.1.bias": "blocks.13.norm1_b.linear.bias",
395
+ "model.diffusion_model.joint_blocks.13.context_block.adaLN_modulation.1.weight": "blocks.13.norm1_b.linear.weight",
396
+ "model.diffusion_model.joint_blocks.13.context_block.attn.proj.bias": "blocks.13.attn.b_to_out.bias",
397
+ "model.diffusion_model.joint_blocks.13.context_block.attn.proj.weight": "blocks.13.attn.b_to_out.weight",
398
+ "model.diffusion_model.joint_blocks.13.context_block.attn.qkv.bias": ['blocks.13.attn.b_to_q.bias', 'blocks.13.attn.b_to_k.bias', 'blocks.13.attn.b_to_v.bias'],
399
+ "model.diffusion_model.joint_blocks.13.context_block.attn.qkv.weight": ['blocks.13.attn.b_to_q.weight', 'blocks.13.attn.b_to_k.weight', 'blocks.13.attn.b_to_v.weight'],
400
+ "model.diffusion_model.joint_blocks.13.context_block.mlp.fc1.bias": "blocks.13.ff_b.0.bias",
401
+ "model.diffusion_model.joint_blocks.13.context_block.mlp.fc1.weight": "blocks.13.ff_b.0.weight",
402
+ "model.diffusion_model.joint_blocks.13.context_block.mlp.fc2.bias": "blocks.13.ff_b.2.bias",
403
+ "model.diffusion_model.joint_blocks.13.context_block.mlp.fc2.weight": "blocks.13.ff_b.2.weight",
404
+ "model.diffusion_model.joint_blocks.13.x_block.adaLN_modulation.1.bias": "blocks.13.norm1_a.linear.bias",
405
+ "model.diffusion_model.joint_blocks.13.x_block.adaLN_modulation.1.weight": "blocks.13.norm1_a.linear.weight",
406
+ "model.diffusion_model.joint_blocks.13.x_block.attn.proj.bias": "blocks.13.attn.a_to_out.bias",
407
+ "model.diffusion_model.joint_blocks.13.x_block.attn.proj.weight": "blocks.13.attn.a_to_out.weight",
408
+ "model.diffusion_model.joint_blocks.13.x_block.attn.qkv.bias": ['blocks.13.attn.a_to_q.bias', 'blocks.13.attn.a_to_k.bias', 'blocks.13.attn.a_to_v.bias'],
409
+ "model.diffusion_model.joint_blocks.13.x_block.attn.qkv.weight": ['blocks.13.attn.a_to_q.weight', 'blocks.13.attn.a_to_k.weight', 'blocks.13.attn.a_to_v.weight'],
410
+ "model.diffusion_model.joint_blocks.13.x_block.mlp.fc1.bias": "blocks.13.ff_a.0.bias",
411
+ "model.diffusion_model.joint_blocks.13.x_block.mlp.fc1.weight": "blocks.13.ff_a.0.weight",
412
+ "model.diffusion_model.joint_blocks.13.x_block.mlp.fc2.bias": "blocks.13.ff_a.2.bias",
413
+ "model.diffusion_model.joint_blocks.13.x_block.mlp.fc2.weight": "blocks.13.ff_a.2.weight",
414
+ "model.diffusion_model.joint_blocks.14.context_block.adaLN_modulation.1.bias": "blocks.14.norm1_b.linear.bias",
415
+ "model.diffusion_model.joint_blocks.14.context_block.adaLN_modulation.1.weight": "blocks.14.norm1_b.linear.weight",
416
+ "model.diffusion_model.joint_blocks.14.context_block.attn.proj.bias": "blocks.14.attn.b_to_out.bias",
417
+ "model.diffusion_model.joint_blocks.14.context_block.attn.proj.weight": "blocks.14.attn.b_to_out.weight",
418
+ "model.diffusion_model.joint_blocks.14.context_block.attn.qkv.bias": ['blocks.14.attn.b_to_q.bias', 'blocks.14.attn.b_to_k.bias', 'blocks.14.attn.b_to_v.bias'],
419
+ "model.diffusion_model.joint_blocks.14.context_block.attn.qkv.weight": ['blocks.14.attn.b_to_q.weight', 'blocks.14.attn.b_to_k.weight', 'blocks.14.attn.b_to_v.weight'],
420
+ "model.diffusion_model.joint_blocks.14.context_block.mlp.fc1.bias": "blocks.14.ff_b.0.bias",
421
+ "model.diffusion_model.joint_blocks.14.context_block.mlp.fc1.weight": "blocks.14.ff_b.0.weight",
422
+ "model.diffusion_model.joint_blocks.14.context_block.mlp.fc2.bias": "blocks.14.ff_b.2.bias",
423
+ "model.diffusion_model.joint_blocks.14.context_block.mlp.fc2.weight": "blocks.14.ff_b.2.weight",
424
+ "model.diffusion_model.joint_blocks.14.x_block.adaLN_modulation.1.bias": "blocks.14.norm1_a.linear.bias",
425
+ "model.diffusion_model.joint_blocks.14.x_block.adaLN_modulation.1.weight": "blocks.14.norm1_a.linear.weight",
426
+ "model.diffusion_model.joint_blocks.14.x_block.attn.proj.bias": "blocks.14.attn.a_to_out.bias",
427
+ "model.diffusion_model.joint_blocks.14.x_block.attn.proj.weight": "blocks.14.attn.a_to_out.weight",
428
+ "model.diffusion_model.joint_blocks.14.x_block.attn.qkv.bias": ['blocks.14.attn.a_to_q.bias', 'blocks.14.attn.a_to_k.bias', 'blocks.14.attn.a_to_v.bias'],
429
+ "model.diffusion_model.joint_blocks.14.x_block.attn.qkv.weight": ['blocks.14.attn.a_to_q.weight', 'blocks.14.attn.a_to_k.weight', 'blocks.14.attn.a_to_v.weight'],
430
+ "model.diffusion_model.joint_blocks.14.x_block.mlp.fc1.bias": "blocks.14.ff_a.0.bias",
431
+ "model.diffusion_model.joint_blocks.14.x_block.mlp.fc1.weight": "blocks.14.ff_a.0.weight",
432
+ "model.diffusion_model.joint_blocks.14.x_block.mlp.fc2.bias": "blocks.14.ff_a.2.bias",
433
+ "model.diffusion_model.joint_blocks.14.x_block.mlp.fc2.weight": "blocks.14.ff_a.2.weight",
434
+ "model.diffusion_model.joint_blocks.15.context_block.adaLN_modulation.1.bias": "blocks.15.norm1_b.linear.bias",
435
+ "model.diffusion_model.joint_blocks.15.context_block.adaLN_modulation.1.weight": "blocks.15.norm1_b.linear.weight",
436
+ "model.diffusion_model.joint_blocks.15.context_block.attn.proj.bias": "blocks.15.attn.b_to_out.bias",
437
+ "model.diffusion_model.joint_blocks.15.context_block.attn.proj.weight": "blocks.15.attn.b_to_out.weight",
438
+ "model.diffusion_model.joint_blocks.15.context_block.attn.qkv.bias": ['blocks.15.attn.b_to_q.bias', 'blocks.15.attn.b_to_k.bias', 'blocks.15.attn.b_to_v.bias'],
439
+ "model.diffusion_model.joint_blocks.15.context_block.attn.qkv.weight": ['blocks.15.attn.b_to_q.weight', 'blocks.15.attn.b_to_k.weight', 'blocks.15.attn.b_to_v.weight'],
440
+ "model.diffusion_model.joint_blocks.15.context_block.mlp.fc1.bias": "blocks.15.ff_b.0.bias",
441
+ "model.diffusion_model.joint_blocks.15.context_block.mlp.fc1.weight": "blocks.15.ff_b.0.weight",
442
+ "model.diffusion_model.joint_blocks.15.context_block.mlp.fc2.bias": "blocks.15.ff_b.2.bias",
443
+ "model.diffusion_model.joint_blocks.15.context_block.mlp.fc2.weight": "blocks.15.ff_b.2.weight",
444
+ "model.diffusion_model.joint_blocks.15.x_block.adaLN_modulation.1.bias": "blocks.15.norm1_a.linear.bias",
445
+ "model.diffusion_model.joint_blocks.15.x_block.adaLN_modulation.1.weight": "blocks.15.norm1_a.linear.weight",
446
+ "model.diffusion_model.joint_blocks.15.x_block.attn.proj.bias": "blocks.15.attn.a_to_out.bias",
447
+ "model.diffusion_model.joint_blocks.15.x_block.attn.proj.weight": "blocks.15.attn.a_to_out.weight",
448
+ "model.diffusion_model.joint_blocks.15.x_block.attn.qkv.bias": ['blocks.15.attn.a_to_q.bias', 'blocks.15.attn.a_to_k.bias', 'blocks.15.attn.a_to_v.bias'],
449
+ "model.diffusion_model.joint_blocks.15.x_block.attn.qkv.weight": ['blocks.15.attn.a_to_q.weight', 'blocks.15.attn.a_to_k.weight', 'blocks.15.attn.a_to_v.weight'],
450
+ "model.diffusion_model.joint_blocks.15.x_block.mlp.fc1.bias": "blocks.15.ff_a.0.bias",
451
+ "model.diffusion_model.joint_blocks.15.x_block.mlp.fc1.weight": "blocks.15.ff_a.0.weight",
452
+ "model.diffusion_model.joint_blocks.15.x_block.mlp.fc2.bias": "blocks.15.ff_a.2.bias",
453
+ "model.diffusion_model.joint_blocks.15.x_block.mlp.fc2.weight": "blocks.15.ff_a.2.weight",
454
+ "model.diffusion_model.joint_blocks.16.context_block.adaLN_modulation.1.bias": "blocks.16.norm1_b.linear.bias",
455
+ "model.diffusion_model.joint_blocks.16.context_block.adaLN_modulation.1.weight": "blocks.16.norm1_b.linear.weight",
456
+ "model.diffusion_model.joint_blocks.16.context_block.attn.proj.bias": "blocks.16.attn.b_to_out.bias",
457
+ "model.diffusion_model.joint_blocks.16.context_block.attn.proj.weight": "blocks.16.attn.b_to_out.weight",
458
+ "model.diffusion_model.joint_blocks.16.context_block.attn.qkv.bias": ['blocks.16.attn.b_to_q.bias', 'blocks.16.attn.b_to_k.bias', 'blocks.16.attn.b_to_v.bias'],
459
+ "model.diffusion_model.joint_blocks.16.context_block.attn.qkv.weight": ['blocks.16.attn.b_to_q.weight', 'blocks.16.attn.b_to_k.weight', 'blocks.16.attn.b_to_v.weight'],
460
+ "model.diffusion_model.joint_blocks.16.context_block.mlp.fc1.bias": "blocks.16.ff_b.0.bias",
461
+ "model.diffusion_model.joint_blocks.16.context_block.mlp.fc1.weight": "blocks.16.ff_b.0.weight",
462
+ "model.diffusion_model.joint_blocks.16.context_block.mlp.fc2.bias": "blocks.16.ff_b.2.bias",
463
+ "model.diffusion_model.joint_blocks.16.context_block.mlp.fc2.weight": "blocks.16.ff_b.2.weight",
464
+ "model.diffusion_model.joint_blocks.16.x_block.adaLN_modulation.1.bias": "blocks.16.norm1_a.linear.bias",
465
+ "model.diffusion_model.joint_blocks.16.x_block.adaLN_modulation.1.weight": "blocks.16.norm1_a.linear.weight",
466
+ "model.diffusion_model.joint_blocks.16.x_block.attn.proj.bias": "blocks.16.attn.a_to_out.bias",
467
+ "model.diffusion_model.joint_blocks.16.x_block.attn.proj.weight": "blocks.16.attn.a_to_out.weight",
468
+ "model.diffusion_model.joint_blocks.16.x_block.attn.qkv.bias": ['blocks.16.attn.a_to_q.bias', 'blocks.16.attn.a_to_k.bias', 'blocks.16.attn.a_to_v.bias'],
469
+ "model.diffusion_model.joint_blocks.16.x_block.attn.qkv.weight": ['blocks.16.attn.a_to_q.weight', 'blocks.16.attn.a_to_k.weight', 'blocks.16.attn.a_to_v.weight'],
470
+ "model.diffusion_model.joint_blocks.16.x_block.mlp.fc1.bias": "blocks.16.ff_a.0.bias",
471
+ "model.diffusion_model.joint_blocks.16.x_block.mlp.fc1.weight": "blocks.16.ff_a.0.weight",
472
+ "model.diffusion_model.joint_blocks.16.x_block.mlp.fc2.bias": "blocks.16.ff_a.2.bias",
473
+ "model.diffusion_model.joint_blocks.16.x_block.mlp.fc2.weight": "blocks.16.ff_a.2.weight",
474
+ "model.diffusion_model.joint_blocks.17.context_block.adaLN_modulation.1.bias": "blocks.17.norm1_b.linear.bias",
475
+ "model.diffusion_model.joint_blocks.17.context_block.adaLN_modulation.1.weight": "blocks.17.norm1_b.linear.weight",
476
+ "model.diffusion_model.joint_blocks.17.context_block.attn.proj.bias": "blocks.17.attn.b_to_out.bias",
477
+ "model.diffusion_model.joint_blocks.17.context_block.attn.proj.weight": "blocks.17.attn.b_to_out.weight",
478
+ "model.diffusion_model.joint_blocks.17.context_block.attn.qkv.bias": ['blocks.17.attn.b_to_q.bias', 'blocks.17.attn.b_to_k.bias', 'blocks.17.attn.b_to_v.bias'],
479
+ "model.diffusion_model.joint_blocks.17.context_block.attn.qkv.weight": ['blocks.17.attn.b_to_q.weight', 'blocks.17.attn.b_to_k.weight', 'blocks.17.attn.b_to_v.weight'],
480
+ "model.diffusion_model.joint_blocks.17.context_block.mlp.fc1.bias": "blocks.17.ff_b.0.bias",
481
+ "model.diffusion_model.joint_blocks.17.context_block.mlp.fc1.weight": "blocks.17.ff_b.0.weight",
482
+ "model.diffusion_model.joint_blocks.17.context_block.mlp.fc2.bias": "blocks.17.ff_b.2.bias",
483
+ "model.diffusion_model.joint_blocks.17.context_block.mlp.fc2.weight": "blocks.17.ff_b.2.weight",
484
+ "model.diffusion_model.joint_blocks.17.x_block.adaLN_modulation.1.bias": "blocks.17.norm1_a.linear.bias",
485
+ "model.diffusion_model.joint_blocks.17.x_block.adaLN_modulation.1.weight": "blocks.17.norm1_a.linear.weight",
486
+ "model.diffusion_model.joint_blocks.17.x_block.attn.proj.bias": "blocks.17.attn.a_to_out.bias",
487
+ "model.diffusion_model.joint_blocks.17.x_block.attn.proj.weight": "blocks.17.attn.a_to_out.weight",
488
+ "model.diffusion_model.joint_blocks.17.x_block.attn.qkv.bias": ['blocks.17.attn.a_to_q.bias', 'blocks.17.attn.a_to_k.bias', 'blocks.17.attn.a_to_v.bias'],
489
+ "model.diffusion_model.joint_blocks.17.x_block.attn.qkv.weight": ['blocks.17.attn.a_to_q.weight', 'blocks.17.attn.a_to_k.weight', 'blocks.17.attn.a_to_v.weight'],
490
+ "model.diffusion_model.joint_blocks.17.x_block.mlp.fc1.bias": "blocks.17.ff_a.0.bias",
491
+ "model.diffusion_model.joint_blocks.17.x_block.mlp.fc1.weight": "blocks.17.ff_a.0.weight",
492
+ "model.diffusion_model.joint_blocks.17.x_block.mlp.fc2.bias": "blocks.17.ff_a.2.bias",
493
+ "model.diffusion_model.joint_blocks.17.x_block.mlp.fc2.weight": "blocks.17.ff_a.2.weight",
494
+ "model.diffusion_model.joint_blocks.18.context_block.adaLN_modulation.1.bias": "blocks.18.norm1_b.linear.bias",
495
+ "model.diffusion_model.joint_blocks.18.context_block.adaLN_modulation.1.weight": "blocks.18.norm1_b.linear.weight",
496
+ "model.diffusion_model.joint_blocks.18.context_block.attn.proj.bias": "blocks.18.attn.b_to_out.bias",
497
+ "model.diffusion_model.joint_blocks.18.context_block.attn.proj.weight": "blocks.18.attn.b_to_out.weight",
498
+ "model.diffusion_model.joint_blocks.18.context_block.attn.qkv.bias": ['blocks.18.attn.b_to_q.bias', 'blocks.18.attn.b_to_k.bias', 'blocks.18.attn.b_to_v.bias'],
499
+ "model.diffusion_model.joint_blocks.18.context_block.attn.qkv.weight": ['blocks.18.attn.b_to_q.weight', 'blocks.18.attn.b_to_k.weight', 'blocks.18.attn.b_to_v.weight'],
500
+ "model.diffusion_model.joint_blocks.18.context_block.mlp.fc1.bias": "blocks.18.ff_b.0.bias",
501
+ "model.diffusion_model.joint_blocks.18.context_block.mlp.fc1.weight": "blocks.18.ff_b.0.weight",
502
+ "model.diffusion_model.joint_blocks.18.context_block.mlp.fc2.bias": "blocks.18.ff_b.2.bias",
503
+ "model.diffusion_model.joint_blocks.18.context_block.mlp.fc2.weight": "blocks.18.ff_b.2.weight",
504
+ "model.diffusion_model.joint_blocks.18.x_block.adaLN_modulation.1.bias": "blocks.18.norm1_a.linear.bias",
505
+ "model.diffusion_model.joint_blocks.18.x_block.adaLN_modulation.1.weight": "blocks.18.norm1_a.linear.weight",
506
+ "model.diffusion_model.joint_blocks.18.x_block.attn.proj.bias": "blocks.18.attn.a_to_out.bias",
507
+ "model.diffusion_model.joint_blocks.18.x_block.attn.proj.weight": "blocks.18.attn.a_to_out.weight",
508
+ "model.diffusion_model.joint_blocks.18.x_block.attn.qkv.bias": ['blocks.18.attn.a_to_q.bias', 'blocks.18.attn.a_to_k.bias', 'blocks.18.attn.a_to_v.bias'],
509
+ "model.diffusion_model.joint_blocks.18.x_block.attn.qkv.weight": ['blocks.18.attn.a_to_q.weight', 'blocks.18.attn.a_to_k.weight', 'blocks.18.attn.a_to_v.weight'],
510
+ "model.diffusion_model.joint_blocks.18.x_block.mlp.fc1.bias": "blocks.18.ff_a.0.bias",
511
+ "model.diffusion_model.joint_blocks.18.x_block.mlp.fc1.weight": "blocks.18.ff_a.0.weight",
512
+ "model.diffusion_model.joint_blocks.18.x_block.mlp.fc2.bias": "blocks.18.ff_a.2.bias",
513
+ "model.diffusion_model.joint_blocks.18.x_block.mlp.fc2.weight": "blocks.18.ff_a.2.weight",
514
+ "model.diffusion_model.joint_blocks.19.context_block.adaLN_modulation.1.bias": "blocks.19.norm1_b.linear.bias",
515
+ "model.diffusion_model.joint_blocks.19.context_block.adaLN_modulation.1.weight": "blocks.19.norm1_b.linear.weight",
516
+ "model.diffusion_model.joint_blocks.19.context_block.attn.proj.bias": "blocks.19.attn.b_to_out.bias",
517
+ "model.diffusion_model.joint_blocks.19.context_block.attn.proj.weight": "blocks.19.attn.b_to_out.weight",
518
+ "model.diffusion_model.joint_blocks.19.context_block.attn.qkv.bias": ['blocks.19.attn.b_to_q.bias', 'blocks.19.attn.b_to_k.bias', 'blocks.19.attn.b_to_v.bias'],
519
+ "model.diffusion_model.joint_blocks.19.context_block.attn.qkv.weight": ['blocks.19.attn.b_to_q.weight', 'blocks.19.attn.b_to_k.weight', 'blocks.19.attn.b_to_v.weight'],
520
+ "model.diffusion_model.joint_blocks.19.context_block.mlp.fc1.bias": "blocks.19.ff_b.0.bias",
521
+ "model.diffusion_model.joint_blocks.19.context_block.mlp.fc1.weight": "blocks.19.ff_b.0.weight",
522
+ "model.diffusion_model.joint_blocks.19.context_block.mlp.fc2.bias": "blocks.19.ff_b.2.bias",
523
+ "model.diffusion_model.joint_blocks.19.context_block.mlp.fc2.weight": "blocks.19.ff_b.2.weight",
524
+ "model.diffusion_model.joint_blocks.19.x_block.adaLN_modulation.1.bias": "blocks.19.norm1_a.linear.bias",
525
+ "model.diffusion_model.joint_blocks.19.x_block.adaLN_modulation.1.weight": "blocks.19.norm1_a.linear.weight",
526
+ "model.diffusion_model.joint_blocks.19.x_block.attn.proj.bias": "blocks.19.attn.a_to_out.bias",
527
+ "model.diffusion_model.joint_blocks.19.x_block.attn.proj.weight": "blocks.19.attn.a_to_out.weight",
528
+ "model.diffusion_model.joint_blocks.19.x_block.attn.qkv.bias": ['blocks.19.attn.a_to_q.bias', 'blocks.19.attn.a_to_k.bias', 'blocks.19.attn.a_to_v.bias'],
529
+ "model.diffusion_model.joint_blocks.19.x_block.attn.qkv.weight": ['blocks.19.attn.a_to_q.weight', 'blocks.19.attn.a_to_k.weight', 'blocks.19.attn.a_to_v.weight'],
530
+ "model.diffusion_model.joint_blocks.19.x_block.mlp.fc1.bias": "blocks.19.ff_a.0.bias",
531
+ "model.diffusion_model.joint_blocks.19.x_block.mlp.fc1.weight": "blocks.19.ff_a.0.weight",
532
+ "model.diffusion_model.joint_blocks.19.x_block.mlp.fc2.bias": "blocks.19.ff_a.2.bias",
533
+ "model.diffusion_model.joint_blocks.19.x_block.mlp.fc2.weight": "blocks.19.ff_a.2.weight",
534
+ "model.diffusion_model.joint_blocks.2.context_block.adaLN_modulation.1.bias": "blocks.2.norm1_b.linear.bias",
535
+ "model.diffusion_model.joint_blocks.2.context_block.adaLN_modulation.1.weight": "blocks.2.norm1_b.linear.weight",
536
+ "model.diffusion_model.joint_blocks.2.context_block.attn.proj.bias": "blocks.2.attn.b_to_out.bias",
537
+ "model.diffusion_model.joint_blocks.2.context_block.attn.proj.weight": "blocks.2.attn.b_to_out.weight",
538
+ "model.diffusion_model.joint_blocks.2.context_block.attn.qkv.bias": ['blocks.2.attn.b_to_q.bias', 'blocks.2.attn.b_to_k.bias', 'blocks.2.attn.b_to_v.bias'],
539
+ "model.diffusion_model.joint_blocks.2.context_block.attn.qkv.weight": ['blocks.2.attn.b_to_q.weight', 'blocks.2.attn.b_to_k.weight', 'blocks.2.attn.b_to_v.weight'],
540
+ "model.diffusion_model.joint_blocks.2.context_block.mlp.fc1.bias": "blocks.2.ff_b.0.bias",
541
+ "model.diffusion_model.joint_blocks.2.context_block.mlp.fc1.weight": "blocks.2.ff_b.0.weight",
542
+ "model.diffusion_model.joint_blocks.2.context_block.mlp.fc2.bias": "blocks.2.ff_b.2.bias",
543
+ "model.diffusion_model.joint_blocks.2.context_block.mlp.fc2.weight": "blocks.2.ff_b.2.weight",
544
+ "model.diffusion_model.joint_blocks.2.x_block.adaLN_modulation.1.bias": "blocks.2.norm1_a.linear.bias",
545
+ "model.diffusion_model.joint_blocks.2.x_block.adaLN_modulation.1.weight": "blocks.2.norm1_a.linear.weight",
546
+ "model.diffusion_model.joint_blocks.2.x_block.attn.proj.bias": "blocks.2.attn.a_to_out.bias",
547
+ "model.diffusion_model.joint_blocks.2.x_block.attn.proj.weight": "blocks.2.attn.a_to_out.weight",
548
+ "model.diffusion_model.joint_blocks.2.x_block.attn.qkv.bias": ['blocks.2.attn.a_to_q.bias', 'blocks.2.attn.a_to_k.bias', 'blocks.2.attn.a_to_v.bias'],
549
+ "model.diffusion_model.joint_blocks.2.x_block.attn.qkv.weight": ['blocks.2.attn.a_to_q.weight', 'blocks.2.attn.a_to_k.weight', 'blocks.2.attn.a_to_v.weight'],
550
+ "model.diffusion_model.joint_blocks.2.x_block.mlp.fc1.bias": "blocks.2.ff_a.0.bias",
551
+ "model.diffusion_model.joint_blocks.2.x_block.mlp.fc1.weight": "blocks.2.ff_a.0.weight",
552
+ "model.diffusion_model.joint_blocks.2.x_block.mlp.fc2.bias": "blocks.2.ff_a.2.bias",
553
+ "model.diffusion_model.joint_blocks.2.x_block.mlp.fc2.weight": "blocks.2.ff_a.2.weight",
554
+ "model.diffusion_model.joint_blocks.20.context_block.adaLN_modulation.1.bias": "blocks.20.norm1_b.linear.bias",
555
+ "model.diffusion_model.joint_blocks.20.context_block.adaLN_modulation.1.weight": "blocks.20.norm1_b.linear.weight",
556
+ "model.diffusion_model.joint_blocks.20.context_block.attn.proj.bias": "blocks.20.attn.b_to_out.bias",
557
+ "model.diffusion_model.joint_blocks.20.context_block.attn.proj.weight": "blocks.20.attn.b_to_out.weight",
558
+ "model.diffusion_model.joint_blocks.20.context_block.attn.qkv.bias": ['blocks.20.attn.b_to_q.bias', 'blocks.20.attn.b_to_k.bias', 'blocks.20.attn.b_to_v.bias'],
559
+ "model.diffusion_model.joint_blocks.20.context_block.attn.qkv.weight": ['blocks.20.attn.b_to_q.weight', 'blocks.20.attn.b_to_k.weight', 'blocks.20.attn.b_to_v.weight'],
560
+ "model.diffusion_model.joint_blocks.20.context_block.mlp.fc1.bias": "blocks.20.ff_b.0.bias",
561
+ "model.diffusion_model.joint_blocks.20.context_block.mlp.fc1.weight": "blocks.20.ff_b.0.weight",
562
+ "model.diffusion_model.joint_blocks.20.context_block.mlp.fc2.bias": "blocks.20.ff_b.2.bias",
563
+ "model.diffusion_model.joint_blocks.20.context_block.mlp.fc2.weight": "blocks.20.ff_b.2.weight",
564
+ "model.diffusion_model.joint_blocks.20.x_block.adaLN_modulation.1.bias": "blocks.20.norm1_a.linear.bias",
565
+ "model.diffusion_model.joint_blocks.20.x_block.adaLN_modulation.1.weight": "blocks.20.norm1_a.linear.weight",
566
+ "model.diffusion_model.joint_blocks.20.x_block.attn.proj.bias": "blocks.20.attn.a_to_out.bias",
567
+ "model.diffusion_model.joint_blocks.20.x_block.attn.proj.weight": "blocks.20.attn.a_to_out.weight",
568
+ "model.diffusion_model.joint_blocks.20.x_block.attn.qkv.bias": ['blocks.20.attn.a_to_q.bias', 'blocks.20.attn.a_to_k.bias', 'blocks.20.attn.a_to_v.bias'],
569
+ "model.diffusion_model.joint_blocks.20.x_block.attn.qkv.weight": ['blocks.20.attn.a_to_q.weight', 'blocks.20.attn.a_to_k.weight', 'blocks.20.attn.a_to_v.weight'],
570
+ "model.diffusion_model.joint_blocks.20.x_block.mlp.fc1.bias": "blocks.20.ff_a.0.bias",
571
+ "model.diffusion_model.joint_blocks.20.x_block.mlp.fc1.weight": "blocks.20.ff_a.0.weight",
572
+ "model.diffusion_model.joint_blocks.20.x_block.mlp.fc2.bias": "blocks.20.ff_a.2.bias",
573
+ "model.diffusion_model.joint_blocks.20.x_block.mlp.fc2.weight": "blocks.20.ff_a.2.weight",
574
+ "model.diffusion_model.joint_blocks.21.context_block.adaLN_modulation.1.bias": "blocks.21.norm1_b.linear.bias",
575
+ "model.diffusion_model.joint_blocks.21.context_block.adaLN_modulation.1.weight": "blocks.21.norm1_b.linear.weight",
576
+ "model.diffusion_model.joint_blocks.21.context_block.attn.proj.bias": "blocks.21.attn.b_to_out.bias",
577
+ "model.diffusion_model.joint_blocks.21.context_block.attn.proj.weight": "blocks.21.attn.b_to_out.weight",
578
+ "model.diffusion_model.joint_blocks.21.context_block.attn.qkv.bias": ['blocks.21.attn.b_to_q.bias', 'blocks.21.attn.b_to_k.bias', 'blocks.21.attn.b_to_v.bias'],
579
+ "model.diffusion_model.joint_blocks.21.context_block.attn.qkv.weight": ['blocks.21.attn.b_to_q.weight', 'blocks.21.attn.b_to_k.weight', 'blocks.21.attn.b_to_v.weight'],
580
+ "model.diffusion_model.joint_blocks.21.context_block.mlp.fc1.bias": "blocks.21.ff_b.0.bias",
581
+ "model.diffusion_model.joint_blocks.21.context_block.mlp.fc1.weight": "blocks.21.ff_b.0.weight",
582
+ "model.diffusion_model.joint_blocks.21.context_block.mlp.fc2.bias": "blocks.21.ff_b.2.bias",
583
+ "model.diffusion_model.joint_blocks.21.context_block.mlp.fc2.weight": "blocks.21.ff_b.2.weight",
584
+ "model.diffusion_model.joint_blocks.21.x_block.adaLN_modulation.1.bias": "blocks.21.norm1_a.linear.bias",
585
+ "model.diffusion_model.joint_blocks.21.x_block.adaLN_modulation.1.weight": "blocks.21.norm1_a.linear.weight",
586
+ "model.diffusion_model.joint_blocks.21.x_block.attn.proj.bias": "blocks.21.attn.a_to_out.bias",
587
+ "model.diffusion_model.joint_blocks.21.x_block.attn.proj.weight": "blocks.21.attn.a_to_out.weight",
588
+ "model.diffusion_model.joint_blocks.21.x_block.attn.qkv.bias": ['blocks.21.attn.a_to_q.bias', 'blocks.21.attn.a_to_k.bias', 'blocks.21.attn.a_to_v.bias'],
589
+ "model.diffusion_model.joint_blocks.21.x_block.attn.qkv.weight": ['blocks.21.attn.a_to_q.weight', 'blocks.21.attn.a_to_k.weight', 'blocks.21.attn.a_to_v.weight'],
590
+ "model.diffusion_model.joint_blocks.21.x_block.mlp.fc1.bias": "blocks.21.ff_a.0.bias",
591
+ "model.diffusion_model.joint_blocks.21.x_block.mlp.fc1.weight": "blocks.21.ff_a.0.weight",
592
+ "model.diffusion_model.joint_blocks.21.x_block.mlp.fc2.bias": "blocks.21.ff_a.2.bias",
593
+ "model.diffusion_model.joint_blocks.21.x_block.mlp.fc2.weight": "blocks.21.ff_a.2.weight",
594
+ "model.diffusion_model.joint_blocks.22.context_block.adaLN_modulation.1.bias": "blocks.22.norm1_b.linear.bias",
595
+ "model.diffusion_model.joint_blocks.22.context_block.adaLN_modulation.1.weight": "blocks.22.norm1_b.linear.weight",
596
+ "model.diffusion_model.joint_blocks.22.context_block.attn.proj.bias": "blocks.22.attn.b_to_out.bias",
597
+ "model.diffusion_model.joint_blocks.22.context_block.attn.proj.weight": "blocks.22.attn.b_to_out.weight",
598
+ "model.diffusion_model.joint_blocks.22.context_block.attn.qkv.bias": ['blocks.22.attn.b_to_q.bias', 'blocks.22.attn.b_to_k.bias', 'blocks.22.attn.b_to_v.bias'],
599
+ "model.diffusion_model.joint_blocks.22.context_block.attn.qkv.weight": ['blocks.22.attn.b_to_q.weight', 'blocks.22.attn.b_to_k.weight', 'blocks.22.attn.b_to_v.weight'],
600
+ "model.diffusion_model.joint_blocks.22.context_block.mlp.fc1.bias": "blocks.22.ff_b.0.bias",
601
+ "model.diffusion_model.joint_blocks.22.context_block.mlp.fc1.weight": "blocks.22.ff_b.0.weight",
602
+ "model.diffusion_model.joint_blocks.22.context_block.mlp.fc2.bias": "blocks.22.ff_b.2.bias",
603
+ "model.diffusion_model.joint_blocks.22.context_block.mlp.fc2.weight": "blocks.22.ff_b.2.weight",
604
+ "model.diffusion_model.joint_blocks.22.x_block.adaLN_modulation.1.bias": "blocks.22.norm1_a.linear.bias",
605
+ "model.diffusion_model.joint_blocks.22.x_block.adaLN_modulation.1.weight": "blocks.22.norm1_a.linear.weight",
606
+ "model.diffusion_model.joint_blocks.22.x_block.attn.proj.bias": "blocks.22.attn.a_to_out.bias",
607
+ "model.diffusion_model.joint_blocks.22.x_block.attn.proj.weight": "blocks.22.attn.a_to_out.weight",
608
+ "model.diffusion_model.joint_blocks.22.x_block.attn.qkv.bias": ['blocks.22.attn.a_to_q.bias', 'blocks.22.attn.a_to_k.bias', 'blocks.22.attn.a_to_v.bias'],
609
+ "model.diffusion_model.joint_blocks.22.x_block.attn.qkv.weight": ['blocks.22.attn.a_to_q.weight', 'blocks.22.attn.a_to_k.weight', 'blocks.22.attn.a_to_v.weight'],
610
+ "model.diffusion_model.joint_blocks.22.x_block.mlp.fc1.bias": "blocks.22.ff_a.0.bias",
611
+ "model.diffusion_model.joint_blocks.22.x_block.mlp.fc1.weight": "blocks.22.ff_a.0.weight",
612
+ "model.diffusion_model.joint_blocks.22.x_block.mlp.fc2.bias": "blocks.22.ff_a.2.bias",
613
+ "model.diffusion_model.joint_blocks.22.x_block.mlp.fc2.weight": "blocks.22.ff_a.2.weight",
614
+ "model.diffusion_model.joint_blocks.23.context_block.attn.qkv.bias": ['blocks.23.attn.b_to_q.bias', 'blocks.23.attn.b_to_k.bias', 'blocks.23.attn.b_to_v.bias'],
615
+ "model.diffusion_model.joint_blocks.23.context_block.attn.qkv.weight": ['blocks.23.attn.b_to_q.weight', 'blocks.23.attn.b_to_k.weight', 'blocks.23.attn.b_to_v.weight'],
616
+ "model.diffusion_model.joint_blocks.23.x_block.adaLN_modulation.1.bias": "blocks.23.norm1_a.linear.bias",
617
+ "model.diffusion_model.joint_blocks.23.x_block.adaLN_modulation.1.weight": "blocks.23.norm1_a.linear.weight",
618
+ "model.diffusion_model.joint_blocks.23.x_block.attn.proj.bias": "blocks.23.attn.a_to_out.bias",
619
+ "model.diffusion_model.joint_blocks.23.x_block.attn.proj.weight": "blocks.23.attn.a_to_out.weight",
620
+ "model.diffusion_model.joint_blocks.23.x_block.attn.qkv.bias": ['blocks.23.attn.a_to_q.bias', 'blocks.23.attn.a_to_k.bias', 'blocks.23.attn.a_to_v.bias'],
621
+ "model.diffusion_model.joint_blocks.23.x_block.attn.qkv.weight": ['blocks.23.attn.a_to_q.weight', 'blocks.23.attn.a_to_k.weight', 'blocks.23.attn.a_to_v.weight'],
622
+ "model.diffusion_model.joint_blocks.23.x_block.mlp.fc1.bias": "blocks.23.ff_a.0.bias",
623
+ "model.diffusion_model.joint_blocks.23.x_block.mlp.fc1.weight": "blocks.23.ff_a.0.weight",
624
+ "model.diffusion_model.joint_blocks.23.x_block.mlp.fc2.bias": "blocks.23.ff_a.2.bias",
625
+ "model.diffusion_model.joint_blocks.23.x_block.mlp.fc2.weight": "blocks.23.ff_a.2.weight",
626
+ "model.diffusion_model.joint_blocks.3.context_block.adaLN_modulation.1.bias": "blocks.3.norm1_b.linear.bias",
627
+ "model.diffusion_model.joint_blocks.3.context_block.adaLN_modulation.1.weight": "blocks.3.norm1_b.linear.weight",
628
+ "model.diffusion_model.joint_blocks.3.context_block.attn.proj.bias": "blocks.3.attn.b_to_out.bias",
629
+ "model.diffusion_model.joint_blocks.3.context_block.attn.proj.weight": "blocks.3.attn.b_to_out.weight",
630
+ "model.diffusion_model.joint_blocks.3.context_block.attn.qkv.bias": ['blocks.3.attn.b_to_q.bias', 'blocks.3.attn.b_to_k.bias', 'blocks.3.attn.b_to_v.bias'],
631
+ "model.diffusion_model.joint_blocks.3.context_block.attn.qkv.weight": ['blocks.3.attn.b_to_q.weight', 'blocks.3.attn.b_to_k.weight', 'blocks.3.attn.b_to_v.weight'],
632
+ "model.diffusion_model.joint_blocks.3.context_block.mlp.fc1.bias": "blocks.3.ff_b.0.bias",
633
+ "model.diffusion_model.joint_blocks.3.context_block.mlp.fc1.weight": "blocks.3.ff_b.0.weight",
634
+ "model.diffusion_model.joint_blocks.3.context_block.mlp.fc2.bias": "blocks.3.ff_b.2.bias",
635
+ "model.diffusion_model.joint_blocks.3.context_block.mlp.fc2.weight": "blocks.3.ff_b.2.weight",
636
+ "model.diffusion_model.joint_blocks.3.x_block.adaLN_modulation.1.bias": "blocks.3.norm1_a.linear.bias",
637
+ "model.diffusion_model.joint_blocks.3.x_block.adaLN_modulation.1.weight": "blocks.3.norm1_a.linear.weight",
638
+ "model.diffusion_model.joint_blocks.3.x_block.attn.proj.bias": "blocks.3.attn.a_to_out.bias",
639
+ "model.diffusion_model.joint_blocks.3.x_block.attn.proj.weight": "blocks.3.attn.a_to_out.weight",
640
+ "model.diffusion_model.joint_blocks.3.x_block.attn.qkv.bias": ['blocks.3.attn.a_to_q.bias', 'blocks.3.attn.a_to_k.bias', 'blocks.3.attn.a_to_v.bias'],
641
+ "model.diffusion_model.joint_blocks.3.x_block.attn.qkv.weight": ['blocks.3.attn.a_to_q.weight', 'blocks.3.attn.a_to_k.weight', 'blocks.3.attn.a_to_v.weight'],
642
+ "model.diffusion_model.joint_blocks.3.x_block.mlp.fc1.bias": "blocks.3.ff_a.0.bias",
643
+ "model.diffusion_model.joint_blocks.3.x_block.mlp.fc1.weight": "blocks.3.ff_a.0.weight",
644
+ "model.diffusion_model.joint_blocks.3.x_block.mlp.fc2.bias": "blocks.3.ff_a.2.bias",
645
+ "model.diffusion_model.joint_blocks.3.x_block.mlp.fc2.weight": "blocks.3.ff_a.2.weight",
646
+ "model.diffusion_model.joint_blocks.4.context_block.adaLN_modulation.1.bias": "blocks.4.norm1_b.linear.bias",
647
+ "model.diffusion_model.joint_blocks.4.context_block.adaLN_modulation.1.weight": "blocks.4.norm1_b.linear.weight",
648
+ "model.diffusion_model.joint_blocks.4.context_block.attn.proj.bias": "blocks.4.attn.b_to_out.bias",
649
+ "model.diffusion_model.joint_blocks.4.context_block.attn.proj.weight": "blocks.4.attn.b_to_out.weight",
650
+ "model.diffusion_model.joint_blocks.4.context_block.attn.qkv.bias": ['blocks.4.attn.b_to_q.bias', 'blocks.4.attn.b_to_k.bias', 'blocks.4.attn.b_to_v.bias'],
651
+ "model.diffusion_model.joint_blocks.4.context_block.attn.qkv.weight": ['blocks.4.attn.b_to_q.weight', 'blocks.4.attn.b_to_k.weight', 'blocks.4.attn.b_to_v.weight'],
652
+ "model.diffusion_model.joint_blocks.4.context_block.mlp.fc1.bias": "blocks.4.ff_b.0.bias",
653
+ "model.diffusion_model.joint_blocks.4.context_block.mlp.fc1.weight": "blocks.4.ff_b.0.weight",
654
+ "model.diffusion_model.joint_blocks.4.context_block.mlp.fc2.bias": "blocks.4.ff_b.2.bias",
655
+ "model.diffusion_model.joint_blocks.4.context_block.mlp.fc2.weight": "blocks.4.ff_b.2.weight",
656
+ "model.diffusion_model.joint_blocks.4.x_block.adaLN_modulation.1.bias": "blocks.4.norm1_a.linear.bias",
657
+ "model.diffusion_model.joint_blocks.4.x_block.adaLN_modulation.1.weight": "blocks.4.norm1_a.linear.weight",
658
+ "model.diffusion_model.joint_blocks.4.x_block.attn.proj.bias": "blocks.4.attn.a_to_out.bias",
659
+ "model.diffusion_model.joint_blocks.4.x_block.attn.proj.weight": "blocks.4.attn.a_to_out.weight",
660
+ "model.diffusion_model.joint_blocks.4.x_block.attn.qkv.bias": ['blocks.4.attn.a_to_q.bias', 'blocks.4.attn.a_to_k.bias', 'blocks.4.attn.a_to_v.bias'],
661
+ "model.diffusion_model.joint_blocks.4.x_block.attn.qkv.weight": ['blocks.4.attn.a_to_q.weight', 'blocks.4.attn.a_to_k.weight', 'blocks.4.attn.a_to_v.weight'],
662
+ "model.diffusion_model.joint_blocks.4.x_block.mlp.fc1.bias": "blocks.4.ff_a.0.bias",
663
+ "model.diffusion_model.joint_blocks.4.x_block.mlp.fc1.weight": "blocks.4.ff_a.0.weight",
664
+ "model.diffusion_model.joint_blocks.4.x_block.mlp.fc2.bias": "blocks.4.ff_a.2.bias",
665
+ "model.diffusion_model.joint_blocks.4.x_block.mlp.fc2.weight": "blocks.4.ff_a.2.weight",
666
+ "model.diffusion_model.joint_blocks.5.context_block.adaLN_modulation.1.bias": "blocks.5.norm1_b.linear.bias",
667
+ "model.diffusion_model.joint_blocks.5.context_block.adaLN_modulation.1.weight": "blocks.5.norm1_b.linear.weight",
668
+ "model.diffusion_model.joint_blocks.5.context_block.attn.proj.bias": "blocks.5.attn.b_to_out.bias",
669
+ "model.diffusion_model.joint_blocks.5.context_block.attn.proj.weight": "blocks.5.attn.b_to_out.weight",
670
+ "model.diffusion_model.joint_blocks.5.context_block.attn.qkv.bias": ['blocks.5.attn.b_to_q.bias', 'blocks.5.attn.b_to_k.bias', 'blocks.5.attn.b_to_v.bias'],
671
+ "model.diffusion_model.joint_blocks.5.context_block.attn.qkv.weight": ['blocks.5.attn.b_to_q.weight', 'blocks.5.attn.b_to_k.weight', 'blocks.5.attn.b_to_v.weight'],
672
+ "model.diffusion_model.joint_blocks.5.context_block.mlp.fc1.bias": "blocks.5.ff_b.0.bias",
673
+ "model.diffusion_model.joint_blocks.5.context_block.mlp.fc1.weight": "blocks.5.ff_b.0.weight",
674
+ "model.diffusion_model.joint_blocks.5.context_block.mlp.fc2.bias": "blocks.5.ff_b.2.bias",
675
+ "model.diffusion_model.joint_blocks.5.context_block.mlp.fc2.weight": "blocks.5.ff_b.2.weight",
676
+ "model.diffusion_model.joint_blocks.5.x_block.adaLN_modulation.1.bias": "blocks.5.norm1_a.linear.bias",
677
+ "model.diffusion_model.joint_blocks.5.x_block.adaLN_modulation.1.weight": "blocks.5.norm1_a.linear.weight",
678
+ "model.diffusion_model.joint_blocks.5.x_block.attn.proj.bias": "blocks.5.attn.a_to_out.bias",
679
+ "model.diffusion_model.joint_blocks.5.x_block.attn.proj.weight": "blocks.5.attn.a_to_out.weight",
680
+ "model.diffusion_model.joint_blocks.5.x_block.attn.qkv.bias": ['blocks.5.attn.a_to_q.bias', 'blocks.5.attn.a_to_k.bias', 'blocks.5.attn.a_to_v.bias'],
681
+ "model.diffusion_model.joint_blocks.5.x_block.attn.qkv.weight": ['blocks.5.attn.a_to_q.weight', 'blocks.5.attn.a_to_k.weight', 'blocks.5.attn.a_to_v.weight'],
682
+ "model.diffusion_model.joint_blocks.5.x_block.mlp.fc1.bias": "blocks.5.ff_a.0.bias",
683
+ "model.diffusion_model.joint_blocks.5.x_block.mlp.fc1.weight": "blocks.5.ff_a.0.weight",
684
+ "model.diffusion_model.joint_blocks.5.x_block.mlp.fc2.bias": "blocks.5.ff_a.2.bias",
685
+ "model.diffusion_model.joint_blocks.5.x_block.mlp.fc2.weight": "blocks.5.ff_a.2.weight",
686
+ "model.diffusion_model.joint_blocks.6.context_block.adaLN_modulation.1.bias": "blocks.6.norm1_b.linear.bias",
687
+ "model.diffusion_model.joint_blocks.6.context_block.adaLN_modulation.1.weight": "blocks.6.norm1_b.linear.weight",
688
+ "model.diffusion_model.joint_blocks.6.context_block.attn.proj.bias": "blocks.6.attn.b_to_out.bias",
689
+ "model.diffusion_model.joint_blocks.6.context_block.attn.proj.weight": "blocks.6.attn.b_to_out.weight",
690
+ "model.diffusion_model.joint_blocks.6.context_block.attn.qkv.bias": ['blocks.6.attn.b_to_q.bias', 'blocks.6.attn.b_to_k.bias', 'blocks.6.attn.b_to_v.bias'],
691
+ "model.diffusion_model.joint_blocks.6.context_block.attn.qkv.weight": ['blocks.6.attn.b_to_q.weight', 'blocks.6.attn.b_to_k.weight', 'blocks.6.attn.b_to_v.weight'],
692
+ "model.diffusion_model.joint_blocks.6.context_block.mlp.fc1.bias": "blocks.6.ff_b.0.bias",
693
+ "model.diffusion_model.joint_blocks.6.context_block.mlp.fc1.weight": "blocks.6.ff_b.0.weight",
694
+ "model.diffusion_model.joint_blocks.6.context_block.mlp.fc2.bias": "blocks.6.ff_b.2.bias",
695
+ "model.diffusion_model.joint_blocks.6.context_block.mlp.fc2.weight": "blocks.6.ff_b.2.weight",
696
+ "model.diffusion_model.joint_blocks.6.x_block.adaLN_modulation.1.bias": "blocks.6.norm1_a.linear.bias",
697
+ "model.diffusion_model.joint_blocks.6.x_block.adaLN_modulation.1.weight": "blocks.6.norm1_a.linear.weight",
698
+ "model.diffusion_model.joint_blocks.6.x_block.attn.proj.bias": "blocks.6.attn.a_to_out.bias",
699
+ "model.diffusion_model.joint_blocks.6.x_block.attn.proj.weight": "blocks.6.attn.a_to_out.weight",
700
+ "model.diffusion_model.joint_blocks.6.x_block.attn.qkv.bias": ['blocks.6.attn.a_to_q.bias', 'blocks.6.attn.a_to_k.bias', 'blocks.6.attn.a_to_v.bias'],
701
+ "model.diffusion_model.joint_blocks.6.x_block.attn.qkv.weight": ['blocks.6.attn.a_to_q.weight', 'blocks.6.attn.a_to_k.weight', 'blocks.6.attn.a_to_v.weight'],
702
+ "model.diffusion_model.joint_blocks.6.x_block.mlp.fc1.bias": "blocks.6.ff_a.0.bias",
703
+ "model.diffusion_model.joint_blocks.6.x_block.mlp.fc1.weight": "blocks.6.ff_a.0.weight",
704
+ "model.diffusion_model.joint_blocks.6.x_block.mlp.fc2.bias": "blocks.6.ff_a.2.bias",
705
+ "model.diffusion_model.joint_blocks.6.x_block.mlp.fc2.weight": "blocks.6.ff_a.2.weight",
706
+ "model.diffusion_model.joint_blocks.7.context_block.adaLN_modulation.1.bias": "blocks.7.norm1_b.linear.bias",
707
+ "model.diffusion_model.joint_blocks.7.context_block.adaLN_modulation.1.weight": "blocks.7.norm1_b.linear.weight",
708
+ "model.diffusion_model.joint_blocks.7.context_block.attn.proj.bias": "blocks.7.attn.b_to_out.bias",
709
+ "model.diffusion_model.joint_blocks.7.context_block.attn.proj.weight": "blocks.7.attn.b_to_out.weight",
710
+ "model.diffusion_model.joint_blocks.7.context_block.attn.qkv.bias": ['blocks.7.attn.b_to_q.bias', 'blocks.7.attn.b_to_k.bias', 'blocks.7.attn.b_to_v.bias'],
711
+ "model.diffusion_model.joint_blocks.7.context_block.attn.qkv.weight": ['blocks.7.attn.b_to_q.weight', 'blocks.7.attn.b_to_k.weight', 'blocks.7.attn.b_to_v.weight'],
712
+ "model.diffusion_model.joint_blocks.7.context_block.mlp.fc1.bias": "blocks.7.ff_b.0.bias",
713
+ "model.diffusion_model.joint_blocks.7.context_block.mlp.fc1.weight": "blocks.7.ff_b.0.weight",
714
+ "model.diffusion_model.joint_blocks.7.context_block.mlp.fc2.bias": "blocks.7.ff_b.2.bias",
715
+ "model.diffusion_model.joint_blocks.7.context_block.mlp.fc2.weight": "blocks.7.ff_b.2.weight",
716
+ "model.diffusion_model.joint_blocks.7.x_block.adaLN_modulation.1.bias": "blocks.7.norm1_a.linear.bias",
717
+ "model.diffusion_model.joint_blocks.7.x_block.adaLN_modulation.1.weight": "blocks.7.norm1_a.linear.weight",
718
+ "model.diffusion_model.joint_blocks.7.x_block.attn.proj.bias": "blocks.7.attn.a_to_out.bias",
719
+ "model.diffusion_model.joint_blocks.7.x_block.attn.proj.weight": "blocks.7.attn.a_to_out.weight",
720
+ "model.diffusion_model.joint_blocks.7.x_block.attn.qkv.bias": ['blocks.7.attn.a_to_q.bias', 'blocks.7.attn.a_to_k.bias', 'blocks.7.attn.a_to_v.bias'],
721
+ "model.diffusion_model.joint_blocks.7.x_block.attn.qkv.weight": ['blocks.7.attn.a_to_q.weight', 'blocks.7.attn.a_to_k.weight', 'blocks.7.attn.a_to_v.weight'],
722
+ "model.diffusion_model.joint_blocks.7.x_block.mlp.fc1.bias": "blocks.7.ff_a.0.bias",
723
+ "model.diffusion_model.joint_blocks.7.x_block.mlp.fc1.weight": "blocks.7.ff_a.0.weight",
724
+ "model.diffusion_model.joint_blocks.7.x_block.mlp.fc2.bias": "blocks.7.ff_a.2.bias",
725
+ "model.diffusion_model.joint_blocks.7.x_block.mlp.fc2.weight": "blocks.7.ff_a.2.weight",
726
+ "model.diffusion_model.joint_blocks.8.context_block.adaLN_modulation.1.bias": "blocks.8.norm1_b.linear.bias",
727
+ "model.diffusion_model.joint_blocks.8.context_block.adaLN_modulation.1.weight": "blocks.8.norm1_b.linear.weight",
728
+ "model.diffusion_model.joint_blocks.8.context_block.attn.proj.bias": "blocks.8.attn.b_to_out.bias",
729
+ "model.diffusion_model.joint_blocks.8.context_block.attn.proj.weight": "blocks.8.attn.b_to_out.weight",
730
+ "model.diffusion_model.joint_blocks.8.context_block.attn.qkv.bias": ['blocks.8.attn.b_to_q.bias', 'blocks.8.attn.b_to_k.bias', 'blocks.8.attn.b_to_v.bias'],
731
+ "model.diffusion_model.joint_blocks.8.context_block.attn.qkv.weight": ['blocks.8.attn.b_to_q.weight', 'blocks.8.attn.b_to_k.weight', 'blocks.8.attn.b_to_v.weight'],
732
+ "model.diffusion_model.joint_blocks.8.context_block.mlp.fc1.bias": "blocks.8.ff_b.0.bias",
733
+ "model.diffusion_model.joint_blocks.8.context_block.mlp.fc1.weight": "blocks.8.ff_b.0.weight",
734
+ "model.diffusion_model.joint_blocks.8.context_block.mlp.fc2.bias": "blocks.8.ff_b.2.bias",
735
+ "model.diffusion_model.joint_blocks.8.context_block.mlp.fc2.weight": "blocks.8.ff_b.2.weight",
736
+ "model.diffusion_model.joint_blocks.8.x_block.adaLN_modulation.1.bias": "blocks.8.norm1_a.linear.bias",
737
+ "model.diffusion_model.joint_blocks.8.x_block.adaLN_modulation.1.weight": "blocks.8.norm1_a.linear.weight",
738
+ "model.diffusion_model.joint_blocks.8.x_block.attn.proj.bias": "blocks.8.attn.a_to_out.bias",
739
+ "model.diffusion_model.joint_blocks.8.x_block.attn.proj.weight": "blocks.8.attn.a_to_out.weight",
740
+ "model.diffusion_model.joint_blocks.8.x_block.attn.qkv.bias": ['blocks.8.attn.a_to_q.bias', 'blocks.8.attn.a_to_k.bias', 'blocks.8.attn.a_to_v.bias'],
741
+ "model.diffusion_model.joint_blocks.8.x_block.attn.qkv.weight": ['blocks.8.attn.a_to_q.weight', 'blocks.8.attn.a_to_k.weight', 'blocks.8.attn.a_to_v.weight'],
742
+ "model.diffusion_model.joint_blocks.8.x_block.mlp.fc1.bias": "blocks.8.ff_a.0.bias",
743
+ "model.diffusion_model.joint_blocks.8.x_block.mlp.fc1.weight": "blocks.8.ff_a.0.weight",
744
+ "model.diffusion_model.joint_blocks.8.x_block.mlp.fc2.bias": "blocks.8.ff_a.2.bias",
745
+ "model.diffusion_model.joint_blocks.8.x_block.mlp.fc2.weight": "blocks.8.ff_a.2.weight",
746
+ "model.diffusion_model.joint_blocks.9.context_block.adaLN_modulation.1.bias": "blocks.9.norm1_b.linear.bias",
747
+ "model.diffusion_model.joint_blocks.9.context_block.adaLN_modulation.1.weight": "blocks.9.norm1_b.linear.weight",
748
+ "model.diffusion_model.joint_blocks.9.context_block.attn.proj.bias": "blocks.9.attn.b_to_out.bias",
749
+ "model.diffusion_model.joint_blocks.9.context_block.attn.proj.weight": "blocks.9.attn.b_to_out.weight",
750
+ "model.diffusion_model.joint_blocks.9.context_block.attn.qkv.bias": ['blocks.9.attn.b_to_q.bias', 'blocks.9.attn.b_to_k.bias', 'blocks.9.attn.b_to_v.bias'],
751
+ "model.diffusion_model.joint_blocks.9.context_block.attn.qkv.weight": ['blocks.9.attn.b_to_q.weight', 'blocks.9.attn.b_to_k.weight', 'blocks.9.attn.b_to_v.weight'],
752
+ "model.diffusion_model.joint_blocks.9.context_block.mlp.fc1.bias": "blocks.9.ff_b.0.bias",
753
+ "model.diffusion_model.joint_blocks.9.context_block.mlp.fc1.weight": "blocks.9.ff_b.0.weight",
754
+ "model.diffusion_model.joint_blocks.9.context_block.mlp.fc2.bias": "blocks.9.ff_b.2.bias",
755
+ "model.diffusion_model.joint_blocks.9.context_block.mlp.fc2.weight": "blocks.9.ff_b.2.weight",
756
+ "model.diffusion_model.joint_blocks.9.x_block.adaLN_modulation.1.bias": "blocks.9.norm1_a.linear.bias",
757
+ "model.diffusion_model.joint_blocks.9.x_block.adaLN_modulation.1.weight": "blocks.9.norm1_a.linear.weight",
758
+ "model.diffusion_model.joint_blocks.9.x_block.attn.proj.bias": "blocks.9.attn.a_to_out.bias",
759
+ "model.diffusion_model.joint_blocks.9.x_block.attn.proj.weight": "blocks.9.attn.a_to_out.weight",
760
+ "model.diffusion_model.joint_blocks.9.x_block.attn.qkv.bias": ['blocks.9.attn.a_to_q.bias', 'blocks.9.attn.a_to_k.bias', 'blocks.9.attn.a_to_v.bias'],
761
+ "model.diffusion_model.joint_blocks.9.x_block.attn.qkv.weight": ['blocks.9.attn.a_to_q.weight', 'blocks.9.attn.a_to_k.weight', 'blocks.9.attn.a_to_v.weight'],
762
+ "model.diffusion_model.joint_blocks.9.x_block.mlp.fc1.bias": "blocks.9.ff_a.0.bias",
763
+ "model.diffusion_model.joint_blocks.9.x_block.mlp.fc1.weight": "blocks.9.ff_a.0.weight",
764
+ "model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.bias": "blocks.9.ff_a.2.bias",
765
+ "model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.weight": "blocks.9.ff_a.2.weight",
766
+ "model.diffusion_model.pos_embed": "pos_embedder.pos_embed",
767
+ "model.diffusion_model.t_embedder.mlp.0.bias": "time_embedder.timestep_embedder.0.bias",
768
+ "model.diffusion_model.t_embedder.mlp.0.weight": "time_embedder.timestep_embedder.0.weight",
769
+ "model.diffusion_model.t_embedder.mlp.2.bias": "time_embedder.timestep_embedder.2.bias",
770
+ "model.diffusion_model.t_embedder.mlp.2.weight": "time_embedder.timestep_embedder.2.weight",
771
+ "model.diffusion_model.x_embedder.proj.bias": "pos_embedder.proj.bias",
772
+ "model.diffusion_model.x_embedder.proj.weight": "pos_embedder.proj.weight",
773
+ "model.diffusion_model.y_embedder.mlp.0.bias": "pooled_text_embedder.0.bias",
774
+ "model.diffusion_model.y_embedder.mlp.0.weight": "pooled_text_embedder.0.weight",
775
+ "model.diffusion_model.y_embedder.mlp.2.bias": "pooled_text_embedder.2.bias",
776
+ "model.diffusion_model.y_embedder.mlp.2.weight": "pooled_text_embedder.2.weight",
777
+
778
+ "model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.weight": "blocks.23.norm1_b.linear.weight",
779
+ "model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.bias": "blocks.23.norm1_b.linear.bias",
780
+ "model.diffusion_model.final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
781
+ "model.diffusion_model.final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
782
+ }
783
+ state_dict_ = {}
784
+ for name in state_dict:
785
+ if name in rename_dict:
786
+ param = state_dict[name]
787
+ if name.startswith("model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1."):
788
+ param = torch.concat([param[1536:], param[:1536]], axis=0)
789
+ elif name.startswith("model.diffusion_model.final_layer.adaLN_modulation.1."):
790
+ param = torch.concat([param[1536:], param[:1536]], axis=0)
791
+ elif name == "model.diffusion_model.pos_embed":
792
+ param = param.reshape((1, 192, 192, 1536))
793
+ if isinstance(rename_dict[name], str):
794
+ state_dict_[rename_dict[name]] = param
795
+ else:
796
+ name_ = rename_dict[name][0].replace(".a_to_q.", ".a_to_qkv.").replace(".b_to_q.", ".b_to_qkv.")
797
+ state_dict_[name_] = param
798
+ return state_dict_
diffsynth/models/sd3_text_encoder.py ADDED
The diff for this file is too large to render. See raw diff
 
diffsynth/models/sd3_vae_decoder.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .sd_vae_decoder import VAEAttentionBlock, SDVAEDecoderStateDictConverter
3
+ from .sd_unet import ResnetBlock, UpSampler
4
+ from .tiler import TileWorker
5
+
6
+
7
+
8
+ class SD3VAEDecoder(torch.nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+ self.scaling_factor = 1.5305 # Different from SD 1.x
12
+ self.shift_factor = 0.0609 # Different from SD 1.x
13
+ self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x
14
+
15
+ self.blocks = torch.nn.ModuleList([
16
+ # UNetMidBlock2D
17
+ ResnetBlock(512, 512, eps=1e-6),
18
+ VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
19
+ ResnetBlock(512, 512, eps=1e-6),
20
+ # UpDecoderBlock2D
21
+ ResnetBlock(512, 512, eps=1e-6),
22
+ ResnetBlock(512, 512, eps=1e-6),
23
+ ResnetBlock(512, 512, eps=1e-6),
24
+ UpSampler(512),
25
+ # UpDecoderBlock2D
26
+ ResnetBlock(512, 512, eps=1e-6),
27
+ ResnetBlock(512, 512, eps=1e-6),
28
+ ResnetBlock(512, 512, eps=1e-6),
29
+ UpSampler(512),
30
+ # UpDecoderBlock2D
31
+ ResnetBlock(512, 256, eps=1e-6),
32
+ ResnetBlock(256, 256, eps=1e-6),
33
+ ResnetBlock(256, 256, eps=1e-6),
34
+ UpSampler(256),
35
+ # UpDecoderBlock2D
36
+ ResnetBlock(256, 128, eps=1e-6),
37
+ ResnetBlock(128, 128, eps=1e-6),
38
+ ResnetBlock(128, 128, eps=1e-6),
39
+ ])
40
+
41
+ self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6)
42
+ self.conv_act = torch.nn.SiLU()
43
+ self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
44
+
45
+ def tiled_forward(self, sample, tile_size=64, tile_stride=32):
46
+ hidden_states = TileWorker().tiled_forward(
47
+ lambda x: self.forward(x),
48
+ sample,
49
+ tile_size,
50
+ tile_stride,
51
+ tile_device=sample.device,
52
+ tile_dtype=sample.dtype
53
+ )
54
+ return hidden_states
55
+
56
+ def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
57
+ # For VAE Decoder, we do not need to apply the tiler on each layer.
58
+ if tiled:
59
+ return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
60
+
61
+ # 1. pre-process
62
+ hidden_states = sample / self.scaling_factor + self.shift_factor
63
+ hidden_states = self.conv_in(hidden_states)
64
+ time_emb = None
65
+ text_emb = None
66
+ res_stack = None
67
+
68
+ # 2. blocks
69
+ for i, block in enumerate(self.blocks):
70
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
71
+
72
+ # 3. output
73
+ hidden_states = self.conv_norm_out(hidden_states)
74
+ hidden_states = self.conv_act(hidden_states)
75
+ hidden_states = self.conv_out(hidden_states)
76
+
77
+ return hidden_states
78
+
79
+ @staticmethod
80
+ def state_dict_converter():
81
+ return SDVAEDecoderStateDictConverter()
diffsynth/models/sd3_vae_encoder.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .sd_unet import ResnetBlock, DownSampler
3
+ from .sd_vae_encoder import VAEAttentionBlock, SDVAEEncoderStateDictConverter
4
+ from .tiler import TileWorker
5
+ from einops import rearrange
6
+
7
+
8
+ class SD3VAEEncoder(torch.nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+ self.scaling_factor = 1.5305 # Different from SD 1.x
12
+ self.shift_factor = 0.0609 # Different from SD 1.x
13
+ self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
14
+
15
+ self.blocks = torch.nn.ModuleList([
16
+ # DownEncoderBlock2D
17
+ ResnetBlock(128, 128, eps=1e-6),
18
+ ResnetBlock(128, 128, eps=1e-6),
19
+ DownSampler(128, padding=0, extra_padding=True),
20
+ # DownEncoderBlock2D
21
+ ResnetBlock(128, 256, eps=1e-6),
22
+ ResnetBlock(256, 256, eps=1e-6),
23
+ DownSampler(256, padding=0, extra_padding=True),
24
+ # DownEncoderBlock2D
25
+ ResnetBlock(256, 512, eps=1e-6),
26
+ ResnetBlock(512, 512, eps=1e-6),
27
+ DownSampler(512, padding=0, extra_padding=True),
28
+ # DownEncoderBlock2D
29
+ ResnetBlock(512, 512, eps=1e-6),
30
+ ResnetBlock(512, 512, eps=1e-6),
31
+ # UNetMidBlock2D
32
+ ResnetBlock(512, 512, eps=1e-6),
33
+ VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
34
+ ResnetBlock(512, 512, eps=1e-6),
35
+ ])
36
+
37
+ self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
38
+ self.conv_act = torch.nn.SiLU()
39
+ self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1)
40
+
41
+ def tiled_forward(self, sample, tile_size=64, tile_stride=32):
42
+ hidden_states = TileWorker().tiled_forward(
43
+ lambda x: self.forward(x),
44
+ sample,
45
+ tile_size,
46
+ tile_stride,
47
+ tile_device=sample.device,
48
+ tile_dtype=sample.dtype
49
+ )
50
+ return hidden_states
51
+
52
+ def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
53
+ # For VAE Decoder, we do not need to apply the tiler on each layer.
54
+ if tiled:
55
+ return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
56
+
57
+ # 1. pre-process
58
+ hidden_states = self.conv_in(sample)
59
+ time_emb = None
60
+ text_emb = None
61
+ res_stack = None
62
+
63
+ # 2. blocks
64
+ for i, block in enumerate(self.blocks):
65
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
66
+
67
+ # 3. output
68
+ hidden_states = self.conv_norm_out(hidden_states)
69
+ hidden_states = self.conv_act(hidden_states)
70
+ hidden_states = self.conv_out(hidden_states)
71
+ hidden_states = hidden_states[:, :16]
72
+ hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor
73
+
74
+ return hidden_states
75
+
76
+ def encode_video(self, sample, batch_size=8):
77
+ B = sample.shape[0]
78
+ hidden_states = []
79
+
80
+ for i in range(0, sample.shape[2], batch_size):
81
+
82
+ j = min(i + batch_size, sample.shape[2])
83
+ sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
84
+
85
+ hidden_states_batch = self(sample_batch)
86
+ hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
87
+
88
+ hidden_states.append(hidden_states_batch)
89
+
90
+ hidden_states = torch.concat(hidden_states, dim=2)
91
+ return hidden_states
92
+
93
+ @staticmethod
94
+ def state_dict_converter():
95
+ return SDVAEEncoderStateDictConverter()
diffsynth/models/sd_controlnet.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
3
+ from .tiler import TileWorker
4
+
5
+
6
+ class ControlNetConditioningLayer(torch.nn.Module):
7
+ def __init__(self, channels = (3, 16, 32, 96, 256, 320)):
8
+ super().__init__()
9
+ self.blocks = torch.nn.ModuleList([])
10
+ self.blocks.append(torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1))
11
+ self.blocks.append(torch.nn.SiLU())
12
+ for i in range(1, len(channels) - 2):
13
+ self.blocks.append(torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1))
14
+ self.blocks.append(torch.nn.SiLU())
15
+ self.blocks.append(torch.nn.Conv2d(channels[i], channels[i+1], kernel_size=3, padding=1, stride=2))
16
+ self.blocks.append(torch.nn.SiLU())
17
+ self.blocks.append(torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1))
18
+
19
+ def forward(self, conditioning):
20
+ for block in self.blocks:
21
+ conditioning = block(conditioning)
22
+ return conditioning
23
+
24
+
25
+ class SDControlNet(torch.nn.Module):
26
+ def __init__(self, global_pool=False):
27
+ super().__init__()
28
+ self.time_proj = Timesteps(320)
29
+ self.time_embedding = torch.nn.Sequential(
30
+ torch.nn.Linear(320, 1280),
31
+ torch.nn.SiLU(),
32
+ torch.nn.Linear(1280, 1280)
33
+ )
34
+ self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
35
+
36
+ self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
37
+
38
+ self.blocks = torch.nn.ModuleList([
39
+ # CrossAttnDownBlock2D
40
+ ResnetBlock(320, 320, 1280),
41
+ AttentionBlock(8, 40, 320, 1, 768),
42
+ PushBlock(),
43
+ ResnetBlock(320, 320, 1280),
44
+ AttentionBlock(8, 40, 320, 1, 768),
45
+ PushBlock(),
46
+ DownSampler(320),
47
+ PushBlock(),
48
+ # CrossAttnDownBlock2D
49
+ ResnetBlock(320, 640, 1280),
50
+ AttentionBlock(8, 80, 640, 1, 768),
51
+ PushBlock(),
52
+ ResnetBlock(640, 640, 1280),
53
+ AttentionBlock(8, 80, 640, 1, 768),
54
+ PushBlock(),
55
+ DownSampler(640),
56
+ PushBlock(),
57
+ # CrossAttnDownBlock2D
58
+ ResnetBlock(640, 1280, 1280),
59
+ AttentionBlock(8, 160, 1280, 1, 768),
60
+ PushBlock(),
61
+ ResnetBlock(1280, 1280, 1280),
62
+ AttentionBlock(8, 160, 1280, 1, 768),
63
+ PushBlock(),
64
+ DownSampler(1280),
65
+ PushBlock(),
66
+ # DownBlock2D
67
+ ResnetBlock(1280, 1280, 1280),
68
+ PushBlock(),
69
+ ResnetBlock(1280, 1280, 1280),
70
+ PushBlock(),
71
+ # UNetMidBlock2DCrossAttn
72
+ ResnetBlock(1280, 1280, 1280),
73
+ AttentionBlock(8, 160, 1280, 1, 768),
74
+ ResnetBlock(1280, 1280, 1280),
75
+ PushBlock()
76
+ ])
77
+
78
+ self.controlnet_blocks = torch.nn.ModuleList([
79
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
80
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
81
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
82
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
83
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
84
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
85
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
86
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
87
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
88
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
89
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
90
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
91
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
92
+ ])
93
+
94
+ self.global_pool = global_pool
95
+
96
+ def forward(
97
+ self,
98
+ sample, timestep, encoder_hidden_states, conditioning,
99
+ tiled=False, tile_size=64, tile_stride=32,
100
+ **kwargs
101
+ ):
102
+ # 1. time
103
+ time_emb = self.time_proj(timestep).to(sample.dtype)
104
+ time_emb = self.time_embedding(time_emb)
105
+ time_emb = time_emb.repeat(sample.shape[0], 1)
106
+
107
+ # 2. pre-process
108
+ height, width = sample.shape[2], sample.shape[3]
109
+ hidden_states = self.conv_in(sample) + self.controlnet_conv_in(conditioning)
110
+ text_emb = encoder_hidden_states
111
+ res_stack = [hidden_states]
112
+
113
+ # 3. blocks
114
+ for i, block in enumerate(self.blocks):
115
+ if tiled and not isinstance(block, PushBlock):
116
+ _, _, inter_height, _ = hidden_states.shape
117
+ resize_scale = inter_height / height
118
+ hidden_states = TileWorker().tiled_forward(
119
+ lambda x: block(x, time_emb, text_emb, res_stack)[0],
120
+ hidden_states,
121
+ int(tile_size * resize_scale),
122
+ int(tile_stride * resize_scale),
123
+ tile_device=hidden_states.device,
124
+ tile_dtype=hidden_states.dtype
125
+ )
126
+ else:
127
+ hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
128
+
129
+ # 4. ControlNet blocks
130
+ controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
131
+
132
+ # pool
133
+ if self.global_pool:
134
+ controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack]
135
+
136
+ return controlnet_res_stack
137
+
138
+ @staticmethod
139
+ def state_dict_converter():
140
+ return SDControlNetStateDictConverter()
141
+
142
+
143
+ class SDControlNetStateDictConverter:
144
+ def __init__(self):
145
+ pass
146
+
147
+ def from_diffusers(self, state_dict):
148
+ # architecture
149
+ block_types = [
150
+ 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
151
+ 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
152
+ 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
153
+ 'ResnetBlock', 'PushBlock', 'ResnetBlock', 'PushBlock',
154
+ 'ResnetBlock', 'AttentionBlock', 'ResnetBlock',
155
+ 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'UpSampler',
156
+ 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
157
+ 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
158
+ 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock'
159
+ ]
160
+
161
+ # controlnet_rename_dict
162
+ controlnet_rename_dict = {
163
+ "controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight",
164
+ "controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias",
165
+ "controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight",
166
+ "controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias",
167
+ "controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight",
168
+ "controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias",
169
+ "controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight",
170
+ "controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias",
171
+ "controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight",
172
+ "controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias",
173
+ "controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight",
174
+ "controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias",
175
+ "controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
176
+ "controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
177
+ "controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
178
+ "controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
179
+ }
180
+
181
+ # Rename each parameter
182
+ name_list = sorted([name for name in state_dict])
183
+ rename_dict = {}
184
+ block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
185
+ last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
186
+ for name in name_list:
187
+ names = name.split(".")
188
+ if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
189
+ pass
190
+ elif name in controlnet_rename_dict:
191
+ names = controlnet_rename_dict[name].split(".")
192
+ elif names[0] == "controlnet_down_blocks":
193
+ names[0] = "controlnet_blocks"
194
+ elif names[0] == "controlnet_mid_block":
195
+ names = ["controlnet_blocks", "12", names[-1]]
196
+ elif names[0] in ["time_embedding", "add_embedding"]:
197
+ if names[0] == "add_embedding":
198
+ names[0] = "add_time_embedding"
199
+ names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
200
+ elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
201
+ if names[0] == "mid_block":
202
+ names.insert(1, "0")
203
+ block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
204
+ block_type_with_id = ".".join(names[:4])
205
+ if block_type_with_id != last_block_type_with_id[block_type]:
206
+ block_id[block_type] += 1
207
+ last_block_type_with_id[block_type] = block_type_with_id
208
+ while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
209
+ block_id[block_type] += 1
210
+ block_type_with_id = ".".join(names[:4])
211
+ names = ["blocks", str(block_id[block_type])] + names[4:]
212
+ if "ff" in names:
213
+ ff_index = names.index("ff")
214
+ component = ".".join(names[ff_index:ff_index+3])
215
+ component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
216
+ names = names[:ff_index] + [component] + names[ff_index+3:]
217
+ if "to_out" in names:
218
+ names.pop(names.index("to_out") + 1)
219
+ else:
220
+ raise ValueError(f"Unknown parameters: {name}")
221
+ rename_dict[name] = ".".join(names)
222
+
223
+ # Convert state_dict
224
+ state_dict_ = {}
225
+ for name, param in state_dict.items():
226
+ if ".proj_in." in name or ".proj_out." in name:
227
+ param = param.squeeze()
228
+ if rename_dict[name] in [
229
+ "controlnet_blocks.1.bias", "controlnet_blocks.2.bias", "controlnet_blocks.3.bias", "controlnet_blocks.5.bias", "controlnet_blocks.6.bias",
230
+ "controlnet_blocks.8.bias", "controlnet_blocks.9.bias", "controlnet_blocks.10.bias", "controlnet_blocks.11.bias", "controlnet_blocks.12.bias"
231
+ ]:
232
+ continue
233
+ state_dict_[rename_dict[name]] = param
234
+ return state_dict_
235
+
236
+ def from_civitai(self, state_dict):
237
+ if "mid_block.resnets.1.time_emb_proj.weight" in state_dict:
238
+ # For controlnets in diffusers format
239
+ return self.from_diffusers(state_dict)
240
+ rename_dict = {
241
+ "control_model.time_embed.0.weight": "time_embedding.0.weight",
242
+ "control_model.time_embed.0.bias": "time_embedding.0.bias",
243
+ "control_model.time_embed.2.weight": "time_embedding.2.weight",
244
+ "control_model.time_embed.2.bias": "time_embedding.2.bias",
245
+ "control_model.input_blocks.0.0.weight": "conv_in.weight",
246
+ "control_model.input_blocks.0.0.bias": "conv_in.bias",
247
+ "control_model.input_blocks.1.0.in_layers.0.weight": "blocks.0.norm1.weight",
248
+ "control_model.input_blocks.1.0.in_layers.0.bias": "blocks.0.norm1.bias",
249
+ "control_model.input_blocks.1.0.in_layers.2.weight": "blocks.0.conv1.weight",
250
+ "control_model.input_blocks.1.0.in_layers.2.bias": "blocks.0.conv1.bias",
251
+ "control_model.input_blocks.1.0.emb_layers.1.weight": "blocks.0.time_emb_proj.weight",
252
+ "control_model.input_blocks.1.0.emb_layers.1.bias": "blocks.0.time_emb_proj.bias",
253
+ "control_model.input_blocks.1.0.out_layers.0.weight": "blocks.0.norm2.weight",
254
+ "control_model.input_blocks.1.0.out_layers.0.bias": "blocks.0.norm2.bias",
255
+ "control_model.input_blocks.1.0.out_layers.3.weight": "blocks.0.conv2.weight",
256
+ "control_model.input_blocks.1.0.out_layers.3.bias": "blocks.0.conv2.bias",
257
+ "control_model.input_blocks.1.1.norm.weight": "blocks.1.norm.weight",
258
+ "control_model.input_blocks.1.1.norm.bias": "blocks.1.norm.bias",
259
+ "control_model.input_blocks.1.1.proj_in.weight": "blocks.1.proj_in.weight",
260
+ "control_model.input_blocks.1.1.proj_in.bias": "blocks.1.proj_in.bias",
261
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "blocks.1.transformer_blocks.0.attn1.to_q.weight",
262
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "blocks.1.transformer_blocks.0.attn1.to_k.weight",
263
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "blocks.1.transformer_blocks.0.attn1.to_v.weight",
264
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.1.transformer_blocks.0.attn1.to_out.weight",
265
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.1.transformer_blocks.0.attn1.to_out.bias",
266
+ "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.1.transformer_blocks.0.act_fn.proj.weight",
267
+ "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.1.transformer_blocks.0.act_fn.proj.bias",
268
+ "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "blocks.1.transformer_blocks.0.ff.weight",
269
+ "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "blocks.1.transformer_blocks.0.ff.bias",
270
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "blocks.1.transformer_blocks.0.attn2.to_q.weight",
271
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "blocks.1.transformer_blocks.0.attn2.to_k.weight",
272
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "blocks.1.transformer_blocks.0.attn2.to_v.weight",
273
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.1.transformer_blocks.0.attn2.to_out.weight",
274
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.1.transformer_blocks.0.attn2.to_out.bias",
275
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "blocks.1.transformer_blocks.0.norm1.weight",
276
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "blocks.1.transformer_blocks.0.norm1.bias",
277
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "blocks.1.transformer_blocks.0.norm2.weight",
278
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "blocks.1.transformer_blocks.0.norm2.bias",
279
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "blocks.1.transformer_blocks.0.norm3.weight",
280
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "blocks.1.transformer_blocks.0.norm3.bias",
281
+ "control_model.input_blocks.1.1.proj_out.weight": "blocks.1.proj_out.weight",
282
+ "control_model.input_blocks.1.1.proj_out.bias": "blocks.1.proj_out.bias",
283
+ "control_model.input_blocks.2.0.in_layers.0.weight": "blocks.3.norm1.weight",
284
+ "control_model.input_blocks.2.0.in_layers.0.bias": "blocks.3.norm1.bias",
285
+ "control_model.input_blocks.2.0.in_layers.2.weight": "blocks.3.conv1.weight",
286
+ "control_model.input_blocks.2.0.in_layers.2.bias": "blocks.3.conv1.bias",
287
+ "control_model.input_blocks.2.0.emb_layers.1.weight": "blocks.3.time_emb_proj.weight",
288
+ "control_model.input_blocks.2.0.emb_layers.1.bias": "blocks.3.time_emb_proj.bias",
289
+ "control_model.input_blocks.2.0.out_layers.0.weight": "blocks.3.norm2.weight",
290
+ "control_model.input_blocks.2.0.out_layers.0.bias": "blocks.3.norm2.bias",
291
+ "control_model.input_blocks.2.0.out_layers.3.weight": "blocks.3.conv2.weight",
292
+ "control_model.input_blocks.2.0.out_layers.3.bias": "blocks.3.conv2.bias",
293
+ "control_model.input_blocks.2.1.norm.weight": "blocks.4.norm.weight",
294
+ "control_model.input_blocks.2.1.norm.bias": "blocks.4.norm.bias",
295
+ "control_model.input_blocks.2.1.proj_in.weight": "blocks.4.proj_in.weight",
296
+ "control_model.input_blocks.2.1.proj_in.bias": "blocks.4.proj_in.bias",
297
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "blocks.4.transformer_blocks.0.attn1.to_q.weight",
298
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "blocks.4.transformer_blocks.0.attn1.to_k.weight",
299
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "blocks.4.transformer_blocks.0.attn1.to_v.weight",
300
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.4.transformer_blocks.0.attn1.to_out.weight",
301
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.4.transformer_blocks.0.attn1.to_out.bias",
302
+ "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.4.transformer_blocks.0.act_fn.proj.weight",
303
+ "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.4.transformer_blocks.0.act_fn.proj.bias",
304
+ "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "blocks.4.transformer_blocks.0.ff.weight",
305
+ "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "blocks.4.transformer_blocks.0.ff.bias",
306
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "blocks.4.transformer_blocks.0.attn2.to_q.weight",
307
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "blocks.4.transformer_blocks.0.attn2.to_k.weight",
308
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "blocks.4.transformer_blocks.0.attn2.to_v.weight",
309
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.4.transformer_blocks.0.attn2.to_out.weight",
310
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.4.transformer_blocks.0.attn2.to_out.bias",
311
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "blocks.4.transformer_blocks.0.norm1.weight",
312
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "blocks.4.transformer_blocks.0.norm1.bias",
313
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "blocks.4.transformer_blocks.0.norm2.weight",
314
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "blocks.4.transformer_blocks.0.norm2.bias",
315
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "blocks.4.transformer_blocks.0.norm3.weight",
316
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "blocks.4.transformer_blocks.0.norm3.bias",
317
+ "control_model.input_blocks.2.1.proj_out.weight": "blocks.4.proj_out.weight",
318
+ "control_model.input_blocks.2.1.proj_out.bias": "blocks.4.proj_out.bias",
319
+ "control_model.input_blocks.3.0.op.weight": "blocks.6.conv.weight",
320
+ "control_model.input_blocks.3.0.op.bias": "blocks.6.conv.bias",
321
+ "control_model.input_blocks.4.0.in_layers.0.weight": "blocks.8.norm1.weight",
322
+ "control_model.input_blocks.4.0.in_layers.0.bias": "blocks.8.norm1.bias",
323
+ "control_model.input_blocks.4.0.in_layers.2.weight": "blocks.8.conv1.weight",
324
+ "control_model.input_blocks.4.0.in_layers.2.bias": "blocks.8.conv1.bias",
325
+ "control_model.input_blocks.4.0.emb_layers.1.weight": "blocks.8.time_emb_proj.weight",
326
+ "control_model.input_blocks.4.0.emb_layers.1.bias": "blocks.8.time_emb_proj.bias",
327
+ "control_model.input_blocks.4.0.out_layers.0.weight": "blocks.8.norm2.weight",
328
+ "control_model.input_blocks.4.0.out_layers.0.bias": "blocks.8.norm2.bias",
329
+ "control_model.input_blocks.4.0.out_layers.3.weight": "blocks.8.conv2.weight",
330
+ "control_model.input_blocks.4.0.out_layers.3.bias": "blocks.8.conv2.bias",
331
+ "control_model.input_blocks.4.0.skip_connection.weight": "blocks.8.conv_shortcut.weight",
332
+ "control_model.input_blocks.4.0.skip_connection.bias": "blocks.8.conv_shortcut.bias",
333
+ "control_model.input_blocks.4.1.norm.weight": "blocks.9.norm.weight",
334
+ "control_model.input_blocks.4.1.norm.bias": "blocks.9.norm.bias",
335
+ "control_model.input_blocks.4.1.proj_in.weight": "blocks.9.proj_in.weight",
336
+ "control_model.input_blocks.4.1.proj_in.bias": "blocks.9.proj_in.bias",
337
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.9.transformer_blocks.0.attn1.to_q.weight",
338
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.9.transformer_blocks.0.attn1.to_k.weight",
339
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.9.transformer_blocks.0.attn1.to_v.weight",
340
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.9.transformer_blocks.0.attn1.to_out.weight",
341
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.9.transformer_blocks.0.attn1.to_out.bias",
342
+ "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.9.transformer_blocks.0.act_fn.proj.weight",
343
+ "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.9.transformer_blocks.0.act_fn.proj.bias",
344
+ "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.9.transformer_blocks.0.ff.weight",
345
+ "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.9.transformer_blocks.0.ff.bias",
346
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.9.transformer_blocks.0.attn2.to_q.weight",
347
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.9.transformer_blocks.0.attn2.to_k.weight",
348
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.9.transformer_blocks.0.attn2.to_v.weight",
349
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.9.transformer_blocks.0.attn2.to_out.weight",
350
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.9.transformer_blocks.0.attn2.to_out.bias",
351
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.9.transformer_blocks.0.norm1.weight",
352
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.9.transformer_blocks.0.norm1.bias",
353
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.9.transformer_blocks.0.norm2.weight",
354
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.9.transformer_blocks.0.norm2.bias",
355
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.9.transformer_blocks.0.norm3.weight",
356
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.9.transformer_blocks.0.norm3.bias",
357
+ "control_model.input_blocks.4.1.proj_out.weight": "blocks.9.proj_out.weight",
358
+ "control_model.input_blocks.4.1.proj_out.bias": "blocks.9.proj_out.bias",
359
+ "control_model.input_blocks.5.0.in_layers.0.weight": "blocks.11.norm1.weight",
360
+ "control_model.input_blocks.5.0.in_layers.0.bias": "blocks.11.norm1.bias",
361
+ "control_model.input_blocks.5.0.in_layers.2.weight": "blocks.11.conv1.weight",
362
+ "control_model.input_blocks.5.0.in_layers.2.bias": "blocks.11.conv1.bias",
363
+ "control_model.input_blocks.5.0.emb_layers.1.weight": "blocks.11.time_emb_proj.weight",
364
+ "control_model.input_blocks.5.0.emb_layers.1.bias": "blocks.11.time_emb_proj.bias",
365
+ "control_model.input_blocks.5.0.out_layers.0.weight": "blocks.11.norm2.weight",
366
+ "control_model.input_blocks.5.0.out_layers.0.bias": "blocks.11.norm2.bias",
367
+ "control_model.input_blocks.5.0.out_layers.3.weight": "blocks.11.conv2.weight",
368
+ "control_model.input_blocks.5.0.out_layers.3.bias": "blocks.11.conv2.bias",
369
+ "control_model.input_blocks.5.1.norm.weight": "blocks.12.norm.weight",
370
+ "control_model.input_blocks.5.1.norm.bias": "blocks.12.norm.bias",
371
+ "control_model.input_blocks.5.1.proj_in.weight": "blocks.12.proj_in.weight",
372
+ "control_model.input_blocks.5.1.proj_in.bias": "blocks.12.proj_in.bias",
373
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.12.transformer_blocks.0.attn1.to_q.weight",
374
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.12.transformer_blocks.0.attn1.to_k.weight",
375
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.12.transformer_blocks.0.attn1.to_v.weight",
376
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.12.transformer_blocks.0.attn1.to_out.weight",
377
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.12.transformer_blocks.0.attn1.to_out.bias",
378
+ "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.12.transformer_blocks.0.act_fn.proj.weight",
379
+ "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.12.transformer_blocks.0.act_fn.proj.bias",
380
+ "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.12.transformer_blocks.0.ff.weight",
381
+ "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.12.transformer_blocks.0.ff.bias",
382
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.12.transformer_blocks.0.attn2.to_q.weight",
383
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.12.transformer_blocks.0.attn2.to_k.weight",
384
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.12.transformer_blocks.0.attn2.to_v.weight",
385
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.12.transformer_blocks.0.attn2.to_out.weight",
386
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.12.transformer_blocks.0.attn2.to_out.bias",
387
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.12.transformer_blocks.0.norm1.weight",
388
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.12.transformer_blocks.0.norm1.bias",
389
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.12.transformer_blocks.0.norm2.weight",
390
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.12.transformer_blocks.0.norm2.bias",
391
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.12.transformer_blocks.0.norm3.weight",
392
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.12.transformer_blocks.0.norm3.bias",
393
+ "control_model.input_blocks.5.1.proj_out.weight": "blocks.12.proj_out.weight",
394
+ "control_model.input_blocks.5.1.proj_out.bias": "blocks.12.proj_out.bias",
395
+ "control_model.input_blocks.6.0.op.weight": "blocks.14.conv.weight",
396
+ "control_model.input_blocks.6.0.op.bias": "blocks.14.conv.bias",
397
+ "control_model.input_blocks.7.0.in_layers.0.weight": "blocks.16.norm1.weight",
398
+ "control_model.input_blocks.7.0.in_layers.0.bias": "blocks.16.norm1.bias",
399
+ "control_model.input_blocks.7.0.in_layers.2.weight": "blocks.16.conv1.weight",
400
+ "control_model.input_blocks.7.0.in_layers.2.bias": "blocks.16.conv1.bias",
401
+ "control_model.input_blocks.7.0.emb_layers.1.weight": "blocks.16.time_emb_proj.weight",
402
+ "control_model.input_blocks.7.0.emb_layers.1.bias": "blocks.16.time_emb_proj.bias",
403
+ "control_model.input_blocks.7.0.out_layers.0.weight": "blocks.16.norm2.weight",
404
+ "control_model.input_blocks.7.0.out_layers.0.bias": "blocks.16.norm2.bias",
405
+ "control_model.input_blocks.7.0.out_layers.3.weight": "blocks.16.conv2.weight",
406
+ "control_model.input_blocks.7.0.out_layers.3.bias": "blocks.16.conv2.bias",
407
+ "control_model.input_blocks.7.0.skip_connection.weight": "blocks.16.conv_shortcut.weight",
408
+ "control_model.input_blocks.7.0.skip_connection.bias": "blocks.16.conv_shortcut.bias",
409
+ "control_model.input_blocks.7.1.norm.weight": "blocks.17.norm.weight",
410
+ "control_model.input_blocks.7.1.norm.bias": "blocks.17.norm.bias",
411
+ "control_model.input_blocks.7.1.proj_in.weight": "blocks.17.proj_in.weight",
412
+ "control_model.input_blocks.7.1.proj_in.bias": "blocks.17.proj_in.bias",
413
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.17.transformer_blocks.0.attn1.to_q.weight",
414
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.17.transformer_blocks.0.attn1.to_k.weight",
415
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.17.transformer_blocks.0.attn1.to_v.weight",
416
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.17.transformer_blocks.0.attn1.to_out.weight",
417
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.17.transformer_blocks.0.attn1.to_out.bias",
418
+ "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.17.transformer_blocks.0.act_fn.proj.weight",
419
+ "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.17.transformer_blocks.0.act_fn.proj.bias",
420
+ "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.17.transformer_blocks.0.ff.weight",
421
+ "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.17.transformer_blocks.0.ff.bias",
422
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.17.transformer_blocks.0.attn2.to_q.weight",
423
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.17.transformer_blocks.0.attn2.to_k.weight",
424
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.17.transformer_blocks.0.attn2.to_v.weight",
425
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.17.transformer_blocks.0.attn2.to_out.weight",
426
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.17.transformer_blocks.0.attn2.to_out.bias",
427
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.17.transformer_blocks.0.norm1.weight",
428
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.17.transformer_blocks.0.norm1.bias",
429
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.17.transformer_blocks.0.norm2.weight",
430
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.17.transformer_blocks.0.norm2.bias",
431
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.17.transformer_blocks.0.norm3.weight",
432
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.17.transformer_blocks.0.norm3.bias",
433
+ "control_model.input_blocks.7.1.proj_out.weight": "blocks.17.proj_out.weight",
434
+ "control_model.input_blocks.7.1.proj_out.bias": "blocks.17.proj_out.bias",
435
+ "control_model.input_blocks.8.0.in_layers.0.weight": "blocks.19.norm1.weight",
436
+ "control_model.input_blocks.8.0.in_layers.0.bias": "blocks.19.norm1.bias",
437
+ "control_model.input_blocks.8.0.in_layers.2.weight": "blocks.19.conv1.weight",
438
+ "control_model.input_blocks.8.0.in_layers.2.bias": "blocks.19.conv1.bias",
439
+ "control_model.input_blocks.8.0.emb_layers.1.weight": "blocks.19.time_emb_proj.weight",
440
+ "control_model.input_blocks.8.0.emb_layers.1.bias": "blocks.19.time_emb_proj.bias",
441
+ "control_model.input_blocks.8.0.out_layers.0.weight": "blocks.19.norm2.weight",
442
+ "control_model.input_blocks.8.0.out_layers.0.bias": "blocks.19.norm2.bias",
443
+ "control_model.input_blocks.8.0.out_layers.3.weight": "blocks.19.conv2.weight",
444
+ "control_model.input_blocks.8.0.out_layers.3.bias": "blocks.19.conv2.bias",
445
+ "control_model.input_blocks.8.1.norm.weight": "blocks.20.norm.weight",
446
+ "control_model.input_blocks.8.1.norm.bias": "blocks.20.norm.bias",
447
+ "control_model.input_blocks.8.1.proj_in.weight": "blocks.20.proj_in.weight",
448
+ "control_model.input_blocks.8.1.proj_in.bias": "blocks.20.proj_in.bias",
449
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.20.transformer_blocks.0.attn1.to_q.weight",
450
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.20.transformer_blocks.0.attn1.to_k.weight",
451
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.20.transformer_blocks.0.attn1.to_v.weight",
452
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.20.transformer_blocks.0.attn1.to_out.weight",
453
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.20.transformer_blocks.0.attn1.to_out.bias",
454
+ "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.20.transformer_blocks.0.act_fn.proj.weight",
455
+ "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.20.transformer_blocks.0.act_fn.proj.bias",
456
+ "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.20.transformer_blocks.0.ff.weight",
457
+ "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.20.transformer_blocks.0.ff.bias",
458
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.20.transformer_blocks.0.attn2.to_q.weight",
459
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.20.transformer_blocks.0.attn2.to_k.weight",
460
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.20.transformer_blocks.0.attn2.to_v.weight",
461
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.20.transformer_blocks.0.attn2.to_out.weight",
462
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.20.transformer_blocks.0.attn2.to_out.bias",
463
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.20.transformer_blocks.0.norm1.weight",
464
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.20.transformer_blocks.0.norm1.bias",
465
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.20.transformer_blocks.0.norm2.weight",
466
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.20.transformer_blocks.0.norm2.bias",
467
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.20.transformer_blocks.0.norm3.weight",
468
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.20.transformer_blocks.0.norm3.bias",
469
+ "control_model.input_blocks.8.1.proj_out.weight": "blocks.20.proj_out.weight",
470
+ "control_model.input_blocks.8.1.proj_out.bias": "blocks.20.proj_out.bias",
471
+ "control_model.input_blocks.9.0.op.weight": "blocks.22.conv.weight",
472
+ "control_model.input_blocks.9.0.op.bias": "blocks.22.conv.bias",
473
+ "control_model.input_blocks.10.0.in_layers.0.weight": "blocks.24.norm1.weight",
474
+ "control_model.input_blocks.10.0.in_layers.0.bias": "blocks.24.norm1.bias",
475
+ "control_model.input_blocks.10.0.in_layers.2.weight": "blocks.24.conv1.weight",
476
+ "control_model.input_blocks.10.0.in_layers.2.bias": "blocks.24.conv1.bias",
477
+ "control_model.input_blocks.10.0.emb_layers.1.weight": "blocks.24.time_emb_proj.weight",
478
+ "control_model.input_blocks.10.0.emb_layers.1.bias": "blocks.24.time_emb_proj.bias",
479
+ "control_model.input_blocks.10.0.out_layers.0.weight": "blocks.24.norm2.weight",
480
+ "control_model.input_blocks.10.0.out_layers.0.bias": "blocks.24.norm2.bias",
481
+ "control_model.input_blocks.10.0.out_layers.3.weight": "blocks.24.conv2.weight",
482
+ "control_model.input_blocks.10.0.out_layers.3.bias": "blocks.24.conv2.bias",
483
+ "control_model.input_blocks.11.0.in_layers.0.weight": "blocks.26.norm1.weight",
484
+ "control_model.input_blocks.11.0.in_layers.0.bias": "blocks.26.norm1.bias",
485
+ "control_model.input_blocks.11.0.in_layers.2.weight": "blocks.26.conv1.weight",
486
+ "control_model.input_blocks.11.0.in_layers.2.bias": "blocks.26.conv1.bias",
487
+ "control_model.input_blocks.11.0.emb_layers.1.weight": "blocks.26.time_emb_proj.weight",
488
+ "control_model.input_blocks.11.0.emb_layers.1.bias": "blocks.26.time_emb_proj.bias",
489
+ "control_model.input_blocks.11.0.out_layers.0.weight": "blocks.26.norm2.weight",
490
+ "control_model.input_blocks.11.0.out_layers.0.bias": "blocks.26.norm2.bias",
491
+ "control_model.input_blocks.11.0.out_layers.3.weight": "blocks.26.conv2.weight",
492
+ "control_model.input_blocks.11.0.out_layers.3.bias": "blocks.26.conv2.bias",
493
+ "control_model.zero_convs.0.0.weight": "controlnet_blocks.0.weight",
494
+ "control_model.zero_convs.0.0.bias": "controlnet_blocks.0.bias",
495
+ "control_model.zero_convs.1.0.weight": "controlnet_blocks.1.weight",
496
+ "control_model.zero_convs.1.0.bias": "controlnet_blocks.0.bias",
497
+ "control_model.zero_convs.2.0.weight": "controlnet_blocks.2.weight",
498
+ "control_model.zero_convs.2.0.bias": "controlnet_blocks.0.bias",
499
+ "control_model.zero_convs.3.0.weight": "controlnet_blocks.3.weight",
500
+ "control_model.zero_convs.3.0.bias": "controlnet_blocks.0.bias",
501
+ "control_model.zero_convs.4.0.weight": "controlnet_blocks.4.weight",
502
+ "control_model.zero_convs.4.0.bias": "controlnet_blocks.4.bias",
503
+ "control_model.zero_convs.5.0.weight": "controlnet_blocks.5.weight",
504
+ "control_model.zero_convs.5.0.bias": "controlnet_blocks.4.bias",
505
+ "control_model.zero_convs.6.0.weight": "controlnet_blocks.6.weight",
506
+ "control_model.zero_convs.6.0.bias": "controlnet_blocks.4.bias",
507
+ "control_model.zero_convs.7.0.weight": "controlnet_blocks.7.weight",
508
+ "control_model.zero_convs.7.0.bias": "controlnet_blocks.7.bias",
509
+ "control_model.zero_convs.8.0.weight": "controlnet_blocks.8.weight",
510
+ "control_model.zero_convs.8.0.bias": "controlnet_blocks.7.bias",
511
+ "control_model.zero_convs.9.0.weight": "controlnet_blocks.9.weight",
512
+ "control_model.zero_convs.9.0.bias": "controlnet_blocks.7.bias",
513
+ "control_model.zero_convs.10.0.weight": "controlnet_blocks.10.weight",
514
+ "control_model.zero_convs.10.0.bias": "controlnet_blocks.7.bias",
515
+ "control_model.zero_convs.11.0.weight": "controlnet_blocks.11.weight",
516
+ "control_model.zero_convs.11.0.bias": "controlnet_blocks.7.bias",
517
+ "control_model.input_hint_block.0.weight": "controlnet_conv_in.blocks.0.weight",
518
+ "control_model.input_hint_block.0.bias": "controlnet_conv_in.blocks.0.bias",
519
+ "control_model.input_hint_block.2.weight": "controlnet_conv_in.blocks.2.weight",
520
+ "control_model.input_hint_block.2.bias": "controlnet_conv_in.blocks.2.bias",
521
+ "control_model.input_hint_block.4.weight": "controlnet_conv_in.blocks.4.weight",
522
+ "control_model.input_hint_block.4.bias": "controlnet_conv_in.blocks.4.bias",
523
+ "control_model.input_hint_block.6.weight": "controlnet_conv_in.blocks.6.weight",
524
+ "control_model.input_hint_block.6.bias": "controlnet_conv_in.blocks.6.bias",
525
+ "control_model.input_hint_block.8.weight": "controlnet_conv_in.blocks.8.weight",
526
+ "control_model.input_hint_block.8.bias": "controlnet_conv_in.blocks.8.bias",
527
+ "control_model.input_hint_block.10.weight": "controlnet_conv_in.blocks.10.weight",
528
+ "control_model.input_hint_block.10.bias": "controlnet_conv_in.blocks.10.bias",
529
+ "control_model.input_hint_block.12.weight": "controlnet_conv_in.blocks.12.weight",
530
+ "control_model.input_hint_block.12.bias": "controlnet_conv_in.blocks.12.bias",
531
+ "control_model.input_hint_block.14.weight": "controlnet_conv_in.blocks.14.weight",
532
+ "control_model.input_hint_block.14.bias": "controlnet_conv_in.blocks.14.bias",
533
+ "control_model.middle_block.0.in_layers.0.weight": "blocks.28.norm1.weight",
534
+ "control_model.middle_block.0.in_layers.0.bias": "blocks.28.norm1.bias",
535
+ "control_model.middle_block.0.in_layers.2.weight": "blocks.28.conv1.weight",
536
+ "control_model.middle_block.0.in_layers.2.bias": "blocks.28.conv1.bias",
537
+ "control_model.middle_block.0.emb_layers.1.weight": "blocks.28.time_emb_proj.weight",
538
+ "control_model.middle_block.0.emb_layers.1.bias": "blocks.28.time_emb_proj.bias",
539
+ "control_model.middle_block.0.out_layers.0.weight": "blocks.28.norm2.weight",
540
+ "control_model.middle_block.0.out_layers.0.bias": "blocks.28.norm2.bias",
541
+ "control_model.middle_block.0.out_layers.3.weight": "blocks.28.conv2.weight",
542
+ "control_model.middle_block.0.out_layers.3.bias": "blocks.28.conv2.bias",
543
+ "control_model.middle_block.1.norm.weight": "blocks.29.norm.weight",
544
+ "control_model.middle_block.1.norm.bias": "blocks.29.norm.bias",
545
+ "control_model.middle_block.1.proj_in.weight": "blocks.29.proj_in.weight",
546
+ "control_model.middle_block.1.proj_in.bias": "blocks.29.proj_in.bias",
547
+ "control_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "blocks.29.transformer_blocks.0.attn1.to_q.weight",
548
+ "control_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "blocks.29.transformer_blocks.0.attn1.to_k.weight",
549
+ "control_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "blocks.29.transformer_blocks.0.attn1.to_v.weight",
550
+ "control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.29.transformer_blocks.0.attn1.to_out.weight",
551
+ "control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.29.transformer_blocks.0.attn1.to_out.bias",
552
+ "control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.29.transformer_blocks.0.act_fn.proj.weight",
553
+ "control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.29.transformer_blocks.0.act_fn.proj.bias",
554
+ "control_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "blocks.29.transformer_blocks.0.ff.weight",
555
+ "control_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "blocks.29.transformer_blocks.0.ff.bias",
556
+ "control_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "blocks.29.transformer_blocks.0.attn2.to_q.weight",
557
+ "control_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "blocks.29.transformer_blocks.0.attn2.to_k.weight",
558
+ "control_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "blocks.29.transformer_blocks.0.attn2.to_v.weight",
559
+ "control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.29.transformer_blocks.0.attn2.to_out.weight",
560
+ "control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.29.transformer_blocks.0.attn2.to_out.bias",
561
+ "control_model.middle_block.1.transformer_blocks.0.norm1.weight": "blocks.29.transformer_blocks.0.norm1.weight",
562
+ "control_model.middle_block.1.transformer_blocks.0.norm1.bias": "blocks.29.transformer_blocks.0.norm1.bias",
563
+ "control_model.middle_block.1.transformer_blocks.0.norm2.weight": "blocks.29.transformer_blocks.0.norm2.weight",
564
+ "control_model.middle_block.1.transformer_blocks.0.norm2.bias": "blocks.29.transformer_blocks.0.norm2.bias",
565
+ "control_model.middle_block.1.transformer_blocks.0.norm3.weight": "blocks.29.transformer_blocks.0.norm3.weight",
566
+ "control_model.middle_block.1.transformer_blocks.0.norm3.bias": "blocks.29.transformer_blocks.0.norm3.bias",
567
+ "control_model.middle_block.1.proj_out.weight": "blocks.29.proj_out.weight",
568
+ "control_model.middle_block.1.proj_out.bias": "blocks.29.proj_out.bias",
569
+ "control_model.middle_block.2.in_layers.0.weight": "blocks.30.norm1.weight",
570
+ "control_model.middle_block.2.in_layers.0.bias": "blocks.30.norm1.bias",
571
+ "control_model.middle_block.2.in_layers.2.weight": "blocks.30.conv1.weight",
572
+ "control_model.middle_block.2.in_layers.2.bias": "blocks.30.conv1.bias",
573
+ "control_model.middle_block.2.emb_layers.1.weight": "blocks.30.time_emb_proj.weight",
574
+ "control_model.middle_block.2.emb_layers.1.bias": "blocks.30.time_emb_proj.bias",
575
+ "control_model.middle_block.2.out_layers.0.weight": "blocks.30.norm2.weight",
576
+ "control_model.middle_block.2.out_layers.0.bias": "blocks.30.norm2.bias",
577
+ "control_model.middle_block.2.out_layers.3.weight": "blocks.30.conv2.weight",
578
+ "control_model.middle_block.2.out_layers.3.bias": "blocks.30.conv2.bias",
579
+ "control_model.middle_block_out.0.weight": "controlnet_blocks.12.weight",
580
+ "control_model.middle_block_out.0.bias": "controlnet_blocks.7.bias",
581
+ }
582
+ state_dict_ = {}
583
+ for name in state_dict:
584
+ if name in rename_dict:
585
+ param = state_dict[name]
586
+ if ".proj_in." in name or ".proj_out." in name:
587
+ param = param.squeeze()
588
+ state_dict_[rename_dict[name]] = param
589
+ return state_dict_
diffsynth/models/sd_ipadapter.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .svd_image_encoder import SVDImageEncoder
2
+ from .sdxl_ipadapter import IpAdapterImageProjModel, IpAdapterModule, SDXLIpAdapterStateDictConverter
3
+ from transformers import CLIPImageProcessor
4
+ import torch
5
+
6
+
7
+ class IpAdapterCLIPImageEmbedder(SVDImageEncoder):
8
+ def __init__(self):
9
+ super().__init__()
10
+ self.image_processor = CLIPImageProcessor()
11
+
12
+ def forward(self, image):
13
+ pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values
14
+ pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype)
15
+ return super().forward(pixel_values)
16
+
17
+
18
+ class SDIpAdapter(torch.nn.Module):
19
+ def __init__(self):
20
+ super().__init__()
21
+ shape_list = [(768, 320)] * 2 + [(768, 640)] * 2 + [(768, 1280)] * 5 + [(768, 640)] * 3 + [(768, 320)] * 3 + [(768, 1280)] * 1
22
+ self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list])
23
+ self.image_proj = IpAdapterImageProjModel(cross_attention_dim=768, clip_embeddings_dim=1024, clip_extra_context_tokens=4)
24
+ self.set_full_adapter()
25
+
26
+ def set_full_adapter(self):
27
+ block_ids = [1, 4, 9, 12, 17, 20, 40, 43, 46, 50, 53, 56, 60, 63, 66, 29]
28
+ self.call_block_id = {(i, 0): j for j, i in enumerate(block_ids)}
29
+
30
+ def set_less_adapter(self):
31
+ # IP-Adapter for SD v1.5 doesn't support this feature.
32
+ self.set_full_adapter()
33
+
34
+ def forward(self, hidden_states, scale=1.0):
35
+ hidden_states = self.image_proj(hidden_states)
36
+ hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
37
+ ip_kv_dict = {}
38
+ for (block_id, transformer_id) in self.call_block_id:
39
+ ipadapter_id = self.call_block_id[(block_id, transformer_id)]
40
+ ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
41
+ if block_id not in ip_kv_dict:
42
+ ip_kv_dict[block_id] = {}
43
+ ip_kv_dict[block_id][transformer_id] = {
44
+ "ip_k": ip_k,
45
+ "ip_v": ip_v,
46
+ "scale": scale
47
+ }
48
+ return ip_kv_dict
49
+
50
+ @staticmethod
51
+ def state_dict_converter():
52
+ return SDIpAdapterStateDictConverter()
53
+
54
+
55
+ class SDIpAdapterStateDictConverter(SDXLIpAdapterStateDictConverter):
56
+ def __init__(self):
57
+ pass
diffsynth/models/sd_motion.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .sd_unet import SDUNet, Attention, GEGLU
2
+ import torch
3
+ from einops import rearrange, repeat
4
+
5
+
6
+ class TemporalTransformerBlock(torch.nn.Module):
7
+
8
+ def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32):
9
+ super().__init__()
10
+
11
+ # 1. Self-Attn
12
+ self.pe1 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
13
+ self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True)
14
+ self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
15
+
16
+ # 2. Cross-Attn
17
+ self.pe2 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
18
+ self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True)
19
+ self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
20
+
21
+ # 3. Feed-forward
22
+ self.norm3 = torch.nn.LayerNorm(dim, elementwise_affine=True)
23
+ self.act_fn = GEGLU(dim, dim * 4)
24
+ self.ff = torch.nn.Linear(dim * 4, dim)
25
+
26
+
27
+ def forward(self, hidden_states, batch_size=1):
28
+
29
+ # 1. Self-Attention
30
+ norm_hidden_states = self.norm1(hidden_states)
31
+ norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
32
+ attn_output = self.attn1(norm_hidden_states + self.pe1[:, :norm_hidden_states.shape[1]])
33
+ attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
34
+ hidden_states = attn_output + hidden_states
35
+
36
+ # 2. Cross-Attention
37
+ norm_hidden_states = self.norm2(hidden_states)
38
+ norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
39
+ attn_output = self.attn2(norm_hidden_states + self.pe2[:, :norm_hidden_states.shape[1]])
40
+ attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
41
+ hidden_states = attn_output + hidden_states
42
+
43
+ # 3. Feed-forward
44
+ norm_hidden_states = self.norm3(hidden_states)
45
+ ff_output = self.act_fn(norm_hidden_states)
46
+ ff_output = self.ff(ff_output)
47
+ hidden_states = ff_output + hidden_states
48
+
49
+ return hidden_states
50
+
51
+
52
+ class TemporalBlock(torch.nn.Module):
53
+
54
+ def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
55
+ super().__init__()
56
+ inner_dim = num_attention_heads * attention_head_dim
57
+
58
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
59
+ self.proj_in = torch.nn.Linear(in_channels, inner_dim)
60
+
61
+ self.transformer_blocks = torch.nn.ModuleList([
62
+ TemporalTransformerBlock(
63
+ inner_dim,
64
+ num_attention_heads,
65
+ attention_head_dim
66
+ )
67
+ for d in range(num_layers)
68
+ ])
69
+
70
+ self.proj_out = torch.nn.Linear(inner_dim, in_channels)
71
+
72
+ def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1):
73
+ batch, _, height, width = hidden_states.shape
74
+ residual = hidden_states
75
+
76
+ hidden_states = self.norm(hidden_states)
77
+ inner_dim = hidden_states.shape[1]
78
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
79
+ hidden_states = self.proj_in(hidden_states)
80
+
81
+ for block in self.transformer_blocks:
82
+ hidden_states = block(
83
+ hidden_states,
84
+ batch_size=batch_size
85
+ )
86
+
87
+ hidden_states = self.proj_out(hidden_states)
88
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
89
+ hidden_states = hidden_states + residual
90
+
91
+ return hidden_states, time_emb, text_emb, res_stack
92
+
93
+
94
+ class SDMotionModel(torch.nn.Module):
95
+ def __init__(self):
96
+ super().__init__()
97
+ self.motion_modules = torch.nn.ModuleList([
98
+ TemporalBlock(8, 40, 320, eps=1e-6),
99
+ TemporalBlock(8, 40, 320, eps=1e-6),
100
+ TemporalBlock(8, 80, 640, eps=1e-6),
101
+ TemporalBlock(8, 80, 640, eps=1e-6),
102
+ TemporalBlock(8, 160, 1280, eps=1e-6),
103
+ TemporalBlock(8, 160, 1280, eps=1e-6),
104
+ TemporalBlock(8, 160, 1280, eps=1e-6),
105
+ TemporalBlock(8, 160, 1280, eps=1e-6),
106
+ TemporalBlock(8, 160, 1280, eps=1e-6),
107
+ TemporalBlock(8, 160, 1280, eps=1e-6),
108
+ TemporalBlock(8, 160, 1280, eps=1e-6),
109
+ TemporalBlock(8, 160, 1280, eps=1e-6),
110
+ TemporalBlock(8, 160, 1280, eps=1e-6),
111
+ TemporalBlock(8, 160, 1280, eps=1e-6),
112
+ TemporalBlock(8, 160, 1280, eps=1e-6),
113
+ TemporalBlock(8, 80, 640, eps=1e-6),
114
+ TemporalBlock(8, 80, 640, eps=1e-6),
115
+ TemporalBlock(8, 80, 640, eps=1e-6),
116
+ TemporalBlock(8, 40, 320, eps=1e-6),
117
+ TemporalBlock(8, 40, 320, eps=1e-6),
118
+ TemporalBlock(8, 40, 320, eps=1e-6),
119
+ ])
120
+ self.call_block_id = {
121
+ 1: 0,
122
+ 4: 1,
123
+ 9: 2,
124
+ 12: 3,
125
+ 17: 4,
126
+ 20: 5,
127
+ 24: 6,
128
+ 26: 7,
129
+ 29: 8,
130
+ 32: 9,
131
+ 34: 10,
132
+ 36: 11,
133
+ 40: 12,
134
+ 43: 13,
135
+ 46: 14,
136
+ 50: 15,
137
+ 53: 16,
138
+ 56: 17,
139
+ 60: 18,
140
+ 63: 19,
141
+ 66: 20
142
+ }
143
+
144
+ def forward(self):
145
+ pass
146
+
147
+ @staticmethod
148
+ def state_dict_converter():
149
+ return SDMotionModelStateDictConverter()
150
+
151
+
152
+ class SDMotionModelStateDictConverter:
153
+ def __init__(self):
154
+ pass
155
+
156
+ def from_diffusers(self, state_dict):
157
+ rename_dict = {
158
+ "norm": "norm",
159
+ "proj_in": "proj_in",
160
+ "transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q",
161
+ "transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k",
162
+ "transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v",
163
+ "transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out",
164
+ "transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1",
165
+ "transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q",
166
+ "transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k",
167
+ "transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v",
168
+ "transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out",
169
+ "transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2",
170
+ "transformer_blocks.0.norms.0": "transformer_blocks.0.norm1",
171
+ "transformer_blocks.0.norms.1": "transformer_blocks.0.norm2",
172
+ "transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj",
173
+ "transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff",
174
+ "transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3",
175
+ "proj_out": "proj_out",
176
+ }
177
+ name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")])
178
+ name_list += sorted([i for i in state_dict if i.startswith("mid_block.")])
179
+ name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")])
180
+ state_dict_ = {}
181
+ last_prefix, module_id = "", -1
182
+ for name in name_list:
183
+ names = name.split(".")
184
+ prefix_index = names.index("temporal_transformer") + 1
185
+ prefix = ".".join(names[:prefix_index])
186
+ if prefix != last_prefix:
187
+ last_prefix = prefix
188
+ module_id += 1
189
+ middle_name = ".".join(names[prefix_index:-1])
190
+ suffix = names[-1]
191
+ if "pos_encoder" in names:
192
+ rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]])
193
+ else:
194
+ rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
195
+ state_dict_[rename] = state_dict[name]
196
+ return state_dict_
197
+
198
+ def from_civitai(self, state_dict):
199
+ return self.from_diffusers(state_dict)
diffsynth/models/sd_text_encoder.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .attention import Attention
3
+
4
+
5
+ class CLIPEncoderLayer(torch.nn.Module):
6
+ def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
7
+ super().__init__()
8
+ self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
9
+ self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
10
+ self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
11
+ self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
12
+ self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
13
+
14
+ self.use_quick_gelu = use_quick_gelu
15
+
16
+ def quickGELU(self, x):
17
+ return x * torch.sigmoid(1.702 * x)
18
+
19
+ def forward(self, hidden_states, attn_mask=None):
20
+ residual = hidden_states
21
+
22
+ hidden_states = self.layer_norm1(hidden_states)
23
+ hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
24
+ hidden_states = residual + hidden_states
25
+
26
+ residual = hidden_states
27
+ hidden_states = self.layer_norm2(hidden_states)
28
+ hidden_states = self.fc1(hidden_states)
29
+ if self.use_quick_gelu:
30
+ hidden_states = self.quickGELU(hidden_states)
31
+ else:
32
+ hidden_states = torch.nn.functional.gelu(hidden_states)
33
+ hidden_states = self.fc2(hidden_states)
34
+ hidden_states = residual + hidden_states
35
+
36
+ return hidden_states
37
+
38
+
39
+ class SDTextEncoder(torch.nn.Module):
40
+ def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
41
+ super().__init__()
42
+
43
+ # token_embedding
44
+ self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
45
+
46
+ # position_embeds (This is a fixed tensor)
47
+ self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
48
+
49
+ # encoders
50
+ self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
51
+
52
+ # attn_mask
53
+ self.attn_mask = self.attention_mask(max_position_embeddings)
54
+
55
+ # final_layer_norm
56
+ self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
57
+
58
+ def attention_mask(self, length):
59
+ mask = torch.empty(length, length)
60
+ mask.fill_(float("-inf"))
61
+ mask.triu_(1)
62
+ return mask
63
+
64
+ def forward(self, input_ids, clip_skip=1):
65
+ embeds = self.token_embedding(input_ids) + self.position_embeds
66
+ attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
67
+ for encoder_id, encoder in enumerate(self.encoders):
68
+ embeds = encoder(embeds, attn_mask=attn_mask)
69
+ if encoder_id + clip_skip == len(self.encoders):
70
+ break
71
+ embeds = self.final_layer_norm(embeds)
72
+ return embeds
73
+
74
+ @staticmethod
75
+ def state_dict_converter():
76
+ return SDTextEncoderStateDictConverter()
77
+
78
+
79
+ class SDTextEncoderStateDictConverter:
80
+ def __init__(self):
81
+ pass
82
+
83
+ def from_diffusers(self, state_dict):
84
+ rename_dict = {
85
+ "text_model.embeddings.token_embedding.weight": "token_embedding.weight",
86
+ "text_model.embeddings.position_embedding.weight": "position_embeds",
87
+ "text_model.final_layer_norm.weight": "final_layer_norm.weight",
88
+ "text_model.final_layer_norm.bias": "final_layer_norm.bias"
89
+ }
90
+ attn_rename_dict = {
91
+ "self_attn.q_proj": "attn.to_q",
92
+ "self_attn.k_proj": "attn.to_k",
93
+ "self_attn.v_proj": "attn.to_v",
94
+ "self_attn.out_proj": "attn.to_out",
95
+ "layer_norm1": "layer_norm1",
96
+ "layer_norm2": "layer_norm2",
97
+ "mlp.fc1": "fc1",
98
+ "mlp.fc2": "fc2",
99
+ }
100
+ state_dict_ = {}
101
+ for name in state_dict:
102
+ if name in rename_dict:
103
+ param = state_dict[name]
104
+ if name == "text_model.embeddings.position_embedding.weight":
105
+ param = param.reshape((1, param.shape[0], param.shape[1]))
106
+ state_dict_[rename_dict[name]] = param
107
+ elif name.startswith("text_model.encoder.layers."):
108
+ param = state_dict[name]
109
+ names = name.split(".")
110
+ layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
111
+ name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
112
+ state_dict_[name_] = param
113
+ return state_dict_
114
+
115
+ def from_civitai(self, state_dict):
116
+ rename_dict = {
117
+ "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
118
+ "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
119
+ "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
120
+ "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
121
+ "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
122
+ "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
123
+ "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
124
+ "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
125
+ "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
126
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
127
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
128
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
129
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
130
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
131
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
132
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
133
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
134
+ "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
135
+ "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
136
+ "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
137
+ "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
138
+ "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
139
+ "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
140
+ "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
141
+ "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
142
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
143
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
144
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
145
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
146
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
147
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
148
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
149
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
150
+ "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
151
+ "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
152
+ "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
153
+ "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
154
+ "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
155
+ "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
156
+ "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
157
+ "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
158
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
159
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
160
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
161
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
162
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
163
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
164
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
165
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
166
+ "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias",
167
+ "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight",
168
+ "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias",
169
+ "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight",
170
+ "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias",
171
+ "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight",
172
+ "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias",
173
+ "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight",
174
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias",
175
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight",
176
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias",
177
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight",
178
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias",
179
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight",
180
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias",
181
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight",
182
+ "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
183
+ "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
184
+ "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
185
+ "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
186
+ "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
187
+ "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
188
+ "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
189
+ "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
190
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
191
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
192
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
193
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
194
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
195
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
196
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
197
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
198
+ "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
199
+ "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
200
+ "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
201
+ "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
202
+ "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
203
+ "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
204
+ "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
205
+ "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
206
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
207
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
208
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
209
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
210
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
211
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
212
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
213
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
214
+ "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
215
+ "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
216
+ "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
217
+ "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
218
+ "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
219
+ "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
220
+ "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
221
+ "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
222
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
223
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
224
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
225
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
226
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
227
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
228
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
229
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
230
+ "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
231
+ "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
232
+ "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
233
+ "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
234
+ "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
235
+ "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
236
+ "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
237
+ "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
238
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
239
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
240
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
241
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
242
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
243
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
244
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
245
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
246
+ "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
247
+ "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
248
+ "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
249
+ "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
250
+ "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
251
+ "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
252
+ "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
253
+ "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
254
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
255
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
256
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
257
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
258
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
259
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
260
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
261
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
262
+ "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
263
+ "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
264
+ "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
265
+ "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
266
+ "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
267
+ "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
268
+ "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
269
+ "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
270
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
271
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
272
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
273
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
274
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
275
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
276
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
277
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
278
+ "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
279
+ "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
280
+ "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
281
+ "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
282
+ "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
283
+ "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
284
+ "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
285
+ "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
286
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
287
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
288
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
289
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
290
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
291
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
292
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
293
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
294
+ "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
295
+ "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
296
+ "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
297
+ "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
298
+ "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
299
+ "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
300
+ "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
301
+ "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
302
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
303
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
304
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
305
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
306
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
307
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
308
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
309
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
310
+ "cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias",
311
+ "cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight",
312
+ "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds"
313
+ }
314
+ state_dict_ = {}
315
+ for name in state_dict:
316
+ if name in rename_dict:
317
+ param = state_dict[name]
318
+ if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight":
319
+ param = param.reshape((1, param.shape[0], param.shape[1]))
320
+ state_dict_[rename_dict[name]] = param
321
+ return state_dict_
diffsynth/models/sd_unet.py ADDED
The diff for this file is too large to render. See raw diff
 
diffsynth/models/sd_vae_decoder.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .attention import Attention
3
+ from .sd_unet import ResnetBlock, UpSampler
4
+ from .tiler import TileWorker
5
+
6
+
7
+ class VAEAttentionBlock(torch.nn.Module):
8
+
9
+ def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
10
+ super().__init__()
11
+ inner_dim = num_attention_heads * attention_head_dim
12
+
13
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
14
+
15
+ self.transformer_blocks = torch.nn.ModuleList([
16
+ Attention(
17
+ inner_dim,
18
+ num_attention_heads,
19
+ attention_head_dim,
20
+ bias_q=True,
21
+ bias_kv=True,
22
+ bias_out=True
23
+ )
24
+ for d in range(num_layers)
25
+ ])
26
+
27
+ def forward(self, hidden_states, time_emb, text_emb, res_stack):
28
+ batch, _, height, width = hidden_states.shape
29
+ residual = hidden_states
30
+
31
+ hidden_states = self.norm(hidden_states)
32
+ inner_dim = hidden_states.shape[1]
33
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
34
+
35
+ for block in self.transformer_blocks:
36
+ hidden_states = block(hidden_states)
37
+
38
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
39
+ hidden_states = hidden_states + residual
40
+
41
+ return hidden_states, time_emb, text_emb, res_stack
42
+
43
+
44
+ class SDVAEDecoder(torch.nn.Module):
45
+ def __init__(self):
46
+ super().__init__()
47
+ self.scaling_factor = 0.18215
48
+ self.post_quant_conv = torch.nn.Conv2d(4, 4, kernel_size=1)
49
+ self.conv_in = torch.nn.Conv2d(4, 512, kernel_size=3, padding=1)
50
+
51
+ self.blocks = torch.nn.ModuleList([
52
+ # UNetMidBlock2D
53
+ ResnetBlock(512, 512, eps=1e-6),
54
+ VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
55
+ ResnetBlock(512, 512, eps=1e-6),
56
+ # UpDecoderBlock2D
57
+ ResnetBlock(512, 512, eps=1e-6),
58
+ ResnetBlock(512, 512, eps=1e-6),
59
+ ResnetBlock(512, 512, eps=1e-6),
60
+ UpSampler(512),
61
+ # UpDecoderBlock2D
62
+ ResnetBlock(512, 512, eps=1e-6),
63
+ ResnetBlock(512, 512, eps=1e-6),
64
+ ResnetBlock(512, 512, eps=1e-6),
65
+ UpSampler(512),
66
+ # UpDecoderBlock2D
67
+ ResnetBlock(512, 256, eps=1e-6),
68
+ ResnetBlock(256, 256, eps=1e-6),
69
+ ResnetBlock(256, 256, eps=1e-6),
70
+ UpSampler(256),
71
+ # UpDecoderBlock2D
72
+ ResnetBlock(256, 128, eps=1e-6),
73
+ ResnetBlock(128, 128, eps=1e-6),
74
+ ResnetBlock(128, 128, eps=1e-6),
75
+ ])
76
+
77
+ self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-5)
78
+ self.conv_act = torch.nn.SiLU()
79
+ self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
80
+
81
+ def tiled_forward(self, sample, tile_size=64, tile_stride=32):
82
+ hidden_states = TileWorker().tiled_forward(
83
+ lambda x: self.forward(x),
84
+ sample,
85
+ tile_size,
86
+ tile_stride,
87
+ tile_device=sample.device,
88
+ tile_dtype=sample.dtype
89
+ )
90
+ return hidden_states
91
+
92
+ def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
93
+ original_dtype = sample.dtype
94
+ sample = sample.to(dtype=next(iter(self.parameters())).dtype)
95
+ # For VAE Decoder, we do not need to apply the tiler on each layer.
96
+ if tiled:
97
+ return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
98
+
99
+ # 1. pre-process
100
+ sample = sample / self.scaling_factor
101
+ hidden_states = self.post_quant_conv(sample)
102
+ hidden_states = self.conv_in(hidden_states)
103
+ time_emb = None
104
+ text_emb = None
105
+ res_stack = None
106
+
107
+ # 2. blocks
108
+ for i, block in enumerate(self.blocks):
109
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
110
+
111
+ # 3. output
112
+ hidden_states = self.conv_norm_out(hidden_states)
113
+ hidden_states = self.conv_act(hidden_states)
114
+ hidden_states = self.conv_out(hidden_states)
115
+ hidden_states = hidden_states.to(original_dtype)
116
+
117
+ return hidden_states
118
+
119
+ @staticmethod
120
+ def state_dict_converter():
121
+ return SDVAEDecoderStateDictConverter()
122
+
123
+
124
+ class SDVAEDecoderStateDictConverter:
125
+ def __init__(self):
126
+ pass
127
+
128
+ def from_diffusers(self, state_dict):
129
+ # architecture
130
+ block_types = [
131
+ 'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock',
132
+ 'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
133
+ 'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
134
+ 'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
135
+ 'ResnetBlock', 'ResnetBlock', 'ResnetBlock'
136
+ ]
137
+
138
+ # Rename each parameter
139
+ local_rename_dict = {
140
+ "post_quant_conv": "post_quant_conv",
141
+ "decoder.conv_in": "conv_in",
142
+ "decoder.mid_block.attentions.0.group_norm": "blocks.1.norm",
143
+ "decoder.mid_block.attentions.0.to_q": "blocks.1.transformer_blocks.0.to_q",
144
+ "decoder.mid_block.attentions.0.to_k": "blocks.1.transformer_blocks.0.to_k",
145
+ "decoder.mid_block.attentions.0.to_v": "blocks.1.transformer_blocks.0.to_v",
146
+ "decoder.mid_block.attentions.0.to_out.0": "blocks.1.transformer_blocks.0.to_out",
147
+ "decoder.mid_block.resnets.0.norm1": "blocks.0.norm1",
148
+ "decoder.mid_block.resnets.0.conv1": "blocks.0.conv1",
149
+ "decoder.mid_block.resnets.0.norm2": "blocks.0.norm2",
150
+ "decoder.mid_block.resnets.0.conv2": "blocks.0.conv2",
151
+ "decoder.mid_block.resnets.1.norm1": "blocks.2.norm1",
152
+ "decoder.mid_block.resnets.1.conv1": "blocks.2.conv1",
153
+ "decoder.mid_block.resnets.1.norm2": "blocks.2.norm2",
154
+ "decoder.mid_block.resnets.1.conv2": "blocks.2.conv2",
155
+ "decoder.conv_norm_out": "conv_norm_out",
156
+ "decoder.conv_out": "conv_out",
157
+ }
158
+ name_list = sorted([name for name in state_dict])
159
+ rename_dict = {}
160
+ block_id = {"ResnetBlock": 2, "DownSampler": 2, "UpSampler": 2}
161
+ last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
162
+ for name in name_list:
163
+ names = name.split(".")
164
+ name_prefix = ".".join(names[:-1])
165
+ if name_prefix in local_rename_dict:
166
+ rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
167
+ elif name.startswith("decoder.up_blocks"):
168
+ block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
169
+ block_type_with_id = ".".join(names[:5])
170
+ if block_type_with_id != last_block_type_with_id[block_type]:
171
+ block_id[block_type] += 1
172
+ last_block_type_with_id[block_type] = block_type_with_id
173
+ while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
174
+ block_id[block_type] += 1
175
+ block_type_with_id = ".".join(names[:5])
176
+ names = ["blocks", str(block_id[block_type])] + names[5:]
177
+ rename_dict[name] = ".".join(names)
178
+
179
+ # Convert state_dict
180
+ state_dict_ = {}
181
+ for name, param in state_dict.items():
182
+ if name in rename_dict:
183
+ state_dict_[rename_dict[name]] = param
184
+ return state_dict_
185
+
186
+ def from_civitai(self, state_dict):
187
+ rename_dict = {
188
+ "first_stage_model.decoder.conv_in.bias": "conv_in.bias",
189
+ "first_stage_model.decoder.conv_in.weight": "conv_in.weight",
190
+ "first_stage_model.decoder.conv_out.bias": "conv_out.bias",
191
+ "first_stage_model.decoder.conv_out.weight": "conv_out.weight",
192
+ "first_stage_model.decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
193
+ "first_stage_model.decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
194
+ "first_stage_model.decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
195
+ "first_stage_model.decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
196
+ "first_stage_model.decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
197
+ "first_stage_model.decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
198
+ "first_stage_model.decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
199
+ "first_stage_model.decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
200
+ "first_stage_model.decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
201
+ "first_stage_model.decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
202
+ "first_stage_model.decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
203
+ "first_stage_model.decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
204
+ "first_stage_model.decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
205
+ "first_stage_model.decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
206
+ "first_stage_model.decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
207
+ "first_stage_model.decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
208
+ "first_stage_model.decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
209
+ "first_stage_model.decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
210
+ "first_stage_model.decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
211
+ "first_stage_model.decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
212
+ "first_stage_model.decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
213
+ "first_stage_model.decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
214
+ "first_stage_model.decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
215
+ "first_stage_model.decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
216
+ "first_stage_model.decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
217
+ "first_stage_model.decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
218
+ "first_stage_model.decoder.norm_out.bias": "conv_norm_out.bias",
219
+ "first_stage_model.decoder.norm_out.weight": "conv_norm_out.weight",
220
+ "first_stage_model.decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
221
+ "first_stage_model.decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
222
+ "first_stage_model.decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
223
+ "first_stage_model.decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
224
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
225
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
226
+ "first_stage_model.decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
227
+ "first_stage_model.decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
228
+ "first_stage_model.decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
229
+ "first_stage_model.decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
230
+ "first_stage_model.decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
231
+ "first_stage_model.decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
232
+ "first_stage_model.decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
233
+ "first_stage_model.decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
234
+ "first_stage_model.decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
235
+ "first_stage_model.decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
236
+ "first_stage_model.decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
237
+ "first_stage_model.decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
238
+ "first_stage_model.decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
239
+ "first_stage_model.decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
240
+ "first_stage_model.decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
241
+ "first_stage_model.decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
242
+ "first_stage_model.decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
243
+ "first_stage_model.decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
244
+ "first_stage_model.decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
245
+ "first_stage_model.decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
246
+ "first_stage_model.decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
247
+ "first_stage_model.decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
248
+ "first_stage_model.decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
249
+ "first_stage_model.decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
250
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
251
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
252
+ "first_stage_model.decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
253
+ "first_stage_model.decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
254
+ "first_stage_model.decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
255
+ "first_stage_model.decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
256
+ "first_stage_model.decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
257
+ "first_stage_model.decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
258
+ "first_stage_model.decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
259
+ "first_stage_model.decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
260
+ "first_stage_model.decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
261
+ "first_stage_model.decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
262
+ "first_stage_model.decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
263
+ "first_stage_model.decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
264
+ "first_stage_model.decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
265
+ "first_stage_model.decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
266
+ "first_stage_model.decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
267
+ "first_stage_model.decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
268
+ "first_stage_model.decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
269
+ "first_stage_model.decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
270
+ "first_stage_model.decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
271
+ "first_stage_model.decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
272
+ "first_stage_model.decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
273
+ "first_stage_model.decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
274
+ "first_stage_model.decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
275
+ "first_stage_model.decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
276
+ "first_stage_model.decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
277
+ "first_stage_model.decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
278
+ "first_stage_model.decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
279
+ "first_stage_model.decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
280
+ "first_stage_model.decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
281
+ "first_stage_model.decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
282
+ "first_stage_model.decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
283
+ "first_stage_model.decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
284
+ "first_stage_model.decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
285
+ "first_stage_model.decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
286
+ "first_stage_model.decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
287
+ "first_stage_model.decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
288
+ "first_stage_model.decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
289
+ "first_stage_model.decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
290
+ "first_stage_model.decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
291
+ "first_stage_model.decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
292
+ "first_stage_model.decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
293
+ "first_stage_model.decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
294
+ "first_stage_model.decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
295
+ "first_stage_model.decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
296
+ "first_stage_model.decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
297
+ "first_stage_model.decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
298
+ "first_stage_model.decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
299
+ "first_stage_model.decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
300
+ "first_stage_model.decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
301
+ "first_stage_model.decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
302
+ "first_stage_model.decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
303
+ "first_stage_model.decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
304
+ "first_stage_model.decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
305
+ "first_stage_model.decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
306
+ "first_stage_model.decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
307
+ "first_stage_model.decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
308
+ "first_stage_model.decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
309
+ "first_stage_model.decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
310
+ "first_stage_model.decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
311
+ "first_stage_model.decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
312
+ "first_stage_model.decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
313
+ "first_stage_model.decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
314
+ "first_stage_model.decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
315
+ "first_stage_model.decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
316
+ "first_stage_model.decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
317
+ "first_stage_model.decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
318
+ "first_stage_model.decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
319
+ "first_stage_model.decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
320
+ "first_stage_model.decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
321
+ "first_stage_model.decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
322
+ "first_stage_model.decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
323
+ "first_stage_model.decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
324
+ "first_stage_model.decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
325
+ "first_stage_model.decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
326
+ "first_stage_model.post_quant_conv.bias": "post_quant_conv.bias",
327
+ "first_stage_model.post_quant_conv.weight": "post_quant_conv.weight",
328
+ }
329
+ state_dict_ = {}
330
+ for name in state_dict:
331
+ if name in rename_dict:
332
+ param = state_dict[name]
333
+ if "transformer_blocks" in rename_dict[name]:
334
+ param = param.squeeze()
335
+ state_dict_[rename_dict[name]] = param
336
+ return state_dict_
diffsynth/models/sd_vae_encoder.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .sd_unet import ResnetBlock, DownSampler
3
+ from .sd_vae_decoder import VAEAttentionBlock
4
+ from .tiler import TileWorker
5
+ from einops import rearrange
6
+
7
+
8
+ class SDVAEEncoder(torch.nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+ self.scaling_factor = 0.18215
12
+ self.quant_conv = torch.nn.Conv2d(8, 8, kernel_size=1)
13
+ self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
14
+
15
+ self.blocks = torch.nn.ModuleList([
16
+ # DownEncoderBlock2D
17
+ ResnetBlock(128, 128, eps=1e-6),
18
+ ResnetBlock(128, 128, eps=1e-6),
19
+ DownSampler(128, padding=0, extra_padding=True),
20
+ # DownEncoderBlock2D
21
+ ResnetBlock(128, 256, eps=1e-6),
22
+ ResnetBlock(256, 256, eps=1e-6),
23
+ DownSampler(256, padding=0, extra_padding=True),
24
+ # DownEncoderBlock2D
25
+ ResnetBlock(256, 512, eps=1e-6),
26
+ ResnetBlock(512, 512, eps=1e-6),
27
+ DownSampler(512, padding=0, extra_padding=True),
28
+ # DownEncoderBlock2D
29
+ ResnetBlock(512, 512, eps=1e-6),
30
+ ResnetBlock(512, 512, eps=1e-6),
31
+ # UNetMidBlock2D
32
+ ResnetBlock(512, 512, eps=1e-6),
33
+ VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
34
+ ResnetBlock(512, 512, eps=1e-6),
35
+ ])
36
+
37
+ self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
38
+ self.conv_act = torch.nn.SiLU()
39
+ self.conv_out = torch.nn.Conv2d(512, 8, kernel_size=3, padding=1)
40
+
41
+ def tiled_forward(self, sample, tile_size=64, tile_stride=32):
42
+ hidden_states = TileWorker().tiled_forward(
43
+ lambda x: self.forward(x),
44
+ sample,
45
+ tile_size,
46
+ tile_stride,
47
+ tile_device=sample.device,
48
+ tile_dtype=sample.dtype
49
+ )
50
+ return hidden_states
51
+
52
+ def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
53
+ original_dtype = sample.dtype
54
+ sample = sample.to(dtype=next(iter(self.parameters())).dtype)
55
+ # For VAE Decoder, we do not need to apply the tiler on each layer.
56
+ if tiled:
57
+ return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
58
+
59
+ # 1. pre-process
60
+ hidden_states = self.conv_in(sample)
61
+ time_emb = None
62
+ text_emb = None
63
+ res_stack = None
64
+
65
+ # 2. blocks
66
+ for i, block in enumerate(self.blocks):
67
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
68
+
69
+ # 3. output
70
+ hidden_states = self.conv_norm_out(hidden_states)
71
+ hidden_states = self.conv_act(hidden_states)
72
+ hidden_states = self.conv_out(hidden_states)
73
+ hidden_states = self.quant_conv(hidden_states)
74
+ hidden_states = hidden_states[:, :4]
75
+ hidden_states *= self.scaling_factor
76
+ hidden_states = hidden_states.to(original_dtype)
77
+
78
+ return hidden_states
79
+
80
+ def encode_video(self, sample, batch_size=8):
81
+ B = sample.shape[0]
82
+ hidden_states = []
83
+
84
+ for i in range(0, sample.shape[2], batch_size):
85
+
86
+ j = min(i + batch_size, sample.shape[2])
87
+ sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
88
+
89
+ hidden_states_batch = self(sample_batch)
90
+ hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
91
+
92
+ hidden_states.append(hidden_states_batch)
93
+
94
+ hidden_states = torch.concat(hidden_states, dim=2)
95
+ return hidden_states
96
+
97
+ @staticmethod
98
+ def state_dict_converter():
99
+ return SDVAEEncoderStateDictConverter()
100
+
101
+
102
+ class SDVAEEncoderStateDictConverter:
103
+ def __init__(self):
104
+ pass
105
+
106
+ def from_diffusers(self, state_dict):
107
+ # architecture
108
+ block_types = [
109
+ 'ResnetBlock', 'ResnetBlock', 'DownSampler',
110
+ 'ResnetBlock', 'ResnetBlock', 'DownSampler',
111
+ 'ResnetBlock', 'ResnetBlock', 'DownSampler',
112
+ 'ResnetBlock', 'ResnetBlock',
113
+ 'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock'
114
+ ]
115
+
116
+ # Rename each parameter
117
+ local_rename_dict = {
118
+ "quant_conv": "quant_conv",
119
+ "encoder.conv_in": "conv_in",
120
+ "encoder.mid_block.attentions.0.group_norm": "blocks.12.norm",
121
+ "encoder.mid_block.attentions.0.to_q": "blocks.12.transformer_blocks.0.to_q",
122
+ "encoder.mid_block.attentions.0.to_k": "blocks.12.transformer_blocks.0.to_k",
123
+ "encoder.mid_block.attentions.0.to_v": "blocks.12.transformer_blocks.0.to_v",
124
+ "encoder.mid_block.attentions.0.to_out.0": "blocks.12.transformer_blocks.0.to_out",
125
+ "encoder.mid_block.resnets.0.norm1": "blocks.11.norm1",
126
+ "encoder.mid_block.resnets.0.conv1": "blocks.11.conv1",
127
+ "encoder.mid_block.resnets.0.norm2": "blocks.11.norm2",
128
+ "encoder.mid_block.resnets.0.conv2": "blocks.11.conv2",
129
+ "encoder.mid_block.resnets.1.norm1": "blocks.13.norm1",
130
+ "encoder.mid_block.resnets.1.conv1": "blocks.13.conv1",
131
+ "encoder.mid_block.resnets.1.norm2": "blocks.13.norm2",
132
+ "encoder.mid_block.resnets.1.conv2": "blocks.13.conv2",
133
+ "encoder.conv_norm_out": "conv_norm_out",
134
+ "encoder.conv_out": "conv_out",
135
+ }
136
+ name_list = sorted([name for name in state_dict])
137
+ rename_dict = {}
138
+ block_id = {"ResnetBlock": -1, "DownSampler": -1, "UpSampler": -1}
139
+ last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
140
+ for name in name_list:
141
+ names = name.split(".")
142
+ name_prefix = ".".join(names[:-1])
143
+ if name_prefix in local_rename_dict:
144
+ rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
145
+ elif name.startswith("encoder.down_blocks"):
146
+ block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
147
+ block_type_with_id = ".".join(names[:5])
148
+ if block_type_with_id != last_block_type_with_id[block_type]:
149
+ block_id[block_type] += 1
150
+ last_block_type_with_id[block_type] = block_type_with_id
151
+ while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
152
+ block_id[block_type] += 1
153
+ block_type_with_id = ".".join(names[:5])
154
+ names = ["blocks", str(block_id[block_type])] + names[5:]
155
+ rename_dict[name] = ".".join(names)
156
+
157
+ # Convert state_dict
158
+ state_dict_ = {}
159
+ for name, param in state_dict.items():
160
+ if name in rename_dict:
161
+ state_dict_[rename_dict[name]] = param
162
+ return state_dict_
163
+
164
+ def from_civitai(self, state_dict):
165
+ rename_dict = {
166
+ "first_stage_model.encoder.conv_in.bias": "conv_in.bias",
167
+ "first_stage_model.encoder.conv_in.weight": "conv_in.weight",
168
+ "first_stage_model.encoder.conv_out.bias": "conv_out.bias",
169
+ "first_stage_model.encoder.conv_out.weight": "conv_out.weight",
170
+ "first_stage_model.encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
171
+ "first_stage_model.encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
172
+ "first_stage_model.encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
173
+ "first_stage_model.encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
174
+ "first_stage_model.encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
175
+ "first_stage_model.encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
176
+ "first_stage_model.encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
177
+ "first_stage_model.encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
178
+ "first_stage_model.encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
179
+ "first_stage_model.encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
180
+ "first_stage_model.encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
181
+ "first_stage_model.encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
182
+ "first_stage_model.encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
183
+ "first_stage_model.encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
184
+ "first_stage_model.encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
185
+ "first_stage_model.encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
186
+ "first_stage_model.encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
187
+ "first_stage_model.encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
188
+ "first_stage_model.encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
189
+ "first_stage_model.encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
190
+ "first_stage_model.encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
191
+ "first_stage_model.encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
192
+ "first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
193
+ "first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
194
+ "first_stage_model.encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
195
+ "first_stage_model.encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
196
+ "first_stage_model.encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
197
+ "first_stage_model.encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
198
+ "first_stage_model.encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
199
+ "first_stage_model.encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
200
+ "first_stage_model.encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
201
+ "first_stage_model.encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
202
+ "first_stage_model.encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
203
+ "first_stage_model.encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
204
+ "first_stage_model.encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
205
+ "first_stage_model.encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
206
+ "first_stage_model.encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
207
+ "first_stage_model.encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
208
+ "first_stage_model.encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
209
+ "first_stage_model.encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
210
+ "first_stage_model.encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
211
+ "first_stage_model.encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
212
+ "first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
213
+ "first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
214
+ "first_stage_model.encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
215
+ "first_stage_model.encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
216
+ "first_stage_model.encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
217
+ "first_stage_model.encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
218
+ "first_stage_model.encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
219
+ "first_stage_model.encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
220
+ "first_stage_model.encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
221
+ "first_stage_model.encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
222
+ "first_stage_model.encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
223
+ "first_stage_model.encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
224
+ "first_stage_model.encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
225
+ "first_stage_model.encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
226
+ "first_stage_model.encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
227
+ "first_stage_model.encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
228
+ "first_stage_model.encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
229
+ "first_stage_model.encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
230
+ "first_stage_model.encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
231
+ "first_stage_model.encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
232
+ "first_stage_model.encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
233
+ "first_stage_model.encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
234
+ "first_stage_model.encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
235
+ "first_stage_model.encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
236
+ "first_stage_model.encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
237
+ "first_stage_model.encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
238
+ "first_stage_model.encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
239
+ "first_stage_model.encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
240
+ "first_stage_model.encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
241
+ "first_stage_model.encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
242
+ "first_stage_model.encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
243
+ "first_stage_model.encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
244
+ "first_stage_model.encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
245
+ "first_stage_model.encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
246
+ "first_stage_model.encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
247
+ "first_stage_model.encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
248
+ "first_stage_model.encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
249
+ "first_stage_model.encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
250
+ "first_stage_model.encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
251
+ "first_stage_model.encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
252
+ "first_stage_model.encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
253
+ "first_stage_model.encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
254
+ "first_stage_model.encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
255
+ "first_stage_model.encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
256
+ "first_stage_model.encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
257
+ "first_stage_model.encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
258
+ "first_stage_model.encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
259
+ "first_stage_model.encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
260
+ "first_stage_model.encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
261
+ "first_stage_model.encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
262
+ "first_stage_model.encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
263
+ "first_stage_model.encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
264
+ "first_stage_model.encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
265
+ "first_stage_model.encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
266
+ "first_stage_model.encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
267
+ "first_stage_model.encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
268
+ "first_stage_model.encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
269
+ "first_stage_model.encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
270
+ "first_stage_model.encoder.norm_out.bias": "conv_norm_out.bias",
271
+ "first_stage_model.encoder.norm_out.weight": "conv_norm_out.weight",
272
+ "first_stage_model.quant_conv.bias": "quant_conv.bias",
273
+ "first_stage_model.quant_conv.weight": "quant_conv.weight",
274
+ }
275
+ state_dict_ = {}
276
+ for name in state_dict:
277
+ if name in rename_dict:
278
+ param = state_dict[name]
279
+ if "transformer_blocks" in rename_dict[name]:
280
+ param = param.squeeze()
281
+ state_dict_[rename_dict[name]] = param
282
+ return state_dict_
diffsynth/models/sdxl_controlnet.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
3
+ from .sdxl_unet import SDXLUNet
4
+ from .tiler import TileWorker
5
+ from .sd_controlnet import ControlNetConditioningLayer
6
+ from collections import OrderedDict
7
+
8
+
9
+
10
+ class QuickGELU(torch.nn.Module):
11
+
12
+ def forward(self, x: torch.Tensor):
13
+ return x * torch.sigmoid(1.702 * x)
14
+
15
+
16
+
17
+ class ResidualAttentionBlock(torch.nn.Module):
18
+
19
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
20
+ super().__init__()
21
+
22
+ self.attn = torch.nn.MultiheadAttention(d_model, n_head)
23
+ self.ln_1 = torch.nn.LayerNorm(d_model)
24
+ self.mlp = torch.nn.Sequential(OrderedDict([
25
+ ("c_fc", torch.nn.Linear(d_model, d_model * 4)),
26
+ ("gelu", QuickGELU()),
27
+ ("c_proj", torch.nn.Linear(d_model * 4, d_model))
28
+ ]))
29
+ self.ln_2 = torch.nn.LayerNorm(d_model)
30
+ self.attn_mask = attn_mask
31
+
32
+ def attention(self, x: torch.Tensor):
33
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
34
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
35
+
36
+ def forward(self, x: torch.Tensor):
37
+ x = x + self.attention(self.ln_1(x))
38
+ x = x + self.mlp(self.ln_2(x))
39
+ return x
40
+
41
+
42
+
43
+ class SDXLControlNetUnion(torch.nn.Module):
44
+ def __init__(self, global_pool=False):
45
+ super().__init__()
46
+ self.time_proj = Timesteps(320)
47
+ self.time_embedding = torch.nn.Sequential(
48
+ torch.nn.Linear(320, 1280),
49
+ torch.nn.SiLU(),
50
+ torch.nn.Linear(1280, 1280)
51
+ )
52
+ self.add_time_proj = Timesteps(256)
53
+ self.add_time_embedding = torch.nn.Sequential(
54
+ torch.nn.Linear(2816, 1280),
55
+ torch.nn.SiLU(),
56
+ torch.nn.Linear(1280, 1280)
57
+ )
58
+ self.control_type_proj = Timesteps(256)
59
+ self.control_type_embedding = torch.nn.Sequential(
60
+ torch.nn.Linear(256 * 8, 1280),
61
+ torch.nn.SiLU(),
62
+ torch.nn.Linear(1280, 1280)
63
+ )
64
+ self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
65
+
66
+ self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
67
+ self.controlnet_transformer = ResidualAttentionBlock(320, 8)
68
+ self.task_embedding = torch.nn.Parameter(torch.randn(8, 320))
69
+ self.spatial_ch_projs = torch.nn.Linear(320, 320)
70
+
71
+ self.blocks = torch.nn.ModuleList([
72
+ # DownBlock2D
73
+ ResnetBlock(320, 320, 1280),
74
+ PushBlock(),
75
+ ResnetBlock(320, 320, 1280),
76
+ PushBlock(),
77
+ DownSampler(320),
78
+ PushBlock(),
79
+ # CrossAttnDownBlock2D
80
+ ResnetBlock(320, 640, 1280),
81
+ AttentionBlock(10, 64, 640, 2, 2048),
82
+ PushBlock(),
83
+ ResnetBlock(640, 640, 1280),
84
+ AttentionBlock(10, 64, 640, 2, 2048),
85
+ PushBlock(),
86
+ DownSampler(640),
87
+ PushBlock(),
88
+ # CrossAttnDownBlock2D
89
+ ResnetBlock(640, 1280, 1280),
90
+ AttentionBlock(20, 64, 1280, 10, 2048),
91
+ PushBlock(),
92
+ ResnetBlock(1280, 1280, 1280),
93
+ AttentionBlock(20, 64, 1280, 10, 2048),
94
+ PushBlock(),
95
+ # UNetMidBlock2DCrossAttn
96
+ ResnetBlock(1280, 1280, 1280),
97
+ AttentionBlock(20, 64, 1280, 10, 2048),
98
+ ResnetBlock(1280, 1280, 1280),
99
+ PushBlock()
100
+ ])
101
+
102
+ self.controlnet_blocks = torch.nn.ModuleList([
103
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
104
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
105
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
106
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
107
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
108
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
109
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
110
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
111
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
112
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
113
+ ])
114
+
115
+ self.global_pool = global_pool
116
+
117
+ # 0 -- openpose
118
+ # 1 -- depth
119
+ # 2 -- hed/pidi/scribble/ted
120
+ # 3 -- canny/lineart/anime_lineart/mlsd
121
+ # 4 -- normal
122
+ # 5 -- segment
123
+ # 6 -- tile
124
+ # 7 -- repaint
125
+ self.task_id = {
126
+ "openpose": 0,
127
+ "depth": 1,
128
+ "softedge": 2,
129
+ "canny": 3,
130
+ "lineart": 3,
131
+ "lineart_anime": 3,
132
+ "tile": 6,
133
+ "inpaint": 7
134
+ }
135
+
136
+
137
+ def fuse_condition_to_input(self, hidden_states, task_id, conditioning):
138
+ controlnet_cond = self.controlnet_conv_in(conditioning)
139
+ feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
140
+ feat_seq = feat_seq + self.task_embedding[task_id]
141
+ x = torch.stack([feat_seq, torch.mean(hidden_states, dim=(2, 3))], dim=1)
142
+ x = self.controlnet_transformer(x)
143
+
144
+ alpha = self.spatial_ch_projs(x[:,0]).unsqueeze(-1).unsqueeze(-1)
145
+ controlnet_cond_fuser = controlnet_cond + alpha
146
+
147
+ hidden_states = hidden_states + controlnet_cond_fuser
148
+ return hidden_states
149
+
150
+
151
+ def forward(
152
+ self,
153
+ sample, timestep, encoder_hidden_states,
154
+ conditioning, processor_id, add_time_id, add_text_embeds,
155
+ tiled=False, tile_size=64, tile_stride=32,
156
+ unet:SDXLUNet=None,
157
+ **kwargs
158
+ ):
159
+ task_id = self.task_id[processor_id]
160
+
161
+ # 1. time
162
+ t_emb = self.time_proj(timestep).to(sample.dtype)
163
+ t_emb = self.time_embedding(t_emb)
164
+
165
+ time_embeds = self.add_time_proj(add_time_id)
166
+ time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1))
167
+ add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1)
168
+ add_embeds = add_embeds.to(sample.dtype)
169
+ if unet is not None and unet.is_kolors:
170
+ add_embeds = unet.add_time_embedding(add_embeds)
171
+ else:
172
+ add_embeds = self.add_time_embedding(add_embeds)
173
+
174
+ control_type = torch.zeros((sample.shape[0], 8), dtype=sample.dtype, device=sample.device)
175
+ control_type[:, task_id] = 1
176
+ control_embeds = self.control_type_proj(control_type.flatten())
177
+ control_embeds = control_embeds.reshape((sample.shape[0], -1))
178
+ control_embeds = control_embeds.to(sample.dtype)
179
+ control_embeds = self.control_type_embedding(control_embeds)
180
+ time_emb = t_emb + add_embeds + control_embeds
181
+
182
+ # 2. pre-process
183
+ height, width = sample.shape[2], sample.shape[3]
184
+ hidden_states = self.conv_in(sample)
185
+ hidden_states = self.fuse_condition_to_input(hidden_states, task_id, conditioning)
186
+ text_emb = encoder_hidden_states
187
+ if unet is not None and unet.is_kolors:
188
+ text_emb = unet.text_intermediate_proj(text_emb)
189
+ res_stack = [hidden_states]
190
+
191
+ # 3. blocks
192
+ for i, block in enumerate(self.blocks):
193
+ if tiled and not isinstance(block, PushBlock):
194
+ _, _, inter_height, _ = hidden_states.shape
195
+ resize_scale = inter_height / height
196
+ hidden_states = TileWorker().tiled_forward(
197
+ lambda x: block(x, time_emb, text_emb, res_stack)[0],
198
+ hidden_states,
199
+ int(tile_size * resize_scale),
200
+ int(tile_stride * resize_scale),
201
+ tile_device=hidden_states.device,
202
+ tile_dtype=hidden_states.dtype
203
+ )
204
+ else:
205
+ hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
206
+
207
+ # 4. ControlNet blocks
208
+ controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
209
+
210
+ # pool
211
+ if self.global_pool:
212
+ controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack]
213
+
214
+ return controlnet_res_stack
215
+
216
+ @staticmethod
217
+ def state_dict_converter():
218
+ return SDXLControlNetUnionStateDictConverter()
219
+
220
+
221
+
222
+ class SDXLControlNetUnionStateDictConverter:
223
+ def __init__(self):
224
+ pass
225
+
226
+ def from_diffusers(self, state_dict):
227
+ # architecture
228
+ block_types = [
229
+ "ResnetBlock", "PushBlock", "ResnetBlock", "PushBlock", "DownSampler", "PushBlock",
230
+ "ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock", "DownSampler", "PushBlock",
231
+ "ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock",
232
+ "ResnetBlock", "AttentionBlock", "ResnetBlock", "PushBlock"
233
+ ]
234
+
235
+ # controlnet_rename_dict
236
+ controlnet_rename_dict = {
237
+ "controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight",
238
+ "controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias",
239
+ "controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight",
240
+ "controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias",
241
+ "controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight",
242
+ "controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias",
243
+ "controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight",
244
+ "controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias",
245
+ "controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight",
246
+ "controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias",
247
+ "controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight",
248
+ "controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias",
249
+ "controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
250
+ "controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
251
+ "controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
252
+ "controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
253
+ "control_add_embedding.linear_1.weight": "control_type_embedding.0.weight",
254
+ "control_add_embedding.linear_1.bias": "control_type_embedding.0.bias",
255
+ "control_add_embedding.linear_2.weight": "control_type_embedding.2.weight",
256
+ "control_add_embedding.linear_2.bias": "control_type_embedding.2.bias",
257
+ }
258
+
259
+ # Rename each parameter
260
+ name_list = sorted([name for name in state_dict])
261
+ rename_dict = {}
262
+ block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
263
+ last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
264
+ for name in name_list:
265
+ names = name.split(".")
266
+ if names[0] in ["conv_in", "conv_norm_out", "conv_out", "task_embedding", "spatial_ch_projs"]:
267
+ pass
268
+ elif name in controlnet_rename_dict:
269
+ names = controlnet_rename_dict[name].split(".")
270
+ elif names[0] == "controlnet_down_blocks":
271
+ names[0] = "controlnet_blocks"
272
+ elif names[0] == "controlnet_mid_block":
273
+ names = ["controlnet_blocks", "9", names[-1]]
274
+ elif names[0] in ["time_embedding", "add_embedding"]:
275
+ if names[0] == "add_embedding":
276
+ names[0] = "add_time_embedding"
277
+ names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
278
+ elif names[0] == "control_add_embedding":
279
+ names[0] = "control_type_embedding"
280
+ elif names[0] == "transformer_layes":
281
+ names[0] = "controlnet_transformer"
282
+ names.pop(1)
283
+ elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
284
+ if names[0] == "mid_block":
285
+ names.insert(1, "0")
286
+ block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
287
+ block_type_with_id = ".".join(names[:4])
288
+ if block_type_with_id != last_block_type_with_id[block_type]:
289
+ block_id[block_type] += 1
290
+ last_block_type_with_id[block_type] = block_type_with_id
291
+ while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
292
+ block_id[block_type] += 1
293
+ block_type_with_id = ".".join(names[:4])
294
+ names = ["blocks", str(block_id[block_type])] + names[4:]
295
+ if "ff" in names:
296
+ ff_index = names.index("ff")
297
+ component = ".".join(names[ff_index:ff_index+3])
298
+ component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
299
+ names = names[:ff_index] + [component] + names[ff_index+3:]
300
+ if "to_out" in names:
301
+ names.pop(names.index("to_out") + 1)
302
+ else:
303
+ print(name, state_dict[name].shape)
304
+ # raise ValueError(f"Unknown parameters: {name}")
305
+ rename_dict[name] = ".".join(names)
306
+
307
+ # Convert state_dict
308
+ state_dict_ = {}
309
+ for name, param in state_dict.items():
310
+ if name not in rename_dict:
311
+ continue
312
+ if ".proj_in." in name or ".proj_out." in name:
313
+ param = param.squeeze()
314
+ state_dict_[rename_dict[name]] = param
315
+ return state_dict_
316
+
317
+ def from_civitai(self, state_dict):
318
+ return self.from_diffusers(state_dict)
diffsynth/models/sdxl_ipadapter.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .svd_image_encoder import SVDImageEncoder
2
+ from transformers import CLIPImageProcessor
3
+ import torch
4
+
5
+
6
+ class IpAdapterXLCLIPImageEmbedder(SVDImageEncoder):
7
+ def __init__(self):
8
+ super().__init__(embed_dim=1664, encoder_intermediate_size=8192, projection_dim=1280, num_encoder_layers=48, num_heads=16, head_dim=104)
9
+ self.image_processor = CLIPImageProcessor()
10
+
11
+ def forward(self, image):
12
+ pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values
13
+ pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype)
14
+ return super().forward(pixel_values)
15
+
16
+
17
+ class IpAdapterImageProjModel(torch.nn.Module):
18
+ def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280, clip_extra_context_tokens=4):
19
+ super().__init__()
20
+ self.cross_attention_dim = cross_attention_dim
21
+ self.clip_extra_context_tokens = clip_extra_context_tokens
22
+ self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
23
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
24
+
25
+ def forward(self, image_embeds):
26
+ clip_extra_context_tokens = self.proj(image_embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
27
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
28
+ return clip_extra_context_tokens
29
+
30
+
31
+ class IpAdapterModule(torch.nn.Module):
32
+ def __init__(self, input_dim, output_dim):
33
+ super().__init__()
34
+ self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
35
+ self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
36
+
37
+ def forward(self, hidden_states):
38
+ ip_k = self.to_k_ip(hidden_states)
39
+ ip_v = self.to_v_ip(hidden_states)
40
+ return ip_k, ip_v
41
+
42
+
43
+ class SDXLIpAdapter(torch.nn.Module):
44
+ def __init__(self):
45
+ super().__init__()
46
+ shape_list = [(2048, 640)] * 4 + [(2048, 1280)] * 50 + [(2048, 640)] * 6 + [(2048, 1280)] * 10
47
+ self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list])
48
+ self.image_proj = IpAdapterImageProjModel()
49
+ self.set_full_adapter()
50
+
51
+ def set_full_adapter(self):
52
+ map_list = sum([
53
+ [(7, i) for i in range(2)],
54
+ [(10, i) for i in range(2)],
55
+ [(15, i) for i in range(10)],
56
+ [(18, i) for i in range(10)],
57
+ [(25, i) for i in range(10)],
58
+ [(28, i) for i in range(10)],
59
+ [(31, i) for i in range(10)],
60
+ [(35, i) for i in range(2)],
61
+ [(38, i) for i in range(2)],
62
+ [(41, i) for i in range(2)],
63
+ [(21, i) for i in range(10)],
64
+ ], [])
65
+ self.call_block_id = {i: j for j, i in enumerate(map_list)}
66
+
67
+ def set_less_adapter(self):
68
+ map_list = sum([
69
+ [(7, i) for i in range(2)],
70
+ [(10, i) for i in range(2)],
71
+ [(15, i) for i in range(10)],
72
+ [(18, i) for i in range(10)],
73
+ [(25, i) for i in range(10)],
74
+ [(28, i) for i in range(10)],
75
+ [(31, i) for i in range(10)],
76
+ [(35, i) for i in range(2)],
77
+ [(38, i) for i in range(2)],
78
+ [(41, i) for i in range(2)],
79
+ [(21, i) for i in range(10)],
80
+ ], [])
81
+ self.call_block_id = {i: j for j, i in enumerate(map_list) if j>=34 and j<44}
82
+
83
+ def forward(self, hidden_states, scale=1.0):
84
+ hidden_states = self.image_proj(hidden_states)
85
+ hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
86
+ ip_kv_dict = {}
87
+ for (block_id, transformer_id) in self.call_block_id:
88
+ ipadapter_id = self.call_block_id[(block_id, transformer_id)]
89
+ ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
90
+ if block_id not in ip_kv_dict:
91
+ ip_kv_dict[block_id] = {}
92
+ ip_kv_dict[block_id][transformer_id] = {
93
+ "ip_k": ip_k,
94
+ "ip_v": ip_v,
95
+ "scale": scale
96
+ }
97
+ return ip_kv_dict
98
+
99
+ @staticmethod
100
+ def state_dict_converter():
101
+ return SDXLIpAdapterStateDictConverter()
102
+
103
+
104
+ class SDXLIpAdapterStateDictConverter:
105
+ def __init__(self):
106
+ pass
107
+
108
+ def from_diffusers(self, state_dict):
109
+ state_dict_ = {}
110
+ for name in state_dict["ip_adapter"]:
111
+ names = name.split(".")
112
+ layer_id = str(int(names[0]) // 2)
113
+ name_ = ".".join(["ipadapter_modules"] + [layer_id] + names[1:])
114
+ state_dict_[name_] = state_dict["ip_adapter"][name]
115
+ for name in state_dict["image_proj"]:
116
+ name_ = "image_proj." + name
117
+ state_dict_[name_] = state_dict["image_proj"][name]
118
+ return state_dict_
119
+
120
+ def from_civitai(self, state_dict):
121
+ return self.from_diffusers(state_dict)
122
+
diffsynth/models/sdxl_motion.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .sd_motion import TemporalBlock
2
+ import torch
3
+
4
+
5
+
6
+ class SDXLMotionModel(torch.nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+ self.motion_modules = torch.nn.ModuleList([
10
+ TemporalBlock(8, 320//8, 320, eps=1e-6),
11
+ TemporalBlock(8, 320//8, 320, eps=1e-6),
12
+
13
+ TemporalBlock(8, 640//8, 640, eps=1e-6),
14
+ TemporalBlock(8, 640//8, 640, eps=1e-6),
15
+
16
+ TemporalBlock(8, 1280//8, 1280, eps=1e-6),
17
+ TemporalBlock(8, 1280//8, 1280, eps=1e-6),
18
+
19
+ TemporalBlock(8, 1280//8, 1280, eps=1e-6),
20
+ TemporalBlock(8, 1280//8, 1280, eps=1e-6),
21
+ TemporalBlock(8, 1280//8, 1280, eps=1e-6),
22
+
23
+ TemporalBlock(8, 640//8, 640, eps=1e-6),
24
+ TemporalBlock(8, 640//8, 640, eps=1e-6),
25
+ TemporalBlock(8, 640//8, 640, eps=1e-6),
26
+
27
+ TemporalBlock(8, 320//8, 320, eps=1e-6),
28
+ TemporalBlock(8, 320//8, 320, eps=1e-6),
29
+ TemporalBlock(8, 320//8, 320, eps=1e-6),
30
+ ])
31
+ self.call_block_id = {
32
+ 0: 0,
33
+ 2: 1,
34
+ 7: 2,
35
+ 10: 3,
36
+ 15: 4,
37
+ 18: 5,
38
+ 25: 6,
39
+ 28: 7,
40
+ 31: 8,
41
+ 35: 9,
42
+ 38: 10,
43
+ 41: 11,
44
+ 44: 12,
45
+ 46: 13,
46
+ 48: 14,
47
+ }
48
+
49
+ def forward(self):
50
+ pass
51
+
52
+ @staticmethod
53
+ def state_dict_converter():
54
+ return SDMotionModelStateDictConverter()
55
+
56
+
57
+ class SDMotionModelStateDictConverter:
58
+ def __init__(self):
59
+ pass
60
+
61
+ def from_diffusers(self, state_dict):
62
+ rename_dict = {
63
+ "norm": "norm",
64
+ "proj_in": "proj_in",
65
+ "transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q",
66
+ "transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k",
67
+ "transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v",
68
+ "transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out",
69
+ "transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1",
70
+ "transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q",
71
+ "transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k",
72
+ "transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v",
73
+ "transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out",
74
+ "transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2",
75
+ "transformer_blocks.0.norms.0": "transformer_blocks.0.norm1",
76
+ "transformer_blocks.0.norms.1": "transformer_blocks.0.norm2",
77
+ "transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj",
78
+ "transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff",
79
+ "transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3",
80
+ "proj_out": "proj_out",
81
+ }
82
+ name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")])
83
+ name_list += sorted([i for i in state_dict if i.startswith("mid_block.")])
84
+ name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")])
85
+ state_dict_ = {}
86
+ last_prefix, module_id = "", -1
87
+ for name in name_list:
88
+ names = name.split(".")
89
+ prefix_index = names.index("temporal_transformer") + 1
90
+ prefix = ".".join(names[:prefix_index])
91
+ if prefix != last_prefix:
92
+ last_prefix = prefix
93
+ module_id += 1
94
+ middle_name = ".".join(names[prefix_index:-1])
95
+ suffix = names[-1]
96
+ if "pos_encoder" in names:
97
+ rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]])
98
+ else:
99
+ rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
100
+ state_dict_[rename] = state_dict[name]
101
+ return state_dict_
102
+
103
+ def from_civitai(self, state_dict):
104
+ return self.from_diffusers(state_dict)
diffsynth/models/sdxl_text_encoder.py ADDED
@@ -0,0 +1,759 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .sd_text_encoder import CLIPEncoderLayer
3
+
4
+
5
+ class SDXLTextEncoder(torch.nn.Module):
6
+ def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=11, encoder_intermediate_size=3072):
7
+ super().__init__()
8
+
9
+ # token_embedding
10
+ self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
11
+
12
+ # position_embeds (This is a fixed tensor)
13
+ self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
14
+
15
+ # encoders
16
+ self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
17
+
18
+ # attn_mask
19
+ self.attn_mask = self.attention_mask(max_position_embeddings)
20
+
21
+ # The text encoder is different to that in Stable Diffusion 1.x.
22
+ # It does not include final_layer_norm.
23
+
24
+ def attention_mask(self, length):
25
+ mask = torch.empty(length, length)
26
+ mask.fill_(float("-inf"))
27
+ mask.triu_(1)
28
+ return mask
29
+
30
+ def forward(self, input_ids, clip_skip=1):
31
+ embeds = self.token_embedding(input_ids) + self.position_embeds
32
+ attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
33
+ for encoder_id, encoder in enumerate(self.encoders):
34
+ embeds = encoder(embeds, attn_mask=attn_mask)
35
+ if encoder_id + clip_skip == len(self.encoders):
36
+ break
37
+ return embeds
38
+
39
+ @staticmethod
40
+ def state_dict_converter():
41
+ return SDXLTextEncoderStateDictConverter()
42
+
43
+
44
+ class SDXLTextEncoder2(torch.nn.Module):
45
+ def __init__(self, embed_dim=1280, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=32, encoder_intermediate_size=5120):
46
+ super().__init__()
47
+
48
+ # token_embedding
49
+ self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
50
+
51
+ # position_embeds (This is a fixed tensor)
52
+ self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
53
+
54
+ # encoders
55
+ self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=20, head_dim=64, use_quick_gelu=False) for _ in range(num_encoder_layers)])
56
+
57
+ # attn_mask
58
+ self.attn_mask = self.attention_mask(max_position_embeddings)
59
+
60
+ # final_layer_norm
61
+ self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
62
+
63
+ # text_projection
64
+ self.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False)
65
+
66
+ def attention_mask(self, length):
67
+ mask = torch.empty(length, length)
68
+ mask.fill_(float("-inf"))
69
+ mask.triu_(1)
70
+ return mask
71
+
72
+ def forward(self, input_ids, clip_skip=2):
73
+ embeds = self.token_embedding(input_ids) + self.position_embeds
74
+ attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
75
+ for encoder_id, encoder in enumerate(self.encoders):
76
+ embeds = encoder(embeds, attn_mask=attn_mask)
77
+ if encoder_id + clip_skip == len(self.encoders):
78
+ hidden_states = embeds
79
+ embeds = self.final_layer_norm(embeds)
80
+ pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
81
+ pooled_embeds = self.text_projection(pooled_embeds)
82
+ return pooled_embeds, hidden_states
83
+
84
+ @staticmethod
85
+ def state_dict_converter():
86
+ return SDXLTextEncoder2StateDictConverter()
87
+
88
+
89
+ class SDXLTextEncoderStateDictConverter:
90
+ def __init__(self):
91
+ pass
92
+
93
+ def from_diffusers(self, state_dict):
94
+ rename_dict = {
95
+ "text_model.embeddings.token_embedding.weight": "token_embedding.weight",
96
+ "text_model.embeddings.position_embedding.weight": "position_embeds",
97
+ "text_model.final_layer_norm.weight": "final_layer_norm.weight",
98
+ "text_model.final_layer_norm.bias": "final_layer_norm.bias"
99
+ }
100
+ attn_rename_dict = {
101
+ "self_attn.q_proj": "attn.to_q",
102
+ "self_attn.k_proj": "attn.to_k",
103
+ "self_attn.v_proj": "attn.to_v",
104
+ "self_attn.out_proj": "attn.to_out",
105
+ "layer_norm1": "layer_norm1",
106
+ "layer_norm2": "layer_norm2",
107
+ "mlp.fc1": "fc1",
108
+ "mlp.fc2": "fc2",
109
+ }
110
+ state_dict_ = {}
111
+ for name in state_dict:
112
+ if name in rename_dict:
113
+ param = state_dict[name]
114
+ if name == "text_model.embeddings.position_embedding.weight":
115
+ param = param.reshape((1, param.shape[0], param.shape[1]))
116
+ state_dict_[rename_dict[name]] = param
117
+ elif name.startswith("text_model.encoder.layers."):
118
+ param = state_dict[name]
119
+ names = name.split(".")
120
+ layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
121
+ name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
122
+ state_dict_[name_] = param
123
+ return state_dict_
124
+
125
+ def from_civitai(self, state_dict):
126
+ rename_dict = {
127
+ "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight": "position_embeds",
128
+ "conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
129
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
130
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
131
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
132
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
133
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
134
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
135
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
136
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
137
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
138
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
139
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
140
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
141
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
142
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
143
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
144
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
145
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
146
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
147
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
148
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
149
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
150
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
151
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
152
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
153
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
154
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
155
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
156
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
157
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
158
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
159
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
160
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
161
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
162
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
163
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
164
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
165
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
166
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
167
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
168
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
169
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
170
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
171
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
172
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
173
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
174
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
175
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
176
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
177
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
178
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
179
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
180
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
181
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
182
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
183
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
184
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
185
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
186
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
187
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
188
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
189
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
190
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
191
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
192
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
193
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
194
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
195
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
196
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
197
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
198
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
199
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
200
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
201
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
202
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
203
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
204
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
205
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
206
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
207
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
208
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
209
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
210
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
211
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
212
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
213
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
214
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
215
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
216
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
217
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
218
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
219
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
220
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
221
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
222
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
223
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
224
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
225
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
226
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
227
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
228
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
229
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
230
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
231
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
232
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
233
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
234
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
235
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
236
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
237
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
238
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
239
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
240
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
241
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
242
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
243
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
244
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
245
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
246
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
247
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
248
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
249
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
250
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
251
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
252
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
253
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
254
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
255
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
256
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
257
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
258
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
259
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
260
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
261
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
262
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
263
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
264
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
265
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
266
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
267
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
268
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
269
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
270
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
271
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
272
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
273
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
274
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
275
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
276
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
277
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
278
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
279
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
280
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
281
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
282
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
283
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
284
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
285
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
286
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
287
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
288
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
289
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
290
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
291
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
292
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
293
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
294
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
295
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
296
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
297
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
298
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
299
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
300
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
301
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
302
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
303
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
304
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
305
+ }
306
+ state_dict_ = {}
307
+ for name in state_dict:
308
+ if name in rename_dict:
309
+ param = state_dict[name]
310
+ if name == "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight":
311
+ param = param.reshape((1, param.shape[0], param.shape[1]))
312
+ state_dict_[rename_dict[name]] = param
313
+ return state_dict_
314
+
315
+
316
+ class SDXLTextEncoder2StateDictConverter:
317
+ def __init__(self):
318
+ pass
319
+
320
+ def from_diffusers(self, state_dict):
321
+ rename_dict = {
322
+ "text_model.embeddings.token_embedding.weight": "token_embedding.weight",
323
+ "text_model.embeddings.position_embedding.weight": "position_embeds",
324
+ "text_model.final_layer_norm.weight": "final_layer_norm.weight",
325
+ "text_model.final_layer_norm.bias": "final_layer_norm.bias",
326
+ "text_projection.weight": "text_projection.weight"
327
+ }
328
+ attn_rename_dict = {
329
+ "self_attn.q_proj": "attn.to_q",
330
+ "self_attn.k_proj": "attn.to_k",
331
+ "self_attn.v_proj": "attn.to_v",
332
+ "self_attn.out_proj": "attn.to_out",
333
+ "layer_norm1": "layer_norm1",
334
+ "layer_norm2": "layer_norm2",
335
+ "mlp.fc1": "fc1",
336
+ "mlp.fc2": "fc2",
337
+ }
338
+ state_dict_ = {}
339
+ for name in state_dict:
340
+ if name in rename_dict:
341
+ param = state_dict[name]
342
+ if name == "text_model.embeddings.position_embedding.weight":
343
+ param = param.reshape((1, param.shape[0], param.shape[1]))
344
+ state_dict_[rename_dict[name]] = param
345
+ elif name.startswith("text_model.encoder.layers."):
346
+ param = state_dict[name]
347
+ names = name.split(".")
348
+ layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
349
+ name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
350
+ state_dict_[name_] = param
351
+ return state_dict_
352
+
353
+ def from_civitai(self, state_dict):
354
+ rename_dict = {
355
+ "conditioner.embedders.1.model.ln_final.bias": "final_layer_norm.bias",
356
+ "conditioner.embedders.1.model.ln_final.weight": "final_layer_norm.weight",
357
+ "conditioner.embedders.1.model.positional_embedding": "position_embeds",
358
+ "conditioner.embedders.1.model.token_embedding.weight": "token_embedding.weight",
359
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias": ['encoders.0.attn.to_q.bias', 'encoders.0.attn.to_k.bias', 'encoders.0.attn.to_v.bias'],
360
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight": ['encoders.0.attn.to_q.weight', 'encoders.0.attn.to_k.weight', 'encoders.0.attn.to_v.weight'],
361
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.bias": "encoders.0.attn.to_out.bias",
362
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.weight": "encoders.0.attn.to_out.weight",
363
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias": "encoders.0.layer_norm1.bias",
364
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.weight": "encoders.0.layer_norm1.weight",
365
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.bias": "encoders.0.layer_norm2.bias",
366
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.weight": "encoders.0.layer_norm2.weight",
367
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.bias": "encoders.0.fc1.bias",
368
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.weight": "encoders.0.fc1.weight",
369
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.bias": "encoders.0.fc2.bias",
370
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.weight": "encoders.0.fc2.weight",
371
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias": ['encoders.1.attn.to_q.bias', 'encoders.1.attn.to_k.bias', 'encoders.1.attn.to_v.bias'],
372
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight": ['encoders.1.attn.to_q.weight', 'encoders.1.attn.to_k.weight', 'encoders.1.attn.to_v.weight'],
373
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.bias": "encoders.1.attn.to_out.bias",
374
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.weight": "encoders.1.attn.to_out.weight",
375
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.bias": "encoders.1.layer_norm1.bias",
376
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.weight": "encoders.1.layer_norm1.weight",
377
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.bias": "encoders.1.layer_norm2.bias",
378
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.weight": "encoders.1.layer_norm2.weight",
379
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.bias": "encoders.1.fc1.bias",
380
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.weight": "encoders.1.fc1.weight",
381
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.bias": "encoders.1.fc2.bias",
382
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.weight": "encoders.1.fc2.weight",
383
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias": ['encoders.10.attn.to_q.bias', 'encoders.10.attn.to_k.bias', 'encoders.10.attn.to_v.bias'],
384
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight": ['encoders.10.attn.to_q.weight', 'encoders.10.attn.to_k.weight', 'encoders.10.attn.to_v.weight'],
385
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.bias": "encoders.10.attn.to_out.bias",
386
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.weight": "encoders.10.attn.to_out.weight",
387
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.bias": "encoders.10.layer_norm1.bias",
388
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.weight": "encoders.10.layer_norm1.weight",
389
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.bias": "encoders.10.layer_norm2.bias",
390
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.weight": "encoders.10.layer_norm2.weight",
391
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.bias": "encoders.10.fc1.bias",
392
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.weight": "encoders.10.fc1.weight",
393
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.bias": "encoders.10.fc2.bias",
394
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.weight": "encoders.10.fc2.weight",
395
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias": ['encoders.11.attn.to_q.bias', 'encoders.11.attn.to_k.bias', 'encoders.11.attn.to_v.bias'],
396
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight": ['encoders.11.attn.to_q.weight', 'encoders.11.attn.to_k.weight', 'encoders.11.attn.to_v.weight'],
397
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.bias": "encoders.11.attn.to_out.bias",
398
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.weight": "encoders.11.attn.to_out.weight",
399
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.bias": "encoders.11.layer_norm1.bias",
400
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.weight": "encoders.11.layer_norm1.weight",
401
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.bias": "encoders.11.layer_norm2.bias",
402
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.weight": "encoders.11.layer_norm2.weight",
403
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.bias": "encoders.11.fc1.bias",
404
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.weight": "encoders.11.fc1.weight",
405
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.bias": "encoders.11.fc2.bias",
406
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.weight": "encoders.11.fc2.weight",
407
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias": ['encoders.12.attn.to_q.bias', 'encoders.12.attn.to_k.bias', 'encoders.12.attn.to_v.bias'],
408
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight": ['encoders.12.attn.to_q.weight', 'encoders.12.attn.to_k.weight', 'encoders.12.attn.to_v.weight'],
409
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.bias": "encoders.12.attn.to_out.bias",
410
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.weight": "encoders.12.attn.to_out.weight",
411
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.bias": "encoders.12.layer_norm1.bias",
412
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.weight": "encoders.12.layer_norm1.weight",
413
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.bias": "encoders.12.layer_norm2.bias",
414
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.weight": "encoders.12.layer_norm2.weight",
415
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.bias": "encoders.12.fc1.bias",
416
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.weight": "encoders.12.fc1.weight",
417
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.bias": "encoders.12.fc2.bias",
418
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.weight": "encoders.12.fc2.weight",
419
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias": ['encoders.13.attn.to_q.bias', 'encoders.13.attn.to_k.bias', 'encoders.13.attn.to_v.bias'],
420
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight": ['encoders.13.attn.to_q.weight', 'encoders.13.attn.to_k.weight', 'encoders.13.attn.to_v.weight'],
421
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.bias": "encoders.13.attn.to_out.bias",
422
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.weight": "encoders.13.attn.to_out.weight",
423
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.bias": "encoders.13.layer_norm1.bias",
424
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.weight": "encoders.13.layer_norm1.weight",
425
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.bias": "encoders.13.layer_norm2.bias",
426
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.weight": "encoders.13.layer_norm2.weight",
427
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.bias": "encoders.13.fc1.bias",
428
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.weight": "encoders.13.fc1.weight",
429
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.bias": "encoders.13.fc2.bias",
430
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.weight": "encoders.13.fc2.weight",
431
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias": ['encoders.14.attn.to_q.bias', 'encoders.14.attn.to_k.bias', 'encoders.14.attn.to_v.bias'],
432
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight": ['encoders.14.attn.to_q.weight', 'encoders.14.attn.to_k.weight', 'encoders.14.attn.to_v.weight'],
433
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.bias": "encoders.14.attn.to_out.bias",
434
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.weight": "encoders.14.attn.to_out.weight",
435
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.bias": "encoders.14.layer_norm1.bias",
436
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.weight": "encoders.14.layer_norm1.weight",
437
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.bias": "encoders.14.layer_norm2.bias",
438
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.weight": "encoders.14.layer_norm2.weight",
439
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.bias": "encoders.14.fc1.bias",
440
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.weight": "encoders.14.fc1.weight",
441
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.bias": "encoders.14.fc2.bias",
442
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.weight": "encoders.14.fc2.weight",
443
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias": ['encoders.15.attn.to_q.bias', 'encoders.15.attn.to_k.bias', 'encoders.15.attn.to_v.bias'],
444
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight": ['encoders.15.attn.to_q.weight', 'encoders.15.attn.to_k.weight', 'encoders.15.attn.to_v.weight'],
445
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.bias": "encoders.15.attn.to_out.bias",
446
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.weight": "encoders.15.attn.to_out.weight",
447
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.bias": "encoders.15.layer_norm1.bias",
448
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.weight": "encoders.15.layer_norm1.weight",
449
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.bias": "encoders.15.layer_norm2.bias",
450
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.weight": "encoders.15.layer_norm2.weight",
451
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.bias": "encoders.15.fc1.bias",
452
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.weight": "encoders.15.fc1.weight",
453
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.bias": "encoders.15.fc2.bias",
454
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.weight": "encoders.15.fc2.weight",
455
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias": ['encoders.16.attn.to_q.bias', 'encoders.16.attn.to_k.bias', 'encoders.16.attn.to_v.bias'],
456
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight": ['encoders.16.attn.to_q.weight', 'encoders.16.attn.to_k.weight', 'encoders.16.attn.to_v.weight'],
457
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.bias": "encoders.16.attn.to_out.bias",
458
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.weight": "encoders.16.attn.to_out.weight",
459
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.bias": "encoders.16.layer_norm1.bias",
460
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.weight": "encoders.16.layer_norm1.weight",
461
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.bias": "encoders.16.layer_norm2.bias",
462
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.weight": "encoders.16.layer_norm2.weight",
463
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.bias": "encoders.16.fc1.bias",
464
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.weight": "encoders.16.fc1.weight",
465
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.bias": "encoders.16.fc2.bias",
466
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.weight": "encoders.16.fc2.weight",
467
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias": ['encoders.17.attn.to_q.bias', 'encoders.17.attn.to_k.bias', 'encoders.17.attn.to_v.bias'],
468
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight": ['encoders.17.attn.to_q.weight', 'encoders.17.attn.to_k.weight', 'encoders.17.attn.to_v.weight'],
469
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.bias": "encoders.17.attn.to_out.bias",
470
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.weight": "encoders.17.attn.to_out.weight",
471
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.bias": "encoders.17.layer_norm1.bias",
472
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.weight": "encoders.17.layer_norm1.weight",
473
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.bias": "encoders.17.layer_norm2.bias",
474
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.weight": "encoders.17.layer_norm2.weight",
475
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.bias": "encoders.17.fc1.bias",
476
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.weight": "encoders.17.fc1.weight",
477
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.bias": "encoders.17.fc2.bias",
478
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.weight": "encoders.17.fc2.weight",
479
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias": ['encoders.18.attn.to_q.bias', 'encoders.18.attn.to_k.bias', 'encoders.18.attn.to_v.bias'],
480
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight": ['encoders.18.attn.to_q.weight', 'encoders.18.attn.to_k.weight', 'encoders.18.attn.to_v.weight'],
481
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.bias": "encoders.18.attn.to_out.bias",
482
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.weight": "encoders.18.attn.to_out.weight",
483
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.bias": "encoders.18.layer_norm1.bias",
484
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.weight": "encoders.18.layer_norm1.weight",
485
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.bias": "encoders.18.layer_norm2.bias",
486
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.weight": "encoders.18.layer_norm2.weight",
487
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.bias": "encoders.18.fc1.bias",
488
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.weight": "encoders.18.fc1.weight",
489
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.bias": "encoders.18.fc2.bias",
490
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.weight": "encoders.18.fc2.weight",
491
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias": ['encoders.19.attn.to_q.bias', 'encoders.19.attn.to_k.bias', 'encoders.19.attn.to_v.bias'],
492
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight": ['encoders.19.attn.to_q.weight', 'encoders.19.attn.to_k.weight', 'encoders.19.attn.to_v.weight'],
493
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.bias": "encoders.19.attn.to_out.bias",
494
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.weight": "encoders.19.attn.to_out.weight",
495
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.bias": "encoders.19.layer_norm1.bias",
496
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.weight": "encoders.19.layer_norm1.weight",
497
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.bias": "encoders.19.layer_norm2.bias",
498
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.weight": "encoders.19.layer_norm2.weight",
499
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.bias": "encoders.19.fc1.bias",
500
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.weight": "encoders.19.fc1.weight",
501
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.bias": "encoders.19.fc2.bias",
502
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.weight": "encoders.19.fc2.weight",
503
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias": ['encoders.2.attn.to_q.bias', 'encoders.2.attn.to_k.bias', 'encoders.2.attn.to_v.bias'],
504
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight": ['encoders.2.attn.to_q.weight', 'encoders.2.attn.to_k.weight', 'encoders.2.attn.to_v.weight'],
505
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.bias": "encoders.2.attn.to_out.bias",
506
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.weight": "encoders.2.attn.to_out.weight",
507
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.bias": "encoders.2.layer_norm1.bias",
508
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.weight": "encoders.2.layer_norm1.weight",
509
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.bias": "encoders.2.layer_norm2.bias",
510
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.weight": "encoders.2.layer_norm2.weight",
511
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.bias": "encoders.2.fc1.bias",
512
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.weight": "encoders.2.fc1.weight",
513
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.bias": "encoders.2.fc2.bias",
514
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.weight": "encoders.2.fc2.weight",
515
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias": ['encoders.20.attn.to_q.bias', 'encoders.20.attn.to_k.bias', 'encoders.20.attn.to_v.bias'],
516
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight": ['encoders.20.attn.to_q.weight', 'encoders.20.attn.to_k.weight', 'encoders.20.attn.to_v.weight'],
517
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.bias": "encoders.20.attn.to_out.bias",
518
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.weight": "encoders.20.attn.to_out.weight",
519
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.bias": "encoders.20.layer_norm1.bias",
520
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.weight": "encoders.20.layer_norm1.weight",
521
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.bias": "encoders.20.layer_norm2.bias",
522
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.weight": "encoders.20.layer_norm2.weight",
523
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.bias": "encoders.20.fc1.bias",
524
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.weight": "encoders.20.fc1.weight",
525
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.bias": "encoders.20.fc2.bias",
526
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.weight": "encoders.20.fc2.weight",
527
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias": ['encoders.21.attn.to_q.bias', 'encoders.21.attn.to_k.bias', 'encoders.21.attn.to_v.bias'],
528
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight": ['encoders.21.attn.to_q.weight', 'encoders.21.attn.to_k.weight', 'encoders.21.attn.to_v.weight'],
529
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.bias": "encoders.21.attn.to_out.bias",
530
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.weight": "encoders.21.attn.to_out.weight",
531
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.bias": "encoders.21.layer_norm1.bias",
532
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.weight": "encoders.21.layer_norm1.weight",
533
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.bias": "encoders.21.layer_norm2.bias",
534
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.weight": "encoders.21.layer_norm2.weight",
535
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.bias": "encoders.21.fc1.bias",
536
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.weight": "encoders.21.fc1.weight",
537
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.bias": "encoders.21.fc2.bias",
538
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.weight": "encoders.21.fc2.weight",
539
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias": ['encoders.22.attn.to_q.bias', 'encoders.22.attn.to_k.bias', 'encoders.22.attn.to_v.bias'],
540
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight": ['encoders.22.attn.to_q.weight', 'encoders.22.attn.to_k.weight', 'encoders.22.attn.to_v.weight'],
541
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.bias": "encoders.22.attn.to_out.bias",
542
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.weight": "encoders.22.attn.to_out.weight",
543
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.bias": "encoders.22.layer_norm1.bias",
544
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.weight": "encoders.22.layer_norm1.weight",
545
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.bias": "encoders.22.layer_norm2.bias",
546
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.weight": "encoders.22.layer_norm2.weight",
547
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.bias": "encoders.22.fc1.bias",
548
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.weight": "encoders.22.fc1.weight",
549
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.bias": "encoders.22.fc2.bias",
550
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.weight": "encoders.22.fc2.weight",
551
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias": ['encoders.23.attn.to_q.bias', 'encoders.23.attn.to_k.bias', 'encoders.23.attn.to_v.bias'],
552
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight": ['encoders.23.attn.to_q.weight', 'encoders.23.attn.to_k.weight', 'encoders.23.attn.to_v.weight'],
553
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.bias": "encoders.23.attn.to_out.bias",
554
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.weight": "encoders.23.attn.to_out.weight",
555
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.bias": "encoders.23.layer_norm1.bias",
556
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.weight": "encoders.23.layer_norm1.weight",
557
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.bias": "encoders.23.layer_norm2.bias",
558
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.weight": "encoders.23.layer_norm2.weight",
559
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.bias": "encoders.23.fc1.bias",
560
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.weight": "encoders.23.fc1.weight",
561
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.bias": "encoders.23.fc2.bias",
562
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.weight": "encoders.23.fc2.weight",
563
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias": ['encoders.24.attn.to_q.bias', 'encoders.24.attn.to_k.bias', 'encoders.24.attn.to_v.bias'],
564
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight": ['encoders.24.attn.to_q.weight', 'encoders.24.attn.to_k.weight', 'encoders.24.attn.to_v.weight'],
565
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.bias": "encoders.24.attn.to_out.bias",
566
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.weight": "encoders.24.attn.to_out.weight",
567
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.bias": "encoders.24.layer_norm1.bias",
568
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.weight": "encoders.24.layer_norm1.weight",
569
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.bias": "encoders.24.layer_norm2.bias",
570
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.weight": "encoders.24.layer_norm2.weight",
571
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.bias": "encoders.24.fc1.bias",
572
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.weight": "encoders.24.fc1.weight",
573
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.bias": "encoders.24.fc2.bias",
574
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.weight": "encoders.24.fc2.weight",
575
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias": ['encoders.25.attn.to_q.bias', 'encoders.25.attn.to_k.bias', 'encoders.25.attn.to_v.bias'],
576
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight": ['encoders.25.attn.to_q.weight', 'encoders.25.attn.to_k.weight', 'encoders.25.attn.to_v.weight'],
577
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.bias": "encoders.25.attn.to_out.bias",
578
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.weight": "encoders.25.attn.to_out.weight",
579
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.bias": "encoders.25.layer_norm1.bias",
580
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.weight": "encoders.25.layer_norm1.weight",
581
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.bias": "encoders.25.layer_norm2.bias",
582
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.weight": "encoders.25.layer_norm2.weight",
583
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.bias": "encoders.25.fc1.bias",
584
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.weight": "encoders.25.fc1.weight",
585
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.bias": "encoders.25.fc2.bias",
586
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.weight": "encoders.25.fc2.weight",
587
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias": ['encoders.26.attn.to_q.bias', 'encoders.26.attn.to_k.bias', 'encoders.26.attn.to_v.bias'],
588
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight": ['encoders.26.attn.to_q.weight', 'encoders.26.attn.to_k.weight', 'encoders.26.attn.to_v.weight'],
589
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.bias": "encoders.26.attn.to_out.bias",
590
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.weight": "encoders.26.attn.to_out.weight",
591
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.bias": "encoders.26.layer_norm1.bias",
592
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.weight": "encoders.26.layer_norm1.weight",
593
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.bias": "encoders.26.layer_norm2.bias",
594
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.weight": "encoders.26.layer_norm2.weight",
595
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.bias": "encoders.26.fc1.bias",
596
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.weight": "encoders.26.fc1.weight",
597
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.bias": "encoders.26.fc2.bias",
598
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.weight": "encoders.26.fc2.weight",
599
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias": ['encoders.27.attn.to_q.bias', 'encoders.27.attn.to_k.bias', 'encoders.27.attn.to_v.bias'],
600
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight": ['encoders.27.attn.to_q.weight', 'encoders.27.attn.to_k.weight', 'encoders.27.attn.to_v.weight'],
601
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.bias": "encoders.27.attn.to_out.bias",
602
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.weight": "encoders.27.attn.to_out.weight",
603
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.bias": "encoders.27.layer_norm1.bias",
604
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.weight": "encoders.27.layer_norm1.weight",
605
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.bias": "encoders.27.layer_norm2.bias",
606
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.weight": "encoders.27.layer_norm2.weight",
607
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.bias": "encoders.27.fc1.bias",
608
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.weight": "encoders.27.fc1.weight",
609
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.bias": "encoders.27.fc2.bias",
610
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.weight": "encoders.27.fc2.weight",
611
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias": ['encoders.28.attn.to_q.bias', 'encoders.28.attn.to_k.bias', 'encoders.28.attn.to_v.bias'],
612
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight": ['encoders.28.attn.to_q.weight', 'encoders.28.attn.to_k.weight', 'encoders.28.attn.to_v.weight'],
613
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.bias": "encoders.28.attn.to_out.bias",
614
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.weight": "encoders.28.attn.to_out.weight",
615
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.bias": "encoders.28.layer_norm1.bias",
616
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.weight": "encoders.28.layer_norm1.weight",
617
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.bias": "encoders.28.layer_norm2.bias",
618
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.weight": "encoders.28.layer_norm2.weight",
619
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.bias": "encoders.28.fc1.bias",
620
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.weight": "encoders.28.fc1.weight",
621
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.bias": "encoders.28.fc2.bias",
622
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.weight": "encoders.28.fc2.weight",
623
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias": ['encoders.29.attn.to_q.bias', 'encoders.29.attn.to_k.bias', 'encoders.29.attn.to_v.bias'],
624
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight": ['encoders.29.attn.to_q.weight', 'encoders.29.attn.to_k.weight', 'encoders.29.attn.to_v.weight'],
625
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.bias": "encoders.29.attn.to_out.bias",
626
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.weight": "encoders.29.attn.to_out.weight",
627
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.bias": "encoders.29.layer_norm1.bias",
628
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.weight": "encoders.29.layer_norm1.weight",
629
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.bias": "encoders.29.layer_norm2.bias",
630
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.weight": "encoders.29.layer_norm2.weight",
631
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.bias": "encoders.29.fc1.bias",
632
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.weight": "encoders.29.fc1.weight",
633
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.bias": "encoders.29.fc2.bias",
634
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.weight": "encoders.29.fc2.weight",
635
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias": ['encoders.3.attn.to_q.bias', 'encoders.3.attn.to_k.bias', 'encoders.3.attn.to_v.bias'],
636
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight": ['encoders.3.attn.to_q.weight', 'encoders.3.attn.to_k.weight', 'encoders.3.attn.to_v.weight'],
637
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.bias": "encoders.3.attn.to_out.bias",
638
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.weight": "encoders.3.attn.to_out.weight",
639
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.bias": "encoders.3.layer_norm1.bias",
640
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.weight": "encoders.3.layer_norm1.weight",
641
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.bias": "encoders.3.layer_norm2.bias",
642
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.weight": "encoders.3.layer_norm2.weight",
643
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.bias": "encoders.3.fc1.bias",
644
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.weight": "encoders.3.fc1.weight",
645
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.bias": "encoders.3.fc2.bias",
646
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.weight": "encoders.3.fc2.weight",
647
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias": ['encoders.30.attn.to_q.bias', 'encoders.30.attn.to_k.bias', 'encoders.30.attn.to_v.bias'],
648
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight": ['encoders.30.attn.to_q.weight', 'encoders.30.attn.to_k.weight', 'encoders.30.attn.to_v.weight'],
649
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.bias": "encoders.30.attn.to_out.bias",
650
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.weight": "encoders.30.attn.to_out.weight",
651
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.bias": "encoders.30.layer_norm1.bias",
652
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.weight": "encoders.30.layer_norm1.weight",
653
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.bias": "encoders.30.layer_norm2.bias",
654
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.weight": "encoders.30.layer_norm2.weight",
655
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.bias": "encoders.30.fc1.bias",
656
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.weight": "encoders.30.fc1.weight",
657
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.bias": "encoders.30.fc2.bias",
658
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.weight": "encoders.30.fc2.weight",
659
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias": ['encoders.31.attn.to_q.bias', 'encoders.31.attn.to_k.bias', 'encoders.31.attn.to_v.bias'],
660
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight": ['encoders.31.attn.to_q.weight', 'encoders.31.attn.to_k.weight', 'encoders.31.attn.to_v.weight'],
661
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.bias": "encoders.31.attn.to_out.bias",
662
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.weight": "encoders.31.attn.to_out.weight",
663
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.bias": "encoders.31.layer_norm1.bias",
664
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.weight": "encoders.31.layer_norm1.weight",
665
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.bias": "encoders.31.layer_norm2.bias",
666
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.weight": "encoders.31.layer_norm2.weight",
667
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.bias": "encoders.31.fc1.bias",
668
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.weight": "encoders.31.fc1.weight",
669
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.bias": "encoders.31.fc2.bias",
670
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.weight": "encoders.31.fc2.weight",
671
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias": ['encoders.4.attn.to_q.bias', 'encoders.4.attn.to_k.bias', 'encoders.4.attn.to_v.bias'],
672
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight": ['encoders.4.attn.to_q.weight', 'encoders.4.attn.to_k.weight', 'encoders.4.attn.to_v.weight'],
673
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.bias": "encoders.4.attn.to_out.bias",
674
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.weight": "encoders.4.attn.to_out.weight",
675
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.bias": "encoders.4.layer_norm1.bias",
676
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.weight": "encoders.4.layer_norm1.weight",
677
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.bias": "encoders.4.layer_norm2.bias",
678
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.weight": "encoders.4.layer_norm2.weight",
679
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.bias": "encoders.4.fc1.bias",
680
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.weight": "encoders.4.fc1.weight",
681
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.bias": "encoders.4.fc2.bias",
682
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.weight": "encoders.4.fc2.weight",
683
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias": ['encoders.5.attn.to_q.bias', 'encoders.5.attn.to_k.bias', 'encoders.5.attn.to_v.bias'],
684
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight": ['encoders.5.attn.to_q.weight', 'encoders.5.attn.to_k.weight', 'encoders.5.attn.to_v.weight'],
685
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.bias": "encoders.5.attn.to_out.bias",
686
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.weight": "encoders.5.attn.to_out.weight",
687
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.bias": "encoders.5.layer_norm1.bias",
688
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.weight": "encoders.5.layer_norm1.weight",
689
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.bias": "encoders.5.layer_norm2.bias",
690
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.weight": "encoders.5.layer_norm2.weight",
691
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.bias": "encoders.5.fc1.bias",
692
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.weight": "encoders.5.fc1.weight",
693
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.bias": "encoders.5.fc2.bias",
694
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.weight": "encoders.5.fc2.weight",
695
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias": ['encoders.6.attn.to_q.bias', 'encoders.6.attn.to_k.bias', 'encoders.6.attn.to_v.bias'],
696
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight": ['encoders.6.attn.to_q.weight', 'encoders.6.attn.to_k.weight', 'encoders.6.attn.to_v.weight'],
697
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.bias": "encoders.6.attn.to_out.bias",
698
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.weight": "encoders.6.attn.to_out.weight",
699
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.bias": "encoders.6.layer_norm1.bias",
700
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.weight": "encoders.6.layer_norm1.weight",
701
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.bias": "encoders.6.layer_norm2.bias",
702
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.weight": "encoders.6.layer_norm2.weight",
703
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.bias": "encoders.6.fc1.bias",
704
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.weight": "encoders.6.fc1.weight",
705
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.bias": "encoders.6.fc2.bias",
706
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.weight": "encoders.6.fc2.weight",
707
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias": ['encoders.7.attn.to_q.bias', 'encoders.7.attn.to_k.bias', 'encoders.7.attn.to_v.bias'],
708
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight": ['encoders.7.attn.to_q.weight', 'encoders.7.attn.to_k.weight', 'encoders.7.attn.to_v.weight'],
709
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.bias": "encoders.7.attn.to_out.bias",
710
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.weight": "encoders.7.attn.to_out.weight",
711
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.bias": "encoders.7.layer_norm1.bias",
712
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.weight": "encoders.7.layer_norm1.weight",
713
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.bias": "encoders.7.layer_norm2.bias",
714
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.weight": "encoders.7.layer_norm2.weight",
715
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.bias": "encoders.7.fc1.bias",
716
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.weight": "encoders.7.fc1.weight",
717
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.bias": "encoders.7.fc2.bias",
718
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.weight": "encoders.7.fc2.weight",
719
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias": ['encoders.8.attn.to_q.bias', 'encoders.8.attn.to_k.bias', 'encoders.8.attn.to_v.bias'],
720
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight": ['encoders.8.attn.to_q.weight', 'encoders.8.attn.to_k.weight', 'encoders.8.attn.to_v.weight'],
721
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.bias": "encoders.8.attn.to_out.bias",
722
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.weight": "encoders.8.attn.to_out.weight",
723
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.bias": "encoders.8.layer_norm1.bias",
724
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.weight": "encoders.8.layer_norm1.weight",
725
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.bias": "encoders.8.layer_norm2.bias",
726
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.weight": "encoders.8.layer_norm2.weight",
727
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.bias": "encoders.8.fc1.bias",
728
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.weight": "encoders.8.fc1.weight",
729
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.bias": "encoders.8.fc2.bias",
730
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.weight": "encoders.8.fc2.weight",
731
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias": ['encoders.9.attn.to_q.bias', 'encoders.9.attn.to_k.bias', 'encoders.9.attn.to_v.bias'],
732
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight": ['encoders.9.attn.to_q.weight', 'encoders.9.attn.to_k.weight', 'encoders.9.attn.to_v.weight'],
733
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.bias": "encoders.9.attn.to_out.bias",
734
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.weight": "encoders.9.attn.to_out.weight",
735
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.bias": "encoders.9.layer_norm1.bias",
736
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.weight": "encoders.9.layer_norm1.weight",
737
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.bias": "encoders.9.layer_norm2.bias",
738
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.weight": "encoders.9.layer_norm2.weight",
739
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.bias": "encoders.9.fc1.bias",
740
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.weight": "encoders.9.fc1.weight",
741
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias": "encoders.9.fc2.bias",
742
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight": "encoders.9.fc2.weight",
743
+ "conditioner.embedders.1.model.text_projection": "text_projection.weight",
744
+ }
745
+ state_dict_ = {}
746
+ for name in state_dict:
747
+ if name in rename_dict:
748
+ param = state_dict[name]
749
+ if name == "conditioner.embedders.1.model.positional_embedding":
750
+ param = param.reshape((1, param.shape[0], param.shape[1]))
751
+ elif name == "conditioner.embedders.1.model.text_projection":
752
+ param = param.T
753
+ if isinstance(rename_dict[name], str):
754
+ state_dict_[rename_dict[name]] = param
755
+ else:
756
+ length = param.shape[0] // 3
757
+ for i, rename in enumerate(rename_dict[name]):
758
+ state_dict_[rename] = param[i*length: i*length+length]
759
+ return state_dict_