Plachta commited on
Commit
420ea59
·
verified ·
1 Parent(s): a83a03b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -50
app.py CHANGED
@@ -13,9 +13,8 @@ import spaces
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
  dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
16
- "DiT_step_298000_seed_uvit_facodec_small_wavenet_pruned.pth",
17
- "config_dit_mel_seed_facodec_small_wavenet.yml")
18
-
19
  config = yaml.safe_load(open(dit_config_path, 'r'))
20
  model_params = recursive_munch(config['model_params'])
21
  model = build_model(model_params, stage='DiT')
@@ -39,18 +38,6 @@ campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"
39
  campplus_model.eval()
40
  campplus_model.to(device)
41
 
42
- from modules.hifigan.generator import HiFTGenerator
43
- from modules.hifigan.f0_predictor import ConvRNNF0Predictor
44
-
45
- hift_checkpoint_path, hift_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
46
- "hift.pt",
47
- "hifigan.yml")
48
- hift_config = yaml.safe_load(open(hift_config_path, 'r'))
49
- hift_gen = HiFTGenerator(**hift_config['hift'], f0_predictor=ConvRNNF0Predictor(**hift_config['f0_predictor']))
50
- hift_gen.load_state_dict(torch.load(hift_checkpoint_path, map_location='cpu'))
51
- hift_gen.eval()
52
- hift_gen.to(device)
53
-
54
  from modules.bigvgan import bigvgan
55
 
56
  bigvgan_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_22khz_80band_256x', use_cuda_kernel=False)
@@ -59,25 +46,27 @@ bigvgan_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_22khz_80band_
59
  bigvgan_model.remove_weight_norm()
60
  bigvgan_model = bigvgan_model.eval().to(device)
61
 
62
- speech_tokenizer_type = config['model_params']['speech_tokenizer'].get('type', 'cosyvoice')
63
- if speech_tokenizer_type == 'cosyvoice':
64
- from modules.cosyvoice_tokenizer.frontend import CosyVoiceFrontEnd
65
- speech_tokenizer_path = load_custom_model_from_hf("Plachta/Seed-VC", "speech_tokenizer_v1.onnx", None)
66
- cosyvoice_frontend = CosyVoiceFrontEnd(speech_tokenizer_model=speech_tokenizer_path,
67
- device='cuda', device_id=0)
68
- elif speech_tokenizer_type == 'facodec':
69
- ckpt_path, config_path = load_custom_model_from_hf("Plachta/FAcodec", 'pytorch_model.bin', 'config.yml')
70
 
71
- codec_config = yaml.safe_load(open(config_path))
72
- codec_model_params = recursive_munch(codec_config['model_params'])
73
- codec_encoder = build_model(codec_model_params, stage="codec")
74
 
75
- ckpt_params = torch.load(ckpt_path, map_location="cpu")
76
 
77
- for key in codec_encoder:
78
- codec_encoder[key].load_state_dict(ckpt_params[key], strict=False)
79
- _ = [codec_encoder[key].eval() for key in codec_encoder]
80
- _ = [codec_encoder[key].to(device) for key in codec_encoder]
 
 
 
 
 
 
 
 
 
81
 
82
  # Generate mel spectrograms
