SakuraD commited on
Commit
bc9387a
1 Parent(s): aeed542

update packages

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +184 -0
  2. __pycache__/imagenet_class_index.cpython-310.pyc +0 -0
  3. __pycache__/kinetics_class_index.cpython-310.pyc +0 -0
  4. __pycache__/transforms.cpython-310.pyc +0 -0
  5. __pycache__/videomamba_image.cpython-310.pyc +0 -0
  6. __pycache__/videomamba_video.cpython-310.pyc +0 -0
  7. app.py +1 -2
  8. causal-conv1d/AUTHORS +0 -1
  9. causal-conv1d/LICENSE +0 -29
  10. causal-conv1d/README.md +0 -1
  11. causal-conv1d/causal_conv1d/__init__.py +0 -3
  12. causal-conv1d/causal_conv1d/causal_conv1d_interface.py +0 -104
  13. causal-conv1d/csrc/causal_conv1d.cpp +0 -333
  14. causal-conv1d/csrc/causal_conv1d.h +0 -53
  15. causal-conv1d/csrc/causal_conv1d_bwd.cu +0 -525
  16. causal-conv1d/csrc/causal_conv1d_common.h +0 -64
  17. causal-conv1d/csrc/causal_conv1d_fwd.cu +0 -350
  18. causal-conv1d/csrc/causal_conv1d_update.cu +0 -96
  19. causal-conv1d/csrc/static_switch.h +0 -25
  20. causal-conv1d/setup.py +0 -264
  21. causal-conv1d/tests/test_causal_conv1d.py +0 -173
  22. install.sh +0 -2
  23. mamba/.gitmodules +0 -3
  24. mamba/AUTHORS +0 -2
  25. mamba/LICENSE +0 -201
  26. mamba/README.md +0 -149
  27. mamba/assets/selection.png +0 -0
  28. mamba/benchmarks/benchmark_generation_mamba_simple.py +0 -88
  29. mamba/csrc/selective_scan/reverse_scan.cuh +0 -401
  30. mamba/csrc/selective_scan/selective_scan.cpp +0 -497
  31. mamba/csrc/selective_scan/selective_scan.h +0 -101
  32. mamba/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu +0 -9
  33. mamba/csrc/selective_scan/selective_scan_bwd_bf16_real.cu +0 -9
  34. mamba/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu +0 -9
  35. mamba/csrc/selective_scan/selective_scan_bwd_fp16_real.cu +0 -9
  36. mamba/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu +0 -9
  37. mamba/csrc/selective_scan/selective_scan_bwd_fp32_real.cu +0 -9
  38. mamba/csrc/selective_scan/selective_scan_bwd_kernel.cuh +0 -531
  39. mamba/csrc/selective_scan/selective_scan_common.h +0 -221
  40. mamba/csrc/selective_scan/selective_scan_fwd_bf16.cu +0 -10
  41. mamba/csrc/selective_scan/selective_scan_fwd_fp16.cu +0 -10
  42. mamba/csrc/selective_scan/selective_scan_fwd_fp32.cu +0 -10
  43. mamba/csrc/selective_scan/selective_scan_fwd_kernel.cuh +0 -345
  44. mamba/csrc/selective_scan/static_switch.h +0 -25
  45. mamba/csrc/selective_scan/uninitialized_copy.cuh +0 -69
  46. mamba/evals/lm_harness_eval.py +0 -39
  47. mamba/mamba_ssm/__init__.py +0 -5
  48. mamba/mamba_ssm/models/__init__.py +0 -0
  49. mamba/mamba_ssm/models/mixer_seq_simple.py +0 -233
  50. mamba/mamba_ssm/modules/__init__.py +0 -0
