File size: 6,324 Bytes
c968fc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torchaudio
class ValleInference(torch.nn.Module):
def __init__(
self,
use_vocos=False,
use_speechtokenizer=True,
ar_path=None,
nar_path=None,
speechtokenizer_path=None,
device="cuda",
):
super().__init__()
self.device = device
# prepare pretrained VALLE AR model
from .valle_ar import ValleAR
self.ar_model = ValleAR(
phone_vocab_size=300,
target_vocab_size=1024,
pad_token_id=1324,
bos_target_id=1325,
eos_target_id=1326,
bos_phone_id=1327,
eos_phone_id=1328,
bos_prompt_id=1329,
eos_prompt_id=1330,
num_hidden_layers=16,
)
# change the following path to your trained model path
assert ar_path is not None
self.ar_model.load_state_dict(torch.load(ar_path, map_location="cpu"))
self.ar_model.eval().to(self.device)
# prepare pretrained VALLE NAR model
from .valle_nar import ValleNAR
self.nar_model = ValleNAR(
phone_vocab_size=300,
target_vocab_size=1024,
pad_token_id=1324,
bos_target_id=1325,
eos_target_id=1326,
bos_phone_id=1327,
eos_phone_id=1328,
bos_prompt_id=1329,
eos_prompt_id=1330,
num_hidden_layers=16,
)
assert nar_path is not None
self.nar_model.load_state_dict(torch.load(nar_path, map_location="cpu"))
self.nar_model.eval().to(self.device)
# prepare codec encoder
assert not (
use_speechtokenizer and use_vocos
), "Only one of use_speechtokenizer and use_vocos can be True"
self.use_speechtokenizer = use_speechtokenizer
if use_speechtokenizer:
from models.codec.speechtokenizer.model import SpeechTokenizer
# download from https://huggingface.co/fnlp/SpeechTokenizer/tree/main/speechtokenizer_hubert_avg
config_path = speechtokenizer_path + "/config.json"
ckpt_path = speechtokenizer_path + "/SpeechTokenizer.pt"
self.codec_encoder = SpeechTokenizer.load_from_checkpoint(
config_path, ckpt_path
)
self.codec_encoder.eval()
self.codec_encoder.to(device)
print(f"Loaded SpeechTokenizer from {config_path} and {ckpt_path}")
else:
# use Encodec
from encodec import EncodecModel
self.codec_encoder = EncodecModel.encodec_model_24khz()
self.codec_encoder.set_target_bandwidth(6.0)
self.codec_encoder.to(self.device)
if use_vocos:
from vocos import Vocos
self.codec_decoder = Vocos.from_pretrained(
"charactr/vocos-encodec-24khz"
)
self.codec_decoder.to(self.device)
print("Loaded Vocos")
print("Loaded EncodecModel")
self.use_vocos = use_vocos
def decode(self, vq_ids):
"""vq_ids.shape: [8, B, T],
returns: [B, 1, T]"""
if self.use_speechtokenizer:
# infer speechtokenizer
return self.codec_encoder.decode(vq_ids) # [B, 1, T]
else:
if not self.use_vocos:
# vocos decoder
return self.codec_encoder.decode([(vq_ids.transpose(0, 1), None)])
else:
# encodec decoder
features = self.codec_decoder.codes_to_features(vq_ids.squeeze(1))
bandwidth_id = torch.tensor([2], device=vq_ids.device)
return self.codec_decoder.decode(
features, bandwidth_id=bandwidth_id
).unsqueeze(0)
def forward(self, batch, chunk_configs: list, return_prompt=False, prompt_len=None):
"""batch: dict(
speech: [B, T]
phone_ids: [B, T]
)
returns: [B, 1, T] audio
"""
if prompt_len is None:
prompt_len = 100000 # no prompt length limiting
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(self.device)
with torch.no_grad():
if self.use_speechtokenizer:
vq_id = self.codec_encoder.encode(
batch["speech"].unsqueeze(1)
) # [B,1,T] -> (n_q, B, T)
else:
vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1))
vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose(
0, 1
)
# typically we only require one config in the chunk,
# but we can also use multiple configs to, for example, use different sampling temperature at different positions
for chunk in chunk_configs:
ar_vq_ids = self.ar_model.sample_hf(
batch["phone_ids"],
vq_id[0, :, :prompt_len],
top_p=chunk["top_p"],
top_k=chunk["top_k"],
temperature=chunk["temperature"],
num_beams=chunk["num_beams"],
repeat_penalty=chunk["repeat_penalty"],
max_length=chunk["max_length"],
)
# recovered_audio_ar = self.decode(ar_vq_ids.unsqueeze(0))
# torchaudio.save('recovered_audio_ar.wav', recovered_audio_ar[0].cpu(), 24000)
nar_vq_ids = self.nar_model.sample_hf(
phone_ids=batch["phone_ids"],
prompt_ids=vq_id[:, :, :prompt_len],
first_stage_ids=ar_vq_ids,
# first_stage_ids=vq_id[0, :, prompt_len:],
)
if return_prompt:
nar_vq_ids = torch.cat(
[vq_id[..., :prompt_len], nar_vq_ids], dim=-1
)
recovered_audio = self.decode(nar_vq_ids)
return recovered_audio # [B, 1, T]
|