MAZALA2024 commited on
Commit
f35f27e
·
verified ·
1 Parent(s): 8d4ed80

Update voice_processing.py

Browse files
Files changed (1) hide show
  1. voice_processing.py +78 -82
voice_processing.py CHANGED
@@ -34,7 +34,7 @@ limitation = os.getenv("SYSTEM") == "spaces"
34
 
35
  config = Config()
36
 
37
- # Edge TTS
38
  tts_voice_list = asyncio.get_event_loop().run_until_complete(edge_tts.list_voices())
39
  tts_voices = ["mn-MN-BataaNeural", "mn-MN-YesuiNeural"] # Specific voices
40
 
@@ -47,12 +47,13 @@ def get_unique_filename(extension):
47
  return f"{uuid.uuid4()}.{extension}"
48
 
49
  def model_data(model_name):
50
- pth_path = [
51
- f"{model_root}/{model_name}/{f}"
52
- for f in os.listdir(f"{model_root}/{model_name}")
53
- if f.endswith(".pth")
54
- ][0]
55
- print(f"Loading {pth_path}")
 
56
  cpt = torch.load(pth_path, map_location="cpu")
57
  tgt_sr = cpt["config"][-1]
58
  cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
@@ -72,25 +73,23 @@ def model_data(model_name):
72
  raise ValueError("Unknown version")
73
  del net_g.enc_q
74
  net_g.load_state_dict(cpt["weight"], strict=False)
75
- print("Model loaded")
76
  net_g.eval().to(config.device)
77
  if config.is_half:
78
  net_g = net_g.half()
79
  else:
80
  net_g = net_g.float()
 
81
  vc = VC(tgt_sr, config)
82
 
83
  index_files = [
84
- f"{model_root}/{model_name}/{f}"
85
- for f in os.listdir(f"{model_root}/{model_name}")
86
- if f.endswith(".index")
87
  ]
88
- if len(index_files) == 0:
89
- print("No index file found")
90
- index_file = ""
91
- else:
92
- index_file = index_files[0]
93
  print(f"Index file found: {index_file}")
 
 
 
94
 
95
  return tgt_sr, net_g, vc, version, index_file, if_f0
96
 
@@ -108,22 +107,17 @@ def load_hubert():
108
  return hubert_model.eval()
109
 
110
  def get_model_names():
111
- model_root = "weights" # Assuming this is where your models are stored
112
  return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
113
 
114
- # Add this helper function to ensure a new event loop is created if none exists
115
- def run_async_in_thread(fn, *args):
116
- loop = asyncio.new_event_loop()
117
- asyncio.set_event_loop(loop)
118
- result = loop.run_until_complete(fn(*args))
119
- loop.close()
120
- return result
121
 
122
- def parallel_tts(tasks):
123
- with ThreadPoolExecutor() as executor:
124
- futures = [executor.submit(run_async_in_thread, tts, *task) for task in tasks]
125
- results = [future.result() for future in futures]
126
- return results
127
 
