|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import print_function |
|
|
|
import argparse |
|
import logging |
|
logging.getLogger('matplotlib').setLevel(logging.WARNING) |
|
import os |
|
|
|
import torch |
|
from torch.utils.data import DataLoader |
|
import torchaudio |
|
from hyperpyyaml import load_hyperpyyaml |
|
from tqdm import tqdm |
|
from cosyvoice.cli.model import CosyVoiceModel |
|
|
|
from cosyvoice.dataset.dataset import Dataset |
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser(description='inference with your model') |
|
parser.add_argument('--config', required=True, help='config file') |
|
parser.add_argument('--prompt_data', required=True, help='prompt data file') |
|
parser.add_argument('--prompt_utt2data', required=True, help='prompt data file') |
|
parser.add_argument('--tts_text', required=True, help='tts input file') |
|
parser.add_argument('--llm_model', required=True, help='llm model file') |
|
parser.add_argument('--flow_model', required=True, help='flow model file') |
|
parser.add_argument('--hifigan_model', required=True, help='hifigan model file') |
|
parser.add_argument('--gpu', |
|
type=int, |
|
default=-1, |
|
help='gpu id for this rank, -1 for cpu') |
|
parser.add_argument('--mode', |
|
default='sft', |
|
choices=['sft', 'zero_shot'], |
|
help='inference mode') |
|
parser.add_argument('--result_dir', required=True, help='asr result file') |
|
args = parser.parse_args() |
|
print(args) |
|
return args |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
logging.basicConfig(level=logging.DEBUG, |
|
format='%(asctime)s %(levelname)s %(message)s') |
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) |
|
|
|
|
|
use_cuda = args.gpu >= 0 and torch.cuda.is_available() |
|
device = torch.device('cuda' if use_cuda else 'cpu') |
|
with open(args.config, 'r') as f: |
|
configs = load_hyperpyyaml(f) |
|
|
|
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift']) |
|
model.load(args.llm_model, args.flow_model, args.hifigan_model) |
|
|
|
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False, tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data) |
|
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) |
|
|
|
del configs |
|
os.makedirs(args.result_dir, exist_ok=True) |
|
fn = os.path.join(args.result_dir, 'wav.scp') |
|
f = open(fn, 'w') |
|
with torch.no_grad(): |
|
for batch_idx, batch in tqdm(enumerate(test_data_loader)): |
|
utts = batch["utts"] |
|
assert len(utts) == 1, "inference mode only support batchsize 1" |
|
text = batch["text"] |
|
text_token = batch["text_token"].to(device) |
|
text_token_len = batch["text_token_len"].to(device) |
|
tts_text = batch["tts_text"] |
|
tts_index = batch["tts_index"] |
|
tts_text_token = batch["tts_text_token"].to(device) |
|
tts_text_token_len = batch["tts_text_token_len"].to(device) |
|
speech_token = batch["speech_token"].to(device) |
|
speech_token_len = batch["speech_token_len"].to(device) |
|
speech_feat = batch["speech_feat"].to(device) |
|
speech_feat_len = batch["speech_feat_len"].to(device) |
|
utt_embedding = batch["utt_embedding"].to(device) |
|
spk_embedding = batch["spk_embedding"].to(device) |
|
if args.mode == 'sft': |
|
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, |
|
'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding} |
|
else: |
|
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, |
|
'prompt_text': text_token, 'prompt_text_len': text_token_len, |
|
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len, |
|
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len, |
|
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len, |
|
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding} |
|
model_output = model.inference(**model_input) |
|
tts_key = '{}_{}'.format(utts[0], tts_index[0]) |
|
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key)) |
|
torchaudio.save(tts_fn, model_output['tts_speech'], sample_rate=22050) |
|
f.write('{} {}\n'.format(tts_key, tts_fn)) |
|
f.flush() |
|
f.close() |
|
logging.info('Result wav.scp saved in {}'.format(fn)) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|