Spaces:
Running
on
T4
Running
on
T4
Update modules/model.py
Browse files- modules/model.py +4 -6
modules/model.py
CHANGED
@@ -203,14 +203,14 @@ class CrossAttnProcessor(nn.Module):
|
|
203 |
k_bucket_size = 1024
|
204 |
|
205 |
# use flash-attention
|
206 |
-
hidden_states =
|
207 |
query.contiguous(),
|
208 |
key.contiguous(),
|
209 |
value.contiguous(),
|
210 |
attention_mask,
|
211 |
-
|
212 |
-
q_bucket_size
|
213 |
-
k_bucket_size
|
214 |
)
|
215 |
hidden_states = hidden_states.to(query.dtype)
|
216 |
|
@@ -1021,5 +1021,3 @@ class FlashAttentionFunction(Function):
|
|
1021 |
dvc.add_(dv_chunk)
|
1022 |
|
1023 |
return dq, dk, dv, None, None, None, None
|
1024 |
-
|
1025 |
-
FlashAttn = FlashAttentionFunction()
|
|
|
203 |
k_bucket_size = 1024
|
204 |
|
205 |
# use flash-attention
|
206 |
+
hidden_states = FlashAttentionFunction.apply(
|
207 |
query.contiguous(),
|
208 |
key.contiguous(),
|
209 |
value.contiguous(),
|
210 |
attention_mask,
|
211 |
+
False,
|
212 |
+
q_bucket_size,
|
213 |
+
k_bucket_size,
|
214 |
)
|
215 |
hidden_states = hidden_states.to(query.dtype)
|
216 |
|
|
|
1021 |
dvc.add_(dv_chunk)
|
1022 |
|
1023 |
return dq, dk, dv, None, None, None, None
|
|
|
|