teticio commited on
Commit
8aa7c27
β€’
1 Parent(s): d76bdef

convert to hf model

Browse files
.gitignore CHANGED
@@ -3,9 +3,10 @@ __pycache__
3
  .ipynb_checkpoints
4
  data*
5
  ddpm-ema-audio-*
6
- flagged/
7
- build/
8
  audiodiffusion.egg-info
9
- lightning_logs/
10
- taming/
11
- checkpoints/
 
 
3
  .ipynb_checkpoints
4
  data*
5
  ddpm-ema-audio-*
6
+ flagged
7
+ build
8
  audiodiffusion.egg-info
9
+ lightning_logs
10
+ taming
11
+ checkpoints
12
+ vae_model
README.md CHANGED
@@ -45,20 +45,23 @@ You can play around with some pretrained models on [Google Colab](https://colab.
45
  ---
46
 
47
  ## Generate Mel spectrogram dataset from directory of audio files
 
 
 
 
48
  #### Training can be run with Mel spectrograms of resolution 64x64 on a single commercial grade GPU (e.g. RTX 2080 Ti). The `hop_length` should be set to 1024 for better results.
49
 
50
  ```bash
51
- python audio_to_images.py \
52
  --resolution 64 \
53
  --hop_length 1024 \
54
  --input_dir path-to-audio-files \
55
  --output_dir data-test
56
  ```
57
-
58
  #### Generate dataset of 256x256 Mel spectrograms and push to hub (you will need to be authenticated with `huggingface-cli login`).
59
 
60
  ```bash
61
- python audio_to_images.py \
62
  --resolution 256 \
63
  --input_dir path-to-audio-files \
64
  --output_dir data-256 \
@@ -66,10 +69,9 @@ python audio_to_images.py \
66
  ```
67
  ## Train model
68
  #### Run training on local machine.
69
-
70
  ```bash
71
- accelerate launch --config_file accelerate_local.yaml \
72
- train_unconditional.py \
73
  --dataset_name data-64 \
74
  --resolution 64 \
75
  --hop_length 1024 \
@@ -81,12 +83,10 @@ accelerate launch --config_file accelerate_local.yaml \
81
  --lr_warmup_steps 500 \
82
  --mixed_precision no
83
  ```
84
-
85
  #### Run training on local machine with `batch_size` of 2 and `gradient_accumulation_steps` 8 to compensate, so that 256x256 resolution model fits on commercial grade GPU and push to hub.
86
-
87
  ```bash
88
- accelerate launch --config_file accelerate_local.yaml \
89
- train_unconditional.py \
90
  --dataset_name teticio/audio-diffusion-256 \
91
  --resolution 256 \
92
  --output_dir latent-audio-diffusion-256 \
@@ -101,12 +101,10 @@ accelerate launch --config_file accelerate_local.yaml \
101
  --hub_model_id latent-audio-diffusion-256 \
102
  --hub_token $(cat $HOME/.huggingface/token)
103
  ```
104
-
105
  #### Run training on SageMaker.
106
-
107
  ```bash
108
- accelerate launch --config_file accelerate_sagemaker.yaml \
109
- strain_unconditional.py \
110
  --dataset_name teticio/audio-diffusion-256 \
111
  --resolution 256 \
112
  --output_dir ddpm-ema-audio-256 \
 
45
  ---
46
 
47
  ## Generate Mel spectrogram dataset from directory of audio files
48
+ #### Install
49
+ ```bash
50
+ pip install .
51
+ ```
52
  #### Training can be run with Mel spectrograms of resolution 64x64 on a single commercial grade GPU (e.g. RTX 2080 Ti). The `hop_length` should be set to 1024 for better results.
53
 
54
  ```bash
55
+ python scripts/audio_to_images.py \
56
  --resolution 64 \
57
  --hop_length 1024 \
58
  --input_dir path-to-audio-files \
59
  --output_dir data-test
60
  ```
 
61
  #### Generate dataset of 256x256 Mel spectrograms and push to hub (you will need to be authenticated with `huggingface-cli login`).
62
 
63
  ```bash
64
+ python scripts/audio_to_images.py \
65
  --resolution 256 \
66
  --input_dir path-to-audio-files \
67
  --output_dir data-256 \
 
69
  ```
70
  ## Train model
71
  #### Run training on local machine.
 
72
  ```bash
73
+ accelerate launch --config_file config/accelerate_local.yaml \
74
+ scripts/train_unconditional.py \
75
  --dataset_name data-64 \
76
  --resolution 64 \
77
  --hop_length 1024 \
 
83
  --lr_warmup_steps 500 \
84
  --mixed_precision no
85
  ```
 
86
  #### Run training on local machine with `batch_size` of 2 and `gradient_accumulation_steps` 8 to compensate, so that 256x256 resolution model fits on commercial grade GPU and push to hub.
 
87
  ```bash
88
+ accelerate launch --config_file config/accelerate_local.yaml \
89
+ scripts/train_unconditional.py \
90
  --dataset_name teticio/audio-diffusion-256 \
91
  --resolution 256 \
92
  --output_dir latent-audio-diffusion-256 \
 
101
  --hub_model_id latent-audio-diffusion-256 \
102
  --hub_token $(cat $HOME/.huggingface/token)
103
  ```
 
104
  #### Run training on SageMaker.
 
105
  ```bash
106
+ accelerate launch --config_file config/accelerate_sagemaker.yaml \
107
+ scripts/train_unconditional.py \
108
  --dataset_name teticio/audio-diffusion-256 \
109
  --resolution 256 \
110
  --output_dir ddpm-ema-audio-256 \
audiodiffusion/utils.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adpated from https://github.com/huggingface/diffusers/blob/main/scripts/convert_original_stable_diffusion_to_diffusers.py
2
+
3
+ import torch
4
+ from diffusers import AutoencoderKL
5
+
6
+
7
+ def shave_segments(path, n_shave_prefix_segments=1):
8
+ """
9
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
10
+ """
11
+ if n_shave_prefix_segments >= 0:
12
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
13
+ else:
14
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
15
+
16
+
17
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
18
+ """
19
+ Updates paths inside resnets to the new naming scheme (local renaming)
20
+ """
21
+ mapping = []
22
+ for old_item in old_list:
23
+ new_item = old_item
24
+
25
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
26
+ new_item = shave_segments(
27
+ new_item, n_shave_prefix_segments=n_shave_prefix_segments)
28
+
29
+ mapping.append({"old": old_item, "new": new_item})
30
+
31
+ return mapping
32
+
33
+
34
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
35
+ """
36
+ Updates paths inside attentions to the new naming scheme (local renaming)
37
+ """
38
+ mapping = []
39
+ for old_item in old_list:
40
+ new_item = old_item
41
+
42
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
43
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
44
+
45
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
46
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
47
+
48
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
49
+
50
+ mapping.append({"old": old_item, "new": new_item})
51
+
52
+ return mapping
53
+
54
+
55
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
56
+ """
57
+ Updates paths inside attentions to the new naming scheme (local renaming)
58
+ """
59
+ mapping = []
60
+ for old_item in old_list:
61
+ new_item = old_item
62
+
63
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
64
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
65
+
66
+ new_item = new_item.replace("q.weight", "query.weight")
67
+ new_item = new_item.replace("q.bias", "query.bias")
68
+
69
+ new_item = new_item.replace("k.weight", "key.weight")
70
+ new_item = new_item.replace("k.bias", "key.bias")
71
+
72
+ new_item = new_item.replace("v.weight", "value.weight")
73
+ new_item = new_item.replace("v.bias", "value.bias")
74
+
75
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
76
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
77
+
78
+ new_item = shave_segments(
79
+ new_item, n_shave_prefix_segments=n_shave_prefix_segments)
80
+
81
+ mapping.append({"old": old_item, "new": new_item})
82
+
83
+ return mapping
84
+
85
+
86
+ def assign_to_checkpoint(paths,
87
+ checkpoint,
88
+ old_checkpoint,
89
+ attention_paths_to_split=None,
90
+ additional_replacements=None,
91
+ config=None):
92
+ """
93
+ This does the final conversion step: take locally converted weights and apply a global renaming
94
+ to them. It splits attention layers, and takes into account additional replacements
95
+ that may arise.
96
+
97
+ Assigns the weights to the new checkpoint.
98
+ """
99
+ assert isinstance(
100
+ paths, list
101
+ ), "Paths should be a list of dicts containing 'old' and 'new' keys."
102
+
103
+ # Splits the attention layers into three variables.
104
+ if attention_paths_to_split is not None:
105
+ for path, path_map in attention_paths_to_split.items():
106
+ old_tensor = old_checkpoint[path]
107
+ channels = old_tensor.shape[0] // 3
108
+
109
+ target_shape = (-1,
110
+ channels) if len(old_tensor.shape) == 3 else (-1)
111
+
112
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
113
+
114
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels //
115
+ num_heads) + old_tensor.shape[1:])
116
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
117
+
118
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
119
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
120
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
121
+
122
+ for path in paths:
123
+ new_path = path["new"]
124
+
125
+ # These have already been assigned
126
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
127
+ continue
128
+
129
+ # Global renaming happens here
130
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
131
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
132
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
133
+
134
+ if additional_replacements is not None:
135
+ for replacement in additional_replacements:
136
+ new_path = new_path.replace(replacement["old"],
137
+ replacement["new"])
138
+
139
+ # proj_attn.weight has to be converted from conv 1D to linear
140
+ if "proj_attn.weight" in new_path:
141
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
142
+ else:
143
+ checkpoint[new_path] = old_checkpoint[path["old"]]
144
+
145
+
146
+ def conv_attn_to_linear(checkpoint):
147
+ keys = list(checkpoint.keys())
148
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
149
+ for key in keys:
150
+ if ".".join(key.split(".")[-2:]) in attn_keys:
151
+ if checkpoint[key].ndim > 2:
152
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
153
+ elif "proj_attn.weight" in key:
154
+ if checkpoint[key].ndim > 2:
155
+ checkpoint[key] = checkpoint[key][:, :, 0]
156
+
157
+
158
+ def create_vae_diffusers_config(original_config):
159
+ """
160
+ Creates a config for the diffusers based on the config of the LDM model.
161
+ """
162
+ vae_params = original_config.model.params.ddconfig
163
+ _ = original_config.model.params.embed_dim
164
+
165
+ block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
166
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
167
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
168
+
169
+ config = dict(
170
+ sample_size=vae_params.resolution,
171
+ in_channels=vae_params.in_channels,
172
+ out_channels=vae_params.out_ch,
173
+ down_block_types=tuple(down_block_types),
174
+ up_block_types=tuple(up_block_types),
175
+ block_out_channels=tuple(block_out_channels),
176
+ latent_channels=vae_params.z_channels,
177
+ layers_per_block=vae_params.num_res_blocks,
178
+ )
179
+ return config
180
+
181
+
182
+ def convert_ldm_vae_checkpoint(checkpoint, config):
183
+ # extract state dict for VAE
184
+ vae_state_dict = checkpoint
185
+
186
+ new_checkpoint = {}
187
+
188
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict[
189
+ "encoder.conv_in.weight"]
190
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict[
191
+ "encoder.conv_in.bias"]
192
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict[
193
+ "encoder.conv_out.weight"]
194
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict[
195
+ "encoder.conv_out.bias"]
196
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict[
197
+ "encoder.norm_out.weight"]
198
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict[
199
+ "encoder.norm_out.bias"]
200
+
201
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict[
202
+ "decoder.conv_in.weight"]
203
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict[
204
+ "decoder.conv_in.bias"]
205
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict[
206
+ "decoder.conv_out.weight"]
207
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict[
208
+ "decoder.conv_out.bias"]
209
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict[
210
+ "decoder.norm_out.weight"]
211
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict[
212
+ "decoder.norm_out.bias"]
213
+
214
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
215
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
216
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict[
217
+ "post_quant_conv.weight"]
218
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict[
219
+ "post_quant_conv.bias"]
220
+
221
+ # Retrieves the keys for the encoder down blocks only
222
+ num_down_blocks = len({
223
+ ".".join(layer.split(".")[:3])
224
+ for layer in vae_state_dict if "encoder.down" in layer
225
+ })
226
+ down_blocks = {
227
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key]
228
+ for layer_id in range(num_down_blocks)
229
+ }
230
+
231
+ # Retrieves the keys for the decoder up blocks only
232
+ num_up_blocks = len({
233
+ ".".join(layer.split(".")[:3])
234
+ for layer in vae_state_dict if "decoder.up" in layer
235
+ })
236
+ up_blocks = {
237
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key]
238
+ for layer_id in range(num_up_blocks)
239
+ }
240
+
241
+ for i in range(num_down_blocks):
242
+ resnets = [
243
+ key for key in down_blocks[i]
244
+ if f"down.{i}" in key and f"down.{i}.downsample" not in key
245
+ ]
246
+
247
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
248
+ new_checkpoint[
249
+ f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
250
+ f"encoder.down.{i}.downsample.conv.weight")
251
+ new_checkpoint[
252
+ f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
253
+ f"encoder.down.{i}.downsample.conv.bias")
254
+
255
+ paths = renew_vae_resnet_paths(resnets)
256
+ meta_path = {
257
+ "old": f"down.{i}.block",
258
+ "new": f"down_blocks.{i}.resnets"
259
+ }
260
+ assign_to_checkpoint(paths,
261
+ new_checkpoint,
262
+ vae_state_dict,
263
+ additional_replacements=[meta_path],
264
+ config=config)
265
+
266
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
267
+ num_mid_res_blocks = 2
268
+ for i in range(1, num_mid_res_blocks + 1):
269
+ resnets = [
270
+ key for key in mid_resnets if f"encoder.mid.block_{i}" in key
271
+ ]
272
+
273
+ paths = renew_vae_resnet_paths(resnets)
274
+ meta_path = {
275
+ "old": f"mid.block_{i}",
276
+ "new": f"mid_block.resnets.{i - 1}"
277
+ }
278
+ assign_to_checkpoint(paths,
279
+ new_checkpoint,
280
+ vae_state_dict,
281
+ additional_replacements=[meta_path],
282
+ config=config)
283
+
284
+ mid_attentions = [
285
+ key for key in vae_state_dict if "encoder.mid.attn" in key
286
+ ]
287
+ paths = renew_vae_attention_paths(mid_attentions)
288
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
289
+ assign_to_checkpoint(paths,
290
+ new_checkpoint,
291
+ vae_state_dict,
292
+ additional_replacements=[meta_path],
293
+ config=config)
294
+ conv_attn_to_linear(new_checkpoint)
295
+
296
+ for i in range(num_up_blocks):
297
+ block_id = num_up_blocks - 1 - i
298
+ resnets = [
299
+ key for key in up_blocks[block_id]
300
+ if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
301
+ ]
302
+
303
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
304
+ new_checkpoint[
305
+ f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
306
+ f"decoder.up.{block_id}.upsample.conv.weight"]
307
+ new_checkpoint[
308
+ f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
309
+ f"decoder.up.{block_id}.upsample.conv.bias"]
310
+
311
+ paths = renew_vae_resnet_paths(resnets)
312
+ meta_path = {
313
+ "old": f"up.{block_id}.block",
314
+ "new": f"up_blocks.{i}.resnets"
315
+ }
316
+ assign_to_checkpoint(paths,
317
+ new_checkpoint,
318
+ vae_state_dict,
319
+ additional_replacements=[meta_path],
320
+ config=config)
321
+
322
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
323
+ num_mid_res_blocks = 2
324
+ for i in range(1, num_mid_res_blocks + 1):
325
+ resnets = [
326
+ key for key in mid_resnets if f"decoder.mid.block_{i}" in key
327
+ ]
328
+
329
+ paths = renew_vae_resnet_paths(resnets)
330
+ meta_path = {
331
+ "old": f"mid.block_{i}",
332
+ "new": f"mid_block.resnets.{i - 1}"
333
+ }
334
+ assign_to_checkpoint(paths,
335
+ new_checkpoint,
336
+ vae_state_dict,
337
+ additional_replacements=[meta_path],
338
+ config=config)
339
+
340
+ mid_attentions = [
341
+ key for key in vae_state_dict if "decoder.mid.attn" in key
342
+ ]
343
+ paths = renew_vae_attention_paths(mid_attentions)
344
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
345
+ assign_to_checkpoint(paths,
346
+ new_checkpoint,
347
+ vae_state_dict,
348
+ additional_replacements=[meta_path],
349
+ config=config)
350
+ conv_attn_to_linear(new_checkpoint)
351
+ return new_checkpoint
352
+
353
+ def convert_ldm_to_hf_vae(ldm_checkpoint, ldm_config, hf_checkpoint):
354
+ checkpoint = torch.load(ldm_checkpoint)["state_dict"]
355
+
356
+ # Convert the VAE model.
357
+ vae_config = create_vae_diffusers_config(ldm_config)
358
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(
359
+ checkpoint, vae_config)
360
+
361
+ vae = AutoencoderKL(**vae_config)
362
+ vae.load_state_dict(converted_vae_checkpoint)
363
+ vae.save_pretrained(hf_checkpoint)
accelerate_deepspeed.yaml β†’ config/accelerate_deepspeed.yaml RENAMED
File without changes
accelerate_local.yaml β†’ config/accelerate_local.yaml RENAMED
File without changes
accelerate_sagemaker.yaml β†’ config/accelerate_sagemaker.yaml RENAMED
File without changes
ldm_autoencoder_kl.yaml β†’ config/ldm_autoencoder_kl.yaml RENAMED
@@ -27,6 +27,5 @@ model:
27
  lightning:
