File size: 10,500 Bytes
e882f51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import torch
from torch.nn import functional as F

from utils.hparams import hparams
from utils.pitch_utils import f0_to_coarse, denorm_f0


class Batch2Loss:
    '''
        pipeline: batch -> insert1 -> module1 -> insert2 -> module2 -> insert3 -> module3 -> insert4 -> module4 -> loss
    '''

    @staticmethod
    def insert1(pitch_midi, midi_dur, is_slur,  # variables
                midi_embed, midi_dur_layer, is_slur_embed):  # modules
        '''
            add embeddings for midi, midi_dur, slur
        '''
        midi_embedding = midi_embed(pitch_midi)
        midi_dur_embedding, slur_embedding = 0, 0
        if midi_dur is not None:
            midi_dur_embedding = midi_dur_layer(midi_dur[:, :, None])  # [B, T, 1] -> [B, T, H]
        if is_slur is not None:
            slur_embedding = is_slur_embed(is_slur)
        return midi_embedding, midi_dur_embedding, slur_embedding

    @staticmethod
    def module1(fs2_encoder,  # modules
                txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding):  # variables
        '''
            get *encoder_out* == fs2_encoder(*txt_tokens*, some embeddings)
        '''
        encoder_out = fs2_encoder(txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding)
        return encoder_out

    @staticmethod
    def insert2(encoder_out, spk_embed_id, spk_embed_dur_id, spk_embed_f0_id, src_nonpadding,  # variables
                spk_embed_proj):  # modules
        '''
            1. add embeddings for pspk, spk_dur, sk_f0
            2. get *dur_inp* ~= *encoder_out* + *spk_embed_dur*
        '''
        # add ref style embed
        # Not implemented
        # variance encoder
        var_embed = 0

        # encoder_out_dur denotes encoder outputs for duration predictor
        # in speech adaptation, duration predictor use old speaker embedding
        if hparams['use_spk_embed']:
            spk_embed_dur = spk_embed_f0 = spk_embed = spk_embed_proj(spk_embed_id)[:, None, :]
        elif hparams['use_spk_id']:
            if spk_embed_dur_id is None:
                spk_embed_dur_id = spk_embed_id
            if spk_embed_f0_id is None:
                spk_embed_f0_id = spk_embed_id
            spk_embed = spk_embed_proj(spk_embed_id)[:, None, :]
            spk_embed_dur = spk_embed_f0 = spk_embed
            if hparams['use_split_spk_id']:
                spk_embed_dur = spk_embed_dur(spk_embed_dur_id)[:, None, :]
                spk_embed_f0 = spk_embed_f0(spk_embed_f0_id)[:, None, :]
        else:
            spk_embed_dur = spk_embed_f0 = spk_embed = 0

        # add dur
        dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding
        return var_embed, spk_embed, spk_embed_dur, spk_embed_f0, dur_inp

    @staticmethod
    def module2(dur_predictor, length_regulator,  # modules
                dur_input, mel2ph, txt_tokens, all_vowel_tokens, ret, midi_dur=None):  # variables
        '''
            1. get *dur* ~= dur_predictor(*dur_inp*)
            2. (mel2ph is None): get *mel2ph* ~= length_regulater(*dur*)
        '''
        src_padding = (txt_tokens == 0)
        dur_input = dur_input.detach() + hparams['predictor_grad'] * (dur_input - dur_input.detach())

        if mel2ph is None:
            dur, xs = dur_predictor.inference(dur_input, src_padding)
            ret['dur'] = xs
            dur = xs.squeeze(-1).exp() - 1.0
            for i in range(len(dur)):
                for j in range(len(dur[i])):
                    if txt_tokens[i, j] in all_vowel_tokens:
                        if j < len(dur[i]) - 1 and txt_tokens[i, j + 1] not in all_vowel_tokens:
                            dur[i, j] = midi_dur[i, j] - dur[i, j + 1]
                            if dur[i, j] < 0:
                                dur[i, j] = 0
                                dur[i, j + 1] = midi_dur[i, j]
                        else:
                            dur[i, j] = midi_dur[i, j]
            dur[:, 0] = dur[:, 0] + 0.5
            dur_acc = F.pad(torch.round(torch.cumsum(dur, axis=1)), (1, 0))
            dur = torch.clamp(dur_acc[:, 1:] - dur_acc[:, :-1], min=0).long()
            ret['dur_choice'] = dur
            mel2ph = length_regulator(dur, src_padding).detach()
        else:
            ret['dur'] = dur_predictor(dur_input, src_padding)
        ret['mel2ph'] = mel2ph

        return mel2ph

    @staticmethod
    def insert3(encoder_out, mel2ph, var_embed, spk_embed_f0, src_nonpadding, tgt_nonpadding):  # variables
        '''
            1. get *decoder_inp* ~= gather *encoder_out* according to *mel2ph*
            2. get *pitch_inp* ~= *decoder_inp* + *spk_embed_f0*
            3. get *pitch_inp_ph* ~= *encoder_out* + *spk_embed_f0*
        '''
        decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])
        mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
        decoder_inp = decoder_inp_origin = torch.gather(decoder_inp, 1, mel2ph_)  # [B, T, H]

        pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding
        pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
        return decoder_inp, pitch_inp, pitch_inp_ph

    @staticmethod
    def module3(pitch_predictor, pitch_embed, energy_predictor, energy_embed,  # modules
                pitch_inp, pitch_inp_ph, f0, uv, energy, mel2ph, is_training, ret):  # variables
        '''
            1. get *ret['pitch_pred']*, *ret['energy_pred']* ~= pitch_predictor(*pitch_inp*), energy_predictor(*pitch_inp*)
            2. get *pitch_embedding* ~= pitch_embed(f0_to_coarse(denorm_f0(*f0* or *pitch_pred*))
            3. get *energy_embedding* ~= energy_embed(energy_to_coarse(*energy* or *energy_pred*))
        '''

        def add_pitch(decoder_inp, f0, uv, mel2ph, ret, encoder_out=None):
            if hparams['pitch_type'] == 'ph':
                pitch_pred_inp = encoder_out.detach() + hparams['predictor_grad'] * (encoder_out - encoder_out.detach())
                pitch_padding = (encoder_out.sum().abs() == 0)
                ret['pitch_pred'] = pitch_pred = pitch_predictor(pitch_pred_inp)
                if f0 is None:
                    f0 = pitch_pred[:, :, 0]
                ret['f0_denorm'] = f0_denorm = denorm_f0(f0, None, hparams, pitch_padding=pitch_padding)
                pitch = f0_to_coarse(f0_denorm)  # start from 0 [B, T_txt]
                pitch = F.pad(pitch, [1, 0])
                pitch = torch.gather(pitch, 1, mel2ph)  # [B, T_mel]
                pitch_embedding = pitch_embed(pitch)
                return pitch_embedding

            decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())

            pitch_padding = (mel2ph == 0)

            if hparams['pitch_ar']:
                ret['pitch_pred'] = pitch_pred = pitch_predictor(decoder_inp, f0 if is_training else None)
                if f0 is None:
                    f0 = pitch_pred[:, :, 0]
            else:
                ret['pitch_pred'] = pitch_pred = pitch_predictor(decoder_inp)
                if f0 is None:
                    f0 = pitch_pred[:, :, 0]
                if hparams['use_uv'] and uv is None:
                    uv = pitch_pred[:, :, 1] > 0
            ret['f0_denorm'] = f0_denorm = denorm_f0(f0, uv, hparams, pitch_padding=pitch_padding)
            if pitch_padding is not None:
                f0[pitch_padding] = 0

            pitch = f0_to_coarse(f0_denorm)  # start from 0
            pitch_embedding = pitch_embed(pitch)
            return pitch_embedding

        def add_energy(decoder_inp, energy, ret):
            decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
            ret['energy_pred'] = energy_pred = energy_predictor(decoder_inp)[:, :, 0]
            if energy is None:
                energy = energy_pred
            energy = torch.clamp(energy * 256 // 4, max=255).long()  # energy_to_coarse
            energy_embedding = energy_embed(energy)
            return energy_embedding

        # add pitch and energy embed
        nframes = mel2ph.size(1)

        pitch_embedding = 0
        if hparams['use_pitch_embed']:
            if f0 is not None:
                delta_l = nframes - f0.size(1)
                if delta_l > 0:
                    f0 = torch.cat((f0, torch.FloatTensor([[x[-1]] * delta_l for x in f0]).to(f0.device)), 1)
                f0 = f0[:, :nframes]
            if uv is not None:
                delta_l = nframes - uv.size(1)
                if delta_l > 0:
                    uv = torch.cat((uv, torch.FloatTensor([[x[-1]] * delta_l for x in uv]).to(uv.device)), 1)
                uv = uv[:, :nframes]
            pitch_embedding = add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)

        energy_embedding = 0
        if hparams['use_energy_embed']:
            if energy is not None:
                delta_l = nframes - energy.size(1)
                if delta_l > 0:
                    energy = torch.cat(
                        (energy, torch.FloatTensor([[x[-1]] * delta_l for x in energy]).to(energy.device)), 1)
                energy = energy[:, :nframes]
            energy_embedding = add_energy(pitch_inp, energy, ret)

        return pitch_embedding, energy_embedding

    @staticmethod
    def insert4(decoder_inp, pitch_embedding, energy_embedding, spk_embed, ret, tgt_nonpadding):
        '''
            *decoder_inp* ~= *decoder_inp* + embeddings for spk, pitch, energy
        '''
        ret['decoder_inp'] = decoder_inp = (
                                                   decoder_inp + pitch_embedding + energy_embedding + spk_embed) * tgt_nonpadding
        return decoder_inp

    @staticmethod
    def module4(diff_main_loss,  # modules
                norm_spec, decoder_inp_t, ret, K_step, batch_size, device):  # variables
        '''
            training diffusion using spec as input and decoder_inp as condition.
            
            Args:
                norm_spec: (normalized) spec
                decoder_inp_t: (transposed) decoder_inp
            Returns:
                ret['diff_loss']
        '''
        t = torch.randint(0, K_step, (batch_size,), device=device).long()
        norm_spec = norm_spec.transpose(1, 2)[:, None, :, :]  # [B, 1, M, T]
        ret['diff_loss'] = diff_main_loss(norm_spec, t, cond=decoder_inp_t)
        # nonpadding = (mel2ph != 0).float()
        # ret['diff_loss'] = self.p_losses(x, t, cond, nonpadding=nonpadding)