Uniaff commited on
Commit
0ab4532
·
verified ·
1 Parent(s): cd81475

Upload 101 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +43 -36
  2. campplus_cn_common.bin +3 -0
  3. configs/config_dit_mel_seed.yml +79 -0
  4. configs/config_dit_mel_seed_facodec_small.yml +97 -0
  5. configs/config_dit_mel_seed_wavenet.yml +79 -0
  6. configs/hifigan.yml +25 -0
  7. dac/__init__.py +16 -0
  8. dac/__main__.py +36 -0
  9. dac/model/__init__.py +4 -0
  10. dac/model/base.py +294 -0
  11. dac/model/dac.py +400 -0
  12. dac/model/discriminator.py +228 -0
  13. dac/model/encodec.py +320 -0
  14. dac/nn/__init__.py +3 -0
  15. dac/nn/layers.py +33 -0
  16. dac/nn/loss.py +368 -0
  17. dac/nn/quantize.py +339 -0
  18. dac/utils/__init__.py +123 -0
  19. dac/utils/decode.py +95 -0
  20. dac/utils/encode.py +94 -0
  21. examples/reference/azuma_0.wav +0 -0
  22. examples/reference/dingzhen_0.wav +3 -0
  23. examples/reference/kobe_0.wav +0 -0
  24. examples/reference/s1p1.wav +0 -0
  25. examples/reference/s1p2.wav +0 -0
  26. examples/reference/s2p1.wav +0 -0
  27. examples/reference/s2p2.wav +0 -0
  28. examples/reference/s3p1.wav +0 -0
  29. examples/reference/s3p2.wav +3 -0
  30. examples/reference/s4p1.wav +0 -0
  31. examples/reference/s4p2.wav +0 -0
  32. examples/reference/teio_0.wav +0 -0
  33. examples/reference/trump_0.wav +3 -0
  34. examples/source/TECHNOPOLIS - 2085 [vocals]_[cut_14sec].wav +3 -0
  35. examples/source/Wiz Khalifa,Charlie Puth - See You Again [vocals]_[cut_28sec].wav +3 -0
  36. examples/source/glados_0.wav +0 -0
  37. examples/source/jay_0.wav +3 -0
  38. examples/source/source_s1.wav +0 -0
  39. examples/source/source_s2.wav +0 -0
  40. examples/source/source_s3.wav +3 -0
  41. examples/source/source_s4.wav +3 -0
  42. examples/source/yae_0.wav +0 -0
  43. hf_utils.py +12 -0
  44. modules/__pycache__/audio.cpython-310.pyc +0 -0
  45. modules/__pycache__/commons.cpython-310.pyc +0 -0
  46. modules/__pycache__/diffusion_transformer.cpython-310.pyc +0 -0
  47. modules/__pycache__/encodec.cpython-310.pyc +0 -0
  48. modules/__pycache__/flow_matching.cpython-310.pyc +0 -0
  49. modules/__pycache__/length_regulator.cpython-310.pyc +0 -0
  50. modules/__pycache__/wavenet.cpython-310.pyc +0 -0
