Localsong commited on
Commit
d0831da
·
verified ·
1 Parent(s): 58dff48

Upload 15 files

Browse files
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,43 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LocalSong
2
+
3
+ LocalSong is an audio generation model focused on melodic instrumental music that uses tag-based conditioning to generate audio.
4
+
5
+ ## Installation
6
+
7
+ ### Prerequisites
8
+
9
+ - Python 3.10 or higher
10
+ - CUDA-capable GPU recommended
11
+
12
+ ### Setup
13
+
14
+ git clone https://huggingface.co/Localsong/LocalSong
15
+ cd localsong
16
+ python3 -m venv venv
17
+ source venv/bin/activate
18
+ pip install -r requirements.txt
19
+
20
+ ### Run
21
+
22
+ python gradio_app.py
23
+
24
+ The interface will be available at `http://localhost:7860`
25
+
26
+ ### Generation Advice
27
+
28
+ Generations should use one of the soundtrack, soundtrack1 or soundtrack2 tags, as well as at least one other tag. They can use up to 8 tags; try combining genres and instruments.
29
+ The default settings (CFG 3.5, steps 200) have been tested as optimal.
30
+ The first generation will be slower due to torch.compile, then speed will increase.
31
+ The model was trained on vocals but not lyrics. Vocals will not have recognizable words.
32
+
33
+ ## Credits
34
+
35
+ This project builds upon the following open-source projects:
36
+
37
+ - **Model Architecture**: Adapted from [DDT](https://github.com/MCG-NJU/DDT)
38
+ - **Flow Matching**: Adapted from [minRF](https://github.com/cloneofsimo/minRF)
39
+ - **Audio VAE**: [ACE-Step](https://github.com/ACE-Step/ACE-Step)
40
+
41
+ ## License
42
+
43
+ This project is licensed under the Apache License 2.0
acestep/checkpoints/music_dcae_f8c8/config.json ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderDC",
3
+ "_diffusers_version": "0.32.2",
4
+ "_name_or_path": "checkpoints/music_dcae_f8c8",
5
+ "attention_head_dim": 32,
6
+ "decoder_act_fns": "silu",
7
+ "decoder_block_out_channels": [
8
+ 128,
9
+ 256,
10
+ 512,
11
+ 1024
12
+ ],
13
+ "decoder_block_types": [
14
+ "ResBlock",
15
+ "ResBlock",
16
+ "ResBlock",
17
+ "EfficientViTBlock"
18
+ ],
19
+ "decoder_layers_per_block": [
20
+ 3,
21
+ 3,
22
+ 3,
23
+ 3
24
+ ],
25
+ "decoder_norm_types": "rms_norm",
26
+ "decoder_qkv_multiscales": [
27
+ [],
28
+ [],
29
+ [
30
+ 5
31
+ ],
32
+ [
33
+ 5
34
+ ]
35
+ ],
36
+ "downsample_block_type": "Conv",
37
+ "encoder_block_out_channels": [
38
+ 128,
39
+ 256,
40
+ 512,
41
+ 1024
42
+ ],
43
+ "encoder_block_types": [
44
+ "ResBlock",
45
+ "ResBlock",
46
+ "ResBlock",
47
+ "EfficientViTBlock"
48
+ ],
49
+ "encoder_layers_per_block": [
50
+ 2,
51
+ 2,
52
+ 3,
53
+ 3
54
+ ],
55
+ "encoder_qkv_multiscales": [
56
+ [],
57
+ [],
58
+ [
59
+ 5
60
+ ],
61
+ [
62
+ 5
63
+ ]
64
+ ],
65
+ "in_channels": 2,
66
+ "latent_channels": 8,
67
+ "scaling_factor": 0.41407,
68
+ "upsample_block_type": "interpolate"
69
+ }
acestep/checkpoints/music_dcae_f8c8/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b0cb469307ac50659d1880db2a99bae47d0df335cbb36853964662d4b80e8ee
3
+ size 313646516
acestep/checkpoints/music_vocoder/config.json ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ADaMoSHiFiGANV1",
3
+ "_diffusers_version": "0.32.2",
4
+ "depths": [
5
+ 3,
6
+ 3,
7
+ 9,
8
+ 3
9
+ ],
10
+ "dims": [
11
+ 128,
12
+ 256,
13
+ 384,
14
+ 512
15
+ ],
16
+ "drop_path_rate": 0.0,
17
+ "f_max": 16000,
18
+ "f_min": 40,
19
+ "hop_length": 512,
20
+ "input_channels": 128,
21
+ "kernel_sizes": [
22
+ 7
23
+ ],
24
+ "n_fft": 2048,
25
+ "n_mels": 128,
26
+ "num_mels": 512,
27
+ "post_conv_kernel_size": 13,
28
+ "pre_conv_kernel_size": 13,
29
+ "resblock_dilation_sizes": [
30
+ [
31
+ 1,
32
+ 3,
33
+ 5
34
+ ],
35
+ [
36
+ 1,
37
+ 3,
38
+ 5
39
+ ],
40
+ [
41
+ 1,
42
+ 3,
43
+ 5
44
+ ],
45
+ [
46
+ 1,
47
+ 3,
48
+ 5
49
+ ]
50
+ ],
51
+ "resblock_kernel_sizes": [
52
+ 3,
53
+ 7,
54
+ 11,
55
+ 13
56
+ ],
57
+ "sampling_rate": 44100,
58
+ "upsample_initial_channel": 1024,
59
+ "upsample_kernel_sizes": [
60
+ 8,
61
+ 8,
62
+ 4,
63
+ 4,
64
+ 4,
65
+ 4,
66
+ 4
67
+ ],
68
+ "upsample_rates": [
69
+ 4,
70
+ 4,
71
+ 2,
72
+ 2,
73
+ 2,
74
+ 2,
75
+ 2
76
+ ],
77
+ "use_template": false,
78
+ "win_length": 2048
79
+ }
acestep/checkpoints/music_vocoder/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c92c9b46e28ab7b37b777780cf4308ad7ddac869636bb77aa61599358c4bc1c0
3
+ size 206350988
acestep/music_dcae/__init__.py ADDED
File without changes
acestep/music_dcae/music_dcae_pipeline.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ACE-Step: A Step Towards Music Generation Foundation Model
3
+
4
+ https://github.com/ace-step/ACE-Step
5
+
6
+ Apache 2.0 License
7
+ """
8
+
9
+ import os
10
+ import torch
11
+ from diffusers import AutoencoderDC
12
+ import torchaudio
13
+ import torchvision.transforms as transforms
14
+ from diffusers.models.modeling_utils import ModelMixin
15
+ from diffusers.loaders import FromOriginalModelMixin
16
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
17
+ from tqdm import tqdm
18
+
19
+ from acestep.music_dcae.music_vocoder import ADaMoSHiFiGANV1
20
+
21
+
22
+ root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
23
+ DEFAULT_PRETRAINED_PATH = os.path.join(root_dir, "checkpoints", "music_dcae_f8c8")
24
+ VOCODER_PRETRAINED_PATH = os.path.join(root_dir, "checkpoints", "music_vocoder")
25
+
26
+
27
+ class MusicDCAE(ModelMixin, ConfigMixin, FromOriginalModelMixin):
28
+ @register_to_config
29
+ def __init__(
30
+ self,
31
+ source_sample_rate=None,
32
+ dcae_checkpoint_path=DEFAULT_PRETRAINED_PATH,
33
+ vocoder_checkpoint_path=VOCODER_PRETRAINED_PATH,
34
+ ):
35
+ super(MusicDCAE, self).__init__()
36
+
37
+ self.dcae = AutoencoderDC.from_pretrained(dcae_checkpoint_path)
38
+ self.vocoder = ADaMoSHiFiGANV1.from_pretrained(vocoder_checkpoint_path)
39
+
40
+ if source_sample_rate is None:
41
+ source_sample_rate = 48000
42
+
43
+ self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
44
+
45
+ self.transform = transforms.Compose(
46
+ [
47
+ transforms.Normalize(0.5, 0.5),
48
+ ]
49
+ )
50
+ self.min_mel_value = -11.0
51
+ self.max_mel_value = 3.0
52
+ self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000)))
53
+ self.mel_chunk_size = 1024
54
+ self.time_dimention_multiple = 8
55
+ self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple
56
+ self.scale_factor = 0.1786
57
+ self.shift_factor = -1.9091
58
+
59
+ def load_audio(self, audio_path):
60
+ audio, sr = torchaudio.load(audio_path)
61
+ if audio.shape[0] == 1:
62
+ audio = audio.repeat(2, 1)
63
+ return audio, sr
64
+
65
+ def forward_mel(self, audios):
66
+ mels = []
67
+ for i in range(len(audios)):
68
+ image = self.vocoder.mel_transform(audios[i])
69
+ mels.append(image)
70
+ mels = torch.stack(mels)
71
+ return mels
72
+
73
+ @torch.no_grad()
74
+ def encode(self, audios, audio_lengths=None, sr=None):
75
+ if audio_lengths is None:
76
+ audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0])
77
+ audio_lengths = audio_lengths.to(audios.device)
78
+
79
+ # audios: N x 2 x T, 48kHz
80
+ device = audios.device
81
+ dtype = audios.dtype
82
+
83
+ if sr is None:
84
+ sr = 48000
85
+ resampler = self.resampler
86
+ else:
87
+ resampler = torchaudio.transforms.Resample(sr, 44100).to(device).to(dtype)
88
+
89
+ audio = resampler(audios)
90
+
91
+ max_audio_len = audio.shape[-1]
92
+ if max_audio_len % (8 * 512) != 0:
93
+ audio = torch.nn.functional.pad(
94
+ audio, (0, 8 * 512 - max_audio_len % (8 * 512))
95
+ )
96
+
97
+ mels = self.forward_mel(audio)
98
+ mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value)
99
+ mels = self.transform(mels)
100
+ latents = []
101
+ for mel in mels:
102
+ latent = self.dcae.encoder(mel.unsqueeze(0))
103
+ latents.append(latent)
104
+ latents = torch.cat(latents, dim=0)
105
+ latent_lengths = (
106
+ audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple
107
+ ).long()
108
+ latents = (latents - self.shift_factor) * self.scale_factor
109
+ return latents, latent_lengths
110
+
111
+ @torch.no_grad()
112
+ def decode(self, latents, audio_lengths=None, sr=None):
113
+ latents = latents / self.scale_factor + self.shift_factor
114
+
115
+ pred_wavs = []
116
+
117
+ for latent in latents:
118
+ mels = self.dcae.decoder(latent.unsqueeze(0))
119
+ mels = mels * 0.5 + 0.5
120
+ mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
121
+
122
+ # wav = self.vocoder.decode(mels[0]).squeeze(1)
123
+ # decode waveform for each channels to reduce vram footprint
124
+ wav_ch1 = self.vocoder.decode(mels[:,0,:,:]).squeeze(1).cpu()
125
+ wav_ch2 = self.vocoder.decode(mels[:,1,:,:]).squeeze(1).cpu()
126
+ wav = torch.cat([wav_ch1, wav_ch2],dim=0)
127
+
128
+ if sr is not None:
129
+ resampler = (
130
+ torchaudio.transforms.Resample(44100, sr)
131
+ )
132
+ wav = resampler(wav.cpu().float())
133
+ else:
134
+ sr = 44100
135
+ pred_wavs.append(wav)
136
+
137
+ if audio_lengths is not None:
138
+ pred_wavs = [
139
+ wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)
140
+ ]
141
+ return sr, pred_wavs
142
+
143
+ @torch.no_grad()
144
+ def decode_overlap(self, latents, audio_lengths=None, sr=None):
145
+ """
146
+ Decodes latents into waveforms using an overlapped DCAE and Vocoder.
147
+ """
148
+ print("Using Overlapped DCAE and Vocoder")
149
+
150
+ MODEL_INTERNAL_SR = 44100
151
+ DCAE_LATENT_TO_MEL_STRIDE = 8
152
+ VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME = 512
153
+
154
+ pred_wavs = []
155
+ final_output_sr = sr if sr is not None else MODEL_INTERNAL_SR
156
+
157
+ # --- DCAE Parameters ---
158
+ # dcae_win_len_latent: Window length in the latent domain for DCAE processing
159
+ dcae_win_len_latent = 512
160
+ # dcae_mel_win_len: Expected mel window length from DCAE decoder output (latent_win * stride)
161
+ dcae_mel_win_len = dcae_win_len_latent * 8
162
+ # dcae_anchor_offset: Offset from anchor point to actual start of latent window slice
163
+ dcae_anchor_offset = dcae_win_len_latent // 4
164
+ # dcae_anchor_hop: Hop size for anchor points in latent domain
165
+ dcae_anchor_hop = dcae_win_len_latent // 2
166
+ # dcae_mel_overlap_len: Overlap length in the mel domain to be trimmed/blended
167
+ dcae_mel_overlap_len = dcae_mel_win_len // 4
168
+
169
+ # --- Vocoder Parameters ---
170
+ # vocoder_win_len_audio: Audio samples per vocoder processing window
171
+ vocoder_win_len_audio = 512 * 512 # Example: 262144 samples
172
+ # vocoder_overlap_len_audio: Audio samples for overlap between vocoder windows
173
+ vocoder_overlap_len_audio = 1024
174
+ # vocoder_hop_len_audio: Hop size in audio samples for vocoder processing
175
+ vocoder_hop_len_audio = vocoder_win_len_audio - 2 * vocoder_overlap_len_audio
176
+ # vocoder_input_mel_frames_per_block: Number of mel frames fed to vocoder in one go
177
+ vocoder_input_mel_frames_per_block = vocoder_win_len_audio // VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME
178
+
179
+ crossfade_len_audio = 128 # Audio samples for crossfading vocoder outputs
180
+ cf_win_tail = torch.linspace(1, 0, crossfade_len_audio, device=self.device).unsqueeze(0).unsqueeze(0)
181
+ cf_win_head = torch.linspace(0, 1, crossfade_len_audio, device=self.device).unsqueeze(0).unsqueeze(0)
182
+
183
+ for latent_idx, latent_item in enumerate(latents):
184
+ latent_item = latent_item.to(self.device)
185
+ current_latent = (latent_item / self.scale_factor + self.shift_factor).unsqueeze(0) # (1, C, H, W_latent)
186
+ latent_len = current_latent.shape[3]
187
+
188
+ # 1. DCAE: Latent to Mel Spectrogram (Overlapped)
189
+ mels_segments = []
190
+ if latent_len == 0:
191
+ pass # No mel segments to generate
192
+ else:
193
+ # Determine anchor points for DCAE windows
194
+ # An anchor marks a reference point for a window slice.
195
+ # Window slice: current_latent[..., anchor - offset : anchor - offset + win_len]
196
+ # First anchor ensures window starts at 0. Last anchor ensures tail is covered.
197
+ dcae_anchors = list(range(dcae_anchor_offset, latent_len - dcae_anchor_offset, dcae_anchor_hop))
198
+ if not dcae_anchors: # If latent is too short for the range, use one anchor
199
+ dcae_anchors = [dcae_anchor_offset]
200
+
201
+ for i, anchor in enumerate(dcae_anchors):
202
+ win_start_idx = max(0, anchor - dcae_anchor_offset)
203
+ win_end_idx = min(latent_len, win_start_idx + dcae_win_len_latent)
204
+
205
+ dcae_input_segment = current_latent[:, :, :, win_start_idx:win_end_idx]
206
+ if dcae_input_segment.shape[3] == 0: continue
207
+
208
+ mel_output_full = self.dcae.decoder(dcae_input_segment) # (1, C, H_mel, W_mel_fixed_from_dcae)
209
+
210
+ is_first = (i == 0)
211
+ is_last = (i == len(dcae_anchors) - 1)
212
+
213
+ if is_first and is_last: # Only one segment
214
+ # Use mel corresponding to actual input latent length
215
+ true_mel_content_len = dcae_input_segment.shape[3] * DCAE_LATENT_TO_MEL_STRIDE
216
+ mel_to_keep = mel_output_full[:, :, :, :min(true_mel_content_len, mel_output_full.shape[3])]
217
+ elif is_first: # First segment, trim end overlap
218
+ mel_to_keep = mel_output_full[:, :, :, :-dcae_mel_overlap_len]
219
+ elif is_last: # Last segment, trim start overlap
220
+ # And ensure we only take content relevant to the (potentially partial) last latent window
221
+ # The mel_output_full is fixed length. The useful part starts after overlap.
222
+ # The length of the useful part depends on how much of dcae_input_segment was actual content.
223
+ # For simplicity in overlap-add, typically trim fixed overlap.
224
+ # If dcae_input_segment was shorter than dcae_win_len_latent, mel_output_full might contain padding effects.
225
+ # Standard OLA keeps the corresponding tail.
226
+ mel_to_keep = mel_output_full[:, :, :, dcae_mel_overlap_len:]
227
+ else: # Middle segment, trim both overlaps
228
+ mel_to_keep = mel_output_full[:, :, :, dcae_mel_overlap_len:-dcae_mel_overlap_len]
229
+
230
+ if mel_to_keep.shape[3] > 0:
231
+ mels_segments.append(mel_to_keep)
232
+
233
+ if not mels_segments:
234
+ num_mel_channels = current_latent.shape[1]
235
+ mel_height = self.dcae.decoder_output_mel_height
236
+ concatenated_mels = torch.empty(
237
+ (1, num_mel_channels, mel_height, 0),
238
+ device=current_latent.device, dtype=current_latent.dtype
239
+ )
240
+ else:
241
+ concatenated_mels = torch.cat(mels_segments, dim=3)
242
+
243
+ # Denormalize mels
244
+ concatenated_mels = concatenated_mels * 0.5 + 0.5
245
+ concatenated_mels = concatenated_mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
246
+
247
+ mel_total_frames = concatenated_mels.shape[3]
248
+
249
+ # 2. Vocoder: Mel Spectrogram to Waveform (Overlapped)
250
+ if mel_total_frames == 0:
251
+ # Assuming mono or stereo output based on mel channels (typically mono for vocoder from single mel)
252
+ num_audio_channels = 1 # Or determine from vocoder capabilities / mel channels
253
+ final_wav = torch.zeros((num_audio_channels, 0), device=self.device, dtype=torch.float32)
254
+ else:
255
+ # Initial vocoder window
256
+ # Vocoder expects (C_mel, H_mel, W_mel_block)
257
+ mel_block = concatenated_mels[0, :, :, :vocoder_input_mel_frames_per_block].to(self.device)
258
+
259
+ # Pad mel_block if it's shorter than vocoder_input_mel_frames_per_block (e.g. very short audio)
260
+ if 0 < mel_block.shape[2] < vocoder_input_mel_frames_per_block:
261
+ pad_len = vocoder_input_mel_frames_per_block - mel_block.shape[2]
262
+ mel_block = torch.nn.functional.pad(mel_block, (0, pad_len), mode='constant', value=0) # Pad last dim
263
+
264
+ current_audio_output = self.vocoder.decode(mel_block) # (C_audio, 1, Samples)
265
+ current_audio_output = current_audio_output[:, :, :-vocoder_overlap_len_audio] # Remove end overlap
266
+
267
+ # p_audio_samples tracks the start of the *next* audio segment to generate (in conceptual total audio samples)
268
+ p_audio_samples = vocoder_hop_len_audio
269
+ conceptual_total_audio_len_native_sr = mel_total_frames * VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME
270
+
271
+ pbar_total = 1 + max(0, (conceptual_total_audio_len_native_sr - (vocoder_win_len_audio - vocoder_overlap_len_audio))) // vocoder_hop_len_audio
272
+
273
+ # Use tqdm if you want a progress bar for the vocoder part
274
+ # with tqdm(total=pbar_total, desc=f"Vocoder {latent_idx+1}/{len(latents)}", leave=False) as pbar:
275
+ # pbar.update(1) # For initial window
276
+ # The loop for subsequent windows
277
+ while p_audio_samples < conceptual_total_audio_len_native_sr:
278
+ mel_frame_start = p_audio_samples // VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME
279
+ mel_frame_end = mel_frame_start + vocoder_input_mel_frames_per_block
280
+
281
+ if mel_frame_start >= mel_total_frames: break # No more mel frames
282
+
283
+ mel_block = concatenated_mels[0, :, :, mel_frame_start:min(mel_frame_end, mel_total_frames)].to(self.device)
284
+
285
+ if mel_block.shape[2] == 0: break # Should not happen if mel_frame_start is valid
286
+
287
+ # Pad if current mel_block is too short (end of sequence)
288
+ if mel_block.shape[2] < vocoder_input_mel_frames_per_block:
289
+ pad_len = vocoder_input_mel_frames_per_block - mel_block.shape[2]
290
+ mel_block = torch.nn.functional.pad(mel_block, (0, pad_len), mode='constant', value=0)
291
+
292
+ new_audio_win = self.vocoder.decode(mel_block) # (C_audio, 1, Samples)
293
+
294
+ # Crossfade
295
+ # Determine actual crossfade length based on available audio
296
+ actual_cf_len = min(crossfade_len_audio, current_audio_output.shape[2], new_audio_win.shape[2] - (vocoder_overlap_len_audio - crossfade_len_audio))
297
+ if actual_cf_len > 0: # Ensure valid slice lengths for crossfade
298
+ tail_part = current_audio_output[:, :, -actual_cf_len:]
299
+ head_part = new_audio_win[:, :, vocoder_overlap_len_audio - actual_cf_len : vocoder_overlap_len_audio]
300
+
301
+ crossfaded_segment = tail_part * cf_win_tail[:,:,:actual_cf_len] + \
302
+ head_part * cf_win_head[:,:,:actual_cf_len]
303
+
304
+ current_audio_output = torch.cat([current_audio_output[:, :, :-actual_cf_len], crossfaded_segment], dim=2)
305
+
306
+ # Append non-overlapping part of new_audio_win
307
+ is_final_append = (p_audio_samples + vocoder_hop_len_audio >= conceptual_total_audio_len_native_sr)
308
+ if is_final_append:
309
+ segment_to_append = new_audio_win[:, :, vocoder_overlap_len_audio:]
310
+ else:
311
+ segment_to_append = new_audio_win[:, :, vocoder_overlap_len_audio:-vocoder_overlap_len_audio]
312
+
313
+ current_audio_output = torch.cat([current_audio_output, segment_to_append], dim=2)
314
+
315
+ p_audio_samples += vocoder_hop_len_audio
316
+ # pbar.update(1) # if using tqdm
317
+
318
+ final_wav = current_audio_output.squeeze(1) # (C_audio, Samples)
319
+
320
+ # 3. Resampling (if necessary)
321
+ if final_output_sr != MODEL_INTERNAL_SR and final_wav.numel() > 0:
322
+ # Resample expects CPU tensor if using torchaudio.transforms on older versions or for some backends
323
+ resampler = torchaudio.transforms.Resample(
324
+ MODEL_INTERNAL_SR, final_output_sr, dtype=final_wav.dtype
325
+ )
326
+ final_wav = resampler(final_wav.cpu()).to(self.device) # Move back to device if needed later
327
+
328
+ pred_wavs.append(final_wav)
329
+
330
+ # 4. Final Truncation
331
+ processed_pred_wavs = []
332
+ for i, wav in enumerate(pred_wavs):
333
+ # Calculate expected length based on original latent, at the FINAL output sample rate
334
+ _num_latent_frames = latents[i].shape[-1] # Use original latent item for shape
335
+ _num_mel_frames = _num_latent_frames * DCAE_LATENT_TO_MEL_STRIDE
336
+ _conceptual_native_audio_len = _num_mel_frames * VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME
337
+ max_possible_len = int(_conceptual_native_audio_len * final_output_sr / MODEL_INTERNAL_SR)
338
+
339
+ current_wav_len = wav.shape[1]
340
+
341
+ if audio_lengths is not None:
342
+ # User-provided length is the primary target, capped by actual and max possible
343
+ target_len = min(audio_lengths[i], current_wav_len, max_possible_len)
344
+ else:
345
+ # No user length, use max possible capped by actual
346
+ target_len = min(max_possible_len, current_wav_len)
347
+
348
+ processed_pred_wavs.append(wav[:, :max(0, target_len)].cpu()) # Ensure length is non-negative
349
+
350
+ return final_output_sr, processed_pred_wavs
351
+
352
+ def forward(self, audios, audio_lengths=None, sr=None):
353
+ latents, latent_lengths = self.encode(
354
+ audios=audios, audio_lengths=audio_lengths, sr=sr
355
+ )
356
+ sr, pred_wavs = self.decode(latents=latents, audio_lengths=audio_lengths, sr=sr)
357
+ return sr, pred_wavs, latents, latent_lengths
358
+
359
+
360
+ if __name__ == "__main__":
361
+
362
+ audio, sr = torchaudio.load("test.wav")
363
+ audio_lengths = torch.tensor([audio.shape[1]])
364
+ audios = audio.unsqueeze(0)
365
+
366
+ # test encode only
367
+ model = MusicDCAE()
368
+ # latents, latent_lengths = model.encode(audios, audio_lengths)
369
+ # print("latents shape: ", latents.shape)
370
+ # print("latent_lengths: ", latent_lengths)
371
+
372
+ # test encode and decode
373
+ sr, pred_wavs, latents, latent_lengths = model(audios, audio_lengths, sr)
374
+ print("reconstructed wavs: ", pred_wavs[0].shape)
375
+ print("latents shape: ", latents.shape)
376
+ print("latent_lengths: ", latent_lengths)
377
+ print("sr: ", sr)
378
+ torchaudio.save("test_reconstructed.wav", pred_wavs[0], sr)
379
+ print("test_reconstructed.wav")
acestep/music_dcae/music_log_mel.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ACE-Step: A Step Towards Music Generation Foundation Model
3
+
4
+ https://github.com/ace-step/ACE-Step
5
+
6
+ Apache 2.0 License
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch import Tensor
12
+ from torchaudio.transforms import MelScale
13
+
14
+
15
+ class LinearSpectrogram(nn.Module):
16
+ def __init__(
17
+ self,
18
+ n_fft=2048,
19
+ win_length=2048,
20
+ hop_length=512,
21
+ center=False,
22
+ mode="pow2_sqrt",
23
+ ):
24
+ super().__init__()
25
+
26
+ self.n_fft = n_fft
27
+ self.win_length = win_length
28
+ self.hop_length = hop_length
29
+ self.center = center
30
+ self.mode = mode
31
+
32
+ self.register_buffer("window", torch.hann_window(win_length))
33
+
34
+ def forward(self, y: Tensor) -> Tensor:
35
+ if y.ndim == 3:
36
+ y = y.squeeze(1)
37
+
38
+ y = torch.nn.functional.pad(
39
+ y.unsqueeze(1),
40
+ (
41
+ (self.win_length - self.hop_length) // 2,
42
+ (self.win_length - self.hop_length + 1) // 2,
43
+ ),
44
+ mode="reflect",
45
+ ).squeeze(1)
46
+ dtype = y.dtype
47
+ spec = torch.stft(
48
+ y.float(),
49
+ self.n_fft,
50
+ hop_length=self.hop_length,
51
+ win_length=self.win_length,
52
+ window=self.window,
53
+ center=self.center,
54
+ pad_mode="reflect",
55
+ normalized=False,
56
+ onesided=True,
57
+ return_complex=True,
58
+ )
59
+ spec = torch.view_as_real(spec)
60
+
61
+ if self.mode == "pow2_sqrt":
62
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
63
+ spec = spec.to(dtype)
64
+ return spec
65
+
66
+
67
+ class LogMelSpectrogram(nn.Module):
68
+ def __init__(
69
+ self,
70
+ sample_rate=44100,
71
+ n_fft=2048,
72
+ win_length=2048,
73
+ hop_length=512,
74
+ n_mels=128,
75
+ center=False,
76
+ f_min=0.0,
77
+ f_max=None,
78
+ ):
79
+ super().__init__()
80
+
81
+ self.sample_rate = sample_rate
82
+ self.n_fft = n_fft
83
+ self.win_length = win_length
84
+ self.hop_length = hop_length
85
+ self.center = center
86
+ self.n_mels = n_mels
87
+ self.f_min = f_min
88
+ self.f_max = f_max or sample_rate // 2
89
+
90
+ self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
91
+ self.mel_scale = MelScale(
92
+ self.n_mels,
93
+ self.sample_rate,
94
+ self.f_min,
95
+ self.f_max,
96
+ self.n_fft // 2 + 1,
97
+ "slaney",
98
+ "slaney",
99
+ )
100
+
101
+ def compress(self, x: Tensor) -> Tensor:
102
+ return torch.log(torch.clamp(x, min=1e-5))
103
+
104
+ def decompress(self, x: Tensor) -> Tensor:
105
+ return torch.exp(x)
106
+
107
+ def forward(self, x: Tensor, return_linear: bool = False) -> Tensor:
108
+ linear = self.spectrogram(x)
109
+ x = self.mel_scale(linear)
110
+ x = self.compress(x)
111
+ # print(x.shape)
112
+ if return_linear:
113
+ return x, self.compress(linear)
114
+
115
+ return x
acestep/music_dcae/music_vocoder.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ACE-Step: A Step Towards Music Generation Foundation Model
3
+
4
+ https://github.com/ace-step/ACE-Step
5
+
6
+ Apache 2.0 License
7
+ """
8
+
9
+ import librosa
10
+ import torch
11
+ from torch import nn
12
+
13
+ from functools import partial
14
+ from math import prod
15
+ from typing import Callable, Tuple, List
16
+
17
+ import numpy as np
18
+ import torch.nn.functional as F
19
+ from torch.nn import Conv1d
20
+ from torch.nn.utils import weight_norm
21
+ from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
22
+ from diffusers.models.modeling_utils import ModelMixin
23
+ from diffusers.loaders import FromOriginalModelMixin
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+
26
+
27
+ try:
28
+ from music_log_mel import LogMelSpectrogram
29
+ except ImportError:
30
+ from .music_log_mel import LogMelSpectrogram
31
+
32
+
33
+ def drop_path(
34
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
35
+ ):
36
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
37
+
38
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
39
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
40
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
41
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
42
+ 'survival rate' as the argument.
43
+
44
+ """ # noqa: E501
45
+
46
+ if drop_prob == 0.0 or not training:
47
+ return x
48
+ keep_prob = 1 - drop_prob
49
+ shape = (x.shape[0],) + (1,) * (
50
+ x.ndim - 1
51
+ ) # work with diff dim tensors, not just 2D ConvNets
52
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
53
+ if keep_prob > 0.0 and scale_by_keep:
54
+ random_tensor.div_(keep_prob)
55
+ return x * random_tensor
56
+
57
+
58
+ class DropPath(nn.Module):
59
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
60
+
61
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
62
+ super(DropPath, self).__init__()
63
+ self.drop_prob = drop_prob
64
+ self.scale_by_keep = scale_by_keep
65
+
66
+ def forward(self, x):
67
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
68
+
69
+ def extra_repr(self):
70
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
71
+
72
+
73
+ class LayerNorm(nn.Module):
74
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
75
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
76
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
77
+ with shape (batch_size, channels, height, width).
78
+ """ # noqa: E501
79
+
80
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
81
+ super().__init__()
82
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
83
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
84
+ self.eps = eps
85
+ self.data_format = data_format
86
+ if self.data_format not in ["channels_last", "channels_first"]:
87
+ raise NotImplementedError
88
+ self.normalized_shape = (normalized_shape,)
89
+
90
+ def forward(self, x):
91
+ if self.data_format == "channels_last":
92
+ return F.layer_norm(
93
+ x, self.normalized_shape, self.weight, self.bias, self.eps
94
+ )
95
+ elif self.data_format == "channels_first":
96
+ u = x.mean(1, keepdim=True)
97
+ s = (x - u).pow(2).mean(1, keepdim=True)
98
+ x = (x - u) / torch.sqrt(s + self.eps)
99
+ x = self.weight[:, None] * x + self.bias[:, None]
100
+ return x
101
+
102
+
103
+ class ConvNeXtBlock(nn.Module):
104
+ r"""ConvNeXt Block. There are two equivalent implementations:
105
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
106
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
107
+ We use (2) as we find it slightly faster in PyTorch
108
+
109
+ Args:
110
+ dim (int): Number of input channels.
111
+ drop_path (float): Stochastic depth rate. Default: 0.0
112
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
113
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
114
+ kernel_size (int): Kernel size for depthwise conv. Default: 7.
115
+ dilation (int): Dilation for depthwise conv. Default: 1.
116
+ """ # noqa: E501
117
+
118
+ def __init__(
119
+ self,
120
+ dim: int,
121
+ drop_path: float = 0.0,
122
+ layer_scale_init_value: float = 1e-6,
123
+ mlp_ratio: float = 4.0,
124
+ kernel_size: int = 7,
125
+ dilation: int = 1,
126
+ ):
127
+ super().__init__()
128
+
129
+ self.dwconv = nn.Conv1d(
130
+ dim,
131
+ dim,
132
+ kernel_size=kernel_size,
133
+ padding=int(dilation * (kernel_size - 1) / 2),
134
+ groups=dim,
135
+ ) # depthwise conv
136
+ self.norm = LayerNorm(dim, eps=1e-6)
137
+ self.pwconv1 = nn.Linear(
138
+ dim, int(mlp_ratio * dim)
139
+ ) # pointwise/1x1 convs, implemented with linear layers
140
+ self.act = nn.GELU()
141
+ self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
142
+ self.gamma = (
143
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
144
+ if layer_scale_init_value > 0
145
+ else None
146
+ )
147
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
148
+
149
+ def forward(self, x, apply_residual: bool = True):
150
+ input = x
151
+
152
+ x = self.dwconv(x)
153
+ x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
154
+ x = self.norm(x)
155
+ x = self.pwconv1(x)
156
+ x = self.act(x)
157
+ x = self.pwconv2(x)
158
+
159
+ if self.gamma is not None:
160
+ x = self.gamma * x
161
+
162
+ x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
163
+ x = self.drop_path(x)
164
+
165
+ if apply_residual:
166
+ x = input + x
167
+
168
+ return x
169
+
170
+
171
+ class ParallelConvNeXtBlock(nn.Module):
172
+ def __init__(self, kernel_sizes: List[int], *args, **kwargs):
173
+ super().__init__()
174
+ self.blocks = nn.ModuleList(
175
+ [
176
+ ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
177
+ for kernel_size in kernel_sizes
178
+ ]
179
+ )
180
+
181
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
182
+ return torch.stack(
183
+ [block(x, apply_residual=False) for block in self.blocks] + [x],
184
+ dim=1,
185
+ ).sum(dim=1)
186
+
187
+
188
+ class ConvNeXtEncoder(nn.Module):
189
+ def __init__(
190
+ self,
191
+ input_channels=3,
192
+ depths=[3, 3, 9, 3],
193
+ dims=[96, 192, 384, 768],
194
+ drop_path_rate=0.0,
195
+ layer_scale_init_value=1e-6,
196
+ kernel_sizes: Tuple[int] = (7,),
197
+ ):
198
+ super().__init__()
199
+ assert len(depths) == len(dims)
200
+
201
+ self.channel_layers = nn.ModuleList()
202
+ stem = nn.Sequential(
203
+ nn.Conv1d(
204
+ input_channels,
205
+ dims[0],
206
+ kernel_size=7,
207
+ padding=3,
208
+ padding_mode="replicate",
209
+ ),
210
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
211
+ )
212
+ self.channel_layers.append(stem)
213
+
214
+ for i in range(len(depths) - 1):
215
+ mid_layer = nn.Sequential(
216
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
217
+ nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
218
+ )
219
+ self.channel_layers.append(mid_layer)
220
+
221
+ block_fn = (
222
+ partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
223
+ if len(kernel_sizes) == 1
224
+ else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
225
+ )
226
+
227
+ self.stages = nn.ModuleList()
228
+ drop_path_rates = [
229
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
230
+ ]
231
+
232
+ cur = 0
233
+ for i in range(len(depths)):
234
+ stage = nn.Sequential(
235
+ *[
236
+ block_fn(
237
+ dim=dims[i],
238
+ drop_path=drop_path_rates[cur + j],
239
+ layer_scale_init_value=layer_scale_init_value,
240
+ )
241
+ for j in range(depths[i])
242
+ ]
243
+ )
244
+ self.stages.append(stage)
245
+ cur += depths[i]
246
+
247
+ self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
248
+ self.apply(self._init_weights)
249
+
250
+ def _init_weights(self, m):
251
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
252
+ nn.init.trunc_normal_(m.weight, std=0.02)
253
+ nn.init.constant_(m.bias, 0)
254
+
255
+ def forward(
256
+ self,
257
+ x: torch.Tensor,
258
+ ) -> torch.Tensor:
259
+ for channel_layer, stage in zip(self.channel_layers, self.stages):
260
+ x = channel_layer(x)
261
+ x = stage(x)
262
+
263
+ return self.norm(x)
264
+
265
+
266
+ def init_weights(m, mean=0.0, std=0.01):
267
+ classname = m.__class__.__name__
268
+ if classname.find("Conv") != -1:
269
+ m.weight.data.normal_(mean, std)
270
+
271
+
272
+ def get_padding(kernel_size, dilation=1):
273
+ return (kernel_size * dilation - dilation) // 2
274
+
275
+
276
+ class ResBlock1(torch.nn.Module):
277
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
278
+ super().__init__()
279
+
280
+ self.convs1 = nn.ModuleList(
281
+ [
282
+ weight_norm(
283
+ Conv1d(
284
+ channels,
285
+ channels,
286
+ kernel_size,
287
+ 1,
288
+ dilation=dilation[0],
289
+ padding=get_padding(kernel_size, dilation[0]),
290
+ )
291
+ ),
292
+ weight_norm(
293
+ Conv1d(
294
+ channels,
295
+ channels,
296
+ kernel_size,
297
+ 1,
298
+ dilation=dilation[1],
299
+ padding=get_padding(kernel_size, dilation[1]),
300
+ )
301
+ ),
302
+ weight_norm(
303
+ Conv1d(
304
+ channels,
305
+ channels,
306
+ kernel_size,
307
+ 1,
308
+ dilation=dilation[2],
309
+ padding=get_padding(kernel_size, dilation[2]),
310
+ )
311
+ ),
312
+ ]
313
+ )
314
+ self.convs1.apply(init_weights)
315
+
316
+ self.convs2 = nn.ModuleList(
317
+ [
318
+ weight_norm(
319
+ Conv1d(
320
+ channels,
321
+ channels,
322
+ kernel_size,
323
+ 1,
324
+ dilation=1,
325
+ padding=get_padding(kernel_size, 1),
326
+ )
327
+ ),
328
+ weight_norm(
329
+ Conv1d(
330
+ channels,
331
+ channels,
332
+ kernel_size,
333
+ 1,
334
+ dilation=1,
335
+ padding=get_padding(kernel_size, 1),
336
+ )
337
+ ),
338
+ weight_norm(
339
+ Conv1d(
340
+ channels,
341
+ channels,
342
+ kernel_size,
343
+ 1,
344
+ dilation=1,
345
+ padding=get_padding(kernel_size, 1),
346
+ )
347
+ ),
348
+ ]
349
+ )
350
+ self.convs2.apply(init_weights)
351
+
352
+ def forward(self, x):
353
+ for c1, c2 in zip(self.convs1, self.convs2):
354
+ xt = F.silu(x)
355
+ xt = c1(xt)
356
+ xt = F.silu(xt)
357
+ xt = c2(xt)
358
+ x = xt + x
359
+ return x
360
+
361
+ def remove_weight_norm(self):
362
+ for conv in self.convs1:
363
+ remove_weight_norm(conv)
364
+ for conv in self.convs2:
365
+ remove_weight_norm(conv)
366
+
367
+
368
+ class HiFiGANGenerator(nn.Module):
369
+ def __init__(
370
+ self,
371
+ *,
372
+ hop_length: int = 512,
373
+ upsample_rates: Tuple[int] = (8, 8, 2, 2, 2),
374
+ upsample_kernel_sizes: Tuple[int] = (16, 16, 8, 2, 2),
375
+ resblock_kernel_sizes: Tuple[int] = (3, 7, 11),
376
+ resblock_dilation_sizes: Tuple[Tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
377
+ num_mels: int = 128,
378
+ upsample_initial_channel: int = 512,
379
+ use_template: bool = True,
380
+ pre_conv_kernel_size: int = 7,
381
+ post_conv_kernel_size: int = 7,
382
+ post_activation: Callable = partial(nn.SiLU, inplace=True),
383
+ ):
384
+ super().__init__()
385
+
386
+ assert (
387
+ prod(upsample_rates) == hop_length
388
+ ), f"hop_length must be {prod(upsample_rates)}"
389
+
390
+ self.conv_pre = weight_norm(
391
+ nn.Conv1d(
392
+ num_mels,
393
+ upsample_initial_channel,
394
+ pre_conv_kernel_size,
395
+ 1,
396
+ padding=get_padding(pre_conv_kernel_size),
397
+ )
398
+ )
399
+
400
+ self.num_upsamples = len(upsample_rates)
401
+ self.num_kernels = len(resblock_kernel_sizes)
402
+
403
+ self.noise_convs = nn.ModuleList()
404
+ self.use_template = use_template
405
+ self.ups = nn.ModuleList()
406
+
407
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
408
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
409
+ self.ups.append(
410
+ weight_norm(
411
+ nn.ConvTranspose1d(
412
+ upsample_initial_channel // (2**i),
413
+ upsample_initial_channel // (2 ** (i + 1)),
414
+ k,
415
+ u,
416
+ padding=(k - u) // 2,
417
+ )
418
+ )
419
+ )
420
+
421
+ if not use_template:
422
+ continue
423
+
424
+ if i + 1 < len(upsample_rates):
425
+ stride_f0 = np.prod(upsample_rates[i + 1 :])
426
+ self.noise_convs.append(
427
+ Conv1d(
428
+ 1,
429
+ c_cur,
430
+ kernel_size=stride_f0 * 2,
431
+ stride=stride_f0,
432
+ padding=stride_f0 // 2,
433
+ )
434
+ )
435
+ else:
436
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
437
+
438
+ self.resblocks = nn.ModuleList()
439
+ for i in range(len(self.ups)):
440
+ ch = upsample_initial_channel // (2 ** (i + 1))
441
+ for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
442
+ self.resblocks.append(ResBlock1(ch, k, d))
443
+
444
+ self.activation_post = post_activation()
445
+ self.conv_post = weight_norm(
446
+ nn.Conv1d(
447
+ ch,
448
+ 1,
449
+ post_conv_kernel_size,
450
+ 1,
451
+ padding=get_padding(post_conv_kernel_size),
452
+ )
453
+ )
454
+ self.ups.apply(init_weights)
455
+ self.conv_post.apply(init_weights)
456
+
457
+ def forward(self, x, template=None):
458
+ x = self.conv_pre(x)
459
+
460
+ for i in range(self.num_upsamples):
461
+ x = F.silu(x, inplace=True)
462
+ x = self.ups[i](x)
463
+
464
+ if self.use_template:
465
+ x = x + self.noise_convs[i](template)
466
+
467
+ xs = None
468
+
469
+ for j in range(self.num_kernels):
470
+ if xs is None:
471
+ xs = self.resblocks[i * self.num_kernels + j](x)
472
+ else:
473
+ xs += self.resblocks[i * self.num_kernels + j](x)
474
+
475
+ x = xs / self.num_kernels
476
+
477
+ x = self.activation_post(x)
478
+ x = self.conv_post(x)
479
+ x = torch.tanh(x)
480
+
481
+ return x
482
+
483
+ def remove_weight_norm(self):
484
+ for up in self.ups:
485
+ remove_weight_norm(up)
486
+ for block in self.resblocks:
487
+ block.remove_weight_norm()
488
+ remove_weight_norm(self.conv_pre)
489
+ remove_weight_norm(self.conv_post)
490
+
491
+
492
+ class ADaMoSHiFiGANV1(ModelMixin, ConfigMixin, FromOriginalModelMixin):
493
+
494
+ @register_to_config
495
+ def __init__(
496
+ self,
497
+ input_channels: int = 128,
498
+ depths: List[int] = [3, 3, 9, 3],
499
+ dims: List[int] = [128, 256, 384, 512],
500
+ drop_path_rate: float = 0.0,
501
+ kernel_sizes: Tuple[int] = (7,),
502
+ upsample_rates: Tuple[int] = (4, 4, 2, 2, 2, 2, 2),
503
+ upsample_kernel_sizes: Tuple[int] = (8, 8, 4, 4, 4, 4, 4),
504
+ resblock_kernel_sizes: Tuple[int] = (3, 7, 11, 13),
505
+ resblock_dilation_sizes: Tuple[Tuple[int]] = (
506
+ (1, 3, 5),
507
+ (1, 3, 5),
508
+ (1, 3, 5),
509
+ (1, 3, 5),
510
+ ),
511
+ num_mels: int = 512,
512
+ upsample_initial_channel: int = 1024,
513
+ use_template: bool = False,
514
+ pre_conv_kernel_size: int = 13,
515
+ post_conv_kernel_size: int = 13,
516
+ sampling_rate: int = 44100,
517
+ n_fft: int = 2048,
518
+ win_length: int = 2048,
519
+ hop_length: int = 512,
520
+ f_min: int = 40,
521
+ f_max: int = 16000,
522
+ n_mels: int = 128,
523
+ ):
524
+ super().__init__()
525
+
526
+ self.backbone = ConvNeXtEncoder(
527
+ input_channels=input_channels,
528
+ depths=depths,
529
+ dims=dims,
530
+ drop_path_rate=drop_path_rate,
531
+ kernel_sizes=kernel_sizes,
532
+ )
533
+
534
+ self.head = HiFiGANGenerator(
535
+ hop_length=hop_length,
536
+ upsample_rates=upsample_rates,
537
+ upsample_kernel_sizes=upsample_kernel_sizes,
538
+ resblock_kernel_sizes=resblock_kernel_sizes,
539
+ resblock_dilation_sizes=resblock_dilation_sizes,
540
+ num_mels=num_mels,
541
+ upsample_initial_channel=upsample_initial_channel,
542
+ use_template=use_template,
543
+ pre_conv_kernel_size=pre_conv_kernel_size,
544
+ post_conv_kernel_size=post_conv_kernel_size,
545
+ )
546
+ self.sampling_rate = sampling_rate
547
+ self.mel_transform = LogMelSpectrogram(
548
+ sample_rate=sampling_rate,
549
+ n_fft=n_fft,
550
+ win_length=win_length,
551
+ hop_length=hop_length,
552
+ f_min=f_min,
553
+ f_max=f_max,
554
+ n_mels=n_mels,
555
+ )
556
+ self.eval()
557
+
558
+ @torch.no_grad()
559
+ def decode(self, mel):
560
+ y = self.backbone(mel)
561
+ y = self.head(y)
562
+ return y
563
+
564
+ @torch.no_grad()
565
+ def encode(self, x):
566
+ return self.mel_transform(x)
567
+
568
+ def forward(self, mel):
569
+ y = self.backbone(mel)
570
+ y = self.head(y)
571
+ return y
572
+
573
+
574
+ if __name__ == "__main__":
575
+ import soundfile as sf
576
+
577
+ x = "test_audio.wav"
578
+ model = ADaMoSHiFiGANV1.from_pretrained(
579
+ "./checkpoints/music_vocoder", local_files_only=True
580
+ )
581
+
582
+ wav, sr = librosa.load(x, sr=44100, mono=True)
583
+ wav = torch.from_numpy(wav).float()[None]
584
+ mel = model.encode(wav)
585
+
586
+ wav = model.decode(mel)[0].mT
587
+ sf.write("test_audio_vocoder_rec.wav", wav.cpu().numpy(), 44100)
checkpoints/checkpoint_461260.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:796a66a9a098ec75554897e830868c8eb4a9a90c35bb4f972ce317420bb1bbb5
3
+ size 2920814816
checkpoints/tag_mapping.json ADDED
@@ -0,0 +1,858 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "rock": 1697,
3
+ "male vocalist": 1698,
4
+ "pop": 1699,
5
+ "energetic": 1700,
6
+ "instrumental": 1701,
7
+ "electronic": 1702,
8
+ "rhythmic": 1703,
9
+ "female vocalist": 1704,
10
+ "passionate": 1705,
11
+ "atmospheric": 1706,
12
+ "rap": 1707,
13
+ "hip hop": 1708,
14
+ "uplifting": 1709,
15
+ "metal": 1710,
16
+ "alternative rock": 1711,
17
+ "pop rock": 1712,
18
+ "dark": 1713,
19
+ "anthemic": 1714,
20
+ "male vocals": 1715,
21
+ "melancholic": 1716,
22
+ "epic": 1717,
23
+ "bittersweet": 1718,
24
+ "love": 1719,
25
+ "dance": 1720,
26
+ "warm": 1721,
27
+ "electronic dance music": 1722,
28
+ "female vocals": 1723,
29
+ "lush": 1724,
30
+ "trap": 1725,
31
+ "introspective": 1726,
32
+ "aggressive": 1727,
33
+ "r&b": 1728,
34
+ "playful": 1729,
35
+ "regional music": 1730,
36
+ "dance-pop": 1731,
37
+ "hard rock": 1732,
38
+ "ambient": 1733,
39
+ "ethereal": 1734,
40
+ "emotional": 1735,
41
+ "heavy": 1736,
42
+ "piano": 1737,
43
+ "mellow": 1738,
44
+ "jazz": 1739,
45
+ "folk": 1740,
46
+ "country": 1741,
47
+ "house": 1742,
48
+ "party": 1743,
49
+ "romantic": 1744,
50
+ "orchestral": 1745,
51
+ "pop rap": 1746,
52
+ "acoustic": 1747,
53
+ "electropop": 1748,
54
+ "electro": 1749,
55
+ "nocturnal": 1750,
56
+ "bass": 1751,
57
+ "guitar": 1752,
58
+ "urban": 1753,
59
+ "soul": 1754,
60
+ "psychedelic": 1755,
61
+ "edm": 1756,
62
+ "experimental": 1757,
63
+ "funk": 1758,
64
+ "futuristic": 1759,
65
+ "boastful": 1760,
66
+ "hypnotic": 1761,
67
+ "heavy metal": 1762,
68
+ "contemporary r&b": 1763,
69
+ "techno": 1764,
70
+ "eclectic": 1765,
71
+ "longing": 1766,
72
+ "violin": 1767,
73
+ "sentimental": 1768,
74
+ "synthpop": 1769,
75
+ "cinematic": 1770,
76
+ "happy": 1771,
77
+ "repetitive": 1772,
78
+ "progressive": 1773,
79
+ "catchy": 1774,
80
+ "sad": 1775,
81
+ "indie pop": 1776,
82
+ "indie rock": 1777,
83
+ "singer-songwriter": 1778,
84
+ "classical music": 1779,
85
+ "slow": 1780,
86
+ "northern american music": 1781,
87
+ "sampling": 1782,
88
+ "trance": 1783,
89
+ "western classical music": 1784,
90
+ "upbeat": 1785,
91
+ "blues": 1786,
92
+ "hip-hop": 1787,
93
+ "ballad": 1788,
94
+ "soothing": 1789,
95
+ "synthwave": 1790,
96
+ "electric guitar": 1791,
97
+ "calm": 1792,
98
+ "raw": 1793,
99
+ "downtempo": 1794,
100
+ "hardcore hip hop": 1795,
101
+ "soft": 1796,
102
+ "dubstep": 1797,
103
+ "classical": 1798,
104
+ "film score": 1799,
105
+ "synth": 1800,
106
+ "triumphant": 1801,
107
+ "drums": 1802,
108
+ "punk": 1803,
109
+ "female voice": 1804,
110
+ "angry": 1805,
111
+ "alternative metal": 1806,
112
+ "acoustic guitar": 1807,
113
+ "lo-fi": 1808,
114
+ "male voice": 1809,
115
+ "dense": 1810,
116
+ "progressive rock": 1811,
117
+ "optimistic": 1812,
118
+ "ominous": 1813,
119
+ "reggae": 1814,
120
+ "sombre": 1815,
121
+ "mysterious": 1816,
122
+ "complex": 1817,
123
+ "contemporary folk": 1818,
124
+ "disco": 1819,
125
+ "drum and bass": 1820,
126
+ "new wave": 1821,
127
+ "nu metal": 1822,
128
+ "summer": 1823,
129
+ "sensual": 1824,
130
+ "powerful": 1825,
131
+ "folk rock": 1826,
132
+ "glitch": 1827,
133
+ "symphonic metal": 1828,
134
+ "emo": 1829,
135
+ "power metal": 1830,
136
+ "conscious": 1831,
137
+ "technical": 1832,
138
+ "suspenseful": 1833,
139
+ "dramatic": 1834,
140
+ "electro house": 1835,
141
+ "deep": 1836,
142
+ "swing": 1837,
143
+ "punk rock": 1838,
144
+ "gangsta rap": 1839,
145
+ "soulful": 1840,
146
+ "intense": 1841,
147
+ "industrial": 1842,
148
+ "cinematic classical": 1843,
149
+ "k-pop": 1844,
150
+ "new age": 1845,
151
+ "hedonistic": 1846,
152
+ "synth-pop": 1847,
153
+ "meditative": 1848,
154
+ "cello": 1849,
155
+ "pop punk": 1850,
156
+ "chillout": 1851,
157
+ "metalcore": 1852,
158
+ "dreamy": 1853,
159
+ "rebellious": 1854,
160
+ "east coast hip hop": 1855,
161
+ "progressive metal": 1856,
162
+ "lonely": 1857,
163
+ "conscious hip hop": 1858,
164
+ "flute": 1859,
165
+ "chill": 1860,
166
+ "phonk": 1861,
167
+ "blues rock": 1862,
168
+ "drum": 1863,
169
+ "quirky": 1864,
170
+ "pop soul": 1865,
171
+ "j-pop": 1866,
172
+ "groovy": 1867,
173
+ "trip hop": 1868,
174
+ "fantasy": 1869,
175
+ "dream pop": 1870,
176
+ "psychedelic rock": 1871,
177
+ "beat": 1872,
178
+ "country rock": 1873,
179
+ "surreal": 1874,
180
+ "gospel": 1875,
181
+ "fast": 1876,
182
+ "soft rock": 1877,
183
+ "smooth": 1878,
184
+ "peaceful": 1879,
185
+ "poetic": 1880,
186
+ "opera": 1881,
187
+ "power pop": 1882,
188
+ "indie folk": 1883,
189
+ "indie": 1884,
190
+ "mechanical": 1885,
191
+ "breakbeat": 1886,
192
+ "anxious": 1887,
193
+ "female vocal": 1888,
194
+ "deep bass": 1889,
195
+ "post-punk": 1890,
196
+ "grunge": 1891,
197
+ "breakup": 1892,
198
+ "choir": 1893,
199
+ "orchestra": 1894,
200
+ "avant-garde": 1895,
201
+ "deep house": 1896,
202
+ "boom bap": 1897,
203
+ "folk pop": 1898,
204
+ "pastoral": 1899,
205
+ "jazz fusion": 1900,
206
+ "progressive house": 1901,
207
+ "synthesizer": 1902,
208
+ "nostalgic": 1903,
209
+ "funky": 1904,
210
+ "country pop": 1905,
211
+ "death": 1906,
212
+ "spiritual": 1907,
213
+ "soundtrack": 1908,
214
+ "2000s": 1909,
215
+ "choral": 1910,
216
+ "strings": 1911,
217
+ "fun": 1912,
218
+ "electric": 1913,
219
+ "post-grunge": 1914,
220
+ "female singer": 1915,
221
+ "male vocal": 1916,
222
+ "modern classical": 1917,
223
+ "death metal": 1918,
224
+ "post-hardcore": 1919,
225
+ "humorous": 1920,
226
+ "heartfelt": 1921,
227
+ "psychedelia": 1922,
228
+ "haunting": 1923,
229
+ "afrobeat": 1924,
230
+ "medieval": 1925,
231
+ "progressive electronic": 1926,
232
+ "adult contemporary": 1927,
233
+ "reggaeton": 1928,
234
+ "dynamic": 1929,
235
+ "contemporary country": 1930,
236
+ "beats": 1931,
237
+ "idm": 1932,
238
+ "southern hip hop": 1933,
239
+ "80s": 1934,
240
+ "cold": 1935,
241
+ "big band": 1936,
242
+ "saxophone": 1937,
243
+ "future bass": 1938,
244
+ "noisy": 1939,
245
+ "gritty": 1940,
246
+ "dark ambient": 1941,
247
+ "trumpet": 1942,
248
+ "art rock": 1943,
249
+ "chaotic": 1944,
250
+ "smooth soul": 1945,
251
+ "post-industrial": 1946,
252
+ "bluegrass": 1947,
253
+ "industrial & noise": 1948,
254
+ "anime": 1949,
255
+ "drill": 1950,
256
+ "electro swing": 1951,
257
+ "dancehall": 1952,
258
+ "epic music": 1953,
259
+ "witch house": 1954,
260
+ "minimalistic": 1955,
261
+ "hispanic american music": 1956,
262
+ "electronica": 1957,
263
+ "americana": 1958,
264
+ "political": 1959,
265
+ "latin": 1960,
266
+ "tech house": 1961,
267
+ "neo-soul": 1962,
268
+ "hispanic music": 1963,
269
+ "heavy bass": 1964,
270
+ "knee surgery": 1965,
271
+ "horror": 1966,
272
+ "psychedelic pop": 1967,
273
+ "industrial metal": 1968,
274
+ "space": 1969,
275
+ "dub": 1970,
276
+ "art pop": 1971,
277
+ "spoken word": 1972,
278
+ "reverb": 1973,
279
+ "caribbean music": 1974,
280
+ "alternative": 1975,
281
+ "symphonic": 1976,
282
+ "cloud rap": 1977,
283
+ "neo-psychedelia": 1978,
284
+ "gothic metal": 1979,
285
+ "classic rock": 1980,
286
+ "female": 1981,
287
+ "bossa nova": 1982,
288
+ "thrash metal": 1983,
289
+ "djent": 1984,
290
+ "teen pop": 1985,
291
+ "cyberpunk": 1986,
292
+ "hardcore": 1987,
293
+ "glam rock": 1988,
294
+ "slow tempo": 1989,
295
+ "jazz rap": 1990,
296
+ "sexy": 1991,
297
+ "harp": 1992,
298
+ "outlaw country": 1993,
299
+ "progressive trance": 1994,
300
+ "european music": 1995,
301
+ "west coast hip hop": 1996,
302
+ "vocal": 1997,
303
+ "alternative dance": 1998,
304
+ "accordion": 1999,
305
+ "minimal": 2000,
306
+ "tribal": 2001,
307
+ "sarcastic": 2002,
308
+ "vocal jazz": 2003,
309
+ "jamaican music": 2004,
310
+ "alternative r&b": 2005,
311
+ "smooth jazz": 2006,
312
+ "gothic": 2007,
313
+ "ska": 2008,
314
+ "manic": 2009,
315
+ "bass guitar": 2010,
316
+ "chillwave": 2011,
317
+ "improvisation": 2012,
318
+ "melancholy": 2013,
319
+ "shoegaze": 2014,
320
+ "big beat": 2015,
321
+ "keyboard": 2016,
322
+ "groove metal": 2017,
323
+ "90s": 2018,
324
+ "latin pop": 2019,
325
+ "hardcore [punk]": 2020,
326
+ "darkwave": 2021,
327
+ "modern": 2022,
328
+ "glam metal": 2023,
329
+ "reflective": 2024,
330
+ "eerie": 2025,
331
+ "chamber pop": 2026,
332
+ "martial": 2027,
333
+ "flamenco": 2028,
334
+ "male singer": 2029,
335
+ "indietronica": 2030,
336
+ "beautiful": 2031,
337
+ "gothic rock": 2032,
338
+ "vocaloid": 2033,
339
+ "world": 2034,
340
+ "math rock": 2035,
341
+ "dark pop": 2036,
342
+ "jazz-funk": 2037,
343
+ "symphonic rock": 2038,
344
+ "club": 2039,
345
+ "bouncy": 2040,
346
+ "easy listening": 2041,
347
+ "j-rock": 2042,
348
+ "baroque": 2043,
349
+ "percussion": 2044,
350
+ "acid jazz": 2045,
351
+ "hardstyle": 2046,
352
+ "rock & roll": 2047,
353
+ "hymn": 2048,
354
+ "dissonant": 2049,
355
+ "ambient pop": 2050,
356
+ "eurodance": 2051,
357
+ "danceable": 2052,
358
+ "turntablism": 2053,
359
+ "dolby atmos": 2054,
360
+ "depressive": 2055,
361
+ "doom metal": 2056,
362
+ "hyperpop": 2057,
363
+ "existential": 2058,
364
+ "melodic metalcore": 2059,
365
+ "male": 2060,
366
+ "chanson": 2061,
367
+ "vaporwave": 2062,
368
+ "salsa": 2063,
369
+ "war": 2064,
370
+ "melodic": 972,
371
+ "fiddle": 2065,
372
+ "film soundtrack": 2066,
373
+ "inspirational": 2067,
374
+ "nu jazz": 2068,
375
+ "vulgar": 2069,
376
+ "abstract": 2070,
377
+ "brass": 2071,
378
+ "confident": 2072,
379
+ "black metal": 2073,
380
+ "video game music": 2074,
381
+ "creepy": 2075,
382
+ "uncommon time signatures": 2076,
383
+ "intimate": 2077,
384
+ "relaxing": 2078,
385
+ "post-rock": 2079,
386
+ "lofi": 2080,
387
+ "roots reggae": 2081,
388
+ "industrial rock": 2082,
389
+ "remix": 2083,
390
+ "storytelling": 2084,
391
+ "funny": 2085,
392
+ "ambient techno": 2086,
393
+ "high-energy": 2087,
394
+ "experimental rock": 2088,
395
+ "southern rock": 2089,
396
+ "celtic": 2090,
397
+ "banjo": 2091,
398
+ "rockabilly": 2092,
399
+ "tabla": 2093,
400
+ "melodic death metal": 2094,
401
+ "minor key": 2095,
402
+ "rap rock": 2096,
403
+ "synth funk": 2097,
404
+ "harmonies": 2098,
405
+ "fast tempo": 2099,
406
+ "garage rock": 2100,
407
+ "breakcore": 2101,
408
+ "harmony": 2102,
409
+ "uptempo": 2103,
410
+ "harmonica": 2104,
411
+ "duet": 2105,
412
+ "alt-pop": 2106,
413
+ "bounce": 2107,
414
+ "hiphop": 2108,
415
+ "funk rock": 2109,
416
+ "jungle": 2110,
417
+ "acoustic rock": 2111,
418
+ "tropical house": 2112,
419
+ "piano rock": 2113,
420
+ "sound effects": 2114,
421
+ "glitch hop": 2115,
422
+ "dance pop": 2116,
423
+ "aquatic": 2117,
424
+ "organ": 2118,
425
+ "baroque pop": 2119,
426
+ "comedy": 2120,
427
+ "theatrical": 2121,
428
+ "sparse": 2122,
429
+ "bassline": 2123,
430
+ "scary": 2124,
431
+ "cute": 2125,
432
+ "drone": 2126,
433
+ "horrorcore": 2127,
434
+ "bass house": 2128,
435
+ "emo rap": 2129,
436
+ "moody": 2130,
437
+ "drums (drum set)": 2131,
438
+ "fast-paced": 2132,
439
+ "double bass": 2133,
440
+ "progressive pop": 2134,
441
+ "apocalyptic": 2135,
442
+ "hardcore punk": 2136,
443
+ "anthem": 2137,
444
+ "europop": 2138,
445
+ "upright bass": 2139,
446
+ "groove": 2140,
447
+ "psytrance": 2141,
448
+ "dark wave": 2142,
449
+ "kpop": 2143,
450
+ "minimal techno": 2144,
451
+ "rock and roll": 2145,
452
+ "grime": 2146,
453
+ "lively": 2147,
454
+ "rave": 2148,
455
+ "syncopated": 2149,
456
+ "show tunes": 2150,
457
+ "autotune": 2151,
458
+ "sitar": 2152,
459
+ "nu-disco": 2153,
460
+ "folk metal": 2154,
461
+ "traditional pop": 2155,
462
+ "surf rock": 2156,
463
+ "noise": 2157,
464
+ "brostep": 2158,
465
+ "serious": 2159,
466
+ "traditional": 2160,
467
+ "pessimistic": 2161,
468
+ "ebm": 2162,
469
+ "female vocalists": 2163,
470
+ "speed metal": 2164,
471
+ "classic": 2165,
472
+ "post-punk revival": 2166,
473
+ "lounge": 2167,
474
+ "electric blues": 2168,
475
+ "winter": 2169,
476
+ "clear vocals": 2170,
477
+ "retro": 2171,
478
+ "raspy": 2172,
479
+ "progressive country": 2173,
480
+ "vibrant": 2174,
481
+ "mystical": 2175,
482
+ "deathcore": 2176,
483
+ "alt-country": 2177,
484
+ "theme": 2178,
485
+ "8-bit": 2179,
486
+ "jangle pop": 2180,
487
+ "aor": 2181,
488
+ "delta blues": 2182,
489
+ "light": 2183,
490
+ "lyrical": 2184,
491
+ "distorted guitars": 2185,
492
+ "jazz-rock": 2186,
493
+ "classical crossover": 2187,
494
+ "fusion": 2188,
495
+ "doo-wop": 2189,
496
+ "television music": 2190,
497
+ "clean": 2191,
498
+ "symphony": 2192,
499
+ "whimsical": 2193,
500
+ "honky tonk": 2194,
501
+ "chamber music": 2195,
502
+ "breathy": 2196,
503
+ "echo": 2197,
504
+ "uk garage": 2198,
505
+ "acid techno": 2199,
506
+ "ritualistic": 2200,
507
+ "scratch": 2201,
508
+ "darksynth": 2202,
509
+ "edgy": 2203,
510
+ "layered harmonies": 2204,
511
+ "rhythm & blues": 2205,
512
+ "80's": 2206,
513
+ "experimental hip hop": 2207,
514
+ "808": 2208,
515
+ "expressive": 2209,
516
+ "1960s": 2210,
517
+ "cryptic": 2211,
518
+ "g-funk": 2212,
519
+ "oud": 2213,
520
+ "male vocalists": 2214,
521
+ "uk drill": 2215,
522
+ "gentle": 2216,
523
+ "musical": 2217,
524
+ "sultry": 2218,
525
+ "samba": 2219,
526
+ "violins": 2220,
527
+ "soul jazz": 2221,
528
+ "alienation": 2222,
529
+ "deep voice": 2223,
530
+ "layered": 2224,
531
+ "screamo": 2225,
532
+ "drift phonk": 2226,
533
+ "shamisen": 2227,
534
+ "rap metal": 2228,
535
+ "strong": 2229,
536
+ "062 final fantasy ii": 3,
537
+ "063 final fantasy iii": 4,
538
+ "064 final fantasy iii remake": 5,
539
+ "066 final fantasy iv": 7,
540
+ "067 final fantasy iv remake": 8,
541
+ "068 final fantasy v": 9,
542
+ "069 final fantasy vi": 10,
543
+ "070 final fantasy vii": 11,
544
+ "071 final fantasy vii remake": 12,
545
+ "072 final fantasy viii": 13,
546
+ "073 final fantasy ix": 14,
547
+ "075 final fantasy x": 15,
548
+ "076 final fantasy xi": 16,
549
+ "077 final fantasy xii": 17,
550
+ "078 final fantasy xiii": 18,
551
+ "079 final fantasy xiv": 19,
552
+ "081 final fantasy xv": 20,
553
+ "082 final fantasy 0": 21,
554
+ "089 final fantasy x2": 26,
555
+ "093 final fantasy xiii2": 29,
556
+ "094 final fantasy xiii3": 30,
557
+ "097 dissidia final fantasy": 33,
558
+ "13 sentinels aegis rim": 40,
559
+ "ace combat 7": 143,
560
+ "advance wars": 144,
561
+ "advance wars days of ruin": 145,
562
+ "advance wars dual strike": 146,
563
+ "advance wars 2 black hole rising": 148,
564
+ "advance wars dual strike": 149,
565
+ "animal crossing wild world": 166,
566
+ "animal crossing new horizons": 167,
567
+ "ar tonelico": 171,
568
+ "armored core": 173,
569
+ "atelier escher and logy": 182,
570
+ "atelier iris": 183,
571
+ "atelier iris 2": 184,
572
+ "atelier iris 3": 185,
573
+ "atelier marie": 186,
574
+ "atelier resleriana": 187,
575
+ "atelier rorona": 188,
576
+ "atelier ryza": 189,
577
+ "atelier ryza 2": 190,
578
+ "atelier ryza 3": 191,
579
+ "atelier totori": 192,
580
+ "atlantis kitsune": 193,
581
+ "attack on titan": 194,
582
+ "azur lane": 198,
583
+ "baldurs gate 3": 255,
584
+ "banjo kazooie": 261,
585
+ "banjo tooie": 262,
586
+ "black clover": 272,
587
+ "black myth wukong": 273,
588
+ "blazblue": 277,
589
+ "bleach": 278,
590
+ "blue reflection": 285,
591
+ "bocchi the rock": 289,
592
+ "castlevania": 329,
593
+ "castlevania dawn of sorrow": 330,
594
+ "castlevania order of ecclesia": 331,
595
+ "castlevania portrait of ruin": 332,
596
+ "castlevania symphony of the night": 333,
597
+ "castlevania aria of sorrow": 334,
598
+ "cave story": 337,
599
+ "celeste": 339,
600
+ "chiptune": 350,
601
+ "chrono cross": 359,
602
+ "chrono trigger": 360,
603
+ "clair obscur": 367,
604
+ "clannad": 368,
605
+ "contra": 374,
606
+ "crosscode": 382,
607
+ "cuphead": 384,
608
+ "dmc4": 397,
609
+ "dmcv": 398,
610
+ "danganronpa": 414,
611
+ "danganronpa 2": 415,
612
+ "deltarune 2": 423,
613
+ "deltarune34": 424,
614
+ "diddy kong racing": 428,
615
+ "gb sounds": 431,
616
+ "disgaea 5": 432,
617
+ "doki doki literature club": 435,
618
+ "donkey kong 64": 436,
619
+ "donkey kong country": 437,
620
+ "donkey kong country 2": 438,
621
+ "donkey kong country 3": 439,
622
+ "doom": 442,
623
+ "dragalia lost": 445,
624
+ "dragon quest ix": 448,
625
+ "drakengard 3": 449,
626
+ "elder scrolls 3 morrowind": 474,
627
+ "etrian odyssey ii": 479,
628
+ "etrian odyssey iii": 480,
629
+ "fzero": 495,
630
+ "fzero maximum velocity": 496,
631
+ "fzero gx": 497,
632
+ "fzero x": 498,
633
+ "fairy tail": 499,
634
+ "far cry 6": 500,
635
+ "fate grand order": 501,
636
+ "fate stay night": 505,
637
+ "fire emblem": 512,
638
+ "fire emblem awakening": 513,
639
+ "fire emblem three houses": 515,
640
+ "fruits basket": 521,
641
+ "fuga melodies of steel": 523,
642
+ "fullmetal alchemist": 524,
643
+ "fullmetal alchemist brotherhood": 525,
644
+ "genshin impact": 554,
645
+ "ghost in the shell": 555,
646
+ "goldeneye 007": 565,
647
+ "granblue fantasy": 566,
648
+ "granblue fantasy versus": 567,
649
+ "gurren lagann": 574,
650
+ "gust": 575,
651
+ "hades": 616,
652
+ "haikyuu": 617,
653
+ "harvest moon": 621,
654
+ "hearthstone": 625,
655
+ "hollow knight": 638,
656
+ "hololive": 639,
657
+ "homestuck": 640,
658
+ "homestuck alternia": 645,
659
+ "homestuck alterniabound": 646,
660
+ "homestuck cherubim": 647,
661
+ "honkai impact 3rd": 658,
662
+ "honkai star rail": 661,
663
+ "jojos bizarre adventure": 755,
664
+ "journey": 756,
665
+ "kid icarus uprising": 800,
666
+ "kill la kill": 803,
667
+ "kingdom hearts 3582 days": 816,
668
+ "kingdom hearts 3d dream drop distance": 817,
669
+ "kingdom hearts recoded": 818,
670
+ "kirby": 819,
671
+ "kirby 64 the crystal shards": 820,
672
+ "kirby ds": 821,
673
+ "kirbys dream land 3": 822,
674
+ "konosuba": 827,
675
+ "lamulana": 859,
676
+ "legend of zelda the": 878,
677
+ "legend of zelda the a link to the past": 879,
678
+ "legend of zelda the majoras mask": 880,
679
+ "legend of zelda the ocarina of time": 881,
680
+ "legend of zelda the phantom hourglass": 882,
681
+ "legend of zelda the spirit tracks": 883,
682
+ "legend of zelda the twilight princess": 884,
683
+ "mana khemia": 936,
684
+ "maple story": 937,
685
+ "mario luigi bowsers inside story": 938,
686
+ "mario luigi dream team": 939,
687
+ "mario luigi partners in time": 940,
688
+ "mario luigi superstar saga": 941,
689
+ "mario 3d land": 942,
690
+ "mario golf": 943,
691
+ "mario kart super circuit": 944,
692
+ "mario kart 64": 945,
693
+ "mario kart 7": 946,
694
+ "mario kart ds": 947,
695
+ "mario kart wii": 948,
696
+ "mario kart 8": 949,
697
+ "mario party 3": 952,
698
+ "mario party 4": 953,
699
+ "mario party 5": 954,
700
+ "mario tennis": 955,
701
+ "mega man": 961,
702
+ "mega man 3": 962,
703
+ "mega man 4": 963,
704
+ "mega man 7": 964,
705
+ "mega man battle network": 965,
706
+ "mega man x": 966,
707
+ "mega man x2": 967,
708
+ "mega man x3": 968,
709
+ "mega man x4": 969,
710
+ "mega man zero zx": 970,
711
+ "metal gear solid 2": 978,
712
+ "metroid": 979,
713
+ "metroid zero mission": 980,
714
+ "metroid prime 2 echoes": 981,
715
+ "metroid prime 3": 982,
716
+ "metroid prime": 983,
717
+ "minecraft": 989,
718
+ "monogatari": 1411,
719
+ "my hero academia": 1006,
720
+ "nausicaa valley of the wind": 1034,
721
+ "neon genesis evangelion": 1039,
722
+ "neon white": 1040,
723
+ "new super mario bros": 1045,
724
+ "new super mario bros wii": 1046,
725
+ "ni no kuni": 1050,
726
+ "ni no kuni 2": 1051,
727
+ "nier automata": 1053,
728
+ "night in the woods": 1054,
729
+ "ninja gaiden 1": 1055,
730
+ "ninja gaiden 2": 1056,
731
+ "omori": 1080,
732
+ "one piece": 1081,
733
+ "outer wilds": 1084,
734
+ "parasite eve": 1114,
735
+ "perfect dark": 1122,
736
+ "phoenix wright ace attorney": 1124,
737
+ "phoenix wright ace attorney 2": 1125,
738
+ "pokemon anime": 1131,
739
+ "pokemon black and white": 1133,
740
+ "pokemon crystal": 1134,
741
+ "pokemon diamond": 1135,
742
+ "pokemon fire red and leaf green": 1137,
743
+ "pokemon heart gold soul silver": 1138,
744
+ "pokemon mystery dungeon blue rescue team": 1140,
745
+ "pokemon mystery dungeon explorers of sky": 1141,
746
+ "pokemon mystery dungeon gates to infinity": 1142,
747
+ "pokemon omega ruby": 1143,
748
+ "pokemon red": 1144,
749
+ "pokemon ruby": 1145,
750
+ "pokemon scarlet": 1147,
751
+ "pokemon sun and moon": 1148,
752
+ "pokemon super mystery dungeon": 1149,
753
+ "pokemon x and y": 1150,
754
+ "pokemon xd gale of darkness": 1151,
755
+ "professor layton and the curious village": 1173,
756
+ "resident evil": 1208,
757
+ "scottpilgrim": 1277,
758
+ "secret of mana": 1284,
759
+ "shin megami tensei iv": 1298,
760
+ "shovelknight": 1300,
761
+ "skyrim": 1320,
762
+ "sonic advance 3": 1335,
763
+ "sonic adventure 2": 1337,
764
+ "sonic mania": 1338,
765
+ "sonic the hedgehog": 1339,
766
+ "sonic the hedgehog 2": 1340,
767
+ "sonic the hedgehog 3": 1341,
768
+ "spirited away": 1347,
769
+ "star fox": 1348,
770
+ "star ocean": 1349,
771
+ "starcraft 2": 1350,
772
+ "stardew valley": 1351,
773
+ "stellaris": 1354,
774
+ "street fighter ii": 1359,
775
+ "super mario 64": 1370,
776
+ "super mario bros 3": 1374,
777
+ "super mario galaxy": 1375,
778
+ "super mario rpg": 1377,
779
+ "super mario sunshine": 1378,
780
+ "super monkey ball 2": 1379,
781
+ "super smash bros brawl": 1381,
782
+ "tales of symphonia": 1391,
783
+ "the sims 2": 1395,
784
+ "totalwar": 1396,
785
+ "touhou 10": 1397,
786
+ "touhou 11": 1398,
787
+ "touhou 12": 1399,
788
+ "touhou 14": 1400,
789
+ "touhou 15": 1401,
790
+ "touhou 6": 1402,
791
+ "touhou 7": 1403,
792
+ "touhou 8": 1405,
793
+ "touhou 9": 1406,
794
+ "tunic": 1407,
795
+ "undertale": 1410,
796
+ "violet evergarden": 1413,
797
+ "wild arms 2": 1417,
798
+ "witcher 3": 1418,
799
+ "wow": 1419,
800
+ "wuthering waves": 1421,
801
+ "xenoblade chronicles": 1423,
802
+ "xenoblade chronicles 2": 1424,
803
+ "xenoblade chronicles 2 torna": 1425,
804
+ "xenoblade chronicles 3": 1426,
805
+ "xenogears": 1427,
806
+ "ys": 1429,
807
+ "zenless zone zero": 1434,
808
+ "beatmania": 1459,
809
+ "berserk": 1460,
810
+ "castle crashers": 1465,
811
+ "everquest": 1470,
812
+ "mortal kombat": 1482,
813
+ "nier": 1483,
814
+ "persona": 1484,
815
+ "sayonara wild hearts": 1487,
816
+ "touhou remixes": 1550,
817
+ "yakuza": 1659,
818
+ "sea shanty": 2230,
819
+ "emo-pop": 2231,
820
+ "skate punk": 2232,
821
+ "bright": 2233,
822
+ "cumbia": 2234,
823
+ "world music": 2235,
824
+ "synth pop": 2236,
825
+ "chorus": 2237,
826
+ "japanese": 2238,
827
+ "schlager": 2239,
828
+ "asian music": 2240,
829
+ "glam pop": 2241,
830
+ "lute": 2242,
831
+ "misanthropic": 2243,
832
+ "christian": 2244,
833
+ "bubblegum pop": 2245,
834
+ "808s": 2246,
835
+ "remastered": 2247,
836
+ "christmas music": 2248,
837
+ "wave": 2249,
838
+ "tango": 2250,
839
+ "hateful": 2251,
840
+ "high energy": 2252,
841
+ "neoclassical darkwave": 2253,
842
+ "electroclash": 2254,
843
+ "seductive": 2255,
844
+ "dungeon synth": 2256,
845
+ "city pop": 2257,
846
+ "heroic": 2258,
847
+ "freestyle": 2259,
848
+ "space ambient": 2260,
849
+ "bounce drop": 2261,
850
+ "afrobeats": 2262,
851
+ "power ballad": 2263,
852
+ "trombone": 2264,
853
+ "guitar solo": 2265,
854
+ "battle": 2266,
855
+ "ending": 2267,
856
+ "soundtrack1": 2268,
857
+ "soundtrack2": 2269
858
+ }
gradio_app.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import List, Tuple
4
+ import uuid
5
+ import json
6
+ import gradio as gr
7
+ import torch
8
+ import torchaudio
9
+ from safetensors.torch import load_file
10
+
11
+ from model import LocalSongModel
12
+ from acestep.music_dcae.music_dcae_pipeline import MusicDCAE
13
+
14
+ class TagEmbedder:
15
+ def __init__(self, mapping_file: str = "checkpoints/tag_mapping.json"):
16
+
17
+ with open(mapping_file, 'r', encoding='utf-8') as f:
18
+ self.tag_mapping = json.load(f)
19
+
20
+ print(f"Loaded {len(self.tag_mapping)} tags from {mapping_file}")
21
+ self.num_classes = 2304
22
+
23
+ class AudioVAE:
24
+ def __init__(self, device: torch.device):
25
+ self.model = MusicDCAE().to(device)
26
+ self.model.eval()
27
+ self.device = device
28
+ self.latent_mean = torch.tensor(
29
+ [0.1207, -0.0186, -0.0947, -0.3779, 0.5956, 0.3422, 0.1796, -0.0526],
30
+ device=device,
31
+ ).view(1, -1, 1, 1)
32
+ self.latent_std = torch.tensor(
33
+ [0.4638, 0.3154, 0.6244, 1.5078, 0.4696, 0.4633, 0.5614, 0.2707],
34
+ device=device,
35
+ ).view(1, -1, 1, 1)
36
+
37
+ def decode(self, latents: torch.Tensor) -> torch.Tensor:
38
+ with torch.no_grad():
39
+ latents = latents * self.latent_std + self.latent_mean
40
+ sr, audio_list = self.model.decode(latents, sr=48000)
41
+ audio_batch = torch.stack(audio_list).to(self.device)
42
+ return audio_batch
43
+
44
+ class RF:
45
+ def __init__(self, model: torch.nn.Module):
46
+ self.model = model
47
+
48
+ def sample(
49
+ self,
50
+ z: torch.Tensor,
51
+ cond: List[List[int]],
52
+ null_cond: List[List[int]] | None = None,
53
+ sample_steps: int = 100,
54
+ cfg: float = 3.0,
55
+ ) -> List[torch.Tensor]:
56
+ batch = z.size(0)
57
+ dt = 1.0 / sample_steps
58
+ dt = torch.tensor([dt] * batch, device=z.device).view([batch, *([1] * len(z.shape[1:]))])
59
+ images = [z]
60
+ for i in range(sample_steps, 0, -1):
61
+ t = torch.tensor([i / sample_steps] * batch, device=z.device)
62
+
63
+ if null_cond is not None:
64
+
65
+ z_batched = torch.cat([z, z], dim=0)
66
+ t_batched = torch.cat([t, t], dim=0)
67
+ cond_batched = cond + null_cond
68
+ v_batched = self.model(z_batched, t_batched, cond_batched)
69
+ vc, vu = v_batched.chunk(2, dim=0)
70
+ vc = vu + cfg * (vc - vu)
71
+
72
+ else:
73
+ vc = self.model(z, t, cond)
74
+
75
+ z = z - dt * vc
76
+ images.append(z)
77
+ return images
78
+
79
+ model: torch.nn.Module | None = None
80
+ vae: AudioVAE | None = None
81
+ tag_embedder: TagEmbedder | None = None
82
+ rf_sampler: RF | None = None
83
+ device: torch.device | None = None
84
+ _available_tags: List[str] | None = None
85
+
86
+ def load_resources() -> List[str]:
87
+
88
+ torch.set_float32_matmul_precision('high')
89
+
90
+ global model, vae, tag_embedder, rf_sampler, device, _available_tags
91
+
92
+ if _available_tags is not None:
93
+ return _available_tags
94
+
95
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
96
+
97
+ tag_embedder = TagEmbedder()
98
+
99
+ model = LocalSongModel(
100
+ in_channels=8,
101
+ num_groups=16,
102
+ hidden_size=1024,
103
+ decoder_hidden_size=2048,
104
+ num_blocks=36,
105
+ patch_size=(16, 1),
106
+ num_classes=tag_embedder.num_classes,
107
+ max_tags=8,
108
+ ).to(device)
109
+
110
+ checkpoint_path = "checkpoints/checkpoint_461260.safetensors"
111
+ print(f"Loading checkpoint: {checkpoint_path}")
112
+
113
+ state_dict = load_file(checkpoint_path, device=str(device))
114
+ model.load_state_dict(state_dict, strict=True)
115
+ model.eval()
116
+
117
+ vae = AudioVAE(device)
118
+ rf_sampler = RF(model)
119
+
120
+ _available_tags = sorted(tag_embedder.tag_mapping.keys())
121
+ return _available_tags
122
+
123
+
124
+ def _tags_to_indices(tags: List[str]) -> List[int]:
125
+ assert tag_embedder is not None
126
+ indices = []
127
+
128
+ for tag in tags:
129
+ tag_lower = tag.lower().strip()
130
+ if tag_lower in tag_embedder.tag_mapping:
131
+ indices.append(tag_embedder.tag_mapping[tag_lower])
132
+
133
+ return indices
134
+
135
+
136
+ def generate_audio(
137
+ tags: List[str],
138
+ cfg: float,
139
+ sample_steps: int,
140
+ ) -> Tuple[Tuple[int, object], str]:
141
+
142
+ load_resources()
143
+ assert model is not None and vae is not None and rf_sampler is not None and device is not None
144
+
145
+ if not tags:
146
+ tags = []
147
+ if len(tags) > 8:
148
+ raise gr.Error("A maximum of 8 tags is supported.")
149
+
150
+ tag_indices = _tags_to_indices(tags)
151
+
152
+ batch = 1
153
+ channels = 8
154
+ height = 16
155
+ width = 512
156
+
157
+ z = torch.randn(batch, channels, height, width, device=device)
158
+ cond = [tag_indices]
159
+ null_cond = [[]]
160
+
161
+ with torch.no_grad():
162
+ sampled_latents = rf_sampler.sample(
163
+ z=z,
164
+ cond=cond,
165
+ null_cond=null_cond,
166
+ sample_steps=sample_steps,
167
+ cfg=cfg,
168
+ )[-1]
169
+ audio = vae.decode(sampled_latents)
170
+
171
+ audio_tensor = audio[0].cpu()
172
+ sr = 48000
173
+ audio_numpy = audio_tensor.transpose(0, 1).numpy()
174
+
175
+ os.makedirs("generated", exist_ok=True)
176
+ output_path = f"generated/generated_{uuid.uuid4().hex}.wav"
177
+ torchaudio.save(str(output_path), audio_tensor, sr)
178
+
179
+ return (sr, audio_numpy), str(output_path)
180
+
181
+ def build_interface() -> gr.Blocks:
182
+ available_tags = load_resources()
183
+
184
+ # Define preset tag combinations
185
+ presets = [
186
+ ["soundtrack1", "female vocalist","rock","melodic"],
187
+ ["soundtrack", "chrono trigger", "emotional", "piano", "strings"],
188
+ ["soundtrack", "touhou 10", "trumpet"],
189
+ ["soundtrack", "christmas music","winter","melodic"],
190
+ ["soundtrack2", "male vocalist","pop","melodic","acoustic guitar","ballad"],
191
+ ]
192
+
193
+ with gr.Blocks(title="LocalSong") as demo:
194
+ gr.Markdown("# LocalSong")
195
+
196
+ with gr.Row():
197
+ tag_input = gr.Dropdown(
198
+ label="Tags (select up to 8)",
199
+ choices=available_tags,
200
+ multiselect=True,
201
+ max_choices=8,
202
+ value=presets[0],
203
+ )
204
+
205
+ gr.Markdown("**Presets:**")
206
+ with gr.Row():
207
+ for preset in presets:
208
+ btn = gr.Button(f"{' + '.join(preset)}", size="sm")
209
+ def make_preset_fn(p):
210
+ return lambda: p
211
+ btn.click(fn=make_preset_fn(preset), inputs=None, outputs=tag_input)
212
+
213
+ with gr.Row():
214
+ cfg_slider = gr.Slider(
215
+ label="CFG Scale",
216
+ minimum=1.0,
217
+ maximum=7.0,
218
+ step=0.5,
219
+ value=3.5,
220
+ )
221
+ sample_steps_slider = gr.Slider(
222
+ label="Sample Steps",
223
+ minimum=50,
224
+ maximum=200,
225
+ step=10,
226
+ value=200,
227
+ )
228
+
229
+ with gr.Row():
230
+ seed_input = gr.Number(
231
+ label="Seed",
232
+ value=45,
233
+ precision=0,
234
+ )
235
+
236
+ generate_button = gr.Button("Generate Audio", variant="primary")
237
+ audio_output = gr.Audio(label="Generated Audio", type="numpy")
238
+ download_output = gr.File(label="Download WAV")
239
+
240
+ def generate_wrapper(tags, cfg, steps, seed):
241
+ torch.manual_seed(seed)
242
+ if torch.cuda.is_available():
243
+ torch.cuda.manual_seed(seed)
244
+ return generate_audio(tags, cfg, steps)
245
+
246
+ generate_button.click(
247
+ fn=generate_wrapper,
248
+ inputs=[
249
+ tag_input,
250
+ cfg_slider,
251
+ sample_steps_slider,
252
+ seed_input,
253
+ ],
254
+ outputs=[
255
+ audio_output,
256
+ download_output,
257
+ ],
258
+ )
259
+
260
+ return demo
261
+
262
+ demo = build_interface()
263
+
264
+ if __name__ == "__main__":
265
+ demo.launch()
model.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import torch
3
+ import torch.nn as nn
4
+ import math
5
+ from einops import rearrange
6
+ from torch.nn.functional import scaled_dot_product_attention
7
+
8
+ def modulate(x, shift, scale):
9
+ return x * (1 + scale) + shift
10
+
11
+ class Embed(nn.Module):
12
+ def __init__(
13
+ self,
14
+ in_chans: int = 3,
15
+ embed_dim: int = 768,
16
+ norm_layer = None,
17
+ bias: bool = True,
18
+ ):
19
+ super().__init__()
20
+ self.in_chans = in_chans
21
+ self.embed_dim = embed_dim
22
+ self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
23
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
24
+ def forward(self, x):
25
+ x = self.proj(x)
26
+ x = self.norm(x)
27
+ return x
28
+
29
+ class PatchEmbed(nn.Module):
30
+ def __init__(
31
+ self,
32
+ in_channels=8,
33
+ embed_dim=1152,
34
+ bias=True,
35
+ patch_size=1,
36
+ ):
37
+ super().__init__()
38
+
39
+ self.patch_h, self.patch_w = patch_size
40
+
41
+ self.patch_size = patch_size
42
+ self.proj = nn.Linear(in_channels * self.patch_h * self.patch_w, embed_dim, bias=bias)
43
+ self.in_channels = in_channels
44
+ self.embed_dim = embed_dim
45
+
46
+ def forward(self, latent):
47
+ x = rearrange(latent, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1=self.patch_h, p2=self.patch_w)
48
+ x = self.proj(x)
49
+ return x
50
+
51
+ class FinalLayer(nn.Module):
52
+ """Final layer with configurable patch_size support"""
53
+
54
+ def __init__(self, hidden_size, out_channels=8, patch_size=1):
55
+ super().__init__()
56
+ self.patch_h, self.patch_w = patch_size
57
+
58
+ self.linear = nn.Linear(hidden_size, out_channels * self.patch_h * self.patch_w, bias=True)
59
+ self.out_channels = out_channels
60
+ self.patch_size = patch_size
61
+
62
+ def forward(self, x, target_height, target_width):
63
+
64
+ x = self.linear(x)
65
+
66
+ x = rearrange(x, 'b (h w) (c p1 p2) -> b c (h p1) (w p2)',
67
+ h=target_height, w=target_width,
68
+ p1=self.patch_h, p2=self.patch_w, c=self.out_channels)
69
+ return x
70
+
71
+ class TimestepEmbedder(nn.Module):
72
+
73
+ def __init__(self, hidden_size, frequency_embedding_size=256):
74
+ super().__init__()
75
+ self.mlp = nn.Sequential(
76
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
77
+ nn.SiLU(),
78
+ nn.Linear(hidden_size, hidden_size, bias=True),
79
+ )
80
+ self.frequency_embedding_size = frequency_embedding_size
81
+
82
+ @staticmethod
83
+ def timestep_embedding(t, dim, max_period=10):
84
+ half = dim // 2
85
+ freqs = torch.exp(
86
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
87
+ )
88
+ args = t[..., None].float() * freqs[None, ...]
89
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
90
+ if dim % 2:
91
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
92
+ return embedding
93
+
94
+ def forward(self, t):
95
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
96
+ t_emb = self.mlp(t_freq)
97
+ return t_emb
98
+
99
+ class RMSNorm(nn.Module):
100
+ def __init__(self, hidden_size, eps=1e-6):
101
+ super().__init__()
102
+ self.weight = nn.Parameter(torch.ones(hidden_size))
103
+ self.variance_epsilon = eps
104
+
105
+ def forward(self, hidden_states):
106
+ input_dtype = hidden_states.dtype
107
+ hidden_states = hidden_states.to(torch.float32)
108
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
109
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
110
+ return self.weight * hidden_states.to(input_dtype)
111
+
112
+ class FeedForward(nn.Module):
113
+ def __init__(
114
+ self,
115
+ dim: int,
116
+ hidden_dim: int,
117
+ ):
118
+ super().__init__()
119
+ hidden_dim = int(2 * hidden_dim / 3)
120
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
121
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
122
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
123
+ def forward(self, x):
124
+ x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
125
+ return x
126
+
127
+ def precompute_freqs_cis_2d(dim: int, height: int, width: int, theta: float = 10000.0, scale=1.0):
128
+
129
+ if isinstance(scale, float):
130
+ scale = (scale, scale)
131
+ x_pos = torch.linspace(0, width * scale[0], width)
132
+ y_pos = torch.linspace(0, height * scale[1], height)
133
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
134
+ y_pos = y_pos.reshape(-1)
135
+ x_pos = x_pos.reshape(-1)
136
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
137
+ x_freqs = torch.outer(x_pos, freqs).float()
138
+ y_freqs = torch.outer(y_pos, freqs).float()
139
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
140
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
141
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1)
142
+ freqs_cis = freqs_cis.reshape(height * width, -1)
143
+ return freqs_cis
144
+
145
+ @torch.compiler.disable
146
+ def apply_rotary_emb_2d(
147
+ xq: torch.Tensor,
148
+ xk: torch.Tensor,
149
+ freqs_cis: torch.Tensor,
150
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
151
+
152
+ freqs_cis = freqs_cis[None, None, :, :]
153
+
154
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
155
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
156
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
157
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
158
+ return xq_out.type_as(xq), xk_out.type_as(xk)
159
+
160
+ class RAttention(nn.Module):
161
+ def __init__(
162
+ self,
163
+ dim: int,
164
+ num_heads: int = 8,
165
+ qkv_bias: bool = False,
166
+ qk_norm: bool = True,
167
+ attn_drop: float = 0.,
168
+ proj_drop: float = 0.,
169
+ norm_layer: nn.Module = RMSNorm,
170
+ ) -> None:
171
+ super().__init__()
172
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
173
+
174
+ self.dim = dim
175
+ self.num_heads = num_heads
176
+ self.head_dim = dim // num_heads
177
+ self.scale = self.head_dim ** -0.5
178
+
179
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
180
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
181
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
182
+ self.attn_drop = nn.Dropout(attn_drop)
183
+ self.proj = nn.Linear(dim, dim)
184
+ self.proj_drop = nn.Dropout(proj_drop)
185
+
186
+ def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
187
+ B, N, C = x.shape
188
+
189
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
190
+ q, k, v = qkv[0], qkv[1], qkv[2]
191
+ q = self.q_norm(q.contiguous())
192
+ k = self.k_norm(k.contiguous())
193
+ q, k = apply_rotary_emb_2d(q, k, freqs_cis=pos)
194
+
195
+ q = q.view(B, self.num_heads, -1, C // self.num_heads)
196
+ k = k.view(B, self.num_heads, -1, C // self.num_heads).contiguous()
197
+ v = v.view(B, self.num_heads, -1, C // self.num_heads).contiguous()
198
+
199
+ x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.attn_drop.p if self.training else 0.0)
200
+
201
+ x = x.transpose(1, 2).reshape(B, N, C)
202
+ x = self.proj(x)
203
+ x = self.proj_drop(x)
204
+ return x
205
+
206
+ class CrossAttention(nn.Module):
207
+ def __init__(
208
+ self,
209
+ dim: int,
210
+ context_dim: int,
211
+ num_heads: int,
212
+ qkv_bias: bool = False,
213
+ proj_drop: float = 0.0,
214
+ ):
215
+ super().__init__()
216
+ self.num_heads = num_heads
217
+ self.head_dim = dim // num_heads
218
+ self.scale = self.head_dim**-0.5
219
+
220
+ self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
221
+ self.kv_proj = nn.Linear(context_dim, dim * 2, bias=qkv_bias)
222
+ self.proj = nn.Linear(dim, dim)
223
+ self.proj_drop = nn.Dropout(proj_drop)
224
+
225
+ def forward(self, x: torch.Tensor, context: torch.Tensor, context_mask: torch.Tensor = None) -> torch.Tensor:
226
+ B, N, C = x.shape
227
+ B_ctx, M, C_ctx = context.shape
228
+
229
+ q = self.q_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
230
+ kv = self.kv_proj(context).reshape(B_ctx, M, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
231
+ k, v = kv[0], kv[1]
232
+
233
+ attn_mask = None
234
+ if context_mask is not None:
235
+ attn_mask = torch.zeros(B, 1, 1, M, dtype=q.dtype, device=q.device)
236
+ attn_mask.masked_fill_(~context_mask.unsqueeze(1).unsqueeze(2), float('-inf'))
237
+
238
+ attn = scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=self.proj_drop.p if self.training else 0.0)
239
+
240
+ x = attn.permute(0, 2, 1, 3).reshape(B, N, C)
241
+ x = self.proj(x)
242
+ x = self.proj_drop(x)
243
+ return x
244
+
245
+
246
+ class DDTBlock(nn.Module):
247
+ def __init__(self, hidden_size, groups, mlp_ratio=4.0, context_dim=None, is_encoder_block=False):
248
+ super().__init__()
249
+ self.hidden_size = hidden_size
250
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
251
+ self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
252
+
253
+ self.norm_cross = RMSNorm(hidden_size, eps=1e-6) if context_dim else nn.Identity()
254
+ self.cross_attn = CrossAttention(hidden_size, context_dim, groups) if context_dim else None
255
+
256
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
257
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
258
+ self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
259
+
260
+ self.is_encoder_block = is_encoder_block
261
+ if not is_encoder_block:
262
+ self.adaLN_modulation = nn.Sequential(
263
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
264
+ )
265
+
266
+ def forward(self, x, c, pos, mask=None, context=None, context_mask=None, shared_adaLN=None):
267
+ if self.is_encoder_block:
268
+ adaLN_output = shared_adaLN(c)
269
+ else:
270
+ adaLN_output = self.adaLN_modulation(c)
271
+
272
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = adaLN_output.chunk(6, dim=-1)
273
+
274
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
275
+
276
+ if self.cross_attn is not None and context is not None:
277
+ x = x + self.cross_attn(self.norm_cross(x), context=context, context_mask=context_mask)
278
+
279
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
280
+ return x
281
+
282
+ class LocalSongModel(nn.Module):
283
+ def __init__(
284
+ self,
285
+ in_channels=8,
286
+ num_groups=16,
287
+ hidden_size=1024,
288
+ decoder_hidden_size=2048,
289
+ num_blocks=36,
290
+ patch_size=(16,1),
291
+ num_classes=2304,
292
+ max_tags=8,
293
+ ):
294
+ super().__init__()
295
+ self.in_channels = in_channels
296
+ self.out_channels = in_channels
297
+ self.hidden_size = hidden_size
298
+ self.decoder_hidden_size = decoder_hidden_size
299
+ self.num_groups = num_groups
300
+ self.num_groups = num_groups
301
+ self.num_blocks = num_blocks
302
+ self.patch_size = patch_size
303
+ self.num_classes = num_classes
304
+ self.max_tags = max_tags
305
+
306
+ self.patch_h, self.patch_w = patch_size
307
+
308
+ self.x_embedder = PatchEmbed(
309
+ in_channels=in_channels,
310
+ embed_dim=decoder_hidden_size,
311
+ bias=True,
312
+ patch_size=patch_size
313
+ )
314
+
315
+ self.s_embedder = PatchEmbed(
316
+ in_channels=in_channels,
317
+ embed_dim=decoder_hidden_size,
318
+ bias=True,
319
+ patch_size=patch_size
320
+ )
321
+
322
+ self.encoder_to_decoder = nn.Linear(hidden_size, decoder_hidden_size, bias=False)
323
+
324
+ self.a_to_b_proj = nn.Linear(decoder_hidden_size, hidden_size, bias=False)
325
+
326
+ self.t_embedder = TimestepEmbedder(hidden_size)
327
+
328
+ self.y_embedder = nn.Embedding(num_classes + 1, hidden_size, padding_idx=0)
329
+
330
+ self.final_layer = FinalLayer(
331
+ decoder_hidden_size,
332
+ out_channels=in_channels,
333
+ patch_size=patch_size
334
+ )
335
+
336
+ self.shared_encoder_adaLN = nn.Sequential(
337
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
338
+ )
339
+
340
+ self.shared_decoder_adaLN = nn.Sequential(
341
+ nn.Linear(hidden_size, 6 * decoder_hidden_size, bias=True)
342
+ )
343
+
344
+ self.blocks = nn.ModuleList()
345
+ for i in range(self.num_blocks):
346
+ is_encoder = i < self.num_blocks
347
+
348
+ if is_encoder:
349
+ if i < 1:
350
+ block_hidden_size = decoder_hidden_size
351
+ num_heads = self.num_groups
352
+ elif i >= self.num_blocks - 3:
353
+ block_hidden_size = decoder_hidden_size
354
+ num_heads = self.num_groups
355
+ else:
356
+ block_hidden_size = hidden_size
357
+ num_heads = self.num_groups
358
+ else:
359
+ block_hidden_size = decoder_hidden_size
360
+ num_heads = self.num_groups
361
+
362
+ context_dim = hidden_size if i % 2 == 0 and is_encoder else None
363
+
364
+ self.blocks.append(
365
+ DDTBlock(
366
+ block_hidden_size,
367
+ num_heads,
368
+ context_dim=context_dim,
369
+ is_encoder_block=is_encoder
370
+ )
371
+ )
372
+
373
+ self.bc_projection = nn.Linear(decoder_hidden_size + hidden_size, decoder_hidden_size, bias=False)
374
+
375
+ self.initialize_weights()
376
+ self.precompute_encoder_pos = dict()
377
+ self.precompute_decoder_pos = dict()
378
+
379
+ from functools import lru_cache
380
+
381
+ @lru_cache
382
+ def fetch_encoder_pos(self, height, width, device):
383
+ key = (height, width)
384
+ if key in self.precompute_encoder_pos:
385
+ return self.precompute_encoder_pos[key].to(device)
386
+ else:
387
+ pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
388
+ self.precompute_encoder_pos[key] = pos
389
+ return pos
390
+
391
+ @lru_cache
392
+ def fetch_decoder_pos(self, height, width, device):
393
+ key = (height, width)
394
+ if key in self.precompute_decoder_pos:
395
+ return self.precompute_decoder_pos[key].to(device)
396
+ else:
397
+ pos = precompute_freqs_cis_2d(self.decoder_hidden_size // self.num_groups, height, width).to(device)
398
+ self.precompute_decoder_pos[key] = pos
399
+ return pos
400
+
401
+ def initialize_weights(self):
402
+ for embedder in [self.x_embedder, self.s_embedder]:
403
+ nn.init.xavier_uniform_(embedder.proj.weight)
404
+ if embedder.proj.bias is not None:
405
+ nn.init.constant_(embedder.proj.bias, 0)
406
+
407
+ nn.init.xavier_uniform_(self.encoder_to_decoder.weight)
408
+ nn.init.xavier_uniform_(self.a_to_b_proj.weight)
409
+
410
+ nn.init.normal_(self.y_embedder.weight, std=0.02)
411
+
412
+ with torch.no_grad():
413
+ self.y_embedder.weight[0].fill_(0)
414
+
415
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
416
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
417
+
418
+ nn.init.constant_(self.shared_encoder_adaLN[-1].weight, 0)
419
+ nn.init.constant_(self.shared_encoder_adaLN[-1].bias, 0)
420
+ nn.init.constant_(self.shared_decoder_adaLN[-1].weight, 0)
421
+ nn.init.constant_(self.shared_decoder_adaLN[-1].bias, 0)
422
+
423
+ nn.init.constant_(self.final_layer.linear.weight, 0)
424
+ nn.init.constant_(self.final_layer.linear.bias, 0)
425
+
426
+ nn.init.xavier_uniform_(self.bc_projection.weight)
427
+
428
+ def embed_condition(self, cond):
429
+
430
+ device = self.y_embedder.weight.device
431
+
432
+ max_len = self.max_tags
433
+ batch_size = len(cond)
434
+
435
+ padded_tags = torch.zeros(batch_size, max_len, dtype=torch.long, device=device)
436
+
437
+ for i, tags in enumerate(cond):
438
+ truncated_tags = tags[:max_len]
439
+ padded_tags[i, :len(truncated_tags)] = torch.tensor(truncated_tags, dtype=torch.long, device=device)
440
+
441
+ padding_mask = (padded_tags != 0)
442
+
443
+ embedded = self.y_embedder(padded_tags)
444
+
445
+ return embedded, padding_mask
446
+
447
+ def forward(self, x, t, y):
448
+ y_emb, padding_mask = self.embed_condition(y)
449
+
450
+ return self.forward_emb(x, t, y_emb, padding_mask)
451
+
452
+ @torch.compile()
453
+ def forward_emb(self, x, t, y_emb, padding_mask=None):
454
+ B, _, H, W = x.shape
455
+
456
+ h_patches = H // self.patch_h
457
+ w_patches = W // self.patch_w
458
+ encoder_pos = self.fetch_encoder_pos(h_patches, w_patches, x.device)
459
+ decoder_pos = self.fetch_decoder_pos(h_patches, w_patches, x.device)
460
+
461
+ t_emb = self.t_embedder(t.view(-1)).view(B, 1, self.hidden_size)
462
+
463
+ t_cond = nn.functional.silu(t_emb)
464
+
465
+ s = self.s_embedder(x)
466
+
467
+ s_section_a = s
468
+ for i in range(min(1, self.num_blocks)):
469
+ block_context = y_emb if i % 2 == 0 else None
470
+ s_section_a = self.blocks[i](s_section_a, t_cond, decoder_pos, None, context=block_context, context_mask=padding_mask, shared_adaLN=self.shared_decoder_adaLN)
471
+
472
+ s_section_a_projected = self.a_to_b_proj(s_section_a)
473
+
474
+ s_section_b = s_section_a_projected
475
+
476
+ for i in range(1, self.num_blocks - 3):
477
+ block_context = y_emb if i % 2 == 0 else None
478
+ s_section_b = self.blocks[i](s_section_b, t_cond, encoder_pos, None, context=block_context, context_mask=padding_mask, shared_adaLN=self.shared_encoder_adaLN)
479
+
480
+ s_concat = torch.cat([s_section_a, s_section_b], dim=-1)
481
+
482
+ s = self.bc_projection(s_concat)
483
+
484
+ for i in range(max(1, self.num_blocks - 3), self.num_blocks):
485
+ block_context = y_emb if i % 2 == 0 else None
486
+ s = self.blocks[i](s, t_cond, decoder_pos, None, context=block_context, context_mask=padding_mask, shared_adaLN=self.shared_decoder_adaLN)
487
+
488
+ s = self.final_layer(s, H // self.patch_h, W // self.patch_w)
489
+
490
+ return s
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.8.0
2
+ torchaudio>=2.8.0
3
+ torchvision>=0.23.0
4
+ torchcodec>=0.8.0
5
+ accelerate>=1.9.0
6
+ diffusers>=0.34.0
7
+ einops>=0.8.1
8
+ librosa>=0.11.0
9
+ safetensors>=0.4.0
10
+ gradio>=5.45.0