oweller2 commited on
Commit
9e4ff15
1 Parent(s): f66abc1
Files changed (1) hide show
  1. attention.py +1 -1
attention.py CHANGED
@@ -863,7 +863,7 @@ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
863
  qkv = self.Wqkv(hidden_states)
864
 
865
  # only needed for inference when we have KV cache
866
- seqlen_offset = max_seqlen * (len(cu_seqlens) - 2) if len(cu_seqlens) > 1 else 0
867
 
868
  # (total_seqlen, 3, nheads, headdim)
869
  qkv = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size)
 
863
  qkv = self.Wqkv(hidden_states)
864
 
865
  # only needed for inference when we have KV cache
866
+ seqlen_offset = max_seqlen * (cu_seqlens[0].item() // max_seqlen)
867
 
868
  # (total_seqlen, 3, nheads, headdim)
869
  qkv = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size)