lmzjms's picture
Upload 591 files
9206300
import torch
from modules.GenerSpeech.model.glow_modules import Glow
from modules.fastspeech.tts_modules import PitchPredictor
import random
from modules.GenerSpeech.model.prosody_util import ProsodyAligner, LocalStyleAdaptor
from utils.pitch_utils import f0_to_coarse, denorm_f0
from modules.commons.common_layers import *
import torch.distributions as dist
from utils.hparams import hparams
from modules.GenerSpeech.model.mixstyle import MixStyle
from modules.fastspeech.fs2 import FastSpeech2
import json
from modules.fastspeech.tts_modules import DEFAULT_MAX_SOURCE_POSITIONS, DEFAULT_MAX_TARGET_POSITIONS
class GenerSpeech(FastSpeech2):
'''
GenerSpeech: Towards Style Transfer for Generalizable Out-Of-Domain Text-to-Speech
https://arxiv.org/abs/2205.07211
'''
def __init__(self, dictionary, out_dims=None):
super().__init__(dictionary, out_dims)
# Mixstyle
self.norm = MixStyle(p=0.5, alpha=0.1, eps=1e-6, hidden_size=self.hidden_size)
# emotion embedding
self.emo_embed_proj = Linear(256, self.hidden_size, bias=True)
# build prosody extractor
## frame level
self.prosody_extractor_utter = LocalStyleAdaptor(self.hidden_size, hparams['nVQ'], self.padding_idx)
self.l1_utter = nn.Linear(self.hidden_size * 2, self.hidden_size)
self.align_utter = ProsodyAligner(num_layers=2)
## phoneme level
self.prosody_extractor_ph = LocalStyleAdaptor(self.hidden_size, hparams['nVQ'], self.padding_idx)
self.l1_ph = nn.Linear(self.hidden_size * 2, self.hidden_size)
self.align_ph = ProsodyAligner(num_layers=2)
## word level
self.prosody_extractor_word = LocalStyleAdaptor(self.hidden_size, hparams['nVQ'], self.padding_idx)
self.l1_word = nn.Linear(self.hidden_size * 2, self.hidden_size)
self.align_word = ProsodyAligner(num_layers=2)
self.pitch_inpainter_predictor = PitchPredictor(
self.hidden_size, n_chans=self.hidden_size,
n_layers=3, dropout_rate=0.1, odim=2,
padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
# build attention layer
self.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
self.embed_positions = SinusoidalPositionalEmbedding(
self.hidden_size, self.padding_idx,
init_size=self.max_source_positions + self.padding_idx + 1,
)
# build post flow
cond_hs = 80
if hparams.get('use_txt_cond', True):
cond_hs = cond_hs + hparams['hidden_size']
cond_hs = cond_hs + hparams['hidden_size'] * 3 # for emo, spk embedding and prosody embedding
self.post_flow = Glow(
80, hparams['post_glow_hidden'], hparams['post_glow_kernel_size'], 1,
hparams['post_glow_n_blocks'], hparams['post_glow_n_block_layers'],
n_split=4, n_sqz=2,
gin_channels=cond_hs,
share_cond_layers=hparams['post_share_cond_layers'],
share_wn_layers=hparams['share_wn_layers'],
sigmoid_scale=hparams['sigmoid_scale']
)
self.prior_dist = dist.Normal(0, 1)
def forward(self, txt_tokens, mel2ph=None, ref_mel2ph=None, ref_mel2word=None, spk_embed=None, emo_embed=None, ref_mels=None,
f0=None, uv=None, skip_decoder=False, global_steps=0, infer=False, **kwargs):
ret = {}
encoder_out = self.encoder(txt_tokens) # [B, T, C]
src_nonpadding = (txt_tokens > 0).float()[:, :, None]
# add spk/emo embed
spk_embed = self.spk_embed_proj(spk_embed)[:, None, :]
emo_embed = self.emo_embed_proj(emo_embed)[:, None, :]
# add dur
dur_inp = (encoder_out + spk_embed + emo_embed) * src_nonpadding
mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret)
tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
decoder_inp = self.expand_states(encoder_out, mel2ph)
decoder_inp = self.norm(decoder_inp, spk_embed + emo_embed)
# add prosody VQ
ret['ref_mel2ph'] = ref_mel2ph
ret['ref_mel2word'] = ref_mel2word
prosody_utter_mel = self.get_prosody_utter(decoder_inp, ref_mels, ret, infer, global_steps)
prosody_ph_mel = self.get_prosody_ph(decoder_inp, ref_mels, ret, infer, global_steps)
prosody_word_mel = self.get_prosody_word(decoder_inp, ref_mels, ret, infer, global_steps)
# add pitch embed
pitch_inp_domain_agnostic = decoder_inp * tgt_nonpadding
pitch_inp_domain_specific = (decoder_inp + spk_embed + emo_embed + prosody_utter_mel + prosody_ph_mel + prosody_word_mel) * tgt_nonpadding
predicted_pitch = self.inpaint_pitch(pitch_inp_domain_agnostic, pitch_inp_domain_specific, f0, uv, mel2ph, ret)
# decode
decoder_inp = decoder_inp + spk_embed + emo_embed + predicted_pitch + prosody_utter_mel + prosody_ph_mel + prosody_word_mel
ret['decoder_inp'] = decoder_inp = decoder_inp * tgt_nonpadding
if skip_decoder:
return ret
ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
# postflow
is_training = self.training
ret['x_mask'] = tgt_nonpadding
ret['spk_embed'] = spk_embed
ret['emo_embed'] = emo_embed
ret['ref_prosody'] = prosody_utter_mel + prosody_ph_mel + prosody_word_mel
self.run_post_glow(ref_mels, infer, is_training, ret)
return ret
def get_prosody_ph(self, encoder_out, ref_mels, ret, infer=False, global_steps=0):
# get VQ prosody
if global_steps > hparams['vq_start'] or infer:
prosody_embedding, loss, ppl = self.prosody_extractor_ph(ref_mels, ret['ref_mel2ph'], no_vq=False)
ret['vq_loss_ph'] = loss
ret['ppl_ph'] = ppl
else:
prosody_embedding = self.prosody_extractor_ph(ref_mels, ret['ref_mel2ph'], no_vq=True)
# add positional embedding
positions = self.embed_positions(prosody_embedding[:, :, 0])
prosody_embedding = self.l1_ph(torch.cat([prosody_embedding, positions], dim=-1))
# style-to-content attention
src_key_padding_mask = encoder_out[:, :, 0].eq(self.padding_idx).data
prosody_key_padding_mask = prosody_embedding[:, :, 0].eq(self.padding_idx).data
if global_steps < hparams['forcing']:
output, guided_loss, attn_emo = self.align_ph(encoder_out.transpose(0, 1), prosody_embedding.transpose(0, 1),
src_key_padding_mask, prosody_key_padding_mask, forcing=True)
else:
output, guided_loss, attn_emo = self.align_ph(encoder_out.transpose(0, 1), prosody_embedding.transpose(0, 1),
src_key_padding_mask, prosody_key_padding_mask, forcing=False)
ret['gloss_ph'] = guided_loss
ret['attn_ph'] = attn_emo
return output.transpose(0, 1)
def get_prosody_word(self, encoder_out, ref_mels, ret, infer=False, global_steps=0):
# get VQ prosody
if global_steps > hparams['vq_start'] or infer:
prosody_embedding, loss, ppl = self.prosody_extractor_word(ref_mels, ret['ref_mel2word'], no_vq=False)
ret['vq_loss_word'] = loss
ret['ppl_word'] = ppl
else:
prosody_embedding = self.prosody_extractor_word(ref_mels, ret['ref_mel2word'], no_vq=True)
# add positional embedding
positions = self.embed_positions(prosody_embedding[:, :, 0])
prosody_embedding = self.l1_word(torch.cat([prosody_embedding, positions], dim=-1))
# style-to-content attention
src_key_padding_mask = encoder_out[:, :, 0].eq(self.padding_idx).data
prosody_key_padding_mask = prosody_embedding[:, :, 0].eq(self.padding_idx).data
if global_steps < hparams['forcing']:
output, guided_loss, attn_emo = self.align_word(encoder_out.transpose(0, 1), prosody_embedding.transpose(0, 1),
src_key_padding_mask, prosody_key_padding_mask, forcing=True)
else:
output, guided_loss, attn_emo = self.align_word(encoder_out.transpose(0, 1), prosody_embedding.transpose(0, 1),
src_key_padding_mask, prosody_key_padding_mask, forcing=False)
ret['gloss_word'] = guided_loss
ret['attn_word'] = attn_emo
return output.transpose(0, 1)
def get_prosody_utter(self, encoder_out, ref_mels, ret, infer=False, global_steps=0):
# get VQ prosody
if global_steps > hparams['vq_start'] or infer:
prosody_embedding, loss, ppl = self.prosody_extractor_utter(ref_mels, no_vq=False)
ret['vq_loss_utter'] = loss
ret['ppl_utter'] = ppl
else:
prosody_embedding = self.prosody_extractor_utter(ref_mels, no_vq=True)
# add positional embedding
positions = self.embed_positions(prosody_embedding[:, :, 0])
prosody_embedding = self.l1_utter(torch.cat([prosody_embedding, positions], dim=-1))
# style-to-content attention
src_key_padding_mask = encoder_out[:, :, 0].eq(self.padding_idx).data
prosody_key_padding_mask = prosody_embedding[:, :, 0].eq(self.padding_idx).data
if global_steps < hparams['forcing']:
output, guided_loss, attn_emo = self.align_utter(encoder_out.transpose(0, 1), prosody_embedding.transpose(0, 1),
src_key_padding_mask, prosody_key_padding_mask, forcing=True)
else:
output, guided_loss, attn_emo = self.align_utter(encoder_out.transpose(0, 1), prosody_embedding.transpose(0, 1),
src_key_padding_mask, prosody_key_padding_mask, forcing=False)
ret['gloss_utter'] = guided_loss
ret['attn_utter'] = attn_emo
return output.transpose(0, 1)
def inpaint_pitch(self, pitch_inp_domain_agnostic, pitch_inp_domain_specific, f0, uv, mel2ph, ret):
if hparams['pitch_type'] == 'frame':
pitch_padding = mel2ph == 0
if hparams['predictor_grad'] != 1:
pitch_inp_domain_agnostic = pitch_inp_domain_agnostic.detach() + hparams['predictor_grad'] * (pitch_inp_domain_agnostic - pitch_inp_domain_agnostic.detach())
pitch_inp_domain_specific = pitch_inp_domain_specific.detach() + hparams['predictor_grad'] * (pitch_inp_domain_specific - pitch_inp_domain_specific.detach())
pitch_domain_agnostic = self.pitch_predictor(pitch_inp_domain_agnostic)
pitch_domain_specific = self.pitch_inpainter_predictor(pitch_inp_domain_specific)
pitch_pred = pitch_domain_agnostic + pitch_domain_specific
ret['pitch_pred'] = pitch_pred
use_uv = hparams['pitch_type'] == 'frame' and hparams['use_uv']
if f0 is None:
f0 = pitch_pred[:, :, 0] # [B, T]
if use_uv:
uv = pitch_pred[:, :, 1] > 0 # [B, T]
f0_denorm = denorm_f0(f0, uv if use_uv else None, hparams, pitch_padding=pitch_padding)
pitch = f0_to_coarse(f0_denorm) # start from 0 [B, T_txt]
ret['f0_denorm'] = f0_denorm
ret['f0_denorm_pred'] = denorm_f0(pitch_pred[:, :, 0], (pitch_pred[:, :, 1] > 0) if use_uv else None, hparams, pitch_padding=pitch_padding)
if hparams['pitch_type'] == 'ph':
pitch = torch.gather(F.pad(pitch, [1, 0]), 1, mel2ph)
ret['f0_denorm'] = torch.gather(F.pad(ret['f0_denorm'], [1, 0]), 1, mel2ph)
ret['f0_denorm_pred'] = torch.gather(F.pad(ret['f0_denorm_pred'], [1, 0]), 1, mel2ph)
pitch_embed = self.pitch_embed(pitch)
return pitch_embed
def run_post_glow(self, tgt_mels, infer, is_training, ret):
x_recon = ret['mel_out'].transpose(1, 2)
g = x_recon
B, _, T = g.shape
if hparams.get('use_txt_cond', True):
g = torch.cat([g, ret['decoder_inp'].transpose(1, 2)], 1)
g_spk_embed = ret['spk_embed'].repeat(1, T, 1).transpose(1, 2)
g_emo_embed = ret['emo_embed'].repeat(1, T, 1).transpose(1, 2)
l_ref_prosody = ret['ref_prosody'].transpose(1, 2)
g = torch.cat([g, g_spk_embed, g_emo_embed, l_ref_prosody], dim=1)
prior_dist = self.prior_dist
if not infer:
if is_training:
self.train()
x_mask = ret['x_mask'].transpose(1, 2)
y_lengths = x_mask.sum(-1)
g = g.detach()
tgt_mels = tgt_mels.transpose(1, 2)
z_postflow, ldj = self.post_flow(tgt_mels, x_mask, g=g)
ldj = ldj / y_lengths / 80
ret['z_pf'], ret['ldj_pf'] = z_postflow, ldj
ret['postflow'] = -prior_dist.log_prob(z_postflow).mean() - ldj.mean()
else:
x_mask = torch.ones_like(x_recon[:, :1, :])
z_post = prior_dist.sample(x_recon.shape).to(g.device) * hparams['noise_scale']
x_recon_, _ = self.post_flow(z_post, x_mask, g, reverse=True)
x_recon = x_recon_
ret['mel_out'] = x_recon.transpose(1, 2)