File size: 1,293 Bytes
d1b91e7
 
 
 
 
 
 
 
 
 
 
 
 
91c5bdb
 
e75aa39
d1b91e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
from inference.tts.base_tts_infer import BaseTTSInfer
from modules.tts.portaspeech.portaspeech_flow import PortaSpeechFlow
from utils.commons.ckpt_utils import load_ckpt
from utils.commons.hparams import hparams


class PortaSpeechFlowInfer(BaseTTSInfer):
    def build_model(self):
        ph_dict_size = len(self.ph_encoder)
        word_dict_size = len(self.word_encoder)
        model = PortaSpeechFlow(ph_dict_size, word_dict_size, self.hparams)
        load_ckpt(model, hparams['work_dir'], 'model')
        with torch.no_grad():
            model.store_inverse_all()
        model.eval()
        return model

    def forward_model(self, inp):
        sample = self.input_to_batch(inp)
        with torch.no_grad():
            output = self.model(
                sample['txt_tokens'],
                sample['word_tokens'],
                ph2word=sample['ph2word'],
                word_len=sample['word_lengths'].max(),
                infer=True,
                forward_post_glow=True,
                spk_id=sample.get('spk_ids')
            )
            mel_out = output['mel_out']
            wav_out = self.run_vocoder(mel_out)
        wav_out = wav_out.cpu().numpy()
        return wav_out[0]


if __name__ == '__main__':
    PortaSpeechFlowInfer.example_run()