s-JoL commited on
Commit
0496919
1 Parent(s): 3c9628b

Update modeling_baichuan.py

Browse files
Files changed (1) hide show
  1. modeling_baichuan.py +1 -0
modeling_baichuan.py CHANGED
@@ -181,6 +181,7 @@ class BaichuanAttention(torch.nn.Module):
181
  # )
182
  with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
183
  attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask)
 
184
  else:
185
  attn_weights = torch.matmul(
186
  query_states, key_states.transpose(2, 3)
 
181
  # )
182
  with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
183
  attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask)
184
+ attn_output = attn_output.transpose(1, 2)
185
  else:
186
  attn_weights = torch.matmul(
187
  query_states, key_states.transpose(2, 3)