tts-rvc-autopst / model_autopst.py
jonathanjordan21's picture
67809715652a92b22870c50ad30f6ff38e292006aedc75ddbdc828aa856ef68f
c021d8e verified
raw
history blame
9.03 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from utils import filter_bank_mean
from fast_decoders import DecodeFunc_Sp
from model_sea import Encoder_2 as Encoder_Code_2
from override_decoder import OnmtDecoder_1 as OnmtDecoder
from onmt_modules.misc import sequence_mask
from onmt_modules.embeddings import PositionalEncoding
from onmt_modules.encoder_transformer import TransformerEncoder as OnmtEncoder
class Prenet(nn.Module):
def __init__(self, dim_input, dim_output, dropout=0.1):
super().__init__()
mlp = nn.Linear(dim_input, dim_output, bias=True)
pe = PositionalEncoding(dropout, dim_output, 1600)
self.make_prenet = nn.Sequential()
self.make_prenet.add_module('mlp', mlp)
self.make_prenet.add_module('pe', pe)
self.word_padding_idx = 1
def forward(self, source, step=None):
for i, module in enumerate(self.make_prenet._modules.values()):
if i == len(self.make_prenet._modules.values()) - 1:
source = module(source, step=step)
else:
source = module(source)
return source
class Decoder_Sp(nn.Module):
"""
Speech Decoder
"""
def __init__(self, hparams):
super().__init__()
self.dim_freq = hparams.dim_freq
self.max_decoder_steps = hparams.dec_steps_sp
self.gate_threshold = hparams.gate_threshold
prenet = Prenet(hparams.dim_freq, hparams.dec_rnn_size)
self.decoder = OnmtDecoder.from_opt(hparams, prenet)
self.postnet = nn.Linear(hparams.dec_rnn_size,
hparams.dim_freq+1, bias=True)
def forward(self, tgt, tgt_lengths, memory_bank, memory_lengths):
dec_outs, attns = self.decoder(tgt, memory_bank, step=None,
memory_lengths=memory_lengths,
tgt_lengths=tgt_lengths)
spect_gate = self.postnet(dec_outs)
spect, gate = spect_gate[:, :, 1:], spect_gate[:, :, :1]
return spect, gate
class Encoder_Tx_Spk(nn.Module):
"""
Text Encoder
"""
def __init__(self, hparams):
super().__init__()
prenet = Prenet(hparams.dim_code+hparams.dim_spk,
hparams.enc_rnn_size)
self.encoder = OnmtEncoder.from_opt(hparams, prenet)
def forward(self, src, src_lengths, spk_emb):
spk_emb = spk_emb.unsqueeze(0).expand(src.size(0),-1,-1)
src_spk = torch.cat((src, spk_emb), dim=-1)
enc_states, memory_bank, src_lengths = self.encoder(src_spk, src_lengths)
return enc_states, memory_bank, src_lengths
class Decoder_Tx(nn.Module):
"""
Text Decoder with stop
and num_rep prediction
"""
def __init__(self, hparams):
super().__init__()
self.dim_code = hparams.dim_code
self.max_decoder_steps = hparams.dec_steps_tx
self.gate_threshold = hparams.gate_threshold
self.dim_rep = hparams.dim_rep
prenet = Prenet(hparams.dim_code, hparams.dec_rnn_size)
self.decoder = OnmtDecoder.from_opt(hparams, prenet)
self.postnet_1 = nn.Linear(hparams.dec_rnn_size,
hparams.dim_code+1, bias=True)
self.postnet_2 = nn.Linear(hparams.dec_rnn_size,
self.dim_rep, bias=True)
def forward(self, tgt, tgt_lengths, memory_bank, memory_lengths):
dec_outs, attns = self.decoder(tgt, memory_bank, step=None,
memory_lengths=memory_lengths,
tgt_lengths=tgt_lengths)
gate_text = self.postnet_1(dec_outs)
rep = self.postnet_2(dec_outs)
gate, text = gate_text[:, :, :1], gate_text[:, :, 1:]
return text, gate, rep
class Generator_1(nn.Module):
'''
sync stage 1
'''
def __init__(self, hparams):
super().__init__()
self.encoder_cd = Encoder_Code_2(hparams)
self.encoder_tx = Encoder_Tx_Spk(hparams)
self.decoder_sp = Decoder_Sp(hparams)
self.encoder_spk = nn.Linear(hparams.dim_spk,
hparams.enc_rnn_size, bias=True)
self.fast_dec_sp = DecodeFunc_Sp(hparams, 'Sp')
def pad_sequences_rnn(self, cd_short, num_rep, len_long):
B, L, C = cd_short.size()
out_tensor = torch.zeros((B, len_long.max(), C), device=cd_short.device)
'''
len_long = len_spect + 1
'''
for i in range(B):
code_sync = cd_short[i].repeat_interleave(num_rep[i], dim=0)
out_tensor[i, :len_long[i]-1, :] = code_sync
return out_tensor
def forward(self, cep_in, mask_long, codes_mask, num_rep, len_short,
tgt_spect, len_spect,
spk_emb):
cd_long = self.encoder_cd(cep_in, mask_long)
fb = filter_bank_mean(num_rep, codes_mask, cd_long.size(1))
cd_short = torch.bmm(fb.detach(), cd_long)
cd_short_sync = self.pad_sequences_rnn(cd_short, num_rep, len_spect)
spk_emb_1 = self.encoder_spk(spk_emb)
# text to speech
_, memory_tx, _ = self.encoder_tx(cd_short_sync.transpose(1,0), len_spect,
spk_emb)
memory_tx_spk = torch.cat((spk_emb_1.unsqueeze(0), memory_tx), dim=0)
self.decoder_sp.decoder.init_state(memory_tx_spk, None, None)
spect_out, gate_sp_out \
= self.decoder_sp(tgt_spect, len_spect, memory_tx_spk, len_spect+1)
return spect_out, gate_sp_out
def infer_onmt(self, cep_in, mask_long,
len_spect,
spk_emb):
cd_long = self.encoder_cd(cep_in, mask_long)
spk_emb_1 = self.encoder_spk(spk_emb)
# text to speech
_, memory_tx, _ = self.encoder_tx(cd_long.transpose(1,0), len_spect,
spk_emb)
memory_tx_spk = torch.cat((spk_emb_1.unsqueeze(0), memory_tx), dim=0)
self.decoder_sp.decoder.init_state(memory_tx_spk, None, None)
spect_output, len_spect_out, stop_sp_output \
= self.fast_dec_sp.infer(None, memory_tx_spk, len_spect+1,
self.decoder_sp.decoder,
self.decoder_sp.postnet)
return spect_output, len_spect_out
class Generator_2(nn.Module):
'''
async stage 2
'''
def __init__(self, hparams):
super().__init__()
self.encoder_cd = Encoder_Code_2(hparams)
self.encoder_tx = Encoder_Tx_Spk(hparams)
self.decoder_sp = Decoder_Sp(hparams)
self.encoder_spk = nn.Linear(hparams.dim_spk,
hparams.enc_rnn_size, bias=True)
self.fast_dec_sp = DecodeFunc_Sp(hparams, 'Sp')
def forward(self, cep_in, mask_long, codes_mask, num_rep, len_short,
tgt_spect, len_spect,
spk_emb):
cd_long = self.encoder_cd(cep_in, mask_long)
fb = filter_bank_mean(num_rep, codes_mask, cd_long.size(1))
cd_short = torch.bmm(fb.detach(), cd_long.detach())
spk_emb_1 = self.encoder_spk(spk_emb)
# text to speech
_, memory_tx, _ = self.encoder_tx(cd_short.transpose(1,0), len_short,
spk_emb)
memory_tx_spk = torch.cat((spk_emb_1.unsqueeze(0), memory_tx), dim=0)
self.decoder_sp.decoder.init_state(memory_tx_spk, None, None)
spect_out, gate_sp_out \
= self.decoder_sp(tgt_spect, len_spect, memory_tx_spk, len_short+1)
return spect_out, gate_sp_out
def infer_onmt(self, cep_in, mask_long, len_spect,
spk_emb):
cd_long = self.encoder_cd(cep_in, mask_long)
spk_emb_1 = self.encoder_spk(spk_emb)
# text to speech
_, memory_tx, _ = self.encoder_tx(cd_long.transpose(1,0), len_spect,
spk_emb)
memory_tx_spk = torch.cat((spk_emb_1.unsqueeze(0), memory_tx), dim=0)
self.decoder_sp.decoder.init_state(memory_tx_spk, None, None)
spect_output, len_spect_out, stop_sp_output \
= self.fast_dec_sp.infer(None, memory_tx_spk, len_spect+1,
self.decoder_sp.decoder,
self.decoder_sp.postnet)
return spect_output, len_spect_out