Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
ef96930
0
Parent(s):
init
Browse files- .gitattributes +9 -0
- .gitignore +10 -0
- LICENSE +201 -0
- README.md +33 -0
- app.py +98 -0
- inference_example_pretrain.py +44 -0
- inference_example_sft.py +71 -0
- requirements.txt +9 -0
- requirements_space.txt +9 -0
- run_mimo_audio.py +764 -0
- src/mimo_audio/mimo_audio.py +1292 -0
- src/mimo_audio/modeling_mimo_audio.py +835 -0
- src/mimo_audio/process_speechdata.py +289 -0
- src/mimo_audio/templates.py +54 -0
- src/mimo_audio_tokenizer/__init__.py +6 -0
- src/mimo_audio_tokenizer/configuration_audio_tokenizer.py +104 -0
- src/mimo_audio_tokenizer/modeling_audio_tokenizer.py +857 -0
- src/mimo_audio_tokenizer/modeling_rope_utils.py +878 -0
- src/mimo_audio_tokenizer/quantization.py +480 -0
.gitattributes
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
build/
|
| 2 |
+
dist/
|
| 3 |
+
checkpoints/
|
| 4 |
+
*.egg-info/
|
| 5 |
+
*.egg
|
| 6 |
+
*.pyc
|
| 7 |
+
*.pyo
|
| 8 |
+
*.pyd
|
| 9 |
+
*.pyw
|
| 10 |
+
*.pyz
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright 2025 Xiaomi Corporation.
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: MiMo-Audio TTS
|
| 3 |
+
emoji: 🎵
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.46.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
python_version: 3.12
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# MiMo-Audio Text-to-Speech
|
| 15 |
+
|
| 16 |
+
A simple text-to-speech interface powered by Xiaomi's MiMo-Audio model.
|
| 17 |
+
|
| 18 |
+
## Features
|
| 19 |
+
|
| 20 |
+
- Convert text to natural-sounding speech
|
| 21 |
+
- Optional style descriptions to control voice characteristics
|
| 22 |
+
- Powered by MiMo-Audio-7B-Instruct model
|
| 23 |
+
|
| 24 |
+
## Usage
|
| 25 |
+
|
| 26 |
+
1. Enter your text in the input box
|
| 27 |
+
2. Optionally add a style description (e.g., "a calm, gentle voice")
|
| 28 |
+
3. Click "Generate Speech"
|
| 29 |
+
4. Listen to or download the generated audio
|
| 30 |
+
|
| 31 |
+
## Model
|
| 32 |
+
|
| 33 |
+
This Space uses the [MiMo-Audio-7B-Instruct](https://huggingface.co/XiaomiMiMo/MiMo-Audio-7B-Instruct) model, a 7B parameter audio language model developed by Xiaomi.
|
app.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import spaces
|
| 3 |
+
import torch
|
| 4 |
+
from huggingface_hub import snapshot_download
|
| 5 |
+
from src.mimo_audio.mimo_audio import MimoAudio
|
| 6 |
+
import tempfile
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
# Download models from Hugging Face
|
| 10 |
+
print("Downloading MiMo-Audio models from Hugging Face...")
|
| 11 |
+
model_path = snapshot_download(repo_id="XiaomiMiMo/MiMo-Audio-7B-Instruct")
|
| 12 |
+
tokenizer_path = snapshot_download(repo_id="XiaomiMiMo/MiMo-Audio-Tokenizer")
|
| 13 |
+
print(f"Models downloaded to: {model_path} and {tokenizer_path}")
|
| 14 |
+
|
| 15 |
+
# Initialize model
|
| 16 |
+
print("Loading MiMo-Audio model...")
|
| 17 |
+
model = MimoAudio(
|
| 18 |
+
model_path=model_path,
|
| 19 |
+
tokenizer_path=tokenizer_path
|
| 20 |
+
)
|
| 21 |
+
print("Model loaded successfully!")
|
| 22 |
+
|
| 23 |
+
@spaces.GPU
|
| 24 |
+
def generate_speech(text, style_description=""):
|
| 25 |
+
"""Generate speech from text using MiMo-Audio TTS"""
|
| 26 |
+
if not text or not text.strip():
|
| 27 |
+
return None, "Please enter some text to convert to speech."
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
# Create temporary file for output
|
| 31 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
| 32 |
+
output_path = tmp_file.name
|
| 33 |
+
|
| 34 |
+
# Generate TTS
|
| 35 |
+
instruct = style_description if style_description.strip() else None
|
| 36 |
+
model.tts_sft(
|
| 37 |
+
text=text.strip(),
|
| 38 |
+
output_audio_path=output_path,
|
| 39 |
+
instruct=instruct
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
return output_path, "✅ Speech generated successfully!"
|
| 43 |
+
|
| 44 |
+
except Exception as e:
|
| 45 |
+
return None, f"❌ Error: {str(e)}"
|
| 46 |
+
|
| 47 |
+
# Create Gradio interface
|
| 48 |
+
with gr.Blocks(title="MiMo-Audio TTS") as demo:
|
| 49 |
+
gr.Markdown("""
|
| 50 |
+
# 🎵 MiMo-Audio Text-to-Speech
|
| 51 |
+
|
| 52 |
+
Convert text to natural-sounding speech using Xiaomi's MiMo-Audio model.
|
| 53 |
+
Optionally add a style description to control the voice characteristics.
|
| 54 |
+
""")
|
| 55 |
+
|
| 56 |
+
with gr.Row():
|
| 57 |
+
with gr.Column():
|
| 58 |
+
text_input = gr.Textbox(
|
| 59 |
+
label="Input Text",
|
| 60 |
+
placeholder="Enter text to convert to speech...",
|
| 61 |
+
lines=5
|
| 62 |
+
)
|
| 63 |
+
style_input = gr.Textbox(
|
| 64 |
+
label="Style Description (Optional)",
|
| 65 |
+
placeholder="e.g., 'a calm, gentle female voice' or 'an energetic male speaker'",
|
| 66 |
+
lines=2
|
| 67 |
+
)
|
| 68 |
+
generate_btn = gr.Button("Generate Speech", variant="primary")
|
| 69 |
+
|
| 70 |
+
with gr.Column():
|
| 71 |
+
audio_output = gr.Audio(
|
| 72 |
+
label="Generated Speech",
|
| 73 |
+
type="filepath"
|
| 74 |
+
)
|
| 75 |
+
status_output = gr.Textbox(
|
| 76 |
+
label="Status",
|
| 77 |
+
interactive=False
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# Examples
|
| 81 |
+
gr.Examples(
|
| 82 |
+
examples=[
|
| 83 |
+
["Hello! This is MiMo-Audio text-to-speech synthesis.", ""],
|
| 84 |
+
["The quick brown fox jumps over the lazy dog.", "a clear, professional voice"],
|
| 85 |
+
["Welcome to the world of artificial intelligence and natural language processing.", "an enthusiastic, friendly tone"]
|
| 86 |
+
],
|
| 87 |
+
inputs=[text_input, style_input]
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Event handler
|
| 91 |
+
generate_btn.click(
|
| 92 |
+
fn=generate_speech,
|
| 93 |
+
inputs=[text_input, style_input],
|
| 94 |
+
outputs=[audio_output, status_output]
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
if __name__ == "__main__":
|
| 98 |
+
demo.launch()
|
inference_example_pretrain.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Xiaomi Corporation.
|
| 2 |
+
from src.mimo_audio.mimo_audio import MimoAudio
|
| 3 |
+
|
| 4 |
+
model_path = "models/MiMo-Audio-7B-Base"
|
| 5 |
+
tokenizer_path = "models/MiMo-Audio-Tokenizer"
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
model = MimoAudio(model_path, tokenizer_path)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# in context learning: speech-to-speech generation
|
| 12 |
+
instruction = "Convert the timbre of the input speech to target timbre."
|
| 13 |
+
|
| 14 |
+
input_audio = "examples/ESD/0013_000200.wav"
|
| 15 |
+
prompt_examples = [
|
| 16 |
+
{
|
| 17 |
+
"input_audio": "examples/ESD/0013_000139.wav",
|
| 18 |
+
"output_audio": "examples/ESD/0019_000139.wav",
|
| 19 |
+
"output_transcription": "Cuckoos is downheaded and crying.",
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"input_audio": "examples/ESD/0013_000963.wav",
|
| 23 |
+
"output_audio": "examples/ESD/0019_000963.wav",
|
| 24 |
+
"output_transcription": "She said in subdued voice.",
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"input_audio": "examples/ESD/0013_000559.wav",
|
| 28 |
+
"output_audio": "examples/ESD/0019_000559.wav",
|
| 29 |
+
"output_transcription": "A raging fire was-in his eyes.",
|
| 30 |
+
},
|
| 31 |
+
{
|
| 32 |
+
"input_audio": "examples/ESD/0013_001142.wav",
|
| 33 |
+
"output_audio": "examples/ESD/0019_001142.wav",
|
| 34 |
+
"output_transcription": "Does the one that wins get the crowned?",
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"input_audio": "examples/ESD/0013_000769.wav",
|
| 38 |
+
"output_audio": "examples/ESD/0019_000769.wav",
|
| 39 |
+
"output_transcription": "Not much use is it, sam?",
|
| 40 |
+
},
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
output_audio_path = "examples/in_context_learning_s2s.wav"
|
| 44 |
+
text_channel_output = model.in_context_learning_s2s(instruction, prompt_examples, input_audio, max_new_tokens=8192, output_audio_path=output_audio_path)
|
inference_example_sft.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Xiaomi Corporation.
|
| 2 |
+
from src.mimo_audio.mimo_audio import MimoAudio
|
| 3 |
+
|
| 4 |
+
model_path = "models/MiMo-Audio-7B-Instruct"
|
| 5 |
+
tokenizer_path = "models/MiMo-Audio-Tokenizer"
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
model = MimoAudio(model_path, tokenizer_path)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# tts
|
| 12 |
+
text = "今天天气真好"
|
| 13 |
+
output_audio_path = "examples/tts.wav"
|
| 14 |
+
text_channel_output = model.tts_sft(text, output_audio_path)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# instruct tts
|
| 18 |
+
text = "今天天气真好"
|
| 19 |
+
output_audio_path = "examples/instruct_tts.wav"
|
| 20 |
+
instruct = "用小孩子的声音开心的说"
|
| 21 |
+
text_channel_output = model.tts_sft(text, output_audio_path, instruct=instruct)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# natural instruction tts
|
| 25 |
+
text = "用气喘吁吁的年轻男性声音说:我跑不动了,你等等我!"
|
| 26 |
+
output_audio_path = "examples/natural_instruction_tts.wav"
|
| 27 |
+
text_channel_output = model.tts_sft(text, output_audio_path, read_text_only=False)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# audio understanding
|
| 31 |
+
audio_path = "examples/spoken_dialogue_assistant_turn_1.wav"
|
| 32 |
+
text = "Summarize the audio."
|
| 33 |
+
text_channel_output = model.audio_understanding_sft(audio_path, text)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# audio understanding with thinking
|
| 37 |
+
audio_path = "examples/spoken_dialogue_assistant_turn_1.wav"
|
| 38 |
+
text = "Summarize the audio."
|
| 39 |
+
text_channel_output = model.audio_understanding_sft(audio_path, text, thinking=True)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# spoken dialogue
|
| 43 |
+
first_turn_text_response = "我没办法获取实时的天气信息。不过呢,你可以试试几个方法来查看今天的天气。首先,你可以用手机自带的天气功能,比如苹果手机的天气应用,或者直接在系统设置里查看。其次,你也可以用一些专业的天气服务,像是国外的AccuWeather、Weather.com,或者国内的中国天气网、墨迹天气等等。再有就是,你还可以在谷歌或者百度里直接搜索你所在的城市加上天气这两个字。如果你能告诉我你所在的城市,我也可以帮你分析一下历史天气趋势,不过最新的数据还是需要你通过官方渠道去获取哦。"
|
| 44 |
+
message_list = [
|
| 45 |
+
{"role": "user", "content": "examples/今天天气如何.mp3"},
|
| 46 |
+
{"role": "assistant", "content": {"text": first_turn_text_response, "audio": "examples/spoken_dialogue_assistant_turn_1.wav"}},
|
| 47 |
+
{"role": "user", "content": "examples/北京.mp3"},
|
| 48 |
+
]
|
| 49 |
+
output_audio_path = "examples/spoken_dialogue_assistant_turn_2.wav"
|
| 50 |
+
text_channel_output = model.spoken_dialogue_sft_multiturn(message_list, output_audio_path=output_audio_path, system_prompt=None, prompt_speech="examples/prompt_speech_zh_m.wav")
|
| 51 |
+
text_channel_output = text_channel_output.split("<|eot|>")[0].replace(".....", "")
|
| 52 |
+
print(text_channel_output)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# speech-to-text dialogue
|
| 56 |
+
message_list = [
|
| 57 |
+
{"role": "user", "content": "./examples/今天天气如何.mp3"},
|
| 58 |
+
{"role": "assistant", "content": "你好,我没办法获取实时的天气信息。如果你能告诉我你所在的城市,我也可以帮你分析一下历史天气趋势,不过最新的数据还是需要你通过官方渠道去获取哦。"},
|
| 59 |
+
{"role": "user", "content": "./examples/北京.mp3"},
|
| 60 |
+
]
|
| 61 |
+
text_channel_output = model.speech2text_dialogue_sft_multiturn(message_list, thinking=True)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# text dialogue
|
| 65 |
+
|
| 66 |
+
message_list = [
|
| 67 |
+
{"role": "user", "content": "可以给我介绍一些中国的旅游景点吗?"},
|
| 68 |
+
{"role": "assistant", "content": "你好,您想去哪个城市旅游呢?"},
|
| 69 |
+
{"role": "user", "content": "北京"},
|
| 70 |
+
]
|
| 71 |
+
text_channel_output = model.text_dialogue_sft_multiturn(message_list, thinking=True)
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate>=1.9.0
|
| 2 |
+
torch==2.6.0
|
| 3 |
+
torchaudio==2.6.0
|
| 4 |
+
transformers==4.49.0
|
| 5 |
+
librosa>=0.11.0
|
| 6 |
+
scipy>=1.14.0
|
| 7 |
+
gradio==5.46.1
|
| 8 |
+
flash-attn==2.7.4.post1
|
| 9 |
+
spaces
|
requirements_space.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate>=1.9.0
|
| 2 |
+
torch==2.6.0
|
| 3 |
+
torchaudio==2.6.0
|
| 4 |
+
transformers==4.49.0
|
| 5 |
+
librosa>=0.11.0
|
| 6 |
+
scipy>=1.14.0
|
| 7 |
+
gradio==5.46.1
|
| 8 |
+
flash-attn==2.7.4.post1
|
| 9 |
+
spaces
|
run_mimo_audio.py
ADDED
|
@@ -0,0 +1,764 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Xiaomi Corporation.
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import torch
|
| 4 |
+
import os
|
| 5 |
+
import tempfile
|
| 6 |
+
import argparse
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from src.mimo_audio.mimo_audio import MimoAudio
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TTSGenerator:
|
| 12 |
+
def __init__(self, model, device=None):
|
| 13 |
+
self.model = model
|
| 14 |
+
self.device = device
|
| 15 |
+
|
| 16 |
+
def generate(self, text, instruct, output_audio_path):
|
| 17 |
+
path = Path(output_audio_path)
|
| 18 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 19 |
+
text_output = self.model.tts_sft(text, output_audio_path, instruct)
|
| 20 |
+
return text_output
|
| 21 |
+
|
| 22 |
+
class AudioUnderstandingGenerator:
|
| 23 |
+
def __init__(self, model, device=None):
|
| 24 |
+
self.model = model
|
| 25 |
+
self.device = device
|
| 26 |
+
|
| 27 |
+
def generate(self, input_speech, input_text, thinking=False):
|
| 28 |
+
text = self.model.audio_understanding_sft(input_speech, input_text, thinking=thinking)
|
| 29 |
+
return text
|
| 30 |
+
|
| 31 |
+
class SpokenDialogueGenerator:
|
| 32 |
+
def __init__(self, model, device=None):
|
| 33 |
+
self.model = model
|
| 34 |
+
self.device = device
|
| 35 |
+
|
| 36 |
+
def generate(self, input_speech, output_audio_path, system_prompt="You are MiMo-Audio, a friendly AI assistant and your response needs to be concise.", prompt_speech=None, add_history=False):
|
| 37 |
+
|
| 38 |
+
path = Path(output_audio_path)
|
| 39 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 40 |
+
text_response = self.model.spoken_dialogue_sft(input_speech, output_audio_path, system_prompt=system_prompt, prompt_speech=prompt_speech, add_history=add_history)
|
| 41 |
+
return text_response
|
| 42 |
+
|
| 43 |
+
def clear_history(self):
|
| 44 |
+
self.model.clear_history()
|
| 45 |
+
|
| 46 |
+
class Speech2TextDialogueGenerator:
|
| 47 |
+
def __init__(self, model, device=None):
|
| 48 |
+
self.model = model
|
| 49 |
+
self.device = device
|
| 50 |
+
|
| 51 |
+
def generate(self, input_speech, thinking=False, add_history=False):
|
| 52 |
+
text = self.model.speech2text_dialogue_sft(input_speech, thinking=thinking, add_history=add_history)
|
| 53 |
+
return text
|
| 54 |
+
|
| 55 |
+
def clear_history(self):
|
| 56 |
+
self.model.clear_history()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class TextDialogueGenerator:
|
| 60 |
+
def __init__(self, model, device=None):
|
| 61 |
+
self.model = model
|
| 62 |
+
self.device = device
|
| 63 |
+
|
| 64 |
+
def generate(self, input_text, thinking=False, add_history=False):
|
| 65 |
+
text = self.model.text_dialogue_sft(input_text, thinking=thinking, add_history=add_history)
|
| 66 |
+
return text
|
| 67 |
+
|
| 68 |
+
def clear_history(self):
|
| 69 |
+
self.model.clear_history()
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class MultiModalSpeechInterface:
|
| 73 |
+
def __init__(self):
|
| 74 |
+
self.model = None
|
| 75 |
+
self.tts_generator = None
|
| 76 |
+
self.audio_understanding_generator = None
|
| 77 |
+
self.spoken_dialogue_generator = None
|
| 78 |
+
self.speech2text_dialogue_generator = None
|
| 79 |
+
self.text_dialogue_generator = None
|
| 80 |
+
|
| 81 |
+
self.device = None
|
| 82 |
+
self.model_initialized = False
|
| 83 |
+
|
| 84 |
+
def initialize_model(self, model_path=None, tokenizer_path=None):
|
| 85 |
+
|
| 86 |
+
try:
|
| 87 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 88 |
+
|
| 89 |
+
if model_path is None:
|
| 90 |
+
model_path = "./models/MiMo-Audio-7B-Instruct"
|
| 91 |
+
if tokenizer_path is None:
|
| 92 |
+
tokenizer_path = "./models/MiMo-Audio-Tokenizer"
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
print(f"Model path: {model_path}")
|
| 96 |
+
print(f"Tokenizer path: {tokenizer_path}")
|
| 97 |
+
|
| 98 |
+
self.model = MimoAudio(model_path, tokenizer_path)
|
| 99 |
+
self.tts_generator = TTSGenerator(self.model, self.device)
|
| 100 |
+
self.audio_understanding_generator = AudioUnderstandingGenerator(self.model, self.device)
|
| 101 |
+
self.spoken_dialogue_generator = SpokenDialogueGenerator(self.model, self.device)
|
| 102 |
+
self.speech2text_dialogue_generator = Speech2TextDialogueGenerator(self.model, self.device)
|
| 103 |
+
self.text_dialogue_generator = TextDialogueGenerator(self.model, self.device)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
self.model_initialized = True
|
| 107 |
+
print("Model loaded successfully!")
|
| 108 |
+
return "✅ Model loaded successfully!"
|
| 109 |
+
|
| 110 |
+
except Exception as e:
|
| 111 |
+
error_msg = f"❌ Model loading failed: {str(e)}"
|
| 112 |
+
print(error_msg)
|
| 113 |
+
return error_msg
|
| 114 |
+
|
| 115 |
+
def generate_tts_audio(self, input_text, instruct="", use_instruct=False):
|
| 116 |
+
if not self.model_initialized:
|
| 117 |
+
return None, "❌ Error: Model not initialized, please load the model first"
|
| 118 |
+
|
| 119 |
+
if not input_text.strip():
|
| 120 |
+
return None, "❌ Error: Please input text content"
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
try:
|
| 124 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
|
| 125 |
+
output_path = tmp_file.name
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if not (use_instruct and instruct.strip()):
|
| 129 |
+
instruct = None
|
| 130 |
+
|
| 131 |
+
print(f"Generating TTS audio: {input_text}")
|
| 132 |
+
|
| 133 |
+
text_channel = self.tts_generator.generate(input_text, instruct, output_path)
|
| 134 |
+
status_msg = f"✅ TTS audio generated successfully!\n📝 Input text: {input_text}"
|
| 135 |
+
if use_instruct and instruct is not None and instruct.strip():
|
| 136 |
+
status_msg += f"\n🎭 Style description: {instruct}"
|
| 137 |
+
status_msg += f"\n🎵 Output text channel: {text_channel}"
|
| 138 |
+
|
| 139 |
+
return output_path, status_msg, gr.update(value=output_path, visible=True)
|
| 140 |
+
|
| 141 |
+
except Exception as e:
|
| 142 |
+
error_msg = f"❌ Error generating TTS audio: {str(e)}"
|
| 143 |
+
print(error_msg)
|
| 144 |
+
return None, error_msg, gr.update(visible=False)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def generate_audio_understanding_response(self, input_audio, input_text, thinking=False):
|
| 148 |
+
if not self.model_initialized:
|
| 149 |
+
return "", "❌ Error: Model not initialized, please load the model first"
|
| 150 |
+
|
| 151 |
+
if input_audio is None and not input_text.strip():
|
| 152 |
+
return "", "❌ Error: Please provide either audio input or text question"
|
| 153 |
+
|
| 154 |
+
if input_audio is None:
|
| 155 |
+
return "", "❌ Error: Please upload an audio file for Audio Understanding task"
|
| 156 |
+
|
| 157 |
+
if not input_text.strip():
|
| 158 |
+
return "", "❌ Error: Please input your question"
|
| 159 |
+
|
| 160 |
+
try:
|
| 161 |
+
print(f"Performing Audio Understanding task:")
|
| 162 |
+
print(f"Audio input: {input_audio}")
|
| 163 |
+
print(f"Text question: {input_text}")
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
audio_understanding_response = self.audio_understanding_generator.generate(input_audio, input_text.strip(), thinking=thinking)
|
| 167 |
+
|
| 168 |
+
status_msg = f"✅ Audio Understanding task completed successfully!\n🎵 Audio input: {os.path.basename(input_audio)}\n❓ Question: {input_text}\n💬 Answer: {audio_understanding_response}"
|
| 169 |
+
|
| 170 |
+
return audio_understanding_response, status_msg
|
| 171 |
+
|
| 172 |
+
except Exception as e:
|
| 173 |
+
error_msg = f"❌ Error performing Audio Understanding task: {str(e)}"
|
| 174 |
+
print(error_msg)
|
| 175 |
+
return "", error_msg
|
| 176 |
+
|
| 177 |
+
def generate_spoken_dialogue_response(self, input_audio, system_prompt=None, prompt_speech=None, add_history=False):
|
| 178 |
+
if not self.model_initialized:
|
| 179 |
+
return "", "❌ Error: Model not initialized, please load the model first"
|
| 180 |
+
|
| 181 |
+
if input_audio is None:
|
| 182 |
+
return "", "❌ Error: Please upload an audio file"
|
| 183 |
+
|
| 184 |
+
try:
|
| 185 |
+
|
| 186 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
|
| 187 |
+
output_audio_path = tmp_file.name
|
| 188 |
+
|
| 189 |
+
print(f"Performing spoken dialogue task:")
|
| 190 |
+
print(f"Audio input: {input_audio}")
|
| 191 |
+
print(f"Output path: {output_audio_path}")
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
dialogue_response = self.spoken_dialogue_generator.generate(input_audio, output_audio_path, system_prompt=system_prompt, prompt_speech=prompt_speech, add_history=add_history)
|
| 195 |
+
|
| 196 |
+
status_msg = f"✅ Spoken dialogue task completed successfully!\n🎵 Audio input: {os.path.basename(input_audio)}\n💬 Response: {dialogue_response[:300]}..."
|
| 197 |
+
|
| 198 |
+
return output_audio_path, dialogue_response, status_msg
|
| 199 |
+
|
| 200 |
+
except Exception as e:
|
| 201 |
+
error_msg = f"❌ Error performing spoken dialogue task: {str(e)}"
|
| 202 |
+
print(error_msg)
|
| 203 |
+
return None, None, error_msg
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def generate_speech2text_dialogue_response(self, input_audio, thinking=False, add_history=False):
|
| 207 |
+
if not self.model_initialized:
|
| 208 |
+
return "", "❌ Error: Model not initialized, please load the model first"
|
| 209 |
+
|
| 210 |
+
if input_audio is None:
|
| 211 |
+
return "", "❌ Error: Please upload an audio file for S2T Dialogue task"
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
try:
|
| 215 |
+
print(f"Performing S2T Dialogue task:")
|
| 216 |
+
print(f"Audio input: {input_audio}")
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
s2t_response = self.speech2text_dialogue_generator.generate(input_audio, thinking=thinking, add_history=add_history)
|
| 220 |
+
|
| 221 |
+
status_msg = f"✅ S2T dialogue task completed successfully!\n🎵 Audio input: {os.path.basename(input_audio)}\n❓💬 Answer: {s2t_response}"
|
| 222 |
+
|
| 223 |
+
return s2t_response, status_msg
|
| 224 |
+
|
| 225 |
+
except Exception as e:
|
| 226 |
+
error_msg = f"❌ Error performing QA task: {str(e)}"
|
| 227 |
+
print(error_msg)
|
| 228 |
+
return "", error_msg
|
| 229 |
+
|
| 230 |
+
def generate_text_dialogue_response(self, input_text, thinking=False, add_history=False):
|
| 231 |
+
if not self.model_initialized:
|
| 232 |
+
return "", "❌ Error: Model not initialized, please load the model first"
|
| 233 |
+
|
| 234 |
+
if not input_text or not input_text.strip():
|
| 235 |
+
return "", "❌ Error: Please input your text"
|
| 236 |
+
|
| 237 |
+
try:
|
| 238 |
+
print(f"Performing Text Dialogue task:")
|
| 239 |
+
print(f"Text input: {input_text}")
|
| 240 |
+
print(f"Thinking mode: {thinking}")
|
| 241 |
+
print(f"Add history: {add_history}")
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
t2t_response = self.text_dialogue_generator.generate(input_text.strip(), thinking=thinking, add_history=add_history)
|
| 245 |
+
|
| 246 |
+
status_msg = f"✅ T2T dialogue task completed successfully!\n💬 Input: {input_text}"
|
| 247 |
+
if thinking:
|
| 248 |
+
status_msg += f"\n🧠 Thinking mode: Enabled"
|
| 249 |
+
status_msg += f"\n💬 Answer: {t2t_response}"
|
| 250 |
+
|
| 251 |
+
return t2t_response, status_msg
|
| 252 |
+
|
| 253 |
+
except Exception as e:
|
| 254 |
+
error_msg = f"❌ Error performing T2T dialogue task: {str(e)}"
|
| 255 |
+
print(error_msg)
|
| 256 |
+
return "", error_msg
|
| 257 |
+
|
| 258 |
+
def clear_spoken_dialogue_history(self):
|
| 259 |
+
if not self.model_initialized:
|
| 260 |
+
return None, "", "❌ Error: Model not initialized, please load the model first"
|
| 261 |
+
|
| 262 |
+
try:
|
| 263 |
+
self.spoken_dialogue_generator.clear_history()
|
| 264 |
+
return None, "", "✅ Spoken dialogue history cleared successfully!"
|
| 265 |
+
except Exception as e:
|
| 266 |
+
error_msg = f"❌ Error clearing spoken dialogue history: {str(e)}"
|
| 267 |
+
print(error_msg)
|
| 268 |
+
return None, "", error_msg
|
| 269 |
+
|
| 270 |
+
def clear_speech2text_dialogue_history(self):
|
| 271 |
+
if not self.model_initialized:
|
| 272 |
+
return "", "❌ Error: Model not initialized, please load the model first"
|
| 273 |
+
|
| 274 |
+
try:
|
| 275 |
+
self.speech2text_dialogue_generator.clear_history()
|
| 276 |
+
return "", "✅ Speech-to-text dialogue history cleared successfully!"
|
| 277 |
+
except Exception as e:
|
| 278 |
+
error_msg = f"❌ Error clearing S2T dialogue history: {str(e)}"
|
| 279 |
+
print(error_msg)
|
| 280 |
+
return "", error_msg
|
| 281 |
+
|
| 282 |
+
def clear_text_dialogue_history(self):
|
| 283 |
+
if not self.model_initialized:
|
| 284 |
+
return "", "❌ Error: Model not initialized, please load the model first"
|
| 285 |
+
|
| 286 |
+
try:
|
| 287 |
+
self.text_dialogue_generator.clear_history()
|
| 288 |
+
return "", "✅ Text dialogue history cleared successfully!"
|
| 289 |
+
except Exception as e:
|
| 290 |
+
error_msg = f"❌ Error clearing T2T dialogue history: {str(e)}"
|
| 291 |
+
print(error_msg)
|
| 292 |
+
return "", error_msg
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def create_interface(self):
|
| 297 |
+
|
| 298 |
+
with gr.Blocks(title="MiMo-Audio Multimodal Speech Processing System", theme=gr.themes.Soft()) as iface:
|
| 299 |
+
gr.Markdown("# 🎵 MiMo-Audio Multimodal Speech Processing System")
|
| 300 |
+
gr.Markdown("Supports audio understanding, text-to-speech, spoken dialogue, speech-to-text dialogue and text-to-text dialogue")
|
| 301 |
+
|
| 302 |
+
with gr.Tabs():
|
| 303 |
+
|
| 304 |
+
with gr.TabItem("⚙️ Model Configuration"):
|
| 305 |
+
gr.Markdown("### 📋 Model initialization configuration")
|
| 306 |
+
|
| 307 |
+
with gr.Row():
|
| 308 |
+
with gr.Column():
|
| 309 |
+
|
| 310 |
+
model_path = gr.Textbox(
|
| 311 |
+
label="Model path",
|
| 312 |
+
placeholder="Leave blank to use default path: ./models/MiMo-Audio-7B-Instruct",
|
| 313 |
+
lines=3
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
tokenizer_path = gr.Textbox(
|
| 317 |
+
label="Tokenizer path",
|
| 318 |
+
placeholder="Leave blank to use default path: ./models/MiMo-Audio-Tokenizer",
|
| 319 |
+
lines=3
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
init_btn = gr.Button("🔄 Initialize model", variant="primary", size="lg")
|
| 323 |
+
|
| 324 |
+
with gr.Column():
|
| 325 |
+
init_status = gr.Textbox(
|
| 326 |
+
label="Initialization status",
|
| 327 |
+
interactive=False,
|
| 328 |
+
lines=6,
|
| 329 |
+
placeholder="Click the initialize model button to start..."
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
gr.Markdown("### 💻 System information")
|
| 334 |
+
device_info = gr.Textbox(
|
| 335 |
+
label="Device information",
|
| 336 |
+
value=f"GPU available: {'Yes' if torch.cuda.is_available() else 'No'}",
|
| 337 |
+
interactive=False
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
with gr.TabItem("🔊 Audio Understanding"):
|
| 342 |
+
gr.Markdown("### 🎯 Audio Understanding")
|
| 343 |
+
|
| 344 |
+
with gr.Row():
|
| 345 |
+
with gr.Column():
|
| 346 |
+
audio_understanding_input_audio = gr.Audio(
|
| 347 |
+
label="Upload Audio File",
|
| 348 |
+
type="filepath",
|
| 349 |
+
interactive=True,
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
audio_understanding_input_text = gr.Textbox(
|
| 353 |
+
label="Input Question",
|
| 354 |
+
placeholder="Please input your question...",
|
| 355 |
+
lines=3,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
audio_understanding_thinking = gr.Checkbox(
|
| 359 |
+
label="Enable Thinking Mode",
|
| 360 |
+
value=False,
|
| 361 |
+
info="Enable thinking mode, AI will perform a deeper analysis and thinking"
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
audio_understanding_generate_btn = gr.Button("🤖 Start Audio Understanding", variant="primary", size="lg")
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
with gr.Column():
|
| 369 |
+
audio_understanding_output_text = gr.Textbox(
|
| 370 |
+
label="Answer Result",
|
| 371 |
+
lines=8,
|
| 372 |
+
interactive=False,
|
| 373 |
+
placeholder="AI's answer will be displayed here...",
|
| 374 |
+
elem_id="audio_understanding_output_text"
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
audio_understanding_status = gr.Textbox(
|
| 378 |
+
label="Processing Status",
|
| 379 |
+
lines=6,
|
| 380 |
+
interactive=False,
|
| 381 |
+
placeholder="Processing status information will be displayed here..."
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
with gr.Row():
|
| 385 |
+
audio_understanding_copy_btn = gr.Button("📋 Copy Answer", size="sm")
|
| 386 |
+
audio_understanding_clear_btn = gr.Button("🗑️ Clear Result", size="sm")
|
| 387 |
+
|
| 388 |
+
gr.Markdown("### 🌟 Audio Understanding Examples")
|
| 389 |
+
audio_understanding_examples = gr.Examples(
|
| 390 |
+
examples=[
|
| 391 |
+
[None, "这段音频的主要内容是什么?"],
|
| 392 |
+
[None, "说话者的情感状态如何?"],
|
| 393 |
+
[None, "音频中提到了哪些关键信息?"],
|
| 394 |
+
[None, "Please summarize the main points of this conversation."],
|
| 395 |
+
[None, "What viewpoint does the speaker want to express?"]
|
| 396 |
+
],
|
| 397 |
+
inputs=[audio_understanding_input_audio, audio_understanding_input_text],
|
| 398 |
+
label="Click the example to automatically fill the question"
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
with gr.TabItem("🎵 Text-to-Speech"):
|
| 405 |
+
gr.Markdown("### 🎵 Text-to-Speech")
|
| 406 |
+
|
| 407 |
+
with gr.Row():
|
| 408 |
+
with gr.Column():
|
| 409 |
+
|
| 410 |
+
tts_input_text = gr.Textbox(
|
| 411 |
+
label="Input Text",
|
| 412 |
+
placeholder="Please input the text you want to convert to speech...",
|
| 413 |
+
lines=4,
|
| 414 |
+
max_lines=8
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
tts_instruct = gr.Textbox(
|
| 418 |
+
label="Style Description (Optional)",
|
| 419 |
+
placeholder="Please input the style description (optional)...",
|
| 420 |
+
lines=3,
|
| 421 |
+
max_lines=5
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
tts_use_instruct = gr.Checkbox(
|
| 425 |
+
label="Use Style Description",
|
| 426 |
+
value=True,
|
| 427 |
+
info="Enable to use InstructTTS for style-controlled speech generation"
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
tts_generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
|
| 431 |
+
|
| 432 |
+
with gr.Column():
|
| 433 |
+
|
| 434 |
+
tts_output_audio = gr.Audio(
|
| 435 |
+
label="Generated Speech",
|
| 436 |
+
type="filepath"
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
tts_status = gr.Textbox(
|
| 440 |
+
label="Generation Status",
|
| 441 |
+
lines=6,
|
| 442 |
+
interactive=False
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
tts_download_btn = gr.DownloadButton(
|
| 447 |
+
label="Download Generated Audio",
|
| 448 |
+
visible=False
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
with gr.TabItem("🎤 Spoken Dialogue"):
|
| 455 |
+
gr.Markdown("### 🎯 Spoken Dialogue")
|
| 456 |
+
|
| 457 |
+
with gr.Row():
|
| 458 |
+
with gr.Column():
|
| 459 |
+
|
| 460 |
+
dialogue_input_audio = gr.Audio(
|
| 461 |
+
label="Upload User Speech",
|
| 462 |
+
type="filepath",
|
| 463 |
+
interactive=True
|
| 464 |
+
)
|
| 465 |
+
system_prompt = gr.Textbox(
|
| 466 |
+
label="System Prompt (Optional)",
|
| 467 |
+
placeholder="e.g.: You are MiMo-Audio, a friendly AI assistant and your response needs to be concise.",
|
| 468 |
+
lines=1
|
| 469 |
+
)
|
| 470 |
+
prompt_speech = gr.Audio(
|
| 471 |
+
label="Prompt Speech (Optional, MiMo-Audio speaks with the same timbre as your prompt.)",
|
| 472 |
+
type="filepath",
|
| 473 |
+
interactive=True
|
| 474 |
+
)
|
| 475 |
+
spoken_dialogue_add_history = gr.Checkbox(
|
| 476 |
+
label="Enable History Record",
|
| 477 |
+
value=True,
|
| 478 |
+
info="Enable to remember the previous dialogue context"
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
with gr.Row():
|
| 482 |
+
dialogue_generate_btn = gr.Button("💬 Start Dialogue", variant="primary", size="lg")
|
| 483 |
+
|
| 484 |
+
with gr.Row():
|
| 485 |
+
dialogue_clear_history_btn = gr.Button("🗑️ Clear Dialogue History", size="sm", variant="secondary")
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
with gr.Column():
|
| 492 |
+
|
| 493 |
+
dialogue_output_audio = gr.Audio(
|
| 494 |
+
label="Output Audio",
|
| 495 |
+
type="filepath"
|
| 496 |
+
)
|
| 497 |
+
dialogue_output_text = gr.Textbox(
|
| 498 |
+
label="Dialogue Response",
|
| 499 |
+
lines=5,
|
| 500 |
+
interactive=False,
|
| 501 |
+
)
|
| 502 |
+
dialogue_status = gr.Textbox(
|
| 503 |
+
label="Dialogue Status",
|
| 504 |
+
lines=5,
|
| 505 |
+
interactive=False,
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
with gr.TabItem("💬 S2T Dialogue"):
|
| 513 |
+
gr.Markdown("### 🎯 S2T Dialogue")
|
| 514 |
+
|
| 515 |
+
with gr.Row():
|
| 516 |
+
with gr.Column():
|
| 517 |
+
|
| 518 |
+
s2t_dialogue_input_audio = gr.Audio(
|
| 519 |
+
label="Upload User Speech",
|
| 520 |
+
type="filepath",
|
| 521 |
+
interactive=True
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
s2t_dialogue_add_history = gr.Checkbox(
|
| 526 |
+
label="Enable History Record",
|
| 527 |
+
value=True,
|
| 528 |
+
info="Enable to remember the previous dialogue context"
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
s2t_dialogue_thinking = gr.Checkbox(
|
| 532 |
+
label="Enable Thinking Mode (think mode)",
|
| 533 |
+
value=False,
|
| 534 |
+
info="Enable to perform a deeper analysis and reasoning"
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
with gr.Row():
|
| 538 |
+
s2t_dialogue_generate_btn = gr.Button("🎧 Start S2T Dialogue", variant="primary", size="lg")
|
| 539 |
+
|
| 540 |
+
with gr.Row():
|
| 541 |
+
s2t_dialogue_clear_history_btn = gr.Button("🗑️ Clear Dialogue History", size="sm", variant="secondary")
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
with gr.Column():
|
| 545 |
+
|
| 546 |
+
s2t_dialogue_output_text = gr.Textbox(
|
| 547 |
+
label="Dialogue Response",
|
| 548 |
+
lines=8,
|
| 549 |
+
interactive=False,
|
| 550 |
+
placeholder="AI's dialogue response will be displayed here..."
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
s2t_dialogue_status = gr.Textbox(
|
| 554 |
+
label="Dialogue Status",
|
| 555 |
+
lines=5,
|
| 556 |
+
interactive=False,
|
| 557 |
+
placeholder="Dialogue status information will be displayed here..."
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
with gr.TabItem("📝 T2T Dialogue"):
|
| 563 |
+
gr.Markdown("### 🎯 T2T Dialogue")
|
| 564 |
+
|
| 565 |
+
with gr.Row():
|
| 566 |
+
with gr.Column():
|
| 567 |
+
|
| 568 |
+
t2t_dialogue_input_text = gr.Textbox(
|
| 569 |
+
label="Input Dialogue Content",
|
| 570 |
+
placeholder="Please input the text content you want to dialogue...",
|
| 571 |
+
lines=4,
|
| 572 |
+
max_lines=8
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
t2t_dialogue_add_history = gr.Checkbox(
|
| 576 |
+
label="Enable History Record",
|
| 577 |
+
value=True,
|
| 578 |
+
info="Enable to remember the previous dialogue context"
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
t2t_dialogue_thinking = gr.Checkbox(
|
| 582 |
+
label="Enable Thinking Mode (Thinking)",
|
| 583 |
+
value=False,
|
| 584 |
+
info="Enable thinking mode, AI will perform a deeper analysis and thinking"
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
with gr.Row():
|
| 588 |
+
t2t_dialogue_generate_btn = gr.Button("💬 Start T2T Dialogue", variant="primary", size="lg")
|
| 589 |
+
|
| 590 |
+
with gr.Row():
|
| 591 |
+
t2t_dialogue_clear_history_btn = gr.Button("🗑️ Clear Dialogue History", size="sm", variant="secondary")
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
with gr.Column():
|
| 596 |
+
t2t_dialogue_output_text = gr.Textbox(
|
| 597 |
+
label="Dialogue Response",
|
| 598 |
+
lines=8,
|
| 599 |
+
interactive=False,
|
| 600 |
+
placeholder="AI's dialogue response will be displayed here..."
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
t2t_dialogue_status = gr.Textbox(
|
| 604 |
+
label="Dialogue Status",
|
| 605 |
+
lines=5,
|
| 606 |
+
interactive=False,
|
| 607 |
+
placeholder="Dialogue status information will be displayed here..."
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
gr.Markdown("### 🌟 T2T Dialogue Examples")
|
| 611 |
+
t2t_dialogue_examples = gr.Examples(
|
| 612 |
+
examples=[
|
| 613 |
+
["Hello, how are you?"],
|
| 614 |
+
["I want to know the history of the development of artificial intelligence"],
|
| 615 |
+
["Please recommend some good movies"],
|
| 616 |
+
["Can you help me explain the basic concepts of quantum physics?"],
|
| 617 |
+
["I'm learning programming recently, any suggestions?"]
|
| 618 |
+
],
|
| 619 |
+
inputs=[t2t_dialogue_input_text],
|
| 620 |
+
label="Click the example to automatically fill the dialogue content"
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
def copy_text_to_clipboard(text):
|
| 626 |
+
return text
|
| 627 |
+
|
| 628 |
+
def clear_audio_understanding_results():
|
| 629 |
+
return "", "🗑️ Audio Understanding Result Cleared"
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
init_btn.click(
|
| 633 |
+
fn=lambda path, tok_path: self.initialize_model(path or None, tok_path or None),
|
| 634 |
+
inputs=[model_path, tokenizer_path],
|
| 635 |
+
outputs=[init_status]
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
audio_understanding_generate_btn.click(
|
| 640 |
+
fn=self.generate_audio_understanding_response,
|
| 641 |
+
inputs=[audio_understanding_input_audio, audio_understanding_input_text, audio_understanding_thinking],
|
| 642 |
+
outputs=[audio_understanding_output_text, audio_understanding_status]
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
audio_understanding_copy_btn.click(
|
| 646 |
+
fn=None,
|
| 647 |
+
inputs=[audio_understanding_output_text],
|
| 648 |
+
js="(text) => {navigator.clipboard.writeText(text); alert('Copied to clipboard!')}"
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
tts_generate_btn.click(
|
| 652 |
+
fn=self.generate_tts_audio,
|
| 653 |
+
inputs=[tts_input_text, tts_instruct, tts_use_instruct],
|
| 654 |
+
outputs=[tts_output_audio, tts_status, tts_download_btn]
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
dialogue_generate_btn.click(
|
| 658 |
+
fn=self.generate_spoken_dialogue_response,
|
| 659 |
+
inputs=[dialogue_input_audio, system_prompt, prompt_speech, spoken_dialogue_add_history],
|
| 660 |
+
outputs=[dialogue_output_audio, dialogue_output_text, dialogue_status]
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
dialogue_clear_history_btn.click(
|
| 666 |
+
fn=self.clear_spoken_dialogue_history,
|
| 667 |
+
outputs=[dialogue_output_audio, dialogue_output_text, dialogue_status]
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
s2t_dialogue_generate_btn.click(
|
| 672 |
+
fn=self.generate_speech2text_dialogue_response,
|
| 673 |
+
inputs=[s2t_dialogue_input_audio, s2t_dialogue_thinking, s2t_dialogue_add_history],
|
| 674 |
+
outputs=[s2t_dialogue_output_text, s2t_dialogue_status]
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
s2t_dialogue_clear_history_btn.click(
|
| 680 |
+
fn=self.clear_speech2text_dialogue_history,
|
| 681 |
+
outputs=[s2t_dialogue_output_text, s2t_dialogue_status]
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
t2t_dialogue_generate_btn.click(
|
| 686 |
+
fn=self.generate_text_dialogue_response,
|
| 687 |
+
inputs=[t2t_dialogue_input_text, t2t_dialogue_thinking, t2t_dialogue_add_history],
|
| 688 |
+
outputs=[t2t_dialogue_output_text, t2t_dialogue_status]
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
|
| 692 |
+
t2t_dialogue_clear_history_btn.click(
|
| 693 |
+
fn=self.clear_text_dialogue_history,
|
| 694 |
+
outputs=[t2t_dialogue_output_text, t2t_dialogue_status]
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
audio_understanding_clear_btn.click(
|
| 701 |
+
fn=clear_audio_understanding_results,
|
| 702 |
+
outputs=[audio_understanding_output_text, audio_understanding_status]
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
|
| 710 |
+
tts_input_text.submit(
|
| 711 |
+
fn=self.generate_tts_audio,
|
| 712 |
+
inputs=[tts_input_text, tts_instruct, tts_use_instruct],
|
| 713 |
+
outputs=[tts_output_audio, tts_status, tts_download_btn]
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
audio_understanding_input_text.submit(
|
| 718 |
+
fn=self.generate_audio_understanding_response,
|
| 719 |
+
inputs=[audio_understanding_input_audio, audio_understanding_input_text, audio_understanding_thinking],
|
| 720 |
+
outputs=[audio_understanding_output_text, audio_understanding_status]
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
t2t_dialogue_input_text.submit(
|
| 724 |
+
fn=self.generate_text_dialogue_response,
|
| 725 |
+
inputs=[t2t_dialogue_input_text, t2t_dialogue_thinking, t2t_dialogue_add_history],
|
| 726 |
+
outputs=[t2t_dialogue_output_text, t2t_dialogue_status]
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
|
| 730 |
+
return iface
|
| 731 |
+
|
| 732 |
+
def main():
|
| 733 |
+
parser = argparse.ArgumentParser(description="MiMo-Audio")
|
| 734 |
+
parser.add_argument("--host", default="0.0.0.0", help="Server Address")
|
| 735 |
+
parser.add_argument("--port", type=int, default=7897, help="Port")
|
| 736 |
+
parser.add_argument("--share", action="store_true", help="Create Public Link")
|
| 737 |
+
parser.add_argument("--debug", action="store_true", help="Debug Mode")
|
| 738 |
+
|
| 739 |
+
args = parser.parse_args()
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
print("🚀 Launch MiMo-Audio...")
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
speech_interface = MultiModalSpeechInterface()
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
print("🎨 Create Gradio Interface...")
|
| 751 |
+
iface = speech_interface.create_interface()
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
print(f"🌐 Launch Service - Address: {args.host}:{args.port}")
|
| 755 |
+
|
| 756 |
+
iface.launch(
|
| 757 |
+
server_name=args.host,
|
| 758 |
+
server_port=args.port,
|
| 759 |
+
share=args.share,
|
| 760 |
+
debug=args.debug
|
| 761 |
+
)
|
| 762 |
+
|
| 763 |
+
if __name__ == "__main__":
|
| 764 |
+
main()
|
src/mimo_audio/mimo_audio.py
ADDED
|
@@ -0,0 +1,1292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Xiaomi Corporation.
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import time
|
| 5 |
+
import random
|
| 6 |
+
import torch
|
| 7 |
+
import torchaudio
|
| 8 |
+
import soundfile as sf
|
| 9 |
+
|
| 10 |
+
from typing import Union
|
| 11 |
+
from torchaudio.transforms import MelSpectrogram
|
| 12 |
+
from transformers import (
|
| 13 |
+
AutoTokenizer,
|
| 14 |
+
GenerationConfig
|
| 15 |
+
)
|
| 16 |
+
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
| 17 |
+
|
| 18 |
+
from .process_speechdata import InputSegment, StreamingInputSegment
|
| 19 |
+
from ..mimo_audio_tokenizer import MiMoAudioTokenizer
|
| 20 |
+
from .templates import asr_en_templates, asr_zh_templates, tts_en_templates, tts_zh_templates
|
| 21 |
+
from .modeling_mimo_audio import (
|
| 22 |
+
MiMoAudioArguments,
|
| 23 |
+
MiMoAudioForCausalLM,
|
| 24 |
+
MiMoSampler,
|
| 25 |
+
MiMoStopper,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def detect_language(text):
|
| 30 |
+
if re.search(r'[\u4e00-\u9fff]', text):
|
| 31 |
+
return 'zh'
|
| 32 |
+
else:
|
| 33 |
+
return 'en'
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class MimoAudio:
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
model_path: str,
|
| 41 |
+
mimo_audio_tokenizer_path: str,
|
| 42 |
+
device: str | None = None,
|
| 43 |
+
) -> None:
|
| 44 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 45 |
+
|
| 46 |
+
self.path = model_path
|
| 47 |
+
self.mimo_audio_tokenizer_path = mimo_audio_tokenizer_path
|
| 48 |
+
|
| 49 |
+
self.tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(
|
| 50 |
+
self.path
|
| 51 |
+
)
|
| 52 |
+
self.padding_idx = int(self.tokenizer.pad_token_id)
|
| 53 |
+
|
| 54 |
+
special_tokens = [
|
| 55 |
+
"<|sosp|>",
|
| 56 |
+
"<|eosp|>",
|
| 57 |
+
"<|empty|>",
|
| 58 |
+
"<|Human|>",
|
| 59 |
+
"<|SpeechLM|>",
|
| 60 |
+
"<|sostm|>",
|
| 61 |
+
"<|eostm|>",
|
| 62 |
+
"<|eot|>",
|
| 63 |
+
]
|
| 64 |
+
for token in special_tokens:
|
| 65 |
+
if token not in self.tokenizer.get_vocab():
|
| 66 |
+
print(f"Add special tokens {token} to tokenizer.vocab")
|
| 67 |
+
self.tokenizer.add_tokens([token], special_tokens=True)
|
| 68 |
+
|
| 69 |
+
self.sosp_idx = self.tokenizer.convert_tokens_to_ids("<|sosp|>")
|
| 70 |
+
self.eosp_idx = self.tokenizer.convert_tokens_to_ids("<|eosp|>")
|
| 71 |
+
self.empty_token = self.tokenizer.convert_tokens_to_ids("<|empty|>")
|
| 72 |
+
self.sostm_idx = self.tokenizer.convert_tokens_to_ids("<|sostm|>")
|
| 73 |
+
self.eostm_idx = self.tokenizer.convert_tokens_to_ids("<|eostm|>")
|
| 74 |
+
self.eot_idx = self.tokenizer.convert_tokens_to_ids("<|eot|>")
|
| 75 |
+
self.im_start_idx = self.tokenizer.convert_tokens_to_ids("<|im_start|>")
|
| 76 |
+
self.im_end_idx = self.tokenizer.convert_tokens_to_ids("<|im_end|>")
|
| 77 |
+
|
| 78 |
+
model_args = MiMoAudioArguments(
|
| 79 |
+
model_name_or_path=self.path,
|
| 80 |
+
sosp_idx=self.sosp_idx,
|
| 81 |
+
eosp_idx=self.eosp_idx,
|
| 82 |
+
empty_idx=self.empty_token,
|
| 83 |
+
sostm_idx=self.sostm_idx,
|
| 84 |
+
eostm_idx=self.eostm_idx,
|
| 85 |
+
eot_idx=self.eot_idx,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
start_loading_time = time.monotonic()
|
| 89 |
+
self.model = MiMoAudioForCausalLM.from_pretrained(
|
| 90 |
+
self.path,
|
| 91 |
+
args=model_args,
|
| 92 |
+
torch_dtype=torch.bfloat16,
|
| 93 |
+
device_map={"": self.device},
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
self.group_size=self.model.config.group_size
|
| 97 |
+
self.audio_channels=self.model.config.audio_channels
|
| 98 |
+
self.delay_pattern = self.model.config.delay_pattern
|
| 99 |
+
self.vocab_size = self.model.config.vocab_size
|
| 100 |
+
|
| 101 |
+
self.speech_zeroemb_idx = self.model.speech_empty_ids
|
| 102 |
+
|
| 103 |
+
self.model.eval()
|
| 104 |
+
print(
|
| 105 |
+
f"Model loaded in {time.monotonic() - start_loading_time:.2f} seconds, device: {self.device}"
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
self.generate_kwargs = {
|
| 109 |
+
"max_length": 8192,
|
| 110 |
+
"eos_token_id": self.tokenizer.eos_token_id,
|
| 111 |
+
"pad_token_id": self.tokenizer.pad_token_id,
|
| 112 |
+
}
|
| 113 |
+
self.default_global_sampler = MiMoSampler(
|
| 114 |
+
do_sample=True, temperature=0.6, top_k=50, top_p=0.95
|
| 115 |
+
)
|
| 116 |
+
self.default_local_sampler = MiMoSampler(
|
| 117 |
+
do_sample=True, temperature=0.9, top_k=50, top_p=0.95
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
self.task_sampler_configs = {
|
| 121 |
+
"asr": {
|
| 122 |
+
"global": MiMoSampler(do_sample=False, temperature=1.0, top_p=1.0),
|
| 123 |
+
"local": MiMoSampler(do_sample=True, temperature=0.9, top_p=0.95)
|
| 124 |
+
},
|
| 125 |
+
"tts": {
|
| 126 |
+
"global": MiMoSampler(do_sample=True, temperature=0.6, top_p=1.0),
|
| 127 |
+
"local": MiMoSampler(do_sample=True, temperature=0.9, top_p=0.95)
|
| 128 |
+
},
|
| 129 |
+
"spoken_dialogue": {
|
| 130 |
+
"global": MiMoSampler(do_sample=True, temperature=0.6, top_p=0.95),
|
| 131 |
+
"local": MiMoSampler(do_sample=True, temperature=0.9, top_p=0.95)
|
| 132 |
+
},
|
| 133 |
+
"audio_understanding": {
|
| 134 |
+
"global": MiMoSampler(do_sample=True, temperature=0.3, top_p=0.95),
|
| 135 |
+
"local": MiMoSampler(do_sample=True, temperature=0.9, top_p=0.95)
|
| 136 |
+
},
|
| 137 |
+
"text_chat": {
|
| 138 |
+
"global": MiMoSampler(do_sample=True, temperature=0.4, top_p=0.95),
|
| 139 |
+
"local": MiMoSampler(do_sample=True, temperature=0.9, top_p=0.95)
|
| 140 |
+
},
|
| 141 |
+
"in_context_learning_s2s": {
|
| 142 |
+
"global": MiMoSampler(do_sample=False, temperature=1.0, top_p=1.0),
|
| 143 |
+
"local": MiMoSampler(do_sample=True, temperature=0.9, top_p=0.95)
|
| 144 |
+
},
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
start_loading_mimo_audio_tokenizer_time = time.monotonic()
|
| 148 |
+
self.mimo_audio_tokenizer = MiMoAudioTokenizer.from_pretrained(self.mimo_audio_tokenizer_path)
|
| 149 |
+
|
| 150 |
+
self.mimo_audio_tokenizer.eval().bfloat16().to(self.device)
|
| 151 |
+
print(
|
| 152 |
+
f"MiMo-Audio Tokenizer loaded in {time.monotonic() - start_loading_mimo_audio_tokenizer_time:.2f} seconds, device: {self.device}"
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# Initialize mel spectrogram transform for consistent processing
|
| 156 |
+
self.mel_transform = MelSpectrogram(
|
| 157 |
+
sample_rate=self.mimo_audio_tokenizer.config.sampling_rate,
|
| 158 |
+
n_fft=self.mimo_audio_tokenizer.config.nfft,
|
| 159 |
+
hop_length=self.mimo_audio_tokenizer.config.hop_length,
|
| 160 |
+
win_length=self.mimo_audio_tokenizer.config.window_size,
|
| 161 |
+
f_min=self.mimo_audio_tokenizer.config.fmin,
|
| 162 |
+
f_max=self.mimo_audio_tokenizer.config.fmax,
|
| 163 |
+
n_mels=self.mimo_audio_tokenizer.config.n_mels,
|
| 164 |
+
power=1.0,
|
| 165 |
+
center=True,
|
| 166 |
+
).to(self.device)
|
| 167 |
+
|
| 168 |
+
self.history = None
|
| 169 |
+
|
| 170 |
+
def get_task_sampler(self, task_name):
|
| 171 |
+
if task_name not in self.task_sampler_configs:
|
| 172 |
+
return {
|
| 173 |
+
"global": self.default_global_sampler,
|
| 174 |
+
"local": self.default_local_sampler
|
| 175 |
+
}
|
| 176 |
+
return self.task_sampler_configs[task_name]
|
| 177 |
+
|
| 178 |
+
def save_wav(self, path, wav):
|
| 179 |
+
sf.write(
|
| 180 |
+
path,
|
| 181 |
+
wav.reshape(-1).detach().cpu().numpy(),
|
| 182 |
+
24000,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
def wav2mel(self, wav):
|
| 186 |
+
spec = self.mel_transform(wav[None, :])
|
| 187 |
+
return torch.log(torch.clip(spec, min=1e-7)).squeeze()
|
| 188 |
+
|
| 189 |
+
def resample_audio_if_needed(self, wav_tensor: torch.Tensor, original_sr: int):
|
| 190 |
+
target_sr = self.mimo_audio_tokenizer.config.sampling_rate
|
| 191 |
+
if original_sr != target_sr:
|
| 192 |
+
wav_tensor = torchaudio.functional.resample(
|
| 193 |
+
wav_tensor, original_sr, target_sr
|
| 194 |
+
)
|
| 195 |
+
return wav_tensor
|
| 196 |
+
|
| 197 |
+
def group_by_length(self, features: torch.Tensor, lengths: torch.Tensor, max_length: int):
|
| 198 |
+
if features.size(0) != lengths.sum().item():
|
| 199 |
+
raise ValueError(f"Feature size mismatch: {features.size(0)} vs {lengths.sum().item()}")
|
| 200 |
+
|
| 201 |
+
split_points = []
|
| 202 |
+
current_sum = 0
|
| 203 |
+
|
| 204 |
+
for i, seq_len in enumerate(lengths):
|
| 205 |
+
if current_sum + seq_len > max_length and current_sum > 0:
|
| 206 |
+
split_points.append(i)
|
| 207 |
+
current_sum = seq_len.item()
|
| 208 |
+
else:
|
| 209 |
+
current_sum += seq_len.item()
|
| 210 |
+
|
| 211 |
+
# Convert split points to group sizes
|
| 212 |
+
group_sizes = []
|
| 213 |
+
prev = 0
|
| 214 |
+
for point in split_points:
|
| 215 |
+
group_sizes.append(point - prev)
|
| 216 |
+
prev = point
|
| 217 |
+
if prev < len(lengths):
|
| 218 |
+
group_sizes.append(len(lengths) - prev)
|
| 219 |
+
|
| 220 |
+
len_groups = torch.split(lengths, group_sizes)
|
| 221 |
+
feature_sizes = [group.sum().item() for group in len_groups]
|
| 222 |
+
feature_groups = torch.split(features, feature_sizes)
|
| 223 |
+
|
| 224 |
+
return feature_groups, len_groups
|
| 225 |
+
|
| 226 |
+
def encode_batch(self, input_features: torch.Tensor, input_lens: torch.Tensor, max_length: int = 256000):
|
| 227 |
+
feature_groups, len_groups = self.group_by_length(input_features, input_lens, max_length)
|
| 228 |
+
|
| 229 |
+
encoded_parts = []
|
| 230 |
+
for features, lengths in zip(feature_groups, len_groups):
|
| 231 |
+
with torch.no_grad():
|
| 232 |
+
codes, _ = self.mimo_audio_tokenizer.encoder.encode(
|
| 233 |
+
input_features=features.to(self.device),
|
| 234 |
+
input_lens=lengths.to(self.device),
|
| 235 |
+
return_codes_only=True
|
| 236 |
+
)
|
| 237 |
+
encoded_parts.append(codes)
|
| 238 |
+
|
| 239 |
+
return torch.cat(encoded_parts, dim=-1)
|
| 240 |
+
|
| 241 |
+
def preprocess_input(
|
| 242 |
+
self,
|
| 243 |
+
input: Union[None, str, torch.Tensor] = None,
|
| 244 |
+
):
|
| 245 |
+
if isinstance(input, torch.Tensor) or (isinstance(input, str) and os.path.isfile(input)):
|
| 246 |
+
if isinstance(input, torch.Tensor):
|
| 247 |
+
wav = input
|
| 248 |
+
else:
|
| 249 |
+
wav, sr = torchaudio.load(input)
|
| 250 |
+
if wav.ndim == 2:
|
| 251 |
+
wav = wav.mean(dim=0)
|
| 252 |
+
wav = self.resample_audio_if_needed(wav, sr)
|
| 253 |
+
wav = wav.to(self.device)
|
| 254 |
+
|
| 255 |
+
mel = self.wav2mel(wav).transpose(0, 1) # (seq_len, n_mels)
|
| 256 |
+
|
| 257 |
+
input_len = mel.size(0)
|
| 258 |
+
segment_size = 6000
|
| 259 |
+
input_len_seg = [segment_size] * (input_len // segment_size)
|
| 260 |
+
if input_len % segment_size > 0:
|
| 261 |
+
input_len_seg.append(input_len % segment_size)
|
| 262 |
+
|
| 263 |
+
codes_packed = self.encode_batch(
|
| 264 |
+
input_features=mel,
|
| 265 |
+
input_lens=torch.tensor(input_len_seg),
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
codes = codes_packed.transpose(0, 1).detach().cpu()
|
| 269 |
+
audio_codes = codes[:, :self.audio_channels]
|
| 270 |
+
|
| 271 |
+
# Pad the sequence to be a multiple of group_size by repeating the last frame
|
| 272 |
+
num_timesteps = audio_codes.shape[0]
|
| 273 |
+
if num_timesteps % self.group_size != 0:
|
| 274 |
+
padding_needed = self.group_size - (num_timesteps % self.group_size)
|
| 275 |
+
last_tokens = audio_codes[-1:, :] # Keep dim for repeat
|
| 276 |
+
padding_tokens = last_tokens.repeat(padding_needed, 1)
|
| 277 |
+
audio_codes = torch.cat([audio_codes, padding_tokens], dim=0)
|
| 278 |
+
|
| 279 |
+
audio_tokenized = audio_codes.reshape(-1)
|
| 280 |
+
|
| 281 |
+
return audio_tokenized
|
| 282 |
+
else:
|
| 283 |
+
text = input
|
| 284 |
+
if (
|
| 285 |
+
text.isupper() or text.islower()
|
| 286 |
+
): # If the text only contains upper-case or lower-case letters, capitalize it.
|
| 287 |
+
text = text.capitalize()
|
| 288 |
+
return text
|
| 289 |
+
|
| 290 |
+
def get_input_ids(self, prompt):
|
| 291 |
+
input_ids = [
|
| 292 |
+
seg.to_input_id(
|
| 293 |
+
self.tokenizer,
|
| 294 |
+
self.group_size,
|
| 295 |
+
self.audio_channels,
|
| 296 |
+
)
|
| 297 |
+
for seg in prompt
|
| 298 |
+
]
|
| 299 |
+
input_ids = torch.cat(input_ids, dim=1)
|
| 300 |
+
return input_ids.to(self.device)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def get_asr_sft_prompt(
|
| 304 |
+
self,
|
| 305 |
+
input: Union[None, str] = None,
|
| 306 |
+
):
|
| 307 |
+
audio_tokenized = self.preprocess_input(input)
|
| 308 |
+
|
| 309 |
+
template = random.choice(asr_zh_templates + asr_en_templates)
|
| 310 |
+
|
| 311 |
+
lm_prompt = [
|
| 312 |
+
InputSegment(
|
| 313 |
+
text=f"<|im_start|>user\n",
|
| 314 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 315 |
+
text_zeroemb_idx=self.empty_token,
|
| 316 |
+
),
|
| 317 |
+
InputSegment(
|
| 318 |
+
audio=audio_tokenized,
|
| 319 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 320 |
+
text_zeroemb_idx=self.empty_token,
|
| 321 |
+
),
|
| 322 |
+
InputSegment(
|
| 323 |
+
text=template,
|
| 324 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 325 |
+
text_zeroemb_idx=self.empty_token,
|
| 326 |
+
),
|
| 327 |
+
InputSegment(
|
| 328 |
+
text=f"<|im_end|>\n",
|
| 329 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 330 |
+
text_zeroemb_idx=self.empty_token,
|
| 331 |
+
),
|
| 332 |
+
InputSegment(
|
| 333 |
+
text=f"<|im_start|>assistant\n",
|
| 334 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 335 |
+
text_zeroemb_idx=self.empty_token,
|
| 336 |
+
),
|
| 337 |
+
InputSegment(
|
| 338 |
+
text="<think>\n\n</think>\n",
|
| 339 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 340 |
+
text_zeroemb_idx=self.empty_token,
|
| 341 |
+
)
|
| 342 |
+
]
|
| 343 |
+
input_ids = self.get_input_ids(lm_prompt)
|
| 344 |
+
return input_ids
|
| 345 |
+
|
| 346 |
+
def get_tts_sft_prompt(
|
| 347 |
+
self,
|
| 348 |
+
input: Union[None, str] = None,
|
| 349 |
+
instruct=None,
|
| 350 |
+
read_text_only=True,
|
| 351 |
+
prompt_speech=None,
|
| 352 |
+
):
|
| 353 |
+
if prompt_speech is not None:
|
| 354 |
+
assistant_prompt_audio_token = self.preprocess_input(prompt_speech)
|
| 355 |
+
else:
|
| 356 |
+
assistant_prompt_audio_token = None
|
| 357 |
+
if not read_text_only:
|
| 358 |
+
text = self.preprocess_input(input)
|
| 359 |
+
if assistant_prompt_audio_token is None:
|
| 360 |
+
lm_prompt = [
|
| 361 |
+
InputSegment(
|
| 362 |
+
text="<|im_start|>system\n",
|
| 363 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 364 |
+
text_zeroemb_idx=self.empty_token,
|
| 365 |
+
),
|
| 366 |
+
InputSegment(
|
| 367 |
+
text=f"你需要根据指定的风格指令和文本内容来生成语音。",
|
| 368 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 369 |
+
text_zeroemb_idx=self.empty_token,
|
| 370 |
+
),
|
| 371 |
+
InputSegment(
|
| 372 |
+
text="<|im_end|>\n",
|
| 373 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 374 |
+
text_zeroemb_idx=self.empty_token,
|
| 375 |
+
),
|
| 376 |
+
InputSegment(
|
| 377 |
+
text=f"<|im_start|>user\n{text}<|im_end|>\n",
|
| 378 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 379 |
+
text_zeroemb_idx=self.empty_token,
|
| 380 |
+
),
|
| 381 |
+
InputSegment(
|
| 382 |
+
text=f"<|im_start|>assistant\n<think>\n",
|
| 383 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 384 |
+
text_zeroemb_idx=self.empty_token,
|
| 385 |
+
),
|
| 386 |
+
]
|
| 387 |
+
else:
|
| 388 |
+
lm_prompt = [
|
| 389 |
+
InputSegment(
|
| 390 |
+
text="<|im_start|>system\n",
|
| 391 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 392 |
+
text_zeroemb_idx=self.empty_token,
|
| 393 |
+
),
|
| 394 |
+
InputSegment(
|
| 395 |
+
text=f"你需要根据指定的风格指令和文本内容来生成和语音prompt具有相同音色的语音。你的音色应该是:",
|
| 396 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 397 |
+
text_zeroemb_idx=self.empty_token,
|
| 398 |
+
),
|
| 399 |
+
InputSegment(
|
| 400 |
+
text="",
|
| 401 |
+
audio=assistant_prompt_audio_token,
|
| 402 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 403 |
+
text_zeroemb_idx=self.empty_token,
|
| 404 |
+
),
|
| 405 |
+
InputSegment(
|
| 406 |
+
text="<|im_end|>\n",
|
| 407 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 408 |
+
text_zeroemb_idx=self.empty_token,
|
| 409 |
+
),
|
| 410 |
+
InputSegment(
|
| 411 |
+
text=f"<|im_start|>user\n{text}<|im_end|>\n",
|
| 412 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 413 |
+
text_zeroemb_idx=self.empty_token,
|
| 414 |
+
),
|
| 415 |
+
InputSegment(
|
| 416 |
+
text=f"<|im_start|>assistant\n<think>\n",
|
| 417 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 418 |
+
text_zeroemb_idx=self.empty_token,
|
| 419 |
+
),
|
| 420 |
+
]
|
| 421 |
+
else:
|
| 422 |
+
language = detect_language(input)
|
| 423 |
+
if language == "zh":
|
| 424 |
+
template = random.choice(tts_zh_templates)
|
| 425 |
+
else:
|
| 426 |
+
template = random.choice(tts_en_templates)
|
| 427 |
+
|
| 428 |
+
text = self.preprocess_input(input)
|
| 429 |
+
if instruct is None:
|
| 430 |
+
lm_prompt = [
|
| 431 |
+
InputSegment(
|
| 432 |
+
text=f"<|im_start|>user\n{template}: {text}<|im_end|>\n",
|
| 433 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 434 |
+
text_zeroemb_idx=self.empty_token,
|
| 435 |
+
),
|
| 436 |
+
InputSegment(
|
| 437 |
+
text=f"<|im_start|>assistant\n<|sostm|>",
|
| 438 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 439 |
+
text_zeroemb_idx=self.empty_token,
|
| 440 |
+
),
|
| 441 |
+
]
|
| 442 |
+
else:
|
| 443 |
+
if assistant_prompt_audio_token is None:
|
| 444 |
+
lm_prompt = [
|
| 445 |
+
InputSegment(
|
| 446 |
+
text="<|im_start|>system\n",
|
| 447 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 448 |
+
text_zeroemb_idx=self.empty_token,
|
| 449 |
+
),
|
| 450 |
+
InputSegment(
|
| 451 |
+
text=f"你需要根据指定的风格指令和文本内容来生成语音。",
|
| 452 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 453 |
+
text_zeroemb_idx=self.empty_token,
|
| 454 |
+
),
|
| 455 |
+
InputSegment(
|
| 456 |
+
text="<|im_end|>\n",
|
| 457 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 458 |
+
text_zeroemb_idx=self.empty_token,
|
| 459 |
+
),
|
| 460 |
+
InputSegment(
|
| 461 |
+
text=f"<|im_start|>user\n{template}: {text}({instruct})<|im_end|>\n",
|
| 462 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 463 |
+
text_zeroemb_idx=self.empty_token,
|
| 464 |
+
),
|
| 465 |
+
InputSegment(
|
| 466 |
+
text=f"<|im_start|>assistant\n<think>\n",
|
| 467 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 468 |
+
text_zeroemb_idx=self.empty_token,
|
| 469 |
+
),
|
| 470 |
+
]
|
| 471 |
+
else:
|
| 472 |
+
lm_prompt = [
|
| 473 |
+
InputSegment(
|
| 474 |
+
text="<|im_start|>system\n",
|
| 475 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 476 |
+
text_zeroemb_idx=self.empty_token,
|
| 477 |
+
),
|
| 478 |
+
InputSegment(
|
| 479 |
+
text=f"你需要根据指定的风格指令和文本内容来生成和语音prompt具有相同音色的语音。你的音色应该是:",
|
| 480 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 481 |
+
text_zeroemb_idx=self.empty_token,
|
| 482 |
+
),
|
| 483 |
+
InputSegment(
|
| 484 |
+
text="",
|
| 485 |
+
audio=assistant_prompt_audio_token,
|
| 486 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 487 |
+
text_zeroemb_idx=self.empty_token,
|
| 488 |
+
),
|
| 489 |
+
InputSegment(
|
| 490 |
+
text="<|im_end|>\n",
|
| 491 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 492 |
+
text_zeroemb_idx=self.empty_token,
|
| 493 |
+
),
|
| 494 |
+
InputSegment(
|
| 495 |
+
text=f"<|im_start|>user\n{template}: {text}({instruct})<|im_end|>\n",
|
| 496 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 497 |
+
text_zeroemb_idx=self.empty_token,
|
| 498 |
+
),
|
| 499 |
+
InputSegment(
|
| 500 |
+
text=f"<|im_start|>assistant\n<think>\n",
|
| 501 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 502 |
+
text_zeroemb_idx=self.empty_token,
|
| 503 |
+
),
|
| 504 |
+
]
|
| 505 |
+
|
| 506 |
+
input_ids = self.get_input_ids(lm_prompt)
|
| 507 |
+
return input_ids
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
def get_audio_understanding_sft_prompt(
|
| 511 |
+
self,
|
| 512 |
+
input_speech,
|
| 513 |
+
input_text,
|
| 514 |
+
thinking=False,
|
| 515 |
+
):
|
| 516 |
+
audio_tokenized = self.preprocess_input(input_speech)
|
| 517 |
+
|
| 518 |
+
lm_prompt = [
|
| 519 |
+
InputSegment(
|
| 520 |
+
text=f"<|im_start|>user\n",
|
| 521 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 522 |
+
text_zeroemb_idx=self.empty_token,
|
| 523 |
+
),
|
| 524 |
+
InputSegment(
|
| 525 |
+
audio=audio_tokenized,
|
| 526 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 527 |
+
text_zeroemb_idx=self.empty_token,
|
| 528 |
+
),
|
| 529 |
+
InputSegment(
|
| 530 |
+
text=input_text,
|
| 531 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 532 |
+
text_zeroemb_idx=self.empty_token,
|
| 533 |
+
),
|
| 534 |
+
InputSegment(
|
| 535 |
+
text=f"<|im_end|>\n",
|
| 536 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 537 |
+
text_zeroemb_idx=self.empty_token,
|
| 538 |
+
),
|
| 539 |
+
InputSegment(
|
| 540 |
+
text=f"<|im_start|>assistant\n",
|
| 541 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 542 |
+
text_zeroemb_idx=self.empty_token,
|
| 543 |
+
),
|
| 544 |
+
]
|
| 545 |
+
if not thinking:
|
| 546 |
+
lm_prompt.append(
|
| 547 |
+
InputSegment(
|
| 548 |
+
text="<think>\n\n</think>\n",
|
| 549 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 550 |
+
text_zeroemb_idx=self.empty_token,
|
| 551 |
+
)
|
| 552 |
+
)
|
| 553 |
+
else:
|
| 554 |
+
lm_prompt.append(
|
| 555 |
+
InputSegment(
|
| 556 |
+
text="<think>\n",
|
| 557 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 558 |
+
text_zeroemb_idx=self.empty_token,
|
| 559 |
+
)
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
input_ids = self.get_input_ids(lm_prompt)
|
| 563 |
+
return input_ids
|
| 564 |
+
|
| 565 |
+
def get_spoken_dialogue_sft_prompt(
|
| 566 |
+
self,
|
| 567 |
+
input_speech,
|
| 568 |
+
system_prompt=None,
|
| 569 |
+
prompt_speech=None,
|
| 570 |
+
add_history=False,
|
| 571 |
+
):
|
| 572 |
+
audio_tokenized = self.preprocess_input(input_speech)
|
| 573 |
+
|
| 574 |
+
lm_prompt = []
|
| 575 |
+
|
| 576 |
+
if add_history and self.history is not None:
|
| 577 |
+
lm_prompt += [
|
| 578 |
+
InputSegment(
|
| 579 |
+
text=f"<|im_start|>user\n",
|
| 580 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 581 |
+
text_zeroemb_idx=self.empty_token,
|
| 582 |
+
),
|
| 583 |
+
InputSegment(
|
| 584 |
+
audio=audio_tokenized,
|
| 585 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 586 |
+
text_zeroemb_idx=self.empty_token,
|
| 587 |
+
),
|
| 588 |
+
InputSegment(
|
| 589 |
+
text=f"<|im_end|>\n",
|
| 590 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 591 |
+
text_zeroemb_idx=self.empty_token,
|
| 592 |
+
),
|
| 593 |
+
InputSegment(
|
| 594 |
+
text=f"<|im_start|>assistant\n<|sostm|>",
|
| 595 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 596 |
+
text_zeroemb_idx=self.empty_token,
|
| 597 |
+
),
|
| 598 |
+
]
|
| 599 |
+
else:
|
| 600 |
+
if prompt_speech:
|
| 601 |
+
lm_prompt += [
|
| 602 |
+
InputSegment(
|
| 603 |
+
text="<|im_start|>system\n",
|
| 604 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 605 |
+
text_zeroemb_idx=self.empty_token,
|
| 606 |
+
),
|
| 607 |
+
InputSegment(
|
| 608 |
+
text=f"Your voice should be:",
|
| 609 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 610 |
+
text_zeroemb_idx=self.empty_token,
|
| 611 |
+
),
|
| 612 |
+
InputSegment(
|
| 613 |
+
audio=self.preprocess_input(prompt_speech),
|
| 614 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 615 |
+
text_zeroemb_idx=self.empty_token,
|
| 616 |
+
),
|
| 617 |
+
InputSegment(
|
| 618 |
+
text="<|im_end|>\n",
|
| 619 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 620 |
+
text_zeroemb_idx=self.empty_token,
|
| 621 |
+
),
|
| 622 |
+
]
|
| 623 |
+
|
| 624 |
+
lm_prompt += [
|
| 625 |
+
InputSegment(
|
| 626 |
+
text=f"<|im_start|>user\n",
|
| 627 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 628 |
+
text_zeroemb_idx=self.empty_token,
|
| 629 |
+
)
|
| 630 |
+
]
|
| 631 |
+
|
| 632 |
+
if system_prompt:
|
| 633 |
+
lm_prompt += [
|
| 634 |
+
InputSegment(
|
| 635 |
+
text=system_prompt,
|
| 636 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 637 |
+
text_zeroemb_idx=self.empty_token,
|
| 638 |
+
),
|
| 639 |
+
InputSegment(
|
| 640 |
+
text="\n\n",
|
| 641 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 642 |
+
text_zeroemb_idx=self.empty_token,
|
| 643 |
+
)
|
| 644 |
+
]
|
| 645 |
+
lm_prompt += [
|
| 646 |
+
InputSegment(
|
| 647 |
+
audio=audio_tokenized,
|
| 648 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 649 |
+
text_zeroemb_idx=self.empty_token,
|
| 650 |
+
),
|
| 651 |
+
InputSegment(
|
| 652 |
+
text=f"<|im_end|>\n",
|
| 653 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 654 |
+
text_zeroemb_idx=self.empty_token,
|
| 655 |
+
),
|
| 656 |
+
InputSegment(
|
| 657 |
+
text=f"<|im_start|>assistant\n<|sostm|>",
|
| 658 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 659 |
+
text_zeroemb_idx=self.empty_token,
|
| 660 |
+
),
|
| 661 |
+
]
|
| 662 |
+
|
| 663 |
+
input_ids = self.get_input_ids(lm_prompt)
|
| 664 |
+
return input_ids
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
def get_spoken_dialogue_sft_multiturn_prompt(
|
| 668 |
+
self,
|
| 669 |
+
message_list,
|
| 670 |
+
system_prompt=None,
|
| 671 |
+
prompt_speech=None,
|
| 672 |
+
):
|
| 673 |
+
lm_prompt = []
|
| 674 |
+
|
| 675 |
+
if prompt_speech:
|
| 676 |
+
lm_prompt += [
|
| 677 |
+
InputSegment(
|
| 678 |
+
text="<|im_start|>system\n",
|
| 679 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 680 |
+
text_zeroemb_idx=self.empty_token,
|
| 681 |
+
),
|
| 682 |
+
InputSegment(
|
| 683 |
+
text=f"Your voice should be:",
|
| 684 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 685 |
+
text_zeroemb_idx=self.empty_token,
|
| 686 |
+
),
|
| 687 |
+
InputSegment(
|
| 688 |
+
audio=self.preprocess_input(prompt_speech),
|
| 689 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 690 |
+
text_zeroemb_idx=self.empty_token,
|
| 691 |
+
),
|
| 692 |
+
InputSegment(
|
| 693 |
+
text="<|im_end|>\n",
|
| 694 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 695 |
+
text_zeroemb_idx=self.empty_token,
|
| 696 |
+
)
|
| 697 |
+
]
|
| 698 |
+
|
| 699 |
+
for i in range(len(message_list)):
|
| 700 |
+
if message_list[i]['role'] == 'user':
|
| 701 |
+
lm_prompt += [
|
| 702 |
+
InputSegment(
|
| 703 |
+
text=f"<|im_start|>user\n",
|
| 704 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 705 |
+
text_zeroemb_idx=self.empty_token,
|
| 706 |
+
)
|
| 707 |
+
]
|
| 708 |
+
if system_prompt and i == 0:
|
| 709 |
+
lm_prompt += [
|
| 710 |
+
InputSegment(
|
| 711 |
+
text=system_prompt,
|
| 712 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 713 |
+
text_zeroemb_idx=self.empty_token,
|
| 714 |
+
),
|
| 715 |
+
InputSegment(
|
| 716 |
+
text="\n\n",
|
| 717 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 718 |
+
text_zeroemb_idx=self.empty_token,
|
| 719 |
+
)
|
| 720 |
+
]
|
| 721 |
+
lm_prompt += [
|
| 722 |
+
InputSegment(
|
| 723 |
+
audio=self.preprocess_input(message_list[i]['content']),
|
| 724 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 725 |
+
text_zeroemb_idx=self.empty_token,
|
| 726 |
+
),
|
| 727 |
+
InputSegment(
|
| 728 |
+
text=f"<|im_end|>\n",
|
| 729 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 730 |
+
text_zeroemb_idx=self.empty_token,
|
| 731 |
+
)
|
| 732 |
+
]
|
| 733 |
+
elif message_list[i]['role'] == 'assistant':
|
| 734 |
+
lm_prompt += [
|
| 735 |
+
InputSegment(
|
| 736 |
+
text=f"<|im_start|>assistant\n",
|
| 737 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 738 |
+
text_zeroemb_idx=self.empty_token,
|
| 739 |
+
),
|
| 740 |
+
StreamingInputSegment(
|
| 741 |
+
text=message_list[i]['content']["text"],
|
| 742 |
+
audio=self.preprocess_input(message_list[i]['content']["audio"]),
|
| 743 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 744 |
+
text_zeroemb_idx=self.empty_token,
|
| 745 |
+
tokenizer=self.tokenizer,
|
| 746 |
+
group_size=self.group_size,
|
| 747 |
+
audio_channels=self.audio_channels,
|
| 748 |
+
),
|
| 749 |
+
InputSegment(
|
| 750 |
+
text=f"<|im_end|>\n",
|
| 751 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 752 |
+
text_zeroemb_idx=self.empty_token,
|
| 753 |
+
)
|
| 754 |
+
]
|
| 755 |
+
else:
|
| 756 |
+
raise ValueError(f"Invalid role: {message_list[i]['role']}")
|
| 757 |
+
|
| 758 |
+
lm_prompt += [
|
| 759 |
+
InputSegment(
|
| 760 |
+
text=f"<|im_start|>assistant\n<|sostm|>",
|
| 761 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 762 |
+
text_zeroemb_idx=self.empty_token,
|
| 763 |
+
),
|
| 764 |
+
]
|
| 765 |
+
|
| 766 |
+
input_ids = self.get_input_ids(lm_prompt)
|
| 767 |
+
return input_ids
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
def get_s2t_dialogue_sft_prompt(
|
| 771 |
+
self,
|
| 772 |
+
input_speech,
|
| 773 |
+
thinking=False,
|
| 774 |
+
):
|
| 775 |
+
audio_tokenized = self.preprocess_input(input_speech)
|
| 776 |
+
lm_prompt = [
|
| 777 |
+
InputSegment(
|
| 778 |
+
text=f"<|im_start|>user\n",
|
| 779 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 780 |
+
text_zeroemb_idx=self.empty_token,
|
| 781 |
+
),
|
| 782 |
+
InputSegment(
|
| 783 |
+
audio=audio_tokenized,
|
| 784 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 785 |
+
text_zeroemb_idx=self.empty_token,
|
| 786 |
+
),
|
| 787 |
+
InputSegment(
|
| 788 |
+
text=f"<|im_end|>\n",
|
| 789 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 790 |
+
text_zeroemb_idx=self.empty_token,
|
| 791 |
+
),
|
| 792 |
+
InputSegment(
|
| 793 |
+
text=f"<|im_start|>assistant\n",
|
| 794 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 795 |
+
text_zeroemb_idx=self.empty_token,
|
| 796 |
+
)
|
| 797 |
+
]
|
| 798 |
+
if not thinking:
|
| 799 |
+
lm_prompt.append(
|
| 800 |
+
InputSegment(
|
| 801 |
+
text="<think>\n\n</think>\n",
|
| 802 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 803 |
+
text_zeroemb_idx=self.empty_token,
|
| 804 |
+
)
|
| 805 |
+
)
|
| 806 |
+
else:
|
| 807 |
+
lm_prompt.append(
|
| 808 |
+
InputSegment(
|
| 809 |
+
text="<think>\n",
|
| 810 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 811 |
+
text_zeroemb_idx=self.empty_token,
|
| 812 |
+
)
|
| 813 |
+
)
|
| 814 |
+
input_ids = self.get_input_ids(lm_prompt)
|
| 815 |
+
return input_ids
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
def get_s2t_dialogue_sft_multiturn_prompt(self, message_list, thinking=False):
|
| 819 |
+
lm_prompt = []
|
| 820 |
+
for i in range(len(message_list)):
|
| 821 |
+
if message_list[i]['role'] == 'user':
|
| 822 |
+
lm_prompt += [
|
| 823 |
+
InputSegment(
|
| 824 |
+
text=f"<|im_start|>user\n",
|
| 825 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 826 |
+
text_zeroemb_idx=self.empty_token,
|
| 827 |
+
),
|
| 828 |
+
InputSegment(
|
| 829 |
+
audio=self.preprocess_input(message_list[i]['content']),
|
| 830 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 831 |
+
text_zeroemb_idx=self.empty_token,
|
| 832 |
+
),
|
| 833 |
+
InputSegment(
|
| 834 |
+
text=f"<|im_end|>\n",
|
| 835 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 836 |
+
text_zeroemb_idx=self.empty_token,
|
| 837 |
+
)
|
| 838 |
+
]
|
| 839 |
+
elif message_list[i]['role'] == 'assistant':
|
| 840 |
+
lm_prompt += [
|
| 841 |
+
InputSegment(
|
| 842 |
+
text=f"<|im_start|>assistant\n",
|
| 843 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 844 |
+
text_zeroemb_idx=self.empty_token,
|
| 845 |
+
),
|
| 846 |
+
InputSegment(
|
| 847 |
+
text=message_list[i]['content'],
|
| 848 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 849 |
+
text_zeroemb_idx=self.empty_token,
|
| 850 |
+
),
|
| 851 |
+
InputSegment(
|
| 852 |
+
text=f"<|im_end|>\n",
|
| 853 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 854 |
+
text_zeroemb_idx=self.empty_token,
|
| 855 |
+
)
|
| 856 |
+
]
|
| 857 |
+
else:
|
| 858 |
+
raise ValueError(f"Invalid role: {message_list[i]['role']}")
|
| 859 |
+
|
| 860 |
+
lm_prompt.append(
|
| 861 |
+
InputSegment(
|
| 862 |
+
text=f"<|im_start|>assistant\n",
|
| 863 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 864 |
+
text_zeroemb_idx=self.empty_token,
|
| 865 |
+
)
|
| 866 |
+
)
|
| 867 |
+
if not thinking:
|
| 868 |
+
lm_prompt.append(
|
| 869 |
+
InputSegment(
|
| 870 |
+
text="<think>\n\n</think>\n",
|
| 871 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 872 |
+
text_zeroemb_idx=self.empty_token,
|
| 873 |
+
)
|
| 874 |
+
)
|
| 875 |
+
else:
|
| 876 |
+
lm_prompt.append(
|
| 877 |
+
InputSegment(
|
| 878 |
+
text="<think>\n",
|
| 879 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 880 |
+
text_zeroemb_idx=self.empty_token,
|
| 881 |
+
)
|
| 882 |
+
)
|
| 883 |
+
input_ids = self.get_input_ids(lm_prompt)
|
| 884 |
+
return input_ids
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
def get_text_dialogue_sft_prompt(
|
| 888 |
+
self,
|
| 889 |
+
input_text,
|
| 890 |
+
thinking=False,
|
| 891 |
+
):
|
| 892 |
+
lm_prompt = [
|
| 893 |
+
InputSegment(
|
| 894 |
+
text=f"<|im_start|>user\n",
|
| 895 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 896 |
+
text_zeroemb_idx=self.empty_token,
|
| 897 |
+
),
|
| 898 |
+
InputSegment(
|
| 899 |
+
text=input_text,
|
| 900 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 901 |
+
text_zeroemb_idx=self.empty_token,
|
| 902 |
+
),
|
| 903 |
+
InputSegment(
|
| 904 |
+
text=f"<|im_end|>\n",
|
| 905 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 906 |
+
text_zeroemb_idx=self.empty_token,
|
| 907 |
+
),
|
| 908 |
+
InputSegment(
|
| 909 |
+
text=f"<|im_start|>assistant\n",
|
| 910 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 911 |
+
text_zeroemb_idx=self.empty_token,
|
| 912 |
+
),
|
| 913 |
+
]
|
| 914 |
+
if not thinking:
|
| 915 |
+
lm_prompt.append(
|
| 916 |
+
InputSegment(
|
| 917 |
+
text="<think>\n\n</think>\n",
|
| 918 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 919 |
+
text_zeroemb_idx=self.empty_token,
|
| 920 |
+
)
|
| 921 |
+
)
|
| 922 |
+
else:
|
| 923 |
+
lm_prompt.append(
|
| 924 |
+
InputSegment(
|
| 925 |
+
text="<think>\n",
|
| 926 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 927 |
+
text_zeroemb_idx=self.empty_token,
|
| 928 |
+
)
|
| 929 |
+
)
|
| 930 |
+
input_ids = self.get_input_ids(lm_prompt)
|
| 931 |
+
return input_ids
|
| 932 |
+
|
| 933 |
+
def get_text_dialogue_sft_multiturn_prompt(
|
| 934 |
+
self,
|
| 935 |
+
message_list,
|
| 936 |
+
thinking=False,
|
| 937 |
+
):
|
| 938 |
+
lm_prompt = []
|
| 939 |
+
for i in range(len(message_list)):
|
| 940 |
+
if message_list[i]['role'] == 'user':
|
| 941 |
+
lm_prompt += [
|
| 942 |
+
InputSegment(
|
| 943 |
+
text=f"<|im_start|>user\n",
|
| 944 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 945 |
+
text_zeroemb_idx=self.empty_token,
|
| 946 |
+
),
|
| 947 |
+
InputSegment(
|
| 948 |
+
text=message_list[i]['content'],
|
| 949 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 950 |
+
text_zeroemb_idx=self.empty_token,
|
| 951 |
+
),
|
| 952 |
+
InputSegment(
|
| 953 |
+
text=f"<|im_end|>\n",
|
| 954 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 955 |
+
text_zeroemb_idx=self.empty_token,
|
| 956 |
+
)
|
| 957 |
+
]
|
| 958 |
+
elif message_list[i]['role'] == 'assistant':
|
| 959 |
+
lm_prompt += [
|
| 960 |
+
InputSegment(
|
| 961 |
+
text=f"<|im_start|>assistant\n",
|
| 962 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 963 |
+
text_zeroemb_idx=self.empty_token,
|
| 964 |
+
),
|
| 965 |
+
InputSegment(
|
| 966 |
+
text=message_list[i]['content'],
|
| 967 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 968 |
+
text_zeroemb_idx=self.empty_token,
|
| 969 |
+
),
|
| 970 |
+
InputSegment(
|
| 971 |
+
text=f"<|im_end|>\n",
|
| 972 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 973 |
+
text_zeroemb_idx=self.empty_token,
|
| 974 |
+
)
|
| 975 |
+
]
|
| 976 |
+
else:
|
| 977 |
+
raise ValueError(f"Invalid role: {message_list[i]['role']}")
|
| 978 |
+
|
| 979 |
+
lm_prompt.append(
|
| 980 |
+
InputSegment(
|
| 981 |
+
text=f"<|im_start|>assistant\n",
|
| 982 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 983 |
+
text_zeroemb_idx=self.empty_token,
|
| 984 |
+
)
|
| 985 |
+
)
|
| 986 |
+
if not thinking:
|
| 987 |
+
lm_prompt.append(
|
| 988 |
+
InputSegment(
|
| 989 |
+
text="<think>\n\n</think>\n",
|
| 990 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 991 |
+
text_zeroemb_idx=self.empty_token,
|
| 992 |
+
)
|
| 993 |
+
)
|
| 994 |
+
else:
|
| 995 |
+
lm_prompt.append(
|
| 996 |
+
InputSegment(
|
| 997 |
+
text="<think>\n",
|
| 998 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 999 |
+
text_zeroemb_idx=self.empty_token,
|
| 1000 |
+
)
|
| 1001 |
+
)
|
| 1002 |
+
input_ids = self.get_input_ids(lm_prompt)
|
| 1003 |
+
return input_ids
|
| 1004 |
+
|
| 1005 |
+
|
| 1006 |
+
def get_in_context_learning_s2s_prompt(self, instruction, prompt_examples, audio):
|
| 1007 |
+
prompt = [
|
| 1008 |
+
InputSegment(
|
| 1009 |
+
text=f"[Int]:{instruction}\n",
|
| 1010 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 1011 |
+
text_zeroemb_idx=self.empty_token,
|
| 1012 |
+
)
|
| 1013 |
+
]
|
| 1014 |
+
|
| 1015 |
+
for i in range(len(prompt_examples)):
|
| 1016 |
+
prompt += [
|
| 1017 |
+
InputSegment(
|
| 1018 |
+
audio=self.preprocess_input(prompt_examples[i]["input_audio"]),
|
| 1019 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 1020 |
+
text_zeroemb_idx=self.empty_token,
|
| 1021 |
+
),
|
| 1022 |
+
InputSegment(
|
| 1023 |
+
text="\n",
|
| 1024 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 1025 |
+
text_zeroemb_idx=self.empty_token,
|
| 1026 |
+
),
|
| 1027 |
+
StreamingInputSegment(
|
| 1028 |
+
text=prompt_examples[i]["output_transcription"],
|
| 1029 |
+
audio=self.preprocess_input(prompt_examples[i]["output_audio"]),
|
| 1030 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 1031 |
+
text_zeroemb_idx=self.empty_token,
|
| 1032 |
+
tokenizer=self.tokenizer,
|
| 1033 |
+
group_size=self.group_size,
|
| 1034 |
+
audio_channels=self.audio_channels,
|
| 1035 |
+
),
|
| 1036 |
+
InputSegment(
|
| 1037 |
+
text=" \n\n",
|
| 1038 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 1039 |
+
text_zeroemb_idx=self.empty_token,
|
| 1040 |
+
),
|
| 1041 |
+
]
|
| 1042 |
+
|
| 1043 |
+
prompt += [
|
| 1044 |
+
InputSegment(
|
| 1045 |
+
audio=self.preprocess_input(audio),
|
| 1046 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 1047 |
+
text_zeroemb_idx=self.empty_token,
|
| 1048 |
+
),
|
| 1049 |
+
InputSegment(
|
| 1050 |
+
text="\n",
|
| 1051 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 1052 |
+
text_zeroemb_idx=self.empty_token,
|
| 1053 |
+
),
|
| 1054 |
+
InputSegment(
|
| 1055 |
+
text="<|sostm|>",
|
| 1056 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 1057 |
+
text_zeroemb_idx=self.empty_token,
|
| 1058 |
+
),
|
| 1059 |
+
]
|
| 1060 |
+
input_ids = self.get_input_ids(prompt)
|
| 1061 |
+
return input_ids
|
| 1062 |
+
|
| 1063 |
+
@torch.no_grad()
|
| 1064 |
+
def forward(
|
| 1065 |
+
self,
|
| 1066 |
+
input_ids,
|
| 1067 |
+
return_audio=False,
|
| 1068 |
+
output_audio_path=None,
|
| 1069 |
+
stopping_criteria=None,
|
| 1070 |
+
min_new_tokens=0,
|
| 1071 |
+
max_new_tokens=8192,
|
| 1072 |
+
add_history=False,
|
| 1073 |
+
task_name=None,
|
| 1074 |
+
):
|
| 1075 |
+
|
| 1076 |
+
task_sampler = self.get_task_sampler(task_name)
|
| 1077 |
+
|
| 1078 |
+
generation_kwargs = self.generate_kwargs.copy()
|
| 1079 |
+
generation_config = GenerationConfig(**generation_kwargs)
|
| 1080 |
+
|
| 1081 |
+
input_ids = input_ids.T.reshape(1, -1) # [B, flattened(T, audio_channels + 1)]
|
| 1082 |
+
if add_history and self.history is not None:
|
| 1083 |
+
input_ids = torch.cat([self.history, input_ids], dim=1)
|
| 1084 |
+
|
| 1085 |
+
prompt_length = input_ids.shape[1] // (self.audio_channels+1)
|
| 1086 |
+
|
| 1087 |
+
max_length = prompt_length // self.group_size + max_new_tokens
|
| 1088 |
+
min_length = prompt_length // self.group_size + min_new_tokens
|
| 1089 |
+
|
| 1090 |
+
if stopping_criteria is not None:
|
| 1091 |
+
for criterion in stopping_criteria:
|
| 1092 |
+
if isinstance(criterion, MiMoStopper):
|
| 1093 |
+
criterion.max_length = max_length
|
| 1094 |
+
criterion.min_length = min_length
|
| 1095 |
+
|
| 1096 |
+
generated_ids = self.model.generate(
|
| 1097 |
+
input_ids,
|
| 1098 |
+
generation_config,
|
| 1099 |
+
stopping_criteria=stopping_criteria,
|
| 1100 |
+
global_sampler=task_sampler["global"],
|
| 1101 |
+
local_sampler=task_sampler["local"],
|
| 1102 |
+
)
|
| 1103 |
+
|
| 1104 |
+
self.history = generated_ids
|
| 1105 |
+
generated_ids = generated_ids.int().cpu().reshape(-1, self.audio_channels+1).T[:, prompt_length:]
|
| 1106 |
+
|
| 1107 |
+
text = generated_ids[0, ::self.group_size][:-1]
|
| 1108 |
+
detokenized_text = self.tokenizer.decode(text, skip_special_tokens=False).strip().replace("<|empty|>", "").replace("<|eot|>", "").replace("<|eostm|>", "")
|
| 1109 |
+
print("Text channel:\t", detokenized_text)
|
| 1110 |
+
|
| 1111 |
+
if output_audio_path:
|
| 1112 |
+
return_audio = True
|
| 1113 |
+
|
| 1114 |
+
if not return_audio:
|
| 1115 |
+
return detokenized_text
|
| 1116 |
+
|
| 1117 |
+
sosp_idx_locations = (text == self.sostm_idx).nonzero(as_tuple=True)[0]
|
| 1118 |
+
eosp_idx_locations = (text == self.eostm_idx).nonzero(as_tuple=True)[0]
|
| 1119 |
+
if len(sosp_idx_locations) == 0:
|
| 1120 |
+
start_location = 0
|
| 1121 |
+
else:
|
| 1122 |
+
start_location = sosp_idx_locations[0] * self.group_size + self.group_size
|
| 1123 |
+
if len(eosp_idx_locations) == 0:
|
| 1124 |
+
end_location = text.shape[0] * self.group_size
|
| 1125 |
+
else:
|
| 1126 |
+
end_location = eosp_idx_locations[0] * self.group_size
|
| 1127 |
+
audio_sequence = generated_ids[:, start_location:end_location] #[audio_channels+1, audio_length]
|
| 1128 |
+
speech_sequence = audio_sequence[1:]
|
| 1129 |
+
|
| 1130 |
+
mask = speech_sequence[0] != (self.speech_zeroemb_idx[0] if isinstance(self.speech_zeroemb_idx, list) else self.speech_zeroemb_idx)
|
| 1131 |
+
speech_sequence = speech_sequence[:, mask]
|
| 1132 |
+
|
| 1133 |
+
assert (speech_sequence < torch.tensor(self.speech_zeroemb_idx).unsqueeze(1)).all()
|
| 1134 |
+
|
| 1135 |
+
speech_sequence = speech_sequence.T.flatten()
|
| 1136 |
+
|
| 1137 |
+
speech_str = "".join([f"<{i}>" for i in speech_sequence])
|
| 1138 |
+
tokens = torch.tensor(
|
| 1139 |
+
[int(num) for num in re.findall(r"(\d+)>", speech_str)]
|
| 1140 |
+
)
|
| 1141 |
+
|
| 1142 |
+
if tokens.numel() == 0:
|
| 1143 |
+
wav = torch.zeros(24000)
|
| 1144 |
+
self.save_wav(output_audio_path, wav)
|
| 1145 |
+
return detokenized_text
|
| 1146 |
+
|
| 1147 |
+
codes = tokens.reshape(-1, self.audio_channels).T
|
| 1148 |
+
codes = codes.type(torch.LongTensor).to(self.device)
|
| 1149 |
+
|
| 1150 |
+
segment_len = 1500
|
| 1151 |
+
wav_list=[]
|
| 1152 |
+
for start in range(0, codes.shape[-1], segment_len):
|
| 1153 |
+
wav = self.mimo_audio_tokenizer.decode(codes[:,start:start+segment_len]).float()
|
| 1154 |
+
wav_list.append(wav)
|
| 1155 |
+
wav_concat = torch.cat(wav_list, dim=-1)
|
| 1156 |
+
|
| 1157 |
+
#wav = self.mimo_audio_tokenizer.decode(codes).float()
|
| 1158 |
+
if output_audio_path is not None:
|
| 1159 |
+
self.save_wav(output_audio_path, wav_concat)
|
| 1160 |
+
return detokenized_text
|
| 1161 |
+
else:
|
| 1162 |
+
return wav_concat
|
| 1163 |
+
|
| 1164 |
+
def asr_sft(self, audio):
|
| 1165 |
+
stopping_criteria = [
|
| 1166 |
+
MiMoStopper(
|
| 1167 |
+
stop_tokens=[self.tokenizer.eos_token_id, self.im_end_idx],
|
| 1168 |
+
group_size=self.group_size,
|
| 1169 |
+
audio_channels=self.audio_channels,
|
| 1170 |
+
)
|
| 1171 |
+
]
|
| 1172 |
+
input_ids = self.get_asr_sft_prompt(audio)
|
| 1173 |
+
result = self.forward(input_ids, stopping_criteria=stopping_criteria, task_name="asr")
|
| 1174 |
+
return result
|
| 1175 |
+
|
| 1176 |
+
def tts_sft(self, text, output_path, instruct=None, read_text_only=True, prompt_speech=None):
|
| 1177 |
+
stopping_criteria = [
|
| 1178 |
+
MiMoStopper(
|
| 1179 |
+
stop_tokens=[self.tokenizer.eos_token_id, self.eostm_idx, self.im_end_idx],
|
| 1180 |
+
group_size=self.group_size,
|
| 1181 |
+
audio_channels=self.audio_channels,
|
| 1182 |
+
)
|
| 1183 |
+
]
|
| 1184 |
+
input_ids = self.get_tts_sft_prompt(text, instruct=instruct, read_text_only=read_text_only, prompt_speech=prompt_speech)
|
| 1185 |
+
text_output = self.forward(input_ids, output_audio_path=output_path, stopping_criteria=stopping_criteria, task_name="tts")
|
| 1186 |
+
return text_output
|
| 1187 |
+
|
| 1188 |
+
def audio_understanding_sft(self, input_speech, input_text, thinking=False):
|
| 1189 |
+
stopping_criteria = [
|
| 1190 |
+
MiMoStopper(
|
| 1191 |
+
stop_tokens=[self.tokenizer.eos_token_id, self.im_end_idx],
|
| 1192 |
+
group_size=self.group_size,
|
| 1193 |
+
audio_channels=self.audio_channels,
|
| 1194 |
+
)
|
| 1195 |
+
]
|
| 1196 |
+
input_ids = self.get_audio_understanding_sft_prompt(input_speech, input_text, thinking=thinking)
|
| 1197 |
+
result = self.forward(input_ids, stopping_criteria=stopping_criteria, task_name="audio_understanding")
|
| 1198 |
+
return result
|
| 1199 |
+
|
| 1200 |
+
def spoken_dialogue_sft(self, input_speech, output_audio_path=None, system_prompt=None, prompt_speech=None, add_history=False):
|
| 1201 |
+
stopping_criteria = [
|
| 1202 |
+
MiMoStopper(
|
| 1203 |
+
stop_tokens=[self.tokenizer.eos_token_id, self.eostm_idx, self.im_end_idx],
|
| 1204 |
+
group_size=self.group_size,
|
| 1205 |
+
audio_channels=self.audio_channels,
|
| 1206 |
+
)
|
| 1207 |
+
]
|
| 1208 |
+
input_ids = self.get_spoken_dialogue_sft_prompt(input_speech, system_prompt=system_prompt, prompt_speech=prompt_speech, add_history=add_history)
|
| 1209 |
+
text = self.forward(input_ids, output_audio_path=output_audio_path, stopping_criteria=stopping_criteria, task_name="spoken_dialogue", add_history=add_history)
|
| 1210 |
+
return text
|
| 1211 |
+
|
| 1212 |
+
# interface for message list interaction
|
| 1213 |
+
def spoken_dialogue_sft_multiturn(self, message_list, output_audio_path=None, system_prompt=None, prompt_speech=None):
|
| 1214 |
+
stopping_criteria = [
|
| 1215 |
+
MiMoStopper(
|
| 1216 |
+
stop_tokens=[self.tokenizer.eos_token_id, self.eostm_idx, self.im_end_idx],
|
| 1217 |
+
group_size=self.group_size,
|
| 1218 |
+
audio_channels=self.audio_channels,
|
| 1219 |
+
)
|
| 1220 |
+
]
|
| 1221 |
+
input_ids = self.get_spoken_dialogue_sft_multiturn_prompt(message_list, system_prompt=system_prompt, prompt_speech=prompt_speech)
|
| 1222 |
+
text = self.forward(input_ids, output_audio_path=output_audio_path, stopping_criteria=stopping_criteria, task_name="spoken_dialogue", add_history=False)
|
| 1223 |
+
return text
|
| 1224 |
+
|
| 1225 |
+
|
| 1226 |
+
def speech2text_dialogue_sft(self, input_speech, thinking=False, add_history=False):
|
| 1227 |
+
stopping_criteria = [
|
| 1228 |
+
MiMoStopper(
|
| 1229 |
+
stop_tokens=[self.tokenizer.eos_token_id, self.im_end_idx],
|
| 1230 |
+
group_size=self.group_size,
|
| 1231 |
+
audio_channels=self.audio_channels,
|
| 1232 |
+
)
|
| 1233 |
+
]
|
| 1234 |
+
input_ids = self.get_s2t_dialogue_sft_prompt(input_speech, thinking=thinking)
|
| 1235 |
+
text = self.forward(input_ids, stopping_criteria=stopping_criteria, task_name="spoken_dialogue", add_history=add_history)
|
| 1236 |
+
return text
|
| 1237 |
+
|
| 1238 |
+
|
| 1239 |
+
# interface for message list interaction
|
| 1240 |
+
def speech2text_dialogue_sft_multiturn(self, message_list, thinking=False):
|
| 1241 |
+
stopping_criteria = [
|
| 1242 |
+
MiMoStopper(
|
| 1243 |
+
stop_tokens=[self.tokenizer.eos_token_id, self.im_end_idx],
|
| 1244 |
+
group_size=self.group_size,
|
| 1245 |
+
audio_channels=self.audio_channels,
|
| 1246 |
+
)
|
| 1247 |
+
]
|
| 1248 |
+
input_ids = self.get_s2t_dialogue_sft_multiturn_prompt(message_list, thinking=thinking)
|
| 1249 |
+
text = self.forward(input_ids, stopping_criteria=stopping_criteria, task_name="spoken_dialogue", add_history=False)
|
| 1250 |
+
return text
|
| 1251 |
+
|
| 1252 |
+
def text_dialogue_sft(self, input_text, thinking=False, add_history=False):
|
| 1253 |
+
stopping_criteria = [
|
| 1254 |
+
MiMoStopper(
|
| 1255 |
+
stop_tokens=[self.tokenizer.eos_token_id, self.im_end_idx],
|
| 1256 |
+
group_size=self.group_size,
|
| 1257 |
+
audio_channels=self.audio_channels,
|
| 1258 |
+
)
|
| 1259 |
+
]
|
| 1260 |
+
input_ids = self.get_text_dialogue_sft_prompt(input_text, thinking=thinking)
|
| 1261 |
+
text = self.forward(input_ids, stopping_criteria=stopping_criteria, task_name="text_chat", add_history=add_history)
|
| 1262 |
+
return text
|
| 1263 |
+
|
| 1264 |
+
# interface for message list interaction
|
| 1265 |
+
def text_dialogue_sft_multiturn(self, message_list, thinking=False):
|
| 1266 |
+
stopping_criteria = [
|
| 1267 |
+
MiMoStopper(
|
| 1268 |
+
stop_tokens=[self.tokenizer.eos_token_id, self.im_end_idx],
|
| 1269 |
+
group_size=self.group_size,
|
| 1270 |
+
audio_channels=self.audio_channels,
|
| 1271 |
+
)
|
| 1272 |
+
]
|
| 1273 |
+
input_ids = self.get_text_dialogue_sft_multiturn_prompt(message_list, thinking=thinking)
|
| 1274 |
+
text = self.forward(input_ids, stopping_criteria=stopping_criteria, task_name="text_chat", add_history=False)
|
| 1275 |
+
return text
|
| 1276 |
+
|
| 1277 |
+
def clear_history(self):
|
| 1278 |
+
self.history = None
|
| 1279 |
+
print("History cleared")
|
| 1280 |
+
|
| 1281 |
+
def in_context_learning_s2s(self, instruction, prompt_examples, audio, max_new_tokens=None, output_audio_path=None):
|
| 1282 |
+
stopping_criteria = [
|
| 1283 |
+
MiMoStopper(
|
| 1284 |
+
stop_tokens=[self.tokenizer.eos_token_id, self.eostm_idx],
|
| 1285 |
+
group_size=self.group_size,
|
| 1286 |
+
audio_channels=self.audio_channels,
|
| 1287 |
+
)
|
| 1288 |
+
]
|
| 1289 |
+
input_ids = self.get_in_context_learning_s2s_prompt(instruction, prompt_examples, audio)
|
| 1290 |
+
self.forward(input_ids, output_audio_path=output_audio_path, stopping_criteria=stopping_criteria, max_new_tokens=max_new_tokens, task_name="in_context_learning_s2s")
|
| 1291 |
+
|
| 1292 |
+
|
src/mimo_audio/modeling_mimo_audio.py
ADDED
|
@@ -0,0 +1,835 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Xiaomi Corporation.
|
| 2 |
+
import copy
|
| 3 |
+
import logging
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import List, Optional, Union, cast
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
from torch import nn
|
| 10 |
+
from transformers import StoppingCriteria
|
| 11 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 12 |
+
from transformers.generation.streamers import BaseStreamer
|
| 13 |
+
from transformers.generation.utils import (
|
| 14 |
+
GenerateOutput,
|
| 15 |
+
GenerationConfig,
|
| 16 |
+
StoppingCriteriaList,
|
| 17 |
+
is_deepspeed_zero3_enabled,
|
| 18 |
+
)
|
| 19 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
| 20 |
+
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
| 21 |
+
from transformers.models.qwen2.modeling_qwen2 import (
|
| 22 |
+
Qwen2Model,
|
| 23 |
+
Qwen2PreTrainedModel,
|
| 24 |
+
)
|
| 25 |
+
from transformers.utils import is_torchdynamo_compiling
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class MiMoStopper(StoppingCriteria):
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
group_size: int,
|
| 35 |
+
audio_channels: int,
|
| 36 |
+
stop_tokens: list[int] | None = None,
|
| 37 |
+
max_length: int | None = None,
|
| 38 |
+
min_length: int | None = None,
|
| 39 |
+
) -> None:
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.group_size = group_size
|
| 42 |
+
self.audio_channels = audio_channels
|
| 43 |
+
self.step = (audio_channels + 1) * group_size
|
| 44 |
+
|
| 45 |
+
self.stop_token_ids = set(stop_tokens or [])
|
| 46 |
+
|
| 47 |
+
self.max_length = max_length
|
| 48 |
+
self.min_length = min_length or 0
|
| 49 |
+
|
| 50 |
+
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor):
|
| 51 |
+
is_done = False
|
| 52 |
+
cur_len = input_ids.shape[-1] // self.step
|
| 53 |
+
|
| 54 |
+
if self.max_length:
|
| 55 |
+
is_done |= cur_len >= self.max_length
|
| 56 |
+
|
| 57 |
+
if (self.stop_token_ids and
|
| 58 |
+
input_ids.shape[1] >= self.step and
|
| 59 |
+
cur_len >= self.min_length):
|
| 60 |
+
last_token = input_ids[0, -self.step].item()
|
| 61 |
+
is_done |= last_token in self.stop_token_ids
|
| 62 |
+
|
| 63 |
+
return torch.full(
|
| 64 |
+
(input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@dataclass
|
| 69 |
+
class MiMoSampler:
|
| 70 |
+
do_sample: bool | None = None
|
| 71 |
+
temperature: float | None = None
|
| 72 |
+
top_k: int | None = None
|
| 73 |
+
top_p: float | None = None
|
| 74 |
+
|
| 75 |
+
def process(self, scores: torch.Tensor):
|
| 76 |
+
if self.temperature is not None:
|
| 77 |
+
scores = scores / self.temperature
|
| 78 |
+
|
| 79 |
+
if self.top_k is not None and self.top_k > 0:
|
| 80 |
+
top_k = min(self.top_k, scores.shape[-1])
|
| 81 |
+
indices_to_remove = scores < torch.topk(scores, top_k)[0][:, -1]
|
| 82 |
+
scores = scores.masked_fill(indices_to_remove, float("-inf"))
|
| 83 |
+
|
| 84 |
+
if self.top_p is not None and 0.0 < self.top_p <= 1.0:
|
| 85 |
+
top_p = self.top_p if 0.0 < self.top_p <= 1.0 else 1.0
|
| 86 |
+
sorted_logits, sorted_indices = torch.sort(scores)
|
| 87 |
+
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
| 88 |
+
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
|
| 89 |
+
sorted_indices_to_remove[:, -1] = 0
|
| 90 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
| 91 |
+
1, sorted_indices, sorted_indices_to_remove
|
| 92 |
+
)
|
| 93 |
+
scores = scores.masked_fill(indices_to_remove, float("-inf"))
|
| 94 |
+
|
| 95 |
+
return scores
|
| 96 |
+
|
| 97 |
+
def sample(self, scores: torch.Tensor, removed_tokens: list[int] | None = None):
|
| 98 |
+
scores = self.process(scores)
|
| 99 |
+
for t in removed_tokens or []:
|
| 100 |
+
scores[:, t] = float("-inf")
|
| 101 |
+
|
| 102 |
+
if self.do_sample:
|
| 103 |
+
probs = scores.softmax(dim=-1)
|
| 104 |
+
return torch.multinomial(probs, num_samples=1).squeeze(-1)
|
| 105 |
+
|
| 106 |
+
return torch.argmax(scores, dim=-1)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@dataclass
|
| 110 |
+
class MiMoAudioOutput(ModelOutput):
|
| 111 |
+
text_logits: torch.FloatTensor | None = None
|
| 112 |
+
local_hidden_states: torch.FloatTensor | None = None
|
| 113 |
+
past_key_values: Cache | None = None
|
| 114 |
+
"""Downcast hidden states for local transformer generation"""
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@dataclass
|
| 118 |
+
class MiMoAudioConfig(Qwen2Config):
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
*,
|
| 122 |
+
speech_vocab_size: str | int = "1025-1025-129-129-129-129-129-129",
|
| 123 |
+
speech_zeroemb_idx: str | int = "1024-1024-128-128-128-128-128-128",
|
| 124 |
+
delay_pattern: str = "0-1-2-3-4-5-6-7",
|
| 125 |
+
head_dim: int = 128,
|
| 126 |
+
group_size: int = 4,
|
| 127 |
+
audio_channels: int = 8,
|
| 128 |
+
local_dim: int = 1024,
|
| 129 |
+
local_layers: int = 16,
|
| 130 |
+
local_attn_heads: int = 64,
|
| 131 |
+
local_ffn_dim: int = 4096,
|
| 132 |
+
local_attn_dropout: float = 0.1,
|
| 133 |
+
input_local_layers: int = 6,
|
| 134 |
+
input_local_dim: int | None = None,
|
| 135 |
+
input_full_attention: bool | None = None,
|
| 136 |
+
**kwargs,
|
| 137 |
+
):
|
| 138 |
+
super().__init__(
|
| 139 |
+
**kwargs,
|
| 140 |
+
)
|
| 141 |
+
self.speech_vocab_size = speech_vocab_size
|
| 142 |
+
self.speech_zeroemb_idx = speech_zeroemb_idx
|
| 143 |
+
self.delay_pattern = delay_pattern
|
| 144 |
+
|
| 145 |
+
self.head_dim = head_dim
|
| 146 |
+
|
| 147 |
+
self.group_size = group_size
|
| 148 |
+
self.audio_channels = audio_channels
|
| 149 |
+
|
| 150 |
+
self.local_dim = local_dim
|
| 151 |
+
self.local_layers = local_layers
|
| 152 |
+
self.local_attn_heads = local_attn_heads
|
| 153 |
+
self.local_ffn_dim = local_ffn_dim
|
| 154 |
+
self.local_attn_dropout = local_attn_dropout
|
| 155 |
+
|
| 156 |
+
self.input_local_layers = input_local_layers
|
| 157 |
+
self.input_local_dim = input_local_dim or local_dim
|
| 158 |
+
|
| 159 |
+
self.input_full_attention = input_full_attention
|
| 160 |
+
|
| 161 |
+
def _parse_maybe_list(self, value: str | int, length: int) -> List[int]:
|
| 162 |
+
if isinstance(value, str) and "-" in value:
|
| 163 |
+
return [int(s) for s in value.split("-")]
|
| 164 |
+
return [int(value)] * length
|
| 165 |
+
|
| 166 |
+
def parsed_speech_empty_ids(self):
|
| 167 |
+
return self._parse_maybe_list(self.speech_zeroemb_idx, self.audio_channels)
|
| 168 |
+
|
| 169 |
+
def parsed_speech_vocab_sizes(self):
|
| 170 |
+
return self._parse_maybe_list(self.speech_vocab_size, self.audio_channels)
|
| 171 |
+
|
| 172 |
+
def parsed_delay_pattern(self):
|
| 173 |
+
return self._parse_maybe_list(self.delay_pattern, self.audio_channels)
|
| 174 |
+
|
| 175 |
+
def local_config(self):
|
| 176 |
+
config = copy.deepcopy(self)
|
| 177 |
+
|
| 178 |
+
config.hidden_size = self.local_dim
|
| 179 |
+
config.num_hidden_layers = self.local_layers
|
| 180 |
+
config.num_attention_heads = self.local_attn_heads
|
| 181 |
+
config.num_key_value_heads = self.local_attn_heads
|
| 182 |
+
config.head_dim = config.hidden_size // self.local_attn_heads
|
| 183 |
+
config.intermediate_size = self.local_ffn_dim
|
| 184 |
+
config.attention_dropout = self.local_attn_dropout
|
| 185 |
+
|
| 186 |
+
return config
|
| 187 |
+
|
| 188 |
+
def input_local_config(self):
|
| 189 |
+
config = copy.deepcopy(self)
|
| 190 |
+
|
| 191 |
+
config.hidden_size = self.input_local_dim
|
| 192 |
+
config.num_hidden_layers = self.input_local_layers
|
| 193 |
+
config.num_attention_heads = self.local_attn_heads
|
| 194 |
+
config.num_key_value_heads = self.local_attn_heads
|
| 195 |
+
config.head_dim = config.hidden_size // self.local_attn_heads
|
| 196 |
+
config.intermediate_size = config.hidden_size * 4
|
| 197 |
+
config.attention_dropout = self.local_attn_dropout
|
| 198 |
+
|
| 199 |
+
return config
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
@dataclass
|
| 203 |
+
class MiMoAudioArguments:
|
| 204 |
+
model_name_or_path: str
|
| 205 |
+
sosp_idx: int
|
| 206 |
+
eosp_idx: int
|
| 207 |
+
sostm_idx: int
|
| 208 |
+
eostm_idx: int
|
| 209 |
+
eot_idx: int
|
| 210 |
+
empty_idx: int
|
| 211 |
+
|
| 212 |
+
def to_dict(self):
|
| 213 |
+
return {
|
| 214 |
+
"model_name_or_path": self.model_name_or_path,
|
| 215 |
+
"sosp_idx": self.sosp_idx,
|
| 216 |
+
"eosp_idx": self.eosp_idx,
|
| 217 |
+
"sostm_idx": self.sostm_idx,
|
| 218 |
+
"eostm_idx": self.eostm_idx,
|
| 219 |
+
"eot_idx": self.eot_idx,
|
| 220 |
+
"empty_idx": self.empty_idx,
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class MiMoAudioForCausalLM(Qwen2PreTrainedModel):
|
| 225 |
+
def __init__(
|
| 226 |
+
self,
|
| 227 |
+
config: MiMoAudioConfig | Qwen2Config,
|
| 228 |
+
args: MiMoAudioArguments | dict,
|
| 229 |
+
):
|
| 230 |
+
super().__init__(config)
|
| 231 |
+
config = (
|
| 232 |
+
MiMoAudioConfig(**vars(config))
|
| 233 |
+
if isinstance(config, Qwen2Config)
|
| 234 |
+
else config
|
| 235 |
+
)
|
| 236 |
+
args = MiMoAudioArguments(**args) if isinstance(args, dict) else args
|
| 237 |
+
self.config = config
|
| 238 |
+
self.args = args
|
| 239 |
+
|
| 240 |
+
self.model = Qwen2Model(config)
|
| 241 |
+
|
| 242 |
+
self.speech_vocab_sizes = config.parsed_speech_vocab_sizes()
|
| 243 |
+
self.speech_empty_ids = config.parsed_speech_empty_ids()
|
| 244 |
+
self.delay_pattern = config.parsed_delay_pattern()
|
| 245 |
+
|
| 246 |
+
self.group_size = config.group_size
|
| 247 |
+
self.audio_channels = config.audio_channels
|
| 248 |
+
|
| 249 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 250 |
+
|
| 251 |
+
# Construct local transformer
|
| 252 |
+
self.local_config = config.local_config()
|
| 253 |
+
self.local_transformer = Qwen2Model(self.local_config)
|
| 254 |
+
self.local_transformer.embed_tokens = None
|
| 255 |
+
|
| 256 |
+
# Add input local transformer if configured
|
| 257 |
+
self.input_local_config = config.input_local_config()
|
| 258 |
+
self.input_local_transformer = Qwen2Model(self.input_local_config)
|
| 259 |
+
self.input_local_transformer.embed_tokens = None
|
| 260 |
+
|
| 261 |
+
self.local_transformer_lm_heads = nn.ModuleList(
|
| 262 |
+
[
|
| 263 |
+
nn.Linear(
|
| 264 |
+
self.local_config.hidden_size,
|
| 265 |
+
self.speech_vocab_sizes[i],
|
| 266 |
+
bias=False,
|
| 267 |
+
)
|
| 268 |
+
for i in range(self.audio_channels)
|
| 269 |
+
]
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
self.speech_embeddings = nn.ModuleList(
|
| 273 |
+
[
|
| 274 |
+
nn.Embedding(
|
| 275 |
+
self.speech_vocab_sizes[i],
|
| 276 |
+
self.input_local_config.hidden_size,
|
| 277 |
+
padding_idx=self.speech_empty_ids[i],
|
| 278 |
+
)
|
| 279 |
+
for i in range(self.audio_channels)
|
| 280 |
+
]
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
if self.input_local_config.hidden_size != self.local_config.hidden_size:
|
| 284 |
+
self.speech_embeddings_to_local = nn.Linear(
|
| 285 |
+
self.input_local_config.hidden_size,
|
| 286 |
+
self.local_config.hidden_size,
|
| 287 |
+
bias=False,
|
| 288 |
+
)
|
| 289 |
+
else:
|
| 290 |
+
self.speech_embeddings_to_local = None
|
| 291 |
+
|
| 292 |
+
# Create speech_group_downcast_first for group_first_in_global_context
|
| 293 |
+
self.speech_group_downcast = nn.Linear(
|
| 294 |
+
self.input_local_config.hidden_size * config.group_size,
|
| 295 |
+
config.hidden_size,
|
| 296 |
+
bias=False,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
self.hidden_states_downcast = nn.Linear(
|
| 300 |
+
config.hidden_size,
|
| 301 |
+
self.local_config.hidden_size,
|
| 302 |
+
bias=False,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
# Initialize weights and apply final processing
|
| 306 |
+
self.post_init()
|
| 307 |
+
|
| 308 |
+
def apply_input_local_transformer(self, speech_embeddings: torch.Tensor):
|
| 309 |
+
B, T_groups, group_size, hidden_size = speech_embeddings.shape
|
| 310 |
+
|
| 311 |
+
# Process each group independently: [B*T//group_size, group_size, hidden_size]
|
| 312 |
+
input_embeddings = speech_embeddings.reshape(
|
| 313 |
+
B * T_groups, group_size, hidden_size
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
output: BaseModelOutputWithPast = self.input_local_transformer(
|
| 317 |
+
inputs_embeds=input_embeddings,
|
| 318 |
+
return_dict=True,
|
| 319 |
+
is_causal=not self.config.input_full_attention, # for SDPA
|
| 320 |
+
)
|
| 321 |
+
encoded_embeddings = output.last_hidden_state
|
| 322 |
+
|
| 323 |
+
# Reshape back to original format
|
| 324 |
+
# [B*T//group_size, group_size, hidden_size] -> [B, T//group_size, group_size, hidden_size]
|
| 325 |
+
encoded_embeddings = encoded_embeddings.reshape(
|
| 326 |
+
B, T_groups, group_size, hidden_size
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
return encoded_embeddings
|
| 330 |
+
|
| 331 |
+
def _prepare_input_embeds(
|
| 332 |
+
self,
|
| 333 |
+
input_ids: torch.LongTensor, # [B, audio_channels + 1, new_T]
|
| 334 |
+
):
|
| 335 |
+
B = input_ids.shape[0]
|
| 336 |
+
|
| 337 |
+
input_ids = input_ids.int()
|
| 338 |
+
group_size = self.config.group_size
|
| 339 |
+
|
| 340 |
+
text_input_ids = input_ids[:, 0, ::group_size]
|
| 341 |
+
speech_input_ids = (
|
| 342 |
+
input_ids[:, 1:, :]
|
| 343 |
+
.view(B, self.audio_channels, -1, group_size)
|
| 344 |
+
.transpose(1, 2)
|
| 345 |
+
) # [B, T//group_size, audio_channels, group_size]
|
| 346 |
+
|
| 347 |
+
is_speech = text_input_ids == self.args.empty_idx # [B, T//group_size]
|
| 348 |
+
|
| 349 |
+
speech_embeds = torch.zeros(
|
| 350 |
+
(
|
| 351 |
+
B,
|
| 352 |
+
is_speech.shape[1],
|
| 353 |
+
group_size,
|
| 354 |
+
self.input_local_config.hidden_size,
|
| 355 |
+
),
|
| 356 |
+
device=input_ids.device,
|
| 357 |
+
dtype=torch.bfloat16,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
for idx in range(self.audio_channels):
|
| 361 |
+
cur_empty = self.speech_empty_ids[idx]
|
| 362 |
+
cur_embed = self.speech_embeddings[idx]
|
| 363 |
+
cur_speech_ids = speech_input_ids[:, :, idx, :]
|
| 364 |
+
cur_speech_embeds: torch.Tensor = cur_embed(cur_speech_ids)
|
| 365 |
+
# [B, T_groups, group_size, hidden_size]
|
| 366 |
+
|
| 367 |
+
cur_mask = cur_speech_ids == cur_empty
|
| 368 |
+
cur_speech_embeds.masked_fill_(cur_mask.unsqueeze(-1), 0.0)
|
| 369 |
+
|
| 370 |
+
speech_embeds += cur_speech_embeds
|
| 371 |
+
|
| 372 |
+
speech_embeds = speech_embeds * is_speech.unsqueeze(-1).unsqueeze(-1)
|
| 373 |
+
|
| 374 |
+
# Apply input local transformer if configured
|
| 375 |
+
speech_embeds = self.apply_input_local_transformer(speech_embeds)
|
| 376 |
+
speech_embeds = speech_embeds * is_speech.unsqueeze(-1).unsqueeze(-1)
|
| 377 |
+
|
| 378 |
+
T_groups = speech_embeds.shape[1]
|
| 379 |
+
speech_grouped_embeds: torch.Tensor = self.speech_group_downcast(
|
| 380 |
+
speech_embeds.view(B, T_groups, -1)
|
| 381 |
+
) # [B, T_groups, hidden_size]
|
| 382 |
+
|
| 383 |
+
text_embeds: torch.Tensor = self.model.embed_tokens(text_input_ids)
|
| 384 |
+
text_zero_mask = text_input_ids == self.args.empty_idx
|
| 385 |
+
text_embeds.masked_fill_(text_zero_mask.unsqueeze(-1), 0.0)
|
| 386 |
+
|
| 387 |
+
return text_embeds + speech_grouped_embeds
|
| 388 |
+
|
| 389 |
+
def forward(
|
| 390 |
+
self,
|
| 391 |
+
input_ids: torch.LongTensor, # [B, audio_channels + 1, new_T]
|
| 392 |
+
attention_mask: torch.Tensor, # [B, T_group]
|
| 393 |
+
position_ids: torch.LongTensor, # [B, new_T_group]
|
| 394 |
+
past_key_values: Cache | None = None,
|
| 395 |
+
cache_position: torch.LongTensor | None = None, # [new_T_group]
|
| 396 |
+
**_kwargs,
|
| 397 |
+
):
|
| 398 |
+
inputs_embeds = self._prepare_input_embeds(input_ids)
|
| 399 |
+
|
| 400 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 401 |
+
attention_mask=attention_mask,
|
| 402 |
+
position_ids=position_ids,
|
| 403 |
+
past_key_values=past_key_values,
|
| 404 |
+
inputs_embeds=inputs_embeds,
|
| 405 |
+
use_cache=True,
|
| 406 |
+
return_dict=True,
|
| 407 |
+
cache_position=cache_position,
|
| 408 |
+
)
|
| 409 |
+
hidden_states = outputs.last_hidden_state # [B, new_T_group, hidden_size]
|
| 410 |
+
|
| 411 |
+
text_logits: torch.Tensor = self.lm_head(
|
| 412 |
+
hidden_states[:, -1:, :]
|
| 413 |
+
) # [B, 1, vocab_size]
|
| 414 |
+
shift_hidden_states: torch.Tensor = self.hidden_states_downcast(
|
| 415 |
+
hidden_states[:, -1:, :]
|
| 416 |
+
) # [B, 1, hidden_size]
|
| 417 |
+
|
| 418 |
+
return MiMoAudioOutput(
|
| 419 |
+
text_logits=text_logits,
|
| 420 |
+
local_hidden_states=shift_hidden_states,
|
| 421 |
+
past_key_values=outputs.past_key_values,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
def local_forward(
|
| 425 |
+
self,
|
| 426 |
+
local_embeds: torch.FloatTensor, # [B, 1, hidden_size]
|
| 427 |
+
tokens_dtype: torch.dtype,
|
| 428 |
+
tokens_device: torch.device,
|
| 429 |
+
local_sampler: MiMoSampler | None = None,
|
| 430 |
+
):
|
| 431 |
+
B = local_embeds.shape[0]
|
| 432 |
+
delay_iters = self.group_size + max(self.delay_pattern)
|
| 433 |
+
past_key_values = DynamicCache()
|
| 434 |
+
local_tokens = torch.zeros(
|
| 435 |
+
(B, self.group_size, self.audio_channels),
|
| 436 |
+
dtype=tokens_dtype,
|
| 437 |
+
device=tokens_device,
|
| 438 |
+
)
|
| 439 |
+
if local_sampler is None:
|
| 440 |
+
local_sampler = MiMoSampler()
|
| 441 |
+
|
| 442 |
+
for t in range(delay_iters):
|
| 443 |
+
output: BaseModelOutputWithPast = self.local_transformer(
|
| 444 |
+
inputs_embeds=local_embeds,
|
| 445 |
+
past_key_values=past_key_values,
|
| 446 |
+
return_dict=True,
|
| 447 |
+
use_cache=True,
|
| 448 |
+
)
|
| 449 |
+
hidden_state = output.last_hidden_state
|
| 450 |
+
past_key_values = output.past_key_values
|
| 451 |
+
|
| 452 |
+
local_embeds = torch.zeros_like(local_embeds)
|
| 453 |
+
for idx in range(self.audio_channels):
|
| 454 |
+
cur_start = self.delay_pattern[idx]
|
| 455 |
+
cur_end = cur_start + self.group_size
|
| 456 |
+
cur_empty = self.speech_empty_ids[idx]
|
| 457 |
+
if cur_start <= t < cur_end:
|
| 458 |
+
cur_lm_head = self.local_transformer_lm_heads[idx]
|
| 459 |
+
cur_scores: torch.Tensor = cur_lm_head(hidden_state)[:, -1, :]
|
| 460 |
+
# [B, vocab_size]
|
| 461 |
+
cur_token = local_sampler.sample(
|
| 462 |
+
cur_scores,
|
| 463 |
+
[cur_empty],
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
local_tokens[:, t - cur_start, idx] = cur_token
|
| 467 |
+
cur_input_embed = self.speech_embeddings[idx](
|
| 468 |
+
cur_token.unsqueeze(1)
|
| 469 |
+
)
|
| 470 |
+
if self.speech_embeddings_to_local is not None:
|
| 471 |
+
cur_input_embed = self.speech_embeddings_to_local(
|
| 472 |
+
cur_input_embed
|
| 473 |
+
)
|
| 474 |
+
local_embeds += cur_input_embed
|
| 475 |
+
|
| 476 |
+
return local_tokens # [B, group_size, audio_channels]
|
| 477 |
+
|
| 478 |
+
def _prepare_attention_mask(
|
| 479 |
+
self, inputs: torch.Tensor, input_ids_length: int
|
| 480 |
+
) -> torch.Tensor:
|
| 481 |
+
# No information for attention mask inference -> return default attention mask
|
| 482 |
+
return torch.ones(
|
| 483 |
+
(inputs.shape[0], input_ids_length),
|
| 484 |
+
dtype=torch.bool,
|
| 485 |
+
device=inputs.device,
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
def prepare_inputs_for_generation(
|
| 489 |
+
self,
|
| 490 |
+
input_ids: torch.LongTensor,
|
| 491 |
+
past_key_values: Optional[Cache] = None,
|
| 492 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 493 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 494 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 495 |
+
**kwargs,
|
| 496 |
+
):
|
| 497 |
+
"""
|
| 498 |
+
Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or
|
| 499 |
+
slicing inputs given the existing cache.
|
| 500 |
+
|
| 501 |
+
See the forward pass in the model documentation for expected arguments (different models might have different
|
| 502 |
+
requirements for e.g. `past_key_values`). This function should work as is for most LLMs.
|
| 503 |
+
"""
|
| 504 |
+
|
| 505 |
+
# 1. Handle BC:
|
| 506 |
+
model_inputs = {}
|
| 507 |
+
input_ids = input_ids.reshape(
|
| 508 |
+
input_ids.shape[0], -1, (self.audio_channels + 1) * self.config.group_size
|
| 509 |
+
).transpose(1, 2) # [B, audio_channels*group_size, T]
|
| 510 |
+
# - some models don't have `Cache` support (which implies they don't expect `cache_position` in `forward`)
|
| 511 |
+
if self._supports_cache_class:
|
| 512 |
+
model_inputs["cache_position"] = cache_position
|
| 513 |
+
# - `cache_position` was not a mandatory input in `prepare_inputs_for_generation` for those models, and this
|
| 514 |
+
# function may be called outside of `generate`. Handle most use cases by creating `cache_position` on the fly
|
| 515 |
+
# (this alternative is not as robust as calling `generate` and letting it create `cache_position`)
|
| 516 |
+
elif cache_position is None:
|
| 517 |
+
past_length = (
|
| 518 |
+
past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
| 519 |
+
)
|
| 520 |
+
cache_position = torch.arange(
|
| 521 |
+
past_length,
|
| 522 |
+
input_ids.shape[2],
|
| 523 |
+
dtype=torch.long,
|
| 524 |
+
device=input_ids.device,
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
# 2. Generic cache-dependent input preparation
|
| 528 |
+
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
| 529 |
+
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
| 530 |
+
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
| 531 |
+
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case
|
| 532 |
+
if past_key_values is not None:
|
| 533 |
+
model_inputs["past_key_values"] = past_key_values
|
| 534 |
+
if (
|
| 535 |
+
inputs_embeds is not None or cache_position[-1] >= input_ids.shape[2]
|
| 536 |
+
): # Exception 1 or Exception 3
|
| 537 |
+
input_ids = input_ids[:, :, -cache_position.shape[0] :]
|
| 538 |
+
elif (
|
| 539 |
+
input_ids.shape[2] != cache_position.shape[0]
|
| 540 |
+
): # Default case (the "else", a no op, is Exception 2)
|
| 541 |
+
input_ids = input_ids[:, :, cache_position]
|
| 542 |
+
|
| 543 |
+
# 3. Prepare base model inputs
|
| 544 |
+
input_ids_key = (
|
| 545 |
+
"decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
| 546 |
+
)
|
| 547 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 548 |
+
if not self.config.is_encoder_decoder:
|
| 549 |
+
if inputs_embeds is not None and cache_position[0] == 0:
|
| 550 |
+
model_inputs[input_ids_key] = None
|
| 551 |
+
model_inputs["inputs_embeds"] = inputs_embeds
|
| 552 |
+
else:
|
| 553 |
+
# `clone` calls in this function ensure a consistent stride. See #32227
|
| 554 |
+
model_inputs[input_ids_key] = input_ids.clone(
|
| 555 |
+
memory_format=torch.contiguous_format
|
| 556 |
+
)
|
| 557 |
+
model_inputs["inputs_embeds"] = None
|
| 558 |
+
else:
|
| 559 |
+
model_inputs[input_ids_key] = input_ids.clone(
|
| 560 |
+
memory_format=torch.contiguous_format
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
# 4. Create missing `position_ids` on the fly
|
| 564 |
+
if attention_mask is not None and kwargs.get("position_ids") is None:
|
| 565 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 566 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 567 |
+
kwargs["position_ids"] = (
|
| 568 |
+
position_ids # placed in kwargs for further processing (see below)
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
# 5. Slice model inputs if it's an input that should have the same length as `input_ids`
|
| 572 |
+
for model_input_name in ["position_ids", "token_type_ids"]:
|
| 573 |
+
model_input: torch.Tensor = kwargs.get(model_input_name)
|
| 574 |
+
if model_input is not None:
|
| 575 |
+
if past_key_values:
|
| 576 |
+
model_input = model_input[:, -input_ids.shape[2] :]
|
| 577 |
+
model_input = model_input.clone(
|
| 578 |
+
memory_format=torch.contiguous_format
|
| 579 |
+
)
|
| 580 |
+
model_inputs[model_input_name] = model_input
|
| 581 |
+
|
| 582 |
+
if attention_mask is not None:
|
| 583 |
+
model_inputs["attention_mask"] = attention_mask
|
| 584 |
+
|
| 585 |
+
# 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
|
| 586 |
+
for key, value in kwargs.items():
|
| 587 |
+
if key not in model_inputs:
|
| 588 |
+
model_inputs[key] = value
|
| 589 |
+
|
| 590 |
+
if model_inputs[input_ids_key] is not None:
|
| 591 |
+
model_inputs[input_ids_key] = (
|
| 592 |
+
cast(torch.Tensor, model_inputs[input_ids_key])
|
| 593 |
+
.transpose(1, 2)
|
| 594 |
+
.reshape(input_ids.shape[0], -1, (self.audio_channels + 1))
|
| 595 |
+
.transpose(1, 2)
|
| 596 |
+
) # [B, audio_channels, T*group_size]
|
| 597 |
+
|
| 598 |
+
# 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
|
| 599 |
+
model_inputs.pop("labels", None)
|
| 600 |
+
return model_inputs
|
| 601 |
+
|
| 602 |
+
def _get_initial_cache_position(self, input_ids: torch.Tensor, model_kwargs: dict):
|
| 603 |
+
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
|
| 604 |
+
# `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
|
| 605 |
+
if "inputs_embeds" in model_kwargs:
|
| 606 |
+
cache_position = (
|
| 607 |
+
torch.ones_like(
|
| 608 |
+
model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64
|
| 609 |
+
).cumsum(0)
|
| 610 |
+
- 1
|
| 611 |
+
)
|
| 612 |
+
else:
|
| 613 |
+
cache_position = (
|
| 614 |
+
torch.ones(
|
| 615 |
+
(
|
| 616 |
+
input_ids.shape[1]
|
| 617 |
+
// (self.audio_channels + 1)
|
| 618 |
+
// self.config.group_size,
|
| 619 |
+
),
|
| 620 |
+
dtype=torch.int64,
|
| 621 |
+
device=input_ids.device,
|
| 622 |
+
).cumsum(0)
|
| 623 |
+
- 1
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
past_length = 0
|
| 627 |
+
if model_kwargs.get("past_key_values") is not None:
|
| 628 |
+
cache = model_kwargs["past_key_values"]
|
| 629 |
+
past_length = 0
|
| 630 |
+
if not isinstance(cache, Cache):
|
| 631 |
+
past_length = cache[0][0].shape[2]
|
| 632 |
+
elif (
|
| 633 |
+
hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None
|
| 634 |
+
):
|
| 635 |
+
past_length = cache.get_seq_length()
|
| 636 |
+
|
| 637 |
+
# TODO(joao): this is not torch.compile-friendly, find a work-around. If the cache is not empty,
|
| 638 |
+
# end-to-end compilation will yield bad results because `cache_position` will be incorrect.
|
| 639 |
+
if not is_torchdynamo_compiling():
|
| 640 |
+
cache_position = cache_position[past_length:]
|
| 641 |
+
|
| 642 |
+
model_kwargs["cache_position"] = cache_position
|
| 643 |
+
|
| 644 |
+
return model_kwargs
|
| 645 |
+
|
| 646 |
+
@torch.inference_mode()
|
| 647 |
+
def generate(
|
| 648 |
+
self,
|
| 649 |
+
inputs: torch.Tensor | None = None,
|
| 650 |
+
generation_config: GenerationConfig | None = None,
|
| 651 |
+
stopping_criteria: StoppingCriteriaList | list | None = None,
|
| 652 |
+
streamer: BaseStreamer | None = None,
|
| 653 |
+
synced_gpus: bool | None = None,
|
| 654 |
+
global_sampler: MiMoSampler | None = None,
|
| 655 |
+
local_sampler: MiMoSampler | None = None,
|
| 656 |
+
warmup_run: bool | None = None,
|
| 657 |
+
**kwargs,
|
| 658 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
| 659 |
+
generation_config, model_kwargs = self._prepare_generation_config(
|
| 660 |
+
generation_config, **kwargs
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
self._validate_model_kwargs(model_kwargs.copy())
|
| 664 |
+
|
| 665 |
+
# 2. Set generation parameters if not already defined
|
| 666 |
+
if synced_gpus is None:
|
| 667 |
+
if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1:
|
| 668 |
+
synced_gpus = True
|
| 669 |
+
else:
|
| 670 |
+
synced_gpus = False
|
| 671 |
+
|
| 672 |
+
# 3. Define model inputs
|
| 673 |
+
input_ids, _model_input_name, model_kwargs = self._prepare_model_inputs(
|
| 674 |
+
inputs, generation_config.bos_token_id, model_kwargs
|
| 675 |
+
)
|
| 676 |
+
input_ids_length = input_ids.shape[-1]
|
| 677 |
+
input_ids_length //= self.group_size * (self.audio_channels + 1)
|
| 678 |
+
|
| 679 |
+
if streamer is not None:
|
| 680 |
+
streamer.put(input_ids.cpu())
|
| 681 |
+
|
| 682 |
+
if "attention_mask" not in model_kwargs:
|
| 683 |
+
model_kwargs["attention_mask"] = self._prepare_attention_mask(
|
| 684 |
+
inputs, input_ids_length
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
device = input_ids.device
|
| 688 |
+
self._prepare_special_tokens(generation_config, True, device=device)
|
| 689 |
+
|
| 690 |
+
model_kwargs["use_cache"] = True
|
| 691 |
+
model_kwargs["past_key_values"] = DynamicCache()
|
| 692 |
+
|
| 693 |
+
prepared_stopping_criteria = StoppingCriteriaList(
|
| 694 |
+
stopping_criteria if stopping_criteria is not None else []
|
| 695 |
+
)
|
| 696 |
+
prepared_stopping_criteria.append(
|
| 697 |
+
MiMoStopper(
|
| 698 |
+
self.group_size,
|
| 699 |
+
self.audio_channels,
|
| 700 |
+
max_length=generation_config.max_length,
|
| 701 |
+
)
|
| 702 |
+
)
|
| 703 |
+
stance = "default" if warmup_run else "eager_on_recompile"
|
| 704 |
+
with torch.compiler.set_stance(stance):
|
| 705 |
+
return self.slm_sample(
|
| 706 |
+
input_ids,
|
| 707 |
+
stopping_criteria=prepared_stopping_criteria,
|
| 708 |
+
generation_config=generation_config,
|
| 709 |
+
synced_gpus=synced_gpus,
|
| 710 |
+
streamer=streamer,
|
| 711 |
+
global_sampler=global_sampler,
|
| 712 |
+
local_sampler=local_sampler,
|
| 713 |
+
**model_kwargs,
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
def slm_sample(
|
| 717 |
+
self,
|
| 718 |
+
input_ids: torch.LongTensor,
|
| 719 |
+
stopping_criteria: StoppingCriteriaList,
|
| 720 |
+
generation_config: GenerationConfig,
|
| 721 |
+
synced_gpus: bool,
|
| 722 |
+
streamer: BaseStreamer | None,
|
| 723 |
+
global_sampler: MiMoSampler | None = None,
|
| 724 |
+
local_sampler: MiMoSampler | None = None,
|
| 725 |
+
**model_kwargs,
|
| 726 |
+
) -> torch.LongTensor:
|
| 727 |
+
max_length = generation_config.max_length
|
| 728 |
+
|
| 729 |
+
B, cur_len = input_ids.shape
|
| 730 |
+
cur_len //= self.group_size * (self.audio_channels + 1)
|
| 731 |
+
initial_len = cur_len
|
| 732 |
+
this_peer_finished = False
|
| 733 |
+
unfinished_sequences = torch.ones(B, dtype=torch.long, device=input_ids.device)
|
| 734 |
+
|
| 735 |
+
min_length = 0
|
| 736 |
+
stop_token_ids = set()
|
| 737 |
+
for criterion in stopping_criteria:
|
| 738 |
+
if isinstance(criterion, MiMoStopper):
|
| 739 |
+
if criterion.min_length is not None:
|
| 740 |
+
min_length = max(min_length, criterion.min_length)
|
| 741 |
+
stop_token_ids.update(criterion.stop_token_ids)
|
| 742 |
+
|
| 743 |
+
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
| 744 |
+
|
| 745 |
+
while self._has_unfinished_sequences(
|
| 746 |
+
this_peer_finished,
|
| 747 |
+
synced_gpus,
|
| 748 |
+
device=input_ids.device,
|
| 749 |
+
cur_len=cur_len,
|
| 750 |
+
max_length=max_length,
|
| 751 |
+
):
|
| 752 |
+
# prepare model inputs
|
| 753 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
| 754 |
+
|
| 755 |
+
# forward pass to get next token
|
| 756 |
+
if (
|
| 757 |
+
cast(torch.Tensor, model_inputs["input_ids"]).shape[2]
|
| 758 |
+
!= self.group_size
|
| 759 |
+
):
|
| 760 |
+
# prefill run
|
| 761 |
+
with torch.compiler.set_stance("force_eager"):
|
| 762 |
+
outputs: MiMoAudioOutput = self(**model_inputs)
|
| 763 |
+
else:
|
| 764 |
+
outputs: MiMoAudioOutput = self(**model_inputs)
|
| 765 |
+
|
| 766 |
+
if synced_gpus and this_peer_finished:
|
| 767 |
+
continue # don't waste resources running the code we don't need
|
| 768 |
+
|
| 769 |
+
text_logits: torch.Tensor = outputs.text_logits[:, -1, :].clone()
|
| 770 |
+
# [B, vocab_size]
|
| 771 |
+
|
| 772 |
+
removed_tokens = None
|
| 773 |
+
if cur_len < min_length:
|
| 774 |
+
removed_tokens = list(stop_token_ids)
|
| 775 |
+
|
| 776 |
+
next_text_tokens = global_sampler.sample(text_logits, removed_tokens=removed_tokens)
|
| 777 |
+
# [B]
|
| 778 |
+
|
| 779 |
+
local_hidden_states = outputs.local_hidden_states
|
| 780 |
+
|
| 781 |
+
# Only Supports batch_size=1 here
|
| 782 |
+
if next_text_tokens[0] != self.args.empty_idx:
|
| 783 |
+
zero_embed_tensor = torch.tensor(
|
| 784 |
+
self.speech_empty_ids,
|
| 785 |
+
device=next_text_tokens.device,
|
| 786 |
+
dtype=input_ids.dtype,
|
| 787 |
+
)
|
| 788 |
+
next_speech_tokens = zero_embed_tensor.view(
|
| 789 |
+
1, 1, self.audio_channels
|
| 790 |
+
).expand(B, self.config.group_size, -1)
|
| 791 |
+
else:
|
| 792 |
+
next_speech_tokens = self.local_forward(
|
| 793 |
+
local_embeds=local_hidden_states,
|
| 794 |
+
tokens_dtype=next_text_tokens.dtype,
|
| 795 |
+
tokens_device=next_text_tokens.device,
|
| 796 |
+
local_sampler=local_sampler,
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
next_text_tokens = next_text_tokens.reshape(B, 1, 1).expand(
|
| 800 |
+
-1, self.group_size, -1
|
| 801 |
+
) # [B, group_size, 1]
|
| 802 |
+
|
| 803 |
+
# generate speech tokens
|
| 804 |
+
next_tokens = torch.cat(
|
| 805 |
+
(next_text_tokens, next_speech_tokens), dim=-1
|
| 806 |
+
).reshape(B, -1) # [B, group_size * (audio_channels + 1)]
|
| 807 |
+
|
| 808 |
+
input_ids = torch.cat(
|
| 809 |
+
[input_ids, next_tokens], dim=-1
|
| 810 |
+
) # [B, T*group_size*vq]
|
| 811 |
+
|
| 812 |
+
if streamer is not None:
|
| 813 |
+
streamer.put(next_tokens.cpu())
|
| 814 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
| 815 |
+
outputs,
|
| 816 |
+
model_kwargs,
|
| 817 |
+
is_encoder_decoder=self.config.is_encoder_decoder,
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
unfinished_sequences = unfinished_sequences & ~stopping_criteria(
|
| 821 |
+
input_ids, None
|
| 822 |
+
)
|
| 823 |
+
this_peer_finished = unfinished_sequences.max() == 0
|
| 824 |
+
cur_len += 1
|
| 825 |
+
|
| 826 |
+
# This is needed to properly delete outputs.logits which may be very large for first iteration
|
| 827 |
+
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
|
| 828 |
+
del outputs
|
| 829 |
+
|
| 830 |
+
if streamer is not None:
|
| 831 |
+
streamer.end()
|
| 832 |
+
|
| 833 |
+
input_ids = input_ids[:B]
|
| 834 |
+
|
| 835 |
+
return input_ids
|
src/mimo_audio/process_speechdata.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2025 Xiaomi Corporation.
|
| 3 |
+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from typing import Tuple, Union, List
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class InputSegment:
|
| 11 |
+
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
text: str = "",
|
| 15 |
+
audio: torch.Tensor = None,
|
| 16 |
+
tokenized_text: torch.Tensor = None,
|
| 17 |
+
speech_zeroemb_idx: Union[int, List[int]] = 1024,
|
| 18 |
+
text_zeroemb_idx: int = 152067,
|
| 19 |
+
add_sosp_eosp=True,
|
| 20 |
+
) -> None:
|
| 21 |
+
has_text = text is not None
|
| 22 |
+
has_tokenized_text = tokenized_text is not None
|
| 23 |
+
assert has_text or has_tokenized_text, "Text or tokenized text must be provided"
|
| 24 |
+
|
| 25 |
+
self.audio = audio
|
| 26 |
+
self.text = text
|
| 27 |
+
self.tokenized_text = tokenized_text
|
| 28 |
+
self.speech_zeroemb_idx = speech_zeroemb_idx
|
| 29 |
+
self.text_zeroemb_idx = text_zeroemb_idx
|
| 30 |
+
self.add_sosp_eosp = add_sosp_eosp
|
| 31 |
+
|
| 32 |
+
@staticmethod
|
| 33 |
+
def insert_between(tensor, i, value=-1):
|
| 34 |
+
return torch.scatter(
|
| 35 |
+
torch.full(
|
| 36 |
+
(1, tensor.shape[1] + (tensor.shape[1] - 1) * i + i),
|
| 37 |
+
value,
|
| 38 |
+
dtype=tensor.dtype,
|
| 39 |
+
),
|
| 40 |
+
1,
|
| 41 |
+
torch.arange(0, tensor.shape[1], dtype=torch.int64)[None] * (i + 1),
|
| 42 |
+
tensor,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
def to_input_id(
|
| 46 |
+
self,
|
| 47 |
+
tokenizer,
|
| 48 |
+
group_size: int,
|
| 49 |
+
audio_channels: int = 8,
|
| 50 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 51 |
+
if self.audio is None:
|
| 52 |
+
if self.tokenized_text is None:
|
| 53 |
+
tokenized_text = tokenizer(
|
| 54 |
+
self.text,
|
| 55 |
+
return_tensors="pt",
|
| 56 |
+
truncation=True,
|
| 57 |
+
max_length=999999,
|
| 58 |
+
padding=False,
|
| 59 |
+
add_special_tokens=False,
|
| 60 |
+
)["input_ids"].int()
|
| 61 |
+
else:
|
| 62 |
+
tokenized_text = self.tokenized_text.unsqueeze(0)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if group_size > 1:
|
| 66 |
+
tokenized_text = self.insert_between(
|
| 67 |
+
tokenized_text, group_size - 1, value=-100
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
if isinstance(self.speech_zeroemb_idx, list):
|
| 72 |
+
audio_part_input_id = torch.zeros((audio_channels, tokenized_text.shape[1]), dtype=torch.int)
|
| 73 |
+
for i, idx in enumerate(self.speech_zeroemb_idx):
|
| 74 |
+
audio_part_input_id[i, :] = idx
|
| 75 |
+
else:
|
| 76 |
+
audio_part_input_id = torch.full(
|
| 77 |
+
(audio_channels, tokenized_text.shape[1]), self.speech_zeroemb_idx, dtype=torch.int
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
else:
|
| 82 |
+
sosp_token = (
|
| 83 |
+
tokenizer.convert_tokens_to_ids("<|sosp|>")
|
| 84 |
+
if self.add_sosp_eosp
|
| 85 |
+
else None
|
| 86 |
+
)
|
| 87 |
+
eosp_token = (
|
| 88 |
+
tokenizer.convert_tokens_to_ids("<|eosp|>")
|
| 89 |
+
if self.add_sosp_eosp
|
| 90 |
+
else None
|
| 91 |
+
)
|
| 92 |
+
audio_part = self.audio.reshape(-1, audio_channels).T # [audio_channels, seqlen]
|
| 93 |
+
|
| 94 |
+
assert (
|
| 95 |
+
audio_part.shape[1] % group_size == 0
|
| 96 |
+
), f"Audio shape {audio_part.shape} is not divisible by group_size {group_size}"
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
text_len = audio_part.shape[1] // group_size
|
| 100 |
+
empty_token = self.text_zeroemb_idx
|
| 101 |
+
if empty_token is None:
|
| 102 |
+
empty_token = tokenizer.eod
|
| 103 |
+
tokenized_text = torch.full((1, text_len), empty_token, dtype=torch.int)
|
| 104 |
+
|
| 105 |
+
tokenized_text = (
|
| 106 |
+
torch.cat(
|
| 107 |
+
[
|
| 108 |
+
torch.tensor([[sosp_token]], dtype=torch.int),
|
| 109 |
+
tokenized_text,
|
| 110 |
+
torch.tensor([[eosp_token]], dtype=torch.int),
|
| 111 |
+
],
|
| 112 |
+
dim=1,
|
| 113 |
+
)
|
| 114 |
+
if self.add_sosp_eosp
|
| 115 |
+
else tokenized_text
|
| 116 |
+
)
|
| 117 |
+
tokenized_text = self.insert_between(
|
| 118 |
+
tokenized_text, group_size - 1, value=-100
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
if self.add_sosp_eosp:
|
| 123 |
+
if isinstance(self.speech_zeroemb_idx, list):
|
| 124 |
+
sosp_part = torch.zeros((audio_channels, group_size), dtype=torch.int)
|
| 125 |
+
eosp_part = torch.zeros((audio_channels, group_size), dtype=torch.int)
|
| 126 |
+
for i, idx in enumerate(self.speech_zeroemb_idx):
|
| 127 |
+
sosp_part[i, :] = idx
|
| 128 |
+
eosp_part[i, :] = idx
|
| 129 |
+
audio_part_input_id = torch.cat([sosp_part, audio_part, eosp_part], dim=1)
|
| 130 |
+
else:
|
| 131 |
+
audio_part_input_id = torch.cat(
|
| 132 |
+
[
|
| 133 |
+
torch.full((audio_channels, group_size), self.speech_zeroemb_idx, dtype=torch.int),
|
| 134 |
+
audio_part,
|
| 135 |
+
torch.full((audio_channels, group_size), self.speech_zeroemb_idx, dtype=torch.int),
|
| 136 |
+
],
|
| 137 |
+
dim=1,
|
| 138 |
+
)
|
| 139 |
+
else:
|
| 140 |
+
audio_part_input_id = audio_part
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
input_ids = torch.cat(
|
| 145 |
+
[tokenized_text, audio_part_input_id], dim=0
|
| 146 |
+
) # [n_rvq + 1, seqlen]
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
return input_ids
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class StreamingInputSegment:
|
| 153 |
+
def __init__(
|
| 154 |
+
self,
|
| 155 |
+
text: str = "",
|
| 156 |
+
audio: torch.Tensor = None,
|
| 157 |
+
tokenized_text: torch.Tensor = None,
|
| 158 |
+
speech_zeroemb_idx: Union[int, List[int]] = 1024,
|
| 159 |
+
text_zeroemb_idx: int = 152067,
|
| 160 |
+
text_segment_size: int = 5,
|
| 161 |
+
audio_segment_size: int = 5,
|
| 162 |
+
tokenizer=None,
|
| 163 |
+
group_size=None,
|
| 164 |
+
audio_channels=None,
|
| 165 |
+
) -> None:
|
| 166 |
+
has_text = text is not None
|
| 167 |
+
has_tokenized_text = tokenized_text is not None
|
| 168 |
+
assert has_text or has_tokenized_text, "Text or tokenized text must be provided"
|
| 169 |
+
|
| 170 |
+
self.audio = audio
|
| 171 |
+
self.text = text
|
| 172 |
+
self.tokenized_text = tokenized_text
|
| 173 |
+
self.speech_zeroemb_idx = speech_zeroemb_idx
|
| 174 |
+
self.text_zeroemb_idx = text_zeroemb_idx
|
| 175 |
+
self.text_segment_size = text_segment_size
|
| 176 |
+
self.audio_segment_size = audio_segment_size
|
| 177 |
+
self.tokenizer = tokenizer
|
| 178 |
+
self.group_size = group_size
|
| 179 |
+
self.audio_channels = audio_channels
|
| 180 |
+
|
| 181 |
+
def to_input_id(
|
| 182 |
+
self,
|
| 183 |
+
tokenizer,
|
| 184 |
+
group_size: int,
|
| 185 |
+
audio_channels: int = 8,
|
| 186 |
+
):
|
| 187 |
+
if self.tokenized_text is None:
|
| 188 |
+
tokenized_text = tokenizer(
|
| 189 |
+
self.text,
|
| 190 |
+
return_tensors="pt",
|
| 191 |
+
truncation=True,
|
| 192 |
+
max_length=999999,
|
| 193 |
+
padding=False,
|
| 194 |
+
add_special_tokens=False,
|
| 195 |
+
)["input_ids"].int() # [1, seqlen]
|
| 196 |
+
else:
|
| 197 |
+
tokenized_text = self.tokenized_text.unsqueeze(0)
|
| 198 |
+
|
| 199 |
+
tokenized_text = tokenized_text.squeeze(0)
|
| 200 |
+
|
| 201 |
+
text_segments = tokenized_text.split(self.text_segment_size, dim=0)
|
| 202 |
+
audio_segments = self.audio.split(self.audio_segment_size*group_size*audio_channels, dim=0)
|
| 203 |
+
|
| 204 |
+
tokenized_segments = []
|
| 205 |
+
tokenized_segments.append(
|
| 206 |
+
InputSegment(
|
| 207 |
+
text='<|sostm|>',
|
| 208 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 209 |
+
text_zeroemb_idx=self.text_zeroemb_idx,
|
| 210 |
+
),
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
eot_tokens = tokenizer(
|
| 215 |
+
"<|eot|>",
|
| 216 |
+
return_tensors="pt",
|
| 217 |
+
truncation=True,
|
| 218 |
+
max_length=999999,
|
| 219 |
+
padding=False,
|
| 220 |
+
add_special_tokens=False,
|
| 221 |
+
)["input_ids"][0].to(text_segments[-1])
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
text_segments = text_segments[:-1] + (torch.cat([text_segments[-1], eot_tokens], dim=0),)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
length = min(len(text_segments), len(audio_segments))
|
| 228 |
+
for i in range(length):
|
| 229 |
+
text_segment = text_segments[i]
|
| 230 |
+
audio_segment = audio_segments[i]
|
| 231 |
+
|
| 232 |
+
tokenized_segments.append(
|
| 233 |
+
InputSegment(
|
| 234 |
+
tokenized_text=text_segment,
|
| 235 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 236 |
+
text_zeroemb_idx=self.text_zeroemb_idx,
|
| 237 |
+
),
|
| 238 |
+
)
|
| 239 |
+
tokenized_segments.append(
|
| 240 |
+
InputSegment(
|
| 241 |
+
audio=audio_segment,
|
| 242 |
+
add_sosp_eosp=False,
|
| 243 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 244 |
+
text_zeroemb_idx=self.text_zeroemb_idx,
|
| 245 |
+
),
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
for j in range(length, len(text_segments)):
|
| 249 |
+
tokenized_segments.append(
|
| 250 |
+
InputSegment(
|
| 251 |
+
tokenized_text=text_segments[j],
|
| 252 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 253 |
+
text_zeroemb_idx=self.text_zeroemb_idx,
|
| 254 |
+
),
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
for j in range(length, len(audio_segments)):
|
| 258 |
+
tokenized_segments.append(
|
| 259 |
+
InputSegment(
|
| 260 |
+
audio=audio_segments[j],
|
| 261 |
+
add_sosp_eosp=False,
|
| 262 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 263 |
+
text_zeroemb_idx=self.text_zeroemb_idx,
|
| 264 |
+
),
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
tokenized_segments.append(
|
| 268 |
+
InputSegment(
|
| 269 |
+
text="<|eostm|>",
|
| 270 |
+
speech_zeroemb_idx=self.speech_zeroemb_idx,
|
| 271 |
+
text_zeroemb_idx=self.text_zeroemb_idx,
|
| 272 |
+
),
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
input_ids = [
|
| 277 |
+
seg.to_input_id(
|
| 278 |
+
self.tokenizer,
|
| 279 |
+
self.group_size,
|
| 280 |
+
self.audio_channels,
|
| 281 |
+
)
|
| 282 |
+
for seg in tokenized_segments
|
| 283 |
+
]
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
input_ids = torch.cat(input_ids, dim=1).type(torch.int64) # [n_rvq + 1, seqlen]
|
| 288 |
+
|
| 289 |
+
return input_ids
|
src/mimo_audio/templates.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Xiaomi Corporation.
|
| 2 |
+
asr_zh_templates = [
|
| 3 |
+
"请将这段语音转换为文字",
|
| 4 |
+
"帮我识别这个音频文件中的内容",
|
| 5 |
+
"把这段录音转成文本",
|
| 6 |
+
"请转录这段语音",
|
| 7 |
+
"将音频内容转换成文字格式",
|
| 8 |
+
"识别并转写这段语音",
|
| 9 |
+
"把语音内容写成文字",
|
| 10 |
+
"转录这个音频片段",
|
| 11 |
+
"将这段对话转换为文本",
|
| 12 |
+
"麻烦帮我把这段录音整理成详细的文字记录",
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
asr_en_templates = [
|
| 16 |
+
"Please transcribe this audio file",
|
| 17 |
+
"Convert this speech recording to text",
|
| 18 |
+
"Transcribe the following voice message",
|
| 19 |
+
"Turn this audio into readable text",
|
| 20 |
+
"Please convert the recording to written format",
|
| 21 |
+
"Transcribe what you hear in this audio",
|
| 22 |
+
"Convert this spoken content to text",
|
| 23 |
+
"Please write down what is said in this recording",
|
| 24 |
+
"Transcribe this voice recording",
|
| 25 |
+
"Could you please help me transcribe this important recording?",
|
| 26 |
+
"Would you mind converting this voice message into a readable text format?",
|
| 27 |
+
"I'd really appreciate it if you could turn this audio file into a written document",
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
tts_zh_templates = [
|
| 31 |
+
"请将这段文字转换为语音",
|
| 32 |
+
"帮我把这个文本读出来",
|
| 33 |
+
"将这些文字生成音频",
|
| 34 |
+
"请朗读这段内容",
|
| 35 |
+
"把这段话转换成语音文件",
|
| 36 |
+
"生成这段文字的语音版本",
|
| 37 |
+
"请用语音播报这些内容",
|
| 38 |
+
"将文本转换为可听的音频",
|
| 39 |
+
"帮我朗读这段文字",
|
| 40 |
+
"把这些内容念出来",
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
tts_en_templates = [
|
| 44 |
+
"Please convert this text to speech",
|
| 45 |
+
"Turn this writing into audio",
|
| 46 |
+
"Generate speech from this text",
|
| 47 |
+
"Read this content out loud",
|
| 48 |
+
"Convert these words to voice",
|
| 49 |
+
"Create an audio version of this text",
|
| 50 |
+
"Please vocalize this content",
|
| 51 |
+
"Turn this text into audible format",
|
| 52 |
+
"Help me convert this writing to speech",
|
| 53 |
+
"Make this text into spoken audio",
|
| 54 |
+
]
|
src/mimo_audio_tokenizer/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Xiaomi Corporation.
|
| 2 |
+
from .modeling_audio_tokenizer import MiMoAudioTokenizer, StreamingConfig, StreamingCache
|
| 3 |
+
from .configuration_audio_tokenizer import MiMoAudioTokenizerConfig
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
__all__ = ['MiMoAudioTokenizer', 'StreamingConfig', 'StreamingCache', 'MiMoAudioTokenizerConfig']
|
src/mimo_audio_tokenizer/configuration_audio_tokenizer.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Xiaomi Corporation.
|
| 2 |
+
from transformers import PretrainedConfig
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class MiMoAudioTokenizerConfig(PretrainedConfig):
|
| 6 |
+
model_type = "mimo_audio_tokenizer"
|
| 7 |
+
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
max_audio_seconds: int = 1800,
|
| 11 |
+
stride_size: int = 2,
|
| 12 |
+
avg_pooler: int = 1,
|
| 13 |
+
d_model: int = 768,
|
| 14 |
+
scale_embedding: bool = True,
|
| 15 |
+
kernel_size: int = 3,
|
| 16 |
+
activation_function: str = "gelu",
|
| 17 |
+
encoder_layers: int = 8,
|
| 18 |
+
encoder_skip_layer_id: int = None,
|
| 19 |
+
encoder_attention_heads: int = 12,
|
| 20 |
+
encoder_ffn_dim: int = 3072,
|
| 21 |
+
encoder_causal: bool = False,
|
| 22 |
+
encoder_attn_window_size: list[int] = None,
|
| 23 |
+
decoder_layers: int = 8,
|
| 24 |
+
decoder_attention_heads: int = 12,
|
| 25 |
+
decoder_ffn_dim: int = 3072,
|
| 26 |
+
decoder_kernel_size: int = 3,
|
| 27 |
+
decoder_stride_size: int = 2,
|
| 28 |
+
decoder_causal: bool = True,
|
| 29 |
+
decoder_attn_window_size: list[int] = None,
|
| 30 |
+
nfft: int = 1024,
|
| 31 |
+
vocoder_dim: int = 512,
|
| 32 |
+
vocoder_intermediate_dim: int = 4096,
|
| 33 |
+
vocoder_num_layers: int = 30,
|
| 34 |
+
n_mels: int = 80,
|
| 35 |
+
sampling_rate: int = 24000,
|
| 36 |
+
hop_length: int = 240,
|
| 37 |
+
window_size: int = 1024,
|
| 38 |
+
vocoder_padding: str = "same",
|
| 39 |
+
fmin: int = 0,
|
| 40 |
+
fmax: int = None,
|
| 41 |
+
num_quantizers: int = 12,
|
| 42 |
+
codebook_size: list[int] = None,
|
| 43 |
+
threshold_ema_dead_code: int = 10,
|
| 44 |
+
position_embedding_type: str = "rope",
|
| 45 |
+
rope_theta: int = 10000,
|
| 46 |
+
rope_type: str = "default",
|
| 47 |
+
ln_type: str = "LayerNorm",
|
| 48 |
+
vocoder_attention_heads: int = 4,
|
| 49 |
+
vocoder_attn_window_size: list[int] = None,
|
| 50 |
+
**kwargs,
|
| 51 |
+
):
|
| 52 |
+
super().__init__(**kwargs)
|
| 53 |
+
self.max_audio_seconds = max_audio_seconds
|
| 54 |
+
self.stride_size = stride_size
|
| 55 |
+
self.avg_pooler = avg_pooler
|
| 56 |
+
self.d_model = d_model
|
| 57 |
+
self.scale_embedding = scale_embedding
|
| 58 |
+
self.kernel_size = kernel_size
|
| 59 |
+
self.activation_function = activation_function
|
| 60 |
+
self.encoder_layers = encoder_layers
|
| 61 |
+
self.encoder_skip_layer_id = encoder_skip_layer_id
|
| 62 |
+
self.encoder_attention_heads = encoder_attention_heads
|
| 63 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
| 64 |
+
self.encoder_causal = encoder_causal
|
| 65 |
+
self.encoder_attn_window_size = (
|
| 66 |
+
encoder_attn_window_size
|
| 67 |
+
if encoder_attn_window_size is not None
|
| 68 |
+
else [-1, -1]
|
| 69 |
+
)
|
| 70 |
+
self.decoder_layers = decoder_layers
|
| 71 |
+
self.decoder_attention_heads = decoder_attention_heads
|
| 72 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
| 73 |
+
self.decoder_kernel_size = decoder_kernel_size
|
| 74 |
+
self.decoder_stride_size = decoder_stride_size
|
| 75 |
+
self.decoder_causal = decoder_causal
|
| 76 |
+
self.decoder_attn_window_size = (
|
| 77 |
+
decoder_attn_window_size
|
| 78 |
+
if decoder_attn_window_size is not None
|
| 79 |
+
else [-1, -1]
|
| 80 |
+
)
|
| 81 |
+
self.nfft = nfft
|
| 82 |
+
self.vocoder_dim = vocoder_dim
|
| 83 |
+
self.vocoder_intermediate_dim = vocoder_intermediate_dim
|
| 84 |
+
self.vocoder_num_layers = vocoder_num_layers
|
| 85 |
+
self.n_mels = n_mels
|
| 86 |
+
self.sampling_rate = sampling_rate
|
| 87 |
+
self.hop_length = hop_length
|
| 88 |
+
self.window_size = window_size
|
| 89 |
+
self.vocoder_padding = vocoder_padding
|
| 90 |
+
self.fmin = fmin
|
| 91 |
+
self.fmax = fmax
|
| 92 |
+
self.num_quantizers = num_quantizers
|
| 93 |
+
self.codebook_size = codebook_size if codebook_size is not None else [1024]
|
| 94 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
| 95 |
+
self.position_embedding_type = position_embedding_type
|
| 96 |
+
self.rope_theta = rope_theta
|
| 97 |
+
self.rope_type = rope_type
|
| 98 |
+
self.ln_type = ln_type
|
| 99 |
+
self.vocoder_attention_heads = vocoder_attention_heads
|
| 100 |
+
self.vocoder_attn_window_size = (
|
| 101 |
+
vocoder_attn_window_size
|
| 102 |
+
if vocoder_attn_window_size is not None
|
| 103 |
+
else [40, 10]
|
| 104 |
+
)
|
src/mimo_audio_tokenizer/modeling_audio_tokenizer.py
ADDED
|
@@ -0,0 +1,857 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Xiaomi Corporation.
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from flash_attn import flash_attn_varlen_func
|
| 8 |
+
from torch.nn import functional as F
|
| 9 |
+
from transformers.activations import ACT2FN
|
| 10 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 11 |
+
|
| 12 |
+
from .configuration_audio_tokenizer import MiMoAudioTokenizerConfig
|
| 13 |
+
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update, apply_rotary_pos_emb
|
| 14 |
+
from .quantization import ResidualVectorQuantizer
|
| 15 |
+
from dataclasses import dataclass, field
|
| 16 |
+
from typing import List
|
| 17 |
+
|
| 18 |
+
def get_sequence_mask(inputs, inputs_length):
|
| 19 |
+
if inputs.dim() == 3:
|
| 20 |
+
bsz, tgt_len, _ = inputs.size()
|
| 21 |
+
else:
|
| 22 |
+
bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
|
| 23 |
+
sequence_mask = torch.arange(0, tgt_len).to(inputs.device)
|
| 24 |
+
sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(
|
| 25 |
+
bsz, tgt_len, 1
|
| 26 |
+
)
|
| 27 |
+
unpacking_index = torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1
|
| 28 |
+
return sequence_mask, unpacking_index
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def unpack_hidden_states(
|
| 32 |
+
hidden_states, lengths, sequence_mask=None, unpacking_index=None
|
| 33 |
+
):
|
| 34 |
+
bsz = lengths.shape[0]
|
| 35 |
+
if sequence_mask is None or unpacking_index is None:
|
| 36 |
+
sequence_mask, unpacking_index = get_sequence_mask(hidden_states, lengths)
|
| 37 |
+
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
|
| 38 |
+
bsz, torch.max(lengths), hidden_states.shape[-1]
|
| 39 |
+
)
|
| 40 |
+
hidden_states = torch.where(sequence_mask, hidden_states, 0)
|
| 41 |
+
return hidden_states
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_position_ids(lengths):
|
| 45 |
+
total_len = lengths.sum()
|
| 46 |
+
offset = torch.cat([torch.zeros(1).to(lengths), lengths[:-1].cumsum(dim=0)])
|
| 47 |
+
offset = torch.repeat_interleave(offset, lengths)
|
| 48 |
+
position_ids = torch.arange(0, total_len).to(offset) - offset
|
| 49 |
+
return position_ids
|
| 50 |
+
|
| 51 |
+
@dataclass
|
| 52 |
+
class StreamingConfig:
|
| 53 |
+
seg_point: int = field(default=60 * 25)
|
| 54 |
+
process_seg_point: bool = field(default=True)
|
| 55 |
+
left_overlap: int = field(default=10 * 25)
|
| 56 |
+
right_overlap: int = field(default=40)
|
| 57 |
+
seg_point_left_overlap: int = field(default=0)
|
| 58 |
+
|
| 59 |
+
@dataclass
|
| 60 |
+
class StreamingCache:
|
| 61 |
+
hidden_states: List[torch.Tensor] = field(default=None)
|
| 62 |
+
processed_lengths: List[int] = field(default=None)
|
| 63 |
+
|
| 64 |
+
class ISTFT(nn.Module):
|
| 65 |
+
"""
|
| 66 |
+
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
|
| 67 |
+
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
|
| 68 |
+
See issue: https://github.com/pytorch/pytorch/issues/62323
|
| 69 |
+
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
|
| 70 |
+
The NOLA constraint is met as we trim padded samples anyway.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
n_fft (int): Size of Fourier transform.
|
| 74 |
+
hop_length (int): The distance between neighboring sliding window frames.
|
| 75 |
+
win_length (int): The size of window frame and STFT filter.
|
| 76 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(
|
| 80 |
+
self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
|
| 81 |
+
):
|
| 82 |
+
super().__init__()
|
| 83 |
+
if padding not in ["center", "same"]:
|
| 84 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 85 |
+
self.padding = padding
|
| 86 |
+
self.n_fft = n_fft
|
| 87 |
+
self.hop_length = hop_length
|
| 88 |
+
self.win_length = win_length
|
| 89 |
+
window = torch.hann_window(win_length)
|
| 90 |
+
self.register_buffer("window", window)
|
| 91 |
+
|
| 92 |
+
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
| 93 |
+
"""
|
| 94 |
+
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
|
| 98 |
+
N is the number of frequency bins, and T is the number of time frames.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
|
| 102 |
+
"""
|
| 103 |
+
if self.padding == "center":
|
| 104 |
+
# Fallback to pytorch native implementation
|
| 105 |
+
return torch.istft(
|
| 106 |
+
spec,
|
| 107 |
+
self.n_fft,
|
| 108 |
+
self.hop_length,
|
| 109 |
+
self.win_length,
|
| 110 |
+
self.window,
|
| 111 |
+
center=True,
|
| 112 |
+
)
|
| 113 |
+
elif self.padding == "same":
|
| 114 |
+
pad = (self.win_length - self.hop_length) // 2
|
| 115 |
+
else:
|
| 116 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 117 |
+
|
| 118 |
+
assert spec.dim() == 3, "Expected a 3D tensor as input"
|
| 119 |
+
B, N, T = spec.shape
|
| 120 |
+
|
| 121 |
+
# Inverse FFT
|
| 122 |
+
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
|
| 123 |
+
ifft = ifft * self.window[None, :, None]
|
| 124 |
+
|
| 125 |
+
# Overlap and Add
|
| 126 |
+
output_size = (T - 1) * self.hop_length + self.win_length
|
| 127 |
+
y = torch.nn.functional.fold(
|
| 128 |
+
ifft,
|
| 129 |
+
output_size=(1, output_size),
|
| 130 |
+
kernel_size=(1, self.win_length),
|
| 131 |
+
stride=(1, self.hop_length),
|
| 132 |
+
)[:, 0, 0, pad:-pad]
|
| 133 |
+
|
| 134 |
+
# Window envelope
|
| 135 |
+
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
|
| 136 |
+
window_envelope = torch.nn.functional.fold(
|
| 137 |
+
window_sq,
|
| 138 |
+
output_size=(1, output_size),
|
| 139 |
+
kernel_size=(1, self.win_length),
|
| 140 |
+
stride=(1, self.hop_length),
|
| 141 |
+
).squeeze()[pad:-pad]
|
| 142 |
+
|
| 143 |
+
# Normalize
|
| 144 |
+
assert (window_envelope > 1e-11).all()
|
| 145 |
+
y = y / window_envelope
|
| 146 |
+
|
| 147 |
+
return y
|
| 148 |
+
|
| 149 |
+
class ISTFTHead(nn.Module):
|
| 150 |
+
"""
|
| 151 |
+
ISTFT Head module for predicting STFT complex coefficients.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
dim (int): Hidden dimension of the model.
|
| 155 |
+
n_fft (int): Size of Fourier transform.
|
| 156 |
+
hop_length (int): The distance between neighboring sliding window frames, which should align with
|
| 157 |
+
the resolution of the input features.
|
| 158 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
|
| 162 |
+
super().__init__()
|
| 163 |
+
out_dim = n_fft + 2
|
| 164 |
+
self.out = torch.nn.Linear(dim, out_dim)
|
| 165 |
+
self.istft = ISTFT(
|
| 166 |
+
n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 170 |
+
"""
|
| 171 |
+
Forward pass of the ISTFTHead module.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
| 175 |
+
L is the sequence length, and H denotes the model dimension.
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
| 179 |
+
"""
|
| 180 |
+
x = self.out(x).transpose(1, 2)
|
| 181 |
+
mag, p = x.chunk(2, dim=1)
|
| 182 |
+
mag = torch.exp(mag)
|
| 183 |
+
mag = torch.clip(
|
| 184 |
+
mag, max=1e2
|
| 185 |
+
) # safeguard to prevent excessively large magnitudes
|
| 186 |
+
# wrapping happens here. These two lines produce real and imaginary value
|
| 187 |
+
x = torch.cos(p)
|
| 188 |
+
y = torch.sin(p)
|
| 189 |
+
# recalculating phase here does not produce anything new
|
| 190 |
+
# only costs time
|
| 191 |
+
# phase = torch.atan2(y, x)
|
| 192 |
+
# S = mag * torch.exp(phase * 1j)
|
| 193 |
+
# better directly produce the complex value
|
| 194 |
+
original_dtype = x.dtype
|
| 195 |
+
S = mag.float() * (x.float() + 1j * y.float())
|
| 196 |
+
audio = self.istft(S)
|
| 197 |
+
audio = audio.to(original_dtype)
|
| 198 |
+
return audio
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class RotaryEmbedding(nn.Module):
|
| 202 |
+
def __init__(self, base, dim, max_seq_len, rope_type="default", device=None):
|
| 203 |
+
super().__init__()
|
| 204 |
+
self.max_seq_len = max_seq_len
|
| 205 |
+
self.rope_type = rope_type
|
| 206 |
+
|
| 207 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 208 |
+
|
| 209 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(
|
| 210 |
+
device=device, base=base, dim=dim
|
| 211 |
+
)
|
| 212 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 213 |
+
self.original_inv_freq = self.inv_freq
|
| 214 |
+
|
| 215 |
+
@torch.no_grad()
|
| 216 |
+
@dynamic_rope_update
|
| 217 |
+
def forward(self, x, position_ids):
|
| 218 |
+
inv_freq_expanded = self.inv_freq[:, None].float().expand(-1, 1).to(x.device)
|
| 219 |
+
position_ids_expanded = position_ids[None, :].float()
|
| 220 |
+
|
| 221 |
+
device_type = (
|
| 222 |
+
x.device.type
|
| 223 |
+
if isinstance(x.device.type, str) and x.device.type != "mps"
|
| 224 |
+
else "cpu"
|
| 225 |
+
)
|
| 226 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 227 |
+
freqs = (
|
| 228 |
+
inv_freq_expanded.float() @ position_ids_expanded.float()
|
| 229 |
+
).transpose(0, 1)
|
| 230 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 231 |
+
cos = emb.cos() * self.attention_scaling
|
| 232 |
+
sin = emb.sin() * self.attention_scaling
|
| 233 |
+
|
| 234 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 235 |
+
|
| 236 |
+
class RMSNorm(nn.Module):
|
| 237 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 238 |
+
"""
|
| 239 |
+
RMSNorm is equivalent to T5LayerNorm
|
| 240 |
+
"""
|
| 241 |
+
super().__init__()
|
| 242 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 243 |
+
self.variance_epsilon = eps
|
| 244 |
+
|
| 245 |
+
def forward(self, hidden_states):
|
| 246 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
| 247 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 248 |
+
|
| 249 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
| 250 |
+
hidden_states = hidden_states.to(self.weight.dtype)
|
| 251 |
+
|
| 252 |
+
return self.weight * hidden_states
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
LAYER_NORM = {"LayerNorm": nn.LayerNorm, "RMSNorm": RMSNorm}
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class Attention(nn.Module):
|
| 259 |
+
def __init__(self, embed_dim, num_heads, window_size=(-1, -1), causal=False):
|
| 260 |
+
super().__init__()
|
| 261 |
+
self.embed_dim = embed_dim
|
| 262 |
+
self.num_heads = num_heads
|
| 263 |
+
self.head_dim = embed_dim // num_heads
|
| 264 |
+
self.window_size = window_size
|
| 265 |
+
|
| 266 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
| 267 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 268 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 269 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 270 |
+
|
| 271 |
+
self.causal = causal
|
| 272 |
+
|
| 273 |
+
def forward(
|
| 274 |
+
self,
|
| 275 |
+
hidden_states: torch.Tensor,
|
| 276 |
+
seq_len: torch.Tensor,
|
| 277 |
+
rope_position_embeddings=None,
|
| 278 |
+
):
|
| 279 |
+
bsz, _ = hidden_states.size()
|
| 280 |
+
|
| 281 |
+
query_states = self.q_proj(hidden_states).view(
|
| 282 |
+
bsz, self.num_heads, self.head_dim
|
| 283 |
+
)
|
| 284 |
+
key_states = self.k_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
|
| 285 |
+
value_states = self.v_proj(hidden_states).view(
|
| 286 |
+
bsz, self.num_heads, self.head_dim
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
if rope_position_embeddings is not None:
|
| 290 |
+
cos, sin = rope_position_embeddings
|
| 291 |
+
query_states = apply_rotary_pos_emb(query_states, cos, sin)
|
| 292 |
+
key_states = apply_rotary_pos_emb(key_states, cos, sin)
|
| 293 |
+
|
| 294 |
+
cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(
|
| 295 |
+
torch.int32
|
| 296 |
+
)
|
| 297 |
+
max_seqlen = torch.max(seq_len).to(torch.int32).detach()
|
| 298 |
+
attn_output = flash_attn_varlen_func(
|
| 299 |
+
query_states,
|
| 300 |
+
key_states,
|
| 301 |
+
value_states,
|
| 302 |
+
cu_len,
|
| 303 |
+
cu_len,
|
| 304 |
+
max_seqlen,
|
| 305 |
+
max_seqlen,
|
| 306 |
+
causal=self.causal,
|
| 307 |
+
window_size=self.window_size,
|
| 308 |
+
)
|
| 309 |
+
attn_output = attn_output.reshape(bsz, self.embed_dim)
|
| 310 |
+
attn_output = self.out_proj(attn_output)
|
| 311 |
+
return attn_output
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class TransformerLayer(nn.Module):
|
| 315 |
+
def __init__(
|
| 316 |
+
self,
|
| 317 |
+
act,
|
| 318 |
+
d_model,
|
| 319 |
+
encoder_attention_heads,
|
| 320 |
+
encoder_ffn_dim,
|
| 321 |
+
causal,
|
| 322 |
+
ln_type="LayerNorm",
|
| 323 |
+
attn_window_size=(-1, -1),
|
| 324 |
+
):
|
| 325 |
+
super().__init__()
|
| 326 |
+
self.embed_dim = d_model
|
| 327 |
+
self.self_attn = Attention(
|
| 328 |
+
self.embed_dim, encoder_attention_heads, attn_window_size, causal
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
self.self_attn_layer_norm = LAYER_NORM[ln_type](self.embed_dim)
|
| 332 |
+
|
| 333 |
+
self.activation_fn = act
|
| 334 |
+
self.fc1 = nn.Linear(self.embed_dim, encoder_ffn_dim)
|
| 335 |
+
self.fc2 = nn.Linear(encoder_ffn_dim, self.embed_dim)
|
| 336 |
+
|
| 337 |
+
self.final_layer_norm = LAYER_NORM[ln_type](self.embed_dim)
|
| 338 |
+
|
| 339 |
+
def forward(
|
| 340 |
+
self,
|
| 341 |
+
hidden_states: torch.Tensor,
|
| 342 |
+
seq_len: torch.Tensor,
|
| 343 |
+
rope_position_embeddings: torch.Tensor,
|
| 344 |
+
) -> torch.Tensor:
|
| 345 |
+
residual = hidden_states
|
| 346 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 347 |
+
hidden_states = self.self_attn(
|
| 348 |
+
hidden_states, seq_len, rope_position_embeddings=rope_position_embeddings
|
| 349 |
+
)
|
| 350 |
+
hidden_states = residual + hidden_states
|
| 351 |
+
residual = hidden_states
|
| 352 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 353 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
| 354 |
+
hidden_states = self.fc2(hidden_states)
|
| 355 |
+
hidden_states = residual + hidden_states
|
| 356 |
+
|
| 357 |
+
if (
|
| 358 |
+
hidden_states.dtype == torch.float16
|
| 359 |
+
or hidden_states.dtype == torch.bfloat16
|
| 360 |
+
) and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()):
|
| 361 |
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
| 362 |
+
hidden_states = torch.clamp(
|
| 363 |
+
hidden_states, min=-clamp_value, max=clamp_value
|
| 364 |
+
)
|
| 365 |
+
return hidden_states
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
class TransformerVocos(nn.Module):
|
| 369 |
+
def __init__(self, config: MiMoAudioTokenizerConfig):
|
| 370 |
+
super().__init__()
|
| 371 |
+
self.config = config
|
| 372 |
+
self.max_source_positions = (
|
| 373 |
+
self.config.max_audio_seconds
|
| 374 |
+
* self.config.sampling_rate
|
| 375 |
+
// self.config.hop_length
|
| 376 |
+
)
|
| 377 |
+
self.embeddings = nn.Linear(config.n_mels, config.vocoder_dim, bias=False)
|
| 378 |
+
|
| 379 |
+
self.poisition_embedding = RotaryEmbedding(
|
| 380 |
+
config.rope_theta,
|
| 381 |
+
config.vocoder_dim // config.vocoder_attention_heads,
|
| 382 |
+
self.max_source_positions,
|
| 383 |
+
self.config.rope_type,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
self.layers = nn.ModuleList(
|
| 387 |
+
[
|
| 388 |
+
TransformerLayer(
|
| 389 |
+
ACT2FN[self.config.activation_function],
|
| 390 |
+
self.config.vocoder_dim,
|
| 391 |
+
self.config.vocoder_attention_heads,
|
| 392 |
+
self.config.vocoder_intermediate_dim,
|
| 393 |
+
causal=False,
|
| 394 |
+
ln_type=self.config.ln_type,
|
| 395 |
+
attn_window_size=self.config.vocoder_attn_window_size,
|
| 396 |
+
)
|
| 397 |
+
for _ in range(self.config.vocoder_num_layers)
|
| 398 |
+
]
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
self.layer_norm = LAYER_NORM[self.config.ln_type](self.config.vocoder_dim)
|
| 402 |
+
self.hop_size = self.config.hop_length
|
| 403 |
+
self.head = ISTFTHead(
|
| 404 |
+
self.config.vocoder_dim,
|
| 405 |
+
self.config.nfft,
|
| 406 |
+
self.config.hop_length,
|
| 407 |
+
self.config.vocoder_padding,
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
def forward(self, x: torch.Tensor, input_length):
|
| 411 |
+
x = x.transpose(1, 2)
|
| 412 |
+
attention_mask, unpacking_index = get_sequence_mask(x, input_length)
|
| 413 |
+
x = torch.masked_select(x, attention_mask).view(
|
| 414 |
+
torch.sum(input_length), self.config.n_mels
|
| 415 |
+
)
|
| 416 |
+
x = self.embeddings(x)
|
| 417 |
+
position_ids = torch.arange(0, x.size(0), device=x.device, dtype=torch.long)
|
| 418 |
+
rope_position_embeddings = self.poisition_embedding(x, position_ids)
|
| 419 |
+
for idx, layer in enumerate(self.layers):
|
| 420 |
+
x = layer(
|
| 421 |
+
x, input_length, rope_position_embeddings=rope_position_embeddings
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
x = self.layer_norm(x)
|
| 425 |
+
x = unpack_hidden_states(x, input_length, attention_mask, unpacking_index)
|
| 426 |
+
x = self.head(x)
|
| 427 |
+
output_length = input_length * self.hop_size
|
| 428 |
+
return x[:, None, :], output_length
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
class AudioEncoder(nn.Module):
|
| 432 |
+
def __init__(self, config: MiMoAudioTokenizerConfig):
|
| 433 |
+
super().__init__()
|
| 434 |
+
config._attn_implementation = "flash_attention_2"
|
| 435 |
+
self.config = config
|
| 436 |
+
self.max_source_positions = (
|
| 437 |
+
config.max_audio_seconds * config.sampling_rate // config.hop_length
|
| 438 |
+
) // config.stride_size
|
| 439 |
+
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
| 440 |
+
|
| 441 |
+
self.skip_layer_idx = config.encoder_skip_layer_id
|
| 442 |
+
self.conv1 = nn.Conv1d(
|
| 443 |
+
config.n_mels, config.d_model, kernel_size=config.kernel_size, padding=1
|
| 444 |
+
)
|
| 445 |
+
self.conv2 = nn.Conv1d(
|
| 446 |
+
config.d_model,
|
| 447 |
+
config.d_model,
|
| 448 |
+
kernel_size=config.kernel_size,
|
| 449 |
+
stride=config.stride_size,
|
| 450 |
+
padding=1,
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
self.position_embedding = RotaryEmbedding(
|
| 454 |
+
config.rope_theta,
|
| 455 |
+
config.d_model // config.encoder_attention_heads,
|
| 456 |
+
self.max_source_positions,
|
| 457 |
+
config.rope_type,
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
self.layers = nn.ModuleList(
|
| 461 |
+
[
|
| 462 |
+
TransformerLayer(
|
| 463 |
+
ACT2FN[config.activation_function],
|
| 464 |
+
config.d_model,
|
| 465 |
+
config.encoder_attention_heads,
|
| 466 |
+
config.encoder_ffn_dim,
|
| 467 |
+
causal=self.config.encoder_causal,
|
| 468 |
+
ln_type=self.config.ln_type,
|
| 469 |
+
attn_window_size=self.config.encoder_attn_window_size,
|
| 470 |
+
)
|
| 471 |
+
for _ in range(config.encoder_layers)
|
| 472 |
+
]
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
self.layer_norm = LAYER_NORM[config.ln_type](config.d_model)
|
| 476 |
+
|
| 477 |
+
if self.config.avg_pooler != 1:
|
| 478 |
+
self.down_sample_layer = nn.Sequential(
|
| 479 |
+
nn.Conv1d(
|
| 480 |
+
config.d_model,
|
| 481 |
+
config.d_model,
|
| 482 |
+
config.avg_pooler,
|
| 483 |
+
config.avg_pooler,
|
| 484 |
+
bias=False,
|
| 485 |
+
),
|
| 486 |
+
nn.GELU(),
|
| 487 |
+
)
|
| 488 |
+
self.down_sample_norm = LAYER_NORM[config.ln_type](config.d_model)
|
| 489 |
+
else:
|
| 490 |
+
self.down_sample_layer = None
|
| 491 |
+
|
| 492 |
+
if self.config.num_quantizers != 0:
|
| 493 |
+
self.quantizer = ResidualVectorQuantizer(
|
| 494 |
+
dimension=self.config.d_model,
|
| 495 |
+
n_q=self.config.num_quantizers,
|
| 496 |
+
bins=self.config.codebook_size,
|
| 497 |
+
threshold_ema_dead_code=self.config.threshold_ema_dead_code,
|
| 498 |
+
)
|
| 499 |
+
else:
|
| 500 |
+
self.quantizer = None
|
| 501 |
+
|
| 502 |
+
def get_features(self, input_features, output_length):
|
| 503 |
+
input_features = input_features.to(self.conv1.weight)
|
| 504 |
+
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
|
| 505 |
+
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
|
| 506 |
+
inputs_embeds = inputs_embeds.permute(0, 2, 1)
|
| 507 |
+
bsz, tgt_len, _ = inputs_embeds.size()
|
| 508 |
+
|
| 509 |
+
hidden_states = inputs_embeds
|
| 510 |
+
|
| 511 |
+
position_ids = (
|
| 512 |
+
get_position_ids(output_length).long().to(input_features.device)
|
| 513 |
+
)
|
| 514 |
+
rope_position_embeddings = self.position_embedding(
|
| 515 |
+
input_features, position_ids
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
attention_mask, unpacking_index = get_sequence_mask(
|
| 519 |
+
hidden_states, output_length
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
hidden_states = torch.masked_select(hidden_states, attention_mask).view(
|
| 523 |
+
torch.sum(output_length), self.config.d_model
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
skip_connect_hidden_states = 0.0
|
| 527 |
+
for idx, encoder_layer in enumerate(self.layers):
|
| 528 |
+
hidden_states = encoder_layer(
|
| 529 |
+
hidden_states,
|
| 530 |
+
output_length,
|
| 531 |
+
rope_position_embeddings=rope_position_embeddings,
|
| 532 |
+
)
|
| 533 |
+
if (self.skip_layer_idx is not None) and idx == self.skip_layer_idx - 1:
|
| 534 |
+
skip_connect_hidden_states = hidden_states.clone()
|
| 535 |
+
|
| 536 |
+
hidden_states += skip_connect_hidden_states
|
| 537 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 538 |
+
|
| 539 |
+
if self.down_sample_layer is not None:
|
| 540 |
+
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
|
| 541 |
+
bsz, tgt_len, self.config.d_model
|
| 542 |
+
)
|
| 543 |
+
if hidden_states.size(1) % self.config.avg_pooler:
|
| 544 |
+
pad_len = (
|
| 545 |
+
self.config.avg_pooler
|
| 546 |
+
- hidden_states.size(1) % self.config.avg_pooler
|
| 547 |
+
)
|
| 548 |
+
hidden_states = torch.nn.functional.pad(
|
| 549 |
+
hidden_states, (0, 0, 0, pad_len), mode="constant", value=0.0
|
| 550 |
+
)
|
| 551 |
+
tgt_len += pad_len
|
| 552 |
+
tgt_len = tgt_len // self.config.avg_pooler
|
| 553 |
+
hidden_states = self.down_sample_layer(hidden_states.transpose(1, 2))
|
| 554 |
+
output_length = (
|
| 555 |
+
output_length // self.config.avg_pooler
|
| 556 |
+
+ (output_length % self.config.avg_pooler != 0).int()
|
| 557 |
+
)
|
| 558 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 559 |
+
attention_mask, unpacking_index = get_sequence_mask(
|
| 560 |
+
hidden_states, output_length
|
| 561 |
+
)
|
| 562 |
+
hidden_states = torch.masked_select(hidden_states, attention_mask).view(
|
| 563 |
+
torch.sum(output_length), self.config.d_model
|
| 564 |
+
)
|
| 565 |
+
hidden_states = self.down_sample_norm(hidden_states)
|
| 566 |
+
|
| 567 |
+
return (
|
| 568 |
+
hidden_states,
|
| 569 |
+
output_length,
|
| 570 |
+
attention_mask,
|
| 571 |
+
unpacking_index,
|
| 572 |
+
tgt_len,
|
| 573 |
+
bsz,
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
def get_output_length(self, mel_len):
|
| 577 |
+
tgt_len = mel_len + 3 - self.config.kernel_size
|
| 578 |
+
return (tgt_len + 2 - self.config.kernel_size) // self.config.stride_size + 1
|
| 579 |
+
|
| 580 |
+
@torch.no_grad()
|
| 581 |
+
def encode(
|
| 582 |
+
self,
|
| 583 |
+
input_features,
|
| 584 |
+
input_lens=None,
|
| 585 |
+
output_length=None,
|
| 586 |
+
return_codes_only=False,
|
| 587 |
+
n_q=None,
|
| 588 |
+
use_quantizer=True,
|
| 589 |
+
):
|
| 590 |
+
if output_length is None:
|
| 591 |
+
output_length = self.get_output_length(input_lens)
|
| 592 |
+
input_features = unpack_hidden_states(input_features, input_lens)
|
| 593 |
+
hidden_states, output_length, attention_mask, unpacking_index, tgt_len, bsz = (
|
| 594 |
+
self.get_features(
|
| 595 |
+
input_features=input_features.transpose(1, 2),
|
| 596 |
+
output_length=output_length,
|
| 597 |
+
)
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
dtype = hidden_states.dtype
|
| 601 |
+
|
| 602 |
+
if use_quantizer and self.quantizer is not None:
|
| 603 |
+
self.quantizer.float()
|
| 604 |
+
|
| 605 |
+
codes = self.quantizer.encode(hidden_states.float(), n_q=n_q)
|
| 606 |
+
if return_codes_only:
|
| 607 |
+
return codes, output_length
|
| 608 |
+
hidden_states = self.quantizer.decode(codes)
|
| 609 |
+
hidden_states = hidden_states.to(dtype)
|
| 610 |
+
else:
|
| 611 |
+
codes = None
|
| 612 |
+
|
| 613 |
+
hidden_states_packed = hidden_states.clone()
|
| 614 |
+
|
| 615 |
+
# unpacking
|
| 616 |
+
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
|
| 617 |
+
bsz, tgt_len, self.config.d_model
|
| 618 |
+
)
|
| 619 |
+
hidden_states = torch.where(attention_mask, hidden_states, 0)
|
| 620 |
+
return hidden_states, hidden_states_packed, output_length, codes
|
| 621 |
+
|
| 622 |
+
@torch.no_grad()
|
| 623 |
+
def decode_vq(self, codes):
|
| 624 |
+
self.quantizer.float()
|
| 625 |
+
hidden_states = self.quantizer.decode(codes)
|
| 626 |
+
|
| 627 |
+
return hidden_states
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
class CausalConvTranspose1d(nn.Module):
|
| 631 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride):
|
| 632 |
+
super().__init__()
|
| 633 |
+
self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
|
| 634 |
+
self.norm = nn.GroupNorm(1, out_channels)
|
| 635 |
+
self.in_channels = in_channels
|
| 636 |
+
self.out_channels = out_channels
|
| 637 |
+
|
| 638 |
+
def forward(self, hidden_states, input_length, output_dim=None):
|
| 639 |
+
kernel_size = self.conv.kernel_size[0]
|
| 640 |
+
stride = self.conv.stride[0]
|
| 641 |
+
bsz = input_length.shape[0]
|
| 642 |
+
|
| 643 |
+
if output_dim is None:
|
| 644 |
+
output_dim = hidden_states.dim()
|
| 645 |
+
if hidden_states.dim() <= 2: # unpack sequence to 3d
|
| 646 |
+
sequence_mask, unpacking_index = get_sequence_mask(
|
| 647 |
+
hidden_states, input_length
|
| 648 |
+
)
|
| 649 |
+
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
|
| 650 |
+
bsz, torch.max(input_length), self.in_channels
|
| 651 |
+
)
|
| 652 |
+
hidden_states = torch.where(sequence_mask, hidden_states, 0)
|
| 653 |
+
|
| 654 |
+
hidden_states = hidden_states.transpose(2, 1) # (N, L, C) -> (N, C, L)
|
| 655 |
+
hidden_states = self.conv(hidden_states)
|
| 656 |
+
hidden_states = self.norm(hidden_states)
|
| 657 |
+
hidden_states = hidden_states.transpose(2, 1) # (N, C, L) -> (N, L, C)
|
| 658 |
+
|
| 659 |
+
casual_padding_right = max(0, kernel_size - stride)
|
| 660 |
+
hidden_states = hidden_states[
|
| 661 |
+
:, : hidden_states.shape[1] - casual_padding_right, :
|
| 662 |
+
]
|
| 663 |
+
output_length = (input_length - 1) * stride + kernel_size - casual_padding_right
|
| 664 |
+
sequence_mask, _ = get_sequence_mask(hidden_states, output_length)
|
| 665 |
+
if output_dim <= 2:
|
| 666 |
+
hidden_states = torch.masked_select(hidden_states, sequence_mask).view(
|
| 667 |
+
-1, self.out_channels
|
| 668 |
+
)
|
| 669 |
+
else:
|
| 670 |
+
hidden_states = torch.where(sequence_mask, hidden_states, 0)
|
| 671 |
+
hidden_states = hidden_states[:, : torch.max(output_length), :]
|
| 672 |
+
return hidden_states, output_length
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
class AudioDecoder(nn.Module):
|
| 676 |
+
def __init__(self, config: MiMoAudioTokenizerConfig):
|
| 677 |
+
super().__init__()
|
| 678 |
+
self.config = config
|
| 679 |
+
self.max_source_positions = (
|
| 680 |
+
self.config.max_audio_seconds
|
| 681 |
+
* self.config.sampling_rate
|
| 682 |
+
// self.config.hop_length
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
if self.config.avg_pooler != 1:
|
| 686 |
+
self.dconv1 = CausalConvTranspose1d(
|
| 687 |
+
self.config.d_model,
|
| 688 |
+
self.config.d_model,
|
| 689 |
+
self.config.avg_pooler,
|
| 690 |
+
self.config.avg_pooler,
|
| 691 |
+
)
|
| 692 |
+
else:
|
| 693 |
+
self.dconv1 = None
|
| 694 |
+
|
| 695 |
+
self.position_embedding = RotaryEmbedding(
|
| 696 |
+
config.rope_theta,
|
| 697 |
+
config.d_model // config.decoder_attention_heads,
|
| 698 |
+
self.max_source_positions,
|
| 699 |
+
config.rope_type,
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
self.layers = nn.ModuleList(
|
| 703 |
+
[
|
| 704 |
+
TransformerLayer(
|
| 705 |
+
ACT2FN[self.config.activation_function],
|
| 706 |
+
self.config.d_model,
|
| 707 |
+
self.config.decoder_attention_heads,
|
| 708 |
+
self.config.decoder_ffn_dim,
|
| 709 |
+
causal=self.config.decoder_causal,
|
| 710 |
+
ln_type=self.config.ln_type,
|
| 711 |
+
attn_window_size=self.config.decoder_attn_window_size,
|
| 712 |
+
)
|
| 713 |
+
for _ in range(self.config.decoder_layers)
|
| 714 |
+
]
|
| 715 |
+
)
|
| 716 |
+
self.layer_norm = LAYER_NORM[config.ln_type](self.config.d_model)
|
| 717 |
+
self.dconv2 = CausalConvTranspose1d(
|
| 718 |
+
self.config.d_model,
|
| 719 |
+
self.config.n_mels,
|
| 720 |
+
self.config.decoder_kernel_size,
|
| 721 |
+
self.config.decoder_stride_size,
|
| 722 |
+
)
|
| 723 |
+
self.vocoder = TransformerVocos(config)
|
| 724 |
+
|
| 725 |
+
def forward(
|
| 726 |
+
self,
|
| 727 |
+
audio_embed,
|
| 728 |
+
input_length,
|
| 729 |
+
):
|
| 730 |
+
assert audio_embed.shape[-1] == self.config.d_model
|
| 731 |
+
audio_embed = audio_embed.to(self.layer_norm.weight)
|
| 732 |
+
|
| 733 |
+
if self.dconv1 is not None:
|
| 734 |
+
audio_embed, output_length = self.dconv1(
|
| 735 |
+
audio_embed, input_length, output_dim=3
|
| 736 |
+
)
|
| 737 |
+
_, tgt_len, _ = audio_embed.size()
|
| 738 |
+
else:
|
| 739 |
+
output_length = input_length
|
| 740 |
+
tgt_len = audio_embed.size(0)
|
| 741 |
+
|
| 742 |
+
hidden_states = audio_embed
|
| 743 |
+
|
| 744 |
+
position_ids = (
|
| 745 |
+
get_position_ids(output_length).long().to(hidden_states.device)
|
| 746 |
+
)
|
| 747 |
+
rope_position_embeddings = self.position_embedding(
|
| 748 |
+
hidden_states, position_ids
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
# packing hidden states
|
| 753 |
+
attention_mask, _ = get_sequence_mask(hidden_states, output_length)
|
| 754 |
+
hidden_states = torch.masked_select(hidden_states, attention_mask).view(
|
| 755 |
+
torch.sum(output_length), self.config.d_model
|
| 756 |
+
)
|
| 757 |
+
|
| 758 |
+
for idx, encoder_layer in enumerate(self.layers):
|
| 759 |
+
hidden_states = encoder_layer(
|
| 760 |
+
hidden_states,
|
| 761 |
+
output_length,
|
| 762 |
+
rope_position_embeddings=rope_position_embeddings,
|
| 763 |
+
)
|
| 764 |
+
|
| 765 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 766 |
+
|
| 767 |
+
coarse_mel, output_length = self.dconv2(
|
| 768 |
+
hidden_states, output_length, output_dim=3
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
recon_wav, wav_length = self.vocoder(
|
| 772 |
+
x=coarse_mel.transpose(1, 2),
|
| 773 |
+
input_length=output_length,
|
| 774 |
+
)
|
| 775 |
+
|
| 776 |
+
return recon_wav
|
| 777 |
+
|
| 778 |
+
|
| 779 |
+
class MiMoAudioTokenizer(PreTrainedModel):
|
| 780 |
+
config_class = MiMoAudioTokenizerConfig
|
| 781 |
+
|
| 782 |
+
def __init__(self, config: MiMoAudioTokenizerConfig):
|
| 783 |
+
super().__init__(config)
|
| 784 |
+
self.config = config
|
| 785 |
+
self.sampling_rate = config.sampling_rate
|
| 786 |
+
self.encoder = AudioEncoder(config=config)
|
| 787 |
+
self.decoder = AudioDecoder(config=config)
|
| 788 |
+
self.downsample_rate = int(self.config.hop_length * 2 * self.config.avg_pooler)
|
| 789 |
+
|
| 790 |
+
def get_output_length(self, mel_len):
|
| 791 |
+
tgt_len = mel_len + 3 - self.config.kernel_size
|
| 792 |
+
return (tgt_len + 2 - self.config.kernel_size) // self.config.stride_size + 1
|
| 793 |
+
|
| 794 |
+
@torch.no_grad()
|
| 795 |
+
def encode(self, mels, input_lens, use_quantizer=True):
|
| 796 |
+
input_features = mels
|
| 797 |
+
encoder_output_length = self.get_output_length(input_lens)
|
| 798 |
+
hidden_states, hidden_states_packed, encoder_output_length, codes = (
|
| 799 |
+
self.encoder.encode(
|
| 800 |
+
input_features, input_lens=input_lens, use_quantizer=use_quantizer
|
| 801 |
+
)
|
| 802 |
+
)
|
| 803 |
+
return hidden_states, hidden_states_packed, encoder_output_length, codes
|
| 804 |
+
|
| 805 |
+
@torch.no_grad()
|
| 806 |
+
def decode(self, codes):
|
| 807 |
+
hidden_states = self.encoder.decode_vq(codes)
|
| 808 |
+
output = self.decoder(
|
| 809 |
+
hidden_states,
|
| 810 |
+
torch.tensor([hidden_states.size(0)], device=hidden_states.device),
|
| 811 |
+
)
|
| 812 |
+
return output
|
| 813 |
+
|
| 814 |
+
@torch.no_grad()
|
| 815 |
+
def streaming_decode(self, codes_chunks, chunk_input_lengths, history_cache=StreamingCache(), streaming_config=StreamingConfig(), last_chunk=False):
|
| 816 |
+
hidden_states = self.encoder.decode_vq(codes_chunks)
|
| 817 |
+
input_lengths = []
|
| 818 |
+
input_hidden_states = []
|
| 819 |
+
start_idx = 0
|
| 820 |
+
cache_hidden_states = []
|
| 821 |
+
for i, input_length in enumerate(chunk_input_lengths):
|
| 822 |
+
sample_hidden_states = hidden_states[start_idx:start_idx + input_length]
|
| 823 |
+
start_idx += input_length
|
| 824 |
+
if history_cache.hidden_states is not None:
|
| 825 |
+
sample_hidden_states = torch.cat([history_cache.hidden_states[i], sample_hidden_states], dim=0)
|
| 826 |
+
input_length += history_cache.hidden_states[i].size(0)
|
| 827 |
+
input_hidden_states.append(sample_hidden_states)
|
| 828 |
+
cache_hidden_states.append(sample_hidden_states.clone())
|
| 829 |
+
input_lengths.append(input_length)
|
| 830 |
+
input_hidden_states = torch.cat(input_hidden_states, dim=0)
|
| 831 |
+
input_lengths = torch.tensor(input_lengths, device=hidden_states.device)
|
| 832 |
+
output = self.decoder(input_hidden_states, input_lengths)
|
| 833 |
+
return_wavs = []
|
| 834 |
+
frames_per_token = self.config.avg_pooler * self.config.stride_size * self.config.hop_length
|
| 835 |
+
processed_lengths = []
|
| 836 |
+
for i, wav in enumerate(output):
|
| 837 |
+
wav = wav.float().detach().cpu()
|
| 838 |
+
start_idx = history_cache.processed_lengths[i] if history_cache.processed_lengths is not None else 0
|
| 839 |
+
if last_chunk:
|
| 840 |
+
return_wavs.append(wav[:, start_idx * frames_per_token:])
|
| 841 |
+
new_processed_length = input_lengths[i].item()
|
| 842 |
+
elif input_lengths[i].item() <= streaming_config.right_overlap:
|
| 843 |
+
return_wavs.append(None)
|
| 844 |
+
new_processed_length = 0
|
| 845 |
+
else:
|
| 846 |
+
end_idx = (input_lengths[i].item() - streaming_config.right_overlap)
|
| 847 |
+
wav = wav[:, start_idx * frames_per_token: end_idx * frames_per_token]
|
| 848 |
+
return_wavs.append(wav)
|
| 849 |
+
new_processed_length = end_idx
|
| 850 |
+
if input_lengths[i].item() > streaming_config.left_overlap:
|
| 851 |
+
cache_hidden_states[i] = cache_hidden_states[i][-streaming_config.left_overlap:]
|
| 852 |
+
new_processed_length -= (input_lengths[i].item() - streaming_config.left_overlap)
|
| 853 |
+
processed_lengths.append(new_processed_length)
|
| 854 |
+
history_cache.hidden_states = cache_hidden_states
|
| 855 |
+
history_cache.processed_lengths = processed_lengths
|
| 856 |
+
|
| 857 |
+
return return_wavs, history_cache
|
src/mimo_audio_tokenizer/modeling_rope_utils.py
ADDED
|
@@ -0,0 +1,878 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Xiaomi Corporation.
|
| 2 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import math
|
| 17 |
+
from functools import wraps
|
| 18 |
+
from typing import Optional
|
| 19 |
+
|
| 20 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 21 |
+
from transformers.utils import is_torch_available, logging
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if is_torch_available():
|
| 28 |
+
import torch
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def dynamic_rope_update(rope_forward):
|
| 32 |
+
"""
|
| 33 |
+
Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
|
| 34 |
+
(i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
rope_forward (Callable):
|
| 38 |
+
The forward pass of the RoPE implementation.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
The decorated forward pass.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def longrope_frequency_update(self, position_ids, device):
|
| 45 |
+
"""Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
|
| 46 |
+
seq_len = torch.max(position_ids) + 1
|
| 47 |
+
if hasattr(self.config, "original_max_position_embeddings"):
|
| 48 |
+
original_max_position_embeddings = (
|
| 49 |
+
self.config.original_max_position_embeddings
|
| 50 |
+
)
|
| 51 |
+
else:
|
| 52 |
+
original_max_position_embeddings = self.config.max_position_embeddings
|
| 53 |
+
if seq_len > original_max_position_embeddings:
|
| 54 |
+
if not hasattr(self, "long_inv_freq"):
|
| 55 |
+
self.long_inv_freq, _ = self.rope_init_fn(
|
| 56 |
+
self.config, device, seq_len=original_max_position_embeddings + 1
|
| 57 |
+
)
|
| 58 |
+
self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
|
| 59 |
+
else:
|
| 60 |
+
# This .to() is needed if the model has been moved to a device after being initialized (because
|
| 61 |
+
# the buffer is automatically moved, but not the original copy)
|
| 62 |
+
self.original_inv_freq = self.original_inv_freq.to(device)
|
| 63 |
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
| 64 |
+
|
| 65 |
+
def dynamic_frequency_update(self, position_ids, device):
|
| 66 |
+
"""
|
| 67 |
+
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
| 68 |
+
1 - growing beyond the cached sequence length (allow scaling)
|
| 69 |
+
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
| 70 |
+
"""
|
| 71 |
+
seq_len = torch.max(position_ids) + 1
|
| 72 |
+
if seq_len > self.max_seq_len_cached: # growth
|
| 73 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(
|
| 74 |
+
self.config, device, seq_len=seq_len
|
| 75 |
+
)
|
| 76 |
+
self.register_buffer(
|
| 77 |
+
"inv_freq", inv_freq, persistent=False
|
| 78 |
+
) # TODO joao: may break with compilation
|
| 79 |
+
self.max_seq_len_cached = seq_len
|
| 80 |
+
|
| 81 |
+
if (
|
| 82 |
+
seq_len < self.original_max_seq_len
|
| 83 |
+
and self.max_seq_len_cached > self.original_max_seq_len
|
| 84 |
+
): # reset
|
| 85 |
+
# This .to() is needed if the model has been moved to a device after being initialized (because
|
| 86 |
+
# the buffer is automatically moved, but not the original copy)
|
| 87 |
+
self.original_inv_freq = self.original_inv_freq.to(device)
|
| 88 |
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
| 89 |
+
self.max_seq_len_cached = self.original_max_seq_len
|
| 90 |
+
|
| 91 |
+
@wraps(rope_forward)
|
| 92 |
+
def wrapper(self, x, position_ids):
|
| 93 |
+
if "dynamic" in self.rope_type:
|
| 94 |
+
dynamic_frequency_update(self, position_ids, device=x.device)
|
| 95 |
+
elif self.rope_type == "longrope":
|
| 96 |
+
longrope_frequency_update(self, position_ids, device=x.device)
|
| 97 |
+
return rope_forward(self, x, position_ids)
|
| 98 |
+
|
| 99 |
+
return wrapper
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _compute_default_rope_parameters(
|
| 103 |
+
config: Optional[PretrainedConfig] = None,
|
| 104 |
+
device: Optional["torch.device"] = None,
|
| 105 |
+
seq_len: Optional[int] = None,
|
| 106 |
+
**rope_kwargs,
|
| 107 |
+
) -> tuple["torch.Tensor", float]:
|
| 108 |
+
"""
|
| 109 |
+
Computes the inverse frequencies according to the original RoPE implementation
|
| 110 |
+
Args:
|
| 111 |
+
config ([`~transformers.PretrainedConfig`]):
|
| 112 |
+
The model configuration.
|
| 113 |
+
device (`torch.device`):
|
| 114 |
+
The device to use for initialization of the inverse frequencies.
|
| 115 |
+
seq_len (`int`, *optional*):
|
| 116 |
+
The current sequence length. Unused for this type of RoPE.
|
| 117 |
+
rope_kwargs (`Dict`, *optional*):
|
| 118 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
| 119 |
+
Returns:
|
| 120 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 121 |
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
| 122 |
+
"""
|
| 123 |
+
if config is not None and len(rope_kwargs) > 0:
|
| 124 |
+
raise ValueError(
|
| 125 |
+
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
|
| 126 |
+
f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
|
| 127 |
+
)
|
| 128 |
+
if len(rope_kwargs) > 0:
|
| 129 |
+
base = rope_kwargs["base"]
|
| 130 |
+
dim = rope_kwargs["dim"]
|
| 131 |
+
elif config is not None:
|
| 132 |
+
base = config.rope_theta
|
| 133 |
+
partial_rotary_factor = (
|
| 134 |
+
config.partial_rotary_factor
|
| 135 |
+
if hasattr(config, "partial_rotary_factor")
|
| 136 |
+
else 1.0
|
| 137 |
+
)
|
| 138 |
+
head_dim = (
|
| 139 |
+
getattr(config, "head_dim", None)
|
| 140 |
+
or config.hidden_size // config.num_attention_heads
|
| 141 |
+
)
|
| 142 |
+
dim = int(head_dim * partial_rotary_factor)
|
| 143 |
+
|
| 144 |
+
attention_factor = 1.0 # Unused in this type of RoPE
|
| 145 |
+
|
| 146 |
+
# Compute the inverse frequencies
|
| 147 |
+
inv_freq = 1.0 / (
|
| 148 |
+
base
|
| 149 |
+
** (
|
| 150 |
+
torch.arange(0, dim, 2, dtype=torch.int64).to(
|
| 151 |
+
device=device, dtype=torch.float
|
| 152 |
+
)
|
| 153 |
+
/ dim
|
| 154 |
+
)
|
| 155 |
+
)
|
| 156 |
+
return inv_freq, attention_factor
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _compute_linear_scaling_rope_parameters(
|
| 160 |
+
config: Optional[PretrainedConfig] = None,
|
| 161 |
+
device: Optional["torch.device"] = None,
|
| 162 |
+
seq_len: Optional[int] = None,
|
| 163 |
+
**rope_kwargs,
|
| 164 |
+
) -> tuple["torch.Tensor", float]:
|
| 165 |
+
"""
|
| 166 |
+
Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
|
| 167 |
+
Args:
|
| 168 |
+
config ([`~transformers.PretrainedConfig`]):
|
| 169 |
+
The model configuration.
|
| 170 |
+
device (`torch.device`):
|
| 171 |
+
The device to use for initialization of the inverse frequencies.
|
| 172 |
+
seq_len (`int`, *optional*):
|
| 173 |
+
The current sequence length. Unused for this type of RoPE.
|
| 174 |
+
rope_kwargs (`Dict`, *optional*):
|
| 175 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
| 176 |
+
Returns:
|
| 177 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 178 |
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
| 179 |
+
"""
|
| 180 |
+
if config is not None and len(rope_kwargs) > 0:
|
| 181 |
+
raise ValueError(
|
| 182 |
+
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
|
| 183 |
+
f"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
|
| 184 |
+
)
|
| 185 |
+
if len(rope_kwargs) > 0:
|
| 186 |
+
factor = rope_kwargs["factor"]
|
| 187 |
+
elif config is not None:
|
| 188 |
+
factor = config.rope_scaling["factor"]
|
| 189 |
+
|
| 190 |
+
# Gets the default RoPE parameters
|
| 191 |
+
inv_freq, attention_factor = _compute_default_rope_parameters(
|
| 192 |
+
config, device, seq_len, **rope_kwargs
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# Then applies linear scaling to the frequencies.
|
| 196 |
+
# NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
|
| 197 |
+
# applying scaling to the inverse frequencies is equivalent.
|
| 198 |
+
inv_freq /= factor
|
| 199 |
+
return inv_freq, attention_factor
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def _compute_dynamic_ntk_parameters(
|
| 203 |
+
config: Optional[PretrainedConfig] = None,
|
| 204 |
+
device: Optional["torch.device"] = None,
|
| 205 |
+
seq_len: Optional[int] = None,
|
| 206 |
+
**rope_kwargs,
|
| 207 |
+
) -> tuple["torch.Tensor", float]:
|
| 208 |
+
"""
|
| 209 |
+
Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
|
| 210 |
+
Args:
|
| 211 |
+
config ([`~transformers.PretrainedConfig`]):
|
| 212 |
+
The model configuration.
|
| 213 |
+
device (`torch.device`):
|
| 214 |
+
The device to use for initialization of the inverse frequencies.
|
| 215 |
+
seq_len (`int`, *optional*):
|
| 216 |
+
The current sequence length, used to update the dynamic RoPE at inference time.
|
| 217 |
+
rope_kwargs (`Dict`, *optional*):
|
| 218 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
| 219 |
+
Returns:
|
| 220 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 221 |
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
| 222 |
+
"""
|
| 223 |
+
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
|
| 224 |
+
if config is not None and len(rope_kwargs) > 0:
|
| 225 |
+
raise ValueError(
|
| 226 |
+
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
|
| 227 |
+
f"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
|
| 228 |
+
)
|
| 229 |
+
if len(rope_kwargs) > 0:
|
| 230 |
+
base = rope_kwargs["base"]
|
| 231 |
+
dim = rope_kwargs["dim"]
|
| 232 |
+
max_position_embeddings = rope_kwargs["max_position_embeddings"]
|
| 233 |
+
factor = rope_kwargs["factor"]
|
| 234 |
+
elif config is not None:
|
| 235 |
+
base = config.rope_theta
|
| 236 |
+
partial_rotary_factor = (
|
| 237 |
+
config.partial_rotary_factor
|
| 238 |
+
if hasattr(config, "partial_rotary_factor")
|
| 239 |
+
else 1.0
|
| 240 |
+
)
|
| 241 |
+
head_dim = getattr(
|
| 242 |
+
config, "head_dim", config.hidden_size // config.num_attention_heads
|
| 243 |
+
)
|
| 244 |
+
dim = int(head_dim * partial_rotary_factor)
|
| 245 |
+
max_position_embeddings = config.max_position_embeddings
|
| 246 |
+
factor = config.rope_scaling["factor"]
|
| 247 |
+
|
| 248 |
+
attention_factor = 1.0 # Unused in this type of RoPE
|
| 249 |
+
|
| 250 |
+
# seq_len: default to max_position_embeddings, e.g. at init time
|
| 251 |
+
seq_len = (
|
| 252 |
+
seq_len
|
| 253 |
+
if seq_len is not None and seq_len > max_position_embeddings
|
| 254 |
+
else max_position_embeddings
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Compute the inverse frequencies
|
| 258 |
+
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (
|
| 259 |
+
dim / (dim - 2)
|
| 260 |
+
)
|
| 261 |
+
inv_freq = 1.0 / (
|
| 262 |
+
base
|
| 263 |
+
** (
|
| 264 |
+
torch.arange(0, dim, 2, dtype=torch.int64).to(
|
| 265 |
+
device=device, dtype=torch.float
|
| 266 |
+
)
|
| 267 |
+
/ dim
|
| 268 |
+
)
|
| 269 |
+
)
|
| 270 |
+
return inv_freq, attention_factor
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def _compute_yarn_parameters(
|
| 274 |
+
config: PretrainedConfig,
|
| 275 |
+
device: "torch.device",
|
| 276 |
+
seq_len: Optional[int] = None,
|
| 277 |
+
**rope_kwargs,
|
| 278 |
+
) -> tuple["torch.Tensor", float]:
|
| 279 |
+
"""
|
| 280 |
+
Computes the inverse frequencies with NTK scaling. Please refer to the
|
| 281 |
+
[original paper](https://huggingface.co/papers/2309.00071)
|
| 282 |
+
Args:
|
| 283 |
+
config ([`~transformers.PretrainedConfig`]):
|
| 284 |
+
The model configuration.
|
| 285 |
+
device (`torch.device`):
|
| 286 |
+
The device to use for initialization of the inverse frequencies.
|
| 287 |
+
seq_len (`int`, *optional*):
|
| 288 |
+
The current sequence length. Unused for this type of RoPE.
|
| 289 |
+
rope_kwargs (`Dict`, *optional*):
|
| 290 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
| 291 |
+
Returns:
|
| 292 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 293 |
+
post-processing scaling factor applied to the computed cos/sin.
|
| 294 |
+
"""
|
| 295 |
+
# No need to keep BC with yarn, unreleased when this new pattern was created.
|
| 296 |
+
if len(rope_kwargs) > 0:
|
| 297 |
+
raise ValueError(
|
| 298 |
+
f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}"
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
base = config.rope_theta
|
| 302 |
+
partial_rotary_factor = (
|
| 303 |
+
config.partial_rotary_factor
|
| 304 |
+
if hasattr(config, "partial_rotary_factor")
|
| 305 |
+
else 1.0
|
| 306 |
+
)
|
| 307 |
+
head_dim = getattr(
|
| 308 |
+
config, "head_dim", config.hidden_size // config.num_attention_heads
|
| 309 |
+
)
|
| 310 |
+
dim = int(head_dim * partial_rotary_factor)
|
| 311 |
+
factor = config.rope_scaling["factor"]
|
| 312 |
+
attention_factor = config.rope_scaling.get("attention_factor")
|
| 313 |
+
mscale = config.rope_scaling.get("mscale")
|
| 314 |
+
mscale_all_dim = config.rope_scaling.get("mscale_all_dim")
|
| 315 |
+
|
| 316 |
+
# NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a
|
| 317 |
+
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
|
| 318 |
+
# values to compute the default attention scaling factor, instead of using `factor`.
|
| 319 |
+
if "original_max_position_embeddings" in config.rope_scaling:
|
| 320 |
+
original_max_position_embeddings = config.rope_scaling[
|
| 321 |
+
"original_max_position_embeddings"
|
| 322 |
+
]
|
| 323 |
+
factor = config.max_position_embeddings / original_max_position_embeddings
|
| 324 |
+
else:
|
| 325 |
+
original_max_position_embeddings = config.max_position_embeddings
|
| 326 |
+
|
| 327 |
+
def get_mscale(scale, mscale=1):
|
| 328 |
+
if scale <= 1:
|
| 329 |
+
return 1.0
|
| 330 |
+
return 0.1 * mscale * math.log(scale) + 1.0
|
| 331 |
+
|
| 332 |
+
# Sets the attention factor as suggested in the paper
|
| 333 |
+
if attention_factor is None:
|
| 334 |
+
if mscale and mscale_all_dim:
|
| 335 |
+
attention_factor = float(
|
| 336 |
+
get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim)
|
| 337 |
+
)
|
| 338 |
+
else:
|
| 339 |
+
attention_factor = get_mscale(factor)
|
| 340 |
+
|
| 341 |
+
# Optional config options
|
| 342 |
+
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
|
| 343 |
+
beta_fast = config.rope_scaling.get("beta_fast") or 32
|
| 344 |
+
beta_slow = config.rope_scaling.get("beta_slow") or 1
|
| 345 |
+
|
| 346 |
+
# Compute the inverse frequencies
|
| 347 |
+
def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
|
| 348 |
+
"""Inverse dimension formula to find the dimension based on the number of rotations"""
|
| 349 |
+
return (
|
| 350 |
+
dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))
|
| 351 |
+
) / (2 * math.log(base))
|
| 352 |
+
|
| 353 |
+
def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
|
| 354 |
+
"""Find dimension range bounds based on rotations"""
|
| 355 |
+
low = math.floor(
|
| 356 |
+
find_correction_dim(low_rot, dim, base, max_position_embeddings)
|
| 357 |
+
)
|
| 358 |
+
high = math.ceil(
|
| 359 |
+
find_correction_dim(high_rot, dim, base, max_position_embeddings)
|
| 360 |
+
)
|
| 361 |
+
return max(low, 0), min(high, dim - 1)
|
| 362 |
+
|
| 363 |
+
def linear_ramp_factor(min, max, dim):
|
| 364 |
+
if min == max:
|
| 365 |
+
max += 0.001 # Prevent singularity
|
| 366 |
+
|
| 367 |
+
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
| 368 |
+
ramp_func = torch.clamp(linear_func, 0, 1)
|
| 369 |
+
return ramp_func
|
| 370 |
+
|
| 371 |
+
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
|
| 372 |
+
# to expand the possible context length. In other words, interpolation = apply scaling factor.
|
| 373 |
+
pos_freqs = base ** (
|
| 374 |
+
torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim
|
| 375 |
+
)
|
| 376 |
+
inv_freq_extrapolation = 1.0 / pos_freqs
|
| 377 |
+
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
|
| 378 |
+
|
| 379 |
+
low, high = find_correction_range(
|
| 380 |
+
beta_fast, beta_slow, dim, base, original_max_position_embeddings
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
# Get n-dimensional rotational scaling corrected for extrapolation
|
| 384 |
+
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(
|
| 385 |
+
device=device, dtype=torch.float
|
| 386 |
+
)
|
| 387 |
+
inv_freq = (
|
| 388 |
+
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
|
| 389 |
+
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
|
| 390 |
+
)
|
| 391 |
+
return inv_freq, attention_factor
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def _compute_longrope_parameters(
|
| 395 |
+
config: PretrainedConfig,
|
| 396 |
+
device: "torch.device",
|
| 397 |
+
seq_len: Optional[int] = None,
|
| 398 |
+
**rope_kwargs,
|
| 399 |
+
) -> tuple["torch.Tensor", float]:
|
| 400 |
+
"""
|
| 401 |
+
Computes the inverse frequencies with LongRoPE scaling. Please refer to the
|
| 402 |
+
[original implementation](https://github.com/microsoft/LongRoPE)
|
| 403 |
+
Args:
|
| 404 |
+
config ([`~transformers.PretrainedConfig`]):
|
| 405 |
+
The model configuration.
|
| 406 |
+
device (`torch.device`):
|
| 407 |
+
The device to use for initialization of the inverse frequencies.
|
| 408 |
+
seq_len (`int`, *optional*):
|
| 409 |
+
The current sequence length.
|
| 410 |
+
rope_kwargs (`Dict`, *optional*):
|
| 411 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
| 412 |
+
Returns:
|
| 413 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 414 |
+
post-processing scaling factor applied to the computed cos/sin.
|
| 415 |
+
"""
|
| 416 |
+
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
|
| 417 |
+
# No need to keep BC with longrope, unreleased when this new pattern was created.
|
| 418 |
+
if len(rope_kwargs) > 0:
|
| 419 |
+
raise ValueError(
|
| 420 |
+
"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got "
|
| 421 |
+
f"{rope_kwargs}"
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
base = config.rope_theta
|
| 425 |
+
partial_rotary_factor = (
|
| 426 |
+
config.partial_rotary_factor
|
| 427 |
+
if hasattr(config, "partial_rotary_factor")
|
| 428 |
+
else 1.0
|
| 429 |
+
)
|
| 430 |
+
head_dim = getattr(
|
| 431 |
+
config, "head_dim", config.hidden_size // config.num_attention_heads
|
| 432 |
+
)
|
| 433 |
+
dim = int(head_dim * partial_rotary_factor)
|
| 434 |
+
long_factor = config.rope_scaling["long_factor"]
|
| 435 |
+
short_factor = config.rope_scaling["short_factor"]
|
| 436 |
+
factor = config.rope_scaling.get("factor")
|
| 437 |
+
attention_factor = config.rope_scaling.get("attention_factor")
|
| 438 |
+
|
| 439 |
+
# NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
|
| 440 |
+
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
|
| 441 |
+
# values to compute the default attention scaling factor, instead of using `factor`.
|
| 442 |
+
if hasattr(config, "original_max_position_embeddings"):
|
| 443 |
+
original_max_position_embeddings = config.original_max_position_embeddings
|
| 444 |
+
factor = (
|
| 445 |
+
config.max_position_embeddings / config.original_max_position_embeddings
|
| 446 |
+
)
|
| 447 |
+
else:
|
| 448 |
+
original_max_position_embeddings = config.max_position_embeddings
|
| 449 |
+
|
| 450 |
+
# Sets the attention factor as suggested in the paper
|
| 451 |
+
if attention_factor is None:
|
| 452 |
+
if factor <= 1.0:
|
| 453 |
+
attention_factor = 1.0
|
| 454 |
+
else:
|
| 455 |
+
attention_factor = math.sqrt(
|
| 456 |
+
1 + math.log(factor) / math.log(original_max_position_embeddings)
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
# Compute the inverse frequencies -- scaled based on the target sequence length
|
| 460 |
+
if seq_len and seq_len > original_max_position_embeddings:
|
| 461 |
+
ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
|
| 462 |
+
else:
|
| 463 |
+
ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
|
| 464 |
+
inv_freq_shape = (
|
| 465 |
+
torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim
|
| 466 |
+
)
|
| 467 |
+
inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
|
| 468 |
+
|
| 469 |
+
return inv_freq, attention_factor
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def _compute_llama3_parameters(
|
| 473 |
+
config: PretrainedConfig,
|
| 474 |
+
device: "torch.device",
|
| 475 |
+
seq_len: Optional[int] = None,
|
| 476 |
+
**rope_kwargs,
|
| 477 |
+
) -> tuple["torch.Tensor", float]:
|
| 478 |
+
"""
|
| 479 |
+
Computes the inverse frequencies for llama 3.1.
|
| 480 |
+
|
| 481 |
+
Args:
|
| 482 |
+
config ([`~transformers.PretrainedConfig`]):
|
| 483 |
+
The model configuration.
|
| 484 |
+
device (`torch.device`):
|
| 485 |
+
The device to use for initialization of the inverse frequencies.
|
| 486 |
+
seq_len (`int`, *optional*):
|
| 487 |
+
The current sequence length. Unused for this type of RoPE.
|
| 488 |
+
rope_kwargs (`Dict`, *optional*):
|
| 489 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
| 490 |
+
Returns:
|
| 491 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 492 |
+
post-processing scaling factor applied to the computed cos/sin.
|
| 493 |
+
"""
|
| 494 |
+
# Gets the default RoPE parameters
|
| 495 |
+
inv_freq, attention_factor = _compute_default_rope_parameters(
|
| 496 |
+
config, device, seq_len, **rope_kwargs
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
factor = config.rope_scaling["factor"] # `8` in the original implementation
|
| 500 |
+
low_freq_factor = config.rope_scaling[
|
| 501 |
+
"low_freq_factor"
|
| 502 |
+
] # `1` in the original implementation
|
| 503 |
+
high_freq_factor = config.rope_scaling[
|
| 504 |
+
"high_freq_factor"
|
| 505 |
+
] # `4` in the original implementation
|
| 506 |
+
old_context_len = config.rope_scaling[
|
| 507 |
+
"original_max_position_embeddings"
|
| 508 |
+
] # `8192` in the original implementation
|
| 509 |
+
|
| 510 |
+
low_freq_wavelen = old_context_len / low_freq_factor
|
| 511 |
+
high_freq_wavelen = old_context_len / high_freq_factor
|
| 512 |
+
|
| 513 |
+
wavelen = 2 * math.pi / inv_freq
|
| 514 |
+
# wavelen < high_freq_wavelen: do nothing
|
| 515 |
+
# wavelen > low_freq_wavelen: divide by factor
|
| 516 |
+
inv_freq_llama = torch.where(
|
| 517 |
+
wavelen > low_freq_wavelen, inv_freq / factor, inv_freq
|
| 518 |
+
)
|
| 519 |
+
# otherwise: interpolate between the two, using a smooth factor
|
| 520 |
+
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (
|
| 521 |
+
high_freq_factor - low_freq_factor
|
| 522 |
+
)
|
| 523 |
+
smoothed_inv_freq = (
|
| 524 |
+
1 - smooth_factor
|
| 525 |
+
) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
|
| 526 |
+
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
|
| 527 |
+
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
|
| 528 |
+
|
| 529 |
+
return inv_freq_llama, attention_factor
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
|
| 533 |
+
# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
|
| 534 |
+
# parameterizations, as long as the callable has the same signature.
|
| 535 |
+
ROPE_INIT_FUNCTIONS = {
|
| 536 |
+
"default": _compute_default_rope_parameters,
|
| 537 |
+
"linear": _compute_linear_scaling_rope_parameters,
|
| 538 |
+
"dynamic": _compute_dynamic_ntk_parameters,
|
| 539 |
+
"yarn": _compute_yarn_parameters,
|
| 540 |
+
"longrope": _compute_longrope_parameters,
|
| 541 |
+
"llama3": _compute_llama3_parameters,
|
| 542 |
+
}
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def _check_received_keys(
|
| 546 |
+
rope_type: str,
|
| 547 |
+
received_keys: set,
|
| 548 |
+
required_keys: set,
|
| 549 |
+
optional_keys: Optional[set] = None,
|
| 550 |
+
ignore_keys: Optional[set] = None,
|
| 551 |
+
):
|
| 552 |
+
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
|
| 553 |
+
# BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
|
| 554 |
+
if "type" in received_keys:
|
| 555 |
+
received_keys -= {"type"}
|
| 556 |
+
required_keys.add("rope_type")
|
| 557 |
+
|
| 558 |
+
# Some models need to store model-specific keys, and we don't want to throw warning at them
|
| 559 |
+
if ignore_keys is not None:
|
| 560 |
+
received_keys -= ignore_keys
|
| 561 |
+
|
| 562 |
+
missing_keys = required_keys - received_keys
|
| 563 |
+
if missing_keys:
|
| 564 |
+
raise KeyError(
|
| 565 |
+
f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}"
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
if optional_keys is not None:
|
| 569 |
+
unused_keys = received_keys - required_keys - optional_keys
|
| 570 |
+
else:
|
| 571 |
+
unused_keys = received_keys - required_keys
|
| 572 |
+
if unused_keys:
|
| 573 |
+
logger.warning(
|
| 574 |
+
f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}"
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
def _validate_default_rope_parameters(
|
| 579 |
+
config: PretrainedConfig, ignore_keys: Optional[set] = None
|
| 580 |
+
):
|
| 581 |
+
rope_scaling = config.rope_scaling
|
| 582 |
+
rope_type = rope_scaling.get(
|
| 583 |
+
"rope_type", rope_scaling.get("type", None)
|
| 584 |
+
) # BC: "rope_type" was originally "type"
|
| 585 |
+
required_keys = {"rope_type"}
|
| 586 |
+
received_keys = set(rope_scaling.keys())
|
| 587 |
+
_check_received_keys(
|
| 588 |
+
rope_type, received_keys, required_keys, ignore_keys=ignore_keys
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
def _validate_linear_scaling_rope_parameters(
|
| 593 |
+
config: PretrainedConfig, ignore_keys: Optional[set] = None
|
| 594 |
+
):
|
| 595 |
+
rope_scaling = config.rope_scaling
|
| 596 |
+
rope_type = rope_scaling.get(
|
| 597 |
+
"rope_type", rope_scaling.get("type", None)
|
| 598 |
+
) # BC: "rope_type" was originally "type"
|
| 599 |
+
required_keys = {"rope_type", "factor"}
|
| 600 |
+
received_keys = set(rope_scaling.keys())
|
| 601 |
+
_check_received_keys(
|
| 602 |
+
rope_type, received_keys, required_keys, ignore_keys=ignore_keys
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
factor = rope_scaling["factor"]
|
| 606 |
+
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
| 607 |
+
logger.warning(
|
| 608 |
+
f"`rope_scaling`'s factor field must be a float >= 1, got {factor}"
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
def _validate_dynamic_scaling_rope_parameters(
|
| 613 |
+
config: PretrainedConfig, ignore_keys: Optional[set] = None
|
| 614 |
+
):
|
| 615 |
+
rope_scaling = config.rope_scaling
|
| 616 |
+
rope_type = rope_scaling.get(
|
| 617 |
+
"rope_type", rope_scaling.get("type", None)
|
| 618 |
+
) # BC: "rope_type" was originally "type"
|
| 619 |
+
required_keys = {"rope_type", "factor"}
|
| 620 |
+
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
| 621 |
+
optional_keys = {"original_max_position_embeddings"}
|
| 622 |
+
received_keys = set(rope_scaling.keys())
|
| 623 |
+
_check_received_keys(
|
| 624 |
+
rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
factor = rope_scaling["factor"]
|
| 628 |
+
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
| 629 |
+
logger.warning(
|
| 630 |
+
f"`rope_scaling`'s factor field must be a float >= 1, got {factor}"
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
def _validate_yarn_parameters(
|
| 635 |
+
config: PretrainedConfig, ignore_keys: Optional[set] = None
|
| 636 |
+
):
|
| 637 |
+
rope_scaling = config.rope_scaling
|
| 638 |
+
rope_type = rope_scaling.get(
|
| 639 |
+
"rope_type", rope_scaling.get("type", None)
|
| 640 |
+
) # BC: "rope_type" was originally "type"
|
| 641 |
+
required_keys = {"rope_type", "factor"}
|
| 642 |
+
optional_keys = {
|
| 643 |
+
"attention_factor",
|
| 644 |
+
"beta_fast",
|
| 645 |
+
"beta_slow",
|
| 646 |
+
"original_max_position_embeddings",
|
| 647 |
+
"mscale",
|
| 648 |
+
"mscale_all_dim",
|
| 649 |
+
}
|
| 650 |
+
received_keys = set(rope_scaling.keys())
|
| 651 |
+
_check_received_keys(
|
| 652 |
+
rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
factor = rope_scaling["factor"]
|
| 656 |
+
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
| 657 |
+
logger.warning(
|
| 658 |
+
f"`rope_scaling`'s factor field must be a float >= 1, got {factor}"
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
attention_factor = rope_scaling.get("attention_factor")
|
| 662 |
+
if attention_factor is not None and (
|
| 663 |
+
not isinstance(attention_factor, float) or attention_factor < 0
|
| 664 |
+
):
|
| 665 |
+
logger.warning(
|
| 666 |
+
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
|
| 667 |
+
)
|
| 668 |
+
beta_fast = rope_scaling.get("beta_fast")
|
| 669 |
+
if beta_fast is not None and not isinstance(beta_fast, float):
|
| 670 |
+
logger.warning(
|
| 671 |
+
f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}"
|
| 672 |
+
)
|
| 673 |
+
beta_slow = rope_scaling.get("beta_slow")
|
| 674 |
+
if beta_slow is not None and not isinstance(beta_slow, float):
|
| 675 |
+
logger.warning(
|
| 676 |
+
f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}"
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
if (beta_fast or 32) < (beta_slow or 1):
|
| 680 |
+
logger.warning(
|
| 681 |
+
f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
|
| 682 |
+
f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
def _validate_longrope_parameters(
|
| 687 |
+
config: PretrainedConfig, ignore_keys: Optional[set] = None
|
| 688 |
+
):
|
| 689 |
+
rope_scaling = config.rope_scaling
|
| 690 |
+
rope_type = rope_scaling.get(
|
| 691 |
+
"rope_type", rope_scaling.get("type", None)
|
| 692 |
+
) # BC: "rope_type" was originally "type"
|
| 693 |
+
required_keys = {"rope_type", "short_factor", "long_factor"}
|
| 694 |
+
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
| 695 |
+
optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
|
| 696 |
+
received_keys = set(rope_scaling.keys())
|
| 697 |
+
_check_received_keys(
|
| 698 |
+
rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
partial_rotary_factor = (
|
| 702 |
+
config.partial_rotary_factor
|
| 703 |
+
if hasattr(config, "partial_rotary_factor")
|
| 704 |
+
else 1.0
|
| 705 |
+
)
|
| 706 |
+
head_dim = getattr(
|
| 707 |
+
config, "head_dim", config.hidden_size // config.num_attention_heads
|
| 708 |
+
)
|
| 709 |
+
dim = int(head_dim * partial_rotary_factor)
|
| 710 |
+
|
| 711 |
+
short_factor = rope_scaling.get("short_factor")
|
| 712 |
+
if not isinstance(short_factor, list) and all(
|
| 713 |
+
isinstance(x, (int, float)) for x in short_factor
|
| 714 |
+
):
|
| 715 |
+
logger.warning(
|
| 716 |
+
f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}"
|
| 717 |
+
)
|
| 718 |
+
if not len(short_factor) == dim // 2:
|
| 719 |
+
logger.warning(
|
| 720 |
+
f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}"
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
long_factor = rope_scaling.get("long_factor")
|
| 724 |
+
if not isinstance(long_factor, list) and all(
|
| 725 |
+
isinstance(x, (int, float)) for x in long_factor
|
| 726 |
+
):
|
| 727 |
+
logger.warning(
|
| 728 |
+
f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}"
|
| 729 |
+
)
|
| 730 |
+
if not len(long_factor) == dim // 2:
|
| 731 |
+
logger.warning(
|
| 732 |
+
f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}"
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
# Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
|
| 736 |
+
# `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is
|
| 737 |
+
# unique to longrope (= undesirable)
|
| 738 |
+
if hasattr(config, "original_max_position_embeddings"):
|
| 739 |
+
logger.warning_once(
|
| 740 |
+
"This model has set a `original_max_position_embeddings` field, to be used together with "
|
| 741 |
+
"`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`"
|
| 742 |
+
"with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, "
|
| 743 |
+
"as it is compatible with most model architectures."
|
| 744 |
+
)
|
| 745 |
+
else:
|
| 746 |
+
factor = rope_scaling.get("factor")
|
| 747 |
+
if factor is None:
|
| 748 |
+
logger.warning("Missing required keys in `rope_scaling`: 'factor'")
|
| 749 |
+
elif not isinstance(factor, float) or factor < 1.0:
|
| 750 |
+
logger.warning(
|
| 751 |
+
f"`rope_scaling`'s factor field must be a float >= 1, got {factor}"
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
attention_factor = rope_scaling.get("attention_factor")
|
| 755 |
+
if attention_factor is not None:
|
| 756 |
+
if not isinstance(attention_factor, float) or attention_factor < 0.0:
|
| 757 |
+
logger.warning(
|
| 758 |
+
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
|
| 759 |
+
)
|
| 760 |
+
|
| 761 |
+
|
| 762 |
+
def _validate_llama3_parameters(
|
| 763 |
+
config: PretrainedConfig, ignore_keys: Optional[set] = None
|
| 764 |
+
):
|
| 765 |
+
rope_scaling = config.rope_scaling
|
| 766 |
+
rope_type = rope_scaling.get(
|
| 767 |
+
"rope_type", rope_scaling.get("type", None)
|
| 768 |
+
) # BC: "rope_type" was originally "type"
|
| 769 |
+
required_keys = {
|
| 770 |
+
"rope_type",
|
| 771 |
+
"factor",
|
| 772 |
+
"original_max_position_embeddings",
|
| 773 |
+
"low_freq_factor",
|
| 774 |
+
"high_freq_factor",
|
| 775 |
+
}
|
| 776 |
+
received_keys = set(rope_scaling.keys())
|
| 777 |
+
_check_received_keys(
|
| 778 |
+
rope_type, received_keys, required_keys, ignore_keys=ignore_keys
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
factor = rope_scaling["factor"]
|
| 782 |
+
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
| 783 |
+
logger.warning(
|
| 784 |
+
f"`rope_scaling`'s factor field must be a float >= 1, got {factor}"
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
low_freq_factor = rope_scaling["low_freq_factor"]
|
| 788 |
+
high_freq_factor = rope_scaling["high_freq_factor"]
|
| 789 |
+
if low_freq_factor is None or not isinstance(low_freq_factor, float):
|
| 790 |
+
logger.warning(
|
| 791 |
+
f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}"
|
| 792 |
+
)
|
| 793 |
+
if high_freq_factor is None or not isinstance(high_freq_factor, float):
|
| 794 |
+
logger.warning(
|
| 795 |
+
f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}"
|
| 796 |
+
)
|
| 797 |
+
if high_freq_factor <= low_freq_factor:
|
| 798 |
+
logger.warning(
|
| 799 |
+
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
|
| 800 |
+
f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
|
| 804 |
+
if original_max_position_embeddings is None or not isinstance(
|
| 805 |
+
original_max_position_embeddings, int
|
| 806 |
+
):
|
| 807 |
+
logger.warning(
|
| 808 |
+
"`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
|
| 809 |
+
f"{original_max_position_embeddings}"
|
| 810 |
+
)
|
| 811 |
+
if original_max_position_embeddings >= config.max_position_embeddings:
|
| 812 |
+
logger.warning(
|
| 813 |
+
"`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
|
| 814 |
+
f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
|
| 819 |
+
ROPE_VALIDATION_FUNCTIONS = {
|
| 820 |
+
"default": _validate_default_rope_parameters,
|
| 821 |
+
"linear": _validate_linear_scaling_rope_parameters,
|
| 822 |
+
"dynamic": _validate_dynamic_scaling_rope_parameters,
|
| 823 |
+
"yarn": _validate_yarn_parameters,
|
| 824 |
+
"longrope": _validate_longrope_parameters,
|
| 825 |
+
"llama3": _validate_llama3_parameters,
|
| 826 |
+
}
|
| 827 |
+
|
| 828 |
+
|
| 829 |
+
def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
| 830 |
+
"""
|
| 831 |
+
Validate the RoPE config arguments, given a `PretrainedConfig` object
|
| 832 |
+
"""
|
| 833 |
+
rope_scaling = getattr(
|
| 834 |
+
config, "rope_scaling", None
|
| 835 |
+
) # not a default parameter in `PretrainedConfig`
|
| 836 |
+
if rope_scaling is None:
|
| 837 |
+
return
|
| 838 |
+
|
| 839 |
+
# BC: "rope_type" was originally "type"
|
| 840 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
|
| 841 |
+
validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
|
| 842 |
+
if validation_fn is not None:
|
| 843 |
+
validation_fn(config, ignore_keys=ignore_keys)
|
| 844 |
+
else:
|
| 845 |
+
logger.warning(
|
| 846 |
+
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
|
| 847 |
+
)
|
| 848 |
+
|
| 849 |
+
|
| 850 |
+
def rotate_half(x):
|
| 851 |
+
"""Rotates half the hidden dims of the input."""
|
| 852 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 853 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 854 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 855 |
+
|
| 856 |
+
def apply_rotary_pos_emb(x, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 857 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 858 |
+
|
| 859 |
+
Args:
|
| 860 |
+
x (`torch.Tensor`): The input tensor.
|
| 861 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 862 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 863 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 864 |
+
Deprecated and unused.
|
| 865 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 866 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 867 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 868 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 869 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 870 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 871 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 872 |
+
Returns:
|
| 873 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 874 |
+
"""
|
| 875 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 876 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 877 |
+
x_embed = (x * cos) + (rotate_half(x) * sin)
|
| 878 |
+
return x_embed
|
src/mimo_audio_tokenizer/quantization.py
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Xiaomi Corporation.
|
| 2 |
+
import typing as tp
|
| 3 |
+
|
| 4 |
+
from einops import rearrange, repeat
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def rank():
|
| 13 |
+
if dist.is_initialized():
|
| 14 |
+
return dist.get_rank()
|
| 15 |
+
else:
|
| 16 |
+
return 0
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def world_size():
|
| 20 |
+
if dist.is_initialized():
|
| 21 |
+
return dist.get_world_size()
|
| 22 |
+
else:
|
| 23 |
+
return 1
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
| 27 |
+
return val if val is not None else d
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def ema_inplace(moving_avg, new, decay: float):
|
| 31 |
+
if dist.is_initialized():
|
| 32 |
+
dist.all_reduce(new, op=dist.ReduceOp.SUM)
|
| 33 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
| 37 |
+
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def uniform_init(*shape: int):
|
| 41 |
+
t = torch.empty(shape)
|
| 42 |
+
nn.init.kaiming_uniform_(t)
|
| 43 |
+
return t
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def sample_vectors(samples, num: int):
|
| 47 |
+
num_samples, device = samples.shape[0], samples.device
|
| 48 |
+
|
| 49 |
+
if num_samples >= num:
|
| 50 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
| 51 |
+
else:
|
| 52 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
| 53 |
+
|
| 54 |
+
selected_samples = samples[indices]
|
| 55 |
+
|
| 56 |
+
if dist.is_initialized():
|
| 57 |
+
|
| 58 |
+
dist.broadcast(selected_samples, src=0)
|
| 59 |
+
|
| 60 |
+
return selected_samples
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
| 64 |
+
dim, dtype = samples.shape[-1], samples.dtype
|
| 65 |
+
|
| 66 |
+
means = sample_vectors(samples, num_clusters)
|
| 67 |
+
|
| 68 |
+
for _ in range(num_iters):
|
| 69 |
+
dists = -(
|
| 70 |
+
samples.pow(2).sum(1, keepdim=True)
|
| 71 |
+
- 2 * samples @ means.t()
|
| 72 |
+
+ means.t().pow(2).sum(0, keepdim=True)
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
buckets = dists.max(dim=-1).indices
|
| 76 |
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
| 77 |
+
|
| 78 |
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
| 79 |
+
new_means = new_means.scatter_add_(
|
| 80 |
+
0, repeat(buckets, "n -> n d", d=dim), samples
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
if dist.is_initialized():
|
| 84 |
+
dist.all_reduce(bins, op=dist.ReduceOp.SUM)
|
| 85 |
+
dist.all_reduce(new_means, op=dist.ReduceOp.SUM)
|
| 86 |
+
|
| 87 |
+
zero_mask = bins == 0
|
| 88 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
| 89 |
+
|
| 90 |
+
new_means = new_means / bins_min_clamped[..., None]
|
| 91 |
+
|
| 92 |
+
means = torch.where(zero_mask[..., None], means, new_means)
|
| 93 |
+
|
| 94 |
+
return means, bins
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class EuclideanCodebook(nn.Module):
|
| 98 |
+
"""Codebook with Euclidean distance.
|
| 99 |
+
Args:
|
| 100 |
+
dim (int): Dimension.
|
| 101 |
+
codebook_size (int): Codebook size.
|
| 102 |
+
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
| 103 |
+
If set to true, run the k-means algorithm on the first training batch and use
|
| 104 |
+
the learned centroids as initialization.
|
| 105 |
+
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
| 106 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
| 107 |
+
epsilon (float): Epsilon value for numerical stability.
|
| 108 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
| 109 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
| 110 |
+
randomly selected vector from the current batch.
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
def __init__(
|
| 114 |
+
self,
|
| 115 |
+
dim: int,
|
| 116 |
+
codebook_size: int,
|
| 117 |
+
kmeans_init: int = False,
|
| 118 |
+
kmeans_iters: int = 10,
|
| 119 |
+
decay: float = 0.99,
|
| 120 |
+
epsilon: float = 1e-5,
|
| 121 |
+
threshold_ema_dead_code: int = 2,
|
| 122 |
+
):
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.decay = decay
|
| 125 |
+
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
|
| 126 |
+
uniform_init if not kmeans_init else torch.zeros
|
| 127 |
+
)
|
| 128 |
+
embed = init_fn(codebook_size, dim)
|
| 129 |
+
|
| 130 |
+
self.codebook_size = codebook_size
|
| 131 |
+
|
| 132 |
+
self.kmeans_iters = kmeans_iters
|
| 133 |
+
self.epsilon = epsilon
|
| 134 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
| 135 |
+
|
| 136 |
+
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
|
| 137 |
+
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
| 138 |
+
self.register_buffer("embed", embed)
|
| 139 |
+
self.register_buffer("embed_avg", embed.clone())
|
| 140 |
+
|
| 141 |
+
@torch.jit.ignore
|
| 142 |
+
def init_embed_(self, data):
|
| 143 |
+
if self.inited:
|
| 144 |
+
return
|
| 145 |
+
|
| 146 |
+
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
| 147 |
+
self.embed.data.copy_(embed)
|
| 148 |
+
self.embed_avg.data.copy_(embed.clone())
|
| 149 |
+
self.cluster_size.data.copy_(cluster_size)
|
| 150 |
+
self.inited.data.copy_(torch.Tensor([True]))
|
| 151 |
+
|
| 152 |
+
def replace_(self, samples, mask):
|
| 153 |
+
# modified_codebook = torch.where(
|
| 154 |
+
# mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
| 155 |
+
# )
|
| 156 |
+
replace_num = mask.sum()
|
| 157 |
+
modified_codebook = self.embed.clone()
|
| 158 |
+
modified_codebook[mask] = sample_vectors(samples, replace_num)
|
| 159 |
+
self.embed.data.copy_(modified_codebook)
|
| 160 |
+
|
| 161 |
+
def expire_codes_(self, batch_samples):
|
| 162 |
+
if self.threshold_ema_dead_code == 0:
|
| 163 |
+
return
|
| 164 |
+
|
| 165 |
+
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
| 166 |
+
if not torch.any(expired_codes):
|
| 167 |
+
return
|
| 168 |
+
|
| 169 |
+
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
| 170 |
+
self.replace_(batch_samples, mask=expired_codes)
|
| 171 |
+
|
| 172 |
+
def preprocess(self, x):
|
| 173 |
+
x = rearrange(x, "... d -> (...) d")
|
| 174 |
+
return x
|
| 175 |
+
|
| 176 |
+
def quantize(self, x):
|
| 177 |
+
embed = self.embed.t()
|
| 178 |
+
dist = -(
|
| 179 |
+
x.pow(2).sum(1, keepdim=True)
|
| 180 |
+
- 2 * x @ embed
|
| 181 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
| 182 |
+
)
|
| 183 |
+
embed_ind = dist.max(dim=-1).indices
|
| 184 |
+
return embed_ind
|
| 185 |
+
|
| 186 |
+
def postprocess_emb(self, embed_ind, shape):
|
| 187 |
+
return embed_ind.view(*shape[:-1])
|
| 188 |
+
|
| 189 |
+
def dequantize(self, embed_ind):
|
| 190 |
+
quantize = F.embedding(embed_ind, self.embed)
|
| 191 |
+
return quantize
|
| 192 |
+
|
| 193 |
+
def encode(self, x):
|
| 194 |
+
shape = x.shape
|
| 195 |
+
# pre-process
|
| 196 |
+
x = self.preprocess(x)
|
| 197 |
+
# quantize
|
| 198 |
+
embed_ind = self.quantize(x)
|
| 199 |
+
# post-process
|
| 200 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
| 201 |
+
return embed_ind
|
| 202 |
+
|
| 203 |
+
def decode(self, embed_ind):
|
| 204 |
+
quantize = self.dequantize(embed_ind)
|
| 205 |
+
return quantize
|
| 206 |
+
|
| 207 |
+
def forward(self, x):
|
| 208 |
+
shape, dtype = x.shape, x.dtype
|
| 209 |
+
x = self.preprocess(x)
|
| 210 |
+
|
| 211 |
+
self.init_embed_(x)
|
| 212 |
+
|
| 213 |
+
embed_ind = self.quantize(x)
|
| 214 |
+
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
| 215 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
| 216 |
+
quantize = self.dequantize(embed_ind)
|
| 217 |
+
|
| 218 |
+
if self.training:
|
| 219 |
+
# We do the expiry of code at that point as buffers are in sync
|
| 220 |
+
# and all the workers will take the same decision.
|
| 221 |
+
self.expire_codes_(x)
|
| 222 |
+
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
| 223 |
+
embed_sum = x.t() @ embed_onehot
|
| 224 |
+
ema_inplace(self.embed_avg, embed_sum.t().contiguous(), self.decay)
|
| 225 |
+
cluster_size = (
|
| 226 |
+
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
| 227 |
+
* self.cluster_size.sum()
|
| 228 |
+
)
|
| 229 |
+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
| 230 |
+
self.embed.data.copy_(embed_normalized)
|
| 231 |
+
|
| 232 |
+
return quantize, embed_ind
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class VectorQuantization(nn.Module):
|
| 236 |
+
"""Vector quantization implementation.
|
| 237 |
+
Currently supports only euclidean distance.
|
| 238 |
+
Args:
|
| 239 |
+
dim (int): Dimension
|
| 240 |
+
codebook_size (int): Codebook size
|
| 241 |
+
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
| 242 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
| 243 |
+
epsilon (float): Epsilon value for numerical stability.
|
| 244 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
| 245 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
| 246 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
| 247 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
| 248 |
+
randomly selected vector from the current batch.
|
| 249 |
+
commitment_weight (float): Weight for commitment loss.
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
def __init__(
|
| 253 |
+
self,
|
| 254 |
+
dim: int,
|
| 255 |
+
codebook_size: int,
|
| 256 |
+
codebook_dim: tp.Optional[int] = None,
|
| 257 |
+
decay: float = 0.99,
|
| 258 |
+
epsilon: float = 1e-5,
|
| 259 |
+
kmeans_init: bool = True,
|
| 260 |
+
kmeans_iters: int = 50,
|
| 261 |
+
threshold_ema_dead_code: int = 2,
|
| 262 |
+
commitment_weight: float = 1.0,
|
| 263 |
+
):
|
| 264 |
+
super().__init__()
|
| 265 |
+
_codebook_dim: int = default(codebook_dim, dim)
|
| 266 |
+
|
| 267 |
+
requires_projection = _codebook_dim != dim
|
| 268 |
+
self.project_in = (
|
| 269 |
+
nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
|
| 270 |
+
)
|
| 271 |
+
self.project_out = (
|
| 272 |
+
nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
self.epsilon = epsilon
|
| 276 |
+
self.commitment_weight = commitment_weight
|
| 277 |
+
|
| 278 |
+
self._codebook = EuclideanCodebook(
|
| 279 |
+
dim=_codebook_dim,
|
| 280 |
+
codebook_size=codebook_size,
|
| 281 |
+
kmeans_init=kmeans_init,
|
| 282 |
+
kmeans_iters=kmeans_iters,
|
| 283 |
+
decay=decay,
|
| 284 |
+
epsilon=epsilon,
|
| 285 |
+
threshold_ema_dead_code=threshold_ema_dead_code,
|
| 286 |
+
)
|
| 287 |
+
self.codebook_size = codebook_size
|
| 288 |
+
|
| 289 |
+
@property
|
| 290 |
+
def codebook(self):
|
| 291 |
+
return self._codebook.embed
|
| 292 |
+
|
| 293 |
+
def encode(self, x):
|
| 294 |
+
# x = rearrange(x, "b d n -> b n d")
|
| 295 |
+
x = self.project_in(x)
|
| 296 |
+
embed_in = self._codebook.encode(x)
|
| 297 |
+
return embed_in
|
| 298 |
+
|
| 299 |
+
def decode(self, embed_ind):
|
| 300 |
+
quantize = self._codebook.decode(embed_ind)
|
| 301 |
+
quantize = self.project_out(quantize)
|
| 302 |
+
# quantize = rearrange(quantize, "b n d -> b d n")
|
| 303 |
+
return quantize
|
| 304 |
+
|
| 305 |
+
def forward(self, x):
|
| 306 |
+
device = x.device
|
| 307 |
+
x = self.project_in(x)
|
| 308 |
+
|
| 309 |
+
quantize, embed_ind = self._codebook(x)
|
| 310 |
+
|
| 311 |
+
if self.training:
|
| 312 |
+
quantize = x + (quantize - x).detach()
|
| 313 |
+
|
| 314 |
+
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
| 315 |
+
|
| 316 |
+
if self.training:
|
| 317 |
+
if self.commitment_weight > 0:
|
| 318 |
+
commit_loss = F.mse_loss(quantize.detach(), x)
|
| 319 |
+
loss = loss + commit_loss * self.commitment_weight
|
| 320 |
+
|
| 321 |
+
quantize = self.project_out(quantize)
|
| 322 |
+
# quantize = rearrange(quantize, "b n d -> b d n")
|
| 323 |
+
return quantize, embed_ind, loss
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
class ResidualVectorQuantization(nn.Module):
|
| 327 |
+
"""Residual vector quantization implementation.
|
| 328 |
+
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
| 329 |
+
"""
|
| 330 |
+
|
| 331 |
+
def __init__(self, *, num_quantizers, codebook_size, **kwargs):
|
| 332 |
+
super().__init__()
|
| 333 |
+
if isinstance(codebook_size, int):
|
| 334 |
+
codebook_size = [codebook_size] * num_quantizers
|
| 335 |
+
elif len(codebook_size) < num_quantizers:
|
| 336 |
+
codebook_size += [codebook_size[-1]] * (num_quantizers - len(codebook_size))
|
| 337 |
+
self.layers = nn.ModuleList(
|
| 338 |
+
[
|
| 339 |
+
VectorQuantization(codebook_size=codebook_size[i], **kwargs)
|
| 340 |
+
for i in range(num_quantizers)
|
| 341 |
+
]
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
def forward(
|
| 345 |
+
self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None
|
| 346 |
+
):
|
| 347 |
+
quantized_out = 0.0
|
| 348 |
+
residual = x
|
| 349 |
+
|
| 350 |
+
all_losses = []
|
| 351 |
+
all_indices = []
|
| 352 |
+
out_quantized = []
|
| 353 |
+
|
| 354 |
+
n_q = n_q or len(self.layers)
|
| 355 |
+
|
| 356 |
+
for i, layer in enumerate(self.layers[:n_q]):
|
| 357 |
+
quantized, indices, loss = layer(residual)
|
| 358 |
+
residual = residual - quantized
|
| 359 |
+
quantized_out = quantized_out + quantized
|
| 360 |
+
|
| 361 |
+
all_indices.append(indices)
|
| 362 |
+
all_losses.append(loss)
|
| 363 |
+
if layers and i in layers:
|
| 364 |
+
out_quantized.append(quantized_out)
|
| 365 |
+
|
| 366 |
+
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
| 367 |
+
return quantized_out, out_indices, out_losses, out_quantized
|
| 368 |
+
|
| 369 |
+
def encode(
|
| 370 |
+
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
|
| 371 |
+
) -> torch.Tensor:
|
| 372 |
+
residual = x
|
| 373 |
+
all_indices = []
|
| 374 |
+
n_q = len(self.layers) if n_q is None else n_q
|
| 375 |
+
st = 0 if st is None else st
|
| 376 |
+
for layer in self.layers[st:n_q]:
|
| 377 |
+
indices = layer.encode(residual)
|
| 378 |
+
quantized = layer.decode(indices)
|
| 379 |
+
residual = residual - quantized
|
| 380 |
+
all_indices.append(indices)
|
| 381 |
+
out_indices = torch.stack(all_indices)
|
| 382 |
+
return out_indices
|
| 383 |
+
|
| 384 |
+
def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
|
| 385 |
+
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
| 386 |
+
for i, indices in enumerate(q_indices):
|
| 387 |
+
layer = self.layers[st + i]
|
| 388 |
+
quantized = layer.decode(indices)
|
| 389 |
+
quantized_out = quantized_out + quantized
|
| 390 |
+
return quantized_out
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
class ResidualVectorQuantizer(nn.Module):
|
| 394 |
+
"""Residual Vector Quantizer.
|
| 395 |
+
Args:
|
| 396 |
+
dimension (int): Dimension of the codebooks.
|
| 397 |
+
n_q (int): Number of residual vector quantizers used.
|
| 398 |
+
bins (int): Codebook size.
|
| 399 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
| 400 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
| 401 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
| 402 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
| 403 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
| 404 |
+
randomly selected vector from the current batch.
|
| 405 |
+
"""
|
| 406 |
+
|
| 407 |
+
def __init__(
|
| 408 |
+
self,
|
| 409 |
+
dimension: int = 256,
|
| 410 |
+
n_q: int = 8,
|
| 411 |
+
bins: int | list = 1024,
|
| 412 |
+
decay: float = 0.99,
|
| 413 |
+
kmeans_init: bool = True,
|
| 414 |
+
kmeans_iters: int = 50,
|
| 415 |
+
threshold_ema_dead_code: int = 2,
|
| 416 |
+
):
|
| 417 |
+
super().__init__()
|
| 418 |
+
self.n_q = n_q
|
| 419 |
+
self.dimension = dimension
|
| 420 |
+
self.bins = bins
|
| 421 |
+
self.decay = decay
|
| 422 |
+
self.kmeans_init = kmeans_init
|
| 423 |
+
self.kmeans_iters = kmeans_iters
|
| 424 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
| 425 |
+
self.vq = ResidualVectorQuantization(
|
| 426 |
+
dim=self.dimension,
|
| 427 |
+
codebook_size=self.bins,
|
| 428 |
+
num_quantizers=self.n_q,
|
| 429 |
+
decay=self.decay,
|
| 430 |
+
kmeans_init=self.kmeans_init,
|
| 431 |
+
kmeans_iters=self.kmeans_iters,
|
| 432 |
+
threshold_ema_dead_code=self.threshold_ema_dead_code,
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
def forward(
|
| 436 |
+
self,
|
| 437 |
+
x: torch.Tensor,
|
| 438 |
+
n_q: tp.Optional[int] = None,
|
| 439 |
+
layers: tp.Optional[list] = None,
|
| 440 |
+
):
|
| 441 |
+
"""Residual vector quantization on the given input tensor.
|
| 442 |
+
Args:
|
| 443 |
+
x (torch.Tensor): Input tensor.
|
| 444 |
+
n_q (int): Number of quantizer used to quantize. Default: All quantizers.
|
| 445 |
+
layers (list): Layer that need to return quantized. Defalt: None.
|
| 446 |
+
Returns:
|
| 447 |
+
QuantizedResult:
|
| 448 |
+
The quantized (or approximately quantized) representation with
|
| 449 |
+
the associated numbert quantizers and layer quantized required to return.
|
| 450 |
+
"""
|
| 451 |
+
n_q = n_q if n_q else self.n_q
|
| 452 |
+
quantized, codes, commit_loss, quantized_list = self.vq(
|
| 453 |
+
x, n_q=n_q, layers=layers
|
| 454 |
+
)
|
| 455 |
+
return quantized, codes, torch.mean(commit_loss), quantized_list
|
| 456 |
+
|
| 457 |
+
def encode(
|
| 458 |
+
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
|
| 459 |
+
) -> torch.Tensor:
|
| 460 |
+
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
|
| 461 |
+
The RVQ encode method sets the appropriate number of quantizer to use
|
| 462 |
+
and returns indices for each quantizer.
|
| 463 |
+
Args:
|
| 464 |
+
x (torch.Tensor): Input tensor.
|
| 465 |
+
n_q (int): Number of quantizer used to quantize. Default: All quantizers.
|
| 466 |
+
st (int): Start to encode input from which layers. Default: 0.
|
| 467 |
+
"""
|
| 468 |
+
n_q = n_q if n_q else self.n_q
|
| 469 |
+
st = st or 0
|
| 470 |
+
codes = self.vq.encode(x, n_q=n_q, st=st)
|
| 471 |
+
return codes
|
| 472 |
+
|
| 473 |
+
def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor:
|
| 474 |
+
"""Decode the given codes to the quantized representation.
|
| 475 |
+
Args:
|
| 476 |
+
codes (torch.Tensor): Input indices for each quantizer.
|
| 477 |
+
st (int): Start to decode input codes from which layers. Default: 0.
|
| 478 |
+
"""
|
| 479 |
+
quantized = self.vq.decode(codes, st=st)
|
| 480 |
+
return quantized
|