p2o6e100 commited on
Commit
a709ac6
·
verified ·
1 Parent(s): b328734

Update modeling_qwen2.py

Browse files
Files changed (1) hide show
  1. modeling_qwen2.py +4 -2
modeling_qwen2.py CHANGED
@@ -375,7 +375,8 @@ class Qwen2Attention(nn.Module):
375
  key_states = repeat_kv(key_states, self.num_key_value_groups)
376
  value_states = repeat_kv(value_states, self.num_key_value_groups)
377
 
378
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 
379
 
380
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
381
  raise ValueError(
@@ -388,7 +389,8 @@ class Qwen2Attention(nn.Module):
388
  attn_weights = attn_weights + causal_mask
389
 
390
  # upcast attention to fp32
391
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 
392
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
393
  attn_output = torch.matmul(attn_weights, value_states)
394
 
 
375
  key_states = repeat_kv(key_states, self.num_key_value_groups)
376
  value_states = repeat_kv(value_states, self.num_key_value_groups)
377
 
378
+ # attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
379
+ attn_weights = torch.matmul(query_states.to(torch.float32), key_states.transpose(2, 3).to(torch.float32)) / math.sqrt(self.head_dim)
380
 
381
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
382
  raise ValueError(
 
389
  attn_weights = attn_weights + causal_mask
390
 
391
  # upcast attention to fp32
392
+ # attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
393
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query_states.dtype)
394
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
395
  attn_output = torch.matmul(attn_weights, value_states)
396