Bajiyo commited on
Commit
03e24ee
·
verified ·
1 Parent(s): 130c071

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -362
app.py CHANGED
@@ -1,375 +1,176 @@
1
- import spaces
2
  import gradio as gr
3
  import torch
4
- import torchaudio
5
- import librosa
6
- from modules.commons import build_model, load_checkpoint, recursive_munch
7
  import yaml
8
- from hf_utils import load_custom_model_from_hf
9
- import numpy as np
10
- from pydub import AudioSegment
11
-
12
- # Load model and configuration
13
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
-
15
- dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
16
- "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
17
- "config_dit_mel_seed_uvit_whisper_small_wavenet.yml")
18
- # dit_checkpoint_path = "E:/DiT_epoch_00018_step_801000.pth"
19
- # dit_config_path = "configs/config_dit_mel_seed_uvit_whisper_small_encoder_wavenet.yml"
20
- config = yaml.safe_load(open(dit_config_path, 'r'))
21
- model_params = recursive_munch(config['model_params'])
22
- model = build_model(model_params, stage='DiT')
23
- hop_length = config['preprocess_params']['spect_params']['hop_length']
24
- sr = config['preprocess_params']['sr']
25
-
26
- # Load checkpoints
27
- model, _, _, _ = load_checkpoint(model, None, dit_checkpoint_path,
28
- load_only_params=True, ignore_modules=[], is_distributed=False)
29
- for key in model:
30
- model[key].eval()
31
- model[key].to(device)
32
- model.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
33
-
34
- # Load additional modules
35
- from modules.campplus.DTDNN import CAMPPlus
36
-
37
- campplus_ckpt_path = load_custom_model_from_hf("funasr/campplus", "campplus_cn_common.bin", config_filename=None)
38
- campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
39
- campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
40
- campplus_model.eval()
41
- campplus_model.to(device)
42
-
43
- from modules.bigvgan import bigvgan
44
-
45
- bigvgan_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_22khz_80band_256x', use_cuda_kernel=False)
46
-
47
- # remove weight norm in the model and set to eval mode
48
- bigvgan_model.remove_weight_norm()
49
- bigvgan_model = bigvgan_model.eval().to(device)
50
-
51
- ckpt_path, config_path = load_custom_model_from_hf("Plachta/FAcodec", 'pytorch_model.bin', 'config.yml')
52
-
53
- codec_config = yaml.safe_load(open(config_path))
54
- codec_model_params = recursive_munch(codec_config['model_params'])
55
- codec_encoder = build_model(codec_model_params, stage="codec")
56
-
57
- ckpt_params = torch.load(ckpt_path, map_location="cpu")
58
-
59
- for key in codec_encoder:
60
- codec_encoder[key].load_state_dict(ckpt_params[key], strict=False)
61
- _ = [codec_encoder[key].eval() for key in codec_encoder]
62
- _ = [codec_encoder[key].to(device) for key in codec_encoder]
63
-
64
- # whisper
65
- from transformers import AutoFeatureExtractor, WhisperModel
66
-
67
- whisper_name = model_params.speech_tokenizer.whisper_name if hasattr(model_params.speech_tokenizer,
68
- 'whisper_name') else "openai/whisper-small"
69
- whisper_model = WhisperModel.from_pretrained(whisper_name, torch_dtype=torch.float16).to(device)
70
- del whisper_model.decoder
71
- whisper_feature_extractor = AutoFeatureExtractor.from_pretrained(whisper_name)
72
-
73
- # Generate mel spectrograms
74
- mel_fn_args = {
75
- "n_fft": config['preprocess_params']['spect_params']['n_fft'],
76
- "win_size": config['preprocess_params']['spect_params']['win_length'],
77
- "hop_size": config['preprocess_params']['spect_params']['hop_length'],
78
- "num_mels": config['preprocess_params']['spect_params']['n_mels'],
79
- "sampling_rate": sr,
80
- "fmin": 0,
81
- "fmax": None,
82
- "center": False
83
- }
84
- from modules.audio import mel_spectrogram
85
-
86
- to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
87
-
88
- # f0 conditioned model
89
- dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
90
- "DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema.pth",
91
- "config_dit_mel_seed_uvit_whisper_base_f0_44k.yml")
92
-
93
- config = yaml.safe_load(open(dit_config_path, 'r'))
94
- model_params = recursive_munch(config['model_params'])
95
- model_f0 = build_model(model_params, stage='DiT')
96
- hop_length = config['preprocess_params']['spect_params']['hop_length']
97
- sr = config['preprocess_params']['sr']
98
-
99
- # Load checkpoints
100
- model_f0, _, _, _ = load_checkpoint(model_f0, None, dit_checkpoint_path,
101
- load_only_params=True, ignore_modules=[], is_distributed=False)
102
- for key in model_f0:
103
- model_f0[key].eval()
104
- model_f0[key].to(device)
105
- model_f0.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
106
-
107
- # f0 extractor
108
- from modules.rmvpe import RMVPE
109
-
110
- model_path = load_custom_model_from_hf("lj1995/VoiceConversionWebUI", "rmvpe.pt", None)
111
- rmvpe = RMVPE(model_path, is_half=False, device=device)
112
-
113
- mel_fn_args_f0 = {
114
- "n_fft": config['preprocess_params']['spect_params']['n_fft'],
115
- "win_size": config['preprocess_params']['spect_params']['win_length'],
116
- "hop_size": config['preprocess_params']['spect_params']['hop_length'],
117
- "num_mels": config['preprocess_params']['spect_params']['n_mels'],
118
- "sampling_rate": sr,
119
- "fmin": 0,
120
- "fmax": None,
121
- "center": False
122
- }
123
- to_mel_f0 = lambda x: mel_spectrogram(x, **mel_fn_args_f0)
124
- bigvgan_44k_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x', use_cuda_kernel=False)
125
-
126
- # remove weight norm in the model and set to eval mode
127
- bigvgan_44k_model.remove_weight_norm()
128
- bigvgan_44k_model = bigvgan_44k_model.eval().to(device)
129
-
130
- def adjust_f0_semitones(f0_sequence, n_semitones):
131
- factor = 2 ** (n_semitones / 12)
132
- return f0_sequence * factor
133
-
134
- def crossfade(chunk1, chunk2, overlap):
135
- fade_out = np.cos(np.linspace(0, np.pi / 2, overlap)) ** 2
136
- fade_in = np.cos(np.linspace(np.pi / 2, 0, overlap)) ** 2
137
- chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out
138
- return chunk2
139
-
140
- # streaming and chunk processing related params
141
- bitrate = "320k"
142
- overlap_frame_len = 16
143
- @spaces.GPU
144
- @torch.no_grad()
145
- @torch.inference_mode()
146
- def voice_conversion(source, target, diffusion_steps, length_adjust, inference_cfg_rate, f0_condition, auto_f0_adjust, pitch_shift):
147
- inference_module = model if not f0_condition else model_f0
148
- mel_fn = to_mel if not f0_condition else to_mel_f0
149
- bigvgan_fn = bigvgan_model if not f0_condition else bigvgan_44k_model
150
- sr = 22050 if not f0_condition else 44100
151
- hop_length = 256 if not f0_condition else 512
152
- max_context_window = sr // hop_length * 30
153
- overlap_wave_len = overlap_frame_len * hop_length
154
- # Load audio
155
- source_audio = librosa.load(source, sr=sr)[0]
156
- ref_audio = librosa.load(target, sr=sr)[0]
157
-
158
- # Process audio
159
- source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(device)
160
- ref_audio = torch.tensor(ref_audio[:sr * 25]).unsqueeze(0).float().to(device)
161
-
162
- # Resample
163
- ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
164
- converted_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
165
- # if source audio less than 30 seconds, whisper can handle in one forward
166
- if converted_waves_16k.size(-1) <= 16000 * 30:
167
- alt_inputs = whisper_feature_extractor([converted_waves_16k.squeeze(0).cpu().numpy()],
168
- return_tensors="pt",
169
- return_attention_mask=True,
170
- sampling_rate=16000)
171
- alt_input_features = whisper_model._mask_input_features(
172
- alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(device)
173
- alt_outputs = whisper_model.encoder(
174
- alt_input_features.to(whisper_model.encoder.dtype),
175
- head_mask=None,
176
- output_attentions=False,
177
- output_hidden_states=False,
178
- return_dict=True,
179
- )
180
- S_alt = alt_outputs.last_hidden_state.to(torch.float32)
181
- S_alt = S_alt[:, :converted_waves_16k.size(-1) // 320 + 1]
182
- else:
183
- overlapping_time = 5 # 5 seconds
184
- S_alt_list = []
185
- buffer = None
186
- traversed_time = 0
187
- while traversed_time < converted_waves_16k.size(-1):
188
- if buffer is None: # first chunk
189
- chunk = converted_waves_16k[:, traversed_time:traversed_time + 16000 * 30]
190
- else:
191
- chunk = torch.cat([buffer, converted_waves_16k[:, traversed_time:traversed_time + 16000 * (30 - overlapping_time)]], dim=-1)
192
- alt_inputs = whisper_feature_extractor([chunk.squeeze(0).cpu().numpy()],
193
- return_tensors="pt",
194
- return_attention_mask=True,
195
- sampling_rate=16000)
196
- alt_input_features = whisper_model._mask_input_features(
197
- alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(device)
198
- alt_outputs = whisper_model.encoder(
199
- alt_input_features.to(whisper_model.encoder.dtype),
200
- head_mask=None,
201
- output_attentions=False,
202
- output_hidden_states=False,
203
- return_dict=True,
204
- )
205
- S_alt = alt_outputs.last_hidden_state.to(torch.float32)
206
- S_alt = S_alt[:, :chunk.size(-1) // 320 + 1]
207
- if traversed_time == 0:
208
- S_alt_list.append(S_alt)
209
- else:
210
- S_alt_list.append(S_alt[:, 50 * overlapping_time:])
211
- buffer = chunk[:, -16000 * overlapping_time:]
212
- traversed_time += 30 * 16000 if traversed_time == 0 else chunk.size(-1) - 16000 * overlapping_time
213
- S_alt = torch.cat(S_alt_list, dim=1)
214
-
215
- ori_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
216
- ori_inputs = whisper_feature_extractor([ori_waves_16k.squeeze(0).cpu().numpy()],
217
- return_tensors="pt",
218
- return_attention_mask=True)
219
- ori_input_features = whisper_model._mask_input_features(
220
- ori_inputs.input_features, attention_mask=ori_inputs.attention_mask).to(device)
221
- with torch.no_grad():
222
- ori_outputs = whisper_model.encoder(
223
- ori_input_features.to(whisper_model.encoder.dtype),
224
- head_mask=None,
225
- output_attentions=False,
226
- output_hidden_states=False,
227
- return_dict=True,
228
  )
229
- S_ori = ori_outputs.last_hidden_state.to(torch.float32)
230
- S_ori = S_ori[:, :ori_waves_16k.size(-1) // 320 + 1]
231
-
232
- mel = mel_fn(source_audio.to(device).float())
233
- mel2 = mel_fn(ref_audio.to(device).float())
234
-
235
- target_lengths = torch.LongTensor([int(mel.size(2) * length_adjust)]).to(mel.device)
236
- target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)
237
-
238
- feat2 = torchaudio.compliance.kaldi.fbank(ref_waves_16k,
239
- num_mel_bins=80,
240
- dither=0,
241
- sample_frequency=16000)
242
- feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
243
- style2 = campplus_model(feat2.unsqueeze(0))
244
-
245
- if f0_condition:
246
- F0_ori = rmvpe.infer_from_audio(ref_waves_16k[0], thred=0.5)
247
- F0_alt = rmvpe.infer_from_audio(converted_waves_16k[0], thred=0.5)
248
-
249
- F0_ori = torch.from_numpy(F0_ori).to(device)[None]
250
- F0_alt = torch.from_numpy(F0_alt).to(device)[None]
251
-
252
- voiced_F0_ori = F0_ori[F0_ori > 1]
253
- voiced_F0_alt = F0_alt[F0_alt > 1]
254
 
255
- log_f0_alt = torch.log(F0_alt + 1e-5)
256
- voiced_log_f0_ori = torch.log(voiced_F0_ori + 1e-5)
257
- voiced_log_f0_alt = torch.log(voiced_F0_alt + 1e-5)
258
- median_log_f0_ori = torch.median(voiced_log_f0_ori)
259
- median_log_f0_alt = torch.median(voiced_log_f0_alt)
260
-
261
- # shift alt log f0 level to ori log f0 level
262
- shifted_log_f0_alt = log_f0_alt.clone()
263
- if auto_f0_adjust:
264
- shifted_log_f0_alt[F0_alt > 1] = log_f0_alt[F0_alt > 1] - median_log_f0_alt + median_log_f0_ori
265
- shifted_f0_alt = torch.exp(shifted_log_f0_alt)
266
- if pitch_shift != 0:
267
- shifted_f0_alt[F0_alt > 1] = adjust_f0_semitones(shifted_f0_alt[F0_alt > 1], pitch_shift)
268
- else:
269
- F0_ori = None
270
- F0_alt = None
271
- shifted_f0_alt = None
272
-
273
- # Length regulation
274
- cond, _, codes, commitment_loss, codebook_loss = inference_module.length_regulator(S_alt, ylens=target_lengths, n_quantizers=3, f0=shifted_f0_alt)
275
- prompt_condition, _, codes, commitment_loss, codebook_loss = inference_module.length_regulator(S_ori, ylens=target2_lengths, n_quantizers=3, f0=F0_ori)
276
-
277
- max_source_window = max_context_window - mel2.size(2)
278
- # split source condition (cond) into chunks
279
- processed_frames = 0
280
- generated_wave_chunks = []
281
- # generate chunk by chunk and stream the output
282
- while processed_frames < cond.size(1):
283
- chunk_cond = cond[:, processed_frames:processed_frames + max_source_window]
284
- is_last_chunk = processed_frames + max_source_window >= cond.size(1)
285
- cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
286
- with torch.autocast(device_type='cuda', dtype=torch.float16):
287
- # Voice Conversion
288
- vc_target = inference_module.cfm.inference(cat_condition,
289
- torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
290
- mel2, style2, None, diffusion_steps,
291
- inference_cfg_rate=inference_cfg_rate)
292
- vc_target = vc_target[:, :, mel2.size(-1):]
293
- vc_wave = bigvgan_fn(vc_target.float())[0]
294
- if processed_frames == 0:
295
- if is_last_chunk:
296
- output_wave = vc_wave[0].cpu().numpy()
297
- generated_wave_chunks.append(output_wave)
298
- output_wave = (output_wave * 32768.0).astype(np.int16)
299
- mp3_bytes = AudioSegment(
300
- output_wave.tobytes(), frame_rate=sr,
301
- sample_width=output_wave.dtype.itemsize, channels=1
302
- ).export(format="mp3", bitrate=bitrate).read()
303
- yield mp3_bytes, (sr, np.concatenate(generated_wave_chunks))
304
- break
305
- output_wave = vc_wave[0, :-overlap_wave_len].cpu().numpy()
306
- generated_wave_chunks.append(output_wave)
307
- previous_chunk = vc_wave[0, -overlap_wave_len:]
308
- processed_frames += vc_target.size(2) - overlap_frame_len
309
- output_wave = (output_wave * 32768.0).astype(np.int16)
310
- mp3_bytes = AudioSegment(
311
- output_wave.tobytes(), frame_rate=sr,
312
- sample_width=output_wave.dtype.itemsize, channels=1
313
- ).export(format="mp3", bitrate=bitrate).read()
314
- yield mp3_bytes, None
315
- elif is_last_chunk:
316
- output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0].cpu().numpy(), overlap_wave_len)
317
- generated_wave_chunks.append(output_wave)
318
- processed_frames += vc_target.size(2) - overlap_frame_len
319
- output_wave = (output_wave * 32768.0).astype(np.int16)
320
- mp3_bytes = AudioSegment(
321
- output_wave.tobytes(), frame_rate=sr,
322
- sample_width=output_wave.dtype.itemsize, channels=1
323
- ).export(format="mp3", bitrate=bitrate).read()
324
- yield mp3_bytes, (sr, np.concatenate(generated_wave_chunks))
325
- break
326
- else:
327
- output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0, :-overlap_wave_len].cpu().numpy(), overlap_wave_len)
328
- generated_wave_chunks.append(output_wave)
329
- previous_chunk = vc_wave[0, -overlap_wave_len:]
330
- processed_frames += vc_target.size(2) - overlap_frame_len
331
- output_wave = (output_wave * 32768.0).astype(np.int16)
332
- mp3_bytes = AudioSegment(
333
- output_wave.tobytes(), frame_rate=sr,
334
- sample_width=output_wave.dtype.itemsize, channels=1
335
- ).export(format="mp3", bitrate=bitrate).read()
336
- yield mp3_bytes, None
337
-
338
-
339
- if __name__ == "__main__":
340
- description = ("State-of-the-Art zero-shot voice conversion/singing voice conversion. For local deployment please check [GitHub repository](https://github.com/Plachtaa/seed-vc) "
341
  "for details and updates.<br>Note that any reference audio will be forcefully clipped to 25s if beyond this length.<br> "
342
  "If total duration of source and reference audio exceeds 30s, source audio will be processed in chunks.<br> "
343
- "无需训练的 zero-shot 语音/歌声转换模型,若需本地部署查看[GitHub页面](https://github.com/Plachtaa/seed-vc)<br>"
344
  "请注意,参考音频若超过 25 秒,则会被自动裁剪至此长度。<br>若源音频和参考音频的总时长超过 30 秒,源音频将被分段处理。")
 
