Spaces:
Running
Running
import torch | |
from src.Modules import commons | |
from src import utils | |
from src.Voice_Synthesizer import SynthesizerTrn | |
from src.Text.Symbols import symbols | |
from src.Text import text_to_sequence | |
from scipy.io.wavfile import write | |
import logging | |
import os | |
import onnxruntime | |
import numpy as np | |
# Desactiva los molestos logs de matplotlib si lo usas en otro lado | |
logging.getLogger('matplotlib').setLevel(logging.WARNING) | |
class TTS: | |
""" | |
Clase unificada para Texto a Voz (TTS) que soporta tanto modelos | |
PyTorch (.pth) como ONNX (.onnx). | |
""" | |
def __init__(self, config_path, model_path, device="cuda"): | |
""" | |
Inicializa el motor TTS. Detecta automáticamente el tipo de modelo. | |
Args: | |
config_path (str): Ruta al archivo de configuración JSON. | |
model_path (str): Ruta al archivo del modelo (.pth o .onnx). | |
device (str): Dispositivo a usar ("cuda" o "cpu"). | |
""" | |
self.device = device | |
self.hps = utils.get_hparams_from_file(config_path) | |
self.model_path = model_path | |
self.model_type = "onnx" if model_path.endswith(".onnx") else "pytorch" | |
self.net_g = None | |
self.onnx_session = None | |
if self.model_type == "pytorch": | |
self._init_pytorch_model() | |
else: | |
self._init_onnx_model() | |
print(f"Motor TTS inicializado en modo: {self.model_type.upper()}") | |
def _init_pytorch_model(self): | |
"""Inicializa el modelo usando PyTorch.""" | |
if ( | |
"use_mel_posterior_encoder" in self.hps.model | |
and self.hps.model.use_mel_posterior_encoder | |
): | |
posterior_channels = 80 | |
else: | |
posterior_channels = self.hps.data.filter_length // 2 + 1 | |
self.net_g = SynthesizerTrn( | |
len(symbols), | |
posterior_channels, | |
self.hps.train.segment_size // self.hps.data.hop_length, | |
**self.hps.model, | |
).to(self.device) | |
_ = self.net_g.eval() | |
_ = utils.load_checkpoint(self.model_path, self.net_g, None) | |
def _init_onnx_model(self): | |
"""Inicializa el motor de inferencia usando ONNX Runtime.""" | |
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if self.device == "cuda" else ["CPUExecutionProvider"] | |
self.onnx_session = onnxruntime.InferenceSession( | |
self.model_path, providers=providers | |
) | |
def _get_text(self, text): | |
"""Convierte texto plano a una secuencia de IDs de fonemas.""" | |
text_norm = text_to_sequence(text, self.hps.data.text_cleaners) | |
if self.hps.data.add_blank: | |
text_norm = commons.intersperse(text_norm, 0) | |
return np.array(text_norm, dtype=np.int64) | |
def text_to_speech(self, text, output_path="sample.wav", noise_scale=0.667, noise_scale_w=0.8, length_scale=1.0, sid=None): | |
"""Sintetiza audio a partir de un texto y lo guarda en un archivo WAV.""" | |
phoneme_ids = self._get_text(text) | |
if self.model_type == "pytorch": | |
# Inferencia con PyTorch | |
stn_tst = torch.LongTensor(phoneme_ids).to(self.device).unsqueeze(0) | |
x_tst_lengths = torch.LongTensor([stn_tst.size(1)]).to(self.device) | |
sid_tensor = torch.LongTensor([sid]).to(self.device) if sid is not None else None | |
with torch.no_grad(): | |
audio = self.net_g.infer( | |
stn_tst, | |
x_tst_lengths, | |
sid=sid_tensor, | |
noise_scale=noise_scale, | |
noise_scale_w=noise_scale_w, | |
length_scale=length_scale, | |
)[0][0, 0].data.cpu().float().numpy() | |
else: # Inferencia con ONNX | |
text_input = np.expand_dims(phoneme_ids, 0) | |
text_lengths = np.array([text_input.shape[1]], dtype=np.int64) | |
scales = np.array([noise_scale, length_scale, noise_scale_w], dtype=np.float32) | |
sid_input = np.array([sid], dtype=np.int64) if sid is not None else None | |
audio = self.onnx_session.run( | |
None, | |
{ | |
"input": text_input, | |
"input_lengths": text_lengths, | |
"scales": scales, | |
"sid": sid_input, | |
}, | |
)[0].squeeze((0, 1)) | |
write(data=audio, rate=self.hps.data.sampling_rate, filename=output_path) | |
print(f"Audio guardado exitosamente en: {output_path}") | |
# --- Ejemplo de Uso --- | |
if __name__ == "__main__": | |
# Rutas de configuración | |
CONFIG_PATH = "./configs/config.json" | |
# --- PRUEBA CON MODELO ONNX CUANTIZADO (RECOMENDADO) --- | |
ONNX_MODEL_PATH = "./models/LJspeech_quantized.onnx" | |
if os.path.exists(ONNX_MODEL_PATH): | |
print("\n--- Probando con el modelo ONNX ---") | |
tts_onnx_engine = TTS(config_path=CONFIG_PATH, model_path=ONNX_MODEL_PATH) | |
tts_onnx_engine.text_to_speech( | |
"This is a test using the optimized ONNX model. It should be very fast.", | |
"sample_onnx.wav" | |
) | |
else: | |
print(f"No se encontró el modelo ONNX en {ONNX_MODEL_PATH}. Saltando prueba.") | |
# --- PRUEBA CON MODELO PYTORCH ORIGINAL --- | |
PTH_MODEL_PATH = "./models/LJspeech.pth" | |
if os.path.exists(PTH_MODEL_PATH): | |
print("\n--- Probando con el modelo PyTorch ---") | |
tts_pytorch_engine = TTS(config_path=CONFIG_PATH, model_path=PTH_MODEL_PATH) | |
tts_pytorch_engine.text_to_speech( | |
"This is a test using the original PyTorch model.", | |
"sample_pytorch.wav" | |
) | |
else: | |
print(f"No se encontró el modelo PyTorch en {PTH_MODEL_PATH}. Saltando prueba.") |