RedSparkie commited on
Commit
1a3ef6b
1 Parent(s): 3408722

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -31
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import torch
3
  from TTS.api import TTS
@@ -11,9 +12,6 @@ from TTS.tts.models.xtts import Xtts
11
  # Aceptar los términos de COQUI
12
  os.environ["COQUI_TOS_AGREED"] = "1"
13
 
14
- # Establecer precisión reducida para acelerar en CPU
15
- torch.set_default_dtype(torch.float16)
16
-
17
  # Definir el dispositivo como CPU
18
  device = "cpu"
19
 
@@ -22,24 +20,20 @@ model_path = hf_hub_download(repo_id="RedSparkie/danielmula", filename="model.pt
22
  config_path = hf_hub_download(repo_id="RedSparkie/danielmula", filename="config.json")
23
  vocab_path = hf_hub_download(repo_id="RedSparkie/danielmula", filename="vocab.json")
24
 
25
- # Función para resamplear el audio a 24000 Hz y convertirlo a 16 bits
26
- def preprocess_audio(audio_path, target_sr=24000):
27
- waveform, original_sr = torchaudio.load(audio_path)
28
-
29
- # Resamplear si la frecuencia de muestreo es diferente
30
- if original_sr != target_sr:
31
- resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=target_sr)
32
- waveform = resampler(waveform)
33
-
34
- # Convertir a 16 bits
35
- waveform = waveform * (2**15) # Escalar para 16 bits
36
- waveform = waveform.to(torch.int16) # Convertir a formato de 16 bits
37
- return waveform, target_sr
38
 
39
  # Cargar el modelo XTTS
40
  XTTS_MODEL = None
41
  def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
42
  global XTTS_MODEL
 
 
 
 
 
43
  config = XttsConfig()
44
  config.load_json(xtts_config)
45
 
@@ -48,7 +42,8 @@ def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
48
  print("Loading XTTS model!")
49
 
50
  # Cargar el checkpoint del modelo
51
- XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False)
 
52
  print("Model Loaded!")
53
 
54
  # Función para ejecutar TTS
@@ -56,25 +51,14 @@ def run_tts(lang, tts_text, speaker_audio_file):
56
  if XTTS_MODEL is None or not speaker_audio_file:
57
  return "You need to run the previous step to load the model !!", None, None
58
 
59
- # Preprocesar el audio (resampleo a 24000 Hz y conversión a 16 bits)
60
- waveform, sr = preprocess_audio(speaker_audio_file)
61
-
62
- # Guardar el audio procesado temporalmente para usarlo con el modelo
63
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
64
- torchaudio.save(fp.name, waveform, sr)
65
- processed_audio_path = fp.name
66
-
67
  # Usar inference_mode para mejorar el rendimiento
68
  with torch.inference_mode():
69
  gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
70
- audio_path=processed_audio_path,
71
  gpt_cond_len=XTTS_MODEL.config.gpt_cond_len,
72
  max_ref_length=XTTS_MODEL.config.max_ref_len,
73
  sound_norm_refs=XTTS_MODEL.config.sound_norm_refs
74
  )
75
-
76
- if gpt_cond_latent is None or speaker_embedding is None:
77
- return "Failed to process the audio file.", None, None
78
 
79
  out = XTTS_MODEL.inference(
80
  text=tts_text,
@@ -98,6 +82,7 @@ def run_tts(lang, tts_text, speaker_audio_file):
98
  return out_path, speaker_audio_file
99
 
100
  # Definir la función para Gradio
 
101
  def generate(text, audio):
102
  load_model(model_path, config_path, vocab_path)
103
  out_path, speaker_audio_file = run_tts(lang='es', tts_text=text, speaker_audio_file=audio)
@@ -110,5 +95,5 @@ demo = gr.Interface(
110
  outputs=gr.Audio(type='filepath')
111
  )
112
 
113
- # Lanzar la interfaz con un enlace público
114
- demo.launch(share=True)
 
1
+ import spaces
2
  import gradio as gr
3
  import torch
4
  from TTS.api import TTS
 
12
  # Aceptar los términos de COQUI
13
  os.environ["COQUI_TOS_AGREED"] = "1"
14
 
 
 
 
15
  # Definir el dispositivo como CPU
16
  device = "cpu"
17
 
 
20
  config_path = hf_hub_download(repo_id="RedSparkie/danielmula", filename="config.json")
21
  vocab_path = hf_hub_download(repo_id="RedSparkie/danielmula", filename="vocab.json")
22
 
23
+ # Función para limpiar la caché de GPU (por si en el futuro se usa GPU)
24
+ def clear_gpu_cache():
25
+ if torch.cuda.is_available():
26
+ torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
27
 
28
  # Cargar el modelo XTTS
29
  XTTS_MODEL = None
30
  def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
31
  global XTTS_MODEL
32
+ clear_gpu_cache()
33
+ if not xtts_checkpoint or not xtts_config or not xtts_vocab:
34
+ return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!"
35
+
36
+ # Configuración del modelo
37
  config = XttsConfig()
38
  config.load_json(xtts_config)
39
 
 
42
  print("Loading XTTS model!")
43
 
44
  # Cargar el checkpoint del modelo
45
+ XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False, weights_only=True)
46
+
47
  print("Model Loaded!")
48
 
49
  # Función para ejecutar TTS
 
51
  if XTTS_MODEL is None or not speaker_audio_file:
52
  return "You need to run the previous step to load the model !!", None, None
53
 
 
 
 
 
 
 
 
 
54
  # Usar inference_mode para mejorar el rendimiento
55
  with torch.inference_mode():
56
  gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
57
+ audio_path=speaker_audio_file,
58
  gpt_cond_len=XTTS_MODEL.config.gpt_cond_len,
59
  max_ref_length=XTTS_MODEL.config.max_ref_len,
60
  sound_norm_refs=XTTS_MODEL.config.sound_norm_refs
61
  )
 
 
 
62
 
63
  out = XTTS_MODEL.inference(
64
  text=tts_text,
 
82
  return out_path, speaker_audio_file
83
 
84
  # Definir la función para Gradio
85
+ @spaces.GPU
86
  def generate(text, audio):
87
  load_model(model_path, config_path, vocab_path)
88
  out_path, speaker_audio_file = run_tts(lang='es', tts_text=text, speaker_audio_file=audio)
 
95
  outputs=gr.Audio(type='filepath')
96
  )
97
 
98
+ # Lanzar la interfaz
99
+ demo.launch()