yuancwang
init
b725c5a
raw
history blame
6 kB
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import numpy as np
from .Layers import FFTBlock
from text.symbols import symbols
PAD = 0
UNK = 1
BOS = 2
EOS = 3
PAD_WORD = "<blank>"
UNK_WORD = "<unk>"
BOS_WORD = "<s>"
EOS_WORD = "</s>"
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
"""Sinusoid position encoding table"""
def cal_angle(position, hid_idx):
return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
def get_posi_angle_vec(position):
return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
sinusoid_table = np.array(
[get_posi_angle_vec(pos_i) for pos_i in range(n_position)]
)
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
if padding_idx is not None:
# zero vector for padding dimension
sinusoid_table[padding_idx] = 0.0
return torch.FloatTensor(sinusoid_table)
class Encoder(nn.Module):
"""Encoder"""
def __init__(self, config):
super(Encoder, self).__init__()
n_position = config["max_seq_len"] + 1
n_src_vocab = len(symbols) + 1
d_word_vec = config["transformer"]["encoder_hidden"]
n_layers = config["transformer"]["encoder_layer"]
n_head = config["transformer"]["encoder_head"]
d_k = d_v = (
config["transformer"]["encoder_hidden"]
// config["transformer"]["encoder_head"]
)
d_model = config["transformer"]["encoder_hidden"]
d_inner = config["transformer"]["conv_filter_size"]
kernel_size = config["transformer"]["conv_kernel_size"]
dropout = config["transformer"]["encoder_dropout"]
self.max_seq_len = config["max_seq_len"]
self.d_model = d_model
self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=PAD)
self.position_enc = nn.Parameter(
get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0),
requires_grad=False,
)
self.layer_stack = nn.ModuleList(
[
FFTBlock(
d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout
)
for _ in range(n_layers)
]
)
def forward(self, src_seq, mask, return_attns=False):
enc_slf_attn_list = []
batch_size, max_len = src_seq.shape[0], src_seq.shape[1]
# -- Prepare masks
slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
# -- Forward
if not self.training and src_seq.shape[1] > self.max_seq_len:
enc_output = self.src_word_emb(src_seq) + get_sinusoid_encoding_table(
src_seq.shape[1], self.d_model
)[: src_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to(
src_seq.device
)
else:
enc_output = self.src_word_emb(src_seq) + self.position_enc[
:, :max_len, :
].expand(batch_size, -1, -1)
for enc_layer in self.layer_stack:
enc_output, enc_slf_attn = enc_layer(
enc_output, mask=mask, slf_attn_mask=slf_attn_mask
)
if return_attns:
enc_slf_attn_list += [enc_slf_attn]
return enc_output
class Decoder(nn.Module):
"""Decoder"""
def __init__(self, config):
super(Decoder, self).__init__()
n_position = config["max_seq_len"] + 1
d_word_vec = config["transformer"]["decoder_hidden"]
n_layers = config["transformer"]["decoder_layer"]
n_head = config["transformer"]["decoder_head"]
d_k = d_v = (
config["transformer"]["decoder_hidden"]
// config["transformer"]["decoder_head"]
)
d_model = config["transformer"]["decoder_hidden"]
d_inner = config["transformer"]["conv_filter_size"]
kernel_size = config["transformer"]["conv_kernel_size"]
dropout = config["transformer"]["decoder_dropout"]
self.max_seq_len = config["max_seq_len"]
self.d_model = d_model
self.position_enc = nn.Parameter(
get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0),
requires_grad=False,
)
self.layer_stack = nn.ModuleList(
[
FFTBlock(
d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout
)
for _ in range(n_layers)
]
)
def forward(self, enc_seq, mask, return_attns=False):
dec_slf_attn_list = []
batch_size, max_len = enc_seq.shape[0], enc_seq.shape[1]
# -- Forward
if not self.training and enc_seq.shape[1] > self.max_seq_len:
# -- Prepare masks
slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
dec_output = enc_seq + get_sinusoid_encoding_table(
enc_seq.shape[1], self.d_model
)[: enc_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to(
enc_seq.device
)
else:
max_len = min(max_len, self.max_seq_len)
# -- Prepare masks
slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
dec_output = enc_seq[:, :max_len, :] + self.position_enc[
:, :max_len, :
].expand(batch_size, -1, -1)
mask = mask[:, :max_len]
slf_attn_mask = slf_attn_mask[:, :, :max_len]
for dec_layer in self.layer_stack:
dec_output, dec_slf_attn = dec_layer(
dec_output, mask=mask, slf_attn_mask=slf_attn_mask
)
if return_attns:
dec_slf_attn_list += [dec_slf_attn]
return dec_output, mask