File size: 1,956 Bytes
101e5d8
 
 
 
9d1e66f
01920f9
9d1e66f
 
101e5d8
 
 
01920f9
101e5d8
 
 
 
 
 
 
 
01920f9
 
101e5d8
 
 
01920f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62edb7a
101e5d8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# coding=utf-8
# author: xusong <xusong28@jd.com>
# time: 2022/3/3 14:18

"""

"""


import torch
from torch import nn
import transformers
from transformers.models.bart import modeling_bart


class KPlugLearnedPositionalEmbedding(nn.Embedding):
    """
    This module learns positional embeddings up to a fixed maximum size.
    """

    # padding_idx 用于版本兼容 (比如4.2.1版本中,是三个参数)
    def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int = None):
        # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
        # and adjust num_embeddings appropriately. Other models don't have this hack
        self.offset = 2

        if padding_idx is None:
            super().__init__(num_embeddings + self.offset, embedding_dim)
        else:
            super().__init__(num_embeddings + self.offset, embedding_dim, padding_idx=padding_idx)


    if transformers.__version__ >= "4.22.0":
        def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
            """`input_ids' shape is expected to be [bsz x seqlen]."""

            bsz, seq_len = input_ids.shape[:2]
            positions = torch.arange(
                past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
            ).expand(bsz, -1)

            return super().forward(positions + self.offset)
    else:
        def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
            """`input_ids_shape` is expected to be [bsz x seqlen]."""
            bsz, seq_len = input_ids_shape[:2]
            positions = torch.arange(
                past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
            )
            return super().forward(positions + self.offset)


modeling_bart.BartLearnedPositionalEmbedding = KPlugLearnedPositionalEmbedding