Update modeling_baichuan.py
Browse files- modeling_baichuan.py +8 -6
modeling_baichuan.py
CHANGED
@@ -173,12 +173,14 @@ class BaichuanAttention(torch.nn.Module):
|
|
173 |
past_key_value = (key_states, value_states) if use_cache else None
|
174 |
if xops is not None and self.training:
|
175 |
attn_weights = None
|
176 |
-
query_states = query_states.transpose(1, 2)
|
177 |
-
key_states = key_states.transpose(1, 2)
|
178 |
-
value_states = value_states.transpose(1, 2)
|
179 |
-
attn_output = xops.memory_efficient_attention(
|
180 |
-
|
181 |
-
)
|
|
|
|
|
182 |
else:
|
183 |
attn_weights = torch.matmul(
|
184 |
query_states, key_states.transpose(2, 3)
|
|
|
173 |
past_key_value = (key_states, value_states) if use_cache else None
|
174 |
if xops is not None and self.training:
|
175 |
attn_weights = None
|
176 |
+
# query_states = query_states.transpose(1, 2)
|
177 |
+
# key_states = key_states.transpose(1, 2)
|
178 |
+
# value_states = value_states.transpose(1, 2)
|
179 |
+
# attn_output = xops.memory_efficient_attention(
|
180 |
+
# query_states, key_states, value_states, attn_bias=attention_mask
|
181 |
+
# )
|
182 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
183 |
+
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask)
|
184 |
else:
|
185 |
attn_weights = torch.matmul(
|
186 |
query_states, key_states.transpose(2, 3)
|