128
  async def tts(
129
  model_name,
@@ -133,63 +127,65 @@ async def tts(
133
  use_uploaded_voice,
134
  uploaded_voice,
135
  ):
136
- # Default values for parameters used in EdgeTTS
137
- speed = 0 # Default speech speed
138
- f0_up_key = 0 # Default pitch adjustment
139
- f0_method = "rmvpe" # Default pitch extraction method
140
- protect = 0.33 # Default protect value
141
- filter_radius = 3
142
- resample_sr = 0
143
- rms_mix_rate = 0.25
144
- edge_time = 0 # Initialize edge_time
145
-
146
- edge_output_filename = get_unique_filename("mp3")
147
-
148
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  if use_uploaded_voice:
150
  if uploaded_voice is None:
151
- return "No voice file uploaded.", None, None
152
-
153
  # Process the uploaded voice file
154
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
155
  tmp_file.write(uploaded_voice)
156
- uploaded_file_path = tmp_file.name
157
 
158
- audio, sr = librosa.load(uploaded_file_path, sr=16000, mono=True)
159
  else:
160
- # EdgeTTS processing
 
 
 
161
  if limitation and len(tts_text) > 12000:
162
- return (
163
- f"Text characters should be at most 12000 in this huggingface space, but got {len(tts_text)} characters.",
164
- None,
165
- None,
166
- )
167
-
168
- # Invoke Edge TTS
169
- t0 = time.time()
170
  speed_str = f"+{speed}%" if speed >= 0 else f"{speed}%"
171
- await edge_tts.Communicate(
172
  tts_text, tts_voice, rate=speed_str
173
- ).save(edge_output_filename)
 
 
 
174
  t1 = time.time()
175
  edge_time = t1 - t0
176
 
177
  audio, sr = librosa.load(edge_output_filename, sr=16000, mono=True)
 
178
 
179
- # Common processing after loading the audio
180
- duration = len(audio) / sr
181
- print(f"Audio duration: {duration}s")
182
- if limitation and duration >= 20000:
183
- return (
184
- f"Audio should be less than 20 seconds in this huggingface space, but got {duration}s.",
185
- None,
186
- None,
187
- )
188
-
189
- f0_up_key = int(f0_up_key)
190
  tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
191
 
192
- # Setup for RMVPE or other pitch extraction methods
193
  if f0_method == "rmvpe":
194
  vc.model_rmvpe = rmvpe_model
195
 
@@ -198,9 +194,9 @@ async def tts(
198
  audio_opt = vc.pipeline(
199
  hubert_model,
200
  net_g,
201
- 0,
202
  audio,
203
- edge_output_filename if not use_uploaded_voice else uploaded_file_path,
204
  times,
205
  f0_up_key,
206
  f0_method,
@@ -218,25 +214,25 @@ async def tts(
218
 
219
  if tgt_sr != resample_sr and resample_sr >= 16000:
220
  tgt_sr = resample_sr
221
-
222
- info = f"Success. Time: tts: {edge_time}s, npy: {times[0]}s, f0: {times[1]}s, infer: {times[2]}s"
223
  print(info)
224
- return (
225
- info,
226
- edge_output_filename if not use_uploaded_voice else None,
227
- (tgt_sr, audio_opt),
228
- )
 
229
 
230
  except EOFError:
231
- info = (
232
- "output not valid. This may occur when input text and speaker do not match."
233
- )
234
  print(info)
235
- return info, None, None
236
  except Exception as e:
237
  traceback_info = traceback.format_exc()
238
  print(traceback_info)
239
- return str(e), None, None
 
240
 
241
  voice_mapping = {
242
  "Mongolian Male": "mn-MN-BataaNeural",
 
34
 
35
  config = Config()
36
 
37
+ # Edge TTS voices
38
  tts_voice_list = asyncio.get_event_loop().run_until_complete(edge_tts.list_voices())
39
  tts_voices = ["mn-MN-BataaNeural", "mn-MN-YesuiNeural"] # Specific voices
40
 
 
47
  return f"{uuid.uuid4()}.{extension}"
48
 
49
  def model_data(model_name):
50
+ pth_files = [
51
+ f for f in os.listdir(f"{model_root}/{model_name}") if f.endswith(".pth")
52
+ ]
53
+ if not pth_files:
54
+ raise FileNotFoundError(f"No .pth file found for model '{model_name}'")
55
+ pth_path = f"{model_root}/{model_name}/{pth_files[0]}"
56
+ print(f"Loading model from {pth_path}")
57
  cpt = torch.load(pth_path, map_location="cpu")
58
  tgt_sr = cpt["config"][-1]
59
  cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
 
73
  raise ValueError("Unknown version")
74
  del net_g.enc_q
75
  net_g.load_state_dict(cpt["weight"], strict=False)
 
76
  net_g.eval().to(config.device)
77
  if config.is_half:
78
  net_g = net_g.half()
79
  else:
80
  net_g = net_g.float()
81
+ print(f"Model '{model_name}' loaded.")
82
  vc = VC(tgt_sr, config)
83
 
84
  index_files = [
85
+ f for f in os.listdir(f"{model_root}/{model_name}") if f.endswith(".index")
 
 
86
  ]
87
+ if index_files:
88
+ index_file = f"{model_root}/{model_name}/{index_files[0]}"
 
 
 
89
  print(f"Index file found: {index_file}")
90
+ else:
91
+ index_file = ""
92
+ print("No index file found.")
93
 
94
  return tgt_sr, net_g, vc, version, index_file, if_f0
95
 
 
107
  return hubert_model.eval()
108
 
109
  def get_model_names():
 
110
  return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
111
 
112
+ # Voice mapping dictionary
113
+ voice_mapping = {
114
+ "Mongolian Male": "mn-MN-BataaNeural",
115
+ "Mongolian Female": "mn-MN-YesuiNeural"
116
+ }
 
 
117
 
118
+ # Load models once
119
+ hubert_model = load_hubert()
120
+ rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)
 
 
121
 
122
  async def tts(
123
  model_name,
 
127
  use_uploaded_voice,
128
  uploaded_voice,
129
  ):
 
 
 
 
 
 
 
 
 
 
 
 
130
  try:
131
+ # Validate inputs
132
+ if not tts_text.strip():
133
+ return {"success": False, "error": "Input text is empty."}
134
+
135
+ if tts_voice not in voice_mapping.values():
136
+ return {"success": False, "error": f"Invalid voice '{tts_voice}'."}
137
+
138
+ # Default parameters
139
+ f0_up_key = 0 # Pitch adjustment
140
+ f0_method = "rmvpe" # Pitch extraction method
141
+ protect = 0.33 # Protect value
142
+ filter_radius = 3
143
+ resample_sr = 0
144
+ rms_mix_rate = 0.25
145
+ edge_time = 0
146
+
147
+ audio = None
148
+ sr = 16000 # Sample rate
149
+
150
  if use_uploaded_voice:
151
  if uploaded_voice is None:
152
+ return {"success": False, "error": "No voice file uploaded."}
153
+
154
  # Process the uploaded voice file
155
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
156
  tmp_file.write(uploaded_voice)
157
+ input_audio_path = tmp_file.name
158
 
159
+ audio, sr = librosa.load(input_audio_path, sr=16000, mono=True)
160
  else:
161
+ # Edge TTS processing
162
+ edge_output_filename = get_unique_filename("mp3")
163
+
164
+ # Edge TTS limitations
165
  if limitation and len(tts_text) > 12000:
166
+ return {
167
+ "success": False,
168
+ "error": f"Text characters should be at most 12000 in this huggingface space, but got {len(tts_text)} characters."
169
+ }
170
+
171
+ speed = 0 # Speech speed
 
 
172
  speed_str = f"+{speed}%" if speed >= 0 else f"{speed}%"
173
+ communicate = edge_tts.Communicate(
174
  tts_text, tts_voice, rate=speed_str
175
+ )
176
+
177
+ t0 = time.time()
178
+ await communicate.save(edge_output_filename)
179
  t1 = time.time()
180
  edge_time = t1 - t0
181
 
182
  audio, sr = librosa.load(edge_output_filename, sr=16000, mono=True)
183
+ input_audio_path = edge_output_filename
184
 
185
+ # Load the specified RVC model
 
 
 
 
 
 
 
 
 
 
186
  tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
187
 
188
+ # Set RMVPE model for pitch extraction
189
  if f0_method == "rmvpe":
190
  vc.model_rmvpe = rmvpe_model
191
 
 
194
  audio_opt = vc.pipeline(
195
  hubert_model,
196
  net_g,
197
+ 0, # Speaker ID
198
  audio,
199
+ input_audio_path,
200
  times,
201
  f0_up_key,
202
  f0_method,
 
214
 
215
  if tgt_sr != resample_sr and resample_sr >= 16000:
216
  tgt_sr = resample_sr
217
+
218
+ info = f"Success. Time: tts: {edge_time:.2f}s, npy: {times[0]:.2f}s, f0: {times[1]:.2f}s, infer: {times[2]:.2f}s"
219
  print(info)
220
+ return {
221
+ "success": True,
222
+ "info": info,
223
+ "tgt_sr": tgt_sr,
224
+ "audio_opt": audio_opt
225
+ }
226
 
227
  except EOFError:
228
+ info = "Output not valid. This may occur when input text and speaker do not match."
 
 
229
  print(info)
230
+ return {"success": False, "error": info}
231
  except Exception as e:
232
  traceback_info = traceback.format_exc()
233
  print(traceback_info)
234
+ return {"success": False, "error": str(e)}
235
+
236
 
237
  voice_mapping = {
238
  "Mongolian Male": "mn-MN-BataaNeural",