Upload 15 files
Browse files- LICENSE +201 -0
- README.md +43 -3
- acestep/checkpoints/music_dcae_f8c8/config.json +69 -0
- acestep/checkpoints/music_dcae_f8c8/diffusion_pytorch_model.safetensors +3 -0
- acestep/checkpoints/music_vocoder/config.json +79 -0
- acestep/checkpoints/music_vocoder/diffusion_pytorch_model.safetensors +3 -0
- acestep/music_dcae/__init__.py +0 -0
- acestep/music_dcae/music_dcae_pipeline.py +379 -0
- acestep/music_dcae/music_log_mel.py +115 -0
- acestep/music_dcae/music_vocoder.py +587 -0
- checkpoints/checkpoint_461260.safetensors +3 -0
- checkpoints/tag_mapping.json +858 -0
- gradio_app.py +265 -0
- model.py +490 -0
- requirements.txt +10 -0
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,3 +1,43 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LocalSong
|
| 2 |
+
|
| 3 |
+
LocalSong is an audio generation model focused on melodic instrumental music that uses tag-based conditioning to generate audio.
|
| 4 |
+
|
| 5 |
+
## Installation
|
| 6 |
+
|
| 7 |
+
### Prerequisites
|
| 8 |
+
|
| 9 |
+
- Python 3.10 or higher
|
| 10 |
+
- CUDA-capable GPU recommended
|
| 11 |
+
|
| 12 |
+
### Setup
|
| 13 |
+
|
| 14 |
+
git clone https://huggingface.co/Localsong/LocalSong
|
| 15 |
+
cd localsong
|
| 16 |
+
python3 -m venv venv
|
| 17 |
+
source venv/bin/activate
|
| 18 |
+
pip install -r requirements.txt
|
| 19 |
+
|
| 20 |
+
### Run
|
| 21 |
+
|
| 22 |
+
python gradio_app.py
|
| 23 |
+
|
| 24 |
+
The interface will be available at `http://localhost:7860`
|
| 25 |
+
|
| 26 |
+
### Generation Advice
|
| 27 |
+
|
| 28 |
+
Generations should use one of the soundtrack, soundtrack1 or soundtrack2 tags, as well as at least one other tag. They can use up to 8 tags; try combining genres and instruments.
|
| 29 |
+
The default settings (CFG 3.5, steps 200) have been tested as optimal.
|
| 30 |
+
The first generation will be slower due to torch.compile, then speed will increase.
|
| 31 |
+
The model was trained on vocals but not lyrics. Vocals will not have recognizable words.
|
| 32 |
+
|
| 33 |
+
## Credits
|
| 34 |
+
|
| 35 |
+
This project builds upon the following open-source projects:
|
| 36 |
+
|
| 37 |
+
- **Model Architecture**: Adapted from [DDT](https://github.com/MCG-NJU/DDT)
|
| 38 |
+
- **Flow Matching**: Adapted from [minRF](https://github.com/cloneofsimo/minRF)
|
| 39 |
+
- **Audio VAE**: [ACE-Step](https://github.com/ACE-Step/ACE-Step)
|
| 40 |
+
|
| 41 |
+
## License
|
| 42 |
+
|
| 43 |
+
This project is licensed under the Apache License 2.0
|
acestep/checkpoints/music_dcae_f8c8/config.json
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderDC",
|
| 3 |
+
"_diffusers_version": "0.32.2",
|
| 4 |
+
"_name_or_path": "checkpoints/music_dcae_f8c8",
|
| 5 |
+
"attention_head_dim": 32,
|
| 6 |
+
"decoder_act_fns": "silu",
|
| 7 |
+
"decoder_block_out_channels": [
|
| 8 |
+
128,
|
| 9 |
+
256,
|
| 10 |
+
512,
|
| 11 |
+
1024
|
| 12 |
+
],
|
| 13 |
+
"decoder_block_types": [
|
| 14 |
+
"ResBlock",
|
| 15 |
+
"ResBlock",
|
| 16 |
+
"ResBlock",
|
| 17 |
+
"EfficientViTBlock"
|
| 18 |
+
],
|
| 19 |
+
"decoder_layers_per_block": [
|
| 20 |
+
3,
|
| 21 |
+
3,
|
| 22 |
+
3,
|
| 23 |
+
3
|
| 24 |
+
],
|
| 25 |
+
"decoder_norm_types": "rms_norm",
|
| 26 |
+
"decoder_qkv_multiscales": [
|
| 27 |
+
[],
|
| 28 |
+
[],
|
| 29 |
+
[
|
| 30 |
+
5
|
| 31 |
+
],
|
| 32 |
+
[
|
| 33 |
+
5
|
| 34 |
+
]
|
| 35 |
+
],
|
| 36 |
+
"downsample_block_type": "Conv",
|
| 37 |
+
"encoder_block_out_channels": [
|
| 38 |
+
128,
|
| 39 |
+
256,
|
| 40 |
+
512,
|
| 41 |
+
1024
|
| 42 |
+
],
|
| 43 |
+
"encoder_block_types": [
|
| 44 |
+
"ResBlock",
|
| 45 |
+
"ResBlock",
|
| 46 |
+
"ResBlock",
|
| 47 |
+
"EfficientViTBlock"
|
| 48 |
+
],
|
| 49 |
+
"encoder_layers_per_block": [
|
| 50 |
+
2,
|
| 51 |
+
2,
|
| 52 |
+
3,
|
| 53 |
+
3
|
| 54 |
+
],
|
| 55 |
+
"encoder_qkv_multiscales": [
|
| 56 |
+
[],
|
| 57 |
+
[],
|
| 58 |
+
[
|
| 59 |
+
5
|
| 60 |
+
],
|
| 61 |
+
[
|
| 62 |
+
5
|
| 63 |
+
]
|
| 64 |
+
],
|
| 65 |
+
"in_channels": 2,
|
| 66 |
+
"latent_channels": 8,
|
| 67 |
+
"scaling_factor": 0.41407,
|
| 68 |
+
"upsample_block_type": "interpolate"
|
| 69 |
+
}
|
acestep/checkpoints/music_dcae_f8c8/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2b0cb469307ac50659d1880db2a99bae47d0df335cbb36853964662d4b80e8ee
|
| 3 |
+
size 313646516
|
acestep/checkpoints/music_vocoder/config.json
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "ADaMoSHiFiGANV1",
|
| 3 |
+
"_diffusers_version": "0.32.2",
|
| 4 |
+
"depths": [
|
| 5 |
+
3,
|
| 6 |
+
3,
|
| 7 |
+
9,
|
| 8 |
+
3
|
| 9 |
+
],
|
| 10 |
+
"dims": [
|
| 11 |
+
128,
|
| 12 |
+
256,
|
| 13 |
+
384,
|
| 14 |
+
512
|
| 15 |
+
],
|
| 16 |
+
"drop_path_rate": 0.0,
|
| 17 |
+
"f_max": 16000,
|
| 18 |
+
"f_min": 40,
|
| 19 |
+
"hop_length": 512,
|
| 20 |
+
"input_channels": 128,
|
| 21 |
+
"kernel_sizes": [
|
| 22 |
+
7
|
| 23 |
+
],
|
| 24 |
+
"n_fft": 2048,
|
| 25 |
+
"n_mels": 128,
|
| 26 |
+
"num_mels": 512,
|
| 27 |
+
"post_conv_kernel_size": 13,
|
| 28 |
+
"pre_conv_kernel_size": 13,
|
| 29 |
+
"resblock_dilation_sizes": [
|
| 30 |
+
[
|
| 31 |
+
1,
|
| 32 |
+
3,
|
| 33 |
+
5
|
| 34 |
+
],
|
| 35 |
+
[
|
| 36 |
+
1,
|
| 37 |
+
3,
|
| 38 |
+
5
|
| 39 |
+
],
|
| 40 |
+
[
|
| 41 |
+
1,
|
| 42 |
+
3,
|
| 43 |
+
5
|
| 44 |
+
],
|
| 45 |
+
[
|
| 46 |
+
1,
|
| 47 |
+
3,
|
| 48 |
+
5
|
| 49 |
+
]
|
| 50 |
+
],
|
| 51 |
+
"resblock_kernel_sizes": [
|
| 52 |
+
3,
|
| 53 |
+
7,
|
| 54 |
+
11,
|
| 55 |
+
13
|
| 56 |
+
],
|
| 57 |
+
"sampling_rate": 44100,
|
| 58 |
+
"upsample_initial_channel": 1024,
|
| 59 |
+
"upsample_kernel_sizes": [
|
| 60 |
+
8,
|
| 61 |
+
8,
|
| 62 |
+
4,
|
| 63 |
+
4,
|
| 64 |
+
4,
|
| 65 |
+
4,
|
| 66 |
+
4
|
| 67 |
+
],
|
| 68 |
+
"upsample_rates": [
|
| 69 |
+
4,
|
| 70 |
+
4,
|
| 71 |
+
2,
|
| 72 |
+
2,
|
| 73 |
+
2,
|
| 74 |
+
2,
|
| 75 |
+
2
|
| 76 |
+
],
|
| 77 |
+
"use_template": false,
|
| 78 |
+
"win_length": 2048
|
| 79 |
+
}
|
acestep/checkpoints/music_vocoder/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c92c9b46e28ab7b37b777780cf4308ad7ddac869636bb77aa61599358c4bc1c0
|
| 3 |
+
size 206350988
|
acestep/music_dcae/__init__.py
ADDED
|
File without changes
|
acestep/music_dcae/music_dcae_pipeline.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ACE-Step: A Step Towards Music Generation Foundation Model
|
| 3 |
+
|
| 4 |
+
https://github.com/ace-step/ACE-Step
|
| 5 |
+
|
| 6 |
+
Apache 2.0 License
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import torch
|
| 11 |
+
from diffusers import AutoencoderDC
|
| 12 |
+
import torchaudio
|
| 13 |
+
import torchvision.transforms as transforms
|
| 14 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 15 |
+
from diffusers.loaders import FromOriginalModelMixin
|
| 16 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
from acestep.music_dcae.music_vocoder import ADaMoSHiFiGANV1
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 23 |
+
DEFAULT_PRETRAINED_PATH = os.path.join(root_dir, "checkpoints", "music_dcae_f8c8")
|
| 24 |
+
VOCODER_PRETRAINED_PATH = os.path.join(root_dir, "checkpoints", "music_vocoder")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class MusicDCAE(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 28 |
+
@register_to_config
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
source_sample_rate=None,
|
| 32 |
+
dcae_checkpoint_path=DEFAULT_PRETRAINED_PATH,
|
| 33 |
+
vocoder_checkpoint_path=VOCODER_PRETRAINED_PATH,
|
| 34 |
+
):
|
| 35 |
+
super(MusicDCAE, self).__init__()
|
| 36 |
+
|
| 37 |
+
self.dcae = AutoencoderDC.from_pretrained(dcae_checkpoint_path)
|
| 38 |
+
self.vocoder = ADaMoSHiFiGANV1.from_pretrained(vocoder_checkpoint_path)
|
| 39 |
+
|
| 40 |
+
if source_sample_rate is None:
|
| 41 |
+
source_sample_rate = 48000
|
| 42 |
+
|
| 43 |
+
self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
|
| 44 |
+
|
| 45 |
+
self.transform = transforms.Compose(
|
| 46 |
+
[
|
| 47 |
+
transforms.Normalize(0.5, 0.5),
|
| 48 |
+
]
|
| 49 |
+
)
|
| 50 |
+
self.min_mel_value = -11.0
|
| 51 |
+
self.max_mel_value = 3.0
|
| 52 |
+
self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000)))
|
| 53 |
+
self.mel_chunk_size = 1024
|
| 54 |
+
self.time_dimention_multiple = 8
|
| 55 |
+
self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple
|
| 56 |
+
self.scale_factor = 0.1786
|
| 57 |
+
self.shift_factor = -1.9091
|
| 58 |
+
|
| 59 |
+
def load_audio(self, audio_path):
|
| 60 |
+
audio, sr = torchaudio.load(audio_path)
|
| 61 |
+
if audio.shape[0] == 1:
|
| 62 |
+
audio = audio.repeat(2, 1)
|
| 63 |
+
return audio, sr
|
| 64 |
+
|
| 65 |
+
def forward_mel(self, audios):
|
| 66 |
+
mels = []
|
| 67 |
+
for i in range(len(audios)):
|
| 68 |
+
image = self.vocoder.mel_transform(audios[i])
|
| 69 |
+
mels.append(image)
|
| 70 |
+
mels = torch.stack(mels)
|
| 71 |
+
return mels
|
| 72 |
+
|
| 73 |
+
@torch.no_grad()
|
| 74 |
+
def encode(self, audios, audio_lengths=None, sr=None):
|
| 75 |
+
if audio_lengths is None:
|
| 76 |
+
audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0])
|
| 77 |
+
audio_lengths = audio_lengths.to(audios.device)
|
| 78 |
+
|
| 79 |
+
# audios: N x 2 x T, 48kHz
|
| 80 |
+
device = audios.device
|
| 81 |
+
dtype = audios.dtype
|
| 82 |
+
|
| 83 |
+
if sr is None:
|
| 84 |
+
sr = 48000
|
| 85 |
+
resampler = self.resampler
|
| 86 |
+
else:
|
| 87 |
+
resampler = torchaudio.transforms.Resample(sr, 44100).to(device).to(dtype)
|
| 88 |
+
|
| 89 |
+
audio = resampler(audios)
|
| 90 |
+
|
| 91 |
+
max_audio_len = audio.shape[-1]
|
| 92 |
+
if max_audio_len % (8 * 512) != 0:
|
| 93 |
+
audio = torch.nn.functional.pad(
|
| 94 |
+
audio, (0, 8 * 512 - max_audio_len % (8 * 512))
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
mels = self.forward_mel(audio)
|
| 98 |
+
mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value)
|
| 99 |
+
mels = self.transform(mels)
|
| 100 |
+
latents = []
|
| 101 |
+
for mel in mels:
|
| 102 |
+
latent = self.dcae.encoder(mel.unsqueeze(0))
|
| 103 |
+
latents.append(latent)
|
| 104 |
+
latents = torch.cat(latents, dim=0)
|
| 105 |
+
latent_lengths = (
|
| 106 |
+
audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple
|
| 107 |
+
).long()
|
| 108 |
+
latents = (latents - self.shift_factor) * self.scale_factor
|
| 109 |
+
return latents, latent_lengths
|
| 110 |
+
|
| 111 |
+
@torch.no_grad()
|
| 112 |
+
def decode(self, latents, audio_lengths=None, sr=None):
|
| 113 |
+
latents = latents / self.scale_factor + self.shift_factor
|
| 114 |
+
|
| 115 |
+
pred_wavs = []
|
| 116 |
+
|
| 117 |
+
for latent in latents:
|
| 118 |
+
mels = self.dcae.decoder(latent.unsqueeze(0))
|
| 119 |
+
mels = mels * 0.5 + 0.5
|
| 120 |
+
mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
|
| 121 |
+
|
| 122 |
+
# wav = self.vocoder.decode(mels[0]).squeeze(1)
|
| 123 |
+
# decode waveform for each channels to reduce vram footprint
|
| 124 |
+
wav_ch1 = self.vocoder.decode(mels[:,0,:,:]).squeeze(1).cpu()
|
| 125 |
+
wav_ch2 = self.vocoder.decode(mels[:,1,:,:]).squeeze(1).cpu()
|
| 126 |
+
wav = torch.cat([wav_ch1, wav_ch2],dim=0)
|
| 127 |
+
|
| 128 |
+
if sr is not None:
|
| 129 |
+
resampler = (
|
| 130 |
+
torchaudio.transforms.Resample(44100, sr)
|
| 131 |
+
)
|
| 132 |
+
wav = resampler(wav.cpu().float())
|
| 133 |
+
else:
|
| 134 |
+
sr = 44100
|
| 135 |
+
pred_wavs.append(wav)
|
| 136 |
+
|
| 137 |
+
if audio_lengths is not None:
|
| 138 |
+
pred_wavs = [
|
| 139 |
+
wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)
|
| 140 |
+
]
|
| 141 |
+
return sr, pred_wavs
|
| 142 |
+
|
| 143 |
+
@torch.no_grad()
|
| 144 |
+
def decode_overlap(self, latents, audio_lengths=None, sr=None):
|
| 145 |
+
"""
|
| 146 |
+
Decodes latents into waveforms using an overlapped DCAE and Vocoder.
|
| 147 |
+
"""
|
| 148 |
+
print("Using Overlapped DCAE and Vocoder")
|
| 149 |
+
|
| 150 |
+
MODEL_INTERNAL_SR = 44100
|
| 151 |
+
DCAE_LATENT_TO_MEL_STRIDE = 8
|
| 152 |
+
VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME = 512
|
| 153 |
+
|
| 154 |
+
pred_wavs = []
|
| 155 |
+
final_output_sr = sr if sr is not None else MODEL_INTERNAL_SR
|
| 156 |
+
|
| 157 |
+
# --- DCAE Parameters ---
|
| 158 |
+
# dcae_win_len_latent: Window length in the latent domain for DCAE processing
|
| 159 |
+
dcae_win_len_latent = 512
|
| 160 |
+
# dcae_mel_win_len: Expected mel window length from DCAE decoder output (latent_win * stride)
|
| 161 |
+
dcae_mel_win_len = dcae_win_len_latent * 8
|
| 162 |
+
# dcae_anchor_offset: Offset from anchor point to actual start of latent window slice
|
| 163 |
+
dcae_anchor_offset = dcae_win_len_latent // 4
|
| 164 |
+
# dcae_anchor_hop: Hop size for anchor points in latent domain
|
| 165 |
+
dcae_anchor_hop = dcae_win_len_latent // 2
|
| 166 |
+
# dcae_mel_overlap_len: Overlap length in the mel domain to be trimmed/blended
|
| 167 |
+
dcae_mel_overlap_len = dcae_mel_win_len // 4
|
| 168 |
+
|
| 169 |
+
# --- Vocoder Parameters ---
|
| 170 |
+
# vocoder_win_len_audio: Audio samples per vocoder processing window
|
| 171 |
+
vocoder_win_len_audio = 512 * 512 # Example: 262144 samples
|
| 172 |
+
# vocoder_overlap_len_audio: Audio samples for overlap between vocoder windows
|
| 173 |
+
vocoder_overlap_len_audio = 1024
|
| 174 |
+
# vocoder_hop_len_audio: Hop size in audio samples for vocoder processing
|
| 175 |
+
vocoder_hop_len_audio = vocoder_win_len_audio - 2 * vocoder_overlap_len_audio
|
| 176 |
+
# vocoder_input_mel_frames_per_block: Number of mel frames fed to vocoder in one go
|
| 177 |
+
vocoder_input_mel_frames_per_block = vocoder_win_len_audio // VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME
|
| 178 |
+
|
| 179 |
+
crossfade_len_audio = 128 # Audio samples for crossfading vocoder outputs
|
| 180 |
+
cf_win_tail = torch.linspace(1, 0, crossfade_len_audio, device=self.device).unsqueeze(0).unsqueeze(0)
|
| 181 |
+
cf_win_head = torch.linspace(0, 1, crossfade_len_audio, device=self.device).unsqueeze(0).unsqueeze(0)
|
| 182 |
+
|
| 183 |
+
for latent_idx, latent_item in enumerate(latents):
|
| 184 |
+
latent_item = latent_item.to(self.device)
|
| 185 |
+
current_latent = (latent_item / self.scale_factor + self.shift_factor).unsqueeze(0) # (1, C, H, W_latent)
|
| 186 |
+
latent_len = current_latent.shape[3]
|
| 187 |
+
|
| 188 |
+
# 1. DCAE: Latent to Mel Spectrogram (Overlapped)
|
| 189 |
+
mels_segments = []
|
| 190 |
+
if latent_len == 0:
|
| 191 |
+
pass # No mel segments to generate
|
| 192 |
+
else:
|
| 193 |
+
# Determine anchor points for DCAE windows
|
| 194 |
+
# An anchor marks a reference point for a window slice.
|
| 195 |
+
# Window slice: current_latent[..., anchor - offset : anchor - offset + win_len]
|
| 196 |
+
# First anchor ensures window starts at 0. Last anchor ensures tail is covered.
|
| 197 |
+
dcae_anchors = list(range(dcae_anchor_offset, latent_len - dcae_anchor_offset, dcae_anchor_hop))
|
| 198 |
+
if not dcae_anchors: # If latent is too short for the range, use one anchor
|
| 199 |
+
dcae_anchors = [dcae_anchor_offset]
|
| 200 |
+
|
| 201 |
+
for i, anchor in enumerate(dcae_anchors):
|
| 202 |
+
win_start_idx = max(0, anchor - dcae_anchor_offset)
|
| 203 |
+
win_end_idx = min(latent_len, win_start_idx + dcae_win_len_latent)
|
| 204 |
+
|
| 205 |
+
dcae_input_segment = current_latent[:, :, :, win_start_idx:win_end_idx]
|
| 206 |
+
if dcae_input_segment.shape[3] == 0: continue
|
| 207 |
+
|
| 208 |
+
mel_output_full = self.dcae.decoder(dcae_input_segment) # (1, C, H_mel, W_mel_fixed_from_dcae)
|
| 209 |
+
|
| 210 |
+
is_first = (i == 0)
|
| 211 |
+
is_last = (i == len(dcae_anchors) - 1)
|
| 212 |
+
|
| 213 |
+
if is_first and is_last: # Only one segment
|
| 214 |
+
# Use mel corresponding to actual input latent length
|
| 215 |
+
true_mel_content_len = dcae_input_segment.shape[3] * DCAE_LATENT_TO_MEL_STRIDE
|
| 216 |
+
mel_to_keep = mel_output_full[:, :, :, :min(true_mel_content_len, mel_output_full.shape[3])]
|
| 217 |
+
elif is_first: # First segment, trim end overlap
|
| 218 |
+
mel_to_keep = mel_output_full[:, :, :, :-dcae_mel_overlap_len]
|
| 219 |
+
elif is_last: # Last segment, trim start overlap
|
| 220 |
+
# And ensure we only take content relevant to the (potentially partial) last latent window
|
| 221 |
+
# The mel_output_full is fixed length. The useful part starts after overlap.
|
| 222 |
+
# The length of the useful part depends on how much of dcae_input_segment was actual content.
|
| 223 |
+
# For simplicity in overlap-add, typically trim fixed overlap.
|
| 224 |
+
# If dcae_input_segment was shorter than dcae_win_len_latent, mel_output_full might contain padding effects.
|
| 225 |
+
# Standard OLA keeps the corresponding tail.
|
| 226 |
+
mel_to_keep = mel_output_full[:, :, :, dcae_mel_overlap_len:]
|
| 227 |
+
else: # Middle segment, trim both overlaps
|
| 228 |
+
mel_to_keep = mel_output_full[:, :, :, dcae_mel_overlap_len:-dcae_mel_overlap_len]
|
| 229 |
+
|
| 230 |
+
if mel_to_keep.shape[3] > 0:
|
| 231 |
+
mels_segments.append(mel_to_keep)
|
| 232 |
+
|
| 233 |
+
if not mels_segments:
|
| 234 |
+
num_mel_channels = current_latent.shape[1]
|
| 235 |
+
mel_height = self.dcae.decoder_output_mel_height
|
| 236 |
+
concatenated_mels = torch.empty(
|
| 237 |
+
(1, num_mel_channels, mel_height, 0),
|
| 238 |
+
device=current_latent.device, dtype=current_latent.dtype
|
| 239 |
+
)
|
| 240 |
+
else:
|
| 241 |
+
concatenated_mels = torch.cat(mels_segments, dim=3)
|
| 242 |
+
|
| 243 |
+
# Denormalize mels
|
| 244 |
+
concatenated_mels = concatenated_mels * 0.5 + 0.5
|
| 245 |
+
concatenated_mels = concatenated_mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
|
| 246 |
+
|
| 247 |
+
mel_total_frames = concatenated_mels.shape[3]
|
| 248 |
+
|
| 249 |
+
# 2. Vocoder: Mel Spectrogram to Waveform (Overlapped)
|
| 250 |
+
if mel_total_frames == 0:
|
| 251 |
+
# Assuming mono or stereo output based on mel channels (typically mono for vocoder from single mel)
|
| 252 |
+
num_audio_channels = 1 # Or determine from vocoder capabilities / mel channels
|
| 253 |
+
final_wav = torch.zeros((num_audio_channels, 0), device=self.device, dtype=torch.float32)
|
| 254 |
+
else:
|
| 255 |
+
# Initial vocoder window
|
| 256 |
+
# Vocoder expects (C_mel, H_mel, W_mel_block)
|
| 257 |
+
mel_block = concatenated_mels[0, :, :, :vocoder_input_mel_frames_per_block].to(self.device)
|
| 258 |
+
|
| 259 |
+
# Pad mel_block if it's shorter than vocoder_input_mel_frames_per_block (e.g. very short audio)
|
| 260 |
+
if 0 < mel_block.shape[2] < vocoder_input_mel_frames_per_block:
|
| 261 |
+
pad_len = vocoder_input_mel_frames_per_block - mel_block.shape[2]
|
| 262 |
+
mel_block = torch.nn.functional.pad(mel_block, (0, pad_len), mode='constant', value=0) # Pad last dim
|
| 263 |
+
|
| 264 |
+
current_audio_output = self.vocoder.decode(mel_block) # (C_audio, 1, Samples)
|
| 265 |
+
current_audio_output = current_audio_output[:, :, :-vocoder_overlap_len_audio] # Remove end overlap
|
| 266 |
+
|
| 267 |
+
# p_audio_samples tracks the start of the *next* audio segment to generate (in conceptual total audio samples)
|
| 268 |
+
p_audio_samples = vocoder_hop_len_audio
|
| 269 |
+
conceptual_total_audio_len_native_sr = mel_total_frames * VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME
|
| 270 |
+
|
| 271 |
+
pbar_total = 1 + max(0, (conceptual_total_audio_len_native_sr - (vocoder_win_len_audio - vocoder_overlap_len_audio))) // vocoder_hop_len_audio
|
| 272 |
+
|
| 273 |
+
# Use tqdm if you want a progress bar for the vocoder part
|
| 274 |
+
# with tqdm(total=pbar_total, desc=f"Vocoder {latent_idx+1}/{len(latents)}", leave=False) as pbar:
|
| 275 |
+
# pbar.update(1) # For initial window
|
| 276 |
+
# The loop for subsequent windows
|
| 277 |
+
while p_audio_samples < conceptual_total_audio_len_native_sr:
|
| 278 |
+
mel_frame_start = p_audio_samples // VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME
|
| 279 |
+
mel_frame_end = mel_frame_start + vocoder_input_mel_frames_per_block
|
| 280 |
+
|
| 281 |
+
if mel_frame_start >= mel_total_frames: break # No more mel frames
|
| 282 |
+
|
| 283 |
+
mel_block = concatenated_mels[0, :, :, mel_frame_start:min(mel_frame_end, mel_total_frames)].to(self.device)
|
| 284 |
+
|
| 285 |
+
if mel_block.shape[2] == 0: break # Should not happen if mel_frame_start is valid
|
| 286 |
+
|
| 287 |
+
# Pad if current mel_block is too short (end of sequence)
|
| 288 |
+
if mel_block.shape[2] < vocoder_input_mel_frames_per_block:
|
| 289 |
+
pad_len = vocoder_input_mel_frames_per_block - mel_block.shape[2]
|
| 290 |
+
mel_block = torch.nn.functional.pad(mel_block, (0, pad_len), mode='constant', value=0)
|
| 291 |
+
|
| 292 |
+
new_audio_win = self.vocoder.decode(mel_block) # (C_audio, 1, Samples)
|
| 293 |
+
|
| 294 |
+
# Crossfade
|
| 295 |
+
# Determine actual crossfade length based on available audio
|
| 296 |
+
actual_cf_len = min(crossfade_len_audio, current_audio_output.shape[2], new_audio_win.shape[2] - (vocoder_overlap_len_audio - crossfade_len_audio))
|
| 297 |
+
if actual_cf_len > 0: # Ensure valid slice lengths for crossfade
|
| 298 |
+
tail_part = current_audio_output[:, :, -actual_cf_len:]
|
| 299 |
+
head_part = new_audio_win[:, :, vocoder_overlap_len_audio - actual_cf_len : vocoder_overlap_len_audio]
|
| 300 |
+
|
| 301 |
+
crossfaded_segment = tail_part * cf_win_tail[:,:,:actual_cf_len] + \
|
| 302 |
+
head_part * cf_win_head[:,:,:actual_cf_len]
|
| 303 |
+
|
| 304 |
+
current_audio_output = torch.cat([current_audio_output[:, :, :-actual_cf_len], crossfaded_segment], dim=2)
|
| 305 |
+
|
| 306 |
+
# Append non-overlapping part of new_audio_win
|
| 307 |
+
is_final_append = (p_audio_samples + vocoder_hop_len_audio >= conceptual_total_audio_len_native_sr)
|
| 308 |
+
if is_final_append:
|
| 309 |
+
segment_to_append = new_audio_win[:, :, vocoder_overlap_len_audio:]
|
| 310 |
+
else:
|
| 311 |
+
segment_to_append = new_audio_win[:, :, vocoder_overlap_len_audio:-vocoder_overlap_len_audio]
|
| 312 |
+
|
| 313 |
+
current_audio_output = torch.cat([current_audio_output, segment_to_append], dim=2)
|
| 314 |
+
|
| 315 |
+
p_audio_samples += vocoder_hop_len_audio
|
| 316 |
+
# pbar.update(1) # if using tqdm
|
| 317 |
+
|
| 318 |
+
final_wav = current_audio_output.squeeze(1) # (C_audio, Samples)
|
| 319 |
+
|
| 320 |
+
# 3. Resampling (if necessary)
|
| 321 |
+
if final_output_sr != MODEL_INTERNAL_SR and final_wav.numel() > 0:
|
| 322 |
+
# Resample expects CPU tensor if using torchaudio.transforms on older versions or for some backends
|
| 323 |
+
resampler = torchaudio.transforms.Resample(
|
| 324 |
+
MODEL_INTERNAL_SR, final_output_sr, dtype=final_wav.dtype
|
| 325 |
+
)
|
| 326 |
+
final_wav = resampler(final_wav.cpu()).to(self.device) # Move back to device if needed later
|
| 327 |
+
|
| 328 |
+
pred_wavs.append(final_wav)
|
| 329 |
+
|
| 330 |
+
# 4. Final Truncation
|
| 331 |
+
processed_pred_wavs = []
|
| 332 |
+
for i, wav in enumerate(pred_wavs):
|
| 333 |
+
# Calculate expected length based on original latent, at the FINAL output sample rate
|
| 334 |
+
_num_latent_frames = latents[i].shape[-1] # Use original latent item for shape
|
| 335 |
+
_num_mel_frames = _num_latent_frames * DCAE_LATENT_TO_MEL_STRIDE
|
| 336 |
+
_conceptual_native_audio_len = _num_mel_frames * VOCODER_AUDIO_SAMPLES_PER_MEL_FRAME
|
| 337 |
+
max_possible_len = int(_conceptual_native_audio_len * final_output_sr / MODEL_INTERNAL_SR)
|
| 338 |
+
|
| 339 |
+
current_wav_len = wav.shape[1]
|
| 340 |
+
|
| 341 |
+
if audio_lengths is not None:
|
| 342 |
+
# User-provided length is the primary target, capped by actual and max possible
|
| 343 |
+
target_len = min(audio_lengths[i], current_wav_len, max_possible_len)
|
| 344 |
+
else:
|
| 345 |
+
# No user length, use max possible capped by actual
|
| 346 |
+
target_len = min(max_possible_len, current_wav_len)
|
| 347 |
+
|
| 348 |
+
processed_pred_wavs.append(wav[:, :max(0, target_len)].cpu()) # Ensure length is non-negative
|
| 349 |
+
|
| 350 |
+
return final_output_sr, processed_pred_wavs
|
| 351 |
+
|
| 352 |
+
def forward(self, audios, audio_lengths=None, sr=None):
|
| 353 |
+
latents, latent_lengths = self.encode(
|
| 354 |
+
audios=audios, audio_lengths=audio_lengths, sr=sr
|
| 355 |
+
)
|
| 356 |
+
sr, pred_wavs = self.decode(latents=latents, audio_lengths=audio_lengths, sr=sr)
|
| 357 |
+
return sr, pred_wavs, latents, latent_lengths
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
if __name__ == "__main__":
|
| 361 |
+
|
| 362 |
+
audio, sr = torchaudio.load("test.wav")
|
| 363 |
+
audio_lengths = torch.tensor([audio.shape[1]])
|
| 364 |
+
audios = audio.unsqueeze(0)
|
| 365 |
+
|
| 366 |
+
# test encode only
|
| 367 |
+
model = MusicDCAE()
|
| 368 |
+
# latents, latent_lengths = model.encode(audios, audio_lengths)
|
| 369 |
+
# print("latents shape: ", latents.shape)
|
| 370 |
+
# print("latent_lengths: ", latent_lengths)
|
| 371 |
+
|
| 372 |
+
# test encode and decode
|
| 373 |
+
sr, pred_wavs, latents, latent_lengths = model(audios, audio_lengths, sr)
|
| 374 |
+
print("reconstructed wavs: ", pred_wavs[0].shape)
|
| 375 |
+
print("latents shape: ", latents.shape)
|
| 376 |
+
print("latent_lengths: ", latent_lengths)
|
| 377 |
+
print("sr: ", sr)
|
| 378 |
+
torchaudio.save("test_reconstructed.wav", pred_wavs[0], sr)
|
| 379 |
+
print("test_reconstructed.wav")
|
acestep/music_dcae/music_log_mel.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ACE-Step: A Step Towards Music Generation Foundation Model
|
| 3 |
+
|
| 4 |
+
https://github.com/ace-step/ACE-Step
|
| 5 |
+
|
| 6 |
+
Apache 2.0 License
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
from torchaudio.transforms import MelScale
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LinearSpectrogram(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
n_fft=2048,
|
| 19 |
+
win_length=2048,
|
| 20 |
+
hop_length=512,
|
| 21 |
+
center=False,
|
| 22 |
+
mode="pow2_sqrt",
|
| 23 |
+
):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
self.n_fft = n_fft
|
| 27 |
+
self.win_length = win_length
|
| 28 |
+
self.hop_length = hop_length
|
| 29 |
+
self.center = center
|
| 30 |
+
self.mode = mode
|
| 31 |
+
|
| 32 |
+
self.register_buffer("window", torch.hann_window(win_length))
|
| 33 |
+
|
| 34 |
+
def forward(self, y: Tensor) -> Tensor:
|
| 35 |
+
if y.ndim == 3:
|
| 36 |
+
y = y.squeeze(1)
|
| 37 |
+
|
| 38 |
+
y = torch.nn.functional.pad(
|
| 39 |
+
y.unsqueeze(1),
|
| 40 |
+
(
|
| 41 |
+
(self.win_length - self.hop_length) // 2,
|
| 42 |
+
(self.win_length - self.hop_length + 1) // 2,
|
| 43 |
+
),
|
| 44 |
+
mode="reflect",
|
| 45 |
+
).squeeze(1)
|
| 46 |
+
dtype = y.dtype
|
| 47 |
+
spec = torch.stft(
|
| 48 |
+
y.float(),
|
| 49 |
+
self.n_fft,
|
| 50 |
+
hop_length=self.hop_length,
|
| 51 |
+
win_length=self.win_length,
|
| 52 |
+
window=self.window,
|
| 53 |
+
center=self.center,
|
| 54 |
+
pad_mode="reflect",
|
| 55 |
+
normalized=False,
|
| 56 |
+
onesided=True,
|
| 57 |
+
return_complex=True,
|
| 58 |
+
)
|
| 59 |
+
spec = torch.view_as_real(spec)
|
| 60 |
+
|
| 61 |
+
if self.mode == "pow2_sqrt":
|
| 62 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
| 63 |
+
spec = spec.to(dtype)
|
| 64 |
+
return spec
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class LogMelSpectrogram(nn.Module):
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
sample_rate=44100,
|
| 71 |
+
n_fft=2048,
|
| 72 |
+
win_length=2048,
|
| 73 |
+
hop_length=512,
|
| 74 |
+
n_mels=128,
|
| 75 |
+
center=False,
|
| 76 |
+
f_min=0.0,
|
| 77 |
+
f_max=None,
|
| 78 |
+
):
|
| 79 |
+
super().__init__()
|
| 80 |
+
|
| 81 |
+
self.sample_rate = sample_rate
|
| 82 |
+
self.n_fft = n_fft
|
| 83 |
+
self.win_length = win_length
|
| 84 |
+
self.hop_length = hop_length
|
| 85 |
+
self.center = center
|
| 86 |
+
self.n_mels = n_mels
|
| 87 |
+
self.f_min = f_min
|
| 88 |
+
self.f_max = f_max or sample_rate // 2
|
| 89 |
+
|
| 90 |
+
self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
|
| 91 |
+
self.mel_scale = MelScale(
|
| 92 |
+
self.n_mels,
|
| 93 |
+
self.sample_rate,
|
| 94 |
+
self.f_min,
|
| 95 |
+
self.f_max,
|
| 96 |
+
self.n_fft // 2 + 1,
|
| 97 |
+
"slaney",
|
| 98 |
+
"slaney",
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
def compress(self, x: Tensor) -> Tensor:
|
| 102 |
+
return torch.log(torch.clamp(x, min=1e-5))
|
| 103 |
+
|
| 104 |
+
def decompress(self, x: Tensor) -> Tensor:
|
| 105 |
+
return torch.exp(x)
|
| 106 |
+
|
| 107 |
+
def forward(self, x: Tensor, return_linear: bool = False) -> Tensor:
|
| 108 |
+
linear = self.spectrogram(x)
|
| 109 |
+
x = self.mel_scale(linear)
|
| 110 |
+
x = self.compress(x)
|
| 111 |
+
# print(x.shape)
|
| 112 |
+
if return_linear:
|
| 113 |
+
return x, self.compress(linear)
|
| 114 |
+
|
| 115 |
+
return x
|
acestep/music_dcae/music_vocoder.py
ADDED
|
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ACE-Step: A Step Towards Music Generation Foundation Model
|
| 3 |
+
|
| 4 |
+
https://github.com/ace-step/ACE-Step
|
| 5 |
+
|
| 6 |
+
Apache 2.0 License
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import librosa
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
from functools import partial
|
| 14 |
+
from math import prod
|
| 15 |
+
from typing import Callable, Tuple, List
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
from torch.nn import Conv1d
|
| 20 |
+
from torch.nn.utils import weight_norm
|
| 21 |
+
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
|
| 22 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 23 |
+
from diffusers.loaders import FromOriginalModelMixin
|
| 24 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
from music_log_mel import LogMelSpectrogram
|
| 29 |
+
except ImportError:
|
| 30 |
+
from .music_log_mel import LogMelSpectrogram
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def drop_path(
|
| 34 |
+
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
| 35 |
+
):
|
| 36 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 37 |
+
|
| 38 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
| 39 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 40 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
| 41 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
| 42 |
+
'survival rate' as the argument.
|
| 43 |
+
|
| 44 |
+
""" # noqa: E501
|
| 45 |
+
|
| 46 |
+
if drop_prob == 0.0 or not training:
|
| 47 |
+
return x
|
| 48 |
+
keep_prob = 1 - drop_prob
|
| 49 |
+
shape = (x.shape[0],) + (1,) * (
|
| 50 |
+
x.ndim - 1
|
| 51 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
| 52 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 53 |
+
if keep_prob > 0.0 and scale_by_keep:
|
| 54 |
+
random_tensor.div_(keep_prob)
|
| 55 |
+
return x * random_tensor
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class DropPath(nn.Module):
|
| 59 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
|
| 60 |
+
|
| 61 |
+
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
| 62 |
+
super(DropPath, self).__init__()
|
| 63 |
+
self.drop_prob = drop_prob
|
| 64 |
+
self.scale_by_keep = scale_by_keep
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
| 68 |
+
|
| 69 |
+
def extra_repr(self):
|
| 70 |
+
return f"drop_prob={round(self.drop_prob,3):0.3f}"
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class LayerNorm(nn.Module):
|
| 74 |
+
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
| 75 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
| 76 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
| 77 |
+
with shape (batch_size, channels, height, width).
|
| 78 |
+
""" # noqa: E501
|
| 79 |
+
|
| 80 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 83 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 84 |
+
self.eps = eps
|
| 85 |
+
self.data_format = data_format
|
| 86 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
| 87 |
+
raise NotImplementedError
|
| 88 |
+
self.normalized_shape = (normalized_shape,)
|
| 89 |
+
|
| 90 |
+
def forward(self, x):
|
| 91 |
+
if self.data_format == "channels_last":
|
| 92 |
+
return F.layer_norm(
|
| 93 |
+
x, self.normalized_shape, self.weight, self.bias, self.eps
|
| 94 |
+
)
|
| 95 |
+
elif self.data_format == "channels_first":
|
| 96 |
+
u = x.mean(1, keepdim=True)
|
| 97 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 98 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 99 |
+
x = self.weight[:, None] * x + self.bias[:, None]
|
| 100 |
+
return x
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class ConvNeXtBlock(nn.Module):
|
| 104 |
+
r"""ConvNeXt Block. There are two equivalent implementations:
|
| 105 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
| 106 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
| 107 |
+
We use (2) as we find it slightly faster in PyTorch
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
dim (int): Number of input channels.
|
| 111 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
| 112 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
| 113 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
|
| 114 |
+
kernel_size (int): Kernel size for depthwise conv. Default: 7.
|
| 115 |
+
dilation (int): Dilation for depthwise conv. Default: 1.
|
| 116 |
+
""" # noqa: E501
|
| 117 |
+
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
dim: int,
|
| 121 |
+
drop_path: float = 0.0,
|
| 122 |
+
layer_scale_init_value: float = 1e-6,
|
| 123 |
+
mlp_ratio: float = 4.0,
|
| 124 |
+
kernel_size: int = 7,
|
| 125 |
+
dilation: int = 1,
|
| 126 |
+
):
|
| 127 |
+
super().__init__()
|
| 128 |
+
|
| 129 |
+
self.dwconv = nn.Conv1d(
|
| 130 |
+
dim,
|
| 131 |
+
dim,
|
| 132 |
+
kernel_size=kernel_size,
|
| 133 |
+
padding=int(dilation * (kernel_size - 1) / 2),
|
| 134 |
+
groups=dim,
|
| 135 |
+
) # depthwise conv
|
| 136 |
+
self.norm = LayerNorm(dim, eps=1e-6)
|
| 137 |
+
self.pwconv1 = nn.Linear(
|
| 138 |
+
dim, int(mlp_ratio * dim)
|
| 139 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
| 140 |
+
self.act = nn.GELU()
|
| 141 |
+
self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
|
| 142 |
+
self.gamma = (
|
| 143 |
+
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
| 144 |
+
if layer_scale_init_value > 0
|
| 145 |
+
else None
|
| 146 |
+
)
|
| 147 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 148 |
+
|
| 149 |
+
def forward(self, x, apply_residual: bool = True):
|
| 150 |
+
input = x
|
| 151 |
+
|
| 152 |
+
x = self.dwconv(x)
|
| 153 |
+
x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
|
| 154 |
+
x = self.norm(x)
|
| 155 |
+
x = self.pwconv1(x)
|
| 156 |
+
x = self.act(x)
|
| 157 |
+
x = self.pwconv2(x)
|
| 158 |
+
|
| 159 |
+
if self.gamma is not None:
|
| 160 |
+
x = self.gamma * x
|
| 161 |
+
|
| 162 |
+
x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
|
| 163 |
+
x = self.drop_path(x)
|
| 164 |
+
|
| 165 |
+
if apply_residual:
|
| 166 |
+
x = input + x
|
| 167 |
+
|
| 168 |
+
return x
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class ParallelConvNeXtBlock(nn.Module):
|
| 172 |
+
def __init__(self, kernel_sizes: List[int], *args, **kwargs):
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.blocks = nn.ModuleList(
|
| 175 |
+
[
|
| 176 |
+
ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
|
| 177 |
+
for kernel_size in kernel_sizes
|
| 178 |
+
]
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 182 |
+
return torch.stack(
|
| 183 |
+
[block(x, apply_residual=False) for block in self.blocks] + [x],
|
| 184 |
+
dim=1,
|
| 185 |
+
).sum(dim=1)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class ConvNeXtEncoder(nn.Module):
|
| 189 |
+
def __init__(
|
| 190 |
+
self,
|
| 191 |
+
input_channels=3,
|
| 192 |
+
depths=[3, 3, 9, 3],
|
| 193 |
+
dims=[96, 192, 384, 768],
|
| 194 |
+
drop_path_rate=0.0,
|
| 195 |
+
layer_scale_init_value=1e-6,
|
| 196 |
+
kernel_sizes: Tuple[int] = (7,),
|
| 197 |
+
):
|
| 198 |
+
super().__init__()
|
| 199 |
+
assert len(depths) == len(dims)
|
| 200 |
+
|
| 201 |
+
self.channel_layers = nn.ModuleList()
|
| 202 |
+
stem = nn.Sequential(
|
| 203 |
+
nn.Conv1d(
|
| 204 |
+
input_channels,
|
| 205 |
+
dims[0],
|
| 206 |
+
kernel_size=7,
|
| 207 |
+
padding=3,
|
| 208 |
+
padding_mode="replicate",
|
| 209 |
+
),
|
| 210 |
+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
|
| 211 |
+
)
|
| 212 |
+
self.channel_layers.append(stem)
|
| 213 |
+
|
| 214 |
+
for i in range(len(depths) - 1):
|
| 215 |
+
mid_layer = nn.Sequential(
|
| 216 |
+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
| 217 |
+
nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
|
| 218 |
+
)
|
| 219 |
+
self.channel_layers.append(mid_layer)
|
| 220 |
+
|
| 221 |
+
block_fn = (
|
| 222 |
+
partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
|
| 223 |
+
if len(kernel_sizes) == 1
|
| 224 |
+
else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
self.stages = nn.ModuleList()
|
| 228 |
+
drop_path_rates = [
|
| 229 |
+
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
| 230 |
+
]
|
| 231 |
+
|
| 232 |
+
cur = 0
|
| 233 |
+
for i in range(len(depths)):
|
| 234 |
+
stage = nn.Sequential(
|
| 235 |
+
*[
|
| 236 |
+
block_fn(
|
| 237 |
+
dim=dims[i],
|
| 238 |
+
drop_path=drop_path_rates[cur + j],
|
| 239 |
+
layer_scale_init_value=layer_scale_init_value,
|
| 240 |
+
)
|
| 241 |
+
for j in range(depths[i])
|
| 242 |
+
]
|
| 243 |
+
)
|
| 244 |
+
self.stages.append(stage)
|
| 245 |
+
cur += depths[i]
|
| 246 |
+
|
| 247 |
+
self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
|
| 248 |
+
self.apply(self._init_weights)
|
| 249 |
+
|
| 250 |
+
def _init_weights(self, m):
|
| 251 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
| 252 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 253 |
+
nn.init.constant_(m.bias, 0)
|
| 254 |
+
|
| 255 |
+
def forward(
|
| 256 |
+
self,
|
| 257 |
+
x: torch.Tensor,
|
| 258 |
+
) -> torch.Tensor:
|
| 259 |
+
for channel_layer, stage in zip(self.channel_layers, self.stages):
|
| 260 |
+
x = channel_layer(x)
|
| 261 |
+
x = stage(x)
|
| 262 |
+
|
| 263 |
+
return self.norm(x)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def init_weights(m, mean=0.0, std=0.01):
|
| 267 |
+
classname = m.__class__.__name__
|
| 268 |
+
if classname.find("Conv") != -1:
|
| 269 |
+
m.weight.data.normal_(mean, std)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def get_padding(kernel_size, dilation=1):
|
| 273 |
+
return (kernel_size * dilation - dilation) // 2
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class ResBlock1(torch.nn.Module):
|
| 277 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
| 278 |
+
super().__init__()
|
| 279 |
+
|
| 280 |
+
self.convs1 = nn.ModuleList(
|
| 281 |
+
[
|
| 282 |
+
weight_norm(
|
| 283 |
+
Conv1d(
|
| 284 |
+
channels,
|
| 285 |
+
channels,
|
| 286 |
+
kernel_size,
|
| 287 |
+
1,
|
| 288 |
+
dilation=dilation[0],
|
| 289 |
+
padding=get_padding(kernel_size, dilation[0]),
|
| 290 |
+
)
|
| 291 |
+
),
|
| 292 |
+
weight_norm(
|
| 293 |
+
Conv1d(
|
| 294 |
+
channels,
|
| 295 |
+
channels,
|
| 296 |
+
kernel_size,
|
| 297 |
+
1,
|
| 298 |
+
dilation=dilation[1],
|
| 299 |
+
padding=get_padding(kernel_size, dilation[1]),
|
| 300 |
+
)
|
| 301 |
+
),
|
| 302 |
+
weight_norm(
|
| 303 |
+
Conv1d(
|
| 304 |
+
channels,
|
| 305 |
+
channels,
|
| 306 |
+
kernel_size,
|
| 307 |
+
1,
|
| 308 |
+
dilation=dilation[2],
|
| 309 |
+
padding=get_padding(kernel_size, dilation[2]),
|
| 310 |
+
)
|
| 311 |
+
),
|
| 312 |
+
]
|
| 313 |
+
)
|
| 314 |
+
self.convs1.apply(init_weights)
|
| 315 |
+
|
| 316 |
+
self.convs2 = nn.ModuleList(
|
| 317 |
+
[
|
| 318 |
+
weight_norm(
|
| 319 |
+
Conv1d(
|
| 320 |
+
channels,
|
| 321 |
+
channels,
|
| 322 |
+
kernel_size,
|
| 323 |
+
1,
|
| 324 |
+
dilation=1,
|
| 325 |
+
padding=get_padding(kernel_size, 1),
|
| 326 |
+
)
|
| 327 |
+
),
|
| 328 |
+
weight_norm(
|
| 329 |
+
Conv1d(
|
| 330 |
+
channels,
|
| 331 |
+
channels,
|
| 332 |
+
kernel_size,
|
| 333 |
+
1,
|
| 334 |
+
dilation=1,
|
| 335 |
+
padding=get_padding(kernel_size, 1),
|
| 336 |
+
)
|
| 337 |
+
),
|
| 338 |
+
weight_norm(
|
| 339 |
+
Conv1d(
|
| 340 |
+
channels,
|
| 341 |
+
channels,
|
| 342 |
+
kernel_size,
|
| 343 |
+
1,
|
| 344 |
+
dilation=1,
|
| 345 |
+
padding=get_padding(kernel_size, 1),
|
| 346 |
+
)
|
| 347 |
+
),
|
| 348 |
+
]
|
| 349 |
+
)
|
| 350 |
+
self.convs2.apply(init_weights)
|
| 351 |
+
|
| 352 |
+
def forward(self, x):
|
| 353 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
| 354 |
+
xt = F.silu(x)
|
| 355 |
+
xt = c1(xt)
|
| 356 |
+
xt = F.silu(xt)
|
| 357 |
+
xt = c2(xt)
|
| 358 |
+
x = xt + x
|
| 359 |
+
return x
|
| 360 |
+
|
| 361 |
+
def remove_weight_norm(self):
|
| 362 |
+
for conv in self.convs1:
|
| 363 |
+
remove_weight_norm(conv)
|
| 364 |
+
for conv in self.convs2:
|
| 365 |
+
remove_weight_norm(conv)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
class HiFiGANGenerator(nn.Module):
|
| 369 |
+
def __init__(
|
| 370 |
+
self,
|
| 371 |
+
*,
|
| 372 |
+
hop_length: int = 512,
|
| 373 |
+
upsample_rates: Tuple[int] = (8, 8, 2, 2, 2),
|
| 374 |
+
upsample_kernel_sizes: Tuple[int] = (16, 16, 8, 2, 2),
|
| 375 |
+
resblock_kernel_sizes: Tuple[int] = (3, 7, 11),
|
| 376 |
+
resblock_dilation_sizes: Tuple[Tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
| 377 |
+
num_mels: int = 128,
|
| 378 |
+
upsample_initial_channel: int = 512,
|
| 379 |
+
use_template: bool = True,
|
| 380 |
+
pre_conv_kernel_size: int = 7,
|
| 381 |
+
post_conv_kernel_size: int = 7,
|
| 382 |
+
post_activation: Callable = partial(nn.SiLU, inplace=True),
|
| 383 |
+
):
|
| 384 |
+
super().__init__()
|
| 385 |
+
|
| 386 |
+
assert (
|
| 387 |
+
prod(upsample_rates) == hop_length
|
| 388 |
+
), f"hop_length must be {prod(upsample_rates)}"
|
| 389 |
+
|
| 390 |
+
self.conv_pre = weight_norm(
|
| 391 |
+
nn.Conv1d(
|
| 392 |
+
num_mels,
|
| 393 |
+
upsample_initial_channel,
|
| 394 |
+
pre_conv_kernel_size,
|
| 395 |
+
1,
|
| 396 |
+
padding=get_padding(pre_conv_kernel_size),
|
| 397 |
+
)
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
self.num_upsamples = len(upsample_rates)
|
| 401 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
| 402 |
+
|
| 403 |
+
self.noise_convs = nn.ModuleList()
|
| 404 |
+
self.use_template = use_template
|
| 405 |
+
self.ups = nn.ModuleList()
|
| 406 |
+
|
| 407 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
| 408 |
+
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
| 409 |
+
self.ups.append(
|
| 410 |
+
weight_norm(
|
| 411 |
+
nn.ConvTranspose1d(
|
| 412 |
+
upsample_initial_channel // (2**i),
|
| 413 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
| 414 |
+
k,
|
| 415 |
+
u,
|
| 416 |
+
padding=(k - u) // 2,
|
| 417 |
+
)
|
| 418 |
+
)
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
if not use_template:
|
| 422 |
+
continue
|
| 423 |
+
|
| 424 |
+
if i + 1 < len(upsample_rates):
|
| 425 |
+
stride_f0 = np.prod(upsample_rates[i + 1 :])
|
| 426 |
+
self.noise_convs.append(
|
| 427 |
+
Conv1d(
|
| 428 |
+
1,
|
| 429 |
+
c_cur,
|
| 430 |
+
kernel_size=stride_f0 * 2,
|
| 431 |
+
stride=stride_f0,
|
| 432 |
+
padding=stride_f0 // 2,
|
| 433 |
+
)
|
| 434 |
+
)
|
| 435 |
+
else:
|
| 436 |
+
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
|
| 437 |
+
|
| 438 |
+
self.resblocks = nn.ModuleList()
|
| 439 |
+
for i in range(len(self.ups)):
|
| 440 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
| 441 |
+
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
|
| 442 |
+
self.resblocks.append(ResBlock1(ch, k, d))
|
| 443 |
+
|
| 444 |
+
self.activation_post = post_activation()
|
| 445 |
+
self.conv_post = weight_norm(
|
| 446 |
+
nn.Conv1d(
|
| 447 |
+
ch,
|
| 448 |
+
1,
|
| 449 |
+
post_conv_kernel_size,
|
| 450 |
+
1,
|
| 451 |
+
padding=get_padding(post_conv_kernel_size),
|
| 452 |
+
)
|
| 453 |
+
)
|
| 454 |
+
self.ups.apply(init_weights)
|
| 455 |
+
self.conv_post.apply(init_weights)
|
| 456 |
+
|
| 457 |
+
def forward(self, x, template=None):
|
| 458 |
+
x = self.conv_pre(x)
|
| 459 |
+
|
| 460 |
+
for i in range(self.num_upsamples):
|
| 461 |
+
x = F.silu(x, inplace=True)
|
| 462 |
+
x = self.ups[i](x)
|
| 463 |
+
|
| 464 |
+
if self.use_template:
|
| 465 |
+
x = x + self.noise_convs[i](template)
|
| 466 |
+
|
| 467 |
+
xs = None
|
| 468 |
+
|
| 469 |
+
for j in range(self.num_kernels):
|
| 470 |
+
if xs is None:
|
| 471 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
| 472 |
+
else:
|
| 473 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
| 474 |
+
|
| 475 |
+
x = xs / self.num_kernels
|
| 476 |
+
|
| 477 |
+
x = self.activation_post(x)
|
| 478 |
+
x = self.conv_post(x)
|
| 479 |
+
x = torch.tanh(x)
|
| 480 |
+
|
| 481 |
+
return x
|
| 482 |
+
|
| 483 |
+
def remove_weight_norm(self):
|
| 484 |
+
for up in self.ups:
|
| 485 |
+
remove_weight_norm(up)
|
| 486 |
+
for block in self.resblocks:
|
| 487 |
+
block.remove_weight_norm()
|
| 488 |
+
remove_weight_norm(self.conv_pre)
|
| 489 |
+
remove_weight_norm(self.conv_post)
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
class ADaMoSHiFiGANV1(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 493 |
+
|
| 494 |
+
@register_to_config
|
| 495 |
+
def __init__(
|
| 496 |
+
self,
|
| 497 |
+
input_channels: int = 128,
|
| 498 |
+
depths: List[int] = [3, 3, 9, 3],
|
| 499 |
+
dims: List[int] = [128, 256, 384, 512],
|
| 500 |
+
drop_path_rate: float = 0.0,
|
| 501 |
+
kernel_sizes: Tuple[int] = (7,),
|
| 502 |
+
upsample_rates: Tuple[int] = (4, 4, 2, 2, 2, 2, 2),
|
| 503 |
+
upsample_kernel_sizes: Tuple[int] = (8, 8, 4, 4, 4, 4, 4),
|
| 504 |
+
resblock_kernel_sizes: Tuple[int] = (3, 7, 11, 13),
|
| 505 |
+
resblock_dilation_sizes: Tuple[Tuple[int]] = (
|
| 506 |
+
(1, 3, 5),
|
| 507 |
+
(1, 3, 5),
|
| 508 |
+
(1, 3, 5),
|
| 509 |
+
(1, 3, 5),
|
| 510 |
+
),
|
| 511 |
+
num_mels: int = 512,
|
| 512 |
+
upsample_initial_channel: int = 1024,
|
| 513 |
+
use_template: bool = False,
|
| 514 |
+
pre_conv_kernel_size: int = 13,
|
| 515 |
+
post_conv_kernel_size: int = 13,
|
| 516 |
+
sampling_rate: int = 44100,
|
| 517 |
+
n_fft: int = 2048,
|
| 518 |
+
win_length: int = 2048,
|
| 519 |
+
hop_length: int = 512,
|
| 520 |
+
f_min: int = 40,
|
| 521 |
+
f_max: int = 16000,
|
| 522 |
+
n_mels: int = 128,
|
| 523 |
+
):
|
| 524 |
+
super().__init__()
|
| 525 |
+
|
| 526 |
+
self.backbone = ConvNeXtEncoder(
|
| 527 |
+
input_channels=input_channels,
|
| 528 |
+
depths=depths,
|
| 529 |
+
dims=dims,
|
| 530 |
+
drop_path_rate=drop_path_rate,
|
| 531 |
+
kernel_sizes=kernel_sizes,
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
self.head = HiFiGANGenerator(
|
| 535 |
+
hop_length=hop_length,
|
| 536 |
+
upsample_rates=upsample_rates,
|
| 537 |
+
upsample_kernel_sizes=upsample_kernel_sizes,
|
| 538 |
+
resblock_kernel_sizes=resblock_kernel_sizes,
|
| 539 |
+
resblock_dilation_sizes=resblock_dilation_sizes,
|
| 540 |
+
num_mels=num_mels,
|
| 541 |
+
upsample_initial_channel=upsample_initial_channel,
|
| 542 |
+
use_template=use_template,
|
| 543 |
+
pre_conv_kernel_size=pre_conv_kernel_size,
|
| 544 |
+
post_conv_kernel_size=post_conv_kernel_size,
|
| 545 |
+
)
|
| 546 |
+
self.sampling_rate = sampling_rate
|
| 547 |
+
self.mel_transform = LogMelSpectrogram(
|
| 548 |
+
sample_rate=sampling_rate,
|
| 549 |
+
n_fft=n_fft,
|
| 550 |
+
win_length=win_length,
|
| 551 |
+
hop_length=hop_length,
|
| 552 |
+
f_min=f_min,
|
| 553 |
+
f_max=f_max,
|
| 554 |
+
n_mels=n_mels,
|
| 555 |
+
)
|
| 556 |
+
self.eval()
|
| 557 |
+
|
| 558 |
+
@torch.no_grad()
|
| 559 |
+
def decode(self, mel):
|
| 560 |
+
y = self.backbone(mel)
|
| 561 |
+
y = self.head(y)
|
| 562 |
+
return y
|
| 563 |
+
|
| 564 |
+
@torch.no_grad()
|
| 565 |
+
def encode(self, x):
|
| 566 |
+
return self.mel_transform(x)
|
| 567 |
+
|
| 568 |
+
def forward(self, mel):
|
| 569 |
+
y = self.backbone(mel)
|
| 570 |
+
y = self.head(y)
|
| 571 |
+
return y
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
if __name__ == "__main__":
|
| 575 |
+
import soundfile as sf
|
| 576 |
+
|
| 577 |
+
x = "test_audio.wav"
|
| 578 |
+
model = ADaMoSHiFiGANV1.from_pretrained(
|
| 579 |
+
"./checkpoints/music_vocoder", local_files_only=True
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
wav, sr = librosa.load(x, sr=44100, mono=True)
|
| 583 |
+
wav = torch.from_numpy(wav).float()[None]
|
| 584 |
+
mel = model.encode(wav)
|
| 585 |
+
|
| 586 |
+
wav = model.decode(mel)[0].mT
|
| 587 |
+
sf.write("test_audio_vocoder_rec.wav", wav.cpu().numpy(), 44100)
|
checkpoints/checkpoint_461260.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:796a66a9a098ec75554897e830868c8eb4a9a90c35bb4f972ce317420bb1bbb5
|
| 3 |
+
size 2920814816
|
checkpoints/tag_mapping.json
ADDED
|
@@ -0,0 +1,858 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"rock": 1697,
|
| 3 |
+
"male vocalist": 1698,
|
| 4 |
+
"pop": 1699,
|
| 5 |
+
"energetic": 1700,
|
| 6 |
+
"instrumental": 1701,
|
| 7 |
+
"electronic": 1702,
|
| 8 |
+
"rhythmic": 1703,
|
| 9 |
+
"female vocalist": 1704,
|
| 10 |
+
"passionate": 1705,
|
| 11 |
+
"atmospheric": 1706,
|
| 12 |
+
"rap": 1707,
|
| 13 |
+
"hip hop": 1708,
|
| 14 |
+
"uplifting": 1709,
|
| 15 |
+
"metal": 1710,
|
| 16 |
+
"alternative rock": 1711,
|
| 17 |
+
"pop rock": 1712,
|
| 18 |
+
"dark": 1713,
|
| 19 |
+
"anthemic": 1714,
|
| 20 |
+
"male vocals": 1715,
|
| 21 |
+
"melancholic": 1716,
|
| 22 |
+
"epic": 1717,
|
| 23 |
+
"bittersweet": 1718,
|
| 24 |
+
"love": 1719,
|
| 25 |
+
"dance": 1720,
|
| 26 |
+
"warm": 1721,
|
| 27 |
+
"electronic dance music": 1722,
|
| 28 |
+
"female vocals": 1723,
|
| 29 |
+
"lush": 1724,
|
| 30 |
+
"trap": 1725,
|
| 31 |
+
"introspective": 1726,
|
| 32 |
+
"aggressive": 1727,
|
| 33 |
+
"r&b": 1728,
|
| 34 |
+
"playful": 1729,
|
| 35 |
+
"regional music": 1730,
|
| 36 |
+
"dance-pop": 1731,
|
| 37 |
+
"hard rock": 1732,
|
| 38 |
+
"ambient": 1733,
|
| 39 |
+
"ethereal": 1734,
|
| 40 |
+
"emotional": 1735,
|
| 41 |
+
"heavy": 1736,
|
| 42 |
+
"piano": 1737,
|
| 43 |
+
"mellow": 1738,
|
| 44 |
+
"jazz": 1739,
|
| 45 |
+
"folk": 1740,
|
| 46 |
+
"country": 1741,
|
| 47 |
+
"house": 1742,
|
| 48 |
+
"party": 1743,
|
| 49 |
+
"romantic": 1744,
|
| 50 |
+
"orchestral": 1745,
|
| 51 |
+
"pop rap": 1746,
|
| 52 |
+
"acoustic": 1747,
|
| 53 |
+
"electropop": 1748,
|
| 54 |
+
"electro": 1749,
|
| 55 |
+
"nocturnal": 1750,
|
| 56 |
+
"bass": 1751,
|
| 57 |
+
"guitar": 1752,
|
| 58 |
+
"urban": 1753,
|
| 59 |
+
"soul": 1754,
|
| 60 |
+
"psychedelic": 1755,
|
| 61 |
+
"edm": 1756,
|
| 62 |
+
"experimental": 1757,
|
| 63 |
+
"funk": 1758,
|
| 64 |
+
"futuristic": 1759,
|
| 65 |
+
"boastful": 1760,
|
| 66 |
+
"hypnotic": 1761,
|
| 67 |
+
"heavy metal": 1762,
|
| 68 |
+
"contemporary r&b": 1763,
|
| 69 |
+
"techno": 1764,
|
| 70 |
+
"eclectic": 1765,
|
| 71 |
+
"longing": 1766,
|
| 72 |
+
"violin": 1767,
|
| 73 |
+
"sentimental": 1768,
|
| 74 |
+
"synthpop": 1769,
|
| 75 |
+
"cinematic": 1770,
|
| 76 |
+
"happy": 1771,
|
| 77 |
+
"repetitive": 1772,
|
| 78 |
+
"progressive": 1773,
|
| 79 |
+
"catchy": 1774,
|
| 80 |
+
"sad": 1775,
|
| 81 |
+
"indie pop": 1776,
|
| 82 |
+
"indie rock": 1777,
|
| 83 |
+
"singer-songwriter": 1778,
|
| 84 |
+
"classical music": 1779,
|
| 85 |
+
"slow": 1780,
|
| 86 |
+
"northern american music": 1781,
|
| 87 |
+
"sampling": 1782,
|
| 88 |
+
"trance": 1783,
|
| 89 |
+
"western classical music": 1784,
|
| 90 |
+
"upbeat": 1785,
|
| 91 |
+
"blues": 1786,
|
| 92 |
+
"hip-hop": 1787,
|
| 93 |
+
"ballad": 1788,
|
| 94 |
+
"soothing": 1789,
|
| 95 |
+
"synthwave": 1790,
|
| 96 |
+
"electric guitar": 1791,
|
| 97 |
+
"calm": 1792,
|
| 98 |
+
"raw": 1793,
|
| 99 |
+
"downtempo": 1794,
|
| 100 |
+
"hardcore hip hop": 1795,
|
| 101 |
+
"soft": 1796,
|
| 102 |
+
"dubstep": 1797,
|
| 103 |
+
"classical": 1798,
|
| 104 |
+
"film score": 1799,
|
| 105 |
+
"synth": 1800,
|
| 106 |
+
"triumphant": 1801,
|
| 107 |
+
"drums": 1802,
|
| 108 |
+
"punk": 1803,
|
| 109 |
+
"female voice": 1804,
|
| 110 |
+
"angry": 1805,
|
| 111 |
+
"alternative metal": 1806,
|
| 112 |
+
"acoustic guitar": 1807,
|
| 113 |
+
"lo-fi": 1808,
|
| 114 |
+
"male voice": 1809,
|
| 115 |
+
"dense": 1810,
|
| 116 |
+
"progressive rock": 1811,
|
| 117 |
+
"optimistic": 1812,
|
| 118 |
+
"ominous": 1813,
|
| 119 |
+
"reggae": 1814,
|
| 120 |
+
"sombre": 1815,
|
| 121 |
+
"mysterious": 1816,
|
| 122 |
+
"complex": 1817,
|
| 123 |
+
"contemporary folk": 1818,
|
| 124 |
+
"disco": 1819,
|
| 125 |
+
"drum and bass": 1820,
|
| 126 |
+
"new wave": 1821,
|
| 127 |
+
"nu metal": 1822,
|
| 128 |
+
"summer": 1823,
|
| 129 |
+
"sensual": 1824,
|
| 130 |
+
"powerful": 1825,
|
| 131 |
+
"folk rock": 1826,
|
| 132 |
+
"glitch": 1827,
|
| 133 |
+
"symphonic metal": 1828,
|
| 134 |
+
"emo": 1829,
|
| 135 |
+
"power metal": 1830,
|
| 136 |
+
"conscious": 1831,
|
| 137 |
+
"technical": 1832,
|
| 138 |
+
"suspenseful": 1833,
|
| 139 |
+
"dramatic": 1834,
|
| 140 |
+
"electro house": 1835,
|
| 141 |
+
"deep": 1836,
|
| 142 |
+
"swing": 1837,
|
| 143 |
+
"punk rock": 1838,
|
| 144 |
+
"gangsta rap": 1839,
|
| 145 |
+
"soulful": 1840,
|
| 146 |
+
"intense": 1841,
|
| 147 |
+
"industrial": 1842,
|
| 148 |
+
"cinematic classical": 1843,
|
| 149 |
+
"k-pop": 1844,
|
| 150 |
+
"new age": 1845,
|
| 151 |
+
"hedonistic": 1846,
|
| 152 |
+
"synth-pop": 1847,
|
| 153 |
+
"meditative": 1848,
|
| 154 |
+
"cello": 1849,
|
| 155 |
+
"pop punk": 1850,
|
| 156 |
+
"chillout": 1851,
|
| 157 |
+
"metalcore": 1852,
|
| 158 |
+
"dreamy": 1853,
|
| 159 |
+
"rebellious": 1854,
|
| 160 |
+
"east coast hip hop": 1855,
|
| 161 |
+
"progressive metal": 1856,
|
| 162 |
+
"lonely": 1857,
|
| 163 |
+
"conscious hip hop": 1858,
|
| 164 |
+
"flute": 1859,
|
| 165 |
+
"chill": 1860,
|
| 166 |
+
"phonk": 1861,
|
| 167 |
+
"blues rock": 1862,
|
| 168 |
+
"drum": 1863,
|
| 169 |
+
"quirky": 1864,
|
| 170 |
+
"pop soul": 1865,
|
| 171 |
+
"j-pop": 1866,
|
| 172 |
+
"groovy": 1867,
|
| 173 |
+
"trip hop": 1868,
|
| 174 |
+
"fantasy": 1869,
|
| 175 |
+
"dream pop": 1870,
|
| 176 |
+
"psychedelic rock": 1871,
|
| 177 |
+
"beat": 1872,
|
| 178 |
+
"country rock": 1873,
|
| 179 |
+
"surreal": 1874,
|
| 180 |
+
"gospel": 1875,
|
| 181 |
+
"fast": 1876,
|
| 182 |
+
"soft rock": 1877,
|
| 183 |
+
"smooth": 1878,
|
| 184 |
+
"peaceful": 1879,
|
| 185 |
+
"poetic": 1880,
|
| 186 |
+
"opera": 1881,
|
| 187 |
+
"power pop": 1882,
|
| 188 |
+
"indie folk": 1883,
|
| 189 |
+
"indie": 1884,
|
| 190 |
+
"mechanical": 1885,
|
| 191 |
+
"breakbeat": 1886,
|
| 192 |
+
"anxious": 1887,
|
| 193 |
+
"female vocal": 1888,
|
| 194 |
+
"deep bass": 1889,
|
| 195 |
+
"post-punk": 1890,
|
| 196 |
+
"grunge": 1891,
|
| 197 |
+
"breakup": 1892,
|
| 198 |
+
"choir": 1893,
|
| 199 |
+
"orchestra": 1894,
|
| 200 |
+
"avant-garde": 1895,
|
| 201 |
+
"deep house": 1896,
|
| 202 |
+
"boom bap": 1897,
|
| 203 |
+
"folk pop": 1898,
|
| 204 |
+
"pastoral": 1899,
|
| 205 |
+
"jazz fusion": 1900,
|
| 206 |
+
"progressive house": 1901,
|
| 207 |
+
"synthesizer": 1902,
|
| 208 |
+
"nostalgic": 1903,
|
| 209 |
+
"funky": 1904,
|
| 210 |
+
"country pop": 1905,
|
| 211 |
+
"death": 1906,
|
| 212 |
+
"spiritual": 1907,
|
| 213 |
+
"soundtrack": 1908,
|
| 214 |
+
"2000s": 1909,
|
| 215 |
+
"choral": 1910,
|
| 216 |
+
"strings": 1911,
|
| 217 |
+
"fun": 1912,
|
| 218 |
+
"electric": 1913,
|
| 219 |
+
"post-grunge": 1914,
|
| 220 |
+
"female singer": 1915,
|
| 221 |
+
"male vocal": 1916,
|
| 222 |
+
"modern classical": 1917,
|
| 223 |
+
"death metal": 1918,
|
| 224 |
+
"post-hardcore": 1919,
|
| 225 |
+
"humorous": 1920,
|
| 226 |
+
"heartfelt": 1921,
|
| 227 |
+
"psychedelia": 1922,
|
| 228 |
+
"haunting": 1923,
|
| 229 |
+
"afrobeat": 1924,
|
| 230 |
+
"medieval": 1925,
|
| 231 |
+
"progressive electronic": 1926,
|
| 232 |
+
"adult contemporary": 1927,
|
| 233 |
+
"reggaeton": 1928,
|
| 234 |
+
"dynamic": 1929,
|
| 235 |
+
"contemporary country": 1930,
|
| 236 |
+
"beats": 1931,
|
| 237 |
+
"idm": 1932,
|
| 238 |
+
"southern hip hop": 1933,
|
| 239 |
+
"80s": 1934,
|
| 240 |
+
"cold": 1935,
|
| 241 |
+
"big band": 1936,
|
| 242 |
+
"saxophone": 1937,
|
| 243 |
+
"future bass": 1938,
|
| 244 |
+
"noisy": 1939,
|
| 245 |
+
"gritty": 1940,
|
| 246 |
+
"dark ambient": 1941,
|
| 247 |
+
"trumpet": 1942,
|
| 248 |
+
"art rock": 1943,
|
| 249 |
+
"chaotic": 1944,
|
| 250 |
+
"smooth soul": 1945,
|
| 251 |
+
"post-industrial": 1946,
|
| 252 |
+
"bluegrass": 1947,
|
| 253 |
+
"industrial & noise": 1948,
|
| 254 |
+
"anime": 1949,
|
| 255 |
+
"drill": 1950,
|
| 256 |
+
"electro swing": 1951,
|
| 257 |
+
"dancehall": 1952,
|
| 258 |
+
"epic music": 1953,
|
| 259 |
+
"witch house": 1954,
|
| 260 |
+
"minimalistic": 1955,
|
| 261 |
+
"hispanic american music": 1956,
|
| 262 |
+
"electronica": 1957,
|
| 263 |
+
"americana": 1958,
|
| 264 |
+
"political": 1959,
|
| 265 |
+
"latin": 1960,
|
| 266 |
+
"tech house": 1961,
|
| 267 |
+
"neo-soul": 1962,
|
| 268 |
+
"hispanic music": 1963,
|
| 269 |
+
"heavy bass": 1964,
|
| 270 |
+
"knee surgery": 1965,
|
| 271 |
+
"horror": 1966,
|
| 272 |
+
"psychedelic pop": 1967,
|
| 273 |
+
"industrial metal": 1968,
|
| 274 |
+
"space": 1969,
|
| 275 |
+
"dub": 1970,
|
| 276 |
+
"art pop": 1971,
|
| 277 |
+
"spoken word": 1972,
|
| 278 |
+
"reverb": 1973,
|
| 279 |
+
"caribbean music": 1974,
|
| 280 |
+
"alternative": 1975,
|
| 281 |
+
"symphonic": 1976,
|
| 282 |
+
"cloud rap": 1977,
|
| 283 |
+
"neo-psychedelia": 1978,
|
| 284 |
+
"gothic metal": 1979,
|
| 285 |
+
"classic rock": 1980,
|
| 286 |
+
"female": 1981,
|
| 287 |
+
"bossa nova": 1982,
|
| 288 |
+
"thrash metal": 1983,
|
| 289 |
+
"djent": 1984,
|
| 290 |
+
"teen pop": 1985,
|
| 291 |
+
"cyberpunk": 1986,
|
| 292 |
+
"hardcore": 1987,
|
| 293 |
+
"glam rock": 1988,
|
| 294 |
+
"slow tempo": 1989,
|
| 295 |
+
"jazz rap": 1990,
|
| 296 |
+
"sexy": 1991,
|
| 297 |
+
"harp": 1992,
|
| 298 |
+
"outlaw country": 1993,
|
| 299 |
+
"progressive trance": 1994,
|
| 300 |
+
"european music": 1995,
|
| 301 |
+
"west coast hip hop": 1996,
|
| 302 |
+
"vocal": 1997,
|
| 303 |
+
"alternative dance": 1998,
|
| 304 |
+
"accordion": 1999,
|
| 305 |
+
"minimal": 2000,
|
| 306 |
+
"tribal": 2001,
|
| 307 |
+
"sarcastic": 2002,
|
| 308 |
+
"vocal jazz": 2003,
|
| 309 |
+
"jamaican music": 2004,
|
| 310 |
+
"alternative r&b": 2005,
|
| 311 |
+
"smooth jazz": 2006,
|
| 312 |
+
"gothic": 2007,
|
| 313 |
+
"ska": 2008,
|
| 314 |
+
"manic": 2009,
|
| 315 |
+
"bass guitar": 2010,
|
| 316 |
+
"chillwave": 2011,
|
| 317 |
+
"improvisation": 2012,
|
| 318 |
+
"melancholy": 2013,
|
| 319 |
+
"shoegaze": 2014,
|
| 320 |
+
"big beat": 2015,
|
| 321 |
+
"keyboard": 2016,
|
| 322 |
+
"groove metal": 2017,
|
| 323 |
+
"90s": 2018,
|
| 324 |
+
"latin pop": 2019,
|
| 325 |
+
"hardcore [punk]": 2020,
|
| 326 |
+
"darkwave": 2021,
|
| 327 |
+
"modern": 2022,
|
| 328 |
+
"glam metal": 2023,
|
| 329 |
+
"reflective": 2024,
|
| 330 |
+
"eerie": 2025,
|
| 331 |
+
"chamber pop": 2026,
|
| 332 |
+
"martial": 2027,
|
| 333 |
+
"flamenco": 2028,
|
| 334 |
+
"male singer": 2029,
|
| 335 |
+
"indietronica": 2030,
|
| 336 |
+
"beautiful": 2031,
|
| 337 |
+
"gothic rock": 2032,
|
| 338 |
+
"vocaloid": 2033,
|
| 339 |
+
"world": 2034,
|
| 340 |
+
"math rock": 2035,
|
| 341 |
+
"dark pop": 2036,
|
| 342 |
+
"jazz-funk": 2037,
|
| 343 |
+
"symphonic rock": 2038,
|
| 344 |
+
"club": 2039,
|
| 345 |
+
"bouncy": 2040,
|
| 346 |
+
"easy listening": 2041,
|
| 347 |
+
"j-rock": 2042,
|
| 348 |
+
"baroque": 2043,
|
| 349 |
+
"percussion": 2044,
|
| 350 |
+
"acid jazz": 2045,
|
| 351 |
+
"hardstyle": 2046,
|
| 352 |
+
"rock & roll": 2047,
|
| 353 |
+
"hymn": 2048,
|
| 354 |
+
"dissonant": 2049,
|
| 355 |
+
"ambient pop": 2050,
|
| 356 |
+
"eurodance": 2051,
|
| 357 |
+
"danceable": 2052,
|
| 358 |
+
"turntablism": 2053,
|
| 359 |
+
"dolby atmos": 2054,
|
| 360 |
+
"depressive": 2055,
|
| 361 |
+
"doom metal": 2056,
|
| 362 |
+
"hyperpop": 2057,
|
| 363 |
+
"existential": 2058,
|
| 364 |
+
"melodic metalcore": 2059,
|
| 365 |
+
"male": 2060,
|
| 366 |
+
"chanson": 2061,
|
| 367 |
+
"vaporwave": 2062,
|
| 368 |
+
"salsa": 2063,
|
| 369 |
+
"war": 2064,
|
| 370 |
+
"melodic": 972,
|
| 371 |
+
"fiddle": 2065,
|
| 372 |
+
"film soundtrack": 2066,
|
| 373 |
+
"inspirational": 2067,
|
| 374 |
+
"nu jazz": 2068,
|
| 375 |
+
"vulgar": 2069,
|
| 376 |
+
"abstract": 2070,
|
| 377 |
+
"brass": 2071,
|
| 378 |
+
"confident": 2072,
|
| 379 |
+
"black metal": 2073,
|
| 380 |
+
"video game music": 2074,
|
| 381 |
+
"creepy": 2075,
|
| 382 |
+
"uncommon time signatures": 2076,
|
| 383 |
+
"intimate": 2077,
|
| 384 |
+
"relaxing": 2078,
|
| 385 |
+
"post-rock": 2079,
|
| 386 |
+
"lofi": 2080,
|
| 387 |
+
"roots reggae": 2081,
|
| 388 |
+
"industrial rock": 2082,
|
| 389 |
+
"remix": 2083,
|
| 390 |
+
"storytelling": 2084,
|
| 391 |
+
"funny": 2085,
|
| 392 |
+
"ambient techno": 2086,
|
| 393 |
+
"high-energy": 2087,
|
| 394 |
+
"experimental rock": 2088,
|
| 395 |
+
"southern rock": 2089,
|
| 396 |
+
"celtic": 2090,
|
| 397 |
+
"banjo": 2091,
|
| 398 |
+
"rockabilly": 2092,
|
| 399 |
+
"tabla": 2093,
|
| 400 |
+
"melodic death metal": 2094,
|
| 401 |
+
"minor key": 2095,
|
| 402 |
+
"rap rock": 2096,
|
| 403 |
+
"synth funk": 2097,
|
| 404 |
+
"harmonies": 2098,
|
| 405 |
+
"fast tempo": 2099,
|
| 406 |
+
"garage rock": 2100,
|
| 407 |
+
"breakcore": 2101,
|
| 408 |
+
"harmony": 2102,
|
| 409 |
+
"uptempo": 2103,
|
| 410 |
+
"harmonica": 2104,
|
| 411 |
+
"duet": 2105,
|
| 412 |
+
"alt-pop": 2106,
|
| 413 |
+
"bounce": 2107,
|
| 414 |
+
"hiphop": 2108,
|
| 415 |
+
"funk rock": 2109,
|
| 416 |
+
"jungle": 2110,
|
| 417 |
+
"acoustic rock": 2111,
|
| 418 |
+
"tropical house": 2112,
|
| 419 |
+
"piano rock": 2113,
|
| 420 |
+
"sound effects": 2114,
|
| 421 |
+
"glitch hop": 2115,
|
| 422 |
+
"dance pop": 2116,
|
| 423 |
+
"aquatic": 2117,
|
| 424 |
+
"organ": 2118,
|
| 425 |
+
"baroque pop": 2119,
|
| 426 |
+
"comedy": 2120,
|
| 427 |
+
"theatrical": 2121,
|
| 428 |
+
"sparse": 2122,
|
| 429 |
+
"bassline": 2123,
|
| 430 |
+
"scary": 2124,
|
| 431 |
+
"cute": 2125,
|
| 432 |
+
"drone": 2126,
|
| 433 |
+
"horrorcore": 2127,
|
| 434 |
+
"bass house": 2128,
|
| 435 |
+
"emo rap": 2129,
|
| 436 |
+
"moody": 2130,
|
| 437 |
+
"drums (drum set)": 2131,
|
| 438 |
+
"fast-paced": 2132,
|
| 439 |
+
"double bass": 2133,
|
| 440 |
+
"progressive pop": 2134,
|
| 441 |
+
"apocalyptic": 2135,
|
| 442 |
+
"hardcore punk": 2136,
|
| 443 |
+
"anthem": 2137,
|
| 444 |
+
"europop": 2138,
|
| 445 |
+
"upright bass": 2139,
|
| 446 |
+
"groove": 2140,
|
| 447 |
+
"psytrance": 2141,
|
| 448 |
+
"dark wave": 2142,
|
| 449 |
+
"kpop": 2143,
|
| 450 |
+
"minimal techno": 2144,
|
| 451 |
+
"rock and roll": 2145,
|
| 452 |
+
"grime": 2146,
|
| 453 |
+
"lively": 2147,
|
| 454 |
+
"rave": 2148,
|
| 455 |
+
"syncopated": 2149,
|
| 456 |
+
"show tunes": 2150,
|
| 457 |
+
"autotune": 2151,
|
| 458 |
+
"sitar": 2152,
|
| 459 |
+
"nu-disco": 2153,
|
| 460 |
+
"folk metal": 2154,
|
| 461 |
+
"traditional pop": 2155,
|
| 462 |
+
"surf rock": 2156,
|
| 463 |
+
"noise": 2157,
|
| 464 |
+
"brostep": 2158,
|
| 465 |
+
"serious": 2159,
|
| 466 |
+
"traditional": 2160,
|
| 467 |
+
"pessimistic": 2161,
|
| 468 |
+
"ebm": 2162,
|
| 469 |
+
"female vocalists": 2163,
|
| 470 |
+
"speed metal": 2164,
|
| 471 |
+
"classic": 2165,
|
| 472 |
+
"post-punk revival": 2166,
|
| 473 |
+
"lounge": 2167,
|
| 474 |
+
"electric blues": 2168,
|
| 475 |
+
"winter": 2169,
|
| 476 |
+
"clear vocals": 2170,
|
| 477 |
+
"retro": 2171,
|
| 478 |
+
"raspy": 2172,
|
| 479 |
+
"progressive country": 2173,
|
| 480 |
+
"vibrant": 2174,
|
| 481 |
+
"mystical": 2175,
|
| 482 |
+
"deathcore": 2176,
|
| 483 |
+
"alt-country": 2177,
|
| 484 |
+
"theme": 2178,
|
| 485 |
+
"8-bit": 2179,
|
| 486 |
+
"jangle pop": 2180,
|
| 487 |
+
"aor": 2181,
|
| 488 |
+
"delta blues": 2182,
|
| 489 |
+
"light": 2183,
|
| 490 |
+
"lyrical": 2184,
|
| 491 |
+
"distorted guitars": 2185,
|
| 492 |
+
"jazz-rock": 2186,
|
| 493 |
+
"classical crossover": 2187,
|
| 494 |
+
"fusion": 2188,
|
| 495 |
+
"doo-wop": 2189,
|
| 496 |
+
"television music": 2190,
|
| 497 |
+
"clean": 2191,
|
| 498 |
+
"symphony": 2192,
|
| 499 |
+
"whimsical": 2193,
|
| 500 |
+
"honky tonk": 2194,
|
| 501 |
+
"chamber music": 2195,
|
| 502 |
+
"breathy": 2196,
|
| 503 |
+
"echo": 2197,
|
| 504 |
+
"uk garage": 2198,
|
| 505 |
+
"acid techno": 2199,
|
| 506 |
+
"ritualistic": 2200,
|
| 507 |
+
"scratch": 2201,
|
| 508 |
+
"darksynth": 2202,
|
| 509 |
+
"edgy": 2203,
|
| 510 |
+
"layered harmonies": 2204,
|
| 511 |
+
"rhythm & blues": 2205,
|
| 512 |
+
"80's": 2206,
|
| 513 |
+
"experimental hip hop": 2207,
|
| 514 |
+
"808": 2208,
|
| 515 |
+
"expressive": 2209,
|
| 516 |
+
"1960s": 2210,
|
| 517 |
+
"cryptic": 2211,
|
| 518 |
+
"g-funk": 2212,
|
| 519 |
+
"oud": 2213,
|
| 520 |
+
"male vocalists": 2214,
|
| 521 |
+
"uk drill": 2215,
|
| 522 |
+
"gentle": 2216,
|
| 523 |
+
"musical": 2217,
|
| 524 |
+
"sultry": 2218,
|
| 525 |
+
"samba": 2219,
|
| 526 |
+
"violins": 2220,
|
| 527 |
+
"soul jazz": 2221,
|
| 528 |
+
"alienation": 2222,
|
| 529 |
+
"deep voice": 2223,
|
| 530 |
+
"layered": 2224,
|
| 531 |
+
"screamo": 2225,
|
| 532 |
+
"drift phonk": 2226,
|
| 533 |
+
"shamisen": 2227,
|
| 534 |
+
"rap metal": 2228,
|
| 535 |
+
"strong": 2229,
|
| 536 |
+
"062 final fantasy ii": 3,
|
| 537 |
+
"063 final fantasy iii": 4,
|
| 538 |
+
"064 final fantasy iii remake": 5,
|
| 539 |
+
"066 final fantasy iv": 7,
|
| 540 |
+
"067 final fantasy iv remake": 8,
|
| 541 |
+
"068 final fantasy v": 9,
|
| 542 |
+
"069 final fantasy vi": 10,
|
| 543 |
+
"070 final fantasy vii": 11,
|
| 544 |
+
"071 final fantasy vii remake": 12,
|
| 545 |
+
"072 final fantasy viii": 13,
|
| 546 |
+
"073 final fantasy ix": 14,
|
| 547 |
+
"075 final fantasy x": 15,
|
| 548 |
+
"076 final fantasy xi": 16,
|
| 549 |
+
"077 final fantasy xii": 17,
|
| 550 |
+
"078 final fantasy xiii": 18,
|
| 551 |
+
"079 final fantasy xiv": 19,
|
| 552 |
+
"081 final fantasy xv": 20,
|
| 553 |
+
"082 final fantasy 0": 21,
|
| 554 |
+
"089 final fantasy x2": 26,
|
| 555 |
+
"093 final fantasy xiii2": 29,
|
| 556 |
+
"094 final fantasy xiii3": 30,
|
| 557 |
+
"097 dissidia final fantasy": 33,
|
| 558 |
+
"13 sentinels aegis rim": 40,
|
| 559 |
+
"ace combat 7": 143,
|
| 560 |
+
"advance wars": 144,
|
| 561 |
+
"advance wars days of ruin": 145,
|
| 562 |
+
"advance wars dual strike": 146,
|
| 563 |
+
"advance wars 2 black hole rising": 148,
|
| 564 |
+
"advance wars dual strike": 149,
|
| 565 |
+
"animal crossing wild world": 166,
|
| 566 |
+
"animal crossing new horizons": 167,
|
| 567 |
+
"ar tonelico": 171,
|
| 568 |
+
"armored core": 173,
|
| 569 |
+
"atelier escher and logy": 182,
|
| 570 |
+
"atelier iris": 183,
|
| 571 |
+
"atelier iris 2": 184,
|
| 572 |
+
"atelier iris 3": 185,
|
| 573 |
+
"atelier marie": 186,
|
| 574 |
+
"atelier resleriana": 187,
|
| 575 |
+
"atelier rorona": 188,
|
| 576 |
+
"atelier ryza": 189,
|
| 577 |
+
"atelier ryza 2": 190,
|
| 578 |
+
"atelier ryza 3": 191,
|
| 579 |
+
"atelier totori": 192,
|
| 580 |
+
"atlantis kitsune": 193,
|
| 581 |
+
"attack on titan": 194,
|
| 582 |
+
"azur lane": 198,
|
| 583 |
+
"baldurs gate 3": 255,
|
| 584 |
+
"banjo kazooie": 261,
|
| 585 |
+
"banjo tooie": 262,
|
| 586 |
+
"black clover": 272,
|
| 587 |
+
"black myth wukong": 273,
|
| 588 |
+
"blazblue": 277,
|
| 589 |
+
"bleach": 278,
|
| 590 |
+
"blue reflection": 285,
|
| 591 |
+
"bocchi the rock": 289,
|
| 592 |
+
"castlevania": 329,
|
| 593 |
+
"castlevania dawn of sorrow": 330,
|
| 594 |
+
"castlevania order of ecclesia": 331,
|
| 595 |
+
"castlevania portrait of ruin": 332,
|
| 596 |
+
"castlevania symphony of the night": 333,
|
| 597 |
+
"castlevania aria of sorrow": 334,
|
| 598 |
+
"cave story": 337,
|
| 599 |
+
"celeste": 339,
|
| 600 |
+
"chiptune": 350,
|
| 601 |
+
"chrono cross": 359,
|
| 602 |
+
"chrono trigger": 360,
|
| 603 |
+
"clair obscur": 367,
|
| 604 |
+
"clannad": 368,
|
| 605 |
+
"contra": 374,
|
| 606 |
+
"crosscode": 382,
|
| 607 |
+
"cuphead": 384,
|
| 608 |
+
"dmc4": 397,
|
| 609 |
+
"dmcv": 398,
|
| 610 |
+
"danganronpa": 414,
|
| 611 |
+
"danganronpa 2": 415,
|
| 612 |
+
"deltarune 2": 423,
|
| 613 |
+
"deltarune34": 424,
|
| 614 |
+
"diddy kong racing": 428,
|
| 615 |
+
"gb sounds": 431,
|
| 616 |
+
"disgaea 5": 432,
|
| 617 |
+
"doki doki literature club": 435,
|
| 618 |
+
"donkey kong 64": 436,
|
| 619 |
+
"donkey kong country": 437,
|
| 620 |
+
"donkey kong country 2": 438,
|
| 621 |
+
"donkey kong country 3": 439,
|
| 622 |
+
"doom": 442,
|
| 623 |
+
"dragalia lost": 445,
|
| 624 |
+
"dragon quest ix": 448,
|
| 625 |
+
"drakengard 3": 449,
|
| 626 |
+
"elder scrolls 3 morrowind": 474,
|
| 627 |
+
"etrian odyssey ii": 479,
|
| 628 |
+
"etrian odyssey iii": 480,
|
| 629 |
+
"fzero": 495,
|
| 630 |
+
"fzero maximum velocity": 496,
|
| 631 |
+
"fzero gx": 497,
|
| 632 |
+
"fzero x": 498,
|
| 633 |
+
"fairy tail": 499,
|
| 634 |
+
"far cry 6": 500,
|
| 635 |
+
"fate grand order": 501,
|
| 636 |
+
"fate stay night": 505,
|
| 637 |
+
"fire emblem": 512,
|
| 638 |
+
"fire emblem awakening": 513,
|
| 639 |
+
"fire emblem three houses": 515,
|
| 640 |
+
"fruits basket": 521,
|
| 641 |
+
"fuga melodies of steel": 523,
|
| 642 |
+
"fullmetal alchemist": 524,
|
| 643 |
+
"fullmetal alchemist brotherhood": 525,
|
| 644 |
+
"genshin impact": 554,
|
| 645 |
+
"ghost in the shell": 555,
|
| 646 |
+
"goldeneye 007": 565,
|
| 647 |
+
"granblue fantasy": 566,
|
| 648 |
+
"granblue fantasy versus": 567,
|
| 649 |
+
"gurren lagann": 574,
|
| 650 |
+
"gust": 575,
|
| 651 |
+
"hades": 616,
|
| 652 |
+
"haikyuu": 617,
|
| 653 |
+
"harvest moon": 621,
|
| 654 |
+
"hearthstone": 625,
|
| 655 |
+
"hollow knight": 638,
|
| 656 |
+
"hololive": 639,
|
| 657 |
+
"homestuck": 640,
|
| 658 |
+
"homestuck alternia": 645,
|
| 659 |
+
"homestuck alterniabound": 646,
|
| 660 |
+
"homestuck cherubim": 647,
|
| 661 |
+
"honkai impact 3rd": 658,
|
| 662 |
+
"honkai star rail": 661,
|
| 663 |
+
"jojos bizarre adventure": 755,
|
| 664 |
+
"journey": 756,
|
| 665 |
+
"kid icarus uprising": 800,
|
| 666 |
+
"kill la kill": 803,
|
| 667 |
+
"kingdom hearts 3582 days": 816,
|
| 668 |
+
"kingdom hearts 3d dream drop distance": 817,
|
| 669 |
+
"kingdom hearts recoded": 818,
|
| 670 |
+
"kirby": 819,
|
| 671 |
+
"kirby 64 the crystal shards": 820,
|
| 672 |
+
"kirby ds": 821,
|
| 673 |
+
"kirbys dream land 3": 822,
|
| 674 |
+
"konosuba": 827,
|
| 675 |
+
"lamulana": 859,
|
| 676 |
+
"legend of zelda the": 878,
|
| 677 |
+
"legend of zelda the a link to the past": 879,
|
| 678 |
+
"legend of zelda the majoras mask": 880,
|
| 679 |
+
"legend of zelda the ocarina of time": 881,
|
| 680 |
+
"legend of zelda the phantom hourglass": 882,
|
| 681 |
+
"legend of zelda the spirit tracks": 883,
|
| 682 |
+
"legend of zelda the twilight princess": 884,
|
| 683 |
+
"mana khemia": 936,
|
| 684 |
+
"maple story": 937,
|
| 685 |
+
"mario luigi bowsers inside story": 938,
|
| 686 |
+
"mario luigi dream team": 939,
|
| 687 |
+
"mario luigi partners in time": 940,
|
| 688 |
+
"mario luigi superstar saga": 941,
|
| 689 |
+
"mario 3d land": 942,
|
| 690 |
+
"mario golf": 943,
|
| 691 |
+
"mario kart super circuit": 944,
|
| 692 |
+
"mario kart 64": 945,
|
| 693 |
+
"mario kart 7": 946,
|
| 694 |
+
"mario kart ds": 947,
|
| 695 |
+
"mario kart wii": 948,
|
| 696 |
+
"mario kart 8": 949,
|
| 697 |
+
"mario party 3": 952,
|
| 698 |
+
"mario party 4": 953,
|
| 699 |
+
"mario party 5": 954,
|
| 700 |
+
"mario tennis": 955,
|
| 701 |
+
"mega man": 961,
|
| 702 |
+
"mega man 3": 962,
|
| 703 |
+
"mega man 4": 963,
|
| 704 |
+
"mega man 7": 964,
|
| 705 |
+
"mega man battle network": 965,
|
| 706 |
+
"mega man x": 966,
|
| 707 |
+
"mega man x2": 967,
|
| 708 |
+
"mega man x3": 968,
|
| 709 |
+
"mega man x4": 969,
|
| 710 |
+
"mega man zero zx": 970,
|
| 711 |
+
"metal gear solid 2": 978,
|
| 712 |
+
"metroid": 979,
|
| 713 |
+
"metroid zero mission": 980,
|
| 714 |
+
"metroid prime 2 echoes": 981,
|
| 715 |
+
"metroid prime 3": 982,
|
| 716 |
+
"metroid prime": 983,
|
| 717 |
+
"minecraft": 989,
|
| 718 |
+
"monogatari": 1411,
|
| 719 |
+
"my hero academia": 1006,
|
| 720 |
+
"nausicaa valley of the wind": 1034,
|
| 721 |
+
"neon genesis evangelion": 1039,
|
| 722 |
+
"neon white": 1040,
|
| 723 |
+
"new super mario bros": 1045,
|
| 724 |
+
"new super mario bros wii": 1046,
|
| 725 |
+
"ni no kuni": 1050,
|
| 726 |
+
"ni no kuni 2": 1051,
|
| 727 |
+
"nier automata": 1053,
|
| 728 |
+
"night in the woods": 1054,
|
| 729 |
+
"ninja gaiden 1": 1055,
|
| 730 |
+
"ninja gaiden 2": 1056,
|
| 731 |
+
"omori": 1080,
|
| 732 |
+
"one piece": 1081,
|
| 733 |
+
"outer wilds": 1084,
|
| 734 |
+
"parasite eve": 1114,
|
| 735 |
+
"perfect dark": 1122,
|
| 736 |
+
"phoenix wright ace attorney": 1124,
|
| 737 |
+
"phoenix wright ace attorney 2": 1125,
|
| 738 |
+
"pokemon anime": 1131,
|
| 739 |
+
"pokemon black and white": 1133,
|
| 740 |
+
"pokemon crystal": 1134,
|
| 741 |
+
"pokemon diamond": 1135,
|
| 742 |
+
"pokemon fire red and leaf green": 1137,
|
| 743 |
+
"pokemon heart gold soul silver": 1138,
|
| 744 |
+
"pokemon mystery dungeon blue rescue team": 1140,
|
| 745 |
+
"pokemon mystery dungeon explorers of sky": 1141,
|
| 746 |
+
"pokemon mystery dungeon gates to infinity": 1142,
|
| 747 |
+
"pokemon omega ruby": 1143,
|
| 748 |
+
"pokemon red": 1144,
|
| 749 |
+
"pokemon ruby": 1145,
|
| 750 |
+
"pokemon scarlet": 1147,
|
| 751 |
+
"pokemon sun and moon": 1148,
|
| 752 |
+
"pokemon super mystery dungeon": 1149,
|
| 753 |
+
"pokemon x and y": 1150,
|
| 754 |
+
"pokemon xd gale of darkness": 1151,
|
| 755 |
+
"professor layton and the curious village": 1173,
|
| 756 |
+
"resident evil": 1208,
|
| 757 |
+
"scottpilgrim": 1277,
|
| 758 |
+
"secret of mana": 1284,
|
| 759 |
+
"shin megami tensei iv": 1298,
|
| 760 |
+
"shovelknight": 1300,
|
| 761 |
+
"skyrim": 1320,
|
| 762 |
+
"sonic advance 3": 1335,
|
| 763 |
+
"sonic adventure 2": 1337,
|
| 764 |
+
"sonic mania": 1338,
|
| 765 |
+
"sonic the hedgehog": 1339,
|
| 766 |
+
"sonic the hedgehog 2": 1340,
|
| 767 |
+
"sonic the hedgehog 3": 1341,
|
| 768 |
+
"spirited away": 1347,
|
| 769 |
+
"star fox": 1348,
|
| 770 |
+
"star ocean": 1349,
|
| 771 |
+
"starcraft 2": 1350,
|
| 772 |
+
"stardew valley": 1351,
|
| 773 |
+
"stellaris": 1354,
|
| 774 |
+
"street fighter ii": 1359,
|
| 775 |
+
"super mario 64": 1370,
|
| 776 |
+
"super mario bros 3": 1374,
|
| 777 |
+
"super mario galaxy": 1375,
|
| 778 |
+
"super mario rpg": 1377,
|
| 779 |
+
"super mario sunshine": 1378,
|
| 780 |
+
"super monkey ball 2": 1379,
|
| 781 |
+
"super smash bros brawl": 1381,
|
| 782 |
+
"tales of symphonia": 1391,
|
| 783 |
+
"the sims 2": 1395,
|
| 784 |
+
"totalwar": 1396,
|
| 785 |
+
"touhou 10": 1397,
|
| 786 |
+
"touhou 11": 1398,
|
| 787 |
+
"touhou 12": 1399,
|
| 788 |
+
"touhou 14": 1400,
|
| 789 |
+
"touhou 15": 1401,
|
| 790 |
+
"touhou 6": 1402,
|
| 791 |
+
"touhou 7": 1403,
|
| 792 |
+
"touhou 8": 1405,
|
| 793 |
+
"touhou 9": 1406,
|
| 794 |
+
"tunic": 1407,
|
| 795 |
+
"undertale": 1410,
|
| 796 |
+
"violet evergarden": 1413,
|
| 797 |
+
"wild arms 2": 1417,
|
| 798 |
+
"witcher 3": 1418,
|
| 799 |
+
"wow": 1419,
|
| 800 |
+
"wuthering waves": 1421,
|
| 801 |
+
"xenoblade chronicles": 1423,
|
| 802 |
+
"xenoblade chronicles 2": 1424,
|
| 803 |
+
"xenoblade chronicles 2 torna": 1425,
|
| 804 |
+
"xenoblade chronicles 3": 1426,
|
| 805 |
+
"xenogears": 1427,
|
| 806 |
+
"ys": 1429,
|
| 807 |
+
"zenless zone zero": 1434,
|
| 808 |
+
"beatmania": 1459,
|
| 809 |
+
"berserk": 1460,
|
| 810 |
+
"castle crashers": 1465,
|
| 811 |
+
"everquest": 1470,
|
| 812 |
+
"mortal kombat": 1482,
|
| 813 |
+
"nier": 1483,
|
| 814 |
+
"persona": 1484,
|
| 815 |
+
"sayonara wild hearts": 1487,
|
| 816 |
+
"touhou remixes": 1550,
|
| 817 |
+
"yakuza": 1659,
|
| 818 |
+
"sea shanty": 2230,
|
| 819 |
+
"emo-pop": 2231,
|
| 820 |
+
"skate punk": 2232,
|
| 821 |
+
"bright": 2233,
|
| 822 |
+
"cumbia": 2234,
|
| 823 |
+
"world music": 2235,
|
| 824 |
+
"synth pop": 2236,
|
| 825 |
+
"chorus": 2237,
|
| 826 |
+
"japanese": 2238,
|
| 827 |
+
"schlager": 2239,
|
| 828 |
+
"asian music": 2240,
|
| 829 |
+
"glam pop": 2241,
|
| 830 |
+
"lute": 2242,
|
| 831 |
+
"misanthropic": 2243,
|
| 832 |
+
"christian": 2244,
|
| 833 |
+
"bubblegum pop": 2245,
|
| 834 |
+
"808s": 2246,
|
| 835 |
+
"remastered": 2247,
|
| 836 |
+
"christmas music": 2248,
|
| 837 |
+
"wave": 2249,
|
| 838 |
+
"tango": 2250,
|
| 839 |
+
"hateful": 2251,
|
| 840 |
+
"high energy": 2252,
|
| 841 |
+
"neoclassical darkwave": 2253,
|
| 842 |
+
"electroclash": 2254,
|
| 843 |
+
"seductive": 2255,
|
| 844 |
+
"dungeon synth": 2256,
|
| 845 |
+
"city pop": 2257,
|
| 846 |
+
"heroic": 2258,
|
| 847 |
+
"freestyle": 2259,
|
| 848 |
+
"space ambient": 2260,
|
| 849 |
+
"bounce drop": 2261,
|
| 850 |
+
"afrobeats": 2262,
|
| 851 |
+
"power ballad": 2263,
|
| 852 |
+
"trombone": 2264,
|
| 853 |
+
"guitar solo": 2265,
|
| 854 |
+
"battle": 2266,
|
| 855 |
+
"ending": 2267,
|
| 856 |
+
"soundtrack1": 2268,
|
| 857 |
+
"soundtrack2": 2269
|
| 858 |
+
}
|
gradio_app.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import List, Tuple
|
| 4 |
+
import uuid
|
| 5 |
+
import json
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import torch
|
| 8 |
+
import torchaudio
|
| 9 |
+
from safetensors.torch import load_file
|
| 10 |
+
|
| 11 |
+
from model import LocalSongModel
|
| 12 |
+
from acestep.music_dcae.music_dcae_pipeline import MusicDCAE
|
| 13 |
+
|
| 14 |
+
class TagEmbedder:
|
| 15 |
+
def __init__(self, mapping_file: str = "checkpoints/tag_mapping.json"):
|
| 16 |
+
|
| 17 |
+
with open(mapping_file, 'r', encoding='utf-8') as f:
|
| 18 |
+
self.tag_mapping = json.load(f)
|
| 19 |
+
|
| 20 |
+
print(f"Loaded {len(self.tag_mapping)} tags from {mapping_file}")
|
| 21 |
+
self.num_classes = 2304
|
| 22 |
+
|
| 23 |
+
class AudioVAE:
|
| 24 |
+
def __init__(self, device: torch.device):
|
| 25 |
+
self.model = MusicDCAE().to(device)
|
| 26 |
+
self.model.eval()
|
| 27 |
+
self.device = device
|
| 28 |
+
self.latent_mean = torch.tensor(
|
| 29 |
+
[0.1207, -0.0186, -0.0947, -0.3779, 0.5956, 0.3422, 0.1796, -0.0526],
|
| 30 |
+
device=device,
|
| 31 |
+
).view(1, -1, 1, 1)
|
| 32 |
+
self.latent_std = torch.tensor(
|
| 33 |
+
[0.4638, 0.3154, 0.6244, 1.5078, 0.4696, 0.4633, 0.5614, 0.2707],
|
| 34 |
+
device=device,
|
| 35 |
+
).view(1, -1, 1, 1)
|
| 36 |
+
|
| 37 |
+
def decode(self, latents: torch.Tensor) -> torch.Tensor:
|
| 38 |
+
with torch.no_grad():
|
| 39 |
+
latents = latents * self.latent_std + self.latent_mean
|
| 40 |
+
sr, audio_list = self.model.decode(latents, sr=48000)
|
| 41 |
+
audio_batch = torch.stack(audio_list).to(self.device)
|
| 42 |
+
return audio_batch
|
| 43 |
+
|
| 44 |
+
class RF:
|
| 45 |
+
def __init__(self, model: torch.nn.Module):
|
| 46 |
+
self.model = model
|
| 47 |
+
|
| 48 |
+
def sample(
|
| 49 |
+
self,
|
| 50 |
+
z: torch.Tensor,
|
| 51 |
+
cond: List[List[int]],
|
| 52 |
+
null_cond: List[List[int]] | None = None,
|
| 53 |
+
sample_steps: int = 100,
|
| 54 |
+
cfg: float = 3.0,
|
| 55 |
+
) -> List[torch.Tensor]:
|
| 56 |
+
batch = z.size(0)
|
| 57 |
+
dt = 1.0 / sample_steps
|
| 58 |
+
dt = torch.tensor([dt] * batch, device=z.device).view([batch, *([1] * len(z.shape[1:]))])
|
| 59 |
+
images = [z]
|
| 60 |
+
for i in range(sample_steps, 0, -1):
|
| 61 |
+
t = torch.tensor([i / sample_steps] * batch, device=z.device)
|
| 62 |
+
|
| 63 |
+
if null_cond is not None:
|
| 64 |
+
|
| 65 |
+
z_batched = torch.cat([z, z], dim=0)
|
| 66 |
+
t_batched = torch.cat([t, t], dim=0)
|
| 67 |
+
cond_batched = cond + null_cond
|
| 68 |
+
v_batched = self.model(z_batched, t_batched, cond_batched)
|
| 69 |
+
vc, vu = v_batched.chunk(2, dim=0)
|
| 70 |
+
vc = vu + cfg * (vc - vu)
|
| 71 |
+
|
| 72 |
+
else:
|
| 73 |
+
vc = self.model(z, t, cond)
|
| 74 |
+
|
| 75 |
+
z = z - dt * vc
|
| 76 |
+
images.append(z)
|
| 77 |
+
return images
|
| 78 |
+
|
| 79 |
+
model: torch.nn.Module | None = None
|
| 80 |
+
vae: AudioVAE | None = None
|
| 81 |
+
tag_embedder: TagEmbedder | None = None
|
| 82 |
+
rf_sampler: RF | None = None
|
| 83 |
+
device: torch.device | None = None
|
| 84 |
+
_available_tags: List[str] | None = None
|
| 85 |
+
|
| 86 |
+
def load_resources() -> List[str]:
|
| 87 |
+
|
| 88 |
+
torch.set_float32_matmul_precision('high')
|
| 89 |
+
|
| 90 |
+
global model, vae, tag_embedder, rf_sampler, device, _available_tags
|
| 91 |
+
|
| 92 |
+
if _available_tags is not None:
|
| 93 |
+
return _available_tags
|
| 94 |
+
|
| 95 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 96 |
+
|
| 97 |
+
tag_embedder = TagEmbedder()
|
| 98 |
+
|
| 99 |
+
model = LocalSongModel(
|
| 100 |
+
in_channels=8,
|
| 101 |
+
num_groups=16,
|
| 102 |
+
hidden_size=1024,
|
| 103 |
+
decoder_hidden_size=2048,
|
| 104 |
+
num_blocks=36,
|
| 105 |
+
patch_size=(16, 1),
|
| 106 |
+
num_classes=tag_embedder.num_classes,
|
| 107 |
+
max_tags=8,
|
| 108 |
+
).to(device)
|
| 109 |
+
|
| 110 |
+
checkpoint_path = "checkpoints/checkpoint_461260.safetensors"
|
| 111 |
+
print(f"Loading checkpoint: {checkpoint_path}")
|
| 112 |
+
|
| 113 |
+
state_dict = load_file(checkpoint_path, device=str(device))
|
| 114 |
+
model.load_state_dict(state_dict, strict=True)
|
| 115 |
+
model.eval()
|
| 116 |
+
|
| 117 |
+
vae = AudioVAE(device)
|
| 118 |
+
rf_sampler = RF(model)
|
| 119 |
+
|
| 120 |
+
_available_tags = sorted(tag_embedder.tag_mapping.keys())
|
| 121 |
+
return _available_tags
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _tags_to_indices(tags: List[str]) -> List[int]:
|
| 125 |
+
assert tag_embedder is not None
|
| 126 |
+
indices = []
|
| 127 |
+
|
| 128 |
+
for tag in tags:
|
| 129 |
+
tag_lower = tag.lower().strip()
|
| 130 |
+
if tag_lower in tag_embedder.tag_mapping:
|
| 131 |
+
indices.append(tag_embedder.tag_mapping[tag_lower])
|
| 132 |
+
|
| 133 |
+
return indices
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def generate_audio(
|
| 137 |
+
tags: List[str],
|
| 138 |
+
cfg: float,
|
| 139 |
+
sample_steps: int,
|
| 140 |
+
) -> Tuple[Tuple[int, object], str]:
|
| 141 |
+
|
| 142 |
+
load_resources()
|
| 143 |
+
assert model is not None and vae is not None and rf_sampler is not None and device is not None
|
| 144 |
+
|
| 145 |
+
if not tags:
|
| 146 |
+
tags = []
|
| 147 |
+
if len(tags) > 8:
|
| 148 |
+
raise gr.Error("A maximum of 8 tags is supported.")
|
| 149 |
+
|
| 150 |
+
tag_indices = _tags_to_indices(tags)
|
| 151 |
+
|
| 152 |
+
batch = 1
|
| 153 |
+
channels = 8
|
| 154 |
+
height = 16
|
| 155 |
+
width = 512
|
| 156 |
+
|
| 157 |
+
z = torch.randn(batch, channels, height, width, device=device)
|
| 158 |
+
cond = [tag_indices]
|
| 159 |
+
null_cond = [[]]
|
| 160 |
+
|
| 161 |
+
with torch.no_grad():
|
| 162 |
+
sampled_latents = rf_sampler.sample(
|
| 163 |
+
z=z,
|
| 164 |
+
cond=cond,
|
| 165 |
+
null_cond=null_cond,
|
| 166 |
+
sample_steps=sample_steps,
|
| 167 |
+
cfg=cfg,
|
| 168 |
+
)[-1]
|
| 169 |
+
audio = vae.decode(sampled_latents)
|
| 170 |
+
|
| 171 |
+
audio_tensor = audio[0].cpu()
|
| 172 |
+
sr = 48000
|
| 173 |
+
audio_numpy = audio_tensor.transpose(0, 1).numpy()
|
| 174 |
+
|
| 175 |
+
os.makedirs("generated", exist_ok=True)
|
| 176 |
+
output_path = f"generated/generated_{uuid.uuid4().hex}.wav"
|
| 177 |
+
torchaudio.save(str(output_path), audio_tensor, sr)
|
| 178 |
+
|
| 179 |
+
return (sr, audio_numpy), str(output_path)
|
| 180 |
+
|
| 181 |
+
def build_interface() -> gr.Blocks:
|
| 182 |
+
available_tags = load_resources()
|
| 183 |
+
|
| 184 |
+
# Define preset tag combinations
|
| 185 |
+
presets = [
|
| 186 |
+
["soundtrack1", "female vocalist","rock","melodic"],
|
| 187 |
+
["soundtrack", "chrono trigger", "emotional", "piano", "strings"],
|
| 188 |
+
["soundtrack", "touhou 10", "trumpet"],
|
| 189 |
+
["soundtrack", "christmas music","winter","melodic"],
|
| 190 |
+
["soundtrack2", "male vocalist","pop","melodic","acoustic guitar","ballad"],
|
| 191 |
+
]
|
| 192 |
+
|
| 193 |
+
with gr.Blocks(title="LocalSong") as demo:
|
| 194 |
+
gr.Markdown("# LocalSong")
|
| 195 |
+
|
| 196 |
+
with gr.Row():
|
| 197 |
+
tag_input = gr.Dropdown(
|
| 198 |
+
label="Tags (select up to 8)",
|
| 199 |
+
choices=available_tags,
|
| 200 |
+
multiselect=True,
|
| 201 |
+
max_choices=8,
|
| 202 |
+
value=presets[0],
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
gr.Markdown("**Presets:**")
|
| 206 |
+
with gr.Row():
|
| 207 |
+
for preset in presets:
|
| 208 |
+
btn = gr.Button(f"{' + '.join(preset)}", size="sm")
|
| 209 |
+
def make_preset_fn(p):
|
| 210 |
+
return lambda: p
|
| 211 |
+
btn.click(fn=make_preset_fn(preset), inputs=None, outputs=tag_input)
|
| 212 |
+
|
| 213 |
+
with gr.Row():
|
| 214 |
+
cfg_slider = gr.Slider(
|
| 215 |
+
label="CFG Scale",
|
| 216 |
+
minimum=1.0,
|
| 217 |
+
maximum=7.0,
|
| 218 |
+
step=0.5,
|
| 219 |
+
value=3.5,
|
| 220 |
+
)
|
| 221 |
+
sample_steps_slider = gr.Slider(
|
| 222 |
+
label="Sample Steps",
|
| 223 |
+
minimum=50,
|
| 224 |
+
maximum=200,
|
| 225 |
+
step=10,
|
| 226 |
+
value=200,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
with gr.Row():
|
| 230 |
+
seed_input = gr.Number(
|
| 231 |
+
label="Seed",
|
| 232 |
+
value=45,
|
| 233 |
+
precision=0,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
generate_button = gr.Button("Generate Audio", variant="primary")
|
| 237 |
+
audio_output = gr.Audio(label="Generated Audio", type="numpy")
|
| 238 |
+
download_output = gr.File(label="Download WAV")
|
| 239 |
+
|
| 240 |
+
def generate_wrapper(tags, cfg, steps, seed):
|
| 241 |
+
torch.manual_seed(seed)
|
| 242 |
+
if torch.cuda.is_available():
|
| 243 |
+
torch.cuda.manual_seed(seed)
|
| 244 |
+
return generate_audio(tags, cfg, steps)
|
| 245 |
+
|
| 246 |
+
generate_button.click(
|
| 247 |
+
fn=generate_wrapper,
|
| 248 |
+
inputs=[
|
| 249 |
+
tag_input,
|
| 250 |
+
cfg_slider,
|
| 251 |
+
sample_steps_slider,
|
| 252 |
+
seed_input,
|
| 253 |
+
],
|
| 254 |
+
outputs=[
|
| 255 |
+
audio_output,
|
| 256 |
+
download_output,
|
| 257 |
+
],
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
return demo
|
| 261 |
+
|
| 262 |
+
demo = build_interface()
|
| 263 |
+
|
| 264 |
+
if __name__ == "__main__":
|
| 265 |
+
demo.launch()
|
model.py
ADDED
|
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import math
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
from torch.nn.functional import scaled_dot_product_attention
|
| 7 |
+
|
| 8 |
+
def modulate(x, shift, scale):
|
| 9 |
+
return x * (1 + scale) + shift
|
| 10 |
+
|
| 11 |
+
class Embed(nn.Module):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
in_chans: int = 3,
|
| 15 |
+
embed_dim: int = 768,
|
| 16 |
+
norm_layer = None,
|
| 17 |
+
bias: bool = True,
|
| 18 |
+
):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.in_chans = in_chans
|
| 21 |
+
self.embed_dim = embed_dim
|
| 22 |
+
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
|
| 23 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
x = self.proj(x)
|
| 26 |
+
x = self.norm(x)
|
| 27 |
+
return x
|
| 28 |
+
|
| 29 |
+
class PatchEmbed(nn.Module):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
in_channels=8,
|
| 33 |
+
embed_dim=1152,
|
| 34 |
+
bias=True,
|
| 35 |
+
patch_size=1,
|
| 36 |
+
):
|
| 37 |
+
super().__init__()
|
| 38 |
+
|
| 39 |
+
self.patch_h, self.patch_w = patch_size
|
| 40 |
+
|
| 41 |
+
self.patch_size = patch_size
|
| 42 |
+
self.proj = nn.Linear(in_channels * self.patch_h * self.patch_w, embed_dim, bias=bias)
|
| 43 |
+
self.in_channels = in_channels
|
| 44 |
+
self.embed_dim = embed_dim
|
| 45 |
+
|
| 46 |
+
def forward(self, latent):
|
| 47 |
+
x = rearrange(latent, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1=self.patch_h, p2=self.patch_w)
|
| 48 |
+
x = self.proj(x)
|
| 49 |
+
return x
|
| 50 |
+
|
| 51 |
+
class FinalLayer(nn.Module):
|
| 52 |
+
"""Final layer with configurable patch_size support"""
|
| 53 |
+
|
| 54 |
+
def __init__(self, hidden_size, out_channels=8, patch_size=1):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.patch_h, self.patch_w = patch_size
|
| 57 |
+
|
| 58 |
+
self.linear = nn.Linear(hidden_size, out_channels * self.patch_h * self.patch_w, bias=True)
|
| 59 |
+
self.out_channels = out_channels
|
| 60 |
+
self.patch_size = patch_size
|
| 61 |
+
|
| 62 |
+
def forward(self, x, target_height, target_width):
|
| 63 |
+
|
| 64 |
+
x = self.linear(x)
|
| 65 |
+
|
| 66 |
+
x = rearrange(x, 'b (h w) (c p1 p2) -> b c (h p1) (w p2)',
|
| 67 |
+
h=target_height, w=target_width,
|
| 68 |
+
p1=self.patch_h, p2=self.patch_w, c=self.out_channels)
|
| 69 |
+
return x
|
| 70 |
+
|
| 71 |
+
class TimestepEmbedder(nn.Module):
|
| 72 |
+
|
| 73 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 74 |
+
super().__init__()
|
| 75 |
+
self.mlp = nn.Sequential(
|
| 76 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 77 |
+
nn.SiLU(),
|
| 78 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 79 |
+
)
|
| 80 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 81 |
+
|
| 82 |
+
@staticmethod
|
| 83 |
+
def timestep_embedding(t, dim, max_period=10):
|
| 84 |
+
half = dim // 2
|
| 85 |
+
freqs = torch.exp(
|
| 86 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
| 87 |
+
)
|
| 88 |
+
args = t[..., None].float() * freqs[None, ...]
|
| 89 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 90 |
+
if dim % 2:
|
| 91 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 92 |
+
return embedding
|
| 93 |
+
|
| 94 |
+
def forward(self, t):
|
| 95 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 96 |
+
t_emb = self.mlp(t_freq)
|
| 97 |
+
return t_emb
|
| 98 |
+
|
| 99 |
+
class RMSNorm(nn.Module):
|
| 100 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 103 |
+
self.variance_epsilon = eps
|
| 104 |
+
|
| 105 |
+
def forward(self, hidden_states):
|
| 106 |
+
input_dtype = hidden_states.dtype
|
| 107 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 108 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 109 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 110 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 111 |
+
|
| 112 |
+
class FeedForward(nn.Module):
|
| 113 |
+
def __init__(
|
| 114 |
+
self,
|
| 115 |
+
dim: int,
|
| 116 |
+
hidden_dim: int,
|
| 117 |
+
):
|
| 118 |
+
super().__init__()
|
| 119 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
| 120 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
| 121 |
+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
| 122 |
+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
| 123 |
+
def forward(self, x):
|
| 124 |
+
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
| 125 |
+
return x
|
| 126 |
+
|
| 127 |
+
def precompute_freqs_cis_2d(dim: int, height: int, width: int, theta: float = 10000.0, scale=1.0):
|
| 128 |
+
|
| 129 |
+
if isinstance(scale, float):
|
| 130 |
+
scale = (scale, scale)
|
| 131 |
+
x_pos = torch.linspace(0, width * scale[0], width)
|
| 132 |
+
y_pos = torch.linspace(0, height * scale[1], height)
|
| 133 |
+
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
|
| 134 |
+
y_pos = y_pos.reshape(-1)
|
| 135 |
+
x_pos = x_pos.reshape(-1)
|
| 136 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
| 137 |
+
x_freqs = torch.outer(x_pos, freqs).float()
|
| 138 |
+
y_freqs = torch.outer(y_pos, freqs).float()
|
| 139 |
+
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
|
| 140 |
+
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
|
| 141 |
+
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1)
|
| 142 |
+
freqs_cis = freqs_cis.reshape(height * width, -1)
|
| 143 |
+
return freqs_cis
|
| 144 |
+
|
| 145 |
+
@torch.compiler.disable
|
| 146 |
+
def apply_rotary_emb_2d(
|
| 147 |
+
xq: torch.Tensor,
|
| 148 |
+
xk: torch.Tensor,
|
| 149 |
+
freqs_cis: torch.Tensor,
|
| 150 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 151 |
+
|
| 152 |
+
freqs_cis = freqs_cis[None, None, :, :]
|
| 153 |
+
|
| 154 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
| 155 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
| 156 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
| 157 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
| 158 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
| 159 |
+
|
| 160 |
+
class RAttention(nn.Module):
|
| 161 |
+
def __init__(
|
| 162 |
+
self,
|
| 163 |
+
dim: int,
|
| 164 |
+
num_heads: int = 8,
|
| 165 |
+
qkv_bias: bool = False,
|
| 166 |
+
qk_norm: bool = True,
|
| 167 |
+
attn_drop: float = 0.,
|
| 168 |
+
proj_drop: float = 0.,
|
| 169 |
+
norm_layer: nn.Module = RMSNorm,
|
| 170 |
+
) -> None:
|
| 171 |
+
super().__init__()
|
| 172 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
| 173 |
+
|
| 174 |
+
self.dim = dim
|
| 175 |
+
self.num_heads = num_heads
|
| 176 |
+
self.head_dim = dim // num_heads
|
| 177 |
+
self.scale = self.head_dim ** -0.5
|
| 178 |
+
|
| 179 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 180 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 181 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 182 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 183 |
+
self.proj = nn.Linear(dim, dim)
|
| 184 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 185 |
+
|
| 186 |
+
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
|
| 187 |
+
B, N, C = x.shape
|
| 188 |
+
|
| 189 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 190 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 191 |
+
q = self.q_norm(q.contiguous())
|
| 192 |
+
k = self.k_norm(k.contiguous())
|
| 193 |
+
q, k = apply_rotary_emb_2d(q, k, freqs_cis=pos)
|
| 194 |
+
|
| 195 |
+
q = q.view(B, self.num_heads, -1, C // self.num_heads)
|
| 196 |
+
k = k.view(B, self.num_heads, -1, C // self.num_heads).contiguous()
|
| 197 |
+
v = v.view(B, self.num_heads, -1, C // self.num_heads).contiguous()
|
| 198 |
+
|
| 199 |
+
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.attn_drop.p if self.training else 0.0)
|
| 200 |
+
|
| 201 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 202 |
+
x = self.proj(x)
|
| 203 |
+
x = self.proj_drop(x)
|
| 204 |
+
return x
|
| 205 |
+
|
| 206 |
+
class CrossAttention(nn.Module):
|
| 207 |
+
def __init__(
|
| 208 |
+
self,
|
| 209 |
+
dim: int,
|
| 210 |
+
context_dim: int,
|
| 211 |
+
num_heads: int,
|
| 212 |
+
qkv_bias: bool = False,
|
| 213 |
+
proj_drop: float = 0.0,
|
| 214 |
+
):
|
| 215 |
+
super().__init__()
|
| 216 |
+
self.num_heads = num_heads
|
| 217 |
+
self.head_dim = dim // num_heads
|
| 218 |
+
self.scale = self.head_dim**-0.5
|
| 219 |
+
|
| 220 |
+
self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
|
| 221 |
+
self.kv_proj = nn.Linear(context_dim, dim * 2, bias=qkv_bias)
|
| 222 |
+
self.proj = nn.Linear(dim, dim)
|
| 223 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 224 |
+
|
| 225 |
+
def forward(self, x: torch.Tensor, context: torch.Tensor, context_mask: torch.Tensor = None) -> torch.Tensor:
|
| 226 |
+
B, N, C = x.shape
|
| 227 |
+
B_ctx, M, C_ctx = context.shape
|
| 228 |
+
|
| 229 |
+
q = self.q_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
| 230 |
+
kv = self.kv_proj(context).reshape(B_ctx, M, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 231 |
+
k, v = kv[0], kv[1]
|
| 232 |
+
|
| 233 |
+
attn_mask = None
|
| 234 |
+
if context_mask is not None:
|
| 235 |
+
attn_mask = torch.zeros(B, 1, 1, M, dtype=q.dtype, device=q.device)
|
| 236 |
+
attn_mask.masked_fill_(~context_mask.unsqueeze(1).unsqueeze(2), float('-inf'))
|
| 237 |
+
|
| 238 |
+
attn = scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=self.proj_drop.p if self.training else 0.0)
|
| 239 |
+
|
| 240 |
+
x = attn.permute(0, 2, 1, 3).reshape(B, N, C)
|
| 241 |
+
x = self.proj(x)
|
| 242 |
+
x = self.proj_drop(x)
|
| 243 |
+
return x
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class DDTBlock(nn.Module):
|
| 247 |
+
def __init__(self, hidden_size, groups, mlp_ratio=4.0, context_dim=None, is_encoder_block=False):
|
| 248 |
+
super().__init__()
|
| 249 |
+
self.hidden_size = hidden_size
|
| 250 |
+
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
|
| 251 |
+
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
|
| 252 |
+
|
| 253 |
+
self.norm_cross = RMSNorm(hidden_size, eps=1e-6) if context_dim else nn.Identity()
|
| 254 |
+
self.cross_attn = CrossAttention(hidden_size, context_dim, groups) if context_dim else None
|
| 255 |
+
|
| 256 |
+
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
|
| 257 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 258 |
+
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
|
| 259 |
+
|
| 260 |
+
self.is_encoder_block = is_encoder_block
|
| 261 |
+
if not is_encoder_block:
|
| 262 |
+
self.adaLN_modulation = nn.Sequential(
|
| 263 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
def forward(self, x, c, pos, mask=None, context=None, context_mask=None, shared_adaLN=None):
|
| 267 |
+
if self.is_encoder_block:
|
| 268 |
+
adaLN_output = shared_adaLN(c)
|
| 269 |
+
else:
|
| 270 |
+
adaLN_output = self.adaLN_modulation(c)
|
| 271 |
+
|
| 272 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = adaLN_output.chunk(6, dim=-1)
|
| 273 |
+
|
| 274 |
+
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
|
| 275 |
+
|
| 276 |
+
if self.cross_attn is not None and context is not None:
|
| 277 |
+
x = x + self.cross_attn(self.norm_cross(x), context=context, context_mask=context_mask)
|
| 278 |
+
|
| 279 |
+
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
| 280 |
+
return x
|
| 281 |
+
|
| 282 |
+
class LocalSongModel(nn.Module):
|
| 283 |
+
def __init__(
|
| 284 |
+
self,
|
| 285 |
+
in_channels=8,
|
| 286 |
+
num_groups=16,
|
| 287 |
+
hidden_size=1024,
|
| 288 |
+
decoder_hidden_size=2048,
|
| 289 |
+
num_blocks=36,
|
| 290 |
+
patch_size=(16,1),
|
| 291 |
+
num_classes=2304,
|
| 292 |
+
max_tags=8,
|
| 293 |
+
):
|
| 294 |
+
super().__init__()
|
| 295 |
+
self.in_channels = in_channels
|
| 296 |
+
self.out_channels = in_channels
|
| 297 |
+
self.hidden_size = hidden_size
|
| 298 |
+
self.decoder_hidden_size = decoder_hidden_size
|
| 299 |
+
self.num_groups = num_groups
|
| 300 |
+
self.num_groups = num_groups
|
| 301 |
+
self.num_blocks = num_blocks
|
| 302 |
+
self.patch_size = patch_size
|
| 303 |
+
self.num_classes = num_classes
|
| 304 |
+
self.max_tags = max_tags
|
| 305 |
+
|
| 306 |
+
self.patch_h, self.patch_w = patch_size
|
| 307 |
+
|
| 308 |
+
self.x_embedder = PatchEmbed(
|
| 309 |
+
in_channels=in_channels,
|
| 310 |
+
embed_dim=decoder_hidden_size,
|
| 311 |
+
bias=True,
|
| 312 |
+
patch_size=patch_size
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
self.s_embedder = PatchEmbed(
|
| 316 |
+
in_channels=in_channels,
|
| 317 |
+
embed_dim=decoder_hidden_size,
|
| 318 |
+
bias=True,
|
| 319 |
+
patch_size=patch_size
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
self.encoder_to_decoder = nn.Linear(hidden_size, decoder_hidden_size, bias=False)
|
| 323 |
+
|
| 324 |
+
self.a_to_b_proj = nn.Linear(decoder_hidden_size, hidden_size, bias=False)
|
| 325 |
+
|
| 326 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
| 327 |
+
|
| 328 |
+
self.y_embedder = nn.Embedding(num_classes + 1, hidden_size, padding_idx=0)
|
| 329 |
+
|
| 330 |
+
self.final_layer = FinalLayer(
|
| 331 |
+
decoder_hidden_size,
|
| 332 |
+
out_channels=in_channels,
|
| 333 |
+
patch_size=patch_size
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
self.shared_encoder_adaLN = nn.Sequential(
|
| 337 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
self.shared_decoder_adaLN = nn.Sequential(
|
| 341 |
+
nn.Linear(hidden_size, 6 * decoder_hidden_size, bias=True)
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
self.blocks = nn.ModuleList()
|
| 345 |
+
for i in range(self.num_blocks):
|
| 346 |
+
is_encoder = i < self.num_blocks
|
| 347 |
+
|
| 348 |
+
if is_encoder:
|
| 349 |
+
if i < 1:
|
| 350 |
+
block_hidden_size = decoder_hidden_size
|
| 351 |
+
num_heads = self.num_groups
|
| 352 |
+
elif i >= self.num_blocks - 3:
|
| 353 |
+
block_hidden_size = decoder_hidden_size
|
| 354 |
+
num_heads = self.num_groups
|
| 355 |
+
else:
|
| 356 |
+
block_hidden_size = hidden_size
|
| 357 |
+
num_heads = self.num_groups
|
| 358 |
+
else:
|
| 359 |
+
block_hidden_size = decoder_hidden_size
|
| 360 |
+
num_heads = self.num_groups
|
| 361 |
+
|
| 362 |
+
context_dim = hidden_size if i % 2 == 0 and is_encoder else None
|
| 363 |
+
|
| 364 |
+
self.blocks.append(
|
| 365 |
+
DDTBlock(
|
| 366 |
+
block_hidden_size,
|
| 367 |
+
num_heads,
|
| 368 |
+
context_dim=context_dim,
|
| 369 |
+
is_encoder_block=is_encoder
|
| 370 |
+
)
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
self.bc_projection = nn.Linear(decoder_hidden_size + hidden_size, decoder_hidden_size, bias=False)
|
| 374 |
+
|
| 375 |
+
self.initialize_weights()
|
| 376 |
+
self.precompute_encoder_pos = dict()
|
| 377 |
+
self.precompute_decoder_pos = dict()
|
| 378 |
+
|
| 379 |
+
from functools import lru_cache
|
| 380 |
+
|
| 381 |
+
@lru_cache
|
| 382 |
+
def fetch_encoder_pos(self, height, width, device):
|
| 383 |
+
key = (height, width)
|
| 384 |
+
if key in self.precompute_encoder_pos:
|
| 385 |
+
return self.precompute_encoder_pos[key].to(device)
|
| 386 |
+
else:
|
| 387 |
+
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
|
| 388 |
+
self.precompute_encoder_pos[key] = pos
|
| 389 |
+
return pos
|
| 390 |
+
|
| 391 |
+
@lru_cache
|
| 392 |
+
def fetch_decoder_pos(self, height, width, device):
|
| 393 |
+
key = (height, width)
|
| 394 |
+
if key in self.precompute_decoder_pos:
|
| 395 |
+
return self.precompute_decoder_pos[key].to(device)
|
| 396 |
+
else:
|
| 397 |
+
pos = precompute_freqs_cis_2d(self.decoder_hidden_size // self.num_groups, height, width).to(device)
|
| 398 |
+
self.precompute_decoder_pos[key] = pos
|
| 399 |
+
return pos
|
| 400 |
+
|
| 401 |
+
def initialize_weights(self):
|
| 402 |
+
for embedder in [self.x_embedder, self.s_embedder]:
|
| 403 |
+
nn.init.xavier_uniform_(embedder.proj.weight)
|
| 404 |
+
if embedder.proj.bias is not None:
|
| 405 |
+
nn.init.constant_(embedder.proj.bias, 0)
|
| 406 |
+
|
| 407 |
+
nn.init.xavier_uniform_(self.encoder_to_decoder.weight)
|
| 408 |
+
nn.init.xavier_uniform_(self.a_to_b_proj.weight)
|
| 409 |
+
|
| 410 |
+
nn.init.normal_(self.y_embedder.weight, std=0.02)
|
| 411 |
+
|
| 412 |
+
with torch.no_grad():
|
| 413 |
+
self.y_embedder.weight[0].fill_(0)
|
| 414 |
+
|
| 415 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 416 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 417 |
+
|
| 418 |
+
nn.init.constant_(self.shared_encoder_adaLN[-1].weight, 0)
|
| 419 |
+
nn.init.constant_(self.shared_encoder_adaLN[-1].bias, 0)
|
| 420 |
+
nn.init.constant_(self.shared_decoder_adaLN[-1].weight, 0)
|
| 421 |
+
nn.init.constant_(self.shared_decoder_adaLN[-1].bias, 0)
|
| 422 |
+
|
| 423 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 424 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 425 |
+
|
| 426 |
+
nn.init.xavier_uniform_(self.bc_projection.weight)
|
| 427 |
+
|
| 428 |
+
def embed_condition(self, cond):
|
| 429 |
+
|
| 430 |
+
device = self.y_embedder.weight.device
|
| 431 |
+
|
| 432 |
+
max_len = self.max_tags
|
| 433 |
+
batch_size = len(cond)
|
| 434 |
+
|
| 435 |
+
padded_tags = torch.zeros(batch_size, max_len, dtype=torch.long, device=device)
|
| 436 |
+
|
| 437 |
+
for i, tags in enumerate(cond):
|
| 438 |
+
truncated_tags = tags[:max_len]
|
| 439 |
+
padded_tags[i, :len(truncated_tags)] = torch.tensor(truncated_tags, dtype=torch.long, device=device)
|
| 440 |
+
|
| 441 |
+
padding_mask = (padded_tags != 0)
|
| 442 |
+
|
| 443 |
+
embedded = self.y_embedder(padded_tags)
|
| 444 |
+
|
| 445 |
+
return embedded, padding_mask
|
| 446 |
+
|
| 447 |
+
def forward(self, x, t, y):
|
| 448 |
+
y_emb, padding_mask = self.embed_condition(y)
|
| 449 |
+
|
| 450 |
+
return self.forward_emb(x, t, y_emb, padding_mask)
|
| 451 |
+
|
| 452 |
+
@torch.compile()
|
| 453 |
+
def forward_emb(self, x, t, y_emb, padding_mask=None):
|
| 454 |
+
B, _, H, W = x.shape
|
| 455 |
+
|
| 456 |
+
h_patches = H // self.patch_h
|
| 457 |
+
w_patches = W // self.patch_w
|
| 458 |
+
encoder_pos = self.fetch_encoder_pos(h_patches, w_patches, x.device)
|
| 459 |
+
decoder_pos = self.fetch_decoder_pos(h_patches, w_patches, x.device)
|
| 460 |
+
|
| 461 |
+
t_emb = self.t_embedder(t.view(-1)).view(B, 1, self.hidden_size)
|
| 462 |
+
|
| 463 |
+
t_cond = nn.functional.silu(t_emb)
|
| 464 |
+
|
| 465 |
+
s = self.s_embedder(x)
|
| 466 |
+
|
| 467 |
+
s_section_a = s
|
| 468 |
+
for i in range(min(1, self.num_blocks)):
|
| 469 |
+
block_context = y_emb if i % 2 == 0 else None
|
| 470 |
+
s_section_a = self.blocks[i](s_section_a, t_cond, decoder_pos, None, context=block_context, context_mask=padding_mask, shared_adaLN=self.shared_decoder_adaLN)
|
| 471 |
+
|
| 472 |
+
s_section_a_projected = self.a_to_b_proj(s_section_a)
|
| 473 |
+
|
| 474 |
+
s_section_b = s_section_a_projected
|
| 475 |
+
|
| 476 |
+
for i in range(1, self.num_blocks - 3):
|
| 477 |
+
block_context = y_emb if i % 2 == 0 else None
|
| 478 |
+
s_section_b = self.blocks[i](s_section_b, t_cond, encoder_pos, None, context=block_context, context_mask=padding_mask, shared_adaLN=self.shared_encoder_adaLN)
|
| 479 |
+
|
| 480 |
+
s_concat = torch.cat([s_section_a, s_section_b], dim=-1)
|
| 481 |
+
|
| 482 |
+
s = self.bc_projection(s_concat)
|
| 483 |
+
|
| 484 |
+
for i in range(max(1, self.num_blocks - 3), self.num_blocks):
|
| 485 |
+
block_context = y_emb if i % 2 == 0 else None
|
| 486 |
+
s = self.blocks[i](s, t_cond, decoder_pos, None, context=block_context, context_mask=padding_mask, shared_adaLN=self.shared_decoder_adaLN)
|
| 487 |
+
|
| 488 |
+
s = self.final_layer(s, H // self.patch_h, W // self.patch_w)
|
| 489 |
+
|
| 490 |
+
return s
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.8.0
|
| 2 |
+
torchaudio>=2.8.0
|
| 3 |
+
torchvision>=0.23.0
|
| 4 |
+
torchcodec>=0.8.0
|
| 5 |
+
accelerate>=1.9.0
|
| 6 |
+
diffusers>=0.34.0
|
| 7 |
+
einops>=0.8.1
|
| 8 |
+
librosa>=0.11.0
|
| 9 |
+
safetensors>=0.4.0
|
| 10 |
+
gradio>=5.45.0
|