yuancwang
init
b725c5a
raw
history blame
No virus
3.81 kB
import argparse
import os
import torch
import soundfile as sf
import numpy as np
from models.tts.naturalspeech2.ns2 import NaturalSpeech2
from encodec import EncodecModel
from encodec.utils import convert_audio
from utils.util import load_config
from text import text_to_sequence
from text.cmudict import valid_symbols
from text.g2p import preprocess_english, read_lexicon
import torchaudio
class NS2Inference:
def __init__(self, args, cfg):
self.cfg = cfg
self.args = args
self.model = self.build_model()
self.codec = self.build_codec()
self.symbols = valid_symbols + ["sp", "spn", "sil"] + ["<s>", "</s>"]
self.phone2id = {s: i for i, s in enumerate(self.symbols)}
self.id2phone = {i: s for s, i in self.phone2id.items()}
def build_model(self):
model = NaturalSpeech2(self.cfg.model)
model.load_state_dict(
torch.load(
os.path.join(self.args.checkpoint_path, "pytorch_model.bin"),
map_location="cpu",
)
)
model = model.to(self.args.device)
return model
def build_codec(self):
encodec_model = EncodecModel.encodec_model_24khz()
encodec_model = encodec_model.to(device=self.args.device)
encodec_model.set_target_bandwidth(12.0)
return encodec_model
def get_ref_code(self):
ref_wav_path = self.args.ref_audio
ref_wav, sr = torchaudio.load(ref_wav_path)
ref_wav = convert_audio(
ref_wav, sr, self.codec.sample_rate, self.codec.channels
)
ref_wav = ref_wav.unsqueeze(0).to(device=self.args.device)
with torch.no_grad():
encoded_frames = self.codec.encode(ref_wav)
ref_code = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)
# print(ref_code.shape)
ref_mask = torch.ones(ref_code.shape[0], ref_code.shape[-1]).to(ref_code.device)
# print(ref_mask.shape)
return ref_code, ref_mask
def inference(self):
ref_code, ref_mask = self.get_ref_code()
lexicon = read_lexicon(self.cfg.preprocess.lexicon_path)
phone_seq = preprocess_english(self.args.text, lexicon)
print(phone_seq)
phone_id = np.array(
[
*map(
self.phone2id.get,
phone_seq.replace("{", "").replace("}", "").split(),
)
]
)
phone_id = torch.from_numpy(phone_id).unsqueeze(0).to(device=self.args.device)
print(phone_id)
x0, prior_out = self.model.inference(
ref_code, phone_id, ref_mask, self.args.inference_step
)
print(prior_out["dur_pred"])
print(prior_out["dur_pred_round"])
print(torch.sum(prior_out["dur_pred_round"]))
latent_ref = self.codec.quantizer.vq.decode(ref_code.transpose(0, 1))
rec_wav = self.codec.decoder(x0)
# ref_wav = self.codec.decoder(latent_ref)
os.makedirs(self.args.output_dir, exist_ok=True)
sf.write(
"{}/{}.wav".format(
self.args.output_dir, self.args.text.replace(" ", "_", 100)
),
rec_wav[0, 0].detach().cpu().numpy(),
samplerate=24000,
)
def add_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--ref_audio",
type=str,
default="",
help="Reference audio path",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
)
parser.add_argument(
"--inference_step",
type=int,
default=200,
help="Total inference steps for the diffusion model",
)