28
  trainer:
29
  benchmark: True
30
- accumulate_grad_batches: 24
31
  accelerator: gpu
32
  devices: 1
 
27
  lightning:
28
  trainer:
29
  benchmark: True
 
30
  accelerator: gpu
31
  devices: 1
audio_to_images.py β†’ scripts/audio_to_images.py RENAMED
File without changes
train_unconditional.py β†’ scripts/train_unconditional.py RENAMED
File without changes
train_vae.py β†’ scripts/train_vae.py RENAMED
@@ -4,7 +4,8 @@
4
 
5
  # TODO
6
  # grayscale
7
- # convert to huggingface / train huggingface
 
8
 
9
  import os
10
  import argparse
@@ -15,21 +16,26 @@ import numpy as np
15
  from PIL import Image
16
  import pytorch_lightning as pl
17
  from omegaconf import OmegaConf
18
- from datasets import load_dataset
19
  from librosa.util import normalize
20
  from ldm.util import instantiate_from_config
21
  from pytorch_lightning.trainer import Trainer
22
  from torch.utils.data import DataLoader, Dataset
 
23
  from pytorch_lightning.callbacks import Callback, ModelCheckpoint
 
24
 
25
  from audiodiffusion.mel import Mel
 
26
 
27
 
28
  class AudioDiffusion(Dataset):
