fix training
Browse files- attention.py +3 -0
attention.py
CHANGED
@@ -332,6 +332,7 @@ class MultiheadAttention(nn.Module, Attn):
|
|
332 |
key: torch.Tensor,
|
333 |
value: torch.Tensor,
|
334 |
n_heads: int,
|
|
|
335 |
softmax_scale: Optional[float],
|
336 |
attn_bias: Optional[torch.Tensor],
|
337 |
key_padding_mask: Optional[torch.ByteTensor],
|
@@ -345,6 +346,7 @@ class MultiheadAttention(nn.Module, Attn):
|
|
345 |
key,
|
346 |
value,
|
347 |
n_heads,
|
|
|
348 |
softmax_scale,
|
349 |
attn_bias,
|
350 |
key_padding_mask,
|
@@ -361,6 +363,7 @@ class MultiheadAttention(nn.Module, Attn):
|
|
361 |
key,
|
362 |
value,
|
363 |
self.n_heads,
|
|
|
364 |
self.softmax_scale,
|
365 |
attn_bias,
|
366 |
key_padding_mask,
|
|
|
332 |
key: torch.Tensor,
|
333 |
value: torch.Tensor,
|
334 |
n_heads: int,
|
335 |
+
past_key_value,
|
336 |
softmax_scale: Optional[float],
|
337 |
attn_bias: Optional[torch.Tensor],
|
338 |
key_padding_mask: Optional[torch.ByteTensor],
|
|
|
346 |
key,
|
347 |
value,
|
348 |
n_heads,
|
349 |
+
past_key_value,
|
350 |
softmax_scale,
|
351 |
attn_bias,
|
352 |
key_padding_mask,
|
|
|
363 |
key,
|
364 |
value,
|
365 |
self.n_heads,
|
366 |
+
past_key_value,
|
367 |
self.softmax_scale,
|
368 |
attn_bias,
|
369 |
key_padding_mask,
|