s-JoL commited on
Commit
be60123
1 Parent(s): 7a4db28

Update modeling_baichuan.py

Browse files
Files changed (1) hide show
  1. modeling_baichuan.py +1 -1
modeling_baichuan.py CHANGED
@@ -177,7 +177,7 @@ class BaichuanAttention(torch.nn.Module):
177
  key_states = key_states.transpose(1, 2)
178
  value_states = value_states.transpose(1, 2)
179
  attn_output = xops.memory_efficient_attention(
180
- query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask()
181
  )
182
  else:
183
  attn_weights = torch.matmul(
 
177
  key_states = key_states.transpose(1, 2)
178
  value_states = value_states.transpose(1, 2)
179
  attn_output = xops.memory_efficient_attention(
180
+ query_states, key_states, value_states, attn_bias=attention_mask.unsqueeze(0).expand(bsz, -1, -1, -1)
181
  )
182
  else:
183
  attn_weights = torch.matmul(