MAZALA2024 commited on
Commit
b2c48c3
·
verified ·
1 Parent(s): c405c3a

Update voice_processing.py

Browse files
Files changed (1) hide show
  1. voice_processing.py +123 -106
voice_processing.py CHANGED
@@ -1,11 +1,18 @@
 
 
 
1
  import os
2
  import time
3
  import traceback
4
- import torch
5
- import numpy as np
 
 
6
  import librosa
 
7
  from fairseq import checkpoint_utils
8
- from rmvpe import RMVPE
 
9
  from config import Config
10
  from lib.infer_pack.models import (
11
  SynthesizerTrnMs256NSFsid,
@@ -13,65 +20,39 @@ from lib.infer_pack.models import (
13
  SynthesizerTrnMs768NSFsid,
14
  SynthesizerTrnMs768NSFsid_nono,
15
  )
 
16
  from vc_infer_pipeline import VC
17
- import uuid
18
- import tempfile # Ensure this is imported
19
- import asyncio # Ensure this is imported
 
 
 
 
 
 
20
 
21
  config = Config()
22
 
23
- # Global models loaded once
24
- hubert_model = None
25
- rmvpe_model = None
26
- model_cache = {} # Cache for RVC models
27
 
28
- def load_hubert():
29
- global hubert_model
30
- if hubert_model is None:
31
- print("Loading Hubert model...")
32
- models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
33
- ["hubert_base.pt"],
34
- suffix="",
35
- )
36
- hubert_model = models[0]
37
- hubert_model = hubert_model.to(config.device)
38
- if config.is_half:
39
- hubert_model = hubert_model.half()
40
- else:
41
- hubert_model = hubert_model.float()
42
- hubert_model.eval()
43
- print("Hubert model loaded.")
44
- return hubert_model
45
-
46
- def load_rmvpe():
47
- global rmvpe_model
48
- if rmvpe_model is None:
49
- print("Loading RMVPE model...")
50
- rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)
51
- print("RMVPE model loaded.")
52
- return rmvpe_model
53
 
54
  def get_unique_filename(extension):
55
  return f"{uuid.uuid4()}.{extension}"
56
 
57
- def get_model_names():
58
- model_root = "weights" # Assuming this is where your models are stored
59
- return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
60
-
61
  def model_data(model_name):
62
- global model_cache
63
- if model_name in model_cache:
64
- # Return cached model data
65
- return model_cache[model_name]
66
-
67
- model_root = "weights"
68
- pth_files = [
69
- f for f in os.listdir(f"{model_root}/{model_name}") if f.endswith(".pth")
70
- ]
71
- if not pth_files:
72
- raise FileNotFoundError(f"No .pth file found for model '{model_name}'")
73
- pth_path = f"{model_root}/{model_name}/{pth_files[0]}"
74
- print(f"Loading model from {pth_path}")
75
  cpt = torch.load(pth_path, map_location="cpu")
76
  tgt_sr = cpt["config"][-1]
77
  cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
@@ -89,32 +70,61 @@ def model_data(model_name):
89
  net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
90
  else:
91
  raise ValueError("Unknown version")
92
-
93
  del net_g.enc_q
94
  net_g.load_state_dict(cpt["weight"], strict=False)
 
95
  net_g.eval().to(config.device)
96
  if config.is_half:
97
  net_g = net_g.half()
98
  else:
99
  net_g = net_g.float()
100
- print(f"Model '{model_name}' loaded.")
101
-
102
  vc = VC(tgt_sr, config)
103
 
104
  index_files = [
105
- f for f in os.listdir(f"{model_root}/{model_name}") if f.endswith(".index")
 
 
106
  ]
107
- if index_files:
108
- index_file = f"{model_root}/{model_name}/{index_files[0]}"
109
- print(f"Index file found: {index_file}")
110
- else:
111
  index_file = ""
112
- print("No index file found.")
 
 
113
 
