Spaces:
Running
Running
File size: 7,239 Bytes
3e69bb2 82a3b65 3e69bb2 89cbeba 82a3b65 d501fdb 82a3b65 2690696 82a3b65 d501fdb c901ccb 2caeba1 4360556 06dc4b4 4360556 d501fdb 82a3b65 2690696 d501fdb b2de89e bd0c722 d501fdb bd0c722 b2de89e bd0c722 b2de89e bd0c722 b2de89e bd0c722 b2de89e 3d92837 bd0c722 3d92837 cd231b1 bd0c722 b2de89e bd0c722 82a3b65 2690696 d501fdb c901ccb d501fdb 82a3b65 2690696 3d92837 d501fdb d0fe171 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import gradio as gr
import numpy as np
import tensorflow as tf
from tensorflow import keras
import torch
from huggingface_hub import hf_hub_download
from speechbrain.inference.TTS import Tacotron2
import os
# Cargar modelo Tacotron2
tacotron2 = Tacotron2.from_hparams(
source="speechbrain/tts-tacotron2-ljspeech",
savedir="tmpdir_tts",
run_opts={"device": "cpu"}
)
# Diccionario para almacenar los modelos cargados
loaded_models = {}
# Modelos disponibles - define aquí las épocas que quieres incluir
available_models = {
"Época 100": "generator_epoch_100.keras",
"Época 300": "generator_epoch_300.keras",
"Época 400": "generator_epoch_400.keras",
"Época 1000": "generator_epoch_1000.keras",
"Época 4200": "generator_epoch_4200.keras",
"Época 4700": "generator_epoch_4700.keras",
"Época 7700": "generator_epoch_7700.keras",
}
# Función para cargar un modelo específico
def load_generator_model(model_name):
if model_name in loaded_models:
return loaded_models[model_name]
try:
model_path = hf_hub_download(
repo_id="Bmo411/WGAN",
filename=model_name
)
model = keras.models.load_model(model_path, compile=False)
loaded_models[model_name] = model
print(f"Modelo {model_name} cargado correctamente")
return model
except Exception as e:
print(f"Error al cargar el modelo {model_name}: {e}")
# Si falla la carga, intentamos usar el modelo de la época 1000 como fallback
try:
fallback_model = "generator_epoch_1000.keras"
model_path = hf_hub_download(
repo_id="Bmo411/WGAN",
filename=fallback_model
)
model = keras.models.load_model(model_path, compile=False)
loaded_models[model_name] = model # Guardamos con el nombre original para evitar recargar
print(f"Usando modelo fallback {fallback_model}")
return model
except:
print("Error crítico al cargar modelos. No hay modelos disponibles.")
return None
# Función para convertir texto a audio
def text_to_audio(text, model_epoch):
# Crear un array vacío por defecto en caso de error
default_audio = np.zeros(8000, dtype=np.float32)
sample_rate = 8000 # Ajusta según la configuración de tu modelo
if not text or not text.strip():
return (sample_rate, default_audio)
try:
# Obtener el nombre del archivo del modelo seleccionado
model_filename = available_models[model_epoch]
# Cargar el modelo generador correspondiente
generator = load_generator_model(model_filename)
if generator is None:
print("No se pudo cargar el generador")
return (sample_rate, default_audio)
# Convertir texto a mel-spectrograma con Tacotron2
mel_output, _, _ = tacotron2.encode_text(text)
mel = mel_output.detach().cpu().numpy().astype(np.float32)
# Imprimir forma original del mel para debugging
print(f"Forma original del mel: {mel.shape}")
# Reorganizar el mel para que coincida con la forma esperada (batch, 80, frames, 1)
# Si mel tiene forma (80, frames) - lo más probable
if len(mel.shape) == 2:
mel_input = np.expand_dims(mel, axis=0) # (1, 80, frames)
mel_input = np.expand_dims(mel_input, axis=-1) # (1, 80, frames, 1)
# Si viene con otra forma, intentamos adaptarla
elif len(mel.shape) == 3 and mel.shape[0] == 1:
# Si es (1, 80, frames) o (1, frames, 80)
if mel.shape[1] == 80:
mel_input = np.expand_dims(mel, axis=-1) # (1, 80, frames, 1)
else:
mel_input = np.expand_dims(np.transpose(mel, (0, 2, 1)), axis=-1) # (1, 80, frames, 1)
else:
# Intento final de reorganización
mel_input = np.expand_dims(np.expand_dims(mel, axis=0), axis=-1)
print(f"Forma del mel preparado: {mel_input.shape}")
# Generar audio
generated_audio = generator(mel_input, training=False)
# Procesar el audio generado
generated_audio = tf.squeeze(generated_audio).numpy()
# Asegurarse de que hay valores no cero antes de normalizar
if np.max(np.abs(generated_audio)) > 0:
generated_audio = generated_audio / np.max(np.abs(generated_audio))
# Convertir a float32 para gradio
generated_audio = generated_audio.astype(np.float32)
print(f"Forma del audio generado: {generated_audio.shape}")
current_length = len(generated_audio)
if current_length > 8000:
# Recortar si es más largo de 2 segundos
print(f"Recortando audio de {current_length} a {8000} muestras")
final_audio = generated_audio[:8000]
else:
# Rellenar con ceros si es más corto de 2 segundos
print(f"Rellenando audio de {current_length} a {8000} muestras")
final_audio = np.zeros(8000, dtype=np.float32)
final_audio[:current_length] = generated_audio
return (sample_rate, final_audio)
except Exception as e:
print(f"Error en la generación de audio: {e}")
# Si hay error, imprimir un traceback completo para mejor diagnóstico
import traceback
traceback.print_exc()
return (sample_rate, default_audio)
# Crear interfaz en Gradio
with gr.Blocks(title="Demo de TTS con Tacotron2 + Generador") as interface:
gr.Markdown("# Demo de TTS con Tacotron2 + Generador")
gr.Markdown("Convierte texto en audio usando Tacotron2 + modelo Generator entrenado en diferentes épocas.")
with gr.Row():
with gr.Column(scale=3):
text_input = gr.Textbox(lines=2, placeholder="Escribe nine-", label="Texto a convertir")
with gr.Column(scale=1):
model_selection = gr.Dropdown(
choices=list(available_models.keys()),
value="Época 1000",
label="Selecciona la época del modelo"
)
generate_btn = gr.Button("Generar Audio", variant="primary")
audio_output = gr.Audio(label="Audio generado")
# Configurar ejemplos
examples = gr.Examples(
examples=[
["nine", "Época 100"],
["nine", "Época 400"],
["nine", "Época 4700"]
],
inputs=[text_input, model_selection],
outputs=audio_output
)
# Conectar botón a la función
generate_btn.click(fn=text_to_audio, inputs=[text_input, model_selection], outputs=audio_output)
# También permitir enviar con Enter desde el cuadro de texto
text_input.submit(fn=text_to_audio, inputs=[text_input, model_selection], outputs=audio_output)
# Lanzar aplicación
if __name__ == "__main__":
# Precargamos el modelo de la época 1000 para tenerlo disponible inmediatamente
load_generator_model(available_models["Época 1000"])
# Lanzamos la interfaz
interface.launch(debug=True)
|