NaturalSpeech2 / models /tts /valle /valle_inference.py
yuancwang
init
b725c5a
raw
history blame
No virus
8.37 kB
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import numpy as np
import torch
import torchaudio
import argparse
from text.g2p_module import G2PModule
from utils.tokenizer import AudioTokenizer, tokenize_audio
from models.tts.valle.valle import VALLE
from models.tts.base.tts_inferece import TTSInference
from models.tts.valle.valle_dataset import VALLETestDataset, VALLETestCollator
from processors.phone_extractor import phoneExtractor
from text.text_token_collation import phoneIDCollation
class VALLEInference(TTSInference):
def __init__(self, args=None, cfg=None):
TTSInference.__init__(self, args, cfg)
self.g2p_module = G2PModule(backend=self.cfg.preprocess.phone_extractor)
text_token_path = os.path.join(
cfg.preprocess.processed_dir, cfg.dataset[0], cfg.preprocess.symbols_dict
)
self.audio_tokenizer = AudioTokenizer()
def _build_model(self):
model = VALLE(self.cfg.model)
return model
def _build_test_dataset(self):
return VALLETestDataset, VALLETestCollator
def inference_one_clip(self, text, text_prompt, audio_file, save_name="pred"):
# get phone symbol file
phone_symbol_file = None
if self.cfg.preprocess.phone_extractor != "lexicon":
phone_symbol_file = os.path.join(
self.exp_dir, self.cfg.preprocess.symbols_dict
)
assert os.path.exists(phone_symbol_file)
# convert text to phone sequence
phone_extractor = phoneExtractor(self.cfg)
# convert phone sequence to phone id sequence
phon_id_collator = phoneIDCollation(
self.cfg, symbols_dict_file=phone_symbol_file
)
text = f"{text_prompt} {text}".strip()
phone_seq = phone_extractor.extract_phone(text) # phone_seq: list
phone_id_seq = phon_id_collator.get_phone_id_sequence(self.cfg, phone_seq)
phone_id_seq_len = torch.IntTensor([len(phone_id_seq)]).to(self.device)
# convert phone sequence to phone id sequence
phone_id_seq = np.array([phone_id_seq])
phone_id_seq = torch.from_numpy(phone_id_seq).to(self.device)
# extract acoustic token
encoded_frames = tokenize_audio(self.audio_tokenizer, audio_file)
audio_prompt_token = encoded_frames[0][0].transpose(2, 1).to(self.device)
# copysyn
if self.args.copysyn:
samples = self.audio_tokenizer.decode(encoded_frames)
audio_copysyn = samples[0].cpu().detach()
out_path = os.path.join(
self.args.output_dir, self.infer_type, f"{save_name}_copysyn.wav"
)
torchaudio.save(out_path, audio_copysyn, self.cfg.preprocess.sampling_rate)
if self.args.continual:
encoded_frames = self.model.continual(
phone_id_seq,
phone_id_seq_len,
audio_prompt_token,
)
else:
enroll_x_lens = None
if text_prompt:
# prompt_phone_seq = tokenize_text(self.g2p_module, text=f"{text_prompt}".strip())
# _, enroll_x_lens = self.text_tokenizer.get_token_id_seq(prompt_phone_seq)
text = f"{text_prompt}".strip()
prompt_phone_seq = phone_extractor.extract_phone(
text
) # phone_seq: list
prompt_phone_id_seq = phon_id_collator.get_phone_id_sequence(
self.cfg, prompt_phone_seq
)
prompt_phone_id_seq_len = torch.IntTensor(
[len(prompt_phone_id_seq)]
).to(self.device)
encoded_frames = self.model.inference(
phone_id_seq,
phone_id_seq_len,
audio_prompt_token,
enroll_x_lens=prompt_phone_id_seq_len,
top_k=self.args.top_k,
temperature=self.args.temperature,
)
samples = self.audio_tokenizer.decode([(encoded_frames.transpose(2, 1), None)])
audio = samples[0].squeeze(0).cpu().detach()
return audio
def inference_for_single_utterance(self):
text = self.args.text
text_prompt = self.args.text_prompt
audio_file = self.args.audio_prompt
if not self.args.continual:
assert text != ""
else:
text = ""
assert text_prompt != ""
assert audio_file != ""
audio = self.inference_one_clip(text, text_prompt, audio_file)
return audio
def inference_for_batches(self):
test_list_file = self.args.test_list_file
assert test_list_file is not None
pred_res = []
with open(test_list_file, "r") as fin:
for idx, line in enumerate(fin.readlines()):
fields = line.strip().split("|")
if self.args.continual:
assert len(fields) == 2
text_prompt, audio_prompt_path = fields
text = ""
else:
assert len(fields) == 3
text_prompt, audio_prompt_path, text = fields
audio = self.inference_one_clip(
text, text_prompt, audio_prompt_path, str(idx)
)
pred_res.append(audio)
return pred_res
"""
TODO: batch inference
###### Construct test_batch ######
n_batch = len(self.test_dataloader)
now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(
"Model eval time: {}, batch_size = {}, n_batch = {}".format(
now, self.test_batch_size, n_batch
)
)
###### Inference for each batch ######
pred_res = []
with torch.no_grad():
for i, batch_data in enumerate(
self.test_dataloader if n_batch == 1 else tqdm(self.test_dataloader)
):
if self.args.continual:
encoded_frames = self.model.continual(
batch_data["phone_seq"],
batch_data["phone_len"],
batch_data["acoustic_token"],
)
else:
encoded_frames = self.model.inference(
batch_data["phone_seq"],
batch_data["phone_len"],
batch_data["acoustic_token"],
enroll_x_lens=batch_data["pmt_phone_len"],
top_k=self.args.top_k,
temperature=self.args.temperature
)
samples = self.audio_tokenizer.decode(
[(encoded_frames.transpose(2, 1), None)]
)
for idx in range(samples.size(0)):
audio = samples[idx].cpu()
pred_res.append(audio)
return pred_res
"""
def add_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--text_prompt",
type=str,
default="",
help="Text prompt that should be aligned with --audio_prompt.",
)
parser.add_argument(
"--audio_prompt",
type=str,
default="",
help="Audio prompt that should be aligned with --text_prompt.",
)
parser.add_argument(
"--top-k",
type=int,
default=-100,
help="Whether AR Decoder do top_k(if > 0) sampling.",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="The temperature of AR Decoder top_k sampling.",
)
parser.add_argument(
"--continual",
action="store_true",
help="Inference for continual task.",
)
parser.add_argument(
"--copysyn",
action="store_true",
help="Copysyn: generate audio by decoder of the original audio tokenizer.",
)