114
- # Cache the loaded model data
115
- model_cache[model_name] = (tgt_sr, net_g, vc, version, index_file, if_f0)
116
  return tgt_sr, net_g, vc, version, index_file, if_f0
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  async def tts(
119
  model_name,
120
  tts_text,
@@ -123,58 +133,63 @@ async def tts(
123
  use_uploaded_voice,
124
  uploaded_voice,
125
  ):
126
- try:
127
- # Load models if not already loaded
128
- load_hubert()
129
- load_rmvpe()
130
-
131
- # Default values for parameters used in EdgeTTS
132
- f0_up_key = 0 # Default pitch adjustment
133
- f0_method = "rmvpe" # Default pitch extraction method
134
- protect = 0.33 # Default protect value
135
- filter_radius = 3
136
- resample_sr = 0
137
- rms_mix_rate = 0.25
138
- edge_time = 0 # Initialize edge_time
139
-
140
- edge_output_filename = get_unique_filename("mp3")
141
- audio = None
142
- sr = 16000 # Default sample rate
143
 
 
144
  if use_uploaded_voice:
145
  if uploaded_voice is None:
146
- return {"error": "No voice file uploaded."}, None, None
147
-
148
  # Process the uploaded voice file
149
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
150
  tmp_file.write(uploaded_voice)
151
  uploaded_file_path = tmp_file.name
152
 
153
  audio, sr = librosa.load(uploaded_file_path, sr=16000, mono=True)
154
- input_audio_path = uploaded_file_path
155
  else:
156
  # EdgeTTS processing
157
- import edge_tts
 
 
 
 
 
 
 
158
  t0 = time.time()
159
- speed = 0 # Default speech speed
160
  speed_str = f"+{speed}%" if speed >= 0 else f"{speed}%"
161
- communicate = edge_tts.Communicate(
162
  tts_text, tts_voice, rate=speed_str
163
- )
164
- try:
165
- await asyncio.wait_for(communicate.save(edge_output_filename), timeout=30)
166
- except asyncio.TimeoutError:
167
- return {"error": "EdgeTTS operation timed out"}, None, None
168
  t1 = time.time()
169
  edge_time = t1 - t0
170
 
171
  audio, sr = librosa.load(edge_output_filename, sr=16000, mono=True)
172
- input_audio_path = edge_output_filename
173
 
174
- # Load the specified RVC model
 
 
 
 
 
 
 
 
 
 
175
  tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
176
 
177
- # Set RMVPE model for pitch extraction
178
  if f0_method == "rmvpe":
179
  vc.model_rmvpe = rmvpe_model
180
 
@@ -183,9 +198,9 @@ async def tts(
183
  audio_opt = vc.pipeline(
184
  hubert_model,
185
  net_g,
186
- 0, # Speaker ID
187
  audio,
188
- input_audio_path,
189
  times,
190
  f0_up_key,
191
  f0_method,
@@ -203,29 +218,31 @@ async def tts(
203
 
204
  if tgt_sr != resample_sr and resample_sr >= 16000:
205
  tgt_sr = resample_sr
206
-
207
- info = f"Success. Time: tts: {edge_time:.2f}s, npy: {times[0]:.2f}s, f0: {times[1]:.2f}s, infer: {times[2]:.2f}s"
208
  print(info)
209
  return (
210
  info,
211
- edge_output_filename,
212
  (tgt_sr, audio_opt),
213
  )
214
 
215
- except asyncio.CancelledError:
216
- print("TTS operation was cancelled")
217
- return {"error": "Operation cancelled"}, None, None
218
  except EOFError:
219
- info = "Output not valid. This may occur when input text and speaker do not match."
 
 
220
  print(info)
221
- return {"error": info}, None, None
222
  except Exception as e:
223
  traceback_info = traceback.format_exc()
224
  print(traceback_info)
225
- return {"error": str(e)}, None, None
226
 
227
- # Voice mapping dictionary
228
  voice_mapping = {
229
  "Mongolian Male": "mn-MN-BataaNeural",
230
  "Mongolian Female": "mn-MN-YesuiNeural"
231
  }
 
 
 
 
 
1
+ import asyncio
2
+ import datetime
3
+ import logging
4
  import os
5
  import time
6
  import traceback
7
+ import tempfile
8
+ from concurrent.futures import ThreadPoolExecutor
9
+
10
+ import edge_tts
11
  import librosa
12
+ import torch
13
  from fairseq import checkpoint_utils
14
+ import uuid
15
+
16
  from config import Config
17
  from lib.infer_pack.models import (
18
  SynthesizerTrnMs256NSFsid,
 
20
  SynthesizerTrnMs768NSFsid,
21
  SynthesizerTrnMs768NSFsid_nono,
22
  )
23
+ from rmvpe import RMVPE
24
  from vc_infer_pipeline import VC
25
+
26
+ # Set logging levels
27
+ logging.getLogger("fairseq").setLevel(logging.WARNING)
28
+ logging.getLogger("numba").setLevel(logging.WARNING)
29
+ logging.getLogger("markdown_it").setLevel(logging.WARNING)
30
+ logging.getLogger("urllib3").setLevel(logging.WARNING)
31
+ logging.getLogger("matplotlib").setLevel(logging.WARNING)
32
+
33
+ limitation = os.getenv("SYSTEM") == "spaces"
34
 
35
  config = Config()
36
 
37
+ # Edge TTS
38
+ tts_voice_list = asyncio.get_event_loop().run_until_complete(edge_tts.list_voices())
39
+ tts_voices = ["mn-MN-BataaNeural", "mn-MN-YesuiNeural"] # Specific voices
 
40
 
41
+ # RVC models
42
+ model_root = "weights"
43
+ models = [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
44
+ models.sort()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  def get_unique_filename(extension):
47
  return f"{uuid.uuid4()}.{extension}"
48
 
 
 
 
 
49
  def model_data(model_name):
50
+ pth_path = [
51
+ f"{model_root}/{model_name}/{f}"
52
+ for f in os.listdir(f"{model_root}/{model_name}")
53
+ if f.endswith(".pth")
54
+ ][0]
55
+ print(f"Loading {pth_path}")
 
 
 
 
 
 
 
56
  cpt = torch.load(pth_path, map_location="cpu")
57
  tgt_sr = cpt["config"][-1]
58
  cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
 
70
  net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
71
  else:
72
  raise ValueError("Unknown version")
 
73
  del net_g.enc_q
74
  net_g.load_state_dict(cpt["weight"], strict=False)
75
+ print("Model loaded")
76
  net_g.eval().to(config.device)
77
  if config.is_half:
78
  net_g = net_g.half()
79
  else:
80
  net_g = net_g.float()
 
 
81
  vc = VC(tgt_sr, config)
82
 
83
  index_files = [
84
+ f"{model_root}/{model_name}/{f}"
85
+ for f in os.listdir(f"{model_root}/{model_name}")
86
+ if f.endswith(".index")
87
  ]
88
+ if len(index_files) == 0:
89
+ print("No index file found")
 
 
90
  index_file = ""
91
+ else:
92
+ index_file = index_files[0]
93
+ print(f"Index file found: {index_file}")
94
 
 
 
95
  return tgt_sr, net_g, vc, version, index_file, if_f0
96
 
97
+ def load_hubert():
98
+ models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
99
+ ["hubert_base.pt"],
100
+ suffix="",
101
+ )
102
+ hubert_model = models[0]
103
+ hubert_model = hubert_model.to(config.device)
104
+ if config.is_half:
105
+ hubert_model = hubert_model.half()
106
+ else:
107
+ hubert_model = hubert_model.float()
108
+ return hubert_model.eval()
109
+
110
+ def get_model_names():
111
+ model_root = "weights" # Assuming this is where your models are stored
112
+ return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
113
+
114
+ # Add this helper function to ensure a new event loop is created if none exists
115
+ def run_async_in_thread(fn, *args):
116
+ loop = asyncio.new_event_loop()
117
+ asyncio.set_event_loop(loop)
118
+ result = loop.run_until_complete(fn(*args))
119
+ loop.close()
120
+ return result
121
+
122
+ def parallel_tts(tasks):
123
+ with ThreadPoolExecutor() as executor:
124
+ futures = [executor.submit(run_async_in_thread, tts, *task) for task in tasks]
125
+ results = [future.result() for future in futures]
126
+ return results
127
+
128
  async def tts(
129
  model_name,
130
  tts_text,
 
133
  use_uploaded_voice,
134
  uploaded_voice,
135
  ):
136
+ # Default values for parameters used in EdgeTTS
137
+ speed = 0 # Default speech speed
138
+ f0_up_key = 0 # Default pitch adjustment
139
+ f0_method = "rmvpe" # Default pitch extraction method
140
+ protect = 0.33 # Default protect value
141
+ filter_radius = 3
142
+ resample_sr = 0
143
+ rms_mix_rate = 0.25
144
+ edge_time = 0 # Initialize edge_time
145
+
146
+ edge_output_filename = get_unique_filename("mp3")
 
 
 
 
 
 
147
 
148
+ try:
149
  if use_uploaded_voice:
150
  if uploaded_voice is None:
151
+ return "No voice file uploaded.", None, None
152
+
153
  # Process the uploaded voice file
154
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
155
  tmp_file.write(uploaded_voice)
156
  uploaded_file_path = tmp_file.name
157
 
158
  audio, sr = librosa.load(uploaded_file_path, sr=16000, mono=True)
 
159
  else:
160
  # EdgeTTS processing
161
+ if limitation and len(tts_text) > 12000:
162
+ return (
163
+ f"Text characters should be at most 12000 in this huggingface space, but got {len(tts_text)} characters.",
164
+ None,
165
+ None,
166
+ )
167
+
168
+ # Invoke Edge TTS
169
  t0 = time.time()
 
170
  speed_str = f"+{speed}%" if speed >= 0 else f"{speed}%"
171
+ await edge_tts.Communicate(
172
  tts_text, tts_voice, rate=speed_str
173
+ ).save(edge_output_filename)
 
 
 
 
174
  t1 = time.time()
175
  edge_time = t1 - t0
176
 
177
  audio, sr = librosa.load(edge_output_filename, sr=16000, mono=True)
 
178
 
179
+ # Common processing after loading the audio
180
+ duration = len(audio) / sr
181
+ print(f"Audio duration: {duration}s")
182
+ if limitation and duration >= 20000:
183
+ return (
184
+ f"Audio should be less than 20 seconds in this huggingface space, but got {duration}s.",
185
+ None,
186
+ None,
187
+ )
188
+
189
+ f0_up_key = int(f0_up_key)
190
  tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
191
 
192
+ # Setup for RMVPE or other pitch extraction methods
193
  if f0_method == "rmvpe":
194
  vc.model_rmvpe = rmvpe_model
195
 
 
198
  audio_opt = vc.pipeline(
199
  hubert_model,
200
  net_g,
201
+ 0,
202
  audio,
203
+ edge_output_filename if not use_uploaded_voice else uploaded_file_path,
204
  times,
205
  f0_up_key,
206
  f0_method,
 
218
 
219
  if tgt_sr != resample_sr and resample_sr >= 16000:
220
  tgt_sr = resample_sr
221
+
222
+ info = f"Success. Time: tts: {edge_time}s, npy: {times[0]}s, f0: {times[1]}s, infer: {times[2]}s"
223
  print(info)
224
  return (
225
  info,
226
+ edge_output_filename if not use_uploaded_voice else None,
227
  (tgt_sr, audio_opt),
228
  )
229
 
 
 
 
230
  except EOFError:
231
+ info = (
232
+ "output not valid. This may occur when input text and speaker do not match."
233
+ )
234
  print(info)
235
+ return info, None, None
236
  except Exception as e:
237
  traceback_info = traceback.format_exc()
238
  print(traceback_info)
239
+ return str(e), None, None
240
 
 
241
  voice_mapping = {
242
  "Mongolian Male": "mn-MN-BataaNeural",
243
  "Mongolian Female": "mn-MN-YesuiNeural"
244
  }
245
+
246
+ hubert_model = load_hubert()
247
+
248
+ rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)