Use input attention mask instead of casual mask in attention
#101
by
CyberZHG
- opened
- modelling_RW.py +2 -2
modelling_RW.py
CHANGED
@@ -281,13 +281,14 @@ class Attention(nn.Module):
|
|
281 |
else:
|
282 |
present = None
|
283 |
|
|
|
284 |
if alibi is None:
|
285 |
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
286 |
key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
287 |
value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
288 |
|
289 |
attn_output = F.scaled_dot_product_attention(
|
290 |
-
query_layer_, key_layer_, value_layer_,
|
291 |
)
|
292 |
|
293 |
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
@@ -300,7 +301,6 @@ class Attention(nn.Module):
|
|
300 |
assert not output_attentions # not supported.
|
301 |
return outputs
|
302 |
else:
|
303 |
-
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
|
304 |
matmul_result = query_layer @ key_layer.transpose(-1, -2)
|
305 |
|
306 |
# change view to [batch_size, num_heads, q_length, kv_length]
|
|
|
281 |
else:
|
282 |
present = None
|
283 |
|
284 |
+
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(query_layer.dtype)
|
285 |
if alibi is None:
|
286 |
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
287 |
key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
288 |
value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
289 |
|
290 |
attn_output = F.scaled_dot_product_attention(
|
291 |
+
query_layer_, key_layer_, value_layer_, attention_mask_float, 0.0, is_causal=False
|
292 |
)
|
293 |
|
294 |
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
|
|
301 |
assert not output_attentions # not supported.
|
302 |
return outputs
|
303 |
else:
|
|
|
304 |
matmul_result = query_layer @ key_layer.transpose(-1, -2)
|
305 |
|
306 |
# change view to [batch_size, num_heads, q_length, kv_length]
|