import json import os import time from time_util import timer from typing import Optional from unicodedata import normalize import uuid import numpy as np import onnxruntime as ort import soundfile as sf from huggingface_hub import snapshot_download from typing import Optional, Union class UnicodeProcessor: def __init__(self, unicode_indexer_path: str): with open(unicode_indexer_path, "r") as f: self.indexer = json.load(f) def _preprocess_text(self, text: str) -> str: # TODO: add more preprocessing text = normalize("NFKD", text) return text def _get_text_mask(self, text_ids_lengths: np.ndarray) -> np.ndarray: text_mask = length_to_mask(text_ids_lengths) return text_mask def _text_to_unicode_values(self, text: str) -> np.ndarray: unicode_values = np.array( [ord(char) for char in text], dtype=np.uint16 ) # 2 bytes return unicode_values def __call__(self, text_list: list[str]) -> tuple[np.ndarray, np.ndarray]: text_list = [self._preprocess_text(t) for t in text_list] text_ids_lengths = np.array([len(text) for text in text_list], dtype=np.int64) text_ids = np.zeros((len(text_list), text_ids_lengths.max()), dtype=np.int64) for i, text in enumerate(text_list): unicode_vals = self._text_to_unicode_values(text) text_ids[i, : len(unicode_vals)] = np.array( [self.indexer[val] for val in unicode_vals], dtype=np.int64 ) text_mask = self._get_text_mask(text_ids_lengths) return text_ids, text_mask class Style: def __init__(self, style_ttl_onnx: np.ndarray, style_dp_onnx: np.ndarray): self.ttl = style_ttl_onnx self.dp = style_dp_onnx class TextToSpeech: def __init__( self, cfgs: dict, text_processor: UnicodeProcessor, dp_ort: ort.InferenceSession, text_enc_ort: ort.InferenceSession, vector_est_ort: ort.InferenceSession, vocoder_ort: ort.InferenceSession, ): self.cfgs = cfgs self.text_processor = text_processor self.dp_ort = dp_ort self.text_enc_ort = text_enc_ort self.vector_est_ort = vector_est_ort self.vocoder_ort = vocoder_ort self.sample_rate = cfgs["ae"]["sample_rate"] self.base_chunk_size = cfgs["ae"]["base_chunk_size"] self.chunk_compress_factor = cfgs["ttl"]["chunk_compress_factor"] self.ldim = cfgs["ttl"]["latent_dim"] def sample_noisy_latent( self, duration: np.ndarray ) -> tuple[np.ndarray, np.ndarray]: bsz = len(duration) wav_len_max = duration.max() * self.sample_rate wav_lengths = (duration * self.sample_rate).astype(np.int64) chunk_size = self.base_chunk_size * self.chunk_compress_factor latent_len = ((wav_len_max + chunk_size - 1) / chunk_size).astype(np.int32) latent_dim = self.ldim * self.chunk_compress_factor noisy_latent = np.random.randn(bsz, latent_dim, latent_len).astype(np.float32) latent_mask = get_latent_mask( wav_lengths, self.base_chunk_size, self.chunk_compress_factor ) noisy_latent = noisy_latent * latent_mask return noisy_latent, latent_mask def _infer( self, text_list: list[str], style: Style, total_step: int, speed: float = 1.05, suggested_duration: Optional[Union[float, list[float], np.ndarray]] = None, speed_min_factor: float = 0.75, speed_max_factor: float = 1.2, ) -> tuple[np.ndarray, np.ndarray]: assert ( len(text_list) == style.ttl.shape[0] ), "Number of texts must match number of style vectors" bsz = len(text_list) text_ids, text_mask = self.text_processor(text_list) # 1) Predict base duration dur_pred, *_ = self.dp_ort.run( None, {"text_ids": text_ids, "style_dp": style.dp, "text_mask": text_mask} ) dur_pred = np.array(dur_pred, dtype=np.float32).reshape(bsz) # (bsz,) # 2) Adjust duration based on suggested_duration (if given) if suggested_duration is not None: sugg = np.array(suggested_duration, dtype=np.float32) if sugg.ndim == 0: # same suggestion for all sugg = np.full((bsz,), float(sugg), dtype=np.float32) else: sugg = sugg.reshape(bsz) eps = 1e-3 sugg = np.clip(sugg, eps, None) # we want dur_used ≈ sugg # dur_used = dur_pred / speed_used => speed_target = dur_pred / sugg speed_target = dur_pred / sugg speed_min = speed * speed_min_factor speed_max = speed * speed_max_factor speed_used = np.clip(speed_target, speed_min, speed_max) dur_used = dur_pred / speed_used else: # default behaviour speed_used = np.full((bsz,), speed, dtype=np.float32) dur_used = dur_pred / speed_used # 3) Continue as before, using dur_used text_emb_onnx, *_ = self.text_enc_ort.run( None, {"text_ids": text_ids, "style_ttl": style.ttl, "text_mask": text_mask}, ) xt, latent_mask = self.sample_noisy_latent(dur_used) total_step_np = np.array([total_step] * bsz, dtype=np.float32) for step in range(total_step): current_step = np.array([step] * bsz, dtype=np.float32) xt, *_ = self.vector_est_ort.run( None, { "noisy_latent": xt, "text_emb": text_emb_onnx, "style_ttl": style.ttl, "text_mask": text_mask, "latent_mask": latent_mask, "current_step": current_step, "total_step": total_step_np, }, ) wav, *_ = self.vocoder_ort.run(None, {"latent": xt}) return wav, dur_used def batch( self, text_list: list[str], style: Style, total_step: int, speed: float = 1.05, suggested_duration: Optional[Union[float, list[float], np.ndarray]] = None, speed_min_factor: float = 0.75, speed_max_factor: float = 1.2, ) -> tuple[np.ndarray, np.ndarray]: return self._infer( text_list, style, total_step, speed=speed, suggested_duration=suggested_duration, speed_min_factor=speed_min_factor, speed_max_factor=speed_max_factor, ) def __call__( self, text: str, style: Style, total_step: int, speed: float = 1.05, silence_duration: float = 0.3, ) -> tuple[np.ndarray, np.ndarray]: assert ( style.ttl.shape[0] == 1 ), "Single speaker text to speech only supports single style" text_list = chunk_text(text) wav_cat = None dur_cat = None for text in text_list: wav, dur_onnx = self._infer([text], style, total_step, speed) if wav_cat is None: wav_cat = wav dur_cat = dur_onnx else: silence = np.zeros( (1, int(silence_duration * self.sample_rate)), dtype=np.float32 ) wav_cat = np.concatenate([wav_cat, silence, wav], axis=1) dur_cat += dur_onnx + silence_duration return wav_cat, dur_cat def length_to_mask(lengths: np.ndarray, max_len: Optional[int] = None) -> np.ndarray: """ Convert lengths to binary mask. Args: lengths: (B,) max_len: int Returns: mask: (B, 1, max_len) """ max_len = max_len or lengths.max() ids = np.arange(0, max_len) mask = (ids < np.expand_dims(lengths, axis=1)).astype(np.float32) return mask.reshape(-1, 1, max_len) def get_latent_mask( wav_lengths: np.ndarray, base_chunk_size: int, chunk_compress_factor: int ) -> np.ndarray: latent_size = base_chunk_size * chunk_compress_factor latent_lengths = (wav_lengths + latent_size - 1) // latent_size latent_mask = length_to_mask(latent_lengths) return latent_mask def load_onnx( onnx_path: str, opts: ort.SessionOptions, providers: list[str] ) -> ort.InferenceSession: return ort.InferenceSession(onnx_path, sess_options=opts, providers=providers) def load_onnx_all( onnx_dir: str, opts: ort.SessionOptions, providers: list[str] ) -> tuple[ ort.InferenceSession, ort.InferenceSession, ort.InferenceSession, ort.InferenceSession, ]: dp_onnx_path = os.path.join(onnx_dir, "duration_predictor.onnx") text_enc_onnx_path = os.path.join(onnx_dir, "text_encoder.onnx") vector_est_onnx_path = os.path.join(onnx_dir, "vector_estimator.onnx") vocoder_onnx_path = os.path.join(onnx_dir, "vocoder.onnx") dp_ort = load_onnx(dp_onnx_path, opts, providers) text_enc_ort = load_onnx(text_enc_onnx_path, opts, providers) vector_est_ort = load_onnx(vector_est_onnx_path, opts, providers) vocoder_ort = load_onnx(vocoder_onnx_path, opts, providers) return dp_ort, text_enc_ort, vector_est_ort, vocoder_ort def load_cfgs(onnx_dir: str) -> dict: cfg_path = os.path.join(onnx_dir, "tts.json") with open(cfg_path, "r") as f: cfgs = json.load(f) return cfgs def load_text_processor(onnx_dir: str) -> UnicodeProcessor: unicode_indexer_path = os.path.join(onnx_dir, "unicode_indexer.json") text_processor = UnicodeProcessor(unicode_indexer_path) return text_processor # text_to_speech = load_text_to_speech(False) model_dir = snapshot_download("Supertone/supertonic") onnx_dir = f"{model_dir}/onnx" def load_text_to_speech(use_gpu: bool = False) -> TextToSpeech: opts = ort.SessionOptions() if use_gpu: providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] else: providers = ["CPUExecutionProvider"] print("Using CPU for inference") cfgs = load_cfgs(onnx_dir) dp_ort, text_enc_ort, vector_est_ort, vocoder_ort = load_onnx_all( onnx_dir, opts, providers ) text_processor = load_text_processor(onnx_dir) return TextToSpeech( cfgs, text_processor, dp_ort, text_enc_ort, vector_est_ort, vocoder_ort ) def load_voice_style(voice_style_paths: list[str], verbose: bool = False) -> Style: bsz = len(voice_style_paths) # Read first file to get dimensions with open(voice_style_paths[0], "r") as f: first_style = json.load(f) ttl_dims = first_style["style_ttl"]["dims"] dp_dims = first_style["style_dp"]["dims"] # Pre-allocate arrays with full batch size ttl_style = np.zeros([bsz, ttl_dims[1], ttl_dims[2]], dtype=np.float32) dp_style = np.zeros([bsz, dp_dims[1], dp_dims[2]], dtype=np.float32) # Fill in the data for i, voice_style_path in enumerate(voice_style_paths): with open(voice_style_path, "r") as f: voice_style = json.load(f) ttl_data = np.array( voice_style["style_ttl"]["data"], dtype=np.float32 ).flatten() ttl_style[i] = ttl_data.reshape(ttl_dims[1], ttl_dims[2]) dp_data = np.array(voice_style["style_dp"]["data"], dtype=np.float32).flatten() dp_style[i] = dp_data.reshape(dp_dims[1], dp_dims[2]) if verbose: print(f"Loaded {bsz} voice styles") return Style(ttl_style, dp_style) def sanitize_filename(text: str, max_len: int) -> str: """Sanitize filename by replacing non-alphanumeric characters with underscores""" import re prefix = text[:max_len] return re.sub(r"[^a-zA-Z0-9]", "_", prefix) def chunk_text(text: str, max_len: int = 300) -> list[str]: """ Split text into chunks by paragraphs and sentences. Args: text: Input text to chunk max_len: Maximum length of each chunk (default: 300) Returns: List of text chunks """ import re # Split by paragraph (two or more newlines) paragraphs = [p.strip() for p in re.split(r"\n\s*\n+", text.strip()) if p.strip()] chunks = [] for paragraph in paragraphs: paragraph = paragraph.strip() if not paragraph: continue # Split by sentence boundaries (period, question mark, exclamation mark followed by space) # But exclude common abbreviations like Mr., Mrs., Dr., etc. and single capital letters like F. pattern = r"(?