kdexd's picture
Black + isort, remove unused virtx files.
8d0e872
raw history blame
No virus
3.65 kB
import functools
import torch
from torch import nn
class WordAndPositionalEmbedding(nn.Module):
r"""
A :class:`~torch.nn.Module` for learned word embeddings and position
embeddings for input tokens. Each token is mapped to a fixed dimensional
word embedding; and corresponding positional embedding based on its index.
These are summed together followed by layer normalization and an optional
dropout.
Parameters
----------
vocab_size: int
Size of token vocabulary.
hidden_size: int
Size of token embedding vectors.
dropout: float, optional (default = 0.1)
Dropout probability for final dropout applied after layer normalization.
max_caption_length: int, optional (default = 30)
Maximum length of input captions; this is used to create a fixed
positional embedding lookup table.
padding_idx: int, optional (default = 0)
Token index of ``[PAD]`` token, word embedding for these tokens will
be a vector of zeroes (and not trainable).
"""
def __init__(
self,
vocab_size: int,
hidden_size: int,
dropout: float = 0.0,
max_caption_length: int = 30,
padding_idx: int = 0,
):
super().__init__()
self.vocab_size = vocab_size
self.padding_idx = padding_idx
self.words = nn.Embedding(vocab_size, hidden_size, padding_idx=padding_idx)
# We provide no "padding index" for positional embeddings. We zero out
# the positional embeddings of padded positions as a post-processing.
self.positions = nn.Embedding(max_caption_length, hidden_size)
self.layer_norm = nn.LayerNorm(
hidden_size, eps=1e-8, elementwise_affine=True
)
self.dropout = nn.Dropout(p=dropout)
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
r"""
Get combined word and positional embeddings for input tokens.
Parameters
----------
tokens: torch.Tensor
A tensor of shape ``(batch_size, max_caption_length)`` containing
a batch of caption tokens, with values in ``[0, vocab_size)``.
Returns
-------
torch.Tensor
A tensor of shape ``(batch_size, max_caption_length, hidden_size)``
containing corresponding token embeddings.
"""
position_indices = self._create_position_indices(tokens)
# shape: (batch_size, max_caption_length, hidden_size)
word_embeddings = self.words(tokens)
position_embeddings = self.positions(position_indices)
# shape: (batch_size, max_caption_length, hidden_size)
embeddings = self.layer_norm(word_embeddings + position_embeddings)
embeddings = self.dropout(embeddings)
# Zero-out embeddings for positions which have padding tokens.
# shape: (batch_size, max_caption_length, 1)
token_mask = (tokens != self.padding_idx).unsqueeze(-1)
# shape: (batch_size, max_caption_length, hidden_size)
embeddings = embeddings * token_mask.type(embeddings.dtype)
return embeddings
@functools.lru_cache(maxsize=128)
def _create_position_indices(self, tokens: torch.Tensor):
# Create position indices of the same size as token indices.
batch_size, max_caption_length = tokens.size()
positions = torch.arange(
max_caption_length, dtype=tokens.dtype, device=tokens.device
)
# shape: (batch_size, max_caption_length)
positions = positions.unsqueeze(0).expand(batch_size, max_caption_length)
return positions