FVOICE_V2 / inference.py
Alazka2503's picture
Clean initial commit (history reset)
d492199
raw
history blame
5.72 kB
import torch
from src.Modules import commons
from src import utils
from src.Voice_Synthesizer import SynthesizerTrn
from src.Text.Symbols import symbols
from src.Text import text_to_sequence
from scipy.io.wavfile import write
import logging
import os
import onnxruntime
import numpy as np
# Desactiva los molestos logs de matplotlib si lo usas en otro lado
logging.getLogger('matplotlib').setLevel(logging.WARNING)
class TTS:
"""
Clase unificada para Texto a Voz (TTS) que soporta tanto modelos
PyTorch (.pth) como ONNX (.onnx).
"""
def __init__(self, config_path, model_path, device="cuda"):
"""
Inicializa el motor TTS. Detecta automáticamente el tipo de modelo.
Args:
config_path (str): Ruta al archivo de configuración JSON.
model_path (str): Ruta al archivo del modelo (.pth o .onnx).
device (str): Dispositivo a usar ("cuda" o "cpu").
"""
self.device = device
self.hps = utils.get_hparams_from_file(config_path)
self.model_path = model_path
self.model_type = "onnx" if model_path.endswith(".onnx") else "pytorch"
self.net_g = None
self.onnx_session = None
if self.model_type == "pytorch":
self._init_pytorch_model()
else:
self._init_onnx_model()
print(f"Motor TTS inicializado en modo: {self.model_type.upper()}")
def _init_pytorch_model(self):
"""Inicializa el modelo usando PyTorch."""
if (
"use_mel_posterior_encoder" in self.hps.model
and self.hps.model.use_mel_posterior_encoder
):
posterior_channels = 80
else:
posterior_channels = self.hps.data.filter_length // 2 + 1
self.net_g = SynthesizerTrn(
len(symbols),
posterior_channels,
self.hps.train.segment_size // self.hps.data.hop_length,
**self.hps.model,
).to(self.device)
_ = self.net_g.eval()
_ = utils.load_checkpoint(self.model_path, self.net_g, None)
def _init_onnx_model(self):
"""Inicializa el motor de inferencia usando ONNX Runtime."""
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if self.device == "cuda" else ["CPUExecutionProvider"]
self.onnx_session = onnxruntime.InferenceSession(
self.model_path, providers=providers
)
def _get_text(self, text):
"""Convierte texto plano a una secuencia de IDs de fonemas."""
text_norm = text_to_sequence(text, self.hps.data.text_cleaners)
if self.hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
return np.array(text_norm, dtype=np.int64)
def text_to_speech(self, text, output_path="sample.wav", noise_scale=0.667, noise_scale_w=0.8, length_scale=1.0, sid=None):
"""Sintetiza audio a partir de un texto y lo guarda en un archivo WAV."""
phoneme_ids = self._get_text(text)
if self.model_type == "pytorch":
# Inferencia con PyTorch
stn_tst = torch.LongTensor(phoneme_ids).to(self.device).unsqueeze(0)
x_tst_lengths = torch.LongTensor([stn_tst.size(1)]).to(self.device)
sid_tensor = torch.LongTensor([sid]).to(self.device) if sid is not None else None
with torch.no_grad():
audio = self.net_g.infer(
stn_tst,
x_tst_lengths,
sid=sid_tensor,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
)[0][0, 0].data.cpu().float().numpy()
else: # Inferencia con ONNX
text_input = np.expand_dims(phoneme_ids, 0)
text_lengths = np.array([text_input.shape[1]], dtype=np.int64)
scales = np.array([noise_scale, length_scale, noise_scale_w], dtype=np.float32)
sid_input = np.array([sid], dtype=np.int64) if sid is not None else None
audio = self.onnx_session.run(
None,
{
"input": text_input,
"input_lengths": text_lengths,
"scales": scales,
"sid": sid_input,
},
)[0].squeeze((0, 1))
write(data=audio, rate=self.hps.data.sampling_rate, filename=output_path)
print(f"Audio guardado exitosamente en: {output_path}")
# --- Ejemplo de Uso ---
if __name__ == "__main__":
# Rutas de configuración
CONFIG_PATH = "./configs/config.json"
# --- PRUEBA CON MODELO ONNX CUANTIZADO (RECOMENDADO) ---
ONNX_MODEL_PATH = "./models/LJspeech_quantized.onnx"
if os.path.exists(ONNX_MODEL_PATH):
print("\n--- Probando con el modelo ONNX ---")
tts_onnx_engine = TTS(config_path=CONFIG_PATH, model_path=ONNX_MODEL_PATH)
tts_onnx_engine.text_to_speech(
"This is a test using the optimized ONNX model. It should be very fast.",
"sample_onnx.wav"
)
else:
print(f"No se encontró el modelo ONNX en {ONNX_MODEL_PATH}. Saltando prueba.")
# --- PRUEBA CON MODELO PYTORCH ORIGINAL ---
PTH_MODEL_PATH = "./models/LJspeech.pth"
if os.path.exists(PTH_MODEL_PATH):
print("\n--- Probando con el modelo PyTorch ---")
tts_pytorch_engine = TTS(config_path=CONFIG_PATH, model_path=PTH_MODEL_PATH)
tts_pytorch_engine.text_to_speech(
"This is a test using the original PyTorch model.",
"sample_pytorch.wav"
)
else:
print(f"No se encontró el modelo PyTorch en {PTH_MODEL_PATH}. Saltando prueba.")