Spaces:
Running
on
Zero
Running
on
Zero
ZeroGPU
#2
by
hysts
HF staff
- opened
This view is limited to 50 files because it contains too many changes.
See the raw diff here.
- app.py +22 -35
- causal-conv1d/AUTHORS +0 -1
- causal-conv1d/LICENSE +0 -29
- causal-conv1d/README.md +0 -1
- causal-conv1d/causal_conv1d/__init__.py +0 -3
- causal-conv1d/causal_conv1d/causal_conv1d_interface.py +0 -104
- causal-conv1d/csrc/causal_conv1d.cpp +0 -333
- causal-conv1d/csrc/causal_conv1d.h +0 -53
- causal-conv1d/csrc/causal_conv1d_bwd.cu +0 -525
- causal-conv1d/csrc/causal_conv1d_common.h +0 -64
- causal-conv1d/csrc/causal_conv1d_fwd.cu +0 -350
- causal-conv1d/csrc/causal_conv1d_update.cu +0 -96
- causal-conv1d/csrc/static_switch.h +0 -25
- causal-conv1d/setup.py +0 -264
- causal-conv1d/tests/test_causal_conv1d.py +0 -173
- causal_conv1d-1.0.0-cp310-cp310-linux_x86_64.whl +3 -0
- install.sh +0 -3
- mamba/.gitmodules +0 -3
- mamba/AUTHORS +0 -2
- mamba/LICENSE +0 -201
- mamba/README.md +0 -149
- mamba/assets/selection.png +0 -0
- mamba/benchmarks/benchmark_generation_mamba_simple.py +0 -88
- mamba/csrc/selective_scan/reverse_scan.cuh +0 -401
- mamba/csrc/selective_scan/selective_scan.cpp +0 -497
- mamba/csrc/selective_scan/selective_scan.h +0 -101
- mamba/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu +0 -9
- mamba/csrc/selective_scan/selective_scan_bwd_bf16_real.cu +0 -9
- mamba/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu +0 -9
- mamba/csrc/selective_scan/selective_scan_bwd_fp16_real.cu +0 -9
- mamba/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu +0 -9
- mamba/csrc/selective_scan/selective_scan_bwd_fp32_real.cu +0 -9
- mamba/csrc/selective_scan/selective_scan_bwd_kernel.cuh +0 -531
- mamba/csrc/selective_scan/selective_scan_common.h +0 -221
- mamba/csrc/selective_scan/selective_scan_fwd_bf16.cu +0 -10
- mamba/csrc/selective_scan/selective_scan_fwd_fp16.cu +0 -10
- mamba/csrc/selective_scan/selective_scan_fwd_fp32.cu +0 -10
- mamba/csrc/selective_scan/selective_scan_fwd_kernel.cuh +0 -345
- mamba/csrc/selective_scan/static_switch.h +0 -25
- mamba/csrc/selective_scan/uninitialized_copy.cuh +0 -69
- mamba/evals/lm_harness_eval.py +0 -39
- mamba/mamba_ssm/__init__.py +0 -5
- mamba/mamba_ssm/models/__init__.py +0 -0
- mamba/mamba_ssm/models/mixer_seq_simple.py +0 -233
- mamba/mamba_ssm/modules/__init__.py +0 -0
- mamba/mamba_ssm/modules/mamba_simple.py +0 -418
- mamba/mamba_ssm/ops/__init__.py +0 -0
- mamba/mamba_ssm/ops/selective_scan_interface.py +0 -709
- mamba/mamba_ssm/ops/triton/__init__.py +0 -0
- mamba/mamba_ssm/ops/triton/layernorm.py +0 -636
app.py
CHANGED
@@ -1,15 +1,13 @@
|
|
1 |
-
import
|
2 |
-
|
|
|
3 |
import torch
|
4 |
|
5 |
-
os.system("nvidia-smi")
|
6 |
-
print("TORCH_CUDA", torch.cuda.is_available())
|
7 |
-
|
8 |
-
|
9 |
# install packages for mamba
|
10 |
def install():
|
11 |
print("Install personal packages", flush=True)
|
12 |
-
|
|
|
13 |
|
14 |
install()
|
15 |
|
@@ -25,7 +23,7 @@ from videomamba_video import videomamba_tiny
|
|
25 |
from kinetics_class_index import kinetics_classnames
|
26 |
from imagenet_class_index import imagenet_classnames
|
27 |
from transforms import (
|
28 |
-
GroupNormalize, GroupScale, GroupCenterCrop,
|
29 |
Stack, ToTorchFormatTensor
|
30 |
)
|
31 |
|
@@ -38,7 +36,7 @@ from huggingface_hub import hf_hub_download
|
|
38 |
device = "cuda"
|
39 |
model_video_path = hf_hub_download(repo_id="OpenGVLab/VideoMamba", filename="videomamba_t16_k400_f16_res224.pth")
|
40 |
model_image_path = hf_hub_download(repo_id="OpenGVLab/VideoMamba", filename="videomamba_t16_in1k_res224.pth")
|
41 |
-
# Pick a pretrained model
|
42 |
model_video = videomamba_tiny(num_classes=400, num_frames=16)
|
43 |
video_sd = torch.load(model_video_path, map_location='cpu')
|
44 |
model_video.load_state_dict(video_sd)
|
@@ -55,7 +53,7 @@ for k, v in kinetics_classnames.items():
|
|
55 |
kinetics_id_to_classname[k] = v
|
56 |
imagenet_id_to_classname = {}
|
57 |
for k, v in imagenet_classnames.items():
|
58 |
-
imagenet_id_to_classname[k] = v[1]
|
59 |
|
60 |
|
61 |
def get_index(num_frames, num_segments=8):
|
@@ -83,7 +81,7 @@ def load_video(video_path):
|
|
83 |
GroupCenterCrop(crop_size),
|
84 |
Stack(),
|
85 |
ToTorchFormatTensor(),
|
86 |
-
GroupNormalize(input_mean, input_std)
|
87 |
])
|
88 |
|
89 |
images_group = list()
|
@@ -92,28 +90,24 @@ def load_video(video_path):
|
|
92 |
images_group.append(img)
|
93 |
torch_imgs = transform(images_group)
|
94 |
return torch_imgs
|
95 |
-
|
96 |
|
97 |
-
|
|
|
98 |
def inference_video(video):
|
99 |
vid = load_video(video)
|
100 |
-
|
101 |
# The model expects inputs of shape: B x C x H x W
|
102 |
TC, H, W = vid.shape
|
103 |
inputs = vid.reshape(1, TC//3, 3, H, W).permute(0, 2, 1, 3, 4)
|
104 |
-
|
105 |
with torch.no_grad():
|
106 |
prediction = model_video(inputs.to(device))
|
107 |
prediction = F.softmax(prediction, dim=1).flatten()
|
108 |
|
109 |
return {kinetics_id_to_classname[str(i)]: float(prediction[i]) for i in range(400)}
|
110 |
-
|
111 |
-
|
112 |
-
def set_example_video(example: list) -> dict:
|
113 |
-
return gr.Video.update(value=example[0])
|
114 |
|
115 |
|
116 |
-
|
117 |
def inference_image(img):
|
118 |
image = img
|
119 |
image_transform = T.Compose(
|
@@ -125,10 +119,10 @@ def inference_image(img):
|
|
125 |
]
|
126 |
)
|
127 |
image = image_transform(image)
|
128 |
-
|
129 |
# The model expects inputs of shape: B x C x H x W
|
130 |
image = image.unsqueeze(0)
|
131 |
-
|
132 |
with torch.no_grad():
|
133 |
prediction = model_image(image.to(device))
|
134 |
prediction = F.softmax(prediction, dim=1).flatten()
|
@@ -136,10 +130,6 @@ def inference_image(img):
|
|
136 |
return {imagenet_id_to_classname[str(i)]: float(prediction[i]) for i in range(1000)}
|
137 |
|
138 |
|
139 |
-
def set_example_image(example: list) -> dict:
|
140 |
-
return gr.Image.update(value=example[0])
|
141 |
-
|
142 |
-
|
143 |
demo = gr.Blocks()
|
144 |
with demo:
|
145 |
gr.Markdown(
|
@@ -154,26 +144,26 @@ with demo:
|
|
154 |
with gr.Row():
|
155 |
with gr.Column():
|
156 |
with gr.Row():
|
157 |
-
input_video = gr.Video(label='Input Video'
|
158 |
with gr.Row():
|
159 |
submit_video_button = gr.Button('Submit')
|
160 |
with gr.Column():
|
161 |
label_video = gr.Label(num_top_classes=5)
|
162 |
with gr.Row():
|
163 |
-
|
164 |
-
|
165 |
with gr.Tab("Image"):
|
166 |
# with gr.Box():
|
167 |
with gr.Row():
|
168 |
with gr.Column():
|
169 |
with gr.Row():
|
170 |
-
input_image = gr.Image(label='Input Image', type='pil'
|
171 |
with gr.Row():
|
172 |
submit_image_button = gr.Button('Submit')
|
173 |
with gr.Column():
|
174 |
label_image = gr.Label(num_top_classes=5)
|
175 |
with gr.Row():
|
176 |
-
|
177 |
|
178 |
gr.Markdown(
|
179 |
"""
|
@@ -182,9 +172,6 @@ with demo:
|
|
182 |
)
|
183 |
|
184 |
submit_video_button.click(fn=inference_video, inputs=input_video, outputs=label_video)
|
185 |
-
example_videos.click(fn=set_example_video, inputs=example_videos, outputs=example_videos._components)
|
186 |
submit_image_button.click(fn=inference_image, inputs=input_image, outputs=label_image)
|
187 |
-
example_images.click(fn=set_example_image, inputs=example_images, outputs=example_images._components)
|
188 |
|
189 |
-
demo.
|
190 |
-
# demo.launch(server_name="0.0.0.0", server_port=10034, enable_queue=True)
|
|
|
1 |
+
import shlex
|
2 |
+
import subprocess
|
3 |
+
import spaces
|
4 |
import torch
|
5 |
|
|
|
|
|
|
|
|
|
6 |
# install packages for mamba
|
7 |
def install():
|
8 |
print("Install personal packages", flush=True)
|
9 |
+
subprocess.run(shlex.split("pip install causal_conv1d-1.0.0-cp310-cp310-linux_x86_64.whl"))
|
10 |
+
subprocess.run(shlex.split("pip install mamba_ssm-1.0.1-cp310-cp310-linux_x86_64.whl"))
|
11 |
|
12 |
install()
|
13 |
|
|
|
23 |
from kinetics_class_index import kinetics_classnames
|
24 |
from imagenet_class_index import imagenet_classnames
|
25 |
from transforms import (
|
26 |
+
GroupNormalize, GroupScale, GroupCenterCrop,
|
27 |
Stack, ToTorchFormatTensor
|
28 |
)
|
29 |
|
|
|
36 |
device = "cuda"
|
37 |
model_video_path = hf_hub_download(repo_id="OpenGVLab/VideoMamba", filename="videomamba_t16_k400_f16_res224.pth")
|
38 |
model_image_path = hf_hub_download(repo_id="OpenGVLab/VideoMamba", filename="videomamba_t16_in1k_res224.pth")
|
39 |
+
# Pick a pretrained model
|
40 |
model_video = videomamba_tiny(num_classes=400, num_frames=16)
|
41 |
video_sd = torch.load(model_video_path, map_location='cpu')
|
42 |
model_video.load_state_dict(video_sd)
|
|
|
53 |
kinetics_id_to_classname[k] = v
|
54 |
imagenet_id_to_classname = {}
|
55 |
for k, v in imagenet_classnames.items():
|
56 |
+
imagenet_id_to_classname[k] = v[1]
|
57 |
|
58 |
|
59 |
def get_index(num_frames, num_segments=8):
|
|
|
81 |
GroupCenterCrop(crop_size),
|
82 |
Stack(),
|
83 |
ToTorchFormatTensor(),
|
84 |
+
GroupNormalize(input_mean, input_std)
|
85 |
])
|
86 |
|
87 |
images_group = list()
|
|
|
90 |
images_group.append(img)
|
91 |
torch_imgs = transform(images_group)
|
92 |
return torch_imgs
|
|
|
93 |
|
94 |
+
|
95 |
+
@spaces.GPU
|
96 |
def inference_video(video):
|
97 |
vid = load_video(video)
|
98 |
+
|
99 |
# The model expects inputs of shape: B x C x H x W
|
100 |
TC, H, W = vid.shape
|
101 |
inputs = vid.reshape(1, TC//3, 3, H, W).permute(0, 2, 1, 3, 4)
|
102 |
+
|
103 |
with torch.no_grad():
|
104 |
prediction = model_video(inputs.to(device))
|
105 |
prediction = F.softmax(prediction, dim=1).flatten()
|
106 |
|
107 |
return {kinetics_id_to_classname[str(i)]: float(prediction[i]) for i in range(400)}
|
|
|
|
|
|
|
|
|
108 |
|
109 |
|
110 |
+
@spaces.GPU
|
111 |
def inference_image(img):
|
112 |
image = img
|
113 |
image_transform = T.Compose(
|
|
|
119 |
]
|
120 |
)
|
121 |
image = image_transform(image)
|
122 |
+
|
123 |
# The model expects inputs of shape: B x C x H x W
|
124 |
image = image.unsqueeze(0)
|
125 |
+
|
126 |
with torch.no_grad():
|
127 |
prediction = model_image(image.to(device))
|
128 |
prediction = F.softmax(prediction, dim=1).flatten()
|
|
|
130 |
return {imagenet_id_to_classname[str(i)]: float(prediction[i]) for i in range(1000)}
|
131 |
|
132 |
|
|
|
|
|
|
|
|
|
133 |
demo = gr.Blocks()
|
134 |
with demo:
|
135 |
gr.Markdown(
|
|
|
144 |
with gr.Row():
|
145 |
with gr.Column():
|
146 |
with gr.Row():
|
147 |
+
input_video = gr.Video(label='Input Video', height=360)
|
148 |
with gr.Row():
|
149 |
submit_video_button = gr.Button('Submit')
|
150 |
with gr.Column():
|
151 |
label_video = gr.Label(num_top_classes=5)
|
152 |
with gr.Row():
|
153 |
+
gr.Examples(examples=['./videos/hitting_baseball.mp4', './videos/hoverboarding.mp4', './videos/yoga.mp4'], inputs=input_video, outputs=label_video, fn=inference_video, cache_examples=True)
|
154 |
+
|
155 |
with gr.Tab("Image"):
|
156 |
# with gr.Box():
|
157 |
with gr.Row():
|
158 |
with gr.Column():
|
159 |
with gr.Row():
|
160 |
+
input_image = gr.Image(label='Input Image', type='pil', height=360)
|
161 |
with gr.Row():
|
162 |
submit_image_button = gr.Button('Submit')
|
163 |
with gr.Column():
|
164 |
label_image = gr.Label(num_top_classes=5)
|
165 |
with gr.Row():
|
166 |
+
gr.Examples(examples=['./images/cat.png', './images/dog.png', './images/panda.png'], inputs=input_image, outputs=label_image, fn=inference_image, cache_examples=True)
|
167 |
|
168 |
gr.Markdown(
|
169 |
"""
|
|
|
172 |
)
|
173 |
|
174 |
submit_video_button.click(fn=inference_video, inputs=input_video, outputs=label_video)
|
|
|
175 |
submit_image_button.click(fn=inference_image, inputs=input_image, outputs=label_image)
|
|
|
176 |
|
177 |
+
demo.queue(max_size=20).launch()
|
|
causal-conv1d/AUTHORS
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
Tri Dao, tri@tridao.me
|
|
|
|
causal-conv1d/LICENSE
DELETED
@@ -1,29 +0,0 @@
|
|
1 |
-
BSD 3-Clause License
|
2 |
-
|
3 |
-
Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
|
4 |
-
All rights reserved.
|
5 |
-
|
6 |
-
Redistribution and use in source and binary forms, with or without
|
7 |
-
modification, are permitted provided that the following conditions are met:
|
8 |
-
|
9 |
-
* Redistributions of source code must retain the above copyright notice, this
|
10 |
-
list of conditions and the following disclaimer.
|
11 |
-
|
12 |
-
* Redistributions in binary form must reproduce the above copyright notice,
|
13 |
-
this list of conditions and the following disclaimer in the documentation
|
14 |
-
and/or other materials provided with the distribution.
|
15 |
-
|
16 |
-
* Neither the name of the copyright holder nor the names of its
|
17 |
-
contributors may be used to endorse or promote products derived from
|
18 |
-
this software without specific prior written permission.
|
19 |
-
|
20 |
-
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
21 |
-
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
22 |
-
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
23 |
-
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
24 |
-
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
25 |
-
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
26 |
-
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
27 |
-
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
28 |
-
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
29 |
-
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
causal-conv1d/README.md
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
# Causal depthwise conv1d in CUDA with a PyTorch interface
|
|
|
|
causal-conv1d/causal_conv1d/__init__.py
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
__version__ = "1.0.0"
|
2 |
-
|
3 |
-
from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
|
|
|
|
|
|
|
|
causal-conv1d/causal_conv1d/causal_conv1d_interface.py
DELETED
@@ -1,104 +0,0 @@
|
|
1 |
-
# Copyright (c) 2023, Tri Dao.
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import torch.nn.functional as F
|
5 |
-
|
6 |
-
|
7 |
-
import causal_conv1d_cuda
|
8 |
-
|
9 |
-
|
10 |
-
class CausalConv1dFn(torch.autograd.Function):
|
11 |
-
@staticmethod
|
12 |
-
def forward(ctx, x, weight, bias=None, activation=None):
|
13 |
-
if activation not in [None, "silu", "swish"]:
|
14 |
-
raise NotImplementedError("activation must be None, silu, or swish")
|
15 |
-
if x.stride(2) != 1 and x.stride(1) != 1:
|
16 |
-
x = x.contiguous()
|
17 |
-
bias = bias.contiguous() if bias is not None else None
|
18 |
-
ctx.save_for_backward(x, weight, bias)
|
19 |
-
ctx.activation = activation in ["silu", "swish"]
|
20 |
-
out = causal_conv1d_cuda.causal_conv1d_fwd(x, weight, bias, ctx.activation)
|
21 |
-
return out
|
22 |
-
|
23 |
-
@staticmethod
|
24 |
-
def backward(ctx, dout):
|
25 |
-
x, weight, bias = ctx.saved_tensors
|
26 |
-
if dout.stride(2) != 1 and dout.stride(1) != 1:
|
27 |
-
dout = dout.contiguous()
|
28 |
-
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
29 |
-
# backward of conv1d with the backward of chunk).
|
30 |
-
# Here we just pass in None and dx will be allocated in the C++ code.
|
31 |
-
dx, dweight, dbias = causal_conv1d_cuda.causal_conv1d_bwd(
|
32 |
-
x, weight, bias, dout, None, ctx.activation
|
33 |
-
)
|
34 |
-
return dx, dweight, dbias if bias is not None else None, None
|
35 |
-
|
36 |
-
|
37 |
-
def causal_conv1d_fn(x, weight, bias=None, activation=None):
|
38 |
-
"""
|
39 |
-
x: (batch, dim, seqlen)
|
40 |
-
weight: (dim, width)
|
41 |
-
bias: (dim,)
|
42 |
-
activation: either None or "silu" or "swish"
|
43 |
-
|
44 |
-
out: (batch, dim, seqlen)
|
45 |
-
"""
|
46 |
-
return CausalConv1dFn.apply(x, weight, bias, activation)
|
47 |
-
|
48 |
-
|
49 |
-
def causal_conv1d_ref(x, weight, bias=None, activation=None):
|
50 |
-
"""
|
51 |
-
x: (batch, dim, seqlen)
|
52 |
-
weight: (dim, width)
|
53 |
-
bias: (dim,)
|
54 |
-
|
55 |
-
out: (batch, dim, seqlen)
|
56 |
-
"""
|
57 |
-
if activation not in [None, "silu", "swish"]:
|
58 |
-
raise NotImplementedError("activation must be None, silu, or swish")
|
59 |
-
dtype_in = x.dtype
|
60 |
-
x = x.to(weight.dtype)
|
61 |
-
seqlen = x.shape[-1]
|
62 |
-
dim, width = weight.shape
|
63 |
-
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
|
64 |
-
out = out[..., :seqlen]
|
65 |
-
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
66 |
-
|
67 |
-
|
68 |
-
def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None):
|
69 |
-
"""
|
70 |
-
x: (batch, dim)
|
71 |
-
conv_state: (batch, dim, width)
|
72 |
-
weight: (dim, width)
|
73 |
-
bias: (dim,)
|
74 |
-
|
75 |
-
out: (batch, dim)
|
76 |
-
"""
|
77 |
-
if activation not in [None, "silu", "swish"]:
|
78 |
-
raise NotImplementedError("activation must be None, silu, or swish")
|
79 |
-
activation = activation in ["silu", "swish"]
|
80 |
-
return causal_conv1d_cuda.causal_conv1d_update(x, conv_state, weight, bias, activation)
|
81 |
-
|
82 |
-
|
83 |
-
def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None):
|
84 |
-
"""
|
85 |
-
x: (batch, dim)
|
86 |
-
conv_state: (batch, dim, width)
|
87 |
-
weight: (dim, width)
|
88 |
-
bias: (dim,)
|
89 |
-
|
90 |
-
out: (batch, dim)
|
91 |
-
"""
|
92 |
-
if activation not in [None, "silu", "swish"]:
|
93 |
-
raise NotImplementedError("activation must be None, silu, or swish")
|
94 |
-
dtype_in = x.dtype
|
95 |
-
batch, dim = x.shape
|
96 |
-
width = weight.shape[1]
|
97 |
-
assert conv_state.shape == (batch, dim, width)
|
98 |
-
assert weight.shape == (dim, width)
|
99 |
-
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
|
100 |
-
conv_state[:, :, -1] = x
|
101 |
-
out = torch.sum(conv_state * weight, dim=-1) # (B D)
|
102 |
-
if bias is not None:
|
103 |
-
out += bias
|
104 |
-
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
causal-conv1d/csrc/causal_conv1d.cpp
DELETED
@@ -1,333 +0,0 @@
|
|
1 |
-
/******************************************************************************
|
2 |
-
* Copyright (c) 2023, Tri Dao.
|
3 |
-
******************************************************************************/
|
4 |
-
|
5 |
-
#include <ATen/cuda/CUDAContext.h>
|
6 |
-
#include <c10/cuda/CUDAGuard.h>
|
7 |
-
#include <torch/extension.h>
|
8 |
-
#include <vector>
|
9 |
-
|
10 |
-
#include "causal_conv1d.h"
|
11 |
-
|
12 |
-
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
13 |
-
|
14 |
-
#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
15 |
-
if (ITYPE == at::ScalarType::Half) { \
|
16 |
-
using input_t = at::Half; \
|
17 |
-
__VA_ARGS__(); \
|
18 |
-
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
19 |
-
using input_t = at::BFloat16; \
|
20 |
-
__VA_ARGS__(); \
|
21 |
-
} else if (ITYPE == at::ScalarType::Float) { \
|
22 |
-
using input_t = float; \
|
23 |
-
__VA_ARGS__(); \
|
24 |
-
} else { \
|
25 |
-
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
26 |
-
}
|
27 |
-
|
28 |
-
#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
|
29 |
-
if (WTYPE == at::ScalarType::Half) { \
|
30 |
-
using weight_t = at::Half; \
|
31 |
-
__VA_ARGS__(); \
|
32 |
-
} else if (WTYPE == at::ScalarType::BFloat16) { \
|
33 |
-
using weight_t = at::BFloat16; \
|
34 |
-
__VA_ARGS__(); \
|
35 |
-
} else if (WTYPE == at::ScalarType::Float) { \
|
36 |
-
using weight_t = float; \
|
37 |
-
__VA_ARGS__(); \
|
38 |
-
} else { \
|
39 |
-
AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
|
40 |
-
}
|
41 |
-
|
42 |
-
template<typename input_t, typename weight_t>
|
43 |
-
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
44 |
-
template <typename input_t, typename weight_t>
|
45 |
-
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
46 |
-
|
47 |
-
template<typename input_t, typename weight_t>
|
48 |
-
void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
|
49 |
-
template<typename input_t, typename weight_t>
|
50 |
-
void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
|
51 |
-
|
52 |
-
template<typename input_t, typename weight_t>
|
53 |
-
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
54 |
-
|
55 |
-
void set_conv_params_fwd(ConvParamsBase ¶ms,
|
56 |
-
// sizes
|
57 |
-
const size_t batch,
|
58 |
-
const size_t dim,
|
59 |
-
const size_t seqlen,
|
60 |
-
const size_t width,
|
61 |
-
// device pointers
|
62 |
-
const at::Tensor x,
|
63 |
-
const at::Tensor weight,
|
64 |
-
const at::Tensor out,
|
65 |
-
void* bias_ptr,
|
66 |
-
bool silu_activation) {
|
67 |
-
|
68 |
-
// Reset the parameters
|
69 |
-
memset(¶ms, 0, sizeof(params));
|
70 |
-
|
71 |
-
params.batch = batch;
|
72 |
-
params.dim = dim;
|
73 |
-
params.seqlen = seqlen;
|
74 |
-
params.width = width;
|
75 |
-
|
76 |
-
params.silu_activation = silu_activation;
|
77 |
-
|
78 |
-
// Set the pointers and strides.
|
79 |
-
params.x_ptr = x.data_ptr();
|
80 |
-
params.weight_ptr = weight.data_ptr();
|
81 |
-
params.bias_ptr = bias_ptr;
|
82 |
-
params.out_ptr = out.data_ptr();
|
83 |
-
// All stride are in elements, not bytes.
|
84 |
-
params.x_batch_stride = x.stride(0);
|
85 |
-
params.x_c_stride = x.stride(1);
|
86 |
-
params.x_l_stride = x.stride(-1);
|
87 |
-
params.weight_c_stride = weight.stride(0);
|
88 |
-
params.weight_width_stride = weight.stride(1);
|
89 |
-
params.out_batch_stride = out.stride(0);
|
90 |
-
params.out_c_stride = out.stride(1);
|
91 |
-
params.out_l_stride = out.stride(-1);
|
92 |
-
}
|
93 |
-
|
94 |
-
|
95 |
-
void set_conv_params_bwd(ConvParamsBwd ¶ms,
|
96 |
-
// sizes
|
97 |
-
const size_t batch,
|
98 |
-
const size_t dim,
|
99 |
-
const size_t seqlen,
|
100 |
-
const size_t width,
|
101 |
-
// device pointers
|
102 |
-
const at::Tensor x,
|
103 |
-
const at::Tensor weight,
|
104 |
-
void* bias_ptr,
|
105 |
-
const at::Tensor dout,
|
106 |
-
const at::Tensor dx,
|
107 |
-
const at::Tensor dweight,
|
108 |
-
void* dbias_ptr,
|
109 |
-
bool silu_activation) {
|
110 |
-
// Pass in "dout" instead of "out", we're not gonna use "out" at all.
|
111 |
-
set_conv_params_fwd(params, batch, dim, seqlen, width,
|
112 |
-
x, weight, dout, bias_ptr, silu_activation);
|
113 |
-
|
114 |
-
// Set the pointers and strides.
|
115 |
-
params.dout_ptr = dout.data_ptr();
|
116 |
-
params.dx_ptr = dx.data_ptr();
|
117 |
-
params.dweight_ptr = dweight.data_ptr();
|
118 |
-
params.dbias_ptr = dbias_ptr;
|
119 |
-
// All stride are in elements, not bytes.
|
120 |
-
params.dout_batch_stride = dout.stride(0);
|
121 |
-
params.dout_c_stride = dout.stride(1);
|
122 |
-
params.dout_l_stride = dout.stride(2);
|
123 |
-
params.dweight_c_stride = dweight.stride(0);
|
124 |
-
params.dweight_width_stride = dweight.stride(1);
|
125 |
-
params.dx_batch_stride = dx.stride(0);
|
126 |
-
params.dx_c_stride = dx.stride(1);
|
127 |
-
params.dx_l_stride = dx.stride(2);
|
128 |
-
}
|
129 |
-
|
130 |
-
at::Tensor
|
131 |
-
causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
132 |
-
const c10::optional<at::Tensor> &bias_,
|
133 |
-
bool silu_activation) {
|
134 |
-
auto input_type = x.scalar_type();
|
135 |
-
auto weight_type = weight.scalar_type();
|
136 |
-
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
137 |
-
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
138 |
-
|
139 |
-
TORCH_CHECK(x.is_cuda());
|
140 |
-
TORCH_CHECK(weight.is_cuda());
|
141 |
-
|
142 |
-
const auto sizes = x.sizes();
|
143 |
-
const int batch_size = sizes[0];
|
144 |
-
const int dim = sizes[1];
|
145 |
-
const int seqlen = sizes[2];
|
146 |
-
const int width = weight.size(-1);
|
147 |
-
|
148 |
-
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
149 |
-
CHECK_SHAPE(weight, dim, width);
|
150 |
-
|
151 |
-
TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
|
152 |
-
const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
|
153 |
-
|
154 |
-
if (is_channel_last) {
|
155 |
-
TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
|
156 |
-
}
|
157 |
-
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
158 |
-
|
159 |
-
|
160 |
-
if (bias_.has_value()) {
|
161 |
-
auto bias = bias_.value();
|
162 |
-
TORCH_CHECK(bias.scalar_type() == weight_type);
|
163 |
-
TORCH_CHECK(bias.is_cuda());
|
164 |
-
TORCH_CHECK(bias.stride(-1) == 1);
|
165 |
-
CHECK_SHAPE(bias, dim);
|
166 |
-
}
|
167 |
-
|
168 |
-
at::Tensor out = torch::empty_like(x);
|
169 |
-
|
170 |
-
ConvParamsBase params;
|
171 |
-
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
172 |
-
bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
173 |
-
silu_activation);
|
174 |
-
|
175 |
-
// Otherwise the kernel will be launched from cuda:0 device
|
176 |
-
// Cast to char to avoid compiler warning about narrowing
|
177 |
-
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
|
178 |
-
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
179 |
-
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
|
180 |
-
DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_fwd", [&] {
|
181 |
-
if (!is_channel_last) {
|
182 |
-
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
|
183 |
-
} else {
|
184 |
-
causal_conv1d_channellast_fwd_cuda<input_t, weight_t>(params, stream);
|
185 |
-
}
|
186 |
-
});
|
187 |
-
});
|
188 |
-
return out;
|
189 |
-
}
|
190 |
-
|
191 |
-
std::vector<at::Tensor>
|
192 |
-
causal_conv1d_bwd(const at::Tensor &x, const at::Tensor &weight,
|
193 |
-
const c10::optional<at::Tensor> &bias_,
|
194 |
-
at::Tensor &dout,
|
195 |
-
c10::optional<at::Tensor> &dx_,
|
196 |
-
bool silu_activation) {
|
197 |
-
auto input_type = x.scalar_type();
|
198 |
-
auto weight_type = weight.scalar_type();
|
199 |
-
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
200 |
-
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
201 |
-
|
202 |
-
TORCH_CHECK(x.is_cuda());
|
203 |
-
TORCH_CHECK(weight.is_cuda());
|
204 |
-
TORCH_CHECK(dout.is_cuda());
|
205 |
-
|
206 |
-
const auto sizes = x.sizes();
|
207 |
-
const int batch_size = sizes[0];
|
208 |
-
const int dim = sizes[1];
|
209 |
-
const int seqlen = sizes[2];
|
210 |
-
const int width = weight.size(-1);
|
211 |
-
|
212 |
-
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
213 |
-
|
214 |
-
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
215 |
-
CHECK_SHAPE(weight, dim, width);
|
216 |
-
CHECK_SHAPE(dout, batch_size, dim, seqlen);
|
217 |
-
|
218 |
-
TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
|
219 |
-
const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
|
220 |
-
if (!is_channel_last && dout.stride(2) != 1) { dout = dout.contiguous(); }
|
221 |
-
if (is_channel_last && dout.stride(1) != 1) { dout = dout.transpose(-1, -2).contiguous().transpose(-1, -2); }
|
222 |
-
|
223 |
-
if (bias_.has_value()) {
|
224 |
-
auto bias = bias_.value();
|
225 |
-
TORCH_CHECK(bias.scalar_type() == weight_type);
|
226 |
-
TORCH_CHECK(bias.is_cuda());
|
227 |
-
TORCH_CHECK(bias.stride(-1) == 1);
|
228 |
-
CHECK_SHAPE(bias, dim);
|
229 |
-
}
|
230 |
-
|
231 |
-
at::Tensor dx;
|
232 |
-
if (dx_.has_value()) {
|
233 |
-
dx = dx_.value();
|
234 |
-
TORCH_CHECK(dx.scalar_type() == input_type);
|
235 |
-
TORCH_CHECK(dx.is_cuda());
|
236 |
-
CHECK_SHAPE(dx, batch_size, dim, seqlen);
|
237 |
-
if (!is_channel_last) { TORCH_CHECK(dx.stride(2) == 1); }
|
238 |
-
if (is_channel_last) { TORCH_CHECK(dx.stride(1) == 1); }
|
239 |
-
} else {
|
240 |
-
dx = torch::empty_like(x);
|
241 |
-
}
|
242 |
-
|
243 |
-
// Otherwise the kernel will be launched from cuda:0 device
|
244 |
-
// Cast to char to avoid compiler warning about narrowing
|
245 |
-
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
|
246 |
-
|
247 |
-
at::Tensor dweight = torch::zeros_like(weight, weight.options().dtype(at::kFloat));
|
248 |
-
at::Tensor dbias;
|
249 |
-
if (bias_.has_value()) { dbias = torch::zeros_like(bias_.value(), bias_.value().options().dtype(at::kFloat)); }
|
250 |
-
|
251 |
-
ConvParamsBwd params;
|
252 |
-
set_conv_params_bwd(params, batch_size, dim, seqlen, width,
|
253 |
-
x, weight, bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
254 |
-
dout, dx, dweight, bias_.has_value() ? dbias.data_ptr() : nullptr,
|
255 |
-
silu_activation);
|
256 |
-
|
257 |
-
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
258 |
-
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_bwd", [&] {
|
259 |
-
DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_bwd", [&] {
|
260 |
-
if (!is_channel_last) {
|
261 |
-
causal_conv1d_bwd_cuda<input_t, weight_t>(params, stream);
|
262 |
-
} else {
|
263 |
-
causal_conv1d_channellast_bwd_cuda<input_t, weight_t>(params, stream);
|
264 |
-
}
|
265 |
-
});
|
266 |
-
});
|
267 |
-
return {dx, dweight.to(weight.dtype()), bias_.has_value() ? dbias.to(bias_.value().dtype()) : dbias};
|
268 |
-
}
|
269 |
-
|
270 |
-
at::Tensor
|
271 |
-
causal_conv1d_update(const at::Tensor &x,
|
272 |
-
const at::Tensor &conv_state,
|
273 |
-
const at::Tensor &weight,
|
274 |
-
const c10::optional<at::Tensor> &bias_,
|
275 |
-
bool silu_activation) {
|
276 |
-
auto input_type = x.scalar_type();
|
277 |
-
auto weight_type = weight.scalar_type();
|
278 |
-
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
279 |
-
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
280 |
-
TORCH_CHECK(conv_state.scalar_type() == input_type);
|
281 |
-
|
282 |
-
TORCH_CHECK(x.is_cuda());
|
283 |
-
TORCH_CHECK(conv_state.is_cuda());
|
284 |
-
TORCH_CHECK(weight.is_cuda());
|
285 |
-
|
286 |
-
const auto sizes = x.sizes();
|
287 |
-
const int batch_size = sizes[0];
|
288 |
-
const int dim = sizes[1];
|
289 |
-
const int width = weight.size(-1);
|
290 |
-
|
291 |
-
CHECK_SHAPE(x, batch_size, dim);
|
292 |
-
CHECK_SHAPE(conv_state, batch_size, dim, width);
|
293 |
-
CHECK_SHAPE(weight, dim, width);
|
294 |
-
|
295 |
-
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
296 |
-
|
297 |
-
if (bias_.has_value()) {
|
298 |
-
auto bias = bias_.value();
|
299 |
-
TORCH_CHECK(bias.scalar_type() == weight_type);
|
300 |
-
TORCH_CHECK(bias.is_cuda());
|
301 |
-
TORCH_CHECK(bias.stride(-1) == 1);
|
302 |
-
CHECK_SHAPE(bias, dim);
|
303 |
-
}
|
304 |
-
|
305 |
-
at::Tensor out = torch::empty_like(x);
|
306 |
-
|
307 |
-
ConvParamsBase params;
|
308 |
-
set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out,
|
309 |
-
bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
310 |
-
silu_activation);
|
311 |
-
params.conv_state_ptr = conv_state.data_ptr();
|
312 |
-
// All stride are in elements, not bytes.
|
313 |
-
params.conv_state_batch_stride = conv_state.stride(0);
|
314 |
-
params.conv_state_c_stride = conv_state.stride(1);
|
315 |
-
params.conv_state_l_stride = conv_state.stride(2);
|
316 |
-
|
317 |
-
// Otherwise the kernel will be launched from cuda:0 device
|
318 |
-
// Cast to char to avoid compiler warning about narrowing
|
319 |
-
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
|
320 |
-
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
321 |
-
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
|
322 |
-
DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] {
|
323 |
-
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
|
324 |
-
});
|
325 |
-
});
|
326 |
-
return out;
|
327 |
-
}
|
328 |
-
|
329 |
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
330 |
-
m.def("causal_conv1d_fwd", &causal_conv1d_fwd, "Causal conv1d forward");
|
331 |
-
m.def("causal_conv1d_bwd", &causal_conv1d_bwd, "Causal conv1d backward");
|
332 |
-
m.def("causal_conv1d_update", &causal_conv1d_update, "Causal conv1d update");
|
333 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
causal-conv1d/csrc/causal_conv1d.h
DELETED
@@ -1,53 +0,0 @@
|
|
1 |
-
/******************************************************************************
|
2 |
-
* Copyright (c) 2023, Tri Dao.
|
3 |
-
******************************************************************************/
|
4 |
-
|
5 |
-
#pragma once
|
6 |
-
|
7 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
8 |
-
|
9 |
-
struct ConvParamsBase {
|
10 |
-
using index_t = uint32_t;
|
11 |
-
|
12 |
-
int batch, dim, seqlen, width;
|
13 |
-
bool silu_activation;
|
14 |
-
|
15 |
-
index_t x_batch_stride;
|
16 |
-
index_t x_c_stride;
|
17 |
-
index_t x_l_stride;
|
18 |
-
index_t weight_c_stride;
|
19 |
-
index_t weight_width_stride;
|
20 |
-
index_t out_batch_stride;
|
21 |
-
index_t out_c_stride;
|
22 |
-
index_t out_l_stride;
|
23 |
-
|
24 |
-
index_t conv_state_batch_stride;
|
25 |
-
index_t conv_state_c_stride;
|
26 |
-
index_t conv_state_l_stride;
|
27 |
-
|
28 |
-
// Common data pointers.
|
29 |
-
void *__restrict__ x_ptr;
|
30 |
-
void *__restrict__ weight_ptr;
|
31 |
-
void *__restrict__ bias_ptr;
|
32 |
-
void *__restrict__ out_ptr;
|
33 |
-
|
34 |
-
void *__restrict__ conv_state_ptr;
|
35 |
-
};
|
36 |
-
|
37 |
-
struct ConvParamsBwd: public ConvParamsBase {
|
38 |
-
index_t dx_batch_stride;
|
39 |
-
index_t dx_c_stride;
|
40 |
-
index_t dx_l_stride;
|
41 |
-
index_t dweight_c_stride;
|
42 |
-
index_t dweight_width_stride;
|
43 |
-
index_t dout_batch_stride;
|
44 |
-
index_t dout_c_stride;
|
45 |
-
index_t dout_l_stride;
|
46 |
-
|
47 |
-
// Common data pointers.
|
48 |
-
void *__restrict__ dx_ptr;
|
49 |
-
void *__restrict__ dweight_ptr;
|
50 |
-
void *__restrict__ dbias_ptr;
|
51 |
-
void *__restrict__ dout_ptr;
|
52 |
-
};
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
causal-conv1d/csrc/causal_conv1d_bwd.cu
DELETED
@@ -1,525 +0,0 @@
|
|
1 |
-
/******************************************************************************
|
2 |
-
* Copyright (c) 2023, Tri Dao.
|
3 |
-
******************************************************************************/
|
4 |
-
|
5 |
-
#include <c10/util/BFloat16.h>
|
6 |
-
#include <c10/util/Half.h>
|
7 |
-
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
8 |
-
|
9 |
-
#include <cub/block/block_load.cuh>
|
10 |
-
#include <cub/block/block_store.cuh>
|
11 |
-
#include <cub/block/block_reduce.cuh>
|
12 |
-
|
13 |
-
#include "causal_conv1d.h"
|
14 |
-
#include "causal_conv1d_common.h"
|
15 |
-
#include "static_switch.h"
|
16 |
-
|
17 |
-
template<int kNThreads_, int kWidth_, bool kSiluAct_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
18 |
-
struct Causal_conv1d_bwd_kernel_traits {
|
19 |
-
using input_t = input_t_;
|
20 |
-
using weight_t = weight_t_;
|
21 |
-
static constexpr int kNThreads = kNThreads_;
|
22 |
-
static constexpr int kWidth = kWidth_;
|
23 |
-
static constexpr bool kSiluAct = kSiluAct_;
|
24 |
-
static constexpr int kNBytes = sizeof(input_t);
|
25 |
-
static_assert(kNBytes == 2 || kNBytes == 4);
|
26 |
-
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
27 |
-
static_assert(kWidth <= kNElts);
|
28 |
-
// It's possible that we need to do 2 rounds of exchange if input_t is 16 bits
|
29 |
-
// (since then we'd have 8 values of float, and each round we can exchange 4 floats).
|
30 |
-
static constexpr int kNExchangeRounds = sizeof(float) / sizeof(input_t);
|
31 |
-
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
32 |
-
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
33 |
-
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
34 |
-
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
|
35 |
-
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
36 |
-
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
|
37 |
-
using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
|
38 |
-
static constexpr int kSmemIOSize = kIsVecLoad
|
39 |
-
? 0
|
40 |
-
: std::max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
|
41 |
-
static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts * (!kSiluAct ? 1 : kNExchangeRounds + 1);
|
42 |
-
static constexpr int kSmemSize = std::max({kSmemExchangeSize,
|
43 |
-
int(sizeof(typename BlockReduceFloatT::TempStorage))}) + (kIsVecLoad ? 0 : kSmemIOSize);
|
44 |
-
};
|
45 |
-
|
46 |
-
template<typename Ktraits>
|
47 |
-
__global__ __launch_bounds__(Ktraits::kNThreads)
|
48 |
-
void causal_conv1d_bwd_kernel(ConvParamsBwd params) {
|
49 |
-
constexpr int kWidth = Ktraits::kWidth;
|
50 |
-
constexpr int kNThreads = Ktraits::kNThreads;
|
51 |
-
constexpr bool kSiluAct = Ktraits::kSiluAct;
|
52 |
-
constexpr int kNElts = Ktraits::kNElts;
|
53 |
-
constexpr int kNExchangeRounds = Ktraits::kNExchangeRounds;
|
54 |
-
constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
|
55 |
-
using input_t = typename Ktraits::input_t;
|
56 |
-
using vec_t = typename Ktraits::vec_t;
|
57 |
-
using weight_t = typename Ktraits::weight_t;
|
58 |
-
|
59 |
-
// Shared memory.
|
60 |
-
extern __shared__ char smem_[];
|
61 |
-
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
62 |
-
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
|
63 |
-
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
64 |
-
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
|
65 |
-
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
|
66 |
-
vec_t *smem_exchange_x = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize) + kNThreads * kNExchangeRounds;
|
67 |
-
auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
|
68 |
-
|
69 |
-
const int tidx = threadIdx.x;
|
70 |
-
const int batch_id = blockIdx.x;
|
71 |
-
const int dim_id = blockIdx.y;
|
72 |
-
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
73 |
-
+ dim_id * params.x_c_stride;
|
74 |
-
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + dim_id * params.weight_c_stride;
|
75 |
-
input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
|
76 |
-
+ dim_id * params.dout_c_stride;
|
77 |
-
input_t *dx = reinterpret_cast<input_t *>(params.dx_ptr) + batch_id * params.dx_batch_stride
|
78 |
-
+ dim_id * params.dx_c_stride;
|
79 |
-
float *dweight = reinterpret_cast<float *>(params.dweight_ptr) + dim_id * params.dweight_c_stride;
|
80 |
-
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[dim_id]);
|
81 |
-
|
82 |
-
// Thread kNThreads - 1 will load the first elements of the next chunk so we initialize those to 0.
|
83 |
-
if (tidx == 0) {
|
84 |
-
if constexpr (!kSiluAct) {
|
85 |
-
input_t zeros[kNElts] = {0};
|
86 |
-
smem_exchange[0] = reinterpret_cast<vec_t *>(zeros)[0];
|
87 |
-
} else {
|
88 |
-
float zeros[kNElts] = {0};
|
89 |
-
#pragma unroll
|
90 |
-
for (int r = 0; r < kNExchangeRounds; ++r) {
|
91 |
-
smem_exchange[r * kNThreads] = reinterpret_cast<vec_t *>(zeros)[r];
|
92 |
-
}
|
93 |
-
}
|
94 |
-
}
|
95 |
-
|
96 |
-
float weight_vals[kWidth];
|
97 |
-
#pragma unroll
|
98 |
-
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = weight[i * params.weight_width_stride]; }
|
99 |
-
|
100 |
-
float dweight_vals[kWidth] = {0};
|
101 |
-
float dbias_val = 0;
|
102 |
-
|
103 |
-
constexpr int kChunkSize = kNThreads * kNElts;
|
104 |
-
const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
|
105 |
-
x += (n_chunks - 1) * kChunkSize;
|
106 |
-
dout += (n_chunks - 1) * kChunkSize;
|
107 |
-
dx += (n_chunks - 1) * kChunkSize;
|
108 |
-
for (int chunk = n_chunks - 1; chunk >= 0; --chunk) {
|
109 |
-
input_t x_vals_load[2 * kNElts] = {0};
|
110 |
-
input_t dout_vals_load[2 * kNElts] = {0};
|
111 |
-
if constexpr(kIsVecLoad) {
|
112 |
-
Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
|
113 |
-
Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(dout), *reinterpret_cast<vec_t (*)[1]>(&dout_vals_load[0]), (params.seqlen - chunk * kChunkSize) / kNElts);
|
114 |
-
} else {
|
115 |
-
__syncthreads();
|
116 |
-
Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
|
117 |
-
__syncthreads();
|
118 |
-
Ktraits::BlockLoadT(smem_load).Load(dout, *reinterpret_cast<input_t (*)[kNElts]>(&dout_vals_load[0]), params.seqlen - chunk * kChunkSize);
|
119 |
-
}
|
120 |
-
float dout_vals[2 * kNElts], x_vals[2 * kNElts];
|
121 |
-
if constexpr (!kSiluAct) {
|
122 |
-
__syncthreads();
|
123 |
-
// Thread 0 don't write yet, so that thread kNThreads - 1 can read
|
124 |
-
// the first elements of the next chunk.
|
125 |
-
if (tidx > 0) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; }
|
126 |
-
__syncthreads();
|
127 |
-
reinterpret_cast<vec_t *>(dout_vals_load)[1] = smem_exchange[tidx < kNThreads - 1 ? tidx + 1 : 0];
|
128 |
-
__syncthreads();
|
129 |
-
// Now thread 0 can write the first elements of the current chunk.
|
130 |
-
if (tidx == 0) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; }
|
131 |
-
#pragma unroll
|
132 |
-
for (int i = 0; i < 2 * kNElts; ++i) {
|
133 |
-
dout_vals[i] = float(dout_vals_load[i]);
|
134 |
-
x_vals[i] = float(x_vals_load[i]);
|
135 |
-
}
|
136 |
-
} else {
|
137 |
-
if (tidx == 0 && chunk > 0) {
|
138 |
-
if constexpr(kIsVecLoad) {
|
139 |
-
reinterpret_cast<vec_t *>(x_vals_load)[0] = reinterpret_cast<vec_t *>(x)[-1];
|
140 |
-
} else {
|
141 |
-
#pragma unroll
|
142 |
-
for (int i = 0; i < kNElts; ++i) {
|
143 |
-
if (chunk * kChunkSize + i < params.seqlen) { x_vals_load[i] = x[-kNElts + i]; }
|
144 |
-
}
|
145 |
-
}
|
146 |
-
}
|
147 |
-
__syncthreads();
|
148 |
-
smem_exchange_x[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1];
|
149 |
-
__syncthreads();
|
150 |
-
if (tidx > 0) { reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange_x[tidx - 1]; }
|
151 |
-
#pragma unroll
|
152 |
-
for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
|
153 |
-
// Recompute the output
|
154 |
-
#pragma unroll
|
155 |
-
for (int i = 0; i < kNElts; ++i) {
|
156 |
-
float out_val = bias_val;
|
157 |
-
#pragma unroll
|
158 |
-
for (int w = 0; w < kWidth; ++w) {
|
159 |
-
out_val += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
|
160 |
-
}
|
161 |
-
float out_sigmoid_val = 1.0f / (1.0f + expf(-out_val));
|
162 |
-
dout_vals[i] = float(dout_vals_load[i]) * out_sigmoid_val
|
163 |
-
* (1.0f + out_val * (1.0f - out_sigmoid_val));
|
164 |
-
}
|
165 |
-
// Exchange the dout_vals. It's possible that we need to do 2 rounds of exchange
|
166 |
-
// if input_t is 16 bits (since then we'd have 8 values of float)
|
167 |
-
__syncthreads();
|
168 |
-
// Thread 0 don't write yet, so that thread kNThreads - 1 can read
|
169 |
-
// the first elements of the next chunk.
|
170 |
-
if (tidx > 0) {
|
171 |
-
#pragma unroll
|
172 |
-
for (int r = 0; r < kNExchangeRounds; ++r) {
|
173 |
-
smem_exchange[r * kNThreads + tidx] = reinterpret_cast<vec_t *>(dout_vals)[r];
|
174 |
-
}
|
175 |
-
}
|
176 |
-
__syncthreads();
|
177 |
-
#pragma unroll
|
178 |
-
for (int r = 0; r < kNExchangeRounds; ++r) {
|
179 |
-
reinterpret_cast<vec_t *>(dout_vals)[kNExchangeRounds + r]
|
180 |
-
= smem_exchange[r * kNThreads + (tidx < kNThreads - 1 ? tidx + 1 : 0)];
|
181 |
-
}
|
182 |
-
__syncthreads();
|
183 |
-
// Now thread 0 can write the first elements of the current chunk.
|
184 |
-
if (tidx == 0) {
|
185 |
-
#pragma unroll
|
186 |
-
for (int r = 0; r < kNExchangeRounds; ++r) {
|
187 |
-
smem_exchange[r * kNThreads + tidx] = reinterpret_cast<vec_t *>(dout_vals)[r];
|
188 |
-
}
|
189 |
-
}
|
190 |
-
}
|
191 |
-
dout -= kChunkSize;
|
192 |
-
x -= kChunkSize;
|
193 |
-
|
194 |
-
#pragma unroll
|
195 |
-
for (int i = 0; i < kNElts; ++i) { dbias_val += dout_vals[i]; }
|
196 |
-
|
197 |
-
float dx_vals[kNElts] = {0};
|
198 |
-
#pragma unroll
|
199 |
-
for (int i = 0; i < kNElts; ++i) {
|
200 |
-
#pragma unroll
|
201 |
-
for (int w = 0; w < kWidth; ++w) {
|
202 |
-
dx_vals[i] += weight_vals[w] * dout_vals[i + kWidth - w - 1];
|
203 |
-
}
|
204 |
-
}
|
205 |
-
|
206 |
-
input_t dx_vals_store[kNElts];
|
207 |
-
#pragma unroll
|
208 |
-
for (int i = 0; i < kNElts; ++i) { dx_vals_store[i] = dx_vals[i]; }
|
209 |
-
if constexpr(kIsVecLoad) {
|
210 |
-
Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(dx), reinterpret_cast<vec_t (&)[1]>(dx_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
|
211 |
-
} else {
|
212 |
-
Ktraits::BlockStoreT(smem_store).Store(dx, dx_vals_store, params.seqlen - chunk * kChunkSize);
|
213 |
-
}
|
214 |
-
dx -= kChunkSize;
|
215 |
-
|
216 |
-
#pragma unroll
|
217 |
-
for (int w = 0; w < kWidth; ++w) {
|
218 |
-
#pragma unroll
|
219 |
-
for (int i = 0; i < kNElts; ++i) {
|
220 |
-
dweight_vals[w] += x_vals[kNElts + i] * dout_vals[i + kWidth - w - 1];
|
221 |
-
}
|
222 |
-
}
|
223 |
-
}
|
224 |
-
|
225 |
-
#pragma unroll
|
226 |
-
for (int w = 0; w < kWidth; ++w) {
|
227 |
-
__syncthreads();
|
228 |
-
dweight_vals[w] = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dweight_vals[w]);
|
229 |
-
if (tidx == 0) {
|
230 |
-
atomicAdd(&reinterpret_cast<float *>(dweight)[w * params.dweight_width_stride], dweight_vals[w]);
|
231 |
-
}
|
232 |
-
}
|
233 |
-
if (params.bias_ptr != nullptr) {
|
234 |
-
__syncthreads();
|
235 |
-
dbias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dbias_val);
|
236 |
-
if (tidx == 0) {
|
237 |
-
atomicAdd(&reinterpret_cast<float *>(params.dbias_ptr)[dim_id], dbias_val);
|
238 |
-
}
|
239 |
-
}
|
240 |
-
}
|
241 |
-
|
242 |
-
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
243 |
-
void causal_conv1d_bwd_launch(ConvParamsBwd ¶ms, cudaStream_t stream) {
|
244 |
-
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
|
245 |
-
BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
|
246 |
-
BOOL_SWITCH(params.silu_activation, kSiluAct, [&] {
|
247 |
-
using Ktraits = Causal_conv1d_bwd_kernel_traits<kNThreads, kWidth, kSiluAct, kIsVecLoad, input_t, weight_t>;
|
248 |
-
constexpr int kSmemSize = Ktraits::kSmemSize;
|
249 |
-
dim3 grid(params.batch, params.dim);
|
250 |
-
auto kernel = &causal_conv1d_bwd_kernel<Ktraits>;
|
251 |
-
if (kSmemSize >= 48 * 1024) {
|
252 |
-
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
253 |
-
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
254 |
-
}
|
255 |
-
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
256 |
-
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
257 |
-
});
|
258 |
-
});
|
259 |
-
}
|
260 |
-
|
261 |
-
template<typename input_t, typename weight_t>
|
262 |
-
void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream) {
|
263 |
-
if (params.width == 2) {
|
264 |
-
causal_conv1d_bwd_launch<128, 2, input_t, weight_t>(params, stream);
|
265 |
-
} else if (params.width == 3) {
|
266 |
-
causal_conv1d_bwd_launch<128, 3, input_t, weight_t>(params, stream);
|
267 |
-
} else if (params.width == 4) {
|
268 |
-
causal_conv1d_bwd_launch<128, 4, input_t, weight_t>(params, stream);
|
269 |
-
}
|
270 |
-
}
|
271 |
-
|
272 |
-
template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kSiluAct_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
273 |
-
struct Causal_conv1d_channellast_bwd_kernel_traits {
|
274 |
-
// The cache line is 128 bytes, and we try to read 16 bytes per thread.
|
275 |
-
// So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
|
276 |
-
// That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
|
277 |
-
// threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
|
278 |
-
using input_t = input_t_;
|
279 |
-
using weight_t = weight_t_;
|
280 |
-
static constexpr bool kSiluAct = kSiluAct_;
|
281 |
-
static constexpr int kNThreads = kNThreads_;
|
282 |
-
static_assert(kNThreads % 32 == 0);
|
283 |
-
static constexpr int kNWarps = kNThreads / 32;
|
284 |
-
static constexpr int kWidth = kWidth_;
|
285 |
-
static constexpr int kChunkSizeL = kChunkSizeL_;
|
286 |
-
static constexpr int kNBytes = sizeof(input_t);
|
287 |
-
static_assert(kNBytes == 2 || kNBytes == 4);
|
288 |
-
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
289 |
-
static constexpr int kNEltsPerRow = 128 / kNBytes;
|
290 |
-
static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
|
291 |
-
static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
|
292 |
-
static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
|
293 |
-
static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
|
294 |
-
static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
|
295 |
-
static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
|
296 |
-
static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
|
297 |
-
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
298 |
-
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
299 |
-
// using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
300 |
-
// using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
301 |
-
// static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
|
302 |
-
// sizeof(typename BlockStoreT::TempStorage)});
|
303 |
-
// static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
|
304 |
-
};
|
305 |
-
|
306 |
-
template<typename Ktraits>
|
307 |
-
__global__ __launch_bounds__(Ktraits::kNThreads)
|
308 |
-
void causal_conv1d_channellast_bwd_kernel(ConvParamsBwd params) {
|
309 |
-
constexpr int kWidth = Ktraits::kWidth;
|
310 |
-
constexpr int kNThreads = Ktraits::kNThreads;
|
311 |
-
constexpr bool kSiluAct = Ktraits::kSiluAct;
|
312 |
-
constexpr int kNElts = Ktraits::kNElts;
|
313 |
-
constexpr int kNWarp = Ktraits::kNWarps;
|
314 |
-
constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
|
315 |
-
constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
|
316 |
-
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
317 |
-
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
318 |
-
using input_t = typename Ktraits::input_t;
|
319 |
-
using vec_t = typename Ktraits::vec_t;
|
320 |
-
using weight_t = typename Ktraits::weight_t;
|
321 |
-
|
322 |
-
// Shared memory.
|
323 |
-
__shared__ input_t dout_smem[kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts];
|
324 |
-
__shared__ input_t x_smem[kWidth - 1 + kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts];
|
325 |
-
|
326 |
-
const int tid = threadIdx.x;
|
327 |
-
const int l_idx = tid / kNThreadsPerC;
|
328 |
-
const int c_idx = tid % kNThreadsPerC;
|
329 |
-
const int batch_id = blockIdx.x;
|
330 |
-
const int chunk_l_id = blockIdx.y;
|
331 |
-
const int chunk_c_id = blockIdx.z;
|
332 |
-
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
333 |
-
+ (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
334 |
-
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
|
335 |
-
+ chunk_c_id * kChunkSizeC * params.weight_c_stride;
|
336 |
-
input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
|
337 |
-
+ (chunk_l_id * kChunkSizeL + l_idx) * params.dout_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
338 |
-
input_t *dx = reinterpret_cast<input_t *>(params.dx_ptr) + batch_id * params.dx_batch_stride
|
339 |
-
+ (chunk_l_id * kChunkSizeL + l_idx) * params.dx_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
340 |
-
float *dweight = reinterpret_cast<float *>(params.dweight_ptr)
|
341 |
-
+ chunk_c_id * kChunkSizeC * params.dweight_c_stride;
|
342 |
-
|
343 |
-
#pragma unroll
|
344 |
-
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
345 |
-
input_t dout_vals_load[kNElts] = {0};
|
346 |
-
input_t x_vals_load[kNElts] = {0};
|
347 |
-
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
348 |
-
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
349 |
-
reinterpret_cast<vec_t *>(dout_vals_load)[0] = *reinterpret_cast<vec_t *>(dout + l * kLPerLoad * params.dout_l_stride);
|
350 |
-
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
|
351 |
-
}
|
352 |
-
reinterpret_cast<vec_t *>(dout_smem[l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(dout_vals_load)[0];
|
353 |
-
reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
354 |
-
}
|
355 |
-
// Load the elements from the previous chunk or next chunk that are needed for convolution.
|
356 |
-
if (l_idx < kWidth - 1) {
|
357 |
-
input_t dout_vals_load[kNElts] = {0};
|
358 |
-
input_t x_vals_load[kNElts] = {0};
|
359 |
-
if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen
|
360 |
-
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
361 |
-
reinterpret_cast<vec_t *>(dout_vals_load)[0] = *reinterpret_cast<vec_t *>(dout + kChunkSizeL * params.dout_l_stride);
|
362 |
-
}
|
363 |
-
if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
|
364 |
-
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
|
365 |
-
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
366 |
-
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
|
367 |
-
}
|
368 |
-
reinterpret_cast<vec_t *>(dout_smem[kChunkSizeL + l_idx])[c_idx] = reinterpret_cast<vec_t *>(dout_vals_load)[0];
|
369 |
-
reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
370 |
-
}
|
371 |
-
// Need to load (kWdith - 1) extra x's on the right to recompute the (kChunkSizeL + kWidth - 1) outputs
|
372 |
-
if constexpr (kSiluAct) {
|
373 |
-
if (l_idx < kWidth - 1) {
|
374 |
-
input_t x_vals_load[kNElts] = {0};
|
375 |
-
if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen
|
376 |
-
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
377 |
-
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + kChunkSizeL * params.x_l_stride);
|
378 |
-
}
|
379 |
-
reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + kChunkSizeL + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
380 |
-
}
|
381 |
-
}
|
382 |
-
|
383 |
-
__syncthreads();
|
384 |
-
|
385 |
-
constexpr int kLPerThread = std::min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
|
386 |
-
static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
|
387 |
-
constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
|
388 |
-
static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
|
389 |
-
// kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
|
390 |
-
static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
|
391 |
-
static_assert((kLPerThread & (kLPerThread - 1)) == 0);
|
392 |
-
static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
|
393 |
-
static_assert(kNThreadsPerRow <= 32);
|
394 |
-
|
395 |
-
const int row_idx = tid / kNThreadsPerRow;
|
396 |
-
const int col_idx = tid % kNThreadsPerRow;
|
397 |
-
|
398 |
-
float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
|
399 |
-
float weight_vals[kWidth] = {0};
|
400 |
-
if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
401 |
-
#pragma unroll
|
402 |
-
for (int w = 0; w < kWidth; ++w) {
|
403 |
-
weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
|
404 |
-
}
|
405 |
-
}
|
406 |
-
float dout_vals[kLPerThread + kWidth - 1];
|
407 |
-
float x_vals[kWidth - 1 + kLPerThread + kWidth - 1];
|
408 |
-
#pragma unroll
|
409 |
-
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
|
410 |
-
dout_vals[i] = float(dout_smem[col_idx * kLPerThread + i][row_idx]);
|
411 |
-
x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
|
412 |
-
}
|
413 |
-
|
414 |
-
if constexpr (kSiluAct) { // Recompute the output
|
415 |
-
#pragma unroll
|
416 |
-
for (int i = kWidth - 1 + kLPerThread; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) {
|
417 |
-
x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
|
418 |
-
}
|
419 |
-
#pragma unroll
|
420 |
-
for (int i = 0; i < kLPerThread + kWidth - 1; ++i) {
|
421 |
-
float out_val = bias_val;
|
422 |
-
#pragma unroll
|
423 |
-
for (int w = 0; w < kWidth; ++w) { out_val += weight_vals[w] * x_vals[i + w]; }
|
424 |
-
float out_val_sigmoid = 1.f / (1.f + expf(-out_val));
|
425 |
-
dout_vals[i] *= out_val_sigmoid * (1 + out_val * (1 - out_val_sigmoid));
|
426 |
-
}
|
427 |
-
}
|
428 |
-
|
429 |
-
float dweight_vals[kWidth] = {0};
|
430 |
-
SumOp<float> sum_op;
|
431 |
-
#pragma unroll
|
432 |
-
for (int w = 0; w < kWidth; ++w) {
|
433 |
-
#pragma unroll
|
434 |
-
for (int i = 0; i < kLPerThread; ++i) { dweight_vals[w] += x_vals[i + w] * dout_vals[i]; }
|
435 |
-
dweight_vals[w] = Allreduce<kNThreadsPerRow>::run(dweight_vals[w], sum_op);
|
436 |
-
if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
437 |
-
atomicAdd(&reinterpret_cast<float *>(dweight)[row_idx * params.dweight_c_stride + w * params.dweight_width_stride], dweight_vals[w]);
|
438 |
-
}
|
439 |
-
}
|
440 |
-
|
441 |
-
if (params.bias_ptr != nullptr) {
|
442 |
-
float dbias_val = 0.f;
|
443 |
-
for (int i = 0; i < kLPerThread; ++i) { dbias_val += dout_vals[i]; }
|
444 |
-
dbias_val = Allreduce<kNThreadsPerRow>::run(dbias_val, sum_op);
|
445 |
-
if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
446 |
-
atomicAdd(&reinterpret_cast<float *>(params.dbias_ptr)[chunk_c_id * kChunkSizeC + row_idx], dbias_val);
|
447 |
-
}
|
448 |
-
}
|
449 |
-
|
450 |
-
float dx_vals[kLPerThread] = {0};
|
451 |
-
#pragma unroll
|
452 |
-
for (int i = 0; i < kLPerThread; ++i) {
|
453 |
-
#pragma unroll
|
454 |
-
for (int w = 0; w < kWidth; ++w) { dx_vals[i] += weight_vals[kWidth - 1 - w] * dout_vals[i + w]; }
|
455 |
-
}
|
456 |
-
// Since kNThreadsPerRow is a power of 2 and <= 32, we only need syncwarp and not syncthreads.
|
457 |
-
__syncwarp();
|
458 |
-
#pragma unroll
|
459 |
-
for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = dx_vals[i]; }
|
460 |
-
__syncthreads();
|
461 |
-
|
462 |
-
#pragma unroll
|
463 |
-
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
464 |
-
input_t dx_vals_store[kNElts];
|
465 |
-
reinterpret_cast<vec_t *>(dx_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx];
|
466 |
-
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
467 |
-
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
468 |
-
*reinterpret_cast<vec_t *>(dx + l * kLPerLoad * params.dx_l_stride) = reinterpret_cast<vec_t *>(dx_vals_store)[0];
|
469 |
-
}
|
470 |
-
}
|
471 |
-
|
472 |
-
}
|
473 |
-
|
474 |
-
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
475 |
-
void causal_conv1d_channellast_bwd_launch(ConvParamsBwd ¶ms, cudaStream_t stream) {
|
476 |
-
BOOL_SWITCH(params.silu_activation, kSiluAct, [&] {
|
477 |
-
using Ktraits = Causal_conv1d_channellast_bwd_kernel_traits<kNThreads, kWidth, 64, kSiluAct, true, input_t, weight_t>;
|
478 |
-
// constexpr int kSmemSize = Ktraits::kSmemSize;
|
479 |
-
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
480 |
-
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
481 |
-
const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
|
482 |
-
const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
|
483 |
-
dim3 grid(params.batch, n_chunks_L, n_chunks_C);
|
484 |
-
dim3 block(Ktraits::kNThreads);
|
485 |
-
auto kernel = &causal_conv1d_channellast_bwd_kernel<Ktraits>;
|
486 |
-
// if (kSmemSize >= 48 * 1024) {
|
487 |
-
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
488 |
-
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
489 |
-
// }
|
490 |
-
// kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
491 |
-
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
492 |
-
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
493 |
-
});
|
494 |
-
}
|
495 |
-
|
496 |
-
template<typename input_t, typename weight_t>
|
497 |
-
void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream) {
|
498 |
-
if (params.width == 2) {
|
499 |
-
causal_conv1d_channellast_bwd_launch<128, 2, input_t, weight_t>(params, stream);
|
500 |
-
} else if (params.width == 3) {
|
501 |
-
causal_conv1d_channellast_bwd_launch<128, 3, input_t, weight_t>(params, stream);
|
502 |
-
} else if (params.width == 4) {
|
503 |
-
causal_conv1d_channellast_bwd_launch<128, 4, input_t, weight_t>(params, stream);
|
504 |
-
}
|
505 |
-
}
|
506 |
-
|
507 |
-
template void causal_conv1d_bwd_cuda<float, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
508 |
-
template void causal_conv1d_bwd_cuda<at::Half, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
509 |
-
template void causal_conv1d_bwd_cuda<at::BFloat16, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
510 |
-
template void causal_conv1d_bwd_cuda<float, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
511 |
-
template void causal_conv1d_bwd_cuda<at::Half, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
512 |
-
template void causal_conv1d_bwd_cuda<at::BFloat16, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
513 |
-
template void causal_conv1d_bwd_cuda<float, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
514 |
-
template void causal_conv1d_bwd_cuda<at::Half, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
515 |
-
template void causal_conv1d_bwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
516 |
-
|
517 |
-
template void causal_conv1d_channellast_bwd_cuda<float, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
518 |
-
template void causal_conv1d_channellast_bwd_cuda<at::Half, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
519 |
-
template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
520 |
-
template void causal_conv1d_channellast_bwd_cuda<float, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
521 |
-
template void causal_conv1d_channellast_bwd_cuda<at::Half, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
522 |
-
template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
523 |
-
template void causal_conv1d_channellast_bwd_cuda<float, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
524 |
-
template void causal_conv1d_channellast_bwd_cuda<at::Half, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
525 |
-
template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
causal-conv1d/csrc/causal_conv1d_common.h
DELETED
@@ -1,64 +0,0 @@
|
|
1 |
-
/******************************************************************************
|
2 |
-
* Copyright (c) 2023, Tri Dao.
|
3 |
-
******************************************************************************/
|
4 |
-
|
5 |
-
#pragma once
|
6 |
-
|
7 |
-
#include <cuda_bf16.h>
|
8 |
-
#include <cuda_fp16.h>
|
9 |
-
|
10 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
11 |
-
|
12 |
-
template<int BYTES> struct BytesToType {};
|
13 |
-
|
14 |
-
template<> struct BytesToType<16> {
|
15 |
-
using Type = uint4;
|
16 |
-
static_assert(sizeof(Type) == 16);
|
17 |
-
};
|
18 |
-
|
19 |
-
template<> struct BytesToType<8> {
|
20 |
-
using Type = uint64_t;
|
21 |
-
static_assert(sizeof(Type) == 8);
|
22 |
-
};
|
23 |
-
|
24 |
-
template<> struct BytesToType<4> {
|
25 |
-
using Type = uint32_t;
|
26 |
-
static_assert(sizeof(Type) == 4);
|
27 |
-
};
|
28 |
-
|
29 |
-
template<> struct BytesToType<2> {
|
30 |
-
using Type = uint16_t;
|
31 |
-
static_assert(sizeof(Type) == 2);
|
32 |
-
};
|
33 |
-
|
34 |
-
template<> struct BytesToType<1> {
|
35 |
-
using Type = uint8_t;
|
36 |
-
static_assert(sizeof(Type) == 1);
|
37 |
-
};
|
38 |
-
|
39 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
40 |
-
|
41 |
-
template<typename T>
|
42 |
-
struct SumOp {
|
43 |
-
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
|
44 |
-
};
|
45 |
-
|
46 |
-
template<int THREADS>
|
47 |
-
struct Allreduce {
|
48 |
-
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
49 |
-
template<typename T, typename Operator>
|
50 |
-
static __device__ inline T run(T x, Operator &op) {
|
51 |
-
constexpr int OFFSET = THREADS / 2;
|
52 |
-
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
53 |
-
return Allreduce<OFFSET>::run(x, op);
|
54 |
-
}
|
55 |
-
};
|
56 |
-
|
57 |
-
template<>
|
58 |
-
struct Allreduce<2> {
|
59 |
-
template<typename T, typename Operator>
|
60 |
-
static __device__ inline T run(T x, Operator &op) {
|
61 |
-
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
62 |
-
return x;
|
63 |
-
}
|
64 |
-
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
causal-conv1d/csrc/causal_conv1d_fwd.cu
DELETED
@@ -1,350 +0,0 @@
|
|
1 |
-
/******************************************************************************
|
2 |
-
* Copyright (c) 2023, Tri Dao.
|
3 |
-
******************************************************************************/
|
4 |
-
|
5 |
-
#include <c10/util/BFloat16.h>
|
6 |
-
#include <c10/util/Half.h>
|
7 |
-
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
8 |
-
|
9 |
-
#include <cub/block/block_load.cuh>
|
10 |
-
#include <cub/block/block_store.cuh>
|
11 |
-
|
12 |
-
#include "causal_conv1d.h"
|
13 |
-
#include "causal_conv1d_common.h"
|
14 |
-
#include "static_switch.h"
|
15 |
-
|
16 |
-
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
17 |
-
struct Causal_conv1d_fwd_kernel_traits {
|
18 |
-
using input_t = input_t_;
|
19 |
-
using weight_t = weight_t_;
|
20 |
-
static constexpr int kNThreads = kNThreads_;
|
21 |
-
static constexpr int kWidth = kWidth_;
|
22 |
-
static constexpr int kNBytes = sizeof(input_t);
|
23 |
-
static_assert(kNBytes == 2 || kNBytes == 4);
|
24 |
-
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
25 |
-
static_assert(kWidth <= kNElts);
|
26 |
-
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
27 |
-
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
28 |
-
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
29 |
-
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
|
30 |
-
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
31 |
-
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
|
32 |
-
static constexpr int kSmemIOSize = kIsVecLoad
|
33 |
-
? 0
|
34 |
-
: std::max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
|
35 |
-
static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
|
36 |
-
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
|
37 |
-
};
|
38 |
-
|
39 |
-
template<typename Ktraits>
|
40 |
-
__global__ __launch_bounds__(Ktraits::kNThreads)
|
41 |
-
void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
42 |
-
constexpr int kWidth = Ktraits::kWidth;
|
43 |
-
constexpr int kNThreads = Ktraits::kNThreads;
|
44 |
-
constexpr int kNElts = Ktraits::kNElts;
|
45 |
-
constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
|
46 |
-
using input_t = typename Ktraits::input_t;
|
47 |
-
using vec_t = typename Ktraits::vec_t;
|
48 |
-
using weight_t = typename Ktraits::weight_t;
|
49 |
-
|
50 |
-
// Shared memory.
|
51 |
-
extern __shared__ char smem_[];
|
52 |
-
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
53 |
-
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
|
54 |
-
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
55 |
-
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
|
56 |
-
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
|
57 |
-
|
58 |
-
const int tidx = threadIdx.x;
|
59 |
-
const int batch_id = blockIdx.x;
|
60 |
-
const int channel_id = blockIdx.y;
|
61 |
-
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
62 |
-
+ channel_id * params.x_c_stride;
|
63 |
-
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
64 |
-
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
65 |
-
+ channel_id * params.out_c_stride;
|
66 |
-
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
67 |
-
|
68 |
-
// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
|
69 |
-
if (tidx == 0) {
|
70 |
-
input_t zeros[kNElts] = {0};
|
71 |
-
smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(zeros)[0];
|
72 |
-
}
|
73 |
-
|
74 |
-
float weight_vals[kWidth];
|
75 |
-
#pragma unroll
|
76 |
-
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
77 |
-
|
78 |
-
constexpr int kChunkSize = kNThreads * kNElts;
|
79 |
-
const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
|
80 |
-
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
81 |
-
input_t x_vals_load[2 * kNElts] = {0};
|
82 |
-
if constexpr(kIsVecLoad) {
|
83 |
-
Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
|
84 |
-
} else {
|
85 |
-
__syncthreads();
|
86 |
-
Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
|
87 |
-
}
|
88 |
-
x += kChunkSize;
|
89 |
-
__syncthreads();
|
90 |
-
// Thread kNThreads - 1 don't write yet, so that thread 0 can read
|
91 |
-
// the last elements of the previous chunk.
|
92 |
-
if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
93 |
-
__syncthreads();
|
94 |
-
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
|
95 |
-
__syncthreads();
|
96 |
-
// Now thread kNThreads - 1 can write the last elements of the current chunk.
|
97 |
-
if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
98 |
-
|
99 |
-
float x_vals[2 * kNElts];
|
100 |
-
#pragma unroll
|
101 |
-
for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
|
102 |
-
|
103 |
-
float out_vals[kNElts];
|
104 |
-
#pragma unroll
|
105 |
-
for (int i = 0; i < kNElts; ++i) {
|
106 |
-
out_vals[i] = bias_val;
|
107 |
-
#pragma unroll
|
108 |
-
for (int w = 0; w < kWidth; ++w) {
|
109 |
-
out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
|
110 |
-
}
|
111 |
-
}
|
112 |
-
|
113 |
-
if (params.silu_activation) {
|
114 |
-
#pragma unroll
|
115 |
-
for (int i = 0; i < kNElts; ++i) {
|
116 |
-
out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
|
117 |
-
}
|
118 |
-
}
|
119 |
-
|
120 |
-
input_t out_vals_store[kNElts];
|
121 |
-
#pragma unroll
|
122 |
-
for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
|
123 |
-
if constexpr(kIsVecLoad) {
|
124 |
-
Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
|
125 |
-
} else {
|
126 |
-
Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize);
|
127 |
-
}
|
128 |
-
out += kChunkSize;
|
129 |
-
}
|
130 |
-
}
|
131 |
-
|
132 |
-
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
133 |
-
void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
134 |
-
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
|
135 |
-
BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
|
136 |
-
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
|
137 |
-
constexpr int kSmemSize = Ktraits::kSmemSize;
|
138 |
-
dim3 grid(params.batch, params.dim);
|
139 |
-
auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
|
140 |
-
if (kSmemSize >= 48 * 1024) {
|
141 |
-
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
142 |
-
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
143 |
-
}
|
144 |
-
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
145 |
-
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
146 |
-
});
|
147 |
-
}
|
148 |
-
|
149 |
-
template<typename input_t, typename weight_t>
|
150 |
-
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
151 |
-
if (params.width == 2) {
|
152 |
-
causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
153 |
-
} else if (params.width == 3) {
|
154 |
-
causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
155 |
-
} else if (params.width == 4) {
|
156 |
-
causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
157 |
-
}
|
158 |
-
}
|
159 |
-
|
160 |
-
template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
161 |
-
struct Causal_conv1d_channellast_fwd_kernel_traits {
|
162 |
-
// The cache line is 128 bytes, and we try to read 16 bytes per thread.
|
163 |
-
// So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
|
164 |
-
// That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
|
165 |
-
// threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
|
166 |
-
using input_t = input_t_;
|
167 |
-
using weight_t = weight_t_;
|
168 |
-
static constexpr int kNThreads = kNThreads_;
|
169 |
-
static_assert(kNThreads % 32 == 0);
|
170 |
-
static constexpr int kNWarps = kNThreads / 32;
|
171 |
-
static constexpr int kWidth = kWidth_;
|
172 |
-
static constexpr int kChunkSizeL = kChunkSizeL_;
|
173 |
-
static constexpr int kNBytes = sizeof(input_t);
|
174 |
-
static_assert(kNBytes == 2 || kNBytes == 4);
|
175 |
-
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
176 |
-
static constexpr int kNEltsPerRow = 128 / kNBytes;
|
177 |
-
static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
|
178 |
-
static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
|
179 |
-
static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
|
180 |
-
static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
|
181 |
-
static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
|
182 |
-
static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
|
183 |
-
static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
|
184 |
-
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
185 |
-
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
186 |
-
// using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
187 |
-
// using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
188 |
-
// static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
|
189 |
-
// sizeof(typename BlockStoreT::TempStorage)});
|
190 |
-
// static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
|
191 |
-
};
|
192 |
-
|
193 |
-
template<typename Ktraits>
|
194 |
-
__global__ __launch_bounds__(Ktraits::kNThreads)
|
195 |
-
void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) {
|
196 |
-
constexpr int kWidth = Ktraits::kWidth;
|
197 |
-
constexpr int kNThreads = Ktraits::kNThreads;
|
198 |
-
constexpr int kNElts = Ktraits::kNElts;
|
199 |
-
constexpr int kNWarp = Ktraits::kNWarps;
|
200 |
-
constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
|
201 |
-
constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
|
202 |
-
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
203 |
-
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
204 |
-
using input_t = typename Ktraits::input_t;
|
205 |
-
using vec_t = typename Ktraits::vec_t;
|
206 |
-
using weight_t = typename Ktraits::weight_t;
|
207 |
-
|
208 |
-
// Shared memory.
|
209 |
-
__shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts];
|
210 |
-
|
211 |
-
const int tid = threadIdx.x;
|
212 |
-
const int l_idx = tid / kNThreadsPerC;
|
213 |
-
const int c_idx = tid % kNThreadsPerC;
|
214 |
-
const int batch_id = blockIdx.x;
|
215 |
-
const int chunk_l_id = blockIdx.y;
|
216 |
-
const int chunk_c_id = blockIdx.z;
|
217 |
-
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
218 |
-
+ (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
219 |
-
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
|
220 |
-
+ chunk_c_id * kChunkSizeC * params.weight_c_stride;
|
221 |
-
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
222 |
-
+ (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
223 |
-
|
224 |
-
#pragma unroll
|
225 |
-
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
226 |
-
input_t x_vals_load[kNElts] = {0};
|
227 |
-
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
228 |
-
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
229 |
-
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
|
230 |
-
}
|
231 |
-
reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
232 |
-
}
|
233 |
-
// Load the elements from the previous chunk that are needed for convolution.
|
234 |
-
if (l_idx < kWidth - 1) {
|
235 |
-
input_t x_vals_load[kNElts] = {0};
|
236 |
-
if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
|
237 |
-
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
|
238 |
-
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
239 |
-
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
|
240 |
-
}
|
241 |
-
reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
242 |
-
}
|
243 |
-
|
244 |
-
__syncthreads();
|
245 |
-
|
246 |
-
constexpr int kLPerThread = std::min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
|
247 |
-
static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
|
248 |
-
constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
|
249 |
-
static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
|
250 |
-
// kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
|
251 |
-
static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
|
252 |
-
static_assert((kLPerThread & (kLPerThread - 1)) == 0);
|
253 |
-
static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
|
254 |
-
static_assert(kNThreadsPerRow <= 32);
|
255 |
-
|
256 |
-
const int row_idx = tid / kNThreadsPerRow;
|
257 |
-
const int col_idx = tid % kNThreadsPerRow;
|
258 |
-
|
259 |
-
float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
|
260 |
-
float weight_vals[kWidth] = {0};
|
261 |
-
if (chunk_c_id + kChunkSizeC + row_idx < params.dim) {
|
262 |
-
#pragma unroll
|
263 |
-
for (int w = 0; w < kWidth; ++w) {
|
264 |
-
weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
|
265 |
-
}
|
266 |
-
}
|
267 |
-
float x_vals[kWidth - 1 + kLPerThread];
|
268 |
-
#pragma unroll
|
269 |
-
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
|
270 |
-
x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
|
271 |
-
}
|
272 |
-
|
273 |
-
float out_vals[kLPerThread];
|
274 |
-
#pragma unroll
|
275 |
-
for (int i = 0; i < kLPerThread; ++i) {
|
276 |
-
out_vals[i] = bias_val;
|
277 |
-
#pragma unroll
|
278 |
-
for (int w = 0; w < kWidth; ++w) { out_vals[i] += weight_vals[w] * x_vals[i + w]; }
|
279 |
-
if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); }
|
280 |
-
}
|
281 |
-
|
282 |
-
// Since kNThreadsPerRow is a power of 2 and <= 32, we only need syncwarp and not syncthreads.
|
283 |
-
__syncwarp();
|
284 |
-
#pragma unroll
|
285 |
-
for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; }
|
286 |
-
__syncthreads();
|
287 |
-
|
288 |
-
#pragma unroll
|
289 |
-
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
290 |
-
input_t out_vals_store[kNElts];
|
291 |
-
reinterpret_cast<vec_t *>(out_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx];
|
292 |
-
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
293 |
-
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
294 |
-
*reinterpret_cast<vec_t *>(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast<vec_t *>(out_vals_store)[0];
|
295 |
-
}
|
296 |
-
}
|
297 |
-
|
298 |
-
}
|
299 |
-
|
300 |
-
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
301 |
-
void causal_conv1d_channellast_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
302 |
-
using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits<kNThreads, kWidth, 64, true, input_t, weight_t>;
|
303 |
-
// constexpr int kSmemSize = Ktraits::kSmemSize;
|
304 |
-
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
305 |
-
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
306 |
-
const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
|
307 |
-
const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
|
308 |
-
// printf("n_chunks_L: %d, n_chunks_C: %d\n", n_chunks_L, n_chunks_C);
|
309 |
-
dim3 grid(params.batch, n_chunks_L, n_chunks_C);
|
310 |
-
dim3 block(Ktraits::kNThreads);
|
311 |
-
auto kernel = &causal_conv1d_channellast_fwd_kernel<Ktraits>;
|
312 |
-
// if (kSmemSize >= 48 * 1024) {
|
313 |
-
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
314 |
-
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
315 |
-
// }
|
316 |
-
// kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
317 |
-
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
318 |
-
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
319 |
-
}
|
320 |
-
|
321 |
-
template<typename input_t, typename weight_t>
|
322 |
-
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
323 |
-
if (params.width == 2) {
|
324 |
-
causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
325 |
-
} else if (params.width == 3) {
|
326 |
-
causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
327 |
-
} else if (params.width == 4) {
|
328 |
-
causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
329 |
-
}
|
330 |
-
}
|
331 |
-
|
332 |
-
template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
333 |
-
template void causal_conv1d_fwd_cuda<at::Half, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
334 |
-
template void causal_conv1d_fwd_cuda<at::BFloat16, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
335 |
-
template void causal_conv1d_fwd_cuda<float, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
336 |
-
template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
337 |
-
template void causal_conv1d_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
338 |
-
template void causal_conv1d_fwd_cuda<float, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
339 |
-
template void causal_conv1d_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
340 |
-
template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
341 |
-
|
342 |
-
template void causal_conv1d_channellast_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
343 |
-
template void causal_conv1d_channellast_fwd_cuda<at::Half, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
344 |
-
template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
345 |
-
template void causal_conv1d_channellast_fwd_cuda<float, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
346 |
-
template void causal_conv1d_channellast_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
347 |
-
template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
348 |
-
template void causal_conv1d_channellast_fwd_cuda<float, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
349 |
-
template void causal_conv1d_channellast_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
350 |
-
template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
causal-conv1d/csrc/causal_conv1d_update.cu
DELETED
@@ -1,96 +0,0 @@
|
|
1 |
-
/******************************************************************************
|
2 |
-
* Copyright (c) 2023, Tri Dao.
|
3 |
-
******************************************************************************/
|
4 |
-
|
5 |
-
#include <c10/util/BFloat16.h>
|
6 |
-
#include <c10/util/Half.h>
|
7 |
-
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
8 |
-
|
9 |
-
#include <cub/block/block_load.cuh>
|
10 |
-
#include <cub/block/block_store.cuh>
|
11 |
-
|
12 |
-
#include "causal_conv1d.h"
|
13 |
-
#include "causal_conv1d_common.h"
|
14 |
-
#include "static_switch.h"
|
15 |
-
|
16 |
-
template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
|
17 |
-
struct Causal_conv1d_update_kernel_traits {
|
18 |
-
using input_t = input_t_;
|
19 |
-
using weight_t = weight_t_;
|
20 |
-
static constexpr int kNThreads = kNThreads_;
|
21 |
-
static constexpr int kWidth = kWidth_;
|
22 |
-
static constexpr int kNBytes = sizeof(input_t);
|
23 |
-
static_assert(kNBytes == 2 || kNBytes == 4);
|
24 |
-
};
|
25 |
-
|
26 |
-
template<typename Ktraits>
|
27 |
-
__global__ __launch_bounds__(Ktraits::kNThreads)
|
28 |
-
void causal_conv1d_update_kernel(ConvParamsBase params) {
|
29 |
-
constexpr int kWidth = Ktraits::kWidth;
|
30 |
-
constexpr int kNThreads = Ktraits::kNThreads;
|
31 |
-
using input_t = typename Ktraits::input_t;
|
32 |
-
using weight_t = typename Ktraits::weight_t;
|
33 |
-
|
34 |
-
const int tidx = threadIdx.x;
|
35 |
-
const int batch_id = blockIdx.x;
|
36 |
-
const int channel_id = blockIdx.y * kNThreads + tidx;
|
37 |
-
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
38 |
-
+ channel_id * params.x_c_stride;
|
39 |
-
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride
|
40 |
-
+ channel_id * params.conv_state_c_stride;
|
41 |
-
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
42 |
-
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
43 |
-
+ channel_id * params.out_c_stride;
|
44 |
-
float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
45 |
-
|
46 |
-
float weight_vals[kWidth] = {0};
|
47 |
-
if (channel_id < params.dim) {
|
48 |
-
#pragma unroll
|
49 |
-
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
50 |
-
}
|
51 |
-
|
52 |
-
float x_vals[kWidth] = {0};
|
53 |
-
if (channel_id < params.dim) {
|
54 |
-
#pragma unroll
|
55 |
-
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); }
|
56 |
-
x_vals[kWidth - 1] = float(x[0]);
|
57 |
-
#pragma unroll
|
58 |
-
for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); }
|
59 |
-
}
|
60 |
-
|
61 |
-
float out_val = bias_val;
|
62 |
-
#pragma unroll
|
63 |
-
for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; }
|
64 |
-
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
|
65 |
-
if (channel_id < params.dim) { out[0] = input_t(out_val); }
|
66 |
-
}
|
67 |
-
|
68 |
-
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
69 |
-
void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
70 |
-
using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
|
71 |
-
dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
|
72 |
-
auto kernel = &causal_conv1d_update_kernel<Ktraits>;
|
73 |
-
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
74 |
-
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
75 |
-
}
|
76 |
-
|
77 |
-
template<typename input_t, typename weight_t>
|
78 |
-
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
79 |
-
if (params.width == 2) {
|
80 |
-
causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
|
81 |
-
} else if (params.width == 3) {
|
82 |
-
causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
|
83 |
-
} else if (params.width == 4) {
|
84 |
-
causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
|
85 |
-
}
|
86 |
-
}
|
87 |
-
|
88 |
-
template void causal_conv1d_update_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
89 |
-
template void causal_conv1d_update_cuda<at::Half, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
90 |
-
template void causal_conv1d_update_cuda<at::BFloat16, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
91 |
-
template void causal_conv1d_update_cuda<float, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
92 |
-
template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
93 |
-
template void causal_conv1d_update_cuda<at::BFloat16, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
94 |
-
template void causal_conv1d_update_cuda<float, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
95 |
-
template void causal_conv1d_update_cuda<at::Half, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
96 |
-
template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
causal-conv1d/csrc/static_switch.h
DELETED
@@ -1,25 +0,0 @@
|
|
1 |
-
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
2 |
-
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
3 |
-
|
4 |
-
#pragma once
|
5 |
-
|
6 |
-
/// @param COND - a boolean expression to switch by
|
7 |
-
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
8 |
-
/// @param ... - code to execute for true and false
|
9 |
-
///
|
10 |
-
/// Usage:
|
11 |
-
/// ```
|
12 |
-
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
13 |
-
/// some_function<BoolConst>(...);
|
14 |
-
/// });
|
15 |
-
/// ```
|
16 |
-
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
17 |
-
[&] { \
|
18 |
-
if (COND) { \
|
19 |
-
static constexpr bool CONST_NAME = true; \
|
20 |
-
return __VA_ARGS__(); \
|
21 |
-
} else { \
|
22 |
-
static constexpr bool CONST_NAME = false; \
|
23 |
-
return __VA_ARGS__(); \
|
24 |
-
} \
|
25 |
-
}()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
causal-conv1d/setup.py
DELETED
@@ -1,264 +0,0 @@
|
|
1 |
-
# Copyright (c) 2023, Tri Dao.
|
2 |
-
import sys
|
3 |
-
import warnings
|
4 |
-
import os
|
5 |
-
import re
|
6 |
-
import ast
|
7 |
-
from pathlib import Path
|
8 |
-
from packaging.version import parse, Version
|
9 |
-
import platform
|
10 |
-
|
11 |
-
from setuptools import setup, find_packages
|
12 |
-
import subprocess
|
13 |
-
|
14 |
-
import urllib.request
|
15 |
-
import urllib.error
|
16 |
-
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
17 |
-
|
18 |
-
import torch
|
19 |
-
from torch.utils.cpp_extension import (
|
20 |
-
BuildExtension,
|
21 |
-
CppExtension,
|
22 |
-
CUDAExtension,
|
23 |
-
CUDA_HOME,
|
24 |
-
)
|
25 |
-
|
26 |
-
|
27 |
-
with open("README.md", "r", encoding="utf-8") as fh:
|
28 |
-
long_description = fh.read()
|
29 |
-
|
30 |
-
|
31 |
-
# ninja build does not work unless include_dirs are abs path
|
32 |
-
this_dir = os.path.dirname(os.path.abspath(__file__))
|
33 |
-
|
34 |
-
PACKAGE_NAME = "causal_conv1d"
|
35 |
-
|
36 |
-
BASE_WHEEL_URL = "https://github.com/Dao-AILab/causal-conv1d/releases/download/{tag_name}/{wheel_name}"
|
37 |
-
|
38 |
-
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
|
39 |
-
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
|
40 |
-
FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "FALSE") == "TRUE"
|
41 |
-
SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
|
42 |
-
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
|
43 |
-
FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "FALSE") == "TRUE"
|
44 |
-
|
45 |
-
|
46 |
-
def get_platform():
|
47 |
-
"""
|
48 |
-
Returns the platform name as used in wheel filenames.
|
49 |
-
"""
|
50 |
-
if sys.platform.startswith("linux"):
|
51 |
-
return "linux_x86_64"
|
52 |
-
elif sys.platform == "darwin":
|
53 |
-
mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
|
54 |
-
return f"macosx_{mac_version}_x86_64"
|
55 |
-
elif sys.platform == "win32":
|
56 |
-
return "win_amd64"
|
57 |
-
else:
|
58 |
-
raise ValueError("Unsupported platform: {}".format(sys.platform))
|
59 |
-
|
60 |
-
|
61 |
-
def get_cuda_bare_metal_version(cuda_dir):
|
62 |
-
raw_output = subprocess.check_output(
|
63 |
-
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
|
64 |
-
)
|
65 |
-
output = raw_output.split()
|
66 |
-
release_idx = output.index("release") + 1
|
67 |
-
bare_metal_version = parse(output[release_idx].split(",")[0])
|
68 |
-
|
69 |
-
return raw_output, bare_metal_version
|
70 |
-
|
71 |
-
|
72 |
-
def check_if_cuda_home_none(global_option: str) -> None:
|
73 |
-
if CUDA_HOME is not None:
|
74 |
-
return
|
75 |
-
# warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
|
76 |
-
# in that case.
|
77 |
-
warnings.warn(
|
78 |
-
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
|
79 |
-
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
|
80 |
-
"only images whose names contain 'devel' will provide nvcc."
|
81 |
-
)
|
82 |
-
|
83 |
-
|
84 |
-
def append_nvcc_threads(nvcc_extra_args):
|
85 |
-
return nvcc_extra_args + ["--threads", "4"]
|
86 |
-
|
87 |
-
|
88 |
-
cmdclass = {}
|
89 |
-
ext_modules = []
|
90 |
-
|
91 |
-
if not SKIP_CUDA_BUILD:
|
92 |
-
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
|
93 |
-
TORCH_MAJOR = int(torch.__version__.split(".")[0])
|
94 |
-
TORCH_MINOR = int(torch.__version__.split(".")[1])
|
95 |
-
|
96 |
-
check_if_cuda_home_none("causal_conv1d")
|
97 |
-
# Check, if CUDA11 is installed for compute capability 8.0
|
98 |
-
cc_flag = []
|
99 |
-
if CUDA_HOME is not None:
|
100 |
-
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
101 |
-
if bare_metal_version < Version("11.6"):
|
102 |
-
raise RuntimeError(
|
103 |
-
"causal_conv1d is only supported on CUDA 11.6 and above. "
|
104 |
-
"Note: make sure nvcc has a supported version by running nvcc -V."
|
105 |
-
)
|
106 |
-
|
107 |
-
cc_flag.append("-gencode")
|
108 |
-
cc_flag.append("arch=compute_70,code=sm_70")
|
109 |
-
cc_flag.append("-gencode")
|
110 |
-
cc_flag.append("arch=compute_80,code=sm_80")
|
111 |
-
if bare_metal_version >= Version("11.8"):
|
112 |
-
cc_flag.append("-gencode")
|
113 |
-
cc_flag.append("arch=compute_90,code=sm_90")
|
114 |
-
|
115 |
-
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
|
116 |
-
# torch._C._GLIBCXX_USE_CXX11_ABI
|
117 |
-
# https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
|
118 |
-
if FORCE_CXX11_ABI:
|
119 |
-
torch._C._GLIBCXX_USE_CXX11_ABI = True
|
120 |
-
|
121 |
-
ext_modules.append(
|
122 |
-
CUDAExtension(
|
123 |
-
name="causal_conv1d_cuda",
|
124 |
-
sources=[
|
125 |
-
"csrc/causal_conv1d.cpp",
|
126 |
-
"csrc/causal_conv1d_fwd.cu",
|
127 |
-
"csrc/causal_conv1d_bwd.cu",
|
128 |
-
"csrc/causal_conv1d_update.cu",
|
129 |
-
],
|
130 |
-
extra_compile_args={
|
131 |
-
"cxx": ["-O3"],
|
132 |
-
"nvcc": append_nvcc_threads(
|
133 |
-
[
|
134 |
-
"-O3",
|
135 |
-
"-U__CUDA_NO_HALF_OPERATORS__",
|
136 |
-
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
137 |
-
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
|
138 |
-
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
139 |
-
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
|
140 |
-
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
|
141 |
-
"--expt-relaxed-constexpr",
|
142 |
-
"--expt-extended-lambda",
|
143 |
-
"--use_fast_math",
|
144 |
-
"--ptxas-options=-v",
|
145 |
-
"-lineinfo",
|
146 |
-
]
|
147 |
-
+ cc_flag
|
148 |
-
),
|
149 |
-
},
|
150 |
-
include_dirs=[this_dir],
|
151 |
-
)
|
152 |
-
)
|
153 |
-
|
154 |
-
|
155 |
-
def get_package_version():
|
156 |
-
with open(Path(this_dir) / "causal_conv1d" / "__init__.py", "r") as f:
|
157 |
-
version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
|
158 |
-
public_version = ast.literal_eval(version_match.group(1))
|
159 |
-
local_version = os.environ.get("CAUSAL_CONV1D_LOCAL_VERSION")
|
160 |
-
if local_version:
|
161 |
-
return f"{public_version}+{local_version}"
|
162 |
-
else:
|
163 |
-
return str(public_version)
|
164 |
-
|
165 |
-
|
166 |
-
def get_wheel_url():
|
167 |
-
# Determine the version numbers that will be used to determine the correct wheel
|
168 |
-
# We're using the CUDA version used to build torch, not the one currently installed
|
169 |
-
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
|
170 |
-
torch_cuda_version = parse(torch.version.cuda)
|
171 |
-
torch_version_raw = parse(torch.__version__)
|
172 |
-
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
|
173 |
-
# to save CI time. Minor versions should be compatible.
|
174 |
-
torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
|
175 |
-
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
|
176 |
-
platform_name = get_platform()
|
177 |
-
causal_conv1d_version = get_package_version()
|
178 |
-
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
|
179 |
-
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
|
180 |
-
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
|
181 |
-
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
|
182 |
-
|
183 |
-
# Determine wheel URL based on CUDA version, torch version, python version and OS
|
184 |
-
wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
|
185 |
-
wheel_url = BASE_WHEEL_URL.format(
|
186 |
-
tag_name=f"v{causal_conv1d_version}", wheel_name=wheel_filename
|
187 |
-
)
|
188 |
-
return wheel_url, wheel_filename
|
189 |
-
|
190 |
-
|
191 |
-
class CachedWheelsCommand(_bdist_wheel):
|
192 |
-
"""
|
193 |
-
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
|
194 |
-
find an existing wheel (which is currently the case for all installs). We use
|
195 |
-
the environment parameters to detect whether there is already a pre-built version of a compatible
|
196 |
-
wheel available and short-circuits the standard full build pipeline.
|
197 |
-
"""
|
198 |
-
|
199 |
-
def run(self):
|
200 |
-
if FORCE_BUILD:
|
201 |
-
return super().run()
|
202 |
-
|
203 |
-
wheel_url, wheel_filename = get_wheel_url()
|
204 |
-
print("Guessing wheel URL: ", wheel_url)
|
205 |
-
try:
|
206 |
-
urllib.request.urlretrieve(wheel_url, wheel_filename)
|
207 |
-
|
208 |
-
# Make the archive
|
209 |
-
# Lifted from the root wheel processing command
|
210 |
-
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
|
211 |
-
if not os.path.exists(self.dist_dir):
|
212 |
-
os.makedirs(self.dist_dir)
|
213 |
-
|
214 |
-
impl_tag, abi_tag, plat_tag = self.get_tag()
|
215 |
-
archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
|
216 |
-
|
217 |
-
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
|
218 |
-
print("Raw wheel path", wheel_path)
|
219 |
-
os.rename(wheel_filename, wheel_path)
|
220 |
-
except urllib.error.HTTPError:
|
221 |
-
print("Precompiled wheel not found. Building from source...")
|
222 |
-
# If the wheel could not be downloaded, build from source
|
223 |
-
super().run()
|
224 |
-
|
225 |
-
|
226 |
-
setup(
|
227 |
-
name=PACKAGE_NAME,
|
228 |
-
version=get_package_version(),
|
229 |
-
packages=find_packages(
|
230 |
-
exclude=(
|
231 |
-
"build",
|
232 |
-
"csrc",
|
233 |
-
"include",
|
234 |
-
"tests",
|
235 |
-
"dist",
|
236 |
-
"docs",
|
237 |
-
"benchmarks",
|
238 |
-
"causal_conv1d.egg-info",
|
239 |
-
)
|
240 |
-
),
|
241 |
-
author="Tri Dao",
|
242 |
-
author_email="tri@tridao.me",
|
243 |
-
description="Causal depthwise conv1d in CUDA, with a PyTorch interface",
|
244 |
-
long_description=long_description,
|
245 |
-
long_description_content_type="text/markdown",
|
246 |
-
url="https://github.com/Dao-AILab/causal-conv1d",
|
247 |
-
classifiers=[
|
248 |
-
"Programming Language :: Python :: 3",
|
249 |
-
"License :: OSI Approved :: BSD License",
|
250 |
-
"Operating System :: Unix",
|
251 |
-
],
|
252 |
-
ext_modules=ext_modules,
|
253 |
-
cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension}
|
254 |
-
if ext_modules
|
255 |
-
else {
|
256 |
-
"bdist_wheel": CachedWheelsCommand,
|
257 |
-
},
|
258 |
-
python_requires=">=3.7",
|
259 |
-
install_requires=[
|
260 |
-
"torch",
|
261 |
-
"packaging",
|
262 |
-
"ninja",
|
263 |
-
],
|
264 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
causal-conv1d/tests/test_causal_conv1d.py
DELETED
@@ -1,173 +0,0 @@
|
|
1 |
-
# Copyright (C) 2023, Tri Dao.
|
2 |
-
|
3 |
-
import math
|
4 |
-
|
5 |
-
import torch
|
6 |
-
import pytest
|
7 |
-
|
8 |
-
from einops import rearrange
|
9 |
-
|
10 |
-
from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_ref
|
11 |
-
from causal_conv1d.causal_conv1d_interface import causal_conv1d_update, causal_conv1d_update_ref
|
12 |
-
|
13 |
-
|
14 |
-
@pytest.mark.parametrize("channel_last", [False, True])
|
15 |
-
# @pytest.mark.parametrize('channel_last', [True])
|
16 |
-
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
17 |
-
# @pytest.mark.parametrize('itype', [torch.float16])
|
18 |
-
@pytest.mark.parametrize("silu_activation", [False, True])
|
19 |
-
# @pytest.mark.parametrize('silu_activation', [True])
|
20 |
-
@pytest.mark.parametrize("has_bias", [False, True])
|
21 |
-
# @pytest.mark.parametrize('has_bias', [True])
|
22 |
-
@pytest.mark.parametrize("width", [2, 3, 4])
|
23 |
-
# @pytest.mark.parametrize('width', [2])
|
24 |
-
@pytest.mark.parametrize(
|
25 |
-
"seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
|
26 |
-
)
|
27 |
-
# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
|
28 |
-
# @pytest.mark.parametrize('seqlen', [128])
|
29 |
-
def test_causal_conv1d(seqlen, width, has_bias, silu_activation, itype, channel_last):
|
30 |
-
device = "cuda"
|
31 |
-
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
32 |
-
if itype == torch.bfloat16:
|
33 |
-
rtol, atol = 1e-2, 5e-2
|
34 |
-
rtolw, atolw = (1e-3, 1e-3)
|
35 |
-
# set seed
|
36 |
-
torch.random.manual_seed(0)
|
37 |
-
batch_size = 2
|
38 |
-
# batch_size = 1
|
39 |
-
dim = 4096 + 32 # Try dim not divisible by 64
|
40 |
-
# dim = 64
|
41 |
-
if not channel_last:
|
42 |
-
x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
|
43 |
-
else:
|
44 |
-
x = rearrange(
|
45 |
-
torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
|
46 |
-
).requires_grad_()
|
47 |
-
weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
|
48 |
-
if has_bias:
|
49 |
-
bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
50 |
-
else:
|
51 |
-
bias = None
|
52 |
-
x_ref = x.detach().clone().requires_grad_()
|
53 |
-
weight_ref = weight.detach().clone().requires_grad_()
|
54 |
-
bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
|
55 |
-
activation = None if not silu_activation else "silu"
|
56 |
-
out = causal_conv1d_fn(x, weight, bias, activation=activation)
|
57 |
-
out_ref = causal_conv1d_ref(x_ref, weight_ref, bias_ref, activation=activation)
|
58 |
-
|
59 |
-
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
60 |
-
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
61 |
-
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
62 |
-
|
63 |
-
g = torch.randn_like(out)
|
64 |
-
out_ref.backward(g)
|
65 |
-
out.backward(g)
|
66 |
-
|
67 |
-
print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}")
|
68 |
-
print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}")
|
69 |
-
if has_bias:
|
70 |
-
print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}")
|
71 |
-
|
72 |
-
assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
|
73 |
-
assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw)
|
74 |
-
if has_bias:
|
75 |
-
assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw)
|
76 |
-
|
77 |
-
|
78 |
-
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
79 |
-
# @pytest.mark.parametrize('itype', [torch.float16])
|
80 |
-
@pytest.mark.parametrize("silu_activation", [False, True])
|
81 |
-
# @pytest.mark.parametrize('silu_activation', [False])
|
82 |
-
@pytest.mark.parametrize("has_bias", [False, True])
|
83 |
-
# @pytest.mark.parametrize('has_bias', [True])
|
84 |
-
@pytest.mark.parametrize("width", [2, 3, 4])
|
85 |
-
# @pytest.mark.parametrize('width', [2])
|
86 |
-
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
87 |
-
# @pytest.mark.parametrize("dim", [2048])
|
88 |
-
def test_causal_conv1d_update(dim, width, has_bias, silu_activation, itype):
|
89 |
-
device = "cuda"
|
90 |
-
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
91 |
-
if itype == torch.bfloat16:
|
92 |
-
rtol, atol = 1e-2, 5e-2
|
93 |
-
rtolw, atolw = (1e-3, 1e-3)
|
94 |
-
# set seed
|
95 |
-
torch.random.manual_seed(0)
|
96 |
-
batch_size = 2
|
97 |
-
# batch_size = 1
|
98 |
-
# dim = 64
|
99 |
-
x = torch.randn(batch_size, dim, device=device, dtype=itype)
|
100 |
-
conv_state = torch.randn(batch_size, dim, width, device=device, dtype=itype)
|
101 |
-
weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
|
102 |
-
if has_bias:
|
103 |
-
bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
104 |
-
else:
|
105 |
-
bias = None
|
106 |
-
conv_state_ref = conv_state.detach().clone()
|
107 |
-
activation = None if not silu_activation else "silu"
|
108 |
-
out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation)
|
109 |
-
out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation)
|
110 |
-
|
111 |
-
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
112 |
-
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
113 |
-
assert torch.equal(conv_state, conv_state_ref)
|
114 |
-
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
115 |
-
|
116 |
-
|
117 |
-
# @pytest.mark.parametrize("channel_last", [False, True])
|
118 |
-
@pytest.mark.parametrize('channel_last', [True])
|
119 |
-
# @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
120 |
-
@pytest.mark.parametrize('itype', [torch.bfloat16])
|
121 |
-
# @pytest.mark.parametrize("silu_activation", [False, True])
|
122 |
-
@pytest.mark.parametrize('silu_activation', [True])
|
123 |
-
# @pytest.mark.parametrize("has_bias", [False, True])
|
124 |
-
@pytest.mark.parametrize('has_bias', [True])
|
125 |
-
# @pytest.mark.parametrize("width", [2, 3, 4])
|
126 |
-
@pytest.mark.parametrize('width', [4])
|
127 |
-
@pytest.mark.parametrize(
|
128 |
-
# "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
|
129 |
-
"seqlen", [2048]
|
130 |
-
)
|
131 |
-
# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
|
132 |
-
# @pytest.mark.parametrize('seqlen', [128])
|
133 |
-
def test_causal_conv1d_race_condition(seqlen, width, has_bias, silu_activation, itype, channel_last):
|
134 |
-
device = "cuda"
|
135 |
-
# set seed
|
136 |
-
torch.random.manual_seed(0)
|
137 |
-
batch_size = 2
|
138 |
-
# batch_size = 1
|
139 |
-
dim = 4096 + 32 # Try dim not divisible by 64
|
140 |
-
# dim = 64
|
141 |
-
if not channel_last:
|
142 |
-
x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
|
143 |
-
else:
|
144 |
-
x = rearrange(
|
145 |
-
torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
|
146 |
-
).requires_grad_()
|
147 |
-
weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
|
148 |
-
if has_bias:
|
149 |
-
bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
150 |
-
else:
|
151 |
-
bias = None
|
152 |
-
activation = None if not silu_activation else "silu"
|
153 |
-
out0 = causal_conv1d_fn(x, weight, bias, activation=activation)
|
154 |
-
g = torch.randn_like(out0)
|
155 |
-
dx0, dw0, db0 = torch.autograd.grad(out0, (x, weight, bias), g)
|
156 |
-
dw_atol = 1e-4
|
157 |
-
db_atol = 1e-4
|
158 |
-
|
159 |
-
for i in range(10000):
|
160 |
-
out = causal_conv1d_fn(x, weight, bias, activation=activation)
|
161 |
-
dx, dw, db = torch.autograd.grad(out, (x, weight, bias), g)
|
162 |
-
dw_equal = torch.allclose(dw, dw0, atol=dw_atol)
|
163 |
-
# if not dw_equal:
|
164 |
-
# breakpoint()
|
165 |
-
if has_bias:
|
166 |
-
db_equal = torch.allclose(db, db0, atol=db_atol)
|
167 |
-
# if not db_equal:
|
168 |
-
# breakpoint()
|
169 |
-
assert torch.equal(out, out0)
|
170 |
-
assert torch.equal(dx, dx0)
|
171 |
-
assert dw_equal
|
172 |
-
if has_bias:
|
173 |
-
assert dw_equal
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
causal_conv1d-1.0.0-cp310-cp310-linux_x86_64.whl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:78328bff9f0cf4814aa3c4029d63aa75128694e07ddae688b16215e3d8a2e7e7
|
3 |
+
size 8424758
|
install.sh
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu118
|
2 |
-
pip install -e causal-conv1d
|
3 |
-
pip install -e mamba
|
|
|
|
|
|
|
|
mamba/.gitmodules
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
[submodule "3rdparty/lm-evaluation-harness"]
|
2 |
-
path = 3rdparty/lm-evaluation-harness
|
3 |
-
url = https://github.com/EleutherAI/lm-evaluation-harness/
|
|
|
|
|
|
|
|
mamba/AUTHORS
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
Tri Dao, tri@tridao.me
|
2 |
-
Albert Gu, agu@andrew.cmu.edu
|
|
|
|
|
|
mamba/LICENSE
DELETED
@@ -1,201 +0,0 @@
|
|
1 |
-
Apache License
|
2 |
-
Version 2.0, January 2004
|
3 |
-
http://www.apache.org/licenses/
|
4 |
-
|
5 |
-
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
-
|
7 |
-
1. Definitions.
|
8 |
-
|
9 |
-
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
-
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
-
|
12 |
-
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
-
the copyright owner that is granting the License.
|
14 |
-
|
15 |
-
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
-
other entities that control, are controlled by, or are under common
|
17 |
-
control with that entity. For the purposes of this definition,
|
18 |
-
"control" means (i) the power, direct or indirect, to cause the
|
19 |
-
direction or management of such entity, whether by contract or
|
20 |
-
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
-
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
-
|
23 |
-
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
-
exercising permissions granted by this License.
|
25 |
-
|
26 |
-
"Source" form shall mean the preferred form for making modifications,
|
27 |
-
including but not limited to software source code, documentation
|
28 |
-
source, and configuration files.
|
29 |
-
|
30 |
-
"Object" form shall mean any form resulting from mechanical
|
31 |
-
transformation or translation of a Source form, including but
|
32 |
-
not limited to compiled object code, generated documentation,
|
33 |
-
and conversions to other media types.
|
34 |
-
|
35 |
-
"Work" shall mean the work of authorship, whether in Source or
|
36 |
-
Object form, made available under the License, as indicated by a
|
37 |
-
copyright notice that is included in or attached to the work
|
38 |
-
(an example is provided in the Appendix below).
|
39 |
-
|
40 |
-
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
-
form, that is based on (or derived from) the Work and for which the
|
42 |
-
editorial revisions, annotations, elaborations, or other modifications
|
43 |
-
represent, as a whole, an original work of authorship. For the purposes
|
44 |
-
of this License, Derivative Works shall not include works that remain
|
45 |
-
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
-
the Work and Derivative Works thereof.
|
47 |
-
|
48 |
-
"Contribution" shall mean any work of authorship, including
|
49 |
-
the original version of the Work and any modifications or additions
|
50 |
-
to that Work or Derivative Works thereof, that is intentionally
|
51 |
-
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
-
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
-
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
-
means any form of electronic, verbal, or written communication sent
|
55 |
-
to the Licensor or its representatives, including but not limited to
|
56 |
-
communication on electronic mailing lists, source code control systems,
|
57 |
-
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
-
Licensor for the purpose of discussing and improving the Work, but
|
59 |
-
excluding communication that is conspicuously marked or otherwise
|
60 |
-
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
-
|
62 |
-
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
-
on behalf of whom a Contribution has been received by Licensor and
|
64 |
-
subsequently incorporated within the Work.
|
65 |
-
|
66 |
-
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
-
this License, each Contributor hereby grants to You a perpetual,
|
68 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
-
copyright license to reproduce, prepare Derivative Works of,
|
70 |
-
publicly display, publicly perform, sublicense, and distribute the
|
71 |
-
Work and such Derivative Works in Source or Object form.
|
72 |
-
|
73 |
-
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
-
this License, each Contributor hereby grants to You a perpetual,
|
75 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
-
(except as stated in this section) patent license to make, have made,
|
77 |
-
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
-
where such license applies only to those patent claims licensable
|
79 |
-
by such Contributor that are necessarily infringed by their
|
80 |
-
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
-
with the Work to which such Contribution(s) was submitted. If You
|
82 |
-
institute patent litigation against any entity (including a
|
83 |
-
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
-
or a Contribution incorporated within the Work constitutes direct
|
85 |
-
or contributory patent infringement, then any patent licenses
|
86 |
-
granted to You under this License for that Work shall terminate
|
87 |
-
as of the date such litigation is filed.
|
88 |
-
|
89 |
-
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
-
Work or Derivative Works thereof in any medium, with or without
|
91 |
-
modifications, and in Source or Object form, provided that You
|
92 |
-
meet the following conditions:
|
93 |
-
|
94 |
-
(a) You must give any other recipients of the Work or
|
95 |
-
Derivative Works a copy of this License; and
|
96 |
-
|
97 |
-
(b) You must cause any modified files to carry prominent notices
|
98 |
-
stating that You changed the files; and
|
99 |
-
|
100 |
-
(c) You must retain, in the Source form of any Derivative Works
|
101 |
-
that You distribute, all copyright, patent, trademark, and
|
102 |
-
attribution notices from the Source form of the Work,
|
103 |
-
excluding those notices that do not pertain to any part of
|
104 |
-
the Derivative Works; and
|
105 |
-
|
106 |
-
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
-
distribution, then any Derivative Works that You distribute must
|
108 |
-
include a readable copy of the attribution notices contained
|
109 |
-
within such NOTICE file, excluding those notices that do not
|
110 |
-
pertain to any part of the Derivative Works, in at least one
|
111 |
-
of the following places: within a NOTICE text file distributed
|
112 |
-
as part of the Derivative Works; within the Source form or
|
113 |
-
documentation, if provided along with the Derivative Works; or,
|
114 |
-
within a display generated by the Derivative Works, if and
|
115 |
-
wherever such third-party notices normally appear. The contents
|
116 |
-
of the NOTICE file are for informational purposes only and
|
117 |
-
do not modify the License. You may add Your own attribution
|
118 |
-
notices within Derivative Works that You distribute, alongside
|
119 |
-
or as an addendum to the NOTICE text from the Work, provided
|
120 |
-
that such additional attribution notices cannot be construed
|
121 |
-
as modifying the License.
|
122 |
-
|
123 |
-
You may add Your own copyright statement to Your modifications and
|
124 |
-
may provide additional or different license terms and conditions
|
125 |
-
for use, reproduction, or distribution of Your modifications, or
|
126 |
-
for any such Derivative Works as a whole, provided Your use,
|
127 |
-
reproduction, and distribution of the Work otherwise complies with
|
128 |
-
the conditions stated in this License.
|
129 |
-
|
130 |
-
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
-
any Contribution intentionally submitted for inclusion in the Work
|
132 |
-
by You to the Licensor shall be under the terms and conditions of
|
133 |
-
this License, without any additional terms or conditions.
|
134 |
-
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
-
the terms of any separate license agreement you may have executed
|
136 |
-
with Licensor regarding such Contributions.
|
137 |
-
|
138 |
-
6. Trademarks. This License does not grant permission to use the trade
|
139 |
-
names, trademarks, service marks, or product names of the Licensor,
|
140 |
-
except as required for reasonable and customary use in describing the
|
141 |
-
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
-
|
143 |
-
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
-
agreed to in writing, Licensor provides the Work (and each
|
145 |
-
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
-
implied, including, without limitation, any warranties or conditions
|
148 |
-
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
-
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
-
appropriateness of using or redistributing the Work and assume any
|
151 |
-
risks associated with Your exercise of permissions under this License.
|
152 |
-
|
153 |
-
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
-
whether in tort (including negligence), contract, or otherwise,
|
155 |
-
unless required by applicable law (such as deliberate and grossly
|
156 |
-
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
-
liable to You for damages, including any direct, indirect, special,
|
158 |
-
incidental, or consequential damages of any character arising as a
|
159 |
-
result of this License or out of the use or inability to use the
|
160 |
-
Work (including but not limited to damages for loss of goodwill,
|
161 |
-
work stoppage, computer failure or malfunction, or any and all
|
162 |
-
other commercial damages or losses), even if such Contributor
|
163 |
-
has been advised of the possibility of such damages.
|
164 |
-
|
165 |
-
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
-
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
-
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
-
or other liability obligations and/or rights consistent with this
|
169 |
-
License. However, in accepting such obligations, You may act only
|
170 |
-
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
-
of any other Contributor, and only if You agree to indemnify,
|
172 |
-
defend, and hold each Contributor harmless for any liability
|
173 |
-
incurred by, or claims asserted against, such Contributor by reason
|
174 |
-
of your accepting any such warranty or additional liability.
|
175 |
-
|
176 |
-
END OF TERMS AND CONDITIONS
|
177 |
-
|
178 |
-
APPENDIX: How to apply the Apache License to your work.
|
179 |
-
|
180 |
-
To apply the Apache License to your work, attach the following
|
181 |
-
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
-
replaced with your own identifying information. (Don't include
|
183 |
-
the brackets!) The text should be enclosed in the appropriate
|
184 |
-
comment syntax for the file format. We also recommend that a
|
185 |
-
file or class name and description of purpose be included on the
|
186 |
-
same "printed page" as the copyright notice for easier
|
187 |
-
identification within third-party archives.
|
188 |
-
|
189 |
-
Copyright 2023 Tri Dao, Albert Gu
|
190 |
-
|
191 |
-
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
-
you may not use this file except in compliance with the License.
|
193 |
-
You may obtain a copy of the License at
|
194 |
-
|
195 |
-
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
-
|
197 |
-
Unless required by applicable law or agreed to in writing, software
|
198 |
-
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
-
See the License for the specific language governing permissions and
|
201 |
-
limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mamba/README.md
DELETED
@@ -1,149 +0,0 @@
|
|
1 |
-
# Mamba
|
2 |
-
|
3 |
-
![Mamba](assets/selection.png "Selective State Space")
|
4 |
-
> **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\
|
5 |
-
> Albert Gu*, Tri Dao*\
|
6 |
-
> Paper: https://arxiv.org/abs/2312.00752
|
7 |
-
|
8 |
-
## About
|
9 |
-
|
10 |
-
Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers.
|
11 |
-
It is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4),
|
12 |
-
with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention).
|
13 |
-
|
14 |
-
## Installation
|
15 |
-
|
16 |
-
- `pip install causal-conv1d`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
|
17 |
-
- `pip install mamba-ssm`: the core Mamba package.
|
18 |
-
|
19 |
-
It can also be built from source with `pip install .` from this repository.
|
20 |
-
|
21 |
-
If `pip` complains about PyTorch versions, try passing `--no-build-isolation` to `pip`.
|
22 |
-
|
23 |
-
Other requirements:
|
24 |
-
- Linux
|
25 |
-
- NVIDIA GPU
|
26 |
-
- PyTorch 1.12+
|
27 |
-
- CUDA 11.6+
|
28 |
-
|
29 |
-
## Usage
|
30 |
-
|
31 |
-
We expose several levels of interface with the Mamba model.
|
32 |
-
|
33 |
-
### Selective SSM
|
34 |
-
|
35 |
-
Mamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2).
|
36 |
-
|
37 |
-
Source: [ops/selective_scan_interface.py](mamba_ssm/ops/selective_scan_interface.py).
|
38 |
-
|
39 |
-
### Mamba Block
|
40 |
-
|
41 |
-
The main module of this repository is the Mamba architecture block wrapping the selective SSM.
|
42 |
-
|
43 |
-
Source: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py).
|
44 |
-
|
45 |
-
Usage:
|
46 |
-
```
|
47 |
-
from mamba_ssm import Mamba
|
48 |
-
|
49 |
-
batch, length, dim = 2, 64, 16
|
50 |
-
x = torch.randn(batch, length, dim).to("cuda")
|
51 |
-
model = Mamba(
|
52 |
-
# This module uses roughly 3 * expand * d_model^2 parameters
|
53 |
-
d_model=dim, # Model dimension d_model
|
54 |
-
d_state=16, # SSM state expansion factor
|
55 |
-
d_conv=4, # Local convolution width
|
56 |
-
expand=2, # Block expansion factor
|
57 |
-
).to("cuda")
|
58 |
-
y = model(x)
|
59 |
-
assert y.shape == x.shape
|
60 |
-
```
|
61 |
-
|
62 |
-
### Mamba Language Model
|
63 |
-
|
64 |
-
Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head.
|
65 |
-
|
66 |
-
Source: [models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py).
|
67 |
-
|
68 |
-
This is an example of how to integrate Mamba into an end-to-end neural network.
|
69 |
-
This example is used in the generation scripts below.
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
## Pretrained Models
|
74 |
-
|
75 |
-
Pretrained models are uploaded to
|
76 |
-
[HuggingFace](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`,
|
77 |
-
`mamba-790m`, `mamba-1.4b`, `mamba-2.8b`.
|
78 |
-
|
79 |
-
The models will be autodownloaded by the generation script below.
|
80 |
-
|
81 |
-
These models were trained on the [Pile](https://huggingface.co/datasets/EleutherAI/pile), and follow the standard model dimensions described by GPT-3 and followed by many open source models:
|
82 |
-
|
83 |
-
| Parameters | Layers | Model dim. |
|
84 |
-
|------------|--------|------------|
|
85 |
-
| 130M | 12 | 768 |
|
86 |
-
| 370M | 24 | 1024 |
|
87 |
-
| 790M | 24 | 1536 |
|
88 |
-
| 1.4B | 24 | 2048 |
|
89 |
-
| 2.8B | 32 | 2560 |
|
90 |
-
|
91 |
-
(The layer count of Mamba should be doubled, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.)
|
92 |
-
|
93 |
-
Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.).
|
94 |
-
Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models.
|
95 |
-
|
96 |
-
|
97 |
-
## Evaluations
|
98 |
-
|
99 |
-
To run zero-shot evaluations of models (corresponding to Table 3 of the paper),
|
100 |
-
we use the
|
101 |
-
[lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor)
|
102 |
-
library.
|
103 |
-
|
104 |
-
1. Pull the `lm-evaluation-harness` repo by `git submodule update --init
|
105 |
-
--recursive`. We use the `big-refactor` branch.
|
106 |
-
2. Install `lm-evaluation-harness`: `pip install -e 3rdparty/lm-evaluation-harness`
|
107 |
-
3. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo):
|
108 |
-
```
|
109 |
-
python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
|
110 |
-
python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
|
111 |
-
```
|
112 |
-
|
113 |
-
Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process.
|
114 |
-
|
115 |
-
## Inference
|
116 |
-
|
117 |
-
The script [benchmarks/benchmark_generation_mamba_simple.py](benchmarks/benchmark_generation_mamba_simple.py)
|
118 |
-
1. autoloads a model from the HuggingFace Hub,
|
119 |
-
2. generates completions of a user-specified prompt,
|
120 |
-
3. benchmarks the inference speed of this generation.
|
121 |
-
|
122 |
-
Other configurable options include the top-p (nucleus sampling) probability, and the softmax temperature.
|
123 |
-
|
124 |
-
### Examples
|
125 |
-
|
126 |
-
To test generation latency (e.g. batch size = 1) with different sampling strategies:
|
127 |
-
|
128 |
-
```
|
129 |
-
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5
|
130 |
-
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5
|
131 |
-
```
|
132 |
-
|
133 |
-
To test generation throughput with random prompts (e.g. large batch size):
|
134 |
-
```
|
135 |
-
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 128
|
136 |
-
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 128
|
137 |
-
```
|
138 |
-
|
139 |
-
## Citation
|
140 |
-
|
141 |
-
If you use this codebase, or otherwise found our work valuable, please cite Mamba:
|
142 |
-
```
|
143 |
-
@article{mamba,
|
144 |
-
title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
|
145 |
-
author={Gu, Albert and Dao, Tri},
|
146 |
-
journal={arXiv preprint arXiv:2312.00752},
|
147 |
-
year={2023}
|
148 |
-
}
|
149 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mamba/assets/selection.png
DELETED
Binary file (819 kB)
|
|
mamba/benchmarks/benchmark_generation_mamba_simple.py
DELETED
@@ -1,88 +0,0 @@
|
|
1 |
-
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
2 |
-
|
3 |
-
import argparse
|
4 |
-
import time
|
5 |
-
import json
|
6 |
-
|
7 |
-
import torch
|
8 |
-
import torch.nn.functional as F
|
9 |
-
|
10 |
-
from einops import rearrange
|
11 |
-
|
12 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
13 |
-
|
14 |
-
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
15 |
-
|
16 |
-
|
17 |
-
parser = argparse.ArgumentParser(description="Generation benchmarking")
|
18 |
-
parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m")
|
19 |
-
parser.add_argument("--prompt", type=str, default=None)
|
20 |
-
parser.add_argument("--promptlen", type=int, default=100)
|
21 |
-
parser.add_argument("--genlen", type=int, default=100)
|
22 |
-
parser.add_argument("--temperature", type=float, default=1.0)
|
23 |
-
parser.add_argument("--topk", type=int, default=1)
|
24 |
-
parser.add_argument("--topp", type=float, default=1.0)
|
25 |
-
parser.add_argument("--batch", type=int, default=1)
|
26 |
-
args = parser.parse_args()
|
27 |
-
|
28 |
-
repeats = 3
|
29 |
-
device = "cuda"
|
30 |
-
dtype = torch.float16
|
31 |
-
|
32 |
-
print(f"Loading model {args.model_name}")
|
33 |
-
is_mamba = args.model_name.startswith("state-spaces/mamba-") or "mamba" in args.model_name
|
34 |
-
|
35 |
-
if is_mamba:
|
36 |
-
tokenizer = AutoTokenizer.from_pretrained("/home/zhulianghui/VisionProjects/mamba/ckpts/gpt-neox-20b-tokenizer")
|
37 |
-
model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype)
|
38 |
-
else:
|
39 |
-
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
40 |
-
model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype)
|
41 |
-
model.eval()
|
42 |
-
print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
43 |
-
|
44 |
-
torch.random.manual_seed(0)
|
45 |
-
if args.prompt is None:
|
46 |
-
input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda")
|
47 |
-
attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda")
|
48 |
-
else:
|
49 |
-
tokens = tokenizer(args.prompt, return_tensors="pt")
|
50 |
-
input_ids = tokens.input_ids.to(device=device)
|
51 |
-
attn_mask = tokens.attention_mask.to(device=device)
|
52 |
-
max_length = input_ids.shape[1] + args.genlen
|
53 |
-
|
54 |
-
if is_mamba:
|
55 |
-
fn = lambda: model.generate(
|
56 |
-
input_ids=input_ids,
|
57 |
-
max_length=max_length,
|
58 |
-
cg=True,
|
59 |
-
return_dict_in_generate=True,
|
60 |
-
output_scores=True,
|
61 |
-
enable_timing=False,
|
62 |
-
temperature=args.temperature,
|
63 |
-
top_k=args.topk,
|
64 |
-
top_p=args.topp,
|
65 |
-
)
|
66 |
-
else:
|
67 |
-
fn = lambda: model.generate(
|
68 |
-
input_ids=input_ids,
|
69 |
-
attention_mask=attn_mask,
|
70 |
-
max_length=max_length,
|
71 |
-
return_dict_in_generate=True,
|
72 |
-
pad_token_id=tokenizer.eos_token_id,
|
73 |
-
do_sample=True,
|
74 |
-
temperature=args.temperature,
|
75 |
-
top_k=args.topk,
|
76 |
-
top_p=args.topp,
|
77 |
-
)
|
78 |
-
out = fn()
|
79 |
-
if args.prompt is not None:
|
80 |
-
print(tokenizer.batch_decode(out.sequences.tolist()))
|
81 |
-
|
82 |
-
torch.cuda.synchronize()
|
83 |
-
start = time.time()
|
84 |
-
for _ in range(repeats):
|
85 |
-
fn()
|
86 |
-
torch.cuda.synchronize()
|
87 |
-
print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}")
|
88 |
-
print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mamba/csrc/selective_scan/reverse_scan.cuh
DELETED
@@ -1,401 +0,0 @@
|
|
1 |
-
/******************************************************************************
|
2 |
-
* Copyright (c) 2023, Tri Dao.
|
3 |
-
******************************************************************************/
|
4 |
-
|
5 |
-
#pragma once
|
6 |
-
|
7 |
-
#include <cub/config.cuh>
|
8 |
-
|
9 |
-
#include <cub/util_ptx.cuh>
|
10 |
-
#include <cub/util_type.cuh>
|
11 |
-
#include <cub/block/block_raking_layout.cuh>
|
12 |
-
// #include <cub/detail/uninitialized_copy.cuh>
|
13 |
-
#include "uninitialized_copy.cuh"
|
14 |
-
|
15 |
-
/**
|
16 |
-
* Perform a reverse sequential reduction over \p LENGTH elements of the \p input array. The aggregate is returned.
|
17 |
-
*/
|
18 |
-
template <
|
19 |
-
int LENGTH,
|
20 |
-
typename T,
|
21 |
-
typename ReductionOp>
|
22 |
-
__device__ __forceinline__ T ThreadReverseReduce(const T (&input)[LENGTH], ReductionOp reduction_op) {
|
23 |
-
static_assert(LENGTH > 0);
|
24 |
-
T retval = input[LENGTH - 1];
|
25 |
-
#pragma unroll
|
26 |
-
for (int i = LENGTH - 2; i >= 0; --i) { retval = reduction_op(retval, input[i]); }
|
27 |
-
return retval;
|
28 |
-
}
|
29 |
-
|
30 |
-
/**
|
31 |
-
* Perform a sequential inclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned.
|
32 |
-
*/
|
33 |
-
template <
|
34 |
-
int LENGTH,
|
35 |
-
typename T,
|
36 |
-
typename ScanOp>
|
37 |
-
__device__ __forceinline__ T ThreadReverseScanInclusive(
|
38 |
-
const T (&input)[LENGTH],
|
39 |
-
T (&output)[LENGTH],
|
40 |
-
ScanOp scan_op,
|
41 |
-
const T postfix)
|
42 |
-
{
|
43 |
-
T inclusive = postfix;
|
44 |
-
#pragma unroll
|
45 |
-
for (int i = LENGTH - 1; i >= 0; --i) {
|
46 |
-
inclusive = scan_op(inclusive, input[i]);
|
47 |
-
output[i] = inclusive;
|
48 |
-
}
|
49 |
-
}
|
50 |
-
|
51 |
-
/**
|
52 |
-
* Perform a sequential exclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned.
|
53 |
-
*/
|
54 |
-
template <
|
55 |
-
int LENGTH,
|
56 |
-
typename T,
|
57 |
-
typename ScanOp>
|
58 |
-
__device__ __forceinline__ T ThreadReverseScanExclusive(
|
59 |
-
const T (&input)[LENGTH],
|
60 |
-
T (&output)[LENGTH],
|
61 |
-
ScanOp scan_op,
|
62 |
-
const T postfix)
|
63 |
-
{
|
64 |
-
// Careful, output maybe be aliased to input
|
65 |
-
T exclusive = postfix;
|
66 |
-
T inclusive;
|
67 |
-
#pragma unroll
|
68 |
-
for (int i = LENGTH - 1; i >= 0; --i) {
|
69 |
-
inclusive = scan_op(exclusive, input[i]);
|
70 |
-
output[i] = exclusive;
|
71 |
-
exclusive = inclusive;
|
72 |
-
}
|
73 |
-
return inclusive;
|
74 |
-
}
|
75 |
-
|
76 |
-
|
77 |
-
/**
|
78 |
-
* \brief WarpReverseScan provides SHFL-based variants of parallel postfix scan of items partitioned across a CUDA thread warp.
|
79 |
-
*
|
80 |
-
* LOGICAL_WARP_THREADS must be a power-of-two
|
81 |
-
*/
|
82 |
-
template <
|
83 |
-
typename T, ///< Data type being scanned
|
84 |
-
int LOGICAL_WARP_THREADS ///< Number of threads per logical warp
|
85 |
-
>
|
86 |
-
struct WarpReverseScan {
|
87 |
-
//---------------------------------------------------------------------
|
88 |
-
// Constants and type definitions
|
89 |
-
//---------------------------------------------------------------------
|
90 |
-
|
91 |
-
/// Whether the logical warp size and the PTX warp size coincide
|
92 |
-
static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0));
|
93 |
-
/// The number of warp scan steps
|
94 |
-
static constexpr int STEPS = cub::Log2<LOGICAL_WARP_THREADS>::VALUE;
|
95 |
-
static_assert(LOGICAL_WARP_THREADS == 1 << STEPS);
|
96 |
-
|
97 |
-
|
98 |
-
//---------------------------------------------------------------------
|
99 |
-
// Thread fields
|
100 |
-
//---------------------------------------------------------------------
|
101 |
-
|
102 |
-
/// Lane index in logical warp
|
103 |
-
unsigned int lane_id;
|
104 |
-
|
105 |
-
/// Logical warp index in 32-thread physical warp
|
106 |
-
unsigned int warp_id;
|
107 |
-
|
108 |
-
/// 32-thread physical warp member mask of logical warp
|
109 |
-
unsigned int member_mask;
|
110 |
-
|
111 |
-
//---------------------------------------------------------------------
|
112 |
-
// Construction
|
113 |
-
//---------------------------------------------------------------------
|
114 |
-
|
115 |
-
/// Constructor
|
116 |
-
explicit __device__ __forceinline__
|
117 |
-
WarpReverseScan()
|
118 |
-
: lane_id(cub::LaneId())
|
119 |
-
, warp_id(IS_ARCH_WARP ? 0 : (lane_id / LOGICAL_WARP_THREADS))
|
120 |
-
, member_mask(cub::WarpMask<LOGICAL_WARP_THREADS>(warp_id))
|
121 |
-
{
|
122 |
-
if (!IS_ARCH_WARP) {
|
123 |
-
lane_id = lane_id % LOGICAL_WARP_THREADS;
|
124 |
-
}
|
125 |
-
}
|
126 |
-
|
127 |
-
|
128 |
-
/// Broadcast
|
129 |
-
__device__ __forceinline__ T Broadcast(
|
130 |
-
T input, ///< [in] The value to broadcast
|
131 |
-
int src_lane) ///< [in] Which warp lane is to do the broadcasting
|
132 |
-
{
|
133 |
-
return cub::ShuffleIndex<LOGICAL_WARP_THREADS>(input, src_lane, member_mask);
|
134 |
-
}
|
135 |
-
|
136 |
-
|
137 |
-
/// Inclusive scan
|
138 |
-
template <typename ScanOpT>
|
139 |
-
__device__ __forceinline__ void InclusiveReverseScan(
|
140 |
-
T input, ///< [in] Calling thread's input item.
|
141 |
-
T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input.
|
142 |
-
ScanOpT scan_op) ///< [in] Binary scan operator
|
143 |
-
{
|
144 |
-
inclusive_output = input;
|
145 |
-
#pragma unroll
|
146 |
-
for (int STEP = 0; STEP < STEPS; STEP++) {
|
147 |
-
int offset = 1 << STEP;
|
148 |
-
T temp = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
|
149 |
-
inclusive_output, offset, LOGICAL_WARP_THREADS - 1, member_mask
|
150 |
-
);
|
151 |
-
// Perform scan op if from a valid peer
|
152 |
-
inclusive_output = static_cast<int>(lane_id) >= LOGICAL_WARP_THREADS - offset
|
153 |
-
? inclusive_output : scan_op(temp, inclusive_output);
|
154 |
-
}
|
155 |
-
}
|
156 |
-
|
157 |
-
/// Exclusive scan
|
158 |
-
// Get exclusive from inclusive
|
159 |
-
template <typename ScanOpT>
|
160 |
-
__device__ __forceinline__ void ExclusiveReverseScan(
|
161 |
-
T input, ///< [in] Calling thread's input item.
|
162 |
-
T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input.
|
163 |
-
ScanOpT scan_op, ///< [in] Binary scan operator
|
164 |
-
T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items.
|
165 |
-
{
|
166 |
-
T inclusive_output;
|
167 |
-
InclusiveReverseScan(input, inclusive_output, scan_op);
|
168 |
-
warp_aggregate = cub::ShuffleIndex<LOGICAL_WARP_THREADS>(inclusive_output, 0, member_mask);
|
169 |
-
// initial value unknown
|
170 |
-
exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
|
171 |
-
inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask
|
172 |
-
);
|
173 |
-
}
|
174 |
-
|
175 |
-
/**
|
176 |
-
* \brief Computes both inclusive and exclusive reverse scans using the specified binary scan functor across the calling warp. Because no initial value is supplied, the \p exclusive_output computed for the last <em>warp-lane</em> is undefined.
|
177 |
-
*/
|
178 |
-
template <typename ScanOpT>
|
179 |
-
__device__ __forceinline__ void ReverseScan(
|
180 |
-
T input, ///< [in] Calling thread's input item.
|
181 |
-
T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item.
|
182 |
-
T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item.
|
183 |
-
ScanOpT scan_op) ///< [in] Binary scan operator
|
184 |
-
{
|
185 |
-
InclusiveReverseScan(input, inclusive_output, scan_op);
|
186 |
-
// initial value unknown
|
187 |
-
exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
|
188 |
-
inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask
|
189 |
-
);
|
190 |
-
}
|
191 |
-
|
192 |
-
};
|
193 |
-
|
194 |
-
/**
|
195 |
-
* \brief BlockReverseScan provides variants of raking-based parallel postfix scan across a CUDA thread block.
|
196 |
-
*/
|
197 |
-
template <
|
198 |
-
typename T, ///< Data type being scanned
|
199 |
-
int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension
|
200 |
-
bool MEMOIZE=false ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure
|
201 |
-
>
|
202 |
-
struct BlockReverseScan {
|
203 |
-
//---------------------------------------------------------------------
|
204 |
-
// Types and constants
|
205 |
-
//---------------------------------------------------------------------
|
206 |
-
|
207 |
-
/// Constants
|
208 |
-
/// The thread block size in threads
|
209 |
-
static constexpr int BLOCK_THREADS = BLOCK_DIM_X;
|
210 |
-
|
211 |
-
/// Layout type for padded thread block raking grid
|
212 |
-
using BlockRakingLayout = cub::BlockRakingLayout<T, BLOCK_THREADS>;
|
213 |
-
// The number of reduction elements is not a multiple of the number of raking threads for now
|
214 |
-
static_assert(BlockRakingLayout::UNGUARDED);
|
215 |
-
|
216 |
-
/// Number of raking threads
|
217 |
-
static constexpr int RAKING_THREADS = BlockRakingLayout::RAKING_THREADS;
|
218 |
-
/// Number of raking elements per warp synchronous raking thread
|
219 |
-
static constexpr int SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH;
|
220 |
-
/// Cooperative work can be entirely warp synchronous
|
221 |
-
static constexpr bool WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS));
|
222 |
-
|
223 |
-
/// WarpReverseScan utility type
|
224 |
-
using WarpReverseScan = WarpReverseScan<T, RAKING_THREADS>;
|
225 |
-
|
226 |
-
/// Shared memory storage layout type
|
227 |
-
struct _TempStorage {
|
228 |
-
typename BlockRakingLayout::TempStorage raking_grid; ///< Padded thread block raking grid
|
229 |
-
};
|
230 |
-
|
231 |
-
|
232 |
-
/// Alias wrapper allowing storage to be unioned
|
233 |
-
struct TempStorage : cub::Uninitialized<_TempStorage> {};
|
234 |
-
|
235 |
-
|
236 |
-
//---------------------------------------------------------------------
|
237 |
-
// Per-thread fields
|
238 |
-
//---------------------------------------------------------------------
|
239 |
-
|
240 |
-
// Thread fields
|
241 |
-
_TempStorage &temp_storage;
|
242 |
-
unsigned int linear_tid;
|
243 |
-
T cached_segment[SEGMENT_LENGTH];
|
244 |
-
|
245 |
-
|
246 |
-
//---------------------------------------------------------------------
|
247 |
-
// Utility methods
|
248 |
-
//---------------------------------------------------------------------
|
249 |
-
|
250 |
-
/// Performs upsweep raking reduction, returning the aggregate
|
251 |
-
template <typename ScanOp>
|
252 |
-
__device__ __forceinline__ T Upsweep(ScanOp scan_op) {
|
253 |
-
T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);
|
254 |
-
// Read data into registers
|
255 |
-
#pragma unroll
|
256 |
-
for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }
|
257 |
-
T raking_partial = cached_segment[SEGMENT_LENGTH - 1];
|
258 |
-
#pragma unroll
|
259 |
-
for (int i = SEGMENT_LENGTH - 2; i >= 0; --i) {
|
260 |
-
raking_partial = scan_op(raking_partial, cached_segment[i]);
|
261 |
-
}
|
262 |
-
return raking_partial;
|
263 |
-
}
|
264 |
-
|
265 |
-
|
266 |
-
/// Performs exclusive downsweep raking scan
|
267 |
-
template <typename ScanOp>
|
268 |
-
__device__ __forceinline__ void ExclusiveDownsweep(
|
269 |
-
ScanOp scan_op,
|
270 |
-
T raking_partial)
|
271 |
-
{
|
272 |
-
T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);
|
273 |
-
// Read data back into registers
|
274 |
-
if (!MEMOIZE) {
|
275 |
-
#pragma unroll
|
276 |
-
for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }
|
277 |
-
}
|
278 |
-
ThreadReverseScanExclusive(cached_segment, cached_segment, scan_op, raking_partial);
|
279 |
-
// Write data back to smem
|
280 |
-
#pragma unroll
|
281 |
-
for (int i = 0; i < SEGMENT_LENGTH; ++i) { smem_raking_ptr[i] = cached_segment[i]; }
|
282 |
-
}
|
283 |
-
|
284 |
-
|
285 |
-
//---------------------------------------------------------------------
|
286 |
-
// Constructors
|
287 |
-
//---------------------------------------------------------------------
|
288 |
-
|
289 |
-
/// Constructor
|
290 |
-
__device__ __forceinline__ BlockReverseScan(
|
291 |
-
TempStorage &temp_storage)
|
292 |
-
:
|
293 |
-
temp_storage(temp_storage.Alias()),
|
294 |
-
linear_tid(cub::RowMajorTid(BLOCK_DIM_X, 1, 1))
|
295 |
-
{}
|
296 |
-
|
297 |
-
|
298 |
-
/// Computes an exclusive thread block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs.
|
299 |
-
template <
|
300 |
-
typename ScanOp,
|
301 |
-
typename BlockPostfixCallbackOp>
|
302 |
-
__device__ __forceinline__ void ExclusiveReverseScan(
|
303 |
-
T input, ///< [in] Calling thread's input item
|
304 |
-
T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input)
|
305 |
-
ScanOp scan_op, ///< [in] Binary scan operator
|
306 |
-
BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a thread block-wide postfix to be applied to all inputs.
|
307 |
-
{
|
308 |
-
if (WARP_SYNCHRONOUS) {
|
309 |
-
// Short-circuit directly to warp-synchronous scan
|
310 |
-
T block_aggregate;
|
311 |
-
WarpReverseScan warp_scan;
|
312 |
-
warp_scan.ExclusiveReverseScan(input, exclusive_output, scan_op, block_aggregate);
|
313 |
-
// Obtain warp-wide postfix in lane0, then broadcast to other lanes
|
314 |
-
T block_postfix = block_postfix_callback_op(block_aggregate);
|
315 |
-
block_postfix = warp_scan.Broadcast(block_postfix, 0);
|
316 |
-
exclusive_output = linear_tid == BLOCK_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_output);
|
317 |
-
} else {
|
318 |
-
// Place thread partial into shared memory raking grid
|
319 |
-
T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid);
|
320 |
-
detail::uninitialized_copy(placement_ptr, input);
|
321 |
-
cub::CTA_SYNC();
|
322 |
-
// Reduce parallelism down to just raking threads
|
323 |
-
if (linear_tid < RAKING_THREADS) {
|
324 |
-
WarpReverseScan warp_scan;
|
325 |
-
// Raking upsweep reduction across shared partials
|
326 |
-
T upsweep_partial = Upsweep(scan_op);
|
327 |
-
// Warp-synchronous scan
|
328 |
-
T exclusive_partial, block_aggregate;
|
329 |
-
warp_scan.ExclusiveReverseScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate);
|
330 |
-
// Obtain block-wide postfix in lane0, then broadcast to other lanes
|
331 |
-
T block_postfix = block_postfix_callback_op(block_aggregate);
|
332 |
-
block_postfix = warp_scan.Broadcast(block_postfix, 0);
|
333 |
-
// Update postfix with warpscan exclusive partial
|
334 |
-
T downsweep_postfix = linear_tid == RAKING_THREADS - 1
|
335 |
-
? block_postfix : scan_op(block_postfix, exclusive_partial);
|
336 |
-
// Exclusive raking downsweep scan
|
337 |
-
ExclusiveDownsweep(scan_op, downsweep_postfix);
|
338 |
-
}
|
339 |
-
cub::CTA_SYNC();
|
340 |
-
// Grab thread postfix from shared memory
|
341 |
-
exclusive_output = *placement_ptr;
|
342 |
-
|
343 |
-
// // Compute warp scan in each warp.
|
344 |
-
// // The exclusive output from the last lane in each warp is invalid.
|
345 |
-
// T inclusive_output;
|
346 |
-
// WarpReverseScan warp_scan;
|
347 |
-
// warp_scan.ReverseScan(input, inclusive_output, exclusive_output, scan_op);
|
348 |
-
|
349 |
-
// // Compute the warp-wide postfix and block-wide aggregate for each warp. Warp postfix for the last warp is invalid.
|
350 |
-
// T block_aggregate;
|
351 |
-
// T warp_postfix = ComputeWarpPostfix(scan_op, inclusive_output, block_aggregate);
|
352 |
-
|
353 |
-
// // Apply warp postfix to our lane's partial
|
354 |
-
// if (warp_id != 0) {
|
355 |
-
// exclusive_output = scan_op(warp_postfix, exclusive_output);
|
356 |
-
// if (lane_id == 0) { exclusive_output = warp_postfix; }
|
357 |
-
// }
|
358 |
-
|
359 |
-
// // Use the first warp to determine the thread block postfix, returning the result in lane0
|
360 |
-
// if (warp_id == 0) {
|
361 |
-
// T block_postfix = block_postfix_callback_op(block_aggregate);
|
362 |
-
// if (lane_id == 0) {
|
363 |
-
// // Share the postfix with all threads
|
364 |
-
// detail::uninitialized_copy(&temp_storage.block_postfix,
|
365 |
-
// block_postfix);
|
366 |
-
|
367 |
-
// exclusive_output = block_postfix; // The block postfix is the exclusive output for tid0
|
368 |
-
// }
|
369 |
-
// }
|
370 |
-
|
371 |
-
// cub::CTA_SYNC();
|
372 |
-
|
373 |
-
// // Incorporate thread block postfix into outputs
|
374 |
-
// T block_postfix = temp_storage.block_postfix;
|
375 |
-
// if (linear_tid > 0) { exclusive_output = scan_op(block_postfix, exclusive_output); }
|
376 |
-
}
|
377 |
-
}
|
378 |
-
|
379 |
-
|
380 |
-
/**
|
381 |
-
* \brief Computes an inclusive block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs.
|
382 |
-
*/
|
383 |
-
template <
|
384 |
-
int ITEMS_PER_THREAD,
|
385 |
-
typename ScanOp,
|
386 |
-
typename BlockPostfixCallbackOp>
|
387 |
-
__device__ __forceinline__ void InclusiveReverseScan(
|
388 |
-
T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items
|
389 |
-
T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input)
|
390 |
-
ScanOp scan_op, ///< [in] Binary scan functor
|
391 |
-
BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a block-wide postfix to be applied to the logical input sequence.
|
392 |
-
{
|
393 |
-
// Reduce consecutive thread items in registers
|
394 |
-
T thread_postfix = ThreadReverseReduce(input, scan_op);
|
395 |
-
// Exclusive thread block-scan
|
396 |
-
ExclusiveReverseScan(thread_postfix, thread_postfix, scan_op, block_postfix_callback_op);
|
397 |
-
// Inclusive scan in registers with postfix as seed
|
398 |
-
ThreadReverseScanInclusive(input, output, scan_op, thread_postfix);
|
399 |
-
}
|
400 |
-
|
401 |
-
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mamba/csrc/selective_scan/selective_scan.cpp
DELETED
@@ -1,497 +0,0 @@
|
|
1 |
-
/******************************************************************************
|
2 |
-
* Copyright (c) 2023, Tri Dao.
|
3 |
-
******************************************************************************/
|
4 |
-
|
5 |
-
#include <ATen/cuda/CUDAContext.h>
|
6 |
-
#include <c10/cuda/CUDAGuard.h>
|
7 |
-
#include <torch/extension.h>
|
8 |
-
#include <vector>
|
9 |
-
|
10 |
-
#include "selective_scan.h"
|
11 |
-
|
12 |
-
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
13 |
-
|
14 |
-
#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
15 |
-
if (ITYPE == at::ScalarType::Half) { \
|
16 |
-
using input_t = at::Half; \
|
17 |
-
__VA_ARGS__(); \
|
18 |
-
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
19 |
-
using input_t = at::BFloat16; \
|
20 |
-
__VA_ARGS__(); \
|
21 |
-
} else if (ITYPE == at::ScalarType::Float) { \
|
22 |
-
using input_t = float; \
|
23 |
-
__VA_ARGS__(); \
|
24 |
-
} else { \
|
25 |
-
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
26 |
-
}
|
27 |
-
|
28 |
-
#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
|
29 |
-
if (WTYPE == at::ScalarType::Half) { \
|
30 |
-
using weight_t = at::Half; \
|
31 |
-
__VA_ARGS__(); \
|
32 |
-
} else if (WTYPE == at::ScalarType::BFloat16) { \
|
33 |
-
using weight_t = at::BFloat16; \
|
34 |
-
__VA_ARGS__(); \
|
35 |
-
} else if (WTYPE == at::ScalarType::Float) { \
|
36 |
-
using weight_t = float; \
|
37 |
-
__VA_ARGS__(); \
|
38 |
-
} else { \
|
39 |
-
AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
|
40 |
-
}
|
41 |
-
|
42 |
-
#define DISPATCH_WTYPE_FLOAT_AND_COMPLEX(WTYPE, NAME, ...) \
|
43 |
-
if (WTYPE == at::ScalarType::Float) { \
|
44 |
-
using weight_t = float; \
|
45 |
-
__VA_ARGS__(); \
|
46 |
-
} else if (WTYPE == at::ScalarType::ComplexFloat) { \
|
47 |
-
using weight_t = c10::complex<float>; \
|
48 |
-
__VA_ARGS__(); \
|
49 |
-
} else { \
|
50 |
-
AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
|
51 |
-
}
|
52 |
-
|
53 |
-
template<typename input_t, typename weight_t>
|
54 |
-
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
|
55 |
-
|
56 |
-
template <typename input_t, typename weight_t>
|
57 |
-
void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
|
58 |
-
|
59 |
-
void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
60 |
-
// sizes
|
61 |
-
const size_t batch,
|
62 |
-
const size_t dim,
|
63 |
-
const size_t seqlen,
|
64 |
-
const size_t dstate,
|
65 |
-
const size_t n_groups,
|
66 |
-
const size_t n_chunks,
|
67 |
-
const bool is_variable_B,
|
68 |
-
const bool is_variable_C,
|
69 |
-
// device pointers
|
70 |
-
const at::Tensor u,
|
71 |
-
const at::Tensor delta,
|
72 |
-
const at::Tensor A,
|
73 |
-
const at::Tensor B,
|
74 |
-
const at::Tensor C,
|
75 |
-
const at::Tensor out,
|
76 |
-
const at::Tensor z,
|
77 |
-
const at::Tensor out_z,
|
78 |
-
void* D_ptr,
|
79 |
-
void* delta_bias_ptr,
|
80 |
-
void* x_ptr,
|
81 |
-
bool has_z,
|
82 |
-
bool delta_softplus) {
|
83 |
-
|
84 |
-
// Reset the parameters
|
85 |
-
memset(¶ms, 0, sizeof(params));
|
86 |
-
|
87 |
-
params.batch = batch;
|
88 |
-
params.dim = dim;
|
89 |
-
params.seqlen = seqlen;
|
90 |
-
params.dstate = dstate;
|
91 |
-
params.n_groups = n_groups;
|
92 |
-
params.n_chunks = n_chunks;
|
93 |
-
params.dim_ngroups_ratio = dim / n_groups;
|
94 |
-
|
95 |
-
params.delta_softplus = delta_softplus;
|
96 |
-
|
97 |
-
params.is_variable_B = is_variable_B;
|
98 |
-
params.is_variable_C = is_variable_C;
|
99 |
-
|
100 |
-
// Set the pointers and strides.
|
101 |
-
params.u_ptr = u.data_ptr();
|
102 |
-
params.delta_ptr = delta.data_ptr();
|
103 |
-
params.A_ptr = A.data_ptr();
|
104 |
-
params.B_ptr = B.data_ptr();
|
105 |
-
params.C_ptr = C.data_ptr();
|
106 |
-
params.D_ptr = D_ptr;
|
107 |
-
params.delta_bias_ptr = delta_bias_ptr;
|
108 |
-
params.out_ptr = out.data_ptr();
|
109 |
-
params.x_ptr = x_ptr;
|
110 |
-
params.z_ptr = has_z ? z.data_ptr() : nullptr;
|
111 |
-
params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;
|
112 |
-
// All stride are in elements, not bytes.
|
113 |
-
params.A_d_stride = A.stride(0);
|
114 |
-
params.A_dstate_stride = A.stride(1);
|
115 |
-
if (!is_variable_B) {
|
116 |
-
params.B_d_stride = B.stride(0);
|
117 |
-
} else {
|
118 |
-
params.B_batch_stride = B.stride(0);
|
119 |
-
params.B_group_stride = B.stride(1);
|
120 |
-
}
|
121 |
-
params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2);
|
122 |
-
if (!is_variable_C) {
|
123 |
-
params.C_d_stride = C.stride(0);
|
124 |
-
} else {
|
125 |
-
params.C_batch_stride = C.stride(0);
|
126 |
-
params.C_group_stride = C.stride(1);
|
127 |
-
}
|
128 |
-
params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2);
|
129 |
-
params.u_batch_stride = u.stride(0);
|
130 |
-
params.u_d_stride = u.stride(1);
|
131 |
-
params.delta_batch_stride = delta.stride(0);
|
132 |
-
params.delta_d_stride = delta.stride(1);
|
133 |
-
if (has_z) {
|
134 |
-
params.z_batch_stride = z.stride(0);
|
135 |
-
params.z_d_stride = z.stride(1);
|
136 |
-
params.out_z_batch_stride = out_z.stride(0);
|
137 |
-
params.out_z_d_stride = out_z.stride(1);
|
138 |
-
}
|
139 |
-
params.out_batch_stride = out.stride(0);
|
140 |
-
params.out_d_stride = out.stride(1);
|
141 |
-
}
|
142 |
-
|
143 |
-
void set_ssm_params_bwd(SSMParamsBwd ¶ms,
|
144 |
-
// sizes
|
145 |
-
const size_t batch,
|
146 |
-
const size_t dim,
|
147 |
-
const size_t seqlen,
|
148 |
-
const size_t dstate,
|
149 |
-
const size_t n_groups,
|
150 |
-
const size_t n_chunks,
|
151 |
-
const bool is_variable_B,
|
152 |
-
const bool is_variable_C,
|
153 |
-
// device pointers
|
154 |
-
const at::Tensor u,
|
155 |
-
const at::Tensor delta,
|
156 |
-
const at::Tensor A,
|
157 |
-
const at::Tensor B,
|
158 |
-
const at::Tensor C,
|
159 |
-
const at::Tensor z,
|
160 |
-
const at::Tensor out,
|
161 |
-
const at::Tensor out_z,
|
162 |
-
void* D_ptr,
|
163 |
-
void* delta_bias_ptr,
|
164 |
-
void* x_ptr,
|
165 |
-
const at::Tensor dout,
|
166 |
-
const at::Tensor du,
|
167 |
-
const at::Tensor ddelta,
|
168 |
-
const at::Tensor dA,
|
169 |
-
const at::Tensor dB,
|
170 |
-
const at::Tensor dC,
|
171 |
-
const at::Tensor dz,
|
172 |
-
void* dD_ptr,
|
173 |
-
void* ddelta_bias_ptr,
|
174 |
-
bool has_z,
|
175 |
-
bool delta_softplus,
|
176 |
-
bool recompute_out_z) {
|
177 |
-
// Pass in "dout" instead of "out", we're not gonna use "out" unless we have z
|
178 |
-
set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
|
179 |
-
u, delta, A, B, C, has_z ? out : dout,
|
180 |
-
has_z ? z : dout,
|
181 |
-
// If not recompute_out_z, pass dout instead of out_z.
|
182 |
-
// This won't be used by the bwd kernel
|
183 |
-
recompute_out_z ? out_z : dout,
|
184 |
-
D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus);
|
185 |
-
if (!recompute_out_z) { params.out_z_ptr = nullptr; }
|
186 |
-
|
187 |
-
// Set the pointers and strides.
|
188 |
-
params.dout_ptr = dout.data_ptr();
|
189 |
-
params.du_ptr = du.data_ptr();
|
190 |
-
params.dA_ptr = dA.data_ptr();
|
191 |
-
params.dB_ptr = dB.data_ptr();
|
192 |
-
params.dC_ptr = dC.data_ptr();
|
193 |
-
params.dD_ptr = dD_ptr;
|
194 |
-
params.ddelta_ptr = ddelta.data_ptr();
|
195 |
-
params.ddelta_bias_ptr = ddelta_bias_ptr;
|
196 |
-
params.dz_ptr = has_z ? dz.data_ptr() : nullptr;
|
197 |
-
// All stride are in elements, not bytes.
|
198 |
-
params.dout_batch_stride = dout.stride(0);
|
199 |
-
params.dout_d_stride = dout.stride(1);
|
200 |
-
params.dA_d_stride = dA.stride(0);
|
201 |
-
params.dA_dstate_stride = dA.stride(1);
|
202 |
-
if (!is_variable_B) {
|
203 |
-
params.dB_d_stride = dB.stride(0);
|
204 |
-
} else {
|
205 |
-
params.dB_batch_stride = dB.stride(0);
|
206 |
-
params.dB_group_stride = dB.stride(1);
|
207 |
-
}
|
208 |
-
params.dB_dstate_stride = !is_variable_B ? dB.stride(1) : dB.stride(2);
|
209 |
-
if (!is_variable_C) {
|
210 |
-
params.dC_d_stride = dC.stride(0);
|
211 |
-
} else {
|
212 |
-
params.dC_batch_stride = dC.stride(0);
|
213 |
-
params.dC_group_stride = dC.stride(1);
|
214 |
-
}
|
215 |
-
params.dC_dstate_stride = !is_variable_C ? dC.stride(1) : dC.stride(2);
|
216 |
-
params.du_batch_stride = du.stride(0);
|
217 |
-
params.du_d_stride = du.stride(1);
|
218 |
-
params.ddelta_batch_stride = ddelta.stride(0);
|
219 |
-
params.ddelta_d_stride = ddelta.stride(1);
|
220 |
-
if (has_z) {
|
221 |
-
params.dz_batch_stride = dz.stride(0);
|
222 |
-
params.dz_d_stride = dz.stride(1);
|
223 |
-
}
|
224 |
-
}
|
225 |
-
|
226 |
-
std::vector<at::Tensor>
|
227 |
-
selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
|
228 |
-
const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
|
229 |
-
const c10::optional<at::Tensor> &D_,
|
230 |
-
const c10::optional<at::Tensor> &z_,
|
231 |
-
const c10::optional<at::Tensor> &delta_bias_,
|
232 |
-
bool delta_softplus) {
|
233 |
-
auto input_type = u.scalar_type();
|
234 |
-
auto weight_type = A.scalar_type();
|
235 |
-
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
236 |
-
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat);
|
237 |
-
|
238 |
-
const bool is_variable_B = B.dim() >= 3;
|
239 |
-
const bool is_variable_C = C.dim() >= 3;
|
240 |
-
const bool is_complex = weight_type == at::ScalarType::ComplexFloat;
|
241 |
-
|
242 |
-
TORCH_CHECK(delta.scalar_type() == input_type);
|
243 |
-
TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
|
244 |
-
TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
|
245 |
-
|
246 |
-
TORCH_CHECK(u.is_cuda());
|
247 |
-
TORCH_CHECK(delta.is_cuda());
|
248 |
-
TORCH_CHECK(A.is_cuda());
|
249 |
-
TORCH_CHECK(B.is_cuda());
|
250 |
-
TORCH_CHECK(C.is_cuda());
|
251 |
-
|
252 |
-
TORCH_CHECK(u.stride(-1) == 1);
|
253 |
-
TORCH_CHECK(delta.stride(-1) == 1);
|
254 |
-
|
255 |
-
const auto sizes = u.sizes();
|
256 |
-
const int batch_size = sizes[0];
|
257 |
-
const int dim = sizes[1];
|
258 |
-
const int seqlen = sizes[2];
|
259 |
-
const int dstate = A.size(1);
|
260 |
-
const int n_groups = is_variable_B ? B.size(1) : 1;
|
261 |
-
|
262 |
-
TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
|
263 |
-
|
264 |
-
CHECK_SHAPE(u, batch_size, dim, seqlen);
|
265 |
-
CHECK_SHAPE(delta, batch_size, dim, seqlen);
|
266 |
-
CHECK_SHAPE(A, dim, dstate);
|
267 |
-
if (!is_variable_B) {
|
268 |
-
CHECK_SHAPE(B, dim, dstate);
|
269 |
-
} else {
|
270 |
-
CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2);
|
271 |
-
TORCH_CHECK(B.stride(-1) == 1);
|
272 |
-
}
|
273 |
-
if (!is_variable_C) {
|
274 |
-
CHECK_SHAPE(C, dim, dstate);
|
275 |
-
} else {
|
276 |
-
CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2);
|
277 |
-
TORCH_CHECK(C.stride(-1) == 1);
|
278 |
-
}
|
279 |
-
|
280 |
-
if (D_.has_value()) {
|
281 |
-
auto D = D_.value();
|
282 |
-
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
|
283 |
-
TORCH_CHECK(D.is_cuda());
|
284 |
-
TORCH_CHECK(D.stride(-1) == 1);
|
285 |
-
CHECK_SHAPE(D, dim);
|
286 |
-
}
|
287 |
-
|
288 |
-
if (delta_bias_.has_value()) {
|
289 |
-
auto delta_bias = delta_bias_.value();
|
290 |
-
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
|
291 |
-
TORCH_CHECK(delta_bias.is_cuda());
|
292 |
-
TORCH_CHECK(delta_bias.stride(-1) == 1);
|
293 |
-
CHECK_SHAPE(delta_bias, dim);
|
294 |
-
}
|
295 |
-
|
296 |
-
at::Tensor z, out_z;
|
297 |
-
const bool has_z = z_.has_value();
|
298 |
-
if (has_z) {
|
299 |
-
z = z_.value();
|
300 |
-
TORCH_CHECK(z.scalar_type() == input_type);
|
301 |
-
TORCH_CHECK(z.is_cuda());
|
302 |
-
TORCH_CHECK(z.stride(-1) == 1);
|
303 |
-
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
304 |
-
out_z = torch::empty_like(z);
|
305 |
-
}
|
306 |
-
|
307 |
-
const int n_chunks = (seqlen + 2048 - 1) / 2048;
|
308 |
-
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
|
309 |
-
// at::Tensor out = torch::empty_like(u);
|
310 |
-
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
|
311 |
-
at::Tensor out = torch::empty_like(delta);
|
312 |
-
at::Tensor x;
|
313 |
-
x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type));
|
314 |
-
|
315 |
-
SSMParamsBase params;
|
316 |
-
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
|
317 |
-
u, delta, A, B, C, out, z, out_z,
|
318 |
-
D_.has_value() ? D_.value().data_ptr() : nullptr,
|
319 |
-
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
|
320 |
-
x.data_ptr(),
|
321 |
-
has_z,
|
322 |
-
delta_softplus);
|
323 |
-
|
324 |
-
// Otherwise the kernel will be launched from cuda:0 device
|
325 |
-
// Cast to char to avoid compiler warning about narrowing
|
326 |
-
at::cuda::CUDAGuard device_guard{(char)u.get_device()};
|
327 |
-
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
328 |
-
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
|
329 |
-
DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_fwd", [&] {
|
330 |
-
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
|
331 |
-
});
|
332 |
-
});
|
333 |
-
std::vector<at::Tensor> result = {out, x};
|
334 |
-
if (has_z) { result.push_back(out_z); }
|
335 |
-
return result;
|
336 |
-
}
|
337 |
-
|
338 |
-
std::vector<at::Tensor>
|
339 |
-
selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
|
340 |
-
const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
|
341 |
-
const c10::optional<at::Tensor> &D_,
|
342 |
-
const c10::optional<at::Tensor> &z_,
|
343 |
-
const c10::optional<at::Tensor> &delta_bias_,
|
344 |
-
const at::Tensor &dout,
|
345 |
-
const c10::optional<at::Tensor> &x_,
|
346 |
-
const c10::optional<at::Tensor> &out_,
|
347 |
-
c10::optional<at::Tensor> &dz_,
|
348 |
-
bool delta_softplus,
|
349 |
-
bool recompute_out_z) {
|
350 |
-
auto input_type = u.scalar_type();
|
351 |
-
auto weight_type = A.scalar_type();
|
352 |
-
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
353 |
-
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat);
|
354 |
-
|
355 |
-
const bool is_variable_B = B.dim() >= 3;
|
356 |
-
const bool is_variable_C = C.dim() >= 3;
|
357 |
-
const bool is_complex = weight_type == at::ScalarType::ComplexFloat;
|
358 |
-
|
359 |
-
TORCH_CHECK(delta.scalar_type() == input_type);
|
360 |
-
TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
|
361 |
-
TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
|
362 |
-
TORCH_CHECK(dout.scalar_type() == input_type);
|
363 |
-
|
364 |
-
TORCH_CHECK(u.is_cuda());
|
365 |
-
TORCH_CHECK(delta.is_cuda());
|
366 |
-
TORCH_CHECK(A.is_cuda());
|
367 |
-
TORCH_CHECK(B.is_cuda());
|
368 |
-
TORCH_CHECK(C.is_cuda());
|
369 |
-
TORCH_CHECK(dout.is_cuda());
|
370 |
-
|
371 |
-
TORCH_CHECK(u.stride(-1) == 1);
|
372 |
-
TORCH_CHECK(delta.stride(-1) == 1);
|
373 |
-
TORCH_CHECK(dout.stride(-1) == 1);
|
374 |
-
|
375 |
-
const auto sizes = u.sizes();
|
376 |
-
const int batch_size = sizes[0];
|
377 |
-
const int dim = sizes[1];
|
378 |
-
const int seqlen = sizes[2];
|
379 |
-
const int dstate = A.size(1);
|
380 |
-
const int n_groups = is_variable_B ? B.size(1) : 1;
|
381 |
-
|
382 |
-
TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
|
383 |
-
|
384 |
-
CHECK_SHAPE(u, batch_size, dim, seqlen);
|
385 |
-
CHECK_SHAPE(delta, batch_size, dim, seqlen);
|
386 |
-
CHECK_SHAPE(A, dim, dstate);
|
387 |
-
if (!is_variable_B) {
|
388 |
-
CHECK_SHAPE(B, dim, dstate);
|
389 |
-
} else {
|
390 |
-
CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2);
|
391 |
-
TORCH_CHECK(B.stride(-1) == 1);
|
392 |
-
}
|
393 |
-
if (!is_variable_C) {
|
394 |
-
CHECK_SHAPE(C, dim, dstate);
|
395 |
-
} else {
|
396 |
-
CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2);
|
397 |
-
TORCH_CHECK(C.stride(-1) == 1);
|
398 |
-
}
|
399 |
-
CHECK_SHAPE(dout, batch_size, dim, seqlen);
|
400 |
-
|
401 |
-
if (D_.has_value()) {
|
402 |
-
auto D = D_.value();
|
403 |
-
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
|
404 |
-
TORCH_CHECK(D.is_cuda());
|
405 |
-
TORCH_CHECK(D.stride(-1) == 1);
|
406 |
-
CHECK_SHAPE(D, dim);
|
407 |
-
}
|
408 |
-
|
409 |
-
if (delta_bias_.has_value()) {
|
410 |
-
auto delta_bias = delta_bias_.value();
|
411 |
-
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
|
412 |
-
TORCH_CHECK(delta_bias.is_cuda());
|
413 |
-
TORCH_CHECK(delta_bias.stride(-1) == 1);
|
414 |
-
CHECK_SHAPE(delta_bias, dim);
|
415 |
-
}
|
416 |
-
|
417 |
-
at::Tensor z, out, dz, out_z;
|
418 |
-
const bool has_z = z_.has_value();
|
419 |
-
if (has_z) {
|
420 |
-
z = z_.value();
|
421 |
-
TORCH_CHECK(z.scalar_type() == input_type);
|
422 |
-
TORCH_CHECK(z.is_cuda());
|
423 |
-
TORCH_CHECK(z.stride(-1) == 1);
|
424 |
-
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
425 |
-
|
426 |
-
TORCH_CHECK(out_.has_value());
|
427 |
-
out = out_.value();
|
428 |
-
TORCH_CHECK(out.scalar_type() == input_type);
|
429 |
-
TORCH_CHECK(out.is_cuda());
|
430 |
-
TORCH_CHECK(out.stride(-1) == 1);
|
431 |
-
CHECK_SHAPE(out, batch_size, dim, seqlen);
|
432 |
-
|
433 |
-
if (dz_.has_value()) {
|
434 |
-
dz = dz_.value();
|
435 |
-
TORCH_CHECK(dz.scalar_type() == input_type);
|
436 |
-
TORCH_CHECK(dz.is_cuda());
|
437 |
-
TORCH_CHECK(dz.stride(-1) == 1);
|
438 |
-
CHECK_SHAPE(dz, batch_size, dim, seqlen);
|
439 |
-
} else {
|
440 |
-
dz = torch::empty_like(z);
|
441 |
-
}
|
442 |
-
if (recompute_out_z) {
|
443 |
-
out_z = torch::empty_like(out);
|
444 |
-
}
|
445 |
-
}
|
446 |
-
|
447 |
-
const int n_chunks = (seqlen + 2048 - 1) / 2048;
|
448 |
-
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
|
449 |
-
if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); }
|
450 |
-
if (x_.has_value()) {
|
451 |
-
auto x = x_.value();
|
452 |
-
TORCH_CHECK(x.scalar_type() == weight_type);
|
453 |
-
TORCH_CHECK(x.is_cuda());
|
454 |
-
TORCH_CHECK(x.is_contiguous());
|
455 |
-
CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate);
|
456 |
-
}
|
457 |
-
|
458 |
-
at::Tensor du = torch::empty_like(u);
|
459 |
-
at::Tensor ddelta = torch::empty_like(delta);
|
460 |
-
at::Tensor dA = torch::zeros_like(A);
|
461 |
-
at::Tensor dB = !is_variable_B ? torch::zeros_like(B) : torch::zeros_like(B, B.options().dtype(torch::kFloat32));
|
462 |
-
at::Tensor dC = !is_variable_C ? torch::zeros_like(C) : torch::zeros_like(C, C.options().dtype(torch::kFloat32));
|
463 |
-
at::Tensor dD;
|
464 |
-
if (D_.has_value()) { dD = torch::zeros_like(D_.value()); }
|
465 |
-
at::Tensor ddelta_bias;
|
466 |
-
if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); }
|
467 |
-
|
468 |
-
SSMParamsBwd params;
|
469 |
-
set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
|
470 |
-
u, delta, A, B, C, z, out, out_z,
|
471 |
-
D_.has_value() ? D_.value().data_ptr() : nullptr,
|
472 |
-
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
|
473 |
-
x_.has_value() ? x_.value().data_ptr() : nullptr,
|
474 |
-
dout, du, ddelta, dA, dB, dC, dz,
|
475 |
-
D_.has_value() ? dD.data_ptr() : nullptr,
|
476 |
-
delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr,
|
477 |
-
has_z, delta_softplus, recompute_out_z);
|
478 |
-
|
479 |
-
// Otherwise the kernel will be launched from cuda:0 device
|
480 |
-
// Cast to char to avoid compiler warning about narrowing
|
481 |
-
at::cuda::CUDAGuard device_guard{(char)u.get_device()};
|
482 |
-
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
483 |
-
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] {
|
484 |
-
DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_bwd", [&] {
|
485 |
-
selective_scan_bwd_cuda<input_t, weight_t>(params, stream);
|
486 |
-
});
|
487 |
-
});
|
488 |
-
std::vector<at::Tensor> result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias};
|
489 |
-
if (has_z) { result.push_back(dz); }
|
490 |
-
if (recompute_out_z) { result.push_back(out_z); }
|
491 |
-
return result;
|
492 |
-
}
|
493 |
-
|
494 |
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
495 |
-
m.def("fwd", &selective_scan_fwd, "Selective scan forward");
|
496 |
-
m.def("bwd", &selective_scan_bwd, "Selective scan backward");
|
497 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mamba/csrc/selective_scan/selective_scan.h
DELETED
@@ -1,101 +0,0 @@
|
|
1 |
-
/******************************************************************************
|
2 |
-
* Copyright (c) 2023, Tri Dao.
|
3 |
-
******************************************************************************/
|
4 |
-
|
5 |
-
#pragma once
|
6 |
-
|
7 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
8 |
-
|
9 |
-
struct SSMScanParamsBase {
|
10 |
-
using index_t = uint32_t;
|
11 |
-
|
12 |
-
int batch, seqlen, n_chunks;
|
13 |
-
index_t a_batch_stride;
|
14 |
-
index_t b_batch_stride;
|
15 |
-
index_t out_batch_stride;
|
16 |
-
|
17 |
-
// Common data pointers.
|
18 |
-
void *__restrict__ a_ptr;
|
19 |
-
void *__restrict__ b_ptr;
|
20 |
-
void *__restrict__ out_ptr;
|
21 |
-
void *__restrict__ x_ptr;
|
22 |
-
};
|
23 |
-
|
24 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
25 |
-
|
26 |
-
struct SSMParamsBase {
|
27 |
-
using index_t = uint32_t;
|
28 |
-
|
29 |
-
int batch, dim, seqlen, dstate, n_groups, n_chunks;
|
30 |
-
int dim_ngroups_ratio;
|
31 |
-
bool is_variable_B;
|
32 |
-
bool is_variable_C;
|
33 |
-
|
34 |
-
bool delta_softplus;
|
35 |
-
|
36 |
-
index_t A_d_stride;
|
37 |
-
index_t A_dstate_stride;
|
38 |
-
index_t B_batch_stride;
|
39 |
-
index_t B_d_stride;
|
40 |
-
index_t B_dstate_stride;
|
41 |
-
index_t B_group_stride;
|
42 |
-
index_t C_batch_stride;
|
43 |
-
index_t C_d_stride;
|
44 |
-
index_t C_dstate_stride;
|
45 |
-
index_t C_group_stride;
|
46 |
-
index_t u_batch_stride;
|
47 |
-
index_t u_d_stride;
|
48 |
-
index_t delta_batch_stride;
|
49 |
-
index_t delta_d_stride;
|
50 |
-
index_t z_batch_stride;
|
51 |
-
index_t z_d_stride;
|
52 |
-
index_t out_batch_stride;
|
53 |
-
index_t out_d_stride;
|
54 |
-
index_t out_z_batch_stride;
|
55 |
-
index_t out_z_d_stride;
|
56 |
-
|
57 |
-
// Common data pointers.
|
58 |
-
void *__restrict__ A_ptr;
|
59 |
-
void *__restrict__ B_ptr;
|
60 |
-
void *__restrict__ C_ptr;
|
61 |
-
void *__restrict__ D_ptr;
|
62 |
-
void *__restrict__ u_ptr;
|
63 |
-
void *__restrict__ delta_ptr;
|
64 |
-
void *__restrict__ delta_bias_ptr;
|
65 |
-
void *__restrict__ out_ptr;
|
66 |
-
void *__restrict__ x_ptr;
|
67 |
-
void *__restrict__ z_ptr;
|
68 |
-
void *__restrict__ out_z_ptr;
|
69 |
-
};
|
70 |
-
|
71 |
-
struct SSMParamsBwd: public SSMParamsBase {
|
72 |
-
index_t dout_batch_stride;
|
73 |
-
index_t dout_d_stride;
|
74 |
-
index_t dA_d_stride;
|
75 |
-
index_t dA_dstate_stride;
|
76 |
-
index_t dB_batch_stride;
|
77 |
-
index_t dB_group_stride;
|
78 |
-
index_t dB_d_stride;
|
79 |
-
index_t dB_dstate_stride;
|
80 |
-
index_t dC_batch_stride;
|
81 |
-
index_t dC_group_stride;
|
82 |
-
index_t dC_d_stride;
|
83 |
-
index_t dC_dstate_stride;
|
84 |
-
index_t du_batch_stride;
|
85 |
-
index_t du_d_stride;
|
86 |
-
index_t dz_batch_stride;
|
87 |
-
index_t dz_d_stride;
|
88 |
-
index_t ddelta_batch_stride;
|
89 |
-
index_t ddelta_d_stride;
|
90 |
-
|
91 |
-
// Common data pointers.
|
92 |
-
void *__restrict__ dout_ptr;
|
93 |
-
void *__restrict__ dA_ptr;
|
94 |
-
void *__restrict__ dB_ptr;
|
95 |
-
void *__restrict__ dC_ptr;
|
96 |
-
void *__restrict__ dD_ptr;
|
97 |
-
void *__restrict__ du_ptr;
|
98 |
-
void *__restrict__ dz_ptr;
|
99 |
-
void *__restrict__ ddelta_ptr;
|
100 |
-
void *__restrict__ ddelta_bias_ptr;
|
101 |
-
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mamba/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
/******************************************************************************
|
2 |
-
* Copyright (c) 2023, Tri Dao.
|
3 |
-
******************************************************************************/
|
4 |
-
|
5 |
-
// Split into multiple files to compile in paralell
|
6 |
-
|
7 |
-
#include "selective_scan_bwd_kernel.cuh"
|
8 |
-
|
9 |
-
template void selective_scan_bwd_cuda<at::BFloat16, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mamba/csrc/selective_scan/selective_scan_bwd_bf16_real.cu
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
/******************************************************************************
|
2 |
-
* Copyright (c) 2023, Tri Dao.
|
3 |
-
******************************************************************************/
|
4 |
-
|
5 |
-
// Split into multiple files to compile in paralell
|
6 |
-
|
7 |
-
#include "selective_scan_bwd_kernel.cuh"
|
8 |
-
|
9 |
-
template void selective_scan_bwd_cuda<at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mamba/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
/******************************************************************************
|
2 |
-
* Copyright (c) 2023, Tri Dao.
|
3 |
-
******************************************************************************/
|
4 |
-
|
5 |
-
// Split into multiple files to compile in paralell
|
6 |
-
|
7 |
-
#include "selective_scan_bwd_kernel.cuh"
|
8 |
-
|
9 |
-
template void selective_scan_bwd_cuda<at::Half, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mamba/csrc/selective_scan/selective_scan_bwd_fp16_real.cu
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
/******************************************************************************
|
2 |
-
* Copyright (c) 2023, Tri Dao.
|
3 |
-
******************************************************************************/
|
4 |
-
|
5 |
-
// Split into multiple files to compile in paralell
|
6 |
-
|
7 |
-
#include "selective_scan_bwd_kernel.cuh"
|
8 |
-
|
9 |
-
template void selective_scan_bwd_cuda<at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mamba/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
/******************************************************************************
|
2 |
-
* Copyright (c) 2023, Tri Dao.
|
3 |
-
******************************************************************************/
|
4 |
-
|
5 |
-
// Split into multiple files to compile in paralell
|
6 |
-
|
7 |
-
#include "selective_scan_bwd_kernel.cuh"
|
8 |
-
|
9 |
-
template void selective_scan_bwd_cuda<float, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mamba/csrc/selective_scan/selective_scan_bwd_fp32_real.cu
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
/******************************************************************************
|
2 |
-
* Copyright (c) 2023, Tri Dao.
|
3 |
-
******************************************************************************/
|
4 |
-
|
5 |
-
// Split into multiple files to compile in paralell
|
6 |
-
|
7 |
-
#include "selective_scan_bwd_kernel.cuh"
|
8 |
-
|
9 |
-
template void selective_scan_bwd_cuda<float, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mamba/csrc/selective_scan/selective_scan_bwd_kernel.cuh
DELETED
@@ -1,531 +0,0 @@
|
|
1 |
-
/******************************************************************************
|
2 |
-
* Copyright (c) 2023, Tri Dao.
|
3 |
-
******************************************************************************/
|
4 |
-
|
5 |
-
#pragma once
|
6 |
-
|
7 |
-
#include <c10/util/BFloat16.h>
|
8 |
-
#include <c10/util/Half.h>
|
9 |
-
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
10 |
-
#include <ATen/cuda/Atomic.cuh> // For atomicAdd on complex
|
11 |
-
|
12 |
-
#include <cub/block/block_load.cuh>
|
13 |
-
#include <cub/block/block_store.cuh>
|
14 |
-
#include <cub/block/block_scan.cuh>
|
15 |
-
#include <cub/block/block_reduce.cuh>
|
16 |
-
|
17 |
-
#include "selective_scan.h"
|
18 |
-
#include "selective_scan_common.h"
|
19 |
-
#include "reverse_scan.cuh"
|
20 |
-
#include "static_switch.h"
|
21 |
-
|
22 |
-
template<typename scalar_t> __device__ __forceinline__ scalar_t conj(scalar_t x);
|
23 |
-
template<> __device__ __forceinline__ float conj<float>(float x) { return x; }
|
24 |
-
template<> __device__ __forceinline__ complex_t conj<complex_t>(complex_t x) { return std::conj(x); }
|
25 |
-
|
26 |
-
template<int kNThreads_, int kNItems_, bool kIsEvenLen_, bool kIsVariableB_, bool kIsVariableC_,
|
27 |
-
bool kDeltaSoftplus_, bool kHasZ_, typename input_t_, typename weight_t_>
|
28 |
-
struct Selective_Scan_bwd_kernel_traits {
|
29 |
-
static_assert(kNItems_ % 4 == 0);
|
30 |
-
using input_t = input_t_;
|
31 |
-
using weight_t = weight_t_;
|
32 |
-
static constexpr int kNThreads = kNThreads_;
|
33 |
-
static constexpr int kNItems = kNItems_;
|
34 |
-
static constexpr int kNBytes = sizeof(input_t);
|
35 |
-
static_assert(kNBytes == 2 || kNBytes == 4);
|
36 |
-
static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
|
37 |
-
static_assert(kNItems % kNElts == 0);
|
38 |
-
static constexpr int kNLoads = kNItems / kNElts;
|
39 |
-
static constexpr bool kIsComplex = std::is_same_v<weight_t, complex_t>;
|
40 |
-
static constexpr bool kIsEvenLen = kIsEvenLen_;
|
41 |
-
static constexpr bool kIsVariableB = kIsVariableB_;
|
42 |
-
static constexpr bool kIsVariableC = kIsVariableC_;
|
43 |
-
static constexpr bool kDeltaSoftplus = kDeltaSoftplus_;
|
44 |
-
static constexpr bool kHasZ = kHasZ_;
|
45 |
-
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy.
|
46 |
-
// For complex this would lead to massive register spilling, so we keep it at 2.
|
47 |
-
static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2;
|
48 |
-
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
49 |
-
using scan_t = std::conditional_t<!kIsComplex, float2, float4>;
|
50 |
-
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
51 |
-
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
52 |
-
using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, !kIsComplex ? kNItems : kNItems * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
53 |
-
using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, !kIsComplex ? kNLoads : kNLoads * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
54 |
-
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
55 |
-
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
56 |
-
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
|
57 |
-
using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
|
58 |
-
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
59 |
-
using BlockReverseScanT = BlockReverseScan<scan_t, kNThreads>;
|
60 |
-
using BlockReduceT = cub::BlockReduce<scan_t, kNThreads>;
|
61 |
-
using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
|
62 |
-
using BlockReduceComplexT = cub::BlockReduce<complex_t, kNThreads>;
|
63 |
-
using BlockExchangeT = cub::BlockExchange<float, kNThreads, !kIsComplex ? kNItems : kNItems * 2>;
|
64 |
-
static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
|
65 |
-
sizeof(typename BlockLoadVecT::TempStorage),
|
66 |
-
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
|
67 |
-
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
|
68 |
-
sizeof(typename BlockStoreT::TempStorage),
|
69 |
-
sizeof(typename BlockStoreVecT::TempStorage)});
|
70 |
-
static constexpr int kSmemExchangeSize = (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockExchangeT::TempStorage);
|
71 |
-
static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage);
|
72 |
-
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage);
|
73 |
-
};
|
74 |
-
|
75 |
-
template<typename Ktraits>
|
76 |
-
__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
|
77 |
-
void selective_scan_bwd_kernel(SSMParamsBwd params) {
|
78 |
-
constexpr bool kIsComplex = Ktraits::kIsComplex;
|
79 |
-
constexpr bool kIsVariableB = Ktraits::kIsVariableB;
|
80 |
-
constexpr bool kIsVariableC = Ktraits::kIsVariableC;
|
81 |
-
constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus;
|
82 |
-
constexpr bool kHasZ = Ktraits::kHasZ;
|
83 |
-
constexpr int kNThreads = Ktraits::kNThreads;
|
84 |
-
constexpr int kNItems = Ktraits::kNItems;
|
85 |
-
using input_t = typename Ktraits::input_t;
|
86 |
-
using weight_t = typename Ktraits::weight_t;
|
87 |
-
using scan_t = typename Ktraits::scan_t;
|
88 |
-
|
89 |
-
// Shared memory.
|
90 |
-
extern __shared__ char smem_[];
|
91 |
-
// cast to lvalue reference of expected type
|
92 |
-
// char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
|
93 |
-
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
|
94 |
-
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
|
95 |
-
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
96 |
-
auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
|
97 |
-
auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
|
98 |
-
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
99 |
-
auto& smem_exchange = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
|
100 |
-
auto& smem_exchange1 = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage));
|
101 |
-
auto& smem_reduce = *reinterpret_cast<typename Ktraits::BlockReduceT::TempStorage*>(reinterpret_cast<char *>(&smem_exchange) + Ktraits::kSmemExchangeSize);
|
102 |
-
auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(&smem_reduce);
|
103 |
-
auto& smem_reduce_complex = *reinterpret_cast<typename Ktraits::BlockReduceComplexT::TempStorage*>(&smem_reduce);
|
104 |
-
auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(reinterpret_cast<char *>(&smem_reduce) + Ktraits::kSmemReduceSize);
|
105 |
-
auto& smem_reverse_scan = *reinterpret_cast<typename Ktraits::BlockReverseScanT::TempStorage*>(reinterpret_cast<char *>(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage));
|
106 |
-
weight_t *smem_delta_a = reinterpret_cast<weight_t *>(smem_ + Ktraits::kSmemSize);
|
107 |
-
scan_t *smem_running_postfix = reinterpret_cast<scan_t *>(smem_delta_a + 2 * MAX_DSTATE + kNThreads);
|
108 |
-
weight_t *smem_da = reinterpret_cast<weight_t *>(smem_running_postfix + MAX_DSTATE);
|
109 |
-
weight_t *smem_dbc = reinterpret_cast<weight_t *>(smem_da + MAX_DSTATE);
|
110 |
-
|
111 |
-
const int batch_id = blockIdx.x;
|
112 |
-
const int dim_id = blockIdx.y;
|
113 |
-
const int group_id = dim_id / (params.dim_ngroups_ratio);
|
114 |
-
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
|
115 |
-
+ dim_id * params.u_d_stride;
|
116 |
-
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
|
117 |
-
+ dim_id * params.delta_d_stride;
|
118 |
-
input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
|
119 |
-
+ dim_id * params.dout_d_stride;
|
120 |
-
weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * params.A_d_stride;
|
121 |
-
weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * params.B_d_stride;
|
122 |
-
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
|
123 |
-
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * params.C_d_stride;
|
124 |
-
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
|
125 |
-
weight_t *dA = reinterpret_cast<weight_t *>(params.dA_ptr) + dim_id * params.dA_d_stride;
|
126 |
-
weight_t *dB = reinterpret_cast<weight_t *>(params.dB_ptr)
|
127 |
-
+ (!kIsVariableB ? dim_id * params.dB_d_stride : batch_id * (!kIsComplex ? params.dB_batch_stride : params.dB_batch_stride / 2) + group_id * params.dB_group_stride);
|
128 |
-
weight_t *dC = reinterpret_cast<weight_t *>(params.dC_ptr)
|
129 |
-
+ (!kIsVariableC ? dim_id * params.dC_d_stride : batch_id * (!kIsComplex ? params.dC_batch_stride : params.dC_batch_stride / 2) + group_id * params.dC_group_stride);
|
130 |
-
float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.dD_ptr) + dim_id;
|
131 |
-
float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.D_ptr)[dim_id];
|
132 |
-
float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.ddelta_bias_ptr) + dim_id;
|
133 |
-
float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id];
|
134 |
-
scan_t *x = params.x_ptr == nullptr
|
135 |
-
? nullptr
|
136 |
-
: reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate;
|
137 |
-
float dD_val = 0;
|
138 |
-
float ddelta_bias_val = 0;
|
139 |
-
|
140 |
-
constexpr int kChunkSize = kNThreads * kNItems;
|
141 |
-
u += (params.n_chunks - 1) * kChunkSize;
|
142 |
-
delta += (params.n_chunks - 1) * kChunkSize;
|
143 |
-
dout += (params.n_chunks - 1) * kChunkSize;
|
144 |
-
Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2);
|
145 |
-
Cvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2);
|
146 |
-
for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) {
|
147 |
-
input_t u_vals[kNItems];
|
148 |
-
input_t delta_vals_load[kNItems];
|
149 |
-
input_t dout_vals_load[kNItems];
|
150 |
-
__syncthreads();
|
151 |
-
load_input<Ktraits>(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize);
|
152 |
-
u -= kChunkSize;
|
153 |
-
__syncthreads();
|
154 |
-
load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
|
155 |
-
// Will reload delta at the same location if kDeltaSoftplus
|
156 |
-
if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; }
|
157 |
-
__syncthreads();
|
158 |
-
load_input<Ktraits>(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
|
159 |
-
dout -= kChunkSize;
|
160 |
-
|
161 |
-
float dout_vals[kNItems], delta_vals[kNItems];
|
162 |
-
#pragma unroll
|
163 |
-
for (int i = 0; i < kNItems; ++i) {
|
164 |
-
dout_vals[i] = float(dout_vals_load[i]);
|
165 |
-
delta_vals[i] = float(delta_vals_load[i]) + delta_bias;
|
166 |
-
if constexpr (kDeltaSoftplus) {
|
167 |
-
delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i];
|
168 |
-
}
|
169 |
-
}
|
170 |
-
|
171 |
-
if constexpr (kHasZ) {
|
172 |
-
input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride
|
173 |
-
+ dim_id * params.z_d_stride + chunk * kChunkSize;
|
174 |
-
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
175 |
-
+ dim_id * params.out_d_stride + chunk * kChunkSize;
|
176 |
-
input_t *dz = reinterpret_cast<input_t *>(params.dz_ptr) + batch_id * params.dz_batch_stride
|
177 |
-
+ dim_id * params.dz_d_stride + chunk * kChunkSize;
|
178 |
-
input_t z_vals[kNItems], out_vals[kNItems];
|
179 |
-
__syncthreads();
|
180 |
-
load_input<Ktraits>(z, z_vals, smem_load, params.seqlen - chunk * kChunkSize);
|
181 |
-
__syncthreads();
|
182 |
-
load_input<Ktraits>(out, out_vals, smem_load, params.seqlen - chunk * kChunkSize);
|
183 |
-
float dz_vals[kNItems], z_silu_vals[kNItems];
|
184 |
-
#pragma unroll
|
185 |
-
for (int i = 0; i < kNItems; ++i) {
|
186 |
-
float z_val = z_vals[i];
|
187 |
-
float z_sigmoid_val = 1.0f / (1.0f + expf(-z_val));
|
188 |
-
z_silu_vals[i] = z_val * z_sigmoid_val;
|
189 |
-
dz_vals[i] = dout_vals[i] * float(out_vals[i]) * z_sigmoid_val
|
190 |
-
* (1.0f + z_val * (1.0f - z_sigmoid_val));
|
191 |
-
dout_vals[i] *= z_silu_vals[i];
|
192 |
-
}
|
193 |
-
__syncthreads();
|
194 |
-
store_output<Ktraits>(dz, dz_vals, smem_store, params.seqlen - chunk * kChunkSize);
|
195 |
-
if (params.out_z_ptr != nullptr) { // Recompute and store out_z
|
196 |
-
float out_z_vals[kNItems];
|
197 |
-
#pragma unroll
|
198 |
-
for (int i = 0; i < kNItems; ++i) { out_z_vals[i] = float(out_vals[i]) * z_silu_vals[i]; }
|
199 |
-
// if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) {
|
200 |
-
// printf("out_val=%f, z_silu_val = %f, out_z_val = %f\n", float(out_vals[0]), z_silu_vals[0], out_z_vals[0]);
|
201 |
-
// }
|
202 |
-
input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride
|
203 |
-
+ dim_id * params.out_z_d_stride + chunk * kChunkSize;
|
204 |
-
__syncthreads();
|
205 |
-
store_output<Ktraits>(out_z, out_z_vals, smem_store, params.seqlen - chunk * kChunkSize);
|
206 |
-
}
|
207 |
-
}
|
208 |
-
|
209 |
-
float du_vals[kNItems];
|
210 |
-
#pragma unroll
|
211 |
-
for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; }
|
212 |
-
#pragma unroll
|
213 |
-
for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); }
|
214 |
-
|
215 |
-
float ddelta_vals[kNItems] = {0};
|
216 |
-
__syncthreads();
|
217 |
-
for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
|
218 |
-
const weight_t A_val = A[state_idx * params.A_dstate_stride];
|
219 |
-
// Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
|
220 |
-
weight_t A_scaled;
|
221 |
-
constexpr float kLog2e = M_LOG2E;
|
222 |
-
if constexpr (!kIsComplex) {
|
223 |
-
A_scaled = A_val * kLog2e;
|
224 |
-
} else {
|
225 |
-
A_scaled = complex_t(A_val.real_ * kLog2e, A_val.imag_);
|
226 |
-
}
|
227 |
-
weight_t B_val, C_val;
|
228 |
-
weight_t B_vals[kNItems], C_vals[kNItems];
|
229 |
-
if constexpr (!kIsVariableB) {
|
230 |
-
B_val = B[state_idx * params.B_dstate_stride];
|
231 |
-
} else {
|
232 |
-
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
|
233 |
-
smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
|
234 |
-
}
|
235 |
-
if constexpr (!kIsVariableC) {
|
236 |
-
C_val = C[state_idx * params.C_dstate_stride];
|
237 |
-
} else {
|
238 |
-
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
|
239 |
-
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
|
240 |
-
smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
|
241 |
-
}
|
242 |
-
// const weight_t A_val = smem_a[state_idx];
|
243 |
-
scan_t thread_data[kNItems], thread_reverse_data[kNItems];
|
244 |
-
if constexpr (!kIsComplex) {
|
245 |
-
#pragma unroll
|
246 |
-
for (int i = 0; i < kNItems; ++i) {
|
247 |
-
const float delta_a_exp = exp2f(delta_vals[i] * A_scaled);
|
248 |
-
thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]);
|
249 |
-
if (i == 0) {
|
250 |
-
smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp;
|
251 |
-
} else {
|
252 |
-
thread_reverse_data[i - 1].x = delta_a_exp;
|
253 |
-
}
|
254 |
-
thread_reverse_data[i].y = dout_vals[i] *
|
255 |
-
(!kIsVariableC
|
256 |
-
? (!kIsVariableB ? B_val * C_val : C_val)
|
257 |
-
: (!kIsVariableB ? B_val * C_vals[i] : C_vals[i]));
|
258 |
-
}
|
259 |
-
__syncthreads();
|
260 |
-
thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1
|
261 |
-
? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE])
|
262 |
-
: smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE];
|
263 |
-
// Initialize running total
|
264 |
-
scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f);
|
265 |
-
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
266 |
-
Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
267 |
-
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
|
268 |
-
);
|
269 |
-
scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f);
|
270 |
-
SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
|
271 |
-
Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
|
272 |
-
thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
|
273 |
-
);
|
274 |
-
if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; }
|
275 |
-
weight_t dA_val = 0, dBC_val = 0;
|
276 |
-
weight_t dB_vals[kNItems], dC_vals[kNItems];
|
277 |
-
#pragma unroll
|
278 |
-
for (int i = 0; i < kNItems; ++i) {
|
279 |
-
const float dx = thread_reverse_data[i].y;
|
280 |
-
const float ddelta_u = !kIsVariableB ? dx : dx * B_vals[i];
|
281 |
-
du_vals[i] += ddelta_u * delta_vals[i];
|
282 |
-
const float a = thread_data[i].y - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]);
|
283 |
-
ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a;
|
284 |
-
dA_val += dx * delta_vals[i] * a;
|
285 |
-
if constexpr (!kIsVariableB || !kIsVariableC) {
|
286 |
-
if constexpr (!kIsVariableB) { // dBC_val is dB_val
|
287 |
-
dBC_val += dout_vals[i] * (!kIsVariableC ? thread_data[i].y : thread_data[i].y * C_vals[i]);
|
288 |
-
} else { // dBC_val is dC_val
|
289 |
-
dBC_val += dout_vals[i] * thread_data[i].y;
|
290 |
-
}
|
291 |
-
}
|
292 |
-
if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); }
|
293 |
-
if constexpr (kIsVariableC) {
|
294 |
-
dC_vals[i] = dout_vals[i] * (!kIsVariableB ? thread_data[i].y * B_val : thread_data[i].y);
|
295 |
-
}
|
296 |
-
}
|
297 |
-
// Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
|
298 |
-
if constexpr (kIsVariableB || kIsVariableC) {
|
299 |
-
if constexpr (kIsVariableB) {
|
300 |
-
Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals);
|
301 |
-
}
|
302 |
-
if constexpr (kIsVariableC) {
|
303 |
-
auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1;
|
304 |
-
Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals);
|
305 |
-
}
|
306 |
-
const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x;
|
307 |
-
weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x;
|
308 |
-
weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x;
|
309 |
-
#pragma unroll
|
310 |
-
for (int i = 0; i < kNItems; ++i) {
|
311 |
-
if (i * kNThreads < seqlen_remaining) {
|
312 |
-
if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); }
|
313 |
-
if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); }
|
314 |
-
}
|
315 |
-
}
|
316 |
-
}
|
317 |
-
if constexpr (!kIsVariableB || !kIsVariableC) {
|
318 |
-
float2 dA_dBC_val = make_float2(dA_val, dBC_val);
|
319 |
-
dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val);
|
320 |
-
dA_val = dA_dBC_val.x;
|
321 |
-
if (threadIdx.x == 0) {
|
322 |
-
smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx];
|
323 |
-
}
|
324 |
-
} else {
|
325 |
-
dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val);
|
326 |
-
}
|
327 |
-
if (threadIdx.x == 0) {
|
328 |
-
smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];
|
329 |
-
}
|
330 |
-
} else {
|
331 |
-
#pragma unroll
|
332 |
-
for (int i = 0; i < kNItems; ++i) {
|
333 |
-
// Pytorch's implementation of complex exp (which calls thrust) is very slow
|
334 |
-
complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled);
|
335 |
-
weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]);
|
336 |
-
thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);
|
337 |
-
if (i == 0) {
|
338 |
-
smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp;
|
339 |
-
} else {
|
340 |
-
thread_reverse_data[i - 1].x = delta_a_exp.real_;
|
341 |
-
thread_reverse_data[i - 1].y = -delta_a_exp.imag_;
|
342 |
-
}
|
343 |
-
complex_t dout_BC = 2 * dout_vals[i]
|
344 |
-
* conj(!kIsVariableC
|
345 |
-
? (!kIsVariableB ? B_val * C_val : C_val)
|
346 |
-
: (!kIsVariableB ? B_val * C_vals[i] : C_vals[i]));
|
347 |
-
thread_reverse_data[i].z = dout_BC.real_;
|
348 |
-
thread_reverse_data[i].w = dout_BC.imag_;
|
349 |
-
}
|
350 |
-
__syncthreads();
|
351 |
-
complex_t delta_a_exp = threadIdx.x == kNThreads - 1
|
352 |
-
? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE])
|
353 |
-
: smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE];
|
354 |
-
thread_reverse_data[kNItems - 1].x = delta_a_exp.real_;
|
355 |
-
thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag_;
|
356 |
-
// Initialize running total
|
357 |
-
scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
|
358 |
-
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
359 |
-
Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
360 |
-
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
|
361 |
-
);
|
362 |
-
scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
|
363 |
-
SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
|
364 |
-
Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
|
365 |
-
thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
|
366 |
-
);
|
367 |
-
if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; }
|
368 |
-
weight_t dA_val = 0, dBC_val = 0;
|
369 |
-
weight_t dB_vals[kNItems], dC_vals[kNItems];
|
370 |
-
#pragma unroll
|
371 |
-
for (int i = 0; i < kNItems; ++i) {
|
372 |
-
complex_t x = complex_t(thread_data[i].z, thread_data[i].w);
|
373 |
-
complex_t dx = complex_t(thread_reverse_data[i].z, thread_reverse_data[i].w);
|
374 |
-
float ddelta_u = !kIsVariableB ? dx.real_ : (dx * conj(B_vals[i])).real_;
|
375 |
-
if constexpr (!kIsVariableB || !kIsVariableC) {
|
376 |
-
if constexpr (!kIsVariableB) { // dBC_val is dB_val
|
377 |
-
dBC_val += (2 * dout_vals[i]) * conj(!kIsVariableC ? x : x * C_vals[i]);
|
378 |
-
} else { // dBC_val is dC_val
|
379 |
-
dBC_val += (2 * dout_vals[i]) * conj(x);
|
380 |
-
}
|
381 |
-
}
|
382 |
-
const complex_t a_conj = conj(x - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]));
|
383 |
-
du_vals[i] += ddelta_u * delta_vals[i];
|
384 |
-
ddelta_vals[i] += ddelta_u * float(u_vals[i]) + (dx * conj(A_val) * a_conj).real_;
|
385 |
-
dA_val += delta_vals[i] * dx * a_conj;
|
386 |
-
if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); }
|
387 |
-
if constexpr (kIsVariableC) {
|
388 |
-
dC_vals[i] = (2 * dout_vals[i]) * conj(!kIsVariableB ? x * B_val : x);
|
389 |
-
}
|
390 |
-
}
|
391 |
-
// Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
|
392 |
-
if constexpr (kIsVariableB || kIsVariableC) {
|
393 |
-
float dB_vals_f[kNItems * 2], dC_vals_f[kNItems * 2];
|
394 |
-
if constexpr (kIsVariableB) {
|
395 |
-
#pragma unroll
|
396 |
-
for (int i = 0; i < kNItems; ++i) {
|
397 |
-
dB_vals_f[i * 2] = dB_vals[i].real_;
|
398 |
-
dB_vals_f[i * 2 + 1] = dB_vals[i].imag_;
|
399 |
-
}
|
400 |
-
Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f);
|
401 |
-
}
|
402 |
-
if constexpr (kIsVariableC) {
|
403 |
-
#pragma unroll
|
404 |
-
for (int i = 0; i < kNItems; ++i) {
|
405 |
-
dC_vals_f[i * 2] = dC_vals[i].real_;
|
406 |
-
dC_vals_f[i * 2 + 1] = dC_vals[i].imag_;
|
407 |
-
}
|
408 |
-
auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1;
|
409 |
-
Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f);
|
410 |
-
}
|
411 |
-
const int seqlen_remaining = (params.seqlen - chunk * kChunkSize) * 2 - threadIdx.x;
|
412 |
-
float *dB_cur = reinterpret_cast<float *>(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x;
|
413 |
-
float *dC_cur = reinterpret_cast<float *>(dC) + state_idx * params.dC_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x;
|
414 |
-
#pragma unroll
|
415 |
-
for (int i = 0; i < kNItems * 2; ++i) {
|
416 |
-
if (i * kNThreads < seqlen_remaining) {
|
417 |
-
if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals_f[i]); }
|
418 |
-
if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals_f[i]); }
|
419 |
-
}
|
420 |
-
}
|
421 |
-
}
|
422 |
-
if constexpr (!kIsVariableB || !kIsVariableC) {
|
423 |
-
float4 dA_dBC_val = make_float4(dA_val.real_, dA_val.imag_, dBC_val.real_, dBC_val.imag_);
|
424 |
-
dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val);
|
425 |
-
dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y);
|
426 |
-
dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w);
|
427 |
-
if (threadIdx.x == 0) {
|
428 |
-
smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx];
|
429 |
-
}
|
430 |
-
} else {
|
431 |
-
dA_val = Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val);
|
432 |
-
}
|
433 |
-
if (threadIdx.x == 0) {
|
434 |
-
smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];
|
435 |
-
}
|
436 |
-
}
|
437 |
-
}
|
438 |
-
|
439 |
-
if constexpr (kDeltaSoftplus) {
|
440 |
-
__syncthreads();
|
441 |
-
input_t delta_vals_load[kNItems];
|
442 |
-
load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
|
443 |
-
delta -= kChunkSize;
|
444 |
-
#pragma unroll
|
445 |
-
for (int i = 0; i < kNItems; ++i) {
|
446 |
-
float delta_val = float(delta_vals_load[i]) + delta_bias;
|
447 |
-
float delta_val_neg_exp = expf(-delta_val);
|
448 |
-
ddelta_vals[i] = delta_val <= 20.f
|
449 |
-
? ddelta_vals[i] / (1.f + delta_val_neg_exp)
|
450 |
-
: ddelta_vals[i];
|
451 |
-
}
|
452 |
-
}
|
453 |
-
for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; }
|
454 |
-
|
455 |
-
input_t *du = reinterpret_cast<input_t *>(params.du_ptr) + batch_id * params.du_batch_stride
|
456 |
-
+ dim_id * params.du_d_stride + chunk * kChunkSize;
|
457 |
-
input_t *ddelta = reinterpret_cast<input_t *>(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride
|
458 |
-
+ dim_id * params.ddelta_d_stride + chunk * kChunkSize;
|
459 |
-
__syncthreads();
|
460 |
-
store_output<Ktraits>(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize);
|
461 |
-
__syncthreads();
|
462 |
-
store_output<Ktraits>(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize);
|
463 |
-
|
464 |
-
Bvar -= kChunkSize * (!kIsComplex ? 1 : 2);
|
465 |
-
Cvar -= kChunkSize * (!kIsComplex ? 1 : 2);
|
466 |
-
}
|
467 |
-
if (params.dD_ptr != nullptr) {
|
468 |
-
dD_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val);
|
469 |
-
if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); }
|
470 |
-
}
|
471 |
-
if (params.ddelta_bias_ptr != nullptr) {
|
472 |
-
__syncthreads();
|
473 |
-
ddelta_bias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val);
|
474 |
-
if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); }
|
475 |
-
}
|
476 |
-
for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
|
477 |
-
gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride]), smem_da[state_idx]);
|
478 |
-
weight_t dBC_val;
|
479 |
-
if (!kIsVariableB || !kIsVariableC) { dBC_val = smem_dbc[state_idx]; }
|
480 |
-
if constexpr (!kIsVariableB) {
|
481 |
-
gpuAtomicAdd(&(dB[state_idx * params.dB_dstate_stride]),
|
482 |
-
!kIsVariableC ? dBC_val * conj(C[state_idx * params.C_dstate_stride]) : dBC_val);
|
483 |
-
}
|
484 |
-
if constexpr (!kIsVariableC) {
|
485 |
-
gpuAtomicAdd(&(dC[state_idx * params.dC_dstate_stride]),
|
486 |
-
!kIsVariableB ? dBC_val * conj(B[state_idx * params.B_dstate_stride]) : dBC_val);
|
487 |
-
}
|
488 |
-
}
|
489 |
-
}
|
490 |
-
|
491 |
-
template<int kNThreads, int kNItems, typename input_t, typename weight_t>
|
492 |
-
void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) {
|
493 |
-
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
494 |
-
BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
|
495 |
-
BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
|
496 |
-
BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] {
|
497 |
-
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
|
498 |
-
using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
|
499 |
-
// using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, true, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
|
500 |
-
// TODO: check this
|
501 |
-
constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t);
|
502 |
-
// printf("smem_size = %d\n", kSmemSize);
|
503 |
-
dim3 grid(params.batch, params.dim);
|
504 |
-
auto kernel = &selective_scan_bwd_kernel<Ktraits>;
|
505 |
-
if (kSmemSize >= 48 * 1024) {
|
506 |
-
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
507 |
-
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
508 |
-
}
|
509 |
-
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
510 |
-
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
511 |
-
});
|
512 |
-
});
|
513 |
-
});
|
514 |
-
});
|
515 |
-
});
|
516 |
-
}
|
517 |
-
|
518 |
-
template<typename input_t, typename weight_t>
|
519 |
-
void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) {
|
520 |
-
if (params.seqlen <= 128) {
|
521 |
-
selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream);
|
522 |
-
} else if (params.seqlen <= 256) {
|
523 |
-
selective_scan_bwd_launch<32, 8, input_t, weight_t>(params, stream);
|
524 |
-
} else if (params.seqlen <= 512) {
|
525 |
-
selective_scan_bwd_launch<32, 16, input_t, weight_t>(params, stream);
|
526 |
-
} else if (params.seqlen <= 1024) {
|
527 |
-
selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);
|
528 |
-
} else {
|
529 |
-
selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);
|
530 |
-
}
|
531 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mamba/csrc/selective_scan/selective_scan_common.h
DELETED
@@ -1,221 +0,0 @@
|
|
1 |
-
/******************************************************************************
|
2 |
-
* Copyright (c) 2023, Tri Dao.
|
3 |
-
******************************************************************************/
|
4 |
-
|
5 |
-
#pragma once
|
6 |
-
|
7 |
-
#include <cuda_bf16.h>
|
8 |
-
#include <cuda_fp16.h>
|
9 |
-
#include <c10/util/complex.h> // For scalar_value_type
|
10 |
-
|
11 |
-
#define MAX_DSTATE 256
|
12 |
-
|
13 |
-
using complex_t = c10::complex<float>;
|
14 |
-
|
15 |
-
inline __device__ float2 operator+(const float2 & a, const float2 & b){
|
16 |
-
return {a.x + b.x, a.y + b.y};
|
17 |
-
}
|
18 |
-
|
19 |
-
inline __device__ float3 operator+(const float3 &a, const float3 &b) {
|
20 |
-
return {a.x + b.x, a.y + b.y, a.z + b.z};
|
21 |
-
}
|
22 |
-
|
23 |
-
inline __device__ float4 operator+(const float4 & a, const float4 & b){
|
24 |
-
return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
|
25 |
-
}
|
26 |
-
|
27 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
28 |
-
|
29 |
-
template<int BYTES> struct BytesToType {};
|
30 |
-
|
31 |
-
template<> struct BytesToType<16> {
|
32 |
-
using Type = uint4;
|
33 |
-
static_assert(sizeof(Type) == 16);
|
34 |
-
};
|
35 |
-
|
36 |
-
template<> struct BytesToType<8> {
|
37 |
-
using Type = uint64_t;
|
38 |
-
static_assert(sizeof(Type) == 8);
|
39 |
-
};
|
40 |
-
|
41 |
-
template<> struct BytesToType<4> {
|
42 |
-
using Type = uint32_t;
|
43 |
-
static_assert(sizeof(Type) == 4);
|
44 |
-
};
|
45 |
-
|
46 |
-
template<> struct BytesToType<2> {
|
47 |
-
using Type = uint16_t;
|
48 |
-
static_assert(sizeof(Type) == 2);
|
49 |
-
};
|
50 |
-
|
51 |
-
template<> struct BytesToType<1> {
|
52 |
-
using Type = uint8_t;
|
53 |
-
static_assert(sizeof(Type) == 1);
|
54 |
-
};
|
55 |
-
|
56 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
57 |
-
|
58 |
-
template<typename scalar_t, int N>
|
59 |
-
struct Converter{
|
60 |
-
static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) {
|
61 |
-
#pragma unroll
|
62 |
-
for (int i = 0; i < N; ++i) { dst[i] = src[i]; }
|
63 |
-
}
|
64 |
-
};
|
65 |
-
|
66 |
-
template<int N>
|
67 |
-
struct Converter<at::Half, N>{
|
68 |
-
static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) {
|
69 |
-
static_assert(N % 2 == 0);
|
70 |
-
auto &src2 = reinterpret_cast<const half2 (&)[N / 2]>(src);
|
71 |
-
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
|
72 |
-
#pragma unroll
|
73 |
-
for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); }
|
74 |
-
}
|
75 |
-
};
|
76 |
-
|
77 |
-
#if __CUDA_ARCH__ >= 800
|
78 |
-
template<int N>
|
79 |
-
struct Converter<at::BFloat16, N>{
|
80 |
-
static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) {
|
81 |
-
static_assert(N % 2 == 0);
|
82 |
-
auto &src2 = reinterpret_cast<const nv_bfloat162 (&)[N / 2]>(src);
|
83 |
-
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
|
84 |
-
#pragma unroll
|
85 |
-
for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); }
|
86 |
-
}
|
87 |
-
};
|
88 |
-
#endif
|
89 |
-
|
90 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
91 |
-
|
92 |
-
// From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp
|
93 |
-
// and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696
|
94 |
-
__device__ __forceinline__ complex_t cexp2f(complex_t z) {
|
95 |
-
float t = exp2f(z.real_);
|
96 |
-
float c, s;
|
97 |
-
sincosf(z.imag_, &s, &c);
|
98 |
-
return complex_t(c * t, s * t);
|
99 |
-
}
|
100 |
-
|
101 |
-
__device__ __forceinline__ complex_t cexpf(complex_t z) {
|
102 |
-
float t = expf(z.real_);
|
103 |
-
float c, s;
|
104 |
-
sincosf(z.imag_, &s, &c);
|
105 |
-
return complex_t(c * t, s * t);
|
106 |
-
}
|
107 |
-
|
108 |
-
template<typename scalar_t> struct SSMScanOp;
|
109 |
-
|
110 |
-
template<>
|
111 |
-
struct SSMScanOp<float> {
|
112 |
-
__device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const {
|
113 |
-
return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
|
114 |
-
}
|
115 |
-
};
|
116 |
-
|
117 |
-
template<>
|
118 |
-
struct SSMScanOp<complex_t> {
|
119 |
-
__device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const {
|
120 |
-
complex_t a0 = complex_t(ab0.x, ab0.y);
|
121 |
-
complex_t b0 = complex_t(ab0.z, ab0.w);
|
122 |
-
complex_t a1 = complex_t(ab1.x, ab1.y);
|
123 |
-
complex_t b1 = complex_t(ab1.z, ab1.w);
|
124 |
-
complex_t out_a = a1 * a0;
|
125 |
-
complex_t out_b = a1 * b0 + b1;
|
126 |
-
return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_);
|
127 |
-
}
|
128 |
-
};
|
129 |
-
|
130 |
-
// A stateful callback functor that maintains a running prefix to be applied
|
131 |
-
// during consecutive scan operations.
|
132 |
-
template <typename scalar_t> struct SSMScanPrefixCallbackOp {
|
133 |
-
using scan_t = std::conditional_t<std::is_same_v<scalar_t, float>, float2, float4>;
|
134 |
-
scan_t running_prefix;
|
135 |
-
// Constructor
|
136 |
-
__device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {}
|
137 |
-
// Callback operator to be entered by the first warp of threads in the block.
|
138 |
-
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
|
139 |
-
__device__ scan_t operator()(scan_t block_aggregate) {
|
140 |
-
scan_t old_prefix = running_prefix;
|
141 |
-
running_prefix = SSMScanOp<scalar_t>()(running_prefix, block_aggregate);
|
142 |
-
return old_prefix;
|
143 |
-
}
|
144 |
-
};
|
145 |
-
|
146 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
147 |
-
|
148 |
-
template<typename Ktraits>
|
149 |
-
inline __device__ void load_input(typename Ktraits::input_t *u,
|
150 |
-
typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
|
151 |
-
typename Ktraits::BlockLoadT::TempStorage &smem_load,
|
152 |
-
int seqlen) {
|
153 |
-
if constexpr (Ktraits::kIsEvenLen) {
|
154 |
-
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
|
155 |
-
using vec_t = typename Ktraits::vec_t;
|
156 |
-
Ktraits::BlockLoadVecT(smem_load_vec).Load(
|
157 |
-
reinterpret_cast<vec_t*>(u),
|
158 |
-
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(u_vals)
|
159 |
-
);
|
160 |
-
} else {
|
161 |
-
Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);
|
162 |
-
}
|
163 |
-
}
|
164 |
-
|
165 |
-
template<typename Ktraits>
|
166 |
-
inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
|
167 |
-
typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
|
168 |
-
typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight,
|
169 |
-
int seqlen) {
|
170 |
-
constexpr int kNItems = Ktraits::kNItems;
|
171 |
-
if constexpr (!Ktraits::kIsComplex) {
|
172 |
-
typename Ktraits::input_t B_vals_load[kNItems];
|
173 |
-
if constexpr (Ktraits::kIsEvenLen) {
|
174 |
-
auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
|
175 |
-
using vec_t = typename Ktraits::vec_t;
|
176 |
-
Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
|
177 |
-
reinterpret_cast<vec_t*>(Bvar),
|
178 |
-
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(B_vals_load)
|
179 |
-
);
|
180 |
-
} else {
|
181 |
-
Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
|
182 |
-
}
|
183 |
-
// #pragma unroll
|
184 |
-
// for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
|
185 |
-
Converter<typename Ktraits::input_t, kNItems>::to_float(B_vals_load, B_vals);
|
186 |
-
} else {
|
187 |
-
typename Ktraits::input_t B_vals_load[kNItems * 2];
|
188 |
-
if constexpr (Ktraits::kIsEvenLen) {
|
189 |
-
auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
|
190 |
-
using vec_t = typename Ktraits::vec_t;
|
191 |
-
Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
|
192 |
-
reinterpret_cast<vec_t*>(Bvar),
|
193 |
-
reinterpret_cast<vec_t(&)[Ktraits::kNLoads * 2]>(B_vals_load)
|
194 |
-
);
|
195 |
-
} else {
|
196 |
-
Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
|
197 |
-
}
|
198 |
-
#pragma unroll
|
199 |
-
for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); }
|
200 |
-
}
|
201 |
-
}
|
202 |
-
|
203 |
-
template<typename Ktraits>
|
204 |
-
inline __device__ void store_output(typename Ktraits::input_t *out,
|
205 |
-
const float (&out_vals)[Ktraits::kNItems],
|
206 |
-
typename Ktraits::BlockStoreT::TempStorage &smem_store,
|
207 |
-
int seqlen) {
|
208 |
-
typename Ktraits::input_t write_vals[Ktraits::kNItems];
|
209 |
-
#pragma unroll
|
210 |
-
for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
|
211 |
-
if constexpr (Ktraits::kIsEvenLen) {
|
212 |
-
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
|
213 |
-
using vec_t = typename Ktraits::vec_t;
|
214 |
-
Ktraits::BlockStoreVecT(smem_store_vec).Store(
|
215 |
-
reinterpret_cast<vec_t*>(out),
|
216 |
-
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(write_vals)
|
217 |
-
);
|
218 |
-
} else {
|
219 |
-
Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);
|
220 |
-
}
|
221 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mamba/csrc/selective_scan/selective_scan_fwd_bf16.cu
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
/******************************************************************************
|
2 |
-
* Copyright (c) 2023, Tri Dao.
|
3 |
-
******************************************************************************/
|
4 |
-
|
5 |
-
// Split into multiple files to compile in paralell
|
6 |
-
|
7 |
-
#include "selective_scan_fwd_kernel.cuh"
|
8 |
-
|
9 |
-
template void selective_scan_fwd_cuda<at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
10 |
-
template void selective_scan_fwd_cuda<at::BFloat16, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mamba/csrc/selective_scan/selective_scan_fwd_fp16.cu
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
/******************************************************************************
|
2 |
-
* Copyright (c) 2023, Tri Dao.
|
3 |
-
******************************************************************************/
|
4 |
-
|
5 |
-
// Split into multiple files to compile in paralell
|
6 |
-
|
7 |
-
#include "selective_scan_fwd_kernel.cuh"
|
8 |
-
|
9 |
-
template void selective_scan_fwd_cuda<at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
10 |
-
template void selective_scan_fwd_cuda<at::Half, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mamba/csrc/selective_scan/selective_scan_fwd_fp32.cu
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
/******************************************************************************
|
2 |
-
* Copyright (c) 2023, Tri Dao.
|
3 |
-
******************************************************************************/
|
4 |
-
|
5 |
-
// Split into multiple files to compile in paralell
|
6 |
-
|
7 |
-
#include "selective_scan_fwd_kernel.cuh"
|
8 |
-
|
9 |
-
template void selective_scan_fwd_cuda<float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
10 |
-
template void selective_scan_fwd_cuda<float, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mamba/csrc/selective_scan/selective_scan_fwd_kernel.cuh
DELETED
@@ -1,345 +0,0 @@
|
|
1 |
-
/******************************************************************************
|
2 |
-
* Copyright (c) 2023, Tri Dao.
|
3 |
-
******************************************************************************/
|
4 |
-
|
5 |
-
#pragma once
|
6 |
-
|
7 |
-
#include <c10/util/BFloat16.h>
|
8 |
-
#include <c10/util/Half.h>
|
9 |
-
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
10 |
-
|
11 |
-
#include <cub/block/block_load.cuh>
|
12 |
-
#include <cub/block/block_store.cuh>
|
13 |
-
#include <cub/block/block_scan.cuh>
|
14 |
-
|
15 |
-
#include "selective_scan.h"
|
16 |
-
#include "selective_scan_common.h"
|
17 |
-
#include "static_switch.h"
|
18 |
-
|
19 |
-
template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
|
20 |
-
bool kIsVariableB_, bool kIsVariableC_,
|
21 |
-
bool kHasZ_, typename input_t_, typename weight_t_>
|
22 |
-
struct Selective_Scan_fwd_kernel_traits {
|
23 |
-
static_assert(kNItems_ % 4 == 0);
|
24 |
-
using input_t = input_t_;
|
25 |
-
using weight_t = weight_t_;
|
26 |
-
static constexpr int kNThreads = kNThreads_;
|
27 |
-
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
|
28 |
-
static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
|
29 |
-
static constexpr int kNItems = kNItems_;
|
30 |
-
static constexpr int kNRows = kNRows_;
|
31 |
-
static constexpr int kNBytes = sizeof(input_t);
|
32 |
-
static_assert(kNBytes == 2 || kNBytes == 4);
|
33 |
-
static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
|
34 |
-
static_assert(kNItems % kNElts == 0);
|
35 |
-
static constexpr int kNLoads = kNItems / kNElts;
|
36 |
-
static constexpr bool kIsComplex = std::is_same_v<weight_t, complex_t>;
|
37 |
-
static constexpr bool kIsEvenLen = kIsEvenLen_;
|
38 |
-
static constexpr bool kIsVariableB = kIsVariableB_;
|
39 |
-
static constexpr bool kIsVariableC = kIsVariableC_;
|
40 |
-
static constexpr bool kHasZ = kHasZ_;
|
41 |
-
|
42 |
-
static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;
|
43 |
-
|
44 |
-
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
45 |
-
using scan_t = std::conditional_t<!kIsComplex, float2, float4>;
|
46 |
-
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
47 |
-
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
|
48 |
-
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
49 |
-
using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, !kIsComplex ? kNItems : kNItems * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
50 |
-
using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, !kIsComplex ? kNLoads : kNLoads * 2,
|
51 |
-
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
52 |
-
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
53 |
-
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads,
|
54 |
-
!kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
|
55 |
-
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
|
56 |
-
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
|
57 |
-
using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
58 |
-
static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
|
59 |
-
sizeof(typename BlockLoadVecT::TempStorage),
|
60 |
-
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
|
61 |
-
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
|
62 |
-
sizeof(typename BlockStoreT::TempStorage),
|
63 |
-
sizeof(typename BlockStoreVecT::TempStorage)});
|
64 |
-
static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
|
65 |
-
};
|
66 |
-
|
67 |
-
template<typename Ktraits>
|
68 |
-
__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
|
69 |
-
void selective_scan_fwd_kernel(SSMParamsBase params) {
|
70 |
-
constexpr bool kIsComplex = Ktraits::kIsComplex;
|
71 |
-
constexpr bool kIsVariableB = Ktraits::kIsVariableB;
|
72 |
-
constexpr bool kIsVariableC = Ktraits::kIsVariableC;
|
73 |
-
constexpr bool kHasZ = Ktraits::kHasZ;
|
74 |
-
constexpr int kNThreads = Ktraits::kNThreads;
|
75 |
-
constexpr int kNItems = Ktraits::kNItems;
|
76 |
-
constexpr int kNRows = Ktraits::kNRows;
|
77 |
-
constexpr bool kDirectIO = Ktraits::kDirectIO;
|
78 |
-
using input_t = typename Ktraits::input_t;
|
79 |
-
using weight_t = typename Ktraits::weight_t;
|
80 |
-
using scan_t = typename Ktraits::scan_t;
|
81 |
-
|
82 |
-
// Shared memory.
|
83 |
-
extern __shared__ char smem_[];
|
84 |
-
// cast to lvalue reference of expected type
|
85 |
-
// char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
|
86 |
-
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
|
87 |
-
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
|
88 |
-
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
89 |
-
auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
|
90 |
-
auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
|
91 |
-
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
92 |
-
auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
|
93 |
-
// weight_t *smem_a = reinterpret_cast<weight_t *>(smem_ + smem_loadstorescan_size);
|
94 |
-
// weight_t *smem_bc = reinterpret_cast<weight_t *>(smem_a + MAX_DSTATE);
|
95 |
-
scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize);
|
96 |
-
|
97 |
-
const int batch_id = blockIdx.x;
|
98 |
-
const int dim_id = blockIdx.y;
|
99 |
-
const int group_id = dim_id / (params.dim_ngroups_ratio);
|
100 |
-
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
|
101 |
-
+ dim_id * kNRows * params.u_d_stride;
|
102 |
-
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
|
103 |
-
+ dim_id * kNRows * params.delta_d_stride;
|
104 |
-
weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * kNRows * params.A_d_stride;
|
105 |
-
weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * kNRows * params.B_d_stride;
|
106 |
-
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
|
107 |
-
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
|
108 |
-
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
|
109 |
-
scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate;
|
110 |
-
|
111 |
-
float D_val[kNRows] = {0};
|
112 |
-
if (params.D_ptr != nullptr) {
|
113 |
-
#pragma unroll
|
114 |
-
for (int r = 0; r < kNRows; ++r) {
|
115 |
-
D_val[r] = reinterpret_cast<float *>(params.D_ptr)[dim_id * kNRows + r];
|
116 |
-
}
|
117 |
-
}
|
118 |
-
float delta_bias[kNRows] = {0};
|
119 |
-
if (params.delta_bias_ptr != nullptr) {
|
120 |
-
#pragma unroll
|
121 |
-
for (int r = 0; r < kNRows; ++r) {
|
122 |
-
delta_bias[r] = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id * kNRows + r];
|
123 |
-
}
|
124 |
-
}
|
125 |
-
|
126 |
-
// for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
|
127 |
-
// smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
|
128 |
-
// smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];
|
129 |
-
// }
|
130 |
-
|
131 |
-
constexpr int kChunkSize = kNThreads * kNItems;
|
132 |
-
for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
|
133 |
-
input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
|
134 |
-
__syncthreads();
|
135 |
-
#pragma unroll
|
136 |
-
for (int r = 0; r < kNRows; ++r) {
|
137 |
-
if constexpr (!kDirectIO) {
|
138 |
-
if (r > 0) { __syncthreads(); }
|
139 |
-
}
|
140 |
-
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize);
|
141 |
-
if constexpr (!kDirectIO) { __syncthreads(); }
|
142 |
-
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize);
|
143 |
-
}
|
144 |
-
u += kChunkSize;
|
145 |
-
delta += kChunkSize;
|
146 |
-
|
147 |
-
float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
|
148 |
-
#pragma unroll
|
149 |
-
for (int r = 0; r < kNRows; ++r) {
|
150 |
-
#pragma unroll
|
151 |
-
for (int i = 0; i < kNItems; ++i) {
|
152 |
-
float u_val = float(u_vals[r][i]);
|
153 |
-
delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r];
|
154 |
-
if (params.delta_softplus) {
|
155 |
-
delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i];
|
156 |
-
}
|
157 |
-
delta_u_vals[r][i] = delta_vals[r][i] * u_val;
|
158 |
-
out_vals[r][i] = D_val[r] * u_val;
|
159 |
-
}
|
160 |
-
}
|
161 |
-
|
162 |
-
__syncthreads();
|
163 |
-
for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
|
164 |
-
weight_t A_val[kNRows];
|
165 |
-
#pragma unroll
|
166 |
-
for (int r = 0; r < kNRows; ++r) {
|
167 |
-
A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride];
|
168 |
-
// Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
|
169 |
-
constexpr float kLog2e = M_LOG2E;
|
170 |
-
if constexpr (!kIsComplex) {
|
171 |
-
A_val[r] *= kLog2e;
|
172 |
-
} else {
|
173 |
-
A_val[r].real_ *= kLog2e;
|
174 |
-
}
|
175 |
-
}
|
176 |
-
// This variable holds B * C if both B and C are constant across seqlen. If only B varies
|
177 |
-
// across seqlen, this holds C. If only C varies across seqlen, this holds B.
|
178 |
-
// If both B and C vary, this is unused.
|
179 |
-
weight_t BC_val[kNRows];
|
180 |
-
weight_t B_vals[kNItems], C_vals[kNItems];
|
181 |
-
if constexpr (kIsVariableB) {
|
182 |
-
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
|
183 |
-
smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
|
184 |
-
if constexpr (!kIsVariableC) {
|
185 |
-
#pragma unroll
|
186 |
-
for (int r = 0; r < kNRows; ++r) {
|
187 |
-
BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
|
188 |
-
}
|
189 |
-
}
|
190 |
-
}
|
191 |
-
if constexpr (kIsVariableC) {
|
192 |
-
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
|
193 |
-
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
|
194 |
-
smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
|
195 |
-
if constexpr (!kIsVariableB) {
|
196 |
-
#pragma unroll
|
197 |
-
for (int r = 0; r < kNRows; ++r) {
|
198 |
-
BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride];
|
199 |
-
}
|
200 |
-
}
|
201 |
-
}
|
202 |
-
if constexpr (!kIsVariableB && !kIsVariableC) {
|
203 |
-
#pragma unroll
|
204 |
-
for (int r = 0; r < kNRows; ++r) {
|
205 |
-
BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
|
206 |
-
}
|
207 |
-
}
|
208 |
-
|
209 |
-
#pragma unroll
|
210 |
-
for (int r = 0; r < kNRows; ++r) {
|
211 |
-
if (r > 0) { __syncthreads(); } // Scan could be using the same smem
|
212 |
-
scan_t thread_data[kNItems];
|
213 |
-
#pragma unroll
|
214 |
-
for (int i = 0; i < kNItems; ++i) {
|
215 |
-
if constexpr (!kIsComplex) {
|
216 |
-
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
|
217 |
-
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
|
218 |
-
if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
|
219 |
-
if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
|
220 |
-
thread_data[i] = make_float2(1.f, 0.f);
|
221 |
-
}
|
222 |
-
}
|
223 |
-
} else {
|
224 |
-
// Pytorch's implementation of complex exp (which calls thrust) is very slow
|
225 |
-
complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_val[r]);
|
226 |
-
weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i];
|
227 |
-
thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);
|
228 |
-
if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
|
229 |
-
if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
|
230 |
-
thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f);
|
231 |
-
}
|
232 |
-
}
|
233 |
-
}
|
234 |
-
}
|
235 |
-
// Initialize running total
|
236 |
-
scan_t running_prefix;
|
237 |
-
if constexpr (!kIsComplex) {
|
238 |
-
// If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read
|
239 |
-
running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f);
|
240 |
-
// running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
|
241 |
-
} else {
|
242 |
-
running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float4(1.f, 0.f, 0.f, 0.f);
|
243 |
-
// running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
|
244 |
-
}
|
245 |
-
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
246 |
-
Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
247 |
-
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
|
248 |
-
);
|
249 |
-
// There's a syncthreads in the scan op, so we don't need to sync here.
|
250 |
-
// Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
|
251 |
-
if (threadIdx.x == 0) {
|
252 |
-
smem_running_prefix[state_idx] = prefix_op.running_prefix;
|
253 |
-
x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix;
|
254 |
-
}
|
255 |
-
#pragma unroll
|
256 |
-
for (int i = 0; i < kNItems; ++i) {
|
257 |
-
const weight_t C_val = !kIsVariableC
|
258 |
-
? BC_val[r]
|
259 |
-
: (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]);
|
260 |
-
if constexpr (!kIsComplex) {
|
261 |
-
out_vals[r][i] += thread_data[i].y * C_val;
|
262 |
-
} else {
|
263 |
-
out_vals[r][i] += (complex_t(thread_data[i].z, thread_data[i].w) * C_val).real_ * 2;
|
264 |
-
}
|
265 |
-
}
|
266 |
-
}
|
267 |
-
}
|
268 |
-
|
269 |
-
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
270 |
-
+ dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
|
271 |
-
__syncthreads();
|
272 |
-
#pragma unroll
|
273 |
-
for (int r = 0; r < kNRows; ++r) {
|
274 |
-
if constexpr (!kDirectIO) {
|
275 |
-
if (r > 0) { __syncthreads(); }
|
276 |
-
}
|
277 |
-
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
|
278 |
-
}
|
279 |
-
|
280 |
-
if constexpr (kHasZ) {
|
281 |
-
input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride
|
282 |
-
+ dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
|
283 |
-
input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride
|
284 |
-
+ dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize;
|
285 |
-
#pragma unroll
|
286 |
-
for (int r = 0; r < kNRows; ++r) {
|
287 |
-
input_t z_vals[kNItems];
|
288 |
-
__syncthreads();
|
289 |
-
load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize);
|
290 |
-
#pragma unroll
|
291 |
-
for (int i = 0; i < kNItems; ++i) {
|
292 |
-
float z_val = z_vals[i];
|
293 |
-
out_vals[r][i] *= z_val / (1 + expf(-z_val));
|
294 |
-
}
|
295 |
-
__syncthreads();
|
296 |
-
store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
|
297 |
-
}
|
298 |
-
}
|
299 |
-
|
300 |
-
Bvar += kChunkSize * (!kIsComplex ? 1 : 2);
|
301 |
-
Cvar += kChunkSize * (!kIsComplex ? 1 : 2);
|
302 |
-
}
|
303 |
-
}
|
304 |
-
|
305 |
-
template<int kNThreads, int kNItems, typename input_t, typename weight_t>
|
306 |
-
void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
307 |
-
// Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
|
308 |
-
// processing 1 row.
|
309 |
-
constexpr int kNRows = 1;
|
310 |
-
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
311 |
-
BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
|
312 |
-
BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
|
313 |
-
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
|
314 |
-
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, input_t, weight_t>;
|
315 |
-
// constexpr int kSmemSize = Ktraits::kSmemSize;
|
316 |
-
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
|
317 |
-
// printf("smem_size = %d\n", kSmemSize);
|
318 |
-
dim3 grid(params.batch, params.dim / kNRows);
|
319 |
-
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
320 |
-
if (kSmemSize >= 48 * 1024) {
|
321 |
-
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
322 |
-
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
323 |
-
}
|
324 |
-
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
325 |
-
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
326 |
-
});
|
327 |
-
});
|
328 |
-
});
|
329 |
-
});
|
330 |
-
}
|
331 |
-
|
332 |
-
template<typename input_t, typename weight_t>
|
333 |
-
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) {
|
334 |
-
if (params.seqlen <= 128) {
|
335 |
-
selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
|
336 |
-
} else if (params.seqlen <= 256) {
|
337 |
-
selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream);
|
338 |
-
} else if (params.seqlen <= 512) {
|
339 |
-
selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream);
|
340 |
-
} else if (params.seqlen <= 1024) {
|
341 |
-
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
|
342 |
-
} else {
|
343 |
-
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
|
344 |
-
}
|
345 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mamba/csrc/selective_scan/static_switch.h
DELETED
@@ -1,25 +0,0 @@
|
|
1 |
-
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
2 |
-
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
3 |
-
|
4 |
-
#pragma once
|
5 |
-
|
6 |
-
/// @param COND - a boolean expression to switch by
|
7 |
-
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
8 |
-
/// @param ... - code to execute for true and false
|
9 |
-
///
|
10 |
-
/// Usage:
|
11 |
-
/// ```
|
12 |
-
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
13 |
-
/// some_function<BoolConst>(...);
|
14 |
-
/// });
|
15 |
-
/// ```
|
16 |
-
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
17 |
-
[&] { \
|
18 |
-
if (COND) { \
|
19 |
-
constexpr bool CONST_NAME = true; \
|
20 |
-
return __VA_ARGS__(); \
|
21 |
-
} else { \
|
22 |
-
constexpr bool CONST_NAME = false; \
|
23 |
-
return __VA_ARGS__(); \
|
24 |
-
} \
|
25 |
-
}()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mamba/csrc/selective_scan/uninitialized_copy.cuh
DELETED
@@ -1,69 +0,0 @@
|
|
1 |
-
/******************************************************************************
|
2 |
-
* Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved.
|
3 |
-
*
|
4 |
-
* Redistribution and use in source and binary forms, with or without
|
5 |
-
* modification, are permitted provided that the following conditions are met:
|
6 |
-
* * Redistributions of source code must retain the above copyright
|
7 |
-
* notice, this list of conditions and the following disclaimer.
|
8 |
-
* * Redistributions in binary form must reproduce the above copyright
|
9 |
-
* notice, this list of conditions and the following disclaimer in the
|
10 |
-
* documentation and/or other materials provided with the distribution.
|
11 |
-
* * Neither the name of the NVIDIA CORPORATION nor the
|
12 |
-
* names of its contributors may be used to endorse or promote products
|
13 |
-
* derived from this software without specific prior written permission.
|
14 |
-
*
|
15 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
16 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
17 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
18 |
-
* ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
19 |
-
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
20 |
-
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
21 |
-
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
22 |
-
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
23 |
-
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
24 |
-
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
25 |
-
*
|
26 |
-
******************************************************************************/
|
27 |
-
|
28 |
-
#pragma once
|
29 |
-
|
30 |
-
#include <cub/config.cuh>
|
31 |
-
|
32 |
-
#include <cuda/std/type_traits>
|
33 |
-
|
34 |
-
|
35 |
-
namespace detail
|
36 |
-
{
|
37 |
-
|
38 |
-
#if defined(_NVHPC_CUDA)
|
39 |
-
template <typename T, typename U>
|
40 |
-
__host__ __device__ void uninitialized_copy(T *ptr, U &&val)
|
41 |
-
{
|
42 |
-
// NVBug 3384810
|
43 |
-
new (ptr) T(::cuda::std::forward<U>(val));
|
44 |
-
}
|
45 |
-
#else
|
46 |
-
template <typename T,
|
47 |
-
typename U,
|
48 |
-
typename ::cuda::std::enable_if<
|
49 |
-
::cuda::std::is_trivially_copyable<T>::value,
|
50 |
-
int
|
51 |
-
>::type = 0>
|
52 |
-
__host__ __device__ void uninitialized_copy(T *ptr, U &&val)
|
53 |
-
{
|
54 |
-
*ptr = ::cuda::std::forward<U>(val);
|
55 |
-
}
|
56 |
-
|
57 |
-
template <typename T,
|
58 |
-
typename U,
|
59 |
-
typename ::cuda::std::enable_if<
|
60 |
-
!::cuda::std::is_trivially_copyable<T>::value,
|
61 |
-
int
|
62 |
-
>::type = 0>
|
63 |
-
__host__ __device__ void uninitialized_copy(T *ptr, U &&val)
|
64 |
-
{
|
65 |
-
new (ptr) T(::cuda::std::forward<U>(val));
|
66 |
-
}
|
67 |
-
#endif
|
68 |
-
|
69 |
-
} // namespace detail
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mamba/evals/lm_harness_eval.py
DELETED
@@ -1,39 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
import transformers
|
4 |
-
from transformers import AutoTokenizer
|
5 |
-
|
6 |
-
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
7 |
-
|
8 |
-
from lm_eval.api.model import LM
|
9 |
-
from lm_eval.models.huggingface import HFLM
|
10 |
-
from lm_eval.api.registry import register_model
|
11 |
-
from lm_eval.__main__ import cli_evaluate
|
12 |
-
|
13 |
-
|
14 |
-
@register_model("mamba")
|
15 |
-
class MambaEvalWrapper(HFLM):
|
16 |
-
|
17 |
-
AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
|
18 |
-
|
19 |
-
def __init__(self, pretrained="state-spaces/mamba-2.8b", max_length=2048, batch_size=None, device="cuda",
|
20 |
-
dtype=torch.float16):
|
21 |
-
LM.__init__(self)
|
22 |
-
self._model = MambaLMHeadModel.from_pretrained(pretrained, device=device, dtype=dtype)
|
23 |
-
self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
24 |
-
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
25 |
-
self.vocab_size = self.tokenizer.vocab_size
|
26 |
-
self._batch_size = batch_size if batch_size is None else 64
|
27 |
-
self._max_length = max_length
|
28 |
-
self._device = torch.device(device)
|
29 |
-
|
30 |
-
@property
|
31 |
-
def batch_size(self):
|
32 |
-
return self._batch_size
|
33 |
-
|
34 |
-
def _model_generate(self, context, max_length, stop, **generation_kwargs):
|
35 |
-
raise NotImplementedError()
|
36 |
-
|
37 |
-
|
38 |
-
if __name__ == "__main__":
|
39 |
-
cli_evaluate()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mamba/mamba_ssm/__init__.py
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
__version__ = "1.0.1"
|
2 |
-
|
3 |
-
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn
|
4 |
-
from mamba_ssm.modules.mamba_simple import Mamba
|
5 |
-
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
|
|
|
|
|
|
|
|
|
|
|
mamba/mamba_ssm/models/__init__.py
DELETED
File without changes
|
mamba/mamba_ssm/models/mixer_seq_simple.py
DELETED
@@ -1,233 +0,0 @@
|
|
1 |
-
# Copyright (c) 2023, Albert Gu, Tri Dao.
|
2 |
-
|
3 |
-
import math
|
4 |
-
from functools import partial
|
5 |
-
|
6 |
-
from collections import namedtuple
|
7 |
-
|
8 |
-
import torch
|
9 |
-
import torch.nn as nn
|
10 |
-
|
11 |
-
from mamba_ssm.modules.mamba_simple import Mamba, Block
|
12 |
-
from mamba_ssm.utils.generation import GenerationMixin
|
13 |
-
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
|
14 |
-
|
15 |
-
try:
|
16 |
-
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
|
17 |
-
except ImportError:
|
18 |
-
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
19 |
-
|
20 |
-
|
21 |
-
def create_block(
|
22 |
-
d_model,
|
23 |
-
ssm_cfg=None,
|
24 |
-
norm_epsilon=1e-5,
|
25 |
-
rms_norm=False,
|
26 |
-
residual_in_fp32=False,
|
27 |
-
fused_add_norm=False,
|
28 |
-
layer_idx=None,
|
29 |
-
device=None,
|
30 |
-
dtype=None,
|
31 |
-
):
|
32 |
-
if ssm_cfg is None:
|
33 |
-
ssm_cfg = {}
|
34 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
35 |
-
mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
|
36 |
-
norm_cls = partial(
|
37 |
-
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
38 |
-
)
|
39 |
-
block = Block(
|
40 |
-
d_model,
|
41 |
-
mixer_cls,
|
42 |
-
norm_cls=norm_cls,
|
43 |
-
fused_add_norm=fused_add_norm,
|
44 |
-
residual_in_fp32=residual_in_fp32,
|
45 |
-
)
|
46 |
-
block.layer_idx = layer_idx
|
47 |
-
return block
|
48 |
-
|
49 |
-
|
50 |
-
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
51 |
-
def _init_weights(
|
52 |
-
module,
|
53 |
-
n_layer,
|
54 |
-
initializer_range=0.02, # Now only used for embedding layer.
|
55 |
-
rescale_prenorm_residual=True,
|
56 |
-
n_residuals_per_layer=1, # Change to 2 if we have MLP
|
57 |
-
):
|
58 |
-
if isinstance(module, nn.Linear):
|
59 |
-
if module.bias is not None:
|
60 |
-
if not getattr(module.bias, "_no_reinit", False):
|
61 |
-
nn.init.zeros_(module.bias)
|
62 |
-
elif isinstance(module, nn.Embedding):
|
63 |
-
nn.init.normal_(module.weight, std=initializer_range)
|
64 |
-
|
65 |
-
if rescale_prenorm_residual:
|
66 |
-
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
67 |
-
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
68 |
-
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
69 |
-
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
70 |
-
#
|
71 |
-
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
72 |
-
for name, p in module.named_parameters():
|
73 |
-
if name in ["out_proj.weight", "fc2.weight"]:
|
74 |
-
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
75 |
-
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
76 |
-
# We need to reinit p since this code could be called multiple times
|
77 |
-
# Having just p *= scale would repeatedly scale it down
|
78 |
-
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
79 |
-
with torch.no_grad():
|
80 |
-
p /= math.sqrt(n_residuals_per_layer * n_layer)
|
81 |
-
|
82 |
-
|
83 |
-
class MixerModel(nn.Module):
|
84 |
-
def __init__(
|
85 |
-
self,
|
86 |
-
d_model: int,
|
87 |
-
n_layer: int,
|
88 |
-
vocab_size: int,
|
89 |
-
ssm_cfg=None,
|
90 |
-
norm_epsilon: float = 1e-5,
|
91 |
-
rms_norm: bool = False,
|
92 |
-
initializer_cfg=None,
|
93 |
-
fused_add_norm=False,
|
94 |
-
residual_in_fp32=False,
|
95 |
-
device=None,
|
96 |
-
dtype=None,
|
97 |
-
) -> None:
|
98 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
99 |
-
super().__init__()
|
100 |
-
self.residual_in_fp32 = residual_in_fp32
|
101 |
-
|
102 |
-
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
|
103 |
-
|
104 |
-
# We change the order of residual and layer norm:
|
105 |
-
# Instead of LN -> Attn / MLP -> Add, we do:
|
106 |
-
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
|
107 |
-
# the main branch (output of MLP / Mixer). The model definition is unchanged.
|
108 |
-
# This is for performance reason: we can fuse add + layer_norm.
|
109 |
-
self.fused_add_norm = fused_add_norm
|
110 |
-
if self.fused_add_norm:
|
111 |
-
if layer_norm_fn is None or rms_norm_fn is None:
|
112 |
-
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
|
113 |
-
|
114 |
-
self.layers = nn.ModuleList(
|
115 |
-
[
|
116 |
-
create_block(
|
117 |
-
d_model,
|
118 |
-
ssm_cfg=ssm_cfg,
|
119 |
-
norm_epsilon=norm_epsilon,
|
120 |
-
rms_norm=rms_norm,
|
121 |
-
residual_in_fp32=residual_in_fp32,
|
122 |
-
fused_add_norm=fused_add_norm,
|
123 |
-
layer_idx=i,
|
124 |
-
**factory_kwargs,
|
125 |
-
)
|
126 |
-
for i in range(n_layer)
|
127 |
-
]
|
128 |
-
)
|
129 |
-
|
130 |
-
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
|
131 |
-
d_model, eps=norm_epsilon, **factory_kwargs
|
132 |
-
)
|
133 |
-
|
134 |
-
self.apply(
|
135 |
-
partial(
|
136 |
-
_init_weights,
|
137 |
-
n_layer=n_layer,
|
138 |
-
**(initializer_cfg if initializer_cfg is not None else {}),
|
139 |
-
)
|
140 |
-
)
|
141 |
-
|
142 |
-
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
143 |
-
return {
|
144 |
-
i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
145 |
-
for i, layer in enumerate(self.layers)
|
146 |
-
}
|
147 |
-
|
148 |
-
def forward(self, input_ids, inference_params=None):
|
149 |
-
hidden_states = self.embedding(input_ids)
|
150 |
-
residual = None
|
151 |
-
for layer in self.layers:
|
152 |
-
hidden_states, residual = layer(
|
153 |
-
hidden_states, residual, inference_params=inference_params
|
154 |
-
)
|
155 |
-
if not self.fused_add_norm:
|
156 |
-
residual = (hidden_states + residual) if residual is not None else hidden_states
|
157 |
-
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
158 |
-
else:
|
159 |
-
# Set prenorm=False here since we don't need the residual
|
160 |
-
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
|
161 |
-
hidden_states = fused_add_norm_fn(
|
162 |
-
hidden_states,
|
163 |
-
self.norm_f.weight,
|
164 |
-
self.norm_f.bias,
|
165 |
-
eps=self.norm_f.eps,
|
166 |
-
residual=residual,
|
167 |
-
prenorm=False,
|
168 |
-
residual_in_fp32=self.residual_in_fp32,
|
169 |
-
)
|
170 |
-
return hidden_states
|
171 |
-
|
172 |
-
|
173 |
-
class MambaLMHeadModel(nn.Module, GenerationMixin):
|
174 |
-
|
175 |
-
def __init__(
|
176 |
-
self,
|
177 |
-
d_model: int,
|
178 |
-
n_layer: int,
|
179 |
-
vocab_size: int,
|
180 |
-
initializer_cfg=None,
|
181 |
-
pad_vocab_size_multiple: int = 1,
|
182 |
-
device=None,
|
183 |
-
dtype=None,
|
184 |
-
**backbone_kwargs,
|
185 |
-
) -> None:
|
186 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
187 |
-
super().__init__()
|
188 |
-
if vocab_size % pad_vocab_size_multiple != 0:
|
189 |
-
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
|
190 |
-
self.backbone = MixerModel(
|
191 |
-
d_model=d_model,
|
192 |
-
n_layer=n_layer,
|
193 |
-
vocab_size=vocab_size,
|
194 |
-
initializer_cfg=initializer_cfg,
|
195 |
-
**backbone_kwargs,
|
196 |
-
**factory_kwargs,
|
197 |
-
)
|
198 |
-
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
199 |
-
|
200 |
-
# Initialize weights and apply final processing
|
201 |
-
self.apply(
|
202 |
-
partial(
|
203 |
-
_init_weights,
|
204 |
-
n_layer=n_layer,
|
205 |
-
**(initializer_cfg if initializer_cfg is not None else {}),
|
206 |
-
)
|
207 |
-
)
|
208 |
-
self.tie_weights()
|
209 |
-
|
210 |
-
def tie_weights(self):
|
211 |
-
self.lm_head.weight = self.backbone.embedding.weight
|
212 |
-
|
213 |
-
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
214 |
-
return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
215 |
-
|
216 |
-
def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
|
217 |
-
"""
|
218 |
-
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
219 |
-
num_last_tokens: if > 0, only return the logits for the last n tokens
|
220 |
-
"""
|
221 |
-
hidden_states = self.backbone(input_ids, inference_params=inference_params)
|
222 |
-
if num_last_tokens > 0:
|
223 |
-
hidden_states = hidden_states[:, -num_last_tokens:]
|
224 |
-
lm_logits = self.lm_head(hidden_states)
|
225 |
-
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
226 |
-
return CausalLMOutput(logits=lm_logits)
|
227 |
-
|
228 |
-
@classmethod
|
229 |
-
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
230 |
-
config = load_config_hf(pretrained_model_name)
|
231 |
-
model = cls(**config, device=device, dtype=dtype, **kwargs)
|
232 |
-
model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
|
233 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mamba/mamba_ssm/modules/__init__.py
DELETED
File without changes
|
mamba/mamba_ssm/modules/mamba_simple.py
DELETED
@@ -1,418 +0,0 @@
|
|
1 |
-
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
2 |
-
|
3 |
-
import math
|
4 |
-
from typing import Optional
|
5 |
-
|
6 |
-
import torch
|
7 |
-
import torch.nn as nn
|
8 |
-
import torch.nn.functional as F
|
9 |
-
from torch import Tensor
|
10 |
-
|
11 |
-
from einops import rearrange, repeat
|
12 |
-
|
13 |
-
try:
|
14 |
-
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
15 |
-
except ImportError:
|
16 |
-
causal_conv1d_fn, causal_conv1d_update = None
|
17 |
-
|
18 |
-
try:
|
19 |
-
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj
|
20 |
-
except ImportError:
|
21 |
-
selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj = None, None, None, None, None
|
22 |
-
|
23 |
-
try:
|
24 |
-
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
25 |
-
except ImportError:
|
26 |
-
selective_state_update = None
|
27 |
-
|
28 |
-
try:
|
29 |
-
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
|
30 |
-
except ImportError:
|
31 |
-
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
32 |
-
|
33 |
-
|
34 |
-
class Mamba(nn.Module):
|
35 |
-
def __init__(
|
36 |
-
self,
|
37 |
-
d_model,
|
38 |
-
d_state=16,
|
39 |
-
d_conv=4,
|
40 |
-
expand=2,
|
41 |
-
dt_rank="auto",
|
42 |
-
dt_min=0.001,
|
43 |
-
dt_max=0.1,
|
44 |
-
dt_init="random",
|
45 |
-
dt_scale=1.0,
|
46 |
-
dt_init_floor=1e-4,
|
47 |
-
conv_bias=True,
|
48 |
-
bias=False,
|
49 |
-
use_fast_path=True, # Fused kernel options
|
50 |
-
layer_idx=None,
|
51 |
-
device=None,
|
52 |
-
dtype=None,
|
53 |
-
bimamba=True,
|
54 |
-
):
|
55 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
56 |
-
super().__init__()
|
57 |
-
self.d_model = d_model
|
58 |
-
self.d_state = d_state
|
59 |
-
self.d_conv = d_conv
|
60 |
-
self.expand = expand
|
61 |
-
self.d_inner = int(self.expand * self.d_model)
|
62 |
-
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
63 |
-
self.use_fast_path = use_fast_path
|
64 |
-
self.layer_idx = layer_idx
|
65 |
-
self.bimamba = bimamba
|
66 |
-
|
67 |
-
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
|
68 |
-
|
69 |
-
self.conv1d = nn.Conv1d(
|
70 |
-
in_channels=self.d_inner,
|
71 |
-
out_channels=self.d_inner,
|
72 |
-
bias=conv_bias,
|
73 |
-
kernel_size=d_conv,
|
74 |
-
groups=self.d_inner,
|
75 |
-
padding=d_conv - 1,
|
76 |
-
**factory_kwargs,
|
77 |
-
)
|
78 |
-
|
79 |
-
self.activation = "silu"
|
80 |
-
self.act = nn.SiLU()
|
81 |
-
|
82 |
-
self.x_proj = nn.Linear(
|
83 |
-
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
84 |
-
)
|
85 |
-
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
|
86 |
-
|
87 |
-
# Initialize special dt projection to preserve variance at initialization
|
88 |
-
dt_init_std = self.dt_rank**-0.5 * dt_scale
|
89 |
-
if dt_init == "constant":
|
90 |
-
nn.init.constant_(self.dt_proj.weight, dt_init_std)
|
91 |
-
elif dt_init == "random":
|
92 |
-
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
|
93 |
-
else:
|
94 |
-
raise NotImplementedError
|
95 |
-
|
96 |
-
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
|
97 |
-
dt = torch.exp(
|
98 |
-
torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
|
99 |
-
+ math.log(dt_min)
|
100 |
-
).clamp(min=dt_init_floor)
|
101 |
-
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
102 |
-
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
103 |
-
with torch.no_grad():
|
104 |
-
self.dt_proj.bias.copy_(inv_dt)
|
105 |
-
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
|
106 |
-
self.dt_proj.bias._no_reinit = True
|
107 |
-
|
108 |
-
# S4D real initialization
|
109 |
-
# NOTE: why plus 1?
|
110 |
-
A = repeat(
|
111 |
-
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
112 |
-
"n -> d n",
|
113 |
-
d=self.d_inner,
|
114 |
-
).contiguous()
|
115 |
-
A_log = torch.log(A) # Keep A_log in fp32
|
116 |
-
self.A_log = nn.Parameter(A_log)
|
117 |
-
self.A_log._no_weight_decay = True
|
118 |
-
|
119 |
-
# D "skip" parameter
|
120 |
-
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
|
121 |
-
self.D._no_weight_decay = True
|
122 |
-
|
123 |
-
# bidirectional
|
124 |
-
# forked from https://github.com/hustvl/Vim
|
125 |
-
if self.bimamba:
|
126 |
-
A_b = repeat(
|
127 |
-
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
128 |
-
"n -> d n",
|
129 |
-
d=self.d_inner,
|
130 |
-
).contiguous()
|
131 |
-
A_b_log = torch.log(A_b) # Keep A_b_log in fp32
|
132 |
-
self.A_b_log = nn.Parameter(A_b_log)
|
133 |
-
self.A_b_log._no_weight_decay = True
|
134 |
-
|
135 |
-
self.conv1d_b = nn.Conv1d(
|
136 |
-
in_channels=self.d_inner,
|
137 |
-
out_channels=self.d_inner,
|
138 |
-
bias=conv_bias,
|
139 |
-
kernel_size=d_conv,
|
140 |
-
groups=self.d_inner,
|
141 |
-
padding=d_conv - 1,
|
142 |
-
**factory_kwargs,
|
143 |
-
)
|
144 |
-
|
145 |
-
self.x_proj_b = nn.Linear(
|
146 |
-
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
147 |
-
)
|
148 |
-
self.dt_proj_b = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
|
149 |
-
|
150 |
-
self.D_b = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
|
151 |
-
self.D_b._no_weight_decay = True
|
152 |
-
|
153 |
-
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
154 |
-
|
155 |
-
def forward(self, hidden_states, inference_params=None, T=1):
|
156 |
-
"""
|
157 |
-
hidden_states: (B, L, D)
|
158 |
-
Returns: same shape as hidden_states
|
159 |
-
"""
|
160 |
-
batch, seqlen, dim = hidden_states.shape
|
161 |
-
|
162 |
-
conv_state, ssm_state = None, None
|
163 |
-
if inference_params is not None:
|
164 |
-
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
165 |
-
if inference_params.seqlen_offset > 0:
|
166 |
-
# The states are updated inplace
|
167 |
-
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
|
168 |
-
return out
|
169 |
-
|
170 |
-
# We do matmul and transpose BLH -> HBL at the same time
|
171 |
-
# NOTE: same as in_proj(hidden_states) but memory-efficient with the following operations
|
172 |
-
xz = rearrange(
|
173 |
-
self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
|
174 |
-
"d (b l) -> b d l",
|
175 |
-
l=seqlen,
|
176 |
-
)
|
177 |
-
if self.in_proj.bias is not None:
|
178 |
-
xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
|
179 |
-
|
180 |
-
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
181 |
-
# In the backward pass we write dx and dz next to each other to avoid torch.cat
|
182 |
-
if self.use_fast_path and inference_params is None: # Doesn't support outputting the states
|
183 |
-
if self.bimamba:
|
184 |
-
A_b = -torch.exp(self.A_b_log.float())
|
185 |
-
out = mamba_inner_fn_no_out_proj(
|
186 |
-
xz,
|
187 |
-
self.conv1d.weight,
|
188 |
-
self.conv1d.bias,
|
189 |
-
self.x_proj.weight,
|
190 |
-
self.dt_proj.weight,
|
191 |
-
A,
|
192 |
-
None, # input-dependent B
|
193 |
-
None, # input-dependent C
|
194 |
-
self.D.float(),
|
195 |
-
delta_bias=self.dt_proj.bias.float(),
|
196 |
-
delta_softplus=True,
|
197 |
-
)
|
198 |
-
out_b = mamba_inner_fn_no_out_proj(
|
199 |
-
xz.flip([-1]),
|
200 |
-
self.conv1d_b.weight,
|
201 |
-
self.conv1d_b.bias,
|
202 |
-
self.x_proj_b.weight,
|
203 |
-
self.dt_proj_b.weight,
|
204 |
-
A_b,
|
205 |
-
None,
|
206 |
-
None,
|
207 |
-
self.D_b.float(),
|
208 |
-
delta_bias=self.dt_proj_b.bias.float(),
|
209 |
-
delta_softplus=True,
|
210 |
-
)
|
211 |
-
out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias)
|
212 |
-
else:
|
213 |
-
out = mamba_inner_fn(
|
214 |
-
xz,
|
215 |
-
self.conv1d.weight,
|
216 |
-
self.conv1d.bias,
|
217 |
-
self.x_proj.weight,
|
218 |
-
self.dt_proj.weight,
|
219 |
-
self.out_proj.weight,
|
220 |
-
self.out_proj.bias,
|
221 |
-
A,
|
222 |
-
None, # input-dependent B
|
223 |
-
None, # input-dependent C
|
224 |
-
self.D.float(),
|
225 |
-
delta_bias=self.dt_proj.bias.float(),
|
226 |
-
delta_softplus=True,
|
227 |
-
)
|
228 |
-
else:
|
229 |
-
x, z = xz.chunk(2, dim=1)
|
230 |
-
# Compute short convolution
|
231 |
-
if conv_state is not None:
|
232 |
-
conv_state.copy_(x[:, :, -self.d_conv :]) # Update state (B D W)
|
233 |
-
if causal_conv1d_fn is None:
|
234 |
-
x = self.act(self.conv1d(x)[..., :seqlen])
|
235 |
-
else:
|
236 |
-
assert self.activation in ["silu", "swish"]
|
237 |
-
x = causal_conv1d_fn(
|
238 |
-
x,
|
239 |
-
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
240 |
-
self.conv1d.bias,
|
241 |
-
self.activation,
|
242 |
-
)
|
243 |
-
|
244 |
-
# We're careful here about the layout, to avoid extra transposes.
|
245 |
-
# We want dt to have d as the slowest moving dimension
|
246 |
-
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
247 |
-
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
|
248 |
-
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
249 |
-
dt = self.dt_proj.weight @ dt.t()
|
250 |
-
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
251 |
-
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
252 |
-
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
253 |
-
assert self.activation in ["silu", "swish"]
|
254 |
-
y = selective_scan_fn(
|
255 |
-
x,
|
256 |
-
dt,
|
257 |
-
A,
|
258 |
-
B,
|
259 |
-
C,
|
260 |
-
self.D.float(),
|
261 |
-
z=z,
|
262 |
-
delta_bias=self.dt_proj.bias.float(),
|
263 |
-
delta_softplus=True,
|
264 |
-
return_last_state=ssm_state is not None,
|
265 |
-
)
|
266 |
-
if ssm_state is not None:
|
267 |
-
y, last_state = y
|
268 |
-
ssm_state.copy_(last_state)
|
269 |
-
y = rearrange(y, "b d l -> b l d")
|
270 |
-
out = self.out_proj(y)
|
271 |
-
return out
|
272 |
-
|
273 |
-
def step(self, hidden_states, conv_state, ssm_state):
|
274 |
-
dtype = hidden_states.dtype
|
275 |
-
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
|
276 |
-
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
277 |
-
x, z = xz.chunk(2, dim=-1) # (B D)
|
278 |
-
|
279 |
-
# Conv step
|
280 |
-
if causal_conv1d_update is None:
|
281 |
-
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
|
282 |
-
conv_state[:, :, -1] = x
|
283 |
-
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
|
284 |
-
if self.conv1d.bias is not None:
|
285 |
-
x = x + self.conv1d.bias
|
286 |
-
x = self.act(x).to(dtype=dtype)
|
287 |
-
else:
|
288 |
-
x = causal_conv1d_update(
|
289 |
-
x,
|
290 |
-
conv_state,
|
291 |
-
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
292 |
-
self.conv1d.bias,
|
293 |
-
self.activation,
|
294 |
-
)
|
295 |
-
|
296 |
-
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
297 |
-
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
298 |
-
# Don't add dt_bias here
|
299 |
-
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
|
300 |
-
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
301 |
-
|
302 |
-
# SSM step
|
303 |
-
if selective_state_update is None:
|
304 |
-
# Discretize A and B
|
305 |
-
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
|
306 |
-
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
|
307 |
-
dB = torch.einsum("bd,bn->bdn", dt, B)
|
308 |
-
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
|
309 |
-
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
|
310 |
-
y = y + self.D.to(dtype) * x
|
311 |
-
y = y * self.act(z) # (B D)
|
312 |
-
else:
|
313 |
-
y = selective_state_update(
|
314 |
-
ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
|
315 |
-
)
|
316 |
-
|
317 |
-
out = self.out_proj(y)
|
318 |
-
return out.unsqueeze(1), conv_state, ssm_state
|
319 |
-
|
320 |
-
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
321 |
-
device = self.out_proj.weight.device
|
322 |
-
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
|
323 |
-
conv_state = torch.zeros(
|
324 |
-
batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
|
325 |
-
)
|
326 |
-
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
|
327 |
-
# ssm_dtype = torch.float32
|
328 |
-
ssm_state = torch.zeros(
|
329 |
-
batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
|
330 |
-
)
|
331 |
-
return conv_state, ssm_state
|
332 |
-
|
333 |
-
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
|
334 |
-
assert self.layer_idx is not None
|
335 |
-
if self.layer_idx not in inference_params.key_value_memory_dict:
|
336 |
-
batch_shape = (batch_size,)
|
337 |
-
conv_state = torch.zeros(
|
338 |
-
batch_size,
|
339 |
-
self.d_model * self.expand,
|
340 |
-
self.d_conv,
|
341 |
-
device=self.conv1d.weight.device,
|
342 |
-
dtype=self.conv1d.weight.dtype,
|
343 |
-
)
|
344 |
-
ssm_state = torch.zeros(
|
345 |
-
batch_size,
|
346 |
-
self.d_model * self.expand,
|
347 |
-
self.d_state,
|
348 |
-
device=self.dt_proj.weight.device,
|
349 |
-
dtype=self.dt_proj.weight.dtype,
|
350 |
-
# dtype=torch.float32,
|
351 |
-
)
|
352 |
-
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
|
353 |
-
else:
|
354 |
-
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
|
355 |
-
# TODO: What if batch size changes between generation, and we reuse the same states?
|
356 |
-
if initialize_states:
|
357 |
-
conv_state.zero_()
|
358 |
-
ssm_state.zero_()
|
359 |
-
return conv_state, ssm_state
|
360 |
-
|
361 |
-
|
362 |
-
class Block(nn.Module):
|
363 |
-
def __init__(
|
364 |
-
self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
|
365 |
-
):
|
366 |
-
"""
|
367 |
-
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
|
368 |
-
|
369 |
-
This Block has a slightly different structure compared to a regular
|
370 |
-
prenorm Transformer block.
|
371 |
-
The standard block is: LN -> MHA/MLP -> Add.
|
372 |
-
[Ref: https://arxiv.org/abs/2002.04745]
|
373 |
-
Here we have: Add -> LN -> Mixer, returning both
|
374 |
-
the hidden_states (output of the mixer) and the residual.
|
375 |
-
This is purely for performance reasons, as we can fuse add and LayerNorm.
|
376 |
-
The residual needs to be provided (except for the very first block).
|
377 |
-
"""
|
378 |
-
super().__init__()
|
379 |
-
self.residual_in_fp32 = residual_in_fp32
|
380 |
-
self.fused_add_norm = fused_add_norm
|
381 |
-
self.mixer = mixer_cls(dim)
|
382 |
-
self.norm = norm_cls(dim)
|
383 |
-
if self.fused_add_norm:
|
384 |
-
assert RMSNorm is not None, "RMSNorm import fails"
|
385 |
-
assert isinstance(
|
386 |
-
self.norm, (nn.LayerNorm, RMSNorm)
|
387 |
-
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
388 |
-
|
389 |
-
def forward(
|
390 |
-
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
|
391 |
-
):
|
392 |
-
r"""Pass the input through the encoder layer.
|
393 |
-
|
394 |
-
Args:
|
395 |
-
hidden_states: the sequence to the encoder layer (required).
|
396 |
-
residual: hidden_states = Mixer(LN(residual))
|
397 |
-
"""
|
398 |
-
if not self.fused_add_norm:
|
399 |
-
residual = (hidden_states + residual) if residual is not None else hidden_states
|
400 |
-
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
401 |
-
if self.residual_in_fp32:
|
402 |
-
residual = residual.to(torch.float32)
|
403 |
-
else:
|
404 |
-
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
|
405 |
-
hidden_states, residual = fused_add_norm_fn(
|
406 |
-
hidden_states,
|
407 |
-
self.norm.weight,
|
408 |
-
self.norm.bias,
|
409 |
-
residual=residual,
|
410 |
-
prenorm=True,
|
411 |
-
residual_in_fp32=self.residual_in_fp32,
|
412 |
-
eps=self.norm.eps,
|
413 |
-
)
|
414 |
-
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
|
415 |
-
return hidden_states, residual
|
416 |
-
|
417 |
-
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
418 |
-
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mamba/mamba_ssm/ops/__init__.py
DELETED
File without changes
|
mamba/mamba_ssm/ops/selective_scan_interface.py
DELETED
@@ -1,709 +0,0 @@
|
|
1 |
-
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import torch.nn.functional as F
|
5 |
-
from torch.cuda.amp import custom_bwd, custom_fwd
|
6 |
-
|
7 |
-
from einops import rearrange, repeat
|
8 |
-
|
9 |
-
from causal_conv1d import causal_conv1d_fn
|
10 |
-
import causal_conv1d_cuda
|
11 |
-
import selective_scan_cuda
|
12 |
-
|
13 |
-
|
14 |
-
class SelectiveScanFn(torch.autograd.Function):
|
15 |
-
|
16 |
-
@staticmethod
|
17 |
-
def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
18 |
-
return_last_state=False):
|
19 |
-
if u.stride(-1) != 1:
|
20 |
-
u = u.contiguous()
|
21 |
-
if delta.stride(-1) != 1:
|
22 |
-
delta = delta.contiguous()
|
23 |
-
if D is not None:
|
24 |
-
D = D.contiguous()
|
25 |
-
if B.stride(-1) != 1:
|
26 |
-
B = B.contiguous()
|
27 |
-
if C.stride(-1) != 1:
|
28 |
-
C = C.contiguous()
|
29 |
-
if z is not None and z.stride(-1) != 1:
|
30 |
-
z = z.contiguous()
|
31 |
-
if B.dim() == 3:
|
32 |
-
B = rearrange(B, "b dstate l -> b 1 dstate l")
|
33 |
-
ctx.squeeze_B = True
|
34 |
-
if C.dim() == 3:
|
35 |
-
C = rearrange(C, "b dstate l -> b 1 dstate l")
|
36 |
-
ctx.squeeze_C = True
|
37 |
-
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
|
38 |
-
ctx.delta_softplus = delta_softplus
|
39 |
-
ctx.has_z = z is not None
|
40 |
-
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
|
41 |
-
if not ctx.has_z:
|
42 |
-
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
|
43 |
-
return out if not return_last_state else (out, last_state)
|
44 |
-
else:
|
45 |
-
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
|
46 |
-
out_z = rest[0]
|
47 |
-
return out_z if not return_last_state else (out_z, last_state)
|
48 |
-
|
49 |
-
@staticmethod
|
50 |
-
def backward(ctx, dout, *args):
|
51 |
-
if not ctx.has_z:
|
52 |
-
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
|
53 |
-
z = None
|
54 |
-
out = None
|
55 |
-
else:
|
56 |
-
u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
|
57 |
-
if dout.stride(-1) != 1:
|
58 |
-
dout = dout.contiguous()
|
59 |
-
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
60 |
-
# backward of selective_scan_cuda with the backward of chunk).
|
61 |
-
# Here we just pass in None and dz will be allocated in the C++ code.
|
62 |
-
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
|
63 |
-
u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
|
64 |
-
False # option to recompute out_z, not used here
|
65 |
-
)
|
66 |
-
dz = rest[0] if ctx.has_z else None
|
67 |
-
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
|
68 |
-
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
|
69 |
-
return (du, ddelta, dA, dB, dC,
|
70 |
-
dD if D is not None else None,
|
71 |
-
dz,
|
72 |
-
ddelta_bias if delta_bias is not None else None,
|
73 |
-
None,
|
74 |
-
None)
|
75 |
-
|
76 |
-
|
77 |
-
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
78 |
-
return_last_state=False):
|
79 |
-
"""if return_last_state is True, returns (out, last_state)
|
80 |
-
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
|
81 |
-
not considered in the backward pass.
|
82 |
-
"""
|
83 |
-
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
|
84 |
-
|
85 |
-
|
86 |
-
def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
87 |
-
return_last_state=False):
|
88 |
-
"""
|
89 |
-
u: r(B D L)
|
90 |
-
delta: r(B D L)
|
91 |
-
A: c(D N) or r(D N)
|
92 |
-
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
93 |
-
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
94 |
-
D: r(D)
|
95 |
-
z: r(B D L)
|
96 |
-
delta_bias: r(D), fp32
|
97 |
-
|
98 |
-
out: r(B D L)
|
99 |
-
last_state (optional): r(B D dstate) or c(B D dstate)
|
100 |
-
"""
|
101 |
-
dtype_in = u.dtype
|
102 |
-
u = u.float()
|
103 |
-
delta = delta.float()
|
104 |
-
if delta_bias is not None:
|
105 |
-
delta = delta + delta_bias[..., None].float()
|
106 |
-
if delta_softplus:
|
107 |
-
delta = F.softplus(delta)
|
108 |
-
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
|
109 |
-
is_variable_B = B.dim() >= 3
|
110 |
-
is_variable_C = C.dim() >= 3
|
111 |
-
if A.is_complex():
|
112 |
-
if is_variable_B:
|
113 |
-
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
|
114 |
-
if is_variable_C:
|
115 |
-
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
|
116 |
-
else:
|
117 |
-
B = B.float()
|
118 |
-
C = C.float()
|
119 |
-
x = A.new_zeros((batch, dim, dstate))
|
120 |
-
ys = []
|
121 |
-
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
|
122 |
-
if not is_variable_B:
|
123 |
-
deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
|
124 |
-
else:
|
125 |
-
if B.dim() == 3:
|
126 |
-
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
|
127 |
-
else:
|
128 |
-
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
|
129 |
-
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
|
130 |
-
if is_variable_C and C.dim() == 4:
|
131 |
-
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
|
132 |
-
last_state = None
|
133 |
-
for i in range(u.shape[2]):
|
134 |
-
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
135 |
-
if not is_variable_C:
|
136 |
-
y = torch.einsum('bdn,dn->bd', x, C)
|
137 |
-
else:
|
138 |
-
if C.dim() == 3:
|
139 |
-
y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
|
140 |
-
else:
|
141 |
-
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
|
142 |
-
if i == u.shape[2] - 1:
|
143 |
-
last_state = x
|
144 |
-
if y.is_complex():
|
145 |
-
y = y.real * 2
|
146 |
-
ys.append(y)
|
147 |
-
y = torch.stack(ys, dim=2) # (batch dim L)
|
148 |
-
out = y if D is None else y + u * rearrange(D, "d -> d 1")
|
149 |
-
if z is not None:
|
150 |
-
out = out * F.silu(z)
|
151 |
-
out = out.to(dtype=dtype_in)
|
152 |
-
return out if not return_last_state else (out, last_state)
|
153 |
-
|
154 |
-
|
155 |
-
class MambaInnerFnNoOutProj(torch.autograd.Function):
|
156 |
-
|
157 |
-
@staticmethod
|
158 |
-
@custom_fwd
|
159 |
-
def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
160 |
-
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
161 |
-
C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
|
162 |
-
"""
|
163 |
-
xz: (batch, dim, seqlen)
|
164 |
-
"""
|
165 |
-
assert checkpoint_lvl in [0, 1]
|
166 |
-
L = xz.shape[-1]
|
167 |
-
delta_rank = delta_proj_weight.shape[1]
|
168 |
-
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
169 |
-
if torch.is_autocast_enabled():
|
170 |
-
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
171 |
-
delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
172 |
-
if xz.stride(-1) != 1:
|
173 |
-
xz = xz.contiguous()
|
174 |
-
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
|
175 |
-
x, z = xz.chunk(2, dim=1)
|
176 |
-
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
|
177 |
-
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True)
|
178 |
-
# We're being very careful here about the layout, to avoid extra transposes.
|
179 |
-
# We want delta to have d as the slowest moving dimension
|
180 |
-
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
181 |
-
x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
|
182 |
-
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
|
183 |
-
ctx.is_variable_B = B is None
|
184 |
-
ctx.is_variable_C = C is None
|
185 |
-
ctx.B_proj_bias_is_None = B_proj_bias is None
|
186 |
-
ctx.C_proj_bias_is_None = C_proj_bias is None
|
187 |
-
if B is None: # variable B
|
188 |
-
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate)
|
189 |
-
if B_proj_bias is not None:
|
190 |
-
B = B + B_proj_bias.to(dtype=B.dtype)
|
191 |
-
if not A.is_complex():
|
192 |
-
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
193 |
-
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
194 |
-
else:
|
195 |
-
B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
196 |
-
else:
|
197 |
-
if B.stride(-1) != 1:
|
198 |
-
B = B.contiguous()
|
199 |
-
if C is None: # variable C
|
200 |
-
C = x_dbl[:, -d_state:] # (bl dstate)
|
201 |
-
if C_proj_bias is not None:
|
202 |
-
C = C + C_proj_bias.to(dtype=C.dtype)
|
203 |
-
if not A.is_complex():
|
204 |
-
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
205 |
-
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
206 |
-
else:
|
207 |
-
C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
208 |
-
else:
|
209 |
-
if C.stride(-1) != 1:
|
210 |
-
C = C.contiguous()
|
211 |
-
if D is not None:
|
212 |
-
D = D.contiguous()
|
213 |
-
out, scan_intermediates, out_z = selective_scan_cuda.fwd(
|
214 |
-
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
|
215 |
-
)
|
216 |
-
ctx.delta_softplus = delta_softplus
|
217 |
-
ctx.checkpoint_lvl = checkpoint_lvl
|
218 |
-
if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
|
219 |
-
conv1d_out, delta = None, None
|
220 |
-
ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
|
221 |
-
delta_proj_weight, conv1d_out, delta,
|
222 |
-
A, B, C, D, delta_bias, scan_intermediates, out)
|
223 |
-
# return rearrange(out_z, "b d l -> b l d")
|
224 |
-
return out_z
|
225 |
-
|
226 |
-
@staticmethod
|
227 |
-
@custom_bwd
|
228 |
-
def backward(ctx, dout):
|
229 |
-
# dout: (batch, seqlen, dim)
|
230 |
-
(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight,
|
231 |
-
conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
|
232 |
-
L = xz.shape[-1]
|
233 |
-
delta_rank = delta_proj_weight.shape[1]
|
234 |
-
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
235 |
-
x, z = xz.chunk(2, dim=1)
|
236 |
-
if dout.stride(-1) != 1:
|
237 |
-
dout = dout.contiguous()
|
238 |
-
if ctx.checkpoint_lvl == 1:
|
239 |
-
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True)
|
240 |
-
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
|
241 |
-
"d (b l) -> b d l", l = L)
|
242 |
-
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
243 |
-
# backward of selective_scan_cuda with the backward of chunk).
|
244 |
-
dxz = torch.empty_like(xz) # (batch, dim, seqlen)
|
245 |
-
dx, dz = dxz.chunk(2, dim=1)
|
246 |
-
# dout_y = rearrange(dout, "b l d -> b d l") # because no arrange at end of forward, so dout shape is b d l
|
247 |
-
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
|
248 |
-
conv1d_out, delta, A, B, C, D, z, delta_bias, dout, scan_intermediates, out, dz,
|
249 |
-
ctx.delta_softplus,
|
250 |
-
True # option to recompute out_z
|
251 |
-
)
|
252 |
-
dD = dD if D is not None else None
|
253 |
-
dx_dbl = torch.empty_like(x_dbl)
|
254 |
-
dB_proj_bias = None
|
255 |
-
if ctx.is_variable_B:
|
256 |
-
if not A.is_complex():
|
257 |
-
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
|
258 |
-
else:
|
259 |
-
dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
260 |
-
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
|
261 |
-
dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d)
|
262 |
-
dB = None
|
263 |
-
dC_proj_bias = None
|
264 |
-
if ctx.is_variable_C:
|
265 |
-
if not A.is_complex():
|
266 |
-
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
|
267 |
-
else:
|
268 |
-
dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
269 |
-
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
|
270 |
-
dx_dbl[:, -d_state:] = dC # (bl d)
|
271 |
-
dC = None
|
272 |
-
ddelta = rearrange(ddelta, "b d l -> d (b l)")
|
273 |
-
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
|
274 |
-
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
|
275 |
-
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
|
276 |
-
dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
|
277 |
-
dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
|
278 |
-
dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
|
279 |
-
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
280 |
-
# backward of conv1d with the backward of chunk).
|
281 |
-
dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd(
|
282 |
-
x, conv1d_weight, conv1d_bias, dconv1d_out, dx, True
|
283 |
-
)
|
284 |
-
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
|
285 |
-
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
|
286 |
-
return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
|
287 |
-
dA, dB, dC, dD,
|
288 |
-
ddelta_bias if delta_bias is not None else None,
|
289 |
-
dB_proj_bias, dC_proj_bias, None)
|
290 |
-
|
291 |
-
|
292 |
-
class MambaInnerFn(torch.autograd.Function):
|
293 |
-
|
294 |
-
@staticmethod
|
295 |
-
@custom_fwd
|
296 |
-
def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
297 |
-
out_proj_weight, out_proj_bias,
|
298 |
-
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
299 |
-
C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
|
300 |
-
"""
|
301 |
-
xz: (batch, dim, seqlen)
|
302 |
-
"""
|
303 |
-
assert checkpoint_lvl in [0, 1]
|
304 |
-
L = xz.shape[-1]
|
305 |
-
delta_rank = delta_proj_weight.shape[1]
|
306 |
-
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
307 |
-
if torch.is_autocast_enabled():
|
308 |
-
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
309 |
-
delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
310 |
-
out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
311 |
-
out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
|
312 |
-
if out_proj_bias is not None else None)
|
313 |
-
if xz.stride(-1) != 1:
|
314 |
-
xz = xz.contiguous()
|
315 |
-
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
|
316 |
-
x, z = xz.chunk(2, dim=1)
|
317 |
-
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
|
318 |
-
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True)
|
319 |
-
# We're being very careful here about the layout, to avoid extra transposes.
|
320 |
-
# We want delta to have d as the slowest moving dimension
|
321 |
-
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
322 |
-
x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
|
323 |
-
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
|
324 |
-
ctx.is_variable_B = B is None
|
325 |
-
ctx.is_variable_C = C is None
|
326 |
-
ctx.B_proj_bias_is_None = B_proj_bias is None
|
327 |
-
ctx.C_proj_bias_is_None = C_proj_bias is None
|
328 |
-
if B is None: # variable B
|
329 |
-
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate)
|
330 |
-
if B_proj_bias is not None:
|
331 |
-
B = B + B_proj_bias.to(dtype=B.dtype)
|
332 |
-
if not A.is_complex():
|
333 |
-
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
334 |
-
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
335 |
-
else:
|
336 |
-
B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
337 |
-
else:
|
338 |
-
if B.stride(-1) != 1:
|
339 |
-
B = B.contiguous()
|
340 |
-
if C is None: # variable C
|
341 |
-
C = x_dbl[:, -d_state:] # (bl dstate)
|
342 |
-
if C_proj_bias is not None:
|
343 |
-
C = C + C_proj_bias.to(dtype=C.dtype)
|
344 |
-
if not A.is_complex():
|
345 |
-
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
346 |
-
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
347 |
-
else:
|
348 |
-
C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
349 |
-
else:
|
350 |
-
if C.stride(-1) != 1:
|
351 |
-
C = C.contiguous()
|
352 |
-
if D is not None:
|
353 |
-
D = D.contiguous()
|
354 |
-
out, scan_intermediates, out_z = selective_scan_cuda.fwd(
|
355 |
-
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
|
356 |
-
)
|
357 |
-
ctx.delta_softplus = delta_softplus
|
358 |
-
ctx.out_proj_bias_is_None = out_proj_bias is None
|
359 |
-
ctx.checkpoint_lvl = checkpoint_lvl
|
360 |
-
if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
|
361 |
-
conv1d_out, delta = None, None
|
362 |
-
ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
|
363 |
-
delta_proj_weight, out_proj_weight, conv1d_out, delta,
|
364 |
-
A, B, C, D, delta_bias, scan_intermediates, out)
|
365 |
-
return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
366 |
-
|
367 |
-
@staticmethod
|
368 |
-
@custom_bwd
|
369 |
-
def backward(ctx, dout):
|
370 |
-
# dout: (batch, seqlen, dim)
|
371 |
-
(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
|
372 |
-
conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
|
373 |
-
L = xz.shape[-1]
|
374 |
-
delta_rank = delta_proj_weight.shape[1]
|
375 |
-
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
376 |
-
x, z = xz.chunk(2, dim=1)
|
377 |
-
if dout.stride(-1) != 1:
|
378 |
-
dout = dout.contiguous()
|
379 |
-
if ctx.checkpoint_lvl == 1:
|
380 |
-
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True)
|
381 |
-
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
|
382 |
-
"d (b l) -> b d l", l = L)
|
383 |
-
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
384 |
-
# backward of selective_scan_cuda with the backward of chunk).
|
385 |
-
dxz = torch.empty_like(xz) # (batch, dim, seqlen)
|
386 |
-
dx, dz = dxz.chunk(2, dim=1)
|
387 |
-
dout = rearrange(dout, "b l e -> e (b l)")
|
388 |
-
dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
|
389 |
-
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
|
390 |
-
conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz,
|
391 |
-
ctx.delta_softplus,
|
392 |
-
True # option to recompute out_z
|
393 |
-
)
|
394 |
-
dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
|
395 |
-
dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
|
396 |
-
dD = dD if D is not None else None
|
397 |
-
dx_dbl = torch.empty_like(x_dbl)
|
398 |
-
dB_proj_bias = None
|
399 |
-
if ctx.is_variable_B:
|
400 |
-
if not A.is_complex():
|
401 |
-
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
|
402 |
-
else:
|
403 |
-
dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
404 |
-
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
|
405 |
-
dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d)
|
406 |
-
dB = None
|
407 |
-
dC_proj_bias = None
|
408 |
-
if ctx.is_variable_C:
|
409 |
-
if not A.is_complex():
|
410 |
-
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
|
411 |
-
else:
|
412 |
-
dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
413 |
-
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
|
414 |
-
dx_dbl[:, -d_state:] = dC # (bl d)
|
415 |
-
dC = None
|
416 |
-
ddelta = rearrange(ddelta, "b d l -> d (b l)")
|
417 |
-
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
|
418 |
-
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
|
419 |
-
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
|
420 |
-
dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
|
421 |
-
dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
|
422 |
-
dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
|
423 |
-
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
424 |
-
# backward of conv1d with the backward of chunk).
|
425 |
-
dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd(
|
426 |
-
x, conv1d_weight, conv1d_bias, dconv1d_out, dx, True
|
427 |
-
)
|
428 |
-
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
|
429 |
-
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
|
430 |
-
return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
|
431 |
-
dout_proj_weight, dout_proj_bias,
|
432 |
-
dA, dB, dC, dD,
|
433 |
-
ddelta_bias if delta_bias is not None else None,
|
434 |
-
dB_proj_bias, dC_proj_bias, None)
|
435 |
-
|
436 |
-
|
437 |
-
class BiMambaInnerFn(torch.autograd.Function):
|
438 |
-
|
439 |
-
@staticmethod
|
440 |
-
@custom_fwd
|
441 |
-
def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
442 |
-
out_proj_weight, out_proj_bias,
|
443 |
-
A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
444 |
-
C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
|
445 |
-
"""
|
446 |
-
xz: (batch, dim, seqlen)
|
447 |
-
"""
|
448 |
-
assert checkpoint_lvl in [0, 1]
|
449 |
-
L = xz.shape[-1]
|
450 |
-
delta_rank = delta_proj_weight.shape[1]
|
451 |
-
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
452 |
-
if torch.is_autocast_enabled():
|
453 |
-
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
454 |
-
delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
455 |
-
out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
456 |
-
out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
|
457 |
-
if out_proj_bias is not None else None)
|
458 |
-
if xz.stride(-1) != 1:
|
459 |
-
xz = xz.contiguous()
|
460 |
-
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
|
461 |
-
x, z = xz.chunk(2, dim=1)
|
462 |
-
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
|
463 |
-
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True)
|
464 |
-
# We're being very careful here about the layout, to avoid extra transposes.
|
465 |
-
# We want delta to have d as the slowest moving dimension
|
466 |
-
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
467 |
-
x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
|
468 |
-
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
|
469 |
-
ctx.is_variable_B = B is None
|
470 |
-
ctx.is_variable_C = C is None
|
471 |
-
ctx.B_proj_bias_is_None = B_proj_bias is None
|
472 |
-
ctx.C_proj_bias_is_None = C_proj_bias is None
|
473 |
-
if B is None: # variable B
|
474 |
-
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate)
|
475 |
-
if B_proj_bias is not None:
|
476 |
-
B = B + B_proj_bias.to(dtype=B.dtype)
|
477 |
-
if not A.is_complex():
|
478 |
-
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
479 |
-
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
480 |
-
else:
|
481 |
-
B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
482 |
-
else:
|
483 |
-
if B.stride(-1) != 1:
|
484 |
-
B = B.contiguous()
|
485 |
-
if C is None: # variable C
|
486 |
-
C = x_dbl[:, -d_state:] # (bl dstate)
|
487 |
-
if C_proj_bias is not None:
|
488 |
-
C = C + C_proj_bias.to(dtype=C.dtype)
|
489 |
-
if not A.is_complex():
|
490 |
-
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
491 |
-
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
492 |
-
else:
|
493 |
-
C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
494 |
-
else:
|
495 |
-
if C.stride(-1) != 1:
|
496 |
-
C = C.contiguous()
|
497 |
-
if D is not None:
|
498 |
-
D = D.contiguous()
|
499 |
-
out_f, scan_intermediates_f, out_z_f = selective_scan_cuda.fwd(
|
500 |
-
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
|
501 |
-
)
|
502 |
-
assert not A_b.is_complex(), "A should not be complex!!"
|
503 |
-
out_b, scan_intermediates_b, out_z_b = selective_scan_cuda.fwd(
|
504 |
-
conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus,
|
505 |
-
)
|
506 |
-
|
507 |
-
out_z = out_z_f + out_z_b.flip([-1])
|
508 |
-
|
509 |
-
ctx.delta_softplus = delta_softplus
|
510 |
-
ctx.out_proj_bias_is_None = out_proj_bias is None
|
511 |
-
ctx.checkpoint_lvl = checkpoint_lvl
|
512 |
-
if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
|
513 |
-
conv1d_out, delta = None, None
|
514 |
-
ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
|
515 |
-
delta_proj_weight, out_proj_weight, conv1d_out, delta,
|
516 |
-
A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b)
|
517 |
-
return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
518 |
-
|
519 |
-
@staticmethod
|
520 |
-
@custom_bwd
|
521 |
-
def backward(ctx, dout):
|
522 |
-
# dout: (batch, seqlen, dim)
|
523 |
-
(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
|
524 |
-
conv1d_out, delta, A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b) = ctx.saved_tensors
|
525 |
-
L = xz.shape[-1]
|
526 |
-
delta_rank = delta_proj_weight.shape[1]
|
527 |
-
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
528 |
-
x, z = xz.chunk(2, dim=1)
|
529 |
-
if dout.stride(-1) != 1:
|
530 |
-
dout = dout.contiguous()
|
531 |
-
if ctx.checkpoint_lvl == 1:
|
532 |
-
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True)
|
533 |
-
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
|
534 |
-
"d (b l) -> b d l", l = L)
|
535 |
-
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
536 |
-
# backward of selective_scan_cuda with the backward of chunk).
|
537 |
-
dxz = torch.empty_like(xz) # (batch, dim, seqlen)
|
538 |
-
dx, dz = dxz.chunk(2, dim=1)
|
539 |
-
dout = rearrange(dout, "b l e -> e (b l)")
|
540 |
-
dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
|
541 |
-
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z_f = selective_scan_cuda.bwd(
|
542 |
-
conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates_f, out_f, dz,
|
543 |
-
ctx.delta_softplus,
|
544 |
-
True # option to recompute out_z
|
545 |
-
)
|
546 |
-
# flip one
|
547 |
-
dz_b = torch.empty_like(dz)
|
548 |
-
dconv1d_out_f_b, ddelta_f_b, dA_b, dB_f_b, dC_f_b, dD_b, ddelta_bias_b, dz_b, out_z_b = selective_scan_cuda.bwd(
|
549 |
-
conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, dout_y.flip([-1]), scan_intermediates_b, out_b, dz_b,
|
550 |
-
ctx.delta_softplus,
|
551 |
-
True # option to recompute out_z
|
552 |
-
)
|
553 |
-
|
554 |
-
dconv1d_out = dconv1d_out + dconv1d_out_f_b.flip([-1])
|
555 |
-
ddelta = ddelta + ddelta_f_b.flip([-1])
|
556 |
-
dB = dB + dB_f_b.flip([-1])
|
557 |
-
dC = dC + dC_f_b.flip([-1])
|
558 |
-
dD = dD + dD_b
|
559 |
-
ddelta_bias = ddelta_bias + ddelta_bias_b
|
560 |
-
dz = dz + dz_b.flip([-1])
|
561 |
-
out_z = out_z_f + out_z_b.flip([-1])
|
562 |
-
|
563 |
-
dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
|
564 |
-
dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
|
565 |
-
dD = dD if D is not None else None
|
566 |
-
dx_dbl = torch.empty_like(x_dbl)
|
567 |
-
dB_proj_bias = None
|
568 |
-
if ctx.is_variable_B:
|
569 |
-
if not A.is_complex():
|
570 |
-
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
|
571 |
-
else:
|
572 |
-
dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
573 |
-
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
|
574 |
-
dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d)
|
575 |
-
dB = None
|
576 |
-
dC_proj_bias = None
|
577 |
-
if ctx.is_variable_C:
|
578 |
-
if not A.is_complex():
|
579 |
-
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
|
580 |
-
else:
|
581 |
-
dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
582 |
-
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
|
583 |
-
dx_dbl[:, -d_state:] = dC # (bl d)
|
584 |
-
dC = None
|
585 |
-
ddelta = rearrange(ddelta, "b d l -> d (b l)")
|
586 |
-
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
|
587 |
-
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
|
588 |
-
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
|
589 |
-
dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
|
590 |
-
dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
|
591 |
-
dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
|
592 |
-
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
593 |
-
# backward of conv1d with the backward of chunk).
|
594 |
-
dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd(
|
595 |
-
x, conv1d_weight, conv1d_bias, dconv1d_out, dx, True
|
596 |
-
)
|
597 |
-
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
|
598 |
-
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
|
599 |
-
return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
|
600 |
-
dout_proj_weight, dout_proj_bias,
|
601 |
-
dA, dA_b, dB, dC, dD,
|
602 |
-
ddelta_bias if delta_bias is not None else None,
|
603 |
-
dB_proj_bias, dC_proj_bias, None)
|
604 |
-
|
605 |
-
|
606 |
-
def mamba_inner_fn(
|
607 |
-
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
608 |
-
out_proj_weight, out_proj_bias,
|
609 |
-
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
610 |
-
C_proj_bias=None, delta_softplus=True
|
611 |
-
):
|
612 |
-
return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
613 |
-
out_proj_weight, out_proj_bias,
|
614 |
-
A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
|
615 |
-
|
616 |
-
def bimamba_inner_fn(
|
617 |
-
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
618 |
-
out_proj_weight, out_proj_bias,
|
619 |
-
A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
620 |
-
C_proj_bias=None, delta_softplus=True
|
621 |
-
):
|
622 |
-
return BiMambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
623 |
-
out_proj_weight, out_proj_bias,
|
624 |
-
A, A_b, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
|
625 |
-
|
626 |
-
|
627 |
-
def mamba_inner_fn_no_out_proj(
|
628 |
-
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
629 |
-
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
630 |
-
C_proj_bias=None, delta_softplus=True
|
631 |
-
):
|
632 |
-
return MambaInnerFnNoOutProj.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
633 |
-
A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
|
634 |
-
|
635 |
-
|
636 |
-
def mamba_inner_ref(
|
637 |
-
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
638 |
-
out_proj_weight, out_proj_bias,
|
639 |
-
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
640 |
-
C_proj_bias=None, delta_softplus=True
|
641 |
-
):
|
642 |
-
L = xz.shape[-1]
|
643 |
-
delta_rank = delta_proj_weight.shape[1]
|
644 |
-
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
645 |
-
x, z = xz.chunk(2, dim=1)
|
646 |
-
x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu")
|
647 |
-
# We're being very careful here about the layout, to avoid extra transposes.
|
648 |
-
# We want delta to have d as the slowest moving dimension
|
649 |
-
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
650 |
-
x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
|
651 |
-
delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
|
652 |
-
delta = rearrange(delta, "d (b l) -> b d l", l=L)
|
653 |
-
if B is None: # variable B
|
654 |
-
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d)
|
655 |
-
if B_proj_bias is not None:
|
656 |
-
B = B + B_proj_bias.to(dtype=B.dtype)
|
657 |
-
if not A.is_complex():
|
658 |
-
B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
659 |
-
else:
|
660 |
-
B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
|
661 |
-
if C is None: # variable B
|
662 |
-
C = x_dbl[:, -d_state:] # (bl d)
|
663 |
-
if C_proj_bias is not None:
|
664 |
-
C = C + C_proj_bias.to(dtype=C.dtype)
|
665 |
-
if not A.is_complex():
|
666 |
-
C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
667 |
-
else:
|
668 |
-
C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
|
669 |
-
y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
|
670 |
-
return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
671 |
-
|
672 |
-
|
673 |
-
def bimamba_inner_ref(
|
674 |
-
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
675 |
-
out_proj_weight, out_proj_bias,
|
676 |
-
A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
677 |
-
C_proj_bias=None, delta_softplus=True
|
678 |
-
):
|
679 |
-
L = xz.shape[-1]
|
680 |
-
delta_rank = delta_proj_weight.shape[1]
|
681 |
-
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
682 |
-
x, z = xz.chunk(2, dim=1)
|
683 |
-
x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu")
|
684 |
-
# We're being very careful here about the layout, to avoid extra transposes.
|
685 |
-
# We want delta to have d as the slowest moving dimension
|
686 |
-
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
687 |
-
x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
|
688 |
-
delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
|
689 |
-
delta = rearrange(delta, "d (b l) -> b d l", l=L)
|
690 |
-
if B is None: # variable B
|
691 |
-
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d)
|
692 |
-
if B_proj_bias is not None:
|
693 |
-
B = B + B_proj_bias.to(dtype=B.dtype)
|
694 |
-
if not A.is_complex():
|
695 |
-
B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
696 |
-
else:
|
697 |
-
B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
|
698 |
-
if C is None: # variable B
|
699 |
-
C = x_dbl[:, -d_state:] # (bl d)
|
700 |
-
if C_proj_bias is not None:
|
701 |
-
C = C + C_proj_bias.to(dtype=C.dtype)
|
702 |
-
if not A.is_complex():
|
703 |
-
C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
704 |
-
else:
|
705 |
-
C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
|
706 |
-
y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
|
707 |
-
y_b = selective_scan_fn(x.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus=True)
|
708 |
-
y = y + y_b.flip([-1])
|
709 |
-
return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mamba/mamba_ssm/ops/triton/__init__.py
DELETED
File without changes
|
mamba/mamba_ssm/ops/triton/layernorm.py
DELETED
@@ -1,636 +0,0 @@
|
|
1 |
-
# Copyright (c) 2023, Tri Dao.
|
2 |
-
# Implement residual + layer_norm / rms_norm.
|
3 |
-
|
4 |
-
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
5 |
-
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
6 |
-
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
7 |
-
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
8 |
-
|
9 |
-
import math
|
10 |
-
|
11 |
-
import torch
|
12 |
-
import torch.nn.functional as F
|
13 |
-
from torch.cuda.amp import custom_fwd, custom_bwd
|
14 |
-
|
15 |
-
import triton
|
16 |
-
import triton.language as tl
|
17 |
-
|
18 |
-
|
19 |
-
def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
|
20 |
-
dtype = x.dtype
|
21 |
-
if upcast:
|
22 |
-
weight = weight.float()
|
23 |
-
bias = bias.float() if bias is not None else None
|
24 |
-
if upcast:
|
25 |
-
x = x.float()
|
26 |
-
residual = residual.float() if residual is not None else residual
|
27 |
-
if residual is not None:
|
28 |
-
x = (x + residual).to(x.dtype)
|
29 |
-
out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
|
30 |
-
dtype
|
31 |
-
)
|
32 |
-
return out if not prenorm else (out, x)
|
33 |
-
|
34 |
-
|
35 |
-
def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
|
36 |
-
dtype = x.dtype
|
37 |
-
if upcast:
|
38 |
-
weight = weight.float()
|
39 |
-
bias = bias.float() if bias is not None else None
|
40 |
-
if upcast:
|
41 |
-
x = x.float()
|
42 |
-
residual = residual.float() if residual is not None else residual
|
43 |
-
if residual is not None:
|
44 |
-
x = (x + residual).to(x.dtype)
|
45 |
-
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
46 |
-
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
|
47 |
-
out = out.to(dtype)
|
48 |
-
return out if not prenorm else (out, x)
|
49 |
-
|
50 |
-
|
51 |
-
@triton.autotune(
|
52 |
-
configs=[
|
53 |
-
triton.Config({}, num_warps=1),
|
54 |
-
triton.Config({}, num_warps=2),
|
55 |
-
triton.Config({}, num_warps=4),
|
56 |
-
triton.Config({}, num_warps=8),
|
57 |
-
triton.Config({}, num_warps=16),
|
58 |
-
triton.Config({}, num_warps=32),
|
59 |
-
],
|
60 |
-
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
|
61 |
-
)
|
62 |
-
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
63 |
-
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
|
64 |
-
@triton.jit
|
65 |
-
def _layer_norm_fwd_1pass_kernel(
|
66 |
-
X, # pointer to the input
|
67 |
-
Y, # pointer to the output
|
68 |
-
W, # pointer to the weights
|
69 |
-
B, # pointer to the biases
|
70 |
-
RESIDUAL, # pointer to the residual
|
71 |
-
RESIDUAL_OUT, # pointer to the residual
|
72 |
-
Mean, # pointer to the mean
|
73 |
-
Rstd, # pointer to the 1/std
|
74 |
-
stride_x_row, # how much to increase the pointer when moving by 1 row
|
75 |
-
stride_y_row,
|
76 |
-
stride_res_row,
|
77 |
-
stride_res_out_row,
|
78 |
-
N, # number of columns in X
|
79 |
-
eps, # epsilon to avoid division by zero
|
80 |
-
IS_RMS_NORM: tl.constexpr,
|
81 |
-
BLOCK_N: tl.constexpr,
|
82 |
-
HAS_RESIDUAL: tl.constexpr,
|
83 |
-
STORE_RESIDUAL_OUT: tl.constexpr,
|
84 |
-
HAS_BIAS: tl.constexpr,
|
85 |
-
):
|
86 |
-
# Map the program id to the row of X and Y it should compute.
|
87 |
-
row = tl.program_id(0)
|
88 |
-
X += row * stride_x_row
|
89 |
-
Y += row * stride_y_row
|
90 |
-
if HAS_RESIDUAL:
|
91 |
-
RESIDUAL += row * stride_res_row
|
92 |
-
if STORE_RESIDUAL_OUT:
|
93 |
-
RESIDUAL_OUT += row * stride_res_out_row
|
94 |
-
# Compute mean and variance
|
95 |
-
cols = tl.arange(0, BLOCK_N)
|
96 |
-
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
97 |
-
if HAS_RESIDUAL:
|
98 |
-
residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
|
99 |
-
x += residual
|
100 |
-
if STORE_RESIDUAL_OUT:
|
101 |
-
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
|
102 |
-
if not IS_RMS_NORM:
|
103 |
-
mean = tl.sum(x, axis=0) / N
|
104 |
-
tl.store(Mean + row, mean)
|
105 |
-
xbar = tl.where(cols < N, x - mean, 0.0)
|
106 |
-
var = tl.sum(xbar * xbar, axis=0) / N
|
107 |
-
else:
|
108 |
-
xbar = tl.where(cols < N, x, 0.0)
|
109 |
-
var = tl.sum(xbar * xbar, axis=0) / N
|
110 |
-
rstd = 1 / tl.sqrt(var + eps)
|
111 |
-
tl.store(Rstd + row, rstd)
|
112 |
-
# Normalize and apply linear transformation
|
113 |
-
mask = cols < N
|
114 |
-
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
115 |
-
if HAS_BIAS:
|
116 |
-
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
117 |
-
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
118 |
-
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
119 |
-
# Write output
|
120 |
-
tl.store(Y + cols, y, mask=mask)
|
121 |
-
|
122 |
-
|
123 |
-
def _layer_norm_fwd(
|
124 |
-
x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False
|
125 |
-
):
|
126 |
-
if residual is not None:
|
127 |
-
residual_dtype = residual.dtype
|
128 |
-
M, N = x.shape
|
129 |
-
assert x.stride(-1) == 1
|
130 |
-
if residual is not None:
|
131 |
-
assert residual.stride(-1) == 1
|
132 |
-
assert residual.shape == (M, N)
|
133 |
-
assert weight.shape == (N,)
|
134 |
-
assert weight.stride(-1) == 1
|
135 |
-
if bias is not None:
|
136 |
-
assert bias.stride(-1) == 1
|
137 |
-
assert bias.shape == (N,)
|
138 |
-
# allocate output
|
139 |
-
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
|
140 |
-
assert y.stride(-1) == 1
|
141 |
-
if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
|
142 |
-
residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
|
143 |
-
assert residual_out.stride(-1) == 1
|
144 |
-
else:
|
145 |
-
residual_out = None
|
146 |
-
mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None
|
147 |
-
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
|
148 |
-
# Less than 64KB per feature: enqueue fused kernel
|
149 |
-
MAX_FUSED_SIZE = 65536 // x.element_size()
|
150 |
-
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
151 |
-
if N > BLOCK_N:
|
152 |
-
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
153 |
-
# heuristics for number of warps
|
154 |
-
with torch.cuda.device(x.device.index):
|
155 |
-
_layer_norm_fwd_1pass_kernel[(M,)](
|
156 |
-
x,
|
157 |
-
y,
|
158 |
-
weight,
|
159 |
-
bias,
|
160 |
-
residual,
|
161 |
-
residual_out,
|
162 |
-
mean,
|
163 |
-
rstd,
|
164 |
-
x.stride(0),
|
165 |
-
y.stride(0),
|
166 |
-
residual.stride(0) if residual is not None else 0,
|
167 |
-
residual_out.stride(0) if residual_out is not None else 0,
|
168 |
-
N,
|
169 |
-
eps,
|
170 |
-
is_rms_norm,
|
171 |
-
BLOCK_N,
|
172 |
-
residual is not None,
|
173 |
-
residual_out is not None,
|
174 |
-
bias is not None,
|
175 |
-
)
|
176 |
-
# residual_out is None if residual is None and residual_dtype == input_dtype
|
177 |
-
return y, mean, rstd, residual_out if residual_out is not None else x
|
178 |
-
|
179 |
-
|
180 |
-
@triton.autotune(
|
181 |
-
configs=[
|
182 |
-
triton.Config({}, num_warps=1),
|
183 |
-
triton.Config({}, num_warps=2),
|
184 |
-
triton.Config({}, num_warps=4),
|
185 |
-
triton.Config({}, num_warps=8),
|
186 |
-
triton.Config({}, num_warps=16),
|
187 |
-
triton.Config({}, num_warps=32),
|
188 |
-
],
|
189 |
-
key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
|
190 |
-
)
|
191 |
-
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
192 |
-
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
|
193 |
-
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
|
194 |
-
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
|
195 |
-
@triton.jit
|
196 |
-
def _layer_norm_bwd_kernel(
|
197 |
-
X, # pointer to the input
|
198 |
-
W, # pointer to the weights
|
199 |
-
B, # pointer to the biases
|
200 |
-
Y, # pointer to the output to be recomputed
|
201 |
-
DY, # pointer to the output gradient
|
202 |
-
DX, # pointer to the input gradient
|
203 |
-
DW, # pointer to the partial sum of weights gradient
|
204 |
-
DB, # pointer to the partial sum of biases gradient
|
205 |
-
DRESIDUAL,
|
206 |
-
DRESIDUAL_IN,
|
207 |
-
Mean, # pointer to the mean
|
208 |
-
Rstd, # pointer to the 1/std
|
209 |
-
stride_x_row, # how much to increase the pointer when moving by 1 row
|
210 |
-
stride_y_row,
|
211 |
-
stride_dy_row,
|
212 |
-
stride_dx_row,
|
213 |
-
stride_dres_row,
|
214 |
-
stride_dres_in_row,
|
215 |
-
M, # number of rows in X
|
216 |
-
N, # number of columns in X
|
217 |
-
eps, # epsilon to avoid division by zero
|
218 |
-
rows_per_program,
|
219 |
-
IS_RMS_NORM: tl.constexpr,
|
220 |
-
BLOCK_N: tl.constexpr,
|
221 |
-
HAS_DRESIDUAL: tl.constexpr,
|
222 |
-
STORE_DRESIDUAL: tl.constexpr,
|
223 |
-
HAS_BIAS: tl.constexpr,
|
224 |
-
RECOMPUTE_OUTPUT: tl.constexpr,
|
225 |
-
):
|
226 |
-
# Map the program id to the elements of X, DX, and DY it should compute.
|
227 |
-
row_block_id = tl.program_id(0)
|
228 |
-
row_start = row_block_id * rows_per_program
|
229 |
-
cols = tl.arange(0, BLOCK_N)
|
230 |
-
mask = cols < N
|
231 |
-
X += row_start * stride_x_row
|
232 |
-
if HAS_DRESIDUAL:
|
233 |
-
DRESIDUAL += row_start * stride_dres_row
|
234 |
-
if STORE_DRESIDUAL:
|
235 |
-
DRESIDUAL_IN += row_start * stride_dres_in_row
|
236 |
-
DY += row_start * stride_dy_row
|
237 |
-
DX += row_start * stride_dx_row
|
238 |
-
if RECOMPUTE_OUTPUT:
|
239 |
-
Y += row_start * stride_y_row
|
240 |
-
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
241 |
-
if RECOMPUTE_OUTPUT and HAS_BIAS:
|
242 |
-
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
|
243 |
-
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
244 |
-
if HAS_BIAS:
|
245 |
-
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
246 |
-
row_end = min((row_block_id + 1) * rows_per_program, M)
|
247 |
-
for row in range(row_start, row_end):
|
248 |
-
# Load data to SRAM
|
249 |
-
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
250 |
-
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
251 |
-
if not IS_RMS_NORM:
|
252 |
-
mean = tl.load(Mean + row)
|
253 |
-
rstd = tl.load(Rstd + row)
|
254 |
-
# Compute dx
|
255 |
-
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
256 |
-
xhat = tl.where(mask, xhat, 0.0)
|
257 |
-
if RECOMPUTE_OUTPUT:
|
258 |
-
y = xhat * w + b if HAS_BIAS else xhat * w
|
259 |
-
tl.store(Y + cols, y, mask=mask)
|
260 |
-
wdy = w * dy
|
261 |
-
dw += dy * xhat
|
262 |
-
if HAS_BIAS:
|
263 |
-
db += dy
|
264 |
-
if not IS_RMS_NORM:
|
265 |
-
c1 = tl.sum(xhat * wdy, axis=0) / N
|
266 |
-
c2 = tl.sum(wdy, axis=0) / N
|
267 |
-
dx = (wdy - (xhat * c1 + c2)) * rstd
|
268 |
-
else:
|
269 |
-
c1 = tl.sum(xhat * wdy, axis=0) / N
|
270 |
-
dx = (wdy - xhat * c1) * rstd
|
271 |
-
if HAS_DRESIDUAL:
|
272 |
-
dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
|
273 |
-
dx += dres
|
274 |
-
# Write dx
|
275 |
-
if STORE_DRESIDUAL:
|
276 |
-
tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
|
277 |
-
tl.store(DX + cols, dx, mask=mask)
|
278 |
-
|
279 |
-
X += stride_x_row
|
280 |
-
if HAS_DRESIDUAL:
|
281 |
-
DRESIDUAL += stride_dres_row
|
282 |
-
if STORE_DRESIDUAL:
|
283 |
-
DRESIDUAL_IN += stride_dres_in_row
|
284 |
-
if RECOMPUTE_OUTPUT:
|
285 |
-
Y += stride_y_row
|
286 |
-
DY += stride_dy_row
|
287 |
-
DX += stride_dx_row
|
288 |
-
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
|
289 |
-
if HAS_BIAS:
|
290 |
-
tl.store(DB + row_block_id * N + cols, db, mask=mask)
|
291 |
-
|
292 |
-
|
293 |
-
def _layer_norm_bwd(
|
294 |
-
dy,
|
295 |
-
x,
|
296 |
-
weight,
|
297 |
-
bias,
|
298 |
-
eps,
|
299 |
-
mean,
|
300 |
-
rstd,
|
301 |
-
dresidual=None,
|
302 |
-
has_residual=False,
|
303 |
-
is_rms_norm=False,
|
304 |
-
x_dtype=None,
|
305 |
-
recompute_output=False,
|
306 |
-
):
|
307 |
-
M, N = x.shape
|
308 |
-
assert x.stride(-1) == 1
|
309 |
-
assert dy.stride(-1) == 1
|
310 |
-
assert dy.shape == (M, N)
|
311 |
-
if dresidual is not None:
|
312 |
-
assert dresidual.stride(-1) == 1
|
313 |
-
assert dresidual.shape == (M, N)
|
314 |
-
assert weight.shape == (N,)
|
315 |
-
assert weight.stride(-1) == 1
|
316 |
-
if bias is not None:
|
317 |
-
assert bias.stride(-1) == 1
|
318 |
-
assert bias.shape == (N,)
|
319 |
-
# allocate output
|
320 |
-
dx = (
|
321 |
-
torch.empty_like(x)
|
322 |
-
if x_dtype is None
|
323 |
-
else torch.empty(M, N, dtype=x_dtype, device=x.device)
|
324 |
-
)
|
325 |
-
dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None
|
326 |
-
y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
|
327 |
-
|
328 |
-
# Less than 64KB per feature: enqueue fused kernel
|
329 |
-
MAX_FUSED_SIZE = 65536 // x.element_size()
|
330 |
-
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
331 |
-
if N > BLOCK_N:
|
332 |
-
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
333 |
-
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
334 |
-
_dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
|
335 |
-
_db = (
|
336 |
-
torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
|
337 |
-
if bias is not None
|
338 |
-
else None
|
339 |
-
)
|
340 |
-
rows_per_program = math.ceil(M / sm_count)
|
341 |
-
grid = (sm_count,)
|
342 |
-
with torch.cuda.device(x.device.index):
|
343 |
-
_layer_norm_bwd_kernel[grid](
|
344 |
-
x,
|
345 |
-
weight,
|
346 |
-
bias,
|
347 |
-
y,
|
348 |
-
dy,
|
349 |
-
dx,
|
350 |
-
_dw,
|
351 |
-
_db,
|
352 |
-
dresidual,
|
353 |
-
dresidual_in,
|
354 |
-
mean,
|
355 |
-
rstd,
|
356 |
-
x.stride(0),
|
357 |
-
0 if not recompute_output else y.stride(0),
|
358 |
-
dy.stride(0),
|
359 |
-
dx.stride(0),
|
360 |
-
dresidual.stride(0) if dresidual is not None else 0,
|
361 |
-
dresidual_in.stride(0) if dresidual_in is not None else 0,
|
362 |
-
M,
|
363 |
-
N,
|
364 |
-
eps,
|
365 |
-
rows_per_program,
|
366 |
-
is_rms_norm,
|
367 |
-
BLOCK_N,
|
368 |
-
dresidual is not None,
|
369 |
-
dresidual_in is not None,
|
370 |
-
bias is not None,
|
371 |
-
)
|
372 |
-
dw = _dw.sum(0).to(weight.dtype)
|
373 |
-
db = _db.sum(0).to(bias.dtype) if bias is not None else None
|
374 |
-
# Don't need to compute dresidual_in separately in this case
|
375 |
-
if has_residual and dx.dtype == x.dtype:
|
376 |
-
dresidual_in = dx
|
377 |
-
return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)
|
378 |
-
|
379 |
-
|
380 |
-
class LayerNormFn(torch.autograd.Function):
|
381 |
-
@staticmethod
|
382 |
-
def forward(
|
383 |
-
ctx,
|
384 |
-
x,
|
385 |
-
weight,
|
386 |
-
bias,
|
387 |
-
residual=None,
|
388 |
-
eps=1e-6,
|
389 |
-
prenorm=False,
|
390 |
-
residual_in_fp32=False,
|
391 |
-
is_rms_norm=False,
|
392 |
-
):
|
393 |
-
x_shape_og = x.shape
|
394 |
-
# reshape input data into 2D tensor
|
395 |
-
x = x.reshape(-1, x.shape[-1])
|
396 |
-
if x.stride(-1) != 1:
|
397 |
-
x = x.contiguous()
|
398 |
-
if residual is not None:
|
399 |
-
assert residual.shape == x_shape_og
|
400 |
-
residual = residual.reshape(-1, residual.shape[-1])
|
401 |
-
if residual.stride(-1) != 1:
|
402 |
-
residual = residual.contiguous()
|
403 |
-
weight = weight.contiguous()
|
404 |
-
if bias is not None:
|
405 |
-
bias = bias.contiguous()
|
406 |
-
residual_dtype = (
|
407 |
-
residual.dtype
|
408 |
-
if residual is not None
|
409 |
-
else (torch.float32 if residual_in_fp32 else None)
|
410 |
-
)
|
411 |
-
y, mean, rstd, residual_out = _layer_norm_fwd(
|
412 |
-
x, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm
|
413 |
-
)
|
414 |
-
ctx.save_for_backward(residual_out, weight, bias, mean, rstd)
|
415 |
-
ctx.x_shape_og = x_shape_og
|
416 |
-
ctx.eps = eps
|
417 |
-
ctx.is_rms_norm = is_rms_norm
|
418 |
-
ctx.has_residual = residual is not None
|
419 |
-
ctx.prenorm = prenorm
|
420 |
-
ctx.x_dtype = x.dtype
|
421 |
-
y = y.reshape(x_shape_og)
|
422 |
-
return y if not prenorm else (y, residual_out.reshape(x_shape_og))
|
423 |
-
|
424 |
-
@staticmethod
|
425 |
-
def backward(ctx, dy, *args):
|
426 |
-
x, weight, bias, mean, rstd = ctx.saved_tensors
|
427 |
-
dy = dy.reshape(-1, dy.shape[-1])
|
428 |
-
if dy.stride(-1) != 1:
|
429 |
-
dy = dy.contiguous()
|
430 |
-
assert dy.shape == x.shape
|
431 |
-
if ctx.prenorm:
|
432 |
-
dresidual = args[0]
|
433 |
-
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
434 |
-
if dresidual.stride(-1) != 1:
|
435 |
-
dresidual = dresidual.contiguous()
|
436 |
-
assert dresidual.shape == x.shape
|
437 |
-
else:
|
438 |
-
dresidual = None
|
439 |
-
dx, dw, db, dresidual_in = _layer_norm_bwd(
|
440 |
-
dy,
|
441 |
-
x,
|
442 |
-
weight,
|
443 |
-
bias,
|
444 |
-
ctx.eps,
|
445 |
-
mean,
|
446 |
-
rstd,
|
447 |
-
dresidual,
|
448 |
-
ctx.has_residual,
|
449 |
-
ctx.is_rms_norm,
|
450 |
-
x_dtype=ctx.x_dtype,
|
451 |
-
)
|
452 |
-
return (
|
453 |
-
dx.reshape(ctx.x_shape_og),
|
454 |
-
dw,
|
455 |
-
db,
|
456 |
-
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
457 |
-
None,
|
458 |
-
None,
|
459 |
-
None,
|
460 |
-
None,
|
461 |
-
)
|
462 |
-
|
463 |
-
|
464 |
-
def layer_norm_fn(
|
465 |
-
x,
|
466 |
-
weight,
|
467 |
-
bias,
|
468 |
-
residual=None,
|
469 |
-
eps=1e-6,
|
470 |
-
prenorm=False,
|
471 |
-
residual_in_fp32=False,
|
472 |
-
is_rms_norm=False,
|
473 |
-
):
|
474 |
-
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm)
|
475 |
-
|
476 |
-
|
477 |
-
def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6):
|
478 |
-
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
|
479 |
-
|
480 |
-
|
481 |
-
class RMSNorm(torch.nn.Module):
|
482 |
-
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
|
483 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
484 |
-
super().__init__()
|
485 |
-
self.eps = eps
|
486 |
-
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
487 |
-
self.register_parameter("bias", None)
|
488 |
-
self.reset_parameters()
|
489 |
-
|
490 |
-
def reset_parameters(self):
|
491 |
-
torch.nn.init.ones_(self.weight)
|
492 |
-
|
493 |
-
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
|
494 |
-
return rms_norm_fn(
|
495 |
-
x,
|
496 |
-
self.weight,
|
497 |
-
self.bias,
|
498 |
-
residual=residual,
|
499 |
-
eps=self.eps,
|
500 |
-
prenorm=prenorm,
|
501 |
-
residual_in_fp32=residual_in_fp32,
|
502 |
-
# is_rms_norm=True,
|
503 |
-
)
|
504 |
-
|
505 |
-
|
506 |
-
class LayerNormLinearFn(torch.autograd.Function):
|
507 |
-
@staticmethod
|
508 |
-
@custom_fwd
|
509 |
-
def forward(
|
510 |
-
ctx,
|
511 |
-
x,
|
512 |
-
norm_weight,
|
513 |
-
norm_bias,
|
514 |
-
linear_weight,
|
515 |
-
linear_bias,
|
516 |
-
residual=None,
|
517 |
-
eps=1e-6,
|
518 |
-
prenorm=False,
|
519 |
-
residual_in_fp32=False,
|
520 |
-
is_rms_norm=False,
|
521 |
-
):
|
522 |
-
x_shape_og = x.shape
|
523 |
-
# reshape input data into 2D tensor
|
524 |
-
x = x.reshape(-1, x.shape[-1])
|
525 |
-
if x.stride(-1) != 1:
|
526 |
-
x = x.contiguous()
|
527 |
-
if residual is not None:
|
528 |
-
assert residual.shape == x_shape_og
|
529 |
-
residual = residual.reshape(-1, residual.shape[-1])
|
530 |
-
if residual.stride(-1) != 1:
|
531 |
-
residual = residual.contiguous()
|
532 |
-
norm_weight = norm_weight.contiguous()
|
533 |
-
if norm_bias is not None:
|
534 |
-
norm_bias = norm_bias.contiguous()
|
535 |
-
residual_dtype = (
|
536 |
-
residual.dtype
|
537 |
-
if residual is not None
|
538 |
-
else (torch.float32 if residual_in_fp32 else None)
|
539 |
-
)
|
540 |
-
y, mean, rstd, residual_out = _layer_norm_fwd(
|
541 |
-
x,
|
542 |
-
norm_weight,
|
543 |
-
norm_bias,
|
544 |
-
eps,
|
545 |
-
residual,
|
546 |
-
out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
|
547 |
-
residual_dtype=residual_dtype,
|
548 |
-
is_rms_norm=is_rms_norm,
|
549 |
-
)
|
550 |
-
y = y.reshape(x_shape_og)
|
551 |
-
dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
|
552 |
-
linear_weight = linear_weight.to(dtype)
|
553 |
-
linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
|
554 |
-
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
|
555 |
-
# We don't store y, will be recomputed in the backward pass to save memory
|
556 |
-
ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
|
557 |
-
ctx.x_shape_og = x_shape_og
|
558 |
-
ctx.eps = eps
|
559 |
-
ctx.is_rms_norm = is_rms_norm
|
560 |
-
ctx.has_residual = residual is not None
|
561 |
-
ctx.prenorm = prenorm
|
562 |
-
ctx.x_dtype = x.dtype
|
563 |
-
ctx.linear_bias_is_none = linear_bias is None
|
564 |
-
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
|
565 |
-
|
566 |
-
@staticmethod
|
567 |
-
@custom_bwd
|
568 |
-
def backward(ctx, dout, *args):
|
569 |
-
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
|
570 |
-
dout = dout.reshape(-1, dout.shape[-1])
|
571 |
-
dy = F.linear(dout, linear_weight.t())
|
572 |
-
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
|
573 |
-
if dy.stride(-1) != 1:
|
574 |
-
dy = dy.contiguous()
|
575 |
-
assert dy.shape == x.shape
|
576 |
-
if ctx.prenorm:
|
577 |
-
dresidual = args[0]
|
578 |
-
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
579 |
-
if dresidual.stride(-1) != 1:
|
580 |
-
dresidual = dresidual.contiguous()
|
581 |
-
assert dresidual.shape == x.shape
|
582 |
-
else:
|
583 |
-
dresidual = None
|
584 |
-
dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd(
|
585 |
-
dy,
|
586 |
-
x,
|
587 |
-
norm_weight,
|
588 |
-
norm_bias,
|
589 |
-
ctx.eps,
|
590 |
-
mean,
|
591 |
-
rstd,
|
592 |
-
dresidual,
|
593 |
-
ctx.has_residual,
|
594 |
-
ctx.is_rms_norm,
|
595 |
-
x_dtype=ctx.x_dtype,
|
596 |
-
recompute_output=True,
|
597 |
-
)
|
598 |
-
dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
|
599 |
-
return (
|
600 |
-
dx.reshape(ctx.x_shape_og),
|
601 |
-
dnorm_weight,
|
602 |
-
dnorm_bias,
|
603 |
-
dlinear_weight,
|
604 |
-
dlinear_bias,
|
605 |
-
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
606 |
-
None,
|
607 |
-
None,
|
608 |
-
None,
|
609 |
-
None,
|
610 |
-
)
|
611 |
-
|
612 |
-
|
613 |
-
def layer_norm_linear_fn(
|
614 |
-
x,
|
615 |
-
norm_weight,
|
616 |
-
norm_bias,
|
617 |
-
linear_weight,
|
618 |
-
linear_bias,
|
619 |
-
residual=None,
|
620 |
-
eps=1e-6,
|
621 |
-
prenorm=False,
|
622 |
-
residual_in_fp32=False,
|
623 |
-
is_rms_norm=False,
|
624 |
-
):
|
625 |
-
return LayerNormLinearFn.apply(
|
626 |
-
x,
|
627 |
-
norm_weight,
|
628 |
-
norm_bias,
|
629 |
-
linear_weight,
|
630 |
-
linear_bias,
|
631 |
-
residual,
|
632 |
-
eps,
|
633 |
-
prenorm,
|
634 |
-
residual_in_fp32,
|
635 |
-
is_rms_norm,
|
636 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|