OFA-Image_Caption / fairseq /fairseq /modules /learned_positional_embedding.py
JustinLin610
update
8437114
raw history blame
No virus
2.26 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from torch import Tensor
class LearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
Padding ids are ignored by either offsetting based on padding_idx
or by setting padding_idx to None and ensuring that the appropriate
position ids are passed to the forward function.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.onnx_trace = False
if self.padding_idx is not None:
self.max_positions = self.num_embeddings - self.padding_idx - 1
else:
self.max_positions = self.num_embeddings
def forward(
self,
input: Tensor,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
positions: Optional[Tensor] = None,
):
"""Input is expected to be of size [bsz x seqlen]."""
assert (positions is None) or (
self.padding_idx is None
), "If positions is pre-computed then padding_idx should not be set."
if positions is None:
if incremental_state is not None:
# positions is the same for every token when decoding a single step
# Without the int() cast, it doesn't work in some cases when exporting to ONNX
positions = torch.zeros(
(1, 1), device=input.device, dtype=input.dtype
).fill_(int(self.padding_idx + input.size(1)))
else:
positions = utils.make_positions(
input, self.padding_idx, onnx_trace=self.onnx_trace
)
return F.embedding(
positions,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)