| | |
| | |
| | |
| | |
| |
|
| | from typing import Dict, Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from fairseq import utils |
| | from torch import Tensor |
| |
|
| |
|
| | class LearnedPositionalEmbedding(nn.Embedding): |
| | """ |
| | This module learns positional embeddings up to a fixed maximum size. |
| | Padding ids are ignored by either offsetting based on padding_idx |
| | or by setting padding_idx to None and ensuring that the appropriate |
| | position ids are passed to the forward function. |
| | """ |
| |
|
| | def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): |
| | super().__init__(num_embeddings, embedding_dim, padding_idx) |
| | self.onnx_trace = False |
| | if self.padding_idx is not None: |
| | self.max_positions = self.num_embeddings - self.padding_idx - 1 |
| | else: |
| | self.max_positions = self.num_embeddings |
| |
|
| | def forward( |
| | self, |
| | input: Tensor, |
| | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, |
| | positions: Optional[Tensor] = None, |
| | ): |
| | """Input is expected to be of size [bsz x seqlen].""" |
| | assert (positions is None) or ( |
| | self.padding_idx is None |
| | ), "If positions is pre-computed then padding_idx should not be set." |
| |
|
| | if positions is None: |
| | if incremental_state is not None: |
| | |
| | |
| | positions = torch.zeros( |
| | (1, 1), device=input.device, dtype=input.dtype |
| | ).fill_(int(self.padding_idx + input.size(1))) |
| | else: |
| | positions = utils.make_positions( |
| | input, self.padding_idx, onnx_trace=self.onnx_trace |
| | ) |
| | return F.embedding( |
| | positions, |
| | self.weight, |
| | self.padding_idx, |
| | self.max_norm, |
| | self.norm_type, |
| | self.scale_grad_by_freq, |
| | self.sparse, |
| | ) |
| |
|