InPeerReview commited on
Commit
7ff4dd0
·
verified ·
1 Parent(s): c914dec

Upload 62 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. rscd/models/backbones/lib_mamba/__init__.py +58 -0
  3. rscd/models/backbones/lib_mamba/__pycache__/__init__.cpython-38.pyc +0 -0
  4. rscd/models/backbones/lib_mamba/__pycache__/csm_triton.cpython-38.pyc +0 -0
  5. rscd/models/backbones/lib_mamba/__pycache__/csm_tritonk2.cpython-38.pyc +0 -0
  6. rscd/models/backbones/lib_mamba/__pycache__/csms6s.cpython-38.pyc +0 -0
  7. rscd/models/backbones/lib_mamba/__pycache__/vmamba.cpython-38.pyc +0 -0
  8. rscd/models/backbones/lib_mamba/__pycache__/vmambanew.cpython-38.pyc +0 -0
  9. rscd/models/backbones/lib_mamba/csm_triton.py +644 -0
  10. rscd/models/backbones/lib_mamba/csm_tritonk2.py +899 -0
  11. rscd/models/backbones/lib_mamba/csms6s.py +266 -0
  12. rscd/models/backbones/lib_mamba/kernels/selective_scan/README.md +97 -0
  13. rscd/models/backbones/lib_mamba/kernels/selective_scan/build/lib.linux-x86_64-3.8/selective_scan_cuda_oflex.cpython-38-x86_64-linux-gnu.so +3 -0
  14. rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/.ninja_deps +3 -0
  15. rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/.ninja_log +4 -0
  16. rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/build.ninja +35 -0
  17. rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_bwd.o +3 -0
  18. rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_fwd.o +3 -0
  19. rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_oflex.o +3 -0
  20. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cub_extra.cuh +50 -0
  21. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan.cpp +354 -0
  22. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_bwd_kernel.cuh +306 -0
  23. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_core_bwd.cu +9 -0
  24. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_core_fwd.cu +9 -0
  25. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_fwd_kernel.cuh +203 -0
  26. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_bwd_kernel_ndstate.cuh +302 -0
  27. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_core_bwd.cu +9 -0
  28. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_core_fwd.cu +9 -0
  29. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_fwd_kernel_ndstate.cuh +200 -0
  30. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_ndstate.cpp +341 -0
  31. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_ndstate.h +84 -0
  32. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_bwd_kernel_nrow.cuh +344 -0
  33. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd.cu +9 -0
  34. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd2.cu +9 -0
  35. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd3.cu +8 -0
  36. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd4.cu +8 -0
  37. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd.cu +9 -0
  38. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd2.cu +9 -0
  39. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd3.cu +9 -0
  40. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd4.cu +9 -0
  41. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_fwd_kernel_nrow.cuh +238 -0
  42. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_nrow.cpp +367 -0
  43. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_bwd_kernel_oflex.cuh +323 -0
  44. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_core_bwd.cu +11 -0
  45. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_core_fwd.cu +11 -0
  46. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_fwd_kernel_oflex.cuh +211 -0
  47. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_oflex.cpp +363 -0
  48. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/reverse_scan.cuh +403 -0
  49. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/selective_scan.h +90 -0
  50. rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/selective_scan_common.h +210 -0
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ rscd/models/backbones/lib_mamba/kernels/selective_scan/build/lib.linux-x86_64-3.8/selective_scan_cuda_oflex.cpython-38-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
37
+ rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/.ninja_deps filter=lfs diff=lfs merge=lfs -text
38
+ rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_bwd.o filter=lfs diff=lfs merge=lfs -text
39
+ rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_fwd.o filter=lfs diff=lfs merge=lfs -text
40
+ rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_oflex.o filter=lfs diff=lfs merge=lfs -text
rscd/models/backbones/lib_mamba/__init__.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import partial
3
+ import torch
4
+
5
+ from .vmamba import VSSM
6
+ from .csms6s import flops_selective_scan_fn,flops_selective_scan_ref
7
+
8
+
9
+ def build_vssm_model(config, **kwargs):
10
+ model_type = config.MODEL.TYPE
11
+ if model_type in ["vssm"]:
12
+ model = VSSM(
13
+ patch_size=config.MODEL.VSSM.PATCH_SIZE,
14
+ in_chans=config.MODEL.VSSM.IN_CHANS,
15
+ num_classes=config.MODEL.NUM_CLASSES,
16
+ depths=config.MODEL.VSSM.DEPTHS,
17
+ dims=config.MODEL.VSSM.EMBED_DIM,
18
+ # ===================
19
+ ssm_d_state=config.MODEL.VSSM.SSM_D_STATE,
20
+ ssm_ratio=config.MODEL.VSSM.SSM_RATIO,
21
+ ssm_rank_ratio=config.MODEL.VSSM.SSM_RANK_RATIO,
22
+ ssm_dt_rank=("auto" if config.MODEL.VSSM.SSM_DT_RANK == "auto" else int(config.MODEL.VSSM.SSM_DT_RANK)),
23
+ ssm_act_layer=config.MODEL.VSSM.SSM_ACT_LAYER,
24
+ ssm_conv=config.MODEL.VSSM.SSM_CONV,
25
+ ssm_conv_bias=config.MODEL.VSSM.SSM_CONV_BIAS,
26
+ ssm_drop_rate=config.MODEL.VSSM.SSM_DROP_RATE,
27
+ ssm_init=config.MODEL.VSSM.SSM_INIT,
28
+ forward_type=config.MODEL.VSSM.SSM_FORWARDTYPE,
29
+ # ===================
30
+ mlp_ratio=config.MODEL.VSSM.MLP_RATIO,
31
+ mlp_act_layer=config.MODEL.VSSM.MLP_ACT_LAYER,
32
+ mlp_drop_rate=config.MODEL.VSSM.MLP_DROP_RATE,
33
+ # ===================
34
+ drop_path_rate=config.MODEL.DROP_PATH_RATE,
35
+ patch_norm=config.MODEL.VSSM.PATCH_NORM,
36
+ norm_layer=config.MODEL.VSSM.NORM_LAYER,
37
+ downsample_version=config.MODEL.VSSM.DOWNSAMPLE,
38
+ patchembed_version=config.MODEL.VSSM.PATCHEMBED,
39
+ gmlp=config.MODEL.VSSM.GMLP,
40
+ use_checkpoint=config.TRAIN.USE_CHECKPOINT,
41
+ # ===================
42
+ posembed=config.MODEL.VSSM.POSEMBED,
43
+ imgsize=config.DATA.IMG_SIZE,
44
+ )
45
+ return model
46
+
47
+ return None
48
+
49
+
50
+ def build_model(config, is_pretrain=False):
51
+ model = None
52
+ if model is None:
53
+ model = build_vssm_model(config, is_pretrain)
54
+ return model
55
+
56
+
57
+
58
+
rscd/models/backbones/lib_mamba/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (1.71 kB). View file
 
rscd/models/backbones/lib_mamba/__pycache__/csm_triton.cpython-38.pyc ADDED
Binary file (19 kB). View file
 
rscd/models/backbones/lib_mamba/__pycache__/csm_tritonk2.cpython-38.pyc ADDED
Binary file (24.3 kB). View file
 
rscd/models/backbones/lib_mamba/__pycache__/csms6s.cpython-38.pyc ADDED
Binary file (8.38 kB). View file
 
rscd/models/backbones/lib_mamba/__pycache__/vmamba.cpython-38.pyc ADDED
Binary file (46.4 kB). View file
 
rscd/models/backbones/lib_mamba/__pycache__/vmambanew.cpython-38.pyc ADDED
Binary file (40.6 kB). View file
 
