Update modules/v2/vc_wrapper.py
Browse files- modules/v2/vc_wrapper.py +21 -80
modules/v2/vc_wrapper.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import librosa
|
| 3 |
import torchaudio
|
|
@@ -52,56 +53,6 @@ class VoiceConversionWrapper(torch.nn.Module):
|
|
| 52 |
self.ar_max_content_len = 1500 # in num of narrow tokens
|
| 53 |
self.compile_len = 87 * self.dit_max_context_len
|
| 54 |
|
| 55 |
-
def forward_cfm(self, content_indices_wide, content_lens, mels, mel_lens, style_vectors):
|
| 56 |
-
device = content_indices_wide.device
|
| 57 |
-
B = content_indices_wide.size(0)
|
| 58 |
-
cond, _ = self.cfm_length_regulator(content_indices_wide, ylens=mel_lens)
|
| 59 |
-
|
| 60 |
-
# randomly set a length as prompt
|
| 61 |
-
prompt_len_max = mel_lens - 1
|
| 62 |
-
prompt_len = (torch.rand([B], device=device) * prompt_len_max).floor().to(dtype=torch.long)
|
| 63 |
-
prompt_len[torch.rand([B], device=device) < 0.1] = 0
|
| 64 |
-
|
| 65 |
-
loss = self.cfm(mels, mel_lens, prompt_len, cond, style_vectors)
|
| 66 |
-
return loss
|
| 67 |
-
|
| 68 |
-
def forward_ar(self, content_indices_narrow, content_indices_wide, content_lens):
|
| 69 |
-
device = content_indices_narrow.device
|
| 70 |
-
duration_reduced_narrow_tokens = []
|
| 71 |
-
duration_reduced_narrow_lens = []
|
| 72 |
-
for bib in range(content_indices_narrow.size(0)):
|
| 73 |
-
reduced, reduced_len = self.duration_reduction_func(content_indices_narrow[bib])
|
| 74 |
-
duration_reduced_narrow_tokens.append(reduced)
|
| 75 |
-
duration_reduced_narrow_lens.append(reduced_len)
|
| 76 |
-
duration_reduced_narrow_tokens = torch.nn.utils.rnn.pad_sequence(duration_reduced_narrow_tokens,
|
| 77 |
-
batch_first=True, padding_value=0).to(device)
|
| 78 |
-
duration_reduced_narrow_lens = torch.LongTensor(duration_reduced_narrow_lens).to(device)
|
| 79 |
-
|
| 80 |
-
# interpolate speech token to match acoustic feature length
|
| 81 |
-
cond, _ = self.ar_length_regulator(duration_reduced_narrow_tokens)
|
| 82 |
-
loss = self.ar(cond, duration_reduced_narrow_lens, content_indices_wide, content_lens)
|
| 83 |
-
return loss
|
| 84 |
-
|
| 85 |
-
def forward(self, waves_16k, mels, wave_lens_16k, mel_lens, forward_ar=False, forward_cfm=True):
|
| 86 |
-
"""
|
| 87 |
-
Forward pass for the model.
|
| 88 |
-
"""
|
| 89 |
-
# extract wide content features as both AR and CFM models use them
|
| 90 |
-
with torch.no_grad():
|
| 91 |
-
_, content_indices_wide, content_lens = self.content_extractor_wide(waves_16k, wave_lens_16k)
|
| 92 |
-
if forward_ar:
|
| 93 |
-
# extract narrow content features for AR model
|
| 94 |
-
_, content_indices_narrow, _ = self.content_extractor_narrow(waves_16k, wave_lens_16k, ssl_model=self.content_extractor_wide.ssl_model)
|
| 95 |
-
loss_ar = self.forward_ar(content_indices_narrow.clone(), content_indices_wide.clone(), content_lens)
|
| 96 |
-
else:
|
| 97 |
-
loss_ar = torch.tensor(0.0, device=waves_16k.device, dtype=waves_16k.dtype)
|
| 98 |
-
if forward_cfm:
|
| 99 |
-
style_vectors = self.compute_style(waves_16k, wave_lens_16k)
|
| 100 |
-
loss_cfm = self.forward_cfm(content_indices_wide, content_lens, mels, mel_lens, style_vectors)
|
| 101 |
-
else:
|
| 102 |
-
loss_cfm = torch.tensor(0.0, device=waves_16k.device, dtype=waves_16k.dtype)
|
| 103 |
-
return loss_ar, loss_cfm
|
| 104 |
-
|
| 105 |
def compile_ar(self):
|
| 106 |
"""
|
| 107 |
Compile the AR model for inference.
|
|
@@ -258,28 +209,24 @@ class VoiceConversionWrapper(torch.nn.Module):
|
|
| 258 |
repo_id=DEFAULT_REPO_ID,
|
| 259 |
model_filename=DEFAULT_CFM_CHECKPOINT,
|
| 260 |
)
|
| 261 |
-
else:
|
| 262 |
-
print(f"Loading CFM checkpoint from {cfm_checkpoint_path}...")
|
| 263 |
if ar_checkpoint_path is None:
|
| 264 |
ar_checkpoint_path = load_custom_model_from_hf(
|
| 265 |
repo_id=DEFAULT_REPO_ID,
|
| 266 |
model_filename=DEFAULT_AR_CHECKPOINT,
|
| 267 |
)
|
| 268 |
-
else:
|
| 269 |
-
print(f"Loading AR checkpoint from {ar_checkpoint_path}...")
|
| 270 |
# cfm
|
| 271 |
cfm_checkpoint = torch.load(cfm_checkpoint_path, map_location="cpu")
|
| 272 |
cfm_length_regulator_state_dict = self.strip_prefix(cfm_checkpoint["net"]['length_regulator'], "module.")
|
| 273 |
cfm_state_dict = self.strip_prefix(cfm_checkpoint["net"]['cfm'], "module.")
|
| 274 |
-
|
| 275 |
-
|
| 276 |
|
| 277 |
# ar
|
| 278 |
ar_checkpoint = torch.load(ar_checkpoint_path, map_location="cpu")
|
| 279 |
ar_length_regulator_state_dict = self.strip_prefix(ar_checkpoint["net"]['length_regulator'], "module.")
|
| 280 |
ar_state_dict = self.strip_prefix(ar_checkpoint["net"]['ar'], "module.")
|
| 281 |
-
|
| 282 |
-
|
| 283 |
|
| 284 |
# content extractor
|
| 285 |
content_extractor_narrow_checkpoint_path = load_custom_model_from_hf(
|
|
@@ -308,26 +255,13 @@ class VoiceConversionWrapper(torch.nn.Module):
|
|
| 308 |
def setup_ar_caches(self, max_batch_size=1, max_seq_len=4096, dtype=torch.float32, device=torch.device("cpu")):
|
| 309 |
self.ar.setup_caches(max_batch_size=max_batch_size, max_seq_len=max_seq_len, dtype=dtype, device=device)
|
| 310 |
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
num_mel_bins=80,
|
| 319 |
-
dither=0,
|
| 320 |
-
sample_frequency=16000)
|
| 321 |
-
feat = feat - feat.mean(dim=0, keepdim=True)
|
| 322 |
-
feat_list.append(feat)
|
| 323 |
-
max_feat_len = max([feat.size(0) for feat in feat_list])
|
| 324 |
-
feat_lens = torch.tensor([feat.size(0) for feat in feat_list], dtype=torch.int32).to(waves_16k.device) // 2
|
| 325 |
-
feat_list = [
|
| 326 |
-
torch.nn.functional.pad(feat, (0, 0, 0, max_feat_len - feat.size(0)), value=float(feat.min().item()))
|
| 327 |
-
for feat in feat_list
|
| 328 |
-
]
|
| 329 |
-
feat = torch.stack(feat_list, dim=0)
|
| 330 |
-
style = self.style_encoder(feat, feat_lens)
|
| 331 |
return style
|
| 332 |
|
| 333 |
@torch.no_grad()
|
|
@@ -490,6 +424,7 @@ class VoiceConversionWrapper(torch.nn.Module):
|
|
| 490 |
|
| 491 |
return content_indices
|
| 492 |
|
|
|
|
| 493 |
@torch.no_grad()
|
| 494 |
@torch.inference_mode()
|
| 495 |
def convert_voice_with_streaming(
|
|
@@ -623,7 +558,10 @@ class VoiceConversionWrapper(torch.nn.Module):
|
|
| 623 |
|
| 624 |
if stream_output and mp3_bytes is not None:
|
| 625 |
yield mp3_bytes, full_audio
|
|
|
|
| 626 |
if should_break:
|
|
|
|
|
|
|
| 627 |
break
|
| 628 |
else:
|
| 629 |
cond, _ = self.cfm_length_regulator(source_content_indices, ylens=torch.LongTensor([source_mel_len]).to(device))
|
|
@@ -641,7 +579,7 @@ class VoiceConversionWrapper(torch.nn.Module):
|
|
| 641 |
if self.dit_compiled:
|
| 642 |
cat_condition = torch.nn.functional.pad(cat_condition,
|
| 643 |
(0, 0, 0, self.compile_len - cat_condition.size(1),), value=0)
|
| 644 |
-
with torch.autocast(device_type=device.type, dtype=
|
| 645 |
# Voice Conversion
|
| 646 |
vc_mel = self.cfm.inference(
|
| 647 |
cat_condition,
|
|
@@ -660,5 +598,8 @@ class VoiceConversionWrapper(torch.nn.Module):
|
|
| 660 |
|
| 661 |
if stream_output and mp3_bytes is not None:
|
| 662 |
yield mp3_bytes, full_audio
|
|
|
|
| 663 |
if should_break:
|
| 664 |
-
|
|
|
|
|
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
import torch
|
| 3 |
import librosa
|
| 4 |
import torchaudio
|
|
|
|
| 53 |
self.ar_max_content_len = 1500 # in num of narrow tokens
|
| 54 |
self.compile_len = 87 * self.dit_max_context_len
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
def compile_ar(self):
|
| 57 |
"""
|
| 58 |
Compile the AR model for inference.
|
|
|
|
| 209 |
repo_id=DEFAULT_REPO_ID,
|
| 210 |
model_filename=DEFAULT_CFM_CHECKPOINT,
|
| 211 |
)
|
|
|
|
|
|
|
| 212 |
if ar_checkpoint_path is None:
|
| 213 |
ar_checkpoint_path = load_custom_model_from_hf(
|
| 214 |
repo_id=DEFAULT_REPO_ID,
|
| 215 |
model_filename=DEFAULT_AR_CHECKPOINT,
|
| 216 |
)
|
|
|
|
|
|
|
| 217 |
# cfm
|
| 218 |
cfm_checkpoint = torch.load(cfm_checkpoint_path, map_location="cpu")
|
| 219 |
cfm_length_regulator_state_dict = self.strip_prefix(cfm_checkpoint["net"]['length_regulator'], "module.")
|
| 220 |
cfm_state_dict = self.strip_prefix(cfm_checkpoint["net"]['cfm'], "module.")
|
| 221 |
+
self.cfm.load_state_dict(cfm_state_dict, strict=False)
|
| 222 |
+
self.cfm_length_regulator.load_state_dict(cfm_length_regulator_state_dict, strict=False)
|
| 223 |
|
| 224 |
# ar
|
| 225 |
ar_checkpoint = torch.load(ar_checkpoint_path, map_location="cpu")
|
| 226 |
ar_length_regulator_state_dict = self.strip_prefix(ar_checkpoint["net"]['length_regulator'], "module.")
|
| 227 |
ar_state_dict = self.strip_prefix(ar_checkpoint["net"]['ar'], "module.")
|
| 228 |
+
self.ar.load_state_dict(ar_state_dict, strict=False)
|
| 229 |
+
self.ar_length_regulator.load_state_dict(ar_length_regulator_state_dict, strict=False)
|
| 230 |
|
| 231 |
# content extractor
|
| 232 |
content_extractor_narrow_checkpoint_path = load_custom_model_from_hf(
|
|
|
|
| 255 |
def setup_ar_caches(self, max_batch_size=1, max_seq_len=4096, dtype=torch.float32, device=torch.device("cpu")):
|
| 256 |
self.ar.setup_caches(max_batch_size=max_batch_size, max_seq_len=max_seq_len, dtype=dtype, device=device)
|
| 257 |
|
| 258 |
+
def compute_style(self, waves_16k: torch.Tensor):
|
| 259 |
+
feat = torchaudio.compliance.kaldi.fbank(waves_16k,
|
| 260 |
+
num_mel_bins=80,
|
| 261 |
+
dither=0,
|
| 262 |
+
sample_frequency=16000)
|
| 263 |
+
feat = feat - feat.mean(dim=0, keepdim=True)
|
| 264 |
+
style = self.style_encoder(feat.unsqueeze(0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
return style
|
| 266 |
|
| 267 |
@torch.no_grad()
|
|
|
|
| 424 |
|
| 425 |
return content_indices
|
| 426 |
|
| 427 |
+
@spaces.GPU
|
| 428 |
@torch.no_grad()
|
| 429 |
@torch.inference_mode()
|
| 430 |
def convert_voice_with_streaming(
|
|
|
|
| 558 |
|
| 559 |
if stream_output and mp3_bytes is not None:
|
| 560 |
yield mp3_bytes, full_audio
|
| 561 |
+
|
| 562 |
if should_break:
|
| 563 |
+
if not stream_output:
|
| 564 |
+
return full_audio
|
| 565 |
break
|
| 566 |
else:
|
| 567 |
cond, _ = self.cfm_length_regulator(source_content_indices, ylens=torch.LongTensor([source_mel_len]).to(device))
|
|
|
|
| 579 |
if self.dit_compiled:
|
| 580 |
cat_condition = torch.nn.functional.pad(cat_condition,
|
| 581 |
(0, 0, 0, self.compile_len - cat_condition.size(1),), value=0)
|
| 582 |
+
with torch.autocast(device_type=device.type, dtype=dtype):
|
| 583 |
# Voice Conversion
|
| 584 |
vc_mel = self.cfm.inference(
|
| 585 |
cat_condition,
|
|
|
|
| 598 |
|
| 599 |
if stream_output and mp3_bytes is not None:
|
| 600 |
yield mp3_bytes, full_audio
|
| 601 |
+
|
| 602 |
if should_break:
|
| 603 |
+
if not stream_output:
|
| 604 |
+
return full_audio
|
| 605 |
+
break
|