sanchit-gandhi HF staff commited on
Commit
318f5a3
1 Parent(s): 9c0400e

Convert weights and config

Browse files
convert_original_audioldm_to_diffusers.py ADDED
@@ -0,0 +1,1015 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Conversion script for the AudioLDM checkpoints."""
16
+
17
+ import argparse
18
+ import re
19
+
20
+ import torch
21
+ from transformers import (
22
+ AutoTokenizer,
23
+ ClapTextConfig,
24
+ ClapTextModelWithProjection,
25
+ SpeechT5HifiGan,
26
+ SpeechT5HifiGanConfig,
27
+ )
28
+
29
+ from diffusers import (
30
+ AudioLDMPipeline,
31
+ AutoencoderKL,
32
+ DDIMScheduler,
33
+ DPMSolverMultistepScheduler,
34
+ EulerAncestralDiscreteScheduler,
35
+ EulerDiscreteScheduler,
36
+ HeunDiscreteScheduler,
37
+ LMSDiscreteScheduler,
38
+ PNDMScheduler,
39
+ UNet2DConditionModel,
40
+ )
41
+ from diffusers.utils import is_omegaconf_available, is_safetensors_available
42
+ from diffusers.utils.import_utils import BACKENDS_MAPPING
43
+
44
+
45
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments
46
+ def shave_segments(path, n_shave_prefix_segments=1):
47
+ """
48
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
49
+ """
50
+ if n_shave_prefix_segments >= 0:
51
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
52
+ else:
53
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
54
+
55
+
56
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_resnet_paths
57
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
58
+ """
59
+ Updates paths inside resnets to the new naming scheme (local renaming)
60
+ """
61
+ mapping = []
62
+ for old_item in old_list:
63
+ new_item = old_item.replace("in_layers.0", "norm1")
64
+ new_item = new_item.replace("in_layers.2", "conv1")
65
+
66
+ new_item = new_item.replace("out_layers.0", "norm2")
67
+ new_item = new_item.replace("out_layers.3", "conv2")
68
+
69
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
70
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
71
+
72
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
73
+
74
+ mapping.append({"old": old_item, "new": new_item})
75
+
76
+ return mapping
77
+
78
+
79
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_resnet_paths
80
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
81
+ """
82
+ Updates paths inside resnets to the new naming scheme (local renaming)
83
+ """
84
+ mapping = []
85
+ for old_item in old_list:
86
+ new_item = old_item
87
+
88
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
89
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
90
+
91
+ mapping.append({"old": old_item, "new": new_item})
92
+
93
+ return mapping
94
+
95
+
96
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_attention_paths
97
+ def renew_attention_paths(old_list):
98
+ """
99
+ Updates paths inside attentions to the new naming scheme (local renaming)
100
+ """
101
+ mapping = []
102
+ for old_item in old_list:
103
+ new_item = old_item
104
+
105
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
106
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
107
+
108
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
109
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
110
+
111
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
112
+
113
+ mapping.append({"old": old_item, "new": new_item})
114
+
115
+ return mapping
116
+
117
+
118
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_attention_paths
119
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
120
+ """
121
+ Updates paths inside attentions to the new naming scheme (local renaming)
122
+ """
123
+ mapping = []
124
+ for old_item in old_list:
125
+ new_item = old_item
126
+
127
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
128
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
129
+
130
+ new_item = new_item.replace("q.weight", "query.weight")
131
+ new_item = new_item.replace("q.bias", "query.bias")
132
+
133
+ new_item = new_item.replace("k.weight", "key.weight")
134
+ new_item = new_item.replace("k.bias", "key.bias")
135
+
136
+ new_item = new_item.replace("v.weight", "value.weight")
137
+ new_item = new_item.replace("v.bias", "value.bias")
138
+
139
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
140
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
141
+
142
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
143
+
144
+ mapping.append({"old": old_item, "new": new_item})
145
+
146
+ return mapping
147
+
148
+
149
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.assign_to_checkpoint
150
+ def assign_to_checkpoint(
151
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
152
+ ):
153
+ """
154
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
155
+ attention layers, and takes into account additional replacements that may arise.
156
+
157
+ Assigns the weights to the new checkpoint.
158
+ """
159
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
160
+
161
+ # Splits the attention layers into three variables.
162
+ if attention_paths_to_split is not None:
163
+ for path, path_map in attention_paths_to_split.items():
164
+ old_tensor = old_checkpoint[path]
165
+ channels = old_tensor.shape[0] // 3
166
+
167
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
168
+
169
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
170
+
171
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
172
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
173
+
174
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
175
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
176
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
177
+
178
+ for path in paths:
179
+ new_path = path["new"]
180
+
181
+ # These have already been assigned
182
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
183
+ continue
184
+
185
+ # Global renaming happens here
186
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
187
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
188
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
189
+
190
+ if additional_replacements is not None:
191
+ for replacement in additional_replacements:
192
+ new_path = new_path.replace(replacement["old"], replacement["new"])
193
+
194
+ # proj_attn.weight has to be converted from conv 1D to linear
195
+ if "proj_attn.weight" in new_path:
196
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
197
+ else:
198
+ checkpoint[new_path] = old_checkpoint[path["old"]]
199
+
200
+
201
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
202
+ def conv_attn_to_linear(checkpoint):
203
+ keys = list(checkpoint.keys())
204
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
205
+ for key in keys:
206
+ if ".".join(key.split(".")[-2:]) in attn_keys:
207
+ if checkpoint[key].ndim > 2:
208
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
209
+ elif "proj_attn.weight" in key:
210
+ if checkpoint[key].ndim > 2:
211
+ checkpoint[key] = checkpoint[key][:, :, 0]
212
+
213
+
214
+ def create_unet_diffusers_config(original_config, image_size: int):
215
+ """
216
+ Creates a UNet config for diffusers based on the config of the original AudioLDM model.
217
+ """
218
+ unet_params = original_config.model.params.unet_config.params
219
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
220
+
221
+ block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
222
+
223
+ down_block_types = []
224
+ resolution = 1
225
+ for i in range(len(block_out_channels)):
226
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
227
+ down_block_types.append(block_type)
228
+ if i != len(block_out_channels) - 1:
229
+ resolution *= 2
230
+
231
+ up_block_types = []
232
+ for i in range(len(block_out_channels)):
233
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
234
+ up_block_types.append(block_type)
235
+ resolution //= 2
236
+
237
+ vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
238
+
239
+ cross_attention_dim = (
240
+ unet_params.cross_attention_dim if "cross_attention_dim" in unet_params else block_out_channels
241
+ )
242
+
243
+ class_embed_type = "simple_projection" if "extra_film_condition_dim" in unet_params else None
244
+ projection_class_embeddings_input_dim = (
245
+ unet_params.extra_film_condition_dim if "extra_film_condition_dim" in unet_params else None
246
+ )
247
+ class_embeddings_concat = unet_params.extra_film_use_concat if "extra_film_use_concat" in unet_params else None
248
+
249
+ config = {
250
+ "sample_size": image_size // vae_scale_factor,
251
+ "in_channels": unet_params.in_channels,
252
+ "out_channels": unet_params.out_channels,
253
+ "down_block_types": tuple(down_block_types),
254
+ "up_block_types": tuple(up_block_types),
255
+ "block_out_channels": tuple(block_out_channels),
256
+ "layers_per_block": unet_params.num_res_blocks,
257
+ "cross_attention_dim": cross_attention_dim,
258
+ "class_embed_type": class_embed_type,
259
+ "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
260
+ "class_embeddings_concat": class_embeddings_concat,
261
+ }
262
+
263
+ return config
264
+
265
+
266
+ # Adapted from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_vae_diffusers_config
267
+ def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
268
+ """
269
+ Creates a VAE config for diffusers based on the config of the original AudioLDM model. Compared to the original
270
+ Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE.
271
+ """
272
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
273
+ _ = original_config.model.params.first_stage_config.params.embed_dim
274
+
275
+ block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
276
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
277
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
278
+
279
+ scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config.model.params else 0.18215
280
+
281
+ config = {
282
+ "sample_size": image_size,
283
+ "in_channels": vae_params.in_channels,
284
+ "out_channels": vae_params.out_ch,
285
+ "down_block_types": tuple(down_block_types),
286
+ "up_block_types": tuple(up_block_types),
287
+ "block_out_channels": tuple(block_out_channels),
288
+ "latent_channels": vae_params.z_channels,
289
+ "layers_per_block": vae_params.num_res_blocks,
290
+ "scaling_factor": float(scaling_factor),
291
+ }
292
+ return config
293
+
294
+
295
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular
296
+ def create_diffusers_schedular(original_config):
297
+ schedular = DDIMScheduler(
298
+ num_train_timesteps=original_config.model.params.timesteps,
299
+ beta_start=original_config.model.params.linear_start,
300
+ beta_end=original_config.model.params.linear_end,
301
+ beta_schedule="scaled_linear",
302
+ )
303
+ return schedular
304
+
305
+
306
+ # Adapted from diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_unet_checkpoint
307
+ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
308
+ """
309
+ Takes a state dict and a config, and returns a converted checkpoint. Compared to the original Stable Diffusion
310
+ conversion, this function additionally converts the learnt film embedding linear layer.
311
+ """
312
+
313
+ # extract state_dict for UNet
314
+ unet_state_dict = {}
315
+ keys = list(checkpoint.keys())
316
+
317
+ unet_key = "model.diffusion_model."
318
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
319
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
320
+ print(f"Checkpoint {path} has both EMA and non-EMA weights.")
321
+ print(
322
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
323
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
324
+ )
325
+ for key in keys:
326
+ if key.startswith("model.diffusion_model"):
327
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
328
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
329
+ else:
330
+ if sum(k.startswith("model_ema") for k in keys) > 100:
331
+ print(
332
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
333
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
334
+ )
335
+
336
+ for key in keys:
337
+ if key.startswith(unet_key):
338
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
339
+
340
+ new_checkpoint = {}
341
+
342
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
343
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
344
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
345
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
346
+
347
+ new_checkpoint["class_embedding.weight"] = unet_state_dict["film_emb.weight"]
348
+ new_checkpoint["class_embedding.bias"] = unet_state_dict["film_emb.bias"]
349
+
350
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
351
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
352
+
353
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
354
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
355
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
356
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
357
+
358
+ # Retrieves the keys for the input blocks only
359
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
360
+ input_blocks = {
361
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
362
+ for layer_id in range(num_input_blocks)
363
+ }
364
+
365
+ # Retrieves the keys for the middle blocks only
366
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
367
+ middle_blocks = {
368
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
369
+ for layer_id in range(num_middle_blocks)
370
+ }
371
+
372
+ # Retrieves the keys for the output blocks only
373
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
374
+ output_blocks = {
375
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
376
+ for layer_id in range(num_output_blocks)
377
+ }
378
+
379
+ for i in range(1, num_input_blocks):
380
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
381
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
382
+
383
+ resnets = [
384
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
385
+ ]
386
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
387
+
388
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
389
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
390
+ f"input_blocks.{i}.0.op.weight"
391
+ )
392
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
393
+ f"input_blocks.{i}.0.op.bias"
394
+ )
395
+
396
+ paths = renew_resnet_paths(resnets)
397
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
398
+ assign_to_checkpoint(
399
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
400
+ )
401
+
402
+ if len(attentions):
403
+ paths = renew_attention_paths(attentions)
404
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
405
+ assign_to_checkpoint(
406
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
407
+ )
408
+
409
+ resnet_0 = middle_blocks[0]
410
+ attentions = middle_blocks[1]
411
+ resnet_1 = middle_blocks[2]
412
+
413
+ resnet_0_paths = renew_resnet_paths(resnet_0)
414
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
415
+
416
+ resnet_1_paths = renew_resnet_paths(resnet_1)
417
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
418
+
419
+ attentions_paths = renew_attention_paths(attentions)
420
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
421
+ assign_to_checkpoint(
422
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
423
+ )
424
+
425
+ for i in range(num_output_blocks):
426
+ block_id = i // (config["layers_per_block"] + 1)
427
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
428
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
429
+ output_block_list = {}
430
+
431
+ for layer in output_block_layers:
432
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
433
+ if layer_id in output_block_list:
434
+ output_block_list[layer_id].append(layer_name)
435
+ else:
436
+ output_block_list[layer_id] = [layer_name]
437
+
438
+ if len(output_block_list) > 1:
439
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
440
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
441
+
442
+ resnet_0_paths = renew_resnet_paths(resnets)
443
+ paths = renew_resnet_paths(resnets)
444
+
445
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
446
+ assign_to_checkpoint(
447
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
448
+ )
449
+
450
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
451
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
452
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
453
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
454
+ f"output_blocks.{i}.{index}.conv.weight"
455
+ ]
456
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
457
+ f"output_blocks.{i}.{index}.conv.bias"
458
+ ]
459
+
460
+ # Clear attentions as they have been attributed above.
461
+ if len(attentions) == 2:
462
+ attentions = []
463
+
464
+ if len(attentions):
465
+ paths = renew_attention_paths(attentions)
466
+ meta_path = {
467
+ "old": f"output_blocks.{i}.1",
468
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
469
+ }
470
+ assign_to_checkpoint(
471
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
472
+ )
473
+ else:
474
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
475
+ for path in resnet_0_paths:
476
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
477
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
478
+
479
+ new_checkpoint[new_path] = unet_state_dict[old_path]
480
+
481
+ return new_checkpoint
482
+
483
+
484
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_vae_checkpoint
485
+ def convert_ldm_vae_checkpoint(checkpoint, config):
486
+ # extract state dict for VAE
487
+ vae_state_dict = {}
488
+ vae_key = "first_stage_model."
489
+ keys = list(checkpoint.keys())
490
+ for key in keys:
491
+ if key.startswith(vae_key):
492
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
493
+
494
+ new_checkpoint = {}
495
+
496
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
497
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
498
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
499
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
500
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
501
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
502
+
503
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
504
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
505
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
506
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
507
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
508
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
509
+
510
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
511
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
512
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
513
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
514
+
515
+ # Retrieves the keys for the encoder down blocks only
516
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
517
+ down_blocks = {
518
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
519
+ }
520
+
521
+ # Retrieves the keys for the decoder up blocks only
522
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
523
+ up_blocks = {
524
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
525
+ }
526
+
527
+ for i in range(num_down_blocks):
528
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
529
+
530
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
531
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
532
+ f"encoder.down.{i}.downsample.conv.weight"
533
+ )
534
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
535
+ f"encoder.down.{i}.downsample.conv.bias"
536
+ )
537
+
538
+ paths = renew_vae_resnet_paths(resnets)
539
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
540
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
541
+
542
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
543
+ num_mid_res_blocks = 2
544
+ for i in range(1, num_mid_res_blocks + 1):
545
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
546
+
547
+ paths = renew_vae_resnet_paths(resnets)
548
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
549
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
550
+
551
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
552
+ paths = renew_vae_attention_paths(mid_attentions)
553
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
554
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
555
+ conv_attn_to_linear(new_checkpoint)
556
+
557
+ for i in range(num_up_blocks):
558
+ block_id = num_up_blocks - 1 - i
559
+ resnets = [
560
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
561
+ ]
562
+
563
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
564
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
565
+ f"decoder.up.{block_id}.upsample.conv.weight"
566
+ ]
567
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
568
+ f"decoder.up.{block_id}.upsample.conv.bias"
569
+ ]
570
+
571
+ paths = renew_vae_resnet_paths(resnets)
572
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
573
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
574
+
575
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
576
+ num_mid_res_blocks = 2
577
+ for i in range(1, num_mid_res_blocks + 1):
578
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
579
+
580
+ paths = renew_vae_resnet_paths(resnets)
581
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
582
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
583
+
584
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
585
+ paths = renew_vae_attention_paths(mid_attentions)
586
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
587
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
588
+ conv_attn_to_linear(new_checkpoint)
589
+ return new_checkpoint
590
+
591
+
592
+ CLAP_KEYS_TO_MODIFY_MAPPING = {
593
+ "text_branch": "text_model",
594
+ "attn": "attention.self",
595
+ "self.proj": "output.dense",
596
+ "attention.self_mask": "attn_mask",
597
+ "mlp.fc1": "intermediate.dense",
598
+ "mlp.fc2": "output.dense",
599
+ "norm1": "layernorm_before",
600
+ "norm2": "layernorm_after",
601
+ "bn0": "batch_norm",
602
+ }
603
+
604
+ CLAP_KEYS_TO_IGNORE = ["text_transform"]
605
+
606
+ CLAP_EXPECTED_MISSING_KEYS = ["text_model.embeddings.token_type_ids"]
607
+
608
+
609
+ def convert_open_clap_checkpoint(checkpoint):
610
+ """
611
+ Takes a state dict and returns a converted CLAP checkpoint.
612
+ """
613
+ # extract state dict for CLAP text embedding model, discarding the audio component
614
+ model_state_dict = {}
615
+ model_key = "cond_stage_model.model.text_"
616
+ keys = list(checkpoint.keys())
617
+ for key in keys:
618
+ if key.startswith(model_key):
619
+ model_state_dict[key.replace(model_key, "text_")] = checkpoint.get(key)
620
+
621
+ new_checkpoint = {}
622
+
623
+ sequential_layers_pattern = r".*sequential.(\d+).*"
624
+ text_projection_pattern = r".*_projection.(\d+).*"
625
+
626
+ for key, value in model_state_dict.items():
627
+ # check if key should be ignored in mapping
628
+ if key.split(".")[0] in CLAP_KEYS_TO_IGNORE:
629
+ continue
630
+
631
+ # check if any key needs to be modified
632
+ for key_to_modify, new_key in CLAP_KEYS_TO_MODIFY_MAPPING.items():
633
+ if key_to_modify in key:
634
+ key = key.replace(key_to_modify, new_key)
635
+
636
+ if re.match(sequential_layers_pattern, key):
637
+ # replace sequential layers with list
638
+ sequential_layer = re.match(sequential_layers_pattern, key).group(1)
639
+
640
+ key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
641
+ elif re.match(text_projection_pattern, key):
642
+ projecton_layer = int(re.match(text_projection_pattern, key).group(1))
643
+
644
+ # Because in CLAP they use `nn.Sequential`...
645
+ transformers_projection_layer = 1 if projecton_layer == 0 else 2
646
+
647
+ key = key.replace(f"_projection.{projecton_layer}.", f"_projection.linear{transformers_projection_layer}.")
648
+
649
+ if "audio" and "qkv" in key:
650
+ # split qkv into query key and value
651
+ mixed_qkv = value
652
+ qkv_dim = mixed_qkv.size(0) // 3
653
+
654
+ query_layer = mixed_qkv[:qkv_dim]
655
+ key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
656
+ value_layer = mixed_qkv[qkv_dim * 2 :]
657
+
658
+ new_checkpoint[key.replace("qkv", "query")] = query_layer
659
+ new_checkpoint[key.replace("qkv", "key")] = key_layer
660
+ new_checkpoint[key.replace("qkv", "value")] = value_layer
661
+ else:
662
+ new_checkpoint[key] = value
663
+
664
+ return new_checkpoint
665
+
666
+
667
+ def create_transformers_vocoder_config(original_config):
668
+ """
669
+ Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model.
670
+ """
671
+ vocoder_params = original_config.model.params.vocoder_config.params
672
+
673
+ config = {
674
+ "model_in_dim": vocoder_params.num_mels,
675
+ "sampling_rate": vocoder_params.sampling_rate,
676
+ "upsample_initial_channel": vocoder_params.upsample_initial_channel,
677
+ "upsample_rates": list(vocoder_params.upsample_rates),
678
+ "upsample_kernel_sizes": list(vocoder_params.upsample_kernel_sizes),
679
+ "resblock_kernel_sizes": list(vocoder_params.resblock_kernel_sizes),
680
+ "resblock_dilation_sizes": [
681
+ list(resblock_dilation) for resblock_dilation in vocoder_params.resblock_dilation_sizes
682
+ ],
683
+ "normalize_before": False,
684
+ }
685
+
686
+ return config
687
+
688
+
689
+ def convert_hifigan_checkpoint(checkpoint, config):
690
+ """
691
+ Takes a state dict and config, and returns a converted HiFiGAN vocoder checkpoint.
692
+ """
693
+ # extract state dict for vocoder
694
+ vocoder_state_dict = {}
695
+ vocoder_key = "first_stage_model.vocoder."
696
+ keys = list(checkpoint.keys())
697
+ for key in keys:
698
+ if key.startswith(vocoder_key):
699
+ vocoder_state_dict[key.replace(vocoder_key, "")] = checkpoint.get(key)
700
+
701
+ # fix upsampler keys, everything else is correct already
702
+ for i in range(len(config.upsample_rates)):
703
+ vocoder_state_dict[f"upsampler.{i}.weight"] = vocoder_state_dict.pop(f"ups.{i}.weight")
704
+ vocoder_state_dict[f"upsampler.{i}.bias"] = vocoder_state_dict.pop(f"ups.{i}.bias")
705
+
706
+ if not config.normalize_before:
707
+ # if we don't set normalize_before then these variables are unused, so we set them to their initialised values
708
+ vocoder_state_dict["mean"] = torch.zeros(config.model_in_dim)
709
+ vocoder_state_dict["scale"] = torch.ones(config.model_in_dim)
710
+
711
+ return vocoder_state_dict
712
+
713
+
714
+ # Adapted from https://huggingface.co/spaces/haoheliu/audioldm-text-to-audio-generation/blob/84a0384742a22bd80c44e903e241f0623e874f1d/audioldm/utils.py#L72-L73
715
+ DEFAULT_CONFIG = {
716
+ "model": {
717
+ "params": {
718
+ "linear_start": 0.0015,
719
+ "linear_end": 0.0195,
720
+ "timesteps": 1000,
721
+ "channels": 8,
722
+ "scale_by_std": True,
723
+ "unet_config": {
724
+ "target": "audioldm.latent_diffusion.openaimodel.UNetModel",
725
+ "params": {
726
+ "extra_film_condition_dim": 512,
727
+ "extra_film_use_concat": True,
728
+ "in_channels": 8,
729
+ "out_channels": 8,
730
+ "model_channels": 256,
731
+ "attention_resolutions": [8, 4, 2],
732
+ "num_res_blocks": 2,
733
+ "channel_mult": [1, 2, 3, 5],
734
+ "num_head_channels": 64,
735
+ },
736
+ },
737
+ "first_stage_config": {
738
+ "target": "audioldm.variational_autoencoder.autoencoder.AutoencoderKL",
739
+ "params": {
740
+ "embed_dim": 8,
741
+ "ddconfig": {
742
+ "z_channels": 8,
743
+ "resolution": 256,
744
+ "in_channels": 1,
745
+ "out_ch": 1,
746
+ "ch": 128,
747
+ "ch_mult": [1, 2, 4],
748
+ "num_res_blocks": 2,
749
+ },
750
+ },
751
+ },
752
+ "vocoder_config": {
753
+ "target": "audioldm.first_stage_model.vocoder",
754
+ "params": {
755
+ "upsample_rates": [5, 4, 2, 2, 2],
756
+ "upsample_kernel_sizes": [16, 16, 8, 4, 4],
757
+ "upsample_initial_channel": 1024,
758
+ "resblock_kernel_sizes": [3, 7, 11],
759
+ "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
760
+ "num_mels": 64,
761
+ "sampling_rate": 16000,
762
+ },
763
+ },
764
+ },
765
+ },
766
+ }
767
+
768
+
769
+ def load_pipeline_from_original_audioldm_ckpt(
770
+ checkpoint_path: str,
771
+ original_config_file: str = None,
772
+ image_size: int = 512,
773
+ prediction_type: str = None,
774
+ extract_ema: bool = False,
775
+ scheduler_type: str = "ddim",
776
+ num_in_channels: int = None,
777
+ device: str = None,
778
+ from_safetensors: bool = False,
779
+ ) -> AudioLDMPipeline:
780
+ """
781
+ Load an AudioLDM pipeline object from a `.ckpt`/`.safetensors` file and (ideally) a `.yaml` config file.
782
+
783
+ Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the
784
+ global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
785
+ recommended that you override the default values and/or supply an `original_config_file` wherever possible.
786
+
787
+ :param checkpoint_path: Path to `.ckpt` file. :param original_config_file: Path to `.yaml` config file
788
+ corresponding to the original architecture.
789
+ If `None`, will be automatically instantiated based on default values.
790
+ :param image_size: The image size that the model was trained on. Use 512 for original AudioLDM checkpoints. :param
791
+ prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for original
792
+ AudioLDM checkpoints.
793
+ :param num_in_channels: The number of input channels. If `None` number of input channels will be automatically
794
+ inferred.
795
+ :param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler",
796
+ "euler-ancestral", "dpm", "ddim"]`.
797
+ :param extract_ema: Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract
798
+ the EMA weights or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually
799
+ yield higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning.
800
+ :param device: The device to use. Pass `None` to determine automatically. :param from_safetensors: If
801
+ `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors
802
+ instead of PyTorch.
803
+ :return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
804
+ """
805
+
806
+ if not is_omegaconf_available():
807
+ raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
808
+
809
+ from omegaconf import OmegaConf
810
+
811
+ if from_safetensors:
812
+ if not is_safetensors_available():
813
+ raise ValueError(BACKENDS_MAPPING["safetensors"][1])
814
+
815
+ from safetensors import safe_open
816
+
817
+ checkpoint = {}
818
+ with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
819
+ for key in f.keys():
820
+ checkpoint[key] = f.get_tensor(key)
821
+ else:
822
+ if device is None:
823
+ device = "cuda" if torch.cuda.is_available() else "cpu"
824
+ checkpoint = torch.load(checkpoint_path, map_location=device)
825
+ else:
826
+ checkpoint = torch.load(checkpoint_path, map_location=device)
827
+
828
+ if "state_dict" in checkpoint:
829
+ checkpoint = checkpoint["state_dict"]
830
+
831
+ if original_config_file is None:
832
+ original_config = DEFAULT_CONFIG
833
+ original_config = OmegaConf.create(original_config)
834
+ else:
835
+ original_config = OmegaConf.load(original_config_file)
836
+
837
+ if num_in_channels is not None:
838
+ original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
839
+
840
+ if (
841
+ "parameterization" in original_config["model"]["params"]
842
+ and original_config["model"]["params"]["parameterization"] == "v"
843
+ ):
844
+ if prediction_type is None:
845
+ prediction_type = "v_prediction"
846
+ else:
847
+ if prediction_type is None:
848
+ prediction_type = "epsilon"
849
+
850
+ if image_size is None:
851
+ image_size = 512
852
+
853
+ num_train_timesteps = original_config.model.params.timesteps
854
+ beta_start = original_config.model.params.linear_start
855
+ beta_end = original_config.model.params.linear_end
856
+
857
+ scheduler = DDIMScheduler(
858
+ beta_end=beta_end,
859
+ beta_schedule="scaled_linear",
860
+ beta_start=beta_start,
861
+ num_train_timesteps=num_train_timesteps,
862
+ steps_offset=1,
863
+ clip_sample=False,
864
+ set_alpha_to_one=False,
865
+ prediction_type=prediction_type,
866
+ )
867
+ # make sure scheduler works correctly with DDIM
868
+ scheduler.register_to_config(clip_sample=False)
869
+
870
+ if scheduler_type == "pndm":
871
+ config = dict(scheduler.config)
872
+ config["skip_prk_steps"] = True
873
+ scheduler = PNDMScheduler.from_config(config)
874
+ elif scheduler_type == "lms":
875
+ scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
876
+ elif scheduler_type == "heun":
877
+ scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
878
+ elif scheduler_type == "euler":
879
+ scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
880
+ elif scheduler_type == "euler-ancestral":
881
+ scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
882
+ elif scheduler_type == "dpm":
883
+ scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
884
+ elif scheduler_type == "ddim":
885
+ scheduler = scheduler
886
+ else:
887
+ raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
888
+
889
+ # Convert the UNet2DModel
890
+ unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
891
+ unet = UNet2DConditionModel(**unet_config)
892
+
893
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(
894
+ checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
895
+ )
896
+
897
+ unet.load_state_dict(converted_unet_checkpoint)
898
+
899
+ # Convert the VAE model
900
+ vae_config = create_vae_diffusers_config(original_config, checkpoint=checkpoint, image_size=image_size)
901
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
902
+
903
+ vae = AutoencoderKL(**vae_config)
904
+ vae.load_state_dict(converted_vae_checkpoint)
905
+
906
+ # Convert the text model
907
+ # AudioLDM uses the same configuration and tokenizer as the original CLAP model
908
+ config = ClapTextConfig.from_pretrained("laion/clap-htsat-unfused")
909
+ tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")
910
+
911
+ converted_text_model = convert_open_clap_checkpoint(checkpoint)
912
+ text_model = ClapTextModelWithProjection(config)
913
+
914
+ missing_keys, unexpected_keys = text_model.load_state_dict(converted_text_model, strict=False)
915
+ # we expect not to have token_type_ids in our original state dict so let's ignore them
916
+ missing_keys = list(set(missing_keys) - set(CLAP_EXPECTED_MISSING_KEYS))
917
+
918
+ if len(unexpected_keys) > 0:
919
+ raise ValueError(f"Unexpected keys when loading CLAP model: {unexpected_keys}")
920
+
921
+ if len(missing_keys) > 0:
922
+ raise ValueError(f"Missing keys when loading CLAP model: {missing_keys}")
923
+
924
+ # Convert the vocoder model
925
+ vocoder_config = create_transformers_vocoder_config(original_config)
926
+ vocoder_config = SpeechT5HifiGanConfig(**vocoder_config)
927
+ converted_vocoder_checkpoint = convert_hifigan_checkpoint(checkpoint, vocoder_config)
928
+
929
+ vocoder = SpeechT5HifiGan(vocoder_config)
930
+ vocoder.load_state_dict(converted_vocoder_checkpoint)
931
+
932
+ # Instantiate the diffusers pipeline
933
+ pipe = AudioLDMPipeline(
934
+ vae=vae,
935
+ text_encoder=text_model,
936
+ tokenizer=tokenizer,
937
+ unet=unet,
938
+ scheduler=scheduler,
939
+ vocoder=vocoder,
940
+ )
941
+
942
+ return pipe
943
+
944
+
945
+ if __name__ == "__main__":
946
+ parser = argparse.ArgumentParser()
947
+
948
+ parser.add_argument(
949
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
950
+ )
951
+ parser.add_argument(
952
+ "--original_config_file",
953
+ default=None,
954
+ type=str,
955
+ help="The YAML config file corresponding to the original architecture.",
956
+ )
957
+ parser.add_argument(
958
+ "--num_in_channels",
959
+ default=None,
960
+ type=int,
961
+ help="The number of input channels. If `None` number of input channels will be automatically inferred.",
962
+ )
963
+ parser.add_argument(
964
+ "--scheduler_type",
965
+ default="ddim",
966
+ type=str,
967
+ help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']",
968
+ )
969
+ parser.add_argument(
970
+ "--image_size",
971
+ default=None,
972
+ type=int,
973
+ help=("The image size that the model was trained on."),
974
+ )
975
+ parser.add_argument(
976
+ "--prediction_type",
977
+ default=None,
978
+ type=str,
979
+ help=("The prediction type that the model was trained on."),
980
+ )
981
+ parser.add_argument(
982
+ "--extract_ema",
983
+ action="store_true",
984
+ help=(
985
+ "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
986
+ " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
987
+ " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
988
+ ),
989
+ )
990
+ parser.add_argument(
991
+ "--from_safetensors",
992
+ action="store_true",
993
+ help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.",
994
+ )
995
+ parser.add_argument(
996
+ "--to_safetensors",
997
+ action="store_true",
998
+ help="Whether to store pipeline in safetensors format or not.",
999
+ )
1000
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
1001
+ parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
1002
+ args = parser.parse_args()
1003
+
1004
+ pipe = load_pipeline_from_original_audioldm_ckpt(
1005
+ checkpoint_path=args.checkpoint_path,
1006
+ original_config_file=args.original_config_file,
1007
+ image_size=args.image_size,
1008
+ prediction_type=args.prediction_type,
1009
+ extract_ema=args.extract_ema,
1010
+ scheduler_type=args.scheduler_type,
1011
+ num_in_channels=args.num_in_channels,
1012
+ from_safetensors=args.from_safetensors,
1013
+ device=args.device,
1014
+ )
1015
+ pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
model_index.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AudioLDMPipeline",
3
+ "_diffusers_version": "0.15.0.dev0",
4
+ "scheduler": [
5
+ "diffusers",
6
+ "DDIMScheduler"
7
+ ],
8
+ "text_encoder": [
9
+ "transformers",
10
+ "ClapTextModelWithProjection"
11
+ ],
12
+ "tokenizer": [
13
+ "transformers",
14
+ "RobertaTokenizerFast"
15
+ ],
16
+ "unet": [
17
+ "diffusers",
18
+ "UNet2DConditionModel"
19
+ ],
20
+ "vae": [
21
+ "diffusers",
22
+ "AutoencoderKL"
23
+ ],
24
+ "vocoder": [
25
+ "transformers",
26
+ "SpeechT5HifiGan"
27
+ ]
28
+ }
run_conversion.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ python convert_original_audioldm_to_diffusers.py \
4
+ --checkpoint_path "/home/sanchit_huggingface_co/.cache/audioldm/audioldm-l-full.ckpt" \
5
+ --extract_ema \
6
+ --dump_path "./"
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DDIMScheduler",
3
+ "_diffusers_version": "0.15.0.dev0",
4
+ "beta_end": 0.0195,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.0015,
7
+ "clip_sample": false,
8
+ "clip_sample_range": 1.0,
9
+ "dynamic_thresholding_ratio": 0.995,
10
+ "num_train_timesteps": 1000,
11
+ "prediction_type": "epsilon",
12
+ "sample_max_value": 1.0,
13
+ "set_alpha_to_one": false,
14
+ "steps_offset": 1,
15
+ "thresholding": false,
16
+ "trained_betas": null
17
+ }
text_encoder/config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ClapTextModelWithProjection"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 0,
7
+ "classifier_dropout": null,
8
+ "eos_token_id": 2,
9
+ "fusion_hidden_size": 768,
10
+ "fusion_num_hidden_layers": 2,
11
+ "hidden_act": "gelu",
12
+ "hidden_dropout_prob": 0.1,
13
+ "hidden_size": 768,
14
+ "initializer_factor": 1.0,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 3072,
17
+ "layer_norm_eps": 1e-12,
18
+ "max_position_embeddings": 514,
19
+ "model_type": "clap_text_model",
20
+ "num_attention_heads": 12,
21
+ "num_hidden_layers": 12,
22
+ "pad_token_id": 1,
23
+ "position_embedding_type": "absolute",
24
+ "projection_dim": 512,
25
+ "projection_hidden_act": "relu",
26
+ "projection_hidden_size": 768,
27
+ "torch_dtype": "float32",
28
+ "transformers_version": "4.28.0.dev0",
29
+ "type_vocab_size": 1,
30
+ "use_cache": true,
31
+ "vocab_size": 50265
32
+ }
text_encoder/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0292e65e99024f2d21f0a59b718f0e4914546e4287e2dabdd2d7e4a95c169f7b
3
+ size 501284353
tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "cls_token": "<s>",
4
+ "eos_token": "</s>",
5
+ "mask_token": {
6
+ "content": "<mask>",
7
+ "lstrip": true,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "pad_token": "<pad>",
13
+ "sep_token": "</s>",
14
+ "unk_token": "<unk>"
15
+ }
tokenizer/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": "<s>",
4
+ "clean_up_tokenization_spaces": true,
5
+ "cls_token": "<s>",
6
+ "eos_token": "</s>",
7
+ "errors": "replace",
8
+ "mask_token": "<mask>",
9
+ "model_max_length": 512,
10
+ "pad_token": "<pad>",
11
+ "processor_class": "ClapProcessor",
12
+ "sep_token": "</s>",
13
+ "special_tokens_map_file": null,
14
+ "tokenizer_class": "RobertaTokenizer",
15
+ "trim_offsets": true,
16
+ "unk_token": "<unk>"
17
+ }
tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
unet/config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.15.0.dev0",
4
+ "act_fn": "silu",
5
+ "attention_head_dim": 8,
6
+ "block_out_channels": [
7
+ 256,
8
+ 512,
9
+ 768,
10
+ 1280
11
+ ],
12
+ "center_input_sample": false,
13
+ "class_embed_type": "simple_projection",
14
+ "class_embeddings_concat": true,
15
+ "conv_in_kernel": 3,
16
+ "conv_out_kernel": 3,
17
+ "cross_attention_dim": [
18
+ 256,
19
+ 512,
20
+ 768,
21
+ 1280
22
+ ],
23
+ "down_block_types": [
24
+ "DownBlock2D",
25
+ "CrossAttnDownBlock2D",
26
+ "CrossAttnDownBlock2D",
27
+ "CrossAttnDownBlock2D"
28
+ ],
29
+ "downsample_padding": 1,
30
+ "dual_cross_attention": false,
31
+ "flip_sin_to_cos": true,
32
+ "freq_shift": 0,
33
+ "in_channels": 8,
34
+ "layers_per_block": 2,
35
+ "mid_block_scale_factor": 1,
36
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
37
+ "norm_eps": 1e-05,
38
+ "norm_num_groups": 32,
39
+ "num_class_embeds": null,
40
+ "only_cross_attention": false,
41
+ "out_channels": 8,
42
+ "projection_class_embeddings_input_dim": 512,
43
+ "resnet_time_scale_shift": "default",
44
+ "sample_size": 128,
45
+ "time_cond_proj_dim": null,
46
+ "time_embedding_type": "positional",
47
+ "timestep_post_act": null,
48
+ "up_block_types": [
49
+ "CrossAttnUpBlock2D",
50
+ "CrossAttnUpBlock2D",
51
+ "CrossAttnUpBlock2D",
52
+ "UpBlock2D"
53
+ ],
54
+ "upcast_attention": false,
55
+ "use_linear_projection": false
56
+ }
unet/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:060bbc49e8afc88664d6871ef0dd465fc2788226a07315ab2d114b5c0ee6d8a5
3
+ size 2956840221
vae/config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.15.0.dev0",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512
9
+ ],
10
+ "down_block_types": [
11
+ "DownEncoderBlock2D",
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D"
14
+ ],
15
+ "in_channels": 1,
16
+ "latent_channels": 8,
17
+ "layers_per_block": 2,
18
+ "norm_num_groups": 32,
19
+ "out_channels": 1,
20
+ "sample_size": 512,
21
+ "scaling_factor": 0.9654927849769592,
22
+ "up_block_types": [
23
+ "UpDecoderBlock2D",
24
+ "UpDecoderBlock2D",
25
+ "UpDecoderBlock2D"
26
+ ]
27
+ }
vae/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e3fc8ccecb1849c8a23cd4f9dd959eb7aaa203cc010386288418dbc551cdaaf7
3
+ size 221586505
vocoder/config.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SpeechT5HifiGan"
4
+ ],
5
+ "initializer_range": 0.01,
6
+ "leaky_relu_slope": 0.1,
7
+ "model_in_dim": 64,
8
+ "model_type": "hifigan",
9
+ "normalize_before": false,
10
+ "resblock_dilation_sizes": [
11
+ [
12
+ 1,
13
+ 3,
14
+ 5
15
+ ],
16
+ [
17
+ 1,
18
+ 3,
19
+ 5
20
+ ],
21
+ [
22
+ 1,
23
+ 3,
24
+ 5
25
+ ]
26
+ ],
27
+ "resblock_kernel_sizes": [
28
+ 3,
29
+ 7,
30
+ 11
31
+ ],
32
+ "sampling_rate": 16000,
33
+ "torch_dtype": "float32",
34
+ "transformers_version": "4.28.0.dev0",
35
+ "upsample_initial_channel": 1024,
36
+ "upsample_kernel_sizes": [
37
+ 16,
38
+ 16,
39
+ 8,
40
+ 4,
41
+ 4
42
+ ],
43
+ "upsample_rates": [
44
+ 5,
45
+ 4,
46
+ 2,
47
+ 2,
48
+ 2
49
+ ]
50
+ }
vocoder/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9fbefc2b31c85d1dabe98e53d09ac88039af411162a7e641040a9c2b5f62364
3
+ size 221120349