29
 
30
  def __init__(self, model_id):
31
  super().__init__()
32
- self.hf_dataset = load_dataset(model_id)['train']
 
 
 
33
 
34
  def __len__(self):
35
  return len(self.hf_dataset)
@@ -65,11 +71,8 @@ class ImageLogger(Callback):
65
  hop_length=hop_length)
66
  self.every = every
67
 
68
- def on_train_batch_end(self, trainer, pl_module, outputs, batch,
69
- batch_idx):
70
- if (batch_idx + 1) % self.every != 0:
71
- return
72
-
73
  pl_module.eval()
74
  with torch.no_grad():
75
  images = pl_module.log_images(batch, split='train')
@@ -96,27 +99,69 @@ class ImageLogger(Callback):
96
  global_step=pl_module.global_step,
97
  sample_rate=self.mel.get_sample_rate())
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  if __name__ == "__main__":
101
  parser = argparse.ArgumentParser(description="Train VAE using ldm.")
102
- parser.add_argument("--batch_size", type=int, default=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  args = parser.parse_args()
104
 
105
- config = OmegaConf.load('ldm_autoencoder_kl.yaml')
106
  lightning_config = config.pop("lightning", OmegaConf.create())
107
  trainer_config = lightning_config.get("trainer", OmegaConf.create())
 
108
  trainer_opt = argparse.Namespace(**trainer_config)
109
- trainer = Trainer.from_argparse_args(trainer_opt,
110
- callbacks=[
111
- ImageLogger(),
112
- ModelCheckpoint(
113
- dirpath='checkpoints',
114
- filename='{epoch:06}',
115
- verbose=True,
116
- save_last=True)
117
- ])
 
 
 
118
  model = instantiate_from_config(config.model)
119
  model.learning_rate = config.model.base_learning_rate
120
- data = AudioDiffusionDataModule('teticio/audio-diffusion-256',
121
  batch_size=args.batch_size)
122
  trainer.fit(model, data)
 
4
 
5
  # TODO
6
  # grayscale
7
+ # add vae to train_uncond (no_grad)
8
+ # update README
9
 
10
  import os
11
  import argparse
 
16
  from PIL import Image
17
  import pytorch_lightning as pl
18
  from omegaconf import OmegaConf
 
19
  from librosa.util import normalize
20
  from ldm.util import instantiate_from_config
21
  from pytorch_lightning.trainer import Trainer
22
  from torch.utils.data import DataLoader, Dataset
23
+ from datasets import load_from_disk, load_dataset
24
  from pytorch_lightning.callbacks import Callback, ModelCheckpoint
25
+ from pytorch_lightning.utilities.distributed import rank_zero_only
26
 
27
  from audiodiffusion.mel import Mel
28
+ from audiodiffusion.utils import convert_ldm_to_hf_vae
29
 
30
 
31
  class AudioDiffusion(Dataset):
32
 
33
  def __init__(self, model_id):
34
  super().__init__()
35
+ if os.path.exists(model_id):
36
+ self.hf_dataset = load_from_disk(model_id)['train']
37
+ else:
38
+ self.hf_dataset = load_dataset(model_id)['train']
39
 
40
  def __len__(self):
41
  return len(self.hf_dataset)
 
71
  hop_length=hop_length)
