Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) | |
# Github source: https://github.com/mbzuai-nlp/ArTST | |
# Based on speecht5, fairseq and espnet code bases | |
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet | |
# -------------------------------------------------------- | |
from typing import Any, Dict, List, Optional | |
import torch | |
import torch.nn as nn | |
from fairseq import utils | |
from fairseq.distributed import fsdp_wrap | |
from fairseq.models import ( | |
FairseqIncrementalDecoder, | |
) | |
from fairseq.modules import ( | |
FairseqDropout, | |
LayerDropModuleList, | |
LayerNorm, | |
) | |
from fairseq.modules.checkpoint_activations import checkpoint_wrapper | |
from torch import Tensor | |
from .encoder import RelativePositionalEncoding | |
from .transformer_layer import TransformerDecoderLayer | |
DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) | |
class TransformerDecoder(FairseqIncrementalDecoder): | |
""" | |
Transformer decoder consisting of *args.decoder_layers* layers. Each layer | |
is a :class:`TransformerDecoderLayer`. | |
Args: | |
args (argparse.Namespace): parsed command-line arguments | |
dictionary (~fairseq.data.Dictionary): decoding dictionary | |
embed_tokens (torch.nn.Embedding): output embedding | |
no_encoder_attn (bool, optional): whether to attend to encoder outputs | |
(default: False). | |
""" | |
def __init__( | |
self, | |
args, | |
no_encoder_attn=False, | |
): | |
self.args = args | |
super().__init__(None) | |
self.register_buffer("version", torch.Tensor([3])) | |
self._future_mask = torch.empty(0) | |
self.dropout_module = FairseqDropout( | |
args.dropout, module_name=self.__class__.__name__ | |
) | |
self.decoder_layerdrop = args.decoder_layerdrop | |
# self.max_s_positions = args.max_target_positions | |
export = getattr(args, "export", False) | |
self.cross_self_attention = getattr(args, "cross_self_attention", False) | |
if self.decoder_layerdrop > 0.0: | |
self.layers = LayerDropModuleList(p=self.decoder_layerdrop) | |
else: | |
self.layers = nn.ModuleList([]) | |
self.layers.extend( | |
[ | |
self.build_decoder_layer(args, no_encoder_attn) | |
for _ in range(args.decoder_layers) | |
] | |
) | |
self.num_layers = len(self.layers) | |
if args.decoder_normalize_before and not getattr( | |
args, "no_decoder_final_norm", False | |
): | |
self.layer_norm = LayerNorm(args.decoder_embed_dim, eps=args.layer_norm_eps, export=export) | |
else: | |
self.layer_norm = None | |
if args.relative_position_embedding: | |
self.pos_emb = RelativePositionalEncoding(args.encoder_embed_dim//args.encoder_attention_heads, args.decoder_max_relative_position) | |
def build_decoder_layer(self, args, no_encoder_attn=False): | |
layer = TransformerDecoderLayer(args, no_encoder_attn=no_encoder_attn, has_relative_attention_bias=args.relative_position_embedding) | |
checkpoint = getattr(args, "checkpoint_activations", False) | |
if checkpoint: | |
offload_to_cpu = getattr(args, "offload_activations", False) | |
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) | |
# if we are checkpointing, enforce that FSDP always wraps the | |
# checkpointed layer, regardless of layer size | |
min_params_to_wrap = ( | |
getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) | |
if not checkpoint | |
else 0 | |
) | |
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) | |
return layer | |
def forward( | |
self, | |
prev_output_tokens, | |
tgt_mask, | |
encoder_out: Optional[Dict[str, List[Tensor]]] = None, | |
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, | |
full_context_alignment: bool = False, | |
alignment_layer: Optional[int] = None, | |
alignment_heads: Optional[int] = None, | |
src_lengths: Optional[Any] = None, | |
return_all_hiddens: bool = False, | |
): | |
""" | |
Args: | |
prev_output_tokens (LongTensor): previous decoder outputs of shape | |
`(batch, tgt_len)`, for teacher forcing | |
encoder_out (optional): output from the encoder, used for | |
encoder-side attention, should be of size T x B x C | |
incremental_state (dict): dictionary used for storing state during | |
:ref:`Incremental decoding` | |
features_only (bool, optional): only return features without | |
applying output layer (default: False). | |
full_context_alignment (bool, optional): don't apply | |
auto-regressive mask to self-attention (default: False). | |
Returns: | |
tuple: | |
- the decoder's output of shape `(batch, tgt_len, vocab)` | |
- a dictionary with any model-specific outputs | |
""" | |
x, extra = self.extract_features( | |
prev_output_tokens, | |
tgt_mask, | |
encoder_out=encoder_out, | |
incremental_state=incremental_state, | |
full_context_alignment=full_context_alignment, | |
alignment_layer=alignment_layer, | |
alignment_heads=alignment_heads, | |
) | |
return x, extra | |
def extract_features( | |
self, | |
prev_output_tokens, | |
tgt_mask, | |
encoder_out: Optional[Dict[str, List[Tensor]]], | |
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, | |
full_context_alignment: bool = False, | |
alignment_layer: Optional[int] = None, | |
alignment_heads: Optional[int] = None, | |
): | |
return self.extract_features_scriptable( | |
prev_output_tokens, | |
tgt_mask, | |
encoder_out, | |
incremental_state, | |
full_context_alignment, | |
alignment_layer, | |
alignment_heads, | |
) | |
""" | |
A scriptable subclass of this class has an extract_features method and calls | |
super().extract_features, but super() is not supported in torchscript. A copy of | |
this function is made to be used in the subclass instead. | |
""" | |
def extract_features_scriptable( | |
self, | |
prev_output_tokens, | |
tgt_mask, | |
encoder_out: Optional[Dict[str, List[Tensor]]], | |
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, | |
full_context_alignment: bool = False, | |
alignment_layer: Optional[int] = None, | |
alignment_heads: Optional[int] = None, | |
): | |
""" | |
Similar to *forward* but only return features. | |
Includes several features from "Jointly Learning to Align and | |
Translate with Transformer Models" (Garg et al., EMNLP 2019). | |
Args: | |
full_context_alignment (bool, optional): don't apply | |
auto-regressive mask to self-attention (default: False). | |
alignment_layer (int, optional): return mean alignment over | |
heads at this layer (default: last layer). | |
alignment_heads (int, optional): only average alignment over | |
this many heads (default: all heads). | |
Returns: | |
tuple: | |
- the decoder's features of shape `(batch, tgt_len, embed_dim)` | |
- a dictionary with any model-specific outputs | |
""" | |
bs = prev_output_tokens.size(0) | |
if alignment_layer is None: | |
alignment_layer = self.num_layers - 1 | |
enc: Optional[Tensor] = None | |
padding_mask: Optional[Tensor] = None | |
if encoder_out is not None and len(encoder_out["encoder_out"]) > 0: | |
enc = encoder_out["encoder_out"][0] | |
assert ( | |
enc.size()[1] == bs | |
), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}" | |
if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0: | |
padding_mask = encoder_out["encoder_padding_mask"][0] | |
# B x T x C -> T x B x C | |
x = prev_output_tokens.transpose(0, 1) | |
self_attn_padding_mask: Optional[Tensor] = None | |
if self.cross_self_attention or tgt_mask is not None: | |
self_attn_padding_mask = tgt_mask | |
## relative position embedding | |
if self.args.relative_position_embedding: | |
x_len = x.shape[0] | |
pos_seq = torch.arange(0, x_len).long().to(x.device) | |
pos_seq = pos_seq[:, None] - pos_seq[None, :] | |
pos_k, pos_v = self.pos_emb(pos_seq) | |
else: | |
pos_k = None | |
# decoder layers | |
attn_list = [] | |
attn: Optional[Tensor] = None | |
inner_states: List[Optional[Tensor]] = [x] | |
for idx, layer in enumerate(self.layers): | |
if incremental_state is None and not full_context_alignment: | |
self_attn_mask = self.buffered_future_mask(x) | |
else: | |
self_attn_mask = None | |
x, layer_attn, _ = layer( | |
x, | |
enc, | |
padding_mask, | |
incremental_state, | |
self_attn_mask=self_attn_mask, | |
self_attn_padding_mask=self_attn_padding_mask, | |
need_attn=bool((idx == alignment_layer or alignment_layer == -1)), | |
need_head_weights=bool((idx == alignment_layer or alignment_layer == -1)), | |
pos_bias=pos_k, | |
) | |
inner_states.append(x) | |
if layer_attn is not None and (idx == alignment_layer or alignment_layer == -1): | |
attn = layer_attn.float().to(x) | |
attn_list.append(attn.transpose(0, 1)) | |
if attn is not None and len(attn_list) == 1: | |
if alignment_heads is not None: | |
attn = attn[:alignment_heads] | |
# average probabilities over heads | |
attn = attn.mean(dim=0) | |
if self.layer_norm is not None: | |
x = self.layer_norm(x) | |
# T x B x C -> B x T x C | |
x = x.transpose(0, 1) | |
return x, {"attn": [attn if len(attn_list) <= 1 else attn_list], "inner_states": inner_states} | |
# def max_positions(self): | |
# """Maximum output length supported by the decoder.""" | |
# return self.max_target_positions | |
def buffered_future_mask(self, tensor): | |
dim = tensor.size(0) | |
# self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround. | |
if ( | |
self._future_mask.size(0) == 0 | |
or (not self._future_mask.device == tensor.device) | |
or self._future_mask.size(0) < dim | |
): | |
self._future_mask = torch.triu( | |
utils.fill_with_neg_inf(torch.zeros([dim, dim], device=tensor.device)), 1, | |
) | |
else: | |
self._future_mask = self._future_mask.to(tensor) | |
return self._future_mask[:dim, :dim] | |
def upgrade_state_dict_named(self, state_dict, name): | |
"""Upgrade a (possibly old) state dict for new versions of fairseq.""" | |
for i in range(self.num_layers): | |
# update layer norms | |
layer_norm_map = { | |
"0": "self_attn_layer_norm", | |
"1": "encoder_attn_layer_norm", | |
"2": "final_layer_norm", | |
} | |
for old, new in layer_norm_map.items(): | |
for m in ("weight", "bias"): | |
k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m) | |
if k in state_dict: | |
state_dict[ | |
"{}.layers.{}.{}.{}".format(name, i, new, m) | |
] = state_dict[k] | |
del state_dict[k] | |
version_key = "{}.version".format(name) | |
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2: | |
# earlier checkpoints did not normalize after the stack of layers | |
self.layer_norm = None | |
self.normalize = False | |
state_dict[version_key] = torch.Tensor([1]) | |
return state_dict | |
def set_num_updates(self, num_updates): | |
"""State from trainer to pass along to model at every update.""" | |
def _apply(m): | |
if hasattr(m, "set_num_updates") and m != self: | |
m.set_num_updates(num_updates) | |
self.apply(_apply) | |