NeoPy commited on
Commit
ac2697c
·
verified ·
1 Parent(s): 0a6f6ac

Upload utils.py

Browse files
Files changed (1) hide show
  1. infer/lib/utils.py +478 -0
infer/lib/utils.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import gc
4
+ import sys
5
+ import torch
6
+ import faiss
7
+ import codecs
8
+ import logging
9
+
10
+ import numpy as np
11
+
12
+ from pydub import AudioSegment
13
+
14
+ sys.path.append(os.getcwd())
15
+
16
+ from main.tools import huggingface
17
+ from main.library.backends import directml, opencl
18
+ from main.app.variables import translations, configs, config, logger, embedders_model, spin_model, whisper_model
19
+
20
+ for l in ["httpx", "httpcore"]:
21
+ logging.getLogger(l).setLevel(logging.ERROR)
22
+
23
+ def check_assets(f0_method, hubert, predictor_onnx=False, embedders_mode="fairseq"):
24
+ predictors_url = codecs.decode(
25
+ "uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/cerqvpgbef/",
26
+ "rot13"
27
+ )
28
+ embedders_url = codecs.decode(
29
+ "uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/rzorqqref/",
30
+ "rot13"
31
+ )
32
+
33
+ if embedders_mode == "spin": embedders_mode = "transformers"
34
+
35
+ def download_predictor(predictor):
36
+ model_path = os.path.join(configs["predictors_path"], predictor)
37
+
38
+ if not os.path.exists(model_path):
39
+ huggingface.HF_download_file(
40
+ predictors_url + predictor,
41
+ model_path
42
+ )
43
+
44
+ return os.path.exists(model_path)
45
+
46
+ def download_embedder(embedders_mode, hubert):
47
+ model_path = (
48
+ os.path.join(
49
+ configs["speaker_diarization_path"],
50
+ "models",
51
+ hubert
52
+ )
53
+ ) if embedders_mode == "whisper" else (
54
+ os.path.join(
55
+ configs["embedders_path"],
56
+ hubert
57
+ )
58
+ )
59
+
60
+ if embedders_mode != "transformers" and not os.path.exists(model_path):
61
+ if embedders_mode == "whisper":
62
+ huggingface.HF_download_file(
63
+ "".join([
64
+ codecs.decode(
65
+ "uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/fcrnxre_qvnevmngvba/",
66
+ "rot13"
67
+ ),
68
+ hubert
69
+ ]),
70
+ model_path
71
+ )
72
+ else:
73
+ huggingface.HF_download_file(
74
+ "".join([
75
+ embedders_url, "fairseq/" if embedders_mode == "fairseq" else "onnx/",
76
+ hubert
77
+ ]),
78
+ model_path
79
+ )
80
+ elif embedders_mode == "transformers":
81
+ url = "transformers/" if not hubert.startswith("spin") else "spin/"
82
+
83
+ bin_file = os.path.join(model_path, "model.safetensors")
84
+ config_file = os.path.join(model_path, "config.json")
85
+
86
+ os.makedirs(model_path, exist_ok=True)
87
+
88
+ if not os.path.exists(bin_file):
89
+ huggingface.HF_download_file(
90
+ "".join([embedders_url, url, hubert, "/model.safetensors"]),
91
+ bin_file
92
+ )
93
+
94
+ if not os.path.exists(config_file):
95
+ huggingface.HF_download_file(
96
+ "".join([embedders_url, url, hubert, "/config.json"]),
97
+ config_file
98
+ )
99
+
100
+ return os.path.exists(bin_file) and os.path.exists(config_file)
101
+
102
+ return os.path.exists(model_path)
103
+
104
+ def get_modelname(f0_method, predictor_onnx=False):
105
+ suffix = ".onnx" if predictor_onnx else (".pt" if "crepe" not in f0_method else ".pth")
106
+
107
+ if "rmvpe" in f0_method:
108
+ modelname = (
109
+ "hpa-rmvpe-76000"
110
+ if "previous" in f0_method else
111
+ "hpa-rmvpe-112000"
112
+ ) if "hpa" in f0_method else "rmvpe"
113
+ elif "fcpe" in f0_method:
114
+ modelname = (
115
+ "fcpe_legacy"
116
+ if "legacy" in f0_method else
117
+ "fcpe"
118
+ ) if "previous" in f0_method or "legacy" in f0_method else "ddsp_200k"
119
+ elif "crepe" in f0_method:
120
+ modelname = "crepe_" + f0_method.replace("mangio-", "").split("-")[1]
121
+ elif "penn" in f0_method:
122
+ modelname = "fcn"
123
+ elif "djcm" in f0_method:
124
+ modelname = "djcm" + "-svs" if "svs" in f0_method else ""
125
+ elif "pesto" in f0_method:
126
+ modelname = "pesto"
127
+ elif "swift" in f0_method:
128
+ return "swift.onnx"
129
+ else:
130
+ return None
131
+
132
+ return modelname + suffix
133
+
134
+ results = []
135
+ count = configs.get("num_of_restart", 5)
136
+
137
+ for _ in range(count):
138
+ if "hybrid" in f0_method:
139
+ methods_str = re.search(r"hybrid\[(.+)\]", f0_method)
140
+
141
+ if methods_str:
142
+ methods = [
143
+ f0_method.strip()
144
+ for f0_method in methods_str.group(1).split("+")
145
+ ]
146
+
147
+ for method in methods:
148
+ modelname = get_modelname(method, predictor_onnx)
149
+
150
+ if modelname is not None:
151
+ results.append(
152
+ download_predictor(modelname)
153
+ )
154
+ else:
155
+ modelname = get_modelname(f0_method, predictor_onnx)
156
+
157
+ if modelname is not None:
158
+ results.append(
159
+ download_predictor(modelname)
160
+ )
161
+
162
+ if hubert in embedders_model + spin_model + whisper_model:
163
+ if embedders_mode != "transformers":
164
+ hubert += ".pt" if embedders_mode in ["fairseq", "whisper"] else ".onnx"
165
+
166
+ results.append(
167
+ download_embedder(
168
+ embedders_mode,
169
+ hubert
170
+ )
171
+ )
172
+
173
+ if all(results): return
174
+ else: results = []
175
+
176
+ logger.warning(translations["check_assets_error"].format(count=count))
177
+ sys.exit(1)
178
+
179
+ def check_spk_diarization(model_size, speechbrain=True):
180
+ whisper_model = os.path.join(configs["speaker_diarization_path"], "models", f"{model_size}.pt")
181
+
182
+ if not os.path.exists(whisper_model):
183
+ huggingface.HF_download_file(
184
+ "".join([
185
+ codecs.decode(
186
+ "uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/fcrnxre_qvnevmngvba/",
187
+ "rot13"
188
+ ),
189
+ model_size,
190
+ ".pt"
191
+ ]),
192
+ whisper_model
193
+ )
194
+
195
+ speechbrain_path = os.path.join(configs["speaker_diarization_path"], "models", "speechbrain")
196
+ if not os.path.exists(speechbrain_path): os.makedirs(speechbrain_path, exist_ok=True)
197
+
198
+ if speechbrain:
199
+ for f in [
200
+ "classifier.ckpt",
201
+ "config.json",
202
+ "embedding_model.ckpt",
203
+ "hyperparams.yaml",
204
+ "mean_var_norm_emb.ckpt"
205
+ ]:
206
+ speechbrain_model = os.path.join(speechbrain_path, f)
207
+
208
+ if not os.path.exists(speechbrain_model):
209
+ huggingface.HF_download_file(
210
+ codecs.decode(
211
+ "uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/fcrnxre_qvnevmngvba/fcrrpuoenva/",
212
+ "rot13"
213
+ ) + f,
214
+ speechbrain_model
215
+ )
216
+
217
+ def load_audio(file, sample_rate=16000, formant_shifting=False, formant_qfrency=0.8, formant_timbre=0.8):
218
+ import librosa
219
+ import soundfile as sf
220
+
221
+ try:
222
+ file = file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
223
+ if not os.path.isfile(file): raise FileNotFoundError(translations["not_found"].format(name=file))
224
+
225
+ try:
226
+ audio, sr = sf.read(file, dtype=np.float32)
227
+ except:
228
+ audio, sr = librosa.load(file, sr=None)
229
+
230
+ if len(audio.shape) > 1: audio = librosa.to_mono(audio.T)
231
+
232
+ if sr != sample_rate:
233
+ audio = librosa.resample(
234
+ audio,
235
+ orig_sr=sr,
236
+ target_sr=sample_rate,
237
+ res_type="soxr_vhq"
238
+ )
239
+
240
+ if formant_shifting:
241
+ from main.library.algorithm.stftpitchshift import StftPitchShift
242
+
243
+ pitchshifter = StftPitchShift(
244
+ 1024,
245
+ 32,
246
+ sample_rate
247
+ )
248
+
249
+ audio = pitchshifter.shiftpitch(
250
+ audio,
251
+ factors=1,
252
+ quefrency=formant_qfrency * 1e-3,
253
+ distortion=formant_timbre
254
+ )
255
+ except Exception as e:
256
+ raise RuntimeError(f"{translations['errors_loading_audio']}: {e}")
257
+
258
+ return audio.flatten()
259
+
260
+ def pydub_load(input_path, volume = None):
261
+ try:
262
+ if input_path.endswith(".wav"): audio = AudioSegment.from_wav(input_path)
263
+ elif input_path.endswith(".mp3"): audio = AudioSegment.from_mp3(input_path)
264
+ elif input_path.endswith(".ogg"): audio = AudioSegment.from_ogg(input_path)
265
+ else: audio = AudioSegment.from_file(input_path)
266
+ except:
267
+ audio = AudioSegment.from_file(input_path)
268
+
269
+ return audio if volume is None else (audio + volume)
270
+
271
+ def load_embedders_model(embedder_model, embedders_mode="fairseq"):
272
+ if embedders_mode in ["fairseq", "whisper"]: embedder_model += ".pt"
273
+ elif embedders_mode == "onnx": embedder_model += ".onnx"
274
+ elif embedders_mode == "spin": embedders_mode = "transformers"
275
+
276
+ embedder_model_path = (
277
+ os.path.join(
278
+ configs["speaker_diarization_path"],
279
+ "models",
280
+ embedder_model
281
+ )
282
+ ) if embedders_mode == "whisper" else (
283
+ os.path.join(
284
+ configs["embedders_path"],
285
+ embedder_model
286
+ )
287
+ )
288
+
289
+ if not os.path.exists(embedder_model_path):
290
+ raise FileNotFoundError(
291
+ f"{translations['not_found'].format(name=translations['model'])}: {embedder_model}"
292
+ )
293
+
294
+ try:
295
+ if embedders_mode == "fairseq":
296
+ from main.library.embedders.fairseq import load_model
297
+
298
+ hubert_model = load_model(
299
+ embedder_model_path
300
+ )
301
+ elif embedders_mode == "onnx":
302
+ from main.library.embedders.onnx import HubertModelONNX
303
+
304
+ hubert_model = HubertModelONNX(
305
+ embedder_model_path,
306
+ config.providers,
307
+ config.device
308
+ )
309
+ elif embedders_mode == "transformers":
310
+ from main.library.embedders.transformers import HubertModelWithFinalProj
311
+
312
+ hubert_model = HubertModelWithFinalProj.from_pretrained(
313
+ embedder_model_path
314
+ )
315
+ elif embedders_mode == "whisper":
316
+ from main.library.embedders.ppg import WhisperModel
317
+
318
+ hubert_model = WhisperModel(
319
+ embedder_model_path,
320
+ config.device
321
+ )
322
+ else: raise ValueError(translations["option_not_valid"])
323
+ except Exception as e:
324
+ raise RuntimeError(translations["read_model_error"].format(e=e))
325
+
326
+ return hubert_model
327
+
328
+ def cut(audio, sr, db_thresh=-60, min_interval=250):
329
+ from main.inference.preprocess.slicer2 import Slicer2
330
+
331
+ slicer = Slicer2(
332
+ sr=sr,
333
+ threshold=db_thresh,
334
+ min_interval=min_interval
335
+ )
336
+
337
+ return slicer.slice2(audio)
338
+
339
+ def restore(segments, total_len, dtype=np.float32):
340
+ out = []
341
+ last_end = 0
342
+
343
+ for start, end, processed_seg in segments:
344
+ if start > last_end:
345
+ out.append(
346
+ np.zeros(start - last_end, dtype=dtype)
347
+ )
348
+
349
+ out.append(processed_seg)
350
+ last_end = end
351
+
352
+ if last_end < total_len:
353
+ out.append(
354
+ np.zeros(total_len - last_end, dtype=dtype)
355
+ )
356
+
357
+ return np.concatenate(out, axis=-1)
358
+
359
+ def extract_features(model, feats, version, device="cpu"):
360
+ with torch.no_grad():
361
+ logits = model.extract_features(
362
+ **{
363
+ "source": feats,
364
+ "padding_mask": torch.BoolTensor(feats.shape).fill_(False).to(device),
365
+ "output_layer": 9 if version == "v1" else 12
366
+ }
367
+ )
368
+ feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
369
+
370
+ return feats
371
+
372
+ def autotune_f0(note_dict, f0, f0_autotune_strength):
373
+ autotuned_f0 = np.zeros_like(f0)
374
+
375
+ for i, freq in enumerate(f0):
376
+ autotuned_f0[i] = freq + (min(note_dict, key=lambda x: abs(x - freq)) - freq) * f0_autotune_strength
377
+
378
+ return autotuned_f0
379
+
380
+ def change_rms(source_audio, source_rate, target_audio, target_rate, rate):
381
+ import librosa
382
+ import torch.nn.functional as F
383
+
384
+ rms2 = F.interpolate(
385
+ torch.from_numpy(
386
+ librosa.feature.rms(
387
+ y=target_audio,
388
+ frame_length=target_rate // 2 * 2,
389
+ hop_length=target_rate // 2
390
+ )
391
+ ).float().unsqueeze(0),
392
+ size=target_audio.shape[0],
393
+ mode="linear"
394
+ ).squeeze()
395
+
396
+ return target_audio * (
397
+ F.interpolate(
398
+ torch.from_numpy(
399
+ librosa.feature.rms(
400
+ y=source_audio,
401
+ frame_length=source_rate // 2 * 2,
402
+ hop_length=source_rate // 2
403
+ )
404
+ ).float().unsqueeze(0),
405
+ size=target_audio.shape[0],
406
+ mode="linear"
407
+ ).squeeze().pow(1 - rate) * rms2.maximum(torch.zeros_like(rms2) + 1e-6).pow(rate - 1)
408
+ ).numpy()
409
+
410
+ def clear_gpu_cache():
411
+ gc.collect()
412
+
413
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
414
+ elif torch.backends.mps.is_available(): torch.mps.empty_cache()
415
+ elif directml.is_available(): directml.empty_cache()
416
+ elif opencl.is_available(): opencl.pytorch_ocl.empty_cache()
417
+
418
+ def extract_median_f0(f0):
419
+ f0 = np.where(f0 == 0, np.nan, f0)
420
+
421
+ return float(
422
+ np.median(
423
+ np.interp(
424
+ np.arange(len(f0)),
425
+ np.where(~np.isnan(f0))[0],
426
+ f0[~np.isnan(f0)]
427
+ )
428
+ )
429
+ )
430
+
431
+ def proposal_f0_up_key(f0, target_f0 = 155.0, limit = 12):
432
+ try:
433
+ return max(
434
+ -limit,
435
+ min(
436
+ limit, int(np.round(12 * np.log2(target_f0 / extract_median_f0(f0))))
437
+ )
438
+ )
439
+ except ValueError:
440
+ return 0
441
+
442
+ def circular_write(new_data, target):
443
+ offset = new_data.shape[0]
444
+
445
+ target[: -offset] = target[offset :].detach().clone()
446
+ target[-offset :] = new_data
447
+
448
+ return target
449
+
450
+ def load_faiss_index(index_path):
451
+ if index_path != "" and os.path.exists(index_path):
452
+ try:
453
+ index = faiss.read_index(index_path)
454
+ big_npy = index.reconstruct_n(0, index.ntotal)
455
+ except Exception as e:
456
+ logger.error(translations["read_faiss_index_error"].format(e=e))
457
+ index = big_npy = None
458
+ else: index = big_npy = None
459
+
460
+ return index, big_npy
461
+
462
+ def load_model(model_path, weights_only=True, log_severity_level=3):
463
+ if not os.path.isfile(model_path): return None
464
+
465
+ if model_path.endswith(".pth"):
466
+ return torch.load(
467
+ model_path,
468
+ map_location="cpu",
469
+ weights_only=weights_only
470
+ )
471
+ else:
472
+ from main.library.onnx.wrapper import ONNXRVC
473
+
474
+ return ONNXRVC(
475
+ model_path,
476
+ config.providers,
477
+ log_severity_level=log_severity_level
478
+ )