# coding=utf-8 # author: xusong # time: 2022/3/3 14:18 """ ? """ import torch from torch import nn import transformers from transformers.models.bart import modeling_bart class KPlugLearnedPositionalEmbedding(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. """ # padding_idx 用于版本兼容 (比如4.2.1版本中,是三个参数) def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int = None): # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 # and adjust num_embeddings appropriately. Other models don't have this hack self.offset = 2 if padding_idx is None: super().__init__(num_embeddings + self.offset, embedding_dim) else: super().__init__(num_embeddings + self.offset, embedding_dim, padding_idx=padding_idx) if transformers.__version__ >= "4.22.0": def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): """`input_ids' shape is expected to be [bsz x seqlen].""" bsz, seq_len = input_ids.shape[:2] positions = torch.arange( past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device ).expand(bsz, -1) return super().forward(positions + self.offset) else: def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): """`input_ids_shape` is expected to be [bsz x seqlen].""" bsz, seq_len = input_ids_shape[:2] positions = torch.arange( past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device ) return super().forward(positions + self.offset) modeling_bart.BartLearnedPositionalEmbedding = KPlugLearnedPositionalEmbedding