asigalov61 commited on
Commit
d418680
1 Parent(s): e9613c6

Upload x_transformer_1_23_2.py

Browse files
Files changed (1) hide show
  1. x_transformer_1_23_2.py +14 -5
x_transformer_1_23_2.py CHANGED
@@ -26,10 +26,16 @@
26
  from functools import partial
27
  from typing import Optional, Tuple
28
 
 
 
 
29
  import torch
30
  from torch import nn, einsum, Tensor
31
  import torch.nn.functional as F
 
 
32
  from torch.nn.attention import SDPBackend, sdpa_kernel
 
33
 
34
  from collections import namedtuple
35
  from functools import wraps
@@ -259,11 +265,14 @@ class Attend(nn.Module):
259
 
260
  # Legacy code...
261
  # with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=True):
 
262
 
263
- # New SDP kernel code...
264
- # with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
265
- with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
266
 
 
 
 
267
  out = F.scaled_dot_product_attention(
268
  q, k, v,
269
  attn_mask = mask,
@@ -508,7 +517,7 @@ class AutoregressiveWrapper(Module):
508
  # whether to add router z-loss
509
  self.add_attn_z_loss = add_attn_z_loss
510
 
511
- @torch.no_grad()
512
  @eval_decorator
513
  def generate(
514
  self,
@@ -2462,4 +2471,4 @@ class XTransformer(nn.Module):
2462
  enc, mask = dropout_seq(enc, mask, self.cross_attn_tokens_dropout)
2463
 
2464
  out = self.decoder(tgt, context = enc, context_mask = mask)
2465
- return out
 
26
  from functools import partial
27
  from typing import Optional, Tuple
28
 
29
+ import os
30
+ os.environ['USE_FLASH_ATTENTION'] = '1'
31
+
32
  import torch
33
  from torch import nn, einsum, Tensor
34
  import torch.nn.functional as F
35
+
36
+ # Flash attention
37
  from torch.nn.attention import SDPBackend, sdpa_kernel
38
+ torch.backends.cuda.enable_flash_sdp(True)
39
 
40
  from collections import namedtuple
41
  from functools import wraps
 
265
 
266
  # Legacy code...
267
  # with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=True):
268
+ # with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
269
 
270
+ # PyTorch 2.3-2.4 SDPA backend code...
271
+ with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION]):
 
272
 
273
+ # New PyTorch 2.5 SDPA backend code:
274
+ # with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
275
+
276
  out = F.scaled_dot_product_attention(
277
  q, k, v,
278
  attn_mask = mask,
 
517
  # whether to add router z-loss
518
  self.add_attn_z_loss = add_attn_z_loss
519
 
520
+ @torch.inference_mode()
521
  @eval_decorator
522
  def generate(
523
  self,
 
2471
  enc, mask = dropout_seq(enc, mask, self.cross_attn_tokens_dropout)
2472
 
2473
  out = self.decoder(tgt, context = enc, context_mask = mask)
2474
+ return out