# Copyright (c) Meta Platforms, Inc. and affiliates. # # Contents of this file were adapted from the open source fairseq repository. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math from typing import Any, Dict, List, Optional import torch import torch.nn as nn from torch import Tensor from esm.modules import SinusoidalPositionalEmbedding from .transformer_layer import TransformerDecoderLayer def fill_with_neg_inf(t): """FP16-compatible function that fills a tensor with -inf.""" return t.float().fill_(float("-inf")).type_as(t) class TransformerDecoder(nn.Module): """ Transformer decoder consisting of *args.decoder.layers* layers. Each layer is a :class:`TransformerDecoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__( self, args, dictionary, embed_tokens, ): super().__init__() self.args = args self.dictionary = dictionary self._future_mask = torch.empty(0) self.dropout_module = nn.Dropout(args.dropout) input_embed_dim = embed_tokens.embedding_dim embed_dim = args.decoder_embed_dim self.embed_dim = embed_dim self.padding_idx = embed_tokens.padding_idx self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) self.project_in_dim = ( nn.Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None ) self.embed_positions = SinusoidalPositionalEmbedding( embed_dim, self.padding_idx, ) self.layers = nn.ModuleList([]) self.layers.extend( [ self.build_decoder_layer(args) for _ in range(args.decoder_layers) ] ) self.num_layers = len(self.layers) self.layer_norm = nn.LayerNorm(embed_dim) self.build_output_projection(args, dictionary) def build_output_projection(self, args, dictionary): self.output_projection = nn.Linear( args.decoder_embed_dim, len(dictionary), bias=False ) nn.init.normal_( self.output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5 ) def build_decoder_layer(self, args): return TransformerDecoderLayer(args) def forward( self, prev_output_tokens, encoder_out: Optional[Dict[str, List[Tensor]]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, features_only: bool = False, return_all_hiddens: bool = False, ): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for teacher forcing encoder_out (optional): output from the encoder, used for encoder-side attention, should be of size T x B x C incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` features_only (bool, optional): only return features without applying output layer (default: False). Returns: tuple: - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ x, extra = self.extract_features( prev_output_tokens, encoder_out=encoder_out, incremental_state=incremental_state, ) if not features_only: x = self.output_layer(x) x = x.transpose(1, 2) # B x T x C -> B x C x T return x, extra def extract_features( self, prev_output_tokens, encoder_out: Optional[Dict[str, List[Tensor]]], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, ): """ Similar to *forward* but only return features. Includes several features from "Jointly Learning to Align and Translate with Transformer Models" (Garg et al., EMNLP 2019). Returns: tuple: - the decoder's features of shape `(batch, tgt_len, embed_dim)` - a dictionary with any model-specific outputs """ bs, slen = prev_output_tokens.size() enc: Optional[Tensor] = None padding_mask: Optional[Tensor] = None if encoder_out is not None and len(encoder_out["encoder_out"]) > 0: enc = encoder_out["encoder_out"][0] assert ( enc.size()[1] == bs ), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}" if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0: padding_mask = encoder_out["encoder_padding_mask"][0] # embed positions positions = self.embed_positions( prev_output_tokens ) if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] positions = positions[:, -1:] # embed tokens and positions x = self.embed_scale * self.embed_tokens(prev_output_tokens) if self.project_in_dim is not None: x = self.project_in_dim(x) x += positions x = self.dropout_module(x) # B x T x C -> T x B x C x = x.transpose(0, 1) self_attn_padding_mask: Optional[Tensor] = None if prev_output_tokens.eq(self.padding_idx).any(): self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) # decoder layers attn: Optional[Tensor] = None inner_states: List[Optional[Tensor]] = [x] for idx, layer in enumerate(self.layers): if incremental_state is None: self_attn_mask = self.buffered_future_mask(x) else: self_attn_mask = None x, layer_attn, _ = layer( x, enc, padding_mask, incremental_state, self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask, need_attn=False, need_head_weights=False, ) inner_states.append(x) if self.layer_norm is not None: x = self.layer_norm(x) # T x B x C -> B x C x T x = x.transpose(0, 1) return x, {"inner_states": inner_states} def output_layer(self, features): """Project features to the vocabulary size.""" return self.output_projection(features) def buffered_future_mask(self, tensor): dim = tensor.size(0) # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround. if ( self._future_mask.size(0) == 0 or (not self._future_mask.device == tensor.device) or self._future_mask.size(0) < dim ): self._future_mask = torch.triu( fill_with_neg_inf(torch.zeros([dim, dim])), 1 ) self._future_mask = self._future_mask.to(tensor) return self._future_mask[:dim, :dim]