345
  inputs = [
346
  gr.Audio(type="filepath", label="Source Audio / 源音频"),
347
  gr.Audio(type="filepath", label="Reference Audio / 参考音频"),
348
- gr.Slider(minimum=1, maximum=200, value=25, step=1, label="Diffusion Steps / 扩散步数", info="25 by default, 50~100 for best quality / 默认为 25,50~100 为最佳质量"),
349
- gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Adjust / 长度调整", info="<1.0 for speed-up speech, >1.0 for slow-down speech / <1.0 加速语速,>1.0 减慢语速"),
350
- gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="Inference CFG Rate", info="has subtle influence / 有微小影响"),
351
- gr.Checkbox(label="Use F0 conditioned model / 启用F0输入", value=False, info="Must set to true for singing voice conversion / 歌声转换时必须勾选"),
352
- gr.Checkbox(label="Auto F0 adjust / 自动F0调整", value=True,
353
- info="Roughly adjust F0 to match target voice. Only works when F0 conditioned model is used. / 粗略调整 F0 以匹配目标音色,仅在勾选 '启用F0输入' 时生效"),
354
- gr.Slider(label='Pitch shift / 音调变换', minimum=-24, maximum=24, step=1, value=0, info="Pitch shift in semitones, only works when F0 conditioned model is used / 半音数的音高变换,仅在勾选 '启用F0输入' 时生效"),
 
 
 
 
 
 
 
 
 
