Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
from utils import filter_bank_mean | |
from fast_decoders import DecodeFunc_Sp | |
from model_sea import Encoder_2 as Encoder_Code_2 | |
from override_decoder import OnmtDecoder_1 as OnmtDecoder | |
from onmt_modules.misc import sequence_mask | |
from onmt_modules.embeddings import PositionalEncoding | |
from onmt_modules.encoder_transformer import TransformerEncoder as OnmtEncoder | |
class Prenet(nn.Module): | |
def __init__(self, dim_input, dim_output, dropout=0.1): | |
super().__init__() | |
mlp = nn.Linear(dim_input, dim_output, bias=True) | |
pe = PositionalEncoding(dropout, dim_output, 1600) | |
self.make_prenet = nn.Sequential() | |
self.make_prenet.add_module('mlp', mlp) | |
self.make_prenet.add_module('pe', pe) | |
self.word_padding_idx = 1 | |
def forward(self, source, step=None): | |
for i, module in enumerate(self.make_prenet._modules.values()): | |
if i == len(self.make_prenet._modules.values()) - 1: | |
source = module(source, step=step) | |
else: | |
source = module(source) | |
return source | |
class Decoder_Sp(nn.Module): | |
""" | |
Speech Decoder | |
""" | |
def __init__(self, hparams): | |
super().__init__() | |
self.dim_freq = hparams.dim_freq | |
self.max_decoder_steps = hparams.dec_steps_sp | |
self.gate_threshold = hparams.gate_threshold | |
prenet = Prenet(hparams.dim_freq, hparams.dec_rnn_size) | |
self.decoder = OnmtDecoder.from_opt(hparams, prenet) | |
self.postnet = nn.Linear(hparams.dec_rnn_size, | |
hparams.dim_freq+1, bias=True) | |
def forward(self, tgt, tgt_lengths, memory_bank, memory_lengths): | |
dec_outs, attns = self.decoder(tgt, memory_bank, step=None, | |
memory_lengths=memory_lengths, | |
tgt_lengths=tgt_lengths) | |
spect_gate = self.postnet(dec_outs) | |
spect, gate = spect_gate[:, :, 1:], spect_gate[:, :, :1] | |
return spect, gate | |
class Encoder_Tx_Spk(nn.Module): | |
""" | |
Text Encoder | |
""" | |
def __init__(self, hparams): | |
super().__init__() | |
prenet = Prenet(hparams.dim_code+hparams.dim_spk, | |
hparams.enc_rnn_size) | |
self.encoder = OnmtEncoder.from_opt(hparams, prenet) | |
def forward(self, src, src_lengths, spk_emb): | |
spk_emb = spk_emb.unsqueeze(0).expand(src.size(0),-1,-1) | |
src_spk = torch.cat((src, spk_emb), dim=-1) | |
enc_states, memory_bank, src_lengths = self.encoder(src_spk, src_lengths) | |
return enc_states, memory_bank, src_lengths | |
class Decoder_Tx(nn.Module): | |
""" | |
Text Decoder with stop | |
and num_rep prediction | |
""" | |
def __init__(self, hparams): | |
super().__init__() | |
self.dim_code = hparams.dim_code | |
self.max_decoder_steps = hparams.dec_steps_tx | |
self.gate_threshold = hparams.gate_threshold | |
self.dim_rep = hparams.dim_rep | |
prenet = Prenet(hparams.dim_code, hparams.dec_rnn_size) | |
self.decoder = OnmtDecoder.from_opt(hparams, prenet) | |
self.postnet_1 = nn.Linear(hparams.dec_rnn_size, | |
hparams.dim_code+1, bias=True) | |
self.postnet_2 = nn.Linear(hparams.dec_rnn_size, | |
self.dim_rep, bias=True) | |
def forward(self, tgt, tgt_lengths, memory_bank, memory_lengths): | |
dec_outs, attns = self.decoder(tgt, memory_bank, step=None, | |
memory_lengths=memory_lengths, | |
tgt_lengths=tgt_lengths) | |
gate_text = self.postnet_1(dec_outs) | |
rep = self.postnet_2(dec_outs) | |
gate, text = gate_text[:, :, :1], gate_text[:, :, 1:] | |
return text, gate, rep | |
class Generator_1(nn.Module): | |
''' | |
sync stage 1 | |
''' | |
def __init__(self, hparams): | |
super().__init__() | |
self.encoder_cd = Encoder_Code_2(hparams) | |
self.encoder_tx = Encoder_Tx_Spk(hparams) | |
self.decoder_sp = Decoder_Sp(hparams) | |
self.encoder_spk = nn.Linear(hparams.dim_spk, | |
hparams.enc_rnn_size, bias=True) | |
self.fast_dec_sp = DecodeFunc_Sp(hparams, 'Sp') | |
def pad_sequences_rnn(self, cd_short, num_rep, len_long): | |
B, L, C = cd_short.size() | |
out_tensor = torch.zeros((B, len_long.max(), C), device=cd_short.device) | |
''' | |
len_long = len_spect + 1 | |
''' | |
for i in range(B): | |
code_sync = cd_short[i].repeat_interleave(num_rep[i], dim=0) | |
out_tensor[i, :len_long[i]-1, :] = code_sync | |
return out_tensor | |
def forward(self, cep_in, mask_long, codes_mask, num_rep, len_short, | |
tgt_spect, len_spect, | |
spk_emb): | |
cd_long = self.encoder_cd(cep_in, mask_long) | |
fb = filter_bank_mean(num_rep, codes_mask, cd_long.size(1)) | |
cd_short = torch.bmm(fb.detach(), cd_long) | |
cd_short_sync = self.pad_sequences_rnn(cd_short, num_rep, len_spect) | |
spk_emb_1 = self.encoder_spk(spk_emb) | |
# text to speech | |
_, memory_tx, _ = self.encoder_tx(cd_short_sync.transpose(1,0), len_spect, | |
spk_emb) | |
memory_tx_spk = torch.cat((spk_emb_1.unsqueeze(0), memory_tx), dim=0) | |
self.decoder_sp.decoder.init_state(memory_tx_spk, None, None) | |
spect_out, gate_sp_out \ | |
= self.decoder_sp(tgt_spect, len_spect, memory_tx_spk, len_spect+1) | |
return spect_out, gate_sp_out | |
def infer_onmt(self, cep_in, mask_long, | |
len_spect, | |
spk_emb): | |
cd_long = self.encoder_cd(cep_in, mask_long) | |
spk_emb_1 = self.encoder_spk(spk_emb) | |
# text to speech | |
_, memory_tx, _ = self.encoder_tx(cd_long.transpose(1,0), len_spect, | |
spk_emb) | |
memory_tx_spk = torch.cat((spk_emb_1.unsqueeze(0), memory_tx), dim=0) | |
self.decoder_sp.decoder.init_state(memory_tx_spk, None, None) | |
spect_output, len_spect_out, stop_sp_output \ | |
= self.fast_dec_sp.infer(None, memory_tx_spk, len_spect+1, | |
self.decoder_sp.decoder, | |
self.decoder_sp.postnet) | |
return spect_output, len_spect_out | |
class Generator_2(nn.Module): | |
''' | |
async stage 2 | |
''' | |
def __init__(self, hparams): | |
super().__init__() | |
self.encoder_cd = Encoder_Code_2(hparams) | |
self.encoder_tx = Encoder_Tx_Spk(hparams) | |
self.decoder_sp = Decoder_Sp(hparams) | |
self.encoder_spk = nn.Linear(hparams.dim_spk, | |
hparams.enc_rnn_size, bias=True) | |
self.fast_dec_sp = DecodeFunc_Sp(hparams, 'Sp') | |
def forward(self, cep_in, mask_long, codes_mask, num_rep, len_short, | |
tgt_spect, len_spect, | |
spk_emb): | |
cd_long = self.encoder_cd(cep_in, mask_long) | |
fb = filter_bank_mean(num_rep, codes_mask, cd_long.size(1)) | |
cd_short = torch.bmm(fb.detach(), cd_long.detach()) | |
spk_emb_1 = self.encoder_spk(spk_emb) | |
# text to speech | |
_, memory_tx, _ = self.encoder_tx(cd_short.transpose(1,0), len_short, | |
spk_emb) | |
memory_tx_spk = torch.cat((spk_emb_1.unsqueeze(0), memory_tx), dim=0) | |
self.decoder_sp.decoder.init_state(memory_tx_spk, None, None) | |
spect_out, gate_sp_out \ | |
= self.decoder_sp(tgt_spect, len_spect, memory_tx_spk, len_short+1) | |
return spect_out, gate_sp_out | |
def infer_onmt(self, cep_in, mask_long, len_spect, | |
spk_emb): | |
cd_long = self.encoder_cd(cep_in, mask_long) | |
spk_emb_1 = self.encoder_spk(spk_emb) | |
# text to speech | |
_, memory_tx, _ = self.encoder_tx(cd_long.transpose(1,0), len_spect, | |
spk_emb) | |
memory_tx_spk = torch.cat((spk_emb_1.unsqueeze(0), memory_tx), dim=0) | |
self.decoder_sp.decoder.init_state(memory_tx_spk, None, None) | |
spect_output, len_spect_out, stop_sp_output \ | |
= self.fast_dec_sp.infer(None, memory_tx_spk, len_spect+1, | |
self.decoder_sp.decoder, | |
self.decoder_sp.postnet) | |
return spect_output, len_spect_out |