Update modeling_qwen2.py
Browse files- modeling_qwen2.py +4 -2
modeling_qwen2.py
CHANGED
|
@@ -375,7 +375,8 @@ class Qwen2Attention(nn.Module):
|
|
| 375 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 376 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 377 |
|
| 378 |
-
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
|
|
| 379 |
|
| 380 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
| 381 |
raise ValueError(
|
|
@@ -388,7 +389,8 @@ class Qwen2Attention(nn.Module):
|
|
| 388 |
attn_weights = attn_weights + causal_mask
|
| 389 |
|
| 390 |
# upcast attention to fp32
|
| 391 |
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
|
|
|
| 392 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
| 393 |
attn_output = torch.matmul(attn_weights, value_states)
|
| 394 |
|
|
|
|
| 375 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 376 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 377 |
|
| 378 |
+
# attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 379 |
+
attn_weights = torch.matmul(query_states.to(torch.float32), key_states.transpose(2, 3).to(torch.float32)) / math.sqrt(self.head_dim)
|
| 380 |
|
| 381 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
| 382 |
raise ValueError(
|
|
|
|
| 389 |
attn_weights = attn_weights + causal_mask
|
| 390 |
|
| 391 |
# upcast attention to fp32
|
| 392 |
+
# attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 393 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query_states.dtype)
|
| 394 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
| 395 |
attn_output = torch.matmul(attn_weights, value_states)
|
| 396 |
|