# -------------------------------------------------------- # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) # Github source: https://github.com/mbzuai-nlp/ArTST # Based on speecht5, fairseq and espnet code bases # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet # -------------------------------------------------------- from typing import Any, Dict, List, Optional import torch import torch.nn as nn from fairseq import utils from fairseq.distributed import fsdp_wrap from fairseq.models import ( FairseqIncrementalDecoder, ) from fairseq.modules import ( FairseqDropout, LayerDropModuleList, LayerNorm, ) from fairseq.modules.checkpoint_activations import checkpoint_wrapper from torch import Tensor from .encoder import RelativePositionalEncoding from .transformer_layer import TransformerDecoderLayer DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) class TransformerDecoder(FairseqIncrementalDecoder): """ Transformer decoder consisting of *args.decoder_layers* layers. Each layer is a :class:`TransformerDecoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__( self, args, no_encoder_attn=False, ): self.args = args super().__init__(None) self.register_buffer("version", torch.Tensor([3])) self._future_mask = torch.empty(0) self.dropout_module = FairseqDropout( args.dropout, module_name=self.__class__.__name__ ) self.decoder_layerdrop = args.decoder_layerdrop # self.max_s_positions = args.max_target_positions export = getattr(args, "export", False) self.cross_self_attention = getattr(args, "cross_self_attention", False) if self.decoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.decoder_layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend( [ self.build_decoder_layer(args, no_encoder_attn) for _ in range(args.decoder_layers) ] ) self.num_layers = len(self.layers) if args.decoder_normalize_before and not getattr( args, "no_decoder_final_norm", False ): self.layer_norm = LayerNorm(args.decoder_embed_dim, eps=args.layer_norm_eps, export=export) else: self.layer_norm = None if args.relative_position_embedding: self.pos_emb = RelativePositionalEncoding(args.encoder_embed_dim//args.encoder_attention_heads, args.decoder_max_relative_position) def build_decoder_layer(self, args, no_encoder_attn=False): layer = TransformerDecoderLayer(args, no_encoder_attn=no_encoder_attn, has_relative_attention_bias=args.relative_position_embedding) checkpoint = getattr(args, "checkpoint_activations", False) if checkpoint: offload_to_cpu = getattr(args, "offload_activations", False) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) # if we are checkpointing, enforce that FSDP always wraps the # checkpointed layer, regardless of layer size min_params_to_wrap = ( getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) if not checkpoint else 0 ) layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer def forward( self, prev_output_tokens, tgt_mask, encoder_out: Optional[Dict[str, List[Tensor]]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, src_lengths: Optional[Any] = None, return_all_hiddens: bool = False, ): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for teacher forcing encoder_out (optional): output from the encoder, used for encoder-side attention, should be of size T x B x C incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` features_only (bool, optional): only return features without applying output layer (default: False). full_context_alignment (bool, optional): don't apply auto-regressive mask to self-attention (default: False). Returns: tuple: - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ x, extra = self.extract_features( prev_output_tokens, tgt_mask, encoder_out=encoder_out, incremental_state=incremental_state, full_context_alignment=full_context_alignment, alignment_layer=alignment_layer, alignment_heads=alignment_heads, ) return x, extra def extract_features( self, prev_output_tokens, tgt_mask, encoder_out: Optional[Dict[str, List[Tensor]]], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, ): return self.extract_features_scriptable( prev_output_tokens, tgt_mask, encoder_out, incremental_state, full_context_alignment, alignment_layer, alignment_heads, ) """ A scriptable subclass of this class has an extract_features method and calls super().extract_features, but super() is not supported in torchscript. A copy of this function is made to be used in the subclass instead. """ def extract_features_scriptable( self, prev_output_tokens, tgt_mask, encoder_out: Optional[Dict[str, List[Tensor]]], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, ): """ Similar to *forward* but only return features. Includes several features from "Jointly Learning to Align and Translate with Transformer Models" (Garg et al., EMNLP 2019). Args: full_context_alignment (bool, optional): don't apply auto-regressive mask to self-attention (default: False). alignment_layer (int, optional): return mean alignment over heads at this layer (default: last layer). alignment_heads (int, optional): only average alignment over this many heads (default: all heads). Returns: tuple: - the decoder's features of shape `(batch, tgt_len, embed_dim)` - a dictionary with any model-specific outputs """ bs = prev_output_tokens.size(0) if alignment_layer is None: alignment_layer = self.num_layers - 1 enc: Optional[Tensor] = None padding_mask: Optional[Tensor] = None if encoder_out is not None and len(encoder_out["encoder_out"]) > 0: enc = encoder_out["encoder_out"][0] assert ( enc.size()[1] == bs ), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}" if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0: padding_mask = encoder_out["encoder_padding_mask"][0] # B x T x C -> T x B x C x = prev_output_tokens.transpose(0, 1) self_attn_padding_mask: Optional[Tensor] = None if self.cross_self_attention or tgt_mask is not None: self_attn_padding_mask = tgt_mask ## relative position embedding if self.args.relative_position_embedding: x_len = x.shape[0] pos_seq = torch.arange(0, x_len).long().to(x.device) pos_seq = pos_seq[:, None] - pos_seq[None, :] pos_k, pos_v = self.pos_emb(pos_seq) else: pos_k = None # decoder layers attn_list = [] attn: Optional[Tensor] = None inner_states: List[Optional[Tensor]] = [x] for idx, layer in enumerate(self.layers): if incremental_state is None and not full_context_alignment: self_attn_mask = self.buffered_future_mask(x) else: self_attn_mask = None x, layer_attn, _ = layer( x, enc, padding_mask, incremental_state, self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask, need_attn=bool((idx == alignment_layer or alignment_layer == -1)), need_head_weights=bool((idx == alignment_layer or alignment_layer == -1)), pos_bias=pos_k, ) inner_states.append(x) if layer_attn is not None and (idx == alignment_layer or alignment_layer == -1): attn = layer_attn.float().to(x) attn_list.append(attn.transpose(0, 1)) if attn is not None and len(attn_list) == 1: if alignment_heads is not None: attn = attn[:alignment_heads] # average probabilities over heads attn = attn.mean(dim=0) if self.layer_norm is not None: x = self.layer_norm(x) # T x B x C -> B x T x C x = x.transpose(0, 1) return x, {"attn": [attn if len(attn_list) <= 1 else attn_list], "inner_states": inner_states} # def max_positions(self): # """Maximum output length supported by the decoder.""" # return self.max_target_positions def buffered_future_mask(self, tensor): dim = tensor.size(0) # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround. if ( self._future_mask.size(0) == 0 or (not self._future_mask.device == tensor.device) or self._future_mask.size(0) < dim ): self._future_mask = torch.triu( utils.fill_with_neg_inf(torch.zeros([dim, dim], device=tensor.device)), 1, ) else: self._future_mask = self._future_mask.to(tensor) return self._future_mask[:dim, :dim] def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" for i in range(self.num_layers): # update layer norms layer_norm_map = { "0": "self_attn_layer_norm", "1": "encoder_attn_layer_norm", "2": "final_layer_norm", } for old, new in layer_norm_map.items(): for m in ("weight", "bias"): k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m) if k in state_dict: state_dict[ "{}.layers.{}.{}.{}".format(name, i, new, m) ] = state_dict[k] del state_dict[k] version_key = "{}.version".format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict def set_num_updates(self, num_updates): """State from trainer to pass along to model at every update.""" def _apply(m): if hasattr(m, "set_num_updates") and m != self: m.set_num_updates(num_updates) self.apply(_apply)