Spaces:
Runtime error
Runtime error
import functools | |
import torch | |
from torch import nn | |
class WordAndPositionalEmbedding(nn.Module): | |
r""" | |
A :class:`~torch.nn.Module` for learned word embeddings and position | |
embeddings for input tokens. Each token is mapped to a fixed dimensional | |
word embedding; and corresponding positional embedding based on its index. | |
These are summed together followed by layer normalization and an optional | |
dropout. | |
Parameters | |
---------- | |
vocab_size: int | |
Size of token vocabulary. | |
hidden_size: int | |
Size of token embedding vectors. | |
dropout: float, optional (default = 0.1) | |
Dropout probability for final dropout applied after layer normalization. | |
max_caption_length: int, optional (default = 30) | |
Maximum length of input captions; this is used to create a fixed | |
positional embedding lookup table. | |
padding_idx: int, optional (default = 0) | |
Token index of ``[PAD]`` token, word embedding for these tokens will | |
be a vector of zeroes (and not trainable). | |
""" | |
def __init__( | |
self, | |
vocab_size: int, | |
hidden_size: int, | |
dropout: float = 0.0, | |
max_caption_length: int = 30, | |
padding_idx: int = 0, | |
): | |
super().__init__() | |
self.vocab_size = vocab_size | |
self.padding_idx = padding_idx | |
self.words = nn.Embedding(vocab_size, hidden_size, padding_idx=padding_idx) | |
# We provide no "padding index" for positional embeddings. We zero out | |
# the positional embeddings of padded positions as a post-processing. | |
self.positions = nn.Embedding(max_caption_length, hidden_size) | |
self.layer_norm = nn.LayerNorm( | |
hidden_size, eps=1e-8, elementwise_affine=True | |
) | |
self.dropout = nn.Dropout(p=dropout) | |
def forward(self, tokens: torch.Tensor) -> torch.Tensor: | |
r""" | |
Get combined word and positional embeddings for input tokens. | |
Parameters | |
---------- | |
tokens: torch.Tensor | |
A tensor of shape ``(batch_size, max_caption_length)`` containing | |
a batch of caption tokens, with values in ``[0, vocab_size)``. | |
Returns | |
------- | |
torch.Tensor | |
A tensor of shape ``(batch_size, max_caption_length, hidden_size)`` | |
containing corresponding token embeddings. | |
""" | |
position_indices = self._create_position_indices(tokens) | |
# shape: (batch_size, max_caption_length, hidden_size) | |
word_embeddings = self.words(tokens) | |
position_embeddings = self.positions(position_indices) | |
# shape: (batch_size, max_caption_length, hidden_size) | |
embeddings = self.layer_norm(word_embeddings + position_embeddings) | |
embeddings = self.dropout(embeddings) | |
# Zero-out embeddings for positions which have padding tokens. | |
# shape: (batch_size, max_caption_length, 1) | |
token_mask = (tokens != self.padding_idx).unsqueeze(-1) | |
# shape: (batch_size, max_caption_length, hidden_size) | |
embeddings = embeddings * token_mask.type(embeddings.dtype) | |
return embeddings | |
def _create_position_indices(self, tokens: torch.Tensor): | |
# Create position indices of the same size as token indices. | |
batch_size, max_caption_length = tokens.size() | |
positions = torch.arange( | |
max_caption_length, dtype=tokens.dtype, device=tokens.device | |
) | |
# shape: (batch_size, max_caption_length) | |
positions = positions.unsqueeze(0).expand(batch_size, max_caption_length) | |
return positions | |