Spaces:
Runtime error
Runtime error
ddd
commited on
Commit
•
a3411b4
1
Parent(s):
853fd97
fix hparam
Browse files- modules/diffsinger_midi/fs2.py +0 -109
modules/diffsinger_midi/fs2.py
CHANGED
@@ -117,112 +117,3 @@ class FastSpeech2MIDI(FastSpeech2):
|
|
117 |
|
118 |
return ret
|
119 |
|
120 |
-
def add_pitch(self, decoder_inp, f0, uv, mel2ph, ret, encoder_out=None):
|
121 |
-
decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
|
122 |
-
pitch_padding = mel2ph == 0
|
123 |
-
if hparams['pitch_ar']:
|
124 |
-
ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp, f0 if self.training else None)
|
125 |
-
if f0 is None:
|
126 |
-
f0 = pitch_pred[:, :, 0]
|
127 |
-
else:
|
128 |
-
ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp)
|
129 |
-
if f0 is None:
|
130 |
-
f0 = pitch_pred[:, :, 0]
|
131 |
-
if hparams['use_uv'] and uv is None:
|
132 |
-
uv = pitch_pred[:, :, 1] > 0
|
133 |
-
|
134 |
-
# here f0_denorm for pitch prediction
|
135 |
-
ret['f0_denorm'] = denorm_f0(f0, uv, hparams, pitch_padding=pitch_padding)
|
136 |
-
|
137 |
-
# here f0_denorm for mel prediction
|
138 |
-
if self.training:
|
139 |
-
mask = torch.full(uv.shape, hparams.get('mask_uv_prob', 0.)).to(f0.device)
|
140 |
-
masked_uv = torch.bernoulli(mask).bool().to(f0.device) # prob 的概率吐出一个随机uv.
|
141 |
-
uv_masked = uv.bool() | masked_uv
|
142 |
-
# print((uv.float()-uv_masked.float()).mean(dim=1))
|
143 |
-
f0_denorm = denorm_f0(f0, uv_masked, hparams, pitch_padding=pitch_padding)
|
144 |
-
else:
|
145 |
-
f0_denorm = ret['f0_denorm']
|
146 |
-
|
147 |
-
if pitch_padding is not None:
|
148 |
-
f0[pitch_padding] = 0
|
149 |
-
|
150 |
-
pitch = f0_to_coarse(f0_denorm) # start from 0
|
151 |
-
pitch_embed = self.pitch_embed(pitch)
|
152 |
-
return pitch_embed
|
153 |
-
|
154 |
-
|
155 |
-
class FastSpeech2MIDIMasked(FastSpeech2MIDI):
|
156 |
-
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
|
157 |
-
ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=False,
|
158 |
-
spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs):
|
159 |
-
ret = {}
|
160 |
-
|
161 |
-
midi_dur_embedding, slur_embedding = 0, 0
|
162 |
-
if kwargs.get('midi_dur') is not None:
|
163 |
-
midi_dur_embedding = self.midi_dur_layer(kwargs['midi_dur'][:, :, None]) # [B, T, 1] -> [B, T, H]
|
164 |
-
if kwargs.get('is_slur') is not None:
|
165 |
-
slur_embedding = self.is_slur_embed(kwargs['is_slur'])
|
166 |
-
encoder_out = self.encoder(txt_tokens, 0, midi_dur_embedding, slur_embedding) # [B, T, C]
|
167 |
-
src_nonpadding = (txt_tokens > 0).float()[:, :, None]
|
168 |
-
|
169 |
-
# add ref style embed
|
170 |
-
# Not implemented
|
171 |
-
# variance encoder
|
172 |
-
var_embed = 0
|
173 |
-
|
174 |
-
# encoder_out_dur denotes encoder outputs for duration predictor
|
175 |
-
# in speech adaptation, duration predictor use old speaker embedding
|
176 |
-
if hparams['use_spk_embed']:
|
177 |
-
spk_embed_dur = spk_embed_f0 = spk_embed = self.spk_embed_proj(spk_embed)[:, None, :]
|
178 |
-
elif hparams['use_spk_id']:
|
179 |
-
spk_embed_id = spk_embed
|
180 |
-
if spk_embed_dur_id is None:
|
181 |
-
spk_embed_dur_id = spk_embed_id
|
182 |
-
if spk_embed_f0_id is None:
|
183 |
-
spk_embed_f0_id = spk_embed_id
|
184 |
-
spk_embed = self.spk_embed_proj(spk_embed_id)[:, None, :]
|
185 |
-
spk_embed_dur = spk_embed_f0 = spk_embed
|
186 |
-
if hparams['use_split_spk_id']:
|
187 |
-
spk_embed_dur = self.spk_embed_dur(spk_embed_dur_id)[:, None, :]
|
188 |
-
spk_embed_f0 = self.spk_embed_f0(spk_embed_f0_id)[:, None, :]
|
189 |
-
else:
|
190 |
-
spk_embed_dur = spk_embed_f0 = spk_embed = 0
|
191 |
-
|
192 |
-
# add dur
|
193 |
-
dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding
|
194 |
-
|
195 |
-
mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret)
|
196 |
-
|
197 |
-
decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])
|
198 |
-
|
199 |
-
mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
|
200 |
-
decoder_inp = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H]
|
201 |
-
|
202 |
-
# expanded midi
|
203 |
-
midi_embedding = self.midi_embed(kwargs['pitch_midi'])
|
204 |
-
midi_embedding = F.pad(midi_embedding, [0, 0, 1, 0])
|
205 |
-
midi_embedding = torch.gather(midi_embedding, 1, mel2ph_)
|
206 |
-
print(midi_embedding.shape, decoder_inp.shape)
|
207 |
-
midi_mask = torch.full(midi_embedding.shape, hparams.get('mask_uv_prob', 0.)).to(midi_embedding.device)
|
208 |
-
midi_mask = 1 - torch.bernoulli(midi_mask).bool().to(midi_embedding.device) # prob 的概率吐出一个随机uv.
|
209 |
-
|
210 |
-
tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
|
211 |
-
|
212 |
-
decoder_inp += midi_embedding
|
213 |
-
decoder_inp_origin = decoder_inp
|
214 |
-
# add pitch and energy embed
|
215 |
-
pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding
|
216 |
-
if hparams['use_pitch_embed']:
|
217 |
-
pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
|
218 |
-
decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
|
219 |
-
if hparams['use_energy_embed']:
|
220 |
-
decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)
|
221 |
-
|
222 |
-
ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding
|
223 |
-
|
224 |
-
if skip_decoder:
|
225 |
-
return ret
|
226 |
-
ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
|
227 |
-
|
228 |
-
return ret
|
|
|
117 |
|
118 |
return ret
|
119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|