Spaces:
Running on Zero
Running on Zero
| from create_env import setup_dependencies | |
| setup_dependencies() | |
| import spaces | |
| import gradio as gr | |
| from util import InitModels, load_config, Examples, SpeakerManager | |
| import numpy as np | |
| import torch | |
| config = load_config("./model_config.yaml") | |
| models_configs = config.models | |
| examples_cfg = load_config("./examples.yaml") | |
| examples_maker = Examples(examples_cfg) | |
| examples = examples_maker() | |
| init_models = InitModels(models_configs) | |
| models = init_models() | |
| # Initialize speaker manager | |
| speaker_manager = SpeakerManager() | |
| def generate_embedding_gpu(audio_data): | |
| """ | |
| Generate speaker embedding from audio on GPU | |
| Returns: (embedding_tensor, status_message) | |
| """ | |
| try: | |
| if audio_data is None: | |
| return None, "No audio provided" | |
| embedding = speaker_manager.generate_embedding(audio_data) | |
| print("Embedding generated successfully!") | |
| # Move to CPU for State storage (gradio State uses pickle) | |
| embedding_cpu = embedding.cpu() | |
| return embedding_cpu, "✅ Embedding ready" | |
| except Exception as e: | |
| print(f"Error generating embedding: {str(e)}") | |
| return None, f"Error: {str(e)}" | |
| def generate_speech_gpu(text, model_choice, mode, speaker_choice, embedding_from_state, json_input, t, top_p, rp): | |
| """ | |
| Generate speech from text using the selected model on GPU | |
| """ | |
| if not text.strip(): | |
| return None | |
| if not model_choice: | |
| return None | |
| try: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| selected_model = models[model_choice] | |
| # Get speaker embedding based on mode | |
| if mode == "generate": | |
| # Use embedding from State | |
| speaker_emb = embedding_from_state | |
| if speaker_emb is not None and torch.is_tensor(speaker_emb): | |
| # Move to GPU if needed | |
| speaker_emb = speaker_emb.to(device) | |
| elif mode == "json": | |
| # Parse embedding from JSON input | |
| speaker_emb = speaker_manager.get_speaker_emb(mode, json_emb=json_input) | |
| if speaker_emb is not None and torch.is_tensor(speaker_emb): | |
| # Move to GPU if needed | |
| speaker_emb = speaker_emb.to(device) | |
| else: # mode == "select" | |
| # Use speaker path from speaker_map | |
| speaker_emb = speaker_manager.get_speaker_emb(mode, speaker_choice) | |
| print(f"Generating speech with {model_choice}...") | |
| audio, _ = selected_model( | |
| text, | |
| speaker_emb=speaker_emb, | |
| temperature=t, | |
| top_p=top_p, | |
| repetition_penalty=rp | |
| ) | |
| sample_rate = 22050 | |
| print("Speech generation completed!") | |
| return (sample_rate, audio) | |
| except Exception as e: | |
| print(f"Error during generation: {str(e)}") | |
| return None | |
| # Create Gradio interface | |
| with gr.Blocks(title="😻 KaniTTS2 - Text to Speech", theme=gr.themes.Ocean()) as demo: | |
| gr.Markdown("# 😻 KaniTTS2: Fast and Expressive Speech Generation Model") | |
| gr.Markdown("Select a model and enter text to generate emotional speech") | |
| # State for storing generated embedding (invisible to user) | |
| embedding_state = gr.State(None) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_dropdown = gr.Dropdown( | |
| choices=list(models_configs.keys()), | |
| value=list(models_configs.keys())[0], | |
| label="Selected Model" | |
| ) | |
| # Speaker mode selector | |
| speaker_mode = gr.Radio( | |
| choices=[("Select", "select"), ("Clone", "generate"), ("JSON", "json")], | |
| value="select", | |
| label="Speaker Mode" | |
| ) | |
| # Speaker selection (visible in "select" mode) | |
| speaker_dropdown = gr.Dropdown( | |
| choices=speaker_manager.get_speaker_names(), | |
| value=speaker_manager.get_speaker_names()[0] if speaker_manager.get_speaker_names() else None, | |
| label="Speaker", | |
| visible=True | |
| ) | |
| # Audio upload and embedding generation (visible in "generate" mode) | |
| with gr.Group(visible=False) as embedding_group: | |
| audio_input = gr.Audio( | |
| label="Upload or Record Audio (will be resampled to 16kHz)", | |
| type="numpy", | |
| sources=["upload", "microphone"], | |
| format="wav", | |
| ) | |
| with gr.Row(): | |
| run_embedding_btn = gr.Button("Extract Embedding", variant="primary") | |
| clean_embedding_btn = gr.Button("Clear", variant="stop") | |
| embedding_status = gr.Textbox( | |
| label="Embedding Status", | |
| value="No embedding generated", | |
| interactive=False | |
| ) | |
| # JSON embedding input (visible in "json" mode) | |
| with gr.Group(visible=False) as json_group: | |
| json_input = gr.Textbox( | |
| label="Speaker Embedding (JSON)", | |
| placeholder='Paste 128-dimensional embedding as JSON array: [0.123, -0.456, ...]', | |
| lines=6, | |
| info="Enter a list of 128 floating-point numbers" | |
| ) | |
| with gr.Column(scale=1): | |
| text_input = gr.Textbox( | |
| label="Text", | |
| placeholder="Enter your text ...", | |
| lines=3, | |
| max_lines=10 | |
| ) | |
| with gr.Accordion("Settings", open=False): | |
| temp = gr.Slider( | |
| minimum=0.1, maximum=1.5, value=1.0, step=0.05, | |
| label="Temp", | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, maximum=1.0, value=0.95, step=0.05, | |
| label="Top P", | |
| ) | |
| rp = gr.Slider( | |
| minimum=1.0, maximum=2.0, value=1.1, step=0.05, | |
| label="Repetition Penalty", | |
| ) | |
| generate_btn = gr.Button("Generate", variant="primary", size="lg") | |
| audio_output = gr.Audio( | |
| label="Generated Audio", | |
| type="numpy" | |
| ) | |
| # Toggle visibility based on speaker mode | |
| def toggle_speaker_mode(mode): | |
| if mode == "select": | |
| return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) | |
| elif mode == "generate": | |
| return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) | |
| else: # json | |
| return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) | |
| speaker_mode.change( | |
| fn=toggle_speaker_mode, | |
| inputs=[speaker_mode], | |
| outputs=[speaker_dropdown, embedding_group, json_group] | |
| ) | |
| # Embedding generation events | |
| run_embedding_btn.click( | |
| fn=generate_embedding_gpu, | |
| inputs=[audio_input], | |
| outputs=[embedding_state, embedding_status] | |
| ) | |
| def clean_embedding(): | |
| """Clear embedding state""" | |
| return None, "Embedding cleared" | |
| clean_embedding_btn.click( | |
| fn=clean_embedding, | |
| inputs=[], | |
| outputs=[embedding_state, embedding_status] | |
| ) | |
| # GPU generation event | |
| generate_btn.click( | |
| fn=generate_speech_gpu, | |
| inputs=[text_input, model_dropdown, speaker_mode, speaker_dropdown, embedding_state, json_input, temp, top_p, rp], | |
| outputs=[audio_output] | |
| ) | |
| def load_example_input(text, model, mode, speaker, emb, json_emb, t, p, r): | |
| """Load example and populate inputs without generating audio yet.""" | |
| # Return values for all inputs + embedding_state output | |
| return text, model, mode, speaker, emb, json_emb, t, p, r | |
| with gr.Row(): | |
| examples = examples | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[text_input, model_dropdown, speaker_mode, speaker_dropdown, embedding_state, json_input, temp, top_p, rp], | |
| outputs=[text_input, model_dropdown, speaker_mode, speaker_dropdown, embedding_state, json_input, temp, top_p, rp], | |
| fn=load_example_input, | |
| cache_examples=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True | |
| ) |