pnnbao-ump commited on
Commit
ab844ce
·
1 Parent(s): 0782303

Add VieNeu-TTS app with LFS audio files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.wav filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,27 @@
1
  ---
2
  title: VieNeuTTS
3
- emoji: 🐠
4
- colorFrom: indigo
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: VieNeuTTS
3
+ emoji: 🎙️
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # VieNeu-TTS: Vietnamese Text-to-Speech
14
+
15
+ Hệ thống tổng hợp tiếng nói tiếng Việt sử dụng Large Language Model và Neural Codec.
16
+
17
+ ## Tính năng
18
+
19
+ - ✅ Hỗ trợ tiếng Việt với giọng tự nhiên
20
+ - ✅ Nhiều giọng có sẵn (Nam/Nữ miền Nam)
21
+ - ✅ Hỗ trợ voice cloning với audio tùy chỉnh
22
+ - ✅ Chất lượng âm thanh cao (24kHz)
23
+
24
+ ## Liên kết
25
+
26
+ - [GitHub Repository](https://github.com/pnnbao97/VieNeu-TTS)
27
+ - [Model Card](https://huggingface.co/pnnbao-ump/VieNeu-TTS)
app.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import soundfile as sf
4
+ from vieneutts import VieNeuTTS
5
+ import tempfile
6
+ import os
7
+
8
+ # Khởi tạo model
9
+ print("Đang tải model...")
10
+ tts = VieNeuTTS(
11
+ backbone_repo="pnnbao-ump/VieNeu-TTS",
12
+ backbone_device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu",
13
+ codec_repo="neuphonic/neucodec",
14
+ codec_device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu"
15
+ )
16
+
17
+ # Danh sách giọng mẫu
18
+ VOICE_SAMPLES = {
19
+ "Nam miền Nam": {
20
+ "audio": "./sample/id_0001.wav",
21
+ "text": "./sample/id_0001.txt"
22
+ },
23
+ "Nữ miền Nam": {
24
+ "audio": "./sample/id_0002.wav",
25
+ "text": "./sample/id_0002.txt"
26
+ }
27
+ }
28
+
29
+ def synthesize_speech(text, voice_choice, custom_audio=None, custom_text=None):
30
+ """
31
+ Tổng hợp giọng nói từ văn bản
32
+
33
+ Args:
34
+ text: Văn bản cần chuyển thành giọng nói
35
+ voice_choice: Lựa chọn giọng có sẵn
36
+ custom_audio: File audio tùy chỉnh (optional)
37
+ custom_text: Text tương ứng với audio tùy chỉnh (optional)
38
+ """
39
+ try:
40
+ # Xác định reference audio và text
41
+ if custom_audio is not None and custom_text:
42
+ ref_audio_path = custom_audio
43
+ ref_text = custom_text
44
+ elif voice_choice in VOICE_SAMPLES:
45
+ ref_audio_path = VOICE_SAMPLES[voice_choice]["audio"]
46
+ ref_text_path = VOICE_SAMPLES[voice_choice]["text"]
47
+ with open(ref_text_path, "r", encoding="utf-8") as f:
48
+ ref_text = f.read()
49
+ else:
50
+ return None, "❌ Vui lòng chọn giọng hoặc tải lên audio tùy chỉnh"
51
+
52
+ # Kiểm tra text input
53
+ if not text or text.strip() == "":
54
+ return None, "❌ Vui lòng nhập văn bản cần tổng hợp"
55
+
56
+ # Encode reference audio
57
+ print(f"Đang encode audio tham chiếu: {ref_audio_path}")
58
+ ref_codes = tts.encode_reference(ref_audio_path)
59
+
60
+ # Tổng hợp giọng nói
61
+ print(f"Đang tổng hợp giọng nói cho: {text[:50]}...")
62
+ wav = tts.infer(text, ref_codes, ref_text)
63
+
64
+ # Lưu file tạm
65
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
66
+ sf.write(tmp_file.name, wav, 24000)
67
+ output_path = tmp_file.name
68
+
69
+ return output_path, "✅ Tổng hợp thành công!"
70
+
71
+ except Exception as e:
72
+ print(f"Lỗi: {str(e)}")
73
+ return None, f"❌ Lỗi: {str(e)}"
74
+
75
+ # Các ví dụ mẫu
76
+ examples = [
77
+ ["Các khóa học trực tuyến đang giúp học sinh tiếp cận kiến thức mọi lúc mọi nơi.", "Nam miền Nam"],
78
+ ["Các nghiên cứu về bệnh Alzheimer cho thấy tác dụng tích cực của các bài tập trí não.", "Nữ miền Nam"],
79
+ ["Một tiểu thuyết trinh thám hiện đại dẫn dắt độc giả qua những tình tiết phức tạp.", "Nam miền Nam"],
80
+ ]
81
+
82
+ # Tạo giao diện Gradio
83
+ with gr.Blocks(title="VieNeu-TTS: Vietnamese Text-to-Speech") as demo:
84
+ gr.Markdown("""
85
+ # 🎙️ VieNeu-TTS: Vietnamese Text-to-Speech
86
+
87
+ Hệ thống tổng hợp tiếng nói tiếng Việt sử dụng mô hình ngôn ngữ lớn và neural codec.
88
+
89
+ **Hướng dẫn sử dụng:**
90
+ 1. Nhập văn bản tiếng Việt cần chuyển thành giọng nói
91
+ 2. Chọn giọng có sẵn hoặc tải lên audio tùy chỉnh
92
+ 3. Nhấn "Tổng hợp giọng nói"
93
+ """)
94
+
95
+ with gr.Row():
96
+ with gr.Column():
97
+ # Input text
98
+ text_input = gr.Textbox(
99
+ label="Văn bản đầu vào",
100
+ placeholder="Nhập văn bản tiếng Việt...",
101
+ lines=5
102
+ )
103
+
104
+ # Voice selection
105
+ voice_select = gr.Radio(
106
+ choices=list(VOICE_SAMPLES.keys()),
107
+ label="Chọn giọng có sẵn",
108
+ value="Nam miền Nam"
109
+ )
110
+
111
+ # Custom voice option
112
+ with gr.Accordion("Hoặc sử dụng giọng tùy chỉnh", open=False):
113
+ custom_audio = gr.Audio(
114
+ label="Tải lên file audio mẫu (.wav)",
115
+ type="filepath"
116
+ )
117
+ custom_text = gr.Textbox(
118
+ label="Nội dung của audio mẫu",
119
+ placeholder="Nhập chính xác nội dung trong audio...",
120
+ lines=3
121
+ )
122
+
123
+ # Submit button
124
+ submit_btn = gr.Button("🎵 Tổng hợp giọng nói", variant="primary")
125
+
126
+ with gr.Column():
127
+ # Output
128
+ audio_output = gr.Audio(label="Kết quả")
129
+ status_output = gr.Textbox(label="Trạng thái")
130
+
131
+ # Examples
132
+ gr.Examples(
133
+ examples=examples,
134
+ inputs=[text_input, voice_select],
135
+ outputs=[audio_output, status_output],
136
+ fn=synthesize_speech,
137
+ cache_examples=True
138
+ )
139
+
140
+ # Event handler
141
+ submit_btn.click(
142
+ fn=synthesize_speech,
143
+ inputs=[text_input, voice_select, custom_audio, custom_text],
144
+ outputs=[audio_output, status_output]
145
+ )
146
+
147
+ gr.Markdown("""
148
+ ---
149
+ **Lưu ý:**
150
+ - Model hỗ trợ tiếng Việt với các giọng miền Nam
151
+ - Chất lượng giọng nói phụ thuộc vào audio tham chiếu
152
+ - Để có kết quả tốt nhất, audio tham chiếu nên rõ ràng, không nhiễu
153
+
154
+ **Liên kết:**
155
+ - [GitHub Repository](https://github.com/pnnbao97/VieNeu-TTS)
156
+ - [Model on Hugging Face](https://huggingface.co/pnnbao-ump/VieNeu-TTS)
157
+ """)
158
+
159
+ if __name__ == "__main__":
160
+ demo.launch()
packages.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ espeak-ng
2
+ libespeak-ng1
3
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchaudio
4
+ transformers
5
+ librosa
6
+ soundfile
7
+ numpy
8
+ phonemizer
9
+ neucodec
10
+ espeak-ng
sample/id_0001.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Đến cuối thế kỷ 19, ngành đánh bắt cá được thương mại hóa.
sample/id_0001.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:632a5c8fa34fe03001cc3c44427b5e0ee70f767377bc788b59a5dc9afa9fba49
3
+ size 164492
sample/id_0002.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Từ nhiều nguồn tài liệu lịch sử, có thể thấy nuôi con theo phong cách Do Thái không chỉ tốt cho đứa trẻ mà còn tốt cho cả các bậc cha mẹ.
sample/id_0002.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8195d3fa6bfc1a049f2b61f2d4052e2d8caf69556b206d2d5eff48314c5ee907
3
+ size 435788
vieneutts.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Generator
3
+ import librosa
4
+ import numpy as np
5
+ import torch
6
+ from neucodec import NeuCodec, DistillNeuCodec
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ from phonemizer.backend.espeak.espeak import EspeakWrapper
9
+ from phonemizer import phonemize
10
+ import platform
11
+ import re
12
+ import os
13
+
14
+ # Cấu hình espeak cho các môi trường khác nhau
15
+ if platform.system() == "Windows":
16
+ EspeakWrapper.set_library(r"C:\Program Files\eSpeak NG\libespeak-ng.dll")
17
+ elif platform.system() == "Linux":
18
+ # Thử các path phổ biến
19
+ espeak_paths = [
20
+ "/usr/lib/x86_64-linux-gnu/libespeak-ng.so",
21
+ "/usr/lib/libespeak-ng.so",
22
+ "/usr/lib/x86_64-linux-gnu/libespeak-ng.so.1",
23
+ ]
24
+ for path in espeak_paths:
25
+ if os.path.exists(path):
26
+ EspeakWrapper.set_library(path)
27
+ break
28
+ else:
29
+ raise ValueError(f"Please set the espeak library path for your platform.")
30
+
31
+ def _linear_overlap_add(frames: list[np.ndarray], stride: int) -> np.ndarray:
32
+ # original impl --> https://github.com/facebookresearch/encodec/blob/main/encodec/utils.py
33
+ assert len(frames)
34
+ dtype = frames[0].dtype
35
+ shape = frames[0].shape[:-1]
36
+
37
+ total_size = 0
38
+ for i, frame in enumerate(frames):
39
+ frame_end = stride * i + frame.shape[-1]
40
+ total_size = max(total_size, frame_end)
41
+
42
+ sum_weight = np.zeros(total_size, dtype=dtype)
43
+ out = np.zeros(*shape, total_size, dtype=dtype)
44
+
45
+ offset: int = 0
46
+ for frame in frames:
47
+ frame_length = frame.shape[-1]
48
+ t = np.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
49
+ weight = np.abs(0.5 - (t - 0.5))
50
+
51
+ out[..., offset : offset + frame_length] += weight * frame
52
+ sum_weight[offset : offset + frame_length] += weight
53
+ offset += stride
54
+ assert sum_weight.min() > 0
55
+ return out / sum_weight
56
+
57
+ class VieNeuTTS:
58
+ def __init__(
59
+ self,
60
+ backbone_repo="pnnbao-ump/VieNeu-TTS",
61
+ backbone_device="cpu",
62
+ codec_repo="neuphonic/neucodec",
63
+ codec_device="cpu",
64
+ ):
65
+
66
+ # Constants
67
+ self.sample_rate = 24_000
68
+ self.max_context = 4096
69
+ self.hop_length = 480
70
+ self.streaming_overlap_frames = 1
71
+ self.streaming_frames_per_chunk = 25
72
+ self.streaming_lookforward = 5
73
+ self.streaming_lookback = 50
74
+ self.streaming_stride_samples = self.streaming_frames_per_chunk * self.hop_length
75
+
76
+ # ggml & onnx flags
77
+ self._is_quantized_model = False
78
+ self._is_onnx_codec = False
79
+
80
+ # HF tokenizer
81
+ self.tokenizer = None
82
+
83
+ # Load models
84
+ self._load_backbone(backbone_repo, backbone_device)
85
+ self._load_codec(codec_repo, codec_device)
86
+
87
+ def _load_backbone(self, backbone_repo, backbone_device):
88
+ print(f"Loading backbone from: {backbone_repo} on {backbone_device} ...")
89
+
90
+ if backbone_repo.lower().endswith("gguf") or "gguf" in backbone_repo.lower():
91
+ try:
92
+ from llama_cpp import Llama
93
+ except ImportError as e:
94
+ raise ImportError(
95
+ "Failed to import `llama_cpp`. "
96
+ "Please install it with:\n"
97
+ " pip install llama-cpp-python"
98
+ ) from e
99
+ self.backbone = Llama.from_pretrained(
100
+ repo_id=backbone_repo,
101
+ filename="*.gguf",
102
+ verbose=False,
103
+ n_gpu_layers=-1 if backbone_device == "gpu" else 0,
104
+ n_ctx=self.max_context,
105
+ mlock=True,
106
+ flash_attn=True if backbone_device == "gpu" else False,
107
+ )
108
+ self._is_quantized_model = True
109
+
110
+ else:
111
+ self.tokenizer = AutoTokenizer.from_pretrained(backbone_repo)
112
+ self.backbone = AutoModelForCausalLM.from_pretrained(backbone_repo).to(
113
+ torch.device(backbone_device)
114
+ )
115
+
116
+ def _load_codec(self, codec_repo, codec_device):
117
+ print(f"Loading codec from: {codec_repo} on {codec_device} ...")
118
+ match codec_repo:
119
+ case "neuphonic/neucodec":
120
+ self.codec = NeuCodec.from_pretrained(codec_repo)
121
+ self.codec.eval().to(codec_device)
122
+ case "neuphonic/distill-neucodec":
123
+ self.codec = DistillNeuCodec.from_pretrained(codec_repo)
124
+ self.codec.eval().to(codec_device)
125
+ case "neuphonic/neucodec-onnx-decoder":
126
+ if codec_device != "cpu":
127
+ raise ValueError("Onnx decoder only currently runs on CPU.")
128
+ try:
129
+ from neucodec import NeuCodecOnnxDecoder
130
+ except ImportError as e:
131
+ raise ImportError(
132
+ "Failed to import the onnx decoder."
133
+ " Ensure you have onnxruntime installed as well as neucodec >= 0.0.4."
134
+ ) from e
135
+ self.codec = NeuCodecOnnxDecoder.from_pretrained(codec_repo)
136
+ self._is_onnx_codec = True
137
+ case _:
138
+ raise ValueError(f"Unsupported codec repository: {codec_repo}")
139
+
140
+ def infer(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> np.ndarray:
141
+ """
142
+ Perform inference to generate speech from text using the TTS model and reference audio.
143
+
144
+ Args:
145
+ text (str): Input text to be converted to speech.
146
+ ref_codes (np.ndarray | torch.tensor): Encoded reference.
147
+ ref_text (str): Reference text for reference audio. Defaults to None.
148
+ Returns:
149
+ np.ndarray: Generated speech waveform.
150
+ """
151
+
152
+ # Generate tokens
153
+ if self._is_quantized_model:
154
+ output_str = self._infer_ggml(ref_codes, ref_text, text)
155
+ else:
156
+ prompt_ids = self._apply_chat_template(ref_codes, ref_text, text)
157
+ output_str = self._infer_torch(prompt_ids)
158
+
159
+ # Decode
160
+ wav = self._decode(output_str)
161
+
162
+ return wav
163
+
164
+ def infer_stream(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> Generator[np.ndarray, None, None]:
165
+ """
166
+ Perform streaming inference to generate speech from text using the TTS model and reference audio.
167
+
168
+ Args:
169
+ text (str): Input text to be converted to speech.
170
+ ref_codes (np.ndarray | torch.tensor): Encoded reference.
171
+ ref_text (str): Reference text for reference audio. Defaults to None.
172
+ Yields:
173
+ np.ndarray: Generated speech waveform.
174
+ """
175
+
176
+ if self._is_quantized_model:
177
+ return self._infer_stream_ggml(ref_codes, ref_text, text)
178
+ else:
179
+ raise NotImplementedError("Streaming is not implemented for the torch backend!")
180
+
181
+ def encode_reference(self, ref_audio_path: str | Path):
182
+ wav, _ = librosa.load(ref_audio_path, sr=16000, mono=True)
183
+ wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0) # [1, 1, T]
184
+ with torch.no_grad():
185
+ ref_codes = self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0)
186
+ return ref_codes
187
+
188
+ def _decode(self, codes: str):
189
+ """Decode speech tokens to audio waveform."""
190
+ # Extract speech token IDs using regex
191
+ speech_ids = [int(num) for num in re.findall(r"<\|speech_(\d+)\|>", codes)]
192
+
193
+ if len(speech_ids) == 0:
194
+ raise ValueError(
195
+ "No valid speech tokens found in the output. "
196
+ "The model may not have generated proper speech tokens."
197
+ )
198
+
199
+ # Onnx decode
200
+ if self._is_onnx_codec:
201
+ codes = np.array(speech_ids, dtype=np.int32)[np.newaxis, np.newaxis, :]
202
+ recon = self.codec.decode_code(codes)
203
+ # Torch decode
204
+ else:
205
+ with torch.no_grad():
206
+ codes = torch.tensor(speech_ids, dtype=torch.long)[None, None, :].to(
207
+ self.codec.device
208
+ )
209
+ recon = self.codec.decode_code(codes).cpu().numpy()
210
+
211
+ return recon[0, 0, :]
212
+
213
+ def _to_phones(self, text: str) -> str:
214
+ """Convert text to phonemes using phonemizer."""
215
+ phones = phonemize(
216
+ text,
217
+ language="vi",
218
+ backend="espeak",
219
+ preserve_punctuation=True,
220
+ with_stress=True,
221
+ language_switch="remove-flags"
222
+ )
223
+
224
+ # Handle both string and list returns
225
+ if isinstance(phones, list):
226
+ if len(phones) == 0:
227
+ raise ValueError(f"Phonemization failed for text: {text}")
228
+ return phones[0]
229
+ elif isinstance(phones, str):
230
+ return phones
231
+ else:
232
+ raise TypeError(f"Unexpected phonemize return type: {type(phones)}")
233
+
234
+ def _apply_chat_template(self, ref_codes: list[int], ref_text: str, input_text: str) -> list[int]:
235
+
236
+ input_text = self._to_phones(ref_text) + " " + self._to_phones(input_text)
237
+ speech_replace = self.tokenizer.convert_tokens_to_ids("<|SPEECH_REPLACE|>")
238
+ speech_gen_start = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_START|>")
239
+ text_replace = self.tokenizer.convert_tokens_to_ids("<|TEXT_REPLACE|>")
240
+ text_prompt_start = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_START|>")
241
+ text_prompt_end = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_END|>")
242
+
243
+ input_ids = self.tokenizer.encode(input_text, add_special_tokens=False)
244
+ chat = """user: Convert the text to speech:<|TEXT_REPLACE|>\nassistant:<|SPEECH_REPLACE|>"""
245
+ ids = self.tokenizer.encode(chat)
246
+
247
+ text_replace_idx = ids.index(text_replace)
248
+ ids = (
249
+ ids[:text_replace_idx]
250
+ + [text_prompt_start]
251
+ + input_ids
252
+ + [text_prompt_end]
253
+ + ids[text_replace_idx + 1 :] # noqa
254
+ )
255
+
256
+ speech_replace_idx = ids.index(speech_replace)
257
+ codes_str = "".join([f"<|speech_{i}|>" for i in ref_codes])
258
+ codes = self.tokenizer.encode(codes_str, add_special_tokens=False)
259
+ ids = ids[:speech_replace_idx] + [speech_gen_start] + list(codes)
260
+
261
+ return ids
262
+
263
+ def _infer_torch(self, prompt_ids: list[int]) -> str:
264
+ prompt_tensor = torch.tensor(prompt_ids).unsqueeze(0).to(self.backbone.device)
265
+ speech_end_id = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
266
+ with torch.no_grad():
267
+ output_tokens = self.backbone.generate(
268
+ prompt_tensor,
269
+ max_length=self.max_context,
270
+ eos_token_id=speech_end_id,
271
+ do_sample=True,
272
+ temperature=1.0,
273
+ top_k=50,
274
+ use_cache=True,
275
+ min_new_tokens=50,
276
+ )
277
+ input_length = prompt_tensor.shape[-1]
278
+ output_str = self.tokenizer.decode(
279
+ output_tokens[0, input_length:].cpu().numpy().tolist(), add_special_tokens=False
280
+ )
281
+ return output_str
282
+
283
+ def _infer_ggml(self, ref_codes: list[int], ref_text: str, input_text: str) -> str:
284
+ ref_text = self._to_phones(ref_text)
285
+ input_text = self._to_phones(input_text)
286
+
287
+ codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
288
+ prompt = (
289
+ f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}"
290
+ f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
291
+ )
292
+ output = self.backbone(
293
+ prompt,
294
+ max_tokens=self.max_context,
295
+ temperature=1.0,
296
+ top_k=50,
297
+ stop=["<|SPEECH_GENERATION_END|>"],
298
+ )
299
+ output_str = output["choices"][0]["text"]
300
+ return output_str
301
+
302
+ def _infer_stream_ggml(self, ref_codes: torch.Tensor, ref_text: str, input_text: str) -> Generator[np.ndarray, None, None]:
303
+ ref_text = self._to_phones(ref_text)
304
+ input_text = self._to_phones(input_text)
305
+
306
+ codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
307
+ prompt = (
308
+ f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}"
309
+ f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
310
+ )
311
+
312
+ audio_cache: list[np.ndarray] = []
313
+ token_cache: list[str] = [f"<|speech_{idx}|>" for idx in ref_codes]
314
+ n_decoded_samples: int = 0
315
+ n_decoded_tokens: int = len(ref_codes)
316
+
317
+ for item in self.backbone(
318
+ prompt,
319
+ max_tokens=self.max_context,
320
+ temperature=0.2,
321
+ top_k=50,
322
+ stop=["<|SPEECH_GENERATION_END|>"],
323
+ stream=True
324
+ ):
325
+ output_str = item["choices"][0]["text"]
326
+ token_cache.append(output_str)
327
+
328
+ if len(token_cache[n_decoded_tokens:]) >= self.streaming_frames_per_chunk + self.streaming_lookforward:
329
+
330
+ # decode chunk
331
+ tokens_start = max(
332
+ n_decoded_tokens
333
+ - self.streaming_lookback
334
+ - self.streaming_overlap_frames,
335
+ 0
336
+ )
337
+ tokens_end = (
338
+ n_decoded_tokens
339
+ + self.streaming_frames_per_chunk
340
+ + self.streaming_lookforward
341
+ + self.streaming_overlap_frames
342
+ )
343
+ sample_start = (
344
+ n_decoded_tokens - tokens_start
345
+ ) * self.hop_length
346
+ sample_end = (
347
+ sample_start
348
+ + (self.streaming_frames_per_chunk + 2 * self.streaming_overlap_frames) * self.hop_length
349
+ )
350
+ curr_codes = token_cache[tokens_start:tokens_end]
351
+ recon = self._decode("".join(curr_codes))
352
+ recon = recon[sample_start:sample_end]
353
+ audio_cache.append(recon)
354
+
355
+ # postprocess
356
+ processed_recon = _linear_overlap_add(
357
+ audio_cache, stride=self.streaming_stride_samples
358
+ )
359
+ new_samples_end = len(audio_cache) * self.streaming_stride_samples
360
+ processed_recon = processed_recon[
361
+ n_decoded_samples:new_samples_end
362
+ ]
363
+ n_decoded_samples = new_samples_end
364
+ n_decoded_tokens += self.streaming_frames_per_chunk
365
+ yield processed_recon
366
+
367
+ # final decoding handled separately as non-constant chunk size
368
+ remaining_tokens = len(token_cache) - n_decoded_tokens
369
+ if len(token_cache) > n_decoded_tokens:
370
+ tokens_start = max(
371
+ len(token_cache)
372
+ - (self.streaming_lookback + self.streaming_overlap_frames + remaining_tokens),
373
+ 0
374
+ )
375
+ sample_start = (
376
+ len(token_cache)
377
+ - tokens_start
378
+ - remaining_tokens
379
+ - self.streaming_overlap_frames
380
+ ) * self.hop_length
381
+ curr_codes = token_cache[tokens_start:]
382
+ recon = self._decode("".join(curr_codes))
383
+ recon = recon[sample_start:]
384
+ audio_cache.append(recon)
385
+
386
+ processed_recon = _linear_overlap_add(audio_cache, stride=self.streaming_stride_samples)
387
+ processed_recon = processed_recon[n_decoded_samples:]
388
+ yield processed_recon