|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from torch import nn |
|
from transformers.models.bart import modeling_bart |
|
|
|
|
|
|
|
class KPlugLearnedPositionalEmbedding(nn.Embedding): |
|
""" |
|
This module learns positional embeddings up to a fixed maximum size. |
|
""" |
|
|
|
def __init__(self, num_embeddings: int, embedding_dim: int): |
|
|
|
|
|
self.offset = 2 |
|
super().__init__(num_embeddings + self.offset, embedding_dim) |
|
|
|
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): |
|
"""`input_ids_shape` is expected to be [bsz x seqlen].""" |
|
bsz, seq_len = input_ids_shape[:2] |
|
positions = torch.arange( |
|
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device |
|
) |
|
return super().forward(positions + self.offset) |
|
|
|
|
|
modeling_bart.BartLearnedPositionalEmbedding = KPlugLearnedPositionalEmbedding |
|
|
|
|