#!/usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright 2019 Shigeki Karita # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) """Decoder self-attention layer definition.""" import torch from torch import nn from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm class DecoderLayer(nn.Module): """Single decoder layer module. :param int size: input dim :param espnet.nets.pytorch_backend.transformer.attention.MultiHeadedAttention self_attn: self attention module :param espnet.nets.pytorch_backend.transformer.attention.MultiHeadedAttention src_attn: source attention module :param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward. PositionwiseFeedForward feed_forward: feed forward layer module :param float dropout_rate: dropout rate :param bool normalize_before: whether to use layer_norm before the first block :param bool concat_after: whether to concat attention layer's input and output if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) if False, no additional linear will be applied. i.e. x -> x + att(x) """ def __init__( self, size, self_attn, src_attn, feed_forward, dropout_rate, normalize_before=True, concat_after=False, ): """Construct an DecoderLayer object.""" super(DecoderLayer, self).__init__() self.size = size self.self_attn = self_attn self.src_attn = src_attn self.feed_forward = feed_forward self.norm1 = LayerNorm(size) self.norm2 = LayerNorm(size) self.norm3 = LayerNorm(size) self.dropout = nn.Dropout(dropout_rate) self.normalize_before = normalize_before self.concat_after = concat_after if self.concat_after: self.concat_linear1 = nn.Linear(size + size, size) self.concat_linear2 = nn.Linear(size + size, size) def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None): """Compute decoded features. Args: tgt (torch.Tensor): decoded previous target features (batch, max_time_out, size) tgt_mask (torch.Tensor): mask for x (batch, max_time_out) memory (torch.Tensor): encoded source features (batch, max_time_in, size) memory_mask (torch.Tensor): mask for memory (batch, max_time_in) cache (torch.Tensor): cached output (batch, max_time_out-1, size) """ residual = tgt if self.normalize_before: tgt = self.norm1(tgt) if cache is None: tgt_q = tgt tgt_q_mask = tgt_mask else: # compute only the last frame query keeping dim: max_time_out -> 1 assert cache.shape == ( tgt.shape[0], tgt.shape[1] - 1, self.size, ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" tgt_q = tgt[:, -1:, :] residual = residual[:, -1:, :] tgt_q_mask = None if tgt_mask is not None: tgt_q_mask = tgt_mask[:, -1:, :] if self.concat_after: tgt_concat = torch.cat( (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1 ) x = residual + self.concat_linear1(tgt_concat) else: x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)) if not self.normalize_before: x = self.norm1(x) residual = x if self.normalize_before: x = self.norm2(x) if self.concat_after: x_concat = torch.cat( (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1 ) x = residual + self.concat_linear2(x_concat) else: x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask)) if not self.normalize_before: x = self.norm2(x) residual = x if self.normalize_before: x = self.norm3(x) x = residual + self.dropout(self.feed_forward(x)) if not self.normalize_before: x = self.norm3(x) if cache is not None: x = torch.cat([cache, x], dim=1) return x, tgt_mask, memory, memory_mask