72
  self.every = every
73
 
74
+ @rank_zero_only
75
+ def log_images_and_audios(self, pl_module, batch):
 
 
 
76
  pl_module.eval()
77
  with torch.no_grad():
78
  images = pl_module.log_images(batch, split='train')
 
99
  global_step=pl_module.global_step,
100
  sample_rate=self.mel.get_sample_rate())
101
 
102
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch,
103
+ batch_idx):
104
+ if (batch_idx + 1) % self.every != 0:
105
+ return
106
+ self.log_images_and_audios(pl_module, batch)
107
+
108
+
109
+ class HFModelCheckpoint(ModelCheckpoint):
110
+
111
+ def __init__(self, ldm_config, hf_checkpoint='vae_model', *args, **kwargs):
112
+ super().__init__(*args, **kwargs)
113
+ self.ldm_config = ldm_config
114
+ self.hf_checkpoint = hf_checkpoint
115
+
116
+ def on_train_epoch_end(self, trainer, pl_module):
117
+ super().on_train_epoch_end(trainer, pl_module)
118
+ ldm_checkpoint = self.format_checkpoint_name(
119
+ {'epoch': trainer.current_epoch})
120
+ convert_ldm_to_hf_vae(ldm_checkpoint, self.ldm_config,
121
+ self.hf_checkpoint)
122
+
123
 