355
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
- examples = [["examples/source/yae_0.wav", "examples/reference/dingzhen_0.wav", 25, 1.0, 0.7, False, True, 0],
358
- ["examples/source/jay_0.wav", "examples/reference/azuma_0.wav", 25, 1.0, 0.7, False, True, 0],
359
- ["examples/source/Wiz Khalifa,Charlie Puth - See You Again [vocals]_[cut_28sec].wav",
360
- "examples/reference/kobe_0.wav", 50, 1.0, 0.7, True, False, -6],
361
- ["examples/source/TECHNOPOLIS - 2085 [vocals]_[cut_14sec].wav",
362
- "examples/reference/trump_0.wav", 50, 1.0, 0.7, True, False, -12],
363
- ]
364
-
365
- outputs = [gr.Audio(label="Stream Output Audio / 流式输出", streaming=True, format='mp3'),
366
- gr.Audio(label="Full Output Audio / 完整输出", streaming=False, format='wav')]
367
-
368
- gr.Interface(fn=voice_conversion,
369
- description=description,
370
- inputs=inputs,
371
- outputs=outputs,
372
- title="Seed Voice Conversion",
373
- examples=examples,
374
- cache_examples=False,
375
- ).launch()
 
 
1
  import gradio as gr
2
  import torch
 
 
 