rscd/models/backbones/lib_mamba/csm_triton.py ADDED
@@ -0,0 +1,644 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import warnings
3
+
4
+ WITH_TRITON = True
5
+ # WITH_TRITON = False
6
+ try:
7
+ import triton
8
+ import triton.language as tl
9
+ except:
10
+ WITH_TRITON = False
11
+ warnings.warn("Triton not installed, fall back to pytorch implements.")
12
+
13
+ # to make sure cached_property can be loaded for triton
14
+ if WITH_TRITON:
15
+ try:
16
+ from functools import cached_property
17
+ except:
18
+ warnings.warn("if you are using py37, add this line to functools.py: "
19
+ "cached_property = lambda func: property(lru_cache()(func))")
20
+
21
+ # torch implementation ========================================
22
+ def cross_scan_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
23
+ if in_channel_first:
24
+ B, C, H, W = x.shape
25
+ if scans == 0:
26
+ y = x.new_empty((B, 4, C, H * W))
27
+ y[:, 0, :, :] = x.flatten(2, 3)
28
+ y[:, 1, :, :] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
29
+ y[:, 2:4, :, :] = torch.flip(y[:, 0:2, :, :], dims=[-1])
30
+ elif scans == 1:
31
+ y = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1)
32
+ elif scans == 2:
33
+ y = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1)
34
+ y = torch.cat([y, y.flip(dims=[-1])], dim=1)
35
+ else:
36
+ B, H, W, C = x.shape
37
+ if scans == 0:
38
+ y = x.new_empty((B, H * W, 4, C))
39
+ y[:, :, 0, :] = x.flatten(1, 2)
40
+ y[:, :, 1, :] = x.transpose(dim0=1, dim1=2).flatten(1, 2)
41
+ y[:, :, 2:4, :] = torch.flip(y[:, :, 0:2, :], dims=[1])
42
+ elif scans == 1:
43
+ y = x.view(B, H * W, 1, C).repeat(1, 1, 4, 1)
44
+ elif scans == 2:
45
+ y = x.view(B, H * W, 1, C).repeat(1, 1, 2, 1)
46
+ y = torch.cat([y, y.flip(dims=[1])], dim=2)
47
+
48
+ if in_channel_first and (not out_channel_first):
49
+ y = y.permute(0, 3, 1, 2).contiguous()
50
+ elif (not in_channel_first) and out_channel_first:
51
+ y = y.permute(0, 2, 3, 1).contiguous()
52
+
53
+ return y
54
+
55
+
56
+ def cross_merge_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
57
+ if out_channel_first:
58
+ B, K, D, H, W = y.shape
59
+ y = y.view(B, K, D, -1)
60
+ if scans == 0:
61
+ y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
62
+ y = y[:, 0] + y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
63
+ elif scans == 1:
64
+ y = y.sum(1)
65
+ elif scans == 2:
66
+ y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
67
+ y = y.sum(1)
68
+ else:
69
+ B, H, W, K, D = y.shape
70
+ y = y.view(B, -1, K, D)
71
+ if scans == 0:
72
+ y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D)
73
+ y = y[:, :, 0] + y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).contiguous().view(B, -1, D)
74
+ elif scans == 1:
75
+ y = y.sum(2)
76
+ elif scans == 2:
77
+ y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D)
78
+ y = y.sum(2)
79
+
80
+ if in_channel_first and (not out_channel_first):
81
+ y = y.permute(0, 2, 1).contiguous()
82
+ elif (not in_channel_first) and out_channel_first:
83
+ y = y.permute(0, 2, 1).contiguous()
84
+
85
+ return y
86
+
87
+
88
+ def cross_scan1b1_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
89
+ if in_channel_first:
90
+ B, _, C, H, W = x.shape
91
+ if scans == 0:
92
+ y = torch.stack([
93
+ x[:, 0].flatten(2, 3),
94
+ x[:, 1].transpose(dim0=2, dim1=3).flatten(2, 3),
95
+ torch.flip(x[:, 2].flatten(2, 3), dims=[-1]),
96
+ torch.flip(x[:, 3].transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]),
97
+ ], dim=1)
98
+ elif scans == 1:
99
+ y = x.flatten(2, 3)
100
+ elif scans == 2:
101
+ y = torch.stack([
102
+ x[:, 0].flatten(2, 3),
103
+ x[:, 1].flatten(2, 3),
104
+ torch.flip(x[:, 2].flatten(2, 3), dims=[-1]),
105
+ torch.flip(x[:, 3].flatten(2, 3), dims=[-1]),
106
+ ], dim=1)
107
+ else:
108
+ B, H, W, _, C = x.shape
109
+ if scans == 0:
110
+ y = torch.stack([
111
+ x[:, :, :, 0].flatten(1, 2),
112
+ x[:, :, :, 1].transpose(dim0=1, dim1=2).flatten(1, 2),
113
+ torch.flip(x[:, :, :, 2].flatten(1, 2), dims=[1]),
114
+ torch.flip(x[:, :, :, 3].transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]),
115
+ ], dim=2)
116
+ elif scans == 1:
117
+ y = x.flatten(1, 2)
118
+ elif scans == 2:
119
+ y = torch.stack([
120
+ x[:, 0].flatten(1, 2),
121
+ x[:, 1].flatten(1, 2),
122
+ torch.flip(x[:, 2].flatten(1, 2), dims=[-1]),
123
+ torch.flip(x[:, 3].flatten(1, 2), dims=[-1]),
124
+ ], dim=2)
125
+
126
+ if in_channel_first and (not out_channel_first):
127
+ y = y.permute(0, 3, 1, 2).contiguous()
128
+ elif (not in_channel_first) and out_channel_first:
129
+ y = y.permute(0, 2, 3, 1).contiguous()
130
+
131
+ return y
132
+
133
+
134
+ def cross_merge1b1_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
135
+ if out_channel_first:
136
+ B, K, D, H, W = y.shape
137
+ y = y.view(B, K, D, -1)
138
+ if scans == 0:
139
+ y = torch.stack([
140
+ y[:, 0],
141
+ y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3),
142
+ torch.flip(y[:, 2], dims=[-1]),
143
+ torch.flip(y[:, 3].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]),
144
+ ], dim=1)
145
+ elif scans == 1:
146
+ y = y
147
+ elif scans == 2:
148
+ y = torch.stack([
149
+ y[:, 0],
150
+ y[:, 1],
151
+ torch.flip(y[:, 2], dims=[-1]),
152
+ torch.flip(y[:, 3], dims=[-1]),
153
+ ], dim=1)
154
+ else:
155
+ B, H, W, _, D = y.shape
156
+ y = y.view(B, -1, K, D)
157
+ if scans == 0:
158
+ y = torch.stack([
159
+ y[:, :, 0],
160
+ y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2),
161
+ torch.flip(y[:, :, 2], dims=[1]),
162
+ torch.flip(y[:, :, 3].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]),
163
+ ], dim=2)
164
+ elif scans == 1:
165
+ y = y
166
+ elif scans == 2:
167
+ y = torch.stack([
168
+ y[:, :, 0],
169
+ y[:, :, 1],
170
+ torch.flip(y[:, :, 2], dims=[1]),
171
+ torch.flip(y[:, :, 3], dims=[1]),
172
+ ], dim=2)
173
+
174
+ if out_channel_first and (not in_channel_first):
175
+ y = y.permute(0, 3, 1, 2).contiguous()
176
+ elif (not out_channel_first) and in_channel_first:
177
+ y = y.permute(0, 2, 3, 1).contiguous()
178
+
179
+ return y
180
+
181
+
182
+ class CrossScanF(torch.autograd.Function):
183
+ @staticmethod
184
+ def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
185
+ # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
186
+ # y: (B, 4, C, H * W) | (B, H * W, 4, C)
187
+ ctx.in_channel_first = in_channel_first
188
+ ctx.out_channel_first = out_channel_first
189
+ ctx.one_by_one = one_by_one
190
+ ctx.scans = scans
191
+
192
+ if one_by_one:
193
+ B, K, C, H, W = x.shape
194
+ if not in_channel_first:
195
+ B, H, W, K, C = x.shape
196
+ else:
197
+ B, C, H, W = x.shape
198
+ if not in_channel_first:
199
+ B, H, W, C = x.shape
200
+ ctx.shape = (B, C, H, W)
201
+
202
+ _fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd
203
+ y = _fn(x, in_channel_first, out_channel_first, scans)
204
+
205
+ return y
206
+
207
+ @staticmethod
208
+ def backward(ctx, ys: torch.Tensor):
209
+ # out: (b, k, d, l)
210
+ in_channel_first = ctx.in_channel_first
211
+ out_channel_first = ctx.out_channel_first
212
+ one_by_one = ctx.one_by_one
213
+ scans = ctx.scans
214
+ B, C, H, W = ctx.shape
215
+
216
+ ys = ys.view(B, -1, C, H, W) if out_channel_first else ys.view(B, H, W, -1, C)
217
+ _fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd
218
+ y = _fn(ys, in_channel_first, out_channel_first, scans)
219
+
220
+ if one_by_one:
221
+ y = y.view(B, 4, -1, H, W) if in_channel_first else y.view(B, H, W, 4, -1)
222
+ else:
223
+ y = y.view(B, -1, H, W) if in_channel_first else y.view(B, H, W, -1)
224
+
225
+ return y, None, None, None, None
226
+
227
+
228
+ class CrossMergeF(torch.autograd.Function):
229
+ @staticmethod
230
+ def forward(ctx, ys: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
231
+ # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
232
+ # y: (B, 4, C, H * W) | (B, H * W, 4, C)
233
+ ctx.in_channel_first = in_channel_first
234
+ ctx.out_channel_first = out_channel_first
235
+ ctx.one_by_one = one_by_one
236
+ ctx.scans = scans
237
+
238
+ B, K, C, H, W = ys.shape
239
+ if not out_channel_first:
240
+ B, H, W, K, C = ys.shape
241
+ ctx.shape = (B, C, H, W)
242
+
243
+ _fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd
244
+ y = _fn(ys, in_channel_first, out_channel_first, scans)
245
+
246
+ return y
247
+
248
+ @staticmethod
249
+ def backward(ctx, x: torch.Tensor):
250
+ # B, D, L = x.shape
251
+ # out: (b, k, d, h, w)
252
+ in_channel_first = ctx.in_channel_first
253
+ out_channel_first = ctx.out_channel_first
254
+ one_by_one = ctx.one_by_one
255
+ scans = ctx.scans
256
+ B, C, H, W = ctx.shape
257
+
258
+ if not one_by_one:
259
+ if in_channel_first:
260
+ x = x.view(B, C, H, W)
261
+ else:
262
+ x = x.view(B, H, W, C)
263
+ else:
264
+ if in_channel_first:
265
+ x = x.view(B, 4, C, H, W)
266
+ else:
267
+ x = x.view(B, H, W, 4, C)
268
+
269
+ _fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd
270
+ x = _fn(x, in_channel_first, out_channel_first, scans)
271
+ x = x.view(B, 4, C, H, W) if out_channel_first else x.view(B, H, W, 4, C)
272
+
273
+ return x, None, None, None, None
274
+
275
+
276
+ # triton implements ========================================
277
+
278
+ @triton.jit
279
+ def triton_cross_scan_flex(
280
+ x, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
281
+ y, # (B, 4, C, H, W) | (B, H, W, 4, C)
282
+ x_layout: tl.constexpr,
283
+ y_layout: tl.constexpr,
284
+ operation: tl.constexpr,
285
+ onebyone: tl.constexpr,
286
+ scans: tl.constexpr,
287
+ BC: tl.constexpr,
288
+ BH: tl.constexpr,
289
+ BW: tl.constexpr,
290
+ DC: tl.constexpr,
291
+ DH: tl.constexpr,
292
+ DW: tl.constexpr,
293
+ NH: tl.constexpr,
294
+ NW: tl.constexpr,
295
+ ):
296
+ # x_layout = 0
297
+ # y_layout = 1 # 0 BCHW, 1 BHWC
298
+ # operation = 0 # 0 scan, 1 merge
299
+ # onebyone = 0 # 0 false, 1 true
300
+ # scans = 0 # 0 cross scan, 1 unidirectional, 2 bidirectional
301
+
302
+ i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
303
+ i_h, i_w = (i_hw // NW), (i_hw % NW)
304
+ _mask_h = (i_h * BH + tl.arange(0, BH)) < DH
305
+ _mask_w = (i_w * BW + tl.arange(0, BW)) < DW
306
+ _mask_hw = _mask_h[:, None] & _mask_w[None, :]
307
+ _for_C = min(DC - i_c * BC, BC)
308
+
309
+ HWRoute0 = i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :]
310
+ HWRoute1 = i_w * BW * DH + tl.arange(0, BW)[None, :] * DH + i_h * BH + tl.arange(0, BH)[:, None] # trans
311
+ HWRoute2 = (NH - i_h - 1) * BH * DW + (BH - 1 - tl.arange(0, BH)[:, None]) * DW + (NW - i_w - 1) * BW + (BW - 1 - tl.arange(0, BW)[None, :]) + (DH - NH * BH) * DW + (DW - NW * BW) # flip
312
+ HWRoute3 = (NW - i_w - 1) * BW * DH + (BW - 1 - tl.arange(0, BW)[None, :]) * DH + (NH - i_h - 1) * BH + (BH - 1 - tl.arange(0, BH)[:, None]) + (DH - NH * BH) + (DW - NW * BW) * DH # trans + flip
313
+
314
+ if scans == 1:
315
+ HWRoute1 = HWRoute0
316
+ HWRoute2 = HWRoute0
317
+ HWRoute3 = HWRoute0
318
+ elif scans == 2:
319
+ HWRoute1 = HWRoute0
320
+ HWRoute3 = HWRoute2
321
+
322
+ _tmp1 = DC * DH * DW
323
+
324
+ y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC)
325
+ if y_layout == 0:
326
+ p_y1 = y_ptr_base + HWRoute0
327
+ p_y2 = y_ptr_base + _tmp1 + HWRoute1
328
+ p_y3 = y_ptr_base + 2 * _tmp1 + HWRoute2
329
+ p_y4 = y_ptr_base + 3 * _tmp1 + HWRoute3
330
+ else:
331
+ p_y1 = y_ptr_base + HWRoute0 * 4 * DC
332
+ p_y2 = y_ptr_base + DC + HWRoute1 * 4 * DC
333
+ p_y3 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC
334
+ p_y4 = y_ptr_base + 3 * DC + HWRoute3 * 4 * DC
335
+
336
+ if onebyone == 0:
337
+ x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
338
+ if x_layout == 0:
339
+ p_x = x_ptr_base + HWRoute0
340
+ else:
341
+ p_x = x_ptr_base + HWRoute0 * DC
342
+
343
+ if operation == 0:
344
+ for idxc in range(_for_C):
345
+ _idx_x = idxc * DH * DW if x_layout == 0 else idxc
346
+ _idx_y = idxc * DH * DW if y_layout == 0 else idxc
347
+ _x = tl.load(p_x + _idx_x, mask=_mask_hw)
348
+ tl.store(p_y1 + _idx_y, _x, mask=_mask_hw)
349
+ tl.store(p_y2 + _idx_y, _x, mask=_mask_hw)
350
+ tl.store(p_y3 + _idx_y, _x, mask=_mask_hw)
351
+ tl.store(p_y4 + _idx_y, _x, mask=_mask_hw)
352
+ elif operation == 1:
353
+ for idxc in range(_for_C):
354
+ _idx_x = idxc * DH * DW if x_layout == 0 else idxc
355
+ _idx_y = idxc * DH * DW if y_layout == 0 else idxc
356
+ _y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw)
357
+ _y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw)
358
+ _y3 = tl.load(p_y3 + _idx_y, mask=_mask_hw)
359
+ _y4 = tl.load(p_y4 + _idx_y, mask=_mask_hw)
360
+ tl.store(p_x + _idx_x, _y1 + _y2 + _y3 + _y4, mask=_mask_hw)
361
+
362
+ else:
363
+ x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
364
+ if x_layout == 0:
365
+ p_x1 = x_ptr_base + HWRoute0
366
+ p_x2 = p_x1 + _tmp1
367
+ p_x3 = p_x2 + _tmp1
368
+ p_x4 = p_x3 + _tmp1
369
+ else:
370
+ p_x1 = x_ptr_base + HWRoute0 * 4 * DC
371
+ p_x2 = p_x1 + DC
372
+ p_x3 = p_x2 + DC
373
+ p_x4 = p_x3 + DC
374
+
375
+ if operation == 0:
376
+ for idxc in range(_for_C):
377
+ _idx_x = idxc * DH * DW if x_layout == 0 else idxc
378
+ _idx_y = idxc * DH * DW if y_layout == 0 else idxc
379
+ tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw)
380
+ tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw)
381
+ tl.store(p_y3 + _idx_y, tl.load(p_x3 + _idx_x, mask=_mask_hw), mask=_mask_hw)
382
+ tl.store(p_y4 + _idx_y, tl.load(p_x4 + _idx_x, mask=_mask_hw), mask=_mask_hw)
383
+ else:
384
+ for idxc in range(_for_C):
385
+ _idx_x = idxc * DH * DW if x_layout == 0 else idxc
386
+ _idx_y = idxc * DH * DW if y_layout == 0 else idxc
387
+ tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw)
388
+ tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw)
389
+ tl.store(p_x3 + _idx_x, tl.load(p_y3 + _idx_y), mask=_mask_hw)
390
+ tl.store(p_x4 + _idx_x, tl.load(p_y4 + _idx_y), mask=_mask_hw)
391
+
392
+
393
+ class CrossScanTritonF(torch.autograd.Function):
394
+ @staticmethod
395
+ def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
396
+ if one_by_one:
397
+ if in_channel_first:
398
+ B, _, C, H, W = x.shape
399
+ else:
400
+ B, H, W, _, C = x.shape
401
+ else:
402
+ if in_channel_first:
403
+ B, C, H, W = x.shape
404
+ else:
405
+ B, H, W, C = x.shape
406
+ B, C, H, W = int(B), int(C), int(H), int(W)
407
+ BC, BH, BW = 1, 32, 32
408
+ NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)
409
+
410
+ ctx.in_channel_first = in_channel_first
411
+ ctx.out_channel_first = out_channel_first
412
+ ctx.one_by_one = one_by_one
413
+ ctx.scans = scans
414
+ ctx.shape = (B, C, H, W)
415
+ ctx.triton_shape = (BC, BH, BW, NC, NH, NW)
416
+
417
+ y = x.new_empty((B, 4, C, H * W)) if out_channel_first else x.new_empty((B, H * W, 4, C))
418
+ triton_cross_scan_flex[(NH * NW, NC, B)](
419
+ x.contiguous(), y,
420
+ (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans,
421
+ BC, BH, BW, C, H, W, NH, NW
422
+ )
423
+ return y
424
+
425
+ @staticmethod
426
+ def backward(ctx, y: torch.Tensor):
427
+ in_channel_first = ctx.in_channel_first
428
+ out_channel_first = ctx.out_channel_first
429
+ one_by_one = ctx.one_by_one
430
+ scans = ctx.scans
431
+ B, C, H, W = ctx.shape
432
+ BC, BH, BW, NC, NH, NW = ctx.triton_shape
433
+ if one_by_one:
434
+ x = y.new_empty((B, 4, C, H, W)) if in_channel_first else y.new_empty((B, H, W, 4, C))
435
+ else:
436
+ x = y.new_empty((B, C, H, W)) if in_channel_first else y.new_empty((B, H, W, C))
437
+
438
+ triton_cross_scan_flex[(NH * NW, NC, B)](
439
+ x, y.contiguous(),
440
+ (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,
441
+ BC, BH, BW, C, H, W, NH, NW
442
+ )
443
+ return x, None, None, None, None
444
+
445
+
446
+ class CrossMergeTritonF(torch.autograd.Function):
447
+ @staticmethod
448
+ def forward(ctx, y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
449
+ if out_channel_first:
450
+ B, _, C, H, W = y.shape
451
+ else:
452
+ B, H, W, _, C = y.shape
453
+ B, C, H, W = int(B), int(C), int(H), int(W)
454
+ BC, BH, BW = 1, 32, 32
455
+ NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)
456
+ ctx.in_channel_first = in_channel_first
457
+ ctx.out_channel_first = out_channel_first
458
+ ctx.one_by_one = one_by_one
459
+ ctx.scans = scans
460
+ ctx.shape = (B, C, H, W)
461
+ ctx.triton_shape = (BC, BH, BW, NC, NH, NW)
462
+ if one_by_one:
463
+ x = y.new_empty((B, 4, C, H * W)) if in_channel_first else y.new_empty((B, H * W, 4, C))
464
+ else:
465
+ x = y.new_empty((B, C, H * W)) if in_channel_first else y.new_empty((B, H * W, C))
466
+ triton_cross_scan_flex[(NH * NW, NC, B)](
467
+ x, y.contiguous(),
468
+ (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,
469
+ BC, BH, BW, C, H, W, NH, NW
470
+ )
471
+ return x
472
+
473
+ @staticmethod
474
+ def backward(ctx, x: torch.Tensor):
475
+ in_channel_first = ctx.in_channel_first
476
+ out_channel_first = ctx.out_channel_first
477
+ one_by_one = ctx.one_by_one
478
+ scans = ctx.scans
479
+ B, C, H, W = ctx.shape
480
+ BC, BH, BW, NC, NH, NW = ctx.triton_shape
481
+ y = x.new_empty((B, 4, C, H, W)) if out_channel_first else x.new_empty((B, H, W, 4, C))
482
+ triton_cross_scan_flex[(NH * NW, NC, B)](
483
+ x.contiguous(), y,
484
+ (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans,
485
+ BC, BH, BW, C, H, W, NH, NW
486
+ )
487
+ return y, None, None, None, None, None
488
+
489
+
490
+ # @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
491
+ def cross_scan_fn(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False):
492
+ # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
493
+ # y: (B, 4, C, L) | (B, L, 4, C)
494
+ # scans: 0: cross scan; 1 unidirectional; 2: bidirectional;
495
+ CSF = CrossScanTritonF if WITH_TRITON and x.is_cuda and (not force_torch) else CrossScanF
496
+ return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans)
497
+
498
+
499
+ # @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
500
+ def cross_merge_fn(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False):
501
+ # y: (B, 4, C, L) | (B, L, 4, C)
502
+ # x: (B, C, H * W) | (B, H * W, C) | (B, 4, C, H * W) | (B, H * W, 4, C)
503
+ # scans: 0: cross scan; 1 unidirectional; 2: bidirectional;
504
+ CMF = CrossMergeTritonF if WITH_TRITON and y.is_cuda and (not force_torch) else CrossMergeF
505
+ return CMF.apply(y, in_channel_first, out_channel_first, one_by_one, scans)
506
+
507
+
508
+ # checks =================================================================
509
+
510
+ class CHECK:
511
+ def check_csm_triton():
512
+ B, C, H, W = 2, 192, 56, 57
513
+ dtype=torch.float16
514
+ dtype=torch.float32
515
+ x = torch.randn((B, C, H, W), dtype=dtype, device=torch.device("cuda")).requires_grad_(True)
516
+ y = torch.randn((B, 4, C, H, W), dtype=dtype, device=torch.device("cuda")).requires_grad_(True)
517
+ x1 = x.clone().detach().requires_grad_(True)
518
+ y1 = y.clone().detach().requires_grad_(True)
519
+
520
+ def cross_scan(x: torch.Tensor):
521
+ B, C, H, W = x.shape
522
+ L = H * W
523
+ xs = torch.stack([
524
+ x.view(B, C, L),
525
+ torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L),
526
+ torch.flip(x.contiguous().view(B, C, L), dims=[-1]),
527
+ torch.flip(torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L), dims=[-1]),
528
+ ], dim=1).view(B, 4, C, L)
529
+ return xs
530
+
531
+ def cross_merge(out_y: torch.Tensor):
532
+ B, K, D, H, W = out_y.shape
533
+ L = H * W
534
+ out_y = out_y.view(B, K, D, L)
535
+ inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
536
+ wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
537
+ invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
538
+ y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y
539
+ return y
540
+
541
+ def cross_scan_1b1(x: torch.Tensor):
542
+ B, K, C, H, W = x.shape
543
+ L = H * W
544
+ xs = torch.stack([
545
+ x[:, 0].view(B, C, L),
546
+ torch.transpose(x[:, 1], dim0=2, dim1=3).contiguous().view(B, C, L),
547
+ torch.flip(x[:, 2].contiguous().view(B, C, L), dims=[-1]),
548
+ torch.flip(torch.transpose(x[:, 3], dim0=2, dim1=3).contiguous().view(B, C, L), dims=[-1]),
549
+ ], dim=1).view(B, 4, C, L)
550
+ return xs
551
+
552
+ def unidi_scan(x):
553
+ B, C, H, W = x.shape
554
+ x = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1)
555
+ return x
556
+
557
+ def unidi_merge(ys):
558
+ B, K, C, H, W = ys.shape
559
+ return ys.view(B, 4, -1, H * W).sum(1)
560
+
561
+ def bidi_scan(x):
562
+ B, C, H, W = x.shape
563
+ x = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1)
564
+ x = torch.cat([x, x.flip(dims=[-1])], dim=1)
565
+ return x
566
+
567
+ def bidi_merge(ys):
568
+ B, K, D, H, W = ys.shape
569
+ ys = ys.view(B, K, D, -1)
570
+ ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
571
+ return ys.contiguous().sum(1)
572
+
573
+ if True:
574
+ res0 = triton.testing.do_bench(lambda :cross_scan(x))
575
+ res1 = triton.testing.do_bench(lambda :cross_scan_fn(x, True, True, False))
576
+ # res2 = triton.testing.do_bench(lambda :CrossScanTriton.apply(x))
577
+ res3 = triton.testing.do_bench(lambda :cross_merge(y))
578
+ res4 = triton.testing.do_bench(lambda :cross_merge_fn(y, True, True, False))
579
+ # res5 = triton.testing.do_bench(lambda :CrossMergeTriton.apply(y))
580
+ # print(res0, res1, res2, res3, res4, res5)
581
+ print(res0, res1, res3, res4)
582
+ res0 = triton.testing.do_bench(lambda :cross_scan(x).sum().backward())
583
+ res1 = triton.testing.do_bench(lambda :cross_scan_fn(x, True, True, False).sum().backward())
584
+ # res2 = triton.testing.do_bench(lambda :CrossScanTriton.apply(x).sum().backward())
585
+ res3 = triton.testing.do_bench(lambda :cross_merge(y).sum().backward())
586
+ res4 = triton.testing.do_bench(lambda :cross_merge_fn(y, True, True, False).sum().backward())
587
+ # res5 = triton.testing.do_bench(lambda :CrossMergeTriton.apply(y).sum().backward())
588
+ # print(res0, res1, res2, res3, res4, res5)
589
+ print(res0, res1, res3, res4)
590
+
591
+ print("test cross scan")
592
+ for (cs0, cm0, cs1, cm1) in [
593
+ # channel_first -> channel_first
594
+ (cross_scan, cross_merge, cross_scan_fn, cross_merge_fn),
595
+ (unidi_scan, unidi_merge, lambda x: cross_scan_fn(x, scans=1), lambda x: cross_merge_fn(x, scans=1)),
596
+ (bidi_scan, bidi_merge, lambda x: cross_scan_fn(x, scans=2), lambda x: cross_merge_fn(x, scans=2)),
597
+
598
+ # flex: BLC->BCL; BCL->BLC; BLC->BLC;
599
+ (cross_scan, cross_merge, lambda x: cross_scan_fn(x.permute(0, 2, 3, 1), in_channel_first=False), lambda x: cross_merge_fn(x, in_channel_first=False).permute(0, 2, 1)),
600
+ (cross_scan, cross_merge, lambda x: cross_scan_fn(x, out_channel_first=False).permute(0, 2, 3, 1), lambda x: cross_merge_fn(x.permute(0, 3, 4, 1, 2), out_channel_first=False)),
601
+ (cross_scan, cross_merge, lambda x: cross_scan_fn(x.permute(0, 2, 3, 1), in_channel_first=False, out_channel_first=False).permute(0, 2, 3, 1), lambda x: cross_merge_fn(x.permute(0, 3, 4, 1, 2), in_channel_first=False, out_channel_first=False).permute(0, 2, 1)),
602
+
603
+ # previous
604
+ # (cross_scan, cross_merge, lambda x: CrossScanTriton.apply(x), lambda x: CrossMergeTriton.apply(x)),
605
+ # (unidi_scan, unidi_merge, lambda x: getCSM(1)[0].apply(x), lambda x: getCSM(1)[1].apply(x)),
606
+ # (bidi_scan, bidi_merge, lambda x: getCSM(2)[0].apply(x), lambda x: getCSM(2)[1].apply(x)),
607
+ ]:
608
+ x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
609
+ o0 = cs0(x)
610
+ o1 = cs1(x1)
611
+ o0.backward(y.view(B, 4, C, H * W))
612
+ o1.backward(y.view(B, 4, C, H * W))
613
+ print((o0 - o1).abs().max())
614
+ print((x.grad - x1.grad).abs().max())
615
+ o0 = cm0(y)
616
+ o1 = cm1(y1)
617
+ o0.backward(x.view(B, C, H * W))
618
+ o1.backward(x.view(B, C, H * W))
619
+ print((o0 - o1).abs().max())
620
+ print((y.grad - y1.grad).abs().max())
621
+ x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
622
+ print("===============", flush=True)
623
+
624
+ print("test cross scan one by one")
625
+ for (cs0, cs1) in [
626
+ (cross_scan_1b1, lambda x: cross_scan_fn(x, one_by_one=True)),
627
+ # (cross_scan_1b1, lambda x: CrossScanTriton1b1.apply(x)),
628
+ ]:
629
+ o0 = cs0(y)
630
+ o1 = cs1(y1)
631
+ o0.backward(y.view(B, 4, C, H * W))
632
+ o1.backward(y.view(B, 4, C, H * W))
633
+ print((o0 - o1).abs().max())
634
+ print((y.grad - y1.grad).abs().max())
635
+ x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
636
+ print("===============", flush=True)
637
+
638
+
639
+ if __name__ == "__main__":
640
+ CHECK.check_csm_triton()
641
+
642
+
643
+
644
+
rscd/models/backbones/lib_mamba/csm_tritonk2.py ADDED
@@ -0,0 +1,899 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import warnings
3
+ import os
4
+ os.environ["TRITON_INTERPRET"] = "1"
5
+
6
+ WITH_TRITON = True
7
+ # WITH_TRITON = False
8
+ try:
9
+ import triton
10
+ import triton.language as tl
11
+ except:
12
+ WITH_TRITON = False
13
+ warnings.warn("Triton not installed, fall back to pytorch implements.")
14
+
15
+ # to make sure cached_property can be loaded for triton
16
+ if WITH_TRITON:
17
+ try:
18
+ from functools import cached_property
19
+ except:
20
+ warnings.warn("if you are using py37, add this line to functools.py: "
21
+ "cached_property = lambda func: property(lru_cache()(func))")
22
+
23
+ # torch implementation ========================================
24
+ def cross_scan_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=2):
25
+ if in_channel_first:
26
+ B, C, H, W = x.shape
27
+ if scans == 0:
28
+ y = x.new_empty((B, 4, C, H * W))
29
+ y[:, 0, :, :] = x.flatten(2, 3)
30
+ y[:, 1, :, :] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
31
+ y[:, 2:4, :, :] = torch.flip(y[:, 0:2, :, :], dims=[-1])
32
+ elif scans == 1:
33
+ y = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1)
34
+ elif scans == 2:
35
+ y = x.view(B, 1, C, H * W)
36
+ y = torch.cat([y, y.flip(dims=[-1])], dim=1)
37
+ else:
38
+ B, H, W, C = x.shape
39
+ if scans == 0:
40
+ y = x.new_empty((B, H * W, 4, C))
41
+ y[:, :, 0, :] = x.flatten(1, 2)
42
+ y[:, :, 1, :] = x.transpose(dim0=1, dim1=2).flatten(1, 2)
43
+ y[:, :, 2:4, :] = torch.flip(y[:, :, 0:2, :], dims=[1])
44
+ elif scans == 1:
45
+ y = x.view(B, H * W, 1, C).repeat(1, 1, 2, 1)
46
+ elif scans == 2:
47
+ y = x.view(B, H * W, 1, C)
48
+ y = torch.cat([y, y.flip(dims=[1])], dim=2)
49
+
50
+ if in_channel_first and (not out_channel_first):
51
+ y = y.permute(0, 3, 1, 2).contiguous()
52
+ elif (not in_channel_first) and out_channel_first:
53
+ y = y.permute(0, 2, 3, 1).contiguous()
54
+
55
+ return y
56
+
57
+
58
+ def cross_merge_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=2):
59
+ if out_channel_first:
60
+ B, K, D, H, W = y.shape
61
+ y = y.view(B, K, D, -1)
62
+ if scans == 0:
63
+ y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
64
+ y = y[:, 0] + y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
65
+ elif scans == 1:
66
+ y = y.sum(1)
67
+ elif scans == 2:
68
+ y = y[:, 0] + y[:, 1].flip(dims=[-1]).view(B, 1, D, -1)
69
+ y = y.sum(1)
70
+ else:
71
+ B, H, W, K, D = y.shape
72
+ y = y.view(B, -1, K, D)
73
+ if scans == 0:
74
+ y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D)
75
+ y = y[:, :, 0] + y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).contiguous().view(B, -1, D)
76
+ elif scans == 1:
77
+ y = y.sum(2)
78
+ elif scans == 2:
79
+ y = y[:, :, 0] + y[:, :, 1].flip(dims=[1]).view(B, -1, 1, D)
80
+ y = y.sum(2)
81
+
82
+ if in_channel_first and (not out_channel_first):
83
+ y = y.permute(0, 2, 1).contiguous()
84
+ elif (not in_channel_first) and out_channel_first:
85
+ y = y.permute(0, 2, 1).contiguous()
86
+
87
+ return y
88
+
89
+
90
+ def cross_scan1b1_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=2):
91
+ if in_channel_first:
92
+ B, _, C, H, W = x.shape
93
+ if scans == 0:
94
+ y = torch.stack([
95
+ x[:, 0].flatten(2, 3),
96
+ x[:, 1].transpose(dim0=2, dim1=3).flatten(2, 3),
97
+ torch.flip(x[:, 2].flatten(2, 3), dims=[-1]),
98
+ torch.flip(x[:, 3].transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]),
99
+ ], dim=1)
100
+ elif scans == 1:
101
+ y = x.flatten(2, 3)
102
+ elif scans == 2:
103
+ y = torch.stack([
104
+ x[:, 0].flatten(2, 3),
105
+ x[:, 1].flatten(2, 3),
106
+ torch.flip(x[:, 2].flatten(2, 3), dims=[-1]),
107
+ torch.flip(x[:, 3].flatten(2, 3), dims=[-1]),
108
+ ], dim=1)
109
+ else:
110
+ B, H, W, _, C = x.shape
111
+ if scans == 0:
112
+ y = torch.stack([
113
+ x[:, :, :, 0].flatten(1, 2),
114
+ x[:, :, :, 1].transpose(dim0=1, dim1=2).flatten(1, 2),
115
+ torch.flip(x[:, :, :, 2].flatten(1, 2), dims=[1]),
116
+ torch.flip(x[:, :, :, 3].transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]),
117
+ ], dim=2)
118
+ elif scans == 1:
119
+ y = x.flatten(1, 2)
120
+ elif scans == 2:
121
+ y = torch.stack([
122
+ x[:, 0].flatten(1, 2),
123
+ x[:, 1].flatten(1, 2),
124
+ torch.flip(x[:, 2].flatten(1, 2), dims=[-1]),
125
+ torch.flip(x[:, 3].flatten(1, 2), dims=[-1]),
126
+ ], dim=2)
127
+
128
+ if in_channel_first and (not out_channel_first):
129
+ y = y.permute(0, 3, 1, 2).contiguous()
130
+ elif (not in_channel_first) and out_channel_first:
131
+ y = y.permute(0, 2, 3, 1).contiguous()
132
+
133
+ return y
134
+
135
+
136
+ def cross_merge1b1_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=2):
137
+ if out_channel_first:
138
+ B, K, D, H, W = y.shape
139
+ y = y.view(B, K, D, -1)
140
+ if scans == 0:
141
+ y = torch.stack([
142
+ y[:, 0],
143
+ y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3),
144
+ torch.flip(y[:, 2], dims=[-1]),
145
+ torch.flip(y[:, 3].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]),
146
+ ], dim=1)
147
+ elif scans == 1:
148
+ y = y
149
+ elif scans == 2:
150
+ y = torch.stack([
151
+ y[:, 0],
152
+ y[:, 1],
153
+ torch.flip(y[:, 2], dims=[-1]),
154
+ torch.flip(y[:, 3], dims=[-1]),
155
+ ], dim=1)
156
+ else:
157
+ B, H, W, _, D = y.shape
158
+ y = y.view(B, -1, 2, D)
159
+ if scans == 0:
160
+ y = torch.stack([
161
+ y[:, :, 0],
162
+ y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2),
163
+ torch.flip(y[:, :, 2], dims=[1]),
164
+ torch.flip(y[:, :, 3].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]),
165
+ ], dim=2)
166
+ elif scans == 1:
167
+ y = y
168
+ elif scans == 2:
169
+ y = torch.stack([
170
+ y[:, :, 0],
171
+ y[:, :, 1],
172
+ torch.flip(y[:, :, 2], dims=[1]),
173
+ torch.flip(y[:, :, 3], dims=[1]),
174
+ ], dim=2)
175
+
176
+ if out_channel_first and (not in_channel_first):
177
+ y = y.permute(0, 3, 1, 2).contiguous()
178
+ elif (not out_channel_first) and in_channel_first:
179
+ y = y.permute(0, 2, 3, 1).contiguous()
180
+
181
+ return y
182
+
183
+ class CrossScan(torch.nn.Module):
184
+ def __init__(self, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2):
185
+ super(CrossScan, self).__init__()
186
+ self.in_channel_first = in_channel_first
187
+ self.out_channel_first = out_channel_first
188
+ self.one_by_one = one_by_one
189
+ self.scans = scans
190
+
191
+ def forward(self, x: torch.Tensor):
192
+ if self.one_by_one:
193
+ B, K, C, H, W = x.shape
194
+ if not self.in_channel_first:
195
+ B, H, W, K, C = x.shape
196
+ else:
197
+ B, C, H, W = x.shape
198
+ if not self.in_channel_first:
199
+ B, H, W, C = x.shape
200
+ self.shape = (B, C, H, W)
201
+
202
+ _fn = cross_scan1b1_fwd if self.one_by_one else cross_scan_fwd
203
+ y = _fn(x, self.in_channel_first, self.out_channel_first, self.scans)
204
+
205
+ return y
206
+
207
+ def backward(self, ys: torch.Tensor):
208
+ B, C, H, W = self.shape
209
+
210
+ ys = ys.view(B, -1, C, H, W) if self.out_channel_first else ys.view(B, H, W, -1, C)
211
+ _fn = cross_merge1b1_fwd if self.one_by_one else cross_merge_fwd
212
+ y = _fn(ys, self.in_channel_first, self.out_channel_first, self.scans)
213
+
214
+ if self.one_by_one:
215
+ y = y.view(B, 2, -1, H, W) if self.in_channel_first else y.view(B, H, W, 2, -1)
216
+ else:
217
+ y = y.view(B, -1, H, W) if self.in_channel_first else y.view(B, H, W, -1)
218
+
219
+ return y
220
+
221
+
222
+ class CrossMerge(torch.nn.Module):
223
+ def __init__(self, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2):
224
+ super(CrossMerge, self).__init__()
225
+ self.in_channel_first = in_channel_first
226
+ self.out_channel_first = out_channel_first
227
+ self.one_by_one = one_by_one
228
+ self.scans = scans
229
+
230
+ def forward(self, ys: torch.Tensor):
231
+ B, K, C, H, W = ys.shape
232
+ if not self.out_channel_first:
233
+ B, H, W, K, C = ys.shape
234
+ self.shape = (B, C, H, W)
235
+
236
+ _fn = cross_merge1b1_fwd if self.one_by_one else cross_merge_fwd
237
+ y = _fn(ys, self.in_channel_first, self.out_channel_first, self.scans)
238
+
239
+ return y
240
+
241
+ def backward(self, x: torch.Tensor):
242
+ B, C, H, W = self.shape
243
+
244
+ if not self.one_by_one:
245
+ if self.in_channel_first:
246
+ x = x.view(B, C, H, W)
247
+ else:
248
+ x = x.view(B, H, W, C)
249
+ else:
250
+ if self.in_channel_first:
251
+ x = x.view(B, 2, C, H, W)
252
+ else:
253
+ x = x.view(B, H, W, 2, C)
254
+
255
+ _fn = cross_scan1b1_fwd if self.one_by_one else cross_scan_fwd
256
+ x = _fn(x, self.in_channel_first, self.out_channel_first, self.scans)
257
+ x = x.view(B, 2, C, H, W) if self.out_channel_first else x.view(B, H, W, 2, C)
258
+
259
+ return x
260
+ class CrossScanF(torch.autograd.Function):
261
+ @staticmethod
262
+ def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2):
263
+ # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 2, C)
264
+ # y: (B, 2, C, H * W) | (B, H * W, 2, C)
265
+ ctx.in_channel_first = in_channel_first
266
+ ctx.out_channel_first = out_channel_first
267
+ ctx.one_by_one = one_by_one
268
+ ctx.scans = scans
269
+
270
+ if one_by_one:
271
+ B, K, C, H, W = x.shape
272
+ if not in_channel_first:
273
+ B, H, W, K, C = x.shape
274
+ else:
275
+ B, C, H, W = x.shape
276
+ if not in_channel_first:
277
+ B, H, W, C = x.shape
278
+ ctx.shape = (B, C, H, W)
279
+
280
+ _fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd
281
+ y = _fn(x, in_channel_first, out_channel_first, scans)
282
+
283
+ return y
284
+
285
+ @staticmethod
286
+ def backward(ctx, ys: torch.Tensor):
287
+ # out: (b, k, d, l)
288
+ in_channel_first = ctx.in_channel_first
289
+ out_channel_first = ctx.out_channel_first
290
+ one_by_one = ctx.one_by_one
291
+ scans = ctx.scans
292
+ B, C, H, W = ctx.shape
293
+
294
+ ys = ys.view(B, -1, C, H, W) if out_channel_first else ys.view(B, H, W, -1, C)
295
+ _fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd
296
+ y = _fn(ys, in_channel_first, out_channel_first, scans)
297
+
298
+ if one_by_one:
299
+ y = y.view(B, 2, -1, H, W) if in_channel_first else y.view(B, H, W, 2, -1)
300
+ else:
301
+ y = y.view(B, -1, H, W) if in_channel_first else y.view(B, H, W, -1)
302
+
303
+ return y, None, None, None, None
304
+
305
+
306
+ class CrossMergeF(torch.autograd.Function):
307
+ @staticmethod
308
+ def forward(ctx, ys: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2):
309
+ # x: (B, C, H, W) | (B, H, W, C) | (B, 2, C, H, W) | (B, H, W, 2, C)
310
+ # y: (B, 2, C, H * W) | (B, H * W, 4, C)
311
+ ctx.in_channel_first = in_channel_first
312
+ ctx.out_channel_first = out_channel_first
313
+ ctx.one_by_one = one_by_one
314
+ ctx.scans = scans
315
+
316
+ B, K, C, H, W = ys.shape
317
+ if not out_channel_first:
318
+ B, H, W, K, C = ys.shape
319
+ ctx.shape = (B, C, H, W)
320
+
321
+ _fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd
322
+ y = _fn(ys, in_channel_first, out_channel_first, scans)
323
+
324
+ return y
325
+
326
+ @staticmethod
327
+ def backward(ctx, x: torch.Tensor):
328
+ # B, D, L = x.shape
329
+ # out: (b, k, d, h, w)
330
+ in_channel_first = ctx.in_channel_first
331
+ out_channel_first = ctx.out_channel_first
332
+ one_by_one = ctx.one_by_one
333
+ scans = ctx.scans
334
+ B, C, H, W = ctx.shape
335
+
336
+ if not one_by_one:
337
+ if in_channel_first:
338
+ x = x.view(B, C, H, W)
339
+ else:
340
+ x = x.view(B, H, W, C)
341
+ else:
342
+ if in_channel_first:
343
+ x = x.view(B, 2, C, H, W)
344
+ else:
345
+ x = x.view(B, H, W, 2, C)
346
+
347
+ _fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd
348
+ x = _fn(x, in_channel_first, out_channel_first, scans)
349
+ x = x.view(B, 2, C, H, W) if out_channel_first else x.view(B, H, W, 2, C)
350
+
351
+ return x, None, None, None, None
352
+
353
+
354
+ # triton implements ========================================
355
+
356
+ @triton.jit
357
+ def triton_cross_scan_flex_k2(
358
+ x, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
359
+ y, # (B, 4, C, H, W) | (B, H, W, 4, C)
360
+ x_layout: tl.constexpr,
361
+ y_layout: tl.constexpr,
362
+ operation: tl.constexpr,
363
+ onebyone: tl.constexpr,
364
+ scans: tl.constexpr,
365
+ BC: tl.constexpr,
366
+ BH: tl.constexpr,
367
+ BW: tl.constexpr,
368
+ DC: tl.constexpr,
369
+ DH: tl.constexpr,
370
+ DW: tl.constexpr,
371
+ NH: tl.constexpr,
372
+ NW: tl.constexpr,
373
+ ):
374
+ # x_layout = 0
375
+ # y_layout = 1 # 0 BCHW, 1 BHWC
376
+ # operation = 0 # 0 scan, 1 merge
377
+ # onebyone = 0 # 0 false, 1 true
378
+ # scans = 0 # 0 cross scan, 1 unidirectional, 2 bidirectional
379
+
380
+ i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
381
+ i_h, i_w = (i_hw // NW), (i_hw % NW)
382
+ _mask_h = (i_h * BH + tl.arange(0, BH)) < DH
383
+ _mask_w = (i_w * BW + tl.arange(0, BW)) < DW
384
+ _mask_hw = _mask_h[:, None] & _mask_w[None, :]
385
+ _for_C = min(DC - i_c * BC, BC)
386
+
387
+ HWRoute0 = i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :]
388
+ # HWRoute1 = i_w * BW * DH + tl.arange(0, BW)[None, :] * DH + i_h * BH + tl.arange(0, BH)[:, None] # trans
389
+ HWRoute2 = (NH - i_h - 1) * BH * DW + (BH - 1 - tl.arange(0, BH)[:, None]) * DW + (NW - i_w - 1) * BW + (BW - 1 - tl.arange(0, BW)[None, :]) + (DH - NH * BH) * DW + (DW - NW * BW) # flip
390
+ # HWRoute3 = (NW - i_w - 1) * BW * DH + (BW - 1 - tl.arange(0, BW)[None, :]) * DH + (NH - i_h - 1) * BH + (BH - 1 - tl.arange(0, BH)[:, None]) + (DH - NH * BH) + (DW - NW * BW) * DH # trans + flip
391
+
392
+ if scans == 1:
393
+ HWRoute2 = HWRoute0
394
+
395
+
396
+ _tmp1 = DC * DH * DW
397
+
398
+ y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC)
399
+ if y_layout == 0:
400
+ p_y1 = y_ptr_base + HWRoute0
401
+ # p_y2 = y_ptr_base + _tmp1 + HWRoute1
402
+ p_y3 = y_ptr_base + 2 * _tmp1 + HWRoute2
403
+ # p_y4 = y_ptr_base + 3 * _tmp1 + HWRoute3
404
+ else:
405
+ p_y1 = y_ptr_base + HWRoute0 * 4 * DC
406
+ # p_y2 = y_ptr_base + DC + HWRoute1 * 4 * DC
407
+ p_y3 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC
408
+ # p_y4 = y_ptr_base + 3 * DC + HWRoute3 * 4 * DC
409
+
410
+ if onebyone == 0:
411
+ x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
412
+ if x_layout == 0:
413
+ p_x = x_ptr_base + HWRoute0
414
+ else:
415
+ p_x = x_ptr_base + HWRoute0 * DC
416
+
417
+ if operation == 0:
418
+ for idxc in range(_for_C):
419
+ _idx_x = idxc * DH * DW if x_layout == 0 else idxc
420
+ _idx_y = idxc * DH * DW if y_layout == 0 else idxc
421
+ _x = tl.load(p_x + _idx_x, mask=_mask_hw)
422
+ tl.store(p_y1 + _idx_y, _x, mask=_mask_hw)
423
+ # tl.store(p_y2 + _idx_y, _x, mask=_mask_hw)
424
+ tl.store(p_y3 + _idx_y, _x, mask=_mask_hw)
425
+ # tl.store(p_y4 + _idx_y, _x, mask=_mask_hw)
426
+ elif operation == 1:
427
+ for idxc in range(_for_C):
428
+ _idx_x = idxc * DH * DW if x_layout == 0 else idxc
429
+ _idx_y = idxc * DH * DW if y_layout == 0 else idxc
430
+ _y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw)
431
+ # _y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw)
432
+ _y3 = tl.load(p_y3 + _idx_y, mask=_mask_hw)
433
+ # _y4 = tl.load(p_y4 + _idx_y, mask=_mask_hw)
434
+ # tl.store(p_x + _idx_x, _y1 + _y2 + _y3 + _y4, mask=_mask_hw)
435
+ tl.store(p_x + _idx_x, _y1 + _y3, mask=_mask_hw)
436
+
437
+
438
+ else:
439
+ x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
440
+ if x_layout == 0:
441
+ p_x1 = x_ptr_base + HWRoute0
442
+ p_x2 = p_x1 + _tmp1
443
+ p_x3 = p_x2 + _tmp1
444
+ p_x4 = p_x3 + _tmp1
445
+ else:
446
+ p_x1 = x_ptr_base + HWRoute0 * 4 * DC
447
+ p_x2 = p_x1 + DC
448
+ p_x3 = p_x2 + DC
449
+ p_x4 = p_x3 + DC
450
+
451
+ if operation == 0:
452
+ for idxc in range(_for_C):
453
+ _idx_x = idxc * DH * DW if x_layout == 0 else idxc
454
+ _idx_y = idxc * DH * DW if y_layout == 0 else idxc
455
+ tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw)
456
+ # tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw)
457
+ tl.store(p_y3 + _idx_y, tl.load(p_x3 + _idx_x, mask=_mask_hw), mask=_mask_hw)
458
+ # tl.store(p_y4 + _idx_y, tl.load(p_x4 + _idx_x, mask=_mask_hw), mask=_mask_hw)
459
+ else:
460
+ for idxc in range(_for_C):
461
+ _idx_x = idxc * DH * DW if x_layout == 0 else idxc
462
+ _idx_y = idxc * DH * DW if y_layout == 0 else idxc
463
+ tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw)
464
+ # tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw)
465
+ tl.store(p_x3 + _idx_x, tl.load(p_y3 + _idx_y), mask=_mask_hw)
466
+ # tl.store(p_x4 + _idx_x, tl.load(p_y4 + _idx_y), mask=_mask_hw)
467
+
468
+ @triton.jit
469
+ def triton_cross_scan_flex_k2(
470
+ x, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
471
+ y, # (B, 4, C, H, W) | (B, H, W, 4, C)
472
+ x_layout: tl.constexpr,
473
+ y_layout: tl.constexpr,
474
+ operation: tl.constexpr,
475
+ onebyone: tl.constexpr,
476
+ scans: tl.constexpr,
477
+ BC: tl.constexpr,
478
+ BH: tl.constexpr,
479
+ BW: tl.constexpr,
480
+ DC: tl.constexpr,
481
+ DH: tl.constexpr,
482
+ DW: tl.constexpr,
483
+ NH: tl.constexpr,
484
+ NW: tl.constexpr,
485
+ ):
486
+ i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
487
+ i_h, i_w = (i_hw // NW), (i_hw % NW)
488
+ _mask_h = (i_h * BH + tl.arange(0, BH)) < DH
489
+ _mask_w = (i_w * BW + tl.arange(0, BW)) < DW
490
+ _mask_hw = _mask_h[:, None] & _mask_w[None, :]
491
+ _for_C = min(DC - i_c * BC, BC)
492
+
493
+ HWRoute0 = i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :]
494
+ HWRoute2 = (NH - i_h - 1) * BH * DW + (BH - 1 - tl.arange(0, BH)[:, None]) * DW + (NW - i_w - 1) * BW + (BW - 1 - tl.arange(0, BW)[None, :]) + (DH - NH * BH) * DW + (DW - NW * BW) # flip
495
+
496
+ if scans == 1:
497
+ HWRoute2 = HWRoute0
498
+
499
+ _tmp1 = DC * DH * DW
500
+
501
+ y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC)
502
+ if y_layout == 0:
503
+ p_y1 = y_ptr_base + HWRoute0
504
+ p_y2 = y_ptr_base + 2 * _tmp1 + HWRoute2
505
+ else:
506
+ p_y1 = y_ptr_base + HWRoute0 * 4 * DC
507
+ p_y2 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC
508
+
509
+ if onebyone == 0:
510
+ x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
511
+ if x_layout == 0:
512
+ p_x = x_ptr_base + HWRoute0
513
+ else:
514
+ p_x = x_ptr_base + HWRoute0 * DC
515
+
516
+ if operation == 0:
517
+ for idxc in range(_for_C):
518
+ _idx_x = idxc * DH * DW if x_layout == 0 else idxc
519
+ _idx_y = idxc * DH * DW if y_layout == 0 else idxc
520
+ _x = tl.load(p_x + _idx_x, mask=_mask_hw)
521
+ tl.store(p_y1 + _idx_y, _x, mask=_mask_hw)
522
+ tl.store(p_y2 + _idx_y, _x, mask=_mask_hw)
523
+ elif operation == 1:
524
+ for idxc in range(_for_C):
525
+ _idx_x = idxc * DH * DW if x_layout == 0 else idxc
526
+ _idx_y = idxc * DH * DW if y_layout == 0 else idxc
527
+ _y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw)
528
+ _y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw)
529
+ tl.store(p_x + _idx_x, _y1 + _y2, mask=_mask_hw)
530
+
531
+ else:
532
+ x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
533
+ if x_layout == 0:
534
+ p_x1 = x_ptr_base + HWRoute0
535
+ p_x2 = p_x1 + 2 * _tmp1
536
+ else:
537
+ p_x1 = x_ptr_base + HWRoute0 * 4 * DC
538
+ p_x2 = p_x1 + 2 * DC
539
+
540
+ if operation == 0:
541
+ for idxc in range(_for_C):
542
+ _idx_x = idxc * DH * DW if x_layout == 0 else idxc
543
+ _idx_y = idxc * DH * DW if y_layout == 0 else idxc
544
+ tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw)
545
+ tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw)
546
+ else:
547
+ for idxc in range(_for_C):
548
+ _idx_x = idxc * DH * DW if x_layout == 0 else idxc
549
+ _idx_y = idxc * DH * DW if y_layout == 0 else idxc
550
+ tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw)
551
+ tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw)
552
+
553
+ @triton.jit
554
+ def triton_cross_scan_flex_k2(
555
+ x, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
556
+ y, # (B, 4, C, H, W) | (B, H, W, 4, C)
557
+ x_layout: tl.constexpr,
558
+ y_layout: tl.constexpr,
559
+ operation: tl.constexpr,
560
+ onebyone: tl.constexpr,
561
+ scans: tl.constexpr,
562
+ BC: tl.constexpr,
563
+ BH: tl.constexpr,
564
+ BW: tl.constexpr,
565
+ DC: tl.constexpr,
566
+ DH: tl.constexpr,
567
+ DW: tl.constexpr,
568
+ NH: tl.constexpr,
569
+ NW: tl.constexpr,
570
+ ):
571
+ i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
572
+ i_h, i_w = (i_hw // NW), (i_hw % NW)
573
+ _mask_h = (i_h * BH + tl.arange(0, BH)) < DH
574
+ _mask_w = (i_w * BW + tl.arange(0, BW)) < DW
575
+ _mask_hw = _mask_h[:, None] & _mask_w[None, :]
576
+ _for_C = min(DC - i_c * BC, BC)
577
+
578
+ HWRoute0 = i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :]
579
+ HWRoute2 = (NH - i_h - 1) * BH * DW + (BH - 1 - tl.arange(0, BH)[:, None]) * DW + (NW - i_w - 1) * BW + (BW - 1 - tl.arange(0, BW)[None, :]) + (DH - NH * BH) * DW + (DW - NW * BW) # flip
580
+
581
+ if scans == 1:
582
+ HWRoute2 = HWRoute0
583
+
584
+ _tmp1 = DC * DH * DW
585
+
586
+ y_ptr_base = y + i_b * 2 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC)
587
+ if y_layout == 0:
588
+ p_y1 = y_ptr_base + HWRoute0
589
+ p_y2 = y_ptr_base + 1 * _tmp1 + HWRoute2
590
+ else:
591
+ p_y1 = y_ptr_base + HWRoute0 * 4 * DC
592
+ p_y2 = y_ptr_base + 1 * DC + HWRoute2 * 4 * DC
593
+
594
+ if onebyone == 0:
595
+ x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
596
+ if x_layout == 0:
597
+ p_x = x_ptr_base + HWRoute0
598
+ else:
599
+ p_x = x_ptr_base + HWRoute0 * DC
600
+
601
+ if operation == 0:
602
+ for idxc in range(_for_C):
603
+ _idx_x = idxc * DH * DW if x_layout == 0 else idxc
604
+ _idx_y = idxc * DH * DW if y_layout == 0 else idxc
605
+ _x = tl.load(p_x + _idx_x, mask=_mask_hw)
606
+ tl.store(p_y1 + _idx_y, _x, mask=_mask_hw)
607
+ tl.store(p_y2 + _idx_y, _x, mask=_mask_hw)
608
+ elif operation == 1:
609
+ for idxc in range(_for_C):
610
+ _idx_x = idxc * DH * DW if x_layout == 0 else idxc
611
+ _idx_y = idxc * DH * DW if y_layout == 0 else idxc
612
+ _y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw)
613
+ _y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw)
614
+ tl.store(p_x + _idx_x, _y1 + _y2, mask=_mask_hw)
615
+
616
+ else:
617
+ x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
618
+ if x_layout == 0:
619
+ p_x1 = x_ptr_base + HWRoute0
620
+ p_x2 = p_x1 + 2 * _tmp1
621
+ else:
622
+ p_x1 = x_ptr_base + HWRoute0 * 4 * DC
623
+ p_x2 = p_x1 + 2 * DC
624
+
625
+ if operation == 0:
626
+ for idxc in range(_for_C):
627
+ _idx_x = idxc * DH * DW if x_layout == 0 else idxc
628
+ _idx_y = idxc * DH * DW if y_layout == 0 else idxc
629
+ _x1 = tl.load(p_x1 + _idx_x, mask=_mask_hw)
630
+ _x2 = tl.load(p_x2 + _idx_x, mask=_mask_hw)
631
+ tl.store(p_y1 + _idx_y, _x1, mask=_mask_hw)
632
+ tl.store(p_y2 + _idx_y, _x2, mask=_mask_hw)
633
+ else:
634
+ for idxc in range(_for_C):
635
+ _idx_x = idxc * DH * DW if x_layout == 0 else idxc
636
+ _idx_y = idxc * DH * DW if y_layout == 0 else idxc
637
+ _y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw)
638
+ _y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw)
639
+ tl.store(p_x1 + _idx_x, _y1, mask=_mask_hw)
640
+ tl.store(p_x2 + _idx_x, _y2, mask=_mask_hw)
641
+
642
+ class CrossScanTritonFk2(torch.autograd.Function):
643
+ @staticmethod
644
+ def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2):
645
+ if one_by_one:
646
+ if in_channel_first:
647
+ B, _, C, H, W = x.shape
648
+ else:
649
+ B, H, W, _, C = x.shape
650
+ else:
651
+ if in_channel_first:
652
+ B, C, H, W = x.shape
653
+ else:
654
+ B, H, W, C = x.shape
655
+ B, C, H, W = int(B), int(C), int(H), int(W)
656
+ BC, BH, BW = 1, 32, 32
657
+ NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)
658
+
659
+ ctx.in_channel_first = in_channel_first
660
+ ctx.out_channel_first = out_channel_first
661
+ ctx.one_by_one = one_by_one
662
+ ctx.scans = scans
663
+ ctx.shape = (B, C, H, W)
664
+ ctx.triton_shape = (BC, BH, BW, NC, NH, NW)
665
+
666
+ y = x.new_empty((B, 2, C, H * W)) if out_channel_first else x.new_empty((B, H * W, 2, C))
667
+ triton_cross_scan_flex_k2[(NH * NW, NC, B)](
668
+ x.contiguous(), y,
669
+ (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans,
670
+ BC, BH, BW, C, H, W, NH, NW
671
+ )
672
+ return y
673
+
674
+ @staticmethod
675
+ def backward(ctx, y: torch.Tensor):
676
+ in_channel_first = ctx.in_channel_first
677
+ out_channel_first = ctx.out_channel_first
678
+ one_by_one = ctx.one_by_one
679
+ scans = ctx.scans
680
+ B, C, H, W = ctx.shape
681
+ BC, BH, BW, NC, NH, NW = ctx.triton_shape
682
+ if one_by_one:
683
+ x = y.new_empty((B, 2, C, H, W)) if in_channel_first else y.new_empty((B, H, W, 2, C))
684
+ else:
685
+ x = y.new_empty((B, C, H, W)) if in_channel_first else y.new_empty((B, H, W, C))
686
+
687
+ triton_cross_scan_flex_k2[(NH * NW, NC, B)](
688
+ x, y.contiguous(),
689
+ (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,
690
+ BC, BH, BW, C, H, W, NH, NW
691
+ )
692
+ return x, None, None, None, None
693
+
694
+
695
+ class CrossMergeTritonFk2(torch.autograd.Function):
696
+ @staticmethod
697
+ def forward(ctx, y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2):
698
+ if out_channel_first:
699
+ B, _, C, H, W = y.shape
700
+ else:
701
+ B, H, W, _, C = y.shape
702
+ B, C, H, W = int(B), int(C), int(H), int(W)
703
+ BC, BH, BW = 1, 32, 32
704
+ NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)
705
+ ctx.in_channel_first = in_channel_first
706
+ ctx.out_channel_first = out_channel_first
707
+ ctx.one_by_one = one_by_one
708
+ ctx.scans = scans
709
+ ctx.shape = (B, C, H, W)
710
+ ctx.triton_shape = (BC, BH, BW, NC, NH, NW)
711
+ if one_by_one:
712
+ x = y.new_empty((B, 2, C, H * W)) if in_channel_first else y.new_empty((B, H * W, 2, C))
713
+ else:
714
+ x = y.new_empty((B, C, H * W)) if in_channel_first else y.new_empty((B, H * W, C))
715
+ triton_cross_scan_flex_k2[(NH * NW, NC, B)](
716
+ x, y.contiguous(),
717
+ (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,
718
+ BC, BH, BW, C, H, W, NH, NW
719
+ )
720
+ return x
721
+
722
+ @staticmethod
723
+ def backward(ctx, x: torch.Tensor):
724
+ in_channel_first = ctx.in_channel_first
725
+ out_channel_first = ctx.out_channel_first
726
+ one_by_one = ctx.one_by_one
727
+ scans = ctx.scans
728
+ B, C, H, W = ctx.shape
729
+ BC, BH, BW, NC, NH, NW = ctx.triton_shape
730
+ y = x.new_empty((B, 2, C, H, W)) if out_channel_first else x.new_empty((B, H, W, 2, C))
731
+ triton_cross_scan_flex_k2[(NH * NW, NC, B)](
732
+ x.contiguous(), y,
733
+ (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans,
734
+ BC, BH, BW, C, H, W, NH, NW
735
+ )
736
+ return y, None, None, None, None, None
737
+
738
+
739
+ # @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
740
+ def cross_scan_fn_k2(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2, force_torch=False):
741
+ # x: (B, C, H, W) | (B, H, W, C) | (B, 2, C, H, W) | (B, H, W, 2, C)
742
+ # y: (B, 2, C, L) | (B, L, 2, C)
743
+ # scans: 0: cross scan; 1 unidirectional; 2: bidirectional;
744
+ CSF = CrossScanTritonFk2 if WITH_TRITON and x.is_cuda and (not force_torch) else CrossScanF
745
+ return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans)
746
+
747
+ # @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
748
+ def cross_merge_fn_k2(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2, force_torch=False):
749
+ # y: (B, 2, C, L) | (B, L, 2, C)
750
+ # x: (B, C, H * W) | (B, H * W, C) | (B, 2, C, H * W) | (B, H * W, 2, C)
751
+ # scans: 0: cross scan; 1 unidirectional; 2: bidirectional;
752
+ CMF = CrossMergeTritonFk2 if WITH_TRITON and y.is_cuda and (not force_torch) else CrossMergeF
753
+ return CMF.apply(y, in_channel_first, out_channel_first, one_by_one, scans)
754
+
755
+ def cross_scan_fn_k2_torch(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2, force_torch=False):
756
+ cross_scan = CrossScan(in_channel_first, out_channel_first, one_by_one, scans)
757
+ return cross_scan(x)
758
+
759
+ def cross_merge_fn_k2_torch(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2, force_torch=False):
760
+ cross_merge = CrossMerge(in_channel_first, out_channel_first, one_by_one, scans)
761
+ return cross_merge(y)
762
+
763
+ # checks =================================================================
764
+
765
+ class CHECK:
766
+ def check_csm_triton():
767
+ B, C, H, W = 2, 192, 56, 57
768
+ dtype=torch.float16
769
+ dtype=torch.float32
770
+ x = torch.randn((B, C, H, W), dtype=dtype, device=torch.device("cuda")).requires_grad_(True)
771
+ y = torch.randn((B, 2, C, H, W), dtype=dtype, device=torch.device("cuda")).requires_grad_(True)
772
+ x1 = x.clone().detach().requires_grad_(True)
773
+ y1 = y.clone().detach().requires_grad_(True)
774
+
775
+ def cross_scan(x: torch.Tensor):
776
+ B, C, H, W = x.shape
777
+ L = H * W
778
+ xs = torch.stack([
779
+ x.view(B, C, L),
780
+ torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L),
781
+ torch.flip(x.contiguous().view(B, C, L), dims=[-1]),
782
+ torch.flip(torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L), dims=[-1]),
783
+ ], dim=1).view(B, 4, C, L)
784
+ return xs
785
+
786
+ def cross_merge(out_y: torch.Tensor):
787
+ B, K, D, H, W = out_y.shape
788
+ L = H * W
789
+ out_y = out_y.view(B, K, D, L)
790
+ inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
791
+ wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
792
+ invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
793
+ y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y
794
+ return y
795
+
796
+ def cross_scan_1b1(x: torch.Tensor):
797
+ B, K, C, H, W = x.shape
798
+ L = H * W
799
+ xs = torch.stack([
800
+ x[:, 0].view(B, C, L),
801
+ torch.transpose(x[:, 1], dim0=2, dim1=3).contiguous().view(B, C, L),
802
+ torch.flip(x[:, 2].contiguous().view(B, C, L), dims=[-1]),
803
+ torch.flip(torch.transpose(x[:, 3], dim0=2, dim1=3).contiguous().view(B, C, L), dims=[-1]),
804
+ ], dim=1).view(B, 2, C, L)
805
+ return xs
806
+
807
+ def unidi_scan(x):
808
+ B, C, H, W = x.shape
809
+ x = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1)
810
+ return x
811
+
812
+ def unidi_merge(ys):
813
+ B, K, C, H, W = ys.shape
814
+ return ys.view(B, 4, -1, H * W).sum(1)
815
+
816
+ def bidi_scan(x):
817
+ B, C, H, W = x.shape
818
+ x = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1)
819
+ x = torch.cat([x, x.flip(dims=[-1])], dim=1)
820
+ return x
821
+
822
+ def bidi_merge(ys):
823
+ B, K, D, H, W = ys.shape
824
+ ys = ys.view(B, K, D, -1)
825
+ ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
826
+ return ys.contiguous().sum(1)
827
+
828
+ if True:
829
+ # res0 = triton.testing.do_bench(lambda :cross_scan(x))
830
+ res1 = triton.testing.do_bench(lambda :cross_scan_fn_k2(x, True, True, False))
831
+ # res2 = triton.testing.do_bench(lambda :CrossScanTriton.apply(x))
832
+ # res3 = triton.testing.do_bench(lambda :cross_merge(y))
833
+ res4 = triton.testing.do_bench(lambda :cross_merge_fn_k2(y, True, True, False))
834
+ # res5 = triton.testing.do_bench(lambda :CrossMergeTriton.apply(y))
835
+ # print(res0, res1, res2, res3, res4, res5)
836
+ print(res0, res1, res3, res4)
837
+ res0 = triton.testing.do_bench(lambda :cross_scan(x).sum().backward())
838
+ res1 = triton.testing.do_bench(lambda :cross_scan_fn_k2(x, True, True, False).sum().backward())
839
+ # res2 = triton.testing.do_bench(lambda :CrossScanTriton.apply(x).sum().backward())
840
+ res3 = triton.testing.do_bench(lambda :cross_merge(y).sum().backward())
841
+ res4 = triton.testing.do_bench(lambda :cross_merge_fn_k2(y, True, True, False).sum().backward())
842
+ # res5 = triton.testing.do_bench(lambda :CrossMergeTriton.apply(y).sum().backward())
843
+ # print(res0, res1, res2, res3, res4, res5)
844
+ print(res0, res1, res3, res4)
845
+
846
+ print("test cross scan")
847
+ for (cs0, cm0, cs1, cm1) in [
848
+ # channel_first -> channel_first
849
+ (cross_scan, cross_merge, cross_scan_fn_k2, cross_merge_fn_k2),
850
+ (unidi_scan, unidi_merge, lambda x: cross_scan_fn_k2(x, scans=1), lambda x: cross_merge_fn_k2(x, scans=1)),
851
+ (bidi_scan, bidi_merge, lambda x: cross_scan_fn_k2(x, scans=2), lambda x: cross_merge_fn_k2(x, scans=2)),
852
+
853
+ # flex: BLC->BCL; BCL->BLC; BLC->BLC;
854
+ (cross_scan, cross_merge, lambda x: cross_scan_fn_k2(x.permute(0, 2, 3, 1), in_channel_first=False), lambda x: cross_merge_fn_k2(x, in_channel_first=False).permute(0, 2, 1)),
855
+ (cross_scan, cross_merge, lambda x: cross_scan_fn_k2(x, out_channel_first=False).permute(0, 2, 3, 1), lambda x: cross_merge_fn_k2(x.permute(0, 3, 4, 1, 2), out_channel_first=False)),
856
+ (cross_scan, cross_merge, lambda x: cross_scan_fn_k2(x.permute(0, 2, 3, 1), in_channel_first=False, out_channel_first=False).permute(0, 2, 3, 1), lambda x: cross_merge_fn_k2(x.permute(0, 3, 4, 1, 2), in_channel_first=False, out_channel_first=False).permute(0, 2, 1)),
857
+
858
+ # previous
859
+ # (cross_scan, cross_merge, lambda x: CrossScanTriton.apply(x), lambda x: CrossMergeTriton.apply(x)),
860
+ # (unidi_scan, unidi_merge, lambda x: getCSM(1)[0].apply(x), lambda x: getCSM(1)[1].apply(x)),
861
+ # (bidi_scan, bidi_merge, lambda x: getCSM(2)[0].apply(x), lambda x: getCSM(2)[1].apply(x)),
862
+ ]:
863
+ x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
864
+ o0 = cs0(x)
865
+ o1 = cs1(x1)
866
+ o0.backward(y.view(B, 2, C, H * W))
867
+ o1.backward(y.view(B, 2, C, H * W))
868
+ print((o0 - o1).abs().max())
869
+ print((x.grad - x1.grad).abs().max())
870
+ o0 = cm0(y)
871
+ o1 = cm1(y1)
872
+ o0.backward(x.view(B, C, H * W))
873
+ o1.backward(x.view(B, C, H * W))
874
+ print((o0 - o1).abs().max())
875
+ print((y.grad - y1.grad).abs().max())
876
+ x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
877
+ print("===============", flush=True)
878
+
879
+ print("test cross scan one by one")
880
+ for (cs0, cs1) in [
881
+ (cross_scan_1b1, lambda x: cross_scan_fn_k2(x, one_by_one=True)),
882
+ # (cross_scan_1b1, lambda x: CrossScanTriton1b1.apply(x)),
883
+ ]:
884
+ o0 = cs0(y)
885
+ o1 = cs1(y1)
886
+ o0.backward(y.view(B, 2, C, H * W))
887
+ o1.backward(y.view(B, 2, C, H * W))
888
+ print((o0 - o1).abs().max())
889
+ print((y.grad - y1.grad).abs().max())
890
+ x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
891
+ print("===============", flush=True)
892
+
893
+
894
+ if __name__ == "__main__":
895
+ CHECK.check_csm_triton()
896
+
897
+
898
+
899
+
rscd/models/backbones/lib_mamba/csms6s.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import warnings
4
+
5
+
6
+ WITH_SELECTIVESCAN_OFLEX = True
7
+ WITH_SELECTIVESCAN_CORE = False
8
+ WITH_SELECTIVESCAN_MAMBA = True
9
+ try:
10
+ import selective_scan_cuda_oflex
11
+ except ImportError:
12
+ WITH_SELECTIVESCAN_OFLEX = False
13
+ warnings.warn("Can not import selective_scan_cuda_oflex. This affects speed.")
14
+ print("Can not import selective_scan_cuda_oflex. This affects speed.", flush=True)
15
+ try:
16
+ import selective_scan_cuda_core
17
+ except ImportError:
18
+ WITH_SELECTIVESCAN_CORE = False
19
+ try:
20
+ import selective_scan_cuda
21
+ except ImportError:
22
+ WITH_SELECTIVESCAN_MAMBA = False
23
+
24
+
25
+ def selective_scan_torch(
26
+ u: torch.Tensor, # (B, K * C, L)
27
+ delta: torch.Tensor, # (B, K * C, L)
28
+ A: torch.Tensor, # (K * C, N)
29
+ B: torch.Tensor, # (B, K, N, L)
30
+ C: torch.Tensor, # (B, K, N, L)
31
+ D: torch.Tensor = None, # (K * C)
32
+ delta_bias: torch.Tensor = None, # (K * C)
33
+ delta_softplus=True,
34
+ oflex=True,
35
+ *args,
36
+ **kwargs
37
+ ):
38
+ dtype_in = u.dtype
39
+ Batch, K, N, L = B.shape
40
+ KCdim = u.shape[1]
41
+ Cdim = int(KCdim / K)
42
+ assert u.shape == (Batch, KCdim, L)
43
+ assert delta.shape == (Batch, KCdim, L)
44
+ assert A.shape == (KCdim, N)
45
+ assert C.shape == B.shape
46
+
47
+ if delta_bias is not None:
48
+ delta = delta + delta_bias[..., None]
49
+ if delta_softplus:
50
+ delta = torch.nn.functional.softplus(delta)
51
+
52
+ u, delta, A, B, C = u.float(), delta.float(), A.float(), B.float(), C.float()
53
+ B = B.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L)
54
+ C = C.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L)
55
+ deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
56
+ deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
57
+
58
+ if True:
59
+ x = A.new_zeros((Batch, KCdim, N))
60
+ ys = []
61
+ for i in range(L):
62
+ x = deltaA[:, :, i, :] * x + deltaB_u[:, :, i, :]
63
+ y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
64
+ ys.append(y)
65
+ y = torch.stack(ys, dim=2) # (B, C, L)
66
+
67
+ out = y if D is None else y + u * D.unsqueeze(-1)
68
+ return out if oflex else out.to(dtype=dtype_in)
69
+
70
+
71
+ class SelectiveScanCuda(torch.autograd.Function):
72
+ @staticmethod
73
+ @torch.cuda.amp.custom_fwd
74
+ def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, oflex=True, backend=None):
75
+ ctx.delta_softplus = delta_softplus
76
+ backend = "oflex" if WITH_SELECTIVESCAN_OFLEX and (backend is None) else backend
77
+ backend = "core" if WITH_SELECTIVESCAN_CORE and (backend is None) else backend
78
+ backend = "mamba" if WITH_SELECTIVESCAN_MAMBA and (backend is None) else backend
79
+ ctx.backend = backend
80
+ if backend == "oflex":
81
+ out, x, *rest = selective_scan_cuda_oflex.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1, oflex)
82
+ elif backend == "core":
83
+ out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1)
84
+ elif backend == "mamba":
85
+ out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus)
86
+ ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
87
+ return out
88
+
89
+ @staticmethod
90
+ @torch.cuda.amp.custom_bwd
91
+ def backward(ctx, dout, *args):
92
+ u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
93
+ backend = ctx.backend
94
+ if dout.stride(-1) != 1:
95
+ dout = dout.contiguous()
96
+ if backend == "oflex":
97
+ du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_oflex.bwd(
98
+ u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1
99
+ )
100
+ elif backend == "core":
101
+ du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_core.bwd(
102
+ u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1
103
+ )
104
+ elif backend == "mamba":
105
+ du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
106
+ u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus,
107
+ False
108
+ )
109
+ return du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None
110
+
111
+
112
+ def selective_scan_fn(
113
+ u: torch.Tensor, # (B, K * C, L)
114
+ delta: torch.Tensor, # (B, K * C, L)
115
+ A: torch.Tensor, # (K * C, N)
116
+ B: torch.Tensor, # (B, K, N, L)
117
+ C: torch.Tensor, # (B, K, N, L)
118
+ D: torch.Tensor = None, # (K * C)
119
+ delta_bias: torch.Tensor = None, # (K * C)
120
+ delta_softplus=True,
121
+ oflex=True,
122
+ backend=None,
123
+ ):
124
+ WITH_CUDA = (WITH_SELECTIVESCAN_OFLEX or WITH_SELECTIVESCAN_CORE or WITH_SELECTIVESCAN_MAMBA)
125
+ fn = selective_scan_torch if backend == "torch" or (not WITH_CUDA) else SelectiveScanCuda.apply
126
+ return fn(u, delta, A, B, C, D, delta_bias, delta_softplus, oflex, backend)
127
+
128
+
129
+ # fvcore flops =======================================
130
+ def print_jit_input_names(inputs):
131
+ print("input params: ", end=" ", flush=True)
132
+ try:
133
+ for i in range(10):
134
+ print(inputs[i].debugName(), end=" ", flush=True)
135
+ except Exception as e:
136
+ pass
137
+ print("", flush=True)
138
+
139
+ def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_complex=False):
140
+ """
141
+ u: r(B D L)
142
+ delta: r(B D L)
143
+ A: r(D N)
144
+ B: r(B N L)
145
+ C: r(B N L)
146
+ D: r(D)
147
+ z: r(B D L)
148
+ delta_bias: r(D), fp32
149
+
150
+ ignores:
151
+ [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
152
+ """
153
+ assert not with_complex
154
+ # https://github.com/state-spaces/mamba/issues/110
155
+ flops = 9 * B * L * D * N
156
+ if with_D:
157
+ flops += B * D * L
158
+ if with_Z:
159
+ flops += B * D * L
160
+ return flops
161
+
162
+ # this is only for selective_scan_ref...
163
+ def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False):
164
+ """
165
+ u: r(B D L)
166
+ delta: r(B D L)
167
+ A: r(D N)
168
+ B: r(B N L)
169
+ C: r(B N L)
170
+ D: r(D)
171
+ z: r(B D L)
172
+ delta_bias: r(D), fp32
173
+
174
+ ignores:
175
+ [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
176
+ """
177
+ import numpy as np
178
+
179
+ # fvcore.nn.jit_handles
180
+ def get_flops_einsum(input_shapes, equation):
181
+ np_arrs = [np.zeros(s) for s in input_shapes]
182
+ optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1]
183
+ for line in optim.split("\n"):
184
+ if "optimized flop" in line.lower():
185
+ # divided by 2 because we count MAC (multiply-add counted as one flop)
186
+ flop = float(np.floor(float(line.split(":")[-1]) / 2))
187
+ return flop
188
+
189
+
190
+ assert not with_complex
191
+
192
+ flops = 0 # below code flops = 0
193
+
194
+ flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln")
195
+ if with_Group:
196
+ flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln")
197
+ else:
198
+ flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln")
199
+
200
+ in_for_flops = B * D * N
201
+ if with_Group:
202
+ in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd")
203
+ else:
204
+ in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd")
205
+ flops += L * in_for_flops
206
+ if with_D:
207
+ flops += B * D * L
208
+ if with_Z:
209
+ flops += B * D * L
210
+ return flops
211
+
212
+ def selective_scan_flop_jit(inputs, outputs, backend="prefixsum", verbose=True):
213
+ if verbose:
214
+ print_jit_input_names(inputs)
215
+ flops_fn = flops_selective_scan_ref if backend == "naive" else flops_selective_scan_fn
216
+ B, D, L = inputs[0].type().sizes()
217
+ N = inputs[2].type().sizes()[1]
218
+ flops = flops_fn(B=B, L=L, D=D, N=N, with_D=True, with_Z=False)
219
+ return flops
220
+
221
+
222
+ if __name__ == "__main__":
223
+ def params(B, K, C, N, L, device = torch.device("cuda"), itype = torch.float):
224
+ As = (-0.5 * torch.rand(K * C, N, device=device, dtype=torch.float32)).requires_grad_()
225
+ Bs = torch.randn((B, K, N, L), device=device, dtype=itype).requires_grad_()
226
+ Cs = torch.randn((B, K, N, L), device=device, dtype=itype).requires_grad_()
227
+ Ds = torch.randn((K * C), device=device, dtype=torch.float32).requires_grad_()
228
+ u = torch.randn((B, K * C, L), device=device, dtype=itype).requires_grad_()
229
+ delta = (0.5 * torch.rand((B, K * C, L), device=device, dtype=itype)).requires_grad_()
230
+ delta_bias = (0.5 * torch.rand((K * C), device=device, dtype=torch.float32)).requires_grad_()
231
+ return u, delta, As, Bs, Cs, Ds, delta_bias
232
+
233
+ def bench(func, xs, Warmup=30, NTimes=20):
234
+ import time
235
+ torch.cuda.synchronize()
236
+ for r in range(Warmup):
237
+ for x in xs:
238
+ func(x)
239
+ torch.cuda.synchronize()
240
+ tim0 = time.time()
241
+ for r in range(NTimes):
242
+ for x in xs:
243
+ func(x)
244
+ torch.cuda.synchronize()
245
+ return (time.time() - tim0) / NTimes
246
+
247
+ def check():
248
+ u, delta, As, Bs, Cs, Ds, delta_bias = params(1, 4, 16, 8, 512, itype=torch.float16)
249
+ u1, delta1, As1, Bs1, Cs1, Ds1, delta_bias1 = [x.clone().detach().requires_grad_() for x in [u, delta, As, Bs, Cs, Ds, delta_bias]]
250
+
251
+ # out_ref = selective_scan_fn(u, delta, As, Bs, Cs, Ds, delta_bias, True, backend="torch")
252
+ out = selective_scan_fn(u1, delta1, As1, Bs1, Cs1, Ds1, delta_bias1, True, backend="oflex")
253
+ out_ref = selective_scan_fn(u, delta, As, Bs, Cs, Ds, delta_bias, True, backend="mamba")
254
+ print((out_ref - out).abs().max())
255
+ out.sum().backward()
256
+ out_ref.sum().backward()
257
+ for x, y in zip([u, As, Bs, Cs, Ds, delta, delta_bias], [u1, As1, Bs1, Cs1, Ds1, delta1, delta_bias1]):
258
+ print((x.grad - y.grad).abs().max())
259
+
260
+ u, delta, As, Bs, Cs, Ds, delta_bias = params(128, 4, 96, 8, 56 * 56)
261
+ print(bench(lambda x: selective_scan_fn(x[0], x[1], x[2], x[3], x[4], x[5], x[6], True, backend="oflex"), [(u, delta, As, Bs, Cs, Ds, delta_bias),]))
262
+ print(bench(lambda x: selective_scan_fn(x[0], x[1], x[2], x[3], x[4], x[5], x[6], True, backend="mamba"), [(u, delta, As, Bs, Cs, Ds, delta_bias),]))
263
+ print(bench(lambda x: selective_scan_fn(x[0], x[1], x[2], x[3], x[4], x[5], x[6], True, backend="torch"), [(u, delta, As, Bs, Cs, Ds, delta_bias),]))
264
+
265
+ check()
266
+
rscd/models/backbones/lib_mamba/kernels/selective_scan/README.md ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mamba-mini
2
+ An efficient implementation of selective scan in one file, works with both cpu and gpu, with corresponding mathematical derivation. It is probably the code which is the most close to selective_scan_cuda in mamba.
3
+
4
+ ### mathematical derivation
5
+ ![image](../assets/derivation.png)
6
+
7
+ ### code
8
+ ```python
9
+ import torch
10
+ def selective_scan_easy(us, dts, As, Bs, Cs, Ds, delta_bias=None, delta_softplus=False, return_last_state=False, chunksize=64):
11
+ """
12
+ # B: batch_size, G: groups, D: dim, N: state dim, L: seqlen
13
+ us: B, G * D, L
14
+ dts: B, G * D, L
15
+ As: G * D, N
16
+ Bs: B, G, N, L
17
+ Cs: B, G, N, L
18
+ Ds: G * D
19
+ delta_bias: G * D
20
+ # chunksize can be any as you like. But as the chunksize raises, hs may get None, as exp(sum(delta) A) is really small
21
+ """
22
+ def selective_scan_chunk(us, dts, As, Bs, Cs, hprefix):
23
+ """
24
+ partial(h) / partial(t) = Ah + Bu; y = Ch + Du;
25
+ => partial(h*exp(-At)) / partial(t) = Bu*exp(-At);
26
+ => h_t = h_0 + sum_{0}_{t}_{Bu*exp(A(t-v)) dv};
27
+ => h_b = exp(A(dt_a + ... + dt_{b-1})) * (h_a + sum_{a}_{b-1}_{Bu*exp(-A(dt_a + ... + dt_i)) dt_i});
28
+ y_i = C_i*h_i + D*u_i
29
+ """
30
+ """
31
+ us, dts: (L, B, G, D) # L is chunk_size
32
+ As: (G, D, N)
33
+ Bs, Cs: (L, B, G, N)
34
+ Ds: (G, D)
35
+ hprefix: (B, G, D, N)
36
+ """
37
+ ts = dts.cumsum(dim=0)
38
+ Ats = torch.einsum("gdn,lbgd->lbgdn", As, ts).exp()
39
+ scale = Ats[-1].detach()
40
+ rAts = Ats / scale
41
+ duts = dts * us
42
+ dtBus = torch.einsum("lbgd,lbgn->lbgdn", duts, Bs)
43
+ hs_tmp = rAts * (dtBus / rAts).cumsum(dim=0)
44
+ hs = hs_tmp + Ats * hprefix.unsqueeze(0)
45
+ ys = torch.einsum("lbgn,lbgdn->lbgd", Cs, hs)
46
+ return ys, hs
47
+
48
+ inp_dtype = us.dtype
49
+ has_D = Ds is not None
50
+
51
+ dts = dts.float()
52
+ if delta_bias is not None:
53
+ dts = dts + delta_bias.view(1, -1, 1).float()
54
+ if delta_softplus:
55
+ dts = torch.nn.functional.softplus(dts)
56
+
57
+ if len(Bs.shape) == 3:
58
+ Bs = Bs.unsqueeze(1)
59
+ if len(Cs.shape) == 3:
60
+ Cs = Cs.unsqueeze(1)
61
+ B, G, N, L = Bs.shape
62
+ us = us.view(B, G, -1, L).permute(3, 0, 1, 2).float()
63
+ dts = dts.view(B, G, -1, L).permute(3, 0, 1, 2).float()
64
+ As = As.view(G, -1, N).float()
65
+ Bs = Bs.permute(3, 0, 1, 2).float()
66
+ Cs = Cs.permute(3, 0, 1, 2).float()
67
+ Ds = Ds.view(G, -1).float() if has_D else None
68
+ D = As.shape[1]
69
+
70
+ oys = []
71
+ # ohs = []
72
+ hprefix = us.new_zeros((B, G, D, N), dtype=torch.float)
73
+ for i in range(0, L - 1, chunksize):
74
+ ys, hs = selective_scan_chunk(
75
+ us[i:i + chunksize], dts[i:i + chunksize],
76
+ As, Bs[i:i + chunksize], Cs[i:i + chunksize], hprefix,
77
+ )
78
+ oys.append(ys)
79
+ # ohs.append(hs)
80
+ hprefix = hs[-1]
81
+
82
+ oys = torch.cat(oys, dim=0)
83
+ # ohs = torch.cat(ohs, dim=0)
84
+ if has_D:
85
+ oys = oys + Ds * us
86
+ oys = oys.permute(1, 2, 3, 0).view(B, -1, L)
87
+ oys = oys.to(inp_dtype)
88
+ # hprefix = hprefix.to(inp_dtype)
89
+
90
+ return oys if not return_last_state else (oys, hprefix.view(B, G * D, N))
91
+
92
+ ```
93
+
94
+ ### to test
95
+ ```bash
96
+ pytest test_selective_scan.py
97
+ ```
rscd/models/backbones/lib_mamba/kernels/selective_scan/build/lib.linux-x86_64-3.8/selective_scan_cuda_oflex.cpython-38-x86_64-linux-gnu.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3fcf524eeaf71e641653c1aff1f8fac591a1e6916300d322830bd02476873ab1
3
+ size 34969816
rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/.ninja_deps ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e61c23bbef4f0b8f414187d9149e6e0818ce400c3351405d494b774b988bf6d
3
+ size 501136
rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/.ninja_log ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # ninja log v5
2
+ 7 17272 1748140810026258300 /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_fwd.o ab3bac6bd7b8268f
3
+ 8 23832 1748140816810431800 /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_oflex.o 7f9a77b388057fc6
4
+ 7 57431 1748140850419474900 /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_bwd.o 3cffffbdd6b9fec1
rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/build.ninja ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ninja_required_version = 1.3
2
+ cxx = c++
3
+ nvcc = /usr/local/cuda-11.8/bin/nvcc
4
+
5
+ cflags = -pthread -B /root/anaconda3/envs/rscd/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan -I/root/anaconda3/envs/rscd/lib/python3.8/site-packages/torch/include -I/root/anaconda3/envs/rscd/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/root/anaconda3/envs/rscd/lib/python3.8/site-packages/torch/include/TH -I/root/anaconda3/envs/rscd/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda-11.8/include -I/root/anaconda3/envs/rscd/include/python3.8 -c
6
+ post_cflags = -O3 -std=c++17 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=selective_scan_cuda_oflex -D_GLIBCXX_USE_CXX11_ABI=0
7
+ cuda_cflags = -I/mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan -I/root/anaconda3/envs/rscd/lib/python3.8/site-packages/torch/include -I/root/anaconda3/envs/rscd/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/root/anaconda3/envs/rscd/lib/python3.8/site-packages/torch/include/TH -I/root/anaconda3/envs/rscd/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda-11.8/include -I/root/anaconda3/envs/rscd/include/python3.8 -c
8
+ cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -std=c++17 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_BFLOAT16_OPERATORS__ -U__CUDA_NO_BFLOAT16_CONVERSIONS__ -U__CUDA_NO_BFLOAT162_OPERATORS__ -U__CUDA_NO_BFLOAT162_CONVERSIONS__ --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math --ptxas-options=-v -lineinfo -gencode arch=compute_70,code=sm_70 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_90,code=sm_90 --threads 4 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=selective_scan_cuda_oflex -D_GLIBCXX_USE_CXX11_ABI=0
9
+ cuda_dlink_post_cflags =
10
+ ldflags =
11
+
12
+ rule compile
13
+ command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
14
+ depfile = $out.d
15
+ deps = gcc
16
+
17
+ rule cuda_compile
18
+ depfile = $out.d
19
+ deps = gcc
20
+ command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags
21
+
22
+
23
+
24
+
25
+
26
+ build /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_bwd.o: cuda_compile /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_core_bwd.cu
27
+ build /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_fwd.o: cuda_compile /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_core_fwd.cu
28
+ build /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_oflex.o: compile /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_oflex.cpp
29
+
30
+
31
+
32
+
33
+
34
+
35
+
rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_bwd.o ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a18749c332876fbeddf0356bbb0d4c979fffd89df7f4a27796e2dd39b523f2e
3
+ size 12294744
rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_fwd.o ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:642208da032fd86584c184e13e9ed9d18a9c6e85d925770f7ce034e0d22774a9
3
+ size 13211880
rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_oflex.o ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0867e500c6352b2c0e1938ea0c8e6825aafbabc5699ec41d25a7793c56ed5d1e
3
+ size 14839600
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cub_extra.cuh ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // WarpMask is copied from /usr/local/cuda-12.1/include/cub/util_ptx.cuh
2
+ // PowerOfTwo is copied from /usr/local/cuda-12.1/include/cub/util_type.cuh
3
+
4
+ #pragma once
5
+
6
+ #include <cub/util_type.cuh>
7
+ #include <cub/util_arch.cuh>
8
+ #include <cub/util_namespace.cuh>
9
+ #include <cub/util_debug.cuh>
10
+
11
+ /**
12
+ * \brief Statically determine if N is a power-of-two
13
+ */
14
+ template <int N>
15
+ struct PowerOfTwo
16
+ {
17
+ enum { VALUE = ((N & (N - 1)) == 0) };
18
+ };
19
+
20
+
21
+ /**
22
+ * @brief Returns the warp mask for a warp of @p LOGICAL_WARP_THREADS threads
23
+ *
24
+ * @par
25
+ * If the number of threads assigned to the virtual warp is not a power of two,
26
+ * it's assumed that only one virtual warp exists.
27
+ *
28
+ * @tparam LOGICAL_WARP_THREADS <b>[optional]</b> The number of threads per
29
+ * "logical" warp (may be less than the number of
30
+ * hardware warp threads).
31
+ * @param warp_id Id of virtual warp within architectural warp
32
+ */
33
+ template <int LOGICAL_WARP_THREADS, int LEGACY_PTX_ARCH = 0>
34
+ __host__ __device__ __forceinline__
35
+ unsigned int WarpMask(unsigned int warp_id)
36
+ {
37
+ constexpr bool is_pow_of_two = PowerOfTwo<LOGICAL_WARP_THREADS>::VALUE;
38
+ constexpr bool is_arch_warp = LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0);
39
+
40
+ unsigned int member_mask = 0xFFFFFFFFu >>
41
+ (CUB_WARP_THREADS(0) - LOGICAL_WARP_THREADS);
42
+
43
+ if (is_pow_of_two && !is_arch_warp)
44
+ {
45
+ member_mask <<= warp_id * LOGICAL_WARP_THREADS;
46
+ }
47
+
48
+ return member_mask;
49
+ }
50
+
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan.cpp ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #include <ATen/cuda/CUDAContext.h>
6
+ #include <c10/cuda/CUDAGuard.h>
7
+ #include <torch/extension.h>
8
+ #include <vector>
9
+
10
+ #include "selective_scan.h"
11
+ #define MAX_DSTATE 256
12
+
13
+ #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
14
+ using weight_t = float;
15
+
16
+ #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
17
+ if (ITYPE == at::ScalarType::Half) { \
18
+ using input_t = at::Half; \
19
+ __VA_ARGS__(); \
20
+ } else if (ITYPE == at::ScalarType::BFloat16) { \
21
+ using input_t = at::BFloat16; \
22
+ __VA_ARGS__(); \
23
+ } else if (ITYPE == at::ScalarType::Float) { \
24
+ using input_t = float; \
25
+ __VA_ARGS__(); \
26
+ } else { \
27
+ AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
28
+ }
29
+
30
+ template<int knrows, typename input_t, typename weight_t>
31
+ void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream);
32
+
33
+ template <int knrows, typename input_t, typename weight_t>
34
+ void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream);
35
+
36
+ void set_ssm_params_fwd(SSMParamsBase &params,
37
+ // sizes
38
+ const size_t batch,
39
+ const size_t dim,
40
+ const size_t seqlen,
41
+ const size_t dstate,
42
+ const size_t n_groups,
43
+ const size_t n_chunks,
44
+ // device pointers
45
+ const at::Tensor u,
46
+ const at::Tensor delta,
47
+ const at::Tensor A,
48
+ const at::Tensor B,
49
+ const at::Tensor C,
50
+ const at::Tensor out,
51
+ void* D_ptr,
52
+ void* delta_bias_ptr,
53
+ void* x_ptr,
54
+ bool delta_softplus) {
55
+
56
+ // Reset the parameters
57
+ memset(&params, 0, sizeof(params));
58
+
59
+ params.batch = batch;
60
+ params.dim = dim;
61
+ params.seqlen = seqlen;
62
+ params.dstate = dstate;
63
+ params.n_groups = n_groups;
64
+ params.n_chunks = n_chunks;
65
+ params.dim_ngroups_ratio = dim / n_groups;
66
+
67
+ params.delta_softplus = delta_softplus;
68
+
69
+ // Set the pointers and strides.
70
+ params.u_ptr = u.data_ptr();
71
+ params.delta_ptr = delta.data_ptr();
72
+ params.A_ptr = A.data_ptr();
73
+ params.B_ptr = B.data_ptr();
74
+ params.C_ptr = C.data_ptr();
75
+ params.D_ptr = D_ptr;
76
+ params.delta_bias_ptr = delta_bias_ptr;
77
+ params.out_ptr = out.data_ptr();
78
+ params.x_ptr = x_ptr;
79
+
80
+ // All stride are in elements, not bytes.
81
+ params.A_d_stride = A.stride(0);
82
+ params.A_dstate_stride = A.stride(1);
83
+ params.B_batch_stride = B.stride(0);
84
+ params.B_group_stride = B.stride(1);
85
+ params.B_dstate_stride = B.stride(2);
86
+ params.C_batch_stride = C.stride(0);
87
+ params.C_group_stride = C.stride(1);
88
+ params.C_dstate_stride = C.stride(2);
89
+ params.u_batch_stride = u.stride(0);
90
+ params.u_d_stride = u.stride(1);
91
+ params.delta_batch_stride = delta.stride(0);
92
+ params.delta_d_stride = delta.stride(1);
93
+
94
+ params.out_batch_stride = out.stride(0);
95
+ params.out_d_stride = out.stride(1);
96
+ }
97
+
98
+ void set_ssm_params_bwd(SSMParamsBwd &params,
99
+ // sizes
100
+ const size_t batch,
101
+ const size_t dim,
102
+ const size_t seqlen,
103
+ const size_t dstate,
104
+ const size_t n_groups,
105
+ const size_t n_chunks,
106
+ // device pointers
107
+ const at::Tensor u,
108
+ const at::Tensor delta,
109
+ const at::Tensor A,
110
+ const at::Tensor B,
111
+ const at::Tensor C,
112
+ const at::Tensor out,
113
+ void* D_ptr,
114
+ void* delta_bias_ptr,
115
+ void* x_ptr,
116
+ const at::Tensor dout,
117
+ const at::Tensor du,
118
+ const at::Tensor ddelta,
119
+ const at::Tensor dA,
120
+ const at::Tensor dB,
121
+ const at::Tensor dC,
122
+ void* dD_ptr,
123
+ void* ddelta_bias_ptr,
124
+ bool delta_softplus) {
125
+ // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z
126
+ set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks,
127
+ u, delta, A, B, C, dout,
128
+ D_ptr, delta_bias_ptr, x_ptr, delta_softplus);
129
+
130
+ // Set the pointers and strides.
131
+ params.dout_ptr = dout.data_ptr();
132
+ params.du_ptr = du.data_ptr();
133
+ params.dA_ptr = dA.data_ptr();
134
+ params.dB_ptr = dB.data_ptr();
135
+ params.dC_ptr = dC.data_ptr();
136
+ params.dD_ptr = dD_ptr;
137
+ params.ddelta_ptr = ddelta.data_ptr();
138
+ params.ddelta_bias_ptr = ddelta_bias_ptr;
139
+ // All stride are in elements, not bytes.
140
+ params.dout_batch_stride = dout.stride(0);
141
+ params.dout_d_stride = dout.stride(1);
142
+ params.dA_d_stride = dA.stride(0);
143
+ params.dA_dstate_stride = dA.stride(1);
144
+ params.dB_batch_stride = dB.stride(0);
145
+ params.dB_group_stride = dB.stride(1);
146
+ params.dB_dstate_stride = dB.stride(2);
147
+ params.dC_batch_stride = dC.stride(0);
148
+ params.dC_group_stride = dC.stride(1);
149
+ params.dC_dstate_stride = dC.stride(2);
150
+ params.du_batch_stride = du.stride(0);
151
+ params.du_d_stride = du.stride(1);
152
+ params.ddelta_batch_stride = ddelta.stride(0);
153
+ params.ddelta_d_stride = ddelta.stride(1);
154
+
155
+ }
156
+
157
+ std::vector<at::Tensor>
158
+ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
159
+ const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
160
+ const c10::optional<at::Tensor> &D_,
161
+ const c10::optional<at::Tensor> &delta_bias_,
162
+ bool delta_softplus,
163
+ int nrows
164
+ ) {
165
+ auto input_type = u.scalar_type();
166
+ auto weight_type = A.scalar_type();
167
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
168
+ TORCH_CHECK(weight_type == at::ScalarType::Float);
169
+
170
+ TORCH_CHECK(delta.scalar_type() == input_type);
171
+ TORCH_CHECK(B.scalar_type() == input_type);
172
+ TORCH_CHECK(C.scalar_type() == input_type);
173
+
174
+ TORCH_CHECK(u.is_cuda());
175
+ TORCH_CHECK(delta.is_cuda());
176
+ TORCH_CHECK(A.is_cuda());
177
+ TORCH_CHECK(B.is_cuda());
178
+ TORCH_CHECK(C.is_cuda());
179
+
180
+ TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
181
+ TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
182
+
183
+ const auto sizes = u.sizes();
184
+ const int batch_size = sizes[0];
185
+ const int dim = sizes[1];
186
+ const int seqlen = sizes[2];
187
+ const int dstate = A.size(1);
188
+ const int n_groups = B.size(1);
189
+
190
+ TORCH_CHECK(dim % n_groups == 0, "dims should be dividable by n_groups");
191
+ TORCH_CHECK(dstate <= MAX_DSTATE, "selective_scan only supports state dimension <= 256");
192
+
193
+ CHECK_SHAPE(u, batch_size, dim, seqlen);
194
+ CHECK_SHAPE(delta, batch_size, dim, seqlen);
195
+ CHECK_SHAPE(A, dim, dstate);
196
+ CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen);
197
+ TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
198
+ CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
199
+ TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
200
+
201
+ if (D_.has_value()) {
202
+ auto D = D_.value();
203
+ TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
204
+ TORCH_CHECK(D.is_cuda());
205
+ TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
206
+ CHECK_SHAPE(D, dim);
207
+ }
208
+
209
+ if (delta_bias_.has_value()) {
210
+ auto delta_bias = delta_bias_.value();
211
+ TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
212
+ TORCH_CHECK(delta_bias.is_cuda());
213
+ TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
214
+ CHECK_SHAPE(delta_bias, dim);
215
+ }
216
+
217
+ const int n_chunks = (seqlen + 2048 - 1) / 2048; // max is 128 * 16 = 2048 in fwd_kernel
218
+ at::Tensor out = torch::empty_like(delta);
219
+ at::Tensor x;
220
+ x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type));
221
+
222
+ SSMParamsBase params;
223
+ set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks,
224
+ u, delta, A, B, C, out,
225
+ D_.has_value() ? D_.value().data_ptr() : nullptr,
226
+ delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
227
+ x.data_ptr(),
228
+ delta_softplus);
229
+
230
+ // Otherwise the kernel will be launched from cuda:0 device
231
+ // Cast to char to avoid compiler warning about narrowing
232
+ at::cuda::CUDAGuard device_guard{(char)u.get_device()};
233
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
234
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
235
+ selective_scan_fwd_cuda<1, input_t, weight_t>(params, stream);
236
+ });
237
+ std::vector<at::Tensor> result = {out, x};
238
+ return result;
239
+ }
240
+
241
+ std::vector<at::Tensor>
242
+ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
243
+ const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
244
+ const c10::optional<at::Tensor> &D_,
245
+ const c10::optional<at::Tensor> &delta_bias_,
246
+ const at::Tensor &dout,
247
+ const c10::optional<at::Tensor> &x_,
248
+ bool delta_softplus,
249
+ int nrows
250
+ ) {
251
+ auto input_type = u.scalar_type();
252
+ auto weight_type = A.scalar_type();
253
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
254
+ TORCH_CHECK(weight_type == at::ScalarType::Float);
255
+
256
+ TORCH_CHECK(delta.scalar_type() == input_type);
257
+ TORCH_CHECK(B.scalar_type() == input_type);
258
+ TORCH_CHECK(C.scalar_type() == input_type);
259
+ TORCH_CHECK(dout.scalar_type() == input_type);
260
+
261
+ TORCH_CHECK(u.is_cuda());
262
+ TORCH_CHECK(delta.is_cuda());
263
+ TORCH_CHECK(A.is_cuda());
264
+ TORCH_CHECK(B.is_cuda());
265
+ TORCH_CHECK(C.is_cuda());
266
+ TORCH_CHECK(dout.is_cuda());
267
+
268
+ TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
269
+ TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
270
+ TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1);
271
+
272
+ const auto sizes = u.sizes();
273
+ const int batch_size = sizes[0];
274
+ const int dim = sizes[1];
275
+ const int seqlen = sizes[2];
276
+ const int dstate = A.size(1);
277
+ const int n_groups = B.size(1);
278
+
279
+ TORCH_CHECK(dim % n_groups == 0, "dims should be dividable by n_groups");
280
+ TORCH_CHECK(dstate <= MAX_DSTATE, "selective_scan only supports state dimension <= 256");
281
+
282
+ CHECK_SHAPE(u, batch_size, dim, seqlen);
283
+ CHECK_SHAPE(delta, batch_size, dim, seqlen);
284
+ CHECK_SHAPE(A, dim, dstate);
285
+ CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen);
286
+ TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
287
+ CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
288
+ TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
289
+ CHECK_SHAPE(dout, batch_size, dim, seqlen);
290
+
291
+ if (D_.has_value()) {
292
+ auto D = D_.value();
293
+ TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
294
+ TORCH_CHECK(D.is_cuda());
295
+ TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
296
+ CHECK_SHAPE(D, dim);
297
+ }
298
+
299
+ if (delta_bias_.has_value()) {
300
+ auto delta_bias = delta_bias_.value();
301
+ TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
302
+ TORCH_CHECK(delta_bias.is_cuda());
303
+ TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
304
+ CHECK_SHAPE(delta_bias, dim);
305
+ }
306
+
307
+ at::Tensor out;
308
+ const int n_chunks = (seqlen + 2048 - 1) / 2048;
309
+ // const int n_chunks = (seqlen + 1024 - 1) / 1024;
310
+ if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); }
311
+ if (x_.has_value()) {
312
+ auto x = x_.value();
313
+ TORCH_CHECK(x.scalar_type() == weight_type);
314
+ TORCH_CHECK(x.is_cuda());
315
+ TORCH_CHECK(x.is_contiguous());
316
+ CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate);
317
+ }
318
+
319
+ at::Tensor du = torch::empty_like(u);
320
+ at::Tensor ddelta = torch::empty_like(delta);
321
+ at::Tensor dA = torch::zeros_like(A);
322
+ at::Tensor dB = torch::zeros_like(B, B.options().dtype(torch::kFloat32));
323
+ at::Tensor dC = torch::zeros_like(C, C.options().dtype(torch::kFloat32));
324
+ at::Tensor dD;
325
+ if (D_.has_value()) { dD = torch::zeros_like(D_.value()); }
326
+ at::Tensor ddelta_bias;
327
+ if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); }
328
+
329
+ SSMParamsBwd params;
330
+ set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks,
331
+ u, delta, A, B, C, out,
332
+ D_.has_value() ? D_.value().data_ptr() : nullptr,
333
+ delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
334
+ x_.has_value() ? x_.value().data_ptr() : nullptr,
335
+ dout, du, ddelta, dA, dB, dC,
336
+ D_.has_value() ? dD.data_ptr() : nullptr,
337
+ delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr,
338
+ delta_softplus);
339
+
340
+ // Otherwise the kernel will be launched from cuda:0 device
341
+ // Cast to char to avoid compiler warning about narrowing
342
+ at::cuda::CUDAGuard device_guard{(char)u.get_device()};
343
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
344
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] {
345
+ selective_scan_bwd_cuda<1, input_t, weight_t>(params, stream);
346
+ });
347
+ std::vector<at::Tensor> result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias};
348
+ return result;
349
+ }
350
+
351
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
352
+ m.def("fwd", &selective_scan_fwd, "Selective scan forward");
353
+ m.def("bwd", &selective_scan_bwd, "Selective scan backward");
354
+ }
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_bwd_kernel.cuh ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <c10/util/BFloat16.h>
8
+ #include <c10/util/Half.h>
9
+ #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
10
+ #include <ATen/cuda/Atomic.cuh> // For atomicAdd on complex
11
+
12
+ #include <cub/block/block_load.cuh>
13
+ #include <cub/block/block_store.cuh>
14
+ #include <cub/block/block_scan.cuh>
15
+ #include <cub/block/block_reduce.cuh>
16
+
17
+ #include "selective_scan.h"
18
+ #include "selective_scan_common.h"
19
+ #include "reverse_scan.cuh"
20
+ #include "static_switch.h"
21
+
22
+ template<int kNThreads_, int kNItems_, bool kIsEvenLen_, bool kDeltaSoftplus_, typename input_t_, typename weight_t_>
23
+ struct Selective_Scan_bwd_kernel_traits {
24
+ static_assert(kNItems_ % 4 == 0);
25
+ using input_t = input_t_;
26
+ using weight_t = weight_t_;
27
+ static constexpr int kNThreads = kNThreads_;
28
+ static constexpr int kNItems = kNItems_;
29
+ static constexpr int MaxDState = MAX_DSTATE;
30
+ static constexpr int kNBytes = sizeof(input_t);
31
+ static_assert(kNBytes == 2 || kNBytes == 4);
32
+ static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
33
+ static_assert(kNItems % kNElts == 0);
34
+ static constexpr int kNLoads = kNItems / kNElts;
35
+ static constexpr bool kIsEvenLen = kIsEvenLen_;
36
+ static constexpr bool kDeltaSoftplus = kDeltaSoftplus_;
37
+ // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy.
38
+ // For complex this would lead to massive register spilling, so we keep it at 2.
39
+ static constexpr int kMinBlocks = kNThreads == 128 && 3;
40
+ using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
41
+ using scan_t = float2;
42
+ using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
43
+ using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
44
+ using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
45
+ using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
46
+ using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
47
+ using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads, cub::BLOCK_STORE_WARP_TRANSPOSE>;
48
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
49
+ using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
50
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
51
+ using BlockReverseScanT = BlockReverseScan<scan_t, kNThreads>;
52
+ using BlockReduceT = cub::BlockReduce<scan_t, kNThreads>;
53
+ using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
54
+ using BlockExchangeT = cub::BlockExchange<float, kNThreads, kNItems>;
55
+ static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
56
+ sizeof(typename BlockLoadVecT::TempStorage),
57
+ 2 * sizeof(typename BlockLoadWeightT::TempStorage),
58
+ 2 * sizeof(typename BlockLoadWeightVecT::TempStorage),
59
+ sizeof(typename BlockStoreT::TempStorage),
60
+ sizeof(typename BlockStoreVecT::TempStorage)});
61
+ static constexpr int kSmemExchangeSize = 2 * sizeof(typename BlockExchangeT::TempStorage);
62
+ static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage);
63
+ static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage);
64
+ };
65
+
66
+ template<typename Ktraits>
67
+ __global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
68
+ void selective_scan_bwd_kernel(SSMParamsBwd params) {
69
+ constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus;
70
+ constexpr int kNThreads = Ktraits::kNThreads;
71
+ constexpr int kNItems = Ktraits::kNItems;
72
+ using input_t = typename Ktraits::input_t;
73
+ using weight_t = typename Ktraits::weight_t;
74
+ using scan_t = typename Ktraits::scan_t;
75
+
76
+ // Shared memory.
77
+ extern __shared__ char smem_[];
78
+ auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
79
+ auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
80
+ auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
81
+ auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
82
+ auto& smem_exchange = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
83
+ auto& smem_exchange1 = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage));
84
+ auto& smem_reduce = *reinterpret_cast<typename Ktraits::BlockReduceT::TempStorage*>(reinterpret_cast<char *>(&smem_exchange) + Ktraits::kSmemExchangeSize);
85
+ auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(&smem_reduce);
86
+ auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(reinterpret_cast<char *>(&smem_reduce) + Ktraits::kSmemReduceSize);
87
+ auto& smem_reverse_scan = *reinterpret_cast<typename Ktraits::BlockReverseScanT::TempStorage*>(reinterpret_cast<char *>(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage));
88
+ weight_t *smem_delta_a = reinterpret_cast<weight_t *>(smem_ + Ktraits::kSmemSize);
89
+ scan_t *smem_running_postfix = reinterpret_cast<scan_t *>(smem_delta_a + 2 * Ktraits::MaxDState + kNThreads);
90
+ weight_t *smem_da = reinterpret_cast<weight_t *>(smem_running_postfix + Ktraits::MaxDState);
91
+
92
+ const int batch_id = blockIdx.x;
93
+ const int dim_id = blockIdx.y;
94
+ const int group_id = dim_id / (params.dim_ngroups_ratio);
95
+ input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
96
+ + dim_id * params.u_d_stride;
97
+ input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
98
+ + dim_id * params.delta_d_stride;
99
+ input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
100
+ + dim_id * params.dout_d_stride;
101
+ weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * params.A_d_stride;
102
+ input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
103
+ input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
104
+ weight_t *dA = reinterpret_cast<weight_t *>(params.dA_ptr) + dim_id * params.dA_d_stride;
105
+ weight_t *dB = reinterpret_cast<weight_t *>(params.dB_ptr)
106
+ + (batch_id * params.dB_batch_stride + group_id * params.dB_group_stride);
107
+ weight_t *dC = reinterpret_cast<weight_t *>(params.dC_ptr)
108
+ + (batch_id * params.dC_batch_stride + group_id * params.dC_group_stride);
109
+ float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.dD_ptr) + dim_id;
110
+ float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.D_ptr)[dim_id];
111
+ float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.ddelta_bias_ptr) + dim_id;
112
+ float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id];
113
+ scan_t *x = params.x_ptr == nullptr
114
+ ? nullptr
115
+ : reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate;
116
+ float dD_val = 0;
117
+ float ddelta_bias_val = 0;
118
+
119
+ constexpr int kChunkSize = kNThreads * kNItems;
120
+ u += (params.n_chunks - 1) * kChunkSize;
121
+ delta += (params.n_chunks - 1) * kChunkSize;
122
+ dout += (params.n_chunks - 1) * kChunkSize;
123
+ Bvar += (params.n_chunks - 1) * kChunkSize;
124
+ Cvar += (params.n_chunks - 1) * kChunkSize;
125
+ for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) {
126
+ input_t u_vals[kNItems];
127
+ input_t delta_vals_load[kNItems];
128
+ input_t dout_vals_load[kNItems];
129
+ __syncthreads();
130
+ load_input<Ktraits>(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize);
131
+ __syncthreads();
132
+ load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
133
+ __syncthreads();
134
+ load_input<Ktraits>(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
135
+ u -= kChunkSize;
136
+ // Will reload delta at the same location if kDeltaSoftplus
137
+ if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; }
138
+ dout -= kChunkSize;
139
+
140
+ float dout_vals[kNItems], delta_vals[kNItems];
141
+ float du_vals[kNItems];
142
+ #pragma unroll
143
+ for (int i = 0; i < kNItems; ++i) {
144
+ dout_vals[i] = float(dout_vals_load[i]);
145
+ delta_vals[i] = float(delta_vals_load[i]) + delta_bias;
146
+ if constexpr (kDeltaSoftplus) {
147
+ delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i];
148
+ }
149
+ }
150
+ #pragma unroll
151
+ for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; }
152
+ #pragma unroll
153
+ for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); }
154
+
155
+ float ddelta_vals[kNItems] = {0};
156
+ __syncthreads();
157
+ for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
158
+ constexpr float kLog2e = M_LOG2E;
159
+ weight_t A_val = A[state_idx * params.A_dstate_stride];
160
+ weight_t A_scaled = A_val * kLog2e;
161
+ weight_t B_vals[kNItems], C_vals[kNItems];
162
+ load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
163
+ smem_load_weight, (params.seqlen - chunk * kChunkSize));
164
+ auto &smem_load_weight_C = smem_load_weight1;
165
+ load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
166
+ smem_load_weight_C, (params.seqlen - chunk * kChunkSize));
167
+ scan_t thread_data[kNItems], thread_reverse_data[kNItems];
168
+ #pragma unroll
169
+ for (int i = 0; i < kNItems; ++i) {
170
+ const float delta_a_exp = exp2f(delta_vals[i] * A_scaled);
171
+ thread_data[i] = make_float2(delta_a_exp, delta_vals[i] * float(u_vals[i]) * B_vals[i]);
172
+ if (i == 0) {
173
+ smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * Ktraits::MaxDState: threadIdx.x + 2 * Ktraits::MaxDState] = delta_a_exp;
174
+ } else {
175
+ thread_reverse_data[i - 1].x = delta_a_exp;
176
+ }
177
+ thread_reverse_data[i].y = dout_vals[i] * C_vals[i];
178
+ }
179
+ __syncthreads();
180
+ thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1
181
+ ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState])
182
+ : smem_delta_a[threadIdx.x + 1 + 2 * Ktraits::MaxDState];
183
+ // Initialize running total
184
+ scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f);
185
+ SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
186
+ Ktraits::BlockScanT(smem_scan).InclusiveScan(
187
+ thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
188
+ );
189
+ scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f);
190
+ SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
191
+ Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
192
+ thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
193
+ );
194
+ if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; }
195
+ weight_t dA_val = 0;
196
+ weight_t dB_vals[kNItems], dC_vals[kNItems];
197
+ #pragma unroll
198
+ for (int i = 0; i < kNItems; ++i) {
199
+ const float dx = thread_reverse_data[i].y;
200
+ const float ddelta_u = dx * B_vals[i];
201
+ du_vals[i] += ddelta_u * delta_vals[i];
202
+ const float a = thread_data[i].y - (delta_vals[i] * float(u_vals[i]) * B_vals[i]);
203
+ ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a;
204
+ dA_val += dx * delta_vals[i] * a;
205
+ dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]);
206
+ dC_vals[i] = dout_vals[i] * thread_data[i].y;
207
+ }
208
+ // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
209
+ Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals);
210
+ auto &smem_exchange_C = smem_exchange1;
211
+ Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals);
212
+ const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x;
213
+ weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x;
214
+ weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x;
215
+ #pragma unroll
216
+ for (int i = 0; i < kNItems; ++i) {
217
+ if (i * kNThreads < seqlen_remaining) {
218
+ { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); }
219
+ { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); }
220
+ }
221
+ }
222
+ dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val);
223
+ if (threadIdx.x == 0) {
224
+ smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];
225
+ }
226
+ }
227
+
228
+ if constexpr (kDeltaSoftplus) {
229
+ input_t delta_vals_load[kNItems];
230
+ __syncthreads();
231
+ load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
232
+ delta -= kChunkSize;
233
+ #pragma unroll
234
+ for (int i = 0; i < kNItems; ++i) {
235
+ float delta_val = float(delta_vals_load[i]) + delta_bias;
236
+ float delta_val_neg_exp = expf(-delta_val);
237
+ ddelta_vals[i] = delta_val <= 20.f
238
+ ? ddelta_vals[i] / (1.f + delta_val_neg_exp)
239
+ : ddelta_vals[i];
240
+ }
241
+ }
242
+
243
+ __syncthreads();
244
+ #pragma unroll
245
+ for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; }
246
+
247
+ input_t *du = reinterpret_cast<input_t *>(params.du_ptr) + batch_id * params.du_batch_stride
248
+ + dim_id * params.du_d_stride + chunk * kChunkSize;
249
+ input_t *ddelta = reinterpret_cast<input_t *>(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride
250
+ + dim_id * params.ddelta_d_stride + chunk * kChunkSize;
251
+ __syncthreads();
252
+ store_output<Ktraits>(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize);
253
+ __syncthreads();
254
+ store_output<Ktraits>(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize);
255
+ Bvar -= kChunkSize;
256
+ Cvar -= kChunkSize;
257
+ }
258
+
259
+ if (params.dD_ptr != nullptr) {
260
+ __syncthreads();
261
+ dD_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val);
262
+ if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); }
263
+ }
264
+ if (params.ddelta_bias_ptr != nullptr) {
265
+ __syncthreads();
266
+ ddelta_bias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val);
267
+ if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); }
268
+ }
269
+ __syncthreads();
270
+ for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
271
+ gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride]), smem_da[state_idx]);
272
+ }
273
+ }
274
+
275
+ template<int kNThreads, int kNItems, typename input_t, typename weight_t>
276
+ void selective_scan_bwd_launch(SSMParamsBwd &params, cudaStream_t stream) {
277
+ BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
278
+ BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] {
279
+ using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, kDeltaSoftplus, input_t, weight_t>;
280
+ constexpr int kSmemSize = Ktraits::kSmemSize + Ktraits::MaxDState * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * Ktraits::MaxDState) * sizeof(typename Ktraits::weight_t);
281
+ // printf("smem_size = %d\n", kSmemSize);
282
+ dim3 grid(params.batch, params.dim);
283
+ auto kernel = &selective_scan_bwd_kernel<Ktraits>;
284
+ if (kSmemSize >= 48 * 1024) {
285
+ C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
286
+ }
287
+ kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
288
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
289
+ });
290
+ });
291
+ }
292
+
293
+ template<int knrows, typename input_t, typename weight_t>
294
+ void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream) {
295
+ if (params.seqlen <= 128) {
296
+ selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream);
297
+ } else if (params.seqlen <= 256) {
298
+ selective_scan_bwd_launch<32, 8, input_t, weight_t>(params, stream);
299
+ } else if (params.seqlen <= 512) {
300
+ selective_scan_bwd_launch<32, 16, input_t, weight_t>(params, stream);
301
+ } else if (params.seqlen <= 1024) {
302
+ selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);
303
+ } else {
304
+ selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);
305
+ }
306
+ }
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_core_bwd.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+ #include "selective_scan_bwd_kernel.cuh"
5
+
6
+ template void selective_scan_bwd_cuda<1, float, float>(SSMParamsBwd &params, cudaStream_t stream);
7
+ template void selective_scan_bwd_cuda<1, at::Half, float>(SSMParamsBwd &params, cudaStream_t stream);
8
+ template void selective_scan_bwd_cuda<1, at::BFloat16, float>(SSMParamsBwd &params, cudaStream_t stream);
9
+
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_core_fwd.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+ #include "selective_scan_fwd_kernel.cuh"
5
+
6
+ template void selective_scan_fwd_cuda<1, float, float>(SSMParamsBase &params, cudaStream_t stream);
7
+ template void selective_scan_fwd_cuda<1, at::Half, float>(SSMParamsBase &params, cudaStream_t stream);
8
+ template void selective_scan_fwd_cuda<1, at::BFloat16, float>(SSMParamsBase &params, cudaStream_t stream);
9
+
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_fwd_kernel.cuh ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <c10/util/BFloat16.h>
8
+ #include <c10/util/Half.h>
9
+ #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
10
+
11
+ #include <cub/block/block_load.cuh>
12
+ #include <cub/block/block_store.cuh>
13
+ #include <cub/block/block_scan.cuh>
14
+
15
+ #include "selective_scan.h"
16
+ #include "selective_scan_common.h"
17
+ #include "static_switch.h"
18
+
19
+ template<int kNThreads_, int kNItems_, bool kIsEvenLen_, typename input_t_, typename weight_t_>
20
+ struct Selective_Scan_fwd_kernel_traits {
21
+ static_assert(kNItems_ % 4 == 0);
22
+ using input_t = input_t_;
23
+ using weight_t = weight_t_;
24
+ static constexpr int kNThreads = kNThreads_;
25
+ // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
26
+ static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
27
+ static constexpr int kNItems = kNItems_;
28
+ static constexpr int MaxDState = MAX_DSTATE;
29
+ static constexpr int kNBytes = sizeof(input_t);
30
+ static_assert(kNBytes == 2 || kNBytes == 4);
31
+ static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
32
+ static_assert(kNItems % kNElts == 0);
33
+ static constexpr int kNLoads = kNItems / kNElts;
34
+ static constexpr bool kIsEvenLen = kIsEvenLen_;
35
+
36
+ static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;
37
+
38
+ using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
39
+ using scan_t = float2;
40
+ using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
41
+ using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
42
+ !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
43
+ using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
44
+ using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
45
+ !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
46
+ using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
47
+ using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads,
48
+ !kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
49
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
50
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
51
+ using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
52
+ static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
53
+ sizeof(typename BlockLoadVecT::TempStorage),
54
+ 2 * sizeof(typename BlockLoadWeightT::TempStorage),
55
+ 2 * sizeof(typename BlockLoadWeightVecT::TempStorage),
56
+ sizeof(typename BlockStoreT::TempStorage),
57
+ sizeof(typename BlockStoreVecT::TempStorage)});
58
+ static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
59
+ };
60
+
61
+ template<typename Ktraits>
62
+ __global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
63
+ void selective_scan_fwd_kernel(SSMParamsBase params) {
64
+ constexpr int kNThreads = Ktraits::kNThreads;
65
+ constexpr int kNItems = Ktraits::kNItems;
66
+ constexpr bool kDirectIO = Ktraits::kDirectIO;
67
+ using input_t = typename Ktraits::input_t;
68
+ using weight_t = typename Ktraits::weight_t;
69
+ using scan_t = typename Ktraits::scan_t;
70
+
71
+ // Shared memory.
72
+ extern __shared__ char smem_[];
73
+ auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
74
+ auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
75
+ auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
76
+ auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
77
+ auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
78
+ scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize);
79
+
80
+ const int batch_id = blockIdx.x;
81
+ const int dim_id = blockIdx.y;
82
+ const int group_id = dim_id / (params.dim_ngroups_ratio);
83
+ input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
84
+ + dim_id * params.u_d_stride;
85
+ input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
86
+ + dim_id * params.delta_d_stride;
87
+ weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * params.A_d_stride;
88
+ input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
89
+ input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
90
+ scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * params.n_chunks * params.dstate;
91
+
92
+ float D_val = 0; // attention!
93
+ if (params.D_ptr != nullptr) {
94
+ D_val = reinterpret_cast<float *>(params.D_ptr)[dim_id];
95
+ }
96
+ float delta_bias = 0;
97
+ if (params.delta_bias_ptr != nullptr) {
98
+ delta_bias = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id];
99
+ }
100
+
101
+ constexpr int kChunkSize = kNThreads * kNItems;
102
+ for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
103
+ input_t u_vals[kNItems], delta_vals_load[kNItems];
104
+ __syncthreads();
105
+ load_input<Ktraits>(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize);
106
+ if constexpr (!kDirectIO) { __syncthreads(); }
107
+ load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
108
+ u += kChunkSize;
109
+ delta += kChunkSize;
110
+
111
+ float delta_vals[kNItems], delta_u_vals[kNItems], out_vals[kNItems];
112
+ #pragma unroll
113
+ for (int i = 0; i < kNItems; ++i) {
114
+ float u_val = float(u_vals[i]);
115
+ delta_vals[i] = float(delta_vals_load[i]) + delta_bias;
116
+ if (params.delta_softplus) {
117
+ delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i];
118
+ }
119
+ delta_u_vals[i] = delta_vals[i] * u_val;
120
+ out_vals[i] = D_val * u_val;
121
+ }
122
+
123
+ __syncthreads();
124
+ for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
125
+ constexpr float kLog2e = M_LOG2E;
126
+ weight_t A_val = A[state_idx * params.A_dstate_stride];
127
+ A_val *= kLog2e;
128
+ weight_t B_vals[kNItems], C_vals[kNItems];
129
+ load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
130
+ smem_load_weight, (params.seqlen - chunk * kChunkSize));
131
+ load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
132
+ smem_load_weight1, (params.seqlen - chunk * kChunkSize));
133
+ __syncthreads();
134
+ scan_t thread_data[kNItems];
135
+ #pragma unroll
136
+ for (int i = 0; i < kNItems; ++i) {
137
+ thread_data[i] = make_float2(exp2f(delta_vals[i] * A_val), B_vals[i] * delta_u_vals[i]);
138
+ if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
139
+ if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
140
+ thread_data[i] = make_float2(1.f, 0.f);
141
+ }
142
+ }
143
+ }
144
+ // Initialize running total
145
+ scan_t running_prefix;
146
+ // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read
147
+ running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
148
+ // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
149
+ SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
150
+ Ktraits::BlockScanT(smem_scan).InclusiveScan(
151
+ thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
152
+ );
153
+ // There's a syncthreads in the scan op, so we don't need to sync here.
154
+ // Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
155
+ if (threadIdx.x == 0) {
156
+ smem_running_prefix[state_idx] = prefix_op.running_prefix;
157
+ x[chunk * params.dstate + state_idx] = prefix_op.running_prefix;
158
+ }
159
+ #pragma unroll
160
+ for (int i = 0; i < kNItems; ++i) {
161
+ out_vals[i] += thread_data[i].y * C_vals[i];
162
+ }
163
+ }
164
+
165
+ input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
166
+ + dim_id * params.out_d_stride + chunk * kChunkSize;
167
+ __syncthreads();
168
+ store_output<Ktraits>(out, out_vals, smem_store, params.seqlen - chunk * kChunkSize);
169
+ Bvar += kChunkSize;
170
+ Cvar += kChunkSize;
171
+ }
172
+ }
173
+
174
+ template<int kNThreads, int kNItems, typename input_t, typename weight_t>
175
+ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
176
+ BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
177
+ using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, input_t, weight_t>;
178
+ constexpr int kSmemSize = Ktraits::kSmemSize + Ktraits::MaxDState * sizeof(typename Ktraits::scan_t);
179
+ // printf("smem_size = %d\n", kSmemSize);
180
+ dim3 grid(params.batch, params.dim);
181
+ auto kernel = &selective_scan_fwd_kernel<Ktraits>;
182
+ if (kSmemSize >= 48 * 1024) {
183
+ C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
184
+ }
185
+ kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
186
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
187
+ });
188
+ }
189
+
190
+ template<int knrows, typename input_t, typename weight_t>
191
+ void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
192
+ if (params.seqlen <= 128) {
193
+ selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
194
+ } else if (params.seqlen <= 256) {
195
+ selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream);
196
+ } else if (params.seqlen <= 512) {
197
+ selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream);
198
+ } else if (params.seqlen <= 1024) {
199
+ selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
200
+ } else {
201
+ selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
202
+ }
203
+ }
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_bwd_kernel_ndstate.cuh ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <c10/util/BFloat16.h>
8
+ #include <c10/util/Half.h>
9
+ #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
10
+ #include <ATen/cuda/Atomic.cuh> // For atomicAdd on complex
11
+
12
+ #include <cub/block/block_load.cuh>
13
+ #include <cub/block/block_store.cuh>
14
+ #include <cub/block/block_scan.cuh>
15
+ #include <cub/block/block_reduce.cuh>
16
+
17
+ #include "selective_scan_ndstate.h"
18
+ #include "selective_scan_common.h"
19
+ #include "reverse_scan.cuh"
20
+ #include "static_switch.h"
21
+
22
+ template<int kNThreads_, int kNItems_, bool kIsEvenLen_, bool kDeltaSoftplus_, typename input_t_, typename weight_t_>
23
+ struct Selective_Scan_bwd_kernel_traits {
24
+ static_assert(kNItems_ % 4 == 0);
25
+ using input_t = input_t_;
26
+ using weight_t = weight_t_;
27
+ static constexpr int kNThreads = kNThreads_;
28
+ static constexpr int kNItems = kNItems_;
29
+ static constexpr int kNBytes = sizeof(input_t);
30
+ static_assert(kNBytes == 2 || kNBytes == 4);
31
+ static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
32
+ static_assert(kNItems % kNElts == 0);
33
+ static constexpr int kNLoads = kNItems / kNElts;
34
+ static constexpr bool kIsEvenLen = kIsEvenLen_;
35
+ static constexpr bool kDeltaSoftplus = kDeltaSoftplus_;
36
+ // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy.
37
+ // For complex this would lead to massive register spilling, so we keep it at 2.
38
+ static constexpr int kMinBlocks = kNThreads == 128 && 3;
39
+ using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
40
+ using scan_t = float2;
41
+ using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
42
+ using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
43
+ using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
44
+ using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
45
+ using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
46
+ using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads, cub::BLOCK_STORE_WARP_TRANSPOSE>;
47
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
48
+ using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
49
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
50
+ using BlockReverseScanT = BlockReverseScan<scan_t, kNThreads>;
51
+ using BlockReduceT = cub::BlockReduce<scan_t, kNThreads>;
52
+ using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
53
+ using BlockExchangeT = cub::BlockExchange<float, kNThreads, kNItems>;
54
+ static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
55
+ sizeof(typename BlockLoadVecT::TempStorage),
56
+ 2 * sizeof(typename BlockLoadWeightT::TempStorage),
57
+ 2 * sizeof(typename BlockLoadWeightVecT::TempStorage),
58
+ sizeof(typename BlockStoreT::TempStorage),
59
+ sizeof(typename BlockStoreVecT::TempStorage)});
60
+ static constexpr int kSmemExchangeSize = 2 * sizeof(typename BlockExchangeT::TempStorage);
61
+ static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage);
62
+ static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage);
63
+ };
64
+
65
+ template<typename Ktraits>
66
+ __global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
67
+ void selective_scan_bwd_kernel(SSMParamsBwd params) {
68
+ constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus;
69
+ constexpr int kNThreads = Ktraits::kNThreads;
70
+ constexpr int kNItems = Ktraits::kNItems;
71
+ using input_t = typename Ktraits::input_t;
72
+ using weight_t = typename Ktraits::weight_t;
73
+ using scan_t = typename Ktraits::scan_t;
74
+
75
+ // Shared memory.
76
+ extern __shared__ char smem_[];
77
+ auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
78
+ auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
79
+ auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
80
+ auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
81
+ auto& smem_exchange = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
82
+ auto& smem_exchange1 = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage));
83
+ auto& smem_reduce = *reinterpret_cast<typename Ktraits::BlockReduceT::TempStorage*>(reinterpret_cast<char *>(&smem_exchange) + Ktraits::kSmemExchangeSize);
84
+ auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(&smem_reduce);
85
+ auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(reinterpret_cast<char *>(&smem_reduce) + Ktraits::kSmemReduceSize);
86
+ auto& smem_reverse_scan = *reinterpret_cast<typename Ktraits::BlockReverseScanT::TempStorage*>(reinterpret_cast<char *>(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage));
87
+ weight_t *smem_delta_a = reinterpret_cast<weight_t *>(smem_ + Ktraits::kSmemSize);
88
+ scan_t *smem_running_postfix = reinterpret_cast<scan_t *>(smem_delta_a + 2 + kNThreads);
89
+ weight_t *smem_da = reinterpret_cast<weight_t *>(smem_running_postfix + 1);
90
+
91
+ const int batch_id = blockIdx.x;
92
+ const int dim_id = blockIdx.y;
93
+ const int group_id = dim_id / (params.dim_ngroups_ratio);
94
+ input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
95
+ + dim_id * params.u_d_stride;
96
+ input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
97
+ + dim_id * params.delta_d_stride;
98
+ input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
99
+ + dim_id * params.dout_d_stride;
100
+ weight_t A_val = reinterpret_cast<weight_t *>(params.A_ptr)[dim_id];
101
+ constexpr float kLog2e = M_LOG2E;
102
+ weight_t A_scaled = A_val * kLog2e;
103
+ input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
104
+ input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
105
+ weight_t *dA = reinterpret_cast<weight_t *>(params.dA_ptr) + dim_id;
106
+ weight_t *dB = reinterpret_cast<weight_t *>(params.dB_ptr)
107
+ + (batch_id * params.dB_batch_stride + group_id * params.dB_group_stride);
108
+ weight_t *dC = reinterpret_cast<weight_t *>(params.dC_ptr)
109
+ + (batch_id * params.dC_batch_stride + group_id * params.dC_group_stride);
110
+ float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.dD_ptr) + dim_id;
111
+ float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.D_ptr)[dim_id];
112
+ float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.ddelta_bias_ptr) + dim_id;
113
+ float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id];
114
+ scan_t *x = params.x_ptr == nullptr
115
+ ? nullptr
116
+ : reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks);
117
+ float dD_val = 0;
118
+ float ddelta_bias_val = 0;
119
+
120
+ constexpr int kChunkSize = kNThreads * kNItems;
121
+ u += (params.n_chunks - 1) * kChunkSize;
122
+ delta += (params.n_chunks - 1) * kChunkSize;
123
+ dout += (params.n_chunks - 1) * kChunkSize;
124
+ Bvar += (params.n_chunks - 1) * kChunkSize;
125
+ Cvar += (params.n_chunks - 1) * kChunkSize;
126
+ for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) {
127
+ input_t u_vals[kNItems];
128
+ input_t delta_vals_load[kNItems];
129
+ input_t dout_vals_load[kNItems];
130
+ __syncthreads();
131
+ load_input<Ktraits>(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize);
132
+ __syncthreads();
133
+ load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
134
+ __syncthreads();
135
+ load_input<Ktraits>(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
136
+ u -= kChunkSize;
137
+ // Will reload delta at the same location if kDeltaSoftplus
138
+ if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; }
139
+ dout -= kChunkSize;
140
+
141
+ float dout_vals[kNItems], delta_vals[kNItems];
142
+ float du_vals[kNItems];
143
+ #pragma unroll
144
+ for (int i = 0; i < kNItems; ++i) {
145
+ dout_vals[i] = float(dout_vals_load[i]);
146
+ delta_vals[i] = float(delta_vals_load[i]) + delta_bias;
147
+ if constexpr (kDeltaSoftplus) {
148
+ delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i];
149
+ }
150
+ }
151
+ #pragma unroll
152
+ for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; }
153
+ #pragma unroll
154
+ for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); }
155
+
156
+ float ddelta_vals[kNItems] = {0};
157
+ __syncthreads();
158
+ {
159
+ weight_t B_vals[kNItems], C_vals[kNItems];
160
+ load_weight<Ktraits>(Bvar, B_vals,
161
+ smem_load_weight, (params.seqlen - chunk * kChunkSize));
162
+ auto &smem_load_weight_C = smem_load_weight1;
163
+ load_weight<Ktraits>(Cvar, C_vals,
164
+ smem_load_weight_C, (params.seqlen - chunk * kChunkSize));
165
+ scan_t thread_data[kNItems], thread_reverse_data[kNItems];
166
+ #pragma unroll
167
+ for (int i = 0; i < kNItems; ++i) {
168
+ const float delta_a_exp = exp2f(delta_vals[i] * A_scaled);
169
+ thread_data[i] = make_float2(delta_a_exp, delta_vals[i] * float(u_vals[i]) * B_vals[i]);
170
+ if (i == 0) {
171
+ smem_delta_a[threadIdx.x == 0 ? (chunk % 2): threadIdx.x + 2] = delta_a_exp;
172
+ } else {
173
+ thread_reverse_data[i - 1].x = delta_a_exp;
174
+ }
175
+ thread_reverse_data[i].y = dout_vals[i] * C_vals[i];
176
+ }
177
+ __syncthreads();
178
+ thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1
179
+ ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[(chunk + 1) % 2])
180
+ : smem_delta_a[threadIdx.x + 1 + 2];
181
+ // Initialize running total
182
+ scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[chunk - 1] : make_float2(1.f, 0.f);
183
+ SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
184
+ Ktraits::BlockScanT(smem_scan).InclusiveScan(
185
+ thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
186
+ );
187
+ scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[0] : make_float2(1.f, 0.f);
188
+ SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
189
+ Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
190
+ thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
191
+ );
192
+ if (threadIdx.x == 0) { smem_running_postfix[0] = postfix_op.running_prefix; }
193
+ weight_t dA_val = 0;
194
+ weight_t dB_vals[kNItems], dC_vals[kNItems];
195
+ #pragma unroll
196
+ for (int i = 0; i < kNItems; ++i) {
197
+ const float dx = thread_reverse_data[i].y;
198
+ const float ddelta_u = dx * B_vals[i];
199
+ du_vals[i] += ddelta_u * delta_vals[i];
200
+ const float a = thread_data[i].y - (delta_vals[i] * float(u_vals[i]) * B_vals[i]);
201
+ ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a;
202
+ dA_val += dx * delta_vals[i] * a;
203
+ dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]);
204
+ dC_vals[i] = dout_vals[i] * thread_data[i].y;
205
+ }
206
+ // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
207
+ Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals);
208
+ auto &smem_exchange_C = smem_exchange1;
209
+ Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals);
210
+ const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x;
211
+ weight_t *dB_cur = dB + chunk * kChunkSize + threadIdx.x;
212
+ weight_t *dC_cur = dC + chunk * kChunkSize + threadIdx.x;
213
+ #pragma unroll
214
+ for (int i = 0; i < kNItems; ++i) {
215
+ if (i * kNThreads < seqlen_remaining) {
216
+ { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); }
217
+ { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); }
218
+ }
219
+ }
220
+ dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val);
221
+ if (threadIdx.x == 0) {
222
+ smem_da[0] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[0];
223
+ }
224
+ }
225
+
226
+ if constexpr (kDeltaSoftplus) {
227
+ input_t delta_vals_load[kNItems];
228
+ __syncthreads();
229
+ load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
230
+ delta -= kChunkSize;
231
+ #pragma unroll
232
+ for (int i = 0; i < kNItems; ++i) {
233
+ float delta_val = float(delta_vals_load[i]) + delta_bias;
234
+ float delta_val_neg_exp = expf(-delta_val);
235
+ ddelta_vals[i] = delta_val <= 20.f
236
+ ? ddelta_vals[i] / (1.f + delta_val_neg_exp)
237
+ : ddelta_vals[i];
238
+ }
239
+ }
240
+
241
+ __syncthreads();
242
+ #pragma unroll
243
+ for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; }
244
+
245
+ input_t *du = reinterpret_cast<input_t *>(params.du_ptr) + batch_id * params.du_batch_stride
246
+ + dim_id * params.du_d_stride + chunk * kChunkSize;
247
+ input_t *ddelta = reinterpret_cast<input_t *>(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride
248
+ + dim_id * params.ddelta_d_stride + chunk * kChunkSize;
249
+ __syncthreads();
250
+ store_output<Ktraits>(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize);
251
+ __syncthreads();
252
+ store_output<Ktraits>(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize);
253
+ Bvar -= kChunkSize;
254
+ Cvar -= kChunkSize;
255
+ }
256
+
257
+ if (params.dD_ptr != nullptr) {
258
+ __syncthreads();
259
+ dD_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val);
260
+ if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); }
261
+ }
262
+ if (params.ddelta_bias_ptr != nullptr) {
263
+ __syncthreads();
264
+ ddelta_bias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val);
265
+ if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); }
266
+ }
267
+ __syncthreads();
268
+ if (threadIdx.x == 0) { gpuAtomicAdd(dA, smem_da[0]); }
269
+ }
270
+
271
+ template<int kNThreads, int kNItems, typename input_t, typename weight_t>
272
+ void selective_scan_bwd_launch(SSMParamsBwd &params, cudaStream_t stream) {
273
+ BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
274
+ BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] {
275
+ using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, kDeltaSoftplus, input_t, weight_t>;
276
+ constexpr int kSmemSize = Ktraits::kSmemSize + sizeof(typename Ktraits::scan_t) + (kNThreads + 4) * sizeof(typename Ktraits::weight_t);
277
+ // printf("smem_size = %d\n", kSmemSize);
278
+ dim3 grid(params.batch, params.dim);
279
+ auto kernel = &selective_scan_bwd_kernel<Ktraits>;
280
+ if (kSmemSize >= 48 * 1024) {
281
+ C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
282
+ }
283
+ kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
284
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
285
+ });
286
+ });
287
+ }
288
+
289
+ template<int knrows, typename input_t, typename weight_t>
290
+ void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream) {
291
+ if (params.seqlen <= 128) {
292
+ selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream);
293
+ } else if (params.seqlen <= 256) {
294
+ selective_scan_bwd_launch<32, 8, input_t, weight_t>(params, stream);
295
+ } else if (params.seqlen <= 512) {
296
+ selective_scan_bwd_launch<32, 16, input_t, weight_t>(params, stream);
297
+ } else if (params.seqlen <= 1024) {
298
+ selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);
299
+ } else {
300
+ selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);
301
+ }
302
+ }
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_core_bwd.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+ #include "selective_scan_bwd_kernel_ndstate.cuh"
5
+
6
+ template void selective_scan_bwd_cuda<1, float, float>(SSMParamsBwd &params, cudaStream_t stream);
7
+ template void selective_scan_bwd_cuda<1, at::Half, float>(SSMParamsBwd &params, cudaStream_t stream);
8
+ template void selective_scan_bwd_cuda<1, at::BFloat16, float>(SSMParamsBwd &params, cudaStream_t stream);
9
+
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_core_fwd.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+ #include "selective_scan_fwd_kernel_ndstate.cuh"
5
+
6
+ template void selective_scan_fwd_cuda<1, float, float>(SSMParamsBase &params, cudaStream_t stream);
7
+ template void selective_scan_fwd_cuda<1, at::Half, float>(SSMParamsBase &params, cudaStream_t stream);
8
+ template void selective_scan_fwd_cuda<1, at::BFloat16, float>(SSMParamsBase &params, cudaStream_t stream);
9
+
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_fwd_kernel_ndstate.cuh ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <c10/util/BFloat16.h>
8
+ #include <c10/util/Half.h>
9
+ #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
10
+
11
+ #include <cub/block/block_load.cuh>
12
+ #include <cub/block/block_store.cuh>
13
+ #include <cub/block/block_scan.cuh>
14
+
15
+ #include "selective_scan_ndstate.h"
16
+ #include "selective_scan_common.h"
17
+ #include "static_switch.h"
18
+
19
+ template<int kNThreads_, int kNItems_, bool kIsEvenLen_, typename input_t_, typename weight_t_>
20
+ struct Selective_Scan_fwd_kernel_traits {
21
+ static_assert(kNItems_ % 4 == 0);
22
+ using input_t = input_t_;
23
+ using weight_t = weight_t_;
24
+ static constexpr int kNThreads = kNThreads_;
25
+ // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
26
+ static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
27
+ static constexpr int kNItems = kNItems_;
28
+ static constexpr int kNBytes = sizeof(input_t);
29
+ static_assert(kNBytes == 2 || kNBytes == 4);
30
+ static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
31
+ static_assert(kNItems % kNElts == 0);
32
+ static constexpr int kNLoads = kNItems / kNElts;
33
+ static constexpr bool kIsEvenLen = kIsEvenLen_;
34
+
35
+ static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;
36
+
37
+ using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
38
+ using scan_t = float2;
39
+ using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
40
+ using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
41
+ !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
42
+ using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
43
+ using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
44
+ !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
45
+ using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
46
+ using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads,
47
+ !kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
48
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
49
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
50
+ using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
51
+ static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
52
+ sizeof(typename BlockLoadVecT::TempStorage),
53
+ 2 * sizeof(typename BlockLoadWeightT::TempStorage),
54
+ 2 * sizeof(typename BlockLoadWeightVecT::TempStorage),
55
+ sizeof(typename BlockStoreT::TempStorage),
56
+ sizeof(typename BlockStoreVecT::TempStorage)});
57
+ static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
58
+ };
59
+
60
+ template<typename Ktraits>
61
+ __global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
62
+ void selective_scan_fwd_kernel(SSMParamsBase params) {
63
+ constexpr int kNThreads = Ktraits::kNThreads;
64
+ constexpr int kNItems = Ktraits::kNItems;
65
+ constexpr bool kDirectIO = Ktraits::kDirectIO;
66
+ using input_t = typename Ktraits::input_t;
67
+ using weight_t = typename Ktraits::weight_t;
68
+ using scan_t = typename Ktraits::scan_t;
69
+
70
+ // Shared memory.
71
+ extern __shared__ char smem_[];
72
+ auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
73
+ auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
74
+ auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
75
+ auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
76
+ auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
77
+ scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize);
78
+
79
+ const int batch_id = blockIdx.x;
80
+ const int dim_id = blockIdx.y;
81
+ const int group_id = dim_id / (params.dim_ngroups_ratio);
82
+ input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
83
+ + dim_id * params.u_d_stride;
84
+ input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
85
+ + dim_id * params.delta_d_stride;
86
+ constexpr float kLog2e = M_LOG2E;
87
+ weight_t A_val = reinterpret_cast<weight_t *>(params.A_ptr)[dim_id] * kLog2e;
88
+ input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
89
+ input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
90
+ scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * params.n_chunks;
91
+
92
+ float D_val = 0; // attention!
93
+ if (params.D_ptr != nullptr) {
94
+ D_val = reinterpret_cast<float *>(params.D_ptr)[dim_id];
95
+ }
96
+ float delta_bias = 0;
97
+ if (params.delta_bias_ptr != nullptr) {
98
+ delta_bias = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id];
99
+ }
100
+
101
+ constexpr int kChunkSize = kNThreads * kNItems;
102
+ for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
103
+ input_t u_vals[kNItems], delta_vals_load[kNItems];
104
+ __syncthreads();
105
+ load_input<Ktraits>(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize);
106
+ if constexpr (!kDirectIO) { __syncthreads(); }
107
+ load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
108
+ u += kChunkSize;
109
+ delta += kChunkSize;
110
+
111
+ float delta_vals[kNItems], delta_u_vals[kNItems], out_vals[kNItems];
112
+ #pragma unroll
113
+ for (int i = 0; i < kNItems; ++i) {
114
+ float u_val = float(u_vals[i]);
115
+ delta_vals[i] = float(delta_vals_load[i]) + delta_bias;
116
+ if (params.delta_softplus) {
117
+ delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i];
118
+ }
119
+ delta_u_vals[i] = delta_vals[i] * u_val;
120
+ out_vals[i] = D_val * u_val;
121
+ }
122
+
123
+ __syncthreads();
124
+ {
125
+ weight_t B_vals[kNItems], C_vals[kNItems];
126
+ load_weight<Ktraits>(Bvar, B_vals,
127
+ smem_load_weight, (params.seqlen - chunk * kChunkSize));
128
+ auto &smem_load_weight_C = smem_load_weight1;
129
+ load_weight<Ktraits>(Cvar, C_vals,
130
+ smem_load_weight_C, (params.seqlen - chunk * kChunkSize));
131
+ __syncthreads();
132
+ scan_t thread_data[kNItems];
133
+ #pragma unroll
134
+ for (int i = 0; i < kNItems; ++i) {
135
+ thread_data[i] = make_float2(exp2f(delta_vals[i] * A_val), B_vals[i] * delta_u_vals[i]);
136
+ if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
137
+ if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
138
+ thread_data[i] = make_float2(1.f, 0.f);
139
+ }
140
+ }
141
+ }
142
+ // Initialize running total
143
+ scan_t running_prefix;
144
+ // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read
145
+ running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[0] : make_float2(1.f, 0.f);
146
+ SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
147
+ Ktraits::BlockScanT(smem_scan).InclusiveScan(
148
+ thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
149
+ );
150
+ // There's a syncthreads in the scan op, so we don't need to sync here.
151
+ // Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
152
+ if (threadIdx.x == 0) {
153
+ smem_running_prefix[0] = prefix_op.running_prefix;
154
+ x[chunk] = prefix_op.running_prefix;
155
+ }
156
+ #pragma unroll
157
+ for (int i = 0; i < kNItems; ++i) {
158
+ out_vals[i] += thread_data[i].y * C_vals[i];
159
+ }
160
+ }
161
+
162
+ input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
163
+ + dim_id * params.out_d_stride + chunk * kChunkSize;
164
+ __syncthreads();
165
+ store_output<Ktraits>(out, out_vals, smem_store, params.seqlen - chunk * kChunkSize);
166
+ Bvar += kChunkSize;
167
+ Cvar += kChunkSize;
168
+ }
169
+ }
170
+
171
+ template<int kNThreads, int kNItems, typename input_t, typename weight_t>
172
+ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
173
+ BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
174
+ using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, input_t, weight_t>;
175
+ constexpr int kSmemSize = Ktraits::kSmemSize + sizeof(typename Ktraits::scan_t);
176
+ // printf("smem_size = %d\n", kSmemSize);
177
+ dim3 grid(params.batch, params.dim);
178
+ auto kernel = &selective_scan_fwd_kernel<Ktraits>;
179
+ if (kSmemSize >= 48 * 1024) {
180
+ C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
181
+ }
182
+ kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
183
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
184
+ });
185
+ }
186
+
187
+ template<int knrows, typename input_t, typename weight_t>
188
+ void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
189
+ if (params.seqlen <= 128) {
190
+ selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
191
+ } else if (params.seqlen <= 256) {
192
+ selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream);
193
+ } else if (params.seqlen <= 512) {
194
+ selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream);
195
+ } else if (params.seqlen <= 1024) {
196
+ selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
197
+ } else {
198
+ selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
199
+ }
200
+ }
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_ndstate.cpp ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #include <ATen/cuda/CUDAContext.h>
6
+ #include <c10/cuda/CUDAGuard.h>
7
+ #include <torch/extension.h>
8
+ #include <vector>
9
+
10
+ #include "selective_scan_ndstate.h"
11
+ #define MAX_DSTATE 256
12
+
13
+ #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
14
+ using weight_t = float;
15
+
16
+ #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
17
+ if (ITYPE == at::ScalarType::Half) { \
18
+ using input_t = at::Half; \
19
+ __VA_ARGS__(); \
20
+ } else if (ITYPE == at::ScalarType::BFloat16) { \
21
+ using input_t = at::BFloat16; \
22
+ __VA_ARGS__(); \
23
+ } else if (ITYPE == at::ScalarType::Float) { \
24
+ using input_t = float; \
25
+ __VA_ARGS__(); \
26
+ } else { \
27
+ AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
28
+ }
29
+
30
+ template<int knrows, typename input_t, typename weight_t>
31
+ void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream);
32
+
33
+ template <int knrows, typename input_t, typename weight_t>
34
+ void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream);
35
+
36
+ void set_ssm_params_fwd(SSMParamsBase &params,
37
+ // sizes
38
+ const size_t batch,
39
+ const size_t dim,
40
+ const size_t seqlen,
41
+ const size_t n_groups,
42
+ const size_t n_chunks,
43
+ // device pointers
44
+ const at::Tensor u,
45
+ const at::Tensor delta,
46
+ const at::Tensor A,
47
+ const at::Tensor B,
48
+ const at::Tensor C,
49
+ const at::Tensor out,
50
+ void* D_ptr,
51
+ void* delta_bias_ptr,
52
+ void* x_ptr,
53
+ bool delta_softplus) {
54
+
55
+ // Reset the parameters
56
+ memset(&params, 0, sizeof(params));
57
+
58
+ params.batch = batch;
59
+ params.dim = dim;
60
+ params.seqlen = seqlen;
61
+ params.n_groups = n_groups;
62
+ params.n_chunks = n_chunks;
63
+ params.dim_ngroups_ratio = dim / n_groups;
64
+
65
+ params.delta_softplus = delta_softplus;
66
+
67
+ // Set the pointers and strides.
68
+ params.u_ptr = u.data_ptr();
69
+ params.delta_ptr = delta.data_ptr();
70
+ params.A_ptr = A.data_ptr();
71
+ params.B_ptr = B.data_ptr();
72
+ params.C_ptr = C.data_ptr();
73
+ params.D_ptr = D_ptr;
74
+ params.delta_bias_ptr = delta_bias_ptr;
75
+ params.out_ptr = out.data_ptr();
76
+ params.x_ptr = x_ptr;
77
+
78
+ // All stride are in elements, not bytes.
79
+ params.A_d_stride = A.stride(0);
80
+ params.B_batch_stride = B.stride(0);
81
+ params.B_group_stride = B.stride(1);
82
+ params.C_batch_stride = C.stride(0);
83
+ params.C_group_stride = C.stride(1);
84
+ params.u_batch_stride = u.stride(0);
85
+ params.u_d_stride = u.stride(1);
86
+ params.delta_batch_stride = delta.stride(0);
87
+ params.delta_d_stride = delta.stride(1);
88
+
89
+ params.out_batch_stride = out.stride(0);
90
+ params.out_d_stride = out.stride(1);
91
+ }
92
+
93
+ void set_ssm_params_bwd(SSMParamsBwd &params,
94
+ // sizes
95
+ const size_t batch,
96
+ const size_t dim,
97
+ const size_t seqlen,
98
+ const size_t n_groups,
99
+ const size_t n_chunks,
100
+ // device pointers
101
+ const at::Tensor u,
102
+ const at::Tensor delta,
103
+ const at::Tensor A,
104
+ const at::Tensor B,
105
+ const at::Tensor C,
106
+ const at::Tensor out,
107
+ void* D_ptr,
108
+ void* delta_bias_ptr,
109
+ void* x_ptr,
110
+ const at::Tensor dout,
111
+ const at::Tensor du,
112
+ const at::Tensor ddelta,
113
+ const at::Tensor dA,
114
+ const at::Tensor dB,
115
+ const at::Tensor dC,
116
+ void* dD_ptr,
117
+ void* ddelta_bias_ptr,
118
+ bool delta_softplus) {
119
+ // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z
120
+ set_ssm_params_fwd(params, batch, dim, seqlen, n_groups, n_chunks,
121
+ u, delta, A, B, C, dout,
122
+ D_ptr, delta_bias_ptr, x_ptr, delta_softplus);
123
+
124
+ // Set the pointers and strides.
125
+ params.dout_ptr = dout.data_ptr();
126
+ params.du_ptr = du.data_ptr();
127
+ params.dA_ptr = dA.data_ptr();
128
+ params.dB_ptr = dB.data_ptr();
129
+ params.dC_ptr = dC.data_ptr();
130
+ params.dD_ptr = dD_ptr;
131
+ params.ddelta_ptr = ddelta.data_ptr();
132
+ params.ddelta_bias_ptr = ddelta_bias_ptr;
133
+ // All stride are in elements, not bytes.
134
+ params.dout_batch_stride = dout.stride(0);
135
+ params.dout_d_stride = dout.stride(1);
136
+ params.dA_d_stride = dA.stride(0);
137
+ params.dB_batch_stride = dB.stride(0);
138
+ params.dB_group_stride = dB.stride(1);
139
+ params.dC_batch_stride = dC.stride(0);
140
+ params.dC_group_stride = dC.stride(1);
141
+ params.du_batch_stride = du.stride(0);
142
+ params.du_d_stride = du.stride(1);
143
+ params.ddelta_batch_stride = ddelta.stride(0);
144
+ params.ddelta_d_stride = ddelta.stride(1);
145
+
146
+ }
147
+
148
+ std::vector<at::Tensor>
149
+ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
150
+ const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
151
+ const c10::optional<at::Tensor> &D_,
152
+ const c10::optional<at::Tensor> &delta_bias_,
153
+ bool delta_softplus,
154
+ int nrows
155
+ ) {
156
+ auto input_type = u.scalar_type();
157
+ auto weight_type = A.scalar_type();
158
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
159
+ TORCH_CHECK(weight_type == at::ScalarType::Float);
160
+
161
+ TORCH_CHECK(delta.scalar_type() == input_type);
162
+ TORCH_CHECK(B.scalar_type() == input_type);
163
+ TORCH_CHECK(C.scalar_type() == input_type);
164
+
165
+ TORCH_CHECK(u.is_cuda());
166
+ TORCH_CHECK(delta.is_cuda());
167
+ TORCH_CHECK(A.is_cuda());
168
+ TORCH_CHECK(B.is_cuda());
169
+ TORCH_CHECK(C.is_cuda());
170
+
171
+ TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
172
+ TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
173
+
174
+ const auto sizes = u.sizes();
175
+ const int batch_size = sizes[0];
176
+ const int dim = sizes[1];
177
+ const int seqlen = sizes[2];
178
+ const int n_groups = B.size(1);
179
+
180
+ TORCH_CHECK(dim % n_groups == 0, "dims should be dividable by n_groups");
181
+
182
+ CHECK_SHAPE(u, batch_size, dim, seqlen);
183
+ CHECK_SHAPE(delta, batch_size, dim, seqlen);
184
+ CHECK_SHAPE(A, dim);
185
+ CHECK_SHAPE(B, batch_size, n_groups, seqlen);
186
+ TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
187
+ CHECK_SHAPE(C, batch_size, n_groups, seqlen);
188
+ TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
189
+
190
+ if (D_.has_value()) {
191
+ auto D = D_.value();
192
+ TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
193
+ TORCH_CHECK(D.is_cuda());
194
+ TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
195
+ CHECK_SHAPE(D, dim);
196
+ }
197
+
198
+ if (delta_bias_.has_value()) {
199
+ auto delta_bias = delta_bias_.value();
200
+ TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
201
+ TORCH_CHECK(delta_bias.is_cuda());
202
+ TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
203
+ CHECK_SHAPE(delta_bias, dim);
204
+ }
205
+
206
+ const int n_chunks = (seqlen + 2048 - 1) / 2048; // max is 128 * 16 = 2048 in fwd_kernel
207
+ at::Tensor out = torch::empty_like(delta);
208
+ at::Tensor x;
209
+ x = torch::empty({batch_size, dim, n_chunks, 1 * 2}, u.options().dtype(weight_type));
210
+
211
+ SSMParamsBase params;
212
+ set_ssm_params_fwd(params, batch_size, dim, seqlen, n_groups, n_chunks,
213
+ u, delta, A, B, C, out,
214
+ D_.has_value() ? D_.value().data_ptr() : nullptr,
215
+ delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
216
+ x.data_ptr(),
217
+ delta_softplus);
218
+
219
+ // Otherwise the kernel will be launched from cuda:0 device
220
+ // Cast to char to avoid compiler warning about narrowing
221
+ at::cuda::CUDAGuard device_guard{(char)u.get_device()};
222
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
223
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
224
+ selective_scan_fwd_cuda<1, input_t, weight_t>(params, stream);
225
+ });
226
+ std::vector<at::Tensor> result = {out, x};
227
+ return result;
228
+ }
229
+
230
+ std::vector<at::Tensor>
231
+ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
232
+ const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
233
+ const c10::optional<at::Tensor> &D_,
234
+ const c10::optional<at::Tensor> &delta_bias_,
235
+ const at::Tensor &dout,
236
+ const c10::optional<at::Tensor> &x_,
237
+ bool delta_softplus,
238
+ int nrows
239
+ ) {
240
+ auto input_type = u.scalar_type();
241
+ auto weight_type = A.scalar_type();
242
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
243
+ TORCH_CHECK(weight_type == at::ScalarType::Float);
244
+
245
+ TORCH_CHECK(delta.scalar_type() == input_type);
246
+ TORCH_CHECK(B.scalar_type() == input_type);
247
+ TORCH_CHECK(C.scalar_type() == input_type);
248
+ TORCH_CHECK(dout.scalar_type() == input_type);
249
+
250
+ TORCH_CHECK(u.is_cuda());
251
+ TORCH_CHECK(delta.is_cuda());
252
+ TORCH_CHECK(A.is_cuda());
253
+ TORCH_CHECK(B.is_cuda());
254
+ TORCH_CHECK(C.is_cuda());
255
+ TORCH_CHECK(dout.is_cuda());
256
+
257
+ TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
258
+ TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
259
+ TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1);
260
+
261
+ const auto sizes = u.sizes();
262
+ const int batch_size = sizes[0];
263
+ const int dim = sizes[1];
264
+ const int seqlen = sizes[2];
265
+ const int n_groups = B.size(1);
266
+
267
+ TORCH_CHECK(dim % n_groups == 0, "dims should be dividable by n_groups");
268
+
269
+ CHECK_SHAPE(u, batch_size, dim, seqlen);
270
+ CHECK_SHAPE(delta, batch_size, dim, seqlen);
271
+ CHECK_SHAPE(A, dim);
272
+ CHECK_SHAPE(B, batch_size, n_groups, seqlen);
273
+ TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
274
+ CHECK_SHAPE(C, batch_size, n_groups, seqlen);
275
+ TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
276
+ CHECK_SHAPE(dout, batch_size, dim, seqlen);
277
+
278
+ if (D_.has_value()) {
279
+ auto D = D_.value();
280
+ TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
281
+ TORCH_CHECK(D.is_cuda());
282
+ TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
283
+ CHECK_SHAPE(D, dim);
284
+ }
285
+
286
+ if (delta_bias_.has_value()) {
287
+ auto delta_bias = delta_bias_.value();
288
+ TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
289
+ TORCH_CHECK(delta_bias.is_cuda());
290
+ TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
291
+ CHECK_SHAPE(delta_bias, dim);
292
+ }
293
+
294
+ at::Tensor out;
295
+ const int n_chunks = (seqlen + 2048 - 1) / 2048;
296
+ // const int n_chunks = (seqlen + 1024 - 1) / 1024;
297
+ if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); }
298
+ if (x_.has_value()) {
299
+ auto x = x_.value();
300
+ TORCH_CHECK(x.scalar_type() == weight_type);
301
+ TORCH_CHECK(x.is_cuda());
302
+ TORCH_CHECK(x.is_contiguous());
303
+ CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * 1);
304
+ }
305
+
306
+ at::Tensor du = torch::empty_like(u);
307
+ at::Tensor ddelta = torch::empty_like(delta);
308
+ at::Tensor dA = torch::zeros_like(A);
309
+ at::Tensor dB = torch::zeros_like(B, B.options().dtype(torch::kFloat32));
310
+ at::Tensor dC = torch::zeros_like(C, C.options().dtype(torch::kFloat32));
311
+ at::Tensor dD;
312
+ if (D_.has_value()) { dD = torch::zeros_like(D_.value()); }
313
+ at::Tensor ddelta_bias;
314
+ if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); }
315
+
316
+ SSMParamsBwd params;
317
+ set_ssm_params_bwd(params, batch_size, dim, seqlen, n_groups, n_chunks,
318
+ u, delta, A, B, C, out,
319
+ D_.has_value() ? D_.value().data_ptr() : nullptr,
320
+ delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
321
+ x_.has_value() ? x_.value().data_ptr() : nullptr,
322
+ dout, du, ddelta, dA, dB, dC,
323
+ D_.has_value() ? dD.data_ptr() : nullptr,
324
+ delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr,
325
+ delta_softplus);
326
+
327
+ // Otherwise the kernel will be launched from cuda:0 device
328
+ // Cast to char to avoid compiler warning about narrowing
329
+ at::cuda::CUDAGuard device_guard{(char)u.get_device()};
330
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
331
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] {
332
+ selective_scan_bwd_cuda<1, input_t, weight_t>(params, stream);
333
+ });
334
+ std::vector<at::Tensor> result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias};
335
+ return result;
336
+ }
337
+
338
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
339
+ m.def("fwd", &selective_scan_fwd, "Selective scan forward");
340
+ m.def("bwd", &selective_scan_bwd, "Selective scan backward");
341
+ }
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_ndstate.h ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
8
+
9
+ struct SSMScanParamsBase {
10
+ using index_t = uint32_t;
11
+
12
+ int batch, seqlen, n_chunks;
13
+ index_t a_batch_stride;
14
+ index_t b_batch_stride;
15
+ index_t out_batch_stride;
16
+
17
+ // Common data pointers.
18
+ void *__restrict__ a_ptr;
19
+ void *__restrict__ b_ptr;
20
+ void *__restrict__ out_ptr;
21
+ void *__restrict__ x_ptr;
22
+ };
23
+
24
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
25
+
26
+ struct SSMParamsBase {
27
+ using index_t = uint32_t;
28
+
29
+ int batch, dim, seqlen, n_groups, n_chunks;
30
+ int dim_ngroups_ratio;
31
+
32
+ bool delta_softplus;
33
+
34
+ index_t A_d_stride;
35
+ index_t B_batch_stride;
36
+ index_t B_d_stride;
37
+ index_t B_group_stride;
38
+ index_t C_batch_stride;
39
+ index_t C_d_stride;
40
+ index_t C_group_stride;
41
+ index_t u_batch_stride;
42
+ index_t u_d_stride;
43
+ index_t delta_batch_stride;
44
+ index_t delta_d_stride;
45
+ index_t out_batch_stride;
46
+ index_t out_d_stride;
47
+
48
+ // Common data pointers.
49
+ void *__restrict__ A_ptr;
50
+ void *__restrict__ B_ptr;
51
+ void *__restrict__ C_ptr;
52
+ void *__restrict__ D_ptr;
53
+ void *__restrict__ u_ptr;
54
+ void *__restrict__ delta_ptr;
55
+ void *__restrict__ delta_bias_ptr;
56
+ void *__restrict__ out_ptr;
57
+ void *__restrict__ x_ptr;
58
+ };
59
+
60
+ struct SSMParamsBwd: public SSMParamsBase {
61
+ index_t dout_batch_stride;
62
+ index_t dout_d_stride;
63
+ index_t dA_d_stride;
64
+ index_t dB_batch_stride;
65
+ index_t dB_group_stride;
66
+ index_t dB_d_stride;
67
+ index_t dC_batch_stride;
68
+ index_t dC_group_stride;
69
+ index_t dC_d_stride;
70
+ index_t du_batch_stride;
71
+ index_t du_d_stride;
72
+ index_t ddelta_batch_stride;
73
+ index_t ddelta_d_stride;
74
+
75
+ // Common data pointers.
76
+ void *__restrict__ dout_ptr;
77
+ void *__restrict__ dA_ptr;
78
+ void *__restrict__ dB_ptr;
79
+ void *__restrict__ dC_ptr;
80
+ void *__restrict__ dD_ptr;
81
+ void *__restrict__ du_ptr;
82
+ void *__restrict__ ddelta_ptr;
83
+ void *__restrict__ ddelta_bias_ptr;
84
+ };
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_bwd_kernel_nrow.cuh ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <c10/util/BFloat16.h>
8
+ #include <c10/util/Half.h>
9
+ #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
10
+ #include <ATen/cuda/Atomic.cuh> // For atomicAdd on complex
11
+
12
+ #include <cub/block/block_load.cuh>
13
+ #include <cub/block/block_store.cuh>
14
+ #include <cub/block/block_scan.cuh>
15
+ #include <cub/block/block_reduce.cuh>
16
+
17
+ #include "selective_scan.h"
18
+ #include "selective_scan_common.h"
19
+ #include "reverse_scan.cuh"
20
+ #include "static_switch.h"
21
+
22
+ template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_, bool kDeltaSoftplus_, typename input_t_, typename weight_t_>
23
+ struct Selective_Scan_bwd_kernel_traits {
24
+ static_assert(kNItems_ % 4 == 0);
25
+ using input_t = input_t_;
26
+ using weight_t = weight_t_;
27
+ static constexpr int kNThreads = kNThreads_;
28
+ static constexpr int kNItems = kNItems_;
29
+ static constexpr int kNRows = kNRows_;
30
+ static constexpr int MaxDState = MAX_DSTATE / kNRows_;
31
+ static constexpr int kNBytes = sizeof(input_t);
32
+ static_assert(kNBytes == 2 || kNBytes == 4);
33
+ static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
34
+ static_assert(kNItems % kNElts == 0);
35
+ static constexpr int kNLoads = kNItems / kNElts;
36
+ static constexpr bool kIsEvenLen = kIsEvenLen_;
37
+ static constexpr bool kDeltaSoftplus = kDeltaSoftplus_;
38
+ // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy.
39
+ // For complex this would lead to massive register spilling, so we keep it at 2.
40
+ static constexpr int kMinBlocks = kNThreads == 128 && 3;
41
+ using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
42
+ using scan_t = float2;
43
+ using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
44
+ using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
45
+ using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
46
+ using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
47
+ using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
48
+ using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads, cub::BLOCK_STORE_WARP_TRANSPOSE>;
49
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
50
+ using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
51
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
52
+ using BlockReverseScanT = BlockReverseScan<scan_t, kNThreads>;
53
+ using BlockReduceT = cub::BlockReduce<scan_t, kNThreads>;
54
+ using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
55
+ using BlockExchangeT = cub::BlockExchange<float, kNThreads, kNItems>;
56
+ static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
57
+ sizeof(typename BlockLoadVecT::TempStorage),
58
+ 2 * sizeof(typename BlockLoadWeightT::TempStorage),
59
+ 2 * sizeof(typename BlockLoadWeightVecT::TempStorage),
60
+ sizeof(typename BlockStoreT::TempStorage),
61
+ sizeof(typename BlockStoreVecT::TempStorage)});
62
+ static constexpr int kSmemExchangeSize = 2 * sizeof(typename BlockExchangeT::TempStorage);
63
+ static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage);
64
+ static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage);
65
+ };
66
+
67
+ template<typename Ktraits>
68
+ __global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
69
+ void selective_scan_bwd_kernel(SSMParamsBwd params) {
70
+ constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus;
71
+ constexpr int kNThreads = Ktraits::kNThreads;
72
+ constexpr int kNItems = Ktraits::kNItems;
73
+ constexpr int kNRows = Ktraits::kNRows;
74
+ using input_t = typename Ktraits::input_t;
75
+ using weight_t = typename Ktraits::weight_t;
76
+ using scan_t = typename Ktraits::scan_t;
77
+
78
+ // Shared memory.
79
+ extern __shared__ char smem_[];
80
+ auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
81
+ auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
82
+ auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
83
+ auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
84
+ auto& smem_exchange = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
85
+ auto& smem_exchange1 = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage));
86
+ auto& smem_reduce = *reinterpret_cast<typename Ktraits::BlockReduceT::TempStorage*>(reinterpret_cast<char *>(&smem_exchange) + Ktraits::kSmemExchangeSize);
87
+ auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(&smem_reduce);
88
+ auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(reinterpret_cast<char *>(&smem_reduce) + Ktraits::kSmemReduceSize);
89
+ auto& smem_reverse_scan = *reinterpret_cast<typename Ktraits::BlockReverseScanT::TempStorage*>(reinterpret_cast<char *>(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage));
90
+ weight_t *smem_delta_a = reinterpret_cast<weight_t *>(smem_ + Ktraits::kSmemSize);
91
+ // scan_t *smem_running_postfix = reinterpret_cast<scan_t *>(smem_delta_a + kNRows * (2 * Ktraits::MaxDState + kNThreads));
92
+ scan_t *smem_running_postfix = reinterpret_cast<scan_t *>(smem_delta_a + kNRows * 2 * Ktraits::MaxDState + kNThreads);
93
+ weight_t *smem_da = reinterpret_cast<weight_t *>(smem_running_postfix + kNRows * Ktraits::MaxDState);
94
+
95
+ const int batch_id = blockIdx.x;
96
+ const int dim_id = blockIdx.y;
97
+ const int dim_id_nrow = dim_id * kNRows;
98
+ const int group_id = dim_id_nrow / (params.dim_ngroups_ratio);
99
+ input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
100
+ + dim_id_nrow * params.u_d_stride;
101
+ input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
102
+ + dim_id_nrow * params.delta_d_stride;
103
+ input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
104
+ + dim_id_nrow * params.dout_d_stride;
105
+ weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id_nrow * params.A_d_stride;
106
+ input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
107
+ input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
108
+ weight_t *dA = reinterpret_cast<weight_t *>(params.dA_ptr) + dim_id_nrow * params.dA_d_stride;
109
+ weight_t *dB = reinterpret_cast<weight_t *>(params.dB_ptr)
110
+ + (batch_id * params.dB_batch_stride + group_id * params.dB_group_stride);
111
+ weight_t *dC = reinterpret_cast<weight_t *>(params.dC_ptr)
112
+ + (batch_id * params.dC_batch_stride + group_id * params.dC_group_stride);
113
+ float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.dD_ptr) + dim_id_nrow;
114
+ float *D_val = params.D_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.D_ptr) + dim_id_nrow;
115
+ float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.ddelta_bias_ptr) + dim_id_nrow;
116
+ float *delta_bias = params.delta_bias_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.delta_bias_ptr) + dim_id_nrow;
117
+ scan_t *x = params.x_ptr == nullptr
118
+ ? nullptr
119
+ : reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id_nrow) * (params.n_chunks) * params.dstate;
120
+ float dD_val[kNRows] = {0};
121
+ float ddelta_bias_val[kNRows] = {0};
122
+
123
+ constexpr int kChunkSize = kNThreads * kNItems;
124
+ u += (params.n_chunks - 1) * kChunkSize;
125
+ delta += (params.n_chunks - 1) * kChunkSize;
126
+ dout += (params.n_chunks - 1) * kChunkSize;
127
+ Bvar += (params.n_chunks - 1) * kChunkSize;
128
+ Cvar += (params.n_chunks - 1) * kChunkSize;
129
+ for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) {
130
+ input_t u_vals[kNRows][kNItems];
131
+ input_t delta_vals_load[kNRows][kNItems];
132
+ input_t dout_vals_load[kNRows][kNItems];
133
+ #pragma unroll
134
+ for (int r = 0; r < kNRows; ++r) {
135
+ __syncthreads();
136
+ load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize);
137
+ __syncthreads();
138
+ load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize);
139
+ __syncthreads();
140
+ load_input<Ktraits>(dout + r * params.dout_d_stride, dout_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize);
141
+ }
142
+ u -= kChunkSize;
143
+ // Will reload delta at the same location if kDeltaSoftplus
144
+ if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; }
145
+ dout -= kChunkSize;
146
+
147
+ float dout_vals[kNRows][kNItems], delta_vals[kNRows][kNItems];
148
+ float du_vals[kNRows][kNItems];
149
+ #pragma unroll
150
+ for (int r = 0; r < kNRows; ++r) {
151
+ #pragma unroll
152
+ for (int i = 0; i < kNItems; ++i) {
153
+ dout_vals[r][i] = float(dout_vals_load[r][i]);
154
+ delta_vals[r][i] = float(delta_vals_load[r][i]) + (delta_bias == nullptr? 0: delta_bias[r]);
155
+ if constexpr (kDeltaSoftplus) {
156
+ delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i];
157
+ }
158
+ }
159
+ #pragma unroll
160
+ for (int i = 0; i < kNItems; ++i) { du_vals[r][i] = (D_val == nullptr? 0: D_val[r]) * dout_vals[r][i]; }
161
+ #pragma unroll
162
+ for (int i = 0; i < kNItems; ++i) { dD_val[r] += dout_vals[r][i] * float(u_vals[r][i]); }
163
+ }
164
+
165
+ float ddelta_vals[kNRows][kNItems] = {0};
166
+ __syncthreads();
167
+ for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
168
+ weight_t A_val[kNRows];
169
+ weight_t A_scaled[kNRows];
170
+ #pragma unroll
171
+ for (int r = 0; r < kNRows; ++r) {
172
+ A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride];
173
+ constexpr float kLog2e = M_LOG2E;
174
+ A_scaled[r] = A_val[r] * kLog2e;
175
+ }
176
+ weight_t B_vals[kNItems], C_vals[kNItems];
177
+ load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
178
+ smem_load_weight, (params.seqlen - chunk * kChunkSize));
179
+ auto &smem_load_weight_C = smem_load_weight1;
180
+ load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
181
+ smem_load_weight_C, (params.seqlen - chunk * kChunkSize));
182
+ #pragma unroll
183
+ for (int r = 0; r < kNRows; ++r) {
184
+ scan_t thread_data[kNItems], thread_reverse_data[kNItems];
185
+ #pragma unroll
186
+ for (int i = 0; i < kNItems; ++i) {
187
+ const float delta_a_exp = exp2f(delta_vals[r][i] * A_scaled[r]);
188
+ thread_data[i] = make_float2(delta_a_exp, delta_vals[r][i] * float(u_vals[r][i]) * B_vals[i]);
189
+ if (i == 0) {
190
+ // smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * Ktraits::MaxDState + r * (2 * Ktraits::MaxDState + kNThreads) : threadIdx.x + 2 * Ktraits::MaxDState + r * (2 * Ktraits::MaxDState + kNThreads)] = delta_a_exp;
191
+ smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * Ktraits::MaxDState + r * 2 * Ktraits::MaxDState : threadIdx.x + kNRows * 2 * Ktraits::MaxDState] = delta_a_exp;
192
+
193
+ } else {
194
+ thread_reverse_data[i - 1].x = delta_a_exp;
195
+ }
196
+ thread_reverse_data[i].y = dout_vals[r][i] * C_vals[i];
197
+ }
198
+ __syncthreads();
199
+ thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1
200
+ // ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState + r * (2 * Ktraits::MaxDState + kNThreads)])
201
+ // : smem_delta_a[threadIdx.x + 1 + 2 * Ktraits::MaxDState + r * (2 * Ktraits::MaxDState + kNThreads)];
202
+ ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState + r * 2 * Ktraits::MaxDState])
203
+ : smem_delta_a[threadIdx.x + 1 + kNRows * 2 * Ktraits::MaxDState];
204
+ // Initialize running total
205
+ scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(r * params.n_chunks + chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f);
206
+ SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
207
+ Ktraits::BlockScanT(smem_scan).InclusiveScan(
208
+ thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
209
+ );
210
+ scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx + r * Ktraits::MaxDState] : make_float2(1.f, 0.f);
211
+ SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
212
+ Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
213
+ thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
214
+ );
215
+ if (threadIdx.x == 0) { smem_running_postfix[state_idx + r * Ktraits::MaxDState] = postfix_op.running_prefix; }
216
+ weight_t dA_val = 0;
217
+ weight_t dB_vals[kNItems], dC_vals[kNItems];
218
+ #pragma unroll
219
+ for (int i = 0; i < kNItems; ++i) {
220
+ const float dx = thread_reverse_data[i].y;
221
+ const float ddelta_u = dx * B_vals[i];
222
+ du_vals[r][i] += ddelta_u * delta_vals[r][i];
223
+ const float a = thread_data[i].y - (delta_vals[r][i] * float(u_vals[r][i]) * B_vals[i]);
224
+ ddelta_vals[r][i] += ddelta_u * float(u_vals[r][i]) + dx * A_val[r] * a;
225
+ dA_val += dx * delta_vals[r][i] * a;
226
+ dB_vals[i] = dx * delta_vals[r][i] * float(u_vals[r][i]);
227
+ dC_vals[i] = dout_vals[r][i] * thread_data[i].y;
228
+ }
229
+ // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
230
+ Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals);
231
+ auto &smem_exchange_C = smem_exchange1;
232
+ Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals);
233
+ const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x;
234
+ weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x;
235
+ weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x;
236
+ #pragma unroll
237
+ for (int i = 0; i < kNItems; ++i) {
238
+ if (i * kNThreads < seqlen_remaining) {
239
+ { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); }
240
+ { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); }
241
+ }
242
+ }
243
+ dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val);
244
+ if (threadIdx.x == 0) {
245
+ smem_da[state_idx + r * Ktraits::MaxDState] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx + r * Ktraits::MaxDState];
246
+ }
247
+ }
248
+ }
249
+
250
+ if constexpr (kDeltaSoftplus) {
251
+ input_t delta_vals_load[kNRows][kNItems];
252
+ #pragma unroll
253
+ for (int r = 0; r < kNRows; ++r) {
254
+ __syncthreads();
255
+ load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize);
256
+ }
257
+ delta -= kChunkSize;
258
+ #pragma unroll
259
+ for (int r = 0; r < kNRows; ++r) {
260
+ #pragma unroll
261
+ for (int i = 0; i < kNItems; ++i) {
262
+ float delta_val = float(delta_vals_load[r][i]) + (delta_bias == nullptr? 0: delta_bias[r]);
263
+ float delta_val_neg_exp = expf(-delta_val);
264
+ ddelta_vals[r][i] = delta_val <= 20.f
265
+ ? ddelta_vals[r][i] / (1.f + delta_val_neg_exp)
266
+ : ddelta_vals[r][i];
267
+ }
268
+ }
269
+ }
270
+
271
+ __syncthreads();
272
+ #pragma unroll
273
+ for (int r = 0; r < kNRows; ++r) {
274
+ #pragma unroll
275
+ for (int i = 0; i < kNItems; ++i) { ddelta_bias_val[r] += ddelta_vals[r][i]; }
276
+ }
277
+
278
+ input_t *du = reinterpret_cast<input_t *>(params.du_ptr) + batch_id * params.du_batch_stride
279
+ + dim_id_nrow * params.du_d_stride + chunk * kChunkSize;
280
+ input_t *ddelta = reinterpret_cast<input_t *>(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride
281
+ + dim_id_nrow * params.ddelta_d_stride + chunk * kChunkSize;
282
+ #pragma unroll
283
+ for (int r = 0; r < kNRows; ++r) {
284
+ __syncthreads();
285
+ store_output<Ktraits>(du + r * params.du_d_stride, du_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
286
+ __syncthreads();
287
+ store_output<Ktraits>(ddelta + r * params.ddelta_d_stride, ddelta_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
288
+ }
289
+
290
+ Bvar -= kChunkSize;
291
+ Cvar -= kChunkSize;
292
+ }
293
+
294
+ #pragma unroll
295
+ for (int r = 0; r < kNRows; ++r) {
296
+ if (params.dD_ptr != nullptr) {
297
+ __syncthreads();
298
+ dD_val[r] = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val[r]);
299
+ if (threadIdx.x == 0) { gpuAtomicAdd(&(dD[r]), dD_val[r]); }
300
+ }
301
+ if (params.ddelta_bias_ptr != nullptr) {
302
+ __syncthreads();
303
+ ddelta_bias_val[r] = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val[r]);
304
+ if (threadIdx.x == 0) { gpuAtomicAdd(&(ddelta_bias[r]), ddelta_bias_val[r]); }
305
+ }
306
+ __syncthreads();
307
+ for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
308
+ gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride + r * params.dA_d_stride]), smem_da[state_idx + r * Ktraits::MaxDState]);
309
+ }
310
+ }
311
+ }
312
+
313
+ template<int kNThreads, int kNItems, int kNRows, typename input_t, typename weight_t>
314
+ void selective_scan_bwd_launch(SSMParamsBwd &params, cudaStream_t stream) {
315
+ BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
316
+ BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] {
317
+ using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kDeltaSoftplus, input_t, weight_t>;
318
+ constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * Ktraits::MaxDState * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * kNRows * Ktraits::MaxDState) * sizeof(typename Ktraits::weight_t);
319
+ // printf("smem_size = %d\n", kSmemSize);
320
+ dim3 grid(params.batch, params.dim / kNRows);
321
+ auto kernel = &selective_scan_bwd_kernel<Ktraits>;
322
+ if (kSmemSize >= 48 * 1024) {
323
+ C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
324
+ }
325
+ kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
326
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
327
+ });
328
+ });
329
+ }
330
+
331
+ template<int knrows, typename input_t, typename weight_t>
332
+ void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream) {
333
+ if (params.seqlen <= 128) {
334
+ selective_scan_bwd_launch<32, 4, knrows, input_t, weight_t>(params, stream);
335
+ } else if (params.seqlen <= 256) {
336
+ selective_scan_bwd_launch<32, 8, knrows, input_t, weight_t>(params, stream);
337
+ } else if (params.seqlen <= 512) {
338
+ selective_scan_bwd_launch<32, 16, knrows, input_t, weight_t>(params, stream);
339
+ } else if (params.seqlen <= 1024) {
340
+ selective_scan_bwd_launch<64, 16, knrows, input_t, weight_t>(params, stream);
341
+ } else {
342
+ selective_scan_bwd_launch<128, 16, knrows, input_t, weight_t>(params, stream);
343
+ }
344
+ }
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+ #include "selective_scan_bwd_kernel_nrow.cuh"
5
+
6
+ template void selective_scan_bwd_cuda<1, float, float>(SSMParamsBwd &params, cudaStream_t stream);
7
+ template void selective_scan_bwd_cuda<1, at::Half, float>(SSMParamsBwd &params, cudaStream_t stream);
8
+ template void selective_scan_bwd_cuda<1, at::BFloat16, float>(SSMParamsBwd &params, cudaStream_t stream);
9
+
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd2.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+ #include "selective_scan_bwd_kernel_nrow.cuh"
5
+
6
+ template void selective_scan_bwd_cuda<2, float, float>(SSMParamsBwd &params, cudaStream_t stream);
7
+ template void selective_scan_bwd_cuda<2, at::Half, float>(SSMParamsBwd &params, cudaStream_t stream);
8
+ template void selective_scan_bwd_cuda<2, at::BFloat16, float>(SSMParamsBwd &params, cudaStream_t stream);
9
+
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd3.cu ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+ #include "selective_scan_bwd_kernel_nrow.cuh"
5
+
6
+ template void selective_scan_bwd_cuda<3, float, float>(SSMParamsBwd &params, cudaStream_t stream);
7
+ template void selective_scan_bwd_cuda<3, at::Half, float>(SSMParamsBwd &params, cudaStream_t stream);
8
+ template void selective_scan_bwd_cuda<3, at::BFloat16, float>(SSMParamsBwd &params, cudaStream_t stream);
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd4.cu ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+ #include "selective_scan_bwd_kernel_nrow.cuh"
5
+
6
+ template void selective_scan_bwd_cuda<4, float, float>(SSMParamsBwd &params, cudaStream_t stream);
7
+ template void selective_scan_bwd_cuda<4, at::Half, float>(SSMParamsBwd &params, cudaStream_t stream);
8
+ template void selective_scan_bwd_cuda<4, at::BFloat16, float>(SSMParamsBwd &params, cudaStream_t stream);
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+ #include "selective_scan_fwd_kernel_nrow.cuh"
5
+
6
+ template void selective_scan_fwd_cuda<1, float, float>(SSMParamsBase &params, cudaStream_t stream);
7
+ template void selective_scan_fwd_cuda<1, at::Half, float>(SSMParamsBase &params, cudaStream_t stream);
8
+ template void selective_scan_fwd_cuda<1, at::BFloat16, float>(SSMParamsBase &params, cudaStream_t stream);
9
+
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd2.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+ #include "selective_scan_fwd_kernel_nrow.cuh"
5
+
6
+ template void selective_scan_fwd_cuda<2, float, float>(SSMParamsBase &params, cudaStream_t stream);
7
+ template void selective_scan_fwd_cuda<2, at::Half, float>(SSMParamsBase &params, cudaStream_t stream);
8
+ template void selective_scan_fwd_cuda<2, at::BFloat16, float>(SSMParamsBase &params, cudaStream_t stream);
9
+
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd3.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+ #include "selective_scan_fwd_kernel_nrow.cuh"
5
+
6
+ template void selective_scan_fwd_cuda<3, float, float>(SSMParamsBase &params, cudaStream_t stream);
7
+ template void selective_scan_fwd_cuda<3, at::Half, float>(SSMParamsBase &params, cudaStream_t stream);
8
+ template void selective_scan_fwd_cuda<3, at::BFloat16, float>(SSMParamsBase &params, cudaStream_t stream);
9
+
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd4.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+ #include "selective_scan_fwd_kernel_nrow.cuh"
5
+
6
+ template void selective_scan_fwd_cuda<4, float, float>(SSMParamsBase &params, cudaStream_t stream);
7
+ template void selective_scan_fwd_cuda<4, at::Half, float>(SSMParamsBase &params, cudaStream_t stream);
8
+ template void selective_scan_fwd_cuda<4, at::BFloat16, float>(SSMParamsBase &params, cudaStream_t stream);
9
+
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_fwd_kernel_nrow.cuh ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <c10/util/BFloat16.h>
8
+ #include <c10/util/Half.h>
9
+ #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
10
+
11
+ #include <cub/block/block_load.cuh>
12
+ #include <cub/block/block_store.cuh>
13
+ #include <cub/block/block_scan.cuh>
14
+
15
+ #include "selective_scan.h"
16
+ #include "selective_scan_common.h"
17
+ #include "static_switch.h"
18
+
19
+ template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_, typename input_t_, typename weight_t_>
20
+ struct Selective_Scan_fwd_kernel_traits {
21
+ static_assert(kNItems_ % 4 == 0);
22
+ using input_t = input_t_;
23
+ using weight_t = weight_t_;
24
+ static constexpr int kNThreads = kNThreads_;
25
+ // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
26
+ static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
27
+ static constexpr int kNItems = kNItems_;
28
+ static constexpr int kNRows = kNRows_;
29
+ static constexpr int MaxDState = MAX_DSTATE / kNRows;
30
+ static constexpr int kNBytes = sizeof(input_t);
31
+ static_assert(kNBytes == 2 || kNBytes == 4);
32
+ static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
33
+ static_assert(kNItems % kNElts == 0);
34
+ static constexpr int kNLoads = kNItems / kNElts;
35
+ static constexpr bool kIsEvenLen = kIsEvenLen_;
36
+
37
+ static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;
38
+
39
+ using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
40
+ using scan_t = float2;
41
+ using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
42
+ using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
43
+ !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
44
+ using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
45
+ using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
46
+ !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
47
+ using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
48
+ using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads,
49
+ !kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
50
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
51
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
52
+ using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
53
+ static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
54
+ sizeof(typename BlockLoadVecT::TempStorage),
55
+ 2 * sizeof(typename BlockLoadWeightT::TempStorage),
56
+ 2 * sizeof(typename BlockLoadWeightVecT::TempStorage),
57
+ sizeof(typename BlockStoreT::TempStorage),
58
+ sizeof(typename BlockStoreVecT::TempStorage)});
59
+ static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
60
+ };
61
+
62
+ template<typename Ktraits>
63
+ __global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
64
+ void selective_scan_fwd_kernel(SSMParamsBase params) {
65
+ constexpr int kNThreads = Ktraits::kNThreads;
66
+ constexpr int kNItems = Ktraits::kNItems;
67
+ constexpr int kNRows = Ktraits::kNRows;
68
+ constexpr bool kDirectIO = Ktraits::kDirectIO;
69
+ using input_t = typename Ktraits::input_t;
70
+ using weight_t = typename Ktraits::weight_t;
71
+ using scan_t = typename Ktraits::scan_t;
72
+
73
+ // Shared memory.
74
+ extern __shared__ char smem_[];
75
+ auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
76
+ auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
77
+ auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
78
+ auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
79
+ auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
80
+ scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize);
81
+
82
+ const int batch_id = blockIdx.x;
83
+ const int dim_id = blockIdx.y;
84
+ const int dim_id_nrow = dim_id * kNRows;
85
+ const int group_id = dim_id_nrow / (params.dim_ngroups_ratio);
86
+ input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
87
+ + dim_id_nrow * params.u_d_stride;
88
+ input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
89
+ + dim_id_nrow * params.delta_d_stride;
90
+ weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id_nrow * params.A_d_stride;
91
+ input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
92
+ input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
93
+ scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id_nrow) * params.n_chunks * params.dstate;
94
+
95
+ float D_val[kNRows] = {0};
96
+ if (params.D_ptr != nullptr) {
97
+ #pragma unroll
98
+ for (int r = 0; r < kNRows; ++r) {
99
+ D_val[r] = reinterpret_cast<float *>(params.D_ptr)[dim_id_nrow + r];
100
+ }
101
+ }
102
+ float delta_bias[kNRows] = {0};
103
+ if (params.delta_bias_ptr != nullptr) {
104
+ #pragma unroll
105
+ for (int r = 0; r < kNRows; ++r) {
106
+ delta_bias[r] = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id_nrow + r];
107
+ }
108
+ }
109
+
110
+ constexpr int kChunkSize = kNThreads * kNItems;
111
+ for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
112
+ input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
113
+ __syncthreads();
114
+ #pragma unroll
115
+ for (int r = 0; r < kNRows; ++r) {
116
+ if constexpr (!kDirectIO) {
117
+ if (r > 0) { __syncthreads(); }
118
+ }
119
+ load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize);
120
+ if constexpr (!kDirectIO) { __syncthreads(); }
121
+ load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize);
122
+ }
123
+ u += kChunkSize;
124
+ delta += kChunkSize;
125
+
126
+ float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
127
+ #pragma unroll
128
+ for (int r = 0; r < kNRows; ++r) {
129
+ #pragma unroll
130
+ for (int i = 0; i < kNItems; ++i) {
131
+ float u_val = float(u_vals[r][i]);
132
+ delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r];
133
+ if (params.delta_softplus) {
134
+ delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i];
135
+ }
136
+ delta_u_vals[r][i] = delta_vals[r][i] * u_val;
137
+ out_vals[r][i] = D_val[r] * u_val;
138
+ }
139
+ }
140
+
141
+ __syncthreads();
142
+ for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
143
+ weight_t A_val[kNRows];
144
+ #pragma unroll
145
+ for (int r = 0; r < kNRows; ++r) {
146
+ A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride];
147
+ // Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
148
+ constexpr float kLog2e = M_LOG2E;
149
+ A_val[r] *= kLog2e;
150
+ }
151
+ weight_t B_vals[kNItems], C_vals[kNItems];
152
+ load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
153
+ smem_load_weight, (params.seqlen - chunk * kChunkSize));
154
+ auto &smem_load_weight_C = smem_load_weight1;
155
+ load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
156
+ smem_load_weight_C, (params.seqlen - chunk * kChunkSize));
157
+ #pragma unroll
158
+ for (int r = 0; r < kNRows; ++r) {
159
+ if (r > 0) { __syncthreads(); } // Scan could be using the same smem
160
+ scan_t thread_data[kNItems];
161
+ #pragma unroll
162
+ for (int i = 0; i < kNItems; ++i) {
163
+ thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
164
+ B_vals[i] * delta_u_vals[r][i]);
165
+ if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
166
+ if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
167
+ thread_data[i] = make_float2(1.f, 0.f);
168
+ }
169
+ }
170
+ }
171
+ // Initialize running total
172
+ scan_t running_prefix;
173
+ // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read
174
+ running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * Ktraits::MaxDState] : make_float2(1.f, 0.f);
175
+ // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
176
+ SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
177
+ Ktraits::BlockScanT(smem_scan).InclusiveScan(
178
+ thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
179
+ );
180
+ // There's a syncthreads in the scan op, so we don't need to sync here.
181
+ // Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
182
+ if (threadIdx.x == 0) {
183
+ smem_running_prefix[state_idx + r * Ktraits::MaxDState] = prefix_op.running_prefix;
184
+ x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix;
185
+ }
186
+ #pragma unroll
187
+ for (int i = 0; i < kNItems; ++i) {
188
+ out_vals[r][i] += thread_data[i].y * C_vals[i];
189
+ }
190
+ }
191
+ }
192
+
193
+ input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
194
+ + dim_id_nrow * params.out_d_stride + chunk * kChunkSize;
195
+ __syncthreads();
196
+ #pragma unroll
197
+ for (int r = 0; r < kNRows; ++r) {
198
+ if constexpr (!kDirectIO) {
199
+ if (r > 0) { __syncthreads(); }
200
+ }
201
+ store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
202
+ }
203
+
204
+ Bvar += kChunkSize;
205
+ Cvar += kChunkSize;
206
+ }
207
+ }
208
+
209
+ template<int kNThreads, int kNItems, int kNRows, typename input_t, typename weight_t>
210
+ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
211
+ BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
212
+ using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, input_t, weight_t>;
213
+ constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * Ktraits::MaxDState * sizeof(typename Ktraits::scan_t);
214
+ // printf("smem_size = %d\n", kSmemSize);
215
+ dim3 grid(params.batch, params.dim / kNRows);
216
+ auto kernel = &selective_scan_fwd_kernel<Ktraits>;
217
+ if (kSmemSize >= 48 * 1024) {
218
+ C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
219
+ }
220
+ kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
221
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
222
+ });
223
+ }
224
+
225
+ template<int knrows, typename input_t, typename weight_t>
226
+ void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
227
+ if (params.seqlen <= 128) {
228
+ selective_scan_fwd_launch<32, 4, knrows, input_t, weight_t>(params, stream);
229
+ } else if (params.seqlen <= 256) {
230
+ selective_scan_fwd_launch<32, 8, knrows, input_t, weight_t>(params, stream);
231
+ } else if (params.seqlen <= 512) {
232
+ selective_scan_fwd_launch<32, 16, knrows, input_t, weight_t>(params, stream);
233
+ } else if (params.seqlen <= 1024) {
234
+ selective_scan_fwd_launch<64, 16, knrows, input_t, weight_t>(params, stream);
235
+ } else {
236
+ selective_scan_fwd_launch<128, 16, knrows, input_t, weight_t>(params, stream);
237
+ }
238
+ }
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_nrow.cpp ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #include <ATen/cuda/CUDAContext.h>
6
+ #include <c10/cuda/CUDAGuard.h>
7
+ #include <torch/extension.h>
8
+ #include <vector>
9
+
10
+ #include "selective_scan.h"
11
+ #define MAX_DSTATE 256
12
+
13
+ #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
14
+ using weight_t = float;
15
+
16
+ #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
17
+ if (ITYPE == at::ScalarType::Half) { \
18
+ using input_t = at::Half; \
19
+ __VA_ARGS__(); \
20
+ } else if (ITYPE == at::ScalarType::BFloat16) { \
21
+ using input_t = at::BFloat16; \
22
+ __VA_ARGS__(); \
23
+ } else if (ITYPE == at::ScalarType::Float) { \
24
+ using input_t = float; \
25
+ __VA_ARGS__(); \
26
+ } else { \
27
+ AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
28
+ }
29
+
30
+ #define INT_SWITCH(INT, NAME, ...) [&] { \
31
+ if (INT == 2) {constexpr int NAME = 2; __VA_ARGS__(); } \
32
+ else if (INT == 3) {constexpr int NAME = 3; __VA_ARGS__(); } \
33
+ else if (INT == 4) {constexpr int NAME = 4; __VA_ARGS__(); } \
34
+ else {constexpr int NAME = 1; __VA_ARGS__(); } \
35
+ }() \
36
+
37
+
38
+ template<int knrows, typename input_t, typename weight_t>
39
+ void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream);
40
+
41
+ template <int knrows, typename input_t, typename weight_t>
42
+ void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream);
43
+
44
+ void set_ssm_params_fwd(SSMParamsBase &params,
45
+ // sizes
46
+ const size_t batch,
47
+ const size_t dim,
48
+ const size_t seqlen,
49
+ const size_t dstate,
50
+ const size_t n_groups,
51
+ const size_t n_chunks,
52
+ // device pointers
53
+ const at::Tensor u,
54
+ const at::Tensor delta,
55
+ const at::Tensor A,
56
+ const at::Tensor B,
57
+ const at::Tensor C,
58
+ const at::Tensor out,
59
+ void* D_ptr,
60
+ void* delta_bias_ptr,
61
+ void* x_ptr,
62
+ bool delta_softplus) {
63
+
64
+ // Reset the parameters
65
+ memset(&params, 0, sizeof(params));
66
+
67
+ params.batch = batch;
68
+ params.dim = dim;
69
+ params.seqlen = seqlen;
70
+ params.dstate = dstate;
71
+ params.n_groups = n_groups;
72
+ params.n_chunks = n_chunks;
73
+ params.dim_ngroups_ratio = dim / n_groups;
74
+
75
+ params.delta_softplus = delta_softplus;
76
+
77
+ // Set the pointers and strides.
78
+ params.u_ptr = u.data_ptr();
79
+ params.delta_ptr = delta.data_ptr();
80
+ params.A_ptr = A.data_ptr();
81
+ params.B_ptr = B.data_ptr();
82
+ params.C_ptr = C.data_ptr();
83
+ params.D_ptr = D_ptr;
84
+ params.delta_bias_ptr = delta_bias_ptr;
85
+ params.out_ptr = out.data_ptr();
86
+ params.x_ptr = x_ptr;
87
+
88
+ // All stride are in elements, not bytes.
89
+ params.A_d_stride = A.stride(0);
90
+ params.A_dstate_stride = A.stride(1);
91
+ params.B_batch_stride = B.stride(0);
92
+ params.B_group_stride = B.stride(1);
93
+ params.B_dstate_stride = B.stride(2);
94
+ params.C_batch_stride = C.stride(0);
95
+ params.C_group_stride = C.stride(1);
96
+ params.C_dstate_stride = C.stride(2);
97
+ params.u_batch_stride = u.stride(0);
98
+ params.u_d_stride = u.stride(1);
99
+ params.delta_batch_stride = delta.stride(0);
100
+ params.delta_d_stride = delta.stride(1);
101
+
102
+ params.out_batch_stride = out.stride(0);
103
+ params.out_d_stride = out.stride(1);
104
+ }
105
+
106
+ void set_ssm_params_bwd(SSMParamsBwd &params,
107
+ // sizes
108
+ const size_t batch,
109
+ const size_t dim,
110
+ const size_t seqlen,
111
+ const size_t dstate,
112
+ const size_t n_groups,
113
+ const size_t n_chunks,
114
+ // device pointers
115
+ const at::Tensor u,
116
+ const at::Tensor delta,
117
+ const at::Tensor A,
118
+ const at::Tensor B,
119
+ const at::Tensor C,
120
+ const at::Tensor out,
121
+ void* D_ptr,
122
+ void* delta_bias_ptr,
123
+ void* x_ptr,
124
+ const at::Tensor dout,
125
+ const at::Tensor du,
126
+ const at::Tensor ddelta,
127
+ const at::Tensor dA,
128
+ const at::Tensor dB,
129
+ const at::Tensor dC,
130
+ void* dD_ptr,
131
+ void* ddelta_bias_ptr,
132
+ bool delta_softplus) {
133
+ // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z
134
+ set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks,
135
+ u, delta, A, B, C, dout,
136
+ D_ptr, delta_bias_ptr, x_ptr, delta_softplus);
137
+
138
+ // Set the pointers and strides.
139
+ params.dout_ptr = dout.data_ptr();
140
+ params.du_ptr = du.data_ptr();
141
+ params.dA_ptr = dA.data_ptr();
142
+ params.dB_ptr = dB.data_ptr();
143
+ params.dC_ptr = dC.data_ptr();
144
+ params.dD_ptr = dD_ptr;
145
+ params.ddelta_ptr = ddelta.data_ptr();
146
+ params.ddelta_bias_ptr = ddelta_bias_ptr;
147
+ // All stride are in elements, not bytes.
148
+ params.dout_batch_stride = dout.stride(0);
149
+ params.dout_d_stride = dout.stride(1);
150
+ params.dA_d_stride = dA.stride(0);
151
+ params.dA_dstate_stride = dA.stride(1);
152
+ params.dB_batch_stride = dB.stride(0);
153
+ params.dB_group_stride = dB.stride(1);
154
+ params.dB_dstate_stride = dB.stride(2);
155
+ params.dC_batch_stride = dC.stride(0);
156
+ params.dC_group_stride = dC.stride(1);
157
+ params.dC_dstate_stride = dC.stride(2);
158
+ params.du_batch_stride = du.stride(0);
159
+ params.du_d_stride = du.stride(1);
160
+ params.ddelta_batch_stride = ddelta.stride(0);
161
+ params.ddelta_d_stride = ddelta.stride(1);
162
+
163
+ }
164
+
165
+ std::vector<at::Tensor>
166
+ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
167
+ const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
168
+ const c10::optional<at::Tensor> &D_,
169
+ const c10::optional<at::Tensor> &delta_bias_,
170
+ bool delta_softplus,
171
+ int nrows
172
+ ) {
173
+ auto input_type = u.scalar_type();
174
+ auto weight_type = A.scalar_type();
175
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
176
+ TORCH_CHECK(weight_type == at::ScalarType::Float);
177
+
178
+ TORCH_CHECK(delta.scalar_type() == input_type);
179
+ TORCH_CHECK(B.scalar_type() == input_type);
180
+ TORCH_CHECK(C.scalar_type() == input_type);
181
+
182
+ TORCH_CHECK(u.is_cuda());
183
+ TORCH_CHECK(delta.is_cuda());
184
+ TORCH_CHECK(A.is_cuda());
185
+ TORCH_CHECK(B.is_cuda());
186
+ TORCH_CHECK(C.is_cuda());
187
+
188
+ TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
189
+ TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
190
+
191
+ const auto sizes = u.sizes();
192
+ const int batch_size = sizes[0];
193
+ const int dim = sizes[1];
194
+ const int seqlen = sizes[2];
195
+ const int dstate = A.size(1);
196
+ const int n_groups = B.size(1);
197
+
198
+ TORCH_CHECK(dim % (n_groups * nrows) == 0, "dims should be dividable by n_groups * nrows");
199
+ TORCH_CHECK(dstate <= MAX_DSTATE / nrows, "selective_scan only supports state dimension <= 256 / nrows");
200
+
201
+ CHECK_SHAPE(u, batch_size, dim, seqlen);
202
+ CHECK_SHAPE(delta, batch_size, dim, seqlen);
203
+ CHECK_SHAPE(A, dim, dstate);
204
+ CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen);
205
+ TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
206
+ CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
207
+ TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
208
+
209
+ if (D_.has_value()) {
210
+ auto D = D_.value();
211
+ TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
212
+ TORCH_CHECK(D.is_cuda());
213
+ TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
214
+ CHECK_SHAPE(D, dim);
215
+ }
216
+
217
+ if (delta_bias_.has_value()) {
218
+ auto delta_bias = delta_bias_.value();
219
+ TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
220
+ TORCH_CHECK(delta_bias.is_cuda());
221
+ TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
222
+ CHECK_SHAPE(delta_bias, dim);
223
+ }
224
+
225
+ const int n_chunks = (seqlen + 2048 - 1) / 2048; // max is 128 * 16 = 2048 in fwd_kernel
226
+ at::Tensor out = torch::empty_like(delta);
227
+ at::Tensor x;
228
+ x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type));
229
+
230
+ SSMParamsBase params;
231
+ set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks,
232
+ u, delta, A, B, C, out,
233
+ D_.has_value() ? D_.value().data_ptr() : nullptr,
234
+ delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
235
+ x.data_ptr(),
236
+ delta_softplus);
237
+
238
+ // Otherwise the kernel will be launched from cuda:0 device
239
+ // Cast to char to avoid compiler warning about narrowing
240
+ at::cuda::CUDAGuard device_guard{(char)u.get_device()};
241
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
242
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
243
+ INT_SWITCH(nrows, kNRows, [&] {
244
+ selective_scan_fwd_cuda<kNRows, input_t, weight_t>(params, stream);
245
+ });
246
+ });
247
+ std::vector<at::Tensor> result = {out, x};
248
+ return result;
249
+ }
250
+
251
+ std::vector<at::Tensor>
252
+ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
253
+ const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
254
+ const c10::optional<at::Tensor> &D_,
255
+ const c10::optional<at::Tensor> &delta_bias_,
256
+ const at::Tensor &dout,
257
+ const c10::optional<at::Tensor> &x_,
258
+ bool delta_softplus,
259
+ int nrows
260
+ ) {
261
+ auto input_type = u.scalar_type();
262
+ auto weight_type = A.scalar_type();
263
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
264
+ TORCH_CHECK(weight_type == at::ScalarType::Float);
265
+
266
+ TORCH_CHECK(delta.scalar_type() == input_type);
267
+ TORCH_CHECK(B.scalar_type() == input_type);
268
+ TORCH_CHECK(C.scalar_type() == input_type);
269
+ TORCH_CHECK(dout.scalar_type() == input_type);
270
+
271
+ TORCH_CHECK(u.is_cuda());
272
+ TORCH_CHECK(delta.is_cuda());
273
+ TORCH_CHECK(A.is_cuda());
274
+ TORCH_CHECK(B.is_cuda());
275
+ TORCH_CHECK(C.is_cuda());
276
+ TORCH_CHECK(dout.is_cuda());
277
+
278
+ TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
279
+ TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
280
+ TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1);
281
+
282
+ const auto sizes = u.sizes();
283
+ const int batch_size = sizes[0];
284
+ const int dim = sizes[1];
285
+ const int seqlen = sizes[2];
286
+ const int dstate = A.size(1);
287
+ const int n_groups = B.size(1);
288
+
289
+ TORCH_CHECK(dim % (n_groups * nrows) == 0, "dims should be dividable by n_groups * nrows");
290
+ TORCH_CHECK(dstate <= MAX_DSTATE / nrows, "selective_scan only supports state dimension <= 256 / nrows");
291
+
292
+ CHECK_SHAPE(u, batch_size, dim, seqlen);
293
+ CHECK_SHAPE(delta, batch_size, dim, seqlen);
294
+ CHECK_SHAPE(A, dim, dstate);
295
+ CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen);
296
+ TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
297
+ CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
298
+ TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
299
+ CHECK_SHAPE(dout, batch_size, dim, seqlen);
300
+
301
+ if (D_.has_value()) {
302
+ auto D = D_.value();
303
+ TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
304
+ TORCH_CHECK(D.is_cuda());
305
+ TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
306
+ CHECK_SHAPE(D, dim);
307
+ }
308
+
309
+ if (delta_bias_.has_value()) {
310
+ auto delta_bias = delta_bias_.value();
311
+ TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
312
+ TORCH_CHECK(delta_bias.is_cuda());
313
+ TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
314
+ CHECK_SHAPE(delta_bias, dim);
315
+ }
316
+
317
+ at::Tensor out;
318
+ const int n_chunks = (seqlen + 2048 - 1) / 2048;
319
+ // const int n_chunks = (seqlen + 1024 - 1) / 1024;
320
+ if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); }
321
+ if (x_.has_value()) {
322
+ auto x = x_.value();
323
+ TORCH_CHECK(x.scalar_type() == weight_type);
324
+ TORCH_CHECK(x.is_cuda());
325
+ TORCH_CHECK(x.is_contiguous());
326
+ CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate);
327
+ }
328
+
329
+ at::Tensor du = torch::empty_like(u);
330
+ at::Tensor ddelta = torch::empty_like(delta);
331
+ at::Tensor dA = torch::zeros_like(A);
332
+ at::Tensor dB = torch::zeros_like(B, B.options().dtype(torch::kFloat32));
333
+ at::Tensor dC = torch::zeros_like(C, C.options().dtype(torch::kFloat32));
334
+ at::Tensor dD;
335
+ if (D_.has_value()) { dD = torch::zeros_like(D_.value()); }
336
+ at::Tensor ddelta_bias;
337
+ if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); }
338
+
339
+ SSMParamsBwd params;
340
+ set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks,
341
+ u, delta, A, B, C, out,
342
+ D_.has_value() ? D_.value().data_ptr() : nullptr,
343
+ delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
344
+ x_.has_value() ? x_.value().data_ptr() : nullptr,
345
+ dout, du, ddelta, dA, dB, dC,
346
+ D_.has_value() ? dD.data_ptr() : nullptr,
347
+ delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr,
348
+ delta_softplus);
349
+
350
+ // Otherwise the kernel will be launched from cuda:0 device
351
+ // Cast to char to avoid compiler warning about narrowing
352
+ at::cuda::CUDAGuard device_guard{(char)u.get_device()};
353
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
354
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] {
355
+ // constexpr int kNRows = 1;
356
+ INT_SWITCH(nrows, kNRows, [&] {
357
+ selective_scan_bwd_cuda<kNRows, input_t, weight_t>(params, stream);
358
+ });
359
+ });
360
+ std::vector<at::Tensor> result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias};
361
+ return result;
362
+ }
363
+
364
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
365
+ m.def("fwd", &selective_scan_fwd, "Selective scan forward");
366
+ m.def("bwd", &selective_scan_bwd, "Selective scan backward");
367
+ }
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_bwd_kernel_oflex.cuh ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <c10/util/BFloat16.h>
8
+ #include <c10/util/Half.h>
9
+ #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
10
+ #include <ATen/cuda/Atomic.cuh> // For atomicAdd on complex
11
+
12
+ #include <cub/block/block_load.cuh>
13
+ #include <cub/block/block_store.cuh>
14
+ #include <cub/block/block_scan.cuh>
15
+ #include <cub/block/block_reduce.cuh>
16
+
17
+ #include "selective_scan.h"
18
+ #include "selective_scan_common.h"
19
+ #include "reverse_scan.cuh"
20
+ #include "static_switch.h"
21
+
22
+ template<int kNThreads_, int kNItems_, bool kIsEvenLen_, bool kDeltaSoftplus_, typename input_t_, typename weight_t_, typename output_t_>
23
+ struct Selective_Scan_bwd_kernel_traits {
24
+ static_assert(kNItems_ % 4 == 0);
25
+ using input_t = input_t_;
26
+ using weight_t = weight_t_;
27
+ using output_t = output_t_;
28
+
29
+ static constexpr int kNThreads = kNThreads_;
30
+ static constexpr int kNItems = kNItems_;
31
+ static constexpr int MaxDState = MAX_DSTATE;
32
+ static constexpr int kNBytes = sizeof(input_t);
33
+ static_assert(kNBytes == 2 || kNBytes == 4);
34
+ static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
35
+ static_assert(kNItems % kNElts == 0);
36
+ static constexpr int kNLoads = kNItems / kNElts;
37
+ static constexpr bool kIsEvenLen = kIsEvenLen_;
38
+ static constexpr bool kDeltaSoftplus = kDeltaSoftplus_;
39
+ // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy.
40
+ // For complex this would lead to massive register spilling, so we keep it at 2.
41
+ static constexpr int kMinBlocks = kNThreads == 128 && 3;
42
+ static constexpr int kNLoadsOutput = sizeof(output_t) * kNLoads / kNBytes;
43
+ using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
44
+ using scan_t = float2;
45
+ using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
46
+ using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
47
+ using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
48
+ using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
49
+ using BlockLoadOutputT = cub::BlockLoad<output_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
50
+ using BlockLoadOutputVecT = cub::BlockLoad<vec_t, kNThreads, kNLoadsOutput, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
51
+ using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
52
+ using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads, cub::BLOCK_STORE_WARP_TRANSPOSE>;
53
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
54
+ using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
55
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
56
+ using BlockReverseScanT = BlockReverseScan<scan_t, kNThreads>;
57
+ using BlockReduceT = cub::BlockReduce<scan_t, kNThreads>;
58
+ using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
59
+ using BlockExchangeT = cub::BlockExchange<float, kNThreads, kNItems>;
60
+ static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
61
+ sizeof(typename BlockLoadVecT::TempStorage),
62
+ 2 * sizeof(typename BlockLoadWeightT::TempStorage),
63
+ 2 * sizeof(typename BlockLoadWeightVecT::TempStorage),
64
+ sizeof(typename BlockLoadOutputT::TempStorage),
65
+ sizeof(typename BlockLoadOutputVecT::TempStorage),
66
+ sizeof(typename BlockStoreT::TempStorage),
67
+ sizeof(typename BlockStoreVecT::TempStorage)});
68
+ static constexpr int kSmemExchangeSize = 2 * sizeof(typename BlockExchangeT::TempStorage);
69
+ static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage);
70
+ static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage);
71
+ };
72
+
73
+ template<typename Ktraits>
74
+ __global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
75
+ void selective_scan_bwd_kernel(SSMParamsBwd params) {
76
+ constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus;
77
+ constexpr int kNThreads = Ktraits::kNThreads;
78
+ constexpr int kNItems = Ktraits::kNItems;
79
+ using input_t = typename Ktraits::input_t;
80
+ using weight_t = typename Ktraits::weight_t;
81
+ using output_t = typename Ktraits::output_t;
82
+ using scan_t = typename Ktraits::scan_t;
83
+
84
+ // Shared memory.
85
+ extern __shared__ char smem_[];
86
+ auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
87
+ auto& smem_load1 = reinterpret_cast<typename Ktraits::BlockLoadOutputT::TempStorage&>(smem_);
88
+ auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
89
+ auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
90
+ auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
91
+ auto& smem_exchange = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
92
+ auto& smem_exchange1 = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage));
93
+ auto& smem_reduce = *reinterpret_cast<typename Ktraits::BlockReduceT::TempStorage*>(reinterpret_cast<char *>(&smem_exchange) + Ktraits::kSmemExchangeSize);
94
+ auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(&smem_reduce);
95
+ auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(reinterpret_cast<char *>(&smem_reduce) + Ktraits::kSmemReduceSize);
96
+ auto& smem_reverse_scan = *reinterpret_cast<typename Ktraits::BlockReverseScanT::TempStorage*>(reinterpret_cast<char *>(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage));
97
+ weight_t *smem_delta_a = reinterpret_cast<weight_t *>(smem_ + Ktraits::kSmemSize);
98
+ scan_t *smem_running_postfix = reinterpret_cast<scan_t *>(smem_delta_a + 2 * Ktraits::MaxDState + kNThreads);
99
+ weight_t *smem_da = reinterpret_cast<weight_t *>(smem_running_postfix + Ktraits::MaxDState);
100
+
101
+ const int batch_id = blockIdx.x;
102
+ const int dim_id = blockIdx.y;
103
+ const int group_id = dim_id / (params.dim_ngroups_ratio);
104
+ input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
105
+ + dim_id * params.u_d_stride;
106
+ input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
107
+ + dim_id * params.delta_d_stride;
108
+
109
+ weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * params.A_d_stride;
110
+ input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
111
+ input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
112
+ weight_t *dA = reinterpret_cast<weight_t *>(params.dA_ptr) + dim_id * params.dA_d_stride;
113
+ weight_t *dB = reinterpret_cast<weight_t *>(params.dB_ptr)
114
+ + (batch_id * params.dB_batch_stride + group_id * params.dB_group_stride);
115
+ weight_t *dC = reinterpret_cast<weight_t *>(params.dC_ptr)
116
+ + (batch_id * params.dC_batch_stride + group_id * params.dC_group_stride);
117
+ float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.dD_ptr) + dim_id;
118
+ float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.D_ptr)[dim_id];
119
+ float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.ddelta_bias_ptr) + dim_id;
120
+ float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id];
121
+ scan_t *x = params.x_ptr == nullptr
122
+ ? nullptr
123
+ : reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate;
124
+ float dD_val = 0;
125
+ float ddelta_bias_val = 0;
126
+
127
+ output_t *dout = reinterpret_cast<output_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride + dim_id * params.dout_d_stride;
128
+
129
+ constexpr int kChunkSize = kNThreads * kNItems;
130
+ u += (params.n_chunks - 1) * kChunkSize;
131
+ delta += (params.n_chunks - 1) * kChunkSize;
132
+ dout += (params.n_chunks - 1) * kChunkSize;
133
+ Bvar += (params.n_chunks - 1) * kChunkSize;
134
+ Cvar += (params.n_chunks - 1) * kChunkSize;
135
+ for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) {
136
+ input_t u_vals[kNItems];
137
+ input_t delta_vals_load[kNItems];
138
+ float dout_vals[kNItems];
139
+ __syncthreads();
140
+ load_input<Ktraits>(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize);
141
+ __syncthreads();
142
+ load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
143
+ __syncthreads();
144
+ if constexpr (std::is_same_v<output_t, input_t>) {
145
+ input_t dout_vals_load[kNItems];
146
+ load_input<Ktraits>(reinterpret_cast<input_t *>(dout), dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
147
+ Converter<typename Ktraits::input_t, kNItems>::to_float(dout_vals_load, dout_vals);
148
+ } else {
149
+ static_assert(std::is_same_v<output_t, float>);
150
+ load_output<Ktraits>(dout, dout_vals, smem_load1, params.seqlen - chunk * kChunkSize);
151
+ }
152
+ u -= kChunkSize;
153
+ // Will reload delta at the same location if kDeltaSoftplus
154
+ if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; }
155
+ dout -= kChunkSize;
156
+
157
+ float delta_vals[kNItems];
158
+ float du_vals[kNItems];
159
+ #pragma unroll
160
+ for (int i = 0; i < kNItems; ++i) {
161
+ delta_vals[i] = float(delta_vals_load[i]) + delta_bias;
162
+ if constexpr (kDeltaSoftplus) {
163
+ delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i];
164
+ }
165
+ }
166
+ #pragma unroll
167
+ for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; }
168
+ #pragma unroll
169
+ for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); }
170
+
171
+ float ddelta_vals[kNItems] = {0};
172
+ __syncthreads();
173
+ for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
174
+ constexpr float kLog2e = M_LOG2E;
175
+ weight_t A_val = A[state_idx * params.A_dstate_stride];
176
+ weight_t A_scaled = A_val * kLog2e;
177
+ weight_t B_vals[kNItems], C_vals[kNItems];
178
+ load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
179
+ smem_load_weight, (params.seqlen - chunk * kChunkSize));
180
+ auto &smem_load_weight_C = smem_load_weight1;
181
+ load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
182
+ smem_load_weight_C, (params.seqlen - chunk * kChunkSize));
183
+ scan_t thread_data[kNItems], thread_reverse_data[kNItems];
184
+ #pragma unroll
185
+ for (int i = 0; i < kNItems; ++i) {
186
+ const float delta_a_exp = exp2f(delta_vals[i] * A_scaled);
187
+ thread_data[i] = make_float2(delta_a_exp, delta_vals[i] * float(u_vals[i]) * B_vals[i]);
188
+ if (i == 0) {
189
+ smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * Ktraits::MaxDState: threadIdx.x + 2 * Ktraits::MaxDState] = delta_a_exp;
190
+ } else {
191
+ thread_reverse_data[i - 1].x = delta_a_exp;
192
+ }
193
+ thread_reverse_data[i].y = dout_vals[i] * C_vals[i];
194
+ }
195
+ __syncthreads();
196
+ thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1
197
+ ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState])
198
+ : smem_delta_a[threadIdx.x + 1 + 2 * Ktraits::MaxDState];
199
+ // Initialize running total
200
+ scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f);
201
+ SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
202
+ Ktraits::BlockScanT(smem_scan).InclusiveScan(
203
+ thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
204
+ );
205
+ scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f);
206
+ SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
207
+ Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
208
+ thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
209
+ );
210
+ if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; }
211
+ weight_t dA_val = 0;
212
+ weight_t dB_vals[kNItems], dC_vals[kNItems];
213
+ #pragma unroll
214
+ for (int i = 0; i < kNItems; ++i) {
215
+ const float dx = thread_reverse_data[i].y;
216
+ const float ddelta_u = dx * B_vals[i];
217
+ du_vals[i] += ddelta_u * delta_vals[i];
218
+ const float a = thread_data[i].y - (delta_vals[i] * float(u_vals[i]) * B_vals[i]);
219
+ ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a;
220
+ dA_val += dx * delta_vals[i] * a;
221
+ dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]);
222
+ dC_vals[i] = dout_vals[i] * thread_data[i].y;
223
+ }
224
+ // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
225
+ Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals);
226
+ auto &smem_exchange_C = smem_exchange1;
227
+ Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals);
228
+ const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x;
229
+ weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x;
230
+ weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x;
231
+ #pragma unroll
232
+ for (int i = 0; i < kNItems; ++i) {
233
+ if (i * kNThreads < seqlen_remaining) {
234
+ { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); }
235
+ { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); }
236
+ }
237
+ }
238
+ dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val);
239
+ if (threadIdx.x == 0) {
240
+ smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];
241
+ }
242
+ }
243
+
244
+ if constexpr (kDeltaSoftplus) {
245
+ input_t delta_vals_load[kNItems];
246
+ __syncthreads();
247
+ load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
248
+ delta -= kChunkSize;
249
+ #pragma unroll
250
+ for (int i = 0; i < kNItems; ++i) {
251
+ float delta_val = float(delta_vals_load[i]) + delta_bias;
252
+ float delta_val_neg_exp = expf(-delta_val);
253
+ ddelta_vals[i] = delta_val <= 20.f
254
+ ? ddelta_vals[i] / (1.f + delta_val_neg_exp)
255
+ : ddelta_vals[i];
256
+ }
257
+ }
258
+
259
+ __syncthreads();
260
+ #pragma unroll
261
+ for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; }
262
+
263
+ input_t *du = reinterpret_cast<input_t *>(params.du_ptr) + batch_id * params.du_batch_stride
264
+ + dim_id * params.du_d_stride + chunk * kChunkSize;
265
+ input_t *ddelta = reinterpret_cast<input_t *>(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride
266
+ + dim_id * params.ddelta_d_stride + chunk * kChunkSize;
267
+ __syncthreads();
268
+ store_output<Ktraits>(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize);
269
+ __syncthreads();
270
+ store_output<Ktraits>(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize);
271
+ Bvar -= kChunkSize;
272
+ Cvar -= kChunkSize;
273
+ }
274
+
275
+ if (params.dD_ptr != nullptr) {
276
+ __syncthreads();
277
+ dD_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val);
278
+ if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); }
279
+ }
280
+ if (params.ddelta_bias_ptr != nullptr) {
281
+ __syncthreads();
282
+ ddelta_bias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val);
283
+ if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); }
284
+ }
285
+ __syncthreads();
286
+ for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
287
+ gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride]), smem_da[state_idx]);
288
+ }
289
+ }
290
+
291
+ template<int kNThreads, int kNItems, typename input_t, typename weight_t, typename output_t>
292
+ void selective_scan_bwd_launch(SSMParamsBwd &params, cudaStream_t stream) {
293
+ BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
294
+ BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] {
295
+ using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, kDeltaSoftplus, input_t, weight_t, output_t>;
296
+ constexpr int kSmemSize = Ktraits::kSmemSize + Ktraits::MaxDState * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * Ktraits::MaxDState) * sizeof(typename Ktraits::weight_t);
297
+ // printf("smem_size = %d\n", kSmemSize);
298
+ dim3 grid(params.batch, params.dim);
299
+ auto kernel = &selective_scan_bwd_kernel<Ktraits>;
300
+ if (kSmemSize >= 48 * 1024) {
301
+ C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
302
+ }
303
+ kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
304
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
305
+ });
306
+ });
307
+ }
308
+
309
+ template<int knrows, typename input_t, typename weight_t, typename output_t>
310
+ void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream) {
311
+ if (params.seqlen <= 128) {
312
+ selective_scan_bwd_launch<32, 4, input_t, weight_t, output_t>(params, stream);
313
+ } else if (params.seqlen <= 256) {
314
+ selective_scan_bwd_launch<32, 8, input_t, weight_t, output_t>(params, stream);
315
+ } else if (params.seqlen <= 512) {
316
+ selective_scan_bwd_launch<32, 16, input_t, weight_t, output_t>(params, stream);
317
+ } else if (params.seqlen <= 1024) {
318
+ selective_scan_bwd_launch<64, 16, input_t, weight_t, output_t>(params, stream);
319
+ } else {
320
+ selective_scan_bwd_launch<128, 16, input_t, weight_t, output_t>(params, stream);
321
+ }
322
+ }
323
+
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_core_bwd.cu ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+ #include "selective_scan_bwd_kernel_oflex.cuh"
5
+
6
+ template void selective_scan_bwd_cuda<1, float, float, float>(SSMParamsBwd &params, cudaStream_t stream);
7
+ template void selective_scan_bwd_cuda<1, at::Half, float, float>(SSMParamsBwd &params, cudaStream_t stream);
8
+ template void selective_scan_bwd_cuda<1, at::BFloat16, float, float>(SSMParamsBwd &params, cudaStream_t stream);
9
+ template void selective_scan_bwd_cuda<1, at::Half, float, at::Half>(SSMParamsBwd &params, cudaStream_t stream);
10
+ template void selective_scan_bwd_cuda<1, at::BFloat16, float, at::BFloat16>(SSMParamsBwd &params, cudaStream_t stream);
11
+
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_core_fwd.cu ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+ #include "selective_scan_fwd_kernel_oflex.cuh"
5
+
6
+ template void selective_scan_fwd_cuda<1, float, float, float>(SSMParamsBase &params, cudaStream_t stream);
7
+ template void selective_scan_fwd_cuda<1, at::Half, float, float>(SSMParamsBase &params, cudaStream_t stream);
8
+ template void selective_scan_fwd_cuda<1, at::BFloat16, float, float>(SSMParamsBase &params, cudaStream_t stream);
9
+ template void selective_scan_fwd_cuda<1, at::Half, float, at::Half>(SSMParamsBase &params, cudaStream_t stream);
10
+ template void selective_scan_fwd_cuda<1, at::BFloat16, float, at::BFloat16>(SSMParamsBase &params, cudaStream_t stream);
11
+
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_fwd_kernel_oflex.cuh ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <c10/util/BFloat16.h>
8
+ #include <c10/util/Half.h>
9
+ #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
10
+
11
+ #include <cub/block/block_load.cuh>
12
+ #include <cub/block/block_store.cuh>
13
+ #include <cub/block/block_scan.cuh>
14
+
15
+ #include "selective_scan.h"
16
+ #include "selective_scan_common.h"
17
+ #include "static_switch.h"
18
+
19
+ template<int kNThreads_, int kNItems_, bool kIsEvenLen_, typename input_t_, typename weight_t_, typename output_t_>
20
+ struct Selective_Scan_fwd_kernel_traits {
21
+ static_assert(kNItems_ % 4 == 0);
22
+ using input_t = input_t_;
23
+ using weight_t = weight_t_;
24
+ using output_t = output_t_;
25
+ static constexpr int kNThreads = kNThreads_;
26
+ // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
27
+ static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
28
+ static constexpr int kNItems = kNItems_;
29
+ static constexpr int MaxDState = MAX_DSTATE;
30
+ static constexpr int kNBytes = sizeof(input_t);
31
+ static_assert(kNBytes == 2 || kNBytes == 4);
32
+ static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
33
+ static_assert(kNItems % kNElts == 0);
34
+ static constexpr int kNLoads = kNItems / kNElts;
35
+ static constexpr bool kIsEvenLen = kIsEvenLen_;
36
+ static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;
37
+ static constexpr int kNLoadsOutput = sizeof(output_t) * kNLoads / kNBytes;
38
+ static constexpr bool kDirectIOOutput = kDirectIO && (kNLoadsOutput == 1);
39
+ using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
40
+ using scan_t = float2;
41
+ using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
42
+ using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
43
+ !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
44
+ using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
45
+ using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
46
+ !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
47
+ using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
48
+ using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads,
49
+ !kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
50
+ using BlockStoreOutputT = cub::BlockStore<output_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
51
+ using BlockStoreOutputVecT = cub::BlockStore<vec_t, kNThreads, kNLoadsOutput,
52
+ !kDirectIOOutput ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
53
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
54
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
55
+ using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
56
+ static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
57
+ sizeof(typename BlockLoadVecT::TempStorage),
58
+ 2 * sizeof(typename BlockLoadWeightT::TempStorage),
59
+ 2 * sizeof(typename BlockLoadWeightVecT::TempStorage),
60
+ sizeof(typename BlockStoreT::TempStorage),
61
+ sizeof(typename BlockStoreVecT::TempStorage),
62
+ sizeof(typename BlockStoreOutputT::TempStorage),
63
+ sizeof(typename BlockStoreOutputVecT::TempStorage)});
64
+ static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
65
+ };
66
+
67
+ template<typename Ktraits>
68
+ __global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
69
+ void selective_scan_fwd_kernel(SSMParamsBase params) {
70
+ constexpr int kNThreads = Ktraits::kNThreads;
71
+ constexpr int kNItems = Ktraits::kNItems;
72
+ constexpr bool kDirectIO = Ktraits::kDirectIO;
73
+ using input_t = typename Ktraits::input_t;
74
+ using weight_t = typename Ktraits::weight_t;
75
+ using output_t = typename Ktraits::output_t;
76
+ using scan_t = typename Ktraits::scan_t;
77
+
78
+ // Shared memory.
79
+ extern __shared__ char smem_[];
80
+ auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
81
+ auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
82
+ auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
83
+ auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
84
+ auto& smem_store1 = reinterpret_cast<typename Ktraits::BlockStoreOutputT::TempStorage&>(smem_);
85
+ auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
86
+ scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize);
87
+
88
+ const int batch_id = blockIdx.x;
89
+ const int dim_id = blockIdx.y;
90
+ const int group_id = dim_id / (params.dim_ngroups_ratio);
91
+ input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
92
+ + dim_id * params.u_d_stride;
93
+ input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
94
+ + dim_id * params.delta_d_stride;
95
+ weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * params.A_d_stride;
96
+ input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
97
+ input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
98
+ scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * params.n_chunks * params.dstate;
99
+
100
+ float D_val = 0; // attention!
101
+ if (params.D_ptr != nullptr) {
102
+ D_val = reinterpret_cast<float *>(params.D_ptr)[dim_id];
103
+ }
104
+ float delta_bias = 0;
105
+ if (params.delta_bias_ptr != nullptr) {
106
+ delta_bias = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id];
107
+ }
108
+
109
+ constexpr int kChunkSize = kNThreads * kNItems;
110
+ for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
111
+ input_t u_vals[kNItems], delta_vals_load[kNItems];
112
+ __syncthreads();
113
+ load_input<Ktraits>(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize);
114
+ if constexpr (!kDirectIO) { __syncthreads(); }
115
+ load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
116
+ u += kChunkSize;
117
+ delta += kChunkSize;
118
+
119
+ float delta_vals[kNItems], delta_u_vals[kNItems], out_vals[kNItems];
120
+ #pragma unroll
121
+ for (int i = 0; i < kNItems; ++i) {
122
+ float u_val = float(u_vals[i]);
123
+ delta_vals[i] = float(delta_vals_load[i]) + delta_bias;
124
+ if (params.delta_softplus) {
125
+ delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i];
126
+ }
127
+ delta_u_vals[i] = delta_vals[i] * u_val;
128
+ out_vals[i] = D_val * u_val;
129
+ }
130
+
131
+ __syncthreads();
132
+ for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
133
+ constexpr float kLog2e = M_LOG2E;
134
+ weight_t A_val = A[state_idx * params.A_dstate_stride];
135
+ A_val *= kLog2e;
136
+ weight_t B_vals[kNItems], C_vals[kNItems];
137
+ load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
138
+ smem_load_weight, (params.seqlen - chunk * kChunkSize));
139
+ load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
140
+ smem_load_weight1, (params.seqlen - chunk * kChunkSize));
141
+ __syncthreads();
142
+ scan_t thread_data[kNItems];
143
+ #pragma unroll
144
+ for (int i = 0; i < kNItems; ++i) {
145
+ thread_data[i] = make_float2(exp2f(delta_vals[i] * A_val), B_vals[i] * delta_u_vals[i]);
146
+ if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
147
+ if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
148
+ thread_data[i] = make_float2(1.f, 0.f);
149
+ }
150
+ }
151
+ }
152
+ // Initialize running total
153
+ scan_t running_prefix;
154
+ // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read
155
+ running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
156
+ // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
157
+ SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
158
+ Ktraits::BlockScanT(smem_scan).InclusiveScan(
159
+ thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
160
+ );
161
+ // There's a syncthreads in the scan op, so we don't need to sync here.
162
+ // Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
163
+ if (threadIdx.x == 0) {
164
+ smem_running_prefix[state_idx] = prefix_op.running_prefix;
165
+ x[chunk * params.dstate + state_idx] = prefix_op.running_prefix;
166
+ }
167
+ #pragma unroll
168
+ for (int i = 0; i < kNItems; ++i) {
169
+ out_vals[i] += thread_data[i].y * C_vals[i];
170
+ }
171
+ }
172
+
173
+ output_t *out = reinterpret_cast<output_t *>(params.out_ptr) + batch_id * params.out_batch_stride
174
+ + dim_id * params.out_d_stride + chunk * kChunkSize;
175
+ __syncthreads();
176
+ store_output1<Ktraits>(out, out_vals, smem_store1, params.seqlen - chunk * kChunkSize);
177
+ Bvar += kChunkSize;
178
+ Cvar += kChunkSize;
179
+ }
180
+ }
181
+
182
+ template<int kNThreads, int kNItems, typename input_t, typename weight_t, typename output_t>
183
+ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
184
+ BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
185
+ using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, input_t, weight_t, output_t>;
186
+ constexpr int kSmemSize = Ktraits::kSmemSize + Ktraits::MaxDState * sizeof(typename Ktraits::scan_t);
187
+ // printf("smem_size = %d\n", kSmemSize);
188
+ dim3 grid(params.batch, params.dim);
189
+ auto kernel = &selective_scan_fwd_kernel<Ktraits>;
190
+ if (kSmemSize >= 48 * 1024) {
191
+ C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
192
+ }
193
+ kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
194
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
195
+ });
196
+ }
197
+
198
+ template<int knrows, typename input_t, typename weight_t, typename output_t>
199
+ void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
200
+ if (params.seqlen <= 128) {
201
+ selective_scan_fwd_launch<32, 4, input_t, weight_t, output_t>(params, stream);
202
+ } else if (params.seqlen <= 256) {
203
+ selective_scan_fwd_launch<32, 8, input_t, weight_t, output_t>(params, stream);
204
+ } else if (params.seqlen <= 512) {
205
+ selective_scan_fwd_launch<32, 16, input_t, weight_t, output_t>(params, stream);
206
+ } else if (params.seqlen <= 1024) {
207
+ selective_scan_fwd_launch<64, 16, input_t, weight_t, output_t>(params, stream);
208
+ } else {
209
+ selective_scan_fwd_launch<128, 16, input_t, weight_t, output_t>(params, stream);
210
+ }
211
+ }
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_oflex.cpp ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #include <ATen/cuda/CUDAContext.h>
6
+ #include <c10/cuda/CUDAGuard.h>
7
+ #include <torch/extension.h>
8
+ #include <vector>
9
+
10
+ #include "selective_scan.h"
11
+ #define MAX_DSTATE 256
12
+
13
+ #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
14
+ using weight_t = float;
15
+
16
+ #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
17
+ if (ITYPE == at::ScalarType::Half) { \
18
+ using input_t = at::Half; \
19
+ __VA_ARGS__(); \
20
+ } else if (ITYPE == at::ScalarType::BFloat16) { \
21
+ using input_t = at::BFloat16; \
22
+ __VA_ARGS__(); \
23
+ } else if (ITYPE == at::ScalarType::Float) { \
24
+ using input_t = float; \
25
+ __VA_ARGS__(); \
26
+ } else { \
27
+ AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
28
+ }
29
+
30
+ template<int knrows, typename input_t, typename weight_t, typename output_t>
31
+ void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream);
32
+
33
+ template <int knrows, typename input_t, typename weight_t, typename output_t>
34
+ void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream);
35
+
36
+ void set_ssm_params_fwd(SSMParamsBase &params,
37
+ // sizes
38
+ const size_t batch,
39
+ const size_t dim,
40
+ const size_t seqlen,
41
+ const size_t dstate,
42
+ const size_t n_groups,
43
+ const size_t n_chunks,
44
+ // device pointers
45
+ const at::Tensor u,
46
+ const at::Tensor delta,
47
+ const at::Tensor A,
48
+ const at::Tensor B,
49
+ const at::Tensor C,
50
+ const at::Tensor out,
51
+ void* D_ptr,
52
+ void* delta_bias_ptr,
53
+ void* x_ptr,
54
+ bool delta_softplus) {
55
+
56
+ // Reset the parameters
57
+ memset(&params, 0, sizeof(params));
58
+
59
+ params.batch = batch;
60
+ params.dim = dim;
61
+ params.seqlen = seqlen;
62
+ params.dstate = dstate;
63
+ params.n_groups = n_groups;
64
+ params.n_chunks = n_chunks;
65
+ params.dim_ngroups_ratio = dim / n_groups;
66
+
67
+ params.delta_softplus = delta_softplus;
68
+
69
+ // Set the pointers and strides.
70
+ params.u_ptr = u.data_ptr();
71
+ params.delta_ptr = delta.data_ptr();
72
+ params.A_ptr = A.data_ptr();
73
+ params.B_ptr = B.data_ptr();
74
+ params.C_ptr = C.data_ptr();
75
+ params.D_ptr = D_ptr;
76
+ params.delta_bias_ptr = delta_bias_ptr;
77
+ params.out_ptr = out.data_ptr();
78
+ params.x_ptr = x_ptr;
79
+
80
+ // All stride are in elements, not bytes.
81
+ params.A_d_stride = A.stride(0);
82
+ params.A_dstate_stride = A.stride(1);
83
+ params.B_batch_stride = B.stride(0);
84
+ params.B_group_stride = B.stride(1);
85
+ params.B_dstate_stride = B.stride(2);
86
+ params.C_batch_stride = C.stride(0);
87
+ params.C_group_stride = C.stride(1);
88
+ params.C_dstate_stride = C.stride(2);
89
+ params.u_batch_stride = u.stride(0);
90
+ params.u_d_stride = u.stride(1);
91
+ params.delta_batch_stride = delta.stride(0);
92
+ params.delta_d_stride = delta.stride(1);
93
+
94
+ params.out_batch_stride = out.stride(0);
95
+ params.out_d_stride = out.stride(1);
96
+ }
97
+
98
+ void set_ssm_params_bwd(SSMParamsBwd &params,
99
+ // sizes
100
+ const size_t batch,
101
+ const size_t dim,
102
+ const size_t seqlen,
103
+ const size_t dstate,
104
+ const size_t n_groups,
105
+ const size_t n_chunks,
106
+ // device pointers
107
+ const at::Tensor u,
108
+ const at::Tensor delta,
109
+ const at::Tensor A,
110
+ const at::Tensor B,
111
+ const at::Tensor C,
112
+ const at::Tensor out,
113
+ void* D_ptr,
114
+ void* delta_bias_ptr,
115
+ void* x_ptr,
116
+ const at::Tensor dout,
117
+ const at::Tensor du,
118
+ const at::Tensor ddelta,
119
+ const at::Tensor dA,
120
+ const at::Tensor dB,
121
+ const at::Tensor dC,
122
+ void* dD_ptr,
123
+ void* ddelta_bias_ptr,
124
+ bool delta_softplus) {
125
+ // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z
126
+ set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks,
127
+ u, delta, A, B, C, dout,
128
+ D_ptr, delta_bias_ptr, x_ptr, delta_softplus);
129
+
130
+ // Set the pointers and strides.
131
+ params.dout_ptr = dout.data_ptr();
132
+ params.du_ptr = du.data_ptr();
133
+ params.dA_ptr = dA.data_ptr();
134
+ params.dB_ptr = dB.data_ptr();
135
+ params.dC_ptr = dC.data_ptr();
136
+ params.dD_ptr = dD_ptr;
137
+ params.ddelta_ptr = ddelta.data_ptr();
138
+ params.ddelta_bias_ptr = ddelta_bias_ptr;
139
+ // All stride are in elements, not bytes.
140
+ params.dout_batch_stride = dout.stride(0);
141
+ params.dout_d_stride = dout.stride(1);
142
+ params.dA_d_stride = dA.stride(0);
143
+ params.dA_dstate_stride = dA.stride(1);
144
+ params.dB_batch_stride = dB.stride(0);
145
+ params.dB_group_stride = dB.stride(1);
146
+ params.dB_dstate_stride = dB.stride(2);
147
+ params.dC_batch_stride = dC.stride(0);
148
+ params.dC_group_stride = dC.stride(1);
149
+ params.dC_dstate_stride = dC.stride(2);
150
+ params.du_batch_stride = du.stride(0);
151
+ params.du_d_stride = du.stride(1);
152
+ params.ddelta_batch_stride = ddelta.stride(0);
153
+ params.ddelta_d_stride = ddelta.stride(1);
154
+
155
+ }
156
+
157
+ std::vector<at::Tensor>
158
+ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
159
+ const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
160
+ const c10::optional<at::Tensor> &D_,
161
+ const c10::optional<at::Tensor> &delta_bias_,
162
+ bool delta_softplus,
163
+ int nrows,
164
+ bool out_float
165
+ ) {
166
+ auto input_type = u.scalar_type();
167
+ auto weight_type = A.scalar_type();
168
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
169
+ TORCH_CHECK(weight_type == at::ScalarType::Float);
170
+
171
+ TORCH_CHECK(delta.scalar_type() == input_type);
172
+ TORCH_CHECK(B.scalar_type() == input_type);
173
+ TORCH_CHECK(C.scalar_type() == input_type);
174
+
175
+ TORCH_CHECK(u.is_cuda());
176
+ TORCH_CHECK(delta.is_cuda());
177
+ TORCH_CHECK(A.is_cuda());
178
+ TORCH_CHECK(B.is_cuda());
179
+ TORCH_CHECK(C.is_cuda());
180
+
181
+ TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
182
+ TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
183
+
184
+ const auto sizes = u.sizes();
185
+ const int batch_size = sizes[0];
186
+ const int dim = sizes[1];
187
+ const int seqlen = sizes[2];
188
+ const int dstate = A.size(1);
189
+ const int n_groups = B.size(1);
190
+
191
+ TORCH_CHECK(dim % n_groups == 0, "dims should be dividable by n_groups");
192
+ TORCH_CHECK(dstate <= MAX_DSTATE, "selective_scan only supports state dimension <= 256");
193
+
194
+ CHECK_SHAPE(u, batch_size, dim, seqlen);
195
+ CHECK_SHAPE(delta, batch_size, dim, seqlen);
196
+ CHECK_SHAPE(A, dim, dstate);
197
+ CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen);
198
+ TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
199
+ CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
200
+ TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
201
+
202
+ if (D_.has_value()) {
203
+ auto D = D_.value();
204
+ TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
205
+ TORCH_CHECK(D.is_cuda());
206
+ TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
207
+ CHECK_SHAPE(D, dim);
208
+ }
209
+
210
+ if (delta_bias_.has_value()) {
211
+ auto delta_bias = delta_bias_.value();
212
+ TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
213
+ TORCH_CHECK(delta_bias.is_cuda());
214
+ TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
215
+ CHECK_SHAPE(delta_bias, dim);
216
+ }
217
+
218
+ const int n_chunks = (seqlen + 2048 - 1) / 2048; // max is 128 * 16 = 2048 in fwd_kernel
219
+ at::Tensor out = torch::empty({batch_size, dim, seqlen}, u.options().dtype(out_float? (at::ScalarType::Float): input_type));
220
+ at::Tensor x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type));
221
+
222
+ SSMParamsBase params;
223
+ set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks,
224
+ u, delta, A, B, C, out,
225
+ D_.has_value() ? D_.value().data_ptr() : nullptr,
226
+ delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
227
+ x.data_ptr(),
228
+ delta_softplus);
229
+
230
+ // Otherwise the kernel will be launched from cuda:0 device
231
+ // Cast to char to avoid compiler warning about narrowing
232
+ at::cuda::CUDAGuard device_guard{(char)u.get_device()};
233
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
234
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
235
+ if (!out_float) {
236
+ selective_scan_fwd_cuda<1, input_t, weight_t, input_t>(params, stream);
237
+ } else {
238
+ selective_scan_fwd_cuda<1, input_t, weight_t, float>(params, stream);
239
+ }
240
+ });
241
+ std::vector<at::Tensor> result = {out, x};
242
+ return result;
243
+ }
244
+
245
+ std::vector<at::Tensor>
246
+ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
247
+ const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
248
+ const c10::optional<at::Tensor> &D_,
249
+ const c10::optional<at::Tensor> &delta_bias_,
250
+ const at::Tensor &dout,
251
+ const c10::optional<at::Tensor> &x_,
252
+ bool delta_softplus,
253
+ int nrows
254
+ ) {
255
+ auto input_type = u.scalar_type();
256
+ auto weight_type = A.scalar_type();
257
+ auto output_type = dout.scalar_type();
258
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
259
+ TORCH_CHECK(weight_type == at::ScalarType::Float);
260
+ TORCH_CHECK(output_type == input_type || output_type == at::ScalarType::Float);
261
+
262
+ TORCH_CHECK(delta.scalar_type() == input_type);
263
+ TORCH_CHECK(B.scalar_type() == input_type);
264
+ TORCH_CHECK(C.scalar_type() == input_type);
265
+
266
+ TORCH_CHECK(u.is_cuda());
267
+ TORCH_CHECK(delta.is_cuda());
268
+ TORCH_CHECK(A.is_cuda());
269
+ TORCH_CHECK(B.is_cuda());
270
+ TORCH_CHECK(C.is_cuda());
271
+ TORCH_CHECK(dout.is_cuda());
272
+
273
+ TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
274
+ TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
275
+ TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1);
276
+
277
+ const auto sizes = u.sizes();
278
+ const int batch_size = sizes[0];
279
+ const int dim = sizes[1];
280
+ const int seqlen = sizes[2];
281
+ const int dstate = A.size(1);
282
+ const int n_groups = B.size(1);
283
+
284
+ TORCH_CHECK(dim % n_groups == 0, "dims should be dividable by n_groups");
285
+ TORCH_CHECK(dstate <= MAX_DSTATE, "selective_scan only supports state dimension <= 256");
286
+
287
+ CHECK_SHAPE(u, batch_size, dim, seqlen);
288
+ CHECK_SHAPE(delta, batch_size, dim, seqlen);
289
+ CHECK_SHAPE(A, dim, dstate);
290
+ CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen);
291
+ TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
292
+ CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
293
+ TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
294
+ CHECK_SHAPE(dout, batch_size, dim, seqlen);
295
+
296
+ if (D_.has_value()) {
297
+ auto D = D_.value();
298
+ TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
299
+ TORCH_CHECK(D.is_cuda());
300
+ TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
301
+ CHECK_SHAPE(D, dim);
302
+ }
303
+
304
+ if (delta_bias_.has_value()) {
305
+ auto delta_bias = delta_bias_.value();
306
+ TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
307
+ TORCH_CHECK(delta_bias.is_cuda());
308
+ TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
309
+ CHECK_SHAPE(delta_bias, dim);
310
+ }
311
+
312
+ at::Tensor out;
313
+ const int n_chunks = (seqlen + 2048 - 1) / 2048;
314
+ // const int n_chunks = (seqlen + 1024 - 1) / 1024;
315
+ if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); }
316
+ if (x_.has_value()) {
317
+ auto x = x_.value();
318
+ TORCH_CHECK(x.scalar_type() == weight_type);
319
+ TORCH_CHECK(x.is_cuda());
320
+ TORCH_CHECK(x.is_contiguous());
321
+ CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate);
322
+ }
323
+
324
+ at::Tensor du = torch::empty_like(u);
325
+ at::Tensor ddelta = torch::empty_like(delta);
326
+ at::Tensor dA = torch::zeros_like(A);
327
+ at::Tensor dB = torch::zeros_like(B, B.options().dtype(torch::kFloat32));
328
+ at::Tensor dC = torch::zeros_like(C, C.options().dtype(torch::kFloat32));
329
+ at::Tensor dD;
330
+ if (D_.has_value()) { dD = torch::zeros_like(D_.value()); }
331
+ at::Tensor ddelta_bias;
332
+ if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); }
333
+
334
+ SSMParamsBwd params;
335
+ set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks,
336
+ u, delta, A, B, C, out,
337
+ D_.has_value() ? D_.value().data_ptr() : nullptr,
338
+ delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
339
+ x_.has_value() ? x_.value().data_ptr() : nullptr,
340
+ dout, du, ddelta, dA, dB, dC,
341
+ D_.has_value() ? dD.data_ptr() : nullptr,
342
+ delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr,
343
+ delta_softplus);
344
+
345
+ // Otherwise the kernel will be launched from cuda:0 device
346
+ // Cast to char to avoid compiler warning about narrowing
347
+ at::cuda::CUDAGuard device_guard{(char)u.get_device()};
348
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
349
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] {
350
+ if (output_type == input_type) {
351
+ selective_scan_bwd_cuda<1, input_t, weight_t, input_t>(params, stream);
352
+ } else {
353
+ selective_scan_bwd_cuda<1, input_t, weight_t, float>(params, stream);
354
+ }
355
+ });
356
+ std::vector<at::Tensor> result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias};
357
+ return result;
358
+ }
359
+
360
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
361
+ m.def("fwd", &selective_scan_fwd, "Selective scan forward");
362
+ m.def("bwd", &selective_scan_bwd, "Selective scan backward");
363
+ }
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/reverse_scan.cuh ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <cub/config.cuh>
8
+
9
+ #include <cub/util_ptx.cuh>
10
+ #include <cub/util_type.cuh>
11
+ #include <cub/block/block_raking_layout.cuh>
12
+ // #include <cub/detail/uninitialized_copy.cuh>
13
+ #include "uninitialized_copy.cuh"
14
+ #include "cub_extra.cuh"
15
+
16
+ /**
17
+ * Perform a reverse sequential reduction over \p LENGTH elements of the \p input array. The aggregate is returned.
18
+ */
19
+ template <
20
+ int LENGTH,
21
+ typename T,
22
+ typename ReductionOp>
23
+ __device__ __forceinline__ T ThreadReverseReduce(const T (&input)[LENGTH], ReductionOp reduction_op) {
24
+ static_assert(LENGTH > 0);
25
+ T retval = input[LENGTH - 1];
26
+ #pragma unroll
27
+ for (int i = LENGTH - 2; i >= 0; --i) { retval = reduction_op(retval, input[i]); }
28
+ return retval;
29
+ }
30
+
31
+ /**
32
+ * Perform a sequential inclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned.
33
+ */
34
+ template <
35
+ int LENGTH,
36
+ typename T,
37
+ typename ScanOp>
38
+ __device__ __forceinline__ T ThreadReverseScanInclusive(
39
+ const T (&input)[LENGTH],
40
+ T (&output)[LENGTH],
41
+ ScanOp scan_op,
42
+ const T postfix)
43
+ {
44
+ T inclusive = postfix;
45
+ #pragma unroll
46
+ for (int i = LENGTH - 1; i >= 0; --i) {
47
+ inclusive = scan_op(inclusive, input[i]);
48
+ output[i] = inclusive;
49
+ }
50
+ }
51
+
52
+ /**
53
+ * Perform a sequential exclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned.
54
+ */
55
+ template <
56
+ int LENGTH,
57
+ typename T,
58
+ typename ScanOp>
59
+ __device__ __forceinline__ T ThreadReverseScanExclusive(
60
+ const T (&input)[LENGTH],
61
+ T (&output)[LENGTH],
62
+ ScanOp scan_op,
63
+ const T postfix)
64
+ {
65
+ // Careful, output maybe be aliased to input
66
+ T exclusive = postfix;
67
+ T inclusive;
68
+ #pragma unroll
69
+ for (int i = LENGTH - 1; i >= 0; --i) {
70
+ inclusive = scan_op(exclusive, input[i]);
71
+ output[i] = exclusive;
72
+ exclusive = inclusive;
73
+ }
74
+ return inclusive;
75
+ }
76
+
77
+
78
+ /**
79
+ * \brief WarpReverseScan provides SHFL-based variants of parallel postfix scan of items partitioned across a CUDA thread warp.
80
+ *
81
+ * LOGICAL_WARP_THREADS must be a power-of-two
82
+ */
83
+ template <
84
+ typename T, ///< Data type being scanned
85
+ int LOGICAL_WARP_THREADS ///< Number of threads per logical warp
86
+ >
87
+ struct WarpReverseScan {
88
+ //---------------------------------------------------------------------
89
+ // Constants and type definitions
90
+ //---------------------------------------------------------------------
91
+
92
+ /// Whether the logical warp size and the PTX warp size coincide
93
+ static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0));
94
+ /// The number of warp scan steps
95
+ static constexpr int STEPS = cub::Log2<LOGICAL_WARP_THREADS>::VALUE;
96
+ static_assert(LOGICAL_WARP_THREADS == 1 << STEPS);
97
+
98
+
99
+ //---------------------------------------------------------------------
100
+ // Thread fields
101
+ //---------------------------------------------------------------------
102
+
103
+ /// Lane index in logical warp
104
+ unsigned int lane_id;
105
+
106
+ /// Logical warp index in 32-thread physical warp
107
+ unsigned int warp_id;
108
+
109
+ /// 32-thread physical warp member mask of logical warp
110
+ unsigned int member_mask;
111
+
112
+ //---------------------------------------------------------------------
113
+ // Construction
114
+ //---------------------------------------------------------------------
115
+
116
+ /// Constructor
117
+ explicit __device__ __forceinline__
118
+ WarpReverseScan()
119
+ : lane_id(cub::LaneId())
120
+ , warp_id(IS_ARCH_WARP ? 0 : (lane_id / LOGICAL_WARP_THREADS))
121
+ // , member_mask(cub::WarpMask<LOGICAL_WARP_THREADS>(warp_id))
122
+ , member_mask(WarpMask<LOGICAL_WARP_THREADS>(warp_id))
123
+ {
124
+ if (!IS_ARCH_WARP) {
125
+ lane_id = lane_id % LOGICAL_WARP_THREADS;
126
+ }
127
+ }
128
+
129
+
130
+ /// Broadcast
131
+ __device__ __forceinline__ T Broadcast(
132
+ T input, ///< [in] The value to broadcast
133
+ int src_lane) ///< [in] Which warp lane is to do the broadcasting
134
+ {
135
+ return cub::ShuffleIndex<LOGICAL_WARP_THREADS>(input, src_lane, member_mask);
136
+ }
137
+
138
+
139
+ /// Inclusive scan
140
+ template <typename ScanOpT>
141
+ __device__ __forceinline__ void InclusiveReverseScan(
142
+ T input, ///< [in] Calling thread's input item.
143
+ T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input.
144
+ ScanOpT scan_op) ///< [in] Binary scan operator
145
+ {
146
+ inclusive_output = input;
147
+ #pragma unroll
148
+ for (int STEP = 0; STEP < STEPS; STEP++) {
149
+ int offset = 1 << STEP;
150
+ T temp = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
151
+ inclusive_output, offset, LOGICAL_WARP_THREADS - 1, member_mask
152
+ );
153
+ // Perform scan op if from a valid peer
154
+ inclusive_output = static_cast<int>(lane_id) >= LOGICAL_WARP_THREADS - offset
155
+ ? inclusive_output : scan_op(temp, inclusive_output);
156
+ }
157
+ }
158
+
159
+ /// Exclusive scan
160
+ // Get exclusive from inclusive
161
+ template <typename ScanOpT>
162
+ __device__ __forceinline__ void ExclusiveReverseScan(
163
+ T input, ///< [in] Calling thread's input item.
164
+ T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input.
165
+ ScanOpT scan_op, ///< [in] Binary scan operator
166
+ T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items.
167
+ {
168
+ T inclusive_output;
169
+ InclusiveReverseScan(input, inclusive_output, scan_op);
170
+ warp_aggregate = cub::ShuffleIndex<LOGICAL_WARP_THREADS>(inclusive_output, 0, member_mask);
171
+ // initial value unknown
172
+ exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
173
+ inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask
174
+ );
175
+ }
176
+
177
+ /**
178
+ * \brief Computes both inclusive and exclusive reverse scans using the specified binary scan functor across the calling warp. Because no initial value is supplied, the \p exclusive_output computed for the last <em>warp-lane</em> is undefined.
179
+ */
180
+ template <typename ScanOpT>
181
+ __device__ __forceinline__ void ReverseScan(
182
+ T input, ///< [in] Calling thread's input item.
183
+ T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item.
184
+ T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item.
185
+ ScanOpT scan_op) ///< [in] Binary scan operator
186
+ {
187
+ InclusiveReverseScan(input, inclusive_output, scan_op);
188
+ // initial value unknown
189
+ exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
190
+ inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask
191
+ );
192
+ }
193
+
194
+ };
195
+
196
+ /**
197
+ * \brief BlockReverseScan provides variants of raking-based parallel postfix scan across a CUDA thread block.
198
+ */
199
+ template <
200
+ typename T, ///< Data type being scanned
201
+ int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension
202
+ bool MEMOIZE=false ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure
203
+ >
204
+ struct BlockReverseScan {
205
+ //---------------------------------------------------------------------
206
+ // Types and constants
207
+ //---------------------------------------------------------------------
208
+
209
+ /// Constants
210
+ /// The thread block size in threads
211
+ static constexpr int BLOCK_THREADS = BLOCK_DIM_X;
212
+
213
+ /// Layout type for padded thread block raking grid
214
+ using BlockRakingLayout = cub::BlockRakingLayout<T, BLOCK_THREADS>;
215
+ // The number of reduction elements is not a multiple of the number of raking threads for now
216
+ static_assert(BlockRakingLayout::UNGUARDED);
217
+
218
+ /// Number of raking threads
219
+ static constexpr int RAKING_THREADS = BlockRakingLayout::RAKING_THREADS;
220
+ /// Number of raking elements per warp synchronous raking thread
221
+ static constexpr int SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH;
222
+ /// Cooperative work can be entirely warp synchronous
223
+ static constexpr bool WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS));
224
+
225
+ /// WarpReverseScan utility type
226
+ using WarpReverseScan = WarpReverseScan<T, RAKING_THREADS>;
227
+
228
+ /// Shared memory storage layout type
229
+ struct _TempStorage {
230
+ typename BlockRakingLayout::TempStorage raking_grid; ///< Padded thread block raking grid
231
+ };
232
+
233
+
234
+ /// Alias wrapper allowing storage to be unioned
235
+ struct TempStorage : cub::Uninitialized<_TempStorage> {};
236
+
237
+
238
+ //---------------------------------------------------------------------
239
+ // Per-thread fields
240
+ //---------------------------------------------------------------------
241
+
242
+ // Thread fields
243
+ _TempStorage &temp_storage;
244
+ unsigned int linear_tid;
245
+ T cached_segment[SEGMENT_LENGTH];
246
+
247
+
248
+ //---------------------------------------------------------------------
249
+ // Utility methods
250
+ //---------------------------------------------------------------------
251
+
252
+ /// Performs upsweep raking reduction, returning the aggregate
253
+ template <typename ScanOp>
254
+ __device__ __forceinline__ T Upsweep(ScanOp scan_op) {
255
+ T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);
256
+ // Read data into registers
257
+ #pragma unroll
258
+ for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }
259
+ T raking_partial = cached_segment[SEGMENT_LENGTH - 1];
260
+ #pragma unroll
261
+ for (int i = SEGMENT_LENGTH - 2; i >= 0; --i) {
262
+ raking_partial = scan_op(raking_partial, cached_segment[i]);
263
+ }
264
+ return raking_partial;
265
+ }
266
+
267
+
268
+ /// Performs exclusive downsweep raking scan
269
+ template <typename ScanOp>
270
+ __device__ __forceinline__ void ExclusiveDownsweep(
271
+ ScanOp scan_op,
272
+ T raking_partial)
273
+ {
274
+ T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);
275
+ // Read data back into registers
276
+ if (!MEMOIZE) {
277
+ #pragma unroll
278
+ for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }
279
+ }
280
+ ThreadReverseScanExclusive(cached_segment, cached_segment, scan_op, raking_partial);
281
+ // Write data back to smem
282
+ #pragma unroll
283
+ for (int i = 0; i < SEGMENT_LENGTH; ++i) { smem_raking_ptr[i] = cached_segment[i]; }
284
+ }
285
+
286
+
287
+ //---------------------------------------------------------------------
288
+ // Constructors
289
+ //---------------------------------------------------------------------
290
+
291
+ /// Constructor
292
+ __device__ __forceinline__ BlockReverseScan(
293
+ TempStorage &temp_storage)
294
+ :
295
+ temp_storage(temp_storage.Alias()),
296
+ linear_tid(cub::RowMajorTid(BLOCK_DIM_X, 1, 1))
297
+ {}
298
+
299
+
300
+ /// Computes an exclusive thread block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs.
301
+ template <
302
+ typename ScanOp,
303
+ typename BlockPostfixCallbackOp>
304
+ __device__ __forceinline__ void ExclusiveReverseScan(
305
+ T input, ///< [in] Calling thread's input item
306
+ T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input)
307
+ ScanOp scan_op, ///< [in] Binary scan operator
308
+ BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a thread block-wide postfix to be applied to all inputs.
309
+ {
310
+ if (WARP_SYNCHRONOUS) {
311
+ // Short-circuit directly to warp-synchronous scan
312
+ T block_aggregate;
313
+ WarpReverseScan warp_scan;
314
+ warp_scan.ExclusiveReverseScan(input, exclusive_output, scan_op, block_aggregate);
315
+ // Obtain warp-wide postfix in lane0, then broadcast to other lanes
316
+ T block_postfix = block_postfix_callback_op(block_aggregate);
317
+ block_postfix = warp_scan.Broadcast(block_postfix, 0);
318
+ exclusive_output = linear_tid == BLOCK_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_output);
319
+ } else {
320
+ // Place thread partial into shared memory raking grid
321
+ T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid);
322
+ detail::uninitialized_copy(placement_ptr, input);
323
+ cub::CTA_SYNC();
324
+ // Reduce parallelism down to just raking threads
325
+ if (linear_tid < RAKING_THREADS) {
326
+ WarpReverseScan warp_scan;
327
+ // Raking upsweep reduction across shared partials
328
+ T upsweep_partial = Upsweep(scan_op);
329
+ // Warp-synchronous scan
330
+ T exclusive_partial, block_aggregate;
331
+ warp_scan.ExclusiveReverseScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate);
332
+ // Obtain block-wide postfix in lane0, then broadcast to other lanes
333
+ T block_postfix = block_postfix_callback_op(block_aggregate);
334
+ block_postfix = warp_scan.Broadcast(block_postfix, 0);
335
+ // Update postfix with warpscan exclusive partial
336
+ T downsweep_postfix = linear_tid == RAKING_THREADS - 1
337
+ ? block_postfix : scan_op(block_postfix, exclusive_partial);
338
+ // Exclusive raking downsweep scan
339
+ ExclusiveDownsweep(scan_op, downsweep_postfix);
340
+ }
341
+ cub::CTA_SYNC();
342
+ // Grab thread postfix from shared memory
343
+ exclusive_output = *placement_ptr;
344
+
345
+ // // Compute warp scan in each warp.
346
+ // // The exclusive output from the last lane in each warp is invalid.
347
+ // T inclusive_output;
348
+ // WarpReverseScan warp_scan;
349
+ // warp_scan.ReverseScan(input, inclusive_output, exclusive_output, scan_op);
350
+
351
+ // // Compute the warp-wide postfix and block-wide aggregate for each warp. Warp postfix for the last warp is invalid.
352
+ // T block_aggregate;
353
+ // T warp_postfix = ComputeWarpPostfix(scan_op, inclusive_output, block_aggregate);
354
+
355
+ // // Apply warp postfix to our lane's partial
356
+ // if (warp_id != 0) {
357
+ // exclusive_output = scan_op(warp_postfix, exclusive_output);
358
+ // if (lane_id == 0) { exclusive_output = warp_postfix; }
359
+ // }
360
+
361
+ // // Use the first warp to determine the thread block postfix, returning the result in lane0
362
+ // if (warp_id == 0) {
363
+ // T block_postfix = block_postfix_callback_op(block_aggregate);
364
+ // if (lane_id == 0) {
365
+ // // Share the postfix with all threads
366
+ // detail::uninitialized_copy(&temp_storage.block_postfix,
367
+ // block_postfix);
368
+
369
+ // exclusive_output = block_postfix; // The block postfix is the exclusive output for tid0
370
+ // }
371
+ // }
372
+
373
+ // cub::CTA_SYNC();
374
+
375
+ // // Incorporate thread block postfix into outputs
376
+ // T block_postfix = temp_storage.block_postfix;
377
+ // if (linear_tid > 0) { exclusive_output = scan_op(block_postfix, exclusive_output); }
378
+ }
379
+ }
380
+
381
+
382
+ /**
383
+ * \brief Computes an inclusive block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs.
384
+ */
385
+ template <
386
+ int ITEMS_PER_THREAD,
387
+ typename ScanOp,
388
+ typename BlockPostfixCallbackOp>
389
+ __device__ __forceinline__ void InclusiveReverseScan(
390
+ T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items
391
+ T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input)
392
+ ScanOp scan_op, ///< [in] Binary scan functor
393
+ BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a block-wide postfix to be applied to the logical input sequence.
394
+ {
395
+ // Reduce consecutive thread items in registers
396
+ T thread_postfix = ThreadReverseReduce(input, scan_op);
397
+ // Exclusive thread block-scan
398
+ ExclusiveReverseScan(thread_postfix, thread_postfix, scan_op, block_postfix_callback_op);
399
+ // Inclusive scan in registers with postfix as seed
400
+ ThreadReverseScanInclusive(input, output, scan_op, thread_postfix);
401
+ }
402
+
403
+ };
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/selective_scan.h ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
8
+
9
+ struct SSMScanParamsBase {
10
+ using index_t = uint32_t;
11
+
12
+ int batch, seqlen, n_chunks;
13
+ index_t a_batch_stride;
14
+ index_t b_batch_stride;
15
+ index_t out_batch_stride;
16
+
17
+ // Common data pointers.
18
+ void *__restrict__ a_ptr;
19
+ void *__restrict__ b_ptr;
20
+ void *__restrict__ out_ptr;
21
+ void *__restrict__ x_ptr;
22
+ };
23
+
24
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
25
+
26
+ struct SSMParamsBase {
27
+ using index_t = uint32_t;
28
+
29
+ int batch, dim, seqlen, dstate, n_groups, n_chunks;
30
+ int dim_ngroups_ratio;
31
+
32
+ bool delta_softplus;
33
+
34
+ index_t A_d_stride;
35
+ index_t A_dstate_stride;
36
+ index_t B_batch_stride;
37
+ index_t B_d_stride;
38
+ index_t B_dstate_stride;
39
+ index_t B_group_stride;
40
+ index_t C_batch_stride;
41
+ index_t C_d_stride;
42
+ index_t C_dstate_stride;
43
+ index_t C_group_stride;
44
+ index_t u_batch_stride;
45
+ index_t u_d_stride;
46
+ index_t delta_batch_stride;
47
+ index_t delta_d_stride;
48
+ index_t out_batch_stride;
49
+ index_t out_d_stride;
50
+
51
+ // Common data pointers.
52
+ void *__restrict__ A_ptr;
53
+ void *__restrict__ B_ptr;
54
+ void *__restrict__ C_ptr;
55
+ void *__restrict__ D_ptr;
56
+ void *__restrict__ u_ptr;
57
+ void *__restrict__ delta_ptr;
58
+ void *__restrict__ delta_bias_ptr;
59
+ void *__restrict__ out_ptr;
60
+ void *__restrict__ x_ptr;
61
+ };
62
+
63
+ struct SSMParamsBwd: public SSMParamsBase {
64
+ index_t dout_batch_stride;
65
+ index_t dout_d_stride;
66
+ index_t dA_d_stride;
67
+ index_t dA_dstate_stride;
68
+ index_t dB_batch_stride;
69
+ index_t dB_group_stride;
70
+ index_t dB_d_stride;
71
+ index_t dB_dstate_stride;
72
+ index_t dC_batch_stride;
73
+ index_t dC_group_stride;
74
+ index_t dC_d_stride;
75
+ index_t dC_dstate_stride;
76
+ index_t du_batch_stride;
77
+ index_t du_d_stride;
78
+ index_t ddelta_batch_stride;
79
+ index_t ddelta_d_stride;
80
+
81
+ // Common data pointers.
82
+ void *__restrict__ dout_ptr;
83
+ void *__restrict__ dA_ptr;
84
+ void *__restrict__ dB_ptr;
85
+ void *__restrict__ dC_ptr;
86
+ void *__restrict__ dD_ptr;
87
+ void *__restrict__ du_ptr;
88
+ void *__restrict__ ddelta_ptr;
89
+ void *__restrict__ ddelta_bias_ptr;
90
+ };
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/selective_scan_common.h ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <cuda_bf16.h>
8
+ #include <cuda_fp16.h>
9
+ #include <c10/util/complex.h> // For scalar_value_type
10
+
11
+ #define MAX_DSTATE 256
12
+
13
+ inline __device__ float2 operator+(const float2 & a, const float2 & b){
14
+ return {a.x + b.x, a.y + b.y};
15
+ }
16
+
17
+ inline __device__ float3 operator+(const float3 &a, const float3 &b) {
18
+ return {a.x + b.x, a.y + b.y, a.z + b.z};
19
+ }
20
+
21
+ inline __device__ float4 operator+(const float4 & a, const float4 & b){
22
+ return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
23
+ }
24
+
25
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
26
+
27
+ template<int BYTES> struct BytesToType {};
28
+
29
+ template<> struct BytesToType<16> {
30
+ using Type = uint4;
31
+ static_assert(sizeof(Type) == 16);
32
+ };
33
+
34
+ template<> struct BytesToType<8> {
35
+ using Type = uint64_t;
36
+ static_assert(sizeof(Type) == 8);
37
+ };
38
+
39
+ template<> struct BytesToType<4> {
40
+ using Type = uint32_t;
41
+ static_assert(sizeof(Type) == 4);
42
+ };
43
+
44
+ template<> struct BytesToType<2> {
45
+ using Type = uint16_t;
46
+ static_assert(sizeof(Type) == 2);
47
+ };
48
+
49
+ template<> struct BytesToType<1> {
50
+ using Type = uint8_t;
51
+ static_assert(sizeof(Type) == 1);
52
+ };
53
+
54
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
55
+
56
+ template<typename scalar_t, int N>
57
+ struct Converter{
58
+ static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) {
59
+ #pragma unroll
60
+ for (int i = 0; i < N; ++i) { dst[i] = src[i]; }
61
+ }
62
+ };
63
+
64
+ template<int N>
65
+ struct Converter<at::Half, N>{
66
+ static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) {
67
+ static_assert(N % 2 == 0);
68
+ auto &src2 = reinterpret_cast<const half2 (&)[N / 2]>(src);
69
+ auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
70
+ #pragma unroll
71
+ for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); }
72
+ }
73
+ };
74
+
75
+ #if __CUDA_ARCH__ >= 800
76
+ template<int N>
77
+ struct Converter<at::BFloat16, N>{
78
+ static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) {
79
+ static_assert(N % 2 == 0);
80
+ auto &src2 = reinterpret_cast<const nv_bfloat162 (&)[N / 2]>(src);
81
+ auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
82
+ #pragma unroll
83
+ for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); }
84
+ }
85
+ };
86
+ #endif
87
+
88
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
89
+ template<typename scalar_t> struct SSMScanOp;
90
+
91
+ template<>
92
+ struct SSMScanOp<float> {
93
+ __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const {
94
+ return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
95
+ }
96
+ };
97
+
98
+ // A stateful callback functor that maintains a running prefix to be applied
99
+ // during consecutive scan operations.
100
+ template <typename scalar_t> struct SSMScanPrefixCallbackOp {
101
+ using scan_t = std::conditional_t<std::is_same_v<scalar_t, float>, float2, float4>;
102
+ scan_t running_prefix;
103
+ // Constructor
104
+ __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {}
105
+ // Callback operator to be entered by the first warp of threads in the block.
106
+ // Thread-0 is responsible for returning a value for seeding the block-wide scan.
107
+ __device__ scan_t operator()(scan_t block_aggregate) {
108
+ scan_t old_prefix = running_prefix;
109
+ running_prefix = SSMScanOp<scalar_t>()(running_prefix, block_aggregate);
110
+ return old_prefix;
111
+ }
112
+ };
113
+
114
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
115
+
116
+ template<typename Ktraits>
117
+ inline __device__ void load_input(typename Ktraits::input_t *u,
118
+ typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
119
+ typename Ktraits::BlockLoadT::TempStorage &smem_load,
120
+ int seqlen) {
121
+ if constexpr (Ktraits::kIsEvenLen) {
122
+ auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
123
+ using vec_t = typename Ktraits::vec_t;
124
+ Ktraits::BlockLoadVecT(smem_load_vec).Load(
125
+ reinterpret_cast<vec_t*>(u),
126
+ reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(u_vals)
127
+ );
128
+ } else {
129
+ Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);
130
+ }
131
+ }
132
+
133
+ template<typename Ktraits>
134
+ inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
135
+ typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
136
+ typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight,
137
+ int seqlen) {
138
+ constexpr int kNItems = Ktraits::kNItems;
139
+ typename Ktraits::input_t B_vals_load[kNItems];
140
+ if constexpr (Ktraits::kIsEvenLen) {
141
+ auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
142
+ using vec_t = typename Ktraits::vec_t;
143
+ Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
144
+ reinterpret_cast<vec_t*>(Bvar),
145
+ reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(B_vals_load)
146
+ );
147
+ } else {
148
+ Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
149
+ }
150
+ // #pragma unroll
151
+ // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
152
+ Converter<typename Ktraits::input_t, kNItems>::to_float(B_vals_load, B_vals);
153
+ }
154
+
155
+ template<typename Ktraits>
156
+ inline __device__ void store_output(typename Ktraits::input_t *out,
157
+ const float (&out_vals)[Ktraits::kNItems],
158
+ typename Ktraits::BlockStoreT::TempStorage &smem_store,
159
+ int seqlen) {
160
+ typename Ktraits::input_t write_vals[Ktraits::kNItems];
161
+ #pragma unroll
162
+ for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
163
+ if constexpr (Ktraits::kIsEvenLen) {
164
+ auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
165
+ using vec_t = typename Ktraits::vec_t;
166
+ Ktraits::BlockStoreVecT(smem_store_vec).Store(
167
+ reinterpret_cast<vec_t*>(out),
168
+ reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(write_vals)
169
+ );
170
+ } else {
171
+ Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);
172
+ }
173
+ }
174
+
175
+ template<typename Ktraits>
176
+ inline __device__ void store_output1(typename Ktraits::output_t *out,
177
+ const float (&out_vals)[Ktraits::kNItems],
178
+ typename Ktraits::BlockStoreOutputT::TempStorage &smem_store,
179
+ int seqlen) {
180
+ typename Ktraits::output_t write_vals[Ktraits::kNItems];
181
+ #pragma unroll
182
+ for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
183
+ if constexpr (Ktraits::kIsEvenLen) {
184
+ auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreOutputVecT::TempStorage&>(smem_store);
185
+ using vec_t = typename Ktraits::vec_t;
186
+ Ktraits::BlockStoreOutputVecT(smem_store_vec).Store(
187
+ reinterpret_cast<vec_t*>(out),
188
+ reinterpret_cast<vec_t(&)[Ktraits::kNLoadsOutput]>(write_vals)
189
+ );
190
+ } else {
191
+ Ktraits::BlockStoreOutputT(smem_store).Store(out, write_vals, seqlen);
192
+ }
193
+ }
194
+
195
+ template<typename Ktraits>
196
+ inline __device__ void load_output(typename Ktraits::output_t *u,
197
+ typename Ktraits::output_t (&u_vals)[Ktraits::kNItems],
198
+ typename Ktraits::BlockLoadOutputT::TempStorage &smem_load,
199
+ int seqlen) {
200
+ if constexpr (Ktraits::kIsEvenLen) {
201
+ auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadOutputVecT::TempStorage&>(smem_load);
202
+ using vec_t = typename Ktraits::vec_t;
203
+ Ktraits::BlockLoadOutputVecT(smem_load_vec).Load(
204
+ reinterpret_cast<vec_t*>(u),
205
+ reinterpret_cast<vec_t(&)[Ktraits::kNLoadsOutput]>(u_vals)
206
+ );
207
+ } else {
208
+ Ktraits::BlockLoadOutputT(smem_load).Load(u, u_vals, seqlen, 0.f);
209
+ }
210
+ }