124
  if __name__ == "__main__":
125
  parser = argparse.ArgumentParser(description="Train VAE using ldm.")
126
+ parser.add_argument("-d", "--dataset_name", type=str, default=None)
127
+ parser.add_argument("-b", "--batch_size", type=int, default=1)
128
+ parser.add_argument("-c",
129
+ "--ldm_config_file",
130
+ type=str,
131
+ default="config/ldm_autoencoder_kl.yaml")
132
+ parser.add_argument("--ldm_checkpoint_dir",
133
+ type=str,
134
+ default="checkpoints")
135
+ parser.add_argument("--hf_checkpoint_dir", type=str, default="vae_model")
136
+ parser.add_argument("-r",
137
+ "--resume_from_checkpoint",
138
+ type=str,
139
+ default=None)
140
+ parser.add_argument("-g",
141
+ "--gradient_accumulation_steps",
142
+ type=int,
143
+ default=1)
144
  args = parser.parse_args()
145
 
146
+ config = OmegaConf.load(args.ldm_config_file)
147
  lightning_config = config.pop("lightning", OmegaConf.create())
148
  trainer_config = lightning_config.get("trainer", OmegaConf.create())
149
+ trainer_config.accumulate_grad_batches = args.gradient_accumulation_steps
150
  trainer_opt = argparse.Namespace(**trainer_config)
151
+ trainer = Trainer.from_argparse_args(
152
+ trainer_opt,
153
+ resume_from_checkpoint=args.resume_from_checkpoint,
154
+ callbacks=[
155
+ ImageLogger(),
156
+ HFModelCheckpoint(ldm_config=config,
157
+ hf_checkpoint=args.hf_checkpoint_dir,
158
+ dirpath=args.ldm_checkpoint_dir,
159
+ filename='{epoch:06}',
160
+ verbose=True,
161
+ save_last=True)
162
+ ])
163
  model = instantiate_from_config(config.model)
164
  model.learning_rate = config.model.base_learning_rate
165
+ data = AudioDiffusionDataModule(args.dataset_name,
166
  batch_size=args.batch_size)
167
  trainer.fit(model, data)