kani-tts-2-pt / util.py
Simonlob's picture
Update util.py
ca1f9c1 verified
from kani_tts import KaniTTS
from kani_tts import SpeakerEmbedder
import os
import json
import torch
from omegaconf import OmegaConf
import numpy as np
def load_config(config_path: str):
"""Load configuration from a YAML file using OmegaConf.
Args:
config_path (str): Path to the YAML configuration file.
Returns:
Any: The loaded OmegaConf DictConfig.
"""
resolved_path = os.path.abspath(config_path)
if not os.path.exists(resolved_path):
raise FileNotFoundError(f"Config file not found: {resolved_path}")
config = OmegaConf.load(resolved_path)
return config
class InitModels:
"""
Lazy initializer that constructs a map of model name -> KaniTTS.
Parameters
----------
models_configs : OmegaConf | DictConfig
The `models` section from `model_config.yaml` describing one or
more HF model checkpoints and their options (device_map, use_bematts, etc.).
Returns
-------
dict
When called, returns a dictionary `{model_name: KaniTTS}`.
Notes
-----
- All models are loaded immediately in `__call__` so the UI can list
them and switch between them without extra latency.
- Each KaniTTS instance is initialized with its config directly.
"""
def __init__(self, models_configs: OmegaConf):
self.models_configs = models_configs
def __call__(self):
models = {}
for model_name, config in self.models_configs.items():
print(f"Loading {model_name}...")
# Convert OmegaConf to dict to access parameters
cfg_dict = dict(config)
models[model_name] = KaniTTS(
model_name=cfg_dict.get('model_name'),
device_map=cfg_dict.get('device_map'),
)
print(f"{model_name} loaded!")
print("All models loaded!")
return models
class SpeakerManager:
"""
Manages speaker embeddings for the TTS application.
Supports three modes:
1. Select speaker: Load pre-saved speaker embeddings from speaker_map.json
2. Generate embedding: Generate speaker embedding from uploaded audio using SpeakerEmbedder
3. JSON embedding: Parse speaker embedding from JSON string (list of 128 floats)
Parameters
----------
speaker_map_path : str
Path to speaker_map.json file
Methods
-------
get_speaker_emb(mode, speaker_name=None, json_emb=None) -> str | torch.Tensor | None
Returns speaker embedding based on mode:
- "select": Returns path to .pt file from speaker_map
- "generate": Returns cached generated embedding tensor or None
- "json": Returns embedding tensor parsed from JSON string or None
generate_embedding(audio_data, sample_rate) -> torch.Tensor
Generates speaker embedding from audio using SpeakerEmbedder.
Expects audio at 16kHz. Caches the result internally.
parse_json_embedding(json_str, speaker_emb_dim=128) -> torch.Tensor | None
Parses speaker embedding from JSON string.
Returns [1, dim] tensor or None if parsing fails.
clean()
Clears cached generated embedding.
get_speaker_names() -> list[str]
Returns list of available speaker names from speaker_map.json.
"""
def __init__(self, speaker_map_path: str = "./speakers/speaker_map.json"):
self.speaker_map_path = speaker_map_path
self.speaker_map = self._load_speaker_map()
self.cached_embedding = None
self.embedder = None
def _load_speaker_map(self):
"""Load speaker map from JSON file."""
if not os.path.exists(self.speaker_map_path):
return {}
with open(self.speaker_map_path, "r") as f:
return json.load(f)
def get_speaker_names(self):
"""Get list of available speaker names."""
return list(self.speaker_map.keys())
def get_speaker_emb(self, mode: str, speaker_name: str = None, json_emb: str = None):
"""
Get speaker embedding based on mode.
Parameters
----------
mode : str
Either "select", "generate", or "json"
speaker_name : str, optional
Name of speaker from speaker_map (only used in "select" mode)
json_emb : str, optional
JSON string containing embedding list (only used in "json" mode)
Returns
-------
str | torch.Tensor | None
Path to .pt file (select mode) or embedding tensor (generate/json mode)
"""
if mode == "select":
if speaker_name and speaker_name in self.speaker_map:
return self.speaker_map[speaker_name]
return None
elif mode == "generate":
return self.cached_embedding
elif mode == "json":
return self.parse_json_embedding(json_emb)
return None
def generate_embedding(self, audio_data):
"""
Generate speaker embedding from audio data.
Parameters
----------
audio_data : tuple
Tuple of (sample_rate, audio_array) from Gradio Audio component
Returns
-------
torch.Tensor
Generated speaker embedding [1, 128]
"""
# Initialize embedder lazily
if self.embedder is None:
self.embedder = SpeakerEmbedder()
# Handle Gradio audio format (sr, audio) tuple
if isinstance(audio_data, tuple):
sample_rate, audio_array = audio_data
else:
# Fallback: assume it's just audio array at 16kHz
audio_array = audio_data
sample_rate = 16000
# convert audio from int16 (gradio.audio() returns int16) to float32
audio_array = audio_array.astype(np.float32) / 32768.0
# make mono if stereo (gradio specific)! gradio returns waveform with shape (num_samples, num_channels) that not typical for torch paradigm
if audio_array.ndim == 2:
print('Make MONO from STEREO')
audio_array = audio_array.mean(axis=1)
# Generate embedding (SpeakerEmbedder will handle resampling if needed)
embedding = self.embedder.embed_audio(audio_array, sample_rate=sample_rate)
# Cache the result
self.cached_embedding = embedding
return embedding
def parse_json_embedding(self, json_str: str, speaker_emb_dim: int = 128):
"""
Parse speaker embedding from JSON string.
Parameters
----------
json_str : str
JSON string containing list of floats [0.123, -0.456, ...]
speaker_emb_dim : int
Expected embedding dimension (default: 128)
Returns
-------
torch.Tensor | None
Speaker embedding tensor [1, dim] or None if parsing fails
"""
if not json_str or not json_str.strip():
print("No JSON embedding provided")
return None
try:
# Parse JSON array
emb_list = json.loads(json_str.strip())
# Validate it's a list
if not isinstance(emb_list, list):
print(f"Error: Speaker embedding must be a JSON array, got {type(emb_list)}")
return None
# Validate length
if len(emb_list) != speaker_emb_dim:
print(f"Error: Speaker embedding must have {speaker_emb_dim} dimensions, got {len(emb_list)}")
return None
# Convert to torch tensor [1, dim]
speaker_emb = torch.tensor([emb_list], dtype=torch.float32)
print(f"Using speaker embedding from JSON: shape {speaker_emb.shape}")
return speaker_emb
except json.JSONDecodeError as e:
print(f"Error parsing JSON: {e}")
return None
except Exception as e:
print(f"Error processing speaker embedding: {e}")
return None
def clean(self):
"""Clear cached generated embedding."""
self.cached_embedding = None
return "Embedding cleared"
def get_status(self):
"""Get current status of generated embedding."""
if self.cached_embedding is not None:
return "✅ Embedding ready"
return "No embedding generated"
class Examples:
"""
Adapter that converts YAML examples into Gradio `gr.Examples` rows.
Parameters
----------
exam_cfg : OmegaConf | DictConfig
Parsed contents of `examples.yaml`. Expected structure:
`examples: [ {text, model, speaker?, temperature?, top_p?, repetition_penalty?}, ... ]`.
Behavior
--------
- Produces a list-of-lists whose order must match the `inputs` order
used when constructing `gr.Examples` in `app.py`.
- Current order: `[text, model_dropdown, speaker_mode, speaker_dropdown, embedding_state, json_input, temp, top_p, rp]`.
Why this exists
---------------
- Keeps format and defaults centralized, so changing the UI inputs
order only requires a single change here and in `app.py`.
"""
def __init__(self, exam_cfg: OmegaConf):
self.exam_cfg = exam_cfg
def __call__(self) -> list[list]:
rows = []
for e in self.exam_cfg.examples:
text = e.get("text")
model = e.get("model")
speaker_mode = e.get("speaker_mode", "select") # Default to "select" mode
speaker = e.get("speaker", "Kore (en)")
embedding_state = None # Examples always use select mode, so no embedding state needed
json_input = e.get("json_input", "") # Empty JSON input for examples
temperature = e.get("temperature", 1.0)
top_p = e.get("top_p", 0.95)
repetition_penalty = e.get("repetition_penalty", 1.1)
# Order must match gr.Examples inputs: [text, model_dropdown, speaker_mode, speaker_dropdown, embedding_state, json_input, temp, top_p, rp]
rows.append([text, model, speaker_mode, speaker, embedding_state, json_input, temperature, top_p, repetition_penalty])
return rows