DeepBeepMeep commited on
Commit
c0c0b08
·
1 Parent(s): 6e7a596

fix flash attention

Browse files
Files changed (1) hide show
  1. wan/modules/attention.py +3 -3
wan/modules/attention.py CHANGED
@@ -276,7 +276,7 @@ def pay_attention(
276
  k=k,
277
  v=v,
278
  cu_seqlens_q= cu_seqlens_q,
279
- cu_seqlens_kv= cu_seqlens_k,
280
  seqused_q=None,
281
  seqused_k=None,
282
  max_seqlen_q=lq,
@@ -289,8 +289,8 @@ def pay_attention(
289
  q=q,
290
  k=k,
291
  v=v,
292
- cu_seqlens_q= [0, lq],
293
- cu_seqlens_kv=[0, lk],
294
  max_seqlen_q=lq,
295
  max_seqlen_k=lk,
296
  dropout_p=dropout_p,
 
276
  k=k,
277
  v=v,
278
  cu_seqlens_q= cu_seqlens_q,
279
+ cu_seqlens_k= cu_seqlens_k,
280
  seqused_q=None,
281
  seqused_k=None,
282
  max_seqlen_q=lq,
 
289
  q=q,
290
  k=k,
291
  v=v,
292
+ cu_seqlens_q= cu_seqlens_q,
293
+ cu_seqlens_k= cu_seqlens_k,
294
  max_seqlen_q=lq,
295
  max_seqlen_k=lk,
296
  dropout_p=dropout_p,