DiffSpeech / inference /tts /ps_flow.py
RayeRen's picture
update
91c5bdb
raw
history blame
1.29 kB
import torch
from inference.tts.base_tts_infer import BaseTTSInfer
from modules.tts.portaspeech.portaspeech_flow import PortaSpeechFlow
from utils.commons.ckpt_utils import load_ckpt
from utils.commons.hparams import hparams
class PortaSpeechFlowInfer(BaseTTSInfer):
def build_model(self):
ph_dict_size = len(self.ph_encoder)
word_dict_size = len(self.word_encoder)
model = PortaSpeechFlow(ph_dict_size, word_dict_size, self.hparams)
load_ckpt(model, hparams['work_dir'], 'model')
with torch.no_grad():
model.store_inverse_all()
model.eval()
return model
def forward_model(self, inp):
sample = self.input_to_batch(inp)
with torch.no_grad():
output = self.model(
sample['txt_tokens'],
sample['word_tokens'],
ph2word=sample['ph2word'],
word_len=sample['word_lengths'].max(),
infer=True,
forward_post_glow=True,
spk_id=sample.get('spk_ids')
)
mel_out = output['mel_out']
wav_out = self.run_vocoder(mel_out)
wav_out = wav_out.cpu().numpy()
return wav_out[0]
if __name__ == '__main__':
PortaSpeechFlowInfer.example_run()