#!/usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright 2019 Shigeki Karita # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) """Encoder self-attention layer definition.""" import copy import torch from torch import nn from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm class EncoderLayer(nn.Module): """Encoder layer module. :param int size: input dim :param espnet.nets.pytorch_backend.transformer.attention. MultiHeadedAttention self_attn: self attention module RelPositionMultiHeadedAttention self_attn: self attention module :param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward. PositionwiseFeedForward feed_forward: feed forward module :param espnet.nets.pytorch_backend.transformer.convolution. ConvolutionModule feed_foreard: feed forward module :param float dropout_rate: dropout rate :param bool normalize_before: whether to use layer_norm before the first block :param bool concat_after: whether to concat attention layer's input and output if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) if False, no additional linear will be applied. i.e. x -> x + att(x) :param bool macaron_style: whether to use macaron style for PositionwiseFeedForward """ def __init__( self, size, self_attn, feed_forward, conv_module, dropout_rate, normalize_before=True, concat_after=False, macaron_style=False, ): """Construct an EncoderLayer object.""" super(EncoderLayer, self).__init__() self.self_attn = self_attn self.feed_forward = feed_forward self.ff_scale = 1.0 self.conv_module = conv_module self.macaron_style = macaron_style self.norm_ff = LayerNorm(size) # for the FNN module self.norm_mha = LayerNorm(size) # for the MHA module if self.macaron_style: self.feed_forward_macaron = copy.deepcopy(feed_forward) self.ff_scale = 0.5 # for another FNN module in macaron style self.norm_ff_macaron = LayerNorm(size) if self.conv_module is not None: self.norm_conv = LayerNorm(size) # for the CNN module self.norm_final = LayerNorm(size) # for the final output of the block self.dropout = nn.Dropout(dropout_rate) self.size = size self.normalize_before = normalize_before self.concat_after = concat_after if self.concat_after: self.concat_linear = nn.Linear(size + size, size) def forward(self, x_input, mask, cache=None): """Compute encoded features. :param torch.Tensor x_input: encoded source features (batch, max_time_in, size) :param torch.Tensor mask: mask for x (batch, max_time_in) :param torch.Tensor cache: cache for x (batch, max_time_in - 1, size) :rtype: Tuple[torch.Tensor, torch.Tensor] """ if isinstance(x_input, tuple): x, pos_emb = x_input[0], x_input[1] else: x, pos_emb = x_input, None # whether to use macaron style if self.macaron_style: residual = x if self.normalize_before: x = self.norm_ff_macaron(x) x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x)) if not self.normalize_before: x = self.norm_ff_macaron(x) # multi-headed self-attention module residual = x if self.normalize_before: x = self.norm_mha(x) if cache is None: x_q = x else: assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) x_q = x[:, -1:, :] residual = residual[:, -1:, :] mask = None if mask is None else mask[:, -1:, :] if pos_emb is not None: x_att = self.self_attn(x_q, x, x, pos_emb, mask) else: x_att = self.self_attn(x_q, x, x, mask) if self.concat_after: x_concat = torch.cat((x, x_att), dim=-1) x = residual + self.concat_linear(x_concat) else: x = residual + self.dropout(x_att) if not self.normalize_before: x = self.norm_mha(x) # convolution module if self.conv_module is not None: residual = x if self.normalize_before: x = self.norm_conv(x) x = residual + self.dropout(self.conv_module(x)) if not self.normalize_before: x = self.norm_conv(x) # feed forward module residual = x if self.normalize_before: x = self.norm_ff(x) x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) if not self.normalize_before: x = self.norm_ff(x) if self.conv_module is not None: x = self.norm_final(x) if cache is not None: x = torch.cat([cache, x], dim=1) if pos_emb is not None: return (x, pos_emb), mask else: return x, mask