mrfakename commited on
Commit
ef96930
·
0 Parent(s):
.gitattributes ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ *.bin filter=lfs diff=lfs merge=lfs -text
2
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
3
+ *.pt filter=lfs diff=lfs merge=lfs -text
4
+ *.pth filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
7
+ *.ot filter=lfs diff=lfs merge=lfs -text
8
+ *.pkl filter=lfs diff=lfs merge=lfs -text
9
+ *.h5 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ build/
2
+ dist/
3
+ checkpoints/
4
+ *.egg-info/
5
+ *.egg
6
+ *.pyc
7
+ *.pyo
8
+ *.pyd
9
+ *.pyw
10
+ *.pyz
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2025 Xiaomi Corporation.
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MiMo-Audio TTS
3
+ emoji: 🎵
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 5.46.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ python_version: 3.12
12
+ ---
13
+
14
+ # MiMo-Audio Text-to-Speech
15
+
16
+ A simple text-to-speech interface powered by Xiaomi's MiMo-Audio model.
17
+
18
+ ## Features
19
+
20
+ - Convert text to natural-sounding speech
21
+ - Optional style descriptions to control voice characteristics
22
+ - Powered by MiMo-Audio-7B-Instruct model
23
+
24
+ ## Usage
25
+
26
+ 1. Enter your text in the input box
27
+ 2. Optionally add a style description (e.g., "a calm, gentle voice")
28
+ 3. Click "Generate Speech"
29
+ 4. Listen to or download the generated audio
30
+
31
+ ## Model
32
+
33
+ This Space uses the [MiMo-Audio-7B-Instruct](https://huggingface.co/XiaomiMiMo/MiMo-Audio-7B-Instruct) model, a 7B parameter audio language model developed by Xiaomi.
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ from huggingface_hub import snapshot_download
5
+ from src.mimo_audio.mimo_audio import MimoAudio
6
+ import tempfile
7
+ import os
8
+
9
+ # Download models from Hugging Face
10
+ print("Downloading MiMo-Audio models from Hugging Face...")
11
+ model_path = snapshot_download(repo_id="XiaomiMiMo/MiMo-Audio-7B-Instruct")
12
+ tokenizer_path = snapshot_download(repo_id="XiaomiMiMo/MiMo-Audio-Tokenizer")
13
+ print(f"Models downloaded to: {model_path} and {tokenizer_path}")
14
+
15
+ # Initialize model
16
+ print("Loading MiMo-Audio model...")
17
+ model = MimoAudio(
18
+ model_path=model_path,
19
+ tokenizer_path=tokenizer_path
20
+ )
21
+ print("Model loaded successfully!")
22
+
23
+ @spaces.GPU
24
+ def generate_speech(text, style_description=""):
25
+ """Generate speech from text using MiMo-Audio TTS"""
26
+ if not text or not text.strip():
27
+ return None, "Please enter some text to convert to speech."
28
+
29
+ try:
30
+ # Create temporary file for output
31
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
32
+ output_path = tmp_file.name
33
+
34
+ # Generate TTS
35
+ instruct = style_description if style_description.strip() else None
36
+ model.tts_sft(
37
+ text=text.strip(),
38
+ output_audio_path=output_path,
39
+ instruct=instruct
40
+ )
41
+
42
+ return output_path, "✅ Speech generated successfully!"
43
+
44
+ except Exception as e:
45
+ return None, f"❌ Error: {str(e)}"
46
+
47
+ # Create Gradio interface
48
+ with gr.Blocks(title="MiMo-Audio TTS") as demo:
49
+ gr.Markdown("""
50
+ # 🎵 MiMo-Audio Text-to-Speech
51
+
52
+ Convert text to natural-sounding speech using Xiaomi's MiMo-Audio model.
53
+ Optionally add a style description to control the voice characteristics.
54
+ """)
55
+
56
+ with gr.Row():
57
+ with gr.Column():
58
+ text_input = gr.Textbox(
59
+ label="Input Text",
60
+ placeholder="Enter text to convert to speech...",
61
+ lines=5
62
+ )
63
+ style_input = gr.Textbox(
64
+ label="Style Description (Optional)",
65
+ placeholder="e.g., 'a calm, gentle female voice' or 'an energetic male speaker'",
66
+ lines=2
67
+ )
68
+ generate_btn = gr.Button("Generate Speech", variant="primary")
69
+
70
+ with gr.Column():
71
+ audio_output = gr.Audio(
72
+ label="Generated Speech",
73
+ type="filepath"
74
+ )
75
+ status_output = gr.Textbox(
76
+ label="Status",
77
+ interactive=False
78
+ )
79
+
80
+ # Examples
81
+ gr.Examples(
82
+ examples=[
83
+ ["Hello! This is MiMo-Audio text-to-speech synthesis.", ""],
84
+ ["The quick brown fox jumps over the lazy dog.", "a clear, professional voice"],
85
+ ["Welcome to the world of artificial intelligence and natural language processing.", "an enthusiastic, friendly tone"]
86
+ ],
87
+ inputs=[text_input, style_input]
88
+ )
89
+
90
+ # Event handler
91
+ generate_btn.click(
92
+ fn=generate_speech,
93
+ inputs=[text_input, style_input],
94
+ outputs=[audio_output, status_output]
95
+ )
96
+
97
+ if __name__ == "__main__":
98
+ demo.launch()
inference_example_pretrain.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xiaomi Corporation.
2
+ from src.mimo_audio.mimo_audio import MimoAudio
3
+
4
+ model_path = "models/MiMo-Audio-7B-Base"
5
+ tokenizer_path = "models/MiMo-Audio-Tokenizer"
6
+
7
+
8
+ model = MimoAudio(model_path, tokenizer_path)
9
+
10
+
11
+ # in context learning: speech-to-speech generation
12
+ instruction = "Convert the timbre of the input speech to target timbre."
13
+
14
+ input_audio = "examples/ESD/0013_000200.wav"
15
+ prompt_examples = [
16
+ {
17
+ "input_audio": "examples/ESD/0013_000139.wav",
18
+ "output_audio": "examples/ESD/0019_000139.wav",
19
+ "output_transcription": "Cuckoos is downheaded and crying.",
20
+ },
21
+ {
22
+ "input_audio": "examples/ESD/0013_000963.wav",
23
+ "output_audio": "examples/ESD/0019_000963.wav",
24
+ "output_transcription": "She said in subdued voice.",
25
+ },
26
+ {
27
+ "input_audio": "examples/ESD/0013_000559.wav",
28
+ "output_audio": "examples/ESD/0019_000559.wav",
29
+ "output_transcription": "A raging fire was-in his eyes.",
30
+ },
31
+ {
32
+ "input_audio": "examples/ESD/0013_001142.wav",
33
+ "output_audio": "examples/ESD/0019_001142.wav",
34
+ "output_transcription": "Does the one that wins get the crowned?",
35
+ },
36
+ {
37
+ "input_audio": "examples/ESD/0013_000769.wav",
38
+ "output_audio": "examples/ESD/0019_000769.wav",
39
+ "output_transcription": "Not much use is it, sam?",
40
+ },
41
+ ]
42
+
43
+ output_audio_path = "examples/in_context_learning_s2s.wav"
44
+ text_channel_output = model.in_context_learning_s2s(instruction, prompt_examples, input_audio, max_new_tokens=8192, output_audio_path=output_audio_path)
inference_example_sft.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xiaomi Corporation.
2
+ from src.mimo_audio.mimo_audio import MimoAudio
3
+
4
+ model_path = "models/MiMo-Audio-7B-Instruct"
5
+ tokenizer_path = "models/MiMo-Audio-Tokenizer"
6
+
7
+
8
+ model = MimoAudio(model_path, tokenizer_path)
9
+
10
+
11
+ # tts
12
+ text = "今天天气真好"
13
+ output_audio_path = "examples/tts.wav"
14
+ text_channel_output = model.tts_sft(text, output_audio_path)
15
+
16
+
17
+ # instruct tts
18
+ text = "今天天气真好"
19
+ output_audio_path = "examples/instruct_tts.wav"
20
+ instruct = "用小孩子的声音开心的说"
21
+ text_channel_output = model.tts_sft(text, output_audio_path, instruct=instruct)
22
+
23
+
24
+ # natural instruction tts
25
+ text = "用气喘吁吁的年轻男性声音说:我跑不动了,你等等我!"
26
+ output_audio_path = "examples/natural_instruction_tts.wav"
27
+ text_channel_output = model.tts_sft(text, output_audio_path, read_text_only=False)
28
+
29
+
30
+ # audio understanding
31
+ audio_path = "examples/spoken_dialogue_assistant_turn_1.wav"
32
+ text = "Summarize the audio."
33
+ text_channel_output = model.audio_understanding_sft(audio_path, text)
34
+
35
+
36
+ # audio understanding with thinking
37
+ audio_path = "examples/spoken_dialogue_assistant_turn_1.wav"
38
+ text = "Summarize the audio."
39
+ text_channel_output = model.audio_understanding_sft(audio_path, text, thinking=True)
40
+
41
+
42
+ # spoken dialogue
43
+ first_turn_text_response = "我没办法获取实时的天气信息。不过呢,你可以试试几个方法来查看今天的天气。首先,你可以用手机自带的天气功能,比如苹果手机的天气应用,或者直接在系统设置里查看。其次,你也可以用一些专业的天气服务,像是国外的AccuWeather、Weather.com,或者国内的中国天气网、墨迹天气等等。再有就是,你还可以在谷歌或者百度里直接搜索你所在的城市加上天气这两个字。如果你能告诉我你所在的城市,我也可以帮你分析一下历史天气趋势,不过最新的数据还是需要你通过官方渠道去获取哦。"
44
+ message_list = [
45
+ {"role": "user", "content": "examples/今天天气如何.mp3"},
46
+ {"role": "assistant", "content": {"text": first_turn_text_response, "audio": "examples/spoken_dialogue_assistant_turn_1.wav"}},
47
+ {"role": "user", "content": "examples/北京.mp3"},
48
+ ]
49
+ output_audio_path = "examples/spoken_dialogue_assistant_turn_2.wav"
50
+ text_channel_output = model.spoken_dialogue_sft_multiturn(message_list, output_audio_path=output_audio_path, system_prompt=None, prompt_speech="examples/prompt_speech_zh_m.wav")
51
+ text_channel_output = text_channel_output.split("<|eot|>")[0].replace(".....", "")
52
+ print(text_channel_output)
53
+
54
+
55
+ # speech-to-text dialogue
56
+ message_list = [
57
+ {"role": "user", "content": "./examples/今天天气如何.mp3"},
58
+ {"role": "assistant", "content": "你好,我没办法获取实时的天气信息。如果你能告诉我你所在的城市,我也可以帮你分析一下历史天气趋势,不过最新的数据还是需要你通过官方渠道去获取哦。"},
59
+ {"role": "user", "content": "./examples/北京.mp3"},
60
+ ]
61
+ text_channel_output = model.speech2text_dialogue_sft_multiturn(message_list, thinking=True)
62
+
63
+
64
+ # text dialogue
65
+
66
+ message_list = [
67
+ {"role": "user", "content": "可以给我介绍一些中国的旅游景点吗?"},
68
+ {"role": "assistant", "content": "你好,您想去哪个城市旅游呢?"},
69
+ {"role": "user", "content": "北京"},
70
+ ]
71
+ text_channel_output = model.text_dialogue_sft_multiturn(message_list, thinking=True)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=1.9.0
2
+ torch==2.6.0
3
+ torchaudio==2.6.0
4
+ transformers==4.49.0
5
+ librosa>=0.11.0
6
+ scipy>=1.14.0
7
+ gradio==5.46.1
8
+ flash-attn==2.7.4.post1
9
+ spaces
requirements_space.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=1.9.0
2
+ torch==2.6.0
3
+ torchaudio==2.6.0
4
+ transformers==4.49.0
5
+ librosa>=0.11.0
6
+ scipy>=1.14.0
7
+ gradio==5.46.1
8
+ flash-attn==2.7.4.post1
9
+ spaces
run_mimo_audio.py ADDED
@@ -0,0 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xiaomi Corporation.
2
+ import gradio as gr
3
+ import torch
4
+ import os
5
+ import tempfile
6
+ import argparse
7
+ from pathlib import Path
8
+ from src.mimo_audio.mimo_audio import MimoAudio
9
+
10
+
11
+ class TTSGenerator:
12
+ def __init__(self, model, device=None):
13
+ self.model = model
14
+ self.device = device
15
+
16
+ def generate(self, text, instruct, output_audio_path):
17
+ path = Path(output_audio_path)
18
+ path.parent.mkdir(parents=True, exist_ok=True)
19
+ text_output = self.model.tts_sft(text, output_audio_path, instruct)
20
+ return text_output
21
+
22
+ class AudioUnderstandingGenerator:
23
+ def __init__(self, model, device=None):
24
+ self.model = model
25
+ self.device = device
26
+
27
+ def generate(self, input_speech, input_text, thinking=False):
28
+ text = self.model.audio_understanding_sft(input_speech, input_text, thinking=thinking)
29
+ return text
30
+
31
+ class SpokenDialogueGenerator:
32
+ def __init__(self, model, device=None):
33
+ self.model = model
34
+ self.device = device
35
+
36
+ def generate(self, input_speech, output_audio_path, system_prompt="You are MiMo-Audio, a friendly AI assistant and your response needs to be concise.", prompt_speech=None, add_history=False):
37
+
38
+ path = Path(output_audio_path)
39
+ path.parent.mkdir(parents=True, exist_ok=True)
40
+ text_response = self.model.spoken_dialogue_sft(input_speech, output_audio_path, system_prompt=system_prompt, prompt_speech=prompt_speech, add_history=add_history)
41
+ return text_response
42
+
43
+ def clear_history(self):
44
+ self.model.clear_history()
45
+
46
+ class Speech2TextDialogueGenerator:
47
+ def __init__(self, model, device=None):
48
+ self.model = model
49
+ self.device = device
50
+
51
+ def generate(self, input_speech, thinking=False, add_history=False):
52
+ text = self.model.speech2text_dialogue_sft(input_speech, thinking=thinking, add_history=add_history)
53
+ return text
54
+
55
+ def clear_history(self):
56
+ self.model.clear_history()
57
+
58
+
59
+ class TextDialogueGenerator:
60
+ def __init__(self, model, device=None):
61
+ self.model = model
62
+ self.device = device
63
+
64
+ def generate(self, input_text, thinking=False, add_history=False):
65
+ text = self.model.text_dialogue_sft(input_text, thinking=thinking, add_history=add_history)
66
+ return text
67
+
68
+ def clear_history(self):
69
+ self.model.clear_history()
70
+
71
+
72
+ class MultiModalSpeechInterface:
73
+ def __init__(self):
74
+ self.model = None
75
+ self.tts_generator = None
76
+ self.audio_understanding_generator = None
77
+ self.spoken_dialogue_generator = None
78
+ self.speech2text_dialogue_generator = None
79
+ self.text_dialogue_generator = None
80
+
81
+ self.device = None
82
+ self.model_initialized = False
83
+
84
+ def initialize_model(self, model_path=None, tokenizer_path=None):
85
+
86
+ try:
87
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88
+
89
+ if model_path is None:
90
+ model_path = "./models/MiMo-Audio-7B-Instruct"
91
+ if tokenizer_path is None:
92
+ tokenizer_path = "./models/MiMo-Audio-Tokenizer"
93
+
94
+
95
+ print(f"Model path: {model_path}")
96
+ print(f"Tokenizer path: {tokenizer_path}")
97
+
98
+ self.model = MimoAudio(model_path, tokenizer_path)
99
+ self.tts_generator = TTSGenerator(self.model, self.device)
100
+ self.audio_understanding_generator = AudioUnderstandingGenerator(self.model, self.device)
101
+ self.spoken_dialogue_generator = SpokenDialogueGenerator(self.model, self.device)
102
+ self.speech2text_dialogue_generator = Speech2TextDialogueGenerator(self.model, self.device)
103
+ self.text_dialogue_generator = TextDialogueGenerator(self.model, self.device)
104
+
105
+
106
+ self.model_initialized = True
107
+ print("Model loaded successfully!")
108
+ return "✅ Model loaded successfully!"
109
+
110
+ except Exception as e:
111
+ error_msg = f"❌ Model loading failed: {str(e)}"
112
+ print(error_msg)
113
+ return error_msg
114
+
115
+ def generate_tts_audio(self, input_text, instruct="", use_instruct=False):
116
+ if not self.model_initialized:
117
+ return None, "❌ Error: Model not initialized, please load the model first"
118
+
119
+ if not input_text.strip():
120
+ return None, "❌ Error: Please input text content"
121
+
122
+
123
+ try:
124
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
125
+ output_path = tmp_file.name
126
+
127
+
128
+ if not (use_instruct and instruct.strip()):
129
+ instruct = None
130
+
131
+ print(f"Generating TTS audio: {input_text}")
132
+
133
+ text_channel = self.tts_generator.generate(input_text, instruct, output_path)
134
+ status_msg = f"✅ TTS audio generated successfully!\n📝 Input text: {input_text}"
135
+ if use_instruct and instruct is not None and instruct.strip():
136
+ status_msg += f"\n🎭 Style description: {instruct}"
137
+ status_msg += f"\n🎵 Output text channel: {text_channel}"
138
+
139
+ return output_path, status_msg, gr.update(value=output_path, visible=True)
140
+
141
+ except Exception as e:
142
+ error_msg = f"❌ Error generating TTS audio: {str(e)}"
143
+ print(error_msg)
144
+ return None, error_msg, gr.update(visible=False)
145
+
146
+
147
+ def generate_audio_understanding_response(self, input_audio, input_text, thinking=False):
148
+ if not self.model_initialized:
149
+ return "", "❌ Error: Model not initialized, please load the model first"
150
+
151
+ if input_audio is None and not input_text.strip():
152
+ return "", "❌ Error: Please provide either audio input or text question"
153
+
154
+ if input_audio is None:
155
+ return "", "❌ Error: Please upload an audio file for Audio Understanding task"
156
+
157
+ if not input_text.strip():
158
+ return "", "❌ Error: Please input your question"
159
+
160
+ try:
161
+ print(f"Performing Audio Understanding task:")
162
+ print(f"Audio input: {input_audio}")
163
+ print(f"Text question: {input_text}")
164
+
165
+
166
+ audio_understanding_response = self.audio_understanding_generator.generate(input_audio, input_text.strip(), thinking=thinking)
167
+
168
+ status_msg = f"✅ Audio Understanding task completed successfully!\n🎵 Audio input: {os.path.basename(input_audio)}\n❓ Question: {input_text}\n💬 Answer: {audio_understanding_response}"
169
+
170
+ return audio_understanding_response, status_msg
171
+
172
+ except Exception as e:
173
+ error_msg = f"❌ Error performing Audio Understanding task: {str(e)}"
174
+ print(error_msg)
175
+ return "", error_msg
176
+
177
+ def generate_spoken_dialogue_response(self, input_audio, system_prompt=None, prompt_speech=None, add_history=False):
178
+ if not self.model_initialized:
179
+ return "", "❌ Error: Model not initialized, please load the model first"
180
+
181
+ if input_audio is None:
182
+ return "", "❌ Error: Please upload an audio file"
183
+
184
+ try:
185
+
186
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
187
+ output_audio_path = tmp_file.name
188
+
189
+ print(f"Performing spoken dialogue task:")
190
+ print(f"Audio input: {input_audio}")
191
+ print(f"Output path: {output_audio_path}")
192
+
193
+
194
+ dialogue_response = self.spoken_dialogue_generator.generate(input_audio, output_audio_path, system_prompt=system_prompt, prompt_speech=prompt_speech, add_history=add_history)
195
+
196
+ status_msg = f"✅ Spoken dialogue task completed successfully!\n🎵 Audio input: {os.path.basename(input_audio)}\n💬 Response: {dialogue_response[:300]}..."
197
+
198
+ return output_audio_path, dialogue_response, status_msg
199
+
200
+ except Exception as e:
201
+ error_msg = f"❌ Error performing spoken dialogue task: {str(e)}"
202
+ print(error_msg)
203
+ return None, None, error_msg
204
+
205
+
206
+ def generate_speech2text_dialogue_response(self, input_audio, thinking=False, add_history=False):
207
+ if not self.model_initialized:
208
+ return "", "❌ Error: Model not initialized, please load the model first"
209
+
210
+ if input_audio is None:
211
+ return "", "❌ Error: Please upload an audio file for S2T Dialogue task"
212
+
213
+
214
+ try:
215
+ print(f"Performing S2T Dialogue task:")
216
+ print(f"Audio input: {input_audio}")
217
+
218
+
219
+ s2t_response = self.speech2text_dialogue_generator.generate(input_audio, thinking=thinking, add_history=add_history)
220
+
221
+ status_msg = f"✅ S2T dialogue task completed successfully!\n🎵 Audio input: {os.path.basename(input_audio)}\n❓💬 Answer: {s2t_response}"
222
+
223
+ return s2t_response, status_msg
224
+
225
+ except Exception as e:
226
+ error_msg = f"❌ Error performing QA task: {str(e)}"
227
+ print(error_msg)
228
+ return "", error_msg
229
+
230
+ def generate_text_dialogue_response(self, input_text, thinking=False, add_history=False):
231
+ if not self.model_initialized:
232
+ return "", "❌ Error: Model not initialized, please load the model first"
233
+
234
+ if not input_text or not input_text.strip():
235
+ return "", "❌ Error: Please input your text"
236
+
237
+ try:
238
+ print(f"Performing Text Dialogue task:")
239
+ print(f"Text input: {input_text}")
240
+ print(f"Thinking mode: {thinking}")
241
+ print(f"Add history: {add_history}")
242
+
243
+
244
+ t2t_response = self.text_dialogue_generator.generate(input_text.strip(), thinking=thinking, add_history=add_history)
245
+
246
+ status_msg = f"✅ T2T dialogue task completed successfully!\n💬 Input: {input_text}"
247
+ if thinking:
248
+ status_msg += f"\n🧠 Thinking mode: Enabled"
249
+ status_msg += f"\n💬 Answer: {t2t_response}"
250
+
251
+ return t2t_response, status_msg
252
+
253
+ except Exception as e:
254
+ error_msg = f"❌ Error performing T2T dialogue task: {str(e)}"
255
+ print(error_msg)
256
+ return "", error_msg
257
+
258
+ def clear_spoken_dialogue_history(self):
259
+ if not self.model_initialized:
260
+ return None, "", "❌ Error: Model not initialized, please load the model first"
261
+
262
+ try:
263
+ self.spoken_dialogue_generator.clear_history()
264
+ return None, "", "✅ Spoken dialogue history cleared successfully!"
265
+ except Exception as e:
266
+ error_msg = f"❌ Error clearing spoken dialogue history: {str(e)}"
267
+ print(error_msg)
268
+ return None, "", error_msg
269
+
270
+ def clear_speech2text_dialogue_history(self):
271
+ if not self.model_initialized:
272
+ return "", "❌ Error: Model not initialized, please load the model first"
273
+
274
+ try:
275
+ self.speech2text_dialogue_generator.clear_history()
276
+ return "", "✅ Speech-to-text dialogue history cleared successfully!"
277
+ except Exception as e:
278
+ error_msg = f"❌ Error clearing S2T dialogue history: {str(e)}"
279
+ print(error_msg)
280
+ return "", error_msg
281
+
282
+ def clear_text_dialogue_history(self):
283
+ if not self.model_initialized:
284
+ return "", "❌ Error: Model not initialized, please load the model first"
285
+
286
+ try:
287
+ self.text_dialogue_generator.clear_history()
288
+ return "", "✅ Text dialogue history cleared successfully!"
289
+ except Exception as e:
290
+ error_msg = f"❌ Error clearing T2T dialogue history: {str(e)}"
291
+ print(error_msg)
292
+ return "", error_msg
293
+
294
+
295
+
296
+ def create_interface(self):
297
+
298
+ with gr.Blocks(title="MiMo-Audio Multimodal Speech Processing System", theme=gr.themes.Soft()) as iface:
299
+ gr.Markdown("# 🎵 MiMo-Audio Multimodal Speech Processing System")
300
+ gr.Markdown("Supports audio understanding, text-to-speech, spoken dialogue, speech-to-text dialogue and text-to-text dialogue")
301
+
302
+ with gr.Tabs():
303
+
304
+ with gr.TabItem("⚙️ Model Configuration"):
305
+ gr.Markdown("### 📋 Model initialization configuration")
306
+
307
+ with gr.Row():
308
+ with gr.Column():
309
+
310
+ model_path = gr.Textbox(
311
+ label="Model path",
312
+ placeholder="Leave blank to use default path: ./models/MiMo-Audio-7B-Instruct",
313
+ lines=3
314
+ )
315
+
316
+ tokenizer_path = gr.Textbox(
317
+ label="Tokenizer path",
318
+ placeholder="Leave blank to use default path: ./models/MiMo-Audio-Tokenizer",
319
+ lines=3
320
+ )
321
+
322
+ init_btn = gr.Button("🔄 Initialize model", variant="primary", size="lg")
323
+
324
+ with gr.Column():
325
+ init_status = gr.Textbox(
326
+ label="Initialization status",
327
+ interactive=False,
328
+ lines=6,
329
+ placeholder="Click the initialize model button to start..."
330
+ )
331
+
332
+
333
+ gr.Markdown("### 💻 System information")
334
+ device_info = gr.Textbox(
335
+ label="Device information",
336
+ value=f"GPU available: {'Yes' if torch.cuda.is_available() else 'No'}",
337
+ interactive=False
338
+ )
339
+
340
+
341
+ with gr.TabItem("🔊 Audio Understanding"):
342
+ gr.Markdown("### 🎯 Audio Understanding")
343
+
344
+ with gr.Row():
345
+ with gr.Column():
346
+ audio_understanding_input_audio = gr.Audio(
347
+ label="Upload Audio File",
348
+ type="filepath",
349
+ interactive=True,
350
+ )
351
+
352
+ audio_understanding_input_text = gr.Textbox(
353
+ label="Input Question",
354
+ placeholder="Please input your question...",
355
+ lines=3,
356
+ )
357
+
358
+ audio_understanding_thinking = gr.Checkbox(
359
+ label="Enable Thinking Mode",
360
+ value=False,
361
+ info="Enable thinking mode, AI will perform a deeper analysis and thinking"
362
+ )
363
+
364
+ audio_understanding_generate_btn = gr.Button("🤖 Start Audio Understanding", variant="primary", size="lg")
365
+
366
+
367
+
368
+ with gr.Column():
369
+ audio_understanding_output_text = gr.Textbox(
370
+ label="Answer Result",
371
+ lines=8,
372
+ interactive=False,
373
+ placeholder="AI's answer will be displayed here...",
374
+ elem_id="audio_understanding_output_text"
375
+ )
376
+
377
+ audio_understanding_status = gr.Textbox(
378
+ label="Processing Status",
379
+ lines=6,
380
+ interactive=False,
381
+ placeholder="Processing status information will be displayed here..."
382
+ )
383
+
384
+ with gr.Row():
385
+ audio_understanding_copy_btn = gr.Button("📋 Copy Answer", size="sm")
386
+ audio_understanding_clear_btn = gr.Button("🗑️ Clear Result", size="sm")
387
+
388
+ gr.Markdown("### 🌟 Audio Understanding Examples")
389
+ audio_understanding_examples = gr.Examples(
390
+ examples=[
391
+ [None, "这段音频的主要内容是什么?"],
392
+ [None, "说话者的情感状态如何?"],
393
+ [None, "音频中提到了哪些关键信息?"],
394
+ [None, "Please summarize the main points of this conversation."],
395
+ [None, "What viewpoint does the speaker want to express?"]
396
+ ],
397
+ inputs=[audio_understanding_input_audio, audio_understanding_input_text],
398
+ label="Click the example to automatically fill the question"
399
+ )
400
+
401
+
402
+
403
+
404
+ with gr.TabItem("🎵 Text-to-Speech"):
405
+ gr.Markdown("### 🎵 Text-to-Speech")
406
+
407
+ with gr.Row():
408
+ with gr.Column():
409
+
410
+ tts_input_text = gr.Textbox(
411
+ label="Input Text",
412
+ placeholder="Please input the text you want to convert to speech...",
413
+ lines=4,
414
+ max_lines=8
415
+ )
416
+
417
+ tts_instruct = gr.Textbox(
418
+ label="Style Description (Optional)",
419
+ placeholder="Please input the style description (optional)...",
420
+ lines=3,
421
+ max_lines=5
422
+ )
423
+
424
+ tts_use_instruct = gr.Checkbox(
425
+ label="Use Style Description",
426
+ value=True,
427
+ info="Enable to use InstructTTS for style-controlled speech generation"
428
+ )
429
+
430
+ tts_generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
431
+
432
+ with gr.Column():
433
+
434
+ tts_output_audio = gr.Audio(
435
+ label="Generated Speech",
436
+ type="filepath"
437
+ )
438
+
439
+ tts_status = gr.Textbox(
440
+ label="Generation Status",
441
+ lines=6,
442
+ interactive=False
443
+ )
444
+
445
+
446
+ tts_download_btn = gr.DownloadButton(
447
+ label="Download Generated Audio",
448
+ visible=False
449
+ )
450
+
451
+
452
+
453
+
454
+ with gr.TabItem("🎤 Spoken Dialogue"):
455
+ gr.Markdown("### 🎯 Spoken Dialogue")
456
+
457
+ with gr.Row():
458
+ with gr.Column():
459
+
460
+ dialogue_input_audio = gr.Audio(
461
+ label="Upload User Speech",
462
+ type="filepath",
463
+ interactive=True
464
+ )
465
+ system_prompt = gr.Textbox(
466
+ label="System Prompt (Optional)",
467
+ placeholder="e.g.: You are MiMo-Audio, a friendly AI assistant and your response needs to be concise.",
468
+ lines=1
469
+ )
470
+ prompt_speech = gr.Audio(
471
+ label="Prompt Speech (Optional, MiMo-Audio speaks with the same timbre as your prompt.)",
472
+ type="filepath",
473
+ interactive=True
474
+ )
475
+ spoken_dialogue_add_history = gr.Checkbox(
476
+ label="Enable History Record",
477
+ value=True,
478
+ info="Enable to remember the previous dialogue context"
479
+ )
480
+
481
+ with gr.Row():
482
+ dialogue_generate_btn = gr.Button("💬 Start Dialogue", variant="primary", size="lg")
483
+
484
+ with gr.Row():
485
+ dialogue_clear_history_btn = gr.Button("🗑️ Clear Dialogue History", size="sm", variant="secondary")
486
+
487
+
488
+
489
+
490
+
491
+ with gr.Column():
492
+
493
+ dialogue_output_audio = gr.Audio(
494
+ label="Output Audio",
495
+ type="filepath"
496
+ )
497
+ dialogue_output_text = gr.Textbox(
498
+ label="Dialogue Response",
499
+ lines=5,
500
+ interactive=False,
501
+ )
502
+ dialogue_status = gr.Textbox(
503
+ label="Dialogue Status",
504
+ lines=5,
505
+ interactive=False,
506
+ )
507
+
508
+
509
+
510
+
511
+
512
+ with gr.TabItem("💬 S2T Dialogue"):
513
+ gr.Markdown("### 🎯 S2T Dialogue")
514
+
515
+ with gr.Row():
516
+ with gr.Column():
517
+
518
+ s2t_dialogue_input_audio = gr.Audio(
519
+ label="Upload User Speech",
520
+ type="filepath",
521
+ interactive=True
522
+ )
523
+
524
+
525
+ s2t_dialogue_add_history = gr.Checkbox(
526
+ label="Enable History Record",
527
+ value=True,
528
+ info="Enable to remember the previous dialogue context"
529
+ )
530
+
531
+ s2t_dialogue_thinking = gr.Checkbox(
532
+ label="Enable Thinking Mode (think mode)",
533
+ value=False,
534
+ info="Enable to perform a deeper analysis and reasoning"
535
+ )
536
+
537
+ with gr.Row():
538
+ s2t_dialogue_generate_btn = gr.Button("🎧 Start S2T Dialogue", variant="primary", size="lg")
539
+
540
+ with gr.Row():
541
+ s2t_dialogue_clear_history_btn = gr.Button("🗑️ Clear Dialogue History", size="sm", variant="secondary")
542
+
543
+
544
+ with gr.Column():
545
+
546
+ s2t_dialogue_output_text = gr.Textbox(
547
+ label="Dialogue Response",
548
+ lines=8,
549
+ interactive=False,
550
+ placeholder="AI's dialogue response will be displayed here..."
551
+ )
552
+
553
+ s2t_dialogue_status = gr.Textbox(
554
+ label="Dialogue Status",
555
+ lines=5,
556
+ interactive=False,
557
+ placeholder="Dialogue status information will be displayed here..."
558
+ )
559
+
560
+
561
+
562
+ with gr.TabItem("📝 T2T Dialogue"):
563
+ gr.Markdown("### 🎯 T2T Dialogue")
564
+
565
+ with gr.Row():
566
+ with gr.Column():
567
+
568
+ t2t_dialogue_input_text = gr.Textbox(
569
+ label="Input Dialogue Content",
570
+ placeholder="Please input the text content you want to dialogue...",
571
+ lines=4,
572
+ max_lines=8
573
+ )
574
+
575
+ t2t_dialogue_add_history = gr.Checkbox(
576
+ label="Enable History Record",
577
+ value=True,
578
+ info="Enable to remember the previous dialogue context"
579
+ )
580
+
581
+ t2t_dialogue_thinking = gr.Checkbox(
582
+ label="Enable Thinking Mode (Thinking)",
583
+ value=False,
584
+ info="Enable thinking mode, AI will perform a deeper analysis and thinking"
585
+ )
586
+
587
+ with gr.Row():
588
+ t2t_dialogue_generate_btn = gr.Button("💬 Start T2T Dialogue", variant="primary", size="lg")
589
+
590
+ with gr.Row():
591
+ t2t_dialogue_clear_history_btn = gr.Button("🗑️ Clear Dialogue History", size="sm", variant="secondary")
592
+
593
+
594
+
595
+ with gr.Column():
596
+ t2t_dialogue_output_text = gr.Textbox(
597
+ label="Dialogue Response",
598
+ lines=8,
599
+ interactive=False,
600
+ placeholder="AI's dialogue response will be displayed here..."
601
+ )
602
+
603
+ t2t_dialogue_status = gr.Textbox(
604
+ label="Dialogue Status",
605
+ lines=5,
606
+ interactive=False,
607
+ placeholder="Dialogue status information will be displayed here..."
608
+ )
609
+
610
+ gr.Markdown("### 🌟 T2T Dialogue Examples")
611
+ t2t_dialogue_examples = gr.Examples(
612
+ examples=[
613
+ ["Hello, how are you?"],
614
+ ["I want to know the history of the development of artificial intelligence"],
615
+ ["Please recommend some good movies"],
616
+ ["Can you help me explain the basic concepts of quantum physics?"],
617
+ ["I'm learning programming recently, any suggestions?"]
618
+ ],
619
+ inputs=[t2t_dialogue_input_text],
620
+ label="Click the example to automatically fill the dialogue content"
621
+ )
622
+
623
+
624
+
625
+ def copy_text_to_clipboard(text):
626
+ return text
627
+
628
+ def clear_audio_understanding_results():
629
+ return "", "🗑️ Audio Understanding Result Cleared"
630
+
631
+
632
+ init_btn.click(
633
+ fn=lambda path, tok_path: self.initialize_model(path or None, tok_path or None),
634
+ inputs=[model_path, tokenizer_path],
635
+ outputs=[init_status]
636
+ )
637
+
638
+
639
+ audio_understanding_generate_btn.click(
640
+ fn=self.generate_audio_understanding_response,
641
+ inputs=[audio_understanding_input_audio, audio_understanding_input_text, audio_understanding_thinking],
642
+ outputs=[audio_understanding_output_text, audio_understanding_status]
643
+ )
644
+
645
+ audio_understanding_copy_btn.click(
646
+ fn=None,
647
+ inputs=[audio_understanding_output_text],
648
+ js="(text) => {navigator.clipboard.writeText(text); alert('Copied to clipboard!')}"
649
+ )
650
+
651
+ tts_generate_btn.click(
652
+ fn=self.generate_tts_audio,
653
+ inputs=[tts_input_text, tts_instruct, tts_use_instruct],
654
+ outputs=[tts_output_audio, tts_status, tts_download_btn]
655
+ )
656
+
657
+ dialogue_generate_btn.click(
658
+ fn=self.generate_spoken_dialogue_response,
659
+ inputs=[dialogue_input_audio, system_prompt, prompt_speech, spoken_dialogue_add_history],
660
+ outputs=[dialogue_output_audio, dialogue_output_text, dialogue_status]
661
+ )
662
+
663
+
664
+
665
+ dialogue_clear_history_btn.click(
666
+ fn=self.clear_spoken_dialogue_history,
667
+ outputs=[dialogue_output_audio, dialogue_output_text, dialogue_status]
668
+ )
669
+
670
+
671
+ s2t_dialogue_generate_btn.click(
672
+ fn=self.generate_speech2text_dialogue_response,
673
+ inputs=[s2t_dialogue_input_audio, s2t_dialogue_thinking, s2t_dialogue_add_history],
674
+ outputs=[s2t_dialogue_output_text, s2t_dialogue_status]
675
+ )
676
+
677
+
678
+
679
+ s2t_dialogue_clear_history_btn.click(
680
+ fn=self.clear_speech2text_dialogue_history,
681
+ outputs=[s2t_dialogue_output_text, s2t_dialogue_status]
682
+ )
683
+
684
+
685
+ t2t_dialogue_generate_btn.click(
686
+ fn=self.generate_text_dialogue_response,
687
+ inputs=[t2t_dialogue_input_text, t2t_dialogue_thinking, t2t_dialogue_add_history],
688
+ outputs=[t2t_dialogue_output_text, t2t_dialogue_status]
689
+ )
690
+
691
+
692
+ t2t_dialogue_clear_history_btn.click(
693
+ fn=self.clear_text_dialogue_history,
694
+ outputs=[t2t_dialogue_output_text, t2t_dialogue_status]
695
+ )
696
+
697
+
698
+
699
+
700
+ audio_understanding_clear_btn.click(
701
+ fn=clear_audio_understanding_results,
702
+ outputs=[audio_understanding_output_text, audio_understanding_status]
703
+ )
704
+
705
+
706
+
707
+
708
+
709
+
710
+ tts_input_text.submit(
711
+ fn=self.generate_tts_audio,
712
+ inputs=[tts_input_text, tts_instruct, tts_use_instruct],
713
+ outputs=[tts_output_audio, tts_status, tts_download_btn]
714
+ )
715
+
716
+
717
+ audio_understanding_input_text.submit(
718
+ fn=self.generate_audio_understanding_response,
719
+ inputs=[audio_understanding_input_audio, audio_understanding_input_text, audio_understanding_thinking],
720
+ outputs=[audio_understanding_output_text, audio_understanding_status]
721
+ )
722
+
723
+ t2t_dialogue_input_text.submit(
724
+ fn=self.generate_text_dialogue_response,
725
+ inputs=[t2t_dialogue_input_text, t2t_dialogue_thinking, t2t_dialogue_add_history],
726
+ outputs=[t2t_dialogue_output_text, t2t_dialogue_status]
727
+ )
728
+
729
+
730
+ return iface
731
+
732
+ def main():
733
+ parser = argparse.ArgumentParser(description="MiMo-Audio")
734
+ parser.add_argument("--host", default="0.0.0.0", help="Server Address")
735
+ parser.add_argument("--port", type=int, default=7897, help="Port")
736
+ parser.add_argument("--share", action="store_true", help="Create Public Link")
737
+ parser.add_argument("--debug", action="store_true", help="Debug Mode")
738
+
739
+ args = parser.parse_args()
740
+
741
+
742
+
743
+ print("🚀 Launch MiMo-Audio...")
744
+
745
+
746
+ speech_interface = MultiModalSpeechInterface()
747
+
748
+
749
+
750
+ print("🎨 Create Gradio Interface...")
751
+ iface = speech_interface.create_interface()
752
+
753
+
754
+ print(f"🌐 Launch Service - Address: {args.host}:{args.port}")
755
+
756
+ iface.launch(
757
+ server_name=args.host,
758
+ server_port=args.port,
759
+ share=args.share,
760
+ debug=args.debug
761
+ )
762
+
763
+ if __name__ == "__main__":
764
+ main()
src/mimo_audio/mimo_audio.py ADDED
@@ -0,0 +1,1292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xiaomi Corporation.
2
+ import os
3
+ import re
4
+ import time
5
+ import random
6
+ import torch
7
+ import torchaudio
8
+ import soundfile as sf
9
+
10
+ from typing import Union
11
+ from torchaudio.transforms import MelSpectrogram
12
+ from transformers import (
13
+ AutoTokenizer,
14
+ GenerationConfig
15
+ )
16
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
17
+
18
+ from .process_speechdata import InputSegment, StreamingInputSegment
19
+ from ..mimo_audio_tokenizer import MiMoAudioTokenizer
20
+ from .templates import asr_en_templates, asr_zh_templates, tts_en_templates, tts_zh_templates
21
+ from .modeling_mimo_audio import (
22
+ MiMoAudioArguments,
23
+ MiMoAudioForCausalLM,
24
+ MiMoSampler,
25
+ MiMoStopper,
26
+ )
27
+
28
+
29
+ def detect_language(text):
30
+ if re.search(r'[\u4e00-\u9fff]', text):
31
+ return 'zh'
32
+ else:
33
+ return 'en'
34
+
35
+
36
+ class MimoAudio:
37
+
38
+ def __init__(
39
+ self,
40
+ model_path: str,
41
+ mimo_audio_tokenizer_path: str,
42
+ device: str | None = None,
43
+ ) -> None:
44
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
45
+
46
+ self.path = model_path
47
+ self.mimo_audio_tokenizer_path = mimo_audio_tokenizer_path
48
+
49
+ self.tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(
50
+ self.path
51
+ )
52
+ self.padding_idx = int(self.tokenizer.pad_token_id)
53
+
54
+ special_tokens = [
55
+ "<|sosp|>",
56
+ "<|eosp|>",
57
+ "<|empty|>",
58
+ "<|Human|>",
59
+ "<|SpeechLM|>",
60
+ "<|sostm|>",
61
+ "<|eostm|>",
62
+ "<|eot|>",
63
+ ]
64
+ for token in special_tokens:
65
+ if token not in self.tokenizer.get_vocab():
66
+ print(f"Add special tokens {token} to tokenizer.vocab")
67
+ self.tokenizer.add_tokens([token], special_tokens=True)
68
+
69
+ self.sosp_idx = self.tokenizer.convert_tokens_to_ids("<|sosp|>")
70
+ self.eosp_idx = self.tokenizer.convert_tokens_to_ids("<|eosp|>")
71
+ self.empty_token = self.tokenizer.convert_tokens_to_ids("<|empty|>")
72
+ self.sostm_idx = self.tokenizer.convert_tokens_to_ids("<|sostm|>")
73
+ self.eostm_idx = self.tokenizer.convert_tokens_to_ids("<|eostm|>")
74
+ self.eot_idx = self.tokenizer.convert_tokens_to_ids("<|eot|>")
75
+ self.im_start_idx = self.tokenizer.convert_tokens_to_ids("<|im_start|>")
76
+ self.im_end_idx = self.tokenizer.convert_tokens_to_ids("<|im_end|>")
77
+
78
+ model_args = MiMoAudioArguments(
79
+ model_name_or_path=self.path,
80
+ sosp_idx=self.sosp_idx,
81
+ eosp_idx=self.eosp_idx,
82
+ empty_idx=self.empty_token,
83
+ sostm_idx=self.sostm_idx,
84
+ eostm_idx=self.eostm_idx,
85
+ eot_idx=self.eot_idx,
86
+ )
87
+
88
+ start_loading_time = time.monotonic()
89
+ self.model = MiMoAudioForCausalLM.from_pretrained(
90
+ self.path,
91
+ args=model_args,
92
+ torch_dtype=torch.bfloat16,
93
+ device_map={"": self.device},
94
+ )
95
+
96
+ self.group_size=self.model.config.group_size
97
+ self.audio_channels=self.model.config.audio_channels
98
+ self.delay_pattern = self.model.config.delay_pattern
99
+ self.vocab_size = self.model.config.vocab_size
100
+
101
+ self.speech_zeroemb_idx = self.model.speech_empty_ids
102
+
103
+ self.model.eval()
104
+ print(
105
+ f"Model loaded in {time.monotonic() - start_loading_time:.2f} seconds, device: {self.device}"
106
+ )
107
+
108
+ self.generate_kwargs = {
109
+ "max_length": 8192,
110
+ "eos_token_id": self.tokenizer.eos_token_id,
111
+ "pad_token_id": self.tokenizer.pad_token_id,
112
+ }
113
+ self.default_global_sampler = MiMoSampler(
114
+ do_sample=True, temperature=0.6, top_k=50, top_p=0.95
115
+ )
116
+ self.default_local_sampler = MiMoSampler(
117
+ do_sample=True, temperature=0.9, top_k=50, top_p=0.95
118
+ )
119
+
120
+ self.task_sampler_configs = {
121
+ "asr": {
122
+ "global": MiMoSampler(do_sample=False, temperature=1.0, top_p=1.0),
123
+ "local": MiMoSampler(do_sample=True, temperature=0.9, top_p=0.95)
124
+ },
125
+ "tts": {
126
+ "global": MiMoSampler(do_sample=True, temperature=0.6, top_p=1.0),
127
+ "local": MiMoSampler(do_sample=True, temperature=0.9, top_p=0.95)
128
+ },
129
+ "spoken_dialogue": {
130
+ "global": MiMoSampler(do_sample=True, temperature=0.6, top_p=0.95),
131
+ "local": MiMoSampler(do_sample=True, temperature=0.9, top_p=0.95)
132
+ },
133
+ "audio_understanding": {
134
+ "global": MiMoSampler(do_sample=True, temperature=0.3, top_p=0.95),
135
+ "local": MiMoSampler(do_sample=True, temperature=0.9, top_p=0.95)
136
+ },
137
+ "text_chat": {
138
+ "global": MiMoSampler(do_sample=True, temperature=0.4, top_p=0.95),
139
+ "local": MiMoSampler(do_sample=True, temperature=0.9, top_p=0.95)
140
+ },
141
+ "in_context_learning_s2s": {
142
+ "global": MiMoSampler(do_sample=False, temperature=1.0, top_p=1.0),
143
+ "local": MiMoSampler(do_sample=True, temperature=0.9, top_p=0.95)
144
+ },
145
+ }
146
+
147
+ start_loading_mimo_audio_tokenizer_time = time.monotonic()
148
+ self.mimo_audio_tokenizer = MiMoAudioTokenizer.from_pretrained(self.mimo_audio_tokenizer_path)
149
+
150
+ self.mimo_audio_tokenizer.eval().bfloat16().to(self.device)
151
+ print(
152
+ f"MiMo-Audio Tokenizer loaded in {time.monotonic() - start_loading_mimo_audio_tokenizer_time:.2f} seconds, device: {self.device}"
153
+ )
154
+
155
+ # Initialize mel spectrogram transform for consistent processing
156
+ self.mel_transform = MelSpectrogram(
157
+ sample_rate=self.mimo_audio_tokenizer.config.sampling_rate,
158
+ n_fft=self.mimo_audio_tokenizer.config.nfft,
159
+ hop_length=self.mimo_audio_tokenizer.config.hop_length,
160
+ win_length=self.mimo_audio_tokenizer.config.window_size,
161
+ f_min=self.mimo_audio_tokenizer.config.fmin,
162
+ f_max=self.mimo_audio_tokenizer.config.fmax,
163
+ n_mels=self.mimo_audio_tokenizer.config.n_mels,
164
+ power=1.0,
165
+ center=True,
166
+ ).to(self.device)
167
+
168
+ self.history = None
169
+
170
+ def get_task_sampler(self, task_name):
171
+ if task_name not in self.task_sampler_configs:
172
+ return {
173
+ "global": self.default_global_sampler,
174
+ "local": self.default_local_sampler
175
+ }
176
+ return self.task_sampler_configs[task_name]
177
+
178
+ def save_wav(self, path, wav):
179
+ sf.write(
180
+ path,
181
+ wav.reshape(-1).detach().cpu().numpy(),
182
+ 24000,
183
+ )
184
+
185
+ def wav2mel(self, wav):
186
+ spec = self.mel_transform(wav[None, :])
187
+ return torch.log(torch.clip(spec, min=1e-7)).squeeze()
188
+
189
+ def resample_audio_if_needed(self, wav_tensor: torch.Tensor, original_sr: int):
190
+ target_sr = self.mimo_audio_tokenizer.config.sampling_rate
191
+ if original_sr != target_sr:
192
+ wav_tensor = torchaudio.functional.resample(
193
+ wav_tensor, original_sr, target_sr
194
+ )
195
+ return wav_tensor
196
+
197
+ def group_by_length(self, features: torch.Tensor, lengths: torch.Tensor, max_length: int):
198
+ if features.size(0) != lengths.sum().item():
199
+ raise ValueError(f"Feature size mismatch: {features.size(0)} vs {lengths.sum().item()}")
200
+
201
+ split_points = []
202
+ current_sum = 0
203
+
204
+ for i, seq_len in enumerate(lengths):
205
+ if current_sum + seq_len > max_length and current_sum > 0:
206
+ split_points.append(i)
207
+ current_sum = seq_len.item()
208
+ else:
209
+ current_sum += seq_len.item()
210
+
211
+ # Convert split points to group sizes
212
+ group_sizes = []
213
+ prev = 0
214
+ for point in split_points:
215
+ group_sizes.append(point - prev)
216
+ prev = point
217
+ if prev < len(lengths):
218
+ group_sizes.append(len(lengths) - prev)
219
+
220
+ len_groups = torch.split(lengths, group_sizes)
221
+ feature_sizes = [group.sum().item() for group in len_groups]
222
+ feature_groups = torch.split(features, feature_sizes)
223
+
224
+ return feature_groups, len_groups
225
+
226
+ def encode_batch(self, input_features: torch.Tensor, input_lens: torch.Tensor, max_length: int = 256000):
227
+ feature_groups, len_groups = self.group_by_length(input_features, input_lens, max_length)
228
+
229
+ encoded_parts = []
230
+ for features, lengths in zip(feature_groups, len_groups):
231
+ with torch.no_grad():
232
+ codes, _ = self.mimo_audio_tokenizer.encoder.encode(
233
+ input_features=features.to(self.device),
234
+ input_lens=lengths.to(self.device),
235
+ return_codes_only=True
236
+ )
237
+ encoded_parts.append(codes)
238
+
239
+ return torch.cat(encoded_parts, dim=-1)
240
+
241
+ def preprocess_input(
242
+ self,
243
+ input: Union[None, str, torch.Tensor] = None,
244
+ ):
245
+ if isinstance(input, torch.Tensor) or (isinstance(input, str) and os.path.isfile(input)):
246
+ if isinstance(input, torch.Tensor):
247
+ wav = input
248
+ else:
249
+ wav, sr = torchaudio.load(input)
250
+ if wav.ndim == 2:
251
+ wav = wav.mean(dim=0)
252
+ wav = self.resample_audio_if_needed(wav, sr)
253
+ wav = wav.to(self.device)
254
+
255
+ mel = self.wav2mel(wav).transpose(0, 1) # (seq_len, n_mels)
256
+
257
+ input_len = mel.size(0)
258
+ segment_size = 6000
259
+ input_len_seg = [segment_size] * (input_len // segment_size)
260
+ if input_len % segment_size > 0:
261
+ input_len_seg.append(input_len % segment_size)
262
+
263
+ codes_packed = self.encode_batch(
264
+ input_features=mel,
265
+ input_lens=torch.tensor(input_len_seg),
266
+ )
267
+
268
+ codes = codes_packed.transpose(0, 1).detach().cpu()
269
+ audio_codes = codes[:, :self.audio_channels]
270
+
271
+ # Pad the sequence to be a multiple of group_size by repeating the last frame
272
+ num_timesteps = audio_codes.shape[0]
273
+ if num_timesteps % self.group_size != 0:
274
+ padding_needed = self.group_size - (num_timesteps % self.group_size)
275
+ last_tokens = audio_codes[-1:, :] # Keep dim for repeat
276
+ padding_tokens = last_tokens.repeat(padding_needed, 1)
277
+ audio_codes = torch.cat([audio_codes, padding_tokens], dim=0)
278
+
279
+ audio_tokenized = audio_codes.reshape(-1)
280
+
281
+ return audio_tokenized
282
+ else:
283
+ text = input
284
+ if (
285
+ text.isupper() or text.islower()
286
+ ): # If the text only contains upper-case or lower-case letters, capitalize it.
287
+ text = text.capitalize()
288
+ return text
289
+
290
+ def get_input_ids(self, prompt):
291
+ input_ids = [
292
+ seg.to_input_id(
293
+ self.tokenizer,
294
+ self.group_size,
295
+ self.audio_channels,
296
+ )
297
+ for seg in prompt
298
+ ]
299
+ input_ids = torch.cat(input_ids, dim=1)
300
+ return input_ids.to(self.device)
301
+
302
+
303
+ def get_asr_sft_prompt(
304
+ self,
305
+ input: Union[None, str] = None,
306
+ ):
307
+ audio_tokenized = self.preprocess_input(input)
308
+
309
+ template = random.choice(asr_zh_templates + asr_en_templates)
310
+
311
+ lm_prompt = [
312
+ InputSegment(
313
+ text=f"<|im_start|>user\n",
314
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
315
+ text_zeroemb_idx=self.empty_token,
316
+ ),
317
+ InputSegment(
318
+ audio=audio_tokenized,
319
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
320
+ text_zeroemb_idx=self.empty_token,
321
+ ),
322
+ InputSegment(
323
+ text=template,
324
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
325
+ text_zeroemb_idx=self.empty_token,
326
+ ),
327
+ InputSegment(
328
+ text=f"<|im_end|>\n",
329
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
330
+ text_zeroemb_idx=self.empty_token,
331
+ ),
332
+ InputSegment(
333
+ text=f"<|im_start|>assistant\n",
334
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
335
+ text_zeroemb_idx=self.empty_token,
336
+ ),
337
+ InputSegment(
338
+ text="<think>\n\n</think>\n",
339
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
340
+ text_zeroemb_idx=self.empty_token,
341
+ )
342
+ ]
343
+ input_ids = self.get_input_ids(lm_prompt)
344
+ return input_ids
345
+
346
+ def get_tts_sft_prompt(
347
+ self,
348
+ input: Union[None, str] = None,
349
+ instruct=None,
350
+ read_text_only=True,
351
+ prompt_speech=None,
352
+ ):
353
+ if prompt_speech is not None:
354
+ assistant_prompt_audio_token = self.preprocess_input(prompt_speech)
355
+ else:
356
+ assistant_prompt_audio_token = None
357
+ if not read_text_only:
358
+ text = self.preprocess_input(input)
359
+ if assistant_prompt_audio_token is None:
360
+ lm_prompt = [
361
+ InputSegment(
362
+ text="<|im_start|>system\n",
363
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
364
+ text_zeroemb_idx=self.empty_token,
365
+ ),
366
+ InputSegment(
367
+ text=f"你需要根据指定的风格指令和文本内容来生成语音。",
368
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
369
+ text_zeroemb_idx=self.empty_token,
370
+ ),
371
+ InputSegment(
372
+ text="<|im_end|>\n",
373
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
374
+ text_zeroemb_idx=self.empty_token,
375
+ ),
376
+ InputSegment(
377
+ text=f"<|im_start|>user\n{text}<|im_end|>\n",
378
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
379
+ text_zeroemb_idx=self.empty_token,
380
+ ),
381
+ InputSegment(
382
+ text=f"<|im_start|>assistant\n<think>\n",
383
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
384
+ text_zeroemb_idx=self.empty_token,
385
+ ),
386
+ ]
387
+ else:
388
+ lm_prompt = [
389
+ InputSegment(
390
+ text="<|im_start|>system\n",
391
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
392
+ text_zeroemb_idx=self.empty_token,
393
+ ),
394
+ InputSegment(
395
+ text=f"你需要根据指定的风格指令和文本内容来生成和语音prompt具有相同音色的语音。你的音色应该是:",
396
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
397
+ text_zeroemb_idx=self.empty_token,
398
+ ),
399
+ InputSegment(
400
+ text="",
401
+ audio=assistant_prompt_audio_token,
402
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
403
+ text_zeroemb_idx=self.empty_token,
404
+ ),
405
+ InputSegment(
406
+ text="<|im_end|>\n",
407
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
408
+ text_zeroemb_idx=self.empty_token,
409
+ ),
410
+ InputSegment(
411
+ text=f"<|im_start|>user\n{text}<|im_end|>\n",
412
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
413
+ text_zeroemb_idx=self.empty_token,
414
+ ),
415
+ InputSegment(
416
+ text=f"<|im_start|>assistant\n<think>\n",
417
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
418
+ text_zeroemb_idx=self.empty_token,
419
+ ),
420
+ ]
421
+ else:
422
+ language = detect_language(input)
423
+ if language == "zh":
424
+ template = random.choice(tts_zh_templates)
425
+ else:
426
+ template = random.choice(tts_en_templates)
427
+
428
+ text = self.preprocess_input(input)
429
+ if instruct is None:
430
+ lm_prompt = [
431
+ InputSegment(
432
+ text=f"<|im_start|>user\n{template}: {text}<|im_end|>\n",
433
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
434
+ text_zeroemb_idx=self.empty_token,
435
+ ),
436
+ InputSegment(
437
+ text=f"<|im_start|>assistant\n<|sostm|>",
438
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
439
+ text_zeroemb_idx=self.empty_token,
440
+ ),
441
+ ]
442
+ else:
443
+ if assistant_prompt_audio_token is None:
444
+ lm_prompt = [
445
+ InputSegment(
446
+ text="<|im_start|>system\n",
447
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
448
+ text_zeroemb_idx=self.empty_token,
449
+ ),
450
+ InputSegment(
451
+ text=f"你需要根据指定的风格指令和文本内容来生成语音。",
452
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
453
+ text_zeroemb_idx=self.empty_token,
454
+ ),
455
+ InputSegment(
456
+ text="<|im_end|>\n",
457
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
458
+ text_zeroemb_idx=self.empty_token,
459
+ ),
460
+ InputSegment(
461
+ text=f"<|im_start|>user\n{template}: {text}({instruct})<|im_end|>\n",
462
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
463
+ text_zeroemb_idx=self.empty_token,
464
+ ),
465
+ InputSegment(
466
+ text=f"<|im_start|>assistant\n<think>\n",
467
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
468
+ text_zeroemb_idx=self.empty_token,
469
+ ),
470
+ ]
471
+ else:
472
+ lm_prompt = [
473
+ InputSegment(
474
+ text="<|im_start|>system\n",
475
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
476
+ text_zeroemb_idx=self.empty_token,
477
+ ),
478
+ InputSegment(
479
+ text=f"你需要根据指定的风格指令和文本内容来生成和语音prompt具有相同音色的语音。你的音色应该是:",
480
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
481
+ text_zeroemb_idx=self.empty_token,
482
+ ),
483
+ InputSegment(
484
+ text="",
485
+ audio=assistant_prompt_audio_token,
486
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
487
+ text_zeroemb_idx=self.empty_token,
488
+ ),
489
+ InputSegment(
490
+ text="<|im_end|>\n",
491
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
492
+ text_zeroemb_idx=self.empty_token,
493
+ ),
494
+ InputSegment(
495
+ text=f"<|im_start|>user\n{template}: {text}({instruct})<|im_end|>\n",
496
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
497
+ text_zeroemb_idx=self.empty_token,
498
+ ),
499
+ InputSegment(
500
+ text=f"<|im_start|>assistant\n<think>\n",
501
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
502
+ text_zeroemb_idx=self.empty_token,
503
+ ),
504
+ ]
505
+
506
+ input_ids = self.get_input_ids(lm_prompt)
507
+ return input_ids
508
+
509
+
510
+ def get_audio_understanding_sft_prompt(
511
+ self,
512
+ input_speech,
513
+ input_text,
514
+ thinking=False,
515
+ ):
516
+ audio_tokenized = self.preprocess_input(input_speech)
517
+
518
+ lm_prompt = [
519
+ InputSegment(
520
+ text=f"<|im_start|>user\n",
521
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
522
+ text_zeroemb_idx=self.empty_token,
523
+ ),
524
+ InputSegment(
525
+ audio=audio_tokenized,
526
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
527
+ text_zeroemb_idx=self.empty_token,
528
+ ),
529
+ InputSegment(
530
+ text=input_text,
531
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
532
+ text_zeroemb_idx=self.empty_token,
533
+ ),
534
+ InputSegment(
535
+ text=f"<|im_end|>\n",
536
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
537
+ text_zeroemb_idx=self.empty_token,
538
+ ),
539
+ InputSegment(
540
+ text=f"<|im_start|>assistant\n",
541
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
542
+ text_zeroemb_idx=self.empty_token,
543
+ ),
544
+ ]
545
+ if not thinking:
546
+ lm_prompt.append(
547
+ InputSegment(
548
+ text="<think>\n\n</think>\n",
549
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
550
+ text_zeroemb_idx=self.empty_token,
551
+ )
552
+ )
553
+ else:
554
+ lm_prompt.append(
555
+ InputSegment(
556
+ text="<think>\n",
557
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
558
+ text_zeroemb_idx=self.empty_token,
559
+ )
560
+ )
561
+
562
+ input_ids = self.get_input_ids(lm_prompt)
563
+ return input_ids
564
+
565
+ def get_spoken_dialogue_sft_prompt(
566
+ self,
567
+ input_speech,
568
+ system_prompt=None,
569
+ prompt_speech=None,
570
+ add_history=False,
571
+ ):
572
+ audio_tokenized = self.preprocess_input(input_speech)
573
+
574
+ lm_prompt = []
575
+
576
+ if add_history and self.history is not None:
577
+ lm_prompt += [
578
+ InputSegment(
579
+ text=f"<|im_start|>user\n",
580
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
581
+ text_zeroemb_idx=self.empty_token,
582
+ ),
583
+ InputSegment(
584
+ audio=audio_tokenized,
585
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
586
+ text_zeroemb_idx=self.empty_token,
587
+ ),
588
+ InputSegment(
589
+ text=f"<|im_end|>\n",
590
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
591
+ text_zeroemb_idx=self.empty_token,
592
+ ),
593
+ InputSegment(
594
+ text=f"<|im_start|>assistant\n<|sostm|>",
595
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
596
+ text_zeroemb_idx=self.empty_token,
597
+ ),
598
+ ]
599
+ else:
600
+ if prompt_speech:
601
+ lm_prompt += [
602
+ InputSegment(
603
+ text="<|im_start|>system\n",
604
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
605
+ text_zeroemb_idx=self.empty_token,
606
+ ),
607
+ InputSegment(
608
+ text=f"Your voice should be:",
609
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
610
+ text_zeroemb_idx=self.empty_token,
611
+ ),
612
+ InputSegment(
613
+ audio=self.preprocess_input(prompt_speech),
614
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
615
+ text_zeroemb_idx=self.empty_token,
616
+ ),
617
+ InputSegment(
618
+ text="<|im_end|>\n",
619
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
620
+ text_zeroemb_idx=self.empty_token,
621
+ ),
622
+ ]
623
+
624
+ lm_prompt += [
625
+ InputSegment(
626
+ text=f"<|im_start|>user\n",
627
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
628
+ text_zeroemb_idx=self.empty_token,
629
+ )
630
+ ]
631
+
632
+ if system_prompt:
633
+ lm_prompt += [
634
+ InputSegment(
635
+ text=system_prompt,
636
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
637
+ text_zeroemb_idx=self.empty_token,
638
+ ),
639
+ InputSegment(
640
+ text="\n\n",
641
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
642
+ text_zeroemb_idx=self.empty_token,
643
+ )
644
+ ]
645
+ lm_prompt += [
646
+ InputSegment(
647
+ audio=audio_tokenized,
648
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
649
+ text_zeroemb_idx=self.empty_token,
650
+ ),
651
+ InputSegment(
652
+ text=f"<|im_end|>\n",
653
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
654
+ text_zeroemb_idx=self.empty_token,
655
+ ),
656
+ InputSegment(
657
+ text=f"<|im_start|>assistant\n<|sostm|>",
658
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
659
+ text_zeroemb_idx=self.empty_token,
660
+ ),
661
+ ]
662
+
663
+ input_ids = self.get_input_ids(lm_prompt)
664
+ return input_ids
665
+
666
+
667
+ def get_spoken_dialogue_sft_multiturn_prompt(
668
+ self,
669
+ message_list,
670
+ system_prompt=None,
671
+ prompt_speech=None,
672
+ ):
673
+ lm_prompt = []
674
+
675
+ if prompt_speech:
676
+ lm_prompt += [
677
+ InputSegment(
678
+ text="<|im_start|>system\n",
679
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
680
+ text_zeroemb_idx=self.empty_token,
681
+ ),
682
+ InputSegment(
683
+ text=f"Your voice should be:",
684
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
685
+ text_zeroemb_idx=self.empty_token,
686
+ ),
687
+ InputSegment(
688
+ audio=self.preprocess_input(prompt_speech),
689
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
690
+ text_zeroemb_idx=self.empty_token,
691
+ ),
692
+ InputSegment(
693
+ text="<|im_end|>\n",
694
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
695
+ text_zeroemb_idx=self.empty_token,
696
+ )
697
+ ]
698
+
699
+ for i in range(len(message_list)):
700
+ if message_list[i]['role'] == 'user':
701
+ lm_prompt += [
702
+ InputSegment(
703
+ text=f"<|im_start|>user\n",
704
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
705
+ text_zeroemb_idx=self.empty_token,
706
+ )
707
+ ]
708
+ if system_prompt and i == 0:
709
+ lm_prompt += [
710
+ InputSegment(
711
+ text=system_prompt,
712
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
713
+ text_zeroemb_idx=self.empty_token,
714
+ ),
715
+ InputSegment(
716
+ text="\n\n",
717
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
718
+ text_zeroemb_idx=self.empty_token,
719
+ )
720
+ ]
721
+ lm_prompt += [
722
+ InputSegment(
723
+ audio=self.preprocess_input(message_list[i]['content']),
724
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
725
+ text_zeroemb_idx=self.empty_token,
726
+ ),
727
+ InputSegment(
728
+ text=f"<|im_end|>\n",
729
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
730
+ text_zeroemb_idx=self.empty_token,
731
+ )
732
+ ]
733
+ elif message_list[i]['role'] == 'assistant':
734
+ lm_prompt += [
735
+ InputSegment(
736
+ text=f"<|im_start|>assistant\n",
737
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
738
+ text_zeroemb_idx=self.empty_token,
739
+ ),
740
+ StreamingInputSegment(
741
+ text=message_list[i]['content']["text"],
742
+ audio=self.preprocess_input(message_list[i]['content']["audio"]),
743
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
744
+ text_zeroemb_idx=self.empty_token,
745
+ tokenizer=self.tokenizer,
746
+ group_size=self.group_size,
747
+ audio_channels=self.audio_channels,
748
+ ),
749
+ InputSegment(
750
+ text=f"<|im_end|>\n",
751
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
752
+ text_zeroemb_idx=self.empty_token,
753
+ )
754
+ ]
755
+ else:
756
+ raise ValueError(f"Invalid role: {message_list[i]['role']}")
757
+
758
+ lm_prompt += [
759
+ InputSegment(
760
+ text=f"<|im_start|>assistant\n<|sostm|>",
761
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
762
+ text_zeroemb_idx=self.empty_token,
763
+ ),
764
+ ]
765
+
766
+ input_ids = self.get_input_ids(lm_prompt)
767
+ return input_ids
768
+
769
+
770
+ def get_s2t_dialogue_sft_prompt(
771
+ self,
772
+ input_speech,
773
+ thinking=False,
774
+ ):
775
+ audio_tokenized = self.preprocess_input(input_speech)
776
+ lm_prompt = [
777
+ InputSegment(
778
+ text=f"<|im_start|>user\n",
779
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
780
+ text_zeroemb_idx=self.empty_token,
781
+ ),
782
+ InputSegment(
783
+ audio=audio_tokenized,
784
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
785
+ text_zeroemb_idx=self.empty_token,
786
+ ),
787
+ InputSegment(
788
+ text=f"<|im_end|>\n",
789
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
790
+ text_zeroemb_idx=self.empty_token,
791
+ ),
792
+ InputSegment(
793
+ text=f"<|im_start|>assistant\n",
794
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
795
+ text_zeroemb_idx=self.empty_token,
796
+ )
797
+ ]
798
+ if not thinking:
799
+ lm_prompt.append(
800
+ InputSegment(
801
+ text="<think>\n\n</think>\n",
802
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
803
+ text_zeroemb_idx=self.empty_token,
804
+ )
805
+ )
806
+ else:
807
+ lm_prompt.append(
808
+ InputSegment(
809
+ text="<think>\n",
810
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
811
+ text_zeroemb_idx=self.empty_token,
812
+ )
813
+ )
814
+ input_ids = self.get_input_ids(lm_prompt)
815
+ return input_ids
816
+
817
+
818
+ def get_s2t_dialogue_sft_multiturn_prompt(self, message_list, thinking=False):
819
+ lm_prompt = []
820
+ for i in range(len(message_list)):
821
+ if message_list[i]['role'] == 'user':
822
+ lm_prompt += [
823
+ InputSegment(
824
+ text=f"<|im_start|>user\n",
825
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
826
+ text_zeroemb_idx=self.empty_token,
827
+ ),
828
+ InputSegment(
829
+ audio=self.preprocess_input(message_list[i]['content']),
830
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
831
+ text_zeroemb_idx=self.empty_token,
832
+ ),
833
+ InputSegment(
834
+ text=f"<|im_end|>\n",
835
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
836
+ text_zeroemb_idx=self.empty_token,
837
+ )
838
+ ]
839
+ elif message_list[i]['role'] == 'assistant':
840
+ lm_prompt += [
841
+ InputSegment(
842
+ text=f"<|im_start|>assistant\n",
843
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
844
+ text_zeroemb_idx=self.empty_token,
845
+ ),
846
+ InputSegment(
847
+ text=message_list[i]['content'],
848
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
849
+ text_zeroemb_idx=self.empty_token,
850
+ ),
851
+ InputSegment(
852
+ text=f"<|im_end|>\n",
853
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
854
+ text_zeroemb_idx=self.empty_token,
855
+ )
856
+ ]
857
+ else:
858
+ raise ValueError(f"Invalid role: {message_list[i]['role']}")
859
+
860
+ lm_prompt.append(
861
+ InputSegment(
862
+ text=f"<|im_start|>assistant\n",
863
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
864
+ text_zeroemb_idx=self.empty_token,
865
+ )
866
+ )
867
+ if not thinking:
868
+ lm_prompt.append(
869
+ InputSegment(
870
+ text="<think>\n\n</think>\n",
871
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
872
+ text_zeroemb_idx=self.empty_token,
873
+ )
874
+ )
875
+ else:
876
+ lm_prompt.append(
877
+ InputSegment(
878
+ text="<think>\n",
879
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
880
+ text_zeroemb_idx=self.empty_token,
881
+ )
882
+ )
883
+ input_ids = self.get_input_ids(lm_prompt)
884
+ return input_ids
885
+
886
+
887
+ def get_text_dialogue_sft_prompt(
888
+ self,
889
+ input_text,
890
+ thinking=False,
891
+ ):
892
+ lm_prompt = [
893
+ InputSegment(
894
+ text=f"<|im_start|>user\n",
895
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
896
+ text_zeroemb_idx=self.empty_token,
897
+ ),
898
+ InputSegment(
899
+ text=input_text,
900
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
901
+ text_zeroemb_idx=self.empty_token,
902
+ ),
903
+ InputSegment(
904
+ text=f"<|im_end|>\n",
905
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
906
+ text_zeroemb_idx=self.empty_token,
907
+ ),
908
+ InputSegment(
909
+ text=f"<|im_start|>assistant\n",
910
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
911
+ text_zeroemb_idx=self.empty_token,
912
+ ),
913
+ ]
914
+ if not thinking:
915
+ lm_prompt.append(
916
+ InputSegment(
917
+ text="<think>\n\n</think>\n",
918
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
919
+ text_zeroemb_idx=self.empty_token,
920
+ )
921
+ )
922
+ else:
923
+ lm_prompt.append(
924
+ InputSegment(
925
+ text="<think>\n",
926
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
927
+ text_zeroemb_idx=self.empty_token,
928
+ )
929
+ )
930
+ input_ids = self.get_input_ids(lm_prompt)
931
+ return input_ids
932
+
933
+ def get_text_dialogue_sft_multiturn_prompt(
934
+ self,
935
+ message_list,
936
+ thinking=False,
937
+ ):
938
+ lm_prompt = []
939
+ for i in range(len(message_list)):
940
+ if message_list[i]['role'] == 'user':
941
+ lm_prompt += [
942
+ InputSegment(
943
+ text=f"<|im_start|>user\n",
944
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
945
+ text_zeroemb_idx=self.empty_token,
946
+ ),
947
+ InputSegment(
948
+ text=message_list[i]['content'],
949
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
950
+ text_zeroemb_idx=self.empty_token,
951
+ ),
952
+ InputSegment(
953
+ text=f"<|im_end|>\n",
954
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
955
+ text_zeroemb_idx=self.empty_token,
956
+ )
957
+ ]
958
+ elif message_list[i]['role'] == 'assistant':
959
+ lm_prompt += [
960
+ InputSegment(
961
+ text=f"<|im_start|>assistant\n",
962
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
963
+ text_zeroemb_idx=self.empty_token,
964
+ ),
965
+ InputSegment(
966
+ text=message_list[i]['content'],
967
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
968
+ text_zeroemb_idx=self.empty_token,
969
+ ),
970
+ InputSegment(
971
+ text=f"<|im_end|>\n",
972
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
973
+ text_zeroemb_idx=self.empty_token,
974
+ )
975
+ ]
976
+ else:
977
+ raise ValueError(f"Invalid role: {message_list[i]['role']}")
978
+
979
+ lm_prompt.append(
980
+ InputSegment(
981
+ text=f"<|im_start|>assistant\n",
982
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
983
+ text_zeroemb_idx=self.empty_token,
984
+ )
985
+ )
986
+ if not thinking:
987
+ lm_prompt.append(
988
+ InputSegment(
989
+ text="<think>\n\n</think>\n",
990
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
991
+ text_zeroemb_idx=self.empty_token,
992
+ )
993
+ )
994
+ else:
995
+ lm_prompt.append(
996
+ InputSegment(
997
+ text="<think>\n",
998
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
999
+ text_zeroemb_idx=self.empty_token,
1000
+ )
1001
+ )
1002
+ input_ids = self.get_input_ids(lm_prompt)
1003
+ return input_ids
1004
+
1005
+
1006
+ def get_in_context_learning_s2s_prompt(self, instruction, prompt_examples, audio):
1007
+ prompt = [
1008
+ InputSegment(
1009
+ text=f"[Int]:{instruction}\n",
1010
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
1011
+ text_zeroemb_idx=self.empty_token,
1012
+ )
1013
+ ]
1014
+
1015
+ for i in range(len(prompt_examples)):
1016
+ prompt += [
1017
+ InputSegment(
1018
+ audio=self.preprocess_input(prompt_examples[i]["input_audio"]),
1019
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
1020
+ text_zeroemb_idx=self.empty_token,
1021
+ ),
1022
+ InputSegment(
1023
+ text="\n",
1024
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
1025
+ text_zeroemb_idx=self.empty_token,
1026
+ ),
1027
+ StreamingInputSegment(
1028
+ text=prompt_examples[i]["output_transcription"],
1029
+ audio=self.preprocess_input(prompt_examples[i]["output_audio"]),
1030
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
1031
+ text_zeroemb_idx=self.empty_token,
1032
+ tokenizer=self.tokenizer,
1033
+ group_size=self.group_size,
1034
+ audio_channels=self.audio_channels,
1035
+ ),
1036
+ InputSegment(
1037
+ text=" \n\n",
1038
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
1039
+ text_zeroemb_idx=self.empty_token,
1040
+ ),
1041
+ ]
1042
+
1043
+ prompt += [
1044
+ InputSegment(
1045
+ audio=self.preprocess_input(audio),
1046
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
1047
+ text_zeroemb_idx=self.empty_token,
1048
+ ),
1049
+ InputSegment(
1050
+ text="\n",
1051
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
1052
+ text_zeroemb_idx=self.empty_token,
1053
+ ),
1054
+ InputSegment(
1055
+ text="<|sostm|>",
1056
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
1057
+ text_zeroemb_idx=self.empty_token,
1058
+ ),
1059
+ ]
1060
+ input_ids = self.get_input_ids(prompt)
1061
+ return input_ids
1062
+
1063
+ @torch.no_grad()
1064
+ def forward(
1065
+ self,
1066
+ input_ids,
1067
+ return_audio=False,
1068
+ output_audio_path=None,
1069
+ stopping_criteria=None,
1070
+ min_new_tokens=0,
1071
+ max_new_tokens=8192,
1072
+ add_history=False,
1073
+ task_name=None,
1074
+ ):
1075
+
1076
+ task_sampler = self.get_task_sampler(task_name)
1077
+
1078
+ generation_kwargs = self.generate_kwargs.copy()
1079
+ generation_config = GenerationConfig(**generation_kwargs)
1080
+
1081
+ input_ids = input_ids.T.reshape(1, -1) # [B, flattened(T, audio_channels + 1)]
1082
+ if add_history and self.history is not None:
1083
+ input_ids = torch.cat([self.history, input_ids], dim=1)
1084
+
1085
+ prompt_length = input_ids.shape[1] // (self.audio_channels+1)
1086
+
1087
+ max_length = prompt_length // self.group_size + max_new_tokens
1088
+ min_length = prompt_length // self.group_size + min_new_tokens
1089
+
1090
+ if stopping_criteria is not None:
1091
+ for criterion in stopping_criteria:
1092
+ if isinstance(criterion, MiMoStopper):
1093
+ criterion.max_length = max_length
1094
+ criterion.min_length = min_length
1095
+
1096
+ generated_ids = self.model.generate(
1097
+ input_ids,
1098
+ generation_config,
1099
+ stopping_criteria=stopping_criteria,
1100
+ global_sampler=task_sampler["global"],
1101
+ local_sampler=task_sampler["local"],
1102
+ )
1103
+
1104
+ self.history = generated_ids
1105
+ generated_ids = generated_ids.int().cpu().reshape(-1, self.audio_channels+1).T[:, prompt_length:]
1106
+
1107
+ text = generated_ids[0, ::self.group_size][:-1]
1108
+ detokenized_text = self.tokenizer.decode(text, skip_special_tokens=False).strip().replace("<|empty|>", "").replace("<|eot|>", "").replace("<|eostm|>", "")
1109
+ print("Text channel:\t", detokenized_text)
1110
+
1111
+ if output_audio_path:
1112
+ return_audio = True
1113
+
1114
+ if not return_audio:
1115
+ return detokenized_text
1116
+
1117
+ sosp_idx_locations = (text == self.sostm_idx).nonzero(as_tuple=True)[0]
1118
+ eosp_idx_locations = (text == self.eostm_idx).nonzero(as_tuple=True)[0]
1119
+ if len(sosp_idx_locations) == 0:
1120
+ start_location = 0
1121
+ else:
1122
+ start_location = sosp_idx_locations[0] * self.group_size + self.group_size
1123
+ if len(eosp_idx_locations) == 0:
1124
+ end_location = text.shape[0] * self.group_size
1125
+ else:
1126
+ end_location = eosp_idx_locations[0] * self.group_size
1127
+ audio_sequence = generated_ids[:, start_location:end_location] #[audio_channels+1, audio_length]
1128
+ speech_sequence = audio_sequence[1:]
1129
+
1130
+ mask = speech_sequence[0] != (self.speech_zeroemb_idx[0] if isinstance(self.speech_zeroemb_idx, list) else self.speech_zeroemb_idx)
1131
+ speech_sequence = speech_sequence[:, mask]
1132
+
1133
+ assert (speech_sequence < torch.tensor(self.speech_zeroemb_idx).unsqueeze(1)).all()
1134
+
1135
+ speech_sequence = speech_sequence.T.flatten()
1136
+
1137
+ speech_str = "".join([f"<{i}>" for i in speech_sequence])
1138
+ tokens = torch.tensor(
1139
+ [int(num) for num in re.findall(r"(\d+)>", speech_str)]
1140
+ )
1141
+
1142
+ if tokens.numel() == 0:
1143
+ wav = torch.zeros(24000)
1144
+ self.save_wav(output_audio_path, wav)
1145
+ return detokenized_text
1146
+
1147
+ codes = tokens.reshape(-1, self.audio_channels).T
1148
+ codes = codes.type(torch.LongTensor).to(self.device)
1149
+
1150
+ segment_len = 1500
1151
+ wav_list=[]
1152
+ for start in range(0, codes.shape[-1], segment_len):
1153
+ wav = self.mimo_audio_tokenizer.decode(codes[:,start:start+segment_len]).float()
1154
+ wav_list.append(wav)
1155
+ wav_concat = torch.cat(wav_list, dim=-1)
1156
+
1157
+ #wav = self.mimo_audio_tokenizer.decode(codes).float()
1158
+ if output_audio_path is not None:
1159
+ self.save_wav(output_audio_path, wav_concat)
1160
+ return detokenized_text
1161
+ else:
1162
+ return wav_concat
1163
+
1164
+ def asr_sft(self, audio):
1165
+ stopping_criteria = [
1166
+ MiMoStopper(
1167
+ stop_tokens=[self.tokenizer.eos_token_id, self.im_end_idx],
1168
+ group_size=self.group_size,
1169
+ audio_channels=self.audio_channels,
1170
+ )
1171
+ ]
1172
+ input_ids = self.get_asr_sft_prompt(audio)
1173
+ result = self.forward(input_ids, stopping_criteria=stopping_criteria, task_name="asr")
1174
+ return result
1175
+
1176
+ def tts_sft(self, text, output_path, instruct=None, read_text_only=True, prompt_speech=None):
1177
+ stopping_criteria = [
1178
+ MiMoStopper(
1179
+ stop_tokens=[self.tokenizer.eos_token_id, self.eostm_idx, self.im_end_idx],
1180
+ group_size=self.group_size,
1181
+ audio_channels=self.audio_channels,
1182
+ )
1183
+ ]
1184
+ input_ids = self.get_tts_sft_prompt(text, instruct=instruct, read_text_only=read_text_only, prompt_speech=prompt_speech)
1185
+ text_output = self.forward(input_ids, output_audio_path=output_path, stopping_criteria=stopping_criteria, task_name="tts")
1186
+ return text_output
1187
+
1188
+ def audio_understanding_sft(self, input_speech, input_text, thinking=False):
1189
+ stopping_criteria = [
1190
+ MiMoStopper(
1191
+ stop_tokens=[self.tokenizer.eos_token_id, self.im_end_idx],
1192
+ group_size=self.group_size,
1193
+ audio_channels=self.audio_channels,
1194
+ )
1195
+ ]
1196
+ input_ids = self.get_audio_understanding_sft_prompt(input_speech, input_text, thinking=thinking)
1197
+ result = self.forward(input_ids, stopping_criteria=stopping_criteria, task_name="audio_understanding")
1198
+ return result
1199
+
1200
+ def spoken_dialogue_sft(self, input_speech, output_audio_path=None, system_prompt=None, prompt_speech=None, add_history=False):
1201
+ stopping_criteria = [
1202
+ MiMoStopper(
1203
+ stop_tokens=[self.tokenizer.eos_token_id, self.eostm_idx, self.im_end_idx],
1204
+ group_size=self.group_size,
1205
+ audio_channels=self.audio_channels,
1206
+ )
1207
+ ]
1208
+ input_ids = self.get_spoken_dialogue_sft_prompt(input_speech, system_prompt=system_prompt, prompt_speech=prompt_speech, add_history=add_history)
1209
+ text = self.forward(input_ids, output_audio_path=output_audio_path, stopping_criteria=stopping_criteria, task_name="spoken_dialogue", add_history=add_history)
1210
+ return text
1211
+
1212
+ # interface for message list interaction
1213
+ def spoken_dialogue_sft_multiturn(self, message_list, output_audio_path=None, system_prompt=None, prompt_speech=None):
1214
+ stopping_criteria = [
1215
+ MiMoStopper(
1216
+ stop_tokens=[self.tokenizer.eos_token_id, self.eostm_idx, self.im_end_idx],
1217
+ group_size=self.group_size,
1218
+ audio_channels=self.audio_channels,
1219
+ )
1220
+ ]
1221
+ input_ids = self.get_spoken_dialogue_sft_multiturn_prompt(message_list, system_prompt=system_prompt, prompt_speech=prompt_speech)
1222
+ text = self.forward(input_ids, output_audio_path=output_audio_path, stopping_criteria=stopping_criteria, task_name="spoken_dialogue", add_history=False)
1223
+ return text
1224
+
1225
+
1226
+ def speech2text_dialogue_sft(self, input_speech, thinking=False, add_history=False):
1227
+ stopping_criteria = [
1228
+ MiMoStopper(
1229
+ stop_tokens=[self.tokenizer.eos_token_id, self.im_end_idx],
1230
+ group_size=self.group_size,
1231
+ audio_channels=self.audio_channels,
1232
+ )
1233
+ ]
1234
+ input_ids = self.get_s2t_dialogue_sft_prompt(input_speech, thinking=thinking)
1235
+ text = self.forward(input_ids, stopping_criteria=stopping_criteria, task_name="spoken_dialogue", add_history=add_history)
1236
+ return text
1237
+
1238
+
1239
+ # interface for message list interaction
1240
+ def speech2text_dialogue_sft_multiturn(self, message_list, thinking=False):
1241
+ stopping_criteria = [
1242
+ MiMoStopper(
1243
+ stop_tokens=[self.tokenizer.eos_token_id, self.im_end_idx],
1244
+ group_size=self.group_size,
1245
+ audio_channels=self.audio_channels,
1246
+ )
1247
+ ]
1248
+ input_ids = self.get_s2t_dialogue_sft_multiturn_prompt(message_list, thinking=thinking)
1249
+ text = self.forward(input_ids, stopping_criteria=stopping_criteria, task_name="spoken_dialogue", add_history=False)
1250
+ return text
1251
+
1252
+ def text_dialogue_sft(self, input_text, thinking=False, add_history=False):
1253
+ stopping_criteria = [
1254
+ MiMoStopper(
1255
+ stop_tokens=[self.tokenizer.eos_token_id, self.im_end_idx],
1256
+ group_size=self.group_size,
1257
+ audio_channels=self.audio_channels,
1258
+ )
1259
+ ]
1260
+ input_ids = self.get_text_dialogue_sft_prompt(input_text, thinking=thinking)
1261
+ text = self.forward(input_ids, stopping_criteria=stopping_criteria, task_name="text_chat", add_history=add_history)
1262
+ return text
1263
+
1264
+ # interface for message list interaction
1265
+ def text_dialogue_sft_multiturn(self, message_list, thinking=False):
1266
+ stopping_criteria = [
1267
+ MiMoStopper(
1268
+ stop_tokens=[self.tokenizer.eos_token_id, self.im_end_idx],
1269
+ group_size=self.group_size,
1270
+ audio_channels=self.audio_channels,
1271
+ )
1272
+ ]
1273
+ input_ids = self.get_text_dialogue_sft_multiturn_prompt(message_list, thinking=thinking)
1274
+ text = self.forward(input_ids, stopping_criteria=stopping_criteria, task_name="text_chat", add_history=False)
1275
+ return text
1276
+
1277
+ def clear_history(self):
1278
+ self.history = None
1279
+ print("History cleared")
1280
+
1281
+ def in_context_learning_s2s(self, instruction, prompt_examples, audio, max_new_tokens=None, output_audio_path=None):
1282
+ stopping_criteria = [
1283
+ MiMoStopper(
1284
+ stop_tokens=[self.tokenizer.eos_token_id, self.eostm_idx],
1285
+ group_size=self.group_size,
1286
+ audio_channels=self.audio_channels,
1287
+ )
1288
+ ]
1289
+ input_ids = self.get_in_context_learning_s2s_prompt(instruction, prompt_examples, audio)
1290
+ self.forward(input_ids, output_audio_path=output_audio_path, stopping_criteria=stopping_criteria, max_new_tokens=max_new_tokens, task_name="in_context_learning_s2s")
1291
+
1292
+
src/mimo_audio/modeling_mimo_audio.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xiaomi Corporation.
2
+ import copy
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from typing import List, Optional, Union, cast
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from torch import nn
10
+ from transformers import StoppingCriteria
11
+ from transformers.cache_utils import Cache, DynamicCache
12
+ from transformers.generation.streamers import BaseStreamer
13
+ from transformers.generation.utils import (
14
+ GenerateOutput,
15
+ GenerationConfig,
16
+ StoppingCriteriaList,
17
+ is_deepspeed_zero3_enabled,
18
+ )
19
+ from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
20
+ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
21
+ from transformers.models.qwen2.modeling_qwen2 import (
22
+ Qwen2Model,
23
+ Qwen2PreTrainedModel,
24
+ )
25
+ from transformers.utils import is_torchdynamo_compiling
26
+
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class MiMoStopper(StoppingCriteria):
32
+ def __init__(
33
+ self,
34
+ group_size: int,
35
+ audio_channels: int,
36
+ stop_tokens: list[int] | None = None,
37
+ max_length: int | None = None,
38
+ min_length: int | None = None,
39
+ ) -> None:
40
+ super().__init__()
41
+ self.group_size = group_size
42
+ self.audio_channels = audio_channels
43
+ self.step = (audio_channels + 1) * group_size
44
+
45
+ self.stop_token_ids = set(stop_tokens or [])
46
+
47
+ self.max_length = max_length
48
+ self.min_length = min_length or 0
49
+
50
+ def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor):
51
+ is_done = False
52
+ cur_len = input_ids.shape[-1] // self.step
53
+
54
+ if self.max_length:
55
+ is_done |= cur_len >= self.max_length
56
+
57
+ if (self.stop_token_ids and
58
+ input_ids.shape[1] >= self.step and
59
+ cur_len >= self.min_length):
60
+ last_token = input_ids[0, -self.step].item()
61
+ is_done |= last_token in self.stop_token_ids
62
+
63
+ return torch.full(
64
+ (input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool
65
+ )
66
+
67
+
68
+ @dataclass
69
+ class MiMoSampler:
70
+ do_sample: bool | None = None
71
+ temperature: float | None = None
72
+ top_k: int | None = None
73
+ top_p: float | None = None
74
+
75
+ def process(self, scores: torch.Tensor):
76
+ if self.temperature is not None:
77
+ scores = scores / self.temperature
78
+
79
+ if self.top_k is not None and self.top_k > 0:
80
+ top_k = min(self.top_k, scores.shape[-1])
81
+ indices_to_remove = scores < torch.topk(scores, top_k)[0][:, -1]
82
+ scores = scores.masked_fill(indices_to_remove, float("-inf"))
83
+
84
+ if self.top_p is not None and 0.0 < self.top_p <= 1.0:
85
+ top_p = self.top_p if 0.0 < self.top_p <= 1.0 else 1.0
86
+ sorted_logits, sorted_indices = torch.sort(scores)
87
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
88
+ sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
89
+ sorted_indices_to_remove[:, -1] = 0
90
+ indices_to_remove = sorted_indices_to_remove.scatter(
91
+ 1, sorted_indices, sorted_indices_to_remove
92
+ )
93
+ scores = scores.masked_fill(indices_to_remove, float("-inf"))
94
+
95
+ return scores
96
+
97
+ def sample(self, scores: torch.Tensor, removed_tokens: list[int] | None = None):
98
+ scores = self.process(scores)
99
+ for t in removed_tokens or []:
100
+ scores[:, t] = float("-inf")
101
+
102
+ if self.do_sample:
103
+ probs = scores.softmax(dim=-1)
104
+ return torch.multinomial(probs, num_samples=1).squeeze(-1)
105
+
106
+ return torch.argmax(scores, dim=-1)
107
+
108
+
109
+ @dataclass
110
+ class MiMoAudioOutput(ModelOutput):
111
+ text_logits: torch.FloatTensor | None = None
112
+ local_hidden_states: torch.FloatTensor | None = None
113
+ past_key_values: Cache | None = None
114
+ """Downcast hidden states for local transformer generation"""
115
+
116
+
117
+ @dataclass
118
+ class MiMoAudioConfig(Qwen2Config):
119
+ def __init__(
120
+ self,
121
+ *,
122
+ speech_vocab_size: str | int = "1025-1025-129-129-129-129-129-129",
123
+ speech_zeroemb_idx: str | int = "1024-1024-128-128-128-128-128-128",
124
+ delay_pattern: str = "0-1-2-3-4-5-6-7",
125
+ head_dim: int = 128,
126
+ group_size: int = 4,
127
+ audio_channels: int = 8,
128
+ local_dim: int = 1024,
129
+ local_layers: int = 16,
130
+ local_attn_heads: int = 64,
131
+ local_ffn_dim: int = 4096,
132
+ local_attn_dropout: float = 0.1,
133
+ input_local_layers: int = 6,
134
+ input_local_dim: int | None = None,
135
+ input_full_attention: bool | None = None,
136
+ **kwargs,
137
+ ):
138
+ super().__init__(
139
+ **kwargs,
140
+ )
141
+ self.speech_vocab_size = speech_vocab_size
142
+ self.speech_zeroemb_idx = speech_zeroemb_idx
143
+ self.delay_pattern = delay_pattern
144
+
145
+ self.head_dim = head_dim
146
+
147
+ self.group_size = group_size
148
+ self.audio_channels = audio_channels
149
+
150
+ self.local_dim = local_dim
151
+ self.local_layers = local_layers
152
+ self.local_attn_heads = local_attn_heads
153
+ self.local_ffn_dim = local_ffn_dim
154
+ self.local_attn_dropout = local_attn_dropout
155
+
156
+ self.input_local_layers = input_local_layers
157
+ self.input_local_dim = input_local_dim or local_dim
158
+
159
+ self.input_full_attention = input_full_attention
160
+
161
+ def _parse_maybe_list(self, value: str | int, length: int) -> List[int]:
162
+ if isinstance(value, str) and "-" in value:
163
+ return [int(s) for s in value.split("-")]
164
+ return [int(value)] * length
165
+
166
+ def parsed_speech_empty_ids(self):
167
+ return self._parse_maybe_list(self.speech_zeroemb_idx, self.audio_channels)
168
+
169
+ def parsed_speech_vocab_sizes(self):
170
+ return self._parse_maybe_list(self.speech_vocab_size, self.audio_channels)
171
+
172
+ def parsed_delay_pattern(self):
173
+ return self._parse_maybe_list(self.delay_pattern, self.audio_channels)
174
+
175
+ def local_config(self):
176
+ config = copy.deepcopy(self)
177
+
178
+ config.hidden_size = self.local_dim
179
+ config.num_hidden_layers = self.local_layers
180
+ config.num_attention_heads = self.local_attn_heads
181
+ config.num_key_value_heads = self.local_attn_heads
182
+ config.head_dim = config.hidden_size // self.local_attn_heads
183
+ config.intermediate_size = self.local_ffn_dim
184
+ config.attention_dropout = self.local_attn_dropout
185
+
186
+ return config
187
+
188
+ def input_local_config(self):
189
+ config = copy.deepcopy(self)
190
+
191
+ config.hidden_size = self.input_local_dim
192
+ config.num_hidden_layers = self.input_local_layers
193
+ config.num_attention_heads = self.local_attn_heads
194
+ config.num_key_value_heads = self.local_attn_heads
195
+ config.head_dim = config.hidden_size // self.local_attn_heads
196
+ config.intermediate_size = config.hidden_size * 4
197
+ config.attention_dropout = self.local_attn_dropout
198
+
199
+ return config
200
+
201
+
202
+ @dataclass
203
+ class MiMoAudioArguments:
204
+ model_name_or_path: str
205
+ sosp_idx: int
206
+ eosp_idx: int
207
+ sostm_idx: int
208
+ eostm_idx: int
209
+ eot_idx: int
210
+ empty_idx: int
211
+
212
+ def to_dict(self):
213
+ return {
214
+ "model_name_or_path": self.model_name_or_path,
215
+ "sosp_idx": self.sosp_idx,
216
+ "eosp_idx": self.eosp_idx,
217
+ "sostm_idx": self.sostm_idx,
218
+ "eostm_idx": self.eostm_idx,
219
+ "eot_idx": self.eot_idx,
220
+ "empty_idx": self.empty_idx,
221
+ }
222
+
223
+
224
+ class MiMoAudioForCausalLM(Qwen2PreTrainedModel):
225
+ def __init__(
226
+ self,
227
+ config: MiMoAudioConfig | Qwen2Config,
228
+ args: MiMoAudioArguments | dict,
229
+ ):
230
+ super().__init__(config)
231
+ config = (
232
+ MiMoAudioConfig(**vars(config))
233
+ if isinstance(config, Qwen2Config)
234
+ else config
235
+ )
236
+ args = MiMoAudioArguments(**args) if isinstance(args, dict) else args
237
+ self.config = config
238
+ self.args = args
239
+
240
+ self.model = Qwen2Model(config)
241
+
242
+ self.speech_vocab_sizes = config.parsed_speech_vocab_sizes()
243
+ self.speech_empty_ids = config.parsed_speech_empty_ids()
244
+ self.delay_pattern = config.parsed_delay_pattern()
245
+
246
+ self.group_size = config.group_size
247
+ self.audio_channels = config.audio_channels
248
+
249
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
250
+
251
+ # Construct local transformer
252
+ self.local_config = config.local_config()
253
+ self.local_transformer = Qwen2Model(self.local_config)
254
+ self.local_transformer.embed_tokens = None
255
+
256
+ # Add input local transformer if configured
257
+ self.input_local_config = config.input_local_config()
258
+ self.input_local_transformer = Qwen2Model(self.input_local_config)
259
+ self.input_local_transformer.embed_tokens = None
260
+
261
+ self.local_transformer_lm_heads = nn.ModuleList(
262
+ [
263
+ nn.Linear(
264
+ self.local_config.hidden_size,
265
+ self.speech_vocab_sizes[i],
266
+ bias=False,
267
+ )
268
+ for i in range(self.audio_channels)
269
+ ]
270
+ )
271
+
272
+ self.speech_embeddings = nn.ModuleList(
273
+ [
274
+ nn.Embedding(
275
+ self.speech_vocab_sizes[i],
276
+ self.input_local_config.hidden_size,
277
+ padding_idx=self.speech_empty_ids[i],
278
+ )
279
+ for i in range(self.audio_channels)
280
+ ]
281
+ )
282
+
283
+ if self.input_local_config.hidden_size != self.local_config.hidden_size:
284
+ self.speech_embeddings_to_local = nn.Linear(
285
+ self.input_local_config.hidden_size,
286
+ self.local_config.hidden_size,
287
+ bias=False,
288
+ )
289
+ else:
290
+ self.speech_embeddings_to_local = None
291
+
292
+ # Create speech_group_downcast_first for group_first_in_global_context
293
+ self.speech_group_downcast = nn.Linear(
294
+ self.input_local_config.hidden_size * config.group_size,
295
+ config.hidden_size,
296
+ bias=False,
297
+ )
298
+
299
+ self.hidden_states_downcast = nn.Linear(
300
+ config.hidden_size,
301
+ self.local_config.hidden_size,
302
+ bias=False,
303
+ )
304
+
305
+ # Initialize weights and apply final processing
306
+ self.post_init()
307
+
308
+ def apply_input_local_transformer(self, speech_embeddings: torch.Tensor):
309
+ B, T_groups, group_size, hidden_size = speech_embeddings.shape
310
+
311
+ # Process each group independently: [B*T//group_size, group_size, hidden_size]
312
+ input_embeddings = speech_embeddings.reshape(
313
+ B * T_groups, group_size, hidden_size
314
+ )
315
+
316
+ output: BaseModelOutputWithPast = self.input_local_transformer(
317
+ inputs_embeds=input_embeddings,
318
+ return_dict=True,
319
+ is_causal=not self.config.input_full_attention, # for SDPA
320
+ )
321
+ encoded_embeddings = output.last_hidden_state
322
+
323
+ # Reshape back to original format
324
+ # [B*T//group_size, group_size, hidden_size] -> [B, T//group_size, group_size, hidden_size]
325
+ encoded_embeddings = encoded_embeddings.reshape(
326
+ B, T_groups, group_size, hidden_size
327
+ )
328
+
329
+ return encoded_embeddings
330
+
331
+ def _prepare_input_embeds(
332
+ self,
333
+ input_ids: torch.LongTensor, # [B, audio_channels + 1, new_T]
334
+ ):
335
+ B = input_ids.shape[0]
336
+
337
+ input_ids = input_ids.int()
338
+ group_size = self.config.group_size
339
+
340
+ text_input_ids = input_ids[:, 0, ::group_size]
341
+ speech_input_ids = (
342
+ input_ids[:, 1:, :]
343
+ .view(B, self.audio_channels, -1, group_size)
344
+ .transpose(1, 2)
345
+ ) # [B, T//group_size, audio_channels, group_size]
346
+
347
+ is_speech = text_input_ids == self.args.empty_idx # [B, T//group_size]
348
+
349
+ speech_embeds = torch.zeros(
350
+ (
351
+ B,
352
+ is_speech.shape[1],
353
+ group_size,
354
+ self.input_local_config.hidden_size,
355
+ ),
356
+ device=input_ids.device,
357
+ dtype=torch.bfloat16,
358
+ )
359
+
360
+ for idx in range(self.audio_channels):
361
+ cur_empty = self.speech_empty_ids[idx]
362
+ cur_embed = self.speech_embeddings[idx]
363
+ cur_speech_ids = speech_input_ids[:, :, idx, :]
364
+ cur_speech_embeds: torch.Tensor = cur_embed(cur_speech_ids)
365
+ # [B, T_groups, group_size, hidden_size]
366
+
367
+ cur_mask = cur_speech_ids == cur_empty
368
+ cur_speech_embeds.masked_fill_(cur_mask.unsqueeze(-1), 0.0)
369
+
370
+ speech_embeds += cur_speech_embeds
371
+
372
+ speech_embeds = speech_embeds * is_speech.unsqueeze(-1).unsqueeze(-1)
373
+
374
+ # Apply input local transformer if configured
375
+ speech_embeds = self.apply_input_local_transformer(speech_embeds)
376
+ speech_embeds = speech_embeds * is_speech.unsqueeze(-1).unsqueeze(-1)
377
+
378
+ T_groups = speech_embeds.shape[1]
379
+ speech_grouped_embeds: torch.Tensor = self.speech_group_downcast(
380
+ speech_embeds.view(B, T_groups, -1)
381
+ ) # [B, T_groups, hidden_size]
382
+
383
+ text_embeds: torch.Tensor = self.model.embed_tokens(text_input_ids)
384
+ text_zero_mask = text_input_ids == self.args.empty_idx
385
+ text_embeds.masked_fill_(text_zero_mask.unsqueeze(-1), 0.0)
386
+
387
+ return text_embeds + speech_grouped_embeds
388
+
389
+ def forward(
390
+ self,
391
+ input_ids: torch.LongTensor, # [B, audio_channels + 1, new_T]
392
+ attention_mask: torch.Tensor, # [B, T_group]
393
+ position_ids: torch.LongTensor, # [B, new_T_group]
394
+ past_key_values: Cache | None = None,
395
+ cache_position: torch.LongTensor | None = None, # [new_T_group]
396
+ **_kwargs,
397
+ ):
398
+ inputs_embeds = self._prepare_input_embeds(input_ids)
399
+
400
+ outputs: BaseModelOutputWithPast = self.model(
401
+ attention_mask=attention_mask,
402
+ position_ids=position_ids,
403
+ past_key_values=past_key_values,
404
+ inputs_embeds=inputs_embeds,
405
+ use_cache=True,
406
+ return_dict=True,
407
+ cache_position=cache_position,
408
+ )
409
+ hidden_states = outputs.last_hidden_state # [B, new_T_group, hidden_size]
410
+
411
+ text_logits: torch.Tensor = self.lm_head(
412
+ hidden_states[:, -1:, :]
413
+ ) # [B, 1, vocab_size]
414
+ shift_hidden_states: torch.Tensor = self.hidden_states_downcast(
415
+ hidden_states[:, -1:, :]
416
+ ) # [B, 1, hidden_size]
417
+
418
+ return MiMoAudioOutput(
419
+ text_logits=text_logits,
420
+ local_hidden_states=shift_hidden_states,
421
+ past_key_values=outputs.past_key_values,
422
+ )
423
+
424
+ def local_forward(
425
+ self,
426
+ local_embeds: torch.FloatTensor, # [B, 1, hidden_size]
427
+ tokens_dtype: torch.dtype,
428
+ tokens_device: torch.device,
429
+ local_sampler: MiMoSampler | None = None,
430
+ ):
431
+ B = local_embeds.shape[0]
432
+ delay_iters = self.group_size + max(self.delay_pattern)
433
+ past_key_values = DynamicCache()
434
+ local_tokens = torch.zeros(
435
+ (B, self.group_size, self.audio_channels),
436
+ dtype=tokens_dtype,
437
+ device=tokens_device,
438
+ )
439
+ if local_sampler is None:
440
+ local_sampler = MiMoSampler()
441
+
442
+ for t in range(delay_iters):
443
+ output: BaseModelOutputWithPast = self.local_transformer(
444
+ inputs_embeds=local_embeds,
445
+ past_key_values=past_key_values,
446
+ return_dict=True,
447
+ use_cache=True,
448
+ )
449
+ hidden_state = output.last_hidden_state
450
+ past_key_values = output.past_key_values
451
+
452
+ local_embeds = torch.zeros_like(local_embeds)
453
+ for idx in range(self.audio_channels):
454
+ cur_start = self.delay_pattern[idx]
455
+ cur_end = cur_start + self.group_size
456
+ cur_empty = self.speech_empty_ids[idx]
457
+ if cur_start <= t < cur_end:
458
+ cur_lm_head = self.local_transformer_lm_heads[idx]
459
+ cur_scores: torch.Tensor = cur_lm_head(hidden_state)[:, -1, :]
460
+ # [B, vocab_size]
461
+ cur_token = local_sampler.sample(
462
+ cur_scores,
463
+ [cur_empty],
464
+ )
465
+
466
+ local_tokens[:, t - cur_start, idx] = cur_token
467
+ cur_input_embed = self.speech_embeddings[idx](
468
+ cur_token.unsqueeze(1)
469
+ )
470
+ if self.speech_embeddings_to_local is not None:
471
+ cur_input_embed = self.speech_embeddings_to_local(
472
+ cur_input_embed
473
+ )
474
+ local_embeds += cur_input_embed
475
+
476
+ return local_tokens # [B, group_size, audio_channels]
477
+
478
+ def _prepare_attention_mask(
479
+ self, inputs: torch.Tensor, input_ids_length: int
480
+ ) -> torch.Tensor:
481
+ # No information for attention mask inference -> return default attention mask
482
+ return torch.ones(
483
+ (inputs.shape[0], input_ids_length),
484
+ dtype=torch.bool,
485
+ device=inputs.device,
486
+ )
487
+
488
+ def prepare_inputs_for_generation(
489
+ self,
490
+ input_ids: torch.LongTensor,
491
+ past_key_values: Optional[Cache] = None,
492
+ attention_mask: Optional[torch.LongTensor] = None,
493
+ inputs_embeds: Optional[torch.FloatTensor] = None,
494
+ cache_position: Optional[torch.LongTensor] = None,
495
+ **kwargs,
496
+ ):
497
+ """
498
+ Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or
499
+ slicing inputs given the existing cache.
500
+
501
+ See the forward pass in the model documentation for expected arguments (different models might have different
502
+ requirements for e.g. `past_key_values`). This function should work as is for most LLMs.
503
+ """
504
+
505
+ # 1. Handle BC:
506
+ model_inputs = {}
507
+ input_ids = input_ids.reshape(
508
+ input_ids.shape[0], -1, (self.audio_channels + 1) * self.config.group_size
509
+ ).transpose(1, 2) # [B, audio_channels*group_size, T]
510
+ # - some models don't have `Cache` support (which implies they don't expect `cache_position` in `forward`)
511
+ if self._supports_cache_class:
512
+ model_inputs["cache_position"] = cache_position
513
+ # - `cache_position` was not a mandatory input in `prepare_inputs_for_generation` for those models, and this
514
+ # function may be called outside of `generate`. Handle most use cases by creating `cache_position` on the fly
515
+ # (this alternative is not as robust as calling `generate` and letting it create `cache_position`)
516
+ elif cache_position is None:
517
+ past_length = (
518
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
519
+ )
520
+ cache_position = torch.arange(
521
+ past_length,
522
+ input_ids.shape[2],
523
+ dtype=torch.long,
524
+ device=input_ids.device,
525
+ )
526
+
527
+ # 2. Generic cache-dependent input preparation
528
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
529
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
530
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
531
+ # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case
532
+ if past_key_values is not None:
533
+ model_inputs["past_key_values"] = past_key_values
534
+ if (
535
+ inputs_embeds is not None or cache_position[-1] >= input_ids.shape[2]
536
+ ): # Exception 1 or Exception 3
537
+ input_ids = input_ids[:, :, -cache_position.shape[0] :]
538
+ elif (
539
+ input_ids.shape[2] != cache_position.shape[0]
540
+ ): # Default case (the "else", a no op, is Exception 2)
541
+ input_ids = input_ids[:, :, cache_position]
542
+
543
+ # 3. Prepare base model inputs
544
+ input_ids_key = (
545
+ "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
546
+ )
547
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
548
+ if not self.config.is_encoder_decoder:
549
+ if inputs_embeds is not None and cache_position[0] == 0:
550
+ model_inputs[input_ids_key] = None
551
+ model_inputs["inputs_embeds"] = inputs_embeds
552
+ else:
553
+ # `clone` calls in this function ensure a consistent stride. See #32227
554
+ model_inputs[input_ids_key] = input_ids.clone(
555
+ memory_format=torch.contiguous_format
556
+ )
557
+ model_inputs["inputs_embeds"] = None
558
+ else:
559
+ model_inputs[input_ids_key] = input_ids.clone(
560
+ memory_format=torch.contiguous_format
561
+ )
562
+
563
+ # 4. Create missing `position_ids` on the fly
564
+ if attention_mask is not None and kwargs.get("position_ids") is None:
565
+ position_ids = attention_mask.long().cumsum(-1) - 1
566
+ position_ids.masked_fill_(attention_mask == 0, 1)
567
+ kwargs["position_ids"] = (
568
+ position_ids # placed in kwargs for further processing (see below)
569
+ )
570
+
571
+ # 5. Slice model inputs if it's an input that should have the same length as `input_ids`
572
+ for model_input_name in ["position_ids", "token_type_ids"]:
573
+ model_input: torch.Tensor = kwargs.get(model_input_name)
574
+ if model_input is not None:
575
+ if past_key_values:
576
+ model_input = model_input[:, -input_ids.shape[2] :]
577
+ model_input = model_input.clone(
578
+ memory_format=torch.contiguous_format
579
+ )
580
+ model_inputs[model_input_name] = model_input
581
+
582
+ if attention_mask is not None:
583
+ model_inputs["attention_mask"] = attention_mask
584
+
585
+ # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
586
+ for key, value in kwargs.items():
587
+ if key not in model_inputs:
588
+ model_inputs[key] = value
589
+
590
+ if model_inputs[input_ids_key] is not None:
591
+ model_inputs[input_ids_key] = (
592
+ cast(torch.Tensor, model_inputs[input_ids_key])
593
+ .transpose(1, 2)
594
+ .reshape(input_ids.shape[0], -1, (self.audio_channels + 1))
595
+ .transpose(1, 2)
596
+ ) # [B, audio_channels, T*group_size]
597
+
598
+ # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
599
+ model_inputs.pop("labels", None)
600
+ return model_inputs
601
+
602
+ def _get_initial_cache_position(self, input_ids: torch.Tensor, model_kwargs: dict):
603
+ """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
604
+ # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
605
+ if "inputs_embeds" in model_kwargs:
606
+ cache_position = (
607
+ torch.ones_like(
608
+ model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64
609
+ ).cumsum(0)
610
+ - 1
611
+ )
612
+ else:
613
+ cache_position = (
614
+ torch.ones(
615
+ (
616
+ input_ids.shape[1]
617
+ // (self.audio_channels + 1)
618
+ // self.config.group_size,
619
+ ),
620
+ dtype=torch.int64,
621
+ device=input_ids.device,
622
+ ).cumsum(0)
623
+ - 1
624
+ )
625
+
626
+ past_length = 0
627
+ if model_kwargs.get("past_key_values") is not None:
628
+ cache = model_kwargs["past_key_values"]
629
+ past_length = 0
630
+ if not isinstance(cache, Cache):
631
+ past_length = cache[0][0].shape[2]
632
+ elif (
633
+ hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None
634
+ ):
635
+ past_length = cache.get_seq_length()
636
+
637
+ # TODO(joao): this is not torch.compile-friendly, find a work-around. If the cache is not empty,
638
+ # end-to-end compilation will yield bad results because `cache_position` will be incorrect.
639
+ if not is_torchdynamo_compiling():
640
+ cache_position = cache_position[past_length:]
641
+
642
+ model_kwargs["cache_position"] = cache_position
643
+
644
+ return model_kwargs
645
+
646
+ @torch.inference_mode()
647
+ def generate(
648
+ self,
649
+ inputs: torch.Tensor | None = None,
650
+ generation_config: GenerationConfig | None = None,
651
+ stopping_criteria: StoppingCriteriaList | list | None = None,
652
+ streamer: BaseStreamer | None = None,
653
+ synced_gpus: bool | None = None,
654
+ global_sampler: MiMoSampler | None = None,
655
+ local_sampler: MiMoSampler | None = None,
656
+ warmup_run: bool | None = None,
657
+ **kwargs,
658
+ ) -> Union[GenerateOutput, torch.LongTensor]:
659
+ generation_config, model_kwargs = self._prepare_generation_config(
660
+ generation_config, **kwargs
661
+ )
662
+
663
+ self._validate_model_kwargs(model_kwargs.copy())
664
+
665
+ # 2. Set generation parameters if not already defined
666
+ if synced_gpus is None:
667
+ if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1:
668
+ synced_gpus = True
669
+ else:
670
+ synced_gpus = False
671
+
672
+ # 3. Define model inputs
673
+ input_ids, _model_input_name, model_kwargs = self._prepare_model_inputs(
674
+ inputs, generation_config.bos_token_id, model_kwargs
675
+ )
676
+ input_ids_length = input_ids.shape[-1]
677
+ input_ids_length //= self.group_size * (self.audio_channels + 1)
678
+
679
+ if streamer is not None:
680
+ streamer.put(input_ids.cpu())
681
+
682
+ if "attention_mask" not in model_kwargs:
683
+ model_kwargs["attention_mask"] = self._prepare_attention_mask(
684
+ inputs, input_ids_length
685
+ )
686
+
687
+ device = input_ids.device
688
+ self._prepare_special_tokens(generation_config, True, device=device)
689
+
690
+ model_kwargs["use_cache"] = True
691
+ model_kwargs["past_key_values"] = DynamicCache()
692
+
693
+ prepared_stopping_criteria = StoppingCriteriaList(
694
+ stopping_criteria if stopping_criteria is not None else []
695
+ )
696
+ prepared_stopping_criteria.append(
697
+ MiMoStopper(
698
+ self.group_size,
699
+ self.audio_channels,
700
+ max_length=generation_config.max_length,
701
+ )
702
+ )
703
+ stance = "default" if warmup_run else "eager_on_recompile"
704
+ with torch.compiler.set_stance(stance):
705
+ return self.slm_sample(
706
+ input_ids,
707
+ stopping_criteria=prepared_stopping_criteria,
708
+ generation_config=generation_config,
709
+ synced_gpus=synced_gpus,
710
+ streamer=streamer,
711
+ global_sampler=global_sampler,
712
+ local_sampler=local_sampler,
713
+ **model_kwargs,
714
+ )
715
+
716
+ def slm_sample(
717
+ self,
718
+ input_ids: torch.LongTensor,
719
+ stopping_criteria: StoppingCriteriaList,
720
+ generation_config: GenerationConfig,
721
+ synced_gpus: bool,
722
+ streamer: BaseStreamer | None,
723
+ global_sampler: MiMoSampler | None = None,
724
+ local_sampler: MiMoSampler | None = None,
725
+ **model_kwargs,
726
+ ) -> torch.LongTensor:
727
+ max_length = generation_config.max_length
728
+
729
+ B, cur_len = input_ids.shape
730
+ cur_len //= self.group_size * (self.audio_channels + 1)
731
+ initial_len = cur_len
732
+ this_peer_finished = False
733
+ unfinished_sequences = torch.ones(B, dtype=torch.long, device=input_ids.device)
734
+
735
+ min_length = 0
736
+ stop_token_ids = set()
737
+ for criterion in stopping_criteria:
738
+ if isinstance(criterion, MiMoStopper):
739
+ if criterion.min_length is not None:
740
+ min_length = max(min_length, criterion.min_length)
741
+ stop_token_ids.update(criterion.stop_token_ids)
742
+
743
+ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
744
+
745
+ while self._has_unfinished_sequences(
746
+ this_peer_finished,
747
+ synced_gpus,
748
+ device=input_ids.device,
749
+ cur_len=cur_len,
750
+ max_length=max_length,
751
+ ):
752
+ # prepare model inputs
753
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
754
+
755
+ # forward pass to get next token
756
+ if (
757
+ cast(torch.Tensor, model_inputs["input_ids"]).shape[2]
758
+ != self.group_size
759
+ ):
760
+ # prefill run
761
+ with torch.compiler.set_stance("force_eager"):
762
+ outputs: MiMoAudioOutput = self(**model_inputs)
763
+ else:
764
+ outputs: MiMoAudioOutput = self(**model_inputs)
765
+
766
+ if synced_gpus and this_peer_finished:
767
+ continue # don't waste resources running the code we don't need
768
+
769
+ text_logits: torch.Tensor = outputs.text_logits[:, -1, :].clone()
770
+ # [B, vocab_size]
771
+
772
+ removed_tokens = None
773
+ if cur_len < min_length:
774
+ removed_tokens = list(stop_token_ids)
775
+
776
+ next_text_tokens = global_sampler.sample(text_logits, removed_tokens=removed_tokens)
777
+ # [B]
778
+
779
+ local_hidden_states = outputs.local_hidden_states
780
+
781
+ # Only Supports batch_size=1 here
782
+ if next_text_tokens[0] != self.args.empty_idx:
783
+ zero_embed_tensor = torch.tensor(
784
+ self.speech_empty_ids,
785
+ device=next_text_tokens.device,
786
+ dtype=input_ids.dtype,
787
+ )
788
+ next_speech_tokens = zero_embed_tensor.view(
789
+ 1, 1, self.audio_channels
790
+ ).expand(B, self.config.group_size, -1)
791
+ else:
792
+ next_speech_tokens = self.local_forward(
793
+ local_embeds=local_hidden_states,
794
+ tokens_dtype=next_text_tokens.dtype,
795
+ tokens_device=next_text_tokens.device,
796
+ local_sampler=local_sampler,
797
+ )
798
+
799
+ next_text_tokens = next_text_tokens.reshape(B, 1, 1).expand(
800
+ -1, self.group_size, -1
801
+ ) # [B, group_size, 1]
802
+
803
+ # generate speech tokens
804
+ next_tokens = torch.cat(
805
+ (next_text_tokens, next_speech_tokens), dim=-1
806
+ ).reshape(B, -1) # [B, group_size * (audio_channels + 1)]
807
+
808
+ input_ids = torch.cat(
809
+ [input_ids, next_tokens], dim=-1
810
+ ) # [B, T*group_size*vq]
811
+
812
+ if streamer is not None:
813
+ streamer.put(next_tokens.cpu())
814
+ model_kwargs = self._update_model_kwargs_for_generation(
815
+ outputs,
816
+ model_kwargs,
817
+ is_encoder_decoder=self.config.is_encoder_decoder,
818
+ )
819
+
820
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(
821
+ input_ids, None
822
+ )
823
+ this_peer_finished = unfinished_sequences.max() == 0
824
+ cur_len += 1
825
+
826
+ # This is needed to properly delete outputs.logits which may be very large for first iteration
827
+ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
828
+ del outputs
829
+
830
+ if streamer is not None:
831
+ streamer.end()
832
+
833
+ input_ids = input_ids[:B]
834
+
835
+ return input_ids
src/mimo_audio/process_speechdata.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corporation.
3
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from typing import Tuple, Union, List
8
+
9
+
10
+ class InputSegment:
11
+
12
+ def __init__(
13
+ self,
14
+ text: str = "",
15
+ audio: torch.Tensor = None,
16
+ tokenized_text: torch.Tensor = None,
17
+ speech_zeroemb_idx: Union[int, List[int]] = 1024,
18
+ text_zeroemb_idx: int = 152067,
19
+ add_sosp_eosp=True,
20
+ ) -> None:
21
+ has_text = text is not None
22
+ has_tokenized_text = tokenized_text is not None
23
+ assert has_text or has_tokenized_text, "Text or tokenized text must be provided"
24
+
25
+ self.audio = audio
26
+ self.text = text
27
+ self.tokenized_text = tokenized_text
28
+ self.speech_zeroemb_idx = speech_zeroemb_idx
29
+ self.text_zeroemb_idx = text_zeroemb_idx
30
+ self.add_sosp_eosp = add_sosp_eosp
31
+
32
+ @staticmethod
33
+ def insert_between(tensor, i, value=-1):
34
+ return torch.scatter(
35
+ torch.full(
36
+ (1, tensor.shape[1] + (tensor.shape[1] - 1) * i + i),
37
+ value,
38
+ dtype=tensor.dtype,
39
+ ),
40
+ 1,
41
+ torch.arange(0, tensor.shape[1], dtype=torch.int64)[None] * (i + 1),
42
+ tensor,
43
+ )
44
+
45
+ def to_input_id(
46
+ self,
47
+ tokenizer,
48
+ group_size: int,
49
+ audio_channels: int = 8,
50
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
51
+ if self.audio is None:
52
+ if self.tokenized_text is None:
53
+ tokenized_text = tokenizer(
54
+ self.text,
55
+ return_tensors="pt",
56
+ truncation=True,
57
+ max_length=999999,
58
+ padding=False,
59
+ add_special_tokens=False,
60
+ )["input_ids"].int()
61
+ else:
62
+ tokenized_text = self.tokenized_text.unsqueeze(0)
63
+
64
+
65
+ if group_size > 1:
66
+ tokenized_text = self.insert_between(
67
+ tokenized_text, group_size - 1, value=-100
68
+ )
69
+
70
+
71
+ if isinstance(self.speech_zeroemb_idx, list):
72
+ audio_part_input_id = torch.zeros((audio_channels, tokenized_text.shape[1]), dtype=torch.int)
73
+ for i, idx in enumerate(self.speech_zeroemb_idx):
74
+ audio_part_input_id[i, :] = idx
75
+ else:
76
+ audio_part_input_id = torch.full(
77
+ (audio_channels, tokenized_text.shape[1]), self.speech_zeroemb_idx, dtype=torch.int
78
+ )
79
+
80
+
81
+ else:
82
+ sosp_token = (
83
+ tokenizer.convert_tokens_to_ids("<|sosp|>")
84
+ if self.add_sosp_eosp
85
+ else None
86
+ )
87
+ eosp_token = (
88
+ tokenizer.convert_tokens_to_ids("<|eosp|>")
89
+ if self.add_sosp_eosp
90
+ else None
91
+ )
92
+ audio_part = self.audio.reshape(-1, audio_channels).T # [audio_channels, seqlen]
93
+
94
+ assert (
95
+ audio_part.shape[1] % group_size == 0
96
+ ), f"Audio shape {audio_part.shape} is not divisible by group_size {group_size}"
97
+
98
+
99
+ text_len = audio_part.shape[1] // group_size
100
+ empty_token = self.text_zeroemb_idx
101
+ if empty_token is None:
102
+ empty_token = tokenizer.eod
103
+ tokenized_text = torch.full((1, text_len), empty_token, dtype=torch.int)
104
+
105
+ tokenized_text = (
106
+ torch.cat(
107
+ [
108
+ torch.tensor([[sosp_token]], dtype=torch.int),
109
+ tokenized_text,
110
+ torch.tensor([[eosp_token]], dtype=torch.int),
111
+ ],
112
+ dim=1,
113
+ )
114
+ if self.add_sosp_eosp
115
+ else tokenized_text
116
+ )
117
+ tokenized_text = self.insert_between(
118
+ tokenized_text, group_size - 1, value=-100
119
+ )
120
+
121
+
122
+ if self.add_sosp_eosp:
123
+ if isinstance(self.speech_zeroemb_idx, list):
124
+ sosp_part = torch.zeros((audio_channels, group_size), dtype=torch.int)
125
+ eosp_part = torch.zeros((audio_channels, group_size), dtype=torch.int)
126
+ for i, idx in enumerate(self.speech_zeroemb_idx):
127
+ sosp_part[i, :] = idx
128
+ eosp_part[i, :] = idx
129
+ audio_part_input_id = torch.cat([sosp_part, audio_part, eosp_part], dim=1)
130
+ else:
131
+ audio_part_input_id = torch.cat(
132
+ [
133
+ torch.full((audio_channels, group_size), self.speech_zeroemb_idx, dtype=torch.int),
134
+ audio_part,
135
+ torch.full((audio_channels, group_size), self.speech_zeroemb_idx, dtype=torch.int),
136
+ ],
137
+ dim=1,
138
+ )
139
+ else:
140
+ audio_part_input_id = audio_part
141
+
142
+
143
+
144
+ input_ids = torch.cat(
145
+ [tokenized_text, audio_part_input_id], dim=0
146
+ ) # [n_rvq + 1, seqlen]
147
+
148
+
149
+ return input_ids
150
+
151
+
152
+ class StreamingInputSegment:
153
+ def __init__(
154
+ self,
155
+ text: str = "",
156
+ audio: torch.Tensor = None,
157
+ tokenized_text: torch.Tensor = None,
158
+ speech_zeroemb_idx: Union[int, List[int]] = 1024,
159
+ text_zeroemb_idx: int = 152067,
160
+ text_segment_size: int = 5,
161
+ audio_segment_size: int = 5,
162
+ tokenizer=None,
163
+ group_size=None,
164
+ audio_channels=None,
165
+ ) -> None:
166
+ has_text = text is not None
167
+ has_tokenized_text = tokenized_text is not None
168
+ assert has_text or has_tokenized_text, "Text or tokenized text must be provided"
169
+
170
+ self.audio = audio
171
+ self.text = text
172
+ self.tokenized_text = tokenized_text
173
+ self.speech_zeroemb_idx = speech_zeroemb_idx
174
+ self.text_zeroemb_idx = text_zeroemb_idx
175
+ self.text_segment_size = text_segment_size
176
+ self.audio_segment_size = audio_segment_size
177
+ self.tokenizer = tokenizer
178
+ self.group_size = group_size
179
+ self.audio_channels = audio_channels
180
+
181
+ def to_input_id(
182
+ self,
183
+ tokenizer,
184
+ group_size: int,
185
+ audio_channels: int = 8,
186
+ ):
187
+ if self.tokenized_text is None:
188
+ tokenized_text = tokenizer(
189
+ self.text,
190
+ return_tensors="pt",
191
+ truncation=True,
192
+ max_length=999999,
193
+ padding=False,
194
+ add_special_tokens=False,
195
+ )["input_ids"].int() # [1, seqlen]
196
+ else:
197
+ tokenized_text = self.tokenized_text.unsqueeze(0)
198
+
199
+ tokenized_text = tokenized_text.squeeze(0)
200
+
201
+ text_segments = tokenized_text.split(self.text_segment_size, dim=0)
202
+ audio_segments = self.audio.split(self.audio_segment_size*group_size*audio_channels, dim=0)
203
+
204
+ tokenized_segments = []
205
+ tokenized_segments.append(
206
+ InputSegment(
207
+ text='<|sostm|>',
208
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
209
+ text_zeroemb_idx=self.text_zeroemb_idx,
210
+ ),
211
+ )
212
+
213
+
214
+ eot_tokens = tokenizer(
215
+ "<|eot|>",
216
+ return_tensors="pt",
217
+ truncation=True,
218
+ max_length=999999,
219
+ padding=False,
220
+ add_special_tokens=False,
221
+ )["input_ids"][0].to(text_segments[-1])
222
+
223
+
224
+ text_segments = text_segments[:-1] + (torch.cat([text_segments[-1], eot_tokens], dim=0),)
225
+
226
+
227
+ length = min(len(text_segments), len(audio_segments))
228
+ for i in range(length):
229
+ text_segment = text_segments[i]
230
+ audio_segment = audio_segments[i]
231
+
232
+ tokenized_segments.append(
233
+ InputSegment(
234
+ tokenized_text=text_segment,
235
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
236
+ text_zeroemb_idx=self.text_zeroemb_idx,
237
+ ),
238
+ )
239
+ tokenized_segments.append(
240
+ InputSegment(
241
+ audio=audio_segment,
242
+ add_sosp_eosp=False,
243
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
244
+ text_zeroemb_idx=self.text_zeroemb_idx,
245
+ ),
246
+ )
247
+
248
+ for j in range(length, len(text_segments)):
249
+ tokenized_segments.append(
250
+ InputSegment(
251
+ tokenized_text=text_segments[j],
252
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
253
+ text_zeroemb_idx=self.text_zeroemb_idx,
254
+ ),
255
+ )
256
+
257
+ for j in range(length, len(audio_segments)):
258
+ tokenized_segments.append(
259
+ InputSegment(
260
+ audio=audio_segments[j],
261
+ add_sosp_eosp=False,
262
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
263
+ text_zeroemb_idx=self.text_zeroemb_idx,
264
+ ),
265
+ )
266
+
267
+ tokenized_segments.append(
268
+ InputSegment(
269
+ text="<|eostm|>",
270
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
271
+ text_zeroemb_idx=self.text_zeroemb_idx,
272
+ ),
273
+ )
274
+
275
+
276
+ input_ids = [
277
+ seg.to_input_id(
278
+ self.tokenizer,
279
+ self.group_size,
280
+ self.audio_channels,
281
+ )
282
+ for seg in tokenized_segments
283
+ ]
284
+
285
+
286
+
287
+ input_ids = torch.cat(input_ids, dim=1).type(torch.int64) # [n_rvq + 1, seqlen]
288
+
289
+ return input_ids
src/mimo_audio/templates.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xiaomi Corporation.
2
+ asr_zh_templates = [
3
+ "请将这段语音转换为文字",
4
+ "帮我识别这个音频文件中的内容",
5
+ "把这段录音转成文本",
6
+ "请转录这段语音",
7
+ "将音频内容转换成文字格式",
8
+ "识别并转写这段语音",
9
+ "把语音内容写成文字",
10
+ "转录这个音频片段",
11
+ "将这段对话转换为文本",
12
+ "麻烦帮我把这段录音整理成详细的文字记录",
13
+ ]
14
+
15
+ asr_en_templates = [
16
+ "Please transcribe this audio file",
17
+ "Convert this speech recording to text",
18
+ "Transcribe the following voice message",
19
+ "Turn this audio into readable text",
20
+ "Please convert the recording to written format",
21
+ "Transcribe what you hear in this audio",
22
+ "Convert this spoken content to text",
23
+ "Please write down what is said in this recording",
24
+ "Transcribe this voice recording",
25
+ "Could you please help me transcribe this important recording?",
26
+ "Would you mind converting this voice message into a readable text format?",
27
+ "I'd really appreciate it if you could turn this audio file into a written document",
28
+ ]
29
+
30
+ tts_zh_templates = [
31
+ "请将这段文字转换为语音",
32
+ "帮我把这个文本读出来",
33
+ "将这些文字生成音频",
34
+ "请朗读这段内容",
35
+ "把这段话转换成语音文件",
36
+ "生成这段文字的语音版本",
37
+ "请用语音播报这些内容",
38
+ "将文本转换为可听的音频",
39
+ "帮我朗读这段文字",
40
+ "把这些内容念出来",
41
+ ]
42
+
43
+ tts_en_templates = [
44
+ "Please convert this text to speech",
45
+ "Turn this writing into audio",
46
+ "Generate speech from this text",
47
+ "Read this content out loud",
48
+ "Convert these words to voice",
49
+ "Create an audio version of this text",
50
+ "Please vocalize this content",
51
+ "Turn this text into audible format",
52
+ "Help me convert this writing to speech",
53
+ "Make this text into spoken audio",
54
+ ]
src/mimo_audio_tokenizer/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xiaomi Corporation.
2
+ from .modeling_audio_tokenizer import MiMoAudioTokenizer, StreamingConfig, StreamingCache
3
+ from .configuration_audio_tokenizer import MiMoAudioTokenizerConfig
4
+
5
+
6
+ __all__ = ['MiMoAudioTokenizer', 'StreamingConfig', 'StreamingCache', 'MiMoAudioTokenizerConfig']
src/mimo_audio_tokenizer/configuration_audio_tokenizer.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xiaomi Corporation.
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class MiMoAudioTokenizerConfig(PretrainedConfig):
6
+ model_type = "mimo_audio_tokenizer"
7
+
8
+ def __init__(
9
+ self,
10
+ max_audio_seconds: int = 1800,
11
+ stride_size: int = 2,
12
+ avg_pooler: int = 1,
13
+ d_model: int = 768,
14
+ scale_embedding: bool = True,
15
+ kernel_size: int = 3,
16
+ activation_function: str = "gelu",
17
+ encoder_layers: int = 8,
18
+ encoder_skip_layer_id: int = None,
19
+ encoder_attention_heads: int = 12,
20
+ encoder_ffn_dim: int = 3072,
21
+ encoder_causal: bool = False,
22
+ encoder_attn_window_size: list[int] = None,
23
+ decoder_layers: int = 8,
24
+ decoder_attention_heads: int = 12,
25
+ decoder_ffn_dim: int = 3072,
26
+ decoder_kernel_size: int = 3,
27
+ decoder_stride_size: int = 2,
28
+ decoder_causal: bool = True,
29
+ decoder_attn_window_size: list[int] = None,
30
+ nfft: int = 1024,
31
+ vocoder_dim: int = 512,
32
+ vocoder_intermediate_dim: int = 4096,
33
+ vocoder_num_layers: int = 30,
34
+ n_mels: int = 80,
35
+ sampling_rate: int = 24000,
36
+ hop_length: int = 240,
37
+ window_size: int = 1024,
38
+ vocoder_padding: str = "same",
39
+ fmin: int = 0,
40
+ fmax: int = None,
41
+ num_quantizers: int = 12,
42
+ codebook_size: list[int] = None,
43
+ threshold_ema_dead_code: int = 10,
44
+ position_embedding_type: str = "rope",
45
+ rope_theta: int = 10000,
46
+ rope_type: str = "default",
47
+ ln_type: str = "LayerNorm",
48
+ vocoder_attention_heads: int = 4,
49
+ vocoder_attn_window_size: list[int] = None,
50
+ **kwargs,
51
+ ):
52
+ super().__init__(**kwargs)
53
+ self.max_audio_seconds = max_audio_seconds
54
+ self.stride_size = stride_size
55
+ self.avg_pooler = avg_pooler
56
+ self.d_model = d_model
57
+ self.scale_embedding = scale_embedding
58
+ self.kernel_size = kernel_size
59
+ self.activation_function = activation_function
60
+ self.encoder_layers = encoder_layers
61
+ self.encoder_skip_layer_id = encoder_skip_layer_id
62
+ self.encoder_attention_heads = encoder_attention_heads
63
+ self.encoder_ffn_dim = encoder_ffn_dim
64
+ self.encoder_causal = encoder_causal
65
+ self.encoder_attn_window_size = (
66
+ encoder_attn_window_size
67
+ if encoder_attn_window_size is not None
68
+ else [-1, -1]
69
+ )
70
+ self.decoder_layers = decoder_layers
71
+ self.decoder_attention_heads = decoder_attention_heads
72
+ self.decoder_ffn_dim = decoder_ffn_dim
73
+ self.decoder_kernel_size = decoder_kernel_size
74
+ self.decoder_stride_size = decoder_stride_size
75
+ self.decoder_causal = decoder_causal
76
+ self.decoder_attn_window_size = (
77
+ decoder_attn_window_size
78
+ if decoder_attn_window_size is not None
79
+ else [-1, -1]
80
+ )
81
+ self.nfft = nfft
82
+ self.vocoder_dim = vocoder_dim
83
+ self.vocoder_intermediate_dim = vocoder_intermediate_dim
84
+ self.vocoder_num_layers = vocoder_num_layers
85
+ self.n_mels = n_mels
86
+ self.sampling_rate = sampling_rate
87
+ self.hop_length = hop_length
88
+ self.window_size = window_size
89
+ self.vocoder_padding = vocoder_padding
90
+ self.fmin = fmin
91
+ self.fmax = fmax
92
+ self.num_quantizers = num_quantizers
93
+ self.codebook_size = codebook_size if codebook_size is not None else [1024]
94
+ self.threshold_ema_dead_code = threshold_ema_dead_code
95
+ self.position_embedding_type = position_embedding_type
96
+ self.rope_theta = rope_theta
97
+ self.rope_type = rope_type
98
+ self.ln_type = ln_type
99
+ self.vocoder_attention_heads = vocoder_attention_heads
100
+ self.vocoder_attn_window_size = (
101
+ vocoder_attn_window_size
102
+ if vocoder_attn_window_size is not None
103
+ else [40, 10]
104
+ )
src/mimo_audio_tokenizer/modeling_audio_tokenizer.py ADDED
@@ -0,0 +1,857 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xiaomi Corporation.
2
+ import math
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ from flash_attn import flash_attn_varlen_func
8
+ from torch.nn import functional as F
9
+ from transformers.activations import ACT2FN
10
+ from transformers.modeling_utils import PreTrainedModel
11
+
12
+ from .configuration_audio_tokenizer import MiMoAudioTokenizerConfig
13
+ from .modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update, apply_rotary_pos_emb
14
+ from .quantization import ResidualVectorQuantizer
15
+ from dataclasses import dataclass, field
16
+ from typing import List
17
+
18
+ def get_sequence_mask(inputs, inputs_length):
19
+ if inputs.dim() == 3:
20
+ bsz, tgt_len, _ = inputs.size()
21
+ else:
22
+ bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
23
+ sequence_mask = torch.arange(0, tgt_len).to(inputs.device)
24
+ sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(
25
+ bsz, tgt_len, 1
26
+ )
27
+ unpacking_index = torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1
28
+ return sequence_mask, unpacking_index
29
+
30
+
31
+ def unpack_hidden_states(
32
+ hidden_states, lengths, sequence_mask=None, unpacking_index=None
33
+ ):
34
+ bsz = lengths.shape[0]
35
+ if sequence_mask is None or unpacking_index is None:
36
+ sequence_mask, unpacking_index = get_sequence_mask(hidden_states, lengths)
37
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
38
+ bsz, torch.max(lengths), hidden_states.shape[-1]
39
+ )
40
+ hidden_states = torch.where(sequence_mask, hidden_states, 0)
41
+ return hidden_states
42
+
43
+
44
+ def get_position_ids(lengths):
45
+ total_len = lengths.sum()
46
+ offset = torch.cat([torch.zeros(1).to(lengths), lengths[:-1].cumsum(dim=0)])
47
+ offset = torch.repeat_interleave(offset, lengths)
48
+ position_ids = torch.arange(0, total_len).to(offset) - offset
49
+ return position_ids
50
+
51
+ @dataclass
52
+ class StreamingConfig:
53
+ seg_point: int = field(default=60 * 25)
54
+ process_seg_point: bool = field(default=True)
55
+ left_overlap: int = field(default=10 * 25)
56
+ right_overlap: int = field(default=40)
57
+ seg_point_left_overlap: int = field(default=0)
58
+
59
+ @dataclass
60
+ class StreamingCache:
61
+ hidden_states: List[torch.Tensor] = field(default=None)
62
+ processed_lengths: List[int] = field(default=None)
63
+
64
+ class ISTFT(nn.Module):
65
+ """
66
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
67
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
68
+ See issue: https://github.com/pytorch/pytorch/issues/62323
69
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
70
+ The NOLA constraint is met as we trim padded samples anyway.
71
+
72
+ Args:
73
+ n_fft (int): Size of Fourier transform.
74
+ hop_length (int): The distance between neighboring sliding window frames.
75
+ win_length (int): The size of window frame and STFT filter.
76
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
77
+ """
78
+
79
+ def __init__(
80
+ self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
81
+ ):
82
+ super().__init__()
83
+ if padding not in ["center", "same"]:
84
+ raise ValueError("Padding must be 'center' or 'same'.")
85
+ self.padding = padding
86
+ self.n_fft = n_fft
87
+ self.hop_length = hop_length
88
+ self.win_length = win_length
89
+ window = torch.hann_window(win_length)
90
+ self.register_buffer("window", window)
91
+
92
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
93
+ """
94
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
95
+
96
+ Args:
97
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
98
+ N is the number of frequency bins, and T is the number of time frames.
99
+
100
+ Returns:
101
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
102
+ """
103
+ if self.padding == "center":
104
+ # Fallback to pytorch native implementation
105
+ return torch.istft(
106
+ spec,
107
+ self.n_fft,
108
+ self.hop_length,
109
+ self.win_length,
110
+ self.window,
111
+ center=True,
112
+ )
113
+ elif self.padding == "same":
114
+ pad = (self.win_length - self.hop_length) // 2
115
+ else:
116
+ raise ValueError("Padding must be 'center' or 'same'.")
117
+
118
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
119
+ B, N, T = spec.shape
120
+
121
+ # Inverse FFT
122
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
123
+ ifft = ifft * self.window[None, :, None]
124
+
125
+ # Overlap and Add
126
+ output_size = (T - 1) * self.hop_length + self.win_length
127
+ y = torch.nn.functional.fold(
128
+ ifft,
129
+ output_size=(1, output_size),
130
+ kernel_size=(1, self.win_length),
131
+ stride=(1, self.hop_length),
132
+ )[:, 0, 0, pad:-pad]
133
+
134
+ # Window envelope
135
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
136
+ window_envelope = torch.nn.functional.fold(
137
+ window_sq,
138
+ output_size=(1, output_size),
139
+ kernel_size=(1, self.win_length),
140
+ stride=(1, self.hop_length),
141
+ ).squeeze()[pad:-pad]
142
+
143
+ # Normalize
144
+ assert (window_envelope > 1e-11).all()
145
+ y = y / window_envelope
146
+
147
+ return y
148
+
149
+ class ISTFTHead(nn.Module):
150
+ """
151
+ ISTFT Head module for predicting STFT complex coefficients.
152
+
153
+ Args:
154
+ dim (int): Hidden dimension of the model.
155
+ n_fft (int): Size of Fourier transform.
156
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
157
+ the resolution of the input features.
158
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
159
+ """
160
+
161
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
162
+ super().__init__()
163
+ out_dim = n_fft + 2
164
+ self.out = torch.nn.Linear(dim, out_dim)
165
+ self.istft = ISTFT(
166
+ n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
167
+ )
168
+
169
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
170
+ """
171
+ Forward pass of the ISTFTHead module.
172
+
173
+ Args:
174
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
175
+ L is the sequence length, and H denotes the model dimension.
176
+
177
+ Returns:
178
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
179
+ """
180
+ x = self.out(x).transpose(1, 2)
181
+ mag, p = x.chunk(2, dim=1)
182
+ mag = torch.exp(mag)
183
+ mag = torch.clip(
184
+ mag, max=1e2
185
+ ) # safeguard to prevent excessively large magnitudes
186
+ # wrapping happens here. These two lines produce real and imaginary value
187
+ x = torch.cos(p)
188
+ y = torch.sin(p)
189
+ # recalculating phase here does not produce anything new
190
+ # only costs time
191
+ # phase = torch.atan2(y, x)
192
+ # S = mag * torch.exp(phase * 1j)
193
+ # better directly produce the complex value
194
+ original_dtype = x.dtype
195
+ S = mag.float() * (x.float() + 1j * y.float())
196
+ audio = self.istft(S)
197
+ audio = audio.to(original_dtype)
198
+ return audio
199
+
200
+
201
+ class RotaryEmbedding(nn.Module):
202
+ def __init__(self, base, dim, max_seq_len, rope_type="default", device=None):
203
+ super().__init__()
204
+ self.max_seq_len = max_seq_len
205
+ self.rope_type = rope_type
206
+
207
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
208
+
209
+ inv_freq, self.attention_scaling = self.rope_init_fn(
210
+ device=device, base=base, dim=dim
211
+ )
212
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
213
+ self.original_inv_freq = self.inv_freq
214
+
215
+ @torch.no_grad()
216
+ @dynamic_rope_update
217
+ def forward(self, x, position_ids):
218
+ inv_freq_expanded = self.inv_freq[:, None].float().expand(-1, 1).to(x.device)
219
+ position_ids_expanded = position_ids[None, :].float()
220
+
221
+ device_type = (
222
+ x.device.type
223
+ if isinstance(x.device.type, str) and x.device.type != "mps"
224
+ else "cpu"
225
+ )
226
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
227
+ freqs = (
228
+ inv_freq_expanded.float() @ position_ids_expanded.float()
229
+ ).transpose(0, 1)
230
+ emb = torch.cat((freqs, freqs), dim=-1)
231
+ cos = emb.cos() * self.attention_scaling
232
+ sin = emb.sin() * self.attention_scaling
233
+
234
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
235
+
236
+ class RMSNorm(nn.Module):
237
+ def __init__(self, hidden_size, eps=1e-6):
238
+ """
239
+ RMSNorm is equivalent to T5LayerNorm
240
+ """
241
+ super().__init__()
242
+ self.weight = nn.Parameter(torch.ones(hidden_size))
243
+ self.variance_epsilon = eps
244
+
245
+ def forward(self, hidden_states):
246
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
247
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
248
+
249
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
250
+ hidden_states = hidden_states.to(self.weight.dtype)
251
+
252
+ return self.weight * hidden_states
253
+
254
+
255
+ LAYER_NORM = {"LayerNorm": nn.LayerNorm, "RMSNorm": RMSNorm}
256
+
257
+
258
+ class Attention(nn.Module):
259
+ def __init__(self, embed_dim, num_heads, window_size=(-1, -1), causal=False):
260
+ super().__init__()
261
+ self.embed_dim = embed_dim
262
+ self.num_heads = num_heads
263
+ self.head_dim = embed_dim // num_heads
264
+ self.window_size = window_size
265
+
266
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
267
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
268
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
269
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
270
+
271
+ self.causal = causal
272
+
273
+ def forward(
274
+ self,
275
+ hidden_states: torch.Tensor,
276
+ seq_len: torch.Tensor,
277
+ rope_position_embeddings=None,
278
+ ):
279
+ bsz, _ = hidden_states.size()
280
+
281
+ query_states = self.q_proj(hidden_states).view(
282
+ bsz, self.num_heads, self.head_dim
283
+ )
284
+ key_states = self.k_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
285
+ value_states = self.v_proj(hidden_states).view(
286
+ bsz, self.num_heads, self.head_dim
287
+ )
288
+
289
+ if rope_position_embeddings is not None:
290
+ cos, sin = rope_position_embeddings
291
+ query_states = apply_rotary_pos_emb(query_states, cos, sin)
292
+ key_states = apply_rotary_pos_emb(key_states, cos, sin)
293
+
294
+ cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(
295
+ torch.int32
296
+ )
297
+ max_seqlen = torch.max(seq_len).to(torch.int32).detach()
298
+ attn_output = flash_attn_varlen_func(
299
+ query_states,
300
+ key_states,
301
+ value_states,
302
+ cu_len,
303
+ cu_len,
304
+ max_seqlen,
305
+ max_seqlen,
306
+ causal=self.causal,
307
+ window_size=self.window_size,
308
+ )
309
+ attn_output = attn_output.reshape(bsz, self.embed_dim)
310
+ attn_output = self.out_proj(attn_output)
311
+ return attn_output
312
+
313
+
314
+ class TransformerLayer(nn.Module):
315
+ def __init__(
316
+ self,
317
+ act,
318
+ d_model,
319
+ encoder_attention_heads,
320
+ encoder_ffn_dim,
321
+ causal,
322
+ ln_type="LayerNorm",
323
+ attn_window_size=(-1, -1),
324
+ ):
325
+ super().__init__()
326
+ self.embed_dim = d_model
327
+ self.self_attn = Attention(
328
+ self.embed_dim, encoder_attention_heads, attn_window_size, causal
329
+ )
330
+
331
+ self.self_attn_layer_norm = LAYER_NORM[ln_type](self.embed_dim)
332
+
333
+ self.activation_fn = act
334
+ self.fc1 = nn.Linear(self.embed_dim, encoder_ffn_dim)
335
+ self.fc2 = nn.Linear(encoder_ffn_dim, self.embed_dim)
336
+
337
+ self.final_layer_norm = LAYER_NORM[ln_type](self.embed_dim)
338
+
339
+ def forward(
340
+ self,
341
+ hidden_states: torch.Tensor,
342
+ seq_len: torch.Tensor,
343
+ rope_position_embeddings: torch.Tensor,
344
+ ) -> torch.Tensor:
345
+ residual = hidden_states
346
+ hidden_states = self.self_attn_layer_norm(hidden_states)
347
+ hidden_states = self.self_attn(
348
+ hidden_states, seq_len, rope_position_embeddings=rope_position_embeddings
349
+ )
350
+ hidden_states = residual + hidden_states
351
+ residual = hidden_states
352
+ hidden_states = self.final_layer_norm(hidden_states)
353
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
354
+ hidden_states = self.fc2(hidden_states)
355
+ hidden_states = residual + hidden_states
356
+
357
+ if (
358
+ hidden_states.dtype == torch.float16
359
+ or hidden_states.dtype == torch.bfloat16
360
+ ) and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()):
361
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
362
+ hidden_states = torch.clamp(
363
+ hidden_states, min=-clamp_value, max=clamp_value
364
+ )
365
+ return hidden_states
366
+
367
+
368
+ class TransformerVocos(nn.Module):
369
+ def __init__(self, config: MiMoAudioTokenizerConfig):
370
+ super().__init__()
371
+ self.config = config
372
+ self.max_source_positions = (
373
+ self.config.max_audio_seconds
374
+ * self.config.sampling_rate
375
+ // self.config.hop_length
376
+ )
377
+ self.embeddings = nn.Linear(config.n_mels, config.vocoder_dim, bias=False)
378
+
379
+ self.poisition_embedding = RotaryEmbedding(
380
+ config.rope_theta,
381
+ config.vocoder_dim // config.vocoder_attention_heads,
382
+ self.max_source_positions,
383
+ self.config.rope_type,
384
+ )
385
+
386
+ self.layers = nn.ModuleList(
387
+ [
388
+ TransformerLayer(
389
+ ACT2FN[self.config.activation_function],
390
+ self.config.vocoder_dim,
391
+ self.config.vocoder_attention_heads,
392
+ self.config.vocoder_intermediate_dim,
393
+ causal=False,
394
+ ln_type=self.config.ln_type,
395
+ attn_window_size=self.config.vocoder_attn_window_size,
396
+ )
397
+ for _ in range(self.config.vocoder_num_layers)
398
+ ]
399
+ )
400
+
401
+ self.layer_norm = LAYER_NORM[self.config.ln_type](self.config.vocoder_dim)
402
+ self.hop_size = self.config.hop_length
403
+ self.head = ISTFTHead(
404
+ self.config.vocoder_dim,
405
+ self.config.nfft,
406
+ self.config.hop_length,
407
+ self.config.vocoder_padding,
408
+ )
409
+
410
+ def forward(self, x: torch.Tensor, input_length):
411
+ x = x.transpose(1, 2)
412
+ attention_mask, unpacking_index = get_sequence_mask(x, input_length)
413
+ x = torch.masked_select(x, attention_mask).view(
414
+ torch.sum(input_length), self.config.n_mels
415
+ )
416
+ x = self.embeddings(x)
417
+ position_ids = torch.arange(0, x.size(0), device=x.device, dtype=torch.long)
418
+ rope_position_embeddings = self.poisition_embedding(x, position_ids)
419
+ for idx, layer in enumerate(self.layers):
420
+ x = layer(
421
+ x, input_length, rope_position_embeddings=rope_position_embeddings
422
+ )
423
+
424
+ x = self.layer_norm(x)
425
+ x = unpack_hidden_states(x, input_length, attention_mask, unpacking_index)
426
+ x = self.head(x)
427
+ output_length = input_length * self.hop_size
428
+ return x[:, None, :], output_length
429
+
430
+
431
+ class AudioEncoder(nn.Module):
432
+ def __init__(self, config: MiMoAudioTokenizerConfig):
433
+ super().__init__()
434
+ config._attn_implementation = "flash_attention_2"
435
+ self.config = config
436
+ self.max_source_positions = (
437
+ config.max_audio_seconds * config.sampling_rate // config.hop_length
438
+ ) // config.stride_size
439
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
440
+
441
+ self.skip_layer_idx = config.encoder_skip_layer_id
442
+ self.conv1 = nn.Conv1d(
443
+ config.n_mels, config.d_model, kernel_size=config.kernel_size, padding=1
444
+ )
445
+ self.conv2 = nn.Conv1d(
446
+ config.d_model,
447
+ config.d_model,
448
+ kernel_size=config.kernel_size,
449
+ stride=config.stride_size,
450
+ padding=1,
451
+ )
452
+
453
+ self.position_embedding = RotaryEmbedding(
454
+ config.rope_theta,
455
+ config.d_model // config.encoder_attention_heads,
456
+ self.max_source_positions,
457
+ config.rope_type,
458
+ )
459
+
460
+ self.layers = nn.ModuleList(
461
+ [
462
+ TransformerLayer(
463
+ ACT2FN[config.activation_function],
464
+ config.d_model,
465
+ config.encoder_attention_heads,
466
+ config.encoder_ffn_dim,
467
+ causal=self.config.encoder_causal,
468
+ ln_type=self.config.ln_type,
469
+ attn_window_size=self.config.encoder_attn_window_size,
470
+ )
471
+ for _ in range(config.encoder_layers)
472
+ ]
473
+ )
474
+
475
+ self.layer_norm = LAYER_NORM[config.ln_type](config.d_model)
476
+
477
+ if self.config.avg_pooler != 1:
478
+ self.down_sample_layer = nn.Sequential(
479
+ nn.Conv1d(
480
+ config.d_model,
481
+ config.d_model,
482
+ config.avg_pooler,
483
+ config.avg_pooler,
484
+ bias=False,
485
+ ),
486
+ nn.GELU(),
487
+ )
488
+ self.down_sample_norm = LAYER_NORM[config.ln_type](config.d_model)
489
+ else:
490
+ self.down_sample_layer = None
491
+
492
+ if self.config.num_quantizers != 0:
493
+ self.quantizer = ResidualVectorQuantizer(
494
+ dimension=self.config.d_model,
495
+ n_q=self.config.num_quantizers,
496
+ bins=self.config.codebook_size,
497
+ threshold_ema_dead_code=self.config.threshold_ema_dead_code,
498
+ )
499
+ else:
500
+ self.quantizer = None
501
+
502
+ def get_features(self, input_features, output_length):
503
+ input_features = input_features.to(self.conv1.weight)
504
+ inputs_embeds = nn.functional.gelu(self.conv1(input_features))
505
+ inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
506
+ inputs_embeds = inputs_embeds.permute(0, 2, 1)
507
+ bsz, tgt_len, _ = inputs_embeds.size()
508
+
509
+ hidden_states = inputs_embeds
510
+
511
+ position_ids = (
512
+ get_position_ids(output_length).long().to(input_features.device)
513
+ )
514
+ rope_position_embeddings = self.position_embedding(
515
+ input_features, position_ids
516
+ )
517
+
518
+ attention_mask, unpacking_index = get_sequence_mask(
519
+ hidden_states, output_length
520
+ )
521
+
522
+ hidden_states = torch.masked_select(hidden_states, attention_mask).view(
523
+ torch.sum(output_length), self.config.d_model
524
+ )
525
+
526
+ skip_connect_hidden_states = 0.0
527
+ for idx, encoder_layer in enumerate(self.layers):
528
+ hidden_states = encoder_layer(
529
+ hidden_states,
530
+ output_length,
531
+ rope_position_embeddings=rope_position_embeddings,
532
+ )
533
+ if (self.skip_layer_idx is not None) and idx == self.skip_layer_idx - 1:
534
+ skip_connect_hidden_states = hidden_states.clone()
535
+
536
+ hidden_states += skip_connect_hidden_states
537
+ hidden_states = self.layer_norm(hidden_states)
538
+
539
+ if self.down_sample_layer is not None:
540
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
541
+ bsz, tgt_len, self.config.d_model
542
+ )
543
+ if hidden_states.size(1) % self.config.avg_pooler:
544
+ pad_len = (
545
+ self.config.avg_pooler
546
+ - hidden_states.size(1) % self.config.avg_pooler
547
+ )
548
+ hidden_states = torch.nn.functional.pad(
549
+ hidden_states, (0, 0, 0, pad_len), mode="constant", value=0.0
550
+ )
551
+ tgt_len += pad_len
552
+ tgt_len = tgt_len // self.config.avg_pooler
553
+ hidden_states = self.down_sample_layer(hidden_states.transpose(1, 2))
554
+ output_length = (
555
+ output_length // self.config.avg_pooler
556
+ + (output_length % self.config.avg_pooler != 0).int()
557
+ )
558
+ hidden_states = hidden_states.transpose(1, 2)
559
+ attention_mask, unpacking_index = get_sequence_mask(
560
+ hidden_states, output_length
561
+ )
562
+ hidden_states = torch.masked_select(hidden_states, attention_mask).view(
563
+ torch.sum(output_length), self.config.d_model
564
+ )
565
+ hidden_states = self.down_sample_norm(hidden_states)
566
+
567
+ return (
568
+ hidden_states,
569
+ output_length,
570
+ attention_mask,
571
+ unpacking_index,
572
+ tgt_len,
573
+ bsz,
574
+ )
575
+
576
+ def get_output_length(self, mel_len):
577
+ tgt_len = mel_len + 3 - self.config.kernel_size
578
+ return (tgt_len + 2 - self.config.kernel_size) // self.config.stride_size + 1
579
+
580
+ @torch.no_grad()
581
+ def encode(
582
+ self,
583
+ input_features,
584
+ input_lens=None,
585
+ output_length=None,
586
+ return_codes_only=False,
587
+ n_q=None,
588
+ use_quantizer=True,
589
+ ):
590
+ if output_length is None:
591
+ output_length = self.get_output_length(input_lens)
592
+ input_features = unpack_hidden_states(input_features, input_lens)
593
+ hidden_states, output_length, attention_mask, unpacking_index, tgt_len, bsz = (
594
+ self.get_features(
595
+ input_features=input_features.transpose(1, 2),
596
+ output_length=output_length,
597
+ )
598
+ )
599
+
600
+ dtype = hidden_states.dtype
601
+
602
+ if use_quantizer and self.quantizer is not None:
603
+ self.quantizer.float()
604
+
605
+ codes = self.quantizer.encode(hidden_states.float(), n_q=n_q)
606
+ if return_codes_only:
607
+ return codes, output_length
608
+ hidden_states = self.quantizer.decode(codes)
609
+ hidden_states = hidden_states.to(dtype)
610
+ else:
611
+ codes = None
612
+
613
+ hidden_states_packed = hidden_states.clone()
614
+
615
+ # unpacking
616
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
617
+ bsz, tgt_len, self.config.d_model
618
+ )
619
+ hidden_states = torch.where(attention_mask, hidden_states, 0)
620
+ return hidden_states, hidden_states_packed, output_length, codes
621
+
622
+ @torch.no_grad()
623
+ def decode_vq(self, codes):
624
+ self.quantizer.float()
625
+ hidden_states = self.quantizer.decode(codes)
626
+
627
+ return hidden_states
628
+
629
+
630
+ class CausalConvTranspose1d(nn.Module):
631
+ def __init__(self, in_channels, out_channels, kernel_size, stride):
632
+ super().__init__()
633
+ self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
634
+ self.norm = nn.GroupNorm(1, out_channels)
635
+ self.in_channels = in_channels
636
+ self.out_channels = out_channels
637
+
638
+ def forward(self, hidden_states, input_length, output_dim=None):
639
+ kernel_size = self.conv.kernel_size[0]
640
+ stride = self.conv.stride[0]
641
+ bsz = input_length.shape[0]
642
+
643
+ if output_dim is None:
644
+ output_dim = hidden_states.dim()
645
+ if hidden_states.dim() <= 2: # unpack sequence to 3d
646
+ sequence_mask, unpacking_index = get_sequence_mask(
647
+ hidden_states, input_length
648
+ )
649
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
650
+ bsz, torch.max(input_length), self.in_channels
651
+ )
652
+ hidden_states = torch.where(sequence_mask, hidden_states, 0)
653
+
654
+ hidden_states = hidden_states.transpose(2, 1) # (N, L, C) -> (N, C, L)
655
+ hidden_states = self.conv(hidden_states)
656
+ hidden_states = self.norm(hidden_states)
657
+ hidden_states = hidden_states.transpose(2, 1) # (N, C, L) -> (N, L, C)
658
+
659
+ casual_padding_right = max(0, kernel_size - stride)
660
+ hidden_states = hidden_states[
661
+ :, : hidden_states.shape[1] - casual_padding_right, :
662
+ ]
663
+ output_length = (input_length - 1) * stride + kernel_size - casual_padding_right
664
+ sequence_mask, _ = get_sequence_mask(hidden_states, output_length)
665
+ if output_dim <= 2:
666
+ hidden_states = torch.masked_select(hidden_states, sequence_mask).view(
667
+ -1, self.out_channels
668
+ )
669
+ else:
670
+ hidden_states = torch.where(sequence_mask, hidden_states, 0)
671
+ hidden_states = hidden_states[:, : torch.max(output_length), :]
672
+ return hidden_states, output_length
673
+
674
+
675
+ class AudioDecoder(nn.Module):
676
+ def __init__(self, config: MiMoAudioTokenizerConfig):
677
+ super().__init__()
678
+ self.config = config
679
+ self.max_source_positions = (
680
+ self.config.max_audio_seconds
681
+ * self.config.sampling_rate
682
+ // self.config.hop_length
683
+ )
684
+
685
+ if self.config.avg_pooler != 1:
686
+ self.dconv1 = CausalConvTranspose1d(
687
+ self.config.d_model,
688
+ self.config.d_model,
689
+ self.config.avg_pooler,
690
+ self.config.avg_pooler,
691
+ )
692
+ else:
693
+ self.dconv1 = None
694
+
695
+ self.position_embedding = RotaryEmbedding(
696
+ config.rope_theta,
697
+ config.d_model // config.decoder_attention_heads,
698
+ self.max_source_positions,
699
+ config.rope_type,
700
+ )
701
+
702
+ self.layers = nn.ModuleList(
703
+ [
704
+ TransformerLayer(
705
+ ACT2FN[self.config.activation_function],
706
+ self.config.d_model,
707
+ self.config.decoder_attention_heads,
708
+ self.config.decoder_ffn_dim,
709
+ causal=self.config.decoder_causal,
710
+ ln_type=self.config.ln_type,
711
+ attn_window_size=self.config.decoder_attn_window_size,
712
+ )
713
+ for _ in range(self.config.decoder_layers)
714
+ ]
715
+ )
716
+ self.layer_norm = LAYER_NORM[config.ln_type](self.config.d_model)
717
+ self.dconv2 = CausalConvTranspose1d(
718
+ self.config.d_model,
719
+ self.config.n_mels,
720
+ self.config.decoder_kernel_size,
721
+ self.config.decoder_stride_size,
722
+ )
723
+ self.vocoder = TransformerVocos(config)
724
+
725
+ def forward(
726
+ self,
727
+ audio_embed,
728
+ input_length,
729
+ ):
730
+ assert audio_embed.shape[-1] == self.config.d_model
731
+ audio_embed = audio_embed.to(self.layer_norm.weight)
732
+
733
+ if self.dconv1 is not None:
734
+ audio_embed, output_length = self.dconv1(
735
+ audio_embed, input_length, output_dim=3
736
+ )
737
+ _, tgt_len, _ = audio_embed.size()
738
+ else:
739
+ output_length = input_length
740
+ tgt_len = audio_embed.size(0)
741
+
742
+ hidden_states = audio_embed
743
+
744
+ position_ids = (
745
+ get_position_ids(output_length).long().to(hidden_states.device)
746
+ )
747
+ rope_position_embeddings = self.position_embedding(
748
+ hidden_states, position_ids
749
+ )
750
+
751
+
752
+ # packing hidden states
753
+ attention_mask, _ = get_sequence_mask(hidden_states, output_length)
754
+ hidden_states = torch.masked_select(hidden_states, attention_mask).view(
755
+ torch.sum(output_length), self.config.d_model
756
+ )
757
+
758
+ for idx, encoder_layer in enumerate(self.layers):
759
+ hidden_states = encoder_layer(
760
+ hidden_states,
761
+ output_length,
762
+ rope_position_embeddings=rope_position_embeddings,
763
+ )
764
+
765
+ hidden_states = self.layer_norm(hidden_states)
766
+
767
+ coarse_mel, output_length = self.dconv2(
768
+ hidden_states, output_length, output_dim=3
769
+ )
770
+
771
+ recon_wav, wav_length = self.vocoder(
772
+ x=coarse_mel.transpose(1, 2),
773
+ input_length=output_length,
774
+ )
775
+
776
+ return recon_wav
777
+
778
+
779
+ class MiMoAudioTokenizer(PreTrainedModel):
780
+ config_class = MiMoAudioTokenizerConfig
781
+
782
+ def __init__(self, config: MiMoAudioTokenizerConfig):
783
+ super().__init__(config)
784
+ self.config = config
785
+ self.sampling_rate = config.sampling_rate
786
+ self.encoder = AudioEncoder(config=config)
787
+ self.decoder = AudioDecoder(config=config)
788
+ self.downsample_rate = int(self.config.hop_length * 2 * self.config.avg_pooler)
789
+
790
+ def get_output_length(self, mel_len):
791
+ tgt_len = mel_len + 3 - self.config.kernel_size
792
+ return (tgt_len + 2 - self.config.kernel_size) // self.config.stride_size + 1
793
+
794
+ @torch.no_grad()
795
+ def encode(self, mels, input_lens, use_quantizer=True):
796
+ input_features = mels
797
+ encoder_output_length = self.get_output_length(input_lens)
798
+ hidden_states, hidden_states_packed, encoder_output_length, codes = (
799
+ self.encoder.encode(
800
+ input_features, input_lens=input_lens, use_quantizer=use_quantizer
801
+ )
802
+ )
803
+ return hidden_states, hidden_states_packed, encoder_output_length, codes
804
+
805
+ @torch.no_grad()
806
+ def decode(self, codes):
807
+ hidden_states = self.encoder.decode_vq(codes)
808
+ output = self.decoder(
809
+ hidden_states,
810
+ torch.tensor([hidden_states.size(0)], device=hidden_states.device),
811
+ )
812
+ return output
813
+
814
+ @torch.no_grad()
815
+ def streaming_decode(self, codes_chunks, chunk_input_lengths, history_cache=StreamingCache(), streaming_config=StreamingConfig(), last_chunk=False):
816
+ hidden_states = self.encoder.decode_vq(codes_chunks)
817
+ input_lengths = []
818
+ input_hidden_states = []
819
+ start_idx = 0
820
+ cache_hidden_states = []
821
+ for i, input_length in enumerate(chunk_input_lengths):
822
+ sample_hidden_states = hidden_states[start_idx:start_idx + input_length]
823
+ start_idx += input_length
824
+ if history_cache.hidden_states is not None:
825
+ sample_hidden_states = torch.cat([history_cache.hidden_states[i], sample_hidden_states], dim=0)
826
+ input_length += history_cache.hidden_states[i].size(0)
827
+ input_hidden_states.append(sample_hidden_states)
828
+ cache_hidden_states.append(sample_hidden_states.clone())
829
+ input_lengths.append(input_length)
830
+ input_hidden_states = torch.cat(input_hidden_states, dim=0)
831
+ input_lengths = torch.tensor(input_lengths, device=hidden_states.device)
832
+ output = self.decoder(input_hidden_states, input_lengths)
833
+ return_wavs = []
834
+ frames_per_token = self.config.avg_pooler * self.config.stride_size * self.config.hop_length
835
+ processed_lengths = []
836
+ for i, wav in enumerate(output):
837
+ wav = wav.float().detach().cpu()
838
+ start_idx = history_cache.processed_lengths[i] if history_cache.processed_lengths is not None else 0
839
+ if last_chunk:
840
+ return_wavs.append(wav[:, start_idx * frames_per_token:])
841
+ new_processed_length = input_lengths[i].item()
842
+ elif input_lengths[i].item() <= streaming_config.right_overlap:
843
+ return_wavs.append(None)
844
+ new_processed_length = 0
845
+ else:
846
+ end_idx = (input_lengths[i].item() - streaming_config.right_overlap)
847
+ wav = wav[:, start_idx * frames_per_token: end_idx * frames_per_token]
848
+ return_wavs.append(wav)
849
+ new_processed_length = end_idx
850
+ if input_lengths[i].item() > streaming_config.left_overlap:
851
+ cache_hidden_states[i] = cache_hidden_states[i][-streaming_config.left_overlap:]
852
+ new_processed_length -= (input_lengths[i].item() - streaming_config.left_overlap)
853
+ processed_lengths.append(new_processed_length)
854
+ history_cache.hidden_states = cache_hidden_states
855
+ history_cache.processed_lengths = processed_lengths
856
+
857
+ return return_wavs, history_cache
src/mimo_audio_tokenizer/modeling_rope_utils.py ADDED
@@ -0,0 +1,878 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xiaomi Corporation.
2
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from functools import wraps
18
+ from typing import Optional
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import is_torch_available, logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ if is_torch_available():
28
+ import torch
29
+
30
+
31
+ def dynamic_rope_update(rope_forward):
32
+ """
33
+ Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
34
+ (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
35
+
36
+ Args:
37
+ rope_forward (Callable):
38
+ The forward pass of the RoPE implementation.
39
+
40
+ Returns:
41
+ The decorated forward pass.
42
+ """
43
+
44
+ def longrope_frequency_update(self, position_ids, device):
45
+ """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
46
+ seq_len = torch.max(position_ids) + 1
47
+ if hasattr(self.config, "original_max_position_embeddings"):
48
+ original_max_position_embeddings = (
49
+ self.config.original_max_position_embeddings
50
+ )
51
+ else:
52
+ original_max_position_embeddings = self.config.max_position_embeddings
53
+ if seq_len > original_max_position_embeddings:
54
+ if not hasattr(self, "long_inv_freq"):
55
+ self.long_inv_freq, _ = self.rope_init_fn(
56
+ self.config, device, seq_len=original_max_position_embeddings + 1
57
+ )
58
+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
59
+ else:
60
+ # This .to() is needed if the model has been moved to a device after being initialized (because
61
+ # the buffer is automatically moved, but not the original copy)
62
+ self.original_inv_freq = self.original_inv_freq.to(device)
63
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
64
+
65
+ def dynamic_frequency_update(self, position_ids, device):
66
+ """
67
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
68
+ 1 - growing beyond the cached sequence length (allow scaling)
69
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
70
+ """
71
+ seq_len = torch.max(position_ids) + 1
72
+ if seq_len > self.max_seq_len_cached: # growth
73
+ inv_freq, self.attention_scaling = self.rope_init_fn(
74
+ self.config, device, seq_len=seq_len
75
+ )
76
+ self.register_buffer(
77
+ "inv_freq", inv_freq, persistent=False
78
+ ) # TODO joao: may break with compilation
79
+ self.max_seq_len_cached = seq_len
80
+
81
+ if (
82
+ seq_len < self.original_max_seq_len
83
+ and self.max_seq_len_cached > self.original_max_seq_len
84
+ ): # reset
85
+ # This .to() is needed if the model has been moved to a device after being initialized (because
86
+ # the buffer is automatically moved, but not the original copy)
87
+ self.original_inv_freq = self.original_inv_freq.to(device)
88
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
89
+ self.max_seq_len_cached = self.original_max_seq_len
90
+
91
+ @wraps(rope_forward)
92
+ def wrapper(self, x, position_ids):
93
+ if "dynamic" in self.rope_type:
94
+ dynamic_frequency_update(self, position_ids, device=x.device)
95
+ elif self.rope_type == "longrope":
96
+ longrope_frequency_update(self, position_ids, device=x.device)
97
+ return rope_forward(self, x, position_ids)
98
+
99
+ return wrapper
100
+
101
+
102
+ def _compute_default_rope_parameters(
103
+ config: Optional[PretrainedConfig] = None,
104
+ device: Optional["torch.device"] = None,
105
+ seq_len: Optional[int] = None,
106
+ **rope_kwargs,
107
+ ) -> tuple["torch.Tensor", float]:
108
+ """
109
+ Computes the inverse frequencies according to the original RoPE implementation
110
+ Args:
111
+ config ([`~transformers.PretrainedConfig`]):
112
+ The model configuration.
113
+ device (`torch.device`):
114
+ The device to use for initialization of the inverse frequencies.
115
+ seq_len (`int`, *optional*):
116
+ The current sequence length. Unused for this type of RoPE.
117
+ rope_kwargs (`Dict`, *optional*):
118
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
119
+ Returns:
120
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
121
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
122
+ """
123
+ if config is not None and len(rope_kwargs) > 0:
124
+ raise ValueError(
125
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
126
+ f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
127
+ )
128
+ if len(rope_kwargs) > 0:
129
+ base = rope_kwargs["base"]
130
+ dim = rope_kwargs["dim"]
131
+ elif config is not None:
132
+ base = config.rope_theta
133
+ partial_rotary_factor = (
134
+ config.partial_rotary_factor
135
+ if hasattr(config, "partial_rotary_factor")
136
+ else 1.0
137
+ )
138
+ head_dim = (
139
+ getattr(config, "head_dim", None)
140
+ or config.hidden_size // config.num_attention_heads
141
+ )
142
+ dim = int(head_dim * partial_rotary_factor)
143
+
144
+ attention_factor = 1.0 # Unused in this type of RoPE
145
+
146
+ # Compute the inverse frequencies
147
+ inv_freq = 1.0 / (
148
+ base
149
+ ** (
150
+ torch.arange(0, dim, 2, dtype=torch.int64).to(
151
+ device=device, dtype=torch.float
152
+ )
153
+ / dim
154
+ )
155
+ )
156
+ return inv_freq, attention_factor
157
+
158
+
159
+ def _compute_linear_scaling_rope_parameters(
160
+ config: Optional[PretrainedConfig] = None,
161
+ device: Optional["torch.device"] = None,
162
+ seq_len: Optional[int] = None,
163
+ **rope_kwargs,
164
+ ) -> tuple["torch.Tensor", float]:
165
+ """
166
+ Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
167
+ Args:
168
+ config ([`~transformers.PretrainedConfig`]):
169
+ The model configuration.
170
+ device (`torch.device`):
171
+ The device to use for initialization of the inverse frequencies.
172
+ seq_len (`int`, *optional*):
173
+ The current sequence length. Unused for this type of RoPE.
174
+ rope_kwargs (`Dict`, *optional*):
175
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
176
+ Returns:
177
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
178
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
179
+ """
180
+ if config is not None and len(rope_kwargs) > 0:
181
+ raise ValueError(
182
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
183
+ f"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
184
+ )
185
+ if len(rope_kwargs) > 0:
186
+ factor = rope_kwargs["factor"]
187
+ elif config is not None:
188
+ factor = config.rope_scaling["factor"]
189
+
190
+ # Gets the default RoPE parameters
191
+ inv_freq, attention_factor = _compute_default_rope_parameters(
192
+ config, device, seq_len, **rope_kwargs
193
+ )
194
+
195
+ # Then applies linear scaling to the frequencies.
196
+ # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
197
+ # applying scaling to the inverse frequencies is equivalent.
198
+ inv_freq /= factor
199
+ return inv_freq, attention_factor
200
+
201
+
202
+ def _compute_dynamic_ntk_parameters(
203
+ config: Optional[PretrainedConfig] = None,
204
+ device: Optional["torch.device"] = None,
205
+ seq_len: Optional[int] = None,
206
+ **rope_kwargs,
207
+ ) -> tuple["torch.Tensor", float]:
208
+ """
209
+ Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
210
+ Args:
211
+ config ([`~transformers.PretrainedConfig`]):
212
+ The model configuration.
213
+ device (`torch.device`):
214
+ The device to use for initialization of the inverse frequencies.
215
+ seq_len (`int`, *optional*):
216
+ The current sequence length, used to update the dynamic RoPE at inference time.
217
+ rope_kwargs (`Dict`, *optional*):
218
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
219
+ Returns:
220
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
221
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
222
+ """
223
+ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
224
+ if config is not None and len(rope_kwargs) > 0:
225
+ raise ValueError(
226
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
227
+ f"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
228
+ )
229
+ if len(rope_kwargs) > 0:
230
+ base = rope_kwargs["base"]
231
+ dim = rope_kwargs["dim"]
232
+ max_position_embeddings = rope_kwargs["max_position_embeddings"]
233
+ factor = rope_kwargs["factor"]
234
+ elif config is not None:
235
+ base = config.rope_theta
236
+ partial_rotary_factor = (
237
+ config.partial_rotary_factor
238
+ if hasattr(config, "partial_rotary_factor")
239
+ else 1.0
240
+ )
241
+ head_dim = getattr(
242
+ config, "head_dim", config.hidden_size // config.num_attention_heads
243
+ )
244
+ dim = int(head_dim * partial_rotary_factor)
245
+ max_position_embeddings = config.max_position_embeddings
246
+ factor = config.rope_scaling["factor"]
247
+
248
+ attention_factor = 1.0 # Unused in this type of RoPE
249
+
250
+ # seq_len: default to max_position_embeddings, e.g. at init time
251
+ seq_len = (
252
+ seq_len
253
+ if seq_len is not None and seq_len > max_position_embeddings
254
+ else max_position_embeddings
255
+ )
256
+
257
+ # Compute the inverse frequencies
258
+ base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (
259
+ dim / (dim - 2)
260
+ )
261
+ inv_freq = 1.0 / (
262
+ base
263
+ ** (
264
+ torch.arange(0, dim, 2, dtype=torch.int64).to(
265
+ device=device, dtype=torch.float
266
+ )
267
+ / dim
268
+ )
269
+ )
270
+ return inv_freq, attention_factor
271
+
272
+
273
+ def _compute_yarn_parameters(
274
+ config: PretrainedConfig,
275
+ device: "torch.device",
276
+ seq_len: Optional[int] = None,
277
+ **rope_kwargs,
278
+ ) -> tuple["torch.Tensor", float]:
279
+ """
280
+ Computes the inverse frequencies with NTK scaling. Please refer to the
281
+ [original paper](https://huggingface.co/papers/2309.00071)
282
+ Args:
283
+ config ([`~transformers.PretrainedConfig`]):
284
+ The model configuration.
285
+ device (`torch.device`):
286
+ The device to use for initialization of the inverse frequencies.
287
+ seq_len (`int`, *optional*):
288
+ The current sequence length. Unused for this type of RoPE.
289
+ rope_kwargs (`Dict`, *optional*):
290
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
291
+ Returns:
292
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
293
+ post-processing scaling factor applied to the computed cos/sin.
294
+ """
295
+ # No need to keep BC with yarn, unreleased when this new pattern was created.
296
+ if len(rope_kwargs) > 0:
297
+ raise ValueError(
298
+ f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}"
299
+ )
300
+
301
+ base = config.rope_theta
302
+ partial_rotary_factor = (
303
+ config.partial_rotary_factor
304
+ if hasattr(config, "partial_rotary_factor")
305
+ else 1.0
306
+ )
307
+ head_dim = getattr(
308
+ config, "head_dim", config.hidden_size // config.num_attention_heads
309
+ )
310
+ dim = int(head_dim * partial_rotary_factor)
311
+ factor = config.rope_scaling["factor"]
312
+ attention_factor = config.rope_scaling.get("attention_factor")
313
+ mscale = config.rope_scaling.get("mscale")
314
+ mscale_all_dim = config.rope_scaling.get("mscale_all_dim")
315
+
316
+ # NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a
317
+ # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
318
+ # values to compute the default attention scaling factor, instead of using `factor`.
319
+ if "original_max_position_embeddings" in config.rope_scaling:
320
+ original_max_position_embeddings = config.rope_scaling[
321
+ "original_max_position_embeddings"
322
+ ]
323
+ factor = config.max_position_embeddings / original_max_position_embeddings
324
+ else:
325
+ original_max_position_embeddings = config.max_position_embeddings
326
+
327
+ def get_mscale(scale, mscale=1):
328
+ if scale <= 1:
329
+ return 1.0
330
+ return 0.1 * mscale * math.log(scale) + 1.0
331
+
332
+ # Sets the attention factor as suggested in the paper
333
+ if attention_factor is None:
334
+ if mscale and mscale_all_dim:
335
+ attention_factor = float(
336
+ get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim)
337
+ )
338
+ else:
339
+ attention_factor = get_mscale(factor)
340
+
341
+ # Optional config options
342
+ # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
343
+ beta_fast = config.rope_scaling.get("beta_fast") or 32
344
+ beta_slow = config.rope_scaling.get("beta_slow") or 1
345
+
346
+ # Compute the inverse frequencies
347
+ def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
348
+ """Inverse dimension formula to find the dimension based on the number of rotations"""
349
+ return (
350
+ dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))
351
+ ) / (2 * math.log(base))
352
+
353
+ def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
354
+ """Find dimension range bounds based on rotations"""
355
+ low = math.floor(
356
+ find_correction_dim(low_rot, dim, base, max_position_embeddings)
357
+ )
358
+ high = math.ceil(
359
+ find_correction_dim(high_rot, dim, base, max_position_embeddings)
360
+ )
361
+ return max(low, 0), min(high, dim - 1)
362
+
363
+ def linear_ramp_factor(min, max, dim):
364
+ if min == max:
365
+ max += 0.001 # Prevent singularity
366
+
367
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
368
+ ramp_func = torch.clamp(linear_func, 0, 1)
369
+ return ramp_func
370
+
371
+ # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
372
+ # to expand the possible context length. In other words, interpolation = apply scaling factor.
373
+ pos_freqs = base ** (
374
+ torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim
375
+ )
376
+ inv_freq_extrapolation = 1.0 / pos_freqs
377
+ inv_freq_interpolation = 1.0 / (factor * pos_freqs)
378
+
379
+ low, high = find_correction_range(
380
+ beta_fast, beta_slow, dim, base, original_max_position_embeddings
381
+ )
382
+
383
+ # Get n-dimensional rotational scaling corrected for extrapolation
384
+ inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(
385
+ device=device, dtype=torch.float
386
+ )
387
+ inv_freq = (
388
+ inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
389
+ + inv_freq_extrapolation * inv_freq_extrapolation_factor
390
+ )
391
+ return inv_freq, attention_factor
392
+
393
+
394
+ def _compute_longrope_parameters(
395
+ config: PretrainedConfig,
396
+ device: "torch.device",
397
+ seq_len: Optional[int] = None,
398
+ **rope_kwargs,
399
+ ) -> tuple["torch.Tensor", float]:
400
+ """
401
+ Computes the inverse frequencies with LongRoPE scaling. Please refer to the
402
+ [original implementation](https://github.com/microsoft/LongRoPE)
403
+ Args:
404
+ config ([`~transformers.PretrainedConfig`]):
405
+ The model configuration.
406
+ device (`torch.device`):
407
+ The device to use for initialization of the inverse frequencies.
408
+ seq_len (`int`, *optional*):
409
+ The current sequence length.
410
+ rope_kwargs (`Dict`, *optional*):
411
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
412
+ Returns:
413
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
414
+ post-processing scaling factor applied to the computed cos/sin.
415
+ """
416
+ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
417
+ # No need to keep BC with longrope, unreleased when this new pattern was created.
418
+ if len(rope_kwargs) > 0:
419
+ raise ValueError(
420
+ "Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got "
421
+ f"{rope_kwargs}"
422
+ )
423
+
424
+ base = config.rope_theta
425
+ partial_rotary_factor = (
426
+ config.partial_rotary_factor
427
+ if hasattr(config, "partial_rotary_factor")
428
+ else 1.0
429
+ )
430
+ head_dim = getattr(
431
+ config, "head_dim", config.hidden_size // config.num_attention_heads
432
+ )
433
+ dim = int(head_dim * partial_rotary_factor)
434
+ long_factor = config.rope_scaling["long_factor"]
435
+ short_factor = config.rope_scaling["short_factor"]
436
+ factor = config.rope_scaling.get("factor")
437
+ attention_factor = config.rope_scaling.get("attention_factor")
438
+
439
+ # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
440
+ # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
441
+ # values to compute the default attention scaling factor, instead of using `factor`.
442
+ if hasattr(config, "original_max_position_embeddings"):
443
+ original_max_position_embeddings = config.original_max_position_embeddings
444
+ factor = (
445
+ config.max_position_embeddings / config.original_max_position_embeddings
446
+ )
447
+ else:
448
+ original_max_position_embeddings = config.max_position_embeddings
449
+
450
+ # Sets the attention factor as suggested in the paper
451
+ if attention_factor is None:
452
+ if factor <= 1.0:
453
+ attention_factor = 1.0
454
+ else:
455
+ attention_factor = math.sqrt(
456
+ 1 + math.log(factor) / math.log(original_max_position_embeddings)
457
+ )
458
+
459
+ # Compute the inverse frequencies -- scaled based on the target sequence length
460
+ if seq_len and seq_len > original_max_position_embeddings:
461
+ ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
462
+ else:
463
+ ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
464
+ inv_freq_shape = (
465
+ torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim
466
+ )
467
+ inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
468
+
469
+ return inv_freq, attention_factor
470
+
471
+
472
+ def _compute_llama3_parameters(
473
+ config: PretrainedConfig,
474
+ device: "torch.device",
475
+ seq_len: Optional[int] = None,
476
+ **rope_kwargs,
477
+ ) -> tuple["torch.Tensor", float]:
478
+ """
479
+ Computes the inverse frequencies for llama 3.1.
480
+
481
+ Args:
482
+ config ([`~transformers.PretrainedConfig`]):
483
+ The model configuration.
484
+ device (`torch.device`):
485
+ The device to use for initialization of the inverse frequencies.
486
+ seq_len (`int`, *optional*):
487
+ The current sequence length. Unused for this type of RoPE.
488
+ rope_kwargs (`Dict`, *optional*):
489
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
490
+ Returns:
491
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
492
+ post-processing scaling factor applied to the computed cos/sin.
493
+ """
494
+ # Gets the default RoPE parameters
495
+ inv_freq, attention_factor = _compute_default_rope_parameters(
496
+ config, device, seq_len, **rope_kwargs
497
+ )
498
+
499
+ factor = config.rope_scaling["factor"] # `8` in the original implementation
500
+ low_freq_factor = config.rope_scaling[
501
+ "low_freq_factor"
502
+ ] # `1` in the original implementation
503
+ high_freq_factor = config.rope_scaling[
504
+ "high_freq_factor"
505
+ ] # `4` in the original implementation
506
+ old_context_len = config.rope_scaling[
507
+ "original_max_position_embeddings"
508
+ ] # `8192` in the original implementation
509
+
510
+ low_freq_wavelen = old_context_len / low_freq_factor
511
+ high_freq_wavelen = old_context_len / high_freq_factor
512
+
513
+ wavelen = 2 * math.pi / inv_freq
514
+ # wavelen < high_freq_wavelen: do nothing
515
+ # wavelen > low_freq_wavelen: divide by factor
516
+ inv_freq_llama = torch.where(
517
+ wavelen > low_freq_wavelen, inv_freq / factor, inv_freq
518
+ )
519
+ # otherwise: interpolate between the two, using a smooth factor
520
+ smooth_factor = (old_context_len / wavelen - low_freq_factor) / (
521
+ high_freq_factor - low_freq_factor
522
+ )
523
+ smoothed_inv_freq = (
524
+ 1 - smooth_factor
525
+ ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
526
+ is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
527
+ inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
528
+
529
+ return inv_freq_llama, attention_factor
530
+
531
+
532
+ # This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
533
+ # from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
534
+ # parameterizations, as long as the callable has the same signature.
535
+ ROPE_INIT_FUNCTIONS = {
536
+ "default": _compute_default_rope_parameters,
537
+ "linear": _compute_linear_scaling_rope_parameters,
538
+ "dynamic": _compute_dynamic_ntk_parameters,
539
+ "yarn": _compute_yarn_parameters,
540
+ "longrope": _compute_longrope_parameters,
541
+ "llama3": _compute_llama3_parameters,
542
+ }
543
+
544
+
545
+ def _check_received_keys(
546
+ rope_type: str,
547
+ received_keys: set,
548
+ required_keys: set,
549
+ optional_keys: Optional[set] = None,
550
+ ignore_keys: Optional[set] = None,
551
+ ):
552
+ """Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
553
+ # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
554
+ if "type" in received_keys:
555
+ received_keys -= {"type"}
556
+ required_keys.add("rope_type")
557
+
558
+ # Some models need to store model-specific keys, and we don't want to throw warning at them
559
+ if ignore_keys is not None:
560
+ received_keys -= ignore_keys
561
+
562
+ missing_keys = required_keys - received_keys
563
+ if missing_keys:
564
+ raise KeyError(
565
+ f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}"
566
+ )
567
+
568
+ if optional_keys is not None:
569
+ unused_keys = received_keys - required_keys - optional_keys
570
+ else:
571
+ unused_keys = received_keys - required_keys
572
+ if unused_keys:
573
+ logger.warning(
574
+ f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}"
575
+ )
576
+
577
+
578
+ def _validate_default_rope_parameters(
579
+ config: PretrainedConfig, ignore_keys: Optional[set] = None
580
+ ):
581
+ rope_scaling = config.rope_scaling
582
+ rope_type = rope_scaling.get(
583
+ "rope_type", rope_scaling.get("type", None)
584
+ ) # BC: "rope_type" was originally "type"
585
+ required_keys = {"rope_type"}
586
+ received_keys = set(rope_scaling.keys())
587
+ _check_received_keys(
588
+ rope_type, received_keys, required_keys, ignore_keys=ignore_keys
589
+ )
590
+
591
+
592
+ def _validate_linear_scaling_rope_parameters(
593
+ config: PretrainedConfig, ignore_keys: Optional[set] = None
594
+ ):
595
+ rope_scaling = config.rope_scaling
596
+ rope_type = rope_scaling.get(
597
+ "rope_type", rope_scaling.get("type", None)
598
+ ) # BC: "rope_type" was originally "type"
599
+ required_keys = {"rope_type", "factor"}
600
+ received_keys = set(rope_scaling.keys())
601
+ _check_received_keys(
602
+ rope_type, received_keys, required_keys, ignore_keys=ignore_keys
603
+ )
604
+
605
+ factor = rope_scaling["factor"]
606
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
607
+ logger.warning(
608
+ f"`rope_scaling`'s factor field must be a float >= 1, got {factor}"
609
+ )
610
+
611
+
612
+ def _validate_dynamic_scaling_rope_parameters(
613
+ config: PretrainedConfig, ignore_keys: Optional[set] = None
614
+ ):
615
+ rope_scaling = config.rope_scaling
616
+ rope_type = rope_scaling.get(
617
+ "rope_type", rope_scaling.get("type", None)
618
+ ) # BC: "rope_type" was originally "type"
619
+ required_keys = {"rope_type", "factor"}
620
+ # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
621
+ optional_keys = {"original_max_position_embeddings"}
622
+ received_keys = set(rope_scaling.keys())
623
+ _check_received_keys(
624
+ rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys
625
+ )
626
+
627
+ factor = rope_scaling["factor"]
628
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
629
+ logger.warning(
630
+ f"`rope_scaling`'s factor field must be a float >= 1, got {factor}"
631
+ )
632
+
633
+
634
+ def _validate_yarn_parameters(
635
+ config: PretrainedConfig, ignore_keys: Optional[set] = None
636
+ ):
637
+ rope_scaling = config.rope_scaling
638
+ rope_type = rope_scaling.get(
639
+ "rope_type", rope_scaling.get("type", None)
640
+ ) # BC: "rope_type" was originally "type"
641
+ required_keys = {"rope_type", "factor"}
642
+ optional_keys = {
643
+ "attention_factor",
644
+ "beta_fast",
645
+ "beta_slow",
646
+ "original_max_position_embeddings",
647
+ "mscale",
648
+ "mscale_all_dim",
649
+ }
650
+ received_keys = set(rope_scaling.keys())
651
+ _check_received_keys(
652
+ rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys
653
+ )
654
+
655
+ factor = rope_scaling["factor"]
656
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
657
+ logger.warning(
658
+ f"`rope_scaling`'s factor field must be a float >= 1, got {factor}"
659
+ )
660
+
661
+ attention_factor = rope_scaling.get("attention_factor")
662
+ if attention_factor is not None and (
663
+ not isinstance(attention_factor, float) or attention_factor < 0
664
+ ):
665
+ logger.warning(
666
+ f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
667
+ )
668
+ beta_fast = rope_scaling.get("beta_fast")
669
+ if beta_fast is not None and not isinstance(beta_fast, float):
670
+ logger.warning(
671
+ f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}"
672
+ )
673
+ beta_slow = rope_scaling.get("beta_slow")
674
+ if beta_slow is not None and not isinstance(beta_slow, float):
675
+ logger.warning(
676
+ f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}"
677
+ )
678
+
679
+ if (beta_fast or 32) < (beta_slow or 1):
680
+ logger.warning(
681
+ f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
682
+ f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
683
+ )
684
+
685
+
686
+ def _validate_longrope_parameters(
687
+ config: PretrainedConfig, ignore_keys: Optional[set] = None
688
+ ):
689
+ rope_scaling = config.rope_scaling
690
+ rope_type = rope_scaling.get(
691
+ "rope_type", rope_scaling.get("type", None)
692
+ ) # BC: "rope_type" was originally "type"
693
+ required_keys = {"rope_type", "short_factor", "long_factor"}
694
+ # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
695
+ optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
696
+ received_keys = set(rope_scaling.keys())
697
+ _check_received_keys(
698
+ rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys
699
+ )
700
+
701
+ partial_rotary_factor = (
702
+ config.partial_rotary_factor
703
+ if hasattr(config, "partial_rotary_factor")
704
+ else 1.0
705
+ )
706
+ head_dim = getattr(
707
+ config, "head_dim", config.hidden_size // config.num_attention_heads
708
+ )
709
+ dim = int(head_dim * partial_rotary_factor)
710
+
711
+ short_factor = rope_scaling.get("short_factor")
712
+ if not isinstance(short_factor, list) and all(
713
+ isinstance(x, (int, float)) for x in short_factor
714
+ ):
715
+ logger.warning(
716
+ f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}"
717
+ )
718
+ if not len(short_factor) == dim // 2:
719
+ logger.warning(
720
+ f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}"
721
+ )
722
+
723
+ long_factor = rope_scaling.get("long_factor")
724
+ if not isinstance(long_factor, list) and all(
725
+ isinstance(x, (int, float)) for x in long_factor
726
+ ):
727
+ logger.warning(
728
+ f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}"
729
+ )
730
+ if not len(long_factor) == dim // 2:
731
+ logger.warning(
732
+ f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}"
733
+ )
734
+
735
+ # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
736
+ # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is
737
+ # unique to longrope (= undesirable)
738
+ if hasattr(config, "original_max_position_embeddings"):
739
+ logger.warning_once(
740
+ "This model has set a `original_max_position_embeddings` field, to be used together with "
741
+ "`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`"
742
+ "with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, "
743
+ "as it is compatible with most model architectures."
744
+ )
745
+ else:
746
+ factor = rope_scaling.get("factor")
747
+ if factor is None:
748
+ logger.warning("Missing required keys in `rope_scaling`: 'factor'")
749
+ elif not isinstance(factor, float) or factor < 1.0:
750
+ logger.warning(
751
+ f"`rope_scaling`'s factor field must be a float >= 1, got {factor}"
752
+ )
753
+
754
+ attention_factor = rope_scaling.get("attention_factor")
755
+ if attention_factor is not None:
756
+ if not isinstance(attention_factor, float) or attention_factor < 0.0:
757
+ logger.warning(
758
+ f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
759
+ )
760
+
761
+
762
+ def _validate_llama3_parameters(
763
+ config: PretrainedConfig, ignore_keys: Optional[set] = None
764
+ ):
765
+ rope_scaling = config.rope_scaling
766
+ rope_type = rope_scaling.get(
767
+ "rope_type", rope_scaling.get("type", None)
768
+ ) # BC: "rope_type" was originally "type"
769
+ required_keys = {
770
+ "rope_type",
771
+ "factor",
772
+ "original_max_position_embeddings",
773
+ "low_freq_factor",
774
+ "high_freq_factor",
775
+ }
776
+ received_keys = set(rope_scaling.keys())
777
+ _check_received_keys(
778
+ rope_type, received_keys, required_keys, ignore_keys=ignore_keys
779
+ )
780
+
781
+ factor = rope_scaling["factor"]
782
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
783
+ logger.warning(
784
+ f"`rope_scaling`'s factor field must be a float >= 1, got {factor}"
785
+ )
786
+
787
+ low_freq_factor = rope_scaling["low_freq_factor"]
788
+ high_freq_factor = rope_scaling["high_freq_factor"]
789
+ if low_freq_factor is None or not isinstance(low_freq_factor, float):
790
+ logger.warning(
791
+ f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}"
792
+ )
793
+ if high_freq_factor is None or not isinstance(high_freq_factor, float):
794
+ logger.warning(
795
+ f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}"
796
+ )
797
+ if high_freq_factor <= low_freq_factor:
798
+ logger.warning(
799
+ "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
800
+ f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
801
+ )
802
+
803
+ original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
804
+ if original_max_position_embeddings is None or not isinstance(
805
+ original_max_position_embeddings, int
806
+ ):
807
+ logger.warning(
808
+ "`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
809
+ f"{original_max_position_embeddings}"
810
+ )
811
+ if original_max_position_embeddings >= config.max_position_embeddings:
812
+ logger.warning(
813
+ "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
814
+ f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
815
+ )
816
+
817
+
818
+ # Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
819
+ ROPE_VALIDATION_FUNCTIONS = {
820
+ "default": _validate_default_rope_parameters,
821
+ "linear": _validate_linear_scaling_rope_parameters,
822
+ "dynamic": _validate_dynamic_scaling_rope_parameters,
823
+ "yarn": _validate_yarn_parameters,
824
+ "longrope": _validate_longrope_parameters,
825
+ "llama3": _validate_llama3_parameters,
826
+ }
827
+
828
+
829
+ def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None):
830
+ """
831
+ Validate the RoPE config arguments, given a `PretrainedConfig` object
832
+ """
833
+ rope_scaling = getattr(
834
+ config, "rope_scaling", None
835
+ ) # not a default parameter in `PretrainedConfig`
836
+ if rope_scaling is None:
837
+ return
838
+
839
+ # BC: "rope_type" was originally "type"
840
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
841
+ validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
842
+ if validation_fn is not None:
843
+ validation_fn(config, ignore_keys=ignore_keys)
844
+ else:
845
+ logger.warning(
846
+ f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
847
+ )
848
+
849
+
850
+ def rotate_half(x):
851
+ """Rotates half the hidden dims of the input."""
852
+ x1 = x[..., : x.shape[-1] // 2]
853
+ x2 = x[..., x.shape[-1] // 2 :]
854
+ return torch.cat((-x2, x1), dim=-1)
855
+
856
+ def apply_rotary_pos_emb(x, cos, sin, position_ids=None, unsqueeze_dim=1):
857
+ """Applies Rotary Position Embedding to the query and key tensors.
858
+
859
+ Args:
860
+ x (`torch.Tensor`): The input tensor.
861
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
862
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
863
+ position_ids (`torch.Tensor`, *optional*):
864
+ Deprecated and unused.
865
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
866
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
867
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
868
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
869
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
870
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
871
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
872
+ Returns:
873
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
874
+ """
875
+ cos = cos.unsqueeze(unsqueeze_dim)
876
+ sin = sin.unsqueeze(unsqueeze_dim)
877
+ x_embed = (x * cos) + (rotate_half(x) * sin)
878
+ return x_embed
src/mimo_audio_tokenizer/quantization.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xiaomi Corporation.
2
+ import typing as tp
3
+
4
+ from einops import rearrange, repeat
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+
9
+ import torch.distributed as dist
10
+
11
+
12
+ def rank():
13
+ if dist.is_initialized():
14
+ return dist.get_rank()
15
+ else:
16
+ return 0
17
+
18
+
19
+ def world_size():
20
+ if dist.is_initialized():
21
+ return dist.get_world_size()
22
+ else:
23
+ return 1
24
+
25
+
26
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
27
+ return val if val is not None else d
28
+
29
+
30
+ def ema_inplace(moving_avg, new, decay: float):
31
+ if dist.is_initialized():
32
+ dist.all_reduce(new, op=dist.ReduceOp.SUM)
33
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
34
+
35
+
36
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
37
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
38
+
39
+
40
+ def uniform_init(*shape: int):
41
+ t = torch.empty(shape)
42
+ nn.init.kaiming_uniform_(t)
43
+ return t
44
+
45
+
46
+ def sample_vectors(samples, num: int):
47
+ num_samples, device = samples.shape[0], samples.device
48
+
49
+ if num_samples >= num:
50
+ indices = torch.randperm(num_samples, device=device)[:num]
51
+ else:
52
+ indices = torch.randint(0, num_samples, (num,), device=device)
53
+
54
+ selected_samples = samples[indices]
55
+
56
+ if dist.is_initialized():
57
+
58
+ dist.broadcast(selected_samples, src=0)
59
+
60
+ return selected_samples
61
+
62
+
63
+ def kmeans(samples, num_clusters: int, num_iters: int = 10):
64
+ dim, dtype = samples.shape[-1], samples.dtype
65
+
66
+ means = sample_vectors(samples, num_clusters)
67
+
68
+ for _ in range(num_iters):
69
+ dists = -(
70
+ samples.pow(2).sum(1, keepdim=True)
71
+ - 2 * samples @ means.t()
72
+ + means.t().pow(2).sum(0, keepdim=True)
73
+ )
74
+
75
+ buckets = dists.max(dim=-1).indices
76
+ bins = torch.bincount(buckets, minlength=num_clusters)
77
+
78
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
79
+ new_means = new_means.scatter_add_(
80
+ 0, repeat(buckets, "n -> n d", d=dim), samples
81
+ )
82
+
83
+ if dist.is_initialized():
84
+ dist.all_reduce(bins, op=dist.ReduceOp.SUM)
85
+ dist.all_reduce(new_means, op=dist.ReduceOp.SUM)
86
+
87
+ zero_mask = bins == 0
88
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
89
+
90
+ new_means = new_means / bins_min_clamped[..., None]
91
+
92
+ means = torch.where(zero_mask[..., None], means, new_means)
93
+
94
+ return means, bins
95
+
96
+
97
+ class EuclideanCodebook(nn.Module):
98
+ """Codebook with Euclidean distance.
99
+ Args:
100
+ dim (int): Dimension.
101
+ codebook_size (int): Codebook size.
102
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
103
+ If set to true, run the k-means algorithm on the first training batch and use
104
+ the learned centroids as initialization.
105
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
106
+ decay (float): Decay for exponential moving average over the codebooks.
107
+ epsilon (float): Epsilon value for numerical stability.
108
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
109
+ that have an exponential moving average cluster size less than the specified threshold with
110
+ randomly selected vector from the current batch.
111
+ """
112
+
113
+ def __init__(
114
+ self,
115
+ dim: int,
116
+ codebook_size: int,
117
+ kmeans_init: int = False,
118
+ kmeans_iters: int = 10,
119
+ decay: float = 0.99,
120
+ epsilon: float = 1e-5,
121
+ threshold_ema_dead_code: int = 2,
122
+ ):
123
+ super().__init__()
124
+ self.decay = decay
125
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
126
+ uniform_init if not kmeans_init else torch.zeros
127
+ )
128
+ embed = init_fn(codebook_size, dim)
129
+
130
+ self.codebook_size = codebook_size
131
+
132
+ self.kmeans_iters = kmeans_iters
133
+ self.epsilon = epsilon
134
+ self.threshold_ema_dead_code = threshold_ema_dead_code
135
+
136
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
137
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
138
+ self.register_buffer("embed", embed)
139
+ self.register_buffer("embed_avg", embed.clone())
140
+
141
+ @torch.jit.ignore
142
+ def init_embed_(self, data):
143
+ if self.inited:
144
+ return
145
+
146
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
147
+ self.embed.data.copy_(embed)
148
+ self.embed_avg.data.copy_(embed.clone())
149
+ self.cluster_size.data.copy_(cluster_size)
150
+ self.inited.data.copy_(torch.Tensor([True]))
151
+
152
+ def replace_(self, samples, mask):
153
+ # modified_codebook = torch.where(
154
+ # mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
155
+ # )
156
+ replace_num = mask.sum()
157
+ modified_codebook = self.embed.clone()
158
+ modified_codebook[mask] = sample_vectors(samples, replace_num)
159
+ self.embed.data.copy_(modified_codebook)
160
+
161
+ def expire_codes_(self, batch_samples):
162
+ if self.threshold_ema_dead_code == 0:
163
+ return
164
+
165
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
166
+ if not torch.any(expired_codes):
167
+ return
168
+
169
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
170
+ self.replace_(batch_samples, mask=expired_codes)
171
+
172
+ def preprocess(self, x):
173
+ x = rearrange(x, "... d -> (...) d")
174
+ return x
175
+
176
+ def quantize(self, x):
177
+ embed = self.embed.t()
178
+ dist = -(
179
+ x.pow(2).sum(1, keepdim=True)
180
+ - 2 * x @ embed
181
+ + embed.pow(2).sum(0, keepdim=True)
182
+ )
183
+ embed_ind = dist.max(dim=-1).indices
184
+ return embed_ind
185
+
186
+ def postprocess_emb(self, embed_ind, shape):
187
+ return embed_ind.view(*shape[:-1])
188
+
189
+ def dequantize(self, embed_ind):
190
+ quantize = F.embedding(embed_ind, self.embed)
191
+ return quantize
192
+
193
+ def encode(self, x):
194
+ shape = x.shape
195
+ # pre-process
196
+ x = self.preprocess(x)
197
+ # quantize
198
+ embed_ind = self.quantize(x)
199
+ # post-process
200
+ embed_ind = self.postprocess_emb(embed_ind, shape)
201
+ return embed_ind
202
+
203
+ def decode(self, embed_ind):
204
+ quantize = self.dequantize(embed_ind)
205
+ return quantize
206
+
207
+ def forward(self, x):
208
+ shape, dtype = x.shape, x.dtype
209
+ x = self.preprocess(x)
210
+
211
+ self.init_embed_(x)
212
+
213
+ embed_ind = self.quantize(x)
214
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
215
+ embed_ind = self.postprocess_emb(embed_ind, shape)
216
+ quantize = self.dequantize(embed_ind)
217
+
218
+ if self.training:
219
+ # We do the expiry of code at that point as buffers are in sync
220
+ # and all the workers will take the same decision.
221
+ self.expire_codes_(x)
222
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
223
+ embed_sum = x.t() @ embed_onehot
224
+ ema_inplace(self.embed_avg, embed_sum.t().contiguous(), self.decay)
225
+ cluster_size = (
226
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
227
+ * self.cluster_size.sum()
228
+ )
229
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
230
+ self.embed.data.copy_(embed_normalized)
231
+
232
+ return quantize, embed_ind
233
+
234
+
235
+ class VectorQuantization(nn.Module):
236
+ """Vector quantization implementation.
237
+ Currently supports only euclidean distance.
238
+ Args:
239
+ dim (int): Dimension
240
+ codebook_size (int): Codebook size
241
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
242
+ decay (float): Decay for exponential moving average over the codebooks.
243
+ epsilon (float): Epsilon value for numerical stability.
244
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
245
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
246
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
247
+ that have an exponential moving average cluster size less than the specified threshold with
248
+ randomly selected vector from the current batch.
249
+ commitment_weight (float): Weight for commitment loss.
250
+ """
251
+
252
+ def __init__(
253
+ self,
254
+ dim: int,
255
+ codebook_size: int,
256
+ codebook_dim: tp.Optional[int] = None,
257
+ decay: float = 0.99,
258
+ epsilon: float = 1e-5,
259
+ kmeans_init: bool = True,
260
+ kmeans_iters: int = 50,
261
+ threshold_ema_dead_code: int = 2,
262
+ commitment_weight: float = 1.0,
263
+ ):
264
+ super().__init__()
265
+ _codebook_dim: int = default(codebook_dim, dim)
266
+
267
+ requires_projection = _codebook_dim != dim
268
+ self.project_in = (
269
+ nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
270
+ )
271
+ self.project_out = (
272
+ nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
273
+ )
274
+
275
+ self.epsilon = epsilon
276
+ self.commitment_weight = commitment_weight
277
+
278
+ self._codebook = EuclideanCodebook(
279
+ dim=_codebook_dim,
280
+ codebook_size=codebook_size,
281
+ kmeans_init=kmeans_init,
282
+ kmeans_iters=kmeans_iters,
283
+ decay=decay,
284
+ epsilon=epsilon,
285
+ threshold_ema_dead_code=threshold_ema_dead_code,
286
+ )
287
+ self.codebook_size = codebook_size
288
+
289
+ @property
290
+ def codebook(self):
291
+ return self._codebook.embed
292
+
293
+ def encode(self, x):
294
+ # x = rearrange(x, "b d n -> b n d")
295
+ x = self.project_in(x)
296
+ embed_in = self._codebook.encode(x)
297
+ return embed_in
298
+
299
+ def decode(self, embed_ind):
300
+ quantize = self._codebook.decode(embed_ind)
301
+ quantize = self.project_out(quantize)
302
+ # quantize = rearrange(quantize, "b n d -> b d n")
303
+ return quantize
304
+
305
+ def forward(self, x):
306
+ device = x.device
307
+ x = self.project_in(x)
308
+
309
+ quantize, embed_ind = self._codebook(x)
310
+
311
+ if self.training:
312
+ quantize = x + (quantize - x).detach()
313
+
314
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
315
+
316
+ if self.training:
317
+ if self.commitment_weight > 0:
318
+ commit_loss = F.mse_loss(quantize.detach(), x)
319
+ loss = loss + commit_loss * self.commitment_weight
320
+
321
+ quantize = self.project_out(quantize)
322
+ # quantize = rearrange(quantize, "b n d -> b d n")
323
+ return quantize, embed_ind, loss
324
+
325
+
326
+ class ResidualVectorQuantization(nn.Module):
327
+ """Residual vector quantization implementation.
328
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
329
+ """
330
+
331
+ def __init__(self, *, num_quantizers, codebook_size, **kwargs):
332
+ super().__init__()
333
+ if isinstance(codebook_size, int):
334
+ codebook_size = [codebook_size] * num_quantizers
335
+ elif len(codebook_size) < num_quantizers:
336
+ codebook_size += [codebook_size[-1]] * (num_quantizers - len(codebook_size))
337
+ self.layers = nn.ModuleList(
338
+ [
339
+ VectorQuantization(codebook_size=codebook_size[i], **kwargs)
340
+ for i in range(num_quantizers)
341
+ ]
342
+ )
343
+
344
+ def forward(
345
+ self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None
346
+ ):
347
+ quantized_out = 0.0
348
+ residual = x
349
+
350
+ all_losses = []
351
+ all_indices = []
352
+ out_quantized = []
353
+
354
+ n_q = n_q or len(self.layers)
355
+
356
+ for i, layer in enumerate(self.layers[:n_q]):
357
+ quantized, indices, loss = layer(residual)
358
+ residual = residual - quantized
359
+ quantized_out = quantized_out + quantized
360
+
361
+ all_indices.append(indices)
362
+ all_losses.append(loss)
363
+ if layers and i in layers:
364
+ out_quantized.append(quantized_out)
365
+
366
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
367
+ return quantized_out, out_indices, out_losses, out_quantized
368
+
369
+ def encode(
370
+ self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
371
+ ) -> torch.Tensor:
372
+ residual = x
373
+ all_indices = []
374
+ n_q = len(self.layers) if n_q is None else n_q
375
+ st = 0 if st is None else st
376
+ for layer in self.layers[st:n_q]:
377
+ indices = layer.encode(residual)
378
+ quantized = layer.decode(indices)
379
+ residual = residual - quantized
380
+ all_indices.append(indices)
381
+ out_indices = torch.stack(all_indices)
382
+ return out_indices
383
+
384
+ def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
385
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
386
+ for i, indices in enumerate(q_indices):
387
+ layer = self.layers[st + i]
388
+ quantized = layer.decode(indices)
389
+ quantized_out = quantized_out + quantized
390
+ return quantized_out
391
+
392
+
393
+ class ResidualVectorQuantizer(nn.Module):
394
+ """Residual Vector Quantizer.
395
+ Args:
396
+ dimension (int): Dimension of the codebooks.
397
+ n_q (int): Number of residual vector quantizers used.
398
+ bins (int): Codebook size.
399
+ decay (float): Decay for exponential moving average over the codebooks.
400
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
401
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
402
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
403
+ that have an exponential moving average cluster size less than the specified threshold with
404
+ randomly selected vector from the current batch.
405
+ """
406
+
407
+ def __init__(
408
+ self,
409
+ dimension: int = 256,
410
+ n_q: int = 8,
411
+ bins: int | list = 1024,
412
+ decay: float = 0.99,
413
+ kmeans_init: bool = True,
414
+ kmeans_iters: int = 50,
415
+ threshold_ema_dead_code: int = 2,
416
+ ):
417
+ super().__init__()
418
+ self.n_q = n_q
419
+ self.dimension = dimension
420
+ self.bins = bins
421
+ self.decay = decay
422
+ self.kmeans_init = kmeans_init
423
+ self.kmeans_iters = kmeans_iters
424
+ self.threshold_ema_dead_code = threshold_ema_dead_code
425
+ self.vq = ResidualVectorQuantization(
426
+ dim=self.dimension,
427
+ codebook_size=self.bins,
428
+ num_quantizers=self.n_q,
429
+ decay=self.decay,
430
+ kmeans_init=self.kmeans_init,
431
+ kmeans_iters=self.kmeans_iters,
432
+ threshold_ema_dead_code=self.threshold_ema_dead_code,
433
+ )
434
+
435
+ def forward(
436
+ self,
437
+ x: torch.Tensor,
438
+ n_q: tp.Optional[int] = None,
439
+ layers: tp.Optional[list] = None,
440
+ ):
441
+ """Residual vector quantization on the given input tensor.
442
+ Args:
443
+ x (torch.Tensor): Input tensor.
444
+ n_q (int): Number of quantizer used to quantize. Default: All quantizers.
445
+ layers (list): Layer that need to return quantized. Defalt: None.
446
+ Returns:
447
+ QuantizedResult:
448
+ The quantized (or approximately quantized) representation with
449
+ the associated numbert quantizers and layer quantized required to return.
450
+ """
451
+ n_q = n_q if n_q else self.n_q
452
+ quantized, codes, commit_loss, quantized_list = self.vq(
453
+ x, n_q=n_q, layers=layers
454
+ )
455
+ return quantized, codes, torch.mean(commit_loss), quantized_list
456
+
457
+ def encode(
458
+ self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
459
+ ) -> torch.Tensor:
460
+ """Encode a given input tensor with the specified sample rate at the given bandwidth.
461
+ The RVQ encode method sets the appropriate number of quantizer to use
462
+ and returns indices for each quantizer.
463
+ Args:
464
+ x (torch.Tensor): Input tensor.
465
+ n_q (int): Number of quantizer used to quantize. Default: All quantizers.
466
+ st (int): Start to encode input from which layers. Default: 0.
467
+ """
468
+ n_q = n_q if n_q else self.n_q
469
+ st = st or 0
470
+ codes = self.vq.encode(x, n_q=n_q, st=st)
471
+ return codes
472
+
473
+ def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor:
474
+ """Decode the given codes to the quantized representation.
475
+ Args:
476
+ codes (torch.Tensor): Input indices for each quantizer.
477
+ st (int): Start to decode input codes from which layers. Default: 0.
478
+ """
479
+ quantized = self.vq.decode(codes, st=st)
480
+ return quantized