.gitattributes CHANGED
@@ -1,36 +1,43 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- images/comparison.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/reference/dingzhen_0.wav filter=lfs diff=lfs merge=lfs -text
37
+ examples/reference/s3p2.wav filter=lfs diff=lfs merge=lfs -text
38
+ examples/source/source_s3.wav filter=lfs diff=lfs merge=lfs -text
39
+ examples/source/source_s4.wav filter=lfs diff=lfs merge=lfs -text
40
+ examples/source/Wiz[[:space:]]Khalifa,Charlie[[:space:]]Puth[[:space:]]-[[:space:]]See[[:space:]]You[[:space:]]Again[[:space:]]\[vocals\]_\[cut_28sec\].wav filter=lfs diff=lfs merge=lfs -text
41
+ examples/reference/trump_0.wav filter=lfs diff=lfs merge=lfs -text
42
+ examples/source/jay_0.wav filter=lfs diff=lfs merge=lfs -text
43
+ examples/source/TECHNOPOLIS[[:space:]]-[[:space:]]2085[[:space:]]\[vocals\]_\[cut_14sec\].wav filter=lfs diff=lfs merge=lfs -text
campplus_cn_common.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3388cf5fd3493c9ac9c69851d8e7a8badcfb4f3dc631020c4961371646d5ada8
3
+ size 28036335
configs/config_dit_mel_seed.yml ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "./runs/run_dit_mel_seed"
2
+ save_freq: 1
3
+ log_interval: 10
4
+ save_interval: 1000
5
+ device: "cuda"
6
+ epochs: 1000 # number of epochs for first stage training (pre-training)
7
+ batch_size: 4
8
+ batch_length: 100 # maximum duration of audio in a batch (in seconds)
9
+ max_len: 80 # maximum number of frames
10
+ pretrained_model: ""
11
+ pretrained_encoder: ""
12
+ load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters
13
+
14
+ F0_path: "modules/JDC/bst.t7"
15
+
16
+ preprocess_params:
17
+ sr: 22050
18
+ spect_params:
19
+ n_fft: 1024
20
+ win_length: 1024
21
+ hop_length: 256
22
+ n_mels: 80
23
+
24
+ model_params:
25
+ dit_type: "DiT" # uDiT or DiT
26
+ reg_loss_type: "l2" # l1 or l2
27
+
28
+ speech_tokenizer:
29
+ path: "speech_tokenizer_v1.onnx"
30
+
31
+ style_encoder:
32
+ dim: 192
33
+ campplus_path: "campplus_cn_common.bin"
34
+
35
+ DAC:
36
+ encoder_dim: 64
37
+ encoder_rates: [2, 5, 5, 6]
38
+ decoder_dim: 1536
39
+ decoder_rates: [ 6, 5, 5, 2 ]
40
+ sr: 24000
41
+
42
+ length_regulator:
43
+ channels: 768
44
+ is_discrete: true
45
+ content_codebook_size: 4096
46
+ in_frame_rate: 50
47
+ out_frame_rate: 80
48
+ sampling_ratios: [1, 1, 1, 1]
49
+
50
+ DiT:
51
+ hidden_dim: 768
52
+ num_heads: 12
53
+ depth: 12
54
+ class_dropout_prob: 0.1
55
+ block_size: 4096
56
+ in_channels: 80
57
+ style_condition: true
58
+ final_layer_type: 'wavenet'
59
+ target: 'mel' # mel or codec
60
+ content_dim: 768
61
+ content_codebook_size: 1024
62
+ content_type: 'discrete'
63
+ f0_condition: false
64
+ n_f0_bins: 512
65
+ content_codebooks: 1
66
+ is_causal: false
67
+ long_skip_connection: true
68
+ zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
69
+
70
+ wavenet:
71
+ hidden_dim: 768
72
+ num_layers: 8
73
+ kernel_size: 5
74
+ dilation_rate: 1
75
+ p_dropout: 0.2
76
+ style_condition: true
77
+
78
+ loss_params:
79
+ base_lr: 0.0001
configs/config_dit_mel_seed_facodec_small.yml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "./runs/run_dit_mel_seed_facodec_small"
2
+ save_freq: 1
3
+ log_interval: 10
4
+ save_interval: 1000
5
+ device: "cuda"
6
+ epochs: 1000 # number of epochs for first stage training (pre-training)
7
+ batch_size: 2
8
+ batch_length: 100 # maximum duration of audio in a batch (in seconds)
9
+ max_len: 80 # maximum number of frames
10
+ pretrained_model: ""
11
+ pretrained_encoder: ""
12
+ load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters
13
+
14
+ F0_path: "modules/JDC/bst.t7"
15
+
16
+ data_params:
17
+ train_data: "./data/train.txt"
18
+ val_data: "./data/val.txt"
19
+ root_path: "./data/"
20
+
21
+ preprocess_params:
22
+ sr: 22050
23
+ spect_params:
24
+ n_fft: 1024
25
+ win_length: 1024
26
+ hop_length: 256
27
+ n_mels: 80
28
+
29
+ model_params:
30
+ dit_type: "DiT" # uDiT or DiT
31
+ reg_loss_type: "l1" # l1 or l2
32
+
33
+ speech_tokenizer:
34
+ type: 'facodec' # facodec or cosyvoice
35
+ path: "checkpoints/speech_tokenizer_v1.onnx"
36
+
37
+ style_encoder:
38
+ dim: 192
39
+ campplus_path: "checkpoints/campplus_cn_common.bin"
40
+
41
+ DAC:
42
+ encoder_dim: 64
43
+ encoder_rates: [2, 5, 5, 6]
44
+ decoder_dim: 1536
45
+ decoder_rates: [ 6, 5, 5, 2 ]
46
+ sr: 24000
47
+
48
+ length_regulator:
49
+ channels: 512
50
+ is_discrete: true
51
+ content_codebook_size: 1024
52
+ in_frame_rate: 80
53
+ out_frame_rate: 80
54
+ sampling_ratios: [1, 1, 1, 1]
55
+ token_dropout_prob: 0.3 # probability of performing token dropout
56
+ token_dropout_range: 1.0 # maximum percentage of tokens to drop out
57
+ n_codebooks: 3
58
+ quantizer_dropout: 0.5
59
+ f0_condition: false
60
+ n_f0_bins: 512
61
+
62
+ DiT:
63
+ hidden_dim: 512
64
+ num_heads: 8
65
+ depth: 13
66
+ class_dropout_prob: 0.1
67
+ block_size: 8192
68
+ in_channels: 80
69
+ style_condition: true
70
+ final_layer_type: 'wavenet'
71
+ target: 'mel' # mel or codec
72
+ content_dim: 512
73
+ content_codebook_size: 1024
74
+ content_type: 'discrete'
75
+ f0_condition: true
76
+ n_f0_bins: 512
77
+ content_codebooks: 1
78
+ is_causal: false
79
+ long_skip_connection: true
80
+ zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
81
+ time_as_token: false
82
+ style_as_token: false
83
+ uvit_skip_connection: true
84
+ add_resblock_in_transformer: false
85
+
86
+ wavenet:
87
+ hidden_dim: 512
88
+ num_layers: 8
89
+ kernel_size: 5
90
+ dilation_rate: 1
91
+ p_dropout: 0.2
92
+ style_condition: true
93
+
94
+ loss_params:
95
+ base_lr: 0.0001
96
+ lambda_mel: 45
97
+ lambda_kl: 1.0
configs/config_dit_mel_seed_wavenet.yml ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "./runs/run_dit_mel_seed"
2
+ save_freq: 1
3
+ log_interval: 10
4
+ save_interval: 1000
5
+ device: "cuda"
6
+ epochs: 1000 # number of epochs for first stage training (pre-training)
7
+ batch_size: 4
8
+ batch_length: 100 # maximum duration of audio in a batch (in seconds)
9
+ max_len: 80 # maximum number of frames
10
+ pretrained_model: ""
11
+ pretrained_encoder: ""
12
+ load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters
13
+
14
+ F0_path: "modules/JDC/bst.t7"
15
+
16
+ preprocess_params:
17
+ sr: 22050
18
+ spect_params:
19
+ n_fft: 1024
20
+ win_length: 1024
21
+ hop_length: 256
22
+ n_mels: 80
23
+
24
+ model_params:
25
+ dit_type: "DiT" # uDiT or DiT
26
+ reg_loss_type: "l2" # l1 or l2
27
+
28
+ speech_tokenizer:
29
+ path: "checkpoints/speech_tokenizer_v1.onnx"
30
+
31
+ style_encoder:
32
+ dim: 192
33
+ campplus_path: "campplus_cn_common.bin"
34
+
35
+ DAC:
36
+ encoder_dim: 64
37
+ encoder_rates: [2, 5, 5, 6]
38
+ decoder_dim: 1536
39
+ decoder_rates: [ 6, 5, 5, 2 ]
40
+ sr: 24000
41
+
42
+ length_regulator:
43
+ channels: 768
44
+ is_discrete: true
45
+ content_codebook_size: 4096
46
+ in_frame_rate: 50
47
+ out_frame_rate: 80
48
+ sampling_ratios: [1, 1, 1, 1]
49
+
50
+ DiT:
51
+ hidden_dim: 768
52
+ num_heads: 12
53
+ depth: 12
54
+ class_dropout_prob: 0.1
55
+ block_size: 8192
56
+ in_channels: 80
57
+ style_condition: true
58
+ final_layer_type: 'wavenet'
59
+ target: 'mel' # mel or codec
60
+ content_dim: 768
61
+ content_codebook_size: 1024
62
+ content_type: 'discrete'
63
+ f0_condition: false
64
+ n_f0_bins: 512
65
+ content_codebooks: 1
66
+ is_causal: false
67
+ long_skip_connection: true
68
+ zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
69
+
70
+ wavenet:
71
+ hidden_dim: 768
72
+ num_layers: 8
73
+ kernel_size: 5
74
+ dilation_rate: 1
75
+ p_dropout: 0.2
76
+ style_condition: true
77
+
78
+ loss_params:
79
+ base_lr: 0.0001
configs/hifigan.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hift:
2
+ in_channels: 80
3
+ base_channels: 512
4
+ nb_harmonics: 8
5
+ sampling_rate: 22050
6
+ nsf_alpha: 0.1
7
+ nsf_sigma: 0.003
8
+ nsf_voiced_threshold: 10
9
+ upsample_rates: [8, 8]
10
+ upsample_kernel_sizes: [16, 16]
11
+ istft_params:
12
+ n_fft: 16
13
+ hop_len: 4
14
+ resblock_kernel_sizes: [3, 7, 11]
15
+ resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
16
+ source_resblock_kernel_sizes: [7, 11]
17
+ source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
18
+ lrelu_slope: 0.1
19
+ audio_limit: 0.99
20
+ f0_predictor:
21
+ num_class: 1
22
+ in_channels: 80
23
+ cond_channels: 512
24
+
25
+ pretrained_model_path: "checkpoints/hift.pt"
dac/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "1.0.0"
2
+
3
+ # preserved here for legacy reasons
4
+ __model_version__ = "latest"
5
+
6
+ import audiotools
7
+
8
+ audiotools.ml.BaseModel.INTERN += ["dac.**"]
9
+ audiotools.ml.BaseModel.EXTERN += ["einops"]
10
+
11
+
12
+ from . import nn
13
+ from . import model
14
+ from . import utils
15
+ from .model import DAC
16
+ from .model import DACFile
dac/__main__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import argbind
4
+
5
+ from dac.utils import download
6
+ from dac.utils.decode import decode
7
+ from dac.utils.encode import encode
8
+
9
+ STAGES = ["encode", "decode", "download"]
10
+
11
+
12
+ def run(stage: str):
13
+ """Run stages.
14
+
15
+ Parameters
16
+ ----------
17
+ stage : str
18
+ Stage to run
19
+ """
20
+ if stage not in STAGES:
21
+ raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}")
22
+ stage_fn = globals()[stage]
23
+
24
+ if stage == "download":
25
+ stage_fn()
26
+ return
27
+
28
+ stage_fn()
29
+
30
+
31
+ if __name__ == "__main__":
32
+ group = sys.argv.pop(1)
33
+ args = argbind.parse_args(group=group)
34
+
35
+ with argbind.scope(args):
36
+ run(group)
dac/model/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .base import CodecMixin
2
+ from .base import DACFile
3
+ from .dac import DAC
4
+ from .discriminator import Discriminator
dac/model/base.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import tqdm
9
+ from audiotools import AudioSignal
10
+ from torch import nn
11
+
12
+ SUPPORTED_VERSIONS = ["1.0.0"]
13
+
14
+
15
+ @dataclass
16
+ class DACFile:
17
+ codes: torch.Tensor
18
+
19
+ # Metadata
20
+ chunk_length: int
21
+ original_length: int
22
+ input_db: float
23
+ channels: int
24
+ sample_rate: int
25
+ padding: bool
26
+ dac_version: str
27
+
28
+ def save(self, path):
29
+ artifacts = {
30
+ "codes": self.codes.numpy().astype(np.uint16),
31
+ "metadata": {
32
+ "input_db": self.input_db.numpy().astype(np.float32),
33
+ "original_length": self.original_length,
34
+ "sample_rate": self.sample_rate,
35
+ "chunk_length": self.chunk_length,
36
+ "channels": self.channels,
37
+ "padding": self.padding,
38
+ "dac_version": SUPPORTED_VERSIONS[-1],
39
+ },
40
+ }
41
+ path = Path(path).with_suffix(".dac")
42
+ with open(path, "wb") as f:
43
+ np.save(f, artifacts)
44
+ return path
45
+
46
+ @classmethod
47
+ def load(cls, path):
48
+ artifacts = np.load(path, allow_pickle=True)[()]
49
+ codes = torch.from_numpy(artifacts["codes"].astype(int))
50
+ if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
51
+ raise RuntimeError(
52
+ f"Given file {path} can't be loaded with this version of descript-audio-codec."
53
+ )
54
+ return cls(codes=codes, **artifacts["metadata"])
55
+
56
+
57
+ class CodecMixin:
58
+ @property
59
+ def padding(self):
60
+ if not hasattr(self, "_padding"):
61
+ self._padding = True
62
+ return self._padding
63
+
64
+ @padding.setter
65
+ def padding(self, value):
66
+ assert isinstance(value, bool)
67
+
68
+ layers = [
69
+ l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))
70
+ ]
71
+
72
+ for layer in layers:
73
+ if value:
74
+ if hasattr(layer, "original_padding"):
75
+ layer.padding = layer.original_padding
76
+ else:
77
+ layer.original_padding = layer.padding
78
+ layer.padding = tuple(0 for _ in range(len(layer.padding)))
79
+
80
+ self._padding = value
81
+
82
+ def get_delay(self):
83
+ # Any number works here, delay is invariant to input length
84
+ l_out = self.get_output_length(0)
85
+ L = l_out
86
+
87
+ layers = []
88
+ for layer in self.modules():
89
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
90
+ layers.append(layer)
91
+
92
+ for layer in reversed(layers):
93
+ d = layer.dilation[0]
94
+ k = layer.kernel_size[0]
95
+ s = layer.stride[0]
96
+
97
+ if isinstance(layer, nn.ConvTranspose1d):
98
+ L = ((L - d * (k - 1) - 1) / s) + 1
99
+ elif isinstance(layer, nn.Conv1d):
100
+ L = (L - 1) * s + d * (k - 1) + 1
101
+
102
+ L = math.ceil(L)
103
+
104
+ l_in = L
105
+
106
+ return (l_in - l_out) // 2
107
+
108
+ def get_output_length(self, input_length):
109
+ L = input_length
110
+ # Calculate output length
111
+ for layer in self.modules():
112
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
113
+ d = layer.dilation[0]
114
+ k = layer.kernel_size[0]
115
+ s = layer.stride[0]
116
+
117
+ if isinstance(layer, nn.Conv1d):
118
+ L = ((L - d * (k - 1) - 1) / s) + 1
119
+ elif isinstance(layer, nn.ConvTranspose1d):
120
+ L = (L - 1) * s + d * (k - 1) + 1
121
+
122
+ L = math.floor(L)
123
+ return L
124
+
125
+ @torch.no_grad()
126
+ def compress(
127
+ self,
128
+ audio_path_or_signal: Union[str, Path, AudioSignal],
129
+ win_duration: float = 1.0,
130
+ verbose: bool = False,
131
+ normalize_db: float = -16,
132
+ n_quantizers: int = None,
133
+ ) -> DACFile:
134
+ """Processes an audio signal from a file or AudioSignal object into
135
+ discrete codes. This function processes the signal in short windows,
136
+ using constant GPU memory.
137
+
138
+ Parameters
139
+ ----------
140
+ audio_path_or_signal : Union[str, Path, AudioSignal]
141
+ audio signal to reconstruct
142
+ win_duration : float, optional
143
+ window duration in seconds, by default 5.0
144
+ verbose : bool, optional
145
+ by default False
146
+ normalize_db : float, optional
147
+ normalize db, by default -16
148
+
149
+ Returns
150
+ -------
151
+ DACFile
152
+ Object containing compressed codes and metadata
153
+ required for decompression
154
+ """
155
+ audio_signal = audio_path_or_signal
156
+ if isinstance(audio_signal, (str, Path)):
157
+ audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
158
+
159
+ self.eval()
160
+ original_padding = self.padding
161
+ original_device = audio_signal.device
162
+
163
+ audio_signal = audio_signal.clone()
164
+ original_sr = audio_signal.sample_rate
165
+
166
+ resample_fn = audio_signal.resample
167
+ loudness_fn = audio_signal.loudness
168
+
169
+ # If audio is > 10 minutes long, use the ffmpeg versions
170
+ if audio_signal.signal_duration >= 10 * 60 * 60:
171
+ resample_fn = audio_signal.ffmpeg_resample
172
+ loudness_fn = audio_signal.ffmpeg_loudness
173
+
174
+ original_length = audio_signal.signal_length
175
+ resample_fn(self.sample_rate)
176
+ input_db = loudness_fn()
177
+
178
+ if normalize_db is not None:
179
+ audio_signal.normalize(normalize_db)
180
+ audio_signal.ensure_max_of_audio()
181
+
182
+ nb, nac, nt = audio_signal.audio_data.shape
183
+ audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
184
+ win_duration = (
185
+ audio_signal.signal_duration if win_duration is None else win_duration
186
+ )
187
+
188
+ if audio_signal.signal_duration <= win_duration:
189
+ # Unchunked compression (used if signal length < win duration)
190
+ self.padding = True
191
+ n_samples = nt
192
+ hop = nt
193
+ else:
194
+ # Chunked inference
195
+ self.padding = False
196
+ # Zero-pad signal on either side by the delay
197
+ audio_signal.zero_pad(self.delay, self.delay)
198
+ n_samples = int(win_duration * self.sample_rate)
199
+ # Round n_samples to nearest hop length multiple
200
+ n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
201
+ hop = self.get_output_length(n_samples)
202
+
203
+ codes = []
204
+ range_fn = range if not verbose else tqdm.trange
205
+
206
+ for i in range_fn(0, nt, hop):
207
+ x = audio_signal[..., i : i + n_samples]
208
+ x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
209
+
210
+ audio_data = x.audio_data.to(self.device)
211
+ audio_data = self.preprocess(audio_data, self.sample_rate)
212
+ _, c, _, _, _ = self.encode(audio_data, n_quantizers)
213
+ codes.append(c.to(original_device))
214
+ chunk_length = c.shape[-1]
215
+
216
+ codes = torch.cat(codes, dim=-1)
217
+
218
+ dac_file = DACFile(
219
+ codes=codes,
220
+ chunk_length=chunk_length,
221
+ original_length=original_length,
222
+ input_db=input_db,
223
+ channels=nac,
224
+ sample_rate=original_sr,
225
+ padding=self.padding,
226
+ dac_version=SUPPORTED_VERSIONS[-1],
227
+ )
228
+
229
+ if n_quantizers is not None:
230
+ codes = codes[:, :n_quantizers, :]
231
+
232
+ self.padding = original_padding
233
+ return dac_file
234
+
235
+ @torch.no_grad()
236
+ def decompress(
237
+ self,
238
+ obj: Union[str, Path, DACFile],
239
+ verbose: bool = False,
240
+ ) -> AudioSignal:
241
+ """Reconstruct audio from a given .dac file
242
+
243
+ Parameters
244
+ ----------
245
+ obj : Union[str, Path, DACFile]
246
+ .dac file location or corresponding DACFile object.
247
+ verbose : bool, optional
248
+ Prints progress if True, by default False
249
+
250
+ Returns
251
+ -------
252
+ AudioSignal
253
+ Object with the reconstructed audio
254
+ """
255
+ self.eval()
256
+ if isinstance(obj, (str, Path)):
257
+ obj = DACFile.load(obj)
258
+
259
+ original_padding = self.padding
260
+ self.padding = obj.padding
261
+
262
+ range_fn = range if not verbose else tqdm.trange
263
+ codes = obj.codes
264
+ original_device = codes.device
265
+ chunk_length = obj.chunk_length
266
+ recons = []
267
+
268
+ for i in range_fn(0, codes.shape[-1], chunk_length):
269
+ c = codes[..., i : i + chunk_length].to(self.device)
270
+ z = self.quantizer.from_codes(c)[0]
271
+ r = self.decode(z)
272
+ recons.append(r.to(original_device))
273
+
274
+ recons = torch.cat(recons, dim=-1)
275
+ recons = AudioSignal(recons, self.sample_rate)
276
+
277
+ resample_fn = recons.resample
278
+ loudness_fn = recons.loudness
279
+
280
+ # If audio is > 10 minutes long, use the ffmpeg versions
281
+ if recons.signal_duration >= 10 * 60 * 60:
282
+ resample_fn = recons.ffmpeg_resample
283
+ loudness_fn = recons.ffmpeg_loudness
284
+
285
+ recons.normalize(obj.input_db)
286
+ resample_fn(obj.sample_rate)
287
+ recons = recons[..., : obj.original_length]
288
+ loudness_fn()
289
+ recons.audio_data = recons.audio_data.reshape(
290
+ -1, obj.channels, obj.original_length
291
+ )
292
+
293
+ self.padding = original_padding
294
+ return recons
dac/model/dac.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+ from typing import Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from audiotools import AudioSignal
8
+ from audiotools.ml import BaseModel
9
+ from torch import nn
10
+
11
+ from .base import CodecMixin
12
+ from dac.nn.layers import Snake1d
13
+ from dac.nn.layers import WNConv1d
14
+ from dac.nn.layers import WNConvTranspose1d
15
+ from dac.nn.quantize import ResidualVectorQuantize
16
+ from .encodec import SConv1d, SConvTranspose1d, SLSTM
17
+
18
+
19
+ def init_weights(m):
20
+ if isinstance(m, nn.Conv1d):
21
+ nn.init.trunc_normal_(m.weight, std=0.02)
22
+ nn.init.constant_(m.bias, 0)
23
+
24
+
25
+ class ResidualUnit(nn.Module):
26
+ def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False):
27
+ super().__init__()
28
+ conv1d_type = SConv1d# if causal else WNConv1d
29
+ pad = ((7 - 1) * dilation) // 2
30
+ self.block = nn.Sequential(
31
+ Snake1d(dim),
32
+ conv1d_type(dim, dim, kernel_size=7, dilation=dilation, padding=pad, causal=causal, norm='weight_norm'),
33
+ Snake1d(dim),
34
+ conv1d_type(dim, dim, kernel_size=1, causal=causal, norm='weight_norm'),
35
+ )
36
+
37
+ def forward(self, x):
38
+ y = self.block(x)
39
+ pad = (x.shape[-1] - y.shape[-1]) // 2
40
+ if pad > 0:
41
+ x = x[..., pad:-pad]
42
+ return x + y
43
+
44
+
45
+ class EncoderBlock(nn.Module):
46
+ def __init__(self, dim: int = 16, stride: int = 1, causal: bool = False):
47
+ super().__init__()
48
+ conv1d_type = SConv1d# if causal else WNConv1d
49
+ self.block = nn.Sequential(
50
+ ResidualUnit(dim // 2, dilation=1, causal=causal),
51
+ ResidualUnit(dim // 2, dilation=3, causal=causal),
52
+ ResidualUnit(dim // 2, dilation=9, causal=causal),
53
+ Snake1d(dim // 2),
54
+ conv1d_type(
55
+ dim // 2,
56
+ dim,
57
+ kernel_size=2 * stride,
58
+ stride=stride,
59
+ padding=math.ceil(stride / 2),
60
+ causal=causal,
61
+ norm='weight_norm',
62
+ ),
63
+ )
64
+
65
+ def forward(self, x):
66
+ return self.block(x)
67
+
68
+
69
+ class Encoder(nn.Module):
70
+ def __init__(
71
+ self,
72
+ d_model: int = 64,
73
+ strides: list = [2, 4, 8, 8],
74
+ d_latent: int = 64,
75
+ causal: bool = False,
76
+ lstm: int = 2,
77
+ ):
78
+ super().__init__()
79
+ conv1d_type = SConv1d# if causal else WNConv1d
80
+ # Create first convolution
81
+ self.block = [conv1d_type(1, d_model, kernel_size=7, padding=3, causal=causal, norm='weight_norm')]
82
+
83
+ # Create EncoderBlocks that double channels as they downsample by `stride`
84
+ for stride in strides:
85
+ d_model *= 2
86
+ self.block += [EncoderBlock(d_model, stride=stride, causal=causal)]
87
+
88
+ # Add LSTM if needed
89
+ self.use_lstm = lstm
90
+ if lstm:
91
+ self.block += [SLSTM(d_model, lstm)]
92
+
93
+ # Create last convolution
94
+ self.block += [
95
+ Snake1d(d_model),
96
+ conv1d_type(d_model, d_latent, kernel_size=3, padding=1, causal=causal, norm='weight_norm'),
97
+ ]
98
+
99
+ # Wrap black into nn.Sequential
100
+ self.block = nn.Sequential(*self.block)
101
+ self.enc_dim = d_model
102
+
103
+ def forward(self, x):
104
+ return self.block(x)
105
+
106
+ def reset_cache(self):
107
+ # recursively find all submodules named SConv1d in self.block and use their reset_cache method
108
+ def reset_cache(m):
109
+ if isinstance(m, SConv1d) or isinstance(m, SLSTM):
110
+ m.reset_cache()
111
+ return
112
+ for child in m.children():
113
+ reset_cache(child)
114
+
115
+ reset_cache(self.block)
116
+
117
+
118
+ class DecoderBlock(nn.Module):
119
+ def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, causal: bool = False):
120
+ super().__init__()
121
+ conv1d_type = SConvTranspose1d #if causal else WNConvTranspose1d
122
+ self.block = nn.Sequential(
123
+ Snake1d(input_dim),
124
+ conv1d_type(
125
+ input_dim,
126
+ output_dim,
127
+ kernel_size=2 * stride,
128
+ stride=stride,
129
+ padding=math.ceil(stride / 2),
130
+ causal=causal,
131
+ norm='weight_norm'
132
+ ),
133
+ ResidualUnit(output_dim, dilation=1, causal=causal),
134
+ ResidualUnit(output_dim, dilation=3, causal=causal),
135
+ ResidualUnit(output_dim, dilation=9, causal=causal),
136
+ )
137
+
138
+ def forward(self, x):
139
+ return self.block(x)
140
+
141
+
142
+ class Decoder(nn.Module):
143
+ def __init__(
144
+ self,
145
+ input_channel,
146
+ channels,
147
+ rates,
148
+ d_out: int = 1,
149
+ causal: bool = False,
150
+ lstm: int = 2,
151
+ ):
152
+ super().__init__()
153
+ conv1d_type = SConv1d# if causal else WNConv1d
154
+ # Add first conv layer
155
+ layers = [conv1d_type(input_channel, channels, kernel_size=7, padding=3, causal=causal, norm='weight_norm')]
156
+
157
+ if lstm:
158
+ layers += [SLSTM(channels, num_layers=lstm)]
159
+
160
+ # Add upsampling + MRF blocks
161
+ for i, stride in enumerate(rates):
162
+ input_dim = channels // 2**i
163
+ output_dim = channels // 2 ** (i + 1)
164
+ layers += [DecoderBlock(input_dim, output_dim, stride, causal=causal)]
165
+
166
+ # Add final conv layer
167
+ layers += [
168
+ Snake1d(output_dim),
169
+ conv1d_type(output_dim, d_out, kernel_size=7, padding=3, causal=causal, norm='weight_norm'),
170
+ nn.Tanh(),
171
+ ]
172
+
173
+ self.model = nn.Sequential(*layers)
174
+
175
+ def forward(self, x):
176
+ return self.model(x)
177
+
178
+
179
+ class DAC(BaseModel, CodecMixin):
180
+ def __init__(
181
+ self,
182
+ encoder_dim: int = 64,
183
+ encoder_rates: List[int] = [2, 4, 8, 8],
184
+ latent_dim: int = None,
185
+ decoder_dim: int = 1536,
186
+ decoder_rates: List[int] = [8, 8, 4, 2],
187
+ n_codebooks: int = 9,
188
+ codebook_size: int = 1024,
189
+ codebook_dim: Union[int, list] = 8,
190
+ quantizer_dropout: bool = False,
191
+ sample_rate: int = 44100,
192
+ lstm: int = 2,
193
+ causal: bool = False,
194
+ ):
195
+ super().__init__()
196
+
197
+ self.encoder_dim = encoder_dim
198
+ self.encoder_rates = encoder_rates
199
+ self.decoder_dim = decoder_dim
200
+ self.decoder_rates = decoder_rates
201
+ self.sample_rate = sample_rate
202
+
203
+ if latent_dim is None:
204
+ latent_dim = encoder_dim * (2 ** len(encoder_rates))
205
+
206
+ self.latent_dim = latent_dim
207
+
208
+ self.hop_length = np.prod(encoder_rates)
209
+ self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim, causal=causal, lstm=lstm)
210
+
211
+ self.n_codebooks = n_codebooks
212
+ self.codebook_size = codebook_size
213
+ self.codebook_dim = codebook_dim
214
+ self.quantizer = ResidualVectorQuantize(
215
+ input_dim=latent_dim,
216
+ n_codebooks=n_codebooks,
217
+ codebook_size=codebook_size,
218
+ codebook_dim=codebook_dim,
219
+ quantizer_dropout=quantizer_dropout,
220
+ )
221
+
222
+ self.decoder = Decoder(
223
+ latent_dim,
224
+ decoder_dim,
225
+ decoder_rates,
226
+ lstm=lstm,
227
+ causal=causal,
228
+ )
229
+ self.sample_rate = sample_rate
230
+ self.apply(init_weights)
231
+
232
+ self.delay = self.get_delay()
233
+
234
+ def preprocess(self, audio_data, sample_rate):
235
+ if sample_rate is None:
236
+ sample_rate = self.sample_rate
237
+ assert sample_rate == self.sample_rate
238
+
239
+ length = audio_data.shape[-1]
240
+ right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
241
+ audio_data = nn.functional.pad(audio_data, (0, right_pad))
242
+
243
+ return audio_data
244
+
245
+ def encode(
246
+ self,
247
+ audio_data: torch.Tensor,
248
+ n_quantizers: int = None,
249
+ ):
250
+ """Encode given audio data and return quantized latent codes
251
+
252
+ Parameters
253
+ ----------
254
+ audio_data : Tensor[B x 1 x T]
255
+ Audio data to encode
256
+ n_quantizers : int, optional
257
+ Number of quantizers to use, by default None
258
+ If None, all quantizers are used.
259
+
260
+ Returns
261
+ -------
262
+ dict
263
+ A dictionary with the following keys:
264
+ "z" : Tensor[B x D x T]
265
+ Quantized continuous representation of input
266
+ "codes" : Tensor[B x N x T]
267
+ Codebook indices for each codebook
268
+ (quantized discrete representation of input)
269
+ "latents" : Tensor[B x N*D x T]
270
+ Projected latents (continuous representation of input before quantization)
271
+ "vq/commitment_loss" : Tensor[1]
272
+ Commitment loss to train encoder to predict vectors closer to codebook
273
+ entries
274
+ "vq/codebook_loss" : Tensor[1]
275
+ Codebook loss to update the codebook
276
+ "length" : int
277
+ Number of samples in input audio
278
+ """
279
+ z = self.encoder(audio_data)
280
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
281
+ z, n_quantizers
282
+ )
283
+ return z, codes, latents, commitment_loss, codebook_loss
284
+
285
+ def decode(self, z: torch.Tensor):
286
+ """Decode given latent codes and return audio data
287
+
288
+ Parameters
289
+ ----------
290
+ z : Tensor[B x D x T]
291
+ Quantized continuous representation of input
292
+ length : int, optional
293
+ Number of samples in output audio, by default None
294
+
295
+ Returns
296
+ -------
297
+ dict
298
+ A dictionary with the following keys:
299
+ "audio" : Tensor[B x 1 x length]
300
+ Decoded audio data.
301
+ """
302
+ return self.decoder(z)
303
+
304
+ def forward(
305
+ self,
306
+ audio_data: torch.Tensor,
307
+ sample_rate: int = None,
308
+ n_quantizers: int = None,
309
+ ):
310
+ """Model forward pass
311
+
312
+ Parameters
313
+ ----------
314
+ audio_data : Tensor[B x 1 x T]
315
+ Audio data to encode
316
+ sample_rate : int, optional
317
+ Sample rate of audio data in Hz, by default None
318
+ If None, defaults to `self.sample_rate`
319
+ n_quantizers : int, optional
320
+ Number of quantizers to use, by default None.
321
+ If None, all quantizers are used.
322
+
323
+ Returns
324
+ -------
325
+ dict
326
+ A dictionary with the following keys:
327
+ "z" : Tensor[B x D x T]
328
+ Quantized continuous representation of input
329
+ "codes" : Tensor[B x N x T]
330
+ Codebook indices for each codebook
331
+ (quantized discrete representation of input)
332
+ "latents" : Tensor[B x N*D x T]
333
+ Projected latents (continuous representation of input before quantization)
334
+ "vq/commitment_loss" : Tensor[1]
335
+ Commitment loss to train encoder to predict vectors closer to codebook
336
+ entries
337
+ "vq/codebook_loss" : Tensor[1]
338
+ Codebook loss to update the codebook
339
+ "length" : int
340
+ Number of samples in input audio
341
+ "audio" : Tensor[B x 1 x length]
342
+ Decoded audio data.
343
+ """
344
+ length = audio_data.shape[-1]
345
+ audio_data = self.preprocess(audio_data, sample_rate)
346
+ z, codes, latents, commitment_loss, codebook_loss = self.encode(
347
+ audio_data, n_quantizers
348
+ )
349
+
350
+ x = self.decode(z)
351
+ return {
352
+ "audio": x[..., :length],
353
+ "z": z,
354
+ "codes": codes,
355
+ "latents": latents,
356
+ "vq/commitment_loss": commitment_loss,
357
+ "vq/codebook_loss": codebook_loss,
358
+ }
359
+
360
+
361
+ if __name__ == "__main__":
362
+ import numpy as np
363
+ from functools import partial
364
+
365
+ model = DAC().to("cpu")
366
+
367
+ for n, m in model.named_modules():
368
+ o = m.extra_repr()
369
+ p = sum([np.prod(p.size()) for p in m.parameters()])
370
+ fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
371
+ setattr(m, "extra_repr", partial(fn, o=o, p=p))
372
+ print(model)
373
+ print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
374
+
375
+ length = 88200 * 2
376
+ x = torch.randn(1, 1, length).to(model.device)
377
+ x.requires_grad_(True)
378
+ x.retain_grad()
379
+
380
+ # Make a forward pass
381
+ out = model(x)["audio"]
382
+ print("Input shape:", x.shape)
383
+ print("Output shape:", out.shape)
384
+
385
+ # Create gradient variable
386
+ grad = torch.zeros_like(out)
387
+ grad[:, :, grad.shape[-1] // 2] = 1
388
+
389
+ # Make a backward pass
390
+ out.backward(grad)
391
+
392
+ # Check non-zero values
393
+ gradmap = x.grad.squeeze(0)
394
+ gradmap = (gradmap != 0).sum(0) # sum across features
395
+ rf = (gradmap != 0).sum()
396
+
397
+ print(f"Receptive field: {rf.item()}")
398
+
399
+ x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
400
+ model.decompress(model.compress(x, verbose=True), verbose=True)
dac/model/discriminator.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from audiotools import AudioSignal
5
+ from audiotools import ml
6
+ from audiotools import STFTParams
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+
11
+ def WNConv1d(*args, **kwargs):
12
+ act = kwargs.pop("act", True)
13
+ conv = weight_norm(nn.Conv1d(*args, **kwargs))
14
+ if not act:
15
+ return conv
16
+ return nn.Sequential(conv, nn.LeakyReLU(0.1))
17
+
18
+
19
+ def WNConv2d(*args, **kwargs):
20
+ act = kwargs.pop("act", True)
21
+ conv = weight_norm(nn.Conv2d(*args, **kwargs))
22
+ if not act:
23
+ return conv
24
+ return nn.Sequential(conv, nn.LeakyReLU(0.1))
25
+
26
+
27
+ class MPD(nn.Module):
28
+ def __init__(self, period):
29
+ super().__init__()
30
+ self.period = period
31
+ self.convs = nn.ModuleList(
32
+ [
33
+ WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
34
+ WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
35
+ WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
36
+ WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
37
+ WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
38
+ ]
39
+ )
40
+ self.conv_post = WNConv2d(
41
+ 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
42
+ )
43
+
44
+ def pad_to_period(self, x):
45
+ t = x.shape[-1]
46
+ x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
47
+ return x
48
+
49
+ def forward(self, x):
50
+ fmap = []
51
+
52
+ x = self.pad_to_period(x)
53
+ x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
54
+
55
+ for layer in self.convs:
56
+ x = layer(x)
57
+ fmap.append(x)
58
+
59
+ x = self.conv_post(x)
60
+ fmap.append(x)
61
+
62
+ return fmap
63
+
64
+
65
+ class MSD(nn.Module):
66
+ def __init__(self, rate: int = 1, sample_rate: int = 44100):
67
+ super().__init__()
68
+ self.convs = nn.ModuleList(
69
+ [
70
+ WNConv1d(1, 16, 15, 1, padding=7),
71
+ WNConv1d(16, 64, 41, 4, groups=4, padding=20),
72
+ WNConv1d(64, 256, 41, 4, groups=16, padding=20),
73
+ WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
74
+ WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
75
+ WNConv1d(1024, 1024, 5, 1, padding=2),
76
+ ]
77
+ )
78
+ self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
79
+ self.sample_rate = sample_rate
80
+ self.rate = rate
81
+
82
+ def forward(self, x):
83
+ x = AudioSignal(x, self.sample_rate)
84
+ x.resample(self.sample_rate // self.rate)
85
+ x = x.audio_data
86
+
87
+ fmap = []
88
+
89
+ for l in self.convs:
90
+ x = l(x)
91
+ fmap.append(x)
92
+ x = self.conv_post(x)
93
+ fmap.append(x)
94
+
95
+ return fmap
96
+
97
+
98
+ BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
99
+
100
+
101
+ class MRD(nn.Module):
102
+ def __init__(
103
+ self,
104
+ window_length: int,
105
+ hop_factor: float = 0.25,
106
+ sample_rate: int = 44100,
107
+ bands: list = BANDS,
108
+ ):
109
+ """Complex multi-band spectrogram discriminator.
110
+ Parameters
111
+ ----------
112
+ window_length : int
113
+ Window length of STFT.
114
+ hop_factor : float, optional
115
+ Hop factor of the STFT, defaults to ``0.25 * window_length``.
116
+ sample_rate : int, optional
117
+ Sampling rate of audio in Hz, by default 44100
118
+ bands : list, optional
119
+ Bands to run discriminator over.
120
+ """
121
+ super().__init__()
122
+
123
+ self.window_length = window_length
124
+ self.hop_factor = hop_factor
125
+ self.sample_rate = sample_rate
126
+ self.stft_params = STFTParams(
127
+ window_length=window_length,
128
+ hop_length=int(window_length * hop_factor),
129
+ match_stride=True,
130
+ )
131
+
132
+ n_fft = window_length // 2 + 1
133
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
134
+ self.bands = bands
135
+
136
+ ch = 32
137
+ convs = lambda: nn.ModuleList(
138
+ [
139
+ WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
140
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
141
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
142
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
143
+ WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
144
+ ]
145
+ )
146
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
147
+ self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
148
+
149
+ def spectrogram(self, x):
150
+ x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
151
+ x = torch.view_as_real(x.stft())
152
+ x = rearrange(x, "b 1 f t c -> (b 1) c t f")
153
+ # Split into bands
154
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
155
+ return x_bands
156
+
157
+ def forward(self, x):
158
+ x_bands = self.spectrogram(x)
159
+ fmap = []
160
+
161
+ x = []
162
+ for band, stack in zip(x_bands, self.band_convs):
163
+ for layer in stack:
164
+ band = layer(band)
165
+ fmap.append(band)
166
+ x.append(band)
167
+
168
+ x = torch.cat(x, dim=-1)
169
+ x = self.conv_post(x)
170
+ fmap.append(x)
171
+
172
+ return fmap
173
+
174
+
175
+ class Discriminator(nn.Module):
176
+ def __init__(
177
+ self,
178
+ rates: list = [],
179
+ periods: list = [2, 3, 5, 7, 11],
180
+ fft_sizes: list = [2048, 1024, 512],
181
+ sample_rate: int = 44100,
182
+ bands: list = BANDS,
183
+ ):
184
+ """Discriminator that combines multiple discriminators.
185
+
186
+ Parameters
187
+ ----------
188
+ rates : list, optional
189
+ sampling rates (in Hz) to run MSD at, by default []
190
+ If empty, MSD is not used.
191
+ periods : list, optional
192
+ periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
193
+ fft_sizes : list, optional
194
+ Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
195
+ sample_rate : int, optional
196
+ Sampling rate of audio in Hz, by default 44100
197
+ bands : list, optional
198
+ Bands to run MRD at, by default `BANDS`
199
+ """
200
+ super().__init__()
201
+ discs = []
202
+ discs += [MPD(p) for p in periods]
203
+ discs += [MSD(r, sample_rate=sample_rate) for r in rates]
204
+ discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
205
+ self.discriminators = nn.ModuleList(discs)
206
+
207
+ def preprocess(self, y):
208
+ # Remove DC offset
209
+ y = y - y.mean(dim=-1, keepdims=True)
210
+ # Peak normalize the volume of input audio
211
+ y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
212
+ return y
213
+
214
+ def forward(self, x):
215
+ x = self.preprocess(x)
216
+ fmaps = [d(x) for d in self.discriminators]
217
+ return fmaps
218
+
219
+
220
+ if __name__ == "__main__":
221
+ disc = Discriminator()
222
+ x = torch.zeros(1, 1, 44100)
223
+ results = disc(x)
224
+ for i, result in enumerate(results):
225
+ print(f"disc{i}")
226
+ for i, r in enumerate(result):
227
+ print(r.shape, r.mean(), r.min(), r.max())
228
+ print()
dac/model/encodec.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Convolutional layers wrappers and utilities."""
8
+
9
+ import math
10
+ import typing as tp
11
+ import warnings
12
+
13
+ import torch
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+ from torch.nn.utils import spectral_norm, weight_norm
17
+
18
+ import typing as tp
19
+
20
+ import einops
21
+
22
+
23
+ class ConvLayerNorm(nn.LayerNorm):
24
+ """
25
+ Convolution-friendly LayerNorm that moves channels to last dimensions
26
+ before running the normalization and moves them back to original position right after.
27
+ """
28
+ def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
29
+ super().__init__(normalized_shape, **kwargs)
30
+
31
+ def forward(self, x):
32
+ x = einops.rearrange(x, 'b ... t -> b t ...')
33
+ x = super().forward(x)
34
+ x = einops.rearrange(x, 'b t ... -> b ... t')
35
+ return
36
+
37
+
38
+ CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
39
+ 'time_layer_norm', 'layer_norm', 'time_group_norm'])
40
+
41
+
42
+ def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
43
+ assert norm in CONV_NORMALIZATIONS
44
+ if norm == 'weight_norm':
45
+ return weight_norm(module)
46
+ elif norm == 'spectral_norm':
47
+ return spectral_norm(module)
48
+ else:
49
+ # We already check was in CONV_NORMALIZATION, so any other choice
50
+ # doesn't need reparametrization.
51
+ return module
52
+
53
+
54
+ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
55
+ """Return the proper normalization module. If causal is True, this will ensure the returned
56
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
57
+ """
58
+ assert norm in CONV_NORMALIZATIONS
59
+ if norm == 'layer_norm':
60
+ assert isinstance(module, nn.modules.conv._ConvNd)
61
+ return ConvLayerNorm(module.out_channels, **norm_kwargs)
62
+ elif norm == 'time_group_norm':
63
+ if causal:
64
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
65
+ assert isinstance(module, nn.modules.conv._ConvNd)
66
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
67
+ else:
68
+ return nn.Identity()
69
+
70
+
71
+ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
72
+ padding_total: int = 0) -> int:
73
+ """See `pad_for_conv1d`.
74
+ """
75
+ length = x.shape[-1]
76
+ n_frames = (length - kernel_size + padding_total) / stride + 1
77
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
78
+ return ideal_length - length
79
+
80
+
81
+ def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
82
+ """Pad for a convolution to make sure that the last window is full.
83
+ Extra padding is added at the end. This is required to ensure that we can rebuild
84
+ an output of the same length, as otherwise, even with padding, some time steps
85
+ might get removed.
86
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
87
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
88
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
89
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
90
+ 1 2 3 4 # once you removed padding, we are missing one time step !
91
+ """
92
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
93
+ return F.pad(x, (0, extra_padding))
94
+
95
+
96
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
97
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
98
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
99
+ """
100
+ length = x.shape[-1]
101
+ padding_left, padding_right = paddings
102
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
103
+ if mode == 'reflect':
104
+ max_pad = max(padding_left, padding_right)
105
+ extra_pad = 0
106
+ if length <= max_pad:
107
+ extra_pad = max_pad - length + 1
108
+ x = F.pad(x, (0, extra_pad))
109
+ padded = F.pad(x, paddings, mode, value)
110
+ end = padded.shape[-1] - extra_pad
111
+ return padded[..., :end]
112
+ else:
113
+ return F.pad(x, paddings, mode, value)
114
+
115
+
116
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
117
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
118
+ padding_left, padding_right = paddings
119
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
120
+ assert (padding_left + padding_right) <= x.shape[-1]
121
+ end = x.shape[-1] - padding_right
122
+ return x[..., padding_left: end]
123
+
124
+
125
+ class NormConv1d(nn.Module):
126
+ """Wrapper around Conv1d and normalization applied to this conv
127
+ to provide a uniform interface across normalization approaches.
128
+ """
129
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
130
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
131
+ super().__init__()
132
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
133
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
134
+ self.norm_type = norm
135
+
136
+ def forward(self, x):
137
+ x = self.conv(x)
138
+ x = self.norm(x)
139
+ return x
140
+
141
+
142
+ class NormConv2d(nn.Module):
143
+ """Wrapper around Conv2d and normalization applied to this conv
144
+ to provide a uniform interface across normalization approaches.
145
+ """
146
+ def __init__(self, *args, norm: str = 'none',
147
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
148
+ super().__init__()
149
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
150
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
151
+ self.norm_type = norm
152
+
153
+ def forward(self, x):
154
+ x = self.conv(x)
155
+ x = self.norm(x)
156
+ return x
157
+
158
+
159
+ class NormConvTranspose1d(nn.Module):
160
+ """Wrapper around ConvTranspose1d and normalization applied to this conv
161
+ to provide a uniform interface across normalization approaches.
162
+ """
163
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
164
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
165
+ super().__init__()
166
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
167
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
168
+ self.norm_type = norm
169
+
170
+ def forward(self, x):
171
+ x = self.convtr(x)
172
+ x = self.norm(x)
173
+ return x
174
+
175
+
176
+ class NormConvTranspose2d(nn.Module):
177
+ """Wrapper around ConvTranspose2d and normalization applied to this conv
178
+ to provide a uniform interface across normalization approaches.
179
+ """
180
+ def __init__(self, *args, norm: str = 'none',
181
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
182
+ super().__init__()
183
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
184
+ self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
185
+
186
+ def forward(self, x):
187
+ x = self.convtr(x)
188
+ x = self.norm(x)
189
+ return x
190
+
191
+
192
+ class SConv1d(nn.Module):
193
+ """Conv1d with some builtin handling of asymmetric or causal padding
194
+ and normalization.
195
+ """
196
+ def __init__(self, in_channels: int, out_channels: int,
197
+ kernel_size: int, stride: int = 1, dilation: int = 1,
198
+ groups: int = 1, bias: bool = True, causal: bool = False,
199
+ norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
200
+ pad_mode: str = 'reflect', **kwargs):
201
+ super().__init__()
202
+ # warn user on unusual setup between dilation and stride
203
+ if stride > 1 and dilation > 1:
204
+ warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
205
+ f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
206
+ self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
207
+ dilation=dilation, groups=groups, bias=bias, causal=causal,
208
+ norm=norm, norm_kwargs=norm_kwargs)
209
+ self.causal = causal
210
+ self.pad_mode = pad_mode
211
+
212
+ self.cache_enabled = False
213
+
214
+ def reset_cache(self):
215
+ """Reset the cache when starting a new stream."""
216
+ self.cache = None
217
+ self.cache_enabled = True
218
+
219
+ def forward(self, x):
220
+ B, C, T = x.shape
221
+ kernel_size = self.conv.conv.kernel_size[0]
222
+ stride = self.conv.conv.stride[0]
223
+ dilation = self.conv.conv.dilation[0]
224
+ kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
225
+ padding_total = kernel_size - stride
226
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
227
+
228
+ if self.causal:
229
+ # Left padding for causal
230
+ if self.cache_enabled and self.cache is not None:
231
+ # Concatenate the cache (previous inputs) with the new input for streaming
232
+ x = torch.cat([self.cache, x], dim=2)
233
+ else:
234
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
235
+ else:
236
+ # Asymmetric padding required for odd strides
237
+ padding_right = padding_total // 2
238
+ padding_left = padding_total - padding_right
239
+ x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
240
+
241
+ # Store the most recent input frames for future cache use
242
+ if self.cache_enabled:
243
+ if self.cache is None:
244
+ # Initialize cache with zeros (at the start of streaming)
245
+ self.cache = torch.zeros(B, C, kernel_size - 1, device=x.device)
246
+ # Update the cache by storing the latest input frames
247
+ if kernel_size > 1:
248
+ self.cache = x[:, :, -kernel_size + 1:].detach() # Only store the necessary frames
249
+
250
+ return self.conv(x)
251
+
252
+
253
+
254
+ class SConvTranspose1d(nn.Module):
255
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
256
+ and normalization.
257
+ """
258
+ def __init__(self, in_channels: int, out_channels: int,
259
+ kernel_size: int, stride: int = 1, causal: bool = False,
260
+ norm: str = 'none', trim_right_ratio: float = 1.,
261
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
262
+ super().__init__()
263
+ self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
264
+ causal=causal, norm=norm, norm_kwargs=norm_kwargs)
265
+ self.causal = causal
266
+ self.trim_right_ratio = trim_right_ratio
267
+ assert self.causal or self.trim_right_ratio == 1., \
268
+ "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
269
+ assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
270
+
271
+ def forward(self, x):
272
+ kernel_size = self.convtr.convtr.kernel_size[0]
273
+ stride = self.convtr.convtr.stride[0]
274
+ padding_total = kernel_size - stride
275
+
276
+ y = self.convtr(x)
277
+
278
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
279
+ # removed at the very end, when keeping only the right length for the output,
280
+ # as removing it here would require also passing the length at the matching layer
281
+ # in the encoder.
282
+ if self.causal:
283
+ # Trim the padding on the right according to the specified ratio
284
+ # if trim_right_ratio = 1.0, trim everything from right
285
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
286
+ padding_left = padding_total - padding_right
287
+ y = unpad1d(y, (padding_left, padding_right))
288
+ else:
289
+ # Asymmetric padding required for odd strides
290
+ padding_right = padding_total // 2
291
+ padding_left = padding_total - padding_right
292
+ y = unpad1d(y, (padding_left, padding_right))
293
+ return y
294
+
295
+ class SLSTM(nn.Module):
296
+ """
297
+ LSTM without worrying about the hidden state, nor the layout of the data.
298
+ Expects input as convolutional layout.
299
+ """
300
+ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
301
+ super().__init__()
302
+ self.skip = skip
303
+ self.lstm = nn.LSTM(dimension, dimension, num_layers)
304
+ self.hidden = None
305
+ self.cache_enabled = False
306
+
307
+ def forward(self, x):
308
+ x = x.permute(2, 0, 1)
309
+ if self.training or not self.cache_enabled:
310
+ y, _ = self.lstm(x)
311
+ else:
312
+ y, self.hidden = self.lstm(x, self.hidden)
313
+ if self.skip:
314
+ y = y + x
315
+ y = y.permute(1, 2, 0)
316
+ return y
317
+
318
+ def reset_cache(self):
319
+ self.hidden = None
320
+ self.cache_enabled = True
dac/nn/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import layers
2
+ from . import loss
3
+ from . import quantize
dac/nn/layers.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from torch.nn.utils import weight_norm
7
+
8
+
9
+ def WNConv1d(*args, **kwargs):
10
+ return weight_norm(nn.Conv1d(*args, **kwargs))
11
+
12
+
13
+ def WNConvTranspose1d(*args, **kwargs):
14
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
15
+
16
+
17
+ # Scripting this brings model speed up 1.4x
18
+ @torch.jit.script
19
+ def snake(x, alpha):
20
+ shape = x.shape
21
+ x = x.reshape(shape[0], shape[1], -1)
22
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
23
+ x = x.reshape(shape)
24
+ return x
25
+
26
+
27
+ class Snake1d(nn.Module):
28
+ def __init__(self, channels):
29
+ super().__init__()
30
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
31
+
32
+ def forward(self, x):
33
+ return snake(x, self.alpha)
dac/nn/loss.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+ from typing import List
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from audiotools import AudioSignal
7
+ from audiotools import STFTParams
8
+ from torch import nn
9
+
10
+
11
+ class L1Loss(nn.L1Loss):
12
+ """L1 Loss between AudioSignals. Defaults
13
+ to comparing ``audio_data``, but any
14
+ attribute of an AudioSignal can be used.
15
+
16
+ Parameters
17
+ ----------
18
+ attribute : str, optional
19
+ Attribute of signal to compare, defaults to ``audio_data``.
20
+ weight : float, optional
21
+ Weight of this loss, defaults to 1.0.
22
+
23
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
24
+ """
25
+
26
+ def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
27
+ self.attribute = attribute
28
+ self.weight = weight
29
+ super().__init__(**kwargs)
30
+
31
+ def forward(self, x: AudioSignal, y: AudioSignal):
32
+ """
33
+ Parameters
34
+ ----------
35
+ x : AudioSignal
36
+ Estimate AudioSignal
37
+ y : AudioSignal
38
+ Reference AudioSignal
39
+
40
+ Returns
41
+ -------
42
+ torch.Tensor
43
+ L1 loss between AudioSignal attributes.
44
+ """
45
+ if isinstance(x, AudioSignal):
46
+ x = getattr(x, self.attribute)
47
+ y = getattr(y, self.attribute)
48
+ return super().forward(x, y)
49
+
50
+
51
+ class SISDRLoss(nn.Module):
52
+ """
53
+ Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
54
+ of estimated and reference audio signals or aligned features.
55
+
56
+ Parameters
57
+ ----------
58
+ scaling : int, optional
59
+ Whether to use scale-invariant (True) or
60
+ signal-to-noise ratio (False), by default True
61
+ reduction : str, optional
62
+ How to reduce across the batch (either 'mean',
63
+ 'sum', or none).], by default ' mean'
64
+ zero_mean : int, optional
65
+ Zero mean the references and estimates before
66
+ computing the loss, by default True
67
+ clip_min : int, optional
68
+ The minimum possible loss value. Helps network
69
+ to not focus on making already good examples better, by default None
70
+ weight : float, optional
71
+ Weight of this loss, defaults to 1.0.
72
+
73
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ scaling: int = True,
79
+ reduction: str = "mean",
80
+ zero_mean: int = True,
81
+ clip_min: int = None,
82
+ weight: float = 1.0,
83
+ ):
84
+ self.scaling = scaling
85
+ self.reduction = reduction
86
+ self.zero_mean = zero_mean
87
+ self.clip_min = clip_min
88
+ self.weight = weight
89
+ super().__init__()
90
+
91
+ def forward(self, x: AudioSignal, y: AudioSignal):
92
+ eps = 1e-8
93
+ # nb, nc, nt
94
+ if isinstance(x, AudioSignal):
95
+ references = x.audio_data
96
+ estimates = y.audio_data
97
+ else:
98
+ references = x
99
+ estimates = y
100
+
101
+ nb = references.shape[0]
102
+ references = references.reshape(nb, 1, -1).permute(0, 2, 1)
103
+ estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
104
+
105
+ # samples now on axis 1
106
+ if self.zero_mean:
107
+ mean_reference = references.mean(dim=1, keepdim=True)
108
+ mean_estimate = estimates.mean(dim=1, keepdim=True)
109
+ else:
110
+ mean_reference = 0
111
+ mean_estimate = 0
112
+
113
+ _references = references - mean_reference
114
+ _estimates = estimates - mean_estimate
115
+
116
+ references_projection = (_references**2).sum(dim=-2) + eps
117
+ references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
118
+
119
+ scale = (
120
+ (references_on_estimates / references_projection).unsqueeze(1)
121
+ if self.scaling
122
+ else 1
123
+ )
124
+
125
+ e_true = scale * _references
126
+ e_res = _estimates - e_true
127
+
128
+ signal = (e_true**2).sum(dim=1)
129
+ noise = (e_res**2).sum(dim=1)
130
+ sdr = -10 * torch.log10(signal / noise + eps)
131
+
132
+ if self.clip_min is not None:
133
+ sdr = torch.clamp(sdr, min=self.clip_min)
134
+
135
+ if self.reduction == "mean":
136
+ sdr = sdr.mean()
137
+ elif self.reduction == "sum":
138
+ sdr = sdr.sum()
139
+ return sdr
140
+
141
+
142
+ class MultiScaleSTFTLoss(nn.Module):
143
+ """Computes the multi-scale STFT loss from [1].
144
+
145
+ Parameters
146
+ ----------
147
+ window_lengths : List[int], optional
148
+ Length of each window of each STFT, by default [2048, 512]
149
+ loss_fn : typing.Callable, optional
150
+ How to compare each loss, by default nn.L1Loss()
151
+ clamp_eps : float, optional
152
+ Clamp on the log magnitude, below, by default 1e-5
153
+ mag_weight : float, optional
154
+ Weight of raw magnitude portion of loss, by default 1.0
155
+ log_weight : float, optional
156
+ Weight of log magnitude portion of loss, by default 1.0
157
+ pow : float, optional
158
+ Power to raise magnitude to before taking log, by default 2.0
159
+ weight : float, optional
160
+ Weight of this loss, by default 1.0
161
+ match_stride : bool, optional
162
+ Whether to match the stride of convolutional layers, by default False
163
+
164
+ References
165
+ ----------
166
+
167
+ 1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
168
+ "DDSP: Differentiable Digital Signal Processing."
169
+ International Conference on Learning Representations. 2019.
170
+
171
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ window_lengths: List[int] = [2048, 512],
177
+ loss_fn: typing.Callable = nn.L1Loss(),
178
+ clamp_eps: float = 1e-5,
179
+ mag_weight: float = 1.0,
180
+ log_weight: float = 1.0,
181
+ pow: float = 2.0,
182
+ weight: float = 1.0,
183
+ match_stride: bool = False,
184
+ window_type: str = None,
185
+ ):
186
+ super().__init__()
187
+ self.stft_params = [
188
+ STFTParams(
189
+ window_length=w,
190
+ hop_length=w // 4,
191
+ match_stride=match_stride,
192
+ window_type=window_type,
193
+ )
194
+ for w in window_lengths
195
+ ]
196
+ self.loss_fn = loss_fn
197
+ self.log_weight = log_weight
198
+ self.mag_weight = mag_weight
199
+ self.clamp_eps = clamp_eps
200
+ self.weight = weight
201
+ self.pow = pow
202
+
203
+ def forward(self, x: AudioSignal, y: AudioSignal):
204
+ """Computes multi-scale STFT between an estimate and a reference
205
+ signal.
206
+
207
+ Parameters
208
+ ----------
209
+ x : AudioSignal
210
+ Estimate signal
211
+ y : AudioSignal
212
+ Reference signal
213
+
214
+ Returns
215
+ -------
216
+ torch.Tensor
217
+ Multi-scale STFT loss.
218
+ """
219
+ loss = 0.0
220
+ for s in self.stft_params:
221
+ x.stft(s.window_length, s.hop_length, s.window_type)
222
+ y.stft(s.window_length, s.hop_length, s.window_type)
223
+ loss += self.log_weight * self.loss_fn(
224
+ x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
225
+ y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
226
+ )
227
+ loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
228
+ return loss
229
+
230
+
231
+ class MelSpectrogramLoss(nn.Module):
232
+ """Compute distance between mel spectrograms. Can be used
233
+ in a multi-scale way.
234
+
235
+ Parameters
236
+ ----------
237
+ n_mels : List[int]
238
+ Number of mels per STFT, by default [150, 80],
239
+ window_lengths : List[int], optional
240
+ Length of each window of each STFT, by default [2048, 512]
241
+ loss_fn : typing.Callable, optional
242
+ How to compare each loss, by default nn.L1Loss()
243
+ clamp_eps : float, optional
244
+ Clamp on the log magnitude, below, by default 1e-5
245
+ mag_weight : float, optional
246
+ Weight of raw magnitude portion of loss, by default 1.0
247
+ log_weight : float, optional
248
+ Weight of log magnitude portion of loss, by default 1.0
249
+ pow : float, optional
250
+ Power to raise magnitude to before taking log, by default 2.0
251
+ weight : float, optional
252
+ Weight of this loss, by default 1.0
253
+ match_stride : bool, optional
254
+ Whether to match the stride of convolutional layers, by default False
255
+
256
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
257
+ """
258
+
259
+ def __init__(
260
+ self,
261
+ n_mels: List[int] = [150, 80],
262
+ window_lengths: List[int] = [2048, 512],
263
+ loss_fn: typing.Callable = nn.L1Loss(),
264
+ clamp_eps: float = 1e-5,
265
+ mag_weight: float = 1.0,
266
+ log_weight: float = 1.0,
267
+ pow: float = 2.0,
268
+ weight: float = 1.0,
269
+ match_stride: bool = False,
270
+ mel_fmin: List[float] = [0.0, 0.0],
271
+ mel_fmax: List[float] = [None, None],
272
+ window_type: str = None,
273
+ ):
274
+ super().__init__()
275
+ self.stft_params = [
276
+ STFTParams(
277
+ window_length=w,
278
+ hop_length=w // 4,
279
+ match_stride=match_stride,
280
+ window_type=window_type,
281
+ )
282
+ for w in window_lengths
283
+ ]
284
+ self.n_mels = n_mels
285
+ self.loss_fn = loss_fn
286
+ self.clamp_eps = clamp_eps
287
+ self.log_weight = log_weight
288
+ self.mag_weight = mag_weight
289
+ self.weight = weight
290
+ self.mel_fmin = mel_fmin
291
+ self.mel_fmax = mel_fmax
292
+ self.pow = pow
293
+
294
+ def forward(self, x: AudioSignal, y: AudioSignal):
295
+ """Computes mel loss between an estimate and a reference
296
+ signal.
297
+
298
+ Parameters
299
+ ----------
300
+ x : AudioSignal
301
+ Estimate signal
302
+ y : AudioSignal
303
+ Reference signal
304
+
305
+ Returns
306
+ -------
307
+ torch.Tensor
308
+ Mel loss.
309
+ """
310
+ loss = 0.0
311
+ for n_mels, fmin, fmax, s in zip(
312
+ self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
313
+ ):
314
+ kwargs = {
315
+ "window_length": s.window_length,
316
+ "hop_length": s.hop_length,
317
+ "window_type": s.window_type,
318
+ }
319
+ x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
320
+ y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
321
+
322
+ loss += self.log_weight * self.loss_fn(
323
+ x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
324
+ y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
325
+ )
326
+ loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
327
+ return loss
328
+
329
+
330
+ class GANLoss(nn.Module):
331
+ """
332
+ Computes a discriminator loss, given a discriminator on
333
+ generated waveforms/spectrograms compared to ground truth
334
+ waveforms/spectrograms. Computes the loss for both the
335
+ discriminator and the generator in separate functions.
336
+ """
337
+
338
+ def __init__(self, discriminator):
339
+ super().__init__()
340
+ self.discriminator = discriminator
341
+
342
+ def forward(self, fake, real):
343
+ d_fake = self.discriminator(fake.audio_data)
344
+ d_real = self.discriminator(real.audio_data)
345
+ return d_fake, d_real
346
+
347
+ def discriminator_loss(self, fake, real):
348
+ d_fake, d_real = self.forward(fake.clone().detach(), real)
349
+
350
+ loss_d = 0
351
+ for x_fake, x_real in zip(d_fake, d_real):
352
+ loss_d += torch.mean(x_fake[-1] ** 2)
353
+ loss_d += torch.mean((1 - x_real[-1]) ** 2)
354
+ return loss_d
355
+
356
+ def generator_loss(self, fake, real):
357
+ d_fake, d_real = self.forward(fake, real)
358
+
359
+ loss_g = 0
360
+ for x_fake in d_fake:
361
+ loss_g += torch.mean((1 - x_fake[-1]) ** 2)
362
+
363
+ loss_feature = 0
364
+
365
+ for i in range(len(d_fake)):
366
+ for j in range(len(d_fake[i]) - 1):
367
+ loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
368
+ return loss_g, loss_feature
dac/nn/quantize.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+ from dac.nn.layers import WNConv1d
11
+
12
+ class VectorQuantizeLegacy(nn.Module):
13
+ """
14
+ Implementation of VQ similar to Karpathy's repo:
15
+ https://github.com/karpathy/deep-vector-quantization
16
+ removed in-out projection
17
+ """
18
+
19
+ def __init__(self, input_dim: int, codebook_size: int):
20
+ super().__init__()
21
+ self.codebook_size = codebook_size
22
+ self.codebook = nn.Embedding(codebook_size, input_dim)
23
+
24
+ def forward(self, z, z_mask=None):
25
+ """Quantized the input tensor using a fixed codebook and returns
26
+ the corresponding codebook vectors
27
+
28
+ Parameters
29
+ ----------
30
+ z : Tensor[B x D x T]
31
+
32
+ Returns
33
+ -------
34
+ Tensor[B x D x T]
35
+ Quantized continuous representation of input
36
+ Tensor[1]
37
+ Commitment loss to train encoder to predict vectors closer to codebook
38
+ entries
39
+ Tensor[1]
40
+ Codebook loss to update the codebook
41
+ Tensor[B x T]
42
+ Codebook indices (quantized discrete representation of input)
43
+ Tensor[B x D x T]
44
+ Projected latents (continuous representation of input before quantization)
45
+ """
46
+
47
+ z_e = z
48
+ z_q, indices = self.decode_latents(z)
49
+
50
+ if z_mask is not None:
51
+ commitment_loss = (F.mse_loss(z_e, z_q.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
52
+ codebook_loss = (F.mse_loss(z_q, z_e.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
53
+ else:
54
+ commitment_loss = F.mse_loss(z_e, z_q.detach())
55
+ codebook_loss = F.mse_loss(z_q, z_e.detach())
56
+ z_q = (
57
+ z_e + (z_q - z_e).detach()
58
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
59
+
60
+ return z_q, indices, z_e, commitment_loss, codebook_loss
61
+
62
+ def embed_code(self, embed_id):
63
+ return F.embedding(embed_id, self.codebook.weight)
64
+
65
+ def decode_code(self, embed_id):
66
+ return self.embed_code(embed_id).transpose(1, 2)
67
+
68
+ def decode_latents(self, latents):
69
+ encodings = rearrange(latents, "b d t -> (b t) d")
70
+ codebook = self.codebook.weight # codebook: (N x D)
71
+
72
+ # L2 normalize encodings and codebook (ViT-VQGAN)
73
+ encodings = F.normalize(encodings)
74
+ codebook = F.normalize(codebook)
75
+
76
+ # Compute euclidean distance with codebook
77
+ dist = (
78
+ encodings.pow(2).sum(1, keepdim=True)
79
+ - 2 * encodings @ codebook.t()
80
+ + codebook.pow(2).sum(1, keepdim=True).t()
81
+ )
82
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
83
+ z_q = self.decode_code(indices)
84
+ return z_q, indices
85
+
86
+ class VectorQuantize(nn.Module):
87
+ """
88
+ Implementation of VQ similar to Karpathy's repo:
89
+ https://github.com/karpathy/deep-vector-quantization
90
+ Additionally uses following tricks from Improved VQGAN
91
+ (https://arxiv.org/pdf/2110.04627.pdf):
92
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
93
+ for improved codebook usage
94
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
95
+ improves training stability
96
+ """
97
+
98
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
99
+ super().__init__()
100
+ self.codebook_size = codebook_size
101
+ self.codebook_dim = codebook_dim
102
+
103
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
104
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
105
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
106
+
107
+ def forward(self, z, z_mask=None):
108
+ """Quantized the input tensor using a fixed codebook and returns
109
+ the corresponding codebook vectors
110
+
111
+ Parameters
112
+ ----------
113
+ z : Tensor[B x D x T]
114
+
115
+ Returns
116
+ -------
117
+ Tensor[B x D x T]
118
+ Quantized continuous representation of input
119
+ Tensor[1]
120
+ Commitment loss to train encoder to predict vectors closer to codebook
121
+ entries
122
+ Tensor[1]
123
+ Codebook loss to update the codebook
124
+ Tensor[B x T]
125
+ Codebook indices (quantized discrete representation of input)
126
+ Tensor[B x D x T]
127
+ Projected latents (continuous representation of input before quantization)
128
+ """
129
+
130
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
131
+ z_e = self.in_proj(z) # z_e : (B x D x T)
132
+ z_q, indices = self.decode_latents(z_e)
133
+
134
+ if z_mask is not None:
135
+ commitment_loss = (F.mse_loss(z_e, z_q.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
136
+ codebook_loss = (F.mse_loss(z_q, z_e.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
137
+ else:
138
+ commitment_loss = F.mse_loss(z_e, z_q.detach())
139
+ codebook_loss = F.mse_loss(z_q, z_e.detach())
140
+
141
+ z_q = (
142
+ z_e + (z_q - z_e).detach()
143
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
144
+
145
+ z_q = self.out_proj(z_q)
146
+
147
+ return z_q, commitment_loss, codebook_loss, indices, z_e
148
+
149
+ def embed_code(self, embed_id):
150
+ return F.embedding(embed_id, self.codebook.weight)
151
+
152
+ def decode_code(self, embed_id):
153
+ return self.embed_code(embed_id).transpose(1, 2)
154
+
155
+ def decode_latents(self, latents):
156
+ encodings = rearrange(latents, "b d t -> (b t) d")
157
+ codebook = self.codebook.weight # codebook: (N x D)
158
+
159
+ # L2 normalize encodings and codebook (ViT-VQGAN)
160
+ encodings = F.normalize(encodings)
161
+ codebook = F.normalize(codebook)
162
+
163
+ # Compute euclidean distance with codebook
164
+ dist = (
165
+ encodings.pow(2).sum(1, keepdim=True)
166
+ - 2 * encodings @ codebook.t()
167
+ + codebook.pow(2).sum(1, keepdim=True).t()
168
+ )
169
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
170
+ z_q = self.decode_code(indices)
171
+ return z_q, indices
172
+
173
+
174
+ class ResidualVectorQuantize(nn.Module):
175
+ """
176
+ Introduced in SoundStream: An end2end neural audio codec
177
+ https://arxiv.org/abs/2107.03312
178
+ """
179
+
180
+ def __init__(
181
+ self,
182
+ input_dim: int = 512,
183
+ n_codebooks: int = 9,
184
+ codebook_size: int = 1024,
185
+ codebook_dim: Union[int, list] = 8,
186
+ quantizer_dropout: float = 0.0,
187
+ ):
188
+ super().__init__()
189
+ if isinstance(codebook_dim, int):
190
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
191
+
192
+ self.n_codebooks = n_codebooks
193
+ self.codebook_dim = codebook_dim
194
+ self.codebook_size = codebook_size
195
+
196
+ self.quantizers = nn.ModuleList(
197
+ [
198
+ VectorQuantize(input_dim, codebook_size, codebook_dim[i])
199
+ for i in range(n_codebooks)
200
+ ]
201
+ )
202
+ self.quantizer_dropout = quantizer_dropout
203
+
204
+ def forward(self, z, n_quantizers: int = None):
205
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
206
+ the corresponding codebook vectors
207
+ Parameters
208
+ ----------
209
+ z : Tensor[B x D x T]
210
+ n_quantizers : int, optional
211
+ No. of quantizers to use
212
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
213
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
214
+ when in training mode, and a random number of quantizers is used.
215
+ Returns
216
+ -------
217
+ dict
218
+ A dictionary with the following keys:
219
+
220
+ "z" : Tensor[B x D x T]
221
+ Quantized continuous representation of input
222
+ "codes" : Tensor[B x N x T]
223
+ Codebook indices for each codebook
224
+ (quantized discrete representation of input)
225
+ "latents" : Tensor[B x N*D x T]
226
+ Projected latents (continuous representation of input before quantization)
227
+ "vq/commitment_loss" : Tensor[1]
228
+ Commitment loss to train encoder to predict vectors closer to codebook
229
+ entries
230
+ "vq/codebook_loss" : Tensor[1]
231
+ Codebook loss to update the codebook
232
+ """
233
+ z_q = 0
234
+ residual = z
235
+ commitment_loss = 0
236
+ codebook_loss = 0
237
+
238
+ codebook_indices = []
239
+ latents = []
240
+
241
+ if n_quantizers is None:
242
+ n_quantizers = self.n_codebooks
243
+ if self.training:
244
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
245
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
246
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
247
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
248
+ n_quantizers = n_quantizers.to(z.device)
249
+
250
+ for i, quantizer in enumerate(self.quantizers):
251
+ if self.training is False and i >= n_quantizers:
252
+ break
253
+
254
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
255
+ residual
256
+ )
257
+
258
+ # Create mask to apply quantizer dropout
259
+ mask = (
260
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
261
+ )
262
+ z_q = z_q + z_q_i * mask[:, None, None]
263
+ residual = residual - z_q_i
264
+
265
+ # Sum losses
266
+ commitment_loss += (commitment_loss_i * mask).mean()
267
+ codebook_loss += (codebook_loss_i * mask).mean()
268
+
269
+ codebook_indices.append(indices_i)
270
+ latents.append(z_e_i)
271
+
272
+ codes = torch.stack(codebook_indices, dim=1)
273
+ latents = torch.cat(latents, dim=1)
274
+
275
+ return z_q, codes, latents, commitment_loss, codebook_loss
276
+
277
+ def from_codes(self, codes: torch.Tensor):
278
+ """Given the quantized codes, reconstruct the continuous representation
279
+ Parameters
280
+ ----------
281
+ codes : Tensor[B x N x T]
282
+ Quantized discrete representation of input
283
+ Returns
284
+ -------
285
+ Tensor[B x D x T]
286
+ Quantized continuous representation of input
287
+ """
288
+ z_q = 0.0
289
+ z_p = []
290
+ n_codebooks = codes.shape[1]
291
+ for i in range(n_codebooks):
292
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
293
+ z_p.append(z_p_i)
294
+
295
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
296
+ z_q = z_q + z_q_i
297
+ return z_q, torch.cat(z_p, dim=1), codes
298
+
299
+ def from_latents(self, latents: torch.Tensor):
300
+ """Given the unquantized latents, reconstruct the
301
+ continuous representation after quantization.
302
+
303
+ Parameters
304
+ ----------
305
+ latents : Tensor[B x N x T]
306
+ Continuous representation of input after projection
307
+
308
+ Returns
309
+ -------
310
+ Tensor[B x D x T]
311
+ Quantized representation of full-projected space
312
+ Tensor[B x D x T]
313
+ Quantized representation of latent space
314
+ """
315
+ z_q = 0
316
+ z_p = []
317
+ codes = []
318
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
319
+
320
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
321
+ 0
322
+ ]
323
+ for i in range(n_codebooks):
324
+ j, k = dims[i], dims[i + 1]
325
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
326
+ z_p.append(z_p_i)
327
+ codes.append(codes_i)
328
+
329
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
330
+ z_q = z_q + z_q_i
331
+
332
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
333
+
334
+
335
+ if __name__ == "__main__":
336
+ rvq = ResidualVectorQuantize(quantizer_dropout=True)
337
+ x = torch.randn(16, 512, 80)
338
+ y = rvq(x)
339
+ print(y["latents"].shape)
dac/utils/__init__.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import argbind
4
+ from audiotools import ml
5
+
6
+ import dac
7
+
8
+ DAC = dac.model.DAC
9
+ Accelerator = ml.Accelerator
10
+
11
+ __MODEL_LATEST_TAGS__ = {
12
+ ("44khz", "8kbps"): "0.0.1",
13
+ ("24khz", "8kbps"): "0.0.4",
14
+ ("16khz", "8kbps"): "0.0.5",
15
+ ("44khz", "16kbps"): "1.0.0",
16
+ }
17
+
18
+ __MODEL_URLS__ = {
19
+ (
20
+ "44khz",
21
+ "0.0.1",
22
+ "8kbps",
23
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth",
24
+ (
25
+ "24khz",
26
+ "0.0.4",
27
+ "8kbps",
28
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth",
29
+ (
30
+ "16khz",
31
+ "0.0.5",
32
+ "8kbps",
33
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth",
34
+ (
35
+ "44khz",
36
+ "1.0.0",
37
+ "16kbps",
38
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth",
39
+ }
40
+
41
+
42
+ @argbind.bind(group="download", positional=True, without_prefix=True)
43
+ def download(
44
+ model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest"
45
+ ):
46
+ """
47
+ Function that downloads the weights file from URL if a local cache is not found.
48
+
49
+ Parameters
50
+ ----------
51
+ model_type : str
52
+ The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz".
53
+ model_bitrate: str
54
+ Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
55
+ Only 44khz model supports 16kbps.
56
+ tag : str
57
+ The tag of the model to download. Defaults to "latest".
58
+
59
+ Returns
60
+ -------
61
+ Path
62
+ Directory path required to load model via audiotools.
63
+ """
64
+ model_type = model_type.lower()
65
+ tag = tag.lower()
66
+
67
+ assert model_type in [
68
+ "44khz",
69
+ "24khz",
70
+ "16khz",
71
+ ], "model_type must be one of '44khz', '24khz', or '16khz'"
72
+
73
+ assert model_bitrate in [
74
+ "8kbps",
75
+ "16kbps",
76
+ ], "model_bitrate must be one of '8kbps', or '16kbps'"
77
+
78
+ if tag == "latest":
79
+ tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)]
80
+
81
+ download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None)
82
+
83
+ if download_link is None:
84
+ raise ValueError(
85
+ f"Could not find model with tag {tag} and model type {model_type}"
86
+ )
87
+
88
+ local_path = (
89
+ Path.home()
90
+ / ".cache"
91
+ / "descript"
92
+ / "dac"
93
+ / f"weights_{model_type}_{model_bitrate}_{tag}.pth"
94
+ )
95
+ if not local_path.exists():
96
+ local_path.parent.mkdir(parents=True, exist_ok=True)
97
+
98
+ # Download the model
99
+ import requests
100
+
101
+ response = requests.get(download_link)
102
+
103
+ if response.status_code != 200:
104
+ raise ValueError(
105
+ f"Could not download model. Received response code {response.status_code}"
106
+ )
107
+ local_path.write_bytes(response.content)
108
+
109
+ return local_path
110
+
111
+
112
+ def load_model(
113
+ model_type: str = "44khz",
114
+ model_bitrate: str = "8kbps",
115
+ tag: str = "latest",
116
+ load_path: str = None,
117
+ ):
118
+ if not load_path:
119
+ load_path = download(
120
+ model_type=model_type, model_bitrate=model_bitrate, tag=tag
121
+ )
122
+ generator = DAC.load(load_path)
123
+ return generator
dac/utils/decode.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from pathlib import Path
3
+
4
+ import argbind
5
+ import numpy as np
6
+ import torch
7
+ from audiotools import AudioSignal
8
+ from tqdm import tqdm
9
+
10
+ from dac import DACFile
11
+ from dac.utils import load_model
12
+
13
+ warnings.filterwarnings("ignore", category=UserWarning)
14
+
15
+
16
+ @argbind.bind(group="decode", positional=True, without_prefix=True)
17
+ @torch.inference_mode()
18
+ @torch.no_grad()
19
+ def decode(
20
+ input: str,
21
+ output: str = "",
22
+ weights_path: str = "",
23
+ model_tag: str = "latest",
24
+ model_bitrate: str = "8kbps",
25
+ device: str = "cuda",
26
+ model_type: str = "44khz",
27
+ verbose: bool = False,
28
+ ):
29
+ """Decode audio from codes.
30
+
31
+ Parameters
32
+ ----------
33
+ input : str
34
+ Path to input directory or file
35
+ output : str, optional
36
+ Path to output directory, by default "".
37
+ If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
38
+ weights_path : str, optional
39
+ Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
40
+ model_tag and model_type.
41
+ model_tag : str, optional
42
+ Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
43
+ model_bitrate: str
44
+ Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
45
+ device : str, optional
46
+ Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU.
47
+ model_type : str, optional
48
+ The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
49
+ """
50
+ generator = load_model(
51
+ model_type=model_type,
52
+ model_bitrate=model_bitrate,
53
+ tag=model_tag,
54
+ load_path=weights_path,
55
+ )
56
+ generator.to(device)
57
+ generator.eval()
58
+
59
+ # Find all .dac files in input directory
60
+ _input = Path(input)
61
+ input_files = list(_input.glob("**/*.dac"))
62
+
63
+ # If input is a .dac file, add it to the list
64
+ if _input.suffix == ".dac":
65
+ input_files.append(_input)
66
+
67
+ # Create output directory
68
+ output = Path(output)
69
+ output.mkdir(parents=True, exist_ok=True)
70
+
71
+ for i in tqdm(range(len(input_files)), desc=f"Decoding files"):
72
+ # Load file
73
+ artifact = DACFile.load(input_files[i])
74
+
75
+ # Reconstruct audio from codes
76
+ recons = generator.decompress(artifact, verbose=verbose)
77
+
78
+ # Compute output path
79
+ relative_path = input_files[i].relative_to(input)
80
+ output_dir = output / relative_path.parent
81
+ if not relative_path.name:
82
+ output_dir = output
83
+ relative_path = input_files[i]
84
+ output_name = relative_path.with_suffix(".wav").name
85
+ output_path = output_dir / output_name
86
+ output_path.parent.mkdir(parents=True, exist_ok=True)
87
+
88
+ # Write to file
89
+ recons.write(output_path)
90
+
91
+
92
+ if __name__ == "__main__":
93
+ args = argbind.parse_args()
94
+ with argbind.scope(args):
95
+ decode()
dac/utils/encode.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from pathlib import Path
4
+
5
+ import argbind
6
+ import numpy as np
7
+ import torch
8
+ from audiotools import AudioSignal
9
+ from audiotools.core import util
10
+ from tqdm import tqdm
11
+
12
+ from dac.utils import load_model
13
+
14
+ warnings.filterwarnings("ignore", category=UserWarning)
15
+
16
+
17
+ @argbind.bind(group="encode", positional=True, without_prefix=True)
18
+ @torch.inference_mode()
19
+ @torch.no_grad()
20
+ def encode(
21
+ input: str,
22
+ output: str = "",
23
+ weights_path: str = "",
24
+ model_tag: str = "latest",
25
+ model_bitrate: str = "8kbps",
26
+ n_quantizers: int = None,
27
+ device: str = "cuda",
28
+ model_type: str = "44khz",
29
+ win_duration: float = 5.0,
30
+ verbose: bool = False,
31
+ ):
32
+ """Encode audio files in input path to .dac format.
33
+
34
+ Parameters
35
+ ----------
36
+ input : str
37
+ Path to input audio file or directory
38
+ output : str, optional
39
+ Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
40
+ weights_path : str, optional
41
+ Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
42
+ model_tag and model_type.
43
+ model_tag : str, optional
44
+ Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
45
+ model_bitrate: str
46
+ Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
47
+ n_quantizers : int, optional
48
+ Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate.
49
+ device : str, optional
50
+ Device to use, by default "cuda"
51
+ model_type : str, optional
52
+ The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
53
+ """
54
+ generator = load_model(
55
+ model_type=model_type,
56
+ model_bitrate=model_bitrate,
57
+ tag=model_tag,
58
+ load_path=weights_path,
59
+ )
60
+ generator.to(device)
61
+ generator.eval()
62
+ kwargs = {"n_quantizers": n_quantizers}
63
+
64
+ # Find all audio files in input path
65
+ input = Path(input)
66
+ audio_files = util.find_audio(input)
67
+
68
+ output = Path(output)
69
+ output.mkdir(parents=True, exist_ok=True)
70
+
71
+ for i in tqdm(range(len(audio_files)), desc="Encoding files"):
72
+ # Load file
73
+ signal = AudioSignal(audio_files[i])
74
+
75
+ # Encode audio to .dac format
76
+ artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs)
77
+
78
+ # Compute output path
79
+ relative_path = audio_files[i].relative_to(input)
80
+ output_dir = output / relative_path.parent
81
+ if not relative_path.name:
82
+ output_dir = output
83
+ relative_path = audio_files[i]
84
+ output_name = relative_path.with_suffix(".dac").name
85
+ output_path = output_dir / output_name
86
+ output_path.parent.mkdir(parents=True, exist_ok=True)
87
+
88
+ artifact.save(output_path)
89
+
90
+
91
+ if __name__ == "__main__":
92
+ args = argbind.parse_args()
93
+ with argbind.scope(args):
94
+ encode()
examples/reference/azuma_0.wav ADDED
Binary file (629 kB). View file
 
examples/reference/dingzhen_0.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3db260824d11f56cdf2fccf2b84ad83c95a732ddfa2f8cb8a20b68ca06ea9ff8
3
+ size 1088420
examples/reference/kobe_0.wav ADDED
Binary file (643 kB). View file
 
examples/reference/s1p1.wav ADDED
Binary file (701 kB). View file
 
examples/reference/s1p2.wav ADDED
Binary file (526 kB). View file
 
examples/reference/s2p1.wav ADDED
Binary file (665 kB). View file
 
examples/reference/s2p2.wav ADDED
Binary file (564 kB). View file
 
examples/reference/s3p1.wav ADDED
Binary file (557 kB). View file
 
examples/reference/s3p2.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d28df338203ad8b3c7485474fac41d9ee2891cf27bc0c0239e3249e6c0efadb
3
+ size 1140390
examples/reference/s4p1.wav ADDED
Binary file (619 kB). View file
 
examples/reference/s4p2.wav ADDED
Binary file (651 kB). View file
 
examples/reference/teio_0.wav ADDED
Binary file (366 kB). View file
 
examples/reference/trump_0.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:716becc9daf00351dfe324398edea9e8378f9453408b27612d92b6721f80ddbc
3
+ size 1379484
examples/source/TECHNOPOLIS - 2085 [vocals]_[cut_14sec].wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:87087ca5260ce96659b01a647edb30bb08527ed7d0c074fb5ae1e8338cc733e5
3
+ size 2796016
examples/source/Wiz Khalifa,Charlie Puth - See You Again [vocals]_[cut_28sec].wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7ff68178df50fc5f4f497099af8efe0b4508e7ff13665ee72d780159e6d1875
3
+ size 5411452
examples/source/glados_0.wav ADDED
Binary file (640 kB). View file
 
examples/source/jay_0.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d30f1500acacb597c3b27d7a5937dd088b8029b27e9db8bf5982085f26f4457
3
+ size 1270124
examples/source/source_s1.wav ADDED
Binary file (599 kB). View file
 
examples/source/source_s2.wav ADDED
Binary file (675 kB). View file
 
examples/source/source_s3.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e784fc21f6eade4633db4050528110a9fc307686487fcec257ad3db35a65c8d
3
+ size 1168780
examples/source/source_s4.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee7ccf45748bf8918c17e9bf80f79af423cefd78d9c8eebb717a395305272b9f
3
+ size 1068360
examples/source/yae_0.wav ADDED
Binary file (528 kB). View file
 
hf_utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from huggingface_hub import hf_hub_download
3
+
4
+
5
+ def load_custom_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename="config.yml"):
6
+ os.makedirs("./checkpoints", exist_ok=True)
7
+ model_path = hf_hub_download(repo_id=repo_id, filename=model_filename, cache_dir="./checkpoints")
8
+ if config_filename is None:
9
+ return model_path
10
+ config_path = hf_hub_download(repo_id=repo_id, filename=config_filename, cache_dir="./checkpoints")
11
+
12
+ return model_path, config_path
modules/__pycache__/audio.cpython-310.pyc ADDED
Binary file (2.43 kB). View file
 
modules/__pycache__/commons.cpython-310.pyc ADDED
Binary file (12.6 kB). View file
 
modules/__pycache__/diffusion_transformer.cpython-310.pyc ADDED
Binary file (7.76 kB). View file
 
modules/__pycache__/encodec.cpython-310.pyc ADDED
Binary file (10.8 kB). View file
 
modules/__pycache__/flow_matching.cpython-310.pyc ADDED
Binary file (5.11 kB). View file
 
modules/__pycache__/length_regulator.cpython-310.pyc ADDED
Binary file (1.58 kB). View file
 
modules/__pycache__/wavenet.cpython-310.pyc ADDED
Binary file (5.15 kB). View file