# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math import random from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from fairseq import utils from fairseq.distributed import fsdp_wrap from fairseq.models import ( FairseqEncoder, FairseqEncoderDecoderModel, FairseqIncrementalDecoder, register_model, register_model_architecture, ) from fairseq.modules import ( AdaptiveSoftmax, BaseLayer, FairseqDropout, LayerDropModuleList, LayerNorm, SinusoidalPositionalEmbedding, GradMultiply ) from fairseq.modules.checkpoint_activations import checkpoint_wrapper from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ from torch import Tensor from .unify_transformer_layer import TransformerEncoderLayer, TransformerDecoderLayer from .resnet import ResNet DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024 DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) def BatchNorm2d(out_chan, momentum=0.1, eps=1e-3): return nn.SyncBatchNorm.convert_sync_batchnorm( nn.BatchNorm2d(out_chan, momentum=momentum, eps=eps) ) def make_token_bucket_position(bucket_size, max_position=DEFAULT_MAX_SOURCE_POSITIONS): context_pos = torch.arange(max_position, dtype=torch.long)[:, None] memory_pos = torch.arange(max_position, dtype=torch.long)[None, :] relative_pos = context_pos - memory_pos sign = torch.sign(relative_pos) mid = bucket_size // 2 abs_pos = torch.where((relative_pos -mid), mid-1, torch.abs(relative_pos)) log_pos = torch.ceil(torch.log(abs_pos/mid)/math.log((max_position-1)/mid) * (mid-1)) + mid log_pos = log_pos.int() bucket_pos = torch.where(abs_pos.le(mid), relative_pos, log_pos*sign).long() return bucket_pos + bucket_size - 1 def make_image_bucket_position(bucket_size, num_relative_distance): coords_h = torch.arange(bucket_size) coords_w = torch.arange(bucket_size) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += bucket_size - 1 # shift to start from 0 relative_coords[:, :, 1] += bucket_size - 1 relative_coords[:, :, 0] *= 2 * bucket_size - 1 relative_position_index = torch.zeros(size=(bucket_size * bucket_size + 1,) * 2, dtype=relative_coords.dtype) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = num_relative_distance - 3 relative_position_index[0:, 0] = num_relative_distance - 2 relative_position_index[0, 0] = num_relative_distance - 1 return relative_position_index @register_model("unify_transformer") class TransformerModel(FairseqEncoderDecoderModel): """ Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017) `_. Args: encoder (TransformerEncoder): the encoder decoder (TransformerDecoder): the decoder The Transformer model provides the following named architectures and command-line arguments: .. argparse:: :ref: fairseq.models.transformer_parser :prog: """ def __init__(self, args, encoder, decoder): super().__init__(encoder, decoder) self.args = args self.supports_align_args = True @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" # fmt: off parser.add_argument('--activation-fn', choices=utils.get_available_activation_fns(), help='activation function to use') parser.add_argument('--dropout', type=float, metavar='D', help='dropout probability') parser.add_argument('--attention-dropout', type=float, metavar='D', help='dropout probability for attention weights') parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D', help='dropout probability after activation in FFN.') parser.add_argument('--encoder-embed-path', type=str, metavar='STR', help='path to pre-trained encoder embedding') parser.add_argument('--encoder-embed-dim', type=int, metavar='N', help='encoder embedding dimension') parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N', help='encoder embedding dimension for FFN') parser.add_argument('--encoder-layers', type=int, metavar='N', help='num encoder layers') parser.add_argument('--encoder-attention-heads', type=int, metavar='N', help='num encoder attention heads') parser.add_argument('--encoder-normalize-before', action='store_true', help='apply layernorm before each encoder block') parser.add_argument('--encoder-learned-pos', action='store_true', help='use learned positional embeddings in the encoder') parser.add_argument('--decoder-embed-path', type=str, metavar='STR', help='path to pre-trained decoder embedding') parser.add_argument('--decoder-embed-dim', type=int, metavar='N', help='decoder embedding dimension') parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N', help='decoder embedding dimension for FFN') parser.add_argument('--decoder-layers', type=int, metavar='N', help='num decoder layers') parser.add_argument('--decoder-attention-heads', type=int, metavar='N', help='num decoder attention heads') parser.add_argument('--decoder-learned-pos', action='store_true', help='use learned positional embeddings in the decoder') parser.add_argument('--decoder-normalize-before', action='store_true', help='apply layernorm before each decoder block') parser.add_argument('--decoder-output-dim', type=int, metavar='N', help='decoder output dimension (extra linear layer ' 'if different from decoder embed dim') parser.add_argument('--share-decoder-input-output-embed', action='store_true', help='share decoder input and output embeddings') parser.add_argument('--share-all-embeddings', action='store_true', help='share encoder, decoder and output embeddings' ' (requires shared dictionary and embed dim)') parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true', help='if set, disables positional embeddings (outside self attention)') parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', help='comma separated list of adaptive softmax cutoff points. ' 'Must be used with adaptive_loss criterion'), parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', help='sets adaptive softmax dropout for the tail projections') parser.add_argument('--layernorm-embedding', action='store_true', help='add layernorm to embedding') parser.add_argument('--no-scale-embedding', action='store_true', help='if True, dont scale embeddings') parser.add_argument('--checkpoint-activations', action='store_true', help='checkpoint activations at each layer, which saves GPU ' 'memory usage at the cost of some additional compute') parser.add_argument('--offload-activations', action='store_true', help='checkpoint activations at each layer, then save to gpu. Sets --checkpoint-activations.') # args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019) parser.add_argument('--no-cross-attention', default=False, action='store_true', help='do not perform cross-attention') parser.add_argument('--cross-self-attention', default=False, action='store_true', help='perform cross+self-attention') # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0, help='LayerDrop probability for encoder') parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0, help='LayerDrop probability for decoder') parser.add_argument('--encoder-layers-to-keep', default=None, help='which layers to *keep* when pruning as a comma-separated list') parser.add_argument('--decoder-layers-to-keep', default=None, help='which layers to *keep* when pruning as a comma-separated list') # args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020) parser.add_argument('--quant-noise-pq', type=float, metavar='D', default=0, help='iterative PQ quantization noise at training time') parser.add_argument('--quant-noise-pq-block-size', type=int, metavar='D', default=8, help='block size of quantization noise at training time') parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0, help='scalar quantization noise and scalar quantization at training time') # args for Fully Sharded Data Parallel (FSDP) training parser.add_argument( '--min-params-to-wrap', type=int, metavar='D', default=DEFAULT_MIN_PARAMS_TO_WRAP, help=( 'minimum number of params for a layer to be wrapped with FSDP() when ' 'training with --ddp-backend=fully_sharded. Smaller values will ' 'improve memory efficiency, but may make torch.distributed ' 'communication less efficient due to smaller input sizes. This option ' 'is set to 0 (i.e., always wrap) when --checkpoint-activations or ' '--offload-activations are passed.' ) ) parser.add_argument('--resnet-drop-path-rate', type=float, help='resnet drop path rate') parser.add_argument('--encoder-drop-path-rate', type=float, help='encoder drop path rate') parser.add_argument('--decoder-drop-path-rate', type=float, help='encoder drop path rate') parser.add_argument('--token-bucket-size', type=int, help='token bucket size') parser.add_argument('--image-bucket-size', type=int, help='image bucket size') parser.add_argument('--attn-scale-factor', type=float, help='attention scale factor') parser.add_argument('--freeze-resnet', action='store_true', help='freeze resnet') parser.add_argument('--freeze-encoder-embedding', action='store_true', help='freeze encoder token embedding') parser.add_argument('--freeze-decoder-embedding', action='store_true', help='freeze decoder token embedding') parser.add_argument('--add-type-embedding', action='store_true', help='add source/region/patch type embedding') parser.add_argument('--resnet-type', choices=['resnet50', 'resnet101', 'resnet152'], help='resnet type') parser.add_argument('--resnet-model-path', type=str, metavar='STR', help='path to load resnet') parser.add_argument('--code-image-size', type=int, help='code image size') parser.add_argument('--patch-layernorm-embedding', action='store_true', help='add layernorm to patch embedding') parser.add_argument('--code-layernorm-embedding', action='store_true', help='add layernorm to code embedding') parser.add_argument('--entangle-position-embedding', action='store_true', help='entangle position embedding') parser.add_argument('--disable-entangle', action='store_true', help='disable entangle') parser.add_argument('--sync-bn', action='store_true', help='sync batchnorm') parser.add_argument('--scale-attn', action='store_true', help='scale attn') parser.add_argument('--scale-fc', action='store_true', help='scale fc') parser.add_argument('--scale-heads', action='store_true', help='scale heads') parser.add_argument('--scale-resids', action='store_true', help='scale resids') # fmt: on @classmethod def build_model(cls, args, task): """Build a new model instance.""" # make sure all arguments are present in older models base_architecture(args) if args.encoder_layers_to_keep: args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) if args.decoder_layers_to_keep: args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) if getattr(args, "max_source_positions", None) is None: args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS if getattr(args, "max_target_positions", None) is None: args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS src_dict, tgt_dict = task.source_dictionary, task.target_dictionary if args.share_all_embeddings: if src_dict != tgt_dict: raise ValueError("--share-all-embeddings requires a joined dictionary") if args.encoder_embed_dim != args.decoder_embed_dim: raise ValueError( "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" ) if args.decoder_embed_path and ( args.decoder_embed_path != args.encoder_embed_path ): raise ValueError( "--share-all-embeddings not compatible with --decoder-embed-path" ) encoder_embed_tokens = cls.build_embedding( args, src_dict, args.encoder_embed_dim, args.encoder_embed_path ) decoder_embed_tokens = encoder_embed_tokens args.share_decoder_input_output_embed = True else: encoder_embed_tokens = cls.build_embedding( args, src_dict, args.encoder_embed_dim, args.encoder_embed_path ) decoder_embed_tokens = cls.build_embedding( args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path ) if getattr(args, "freeze_encoder_embedding", False): encoder_embed_tokens.weight.requires_grad = False if getattr(args, "freeze_decoder_embedding", False): decoder_embed_tokens.weight.requires_grad = False if getattr(args, "offload_activations", False): args.checkpoint_activations = True # offloading implies checkpointing encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) if not args.share_all_embeddings: min_params_to_wrap = getattr( args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP ) # fsdp_wrap is a no-op when --ddp-backend != fully_sharded encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap) decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap) return cls(args, encoder, decoder) @classmethod def build_embedding(cls, args, dictionary, embed_dim, path=None): num_embeddings = len(dictionary) padding_idx = dictionary.pad() emb = Embedding(num_embeddings, embed_dim, padding_idx) # if provided, load from preloaded dictionaries if path: embed_dict = utils.parse_embedding(path) utils.load_embedding(embed_dict, dictionary, emb) return emb @classmethod def build_encoder(cls, args, src_dict, embed_tokens): return TransformerEncoder(args, src_dict, embed_tokens) @classmethod def build_decoder(cls, args, tgt_dict, embed_tokens): return TransformerDecoder( args, tgt_dict, embed_tokens, no_encoder_attn=getattr(args, "no_cross_attention", False), ) # TorchScript doesn't support optional arguments with variable length (**kwargs). # Current workaround is to add union of all arguments in child classes. def forward( self, src_tokens, src_lengths, prev_output_tokens, return_all_hiddens: bool = True, features_only: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, ): """ Run the forward pass for an encoder-decoder model. Copied from the base class, but without ``**kwargs``, which are not supported by TorchScript. """ encoder_out = self.encoder( src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens ) decoder_out = self.decoder( prev_output_tokens, encoder_out=encoder_out, features_only=features_only, alignment_layer=alignment_layer, alignment_heads=alignment_heads, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens, ) return decoder_out # Since get_normalized_probs is in the Fairseq Model which is not scriptable, # I rewrite the get_normalized_probs from Base Class to call the # helper function in the Base Class. @torch.jit.export def get_normalized_probs( self, net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], log_probs: bool, sample: Optional[Dict[str, Tensor]] = None, ): """Get normalized probabilities (or log probs) from a net's output.""" return self.get_normalized_probs_scriptable(net_output, log_probs, sample) class TransformerEncoder(FairseqEncoder): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, embed_tokens): self.args = args super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) self.dropout_module = FairseqDropout( args.dropout, module_name=self.__class__.__name__ ) self.encoder_layerdrop = args.encoder_layerdrop embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.num_attention_heads = args.encoder_attention_heads self.embed_tokens = embed_tokens self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) if getattr(args, "layernorm_embedding", False): self.layernorm_embedding = LayerNorm(embed_dim) else: self.layernorm_embedding = None if getattr(args, "add_type_embedding", False): self.type_embedding = Embedding(2, embed_dim, padding_idx=None) else: self.type_embedding = None if getattr(args, "sync_bn", False): norm_layer = BatchNorm2d else: norm_layer = None if args.resnet_type == 'resnet101': self.embed_images = ResNet([3, 4, 23], norm_layer=norm_layer, drop_path_rate=args.resnet_drop_path_rate) elif args.resnet_type == 'resnet152': self.embed_images = ResNet([3, 8, 36], norm_layer=norm_layer, drop_path_rate=args.resnet_drop_path_rate) else: raise NotImplementedError self.image_proj = Linear(1024, embed_dim) if getattr(args, "resnet_model_path", None): print("load resnet {}".format(args.resnet_model_path)) resnet_state_dict = torch.load(self.args.resnet_model_path) self.embed_images.load_state_dict(resnet_state_dict) if getattr(args, "patch_layernorm_embedding", False): self.patch_layernorm_embedding = LayerNorm(embed_dim) else: self.patch_layernorm_embedding = None self.embed_positions = Embedding(args.max_source_positions + 2, embed_dim) self.embed_image_positions = Embedding(args.image_bucket_size ** 2 + 1, embed_dim) self.pos_ln = LayerNorm(embed_dim) self.image_pos_ln = LayerNorm(embed_dim) self.pos_scaling = float(embed_dim / args.encoder_attention_heads * args.attn_scale_factor) ** -0.5 self.pos_q_linear = nn.Linear(embed_dim, embed_dim) self.pos_k_linear = nn.Linear(embed_dim, embed_dim) if not args.adaptive_input and args.quant_noise_pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(embed_dim, embed_dim, bias=False), args.quant_noise_pq, args.quant_noise_pq_block_size, ) else: self.quant_noise = None if self.encoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.encoder_layerdrop) else: self.layers = nn.ModuleList([]) dpr = [x.item() for x in torch.linspace(0, args.encoder_drop_path_rate, args.encoder_layers)] self.layers.extend( [self.build_encoder_layer(args, drop_path_rate=dpr[i]) for i in range(args.encoder_layers)] ) self.num_layers = len(self.layers) if args.encoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None token_bucket_size = args.token_bucket_size token_num_rel_dis = 2 * token_bucket_size - 1 token_rp_bucket = make_token_bucket_position(token_bucket_size) self.token_rel_pos_table_list = nn.ModuleList( [Embedding(token_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.encoder_layers)] ) image_bucket_size = args.image_bucket_size image_num_rel_dis = (2 * image_bucket_size - 1) * (2 * image_bucket_size - 1) + 3 image_rp_bucket = make_image_bucket_position(image_bucket_size, image_num_rel_dis) self.image_rel_pos_table_list = nn.ModuleList( [Embedding(image_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.encoder_layers)] ) self.register_buffer("token_rp_bucket", token_rp_bucket) self.register_buffer("image_rp_bucket", image_rp_bucket) self.entangle_position_embedding = args.entangle_position_embedding def train(self, mode=True): super(TransformerEncoder, self).train(mode) if getattr(self.args, "freeze_resnet", False): for m in self.embed_images.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() m.weight.requires_grad = False m.bias.requires_grad = False def build_encoder_layer(self, args, drop_path_rate=0.0): layer = TransformerEncoderLayer(args, drop_path_rate=drop_path_rate) checkpoint = getattr(args, "checkpoint_activations", False) if checkpoint: offload_to_cpu = getattr(args, "offload_activations", False) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) # if we are checkpointing, enforce that FSDP always wraps the # checkpointed layer, regardless of layer size min_params_to_wrap = ( getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) if not checkpoint else 0 ) layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer def get_rel_pos_bias(self, x, idx): seq_len = x.size(1) rp_bucket = self.token_rp_bucket[:seq_len, :seq_len] values = F.embedding(rp_bucket, self.token_rel_pos_table_list[idx].weight) values = values.unsqueeze(0).expand(x.size(0), -1, -1, -1) values = values.permute([0, 3, 1, 2]) return values.contiguous() def get_image_rel_pos_bias(self, image_position_ids, idx): bsz, seq_len = image_position_ids.shape rp_bucket_size = self.image_rp_bucket.size(1) rp_bucket = self.image_rp_bucket.unsqueeze(0).expand( bsz, rp_bucket_size, rp_bucket_size ).gather(1, image_position_ids[:, :, None].expand(bsz, seq_len, rp_bucket_size) ).gather(2, image_position_ids[:, None, :].expand(bsz, seq_len, seq_len)) values = F.embedding(rp_bucket, self.image_rel_pos_table_list[idx].weight) values = values.permute(0, 3, 1, 2) return values def get_patch_images_info(self, patch_images, sample_patch_num, device): image_embed = self.embed_images(patch_images) h, w = image_embed.shape[-2:] image_num_patches = h * w image_padding_mask = patch_images.new_zeros((patch_images.size(0), image_num_patches)).bool() image_position_idx = torch.arange(w).unsqueeze(0).expand(h, w) + \ torch.arange(h).unsqueeze(1) * self.args.image_bucket_size + 1 image_position_idx = image_position_idx.view(-1).to(device) image_position_ids = image_position_idx[None, :].expand(patch_images.size(0), image_num_patches) image_embed = image_embed.flatten(2).transpose(1, 2) if sample_patch_num is not None: patch_orders = [ random.sample(range(image_num_patches), k=sample_patch_num) for _ in range(patch_images.size(0)) ] patch_orders = torch.LongTensor(patch_orders).to(device) image_embed = image_embed.gather( 1, patch_orders.unsqueeze(2).expand(-1, -1, image_embed.size(2)) ) image_num_patches = sample_patch_num image_padding_mask = image_padding_mask.gather(1, patch_orders) image_position_ids = image_position_ids.gather(1, patch_orders) image_pos_embed = self.embed_image_positions(image_position_ids) return image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed def forward_embedding( self, src_tokens, image_embed: Optional[torch.Tensor] = None, image_embed_2: Optional[torch.Tensor] = None, token_embedding: Optional[torch.Tensor] = None, pos_embed: Optional[torch.Tensor] = None, image_pos_embed: Optional[torch.Tensor] = None, image_pos_embed_2: Optional[torch.Tensor] = None ): # embed tokens and positions if token_embedding is None: token_embedding = self.embed_tokens(src_tokens) x = embed = self.embed_scale * token_embedding if self.entangle_position_embedding and pos_embed is not None: x += pos_embed if self.type_embedding is not None: x += self.type_embedding(src_tokens.new_zeros(x.size()[:2])) if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = self.dropout_module(x) if self.quant_noise is not None: x = self.quant_noise(x) # embed raw images if image_embed is not None: image_embed = self.image_proj(image_embed) image_x = image_embed = self.embed_scale * image_embed if self.entangle_position_embedding and image_pos_embed is not None: image_x += image_pos_embed if self.type_embedding is not None: image_x += self.type_embedding(src_tokens.new_ones(image_x.size()[:2])) if self.patch_layernorm_embedding is not None: image_x = self.patch_layernorm_embedding(image_x) image_x = self.dropout_module(image_x) if self.quant_noise is not None: image_x = self.quant_noise(image_x) x = torch.cat([image_x, x], dim=1) embed = torch.cat([image_embed, embed], dim=1) if image_embed_2 is not None: assert self.type_embedding is not None image_embed_2 = self.image_proj(image_embed_2) image_x_2 = image_embed_2 = self.embed_scale * image_embed_2 if self.entangle_position_embedding and image_pos_embed_2 is not None: image_x_2 += image_pos_embed_2 if self.type_embedding is not None: image_x_2 += self.type_embedding(src_tokens.new_full(image_x_2.size()[:2], fill_value=2)) if self.patch_layernorm_embedding is not None: image_x_2 = self.patch_layernorm_embedding(image_x_2) image_x_2 = self.dropout_module(image_x_2) if self.quant_noise is not None: image_x_2 = self.quant_noise(image_x_2) x = torch.cat([image_x_2, x], dim=1) embed = torch.cat([image_embed_2, embed], dim=1) return x, embed def forward( self, src_tokens, src_lengths, patch_images: Optional[torch.Tensor] = None, patch_images_2: Optional[torch.Tensor] = None, patch_masks: Optional[torch.Tensor] = None, code_masks: Optional[torch.Tensor] = None, return_all_hiddens: bool = False, token_embeddings: Optional[torch.Tensor] = None, sample_patch_num: Optional[int] = None ): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). token_embeddings (torch.Tensor, optional): precomputed embeddings default `None` will recompute embeddings Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ return self.forward_scriptable(src_tokens, src_lengths, patch_images, patch_images_2, patch_masks, return_all_hiddens, token_embeddings, sample_patch_num) # TorchScript doesn't support super() method so that the scriptable Subclass # can't access the base class model in Torchscript. # Current workaround is to add a helper function with different name and # call the helper function from scriptable Subclass. def forward_scriptable( self, src_tokens, src_lengths, patch_images: Optional[torch.Tensor] = None, patch_images_2: Optional[torch.Tensor] = None, patch_masks: Optional[torch.Tensor] = None, return_all_hiddens: bool = False, token_embeddings: Optional[torch.Tensor] = None, sample_patch_num: Optional[int] = None ): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). token_embeddings (torch.Tensor, optional): precomputed embeddings default `None` will recompute embeddings Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ image_embed = None image_embed_2 = None image_pos_embed = None image_pos_embed_2 = None if patch_images is not None: image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed = \ self.get_patch_images_info(patch_images, sample_patch_num, src_tokens.device) image_padding_mask[~patch_masks] = True if patch_images_2 is not None: image_embed_2, image_num_patches_2, image_padding_mask_2, image_position_ids_2, image_pos_embed_2 = \ self.get_patch_images_info(patch_images_2, sample_patch_num, src_tokens.device) image_padding_mask_2[~patch_masks] = True encoder_padding_mask = src_tokens.eq(self.padding_idx) if patch_images is not None: encoder_padding_mask = torch.cat([image_padding_mask, encoder_padding_mask], dim=1) if patch_images_2 is not None: encoder_padding_mask = torch.cat([image_padding_mask_2, encoder_padding_mask], dim=1) has_pads = (src_tokens.device.type == "xla" or encoder_padding_mask.any()) pos_embed = self.embed_positions(utils.new_arange(src_tokens)) x, encoder_embedding = self.forward_embedding( src_tokens, image_embed, image_embed_2, token_embeddings, pos_embed, image_pos_embed, image_pos_embed_2 ) # account for padding while computing the representation if has_pads: x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) # B x T x C -> T x B x C x = x.transpose(0, 1) pos_embed = self.pos_ln(pos_embed) if patch_images is not None: image_pos_embed = self.image_pos_ln(image_pos_embed) pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1) if patch_images_2 is not None: image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2) pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1) pos_q = self.pos_q_linear(pos_embed).view( x.size(1), x.size(0), self.num_attention_heads, -1 ).transpose(1, 2) * self.pos_scaling pos_k = self.pos_k_linear(pos_embed).view( x.size(1), x.size(0), self.num_attention_heads, -1 ).transpose(1, 2) abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) encoder_states = [] if return_all_hiddens: encoder_states.append(x) # encoder layers for idx, layer in enumerate(self.layers): self_attn_bias = abs_pos_bias.clone() self_attn_bias[:, :, -src_tokens.size(1):, -src_tokens.size(1):] += self.get_rel_pos_bias(src_tokens, idx) if patch_images_2 is not None: self_attn_bias[:, :, :image_num_patches_2, :image_num_patches_2] += \ self.get_image_rel_pos_bias(image_position_ids_2, idx) self_attn_bias[:, :, image_num_patches_2:image_num_patches_2+image_num_patches, image_num_patches_2:image_num_patches_2+image_num_patches] += \ self.get_image_rel_pos_bias(image_position_ids, idx) elif patch_images is not None: self_attn_bias[:, :, :x.size(0) - src_tokens.size(1), :x.size(0) - src_tokens.size(1)] += \ self.get_image_rel_pos_bias(image_position_ids, idx) self_attn_bias = self_attn_bias.reshape(-1, x.size(0), x.size(0)) x = layer( x, encoder_padding_mask=encoder_padding_mask if has_pads else None, self_attn_bias=self_attn_bias ) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) if self.layer_norm is not None: x = self.layer_norm(x) # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in # `forward` so we use a dictionary instead. # TorchScript does not support mixed values so the values are all lists. # The empty list is equivalent to None. return { "encoder_out": [x], # T x B x C "encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_embedding": [], # B x T x C "encoder_states": encoder_states, # List[T x B x C] "src_tokens": [], "src_lengths": [], "position_embeddings": [pos_embed], # B x T x C } @torch.jit.export def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ if len(encoder_out["encoder_out"]) == 0: new_encoder_out = [] else: new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] if len(encoder_out["encoder_padding_mask"]) == 0: new_encoder_padding_mask = [] else: new_encoder_padding_mask = [ encoder_out["encoder_padding_mask"][0].index_select(0, new_order) ] if len(encoder_out["encoder_embedding"]) == 0: new_encoder_embedding = [] else: new_encoder_embedding = [ encoder_out["encoder_embedding"][0].index_select(0, new_order) ] if len(encoder_out["src_tokens"]) == 0: new_src_tokens = [] else: new_src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)] if len(encoder_out["src_lengths"]) == 0: new_src_lengths = [] else: new_src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)] if len(encoder_out["position_embeddings"]) == 0: new_position_embeddings = [] else: new_position_embeddings = [(encoder_out["position_embeddings"][0]).index_select(0, new_order)] encoder_states = encoder_out["encoder_states"] if len(encoder_states) > 0: for idx, state in enumerate(encoder_states): encoder_states[idx] = state.index_select(1, new_order) return { "encoder_out": new_encoder_out, # T x B x C "encoder_padding_mask": new_encoder_padding_mask, # B x T "encoder_embedding": new_encoder_embedding, # B x T x C "encoder_states": encoder_states, # List[T x B x C] "src_tokens": new_src_tokens, # B x T "src_lengths": new_src_lengths, # B x 1 "position_embeddings": new_position_embeddings, # B x T x C } def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return self.max_source_positions def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = "{}.embed_positions.weights".format(name) if weights_key in state_dict: print("deleting {0}".format(weights_key)) del state_dict[weights_key] state_dict[ "{}.embed_positions._float_tensor".format(name) ] = torch.FloatTensor(1) for i in range(self.num_layers): # update layer norms self.layers[i].upgrade_state_dict_named( state_dict, "{}.layers.{}".format(name, i) ) # version_key = "{}.version".format(name) # if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # # earlier checkpoints did not normalize after the stack of layers # self.layer_norm = None # self.normalize = False # state_dict[version_key] = torch.Tensor([1]) prefix = name + "." if name != "" else "" for param_name, param_tensor in self.state_dict().items(): if (prefix + param_name) not in state_dict and param_name in self.state_dict(): state_dict[prefix + param_name] = self.state_dict()[param_name] if len(state_dict["encoder.embed_image_positions.weight"]) < len(self.state_dict()["embed_image_positions.weight"]): num_posids_to_add = len(self.state_dict()["embed_image_positions.weight"]) - len(state_dict["encoder.embed_image_positions.weight"]) embed_dim = state_dict["encoder.embed_image_positions.weight"].size(1) new_pos_embed_to_add = torch.zeros(num_posids_to_add, embed_dim) nn.init.normal_(new_pos_embed_to_add, mean=0, std=embed_dim ** -0.5) new_pos_embed_to_add = new_pos_embed_to_add.to( dtype=state_dict["encoder.embed_image_positions.weight"].dtype, ) state_dict["encoder.embed_image_positions.weight"] = torch.cat( [state_dict["encoder.embed_image_positions.weight"], new_pos_embed_to_add] ) return state_dict class TransformerDecoder(FairseqIncrementalDecoder): """ Transformer decoder consisting of *args.decoder_layers* layers. Each layer is a :class:`TransformerDecoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__( self, args, dictionary, embed_tokens, no_encoder_attn=False, output_projection=None, ): self.args = args super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) self._future_mask = torch.empty(0) self.dropout_module = FairseqDropout( args.dropout, module_name=self.__class__.__name__ ) self.decoder_layerdrop = args.decoder_layerdrop self.share_input_output_embed = args.share_decoder_input_output_embed self.num_attention_heads = args.decoder_attention_heads input_embed_dim = embed_tokens.embedding_dim embed_dim = args.decoder_embed_dim self.embed_dim = embed_dim self.output_embed_dim = args.decoder_output_dim self.padding_idx = embed_tokens.padding_idx self.max_target_positions = args.max_target_positions self.embed_tokens = embed_tokens self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) if not args.adaptive_input and args.quant_noise_pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(embed_dim, embed_dim, bias=False), args.quant_noise_pq, args.quant_noise_pq_block_size, ) else: self.quant_noise = None self.project_in_dim = ( Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None ) if getattr(args, "layernorm_embedding", False): self.layernorm_embedding = LayerNorm(embed_dim) else: self.layernorm_embedding = None self.window_size = args.code_image_size // 8 self.embed_positions = Embedding(args.max_target_positions + 2, embed_dim) self.embed_image_positions = Embedding(args.image_bucket_size ** 2 + 1, embed_dim) self.pos_ln = LayerNorm(embed_dim) self.image_pos_ln = LayerNorm(embed_dim) self.pos_scaling = float(embed_dim / self.num_attention_heads * args.attn_scale_factor) ** -0.5 self.self_pos_q_linear = nn.Linear(embed_dim, embed_dim) self.self_pos_k_linear = nn.Linear(embed_dim, embed_dim) self.cross_pos_q_linear = nn.Linear(embed_dim, embed_dim) self.cross_pos_k_linear = nn.Linear(embed_dim, embed_dim) if getattr(args, "code_layernorm_embedding", False): self.code_layernorm_embedding = LayerNorm(embed_dim) else: self.code_layernorm_embedding = None self.cross_self_attention = getattr(args, "cross_self_attention", False) if self.decoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.decoder_layerdrop) else: self.layers = nn.ModuleList([]) dpr = [x.item() for x in torch.linspace(0, args.decoder_drop_path_rate, args.decoder_layers)] self.layers.extend( [ self.build_decoder_layer(args, no_encoder_attn, drop_path_rate=dpr[i]) for i in range(args.decoder_layers) ] ) self.num_layers = len(self.layers) if args.decoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None self.project_out_dim = ( Linear(embed_dim, self.output_embed_dim, bias=False) if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights else None ) self.adaptive_softmax = None self.output_projection = output_projection if self.output_projection is None: self.build_output_projection(args, dictionary, embed_tokens) token_bucket_size = args.token_bucket_size token_num_rel_dis = 2 * token_bucket_size - 1 token_rp_bucket = make_token_bucket_position(token_bucket_size) self.token_rel_pos_table_list = nn.ModuleList( [Embedding(token_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.decoder_layers)] ) image_bucket_size = args.image_bucket_size image_num_rel_dis = (2 * image_bucket_size - 1) * (2 * image_bucket_size - 1) + 3 image_rp_bucket = make_image_bucket_position(image_bucket_size, image_num_rel_dis) image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \ torch.arange(self.window_size).unsqueeze(1) * image_bucket_size + 1 image_position_idx = torch.cat([torch.tensor([0]), image_position_idx.view(-1)]) image_position_idx = torch.cat([image_position_idx, torch.tensor([1024] * 768)]) self.image_rel_pos_table_list = nn.ModuleList( [Embedding(image_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.decoder_layers)] ) self.register_buffer("token_rp_bucket", token_rp_bucket) self.register_buffer("image_rp_bucket", image_rp_bucket) self.register_buffer("image_position_idx", image_position_idx) self.entangle_position_embedding = args.entangle_position_embedding def build_output_projection(self, args, dictionary, embed_tokens): if args.adaptive_softmax_cutoff is not None: self.adaptive_softmax = AdaptiveSoftmax( len(dictionary), self.output_embed_dim, utils.eval_str_list(args.adaptive_softmax_cutoff, type=int), dropout=args.adaptive_softmax_dropout, adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, factor=args.adaptive_softmax_factor, tie_proj=args.tie_adaptive_proj, ) elif self.share_input_output_embed: self.output_projection = nn.Linear( self.embed_tokens.weight.shape[1], self.embed_tokens.weight.shape[0], bias=False, ) self.output_projection.weight = self.embed_tokens.weight else: self.output_projection = nn.Linear( self.output_embed_dim, len(dictionary), bias=False ) nn.init.normal_( self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5 ) num_base_layers = getattr(args, "base_layers", 0) for i in range(num_base_layers): self.layers.insert(((i+1) * args.decoder_layers) // (num_base_layers + 1), BaseLayer(args)) def build_decoder_layer(self, args, no_encoder_attn=False, drop_path_rate=0.0): layer = TransformerDecoderLayer(args, no_encoder_attn, drop_path_rate=drop_path_rate) checkpoint = getattr(args, "checkpoint_activations", False) if checkpoint: offload_to_cpu = getattr(args, "offload_activations", False) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) # if we are checkpointing, enforce that FSDP always wraps the # checkpointed layer, regardless of layer size min_params_to_wrap = ( getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) if not checkpoint else 0 ) layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer def get_rel_pos_bias(self, x, idx): seq_len = x.size(1) rp_bucket = self.token_rp_bucket[:seq_len, :seq_len] values = F.embedding(rp_bucket, self.token_rel_pos_table_list[idx].weight) values = values.permute([2, 0, 1]) return values.contiguous() def get_image_rel_pos_bias(self, x, idx): seq_len = x.size(1) image_position_idx = self.image_position_idx[:seq_len] rp_bucket = self.image_rp_bucket[image_position_idx][:, image_position_idx] values = F.embedding(rp_bucket, self.image_rel_pos_table_list[idx].weight) values = values.permute(2, 0, 1) return values def get_pos_info(self, tokens, tgt_pos_embed, src_pos_embed=None, use_image=False): batch_size = tokens.size(0) tgt_len = tokens.size(1) tgt_pos_embed = self.image_pos_ln(tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed) if src_pos_embed is not None: src_len = src_pos_embed.size(1) pos_q = self.cross_pos_q_linear(tgt_pos_embed).view( batch_size, tgt_len, self.num_attention_heads, -1 ).transpose(1, 2) * self.pos_scaling pos_k = self.cross_pos_k_linear(src_pos_embed).view( batch_size, src_len, self.num_attention_heads, -1 ).transpose(1, 2) else: src_len = tgt_pos_embed.size(1) pos_q = self.self_pos_q_linear(tgt_pos_embed).view( batch_size, tgt_len, self.num_attention_heads, -1 ).transpose(1, 2) * self.pos_scaling pos_k = self.self_pos_k_linear(tgt_pos_embed).view( batch_size, src_len, self.num_attention_heads, -1 ).transpose(1, 2) abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) return abs_pos_bias def forward( self, prev_output_tokens, code_masks: Optional[torch.Tensor] = None, encoder_out: Optional[Dict[str, List[Tensor]]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, features_only: bool = False, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, src_lengths: Optional[Any] = None, return_all_hiddens: bool = False, ): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for teacher forcing encoder_out (optional): output from the encoder, used for encoder-side attention, should be of size T x B x C incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` features_only (bool, optional): only return features without applying output layer (default: False). full_context_alignment (bool, optional): don't apply auto-regressive mask to self-attention (default: False). Returns: tuple: - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ x, extra = self.extract_features( prev_output_tokens, code_masks=code_masks, encoder_out=encoder_out, incremental_state=incremental_state, full_context_alignment=full_context_alignment, alignment_layer=alignment_layer, alignment_heads=alignment_heads, ) if not features_only: x = self.output_layer(x) return x, extra def extract_features( self, prev_output_tokens, code_masks: Optional[torch.Tensor], encoder_out: Optional[Dict[str, List[Tensor]]], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, ): return self.extract_features_scriptable( prev_output_tokens, code_masks, encoder_out, incremental_state, full_context_alignment, alignment_layer, alignment_heads, ) """ A scriptable subclass of this class has an extract_features method and calls super().extract_features, but super() is not supported in torchscript. A copy of this function is made to be used in the subclass instead. """ def extract_features_scriptable( self, prev_output_tokens, code_masks: Optional[torch.Tensor], encoder_out: Optional[Dict[str, List[Tensor]]], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, ): """ Similar to *forward* but only return features. Includes several features from "Jointly Learning to Align and Translate with Transformer Models" (Garg et al., EMNLP 2019). Args: full_context_alignment (bool, optional): don't apply auto-regressive mask to self-attention (default: False). alignment_layer (int, optional): return mean alignment over heads at this layer (default: last layer). alignment_heads (int, optional): only average alignment over this many heads (default: all heads). Returns: tuple: - the decoder's features of shape `(batch, tgt_len, embed_dim)` - a dictionary with any model-specific outputs """ bs, slen = prev_output_tokens.size() if alignment_layer is None: alignment_layer = self.num_layers - 1 enc: Optional[Tensor] = None padding_mask: Optional[Tensor] = None if encoder_out is not None and len(encoder_out["encoder_out"]) > 0: enc = encoder_out["encoder_out"][0] assert ( enc.size()[1] == bs ), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}" if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0: padding_mask = encoder_out["encoder_padding_mask"][0] bsz, tgt_len = prev_output_tokens.shape token_position_idx = utils.new_arange(prev_output_tokens) tgt_pos_embed = self.embed_positions(token_position_idx) if code_masks is not None and torch.any(code_masks): image_position_idx = self.image_position_idx[:prev_output_tokens.size(1)].unsqueeze(0).expand(bsz, tgt_len) tgt_pos_embed[code_masks] = self.embed_image_positions(image_position_idx)[code_masks] # self attn position bias self_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, use_image=False) if code_masks is not None and torch.any(code_masks): self_image_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, use_image=True) self_abs_pos_bias[code_masks] = self_image_abs_pos_bias[code_masks] # cross attn position bias src_pos_embed = encoder_out['position_embeddings'][0] cross_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, src_pos_embed=src_pos_embed) if code_masks is not None and torch.any(code_masks): cross_image_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, src_pos_embed=src_pos_embed, use_image=True) cross_abs_pos_bias[code_masks] = cross_image_abs_pos_bias[code_masks] cross_abs_pos_bias = cross_abs_pos_bias.reshape(-1, *cross_abs_pos_bias.size()[-2:]) all_prev_output_tokens = prev_output_tokens.clone() if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] cross_abs_pos_bias = cross_abs_pos_bias[:, -1:, :] tgt_pos_embed = tgt_pos_embed[:, -1:, :] # embed tokens and positions x = self.embed_scale * self.embed_tokens(prev_output_tokens) if self.quant_noise is not None: x = self.quant_noise(x) if self.project_in_dim is not None: x = self.project_in_dim(x) if self.entangle_position_embedding is not None and not self.args.disable_entangle: x += tgt_pos_embed if self.layernorm_embedding is not None: if code_masks is None or not code_masks.any() or not getattr(self, "code_layernorm_embedding", False): x = self.layernorm_embedding(x) elif code_masks is not None and code_masks.all(): x = self.code_layernorm_embedding(x) else: x[~code_masks] = self.layernorm_embedding(x[~code_masks]) x[code_masks] = self.code_layernorm_embedding(x[code_masks]) x = self.dropout_module(x) # B x T x C -> T x B x C x = x.transpose(0, 1) self_attn_padding_mask: Optional[Tensor] = None if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any(): self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) # decoder layers attn: Optional[Tensor] = None inner_states: List[Optional[Tensor]] = [x] for idx, layer in enumerate(self.layers): if incremental_state is None and not full_context_alignment: self_attn_mask = self.buffered_future_mask(x) else: self_attn_mask = None self_attn_bias = self_abs_pos_bias.clone() if code_masks is None or not code_masks.any(): self_attn_bias += self.get_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0) elif code_masks is not None and code_masks.all(): self_attn_bias += self.get_image_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0) else: self_attn_bias[~code_masks] += self.get_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0) self_attn_bias[code_masks] += self.get_image_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0) self_attn_bias = self_attn_bias.reshape(-1, *self_attn_bias.size()[-2:]) if incremental_state is not None: self_attn_bias = self_attn_bias[:, -1:, :] x, layer_attn, _ = layer( x, enc, padding_mask, incremental_state, self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask, need_attn=bool((idx == alignment_layer)), need_head_weights=bool((idx == alignment_layer)), self_attn_bias=self_attn_bias, cross_attn_bias=cross_abs_pos_bias ) inner_states.append(x) if layer_attn is not None and idx == alignment_layer: attn = layer_attn.float().to(x) if attn is not None: if alignment_heads is not None: attn = attn[:alignment_heads] # average probabilities over heads attn = attn.mean(dim=0) if self.layer_norm is not None: x = self.layer_norm(x) # T x B x C -> B x T x C x = x.transpose(0, 1) if self.project_out_dim is not None: x = self.project_out_dim(x) return x, {"attn": [attn], "inner_states": inner_states} def output_layer(self, features): """Project features to the vocabulary size.""" if self.adaptive_softmax is None: # project back to size of vocabulary return self.output_projection(features) else: return features def max_positions(self): """Maximum output length supported by the decoder.""" if self.embed_positions is None: return self.max_target_positions return self.max_target_positions def buffered_future_mask(self, tensor): dim = tensor.size(0) # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround. if ( self._future_mask.size(0) == 0 or (not self._future_mask.device == tensor.device) or self._future_mask.size(0) < dim ): self._future_mask = torch.triu( utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1 ) self._future_mask = self._future_mask.to(tensor) return self._future_mask[:dim, :dim] def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = "{}.embed_positions.weights".format(name) if weights_key in state_dict: del state_dict[weights_key] state_dict[ "{}.embed_positions._float_tensor".format(name) ] = torch.FloatTensor(1) if f"{name}.output_projection.weight" not in state_dict: if self.share_input_output_embed: embed_out_key = f"{name}.embed_tokens.weight" else: embed_out_key = f"{name}.embed_out" if embed_out_key in state_dict: state_dict[f"{name}.output_projection.weight"] = state_dict[ embed_out_key ] if not self.share_input_output_embed: del state_dict[embed_out_key] for i in range(self.num_layers): # update layer norms self.layers[i].upgrade_state_dict_named( state_dict, "{}.layers.{}".format(name, i) ) # version_key = "{}.version".format(name) # if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2: # # earlier checkpoints did not normalize after the stack of layers # self.layer_norm = None # self.normalize = False # state_dict[version_key] = torch.Tensor([1]) prefix = name + "." if name != "" else "" image_params = ["image_position_idx"] for image_param in image_params: state_dict[prefix + image_param] = self.state_dict()[image_param] for param_name, param_tensor in self.state_dict().items(): if (prefix + param_name) not in state_dict and param_name in self.state_dict(): state_dict[prefix + param_name] = self.state_dict()[param_name] if len(state_dict["decoder.embed_image_positions.weight"]) < len(self.state_dict()["embed_image_positions.weight"]): num_posids_to_add = len(self.state_dict()["embed_image_positions.weight"]) - len(state_dict["decoder.embed_image_positions.weight"]) embed_dim = state_dict["decoder.embed_image_positions.weight"].size(1) new_pos_embed_to_add = torch.zeros(num_posids_to_add, embed_dim) nn.init.normal_(new_pos_embed_to_add, mean=0, std=embed_dim ** -0.5) new_pos_embed_to_add = new_pos_embed_to_add.to( dtype=state_dict["decoder.embed_image_positions.weight"].dtype, ) state_dict["decoder.embed_image_positions.weight"] = torch.cat( [state_dict["decoder.embed_image_positions.weight"], new_pos_embed_to_add] ) return state_dict def Embedding(num_embeddings, embedding_dim, padding_idx=None, zero_init=False): m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) if padding_idx is not None: nn.init.constant_(m.weight[padding_idx], 0) if zero_init: nn.init.constant_(m.weight, 0) return m def Linear(in_features, out_features, bias=True): m = nn.Linear(in_features, out_features, bias) nn.init.xavier_uniform_(m.weight) if bias: nn.init.constant_(m.bias, 0.0) return m @register_model_architecture("unify_transformer", "unify_transformer") def base_architecture(args): args.encoder_embed_path = getattr(args, "encoder_embed_path", None) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) args.encoder_layers = getattr(args, "encoder_layers", 6) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) args.decoder_embed_path = getattr(args, "decoder_embed_path", None) args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) args.decoder_ffn_embed_dim = getattr( args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim ) args.decoder_layers = getattr(args, "decoder_layers", 6) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) args.attention_dropout = getattr(args, "attention_dropout", 0.0) args.activation_dropout = getattr(args, "activation_dropout", 0.0) args.activation_fn = getattr(args, "activation_fn", "relu") args.dropout = getattr(args, "dropout", 0.1) args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) args.share_decoder_input_output_embed = getattr( args, "share_decoder_input_output_embed", False ) args.share_all_embeddings = getattr(args, "share_all_embeddings", False) args.no_token_positional_embeddings = getattr( args, "no_token_positional_embeddings", False ) args.adaptive_input = getattr(args, "adaptive_input", False) args.no_cross_attention = getattr(args, "no_cross_attention", False) args.cross_self_attention = getattr(args, "cross_self_attention", False) args.decoder_output_dim = getattr( args, "decoder_output_dim", args.decoder_embed_dim ) args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.layernorm_embedding = getattr(args, "layernorm_embedding", False) args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) args.checkpoint_activations = getattr(args, "checkpoint_activations", False) args.offload_activations = getattr(args, "offload_activations", False) if args.offload_activations: args.checkpoint_activations = True args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)