Somunia commited on
Commit
8b19012
1 Parent(s): 6bfd1d4

Upload 28 files

Browse files
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import time
4
+
5
+ def generate_prompt(instruction, input=""):
6
+ instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
7
+ input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
8
+ if input:
9
+ return f"""Instruction: {instruction}
10
+
11
+ Input: {input}
12
+
13
+ Response:"""
14
+ else:
15
+ return f"""User: hi
16
+
17
+ Lover: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
18
+
19
+ User: {instruction}
20
+
21
+ Lover:"""
22
+
23
+ model_path = "models/rwkv-6-world-1b6/" # Path to your local model directory
24
+
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ model_path,
27
+ trust_remote_code=True,
28
+ use_flash_attention_2=False # Explicitly disable Flash Attention
29
+ ).to(torch.float32)
30
+
31
+
32
+ tokenizer = AutoTokenizer.from_pretrained(
33
+ model_path,
34
+ bos_token="</s>",
35
+ eos_token="</ s>",
36
+ unk_token="<unk>",
37
+ pad_token="<pad>",
38
+ trust_remote_code=True,
39
+ padding_side='left',
40
+ clean_up_tokenization_spaces=False # Or set to True if you prefer
41
+ )
42
+
43
+ print(tokenizer.special_tokens_map)
44
+
45
+ text = "Hi"
46
+
47
+ prompt = generate_prompt(text)
48
+
49
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
50
+
51
+ # Generate text word by word with stop sequence
52
+ generated_text = ""
53
+ for i in range(333): # Generate up to 333 tokens
54
+ output = model.generate(input_ids, max_new_tokens=1, do_sample=True, temperature=1.0, top_p=0.3, top_k=0)
55
+ new_word = tokenizer.decode(output[0][-1:], skip_special_tokens=True)
56
+
57
+ print(new_word, end="", flush=True) # Print word-by-word
58
+ generated_text += new_word
59
+
60
+ input_ids = output # Update input_ids for next iteration
61
+
62
+ print() # Add a newline at the end
causal-conv1d/.github/workflows/publish.yaml ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This workflow will:
2
+ # - Create a new Github release
3
+ # - Build wheels for supported architectures
4
+ # - Deploy the wheels to the Github release
5
+ # - Release the static code to PyPi
6
+ # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
7
+
8
+ name: Build wheels and deploy
9
+
10
+ on:
11
+ create:
12
+ tags:
13
+ - v*
14
+
15
+ jobs:
16
+
17
+ setup_release:
18
+ name: Create Release
19
+ runs-on: ubuntu-latest
20
+ steps:
21
+ - name: Get the tag version
22
+ id: extract_branch
23
+ run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
24
+ shell: bash
25
+
26
+ - name: Create Release
27
+ id: create_release
28
+ uses: actions/create-release@v1
29
+ env:
30
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
31
+ with:
32
+ tag_name: ${{ steps.extract_branch.outputs.branch }}
33
+ release_name: ${{ steps.extract_branch.outputs.branch }}
34
+
35
+ build_wheels:
36
+ name: Build Wheel
37
+ needs: setup_release
38
+ runs-on: ${{ matrix.os }}
39
+
40
+ strategy:
41
+ fail-fast: false
42
+ matrix:
43
+ # Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
44
+ # manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
45
+ os: [ubuntu-20.04]
46
+ python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
47
+ torch-version: ['2.0.1', '2.1.2', '2.2.2', '2.3.1', '2.4.0.dev20240505']
48
+ cuda-version: ['11.8.0', '12.2.2']
49
+ # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
50
+ # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
51
+ # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
52
+ # when building without C++11 ABI and using it on nvcr images.
53
+ cxx11_abi: ['FALSE', 'TRUE']
54
+ exclude:
55
+ # Pytorch < 2.2 does not support Python 3.12
56
+ - torch-version: '2.0.1'
57
+ python-version: '3.12'
58
+ - torch-version: '2.1.2'
59
+ python-version: '3.12'
60
+ # Pytorch <= 2.0 only supports CUDA <= 11.8
61
+ - torch-version: '2.0.1'
62
+ cuda-version: '12.2.2'
63
+
64
+ steps:
65
+ - name: Checkout
66
+ uses: actions/checkout@v3
67
+
68
+ - name: Set up Python
69
+ uses: actions/setup-python@v4
70
+ with:
71
+ python-version: ${{ matrix.python-version }}
72
+
73
+ - name: Set CUDA and PyTorch versions
74
+ run: |
75
+ echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
76
+ echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
77
+
78
+ - name: Free up disk space
79
+ if: ${{ runner.os == 'Linux' }}
80
+ # https://github.com/easimon/maximize-build-space/blob/master/action.yml
81
+ # https://github.com/easimon/maximize-build-space/tree/test-report
82
+ run: |
83
+ sudo rm -rf /usr/share/dotnet
84
+ sudo rm -rf /opt/ghc
85
+ sudo rm -rf /opt/hostedtoolcache/CodeQL
86
+
87
+ - name: Set up swap space
88
+ if: runner.os == 'Linux'
89
+ uses: pierotofy/set-swap-space@v1.0
90
+ with:
91
+ swap-size-gb: 10
92
+
93
+ - name: Install CUDA ${{ matrix.cuda-version }}
94
+ if: ${{ matrix.cuda-version != 'cpu' }}
95
+ uses: Jimver/cuda-toolkit@v0.2.14
96
+ id: cuda-toolkit
97
+ with:
98
+ cuda: ${{ matrix.cuda-version }}
99
+ linux-local-args: '["--toolkit"]'
100
+ # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
101
+ # method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }}
102
+ method: 'network'
103
+ # We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions,
104
+ # not just nvcc
105
+ # sub-packages: '["nvcc"]'
106
+
107
+ - name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
108
+ run: |
109
+ pip install --upgrade pip
110
+ # If we don't install before installing Pytorch, we get error for torch 2.0.1
111
+ # ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none)
112
+ pip install lit
113
+ # For some reason torch 2.2.0 on python 3.12 errors saying no setuptools
114
+ pip install setuptools==68.0.0
115
+ # We want to figure out the CUDA version to download pytorch
116
+ # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
117
+ # This code is ugly, maybe there's a better way to do this.
118
+ export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
119
+ minv = {'2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118}[env['MATRIX_TORCH_VERSION']]; \
120
+ maxv = {'2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121, '2.4': 121}[env['MATRIX_TORCH_VERSION']]; \
121
+ print(max(min(int(env['MATRIX_CUDA_VERSION']), maxv), minv))" \
122
+ )
123
+ if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
124
+ pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
125
+ else
126
+ pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
127
+ fi
128
+ nvcc --version
129
+ python --version
130
+ python -c "import torch; print('PyTorch:', torch.__version__)"
131
+ python -c "import torch; print('CUDA:', torch.version.cuda)"
132
+ python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
133
+ shell:
134
+ bash
135
+
136
+ - name: Build wheel
137
+ run: |
138
+ # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
139
+ # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
140
+ # However this still fails so I'm using a newer version of setuptools
141
+ pip install setuptools==68.0.0
142
+ pip install ninja packaging wheel
143
+ export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
144
+ export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
145
+ # Limit MAX_JOBS otherwise the github runner goes OOM
146
+ MAX_JOBS=2 CAUSAL_CONV1D_FORCE_BUILD="TRUE" CAUSAL_CONV1D_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
147
+ tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
148
+ wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
149
+ ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
150
+ echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
151
+
152
+ - name: Log Built Wheels
153
+ run: |
154
+ ls dist
155
+
156
+ - name: Get the tag version
157
+ id: extract_branch
158
+ run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
159
+
160
+ - name: Get Release with tag
161
+ id: get_current_release
162
+ uses: joutvhu/get-release@v1
163
+ with:
164
+ tag_name: ${{ steps.extract_branch.outputs.branch }}
165
+ env:
166
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
167
+
168
+ - name: Upload Release Asset
169
+ id: upload_release_asset
170
+ uses: actions/upload-release-asset@v1
171
+ env:
172
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
173
+ with:
174
+ upload_url: ${{ steps.get_current_release.outputs.upload_url }}
175
+ asset_path: ./dist/${{env.wheel_name}}
176
+ asset_name: ${{env.wheel_name}}
177
+ asset_content_type: application/*
178
+
179
+ publish_package:
180
+ name: Publish package
181
+ needs: [build_wheels]
182
+
183
+ runs-on: ubuntu-latest
184
+
185
+ steps:
186
+ - uses: actions/checkout@v3
187
+
188
+ - uses: actions/setup-python@v4
189
+ with:
190
+ python-version: '3.10'
191
+
192
+ - name: Install dependencies
193
+ run: |
194
+ pip install ninja packaging setuptools wheel twine
195
+ # We don't want to download anything CUDA-related here
196
+ pip install torch --index-url https://download.pytorch.org/whl/cpu
197
+
198
+ - name: Build core package
199
+ env:
200
+ CAUSAL_CONV1D_SKIP_CUDA_BUILD: "TRUE"
201
+ run: |
202
+ python setup.py sdist --dist-dir=dist
203
+
204
+ - name: Deploy
205
+ env:
206
+ TWINE_USERNAME: "__token__"
207
+ TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
208
+ run: |
209
+ python -m twine upload dist/*
causal-conv1d/.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ *__pycache__/
2
+ *.egg-info/
3
+ build/
4
+ **.so
5
+ *.hip
6
+ *_hip.*
causal-conv1d/AUTHORS ADDED
@@ -0,0 +1 @@
 
 
1
+ Tri Dao, tri@tridao.me
causal-conv1d/LICENSE ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ * Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ * Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ * Neither the name of the copyright holder nor the names of its
17
+ contributors may be used to endorse or promote products derived from
18
+ this software without specific prior written permission.
19
+
20
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
causal-conv1d/README.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Causal depthwise conv1d in CUDA with a PyTorch interface
2
+
3
+ Features:
4
+ - Support fp32, fp16, bf16.
5
+ - Kernel size 2, 3, 4.
6
+
7
+ ## How to use
8
+
9
+ ```
10
+ from causal_conv1d import causal_conv1d_fn
11
+ ```
12
+
13
+ ```
14
+ def causal_conv1d_fn(x, weight, bias=None, activation=None):
15
+ """
16
+ x: (batch, dim, seqlen)
17
+ weight: (dim, width)
18
+ bias: (dim,)
19
+ activation: either None or "silu" or "swish"
20
+
21
+ out: (batch, dim, seqlen)
22
+ """
23
+ ```
24
+
25
+ Equivalent to:
26
+ ```
27
+ import torch.nn.functional as F
28
+
29
+ F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
30
+ ```
31
+
32
+ ## Additional Prerequisites for AMD cards
33
+
34
+ ### Patching ROCm
35
+
36
+ If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
37
+
38
+ 1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.
39
+
40
+ 2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
41
+ ```bash
42
+ patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
43
+ ```
causal-conv1d/build/lib/causal_conv1d/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __version__ = "1.4.0"
2
+
3
+ from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
causal-conv1d/build/lib/causal_conv1d/causal_conv1d_interface.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ import causal_conv1d_cuda
8
+
9
+
10
+ class CausalConv1dFn(torch.autograd.Function):
11
+ @staticmethod
12
+ def forward(
13
+ ctx,
14
+ x,
15
+ weight,
16
+ bias=None,
17
+ seq_idx=None,
18
+ initial_states=None,
19
+ return_final_states=False,
20
+ final_states_out=None,
21
+ activation=None,
22
+ ):
23
+ if activation not in [None, "silu", "swish"]:
24
+ raise NotImplementedError("activation must be None, silu, or swish")
25
+ if x.stride(2) != 1 and x.stride(1) != 1:
26
+ x = x.contiguous()
27
+ bias = bias.contiguous() if bias is not None else None
28
+ if seq_idx is not None:
29
+ assert (
30
+ initial_states is None
31
+ ), "initial_states must be None if seq_idx is not None"
32
+ assert (
33
+ not return_final_states
34
+ ), "If seq_idx is not None, we don't return final_states_out"
35
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
36
+ if initial_states is not None and (
37
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
38
+ ):
39
+ initial_states = initial_states.contiguous()
40
+ if return_final_states:
41
+ assert (
42
+ x.stride(1) == 1
43
+ ), "Only channel-last layout support returning final_states_out"
44
+ if final_states_out is not None:
45
+ assert (
46
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
47
+ )
48
+ else:
49
+ batch, dim, seqlen = x.shape
50
+ width = weight.shape[1]
51
+ final_states_out = torch.empty(
52
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
53
+ ).transpose(1, 2)
54
+ else:
55
+ final_states_out = None
56
+ ctx.activation = activation in ["silu", "swish"]
57
+ out = causal_conv1d_cuda.causal_conv1d_fwd(
58
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
59
+ )
60
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
61
+ ctx.return_final_states = return_final_states
62
+ ctx.return_dinitial_states = (
63
+ initial_states is not None and initial_states.requires_grad
64
+ )
65
+ return out if not return_final_states else (out, final_states_out)
66
+
67
+ @staticmethod
68
+ def backward(ctx, dout, *args):
69
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
70
+ dfinal_states = args[0] if ctx.return_final_states else None
71
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
72
+ dout = dout.contiguous()
73
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
74
+ # backward of conv1d with the backward of chunk).
75
+ # Here we just pass in None and dx will be allocated in the C++ code.
76
+ dx, dweight, dbias, dinitial_states = causal_conv1d_cuda.causal_conv1d_bwd(
77
+ x,
78
+ weight,
79
+ bias,
80
+ dout,
81
+ seq_idx,
82
+ initial_states,
83
+ dfinal_states,
84
+ None,
85
+ ctx.return_dinitial_states,
86
+ ctx.activation,
87
+ )
88
+ return (
89
+ dx,
90
+ dweight,
91
+ dbias if bias is not None else None,
92
+ None,
93
+ dinitial_states if initial_states is not None else None,
94
+ None,
95
+ None,
96
+ None,
97
+ )
98
+
99
+
100
+ def causal_conv1d_fn(
101
+ x,
102
+ weight,
103
+ bias=None,
104
+ seq_idx=None,
105
+ initial_states=None,
106
+ return_final_states=False,
107
+ final_states_out=None,
108
+ activation=None,
109
+ ):
110
+ """
111
+ x: (batch, dim, seqlen)
112
+ weight: (dim, width)
113
+ bias: (dim,)
114
+ seq_idx: (batch, seqlen)
115
+ initial_states: (batch, dim, width - 1)
116
+ final_states_out: (batch, dim, width - 1), to be written to
117
+ activation: either None or "silu" or "swish"
118
+
119
+ out: (batch, dim, seqlen)
120
+ """
121
+ return CausalConv1dFn.apply(
122
+ x,
123
+ weight,
124
+ bias,
125
+ seq_idx,
126
+ initial_states,
127
+ return_final_states,
128
+ final_states_out,
129
+ activation,
130
+ )
131
+
132
+
133
+ def causal_conv1d_ref(
134
+ x,
135
+ weight,
136
+ bias=None,
137
+ initial_states=None,
138
+ return_final_states=False,
139
+ final_states_out=None,
140
+ activation=None,
141
+ ):
142
+ """
143
+ x: (batch, dim, seqlen)
144
+ weight: (dim, width)
145
+ bias: (dim,)
146
+ initial_states: (batch, dim, width - 1)
147
+ final_states_out: (batch, dim, width - 1)
148
+
149
+ out: (batch, dim, seqlen)
150
+ """
151
+ if activation not in [None, "silu", "swish"]:
152
+ raise NotImplementedError("activation must be None, silu, or swish")
153
+ dtype_in = x.dtype
154
+ x = x.to(weight.dtype)
155
+ seqlen = x.shape[-1]
156
+ dim, width = weight.shape
157
+ if initial_states is None:
158
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
159
+ else:
160
+ x = torch.cat([initial_states, x], dim=-1)
161
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
162
+ out = out[..., :seqlen]
163
+ if return_final_states:
164
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
165
+ dtype_in
166
+ ) # (batch, dim, width - 1)
167
+ if final_states_out is not None:
168
+ final_states_out.copy_(final_states)
169
+ else:
170
+ final_states_out = final_states
171
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
172
+ return out if not return_final_states else (out, final_states_out)
173
+
174
+
175
+ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
176
+ """
177
+ x: (batch, dim) or (batch, dim, seqlen)
178
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
179
+ weight: (dim, width)
180
+ bias: (dim,)
181
+ cache_seqlens: (batch,), dtype int32.
182
+ If not None, the conv_state is treated as a circular buffer.
183
+ The conv_state will be updated by copying x to the conv_state starting at the index
184
+ @cache_seqlens % state_len.
185
+
186
+ out: (batch, dim) or (batch, dim, seqlen)
187
+ """
188
+ if activation not in [None, "silu", "swish"]:
189
+ raise NotImplementedError("activation must be None, silu, or swish")
190
+ activation = activation in ["silu", "swish"]
191
+ unsqueeze = x.dim() == 2
192
+ if unsqueeze:
193
+ x = x.unsqueeze(-1)
194
+ out = causal_conv1d_cuda.causal_conv1d_update(
195
+ x, conv_state, weight, bias, activation, cache_seqlens
196
+ )
197
+ if unsqueeze:
198
+ out = out.squeeze(-1)
199
+ return out
200
+
201
+
202
+ def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
203
+ """
204
+ x: (batch, dim) or (batch, dim, seqlen)
205
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
206
+ weight: (dim, width)
207
+ bias: (dim,)
208
+ cache_seqlens: (batch,), dtype int32.
209
+ If not None, the conv_state is treated as a circular buffer.
210
+ The conv_state will be updated by copying x to the conv_state starting at the index
211
+ @cache_seqlens % state_len before performing the convolution.
212
+
213
+ out: (batch, dim) or (batch, dim, seqlen)
214
+ """
215
+ if activation not in [None, "silu", "swish"]:
216
+ raise NotImplementedError("activation must be None, silu, or swish")
217
+ dtype_in = x.dtype
218
+ unsqueeze = x.dim() == 2
219
+ if unsqueeze:
220
+ x = x.unsqueeze(-1)
221
+ batch, dim, seqlen = x.shape
222
+ width = weight.shape[1]
223
+ state_len = conv_state.shape[-1]
224
+ assert conv_state.shape == (batch, dim, state_len)
225
+ assert weight.shape == (dim, width)
226
+ if cache_seqlens is None:
227
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
228
+ conv_state.copy_(x_new[:, :, -state_len:])
229
+ else:
230
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
231
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
232
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
233
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
234
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
235
+ conv_state.scatter_(2, copy_idx, x)
236
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
237
+ if unsqueeze:
238
+ out = out.squeeze(-1)
239
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
causal-conv1d/build/lib/causal_conv1d/causal_conv1d_varlen.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+
4
+ import triton
5
+ import triton.language as tl
6
+
7
+
8
+ @triton.jit
9
+ def _causal_conv1d_varlen_states(
10
+ X,
11
+ CU_SEQLENS,
12
+ STATES,
13
+ state_len,
14
+ dim,
15
+ stride_x_seqlen, stride_x_dim,
16
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
17
+ BLOCK_M: tl.constexpr,
18
+ BLOCK_N: tl.constexpr
19
+ ):
20
+ batch_idx = tl.program_id(2)
21
+ STATES += batch_idx * stride_states_batch
22
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
23
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
24
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
25
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
26
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
27
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
28
+ other=0)
29
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
30
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
31
+ x,
32
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
33
+
34
+
35
+ def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
36
+ """
37
+ Forward pass only, does not support backward pass.
38
+ Parameters:
39
+ x: (total_tokens, dim)
40
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
41
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
42
+ If some of those elements belong to a different sequence, the value of the states will be zero.
43
+ Return:
44
+ states: (batch, dim, state_len)
45
+ """
46
+ _, dim = x.shape
47
+ batch = cu_seqlens.shape[0] - 1
48
+ cu_seqlens = cu_seqlens.contiguous()
49
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
50
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
51
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
52
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
53
+ with torch.cuda.device(x.device.index):
54
+ _causal_conv1d_varlen_states[grid](
55
+ x,
56
+ cu_seqlens,
57
+ states,
58
+ state_len,
59
+ dim,
60
+ x.stride(0), x.stride(1),
61
+ states.stride(0), states.stride(2), states.stride(1),
62
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
63
+ )
64
+ return states
65
+
66
+
67
+ def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
68
+ """
69
+ Forward pass only, does not support backward pass.
70
+ Parameters:
71
+ x: (total_tokens, dim)
72
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
73
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
74
+ If some of those elements belong to a different sequence, the value of the states will be zero.
75
+ Return:
76
+ states: (batch, dim, state_len)
77
+ """
78
+ _, dim = x.shape
79
+ batch = cu_seqlens.shape[0] - 1
80
+ cu_seqlens = cu_seqlens.contiguous()
81
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
82
+ for i in range(batch):
83
+ end_idx = cu_seqlens[i + 1]
84
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
85
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
86
+ return states
causal-conv1d/causal_conv1d.egg-info/PKG-INFO ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: causal-conv1d
3
+ Version: 1.4.0
4
+ Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
5
+ Home-page: https://github.com/Dao-AILab/causal-conv1d
6
+ Author: Tri Dao
7
+ Author-email: tri@tridao.me
8
+ License: UNKNOWN
9
+ Platform: UNKNOWN
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: License :: OSI Approved :: BSD License
12
+ Classifier: Operating System :: Unix
13
+ Requires-Python: >=3.8
14
+ Description-Content-Type: text/markdown
15
+ License-File: LICENSE
16
+ License-File: AUTHORS
17
+
18
+ # Causal depthwise conv1d in CUDA with a PyTorch interface
19
+
20
+ Features:
21
+ - Support fp32, fp16, bf16.
22
+ - Kernel size 2, 3, 4.
23
+
24
+ ## How to use
25
+
26
+ ```
27
+ from causal_conv1d import causal_conv1d_fn
28
+ ```
29
+
30
+ ```
31
+ def causal_conv1d_fn(x, weight, bias=None, activation=None):
32
+ """
33
+ x: (batch, dim, seqlen)
34
+ weight: (dim, width)
35
+ bias: (dim,)
36
+ activation: either None or "silu" or "swish"
37
+
38
+ out: (batch, dim, seqlen)
39
+ """
40
+ ```
41
+
42
+ Equivalent to:
43
+ ```
44
+ import torch.nn.functional as F
45
+
46
+ F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
47
+ ```
48
+
49
+ ## Additional Prerequisites for AMD cards
50
+
51
+ ### Patching ROCm
52
+
53
+ If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
54
+
55
+ 1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.
56
+
57
+ 2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
58
+ ```bash
59
+ patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
60
+ ```
61
+
62
+
causal-conv1d/causal_conv1d.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ AUTHORS
2
+ LICENSE
3
+ README.md
4
+ setup.py
5
+ causal_conv1d/__init__.py
6
+ causal_conv1d/causal_conv1d_interface.py
7
+ causal_conv1d/causal_conv1d_varlen.py
8
+ causal_conv1d.egg-info/PKG-INFO
9
+ causal_conv1d.egg-info/SOURCES.txt
10
+ causal_conv1d.egg-info/dependency_links.txt
11
+ causal_conv1d.egg-info/requires.txt
12
+ causal_conv1d.egg-info/top_level.txt
causal-conv1d/causal_conv1d.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
causal-conv1d/causal_conv1d.egg-info/requires.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ packaging
3
+ ninja
causal-conv1d/causal_conv1d.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ causal_conv1d
causal-conv1d/causal_conv1d/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __version__ = "1.4.0"
2
+
3
+ from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
causal-conv1d/causal_conv1d/causal_conv1d_interface.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ import causal_conv1d_cuda
8
+
9
+
10
+ class CausalConv1dFn(torch.autograd.Function):
11
+ @staticmethod
12
+ def forward(
13
+ ctx,
14
+ x,
15
+ weight,
16
+ bias=None,
17
+ seq_idx=None,
18
+ initial_states=None,
19
+ return_final_states=False,
20
+ final_states_out=None,
21
+ activation=None,
22
+ ):
23
+ if activation not in [None, "silu", "swish"]:
24
+ raise NotImplementedError("activation must be None, silu, or swish")
25
+ if x.stride(2) != 1 and x.stride(1) != 1:
26
+ x = x.contiguous()
27
+ bias = bias.contiguous() if bias is not None else None
28
+ if seq_idx is not None:
29
+ assert (
30
+ initial_states is None
31
+ ), "initial_states must be None if seq_idx is not None"
32
+ assert (
33
+ not return_final_states
34
+ ), "If seq_idx is not None, we don't return final_states_out"
35
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
36
+ if initial_states is not None and (
37
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
38
+ ):
39
+ initial_states = initial_states.contiguous()
40
+ if return_final_states:
41
+ assert (
42
+ x.stride(1) == 1
43
+ ), "Only channel-last layout support returning final_states_out"
44
+ if final_states_out is not None:
45
+ assert (
46
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
47
+ )
48
+ else:
49
+ batch, dim, seqlen = x.shape
50
+ width = weight.shape[1]
51
+ final_states_out = torch.empty(
52
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
53
+ ).transpose(1, 2)
54
+ else:
55
+ final_states_out = None
56
+ ctx.activation = activation in ["silu", "swish"]
57
+ out = causal_conv1d_cuda.causal_conv1d_fwd(
58
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
59
+ )
60
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
61
+ ctx.return_final_states = return_final_states
62
+ ctx.return_dinitial_states = (
63
+ initial_states is not None and initial_states.requires_grad
64
+ )
65
+ return out if not return_final_states else (out, final_states_out)
66
+
67
+ @staticmethod
68
+ def backward(ctx, dout, *args):
69
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
70
+ dfinal_states = args[0] if ctx.return_final_states else None
71
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
72
+ dout = dout.contiguous()
73
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
74
+ # backward of conv1d with the backward of chunk).
75
+ # Here we just pass in None and dx will be allocated in the C++ code.
76
+ dx, dweight, dbias, dinitial_states = causal_conv1d_cuda.causal_conv1d_bwd(
77
+ x,
78
+ weight,
79
+ bias,
80
+ dout,
81
+ seq_idx,
82
+ initial_states,
83
+ dfinal_states,
84
+ None,
85
+ ctx.return_dinitial_states,
86
+ ctx.activation,
87
+ )
88
+ return (
89
+ dx,
90
+ dweight,
91
+ dbias if bias is not None else None,
92
+ None,
93
+ dinitial_states if initial_states is not None else None,
94
+ None,
95
+ None,
96
+ None,
97
+ )
98
+
99
+
100
+ def causal_conv1d_fn(
101
+ x,
102
+ weight,
103
+ bias=None,
104
+ seq_idx=None,
105
+ initial_states=None,
106
+ return_final_states=False,
107
+ final_states_out=None,
108
+ activation=None,
109
+ ):
110
+ """
111
+ x: (batch, dim, seqlen)
112
+ weight: (dim, width)
113
+ bias: (dim,)
114
+ seq_idx: (batch, seqlen)
115
+ initial_states: (batch, dim, width - 1)
116
+ final_states_out: (batch, dim, width - 1), to be written to
117
+ activation: either None or "silu" or "swish"
118
+
119
+ out: (batch, dim, seqlen)
120
+ """
121
+ return CausalConv1dFn.apply(
122
+ x,
123
+ weight,
124
+ bias,
125
+ seq_idx,
126
+ initial_states,
127
+ return_final_states,
128
+ final_states_out,
129
+ activation,
130
+ )
131
+
132
+
133
+ def causal_conv1d_ref(
134
+ x,
135
+ weight,
136
+ bias=None,
137
+ initial_states=None,
138
+ return_final_states=False,
139
+ final_states_out=None,
140
+ activation=None,
141
+ ):
142
+ """
143
+ x: (batch, dim, seqlen)
144
+ weight: (dim, width)
145
+ bias: (dim,)
146
+ initial_states: (batch, dim, width - 1)
147
+ final_states_out: (batch, dim, width - 1)
148
+
149
+ out: (batch, dim, seqlen)
150
+ """
151
+ if activation not in [None, "silu", "swish"]:
152
+ raise NotImplementedError("activation must be None, silu, or swish")
153
+ dtype_in = x.dtype
154
+ x = x.to(weight.dtype)
155
+ seqlen = x.shape[-1]
156
+ dim, width = weight.shape
157
+ if initial_states is None:
158
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
159
+ else:
160
+ x = torch.cat([initial_states, x], dim=-1)
161
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
162
+ out = out[..., :seqlen]
163
+ if return_final_states:
164
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
165
+ dtype_in
166
+ ) # (batch, dim, width - 1)
167
+ if final_states_out is not None:
168
+ final_states_out.copy_(final_states)
169
+ else:
170
+ final_states_out = final_states
171
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
172
+ return out if not return_final_states else (out, final_states_out)
173
+
174
+
175
+ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
176
+ """
177
+ x: (batch, dim) or (batch, dim, seqlen)
178
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
179
+ weight: (dim, width)
180
+ bias: (dim,)
181
+ cache_seqlens: (batch,), dtype int32.
182
+ If not None, the conv_state is treated as a circular buffer.
183
+ The conv_state will be updated by copying x to the conv_state starting at the index
184
+ @cache_seqlens % state_len.
185
+
186
+ out: (batch, dim) or (batch, dim, seqlen)
187
+ """
188
+ if activation not in [None, "silu", "swish"]:
189
+ raise NotImplementedError("activation must be None, silu, or swish")
190
+ activation = activation in ["silu", "swish"]
191
+ unsqueeze = x.dim() == 2
192
+ if unsqueeze:
193
+ x = x.unsqueeze(-1)
194
+ out = causal_conv1d_cuda.causal_conv1d_update(
195
+ x, conv_state, weight, bias, activation, cache_seqlens
196
+ )
197
+ if unsqueeze:
198
+ out = out.squeeze(-1)
199
+ return out
200
+
201
+
202
+ def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
203
+ """
204
+ x: (batch, dim) or (batch, dim, seqlen)
205
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
206
+ weight: (dim, width)
207
+ bias: (dim,)
208
+ cache_seqlens: (batch,), dtype int32.
209
+ If not None, the conv_state is treated as a circular buffer.
210
+ The conv_state will be updated by copying x to the conv_state starting at the index
211
+ @cache_seqlens % state_len before performing the convolution.
212
+
213
+ out: (batch, dim) or (batch, dim, seqlen)
214
+ """
215
+ if activation not in [None, "silu", "swish"]:
216
+ raise NotImplementedError("activation must be None, silu, or swish")
217
+ dtype_in = x.dtype
218
+ unsqueeze = x.dim() == 2
219
+ if unsqueeze:
220
+ x = x.unsqueeze(-1)
221
+ batch, dim, seqlen = x.shape
222
+ width = weight.shape[1]
223
+ state_len = conv_state.shape[-1]
224
+ assert conv_state.shape == (batch, dim, state_len)
225
+ assert weight.shape == (dim, width)
226
+ if cache_seqlens is None:
227
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
228
+ conv_state.copy_(x_new[:, :, -state_len:])
229
+ else:
230
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
231
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
232
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
233
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
234
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
235
+ conv_state.scatter_(2, copy_idx, x)
236
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
237
+ if unsqueeze:
238
+ out = out.squeeze(-1)
239
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
causal-conv1d/causal_conv1d/causal_conv1d_varlen.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+
4
+ import triton
5
+ import triton.language as tl
6
+
7
+
8
+ @triton.jit
9
+ def _causal_conv1d_varlen_states(
10
+ X,
11
+ CU_SEQLENS,
12
+ STATES,
13
+ state_len,
14
+ dim,
15
+ stride_x_seqlen, stride_x_dim,
16
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
17
+ BLOCK_M: tl.constexpr,
18
+ BLOCK_N: tl.constexpr
19
+ ):
20
+ batch_idx = tl.program_id(2)
21
+ STATES += batch_idx * stride_states_batch
22
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
23
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
24
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
25
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
26
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
27
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
28
+ other=0)
29
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
30
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
31
+ x,
32
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
33
+
34
+
35
+ def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
36
+ """
37
+ Forward pass only, does not support backward pass.
38
+ Parameters:
39
+ x: (total_tokens, dim)
40
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
41
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
42
+ If some of those elements belong to a different sequence, the value of the states will be zero.
43
+ Return:
44
+ states: (batch, dim, state_len)
45
+ """
46
+ _, dim = x.shape
47
+ batch = cu_seqlens.shape[0] - 1
48
+ cu_seqlens = cu_seqlens.contiguous()
49
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
50
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
51
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
52
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
53
+ with torch.cuda.device(x.device.index):
54
+ _causal_conv1d_varlen_states[grid](
55
+ x,
56
+ cu_seqlens,
57
+ states,
58
+ state_len,
59
+ dim,
60
+ x.stride(0), x.stride(1),
61
+ states.stride(0), states.stride(2), states.stride(1),
62
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
63
+ )
64
+ return states
65
+
66
+
67
+ def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
68
+ """
69
+ Forward pass only, does not support backward pass.
70
+ Parameters:
71
+ x: (total_tokens, dim)
72
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
73
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
74
+ If some of those elements belong to a different sequence, the value of the states will be zero.
75
+ Return:
76
+ states: (batch, dim, state_len)
77
+ """
78
+ _, dim = x.shape
79
+ batch = cu_seqlens.shape[0] - 1
80
+ cu_seqlens = cu_seqlens.contiguous()
81
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
82
+ for i in range(batch):
83
+ end_idx = cu_seqlens[i + 1]
84
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
85
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
86
+ return states
causal-conv1d/csrc/causal_conv1d.cpp ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #include <ATen/cuda/CUDAContext.h>
6
+ #include <c10/cuda/CUDAGuard.h>
7
+ #include <torch/extension.h>
8
+ #include <vector>
9
+
10
+ #include "causal_conv1d.h"
11
+
12
+ #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
13
+
14
+ #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
15
+ if (ITYPE == at::ScalarType::Half) { \
16
+ using input_t = at::Half; \
17
+ __VA_ARGS__(); \
18
+ } else if (ITYPE == at::ScalarType::BFloat16) { \
19
+ using input_t = at::BFloat16; \
20
+ __VA_ARGS__(); \
21
+ } else if (ITYPE == at::ScalarType::Float) { \
22
+ using input_t = float; \
23
+ __VA_ARGS__(); \
24
+ } else { \
25
+ AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
26
+ }
27
+
28
+ #define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
29
+ if (WTYPE == at::ScalarType::Half) { \
30
+ using weight_t = at::Half; \
31
+ __VA_ARGS__(); \
32
+ } else if (WTYPE == at::ScalarType::BFloat16) { \
33
+ using weight_t = at::BFloat16; \
34
+ __VA_ARGS__(); \
35
+ } else if (WTYPE == at::ScalarType::Float) { \
36
+ using weight_t = float; \
37
+ __VA_ARGS__(); \
38
+ } else { \
39
+ AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
40
+ }
41
+
42
+ template<typename input_t, typename weight_t>
43
+ void causal_conv1d_fwd_cuda(ConvParamsBase &params, cudaStream_t stream);
44
+ template <typename input_t, typename weight_t>
45
+ void causal_conv1d_channellast_fwd_cuda(ConvParamsBase &params, cudaStream_t stream);
46
+
47
+ template<typename input_t, typename weight_t>
48
+ void causal_conv1d_bwd_cuda(ConvParamsBwd &params, cudaStream_t stream);
49
+ template<typename input_t, typename weight_t>
50
+ void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd &params, cudaStream_t stream);
51
+
52
+ template<typename input_t, typename weight_t>
53
+ void causal_conv1d_update_cuda(ConvParamsBase &params, cudaStream_t stream);
54
+
55
+ void set_conv_params_fwd(ConvParamsBase &params,
56
+ // sizes
57
+ const size_t batch,
58
+ const size_t dim,
59
+ const size_t seqlen,
60
+ const size_t width,
61
+ // device pointers
62
+ const at::Tensor x,
63
+ const at::Tensor weight,
64
+ const at::Tensor out,
65
+ void* bias_ptr,
66
+ bool silu_activation) {
67
+
68
+ // Reset the parameters
69
+ memset(&params, 0, sizeof(params));
70
+
71
+ params.batch = batch;
72
+ params.dim = dim;
73
+ params.seqlen = seqlen;
74
+ params.width = width;
75
+
76
+ params.silu_activation = silu_activation;
77
+
78
+ // Set the pointers and strides.
79
+ params.x_ptr = x.data_ptr();
80
+ params.weight_ptr = weight.data_ptr();
81
+ params.bias_ptr = bias_ptr;
82
+ params.out_ptr = out.data_ptr();
83
+ // All stride are in elements, not bytes.
84
+ params.x_batch_stride = x.stride(0);
85
+ params.x_c_stride = x.stride(1);
86
+ params.x_l_stride = x.stride(-1);
87
+ params.weight_c_stride = weight.stride(0);
88
+ params.weight_width_stride = weight.stride(1);
89
+ params.out_batch_stride = out.stride(0);
90
+ params.out_c_stride = out.stride(1);
91
+ params.out_l_stride = out.stride(-1);
92
+ }
93
+
94
+
95
+ void set_conv_params_bwd(ConvParamsBwd &params,
96
+ // sizes
97
+ const size_t batch,
98
+ const size_t dim,
99
+ const size_t seqlen,
100
+ const size_t width,
101
+ // device pointers
102
+ const at::Tensor x,
103
+ const at::Tensor weight,
104
+ void* bias_ptr,
105
+ const at::Tensor dout,
106
+ const at::Tensor dx,
107
+ const at::Tensor dweight,
108
+ void* dbias_ptr,
109
+ bool silu_activation) {
110
+ // Pass in "dout" instead of "out", we're not gonna use "out" at all.
111
+ set_conv_params_fwd(params, batch, dim, seqlen, width,
112
+ x, weight, dout, bias_ptr, silu_activation);
113
+
114
+ // Set the pointers and strides.
115
+ params.dout_ptr = dout.data_ptr();
116
+ params.dx_ptr = dx.data_ptr();
117
+ params.dweight_ptr = dweight.data_ptr();
118
+ params.dbias_ptr = dbias_ptr;
119
+ // All stride are in elements, not bytes.
120
+ params.dout_batch_stride = dout.stride(0);
121
+ params.dout_c_stride = dout.stride(1);
122
+ params.dout_l_stride = dout.stride(2);
123
+ params.dweight_c_stride = dweight.stride(0);
124
+ params.dweight_width_stride = dweight.stride(1);
125
+ params.dx_batch_stride = dx.stride(0);
126
+ params.dx_c_stride = dx.stride(1);
127
+ params.dx_l_stride = dx.stride(2);
128
+ }
129
+
130
+ at::Tensor
131
+ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
132
+ const c10::optional<at::Tensor> &bias_,
133
+ const c10::optional<at::Tensor> &seq_idx_,
134
+ const c10::optional<at::Tensor> &initial_states_,
135
+ c10::optional<at::Tensor> &final_states_out_,
136
+ bool silu_activation) {
137
+ auto input_type = x.scalar_type();
138
+ auto weight_type = weight.scalar_type();
139
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
140
+ TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
141
+
142
+ TORCH_CHECK(x.is_cuda());
143
+ TORCH_CHECK(weight.is_cuda());
144
+
145
+ const auto sizes = x.sizes();
146
+ const int batch_size = sizes[0];
147
+ const int dim = sizes[1];
148
+ const int seqlen = sizes[2];
149
+ const int width = weight.size(-1);
150
+
151
+ CHECK_SHAPE(x, batch_size, dim, seqlen);
152
+ CHECK_SHAPE(weight, dim, width);
153
+
154
+ TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
155
+ const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
156
+
157
+ if (is_channel_last) {
158
+ TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
159
+ TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8");
160
+ }
161
+ TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
162
+
163
+ if (bias_.has_value()) {
164
+ auto bias = bias_.value();
165
+ TORCH_CHECK(bias.scalar_type() == weight_type);
166
+ TORCH_CHECK(bias.is_cuda());
167
+ TORCH_CHECK(bias.stride(-1) == 1);
168
+ CHECK_SHAPE(bias, dim);
169
+ }
170
+
171
+ if (seq_idx_.has_value()) {
172
+ TORCH_CHECK(is_channel_last, "seq_idx is only supported for channel last layout");
173
+ auto seq_idx = seq_idx_.value();
174
+ TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32);
175
+ TORCH_CHECK(seq_idx.is_cuda());
176
+ TORCH_CHECK(seq_idx.is_contiguous());
177
+ CHECK_SHAPE(seq_idx, batch_size, seqlen);
178
+ }
179
+
180
+ at::Tensor out = torch::empty_like(x);
181
+
182
+ ConvParamsBase params;
183
+ set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
184
+ bias_.has_value() ? bias_.value().data_ptr() : nullptr,
185
+ silu_activation);
186
+
187
+ if (seq_idx_.has_value()) {
188
+ params.seq_idx_ptr = seq_idx_.value().data_ptr();
189
+ } else {
190
+ params.seq_idx_ptr = nullptr;
191
+ }
192
+
193
+ if (initial_states_.has_value()) {
194
+ TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout");
195
+ auto initial_states = initial_states_.value();
196
+ TORCH_CHECK(initial_states.scalar_type() == input_type);
197
+ TORCH_CHECK(initial_states.is_cuda());
198
+ CHECK_SHAPE(initial_states, batch_size, dim, width - 1);
199
+ TORCH_CHECK(initial_states.stride(1) == 1);
200
+ params.initial_states_ptr = initial_states.data_ptr();
201
+ params.initial_states_batch_stride = initial_states.stride(0);
202
+ params.initial_states_c_stride = initial_states.stride(1);
203
+ params.initial_states_l_stride = initial_states.stride(2);
204
+ } else {
205
+ params.initial_states_ptr = nullptr;
206
+ }
207
+
208
+ if (final_states_out_.has_value()) {
209
+ TORCH_CHECK(is_channel_last, "final_states is only supported for channel last layout");
210
+ auto final_states = final_states_out_.value();
211
+ TORCH_CHECK(final_states.scalar_type() == input_type);
212
+ TORCH_CHECK(final_states.is_cuda());
213
+ CHECK_SHAPE(final_states, batch_size, dim, width - 1);
214
+ TORCH_CHECK(final_states.stride(1) == 1);
215
+ params.final_states_ptr = final_states.data_ptr();
216
+ params.final_states_batch_stride = final_states.stride(0);
217
+ params.final_states_c_stride = final_states.stride(1);
218
+ params.final_states_l_stride = final_states.stride(2);
219
+ } else {
220
+ params.final_states_ptr = nullptr;
221
+ }
222
+
223
+ // Otherwise the kernel will be launched from cuda:0 device
224
+ // Cast to char to avoid compiler warning about narrowing
225
+ at::cuda::CUDAGuard device_guard{(char)x.get_device()};
226
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
227
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
228
+ DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_fwd", [&] {
229
+ if (!is_channel_last) {
230
+ causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
231
+ } else {
232
+ causal_conv1d_channellast_fwd_cuda<input_t, weight_t>(params, stream);
233
+ }
234
+ });
235
+ });
236
+ return out;
237
+ }
238
+
239
+ std::vector<at::Tensor>
240
+ causal_conv1d_bwd(const at::Tensor &x, const at::Tensor &weight,
241
+ const c10::optional<at::Tensor> &bias_,
242
+ at::Tensor &dout,
243
+ const c10::optional<at::Tensor> &seq_idx_,
244
+ const c10::optional<at::Tensor> &initial_states_,
245
+ const c10::optional<at::Tensor> &dfinal_states_,
246
+ c10::optional<at::Tensor> &dx_,
247
+ bool return_dinitial_states,
248
+ bool silu_activation) {
249
+ auto input_type = x.scalar_type();
250
+ auto weight_type = weight.scalar_type();
251
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
252
+ TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
253
+
254
+ TORCH_CHECK(x.is_cuda());
255
+ TORCH_CHECK(weight.is_cuda());
256
+ TORCH_CHECK(dout.is_cuda());
257
+
258
+ const auto sizes = x.sizes();
259
+ const int batch_size = sizes[0];
260
+ const int dim = sizes[1];
261
+ const int seqlen = sizes[2];
262
+ const int width = weight.size(-1);
263
+
264
+ TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
265
+
266
+ CHECK_SHAPE(x, batch_size, dim, seqlen);
267
+ CHECK_SHAPE(weight, dim, width);
268
+ CHECK_SHAPE(dout, batch_size, dim, seqlen);
269
+
270
+ TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
271
+ const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
272
+ if (!is_channel_last && dout.stride(2) != 1) { dout = dout.contiguous(); }
273
+ if (is_channel_last && dout.stride(1) != 1) { dout = dout.transpose(-1, -2).contiguous().transpose(-1, -2); }
274
+
275
+ if (is_channel_last) {
276
+ TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
277
+ TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8");
278
+ TORCH_CHECK(dout.stride(2) % 8 == 0 and dout.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (dout.stride(0) and dout.stride(2)) to be multiples of 8");
279
+ }
280
+
281
+ if (bias_.has_value()) {
282
+ auto bias = bias_.value();
283
+ TORCH_CHECK(bias.scalar_type() == weight_type);
284
+ TORCH_CHECK(bias.is_cuda());
285
+ TORCH_CHECK(bias.stride(-1) == 1);
286
+ CHECK_SHAPE(bias, dim);
287
+ }
288
+
289
+ if (seq_idx_.has_value()) {
290
+ TORCH_CHECK(is_channel_last, "seq_idx only supported for channel last layout");
291
+ auto seq_idx = seq_idx_.value();
292
+ TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32);
293
+ TORCH_CHECK(seq_idx.is_cuda());
294
+ TORCH_CHECK(seq_idx.is_contiguous());
295
+ CHECK_SHAPE(seq_idx, batch_size, seqlen);
296
+ }
297
+
298
+ at::Tensor dx;
299
+ if (dx_.has_value()) {
300
+ dx = dx_.value();
301
+ TORCH_CHECK(dx.scalar_type() == input_type);
302
+ TORCH_CHECK(dx.is_cuda());
303
+ CHECK_SHAPE(dx, batch_size, dim, seqlen);
304
+ if (!is_channel_last) { TORCH_CHECK(dx.stride(2) == 1); }
305
+ if (is_channel_last) { TORCH_CHECK(dx.stride(1) == 1); }
306
+ } else {
307
+ dx = torch::empty_like(x);
308
+ }
309
+
310
+ // Otherwise the kernel will be launched from cuda:0 device
311
+ // Cast to char to avoid compiler warning about narrowing
312
+ at::cuda::CUDAGuard device_guard{(char)x.get_device()};
313
+
314
+ at::Tensor dweight = torch::zeros_like(weight, weight.options().dtype(at::kFloat));
315
+ at::Tensor dbias;
316
+ if (bias_.has_value()) { dbias = torch::zeros_like(bias_.value(), bias_.value().options().dtype(at::kFloat)); }
317
+
318
+ ConvParamsBwd params;
319
+ set_conv_params_bwd(params, batch_size, dim, seqlen, width,
320
+ x, weight, bias_.has_value() ? bias_.value().data_ptr() : nullptr,
321
+ dout, dx, dweight, bias_.has_value() ? dbias.data_ptr() : nullptr,
322
+ silu_activation);
323
+
324
+ if (seq_idx_.has_value()) {
325
+ params.seq_idx_ptr = seq_idx_.value().data_ptr();
326
+ } else {
327
+ params.seq_idx_ptr = nullptr;
328
+ }
329
+
330
+ if (initial_states_.has_value()) {
331
+ TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout");
332
+ auto initial_states = initial_states_.value();
333
+ TORCH_CHECK(initial_states.scalar_type() == input_type);
334
+ TORCH_CHECK(initial_states.is_cuda());
335
+ CHECK_SHAPE(initial_states, batch_size, dim, width - 1);
336
+ TORCH_CHECK(initial_states.stride(1) == 1);
337
+ params.initial_states_ptr = initial_states.data_ptr();
338
+ params.initial_states_batch_stride = initial_states.stride(0);
339
+ params.initial_states_c_stride = initial_states.stride(1);
340
+ params.initial_states_l_stride = initial_states.stride(2);
341
+ } else {
342
+ params.initial_states_ptr = nullptr;
343
+ }
344
+
345
+ if (dfinal_states_.has_value()) {
346
+ TORCH_CHECK(is_channel_last, "dfinal_states is only supported for channel last layout");
347
+ auto dfinal_states = dfinal_states_.value();
348
+ TORCH_CHECK(dfinal_states.scalar_type() == input_type);
349
+ TORCH_CHECK(dfinal_states.is_cuda());
350
+ CHECK_SHAPE(dfinal_states, batch_size, dim, width - 1);
351
+ params.dfinal_states_ptr = dfinal_states.data_ptr();
352
+ params.dfinal_states_batch_stride = dfinal_states.stride(0);
353
+ params.dfinal_states_c_stride = dfinal_states.stride(1);
354
+ params.dfinal_states_l_stride = dfinal_states.stride(2);
355
+ } else {
356
+ params.dfinal_states_ptr = nullptr;
357
+ }
358
+
359
+ at::Tensor dinitial_states;
360
+ if (return_dinitial_states) {
361
+ dinitial_states = torch::empty({batch_size, width - 1, dim}, x.options()).transpose(1, 2);
362
+ TORCH_CHECK(dinitial_states.stride(1) == 1);
363
+ params.dinitial_states_ptr = dinitial_states.data_ptr();
364
+ params.dinitial_states_batch_stride = dinitial_states.stride(0);
365
+ params.dinitial_states_c_stride = dinitial_states.stride(1);
366
+ params.dinitial_states_l_stride = dinitial_states.stride(2);
367
+ } else {
368
+ params.dinitial_states_ptr = nullptr;
369
+ }
370
+
371
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
372
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_bwd", [&] {
373
+ DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_bwd", [&] {
374
+ if (!is_channel_last) {
375
+ causal_conv1d_bwd_cuda<input_t, weight_t>(params, stream);
376
+ } else {
377
+ causal_conv1d_channellast_bwd_cuda<input_t, weight_t>(params, stream);
378
+ }
379
+ });
380
+ });
381
+ return {dx, dweight.to(weight.dtype()), bias_.has_value() ? dbias.to(bias_.value().dtype()) : dbias, dinitial_states};
382
+ }
383
+
384
+ at::Tensor
385
+ causal_conv1d_update(const at::Tensor &x,
386
+ const at::Tensor &conv_state,
387
+ const at::Tensor &weight,
388
+ const c10::optional<at::Tensor> &bias_,
389
+ bool silu_activation,
390
+ const c10::optional<at::Tensor> &cache_seqlens_
391
+ ) {
392
+ auto input_type = x.scalar_type();
393
+ auto weight_type = weight.scalar_type();
394
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
395
+ TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
396
+ TORCH_CHECK(conv_state.scalar_type() == input_type);
397
+
398
+ TORCH_CHECK(x.is_cuda());
399
+ TORCH_CHECK(conv_state.is_cuda());
400
+ TORCH_CHECK(weight.is_cuda());
401
+
402
+ const auto sizes = x.sizes();
403
+ const int batch_size = sizes[0];
404
+ const int dim = sizes[1];
405
+ const int seqlen = sizes[2];
406
+ const int width = weight.size(-1);
407
+ const int conv_state_len = conv_state.size(2);
408
+ TORCH_CHECK(conv_state_len >= width - 1);
409
+
410
+ CHECK_SHAPE(x, batch_size, dim, seqlen);
411
+ CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len);
412
+ CHECK_SHAPE(weight, dim, width);
413
+
414
+ TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
415
+
416
+ if (bias_.has_value()) {
417
+ auto bias = bias_.value();
418
+ TORCH_CHECK(bias.scalar_type() == weight_type);
419
+ TORCH_CHECK(bias.is_cuda());
420
+ TORCH_CHECK(bias.stride(-1) == 1);
421
+ CHECK_SHAPE(bias, dim);
422
+ }
423
+
424
+ at::Tensor out = torch::empty_like(x);
425
+
426
+ ConvParamsBase params;
427
+ set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
428
+ bias_.has_value() ? bias_.value().data_ptr() : nullptr,
429
+ silu_activation);
430
+ params.conv_state_ptr = conv_state.data_ptr();
431
+ params.conv_state_len = conv_state_len;
432
+ // All stride are in elements, not bytes.
433
+ params.conv_state_batch_stride = conv_state.stride(0);
434
+ params.conv_state_c_stride = conv_state.stride(1);
435
+ params.conv_state_l_stride = conv_state.stride(2);
436
+
437
+ if (cache_seqlens_.has_value()) {
438
+ auto cache_seqlens = cache_seqlens_.value();
439
+ TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32);
440
+ TORCH_CHECK(cache_seqlens.is_cuda());
441
+ TORCH_CHECK(cache_seqlens.stride(-1) == 1);
442
+ CHECK_SHAPE(cache_seqlens, batch_size);
443
+ params.cache_seqlens = cache_seqlens.data_ptr<int32_t>();
444
+ } else {
445
+ params.cache_seqlens = nullptr;
446
+ }
447
+
448
+ // Otherwise the kernel will be launched from cuda:0 device
449
+ // Cast to char to avoid compiler warning about narrowing
450
+ at::cuda::CUDAGuard device_guard{(char)x.get_device()};
451
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
452
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
453
+ DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] {
454
+ causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
455
+ });
456
+ });
457
+ return out;
458
+ }
459
+
460
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
461
+ m.def("causal_conv1d_fwd", &causal_conv1d_fwd, "Causal conv1d forward");
462
+ m.def("causal_conv1d_bwd", &causal_conv1d_bwd, "Causal conv1d backward");
463
+ m.def("causal_conv1d_update", &causal_conv1d_update, "Causal conv1d update");
464
+ }
causal-conv1d/csrc/causal_conv1d.h ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
8
+
9
+ struct ConvParamsBase {
10
+ using index_t = uint32_t;
11
+
12
+ int batch, dim, seqlen, width;
13
+ bool silu_activation;
14
+
15
+ index_t x_batch_stride;
16
+ index_t x_c_stride;
17
+ index_t x_l_stride;
18
+ index_t weight_c_stride;
19
+ index_t weight_width_stride;
20
+ index_t out_batch_stride;
21
+ index_t out_c_stride;
22
+ index_t out_l_stride;
23
+
24
+ int conv_state_len;
25
+ index_t conv_state_batch_stride;
26
+ index_t conv_state_c_stride;
27
+ index_t conv_state_l_stride;
28
+
29
+ // Common data pointers.
30
+ void *__restrict__ x_ptr;
31
+ void *__restrict__ weight_ptr;
32
+ void *__restrict__ bias_ptr;
33
+ void *__restrict__ out_ptr;
34
+
35
+ void *__restrict__ conv_state_ptr;
36
+ int32_t *__restrict__ cache_seqlens;
37
+
38
+ void *__restrict__ seq_idx_ptr;
39
+
40
+ // No __restrict__ since initial_states could be the same as final_states.
41
+ void * initial_states_ptr;
42
+ index_t initial_states_batch_stride;
43
+ index_t initial_states_l_stride;
44
+ index_t initial_states_c_stride;
45
+
46
+ void * final_states_ptr;
47
+ index_t final_states_batch_stride;
48
+ index_t final_states_l_stride;
49
+ index_t final_states_c_stride;
50
+ };
51
+
52
+ struct ConvParamsBwd: public ConvParamsBase {
53
+ index_t dx_batch_stride;
54
+ index_t dx_c_stride;
55
+ index_t dx_l_stride;
56
+ index_t dweight_c_stride;
57
+ index_t dweight_width_stride;
58
+ index_t dout_batch_stride;
59
+ index_t dout_c_stride;
60
+ index_t dout_l_stride;
61
+
62
+ // Common data pointers.
63
+ void *__restrict__ dx_ptr;
64
+ void *__restrict__ dweight_ptr;
65
+ void *__restrict__ dbias_ptr;
66
+ void *__restrict__ dout_ptr;
67
+
68
+ void * dinitial_states_ptr;
69
+ index_t dinitial_states_batch_stride;
70
+ index_t dinitial_states_l_stride;
71
+ index_t dinitial_states_c_stride;
72
+
73
+ void * dfinal_states_ptr;
74
+ index_t dfinal_states_batch_stride;
75
+ index_t dfinal_states_l_stride;
76
+ index_t dfinal_states_c_stride;
77
+ };
causal-conv1d/csrc/causal_conv1d_bwd.cu ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #include <c10/util/BFloat16.h>
6
+ #include <c10/util/Half.h>
7
+ #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
8
+
9
+ #ifndef USE_ROCM
10
+ #include <cub/block/block_load.cuh>
11
+ #include <cub/block/block_store.cuh>
12
+ #include <cub/block/block_reduce.cuh>
13
+ #else
14
+ #include <hipcub/hipcub.hpp>
15
+ namespace cub = hipcub;
16
+ #endif
17
+
18
+ #include "causal_conv1d.h"
19
+ #include "causal_conv1d_common.h"
20
+ #include "static_switch.h"
21
+
22
+ template<int kNThreads_, int kWidth_, bool kSiluAct_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
23
+ struct Causal_conv1d_bwd_kernel_traits {
24
+ using input_t = input_t_;
25
+ using weight_t = weight_t_;
26
+ static constexpr int kNThreads = kNThreads_;
27
+ static constexpr int kWidth = kWidth_;
28
+ static constexpr bool kSiluAct = kSiluAct_;
29
+ static constexpr int kNBytes = sizeof(input_t);
30
+ static_assert(kNBytes == 2 || kNBytes == 4);
31
+ static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
32
+ static_assert(kWidth <= kNElts);
33
+ // It's possible that we need to do 2 rounds of exchange if input_t is 16 bits
34
+ // (since then we'd have 8 values of float, and each round we can exchange 4 floats).
35
+ static constexpr int kNExchangeRounds = sizeof(float) / sizeof(input_t);
36
+ static constexpr bool kIsVecLoad = kIsVecLoad_;
37
+ using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
38
+ using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
39
+ using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
40
+ using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
41
+ using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
42
+ using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
43
+ static constexpr int kSmemIOSize = kIsVecLoad
44
+ ? 0
45
+ : custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
46
+ static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts * (!kSiluAct ? 1 : kNExchangeRounds + 1);
47
+ static constexpr int kSmemSize = custom_max({kSmemExchangeSize,
48
+ int(sizeof(typename BlockReduceFloatT::TempStorage))}) + (kIsVecLoad ? 0 : kSmemIOSize);
49
+ };
50
+
51
+ template<typename Ktraits>
52
+ __global__ __launch_bounds__(Ktraits::kNThreads)
53
+ void causal_conv1d_bwd_kernel(ConvParamsBwd params) {
54
+ constexpr int kWidth = Ktraits::kWidth;
55
+ constexpr int kNThreads = Ktraits::kNThreads;
56
+ constexpr bool kSiluAct = Ktraits::kSiluAct;
57
+ static constexpr int kNElts = Ktraits::kNElts;
58
+ constexpr int kNExchangeRounds = Ktraits::kNExchangeRounds;
59
+ static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
60
+ using input_t = typename Ktraits::input_t;
61
+ using vec_t = typename Ktraits::vec_t;
62
+ using weight_t = typename Ktraits::weight_t;
63
+
64
+ // Shared memory.
65
+ extern __shared__ char smem_[];
66
+ auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
67
+ auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
68
+ auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
69
+ auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
70
+ vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
71
+ vec_t *smem_exchange_x = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize) + kNThreads * kNExchangeRounds;
72
+ auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
73
+
74
+ const int tidx = threadIdx.x;
75
+ const int batch_id = blockIdx.x;
76
+ const int dim_id = blockIdx.y;
77
+ input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
78
+ + dim_id * params.x_c_stride;
79
+ weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + dim_id * params.weight_c_stride;
80
+ input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
81
+ + dim_id * params.dout_c_stride;
82
+ input_t *dx = reinterpret_cast<input_t *>(params.dx_ptr) + batch_id * params.dx_batch_stride
83
+ + dim_id * params.dx_c_stride;
84
+ float *dweight = reinterpret_cast<float *>(params.dweight_ptr) + dim_id * params.dweight_c_stride;
85
+ float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[dim_id]);
86
+
87
+ // Thread kNThreads - 1 will load the first elements of the next chunk so we initialize those to 0.
88
+ if (tidx == 0) {
89
+ if constexpr (!kSiluAct) {
90
+ input_t zeros[kNElts] = {0};
91
+ smem_exchange[0] = reinterpret_cast<vec_t *>(zeros)[0];
92
+ } else {
93
+ float zeros[kNElts] = {0};
94
+ #pragma unroll
95
+ for (int r = 0; r < kNExchangeRounds; ++r) {
96
+ smem_exchange[r * kNThreads] = reinterpret_cast<vec_t *>(zeros)[r];
97
+ }
98
+ }
99
+ }
100
+
101
+ float weight_vals[kWidth];
102
+ #pragma unroll
103
+ for (int i = 0; i < kWidth; ++i) { weight_vals[i] = weight[i * params.weight_width_stride]; }
104
+
105
+ float dweight_vals[kWidth] = {0};
106
+ float dbias_val = 0;
107
+
108
+ constexpr int kChunkSize = kNThreads * kNElts;
109
+ const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
110
+ x += (n_chunks - 1) * kChunkSize;
111
+ dout += (n_chunks - 1) * kChunkSize;
112
+ dx += (n_chunks - 1) * kChunkSize;
113
+ for (int chunk = n_chunks - 1; chunk >= 0; --chunk) {
114
+ input_t x_vals_load[2 * kNElts] = {0};
115
+ input_t dout_vals_load[2 * kNElts] = {0};
116
+ if constexpr(kIsVecLoad) {
117
+ typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
118
+ typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(dout), *reinterpret_cast<vec_t (*)[1]>(&dout_vals_load[0]), (params.seqlen - chunk * kChunkSize) / kNElts);
119
+ } else {
120
+ __syncthreads();
121
+ typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
122
+ __syncthreads();
123
+ typename Ktraits::BlockLoadT(smem_load).Load(dout, *reinterpret_cast<input_t (*)[kNElts]>(&dout_vals_load[0]), params.seqlen - chunk * kChunkSize);
124
+ }
125
+ float dout_vals[2 * kNElts], x_vals[2 * kNElts];
126
+ if constexpr (!kSiluAct) {
127
+ __syncthreads();
128
+ // Thread 0 don't write yet, so that thread kNThreads - 1 can read
129
+ // the first elements of the next chunk.
130
+ if (tidx > 0) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; }
131
+ __syncthreads();
132
+ reinterpret_cast<vec_t *>(dout_vals_load)[1] = smem_exchange[tidx < kNThreads - 1 ? tidx + 1 : 0];
133
+ __syncthreads();
134
+ // Now thread 0 can write the first elements of the current chunk.
135
+ if (tidx == 0) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; }
136
+ #pragma unroll
137
+ for (int i = 0; i < 2 * kNElts; ++i) {
138
+ dout_vals[i] = float(dout_vals_load[i]);
139
+ x_vals[i] = float(x_vals_load[i]);
140
+ }
141
+ } else {
142
+ if (tidx == 0 && chunk > 0) {
143
+ if constexpr(kIsVecLoad) {
144
+ reinterpret_cast<vec_t *>(x_vals_load)[0] = reinterpret_cast<vec_t *>(x)[-1];
145
+ } else {
146
+ #pragma unroll
147
+ for (int i = 0; i < kNElts; ++i) {
148
+ if (chunk * kChunkSize + i < params.seqlen) { x_vals_load[i] = x[-kNElts + i]; }
149
+ }
150
+ }
151
+ }
152
+ __syncthreads();
153
+ smem_exchange_x[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1];
154
+ __syncthreads();
155
+ if (tidx > 0) { reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange_x[tidx - 1]; }
156
+ #pragma unroll
157
+ for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
158
+ // Recompute the output
159
+ #pragma unroll
160
+ for (int i = 0; i < kNElts; ++i) {
161
+ float out_val = bias_val;
162
+ #pragma unroll
163
+ for (int w = 0; w < kWidth; ++w) {
164
+ out_val += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
165
+ }
166
+ float out_sigmoid_val = 1.0f / (1.0f + expf(-out_val));
167
+ dout_vals[i] = float(dout_vals_load[i]) * out_sigmoid_val
168
+ * (1.0f + out_val * (1.0f - out_sigmoid_val));
169
+ }
170
+ // Exchange the dout_vals. It's possible that we need to do 2 rounds of exchange
171
+ // if input_t is 16 bits (since then we'd have 8 values of float)
172
+ __syncthreads();
173
+ // Thread 0 don't write yet, so that thread kNThreads - 1 can read
174
+ // the first elements of the next chunk.
175
+ if (tidx > 0) {
176
+ #pragma unroll
177
+ for (int r = 0; r < kNExchangeRounds; ++r) {
178
+ smem_exchange[r * kNThreads + tidx] = reinterpret_cast<vec_t *>(dout_vals)[r];
179
+ }
180
+ }
181
+ __syncthreads();
182
+ #pragma unroll
183
+ for (int r = 0; r < kNExchangeRounds; ++r) {
184
+ reinterpret_cast<vec_t *>(dout_vals)[kNExchangeRounds + r]
185
+ = smem_exchange[r * kNThreads + (tidx < kNThreads - 1 ? tidx + 1 : 0)];
186
+ }
187
+ __syncthreads();
188
+ // Now thread 0 can write the first elements of the current chunk.
189
+ if (tidx == 0) {
190
+ #pragma unroll
191
+ for (int r = 0; r < kNExchangeRounds; ++r) {
192
+ smem_exchange[r * kNThreads + tidx] = reinterpret_cast<vec_t *>(dout_vals)[r];
193
+ }
194
+ }
195
+ }
196
+ dout -= kChunkSize;
197
+ x -= kChunkSize;
198
+
199
+ #pragma unroll
200
+ for (int i = 0; i < kNElts; ++i) { dbias_val += dout_vals[i]; }
201
+
202
+ float dx_vals[kNElts] = {0};
203
+ #pragma unroll
204
+ for (int i = 0; i < kNElts; ++i) {
205
+ #pragma unroll
206
+ for (int w = 0; w < kWidth; ++w) {
207
+ dx_vals[i] += weight_vals[w] * dout_vals[i + kWidth - w - 1];
208
+ }
209
+ }
210
+
211
+ input_t dx_vals_store[kNElts];
212
+ #pragma unroll
213
+ for (int i = 0; i < kNElts; ++i) { dx_vals_store[i] = dx_vals[i]; }
214
+ if constexpr(kIsVecLoad) {
215
+ typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(dx), reinterpret_cast<vec_t (&)[1]>(dx_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
216
+ } else {
217
+ typename Ktraits::BlockStoreT(smem_store).Store(dx, dx_vals_store, params.seqlen - chunk * kChunkSize);
218
+ }
219
+ dx -= kChunkSize;
220
+
221
+ #pragma unroll
222
+ for (int w = 0; w < kWidth; ++w) {
223
+ #pragma unroll
224
+ for (int i = 0; i < kNElts; ++i) {
225
+ dweight_vals[w] += x_vals[kNElts + i] * dout_vals[i + kWidth - w - 1];
226
+ }
227
+ }
228
+ }
229
+
230
+ #pragma unroll
231
+ for (int w = 0; w < kWidth; ++w) {
232
+ __syncthreads();
233
+ dweight_vals[w] = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dweight_vals[w]);
234
+ if (tidx == 0) {
235
+ atomicAdd(&reinterpret_cast<float *>(dweight)[w * params.dweight_width_stride], dweight_vals[w]);
236
+ }
237
+ }
238
+ if (params.bias_ptr != nullptr) {
239
+ __syncthreads();
240
+ dbias_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dbias_val);
241
+ if (tidx == 0) {
242
+ atomicAdd(&reinterpret_cast<float *>(params.dbias_ptr)[dim_id], dbias_val);
243
+ }
244
+ }
245
+ }
246
+
247
+ template<int kNThreads, int kWidth, typename input_t, typename weight_t>
248
+ void causal_conv1d_bwd_launch(ConvParamsBwd &params, cudaStream_t stream) {
249
+ static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
250
+ BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
251
+ BOOL_SWITCH(params.silu_activation, kSiluAct, [&] {
252
+ using Ktraits = Causal_conv1d_bwd_kernel_traits<kNThreads, kWidth, kSiluAct, kIsVecLoad, input_t, weight_t>;
253
+ constexpr int kSmemSize = Ktraits::kSmemSize;
254
+ dim3 grid(params.batch, params.dim);
255
+ auto kernel = &causal_conv1d_bwd_kernel<Ktraits>;
256
+
257
+ if (kSmemSize >= 48 * 1024) {
258
+ #ifndef USE_ROCM
259
+ C10_CUDA_CHECK(cudaFuncSetAttribute(
260
+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
261
+ #else
262
+ // There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
263
+ C10_CUDA_CHECK(cudaFuncSetAttribute(
264
+ (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
265
+ std::cerr << "Warning (causal_conv1d bwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
266
+ #endif
267
+ }
268
+
269
+
270
+ kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
271
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
272
+ });
273
+ });
274
+ }
275
+
276
+ template<typename input_t, typename weight_t>
277
+ void causal_conv1d_bwd_cuda(ConvParamsBwd &params, cudaStream_t stream) {
278
+ if (params.width == 2) {
279
+ causal_conv1d_bwd_launch<128, 2, input_t, weight_t>(params, stream);
280
+ } else if (params.width == 3) {
281
+ causal_conv1d_bwd_launch<128, 3, input_t, weight_t>(params, stream);
282
+ } else if (params.width == 4) {
283
+ causal_conv1d_bwd_launch<128, 4, input_t, weight_t>(params, stream);
284
+ }
285
+ }
286
+
287
+ template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kSiluAct_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
288
+ struct Causal_conv1d_channellast_bwd_kernel_traits {
289
+ // The cache line is 128 bytes, and we try to read 16 bytes per thread.
290
+ // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
291
+ // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
292
+ // threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
293
+ using input_t = input_t_;
294
+ using weight_t = weight_t_;
295
+ static constexpr bool kSiluAct = kSiluAct_;
296
+ static constexpr int kNThreads = kNThreads_;
297
+ static_assert(kNThreads % 32 == 0);
298
+ static constexpr int kNWarps = kNThreads / 32;
299
+ static constexpr int kWidth = kWidth_;
300
+ static constexpr int kChunkSizeL = kChunkSizeL_;
301
+ static constexpr int kNBytes = sizeof(input_t);
302
+ static_assert(kNBytes == 2 || kNBytes == 4);
303
+ static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
304
+ static constexpr int kNEltsPerRow = 128 / kNBytes;
305
+ static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
306
+ static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
307
+ static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
308
+ static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
309
+ static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
310
+ static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
311
+ static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
312
+ static constexpr bool kIsVecLoad = kIsVecLoad_;
313
+ using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
314
+ // using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
315
+ // using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
316
+ // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
317
+ // sizeof(typename BlockStoreT::TempStorage)});
318
+ // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
319
+ };
320
+
321
+ template<typename Ktraits, bool kHasSeqIdx, bool kHasDfinalStates>
322
+ __global__ __launch_bounds__(Ktraits::kNThreads)
323
+ void causal_conv1d_channellast_bwd_kernel(ConvParamsBwd params) {
324
+ constexpr int kWidth = Ktraits::kWidth;
325
+ constexpr int kNThreads = Ktraits::kNThreads;
326
+ constexpr bool kSiluAct = Ktraits::kSiluAct;
327
+ constexpr int kNElts = Ktraits::kNElts;
328
+ constexpr int kNWarp = Ktraits::kNWarps;
329
+ constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
330
+ constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
331
+ constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
332
+ constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
333
+ using input_t = typename Ktraits::input_t;
334
+ using vec_t = typename Ktraits::vec_t;
335
+ using weight_t = typename Ktraits::weight_t;
336
+
337
+ // Shared memory.
338
+ __shared__ input_t dout_smem[kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts];
339
+ __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts];
340
+
341
+ const int batch_id = blockIdx.x;
342
+ const int chunk_l_id = blockIdx.y;
343
+ const int chunk_c_id = blockIdx.z;
344
+ const int tid = threadIdx.x;
345
+ const int l_idx = tid / kNThreadsPerC;
346
+ const int c_idx = tid % kNThreadsPerC;
347
+ input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
348
+ + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
349
+ weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
350
+ + chunk_c_id * kChunkSizeC * params.weight_c_stride;
351
+ input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
352
+ + (chunk_l_id * kChunkSizeL + l_idx) * params.dout_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
353
+ input_t *dx = reinterpret_cast<input_t *>(params.dx_ptr) + batch_id * params.dx_batch_stride
354
+ + (chunk_l_id * kChunkSizeL + l_idx) * params.dx_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
355
+ float *dweight = reinterpret_cast<float *>(params.dweight_ptr)
356
+ + chunk_c_id * kChunkSizeC * params.dweight_c_stride;
357
+ int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast<int *>(params.seq_idx_ptr)
358
+ + batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
359
+ input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
360
+ : reinterpret_cast<input_t *>(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
361
+ input_t *dinitial_states = params.dinitial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
362
+ : reinterpret_cast<input_t *>(params.dinitial_states_ptr) + batch_id * params.dinitial_states_batch_stride + l_idx * params.dinitial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
363
+ input_t *dfinal_states = params.dfinal_states_ptr == nullptr ? nullptr
364
+ : reinterpret_cast<input_t *>(params.dfinal_states_ptr) + batch_id * params.dfinal_states_batch_stride + chunk_c_id * kChunkSizeC;
365
+
366
+ #pragma unroll
367
+ for (int l = 0; l < Ktraits::kNLoads; ++l) {
368
+ input_t dout_vals_load[kNElts] = {0};
369
+ input_t x_vals_load[kNElts] = {0};
370
+ if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
371
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
372
+ reinterpret_cast<vec_t *>(dout_vals_load)[0] = *reinterpret_cast<vec_t *>(dout + l * kLPerLoad * params.dout_l_stride);
373
+ reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
374
+ }
375
+ reinterpret_cast<vec_t *>(dout_smem[l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(dout_vals_load)[0];
376
+ reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
377
+ }
378
+ // Load the elements from the previous chunk or next chunk that are needed for convolution.
379
+ if (l_idx < kWidth - 1) {
380
+ input_t dout_vals_load[kNElts] = {0};
381
+ input_t x_vals_load[kNElts] = {0};
382
+ if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen
383
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
384
+ reinterpret_cast<vec_t *>(dout_vals_load)[0] = *reinterpret_cast<vec_t *>(dout + kChunkSizeL * params.dout_l_stride);
385
+ }
386
+ if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
387
+ && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
388
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
389
+ reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
390
+ } else if (initial_states != nullptr
391
+ && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0
392
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
393
+ reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(initial_states);
394
+ }
395
+ reinterpret_cast<vec_t *>(dout_smem[kChunkSizeL + l_idx])[c_idx] = reinterpret_cast<vec_t *>(dout_vals_load)[0];
396
+ reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
397
+ }
398
+ // Need to load (kWdith - 1) extra x's on the right to recompute the (kChunkSizeL + kWidth - 1) outputs
399
+ if constexpr (kSiluAct) {
400
+ if (l_idx < kWidth - 1) {
401
+ input_t x_vals_load[kNElts] = {0};
402
+ if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen
403
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
404
+ reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + kChunkSizeL * params.x_l_stride);
405
+ }
406
+ reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + kChunkSizeL + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
407
+ }
408
+ }
409
+
410
+ __syncthreads();
411
+
412
+ constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
413
+ static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
414
+ constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
415
+ static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
416
+ // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
417
+ static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
418
+ static_assert((kLPerThread & (kLPerThread - 1)) == 0);
419
+ static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
420
+ static_assert(kNThreadsPerRow <= 32);
421
+
422
+ const int row_idx = tid / kNThreadsPerRow;
423
+ const int col_idx = tid % kNThreadsPerRow;
424
+
425
+ float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
426
+ float weight_vals[kWidth] = {0};
427
+ if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
428
+ #pragma unroll
429
+ for (int w = 0; w < kWidth; ++w) {
430
+ weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
431
+ }
432
+ }
433
+ float dout_vals[kLPerThread + kWidth - 1];
434
+ float x_vals[kWidth - 1 + kLPerThread + kWidth - 1];
435
+ #pragma unroll
436
+ for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
437
+ dout_vals[i] = float(dout_smem[col_idx * kLPerThread + i][row_idx]);
438
+ x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
439
+ }
440
+
441
+ int seq_idx_thread[kWidth - 1 + kLPerThread + kWidth - 1];
442
+ if constexpr (kHasSeqIdx) {
443
+ #pragma unroll
444
+ for (int i = 0; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) {
445
+ const int l_idx = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1);
446
+ seq_idx_thread[i] = l_idx >= 0 && l_idx < params.seqlen ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1;
447
+ }
448
+ }
449
+
450
+ if constexpr (kSiluAct) { // Recompute the output
451
+ #pragma unroll
452
+ for (int i = kWidth - 1 + kLPerThread; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) {
453
+ x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
454
+ }
455
+ #pragma unroll
456
+ for (int i = 0; i < kLPerThread + kWidth - 1; ++i) {
457
+ float out_val = bias_val;
458
+ const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
459
+ #pragma unroll
460
+ for (int w = 0; w < kWidth; ++w) {
461
+ if constexpr (!kHasSeqIdx) {
462
+ out_val += weight_vals[w] * x_vals[i + w];
463
+ } else {
464
+ out_val += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f;
465
+ }
466
+ }
467
+ float out_val_sigmoid = 1.f / (1.f + expf(-out_val));
468
+ dout_vals[i] *= out_val_sigmoid * (1 + out_val * (1 - out_val_sigmoid));
469
+ }
470
+ }
471
+
472
+ float dweight_vals[kWidth] = {0};
473
+ SumOp<float> sum_op;
474
+ #pragma unroll
475
+ for (int w = 0; w < kWidth; ++w) {
476
+ #pragma unroll
477
+ for (int i = 0; i < kLPerThread; ++i) {
478
+ if constexpr (!kHasSeqIdx) {
479
+ dweight_vals[w] += x_vals[i + w] * dout_vals[i];
480
+ } else {
481
+ dweight_vals[w] += seq_idx_thread[i + w] == seq_idx_thread[kWidth - 1 + i] ? x_vals[i + w] * dout_vals[i] : 0.f;
482
+ }
483
+ }
484
+ dweight_vals[w] = Allreduce<kNThreadsPerRow>::run(dweight_vals[w], sum_op);
485
+ if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
486
+ atomicAdd(&reinterpret_cast<float *>(dweight)[row_idx * params.dweight_c_stride + w * params.dweight_width_stride], dweight_vals[w]);
487
+ }
488
+ }
489
+
490
+ if (params.bias_ptr != nullptr) {
491
+ float dbias_val = 0.f;
492
+ for (int i = 0; i < kLPerThread; ++i) { dbias_val += dout_vals[i]; }
493
+ dbias_val = Allreduce<kNThreadsPerRow>::run(dbias_val, sum_op);
494
+ if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
495
+ atomicAdd(&reinterpret_cast<float *>(params.dbias_ptr)[chunk_c_id * kChunkSizeC + row_idx], dbias_val);
496
+ }
497
+ }
498
+
499
+ float dx_vals[kLPerThread] = {0};
500
+ #pragma unroll
501
+ for (int i = 0; i < kLPerThread; ++i) {
502
+ const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
503
+ #pragma unroll
504
+ for (int w = 0; w < kWidth; ++w) {
505
+ if constexpr (!kHasSeqIdx) {
506
+ dx_vals[i] += weight_vals[kWidth - 1 - w] * dout_vals[i + w];
507
+ } else {
508
+ dx_vals[i] += seq_idx_thread[kWidth - 1 + i + w] == seq_idx_cur ? weight_vals[kWidth - 1 - w] * dout_vals[i + w] : 0.f;
509
+ }
510
+ }
511
+ // if (dfinal_states != nullptr) {
512
+ if constexpr (kHasDfinalStates) {
513
+ if (chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i >= params.seqlen - kWidth + 1
514
+ && chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i < params.seqlen
515
+ && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
516
+ dx_vals[i] += float(dfinal_states[((chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i) - (params.seqlen - kWidth + 1)) * params.dfinal_states_l_stride + row_idx * params.dfinal_states_c_stride]);
517
+ }
518
+ }
519
+ }
520
+
521
+ float dxinit_vals[kWidth - 1] = {0};
522
+ static_assert(kLPerThread >= kWidth - 1); // So only threads with col_idx == 0 need to handle dinitial_states
523
+ if (dinitial_states != nullptr && col_idx == 0) {
524
+ #pragma unroll
525
+ for (int i = 0; i < kWidth - 1; ++i) {
526
+ #pragma unroll
527
+ for (int w = 0; w < kWidth; ++w) {
528
+ dxinit_vals[i] += i + w - (kWidth - 1) >= 0 ? weight_vals[kWidth - 1 - w] * dout_vals[i + w - (kWidth - 1)] : 0.f;
529
+ }
530
+ // chunk_l_id must be 0 because dinitial_states != nullptr
531
+ // if (dfinal_states != nullptr) {
532
+ if constexpr (kHasDfinalStates) {
533
+ if (i >= params.seqlen) {
534
+ dxinit_vals[i] += float(dfinal_states[(i - params.seqlen) * params.dfinal_states_l_stride + row_idx * params.dfinal_states_c_stride]);
535
+ }
536
+ }
537
+ }
538
+ }
539
+
540
+ __syncthreads();
541
+ #pragma unroll
542
+ for (int i = 0; i < kLPerThread; ++i) { x_smem[kWidth - 1 + col_idx * kLPerThread + i][row_idx] = dx_vals[i]; }
543
+ if (dinitial_states != nullptr && col_idx == 0) {
544
+ #pragma unroll
545
+ for (int i = 0; i < kWidth - 1; ++i) { x_smem[i][row_idx] = dxinit_vals[i]; }
546
+ }
547
+ __syncthreads();
548
+
549
+ #pragma unroll
550
+ for (int l = 0; l < Ktraits::kNLoads; ++l) {
551
+ input_t dx_vals_store[kNElts];
552
+ reinterpret_cast<vec_t *>(dx_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx];
553
+ if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
554
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
555
+ *reinterpret_cast<vec_t *>(dx + l * kLPerLoad * params.dx_l_stride) = reinterpret_cast<vec_t *>(dx_vals_store)[0];
556
+ }
557
+ }
558
+ if (dinitial_states != nullptr
559
+ && l_idx < kWidth - 1
560
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
561
+ input_t dxinit_vals_store[kNElts];
562
+ reinterpret_cast<vec_t *>(dxinit_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx];
563
+ *reinterpret_cast<vec_t *>(dinitial_states) = reinterpret_cast<vec_t *>(dxinit_vals_store)[0];
564
+ }
565
+
566
+ }
567
+
568
+ template<int kNThreads, int kWidth, typename input_t, typename weight_t>
569
+ void causal_conv1d_channellast_bwd_launch(ConvParamsBwd &params, cudaStream_t stream) {
570
+ BOOL_SWITCH(params.silu_activation, kSiluAct, [&] {
571
+ BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] {
572
+ BOOL_SWITCH(params.dfinal_states_ptr != nullptr, kHasDfinalStates, [&] {
573
+ BOOL_SWITCH(params.seqlen <= 128, kChunkSizeL64, [&] {
574
+ // kChunkSizeL = 128 is slightly faster than 64 when seqlen is larger
575
+ static constexpr int kChunk = kChunkSizeL64 ? 64 : 128;
576
+ using Ktraits = Causal_conv1d_channellast_bwd_kernel_traits<kNThreads, kWidth, kChunk, kSiluAct, true, input_t, weight_t>;
577
+ // constexpr int kSmemSize = Ktraits::kSmemSize;
578
+ constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
579
+ constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
580
+ const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
581
+ const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
582
+ dim3 grid(params.batch, n_chunks_L, n_chunks_C);
583
+ dim3 block(Ktraits::kNThreads);
584
+ auto kernel = &causal_conv1d_channellast_bwd_kernel<Ktraits, kHasSeqIdx, kHasDfinalStates>;
585
+ // if (kSmemSize >= 48 * 1024) {
586
+ // C10_CUDA_CHECK(cudaFuncSetAttribute(
587
+ // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
588
+ // }
589
+ // kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
590
+ kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
591
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
592
+ });
593
+ });
594
+ });
595
+ });
596
+ }
597
+
598
+ template<typename input_t, typename weight_t>
599
+ void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd &params, cudaStream_t stream) {
600
+ if (params.width == 2) {
601
+ causal_conv1d_channellast_bwd_launch<128, 2, input_t, weight_t>(params, stream);
602
+ } else if (params.width == 3) {
603
+ causal_conv1d_channellast_bwd_launch<128, 3, input_t, weight_t>(params, stream);
604
+ } else if (params.width == 4) {
605
+ causal_conv1d_channellast_bwd_launch<128, 4, input_t, weight_t>(params, stream);
606
+ }
607
+ }
608
+
609
+ template void causal_conv1d_bwd_cuda<float, float>(ConvParamsBwd &params, cudaStream_t stream);
610
+ template void causal_conv1d_bwd_cuda<at::Half, float>(ConvParamsBwd &params, cudaStream_t stream);
611
+ template void causal_conv1d_bwd_cuda<at::BFloat16, float>(ConvParamsBwd &params, cudaStream_t stream);
612
+ template void causal_conv1d_bwd_cuda<float, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
613
+ template void causal_conv1d_bwd_cuda<at::Half, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
614
+ template void causal_conv1d_bwd_cuda<at::BFloat16, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
615
+ template void causal_conv1d_bwd_cuda<float, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
616
+ template void causal_conv1d_bwd_cuda<at::Half, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
617
+ template void causal_conv1d_bwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
618
+
619
+ template void causal_conv1d_channellast_bwd_cuda<float, float>(ConvParamsBwd &params, cudaStream_t stream);
620
+ template void causal_conv1d_channellast_bwd_cuda<at::Half, float>(ConvParamsBwd &params, cudaStream_t stream);
621
+ template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, float>(ConvParamsBwd &params, cudaStream_t stream);
622
+ template void causal_conv1d_channellast_bwd_cuda<float, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
623
+ template void causal_conv1d_channellast_bwd_cuda<at::Half, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
624
+ template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
625
+ template void causal_conv1d_channellast_bwd_cuda<float, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
626
+ template void causal_conv1d_channellast_bwd_cuda<at::Half, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
627
+ template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
causal-conv1d/csrc/causal_conv1d_common.h ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #ifndef USE_ROCM
8
+ #include <cuda_bf16.h>
9
+
10
+ template<typename T>
11
+ __device__ inline T shuffle_xor(T val, int offset) {
12
+ return __shfl_xor_sync(uint32_t(-1), val, offset);
13
+ }
14
+
15
+ constexpr size_t custom_max(std::initializer_list<size_t> ilist)
16
+ {
17
+ return std::max(ilist);
18
+ }
19
+
20
+ template<typename T>
21
+ constexpr T constexpr_min(T a, T b) {
22
+ return std::min(a, b);
23
+ }
24
+
25
+ #else
26
+ #include <hip/hip_bf16.h>
27
+
28
+ template<typename T>
29
+ __device__ inline T shuffle_xor(T val, int offset) {
30
+ return __shfl_xor(val, offset);
31
+ }
32
+ constexpr size_t custom_max(std::initializer_list<size_t> ilist)
33
+ {
34
+ return *std::max_element(ilist.begin(), ilist.end());
35
+ }
36
+
37
+ template<typename T>
38
+ constexpr T constexpr_min(T a, T b) {
39
+ return a < b ? a : b;
40
+ }
41
+ #endif
42
+ #include <cuda_fp16.h>
43
+
44
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
45
+
46
+ template<int BYTES> struct BytesToType {};
47
+
48
+ template<> struct BytesToType<16> {
49
+ using Type = uint4;
50
+ static_assert(sizeof(Type) == 16);
51
+ };
52
+
53
+ template<> struct BytesToType<8> {
54
+ using Type = uint64_t;
55
+ static_assert(sizeof(Type) == 8);
56
+ };
57
+
58
+ template<> struct BytesToType<4> {
59
+ using Type = uint32_t;
60
+ static_assert(sizeof(Type) == 4);
61
+ };
62
+
63
+ template<> struct BytesToType<2> {
64
+ using Type = uint16_t;
65
+ static_assert(sizeof(Type) == 2);
66
+ };
67
+
68
+ template<> struct BytesToType<1> {
69
+ using Type = uint8_t;
70
+ static_assert(sizeof(Type) == 1);
71
+ };
72
+
73
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
74
+
75
+ template<typename T>
76
+ struct SumOp {
77
+ __device__ inline T operator()(T const & x, T const & y) { return x + y; }
78
+ };
79
+
80
+ template<int THREADS>
81
+ struct Allreduce {
82
+ static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
83
+ template<typename T, typename Operator>
84
+ static __device__ inline T run(T x, Operator &op) {
85
+ constexpr int OFFSET = THREADS / 2;
86
+ x = op(x, shuffle_xor(x, OFFSET));
87
+ return Allreduce<OFFSET>::run(x, op);
88
+ }
89
+ };
90
+
91
+ template<>
92
+ struct Allreduce<2> {
93
+ template<typename T, typename Operator>
94
+ static __device__ inline T run(T x, Operator &op) {
95
+ x = op(x, shuffle_xor(x, 1));
96
+ return x;
97
+ }
98
+ };
causal-conv1d/csrc/causal_conv1d_fwd.cu ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #include <c10/util/BFloat16.h>
6
+ #include <c10/util/Half.h>
7
+ #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
8
+
9
+ #ifndef USE_ROCM
10
+ #include <cub/block/block_load.cuh>
11
+ #include <cub/block/block_store.cuh>
12
+ #else
13
+ #include <hipcub/hipcub.hpp>
14
+ namespace cub = hipcub;
15
+ #endif
16
+
17
+ #include "causal_conv1d.h"
18
+ #include "causal_conv1d_common.h"
19
+ #include "static_switch.h"
20
+
21
+ template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
22
+ struct Causal_conv1d_fwd_kernel_traits {
23
+ using input_t = input_t_;
24
+ using weight_t = weight_t_;
25
+ static constexpr int kNThreads = kNThreads_;
26
+ static constexpr int kWidth = kWidth_;
27
+ static constexpr int kNBytes = sizeof(input_t);
28
+ static_assert(kNBytes == 2 || kNBytes == 4);
29
+ static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
30
+ static_assert(kWidth <= kNElts);
31
+ static constexpr bool kIsVecLoad = kIsVecLoad_;
32
+ using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
33
+ using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
34
+ using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
35
+ using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
36
+ using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
37
+ static constexpr int kSmemIOSize = kIsVecLoad
38
+ ? 0
39
+ : custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
40
+ static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
41
+ static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
42
+ };
43
+
44
+ template<typename Ktraits>
45
+ __global__ __launch_bounds__(Ktraits::kNThreads)
46
+ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
47
+ constexpr int kWidth = Ktraits::kWidth;
48
+ constexpr int kNThreads = Ktraits::kNThreads;
49
+ constexpr int kNElts = Ktraits::kNElts;
50
+ static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
51
+ using input_t = typename Ktraits::input_t;
52
+ using vec_t = typename Ktraits::vec_t;
53
+ using weight_t = typename Ktraits::weight_t;
54
+
55
+ // Shared memory.
56
+ extern __shared__ char smem_[];
57
+ auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
58
+ auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
59
+ auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
60
+ auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
61
+ vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
62
+
63
+ const int tidx = threadIdx.x;
64
+ const int batch_id = blockIdx.x;
65
+ const int channel_id = blockIdx.y;
66
+ input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
67
+ + channel_id * params.x_c_stride;
68
+ weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
69
+ input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
70
+ + channel_id * params.out_c_stride;
71
+ float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
72
+
73
+ // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
74
+ if (tidx == 0) {
75
+ input_t zeros[kNElts] = {0};
76
+ smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(zeros)[0];
77
+ }
78
+
79
+ float weight_vals[kWidth];
80
+ #pragma unroll
81
+ for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
82
+
83
+ constexpr int kChunkSize = kNThreads * kNElts;
84
+ const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
85
+ for (int chunk = 0; chunk < n_chunks; ++chunk) {
86
+ input_t x_vals_load[2 * kNElts] = {0};
87
+ if constexpr(kIsVecLoad) {
88
+ typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
89
+ } else {
90
+ __syncthreads();
91
+ typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
92
+ }
93
+ x += kChunkSize;
94
+ __syncthreads();
95
+ // Thread kNThreads - 1 don't write yet, so that thread 0 can read
96
+ // the last elements of the previous chunk.
97
+ if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
98
+ __syncthreads();
99
+ reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
100
+ __syncthreads();
101
+ // Now thread kNThreads - 1 can write the last elements of the current chunk.
102
+ if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
103
+
104
+ float x_vals[2 * kNElts];
105
+ #pragma unroll
106
+ for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
107
+
108
+ float out_vals[kNElts];
109
+ #pragma unroll
110
+ for (int i = 0; i < kNElts; ++i) {
111
+ out_vals[i] = bias_val;
112
+ #pragma unroll
113
+ for (int w = 0; w < kWidth; ++w) {
114
+ out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
115
+ }
116
+ }
117
+
118
+ if (params.silu_activation) {
119
+ #pragma unroll
120
+ for (int i = 0; i < kNElts; ++i) {
121
+ out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
122
+ }
123
+ }
124
+
125
+ input_t out_vals_store[kNElts];
126
+ #pragma unroll
127
+ for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
128
+ if constexpr(kIsVecLoad) {
129
+ typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
130
+ } else {
131
+ typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize);
132
+ }
133
+ out += kChunkSize;
134
+ }
135
+ }
136
+
137
+ template<int kNThreads, int kWidth, typename input_t, typename weight_t>
138
+ void causal_conv1d_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
139
+ static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
140
+ BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
141
+ using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
142
+ constexpr int kSmemSize = Ktraits::kSmemSize;
143
+ dim3 grid(params.batch, params.dim);
144
+
145
+ auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
146
+
147
+ if (kSmemSize >= 48 * 1024) {
148
+ #ifndef USE_ROCM
149
+ C10_CUDA_CHECK(cudaFuncSetAttribute(
150
+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
151
+ #else
152
+ // There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
153
+ C10_CUDA_CHECK(cudaFuncSetAttribute(
154
+ (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
155
+ std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
156
+ #endif
157
+ }
158
+ kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
159
+
160
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
161
+ });
162
+ }
163
+
164
+ template<typename input_t, typename weight_t>
165
+ void causal_conv1d_fwd_cuda(ConvParamsBase &params, cudaStream_t stream) {
166
+ if (params.width == 2) {
167
+ causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
168
+ } else if (params.width == 3) {
169
+ causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
170
+ } else if (params.width == 4) {
171
+ causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
172
+ }
173
+ }
174
+
175
+ template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
176
+ struct Causal_conv1d_channellast_fwd_kernel_traits {
177
+ // The cache line is 128 bytes, and we try to read 16 bytes per thread.
178
+ // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
179
+ // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
180
+ // threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
181
+ using input_t = input_t_;
182
+ using weight_t = weight_t_;
183
+ static constexpr int kNThreads = kNThreads_;
184
+ static_assert(kNThreads % 32 == 0);
185
+ static constexpr int kNWarps = kNThreads / 32;
186
+ static constexpr int kWidth = kWidth_;
187
+ static constexpr int kChunkSizeL = kChunkSizeL_;
188
+ static constexpr int kNBytes = sizeof(input_t);
189
+ static_assert(kNBytes == 2 || kNBytes == 4);
190
+ static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
191
+ static constexpr int kNEltsPerRow = 128 / kNBytes;
192
+ static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
193
+ static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
194
+ static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
195
+ static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
196
+ static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
197
+ static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
198
+ static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
199
+ static constexpr bool kIsVecLoad = kIsVecLoad_;
200
+ using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
201
+ // using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
202
+ // using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
203
+ // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
204
+ // sizeof(typename BlockStoreT::TempStorage)});
205
+ // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
206
+ };
207
+
208
+ template<typename Ktraits, bool kHasSeqIdx>
209
+ __global__ __launch_bounds__(Ktraits::kNThreads)
210
+ void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) {
211
+ constexpr int kWidth = Ktraits::kWidth;
212
+ constexpr int kNThreads = Ktraits::kNThreads;
213
+ constexpr int kNElts = Ktraits::kNElts;
214
+ constexpr int kNWarp = Ktraits::kNWarps;
215
+ constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
216
+ constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
217
+ constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
218
+ constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
219
+ using input_t = typename Ktraits::input_t;
220
+ using vec_t = typename Ktraits::vec_t;
221
+ using weight_t = typename Ktraits::weight_t;
222
+
223
+ // Shared memory.
224
+ __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts];
225
+
226
+ const int batch_id = blockIdx.x;
227
+ const int chunk_l_id = blockIdx.y;
228
+ const int chunk_c_id = blockIdx.z;
229
+ const int tid = threadIdx.x;
230
+ const int l_idx = tid / kNThreadsPerC;
231
+ const int c_idx = tid % kNThreadsPerC;
232
+ input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
233
+ + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
234
+ weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
235
+ + chunk_c_id * kChunkSizeC * params.weight_c_stride;
236
+ input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
237
+ + (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
238
+ int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast<int *>(params.seq_idx_ptr)
239
+ + batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
240
+ input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
241
+ : reinterpret_cast<input_t *>(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
242
+ // The last L-chunk will also have enough info to write to final states, since it also contain a few x values
243
+ // from the previous L-chunk.
244
+ input_t *final_states = params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr
245
+ : reinterpret_cast<input_t *>(params.final_states_ptr) + batch_id * params.final_states_batch_stride + l_idx * params.final_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
246
+
247
+ #pragma unroll
248
+ for (int l = 0; l < Ktraits::kNLoads; ++l) {
249
+ input_t x_vals_load[kNElts] = {0};
250
+ if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
251
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
252
+ reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
253
+ }
254
+ reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
255
+ }
256
+ // Load the elements from the previous chunk that are needed for convolution.
257
+ if (l_idx < kWidth - 1) {
258
+ input_t x_vals_load[kNElts] = {0};
259
+ if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
260
+ && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
261
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
262
+ reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
263
+ } else if (initial_states != nullptr
264
+ && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0
265
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
266
+ reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(initial_states);
267
+ }
268
+ reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
269
+ }
270
+
271
+ __syncthreads();
272
+
273
+ if (final_states != nullptr
274
+ && l_idx < kWidth - 1
275
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
276
+ // x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - 1)
277
+ // So last few elements (index params.seqlen - kWidth + 1 + l_idx) are stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * kChunkSizeL - kWidth + 1)][c_idx]
278
+ *reinterpret_cast<vec_t *>(final_states) = reinterpret_cast<vec_t *>(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx];
279
+ }
280
+
281
+ constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
282
+ static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
283
+ constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
284
+ static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
285
+ // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
286
+ static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
287
+ static_assert((kLPerThread & (kLPerThread - 1)) == 0);
288
+ static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
289
+ static_assert(kNThreadsPerRow <= 32);
290
+
291
+ const int row_idx = tid / kNThreadsPerRow;
292
+ const int col_idx = tid % kNThreadsPerRow;
293
+
294
+ float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
295
+ float weight_vals[kWidth] = {0};
296
+ if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
297
+ #pragma unroll
298
+ for (int w = 0; w < kWidth; ++w) {
299
+ weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
300
+ }
301
+ }
302
+ float x_vals[kWidth - 1 + kLPerThread];
303
+ #pragma unroll
304
+ for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
305
+ x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
306
+ }
307
+ int seq_idx_thread[kWidth - 1 + kLPerThread];
308
+ if constexpr (kHasSeqIdx) {
309
+ #pragma unroll
310
+ for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
311
+ seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1;
312
+ }
313
+ }
314
+
315
+ float out_vals[kLPerThread];
316
+ #pragma unroll
317
+ for (int i = 0; i < kLPerThread; ++i) {
318
+ out_vals[i] = bias_val;
319
+ const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
320
+ #pragma unroll
321
+ for (int w = 0; w < kWidth; ++w) {
322
+ if constexpr (!kHasSeqIdx) {
323
+ out_vals[i] += weight_vals[w] * x_vals[i + w];
324
+ } else {
325
+ out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f;
326
+ }
327
+ }
328
+ if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); }
329
+ }
330
+
331
+ __syncthreads();
332
+ #pragma unroll
333
+ for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; }
334
+ __syncthreads();
335
+
336
+ #pragma unroll
337
+ for (int l = 0; l < Ktraits::kNLoads; ++l) {
338
+ input_t out_vals_store[kNElts];
339
+ reinterpret_cast<vec_t *>(out_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx];
340
+ if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
341
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
342
+ *reinterpret_cast<vec_t *>(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast<vec_t *>(out_vals_store)[0];
343
+ }
344
+ }
345
+
346
+ }
347
+
348
+ template<int kNThreads, int kWidth, typename input_t, typename weight_t>
349
+ void causal_conv1d_channellast_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
350
+ BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] {
351
+ using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits<kNThreads, kWidth, 64, true, input_t, weight_t>;
352
+ // constexpr int kSmemSize = Ktraits::kSmemSize;
353
+ constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
354
+ constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
355
+ const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
356
+ const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
357
+ dim3 grid(params.batch, n_chunks_L, n_chunks_C);
358
+ dim3 block(Ktraits::kNThreads);
359
+ auto kernel = &causal_conv1d_channellast_fwd_kernel<Ktraits, kHasSeqIdx>;
360
+ // if (kSmemSize >= 48 * 1024) {
361
+ // C10_CUDA_CHECK(cudaFuncSetAttribute(
362
+ // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
363
+ // }
364
+ // kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
365
+ kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
366
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
367
+ });
368
+ }
369
+
370
+ template<typename input_t, typename weight_t>
371
+ void causal_conv1d_channellast_fwd_cuda(ConvParamsBase &params, cudaStream_t stream) {
372
+ if (params.width == 2) {
373
+ causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream);
374
+ } else if (params.width == 3) {
375
+ causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream);
376
+ } else if (params.width == 4) {
377
+ causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream);
378
+ }
379
+ }
380
+
381
+ template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
382
+ template void causal_conv1d_fwd_cuda<at::Half, float>(ConvParamsBase &params, cudaStream_t stream);
383
+ template void causal_conv1d_fwd_cuda<at::BFloat16, float>(ConvParamsBase &params, cudaStream_t stream);
384
+ template void causal_conv1d_fwd_cuda<float, at::Half>(ConvParamsBase &params, cudaStream_t stream);
385
+ template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
386
+ template void causal_conv1d_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase &params, cudaStream_t stream);
387
+ template void causal_conv1d_fwd_cuda<float, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
388
+ template void causal_conv1d_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
389
+ template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
390
+
391
+ template void causal_conv1d_channellast_fwd_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
392
+ template void causal_conv1d_channellast_fwd_cuda<at::Half, float>(ConvParamsBase &params, cudaStream_t stream);
393
+ template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, float>(ConvParamsBase &params, cudaStream_t stream);
394
+ template void causal_conv1d_channellast_fwd_cuda<float, at::Half>(ConvParamsBase &params, cudaStream_t stream);
395
+ template void causal_conv1d_channellast_fwd_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
396
+ template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase &params, cudaStream_t stream);
397
+ template void causal_conv1d_channellast_fwd_cuda<float, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
398
+ template void causal_conv1d_channellast_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
399
+ template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
causal-conv1d/csrc/causal_conv1d_update.cu ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #include <c10/util/BFloat16.h>
6
+ #include <c10/util/Half.h>
7
+ #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
8
+
9
+ #include "causal_conv1d.h"
10
+ #include "causal_conv1d_common.h"
11
+ #include "static_switch.h"
12
+
13
+ template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
14
+ struct Causal_conv1d_update_kernel_traits {
15
+ using input_t = input_t_;
16
+ using weight_t = weight_t_;
17
+ static constexpr int kNThreads = kNThreads_;
18
+ static constexpr int kWidth = kWidth_;
19
+ static constexpr int kNBytes = sizeof(input_t);
20
+ static_assert(kNBytes == 2 || kNBytes == 4);
21
+ };
22
+
23
+ template<typename Ktraits, bool kIsCircularBuffer>
24
+ __global__ __launch_bounds__(Ktraits::kNThreads)
25
+ void causal_conv1d_update_kernel(ConvParamsBase params) {
26
+ constexpr int kWidth = Ktraits::kWidth;
27
+ constexpr int kNThreads = Ktraits::kNThreads;
28
+ using input_t = typename Ktraits::input_t;
29
+ using weight_t = typename Ktraits::weight_t;
30
+
31
+ const int tidx = threadIdx.x;
32
+ const int batch_id = blockIdx.x;
33
+ const int channel_id = blockIdx.y * kNThreads + tidx;
34
+ if (channel_id >= params.dim) return;
35
+
36
+ input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
37
+ + channel_id * params.x_c_stride;
38
+ input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride
39
+ + channel_id * params.conv_state_c_stride;
40
+ weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
41
+ input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
42
+ + channel_id * params.out_c_stride;
43
+ float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
44
+
45
+ int state_len = params.conv_state_len;
46
+ int advance_len = params.seqlen;
47
+ int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0;
48
+ int update_idx = cache_seqlen - (kWidth - 1);
49
+ update_idx = update_idx < 0 ? update_idx + state_len : update_idx;
50
+
51
+ float weight_vals[kWidth] = {0};
52
+ #pragma unroll
53
+ for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
54
+
55
+ float x_vals[kWidth] = {0};
56
+ if constexpr (!kIsCircularBuffer) {
57
+ #pragma unroll 2
58
+ for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) {
59
+ conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride];
60
+ }
61
+ #pragma unroll
62
+ for (int i = 0; i < kWidth - 1; ++i) {
63
+ input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride];
64
+ if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) {
65
+ conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val;
66
+ }
67
+ x_vals[i] = float(state_val);
68
+ }
69
+ } else {
70
+ #pragma unroll
71
+ for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) {
72
+ input_t state_val = conv_state[update_idx * params.conv_state_l_stride];
73
+ x_vals[i] = float(state_val);
74
+ }
75
+ }
76
+ #pragma unroll 2
77
+ for (int i = 0; i < params.seqlen; ++i) {
78
+ input_t x_val = x[i * params.x_l_stride];
79
+ if constexpr (!kIsCircularBuffer) {
80
+ if (i < advance_len && state_len - advance_len + i >= 0) {
81
+ conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val;
82
+ }
83
+ } else {
84
+ conv_state[update_idx * params.conv_state_l_stride] = x_val;
85
+ ++update_idx;
86
+ update_idx = update_idx >= state_len ? update_idx - state_len : update_idx;
87
+ }
88
+ x_vals[kWidth - 1] = float(x_val);
89
+ float out_val = bias_val;
90
+ #pragma unroll
91
+ for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; }
92
+ if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
93
+ out[i * params.out_l_stride] = input_t(out_val);
94
+ // Shift the input buffer by 1
95
+ #pragma unroll
96
+ for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; }
97
+ }
98
+ }
99
+
100
+ template<int kNThreads, int kWidth, typename input_t, typename weight_t>
101
+ void causal_conv1d_update_launch(ConvParamsBase &params, cudaStream_t stream) {
102
+ using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
103
+ dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
104
+ auto kernel = params.cache_seqlens == nullptr
105
+ ? &causal_conv1d_update_kernel<Ktraits, false>
106
+ : &causal_conv1d_update_kernel<Ktraits, true>;
107
+ kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
108
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
109
+ }
110
+
111
+ template<typename input_t, typename weight_t>
112
+ void causal_conv1d_update_cuda(ConvParamsBase &params, cudaStream_t stream) {
113
+ if (params.width == 2) {
114
+ causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
115
+ } else if (params.width == 3) {
116
+ causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
117
+ } else if (params.width == 4) {
118
+ causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
119
+ }
120
+ }
121
+
122
+ template void causal_conv1d_update_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
123
+ template void causal_conv1d_update_cuda<at::Half, float>(ConvParamsBase &params, cudaStream_t stream);
124
+ template void causal_conv1d_update_cuda<at::BFloat16, float>(ConvParamsBase &params, cudaStream_t stream);
125
+ template void causal_conv1d_update_cuda<float, at::Half>(ConvParamsBase &params, cudaStream_t stream);
126
+ template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
127
+ template void causal_conv1d_update_cuda<at::BFloat16, at::Half>(ConvParamsBase &params, cudaStream_t stream);
128
+ template void causal_conv1d_update_cuda<float, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
129
+ template void causal_conv1d_update_cuda<at::Half, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
130
+ template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
causal-conv1d/csrc/static_switch.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
2
+ // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
3
+
4
+ #pragma once
5
+
6
+ /// @param COND - a boolean expression to switch by
7
+ /// @param CONST_NAME - a name given for the constexpr bool variable.
8
+ /// @param ... - code to execute for true and false
9
+ ///
10
+ /// Usage:
11
+ /// ```
12
+ /// BOOL_SWITCH(flag, BoolConst, [&] {
13
+ /// some_function<BoolConst>(...);
14
+ /// });
15
+ /// ```
16
+ #define BOOL_SWITCH(COND, CONST_NAME, ...) \
17
+ [&] { \
18
+ if (COND) { \
19
+ static constexpr bool CONST_NAME = true; \
20
+ return __VA_ARGS__(); \
21
+ } else { \
22
+ static constexpr bool CONST_NAME = false; \
23
+ return __VA_ARGS__(); \
24
+ } \
25
+ }()
causal-conv1d/dist/causal_conv1d-1.4.0-py3.9.egg ADDED
Binary file (10 kB). View file
 
causal-conv1d/rocm_patch/rocm6_0.patch ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --- /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h 2023-12-12 20:11:48.000000000 +0000
2
+ +++ rocm_update_files/amd_hip_bf16.h 2024-05-20 17:40:26.983349079 +0000
3
+ @@ -137,7 +137,7 @@
4
+ * \ingroup HIP_INTRINSIC_BFLOAT16_CONV
5
+ * \brief Converts float to bfloat16
6
+ */
7
+ -__HOST_DEVICE__ __hip_bfloat16 __float2bfloat16(float f) {
8
+ +__HOST_DEVICE__ static inline __hip_bfloat16 __float2bfloat16(float f) {
9
+ __hip_bfloat16 ret;
10
+ union {
11
+ float fp32;
12
+ @@ -181,7 +181,7 @@
13
+ * \ingroup HIP_INTRINSIC_BFLOAT162_CONV
14
+ * \brief Converts and moves bfloat162 to float2
15
+ */
16
+ -__HOST_DEVICE__ float2 __bfloat1622float2(const __hip_bfloat162 a) {
17
+ +__HOST_DEVICE__ static inline float2 __bfloat1622float2(const __hip_bfloat162 a) {
18
+ return float2{__bfloat162float(a.x), __bfloat162float(a.y)};
19
+ }
20
+
21
+ @@ -209,7 +209,7 @@
22
+ * \ingroup HIP_INTRINSIC_BFLOAT162_CONV
23
+ * \brief Convert double to __hip_bfloat16
24
+ */
25
+ -__HOST_DEVICE__ __hip_bfloat16 __double2bfloat16(const double a) {
26
+ +__HOST_DEVICE__ static inline __hip_bfloat16 __double2bfloat16(const double a) {
27
+ return __float2bfloat16((float)a);
28
+ }
29
+
30
+ @@ -217,7 +217,7 @@
31
+ * \ingroup HIP_INTRINSIC_BFLOAT162_CONV
32
+ * \brief Convert float2 to __hip_bfloat162
33
+ */
34
+ -__HOST_DEVICE__ __hip_bfloat162 __float22bfloat162_rn(const float2 a) {
35
+ +__HOST_DEVICE__ static inline __hip_bfloat162 __float22bfloat162_rn(const float2 a) {
36
+ return __hip_bfloat162{__float2bfloat16(a.x), __float2bfloat16(a.y)};
37
+ }
38
+
39
+ @@ -247,7 +247,7 @@
40
+ * \ingroup HIP_INTRINSIC_BFLOAT162_CONV
41
+ * \brief Converts high 16 bits of __hip_bfloat162 to float and returns the result
42
+ */
43
+ -__HOST_DEVICE__ float __high2float(const __hip_bfloat162 a) { return __bfloat162float(a.y); }
44
+ +__HOST_DEVICE__ static inline float __high2float(const __hip_bfloat162 a) { return __bfloat162float(a.y); }
45
+
46
+ /**
47
+ * \ingroup HIP_INTRINSIC_BFLOAT162_CONV
48
+ @@ -275,7 +275,7 @@
49
+ * \ingroup HIP_INTRINSIC_BFLOAT162_CONV
50
+ * \brief Converts low 16 bits of __hip_bfloat162 to float and returns the result
51
+ */
52
+ -__HOST_DEVICE__ float __low2float(const __hip_bfloat162 a) { return __bfloat162float(a.x); }
53
+ +__HOST_DEVICE__ static inline float __low2float(const __hip_bfloat162 a) { return __bfloat162float(a.x); }
54
+
55
+ /**
56
+ * \ingroup HIP_INTRINSIC_BFLOAT162_CONV
causal-conv1d/setup.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+
3
+ import sys
4
+ import warnings
5
+ import os
6
+ import re
7
+ import shutil
8
+ import ast
9
+ from pathlib import Path
10
+ from packaging.version import parse, Version
11
+ import platform
12
+
13
+ from setuptools import setup, find_packages
14
+ import subprocess
15
+
16
+ import urllib.request
17
+ import urllib.error
18
+ from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
19
+
20
+ import torch
21
+ from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME, HIP_HOME
22
+
23
+
24
+ with open("README.md", "r", encoding="utf-8") as fh:
25
+ long_description = fh.read()
26
+
27
+
28
+ # ninja build does not work unless include_dirs are abs path
29
+ this_dir = os.path.dirname(os.path.abspath(__file__))
30
+
31
+ PACKAGE_NAME = "causal_conv1d"
32
+
33
+ BASE_WHEEL_URL = "https://github.com/Dao-AILab/causal-conv1d/releases/download/{tag_name}/{wheel_name}"
34
+
35
+ # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
36
+ # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
37
+ FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "FALSE") == "TRUE"
38
+ SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
39
+ # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
40
+ FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "FALSE") == "TRUE"
41
+
42
+
43
+ def get_platform():
44
+ """
45
+ Returns the platform name as used in wheel filenames.
46
+ """
47
+ if sys.platform.startswith("linux"):
48
+ return "linux_x86_64"
49
+ elif sys.platform == "darwin":
50
+ mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
51
+ return f"macosx_{mac_version}_x86_64"
52
+ elif sys.platform == "win32":
53
+ return "win_amd64"
54
+ else:
55
+ raise ValueError("Unsupported platform: {}".format(sys.platform))
56
+
57
+
58
+
59
+
60
+ def get_hip_version(rocm_dir):
61
+
62
+ hipcc_bin = "hipcc" if rocm_dir is None else os.path.join(rocm_dir, "bin", "hipcc")
63
+ try:
64
+ raw_output = subprocess.check_output(
65
+ [hipcc_bin, "--version"], universal_newlines=True
66
+ )
67
+ except Exception as e:
68
+ print(
69
+ f"hip installation not found: {e} ROCM_PATH={os.environ.get('ROCM_PATH')}"
70
+ )
71
+ return None, None
72
+
73
+ for line in raw_output.split("\n"):
74
+ if "HIP version" in line:
75
+ rocm_version = parse(line.split()[-1].replace("-", "+")) # local version is not parsed correctly
76
+ return line, rocm_version
77
+
78
+ return None, None
79
+
80
+
81
+ def get_torch_hip_version():
82
+ if torch.version.hip:
83
+ return parse(torch.version.hip.split()[-1].replace("-", "+"))
84
+ else:
85
+ return None
86
+
87
+
88
+ def check_if_hip_home_none(global_option: str) -> None:
89
+
90
+ if HIP_HOME is not None:
91
+ return
92
+ # warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary
93
+ # in that case.
94
+ warnings.warn(
95
+ f"{global_option} was requested, but hipcc was not found. Are you sure your environment has hipcc available?"
96
+ )
97
+
98
+
99
+ def check_if_cuda_home_none(global_option: str) -> None:
100
+ if CUDA_HOME is not None:
101
+ return
102
+ # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
103
+ # in that case.
104
+ warnings.warn(
105
+ f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
106
+ "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
107
+ "only images whose names contain 'devel' will provide nvcc."
108
+ )
109
+
110
+
111
+ def append_nvcc_threads(nvcc_extra_args):
112
+ return nvcc_extra_args + ["--threads", "4"]
113
+
114
+
115
+ cmdclass = {}
116
+ ext_modules = []
117
+
118
+
119
+ HIP_BUILD = bool(torch.version.hip)
120
+
121
+ if not SKIP_CUDA_BUILD:
122
+
123
+ print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
124
+ TORCH_MAJOR = int(torch.__version__.split(".")[0])
125
+ TORCH_MINOR = int(torch.__version__.split(".")[1])
126
+
127
+
128
+ cc_flag = []
129
+
130
+ if HIP_BUILD:
131
+ check_if_hip_home_none(PACKAGE_NAME)
132
+
133
+ rocm_home = os.getenv("ROCM_PATH")
134
+ _, hip_version = get_hip_version(rocm_home)
135
+
136
+
137
+ if HIP_HOME is not None:
138
+ if hip_version < Version("6.0"):
139
+ raise RuntimeError(
140
+ f"{PACKAGE_NAME} is only supported on ROCm 6.0 and above. "
141
+ "Note: make sure HIP has a supported version by running hipcc --version."
142
+ )
143
+ if hip_version == Version("6.0"):
144
+ warnings.warn(
145
+ f"{PACKAGE_NAME} requires a patch to be applied when running on ROCm 6.0. "
146
+ "Refer to the README.md for detailed instructions.",
147
+ UserWarning
148
+ )
149
+
150
+ cc_flag.append("-DBUILD_PYTHON_PACKAGE")
151
+
152
+ else:
153
+ cc_flag.append("-gencode")
154
+ cc_flag.append("arch=compute_53,code=sm_53")
155
+ cc_flag.append("-gencode")
156
+ cc_flag.append("arch=compute_62,code=sm_62")
157
+ cc_flag.append("-gencode")
158
+ cc_flag.append("arch=compute_70,code=sm_70")
159
+ cc_flag.append("-gencode")
160
+ cc_flag.append("arch=compute_72,code=sm_72")
161
+ cc_flag.append("-gencode")
162
+ cc_flag.append("arch=compute_80,code=sm_80")
163
+ cc_flag.append("-gencode")
164
+ cc_flag.append("arch=compute_87,code=sm_87")
165
+
166
+ # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
167
+ # torch._C._GLIBCXX_USE_CXX11_ABI
168
+ # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
169
+ if FORCE_CXX11_ABI:
170
+ torch._C._GLIBCXX_USE_CXX11_ABI = True
171
+
172
+
173
+ if HIP_BUILD:
174
+ extra_compile_args = {
175
+ "cxx": ["-O3", "-std=c++17"],
176
+ }
177
+ else:
178
+ extra_compile_args = {
179
+ "cxx": ["-O3"],
180
+ }
181
+
182
+
183
+ def get_package_version():
184
+ with open(Path(this_dir) / "causal_conv1d" / "__init__.py", "r") as f:
185
+ version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
186
+ public_version = ast.literal_eval(version_match.group(1))
187
+ local_version = os.environ.get("CAUSAL_CONV1D_LOCAL_VERSION")
188
+ if local_version:
189
+ return f"{public_version}+{local_version}"
190
+ else:
191
+ return str(public_version)
192
+
193
+
194
+ def get_wheel_url():
195
+
196
+ # Determine the version numbers that will be used to determine the correct wheel
197
+ torch_version_raw = parse(torch.__version__)
198
+
199
+ if HIP_BUILD:
200
+ # We're using the HIP version used to build torch, not the one currently installed
201
+ torch_hip_version = get_torch_hip_version()
202
+ hip_version = f"{torch_hip_version.major}{torch_hip_version.minor}"
203
+
204
+ gpu_compute_version = hip_version if HIP_BUILD else cuda_version
205
+ cuda_or_hip = "hip" if HIP_BUILD else "cu"
206
+
207
+ python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
208
+ platform_name = get_platform()
209
+ causal_conv1d_version = get_package_version()
210
+
211
+ torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
212
+ cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
213
+
214
+ # Determine wheel URL based on CUDA version, torch version, python version and OS
215
+ wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+{cuda_or_hip}{gpu_compute_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
216
+
217
+ wheel_url = BASE_WHEEL_URL.format(
218
+ tag_name=f"v{causal_conv1d_version}", wheel_name=wheel_filename
219
+ )
220
+ return wheel_url, wheel_filename
221
+
222
+
223
+ class CachedWheelsCommand(_bdist_wheel):
224
+ """
225
+ The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
226
+ find an existing wheel (which is currently the case for all installs). We use
227
+ the environment parameters to detect whether there is already a pre-built version of a compatible
228
+ wheel available and short-circuits the standard full build pipeline.
229
+ """
230
+
231
+ def run(self):
232
+ if FORCE_BUILD:
233
+ return super().run()
234
+
235
+ wheel_url, wheel_filename = get_wheel_url()
236
+ print("Guessing wheel URL: ", wheel_url)
237
+ try:
238
+ urllib.request.urlretrieve(wheel_url, wheel_filename)
239
+
240
+ # Make the archive
241
+ # Lifted from the root wheel processing command
242
+ # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
243
+ if not os.path.exists(self.dist_dir):
244
+ os.makedirs(self.dist_dir)
245
+
246
+ impl_tag, abi_tag, plat_tag = self.get_tag()
247
+ archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
248
+
249
+ wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
250
+ print("Raw wheel path", wheel_path)
251
+ shutil.move(wheel_filename, wheel_path)
252
+ except urllib.error.HTTPError:
253
+ print("Precompiled wheel not found. Building from source...")
254
+ # If the wheel could not be downloaded, build from source
255
+ super().run()
256
+
257
+
258
+ setup(
259
+ name=PACKAGE_NAME,
260
+ version=get_package_version(),
261
+ packages=find_packages(
262
+ exclude=(
263
+ "build",
264
+ "csrc",
265
+ "include",
266
+ "tests",
267
+ "dist",
268
+ "docs",
269
+ "benchmarks",
270
+ "causal_conv1d.egg-info",
271
+ )
272
+ ),
273
+ author="Tri Dao",
274
+ author_email="tri@tridao.me",
275
+ description="Causal depthwise conv1d in CUDA, with a PyTorch interface",
276
+ long_description=long_description,
277
+ long_description_content_type="text/markdown",
278
+ url="https://github.com/Dao-AILab/causal-conv1d",
279
+ classifiers=[
280
+ "Programming Language :: Python :: 3",
281
+ "License :: OSI Approved :: BSD License",
282
+ "Operating System :: Unix",
283
+ ],
284
+ ext_modules=ext_modules,
285
+ cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension}
286
+ if ext_modules
287
+ else {
288
+ "bdist_wheel": CachedWheelsCommand,
289
+ },
290
+ python_requires=">=3.8",
291
+ install_requires=[
292
+ "torch",
293
+ "packaging",
294
+ "ninja",
295
+ ],
296
+ )
causal-conv1d/tests/test_causal_conv1d.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024, Tri Dao.
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ import pytest
9
+
10
+ from einops import rearrange
11
+
12
+ from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_ref
13
+ from causal_conv1d.causal_conv1d_interface import causal_conv1d_update, causal_conv1d_update_ref
14
+ from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states, causal_conv1d_varlen_states_ref
15
+
16
+
17
+ @pytest.mark.parametrize("return_final_states", [False, True])
18
+ # @pytest.mark.parametrize("return_final_states", [True])
19
+ @pytest.mark.parametrize("has_initial_states", [False, True])
20
+ # @pytest.mark.parametrize("has_initial_states", [False])
21
+ @pytest.mark.parametrize("channel_last", [False, True])
22
+ # @pytest.mark.parametrize('channel_last', [True])
23
+ @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
24
+ # @pytest.mark.parametrize('itype', [torch.float16])
25
+ @pytest.mark.parametrize("silu_activation", [False, True])
26
+ # @pytest.mark.parametrize('silu_activation', [True])
27
+ @pytest.mark.parametrize("has_bias", [False, True])
28
+ # @pytest.mark.parametrize('has_bias', [True])
29
+ @pytest.mark.parametrize("width", [2, 3, 4])
30
+ # @pytest.mark.parametrize('width', [3])
31
+ @pytest.mark.parametrize(
32
+ "seqlen", [1, 2, 8, 16, 32, 64, 128, 129, 130, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
33
+ )
34
+ # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
35
+ # @pytest.mark.parametrize('seqlen', [128])
36
+ @pytest.mark.parametrize('dim', [64, 4096 + 32])
37
+ # @pytest.mark.parametrize('dim', [64])
38
+ def test_causal_conv1d(dim, seqlen, width, has_bias, silu_activation, itype, channel_last, has_initial_states, return_final_states):
39
+ if not channel_last and (has_initial_states or return_final_states):
40
+ pytest.skip("Only channel_last support initial_states or return_final_states")
41
+ device = "cuda"
42
+ rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
43
+ if itype == torch.bfloat16:
44
+ rtol, atol = 1e-2, 5e-2
45
+ rtolw, atolw = (1e-3, 1e-3)
46
+ # set seed
47
+ torch.random.manual_seed(0)
48
+ batch = 2
49
+ # batch = 1
50
+ if not channel_last:
51
+ x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
52
+ else:
53
+ x = rearrange(
54
+ torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
55
+ ).requires_grad_()
56
+ weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
57
+ if has_bias:
58
+ bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
59
+ else:
60
+ bias = None
61
+ if has_initial_states:
62
+ initial_states = torch.randn(batch, width - 1, dim, device=device, dtype=itype).transpose(1, 2).requires_grad_()
63
+ else:
64
+ initial_states = None
65
+ x_ref = x.detach().clone().requires_grad_()
66
+ weight_ref = weight.detach().clone().requires_grad_()
67
+ bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
68
+ initial_states_ref = initial_states.detach().clone().requires_grad_() if initial_states is not None else None
69
+ activation = None if not silu_activation else "silu"
70
+ out = causal_conv1d_fn(x, weight, bias, initial_states=initial_states, return_final_states=return_final_states,
71
+ activation=activation)
72
+ out_ref = causal_conv1d_ref(x_ref, weight_ref, bias_ref, initial_states=initial_states_ref, return_final_states=return_final_states, activation=activation)
73
+ if return_final_states:
74
+ out, final_states = out
75
+ out_ref, final_states_ref = out_ref
76
+ print(f"Final states max diff: {(final_states - final_states_ref).abs().max().item()}")
77
+ print(f"Final states mean diff: {(final_states - final_states_ref).abs().mean().item()}")
78
+ assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol)
79
+
80
+ print(f"Output max diff: {(out - out_ref).abs().max().item()}")
81
+ print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
82
+ assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
83
+
84
+ if return_final_states:
85
+ out += F.sigmoid(final_states).sum(dim=-1, keepdim=True)
86
+ out_ref += F.sigmoid(final_states_ref).sum(dim=-1, keepdim=True)
87
+
88
+ g = torch.randn_like(out)
89
+ out.backward(g)
90
+ out_ref.backward(g)
91
+
92
+ print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}")
93
+ print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}")
94
+ if has_bias:
95
+ print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}")
96
+ if has_initial_states:
97
+ print(f"dinitial_states max diff: {(initial_states.grad - initial_states_ref.grad).abs().max().item()}")
98
+
99
+ assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
100
+ assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw)
101
+ if has_bias:
102
+ assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw)
103
+ if has_initial_states:
104
+ assert torch.allclose(initial_states.grad, initial_states_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
105
+
106
+
107
+ @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
108
+ # @pytest.mark.parametrize('itype', [torch.float16])
109
+ @pytest.mark.parametrize("silu_activation", [False, True])
110
+ # @pytest.mark.parametrize('silu_activation', [True])
111
+ @pytest.mark.parametrize("has_bias", [False, True])
112
+ # @pytest.mark.parametrize('has_bias', [True])
113
+ @pytest.mark.parametrize("has_cache_seqlens", [False, True])
114
+ # @pytest.mark.parametrize('has_cache_seqlens', [True])
115
+ @pytest.mark.parametrize("seqlen", [1, 4, 5])
116
+ # @pytest.mark.parametrize('seqlen', [4])
117
+ @pytest.mark.parametrize("width", [2, 3, 4])
118
+ # @pytest.mark.parametrize('width', [4])
119
+ @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
120
+ # @pytest.mark.parametrize("dim", [2048])
121
+ def test_causal_conv1d_update(dim, width, seqlen, has_cache_seqlens, has_bias, silu_activation, itype):
122
+ device = "cuda"
123
+ rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
124
+ if itype == torch.bfloat16:
125
+ rtol, atol = 1e-2, 5e-2
126
+ rtolw, atolw = (1e-3, 1e-3)
127
+ # set seed
128
+ torch.random.manual_seed(0)
129
+ batch = 64
130
+ # batch = 1
131
+ # dim = 64
132
+ x = torch.randn(batch, seqlen, dim, device=device, dtype=itype).transpose(-1, -2)
133
+ state_len = torch.randint(width - 1, width + 10, (1,)).item()
134
+ conv_state = torch.randn(batch, state_len, dim, device=device, dtype=itype).transpose(-1, -2)
135
+ weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
136
+ if has_bias:
137
+ bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
138
+ else:
139
+ bias = None
140
+ conv_state_ref = conv_state.detach().clone()
141
+ activation = None if not silu_activation else "silu"
142
+ cache_seqlens = (torch.randint(0, 1024, (batch,), dtype=torch.int32, device=device)
143
+ if has_cache_seqlens else None)
144
+ out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation, cache_seqlens=cache_seqlens)
145
+ out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation, cache_seqlens=cache_seqlens)
146
+
147
+ print(f"Output max diff: {(out - out_ref).abs().max().item()}")
148
+ print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
149
+ assert torch.equal(conv_state, conv_state_ref)
150
+ assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
151
+
152
+
153
+ @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
154
+ # @pytest.mark.parametrize('itype', [torch.float16])
155
+ @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
156
+ # @pytest.mark.parametrize("dim", [2048])
157
+ def test_causal_conv1d_get_states(dim, itype):
158
+ device = "cuda"
159
+ # set seed
160
+ torch.random.manual_seed(0)
161
+ seqlens = torch.randint(1, 32, (100,), device=device)
162
+ total_seqlen = seqlens.sum().item()
163
+ x = torch.randn(total_seqlen, dim, device=device, dtype=itype)
164
+ cu_seqlens = F.pad(seqlens.cumsum(0), (1, 0))
165
+ state_len = 20
166
+ out = causal_conv1d_varlen_states(x, cu_seqlens, state_len)
167
+ out_ref = causal_conv1d_varlen_states_ref(x, cu_seqlens, state_len)
168
+ assert torch.equal(out, out_ref)
169
+
170
+
171
+ # @pytest.mark.parametrize("channel_last", [False, True])
172
+ @pytest.mark.parametrize('channel_last', [True])
173
+ # @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
174
+ @pytest.mark.parametrize('itype', [torch.bfloat16])
175
+ # @pytest.mark.parametrize("silu_activation", [False, True])
176
+ @pytest.mark.parametrize('silu_activation', [True])
177
+ # @pytest.mark.parametrize("has_bias", [False, True])
178
+ @pytest.mark.parametrize('has_bias', [True])
179
+ # @pytest.mark.parametrize("width", [2, 3, 4])
180
+ @pytest.mark.parametrize('width', [4])
181
+ @pytest.mark.parametrize(
182
+ # "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
183
+ "seqlen", [2048]
184
+ )
185
+ # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
186
+ # @pytest.mark.parametrize('seqlen', [128])
187
+ def test_causal_conv1d_race_condition(seqlen, width, has_bias, silu_activation, itype, channel_last):
188
+ device = "cuda"
189
+ # set seed
190
+ torch.random.manual_seed(0)
191
+ batch = 2
192
+ # batch = 1
193
+ dim = 4096 + 32 # Try dim not divisible by 64
194
+ # dim = 64
195
+ if not channel_last:
196
+ x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
197
+ else:
198
+ x = rearrange(
199
+ torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
200
+ ).requires_grad_()
201
+ weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
202
+ if has_bias:
203
+ bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
204
+ else:
205
+ bias = None
206
+ activation = None if not silu_activation else "silu"
207
+ out0 = causal_conv1d_fn(x, weight, bias, activation=activation)
208
+ g = torch.randn_like(out0)
209
+ dx0, dw0, db0 = torch.autograd.grad(out0, (x, weight, bias), g)
210
+ dw_atol = 1e-4
211
+ db_atol = 1e-4
212
+
213
+ for i in range(10000):
214
+ out = causal_conv1d_fn(x, weight, bias, activation=activation)
215
+ dx, dw, db = torch.autograd.grad(out, (x, weight, bias), g)
216
+ dw_equal = torch.allclose(dw, dw0, atol=dw_atol)
217
+ # if not dw_equal:
218
+ # breakpoint()
219
+ if has_bias:
220
+ db_equal = torch.allclose(db, db0, atol=db_atol)
221
+ # if not db_equal:
222
+ # breakpoint()
223
+ assert torch.equal(out, out0)
224
+ assert torch.equal(dx, dx0)
225
+ assert dw_equal
226
+ if has_bias:
227
+ assert dw_equal
228
+
229
+
230
+ @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
231
+ # @pytest.mark.parametrize('itype', [torch.float16])
232
+ @pytest.mark.parametrize("silu_activation", [False, True])
233
+ # @pytest.mark.parametrize('silu_activation', [False])
234
+ @pytest.mark.parametrize("has_bias", [False, True])
235
+ # @pytest.mark.parametrize('has_bias', [False])
236
+ @pytest.mark.parametrize("width", [2, 3, 4])
237
+ # @pytest.mark.parametrize('width', [2])
238
+ @pytest.mark.parametrize(
239
+ "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
240
+ )
241
+ # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
242
+ # @pytest.mark.parametrize('seqlen', [2048])
243
+ @pytest.mark.parametrize('dim', [64, 4096 + 32])
244
+ # @pytest.mark.parametrize('dim', [64])
245
+ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, itype):
246
+ device = "cuda"
247
+ rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
248
+ if itype == torch.bfloat16:
249
+ rtol, atol = 1e-2, 5e-2
250
+ rtolw, atolw = (1e-3, 1e-3)
251
+ # set seed
252
+ torch.random.manual_seed(seqlen + dim + width)
253
+ batch = 3
254
+ seqlens = []
255
+ for b in range(batch):
256
+ nsplits = torch.randint(1, 5, (1,)).item()
257
+ eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
258
+ seqlens.append(torch.diff(torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])).tolist())
259
+ assert sum(seqlens[-1]) == seqlen
260
+ assert all(s > 0 for s in seqlens[-1])
261
+ # Only support channel_last
262
+ x = rearrange(
263
+ torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
264
+ ).requires_grad_()
265
+ weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
266
+ if has_bias:
267
+ bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
268
+ else:
269
+ bias = None
270
+ seq_idx = torch.stack([torch.cat([torch.full((s,), i, dtype=torch.int32, device=device) for i, s in enumerate(sl)], dim=0)
271
+ for sl in seqlens], dim=0)
272
+ x_ref = x.detach().clone().requires_grad_()
273
+ weight_ref = weight.detach().clone().requires_grad_()
274
+ bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
275
+ activation = None if not silu_activation else "silu"
276
+ out = causal_conv1d_fn(x, weight, bias, seq_idx=seq_idx, activation=activation)
277
+ out_ref = []
278
+ for b in range(batch):
279
+ out_ref_b = []
280
+ for x_s in torch.split(x_ref[[b]], seqlens[b], dim=2):
281
+ out_ref_b.append(causal_conv1d_ref(x_s, weight_ref, bias_ref, activation=activation))
282
+ out_ref.append(torch.cat(out_ref_b, dim=2))
283
+ out_ref = torch.cat(out_ref, dim=0)
284
+
285
+ print(f"Output max diff: {(out - out_ref).abs().max().item()}")
286
+ print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
287
+ assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
288
+
289
+ g = torch.randn_like(out)
290
+ out_ref.backward(g)
291
+ out.backward(g)
292
+
293
+ print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}")
294
+ print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}")
295
+ if has_bias:
296
+ print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}")
297
+
298
+ assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
299
+ assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw)
300
+ if has_bias:
301
+ assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw)