Korakoe anton-l HF staff commited on
Commit
9447174
0 Parent(s):

Duplicate from diffusers/convert-sd-ckpt

Browse files

Co-authored-by: Anton Lozhkov <anton-l@users.noreply.huggingface.co>

Files changed (6) hide show
  1. .gitattributes +34 -0
  2. README.md +14 -0
  3. app.py +279 -0
  4. convert.py +878 -0
  5. original_config.yaml +70 -0
  6. requirements.txt +6 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Convert to Diffusers
3
+ emoji: 🤖
4
+ colorFrom: indigo
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 3.9.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: diffusers/convert-sd-ckpt
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import shutil
4
+ import zipfile
5
+
6
+ import gradio as gr
7
+ import requests
8
+ from huggingface_hub import create_repo, upload_folder, whoami
9
+
10
+ from convert import convert_full_checkpoint
11
+
12
+ MODELS_DIR = "models/"
13
+ CKPT_FILE = MODELS_DIR + "model.ckpt"
14
+ HF_MODEL_DIR = MODELS_DIR + "diffusers_model"
15
+ ZIP_FILE = MODELS_DIR + "model.zip"
16
+
17
+
18
+ def download_ckpt(url, out_path):
19
+ with open(out_path, "wb") as out_file:
20
+ with requests.get(url, stream=True) as r:
21
+ r.raise_for_status()
22
+ for chunk in r.iter_content(chunk_size=8192):
23
+ out_file.write(chunk)
24
+
25
+
26
+ def zip_model(model_path, zip_path):
27
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as zip_file:
28
+ for root, dirs, files in os.walk(model_path):
29
+ for file in files:
30
+ zip_file.write(
31
+ os.path.join(root, file),
32
+ os.path.relpath(
33
+ os.path.join(root, file), os.path.join(model_path, "..")
34
+ ),
35
+ )
36
+
37
+
38
+ def download_checkpoint_and_config(ckpt_url, config_url):
39
+ ckpt_url = ckpt_url.strip()
40
+ config_url = config_url.strip()
41
+
42
+ if not ckpt_url.startswith("http://") and not ckpt_url.startswith("https://"):
43
+ raise ValueError("Invalid checkpoint URL")
44
+
45
+ if config_url.startswith("http://") or config_url.startswith("https://"):
46
+ response = requests.get(config_url)
47
+ response.raise_for_status()
48
+ config_file = io.BytesIO(response.content)
49
+ elif config_url != "":
50
+ raise ValueError("Invalid config URL")
51
+ else:
52
+ config_file = open("original_config.yaml", "r")
53
+
54
+ download_ckpt(ckpt_url, CKPT_FILE)
55
+
56
+ return CKPT_FILE, config_file
57
+
58
+
59
+ def convert_and_download(ckpt_url, config_url, scheduler_type, extract_ema):
60
+ shutil.rmtree(MODELS_DIR, ignore_errors=True)
61
+ os.makedirs(HF_MODEL_DIR)
62
+
63
+ ckpt_path, config_file = download_checkpoint_and_config(ckpt_url, config_url)
64
+
65
+ convert_full_checkpoint(
66
+ ckpt_path,
67
+ config_file,
68
+ scheduler_type=scheduler_type,
69
+ extract_ema=(extract_ema == "EMA"),
70
+ output_path=HF_MODEL_DIR,
71
+ )
72
+ zip_model(HF_MODEL_DIR, ZIP_FILE)
73
+
74
+ return ZIP_FILE
75
+
76
+
77
+ def convert_and_upload(
78
+ ckpt_url, config_url, scheduler_type, extract_ema, token, model_name
79
+ ):
80
+ shutil.rmtree(MODELS_DIR, ignore_errors=True)
81
+ os.makedirs(HF_MODEL_DIR)
82
+
83
+ try:
84
+ ckpt_path, config_file = download_checkpoint_and_config(ckpt_url, config_url)
85
+
86
+ username = whoami(token)["name"]
87
+ repo_name = f"{username}/{model_name}"
88
+ repo_url = create_repo(repo_name, token=token, exist_ok=True)
89
+ convert_full_checkpoint(
90
+ ckpt_path,
91
+ config_file,
92
+ scheduler_type=scheduler_type,
93
+ extract_ema=(extract_ema == "EMA"),
94
+ output_path=HF_MODEL_DIR,
95
+ )
96
+ upload_folder(repo_id=repo_name, folder_path=HF_MODEL_DIR, token=token, commit_message=f"Upload diffusers weights")
97
+ except Exception as e:
98
+ return f"#### Error: {e}"
99
+ return f"#### Success! Model uploaded to [{repo_url}]({repo_url})"
100
+
101
+
102
+ TTILE_IMAGE = """
103
+ <div
104
+ style="
105
+ display: block;
106
+ margin-left: auto;
107
+ margin-right: auto;
108
+ width: 50%;
109
+ "
110
+ >
111
+ <img src="https://github.com/huggingface/diffusers/raw/main/docs/source/imgs/diffusers_library.jpg"/>
112
+ </div>
113
+ """
114
+
115
+ TITLE = """
116
+ <div
117
+ style="
118
+ display: inline-flex;
119
+ align-items: center;
120
+ text-align: center;
121
+ max-width: 1400px;
122
+ gap: 0.8rem;
123
+ font-size: 2.2rem;
124
+ "
125
+ >
126
+ <h1 style="font-weight: 900; margin-bottom: 10px; margin-top: 10px;">
127
+ Convert Stable Diffusion `.ckpt` files to Hugging Face Diffusers 🔥
128
+ </h1>
129
+ </div>
130
+ """
131
+
132
+ with gr.Blocks() as interface:
133
+ gr.HTML(TTILE_IMAGE)
134
+ gr.HTML(TITLE)
135
+ gr.Markdown("We will perform all of the checkpoint surgery for you, and create a clean diffusers model!")
136
+ gr.Markdown("This converter will also remove any pickled code from third-party checkpoints.")
137
+
138
+ with gr.Row():
139
+ with gr.Column(scale=50):
140
+ gr.Markdown("### 1. Paste a URL to your <model>.ckpt file")
141
+ ckpt_url = gr.Textbox(
142
+ max_lines=1,
143
+ label="URL to <model>.ckpt",
144
+ placeholder="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned.ckpt",
145
+ )
146
+
147
+ with gr.Column(scale=50):
148
+ gr.Markdown("### (Optional) paste a URL to your <config>.yaml file")
149
+ config_url = gr.Textbox(
150
+ max_lines=1,
151
+ label="URL to <config>.yaml",
152
+ placeholder="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-inference.yaml",
153
+ )
154
+ gr.Markdown(
155
+ "**If you don't provide a config file, we'll try to use"
156
+ " [v1-inference.yaml](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-inference.yaml).*"
157
+ )
158
+ with gr.Accordion("Advanced Settings"):
159
+ scheduler_type = gr.Dropdown(
160
+ label="Choose a scheduler type (if not sure, keep the PNDM default)",
161
+ choices=["PNDM", "K-LMS", "Euler", "EulerAncestral", "DDIM"],
162
+ value="PNDM",
163
+ )
164
+ extract_ema = gr.Radio(
165
+ label=(
166
+ "EMA weights usually yield higher quality images for inference."
167
+ " Non-EMA weights are usually better to continue fine-tuning."
168
+ ),
169
+ choices=["EMA", "Non-EMA"],
170
+ value="EMA",
171
+ interactive=True,
172
+ )
173
+
174
+ gr.Markdown("### 2. Choose what to do with the converted model")
175
+ model_choice = gr.Radio(
176
+ show_label=False,
177
+ choices=[
178
+ "Download the model as an archive",
179
+ "Host the model on the Hugging Face Hub",
180
+ # "Submit a PR with the model for an existing Hub repository",
181
+ ],
182
+ type="index",
183
+ value="Download the model as an archive",
184
+ interactive=True,
185
+ )
186
+
187
+ download_panel = gr.Column(visible=True)
188
+ upload_panel = gr.Column(visible=False)
189
+ # pr_panel = gr.Column(visible=False)
190
+
191
+ model_choice.change(
192
+ fn=lambda i: gr.update(visible=(i == 0)),
193
+ inputs=model_choice,
194
+ outputs=download_panel,
195
+ )
196
+ model_choice.change(
197
+ fn=lambda i: gr.update(visible=(i == 1)),
198
+ inputs=model_choice,
199
+ outputs=upload_panel,
200
+ )
201
+ # model_choice.change(
202
+ # fn=lambda i: gr.update(visible=(i == 2)),
203
+ # inputs=model_choice,
204
+ # outputs=pr_panel,
205
+ # )
206
+
207
+ with download_panel:
208
+ gr.Markdown("### 3. Convert and download")
209
+
210
+ down_btn = gr.Button("Convert")
211
+ output_file = gr.File(
212
+ label="Download the converted model",
213
+ type="binary",
214
+ interactive=False,
215
+ visible=True,
216
+ )
217
+
218
+ down_btn.click(
219
+ fn=convert_and_download,
220
+ inputs=[ckpt_url, config_url, scheduler_type, extract_ema],
221
+ outputs=output_file,
222
+ )
223
+
224
+ with upload_panel:
225
+ gr.Markdown("### 3. Convert and host on the Hub")
226
+ gr.Markdown(
227
+ "This will create a new repository if it doesn't exist yet, and upload the model to the Hugging Face Hub.\n\n"
228
+ "Paste a WRITE token from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)"
229
+ " and make up a model name."
230
+ )
231
+ up_token = gr.Textbox(
232
+ max_lines=1,
233
+ label="Hugging Face token",
234
+ )
235
+ up_model_name = gr.Textbox(
236
+ max_lines=1,
237
+ label="Hub model name (e.g. `artistic-diffusion-v1`)",
238
+ placeholder="my-awesome-model",
239
+ )
240
+
241
+ upload_btn = gr.Button("Convert and upload")
242
+ with gr.Box():
243
+ output_text = gr.Markdown()
244
+ upload_btn.click(
245
+ fn=convert_and_upload,
246
+ inputs=[
247
+ ckpt_url,
248
+ config_url,
249
+ scheduler_type,
250
+ extract_ema,
251
+ up_token,
252
+ up_model_name,
253
+ ],
254
+ outputs=output_text,
255
+ )
256
+
257
+ # with pr_panel:
258
+ # gr.Markdown("### 3. Convert and submit as a PR")
259
+ # gr.Markdown(
260
+ # "This will open a Pull Request on the original model repository, if it already exists on the Hub.\n\n"
261
+ # "Paste a write-access token from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)"
262
+ # " and paste an existing model id from the Hub in the `username/model-name` form."
263
+ # )
264
+ # pr_token = gr.Textbox(
265
+ # max_lines=1,
266
+ # label="Hugging Face token",
267
+ # )
268
+ # pr_model_name = gr.Textbox(
269
+ # max_lines=1,
270
+ # label="Hub model name (e.g. `diffuser/artistic-diffusion-v1`)",
271
+ # placeholder="diffuser/my-awesome-model",
272
+ # )
273
+ #
274
+ # btn = gr.Button("Convert and open a PR")
275
+ # output = gr.Markdown(label="Output")
276
+
277
+
278
+ interface.queue(concurrency_count=1)
279
+ interface.launch()
convert.py ADDED
@@ -0,0 +1,878 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Conversion script for the Stable Diffusion checkpoints. """
16
+
17
+ import torch
18
+
19
+ try:
20
+ from omegaconf import OmegaConf
21
+ except ImportError:
22
+ raise ImportError(
23
+ "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
24
+ )
25
+
26
+ from diffusers import (AutoencoderKL, DDIMScheduler,
27
+ EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
28
+ LMSDiscreteScheduler, PNDMScheduler,
29
+ StableDiffusionPipeline, UNet2DConditionModel)
30
+ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import (
31
+ LDMBertConfig, LDMBertModel)
32
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
33
+ from transformers import AutoFeatureExtractor, CLIPTextModel, CLIPTokenizer
34
+
35
+
36
+ def shave_segments(path, n_shave_prefix_segments=1):
37
+ """
38
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
39
+ """
40
+ if n_shave_prefix_segments >= 0:
41
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
42
+ else:
43
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
44
+
45
+
46
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
47
+ """
48
+ Updates paths inside resnets to the new naming scheme (local renaming)
49
+ """
50
+ mapping = []
51
+ for old_item in old_list:
52
+ new_item = old_item.replace("in_layers.0", "norm1")
53
+ new_item = new_item.replace("in_layers.2", "conv1")
54
+
55
+ new_item = new_item.replace("out_layers.0", "norm2")
56
+ new_item = new_item.replace("out_layers.3", "conv2")
57
+
58
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
59
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
60
+
61
+ new_item = shave_segments(
62
+ new_item, n_shave_prefix_segments=n_shave_prefix_segments
63
+ )
64
+
65
+ mapping.append({"old": old_item, "new": new_item})
66
+
67
+ return mapping
68
+
69
+
70
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
71
+ """
72
+ Updates paths inside resnets to the new naming scheme (local renaming)
73
+ """
74
+ mapping = []
75
+ for old_item in old_list:
76
+ new_item = old_item
77
+
78
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
79
+ new_item = shave_segments(
80
+ new_item, n_shave_prefix_segments=n_shave_prefix_segments
81
+ )
82
+
83
+ mapping.append({"old": old_item, "new": new_item})
84
+
85
+ return mapping
86
+
87
+
88
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
89
+ """
90
+ Updates paths inside attentions to the new naming scheme (local renaming)
91
+ """
92
+ mapping = []
93
+ for old_item in old_list:
94
+ new_item = old_item
95
+
96
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
97
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
98
+
99
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
100
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
101
+
102
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
103
+
104
+ mapping.append({"old": old_item, "new": new_item})
105
+
106
+ return mapping
107
+
108
+
109
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
110
+ """
111
+ Updates paths inside attentions to the new naming scheme (local renaming)
112
+ """
113
+ mapping = []
114
+ for old_item in old_list:
115
+ new_item = old_item
116
+
117
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
118
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
119
+
120
+ new_item = new_item.replace("q.weight", "query.weight")
121
+ new_item = new_item.replace("q.bias", "query.bias")
122
+
123
+ new_item = new_item.replace("k.weight", "key.weight")
124
+ new_item = new_item.replace("k.bias", "key.bias")
125
+
126
+ new_item = new_item.replace("v.weight", "value.weight")
127
+ new_item = new_item.replace("v.bias", "value.bias")
128
+
129
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
130
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
131
+
132
+ new_item = shave_segments(
133
+ new_item, n_shave_prefix_segments=n_shave_prefix_segments
134
+ )
135
+
136
+ mapping.append({"old": old_item, "new": new_item})
137
+
138
+ return mapping
139
+
140
+
141
+ def assign_to_checkpoint(
142
+ paths,
143
+ checkpoint,
144
+ old_checkpoint,
145
+ attention_paths_to_split=None,
146
+ additional_replacements=None,
147
+ config=None,
148
+ ):
149
+ """
150
+ This does the final conversion step: take locally converted weights and apply a global renaming
151
+ to them. It splits attention layers, and takes into account additional replacements
152
+ that may arise.
153
+ Assigns the weights to the new checkpoint.
154
+ """
155
+ assert isinstance(
156
+ paths, list
157
+ ), "Paths should be a list of dicts containing 'old' and 'new' keys."
158
+
159
+ # Splits the attention layers into three variables.
160
+ if attention_paths_to_split is not None:
161
+ for path, path_map in attention_paths_to_split.items():
162
+ old_tensor = old_checkpoint[path]
163
+ channels = old_tensor.shape[0] // 3
164
+
165
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
166
+
167
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
168
+
169
+ old_tensor = old_tensor.reshape(
170
+ (num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]
171
+ )
172
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
173
+
174
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
175
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
176
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
177
+
178
+ for path in paths:
179
+ new_path = path["new"]
180
+
181
+ # These have already been assigned
182
+ if (
183
+ attention_paths_to_split is not None
184
+ and new_path in attention_paths_to_split
185
+ ):
186
+ continue
187
+
188
+ # Global renaming happens here
189
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
190
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
191
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
192
+
193
+ if additional_replacements is not None:
194
+ for replacement in additional_replacements:
195
+ new_path = new_path.replace(replacement["old"], replacement["new"])
196
+
197
+ # proj_attn.weight has to be converted from conv 1D to linear
198
+ if "proj_attn.weight" in new_path:
199
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
200
+ else:
201
+ checkpoint[new_path] = old_checkpoint[path["old"]]
202
+
203
+
204
+ def conv_attn_to_linear(checkpoint):
205
+ keys = list(checkpoint.keys())
206
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
207
+ for key in keys:
208
+ if ".".join(key.split(".")[-2:]) in attn_keys:
209
+ if checkpoint[key].ndim > 2:
210
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
211
+ elif "proj_attn.weight" in key:
212
+ if checkpoint[key].ndim > 2:
213
+ checkpoint[key] = checkpoint[key][:, :, 0]
214
+
215
+
216
+ def create_unet_diffusers_config(original_config):
217
+ """
218
+ Creates a config for the diffusers based on the config of the LDM model.
219
+ """
220
+ unet_params = original_config.model.params.unet_config.params
221
+
222
+ block_out_channels = [
223
+ unet_params.model_channels * mult for mult in unet_params.channel_mult
224
+ ]
225
+
226
+ down_block_types = []
227
+ resolution = 1
228
+ for i in range(len(block_out_channels)):
229
+ block_type = (
230
+ "CrossAttnDownBlock2D"
231
+ if resolution in unet_params.attention_resolutions
232
+ else "DownBlock2D"
233
+ )
234
+ down_block_types.append(block_type)
235
+ if i != len(block_out_channels) - 1:
236
+ resolution *= 2
237
+
238
+ up_block_types = []
239
+ for i in range(len(block_out_channels)):
240
+ block_type = (
241
+ "CrossAttnUpBlock2D"
242
+ if resolution in unet_params.attention_resolutions
243
+ else "UpBlock2D"
244
+ )
245
+ up_block_types.append(block_type)
246
+ resolution //= 2
247
+
248
+ config = dict(
249
+ sample_size=unet_params.image_size,
250
+ in_channels=unet_params.in_channels,
251
+ out_channels=unet_params.out_channels,
252
+ down_block_types=tuple(down_block_types),
253
+ up_block_types=tuple(up_block_types),
254
+ block_out_channels=tuple(block_out_channels),
255
+ layers_per_block=unet_params.num_res_blocks,
256
+ cross_attention_dim=unet_params.context_dim,
257
+ attention_head_dim=unet_params.num_heads,
258
+ )
259
+
260
+ return config
261
+
262
+
263
+ def create_vae_diffusers_config(original_config):
264
+ """
265
+ Creates a config for the diffusers based on the config of the LDM model.
266
+ """
267
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
268
+ _ = original_config.model.params.first_stage_config.params.embed_dim
269
+
270
+ block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
271
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
272
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
273
+
274
+ config = dict(
275
+ sample_size=vae_params.resolution,
276
+ in_channels=vae_params.in_channels,
277
+ out_channels=vae_params.out_ch,
278
+ down_block_types=tuple(down_block_types),
279
+ up_block_types=tuple(up_block_types),
280
+ block_out_channels=tuple(block_out_channels),
281
+ latent_channels=vae_params.z_channels,
282
+ layers_per_block=vae_params.num_res_blocks,
283
+ )
284
+ return config
285
+
286
+
287
+ def create_diffusers_schedular(original_config):
288
+ schedular = DDIMScheduler(
289
+ num_train_timesteps=original_config.model.params.timesteps,
290
+ beta_start=original_config.model.params.linear_start,
291
+ beta_end=original_config.model.params.linear_end,
292
+ beta_schedule="scaled_linear",
293
+ )
294
+ return schedular
295
+
296
+
297
+ def create_ldm_bert_config(original_config):
298
+ bert_params = original_config.model.parms.cond_stage_config.params
299
+ config = LDMBertConfig(
300
+ d_model=bert_params.n_embed,
301
+ encoder_layers=bert_params.n_layer,
302
+ encoder_ffn_dim=bert_params.n_embed * 4,
303
+ )
304
+ return config
305
+
306
+
307
+ def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False):
308
+ """
309
+ Takes a state dict and a config, and returns a converted checkpoint.
310
+ """
311
+
312
+ # extract state_dict for UNet
313
+ unet_state_dict = {}
314
+ keys = list(checkpoint.keys())
315
+
316
+ unet_key = "model.diffusion_model."
317
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
318
+ if sum(k.startswith("model_ema") for k in keys) > 100:
319
+ print(f"Checkpoint has both EMA and non-EMA weights.")
320
+ if extract_ema:
321
+ print(
322
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
323
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
324
+ )
325
+ for key in keys:
326
+ if key.startswith("model.diffusion_model"):
327
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
328
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
329
+ flat_ema_key
330
+ )
331
+ else:
332
+ print(
333
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
334
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
335
+ )
336
+
337
+ for key in keys:
338
+ if key.startswith(unet_key):
339
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
340
+
341
+ new_checkpoint = {}
342
+
343
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict[
344
+ "time_embed.0.weight"
345
+ ]
346
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict[
347
+ "time_embed.0.bias"
348
+ ]
349
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict[
350
+ "time_embed.2.weight"
351
+ ]
352
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict[
353
+ "time_embed.2.bias"
354
+ ]
355
+
356
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
357
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
358
+
359
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
360
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
361
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
362
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
363
+
364
+ # Retrieves the keys for the input blocks only
365
+ num_input_blocks = len(
366
+ {
367
+ ".".join(layer.split(".")[:2])
368
+ for layer in unet_state_dict
369
+ if "input_blocks" in layer
370
+ }
371
+ )
372
+ input_blocks = {
373
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
374
+ for layer_id in range(num_input_blocks)
375
+ }
376
+
377
+ # Retrieves the keys for the middle blocks only
378
+ num_middle_blocks = len(
379
+ {
380
+ ".".join(layer.split(".")[:2])
381
+ for layer in unet_state_dict
382
+ if "middle_block" in layer
383
+ }
384
+ )
385
+ middle_blocks = {
386
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
387
+ for layer_id in range(num_middle_blocks)
388
+ }
389
+
390
+ # Retrieves the keys for the output blocks only
391
+ num_output_blocks = len(
392
+ {
393
+ ".".join(layer.split(".")[:2])
394
+ for layer in unet_state_dict
395
+ if "output_blocks" in layer
396
+ }
397
+ )
398
+ output_blocks = {
399
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
400
+ for layer_id in range(num_output_blocks)
401
+ }
402
+
403
+ for i in range(1, num_input_blocks):
404
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
405
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
406
+
407
+ resnets = [
408
+ key
409
+ for key in input_blocks[i]
410
+ if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
411
+ ]
412
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
413
+
414
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
415
+ new_checkpoint[
416
+ f"down_blocks.{block_id}.downsamplers.0.conv.weight"
417
+ ] = unet_state_dict.pop(f"input_blocks.{i}.0.op.weight")
418
+ new_checkpoint[
419
+ f"down_blocks.{block_id}.downsamplers.0.conv.bias"
420
+ ] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
421
+
422
+ paths = renew_resnet_paths(resnets)
423
+ meta_path = {
424
+ "old": f"input_blocks.{i}.0",
425
+ "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}",
426
+ }
427
+ assign_to_checkpoint(
428
+ paths,
429
+ new_checkpoint,
430
+ unet_state_dict,
431
+ additional_replacements=[meta_path],
432
+ config=config,
433
+ )
434
+
435
+ if len(attentions):
436
+ paths = renew_attention_paths(attentions)
437
+ meta_path = {
438
+ "old": f"input_blocks.{i}.1",
439
+ "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}",
440
+ }
441
+ assign_to_checkpoint(
442
+ paths,
443
+ new_checkpoint,
444
+ unet_state_dict,
445
+ additional_replacements=[meta_path],
446
+ config=config,
447
+ )
448
+
449
+ resnet_0 = middle_blocks[0]
450
+ attentions = middle_blocks[1]
451
+ resnet_1 = middle_blocks[2]
452
+
453
+ resnet_0_paths = renew_resnet_paths(resnet_0)
454
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
455
+
456
+ resnet_1_paths = renew_resnet_paths(resnet_1)
457
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
458
+
459
+ attentions_paths = renew_attention_paths(attentions)
460
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
461
+ assign_to_checkpoint(
462
+ attentions_paths,
463
+ new_checkpoint,
464
+ unet_state_dict,
465
+ additional_replacements=[meta_path],
466
+ config=config,
467
+ )
468
+
469
+ for i in range(num_output_blocks):
470
+ block_id = i // (config["layers_per_block"] + 1)
471
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
472
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
473
+ output_block_list = {}
474
+
475
+ for layer in output_block_layers:
476
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
477
+ if layer_id in output_block_list:
478
+ output_block_list[layer_id].append(layer_name)
479
+ else:
480
+ output_block_list[layer_id] = [layer_name]
481
+
482
+ if len(output_block_list) > 1:
483
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
484
+ attentions = [
485
+ key for key in output_blocks[i] if f"output_blocks.{i}.1" in key
486
+ ]
487
+
488
+ resnet_0_paths = renew_resnet_paths(resnets)
489
+ paths = renew_resnet_paths(resnets)
490
+
491
+ meta_path = {
492
+ "old": f"output_blocks.{i}.0",
493
+ "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}",
494
+ }
495
+ assign_to_checkpoint(
496
+ paths,
497
+ new_checkpoint,
498
+ unet_state_dict,
499
+ additional_replacements=[meta_path],
500
+ config=config,
501
+ )
502
+
503
+ if ["conv.weight", "conv.bias"] in output_block_list.values():
504
+ index = list(output_block_list.values()).index(
505
+ ["conv.weight", "conv.bias"]
506
+ )
507
+ new_checkpoint[
508
+ f"up_blocks.{block_id}.upsamplers.0.conv.weight"
509
+ ] = unet_state_dict[f"output_blocks.{i}.{index}.conv.weight"]
510
+ new_checkpoint[
511
+ f"up_blocks.{block_id}.upsamplers.0.conv.bias"
512
+ ] = unet_state_dict[f"output_blocks.{i}.{index}.conv.bias"]
513
+
514
+ # Clear attentions as they have been attributed above.
515
+ if len(attentions) == 2:
516
+ attentions = []
517
+
518
+ if len(attentions):
519
+ paths = renew_attention_paths(attentions)
520
+ meta_path = {
521
+ "old": f"output_blocks.{i}.1",
522
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
523
+ }
524
+ assign_to_checkpoint(
525
+ paths,
526
+ new_checkpoint,
527
+ unet_state_dict,
528
+ additional_replacements=[meta_path],
529
+ config=config,
530
+ )
531
+ else:
532
+ resnet_0_paths = renew_resnet_paths(
533
+ output_block_layers, n_shave_prefix_segments=1
534
+ )
535
+ for path in resnet_0_paths:
536
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
537
+ new_path = ".".join(
538
+ [
539
+ "up_blocks",
540
+ str(block_id),
541
+ "resnets",
542
+ str(layer_in_block_id),
543
+ path["new"],
544
+ ]
545
+ )
546
+
547
+ new_checkpoint[new_path] = unet_state_dict[old_path]
548
+
549
+ return new_checkpoint
550
+
551
+
552
+ def convert_ldm_vae_checkpoint(checkpoint, config):
553
+ # extract state dict for VAE
554
+ vae_state_dict = {}
555
+ vae_key = "first_stage_model."
556
+ keys = list(checkpoint.keys())
557
+ for key in keys:
558
+ if key.startswith(vae_key):
559
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
560
+
561
+ new_checkpoint = {}
562
+
563
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
564
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
565
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict[
566
+ "encoder.conv_out.weight"
567
+ ]
568
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
569
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict[
570
+ "encoder.norm_out.weight"
571
+ ]
572
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict[
573
+ "encoder.norm_out.bias"
574
+ ]
575
+
576
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
577
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
578
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict[
579
+ "decoder.conv_out.weight"
580
+ ]
581
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
582
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict[
583
+ "decoder.norm_out.weight"
584
+ ]
585
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict[
586
+ "decoder.norm_out.bias"
587
+ ]
588
+
589
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
590
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
591
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
592
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
593
+
594
+ # Retrieves the keys for the encoder down blocks only
595
+ num_down_blocks = len(
596
+ {
597
+ ".".join(layer.split(".")[:3])
598
+ for layer in vae_state_dict
599
+ if "encoder.down" in layer
600
+ }
601
+ )
602
+ down_blocks = {
603
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key]
604
+ for layer_id in range(num_down_blocks)
605
+ }
606
+
607
+ # Retrieves the keys for the decoder up blocks only
608
+ num_up_blocks = len(
609
+ {
610
+ ".".join(layer.split(".")[:3])
611
+ for layer in vae_state_dict
612
+ if "decoder.up" in layer
613
+ }
614
+ )
615
+ up_blocks = {
616
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key]
617
+ for layer_id in range(num_up_blocks)
618
+ }
619
+
620
+ for i in range(num_down_blocks):
621
+ resnets = [
622
+ key
623
+ for key in down_blocks[i]
624
+ if f"down.{i}" in key and f"down.{i}.downsample" not in key
625
+ ]
626
+
627
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
628
+ new_checkpoint[
629
+ f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"
630
+ ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight")
631
+ new_checkpoint[
632
+ f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"
633
+ ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias")
634
+
635
+ paths = renew_vae_resnet_paths(resnets)
636
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
637
+ assign_to_checkpoint(
638
+ paths,
639
+ new_checkpoint,
640
+ vae_state_dict,
641
+ additional_replacements=[meta_path],
642
+ config=config,
643
+ )
644
+
645
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
646
+ num_mid_res_blocks = 2
647
+ for i in range(1, num_mid_res_blocks + 1):
648
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
649
+
650
+ paths = renew_vae_resnet_paths(resnets)
651
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
652
+ assign_to_checkpoint(
653
+ paths,
654
+ new_checkpoint,
655
+ vae_state_dict,
656
+ additional_replacements=[meta_path],
657
+ config=config,
658
+ )
659
+
660
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
661
+ paths = renew_vae_attention_paths(mid_attentions)
662
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
663
+ assign_to_checkpoint(
664
+ paths,
665
+ new_checkpoint,
666
+ vae_state_dict,
667
+ additional_replacements=[meta_path],
668
+ config=config,
669
+ )
670
+ conv_attn_to_linear(new_checkpoint)
671
+
672
+ for i in range(num_up_blocks):
673
+ block_id = num_up_blocks - 1 - i
674
+ resnets = [
675
+ key
676
+ for key in up_blocks[block_id]
677
+ if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
678
+ ]
679
+
680
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
681
+ new_checkpoint[
682
+ f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"
683
+ ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"]
684
+ new_checkpoint[
685
+ f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"
686
+ ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"]
687
+
688
+ paths = renew_vae_resnet_paths(resnets)
689
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
690
+ assign_to_checkpoint(
691
+ paths,
692
+ new_checkpoint,
693
+ vae_state_dict,
694
+ additional_replacements=[meta_path],
695
+ config=config,
696
+ )
697
+
698
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
699
+ num_mid_res_blocks = 2
700
+ for i in range(1, num_mid_res_blocks + 1):
701
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
702
+
703
+ paths = renew_vae_resnet_paths(resnets)
704
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
705
+ assign_to_checkpoint(
706
+ paths,
707
+ new_checkpoint,
708
+ vae_state_dict,
709
+ additional_replacements=[meta_path],
710
+ config=config,
711
+ )
712
+
713
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
714
+ paths = renew_vae_attention_paths(mid_attentions)
715
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
716
+ assign_to_checkpoint(
717
+ paths,
718
+ new_checkpoint,
719
+ vae_state_dict,
720
+ additional_replacements=[meta_path],
721
+ config=config,
722
+ )
723
+ conv_attn_to_linear(new_checkpoint)
724
+ return new_checkpoint
725
+
726
+
727
+ def convert_ldm_bert_checkpoint(checkpoint, config):
728
+ def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
729
+ hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
730
+ hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
731
+ hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
732
+
733
+ hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
734
+ hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
735
+
736
+ def _copy_linear(hf_linear, pt_linear):
737
+ hf_linear.weight = pt_linear.weight
738
+ hf_linear.bias = pt_linear.bias
739
+
740
+ def _copy_layer(hf_layer, pt_layer):
741
+ # copy layer norms
742
+ _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
743
+ _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
744
+
745
+ # copy attn
746
+ _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
747
+
748
+ # copy MLP
749
+ pt_mlp = pt_layer[1][1]
750
+ _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
751
+ _copy_linear(hf_layer.fc2, pt_mlp.net[2])
752
+
753
+ def _copy_layers(hf_layers, pt_layers):
754
+ for i, hf_layer in enumerate(hf_layers):
755
+ if i != 0:
756
+ i += i
757
+ pt_layer = pt_layers[i : i + 2]
758
+ _copy_layer(hf_layer, pt_layer)
759
+
760
+ hf_model = LDMBertModel(config).eval()
761
+
762
+ # copy embeds
763
+ hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
764
+ hf_model.model.embed_positions.weight.data = (
765
+ checkpoint.transformer.pos_emb.emb.weight
766
+ )
767
+
768
+ # copy layer norm
769
+ _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
770
+
771
+ # copy hidden layers
772
+ _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
773
+
774
+ _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
775
+
776
+ return hf_model
777
+
778
+
779
+ def convert_ldm_clip_checkpoint(checkpoint):
780
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
781
+
782
+ keys = list(checkpoint.keys())
783
+
784
+ text_model_dict = {}
785
+
786
+ for key in keys:
787
+ if key.startswith("cond_stage_model.transformer"):
788
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[
789
+ key
790
+ ]
791
+
792
+ text_model.load_state_dict(text_model_dict)
793
+
794
+ return text_model
795
+
796
+
797
+ def convert_full_checkpoint(
798
+ checkpoint_path: str, config_file, scheduler_type, extract_ema, output_path=None
799
+ ):
800
+ original_config = OmegaConf.load(config_file)
801
+ checkpoint = torch.load(checkpoint_path, weights_only=False)
802
+ checkpoint = checkpoint["state_dict"]
803
+
804
+ num_train_timesteps = original_config.model.params.timesteps
805
+ beta_start = original_config.model.params.linear_start
806
+ beta_end = original_config.model.params.linear_end
807
+ if scheduler_type == "PNDM":
808
+ scheduler = PNDMScheduler(
809
+ beta_end=beta_end,
810
+ beta_schedule="scaled_linear",
811
+ beta_start=beta_start,
812
+ num_train_timesteps=num_train_timesteps,
813
+ skip_prk_steps=True,
814
+ )
815
+ elif scheduler_type == "K-LMS":
816
+ scheduler = LMSDiscreteScheduler(
817
+ beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
818
+ )
819
+ elif scheduler_type == "Euler":
820
+ scheduler = EulerDiscreteScheduler(
821
+ beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
822
+ )
823
+ elif scheduler_type == "EulerAncestral":
824
+ scheduler = EulerAncestralDiscreteScheduler(
825
+ beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
826
+ )
827
+ elif scheduler_type == "DDIM":
828
+ scheduler = DDIMScheduler(
829
+ beta_start=beta_start,
830
+ beta_end=beta_end,
831
+ beta_schedule="scaled_linear",
832
+ clip_sample=False,
833
+ set_alpha_to_one=False,
834
+ )
835
+ else:
836
+ raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
837
+
838
+ # Convert the UNet2DConditionModel model.
839
+ unet_config = create_unet_diffusers_config(original_config)
840
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(
841
+ checkpoint, unet_config, extract_ema=extract_ema
842
+ )
843
+
844
+ # Convert the VAE model.
845
+ vae_config = create_vae_diffusers_config(original_config)
846
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
847
+
848
+ # Convert the text model.
849
+ text_model = convert_ldm_clip_checkpoint(checkpoint)
850
+
851
+ del checkpoint
852
+
853
+ unet = UNet2DConditionModel(**unet_config)
854
+ unet.load_state_dict(converted_unet_checkpoint)
855
+ del converted_unet_checkpoint
856
+
857
+ vae = AutoencoderKL(**vae_config)
858
+ vae.load_state_dict(converted_vae_checkpoint)
859
+ del converted_vae_checkpoint
860
+
861
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
862
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(
863
+ "CompVis/stable-diffusion-safety-checker", device_map="auto"
864
+ )
865
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
866
+ "CompVis/stable-diffusion-safety-checker"
867
+ )
868
+ pipe = StableDiffusionPipeline(
869
+ vae=vae,
870
+ text_encoder=text_model,
871
+ tokenizer=tokenizer,
872
+ unet=unet,
873
+ scheduler=scheduler,
874
+ safety_checker=safety_checker,
875
+ feature_extractor=feature_extractor,
876
+ )
877
+
878
+ pipe.save_pretrained(output_path)
original_config.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+
20
+ scheduler_config: # 10000 warmup steps
21
+ target: ldm.lr_scheduler.LambdaLinearScheduler
22
+ params:
23
+ warm_up_steps: [ 10000 ]
24
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25
+ f_start: [ 1.e-6 ]
26
+ f_max: [ 1. ]
27
+ f_min: [ 1. ]
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32 # unused
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [ 4, 2, 1 ]
37
+ num_res_blocks: 2
38
+ channel_mult: [ 1, 2, 4, 4 ]
39
+ num_heads: 8
40
+ use_spatial_transformer: True
41
+ transformer_depth: 1
42
+ context_dim: 768
43
+ use_checkpoint: True
44
+ legacy: False
45
+
46
+ first_stage_config:
47
+ target: ldm.models.autoencoder.AutoencoderKL
48
+ params:
49
+ embed_dim: 4
50
+ monitor: val/rec_loss
51
+ ddconfig:
52
+ double_z: true
53
+ z_channels: 4
54
+ resolution: 256
55
+ in_channels: 3
56
+ out_ch: 3
57
+ ch: 128
58
+ ch_mult:
59
+ - 1
60
+ - 2
61
+ - 4
62
+ - 4
63
+ num_res_blocks: 2
64
+ attn_resolutions: []
65
+ dropout: 0.0
66
+ lossconfig:
67
+ target: torch.nn.Identity
68
+
69
+ cond_stage_config:
70
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ OmegaConf
2
+ pytorch_lightning
3
+ accelerate
4
+ diffusers[torch]
5
+ transformers
6
+ scipy