File size: 5,355 Bytes
d1b91e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
from torch import nn
from modules.commons.layers import Embedding
from modules.commons.nar_tts_modules import EnergyPredictor, PitchPredictor
from modules.tts.commons.align_ops import expand_states
from modules.tts.fs import FastSpeech
from utils.audio.cwt import cwt2f0, get_lf0_cwt
from utils.audio.pitch.utils import denorm_f0, f0_to_coarse, norm_f0
import numpy as np


class FastSpeech2Orig(FastSpeech):
    def __init__(self, dict_size, hparams, out_dims=None):
        super().__init__(dict_size, hparams, out_dims)
        predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size
        if hparams['use_energy_embed']:
            self.energy_embed = Embedding(300, self.hidden_size, 0)
            self.energy_predictor = EnergyPredictor(
                self.hidden_size, n_chans=predictor_hidden,
                n_layers=5, dropout_rate=0.1, odim=2,
                kernel_size=hparams['predictor_kernel'])
        if hparams['pitch_type'] == 'cwt' and hparams['use_pitch_embed']:
            self.pitch_predictor = PitchPredictor(
                self.hidden_size, n_chans=predictor_hidden,
                n_layers=5, dropout_rate=0.1, odim=11,
                kernel_size=hparams['predictor_kernel'])
            self.cwt_stats_layers = nn.Sequential(
                nn.Linear(self.hidden_size, self.hidden_size), nn.ReLU(),
                nn.Linear(self.hidden_size, self.hidden_size), nn.ReLU(), nn.Linear(self.hidden_size, 2))

    def forward(self, txt_tokens, mel2ph=None, spk_embed=None, spk_id=None,
                f0=None, uv=None, energy=None, infer=False, **kwargs):
        ret = {}
        encoder_out = self.encoder(txt_tokens)  # [B, T, C]
        src_nonpadding = (txt_tokens > 0).float()[:, :, None]
        style_embed = self.forward_style_embed(spk_embed, spk_id)

        # add dur
        dur_inp = (encoder_out + style_embed) * src_nonpadding
        mel2ph = self.forward_dur(dur_inp, mel2ph, txt_tokens, ret)
        tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
        decoder_inp = decoder_inp_ = expand_states(encoder_out, mel2ph)

        # add pitch and energy embed
        if self.hparams['use_pitch_embed']:
            pitch_inp = (decoder_inp_ + style_embed) * tgt_nonpadding
            decoder_inp = decoder_inp + self.forward_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out)

        # add pitch and energy embed
        if self.hparams['use_energy_embed']:
            energy_inp = (decoder_inp_ + style_embed) * tgt_nonpadding
            decoder_inp = decoder_inp + self.forward_energy(energy_inp, energy, ret)

        # decoder input
        ret['decoder_inp'] = decoder_inp = (decoder_inp + style_embed) * tgt_nonpadding
        if self.hparams['dec_inp_add_noise']:
            B, T, _ = decoder_inp.shape
            z = kwargs.get('adv_z', torch.randn([B, T, self.z_channels])).to(decoder_inp.device)
            ret['adv_z'] = z
            decoder_inp = torch.cat([decoder_inp, z], -1)
            decoder_inp = self.dec_inp_noise_proj(decoder_inp) * tgt_nonpadding
        ret['mel_out'] = self.forward_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
        return ret

    def forward_pitch(self, decoder_inp, f0, uv, mel2ph, ret, encoder_out=None):
        if self.hparams['pitch_type'] == 'cwt':
            decoder_inp = decoder_inp.detach() + self.hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
            pitch_padding = mel2ph == 0
            ret['cwt'] = cwt_out = self.pitch_predictor(decoder_inp)
            stats_out = self.cwt_stats_layers(encoder_out[:, 0, :])  # [B, 2]
            mean = ret['f0_mean'] = stats_out[:, 0]
            std = ret['f0_std'] = stats_out[:, 1]
            cwt_spec = cwt_out[:, :, :10]
            if f0 is None:
                std = std * self.hparams['cwt_std_scale']
                f0 = self.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
                if self.hparams['use_uv']:
                    assert cwt_out.shape[-1] == 11
                    uv = cwt_out[:, :, -1] > 0
            ret['f0_denorm'] = f0_denorm = denorm_f0(f0, uv if self.hparams['use_uv'] else None,
                                                     pitch_padding=pitch_padding)
            pitch = f0_to_coarse(f0_denorm)  # start from 0
            pitch_embed = self.pitch_embed(pitch)
            return pitch_embed
        else:
            return super(FastSpeech2Orig, self).forward_pitch(decoder_inp, f0, uv, mel2ph, ret, encoder_out)

    def forward_energy(self, decoder_inp, energy, ret):
        decoder_inp = decoder_inp.detach() + self.hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
        ret['energy_pred'] = energy_pred = self.energy_predictor(decoder_inp)[:, :, 0]
        energy_embed_inp = energy_pred if energy is None else energy
        energy_embed_inp = torch.clamp(energy_embed_inp * 256 // 4, min=0, max=255).long()
        energy_embed = self.energy_embed(energy_embed_inp)
        return energy_embed

    def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
        _, cwt_scales = get_lf0_cwt(np.ones(10))
        f0 = cwt2f0(cwt_spec, mean, std, cwt_scales)
        f0 = torch.cat(
            [f0] + [f0[:, -1:]] * (mel2ph.shape[1] - f0.shape[1]), 1)
        f0_norm = norm_f0(f0, None)
        return f0_norm