import functools import torch from torch import nn class WordAndPositionalEmbedding(nn.Module): r""" A :class:`~torch.nn.Module` for learned word embeddings and position embeddings for input tokens. Each token is mapped to a fixed dimensional word embedding; and corresponding positional embedding based on its index. These are summed together followed by layer normalization and an optional dropout. Parameters ---------- vocab_size: int Size of token vocabulary. hidden_size: int Size of token embedding vectors. dropout: float, optional (default = 0.1) Dropout probability for final dropout applied after layer normalization. max_caption_length: int, optional (default = 30) Maximum length of input captions; this is used to create a fixed positional embedding lookup table. padding_idx: int, optional (default = 0) Token index of ``[PAD]`` token, word embedding for these tokens will be a vector of zeroes (and not trainable). """ def __init__( self, vocab_size: int, hidden_size: int, dropout: float = 0.0, max_caption_length: int = 30, padding_idx: int = 0, ): super().__init__() self.vocab_size = vocab_size self.padding_idx = padding_idx self.words = nn.Embedding(vocab_size, hidden_size, padding_idx=padding_idx) # We provide no "padding index" for positional embeddings. We zero out # the positional embeddings of padded positions as a post-processing. self.positions = nn.Embedding(max_caption_length, hidden_size) self.layer_norm = nn.LayerNorm( hidden_size, eps=1e-8, elementwise_affine=True ) self.dropout = nn.Dropout(p=dropout) def forward(self, tokens: torch.Tensor) -> torch.Tensor: r""" Get combined word and positional embeddings for input tokens. Parameters ---------- tokens: torch.Tensor A tensor of shape ``(batch_size, max_caption_length)`` containing a batch of caption tokens, with values in ``[0, vocab_size)``. Returns ------- torch.Tensor A tensor of shape ``(batch_size, max_caption_length, hidden_size)`` containing corresponding token embeddings. """ position_indices = self._create_position_indices(tokens) # shape: (batch_size, max_caption_length, hidden_size) word_embeddings = self.words(tokens) position_embeddings = self.positions(position_indices) # shape: (batch_size, max_caption_length, hidden_size) embeddings = self.layer_norm(word_embeddings + position_embeddings) embeddings = self.dropout(embeddings) # Zero-out embeddings for positions which have padding tokens. # shape: (batch_size, max_caption_length, 1) token_mask = (tokens != self.padding_idx).unsqueeze(-1) # shape: (batch_size, max_caption_length, hidden_size) embeddings = embeddings * token_mask.type(embeddings.dtype) return embeddings @functools.lru_cache(maxsize=128) def _create_position_indices(self, tokens: torch.Tensor): # Create position indices of the same size as token indices. batch_size, max_caption_length = tokens.size() positions = torch.arange( max_caption_length, dtype=tokens.dtype, device=tokens.device ) # shape: (batch_size, max_caption_length) positions = positions.unsqueeze(0).expand(batch_size, max_caption_length) return positions