tobiasc's picture
Initial commit
ad16788
raw
history blame contribute delete
No virus
4.32 kB
# encoding: utf-8
"""Class Declaration of Transformer's Decoder."""
import chainer
import chainer.functions as F
import chainer.links as L
from espnet.nets.chainer_backend.transformer.decoder_layer import DecoderLayer
from espnet.nets.chainer_backend.transformer.embedding import PositionalEncoding
from espnet.nets.chainer_backend.transformer.layer_norm import LayerNorm
from espnet.nets.chainer_backend.transformer.mask import make_history_mask
import numpy as np
class Decoder(chainer.Chain):
"""Decoder layer.
Args:
odim (int): The output dimension.
n_layers (int): Number of ecoder layers.
n_units (int): Number of attention units.
d_units (int): Dimension of input vector of decoder.
h (int): Number of attention heads.
dropout (float): Dropout rate.
initialW (Initializer): Initializer to initialize the weight.
initial_bias (Initializer): Initializer to initialize teh bias.
"""
def __init__(self, odim, args, initialW=None, initial_bias=None):
"""Initialize Decoder."""
super(Decoder, self).__init__()
self.sos = odim - 1
self.eos = odim - 1
initialW = chainer.initializers.Uniform if initialW is None else initialW
initial_bias = (
chainer.initializers.Uniform if initial_bias is None else initial_bias
)
with self.init_scope():
self.output_norm = LayerNorm(args.adim)
self.pe = PositionalEncoding(args.adim, args.dropout_rate)
stvd = 1.0 / np.sqrt(args.adim)
self.output_layer = L.Linear(
args.adim,
odim,
initialW=initialW(scale=stvd),
initial_bias=initial_bias(scale=stvd),
)
self.embed = L.EmbedID(
odim,
args.adim,
ignore_label=-1,
initialW=chainer.initializers.Normal(scale=1.0),
)
for i in range(args.dlayers):
name = "decoders." + str(i)
layer = DecoderLayer(
args.adim,
d_units=args.dunits,
h=args.aheads,
dropout=args.dropout_rate,
initialW=initialW,
initial_bias=initial_bias,
)
self.add_link(name, layer)
self.n_layers = args.dlayers
def make_attention_mask(self, source_block, target_block):
"""Prepare the attention mask.
Args:
source_block (ndarray): Source block with dimensions: (B x S).
target_block (ndarray): Target block with dimensions: (B x T).
Returns:
ndarray: Mask with dimensions (B, S, T).
"""
mask = (target_block[:, None, :] >= 0) * (source_block[:, :, None] >= 0)
# (batch, source_length, target_length)
return mask
def forward(self, ys_pad, source, x_mask):
"""Forward decoder.
:param xp.array e: input token ids, int64 (batch, maxlen_out)
:param xp.array yy_mask: input token mask, uint8 (batch, maxlen_out)
:param xp.array source: encoded memory, float32 (batch, maxlen_in, feat)
:param xp.array xy_mask: encoded memory mask, uint8 (batch, maxlen_in)
:return e: decoded token score before softmax (batch, maxlen_out, token)
:rtype: chainer.Variable
"""
xp = self.xp
sos = np.array([self.sos], np.int32)
ys = [np.concatenate([sos, y], axis=0) for y in ys_pad]
e = F.pad_sequence(ys, padding=self.eos).data
e = xp.array(e)
# mask preparation
xy_mask = self.make_attention_mask(e, xp.array(x_mask))
yy_mask = self.make_attention_mask(e, e)
yy_mask *= make_history_mask(xp, e)
e = self.pe(self.embed(e))
batch, length, dims = e.shape
e = e.reshape(-1, dims)
source = source.reshape(-1, dims)
for i in range(self.n_layers):
e = self["decoders." + str(i)](e, source, xy_mask, yy_mask, batch)
return self.output_layer(self.output_norm(e)).reshape(batch, length, -1)
def recognize(self, e, yy_mask, source):
"""Process recognition function."""
e = self.forward(e, source, yy_mask)
return F.log_softmax(e, axis=-1)