r""" A textual head accepts visual features from the visual backbone, and performs task specific modeling (captioning, classification etc.) to predict an output distribution over vocabulary tokens for one or multiple time-steps in the batch. """ import torch from torch import nn from typing import Optional from virtex.modules.embedding import WordAndPositionalEmbedding from virtex.modules.transformer import ( PreNormTransformerEncoderLayer, PreNormTransformerDecoderLayer, ) class TextualHead(nn.Module): r""" Base class for all textual heads. All child classes can simply inherit from :class:`~torch.nn.Module`, however this is kept here for uniform type annotations. Parameters ---------- visual_feature_size: int Size (number of channels) of the input features from the visual backbone. vocab_size: int Number of tokens in the output vocabulary. hidden_size: int Size of the token embedding vectors, or hidden state vector of the language model. """ def __init__(self, visual_feature_size: int, vocab_size: int, hidden_size: int): super().__init__() self.visual_feature_size = visual_feature_size self.vocab_size = vocab_size self.hidden_size = hidden_size @property def textual_feature_size(self): r""" Size of the last dimension of output right before the output linear layer (which predicts a distribution over vocabulary tokens). This is typically same as :attr:`hidden_size` for most modules. This property is used to add more modules on top of this. """ return self.hidden_size class LinearTextualHead(TextualHead): r""" A textual head containing a single linear layer projecting from the visual feature size to the output vocabulary size. Parameters ---------- visual_feature_size: int Size (number of channels) of the input features from the visual backbone. vocab_size: int Number of tokens in the output vocabulary. """ def __init__(self, visual_feature_size: int, vocab_size: int, **kwargs): # For API consistency. hidden_size = visual_feature_size super().__init__(visual_feature_size, vocab_size, hidden_size) self.output = nn.Linear(visual_feature_size, vocab_size) def forward( self, visual_features: torch.Tensor, caption_tokens: Optional[torch.Tensor] = None, caption_lengths: Optional[torch.Tensor] = None, ) -> torch.Tensor: r""" Project visual features directly to predict a distribution over vocabulary tokens through a single linear layer. This textual head ignores arguments ``caption_tokens`` and ``caption_lengths``, they are here for API consistency. Parameters ---------- visual_features: torch.Tensor A tensor of shape ``(batch_size, channels, height, width)`` containing features from visual backbone. Returns ------- torch.Tensor A tensor of shape ``(batch_size, vocab_size)`` containing output vocabulary logits. """ # Convert to NHWC and project visual features to textual feature size. batch_size, channels, height, width = visual_features.size() visual_features = visual_features.view(batch_size, channels, -1) visual_features = visual_features.permute(0, 2, 1) # Perform global average pooling of visual features. # shape: (batch_size, channels) visual_features = visual_features.mean(dim=1) # shape: (batch_size, max_caption_length, vocab_size) output_logits = self.output(visual_features) return output_logits class TransformerDecoderTextualHead(TextualHead): r""" A textual head composed of four main modules: (1) input projection (linear layer) for visual features to match size with textual features, (2) word and positional embedding for input captions, (3) a unidirectional transformer decoder, and (4) and output projection (linear layer) to predict a distribution over vocabulary tokens. The word embedding weights are tied with output projection; the latter still has its own learnable bias. .. note:: For the "bicaptioning" pretraining task, our *textual head* (as defined in the paper) must have two transformer decoders: one each to decode caption in either direction. This class however will always have one transformer per object. Refer :class:`~virtex.models.captioning.BidirectionalCaptioningModel` source to understand how an object of this class is cloned, along with tying embedding and output weights, for bicaptioning. Hence, while there are *two objects* of this class, it is pragmatically a *single* textual head as a whole, according to the terminology used in paper. Parameters ---------- visual_feature_size: int Size (number of channels) of the input features from the visual backbone. vocab_size: int Number of tokens in the output vocabulary. hidden_size: int Size of the token embedding vectors, or hidden state vector of the language model. num_layers: int Number of layers in the transformer. attention_heads: int Number of attention heads in the transformer. feedforward_size: int Size of feedforward layers in the transformer. dropout: float, optional (default = 0.1) Dropout probability for transformer (applied after layer normalization). norm_type: str, optional (default = "post") Type of transformer layer: pre-normalization (like GPT-2) or post-normalization (like BERT). One of ``{"pre", "post"}``. mask_future_positions: bool, optional (default = True) Whether to mask future positions for self-attention over caption tokens. This must be ``True`` for captioning (and bicaptioning) tasks to prevent the language model from cheating, and ``False`` for masked language modeling, as the self-attention should consider all tokens. max_caption_length: int, optional (default = 30) Maximum length of input captions; this is used to create a fixed positional embedding lookup table. padding_idx: int, optional (default = 0) Token index of ``[PAD]`` token, word embedding for these tokens will be a vector of zeroes (and not trainable). """ def __init__( self, visual_feature_size: int, vocab_size: int, hidden_size: int, num_layers: int, attention_heads: int, feedforward_size: int, dropout: float = 0.1, norm_type: str = "post", mask_future_positions: bool = True, max_caption_length: int = 30, padding_idx: int = 0, ): super().__init__(visual_feature_size, vocab_size, hidden_size) self.num_layers = num_layers self.attention_heads = attention_heads self.feedforward_size = feedforward_size self.dropout = dropout self.mask_future_positions = mask_future_positions self.padding_idx = padding_idx self.visual_projection = nn.Linear( visual_feature_size, self.textual_feature_size ) self.embedding = WordAndPositionalEmbedding( self.vocab_size, self.textual_feature_size, dropout=dropout, max_caption_length=max_caption_length, padding_idx=padding_idx, ) # Make decoder layer depending on whether it's a Pre-Norm or Post-Norm. LayerClass = ( nn.TransformerDecoderLayer if norm_type == "post" else PreNormTransformerDecoderLayer ) _layer = LayerClass( self.textual_feature_size, self.attention_heads, dim_feedforward=self.feedforward_size, dropout=dropout, activation="gelu", ) self.transformer = nn.TransformerDecoder(_layer, self.num_layers) self.apply(self._init_weights) # Create an output linear layer and tie the input and output word # embeddings to reduce parameters. self.output = nn.Linear(self.textual_feature_size, vocab_size) self.output.weight = self.embedding.words.weight @staticmethod def _init_weights(module): r"""Initialize weights like BERT - N(0.0, 0.02), bias = 0.""" if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.MultiheadAttention): module.in_proj_weight.data.normal_(mean=0.0, std=0.02) module.out_proj.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=0.02) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def forward( self, visual_features: torch.Tensor, caption_tokens: torch.Tensor, caption_lengths: torch.Tensor, ) -> torch.Tensor: r""" Given (projected) visual features from visual backbone and caption tokens, predict the output logits for next time-step. Parameters ---------- visual_features: torch.Tensor A tensor of shape ``(batch_size, channels, height, width)`` containing features from visual backbone. caption_tokens: torch.Tensor A tensor of shape ``(batch_size, max_caption_length)`` of caption tokens padded to the right by ``padding_idx``. caption_lengths: torch.Tensor A tensor of shape ``(batch_size, )`` containing lengths of caption tokens in the batch. Returns ------- torch.Tensor A tensor of shape ``(batch_size, max_caption_length, vocab_size)`` containing output vocabulary logits for each time-step. """ # Convert to NHWC and project visual features to textual feature size. batch_size, channels, height, width = visual_features.size() visual_features = visual_features.view(batch_size, channels, -1) visual_features = visual_features.permute(0, 2, 1) # shape: (batch_size, height * width, textual_feature_size) projected_visual_features = self.visual_projection(visual_features) # Now visual and textual features are of same size. # Note that `max_caption_length` here may be less than the # `max_caption_length` passed in `__init__`, but it does not matter. batch_size, max_caption_length = caption_tokens.size() # Create a mask based on caption lengths, shape: (batch_size, ) # Form a binary mask: it is True for padding positions. # These positions will be ignored for multi-headed attention. ones = torch.ones_like(caption_tokens) caption_mask = caption_lengths.unsqueeze(1) < ones.cumsum(dim=1) # shape: (batch_size, max_caption_length, textual_feature_size) caption_embeddings = self.embedding(caption_tokens) if self.mask_future_positions: # An additive mask for masking the future (one direction). unidirectional_mask = self._generate_future_mask( max_caption_length, caption_embeddings.dtype, caption_embeddings.device ) else: unidirectional_mask = None # We transpose the first two dimensions of tokens embeddings and visual # features, as required by decoder. caption_embeddings = caption_embeddings.transpose(0, 1) projected_visual_features = projected_visual_features.transpose(0, 1) # shape: (max_caption_length, batch_size, hidden_size) textual_features = self.transformer( caption_embeddings, projected_visual_features, tgt_mask=unidirectional_mask, tgt_key_padding_mask=caption_mask, ) # Undo the transpose and bring batch to dim 0. # shape: (batch_size, max_caption_length, hidden_size) textual_features = textual_features.transpose(0, 1) # shape: (batch_size, max_caption_length, vocab_size) output_logits = self.output(textual_features) return output_logits def _generate_future_mask( self, size: int, dtype: torch.dtype, device: torch.device ) -> torch.Tensor: r""" Generate a mask for "future" positions, useful when using this module for language modeling. Parameters ---------- size: int """ # Default mask is for forward direction. Flip for backward direction. mask = torch.triu( torch.ones(size, size, device=device, dtype=dtype), diagonal=1 ) mask = mask.masked_fill(mask == 1, float("-inf")) return mask class TransformerEncoderTextualHead(TextualHead): def __init__( self, visual_feature_size: int, vocab_size: int, hidden_size: int, num_layers: int, attention_heads: int, feedforward_size: int, dropout: float = 0.1, norm_type: str = "pre", mask_future_positions: bool = True, max_caption_length: int = 30, padding_idx: int = 0, ): super().__init__(visual_feature_size, vocab_size, hidden_size) self.num_layers = num_layers self.attention_heads = attention_heads self.feedforward_size = feedforward_size self.dropout = dropout self.mask_future_positions = mask_future_positions self.padding_idx = padding_idx self.embedding = WordAndPositionalEmbedding( self.vocab_size, self.textual_feature_size, dropout=dropout, max_caption_length=max_caption_length, padding_idx=padding_idx, ) # Make decoder layer depending on whether it's a Pre-Norm or Post-Norm. LayerClass = ( nn.TransformerEncoderLayer if norm_type == "post" else PreNormTransformerEncoderLayer ) _layer = LayerClass( self.textual_feature_size, self.attention_heads, dim_feedforward=self.feedforward_size, dropout=dropout, activation="gelu", ) self.transformer = nn.TransformerEncoder(_layer, self.num_layers) self.final_ln = nn.LayerNorm(self.textual_feature_size) self._init_weights() def _init_weights(self): nn.init.normal_(self.embedding.words.weight, std=0.02) nn.init.normal_(self.embedding.positions.weight, std=0.01) proj_std = (self.hidden_size ** -0.5) * ((2 * self.num_layers) ** -0.5) for layer in self.transformer.layers: nn.init.normal_(layer.self_attn.in_proj_weight, std=self.hidden_size ** -0.5) nn.init.normal_(layer.self_attn.out_proj.weight, std=proj_std) nn.init.normal_(layer.linear1.weight, std=(2 * self.hidden_size) ** -0.5) nn.init.normal_(layer.linear2.weight, std=proj_std) def forward( self, caption_tokens: torch.Tensor, caption_lengths: torch.Tensor, ) -> torch.Tensor: # Note that `max_caption_length` here may be less than the # `max_caption_length` passed in `__init__`, but it does not matter. batch_size, max_caption_length = caption_tokens.size() # Create a mask based on caption lengths, shape: (batch_size, ) # Form a binary mask: it is True for padding positions. # These positions will be ignored for multi-headed attention. ones = torch.ones_like(caption_tokens) caption_mask = caption_lengths.unsqueeze(1) < ones.cumsum(dim=1) # shape: (batch_size, max_caption_length, textual_feature_size) caption_embeddings = self.embedding(caption_tokens) if self.mask_future_positions: # An additive mask for masking the future (one direction). unidirectional_mask = self._generate_future_mask( max_caption_length, caption_embeddings.dtype, caption_embeddings.device ) else: unidirectional_mask = None # We transpose the first two dimensions of tokens embeddings and visual # features, as required by decoder. caption_embeddings = caption_embeddings.transpose(0, 1) # shape: (max_caption_length, batch_size, hidden_size) textual_features = self.transformer( caption_embeddings, mask=unidirectional_mask, src_key_padding_mask=caption_mask, ) # Undo the transpose and bring batch to dim 0. # shape: (batch_size, max_caption_length, hidden_size) textual_features = textual_features.transpose(0, 1) textual_features = self.final_ln(textual_features) return textual_features @staticmethod def _generate_future_mask( size: int, dtype: torch.dtype, device: torch.device ) -> torch.Tensor: r""" Generate a mask for "future" positions, useful when using this module for language modeling. Parameters ---------- size: int """ # Default mask is for forward direction. Flip for backward direction. mask = torch.triu( torch.ones(size, size, device=device, dtype=dtype), diagonal=1 ) mask = mask.masked_fill(mask == 1, float("-inf")) return mask