.gitignore ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ causal-conv1d/causal_conv1d/__pycache__/*
3
+ mamba/mamba_ssm/utils/__pycache__/*
4
+ mamba/mamba_ssm/ops/triton/__pycache__/*
5
+ mamba/mamba_ssm/ops/__pycache__/*
6
+ mamba/mamba_ssm/modules/__pycache__/*
7
+ mamba/mamba_ssm/__pycache__/*
8
+ mamba/mamba_ssm/models/__pycache__/*
9
+ causal-conv1d/build/
10
+ causal-conv1d/causal_conv1d.egg-info/
11
+ causal-conv1d/causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so
12
+ mamba/build/
13
+ mamba/mamba_ssm.egg-info/
14
+ mamba/selective_scan_cuda.cpython-310-x86_64-linux-gnu.so
15
+
16
+ # Docker file from Python is inspired from here :
17
+ # https://github.com/github/gitignore/blob/master/Python.gitignore
18
+
19
+ *_init
20
+ events*
21
+ *log.txt
22
+ *.pth
23
+ log_*.txt
24
+ log
25
+ latest
26
+ *.txt
27
+ *.pt
28
+ checkpoint*
29
+ batchscript*
30
+ logs/
31
+
32
+ # custom
33
+ .vscode
34
+
35
+ # Byte-compiled / optimized / DLL files
36
+ __pycache__/
37
+ *.py[cod]
38
+ *$py.class
39
+
40
+ # C extensions
41
+ *.so
42
+
43
+ # Distribution / packaging
44
+ .Python
45
+ build/
46
+ develop-eggs/
47
+ dist/
48
+ downloads/
49
+ eggs/
50
+ .eggs/
51
+ lib/
52
+ lib64/
53
+ parts/
54
+ sdist/
55
+ var/
56
+ wheels/
57
+ share/python-wheels/
58
+ *.egg-info/
59
+ .installed.cfg
60
+ *.egg
61
+ MANIFEST
62
+
63
+ # PyInstaller
64
+ # Usually these files are written by a python script from a template
65
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
66
+ *.manifest
67
+ *.spec
68
+
69
+ # Installer logs
70
+ pip-log.txt
71
+ pip-delete-this-directory.txt
72
+
73
+ # Unit test / coverage reports
74
+ tests/report/
75
+ .coverage
76
+ .coverage.*
77
+ .cache
78
+ nosetests.xml
79
+ coverage.xml
80
+ *.cover
81
+ *.py,cover
82
+ .hypothesis/
83
+ .pytest_cache/
84
+
85
+ # Translations
86
+ *.mo
87
+ *.pot
88
+
89
+ # Django stuff:
90
+ *.log
91
+ local_settings.py
92
+ db.sqlite3
93
+ db.sqlite3-journal
94
+
95
+ # Flask stuff:
96
+ instance/
97
+ .webassets-cache
98
+
99
+ # Scrapy stuff:
100
+ .scrapy
101
+
102
+ # Sphinx documentation
103
+ docs/_build/
104
+
105
+ # PyBuilder
106
+ .pybuilder/
107
+ target/
108
+
109
+ # Jupyter Notebook
110
+ .ipynb_checkpoints
111
+
112
+ # IPython
113
+ profile_default/
114
+ ipython_config.py
115
+
116
+ # pyenv
117
+ # For a library or package, you might want to ignore these files since the code is
118
+ # intended to run in multiple environments; otherwise, check them in:
119
+ # .python-version
120
+
121
+ # pipenv
122
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
123
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
124
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
125
+ # install all needed dependencies.
126
+ #Pipfile.lock
127
+
128
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
129
+ __pypackages__/
130
+
131
+ # Celery stuff
132
+ celerybeat-schedule
133
+ celerybeat.pid
134
+
135
+ # SageMath parsed files
136
+ *.sage.py
137
+
138
+ # Environments
139
+ .env
140
+ .venv
141
+ env/
142
+ venv/
143
+ ENV/
144
+ env.bak/
145
+ venv.bak/
146
+
147
+ # Spyder project settings
148
+ .spyderproject
149
+ .spyproject
150
+
151
+ # Rope project settings
152
+ .ropeproject
153
+
154
+ # mkdocs documentation
155
+ /site
156
+
157
+ # mypy
158
+ .mypy_cache/
159
+ .dmypy.json
160
+ dmypy.json
161
+
162
+ # Pyre type checker
163
+ .pyre/
164
+
165
+ # pytype static type analyzer
166
+ .pytype/
167
+
168
+
169
+ # Cython debug symbols
170
+ cython_debug/
171
+
172
+
173
+ # others
174
+ work_dir/
175
+ batchscript-*
176
+ phoenix-slurm-*
177
+ debug*
178
+ wandb
179
+
180
+ *.pkl
181
+ *.err
182
+ *.out
183
+ *.csv
184
+ ckpt/
__pycache__/imagenet_class_index.cpython-310.pyc DELETED
Binary file (60.6 kB)
 
__pycache__/kinetics_class_index.cpython-310.pyc DELETED
Binary file (15.2 kB)
 
__pycache__/transforms.cpython-310.pyc DELETED
Binary file (13.4 kB)
 
__pycache__/videomamba_image.cpython-310.pyc DELETED
Binary file (9.64 kB)
 
__pycache__/videomamba_video.cpython-310.pyc DELETED
Binary file (11.2 kB)
 
app.py CHANGED
@@ -2,10 +2,9 @@ import os
2
  import spaces
3
 
4
  # install packages for mamba
5
- @spaces.GPU
6
  def install():
7
  print("Install personal packages", flush=True)
8
- os.system("bash install.sh")
9
 
10
  install()
11
 
 
2
  import spaces
3
 
4
  # install packages for mamba
 
5
  def install():
6
  print("Install personal packages", flush=True)
7
+ os.system("pip install mamba_ssm-1.0.1-cp310-cp310-linux_x86_64.whl")
8
 
9
  install()
10
 
causal-conv1d/AUTHORS DELETED
@@ -1 +0,0 @@
1
- Tri Dao, tri@tridao.me
 
 
causal-conv1d/LICENSE DELETED
@@ -1,29 +0,0 @@
1
- BSD 3-Clause License
2
-
3
- Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
4
- All rights reserved.
5
-
6
- Redistribution and use in source and binary forms, with or without
7
- modification, are permitted provided that the following conditions are met:
8
-
9
- * Redistributions of source code must retain the above copyright notice, this
10
- list of conditions and the following disclaimer.
11
-
12
- * Redistributions in binary form must reproduce the above copyright notice,
13
- this list of conditions and the following disclaimer in the documentation
14
- and/or other materials provided with the distribution.
15
-
16
- * Neither the name of the copyright holder nor the names of its
17
- contributors may be used to endorse or promote products derived from
18
- this software without specific prior written permission.
19
-
20
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
causal-conv1d/README.md DELETED
@@ -1 +0,0 @@
1
- # Causal depthwise conv1d in CUDA with a PyTorch interface
 
 
causal-conv1d/causal_conv1d/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- __version__ = "1.0.0"
2
-
3
- from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
 
 
 
 
causal-conv1d/causal_conv1d/causal_conv1d_interface.py DELETED
@@ -1,104 +0,0 @@
1
- # Copyright (c) 2023, Tri Dao.
2
-
3
- import torch
4
- import torch.nn.functional as F
5
-
6
-
7
- import causal_conv1d_cuda
8
-
9
-
10
- class CausalConv1dFn(torch.autograd.Function):
11
- @staticmethod
12
- def forward(ctx, x, weight, bias=None, activation=None):
13
- if activation not in [None, "silu", "swish"]:
14
- raise NotImplementedError("activation must be None, silu, or swish")
15
- if x.stride(2) != 1 and x.stride(1) != 1:
16
- x = x.contiguous()
17
- bias = bias.contiguous() if bias is not None else None
18
- ctx.save_for_backward(x, weight, bias)
19
- ctx.activation = activation in ["silu", "swish"]
20
- out = causal_conv1d_cuda.causal_conv1d_fwd(x, weight, bias, ctx.activation)
21
- return out
22
-
23
- @staticmethod
24
- def backward(ctx, dout):
25
- x, weight, bias = ctx.saved_tensors
26
- if dout.stride(2) != 1 and dout.stride(1) != 1:
27
- dout = dout.contiguous()
28
- # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
29
- # backward of conv1d with the backward of chunk).
30
- # Here we just pass in None and dx will be allocated in the C++ code.
31
- dx, dweight, dbias = causal_conv1d_cuda.causal_conv1d_bwd(
32
- x, weight, bias, dout, None, ctx.activation
33
- )
34
- return dx, dweight, dbias if bias is not None else None, None
35
-
36
-
37
- def causal_conv1d_fn(x, weight, bias=None, activation=None):
38
- """
39
- x: (batch, dim, seqlen)
40
- weight: (dim, width)
41
- bias: (dim,)
42
- activation: either None or "silu" or "swish"
43
-
44
- out: (batch, dim, seqlen)
45
- """
46
- return CausalConv1dFn.apply(x, weight, bias, activation)
47
-
48
-
49
- def causal_conv1d_ref(x, weight, bias=None, activation=None):
50
- """
51
- x: (batch, dim, seqlen)
52
- weight: (dim, width)
53
- bias: (dim,)
54
-
55
- out: (batch, dim, seqlen)
56
- """
57
- if activation not in [None, "silu", "swish"]:
58
- raise NotImplementedError("activation must be None, silu, or swish")
59
- dtype_in = x.dtype
60
- x = x.to(weight.dtype)
61
- seqlen = x.shape[-1]
62
- dim, width = weight.shape
63
- out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
64
- out = out[..., :seqlen]
65
- return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
66
-
67
-
68
- def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None):
69
- """
70
- x: (batch, dim)
71
- conv_state: (batch, dim, width)
72
- weight: (dim, width)
73
- bias: (dim,)
74
-
75
- out: (batch, dim)
76
- """
77
- if activation not in [None, "silu", "swish"]:
78
- raise NotImplementedError("activation must be None, silu, or swish")
79
- activation = activation in ["silu", "swish"]
80
- return causal_conv1d_cuda.causal_conv1d_update(x, conv_state, weight, bias, activation)
81
-
82
-
83
- def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None):
84
- """
85
- x: (batch, dim)
86
- conv_state: (batch, dim, width)
87
- weight: (dim, width)
88
- bias: (dim,)
89
-
90
- out: (batch, dim)
91
- """
92
- if activation not in [None, "silu", "swish"]:
93
- raise NotImplementedError("activation must be None, silu, or swish")
94
- dtype_in = x.dtype
95
- batch, dim = x.shape
96
- width = weight.shape[1]
97
- assert conv_state.shape == (batch, dim, width)
98
- assert weight.shape == (dim, width)
99
- conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
100
- conv_state[:, :, -1] = x
101
- out = torch.sum(conv_state * weight, dim=-1) # (B D)
102
- if bias is not None:
103
- out += bias
104
- return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
causal-conv1d/csrc/causal_conv1d.cpp DELETED
@@ -1,333 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2023, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #include <ATen/cuda/CUDAContext.h>
6
- #include <c10/cuda/CUDAGuard.h>
7
- #include <torch/extension.h>
8
- #include <vector>
9
-
10
- #include "causal_conv1d.h"
11
-
12
- #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
13
-
14
- #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
15
- if (ITYPE == at::ScalarType::Half) { \
16
- using input_t = at::Half; \
17
- __VA_ARGS__(); \
18
- } else if (ITYPE == at::ScalarType::BFloat16) { \
19
- using input_t = at::BFloat16; \
20
- __VA_ARGS__(); \
21
- } else if (ITYPE == at::ScalarType::Float) { \
22
- using input_t = float; \
23
- __VA_ARGS__(); \
24
- } else { \
25
- AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
26
- }
27
-
28
- #define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
29
- if (WTYPE == at::ScalarType::Half) { \
30
- using weight_t = at::Half; \
31
- __VA_ARGS__(); \
32
- } else if (WTYPE == at::ScalarType::BFloat16) { \
33
- using weight_t = at::BFloat16; \
34
- __VA_ARGS__(); \
35
- } else if (WTYPE == at::ScalarType::Float) { \
36
- using weight_t = float; \
37
- __VA_ARGS__(); \
38
- } else { \
39
- AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
40
- }
41
-
42
- template<typename input_t, typename weight_t>
43
- void causal_conv1d_fwd_cuda(ConvParamsBase &params, cudaStream_t stream);
44
- template <typename input_t, typename weight_t>
45
- void causal_conv1d_channellast_fwd_cuda(ConvParamsBase &params, cudaStream_t stream);
46
-
47
- template<typename input_t, typename weight_t>
48
- void causal_conv1d_bwd_cuda(ConvParamsBwd &params, cudaStream_t stream);
49
- template<typename input_t, typename weight_t>
50
- void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd &params, cudaStream_t stream);
51
-
52
- template<typename input_t, typename weight_t>
53
- void causal_conv1d_update_cuda(ConvParamsBase &params, cudaStream_t stream);
54
-
55
- void set_conv_params_fwd(ConvParamsBase &params,
56
- // sizes
57
- const size_t batch,
58
- const size_t dim,
59
- const size_t seqlen,
60
- const size_t width,
61
- // device pointers
62
- const at::Tensor x,
63
- const at::Tensor weight,
64
- const at::Tensor out,
65
- void* bias_ptr,
66
- bool silu_activation) {
67
-
68
- // Reset the parameters
69
- memset(&params, 0, sizeof(params));
70
-
71
- params.batch = batch;
72
- params.dim = dim;
73
- params.seqlen = seqlen;
74
- params.width = width;
75
-
76
- params.silu_activation = silu_activation;
77
-
78
- // Set the pointers and strides.
79
- params.x_ptr = x.data_ptr();
80
- params.weight_ptr = weight.data_ptr();
81
- params.bias_ptr = bias_ptr;
82
- params.out_ptr = out.data_ptr();
83
- // All stride are in elements, not bytes.
84
- params.x_batch_stride = x.stride(0);
85
- params.x_c_stride = x.stride(1);
86
- params.x_l_stride = x.stride(-1);
87
- params.weight_c_stride = weight.stride(0);
88
- params.weight_width_stride = weight.stride(1);
89
- params.out_batch_stride = out.stride(0);
90
- params.out_c_stride = out.stride(1);
91
- params.out_l_stride = out.stride(-1);
92
- }
93
-
94
-
95
- void set_conv_params_bwd(ConvParamsBwd &params,
96
- // sizes
97
- const size_t batch,
98
- const size_t dim,
99
- const size_t seqlen,
100
- const size_t width,
101
- // device pointers
102
- const at::Tensor x,
103
- const at::Tensor weight,
104
- void* bias_ptr,
105
- const at::Tensor dout,
106
- const at::Tensor dx,
107
- const at::Tensor dweight,
108
- void* dbias_ptr,
109
- bool silu_activation) {
110
- // Pass in "dout" instead of "out", we're not gonna use "out" at all.
111
- set_conv_params_fwd(params, batch, dim, seqlen, width,
112
- x, weight, dout, bias_ptr, silu_activation);
113
-
114
- // Set the pointers and strides.
115
- params.dout_ptr = dout.data_ptr();
116
- params.dx_ptr = dx.data_ptr();
117
- params.dweight_ptr = dweight.data_ptr();
118
- params.dbias_ptr = dbias_ptr;
119
- // All stride are in elements, not bytes.
120
- params.dout_batch_stride = dout.stride(0);
121
- params.dout_c_stride = dout.stride(1);
122
- params.dout_l_stride = dout.stride(2);
123
- params.dweight_c_stride = dweight.stride(0);
124
- params.dweight_width_stride = dweight.stride(1);
125
- params.dx_batch_stride = dx.stride(0);
126
- params.dx_c_stride = dx.stride(1);
127
- params.dx_l_stride = dx.stride(2);
128
- }
129
-
130
- at::Tensor
131
- causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
132
- const c10::optional<at::Tensor> &bias_,
133
- bool silu_activation) {
134
- auto input_type = x.scalar_type();
135
- auto weight_type = weight.scalar_type();
136
- TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
137
- TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
138
-
139
- TORCH_CHECK(x.is_cuda());
140
- TORCH_CHECK(weight.is_cuda());
141
-
142
- const auto sizes = x.sizes();
143
- const int batch_size = sizes[0];
144
- const int dim = sizes[1];
145
- const int seqlen = sizes[2];
146
- const int width = weight.size(-1);
147
-
148
- CHECK_SHAPE(x, batch_size, dim, seqlen);
149
- CHECK_SHAPE(weight, dim, width);
150
-
151
- TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
152
- const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
153
-
154
- if (is_channel_last) {
155
- TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
156
- }
157
- TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
158
-
159
-
160
- if (bias_.has_value()) {
161
- auto bias = bias_.value();
162
- TORCH_CHECK(bias.scalar_type() == weight_type);
163
- TORCH_CHECK(bias.is_cuda());
164
- TORCH_CHECK(bias.stride(-1) == 1);
165
- CHECK_SHAPE(bias, dim);
166
- }
167
-
168
- at::Tensor out = torch::empty_like(x);
169
-
170
- ConvParamsBase params;
171
- set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
172
- bias_.has_value() ? bias_.value().data_ptr() : nullptr,
173
- silu_activation);
174
-
175
- // Otherwise the kernel will be launched from cuda:0 device
176
- // Cast to char to avoid compiler warning about narrowing
177
- at::cuda::CUDAGuard device_guard{(char)x.get_device()};
178
- auto stream = at::cuda::getCurrentCUDAStream().stream();
179
- DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
180
- DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_fwd", [&] {
181
- if (!is_channel_last) {
182
- causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
183
- } else {
184
- causal_conv1d_channellast_fwd_cuda<input_t, weight_t>(params, stream);
185
- }
186
- });
187
- });
188
- return out;
189
- }
190
-
191
- std::vector<at::Tensor>
192
- causal_conv1d_bwd(const at::Tensor &x, const at::Tensor &weight,
193
- const c10::optional<at::Tensor> &bias_,
194
- at::Tensor &dout,
195
- c10::optional<at::Tensor> &dx_,
196
- bool silu_activation) {
197
- auto input_type = x.scalar_type();
198
- auto weight_type = weight.scalar_type();
199
- TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
200
- TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
201
-
202
- TORCH_CHECK(x.is_cuda());
203
- TORCH_CHECK(weight.is_cuda());
204
- TORCH_CHECK(dout.is_cuda());
205
-
206
- const auto sizes = x.sizes();
207
- const int batch_size = sizes[0];
208
- const int dim = sizes[1];
209
- const int seqlen = sizes[2];
210
- const int width = weight.size(-1);
211
-
212
- TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
213
-
214
- CHECK_SHAPE(x, batch_size, dim, seqlen);
215
- CHECK_SHAPE(weight, dim, width);
216
- CHECK_SHAPE(dout, batch_size, dim, seqlen);
217
-
218
- TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
219
- const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
220
- if (!is_channel_last && dout.stride(2) != 1) { dout = dout.contiguous(); }
221
- if (is_channel_last && dout.stride(1) != 1) { dout = dout.transpose(-1, -2).contiguous().transpose(-1, -2); }
222
-
223
- if (bias_.has_value()) {
224
- auto bias = bias_.value();
225
- TORCH_CHECK(bias.scalar_type() == weight_type);
226
- TORCH_CHECK(bias.is_cuda());
227
- TORCH_CHECK(bias.stride(-1) == 1);
228
- CHECK_SHAPE(bias, dim);
229
- }
230
-
231
- at::Tensor dx;
232
- if (dx_.has_value()) {
233
- dx = dx_.value();
234
- TORCH_CHECK(dx.scalar_type() == input_type);
235
- TORCH_CHECK(dx.is_cuda());
236
- CHECK_SHAPE(dx, batch_size, dim, seqlen);
237
- if (!is_channel_last) { TORCH_CHECK(dx.stride(2) == 1); }
238
- if (is_channel_last) { TORCH_CHECK(dx.stride(1) == 1); }
239
- } else {
240
- dx = torch::empty_like(x);
241
- }
242
-
243
- // Otherwise the kernel will be launched from cuda:0 device
244
- // Cast to char to avoid compiler warning about narrowing
245
- at::cuda::CUDAGuard device_guard{(char)x.get_device()};
246
-
247
- at::Tensor dweight = torch::zeros_like(weight, weight.options().dtype(at::kFloat));
248
- at::Tensor dbias;
249
- if (bias_.has_value()) { dbias = torch::zeros_like(bias_.value(), bias_.value().options().dtype(at::kFloat)); }
250
-
251
- ConvParamsBwd params;
252
- set_conv_params_bwd(params, batch_size, dim, seqlen, width,
253
- x, weight, bias_.has_value() ? bias_.value().data_ptr() : nullptr,
254
- dout, dx, dweight, bias_.has_value() ? dbias.data_ptr() : nullptr,
255
- silu_activation);
256
-
257
- auto stream = at::cuda::getCurrentCUDAStream().stream();
258
- DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_bwd", [&] {
259
- DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_bwd", [&] {
260
- if (!is_channel_last) {
261
- causal_conv1d_bwd_cuda<input_t, weight_t>(params, stream);
262
- } else {
263
- causal_conv1d_channellast_bwd_cuda<input_t, weight_t>(params, stream);
264
- }
265
- });
266
- });
267
- return {dx, dweight.to(weight.dtype()), bias_.has_value() ? dbias.to(bias_.value().dtype()) : dbias};
268
- }
269
-
270
- at::Tensor
271
- causal_conv1d_update(const at::Tensor &x,
272
- const at::Tensor &conv_state,
273
- const at::Tensor &weight,
274
- const c10::optional<at::Tensor> &bias_,
275
- bool silu_activation) {
276
- auto input_type = x.scalar_type();
277
- auto weight_type = weight.scalar_type();
278
- TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
279
- TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
280
- TORCH_CHECK(conv_state.scalar_type() == input_type);
281
-
282
- TORCH_CHECK(x.is_cuda());
283
- TORCH_CHECK(conv_state.is_cuda());
284
- TORCH_CHECK(weight.is_cuda());
285
-
286
- const auto sizes = x.sizes();
287
- const int batch_size = sizes[0];
288
- const int dim = sizes[1];
289
- const int width = weight.size(-1);
290
-
291
- CHECK_SHAPE(x, batch_size, dim);
292
- CHECK_SHAPE(conv_state, batch_size, dim, width);
293
- CHECK_SHAPE(weight, dim, width);
294
-
295
- TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
296
-
297
- if (bias_.has_value()) {
298
- auto bias = bias_.value();
299
- TORCH_CHECK(bias.scalar_type() == weight_type);
300
- TORCH_CHECK(bias.is_cuda());
301
- TORCH_CHECK(bias.stride(-1) == 1);
302
- CHECK_SHAPE(bias, dim);
303
- }
304
-
305
- at::Tensor out = torch::empty_like(x);
306
-
307
- ConvParamsBase params;
308
- set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out,
309
- bias_.has_value() ? bias_.value().data_ptr() : nullptr,
310
- silu_activation);
311
- params.conv_state_ptr = conv_state.data_ptr();
312
- // All stride are in elements, not bytes.
313
- params.conv_state_batch_stride = conv_state.stride(0);
314
- params.conv_state_c_stride = conv_state.stride(1);
315
- params.conv_state_l_stride = conv_state.stride(2);
316
-
317
- // Otherwise the kernel will be launched from cuda:0 device
318
- // Cast to char to avoid compiler warning about narrowing
319
- at::cuda::CUDAGuard device_guard{(char)x.get_device()};
320
- auto stream = at::cuda::getCurrentCUDAStream().stream();
321
- DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
322
- DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] {
323
- causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
324
- });
325
- });
326
- return out;
327
- }
328
-
329
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
330
- m.def("causal_conv1d_fwd", &causal_conv1d_fwd, "Causal conv1d forward");
331
- m.def("causal_conv1d_bwd", &causal_conv1d_bwd, "Causal conv1d backward");
332
- m.def("causal_conv1d_update", &causal_conv1d_update, "Causal conv1d update");
333
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
causal-conv1d/csrc/causal_conv1d.h DELETED
@@ -1,53 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2023, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- ////////////////////////////////////////////////////////////////////////////////////////////////////
8
-
9
- struct ConvParamsBase {
10
- using index_t = uint32_t;
11
-
12
- int batch, dim, seqlen, width;
13
- bool silu_activation;
14
-
15
- index_t x_batch_stride;
16
- index_t x_c_stride;
17
- index_t x_l_stride;
18
- index_t weight_c_stride;
19
- index_t weight_width_stride;
20
- index_t out_batch_stride;
21
- index_t out_c_stride;
22
- index_t out_l_stride;
23
-
24
- index_t conv_state_batch_stride;
25
- index_t conv_state_c_stride;
26
- index_t conv_state_l_stride;
27
-
28
- // Common data pointers.
29
- void *__restrict__ x_ptr;
30
- void *__restrict__ weight_ptr;
31
- void *__restrict__ bias_ptr;
32
- void *__restrict__ out_ptr;
33
-
34
- void *__restrict__ conv_state_ptr;
35
- };
36
-
37
- struct ConvParamsBwd: public ConvParamsBase {
38
- index_t dx_batch_stride;
39
- index_t dx_c_stride;
40
- index_t dx_l_stride;
41
- index_t dweight_c_stride;
42
- index_t dweight_width_stride;
43
- index_t dout_batch_stride;
44
- index_t dout_c_stride;
45
- index_t dout_l_stride;
46
-
47
- // Common data pointers.
48
- void *__restrict__ dx_ptr;
49
- void *__restrict__ dweight_ptr;
50
- void *__restrict__ dbias_ptr;
51
- void *__restrict__ dout_ptr;
52
- };
53
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
causal-conv1d/csrc/causal_conv1d_bwd.cu DELETED
@@ -1,525 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2023, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #include <c10/util/BFloat16.h>
6
- #include <c10/util/Half.h>
7
- #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
8
-
9
- #include <cub/block/block_load.cuh>
10
- #include <cub/block/block_store.cuh>
11
- #include <cub/block/block_reduce.cuh>
12
-
13
- #include "causal_conv1d.h"
14
- #include "causal_conv1d_common.h"
15
- #include "static_switch.h"
16
-
17
- template<int kNThreads_, int kWidth_, bool kSiluAct_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
18
- struct Causal_conv1d_bwd_kernel_traits {
19
- using input_t = input_t_;
20
- using weight_t = weight_t_;
21
- static constexpr int kNThreads = kNThreads_;
22
- static constexpr int kWidth = kWidth_;
23
- static constexpr bool kSiluAct = kSiluAct_;
24
- static constexpr int kNBytes = sizeof(input_t);
25
- static_assert(kNBytes == 2 || kNBytes == 4);
26
- static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
27
- static_assert(kWidth <= kNElts);
28
- // It's possible that we need to do 2 rounds of exchange if input_t is 16 bits
29
- // (since then we'd have 8 values of float, and each round we can exchange 4 floats).
30
- static constexpr int kNExchangeRounds = sizeof(float) / sizeof(input_t);
31
- static constexpr bool kIsVecLoad = kIsVecLoad_;
32
- using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
33
- using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
34
- using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
35
- using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
36
- using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
37
- using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
38
- static constexpr int kSmemIOSize = kIsVecLoad
39
- ? 0
40
- : std::max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
41
- static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts * (!kSiluAct ? 1 : kNExchangeRounds + 1);
42
- static constexpr int kSmemSize = std::max({kSmemExchangeSize,
43
- int(sizeof(typename BlockReduceFloatT::TempStorage))}) + (kIsVecLoad ? 0 : kSmemIOSize);
44
- };
45
-
46
- template<typename Ktraits>
47
- __global__ __launch_bounds__(Ktraits::kNThreads)
48
- void causal_conv1d_bwd_kernel(ConvParamsBwd params) {
49
- constexpr int kWidth = Ktraits::kWidth;
50
- constexpr int kNThreads = Ktraits::kNThreads;
51
- constexpr bool kSiluAct = Ktraits::kSiluAct;
52
- constexpr int kNElts = Ktraits::kNElts;
53
- constexpr int kNExchangeRounds = Ktraits::kNExchangeRounds;
54
- constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
55
- using input_t = typename Ktraits::input_t;
56
- using vec_t = typename Ktraits::vec_t;
57
- using weight_t = typename Ktraits::weight_t;
58
-
59
- // Shared memory.
60
- extern __shared__ char smem_[];
61
- auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
62
- auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
63
- auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
64
- auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
65
- vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
66
- vec_t *smem_exchange_x = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize) + kNThreads * kNExchangeRounds;
67
- auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
68
-
69
- const int tidx = threadIdx.x;
70
- const int batch_id = blockIdx.x;
71
- const int dim_id = blockIdx.y;
72
- input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
73
- + dim_id * params.x_c_stride;
74
- weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + dim_id * params.weight_c_stride;
75
- input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
76
- + dim_id * params.dout_c_stride;
77
- input_t *dx = reinterpret_cast<input_t *>(params.dx_ptr) + batch_id * params.dx_batch_stride
78
- + dim_id * params.dx_c_stride;
79
- float *dweight = reinterpret_cast<float *>(params.dweight_ptr) + dim_id * params.dweight_c_stride;
80
- float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[dim_id]);
81
-
82
- // Thread kNThreads - 1 will load the first elements of the next chunk so we initialize those to 0.
83
- if (tidx == 0) {
84
- if constexpr (!kSiluAct) {
85
- input_t zeros[kNElts] = {0};
86
- smem_exchange[0] = reinterpret_cast<vec_t *>(zeros)[0];
87
- } else {
88
- float zeros[kNElts] = {0};
89
- #pragma unroll
90
- for (int r = 0; r < kNExchangeRounds; ++r) {
91
- smem_exchange[r * kNThreads] = reinterpret_cast<vec_t *>(zeros)[r];
92
- }
93
- }
94
- }
95
-
96
- float weight_vals[kWidth];
97
- #pragma unroll
98
- for (int i = 0; i < kWidth; ++i) { weight_vals[i] = weight[i * params.weight_width_stride]; }
99
-
100
- float dweight_vals[kWidth] = {0};
101
- float dbias_val = 0;
102
-
103
- constexpr int kChunkSize = kNThreads * kNElts;
104
- const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
105
- x += (n_chunks - 1) * kChunkSize;
106
- dout += (n_chunks - 1) * kChunkSize;
107
- dx += (n_chunks - 1) * kChunkSize;
108
- for (int chunk = n_chunks - 1; chunk >= 0; --chunk) {
109
- input_t x_vals_load[2 * kNElts] = {0};
110
- input_t dout_vals_load[2 * kNElts] = {0};
111
- if constexpr(kIsVecLoad) {
112
- Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
113
- Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(dout), *reinterpret_cast<vec_t (*)[1]>(&dout_vals_load[0]), (params.seqlen - chunk * kChunkSize) / kNElts);
114
- } else {
115
- __syncthreads();
116
- Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
117
- __syncthreads();
118
- Ktraits::BlockLoadT(smem_load).Load(dout, *reinterpret_cast<input_t (*)[kNElts]>(&dout_vals_load[0]), params.seqlen - chunk * kChunkSize);
119
- }
120
- float dout_vals[2 * kNElts], x_vals[2 * kNElts];
121
- if constexpr (!kSiluAct) {
122
- __syncthreads();
123
- // Thread 0 don't write yet, so that thread kNThreads - 1 can read
124
- // the first elements of the next chunk.
125
- if (tidx > 0) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; }
126
- __syncthreads();
127
- reinterpret_cast<vec_t *>(dout_vals_load)[1] = smem_exchange[tidx < kNThreads - 1 ? tidx + 1 : 0];
128
- __syncthreads();
129
- // Now thread 0 can write the first elements of the current chunk.
130
- if (tidx == 0) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; }
131
- #pragma unroll
132
- for (int i = 0; i < 2 * kNElts; ++i) {
133
- dout_vals[i] = float(dout_vals_load[i]);
134
- x_vals[i] = float(x_vals_load[i]);
135
- }
136
- } else {
137
- if (tidx == 0 && chunk > 0) {
138
- if constexpr(kIsVecLoad) {
139
- reinterpret_cast<vec_t *>(x_vals_load)[0] = reinterpret_cast<vec_t *>(x)[-1];
140
- } else {
141
- #pragma unroll
142
- for (int i = 0; i < kNElts; ++i) {
143
- if (chunk * kChunkSize + i < params.seqlen) { x_vals_load[i] = x[-kNElts + i]; }
144
- }
145
- }
146
- }
147
- __syncthreads();
148
- smem_exchange_x[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1];
149
- __syncthreads();
150
- if (tidx > 0) { reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange_x[tidx - 1]; }
151
- #pragma unroll
152
- for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
153
- // Recompute the output
154
- #pragma unroll
155
- for (int i = 0; i < kNElts; ++i) {
156
- float out_val = bias_val;
157
- #pragma unroll
158
- for (int w = 0; w < kWidth; ++w) {
159
- out_val += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
160
- }
161
- float out_sigmoid_val = 1.0f / (1.0f + expf(-out_val));
162
- dout_vals[i] = float(dout_vals_load[i]) * out_sigmoid_val
163
- * (1.0f + out_val * (1.0f - out_sigmoid_val));
164
- }
165
- // Exchange the dout_vals. It's possible that we need to do 2 rounds of exchange
166
- // if input_t is 16 bits (since then we'd have 8 values of float)
167
- __syncthreads();
168
- // Thread 0 don't write yet, so that thread kNThreads - 1 can read
169
- // the first elements of the next chunk.
170
- if (tidx > 0) {
171
- #pragma unroll
172
- for (int r = 0; r < kNExchangeRounds; ++r) {
173
- smem_exchange[r * kNThreads + tidx] = reinterpret_cast<vec_t *>(dout_vals)[r];
174
- }
175
- }
176
- __syncthreads();
177
- #pragma unroll
178
- for (int r = 0; r < kNExchangeRounds; ++r) {
179
- reinterpret_cast<vec_t *>(dout_vals)[kNExchangeRounds + r]
180
- = smem_exchange[r * kNThreads + (tidx < kNThreads - 1 ? tidx + 1 : 0)];
181
- }
182
- __syncthreads();
183
- // Now thread 0 can write the first elements of the current chunk.
184
- if (tidx == 0) {
185
- #pragma unroll
186
- for (int r = 0; r < kNExchangeRounds; ++r) {
187
- smem_exchange[r * kNThreads + tidx] = reinterpret_cast<vec_t *>(dout_vals)[r];
188
- }
189
- }
190
- }
191
- dout -= kChunkSize;
192
- x -= kChunkSize;
193
-
194
- #pragma unroll
195
- for (int i = 0; i < kNElts; ++i) { dbias_val += dout_vals[i]; }
196
-
197
- float dx_vals[kNElts] = {0};
198
- #pragma unroll
199
- for (int i = 0; i < kNElts; ++i) {
200
- #pragma unroll
201
- for (int w = 0; w < kWidth; ++w) {
202
- dx_vals[i] += weight_vals[w] * dout_vals[i + kWidth - w - 1];
203
- }
204
- }
205
-
206
- input_t dx_vals_store[kNElts];
207
- #pragma unroll
208
- for (int i = 0; i < kNElts; ++i) { dx_vals_store[i] = dx_vals[i]; }
209
- if constexpr(kIsVecLoad) {
210
- Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(dx), reinterpret_cast<vec_t (&)[1]>(dx_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
211
- } else {
212
- Ktraits::BlockStoreT(smem_store).Store(dx, dx_vals_store, params.seqlen - chunk * kChunkSize);
213
- }
214
- dx -= kChunkSize;
215
-
216
- #pragma unroll
217
- for (int w = 0; w < kWidth; ++w) {
218
- #pragma unroll
219
- for (int i = 0; i < kNElts; ++i) {
220
- dweight_vals[w] += x_vals[kNElts + i] * dout_vals[i + kWidth - w - 1];
221
- }
222
- }
223
- }
224
-
225
- #pragma unroll
226
- for (int w = 0; w < kWidth; ++w) {
227
- __syncthreads();
228
- dweight_vals[w] = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dweight_vals[w]);
229
- if (tidx == 0) {
230
- atomicAdd(&reinterpret_cast<float *>(dweight)[w * params.dweight_width_stride], dweight_vals[w]);
231
- }
232
- }
233
- if (params.bias_ptr != nullptr) {
234
- __syncthreads();
235
- dbias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dbias_val);
236
- if (tidx == 0) {
237
- atomicAdd(&reinterpret_cast<float *>(params.dbias_ptr)[dim_id], dbias_val);
238
- }
239
- }
240
- }
241
-
242
- template<int kNThreads, int kWidth, typename input_t, typename weight_t>
243
- void causal_conv1d_bwd_launch(ConvParamsBwd &params, cudaStream_t stream) {
244
- static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
245
- BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
246
- BOOL_SWITCH(params.silu_activation, kSiluAct, [&] {
247
- using Ktraits = Causal_conv1d_bwd_kernel_traits<kNThreads, kWidth, kSiluAct, kIsVecLoad, input_t, weight_t>;
248
- constexpr int kSmemSize = Ktraits::kSmemSize;
249
- dim3 grid(params.batch, params.dim);
250
- auto kernel = &causal_conv1d_bwd_kernel<Ktraits>;
251
- if (kSmemSize >= 48 * 1024) {
252
- C10_CUDA_CHECK(cudaFuncSetAttribute(
253
- kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
254
- }
255
- kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
256
- C10_CUDA_KERNEL_LAUNCH_CHECK();
257
- });
258
- });
259
- }
260
-
261
- template<typename input_t, typename weight_t>
262
- void causal_conv1d_bwd_cuda(ConvParamsBwd &params, cudaStream_t stream) {
263
- if (params.width == 2) {
264
- causal_conv1d_bwd_launch<128, 2, input_t, weight_t>(params, stream);
265
- } else if (params.width == 3) {
266
- causal_conv1d_bwd_launch<128, 3, input_t, weight_t>(params, stream);
267
- } else if (params.width == 4) {
268
- causal_conv1d_bwd_launch<128, 4, input_t, weight_t>(params, stream);
269
- }
270
- }
271
-
272
- template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kSiluAct_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
273
- struct Causal_conv1d_channellast_bwd_kernel_traits {
274
- // The cache line is 128 bytes, and we try to read 16 bytes per thread.
275
- // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
276
- // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
277
- // threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
278
- using input_t = input_t_;
279
- using weight_t = weight_t_;
280
- static constexpr bool kSiluAct = kSiluAct_;
281
- static constexpr int kNThreads = kNThreads_;
282
- static_assert(kNThreads % 32 == 0);
283
- static constexpr int kNWarps = kNThreads / 32;
284
- static constexpr int kWidth = kWidth_;
285
- static constexpr int kChunkSizeL = kChunkSizeL_;
286
- static constexpr int kNBytes = sizeof(input_t);
287
- static_assert(kNBytes == 2 || kNBytes == 4);
288
- static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
289
- static constexpr int kNEltsPerRow = 128 / kNBytes;
290
- static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
291
- static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
292
- static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
293
- static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
294
- static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
295
- static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
296
- static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
297
- static constexpr bool kIsVecLoad = kIsVecLoad_;
298
- using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
299
- // using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
300
- // using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
301
- // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
302
- // sizeof(typename BlockStoreT::TempStorage)});
303
- // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
304
- };
305
-
306
- template<typename Ktraits>
307
- __global__ __launch_bounds__(Ktraits::kNThreads)
308
- void causal_conv1d_channellast_bwd_kernel(ConvParamsBwd params) {
309
- constexpr int kWidth = Ktraits::kWidth;
310
- constexpr int kNThreads = Ktraits::kNThreads;
311
- constexpr bool kSiluAct = Ktraits::kSiluAct;
312
- constexpr int kNElts = Ktraits::kNElts;
313
- constexpr int kNWarp = Ktraits::kNWarps;
314
- constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
315
- constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
316
- constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
317
- constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
318
- using input_t = typename Ktraits::input_t;
319
- using vec_t = typename Ktraits::vec_t;
320
- using weight_t = typename Ktraits::weight_t;
321
-
322
- // Shared memory.
323
- __shared__ input_t dout_smem[kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts];
324
- __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts];
325
-
326
- const int tid = threadIdx.x;
327
- const int l_idx = tid / kNThreadsPerC;
328
- const int c_idx = tid % kNThreadsPerC;
329
- const int batch_id = blockIdx.x;
330
- const int chunk_l_id = blockIdx.y;
331
- const int chunk_c_id = blockIdx.z;
332
- input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
333
- + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
334
- weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
335
- + chunk_c_id * kChunkSizeC * params.weight_c_stride;
336
- input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
337
- + (chunk_l_id * kChunkSizeL + l_idx) * params.dout_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
338
- input_t *dx = reinterpret_cast<input_t *>(params.dx_ptr) + batch_id * params.dx_batch_stride
339
- + (chunk_l_id * kChunkSizeL + l_idx) * params.dx_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
340
- float *dweight = reinterpret_cast<float *>(params.dweight_ptr)
341
- + chunk_c_id * kChunkSizeC * params.dweight_c_stride;
342
-
343
- #pragma unroll
344
- for (int l = 0; l < Ktraits::kNLoads; ++l) {
345
- input_t dout_vals_load[kNElts] = {0};
346
- input_t x_vals_load[kNElts] = {0};
347
- if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
348
- && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
349
- reinterpret_cast<vec_t *>(dout_vals_load)[0] = *reinterpret_cast<vec_t *>(dout + l * kLPerLoad * params.dout_l_stride);
350
- reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
351
- }
352
- reinterpret_cast<vec_t *>(dout_smem[l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(dout_vals_load)[0];
353
- reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
354
- }
355
- // Load the elements from the previous chunk or next chunk that are needed for convolution.
356
- if (l_idx < kWidth - 1) {
357
- input_t dout_vals_load[kNElts] = {0};
358
- input_t x_vals_load[kNElts] = {0};
359
- if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen
360
- && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
361
- reinterpret_cast<vec_t *>(dout_vals_load)[0] = *reinterpret_cast<vec_t *>(dout + kChunkSizeL * params.dout_l_stride);
362
- }
363
- if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
364
- && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
365
- && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
366
- reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
367
- }
368
- reinterpret_cast<vec_t *>(dout_smem[kChunkSizeL + l_idx])[c_idx] = reinterpret_cast<vec_t *>(dout_vals_load)[0];
369
- reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
370
- }
371
- // Need to load (kWdith - 1) extra x's on the right to recompute the (kChunkSizeL + kWidth - 1) outputs
372
- if constexpr (kSiluAct) {
373
- if (l_idx < kWidth - 1) {
374
- input_t x_vals_load[kNElts] = {0};
375
- if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen
376
- && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
377
- reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + kChunkSizeL * params.x_l_stride);
378
- }
379
- reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + kChunkSizeL + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
380
- }
381
- }
382
-
383
- __syncthreads();
384
-
385
- constexpr int kLPerThread = std::min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
386
- static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
387
- constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
388
- static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
389
- // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
390
- static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
391
- static_assert((kLPerThread & (kLPerThread - 1)) == 0);
392
- static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
393
- static_assert(kNThreadsPerRow <= 32);
394
-
395
- const int row_idx = tid / kNThreadsPerRow;
396
- const int col_idx = tid % kNThreadsPerRow;
397
-
398
- float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
399
- float weight_vals[kWidth] = {0};
400
- if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
401
- #pragma unroll
402
- for (int w = 0; w < kWidth; ++w) {
403
- weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
404
- }
405
- }
406
- float dout_vals[kLPerThread + kWidth - 1];
407
- float x_vals[kWidth - 1 + kLPerThread + kWidth - 1];
408
- #pragma unroll
409
- for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
410
- dout_vals[i] = float(dout_smem[col_idx * kLPerThread + i][row_idx]);
411
- x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
412
- }
413
-
414
- if constexpr (kSiluAct) { // Recompute the output
415
- #pragma unroll
416
- for (int i = kWidth - 1 + kLPerThread; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) {
417
- x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
418
- }
419
- #pragma unroll
420
- for (int i = 0; i < kLPerThread + kWidth - 1; ++i) {
421
- float out_val = bias_val;
422
- #pragma unroll
423
- for (int w = 0; w < kWidth; ++w) { out_val += weight_vals[w] * x_vals[i + w]; }
424
- float out_val_sigmoid = 1.f / (1.f + expf(-out_val));
425
- dout_vals[i] *= out_val_sigmoid * (1 + out_val * (1 - out_val_sigmoid));
426
- }
427
- }
428
-
429
- float dweight_vals[kWidth] = {0};
430
- SumOp<float> sum_op;
431
- #pragma unroll
432
- for (int w = 0; w < kWidth; ++w) {
433
- #pragma unroll
434
- for (int i = 0; i < kLPerThread; ++i) { dweight_vals[w] += x_vals[i + w] * dout_vals[i]; }
435
- dweight_vals[w] = Allreduce<kNThreadsPerRow>::run(dweight_vals[w], sum_op);
436
- if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
437
- atomicAdd(&reinterpret_cast<float *>(dweight)[row_idx * params.dweight_c_stride + w * params.dweight_width_stride], dweight_vals[w]);
438
- }
439
- }
440
-
441
- if (params.bias_ptr != nullptr) {
442
- float dbias_val = 0.f;
443
- for (int i = 0; i < kLPerThread; ++i) { dbias_val += dout_vals[i]; }
444
- dbias_val = Allreduce<kNThreadsPerRow>::run(dbias_val, sum_op);
445
- if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
446
- atomicAdd(&reinterpret_cast<float *>(params.dbias_ptr)[chunk_c_id * kChunkSizeC + row_idx], dbias_val);
447
- }
448
- }
449
-
450
- float dx_vals[kLPerThread] = {0};
451
- #pragma unroll
452
- for (int i = 0; i < kLPerThread; ++i) {
453
- #pragma unroll
454
- for (int w = 0; w < kWidth; ++w) { dx_vals[i] += weight_vals[kWidth - 1 - w] * dout_vals[i + w]; }
455
- }
456
- // Since kNThreadsPerRow is a power of 2 and <= 32, we only need syncwarp and not syncthreads.
457
- __syncwarp();
458
- #pragma unroll
459
- for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = dx_vals[i]; }
460
- __syncthreads();
461
-
462
- #pragma unroll
463
- for (int l = 0; l < Ktraits::kNLoads; ++l) {
464
- input_t dx_vals_store[kNElts];
465
- reinterpret_cast<vec_t *>(dx_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx];
466
- if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
467
- && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
468
- *reinterpret_cast<vec_t *>(dx + l * kLPerLoad * params.dx_l_stride) = reinterpret_cast<vec_t *>(dx_vals_store)[0];
469
- }
470
- }
471
-
472
- }
473
-
474
- template<int kNThreads, int kWidth, typename input_t, typename weight_t>
475
- void causal_conv1d_channellast_bwd_launch(ConvParamsBwd &params, cudaStream_t stream) {
476
- BOOL_SWITCH(params.silu_activation, kSiluAct, [&] {
477
- using Ktraits = Causal_conv1d_channellast_bwd_kernel_traits<kNThreads, kWidth, 64, kSiluAct, true, input_t, weight_t>;
478
- // constexpr int kSmemSize = Ktraits::kSmemSize;
479
- constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
480
- constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
481
- const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
482
- const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
483
- dim3 grid(params.batch, n_chunks_L, n_chunks_C);
484
- dim3 block(Ktraits::kNThreads);
485
- auto kernel = &causal_conv1d_channellast_bwd_kernel<Ktraits>;
486
- // if (kSmemSize >= 48 * 1024) {
487
- // C10_CUDA_CHECK(cudaFuncSetAttribute(
488
- // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
489
- // }
490
- // kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
491
- kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
492
- C10_CUDA_KERNEL_LAUNCH_CHECK();
493
- });
494
- }
495
-
496
- template<typename input_t, typename weight_t>
497
- void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd &params, cudaStream_t stream) {
498
- if (params.width == 2) {
499
- causal_conv1d_channellast_bwd_launch<128, 2, input_t, weight_t>(params, stream);
500
- } else if (params.width == 3) {
501
- causal_conv1d_channellast_bwd_launch<128, 3, input_t, weight_t>(params, stream);
502
- } else if (params.width == 4) {
503
- causal_conv1d_channellast_bwd_launch<128, 4, input_t, weight_t>(params, stream);
504
- }
505
- }
506
-
507
- template void causal_conv1d_bwd_cuda<float, float>(ConvParamsBwd &params, cudaStream_t stream);
508
- template void causal_conv1d_bwd_cuda<at::Half, float>(ConvParamsBwd &params, cudaStream_t stream);
509
- template void causal_conv1d_bwd_cuda<at::BFloat16, float>(ConvParamsBwd &params, cudaStream_t stream);
510
- template void causal_conv1d_bwd_cuda<float, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
511
- template void causal_conv1d_bwd_cuda<at::Half, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
512
- template void causal_conv1d_bwd_cuda<at::BFloat16, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
513
- template void causal_conv1d_bwd_cuda<float, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
514
- template void causal_conv1d_bwd_cuda<at::Half, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
515
- template void causal_conv1d_bwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
516
-
517
- template void causal_conv1d_channellast_bwd_cuda<float, float>(ConvParamsBwd &params, cudaStream_t stream);
518
- template void causal_conv1d_channellast_bwd_cuda<at::Half, float>(ConvParamsBwd &params, cudaStream_t stream);
519
- template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, float>(ConvParamsBwd &params, cudaStream_t stream);
520
- template void causal_conv1d_channellast_bwd_cuda<float, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
521
- template void causal_conv1d_channellast_bwd_cuda<at::Half, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
522
- template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
523
- template void causal_conv1d_channellast_bwd_cuda<float, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
524
- template void causal_conv1d_channellast_bwd_cuda<at::Half, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
525
- template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
causal-conv1d/csrc/causal_conv1d_common.h DELETED
@@ -1,64 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2023, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include <cuda_bf16.h>
8
- #include <cuda_fp16.h>
9
-
10
- ////////////////////////////////////////////////////////////////////////////////////////////////////
11
-
12
- template<int BYTES> struct BytesToType {};
13
-
14
- template<> struct BytesToType<16> {
15
- using Type = uint4;
16
- static_assert(sizeof(Type) == 16);
17
- };
18
-
19
- template<> struct BytesToType<8> {
20
- using Type = uint64_t;
21
- static_assert(sizeof(Type) == 8);
22
- };
23
-
24
- template<> struct BytesToType<4> {
25
- using Type = uint32_t;
26
- static_assert(sizeof(Type) == 4);
27
- };
28
-
29
- template<> struct BytesToType<2> {
30
- using Type = uint16_t;
31
- static_assert(sizeof(Type) == 2);
32
- };
33
-
34
- template<> struct BytesToType<1> {
35
- using Type = uint8_t;
36
- static_assert(sizeof(Type) == 1);
37
- };
38
-
39
- ////////////////////////////////////////////////////////////////////////////////////////////////////
40
-
41
- template<typename T>
42
- struct SumOp {
43
- __device__ inline T operator()(T const & x, T const & y) { return x + y; }
44
- };
45
-
46
- template<int THREADS>
47
- struct Allreduce {
48
- static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
49
- template<typename T, typename Operator>
50
- static __device__ inline T run(T x, Operator &op) {
51
- constexpr int OFFSET = THREADS / 2;
52
- x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
53
- return Allreduce<OFFSET>::run(x, op);
54
- }
55
- };
56
-
57
- template<>
58
- struct Allreduce<2> {
59
- template<typename T, typename Operator>
60
- static __device__ inline T run(T x, Operator &op) {
61
- x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
62
- return x;
63
- }
64
- };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
causal-conv1d/csrc/causal_conv1d_fwd.cu DELETED
@@ -1,350 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2023, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #include <c10/util/BFloat16.h>
6
- #include <c10/util/Half.h>
7
- #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
8
-
9
- #include <cub/block/block_load.cuh>
10
- #include <cub/block/block_store.cuh>
11
-
12
- #include "causal_conv1d.h"
13
- #include "causal_conv1d_common.h"
14
- #include "static_switch.h"
15
-
16
- template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
17
- struct Causal_conv1d_fwd_kernel_traits {
18
- using input_t = input_t_;
19
- using weight_t = weight_t_;
20
- static constexpr int kNThreads = kNThreads_;
21
- static constexpr int kWidth = kWidth_;
22
- static constexpr int kNBytes = sizeof(input_t);
23
- static_assert(kNBytes == 2 || kNBytes == 4);
24
- static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
25
- static_assert(kWidth <= kNElts);
26
- static constexpr bool kIsVecLoad = kIsVecLoad_;
27
- using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
28
- using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
29
- using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
30
- using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
31
- using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
32
- static constexpr int kSmemIOSize = kIsVecLoad
33
- ? 0
34
- : std::max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
35
- static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
36
- static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
37
- };
38
-
39
- template<typename Ktraits>
40
- __global__ __launch_bounds__(Ktraits::kNThreads)
41
- void causal_conv1d_fwd_kernel(ConvParamsBase params) {
42
- constexpr int kWidth = Ktraits::kWidth;
43
- constexpr int kNThreads = Ktraits::kNThreads;
44
- constexpr int kNElts = Ktraits::kNElts;
45
- constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
46
- using input_t = typename Ktraits::input_t;
47
- using vec_t = typename Ktraits::vec_t;
48
- using weight_t = typename Ktraits::weight_t;
49
-
50
- // Shared memory.
51
- extern __shared__ char smem_[];
52
- auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
53
- auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
54
- auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
55
- auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
56
- vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
57
-
58
- const int tidx = threadIdx.x;
59
- const int batch_id = blockIdx.x;
60
- const int channel_id = blockIdx.y;
61
- input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
62
- + channel_id * params.x_c_stride;
63
- weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
64
- input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
65
- + channel_id * params.out_c_stride;
66
- float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
67
-
68
- // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
69
- if (tidx == 0) {
70
- input_t zeros[kNElts] = {0};
71
- smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(zeros)[0];
72
- }
73
-
74
- float weight_vals[kWidth];
75
- #pragma unroll
76
- for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
77
-
78
- constexpr int kChunkSize = kNThreads * kNElts;
79
- const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
80
- for (int chunk = 0; chunk < n_chunks; ++chunk) {
81
- input_t x_vals_load[2 * kNElts] = {0};
82
- if constexpr(kIsVecLoad) {
83
- Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
84
- } else {
85
- __syncthreads();
86
- Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
87
- }
88
- x += kChunkSize;
89
- __syncthreads();
90
- // Thread kNThreads - 1 don't write yet, so that thread 0 can read
91
- // the last elements of the previous chunk.
92
- if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
93
- __syncthreads();
94
- reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
95
- __syncthreads();
96
- // Now thread kNThreads - 1 can write the last elements of the current chunk.
97
- if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
98
-
99
- float x_vals[2 * kNElts];
100
- #pragma unroll
101
- for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
102
-
103
- float out_vals[kNElts];
104
- #pragma unroll
105
- for (int i = 0; i < kNElts; ++i) {
106
- out_vals[i] = bias_val;
107
- #pragma unroll
108
- for (int w = 0; w < kWidth; ++w) {
109
- out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
110
- }
111
- }
112
-
113
- if (params.silu_activation) {
114
- #pragma unroll
115
- for (int i = 0; i < kNElts; ++i) {
116
- out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
117
- }
118
- }
119
-
120
- input_t out_vals_store[kNElts];
121
- #pragma unroll
122
- for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
123
- if constexpr(kIsVecLoad) {
124
- Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
125
- } else {
126
- Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize);
127
- }
128
- out += kChunkSize;
129
- }
130
- }
131
-
132
- template<int kNThreads, int kWidth, typename input_t, typename weight_t>
133
- void causal_conv1d_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
134
- static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
135
- BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
136
- using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
137
- constexpr int kSmemSize = Ktraits::kSmemSize;
138
- dim3 grid(params.batch, params.dim);
139
- auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
140
- if (kSmemSize >= 48 * 1024) {
141
- C10_CUDA_CHECK(cudaFuncSetAttribute(
142
- kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
143
- }
144
- kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
145
- C10_CUDA_KERNEL_LAUNCH_CHECK();
146
- });
147
- }
148
-
149
- template<typename input_t, typename weight_t>
150
- void causal_conv1d_fwd_cuda(ConvParamsBase &params, cudaStream_t stream) {
151
- if (params.width == 2) {
152
- causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
153
- } else if (params.width == 3) {
154
- causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
155
- } else if (params.width == 4) {
156
- causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
157
- }
158
- }
159
-
160
- template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
161
- struct Causal_conv1d_channellast_fwd_kernel_traits {
162
- // The cache line is 128 bytes, and we try to read 16 bytes per thread.
163
- // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
164
- // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
165
- // threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
166
- using input_t = input_t_;
167
- using weight_t = weight_t_;
168
- static constexpr int kNThreads = kNThreads_;
169
- static_assert(kNThreads % 32 == 0);
170
- static constexpr int kNWarps = kNThreads / 32;
171
- static constexpr int kWidth = kWidth_;
172
- static constexpr int kChunkSizeL = kChunkSizeL_;
173
- static constexpr int kNBytes = sizeof(input_t);
174
- static_assert(kNBytes == 2 || kNBytes == 4);
175
- static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
176
- static constexpr int kNEltsPerRow = 128 / kNBytes;
177
- static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
178
- static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
179
- static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
180
- static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
181
- static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
182
- static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
183
- static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
184
- static constexpr bool kIsVecLoad = kIsVecLoad_;
185
- using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
186
- // using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
187
- // using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
188
- // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
189
- // sizeof(typename BlockStoreT::TempStorage)});
190
- // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
191
- };
192
-
193
- template<typename Ktraits>
194
- __global__ __launch_bounds__(Ktraits::kNThreads)
195
- void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) {
196
- constexpr int kWidth = Ktraits::kWidth;
197
- constexpr int kNThreads = Ktraits::kNThreads;
198
- constexpr int kNElts = Ktraits::kNElts;
199
- constexpr int kNWarp = Ktraits::kNWarps;
200
- constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
201
- constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
202
- constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
203
- constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
204
- using input_t = typename Ktraits::input_t;
205
- using vec_t = typename Ktraits::vec_t;
206
- using weight_t = typename Ktraits::weight_t;
207
-
208
- // Shared memory.
209
- __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts];
210
-
211
- const int tid = threadIdx.x;
212
- const int l_idx = tid / kNThreadsPerC;
213
- const int c_idx = tid % kNThreadsPerC;
214
- const int batch_id = blockIdx.x;
215
- const int chunk_l_id = blockIdx.y;
216
- const int chunk_c_id = blockIdx.z;
217
- input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
218
- + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
219
- weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
220
- + chunk_c_id * kChunkSizeC * params.weight_c_stride;
221
- input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
222
- + (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
223
-
224
- #pragma unroll
225
- for (int l = 0; l < Ktraits::kNLoads; ++l) {
226
- input_t x_vals_load[kNElts] = {0};
227
- if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
228
- && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
229
- reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
230
- }
231
- reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
232
- }
233
- // Load the elements from the previous chunk that are needed for convolution.
234
- if (l_idx < kWidth - 1) {
235
- input_t x_vals_load[kNElts] = {0};
236
- if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
237
- && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
238
- && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
239
- reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
240
- }
241
- reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
242
- }
243
-
244
- __syncthreads();
245
-
246
- constexpr int kLPerThread = std::min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
247
- static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
248
- constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
249
- static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
250
- // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
251
- static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
252
- static_assert((kLPerThread & (kLPerThread - 1)) == 0);
253
- static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
254
- static_assert(kNThreadsPerRow <= 32);
255
-
256
- const int row_idx = tid / kNThreadsPerRow;
257
- const int col_idx = tid % kNThreadsPerRow;
258
-
259
- float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
260
- float weight_vals[kWidth] = {0};
261
- if (chunk_c_id + kChunkSizeC + row_idx < params.dim) {
262
- #pragma unroll
263
- for (int w = 0; w < kWidth; ++w) {
264
- weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
265
- }
266
- }
267
- float x_vals[kWidth - 1 + kLPerThread];
268
- #pragma unroll
269
- for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
270
- x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
271
- }
272
-
273
- float out_vals[kLPerThread];
274
- #pragma unroll
275
- for (int i = 0; i < kLPerThread; ++i) {
276
- out_vals[i] = bias_val;
277
- #pragma unroll
278
- for (int w = 0; w < kWidth; ++w) { out_vals[i] += weight_vals[w] * x_vals[i + w]; }
279
- if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); }
280
- }
281
-
282
- // Since kNThreadsPerRow is a power of 2 and <= 32, we only need syncwarp and not syncthreads.
283
- __syncwarp();
284
- #pragma unroll
285
- for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; }
286
- __syncthreads();
287
-
288
- #pragma unroll
289
- for (int l = 0; l < Ktraits::kNLoads; ++l) {
290
- input_t out_vals_store[kNElts];
291
- reinterpret_cast<vec_t *>(out_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx];
292
- if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
293
- && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
294
- *reinterpret_cast<vec_t *>(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast<vec_t *>(out_vals_store)[0];
295
- }
296
- }
297
-
298
- }
299
-
300
- template<int kNThreads, int kWidth, typename input_t, typename weight_t>
301
- void causal_conv1d_channellast_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
302
- using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits<kNThreads, kWidth, 64, true, input_t, weight_t>;
303
- // constexpr int kSmemSize = Ktraits::kSmemSize;
304
- constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
305
- constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
306
- const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
307
- const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
308
- // printf("n_chunks_L: %d, n_chunks_C: %d\n", n_chunks_L, n_chunks_C);
309
- dim3 grid(params.batch, n_chunks_L, n_chunks_C);
310
- dim3 block(Ktraits::kNThreads);
311
- auto kernel = &causal_conv1d_channellast_fwd_kernel<Ktraits>;
312
- // if (kSmemSize >= 48 * 1024) {
313
- // C10_CUDA_CHECK(cudaFuncSetAttribute(
314
- // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
315
- // }
316
- // kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
317
- kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
318
- C10_CUDA_KERNEL_LAUNCH_CHECK();
319
- }
320
-
321
- template<typename input_t, typename weight_t>
322
- void causal_conv1d_channellast_fwd_cuda(ConvParamsBase &params, cudaStream_t stream) {
323
- if (params.width == 2) {
324
- causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream);
325
- } else if (params.width == 3) {
326
- causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream);
327
- } else if (params.width == 4) {
328
- causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream);
329
- }
330
- }
331
-
332
- template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
333
- template void causal_conv1d_fwd_cuda<at::Half, float>(ConvParamsBase &params, cudaStream_t stream);
334
- template void causal_conv1d_fwd_cuda<at::BFloat16, float>(ConvParamsBase &params, cudaStream_t stream);
335
- template void causal_conv1d_fwd_cuda<float, at::Half>(ConvParamsBase &params, cudaStream_t stream);
336
- template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
337
- template void causal_conv1d_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase &params, cudaStream_t stream);
338
- template void causal_conv1d_fwd_cuda<float, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
339
- template void causal_conv1d_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
340
- template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
341
-
342
- template void causal_conv1d_channellast_fwd_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
343
- template void causal_conv1d_channellast_fwd_cuda<at::Half, float>(ConvParamsBase &params, cudaStream_t stream);
344
- template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, float>(ConvParamsBase &params, cudaStream_t stream);
345
- template void causal_conv1d_channellast_fwd_cuda<float, at::Half>(ConvParamsBase &params, cudaStream_t stream);
346
- template void causal_conv1d_channellast_fwd_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
347
- template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase &params, cudaStream_t stream);
348
- template void causal_conv1d_channellast_fwd_cuda<float, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
349
- template void causal_conv1d_channellast_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
350
- template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
causal-conv1d/csrc/causal_conv1d_update.cu DELETED
@@ -1,96 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2023, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #include <c10/util/BFloat16.h>
6
- #include <c10/util/Half.h>
7
- #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
8
-
9
- #include <cub/block/block_load.cuh>
10
- #include <cub/block/block_store.cuh>
11
-
12
- #include "causal_conv1d.h"
13
- #include "causal_conv1d_common.h"
14
- #include "static_switch.h"
15
-
16
- template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
17
- struct Causal_conv1d_update_kernel_traits {
18
- using input_t = input_t_;
19
- using weight_t = weight_t_;
20
- static constexpr int kNThreads = kNThreads_;
21
- static constexpr int kWidth = kWidth_;
22
- static constexpr int kNBytes = sizeof(input_t);
23
- static_assert(kNBytes == 2 || kNBytes == 4);
24
- };
25
-
26
- template<typename Ktraits>
27
- __global__ __launch_bounds__(Ktraits::kNThreads)
28
- void causal_conv1d_update_kernel(ConvParamsBase params) {
29
- constexpr int kWidth = Ktraits::kWidth;
30
- constexpr int kNThreads = Ktraits::kNThreads;
31
- using input_t = typename Ktraits::input_t;
32
- using weight_t = typename Ktraits::weight_t;
33
-
34
- const int tidx = threadIdx.x;
35
- const int batch_id = blockIdx.x;
36
- const int channel_id = blockIdx.y * kNThreads + tidx;
37
- input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
38
- + channel_id * params.x_c_stride;
39
- input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride
40
- + channel_id * params.conv_state_c_stride;
41
- weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
42
- input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
43
- + channel_id * params.out_c_stride;
44
- float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
45
-
46
- float weight_vals[kWidth] = {0};
47
- if (channel_id < params.dim) {
48
- #pragma unroll
49
- for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
50
- }
51
-
52
- float x_vals[kWidth] = {0};
53
- if (channel_id < params.dim) {
54
- #pragma unroll
55
- for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); }
56
- x_vals[kWidth - 1] = float(x[0]);
57
- #pragma unroll
58
- for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); }
59
- }
60
-
61
- float out_val = bias_val;
62
- #pragma unroll
63
- for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; }
64
- if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
65
- if (channel_id < params.dim) { out[0] = input_t(out_val); }
66
- }
67
-
68
- template<int kNThreads, int kWidth, typename input_t, typename weight_t>
69
- void causal_conv1d_update_launch(ConvParamsBase &params, cudaStream_t stream) {
70
- using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
71
- dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
72
- auto kernel = &causal_conv1d_update_kernel<Ktraits>;
73
- kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
74
- C10_CUDA_KERNEL_LAUNCH_CHECK();
75
- }
76
-
77
- template<typename input_t, typename weight_t>
78
- void causal_conv1d_update_cuda(ConvParamsBase &params, cudaStream_t stream) {
79
- if (params.width == 2) {
80
- causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
81
- } else if (params.width == 3) {
82
- causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
83
- } else if (params.width == 4) {
84
- causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
85
- }
86
- }
87
-
88
- template void causal_conv1d_update_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
89
- template void causal_conv1d_update_cuda<at::Half, float>(ConvParamsBase &params, cudaStream_t stream);
90
- template void causal_conv1d_update_cuda<at::BFloat16, float>(ConvParamsBase &params, cudaStream_t stream);
91
- template void causal_conv1d_update_cuda<float, at::Half>(ConvParamsBase &params, cudaStream_t stream);
92
- template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
93
- template void causal_conv1d_update_cuda<at::BFloat16, at::Half>(ConvParamsBase &params, cudaStream_t stream);
94
- template void causal_conv1d_update_cuda<float, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
95
- template void causal_conv1d_update_cuda<at::Half, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
96
- template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
causal-conv1d/csrc/static_switch.h DELETED
@@ -1,25 +0,0 @@
1
- // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
2
- // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
3
-
4
- #pragma once
5
-
6
- /// @param COND - a boolean expression to switch by
7
- /// @param CONST_NAME - a name given for the constexpr bool variable.
8
- /// @param ... - code to execute for true and false
9
- ///
10
- /// Usage:
11
- /// ```
12
- /// BOOL_SWITCH(flag, BoolConst, [&] {
13
- /// some_function<BoolConst>(...);
14
- /// });
15
- /// ```
16
- #define BOOL_SWITCH(COND, CONST_NAME, ...) \
17
- [&] { \
18
- if (COND) { \
19
- static constexpr bool CONST_NAME = true; \
20
- return __VA_ARGS__(); \
21
- } else { \
22
- static constexpr bool CONST_NAME = false; \
23
- return __VA_ARGS__(); \
24
- } \
25
- }()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
causal-conv1d/setup.py DELETED
@@ -1,264 +0,0 @@
1
- # Copyright (c) 2023, Tri Dao.
2
- import sys
3
- import warnings
4
- import os
5
- import re
6
- import ast
7
- from pathlib import Path
8
- from packaging.version import parse, Version
9
- import platform
10
-
11
- from setuptools import setup, find_packages
12
- import subprocess
13
-
14
- import urllib.request
15
- import urllib.error
16
- from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
17
-
18
- import torch
19
- from torch.utils.cpp_extension import (
20
- BuildExtension,
21
- CppExtension,
22
- CUDAExtension,
23
- CUDA_HOME,
24
- )
25
-
26
-
27
- with open("README.md", "r", encoding="utf-8") as fh:
28
- long_description = fh.read()
29
-
30
-
31
- # ninja build does not work unless include_dirs are abs path
32
- this_dir = os.path.dirname(os.path.abspath(__file__))
33
-
34
- PACKAGE_NAME = "causal_conv1d"
35
-
36
- BASE_WHEEL_URL = "https://github.com/Dao-AILab/causal-conv1d/releases/download/{tag_name}/{wheel_name}"
37
-
38
- # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
39
- # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
40
- FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "FALSE") == "TRUE"
41
- SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
42
- # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
43
- FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "FALSE") == "TRUE"
44
-
45
-
46
- def get_platform():
47
- """
48
- Returns the platform name as used in wheel filenames.
49
- """
50
- if sys.platform.startswith("linux"):
51
- return "linux_x86_64"
52
- elif sys.platform == "darwin":
53
- mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
54
- return f"macosx_{mac_version}_x86_64"
55
- elif sys.platform == "win32":
56
- return "win_amd64"
57
- else:
58
- raise ValueError("Unsupported platform: {}".format(sys.platform))
59
-
60
-
61
- def get_cuda_bare_metal_version(cuda_dir):
62
- raw_output = subprocess.check_output(
63
- [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
64
- )
65
- output = raw_output.split()
66
- release_idx = output.index("release") + 1
67
- bare_metal_version = parse(output[release_idx].split(",")[0])
68
-
69
- return raw_output, bare_metal_version
70
-
71
-
72
- def check_if_cuda_home_none(global_option: str) -> None:
73
- if CUDA_HOME is not None:
74
- return
75
- # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
76
- # in that case.
77
- warnings.warn(
78
- f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
79
- "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
80
- "only images whose names contain 'devel' will provide nvcc."
81
- )
82
-
83
-
84
- def append_nvcc_threads(nvcc_extra_args):
85
- return nvcc_extra_args + ["--threads", "4"]
86
-
87
-
88
- cmdclass = {}
89
- ext_modules = []
90
-
91
- if not SKIP_CUDA_BUILD:
92
- print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
93
- TORCH_MAJOR = int(torch.__version__.split(".")[0])
94
- TORCH_MINOR = int(torch.__version__.split(".")[1])
95
-
96
- check_if_cuda_home_none("causal_conv1d")
97
- # Check, if CUDA11 is installed for compute capability 8.0
98
- cc_flag = []
99
- if CUDA_HOME is not None:
100
- _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
101
- if bare_metal_version < Version("11.6"):
102
- raise RuntimeError(
103
- "causal_conv1d is only supported on CUDA 11.6 and above. "
104
- "Note: make sure nvcc has a supported version by running nvcc -V."
105
- )
106
-
107
- cc_flag.append("-gencode")
108
- cc_flag.append("arch=compute_70,code=sm_70")
109
- cc_flag.append("-gencode")
110
- cc_flag.append("arch=compute_80,code=sm_80")
111
- if bare_metal_version >= Version("11.8"):
112
- cc_flag.append("-gencode")
113
- cc_flag.append("arch=compute_90,code=sm_90")
114
-
115
- # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
116
- # torch._C._GLIBCXX_USE_CXX11_ABI
117
- # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
118
- if FORCE_CXX11_ABI:
119
- torch._C._GLIBCXX_USE_CXX11_ABI = True
120
-
121
- ext_modules.append(
122
- CUDAExtension(
123
- name="causal_conv1d_cuda",
124
- sources=[
125
- "csrc/causal_conv1d.cpp",
126
- "csrc/causal_conv1d_fwd.cu",
127
- "csrc/causal_conv1d_bwd.cu",
128
- "csrc/causal_conv1d_update.cu",
129
- ],
130
- extra_compile_args={
131
- "cxx": ["-O3"],
132
- "nvcc": append_nvcc_threads(
133
- [
134
- "-O3",
135
- "-U__CUDA_NO_HALF_OPERATORS__",
136
- "-U__CUDA_NO_HALF_CONVERSIONS__",
137
- "-U__CUDA_NO_BFLOAT16_OPERATORS__",
138
- "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
139
- "-U__CUDA_NO_BFLOAT162_OPERATORS__",
140
- "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
141
- "--expt-relaxed-constexpr",
142
- "--expt-extended-lambda",
143
- "--use_fast_math",
144
- "--ptxas-options=-v",
145
- "-lineinfo",
146
- ]
147
- + cc_flag
148
- ),
149
- },
150
- include_dirs=[this_dir],
151
- )
152
- )
153
-
154
-
155
- def get_package_version():
156
- with open(Path(this_dir) / "causal_conv1d" / "__init__.py", "r") as f:
157
- version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
158
- public_version = ast.literal_eval(version_match.group(1))
159
- local_version = os.environ.get("CAUSAL_CONV1D_LOCAL_VERSION")
160
- if local_version:
161
- return f"{public_version}+{local_version}"
162
- else:
163
- return str(public_version)
164
-
165
-
166
- def get_wheel_url():
167
- # Determine the version numbers that will be used to determine the correct wheel
168
- # We're using the CUDA version used to build torch, not the one currently installed
169
- # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
170
- torch_cuda_version = parse(torch.version.cuda)
171
- torch_version_raw = parse(torch.__version__)
172
- # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
173
- # to save CI time. Minor versions should be compatible.
174
- torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
175
- python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
176
- platform_name = get_platform()
177
- causal_conv1d_version = get_package_version()
178
- # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
179
- cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
180
- torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
181
- cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
182
-
183
- # Determine wheel URL based on CUDA version, torch version, python version and OS
184
- wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
185
- wheel_url = BASE_WHEEL_URL.format(
186
- tag_name=f"v{causal_conv1d_version}", wheel_name=wheel_filename
187
- )
188
- return wheel_url, wheel_filename
189
-
190
-
191
- class CachedWheelsCommand(_bdist_wheel):
192
- """
193
- The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
194
- find an existing wheel (which is currently the case for all installs). We use
195
- the environment parameters to detect whether there is already a pre-built version of a compatible
196
- wheel available and short-circuits the standard full build pipeline.
197
- """
198
-
199
- def run(self):
200
- if FORCE_BUILD:
201
- return super().run()
202
-
203
- wheel_url, wheel_filename = get_wheel_url()
204
- print("Guessing wheel URL: ", wheel_url)
205
- try:
206
- urllib.request.urlretrieve(wheel_url, wheel_filename)
207
-
208
- # Make the archive
209
- # Lifted from the root wheel processing command
210
- # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
211
- if not os.path.exists(self.dist_dir):
212
- os.makedirs(self.dist_dir)
213
-
214
- impl_tag, abi_tag, plat_tag = self.get_tag()
215
- archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
216
-
217
- wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
218
- print("Raw wheel path", wheel_path)
219
- os.rename(wheel_filename, wheel_path)
220
- except urllib.error.HTTPError:
221
- print("Precompiled wheel not found. Building from source...")
222
- # If the wheel could not be downloaded, build from source
223
- super().run()
224
-
225
-
226
- setup(
227
- name=PACKAGE_NAME,
228
- version=get_package_version(),
229
- packages=find_packages(
230
- exclude=(
231
- "build",
232
- "csrc",
233
- "include",
234
- "tests",
235
- "dist",
236
- "docs",
237
- "benchmarks",
238
- "causal_conv1d.egg-info",
239
- )
240
- ),
241
- author="Tri Dao",
242
- author_email="tri@tridao.me",
243
- description="Causal depthwise conv1d in CUDA, with a PyTorch interface",
244
- long_description=long_description,
245
- long_description_content_type="text/markdown",
246
- url="https://github.com/Dao-AILab/causal-conv1d",
247
- classifiers=[
248
- "Programming Language :: Python :: 3",
249
- "License :: OSI Approved :: BSD License",
250
- "Operating System :: Unix",
251
- ],
252
- ext_modules=ext_modules,
253
- cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension}
254
- if ext_modules
255
- else {
256
- "bdist_wheel": CachedWheelsCommand,
257
- },
258
- python_requires=">=3.7",
259
- install_requires=[
260
- "torch",
261
- "packaging",
262
- "ninja",
263
- ],
264
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
causal-conv1d/tests/test_causal_conv1d.py DELETED
@@ -1,173 +0,0 @@
1
- # Copyright (C) 2023, Tri Dao.
2
-
3
- import math
4
-
5
- import torch
6
- import pytest
7
-
8
- from einops import rearrange
9
-
10
- from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_ref
11
- from causal_conv1d.causal_conv1d_interface import causal_conv1d_update, causal_conv1d_update_ref
12
-
13
-
14
- @pytest.mark.parametrize("channel_last", [False, True])
15
- # @pytest.mark.parametrize('channel_last', [True])
16
- @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
17
- # @pytest.mark.parametrize('itype', [torch.float16])
18
- @pytest.mark.parametrize("silu_activation", [False, True])
19
- # @pytest.mark.parametrize('silu_activation', [True])
20
- @pytest.mark.parametrize("has_bias", [False, True])
21
- # @pytest.mark.parametrize('has_bias', [True])
22
- @pytest.mark.parametrize("width", [2, 3, 4])
23
- # @pytest.mark.parametrize('width', [2])
24
- @pytest.mark.parametrize(
25
- "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
26
- )
27
- # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
28
- # @pytest.mark.parametrize('seqlen', [128])
29
- def test_causal_conv1d(seqlen, width, has_bias, silu_activation, itype, channel_last):
30
- device = "cuda"
31
- rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
32
- if itype == torch.bfloat16:
33
- rtol, atol = 1e-2, 5e-2
34
- rtolw, atolw = (1e-3, 1e-3)
35
- # set seed
36
- torch.random.manual_seed(0)
37
- batch_size = 2
38
- # batch_size = 1
39
- dim = 4096 + 32 # Try dim not divisible by 64
40
- # dim = 64
41
- if not channel_last:
42
- x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
43
- else:
44
- x = rearrange(
45
- torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
46
- ).requires_grad_()
47
- weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
48
- if has_bias:
49
- bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
50
- else:
51
- bias = None
52
- x_ref = x.detach().clone().requires_grad_()
53
- weight_ref = weight.detach().clone().requires_grad_()
54
- bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
55
- activation = None if not silu_activation else "silu"
56
- out = causal_conv1d_fn(x, weight, bias, activation=activation)
57
- out_ref = causal_conv1d_ref(x_ref, weight_ref, bias_ref, activation=activation)
58
-
59
- print(f"Output max diff: {(out - out_ref).abs().max().item()}")
60
- print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
61
- assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
62
-
63
- g = torch.randn_like(out)
64
- out_ref.backward(g)
65
- out.backward(g)
66
-
67
- print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}")
68
- print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}")
69
- if has_bias:
70
- print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}")
71
-
72
- assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
73
- assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw)
74
- if has_bias:
75
- assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw)
76
-
77
-
78
- @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
79
- # @pytest.mark.parametrize('itype', [torch.float16])
80
- @pytest.mark.parametrize("silu_activation", [False, True])
81
- # @pytest.mark.parametrize('silu_activation', [False])
82
- @pytest.mark.parametrize("has_bias", [False, True])
83
- # @pytest.mark.parametrize('has_bias', [True])
84
- @pytest.mark.parametrize("width", [2, 3, 4])
85
- # @pytest.mark.parametrize('width', [2])
86
- @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
87
- # @pytest.mark.parametrize("dim", [2048])
88
- def test_causal_conv1d_update(dim, width, has_bias, silu_activation, itype):
89
- device = "cuda"
90
- rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
91
- if itype == torch.bfloat16:
92
- rtol, atol = 1e-2, 5e-2
93
- rtolw, atolw = (1e-3, 1e-3)
94
- # set seed
95
- torch.random.manual_seed(0)
96
- batch_size = 2
97
- # batch_size = 1
98
- # dim = 64
99
- x = torch.randn(batch_size, dim, device=device, dtype=itype)
100
- conv_state = torch.randn(batch_size, dim, width, device=device, dtype=itype)
101
- weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
102
- if has_bias:
103
- bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
104
- else:
105
- bias = None
106
- conv_state_ref = conv_state.detach().clone()
107
- activation = None if not silu_activation else "silu"
108
- out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation)
109
- out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation)
110
-
111
- print(f"Output max diff: {(out - out_ref).abs().max().item()}")
112
- print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
113
- assert torch.equal(conv_state, conv_state_ref)
114
- assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
115
-
116
-
117
- # @pytest.mark.parametrize("channel_last", [False, True])
118
- @pytest.mark.parametrize('channel_last', [True])
119
- # @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
120
- @pytest.mark.parametrize('itype', [torch.bfloat16])
121
- # @pytest.mark.parametrize("silu_activation", [False, True])
122
- @pytest.mark.parametrize('silu_activation', [True])
123
- # @pytest.mark.parametrize("has_bias", [False, True])
124
- @pytest.mark.parametrize('has_bias', [True])
125
- # @pytest.mark.parametrize("width", [2, 3, 4])
126
- @pytest.mark.parametrize('width', [4])
127
- @pytest.mark.parametrize(
128
- # "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
129
- "seqlen", [2048]
130
- )
131
- # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
132
- # @pytest.mark.parametrize('seqlen', [128])
133
- def test_causal_conv1d_race_condition(seqlen, width, has_bias, silu_activation, itype, channel_last):
134
- device = "cuda"
135
- # set seed
136
- torch.random.manual_seed(0)
137
- batch_size = 2
138
- # batch_size = 1
139
- dim = 4096 + 32 # Try dim not divisible by 64
140
- # dim = 64
141
- if not channel_last:
142
- x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
143
- else:
144
- x = rearrange(
145
- torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
146
- ).requires_grad_()
147
- weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
148
- if has_bias:
149
- bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
150
- else:
151
- bias = None
152
- activation = None if not silu_activation else "silu"
153
- out0 = causal_conv1d_fn(x, weight, bias, activation=activation)
154
- g = torch.randn_like(out0)
155
- dx0, dw0, db0 = torch.autograd.grad(out0, (x, weight, bias), g)
156
- dw_atol = 1e-4
157
- db_atol = 1e-4
158
-
159
- for i in range(10000):
160
- out = causal_conv1d_fn(x, weight, bias, activation=activation)
161
- dx, dw, db = torch.autograd.grad(out, (x, weight, bias), g)
162
- dw_equal = torch.allclose(dw, dw0, atol=dw_atol)
163
- # if not dw_equal:
164
- # breakpoint()
165
- if has_bias:
166
- db_equal = torch.allclose(db, db0, atol=db_atol)
167
- # if not db_equal:
168
- # breakpoint()
169
- assert torch.equal(out, out0)
170
- assert torch.equal(dx, dx0)
171
- assert dw_equal
172
- if has_bias:
173
- assert dw_equal
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
install.sh DELETED
@@ -1,2 +0,0 @@
1
- pip install -e causal-conv1d
2
- pip install -e mamba
 
 
 
mamba/.gitmodules DELETED
@@ -1,3 +0,0 @@
1
- [submodule "3rdparty/lm-evaluation-harness"]
2
- path = 3rdparty/lm-evaluation-harness
3
- url = https://github.com/EleutherAI/lm-evaluation-harness/
 
 
 
 
mamba/AUTHORS DELETED
@@ -1,2 +0,0 @@
1
- Tri Dao, tri@tridao.me
2
- Albert Gu, agu@andrew.cmu.edu
 
 
 
mamba/LICENSE DELETED
@@ -1,201 +0,0 @@
1
- Apache License
2
- Version 2.0, January 2004
3
- http://www.apache.org/licenses/
4
-
5
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
-
7
- 1. Definitions.
8
-
9
- "License" shall mean the terms and conditions for use, reproduction,
10
- and distribution as defined by Sections 1 through 9 of this document.
11
-
12
- "Licensor" shall mean the copyright owner or entity authorized by
13
- the copyright owner that is granting the License.
14
-
15
- "Legal Entity" shall mean the union of the acting entity and all
16
- other entities that control, are controlled by, or are under common
17
- control with that entity. For the purposes of this definition,
18
- "control" means (i) the power, direct or indirect, to cause the
19
- direction or management of such entity, whether by contract or
20
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
- outstanding shares, or (iii) beneficial ownership of such entity.
22
-
23
- "You" (or "Your") shall mean an individual or Legal Entity
24
- exercising permissions granted by this License.
25
-
26
- "Source" form shall mean the preferred form for making modifications,
27
- including but not limited to software source code, documentation
28
- source, and configuration files.
29
-
30
- "Object" form shall mean any form resulting from mechanical
31
- transformation or translation of a Source form, including but
32
- not limited to compiled object code, generated documentation,
33
- and conversions to other media types.
34
-
35
- "Work" shall mean the work of authorship, whether in Source or
36
- Object form, made available under the License, as indicated by a
37
- copyright notice that is included in or attached to the work
38
- (an example is provided in the Appendix below).
39
-
40
- "Derivative Works" shall mean any work, whether in Source or Object
41
- form, that is based on (or derived from) the Work and for which the
42
- editorial revisions, annotations, elaborations, or other modifications
43
- represent, as a whole, an original work of authorship. For the purposes
44
- of this License, Derivative Works shall not include works that remain
45
- separable from, or merely link (or bind by name) to the interfaces of,
46
- the Work and Derivative Works thereof.
47
-
48
- "Contribution" shall mean any work of authorship, including
49
- the original version of the Work and any modifications or additions
50
- to that Work or Derivative Works thereof, that is intentionally
51
- submitted to Licensor for inclusion in the Work by the copyright owner
52
- or by an individual or Legal Entity authorized to submit on behalf of
53
- the copyright owner. For the purposes of this definition, "submitted"
54
- means any form of electronic, verbal, or written communication sent
55
- to the Licensor or its representatives, including but not limited to
56
- communication on electronic mailing lists, source code control systems,
57
- and issue tracking systems that are managed by, or on behalf of, the
58
- Licensor for the purpose of discussing and improving the Work, but
59
- excluding communication that is conspicuously marked or otherwise
60
- designated in writing by the copyright owner as "Not a Contribution."
61
-
62
- "Contributor" shall mean Licensor and any individual or Legal Entity
63
- on behalf of whom a Contribution has been received by Licensor and
64
- subsequently incorporated within the Work.
65
-
66
- 2. Grant of Copyright License. Subject to the terms and conditions of
67
- this License, each Contributor hereby grants to You a perpetual,
68
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
- copyright license to reproduce, prepare Derivative Works of,
70
- publicly display, publicly perform, sublicense, and distribute the
71
- Work and such Derivative Works in Source or Object form.
72
-
73
- 3. Grant of Patent License. Subject to the terms and conditions of
74
- this License, each Contributor hereby grants to You a perpetual,
75
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
- (except as stated in this section) patent license to make, have made,
77
- use, offer to sell, sell, import, and otherwise transfer the Work,
78
- where such license applies only to those patent claims licensable
79
- by such Contributor that are necessarily infringed by their
80
- Contribution(s) alone or by combination of their Contribution(s)
81
- with the Work to which such Contribution(s) was submitted. If You
82
- institute patent litigation against any entity (including a
83
- cross-claim or counterclaim in a lawsuit) alleging that the Work
84
- or a Contribution incorporated within the Work constitutes direct
85
- or contributory patent infringement, then any patent licenses
86
- granted to You under this License for that Work shall terminate
87
- as of the date such litigation is filed.
88
-
89
- 4. Redistribution. You may reproduce and distribute copies of the
90
- Work or Derivative Works thereof in any medium, with or without
91
- modifications, and in Source or Object form, provided that You
92
- meet the following conditions:
93
-
94
- (a) You must give any other recipients of the Work or
95
- Derivative Works a copy of this License; and
96
-
97
- (b) You must cause any modified files to carry prominent notices
98
- stating that You changed the files; and
99
-
100
- (c) You must retain, in the Source form of any Derivative Works
101
- that You distribute, all copyright, patent, trademark, and
102
- attribution notices from the Source form of the Work,
103
- excluding those notices that do not pertain to any part of
104
- the Derivative Works; and
105
-
106
- (d) If the Work includes a "NOTICE" text file as part of its
107
- distribution, then any Derivative Works that You distribute must
108
- include a readable copy of the attribution notices contained
109
- within such NOTICE file, excluding those notices that do not
110
- pertain to any part of the Derivative Works, in at least one
111
- of the following places: within a NOTICE text file distributed
112
- as part of the Derivative Works; within the Source form or
113
- documentation, if provided along with the Derivative Works; or,
114
- within a display generated by the Derivative Works, if and
115
- wherever such third-party notices normally appear. The contents
116
- of the NOTICE file are for informational purposes only and
117
- do not modify the License. You may add Your own attribution
118
- notices within Derivative Works that You distribute, alongside
119
- or as an addendum to the NOTICE text from the Work, provided
120
- that such additional attribution notices cannot be construed
121
- as modifying the License.
122
-
123
- You may add Your own copyright statement to Your modifications and
124
- may provide additional or different license terms and conditions
125
- for use, reproduction, or distribution of Your modifications, or
126
- for any such Derivative Works as a whole, provided Your use,
127
- reproduction, and distribution of the Work otherwise complies with
128
- the conditions stated in this License.
129
-
130
- 5. Submission of Contributions. Unless You explicitly state otherwise,
131
- any Contribution intentionally submitted for inclusion in the Work
132
- by You to the Licensor shall be under the terms and conditions of
133
- this License, without any additional terms or conditions.
134
- Notwithstanding the above, nothing herein shall supersede or modify
135
- the terms of any separate license agreement you may have executed
136
- with Licensor regarding such Contributions.
137
-
138
- 6. Trademarks. This License does not grant permission to use the trade
139
- names, trademarks, service marks, or product names of the Licensor,
140
- except as required for reasonable and customary use in describing the
141
- origin of the Work and reproducing the content of the NOTICE file.
142
-
143
- 7. Disclaimer of Warranty. Unless required by applicable law or
144
- agreed to in writing, Licensor provides the Work (and each
145
- Contributor provides its Contributions) on an "AS IS" BASIS,
146
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
- implied, including, without limitation, any warranties or conditions
148
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
- PARTICULAR PURPOSE. You are solely responsible for determining the
150
- appropriateness of using or redistributing the Work and assume any
151
- risks associated with Your exercise of permissions under this License.
152
-
153
- 8. Limitation of Liability. In no event and under no legal theory,
154
- whether in tort (including negligence), contract, or otherwise,
155
- unless required by applicable law (such as deliberate and grossly
156
- negligent acts) or agreed to in writing, shall any Contributor be
157
- liable to You for damages, including any direct, indirect, special,
158
- incidental, or consequential damages of any character arising as a
159
- result of this License or out of the use or inability to use the
160
- Work (including but not limited to damages for loss of goodwill,
161
- work stoppage, computer failure or malfunction, or any and all
162
- other commercial damages or losses), even if such Contributor
163
- has been advised of the possibility of such damages.
164
-
165
- 9. Accepting Warranty or Additional Liability. While redistributing
166
- the Work or Derivative Works thereof, You may choose to offer,
167
- and charge a fee for, acceptance of support, warranty, indemnity,
168
- or other liability obligations and/or rights consistent with this
169
- License. However, in accepting such obligations, You may act only
170
- on Your own behalf and on Your sole responsibility, not on behalf
171
- of any other Contributor, and only if You agree to indemnify,
172
- defend, and hold each Contributor harmless for any liability
173
- incurred by, or claims asserted against, such Contributor by reason
174
- of your accepting any such warranty or additional liability.
175
-
176
- END OF TERMS AND CONDITIONS
177
-
178
- APPENDIX: How to apply the Apache License to your work.
179
-
180
- To apply the Apache License to your work, attach the following
181
- boilerplate notice, with the fields enclosed by brackets "[]"
182
- replaced with your own identifying information. (Don't include
183
- the brackets!) The text should be enclosed in the appropriate
184
- comment syntax for the file format. We also recommend that a
185
- file or class name and description of purpose be included on the
186
- same "printed page" as the copyright notice for easier
187
- identification within third-party archives.
188
-
189
- Copyright 2023 Tri Dao, Albert Gu
190
-
191
- Licensed under the Apache License, Version 2.0 (the "License");
192
- you may not use this file except in compliance with the License.
193
- You may obtain a copy of the License at
194
-
195
- http://www.apache.org/licenses/LICENSE-2.0
196
-
197
- Unless required by applicable law or agreed to in writing, software
198
- distributed under the License is distributed on an "AS IS" BASIS,
199
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
- See the License for the specific language governing permissions and
201
- limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mamba/README.md DELETED
@@ -1,149 +0,0 @@
1
- # Mamba
2
-
3
- ![Mamba](assets/selection.png "Selective State Space")
4
- > **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\
5
- > Albert Gu*, Tri Dao*\
6
- > Paper: https://arxiv.org/abs/2312.00752
7
-
8
- ## About
9
-
10
- Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers.
11
- It is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4),
12
- with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention).
13
-
14
- ## Installation
15
-
16
- - `pip install causal-conv1d`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
17
- - `pip install mamba-ssm`: the core Mamba package.
18
-
19
- It can also be built from source with `pip install .` from this repository.
20
-
21
- If `pip` complains about PyTorch versions, try passing `--no-build-isolation` to `pip`.
22
-
23
- Other requirements:
24
- - Linux
25
- - NVIDIA GPU
26
- - PyTorch 1.12+
27
- - CUDA 11.6+
28
-
29
- ## Usage
30
-
31
- We expose several levels of interface with the Mamba model.
32
-
33
- ### Selective SSM
34
-
35
- Mamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2).
36
-
37
- Source: [ops/selective_scan_interface.py](mamba_ssm/ops/selective_scan_interface.py).
38
-
39
- ### Mamba Block
40
-
41
- The main module of this repository is the Mamba architecture block wrapping the selective SSM.
42
-
43
- Source: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py).
44
-
45
- Usage:
46
- ```
47
- from mamba_ssm import Mamba
48
-
49
- batch, length, dim = 2, 64, 16
50
- x = torch.randn(batch, length, dim).to("cuda")
51
- model = Mamba(
52
- # This module uses roughly 3 * expand * d_model^2 parameters
53
- d_model=dim, # Model dimension d_model
54
- d_state=16, # SSM state expansion factor
55
- d_conv=4, # Local convolution width
56
- expand=2, # Block expansion factor
57
- ).to("cuda")
58
- y = model(x)
59
- assert y.shape == x.shape
60
- ```
61
-
62
- ### Mamba Language Model
63
-
64
- Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head.
65
-
66
- Source: [models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py).
67
-
68
- This is an example of how to integrate Mamba into an end-to-end neural network.
69
- This example is used in the generation scripts below.
70
-
71
-
72
-
73
- ## Pretrained Models
74
-
75
- Pretrained models are uploaded to
76
- [HuggingFace](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`,
77
- `mamba-790m`, `mamba-1.4b`, `mamba-2.8b`.
78
-
79
- The models will be autodownloaded by the generation script below.
80
-
81
- These models were trained on the [Pile](https://huggingface.co/datasets/EleutherAI/pile), and follow the standard model dimensions described by GPT-3 and followed by many open source models:
82
-
83
- | Parameters | Layers | Model dim. |
84
- |------------|--------|------------|
85
- | 130M | 12 | 768 |
86
- | 370M | 24 | 1024 |
87
- | 790M | 24 | 1536 |
88
- | 1.4B | 24 | 2048 |
89
- | 2.8B | 32 | 2560 |
90
-
91
- (The layer count of Mamba should be doubled, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.)
92
-
93
- Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.).
94
- Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models.
95
-
96
-
97
- ## Evaluations
98
-
99
- To run zero-shot evaluations of models (corresponding to Table 3 of the paper),
100
- we use the
101
- [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor)
102
- library.
103
-
104
- 1. Pull the `lm-evaluation-harness` repo by `git submodule update --init
105
- --recursive`. We use the `big-refactor` branch.
106
- 2. Install `lm-evaluation-harness`: `pip install -e 3rdparty/lm-evaluation-harness`
107
- 3. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo):
108
- ```
109
- python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
110
- python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
111
- ```
112
-
113
- Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process.
114
-
115
- ## Inference
116
-
117
- The script [benchmarks/benchmark_generation_mamba_simple.py](benchmarks/benchmark_generation_mamba_simple.py)
118
- 1. autoloads a model from the HuggingFace Hub,
119
- 2. generates completions of a user-specified prompt,
120
- 3. benchmarks the inference speed of this generation.
121
-
122
- Other configurable options include the top-p (nucleus sampling) probability, and the softmax temperature.
123
-
124
- ### Examples
125
-
126
- To test generation latency (e.g. batch size = 1) with different sampling strategies:
127
-
128
- ```
129
- python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5
130
- python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5
131
- ```
132
-
133
- To test generation throughput with random prompts (e.g. large batch size):
134
- ```
135
- python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 128
136
- python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 128
137
- ```
138
-
139
- ## Citation
140
-
141
- If you use this codebase, or otherwise found our work valuable, please cite Mamba:
142
- ```
143
- @article{mamba,
144
- title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
145
- author={Gu, Albert and Dao, Tri},
146
- journal={arXiv preprint arXiv:2312.00752},
147
- year={2023}
148
- }
149
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mamba/assets/selection.png DELETED
Binary file (819 kB)
 
mamba/benchmarks/benchmark_generation_mamba_simple.py DELETED
@@ -1,88 +0,0 @@
1
- # Copyright (c) 2023, Tri Dao, Albert Gu.
2
-
3
- import argparse
4
- import time
5
- import json
6
-
7
- import torch
8
- import torch.nn.functional as F
9
-
10
- from einops import rearrange
11
-
12
- from transformers import AutoTokenizer, AutoModelForCausalLM
13
-
14
- from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
15
-
16
-
17
- parser = argparse.ArgumentParser(description="Generation benchmarking")
18
- parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m")
19
- parser.add_argument("--prompt", type=str, default=None)
20
- parser.add_argument("--promptlen", type=int, default=100)
21
- parser.add_argument("--genlen", type=int, default=100)
22
- parser.add_argument("--temperature", type=float, default=1.0)
23
- parser.add_argument("--topk", type=int, default=1)
24
- parser.add_argument("--topp", type=float, default=1.0)
25
- parser.add_argument("--batch", type=int, default=1)
26
- args = parser.parse_args()
27
-
28
- repeats = 3
29
- device = "cuda"
30
- dtype = torch.float16
31
-
32
- print(f"Loading model {args.model_name}")
33
- is_mamba = args.model_name.startswith("state-spaces/mamba-") or "mamba" in args.model_name
34
-
35
- if is_mamba:
36
- tokenizer = AutoTokenizer.from_pretrained("/home/zhulianghui/VisionProjects/mamba/ckpts/gpt-neox-20b-tokenizer")
37
- model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype)
38
- else:
39
- tokenizer = AutoTokenizer.from_pretrained(args.model_name)
40
- model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype)
41
- model.eval()
42
- print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
43
-
44
- torch.random.manual_seed(0)
45
- if args.prompt is None:
46
- input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda")
47
- attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda")
48
- else:
49
- tokens = tokenizer(args.prompt, return_tensors="pt")
50
- input_ids = tokens.input_ids.to(device=device)
51
- attn_mask = tokens.attention_mask.to(device=device)
52
- max_length = input_ids.shape[1] + args.genlen
53
-
54
- if is_mamba:
55
- fn = lambda: model.generate(
56
- input_ids=input_ids,
57
- max_length=max_length,
58
- cg=True,
59
- return_dict_in_generate=True,
60
- output_scores=True,
61
- enable_timing=False,
62
- temperature=args.temperature,
63
- top_k=args.topk,
64
- top_p=args.topp,
65
- )
66
- else:
67
- fn = lambda: model.generate(
68
- input_ids=input_ids,
69
- attention_mask=attn_mask,
70
- max_length=max_length,
71
- return_dict_in_generate=True,
72
- pad_token_id=tokenizer.eos_token_id,
73
- do_sample=True,
74
- temperature=args.temperature,
75
- top_k=args.topk,
76
- top_p=args.topp,
77
- )
78
- out = fn()
79
- if args.prompt is not None:
80
- print(tokenizer.batch_decode(out.sequences.tolist()))
81
-
82
- torch.cuda.synchronize()
83
- start = time.time()
84
- for _ in range(repeats):
85
- fn()
86
- torch.cuda.synchronize()
87
- print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}")
88
- print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mamba/csrc/selective_scan/reverse_scan.cuh DELETED
@@ -1,401 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2023, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include <cub/config.cuh>
8
-
9
- #include <cub/util_ptx.cuh>
10
- #include <cub/util_type.cuh>
11
- #include <cub/block/block_raking_layout.cuh>
12
- // #include <cub/detail/uninitialized_copy.cuh>
13
- #include "uninitialized_copy.cuh"
14
-
15
- /**
16
- * Perform a reverse sequential reduction over \p LENGTH elements of the \p input array. The aggregate is returned.
17
- */
18
- template <
19
- int LENGTH,
20
- typename T,
21
- typename ReductionOp>
22
- __device__ __forceinline__ T ThreadReverseReduce(const T (&input)[LENGTH], ReductionOp reduction_op) {
23
- static_assert(LENGTH > 0);
24
- T retval = input[LENGTH - 1];
25
- #pragma unroll
26
- for (int i = LENGTH - 2; i >= 0; --i) { retval = reduction_op(retval, input[i]); }
27
- return retval;
28
- }
29
-
30
- /**
31
- * Perform a sequential inclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned.
32
- */
33
- template <
34
- int LENGTH,
35
- typename T,
36
- typename ScanOp>
37
- __device__ __forceinline__ T ThreadReverseScanInclusive(
38
- const T (&input)[LENGTH],
39
- T (&output)[LENGTH],
40
- ScanOp scan_op,
41
- const T postfix)
42
- {
43
- T inclusive = postfix;
44
- #pragma unroll
45
- for (int i = LENGTH - 1; i >= 0; --i) {
46
- inclusive = scan_op(inclusive, input[i]);
47
- output[i] = inclusive;
48
- }
49
- }
50
-
51
- /**
52
- * Perform a sequential exclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned.
53
- */
54
- template <
55
- int LENGTH,
56
- typename T,
57
- typename ScanOp>
58
- __device__ __forceinline__ T ThreadReverseScanExclusive(
59
- const T (&input)[LENGTH],
60
- T (&output)[LENGTH],
61
- ScanOp scan_op,
62
- const T postfix)
63
- {
64
- // Careful, output maybe be aliased to input
65
- T exclusive = postfix;
66
- T inclusive;
67
- #pragma unroll
68
- for (int i = LENGTH - 1; i >= 0; --i) {
69
- inclusive = scan_op(exclusive, input[i]);
70
- output[i] = exclusive;
71
- exclusive = inclusive;
72
- }
73
- return inclusive;
74
- }
75
-
76
-
77
- /**
78
- * \brief WarpReverseScan provides SHFL-based variants of parallel postfix scan of items partitioned across a CUDA thread warp.
79
- *
80
- * LOGICAL_WARP_THREADS must be a power-of-two
81
- */
82
- template <
83
- typename T, ///< Data type being scanned
84
- int LOGICAL_WARP_THREADS ///< Number of threads per logical warp
85
- >
86
- struct WarpReverseScan {
87
- //---------------------------------------------------------------------
88
- // Constants and type definitions
89
- //---------------------------------------------------------------------
90
-
91
- /// Whether the logical warp size and the PTX warp size coincide
92
- static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0));
93
- /// The number of warp scan steps
94
- static constexpr int STEPS = cub::Log2<LOGICAL_WARP_THREADS>::VALUE;
95
- static_assert(LOGICAL_WARP_THREADS == 1 << STEPS);
96
-
97
-
98
- //---------------------------------------------------------------------
99
- // Thread fields
100
- //---------------------------------------------------------------------
101
-
102
- /// Lane index in logical warp
103
- unsigned int lane_id;
104
-
105
- /// Logical warp index in 32-thread physical warp
106
- unsigned int warp_id;
107
-
108
- /// 32-thread physical warp member mask of logical warp
109
- unsigned int member_mask;
110
-
111
- //---------------------------------------------------------------------
112
- // Construction
113
- //---------------------------------------------------------------------
114
-
115
- /// Constructor
116
- explicit __device__ __forceinline__
117
- WarpReverseScan()
118
- : lane_id(cub::LaneId())
119
- , warp_id(IS_ARCH_WARP ? 0 : (lane_id / LOGICAL_WARP_THREADS))
120
- , member_mask(cub::WarpMask<LOGICAL_WARP_THREADS>(warp_id))
121
- {
122
- if (!IS_ARCH_WARP) {
123
- lane_id = lane_id % LOGICAL_WARP_THREADS;
124
- }
125
- }
126
-
127
-
128
- /// Broadcast
129
- __device__ __forceinline__ T Broadcast(
130
- T input, ///< [in] The value to broadcast
131
- int src_lane) ///< [in] Which warp lane is to do the broadcasting
132
- {
133
- return cub::ShuffleIndex<LOGICAL_WARP_THREADS>(input, src_lane, member_mask);
134
- }
135
-
136
-
137
- /// Inclusive scan
138
- template <typename ScanOpT>
139
- __device__ __forceinline__ void InclusiveReverseScan(
140
- T input, ///< [in] Calling thread's input item.
141
- T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input.
142
- ScanOpT scan_op) ///< [in] Binary scan operator
143
- {
144
- inclusive_output = input;
145
- #pragma unroll
146
- for (int STEP = 0; STEP < STEPS; STEP++) {
147
- int offset = 1 << STEP;
148
- T temp = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
149
- inclusive_output, offset, LOGICAL_WARP_THREADS - 1, member_mask
150
- );
151
- // Perform scan op if from a valid peer
152
- inclusive_output = static_cast<int>(lane_id) >= LOGICAL_WARP_THREADS - offset
153
- ? inclusive_output : scan_op(temp, inclusive_output);
154
- }
155
- }
156
-
157
- /// Exclusive scan
158
- // Get exclusive from inclusive
159
- template <typename ScanOpT>
160
- __device__ __forceinline__ void ExclusiveReverseScan(
161
- T input, ///< [in] Calling thread's input item.
162
- T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input.
163
- ScanOpT scan_op, ///< [in] Binary scan operator
164
- T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items.
165
- {
166
- T inclusive_output;
167
- InclusiveReverseScan(input, inclusive_output, scan_op);
168
- warp_aggregate = cub::ShuffleIndex<LOGICAL_WARP_THREADS>(inclusive_output, 0, member_mask);
169
- // initial value unknown
170
- exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
171
- inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask
172
- );
173
- }
174
-
175
- /**
176
- * \brief Computes both inclusive and exclusive reverse scans using the specified binary scan functor across the calling warp. Because no initial value is supplied, the \p exclusive_output computed for the last <em>warp-lane</em> is undefined.
177
- */
178
- template <typename ScanOpT>
179
- __device__ __forceinline__ void ReverseScan(
180
- T input, ///< [in] Calling thread's input item.
181
- T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item.
182
- T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item.
183
- ScanOpT scan_op) ///< [in] Binary scan operator
184
- {
185
- InclusiveReverseScan(input, inclusive_output, scan_op);
186
- // initial value unknown
187
- exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
188
- inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask
189
- );
190
- }
191
-
192
- };
193
-
194
- /**
195
- * \brief BlockReverseScan provides variants of raking-based parallel postfix scan across a CUDA thread block.
196
- */
197
- template <
198
- typename T, ///< Data type being scanned
199
- int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension
200
- bool MEMOIZE=false ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure
201
- >
202
- struct BlockReverseScan {
203
- //---------------------------------------------------------------------
204
- // Types and constants
205
- //---------------------------------------------------------------------
206
-
207
- /// Constants
208
- /// The thread block size in threads
209
- static constexpr int BLOCK_THREADS = BLOCK_DIM_X;
210
-
211
- /// Layout type for padded thread block raking grid
212
- using BlockRakingLayout = cub::BlockRakingLayout<T, BLOCK_THREADS>;
213
- // The number of reduction elements is not a multiple of the number of raking threads for now
214
- static_assert(BlockRakingLayout::UNGUARDED);
215
-
216
- /// Number of raking threads
217
- static constexpr int RAKING_THREADS = BlockRakingLayout::RAKING_THREADS;
218
- /// Number of raking elements per warp synchronous raking thread
219
- static constexpr int SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH;
220
- /// Cooperative work can be entirely warp synchronous
221
- static constexpr bool WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS));
222
-
223
- /// WarpReverseScan utility type
224
- using WarpReverseScan = WarpReverseScan<T, RAKING_THREADS>;
225
-
226
- /// Shared memory storage layout type
227
- struct _TempStorage {
228
- typename BlockRakingLayout::TempStorage raking_grid; ///< Padded thread block raking grid
229
- };
230
-
231
-
232
- /// Alias wrapper allowing storage to be unioned
233
- struct TempStorage : cub::Uninitialized<_TempStorage> {};
234
-
235
-
236
- //---------------------------------------------------------------------
237
- // Per-thread fields
238
- //---------------------------------------------------------------------
239
-
240
- // Thread fields
241
- _TempStorage &temp_storage;
242
- unsigned int linear_tid;
243
- T cached_segment[SEGMENT_LENGTH];
244
-
245
-
246
- //---------------------------------------------------------------------
247
- // Utility methods
248
- //---------------------------------------------------------------------
249
-
250
- /// Performs upsweep raking reduction, returning the aggregate
251
- template <typename ScanOp>
252
- __device__ __forceinline__ T Upsweep(ScanOp scan_op) {
253
- T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);
254
- // Read data into registers
255
- #pragma unroll
256
- for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }
257
- T raking_partial = cached_segment[SEGMENT_LENGTH - 1];
258
- #pragma unroll
259
- for (int i = SEGMENT_LENGTH - 2; i >= 0; --i) {
260
- raking_partial = scan_op(raking_partial, cached_segment[i]);
261
- }
262
- return raking_partial;
263
- }
264
-
265
-
266
- /// Performs exclusive downsweep raking scan
267
- template <typename ScanOp>
268
- __device__ __forceinline__ void ExclusiveDownsweep(
269
- ScanOp scan_op,
270
- T raking_partial)
271
- {
272
- T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);
273
- // Read data back into registers
274
- if (!MEMOIZE) {
275
- #pragma unroll
276
- for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }
277
- }
278
- ThreadReverseScanExclusive(cached_segment, cached_segment, scan_op, raking_partial);
279
- // Write data back to smem
280
- #pragma unroll
281
- for (int i = 0; i < SEGMENT_LENGTH; ++i) { smem_raking_ptr[i] = cached_segment[i]; }
282
- }
283
-
284
-
285
- //---------------------------------------------------------------------
286
- // Constructors
287
- //---------------------------------------------------------------------
288
-
289
- /// Constructor
290
- __device__ __forceinline__ BlockReverseScan(
291
- TempStorage &temp_storage)
292
- :
293
- temp_storage(temp_storage.Alias()),
294
- linear_tid(cub::RowMajorTid(BLOCK_DIM_X, 1, 1))
295
- {}
296
-
297
-
298
- /// Computes an exclusive thread block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs.
299
- template <
300
- typename ScanOp,
301
- typename BlockPostfixCallbackOp>
302
- __device__ __forceinline__ void ExclusiveReverseScan(
303
- T input, ///< [in] Calling thread's input item
304
- T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input)
305
- ScanOp scan_op, ///< [in] Binary scan operator
306
- BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a thread block-wide postfix to be applied to all inputs.
307
- {
308
- if (WARP_SYNCHRONOUS) {
309
- // Short-circuit directly to warp-synchronous scan
310
- T block_aggregate;
311
- WarpReverseScan warp_scan;
312
- warp_scan.ExclusiveReverseScan(input, exclusive_output, scan_op, block_aggregate);
313
- // Obtain warp-wide postfix in lane0, then broadcast to other lanes
314
- T block_postfix = block_postfix_callback_op(block_aggregate);
315
- block_postfix = warp_scan.Broadcast(block_postfix, 0);
316
- exclusive_output = linear_tid == BLOCK_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_output);
317
- } else {
318
- // Place thread partial into shared memory raking grid
319
- T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid);
320
- detail::uninitialized_copy(placement_ptr, input);
321
- cub::CTA_SYNC();
322
- // Reduce parallelism down to just raking threads
323
- if (linear_tid < RAKING_THREADS) {
324
- WarpReverseScan warp_scan;
325
- // Raking upsweep reduction across shared partials
326
- T upsweep_partial = Upsweep(scan_op);
327
- // Warp-synchronous scan
328
- T exclusive_partial, block_aggregate;
329
- warp_scan.ExclusiveReverseScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate);
330
- // Obtain block-wide postfix in lane0, then broadcast to other lanes
331
- T block_postfix = block_postfix_callback_op(block_aggregate);
332
- block_postfix = warp_scan.Broadcast(block_postfix, 0);
333
- // Update postfix with warpscan exclusive partial
334
- T downsweep_postfix = linear_tid == RAKING_THREADS - 1
335
- ? block_postfix : scan_op(block_postfix, exclusive_partial);
336
- // Exclusive raking downsweep scan
337
- ExclusiveDownsweep(scan_op, downsweep_postfix);
338
- }
339
- cub::CTA_SYNC();
340
- // Grab thread postfix from shared memory
341
- exclusive_output = *placement_ptr;
342
-
343
- // // Compute warp scan in each warp.
344
- // // The exclusive output from the last lane in each warp is invalid.
345
- // T inclusive_output;
346
- // WarpReverseScan warp_scan;
347
- // warp_scan.ReverseScan(input, inclusive_output, exclusive_output, scan_op);
348
-
349
- // // Compute the warp-wide postfix and block-wide aggregate for each warp. Warp postfix for the last warp is invalid.
350
- // T block_aggregate;
351
- // T warp_postfix = ComputeWarpPostfix(scan_op, inclusive_output, block_aggregate);
352
-
353
- // // Apply warp postfix to our lane's partial
354
- // if (warp_id != 0) {
355
- // exclusive_output = scan_op(warp_postfix, exclusive_output);
356
- // if (lane_id == 0) { exclusive_output = warp_postfix; }
357
- // }
358
-
359
- // // Use the first warp to determine the thread block postfix, returning the result in lane0
360
- // if (warp_id == 0) {
361
- // T block_postfix = block_postfix_callback_op(block_aggregate);
362
- // if (lane_id == 0) {
363
- // // Share the postfix with all threads
364
- // detail::uninitialized_copy(&temp_storage.block_postfix,
365
- // block_postfix);
366
-
367
- // exclusive_output = block_postfix; // The block postfix is the exclusive output for tid0
368
- // }
369
- // }
370
-
371
- // cub::CTA_SYNC();
372
-
373
- // // Incorporate thread block postfix into outputs
374
- // T block_postfix = temp_storage.block_postfix;
375
- // if (linear_tid > 0) { exclusive_output = scan_op(block_postfix, exclusive_output); }
376
- }
377
- }
378
-
379
-
380
- /**
381
- * \brief Computes an inclusive block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs.
382
- */
383
- template <
384
- int ITEMS_PER_THREAD,
385
- typename ScanOp,
386
- typename BlockPostfixCallbackOp>
387
- __device__ __forceinline__ void InclusiveReverseScan(
388
- T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items
389
- T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input)
390
- ScanOp scan_op, ///< [in] Binary scan functor
391
- BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a block-wide postfix to be applied to the logical input sequence.
392
- {
393
- // Reduce consecutive thread items in registers
394
- T thread_postfix = ThreadReverseReduce(input, scan_op);
395
- // Exclusive thread block-scan
396
- ExclusiveReverseScan(thread_postfix, thread_postfix, scan_op, block_postfix_callback_op);
397
- // Inclusive scan in registers with postfix as seed
398
- ThreadReverseScanInclusive(input, output, scan_op, thread_postfix);
399
- }
400
-
401
- };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mamba/csrc/selective_scan/selective_scan.cpp DELETED
@@ -1,497 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2023, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #include <ATen/cuda/CUDAContext.h>
6
- #include <c10/cuda/CUDAGuard.h>
7
- #include <torch/extension.h>
8
- #include <vector>
9
-
10
- #include "selective_scan.h"
11
-
12
- #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
13
-
14
- #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
15
- if (ITYPE == at::ScalarType::Half) { \
16
- using input_t = at::Half; \
17
- __VA_ARGS__(); \
18
- } else if (ITYPE == at::ScalarType::BFloat16) { \
19
- using input_t = at::BFloat16; \
20
- __VA_ARGS__(); \
21
- } else if (ITYPE == at::ScalarType::Float) { \
22
- using input_t = float; \
23
- __VA_ARGS__(); \
24
- } else { \
25
- AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
26
- }
27
-
28
- #define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
29
- if (WTYPE == at::ScalarType::Half) { \
30
- using weight_t = at::Half; \
31
- __VA_ARGS__(); \
32
- } else if (WTYPE == at::ScalarType::BFloat16) { \
33
- using weight_t = at::BFloat16; \
34
- __VA_ARGS__(); \
35
- } else if (WTYPE == at::ScalarType::Float) { \
36
- using weight_t = float; \
37
- __VA_ARGS__(); \
38
- } else { \
39
- AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
40
- }
41
-
42
- #define DISPATCH_WTYPE_FLOAT_AND_COMPLEX(WTYPE, NAME, ...) \
43
- if (WTYPE == at::ScalarType::Float) { \
44
- using weight_t = float; \
45
- __VA_ARGS__(); \
46
- } else if (WTYPE == at::ScalarType::ComplexFloat) { \
47
- using weight_t = c10::complex<float>; \
48
- __VA_ARGS__(); \
49
- } else { \
50
- AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
51
- }
52
-
53
- template<typename input_t, typename weight_t>
54
- void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream);
55
-
56
- template <typename input_t, typename weight_t>
57
- void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream);
58
-
59
- void set_ssm_params_fwd(SSMParamsBase &params,
60
- // sizes
61
- const size_t batch,
62
- const size_t dim,
63
- const size_t seqlen,
64
- const size_t dstate,
65
- const size_t n_groups,
66
- const size_t n_chunks,
67
- const bool is_variable_B,
68
- const bool is_variable_C,
69
- // device pointers
70
- const at::Tensor u,
71
- const at::Tensor delta,
72
- const at::Tensor A,
73
- const at::Tensor B,
74
- const at::Tensor C,
75
- const at::Tensor out,
76
- const at::Tensor z,
77
- const at::Tensor out_z,
78
- void* D_ptr,
79
- void* delta_bias_ptr,
80
- void* x_ptr,
81
- bool has_z,
82
- bool delta_softplus) {
83
-
84
- // Reset the parameters
85
- memset(&params, 0, sizeof(params));
86
-
87
- params.batch = batch;
88
- params.dim = dim;
89
- params.seqlen = seqlen;
90
- params.dstate = dstate;
91
- params.n_groups = n_groups;
92
- params.n_chunks = n_chunks;
93
- params.dim_ngroups_ratio = dim / n_groups;
94
-
95
- params.delta_softplus = delta_softplus;
96
-
97
- params.is_variable_B = is_variable_B;
98
- params.is_variable_C = is_variable_C;
99
-
100
- // Set the pointers and strides.
101
- params.u_ptr = u.data_ptr();
102
- params.delta_ptr = delta.data_ptr();
103
- params.A_ptr = A.data_ptr();
104
- params.B_ptr = B.data_ptr();
105
- params.C_ptr = C.data_ptr();
106
- params.D_ptr = D_ptr;
107
- params.delta_bias_ptr = delta_bias_ptr;
108
- params.out_ptr = out.data_ptr();
109
- params.x_ptr = x_ptr;
110
- params.z_ptr = has_z ? z.data_ptr() : nullptr;
111
- params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;
112
- // All stride are in elements, not bytes.
113
- params.A_d_stride = A.stride(0);
114
- params.A_dstate_stride = A.stride(1);
115
- if (!is_variable_B) {
116
- params.B_d_stride = B.stride(0);
117
- } else {
118
- params.B_batch_stride = B.stride(0);
119
- params.B_group_stride = B.stride(1);
120
- }
121
- params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2);
122
- if (!is_variable_C) {
123
- params.C_d_stride = C.stride(0);
124
- } else {
125
- params.C_batch_stride = C.stride(0);
126
- params.C_group_stride = C.stride(1);
127
- }
128
- params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2);
129
- params.u_batch_stride = u.stride(0);
130
- params.u_d_stride = u.stride(1);
131
- params.delta_batch_stride = delta.stride(0);
132
- params.delta_d_stride = delta.stride(1);
133
- if (has_z) {
134
- params.z_batch_stride = z.stride(0);
135
- params.z_d_stride = z.stride(1);
136
- params.out_z_batch_stride = out_z.stride(0);
137
- params.out_z_d_stride = out_z.stride(1);
138
- }
139
- params.out_batch_stride = out.stride(0);
140
- params.out_d_stride = out.stride(1);
141
- }
142
-
143
- void set_ssm_params_bwd(SSMParamsBwd &params,
144
- // sizes
145
- const size_t batch,
146
- const size_t dim,
147
- const size_t seqlen,
148
- const size_t dstate,
149
- const size_t n_groups,
150
- const size_t n_chunks,
151
- const bool is_variable_B,
152
- const bool is_variable_C,
153
- // device pointers
154
- const at::Tensor u,
155
- const at::Tensor delta,
156
- const at::Tensor A,
157
- const at::Tensor B,
158
- const at::Tensor C,
159
- const at::Tensor z,
160
- const at::Tensor out,
161
- const at::Tensor out_z,
162
- void* D_ptr,
163
- void* delta_bias_ptr,
164
- void* x_ptr,
165
- const at::Tensor dout,
166
- const at::Tensor du,
167
- const at::Tensor ddelta,
168
- const at::Tensor dA,
169
- const at::Tensor dB,
170
- const at::Tensor dC,
171
- const at::Tensor dz,
172
- void* dD_ptr,
173
- void* ddelta_bias_ptr,
174
- bool has_z,
175
- bool delta_softplus,
176
- bool recompute_out_z) {
177
- // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z
178
- set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
179
- u, delta, A, B, C, has_z ? out : dout,
180
- has_z ? z : dout,
181
- // If not recompute_out_z, pass dout instead of out_z.
182
- // This won't be used by the bwd kernel
183
- recompute_out_z ? out_z : dout,
184
- D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus);
185
- if (!recompute_out_z) { params.out_z_ptr = nullptr; }
186
-
187
- // Set the pointers and strides.
188
- params.dout_ptr = dout.data_ptr();
189
- params.du_ptr = du.data_ptr();
190
- params.dA_ptr = dA.data_ptr();
191
- params.dB_ptr = dB.data_ptr();
192
- params.dC_ptr = dC.data_ptr();
193
- params.dD_ptr = dD_ptr;
194
- params.ddelta_ptr = ddelta.data_ptr();
195
- params.ddelta_bias_ptr = ddelta_bias_ptr;
196
- params.dz_ptr = has_z ? dz.data_ptr() : nullptr;
197
- // All stride are in elements, not bytes.
198
- params.dout_batch_stride = dout.stride(0);
199
- params.dout_d_stride = dout.stride(1);
200
- params.dA_d_stride = dA.stride(0);
201
- params.dA_dstate_stride = dA.stride(1);
202
- if (!is_variable_B) {
203
- params.dB_d_stride = dB.stride(0);
204
- } else {
205
- params.dB_batch_stride = dB.stride(0);
206
- params.dB_group_stride = dB.stride(1);
207
- }
208
- params.dB_dstate_stride = !is_variable_B ? dB.stride(1) : dB.stride(2);
209
- if (!is_variable_C) {
210
- params.dC_d_stride = dC.stride(0);
211
- } else {
212
- params.dC_batch_stride = dC.stride(0);
213
- params.dC_group_stride = dC.stride(1);
214
- }
215
- params.dC_dstate_stride = !is_variable_C ? dC.stride(1) : dC.stride(2);
216
- params.du_batch_stride = du.stride(0);
217
- params.du_d_stride = du.stride(1);
218
- params.ddelta_batch_stride = ddelta.stride(0);
219
- params.ddelta_d_stride = ddelta.stride(1);
220
- if (has_z) {
221
- params.dz_batch_stride = dz.stride(0);
222
- params.dz_d_stride = dz.stride(1);
223
- }
224
- }
225
-
226
- std::vector<at::Tensor>
227
- selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
228
- const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
229
- const c10::optional<at::Tensor> &D_,
230
- const c10::optional<at::Tensor> &z_,
231
- const c10::optional<at::Tensor> &delta_bias_,
232
- bool delta_softplus) {
233
- auto input_type = u.scalar_type();
234
- auto weight_type = A.scalar_type();
235
- TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
236
- TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat);
237
-
238
- const bool is_variable_B = B.dim() >= 3;
239
- const bool is_variable_C = C.dim() >= 3;
240
- const bool is_complex = weight_type == at::ScalarType::ComplexFloat;
241
-
242
- TORCH_CHECK(delta.scalar_type() == input_type);
243
- TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
244
- TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
245
-
246
- TORCH_CHECK(u.is_cuda());
247
- TORCH_CHECK(delta.is_cuda());
248
- TORCH_CHECK(A.is_cuda());
249
- TORCH_CHECK(B.is_cuda());
250
- TORCH_CHECK(C.is_cuda());
251
-
252
- TORCH_CHECK(u.stride(-1) == 1);
253
- TORCH_CHECK(delta.stride(-1) == 1);
254
-
255
- const auto sizes = u.sizes();
256
- const int batch_size = sizes[0];
257
- const int dim = sizes[1];
258
- const int seqlen = sizes[2];
259
- const int dstate = A.size(1);
260
- const int n_groups = is_variable_B ? B.size(1) : 1;
261
-
262
- TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
263
-
264
- CHECK_SHAPE(u, batch_size, dim, seqlen);
265
- CHECK_SHAPE(delta, batch_size, dim, seqlen);
266
- CHECK_SHAPE(A, dim, dstate);
267
- if (!is_variable_B) {
268
- CHECK_SHAPE(B, dim, dstate);
269
- } else {
270
- CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2);
271
- TORCH_CHECK(B.stride(-1) == 1);
272
- }
273
- if (!is_variable_C) {
274
- CHECK_SHAPE(C, dim, dstate);
275
- } else {
276
- CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2);
277
- TORCH_CHECK(C.stride(-1) == 1);
278
- }
279
-
280
- if (D_.has_value()) {
281
- auto D = D_.value();
282
- TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
283
- TORCH_CHECK(D.is_cuda());
284
- TORCH_CHECK(D.stride(-1) == 1);
285
- CHECK_SHAPE(D, dim);
286
- }
287
-
288
- if (delta_bias_.has_value()) {
289
- auto delta_bias = delta_bias_.value();
290
- TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
291
- TORCH_CHECK(delta_bias.is_cuda());
292
- TORCH_CHECK(delta_bias.stride(-1) == 1);
293
- CHECK_SHAPE(delta_bias, dim);
294
- }
295
-
296
- at::Tensor z, out_z;
297
- const bool has_z = z_.has_value();
298
- if (has_z) {
299
- z = z_.value();
300
- TORCH_CHECK(z.scalar_type() == input_type);
301
- TORCH_CHECK(z.is_cuda());
302
- TORCH_CHECK(z.stride(-1) == 1);
303
- CHECK_SHAPE(z, batch_size, dim, seqlen);
304
- out_z = torch::empty_like(z);
305
- }
306
-
307
- const int n_chunks = (seqlen + 2048 - 1) / 2048;
308
- // const int n_chunks = (seqlen + 1024 - 1) / 1024;
309
- // at::Tensor out = torch::empty_like(u);
310
- // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
311
- at::Tensor out = torch::empty_like(delta);
312
- at::Tensor x;
313
- x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type));
314
-
315
- SSMParamsBase params;
316
- set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
317
- u, delta, A, B, C, out, z, out_z,
318
- D_.has_value() ? D_.value().data_ptr() : nullptr,
319
- delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
320
- x.data_ptr(),
321
- has_z,
322
- delta_softplus);
323
-
324
- // Otherwise the kernel will be launched from cuda:0 device
325
- // Cast to char to avoid compiler warning about narrowing
326
- at::cuda::CUDAGuard device_guard{(char)u.get_device()};
327
- auto stream = at::cuda::getCurrentCUDAStream().stream();
328
- DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
329
- DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_fwd", [&] {
330
- selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
331
- });
332
- });
333
- std::vector<at::Tensor> result = {out, x};
334
- if (has_z) { result.push_back(out_z); }
335
- return result;
336
- }
337
-
338
- std::vector<at::Tensor>
339
- selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
340
- const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
341
- const c10::optional<at::Tensor> &D_,
342
- const c10::optional<at::Tensor> &z_,
343
- const c10::optional<at::Tensor> &delta_bias_,
344
- const at::Tensor &dout,
345
- const c10::optional<at::Tensor> &x_,
346
- const c10::optional<at::Tensor> &out_,
347
- c10::optional<at::Tensor> &dz_,
348
- bool delta_softplus,
349
- bool recompute_out_z) {
350
- auto input_type = u.scalar_type();
351
- auto weight_type = A.scalar_type();
352
- TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
353
- TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat);
354
-
355
- const bool is_variable_B = B.dim() >= 3;
356
- const bool is_variable_C = C.dim() >= 3;
357
- const bool is_complex = weight_type == at::ScalarType::ComplexFloat;
358
-
359
- TORCH_CHECK(delta.scalar_type() == input_type);
360
- TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
361
- TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
362
- TORCH_CHECK(dout.scalar_type() == input_type);
363
-
364
- TORCH_CHECK(u.is_cuda());
365
- TORCH_CHECK(delta.is_cuda());
366
- TORCH_CHECK(A.is_cuda());
367
- TORCH_CHECK(B.is_cuda());
368
- TORCH_CHECK(C.is_cuda());
369
- TORCH_CHECK(dout.is_cuda());
370
-
371
- TORCH_CHECK(u.stride(-1) == 1);
372
- TORCH_CHECK(delta.stride(-1) == 1);
373
- TORCH_CHECK(dout.stride(-1) == 1);
374
-
375
- const auto sizes = u.sizes();
376
- const int batch_size = sizes[0];
377
- const int dim = sizes[1];
378
- const int seqlen = sizes[2];
379
- const int dstate = A.size(1);
380
- const int n_groups = is_variable_B ? B.size(1) : 1;
381
-
382
- TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
383
-
384
- CHECK_SHAPE(u, batch_size, dim, seqlen);
385
- CHECK_SHAPE(delta, batch_size, dim, seqlen);
386
- CHECK_SHAPE(A, dim, dstate);
387
- if (!is_variable_B) {
388
- CHECK_SHAPE(B, dim, dstate);
389
- } else {
390
- CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2);
391
- TORCH_CHECK(B.stride(-1) == 1);
392
- }
393
- if (!is_variable_C) {
394
- CHECK_SHAPE(C, dim, dstate);
395
- } else {
396
- CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2);
397
- TORCH_CHECK(C.stride(-1) == 1);
398
- }
399
- CHECK_SHAPE(dout, batch_size, dim, seqlen);
400
-
401
- if (D_.has_value()) {
402
- auto D = D_.value();
403
- TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
404
- TORCH_CHECK(D.is_cuda());
405
- TORCH_CHECK(D.stride(-1) == 1);
406
- CHECK_SHAPE(D, dim);
407
- }
408
-
409
- if (delta_bias_.has_value()) {
410
- auto delta_bias = delta_bias_.value();
411
- TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
412
- TORCH_CHECK(delta_bias.is_cuda());
413
- TORCH_CHECK(delta_bias.stride(-1) == 1);
414
- CHECK_SHAPE(delta_bias, dim);
415
- }
416
-
417
- at::Tensor z, out, dz, out_z;
418
- const bool has_z = z_.has_value();
419
- if (has_z) {
420
- z = z_.value();
421
- TORCH_CHECK(z.scalar_type() == input_type);
422
- TORCH_CHECK(z.is_cuda());
423
- TORCH_CHECK(z.stride(-1) == 1);
424
- CHECK_SHAPE(z, batch_size, dim, seqlen);
425
-
426
- TORCH_CHECK(out_.has_value());
427
- out = out_.value();
428
- TORCH_CHECK(out.scalar_type() == input_type);
429
- TORCH_CHECK(out.is_cuda());
430
- TORCH_CHECK(out.stride(-1) == 1);
431
- CHECK_SHAPE(out, batch_size, dim, seqlen);
432
-
433
- if (dz_.has_value()) {
434
- dz = dz_.value();
435
- TORCH_CHECK(dz.scalar_type() == input_type);
436
- TORCH_CHECK(dz.is_cuda());
437
- TORCH_CHECK(dz.stride(-1) == 1);
438
- CHECK_SHAPE(dz, batch_size, dim, seqlen);
439
- } else {
440
- dz = torch::empty_like(z);
441
- }
442
- if (recompute_out_z) {
443
- out_z = torch::empty_like(out);
444
- }
445
- }
446
-
447
- const int n_chunks = (seqlen + 2048 - 1) / 2048;
448
- // const int n_chunks = (seqlen + 1024 - 1) / 1024;
449
- if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); }
450
- if (x_.has_value()) {
451
- auto x = x_.value();
452
- TORCH_CHECK(x.scalar_type() == weight_type);
453
- TORCH_CHECK(x.is_cuda());
454
- TORCH_CHECK(x.is_contiguous());
455
- CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate);
456
- }
457
-
458
- at::Tensor du = torch::empty_like(u);
459
- at::Tensor ddelta = torch::empty_like(delta);
460
- at::Tensor dA = torch::zeros_like(A);
461
- at::Tensor dB = !is_variable_B ? torch::zeros_like(B) : torch::zeros_like(B, B.options().dtype(torch::kFloat32));
462
- at::Tensor dC = !is_variable_C ? torch::zeros_like(C) : torch::zeros_like(C, C.options().dtype(torch::kFloat32));
463
- at::Tensor dD;
464
- if (D_.has_value()) { dD = torch::zeros_like(D_.value()); }
465
- at::Tensor ddelta_bias;
466
- if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); }
467
-
468
- SSMParamsBwd params;
469
- set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
470
- u, delta, A, B, C, z, out, out_z,
471
- D_.has_value() ? D_.value().data_ptr() : nullptr,
472
- delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
473
- x_.has_value() ? x_.value().data_ptr() : nullptr,
474
- dout, du, ddelta, dA, dB, dC, dz,
475
- D_.has_value() ? dD.data_ptr() : nullptr,
476
- delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr,
477
- has_z, delta_softplus, recompute_out_z);
478
-
479
- // Otherwise the kernel will be launched from cuda:0 device
480
- // Cast to char to avoid compiler warning about narrowing
481
- at::cuda::CUDAGuard device_guard{(char)u.get_device()};
482
- auto stream = at::cuda::getCurrentCUDAStream().stream();
483
- DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] {
484
- DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_bwd", [&] {
485
- selective_scan_bwd_cuda<input_t, weight_t>(params, stream);
486
- });
487
- });
488
- std::vector<at::Tensor> result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias};
489
- if (has_z) { result.push_back(dz); }
490
- if (recompute_out_z) { result.push_back(out_z); }
491
- return result;
492
- }
493
-
494
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
495
- m.def("fwd", &selective_scan_fwd, "Selective scan forward");
496
- m.def("bwd", &selective_scan_bwd, "Selective scan backward");
497
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mamba/csrc/selective_scan/selective_scan.h DELETED
@@ -1,101 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2023, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- ////////////////////////////////////////////////////////////////////////////////////////////////////
8
-
9
- struct SSMScanParamsBase {
10
- using index_t = uint32_t;
11
-
12
- int batch, seqlen, n_chunks;
13
- index_t a_batch_stride;
14
- index_t b_batch_stride;
15
- index_t out_batch_stride;
16
-
17
- // Common data pointers.
18
- void *__restrict__ a_ptr;
19
- void *__restrict__ b_ptr;
20
- void *__restrict__ out_ptr;
21
- void *__restrict__ x_ptr;
22
- };
23
-
24
- ////////////////////////////////////////////////////////////////////////////////////////////////////
25
-
26
- struct SSMParamsBase {
27
- using index_t = uint32_t;
28
-
29
- int batch, dim, seqlen, dstate, n_groups, n_chunks;
30
- int dim_ngroups_ratio;
31
- bool is_variable_B;
32
- bool is_variable_C;
33
-
34
- bool delta_softplus;
35
-
36
- index_t A_d_stride;
37
- index_t A_dstate_stride;
38
- index_t B_batch_stride;
39
- index_t B_d_stride;
40
- index_t B_dstate_stride;
41
- index_t B_group_stride;
42
- index_t C_batch_stride;
43
- index_t C_d_stride;
44
- index_t C_dstate_stride;
45
- index_t C_group_stride;
46
- index_t u_batch_stride;
47
- index_t u_d_stride;
48
- index_t delta_batch_stride;
49
- index_t delta_d_stride;
50
- index_t z_batch_stride;
51
- index_t z_d_stride;
52
- index_t out_batch_stride;
53
- index_t out_d_stride;
54
- index_t out_z_batch_stride;
55
- index_t out_z_d_stride;
56
-
57
- // Common data pointers.
58
- void *__restrict__ A_ptr;
59
- void *__restrict__ B_ptr;
60
- void *__restrict__ C_ptr;
61
- void *__restrict__ D_ptr;
62
- void *__restrict__ u_ptr;
63
- void *__restrict__ delta_ptr;
64
- void *__restrict__ delta_bias_ptr;
65
- void *__restrict__ out_ptr;
66
- void *__restrict__ x_ptr;
67
- void *__restrict__ z_ptr;
68
- void *__restrict__ out_z_ptr;
69
- };
70
-
71
- struct SSMParamsBwd: public SSMParamsBase {
72
- index_t dout_batch_stride;
73
- index_t dout_d_stride;
74
- index_t dA_d_stride;
75
- index_t dA_dstate_stride;
76
- index_t dB_batch_stride;
77
- index_t dB_group_stride;
78
- index_t dB_d_stride;
79
- index_t dB_dstate_stride;
80
- index_t dC_batch_stride;
81
- index_t dC_group_stride;
82
- index_t dC_d_stride;
83
- index_t dC_dstate_stride;
84
- index_t du_batch_stride;
85
- index_t du_d_stride;
86
- index_t dz_batch_stride;
87
- index_t dz_d_stride;
88
- index_t ddelta_batch_stride;
89
- index_t ddelta_d_stride;
90
-
91
- // Common data pointers.
92
- void *__restrict__ dout_ptr;
93
- void *__restrict__ dA_ptr;
94
- void *__restrict__ dB_ptr;
95
- void *__restrict__ dC_ptr;
96
- void *__restrict__ dD_ptr;
97
- void *__restrict__ du_ptr;
98
- void *__restrict__ dz_ptr;
99
- void *__restrict__ ddelta_ptr;
100
- void *__restrict__ ddelta_bias_ptr;
101
- };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mamba/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu DELETED
@@ -1,9 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2023, Tri Dao.
3
- ******************************************************************************/
4
-
5
- // Split into multiple files to compile in paralell
6
-
7
- #include "selective_scan_bwd_kernel.cuh"
8
-
9
- template void selective_scan_bwd_cuda<at::BFloat16, complex_t>(SSMParamsBwd &params, cudaStream_t stream);
 
 
 
 
 
 
 
 
 
 
mamba/csrc/selective_scan/selective_scan_bwd_bf16_real.cu DELETED
@@ -1,9 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2023, Tri Dao.
3
- ******************************************************************************/
4
-
5
- // Split into multiple files to compile in paralell
6
-
7
- #include "selective_scan_bwd_kernel.cuh"
8
-
9
- template void selective_scan_bwd_cuda<at::BFloat16, float>(SSMParamsBwd &params, cudaStream_t stream);
 
 
 
 
 
 
 
 
 
 
mamba/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu DELETED
@@ -1,9 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2023, Tri Dao.
3
- ******************************************************************************/
4
-
5
- // Split into multiple files to compile in paralell
6
-
7
- #include "selective_scan_bwd_kernel.cuh"
8
-
9
- template void selective_scan_bwd_cuda<at::Half, complex_t>(SSMParamsBwd &params, cudaStream_t stream);
 
 
 
 
 
 
 
 
 
 
mamba/csrc/selective_scan/selective_scan_bwd_fp16_real.cu DELETED
@@ -1,9 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2023, Tri Dao.
3
- ******************************************************************************/
4
-
5
- // Split into multiple files to compile in paralell
6
-
7
- #include "selective_scan_bwd_kernel.cuh"
8
-
9
- template void selective_scan_bwd_cuda<at::Half, float>(SSMParamsBwd &params, cudaStream_t stream);
 
 
 
 
 
 
 
 
 
 
mamba/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu DELETED
@@ -1,9 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2023, Tri Dao.
3
- ******************************************************************************/
4
-
5
- // Split into multiple files to compile in paralell
6
-
7
- #include "selective_scan_bwd_kernel.cuh"
8
-
9
- template void selective_scan_bwd_cuda<float, complex_t>(SSMParamsBwd &params, cudaStream_t stream);
 
 
 
 
 
 
 
 
 
 
mamba/csrc/selective_scan/selective_scan_bwd_fp32_real.cu DELETED
@@ -1,9 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2023, Tri Dao.
3
- ******************************************************************************/
4
-
5
- // Split into multiple files to compile in paralell
6
-
7
- #include "selective_scan_bwd_kernel.cuh"
8
-
9
- template void selective_scan_bwd_cuda<float, float>(SSMParamsBwd &params, cudaStream_t stream);
 
 
 
 
 
 
 
 
 
 
mamba/csrc/selective_scan/selective_scan_bwd_kernel.cuh DELETED
@@ -1,531 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2023, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include <c10/util/BFloat16.h>
8
- #include <c10/util/Half.h>
9
- #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
10
- #include <ATen/cuda/Atomic.cuh> // For atomicAdd on complex
11
-
12
- #include <cub/block/block_load.cuh>
13
- #include <cub/block/block_store.cuh>
14
- #include <cub/block/block_scan.cuh>
15
- #include <cub/block/block_reduce.cuh>
16
-
17
- #include "selective_scan.h"
18
- #include "selective_scan_common.h"
19
- #include "reverse_scan.cuh"
20
- #include "static_switch.h"
21
-
22
- template<typename scalar_t> __device__ __forceinline__ scalar_t conj(scalar_t x);
23
- template<> __device__ __forceinline__ float conj<float>(float x) { return x; }
24
- template<> __device__ __forceinline__ complex_t conj<complex_t>(complex_t x) { return std::conj(x); }
25
-
26
- template<int kNThreads_, int kNItems_, bool kIsEvenLen_, bool kIsVariableB_, bool kIsVariableC_,
27
- bool kDeltaSoftplus_, bool kHasZ_, typename input_t_, typename weight_t_>
28
- struct Selective_Scan_bwd_kernel_traits {
29
- static_assert(kNItems_ % 4 == 0);
30
- using input_t = input_t_;
31
- using weight_t = weight_t_;
32
- static constexpr int kNThreads = kNThreads_;
33
- static constexpr int kNItems = kNItems_;
34
- static constexpr int kNBytes = sizeof(input_t);
35
- static_assert(kNBytes == 2 || kNBytes == 4);
36
- static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
37
- static_assert(kNItems % kNElts == 0);
38
- static constexpr int kNLoads = kNItems / kNElts;
39
- static constexpr bool kIsComplex = std::is_same_v<weight_t, complex_t>;
40
- static constexpr bool kIsEvenLen = kIsEvenLen_;
41
- static constexpr bool kIsVariableB = kIsVariableB_;
42
- static constexpr bool kIsVariableC = kIsVariableC_;
43
- static constexpr bool kDeltaSoftplus = kDeltaSoftplus_;
44
- static constexpr bool kHasZ = kHasZ_;
45
- // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy.
46
- // For complex this would lead to massive register spilling, so we keep it at 2.
47
- static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2;
48
- using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
49
- using scan_t = std::conditional_t<!kIsComplex, float2, float4>;
50
- using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
51
- using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
52
- using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, !kIsComplex ? kNItems : kNItems * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
53
- using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, !kIsComplex ? kNLoads : kNLoads * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
54
- using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
55
- using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads, cub::BLOCK_STORE_WARP_TRANSPOSE>;
56
- // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
57
- using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
58
- // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
59
- using BlockReverseScanT = BlockReverseScan<scan_t, kNThreads>;
60
- using BlockReduceT = cub::BlockReduce<scan_t, kNThreads>;
61
- using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
62
- using BlockReduceComplexT = cub::BlockReduce<complex_t, kNThreads>;
63
- using BlockExchangeT = cub::BlockExchange<float, kNThreads, !kIsComplex ? kNItems : kNItems * 2>;
64
- static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
65
- sizeof(typename BlockLoadVecT::TempStorage),
66
- (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
67
- (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
68
- sizeof(typename BlockStoreT::TempStorage),
69
- sizeof(typename BlockStoreVecT::TempStorage)});
70
- static constexpr int kSmemExchangeSize = (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockExchangeT::TempStorage);
71
- static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage);
72
- static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage);
73
- };
74
-
75
- template<typename Ktraits>
76
- __global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
77
- void selective_scan_bwd_kernel(SSMParamsBwd params) {
78
- constexpr bool kIsComplex = Ktraits::kIsComplex;
79
- constexpr bool kIsVariableB = Ktraits::kIsVariableB;
80
- constexpr bool kIsVariableC = Ktraits::kIsVariableC;
81
- constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus;
82
- constexpr bool kHasZ = Ktraits::kHasZ;
83
- constexpr int kNThreads = Ktraits::kNThreads;
84
- constexpr int kNItems = Ktraits::kNItems;
85
- using input_t = typename Ktraits::input_t;
86
- using weight_t = typename Ktraits::weight_t;
87
- using scan_t = typename Ktraits::scan_t;
88
-
89
- // Shared memory.
90
- extern __shared__ char smem_[];
91
- // cast to lvalue reference of expected type
92
- // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
93
- // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
94
- // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
95
- auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
96
- auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
97
- auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
98
- auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
99
- auto& smem_exchange = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
100
- auto& smem_exchange1 = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage));
101
- auto& smem_reduce = *reinterpret_cast<typename Ktraits::BlockReduceT::TempStorage*>(reinterpret_cast<char *>(&smem_exchange) + Ktraits::kSmemExchangeSize);
102
- auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(&smem_reduce);
103
- auto& smem_reduce_complex = *reinterpret_cast<typename Ktraits::BlockReduceComplexT::TempStorage*>(&smem_reduce);
104
- auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(reinterpret_cast<char *>(&smem_reduce) + Ktraits::kSmemReduceSize);
105
- auto& smem_reverse_scan = *reinterpret_cast<typename Ktraits::BlockReverseScanT::TempStorage*>(reinterpret_cast<char *>(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage));
106
- weight_t *smem_delta_a = reinterpret_cast<weight_t *>(smem_ + Ktraits::kSmemSize);
107
- scan_t *smem_running_postfix = reinterpret_cast<scan_t *>(smem_delta_a + 2 * MAX_DSTATE + kNThreads);
108
- weight_t *smem_da = reinterpret_cast<weight_t *>(smem_running_postfix + MAX_DSTATE);
109
- weight_t *smem_dbc = reinterpret_cast<weight_t *>(smem_da + MAX_DSTATE);
110
-
111
- const int batch_id = blockIdx.x;
112
- const int dim_id = blockIdx.y;
113
- const int group_id = dim_id / (params.dim_ngroups_ratio);
114
- input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
115
- + dim_id * params.u_d_stride;
116
- input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
117
- + dim_id * params.delta_d_stride;
118
- input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
119
- + dim_id * params.dout_d_stride;
120
- weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * params.A_d_stride;
121
- weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * params.B_d_stride;
122
- input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
123
- weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * params.C_d_stride;
124
- input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
125
- weight_t *dA = reinterpret_cast<weight_t *>(params.dA_ptr) + dim_id * params.dA_d_stride;
126
- weight_t *dB = reinterpret_cast<weight_t *>(params.dB_ptr)
127
- + (!kIsVariableB ? dim_id * params.dB_d_stride : batch_id * (!kIsComplex ? params.dB_batch_stride : params.dB_batch_stride / 2) + group_id * params.dB_group_stride);
128
- weight_t *dC = reinterpret_cast<weight_t *>(params.dC_ptr)
129
- + (!kIsVariableC ? dim_id * params.dC_d_stride : batch_id * (!kIsComplex ? params.dC_batch_stride : params.dC_batch_stride / 2) + group_id * params.dC_group_stride);
130
- float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.dD_ptr) + dim_id;
131
- float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.D_ptr)[dim_id];
132
- float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.ddelta_bias_ptr) + dim_id;
133
- float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id];
134
- scan_t *x = params.x_ptr == nullptr
135
- ? nullptr
136
- : reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate;
137
- float dD_val = 0;
138
- float ddelta_bias_val = 0;
139
-
140
- constexpr int kChunkSize = kNThreads * kNItems;
141
- u += (params.n_chunks - 1) * kChunkSize;
142
- delta += (params.n_chunks - 1) * kChunkSize;
143
- dout += (params.n_chunks - 1) * kChunkSize;
144
- Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2);
145
- Cvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2);
146
- for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) {
147
- input_t u_vals[kNItems];
148
- input_t delta_vals_load[kNItems];
149
- input_t dout_vals_load[kNItems];
150
- __syncthreads();
151
- load_input<Ktraits>(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize);
152
- u -= kChunkSize;
153
- __syncthreads();
154
- load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
155
- // Will reload delta at the same location if kDeltaSoftplus
156
- if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; }
157
- __syncthreads();
158
- load_input<Ktraits>(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
159
- dout -= kChunkSize;
160
-
161
- float dout_vals[kNItems], delta_vals[kNItems];
162
- #pragma unroll
163
- for (int i = 0; i < kNItems; ++i) {
164
- dout_vals[i] = float(dout_vals_load[i]);
165
- delta_vals[i] = float(delta_vals_load[i]) + delta_bias;
166
- if constexpr (kDeltaSoftplus) {
167
- delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i];
168
- }
169
- }
170
-
171
- if constexpr (kHasZ) {
172
- input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride
173
- + dim_id * params.z_d_stride + chunk * kChunkSize;
174
- input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
175
- + dim_id * params.out_d_stride + chunk * kChunkSize;
176
- input_t *dz = reinterpret_cast<input_t *>(params.dz_ptr) + batch_id * params.dz_batch_stride
177
- + dim_id * params.dz_d_stride + chunk * kChunkSize;
178
- input_t z_vals[kNItems], out_vals[kNItems];
179
- __syncthreads();
180
- load_input<Ktraits>(z, z_vals, smem_load, params.seqlen - chunk * kChunkSize);
181
- __syncthreads();
182
- load_input<Ktraits>(out, out_vals, smem_load, params.seqlen - chunk * kChunkSize);
183
- float dz_vals[kNItems], z_silu_vals[kNItems];
184
- #pragma unroll
185
- for (int i = 0; i < kNItems; ++i) {
186
- float z_val = z_vals[i];
187
- float z_sigmoid_val = 1.0f / (1.0f + expf(-z_val));
188
- z_silu_vals[i] = z_val * z_sigmoid_val;
189
- dz_vals[i] = dout_vals[i] * float(out_vals[i]) * z_sigmoid_val
190
- * (1.0f + z_val * (1.0f - z_sigmoid_val));
191
- dout_vals[i] *= z_silu_vals[i];
192
- }
193
- __syncthreads();
194
- store_output<Ktraits>(dz, dz_vals, smem_store, params.seqlen - chunk * kChunkSize);
195
- if (params.out_z_ptr != nullptr) { // Recompute and store out_z
196
- float out_z_vals[kNItems];
197
- #pragma unroll
198
- for (int i = 0; i < kNItems; ++i) { out_z_vals[i] = float(out_vals[i]) * z_silu_vals[i]; }
199
- // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) {
200
- // printf("out_val=%f, z_silu_val = %f, out_z_val = %f\n", float(out_vals[0]), z_silu_vals[0], out_z_vals[0]);
201
- // }
202
- input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride
203
- + dim_id * params.out_z_d_stride + chunk * kChunkSize;
204
- __syncthreads();
205
- store_output<Ktraits>(out_z, out_z_vals, smem_store, params.seqlen - chunk * kChunkSize);
206
- }
207
- }
208
-
209
- float du_vals[kNItems];
210
- #pragma unroll
211
- for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; }
212
- #pragma unroll
213
- for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); }
214
-
215
- float ddelta_vals[kNItems] = {0};
216
- __syncthreads();
217
- for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
218
- const weight_t A_val = A[state_idx * params.A_dstate_stride];
219
- // Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
220
- weight_t A_scaled;
221
- constexpr float kLog2e = M_LOG2E;
222
- if constexpr (!kIsComplex) {
223
- A_scaled = A_val * kLog2e;
224
- } else {
225
- A_scaled = complex_t(A_val.real_ * kLog2e, A_val.imag_);
226
- }
227
- weight_t B_val, C_val;
228
- weight_t B_vals[kNItems], C_vals[kNItems];
229
- if constexpr (!kIsVariableB) {
230
- B_val = B[state_idx * params.B_dstate_stride];
231
- } else {
232
- load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
233
- smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
234
- }
235
- if constexpr (!kIsVariableC) {
236
- C_val = C[state_idx * params.C_dstate_stride];
237
- } else {
238
- auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
239
- load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
240
- smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
241
- }
242
- // const weight_t A_val = smem_a[state_idx];
243
- scan_t thread_data[kNItems], thread_reverse_data[kNItems];
244
- if constexpr (!kIsComplex) {
245
- #pragma unroll
246
- for (int i = 0; i < kNItems; ++i) {
247
- const float delta_a_exp = exp2f(delta_vals[i] * A_scaled);
248
- thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]);
249
- if (i == 0) {
250
- smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp;
251
- } else {
252
- thread_reverse_data[i - 1].x = delta_a_exp;
253
- }
254
- thread_reverse_data[i].y = dout_vals[i] *
255
- (!kIsVariableC
256
- ? (!kIsVariableB ? B_val * C_val : C_val)
257
- : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i]));
258
- }
259
- __syncthreads();
260
- thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1
261
- ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE])
262
- : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE];
263
- // Initialize running total
264
- scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f);
265
- SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
266
- Ktraits::BlockScanT(smem_scan).InclusiveScan(
267
- thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
268
- );
269
- scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f);
270
- SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
271
- Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
272
- thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
273
- );
274
- if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; }
275
- weight_t dA_val = 0, dBC_val = 0;
276
- weight_t dB_vals[kNItems], dC_vals[kNItems];
277
- #pragma unroll
278
- for (int i = 0; i < kNItems; ++i) {
279
- const float dx = thread_reverse_data[i].y;
280
- const float ddelta_u = !kIsVariableB ? dx : dx * B_vals[i];
281
- du_vals[i] += ddelta_u * delta_vals[i];
282
- const float a = thread_data[i].y - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]);
283
- ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a;
284
- dA_val += dx * delta_vals[i] * a;
285
- if constexpr (!kIsVariableB || !kIsVariableC) {
286
- if constexpr (!kIsVariableB) { // dBC_val is dB_val
287
- dBC_val += dout_vals[i] * (!kIsVariableC ? thread_data[i].y : thread_data[i].y * C_vals[i]);
288
- } else { // dBC_val is dC_val
289
- dBC_val += dout_vals[i] * thread_data[i].y;
290
- }
291
- }
292
- if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); }
293
- if constexpr (kIsVariableC) {
294
- dC_vals[i] = dout_vals[i] * (!kIsVariableB ? thread_data[i].y * B_val : thread_data[i].y);
295
- }
296
- }
297
- // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
298
- if constexpr (kIsVariableB || kIsVariableC) {
299
- if constexpr (kIsVariableB) {
300
- Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals);
301
- }
302
- if constexpr (kIsVariableC) {
303
- auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1;
304
- Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals);
305
- }
306
- const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x;
307
- weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x;
308
- weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x;
309
- #pragma unroll
310
- for (int i = 0; i < kNItems; ++i) {
311
- if (i * kNThreads < seqlen_remaining) {
312
- if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); }
313
- if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); }
314
- }
315
- }
316
- }
317
- if constexpr (!kIsVariableB || !kIsVariableC) {
318
- float2 dA_dBC_val = make_float2(dA_val, dBC_val);
319
- dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val);
320
- dA_val = dA_dBC_val.x;
321
- if (threadIdx.x == 0) {
322
- smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx];
323
- }
324
- } else {
325
- dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val);
326
- }
327
- if (threadIdx.x == 0) {
328
- smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];
329
- }
330
- } else {
331
- #pragma unroll
332
- for (int i = 0; i < kNItems; ++i) {
333
- // Pytorch's implementation of complex exp (which calls thrust) is very slow
334
- complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled);
335
- weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]);
336
- thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);
337
- if (i == 0) {
338
- smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp;
339
- } else {
340
- thread_reverse_data[i - 1].x = delta_a_exp.real_;
341
- thread_reverse_data[i - 1].y = -delta_a_exp.imag_;
342
- }
343
- complex_t dout_BC = 2 * dout_vals[i]
344
- * conj(!kIsVariableC
345
- ? (!kIsVariableB ? B_val * C_val : C_val)
346
- : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i]));
347
- thread_reverse_data[i].z = dout_BC.real_;
348
- thread_reverse_data[i].w = dout_BC.imag_;
349
- }
350
- __syncthreads();
351
- complex_t delta_a_exp = threadIdx.x == kNThreads - 1
352
- ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE])
353
- : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE];
354
- thread_reverse_data[kNItems - 1].x = delta_a_exp.real_;
355
- thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag_;
356
- // Initialize running total
357
- scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
358
- SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
359
- Ktraits::BlockScanT(smem_scan).InclusiveScan(
360
- thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
361
- );
362
- scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
363
- SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
364
- Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
365
- thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
366
- );
367
- if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; }
368
- weight_t dA_val = 0, dBC_val = 0;
369
- weight_t dB_vals[kNItems], dC_vals[kNItems];
370
- #pragma unroll
371
- for (int i = 0; i < kNItems; ++i) {
372
- complex_t x = complex_t(thread_data[i].z, thread_data[i].w);
373
- complex_t dx = complex_t(thread_reverse_data[i].z, thread_reverse_data[i].w);
374
- float ddelta_u = !kIsVariableB ? dx.real_ : (dx * conj(B_vals[i])).real_;
375
- if constexpr (!kIsVariableB || !kIsVariableC) {
376
- if constexpr (!kIsVariableB) { // dBC_val is dB_val
377
- dBC_val += (2 * dout_vals[i]) * conj(!kIsVariableC ? x : x * C_vals[i]);
378
- } else { // dBC_val is dC_val
379
- dBC_val += (2 * dout_vals[i]) * conj(x);
380
- }
381
- }
382
- const complex_t a_conj = conj(x - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]));
383
- du_vals[i] += ddelta_u * delta_vals[i];
384
- ddelta_vals[i] += ddelta_u * float(u_vals[i]) + (dx * conj(A_val) * a_conj).real_;
385
- dA_val += delta_vals[i] * dx * a_conj;
386
- if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); }
387
- if constexpr (kIsVariableC) {
388
- dC_vals[i] = (2 * dout_vals[i]) * conj(!kIsVariableB ? x * B_val : x);
389
- }
390
- }
391
- // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
392
- if constexpr (kIsVariableB || kIsVariableC) {
393
- float dB_vals_f[kNItems * 2], dC_vals_f[kNItems * 2];
394
- if constexpr (kIsVariableB) {
395
- #pragma unroll
396
- for (int i = 0; i < kNItems; ++i) {
397
- dB_vals_f[i * 2] = dB_vals[i].real_;
398
- dB_vals_f[i * 2 + 1] = dB_vals[i].imag_;
399
- }
400
- Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f);
401
- }
402
- if constexpr (kIsVariableC) {
403
- #pragma unroll
404
- for (int i = 0; i < kNItems; ++i) {
405
- dC_vals_f[i * 2] = dC_vals[i].real_;
406
- dC_vals_f[i * 2 + 1] = dC_vals[i].imag_;
407
- }
408
- auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1;
409
- Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f);
410
- }
411
- const int seqlen_remaining = (params.seqlen - chunk * kChunkSize) * 2 - threadIdx.x;
412
- float *dB_cur = reinterpret_cast<float *>(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x;
413
- float *dC_cur = reinterpret_cast<float *>(dC) + state_idx * params.dC_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x;
414
- #pragma unroll
415
- for (int i = 0; i < kNItems * 2; ++i) {
416
- if (i * kNThreads < seqlen_remaining) {
417
- if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals_f[i]); }
418
- if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals_f[i]); }
419
- }
420
- }
421
- }
422
- if constexpr (!kIsVariableB || !kIsVariableC) {
423
- float4 dA_dBC_val = make_float4(dA_val.real_, dA_val.imag_, dBC_val.real_, dBC_val.imag_);
424
- dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val);
425
- dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y);
426
- dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w);
427
- if (threadIdx.x == 0) {
428
- smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx];
429
- }
430
- } else {
431
- dA_val = Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val);
432
- }
433
- if (threadIdx.x == 0) {
434
- smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];
435
- }
436
- }
437
- }
438
-
439
- if constexpr (kDeltaSoftplus) {
440
- __syncthreads();
441
- input_t delta_vals_load[kNItems];
442
- load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
443
- delta -= kChunkSize;
444
- #pragma unroll
445
- for (int i = 0; i < kNItems; ++i) {
446
- float delta_val = float(delta_vals_load[i]) + delta_bias;
447
- float delta_val_neg_exp = expf(-delta_val);
448
- ddelta_vals[i] = delta_val <= 20.f
449
- ? ddelta_vals[i] / (1.f + delta_val_neg_exp)
450
- : ddelta_vals[i];
451
- }
452
- }
453
- for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; }
454
-
455
- input_t *du = reinterpret_cast<input_t *>(params.du_ptr) + batch_id * params.du_batch_stride
456
- + dim_id * params.du_d_stride + chunk * kChunkSize;
457
- input_t *ddelta = reinterpret_cast<input_t *>(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride
458
- + dim_id * params.ddelta_d_stride + chunk * kChunkSize;
459
- __syncthreads();
460
- store_output<Ktraits>(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize);
461
- __syncthreads();
462
- store_output<Ktraits>(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize);
463
-
464
- Bvar -= kChunkSize * (!kIsComplex ? 1 : 2);
465
- Cvar -= kChunkSize * (!kIsComplex ? 1 : 2);
466
- }
467
- if (params.dD_ptr != nullptr) {
468
- dD_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val);
469
- if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); }
470
- }
471
- if (params.ddelta_bias_ptr != nullptr) {
472
- __syncthreads();
473
- ddelta_bias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val);
474
- if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); }
475
- }
476
- for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
477
- gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride]), smem_da[state_idx]);
478
- weight_t dBC_val;
479
- if (!kIsVariableB || !kIsVariableC) { dBC_val = smem_dbc[state_idx]; }
480
- if constexpr (!kIsVariableB) {
481
- gpuAtomicAdd(&(dB[state_idx * params.dB_dstate_stride]),
482
- !kIsVariableC ? dBC_val * conj(C[state_idx * params.C_dstate_stride]) : dBC_val);
483
- }
484
- if constexpr (!kIsVariableC) {
485
- gpuAtomicAdd(&(dC[state_idx * params.dC_dstate_stride]),
486
- !kIsVariableB ? dBC_val * conj(B[state_idx * params.B_dstate_stride]) : dBC_val);
487
- }
488
- }
489
- }
490
-
491
- template<int kNThreads, int kNItems, typename input_t, typename weight_t>
492
- void selective_scan_bwd_launch(SSMParamsBwd &params, cudaStream_t stream) {
493
- BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
494
- BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
495
- BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
496
- BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] {
497
- BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
498
- using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
499
- // using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, true, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
500
- // TODO: check this
501
- constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t);
502
- // printf("smem_size = %d\n", kSmemSize);
503
- dim3 grid(params.batch, params.dim);
504
- auto kernel = &selective_scan_bwd_kernel<Ktraits>;
505
- if (kSmemSize >= 48 * 1024) {
506
- C10_CUDA_CHECK(cudaFuncSetAttribute(
507
- kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
508
- }
509
- kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
510
- C10_CUDA_KERNEL_LAUNCH_CHECK();
511
- });
512
- });
513
- });
514
- });
515
- });
516
- }
517
-
518
- template<typename input_t, typename weight_t>
519
- void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream) {
520
- if (params.seqlen <= 128) {
521
- selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream);
522
- } else if (params.seqlen <= 256) {
523
- selective_scan_bwd_launch<32, 8, input_t, weight_t>(params, stream);
524
- } else if (params.seqlen <= 512) {
525
- selective_scan_bwd_launch<32, 16, input_t, weight_t>(params, stream);
526
- } else if (params.seqlen <= 1024) {
527
- selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);
528
- } else {
529
- selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);
530
- }
531
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mamba/csrc/selective_scan/selective_scan_common.h DELETED
@@ -1,221 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2023, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include <cuda_bf16.h>
8
- #include <cuda_fp16.h>
9
- #include <c10/util/complex.h> // For scalar_value_type
10
-
11
- #define MAX_DSTATE 256
12
-
13
- using complex_t = c10::complex<float>;
14
-
15
- inline __device__ float2 operator+(const float2 & a, const float2 & b){
16
- return {a.x + b.x, a.y + b.y};
17
- }
18
-
19
- inline __device__ float3 operator+(const float3 &a, const float3 &b) {
20
- return {a.x + b.x, a.y + b.y, a.z + b.z};
21
- }
22
-
23
- inline __device__ float4 operator+(const float4 & a, const float4 & b){
24
- return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
25
- }
26
-
27
- ////////////////////////////////////////////////////////////////////////////////////////////////////
28
-
29
- template<int BYTES> struct BytesToType {};
30
-
31
- template<> struct BytesToType<16> {
32
- using Type = uint4;
33
- static_assert(sizeof(Type) == 16);
34
- };
35
-
36
- template<> struct BytesToType<8> {
37
- using Type = uint64_t;
38
- static_assert(sizeof(Type) == 8);
39
- };
40
-
41
- template<> struct BytesToType<4> {
42
- using Type = uint32_t;
43
- static_assert(sizeof(Type) == 4);
44
- };
45
-
46
- template<> struct BytesToType<2> {
47
- using Type = uint16_t;
48
- static_assert(sizeof(Type) == 2);
49
- };
50
-
51
- template<> struct BytesToType<1> {
52
- using Type = uint8_t;
53
- static_assert(sizeof(Type) == 1);
54
- };
55
-
56
- ////////////////////////////////////////////////////////////////////////////////////////////////////
57
-
58
- template<typename scalar_t, int N>
59
- struct Converter{
60
- static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) {
61
- #pragma unroll
62
- for (int i = 0; i < N; ++i) { dst[i] = src[i]; }
63
- }
64
- };
65
-
66
- template<int N>
67
- struct Converter<at::Half, N>{
68
- static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) {
69
- static_assert(N % 2 == 0);
70
- auto &src2 = reinterpret_cast<const half2 (&)[N / 2]>(src);
71
- auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
72
- #pragma unroll
73
- for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); }
74
- }
75
- };
76
-
77
- #if __CUDA_ARCH__ >= 800
78
- template<int N>
79
- struct Converter<at::BFloat16, N>{
80
- static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) {
81
- static_assert(N % 2 == 0);
82
- auto &src2 = reinterpret_cast<const nv_bfloat162 (&)[N / 2]>(src);
83
- auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
84
- #pragma unroll
85
- for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); }
86
- }
87
- };
88
- #endif
89
-
90
- ////////////////////////////////////////////////////////////////////////////////////////////////////
91
-
92
- // From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp
93
- // and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696
94
- __device__ __forceinline__ complex_t cexp2f(complex_t z) {
95
- float t = exp2f(z.real_);
96
- float c, s;
97
- sincosf(z.imag_, &s, &c);
98
- return complex_t(c * t, s * t);
99
- }
100
-
101
- __device__ __forceinline__ complex_t cexpf(complex_t z) {
102
- float t = expf(z.real_);
103
- float c, s;
104
- sincosf(z.imag_, &s, &c);
105
- return complex_t(c * t, s * t);
106
- }
107
-
108
- template<typename scalar_t> struct SSMScanOp;
109
-
110
- template<>
111
- struct SSMScanOp<float> {
112
- __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const {
113
- return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
114
- }
115
- };
116
-
117
- template<>
118
- struct SSMScanOp<complex_t> {
119
- __device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const {
120
- complex_t a0 = complex_t(ab0.x, ab0.y);
121
- complex_t b0 = complex_t(ab0.z, ab0.w);
122
- complex_t a1 = complex_t(ab1.x, ab1.y);
123
- complex_t b1 = complex_t(ab1.z, ab1.w);
124
- complex_t out_a = a1 * a0;
125
- complex_t out_b = a1 * b0 + b1;
126
- return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_);
127
- }
128
- };
129
-
130
- // A stateful callback functor that maintains a running prefix to be applied
131
- // during consecutive scan operations.
132
- template <typename scalar_t> struct SSMScanPrefixCallbackOp {
133
- using scan_t = std::conditional_t<std::is_same_v<scalar_t, float>, float2, float4>;
134
- scan_t running_prefix;
135
- // Constructor
136
- __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {}
137
- // Callback operator to be entered by the first warp of threads in the block.
138
- // Thread-0 is responsible for returning a value for seeding the block-wide scan.
139
- __device__ scan_t operator()(scan_t block_aggregate) {
140
- scan_t old_prefix = running_prefix;
141
- running_prefix = SSMScanOp<scalar_t>()(running_prefix, block_aggregate);
142
- return old_prefix;
143
- }
144
- };
145
-
146
- ////////////////////////////////////////////////////////////////////////////////////////////////////
147
-
148
- template<typename Ktraits>
149
- inline __device__ void load_input(typename Ktraits::input_t *u,
150
- typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
151
- typename Ktraits::BlockLoadT::TempStorage &smem_load,
152
- int seqlen) {
153
- if constexpr (Ktraits::kIsEvenLen) {
154
- auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
155
- using vec_t = typename Ktraits::vec_t;
156
- Ktraits::BlockLoadVecT(smem_load_vec).Load(
157
- reinterpret_cast<vec_t*>(u),
158
- reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(u_vals)
159
- );
160
- } else {
161
- Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);
162
- }
163
- }
164
-
165
- template<typename Ktraits>
166
- inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
167
- typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
168
- typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight,
169
- int seqlen) {
170
- constexpr int kNItems = Ktraits::kNItems;
171
- if constexpr (!Ktraits::kIsComplex) {
172
- typename Ktraits::input_t B_vals_load[kNItems];
173
- if constexpr (Ktraits::kIsEvenLen) {
174
- auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
175
- using vec_t = typename Ktraits::vec_t;
176
- Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
177
- reinterpret_cast<vec_t*>(Bvar),
178
- reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(B_vals_load)
179
- );
180
- } else {
181
- Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
182
- }
183
- // #pragma unroll
184
- // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
185
- Converter<typename Ktraits::input_t, kNItems>::to_float(B_vals_load, B_vals);
186
- } else {
187
- typename Ktraits::input_t B_vals_load[kNItems * 2];
188
- if constexpr (Ktraits::kIsEvenLen) {
189
- auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
190
- using vec_t = typename Ktraits::vec_t;
191
- Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
192
- reinterpret_cast<vec_t*>(Bvar),
193
- reinterpret_cast<vec_t(&)[Ktraits::kNLoads * 2]>(B_vals_load)
194
- );
195
- } else {
196
- Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
197
- }
198
- #pragma unroll
199
- for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); }
200
- }
201
- }
202
-
203
- template<typename Ktraits>
204
- inline __device__ void store_output(typename Ktraits::input_t *out,
205
- const float (&out_vals)[Ktraits::kNItems],
206
- typename Ktraits::BlockStoreT::TempStorage &smem_store,
207
- int seqlen) {
208
- typename Ktraits::input_t write_vals[Ktraits::kNItems];
209
- #pragma unroll
210
- for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
211
- if constexpr (Ktraits::kIsEvenLen) {
212
- auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
213
- using vec_t = typename Ktraits::vec_t;
214
- Ktraits::BlockStoreVecT(smem_store_vec).Store(
215
- reinterpret_cast<vec_t*>(out),
216
- reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(write_vals)
217
- );
218
- } else {
219
- Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);
220
- }
221
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mamba/csrc/selective_scan/selective_scan_fwd_bf16.cu DELETED
@@ -1,10 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2023, Tri Dao.
3
- ******************************************************************************/
4
-
5
- // Split into multiple files to compile in paralell
6
-
7
- #include "selective_scan_fwd_kernel.cuh"
8
-
9
- template void selective_scan_fwd_cuda<at::BFloat16, float>(SSMParamsBase &params, cudaStream_t stream);
10
- template void selective_scan_fwd_cuda<at::BFloat16, complex_t>(SSMParamsBase &params, cudaStream_t stream);
 
 
 
 
 
 
 
 
 
 
 
