xusong28 commited on
Commit
9d1e66f
1 Parent(s): 19fb2f0
Files changed (1) hide show
  1. kplug/modeling_kplug_s2s_patch.py +10 -4
kplug/modeling_kplug_s2s_patch.py CHANGED
@@ -2,6 +2,10 @@
2
  # author: xusong <xusong28@jd.com>
3
  # time: 2022/3/3 14:18
4
 
 
 
 
 
5
 
6
  import torch
7
  from torch import nn
@@ -20,12 +24,14 @@ class KPlugLearnedPositionalEmbedding(nn.Embedding):
20
  self.offset = 2
21
  super().__init__(num_embeddings + self.offset, embedding_dim)
22
 
23
- def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
24
- """`input_ids_shape` is expected to be [bsz x seqlen]."""
25
- bsz, seq_len = input_ids_shape[:2]
 
26
  positions = torch.arange(
27
  past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
28
- )
 
29
  return super().forward(positions + self.offset)
30
 
31
 
 
2
  # author: xusong <xusong28@jd.com>
3
  # time: 2022/3/3 14:18
4
 
5
+ """
6
+ ?
7
+ """
8
+
9
 
10
  import torch
11
  from torch import nn
 
24
  self.offset = 2
25
  super().__init__(num_embeddings + self.offset, embedding_dim)
26
 
27
+ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
28
+ """`input_ids' shape is expected to be [bsz x seqlen]."""
29
+
30
+ bsz, seq_len = input_ids.shape[:2]
31
  positions = torch.arange(
32
  past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
33
+ ).expand(bsz, -1)
34
+
35
  return super().forward(positions + self.offset)
36
 
37