3
  import yaml
4
+ import os
5
+ from huggingface_hub import hf_hub_download
6
+ # Assuming these are available in your Space's environment
7
+ # from seed_vc_wrapper import SeedVCWrapper
8
+ # from modules.v2.vc_wrapper import VoiceConversionWrapper
9
+
10
+ # --- CONFIGURATION (UPDATE YOUR_USERNAME HERE) ---
11
+ # Replace 'YOUR_USERNAME' with your actual Hugging Face username
12
+ MODEL_REPO_ID = "Bajiyo/dhanush_seedvc"
13
+ CFM_FILE = "CFM_epoch_00651_step_21500.pth"
14
+ AR_FILE = "AR_epoch_00651_step_21500.pth"
15
+ # -----------------------------------------------
16
+
17
+ if torch.cuda.is_available():
18
+ device = torch.device("cuda")
19
+ elif torch.backends.mps.is_available():
20
+ device = torch.device("mps")
21
+ else:
22
+ device = torch.device("cpu")
23
+
24
+ dtype = torch.float16
25
+
26
+ def load_models(args):
27
+ """
28
+ Loads models, handling checkpoint download from Hugging Face Hub.
29
+ """
30
+ # 1. Setup local directory and download checkpoints
31
+ LOCAL_CHECKPOINTS_DIR = "downloaded_checkpoints"
32
+ os.makedirs(LOCAL_CHECKPOINTS_DIR, exist_ok=True)
33
+ print(f"Downloading checkpoints from {MODEL_REPO_ID}...")
34
+
35
+ # Download CFM
36
+ cfm_local_path = hf_hub_download(
37
+ repo_id=MODEL_REPO_ID,
38
+ filename=CFM_FILE,
39
+ local_dir=LOCAL_CHECKPOINTS_DIR,
40
+ local_dir_use_symlinks=False
41
+ )
42
+ print(f"CFM checkpoint downloaded to: {cfm_local_path}")
43
+
44
+ # Download AR
45
+ ar_local_path = hf_hub_download(
46
+ repo_id=MODEL_REPO_ID,
47
+ filename=AR_FILE,
48
+ local_dir=LOCAL_CHECKPOINTS_DIR,
49
+ local_dir_use_symlinks=False
50
+ )
51
+ print(f"AR checkpoint downloaded to: {ar_local_path}")
52
+
53
+ # 2. Instantiate and load models
54
+ from hydra.utils import instantiate
55
+ from omegaconf import DictConfig
56
+
57
+ # Assuming 'configs/v2/vc_wrapper.yaml' is present in the Space repo
58
+ cfg = DictConfig(yaml.safe_load(open("configs/v2/vc_wrapper.yaml", "r")))
59
+ vc_wrapper = instantiate(cfg)
60
+
61
+ # Load the downloaded checkpoints
62
+ vc_wrapper.load_checkpoints(
63
+ ar_checkpoint_path=ar_local_path,
64
+ cfm_checkpoint_path=cfm_local_path
65
+ )
66
+ vc_wrapper.to(device)
67
+ vc_wrapper.eval()
68
+
69
+ vc_wrapper.setup_ar_caches(max_batch_size=1, max_seq_len=4096, dtype=dtype, device=device)
70
+
71
+ if args.compile:
72
+ # Standard torch compile settings
73
+ torch._inductor.config.coordinate_descent_tuning = True
74
+ torch._inductor.config.triton.unique_kernel_names = True
75
+
76
+ if hasattr(torch._inductor.config, "fx_graph_cache"):
77
+ torch._inductor.config.fx_graph_cache = True
78
+ vc_wrapper.compile_ar()
79
+ # vc_wrapper.compile_cfm()
80
+
81
+ return vc_wrapper
82
+
83
+ def main(args):
84
+ # load_models handles the download and initialization now
85
+ vc_wrapper = load_models(args)
86
+
87
+ # Define wrapper function for Gradio to ensure arguments are handled correctly
88
+ @gr.Gradio()
89
+ @spaces.GPU # Ensures conversion runs on the specified GPU if available
90
+ def convert_voice_wrapper(source_audio_path, target_audio_path, diffusion_steps,
91
+ length_adjust, intelligibility_cfg_rate, similarity_cfg_rate,
92
+ top_p, temperature, repetition_penalty, convert_style,
93
+ anonymization_only, stream_output=True):
94
+ """
95
+ Wrapper function for vc_wrapper.convert_voice_with_streaming that can be decorated.
96
+ """
97
+ # Ensure correct type for the stream_output argument if needed,
98
+ # though the main function is now calling convert_voice_with_streaming directly
99
+ yield from vc_wrapper.convert_voice_with_streaming(
100
+ source_audio_path=source_audio_path,
101
+ target_audio_path=target_audio_path,
102
+ diffusion_steps=diffusion_steps,
103
+ length_adjust=length_adjust,
104
+ intelligebility_cfg_rate=intelligibility_cfg_rate,
105
+ similarity_cfg_rate=similarity_cfg_rate,
106
+ top_p=top_p,
107
+ temperature=temperature,
108
+ repetition_penalty=repetition_penalty,
109
+ convert_style=convert_style,
110
+ anonymization_only=anonymization_only,
111
+ device=device,
112
+ dtype=dtype,
113
+ stream_output=stream_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
+ # Set up Gradio interface
117
+ description = ("Zero-shot voice conversion with in-context learning. For local deployment please check [GitHub repository](https://github.com/Plachtaa/seed-vc) "
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  "for details and updates.<br>Note that any reference audio will be forcefully clipped to 25s if beyond this length.<br> "
119
  "If total duration of source and reference audio exceeds 30s, source audio will be processed in chunks.<br> "
120
+ "无需训练的 zero-shot 语音/歌声转换模型,若需本地部署查看[GitHub页面](https://github.com/Plachtaa/seed-vc]<br>"
121
  "请注意,参考音频若超过 25 秒,则会被自动裁剪至此长度。<br>若源音频和参考音频的总时长超过 30 秒,源音频将被分段处理。")
122
+
123
  inputs = [
124
  gr.Audio(type="filepath", label="Source Audio / 源音频"),
125
  gr.Audio(type="filepath", label="Reference Audio / 参考音频"),
126
+ gr.Slider(minimum=1, maximum=200, value=30, step=1, label="Diffusion Steps / 扩散步数",
127
+ info="30 by default, 50~100 for best quality / 默认为 30,50~100 为最佳质量"),
128
+ gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Adjust / 长度调整",
129
+ info="<1.0 for speed-up speech, >1.0 for slow-down speech / <1.0 加速语速,>1.0 减慢语速"),
130
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.5, label="Intelligibility CFG Rate",
131
+ info="controls pronunciation intelligibility / 控制发音清晰度"),
132
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.5, label="Similarity CFG Rate",
133
+ info="controls similarity to reference audio / 控制与参考音频的相似度"),
134
+ gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.9, label="Top-p",
135
+ info="Controls diversity of generated audio / 控制生成音频的多样性"),
136
+ gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Temperature",
137
+ info="Controls randomness of generated audio / 控制生成音频的随机性"),
138
+ gr.Slider(minimum=1.0, maximum=3.0, step=0.1, value=1.0, label="Repetition Penalty",
139
+ info="Penalizes repetition in generated audio / 惩罚生成音频中的重复"),
140
+ gr.Checkbox(label="convert style", value=False),
141
+ gr.Checkbox(label="anonymization only", value=False),
142
  ]
