Update modeling_baichuan.py
Browse files- modeling_baichuan.py +1 -1
modeling_baichuan.py
CHANGED
@@ -177,7 +177,7 @@ class BaichuanAttention(torch.nn.Module):
|
|
177 |
key_states = key_states.transpose(1, 2)
|
178 |
value_states = value_states.transpose(1, 2)
|
179 |
attn_output = xops.memory_efficient_attention(
|
180 |
-
query_states, key_states, value_states, attn_bias=
|
181 |
)
|
182 |
else:
|
183 |
attn_weights = torch.matmul(
|
|
|
177 |
key_states = key_states.transpose(1, 2)
|
178 |
value_states = value_states.transpose(1, 2)
|
179 |
attn_output = xops.memory_efficient_attention(
|
180 |
+
query_states, key_states, value_states, attn_bias=attention_mask.unsqueeze(0).expand(bsz, -1, -1, -1)
|
181 |
)
|
182 |
else:
|
183 |
attn_weights = torch.matmul(
|