83
  mel_fn_args = {
@@ -87,7 +76,7 @@ mel_fn_args = {
87
  "num_mels": config['preprocess_params']['spect_params']['n_mels'],
88
  "sampling_rate": sr,
89
  "fmin": 0,
90
- "fmax": 8000,
91
  "center": False
92
  }
93
  mel_fn_args_f0 = {
@@ -149,7 +138,7 @@ bitrate = "320k"
149
  @spaces.GPU
150
  @torch.no_grad()
151
  @torch.inference_mode()
152
- def voice_conversion(source, target, diffusion_steps, length_adjust, inference_cfg_rate, n_quantizers, f0_condition, auto_f0_adjust, pitch_shift):
153
  inference_module = model if not f0_condition else model_f0
154
  mel_fn = to_mel if not f0_condition else to_mel_f0
155
  # Load audio
@@ -161,14 +150,10 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
161
  ref_audio = torch.tensor(ref_audio[:sr * 25]).unsqueeze(0).float().to(device)
162
 
163
  # Resample
164
- source_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
165
  ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
166
 
167
  # Extract features
168
- if speech_tokenizer_type == 'cosyvoice':
169
- S_alt = cosyvoice_frontend.extract_speech_token(source_waves_16k)[0]
170
- S_ori = cosyvoice_frontend.extract_speech_token(ref_waves_16k)[0]
171
- elif speech_tokenizer_type == 'facodec':
172
  converted_waves_24k = torchaudio.functional.resample(source_audio, sr, 24000)
173
  waves_input = converted_waves_24k.unsqueeze(1)
174
  max_wave_len_per_chunk = 24000 * 20
@@ -201,6 +186,74 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
201
  waves_input,
202
  )
203
  S_ori = torch.cat([codes[1], codes[0]], dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  mel = mel_fn(source_audio.to(device).float())
206
  mel2 = mel_fn(ref_audio.to(device).float())
@@ -248,8 +301,8 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
248
  shifted_f0_alt = None
249
 
250
  # Length regulation
251
- cond = inference_module.length_regulator(S_alt, ylens=target_lengths, n_quantizers=int(n_quantizers), f0=shifted_f0_alt)[0]
252
- prompt_condition = inference_module.length_regulator(S_ori, ylens=target2_lengths, n_quantizers=int(n_quantizers), f0=F0_ori)[0]
253
 
254
  max_source_window = max_context_window - mel2.size(2)
255
  # split source condition (cond) into chunks
@@ -266,10 +319,7 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
266
  mel2, style2, None, diffusion_steps,
267
  inference_cfg_rate=inference_cfg_rate)
268
  vc_target = vc_target[:, :, mel2.size(-1):]
269
- if not f0_condition:
270
- vc_wave = hift_gen.inference(vc_target, f0=None)
271
- else:
272
- vc_wave = bigvgan_model(vc_target)[0]
273
  if processed_frames == 0:
274
  if is_last_chunk:
275
  output_wave = vc_wave[0].cpu().numpy()
@@ -327,19 +377,18 @@ if __name__ == "__main__":
327
  gr.Slider(minimum=1, maximum=200, value=10, step=1, label="Diffusion Steps / 扩散步数", info="10 by default, 50~100 for best quality / 默认为 10,50~100 为最佳质量"),
328
  gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Adjust / 长度调整", info="<1.0 for speed-up speech, >1.0 for slow-down speech / <1.0 加速语速,>1.0 减慢语速"),
329
  gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="Inference CFG Rate", info="has subtle influence / 有微小影响"),
330
- gr.Slider(minimum=1, maximum=3, step=1, value=3, label="N FAcodec Quantizers / FAcodec码本数量", info="the less FAcodec quantizer used, the less prosody of source audio is preserved / 使用的FAcodec码本越少,源音频的韵律保留越少"),
331
  gr.Checkbox(label="Use F0 conditioned model / 启用F0输入", value=False, info="Must set to true for singing voice conversion / 歌声转换时必须勾选"),
332
  gr.Checkbox(label="Auto F0 adjust / 自动F0调整", value=True,
333
  info="Roughly adjust F0 to match target voice. Only works when F0 conditioned model is used. / 粗略调整 F0 以匹配目标音色,仅在勾选 '启用F0输入' 时生效"),
334
  gr.Slider(label='Pitch shift / 音调变换', minimum=-24, maximum=24, step=1, value=0, info="Pitch shift in semitones, only works when F0 conditioned model is used / 半音数的音高变换,仅在勾选 '启用F0输入' 时生效"),
335
  ]
336
 
337
- examples = [["examples/source/yae_0.wav", "examples/reference/dingzhen_0.wav", 25, 1.0, 0.7, 1, False, True, 0],
338
- ["examples/source/jay_0.wav", "examples/reference/azuma_0.wav", 25, 1.0, 0.7, 1, True, True, 0],
339
  ["examples/source/Wiz Khalifa,Charlie Puth - See You Again [vocals]_[cut_28sec].wav",
340
- "examples/reference/teio_0.wav", 100, 1.0, 0.7, 1, True, False, 0],
341
  ["examples/source/TECHNOPOLIS - 2085 [vocals]_[cut_14sec].wav",
342
- "examples/reference/trump_0.wav", 50, 1.0, 0.7, 1, True, False, -12],
343
  ]
344
 
345
  outputs = [gr.Audio(label="Stream Output Audio / 流式输出", streaming=True, format='mp3'),
@@ -352,4 +401,4 @@ if __name__ == "__main__":
352
  title="Seed Voice Conversion",
353
  examples=examples,
354
  cache_examples=False,
355
- ).launch()
 
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
  dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
16
+ "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
17
+ "config_dit_mel_seed_uvit_whisper_small_wavenet.yml")
 
18
  config = yaml.safe_load(open(dit_config_path, 'r'))
19
  model_params = recursive_munch(config['model_params'])
20
  model = build_model(model_params, stage='DiT')
 
38
  campplus_model.eval()
39
  campplus_model.to(device)
40
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  from modules.bigvgan import bigvgan
42
 
43
  bigvgan_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_22khz_80band_256x', use_cuda_kernel=False)
 
46
  bigvgan_model.remove_weight_norm()
47
  bigvgan_model = bigvgan_model.eval().to(device)
48
 
