xusong28 commited on
Commit
ae1bf72
1 Parent(s): fff27c0
Files changed (1) hide show
  1. kplug/modeling_kplug_s2s_patch.py +15 -6
kplug/modeling_kplug_s2s_patch.py CHANGED
@@ -24,14 +24,23 @@ class KPlugLearnedPositionalEmbedding(nn.Embedding):
24
  self.offset = 2
25
  super().__init__(num_embeddings + self.offset, embedding_dim)
26
 
27
- def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
28
- """`input_ids' shape is expected to be [bsz x seqlen]."""
29
-
30
- bsz, seq_len = input_ids.shape[:2]
 
 
 
 
 
 
 
 
 
 
31
  positions = torch.arange(
32
  past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
33
- ).expand(bsz, -1)
34
-
35
  return super().forward(positions + self.offset)
36
 
37
 
 
24
  self.offset = 2
25
  super().__init__(num_embeddings + self.offset, embedding_dim)
26
 
27
+ ### 4.21.1 之后的版本
28
+ # def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
29
+ # """`input_ids' shape is expected to be [bsz x seqlen]."""
30
+ #
31
+ # bsz, seq_len = input_ids.shape[:2]
32
+ # positions = torch.arange(
33
+ # past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
34
+ # ).expand(bsz, -1)
35
+ #
36
+ # return super().forward(positions + self.offset)
37
+
38
+ def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
39
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
40
+ bsz, seq_len = input_ids_shape[:2]
41
  positions = torch.arange(
42
  past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
43
+ )
 
44
  return super().forward(positions + self.offset)
45
 
46