jade commited on
Commit
bf277fe
·
1 Parent(s): def28d3

Base files

Browse files
.github/workflows/ci.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Continuous Integration
2
+
3
+ on:
4
+ pull_request:
5
+ branches:
6
+ - main
7
+
8
+ jobs:
9
+ lint_and_format:
10
+ runs-on: ubuntu-latest
11
+ name: Lint and Format
12
+ steps:
13
+ - uses: actions/checkout@v4
14
+ - uses: astral-sh/ruff-action@v3
15
+ with:
16
+ version: latest
17
+
18
+ - name: Check Lint using Ruff
19
+ run: ruff check
20
+
21
+ - name: Check Format using Ruff
22
+ run: ruff format --check --diff
.gitignore ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+ .idea/
12
+ .gradio
13
+
14
+ **/*.pth
15
+ **/*.safetensors
16
+ **/*.mp3
17
+ !example_prompt.mp3
18
+ **/*.txt
19
+
20
+ .ruff_cache
21
+ .ipynb_checkpoints
22
+ config.json
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.10
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2025 Nari Labs
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
app.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import contextlib
3
+ import io
4
+ import random
5
+ import tempfile
6
+ import time
7
+ from pathlib import Path
8
+ from typing import Optional, Tuple
9
+
10
+ import gradio as gr
11
+ import numpy as np
12
+ import soundfile as sf
13
+ import torch
14
+
15
+ from dia.model import Dia
16
+
17
+
18
+ # --- Global Setup ---
19
+ parser = argparse.ArgumentParser(description="Gradio interface for Nari TTS")
20
+ parser.add_argument("--device", type=str, default=None, help="Force device (e.g., 'cuda', 'mps', 'cpu')")
21
+ parser.add_argument("--share", action="store_true", help="Enable Gradio sharing")
22
+
23
+ args = parser.parse_args()
24
+
25
+
26
+ # Determine device
27
+ if args.device:
28
+ device = torch.device(args.device)
29
+ elif torch.cuda.is_available():
30
+ device = torch.device("cuda")
31
+ # Simplified MPS check for broader compatibility
32
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
33
+ # Basic check is usually sufficient, detailed check can be problematic
34
+ device = torch.device("mps")
35
+ else:
36
+ device = torch.device("cpu")
37
+
38
+ print(f"Using device: {device}")
39
+
40
+ # Load Nari model and config
41
+ print("Loading Nari model...")
42
+ try:
43
+ dtype_map = {
44
+ "cpu": "float32",
45
+ "mps": "float32", # Apple M series – better with float32
46
+ "cuda": "float16", # NVIDIA – better with float16
47
+ }
48
+
49
+ dtype = dtype_map.get(device.type, "float16")
50
+ print(f"Using device: {device}, attempting to load model with {dtype}")
51
+ model = Dia.from_pretrained("nari-labs/Dia-1.6B-0626", compute_dtype=dtype, device=device)
52
+ except Exception as e:
53
+ print(f"Error loading Nari model: {e}")
54
+ raise
55
+
56
+
57
+ def set_seed(seed: int):
58
+ """Sets the random seed for reproducibility."""
59
+ random.seed(seed)
60
+ np.random.seed(seed)
61
+ torch.manual_seed(seed)
62
+ if torch.cuda.is_available():
63
+ torch.cuda.manual_seed(seed)
64
+ torch.cuda.manual_seed_all(seed)
65
+ torch.backends.cudnn.deterministic = True
66
+ torch.backends.cudnn.benchmark = False
67
+
68
+
69
+ def run_inference(
70
+ text_input: str,
71
+ audio_prompt_text_input: str,
72
+ audio_prompt_input: Optional[Tuple[int, np.ndarray]],
73
+ max_new_tokens: int,
74
+ cfg_scale: float,
75
+ temperature: float,
76
+ top_p: float,
77
+ cfg_filter_top_k: int,
78
+ speed_factor: float,
79
+ seed: Optional[int] = None,
80
+ ):
81
+ """
82
+ Runs Nari inference using the globally loaded model and provided inputs.
83
+ Uses temporary files for text and audio prompt compatibility with inference.generate.
84
+ """
85
+ global model, device # Access global model, config, device
86
+ console_output_buffer = io.StringIO()
87
+
88
+ with contextlib.redirect_stdout(console_output_buffer):
89
+ # Prepend transcript text if audio_prompt provided
90
+ if audio_prompt_input and audio_prompt_text_input and not audio_prompt_text_input.isspace():
91
+ text_input = audio_prompt_text_input + "\n" + text_input
92
+ text_input = text_input.strip()
93
+
94
+ if audio_prompt_input and (not audio_prompt_text_input or audio_prompt_text_input.isspace()):
95
+ raise gr.Error("Audio Prompt Text input cannot be empty.")
96
+
97
+ if not text_input or text_input.isspace():
98
+ raise gr.Error("Text input cannot be empty.")
99
+
100
+ # Preprocess Audio
101
+ temp_txt_file_path = None
102
+ temp_audio_prompt_path = None
103
+ output_audio = (44100, np.zeros(1, dtype=np.float32))
104
+
105
+ try:
106
+ prompt_path_for_generate = None
107
+ if audio_prompt_input is not None:
108
+ sr, audio_data = audio_prompt_input
109
+ # Check if audio_data is valid
110
+ if audio_data is None or audio_data.size == 0 or audio_data.max() == 0: # Check for silence/empty
111
+ gr.Warning("Audio prompt seems empty or silent, ignoring prompt.")
112
+ else:
113
+ # Save prompt audio to a temporary WAV file
114
+ with tempfile.NamedTemporaryFile(mode="wb", suffix=".wav", delete=False) as f_audio:
115
+ temp_audio_prompt_path = f_audio.name # Store path for cleanup
116
+
117
+ # Basic audio preprocessing for consistency
118
+ # Convert to float32 in [-1, 1] range if integer type
119
+ if np.issubdtype(audio_data.dtype, np.integer):
120
+ max_val = np.iinfo(audio_data.dtype).max
121
+ audio_data = audio_data.astype(np.float32) / max_val
122
+ elif not np.issubdtype(audio_data.dtype, np.floating):
123
+ gr.Warning(f"Unsupported audio prompt dtype {audio_data.dtype}, attempting conversion.")
124
+ # Attempt conversion, might fail for complex types
125
+ try:
126
+ audio_data = audio_data.astype(np.float32)
127
+ except Exception as conv_e:
128
+ raise gr.Error(f"Failed to convert audio prompt to float32: {conv_e}")
129
+
130
+ # Ensure mono (average channels if stereo)
131
+ if audio_data.ndim > 1:
132
+ if audio_data.shape[0] == 2: # Assume (2, N)
133
+ audio_data = np.mean(audio_data, axis=0)
134
+ elif audio_data.shape[1] == 2: # Assume (N, 2)
135
+ audio_data = np.mean(audio_data, axis=1)
136
+ else:
137
+ gr.Warning(
138
+ f"Audio prompt has unexpected shape {audio_data.shape}, taking first channel/axis."
139
+ )
140
+ audio_data = (
141
+ audio_data[0] if audio_data.shape[0] < audio_data.shape[1] else audio_data[:, 0]
142
+ )
143
+ audio_data = np.ascontiguousarray(audio_data) # Ensure contiguous after slicing/mean
144
+
145
+ # Write using soundfile
146
+ try:
147
+ sf.write(
148
+ temp_audio_prompt_path, audio_data, sr, subtype="FLOAT"
149
+ ) # Explicitly use FLOAT subtype
150
+ prompt_path_for_generate = temp_audio_prompt_path
151
+ print(f"Created temporary audio prompt file: {temp_audio_prompt_path} (orig sr: {sr})")
152
+ except Exception as write_e:
153
+ print(f"Error writing temporary audio file: {write_e}")
154
+ raise gr.Error(f"Failed to save audio prompt: {write_e}")
155
+
156
+ # Set and Display Generation Seed
157
+ if seed is None or seed < 0:
158
+ seed = random.randint(0, 2**32 - 1)
159
+ print(f"\nNo seed provided, generated random seed: {seed}\n")
160
+ else:
161
+ print(f"\nUsing user-selected seed: {seed}\n")
162
+ set_seed(seed)
163
+
164
+ # Run Generation
165
+ print(f'Generating speech: \n"{text_input}"\n')
166
+
167
+ start_time = time.time()
168
+
169
+ # Use torch.inference_mode() context manager for the generation call
170
+ with torch.inference_mode():
171
+ output_audio_np = model.generate(
172
+ text_input,
173
+ max_tokens=max_new_tokens,
174
+ cfg_scale=cfg_scale,
175
+ temperature=temperature,
176
+ top_p=top_p,
177
+ cfg_filter_top_k=cfg_filter_top_k, # Pass the value here
178
+ use_torch_compile=False, # Keep False for Gradio stability
179
+ audio_prompt=prompt_path_for_generate,
180
+ verbose=True,
181
+ )
182
+
183
+ end_time = time.time()
184
+ print(f"Generation finished in {end_time - start_time:.2f} seconds.\n")
185
+
186
+ # 4. Convert Codes to Audio
187
+ if output_audio_np is not None:
188
+ # Get sample rate from the loaded DAC model
189
+ output_sr = 44100
190
+
191
+ # --- Slow down audio ---
192
+ original_len = len(output_audio_np)
193
+ # Ensure speed_factor is positive and not excessively small/large to avoid issues
194
+ speed_factor = max(0.1, min(speed_factor, 5.0))
195
+ target_len = int(original_len / speed_factor) # Target length based on speed_factor
196
+ if target_len != original_len and target_len > 0: # Only interpolate if length changes and is valid
197
+ x_original = np.arange(original_len)
198
+ x_resampled = np.linspace(0, original_len - 1, target_len)
199
+ resampled_audio_np = np.interp(x_resampled, x_original, output_audio_np)
200
+ output_audio = (
201
+ output_sr,
202
+ resampled_audio_np.astype(np.float32),
203
+ ) # Use resampled audio
204
+ print(
205
+ f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed."
206
+ )
207
+ else:
208
+ output_audio = (
209
+ output_sr,
210
+ output_audio_np,
211
+ ) # Keep original if calculation fails or no change
212
+ print(f"Skipping audio speed adjustment (factor: {speed_factor:.2f}).")
213
+ # --- End slowdown ---
214
+
215
+ print(f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}")
216
+
217
+ # Explicitly convert to int16 to prevent Gradio warning
218
+ if output_audio[1].dtype == np.float32 or output_audio[1].dtype == np.float64:
219
+ audio_for_gradio = np.clip(output_audio[1], -1.0, 1.0)
220
+ audio_for_gradio = (audio_for_gradio * 32767).astype(np.int16)
221
+ output_audio = (output_sr, audio_for_gradio)
222
+ print("Converted audio to int16 for Gradio output.")
223
+
224
+ else:
225
+ print("\nGeneration finished, but no valid tokens were produced.")
226
+ # Return default silence
227
+ gr.Warning("Generation produced no output.")
228
+
229
+ except Exception as e:
230
+ print(f"Error during inference: {e}")
231
+ import traceback
232
+
233
+ traceback.print_exc()
234
+ # Re-raise as Gradio error to display nicely in the UI
235
+ raise gr.Error(f"Inference failed: {e}")
236
+
237
+ finally:
238
+ # Cleanup Temporary Files defensively
239
+ if temp_txt_file_path and Path(temp_txt_file_path).exists():
240
+ try:
241
+ Path(temp_txt_file_path).unlink()
242
+ print(f"Deleted temporary text file: {temp_txt_file_path}")
243
+ except OSError as e:
244
+ print(f"Warning: Error deleting temporary text file {temp_txt_file_path}: {e}")
245
+ if temp_audio_prompt_path and Path(temp_audio_prompt_path).exists():
246
+ try:
247
+ Path(temp_audio_prompt_path).unlink()
248
+ print(f"Deleted temporary audio prompt file: {temp_audio_prompt_path}")
249
+ except OSError as e:
250
+ print(f"Warning: Error deleting temporary audio prompt file {temp_audio_prompt_path}: {e}")
251
+
252
+ # After generation, capture the printed output
253
+ console_output = console_output_buffer.getvalue()
254
+
255
+ return output_audio, seed, console_output
256
+
257
+
258
+ # --- Create Gradio Interface ---
259
+ css = """
260
+ #col-container {max-width: 90%; margin-left: auto; margin-right: auto;}
261
+ """
262
+ # Attempt to load default text from example.txt
263
+ default_text = "[S1] Dia is an open weights text to dialogue model. \n[S2] You get full control over scripts and voices. \n[S1] Wow. Amazing. (laughs) \n[S2] Try it now on Git hub or Hugging Face."
264
+ example_txt_path = Path("./example.txt")
265
+ if example_txt_path.exists():
266
+ try:
267
+ default_text = example_txt_path.read_text(encoding="utf-8").strip()
268
+ if not default_text: # Handle empty example file
269
+ default_text = "Example text file was empty."
270
+ except Exception as e:
271
+ print(f"Warning: Could not read example.txt: {e}")
272
+
273
+
274
+ # Build Gradio UI
275
+ with gr.Blocks(css=css, theme="gradio/dark") as demo:
276
+ gr.Markdown("# Nari Text-to-Speech Synthesis")
277
+
278
+ with gr.Row(equal_height=False):
279
+ with gr.Column(scale=1):
280
+ with gr.Accordion("Audio Reference Prompt (Optional)", open=False):
281
+ audio_prompt_input = gr.Audio(
282
+ label="Audio Prompt (Optional)",
283
+ show_label=True,
284
+ sources=["upload", "microphone"],
285
+ type="numpy",
286
+ )
287
+ audio_prompt_text_input = gr.Textbox(
288
+ label="Transcript of Audio Prompt (Required if using Audio Prompt)",
289
+ placeholder="Enter text here...",
290
+ value="",
291
+ lines=5, # Increased lines
292
+ )
293
+ text_input = gr.Textbox(
294
+ label="Text To Generate",
295
+ placeholder="Enter text here...",
296
+ value=default_text,
297
+ lines=5, # Increased lines
298
+ )
299
+ with gr.Accordion("Generation Parameters", open=False):
300
+ max_new_tokens = gr.Slider(
301
+ label="Max New Tokens (Audio Length)",
302
+ minimum=860,
303
+ maximum=3072,
304
+ value=model.config.decoder_config.max_position_embeddings, # Use config default if available, else fallback
305
+ step=50,
306
+ info="Controls the maximum length of the generated audio (more tokens = longer audio).",
307
+ )
308
+ cfg_scale = gr.Slider(
309
+ label="CFG Scale (Guidance Strength)",
310
+ minimum=1.0,
311
+ maximum=5.0,
312
+ value=3.0, # Default from inference.py
313
+ step=0.1,
314
+ info="Higher values increase adherence to the text prompt.",
315
+ )
316
+ temperature = gr.Slider(
317
+ label="Temperature (Randomness)",
318
+ minimum=1.0,
319
+ maximum=2.5,
320
+ value=1.8, # Default from inference.py
321
+ step=0.05,
322
+ info="Lower values make the output more deterministic, higher values increase randomness.",
323
+ )
324
+ top_p = gr.Slider(
325
+ label="Top P (Nucleus Sampling)",
326
+ minimum=0.70,
327
+ maximum=1.0,
328
+ value=0.95, # Default from inference.py
329
+ step=0.01,
330
+ info="Filters vocabulary to the most likely tokens cumulatively reaching probability P.",
331
+ )
332
+ cfg_filter_top_k = gr.Slider(
333
+ label="CFG Filter Top K",
334
+ minimum=15,
335
+ maximum=100,
336
+ value=45,
337
+ step=1,
338
+ info="Top k filter for CFG guidance.",
339
+ )
340
+ speed_factor_slider = gr.Slider(
341
+ label="Speed Factor",
342
+ minimum=0.8,
343
+ maximum=1.0,
344
+ value=1.0,
345
+ step=0.02,
346
+ info="Adjusts the speed of the generated audio (1.0 = original speed).",
347
+ )
348
+ seed_input = gr.Number(
349
+ label="Generation Seed (Optional)",
350
+ value=-1,
351
+ precision=0, # No decimal points
352
+ step=1,
353
+ interactive=True,
354
+ info="Set a generation seed for reproducible outputs. Leave empty or -1 for random seed.",
355
+ )
356
+
357
+ run_button = gr.Button("Generate Audio", variant="primary")
358
+
359
+ with gr.Column(scale=1):
360
+ audio_output = gr.Audio(
361
+ label="Generated Audio",
362
+ type="numpy",
363
+ autoplay=False,
364
+ )
365
+ seed_output = gr.Textbox(label="Generation Seed", interactive=False)
366
+ console_output = gr.Textbox(label="Console Output Log", lines=10, interactive=False)
367
+
368
+ # Link button click to function
369
+ run_button.click(
370
+ fn=run_inference,
371
+ inputs=[
372
+ text_input,
373
+ audio_prompt_text_input,
374
+ audio_prompt_input,
375
+ max_new_tokens,
376
+ cfg_scale,
377
+ temperature,
378
+ top_p,
379
+ cfg_filter_top_k,
380
+ speed_factor_slider,
381
+ seed_input,
382
+ ],
383
+ outputs=[
384
+ audio_output,
385
+ seed_output,
386
+ console_output,
387
+ ], # Add status_output here if using it
388
+ api_name="generate_audio",
389
+ )
390
+
391
+ # Add examples (ensure the prompt path is correct or remove it if example file doesn't exist)
392
+ example_prompt_path = "./example_prompt.mp3" # Adjust if needed
393
+ examples_list = [
394
+ [
395
+ "[S1] Oh fire! Oh my goodness! What's the procedure? What to we do people? The smoke could be coming through an air duct! \n[S2] Oh my god! Okay.. it's happening. Everybody stay calm! \n[S1] What's the procedure... \n[S2] Everybody stay fucking calm!!!... Everybody fucking calm down!!!!! \n[S1] No! No! If you touch the handle, if its hot there might be a fire down the hallway! ",
396
+ None,
397
+ 3072,
398
+ 3.0,
399
+ 1.8,
400
+ 0.95,
401
+ 45,
402
+ 1.0,
403
+ ],
404
+ [
405
+ "[S1] Open weights text to dialogue model. \n[S2] You get full control over scripts and voices. \n[S1] I'm biased, but I think we clearly won. \n[S2] Hard to disagree. (laughs) \n[S1] Thanks for listening to this demo. \n[S2] Try it now on Git hub and Hugging Face. \n[S1] If you liked our model, please give us a star and share to your friends. \n[S2] This was Nari Labs.",
406
+ example_prompt_path if Path(example_prompt_path).exists() else None,
407
+ 3072,
408
+ 3.0,
409
+ 1.8,
410
+ 0.95,
411
+ 45,
412
+ 1.0,
413
+ ],
414
+ ]
415
+
416
+ if examples_list:
417
+ gr.Examples(
418
+ examples=examples_list,
419
+ inputs=[
420
+ text_input,
421
+ audio_prompt_input,
422
+ max_new_tokens,
423
+ cfg_scale,
424
+ temperature,
425
+ top_p,
426
+ cfg_filter_top_k,
427
+ speed_factor_slider,
428
+ seed_input,
429
+ ],
430
+ outputs=[audio_output],
431
+ fn=run_inference,
432
+ cache_examples=False,
433
+ label="Examples (Click to Run)",
434
+ )
435
+ else:
436
+ gr.Markdown("_(No examples configured or example prompt file missing)_")
437
+
438
+ # --- Launch the App ---
439
+ if __name__ == "__main__":
440
+ print("Launching Gradio interface...")
441
+
442
+ # set `GRADIO_SERVER_NAME`, `GRADIO_SERVER_PORT` env vars to override default values
443
+ # use `GRADIO_SERVER_NAME=0.0.0.0` for Docker
444
+ demo.launch(share=args.share)
cli.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+
5
+ import numpy as np
6
+ import soundfile as sf
7
+ import torch
8
+
9
+ from dia.model import Dia
10
+
11
+
12
+ def set_seed(seed: int):
13
+ """Sets the random seed for reproducibility."""
14
+ random.seed(seed)
15
+ np.random.seed(seed)
16
+ torch.manual_seed(seed)
17
+ if torch.cuda.is_available():
18
+ torch.cuda.manual_seed(seed)
19
+ torch.cuda.manual_seed_all(seed)
20
+ # Ensure deterministic behavior for cuDNN (if used)
21
+ torch.backends.cudnn.deterministic = True
22
+ torch.backends.cudnn.benchmark = False
23
+
24
+
25
+ def main():
26
+ parser = argparse.ArgumentParser(description="Generate audio using the Dia model.")
27
+
28
+ parser.add_argument("text", type=str, help="Input text for speech generation.")
29
+ parser.add_argument(
30
+ "--output", type=str, required=True, help="Path to save the generated audio file (e.g., output.wav)."
31
+ )
32
+
33
+ parser.add_argument(
34
+ "--repo-id",
35
+ type=str,
36
+ default="nari-labs/Dia-1.6B-0626",
37
+ help="Hugging Face repository ID (e.g., nari-labs/Dia-1.6B-0626).",
38
+ )
39
+ parser.add_argument(
40
+ "--local-paths", action="store_true", help="Load model from local config and checkpoint files."
41
+ )
42
+
43
+ parser.add_argument(
44
+ "--config", type=str, help="Path to local config.json file (required if --local-paths is set)."
45
+ )
46
+ parser.add_argument(
47
+ "--checkpoint", type=str, help="Path to local model checkpoint .pth file (required if --local-paths is set)."
48
+ )
49
+ parser.add_argument(
50
+ "--audio-prompt", type=str, default=None, help="Path to an optional audio prompt WAV file for voice cloning."
51
+ )
52
+
53
+ gen_group = parser.add_argument_group("Generation Parameters")
54
+ gen_group.add_argument(
55
+ "--max-tokens",
56
+ type=int,
57
+ default=None,
58
+ help="Maximum number of audio tokens to generate (defaults to config value).",
59
+ )
60
+ gen_group.add_argument(
61
+ "--cfg-scale", type=float, default=3.0, help="Classifier-Free Guidance scale (default: 3.0)."
62
+ )
63
+ gen_group.add_argument(
64
+ "--temperature", type=float, default=1.3, help="Sampling temperature (higher is more random, default: 0.7)."
65
+ )
66
+ gen_group.add_argument("--top-p", type=float, default=0.95, help="Nucleus sampling probability (default: 0.95).")
67
+
68
+ infra_group = parser.add_argument_group("Infrastructure")
69
+ infra_group.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility.")
70
+ infra_group.add_argument(
71
+ "--device",
72
+ type=str,
73
+ default="cuda" if torch.cuda.is_available() else "cpu",
74
+ help="Device to run inference on (e.g., 'cuda', 'cpu', default: auto).",
75
+ )
76
+
77
+ args = parser.parse_args()
78
+
79
+ # Validation for local paths
80
+ if args.local_paths:
81
+ if not args.config:
82
+ parser.error("--config is required when --local-paths is set.")
83
+ if not args.checkpoint:
84
+ parser.error("--checkpoint is required when --local-paths is set.")
85
+ if not os.path.exists(args.config):
86
+ parser.error(f"Config file not found: {args.config}")
87
+ if not os.path.exists(args.checkpoint):
88
+ parser.error(f"Checkpoint file not found: {args.checkpoint}")
89
+
90
+ # Set seed if provided
91
+ if args.seed is not None:
92
+ set_seed(args.seed)
93
+ print(f"Using user-selected seed: {args.seed}")
94
+
95
+ # Determine device
96
+ device = torch.device(args.device)
97
+ print(f"Using device: {device}")
98
+
99
+ # Load model
100
+ print("Loading model...")
101
+ if args.local_paths:
102
+ print(f"Loading from local paths: config='{args.config}', checkpoint='{args.checkpoint}'")
103
+ try:
104
+ model = Dia.from_local(args.config, args.checkpoint, device=device)
105
+ except Exception as e:
106
+ print(f"Error loading local model: {e}")
107
+ exit(1)
108
+ else:
109
+ print(f"Loading from Hugging Face Hub: repo_id='{args.repo_id}'")
110
+ try:
111
+ model = Dia.from_pretrained(args.repo_id, device=device)
112
+ except Exception as e:
113
+ print(f"Error loading model from Hub: {e}")
114
+ exit(1)
115
+ print("Model loaded.")
116
+
117
+ # Generate audio
118
+ print("Generating audio...")
119
+ try:
120
+ sample_rate = 44100 # Default assumption
121
+
122
+ output_audio = model.generate(
123
+ text=args.text,
124
+ audio_prompt=args.audio_prompt,
125
+ max_tokens=args.max_tokens,
126
+ cfg_scale=args.cfg_scale,
127
+ temperature=args.temperature,
128
+ top_p=args.top_p,
129
+ )
130
+ print("Audio generation complete.")
131
+
132
+ print(f"Saving audio to {args.output}...")
133
+ os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
134
+
135
+ sf.write(args.output, output_audio, sample_rate)
136
+ print(f"Audio successfully saved to {args.output}")
137
+
138
+ except Exception as e:
139
+ print(f"Error during audio generation or saving: {e}")
140
+ exit(1)
141
+
142
+
143
+ if __name__ == "__main__":
144
+ main()
dia/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .model import Dia
2
+
3
+
4
+ __all__ = [
5
+ "Dia",
6
+ ]
dia/audio.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as tp
2
+
3
+ import torch
4
+
5
+
6
+ def build_delay_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
7
+ """
8
+ Precompute (t_idx_BxTxC, indices_BTCx3) so that out[t, c] = in[t - delay[c], c].
9
+ Negative t_idx => BOS; t_idx >= T => PAD.
10
+ """
11
+ delay_arr = torch.tensor(delay_pattern, dtype=torch.int32)
12
+
13
+ t_idx_BxT = torch.broadcast_to(
14
+ torch.arange(T, dtype=torch.int32)[None, :],
15
+ [B, T],
16
+ )
17
+ t_idx_BxTx1 = t_idx_BxT[..., None]
18
+ t_idx_BxTxC = t_idx_BxTx1 - delay_arr.view(1, 1, C)
19
+
20
+ b_idx_BxTxC = torch.broadcast_to(
21
+ torch.arange(B, dtype=torch.int32).view(B, 1, 1),
22
+ [B, T, C],
23
+ )
24
+ c_idx_BxTxC = torch.broadcast_to(
25
+ torch.arange(C, dtype=torch.int32).view(1, 1, C),
26
+ [B, T, C],
27
+ )
28
+
29
+ # We must clamp time indices to [0..T-1] so gather_nd equivalent won't fail
30
+ t_clamped_BxTxC = torch.clamp(t_idx_BxTxC, 0, T - 1)
31
+
32
+ indices_BTCx3 = torch.stack(
33
+ [
34
+ b_idx_BxTxC.reshape(-1),
35
+ t_clamped_BxTxC.reshape(-1),
36
+ c_idx_BxTxC.reshape(-1),
37
+ ],
38
+ dim=1,
39
+ ).long() # Ensure indices are long type for indexing
40
+
41
+ return t_idx_BxTxC, indices_BTCx3
42
+
43
+
44
+ def apply_audio_delay(
45
+ audio_BxTxC: torch.Tensor,
46
+ pad_value: int,
47
+ bos_value: int,
48
+ precomp: tp.Tuple[torch.Tensor, torch.Tensor],
49
+ ) -> torch.Tensor:
50
+ """
51
+ Applies the delay pattern to batched audio tokens using precomputed indices,
52
+ inserting BOS where t_idx < 0 and PAD where t_idx >= T.
53
+
54
+ Args:
55
+ audio_BxTxC: [B, T, C] int16 audio tokens (or int32/float)
56
+ pad_value: the padding token
57
+ bos_value: the BOS token
58
+ precomp: (t_idx_BxTxC, indices_BTCx3) from build_delay_indices
59
+
60
+ Returns:
61
+ result_BxTxC: [B, T, C] delayed audio tokens
62
+ """
63
+ device = audio_BxTxC.device # Get device from input tensor
64
+ t_idx_BxTxC, indices_BTCx3 = precomp
65
+ t_idx_BxTxC = t_idx_BxTxC.to(device) # Move precomputed indices to device
66
+ indices_BTCx3 = indices_BTCx3.to(device)
67
+
68
+ # Equivalent of tf.gather_nd using advanced indexing
69
+ # Ensure indices are long type if not already (build_delay_indices should handle this)
70
+ gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
71
+ gathered_BxTxC = gathered_flat.view(audio_BxTxC.shape)
72
+
73
+ # Create masks on the correct device
74
+ mask_bos = t_idx_BxTxC < 0 # => place bos_value
75
+ mask_pad = t_idx_BxTxC >= audio_BxTxC.shape[1] # => place pad_value
76
+
77
+ # Create scalar tensors on the correct device
78
+ bos_tensor = torch.tensor(bos_value, dtype=audio_BxTxC.dtype, device=device)
79
+ pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
80
+
81
+ # If mask_bos, BOS; else if mask_pad, PAD; else original gather
82
+ # All tensors should now be on the same device
83
+ result_BxTxC = torch.where(mask_bos, bos_tensor, torch.where(mask_pad, pad_tensor, gathered_BxTxC))
84
+
85
+ return result_BxTxC
86
+
87
+
88
+ def build_revert_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
89
+ """
90
+ Precompute indices for the revert operation using PyTorch.
91
+
92
+ Returns:
93
+ A tuple (t_idx_BxTxC, indices_BTCx3) where:
94
+ - t_idx_BxTxC is a tensor of shape [B, T, C] computed as time indices plus the delay.
95
+ - indices_BTCx3 is a tensor of shape [B*T*C, 3] used for gathering, computed from:
96
+ batch indices, clamped time indices, and channel indices.
97
+ """
98
+ # Use default device unless specified otherwise; assumes inputs might define device later
99
+ device = None # Or determine dynamically if needed, e.g., from a model parameter
100
+
101
+ delay_arr = torch.tensor(delay_pattern, dtype=torch.int32, device=device)
102
+
103
+ t_idx_BT1 = torch.broadcast_to(torch.arange(T, device=device).unsqueeze(0), [B, T])
104
+ t_idx_BT1 = t_idx_BT1.unsqueeze(-1)
105
+
106
+ t_idx_BxTxC = torch.minimum(
107
+ t_idx_BT1 + delay_arr.view(1, 1, C),
108
+ torch.tensor(T - 1, device=device),
109
+ )
110
+ b_idx_BxTxC = torch.broadcast_to(torch.arange(B, device=device).view(B, 1, 1), [B, T, C])
111
+ c_idx_BxTxC = torch.broadcast_to(torch.arange(C, device=device).view(1, 1, C), [B, T, C])
112
+
113
+ indices_BTCx3 = torch.stack(
114
+ [
115
+ b_idx_BxTxC.reshape(-1),
116
+ t_idx_BxTxC.reshape(-1),
117
+ c_idx_BxTxC.reshape(-1),
118
+ ],
119
+ axis=1,
120
+ ).long() # Ensure indices are long type
121
+
122
+ return t_idx_BxTxC, indices_BTCx3
123
+
124
+
125
+ def revert_audio_delay(
126
+ audio_BxTxC: torch.Tensor,
127
+ pad_value: int,
128
+ precomp: tp.Tuple[torch.Tensor, torch.Tensor],
129
+ T: int,
130
+ ) -> torch.Tensor:
131
+ """
132
+ Reverts a delay pattern from batched audio tokens using precomputed indices (PyTorch version).
133
+
134
+ Args:
135
+ audio_BxTxC: Input delayed audio tensor
136
+ pad_value: Padding value for out-of-bounds indices
137
+ precomp: Precomputed revert indices tuple containing:
138
+ - t_idx_BxTxC: Time offset indices tensor
139
+ - indices_BTCx3: Gather indices tensor for original audio
140
+ T: Original sequence length before padding
141
+
142
+ Returns:
143
+ Reverted audio tensor with same shape as input
144
+ """
145
+ t_idx_BxTxC, indices_BTCx3 = precomp
146
+ device = audio_BxTxC.device # Get device from input tensor
147
+
148
+ # Move precomputed indices to the same device as audio_BxTxC if they aren't already
149
+ t_idx_BxTxC = t_idx_BxTxC.to(device)
150
+ indices_BTCx3 = indices_BTCx3.to(device)
151
+
152
+ # Using PyTorch advanced indexing (equivalent to tf.gather_nd or np equivalent)
153
+ gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
154
+ gathered_BxTxC = gathered_flat.view(audio_BxTxC.size()) # Use .size() for robust reshaping
155
+
156
+ # Create pad_tensor on the correct device
157
+ pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
158
+ # Create T tensor on the correct device for comparison
159
+ T_tensor = torch.tensor(T, device=device)
160
+
161
+ result_BxTxC = torch.where(t_idx_BxTxC >= T_tensor, pad_tensor, gathered_BxTxC) # Changed np.where to torch.where
162
+
163
+ return result_BxTxC
dia/config.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration management module for the Dia model.
2
+
3
+ This module provides comprehensive configuration management for the Dia model,
4
+ utilizing Pydantic for validation. It defines configurations for data processing,
5
+ model architecture (encoder and decoder), and training settings.
6
+
7
+ Key components:
8
+ - DataConfig: Parameters for data loading and preprocessing.
9
+ - EncoderConfig: Architecture details for the encoder module.
10
+ - DecoderConfig: Architecture details for the decoder module.
11
+ - ModelConfig: Combined model architecture settings.
12
+ - TrainingConfig: Training hyperparameters and settings.
13
+ - DiaConfig: Master configuration combining all components.
14
+ """
15
+
16
+ import os
17
+
18
+ from pydantic import BaseModel, Field
19
+
20
+
21
+ class EncoderConfig(BaseModel, frozen=True):
22
+ """Configuration for the encoder component of the Dia model.
23
+
24
+ Attributes:
25
+ model_type: Type of the model, defaults to "dia_encoder".
26
+ hidden_size: Size of the encoder layers, defaults to 1024.
27
+ intermediate_size: Size of the "intermediate" (i.e., feed-forward) layer in the encoder, defaults to 4096.
28
+ num_hidden_layers: Number of hidden layers in the encoder, defaults to 12.
29
+ num_attention_heads: Number of attention heads in the encoder, defaults to 16.
30
+ num_key_value_heads: Number of key-value heads in the encoder, defaults to 16.
31
+ head_dim: Dimension of each attention head, defaults to 128.
32
+ hidden_act: Activation function in the encoder, defaults to "silu".
33
+ max_position_embeddings: Maximum number of position embeddings, defaults to 1024.
34
+ initializer_range: Range for initializing weights, defaults to 0.02.
35
+ norm_eps: Epsilon value for normalization layers, defaults to 1e-5.
36
+ rope_theta: Theta value for RoPE, defaults to 10000.0.
37
+ rope_scaling: Optional scaling factor for RoPE.
38
+ vocab_size: Vocabulary size, defaults to 256.
39
+ """
40
+
41
+ head_dim: int = Field(default=128, gt=0)
42
+ hidden_act: str = Field(default="silu")
43
+ hidden_size: int = Field(default=1024, gt=0)
44
+ initializer_range: float = Field(default=0.02)
45
+ intermediate_size: int = Field(default=4096, gt=0)
46
+ max_position_embeddings: int = Field(default=1024, gt=0)
47
+ model_type: str = Field(default="dia_encoder")
48
+ norm_eps: float = Field(default=1e-5)
49
+ num_attention_heads: int = Field(default=16, gt=0)
50
+ num_hidden_layers: int = Field(default=12, gt=0)
51
+ num_key_value_heads: int = Field(default=16, gt=0)
52
+ rope_scaling: float | None = Field(default=None)
53
+ rope_theta: float = Field(default=10000.0)
54
+ vocab_size: int = Field(default=256, gt=0)
55
+
56
+
57
+ class DecoderConfig(BaseModel, frozen=True):
58
+ """Configuration for the decoder component of the Dia model.
59
+
60
+ Attributes:
61
+ model_type: Type of the model, defaults to "dia_decoder".
62
+ hidden_size: Size of the decoder layers, defaults to 2048.
63
+ intermediate_size: Size of the "intermediate" (i.e., feed-forward) layer in the decoder, defaults to 8192.
64
+ num_hidden_layers: Number of hidden layers in the decoder, defaults to 18.
65
+ num_attention_heads: Number of attention heads in the decoder, defaults to 16.
66
+ num_key_value_heads: Number of key-value heads in the decoder, defaults to 4.
67
+ head_dim: Dimension of each attention head, defaults to 128.
68
+ cross_hidden_size: Size of the cross-attention layers, defaults to 1024.
69
+ cross_num_attention_heads: Number of attention heads in the cross-attention mechanism, defaults to 16.
70
+ cross_num_key_value_heads: Number of key-value heads in the cross-attention mechanism, defaults to 16.
71
+ cross_head_dim: Dimension of each cross-attention head, defaults to 128.
72
+ hidden_act: Activation function in the decoder, defaults to "silu".
73
+ max_position_embeddings: Maximum number of position embeddings in the decoder, defaults to 3072.
74
+ initializer_range: Range for initializing weights in the decoder, defaults to 0.02.
75
+ norm_eps: Epsilon value for normalization layers in the decoder, defaults to 1e-5.
76
+ rope_theta: Theta value for RoPE in the decoder, defaults to 10000.0.
77
+ rope_scaling: Optional scaling factor for RoPE in the decoder.
78
+ vocab_size: Vocabulary size for the decoder, defaults to 1028.
79
+ num_channels: Number of channels in the decoder, defaults to 9.
80
+ """
81
+
82
+ cross_head_dim: int = Field(default=128, gt=0)
83
+ cross_hidden_size: int = Field(default=1024, gt=0)
84
+ cross_num_attention_heads: int = Field(default=16, gt=0)
85
+ cross_num_key_value_heads: int = Field(default=16, gt=0)
86
+ head_dim: int = Field(default=128, gt=0)
87
+ hidden_act: str = Field(default="silu")
88
+ hidden_size: int = Field(default=2048, gt=0)
89
+ initializer_range: float = Field(default=0.02)
90
+ intermediate_size: int = Field(default=8192, gt=0)
91
+ max_position_embeddings: int = Field(default=3072, gt=0)
92
+ model_type: str = Field(default="dia_decoder")
93
+ norm_eps: float = Field(default=1e-5)
94
+ num_attention_heads: int = Field(default=16, gt=0)
95
+ num_channels: int = Field(default=9, gt=0)
96
+ num_hidden_layers: int = Field(default=18, gt=0)
97
+ num_key_value_heads: int = Field(default=4, gt=0)
98
+ rope_scaling: float | None = Field(default=None)
99
+ rope_theta: float = Field(default=10000.0)
100
+ vocab_size: int = Field(default=1028, gt=0)
101
+
102
+
103
+ class DiaConfig(BaseModel, frozen=True):
104
+ """Main configuration container for the Dia model architecture.
105
+
106
+ Attributes:
107
+ model_type: Type of the model, defaults to "dia".
108
+ is_encoder_decoder: Flag indicating if the model is an encoder-decoder type, defaults to True.
109
+ encoder: Configuration for the encoder component.
110
+ decoder: Configuration for the decoder component.
111
+ src_vocab_size: Size of the source (text) vocabulary.
112
+ tgt_vocab_size: Size of the target (audio code) vocabulary.
113
+ initializer_range: Range for initializing weights, defaults to 0.02.
114
+ norm_eps: Epsilon value for normalization layers, defaults to 1e-5.
115
+ torch_dtype: Data type for model weights in PyTorch, defaults to "float32".
116
+ bos_token_id: Beginning-of-sequence token ID, defaults to 1026.
117
+ eos_token_id: End-of-sequence token ID, defaults to 1024.
118
+ pad_token_id: Padding token ID, defaults to 1025.
119
+ rope_theta: Theta value for RoPE, defaults to 10000.0.
120
+ rope_scaling: Optional scaling factor for RoPE.
121
+ transformers_version: Version of the transformers library, defaults to "4.53.0.dev0".
122
+ architectures: List of model architectures, defaults to ["DiaForConditionalGeneration"].
123
+ delay_pattern: List of delay values for each audio channel, defaults to [0,8,9,10,11,12,13,14,15].
124
+ """
125
+
126
+ architectures: list[str] = Field(default_factory=lambda: ["DiaForConditionalGeneration"])
127
+ bos_token_id: int = Field(default=1026)
128
+ decoder_config: DecoderConfig
129
+ delay_pattern: list[int] = Field(default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15])
130
+ encoder_config: EncoderConfig
131
+ eos_token_id: int = Field(default=1024)
132
+ initializer_range: float = Field(default=0.02)
133
+ is_encoder_decoder: bool = Field(default=True)
134
+ model_type: str = Field(default="dia")
135
+ norm_eps: float = Field(default=1e-5)
136
+ pad_token_id: int = Field(default=1025)
137
+ torch_dtype: str = Field(default="float32")
138
+ transformers_version: str = Field(default="4.53.0.dev0")
139
+
140
+ def save(self, path: str) -> None:
141
+ """Save the current configuration instance to a JSON file.
142
+
143
+ Ensures the parent directory exists and the file has a .json extension.
144
+
145
+ Args:
146
+ path: The target file path to save the configuration.
147
+
148
+ Raises:
149
+ ValueError: If the path is not a file with a .json extension.
150
+ """
151
+ os.makedirs(os.path.dirname(path), exist_ok=True)
152
+ config_json = self.model_dump_json(indent=2)
153
+ with open(path, "w") as f:
154
+ f.write(config_json)
155
+
156
+ @classmethod
157
+ def load(cls, path: str) -> "DiaConfig | None":
158
+ """Load and validate a Dia configuration from a JSON file.
159
+
160
+ Args:
161
+ path: The path to the configuration file.
162
+
163
+ Returns:
164
+ A validated DiaConfig instance if the file exists and is valid,
165
+ otherwise None if the file is not found.
166
+
167
+ Raises:
168
+ ValueError: If the path does not point to an existing .json file.
169
+ pydantic.ValidationError: If the JSON content fails validation against the DiaConfig schema.
170
+ """
171
+ try:
172
+ with open(path, "r") as f:
173
+ content = f.read()
174
+ return cls.model_validate_json(content)
175
+ except FileNotFoundError:
176
+ return None
dia/layers.py ADDED
@@ -0,0 +1,888 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+ from torch import Tensor
6
+ from torch.nn import RMSNorm
7
+
8
+ from .config import DecoderConfig, DiaConfig, EncoderConfig
9
+ from .state import DecoderInferenceState, EncoderInferenceState, KVCache
10
+
11
+
12
+ def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]:
13
+ return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
14
+
15
+
16
+ class DenseGeneral(nn.Module):
17
+ """
18
+ PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init.
19
+ Stores weights (`kernel`) in the same layout as Jax and uses torch.tensordot
20
+ for the generalized matrix multiplication. Weight/bias shapes are calculated
21
+ and parameters created during initialization based on config.
22
+ `load_weights` validates shapes and copies data.
23
+ Attributes:
24
+ axis (Tuple[int, ...]): Input axis or axes to contract.
25
+ in_shapes (Tuple[int, ...]): Sizes of the input dimensions specified by `axis`.
26
+ out_features (Tuple[int, ...]): Shape of the output features (non-contracted dims).
27
+ use_bias (bool): Whether to add a bias term.
28
+ weight (nn.Parameter): The kernel parameter.
29
+ bias (Optional[nn.Parameter]): The bias parameter (if use_bias=True).
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ in_shapes: tuple[int, ...],
35
+ out_features: tuple[int, ...],
36
+ axis: tuple[int, ...] = (-1,),
37
+ weight_dtype: torch.dtype | None = None,
38
+ device: torch.device | None = None,
39
+ ):
40
+ super().__init__()
41
+ self.in_shapes = in_shapes
42
+ self.out_features = out_features
43
+ self.axis = axis
44
+ self.kernel_shape = self.in_shapes + self.out_features
45
+
46
+ factory_kwargs = {"device": device, "dtype": weight_dtype}
47
+ self.weight = nn.Parameter(torch.empty(self.kernel_shape, **factory_kwargs))
48
+
49
+ def forward(self, inputs: Tensor) -> Tensor:
50
+ norm_axis = _normalize_axes(self.axis, inputs.ndim)
51
+ kernel_contract_axes = tuple(range(len(norm_axis)))
52
+
53
+ output = torch.tensordot(
54
+ inputs.to(self.weight.dtype),
55
+ self.weight,
56
+ dims=(norm_axis, kernel_contract_axes),
57
+ ).to(inputs.dtype)
58
+ return output
59
+
60
+
61
+ class MlpBlock(nn.Module):
62
+ """MLP block using DenseGeneral."""
63
+
64
+ def __init__(self, embed_dim: int, intermediate_dim: int, compute_dtype: torch.dtype):
65
+ super().__init__()
66
+ self.dtype = compute_dtype
67
+
68
+ self.wi_fused = DenseGeneral(
69
+ in_shapes=(embed_dim,),
70
+ out_features=(2, intermediate_dim),
71
+ axis=(-1,),
72
+ weight_dtype=compute_dtype,
73
+ )
74
+
75
+ self.wo = DenseGeneral(
76
+ in_shapes=(intermediate_dim,),
77
+ out_features=(embed_dim,),
78
+ axis=(-1,),
79
+ weight_dtype=compute_dtype,
80
+ )
81
+
82
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
83
+ """Forward pass."""
84
+ fused_x = self.wi_fused(x)
85
+
86
+ gate = fused_x[..., 0, :]
87
+ up = fused_x[..., 1, :]
88
+
89
+ hidden = torch.mul(F.silu(gate), up).to(self.dtype)
90
+
91
+ output = self.wo(hidden)
92
+ return output
93
+
94
+
95
+ class RotaryEmbedding(nn.Module):
96
+ """Rotary Position Embedding (RoPE) implementation in PyTorch."""
97
+
98
+ def __init__(
99
+ self,
100
+ embedding_dims: int,
101
+ min_timescale: float = 1.0,
102
+ max_timescale: float = 10000.0,
103
+ dtype: torch.dtype = torch.float32,
104
+ ):
105
+ super().__init__()
106
+ if embedding_dims % 2 != 0:
107
+ raise ValueError("Embedding dim must be even for RoPE.")
108
+ self.embedding_dims = embedding_dims
109
+ self.min_timescale = min_timescale
110
+ self.max_timescale = max_timescale
111
+ self.compute_dtype = dtype
112
+
113
+ half_embedding_dim = embedding_dims // 2
114
+ fraction = (2.0 * torch.arange(0, half_embedding_dim)) / embedding_dims
115
+ timescale = (self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction).to(torch.float32)
116
+ self.register_buffer("timescale", timescale, persistent=False)
117
+
118
+ def forward(self, inputs: torch.Tensor, position: torch.Tensor):
119
+ """Applies RoPE."""
120
+ position = position.unsqueeze(-1).unsqueeze(-1)
121
+ sinusoid_inp = position / self.timescale
122
+ sin = torch.sin(sinusoid_inp)
123
+ cos = torch.cos(sinusoid_inp)
124
+ first_half, second_half = torch.chunk(inputs.to(torch.float32), 2, dim=-1)
125
+ first_part = first_half * cos - second_half * sin
126
+ second_part = second_half * cos + first_half * sin
127
+ return torch.cat(
128
+ (first_part.to(self.compute_dtype), second_part.to(self.compute_dtype)),
129
+ dim=-1,
130
+ )
131
+
132
+ def apply_rope(self, inputs: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor):
133
+ first_half, second_half = torch.chunk(inputs.to(torch.float32), 2, dim=-1)
134
+ first_part = first_half * cos - second_half * sin
135
+ second_part = second_half * cos + first_half * sin
136
+ return torch.cat((first_part.to(self.compute_dtype), second_part.to(self.compute_dtype)), dim=-1)
137
+
138
+
139
+ def custom_scaled_dot_product_attention(
140
+ query: torch.Tensor,
141
+ key: torch.Tensor,
142
+ value: torch.Tensor,
143
+ attn_mask: torch.Tensor | None = None,
144
+ scale: float = 1.0,
145
+ is_causal: bool = False,
146
+ num_gqa_groups: int = 1,
147
+ ) -> torch.Tensor:
148
+ """
149
+ Custom scaled dot-product attention with GQA support for MPS compatibility.
150
+
151
+ Args:
152
+ query: (B, N_q, T, H) - Query tensor, N_q = num_query_heads
153
+ key: (B, N_kv, S, H) - Key tensor, N_kv = num_kv_heads
154
+ value: (B, N_kv, S, H) - Value tensor
155
+ attn_mask: (B, 1, T, S) - Attention mask, optional
156
+ scale: Scaling factor for attention scores
157
+ is_causal: If True, apply causal masking
158
+ num_gqa_groups: Number of query groups per KV head (N_q / N_kv)
159
+
160
+ Returns:
161
+ output: (B, N_q, T, H) - Attention output
162
+ """
163
+ B, N_q, T, H = query.shape
164
+ _, N_kv, S, _ = key.shape
165
+
166
+ # For GQA, repeat key and value tensors to match query heads
167
+ if num_gqa_groups > 1:
168
+ key = key.repeat_interleave(num_gqa_groups, dim=1) # (B, N_q, S, H)
169
+ value = value.repeat_interleave(num_gqa_groups, dim=1) # (B, N_q, S, H)
170
+
171
+ # Compute attention scores: (B, N_q, T, H) @ (B, N_q, H, S) -> (B, N_q, T, S)
172
+ scores = torch.matmul(query, key.transpose(-1, -2)) * scale
173
+
174
+ # Apply causal mask if needed
175
+ if is_causal:
176
+ causal_mask = torch.tril(torch.ones(T, S, dtype=torch.bool, device=query.device))
177
+ scores = scores.masked_fill(~causal_mask, float("-inf"))
178
+
179
+ # Apply attention mask if provided
180
+ if attn_mask is not None:
181
+ scores = scores.masked_fill(~attn_mask, float("-inf"))
182
+
183
+ # Softmax over the last dimension (S)
184
+ attn_weights = F.softmax(scores, dim=-1)
185
+
186
+ # Compute output: (B, N_q, T, S) @ (B, N_q, S, H) -> (B, N_q, T, H)
187
+ output = torch.matmul(attn_weights, value)
188
+
189
+ return output
190
+
191
+
192
+ class CrossAttention(nn.Module):
193
+ """Cross-Attention using DenseGeneral."""
194
+
195
+ def __init__(
196
+ self,
197
+ config: EncoderConfig | DecoderConfig,
198
+ q_embed_dim: int,
199
+ kv_embed_dim: int,
200
+ num_query_heads: int,
201
+ num_kv_heads: int,
202
+ head_dim: int,
203
+ compute_dtype: torch.dtype,
204
+ out_embed_dim: int | None = None,
205
+ ):
206
+ super().__init__()
207
+ self.num_query_heads = num_query_heads
208
+ self.num_kv_heads = num_kv_heads
209
+ self.head_dim = head_dim
210
+ self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
211
+ self.projected_query_dim = num_query_heads * head_dim
212
+ if num_query_heads % num_kv_heads != 0:
213
+ raise ValueError(f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})")
214
+ self.num_gqa_groups = num_query_heads // num_kv_heads
215
+
216
+ # --- Projection Layers using DenseGeneral ---
217
+ self.q_proj = DenseGeneral(
218
+ in_shapes=(q_embed_dim,),
219
+ out_features=(num_query_heads, head_dim),
220
+ axis=(-1,),
221
+ weight_dtype=compute_dtype,
222
+ )
223
+ self.k_proj = DenseGeneral(
224
+ in_shapes=(kv_embed_dim,),
225
+ out_features=(num_kv_heads, head_dim),
226
+ axis=(-1,),
227
+ weight_dtype=compute_dtype,
228
+ )
229
+ self.v_proj = DenseGeneral(
230
+ in_shapes=(kv_embed_dim,),
231
+ out_features=(num_kv_heads, head_dim),
232
+ axis=(-1,),
233
+ weight_dtype=compute_dtype,
234
+ )
235
+ self.o_proj = DenseGeneral(
236
+ in_shapes=(num_query_heads, head_dim),
237
+ out_features=(self.output_dim,),
238
+ axis=(-2, -1),
239
+ weight_dtype=compute_dtype,
240
+ )
241
+
242
+ # --- Rotary Embedding ---
243
+ self.rotary_emb = RotaryEmbedding(
244
+ embedding_dims=self.head_dim,
245
+ max_timescale=config.rope_theta,
246
+ dtype=compute_dtype,
247
+ )
248
+
249
+ def forward(
250
+ self,
251
+ Xq: torch.Tensor, # (B, T, D) T = 1 in AR generation
252
+ q_positions: torch.Tensor, # (B, T)
253
+ kv_positions: torch.Tensor | None = None, # (B, S)
254
+ attn_mask: torch.Tensor | None = None, # None in Decoder Self Attention, Valid mask in Others
255
+ cache: KVCache | None = None, # None in Encoder, KVCache in Decoder
256
+ is_causal: bool = False,
257
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
258
+ """
259
+ Performs attention calculation with optional KV caching.
260
+
261
+ Args:
262
+ Xq: Query tensor (B, T, D). T=1 during single-step decoding.
263
+ Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn.
264
+ q_positions: Positions for queries (B, T).
265
+ kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
266
+ attn_mask: Attention mask.
267
+ cache: KVCache.
268
+
269
+ Returns:
270
+ A tuple containing:
271
+ - output: The attention output tensor (B, T, output_dim).
272
+ - present_kv: The K/V state to be cached for the next step ((B, N, S_new, H), (B, N, S_new, H)). For self-attn, S_new = S_past + S. For cross-attn, S_new = S_kv.
273
+ """
274
+ if kv_positions is None:
275
+ kv_positions = q_positions
276
+ original_dtype = Xq.dtype
277
+
278
+ Xq_BxTxNxH = self.q_proj(Xq)
279
+ Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)
280
+
281
+ attn_k: torch.Tensor | None = cache.k if cache is not None else None
282
+ attn_v: torch.Tensor | None = cache.v if cache is not None else None
283
+
284
+ # Use custom attention for MPS backend, otherwise use optimized PyTorch function
285
+ is_mps = Xq.device.type == "mps" and torch.backends.mps.is_available()
286
+ if is_mps:
287
+ attn_output = custom_scaled_dot_product_attention(
288
+ query=Xq_BxNxTxH,
289
+ key=attn_k,
290
+ value=attn_v,
291
+ attn_mask=attn_mask if not is_causal else None,
292
+ scale=1.0,
293
+ is_causal=is_causal,
294
+ num_gqa_groups=self.num_gqa_groups,
295
+ )
296
+ else:
297
+ attn_output = F.scaled_dot_product_attention(
298
+ Xq_BxNxTxH,
299
+ attn_k,
300
+ attn_v,
301
+ attn_mask=attn_mask if not is_causal else None,
302
+ scale=1.0,
303
+ enable_gqa=self.num_gqa_groups > 1,
304
+ is_causal=is_causal,
305
+ )
306
+
307
+ attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H)
308
+ output = self.o_proj(attn_output)
309
+
310
+ return output.to(original_dtype)
311
+
312
+
313
+ class FusedQKV(nn.Module):
314
+ def __init__(
315
+ self,
316
+ in_features: int,
317
+ out_features: int,
318
+ bias: bool = False,
319
+ num_q_heads: int = 1,
320
+ q_head_dim: int = 1,
321
+ num_kv_heads: int = 1,
322
+ kv_head_dim: int = 1,
323
+ ):
324
+ super().__init__()
325
+ self.num_q_heads = num_q_heads
326
+ self.q_head_dim = q_head_dim
327
+ self.num_kv_heads = num_kv_heads
328
+ self.kv_head_dim = kv_head_dim
329
+ self.q_output_dim = num_q_heads * q_head_dim
330
+ self.kv_output_dim = num_kv_heads * kv_head_dim
331
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
332
+
333
+ def forward(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
334
+ x = self.linear(inputs)
335
+
336
+ q, k, v = x.split([self.q_output_dim, self.kv_output_dim, self.kv_output_dim], dim=-1)
337
+
338
+ q = q.reshape(q.shape[:-1] + (self.num_q_heads, self.q_head_dim))
339
+ k = k.reshape(k.shape[:-1] + (self.num_kv_heads, self.kv_head_dim))
340
+ v = v.reshape(v.shape[:-1] + (self.num_kv_heads, self.kv_head_dim))
341
+
342
+ return q, k, v
343
+
344
+
345
+ class SelfAttention(nn.Module):
346
+ """Attention using DenseGeneral."""
347
+
348
+ def __init__(
349
+ self,
350
+ config: EncoderConfig | DecoderConfig,
351
+ q_embed_dim: int,
352
+ kv_embed_dim: int,
353
+ num_query_heads: int,
354
+ num_kv_heads: int,
355
+ head_dim: int,
356
+ compute_dtype: torch.dtype,
357
+ out_embed_dim: int | None = None,
358
+ ):
359
+ super().__init__()
360
+ self.num_query_heads = num_query_heads
361
+ self.num_kv_heads = num_kv_heads
362
+ self.head_dim = head_dim
363
+ self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
364
+ self.projected_query_dim = num_query_heads * head_dim
365
+ if num_query_heads % num_kv_heads != 0:
366
+ raise ValueError(f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})")
367
+ self.num_gqa_groups = num_query_heads // num_kv_heads
368
+ self.kv_embed_dim = kv_embed_dim
369
+ self.q_embed_dim = q_embed_dim
370
+
371
+ # --- Projection Layers using DenseGeneral ---
372
+ self.q_proj = DenseGeneral(
373
+ in_shapes=(q_embed_dim,),
374
+ out_features=(num_query_heads, head_dim),
375
+ axis=(-1,),
376
+ weight_dtype=compute_dtype,
377
+ )
378
+ self.k_proj = DenseGeneral(
379
+ in_shapes=(kv_embed_dim,),
380
+ out_features=(num_kv_heads, head_dim),
381
+ axis=(-1,),
382
+ weight_dtype=compute_dtype,
383
+ )
384
+ self.v_proj = DenseGeneral(
385
+ in_shapes=(kv_embed_dim,),
386
+ out_features=(num_kv_heads, head_dim),
387
+ axis=(-1,),
388
+ weight_dtype=compute_dtype,
389
+ )
390
+ self.o_proj = DenseGeneral(
391
+ in_shapes=(num_query_heads, head_dim),
392
+ out_features=(self.output_dim,),
393
+ axis=(-2, -1),
394
+ weight_dtype=compute_dtype,
395
+ )
396
+
397
+ # --- Rotary Embedding ---
398
+ self.rotary_emb = RotaryEmbedding(
399
+ embedding_dims=self.head_dim,
400
+ max_timescale=config.rope_theta,
401
+ dtype=compute_dtype,
402
+ )
403
+
404
+ self.is_fused_qkv = False
405
+
406
+ def get_linear_weight(self, dense: DenseGeneral):
407
+ W_dg = dense.weight.data
408
+
409
+ out_features = 1
410
+ input_features = 1
411
+ for dim in dense.out_features:
412
+ out_features *= dim
413
+ for dim in dense.in_shapes:
414
+ input_features *= dim
415
+
416
+ W_dg_reshaped_for_linear_T = W_dg.reshape(input_features, out_features)
417
+ linear_weight = W_dg_reshaped_for_linear_T.transpose(0, 1).contiguous()
418
+ return linear_weight
419
+
420
+ def patch_fused_qkv(self):
421
+ q_proj_weight = self.get_linear_weight(self.q_proj)
422
+ k_proj_weight = self.get_linear_weight(self.k_proj)
423
+ v_proj_weight = self.get_linear_weight(self.v_proj)
424
+
425
+ self.qkv = FusedQKV(
426
+ self.kv_embed_dim,
427
+ (self.num_query_heads * self.head_dim + 2 * (self.num_kv_heads * self.head_dim)),
428
+ bias=False,
429
+ num_q_heads=self.num_query_heads,
430
+ q_head_dim=self.head_dim,
431
+ num_kv_heads=self.num_kv_heads,
432
+ kv_head_dim=self.head_dim,
433
+ )
434
+ self.qkv.linear.weight.data = torch.cat([q_proj_weight, k_proj_weight, v_proj_weight], dim=0)
435
+
436
+ # print(f"qkv.weight.shape: {self.qkv.linear.weight.shape}")
437
+ self.is_fused_qkv = True
438
+
439
+ def forward(
440
+ self,
441
+ X: torch.Tensor, # (B, T, D) T = 1 in AR generation
442
+ q_positions: torch.Tensor, # (B, T)
443
+ kv_positions: torch.Tensor | None = None, # (B, S)
444
+ attn_mask: torch.Tensor | None = None, # None in Decoder Self Attention, Valid mask in Others
445
+ cache: KVCache | None = None, # None in Encoder, KVCache in Decoder
446
+ prefill: bool = False,
447
+ is_causal: bool = False,
448
+ current_idx: torch.Tensor | None = None,
449
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
450
+ """
451
+ Performs attention calculation with optional KV caching.
452
+ Args:
453
+ Xq: Query tensor (B, T, D). T=1 during single-step decoding.
454
+ Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn.
455
+ q_positions: Positions for queries (B, T).
456
+ kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
457
+ attn_mask: Attention mask.
458
+ cache: KVCache.
459
+ prefill: If True, use prefill mode.
460
+ Returns:
461
+ A tuple containing:
462
+ - output: The attention output tensor (B, T, output_dim).
463
+ - present_kv: The K/V state to be cached for the next step ((B, N, S_new, H), (B, N, S_new, H)). For self-attn, S_new = S_past + S. For cross-attn, S_new = S_kv.
464
+ """
465
+ if kv_positions is None:
466
+ kv_positions = q_positions
467
+
468
+ original_dtype = X.dtype
469
+
470
+ if self.is_fused_qkv:
471
+ Xq_BxTxNxH, Xk_BxSxKxH, Xv_BxSxKxH = self.qkv(X)
472
+ else:
473
+ Xq_BxTxNxH = self.q_proj(X)
474
+ Xk_BxSxKxH = self.k_proj(X)
475
+ Xv_BxSxKxH = self.v_proj(X)
476
+
477
+ position = q_positions.unsqueeze(-1).unsqueeze(-1)
478
+ sinusoid_inp = position / self.rotary_emb.timescale
479
+ sin = torch.sin(sinusoid_inp)
480
+ cos = torch.cos(sinusoid_inp)
481
+
482
+ Xq_BxTxNxH = self.rotary_emb.apply_rope(Xq_BxTxNxH, sin, cos)
483
+ Xk_BxSxKxH = self.rotary_emb.apply_rope(Xk_BxSxKxH, sin, cos)
484
+
485
+ Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)
486
+
487
+ attn_k: torch.Tensor | None = cache.k if cache is not None else None
488
+ attn_v: torch.Tensor | None = cache.v if cache is not None else None
489
+
490
+ Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) # (B, K, S, H)
491
+ Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) # (B, K, S, H)
492
+
493
+ if cache is None:
494
+ attn_k = Xk_BxKxSxH
495
+ attn_v = Xv_BxKxSxH
496
+ elif prefill:
497
+ attn_k, attn_v = Xk_BxKxSxH, Xv_BxKxSxH
498
+ cache.prefill(attn_k, attn_v)
499
+ else:
500
+ attn_k, attn_v = cache.update(Xk_BxKxSxH, Xv_BxKxSxH, current_idx)
501
+
502
+ # Use custom attention for MPS backend, otherwise use optimized PyTorch function
503
+ is_mps = Xv_BxSxKxH.device.type == "mps" and torch.backends.mps.is_available()
504
+ if is_mps:
505
+ attn_output = custom_scaled_dot_product_attention(
506
+ query=Xq_BxNxTxH,
507
+ key=attn_k,
508
+ value=attn_v,
509
+ attn_mask=attn_mask if not is_causal else None,
510
+ scale=1.0,
511
+ is_causal=is_causal,
512
+ num_gqa_groups=self.num_gqa_groups,
513
+ )
514
+ else:
515
+ attn_output = F.scaled_dot_product_attention(
516
+ Xq_BxNxTxH,
517
+ attn_k,
518
+ attn_v,
519
+ attn_mask=attn_mask if not is_causal else None,
520
+ scale=1.0,
521
+ enable_gqa=self.num_gqa_groups > 1,
522
+ is_causal=is_causal,
523
+ )
524
+
525
+ attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H)
526
+ output = self.o_proj(attn_output)
527
+
528
+ return output.to(original_dtype)
529
+
530
+
531
+ class EncoderLayer(nn.Module):
532
+ """Transformer Encoder Layer using DenseGeneral."""
533
+
534
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
535
+ super().__init__()
536
+ self.config = config
537
+ enc_config = config.encoder_config
538
+ embed_dim = enc_config.hidden_size
539
+ self.compute_dtype = compute_dtype
540
+
541
+ self.pre_sa_norm = RMSNorm(
542
+ embed_dim,
543
+ eps=enc_config.norm_eps,
544
+ dtype=torch.float32,
545
+ )
546
+ self.self_attention = SelfAttention(
547
+ enc_config,
548
+ q_embed_dim=embed_dim,
549
+ kv_embed_dim=embed_dim,
550
+ num_query_heads=enc_config.num_attention_heads,
551
+ num_kv_heads=enc_config.num_key_value_heads,
552
+ head_dim=enc_config.head_dim,
553
+ compute_dtype=compute_dtype,
554
+ out_embed_dim=embed_dim,
555
+ )
556
+ self.post_sa_norm = RMSNorm(
557
+ embed_dim,
558
+ eps=enc_config.norm_eps,
559
+ dtype=torch.float32,
560
+ )
561
+ self.mlp = MlpBlock(
562
+ embed_dim=embed_dim,
563
+ intermediate_dim=enc_config.intermediate_size,
564
+ compute_dtype=compute_dtype,
565
+ )
566
+
567
+ def forward(
568
+ self,
569
+ x: torch.Tensor,
570
+ state: EncoderInferenceState,
571
+ ) -> torch.Tensor:
572
+ residual = x
573
+ x_norm = self.pre_sa_norm(x).to(self.compute_dtype)
574
+
575
+ sa_out = self.self_attention(
576
+ X=x_norm,
577
+ q_positions=state.positions,
578
+ kv_positions=state.positions,
579
+ attn_mask=state.attn_mask,
580
+ )
581
+ x = residual + sa_out
582
+
583
+ residual = x
584
+ x_norm = self.post_sa_norm(x).to(self.compute_dtype)
585
+ mlp_out = self.mlp(x_norm)
586
+ x = residual + mlp_out
587
+
588
+ return x
589
+
590
+
591
+ class Encoder(nn.Module):
592
+ """Transformer Encoder Stack using DenseGeneral."""
593
+
594
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
595
+ super().__init__()
596
+ self.config = config
597
+ enc_config = config.encoder_config
598
+ self.compute_dtype = compute_dtype
599
+
600
+ self.embedding = nn.Embedding(
601
+ enc_config.vocab_size,
602
+ enc_config.hidden_size,
603
+ dtype=compute_dtype,
604
+ )
605
+ self.layers = nn.ModuleList([EncoderLayer(config, compute_dtype) for _ in range(enc_config.num_hidden_layers)])
606
+ self.norm = RMSNorm(
607
+ enc_config.hidden_size,
608
+ eps=enc_config.norm_eps,
609
+ dtype=torch.float32,
610
+ )
611
+
612
+ def forward(
613
+ self,
614
+ x_ids: torch.Tensor,
615
+ state: EncoderInferenceState,
616
+ ) -> torch.Tensor:
617
+ x = self.embedding(x_ids)
618
+
619
+ for layer in self.layers:
620
+ x = layer(x, state)
621
+
622
+ x = self.norm(x).to(self.compute_dtype)
623
+ return x
624
+
625
+
626
+ class DecoderLayer(nn.Module):
627
+ """Transformer Decoder Layer using DenseGeneral."""
628
+
629
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
630
+ super().__init__()
631
+ self.config = config
632
+ dec_config = config.decoder_config
633
+ enc_config = config.encoder_config
634
+ dec_embed_dim = dec_config.hidden_size
635
+ enc_embed_dim = enc_config.hidden_size
636
+ self.compute_dtype = compute_dtype
637
+
638
+ # Norms
639
+ self.pre_sa_norm = RMSNorm(
640
+ dec_embed_dim,
641
+ eps=dec_config.norm_eps,
642
+ dtype=torch.float32,
643
+ )
644
+ self.pre_ca_norm = RMSNorm(
645
+ dec_embed_dim,
646
+ eps=dec_config.norm_eps,
647
+ dtype=torch.float32,
648
+ )
649
+ self.pre_mlp_norm = RMSNorm(
650
+ dec_embed_dim,
651
+ eps=dec_config.norm_eps,
652
+ dtype=torch.float32,
653
+ )
654
+
655
+ # Self-Attention (GQA) with Causal Masking
656
+ self.self_attention = SelfAttention(
657
+ dec_config,
658
+ q_embed_dim=dec_embed_dim,
659
+ kv_embed_dim=dec_embed_dim,
660
+ num_query_heads=dec_config.num_attention_heads,
661
+ num_kv_heads=dec_config.num_key_value_heads,
662
+ head_dim=dec_config.head_dim,
663
+ compute_dtype=compute_dtype,
664
+ out_embed_dim=dec_embed_dim,
665
+ )
666
+ # Cross-Attention (MHA)
667
+ self.cross_attention = CrossAttention(
668
+ dec_config,
669
+ q_embed_dim=dec_embed_dim,
670
+ kv_embed_dim=enc_embed_dim, # Note kv_embed_dim
671
+ num_query_heads=dec_config.cross_num_attention_heads,
672
+ num_kv_heads=dec_config.cross_num_key_value_heads,
673
+ head_dim=dec_config.cross_head_dim,
674
+ compute_dtype=compute_dtype,
675
+ out_embed_dim=dec_embed_dim,
676
+ )
677
+ # MLP
678
+ self.mlp = MlpBlock(
679
+ embed_dim=dec_embed_dim,
680
+ intermediate_dim=dec_config.intermediate_size,
681
+ compute_dtype=compute_dtype,
682
+ )
683
+
684
+ def forward(
685
+ self,
686
+ x: torch.Tensor,
687
+ state: DecoderInferenceState,
688
+ self_attn_cache: KVCache | None = None,
689
+ cross_attn_cache: KVCache | None = None,
690
+ prefill: bool = False,
691
+ current_idx: int = 0,
692
+ ) -> torch.Tensor:
693
+ residual = x
694
+ x_norm = self.pre_sa_norm(x).to(self.compute_dtype)
695
+
696
+ self_attn_mask = state.casual_attn_mask[None, None, current_idx]
697
+
698
+ sa_out = self.self_attention(
699
+ X=x_norm, # (2, 1, D)
700
+ q_positions=state.dec_positions, # (2, 1)
701
+ kv_positions=state.dec_positions, # (2, 1)
702
+ attn_mask=self_attn_mask,
703
+ cache=self_attn_cache,
704
+ prefill=prefill,
705
+ is_causal=prefill,
706
+ current_idx=current_idx,
707
+ )
708
+
709
+ x = residual + sa_out
710
+
711
+ residual = x
712
+ x_norm = self.pre_ca_norm(x).to(self.compute_dtype)
713
+ ca_out = self.cross_attention(
714
+ Xq=x_norm,
715
+ q_positions=state.dec_positions,
716
+ kv_positions=state.enc_positions,
717
+ attn_mask=state.cross_attn_mask,
718
+ cache=cross_attn_cache,
719
+ )
720
+ x = residual + ca_out
721
+
722
+ residual = x
723
+ x_norm = self.pre_mlp_norm(x).to(self.compute_dtype)
724
+ mlp_out = self.mlp(x_norm)
725
+ x = residual + mlp_out
726
+
727
+ return x
728
+
729
+
730
+ class Decoder(nn.Module):
731
+ """Transformer Decoder Stack using DenseGeneral."""
732
+
733
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
734
+ super().__init__()
735
+ self.config = config
736
+ dec_config = config.decoder_config
737
+ self.num_channels = dec_config.num_channels
738
+ self.num_layers = dec_config.num_hidden_layers
739
+
740
+ self.embeddings = nn.ModuleList(
741
+ [
742
+ nn.Embedding(dec_config.vocab_size, dec_config.hidden_size, dtype=compute_dtype)
743
+ for _ in range(self.num_channels)
744
+ ]
745
+ )
746
+ self.layers = nn.ModuleList(
747
+ [DecoderLayer(config=config, compute_dtype=compute_dtype) for _ in range(self.num_layers)]
748
+ )
749
+
750
+ self.norm = RMSNorm(
751
+ dec_config.hidden_size,
752
+ eps=dec_config.norm_eps,
753
+ dtype=torch.float32,
754
+ )
755
+
756
+ self.logits_dense = DenseGeneral(
757
+ in_shapes=(dec_config.hidden_size,),
758
+ out_features=(self.num_channels, dec_config.vocab_size),
759
+ axis=(-1,),
760
+ weight_dtype=compute_dtype,
761
+ )
762
+
763
+ def precompute_cross_attn_cache(
764
+ self,
765
+ enc_out: torch.Tensor, # (B, S, E)
766
+ ) -> list[KVCache]:
767
+ """
768
+ Computes the Key and Value tensors for cross-attention for each layer from the encoder output.
769
+ """
770
+ per_layer_kv_cache: list[KVCache] = []
771
+
772
+ for layer in self.layers:
773
+ cross_attn_module = layer.cross_attention
774
+ k_proj = cross_attn_module.k_proj(enc_out)
775
+ v_proj = cross_attn_module.v_proj(enc_out)
776
+
777
+ k = k_proj.transpose(1, 2)
778
+ v = v_proj.transpose(1, 2)
779
+
780
+ per_layer_kv_cache.append(KVCache.from_kv(k, v))
781
+
782
+ return per_layer_kv_cache
783
+
784
+ def decode_step(
785
+ self,
786
+ tgt_ids_Bx1xC: torch.Tensor, # [B, 1, C]
787
+ state: DecoderInferenceState,
788
+ current_idx: int,
789
+ ) -> torch.Tensor:
790
+ """
791
+ Performs a single decoding step, managing KV caches layer by layer.
792
+ Returns:
793
+ A tuple containing:
794
+ - logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32.
795
+ """
796
+
797
+ x = None
798
+ for i in range(self.num_channels):
799
+ channel_tokens = tgt_ids_Bx1xC[..., i]
800
+ channel_embed = self.embeddings[i](channel_tokens)
801
+ x = channel_embed if x is None else x + channel_embed
802
+
803
+ for i, layer in enumerate(self.layers):
804
+ self_cache = state.self_attn_cache[i]
805
+ cross_cache = state.cross_attn_cache[i]
806
+ x = layer(
807
+ x, # (2, 1, D)
808
+ state,
809
+ self_attn_cache=self_cache,
810
+ cross_attn_cache=cross_cache,
811
+ current_idx=current_idx,
812
+ )
813
+
814
+ x = self.norm(x)
815
+ logits_Bx1xCxV = self.logits_dense(x)
816
+
817
+ return logits_Bx1xCxV.to(torch.float32)
818
+
819
+ def forward(self, tgt_ids_BxTxC: torch.Tensor, state: DecoderInferenceState) -> torch.Tensor:
820
+ """
821
+ Forward pass for the Decoder stack, managing KV caches.
822
+ Args:
823
+ tgt_ids_BxTxC: Target token IDs (B, T, C).
824
+ encoder_out: Output from the encoder (B, S, E).
825
+ tgt_positions: Positions for target sequence (B, T).
826
+ src_positions: Positions for source sequence (B, S).
827
+ self_attn_mask: Mask for self-attention.
828
+ cross_attn_mask: Mask for cross-attention.
829
+ past_key_values: List containing the self-attention KV cache for each layer
830
+ from the previous decoding step. `len(past_key_values)` should
831
+ equal `num_layers`.
832
+ precomputed_cross_attn_kv: A single tuple containing the pre-computed K/V cache
833
+ derived from `encoder_out`. This is passed identically
834
+ to all layers.
835
+ Returns:
836
+ A tuple containing:
837
+ - logits: The final output logits (B, T, C * V), cast to float32.
838
+ - present_key_values: A list containing the updated self-attention KV cache
839
+ for each layer for the *current* decoding step.
840
+ """
841
+ _, _, num_channels_in = tgt_ids_BxTxC.shape
842
+ assert num_channels_in == self.num_channels, "Input channels mismatch"
843
+
844
+ # Embeddings
845
+ x = None
846
+ for i in range(self.num_channels):
847
+ channel_tokens = tgt_ids_BxTxC[..., i]
848
+ channel_embed = self.embeddings[i](channel_tokens)
849
+ x = channel_embed if x is None else x + channel_embed
850
+
851
+ for i, layer in enumerate(self.layers):
852
+ self_cache = state.self_attn_cache[i]
853
+ cross_cache = state.cross_attn_cache[i]
854
+ x = layer(
855
+ x,
856
+ state,
857
+ self_attn_cache=self_cache,
858
+ cross_attn_cache=cross_cache,
859
+ prefill=True,
860
+ )
861
+
862
+ # Final Norm
863
+ x = self.norm(x)
864
+ logits_BxTxCxV = self.logits_dense(x)
865
+
866
+ return logits_BxTxCxV.to(torch.float32)
867
+
868
+
869
+ class DiaModel(
870
+ nn.Module,
871
+ PyTorchModelHubMixin,
872
+ repo_url="https://github.com/nari-labs/dia",
873
+ pipeline_tag="text-to-speech",
874
+ license="apache-2.0",
875
+ coders={
876
+ DiaConfig: (
877
+ lambda x: x.model_dump(),
878
+ lambda data: DiaConfig.model_validate(data),
879
+ ),
880
+ },
881
+ ):
882
+ """PyTorch Dia Model using DenseGeneral."""
883
+
884
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
885
+ super().__init__()
886
+ self.config = config
887
+ self.encoder = Encoder(config, compute_dtype)
888
+ self.decoder = Decoder(config, compute_dtype)
dia/model.py ADDED
@@ -0,0 +1,802 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from enum import Enum
3
+ from typing import Callable
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torchaudio
9
+
10
+ from .audio import apply_audio_delay, build_delay_indices, build_revert_indices, revert_audio_delay
11
+ from .config import DiaConfig
12
+ from .layers import DiaModel
13
+ from .state import DecoderInferenceState, DecoderOutput, EncoderInferenceState
14
+
15
+
16
+ DEFAULT_SAMPLE_RATE = 44100
17
+ SAMPLE_RATE_RATIO = 512
18
+
19
+
20
+ def _get_default_device():
21
+ if torch.cuda.is_available():
22
+ return torch.device("cuda")
23
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
24
+ return torch.device("mps")
25
+ return torch.device("cpu")
26
+
27
+
28
+ def _sample_next_token(
29
+ logits_BCxV: torch.Tensor,
30
+ temperature: float,
31
+ top_p: float,
32
+ top_k: int | None,
33
+ audio_eos_value: int,
34
+ ) -> torch.Tensor:
35
+ if temperature == 0.0:
36
+ return torch.argmax(logits_BCxV, dim=-1)
37
+
38
+ logits_BCxV = logits_BCxV / temperature
39
+
40
+ if audio_eos_value is not None and audio_eos_value >= 0:
41
+ top_logit_indices_BC = torch.argmax(logits_BCxV, dim=-1)
42
+ eos_not_highest_mask_BC = top_logit_indices_BC != audio_eos_value
43
+ mask_eos_unless_highest_BCxV = torch.zeros_like(logits_BCxV, dtype=torch.bool)
44
+ mask_eos_unless_highest_BCxV[eos_not_highest_mask_BC, audio_eos_value] = True
45
+ logits_BCxV = logits_BCxV.masked_fill(mask_eos_unless_highest_BCxV, -torch.inf)
46
+ eos_highest_mask_BC = top_logit_indices_BC == audio_eos_value
47
+ mask_eos_highest_BCxV = torch.zeros_like(logits_BCxV, dtype=torch.bool)
48
+ mask_eos_highest_BCxV[eos_highest_mask_BC, :audio_eos_value] = True
49
+ logits_BCxV = logits_BCxV.masked_fill(mask_eos_highest_BCxV, -torch.inf)
50
+
51
+ if top_k is not None:
52
+ _, top_k_indices_BCxV = torch.topk(logits_BCxV, k=top_k, dim=-1)
53
+ mask = torch.ones_like(logits_BCxV, dtype=torch.bool)
54
+ mask = mask.scatter(dim=-1, index=top_k_indices_BCxV, value=False)
55
+ logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf)
56
+
57
+ if top_p < 1.0:
58
+ probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
59
+ sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(probs_BCxV, dim=-1, descending=True)
60
+ cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1)
61
+
62
+ sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p
63
+ sorted_indices_to_remove_BCxV = torch.roll(sorted_indices_to_remove_BCxV, shifts=1, dims=-1)
64
+ sorted_indices_to_remove_BCxV[..., 0] = torch.zeros_like(sorted_indices_to_remove_BCxV[..., 0])
65
+
66
+ indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV)
67
+ indices_to_remove_BCxV = indices_to_remove_BCxV.scatter(
68
+ dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV
69
+ )
70
+ logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf)
71
+
72
+ final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
73
+
74
+ sampled_indices_BC = torch.multinomial(final_probs_BCxV, num_samples=1)
75
+ sampled_indices_C = sampled_indices_BC.squeeze(-1)
76
+ return sampled_indices_C
77
+
78
+
79
+ class ComputeDtype(str, Enum):
80
+ FLOAT32 = "float32"
81
+ FLOAT16 = "float16"
82
+ BFLOAT16 = "bfloat16"
83
+
84
+ def to_dtype(self) -> torch.dtype:
85
+ if self == ComputeDtype.FLOAT32:
86
+ return torch.float32
87
+ elif self == ComputeDtype.FLOAT16:
88
+ return torch.float16
89
+ elif self == ComputeDtype.BFLOAT16:
90
+ return torch.bfloat16
91
+ else:
92
+ raise ValueError(f"Unsupported compute dtype: {self}")
93
+
94
+
95
+ class Dia:
96
+ def __init__(
97
+ self,
98
+ config: DiaConfig,
99
+ compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
100
+ device: torch.device | None = None,
101
+ load_dac: bool = True,
102
+ ):
103
+ """Initializes the Dia model.
104
+
105
+ Args:
106
+ config: The configuration object for the model.
107
+ compute_dtype: The computation dtype to use.
108
+ device: The device to load the model onto. If None, will automatically select the best available device.
109
+ load_dac: Whether to load the DAC model.
110
+
111
+ Raises:
112
+ RuntimeError: If there is an error loading the DAC model.
113
+ """
114
+ super().__init__()
115
+ self.config = config
116
+ self.device = device if device is not None else _get_default_device()
117
+ if isinstance(compute_dtype, str):
118
+ compute_dtype = ComputeDtype(compute_dtype)
119
+ self.compute_dtype = compute_dtype.to_dtype()
120
+ self.model: DiaModel = DiaModel(config, self.compute_dtype)
121
+ self.dac_model = None
122
+ self._compiled_step = None
123
+ self.load_dac = load_dac
124
+
125
+ if not self.load_dac:
126
+ print("Warning: DAC model will not be loaded. This is not recommended.")
127
+
128
+ if torch.cuda.is_available():
129
+ torch.backends.cuda.matmul.allow_tf32 = True
130
+
131
+ @classmethod
132
+ def from_local(
133
+ cls,
134
+ config_path: str,
135
+ checkpoint_path: str,
136
+ compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
137
+ device: torch.device | None = None,
138
+ load_dac: bool = True,
139
+ ) -> "Dia":
140
+ """Loads the Dia model from local configuration and checkpoint files.
141
+
142
+ Args:
143
+ config_path: Path to the configuration JSON file.
144
+ checkpoint_path: Path to the model checkpoint (.pth) file.
145
+ compute_dtype: The computation dtype to use.
146
+ device: The device to load the model onto. If None, will automatically select the best available device.
147
+ load_dac: Whether to load the DAC model.
148
+
149
+ Returns:
150
+ An instance of the Dia model loaded with weights and set to eval mode.
151
+
152
+ Raises:
153
+ FileNotFoundError: If the config or checkpoint file is not found.
154
+ RuntimeError: If there is an error loading the checkpoint.
155
+ """
156
+ config = DiaConfig.load(config_path)
157
+ if config is None:
158
+ raise FileNotFoundError(f"Config file not found at {config_path}")
159
+
160
+ dia = cls(config, compute_dtype, device, load_dac)
161
+
162
+ try:
163
+ state_dict = torch.load(checkpoint_path, map_location=dia.device)
164
+ dia.model.load_state_dict(state_dict)
165
+ except FileNotFoundError:
166
+ raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
167
+ except Exception as e:
168
+ raise RuntimeError(f"Error loading checkpoint from {checkpoint_path}") from e
169
+
170
+ dia.model.to(dia.device)
171
+ dia.model.eval()
172
+ if load_dac:
173
+ dia._load_dac_model()
174
+ return dia
175
+
176
+ @classmethod
177
+ def from_pretrained(
178
+ cls,
179
+ model_name: str = "nari-labs/Dia-1.6B-0626",
180
+ compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
181
+ device: torch.device | None = None,
182
+ load_dac: bool = True,
183
+ ) -> "Dia":
184
+ """Loads the Dia model from a Hugging Face Hub repository.
185
+
186
+ Downloads the configuration and checkpoint files from the specified
187
+ repository ID and then loads the model.
188
+
189
+ Args:
190
+ model_name: The Hugging Face Hub repository ID (e.g., "nari-labs/Dia-1.6B-0626").
191
+ compute_dtype: The computation dtype to use.
192
+ device: The device to load the model onto. If None, will automatically select the best available device.
193
+ load_dac: Whether to load the DAC model.
194
+
195
+ Returns:
196
+ An instance of the Dia model loaded with weights and set to eval mode.
197
+
198
+ Raises:
199
+ FileNotFoundError: If config or checkpoint download/loading fails.
200
+ RuntimeError: If there is an error loading the checkpoint.
201
+ """
202
+ if isinstance(compute_dtype, str):
203
+ compute_dtype = ComputeDtype(compute_dtype)
204
+
205
+ # Load model directly using DiaModel's from_pretrained which handles HF download
206
+ try:
207
+ loaded_model = DiaModel.from_pretrained(model_name, compute_dtype=compute_dtype.to_dtype())
208
+ except Exception as e:
209
+ raise RuntimeError(f"Error loading model from Hugging Face Hub ({model_name})") from e
210
+
211
+ config = loaded_model.config # Get config from the loaded model
212
+ dia = cls(config, compute_dtype, device, load_dac)
213
+
214
+ dia.model = loaded_model # Assign the already loaded model
215
+ dia.model.to(dia.device)
216
+ dia.model.eval()
217
+ if load_dac:
218
+ dia._load_dac_model()
219
+ return dia
220
+
221
+ def _load_dac_model(self):
222
+ """Loads the Descript Audio Codec (DAC) model.
223
+
224
+ Downloads the DAC model if necessary and loads it onto the specified device.
225
+ Sets the DAC model to evaluation mode.
226
+
227
+ Raises:
228
+ RuntimeError: If downloading or loading the DAC model fails.
229
+ """
230
+ import dac
231
+
232
+ try:
233
+ dac_model_path = dac.utils.download()
234
+ dac_model = dac.DAC.load(dac_model_path).to(self.device)
235
+ dac_model.eval() # Ensure DAC is in eval mode
236
+ except Exception as e:
237
+ raise RuntimeError("Failed to load DAC model") from e
238
+ self.dac_model = dac_model
239
+
240
+ def _encode_text(self, text: str) -> torch.Tensor:
241
+ """Encodes the input text string into a tensor of token IDs using byte-level encoding.
242
+
243
+ Special tokens [S1] and [S2] are replaced by their byte values. The resulting
244
+ sequence is truncated to the maximum configured text length.
245
+
246
+ Args:
247
+ text: The input text string.
248
+
249
+ Returns:
250
+ A tensor containing the encoded byte token IDs.
251
+ """
252
+ max_len = self.config.encoder_config.max_position_embeddings
253
+
254
+ byte_text = text.encode("utf-8")
255
+ # Replace special tokens with their byte values if needed by the specific tokenizer/config
256
+ # Assuming byte values 1 and 2 are correct placeholders based on original code
257
+ replaced_bytes = byte_text.replace(b"[S1]", b"\x01").replace(b"[S2]", b"\x02")
258
+ text_tokens = list(replaced_bytes)
259
+ return torch.tensor(
260
+ text_tokens[:max_len],
261
+ dtype=torch.long,
262
+ device=self.device,
263
+ )
264
+
265
+ def _pad_text_input(self, text_tokens: list[torch.Tensor]) -> torch.Tensor:
266
+ """Pads the text input to the maximum length."""
267
+ text_pad_value = 0
268
+ max_len = self.config.encoder_config.max_position_embeddings
269
+ batch_size = len(text_tokens)
270
+
271
+ src_tokens = torch.full(
272
+ (batch_size, 1, max_len),
273
+ fill_value=text_pad_value,
274
+ dtype=torch.long,
275
+ device=self.device,
276
+ )
277
+ for i in range(batch_size):
278
+ current_len = len(text_tokens[i])
279
+ src_tokens[i, 0, :current_len] = text_tokens[i]
280
+ return src_tokens
281
+
282
+ def _prepare_audio_prompt(self, audio_prompts: list[torch.Tensor | None]) -> tuple[torch.Tensor, list[int]]:
283
+ """Prepares the audio prompt tensor for the decoder.
284
+
285
+ Handles padding, adds the beginning-of-sequence (BOS) token, applies the
286
+ delay pattern, and determines the number of prefill steps for each item
287
+ in the batch.
288
+
289
+ Args:
290
+ audio_prompts: A list of audio prompt tensors (encoded DAC frames) or None.
291
+ Each tensor should have shape [T, C].
292
+
293
+ Returns:
294
+ A tuple containing:
295
+ - delayed_batch (torch.Tensor): The prepared audio prompt tensor with
296
+ delays applied, shape [B, T_max_padded, C].
297
+ - prefill_steps (list[int]): A list containing the number of valid
298
+ tokens (including BOS) for each prompt in the batch.
299
+ """
300
+ num_channels = self.config.decoder_config.num_channels
301
+ audio_bos_value = self.config.bos_token_id
302
+ delay_pattern = self.config.delay_pattern
303
+ max_delay_pattern = max(delay_pattern)
304
+ batch_size = len(audio_prompts)
305
+
306
+ max_len = max(p.shape[0] if p is not None else 0 for p in audio_prompts) + max_delay_pattern
307
+ prefill_steps = []
308
+
309
+ prefill = torch.full(
310
+ (batch_size, max_len, num_channels),
311
+ fill_value=-1,
312
+ dtype=torch.int,
313
+ device=self.device,
314
+ )
315
+
316
+ prefill[:, 0, :] = audio_bos_value
317
+
318
+ for i in range(batch_size):
319
+ prompt = audio_prompts[i]
320
+ if prompt is not None:
321
+ prompt = prompt.to(device=self.device, dtype=torch.int)
322
+ prefill[i, 1 : prompt.shape[0] + 1, :] = prompt
323
+ prefill_steps.append(prompt.shape[0] + 1)
324
+ else:
325
+ prefill_steps.append(1)
326
+
327
+ delay_precomp = build_delay_indices(
328
+ B=batch_size,
329
+ T=max_len,
330
+ C=num_channels,
331
+ delay_pattern=delay_pattern,
332
+ )
333
+
334
+ delayed_batch = apply_audio_delay(
335
+ audio_BxTxC=prefill,
336
+ pad_value=-1,
337
+ bos_value=audio_bos_value,
338
+ precomp=delay_precomp,
339
+ )
340
+
341
+ return delayed_batch, prefill_steps
342
+
343
+ def _prepare_generation(
344
+ self,
345
+ text: torch.Tensor,
346
+ audio_prompts: list[torch.Tensor | None],
347
+ max_tokens: int | None = None,
348
+ attn_fn: Callable = F.scaled_dot_product_attention,
349
+ ):
350
+ """Initializes the model state for generation.
351
+
352
+ Encodes the text input (conditional and unconditional), prepares the
353
+ encoder and decoder states (including KV caches and cross-attention),
354
+ prepares the audio prompt, and performs the initial decoder prefill steps
355
+ based on the audio prompts.
356
+
357
+ Args:
358
+ text: The padded text input tensor, shape [B, 1, T_text].
359
+ audio_prompts: A list of prepared audio prompt tensors or None.
360
+
361
+ Returns:
362
+ A tuple containing:
363
+ - dec_state (DecoderInferenceState): The initialized decoder state.
364
+ - dec_output (DecoderOutput): The initialized decoder output manager,
365
+ containing the prefilled audio tokens.
366
+ """
367
+ batch_size = text.shape[0]
368
+
369
+ enc_input_uncond = torch.zeros_like(text)
370
+ enc_input_cond = text
371
+ stacked_inputs = torch.stack([enc_input_uncond, enc_input_cond], dim=1)
372
+ enc_input = stacked_inputs.view(2 * batch_size, -1)
373
+
374
+ enc_state = EncoderInferenceState.new(self.config, enc_input_cond)
375
+ encoder_out = self.model.encoder(enc_input, enc_state)
376
+
377
+ dec_cross_attn_cache = self.model.decoder.precompute_cross_attn_cache(encoder_out)
378
+ dec_state = DecoderInferenceState.new(
379
+ self.config,
380
+ enc_state,
381
+ encoder_out,
382
+ dec_cross_attn_cache,
383
+ self.compute_dtype,
384
+ max_generation_length=max_tokens,
385
+ )
386
+ prefill, prefill_steps = self._prepare_audio_prompt(audio_prompts)
387
+
388
+ dec_output = DecoderOutput.new(batch_size, self.config, self.device)
389
+ dec_output.prefill(prefill, prefill_steps)
390
+
391
+ dec_step = min(prefill_steps) - 1
392
+ if dec_step > 0:
393
+ dec_state.prepare_step(0, dec_step)
394
+ tokens_BxTxC = dec_output.get_tokens_at(0, dec_step).repeat_interleave(2, dim=0)
395
+ self.model.decoder.forward(tokens_BxTxC, dec_state)
396
+
397
+ return dec_state, dec_output
398
+
399
+ def _decoder_step(
400
+ self,
401
+ tokens_Bx1xC: torch.Tensor,
402
+ dec_state: DecoderInferenceState,
403
+ cfg_scale: float,
404
+ temperature: float,
405
+ top_p: float,
406
+ top_k: int,
407
+ current_idx: int,
408
+ ) -> torch.Tensor:
409
+ """Performs a single step of the decoder inference.
410
+
411
+ Takes the tokens from the previous step, runs them through the decoder
412
+ (for both conditional and unconditional paths), applies classifier-free
413
+ guidance (CFG), samples the next token using temperature, top-p, and top-k
414
+ sampling, and applies constraints (e.g., preventing EOS in certain channels).
415
+
416
+ Args:
417
+ tokens_Bx1xC: The input tokens for the current step, shape [2*B, 1, C].
418
+ Repeated for CFG (unconditional and conditional).
419
+ dec_state: The current state of the decoder (KV caches, etc.).
420
+ cfg_scale: The scale factor for classifier-free guidance.
421
+ temperature: The temperature for sampling.
422
+ top_p: The cumulative probability threshold for top-p sampling.
423
+ top_k: The number of top logits to consider for top-k sampling.
424
+ current_idx: The current generation step index.
425
+
426
+ Returns:
427
+ torch.Tensor: The sampled next tokens for each item in the batch,
428
+ shape [B, C].
429
+ """
430
+ B = tokens_Bx1xC.shape[0] // 2
431
+
432
+ audio_eos_value = self.config.eos_token_id
433
+ logits_Bx1xCxV = self.model.decoder.decode_step(tokens_Bx1xC, dec_state, current_idx)
434
+
435
+ logits_last_2BxCxV = logits_Bx1xCxV[:, -1]
436
+ logits_last_Bx2xCxV = logits_last_2BxCxV.view(B, 2, *logits_last_2BxCxV.shape[1:])
437
+
438
+ uncond_logits_BxCxV = logits_last_Bx2xCxV[:, 0, :, :] # Shape [B, C, V]
439
+ cond_logits_BxCxV = logits_last_Bx2xCxV[:, 1, :, :] # Shape [B, C, V]
440
+ logits_BxCxV = cond_logits_BxCxV + cfg_scale * (cond_logits_BxCxV - uncond_logits_BxCxV)
441
+
442
+ _, top_k_indices_BxCxk = torch.topk(logits_BxCxV, k=top_k, dim=-1)
443
+ mask_BxCxV = torch.ones_like(logits_BxCxV, dtype=torch.bool)
444
+ mask_BxCxV = mask_BxCxV.scatter(dim=-1, index=top_k_indices_BxCxk, value=False)
445
+ logits_BxCxV = cond_logits_BxCxV.masked_fill(mask_BxCxV, -torch.inf)
446
+
447
+ logits_BxCxV[:, :, audio_eos_value + 1 :] = torch.full_like(
448
+ logits_BxCxV[:, :, audio_eos_value + 1 :],
449
+ fill_value=-torch.inf,
450
+ )
451
+ logits_BxCxV[:, 1:, audio_eos_value:] = torch.full_like(
452
+ logits_BxCxV[:, 1:, audio_eos_value:],
453
+ fill_value=-torch.inf,
454
+ )
455
+
456
+ flat_logits_BCxV = logits_BxCxV.view(B * self.config.decoder_config.num_channels, -1)
457
+
458
+ pred_BC = _sample_next_token(
459
+ flat_logits_BCxV.float(),
460
+ temperature=temperature,
461
+ top_p=top_p,
462
+ top_k=top_k,
463
+ audio_eos_value=audio_eos_value,
464
+ )
465
+
466
+ pred_BxC = pred_BC.view(B, self.config.decoder_config.num_channels)
467
+ return pred_BxC
468
+
469
+ def _generate_output(self, generated_codes: torch.Tensor, lengths_Bx: torch.Tensor) -> list[np.ndarray]:
470
+ """Converts generated delayed codes into audio waveforms.
471
+
472
+ Reverts the delay pattern applied during generation, decodes the resulting
473
+ codebook using the DAC model (if loaded), and returns a list of audio
474
+ waveforms as NumPy arrays. If DAC is not loaded, returns the raw codebook indices.
475
+
476
+ Args:
477
+ generated_codes: The tensor of generated audio codes with delays,
478
+ shape [B, T_gen, C].
479
+ lengths_Bx: A tensor containing the valid length of generated codes
480
+ (excluding padding and BOS/EOS markers) for each item
481
+ in the batch, shape [B].
482
+
483
+ Returns:
484
+ A list of NumPy arrays, where each array represents the generated audio
485
+ waveform for one item in the batch. If DAC is not loaded, returns the
486
+ raw, reverted codebook indices as NumPy arrays.
487
+ """
488
+ num_channels = self.config.decoder_config.num_channels
489
+ batch_size = generated_codes.shape[0]
490
+ seq_length = generated_codes.shape[1]
491
+ delay_pattern = self.config.delay_pattern
492
+ audio_pad_value = self.config.pad_token_id
493
+ max_delay_pattern = max(delay_pattern)
494
+
495
+ revert_precomp = build_revert_indices(
496
+ B=batch_size,
497
+ T=seq_length,
498
+ C=num_channels,
499
+ delay_pattern=delay_pattern,
500
+ )
501
+
502
+ codebook = revert_audio_delay(
503
+ audio_BxTxC=generated_codes,
504
+ pad_value=audio_pad_value,
505
+ precomp=revert_precomp,
506
+ T=seq_length,
507
+ )[:, :-max_delay_pattern, :]
508
+
509
+ min_valid_index = 0
510
+ max_valid_index = 1023
511
+ invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index)
512
+ codebook[invalid_mask] = 0
513
+
514
+ audios = []
515
+
516
+ if self.load_dac:
517
+ for i in range(batch_size):
518
+ audio = self._decode(codebook[i, : lengths_Bx[i], :])
519
+ audio_np = audio.cpu().numpy()
520
+ audios.append(audio_np)
521
+ else:
522
+ for i in range(batch_size):
523
+ audios.append(codebook[i, : lengths_Bx[i], :].cpu().numpy())
524
+ return audios
525
+
526
+ @torch.no_grad()
527
+ @torch.inference_mode()
528
+ def _encode(self, audio: torch.Tensor) -> torch.Tensor:
529
+ """
530
+ Encodes the given audio waveform into a tensor of DAC codebook indices
531
+ """
532
+ audio = audio.unsqueeze(0)
533
+ audio_data = self.dac_model.preprocess(audio, DEFAULT_SAMPLE_RATE)
534
+ _, encoded_frame, _, _, _ = self.dac_model.encode(audio_data)
535
+ encoded_frame: torch.Tensor
536
+ return encoded_frame.squeeze(0).transpose(0, 1)
537
+
538
+ @torch.no_grad()
539
+ @torch.inference_mode()
540
+ def _decode(self, audio_codes: torch.Tensor) -> torch.Tensor:
541
+ """
542
+ Decodes the given frames into an output audio waveform
543
+ """
544
+ audio_codes = audio_codes.unsqueeze(0).transpose(1, 2)
545
+ audio_values, _, _ = self.dac_model.quantizer.from_codes(audio_codes)
546
+ audio_values = self.dac_model.decode(audio_values)
547
+ audio_values: torch.Tensor
548
+ return audio_values.squeeze()
549
+
550
+ def load_audio(self, audio_path: str) -> torch.Tensor:
551
+ """Loads and preprocesses an audio file for use as a prompt.
552
+
553
+ Loads the audio file, resamples it to the target sample rate if necessary,
554
+ preprocesses it using the DAC model's preprocessing, and encodes it into
555
+ DAC codebook indices.
556
+
557
+ Args:
558
+ audio_path: Path to the audio file.
559
+
560
+ Returns:
561
+ torch.Tensor: The encoded audio prompt as DAC codebook indices,
562
+ shape [T, C].
563
+
564
+ Raises:
565
+ RuntimeError: If the DAC model is not loaded (`load_dac=False` during init).
566
+ FileNotFoundError: If the audio file cannot be found.
567
+ Exception: If there's an error during loading or processing.
568
+ """
569
+ if self.dac_model is None:
570
+ raise RuntimeError("DAC model is required for loading audio prompts but was not loaded.")
571
+ audio, sr = torchaudio.load(audio_path, channels_first=True) # C, T
572
+ if sr != DEFAULT_SAMPLE_RATE:
573
+ audio = torchaudio.functional.resample(audio, sr, DEFAULT_SAMPLE_RATE)
574
+ # Convert to mono if stereo
575
+ if audio.shape[0] > 1:
576
+ audio = torch.mean(audio, dim=0, keepdim=True) # Average channels to get mono
577
+ return self._encode(audio.to(self.device))
578
+
579
+ def save_audio(self, path: str, audio: np.ndarray):
580
+ """Saves the generated audio waveform to a file.
581
+
582
+ Uses the soundfile library to write the NumPy audio array to the specified
583
+ path with the default sample rate.
584
+
585
+ Args:
586
+ path: The path where the audio file will be saved.
587
+ audio: The audio waveform as a NumPy array.
588
+ """
589
+ import soundfile as sf
590
+
591
+ sf.write(path, audio, DEFAULT_SAMPLE_RATE)
592
+
593
+ @torch.inference_mode()
594
+ def generate(
595
+ self,
596
+ text: str | list[str],
597
+ max_tokens: int = 3072,
598
+ cfg_scale: float = 3.0,
599
+ temperature: float = 1.2,
600
+ top_p: float = 0.95,
601
+ use_torch_compile: bool = False,
602
+ cfg_filter_top_k: int = 45,
603
+ audio_prompt: list[str | torch.Tensor | None] | str | torch.Tensor | None = None,
604
+ audio_prompt_path: list[str | torch.Tensor | None] | str | torch.Tensor | None = None,
605
+ use_cfg_filter: bool | None = None,
606
+ verbose: bool = False,
607
+ ) -> np.ndarray | list[np.ndarray]:
608
+ """Generates audio corresponding to the input text.
609
+
610
+ Args:
611
+ text: The input text prompt, or a list of text prompts for batch generation.
612
+ max_tokens: The maximum number of audio tokens to generate per prompt.
613
+ Defaults to the model's configured audio length if None.
614
+ cfg_scale: The scale factor for classifier-free guidance (CFG). Higher values
615
+ lead to stronger guidance towards the text prompt.
616
+ temperature: The temperature for sampling. Higher values increase randomness.
617
+ top_p: The cumulative probability threshold for nucleus (top-p) sampling.
618
+ use_torch_compile: Whether to compile the generation steps using torch.compile.
619
+ Can significantly speed up generation after the initial
620
+ compilation overhead. Defaults to False.
621
+ cfg_filter_top_k: The number of top logits to consider during CFG filtering.
622
+ (Note: This parameter name might be slightly misleading based
623
+ on the code; it's used in the `_sample_next_token` function.)
624
+ audio_prompt: An audio prompt or list of prompts to condition the generation.
625
+ Can be a file path (str), a pre-loaded tensor (DAC codes), or None.
626
+ If a list, its length must match the batch size of the text input.
627
+ audio_prompt_path: (Deprecated) Use `audio_prompt` instead.
628
+ use_cfg_filter: (Deprecated) This parameter is no longer used.
629
+ verbose: If True, prints progress information during generation, including
630
+ speed metrics.
631
+
632
+ Returns:
633
+ If a single text prompt was provided, returns a NumPy array containing the
634
+ generated audio waveform.
635
+ If a list of text prompts was provided, returns a list of NumPy arrays,
636
+ each corresponding to a prompt in the input list. Returns None for a
637
+ sequence if no audio was generated for it.
638
+ """
639
+ batch_size = len(text) if isinstance(text, list) else 1
640
+ audio_eos_value = self.config.eos_token_id
641
+ audio_pad_value = self.config.pad_token_id
642
+ delay_pattern = self.config.delay_pattern
643
+ max_delay_pattern = max(delay_pattern)
644
+ delay_pattern_Cx = torch.tensor(delay_pattern, device=self.device, dtype=torch.long)
645
+ self.model.eval()
646
+
647
+ if audio_prompt_path:
648
+ print("Warning: audio_prompt_path is deprecated. Use audio_prompt instead.")
649
+ audio_prompt = audio_prompt_path
650
+ if use_cfg_filter is not None:
651
+ print("Warning: use_cfg_filter is deprecated.")
652
+
653
+ if verbose:
654
+ total_start_time = time.time()
655
+
656
+ if use_torch_compile and not hasattr(self, "_compiled"):
657
+ # Compilation can take about a minute.
658
+ self._prepare_generation = torch.compile(self._prepare_generation, dynamic=True, fullgraph=True)
659
+ self._decoder_step = torch.compile(self._decoder_step, fullgraph=True, mode="max-autotune")
660
+ self._compiled = True
661
+
662
+ if isinstance(audio_prompt, list):
663
+ audio_prompt = [self.load_audio(p) if isinstance(p, str) else p for p in audio_prompt]
664
+ elif isinstance(audio_prompt, str):
665
+ audio_prompt = [self.load_audio(audio_prompt)]
666
+ elif isinstance(audio_prompt, torch.Tensor):
667
+ audio_prompt = [audio_prompt]
668
+ elif audio_prompt is None:
669
+ audio_prompt = [None] * batch_size
670
+
671
+ assert len(audio_prompt) == batch_size, "Number of audio prompts must match batch size"
672
+
673
+ if isinstance(text, list):
674
+ text = [self._encode_text(t) for t in text]
675
+ else:
676
+ text = [self._encode_text(text)]
677
+ text = self._pad_text_input(text)
678
+
679
+ dec_state, dec_output = self._prepare_generation(text, audio_prompt, max_tokens=max_tokens)
680
+ dec_step = min(dec_output.prefill_steps) - 1
681
+ current_idx = torch.tensor([dec_step], device=self.device)
682
+
683
+ eos_detected_Bx = torch.zeros((batch_size,), dtype=torch.bool, device=self.device)
684
+ eos_countdown_Bx = torch.full((batch_size,), -1, dtype=torch.long, device=self.device)
685
+ finished_step_Bx = torch.full((batch_size,), -1, dtype=torch.long, device=self.device)
686
+
687
+ bos_over = False
688
+
689
+ if verbose:
690
+ print("generate: starting generation loop")
691
+ if use_torch_compile:
692
+ print("generate: using use_torch_compile=True, the first step may be slow")
693
+ start_time = time.time()
694
+
695
+ # --- Generation Loop ---
696
+ while dec_step < max_tokens:
697
+ if (eos_countdown_Bx == 0).all():
698
+ break
699
+
700
+ current_step_idx = dec_step + 1
701
+ torch.compiler.cudagraph_mark_step_begin()
702
+ dec_state.prepare_step(dec_step)
703
+ tokens_Bx1xC = dec_output.get_tokens_at(dec_step).repeat_interleave(2, dim=0) # Repeat for CFG
704
+
705
+ pred_BxC = self._decoder_step(
706
+ tokens_Bx1xC,
707
+ dec_state,
708
+ cfg_scale,
709
+ temperature,
710
+ top_p,
711
+ cfg_filter_top_k,
712
+ current_idx,
713
+ )
714
+
715
+ current_idx += 1
716
+
717
+ active_mask_Bx = eos_countdown_Bx != 0
718
+ eos_trigger_Bx = torch.zeros_like(active_mask_Bx)
719
+ if active_mask_Bx.any():
720
+ is_eos_token = (~eos_detected_Bx[active_mask_Bx]) & (pred_BxC[active_mask_Bx, 0] == audio_eos_value)
721
+ is_max_len = current_step_idx >= max_tokens - max_delay_pattern
722
+ eos_trigger_Bx[active_mask_Bx] = is_eos_token | is_max_len
723
+ eos_detected_Bx |= eos_trigger_Bx
724
+ start_countdown_mask_Bx = eos_trigger_Bx & (eos_countdown_Bx < 0)
725
+ if start_countdown_mask_Bx.any():
726
+ eos_countdown_Bx[start_countdown_mask_Bx] = max_delay_pattern
727
+ finished_step_Bx[start_countdown_mask_Bx] = current_step_idx
728
+
729
+ padding_mask_Bx = eos_countdown_Bx > 0
730
+ if padding_mask_Bx.any():
731
+ pred_active_BxC = pred_BxC[padding_mask_Bx].clone()
732
+ countdown_active_Bx = eos_countdown_Bx[padding_mask_Bx]
733
+ step_after_eos_Bx = max_delay_pattern - countdown_active_Bx
734
+ step_after_eos_Bx_ = step_after_eos_Bx.unsqueeze(1)
735
+ delay_pattern_Cx_ = delay_pattern_Cx.unsqueeze(0)
736
+ eos_mask_NxC = step_after_eos_Bx_ == delay_pattern_Cx_
737
+ pad_mask_NxC = step_after_eos_Bx_ > delay_pattern_Cx_
738
+ pred_active_BxC[eos_mask_NxC] = audio_eos_value
739
+ pred_active_BxC[pad_mask_NxC] = audio_pad_value
740
+ pred_BxC[padding_mask_Bx] = pred_active_BxC
741
+ eos_countdown_Bx[padding_mask_Bx] -= 1
742
+
743
+ # --- Update BOS flag (Original) ---
744
+ if not bos_over:
745
+ bos_over = all(
746
+ dec_step - prefill_step > max_delay_pattern for prefill_step in dec_output.prefill_steps
747
+ )
748
+
749
+ dec_output.update_one(pred_BxC, current_step_idx, not bos_over)
750
+
751
+ dec_step += 1
752
+
753
+ if verbose and dec_step % 86 == 0:
754
+ duration = time.time() - start_time
755
+ if duration > 0:
756
+ print(
757
+ f"generate step {dec_step}: speed={86 * batch_size / duration:.3f} tokens/s, realtime factor={batch_size / duration:.3f}x"
758
+ )
759
+ start_time = time.time()
760
+
761
+ # --- Finalize and Extract Output ---
762
+ final_step = dec_step + 1
763
+
764
+ finished_step_Bx[finished_step_Bx == -1] = final_step - max_delay_pattern
765
+
766
+ prefill_steps_tensor = torch.tensor(dec_output.prefill_steps, device=self.device)
767
+ lengths_Bx = finished_step_Bx - prefill_steps_tensor
768
+ lengths_Bx = torch.clamp(lengths_Bx, min=0)
769
+
770
+ max_len = lengths_Bx.max().item() + max_delay_pattern
771
+ outputs = []
772
+
773
+ if max_len > 0:
774
+ num_channels = self.config.decoder_config.num_channels
775
+ audio_pad_value = self.config.pad_token_id
776
+ generated_codes = torch.full(
777
+ (batch_size, max_len, num_channels),
778
+ fill_value=audio_pad_value,
779
+ dtype=torch.long,
780
+ device=self.device,
781
+ )
782
+
783
+ for i in range(batch_size):
784
+ start_step = dec_output.prefill_steps[i]
785
+ actual_len = lengths_Bx[i].item() + max_delay_pattern
786
+ if actual_len > 0:
787
+ tokens_to_copy = dec_output.generated_tokens[i, start_step : start_step + actual_len, :]
788
+ generated_codes[i, :actual_len, :] = tokens_to_copy
789
+
790
+ if verbose:
791
+ avg_steps = lengths_Bx.float().mean().item()
792
+ total_duration = time.time() - total_start_time
793
+ print(f"generate: avg steps={avg_steps:.1f}, total duration={total_duration:.3f}s")
794
+
795
+ del dec_state
796
+
797
+ outputs = self._generate_output(generated_codes, lengths_Bx)
798
+ else:
799
+ print("Warning: Nothing generated for any sequence in the batch.")
800
+ outputs = [None] * batch_size
801
+
802
+ return outputs if batch_size > 1 else outputs[0]
dia/state.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+
6
+ from .config import DiaConfig
7
+
8
+
9
+ def create_attn_mask(
10
+ q_padding_mask_1d: torch.Tensor,
11
+ k_padding_mask_1d: torch.Tensor,
12
+ device: torch.device,
13
+ is_causal: bool = False,
14
+ ) -> torch.Tensor:
15
+ """
16
+ Creates the attention mask (self or cross) mimicking JAX segment ID logic.
17
+ """
18
+ # B1, Tq = q_padding_mask_1d.shape
19
+ # B2, Tk = k_padding_mask_1d.shape
20
+
21
+ p_mask_q = q_padding_mask_1d.unsqueeze(2) # Shape [B, Tq, 1]
22
+ p_mask_k = k_padding_mask_1d.unsqueeze(1) # Shape [B, 1, Tk]
23
+
24
+ # Condition A: Non-padding query attends to non-padding key
25
+ non_pad_attends_non_pad = p_mask_q & p_mask_k # Shape [B, Tq, Tk]
26
+
27
+ # Condition B: Padding query attends to padding key
28
+ pad_attends_pad = (~p_mask_q) & (~p_mask_k) # Shape [B, Tq, Tk]
29
+
30
+ # Combine: True if padding status is compatible (both non-pad OR both pad)
31
+ mask = non_pad_attends_non_pad | pad_attends_pad # Shape [B, Tq, Tk]
32
+
33
+ if is_causal:
34
+ # assert Tq == Tk, "Causal mask requires query and key sequence lengths to be equal"
35
+ causal_mask_2d = torch.tril(torch.ones_like(mask[0], dtype=torch.bool, device=device)) # Shape [B, Tq, Tk]
36
+ causal_mask = mask & causal_mask_2d # Shape [B, Tq, Tk]
37
+ return causal_mask.unsqueeze(1) # Shape [B, 1, Tq, Tk]
38
+ else:
39
+ return mask.unsqueeze(1) # Shape [B, 1, Tq, Tk]
40
+
41
+
42
+ @dataclass
43
+ class EncoderInferenceState:
44
+ """Parameters specifically for encoder inference."""
45
+
46
+ max_seq_len: int
47
+ device: torch.device
48
+ positions: torch.Tensor
49
+ padding_mask: torch.Tensor
50
+ attn_mask: torch.Tensor
51
+
52
+ @classmethod
53
+ def new(cls, config: DiaConfig, cond_src: torch.Tensor) -> "EncoderInferenceState":
54
+ """Creates EtorchrInferenceParams from DiaConfig and a device."""
55
+ device = cond_src.device
56
+
57
+ positions = torch.arange(
58
+ config.encoder_config.max_position_embeddings, dtype=torch.float32, device=device
59
+ ).unsqueeze(0)
60
+ padding_mask = (cond_src.squeeze(1) != 0).to(device).repeat_interleave(2, dim=0)
61
+ attn_mask = create_attn_mask(padding_mask, padding_mask, device, is_causal=False)
62
+
63
+ return cls(
64
+ max_seq_len=config.encoder_config.max_position_embeddings,
65
+ device=device,
66
+ positions=positions,
67
+ padding_mask=padding_mask,
68
+ attn_mask=attn_mask,
69
+ )
70
+
71
+
72
+ class KVCache(torch.nn.Module):
73
+ k: torch.Tensor
74
+ v: torch.Tensor
75
+
76
+ def __init__(
77
+ self,
78
+ batch_size: int,
79
+ num_heads: int,
80
+ max_len: int,
81
+ head_dim: int,
82
+ dtype: torch.dtype,
83
+ device: torch.device,
84
+ k: torch.Tensor | None = None,
85
+ v: torch.Tensor | None = None,
86
+ ):
87
+ k = torch.zeros((2 * batch_size, num_heads, max_len, head_dim), dtype=dtype, device=device) if k is None else k
88
+ v = torch.zeros((2 * batch_size, num_heads, max_len, head_dim), dtype=dtype, device=device) if v is None else v
89
+ super().__init__()
90
+
91
+ self.register_buffer("k", k)
92
+ self.register_buffer("v", v)
93
+
94
+ @classmethod
95
+ def from_kv(cls, k: torch.Tensor, v: torch.Tensor) -> "KVCache":
96
+ return cls(
97
+ batch_size=k.shape[0] // 2,
98
+ num_heads=k.shape[1],
99
+ max_len=k.shape[2],
100
+ head_dim=k.shape[3],
101
+ dtype=k.dtype,
102
+ device=k.device,
103
+ k=k,
104
+ v=v,
105
+ )
106
+
107
+ def update(self, k: torch.Tensor, v: torch.Tensor, current_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
108
+ k_out, v_out = self.k, self.v
109
+ k_out[:, :, current_idx, :] = k
110
+ v_out[:, :, current_idx, :] = v
111
+ return self.k, self.v
112
+
113
+ def prefill(self, k: torch.Tensor, v: torch.Tensor):
114
+ prefill_len = k.shape[2]
115
+ self.k[:, :, :prefill_len, :] = k
116
+ self.v[:, :, :prefill_len, :] = v
117
+
118
+
119
+ @dataclass
120
+ class DecoderInferenceState:
121
+ """Parameters specifically for decoder inference."""
122
+
123
+ device: torch.device
124
+ dtype: torch.dtype
125
+ enc_out: torch.Tensor
126
+ enc_positions: torch.Tensor
127
+ dec_positions: torch.Tensor
128
+ self_attn_cache: list[KVCache]
129
+ cross_attn_cache: list[KVCache]
130
+ casual_attn_mask: torch.Tensor
131
+ cross_attn_mask: torch.Tensor
132
+
133
+ @classmethod
134
+ def new(
135
+ cls,
136
+ config: DiaConfig,
137
+ enc_state: EncoderInferenceState,
138
+ enc_out: torch.Tensor,
139
+ dec_cross_attn_cache: list[KVCache],
140
+ compute_dtype: torch.dtype,
141
+ max_generation_length: Optional[int] = None,
142
+ ) -> "DecoderInferenceState":
143
+ """Creates DecoderInferenceParams from DiaConfig and a device."""
144
+ device = enc_out.device
145
+ max_audio_len = max_generation_length or config.decoder_config.max_position_embeddings
146
+ batch_size = enc_out.shape[0] // 2
147
+
148
+ dec_positions = torch.full((2 * batch_size, 1), fill_value=0, dtype=torch.int32, device=device)
149
+ causal_mask = torch.tril(torch.ones(max_audio_len, max_audio_len, dtype=torch.bool, device=device))
150
+ dec_mask = torch.ones((2 * batch_size, 1), dtype=torch.bool, device=device)
151
+ cross_attn_mask = create_attn_mask(dec_mask, enc_state.padding_mask, device, is_causal=False)
152
+
153
+ self_attn_cache = [
154
+ KVCache(
155
+ batch_size,
156
+ config.decoder_config.num_key_value_heads,
157
+ max_audio_len,
158
+ config.decoder_config.head_dim,
159
+ compute_dtype,
160
+ device,
161
+ )
162
+ for _ in range(config.decoder_config.num_hidden_layers)
163
+ ]
164
+
165
+ return cls(
166
+ device=device,
167
+ dtype=compute_dtype,
168
+ enc_out=enc_out,
169
+ enc_positions=enc_state.positions,
170
+ dec_positions=dec_positions,
171
+ self_attn_cache=self_attn_cache,
172
+ cross_attn_cache=dec_cross_attn_cache,
173
+ casual_attn_mask=causal_mask,
174
+ cross_attn_mask=cross_attn_mask,
175
+ )
176
+
177
+ def prepare_step(self, step_from: int, step_to: int | None = None) -> None:
178
+ if step_to is None:
179
+ step_to = step_from + 1
180
+ self.dec_positions = torch.arange(step_from, step_to, dtype=torch.int32, device=self.device).unsqueeze(0)
181
+
182
+
183
+ @dataclass
184
+ class DecoderOutput:
185
+ generated_tokens: torch.Tensor
186
+ prefill_steps: list[int]
187
+
188
+ @classmethod
189
+ def new(cls, batch_size: int, config: DiaConfig, device: torch.device) -> "DecoderOutput":
190
+ max_audio_len = config.decoder_config.max_position_embeddings
191
+ return cls(
192
+ generated_tokens=torch.full(
193
+ (batch_size, max_audio_len, config.decoder_config.num_channels),
194
+ fill_value=-1,
195
+ dtype=torch.int,
196
+ device=device,
197
+ ),
198
+ prefill_steps=[],
199
+ )
200
+
201
+ def get_tokens_at(self, step_from: int, step_to: int | None = None) -> torch.Tensor:
202
+ if step_to is None:
203
+ step_to = step_from + 1
204
+ return self.generated_tokens[:, step_from:step_to, :]
205
+
206
+ def update_one(self, dec_out: torch.Tensor, step: int, apply_mask: bool = False):
207
+ dec_out = dec_out.to(self.generated_tokens.dtype)
208
+ if apply_mask:
209
+ mask = self.generated_tokens[:, step, :] == -1
210
+ self.generated_tokens[:, step, :] = torch.where(mask, dec_out, self.generated_tokens[:, step, :])
211
+ else:
212
+ self.generated_tokens[:, step, :] = dec_out
213
+
214
+ def prefill(self, dec_out: torch.Tensor, prefill_steps: list[int]):
215
+ length = dec_out.shape[1]
216
+ self.generated_tokens[:, :length, :] = dec_out
217
+ self.prefill_steps = prefill_steps
dia/static/images/banner.png ADDED

Git LFS Details

  • SHA256: e6dcd20d2ec2bbb5a6a45d9fd37bafdd404d3d427e2be2e7279ff989ed5935ff
  • Pointer size: 131 Bytes
  • Size of remote file: 181 kB
docker/Dockerfile.cpu ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile.cpu - CPU-only deployment for DIA
2
+ # --------------------------------------------------
3
+ # Build: docker build . -f docker/Dockerfile.cpu -t dia-cpu
4
+ # Run: docker run --rm -p 7860:7860 dia-cpu
5
+
6
+ FROM python:3.10-slim
7
+
8
+ # Set non-interactive frontend
9
+ ENV DEBIAN_FRONTEND=noninteractive
10
+
11
+ # Install venv, and system dependencies
12
+ RUN apt-get update && apt-get install -y \
13
+ python3-venv \
14
+ libsndfile1 \
15
+ ffmpeg \
16
+ curl \
17
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
18
+
19
+ # Create non-root user and set up directories
20
+ RUN useradd -m -u 1001 appuser && \
21
+ mkdir -p /app/outputs /app && \
22
+ chown -R appuser:appuser /app
23
+
24
+ USER appuser
25
+ WORKDIR /app
26
+
27
+ # Copy all code (including pyproject.toml)
28
+ COPY --chown=appuser:appuser . .
29
+
30
+ # Create and activate virtual environment
31
+ RUN python3 -m venv /app/venv
32
+ ENV PATH="/app/venv/bin:$PATH"
33
+
34
+ # Install all project dependencies (CPU-only PyTorch)
35
+ RUN pip install --upgrade pip && \
36
+ pip install torch torchaudio --index-url https://download.pytorch.org/whl/cpu && \
37
+ pip install --no-cache-dir -e .[dev]
38
+
39
+ # Set environment variables
40
+ ENV PYTHONUNBUFFERED=1 \
41
+ PYTHONPATH=/app
42
+
43
+ # Expose Gradio default port
44
+ ENV GRADIO_SERVER_NAME="0.0.0.0"
45
+ EXPOSE 7860
46
+
47
+ # Entrypoint
48
+ CMD ["python3", "app.py"]
docker/Dockerfile.gpu ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile.gpu - GPU deployment for DIA
2
+ # --------------------------------------------------
3
+ # Build: docker build . -f docker/Dockerfile.gpu -t dia-gpu
4
+ # Run: docker run --rm --gpus all -p 7860:7860 dia-gpu
5
+ # Requires NVIDIA Container Toolkit on host.
6
+
7
+ FROM pytorch/pytorch:2.1.2-cuda12.1-cudnn8-runtime
8
+
9
+ # Set non-interactive frontend
10
+ ENV DEBIAN_FRONTEND=noninteractive
11
+
12
+ # Install venv, and system dependencies
13
+ RUN apt-get update && apt-get install -y \
14
+ python3-venv \
15
+ libsndfile1 \
16
+ ffmpeg \
17
+ curl \
18
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
19
+
20
+ # Create non-root user and set up directories
21
+ RUN useradd -m -u 1001 appuser && \
22
+ mkdir -p /app/outputs /app && \
23
+ chown -R appuser:appuser /app
24
+
25
+ USER appuser
26
+ WORKDIR /app
27
+
28
+ # Copy all code (including pyproject.toml)
29
+ COPY --chown=appuser:appuser . .
30
+
31
+ # Create and activate virtual environment
32
+ RUN python3 -m venv /app/venv
33
+ ENV PATH="/app/venv/bin:$PATH"
34
+
35
+ # Install all project dependencies
36
+ RUN pip install --upgrade pip && pip install --no-cache-dir .
37
+
38
+ # Set environment variables
39
+ ENV PYTHONUNBUFFERED=1 \
40
+ PYTHONPATH=/app \
41
+ USE_GPU=true \
42
+ LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/cuda-12.1/lib64:${LD_LIBRARY_PATH}
43
+
44
+ # Expose Gradio default port
45
+ ENV GRADIO_SERVER_NAME="0.0.0.0"
46
+ EXPOSE 7860
47
+
48
+ # Entrypoint
49
+ CMD ["python3", "app.py"]
example/benchmark.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from random import choice
2
+
3
+ import torch
4
+
5
+ from dia.model import Dia
6
+
7
+
8
+ torch._inductor.config.coordinate_descent_tuning = True
9
+ torch._inductor.config.triton.unique_kernel_names = True
10
+ torch._inductor.config.fx_graph_cache = True
11
+
12
+ # debugging
13
+ torch._logging.set_logs(graph_breaks=True, recompiles=True)
14
+
15
+ model_name = "nari-labs/Dia-1.6B-0626"
16
+ compute_dtype = "float16"
17
+
18
+ model = Dia.from_pretrained(model_name, compute_dtype=compute_dtype)
19
+
20
+
21
+ test_cases = [
22
+ "[S1] Dia is an open weights text to dialogue model.",
23
+ "[S1] Dia is an open weights text to dialogue model. [S2] You get full control over scripts and voices. [S1] Wow. Amazing. (laughs) [S2] Try it now on Git hub or Hugging Face.",
24
+ "[S1] torch.compile is a new feature in PyTorch that allows you to compile your model with a single line of code.",
25
+ "[S1] torch.compile is a new feature in PyTorch that allows you to compile your model with a single line of code. [S2] It is a new feature in PyTorch that allows you to compile your model with a single line of code.",
26
+ ]
27
+
28
+
29
+ # Wram up
30
+ for _ in range(2):
31
+ text = choice(test_cases)
32
+ output = model.generate(text, audio_prompt="./example_prompt.mp3", use_torch_compile=True, verbose=True)
33
+ output = model.generate(text, use_torch_compile=True, verbose=True)
34
+
35
+ # Benchmark
36
+ for _ in range(10):
37
+ text = choice(test_cases)
38
+ output = model.generate(text, use_torch_compile=True, verbose=True)
39
+ output = model.generate(text, audio_prompt="./example_prompt.mp3", use_torch_compile=True, verbose=True)
example/simple-cpu.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from dia.model import Dia
4
+
5
+
6
+ # Select device: CPU
7
+ device = torch.device("cpu")
8
+ print(f"Using device: {device}")
9
+
10
+ # Load model
11
+ model = Dia.from_pretrained(
12
+ "nari-labs/Dia-1.6B-0626", compute_dtype="float32", device=device
13
+ ) # Float32 works better than float16 on CPU - you can also test with float16
14
+
15
+ text = "[S1] Dia is an open weights text to dialogue model. [S2] You get full control over scripts and voices. [S1] Wow. Amazing. (laughs) [S2] Try it now on Git hub or Hugging Face."
16
+
17
+ output = model.generate(text, use_torch_compile=False, verbose=True)
18
+
19
+ model.save_audio("simple.mp3", output)
example/simple-mac.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dia.model import Dia
2
+
3
+
4
+ model = Dia.from_pretrained("nari-labs/Dia-1.6B-0626", compute_dtype="float16")
5
+
6
+ text = "[S1] Dia is an open weights text to dialogue model. [S2] You get full control over scripts and voices. [S1] Wow. Amazing. (laughs) [S2] Try it now on Git hub or Hugging Face."
7
+
8
+ # It is important to set the `use_torch_compile` argument to `False` when using Dia on MacOS.
9
+ # This is because the `torch.compile` function is not supported on MacOS.
10
+ output = model.generate(text, use_torch_compile=False, verbose=True)
11
+
12
+ model.save_audio("simple.mp3", output)
example/simple.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dia.model import Dia
2
+
3
+
4
+ model = Dia.from_pretrained("nari-labs/Dia-1.6B-0626", compute_dtype="float16")
5
+
6
+ text = "[S1] Dia is an open weights text to dialogue model. [S2] You get full control over scripts and voices. [S1] Wow. Amazing. (laughs) [S2] Try it now on Git hub or Hugging Face."
7
+
8
+ output = model.generate(
9
+ text,
10
+ use_torch_compile=False,
11
+ verbose=True,
12
+ cfg_scale=3.0,
13
+ temperature=1.8,
14
+ top_p=0.90,
15
+ cfg_filter_top_k=50,
16
+ )
17
+
18
+ model.save_audio("simple.mp3", output)
example/simple_batch.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dia.model import Dia
2
+
3
+
4
+ model = Dia.from_pretrained("nari-labs/Dia-1.6B-0626", compute_dtype="float16")
5
+
6
+ text = "[S1] Dia is an open weights text to dialogue model. [S2] You get full control over scripts and voices. [S1] Wow. Amazing. (laughs) [S2] Try it now on Git hub or Hugging Face."
7
+ texts = [text for _ in range(10)]
8
+
9
+ output = model.generate(texts, use_torch_compile=True, verbose=True, max_tokens=1500)
10
+
11
+ for i, o in enumerate(output):
12
+ model.save_audio(f"simple_{i}.mp3", o)
example/voice_clone.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dia.model import Dia
2
+
3
+
4
+ model = Dia.from_pretrained("nari-labs/Dia-1.6B-0626", compute_dtype="float16")
5
+
6
+ # You should put the transcript of the voice you want to clone
7
+ # We will use the audio created by running simple.py as an example.
8
+ # Note that you will be REQUIRED TO RUN simple.py for the script to work as-is.
9
+ clone_from_text = "[S1] Dia is an open weights text to dialogue model. [S2] You get full control over scripts and voices. [S1] Wow. Amazing. (laughs) [S2] Try it now on Git hub or Hugging Face."
10
+ clone_from_audio = "simple.mp3"
11
+
12
+ # For your custom needs, replace above with below and add your audio file to this directory:
13
+ # clone_from_text = "[S1] ... [S2] ... [S1] ... corresponding to your_audio_name.mp3"
14
+ # clone_from_audio = "your_audio_name.mp3"
15
+
16
+ # Text to generate
17
+ text_to_generate = "[S1] Hello, how are you? [S2] I'm good, thank you. [S1] What's your name? [S2] My name is Dia. [S1] Nice to meet you. [S2] Nice to meet you too."
18
+
19
+ # It will only return the audio from the text_to_generate
20
+ output = model.generate(
21
+ clone_from_text + text_to_generate,
22
+ audio_prompt=clone_from_audio,
23
+ use_torch_compile=False,
24
+ verbose=True,
25
+ cfg_scale=4.0,
26
+ temperature=1.8,
27
+ top_p=0.90,
28
+ cfg_filter_top_k=50,
29
+ )
30
+
31
+ model.save_audio("voice_clone.mp3", output)
example/voice_clone_batch.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dia.model import Dia
2
+
3
+
4
+ model = Dia.from_pretrained("nari-labs/Dia-1.6B-0626", compute_dtype="float16")
5
+
6
+ # You should put the transcript of the voice you want to clone
7
+ # We will use the audio created by running simple.py as an example.
8
+ # Note that you will be REQUIRED TO RUN simple.py for the script to work as-is.
9
+ clone_from_text = "[S1] Dia is an open weights text to dialogue model. [S2] You get full control over scripts and voices. [S1] Wow. Amazing. (laughs) [S2] Try it now on Git hub or Hugging Face."
10
+
11
+ # For your custom needs, replace above with below and add your audio file to this directory:
12
+ # clone_from_text = "[S1] ... [S2] ... [S1] ... corresponding to your_audio_name.mp3"
13
+ # clone_from_audio = "your_audio_name.mp3"
14
+
15
+ # Text to generate
16
+ text_to_generate = "[S1] Dia is an open weights text to dialogue model. [S2] You get full control over scripts and voices. [S1] Wow. Amazing. (laughs) [S2] Try it now on Git hub or Hugging Face."
17
+
18
+ clone_from_audios = [f"simple_{i}.mp3" for i in range(10)]
19
+
20
+ texts = [clone_from_text + text_to_generate for _ in range(10)]
21
+
22
+ # It will only return the audio from the text_to_generate
23
+ output = model.generate(texts, audio_prompt=clone_from_audios, use_torch_compile=True, verbose=True, max_tokens=2000)
24
+
25
+ for i, o in enumerate(output):
26
+ model.save_audio(f"voice_clone_{i}.mp3", o)
example_prompt.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:576f5134df511edffcd0b5c87c91d829811d825c48845b3b9a156e1e7dd730e1
3
+ size 45839
hf.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, DiaForConditionalGeneration
2
+
3
+
4
+ torch_device = "cuda"
5
+ model_checkpoint = "nari-labs/Dia-1.6B-0626"
6
+
7
+ text = [
8
+ "[S1] Dia is an open weights text to dialogue model. [S2] You get full control over scripts and voices. [S1] Wow. Amazing. (laughs) [S2] Try it now on Git hub or Hugging Face."
9
+ ]
10
+ processor = AutoProcessor.from_pretrained(model_checkpoint)
11
+ inputs = processor(text=text, padding=True, return_tensors="pt").to(torch_device)
12
+
13
+ model = DiaForConditionalGeneration.from_pretrained(model_checkpoint).to(torch_device)
14
+ outputs = model.generate(**inputs, max_new_tokens=3072, guidance_scale=3.0, temperature=1.8, top_p=0.90, top_k=45)
15
+
16
+ outputs = processor.batch_decode(outputs)
17
+ processor.save_audio(outputs, "example.mp3")
pyproject.toml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "nari-tts"
3
+ version = "0.1.0"
4
+ description = "Dia - A text-to-speech model for dialogue generation"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ license = {file = "LICENSE"}
8
+ authors = [
9
+ {name = "Nari Labs", email = "contact@narilabs.ai"}
10
+ ]
11
+ dependencies = [
12
+ "descript-audio-codec>=1.0.0",
13
+ "gradio>=5.25.2",
14
+ "huggingface-hub>=0.30.2",
15
+ "numpy>=2.2.4",
16
+ "pydantic>=2.11.3",
17
+ "safetensors>=0.5.3",
18
+ "soundfile>=0.13.1",
19
+ "torch==2.6.0",
20
+ "torchaudio==2.6.0",
21
+ "triton==3.2.0 ; sys_platform == 'linux'",
22
+ "triton-windows==3.2.0.post18 ; sys_platform == 'win32'",
23
+ ]
24
+
25
+ [build-system]
26
+ requires = ["hatchling"]
27
+ build-backend = "hatchling.build"
28
+
29
+ [project.urls]
30
+ "Homepage" = "https://github.com/nari-labs/dia"
31
+ "Bug Tracker" = "https://github.com/nari-labs/dia/issues"
32
+
33
+ [tool.hatch.build.targets.wheel]
34
+ packages = ["dia"]
35
+
36
+ [tool.ruff]
37
+ # Never enforce `E501` (line length violations).
38
+ lint.ignore = ["C901", "E501", "E741", "W605"]
39
+ lint.select = ["C", "E", "F", "I", "W"]
40
+ line-length = 119
41
+
42
+ # Ignore import violations in all `__init__.py` files.
43
+ [tool.ruff.lint.per-file-ignores]
44
+ "__init__.py" = ["E402", "F401", "F403", "F811"]
45
+
46
+ [tool.ruff.lint.isort]
47
+ lines-after-imports = 2
48
+
49
+ [tool.uv.sources]
50
+ torch = [
51
+ { index = "pytorch-cu126", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
52
+ ]
53
+ torchaudio = [
54
+ { index = "pytorch-cu126", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
55
+ ]
56
+
57
+ [[tool.uv.index]]
58
+ name = "pytorch-cu126"
59
+ url = "https://download.pytorch.org/whl/cu126"
60
+ explicit = true
61
+
62
+ [dependency-groups]
63
+ dev = [
64
+ "ninja>=1.11.1.4",
65
+ "packaging>=25.0",
66
+ ]
uv.lock ADDED
The diff for this file is too large to render. See raw diff