xusong28
commited on
Commit
•
9d1e66f
1
Parent(s):
19fb2f0
update
Browse files
kplug/modeling_kplug_s2s_patch.py
CHANGED
@@ -2,6 +2,10 @@
|
|
2 |
# author: xusong <xusong28@jd.com>
|
3 |
# time: 2022/3/3 14:18
|
4 |
|
|
|
|
|
|
|
|
|
5 |
|
6 |
import torch
|
7 |
from torch import nn
|
@@ -20,12 +24,14 @@ class KPlugLearnedPositionalEmbedding(nn.Embedding):
|
|
20 |
self.offset = 2
|
21 |
super().__init__(num_embeddings + self.offset, embedding_dim)
|
22 |
|
23 |
-
def forward(self,
|
24 |
-
"""`
|
25 |
-
|
|
|
26 |
positions = torch.arange(
|
27 |
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
|
28 |
-
)
|
|
|
29 |
return super().forward(positions + self.offset)
|
30 |
|
31 |
|
|
|
2 |
# author: xusong <xusong28@jd.com>
|
3 |
# time: 2022/3/3 14:18
|
4 |
|
5 |
+
"""
|
6 |
+
?
|
7 |
+
"""
|
8 |
+
|
9 |
|
10 |
import torch
|
11 |
from torch import nn
|
|
|
24 |
self.offset = 2
|
25 |
super().__init__(num_embeddings + self.offset, embedding_dim)
|
26 |
|
27 |
+
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
|
28 |
+
"""`input_ids' shape is expected to be [bsz x seqlen]."""
|
29 |
+
|
30 |
+
bsz, seq_len = input_ids.shape[:2]
|
31 |
positions = torch.arange(
|
32 |
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
|
33 |
+
).expand(bsz, -1)
|
34 |
+
|
35 |
return super().forward(positions + self.offset)
|
36 |
|
37 |
|