Spaces:
Running on Zero
Running on Zero
| from kani_tts import KaniTTS | |
| from kani_tts import SpeakerEmbedder | |
| import os | |
| import json | |
| import torch | |
| from omegaconf import OmegaConf | |
| import numpy as np | |
| def load_config(config_path: str): | |
| """Load configuration from a YAML file using OmegaConf. | |
| Args: | |
| config_path (str): Path to the YAML configuration file. | |
| Returns: | |
| Any: The loaded OmegaConf DictConfig. | |
| """ | |
| resolved_path = os.path.abspath(config_path) | |
| if not os.path.exists(resolved_path): | |
| raise FileNotFoundError(f"Config file not found: {resolved_path}") | |
| config = OmegaConf.load(resolved_path) | |
| return config | |
| class InitModels: | |
| """ | |
| Lazy initializer that constructs a map of model name -> KaniTTS. | |
| Parameters | |
| ---------- | |
| models_configs : OmegaConf | DictConfig | |
| The `models` section from `model_config.yaml` describing one or | |
| more HF model checkpoints and their options (device_map, use_bematts, etc.). | |
| Returns | |
| ------- | |
| dict | |
| When called, returns a dictionary `{model_name: KaniTTS}`. | |
| Notes | |
| ----- | |
| - All models are loaded immediately in `__call__` so the UI can list | |
| them and switch between them without extra latency. | |
| - Each KaniTTS instance is initialized with its config directly. | |
| """ | |
| def __init__(self, models_configs: OmegaConf): | |
| self.models_configs = models_configs | |
| def __call__(self): | |
| models = {} | |
| for model_name, config in self.models_configs.items(): | |
| print(f"Loading {model_name}...") | |
| # Convert OmegaConf to dict to access parameters | |
| cfg_dict = dict(config) | |
| models[model_name] = KaniTTS( | |
| model_name=cfg_dict.get('model_name'), | |
| device_map=cfg_dict.get('device_map'), | |
| ) | |
| print(f"{model_name} loaded!") | |
| print("All models loaded!") | |
| return models | |
| class SpeakerManager: | |
| """ | |
| Manages speaker embeddings for the TTS application. | |
| Supports three modes: | |
| 1. Select speaker: Load pre-saved speaker embeddings from speaker_map.json | |
| 2. Generate embedding: Generate speaker embedding from uploaded audio using SpeakerEmbedder | |
| 3. JSON embedding: Parse speaker embedding from JSON string (list of 128 floats) | |
| Parameters | |
| ---------- | |
| speaker_map_path : str | |
| Path to speaker_map.json file | |
| Methods | |
| ------- | |
| get_speaker_emb(mode, speaker_name=None, json_emb=None) -> str | torch.Tensor | None | |
| Returns speaker embedding based on mode: | |
| - "select": Returns path to .pt file from speaker_map | |
| - "generate": Returns cached generated embedding tensor or None | |
| - "json": Returns embedding tensor parsed from JSON string or None | |
| generate_embedding(audio_data, sample_rate) -> torch.Tensor | |
| Generates speaker embedding from audio using SpeakerEmbedder. | |
| Expects audio at 16kHz. Caches the result internally. | |
| parse_json_embedding(json_str, speaker_emb_dim=128) -> torch.Tensor | None | |
| Parses speaker embedding from JSON string. | |
| Returns [1, dim] tensor or None if parsing fails. | |
| clean() | |
| Clears cached generated embedding. | |
| get_speaker_names() -> list[str] | |
| Returns list of available speaker names from speaker_map.json. | |
| """ | |
| def __init__(self, speaker_map_path: str = "./speakers/speaker_map.json"): | |
| self.speaker_map_path = speaker_map_path | |
| self.speaker_map = self._load_speaker_map() | |
| self.cached_embedding = None | |
| self.embedder = None | |
| def _load_speaker_map(self): | |
| """Load speaker map from JSON file.""" | |
| if not os.path.exists(self.speaker_map_path): | |
| return {} | |
| with open(self.speaker_map_path, "r") as f: | |
| return json.load(f) | |
| def get_speaker_names(self): | |
| """Get list of available speaker names.""" | |
| return list(self.speaker_map.keys()) | |
| def get_speaker_emb(self, mode: str, speaker_name: str = None, json_emb: str = None): | |
| """ | |
| Get speaker embedding based on mode. | |
| Parameters | |
| ---------- | |
| mode : str | |
| Either "select", "generate", or "json" | |
| speaker_name : str, optional | |
| Name of speaker from speaker_map (only used in "select" mode) | |
| json_emb : str, optional | |
| JSON string containing embedding list (only used in "json" mode) | |
| Returns | |
| ------- | |
| str | torch.Tensor | None | |
| Path to .pt file (select mode) or embedding tensor (generate/json mode) | |
| """ | |
| if mode == "select": | |
| if speaker_name and speaker_name in self.speaker_map: | |
| return self.speaker_map[speaker_name] | |
| return None | |
| elif mode == "generate": | |
| return self.cached_embedding | |
| elif mode == "json": | |
| return self.parse_json_embedding(json_emb) | |
| return None | |
| def generate_embedding(self, audio_data): | |
| """ | |
| Generate speaker embedding from audio data. | |
| Parameters | |
| ---------- | |
| audio_data : tuple | |
| Tuple of (sample_rate, audio_array) from Gradio Audio component | |
| Returns | |
| ------- | |
| torch.Tensor | |
| Generated speaker embedding [1, 128] | |
| """ | |
| # Initialize embedder lazily | |
| if self.embedder is None: | |
| self.embedder = SpeakerEmbedder() | |
| # Handle Gradio audio format (sr, audio) tuple | |
| if isinstance(audio_data, tuple): | |
| sample_rate, audio_array = audio_data | |
| else: | |
| # Fallback: assume it's just audio array at 16kHz | |
| audio_array = audio_data | |
| sample_rate = 16000 | |
| # convert audio from int16 (gradio.audio() returns int16) to float32 | |
| audio_array = audio_array.astype(np.float32) / 32768.0 | |
| # make mono if stereo (gradio specific)! gradio returns waveform with shape (num_samples, num_channels) that not typical for torch paradigm | |
| if audio_array.ndim == 2: | |
| print('Make MONO from STEREO') | |
| audio_array = audio_array.mean(axis=1) | |
| # Generate embedding (SpeakerEmbedder will handle resampling if needed) | |
| embedding = self.embedder.embed_audio(audio_array, sample_rate=sample_rate) | |
| # Cache the result | |
| self.cached_embedding = embedding | |
| return embedding | |
| def parse_json_embedding(self, json_str: str, speaker_emb_dim: int = 128): | |
| """ | |
| Parse speaker embedding from JSON string. | |
| Parameters | |
| ---------- | |
| json_str : str | |
| JSON string containing list of floats [0.123, -0.456, ...] | |
| speaker_emb_dim : int | |
| Expected embedding dimension (default: 128) | |
| Returns | |
| ------- | |
| torch.Tensor | None | |
| Speaker embedding tensor [1, dim] or None if parsing fails | |
| """ | |
| if not json_str or not json_str.strip(): | |
| print("No JSON embedding provided") | |
| return None | |
| try: | |
| # Parse JSON array | |
| emb_list = json.loads(json_str.strip()) | |
| # Validate it's a list | |
| if not isinstance(emb_list, list): | |
| print(f"Error: Speaker embedding must be a JSON array, got {type(emb_list)}") | |
| return None | |
| # Validate length | |
| if len(emb_list) != speaker_emb_dim: | |
| print(f"Error: Speaker embedding must have {speaker_emb_dim} dimensions, got {len(emb_list)}") | |
| return None | |
| # Convert to torch tensor [1, dim] | |
| speaker_emb = torch.tensor([emb_list], dtype=torch.float32) | |
| print(f"Using speaker embedding from JSON: shape {speaker_emb.shape}") | |
| return speaker_emb | |
| except json.JSONDecodeError as e: | |
| print(f"Error parsing JSON: {e}") | |
| return None | |
| except Exception as e: | |
| print(f"Error processing speaker embedding: {e}") | |
| return None | |
| def clean(self): | |
| """Clear cached generated embedding.""" | |
| self.cached_embedding = None | |
| return "Embedding cleared" | |
| def get_status(self): | |
| """Get current status of generated embedding.""" | |
| if self.cached_embedding is not None: | |
| return "✅ Embedding ready" | |
| return "No embedding generated" | |
| class Examples: | |
| """ | |
| Adapter that converts YAML examples into Gradio `gr.Examples` rows. | |
| Parameters | |
| ---------- | |
| exam_cfg : OmegaConf | DictConfig | |
| Parsed contents of `examples.yaml`. Expected structure: | |
| `examples: [ {text, model, speaker?, temperature?, top_p?, repetition_penalty?}, ... ]`. | |
| Behavior | |
| -------- | |
| - Produces a list-of-lists whose order must match the `inputs` order | |
| used when constructing `gr.Examples` in `app.py`. | |
| - Current order: `[text, model_dropdown, speaker_mode, speaker_dropdown, embedding_state, json_input, temp, top_p, rp]`. | |
| Why this exists | |
| --------------- | |
| - Keeps format and defaults centralized, so changing the UI inputs | |
| order only requires a single change here and in `app.py`. | |
| """ | |
| def __init__(self, exam_cfg: OmegaConf): | |
| self.exam_cfg = exam_cfg | |
| def __call__(self) -> list[list]: | |
| rows = [] | |
| for e in self.exam_cfg.examples: | |
| text = e.get("text") | |
| model = e.get("model") | |
| speaker_mode = e.get("speaker_mode", "select") # Default to "select" mode | |
| speaker = e.get("speaker", "Kore (en)") | |
| embedding_state = None # Examples always use select mode, so no embedding state needed | |
| json_input = e.get("json_input", "") # Empty JSON input for examples | |
| temperature = e.get("temperature", 1.0) | |
| top_p = e.get("top_p", 0.95) | |
| repetition_penalty = e.get("repetition_penalty", 1.1) | |
| # Order must match gr.Examples inputs: [text, model_dropdown, speaker_mode, speaker_dropdown, embedding_state, json_input, temp, top_p, rp] | |
| rows.append([text, model, speaker_mode, speaker, embedding_state, json_input, temperature, top_p, repetition_penalty]) | |
| return rows | |