GenerSpeech / inference /GenerSpeech.py
Rongjiehuang's picture
update
222619b
raw
history blame
1.47 kB
import torch
from inference.base_tts_infer import BaseTTSInfer
from utils.ckpt_utils import load_ckpt, get_last_checkpoint
from utils.hparams import hparams
from modules.GenerSpeech.model.generspeech import GenerSpeech
from usr.diff.net import DiffNet
import os
import numpy as np
from functools import partial
class GenerSpeechInfer(BaseTTSInfer):
def build_model(self):
f0_stats_fn = f'{hparams["binary_data_dir"]}/train_f0s_mean_std.npy'
if os.path.exists(f0_stats_fn):
hparams['f0_mean'], hparams['f0_std'] = np.load(f0_stats_fn)
hparams['f0_mean'] = float(hparams['f0_mean'])
hparams['f0_std'] = float(hparams['f0_std'])
model = GenerSpeech(self.ph_encoder)
model.eval()
load_ckpt(model, hparams['work_dir'], 'model')
return model
def forward_model(self, inp):
sample = self.input_to_batch(inp)
txt_tokens = sample['txt_tokens'] # [B, T_t]
with torch.no_grad():
output = self.model(txt_tokens, ref_mel2ph=sample['mel2ph'], ref_mel2word=sample['mel2word'], ref_mels=sample['mels'],
spk_embed=sample['spk_embed'], emo_embed=sample['emo_embed'], global_steps=300000, infer=True)
mel_out = output['mel_out']
wav_out = self.run_vocoder(mel_out)
wav_out = wav_out.squeeze().cpu().numpy()
return wav_out
if __name__ == '__main__':
GenerSpeechInfer.example_run()