yangapku commited on
Commit
de89198
1 Parent(s): c04bccd

update modeling_qwen.py

Browse files
Files changed (1) hide show
  1. modeling_qwen.py +4 -6
modeling_qwen.py CHANGED
@@ -520,9 +520,7 @@ class QWenAttention(nn.Module):
520
 
521
  if not self.use_cache_quantization and SUPPORT_TORCH2:
522
  if attention_mask is not None:
523
- attention_mask = attention_mask.expand(
524
- -1, -1, causal_mask.size(2), -1
525
- )
526
  if causal_mask is not None:
527
  attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
528
  else:
@@ -1330,14 +1328,14 @@ def apply_rotary_pos_emb(t, freqs):
1330
  t (tensor(batch_size, seq_len, n_head, head_dim)):
1331
  the input embedding/hidden states
1332
  freqs (list[tensor(1, seq_len, 1, rotary_dim), tensor(1, seq_len, 1, rotary_dim)]):
1333
- the cached cos/sin position embeddings
1334
  """
1335
  rot_dim = freqs[0].shape[-1]
1336
  cos, sin = freqs
1337
  t_float = t.float()
1338
  if apply_rotary_emb_func is not None and t.is_cuda:
1339
- # apply_rotary_emb in flash_attn requires cos/sin to be of
1340
- # shape (seqlen, rotary_dim / 2) and apply rotary embedding
1341
  # to the first rotary_dim of the input
1342
  cos = cos.squeeze(0).squeeze(1)[:, : rot_dim // 2]
1343
  sin = sin.squeeze(0).squeeze(1)[:, : rot_dim // 2]
 
520
 
521
  if not self.use_cache_quantization and SUPPORT_TORCH2:
522
  if attention_mask is not None:
523
+ attention_mask = attention_mask.expand(-1, -1, key_size, -1)
 
 
524
  if causal_mask is not None:
525
  attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
526
  else:
 
1328
  t (tensor(batch_size, seq_len, n_head, head_dim)):
1329
  the input embedding/hidden states
1330
  freqs (list[tensor(1, seq_len, 1, rotary_dim), tensor(1, seq_len, 1, rotary_dim)]):
1331
+ the cached cos/sin position embeddings
1332
  """
1333
  rot_dim = freqs[0].shape[-1]
1334
  cos, sin = freqs
1335
  t_float = t.float()
1336
  if apply_rotary_emb_func is not None and t.is_cuda:
1337
+ # apply_rotary_emb in flash_attn requires cos/sin to be of
1338
+ # shape (seqlen, rotary_dim / 2) and apply rotary embedding
1339
  # to the first rotary_dim of the input
1340
  cos = cos.squeeze(0).squeeze(1)[:, : rot_dim // 2]
1341
  sin = sin.squeeze(0).squeeze(1)[:, : rot_dim // 2]