s-JoL commited on
Commit
3c9628b
1 Parent(s): a9f41a6

Update modeling_baichuan.py

Browse files
Files changed (1) hide show
  1. modeling_baichuan.py +8 -6
modeling_baichuan.py CHANGED
@@ -173,12 +173,14 @@ class BaichuanAttention(torch.nn.Module):
173
  past_key_value = (key_states, value_states) if use_cache else None
174
  if xops is not None and self.training:
175
  attn_weights = None
176
- query_states = query_states.transpose(1, 2)
177
- key_states = key_states.transpose(1, 2)
178
- value_states = value_states.transpose(1, 2)
179
- attn_output = xops.memory_efficient_attention(
180
- query_states, key_states, value_states, attn_bias=attention_mask
181
- )
 
 
182
  else:
183
  attn_weights = torch.matmul(
184
  query_states, key_states.transpose(2, 3)
 
173
  past_key_value = (key_states, value_states) if use_cache else None
174
  if xops is not None and self.training:
175
  attn_weights = None
176
+ # query_states = query_states.transpose(1, 2)
177
+ # key_states = key_states.transpose(1, 2)
178
+ # value_states = value_states.transpose(1, 2)
179
+ # attn_output = xops.memory_efficient_attention(
180
+ # query_states, key_states, value_states, attn_bias=attention_mask
181
+ # )
182
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
183
+ attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask)
184
  else:
185
  attn_weights = torch.matmul(
186
  query_states, key_states.transpose(2, 3)