49
+ ckpt_path, config_path = load_custom_model_from_hf("Plachta/FAcodec", 'pytorch_model.bin', 'config.yml')
 
 
 
 
 
 
 
50
 
51
+ codec_config = yaml.safe_load(open(config_path))
52
+ codec_model_params = recursive_munch(codec_config['model_params'])
53
+ codec_encoder = build_model(codec_model_params, stage="codec")
54
 
55
+ ckpt_params = torch.load(ckpt_path, map_location="cpu")
56
 
57
+ for key in codec_encoder:
58
+ codec_encoder[key].load_state_dict(ckpt_params[key], strict=False)
59
+ _ = [codec_encoder[key].eval() for key in codec_encoder]
60
+ _ = [codec_encoder[key].to(device) for key in codec_encoder]
61
+
62
+ # whisper
63
+ from transformers import AutoFeatureExtractor, WhisperModel
64
+
65
+ whisper_name = model_params.speech_tokenizer.whisper_name if hasattr(model_params.speech_tokenizer,
66
+ 'whisper_name') else "openai/whisper-small"
67
+ whisper_model = WhisperModel.from_pretrained(whisper_name, torch_dtype=torch.float16).to(device)
68
+ del whisper_model.decoder
69
+ whisper_feature_extractor = AutoFeatureExtractor.from_pretrained(whisper_name)
70
 
71
  # Generate mel spectrograms
72
  mel_fn_args = {
 
76
  "num_mels": config['preprocess_params']['spect_params']['n_mels'],
77
  "sampling_rate": sr,
78
  "fmin": 0,
79
+ "fmax": None,
80
  "center": False
81
  }
82
  mel_fn_args_f0 = {
 
138
  @spaces.GPU
139
  @torch.no_grad()
140
  @torch.inference_mode()
141
+ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_cfg_rate, f0_condition, auto_f0_adjust, pitch_shift):
142
  inference_module = model if not f0_condition else model_f0
143
  mel_fn = to_mel if not f0_condition else to_mel_f0
144
  # Load audio
 
150
  ref_audio = torch.tensor(ref_audio[:sr * 25]).unsqueeze(0).float().to(device)
151
 
152
  # Resample
 
153
  ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
154
 
155
  # Extract features
156
+ if f0_condition:
 
 
 
157
  converted_waves_24k = torchaudio.functional.resample(source_audio, sr, 24000)
158
  waves_input = converted_waves_24k.unsqueeze(1)
159
  max_wave_len_per_chunk = 24000 * 20
 
186
  waves_input,
187
  )
188
  S_ori = torch.cat([codes[1], codes[0]], dim=1)
