File size: 1,956 Bytes
101e5d8 9d1e66f 01920f9 9d1e66f 101e5d8 01920f9 101e5d8 01920f9 101e5d8 01920f9 62edb7a 101e5d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
# coding=utf-8
# author: xusong <xusong28@jd.com>
# time: 2022/3/3 14:18
"""
?
"""
import torch
from torch import nn
import transformers
from transformers.models.bart import modeling_bart
class KPlugLearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
# padding_idx 用于版本兼容 (比如4.2.1版本中,是三个参数)
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int = None):
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
if padding_idx is None:
super().__init__(num_embeddings + self.offset, embedding_dim)
else:
super().__init__(num_embeddings + self.offset, embedding_dim, padding_idx=padding_idx)
if transformers.__version__ >= "4.22.0":
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
"""`input_ids' shape is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids.shape[:2]
positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
).expand(bsz, -1)
return super().forward(positions + self.offset)
else:
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2]
positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
)
return super().forward(positions + self.offset)
modeling_bart.BartLearnedPositionalEmbedding = KPlugLearnedPositionalEmbedding
|