143
+
144
+ examples = [
145
+ ["examples/source/yae_0.wav", "examples/reference/dingzhen_0.wav", 50, 1.0, 0.5, 0.5, 0.9, 1.0, 1.0, False, False],
146
+ ["examples/source/jay_0.wav", "examples/reference/azuma_0.wav", 50, 1.0, 0.5, 0.5, 0.9, 1.0, 1.0, False, False],
147
+ ]
148
+
149
+ outputs = [
150
+ gr.Audio(label="Stream Output Audio / 流式输出", streaming=True, format='mp3'),
151
+ gr.Audio(label="Full Output Audio / 完整输出", streaming=False, format='wav')
152
+ ]
153
+
154
+ # Launch the Gradio interface
155
+ gr.Interface(
156
+ fn=convert_voice_wrapper,
157
+ description=description,
158
+ inputs=inputs,
159
+ outputs=outputs,
160
+ title="Seed Voice Conversion V2",
161
+ examples=examples,
162
+ cache_examples=False,
163
+ ).queue().launch(share=False) # Changed share=True to share=False for Spaces deployment
164
 
165
+ if __name__ == "__main__":
166
+ import argparse
167
+ parser = argparse.ArgumentParser()
168
+ parser.add_argument("--compile", action="store_true", help="Compile the model using torch.compile")
169
+ # These arguments are now effectively ignored/not needed since we download the models
170
+ # but we keep them to maintain compatibility with the original script structure.
171
+ parser.add_argument("--ar-checkpoint-path", type=str, default=None,
172
+ help="Path to custom checkpoint file (overridden by HF download in Space)")
173
+ parser.add_argument("--cfm-checkpoint-path", type=str, default=None,
174
+ help="Path to custom checkpoint file (overridden by HF download in Space)")
175
+ args = parser.parse_args()
176
+ main(args)