189
+ else:
190
+ converted_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
191
+ # if source audio less than 30 seconds, whisper can handle in one forward
192
+ if converted_waves_16k.size(-1) <= 16000 * 30:
193
+ alt_inputs = whisper_feature_extractor([converted_waves_16k.squeeze(0).cpu().numpy()],
194
+ return_tensors="pt",
195
+ return_attention_mask=True,
196
+ sampling_rate=16000)
197
+ alt_input_features = whisper_model._mask_input_features(
198
+ alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(device)
199
+ alt_outputs = whisper_model.encoder(
200
+ alt_input_features.to(whisper_model.encoder.dtype),
201
+ head_mask=None,
202
+ output_attentions=False,
203
+ output_hidden_states=False,
204
+ return_dict=True,
205
+ )
206
+ S_alt = alt_outputs.last_hidden_state.to(torch.float32)
207
+ S_alt = S_alt[:, :converted_waves_16k.size(-1) // 320 + 1]
208
+ else:
209
+ overlapping_time = 5 # 5 seconds
210
+ S_alt_list = []
211
+ buffer = None
212
+ traversed_time = 0
213
+ while traversed_time < converted_waves_16k.size(-1):
214
+ if buffer is None: # first chunk
215
+ chunk = converted_waves_16k[:, traversed_time:traversed_time + 16000 * 30]
216
+ else:
217
+ chunk = torch.cat([buffer, converted_waves_16k[:, traversed_time:traversed_time + 16000 * (30 - overlapping_time)]], dim=-1)
218
+ alt_inputs = whisper_feature_extractor([chunk.squeeze(0).cpu().numpy()],
219
+ return_tensors="pt",
220
+ return_attention_mask=True,
221
+ sampling_rate=16000)
222
+ alt_input_features = whisper_model._mask_input_features(
223
+ alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(device)
224
+ alt_outputs = whisper_model.encoder(
225
+ alt_input_features.to(whisper_model.encoder.dtype),
226
+ head_mask=None,
227
+ output_attentions=False,
228
+ output_hidden_states=False,
229
+ return_dict=True,
230
+ )
231
+ S_alt = alt_outputs.last_hidden_state.to(torch.float32)
232
+ S_alt = S_alt[:, :chunk.size(-1) // 320 + 1]
233
+ if traversed_time == 0:
234
+ S_alt_list.append(S_alt)
235
+ else:
236
+ S_alt_list.append(S_alt[:, 50 * overlapping_time:])
237
+ buffer = chunk[:, -16000 * overlapping_time:]
238
+ traversed_time += 30 * 16000 if traversed_time == 0 else chunk.size(-1) - 16000 * overlapping_time
239
+ S_alt = torch.cat(S_alt_list, dim=1)
240
+
241
+ ori_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
242
+ ori_inputs = whisper_feature_extractor([ori_waves_16k.squeeze(0).cpu().numpy()],
243
+ return_tensors="pt",
244
+ return_attention_mask=True)
245
+ ori_input_features = whisper_model._mask_input_features(
246
+ ori_inputs.input_features, attention_mask=ori_inputs.attention_mask).to(device)
247
+ with torch.no_grad():
248
+ ori_outputs = whisper_model.encoder(
249
+ ori_input_features.to(whisper_model.encoder.dtype),
250
+ head_mask=None,
251
+ output_attentions=False,
252
+ output_hidden_states=False,
253
+ return_dict=True,
254
+ )
255
+ S_ori = ori_outputs.last_hidden_state.to(torch.float32)
256
+ S_ori = S_ori[:, :ori_waves_16k.size(-1) // 320 + 1]
257
 
258
  mel = mel_fn(source_audio.to(device).float())
259
  mel2 = mel_fn(ref_audio.to(device).float())
 
301
  shifted_f0_alt = None
302
 
303
  # Length regulation
304
+ cond, _, codes, commitment_loss, codebook_loss = inference_module.length_regulator(S_alt, ylens=target_lengths, n_quantizers=3, f0=shifted_f0_alt)
305
+ prompt_condition, _, codes, commitment_loss, codebook_loss = inference_module.length_regulator(S_ori, ylens=target2_lengths, n_quantizers=3, f0=F0_ori)
306
 
307
  max_source_window = max_context_window - mel2.size(2)
308
  # split source condition (cond) into chunks
 
319
  mel2, style2, None, diffusion_steps,
320
  inference_cfg_rate=inference_cfg_rate)
321
  vc_target = vc_target[:, :, mel2.size(-1):]
322
+ vc_wave = bigvgan_model(vc_target)[0]
 
 
 
323
  if processed_frames == 0:
324
  if is_last_chunk:
325
  output_wave = vc_wave[0].cpu().numpy()
 
377
  gr.Slider(minimum=1, maximum=200, value=10, step=1, label="Diffusion Steps / 扩散步数", info="10 by default, 50~100 for best quality / 默认为 10,50~100 为最佳质量"),
378
  gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Adjust / 长度调整", info="<1.0 for speed-up speech, >1.0 for slow-down speech / <1.0 加速语速,>1.0 减慢语速"),
379
  gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="Inference CFG Rate", info="has subtle influence / 有微小影响"),
 
380
  gr.Checkbox(label="Use F0 conditioned model / 启用F0输入", value=False, info="Must set to true for singing voice conversion / 歌声转换时必须勾选"),
381
  gr.Checkbox(label="Auto F0 adjust / 自动F0调整", value=True,
382
  info="Roughly adjust F0 to match target voice. Only works when F0 conditioned model is used. / 粗略调整 F0 以匹配目标音色,仅在勾选 '启用F0输入' 时生效"),
383
  gr.Slider(label='Pitch shift / 音调变换', minimum=-24, maximum=24, step=1, value=0, info="Pitch shift in semitones, only works when F0 conditioned model is used / 半音数的音高变换,仅在勾选 '启用F0输入' 时生效"),
384
  ]
385
 
386
+ examples = [["examples/source/yae_0.wav", "examples/reference/dingzhen_0.wav", 25, 1.0, 0.7, False, True, 0],
387
+ ["examples/source/jay_0.wav", "examples/reference/azuma_0.wav", 25, 1.0, 0.7, True, True, 0],
388
  ["examples/source/Wiz Khalifa,Charlie Puth - See You Again [vocals]_[cut_28sec].wav",
389
+ "examples/reference/teio_0.wav", 100, 1.0, 0.7, True, False, 0],
390
  ["examples/source/TECHNOPOLIS - 2085 [vocals]_[cut_14sec].wav",
391
+ "examples/reference/trump_0.wav", 50, 1.0, 0.7, True, False, -12],
392
  ]
393
 
394
  outputs = [gr.Audio(label="Stream Output Audio / 流式输出", streaming=True, format='mp3'),
 
401
  title="Seed Voice Conversion",
402
  examples=examples,
403
  cache_examples=False,
404
+ ).launch()