Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import gradio as gr
|
|
|
|
| 2 |
import os
|
| 3 |
import tempfile
|
| 4 |
import time
|
|
@@ -154,13 +155,33 @@ def get_kokoro_voices():
|
|
| 154 |
"zf_xiaobei", "zf_xiaoni", "zf_xiaoxiao", "zf_xiaoyi", "zm_yunjian", "zm_yunxi", "zm_yunxia", "zm_yunyang"
|
| 155 |
]
|
| 156 |
|
| 157 |
-
def
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
if not text or not text.strip():
|
| 165 |
raise gr.Error("Please enter text to synthesize.")
|
| 166 |
|
|
@@ -171,72 +192,80 @@ def kokoro_tts(text: str, speed: float, voice: str) -> str:
|
|
| 171 |
if pipeline is None:
|
| 172 |
raise gr.Error("Kokoro English pipeline not initialized.")
|
| 173 |
|
| 174 |
-
sr = 24_000
|
| 175 |
-
|
| 176 |
-
# Process ALL segments for longer audio generation
|
| 177 |
-
audio_segments = []
|
| 178 |
pack = pipeline.load_voice(voice)
|
| 179 |
|
| 180 |
try:
|
| 181 |
-
|
| 182 |
-
total_segments = len(segments)
|
| 183 |
-
|
| 184 |
-
for idx, (_, ps, _) in enumerate(segments):
|
| 185 |
ref_s = pack[len(ps) - 1]
|
| 186 |
try:
|
| 187 |
audio = model(ps, ref_s, float(speed))
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
print(f"Progress: Generated {idx + 1}/{total_segments} segments...")
|
| 191 |
except Exception as e:
|
| 192 |
raise gr.Error(f"Error generating audio for segment {idx + 1}: {str(e)[:200]}...")
|
| 193 |
-
|
| 194 |
-
if not audio_segments:
|
| 195 |
-
raise gr.Error("No audio was generated.")
|
| 196 |
-
|
| 197 |
-
# Concatenate all segments to create the complete audio
|
| 198 |
-
if len(audio_segments) == 1:
|
| 199 |
-
audio_np = audio_segments[0]
|
| 200 |
-
else:
|
| 201 |
-
audio_np = np.concatenate(audio_segments, axis=0)
|
| 202 |
-
duration = len(audio_np) / sr
|
| 203 |
-
print(f"Completed: {total_segments} segments concatenated into {duration:.1f} seconds of audio")
|
| 204 |
-
|
| 205 |
except gr.Error:
|
| 206 |
raise
|
| 207 |
except Exception as e:
|
| 208 |
raise gr.Error(f"Error during speech generation: {str(e)[:200]}...")
|
| 209 |
|
| 210 |
-
# Convert to 16-bit PCM and write to WAV file
|
| 211 |
-
audio_clipped = np.clip(audio_np, -1.0, 1.0)
|
| 212 |
-
audio_int16 = (audio_clipped * 32767.0).astype(np.int16)
|
| 213 |
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
|
|
|
|
|
|
| 221 |
|
| 222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
|
| 224 |
# Main dispatcher function to handle all services
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
def generate_tts(text, service, openai_api_key, openai_model, openai_voice,
|
| 226 |
elevenlabs_api_key, elevenlabs_voice, voice_dict,
|
| 227 |
kokoro_speed, kokoro_voice):
|
| 228 |
"""Route to appropriate TTS service based on selection"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
if service == "OpenAI":
|
| 230 |
-
|
| 231 |
elif service == "ElevenLabs":
|
| 232 |
voice_id = voice_dict.get(elevenlabs_voice, elevenlabs_voice)
|
| 233 |
-
|
| 234 |
-
elif service == "Kokoro":
|
| 235 |
-
return kokoro_tts(text, kokoro_speed, kokoro_voice)
|
| 236 |
else:
|
| 237 |
-
# Fallback in case of an unknown service
|
| 238 |
raise gr.Error(f"Unknown service selected: {service}")
|
| 239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
# Function to update ElevenLabs voices when API key changes
|
| 241 |
def update_elevenlabs_voices(api_key):
|
| 242 |
"""Update voice dropdown when API key is entered"""
|
|
@@ -341,6 +370,9 @@ with gr.Blocks(theme='Nymbo/Alyx_Theme') as demo:
|
|
| 341 |
|
| 342 |
audio_output = gr.Audio(
|
| 343 |
label="Generated Speech",
|
|
|
|
|
|
|
|
|
|
| 344 |
)
|
| 345 |
|
| 346 |
# ==========================
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import io
|
| 3 |
import os
|
| 4 |
import tempfile
|
| 5 |
import time
|
|
|
|
| 155 |
"zf_xiaobei", "zf_xiaoni", "zf_xiaoxiao", "zf_xiaoyi", "zm_yunjian", "zm_yunxi", "zm_yunxia", "zm_yunyang"
|
| 156 |
]
|
| 157 |
|
| 158 |
+
def _audio_np_to_int16(audio_np: np.ndarray) -> np.ndarray:
|
| 159 |
+
audio_clipped = np.clip(audio_np, -1.0, 1.0)
|
| 160 |
+
return (audio_clipped * 32767.0).astype(np.int16)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def _write_wav_file(audio_int16: np.ndarray, sample_rate: int = 24_000) -> str:
|
| 164 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
| 165 |
+
path = tmp.name
|
| 166 |
+
with wave.open(path, "wb") as wf:
|
| 167 |
+
wf.setnchannels(1)
|
| 168 |
+
wf.setsampwidth(2)
|
| 169 |
+
wf.setframerate(sample_rate)
|
| 170 |
+
wf.writeframes(audio_int16.tobytes())
|
| 171 |
+
return path
|
| 172 |
+
|
| 173 |
|
| 174 |
+
def _wav_bytes_from_int16(audio_int16: np.ndarray, sample_rate: int = 24_000) -> bytes:
|
| 175 |
+
buffer = io.BytesIO()
|
| 176 |
+
with wave.open(buffer, "wb") as wf:
|
| 177 |
+
wf.setnchannels(1)
|
| 178 |
+
wf.setsampwidth(2)
|
| 179 |
+
wf.setframerate(sample_rate)
|
| 180 |
+
wf.writeframes(audio_int16.tobytes())
|
| 181 |
+
return buffer.getvalue()
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def _kokoro_segment_generator(text: str, speed: float, voice: str):
|
| 185 |
if not text or not text.strip():
|
| 186 |
raise gr.Error("Please enter text to synthesize.")
|
| 187 |
|
|
|
|
| 192 |
if pipeline is None:
|
| 193 |
raise gr.Error("Kokoro English pipeline not initialized.")
|
| 194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
pack = pipeline.load_voice(voice)
|
| 196 |
|
| 197 |
try:
|
| 198 |
+
for idx, (_, ps, _) in enumerate(pipeline(text, voice, speed)):
|
|
|
|
|
|
|
|
|
|
| 199 |
ref_s = pack[len(ps) - 1]
|
| 200 |
try:
|
| 201 |
audio = model(ps, ref_s, float(speed))
|
| 202 |
+
audio_np = audio.detach().cpu().numpy()
|
| 203 |
+
yield audio_np
|
|
|
|
| 204 |
except Exception as e:
|
| 205 |
raise gr.Error(f"Error generating audio for segment {idx + 1}: {str(e)[:200]}...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
except gr.Error:
|
| 207 |
raise
|
| 208 |
except Exception as e:
|
| 209 |
raise gr.Error(f"Error during speech generation: {str(e)[:200]}...")
|
| 210 |
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
+
def kokoro_tts(text: str, speed: float, voice: str) -> str:
|
| 213 |
+
sr = 24_000
|
| 214 |
+
segments = list(_kokoro_segment_generator(text, speed, voice))
|
| 215 |
+
if not segments:
|
| 216 |
+
raise gr.Error("No audio was generated.")
|
| 217 |
+
|
| 218 |
+
audio_np = segments[0] if len(segments) == 1 else np.concatenate(segments, axis=0)
|
| 219 |
+
audio_int16 = _audio_np_to_int16(audio_np)
|
| 220 |
+
return _write_wav_file(audio_int16, sr)
|
| 221 |
|
| 222 |
+
|
| 223 |
+
def kokoro_tts_stream(text: str, speed: float, voice: str):
|
| 224 |
+
sr = 24_000
|
| 225 |
+
produced_any = False
|
| 226 |
+
|
| 227 |
+
for audio_np in _kokoro_segment_generator(text, speed, voice):
|
| 228 |
+
produced_any = True
|
| 229 |
+
audio_int16 = _audio_np_to_int16(audio_np)
|
| 230 |
+
chunk_bytes = _wav_bytes_from_int16(audio_int16, sr)
|
| 231 |
+
yield chunk_bytes
|
| 232 |
+
|
| 233 |
+
if not produced_any:
|
| 234 |
+
raise gr.Error("No audio was generated.")
|
| 235 |
|
| 236 |
# Main dispatcher function to handle all services
|
| 237 |
+
def _read_file_bytes(path: str) -> bytes:
|
| 238 |
+
with open(path, "rb") as file:
|
| 239 |
+
data = file.read()
|
| 240 |
+
return data
|
| 241 |
+
|
| 242 |
+
|
| 243 |
def generate_tts(text, service, openai_api_key, openai_model, openai_voice,
|
| 244 |
elevenlabs_api_key, elevenlabs_voice, voice_dict,
|
| 245 |
kokoro_speed, kokoro_voice):
|
| 246 |
"""Route to appropriate TTS service based on selection"""
|
| 247 |
+
if service == "Kokoro":
|
| 248 |
+
yield from kokoro_tts_stream(text, kokoro_speed, kokoro_voice)
|
| 249 |
+
return
|
| 250 |
+
|
| 251 |
if service == "OpenAI":
|
| 252 |
+
file_path = openai_tts(text, openai_model, openai_voice, openai_api_key)
|
| 253 |
elif service == "ElevenLabs":
|
| 254 |
voice_id = voice_dict.get(elevenlabs_voice, elevenlabs_voice)
|
| 255 |
+
file_path = elevenlabs_tts(text, voice_id, elevenlabs_api_key)
|
|
|
|
|
|
|
| 256 |
else:
|
|
|
|
| 257 |
raise gr.Error(f"Unknown service selected: {service}")
|
| 258 |
|
| 259 |
+
try:
|
| 260 |
+
audio_bytes = _read_file_bytes(file_path)
|
| 261 |
+
finally:
|
| 262 |
+
try:
|
| 263 |
+
os.remove(file_path)
|
| 264 |
+
except OSError:
|
| 265 |
+
pass
|
| 266 |
+
|
| 267 |
+
yield audio_bytes
|
| 268 |
+
|
| 269 |
# Function to update ElevenLabs voices when API key changes
|
| 270 |
def update_elevenlabs_voices(api_key):
|
| 271 |
"""Update voice dropdown when API key is entered"""
|
|
|
|
| 370 |
|
| 371 |
audio_output = gr.Audio(
|
| 372 |
label="Generated Speech",
|
| 373 |
+
streaming=True,
|
| 374 |
+
autoplay=True,
|
| 375 |
+
show_download_button=True,
|
| 376 |
)
|
| 377 |
|
| 378 |
# ==========================
|