|
|
|
|
|
|
|
|
|
""" |
|
? |
|
""" |
|
|
|
|
|
import torch |
|
from torch import nn |
|
import transformers |
|
from transformers.models.bart import modeling_bart |
|
|
|
|
|
class KPlugLearnedPositionalEmbedding(nn.Embedding): |
|
""" |
|
This module learns positional embeddings up to a fixed maximum size. |
|
""" |
|
|
|
|
|
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int = None): |
|
|
|
|
|
self.offset = 2 |
|
|
|
if padding_idx is None: |
|
super().__init__(num_embeddings + self.offset, embedding_dim) |
|
else: |
|
super().__init__(num_embeddings + self.offset, embedding_dim, padding_idx=padding_idx) |
|
|
|
|
|
if transformers.__version__ >= "4.22.0": |
|
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): |
|
"""`input_ids' shape is expected to be [bsz x seqlen].""" |
|
|
|
bsz, seq_len = input_ids.shape[:2] |
|
positions = torch.arange( |
|
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device |
|
).expand(bsz, -1) |
|
|
|
return super().forward(positions + self.offset) |
|
else: |
|
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): |
|
"""`input_ids_shape` is expected to be [bsz x seqlen].""" |
|
bsz, seq_len = input_ids_shape[:2] |
|
positions = torch.arange( |
|
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device |
|
) |
|
return super().forward(positions + self.offset) |
|
|
|
|
|
modeling_bart.BartLearnedPositionalEmbedding = KPlugLearnedPositionalEmbedding |
|
|
|
|