kani-tts-2-pt / app.py
Simonlob's picture
Update app.py
aba236b verified
from create_env import setup_dependencies
setup_dependencies()
import spaces
import gradio as gr
from util import InitModels, load_config, Examples, SpeakerManager
import numpy as np
import torch
config = load_config("./model_config.yaml")
models_configs = config.models
examples_cfg = load_config("./examples.yaml")
examples_maker = Examples(examples_cfg)
examples = examples_maker()
init_models = InitModels(models_configs)
models = init_models()
# Initialize speaker manager
speaker_manager = SpeakerManager()
@spaces.GPU
def generate_embedding_gpu(audio_data):
"""
Generate speaker embedding from audio on GPU
Returns: (embedding_tensor, status_message)
"""
try:
if audio_data is None:
return None, "No audio provided"
embedding = speaker_manager.generate_embedding(audio_data)
print("Embedding generated successfully!")
# Move to CPU for State storage (gradio State uses pickle)
embedding_cpu = embedding.cpu()
return embedding_cpu, "✅ Embedding ready"
except Exception as e:
print(f"Error generating embedding: {str(e)}")
return None, f"Error: {str(e)}"
@spaces.GPU
def generate_speech_gpu(text, model_choice, mode, speaker_choice, embedding_from_state, json_input, t, top_p, rp):
"""
Generate speech from text using the selected model on GPU
"""
if not text.strip():
return None
if not model_choice:
return None
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
selected_model = models[model_choice]
# Get speaker embedding based on mode
if mode == "generate":
# Use embedding from State
speaker_emb = embedding_from_state
if speaker_emb is not None and torch.is_tensor(speaker_emb):
# Move to GPU if needed
speaker_emb = speaker_emb.to(device)
elif mode == "json":
# Parse embedding from JSON input
speaker_emb = speaker_manager.get_speaker_emb(mode, json_emb=json_input)
if speaker_emb is not None and torch.is_tensor(speaker_emb):
# Move to GPU if needed
speaker_emb = speaker_emb.to(device)
else: # mode == "select"
# Use speaker path from speaker_map
speaker_emb = speaker_manager.get_speaker_emb(mode, speaker_choice)
print(f"Generating speech with {model_choice}...")
audio, _ = selected_model(
text,
speaker_emb=speaker_emb,
temperature=t,
top_p=top_p,
repetition_penalty=rp
)
sample_rate = 22050
print("Speech generation completed!")
return (sample_rate, audio)
except Exception as e:
print(f"Error during generation: {str(e)}")
return None
# Create Gradio interface
with gr.Blocks(title="😻 KaniTTS2 - Text to Speech", theme=gr.themes.Ocean()) as demo:
gr.Markdown("# 😻 KaniTTS2: Fast and Expressive Speech Generation Model")
gr.Markdown("Select a model and enter text to generate emotional speech")
# State for storing generated embedding (invisible to user)
embedding_state = gr.State(None)
with gr.Row():
with gr.Column(scale=1):
model_dropdown = gr.Dropdown(
choices=list(models_configs.keys()),
value=list(models_configs.keys())[0],
label="Selected Model"
)
# Speaker mode selector
speaker_mode = gr.Radio(
choices=[("Select", "select"), ("Clone", "generate"), ("JSON", "json")],
value="select",
label="Speaker Mode"
)
# Speaker selection (visible in "select" mode)
speaker_dropdown = gr.Dropdown(
choices=speaker_manager.get_speaker_names(),
value=speaker_manager.get_speaker_names()[0] if speaker_manager.get_speaker_names() else None,
label="Speaker",
visible=True
)
# Audio upload and embedding generation (visible in "generate" mode)
with gr.Group(visible=False) as embedding_group:
audio_input = gr.Audio(
label="Upload or Record Audio (will be resampled to 16kHz)",
type="numpy",
sources=["upload", "microphone"],
format="wav",
)
with gr.Row():
run_embedding_btn = gr.Button("Extract Embedding", variant="primary")
clean_embedding_btn = gr.Button("Clear", variant="stop")
embedding_status = gr.Textbox(
label="Embedding Status",
value="No embedding generated",
interactive=False
)
# JSON embedding input (visible in "json" mode)
with gr.Group(visible=False) as json_group:
json_input = gr.Textbox(
label="Speaker Embedding (JSON)",
placeholder='Paste 128-dimensional embedding as JSON array: [0.123, -0.456, ...]',
lines=6,
info="Enter a list of 128 floating-point numbers"
)
with gr.Column(scale=1):
text_input = gr.Textbox(
label="Text",
placeholder="Enter your text ...",
lines=3,
max_lines=10
)
with gr.Accordion("Settings", open=False):
temp = gr.Slider(
minimum=0.1, maximum=1.5, value=1.0, step=0.05,
label="Temp",
)
top_p = gr.Slider(
minimum=0.1, maximum=1.0, value=0.95, step=0.05,
label="Top P",
)
rp = gr.Slider(
minimum=1.0, maximum=2.0, value=1.1, step=0.05,
label="Repetition Penalty",
)
generate_btn = gr.Button("Generate", variant="primary", size="lg")
audio_output = gr.Audio(
label="Generated Audio",
type="numpy"
)
# Toggle visibility based on speaker mode
def toggle_speaker_mode(mode):
if mode == "select":
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
elif mode == "generate":
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
else: # json
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
speaker_mode.change(
fn=toggle_speaker_mode,
inputs=[speaker_mode],
outputs=[speaker_dropdown, embedding_group, json_group]
)
# Embedding generation events
run_embedding_btn.click(
fn=generate_embedding_gpu,
inputs=[audio_input],
outputs=[embedding_state, embedding_status]
)
def clean_embedding():
"""Clear embedding state"""
return None, "Embedding cleared"
clean_embedding_btn.click(
fn=clean_embedding,
inputs=[],
outputs=[embedding_state, embedding_status]
)
# GPU generation event
generate_btn.click(
fn=generate_speech_gpu,
inputs=[text_input, model_dropdown, speaker_mode, speaker_dropdown, embedding_state, json_input, temp, top_p, rp],
outputs=[audio_output]
)
def load_example_input(text, model, mode, speaker, emb, json_emb, t, p, r):
"""Load example and populate inputs without generating audio yet."""
# Return values for all inputs + embedding_state output
return text, model, mode, speaker, emb, json_emb, t, p, r
with gr.Row():
examples = examples
gr.Examples(
examples=examples,
inputs=[text_input, model_dropdown, speaker_mode, speaker_dropdown, embedding_state, json_input, temp, top_p, rp],
outputs=[text_input, model_dropdown, speaker_mode, speaker_dropdown, embedding_state, json_input, temp, top_p, rp],
fn=load_example_input,
cache_examples=False,
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)