Spaces:
Running
on
Zero
Running
on
Zero
asigalov61
commited on
Commit
•
d418680
1
Parent(s):
e9613c6
Upload x_transformer_1_23_2.py
Browse files- x_transformer_1_23_2.py +14 -5
x_transformer_1_23_2.py
CHANGED
@@ -26,10 +26,16 @@
|
|
26 |
from functools import partial
|
27 |
from typing import Optional, Tuple
|
28 |
|
|
|
|
|
|
|
29 |
import torch
|
30 |
from torch import nn, einsum, Tensor
|
31 |
import torch.nn.functional as F
|
|
|
|
|
32 |
from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
|
33 |
|
34 |
from collections import namedtuple
|
35 |
from functools import wraps
|
@@ -259,11 +265,14 @@ class Attend(nn.Module):
|
|
259 |
|
260 |
# Legacy code...
|
261 |
# with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=True):
|
|
|
262 |
|
263 |
-
#
|
264 |
-
|
265 |
-
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
|
266 |
|
|
|
|
|
|
|
267 |
out = F.scaled_dot_product_attention(
|
268 |
q, k, v,
|
269 |
attn_mask = mask,
|
@@ -508,7 +517,7 @@ class AutoregressiveWrapper(Module):
|
|
508 |
# whether to add router z-loss
|
509 |
self.add_attn_z_loss = add_attn_z_loss
|
510 |
|
511 |
-
@torch.
|
512 |
@eval_decorator
|
513 |
def generate(
|
514 |
self,
|
@@ -2462,4 +2471,4 @@ class XTransformer(nn.Module):
|
|
2462 |
enc, mask = dropout_seq(enc, mask, self.cross_attn_tokens_dropout)
|
2463 |
|
2464 |
out = self.decoder(tgt, context = enc, context_mask = mask)
|
2465 |
-
return out
|
|
|
26 |
from functools import partial
|
27 |
from typing import Optional, Tuple
|
28 |
|
29 |
+
import os
|
30 |
+
os.environ['USE_FLASH_ATTENTION'] = '1'
|
31 |
+
|
32 |
import torch
|
33 |
from torch import nn, einsum, Tensor
|
34 |
import torch.nn.functional as F
|
35 |
+
|
36 |
+
# Flash attention
|
37 |
from torch.nn.attention import SDPBackend, sdpa_kernel
|
38 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
39 |
|
40 |
from collections import namedtuple
|
41 |
from functools import wraps
|
|
|
265 |
|
266 |
# Legacy code...
|
267 |
# with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=True):
|
268 |
+
# with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
|
269 |
|
270 |
+
# PyTorch 2.3-2.4 SDPA backend code...
|
271 |
+
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION]):
|
|
|
272 |
|
273 |
+
# New PyTorch 2.5 SDPA backend code:
|
274 |
+
# with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
|
275 |
+
|
276 |
out = F.scaled_dot_product_attention(
|
277 |
q, k, v,
|
278 |
attn_mask = mask,
|
|
|
517 |
# whether to add router z-loss
|
518 |
self.add_attn_z_loss = add_attn_z_loss
|
519 |
|
520 |
+
@torch.inference_mode()
|
521 |
@eval_decorator
|
522 |
def generate(
|
523 |
self,
|
|
|
2471 |
enc, mask = dropout_seq(enc, mask, self.cross_attn_tokens_dropout)
|
2472 |
|
2473 |
out = self.decoder(tgt, context = enc, context_mask = mask)
|
2474 |
+
return out
|