DMOSpeech2 / app.py
mrfakename's picture
Link to GitHub Repo
8a739c0
## IMPORTS ##
import os
import tempfile
import time
from pathlib import Path
import gradio as gr
import numpy as np
import spaces
import torch
import torchaudio
from cached_path import cached_path
from huggingface_hub import hf_hub_download
from transformers import pipeline
from infer import DMOInference
## CUDA DEVICE ##
device = "cuda" if torch.cuda.is_available() else "cpu"
## LOAD MODELS ##
asr_pipe = pipeline(
"automatic-speech-recognition", model="openai/whisper-large-v3-turbo", device=device
)
model = DMOInference(
student_checkpoint_path=str(cached_path("hf://yl4579/DMOSpeech2/model_85000.pt")),
duration_predictor_path=str(cached_path("hf://yl4579/DMOSpeech2/model_1500.pt")),
device=device,
model_type="F5TTS_Base",
)
def transcribe(ref_audio, language=None):
"""Transcribe audio using the pre-loaded ASR pipeline."""
return asr_pipe(
ref_audio,
chunk_length_s=30,
batch_size=128,
generate_kwargs=(
{"task": "transcribe", "language": language}
if language
else {"task": "transcribe"}
),
return_timestamps=False,
)["text"].strip()
MODES = {
"Student Only (4 steps)": {
"teacher_steps": 0,
"teacher_stopping_time": 1.0,
"student_start_step": 0,
"description": "Fastest (4 steps), good quality"
},
"Teacher-Guided (8 steps)": {
"teacher_steps": 16,
"teacher_stopping_time": 0.07,
"student_start_step": 1,
"description": "Best balance (8 steps), recommended"
},
"High Diversity (16 steps)": {
"teacher_steps": 24,
"teacher_stopping_time": 0.3,
"student_start_step": 2,
"description": "More natural prosody (16 steps)"
},
"Custom": {
"teacher_steps": None,
"teacher_stopping_time": None,
"student_start_step": None,
"description": "Fine-tune all parameters"
}
}
@spaces.GPU(duration=120)
def generate_speech(
prompt_audio,
prompt_text,
target_text,
mode,
temperature,
custom_teacher_steps,
custom_teacher_stopping_time,
custom_student_start_step,
verbose,
):
if prompt_audio is None:
raise gr.Error("Please upload a reference audio!")
if not target_text:
raise gr.Error("Please enter text to generate!")
if not prompt_text and prompt_text != "":
prompt_text = transcribe(prompt_audio)
if mode == "Custom":
teacher_steps, teacher_stopping_time, student_start_step = custom_teacher_steps, custom_teacher_stopping_time, custom_student_start_step
else:
teacher_steps = MODES[mode]["teacher_steps"]
teacher_stopping_time = MODES[mode]["teacher_stopping_time"]
student_start_step = MODES[mode]["student_start_step"]
generated_audio = model.generate(
gen_text=target_text,
audio_path=prompt_audio,
prompt_text=prompt_text if prompt_text else None,
teacher_steps=teacher_steps,
teacher_stopping_time=teacher_stopping_time,
student_start_step=student_start_step,
temperature=temperature,
verbose=verbose,
)
if isinstance(generated_audio, torch.Tensor):
audio_np = generated_audio.cpu().numpy()
else:
audio_np = generated_audio
# Ensure audio is properly normalized and in the correct format
if audio_np.ndim == 2 and audio_np.shape[0] == 1:
audio_np = audio_np.squeeze(0) # Remove batch dimension if present
# Normalize audio to [-1, 1] range if needed
if np.abs(audio_np).max() > 1.0:
audio_np = audio_np / np.abs(audio_np).max()
# Ensure audio is in float32 format
audio_np = audio_np.astype(np.float32)
return (24000, audio_np)
# Create Gradio interface
with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS") as demo:
gr.Markdown(
f"""
# πŸŽ™οΈ DMOSpeech 2: Zero-Shot Text-to-Speech
[GitHub Repo](https://github.com/yl4579/DMOSpeech2)
Generate natural speech in any voice with just a short reference audio!
"""
)
with gr.Row():
with gr.Column(scale=1):
# Reference audio input
prompt_audio = gr.Audio(
label="πŸ“Ž Reference Audio",
type="filepath",
sources=["upload", "microphone"],
)
prompt_text = gr.Textbox(
label="πŸ“ Reference Text (leave empty for auto-transcription)",
placeholder="The text spoken in the reference audio...",
lines=2,
)
target_text = gr.Textbox(
label="✍️ Text to Generate",
placeholder="Enter the text you want to synthesize...",
lines=4,
)
# Generation mode
mode = gr.Radio(
choices=[
"Student Only (4 steps)",
"Teacher-Guided (8 steps)",
"High Diversity (16 steps)",
"Custom",
],
value="Teacher-Guided (8 steps)",
label="πŸš€ Generation Mode",
info="Choose speed vs quality/diversity tradeoff",
)
# Advanced settings (collapsible)
with gr.Accordion("βš™οΈ Advanced Settings", open=False):
temperature = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.0,
step=0.1,
label="Duration Temperature",
info="0 = deterministic, >0 = more variation in speech rhythm",
)
with gr.Group(visible=False) as custom_settings:
gr.Markdown("### Custom Mode Settings")
custom_teacher_steps = gr.Slider(
minimum=0,
maximum=32,
value=16,
step=1,
label="Teacher Steps",
info="More steps = higher quality",
)
custom_teacher_stopping_time = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.07,
step=0.01,
label="Teacher Stopping Time",
info="When to switch to student",
)
custom_student_start_step = gr.Slider(
minimum=0,
maximum=4,
value=1,
step=1,
label="Student Start Step",
info="Which student step to start from",
)
verbose = gr.Checkbox(
value=False,
label="Verbose Output",
info="Show detailed generation steps",
)
generate_btn = gr.Button("🎡 Generate Speech", variant="primary", size="lg")
with gr.Column(scale=1):
# Output
output_audio = gr.Audio(
label="πŸ”Š Generated Speech", type="filepath", autoplay=True
)
# Tips
gr.Markdown(
"""
### πŸ’‘ Quick Tips:
- **Auto-transcription**: Leave reference text empty to auto-transcribe
- **Student Only**: Fastest (4 steps), good quality
- **Teacher-Guided**: Best balance (8 steps), recommended
- **High Diversity**: More natural prosody (16 steps)
- **Custom Mode**: Fine-tune all parameters
### πŸ“Š Expected RTF (Real-Time Factor):
- Student Only: ~0.05x (20x faster than real-time)
- Teacher-Guided: ~0.10x (10x faster)
- High Diversity: ~0.20x (5x faster)
"""
)
# Event handler
generate_btn.click(
generate_speech,
inputs=[
prompt_audio,
prompt_text,
target_text,
mode,
temperature,
custom_teacher_steps,
custom_teacher_stopping_time,
custom_student_start_step,
verbose,
],
outputs=[output_audio],
)
mode.change(lambda x: gr.update(visible=x == "Custom"), inputs=[mode], outputs=[custom_settings])
demo.queue().launch()