NariLabs commited on
Commit
f385dd1
·
verified ·
1 Parent(s): 14b2a6c

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +264 -0
app.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import io
5
+ import os
6
+ from pathlib import Path
7
+ from typing import List, Tuple
8
+
9
+ import gradio as gr
10
+ import torch
11
+ import spaces
12
+
13
+ from dia2 import Dia2, GenerationConfig, SamplingConfig
14
+
15
+ DEFAULT_REPO = os.environ.get("DIA2_DEFAULT_REPO", "nari-labs/Dia2-2B")
16
+ MAX_TURNS = 10
17
+ INITIAL_TURNS = 2
18
+
19
+ _dia: Dia2 | None = None
20
+
21
+
22
+ def _get_dia() -> Dia2:
23
+ global _dia
24
+ if _dia is None:
25
+ _dia = Dia2.from_repo(DEFAULT_REPO, device="cuda", dtype="bfloat16")
26
+ return _dia
27
+
28
+
29
+ def _concat_script(turn_count: int, turn_values: List[str]) -> str:
30
+ lines: List[str] = []
31
+ for idx in range(min(turn_count, len(turn_values))):
32
+ text = (turn_values[idx] or "").strip()
33
+ if not text:
34
+ continue
35
+ speaker = "[S1]" if idx % 2 == 0 else "[S2]"
36
+ lines.append(f"{speaker} {text}")
37
+ return "\n".join(lines)
38
+
39
+
40
+ EXAMPLES: dict[str, dict[str, List[str] | str | None]] = {
41
+ "Intro": {
42
+ "turns": [
43
+ "Hello Dia2 fans! Today we're unveiling the new open TTS model.",
44
+ "Sounds exciting. Can you show a sample right now?",
45
+ "Absolutely. (laughs) Just press generate.",
46
+ ],
47
+ "voice_s1": "example_prefix1.wav",
48
+ "voice_s2": "example_prefix2.wav",
49
+ },
50
+ "Customer Support": {
51
+ "turns": [
52
+ "Thanks for calling. How can I help you today?",
53
+ "My parcel never arrived and it's been two weeks.",
54
+ "I'm sorry about that. Let me check your tracking number.",
55
+ "Appreciate it. I really need that package soon.",
56
+ ],
57
+ "voice_s1": "example_prefix1.wav",
58
+ "voice_s2": "example_prefix2.wav",
59
+ },
60
+ }
61
+
62
+
63
+ def _apply_turn_visibility(count: int) -> List[gr.Update]:
64
+ return [gr.update(visible=i < count) for i in range(MAX_TURNS)]
65
+
66
+
67
+ def _add_turn(count: int):
68
+ count = min(count + 1, MAX_TURNS)
69
+ return (count, *_apply_turn_visibility(count))
70
+
71
+
72
+ def _remove_turn(count: int):
73
+ count = max(1, count - 1)
74
+ return (count, *_apply_turn_visibility(count))
75
+
76
+
77
+ def _load_example(name: str, count: int):
78
+ data = EXAMPLES.get(name)
79
+ if not data:
80
+ return (count, *_apply_turn_visibility(count), None, None)
81
+ turns = data.get("turns", [])
82
+ voice_s1_path = data.get("voice_s1")
83
+ voice_s2_path = data.get("voice_s2")
84
+ new_count = min(len(turns), MAX_TURNS)
85
+ updates: List[gr.Update] = []
86
+ for idx in range(MAX_TURNS):
87
+ if idx < new_count:
88
+ updates.append(gr.update(value=turns[idx], visible=True))
89
+ else:
90
+ updates.append(gr.update(value="", visible=idx < INITIAL_TURNS))
91
+ return (new_count, *updates, voice_s1_path, voice_s2_path)
92
+
93
+
94
+ def _prepare_prefix(file_path: str | None) -> str | None:
95
+ if not file_path:
96
+ return None
97
+ path = Path(file_path)
98
+ if not path.exists():
99
+ return None
100
+ return str(path)
101
+
102
+
103
+ @spaces.GPU(duration=100)
104
+ def generate_audio(
105
+ turn_count: int,
106
+ *inputs,
107
+ ):
108
+ turn_values = list(inputs[:MAX_TURNS])
109
+ voice_s1 = inputs[MAX_TURNS]
110
+ voice_s2 = inputs[MAX_TURNS + 1]
111
+ cfg_scale = float(inputs[MAX_TURNS + 2])
112
+ text_temperature = float(inputs[MAX_TURNS + 3])
113
+ audio_temperature = float(inputs[MAX_TURNS + 4])
114
+ text_top_k = int(inputs[MAX_TURNS + 5])
115
+ audio_top_k = int(inputs[MAX_TURNS + 6])
116
+ include_prefix = bool(inputs[MAX_TURNS + 7])
117
+
118
+ script = _concat_script(turn_count, turn_values)
119
+ if not script.strip():
120
+ raise gr.Error("Please enter at least one non-empty speaker turn.")
121
+
122
+ dia = _get_dia()
123
+ config = GenerationConfig(
124
+ cfg_scale=cfg_scale,
125
+ text=SamplingConfig(temperature=text_temperature, top_k=text_top_k),
126
+ audio=SamplingConfig(temperature=audio_temperature, top_k=audio_top_k),
127
+ use_cuda_graph=True,
128
+ )
129
+ kwargs = {
130
+ "prefix_speaker_1": _prepare_prefix(voice_s1),
131
+ "prefix_speaker_2": _prepare_prefix(voice_s2),
132
+ "include_prefix": include_prefix,
133
+ }
134
+ buffer = io.StringIO()
135
+ with contextlib.redirect_stdout(buffer):
136
+ result = dia.generate(
137
+ script,
138
+ config=config,
139
+ output_wav=None,
140
+ verbose=True,
141
+ **kwargs,
142
+ )
143
+ waveform = result.waveform.detach().cpu().numpy()
144
+ sample_rate = result.sample_rate
145
+ timestamps = result.timestamps
146
+ log_text = buffer.getvalue().strip()
147
+ table = [[w, round(t, 3)] for w, t in timestamps]
148
+ return (sample_rate, waveform), table, log_text or "Generation finished."
149
+
150
+
151
+ def build_interface() -> gr.Blocks:
152
+ with gr.Blocks(
153
+ title="Dia2 TTS", css=".compact-turn textarea {min-height: 60px}"
154
+ ) as demo:
155
+ gr.Markdown(
156
+ """## Dia2 — Open TTS Model
157
+ Compose dialogue, attach optional voice prompts, and generate audio (CUDA graphs enabled by default)."""
158
+ )
159
+ turn_state = gr.State(INITIAL_TURNS)
160
+ with gr.Row(equal_height=True):
161
+ example_dropdown = gr.Dropdown(
162
+ choices=["(select example)"] + list(EXAMPLES.keys()),
163
+ label="Examples",
164
+ value="(select example)",
165
+ )
166
+ with gr.Row(equal_height=True):
167
+ with gr.Column(scale=1):
168
+ with gr.Group():
169
+ gr.Markdown("### Script")
170
+ controls = []
171
+ for idx in range(MAX_TURNS):
172
+ speaker = "[S1]" if idx % 2 == 0 else "[S2]"
173
+ box = gr.Textbox(
174
+ label=f"{speaker} turn {idx + 1}",
175
+ lines=2,
176
+ elem_classes=["compact-turn"],
177
+ placeholder=f"Enter dialogue for {speaker}…",
178
+ visible=idx < INITIAL_TURNS,
179
+ )
180
+ controls.append(box)
181
+ with gr.Row():
182
+ add_btn = gr.Button("Add Turn")
183
+ remove_btn = gr.Button("Remove Turn")
184
+ with gr.Group():
185
+ gr.Markdown("### Voice Prompts")
186
+ with gr.Row():
187
+ voice_s1 = gr.File(
188
+ label="[S1] voice (wav/mp3)", type="filepath"
189
+ )
190
+ voice_s2 = gr.File(
191
+ label="[S2] voice (wav/mp3)", type="filepath"
192
+ )
193
+ with gr.Group():
194
+ gr.Markdown("### Sampling")
195
+ cfg_scale = gr.Slider(
196
+ 1.0, 8.0, value=6.0, step=0.1, label="CFG Scale"
197
+ )
198
+ with gr.Group():
199
+ gr.Markdown("#### Text Sampling")
200
+ text_temperature = gr.Slider(
201
+ 0.1, 1.5, value=0.6, step=0.05, label="Text Temperature"
202
+ )
203
+ text_top_k = gr.Slider(
204
+ 1, 200, value=50, step=1, label="Text Top-K"
205
+ )
206
+ with gr.Group():
207
+ gr.Markdown("#### Audio Sampling")
208
+ audio_temperature = gr.Slider(
209
+ 0.1, 1.5, value=0.8, step=0.05, label="Audio Temperature"
210
+ )
211
+ audio_top_k = gr.Slider(
212
+ 1, 200, value=50, step=1, label="Audio Top-K"
213
+ )
214
+ include_prefix = gr.Checkbox(
215
+ label="Keep prefix audio in output", value=False
216
+ )
217
+ generate_btn = gr.Button("Generate", variant="primary")
218
+ with gr.Column(scale=1):
219
+ gr.Markdown("### Output")
220
+ audio_out = gr.Audio(label="Waveform", interactive=False)
221
+ timestamps = gr.Dataframe(
222
+ headers=["word", "seconds"], label="Timestamps"
223
+ )
224
+ log_box = gr.Textbox(label="Logs", lines=8)
225
+
226
+ add_btn.click(
227
+ lambda c: _add_turn(c),
228
+ inputs=turn_state,
229
+ outputs=[turn_state, *controls],
230
+ )
231
+ remove_btn.click(
232
+ lambda c: _remove_turn(c),
233
+ inputs=turn_state,
234
+ outputs=[turn_state, *controls],
235
+ )
236
+ example_dropdown.change(
237
+ lambda name, c: _load_example(name, c),
238
+ inputs=[example_dropdown, turn_state],
239
+ outputs=[turn_state, *controls, voice_s1, voice_s2],
240
+ )
241
+
242
+ generate_btn.click(
243
+ generate_audio,
244
+ inputs=[
245
+ turn_state,
246
+ *controls,
247
+ voice_s1,
248
+ voice_s2,
249
+ cfg_scale,
250
+ text_temperature,
251
+ audio_temperature,
252
+ text_top_k,
253
+ audio_top_k,
254
+ include_prefix,
255
+ ],
256
+ outputs=[audio_out, timestamps, log_box],
257
+ )
258
+ return demo
259
+
260
+
261
+ if __name__ == "__main__":
262
+ app = build_interface()
263
+ app.queue(default_concurrency_limit=1)
264
+ app.launch(share=True)