xernder commited on
Commit
7c03c55
·
verified ·
1 Parent(s): c7a3a20

Create rvc.py

Browse files
Files changed (1) hide show
  1. rvc.py +556 -0
rvc.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess, torch, os, traceback, sys, warnings, shutil, numpy as np
2
+ from mega import Mega
3
+ os.environ["no_proxy"] = "localhost, 127.0.0.1, ::1"
4
+ import threading
5
+ from time import sleep
6
+ from subprocess import Popen
7
+ import faiss
8
+ from random import shuffle
9
+ import json, datetime, requests
10
+ from gtts import gTTS
11
+ now_dir = os.getcwd()
12
+ sys.path.append(now_dir)
13
+ tmp = os.path.join(now_dir, "TEMP")
14
+ shutil.rmtree(tmp, ignore_errors=True)
15
+ shutil.rmtree("%s/runtime/Lib/site-packages/infer_pack" % (now_dir), ignore_errors=True)
16
+ os.makedirs(tmp, exist_ok=True)
17
+ os.makedirs(os.path.join(now_dir, "logs"), exist_ok=True)
18
+ os.makedirs(os.path.join(now_dir, "weights"), exist_ok=True)
19
+ os.environ["TEMP"] = tmp
20
+ warnings.filterwarnings("ignore")
21
+ torch.manual_seed(114514)
22
+ from i18n import I18nAuto
23
+
24
+ import signal
25
+
26
+ import math
27
+
28
+ from utils import load_audio, CSVutil
29
+
30
+ global DoFormant, Quefrency, Timbre
31
+
32
+ if not os.path.isdir('csvdb/'):
33
+ os.makedirs('csvdb')
34
+ frmnt, stp = open("csvdb/formanting.csv", 'w'), open("csvdb/stop.csv", 'w')
35
+ frmnt.close()
36
+ stp.close()
37
+
38
+ try:
39
+ DoFormant, Quefrency, Timbre = CSVutil('csvdb/formanting.csv', 'r', 'formanting')
40
+ DoFormant = (
41
+ lambda DoFormant: True if DoFormant.lower() == 'true' else (False if DoFormant.lower() == 'false' else DoFormant)
42
+ )(DoFormant)
43
+ except (ValueError, TypeError, IndexError):
44
+ DoFormant, Quefrency, Timbre = False, 1.0, 1.0
45
+ CSVutil('csvdb/formanting.csv', 'w+', 'formanting', DoFormant, Quefrency, Timbre)
46
+
47
+ def download_models():
48
+ # Download hubert base model if not present
49
+ if not os.path.isfile('./hubert_base.pt'):
50
+ response = requests.get('https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/hubert_base.pt')
51
+
52
+ if response.status_code == 200:
53
+ with open('./hubert_base.pt', 'wb') as f:
54
+ f.write(response.content)
55
+ print("Downloaded hubert base model file successfully. File saved to ./hubert_base.pt.")
56
+ else:
57
+ raise Exception("Failed to download hubert base model file. Status code: " + str(response.status_code) + ".")
58
+
59
+ # Download rmvpe model if not present
60
+ if not os.path.isfile('./rmvpe.pt'):
61
+ response = requests.get('https://drive.usercontent.google.com/download?id=1Hkn4kNuVFRCNQwyxQFRtmzmMBGpQxptI&export=download&authuser=0&confirm=t&uuid=0b3a40de-465b-4c65-8c41-135b0b45c3f7&at=APZUnTV3lA3LnyTbeuduura6Dmi2:1693724254058')
62
+
63
+ if response.status_code == 200:
64
+ with open('./rmvpe.pt', 'wb') as f:
65
+ f.write(response.content)
66
+ print("Downloaded rmvpe model file successfully. File saved to ./rmvpe.pt.")
67
+ else:
68
+ raise Exception("Failed to download rmvpe model file. Status code: " + str(response.status_code) + ".")
69
+
70
+ download_models()
71
+
72
+
73
+ def formant_apply(qfrency, tmbre):
74
+ Quefrency = qfrency
75
+ Timbre = tmbre
76
+ DoFormant = True
77
+ CSVutil('csvdb/formanting.csv', 'w+', 'formanting', DoFormant, qfrency, tmbre)
78
+
79
+ return ({"value": Quefrency, "__type__": "update"}, {"value": Timbre, "__type__": "update"})
80
+
81
+ def get_fshift_presets():
82
+ fshift_presets_list = []
83
+ for dirpath, _, filenames in os.walk("./formantshiftcfg/"):
84
+ for filename in filenames:
85
+ if filename.endswith(".txt"):
86
+ fshift_presets_list.append(os.path.join(dirpath,filename).replace('\\','/'))
87
+
88
+ if len(fshift_presets_list) > 0:
89
+ return fshift_presets_list
90
+ else:
91
+ return ''
92
+
93
+
94
+
95
+ def formant_enabled(cbox, qfrency, tmbre, frmntapply, formantpreset, formant_refresh_button):
96
+
97
+ if (cbox):
98
+
99
+ DoFormant = True
100
+ CSVutil('csvdb/formanting.csv', 'w+', 'formanting', DoFormant, qfrency, tmbre)
101
+ #print(f"is checked? - {cbox}\ngot {DoFormant}")
102
+
103
+ return (
104
+ {"value": True, "__type__": "update"},
105
+ {"visible": True, "__type__": "update"},
106
+ {"visible": True, "__type__": "update"},
107
+ {"visible": True, "__type__": "update"},
108
+ {"visible": True, "__type__": "update"},
109
+ {"visible": True, "__type__": "update"},
110
+ )
111
+
112
+
113
+ else:
114
+
115
+ DoFormant = False
116
+ CSVutil('csvdb/formanting.csv', 'w+', 'formanting', DoFormant, qfrency, tmbre)
117
+
118
+ #print(f"is checked? - {cbox}\ngot {DoFormant}")
119
+ return (
120
+ {"value": False, "__type__": "update"},
121
+ {"visible": False, "__type__": "update"},
122
+ {"visible": False, "__type__": "update"},
123
+ {"visible": False, "__type__": "update"},
124
+ {"visible": False, "__type__": "update"},
125
+ {"visible": False, "__type__": "update"},
126
+ {"visible": False, "__type__": "update"},
127
+ )
128
+
129
+
130
+
131
+ def preset_apply(preset, qfer, tmbr):
132
+ if str(preset) != '':
133
+ with open(str(preset), 'r') as p:
134
+ content = p.readlines()
135
+ qfer, tmbr = content[0].split('\n')[0], content[1]
136
+
137
+ formant_apply(qfer, tmbr)
138
+ else:
139
+ pass
140
+ return ({"value": qfer, "__type__": "update"}, {"value": tmbr, "__type__": "update"})
141
+
142
+ def update_fshift_presets(preset, qfrency, tmbre):
143
+
144
+ qfrency, tmbre = preset_apply(preset, qfrency, tmbre)
145
+
146
+ if (str(preset) != ''):
147
+ with open(str(preset), 'r') as p:
148
+ content = p.readlines()
149
+ qfrency, tmbre = content[0].split('\n')[0], content[1]
150
+
151
+ formant_apply(qfrency, tmbre)
152
+ else:
153
+ pass
154
+ return (
155
+ {"choices": get_fshift_presets(), "__type__": "update"},
156
+ {"value": qfrency, "__type__": "update"},
157
+ {"value": tmbre, "__type__": "update"},
158
+ )
159
+
160
+ i18n = I18nAuto()
161
+ #i18n.print()
162
+ # 判断是否有能用来训练和加速推理的N卡
163
+ ngpu = torch.cuda.device_count()
164
+ gpu_infos = []
165
+ mem = []
166
+ if (not torch.cuda.is_available()) or ngpu == 0:
167
+ if_gpu_ok = False
168
+ else:
169
+ if_gpu_ok = False
170
+ for i in range(ngpu):
171
+ gpu_name = torch.cuda.get_device_name(i)
172
+ if (
173
+ "10" in gpu_name
174
+ or "16" in gpu_name
175
+ or "20" in gpu_name
176
+ or "30" in gpu_name
177
+ or "40" in gpu_name
178
+ or "A2" in gpu_name.upper()
179
+ or "A3" in gpu_name.upper()
180
+ or "A4" in gpu_name.upper()
181
+ or "P4" in gpu_name.upper()
182
+ or "A50" in gpu_name.upper()
183
+ or "A60" in gpu_name.upper()
184
+ or "70" in gpu_name
185
+ or "80" in gpu_name
186
+ or "90" in gpu_name
187
+ or "M4" in gpu_name.upper()
188
+ or "T4" in gpu_name.upper()
189
+ or "TITAN" in gpu_name.upper()
190
+ ): # A10#A100#V100#A40#P40#M40#K80#A4500
191
+ if_gpu_ok = True # 至少有一张能用的N卡
192
+ gpu_infos.append("%s\t%s" % (i, gpu_name))
193
+ mem.append(
194
+ int(
195
+ torch.cuda.get_device_properties(i).total_memory
196
+ / 1024
197
+ / 1024
198
+ / 1024
199
+ + 0.4
200
+ )
201
+ )
202
+ if if_gpu_ok == True and len(gpu_infos) > 0:
203
+ gpu_info = "\n".join(gpu_infos)
204
+ default_batch_size = min(mem) // 2
205
+ else:
206
+ gpu_info = i18n("很遗憾您这没有能用的显卡来支持您训练")
207
+ default_batch_size = 1
208
+ gpus = "-".join([i[0] for i in gpu_infos])
209
+ from lib.infer_pack.models import (
210
+ SynthesizerTrnMs256NSFsid,
211
+ SynthesizerTrnMs256NSFsid_nono,
212
+ SynthesizerTrnMs768NSFsid,
213
+ SynthesizerTrnMs768NSFsid_nono,
214
+ )
215
+ import soundfile as sf
216
+ from fairseq import checkpoint_utils
217
+ import gradio as gr
218
+ import logging
219
+ from vc_infer_pipeline import VC
220
+ from config import Config
221
+
222
+ config = Config()
223
+ # from trainset_preprocess_pipeline import PreProcess
224
+ logging.getLogger("numba").setLevel(logging.WARNING)
225
+
226
+ hubert_model = None
227
+
228
+ def load_hubert():
229
+ global hubert_model
230
+ models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
231
+ ["hubert_base.pt"],
232
+ suffix="",
233
+ )
234
+ hubert_model = models[0]
235
+ hubert_model = hubert_model.to(config.device)
236
+ if config.is_half:
237
+ hubert_model = hubert_model.half()
238
+ else:
239
+ hubert_model = hubert_model.float()
240
+ hubert_model.eval()
241
+
242
+
243
+ weight_root = "weights"
244
+ index_root = "logs"
245
+ names = []
246
+ for name in os.listdir(weight_root):
247
+ if name.endswith(".pth"):
248
+ names.append(name)
249
+ index_paths = []
250
+ for root, dirs, files in os.walk(index_root, topdown=False):
251
+ for name in files:
252
+ if name.endswith(".index") and "trained" not in name:
253
+ index_paths.append("%s/%s" % (root, name))
254
+
255
+
256
+
257
+ def vc_single(
258
+ sid,
259
+ input_audio_path,
260
+ f0_up_key,
261
+ f0_file,
262
+ f0_method,
263
+ file_index,
264
+ #file_index2,
265
+ # file_big_npy,
266
+ index_rate,
267
+ filter_radius,
268
+ resample_sr,
269
+ rms_mix_rate,
270
+ protect,
271
+ crepe_hop_length,
272
+ ): # spk_item, input_audio0, vc_transform0,f0_file,f0method0
273
+ global tgt_sr, net_g, vc, hubert_model, version
274
+ if input_audio_path is None:
275
+ return "You need to upload an audio", None
276
+ f0_up_key = int(f0_up_key)
277
+ try:
278
+ audio = load_audio(input_audio_path, 16000, DoFormant, Quefrency, Timbre)
279
+ audio_max = np.abs(audio).max() / 0.95
280
+ if audio_max > 1:
281
+ audio /= audio_max
282
+ times = [0, 0, 0]
283
+ if hubert_model == None:
284
+ load_hubert()
285
+ if_f0 = cpt.get("f0", 1)
286
+ file_index = (
287
+ (
288
+ file_index.strip(" ")
289
+ .strip('"')
290
+ .strip("\n")
291
+ .strip('"')
292
+ .strip(" ")
293
+ .replace("trained", "added")
294
+ )
295
+ ) # 防止小白写错,自动帮他替换掉
296
+ # file_big_npy = (
297
+ # file_big_npy.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
298
+ # )
299
+ audio_opt = vc.pipeline(
300
+ hubert_model,
301
+ net_g,
302
+ sid,
303
+ audio,
304
+ input_audio_path,
305
+ times,
306
+ f0_up_key,
307
+ f0_method,
308
+ file_index,
309
+ # file_big_npy,
310
+ index_rate,
311
+ if_f0,
312
+ filter_radius,
313
+ tgt_sr,
314
+ resample_sr,
315
+ rms_mix_rate,
316
+ version,
317
+ protect,
318
+ crepe_hop_length,
319
+ f0_file=f0_file,
320
+ )
321
+ if resample_sr >= 16000 and tgt_sr != resample_sr:
322
+ tgt_sr = resample_sr
323
+ index_info = (
324
+ "Using index:%s." % file_index
325
+ if os.path.exists(file_index)
326
+ else "Index not used."
327
+ )
328
+ return "Success.\n %s\nTime:\n npy:%ss, f0:%ss, infer:%ss" % (
329
+ index_info,
330
+ times[0],
331
+ times[1],
332
+ times[2],
333
+ ), (tgt_sr, audio_opt)
334
+ except:
335
+ info = traceback.format_exc()
336
+ print(info)
337
+ return info, (None, None)
338
+
339
+
340
+ def vc_multi(
341
+ sid,
342
+ dir_path,
343
+ opt_root,
344
+ paths,
345
+ f0_up_key,
346
+ f0_method,
347
+ file_index,
348
+ file_index2,
349
+ # file_big_npy,
350
+ index_rate,
351
+ filter_radius,
352
+ resample_sr,
353
+ rms_mix_rate,
354
+ protect,
355
+ format1,
356
+ crepe_hop_length,
357
+ ):
358
+ try:
359
+ dir_path = (
360
+ dir_path.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
361
+ ) # 防止小白拷路径头尾带了空格和"和回车
362
+ opt_root = opt_root.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
363
+ os.makedirs(opt_root, exist_ok=True)
364
+ try:
365
+ if dir_path != "":
366
+ paths = [os.path.join(dir_path, name) for name in os.listdir(dir_path)]
367
+ else:
368
+ paths = [path.name for path in paths]
369
+ except:
370
+ traceback.print_exc()
371
+ paths = [path.name for path in paths]
372
+ infos = []
373
+ for path in paths:
374
+ info, opt = vc_single(
375
+ sid,
376
+ path,
377
+ f0_up_key,
378
+ None,
379
+ f0_method,
380
+ file_index,
381
+ # file_big_npy,
382
+ index_rate,
383
+ filter_radius,
384
+ resample_sr,
385
+ rms_mix_rate,
386
+ protect,
387
+ crepe_hop_length
388
+ )
389
+ if "Success" in info:
390
+ try:
391
+ tgt_sr, audio_opt = opt
392
+ if format1 in ["wav", "flac"]:
393
+ sf.write(
394
+ "%s/%s.%s" % (opt_root, os.path.basename(path), format1),
395
+ audio_opt,
396
+ tgt_sr,
397
+ )
398
+ else:
399
+ path = "%s/%s.wav" % (opt_root, os.path.basename(path))
400
+ sf.write(
401
+ path,
402
+ audio_opt,
403
+ tgt_sr,
404
+ )
405
+ if os.path.exists(path):
406
+ os.system(
407
+ "ffmpeg -i %s -vn %s -q:a 2 -y"
408
+ % (path, path[:-4] + ".%s" % format1)
409
+ )
410
+ except:
411
+ info += traceback.format_exc()
412
+ infos.append("%s->%s" % (os.path.basename(path), info))
413
+ yield "\n".join(infos)
414
+ yield "\n".join(infos)
415
+ except:
416
+ yield traceback.format_exc()
417
+
418
+ # 一个选项卡全局只能有一个音色
419
+ def get_vc(sid):
420
+ global n_spk, tgt_sr, net_g, vc, cpt, version
421
+ if sid == "" or sid == []:
422
+ global hubert_model
423
+ if hubert_model != None: # 考虑到轮询, 需要加个判断看是否 sid 是由有模型切换到无模型的
424
+ print("clean_empty_cache")
425
+ del net_g, n_spk, vc, hubert_model, tgt_sr # ,cpt
426
+ hubert_model = net_g = n_spk = vc = hubert_model = tgt_sr = None
427
+ if torch.cuda.is_available():
428
+ torch.cuda.empty_cache()
429
+ ###楼下不这么折腾清理不干净
430
+ if_f0 = cpt.get("f0", 1)
431
+ version = cpt.get("version", "v1")
432
+ if version == "v1":
433
+ if if_f0 == 1:
434
+ net_g = SynthesizerTrnMs256NSFsid(
435
+ *cpt["config"], is_half=config.is_half
436
+ )
437
+ else:
438
+ net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
439
+ elif version == "v2":
440
+ if if_f0 == 1:
441
+ net_g = SynthesizerTrnMs768NSFsid(
442
+ *cpt["config"], is_half=config.is_half
443
+ )
444
+ else:
445
+ net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
446
+ del net_g, cpt
447
+ if torch.cuda.is_available():
448
+ torch.cuda.empty_cache()
449
+ cpt = None
450
+ return {"visible": False, "__type__": "update"}
451
+ person = "%s/%s" % (weight_root, sid)
452
+ print("loading %s" % person)
453
+ cpt = torch.load(person, map_location="cpu")
454
+ tgt_sr = cpt["config"][-1]
455
+ cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
456
+ if_f0 = cpt.get("f0", 1)
457
+ version = cpt.get("version", "v1")
458
+ if version == "v1":
459
+ if if_f0 == 1:
460
+ net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=config.is_half)
461
+ else:
462
+ net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
463
+ elif version == "v2":
464
+ if if_f0 == 1:
465
+ net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=config.is_half)
466
+ else:
467
+ net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
468
+ del net_g.enc_q
469
+ print(net_g.load_state_dict(cpt["weight"], strict=False))
470
+ net_g.eval().to(config.device)
471
+ if config.is_half:
472
+ net_g = net_g.half()
473
+ else:
474
+ net_g = net_g.float()
475
+ vc = VC(tgt_sr, config)
476
+ n_spk = cpt["config"][-3]
477
+ return {"visible": False, "maximum": n_spk, "__type__": "update"}
478
+
479
+
480
+ def change_choices():
481
+ names = []
482
+ for name in os.listdir(weight_root):
483
+ if name.endswith(".pth"):
484
+ names.append(name)
485
+ index_paths = []
486
+ for root, dirs, files in os.walk(index_root, topdown=False):
487
+ for name in files:
488
+ if name.endswith(".index") and "trained" not in name:
489
+ index_paths.append("%s/%s" % (root, name))
490
+ return {"choices": sorted(names), "__type__": "update"}, {
491
+ "choices": sorted(index_paths),
492
+ "__type__": "update",
493
+ }
494
+
495
+
496
+ def clean():
497
+ return {"value": "", "__type__": "update"}
498
+
499
+
500
+ sr_dict = {
501
+ "32k": 32000,
502
+ "40k": 40000,
503
+ "48k": 48000,
504
+ }
505
+
506
+
507
+ def if_done(done, p):
508
+ while 1:
509
+ if p.poll() == None:
510
+ sleep(0.5)
511
+ else:
512
+ break
513
+ done[0] = True
514
+
515
+
516
+ def if_done_multi(done, ps):
517
+ while 1:
518
+ # poll==None代表进程未结束
519
+ # 只要有一个进程未结束都不停
520
+ flag = 1
521
+ for p in ps:
522
+ if p.poll() == None:
523
+ flag = 0
524
+ sleep(0.5)
525
+ break
526
+ if flag == 1:
527
+ break
528
+ done[0] = True
529
+
530
+
531
+
532
+ global log_interval
533
+
534
+
535
+ def set_log_interval(exp_dir, batch_size12):
536
+ log_interval = 1
537
+
538
+ folder_path = os.path.join(exp_dir, "1_16k_wavs")
539
+
540
+ if os.path.exists(folder_path) and os.path.isdir(folder_path):
541
+ wav_files = [f for f in os.listdir(folder_path) if f.endswith(".wav")]
542
+ if wav_files:
543
+ sample_size = len(wav_files)
544
+ log_interval = math.ceil(sample_size / batch_size12)
545
+ if log_interval > 1:
546
+ log_interval += 1
547
+ return log_interval
548
+
549
+
550
+
551
+
552
+
553
+ def whethercrepeornah(radio):
554
+ mango = True if radio == 'mangio-crepe' or radio == 'mangio-crepe-tiny' else False
555
+ return ({"visible": mango, "__type__": "update"})
556
+