yangapku commited on
Commit
fcc99d6
1 Parent(s): f4b568f

update modeling_qwen.py

Browse files
Files changed (2) hide show
  1. assets/wechat.png +0 -0
  2. modeling_qwen.py +2 -1
assets/wechat.png CHANGED
modeling_qwen.py CHANGED
@@ -193,9 +193,10 @@ class FlashSelfAttention(torch.nn.Module):
193
  if attention_mask is not None:
194
  k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask)
195
  v = v[indices_k]
196
- if seqlen_q == seqlen_k:
197
  q = q[indices_k]
198
  cu_seqlens_q = cu_seqlens_k
 
199
  else:
200
  cu_seqlens_k = torch.arange(
201
  0,
 
193
  if attention_mask is not None:
194
  k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask)
195
  v = v[indices_k]
196
+ if self.training or q.size(0) == k.size(0):
197
  q = q[indices_k]
198
  cu_seqlens_q = cu_seqlens_k
199
+ seqlen_q = seqlen_k
200
  else:
201
  cu_seqlens_k = torch.arange(
202
  0,