mamba/csrc/selective_scan/selective_scan_fwd_fp16.cu DELETED
@@ -1,10 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2023, Tri Dao.
3
- ******************************************************************************/
4
-
5
- // Split into multiple files to compile in paralell
6
-
7
- #include "selective_scan_fwd_kernel.cuh"
8
-
9
- template void selective_scan_fwd_cuda<at::Half, float>(SSMParamsBase &params, cudaStream_t stream);
10
- template void selective_scan_fwd_cuda<at::Half, complex_t>(SSMParamsBase &params, cudaStream_t stream);
 
 
 
 
 
 
 
 
 
 
 
mamba/csrc/selective_scan/selective_scan_fwd_fp32.cu DELETED
@@ -1,10 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2023, Tri Dao.
3
- ******************************************************************************/
4
-
5
- // Split into multiple files to compile in paralell
6
-
7
- #include "selective_scan_fwd_kernel.cuh"
8
-
9
- template void selective_scan_fwd_cuda<float, float>(SSMParamsBase &params, cudaStream_t stream);
10
- template void selective_scan_fwd_cuda<float, complex_t>(SSMParamsBase &params, cudaStream_t stream);
 
 
 
 
 
 
 
 
 
 
 
mamba/csrc/selective_scan/selective_scan_fwd_kernel.cuh DELETED
@@ -1,345 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2023, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include <c10/util/BFloat16.h>
8
- #include <c10/util/Half.h>
9
- #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
10
-
11
- #include <cub/block/block_load.cuh>
12
- #include <cub/block/block_store.cuh>
13
- #include <cub/block/block_scan.cuh>
14
-
15
- #include "selective_scan.h"
16
- #include "selective_scan_common.h"
17
- #include "static_switch.h"
18
-
19
- template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
20
- bool kIsVariableB_, bool kIsVariableC_,
21
- bool kHasZ_, typename input_t_, typename weight_t_>
22
- struct Selective_Scan_fwd_kernel_traits {
23
- static_assert(kNItems_ % 4 == 0);
24
- using input_t = input_t_;
25
- using weight_t = weight_t_;
26
- static constexpr int kNThreads = kNThreads_;
27
- // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
28
- static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
29
- static constexpr int kNItems = kNItems_;
30
- static constexpr int kNRows = kNRows_;
31
- static constexpr int kNBytes = sizeof(input_t);
32
- static_assert(kNBytes == 2 || kNBytes == 4);
33
- static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
34
- static_assert(kNItems % kNElts == 0);
35
- static constexpr int kNLoads = kNItems / kNElts;
36
- static constexpr bool kIsComplex = std::is_same_v<weight_t, complex_t>;
37
- static constexpr bool kIsEvenLen = kIsEvenLen_;
38
- static constexpr bool kIsVariableB = kIsVariableB_;
39
- static constexpr bool kIsVariableC = kIsVariableC_;
40
- static constexpr bool kHasZ = kHasZ_;
41
-
42
- static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;
43
-
44
- using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
45
- using scan_t = std::conditional_t<!kIsComplex, float2, float4>;
46
- using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
47
- using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
48
- !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
49
- using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, !kIsComplex ? kNItems : kNItems * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
50
- using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, !kIsComplex ? kNLoads : kNLoads * 2,
51
- !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
52
- using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
53
- using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads,
54
- !kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
55
- // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
56
- // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
57
- using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
58
- static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
59
- sizeof(typename BlockLoadVecT::TempStorage),
60
- (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
61
- (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
62
- sizeof(typename BlockStoreT::TempStorage),
63
- sizeof(typename BlockStoreVecT::TempStorage)});
64
- static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
65
- };
66
-
67
- template<typename Ktraits>
68
- __global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
69
- void selective_scan_fwd_kernel(SSMParamsBase params) {
70
- constexpr bool kIsComplex = Ktraits::kIsComplex;
71
- constexpr bool kIsVariableB = Ktraits::kIsVariableB;
72
- constexpr bool kIsVariableC = Ktraits::kIsVariableC;
73
- constexpr bool kHasZ = Ktraits::kHasZ;
74
- constexpr int kNThreads = Ktraits::kNThreads;
75
- constexpr int kNItems = Ktraits::kNItems;
76
- constexpr int kNRows = Ktraits::kNRows;
77
- constexpr bool kDirectIO = Ktraits::kDirectIO;
78
- using input_t = typename Ktraits::input_t;
79
- using weight_t = typename Ktraits::weight_t;
80
- using scan_t = typename Ktraits::scan_t;
81
-
82
- // Shared memory.
83
- extern __shared__ char smem_[];
84
- // cast to lvalue reference of expected type
85
- // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
86
- // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
87
- // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
88
- auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
89
- auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
90
- auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
91
- auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
92
- auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
93
- // weight_t *smem_a = reinterpret_cast<weight_t *>(smem_ + smem_loadstorescan_size);
94
- // weight_t *smem_bc = reinterpret_cast<weight_t *>(smem_a + MAX_DSTATE);
95
- scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize);
96
-
97
- const int batch_id = blockIdx.x;
98
- const int dim_id = blockIdx.y;
99
- const int group_id = dim_id / (params.dim_ngroups_ratio);
100
- input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
101
- + dim_id * kNRows * params.u_d_stride;
102
- input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
103
- + dim_id * kNRows * params.delta_d_stride;
104
- weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * kNRows * params.A_d_stride;
105
- weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * kNRows * params.B_d_stride;
106
- input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
107
- weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
108
- input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
109
- scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate;
110
-
111
- float D_val[kNRows] = {0};
112
- if (params.D_ptr != nullptr) {
113
- #pragma unroll
114
- for (int r = 0; r < kNRows; ++r) {
115
- D_val[r] = reinterpret_cast<float *>(params.D_ptr)[dim_id * kNRows + r];
116
- }
117
- }
118
- float delta_bias[kNRows] = {0};
119
- if (params.delta_bias_ptr != nullptr) {
120
- #pragma unroll
121
- for (int r = 0; r < kNRows; ++r) {
122
- delta_bias[r] = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id * kNRows + r];
123
- }
124
- }
125
-
126
- // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
127
- // smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
128
- // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];
129
- // }
130
-
131
- constexpr int kChunkSize = kNThreads * kNItems;
132
- for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
133
- input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
134
- __syncthreads();
135
- #pragma unroll
136
- for (int r = 0; r < kNRows; ++r) {
137
- if constexpr (!kDirectIO) {
138
- if (r > 0) { __syncthreads(); }
139
- }
140
- load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize);
141
- if constexpr (!kDirectIO) { __syncthreads(); }
142
- load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize);
143
- }
144
- u += kChunkSize;
145
- delta += kChunkSize;
146
-
147
- float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
148
- #pragma unroll
149
- for (int r = 0; r < kNRows; ++r) {
150
- #pragma unroll
151
- for (int i = 0; i < kNItems; ++i) {
152
- float u_val = float(u_vals[r][i]);
153
- delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r];
154
- if (params.delta_softplus) {
155
- delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i];
156
- }
157
- delta_u_vals[r][i] = delta_vals[r][i] * u_val;
158
- out_vals[r][i] = D_val[r] * u_val;
159
- }
160
- }
161
-
162
- __syncthreads();
163
- for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
164
- weight_t A_val[kNRows];
165
- #pragma unroll
166
- for (int r = 0; r < kNRows; ++r) {
167
- A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride];
168
- // Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
169
- constexpr float kLog2e = M_LOG2E;
170
- if constexpr (!kIsComplex) {
171
- A_val[r] *= kLog2e;
172
- } else {
173
- A_val[r].real_ *= kLog2e;
174
- }
175
- }
176
- // This variable holds B * C if both B and C are constant across seqlen. If only B varies
177
- // across seqlen, this holds C. If only C varies across seqlen, this holds B.
178
- // If both B and C vary, this is unused.
179
- weight_t BC_val[kNRows];
180
- weight_t B_vals[kNItems], C_vals[kNItems];
181
- if constexpr (kIsVariableB) {
182
- load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
183
- smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
184
- if constexpr (!kIsVariableC) {
185
- #pragma unroll
186
- for (int r = 0; r < kNRows; ++r) {
187
- BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
188
- }
189
- }
190
- }
191
- if constexpr (kIsVariableC) {
192
- auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
193
- load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
194
- smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
195
- if constexpr (!kIsVariableB) {
196
- #pragma unroll
197
- for (int r = 0; r < kNRows; ++r) {
198
- BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride];
199
- }
200
- }
201
- }
202
- if constexpr (!kIsVariableB && !kIsVariableC) {
203
- #pragma unroll
204
- for (int r = 0; r < kNRows; ++r) {
205
- BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
206
- }
207
- }
208
-
209
- #pragma unroll
210
- for (int r = 0; r < kNRows; ++r) {
211
- if (r > 0) { __syncthreads(); } // Scan could be using the same smem
212
- scan_t thread_data[kNItems];
213
- #pragma unroll
214
- for (int i = 0; i < kNItems; ++i) {
215
- if constexpr (!kIsComplex) {
216
- thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
217
- !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
218
- if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
219
- if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
220
- thread_data[i] = make_float2(1.f, 0.f);
221
- }
222
- }
223
- } else {
224
- // Pytorch's implementation of complex exp (which calls thrust) is very slow
225
- complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_val[r]);
226
- weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i];
227
- thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);
228
- if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
229
- if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
230
- thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f);
231
- }
232
- }
233
- }
234
- }
235
- // Initialize running total
236
- scan_t running_prefix;
237
- if constexpr (!kIsComplex) {
238
- // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read
239
- running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f);
240
- // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
241
- } else {
242
- running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float4(1.f, 0.f, 0.f, 0.f);
243
- // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
244
- }
245
- SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
246
- Ktraits::BlockScanT(smem_scan).InclusiveScan(
247
- thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
248
- );
249
- // There's a syncthreads in the scan op, so we don't need to sync here.
250
- // Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
251
- if (threadIdx.x == 0) {
252
- smem_running_prefix[state_idx] = prefix_op.running_prefix;
253
- x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix;
254
- }
255
- #pragma unroll
256
- for (int i = 0; i < kNItems; ++i) {
257
- const weight_t C_val = !kIsVariableC
258
- ? BC_val[r]
259
- : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]);
260
- if constexpr (!kIsComplex) {
261
- out_vals[r][i] += thread_data[i].y * C_val;
262
- } else {
263
- out_vals[r][i] += (complex_t(thread_data[i].z, thread_data[i].w) * C_val).real_ * 2;
264
- }
265
- }
266
- }
267
- }
268
-
269
- input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
270
- + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
271
- __syncthreads();
272
- #pragma unroll
273
- for (int r = 0; r < kNRows; ++r) {
274
- if constexpr (!kDirectIO) {
275
- if (r > 0) { __syncthreads(); }
276
- }
277
- store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
278
- }
279
-
280
- if constexpr (kHasZ) {
281
- input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride
282
- + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
283
- input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride
284
- + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize;
285
- #pragma unroll
286
- for (int r = 0; r < kNRows; ++r) {
287
- input_t z_vals[kNItems];
288
- __syncthreads();
289
- load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize);
290
- #pragma unroll
291
- for (int i = 0; i < kNItems; ++i) {
292
- float z_val = z_vals[i];
293
- out_vals[r][i] *= z_val / (1 + expf(-z_val));
294
- }
295
- __syncthreads();
296
- store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
297
- }
298
- }
299
-
300
- Bvar += kChunkSize * (!kIsComplex ? 1 : 2);
301
- Cvar += kChunkSize * (!kIsComplex ? 1 : 2);
302
- }
303
- }
304
-
305
- template<int kNThreads, int kNItems, typename input_t, typename weight_t>
306
- void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
307
- // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
308
- // processing 1 row.
309
- constexpr int kNRows = 1;
310
- BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
311
- BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
312
- BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
313
- BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
314
- using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, input_t, weight_t>;
315
- // constexpr int kSmemSize = Ktraits::kSmemSize;
316
- constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
317
- // printf("smem_size = %d\n", kSmemSize);
318
- dim3 grid(params.batch, params.dim / kNRows);
319
- auto kernel = &selective_scan_fwd_kernel<Ktraits>;
320
- if (kSmemSize >= 48 * 1024) {
321
- C10_CUDA_CHECK(cudaFuncSetAttribute(
322
- kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
323
- }
324
- kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
325
- C10_CUDA_KERNEL_LAUNCH_CHECK();
326
- });
327
- });
328
- });
329
- });
330
- }
331
-
332
- template<typename input_t, typename weight_t>
333
- void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
334
- if (params.seqlen <= 128) {
335
- selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
336
- } else if (params.seqlen <= 256) {
337
- selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream);
338
- } else if (params.seqlen <= 512) {
339
- selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream);
340
- } else if (params.seqlen <= 1024) {
341
- selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
342
- } else {
343
- selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
344
- }
345
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mamba/csrc/selective_scan/static_switch.h DELETED
@@ -1,25 +0,0 @@
1
- // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
2
- // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
3
-
4
- #pragma once
5
-
6
- /// @param COND - a boolean expression to switch by
7
- /// @param CONST_NAME - a name given for the constexpr bool variable.
8
- /// @param ... - code to execute for true and false
9
- ///
10
- /// Usage:
11
- /// ```
12
- /// BOOL_SWITCH(flag, BoolConst, [&] {
13
- /// some_function<BoolConst>(...);
14
- /// });
15
- /// ```
16
- #define BOOL_SWITCH(COND, CONST_NAME, ...) \
17
- [&] { \
18
- if (COND) { \
19
- constexpr bool CONST_NAME = true; \
20
- return __VA_ARGS__(); \
21
- } else { \
22
- constexpr bool CONST_NAME = false; \
23
- return __VA_ARGS__(); \
24
- } \
25
- }()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mamba/csrc/selective_scan/uninitialized_copy.cuh DELETED
@@ -1,69 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved.
3
- *
4
- * Redistribution and use in source and binary forms, with or without
5
- * modification, are permitted provided that the following conditions are met:
6
- * * Redistributions of source code must retain the above copyright
7
- * notice, this list of conditions and the following disclaimer.
8
- * * Redistributions in binary form must reproduce the above copyright
9
- * notice, this list of conditions and the following disclaimer in the
10
- * documentation and/or other materials provided with the distribution.
11
- * * Neither the name of the NVIDIA CORPORATION nor the
12
- * names of its contributors may be used to endorse or promote products
13
- * derived from this software without specific prior written permission.
14
- *
15
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18
- * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19
- * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20
- * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21
- * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22
- * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23
- * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24
- * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25
- *
26
- ******************************************************************************/
27
-
28
- #pragma once
29
-
30
- #include <cub/config.cuh>
31
-
32
- #include <cuda/std/type_traits>
33
-
34
-
35
- namespace detail
36
- {
37
-
38
- #if defined(_NVHPC_CUDA)
39
- template <typename T, typename U>
40
- __host__ __device__ void uninitialized_copy(T *ptr, U &&val)
41
- {
42
- // NVBug 3384810
43
- new (ptr) T(::cuda::std::forward<U>(val));
44
- }
45
- #else
46
- template <typename T,
47
- typename U,
48
- typename ::cuda::std::enable_if<
49
- ::cuda::std::is_trivially_copyable<T>::value,
50
- int
51
- >::type = 0>
52
- __host__ __device__ void uninitialized_copy(T *ptr, U &&val)
53
- {
54
- *ptr = ::cuda::std::forward<U>(val);
55
- }
56
-
57
- template <typename T,
58
- typename U,
59
- typename ::cuda::std::enable_if<
60
- !::cuda::std::is_trivially_copyable<T>::value,
61
- int
62
- >::type = 0>
63
- __host__ __device__ void uninitialized_copy(T *ptr, U &&val)
64
- {
65
- new (ptr) T(::cuda::std::forward<U>(val));
66
- }
67
- #endif
68
-
69
- } // namespace detail
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mamba/evals/lm_harness_eval.py DELETED
@@ -1,39 +0,0 @@
1
- import torch
2
-
3
- import transformers
4
- from transformers import AutoTokenizer
5
-
6
- from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
7
-
8
- from lm_eval.api.model import LM
9
- from lm_eval.models.huggingface import HFLM
10
- from lm_eval.api.registry import register_model
11
- from lm_eval.__main__ import cli_evaluate
12
-
13
-
14
- @register_model("mamba")
15
- class MambaEvalWrapper(HFLM):
16
-
17
- AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
18
-
19
- def __init__(self, pretrained="state-spaces/mamba-2.8b", max_length=2048, batch_size=None, device="cuda",
20
- dtype=torch.float16):
21
- LM.__init__(self)
22
- self._model = MambaLMHeadModel.from_pretrained(pretrained, device=device, dtype=dtype)
23
- self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
24
- self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
25
- self.vocab_size = self.tokenizer.vocab_size
26
- self._batch_size = batch_size if batch_size is None else 64
27
- self._max_length = max_length
28
- self._device = torch.device(device)
29
-
30
- @property
31
- def batch_size(self):
32
- return self._batch_size
33
-
34
- def _model_generate(self, context, max_length, stop, **generation_kwargs):
35
- raise NotImplementedError()
36
-
37
-
38
- if __name__ == "__main__":
39
- cli_evaluate()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mamba/mamba_ssm/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- __version__ = "1.0.1"
2
-
3
- from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn
4
- from mamba_ssm.modules.mamba_simple import Mamba
5
- from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
 
 
 
 
 
 
mamba/mamba_ssm/models/__init__.py DELETED
File without changes
mamba/mamba_ssm/models/mixer_seq_simple.py DELETED
@@ -1,233 +0,0 @@
1
- # Copyright (c) 2023, Albert Gu, Tri Dao.
2
-
3
- import math
4
- from functools import partial
5
-
6
- from collections import namedtuple
7
-
8
- import torch
9
- import torch.nn as nn
10
-
11
- from mamba_ssm.modules.mamba_simple import Mamba, Block
12
- from mamba_ssm.utils.generation import GenerationMixin
13
- from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
14
-
15
- try:
16
- from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
17
- except ImportError:
18
- RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
19
-
20
-
21
- def create_block(
22
- d_model,
23
- ssm_cfg=None,
24
- norm_epsilon=1e-5,
25
- rms_norm=False,
26
- residual_in_fp32=False,
27
- fused_add_norm=False,
28
- layer_idx=None,
29
- device=None,
30
- dtype=None,
31
- ):
32
- if ssm_cfg is None:
33
- ssm_cfg = {}
34
- factory_kwargs = {"device": device, "dtype": dtype}
35
- mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
36
- norm_cls = partial(
37
- nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
38
- )
39
- block = Block(
40
- d_model,
41
- mixer_cls,
42
- norm_cls=norm_cls,
43
- fused_add_norm=fused_add_norm,
44
- residual_in_fp32=residual_in_fp32,
45
- )
46
- block.layer_idx = layer_idx
47
- return block
48
-
49
-
50
- # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
51
- def _init_weights(
52
- module,
53
- n_layer,
54
- initializer_range=0.02, # Now only used for embedding layer.
55
- rescale_prenorm_residual=True,
56
- n_residuals_per_layer=1, # Change to 2 if we have MLP
57
- ):
58
- if isinstance(module, nn.Linear):
59
- if module.bias is not None:
60
- if not getattr(module.bias, "_no_reinit", False):
61
- nn.init.zeros_(module.bias)
62
- elif isinstance(module, nn.Embedding):
63
- nn.init.normal_(module.weight, std=initializer_range)
64
-
65
- if rescale_prenorm_residual:
66
- # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
67
- # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
68
- # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
69
- # > -- GPT-2 :: https://openai.com/blog/better-language-models/
70
- #
71
- # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
72
- for name, p in module.named_parameters():
73
- if name in ["out_proj.weight", "fc2.weight"]:
74
- # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
75
- # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
76
- # We need to reinit p since this code could be called multiple times
77
- # Having just p *= scale would repeatedly scale it down
78
- nn.init.kaiming_uniform_(p, a=math.sqrt(5))
79
- with torch.no_grad():
80
- p /= math.sqrt(n_residuals_per_layer * n_layer)
81
-
82
-
83
- class MixerModel(nn.Module):
84
- def __init__(
85
- self,
86
- d_model: int,
87
- n_layer: int,
88
- vocab_size: int,
89
- ssm_cfg=None,
90
- norm_epsilon: float = 1e-5,
91
- rms_norm: bool = False,
92
- initializer_cfg=None,
93
- fused_add_norm=False,
94
- residual_in_fp32=False,
95
- device=None,
96
- dtype=None,
97
- ) -> None:
98
- factory_kwargs = {"device": device, "dtype": dtype}
99
- super().__init__()
100
- self.residual_in_fp32 = residual_in_fp32
101
-
102
- self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
103
-
104
- # We change the order of residual and layer norm:
105
- # Instead of LN -> Attn / MLP -> Add, we do:
106
- # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
107
- # the main branch (output of MLP / Mixer). The model definition is unchanged.
108
- # This is for performance reason: we can fuse add + layer_norm.
109
- self.fused_add_norm = fused_add_norm
110
- if self.fused_add_norm:
111
- if layer_norm_fn is None or rms_norm_fn is None:
112
- raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
113
-
114
- self.layers = nn.ModuleList(
115
- [
116
- create_block(
117
- d_model,
118
- ssm_cfg=ssm_cfg,
119
- norm_epsilon=norm_epsilon,
120
- rms_norm=rms_norm,
121
- residual_in_fp32=residual_in_fp32,
122
- fused_add_norm=fused_add_norm,
123
- layer_idx=i,
124
- **factory_kwargs,
125
- )
126
- for i in range(n_layer)
127
- ]
128
- )
129
-
130
- self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
131
- d_model, eps=norm_epsilon, **factory_kwargs
132
- )
133
-
134
- self.apply(
135
- partial(
136
- _init_weights,
137
- n_layer=n_layer,
138
- **(initializer_cfg if initializer_cfg is not None else {}),
139
- )
140
- )
141
-
142
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
143
- return {
144
- i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
145
- for i, layer in enumerate(self.layers)
146
- }
147
-
148
- def forward(self, input_ids, inference_params=None):
149
- hidden_states = self.embedding(input_ids)
150
- residual = None
151
- for layer in self.layers:
152
- hidden_states, residual = layer(
153
- hidden_states, residual, inference_params=inference_params
154
- )
155
- if not self.fused_add_norm:
156
- residual = (hidden_states + residual) if residual is not None else hidden_states
157
- hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
158
- else:
159
- # Set prenorm=False here since we don't need the residual
160
- fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
161
- hidden_states = fused_add_norm_fn(
162
- hidden_states,
163
- self.norm_f.weight,
164
- self.norm_f.bias,
165
- eps=self.norm_f.eps,
166
- residual=residual,
167
- prenorm=False,
168
- residual_in_fp32=self.residual_in_fp32,
169
- )
170
- return hidden_states
171
-
172
-
173
- class MambaLMHeadModel(nn.Module, GenerationMixin):
174
-
175
- def __init__(
176
- self,
177
- d_model: int,
178
- n_layer: int,
179
- vocab_size: int,
180
- initializer_cfg=None,
181
- pad_vocab_size_multiple: int = 1,
182
- device=None,
183
- dtype=None,
184
- **backbone_kwargs,
185
- ) -> None:
186
- factory_kwargs = {"device": device, "dtype": dtype}
187
- super().__init__()
188
- if vocab_size % pad_vocab_size_multiple != 0:
189
- vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
190
- self.backbone = MixerModel(
191
- d_model=d_model,
192
- n_layer=n_layer,
193
- vocab_size=vocab_size,
194
- initializer_cfg=initializer_cfg,
195
- **backbone_kwargs,
196
- **factory_kwargs,
197
- )
198
- self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
199
-
200
- # Initialize weights and apply final processing
201
- self.apply(
202
- partial(
203
- _init_weights,
204
- n_layer=n_layer,
205
- **(initializer_cfg if initializer_cfg is not None else {}),
206
- )
207
- )
208
- self.tie_weights()
209
-
210
- def tie_weights(self):
211
- self.lm_head.weight = self.backbone.embedding.weight
212
-
213
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
214
- return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
215
-
216
- def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
217
- """
218
- "position_ids" is just to be compatible with Transformer generation. We don't use it.
219
- num_last_tokens: if > 0, only return the logits for the last n tokens
220
- """
221
- hidden_states = self.backbone(input_ids, inference_params=inference_params)
222
- if num_last_tokens > 0:
223
- hidden_states = hidden_states[:, -num_last_tokens:]
224
- lm_logits = self.lm_head(hidden_states)
225
- CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
226
- return CausalLMOutput(logits=lm_logits)
227
-
228
- @classmethod
229
- def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
230
- config = load_config_hf(pretrained_model_name)
231
- model = cls(**config, device=device, dtype=dtype, **kwargs)
232
- model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
233
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mamba/mamba_ssm/modules/__init__.py DELETED
File without changes