|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Tuple, final |
|
|
|
from fairseq2.nn.incremental_state import IncrementalStateBag |
|
from fairseq2.nn.normalization import LayerNorm |
|
from fairseq2.nn.padding import PaddingMask |
|
from fairseq2.nn.transformer import ( |
|
AttentionMask, |
|
FeedForwardNetwork, |
|
MultiheadAttention, |
|
create_standard_layer_norm, |
|
) |
|
from fairseq2.typing import DataType, Device, finaloverride |
|
from torch import Tensor |
|
from torch.nn import Dropout, Module |
|
|
|
from seamless_communication.models.monotonic_decoder.p_choose import PChooseLayer |
|
|
|
|
|
@final |
|
class MonotonicTransformerDecoderLayer(Module): |
|
"""Represents a Monotonic Transformer decoder layer.""" |
|
|
|
self_attn: MultiheadAttention |
|
self_attn_dropout: Optional[Dropout] |
|
self_attn_layer_norm: LayerNorm |
|
encoder_decoder_attn: MultiheadAttention |
|
encoder_decoder_attn_dropout: Optional[Dropout] |
|
encoder_decoder_attn_layer_norm: LayerNorm |
|
p_choose_layer: PChooseLayer |
|
ffn: FeedForwardNetwork |
|
ffn_dropout: Optional[Dropout] |
|
ffn_layer_norm: LayerNorm |
|
|
|
def __init__( |
|
self, |
|
self_attn: MultiheadAttention, |
|
encoder_decoder_attn: MultiheadAttention, |
|
p_choose_layer: PChooseLayer, |
|
ffn: FeedForwardNetwork, |
|
*, |
|
dropout_p: float = 0.1, |
|
device: Optional[Device] = None, |
|
dtype: Optional[DataType] = None, |
|
) -> None: |
|
""" |
|
:param self_attn: |
|
The self attention layer. |
|
:param encoder_decoder_attn: |
|
The encoder-decoder attention layer. |
|
:param ffn: |
|
The feed-forward network. |
|
:param dropout_p: |
|
The dropout probability on outputs of the attention layers and the |
|
feed-forward network. |
|
""" |
|
super().__init__() |
|
|
|
self.model_dim = self_attn.model_dim |
|
|
|
self_attn_layer_norm = create_standard_layer_norm( |
|
self.model_dim, device=device, dtype=dtype |
|
) |
|
|
|
self.self_attn_layer_norm = self_attn_layer_norm |
|
|
|
self.self_attn = self_attn |
|
|
|
if dropout_p > 0.0: |
|
self.self_attn_dropout = Dropout(dropout_p) |
|
else: |
|
self.register_module("self_attn_dropout", None) |
|
|
|
encoder_decoder_attn_layer_norm = create_standard_layer_norm( |
|
self.model_dim, device=device, dtype=dtype |
|
) |
|
|
|
self.encoder_decoder_attn_layer_norm = encoder_decoder_attn_layer_norm |
|
|
|
self.encoder_decoder_attn = encoder_decoder_attn |
|
|
|
if dropout_p > 0.0: |
|
self.encoder_decoder_attn_dropout = Dropout(dropout_p) |
|
else: |
|
self.register_module("encoder_decoder_attn_dropout", None) |
|
|
|
self.p_choose_layer = p_choose_layer |
|
|
|
ffn_layer_norm = create_standard_layer_norm( |
|
self.model_dim, device=device, dtype=dtype |
|
) |
|
|
|
self.ffn_layer_norm = ffn_layer_norm |
|
|
|
self.ffn = ffn |
|
|
|
if dropout_p > 0.0: |
|
self.ffn_dropout = Dropout(dropout_p) |
|
else: |
|
self.register_module("ffn_dropout", None) |
|
|
|
@finaloverride |
|
def forward( |
|
self, |
|
seqs: Tensor, |
|
padding_mask: Optional[PaddingMask], |
|
self_attn_mask: Optional[AttentionMask] = None, |
|
encoder_output: Optional[Tensor] = None, |
|
encoder_padding_mask: Optional[PaddingMask] = None, |
|
*, |
|
state_bag: Optional[IncrementalStateBag] = None, |
|
) -> Tuple[Tensor, Optional[PaddingMask], Tensor]: |
|
seqs = self._forward_self_attn(seqs, padding_mask, self_attn_mask, state_bag) |
|
|
|
seqs, p_choose = self._forward_encoder_decoder_attn( |
|
seqs, padding_mask, encoder_output, encoder_padding_mask |
|
) |
|
|
|
seqs = self._forward_ffn(seqs) |
|
|
|
return seqs, padding_mask, p_choose |
|
|
|
def _forward_self_attn( |
|
self, |
|
seqs: Tensor, |
|
padding_mask: Optional[PaddingMask], |
|
self_attn_mask: Optional[AttentionMask], |
|
state_bag: Optional[IncrementalStateBag], |
|
) -> Tensor: |
|
residual = seqs |
|
|
|
seqs = self.self_attn_layer_norm(seqs) |
|
|
|
seqs = self.self_attn( |
|
seqs, |
|
padding_mask, |
|
keys=seqs, |
|
key_padding_mask=padding_mask, |
|
values=seqs, |
|
attn_mask=self_attn_mask, |
|
state_bag=state_bag, |
|
) |
|
|
|
if self.self_attn_dropout is not None: |
|
seqs = self.self_attn_dropout(seqs) |
|
|
|
seqs = seqs + residual |
|
|
|
return seqs |
|
|
|
def _forward_encoder_decoder_attn( |
|
self, |
|
seqs: Tensor, |
|
padding_mask: Optional[PaddingMask], |
|
encoder_output: Optional[Tensor], |
|
encoder_padding_mask: Optional[PaddingMask], |
|
) -> Tuple[Tensor, Tensor]: |
|
if encoder_output is None: |
|
raise ValueError( |
|
"`encoder_output` must not be `None` for encoder-decoder attention." |
|
) |
|
|
|
residual = seqs |
|
|
|
seqs = self.encoder_decoder_attn_layer_norm(seqs) |
|
|
|
p_choose = self.p_choose_layer(seqs, encoder_output) |
|
|
|
seqs = self.encoder_decoder_attn( |
|
seqs, |
|
padding_mask, |
|
encoder_output, |
|
encoder_padding_mask, |
|
encoder_output, |
|
) |
|
|
|
if self.encoder_decoder_attn_dropout is not None: |
|
seqs = self.encoder_decoder_attn_dropout(seqs) |
|
|
|
seqs = seqs + residual |
|
|
|
return seqs, p_choose |
|
|
|
def _forward_ffn(self, seqs: Tensor) -> Tensor: |
|
residual = seqs |
|
|
|
seqs = self.ffn_layer_norm(seqs) |
|
|
|
seqs = self.ffn(seqs) |
|
|
|
if self.ffn_dropout is not None: |
|
seqs = self.ffn_dropout(seqs) |
|
|
|
seqs = seqs + residual |
|
|
|
return seqs |
|
|