Upstream builds
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- build.toml +2 -1
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/__init__.py +0 -14
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py +0 -326
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py +0 -338
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/selective_scan_interface.py +0 -659
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/layer_norm.py +0 -1166
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/selective_state_update.py +0 -389
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_scan.py +0 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_state.py +0 -2012
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_combined.py +0 -1884
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/__init__.py +0 -14
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py +0 -326
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py +0 -338
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/ops/selective_scan_interface.py +0 -659
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/ops/triton/layer_norm.py +0 -1166
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/ops/triton/selective_state_update.py +0 -389
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_scan.py +0 -0
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_state.py +0 -2012
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/ops/triton/ssd_combined.py +0 -1884
- build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/__init__.py +0 -14
- build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/distributed/__init__.py +0 -0
- build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py +0 -326
- build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/models/__init__.py +0 -0
- build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py +0 -338
- build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/modules/__init__.py +0 -0
- build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/__init__.py +0 -0
- build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/selective_scan_interface.py +0 -659
- build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/triton/__init__.py +0 -0
- build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/triton/layer_norm.py +0 -1166
- build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/triton/selective_state_update.py +0 -389
- build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_scan.py +0 -0
- build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_state.py +0 -2012
- build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/triton/ssd_combined.py +0 -1884
- build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/utils/__init__.py +0 -0
- build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/__init__.py +0 -14
- build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/distributed/__init__.py +0 -0
- build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py +0 -326
- build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/models/__init__.py +0 -0
- build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py +0 -338
- build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/modules/__init__.py +0 -0
- build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/__init__.py +0 -0
- build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/selective_scan_interface.py +0 -659
- build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/triton/__init__.py +0 -0
- build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/triton/layer_norm.py +0 -1166
- build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/triton/selective_state_update.py +0 -389
- build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_scan.py +0 -0
- build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_state.py +0 -2012
- build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_combined.py +0 -1884
- build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/utils/__init__.py +0 -0
- build/torch25-cxx98-cu121-x86_64-linux/mamba_ssm/__init__.py +0 -14
build.toml
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
[general]
|
| 2 |
name = "mamba_ssm"
|
| 3 |
-
|
|
|
|
| 4 |
|
| 5 |
[torch]
|
| 6 |
src = [
|
|
|
|
| 1 |
[general]
|
| 2 |
name = "mamba_ssm"
|
| 3 |
+
backends = ["cuda"]
|
| 4 |
+
python-depends = ["einops"]
|
| 5 |
|
| 6 |
[torch]
|
| 7 |
src = [
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/__init__.py
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
__version__ = "2.2.4"
|
| 2 |
-
|
| 3 |
-
from .ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
| 4 |
-
from .modules.mamba_simple import Mamba
|
| 5 |
-
from .modules.mamba2 import Mamba2
|
| 6 |
-
from .models.mixer_seq_simple import MambaLMHeadModel
|
| 7 |
-
|
| 8 |
-
__all__ = [
|
| 9 |
-
"selective_scan_fn",
|
| 10 |
-
"mamba_inner_fn",
|
| 11 |
-
"Mamba",
|
| 12 |
-
"Mamba2",
|
| 13 |
-
"MambaLMHeadModel",
|
| 14 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py
DELETED
|
@@ -1,326 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2024, Tri Dao.
|
| 2 |
-
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
|
| 3 |
-
from typing import Optional
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn as nn
|
| 7 |
-
import torch.nn.functional as F
|
| 8 |
-
from torch import Tensor
|
| 9 |
-
from torch.distributed import ProcessGroup
|
| 10 |
-
from ..utils.torch import custom_bwd, custom_fwd
|
| 11 |
-
|
| 12 |
-
from einops import rearrange
|
| 13 |
-
|
| 14 |
-
from ..distributed.distributed_utils import (
|
| 15 |
-
all_gather_raw,
|
| 16 |
-
all_reduce,
|
| 17 |
-
all_reduce_raw,
|
| 18 |
-
reduce_scatter,
|
| 19 |
-
reduce_scatter_raw,
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class ParallelLinearFunc(torch.autograd.Function):
|
| 24 |
-
@staticmethod
|
| 25 |
-
@custom_fwd
|
| 26 |
-
def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
|
| 27 |
-
"""
|
| 28 |
-
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
|
| 29 |
-
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
|
| 30 |
-
"""
|
| 31 |
-
ctx.compute_weight_gradient = weight.requires_grad
|
| 32 |
-
ctx.process_group = process_group
|
| 33 |
-
ctx.sequence_parallel = sequence_parallel
|
| 34 |
-
|
| 35 |
-
if torch.is_autocast_enabled():
|
| 36 |
-
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
| 37 |
-
x = x.contiguous()
|
| 38 |
-
if process_group is not None and sequence_parallel:
|
| 39 |
-
# We want to kick off the all_gather early, before weight dtype conversion
|
| 40 |
-
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
| 41 |
-
else:
|
| 42 |
-
total_x = x
|
| 43 |
-
|
| 44 |
-
if torch.is_autocast_enabled():
|
| 45 |
-
weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
|
| 46 |
-
bias = (
|
| 47 |
-
bias.to(dtype=torch.get_autocast_gpu_dtype())
|
| 48 |
-
if bias is not None
|
| 49 |
-
else None
|
| 50 |
-
)
|
| 51 |
-
weight = weight.contiguous()
|
| 52 |
-
if process_group is not None and sequence_parallel:
|
| 53 |
-
handle_x.wait()
|
| 54 |
-
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
|
| 55 |
-
batch_dim = batch_shape.numel()
|
| 56 |
-
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
|
| 57 |
-
output = F.linear(total_x, weight, bias)
|
| 58 |
-
if ctx.compute_weight_gradient:
|
| 59 |
-
ctx.save_for_backward(x, weight)
|
| 60 |
-
else:
|
| 61 |
-
ctx.save_for_backward(weight)
|
| 62 |
-
return output
|
| 63 |
-
|
| 64 |
-
@staticmethod
|
| 65 |
-
@custom_bwd
|
| 66 |
-
def backward(ctx, grad_output):
|
| 67 |
-
grad_output = grad_output.contiguous()
|
| 68 |
-
process_group = ctx.process_group
|
| 69 |
-
sequence_parallel = ctx.sequence_parallel
|
| 70 |
-
if ctx.compute_weight_gradient:
|
| 71 |
-
x, weight = ctx.saved_tensors
|
| 72 |
-
if process_group is not None and sequence_parallel:
|
| 73 |
-
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
| 74 |
-
else:
|
| 75 |
-
total_x = x
|
| 76 |
-
else:
|
| 77 |
-
(weight,) = ctx.saved_tensors
|
| 78 |
-
total_x = None
|
| 79 |
-
batch_shape = grad_output.shape[:-1]
|
| 80 |
-
batch_dim = batch_shape.numel()
|
| 81 |
-
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
| 82 |
-
if ctx.needs_input_grad[0]:
|
| 83 |
-
grad_input = F.linear(grad_output, weight.t())
|
| 84 |
-
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
| 85 |
-
if process_group is not None:
|
| 86 |
-
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
|
| 87 |
-
grad_input, handle_grad_input = reduce_fn(
|
| 88 |
-
grad_input, process_group, async_op=True
|
| 89 |
-
)
|
| 90 |
-
else:
|
| 91 |
-
grad_input = None
|
| 92 |
-
if ctx.needs_input_grad[1]:
|
| 93 |
-
assert ctx.compute_weight_gradient
|
| 94 |
-
if process_group is not None and sequence_parallel:
|
| 95 |
-
handle_x.wait()
|
| 96 |
-
grad_weight = torch.einsum(
|
| 97 |
-
"bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1])
|
| 98 |
-
)
|
| 99 |
-
else:
|
| 100 |
-
grad_weight = None
|
| 101 |
-
grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None
|
| 102 |
-
if process_group is not None and ctx.needs_input_grad[0]:
|
| 103 |
-
handle_grad_input.wait()
|
| 104 |
-
return grad_input, grad_weight, grad_bias, None, None
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
def parallel_linear_func(
|
| 108 |
-
x: Tensor,
|
| 109 |
-
weight: Tensor,
|
| 110 |
-
bias: Optional[Tensor] = None,
|
| 111 |
-
process_group: Optional[ProcessGroup] = None,
|
| 112 |
-
sequence_parallel: bool = True,
|
| 113 |
-
):
|
| 114 |
-
return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel)
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
class ColumnParallelLinear(nn.Linear):
|
| 118 |
-
def __init__(
|
| 119 |
-
self,
|
| 120 |
-
in_features: int,
|
| 121 |
-
out_features: int,
|
| 122 |
-
process_group: ProcessGroup,
|
| 123 |
-
bias: bool = True,
|
| 124 |
-
sequence_parallel=True,
|
| 125 |
-
multiple_of=1,
|
| 126 |
-
device=None,
|
| 127 |
-
dtype=None,
|
| 128 |
-
) -> None:
|
| 129 |
-
world_size = torch.distributed.get_world_size(process_group)
|
| 130 |
-
if out_features % multiple_of:
|
| 131 |
-
raise ValueError(
|
| 132 |
-
f"out_features ({out_features}) must be a multiple of {multiple_of}"
|
| 133 |
-
)
|
| 134 |
-
multiple = out_features // multiple_of
|
| 135 |
-
# We want to split @multiple across world_size, but it could be an uneven split
|
| 136 |
-
div = multiple // world_size
|
| 137 |
-
mod = multiple % world_size
|
| 138 |
-
# The first @mod ranks get @div + 1 copies, the rest get @div copies
|
| 139 |
-
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
| 140 |
-
super().__init__(
|
| 141 |
-
in_features,
|
| 142 |
-
local_multiple * multiple_of,
|
| 143 |
-
bias=bias,
|
| 144 |
-
device=device,
|
| 145 |
-
dtype=dtype,
|
| 146 |
-
)
|
| 147 |
-
self.process_group = process_group
|
| 148 |
-
self.sequence_parallel = sequence_parallel
|
| 149 |
-
|
| 150 |
-
def forward(self, x):
|
| 151 |
-
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
| 152 |
-
# we do an all_gather of x before doing the matmul.
|
| 153 |
-
# If not, then the input is already gathered.
|
| 154 |
-
return parallel_linear_func(
|
| 155 |
-
x,
|
| 156 |
-
self.weight,
|
| 157 |
-
self.bias,
|
| 158 |
-
process_group=self.process_group,
|
| 159 |
-
sequence_parallel=self.sequence_parallel,
|
| 160 |
-
)
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
class RowParallelLinear(nn.Linear):
|
| 164 |
-
def __init__(
|
| 165 |
-
self,
|
| 166 |
-
in_features: int,
|
| 167 |
-
out_features: int,
|
| 168 |
-
process_group: ProcessGroup,
|
| 169 |
-
bias: bool = True,
|
| 170 |
-
sequence_parallel=True,
|
| 171 |
-
multiple_of=1,
|
| 172 |
-
device=None,
|
| 173 |
-
dtype=None,
|
| 174 |
-
) -> None:
|
| 175 |
-
world_size = torch.distributed.get_world_size(process_group)
|
| 176 |
-
rank = torch.distributed.get_rank(process_group)
|
| 177 |
-
if in_features % multiple_of:
|
| 178 |
-
raise ValueError(
|
| 179 |
-
f"in_features ({in_features}) must be a multiple of {multiple_of}"
|
| 180 |
-
)
|
| 181 |
-
multiple = in_features // multiple_of
|
| 182 |
-
# We want to split @multiple across world_size, but it could be an uneven split
|
| 183 |
-
div = multiple // world_size
|
| 184 |
-
mod = multiple % world_size
|
| 185 |
-
# The first @mod ranks get @div + 1 copies, the rest get @div copies
|
| 186 |
-
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
| 187 |
-
# Only rank 0 will have bias
|
| 188 |
-
super().__init__(
|
| 189 |
-
local_multiple * multiple_of,
|
| 190 |
-
out_features,
|
| 191 |
-
bias=bias and rank == 0,
|
| 192 |
-
device=device,
|
| 193 |
-
dtype=dtype,
|
| 194 |
-
)
|
| 195 |
-
self.process_group = process_group
|
| 196 |
-
self.sequence_parallel = sequence_parallel
|
| 197 |
-
|
| 198 |
-
def forward(self, x):
|
| 199 |
-
"""
|
| 200 |
-
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
|
| 201 |
-
a reduce_scatter of the result.
|
| 202 |
-
"""
|
| 203 |
-
out = parallel_linear_func(x, self.weight, self.bias)
|
| 204 |
-
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
| 205 |
-
return reduce_fn(out, self.process_group)
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
class VocabParallelEmbedding(nn.Embedding):
|
| 209 |
-
def __init__(
|
| 210 |
-
self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs
|
| 211 |
-
):
|
| 212 |
-
self.process_group = process_group
|
| 213 |
-
if process_group is not None:
|
| 214 |
-
world_size = torch.distributed.get_world_size(process_group)
|
| 215 |
-
if num_embeddings % world_size != 0:
|
| 216 |
-
raise ValueError(
|
| 217 |
-
f"num_embeddings ({num_embeddings}) must be divisible by "
|
| 218 |
-
f"world_size ({world_size})"
|
| 219 |
-
)
|
| 220 |
-
if world_size > 1 and padding_idx is not None:
|
| 221 |
-
raise RuntimeError("ParallelEmbedding does not support padding_idx")
|
| 222 |
-
else:
|
| 223 |
-
world_size = 1
|
| 224 |
-
super().__init__(
|
| 225 |
-
num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs
|
| 226 |
-
)
|
| 227 |
-
|
| 228 |
-
def forward(self, input: Tensor) -> Tensor:
|
| 229 |
-
if self.process_group is None:
|
| 230 |
-
return super().forward(input)
|
| 231 |
-
else:
|
| 232 |
-
rank = torch.distributed.get_rank(self.process_group)
|
| 233 |
-
vocab_size = self.num_embeddings
|
| 234 |
-
vocab_start_index, vocab_end_index = (
|
| 235 |
-
rank * vocab_size,
|
| 236 |
-
(rank + 1) * vocab_size,
|
| 237 |
-
)
|
| 238 |
-
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
| 239 |
-
input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
|
| 240 |
-
input = input - vocab_start_index
|
| 241 |
-
input[input_ids_mask] = 0
|
| 242 |
-
embeddings = super().forward(input)
|
| 243 |
-
embeddings[input_ids_mask] = 0.0
|
| 244 |
-
return embeddings
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
class ColumnParallelEmbedding(nn.Embedding):
|
| 248 |
-
def __init__(
|
| 249 |
-
self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs
|
| 250 |
-
):
|
| 251 |
-
self.process_group = process_group
|
| 252 |
-
if process_group is not None:
|
| 253 |
-
world_size = torch.distributed.get_world_size(process_group)
|
| 254 |
-
if embedding_dim % world_size != 0:
|
| 255 |
-
raise ValueError(
|
| 256 |
-
f"embedding_dim ({embedding_dim}) must be divisible by "
|
| 257 |
-
f"world_size ({world_size})"
|
| 258 |
-
)
|
| 259 |
-
else:
|
| 260 |
-
world_size = 1
|
| 261 |
-
super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
class ParallelEmbeddings(nn.Module):
|
| 265 |
-
def __init__(
|
| 266 |
-
self,
|
| 267 |
-
embed_dim,
|
| 268 |
-
vocab_size,
|
| 269 |
-
max_position_embeddings,
|
| 270 |
-
process_group,
|
| 271 |
-
padding_idx=None,
|
| 272 |
-
sequence_parallel=True,
|
| 273 |
-
device=None,
|
| 274 |
-
dtype=None,
|
| 275 |
-
):
|
| 276 |
-
"""
|
| 277 |
-
If max_position_embeddings <= 0, there's no position embeddings
|
| 278 |
-
"""
|
| 279 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
| 280 |
-
super().__init__()
|
| 281 |
-
self.process_group = process_group
|
| 282 |
-
self.sequence_parallel = sequence_parallel
|
| 283 |
-
self.word_embeddings = VocabParallelEmbedding(
|
| 284 |
-
vocab_size,
|
| 285 |
-
embed_dim,
|
| 286 |
-
padding_idx=padding_idx,
|
| 287 |
-
process_group=process_group,
|
| 288 |
-
**factory_kwargs,
|
| 289 |
-
)
|
| 290 |
-
self.max_position_embeddings = max_position_embeddings
|
| 291 |
-
if self.max_position_embeddings > 0:
|
| 292 |
-
self.position_embeddings = ColumnParallelEmbedding(
|
| 293 |
-
max_position_embeddings,
|
| 294 |
-
embed_dim,
|
| 295 |
-
process_group=process_group,
|
| 296 |
-
**factory_kwargs,
|
| 297 |
-
)
|
| 298 |
-
|
| 299 |
-
def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
|
| 300 |
-
"""
|
| 301 |
-
input_ids: (batch, seqlen)
|
| 302 |
-
position_ids: (batch, seqlen)
|
| 303 |
-
"""
|
| 304 |
-
batch_size, seqlen = input_ids.shape
|
| 305 |
-
world_size = torch.distributed.get_world_size(self.process_group)
|
| 306 |
-
embeddings = self.word_embeddings(input_ids)
|
| 307 |
-
if self.max_position_embeddings > 0:
|
| 308 |
-
if position_ids is None:
|
| 309 |
-
position_ids = torch.arange(
|
| 310 |
-
seqlen, dtype=torch.long, device=input_ids.device
|
| 311 |
-
)
|
| 312 |
-
position_embeddings = self.position_embeddings(position_ids)
|
| 313 |
-
if world_size <= 1:
|
| 314 |
-
embeddings = embeddings + position_embeddings
|
| 315 |
-
else:
|
| 316 |
-
partition_dim = self.position_embeddings.embedding_dim
|
| 317 |
-
rank = torch.distributed.get_rank(self.process_group)
|
| 318 |
-
embeddings[
|
| 319 |
-
..., rank * partition_dim : (rank + 1) * partition_dim
|
| 320 |
-
] += position_embeddings
|
| 321 |
-
if combine_batch_seqlen_dim:
|
| 322 |
-
embeddings = rearrange(embeddings, "b s d -> (b s) d")
|
| 323 |
-
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
| 324 |
-
return (
|
| 325 |
-
embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
|
| 326 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py
DELETED
|
@@ -1,338 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2023, Albert Gu, Tri Dao.
|
| 2 |
-
|
| 3 |
-
import math
|
| 4 |
-
from functools import partial
|
| 5 |
-
import json
|
| 6 |
-
import os
|
| 7 |
-
import copy
|
| 8 |
-
|
| 9 |
-
from collections import namedtuple
|
| 10 |
-
|
| 11 |
-
import torch
|
| 12 |
-
import torch.nn as nn
|
| 13 |
-
|
| 14 |
-
from .config_mamba import MambaConfig
|
| 15 |
-
from ..modules.mamba_simple import Mamba
|
| 16 |
-
from ..modules.mamba2 import Mamba2
|
| 17 |
-
from ..modules.mha import MHA
|
| 18 |
-
from ..modules.mlp import GatedMLP
|
| 19 |
-
from ..modules.block import Block
|
| 20 |
-
from ..utils.generation import GenerationMixin
|
| 21 |
-
from ..utils.hf import load_config_hf, load_state_dict_hf
|
| 22 |
-
|
| 23 |
-
try:
|
| 24 |
-
from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
|
| 25 |
-
except ImportError:
|
| 26 |
-
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def create_block(
|
| 30 |
-
d_model,
|
| 31 |
-
d_intermediate,
|
| 32 |
-
ssm_cfg=None,
|
| 33 |
-
attn_layer_idx=None,
|
| 34 |
-
attn_cfg=None,
|
| 35 |
-
norm_epsilon=1e-5,
|
| 36 |
-
rms_norm=False,
|
| 37 |
-
residual_in_fp32=False,
|
| 38 |
-
fused_add_norm=False,
|
| 39 |
-
layer_idx=None,
|
| 40 |
-
device=None,
|
| 41 |
-
dtype=None,
|
| 42 |
-
):
|
| 43 |
-
if ssm_cfg is None:
|
| 44 |
-
ssm_cfg = {}
|
| 45 |
-
if attn_layer_idx is None:
|
| 46 |
-
attn_layer_idx = []
|
| 47 |
-
if attn_cfg is None:
|
| 48 |
-
attn_cfg = {}
|
| 49 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
| 50 |
-
if layer_idx not in attn_layer_idx:
|
| 51 |
-
# Create a copy of the config to modify
|
| 52 |
-
ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
|
| 53 |
-
ssm_layer = ssm_cfg.pop("layer", "Mamba1")
|
| 54 |
-
if ssm_layer not in ["Mamba1", "Mamba2"]:
|
| 55 |
-
raise ValueError(
|
| 56 |
-
f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2"
|
| 57 |
-
)
|
| 58 |
-
mixer_cls = partial(
|
| 59 |
-
Mamba2 if ssm_layer == "Mamba2" else Mamba,
|
| 60 |
-
layer_idx=layer_idx,
|
| 61 |
-
**ssm_cfg,
|
| 62 |
-
**factory_kwargs,
|
| 63 |
-
)
|
| 64 |
-
else:
|
| 65 |
-
mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
|
| 66 |
-
norm_cls = partial(
|
| 67 |
-
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
| 68 |
-
)
|
| 69 |
-
if d_intermediate == 0:
|
| 70 |
-
mlp_cls = nn.Identity
|
| 71 |
-
else:
|
| 72 |
-
mlp_cls = partial(
|
| 73 |
-
GatedMLP,
|
| 74 |
-
hidden_features=d_intermediate,
|
| 75 |
-
out_features=d_model,
|
| 76 |
-
**factory_kwargs,
|
| 77 |
-
)
|
| 78 |
-
block = Block(
|
| 79 |
-
d_model,
|
| 80 |
-
mixer_cls,
|
| 81 |
-
mlp_cls,
|
| 82 |
-
norm_cls=norm_cls,
|
| 83 |
-
fused_add_norm=fused_add_norm,
|
| 84 |
-
residual_in_fp32=residual_in_fp32,
|
| 85 |
-
)
|
| 86 |
-
block.layer_idx = layer_idx
|
| 87 |
-
return block
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
| 91 |
-
def _init_weights(
|
| 92 |
-
module,
|
| 93 |
-
n_layer,
|
| 94 |
-
initializer_range=0.02, # Now only used for embedding layer.
|
| 95 |
-
rescale_prenorm_residual=True,
|
| 96 |
-
n_residuals_per_layer=1, # Change to 2 if we have MLP
|
| 97 |
-
):
|
| 98 |
-
if isinstance(module, nn.Linear):
|
| 99 |
-
if module.bias is not None:
|
| 100 |
-
if not getattr(module.bias, "_no_reinit", False):
|
| 101 |
-
nn.init.zeros_(module.bias)
|
| 102 |
-
elif isinstance(module, nn.Embedding):
|
| 103 |
-
nn.init.normal_(module.weight, std=initializer_range)
|
| 104 |
-
|
| 105 |
-
if rescale_prenorm_residual:
|
| 106 |
-
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
| 107 |
-
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
| 108 |
-
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
| 109 |
-
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
| 110 |
-
#
|
| 111 |
-
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
| 112 |
-
for name, p in module.named_parameters():
|
| 113 |
-
if name in ["out_proj.weight", "fc2.weight"]:
|
| 114 |
-
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
| 115 |
-
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
| 116 |
-
# We need to reinit p since this code could be called multiple times
|
| 117 |
-
# Having just p *= scale would repeatedly scale it down
|
| 118 |
-
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
| 119 |
-
with torch.no_grad():
|
| 120 |
-
p /= math.sqrt(n_residuals_per_layer * n_layer)
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
class MixerModel(nn.Module):
|
| 124 |
-
def __init__(
|
| 125 |
-
self,
|
| 126 |
-
d_model: int,
|
| 127 |
-
n_layer: int,
|
| 128 |
-
d_intermediate: int,
|
| 129 |
-
vocab_size: int,
|
| 130 |
-
ssm_cfg=None,
|
| 131 |
-
attn_layer_idx=None,
|
| 132 |
-
attn_cfg=None,
|
| 133 |
-
norm_epsilon: float = 1e-5,
|
| 134 |
-
rms_norm: bool = False,
|
| 135 |
-
initializer_cfg=None,
|
| 136 |
-
fused_add_norm=False,
|
| 137 |
-
residual_in_fp32=False,
|
| 138 |
-
device=None,
|
| 139 |
-
dtype=None,
|
| 140 |
-
) -> None:
|
| 141 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
| 142 |
-
super().__init__()
|
| 143 |
-
self.residual_in_fp32 = residual_in_fp32
|
| 144 |
-
|
| 145 |
-
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
|
| 146 |
-
|
| 147 |
-
# We change the order of residual and layer norm:
|
| 148 |
-
# Instead of LN -> Attn / MLP -> Add, we do:
|
| 149 |
-
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
|
| 150 |
-
# the main branch (output of MLP / Mixer). The model definition is unchanged.
|
| 151 |
-
# This is for performance reason: we can fuse add + layer_norm.
|
| 152 |
-
self.fused_add_norm = fused_add_norm
|
| 153 |
-
if self.fused_add_norm:
|
| 154 |
-
if layer_norm_fn is None or rms_norm_fn is None:
|
| 155 |
-
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
|
| 156 |
-
|
| 157 |
-
self.layers = nn.ModuleList(
|
| 158 |
-
[
|
| 159 |
-
create_block(
|
| 160 |
-
d_model,
|
| 161 |
-
d_intermediate=d_intermediate,
|
| 162 |
-
ssm_cfg=ssm_cfg,
|
| 163 |
-
attn_layer_idx=attn_layer_idx,
|
| 164 |
-
attn_cfg=attn_cfg,
|
| 165 |
-
norm_epsilon=norm_epsilon,
|
| 166 |
-
rms_norm=rms_norm,
|
| 167 |
-
residual_in_fp32=residual_in_fp32,
|
| 168 |
-
fused_add_norm=fused_add_norm,
|
| 169 |
-
layer_idx=i,
|
| 170 |
-
**factory_kwargs,
|
| 171 |
-
)
|
| 172 |
-
for i in range(n_layer)
|
| 173 |
-
]
|
| 174 |
-
)
|
| 175 |
-
|
| 176 |
-
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
|
| 177 |
-
d_model, eps=norm_epsilon, **factory_kwargs
|
| 178 |
-
)
|
| 179 |
-
|
| 180 |
-
self.apply(
|
| 181 |
-
partial(
|
| 182 |
-
_init_weights,
|
| 183 |
-
n_layer=n_layer,
|
| 184 |
-
**(initializer_cfg if initializer_cfg is not None else {}),
|
| 185 |
-
n_residuals_per_layer=(
|
| 186 |
-
1 if d_intermediate == 0 else 2
|
| 187 |
-
), # 2 if we have MLP
|
| 188 |
-
)
|
| 189 |
-
)
|
| 190 |
-
|
| 191 |
-
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 192 |
-
return {
|
| 193 |
-
i: layer.allocate_inference_cache(
|
| 194 |
-
batch_size, max_seqlen, dtype=dtype, **kwargs
|
| 195 |
-
)
|
| 196 |
-
for i, layer in enumerate(self.layers)
|
| 197 |
-
}
|
| 198 |
-
|
| 199 |
-
def forward(self, input_ids, inference_params=None, **mixer_kwargs):
|
| 200 |
-
hidden_states = self.embedding(input_ids)
|
| 201 |
-
residual = None
|
| 202 |
-
for layer in self.layers:
|
| 203 |
-
hidden_states, residual = layer(
|
| 204 |
-
hidden_states,
|
| 205 |
-
residual,
|
| 206 |
-
inference_params=inference_params,
|
| 207 |
-
**mixer_kwargs,
|
| 208 |
-
)
|
| 209 |
-
if not self.fused_add_norm:
|
| 210 |
-
residual = (
|
| 211 |
-
(hidden_states + residual) if residual is not None else hidden_states
|
| 212 |
-
)
|
| 213 |
-
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
| 214 |
-
else:
|
| 215 |
-
# Set prenorm=False here since we don't need the residual
|
| 216 |
-
hidden_states = layer_norm_fn(
|
| 217 |
-
hidden_states,
|
| 218 |
-
self.norm_f.weight,
|
| 219 |
-
self.norm_f.bias,
|
| 220 |
-
eps=self.norm_f.eps,
|
| 221 |
-
residual=residual,
|
| 222 |
-
prenorm=False,
|
| 223 |
-
residual_in_fp32=self.residual_in_fp32,
|
| 224 |
-
is_rms_norm=isinstance(self.norm_f, RMSNorm),
|
| 225 |
-
)
|
| 226 |
-
return hidden_states
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
class MambaLMHeadModel(nn.Module, GenerationMixin):
|
| 230 |
-
|
| 231 |
-
def __init__(
|
| 232 |
-
self,
|
| 233 |
-
config: MambaConfig,
|
| 234 |
-
initializer_cfg=None,
|
| 235 |
-
device=None,
|
| 236 |
-
dtype=None,
|
| 237 |
-
) -> None:
|
| 238 |
-
self.config = config
|
| 239 |
-
d_model = config.d_model
|
| 240 |
-
n_layer = config.n_layer
|
| 241 |
-
d_intermediate = config.d_intermediate
|
| 242 |
-
vocab_size = config.vocab_size
|
| 243 |
-
ssm_cfg = config.ssm_cfg
|
| 244 |
-
attn_layer_idx = config.attn_layer_idx
|
| 245 |
-
attn_cfg = config.attn_cfg
|
| 246 |
-
rms_norm = config.rms_norm
|
| 247 |
-
residual_in_fp32 = config.residual_in_fp32
|
| 248 |
-
fused_add_norm = config.fused_add_norm
|
| 249 |
-
pad_vocab_size_multiple = config.pad_vocab_size_multiple
|
| 250 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
| 251 |
-
|
| 252 |
-
super().__init__()
|
| 253 |
-
if vocab_size % pad_vocab_size_multiple != 0:
|
| 254 |
-
vocab_size += pad_vocab_size_multiple - (
|
| 255 |
-
vocab_size % pad_vocab_size_multiple
|
| 256 |
-
)
|
| 257 |
-
self.backbone = MixerModel(
|
| 258 |
-
d_model=d_model,
|
| 259 |
-
n_layer=n_layer,
|
| 260 |
-
d_intermediate=d_intermediate,
|
| 261 |
-
vocab_size=vocab_size,
|
| 262 |
-
ssm_cfg=ssm_cfg,
|
| 263 |
-
attn_layer_idx=attn_layer_idx,
|
| 264 |
-
attn_cfg=attn_cfg,
|
| 265 |
-
rms_norm=rms_norm,
|
| 266 |
-
initializer_cfg=initializer_cfg,
|
| 267 |
-
fused_add_norm=fused_add_norm,
|
| 268 |
-
residual_in_fp32=residual_in_fp32,
|
| 269 |
-
**factory_kwargs,
|
| 270 |
-
)
|
| 271 |
-
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
| 272 |
-
|
| 273 |
-
# Initialize weights and apply final processing
|
| 274 |
-
self.apply(
|
| 275 |
-
partial(
|
| 276 |
-
_init_weights,
|
| 277 |
-
n_layer=n_layer,
|
| 278 |
-
**(initializer_cfg if initializer_cfg is not None else {}),
|
| 279 |
-
)
|
| 280 |
-
)
|
| 281 |
-
self.tie_weights()
|
| 282 |
-
|
| 283 |
-
def tie_weights(self):
|
| 284 |
-
if self.config.tie_embeddings:
|
| 285 |
-
self.lm_head.weight = self.backbone.embedding.weight
|
| 286 |
-
|
| 287 |
-
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 288 |
-
return self.backbone.allocate_inference_cache(
|
| 289 |
-
batch_size, max_seqlen, dtype=dtype, **kwargs
|
| 290 |
-
)
|
| 291 |
-
|
| 292 |
-
def forward(
|
| 293 |
-
self,
|
| 294 |
-
input_ids,
|
| 295 |
-
position_ids=None,
|
| 296 |
-
inference_params=None,
|
| 297 |
-
num_last_tokens=0,
|
| 298 |
-
**mixer_kwargs,
|
| 299 |
-
):
|
| 300 |
-
"""
|
| 301 |
-
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
| 302 |
-
num_last_tokens: if > 0, only return the logits for the last n tokens
|
| 303 |
-
"""
|
| 304 |
-
hidden_states = self.backbone(
|
| 305 |
-
input_ids, inference_params=inference_params, **mixer_kwargs
|
| 306 |
-
)
|
| 307 |
-
if num_last_tokens > 0:
|
| 308 |
-
hidden_states = hidden_states[:, -num_last_tokens:]
|
| 309 |
-
lm_logits = self.lm_head(hidden_states)
|
| 310 |
-
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
| 311 |
-
return CausalLMOutput(logits=lm_logits)
|
| 312 |
-
|
| 313 |
-
@classmethod
|
| 314 |
-
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
| 315 |
-
config_data = load_config_hf(pretrained_model_name)
|
| 316 |
-
config = MambaConfig(**config_data)
|
| 317 |
-
model = cls(config, device=device, dtype=dtype, **kwargs)
|
| 318 |
-
model.load_state_dict(
|
| 319 |
-
load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)
|
| 320 |
-
)
|
| 321 |
-
return model
|
| 322 |
-
|
| 323 |
-
def save_pretrained(self, save_directory):
|
| 324 |
-
"""
|
| 325 |
-
Minimal implementation of save_pretrained for MambaLMHeadModel.
|
| 326 |
-
Save the model and its configuration file to a directory.
|
| 327 |
-
"""
|
| 328 |
-
# Ensure save_directory exists
|
| 329 |
-
os.makedirs(save_directory, exist_ok=True)
|
| 330 |
-
|
| 331 |
-
# Save the model's state_dict
|
| 332 |
-
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
| 333 |
-
torch.save(self.state_dict(), model_path)
|
| 334 |
-
|
| 335 |
-
# Save the configuration of the model
|
| 336 |
-
config_path = os.path.join(save_directory, "config.json")
|
| 337 |
-
with open(config_path, "w") as f:
|
| 338 |
-
json.dump(self.config.__dict__, f, indent=4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/selective_scan_interface.py
DELETED
|
@@ -1,659 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
from ..utils.torch import custom_fwd, custom_bwd
|
| 6 |
-
|
| 7 |
-
from einops import rearrange, repeat
|
| 8 |
-
|
| 9 |
-
try:
|
| 10 |
-
from causal_conv1d import causal_conv1d_fn
|
| 11 |
-
import causal_conv1d_cuda
|
| 12 |
-
except ImportError:
|
| 13 |
-
causal_conv1d_fn = None
|
| 14 |
-
causal_conv1d_cuda = None
|
| 15 |
-
|
| 16 |
-
from .triton.layer_norm import _layer_norm_fwd
|
| 17 |
-
|
| 18 |
-
from .._ops import ops
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class SelectiveScanFn(torch.autograd.Function):
|
| 22 |
-
|
| 23 |
-
@staticmethod
|
| 24 |
-
def forward(
|
| 25 |
-
ctx,
|
| 26 |
-
u,
|
| 27 |
-
delta,
|
| 28 |
-
A,
|
| 29 |
-
B,
|
| 30 |
-
C,
|
| 31 |
-
D=None,
|
| 32 |
-
z=None,
|
| 33 |
-
delta_bias=None,
|
| 34 |
-
delta_softplus=False,
|
| 35 |
-
return_last_state=False,
|
| 36 |
-
):
|
| 37 |
-
if u.stride(-1) != 1:
|
| 38 |
-
u = u.contiguous()
|
| 39 |
-
if delta.stride(-1) != 1:
|
| 40 |
-
delta = delta.contiguous()
|
| 41 |
-
if D is not None:
|
| 42 |
-
D = D.contiguous()
|
| 43 |
-
if B.stride(-1) != 1:
|
| 44 |
-
B = B.contiguous()
|
| 45 |
-
if C.stride(-1) != 1:
|
| 46 |
-
C = C.contiguous()
|
| 47 |
-
if z is not None and z.stride(-1) != 1:
|
| 48 |
-
z = z.contiguous()
|
| 49 |
-
if B.dim() == 3:
|
| 50 |
-
B = rearrange(B, "b dstate l -> b 1 dstate l")
|
| 51 |
-
ctx.squeeze_B = True
|
| 52 |
-
if C.dim() == 3:
|
| 53 |
-
C = rearrange(C, "b dstate l -> b 1 dstate l")
|
| 54 |
-
ctx.squeeze_C = True
|
| 55 |
-
out, x, *rest = ops.selective_scan_fwd(
|
| 56 |
-
u, delta, A, B, C, D, z, delta_bias, delta_softplus
|
| 57 |
-
)
|
| 58 |
-
ctx.delta_softplus = delta_softplus
|
| 59 |
-
ctx.has_z = z is not None
|
| 60 |
-
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
|
| 61 |
-
if not ctx.has_z:
|
| 62 |
-
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
|
| 63 |
-
return out if not return_last_state else (out, last_state)
|
| 64 |
-
else:
|
| 65 |
-
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
|
| 66 |
-
out_z = rest[0]
|
| 67 |
-
return out_z if not return_last_state else (out_z, last_state)
|
| 68 |
-
|
| 69 |
-
@staticmethod
|
| 70 |
-
def backward(ctx, dout, *args):
|
| 71 |
-
if not ctx.has_z:
|
| 72 |
-
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
|
| 73 |
-
z = None
|
| 74 |
-
out = None
|
| 75 |
-
else:
|
| 76 |
-
u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
|
| 77 |
-
if dout.stride(-1) != 1:
|
| 78 |
-
dout = dout.contiguous()
|
| 79 |
-
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
| 80 |
-
# backward of selective_scan_cuda with the backward of chunk).
|
| 81 |
-
# Here we just pass in None and dz will be allocated in the C++ code.
|
| 82 |
-
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = ops.selective_scan_bwd(
|
| 83 |
-
u,
|
| 84 |
-
delta,
|
| 85 |
-
A,
|
| 86 |
-
B,
|
| 87 |
-
C,
|
| 88 |
-
D,
|
| 89 |
-
z,
|
| 90 |
-
delta_bias,
|
| 91 |
-
dout,
|
| 92 |
-
x,
|
| 93 |
-
out,
|
| 94 |
-
None,
|
| 95 |
-
ctx.delta_softplus,
|
| 96 |
-
False, # option to recompute out_z, not used here
|
| 97 |
-
)
|
| 98 |
-
dz = rest[0] if ctx.has_z else None
|
| 99 |
-
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
|
| 100 |
-
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
|
| 101 |
-
return (
|
| 102 |
-
du,
|
| 103 |
-
ddelta,
|
| 104 |
-
dA,
|
| 105 |
-
dB,
|
| 106 |
-
dC,
|
| 107 |
-
dD if D is not None else None,
|
| 108 |
-
dz,
|
| 109 |
-
ddelta_bias if delta_bias is not None else None,
|
| 110 |
-
None,
|
| 111 |
-
None,
|
| 112 |
-
)
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
def rms_norm_forward(
|
| 116 |
-
x,
|
| 117 |
-
weight,
|
| 118 |
-
bias,
|
| 119 |
-
eps=1e-6,
|
| 120 |
-
is_rms_norm=True,
|
| 121 |
-
):
|
| 122 |
-
# x (b l) d
|
| 123 |
-
if x.stride(-1) != 1:
|
| 124 |
-
x = x.contiguous()
|
| 125 |
-
weight = weight.contiguous()
|
| 126 |
-
if bias is not None:
|
| 127 |
-
bias = bias.contiguous()
|
| 128 |
-
y = _layer_norm_fwd(
|
| 129 |
-
x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm
|
| 130 |
-
)[0]
|
| 131 |
-
# y (b l) d
|
| 132 |
-
return y
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
def selective_scan_fn(
|
| 136 |
-
u,
|
| 137 |
-
delta,
|
| 138 |
-
A,
|
| 139 |
-
B,
|
| 140 |
-
C,
|
| 141 |
-
D=None,
|
| 142 |
-
z=None,
|
| 143 |
-
delta_bias=None,
|
| 144 |
-
delta_softplus=False,
|
| 145 |
-
return_last_state=False,
|
| 146 |
-
):
|
| 147 |
-
"""if return_last_state is True, returns (out, last_state)
|
| 148 |
-
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
|
| 149 |
-
not considered in the backward pass.
|
| 150 |
-
"""
|
| 151 |
-
return SelectiveScanFn.apply(
|
| 152 |
-
u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
def selective_scan_ref(
|
| 157 |
-
u,
|
| 158 |
-
delta,
|
| 159 |
-
A,
|
| 160 |
-
B,
|
| 161 |
-
C,
|
| 162 |
-
D=None,
|
| 163 |
-
z=None,
|
| 164 |
-
delta_bias=None,
|
| 165 |
-
delta_softplus=False,
|
| 166 |
-
return_last_state=False,
|
| 167 |
-
):
|
| 168 |
-
"""
|
| 169 |
-
u: r(B D L)
|
| 170 |
-
delta: r(B D L)
|
| 171 |
-
A: c(D N) or r(D N)
|
| 172 |
-
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
| 173 |
-
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
| 174 |
-
D: r(D)
|
| 175 |
-
z: r(B D L)
|
| 176 |
-
delta_bias: r(D), fp32
|
| 177 |
-
|
| 178 |
-
out: r(B D L)
|
| 179 |
-
last_state (optional): r(B D dstate) or c(B D dstate)
|
| 180 |
-
"""
|
| 181 |
-
dtype_in = u.dtype
|
| 182 |
-
u = u.float()
|
| 183 |
-
delta = delta.float()
|
| 184 |
-
if delta_bias is not None:
|
| 185 |
-
delta = delta + delta_bias[..., None].float()
|
| 186 |
-
if delta_softplus:
|
| 187 |
-
delta = F.softplus(delta)
|
| 188 |
-
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
|
| 189 |
-
is_variable_B = B.dim() >= 3
|
| 190 |
-
is_variable_C = C.dim() >= 3
|
| 191 |
-
if A.is_complex():
|
| 192 |
-
if is_variable_B:
|
| 193 |
-
B = torch.view_as_complex(
|
| 194 |
-
rearrange(B.float(), "... (L two) -> ... L two", two=2)
|
| 195 |
-
)
|
| 196 |
-
if is_variable_C:
|
| 197 |
-
C = torch.view_as_complex(
|
| 198 |
-
rearrange(C.float(), "... (L two) -> ... L two", two=2)
|
| 199 |
-
)
|
| 200 |
-
else:
|
| 201 |
-
B = B.float()
|
| 202 |
-
C = C.float()
|
| 203 |
-
x = A.new_zeros((batch, dim, dstate))
|
| 204 |
-
ys = []
|
| 205 |
-
deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
|
| 206 |
-
if not is_variable_B:
|
| 207 |
-
deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
|
| 208 |
-
else:
|
| 209 |
-
if B.dim() == 3:
|
| 210 |
-
deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
|
| 211 |
-
else:
|
| 212 |
-
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
|
| 213 |
-
deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
|
| 214 |
-
if is_variable_C and C.dim() == 4:
|
| 215 |
-
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
|
| 216 |
-
last_state = None
|
| 217 |
-
for i in range(u.shape[2]):
|
| 218 |
-
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
| 219 |
-
if not is_variable_C:
|
| 220 |
-
y = torch.einsum("bdn,dn->bd", x, C)
|
| 221 |
-
else:
|
| 222 |
-
if C.dim() == 3:
|
| 223 |
-
y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
|
| 224 |
-
else:
|
| 225 |
-
y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
|
| 226 |
-
if i == u.shape[2] - 1:
|
| 227 |
-
last_state = x
|
| 228 |
-
if y.is_complex():
|
| 229 |
-
y = y.real * 2
|
| 230 |
-
ys.append(y)
|
| 231 |
-
y = torch.stack(ys, dim=2) # (batch dim L)
|
| 232 |
-
out = y if D is None else y + u * rearrange(D, "d -> d 1")
|
| 233 |
-
if z is not None:
|
| 234 |
-
out = out * F.silu(z)
|
| 235 |
-
out = out.to(dtype=dtype_in)
|
| 236 |
-
return out if not return_last_state else (out, last_state)
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
class MambaInnerFn(torch.autograd.Function):
|
| 240 |
-
|
| 241 |
-
@staticmethod
|
| 242 |
-
@custom_fwd
|
| 243 |
-
def forward(
|
| 244 |
-
ctx,
|
| 245 |
-
xz,
|
| 246 |
-
conv1d_weight,
|
| 247 |
-
conv1d_bias,
|
| 248 |
-
x_proj_weight,
|
| 249 |
-
delta_proj_weight,
|
| 250 |
-
out_proj_weight,
|
| 251 |
-
out_proj_bias,
|
| 252 |
-
A,
|
| 253 |
-
B=None,
|
| 254 |
-
C=None,
|
| 255 |
-
D=None,
|
| 256 |
-
delta_bias=None,
|
| 257 |
-
B_proj_bias=None,
|
| 258 |
-
C_proj_bias=None,
|
| 259 |
-
delta_softplus=True,
|
| 260 |
-
checkpoint_lvl=1,
|
| 261 |
-
b_rms_weight=None,
|
| 262 |
-
c_rms_weight=None,
|
| 263 |
-
dt_rms_weight=None,
|
| 264 |
-
b_c_dt_rms_eps=1e-6,
|
| 265 |
-
):
|
| 266 |
-
"""
|
| 267 |
-
xz: (batch, dim, seqlen)
|
| 268 |
-
"""
|
| 269 |
-
assert (
|
| 270 |
-
causal_conv1d_cuda is not None
|
| 271 |
-
), "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
| 272 |
-
assert checkpoint_lvl in [0, 1]
|
| 273 |
-
L = xz.shape[-1]
|
| 274 |
-
delta_rank = delta_proj_weight.shape[1]
|
| 275 |
-
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
| 276 |
-
if torch.is_autocast_enabled():
|
| 277 |
-
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
| 278 |
-
delta_proj_weight = delta_proj_weight.to(
|
| 279 |
-
dtype=torch.get_autocast_gpu_dtype()
|
| 280 |
-
)
|
| 281 |
-
out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
| 282 |
-
out_proj_bias = (
|
| 283 |
-
out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
|
| 284 |
-
if out_proj_bias is not None
|
| 285 |
-
else None
|
| 286 |
-
)
|
| 287 |
-
if xz.stride(-1) != 1:
|
| 288 |
-
xz = xz.contiguous()
|
| 289 |
-
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
|
| 290 |
-
x, z = xz.chunk(2, dim=1)
|
| 291 |
-
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
|
| 292 |
-
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
|
| 293 |
-
x, conv1d_weight, conv1d_bias, None, None, None, True
|
| 294 |
-
)
|
| 295 |
-
# We're being very careful here about the layout, to avoid extra transposes.
|
| 296 |
-
# We want delta to have d as the slowest moving dimension
|
| 297 |
-
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
| 298 |
-
x_dbl = F.linear(
|
| 299 |
-
rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight
|
| 300 |
-
) # (bl d)
|
| 301 |
-
delta = rearrange(
|
| 302 |
-
delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L
|
| 303 |
-
)
|
| 304 |
-
ctx.is_variable_B = B is None
|
| 305 |
-
ctx.is_variable_C = C is None
|
| 306 |
-
ctx.B_proj_bias_is_None = B_proj_bias is None
|
| 307 |
-
ctx.C_proj_bias_is_None = C_proj_bias is None
|
| 308 |
-
if B is None: # variable B
|
| 309 |
-
B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate)
|
| 310 |
-
if B_proj_bias is not None:
|
| 311 |
-
B = B + B_proj_bias.to(dtype=B.dtype)
|
| 312 |
-
if not A.is_complex():
|
| 313 |
-
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
| 314 |
-
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| 315 |
-
else:
|
| 316 |
-
B = rearrange(
|
| 317 |
-
B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2
|
| 318 |
-
).contiguous()
|
| 319 |
-
else:
|
| 320 |
-
if B.stride(-1) != 1:
|
| 321 |
-
B = B.contiguous()
|
| 322 |
-
if C is None: # variable C
|
| 323 |
-
C = x_dbl[:, -d_state:] # (bl dstate)
|
| 324 |
-
if C_proj_bias is not None:
|
| 325 |
-
C = C + C_proj_bias.to(dtype=C.dtype)
|
| 326 |
-
if not A.is_complex():
|
| 327 |
-
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
| 328 |
-
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| 329 |
-
else:
|
| 330 |
-
C = rearrange(
|
| 331 |
-
C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2
|
| 332 |
-
).contiguous()
|
| 333 |
-
else:
|
| 334 |
-
if C.stride(-1) != 1:
|
| 335 |
-
C = C.contiguous()
|
| 336 |
-
if D is not None:
|
| 337 |
-
D = D.contiguous()
|
| 338 |
-
|
| 339 |
-
if b_rms_weight is not None:
|
| 340 |
-
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
| 341 |
-
B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)
|
| 342 |
-
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| 343 |
-
if c_rms_weight is not None:
|
| 344 |
-
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
| 345 |
-
C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)
|
| 346 |
-
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| 347 |
-
if dt_rms_weight is not None:
|
| 348 |
-
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
|
| 349 |
-
delta = rms_norm_forward(
|
| 350 |
-
delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps
|
| 351 |
-
)
|
| 352 |
-
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
|
| 353 |
-
|
| 354 |
-
out, scan_intermediates, out_z = ops.selective_scan_fwd(
|
| 355 |
-
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
|
| 356 |
-
)
|
| 357 |
-
ctx.delta_softplus = delta_softplus
|
| 358 |
-
ctx.out_proj_bias_is_None = out_proj_bias is None
|
| 359 |
-
ctx.checkpoint_lvl = checkpoint_lvl
|
| 360 |
-
ctx.b_rms_weight = b_rms_weight
|
| 361 |
-
ctx.c_rms_weight = c_rms_weight
|
| 362 |
-
ctx.dt_rms_weight = dt_rms_weight
|
| 363 |
-
ctx.b_c_dt_rms_eps = b_c_dt_rms_eps
|
| 364 |
-
if (
|
| 365 |
-
checkpoint_lvl >= 1
|
| 366 |
-
): # Will recompute conv1d_out and delta in the backward pass
|
| 367 |
-
conv1d_out, delta = None, None
|
| 368 |
-
ctx.save_for_backward(
|
| 369 |
-
xz,
|
| 370 |
-
conv1d_weight,
|
| 371 |
-
conv1d_bias,
|
| 372 |
-
x_dbl,
|
| 373 |
-
x_proj_weight,
|
| 374 |
-
delta_proj_weight,
|
| 375 |
-
out_proj_weight,
|
| 376 |
-
conv1d_out,
|
| 377 |
-
delta,
|
| 378 |
-
A,
|
| 379 |
-
B,
|
| 380 |
-
C,
|
| 381 |
-
D,
|
| 382 |
-
delta_bias,
|
| 383 |
-
scan_intermediates,
|
| 384 |
-
b_rms_weight,
|
| 385 |
-
c_rms_weight,
|
| 386 |
-
dt_rms_weight,
|
| 387 |
-
out,
|
| 388 |
-
)
|
| 389 |
-
return F.linear(
|
| 390 |
-
rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias
|
| 391 |
-
)
|
| 392 |
-
|
| 393 |
-
@staticmethod
|
| 394 |
-
@custom_bwd
|
| 395 |
-
def backward(ctx, dout):
|
| 396 |
-
# dout: (batch, seqlen, dim)
|
| 397 |
-
assert (
|
| 398 |
-
causal_conv1d_cuda is not None
|
| 399 |
-
), "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
| 400 |
-
(
|
| 401 |
-
xz,
|
| 402 |
-
conv1d_weight,
|
| 403 |
-
conv1d_bias,
|
| 404 |
-
x_dbl,
|
| 405 |
-
x_proj_weight,
|
| 406 |
-
delta_proj_weight,
|
| 407 |
-
out_proj_weight,
|
| 408 |
-
conv1d_out,
|
| 409 |
-
delta,
|
| 410 |
-
A,
|
| 411 |
-
B,
|
| 412 |
-
C,
|
| 413 |
-
D,
|
| 414 |
-
delta_bias,
|
| 415 |
-
scan_intermediates,
|
| 416 |
-
b_rms_weight,
|
| 417 |
-
c_rms_weight,
|
| 418 |
-
dt_rms_weight,
|
| 419 |
-
out,
|
| 420 |
-
) = ctx.saved_tensors
|
| 421 |
-
L = xz.shape[-1]
|
| 422 |
-
delta_rank = delta_proj_weight.shape[1]
|
| 423 |
-
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
| 424 |
-
x, z = xz.chunk(2, dim=1)
|
| 425 |
-
if dout.stride(-1) != 1:
|
| 426 |
-
dout = dout.contiguous()
|
| 427 |
-
if ctx.checkpoint_lvl == 1:
|
| 428 |
-
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
|
| 429 |
-
x, conv1d_weight, conv1d_bias, None, None, None, True
|
| 430 |
-
)
|
| 431 |
-
delta = rearrange(
|
| 432 |
-
delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L
|
| 433 |
-
)
|
| 434 |
-
if dt_rms_weight is not None:
|
| 435 |
-
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
|
| 436 |
-
delta = rms_norm_forward(
|
| 437 |
-
delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps
|
| 438 |
-
)
|
| 439 |
-
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
|
| 440 |
-
if b_rms_weight is not None:
|
| 441 |
-
# Recompute & RMSNorm B
|
| 442 |
-
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
| 443 |
-
B = rms_norm_forward(B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps)
|
| 444 |
-
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| 445 |
-
if c_rms_weight is not None:
|
| 446 |
-
# Recompute & RMSNorm C
|
| 447 |
-
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
| 448 |
-
C = rms_norm_forward(C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps)
|
| 449 |
-
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| 450 |
-
|
| 451 |
-
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
| 452 |
-
# backward of selective_scan_cuda with the backward of chunk).
|
| 453 |
-
dxz = torch.empty_like(xz) # (batch, dim, seqlen)
|
| 454 |
-
dx, dz = dxz.chunk(2, dim=1)
|
| 455 |
-
dout = rearrange(dout, "b l e -> e (b l)")
|
| 456 |
-
dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
|
| 457 |
-
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = (
|
| 458 |
-
ops.selective_scan_bwd(
|
| 459 |
-
conv1d_out,
|
| 460 |
-
delta,
|
| 461 |
-
A,
|
| 462 |
-
B,
|
| 463 |
-
C,
|
| 464 |
-
D,
|
| 465 |
-
z,
|
| 466 |
-
delta_bias,
|
| 467 |
-
dout_y,
|
| 468 |
-
scan_intermediates,
|
| 469 |
-
out,
|
| 470 |
-
dz,
|
| 471 |
-
ctx.delta_softplus,
|
| 472 |
-
True, # option to recompute out_z
|
| 473 |
-
)
|
| 474 |
-
)
|
| 475 |
-
dout_proj_weight = torch.einsum(
|
| 476 |
-
"eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")
|
| 477 |
-
)
|
| 478 |
-
dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
|
| 479 |
-
dD = dD if D is not None else None
|
| 480 |
-
dx_dbl = torch.empty_like(x_dbl)
|
| 481 |
-
dB_proj_bias = None
|
| 482 |
-
if ctx.is_variable_B:
|
| 483 |
-
if not A.is_complex():
|
| 484 |
-
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
|
| 485 |
-
else:
|
| 486 |
-
dB = rearrange(
|
| 487 |
-
dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2
|
| 488 |
-
).contiguous()
|
| 489 |
-
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
|
| 490 |
-
dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d)
|
| 491 |
-
dB = None
|
| 492 |
-
dC_proj_bias = None
|
| 493 |
-
if ctx.is_variable_C:
|
| 494 |
-
if not A.is_complex():
|
| 495 |
-
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
|
| 496 |
-
else:
|
| 497 |
-
dC = rearrange(
|
| 498 |
-
dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2
|
| 499 |
-
).contiguous()
|
| 500 |
-
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
|
| 501 |
-
dx_dbl[:, -d_state:] = dC # (bl d)
|
| 502 |
-
dC = None
|
| 503 |
-
ddelta = rearrange(ddelta, "b d l -> d (b l)")
|
| 504 |
-
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
|
| 505 |
-
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
|
| 506 |
-
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
|
| 507 |
-
dx_proj_weight = torch.einsum(
|
| 508 |
-
"Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")
|
| 509 |
-
)
|
| 510 |
-
dconv1d_out = torch.addmm(
|
| 511 |
-
dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out
|
| 512 |
-
)
|
| 513 |
-
dconv1d_out = rearrange(
|
| 514 |
-
dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]
|
| 515 |
-
)
|
| 516 |
-
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
| 517 |
-
# backward of conv1d with the backward of chunk).
|
| 518 |
-
dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
|
| 519 |
-
x,
|
| 520 |
-
conv1d_weight,
|
| 521 |
-
conv1d_bias,
|
| 522 |
-
dconv1d_out,
|
| 523 |
-
None,
|
| 524 |
-
None,
|
| 525 |
-
None,
|
| 526 |
-
dx,
|
| 527 |
-
False,
|
| 528 |
-
True,
|
| 529 |
-
)
|
| 530 |
-
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
|
| 531 |
-
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
|
| 532 |
-
return (
|
| 533 |
-
dxz,
|
| 534 |
-
dconv1d_weight,
|
| 535 |
-
dconv1d_bias,
|
| 536 |
-
dx_proj_weight,
|
| 537 |
-
ddelta_proj_weight,
|
| 538 |
-
dout_proj_weight,
|
| 539 |
-
dout_proj_bias,
|
| 540 |
-
dA,
|
| 541 |
-
dB,
|
| 542 |
-
dC,
|
| 543 |
-
dD,
|
| 544 |
-
ddelta_bias if delta_bias is not None else None,
|
| 545 |
-
# 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps
|
| 546 |
-
dB_proj_bias,
|
| 547 |
-
dC_proj_bias,
|
| 548 |
-
None,
|
| 549 |
-
None,
|
| 550 |
-
None,
|
| 551 |
-
None,
|
| 552 |
-
None,
|
| 553 |
-
None,
|
| 554 |
-
)
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
def mamba_inner_fn(
|
| 558 |
-
xz,
|
| 559 |
-
conv1d_weight,
|
| 560 |
-
conv1d_bias,
|
| 561 |
-
x_proj_weight,
|
| 562 |
-
delta_proj_weight,
|
| 563 |
-
out_proj_weight,
|
| 564 |
-
out_proj_bias,
|
| 565 |
-
A,
|
| 566 |
-
B=None,
|
| 567 |
-
C=None,
|
| 568 |
-
D=None,
|
| 569 |
-
delta_bias=None,
|
| 570 |
-
B_proj_bias=None,
|
| 571 |
-
C_proj_bias=None,
|
| 572 |
-
delta_softplus=True,
|
| 573 |
-
checkpoint_lvl=1,
|
| 574 |
-
b_rms_weight=None,
|
| 575 |
-
c_rms_weight=None,
|
| 576 |
-
dt_rms_weight=None,
|
| 577 |
-
b_c_dt_rms_eps=1e-6,
|
| 578 |
-
):
|
| 579 |
-
return MambaInnerFn.apply(
|
| 580 |
-
xz,
|
| 581 |
-
conv1d_weight,
|
| 582 |
-
conv1d_bias,
|
| 583 |
-
x_proj_weight,
|
| 584 |
-
delta_proj_weight,
|
| 585 |
-
out_proj_weight,
|
| 586 |
-
out_proj_bias,
|
| 587 |
-
A,
|
| 588 |
-
B,
|
| 589 |
-
C,
|
| 590 |
-
D,
|
| 591 |
-
delta_bias,
|
| 592 |
-
B_proj_bias,
|
| 593 |
-
C_proj_bias,
|
| 594 |
-
delta_softplus,
|
| 595 |
-
checkpoint_lvl,
|
| 596 |
-
b_rms_weight,
|
| 597 |
-
c_rms_weight,
|
| 598 |
-
dt_rms_weight,
|
| 599 |
-
b_c_dt_rms_eps,
|
| 600 |
-
)
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
def mamba_inner_ref(
|
| 604 |
-
xz,
|
| 605 |
-
conv1d_weight,
|
| 606 |
-
conv1d_bias,
|
| 607 |
-
x_proj_weight,
|
| 608 |
-
delta_proj_weight,
|
| 609 |
-
out_proj_weight,
|
| 610 |
-
out_proj_bias,
|
| 611 |
-
A,
|
| 612 |
-
B=None,
|
| 613 |
-
C=None,
|
| 614 |
-
D=None,
|
| 615 |
-
delta_bias=None,
|
| 616 |
-
B_proj_bias=None,
|
| 617 |
-
C_proj_bias=None,
|
| 618 |
-
delta_softplus=True,
|
| 619 |
-
):
|
| 620 |
-
assert (
|
| 621 |
-
causal_conv1d_fn is not None
|
| 622 |
-
), "causal_conv1d_fn is not available. Please install causal-conv1d."
|
| 623 |
-
L = xz.shape[-1]
|
| 624 |
-
delta_rank = delta_proj_weight.shape[1]
|
| 625 |
-
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
| 626 |
-
x, z = xz.chunk(2, dim=1)
|
| 627 |
-
x = causal_conv1d_fn(
|
| 628 |
-
x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu"
|
| 629 |
-
)
|
| 630 |
-
# We're being very careful here about the layout, to avoid extra transposes.
|
| 631 |
-
# We want delta to have d as the slowest moving dimension
|
| 632 |
-
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
| 633 |
-
x_dbl = F.linear(rearrange(x, "b d l -> (b l) d"), x_proj_weight) # (bl d)
|
| 634 |
-
delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
|
| 635 |
-
delta = rearrange(delta, "d (b l) -> b d l", l=L)
|
| 636 |
-
if B is None: # variable B
|
| 637 |
-
B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl d)
|
| 638 |
-
if B_proj_bias is not None:
|
| 639 |
-
B = B + B_proj_bias.to(dtype=B.dtype)
|
| 640 |
-
if not A.is_complex():
|
| 641 |
-
B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
| 642 |
-
else:
|
| 643 |
-
B = rearrange(
|
| 644 |
-
B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2
|
| 645 |
-
).contiguous()
|
| 646 |
-
if C is None: # variable B
|
| 647 |
-
C = x_dbl[:, -d_state:] # (bl d)
|
| 648 |
-
if C_proj_bias is not None:
|
| 649 |
-
C = C + C_proj_bias.to(dtype=C.dtype)
|
| 650 |
-
if not A.is_complex():
|
| 651 |
-
C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
| 652 |
-
else:
|
| 653 |
-
C = rearrange(
|
| 654 |
-
C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2
|
| 655 |
-
).contiguous()
|
| 656 |
-
y = selective_scan_fn(
|
| 657 |
-
x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True
|
| 658 |
-
)
|
| 659 |
-
return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/layer_norm.py
DELETED
|
@@ -1,1166 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2024, Tri Dao.
|
| 2 |
-
# Implement dropout + residual + layer_norm / rms_norm.
|
| 3 |
-
|
| 4 |
-
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
| 5 |
-
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
| 6 |
-
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
| 7 |
-
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
| 8 |
-
|
| 9 |
-
import math
|
| 10 |
-
import warnings
|
| 11 |
-
|
| 12 |
-
import torch
|
| 13 |
-
import torch.nn.functional as F
|
| 14 |
-
from ...utils.torch import custom_bwd, custom_fwd
|
| 15 |
-
|
| 16 |
-
import triton
|
| 17 |
-
import triton.language as tl
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def layer_norm_ref(
|
| 21 |
-
x,
|
| 22 |
-
weight,
|
| 23 |
-
bias,
|
| 24 |
-
residual=None,
|
| 25 |
-
x1=None,
|
| 26 |
-
weight1=None,
|
| 27 |
-
bias1=None,
|
| 28 |
-
eps=1e-6,
|
| 29 |
-
dropout_p=0.0,
|
| 30 |
-
rowscale=None,
|
| 31 |
-
prenorm=False,
|
| 32 |
-
dropout_mask=None,
|
| 33 |
-
dropout_mask1=None,
|
| 34 |
-
upcast=False,
|
| 35 |
-
):
|
| 36 |
-
dtype = x.dtype
|
| 37 |
-
if upcast:
|
| 38 |
-
x = x.float()
|
| 39 |
-
weight = weight.float()
|
| 40 |
-
bias = bias.float() if bias is not None else None
|
| 41 |
-
residual = residual.float() if residual is not None else residual
|
| 42 |
-
x1 = x1.float() if x1 is not None else None
|
| 43 |
-
weight1 = weight1.float() if weight1 is not None else None
|
| 44 |
-
bias1 = bias1.float() if bias1 is not None else None
|
| 45 |
-
if x1 is not None:
|
| 46 |
-
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
| 47 |
-
if rowscale is not None:
|
| 48 |
-
x = x * rowscale[..., None]
|
| 49 |
-
if dropout_p > 0.0:
|
| 50 |
-
if dropout_mask is not None:
|
| 51 |
-
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
|
| 52 |
-
else:
|
| 53 |
-
x = F.dropout(x, p=dropout_p)
|
| 54 |
-
if x1 is not None:
|
| 55 |
-
if dropout_mask1 is not None:
|
| 56 |
-
x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
|
| 57 |
-
else:
|
| 58 |
-
x1 = F.dropout(x1, p=dropout_p)
|
| 59 |
-
if x1 is not None:
|
| 60 |
-
x = x + x1
|
| 61 |
-
if residual is not None:
|
| 62 |
-
x = (x + residual).to(x.dtype)
|
| 63 |
-
out = F.layer_norm(
|
| 64 |
-
x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps
|
| 65 |
-
).to(dtype)
|
| 66 |
-
if weight1 is None:
|
| 67 |
-
return out if not prenorm else (out, x)
|
| 68 |
-
else:
|
| 69 |
-
out1 = F.layer_norm(
|
| 70 |
-
x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
|
| 71 |
-
).to(dtype)
|
| 72 |
-
return (out, out1) if not prenorm else (out, out1, x)
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def rms_norm_ref(
|
| 76 |
-
x,
|
| 77 |
-
weight,
|
| 78 |
-
bias,
|
| 79 |
-
residual=None,
|
| 80 |
-
x1=None,
|
| 81 |
-
weight1=None,
|
| 82 |
-
bias1=None,
|
| 83 |
-
eps=1e-6,
|
| 84 |
-
dropout_p=0.0,
|
| 85 |
-
rowscale=None,
|
| 86 |
-
prenorm=False,
|
| 87 |
-
dropout_mask=None,
|
| 88 |
-
dropout_mask1=None,
|
| 89 |
-
upcast=False,
|
| 90 |
-
):
|
| 91 |
-
dtype = x.dtype
|
| 92 |
-
if upcast:
|
| 93 |
-
x = x.float()
|
| 94 |
-
weight = weight.float()
|
| 95 |
-
bias = bias.float() if bias is not None else None
|
| 96 |
-
residual = residual.float() if residual is not None else residual
|
| 97 |
-
x1 = x1.float() if x1 is not None else None
|
| 98 |
-
weight1 = weight1.float() if weight1 is not None else None
|
| 99 |
-
bias1 = bias1.float() if bias1 is not None else None
|
| 100 |
-
if x1 is not None:
|
| 101 |
-
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
| 102 |
-
if rowscale is not None:
|
| 103 |
-
x = x * rowscale[..., None]
|
| 104 |
-
if dropout_p > 0.0:
|
| 105 |
-
if dropout_mask is not None:
|
| 106 |
-
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
|
| 107 |
-
else:
|
| 108 |
-
x = F.dropout(x, p=dropout_p)
|
| 109 |
-
if x1 is not None:
|
| 110 |
-
if dropout_mask1 is not None:
|
| 111 |
-
x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
|
| 112 |
-
else:
|
| 113 |
-
x1 = F.dropout(x1, p=dropout_p)
|
| 114 |
-
if x1 is not None:
|
| 115 |
-
x = x + x1
|
| 116 |
-
if residual is not None:
|
| 117 |
-
x = (x + residual).to(x.dtype)
|
| 118 |
-
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
| 119 |
-
out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(
|
| 120 |
-
dtype
|
| 121 |
-
)
|
| 122 |
-
if weight1 is None:
|
| 123 |
-
return out if not prenorm else (out, x)
|
| 124 |
-
else:
|
| 125 |
-
out1 = (
|
| 126 |
-
(x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)
|
| 127 |
-
).to(dtype)
|
| 128 |
-
return (out, out1) if not prenorm else (out, out1, x)
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
def config_prune(configs):
|
| 132 |
-
|
| 133 |
-
if torch.version.hip:
|
| 134 |
-
try:
|
| 135 |
-
# set warp size based on gcn architecure
|
| 136 |
-
gcn_arch_name = torch.cuda.get_device_properties(0).gcnArchName
|
| 137 |
-
if "gfx10" in gcn_arch_name or "gfx11" in gcn_arch_name:
|
| 138 |
-
# radeon
|
| 139 |
-
warp_size = 32
|
| 140 |
-
else:
|
| 141 |
-
# instinct
|
| 142 |
-
warp_size = 64
|
| 143 |
-
except AttributeError as e:
|
| 144 |
-
# fall back to crude method to set warp size
|
| 145 |
-
device_name = torch.cuda.get_device_properties(0).name
|
| 146 |
-
if "instinct" in device_name.lower():
|
| 147 |
-
warp_size = 64
|
| 148 |
-
else:
|
| 149 |
-
warp_size = 32
|
| 150 |
-
warnings.warn(
|
| 151 |
-
f"{e}, warp size set to {warp_size} based on device name: {device_name}",
|
| 152 |
-
UserWarning,
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
-
else:
|
| 156 |
-
# cuda
|
| 157 |
-
warp_size = 32
|
| 158 |
-
|
| 159 |
-
max_block_sz = 1024
|
| 160 |
-
max_num_warps = max_block_sz // warp_size
|
| 161 |
-
pruned_configs = [config for config in configs if config.num_warps <= max_num_warps]
|
| 162 |
-
return pruned_configs
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
configs_autotune = [
|
| 166 |
-
triton.Config({}, num_warps=1),
|
| 167 |
-
triton.Config({}, num_warps=2),
|
| 168 |
-
triton.Config({}, num_warps=4),
|
| 169 |
-
triton.Config({}, num_warps=8),
|
| 170 |
-
triton.Config({}, num_warps=16),
|
| 171 |
-
triton.Config({}, num_warps=32),
|
| 172 |
-
]
|
| 173 |
-
|
| 174 |
-
pruned_configs_autotune = config_prune(configs_autotune)
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
@triton.autotune(
|
| 178 |
-
configs=pruned_configs_autotune,
|
| 179 |
-
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
|
| 180 |
-
)
|
| 181 |
-
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
| 182 |
-
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
|
| 183 |
-
@triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
|
| 184 |
-
@triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
|
| 185 |
-
@triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
|
| 186 |
-
@triton.jit
|
| 187 |
-
def _layer_norm_fwd_1pass_kernel(
|
| 188 |
-
X, # pointer to the input
|
| 189 |
-
Y, # pointer to the output
|
| 190 |
-
W, # pointer to the weights
|
| 191 |
-
B, # pointer to the biases
|
| 192 |
-
RESIDUAL, # pointer to the residual
|
| 193 |
-
X1,
|
| 194 |
-
W1,
|
| 195 |
-
B1,
|
| 196 |
-
Y1,
|
| 197 |
-
RESIDUAL_OUT, # pointer to the residual
|
| 198 |
-
ROWSCALE,
|
| 199 |
-
SEEDS, # Dropout seeds for each row
|
| 200 |
-
DROPOUT_MASK,
|
| 201 |
-
Mean, # pointer to the mean
|
| 202 |
-
Rstd, # pointer to the 1/std
|
| 203 |
-
stride_x_row, # how much to increase the pointer when moving by 1 row
|
| 204 |
-
stride_y_row,
|
| 205 |
-
stride_res_row,
|
| 206 |
-
stride_res_out_row,
|
| 207 |
-
stride_x1_row,
|
| 208 |
-
stride_y1_row,
|
| 209 |
-
M, # number of rows in X
|
| 210 |
-
N, # number of columns in X
|
| 211 |
-
eps, # epsilon to avoid division by zero
|
| 212 |
-
dropout_p, # Dropout probability
|
| 213 |
-
IS_RMS_NORM: tl.constexpr,
|
| 214 |
-
BLOCK_N: tl.constexpr,
|
| 215 |
-
HAS_RESIDUAL: tl.constexpr,
|
| 216 |
-
STORE_RESIDUAL_OUT: tl.constexpr,
|
| 217 |
-
HAS_BIAS: tl.constexpr,
|
| 218 |
-
HAS_DROPOUT: tl.constexpr,
|
| 219 |
-
STORE_DROPOUT_MASK: tl.constexpr,
|
| 220 |
-
HAS_ROWSCALE: tl.constexpr,
|
| 221 |
-
HAS_X1: tl.constexpr,
|
| 222 |
-
HAS_W1: tl.constexpr,
|
| 223 |
-
HAS_B1: tl.constexpr,
|
| 224 |
-
):
|
| 225 |
-
# Map the program id to the row of X and Y it should compute.
|
| 226 |
-
row = tl.program_id(0)
|
| 227 |
-
X += row * stride_x_row
|
| 228 |
-
Y += row * stride_y_row
|
| 229 |
-
if HAS_RESIDUAL:
|
| 230 |
-
RESIDUAL += row * stride_res_row
|
| 231 |
-
if STORE_RESIDUAL_OUT:
|
| 232 |
-
RESIDUAL_OUT += row * stride_res_out_row
|
| 233 |
-
if HAS_X1:
|
| 234 |
-
X1 += row * stride_x1_row
|
| 235 |
-
if HAS_W1:
|
| 236 |
-
Y1 += row * stride_y1_row
|
| 237 |
-
# Compute mean and variance
|
| 238 |
-
cols = tl.arange(0, BLOCK_N)
|
| 239 |
-
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
| 240 |
-
if HAS_ROWSCALE:
|
| 241 |
-
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
|
| 242 |
-
x *= rowscale
|
| 243 |
-
if HAS_DROPOUT:
|
| 244 |
-
# Compute dropout mask
|
| 245 |
-
# 7 rounds is good enough, and reduces register pressure
|
| 246 |
-
keep_mask = (
|
| 247 |
-
tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
|
| 248 |
-
)
|
| 249 |
-
x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
|
| 250 |
-
if STORE_DROPOUT_MASK:
|
| 251 |
-
tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
|
| 252 |
-
if HAS_X1:
|
| 253 |
-
x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
|
| 254 |
-
if HAS_ROWSCALE:
|
| 255 |
-
rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
|
| 256 |
-
x1 *= rowscale
|
| 257 |
-
if HAS_DROPOUT:
|
| 258 |
-
# Compute dropout mask
|
| 259 |
-
# 7 rounds is good enough, and reduces register pressure
|
| 260 |
-
keep_mask = (
|
| 261 |
-
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
|
| 262 |
-
> dropout_p
|
| 263 |
-
)
|
| 264 |
-
x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
|
| 265 |
-
if STORE_DROPOUT_MASK:
|
| 266 |
-
tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
|
| 267 |
-
x += x1
|
| 268 |
-
if HAS_RESIDUAL:
|
| 269 |
-
residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
|
| 270 |
-
x += residual
|
| 271 |
-
if STORE_RESIDUAL_OUT:
|
| 272 |
-
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
|
| 273 |
-
if not IS_RMS_NORM:
|
| 274 |
-
mean = tl.sum(x, axis=0) / N
|
| 275 |
-
tl.store(Mean + row, mean)
|
| 276 |
-
xbar = tl.where(cols < N, x - mean, 0.0)
|
| 277 |
-
var = tl.sum(xbar * xbar, axis=0) / N
|
| 278 |
-
else:
|
| 279 |
-
xbar = tl.where(cols < N, x, 0.0)
|
| 280 |
-
var = tl.sum(xbar * xbar, axis=0) / N
|
| 281 |
-
rstd = 1 / tl.sqrt(var + eps)
|
| 282 |
-
tl.store(Rstd + row, rstd)
|
| 283 |
-
# Normalize and apply linear transformation
|
| 284 |
-
mask = cols < N
|
| 285 |
-
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
| 286 |
-
if HAS_BIAS:
|
| 287 |
-
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
| 288 |
-
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
| 289 |
-
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
| 290 |
-
# Write output
|
| 291 |
-
tl.store(Y + cols, y, mask=mask)
|
| 292 |
-
if HAS_W1:
|
| 293 |
-
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
|
| 294 |
-
if HAS_B1:
|
| 295 |
-
b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
|
| 296 |
-
y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
|
| 297 |
-
tl.store(Y1 + cols, y1, mask=mask)
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
def _layer_norm_fwd(
|
| 301 |
-
x,
|
| 302 |
-
weight,
|
| 303 |
-
bias,
|
| 304 |
-
eps,
|
| 305 |
-
residual=None,
|
| 306 |
-
x1=None,
|
| 307 |
-
weight1=None,
|
| 308 |
-
bias1=None,
|
| 309 |
-
dropout_p=0.0,
|
| 310 |
-
rowscale=None,
|
| 311 |
-
out_dtype=None,
|
| 312 |
-
residual_dtype=None,
|
| 313 |
-
is_rms_norm=False,
|
| 314 |
-
return_dropout_mask=False,
|
| 315 |
-
):
|
| 316 |
-
if residual is not None:
|
| 317 |
-
residual_dtype = residual.dtype
|
| 318 |
-
M, N = x.shape
|
| 319 |
-
assert x.stride(-1) == 1
|
| 320 |
-
if residual is not None:
|
| 321 |
-
assert residual.stride(-1) == 1
|
| 322 |
-
assert residual.shape == (M, N)
|
| 323 |
-
assert weight.shape == (N,)
|
| 324 |
-
assert weight.stride(-1) == 1
|
| 325 |
-
if bias is not None:
|
| 326 |
-
assert bias.stride(-1) == 1
|
| 327 |
-
assert bias.shape == (N,)
|
| 328 |
-
if x1 is not None:
|
| 329 |
-
assert x1.shape == x.shape
|
| 330 |
-
assert rowscale is None
|
| 331 |
-
assert x1.stride(-1) == 1
|
| 332 |
-
if weight1 is not None:
|
| 333 |
-
assert weight1.shape == (N,)
|
| 334 |
-
assert weight1.stride(-1) == 1
|
| 335 |
-
if bias1 is not None:
|
| 336 |
-
assert bias1.shape == (N,)
|
| 337 |
-
assert bias1.stride(-1) == 1
|
| 338 |
-
if rowscale is not None:
|
| 339 |
-
assert rowscale.is_contiguous()
|
| 340 |
-
assert rowscale.shape == (M,)
|
| 341 |
-
# allocate output
|
| 342 |
-
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
|
| 343 |
-
assert y.stride(-1) == 1
|
| 344 |
-
if weight1 is not None:
|
| 345 |
-
y1 = torch.empty_like(y)
|
| 346 |
-
assert y1.stride(-1) == 1
|
| 347 |
-
else:
|
| 348 |
-
y1 = None
|
| 349 |
-
if (
|
| 350 |
-
residual is not None
|
| 351 |
-
or (residual_dtype is not None and residual_dtype != x.dtype)
|
| 352 |
-
or dropout_p > 0.0
|
| 353 |
-
or rowscale is not None
|
| 354 |
-
or x1 is not None
|
| 355 |
-
):
|
| 356 |
-
residual_out = torch.empty(
|
| 357 |
-
M,
|
| 358 |
-
N,
|
| 359 |
-
device=x.device,
|
| 360 |
-
dtype=residual_dtype if residual_dtype is not None else x.dtype,
|
| 361 |
-
)
|
| 362 |
-
assert residual_out.stride(-1) == 1
|
| 363 |
-
else:
|
| 364 |
-
residual_out = None
|
| 365 |
-
mean = (
|
| 366 |
-
torch.empty((M,), dtype=torch.float32, device=x.device)
|
| 367 |
-
if not is_rms_norm
|
| 368 |
-
else None
|
| 369 |
-
)
|
| 370 |
-
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
|
| 371 |
-
if dropout_p > 0.0:
|
| 372 |
-
seeds = torch.randint(
|
| 373 |
-
2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
|
| 374 |
-
)
|
| 375 |
-
else:
|
| 376 |
-
seeds = None
|
| 377 |
-
if return_dropout_mask and dropout_p > 0.0:
|
| 378 |
-
dropout_mask = torch.empty(
|
| 379 |
-
M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool
|
| 380 |
-
)
|
| 381 |
-
else:
|
| 382 |
-
dropout_mask = None
|
| 383 |
-
# Less than 64KB per feature: enqueue fused kernel
|
| 384 |
-
MAX_FUSED_SIZE = 65536 // x.element_size()
|
| 385 |
-
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
| 386 |
-
if N > BLOCK_N:
|
| 387 |
-
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
| 388 |
-
with torch.cuda.device(x.device.index):
|
| 389 |
-
_layer_norm_fwd_1pass_kernel[(M,)](
|
| 390 |
-
x,
|
| 391 |
-
y,
|
| 392 |
-
weight,
|
| 393 |
-
bias,
|
| 394 |
-
residual,
|
| 395 |
-
x1,
|
| 396 |
-
weight1,
|
| 397 |
-
bias1,
|
| 398 |
-
y1,
|
| 399 |
-
residual_out,
|
| 400 |
-
rowscale,
|
| 401 |
-
seeds,
|
| 402 |
-
dropout_mask,
|
| 403 |
-
mean,
|
| 404 |
-
rstd,
|
| 405 |
-
x.stride(0),
|
| 406 |
-
y.stride(0),
|
| 407 |
-
residual.stride(0) if residual is not None else 0,
|
| 408 |
-
residual_out.stride(0) if residual_out is not None else 0,
|
| 409 |
-
x1.stride(0) if x1 is not None else 0,
|
| 410 |
-
y1.stride(0) if y1 is not None else 0,
|
| 411 |
-
M,
|
| 412 |
-
N,
|
| 413 |
-
eps,
|
| 414 |
-
dropout_p,
|
| 415 |
-
is_rms_norm,
|
| 416 |
-
BLOCK_N,
|
| 417 |
-
residual is not None,
|
| 418 |
-
residual_out is not None,
|
| 419 |
-
bias is not None,
|
| 420 |
-
dropout_p > 0.0,
|
| 421 |
-
dropout_mask is not None,
|
| 422 |
-
rowscale is not None,
|
| 423 |
-
)
|
| 424 |
-
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
|
| 425 |
-
if dropout_mask is not None and x1 is not None:
|
| 426 |
-
dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
|
| 427 |
-
else:
|
| 428 |
-
dropout_mask1 = None
|
| 429 |
-
return (
|
| 430 |
-
y,
|
| 431 |
-
y1,
|
| 432 |
-
mean,
|
| 433 |
-
rstd,
|
| 434 |
-
residual_out if residual_out is not None else x,
|
| 435 |
-
seeds,
|
| 436 |
-
dropout_mask,
|
| 437 |
-
dropout_mask1,
|
| 438 |
-
)
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
@triton.autotune(
|
| 442 |
-
configs=pruned_configs_autotune,
|
| 443 |
-
key=[
|
| 444 |
-
"N",
|
| 445 |
-
"HAS_DRESIDUAL",
|
| 446 |
-
"STORE_DRESIDUAL",
|
| 447 |
-
"IS_RMS_NORM",
|
| 448 |
-
"HAS_BIAS",
|
| 449 |
-
"HAS_DROPOUT",
|
| 450 |
-
],
|
| 451 |
-
)
|
| 452 |
-
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
| 453 |
-
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
|
| 454 |
-
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
|
| 455 |
-
@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
|
| 456 |
-
@triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
|
| 457 |
-
@triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
|
| 458 |
-
@triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
|
| 459 |
-
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
|
| 460 |
-
@triton.jit
|
| 461 |
-
def _layer_norm_bwd_kernel(
|
| 462 |
-
X, # pointer to the input
|
| 463 |
-
W, # pointer to the weights
|
| 464 |
-
B, # pointer to the biases
|
| 465 |
-
Y, # pointer to the output to be recomputed
|
| 466 |
-
DY, # pointer to the output gradient
|
| 467 |
-
DX, # pointer to the input gradient
|
| 468 |
-
DW, # pointer to the partial sum of weights gradient
|
| 469 |
-
DB, # pointer to the partial sum of biases gradient
|
| 470 |
-
DRESIDUAL,
|
| 471 |
-
W1,
|
| 472 |
-
DY1,
|
| 473 |
-
DX1,
|
| 474 |
-
DW1,
|
| 475 |
-
DB1,
|
| 476 |
-
DRESIDUAL_IN,
|
| 477 |
-
ROWSCALE,
|
| 478 |
-
SEEDS,
|
| 479 |
-
Mean, # pointer to the mean
|
| 480 |
-
Rstd, # pointer to the 1/std
|
| 481 |
-
stride_x_row, # how much to increase the pointer when moving by 1 row
|
| 482 |
-
stride_y_row,
|
| 483 |
-
stride_dy_row,
|
| 484 |
-
stride_dx_row,
|
| 485 |
-
stride_dres_row,
|
| 486 |
-
stride_dy1_row,
|
| 487 |
-
stride_dx1_row,
|
| 488 |
-
stride_dres_in_row,
|
| 489 |
-
M, # number of rows in X
|
| 490 |
-
N, # number of columns in X
|
| 491 |
-
eps, # epsilon to avoid division by zero
|
| 492 |
-
dropout_p,
|
| 493 |
-
rows_per_program,
|
| 494 |
-
IS_RMS_NORM: tl.constexpr,
|
| 495 |
-
BLOCK_N: tl.constexpr,
|
| 496 |
-
HAS_DRESIDUAL: tl.constexpr,
|
| 497 |
-
STORE_DRESIDUAL: tl.constexpr,
|
| 498 |
-
HAS_BIAS: tl.constexpr,
|
| 499 |
-
HAS_DROPOUT: tl.constexpr,
|
| 500 |
-
HAS_ROWSCALE: tl.constexpr,
|
| 501 |
-
HAS_DY1: tl.constexpr,
|
| 502 |
-
HAS_DX1: tl.constexpr,
|
| 503 |
-
HAS_B1: tl.constexpr,
|
| 504 |
-
RECOMPUTE_OUTPUT: tl.constexpr,
|
| 505 |
-
):
|
| 506 |
-
# Map the program id to the elements of X, DX, and DY it should compute.
|
| 507 |
-
row_block_id = tl.program_id(0)
|
| 508 |
-
row_start = row_block_id * rows_per_program
|
| 509 |
-
# Do not early exit if row_start >= M, because we need to write DW and DB
|
| 510 |
-
cols = tl.arange(0, BLOCK_N)
|
| 511 |
-
mask = cols < N
|
| 512 |
-
X += row_start * stride_x_row
|
| 513 |
-
if HAS_DRESIDUAL:
|
| 514 |
-
DRESIDUAL += row_start * stride_dres_row
|
| 515 |
-
if STORE_DRESIDUAL:
|
| 516 |
-
DRESIDUAL_IN += row_start * stride_dres_in_row
|
| 517 |
-
DY += row_start * stride_dy_row
|
| 518 |
-
DX += row_start * stride_dx_row
|
| 519 |
-
if HAS_DY1:
|
| 520 |
-
DY1 += row_start * stride_dy1_row
|
| 521 |
-
if HAS_DX1:
|
| 522 |
-
DX1 += row_start * stride_dx1_row
|
| 523 |
-
if RECOMPUTE_OUTPUT:
|
| 524 |
-
Y += row_start * stride_y_row
|
| 525 |
-
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
| 526 |
-
if RECOMPUTE_OUTPUT and HAS_BIAS:
|
| 527 |
-
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
|
| 528 |
-
if HAS_DY1:
|
| 529 |
-
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
|
| 530 |
-
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 531 |
-
if HAS_BIAS:
|
| 532 |
-
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 533 |
-
if HAS_DY1:
|
| 534 |
-
dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 535 |
-
if HAS_B1:
|
| 536 |
-
db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 537 |
-
row_end = min((row_block_id + 1) * rows_per_program, M)
|
| 538 |
-
for row in range(row_start, row_end):
|
| 539 |
-
# Load data to SRAM
|
| 540 |
-
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
| 541 |
-
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
| 542 |
-
if HAS_DY1:
|
| 543 |
-
dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
|
| 544 |
-
if not IS_RMS_NORM:
|
| 545 |
-
mean = tl.load(Mean + row)
|
| 546 |
-
rstd = tl.load(Rstd + row)
|
| 547 |
-
# Compute dx
|
| 548 |
-
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
| 549 |
-
xhat = tl.where(mask, xhat, 0.0)
|
| 550 |
-
if RECOMPUTE_OUTPUT:
|
| 551 |
-
y = xhat * w + b if HAS_BIAS else xhat * w
|
| 552 |
-
tl.store(Y + cols, y, mask=mask)
|
| 553 |
-
wdy = w * dy
|
| 554 |
-
dw += dy * xhat
|
| 555 |
-
if HAS_BIAS:
|
| 556 |
-
db += dy
|
| 557 |
-
if HAS_DY1:
|
| 558 |
-
wdy += w1 * dy1
|
| 559 |
-
dw1 += dy1 * xhat
|
| 560 |
-
if HAS_B1:
|
| 561 |
-
db1 += dy1
|
| 562 |
-
if not IS_RMS_NORM:
|
| 563 |
-
c1 = tl.sum(xhat * wdy, axis=0) / N
|
| 564 |
-
c2 = tl.sum(wdy, axis=0) / N
|
| 565 |
-
dx = (wdy - (xhat * c1 + c2)) * rstd
|
| 566 |
-
else:
|
| 567 |
-
c1 = tl.sum(xhat * wdy, axis=0) / N
|
| 568 |
-
dx = (wdy - xhat * c1) * rstd
|
| 569 |
-
if HAS_DRESIDUAL:
|
| 570 |
-
dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
|
| 571 |
-
dx += dres
|
| 572 |
-
# Write dx
|
| 573 |
-
if STORE_DRESIDUAL:
|
| 574 |
-
tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
|
| 575 |
-
if HAS_DX1:
|
| 576 |
-
if HAS_DROPOUT:
|
| 577 |
-
keep_mask = (
|
| 578 |
-
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
|
| 579 |
-
> dropout_p
|
| 580 |
-
)
|
| 581 |
-
dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
|
| 582 |
-
else:
|
| 583 |
-
dx1 = dx
|
| 584 |
-
tl.store(DX1 + cols, dx1, mask=mask)
|
| 585 |
-
if HAS_DROPOUT:
|
| 586 |
-
keep_mask = (
|
| 587 |
-
tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7)
|
| 588 |
-
> dropout_p
|
| 589 |
-
)
|
| 590 |
-
dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
|
| 591 |
-
if HAS_ROWSCALE:
|
| 592 |
-
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
|
| 593 |
-
dx *= rowscale
|
| 594 |
-
tl.store(DX + cols, dx, mask=mask)
|
| 595 |
-
|
| 596 |
-
X += stride_x_row
|
| 597 |
-
if HAS_DRESIDUAL:
|
| 598 |
-
DRESIDUAL += stride_dres_row
|
| 599 |
-
if STORE_DRESIDUAL:
|
| 600 |
-
DRESIDUAL_IN += stride_dres_in_row
|
| 601 |
-
if RECOMPUTE_OUTPUT:
|
| 602 |
-
Y += stride_y_row
|
| 603 |
-
DY += stride_dy_row
|
| 604 |
-
DX += stride_dx_row
|
| 605 |
-
if HAS_DY1:
|
| 606 |
-
DY1 += stride_dy1_row
|
| 607 |
-
if HAS_DX1:
|
| 608 |
-
DX1 += stride_dx1_row
|
| 609 |
-
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
|
| 610 |
-
if HAS_BIAS:
|
| 611 |
-
tl.store(DB + row_block_id * N + cols, db, mask=mask)
|
| 612 |
-
if HAS_DY1:
|
| 613 |
-
tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
|
| 614 |
-
if HAS_B1:
|
| 615 |
-
tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
def _layer_norm_bwd(
|
| 619 |
-
dy,
|
| 620 |
-
x,
|
| 621 |
-
weight,
|
| 622 |
-
bias,
|
| 623 |
-
eps,
|
| 624 |
-
mean,
|
| 625 |
-
rstd,
|
| 626 |
-
dresidual=None,
|
| 627 |
-
dy1=None,
|
| 628 |
-
weight1=None,
|
| 629 |
-
bias1=None,
|
| 630 |
-
seeds=None,
|
| 631 |
-
dropout_p=0.0,
|
| 632 |
-
rowscale=None,
|
| 633 |
-
has_residual=False,
|
| 634 |
-
has_x1=False,
|
| 635 |
-
is_rms_norm=False,
|
| 636 |
-
x_dtype=None,
|
| 637 |
-
recompute_output=False,
|
| 638 |
-
):
|
| 639 |
-
M, N = x.shape
|
| 640 |
-
assert x.stride(-1) == 1
|
| 641 |
-
assert dy.stride(-1) == 1
|
| 642 |
-
assert dy.shape == (M, N)
|
| 643 |
-
if dresidual is not None:
|
| 644 |
-
assert dresidual.stride(-1) == 1
|
| 645 |
-
assert dresidual.shape == (M, N)
|
| 646 |
-
assert weight.shape == (N,)
|
| 647 |
-
assert weight.stride(-1) == 1
|
| 648 |
-
if bias is not None:
|
| 649 |
-
assert bias.stride(-1) == 1
|
| 650 |
-
assert bias.shape == (N,)
|
| 651 |
-
if dy1 is not None:
|
| 652 |
-
assert weight1 is not None
|
| 653 |
-
assert dy1.shape == dy.shape
|
| 654 |
-
assert dy1.stride(-1) == 1
|
| 655 |
-
if weight1 is not None:
|
| 656 |
-
assert weight1.shape == (N,)
|
| 657 |
-
assert weight1.stride(-1) == 1
|
| 658 |
-
if bias1 is not None:
|
| 659 |
-
assert bias1.shape == (N,)
|
| 660 |
-
assert bias1.stride(-1) == 1
|
| 661 |
-
if seeds is not None:
|
| 662 |
-
assert seeds.is_contiguous()
|
| 663 |
-
assert seeds.shape == (M if not has_x1 else M * 2,)
|
| 664 |
-
if rowscale is not None:
|
| 665 |
-
assert rowscale.is_contiguous()
|
| 666 |
-
assert rowscale.shape == (M,)
|
| 667 |
-
# allocate output
|
| 668 |
-
dx = (
|
| 669 |
-
torch.empty_like(x)
|
| 670 |
-
if x_dtype is None
|
| 671 |
-
else torch.empty(M, N, dtype=x_dtype, device=x.device)
|
| 672 |
-
)
|
| 673 |
-
dresidual_in = (
|
| 674 |
-
torch.empty_like(x)
|
| 675 |
-
if has_residual
|
| 676 |
-
and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
|
| 677 |
-
else None
|
| 678 |
-
)
|
| 679 |
-
dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
|
| 680 |
-
y = (
|
| 681 |
-
torch.empty(M, N, dtype=dy.dtype, device=dy.device)
|
| 682 |
-
if recompute_output
|
| 683 |
-
else None
|
| 684 |
-
)
|
| 685 |
-
if recompute_output:
|
| 686 |
-
assert (
|
| 687 |
-
weight1 is None
|
| 688 |
-
), "recompute_output is not supported with parallel LayerNorm"
|
| 689 |
-
|
| 690 |
-
# Less than 64KB per feature: enqueue fused kernel
|
| 691 |
-
MAX_FUSED_SIZE = 65536 // x.element_size()
|
| 692 |
-
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
| 693 |
-
if N > BLOCK_N:
|
| 694 |
-
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
| 695 |
-
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
| 696 |
-
_dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
|
| 697 |
-
_db = (
|
| 698 |
-
torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
|
| 699 |
-
if bias is not None
|
| 700 |
-
else None
|
| 701 |
-
)
|
| 702 |
-
_dw1 = torch.empty_like(_dw) if weight1 is not None else None
|
| 703 |
-
_db1 = torch.empty_like(_db) if bias1 is not None else None
|
| 704 |
-
rows_per_program = math.ceil(M / sm_count)
|
| 705 |
-
grid = (sm_count,)
|
| 706 |
-
with torch.cuda.device(x.device.index):
|
| 707 |
-
_layer_norm_bwd_kernel[grid](
|
| 708 |
-
x,
|
| 709 |
-
weight,
|
| 710 |
-
bias,
|
| 711 |
-
y,
|
| 712 |
-
dy,
|
| 713 |
-
dx,
|
| 714 |
-
_dw,
|
| 715 |
-
_db,
|
| 716 |
-
dresidual,
|
| 717 |
-
weight1,
|
| 718 |
-
dy1,
|
| 719 |
-
dx1,
|
| 720 |
-
_dw1,
|
| 721 |
-
_db1,
|
| 722 |
-
dresidual_in,
|
| 723 |
-
rowscale,
|
| 724 |
-
seeds,
|
| 725 |
-
mean,
|
| 726 |
-
rstd,
|
| 727 |
-
x.stride(0),
|
| 728 |
-
0 if not recompute_output else y.stride(0),
|
| 729 |
-
dy.stride(0),
|
| 730 |
-
dx.stride(0),
|
| 731 |
-
dresidual.stride(0) if dresidual is not None else 0,
|
| 732 |
-
dy1.stride(0) if dy1 is not None else 0,
|
| 733 |
-
dx1.stride(0) if dx1 is not None else 0,
|
| 734 |
-
dresidual_in.stride(0) if dresidual_in is not None else 0,
|
| 735 |
-
M,
|
| 736 |
-
N,
|
| 737 |
-
eps,
|
| 738 |
-
dropout_p,
|
| 739 |
-
rows_per_program,
|
| 740 |
-
is_rms_norm,
|
| 741 |
-
BLOCK_N,
|
| 742 |
-
dresidual is not None,
|
| 743 |
-
dresidual_in is not None,
|
| 744 |
-
bias is not None,
|
| 745 |
-
dropout_p > 0.0,
|
| 746 |
-
)
|
| 747 |
-
dw = _dw.sum(0).to(weight.dtype)
|
| 748 |
-
db = _db.sum(0).to(bias.dtype) if bias is not None else None
|
| 749 |
-
dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
|
| 750 |
-
db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
|
| 751 |
-
# Don't need to compute dresidual_in separately in this case
|
| 752 |
-
if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
|
| 753 |
-
dresidual_in = dx
|
| 754 |
-
if has_x1 and dropout_p == 0.0:
|
| 755 |
-
dx1 = dx
|
| 756 |
-
return (
|
| 757 |
-
(dx, dw, db, dresidual_in, dx1, dw1, db1)
|
| 758 |
-
if not recompute_output
|
| 759 |
-
else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
|
| 760 |
-
)
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
class LayerNormFn(torch.autograd.Function):
|
| 764 |
-
@staticmethod
|
| 765 |
-
def forward(
|
| 766 |
-
ctx,
|
| 767 |
-
x,
|
| 768 |
-
weight,
|
| 769 |
-
bias,
|
| 770 |
-
residual=None,
|
| 771 |
-
x1=None,
|
| 772 |
-
weight1=None,
|
| 773 |
-
bias1=None,
|
| 774 |
-
eps=1e-6,
|
| 775 |
-
dropout_p=0.0,
|
| 776 |
-
rowscale=None,
|
| 777 |
-
prenorm=False,
|
| 778 |
-
residual_in_fp32=False,
|
| 779 |
-
is_rms_norm=False,
|
| 780 |
-
return_dropout_mask=False,
|
| 781 |
-
):
|
| 782 |
-
x_shape_og = x.shape
|
| 783 |
-
# reshape input data into 2D tensor
|
| 784 |
-
x = x.reshape(-1, x.shape[-1])
|
| 785 |
-
if x.stride(-1) != 1:
|
| 786 |
-
x = x.contiguous()
|
| 787 |
-
if residual is not None:
|
| 788 |
-
assert residual.shape == x_shape_og
|
| 789 |
-
residual = residual.reshape(-1, residual.shape[-1])
|
| 790 |
-
if residual.stride(-1) != 1:
|
| 791 |
-
residual = residual.contiguous()
|
| 792 |
-
if x1 is not None:
|
| 793 |
-
assert x1.shape == x_shape_og
|
| 794 |
-
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
| 795 |
-
x1 = x1.reshape(-1, x1.shape[-1])
|
| 796 |
-
if x1.stride(-1) != 1:
|
| 797 |
-
x1 = x1.contiguous()
|
| 798 |
-
weight = weight.contiguous()
|
| 799 |
-
if bias is not None:
|
| 800 |
-
bias = bias.contiguous()
|
| 801 |
-
if weight1 is not None:
|
| 802 |
-
weight1 = weight1.contiguous()
|
| 803 |
-
if bias1 is not None:
|
| 804 |
-
bias1 = bias1.contiguous()
|
| 805 |
-
if rowscale is not None:
|
| 806 |
-
rowscale = rowscale.reshape(-1).contiguous()
|
| 807 |
-
residual_dtype = (
|
| 808 |
-
residual.dtype
|
| 809 |
-
if residual is not None
|
| 810 |
-
else (torch.float32 if residual_in_fp32 else None)
|
| 811 |
-
)
|
| 812 |
-
y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = (
|
| 813 |
-
_layer_norm_fwd(
|
| 814 |
-
x,
|
| 815 |
-
weight,
|
| 816 |
-
bias,
|
| 817 |
-
eps,
|
| 818 |
-
residual,
|
| 819 |
-
x1,
|
| 820 |
-
weight1,
|
| 821 |
-
bias1,
|
| 822 |
-
dropout_p=dropout_p,
|
| 823 |
-
rowscale=rowscale,
|
| 824 |
-
residual_dtype=residual_dtype,
|
| 825 |
-
is_rms_norm=is_rms_norm,
|
| 826 |
-
return_dropout_mask=return_dropout_mask,
|
| 827 |
-
)
|
| 828 |
-
)
|
| 829 |
-
ctx.save_for_backward(
|
| 830 |
-
residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
|
| 831 |
-
)
|
| 832 |
-
ctx.x_shape_og = x_shape_og
|
| 833 |
-
ctx.eps = eps
|
| 834 |
-
ctx.dropout_p = dropout_p
|
| 835 |
-
ctx.is_rms_norm = is_rms_norm
|
| 836 |
-
ctx.has_residual = residual is not None
|
| 837 |
-
ctx.has_x1 = x1 is not None
|
| 838 |
-
ctx.prenorm = prenorm
|
| 839 |
-
ctx.x_dtype = x.dtype
|
| 840 |
-
y = y.reshape(x_shape_og)
|
| 841 |
-
y1 = y1.reshape(x_shape_og) if y1 is not None else None
|
| 842 |
-
residual_out = (
|
| 843 |
-
residual_out.reshape(x_shape_og) if residual_out is not None else None
|
| 844 |
-
)
|
| 845 |
-
dropout_mask = (
|
| 846 |
-
dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
|
| 847 |
-
)
|
| 848 |
-
dropout_mask1 = (
|
| 849 |
-
dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
|
| 850 |
-
)
|
| 851 |
-
if not return_dropout_mask:
|
| 852 |
-
if weight1 is None:
|
| 853 |
-
return y if not prenorm else (y, residual_out)
|
| 854 |
-
else:
|
| 855 |
-
return (y, y1) if not prenorm else (y, y1, residual_out)
|
| 856 |
-
else:
|
| 857 |
-
if weight1 is None:
|
| 858 |
-
return (
|
| 859 |
-
(y, dropout_mask, dropout_mask1)
|
| 860 |
-
if not prenorm
|
| 861 |
-
else (y, residual_out, dropout_mask, dropout_mask1)
|
| 862 |
-
)
|
| 863 |
-
else:
|
| 864 |
-
return (
|
| 865 |
-
(y, y1, dropout_mask, dropout_mask1)
|
| 866 |
-
if not prenorm
|
| 867 |
-
else (y, y1, residual_out, dropout_mask, dropout_mask1)
|
| 868 |
-
)
|
| 869 |
-
|
| 870 |
-
@staticmethod
|
| 871 |
-
def backward(ctx, dy, *args):
|
| 872 |
-
x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
|
| 873 |
-
dy = dy.reshape(-1, dy.shape[-1])
|
| 874 |
-
if dy.stride(-1) != 1:
|
| 875 |
-
dy = dy.contiguous()
|
| 876 |
-
assert dy.shape == x.shape
|
| 877 |
-
if weight1 is not None:
|
| 878 |
-
dy1, args = args[0], args[1:]
|
| 879 |
-
dy1 = dy1.reshape(-1, dy1.shape[-1])
|
| 880 |
-
if dy1.stride(-1) != 1:
|
| 881 |
-
dy1 = dy1.contiguous()
|
| 882 |
-
assert dy1.shape == x.shape
|
| 883 |
-
else:
|
| 884 |
-
dy1 = None
|
| 885 |
-
if ctx.prenorm:
|
| 886 |
-
dresidual = args[0]
|
| 887 |
-
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
| 888 |
-
if dresidual.stride(-1) != 1:
|
| 889 |
-
dresidual = dresidual.contiguous()
|
| 890 |
-
assert dresidual.shape == x.shape
|
| 891 |
-
else:
|
| 892 |
-
dresidual = None
|
| 893 |
-
dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
|
| 894 |
-
dy,
|
| 895 |
-
x,
|
| 896 |
-
weight,
|
| 897 |
-
bias,
|
| 898 |
-
ctx.eps,
|
| 899 |
-
mean,
|
| 900 |
-
rstd,
|
| 901 |
-
dresidual,
|
| 902 |
-
dy1,
|
| 903 |
-
weight1,
|
| 904 |
-
bias1,
|
| 905 |
-
seeds,
|
| 906 |
-
ctx.dropout_p,
|
| 907 |
-
rowscale,
|
| 908 |
-
ctx.has_residual,
|
| 909 |
-
ctx.has_x1,
|
| 910 |
-
ctx.is_rms_norm,
|
| 911 |
-
x_dtype=ctx.x_dtype,
|
| 912 |
-
)
|
| 913 |
-
return (
|
| 914 |
-
dx.reshape(ctx.x_shape_og),
|
| 915 |
-
dw,
|
| 916 |
-
db,
|
| 917 |
-
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
| 918 |
-
dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
|
| 919 |
-
dw1,
|
| 920 |
-
db1,
|
| 921 |
-
None,
|
| 922 |
-
None,
|
| 923 |
-
None,
|
| 924 |
-
None,
|
| 925 |
-
None,
|
| 926 |
-
None,
|
| 927 |
-
None,
|
| 928 |
-
)
|
| 929 |
-
|
| 930 |
-
|
| 931 |
-
def layer_norm_fn(
|
| 932 |
-
x,
|
| 933 |
-
weight,
|
| 934 |
-
bias,
|
| 935 |
-
residual=None,
|
| 936 |
-
x1=None,
|
| 937 |
-
weight1=None,
|
| 938 |
-
bias1=None,
|
| 939 |
-
eps=1e-6,
|
| 940 |
-
dropout_p=0.0,
|
| 941 |
-
rowscale=None,
|
| 942 |
-
prenorm=False,
|
| 943 |
-
residual_in_fp32=False,
|
| 944 |
-
is_rms_norm=False,
|
| 945 |
-
return_dropout_mask=False,
|
| 946 |
-
):
|
| 947 |
-
return LayerNormFn.apply(
|
| 948 |
-
x,
|
| 949 |
-
weight,
|
| 950 |
-
bias,
|
| 951 |
-
residual,
|
| 952 |
-
x1,
|
| 953 |
-
weight1,
|
| 954 |
-
bias1,
|
| 955 |
-
eps,
|
| 956 |
-
dropout_p,
|
| 957 |
-
rowscale,
|
| 958 |
-
prenorm,
|
| 959 |
-
residual_in_fp32,
|
| 960 |
-
is_rms_norm,
|
| 961 |
-
return_dropout_mask,
|
| 962 |
-
)
|
| 963 |
-
|
| 964 |
-
|
| 965 |
-
def rms_norm_fn(
|
| 966 |
-
x,
|
| 967 |
-
weight,
|
| 968 |
-
bias,
|
| 969 |
-
residual=None,
|
| 970 |
-
x1=None,
|
| 971 |
-
weight1=None,
|
| 972 |
-
bias1=None,
|
| 973 |
-
eps=1e-6,
|
| 974 |
-
dropout_p=0.0,
|
| 975 |
-
rowscale=None,
|
| 976 |
-
prenorm=False,
|
| 977 |
-
residual_in_fp32=False,
|
| 978 |
-
return_dropout_mask=False,
|
| 979 |
-
):
|
| 980 |
-
return LayerNormFn.apply(
|
| 981 |
-
x,
|
| 982 |
-
weight,
|
| 983 |
-
bias,
|
| 984 |
-
residual,
|
| 985 |
-
x1,
|
| 986 |
-
weight1,
|
| 987 |
-
bias1,
|
| 988 |
-
eps,
|
| 989 |
-
dropout_p,
|
| 990 |
-
rowscale,
|
| 991 |
-
prenorm,
|
| 992 |
-
residual_in_fp32,
|
| 993 |
-
True,
|
| 994 |
-
return_dropout_mask,
|
| 995 |
-
)
|
| 996 |
-
|
| 997 |
-
|
| 998 |
-
class RMSNorm(torch.nn.Module):
|
| 999 |
-
|
| 1000 |
-
def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
|
| 1001 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
| 1002 |
-
super().__init__()
|
| 1003 |
-
self.eps = eps
|
| 1004 |
-
if dropout_p > 0.0:
|
| 1005 |
-
self.drop = torch.nn.Dropout(dropout_p)
|
| 1006 |
-
else:
|
| 1007 |
-
self.drop = None
|
| 1008 |
-
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
| 1009 |
-
self.register_parameter("bias", None)
|
| 1010 |
-
self.reset_parameters()
|
| 1011 |
-
|
| 1012 |
-
def reset_parameters(self):
|
| 1013 |
-
torch.nn.init.ones_(self.weight)
|
| 1014 |
-
|
| 1015 |
-
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
|
| 1016 |
-
return rms_norm_fn(
|
| 1017 |
-
x,
|
| 1018 |
-
self.weight,
|
| 1019 |
-
self.bias,
|
| 1020 |
-
residual=residual,
|
| 1021 |
-
eps=self.eps,
|
| 1022 |
-
dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
|
| 1023 |
-
prenorm=prenorm,
|
| 1024 |
-
residual_in_fp32=residual_in_fp32,
|
| 1025 |
-
)
|
| 1026 |
-
|
| 1027 |
-
|
| 1028 |
-
class LayerNormLinearFn(torch.autograd.Function):
|
| 1029 |
-
@staticmethod
|
| 1030 |
-
@custom_fwd
|
| 1031 |
-
def forward(
|
| 1032 |
-
ctx,
|
| 1033 |
-
x,
|
| 1034 |
-
norm_weight,
|
| 1035 |
-
norm_bias,
|
| 1036 |
-
linear_weight,
|
| 1037 |
-
linear_bias,
|
| 1038 |
-
residual=None,
|
| 1039 |
-
eps=1e-6,
|
| 1040 |
-
prenorm=False,
|
| 1041 |
-
residual_in_fp32=False,
|
| 1042 |
-
is_rms_norm=False,
|
| 1043 |
-
):
|
| 1044 |
-
x_shape_og = x.shape
|
| 1045 |
-
# reshape input data into 2D tensor
|
| 1046 |
-
x = x.reshape(-1, x.shape[-1])
|
| 1047 |
-
if x.stride(-1) != 1:
|
| 1048 |
-
x = x.contiguous()
|
| 1049 |
-
if residual is not None:
|
| 1050 |
-
assert residual.shape == x_shape_og
|
| 1051 |
-
residual = residual.reshape(-1, residual.shape[-1])
|
| 1052 |
-
if residual.stride(-1) != 1:
|
| 1053 |
-
residual = residual.contiguous()
|
| 1054 |
-
norm_weight = norm_weight.contiguous()
|
| 1055 |
-
if norm_bias is not None:
|
| 1056 |
-
norm_bias = norm_bias.contiguous()
|
| 1057 |
-
residual_dtype = (
|
| 1058 |
-
residual.dtype
|
| 1059 |
-
if residual is not None
|
| 1060 |
-
else (torch.float32 if residual_in_fp32 else None)
|
| 1061 |
-
)
|
| 1062 |
-
y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
|
| 1063 |
-
x,
|
| 1064 |
-
norm_weight,
|
| 1065 |
-
norm_bias,
|
| 1066 |
-
eps,
|
| 1067 |
-
residual,
|
| 1068 |
-
out_dtype=(
|
| 1069 |
-
None
|
| 1070 |
-
if not torch.is_autocast_enabled()
|
| 1071 |
-
else torch.get_autocast_gpu_dtype()
|
| 1072 |
-
),
|
| 1073 |
-
residual_dtype=residual_dtype,
|
| 1074 |
-
is_rms_norm=is_rms_norm,
|
| 1075 |
-
)
|
| 1076 |
-
y = y.reshape(x_shape_og)
|
| 1077 |
-
dtype = (
|
| 1078 |
-
torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
|
| 1079 |
-
)
|
| 1080 |
-
linear_weight = linear_weight.to(dtype)
|
| 1081 |
-
linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
|
| 1082 |
-
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
|
| 1083 |
-
# We don't store y, will be recomputed in the backward pass to save memory
|
| 1084 |
-
ctx.save_for_backward(
|
| 1085 |
-
residual_out, norm_weight, norm_bias, linear_weight, mean, rstd
|
| 1086 |
-
)
|
| 1087 |
-
ctx.x_shape_og = x_shape_og
|
| 1088 |
-
ctx.eps = eps
|
| 1089 |
-
ctx.is_rms_norm = is_rms_norm
|
| 1090 |
-
ctx.has_residual = residual is not None
|
| 1091 |
-
ctx.prenorm = prenorm
|
| 1092 |
-
ctx.x_dtype = x.dtype
|
| 1093 |
-
ctx.linear_bias_is_none = linear_bias is None
|
| 1094 |
-
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
|
| 1095 |
-
|
| 1096 |
-
@staticmethod
|
| 1097 |
-
@custom_bwd
|
| 1098 |
-
def backward(ctx, dout, *args):
|
| 1099 |
-
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
|
| 1100 |
-
dout = dout.reshape(-1, dout.shape[-1])
|
| 1101 |
-
dy = F.linear(dout, linear_weight.t())
|
| 1102 |
-
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
|
| 1103 |
-
if dy.stride(-1) != 1:
|
| 1104 |
-
dy = dy.contiguous()
|
| 1105 |
-
assert dy.shape == x.shape
|
| 1106 |
-
if ctx.prenorm:
|
| 1107 |
-
dresidual = args[0]
|
| 1108 |
-
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
| 1109 |
-
if dresidual.stride(-1) != 1:
|
| 1110 |
-
dresidual = dresidual.contiguous()
|
| 1111 |
-
assert dresidual.shape == x.shape
|
| 1112 |
-
else:
|
| 1113 |
-
dresidual = None
|
| 1114 |
-
dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
|
| 1115 |
-
dy,
|
| 1116 |
-
x,
|
| 1117 |
-
norm_weight,
|
| 1118 |
-
norm_bias,
|
| 1119 |
-
ctx.eps,
|
| 1120 |
-
mean,
|
| 1121 |
-
rstd,
|
| 1122 |
-
dresidual=dresidual,
|
| 1123 |
-
has_residual=ctx.has_residual,
|
| 1124 |
-
is_rms_norm=ctx.is_rms_norm,
|
| 1125 |
-
x_dtype=ctx.x_dtype,
|
| 1126 |
-
recompute_output=True,
|
| 1127 |
-
)
|
| 1128 |
-
dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
|
| 1129 |
-
return (
|
| 1130 |
-
dx.reshape(ctx.x_shape_og),
|
| 1131 |
-
dnorm_weight,
|
| 1132 |
-
dnorm_bias,
|
| 1133 |
-
dlinear_weight,
|
| 1134 |
-
dlinear_bias,
|
| 1135 |
-
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
| 1136 |
-
None,
|
| 1137 |
-
None,
|
| 1138 |
-
None,
|
| 1139 |
-
None,
|
| 1140 |
-
)
|
| 1141 |
-
|
| 1142 |
-
|
| 1143 |
-
def layer_norm_linear_fn(
|
| 1144 |
-
x,
|
| 1145 |
-
norm_weight,
|
| 1146 |
-
norm_bias,
|
| 1147 |
-
linear_weight,
|
| 1148 |
-
linear_bias,
|
| 1149 |
-
residual=None,
|
| 1150 |
-
eps=1e-6,
|
| 1151 |
-
prenorm=False,
|
| 1152 |
-
residual_in_fp32=False,
|
| 1153 |
-
is_rms_norm=False,
|
| 1154 |
-
):
|
| 1155 |
-
return LayerNormLinearFn.apply(
|
| 1156 |
-
x,
|
| 1157 |
-
norm_weight,
|
| 1158 |
-
norm_bias,
|
| 1159 |
-
linear_weight,
|
| 1160 |
-
linear_bias,
|
| 1161 |
-
residual,
|
| 1162 |
-
eps,
|
| 1163 |
-
prenorm,
|
| 1164 |
-
residual_in_fp32,
|
| 1165 |
-
is_rms_norm,
|
| 1166 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/selective_state_update.py
DELETED
|
@@ -1,389 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
| 2 |
-
|
| 3 |
-
"""We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import math
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
|
| 10 |
-
import triton
|
| 11 |
-
import triton.language as tl
|
| 12 |
-
|
| 13 |
-
from einops import rearrange, repeat
|
| 14 |
-
|
| 15 |
-
from .softplus import softplus
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
|
| 19 |
-
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
|
| 20 |
-
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
|
| 21 |
-
@triton.heuristics(
|
| 22 |
-
{
|
| 23 |
-
"HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"]
|
| 24 |
-
is not None
|
| 25 |
-
}
|
| 26 |
-
)
|
| 27 |
-
@triton.heuristics(
|
| 28 |
-
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}
|
| 29 |
-
)
|
| 30 |
-
@triton.jit
|
| 31 |
-
def _selective_scan_update_kernel(
|
| 32 |
-
# Pointers to matrices
|
| 33 |
-
state_ptr,
|
| 34 |
-
x_ptr,
|
| 35 |
-
dt_ptr,
|
| 36 |
-
dt_bias_ptr,
|
| 37 |
-
A_ptr,
|
| 38 |
-
B_ptr,
|
| 39 |
-
C_ptr,
|
| 40 |
-
D_ptr,
|
| 41 |
-
z_ptr,
|
| 42 |
-
out_ptr,
|
| 43 |
-
state_batch_indices_ptr,
|
| 44 |
-
# Matrix dimensions
|
| 45 |
-
batch,
|
| 46 |
-
nheads,
|
| 47 |
-
dim,
|
| 48 |
-
dstate,
|
| 49 |
-
nheads_ngroups_ratio,
|
| 50 |
-
# Strides
|
| 51 |
-
stride_state_batch,
|
| 52 |
-
stride_state_head,
|
| 53 |
-
stride_state_dim,
|
| 54 |
-
stride_state_dstate,
|
| 55 |
-
stride_x_batch,
|
| 56 |
-
stride_x_head,
|
| 57 |
-
stride_x_dim,
|
| 58 |
-
stride_dt_batch,
|
| 59 |
-
stride_dt_head,
|
| 60 |
-
stride_dt_dim,
|
| 61 |
-
stride_dt_bias_head,
|
| 62 |
-
stride_dt_bias_dim,
|
| 63 |
-
stride_A_head,
|
| 64 |
-
stride_A_dim,
|
| 65 |
-
stride_A_dstate,
|
| 66 |
-
stride_B_batch,
|
| 67 |
-
stride_B_group,
|
| 68 |
-
stride_B_dstate,
|
| 69 |
-
stride_C_batch,
|
| 70 |
-
stride_C_group,
|
| 71 |
-
stride_C_dstate,
|
| 72 |
-
stride_D_head,
|
| 73 |
-
stride_D_dim,
|
| 74 |
-
stride_z_batch,
|
| 75 |
-
stride_z_head,
|
| 76 |
-
stride_z_dim,
|
| 77 |
-
stride_out_batch,
|
| 78 |
-
stride_out_head,
|
| 79 |
-
stride_out_dim,
|
| 80 |
-
# Meta-parameters
|
| 81 |
-
DT_SOFTPLUS: tl.constexpr,
|
| 82 |
-
TIE_HDIM: tl.constexpr,
|
| 83 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 84 |
-
HAS_DT_BIAS: tl.constexpr,
|
| 85 |
-
HAS_D: tl.constexpr,
|
| 86 |
-
HAS_Z: tl.constexpr,
|
| 87 |
-
HAS_STATE_BATCH_INDICES: tl.constexpr,
|
| 88 |
-
BLOCK_SIZE_DSTATE: tl.constexpr,
|
| 89 |
-
):
|
| 90 |
-
pid_m = tl.program_id(axis=0)
|
| 91 |
-
pid_b = tl.program_id(axis=1)
|
| 92 |
-
pid_h = tl.program_id(axis=2)
|
| 93 |
-
|
| 94 |
-
if HAS_STATE_BATCH_INDICES:
|
| 95 |
-
state_batch_indices_ptr += pid_b
|
| 96 |
-
state_batch_idx = tl.load(state_batch_indices_ptr)
|
| 97 |
-
state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
|
| 98 |
-
else:
|
| 99 |
-
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
|
| 100 |
-
|
| 101 |
-
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
|
| 102 |
-
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
|
| 103 |
-
if HAS_DT_BIAS:
|
| 104 |
-
dt_bias_ptr += pid_h * stride_dt_bias_head
|
| 105 |
-
A_ptr += pid_h * stride_A_head
|
| 106 |
-
B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
|
| 107 |
-
C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
|
| 108 |
-
if HAS_Z:
|
| 109 |
-
z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
|
| 110 |
-
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
|
| 111 |
-
|
| 112 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 113 |
-
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
|
| 114 |
-
state_ptrs = state_ptr + (
|
| 115 |
-
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
|
| 116 |
-
)
|
| 117 |
-
x_ptrs = x_ptr + offs_m * stride_x_dim
|
| 118 |
-
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
|
| 119 |
-
if HAS_DT_BIAS:
|
| 120 |
-
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
|
| 121 |
-
if HAS_D:
|
| 122 |
-
D_ptr += pid_h * stride_D_head
|
| 123 |
-
A_ptrs = A_ptr + (
|
| 124 |
-
offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
|
| 125 |
-
)
|
| 126 |
-
B_ptrs = B_ptr + offs_n * stride_B_dstate
|
| 127 |
-
C_ptrs = C_ptr + offs_n * stride_C_dstate
|
| 128 |
-
if HAS_D:
|
| 129 |
-
D_ptrs = D_ptr + offs_m * stride_D_dim
|
| 130 |
-
if HAS_Z:
|
| 131 |
-
z_ptrs = z_ptr + offs_m * stride_z_dim
|
| 132 |
-
out_ptrs = out_ptr + offs_m * stride_out_dim
|
| 133 |
-
|
| 134 |
-
state = tl.load(
|
| 135 |
-
state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
|
| 136 |
-
)
|
| 137 |
-
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 138 |
-
if not TIE_HDIM:
|
| 139 |
-
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 140 |
-
if HAS_DT_BIAS:
|
| 141 |
-
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 142 |
-
if DT_SOFTPLUS:
|
| 143 |
-
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
| 144 |
-
A = tl.load(
|
| 145 |
-
A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
|
| 146 |
-
).to(tl.float32)
|
| 147 |
-
dA = tl.exp(A * dt[:, None])
|
| 148 |
-
else:
|
| 149 |
-
dt = tl.load(dt_ptr).to(tl.float32)
|
| 150 |
-
if HAS_DT_BIAS:
|
| 151 |
-
dt += tl.load(dt_bias_ptr).to(tl.float32)
|
| 152 |
-
if DT_SOFTPLUS:
|
| 153 |
-
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
| 154 |
-
A = tl.load(A_ptr).to(tl.float32)
|
| 155 |
-
dA = tl.exp(A * dt) # scalar, not a matrix
|
| 156 |
-
|
| 157 |
-
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
| 158 |
-
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
| 159 |
-
if HAS_D:
|
| 160 |
-
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 161 |
-
if HAS_Z:
|
| 162 |
-
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 163 |
-
|
| 164 |
-
if not TIE_HDIM:
|
| 165 |
-
dB = B[None, :] * dt[:, None]
|
| 166 |
-
else:
|
| 167 |
-
dB = B * dt # vector of size (dstate,)
|
| 168 |
-
state = state * dA + dB * x[:, None]
|
| 169 |
-
tl.store(
|
| 170 |
-
state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
| 171 |
-
)
|
| 172 |
-
out = tl.sum(state * C[None, :], axis=1)
|
| 173 |
-
if HAS_D:
|
| 174 |
-
out += x * D
|
| 175 |
-
if HAS_Z:
|
| 176 |
-
out *= z * tl.sigmoid(z)
|
| 177 |
-
tl.store(out_ptrs, out, mask=offs_m < dim)
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
def selective_state_update(
|
| 181 |
-
state,
|
| 182 |
-
x,
|
| 183 |
-
dt,
|
| 184 |
-
A,
|
| 185 |
-
B,
|
| 186 |
-
C,
|
| 187 |
-
D=None,
|
| 188 |
-
z=None,
|
| 189 |
-
dt_bias=None,
|
| 190 |
-
dt_softplus=False,
|
| 191 |
-
state_batch_indices=None,
|
| 192 |
-
):
|
| 193 |
-
"""
|
| 194 |
-
Argument:
|
| 195 |
-
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
| 196 |
-
x: (batch, dim) or (batch, nheads, dim)
|
| 197 |
-
dt: (batch, dim) or (batch, nheads, dim)
|
| 198 |
-
A: (dim, dstate) or (nheads, dim, dstate)
|
| 199 |
-
B: (batch, dstate) or (batch, ngroups, dstate)
|
| 200 |
-
C: (batch, dstate) or (batch, ngroups, dstate)
|
| 201 |
-
D: (dim,) or (nheads, dim)
|
| 202 |
-
z: (batch, dim) or (batch, nheads, dim)
|
| 203 |
-
dt_bias: (dim,) or (nheads, dim)
|
| 204 |
-
Return:
|
| 205 |
-
out: (batch, dim) or (batch, nheads, dim)
|
| 206 |
-
"""
|
| 207 |
-
has_heads = state.dim() > 3
|
| 208 |
-
if state.dim() == 3:
|
| 209 |
-
state = state.unsqueeze(1)
|
| 210 |
-
if x.dim() == 2:
|
| 211 |
-
x = x.unsqueeze(1)
|
| 212 |
-
if dt.dim() == 2:
|
| 213 |
-
dt = dt.unsqueeze(1)
|
| 214 |
-
if A.dim() == 2:
|
| 215 |
-
A = A.unsqueeze(0)
|
| 216 |
-
if B.dim() == 2:
|
| 217 |
-
B = B.unsqueeze(1)
|
| 218 |
-
if C.dim() == 2:
|
| 219 |
-
C = C.unsqueeze(1)
|
| 220 |
-
if D is not None and D.dim() == 1:
|
| 221 |
-
D = D.unsqueeze(0)
|
| 222 |
-
if z is not None and z.dim() == 2:
|
| 223 |
-
z = z.unsqueeze(1)
|
| 224 |
-
if dt_bias is not None and dt_bias.dim() == 1:
|
| 225 |
-
dt_bias = dt_bias.unsqueeze(0)
|
| 226 |
-
_, nheads, dim, dstate = state.shape
|
| 227 |
-
batch = x.shape[0]
|
| 228 |
-
if x.shape != (batch, nheads, dim):
|
| 229 |
-
print(f"{state.shape} {x.shape} {batch} {nheads} {dim}")
|
| 230 |
-
assert x.shape == (batch, nheads, dim)
|
| 231 |
-
assert dt.shape == x.shape
|
| 232 |
-
assert A.shape == (nheads, dim, dstate)
|
| 233 |
-
ngroups = B.shape[1]
|
| 234 |
-
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
| 235 |
-
assert B.shape == (batch, ngroups, dstate)
|
| 236 |
-
assert C.shape == B.shape
|
| 237 |
-
if D is not None:
|
| 238 |
-
assert D.shape == (nheads, dim)
|
| 239 |
-
if z is not None:
|
| 240 |
-
assert z.shape == x.shape
|
| 241 |
-
if dt_bias is not None:
|
| 242 |
-
assert dt_bias.shape == (nheads, dim)
|
| 243 |
-
if state_batch_indices is not None:
|
| 244 |
-
assert state_batch_indices.shape == (batch,)
|
| 245 |
-
out = torch.empty_like(x)
|
| 246 |
-
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
|
| 247 |
-
z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
|
| 248 |
-
# We don't want autotune since it will overwrite the state
|
| 249 |
-
# We instead tune by hand.
|
| 250 |
-
BLOCK_SIZE_M, num_warps = (
|
| 251 |
-
(32, 4)
|
| 252 |
-
if dstate <= 16
|
| 253 |
-
else (
|
| 254 |
-
(16, 4)
|
| 255 |
-
if dstate <= 32
|
| 256 |
-
else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
|
| 257 |
-
)
|
| 258 |
-
)
|
| 259 |
-
tie_hdim = (
|
| 260 |
-
A.stride(-1) == 0
|
| 261 |
-
and A.stride(-2) == 0
|
| 262 |
-
and dt.stride(-1) == 0
|
| 263 |
-
and dt_bias.stride(-1) == 0
|
| 264 |
-
)
|
| 265 |
-
with torch.cuda.device(x.device.index):
|
| 266 |
-
_selective_scan_update_kernel[grid](
|
| 267 |
-
state,
|
| 268 |
-
x,
|
| 269 |
-
dt,
|
| 270 |
-
dt_bias,
|
| 271 |
-
A,
|
| 272 |
-
B,
|
| 273 |
-
C,
|
| 274 |
-
D,
|
| 275 |
-
z,
|
| 276 |
-
out,
|
| 277 |
-
state_batch_indices,
|
| 278 |
-
batch,
|
| 279 |
-
nheads,
|
| 280 |
-
dim,
|
| 281 |
-
dstate,
|
| 282 |
-
nheads // ngroups,
|
| 283 |
-
state.stride(0),
|
| 284 |
-
state.stride(1),
|
| 285 |
-
state.stride(2),
|
| 286 |
-
state.stride(3),
|
| 287 |
-
x.stride(0),
|
| 288 |
-
x.stride(1),
|
| 289 |
-
x.stride(2),
|
| 290 |
-
dt.stride(0),
|
| 291 |
-
dt.stride(1),
|
| 292 |
-
dt.stride(2),
|
| 293 |
-
*(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
|
| 294 |
-
A.stride(0),
|
| 295 |
-
A.stride(1),
|
| 296 |
-
A.stride(2),
|
| 297 |
-
B.stride(0),
|
| 298 |
-
B.stride(1),
|
| 299 |
-
B.stride(2),
|
| 300 |
-
C.stride(0),
|
| 301 |
-
C.stride(1),
|
| 302 |
-
C.stride(2),
|
| 303 |
-
*(D.stride(0), D.stride(1)) if D is not None else 0,
|
| 304 |
-
z_strides[0],
|
| 305 |
-
z_strides[1],
|
| 306 |
-
z_strides[2],
|
| 307 |
-
out.stride(0),
|
| 308 |
-
out.stride(1),
|
| 309 |
-
out.stride(2),
|
| 310 |
-
dt_softplus,
|
| 311 |
-
tie_hdim,
|
| 312 |
-
BLOCK_SIZE_M,
|
| 313 |
-
num_warps=num_warps,
|
| 314 |
-
)
|
| 315 |
-
if not has_heads:
|
| 316 |
-
out = out.squeeze(1)
|
| 317 |
-
return out
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
def selective_state_update_ref(
|
| 321 |
-
state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
|
| 322 |
-
):
|
| 323 |
-
"""
|
| 324 |
-
Argument:
|
| 325 |
-
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
| 326 |
-
x: (batch, dim) or (batch, nheads, dim)
|
| 327 |
-
dt: (batch, dim) or (batch, nheads, dim)
|
| 328 |
-
A: (dim, dstate) or (nheads, dim, dstate)
|
| 329 |
-
B: (batch, dstate) or (batch, ngroups, dstate)
|
| 330 |
-
C: (batch, dstate) or (batch, ngroups, dstate)
|
| 331 |
-
D: (dim,) or (nheads, dim)
|
| 332 |
-
z: (batch, dim) or (batch, nheads, dim)
|
| 333 |
-
dt_bias: (dim,) or (nheads, dim)
|
| 334 |
-
Return:
|
| 335 |
-
out: (batch, dim) or (batch, nheads, dim)
|
| 336 |
-
"""
|
| 337 |
-
has_heads = state.dim() > 3
|
| 338 |
-
if state.dim() == 3:
|
| 339 |
-
state = state.unsqueeze(1)
|
| 340 |
-
if x.dim() == 2:
|
| 341 |
-
x = x.unsqueeze(1)
|
| 342 |
-
if dt.dim() == 2:
|
| 343 |
-
dt = dt.unsqueeze(1)
|
| 344 |
-
if A.dim() == 2:
|
| 345 |
-
A = A.unsqueeze(0)
|
| 346 |
-
if B.dim() == 2:
|
| 347 |
-
B = B.unsqueeze(1)
|
| 348 |
-
if C.dim() == 2:
|
| 349 |
-
C = C.unsqueeze(1)
|
| 350 |
-
if D is not None and D.dim() == 1:
|
| 351 |
-
D = D.unsqueeze(0)
|
| 352 |
-
if z is not None and z.dim() == 2:
|
| 353 |
-
z = z.unsqueeze(1)
|
| 354 |
-
if dt_bias is not None and dt_bias.dim() == 1:
|
| 355 |
-
dt_bias = dt_bias.unsqueeze(0)
|
| 356 |
-
batch, nheads, dim, dstate = state.shape
|
| 357 |
-
assert x.shape == (batch, nheads, dim)
|
| 358 |
-
assert dt.shape == x.shape
|
| 359 |
-
assert A.shape == (nheads, dim, dstate)
|
| 360 |
-
ngroups = B.shape[1]
|
| 361 |
-
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
| 362 |
-
assert B.shape == (batch, ngroups, dstate)
|
| 363 |
-
assert C.shape == B.shape
|
| 364 |
-
if D is not None:
|
| 365 |
-
assert D.shape == (nheads, dim)
|
| 366 |
-
if z is not None:
|
| 367 |
-
assert z.shape == x.shape
|
| 368 |
-
if dt_bias is not None:
|
| 369 |
-
assert dt_bias.shape == (nheads, dim)
|
| 370 |
-
dt = dt + dt_bias
|
| 371 |
-
dt = F.softplus(dt) if dt_softplus else dt
|
| 372 |
-
dA = torch.exp(
|
| 373 |
-
rearrange(dt, "b h d -> b h d 1") * A
|
| 374 |
-
) # (batch, nheads, dim, dstate)
|
| 375 |
-
B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
|
| 376 |
-
C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
|
| 377 |
-
dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
|
| 378 |
-
B, "b h n -> b h 1 n"
|
| 379 |
-
) # (batch, nheads, dim, dstate)
|
| 380 |
-
state.copy_(
|
| 381 |
-
state * dA + dB * rearrange(x, "b h d -> b h d 1")
|
| 382 |
-
) # (batch, dim, dstate
|
| 383 |
-
out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
|
| 384 |
-
if D is not None:
|
| 385 |
-
out += (x * D).to(out.dtype)
|
| 386 |
-
out = (out if z is None else out * F.silu(z)).to(x.dtype)
|
| 387 |
-
if not has_heads:
|
| 388 |
-
out = out.squeeze(1)
|
| 389 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_scan.py
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_state.py
DELETED
|
@@ -1,2012 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
| 2 |
-
|
| 3 |
-
"""We want triton==2.1.0 or 2.2.0 for this
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import math
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
|
| 10 |
-
import triton
|
| 11 |
-
import triton.language as tl
|
| 12 |
-
|
| 13 |
-
from einops import rearrange, repeat
|
| 14 |
-
|
| 15 |
-
from .softplus import softplus
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def init_to_zero(names):
|
| 19 |
-
return lambda nargs: [
|
| 20 |
-
nargs[name].zero_() for name in names if nargs[name] is not None
|
| 21 |
-
]
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
@triton.autotune(
|
| 25 |
-
configs=[
|
| 26 |
-
triton.Config({"BLOCK_SIZE_H": 1}),
|
| 27 |
-
triton.Config({"BLOCK_SIZE_H": 2}),
|
| 28 |
-
triton.Config({"BLOCK_SIZE_H": 4}),
|
| 29 |
-
triton.Config({"BLOCK_SIZE_H": 8}),
|
| 30 |
-
triton.Config({"BLOCK_SIZE_H": 16}),
|
| 31 |
-
triton.Config({"BLOCK_SIZE_H": 32}),
|
| 32 |
-
triton.Config({"BLOCK_SIZE_H": 64}),
|
| 33 |
-
],
|
| 34 |
-
key=["chunk_size", "nheads"],
|
| 35 |
-
)
|
| 36 |
-
@triton.jit
|
| 37 |
-
def _chunk_cumsum_fwd_kernel(
|
| 38 |
-
# Pointers to matrices
|
| 39 |
-
dt_ptr,
|
| 40 |
-
A_ptr,
|
| 41 |
-
dt_bias_ptr,
|
| 42 |
-
dt_out_ptr,
|
| 43 |
-
dA_cumsum_ptr,
|
| 44 |
-
# Matrix dimension
|
| 45 |
-
batch,
|
| 46 |
-
seqlen,
|
| 47 |
-
nheads,
|
| 48 |
-
chunk_size,
|
| 49 |
-
dt_min,
|
| 50 |
-
dt_max,
|
| 51 |
-
# Strides
|
| 52 |
-
stride_dt_batch,
|
| 53 |
-
stride_dt_seqlen,
|
| 54 |
-
stride_dt_head,
|
| 55 |
-
stride_A_head,
|
| 56 |
-
stride_dt_bias_head,
|
| 57 |
-
stride_dt_out_batch,
|
| 58 |
-
stride_dt_out_chunk,
|
| 59 |
-
stride_dt_out_head,
|
| 60 |
-
stride_dt_out_csize,
|
| 61 |
-
stride_dA_cs_batch,
|
| 62 |
-
stride_dA_cs_chunk,
|
| 63 |
-
stride_dA_cs_head,
|
| 64 |
-
stride_dA_cs_csize,
|
| 65 |
-
# Meta-parameters
|
| 66 |
-
DT_SOFTPLUS: tl.constexpr,
|
| 67 |
-
HAS_DT_BIAS: tl.constexpr,
|
| 68 |
-
BLOCK_SIZE_H: tl.constexpr,
|
| 69 |
-
BLOCK_SIZE_CHUNK: tl.constexpr,
|
| 70 |
-
):
|
| 71 |
-
pid_b = tl.program_id(axis=0)
|
| 72 |
-
pid_c = tl.program_id(axis=1)
|
| 73 |
-
pid_h = tl.program_id(axis=2)
|
| 74 |
-
dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
|
| 75 |
-
dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk
|
| 76 |
-
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk
|
| 77 |
-
|
| 78 |
-
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
|
| 79 |
-
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
|
| 80 |
-
dt_ptrs = dt_ptr + (
|
| 81 |
-
offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
|
| 82 |
-
)
|
| 83 |
-
A_ptrs = A_ptr + offs_h * stride_A_head
|
| 84 |
-
dt_out_ptrs = dt_out_ptr + (
|
| 85 |
-
offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize
|
| 86 |
-
)
|
| 87 |
-
dA_cs_ptrs = dA_cumsum_ptr + (
|
| 88 |
-
offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize
|
| 89 |
-
)
|
| 90 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 91 |
-
|
| 92 |
-
dt = tl.load(
|
| 93 |
-
dt_ptrs,
|
| 94 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
| 95 |
-
other=0.0,
|
| 96 |
-
).to(tl.float32)
|
| 97 |
-
if HAS_DT_BIAS:
|
| 98 |
-
dt_bias = tl.load(
|
| 99 |
-
dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
|
| 100 |
-
).to(tl.float32)
|
| 101 |
-
dt += dt_bias[:, None]
|
| 102 |
-
if DT_SOFTPLUS:
|
| 103 |
-
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
| 104 |
-
# As of Triton 2.2.0, tl.clamp is not available yet
|
| 105 |
-
# dt = tl.clamp(dt, dt_min, dt_max)
|
| 106 |
-
dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
|
| 107 |
-
dt = tl.where(
|
| 108 |
-
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
|
| 109 |
-
)
|
| 110 |
-
tl.store(
|
| 111 |
-
dt_out_ptrs,
|
| 112 |
-
dt,
|
| 113 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
|
| 114 |
-
)
|
| 115 |
-
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
|
| 116 |
-
dA = dt * A[:, None]
|
| 117 |
-
dA_cs = tl.cumsum(dA, axis=1)
|
| 118 |
-
tl.store(
|
| 119 |
-
dA_cs_ptrs,
|
| 120 |
-
dA_cs,
|
| 121 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
|
| 122 |
-
)
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
@triton.autotune(
|
| 126 |
-
configs=[
|
| 127 |
-
triton.Config(
|
| 128 |
-
{"BLOCK_SIZE_H": 1}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 129 |
-
),
|
| 130 |
-
triton.Config(
|
| 131 |
-
{"BLOCK_SIZE_H": 2}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 132 |
-
),
|
| 133 |
-
triton.Config(
|
| 134 |
-
{"BLOCK_SIZE_H": 4}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 135 |
-
),
|
| 136 |
-
triton.Config(
|
| 137 |
-
{"BLOCK_SIZE_H": 8}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 138 |
-
),
|
| 139 |
-
triton.Config(
|
| 140 |
-
{"BLOCK_SIZE_H": 16}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 141 |
-
),
|
| 142 |
-
triton.Config(
|
| 143 |
-
{"BLOCK_SIZE_H": 32}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 144 |
-
),
|
| 145 |
-
triton.Config(
|
| 146 |
-
{"BLOCK_SIZE_H": 64}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 147 |
-
),
|
| 148 |
-
],
|
| 149 |
-
key=["chunk_size", "nheads"],
|
| 150 |
-
)
|
| 151 |
-
@triton.jit
|
| 152 |
-
def _chunk_cumsum_bwd_kernel(
|
| 153 |
-
# Pointers to matrices
|
| 154 |
-
ddA_ptr,
|
| 155 |
-
ddt_out_ptr,
|
| 156 |
-
dt_ptr,
|
| 157 |
-
A_ptr,
|
| 158 |
-
dt_bias_ptr,
|
| 159 |
-
ddt_ptr,
|
| 160 |
-
dA_ptr,
|
| 161 |
-
ddt_bias_ptr,
|
| 162 |
-
# Matrix dimensions
|
| 163 |
-
batch,
|
| 164 |
-
seqlen,
|
| 165 |
-
nheads,
|
| 166 |
-
chunk_size,
|
| 167 |
-
dt_min,
|
| 168 |
-
dt_max,
|
| 169 |
-
# Strides
|
| 170 |
-
stride_ddA_batch,
|
| 171 |
-
stride_ddA_chunk,
|
| 172 |
-
stride_ddA_head,
|
| 173 |
-
stride_ddA_csize,
|
| 174 |
-
stride_ddt_out_batch,
|
| 175 |
-
stride_ddt_out_chunk,
|
| 176 |
-
stride_ddt_out_head,
|
| 177 |
-
stride_ddt_out_csize,
|
| 178 |
-
stride_dt_batch,
|
| 179 |
-
stride_dt_seqlen,
|
| 180 |
-
stride_dt_head,
|
| 181 |
-
stride_A_head,
|
| 182 |
-
stride_dt_bias_head,
|
| 183 |
-
stride_ddt_batch,
|
| 184 |
-
stride_ddt_seqlen,
|
| 185 |
-
stride_ddt_head,
|
| 186 |
-
stride_dA_head,
|
| 187 |
-
stride_ddt_bias_head,
|
| 188 |
-
# Meta-parameters
|
| 189 |
-
DT_SOFTPLUS: tl.constexpr,
|
| 190 |
-
HAS_DT_BIAS: tl.constexpr,
|
| 191 |
-
BLOCK_SIZE_H: tl.constexpr,
|
| 192 |
-
BLOCK_SIZE_CHUNK: tl.constexpr,
|
| 193 |
-
):
|
| 194 |
-
pid_b = tl.program_id(axis=0)
|
| 195 |
-
pid_c = tl.program_id(axis=1)
|
| 196 |
-
pid_h = tl.program_id(axis=2)
|
| 197 |
-
ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk
|
| 198 |
-
ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk
|
| 199 |
-
dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
|
| 200 |
-
ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen
|
| 201 |
-
|
| 202 |
-
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
|
| 203 |
-
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
|
| 204 |
-
ddt_out_ptrs = ddt_out_ptr + (
|
| 205 |
-
offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize
|
| 206 |
-
)
|
| 207 |
-
ddA_ptrs = ddA_ptr + (
|
| 208 |
-
offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize
|
| 209 |
-
)
|
| 210 |
-
dt_ptrs = dt_ptr + (
|
| 211 |
-
offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
|
| 212 |
-
)
|
| 213 |
-
ddt_ptrs = ddt_ptr + (
|
| 214 |
-
offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen
|
| 215 |
-
)
|
| 216 |
-
A_ptrs = A_ptr + offs_h * stride_A_head
|
| 217 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 218 |
-
|
| 219 |
-
ddA = tl.load(
|
| 220 |
-
ddA_ptrs,
|
| 221 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
| 222 |
-
other=0.0,
|
| 223 |
-
).to(tl.float32)
|
| 224 |
-
ddt_out = tl.load(
|
| 225 |
-
ddt_out_ptrs,
|
| 226 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
| 227 |
-
other=0.0,
|
| 228 |
-
).to(tl.float32)
|
| 229 |
-
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
|
| 230 |
-
ddt = ddA * A[:, None] + ddt_out
|
| 231 |
-
dt = tl.load(
|
| 232 |
-
dt_ptrs,
|
| 233 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
| 234 |
-
other=0.0,
|
| 235 |
-
).to(tl.float32)
|
| 236 |
-
if HAS_DT_BIAS:
|
| 237 |
-
dt_bias = tl.load(
|
| 238 |
-
dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
|
| 239 |
-
).to(tl.float32)
|
| 240 |
-
dt += dt_bias[:, None]
|
| 241 |
-
if DT_SOFTPLUS:
|
| 242 |
-
dt_presoftplus = dt
|
| 243 |
-
dt = tl.where(dt <= 20.0, softplus(dt), ddt)
|
| 244 |
-
clamp_mask = (dt < dt_min) | (dt > dt_max)
|
| 245 |
-
# As of Triton 2.2.0, tl.clamp is not available yet
|
| 246 |
-
# dt = tl.clamp(dt, dt_min, dt_max)
|
| 247 |
-
dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
|
| 248 |
-
dt = tl.where(
|
| 249 |
-
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
|
| 250 |
-
)
|
| 251 |
-
ddt = tl.where(
|
| 252 |
-
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0
|
| 253 |
-
)
|
| 254 |
-
ddt = tl.where(clamp_mask, 0.0, ddt)
|
| 255 |
-
if DT_SOFTPLUS:
|
| 256 |
-
ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt)
|
| 257 |
-
tl.store(
|
| 258 |
-
ddt_ptrs,
|
| 259 |
-
ddt,
|
| 260 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
| 261 |
-
)
|
| 262 |
-
dA = tl.sum(ddA * dt, axis=1)
|
| 263 |
-
tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads)
|
| 264 |
-
if HAS_DT_BIAS:
|
| 265 |
-
ddt_bias = tl.sum(ddt, axis=1)
|
| 266 |
-
tl.atomic_add(
|
| 267 |
-
ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads
|
| 268 |
-
)
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
@triton.autotune(
|
| 272 |
-
configs=[
|
| 273 |
-
triton.Config(
|
| 274 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
| 275 |
-
num_stages=3,
|
| 276 |
-
num_warps=8,
|
| 277 |
-
),
|
| 278 |
-
triton.Config(
|
| 279 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
| 280 |
-
num_stages=4,
|
| 281 |
-
num_warps=4,
|
| 282 |
-
),
|
| 283 |
-
triton.Config(
|
| 284 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 285 |
-
num_stages=4,
|
| 286 |
-
num_warps=4,
|
| 287 |
-
),
|
| 288 |
-
triton.Config(
|
| 289 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 290 |
-
num_stages=4,
|
| 291 |
-
num_warps=4,
|
| 292 |
-
),
|
| 293 |
-
triton.Config(
|
| 294 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 295 |
-
num_stages=4,
|
| 296 |
-
num_warps=4,
|
| 297 |
-
),
|
| 298 |
-
triton.Config(
|
| 299 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 300 |
-
num_stages=4,
|
| 301 |
-
num_warps=4,
|
| 302 |
-
),
|
| 303 |
-
triton.Config(
|
| 304 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 305 |
-
num_stages=5,
|
| 306 |
-
num_warps=2,
|
| 307 |
-
),
|
| 308 |
-
triton.Config(
|
| 309 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 310 |
-
num_stages=5,
|
| 311 |
-
num_warps=2,
|
| 312 |
-
),
|
| 313 |
-
triton.Config(
|
| 314 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 315 |
-
num_stages=4,
|
| 316 |
-
num_warps=2,
|
| 317 |
-
),
|
| 318 |
-
],
|
| 319 |
-
key=["hdim", "dstate", "chunk_size"],
|
| 320 |
-
)
|
| 321 |
-
@triton.jit
|
| 322 |
-
def _chunk_state_fwd_kernel(
|
| 323 |
-
# Pointers to matrices
|
| 324 |
-
x_ptr,
|
| 325 |
-
b_ptr,
|
| 326 |
-
states_ptr,
|
| 327 |
-
dt_ptr,
|
| 328 |
-
dA_cumsum_ptr,
|
| 329 |
-
seq_idx_ptr,
|
| 330 |
-
# Matrix dimensions
|
| 331 |
-
hdim,
|
| 332 |
-
dstate,
|
| 333 |
-
chunk_size,
|
| 334 |
-
batch,
|
| 335 |
-
seqlen,
|
| 336 |
-
nheads_ngroups_ratio,
|
| 337 |
-
# Strides
|
| 338 |
-
stride_x_batch,
|
| 339 |
-
stride_x_seqlen,
|
| 340 |
-
stride_x_head,
|
| 341 |
-
stride_x_hdim,
|
| 342 |
-
stride_b_batch,
|
| 343 |
-
stride_b_seqlen,
|
| 344 |
-
stride_b_head,
|
| 345 |
-
stride_b_dstate,
|
| 346 |
-
stride_states_batch,
|
| 347 |
-
stride_states_chunk,
|
| 348 |
-
stride_states_head,
|
| 349 |
-
stride_states_hdim,
|
| 350 |
-
stride_states_dstate,
|
| 351 |
-
stride_dt_batch,
|
| 352 |
-
stride_dt_chunk,
|
| 353 |
-
stride_dt_head,
|
| 354 |
-
stride_dt_csize,
|
| 355 |
-
stride_dA_cs_batch,
|
| 356 |
-
stride_dA_cs_chunk,
|
| 357 |
-
stride_dA_cs_head,
|
| 358 |
-
stride_dA_cs_csize,
|
| 359 |
-
stride_seq_idx_batch,
|
| 360 |
-
stride_seq_idx_seqlen,
|
| 361 |
-
# Meta-parameters
|
| 362 |
-
HAS_SEQ_IDX: tl.constexpr,
|
| 363 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 364 |
-
BLOCK_SIZE_N: tl.constexpr,
|
| 365 |
-
BLOCK_SIZE_K: tl.constexpr,
|
| 366 |
-
):
|
| 367 |
-
pid_bc = tl.program_id(axis=1)
|
| 368 |
-
pid_c = pid_bc // batch
|
| 369 |
-
pid_b = pid_bc - pid_c * batch
|
| 370 |
-
pid_h = tl.program_id(axis=2)
|
| 371 |
-
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
| 372 |
-
pid_m = tl.program_id(axis=0) // num_pid_n
|
| 373 |
-
pid_n = tl.program_id(axis=0) % num_pid_n
|
| 374 |
-
b_ptr += (
|
| 375 |
-
pid_b * stride_b_batch
|
| 376 |
-
+ pid_c * chunk_size * stride_b_seqlen
|
| 377 |
-
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
| 378 |
-
)
|
| 379 |
-
x_ptr += (
|
| 380 |
-
pid_b * stride_x_batch
|
| 381 |
-
+ pid_c * chunk_size * stride_x_seqlen
|
| 382 |
-
+ pid_h * stride_x_head
|
| 383 |
-
)
|
| 384 |
-
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
| 385 |
-
dA_cumsum_ptr += (
|
| 386 |
-
pid_b * stride_dA_cs_batch
|
| 387 |
-
+ pid_c * stride_dA_cs_chunk
|
| 388 |
-
+ pid_h * stride_dA_cs_head
|
| 389 |
-
)
|
| 390 |
-
if HAS_SEQ_IDX:
|
| 391 |
-
seq_idx_ptr += (
|
| 392 |
-
pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
| 393 |
-
)
|
| 394 |
-
|
| 395 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 396 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 397 |
-
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 398 |
-
x_ptrs = x_ptr + (
|
| 399 |
-
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
|
| 400 |
-
)
|
| 401 |
-
b_ptrs = b_ptr + (
|
| 402 |
-
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
|
| 403 |
-
)
|
| 404 |
-
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
| 405 |
-
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
| 406 |
-
tl.float32
|
| 407 |
-
)
|
| 408 |
-
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
| 409 |
-
if HAS_SEQ_IDX:
|
| 410 |
-
seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
|
| 411 |
-
|
| 412 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 413 |
-
if HAS_SEQ_IDX:
|
| 414 |
-
seq_idx_last = tl.load(
|
| 415 |
-
seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
|
| 416 |
-
)
|
| 417 |
-
|
| 418 |
-
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 419 |
-
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
| 420 |
-
x = tl.load(
|
| 421 |
-
x_ptrs,
|
| 422 |
-
mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k),
|
| 423 |
-
other=0.0,
|
| 424 |
-
)
|
| 425 |
-
b = tl.load(
|
| 426 |
-
b_ptrs,
|
| 427 |
-
mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate),
|
| 428 |
-
other=0.0,
|
| 429 |
-
).to(tl.float32)
|
| 430 |
-
dA_cs_k = tl.load(
|
| 431 |
-
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
|
| 432 |
-
).to(tl.float32)
|
| 433 |
-
if HAS_SEQ_IDX:
|
| 434 |
-
seq_idx_k = tl.load(
|
| 435 |
-
seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1
|
| 436 |
-
)
|
| 437 |
-
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
|
| 438 |
-
tl.float32
|
| 439 |
-
)
|
| 440 |
-
if not HAS_SEQ_IDX:
|
| 441 |
-
scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
|
| 442 |
-
else:
|
| 443 |
-
scale = tl.where(
|
| 444 |
-
seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0
|
| 445 |
-
)
|
| 446 |
-
b *= scale[:, None]
|
| 447 |
-
b = b.to(x_ptr.dtype.element_ty)
|
| 448 |
-
acc += tl.dot(x, b)
|
| 449 |
-
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
| 450 |
-
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
| 451 |
-
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
| 452 |
-
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
| 453 |
-
if HAS_SEQ_IDX:
|
| 454 |
-
seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
|
| 455 |
-
states = acc.to(states_ptr.dtype.element_ty)
|
| 456 |
-
|
| 457 |
-
states_ptr += (
|
| 458 |
-
pid_b * stride_states_batch
|
| 459 |
-
+ pid_c * stride_states_chunk
|
| 460 |
-
+ pid_h * stride_states_head
|
| 461 |
-
)
|
| 462 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 463 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 464 |
-
states_ptrs = states_ptr + (
|
| 465 |
-
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
|
| 466 |
-
)
|
| 467 |
-
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
| 468 |
-
tl.store(states_ptrs, states, mask=c_mask)
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
@triton.autotune(
|
| 472 |
-
configs=[
|
| 473 |
-
triton.Config(
|
| 474 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
| 475 |
-
num_stages=3,
|
| 476 |
-
num_warps=8,
|
| 477 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 478 |
-
),
|
| 479 |
-
triton.Config(
|
| 480 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
| 481 |
-
num_stages=4,
|
| 482 |
-
num_warps=4,
|
| 483 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 484 |
-
),
|
| 485 |
-
triton.Config(
|
| 486 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 487 |
-
num_stages=4,
|
| 488 |
-
num_warps=4,
|
| 489 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 490 |
-
),
|
| 491 |
-
triton.Config(
|
| 492 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 493 |
-
num_stages=4,
|
| 494 |
-
num_warps=4,
|
| 495 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 496 |
-
),
|
| 497 |
-
triton.Config(
|
| 498 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 499 |
-
num_stages=4,
|
| 500 |
-
num_warps=4,
|
| 501 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 502 |
-
),
|
| 503 |
-
triton.Config(
|
| 504 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 505 |
-
num_stages=4,
|
| 506 |
-
num_warps=4,
|
| 507 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 508 |
-
),
|
| 509 |
-
triton.Config(
|
| 510 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 511 |
-
num_stages=5,
|
| 512 |
-
num_warps=4,
|
| 513 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 514 |
-
),
|
| 515 |
-
triton.Config(
|
| 516 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 517 |
-
num_stages=5,
|
| 518 |
-
num_warps=4,
|
| 519 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 520 |
-
),
|
| 521 |
-
triton.Config(
|
| 522 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 523 |
-
num_stages=4,
|
| 524 |
-
num_warps=4,
|
| 525 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 526 |
-
),
|
| 527 |
-
],
|
| 528 |
-
key=["chunk_size", "hdim", "dstate"],
|
| 529 |
-
)
|
| 530 |
-
@triton.jit
|
| 531 |
-
def _chunk_state_bwd_dx_kernel(
|
| 532 |
-
# Pointers to matrices
|
| 533 |
-
x_ptr,
|
| 534 |
-
b_ptr,
|
| 535 |
-
dstates_ptr,
|
| 536 |
-
dt_ptr,
|
| 537 |
-
dA_cumsum_ptr,
|
| 538 |
-
dx_ptr,
|
| 539 |
-
ddt_ptr,
|
| 540 |
-
ddA_cumsum_ptr,
|
| 541 |
-
# Matrix dimensions
|
| 542 |
-
chunk_size,
|
| 543 |
-
hdim,
|
| 544 |
-
dstate,
|
| 545 |
-
batch,
|
| 546 |
-
seqlen,
|
| 547 |
-
nheads_ngroups_ratio,
|
| 548 |
-
# Strides
|
| 549 |
-
stride_x_batch,
|
| 550 |
-
stride_x_seqlen,
|
| 551 |
-
stride_x_head,
|
| 552 |
-
stride_x_hdim,
|
| 553 |
-
stride_b_batch,
|
| 554 |
-
stride_b_seqlen,
|
| 555 |
-
stride_b_head,
|
| 556 |
-
stride_b_dstate,
|
| 557 |
-
stride_dstates_batch,
|
| 558 |
-
stride_dstates_chunk,
|
| 559 |
-
stride_states_head,
|
| 560 |
-
stride_states_hdim,
|
| 561 |
-
stride_states_dstate,
|
| 562 |
-
stride_dt_batch,
|
| 563 |
-
stride_dt_chunk,
|
| 564 |
-
stride_dt_head,
|
| 565 |
-
stride_dt_csize,
|
| 566 |
-
stride_dA_cs_batch,
|
| 567 |
-
stride_dA_cs_chunk,
|
| 568 |
-
stride_dA_cs_head,
|
| 569 |
-
stride_dA_cs_csize,
|
| 570 |
-
stride_dx_batch,
|
| 571 |
-
stride_dx_seqlen,
|
| 572 |
-
stride_dx_head,
|
| 573 |
-
stride_dx_hdim,
|
| 574 |
-
stride_ddt_batch,
|
| 575 |
-
stride_ddt_chunk,
|
| 576 |
-
stride_ddt_head,
|
| 577 |
-
stride_ddt_csize,
|
| 578 |
-
stride_ddA_cs_batch,
|
| 579 |
-
stride_ddA_cs_chunk,
|
| 580 |
-
stride_ddA_cs_head,
|
| 581 |
-
stride_ddA_cs_csize,
|
| 582 |
-
# Meta-parameters
|
| 583 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 584 |
-
BLOCK_SIZE_N: tl.constexpr,
|
| 585 |
-
BLOCK_SIZE_K: tl.constexpr,
|
| 586 |
-
BLOCK_SIZE_DSTATE: tl.constexpr,
|
| 587 |
-
):
|
| 588 |
-
pid_bc = tl.program_id(axis=1)
|
| 589 |
-
pid_c = pid_bc // batch
|
| 590 |
-
pid_b = pid_bc - pid_c * batch
|
| 591 |
-
pid_h = tl.program_id(axis=2)
|
| 592 |
-
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
| 593 |
-
pid_m = tl.program_id(axis=0) // num_pid_n
|
| 594 |
-
pid_n = tl.program_id(axis=0) % num_pid_n
|
| 595 |
-
x_ptr += (
|
| 596 |
-
pid_b * stride_x_batch
|
| 597 |
-
+ pid_c * chunk_size * stride_x_seqlen
|
| 598 |
-
+ pid_h * stride_x_head
|
| 599 |
-
)
|
| 600 |
-
b_ptr += (
|
| 601 |
-
pid_b * stride_b_batch
|
| 602 |
-
+ pid_c * chunk_size * stride_b_seqlen
|
| 603 |
-
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
| 604 |
-
)
|
| 605 |
-
dstates_ptr += (
|
| 606 |
-
pid_b * stride_dstates_batch
|
| 607 |
-
+ pid_c * stride_dstates_chunk
|
| 608 |
-
+ pid_h * stride_states_head
|
| 609 |
-
)
|
| 610 |
-
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
| 611 |
-
ddt_ptr += (
|
| 612 |
-
pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
|
| 613 |
-
)
|
| 614 |
-
ddA_cumsum_ptr += (
|
| 615 |
-
pid_b * stride_ddA_cs_batch
|
| 616 |
-
+ pid_c * stride_ddA_cs_chunk
|
| 617 |
-
+ pid_h * stride_ddA_cs_head
|
| 618 |
-
)
|
| 619 |
-
dA_cumsum_ptr += (
|
| 620 |
-
pid_b * stride_dA_cs_batch
|
| 621 |
-
+ pid_c * stride_dA_cs_chunk
|
| 622 |
-
+ pid_h * stride_dA_cs_head
|
| 623 |
-
)
|
| 624 |
-
|
| 625 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 626 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 627 |
-
|
| 628 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 629 |
-
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
| 630 |
-
offs_k = tl.arange(
|
| 631 |
-
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
|
| 632 |
-
)
|
| 633 |
-
b_ptrs = b_ptr + (
|
| 634 |
-
offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate
|
| 635 |
-
)
|
| 636 |
-
dstates_ptrs = dstates_ptr + (
|
| 637 |
-
offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate
|
| 638 |
-
)
|
| 639 |
-
if BLOCK_SIZE_DSTATE <= 128:
|
| 640 |
-
b = tl.load(
|
| 641 |
-
b_ptrs,
|
| 642 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate),
|
| 643 |
-
other=0.0,
|
| 644 |
-
)
|
| 645 |
-
dstates = tl.load(
|
| 646 |
-
dstates_ptrs,
|
| 647 |
-
mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim),
|
| 648 |
-
other=0.0,
|
| 649 |
-
)
|
| 650 |
-
dstates = dstates.to(b_ptr.dtype.element_ty)
|
| 651 |
-
acc = tl.dot(b, dstates)
|
| 652 |
-
else:
|
| 653 |
-
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 654 |
-
for k in range(0, dstate, BLOCK_SIZE_K):
|
| 655 |
-
b = tl.load(
|
| 656 |
-
b_ptrs,
|
| 657 |
-
mask=(offs_m[:, None] < chunk_size_limit)
|
| 658 |
-
& (offs_k[None, :] < dstate - k),
|
| 659 |
-
other=0.0,
|
| 660 |
-
)
|
| 661 |
-
dstates = tl.load(
|
| 662 |
-
dstates_ptrs,
|
| 663 |
-
mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim),
|
| 664 |
-
other=0.0,
|
| 665 |
-
)
|
| 666 |
-
dstates = dstates.to(b_ptr.dtype.element_ty)
|
| 667 |
-
acc += tl.dot(b, dstates)
|
| 668 |
-
b_ptrs += BLOCK_SIZE_K * stride_b_dstate
|
| 669 |
-
dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
|
| 670 |
-
|
| 671 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 672 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 673 |
-
|
| 674 |
-
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
| 675 |
-
tl.float32
|
| 676 |
-
)
|
| 677 |
-
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
| 678 |
-
dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
|
| 679 |
-
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(
|
| 680 |
-
tl.float32
|
| 681 |
-
)
|
| 682 |
-
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
| 683 |
-
acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]
|
| 684 |
-
|
| 685 |
-
x_ptrs = x_ptr + (
|
| 686 |
-
offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
|
| 687 |
-
)
|
| 688 |
-
x = tl.load(
|
| 689 |
-
x_ptrs,
|
| 690 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
| 691 |
-
other=0.0,
|
| 692 |
-
).to(tl.float32)
|
| 693 |
-
ddt = tl.sum(acc * x, axis=1)
|
| 694 |
-
ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
|
| 695 |
-
tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
|
| 696 |
-
ddA_cs = -(ddt * dt_m)
|
| 697 |
-
ddA_cs_last = -tl.sum(ddA_cs)
|
| 698 |
-
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
|
| 699 |
-
tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
|
| 700 |
-
tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last)
|
| 701 |
-
|
| 702 |
-
dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty)
|
| 703 |
-
dx_ptr += (
|
| 704 |
-
pid_b * stride_dx_batch
|
| 705 |
-
+ pid_c * chunk_size * stride_dx_seqlen
|
| 706 |
-
+ pid_h * stride_dx_head
|
| 707 |
-
)
|
| 708 |
-
dx_ptrs = dx_ptr + (
|
| 709 |
-
offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim
|
| 710 |
-
)
|
| 711 |
-
tl.store(
|
| 712 |
-
dx_ptrs,
|
| 713 |
-
dx,
|
| 714 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
| 715 |
-
)
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
@triton.autotune(
|
| 719 |
-
configs=[
|
| 720 |
-
triton.Config(
|
| 721 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128},
|
| 722 |
-
num_stages=3,
|
| 723 |
-
num_warps=4,
|
| 724 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 725 |
-
),
|
| 726 |
-
triton.Config(
|
| 727 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32},
|
| 728 |
-
num_stages=3,
|
| 729 |
-
num_warps=4,
|
| 730 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 731 |
-
),
|
| 732 |
-
triton.Config(
|
| 733 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128},
|
| 734 |
-
num_stages=3,
|
| 735 |
-
num_warps=4,
|
| 736 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 737 |
-
),
|
| 738 |
-
triton.Config(
|
| 739 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64},
|
| 740 |
-
num_stages=3,
|
| 741 |
-
num_warps=4,
|
| 742 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 743 |
-
),
|
| 744 |
-
triton.Config(
|
| 745 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64},
|
| 746 |
-
num_stages=3,
|
| 747 |
-
num_warps=4,
|
| 748 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 749 |
-
),
|
| 750 |
-
triton.Config(
|
| 751 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32},
|
| 752 |
-
num_stages=3,
|
| 753 |
-
num_warps=4,
|
| 754 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 755 |
-
),
|
| 756 |
-
triton.Config(
|
| 757 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64},
|
| 758 |
-
num_stages=3,
|
| 759 |
-
num_warps=4,
|
| 760 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 761 |
-
),
|
| 762 |
-
triton.Config(
|
| 763 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32},
|
| 764 |
-
num_stages=3,
|
| 765 |
-
num_warps=4,
|
| 766 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 767 |
-
),
|
| 768 |
-
],
|
| 769 |
-
key=["chunk_size", "dstate", "hdim"],
|
| 770 |
-
)
|
| 771 |
-
@triton.jit
|
| 772 |
-
def _chunk_state_bwd_db_kernel(
|
| 773 |
-
# Pointers to matrices
|
| 774 |
-
x_ptr,
|
| 775 |
-
dstates_ptr,
|
| 776 |
-
b_ptr,
|
| 777 |
-
dt_ptr,
|
| 778 |
-
dA_cumsum_ptr,
|
| 779 |
-
seq_idx_ptr,
|
| 780 |
-
db_ptr,
|
| 781 |
-
ddA_cumsum_ptr,
|
| 782 |
-
# Matrix dimensions
|
| 783 |
-
chunk_size,
|
| 784 |
-
dstate,
|
| 785 |
-
hdim,
|
| 786 |
-
batch,
|
| 787 |
-
seqlen,
|
| 788 |
-
nheads,
|
| 789 |
-
nheads_per_program,
|
| 790 |
-
ngroups,
|
| 791 |
-
# Strides
|
| 792 |
-
stride_x_batch,
|
| 793 |
-
stride_x_seqlen,
|
| 794 |
-
stride_x_head,
|
| 795 |
-
stride_x_hdim,
|
| 796 |
-
stride_dstates_batch,
|
| 797 |
-
stride_dstates_chunk,
|
| 798 |
-
stride_states_head,
|
| 799 |
-
stride_states_hdim,
|
| 800 |
-
stride_states_dstate,
|
| 801 |
-
stride_b_batch,
|
| 802 |
-
stride_b_seqlen,
|
| 803 |
-
stride_b_head,
|
| 804 |
-
stride_b_dstate,
|
| 805 |
-
stride_dt_batch,
|
| 806 |
-
stride_dt_chunk,
|
| 807 |
-
stride_dt_head,
|
| 808 |
-
stride_dt_csize,
|
| 809 |
-
stride_dA_cs_batch,
|
| 810 |
-
stride_dA_cs_chunk,
|
| 811 |
-
stride_dA_cs_head,
|
| 812 |
-
stride_dA_cs_csize,
|
| 813 |
-
stride_seq_idx_batch,
|
| 814 |
-
stride_seq_idx_seqlen,
|
| 815 |
-
stride_db_batch,
|
| 816 |
-
stride_db_seqlen,
|
| 817 |
-
stride_db_split,
|
| 818 |
-
stride_db_group,
|
| 819 |
-
stride_db_dstate,
|
| 820 |
-
stride_ddA_cs_batch,
|
| 821 |
-
stride_ddA_cs_chunk,
|
| 822 |
-
stride_ddA_cs_head,
|
| 823 |
-
stride_ddA_cs_csize,
|
| 824 |
-
# Meta-parameters
|
| 825 |
-
HAS_DDA_CS: tl.constexpr,
|
| 826 |
-
HAS_SEQ_IDX: tl.constexpr,
|
| 827 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 828 |
-
BLOCK_SIZE_N: tl.constexpr,
|
| 829 |
-
BLOCK_SIZE_K: tl.constexpr,
|
| 830 |
-
):
|
| 831 |
-
pid_bc = tl.program_id(axis=1)
|
| 832 |
-
pid_c = pid_bc // batch
|
| 833 |
-
pid_b = pid_bc - pid_c * batch
|
| 834 |
-
pid_sg = tl.program_id(axis=2)
|
| 835 |
-
pid_s = pid_sg // ngroups
|
| 836 |
-
pid_g = pid_sg - pid_s * ngroups
|
| 837 |
-
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
| 838 |
-
pid_m = tl.program_id(axis=0) // num_pid_n
|
| 839 |
-
pid_n = tl.program_id(axis=0) % num_pid_n
|
| 840 |
-
x_ptr += (
|
| 841 |
-
pid_b * stride_x_batch
|
| 842 |
-
+ pid_c * chunk_size * stride_x_seqlen
|
| 843 |
-
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head
|
| 844 |
-
)
|
| 845 |
-
db_ptr += (
|
| 846 |
-
pid_b * stride_db_batch
|
| 847 |
-
+ pid_c * chunk_size * stride_db_seqlen
|
| 848 |
-
+ pid_g * stride_db_group
|
| 849 |
-
+ pid_s * stride_db_split
|
| 850 |
-
)
|
| 851 |
-
dstates_ptr += (
|
| 852 |
-
pid_b * stride_dstates_batch
|
| 853 |
-
+ pid_c * stride_dstates_chunk
|
| 854 |
-
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program)
|
| 855 |
-
* stride_states_head
|
| 856 |
-
)
|
| 857 |
-
dt_ptr += (
|
| 858 |
-
pid_b * stride_dt_batch
|
| 859 |
-
+ pid_c * stride_dt_chunk
|
| 860 |
-
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head
|
| 861 |
-
)
|
| 862 |
-
dA_cumsum_ptr += (
|
| 863 |
-
pid_b * stride_dA_cs_batch
|
| 864 |
-
+ pid_c * stride_dA_cs_chunk
|
| 865 |
-
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head
|
| 866 |
-
)
|
| 867 |
-
if HAS_DDA_CS:
|
| 868 |
-
b_ptr += (
|
| 869 |
-
pid_b * stride_b_batch
|
| 870 |
-
+ pid_c * chunk_size * stride_b_seqlen
|
| 871 |
-
+ pid_g * stride_b_head
|
| 872 |
-
)
|
| 873 |
-
ddA_cumsum_ptr += (
|
| 874 |
-
pid_b * stride_ddA_cs_batch
|
| 875 |
-
+ pid_c * stride_ddA_cs_chunk
|
| 876 |
-
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program)
|
| 877 |
-
* stride_ddA_cs_head
|
| 878 |
-
)
|
| 879 |
-
if HAS_SEQ_IDX:
|
| 880 |
-
seq_idx_ptr += (
|
| 881 |
-
pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
| 882 |
-
)
|
| 883 |
-
|
| 884 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 885 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 886 |
-
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 887 |
-
x_ptrs = x_ptr + (
|
| 888 |
-
offs_m[:, None] * stride_x_seqlen + offs_k[None, :] * stride_x_hdim
|
| 889 |
-
)
|
| 890 |
-
dstates_ptrs = dstates_ptr + (
|
| 891 |
-
offs_n[None, :] * stride_states_dstate + offs_k[:, None] * stride_states_hdim
|
| 892 |
-
)
|
| 893 |
-
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
| 894 |
-
dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
|
| 895 |
-
if HAS_DDA_CS:
|
| 896 |
-
b_ptrs = b_ptr + (
|
| 897 |
-
offs_m[:, None] * stride_b_seqlen + offs_n[None, :] * stride_b_dstate
|
| 898 |
-
)
|
| 899 |
-
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
|
| 900 |
-
|
| 901 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 902 |
-
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 903 |
-
if HAS_DDA_CS:
|
| 904 |
-
b = tl.load(
|
| 905 |
-
b_ptrs,
|
| 906 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate),
|
| 907 |
-
other=0.0,
|
| 908 |
-
).to(tl.float32)
|
| 909 |
-
if HAS_SEQ_IDX:
|
| 910 |
-
seq_idx_m = tl.load(
|
| 911 |
-
seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
| 912 |
-
mask=offs_m < chunk_size_limit,
|
| 913 |
-
other=-1,
|
| 914 |
-
)
|
| 915 |
-
seq_idx_last = tl.load(
|
| 916 |
-
seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
|
| 917 |
-
)
|
| 918 |
-
nheads_iter = min(
|
| 919 |
-
nheads_per_program, nheads // ngroups - pid_s * nheads_per_program
|
| 920 |
-
)
|
| 921 |
-
for h in range(nheads_iter):
|
| 922 |
-
x = tl.load(
|
| 923 |
-
x_ptrs,
|
| 924 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim),
|
| 925 |
-
other=0.0,
|
| 926 |
-
)
|
| 927 |
-
dstates = tl.load(
|
| 928 |
-
dstates_ptrs,
|
| 929 |
-
mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate),
|
| 930 |
-
other=0.0,
|
| 931 |
-
)
|
| 932 |
-
dstates = dstates.to(x_ptrs.dtype.element_ty)
|
| 933 |
-
db = tl.dot(x, dstates)
|
| 934 |
-
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
| 935 |
-
tl.float32
|
| 936 |
-
)
|
| 937 |
-
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(
|
| 938 |
-
tl.float32
|
| 939 |
-
)
|
| 940 |
-
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
| 941 |
-
if not HAS_SEQ_IDX:
|
| 942 |
-
scale = tl.exp(dA_cs_last - dA_cs_m)
|
| 943 |
-
else:
|
| 944 |
-
scale = tl.where(
|
| 945 |
-
seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0
|
| 946 |
-
)
|
| 947 |
-
db *= (scale * dt_m)[:, None]
|
| 948 |
-
if HAS_DDA_CS:
|
| 949 |
-
# This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum
|
| 950 |
-
ddA_cs = tl.sum(db * b, axis=1)
|
| 951 |
-
tl.atomic_add(
|
| 952 |
-
ddA_cumsum_ptrs + stride_ddA_cs_csize,
|
| 953 |
-
ddA_cs,
|
| 954 |
-
mask=offs_m < chunk_size - 1,
|
| 955 |
-
)
|
| 956 |
-
acc += db
|
| 957 |
-
x_ptrs += stride_x_head
|
| 958 |
-
dstates_ptrs += stride_states_head
|
| 959 |
-
dt_ptrs += stride_dt_head
|
| 960 |
-
dA_cumsum_ptr += stride_dA_cs_head
|
| 961 |
-
dA_cumsum_ptrs += stride_dA_cs_head
|
| 962 |
-
if HAS_DDA_CS:
|
| 963 |
-
ddA_cumsum_ptrs += stride_ddA_cs_head
|
| 964 |
-
|
| 965 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 966 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 967 |
-
# if HAS_SEQ_IDX:
|
| 968 |
-
# seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
|
| 969 |
-
# seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
|
| 970 |
-
# acc = tl.where(seq_idx_m[:, None] == seq_idx_last, acc, 0.0)
|
| 971 |
-
db_ptrs = db_ptr + (
|
| 972 |
-
offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_dstate
|
| 973 |
-
)
|
| 974 |
-
tl.store(
|
| 975 |
-
db_ptrs,
|
| 976 |
-
acc,
|
| 977 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate),
|
| 978 |
-
)
|
| 979 |
-
|
| 980 |
-
|
| 981 |
-
@triton.autotune(
|
| 982 |
-
configs=[
|
| 983 |
-
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 984 |
-
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 985 |
-
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 986 |
-
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 987 |
-
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 988 |
-
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 989 |
-
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 990 |
-
# triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 991 |
-
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 992 |
-
triton.Config(
|
| 993 |
-
{"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32},
|
| 994 |
-
num_stages=3,
|
| 995 |
-
num_warps=4,
|
| 996 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 997 |
-
),
|
| 998 |
-
triton.Config(
|
| 999 |
-
{"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 1000 |
-
num_stages=3,
|
| 1001 |
-
num_warps=4,
|
| 1002 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1003 |
-
),
|
| 1004 |
-
triton.Config(
|
| 1005 |
-
{"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 1006 |
-
num_stages=3,
|
| 1007 |
-
num_warps=4,
|
| 1008 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1009 |
-
),
|
| 1010 |
-
triton.Config(
|
| 1011 |
-
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 1012 |
-
num_stages=3,
|
| 1013 |
-
num_warps=4,
|
| 1014 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1015 |
-
),
|
| 1016 |
-
triton.Config(
|
| 1017 |
-
{"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32},
|
| 1018 |
-
num_stages=4,
|
| 1019 |
-
num_warps=8,
|
| 1020 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1021 |
-
),
|
| 1022 |
-
triton.Config(
|
| 1023 |
-
{"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 1024 |
-
num_stages=4,
|
| 1025 |
-
num_warps=8,
|
| 1026 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1027 |
-
),
|
| 1028 |
-
triton.Config(
|
| 1029 |
-
{"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 1030 |
-
num_stages=4,
|
| 1031 |
-
num_warps=8,
|
| 1032 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1033 |
-
),
|
| 1034 |
-
triton.Config(
|
| 1035 |
-
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 1036 |
-
num_stages=4,
|
| 1037 |
-
num_warps=8,
|
| 1038 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1039 |
-
),
|
| 1040 |
-
],
|
| 1041 |
-
key=["chunk_size", "hdim", "dstate"],
|
| 1042 |
-
)
|
| 1043 |
-
@triton.jit
|
| 1044 |
-
def _chunk_state_bwd_ddAcs_stable_kernel(
|
| 1045 |
-
# Pointers to matrices
|
| 1046 |
-
x_ptr,
|
| 1047 |
-
b_ptr,
|
| 1048 |
-
dstates_ptr,
|
| 1049 |
-
dt_ptr,
|
| 1050 |
-
dA_cumsum_ptr,
|
| 1051 |
-
seq_idx_ptr,
|
| 1052 |
-
ddA_cumsum_ptr,
|
| 1053 |
-
# Matrix dimensions
|
| 1054 |
-
chunk_size,
|
| 1055 |
-
hdim,
|
| 1056 |
-
dstate,
|
| 1057 |
-
batch,
|
| 1058 |
-
seqlen,
|
| 1059 |
-
nheads_ngroups_ratio,
|
| 1060 |
-
# Strides
|
| 1061 |
-
stride_x_batch,
|
| 1062 |
-
stride_x_seqlen,
|
| 1063 |
-
stride_x_head,
|
| 1064 |
-
stride_x_hdim,
|
| 1065 |
-
stride_b_batch,
|
| 1066 |
-
stride_b_seqlen,
|
| 1067 |
-
stride_b_head,
|
| 1068 |
-
stride_b_dstate,
|
| 1069 |
-
stride_dstates_batch,
|
| 1070 |
-
stride_dstates_chunk,
|
| 1071 |
-
stride_states_head,
|
| 1072 |
-
stride_states_hdim,
|
| 1073 |
-
stride_states_dstate,
|
| 1074 |
-
stride_dt_batch,
|
| 1075 |
-
stride_dt_chunk,
|
| 1076 |
-
stride_dt_head,
|
| 1077 |
-
stride_dt_csize,
|
| 1078 |
-
stride_dA_cs_batch,
|
| 1079 |
-
stride_dA_cs_chunk,
|
| 1080 |
-
stride_dA_cs_head,
|
| 1081 |
-
stride_dA_cs_csize,
|
| 1082 |
-
stride_seq_idx_batch,
|
| 1083 |
-
stride_seq_idx_seqlen,
|
| 1084 |
-
stride_ddA_cs_batch,
|
| 1085 |
-
stride_ddA_cs_chunk,
|
| 1086 |
-
stride_ddA_cs_head,
|
| 1087 |
-
stride_ddA_cs_csize,
|
| 1088 |
-
# Meta-parameters
|
| 1089 |
-
HAS_SEQ_IDX: tl.constexpr,
|
| 1090 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 1091 |
-
BLOCK_SIZE_N: tl.constexpr,
|
| 1092 |
-
BLOCK_SIZE_K: tl.constexpr,
|
| 1093 |
-
BLOCK_SIZE_DSTATE: tl.constexpr,
|
| 1094 |
-
):
|
| 1095 |
-
pid_bc = tl.program_id(axis=1)
|
| 1096 |
-
pid_c = pid_bc // batch
|
| 1097 |
-
pid_b = pid_bc - pid_c * batch
|
| 1098 |
-
pid_h = tl.program_id(axis=2)
|
| 1099 |
-
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
| 1100 |
-
pid_m = tl.program_id(axis=0) // num_pid_n
|
| 1101 |
-
pid_n = tl.program_id(axis=0) % num_pid_n
|
| 1102 |
-
x_ptr += (
|
| 1103 |
-
pid_b * stride_x_batch
|
| 1104 |
-
+ pid_c * chunk_size * stride_x_seqlen
|
| 1105 |
-
+ pid_h * stride_x_head
|
| 1106 |
-
)
|
| 1107 |
-
b_ptr += (
|
| 1108 |
-
pid_b * stride_b_batch
|
| 1109 |
-
+ pid_c * chunk_size * stride_b_seqlen
|
| 1110 |
-
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
| 1111 |
-
)
|
| 1112 |
-
dstates_ptr += (
|
| 1113 |
-
pid_b * stride_dstates_batch
|
| 1114 |
-
+ pid_c * stride_dstates_chunk
|
| 1115 |
-
+ pid_h * stride_states_head
|
| 1116 |
-
)
|
| 1117 |
-
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
| 1118 |
-
ddA_cumsum_ptr += (
|
| 1119 |
-
pid_b * stride_ddA_cs_batch
|
| 1120 |
-
+ pid_c * stride_ddA_cs_chunk
|
| 1121 |
-
+ pid_h * stride_ddA_cs_head
|
| 1122 |
-
)
|
| 1123 |
-
dA_cumsum_ptr += (
|
| 1124 |
-
pid_b * stride_dA_cs_batch
|
| 1125 |
-
+ pid_c * stride_dA_cs_chunk
|
| 1126 |
-
+ pid_h * stride_dA_cs_head
|
| 1127 |
-
)
|
| 1128 |
-
if HAS_SEQ_IDX:
|
| 1129 |
-
seq_idx_ptr += (
|
| 1130 |
-
pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
| 1131 |
-
)
|
| 1132 |
-
|
| 1133 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 1134 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 1135 |
-
|
| 1136 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 1137 |
-
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
| 1138 |
-
offs_k = tl.arange(
|
| 1139 |
-
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
|
| 1140 |
-
)
|
| 1141 |
-
b_ptrs = b_ptr + (
|
| 1142 |
-
offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate
|
| 1143 |
-
)
|
| 1144 |
-
dstates_ptrs = dstates_ptr + (
|
| 1145 |
-
offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate
|
| 1146 |
-
)
|
| 1147 |
-
if BLOCK_SIZE_DSTATE <= 128:
|
| 1148 |
-
b = tl.load(
|
| 1149 |
-
b_ptrs,
|
| 1150 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate),
|
| 1151 |
-
other=0.0,
|
| 1152 |
-
)
|
| 1153 |
-
dstates = tl.load(
|
| 1154 |
-
dstates_ptrs,
|
| 1155 |
-
mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim),
|
| 1156 |
-
other=0.0,
|
| 1157 |
-
)
|
| 1158 |
-
dstates = dstates.to(b_ptr.dtype.element_ty)
|
| 1159 |
-
acc = tl.dot(b, dstates)
|
| 1160 |
-
else:
|
| 1161 |
-
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 1162 |
-
for k in range(0, dstate, BLOCK_SIZE_K):
|
| 1163 |
-
b = tl.load(
|
| 1164 |
-
b_ptrs,
|
| 1165 |
-
mask=(offs_m[:, None] < chunk_size_limit)
|
| 1166 |
-
& (offs_k[None, :] < dstate - k),
|
| 1167 |
-
other=0.0,
|
| 1168 |
-
)
|
| 1169 |
-
dstates = tl.load(
|
| 1170 |
-
dstates_ptrs,
|
| 1171 |
-
mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim),
|
| 1172 |
-
other=0.0,
|
| 1173 |
-
)
|
| 1174 |
-
dstates = dstates.to(b_ptr.dtype.element_ty)
|
| 1175 |
-
acc += tl.dot(b, dstates)
|
| 1176 |
-
b_ptrs += BLOCK_SIZE_K * stride_b_dstate
|
| 1177 |
-
dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
|
| 1178 |
-
|
| 1179 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 1180 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 1181 |
-
|
| 1182 |
-
dA_cs_m = tl.load(
|
| 1183 |
-
dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0
|
| 1184 |
-
).to(tl.float32)
|
| 1185 |
-
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
| 1186 |
-
tl.float32
|
| 1187 |
-
)
|
| 1188 |
-
if not HAS_SEQ_IDX:
|
| 1189 |
-
scale = tl.exp(dA_cs_last - dA_cs_m)
|
| 1190 |
-
else:
|
| 1191 |
-
seq_idx_m = tl.load(
|
| 1192 |
-
seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
| 1193 |
-
mask=offs_m < chunk_size_limit,
|
| 1194 |
-
other=-1,
|
| 1195 |
-
)
|
| 1196 |
-
seq_idx_last = tl.load(
|
| 1197 |
-
seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
|
| 1198 |
-
)
|
| 1199 |
-
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
|
| 1200 |
-
acc *= scale[:, None]
|
| 1201 |
-
|
| 1202 |
-
x_ptrs = x_ptr + (
|
| 1203 |
-
offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
|
| 1204 |
-
)
|
| 1205 |
-
x = tl.load(
|
| 1206 |
-
x_ptrs,
|
| 1207 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
| 1208 |
-
other=0.0,
|
| 1209 |
-
).to(tl.float32)
|
| 1210 |
-
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
| 1211 |
-
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
| 1212 |
-
ddt = tl.sum(acc * x, axis=1)
|
| 1213 |
-
# ddA_cs = -(ddt * dt_m)
|
| 1214 |
-
# Triton 2.2.0 errors if we have the cumsum here, so we just write it out
|
| 1215 |
-
# then call torch.cumsum outside this kernel.
|
| 1216 |
-
# ddA_cs = tl.cumsum(ddt * dt_m)
|
| 1217 |
-
ddA_cs = ddt * dt_m
|
| 1218 |
-
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
|
| 1219 |
-
# tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
|
| 1220 |
-
tl.atomic_add(
|
| 1221 |
-
ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1
|
| 1222 |
-
)
|
| 1223 |
-
|
| 1224 |
-
|
| 1225 |
-
@triton.autotune(
|
| 1226 |
-
configs=[
|
| 1227 |
-
triton.Config(
|
| 1228 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
| 1229 |
-
num_stages=3,
|
| 1230 |
-
num_warps=8,
|
| 1231 |
-
),
|
| 1232 |
-
triton.Config(
|
| 1233 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
| 1234 |
-
num_stages=4,
|
| 1235 |
-
num_warps=4,
|
| 1236 |
-
),
|
| 1237 |
-
triton.Config(
|
| 1238 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 1239 |
-
num_stages=4,
|
| 1240 |
-
num_warps=4,
|
| 1241 |
-
),
|
| 1242 |
-
triton.Config(
|
| 1243 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 1244 |
-
num_stages=4,
|
| 1245 |
-
num_warps=4,
|
| 1246 |
-
),
|
| 1247 |
-
triton.Config(
|
| 1248 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 1249 |
-
num_stages=4,
|
| 1250 |
-
num_warps=4,
|
| 1251 |
-
),
|
| 1252 |
-
triton.Config(
|
| 1253 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 1254 |
-
num_stages=4,
|
| 1255 |
-
num_warps=4,
|
| 1256 |
-
),
|
| 1257 |
-
triton.Config(
|
| 1258 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 1259 |
-
num_stages=5,
|
| 1260 |
-
num_warps=2,
|
| 1261 |
-
),
|
| 1262 |
-
triton.Config(
|
| 1263 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 1264 |
-
num_stages=5,
|
| 1265 |
-
num_warps=2,
|
| 1266 |
-
),
|
| 1267 |
-
triton.Config(
|
| 1268 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 1269 |
-
num_stages=4,
|
| 1270 |
-
num_warps=2,
|
| 1271 |
-
),
|
| 1272 |
-
],
|
| 1273 |
-
key=["hdim", "dstate", "chunk_size"],
|
| 1274 |
-
)
|
| 1275 |
-
@triton.jit
|
| 1276 |
-
def _chunk_state_varlen_kernel(
|
| 1277 |
-
# Pointers to matrices
|
| 1278 |
-
x_ptr,
|
| 1279 |
-
b_ptr,
|
| 1280 |
-
dt_ptr,
|
| 1281 |
-
dA_cumsum_ptr,
|
| 1282 |
-
chunk_states_ptr,
|
| 1283 |
-
cu_seqlens_ptr,
|
| 1284 |
-
states_ptr,
|
| 1285 |
-
# Matrix dimensions
|
| 1286 |
-
hdim,
|
| 1287 |
-
dstate,
|
| 1288 |
-
chunk_size,
|
| 1289 |
-
seqlen,
|
| 1290 |
-
nheads_ngroups_ratio,
|
| 1291 |
-
# Strides
|
| 1292 |
-
stride_x_seqlen,
|
| 1293 |
-
stride_x_head,
|
| 1294 |
-
stride_x_hdim,
|
| 1295 |
-
stride_b_seqlen,
|
| 1296 |
-
stride_b_head,
|
| 1297 |
-
stride_b_dstate,
|
| 1298 |
-
stride_dt_chunk,
|
| 1299 |
-
stride_dt_head,
|
| 1300 |
-
stride_dt_csize,
|
| 1301 |
-
stride_dA_cs_chunk,
|
| 1302 |
-
stride_dA_cs_head,
|
| 1303 |
-
stride_dA_cs_csize,
|
| 1304 |
-
stride_chunk_states_chunk,
|
| 1305 |
-
stride_chunk_states_head,
|
| 1306 |
-
stride_chunk_states_hdim,
|
| 1307 |
-
stride_chunk_states_dstate,
|
| 1308 |
-
stride_states_batch,
|
| 1309 |
-
stride_states_head,
|
| 1310 |
-
stride_states_hdim,
|
| 1311 |
-
stride_states_dstate,
|
| 1312 |
-
# Meta-parameters
|
| 1313 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 1314 |
-
BLOCK_SIZE_N: tl.constexpr,
|
| 1315 |
-
BLOCK_SIZE_K: tl.constexpr,
|
| 1316 |
-
):
|
| 1317 |
-
pid_b = tl.program_id(axis=1)
|
| 1318 |
-
pid_h = tl.program_id(axis=2)
|
| 1319 |
-
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
| 1320 |
-
pid_m = tl.program_id(axis=0) // num_pid_n
|
| 1321 |
-
pid_n = tl.program_id(axis=0) % num_pid_n
|
| 1322 |
-
end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
|
| 1323 |
-
pid_c = (end_idx - 1) // chunk_size
|
| 1324 |
-
b_ptr += (
|
| 1325 |
-
pid_c * chunk_size * stride_b_seqlen
|
| 1326 |
-
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
| 1327 |
-
)
|
| 1328 |
-
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
| 1329 |
-
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
| 1330 |
-
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
| 1331 |
-
chunk_states_ptr += (
|
| 1332 |
-
pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
|
| 1333 |
-
)
|
| 1334 |
-
|
| 1335 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 1336 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 1337 |
-
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 1338 |
-
x_ptrs = x_ptr + (
|
| 1339 |
-
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
|
| 1340 |
-
)
|
| 1341 |
-
b_ptrs = b_ptr + (
|
| 1342 |
-
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
|
| 1343 |
-
)
|
| 1344 |
-
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
| 1345 |
-
dA_cs_last = tl.load(
|
| 1346 |
-
dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
|
| 1347 |
-
).to(tl.float32)
|
| 1348 |
-
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
| 1349 |
-
|
| 1350 |
-
chunk_size_limit = end_idx - pid_c * chunk_size
|
| 1351 |
-
start_idx = tl.load(cu_seqlens_ptr + pid_b)
|
| 1352 |
-
start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
|
| 1353 |
-
|
| 1354 |
-
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 1355 |
-
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
| 1356 |
-
x = tl.load(
|
| 1357 |
-
x_ptrs,
|
| 1358 |
-
mask=(offs_m[:, None] < hdim)
|
| 1359 |
-
& (offs_k[None, :] < chunk_size_limit - k)
|
| 1360 |
-
& (offs_k[None, :] >= start_idx_cur - k),
|
| 1361 |
-
other=0.0,
|
| 1362 |
-
)
|
| 1363 |
-
b = tl.load(
|
| 1364 |
-
b_ptrs,
|
| 1365 |
-
mask=(offs_k[:, None] < chunk_size_limit - k)
|
| 1366 |
-
& (offs_n[None, :] < dstate)
|
| 1367 |
-
& (offs_k[:, None] >= start_idx_cur - k),
|
| 1368 |
-
other=0.0,
|
| 1369 |
-
).to(tl.float32)
|
| 1370 |
-
dA_cs_k = tl.load(
|
| 1371 |
-
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
|
| 1372 |
-
).to(tl.float32)
|
| 1373 |
-
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
|
| 1374 |
-
tl.float32
|
| 1375 |
-
)
|
| 1376 |
-
scale = tl.where(
|
| 1377 |
-
(offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
|
| 1378 |
-
tl.exp((dA_cs_last - dA_cs_k)) * dt_k,
|
| 1379 |
-
0.0,
|
| 1380 |
-
)
|
| 1381 |
-
b *= scale[:, None]
|
| 1382 |
-
b = b.to(x_ptr.dtype.element_ty)
|
| 1383 |
-
acc += tl.dot(x, b)
|
| 1384 |
-
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
| 1385 |
-
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
| 1386 |
-
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
| 1387 |
-
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
| 1388 |
-
|
| 1389 |
-
# If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
|
| 1390 |
-
if start_idx < pid_c * chunk_size:
|
| 1391 |
-
chunk_states_ptrs = chunk_states_ptr + (
|
| 1392 |
-
offs_m[:, None] * stride_chunk_states_hdim
|
| 1393 |
-
+ offs_n[None, :] * stride_chunk_states_dstate
|
| 1394 |
-
)
|
| 1395 |
-
chunk_states = tl.load(
|
| 1396 |
-
chunk_states_ptrs,
|
| 1397 |
-
mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate),
|
| 1398 |
-
other=0.0,
|
| 1399 |
-
).to(tl.float32)
|
| 1400 |
-
# scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0)
|
| 1401 |
-
scale = tl.exp(dA_cs_last)
|
| 1402 |
-
acc += chunk_states * scale
|
| 1403 |
-
|
| 1404 |
-
states = acc.to(states_ptr.dtype.element_ty)
|
| 1405 |
-
|
| 1406 |
-
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
|
| 1407 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 1408 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 1409 |
-
states_ptrs = states_ptr + (
|
| 1410 |
-
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
|
| 1411 |
-
)
|
| 1412 |
-
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
| 1413 |
-
tl.store(states_ptrs, states, mask=c_mask)
|
| 1414 |
-
|
| 1415 |
-
|
| 1416 |
-
def _chunk_cumsum_fwd(
|
| 1417 |
-
dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))
|
| 1418 |
-
):
|
| 1419 |
-
batch, seqlen, nheads = dt.shape
|
| 1420 |
-
assert A.shape == (nheads,)
|
| 1421 |
-
if dt_bias is not None:
|
| 1422 |
-
assert dt_bias.shape == (nheads,)
|
| 1423 |
-
nchunks = math.ceil(seqlen / chunk_size)
|
| 1424 |
-
dt_out = torch.empty(
|
| 1425 |
-
batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
|
| 1426 |
-
)
|
| 1427 |
-
dA_cumsum = torch.empty(
|
| 1428 |
-
batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
|
| 1429 |
-
)
|
| 1430 |
-
grid_chunk_cs = lambda META: (
|
| 1431 |
-
batch,
|
| 1432 |
-
nchunks,
|
| 1433 |
-
triton.cdiv(nheads, META["BLOCK_SIZE_H"]),
|
| 1434 |
-
)
|
| 1435 |
-
with torch.cuda.device(dt.device.index):
|
| 1436 |
-
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
|
| 1437 |
-
dt,
|
| 1438 |
-
A,
|
| 1439 |
-
dt_bias,
|
| 1440 |
-
dt_out,
|
| 1441 |
-
dA_cumsum,
|
| 1442 |
-
batch,
|
| 1443 |
-
seqlen,
|
| 1444 |
-
nheads,
|
| 1445 |
-
chunk_size,
|
| 1446 |
-
dt_limit[0],
|
| 1447 |
-
dt_limit[1],
|
| 1448 |
-
dt.stride(0),
|
| 1449 |
-
dt.stride(1),
|
| 1450 |
-
dt.stride(2),
|
| 1451 |
-
A.stride(0),
|
| 1452 |
-
dt_bias.stride(0) if dt_bias is not None else 0,
|
| 1453 |
-
dt_out.stride(0),
|
| 1454 |
-
dt_out.stride(2),
|
| 1455 |
-
dt_out.stride(1),
|
| 1456 |
-
dt_out.stride(3),
|
| 1457 |
-
dA_cumsum.stride(0),
|
| 1458 |
-
dA_cumsum.stride(2),
|
| 1459 |
-
dA_cumsum.stride(1),
|
| 1460 |
-
dA_cumsum.stride(3),
|
| 1461 |
-
dt_softplus,
|
| 1462 |
-
HAS_DT_BIAS=dt_bias is not None,
|
| 1463 |
-
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
|
| 1464 |
-
)
|
| 1465 |
-
return dA_cumsum, dt_out
|
| 1466 |
-
|
| 1467 |
-
|
| 1468 |
-
def _chunk_cumsum_bwd(
|
| 1469 |
-
ddA,
|
| 1470 |
-
ddt_out,
|
| 1471 |
-
dt,
|
| 1472 |
-
A,
|
| 1473 |
-
dt_bias=None,
|
| 1474 |
-
dt_softplus=False,
|
| 1475 |
-
dt_limit=(0.0, float("inf")),
|
| 1476 |
-
ddt=None,
|
| 1477 |
-
):
|
| 1478 |
-
batch, seqlen, nheads = dt.shape
|
| 1479 |
-
_, _, nchunks, chunk_size = ddA.shape
|
| 1480 |
-
assert ddA.shape == (batch, nheads, nchunks, chunk_size)
|
| 1481 |
-
assert ddt_out.shape == (batch, nheads, nchunks, chunk_size)
|
| 1482 |
-
assert A.shape == (nheads,)
|
| 1483 |
-
if dt_bias is not None:
|
| 1484 |
-
assert dt_bias.shape == (nheads,)
|
| 1485 |
-
ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32)
|
| 1486 |
-
else:
|
| 1487 |
-
ddt_bias = None
|
| 1488 |
-
if ddt is not None:
|
| 1489 |
-
assert ddt.shape == dt.shape
|
| 1490 |
-
else:
|
| 1491 |
-
ddt = torch.empty_like(dt)
|
| 1492 |
-
dA = torch.empty_like(A, dtype=torch.float32)
|
| 1493 |
-
grid_chunk_cs = lambda META: (
|
| 1494 |
-
batch,
|
| 1495 |
-
nchunks,
|
| 1496 |
-
triton.cdiv(nheads, META["BLOCK_SIZE_H"]),
|
| 1497 |
-
)
|
| 1498 |
-
with torch.cuda.device(dt.device.index):
|
| 1499 |
-
_chunk_cumsum_bwd_kernel[grid_chunk_cs](
|
| 1500 |
-
ddA,
|
| 1501 |
-
ddt_out,
|
| 1502 |
-
dt,
|
| 1503 |
-
A,
|
| 1504 |
-
dt_bias,
|
| 1505 |
-
ddt,
|
| 1506 |
-
dA,
|
| 1507 |
-
ddt_bias,
|
| 1508 |
-
batch,
|
| 1509 |
-
seqlen,
|
| 1510 |
-
nheads,
|
| 1511 |
-
chunk_size,
|
| 1512 |
-
dt_limit[0],
|
| 1513 |
-
dt_limit[1],
|
| 1514 |
-
ddA.stride(0),
|
| 1515 |
-
ddA.stride(2),
|
| 1516 |
-
ddA.stride(1),
|
| 1517 |
-
ddA.stride(3),
|
| 1518 |
-
ddt_out.stride(0),
|
| 1519 |
-
ddt_out.stride(2),
|
| 1520 |
-
ddt_out.stride(1),
|
| 1521 |
-
ddt_out.stride(3),
|
| 1522 |
-
dt.stride(0),
|
| 1523 |
-
dt.stride(1),
|
| 1524 |
-
dt.stride(2),
|
| 1525 |
-
A.stride(0),
|
| 1526 |
-
dt_bias.stride(0) if dt_bias is not None else 0,
|
| 1527 |
-
ddt.stride(0),
|
| 1528 |
-
ddt.stride(1),
|
| 1529 |
-
ddt.stride(2),
|
| 1530 |
-
dA.stride(0),
|
| 1531 |
-
ddt_bias.stride(0) if ddt_bias is not None else 0,
|
| 1532 |
-
dt_softplus,
|
| 1533 |
-
HAS_DT_BIAS=dt_bias is not None,
|
| 1534 |
-
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
|
| 1535 |
-
)
|
| 1536 |
-
return ddt, dA, ddt_bias
|
| 1537 |
-
|
| 1538 |
-
|
| 1539 |
-
def _chunk_state_fwd(
|
| 1540 |
-
B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True
|
| 1541 |
-
):
|
| 1542 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1543 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1544 |
-
_, _, ngroups, dstate = B.shape
|
| 1545 |
-
assert nheads % ngroups == 0
|
| 1546 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 1547 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 1548 |
-
assert dA_cumsum.shape == dt.shape
|
| 1549 |
-
if seq_idx is not None:
|
| 1550 |
-
assert seq_idx.shape == (batch, seqlen)
|
| 1551 |
-
if states is not None:
|
| 1552 |
-
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
|
| 1553 |
-
else:
|
| 1554 |
-
states_dtype = torch.float32 if states_in_fp32 else B.dtype
|
| 1555 |
-
states = torch.empty(
|
| 1556 |
-
(batch, nchunks, nheads, headdim, dstate),
|
| 1557 |
-
device=x.device,
|
| 1558 |
-
dtype=states_dtype,
|
| 1559 |
-
)
|
| 1560 |
-
grid = lambda META: (
|
| 1561 |
-
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
|
| 1562 |
-
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
| 1563 |
-
batch * nchunks,
|
| 1564 |
-
nheads,
|
| 1565 |
-
)
|
| 1566 |
-
with torch.cuda.device(x.device.index):
|
| 1567 |
-
_chunk_state_fwd_kernel[grid](
|
| 1568 |
-
x,
|
| 1569 |
-
B,
|
| 1570 |
-
states,
|
| 1571 |
-
dt,
|
| 1572 |
-
dA_cumsum,
|
| 1573 |
-
seq_idx,
|
| 1574 |
-
headdim,
|
| 1575 |
-
dstate,
|
| 1576 |
-
chunk_size,
|
| 1577 |
-
batch,
|
| 1578 |
-
seqlen,
|
| 1579 |
-
nheads // ngroups,
|
| 1580 |
-
x.stride(0),
|
| 1581 |
-
x.stride(1),
|
| 1582 |
-
x.stride(2),
|
| 1583 |
-
x.stride(3),
|
| 1584 |
-
B.stride(0),
|
| 1585 |
-
B.stride(1),
|
| 1586 |
-
B.stride(2),
|
| 1587 |
-
B.stride(-1),
|
| 1588 |
-
states.stride(0),
|
| 1589 |
-
states.stride(1),
|
| 1590 |
-
states.stride(2),
|
| 1591 |
-
states.stride(3),
|
| 1592 |
-
states.stride(4),
|
| 1593 |
-
dt.stride(0),
|
| 1594 |
-
dt.stride(2),
|
| 1595 |
-
dt.stride(1),
|
| 1596 |
-
dt.stride(3),
|
| 1597 |
-
dA_cumsum.stride(0),
|
| 1598 |
-
dA_cumsum.stride(2),
|
| 1599 |
-
dA_cumsum.stride(1),
|
| 1600 |
-
dA_cumsum.stride(3),
|
| 1601 |
-
*(
|
| 1602 |
-
(seq_idx.stride(0), seq_idx.stride(1))
|
| 1603 |
-
if seq_idx is not None
|
| 1604 |
-
else (0, 0)
|
| 1605 |
-
),
|
| 1606 |
-
HAS_SEQ_IDX=seq_idx is not None,
|
| 1607 |
-
)
|
| 1608 |
-
return states
|
| 1609 |
-
|
| 1610 |
-
|
| 1611 |
-
def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None):
|
| 1612 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1613 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1614 |
-
_, _, ngroups, dstate = B.shape
|
| 1615 |
-
assert nheads % ngroups == 0
|
| 1616 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 1617 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 1618 |
-
assert dA_cumsum.shape == dt.shape
|
| 1619 |
-
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
| 1620 |
-
if dx is not None:
|
| 1621 |
-
assert dx.shape == x.shape
|
| 1622 |
-
else:
|
| 1623 |
-
dx = torch.empty_like(x)
|
| 1624 |
-
ddt = torch.empty(
|
| 1625 |
-
batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
|
| 1626 |
-
)
|
| 1627 |
-
ddA_cumsum = torch.empty(
|
| 1628 |
-
batch, nheads, nchunks, chunk_size, device=dA_cumsum.device, dtype=torch.float32
|
| 1629 |
-
)
|
| 1630 |
-
grid_dx = lambda META: (
|
| 1631 |
-
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
| 1632 |
-
* triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
|
| 1633 |
-
batch * nchunks,
|
| 1634 |
-
nheads,
|
| 1635 |
-
)
|
| 1636 |
-
with torch.cuda.device(x.device.index):
|
| 1637 |
-
_chunk_state_bwd_dx_kernel[grid_dx](
|
| 1638 |
-
x,
|
| 1639 |
-
B,
|
| 1640 |
-
dstates,
|
| 1641 |
-
dt,
|
| 1642 |
-
dA_cumsum,
|
| 1643 |
-
dx,
|
| 1644 |
-
ddt,
|
| 1645 |
-
ddA_cumsum,
|
| 1646 |
-
chunk_size,
|
| 1647 |
-
headdim,
|
| 1648 |
-
dstate,
|
| 1649 |
-
batch,
|
| 1650 |
-
seqlen,
|
| 1651 |
-
nheads // ngroups,
|
| 1652 |
-
x.stride(0),
|
| 1653 |
-
x.stride(1),
|
| 1654 |
-
x.stride(2),
|
| 1655 |
-
x.stride(3),
|
| 1656 |
-
B.stride(0),
|
| 1657 |
-
B.stride(1),
|
| 1658 |
-
B.stride(2),
|
| 1659 |
-
B.stride(-1),
|
| 1660 |
-
dstates.stride(0),
|
| 1661 |
-
dstates.stride(1),
|
| 1662 |
-
dstates.stride(2),
|
| 1663 |
-
dstates.stride(3),
|
| 1664 |
-
dstates.stride(4),
|
| 1665 |
-
dt.stride(0),
|
| 1666 |
-
dt.stride(2),
|
| 1667 |
-
dt.stride(1),
|
| 1668 |
-
dt.stride(3),
|
| 1669 |
-
dA_cumsum.stride(0),
|
| 1670 |
-
dA_cumsum.stride(2),
|
| 1671 |
-
dA_cumsum.stride(1),
|
| 1672 |
-
dA_cumsum.stride(3),
|
| 1673 |
-
dx.stride(0),
|
| 1674 |
-
dx.stride(1),
|
| 1675 |
-
dx.stride(2),
|
| 1676 |
-
dx.stride(3),
|
| 1677 |
-
ddt.stride(0),
|
| 1678 |
-
ddt.stride(2),
|
| 1679 |
-
ddt.stride(1),
|
| 1680 |
-
ddt.stride(3),
|
| 1681 |
-
ddA_cumsum.stride(0),
|
| 1682 |
-
ddA_cumsum.stride(2),
|
| 1683 |
-
ddA_cumsum.stride(1),
|
| 1684 |
-
ddA_cumsum.stride(3),
|
| 1685 |
-
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
| 1686 |
-
)
|
| 1687 |
-
return dx, ddt.to(dt.dtype), ddA_cumsum.to(dA_cumsum.dtype)
|
| 1688 |
-
|
| 1689 |
-
|
| 1690 |
-
def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1):
|
| 1691 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1692 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1693 |
-
dstate = dstates.shape[-1]
|
| 1694 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 1695 |
-
assert dA_cumsum.shape == dt.shape
|
| 1696 |
-
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
| 1697 |
-
if seq_idx is not None:
|
| 1698 |
-
assert seq_idx.shape == (batch, seqlen)
|
| 1699 |
-
if B is not None:
|
| 1700 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 1701 |
-
B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3))
|
| 1702 |
-
# Use torch.empty since the Triton kernel will call init_to_zero
|
| 1703 |
-
ddA_cumsum = torch.empty(
|
| 1704 |
-
batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32
|
| 1705 |
-
)
|
| 1706 |
-
ddA_cumsum_strides = (
|
| 1707 |
-
ddA_cumsum.stride(0),
|
| 1708 |
-
ddA_cumsum.stride(2),
|
| 1709 |
-
ddA_cumsum.stride(1),
|
| 1710 |
-
ddA_cumsum.stride(3),
|
| 1711 |
-
)
|
| 1712 |
-
else:
|
| 1713 |
-
B_strides = (0, 0, 0, 0)
|
| 1714 |
-
ddA_cumsum = None
|
| 1715 |
-
ddA_cumsum_strides = (0, 0, 0, 0)
|
| 1716 |
-
nheads_ngroups_ratio = nheads // ngroups
|
| 1717 |
-
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
| 1718 |
-
nheads_per_program = max(
|
| 1719 |
-
min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1
|
| 1720 |
-
)
|
| 1721 |
-
nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)
|
| 1722 |
-
dB = torch.empty(
|
| 1723 |
-
batch, seqlen, nsplits, ngroups, dstate, device=x.device, dtype=torch.float32
|
| 1724 |
-
)
|
| 1725 |
-
grid_db = lambda META: (
|
| 1726 |
-
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
| 1727 |
-
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
| 1728 |
-
batch * nchunks,
|
| 1729 |
-
nsplits * ngroups,
|
| 1730 |
-
)
|
| 1731 |
-
with torch.cuda.device(x.device.index):
|
| 1732 |
-
_chunk_state_bwd_db_kernel[grid_db](
|
| 1733 |
-
x,
|
| 1734 |
-
dstates,
|
| 1735 |
-
B,
|
| 1736 |
-
dt,
|
| 1737 |
-
dA_cumsum,
|
| 1738 |
-
seq_idx,
|
| 1739 |
-
dB,
|
| 1740 |
-
ddA_cumsum,
|
| 1741 |
-
chunk_size,
|
| 1742 |
-
dstate,
|
| 1743 |
-
headdim,
|
| 1744 |
-
batch,
|
| 1745 |
-
seqlen,
|
| 1746 |
-
nheads,
|
| 1747 |
-
nheads_per_program,
|
| 1748 |
-
ngroups,
|
| 1749 |
-
x.stride(0),
|
| 1750 |
-
x.stride(1),
|
| 1751 |
-
x.stride(2),
|
| 1752 |
-
x.stride(3),
|
| 1753 |
-
dstates.stride(0),
|
| 1754 |
-
dstates.stride(1),
|
| 1755 |
-
dstates.stride(2),
|
| 1756 |
-
dstates.stride(3),
|
| 1757 |
-
dstates.stride(4),
|
| 1758 |
-
*B_strides,
|
| 1759 |
-
dt.stride(0),
|
| 1760 |
-
dt.stride(2),
|
| 1761 |
-
dt.stride(1),
|
| 1762 |
-
dt.stride(3),
|
| 1763 |
-
dA_cumsum.stride(0),
|
| 1764 |
-
dA_cumsum.stride(2),
|
| 1765 |
-
dA_cumsum.stride(1),
|
| 1766 |
-
dA_cumsum.stride(3),
|
| 1767 |
-
*(
|
| 1768 |
-
(seq_idx.stride(0), seq_idx.stride(1))
|
| 1769 |
-
if seq_idx is not None
|
| 1770 |
-
else (0, 0)
|
| 1771 |
-
),
|
| 1772 |
-
dB.stride(0),
|
| 1773 |
-
dB.stride(1),
|
| 1774 |
-
dB.stride(2),
|
| 1775 |
-
dB.stride(3),
|
| 1776 |
-
dB.stride(4),
|
| 1777 |
-
*ddA_cumsum_strides,
|
| 1778 |
-
HAS_DDA_CS=ddA_cumsum is not None,
|
| 1779 |
-
HAS_SEQ_IDX=seq_idx is not None,
|
| 1780 |
-
BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
|
| 1781 |
-
)
|
| 1782 |
-
dB = dB.sum(2)
|
| 1783 |
-
if ddA_cumsum is not None:
|
| 1784 |
-
# The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute
|
| 1785 |
-
# to the state of the chunk.
|
| 1786 |
-
# torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
|
| 1787 |
-
# But it's easier to just do the cumsum for all elements, the result will be the same.
|
| 1788 |
-
torch.cumsum(ddA_cumsum, dim=-1, out=ddA_cumsum)
|
| 1789 |
-
return dB if B is None else (dB, ddA_cumsum)
|
| 1790 |
-
|
| 1791 |
-
|
| 1792 |
-
def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None):
|
| 1793 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1794 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1795 |
-
_, _, ngroups, dstate = B.shape
|
| 1796 |
-
assert nheads % ngroups == 0
|
| 1797 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 1798 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 1799 |
-
assert dA_cumsum.shape == dt.shape
|
| 1800 |
-
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
| 1801 |
-
if seq_idx is not None:
|
| 1802 |
-
assert seq_idx.shape == (batch, seqlen)
|
| 1803 |
-
# Use torch.empty since the Triton kernel will call init_to_zero
|
| 1804 |
-
ddA_cumsum = torch.empty(
|
| 1805 |
-
batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32
|
| 1806 |
-
)
|
| 1807 |
-
grid_ddtcs = lambda META: (
|
| 1808 |
-
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
| 1809 |
-
* triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
|
| 1810 |
-
batch * nchunks,
|
| 1811 |
-
nheads,
|
| 1812 |
-
)
|
| 1813 |
-
with torch.cuda.device(x.device.index):
|
| 1814 |
-
_chunk_state_bwd_ddAcs_stable_kernel[grid_ddtcs](
|
| 1815 |
-
x,
|
| 1816 |
-
B,
|
| 1817 |
-
dstates,
|
| 1818 |
-
dt,
|
| 1819 |
-
dA_cumsum,
|
| 1820 |
-
seq_idx,
|
| 1821 |
-
ddA_cumsum,
|
| 1822 |
-
chunk_size,
|
| 1823 |
-
headdim,
|
| 1824 |
-
dstate,
|
| 1825 |
-
batch,
|
| 1826 |
-
seqlen,
|
| 1827 |
-
nheads // ngroups,
|
| 1828 |
-
x.stride(0),
|
| 1829 |
-
x.stride(1),
|
| 1830 |
-
x.stride(2),
|
| 1831 |
-
x.stride(3),
|
| 1832 |
-
B.stride(0),
|
| 1833 |
-
B.stride(1),
|
| 1834 |
-
B.stride(2),
|
| 1835 |
-
B.stride(-1),
|
| 1836 |
-
dstates.stride(0),
|
| 1837 |
-
dstates.stride(1),
|
| 1838 |
-
dstates.stride(2),
|
| 1839 |
-
dstates.stride(3),
|
| 1840 |
-
dstates.stride(4),
|
| 1841 |
-
dt.stride(0),
|
| 1842 |
-
dt.stride(2),
|
| 1843 |
-
dt.stride(1),
|
| 1844 |
-
dt.stride(3),
|
| 1845 |
-
dA_cumsum.stride(0),
|
| 1846 |
-
dA_cumsum.stride(2),
|
| 1847 |
-
dA_cumsum.stride(1),
|
| 1848 |
-
dA_cumsum.stride(3),
|
| 1849 |
-
*(
|
| 1850 |
-
(seq_idx.stride(0), seq_idx.stride(1))
|
| 1851 |
-
if seq_idx is not None
|
| 1852 |
-
else (0, 0)
|
| 1853 |
-
),
|
| 1854 |
-
ddA_cumsum.stride(0),
|
| 1855 |
-
ddA_cumsum.stride(2),
|
| 1856 |
-
ddA_cumsum.stride(1),
|
| 1857 |
-
ddA_cumsum.stride(3),
|
| 1858 |
-
HAS_SEQ_IDX=seq_idx is not None,
|
| 1859 |
-
BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16),
|
| 1860 |
-
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
| 1861 |
-
)
|
| 1862 |
-
torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
|
| 1863 |
-
return ddA_cumsum
|
| 1864 |
-
|
| 1865 |
-
|
| 1866 |
-
def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states):
|
| 1867 |
-
total_seqlen, nheads, headdim = x.shape
|
| 1868 |
-
_, nchunks, chunk_size = dt.shape
|
| 1869 |
-
_, ngroups, dstate = B.shape
|
| 1870 |
-
batch = cu_seqlens.shape[0] - 1
|
| 1871 |
-
cu_seqlens = cu_seqlens.contiguous()
|
| 1872 |
-
assert nheads % ngroups == 0
|
| 1873 |
-
assert B.shape == (total_seqlen, ngroups, dstate)
|
| 1874 |
-
assert dt.shape == (nheads, nchunks, chunk_size)
|
| 1875 |
-
assert dA_cumsum.shape == dt.shape
|
| 1876 |
-
assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
|
| 1877 |
-
states = torch.empty(
|
| 1878 |
-
batch,
|
| 1879 |
-
nheads,
|
| 1880 |
-
headdim,
|
| 1881 |
-
dstate,
|
| 1882 |
-
dtype=chunk_states.dtype,
|
| 1883 |
-
device=chunk_states.device,
|
| 1884 |
-
)
|
| 1885 |
-
grid = lambda META: (
|
| 1886 |
-
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
|
| 1887 |
-
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
| 1888 |
-
batch,
|
| 1889 |
-
nheads,
|
| 1890 |
-
)
|
| 1891 |
-
with torch.cuda.device(x.device.index):
|
| 1892 |
-
_chunk_state_varlen_kernel[grid](
|
| 1893 |
-
x,
|
| 1894 |
-
B,
|
| 1895 |
-
dt,
|
| 1896 |
-
dA_cumsum,
|
| 1897 |
-
chunk_states,
|
| 1898 |
-
cu_seqlens,
|
| 1899 |
-
states,
|
| 1900 |
-
headdim,
|
| 1901 |
-
dstate,
|
| 1902 |
-
chunk_size,
|
| 1903 |
-
total_seqlen,
|
| 1904 |
-
nheads // ngroups,
|
| 1905 |
-
x.stride(0),
|
| 1906 |
-
x.stride(1),
|
| 1907 |
-
x.stride(2),
|
| 1908 |
-
B.stride(0),
|
| 1909 |
-
B.stride(1),
|
| 1910 |
-
B.stride(2),
|
| 1911 |
-
dt.stride(1),
|
| 1912 |
-
dt.stride(0),
|
| 1913 |
-
dt.stride(2),
|
| 1914 |
-
dA_cumsum.stride(1),
|
| 1915 |
-
dA_cumsum.stride(0),
|
| 1916 |
-
dA_cumsum.stride(2),
|
| 1917 |
-
chunk_states.stride(0),
|
| 1918 |
-
chunk_states.stride(1),
|
| 1919 |
-
chunk_states.stride(2),
|
| 1920 |
-
chunk_states.stride(3),
|
| 1921 |
-
states.stride(0),
|
| 1922 |
-
states.stride(1),
|
| 1923 |
-
states.stride(2),
|
| 1924 |
-
states.stride(3),
|
| 1925 |
-
)
|
| 1926 |
-
return states
|
| 1927 |
-
|
| 1928 |
-
|
| 1929 |
-
class ChunkStateFn(torch.autograd.Function):
|
| 1930 |
-
|
| 1931 |
-
@staticmethod
|
| 1932 |
-
def forward(ctx, B, x, dt, dA_cumsum, states_in_fp32=True):
|
| 1933 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1934 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1935 |
-
assert seqlen <= nchunks * chunk_size
|
| 1936 |
-
_, _, ngroups, dstate = B.shape
|
| 1937 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 1938 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 1939 |
-
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
| 1940 |
-
if B.stride(-1) != 1:
|
| 1941 |
-
B = B.contiguous()
|
| 1942 |
-
if (
|
| 1943 |
-
x.stride(-1) != 1 and x.stride(1) != 1
|
| 1944 |
-
): # Either M or K dimension should be contiguous
|
| 1945 |
-
x = x.contiguous()
|
| 1946 |
-
states = _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=states_in_fp32)
|
| 1947 |
-
ctx.save_for_backward(B, x, dt, dA_cumsum)
|
| 1948 |
-
return states
|
| 1949 |
-
|
| 1950 |
-
@staticmethod
|
| 1951 |
-
def backward(ctx, dstates):
|
| 1952 |
-
B, x, dt, dA_cumsum = ctx.saved_tensors
|
| 1953 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1954 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1955 |
-
_, _, ngroups, dstate = B.shape
|
| 1956 |
-
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
| 1957 |
-
if dstates.stride(-1) != 1:
|
| 1958 |
-
dstates = dstates.contiguous()
|
| 1959 |
-
dx, ddt, ddA_cumsum = _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates)
|
| 1960 |
-
dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, ngroups=ngroups)
|
| 1961 |
-
dB = dB.to(B.dtype)
|
| 1962 |
-
return dB, dx, ddt, ddA_cumsum, None
|
| 1963 |
-
|
| 1964 |
-
|
| 1965 |
-
def chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True):
|
| 1966 |
-
"""
|
| 1967 |
-
Argument:
|
| 1968 |
-
B: (batch, seqlen, ngroups, headdim)
|
| 1969 |
-
x: (batch, seqlen, nheads, headdim)
|
| 1970 |
-
dt: (batch, nheads, nchunks, chunk_size)
|
| 1971 |
-
dA_cumsum: (batch, nheads, nchunks, chunk_size)
|
| 1972 |
-
Return:
|
| 1973 |
-
states: (batch, nchunks, nheads, headdim, dstate)
|
| 1974 |
-
"""
|
| 1975 |
-
return ChunkStateFn.apply(B, x, dt, dA_cumsum, states_in_fp32)
|
| 1976 |
-
|
| 1977 |
-
|
| 1978 |
-
def chunk_state_ref(B, x, dt, dA_cumsum):
|
| 1979 |
-
"""
|
| 1980 |
-
Argument:
|
| 1981 |
-
B: (batch, seqlen, ngroups, headdim)
|
| 1982 |
-
x: (batch, seqlen, nheads, headdim)
|
| 1983 |
-
dt: (batch, nheads, nchunks, chunk_size)
|
| 1984 |
-
dA_cumsum: (batch, nheads, nchunks, chunk_size)
|
| 1985 |
-
Return:
|
| 1986 |
-
states: (batch, nchunks, nheads, headdim, dstate)
|
| 1987 |
-
"""
|
| 1988 |
-
# Check constraints.
|
| 1989 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1990 |
-
dstate = B.shape[-1]
|
| 1991 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1992 |
-
assert seqlen <= nchunks * chunk_size
|
| 1993 |
-
assert x.shape == (batch, seqlen, nheads, headdim)
|
| 1994 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 1995 |
-
ngroups = B.shape[2]
|
| 1996 |
-
assert nheads % ngroups == 0
|
| 1997 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 1998 |
-
B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
|
| 1999 |
-
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
| 2000 |
-
if seqlen < nchunks * chunk_size:
|
| 2001 |
-
x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
|
| 2002 |
-
B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
|
| 2003 |
-
x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size)
|
| 2004 |
-
B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size)
|
| 2005 |
-
decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum))
|
| 2006 |
-
return torch.einsum(
|
| 2007 |
-
"bclhn,bhcl,bhcl,bclhp->bchpn",
|
| 2008 |
-
B.to(x.dtype),
|
| 2009 |
-
decay_states.to(x.dtype),
|
| 2010 |
-
dt.to(x.dtype),
|
| 2011 |
-
x,
|
| 2012 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_combined.py
DELETED
|
@@ -1,1884 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
| 2 |
-
|
| 3 |
-
"""We want triton==2.1.0 or 2.2.0 for this
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
from typing import Optional
|
| 7 |
-
|
| 8 |
-
import math
|
| 9 |
-
from packaging import version
|
| 10 |
-
|
| 11 |
-
import torch
|
| 12 |
-
import torch.nn.functional as F
|
| 13 |
-
from torch import Tensor
|
| 14 |
-
from ...utils.torch import custom_bwd, custom_fwd
|
| 15 |
-
|
| 16 |
-
import triton
|
| 17 |
-
import triton.language as tl
|
| 18 |
-
|
| 19 |
-
from einops import rearrange, repeat
|
| 20 |
-
|
| 21 |
-
try:
|
| 22 |
-
from causal_conv1d import causal_conv1d_fn
|
| 23 |
-
import causal_conv1d_cuda
|
| 24 |
-
except ImportError:
|
| 25 |
-
causal_conv1d_fn, causal_conv1d_cuda = None, None
|
| 26 |
-
|
| 27 |
-
from .ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
|
| 28 |
-
from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd
|
| 29 |
-
from .ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db
|
| 30 |
-
from .ssd_chunk_state import _chunk_state_bwd_ddAcs_stable
|
| 31 |
-
from .ssd_chunk_state import chunk_state, chunk_state_ref
|
| 32 |
-
from .ssd_chunk_state import chunk_state_varlen
|
| 33 |
-
from .ssd_state_passing import _state_passing_fwd, _state_passing_bwd
|
| 34 |
-
from .ssd_state_passing import state_passing, state_passing_ref
|
| 35 |
-
from .ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates
|
| 36 |
-
from .ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb
|
| 37 |
-
from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable
|
| 38 |
-
from .ssd_chunk_scan import chunk_scan, chunk_scan_ref
|
| 39 |
-
from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev
|
| 40 |
-
from .layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd
|
| 41 |
-
from .k_activations import _swiglu_fwd, _swiglu_bwd
|
| 42 |
-
|
| 43 |
-
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
def init_to_zero(names):
|
| 47 |
-
return lambda nargs: [
|
| 48 |
-
nargs[name].zero_() for name in names if nargs[name] is not None
|
| 49 |
-
]
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
@triton.autotune(
|
| 53 |
-
configs=[
|
| 54 |
-
triton.Config(
|
| 55 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
| 56 |
-
num_stages=3,
|
| 57 |
-
num_warps=8,
|
| 58 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 59 |
-
),
|
| 60 |
-
triton.Config(
|
| 61 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
| 62 |
-
num_stages=4,
|
| 63 |
-
num_warps=4,
|
| 64 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 65 |
-
),
|
| 66 |
-
triton.Config(
|
| 67 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 68 |
-
num_stages=4,
|
| 69 |
-
num_warps=4,
|
| 70 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 71 |
-
),
|
| 72 |
-
triton.Config(
|
| 73 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 74 |
-
num_stages=4,
|
| 75 |
-
num_warps=4,
|
| 76 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 77 |
-
),
|
| 78 |
-
triton.Config(
|
| 79 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 80 |
-
num_stages=4,
|
| 81 |
-
num_warps=4,
|
| 82 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 83 |
-
),
|
| 84 |
-
triton.Config(
|
| 85 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 86 |
-
num_stages=4,
|
| 87 |
-
num_warps=4,
|
| 88 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 89 |
-
),
|
| 90 |
-
triton.Config(
|
| 91 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 92 |
-
num_stages=5,
|
| 93 |
-
num_warps=4,
|
| 94 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 95 |
-
),
|
| 96 |
-
triton.Config(
|
| 97 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 98 |
-
num_stages=5,
|
| 99 |
-
num_warps=4,
|
| 100 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 101 |
-
),
|
| 102 |
-
triton.Config(
|
| 103 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 104 |
-
num_stages=4,
|
| 105 |
-
num_warps=4,
|
| 106 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 107 |
-
),
|
| 108 |
-
],
|
| 109 |
-
key=["chunk_size", "hdim", "dstate"],
|
| 110 |
-
)
|
| 111 |
-
@triton.jit
|
| 112 |
-
def _chunk_scan_chunk_state_bwd_dx_kernel(
|
| 113 |
-
# Pointers to matrices
|
| 114 |
-
x_ptr,
|
| 115 |
-
cb_ptr,
|
| 116 |
-
dout_ptr,
|
| 117 |
-
dt_ptr,
|
| 118 |
-
dA_cumsum_ptr,
|
| 119 |
-
seq_idx_ptr,
|
| 120 |
-
D_ptr,
|
| 121 |
-
b_ptr,
|
| 122 |
-
dstates_ptr,
|
| 123 |
-
dx_ptr,
|
| 124 |
-
ddt_ptr,
|
| 125 |
-
dD_ptr,
|
| 126 |
-
# Matrix dimensions
|
| 127 |
-
chunk_size,
|
| 128 |
-
hdim,
|
| 129 |
-
dstate,
|
| 130 |
-
batch,
|
| 131 |
-
seqlen,
|
| 132 |
-
nheads_ngroups_ratio,
|
| 133 |
-
# Strides
|
| 134 |
-
stride_x_batch,
|
| 135 |
-
stride_x_seqlen,
|
| 136 |
-
stride_x_head,
|
| 137 |
-
stride_x_hdim,
|
| 138 |
-
stride_cb_batch,
|
| 139 |
-
stride_cb_chunk,
|
| 140 |
-
stride_cb_head,
|
| 141 |
-
stride_cb_csize_m,
|
| 142 |
-
stride_cb_csize_k,
|
| 143 |
-
stride_dout_batch,
|
| 144 |
-
stride_dout_seqlen,
|
| 145 |
-
stride_dout_head,
|
| 146 |
-
stride_dout_hdim,
|
| 147 |
-
stride_dt_batch,
|
| 148 |
-
stride_dt_chunk,
|
| 149 |
-
stride_dt_head,
|
| 150 |
-
stride_dt_csize,
|
| 151 |
-
stride_dA_cs_batch,
|
| 152 |
-
stride_dA_cs_chunk,
|
| 153 |
-
stride_dA_cs_head,
|
| 154 |
-
stride_dA_cs_csize,
|
| 155 |
-
stride_seq_idx_batch,
|
| 156 |
-
stride_seq_idx_seqlen,
|
| 157 |
-
stride_D_head,
|
| 158 |
-
stride_b_batch,
|
| 159 |
-
stride_b_seqlen,
|
| 160 |
-
stride_b_head,
|
| 161 |
-
stride_b_dstate,
|
| 162 |
-
stride_dstates_batch,
|
| 163 |
-
stride_dstates_chunk,
|
| 164 |
-
stride_dstates_head,
|
| 165 |
-
stride_dstates_hdim,
|
| 166 |
-
stride_dstates_dstate,
|
| 167 |
-
stride_dx_batch,
|
| 168 |
-
stride_dx_seqlen,
|
| 169 |
-
stride_dx_head,
|
| 170 |
-
stride_dx_hdim,
|
| 171 |
-
stride_ddt_batch,
|
| 172 |
-
stride_ddt_chunk,
|
| 173 |
-
stride_ddt_head,
|
| 174 |
-
stride_ddt_csize,
|
| 175 |
-
stride_dD_batch,
|
| 176 |
-
stride_dD_chunk,
|
| 177 |
-
stride_dD_head,
|
| 178 |
-
stride_dD_csize,
|
| 179 |
-
stride_dD_hdim,
|
| 180 |
-
# Meta-parameters
|
| 181 |
-
HAS_D: tl.constexpr,
|
| 182 |
-
D_HAS_HDIM: tl.constexpr,
|
| 183 |
-
HAS_SEQ_IDX: tl.constexpr,
|
| 184 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 185 |
-
BLOCK_SIZE_N: tl.constexpr,
|
| 186 |
-
BLOCK_SIZE_K: tl.constexpr,
|
| 187 |
-
BLOCK_SIZE_DSTATE: tl.constexpr,
|
| 188 |
-
IS_TRITON_22: tl.constexpr,
|
| 189 |
-
):
|
| 190 |
-
pid_bc = tl.program_id(axis=1)
|
| 191 |
-
pid_c = pid_bc // batch
|
| 192 |
-
pid_b = pid_bc - pid_c * batch
|
| 193 |
-
pid_h = tl.program_id(axis=2)
|
| 194 |
-
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
| 195 |
-
pid_m = tl.program_id(axis=0) // num_pid_n
|
| 196 |
-
pid_n = tl.program_id(axis=0) % num_pid_n
|
| 197 |
-
x_ptr += (
|
| 198 |
-
pid_b * stride_x_batch
|
| 199 |
-
+ pid_c * chunk_size * stride_x_seqlen
|
| 200 |
-
+ pid_h * stride_x_head
|
| 201 |
-
)
|
| 202 |
-
cb_ptr += (
|
| 203 |
-
pid_b * stride_cb_batch
|
| 204 |
-
+ pid_c * stride_cb_chunk
|
| 205 |
-
+ (pid_h // nheads_ngroups_ratio) * stride_cb_head
|
| 206 |
-
)
|
| 207 |
-
dout_ptr += (
|
| 208 |
-
pid_b * stride_dout_batch
|
| 209 |
-
+ pid_c * chunk_size * stride_dout_seqlen
|
| 210 |
-
+ pid_h * stride_dout_head
|
| 211 |
-
)
|
| 212 |
-
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
| 213 |
-
ddt_ptr += (
|
| 214 |
-
pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
|
| 215 |
-
)
|
| 216 |
-
dA_cumsum_ptr += (
|
| 217 |
-
pid_b * stride_dA_cs_batch
|
| 218 |
-
+ pid_c * stride_dA_cs_chunk
|
| 219 |
-
+ pid_h * stride_dA_cs_head
|
| 220 |
-
)
|
| 221 |
-
b_ptr += (
|
| 222 |
-
pid_b * stride_b_batch
|
| 223 |
-
+ pid_c * chunk_size * stride_b_seqlen
|
| 224 |
-
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
| 225 |
-
)
|
| 226 |
-
dstates_ptr += (
|
| 227 |
-
pid_b * stride_dstates_batch
|
| 228 |
-
+ pid_c * stride_dstates_chunk
|
| 229 |
-
+ pid_h * stride_dstates_head
|
| 230 |
-
)
|
| 231 |
-
if HAS_SEQ_IDX:
|
| 232 |
-
seq_idx_ptr += (
|
| 233 |
-
pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
| 234 |
-
)
|
| 235 |
-
|
| 236 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 237 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 238 |
-
|
| 239 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 240 |
-
|
| 241 |
-
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 242 |
-
|
| 243 |
-
dA_cs_m = tl.load(
|
| 244 |
-
dA_cumsum_ptr + offs_m * stride_dA_cs_csize,
|
| 245 |
-
mask=offs_m < chunk_size_limit,
|
| 246 |
-
other=0.0,
|
| 247 |
-
).to(tl.float32)
|
| 248 |
-
|
| 249 |
-
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
| 250 |
-
tl.float32
|
| 251 |
-
)
|
| 252 |
-
if not HAS_SEQ_IDX:
|
| 253 |
-
scale = tl.exp(dA_cs_last - dA_cs_m)
|
| 254 |
-
else:
|
| 255 |
-
seq_idx_m = tl.load(
|
| 256 |
-
seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
| 257 |
-
mask=offs_m < chunk_size_limit,
|
| 258 |
-
other=-1,
|
| 259 |
-
)
|
| 260 |
-
seq_idx_last = tl.load(
|
| 261 |
-
seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
|
| 262 |
-
)
|
| 263 |
-
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
|
| 264 |
-
# Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
| 265 |
-
# However, we're getting error with the Triton compiler 2.1.0 for that code path:
|
| 266 |
-
# Unexpected mma -> mma layout conversion
|
| 267 |
-
# Triton 2.2.0 fixes this
|
| 268 |
-
offs_dstate = tl.arange(
|
| 269 |
-
0,
|
| 270 |
-
(
|
| 271 |
-
BLOCK_SIZE_DSTATE
|
| 272 |
-
if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128
|
| 273 |
-
else BLOCK_SIZE_K
|
| 274 |
-
),
|
| 275 |
-
)
|
| 276 |
-
b_ptrs = b_ptr + (
|
| 277 |
-
offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate
|
| 278 |
-
)
|
| 279 |
-
dstates_ptrs = dstates_ptr + (
|
| 280 |
-
offs_n[None, :] * stride_dstates_hdim
|
| 281 |
-
+ offs_dstate[:, None] * stride_dstates_dstate
|
| 282 |
-
)
|
| 283 |
-
if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128:
|
| 284 |
-
b = tl.load(
|
| 285 |
-
b_ptrs,
|
| 286 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate),
|
| 287 |
-
other=0.0,
|
| 288 |
-
)
|
| 289 |
-
dstates = tl.load(
|
| 290 |
-
dstates_ptrs,
|
| 291 |
-
mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim),
|
| 292 |
-
other=0.0,
|
| 293 |
-
)
|
| 294 |
-
dstates = dstates.to(b_ptr.dtype.element_ty)
|
| 295 |
-
acc = tl.dot(b, dstates) * scale[:, None]
|
| 296 |
-
else:
|
| 297 |
-
for k in range(0, dstate, BLOCK_SIZE_K):
|
| 298 |
-
b = tl.load(
|
| 299 |
-
b_ptrs,
|
| 300 |
-
mask=(offs_m[:, None] < chunk_size_limit)
|
| 301 |
-
& (offs_dstate[None, :] < dstate - k),
|
| 302 |
-
other=0.0,
|
| 303 |
-
)
|
| 304 |
-
dstates = tl.load(
|
| 305 |
-
dstates_ptrs,
|
| 306 |
-
mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim),
|
| 307 |
-
other=0.0,
|
| 308 |
-
)
|
| 309 |
-
dstates = dstates.to(b_ptr.dtype.element_ty)
|
| 310 |
-
acc += tl.dot(b, dstates)
|
| 311 |
-
b_ptrs += BLOCK_SIZE_K * stride_b_dstate
|
| 312 |
-
dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate
|
| 313 |
-
acc *= scale[:, None]
|
| 314 |
-
|
| 315 |
-
# x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
|
| 316 |
-
# x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
|
| 317 |
-
# dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
| 318 |
-
# dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
|
| 319 |
-
# ddt = tl.sum(acc * x, axis=1) * dt_m
|
| 320 |
-
# ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
|
| 321 |
-
# tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
|
| 322 |
-
|
| 323 |
-
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 324 |
-
cb_ptrs = cb_ptr + (
|
| 325 |
-
offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k
|
| 326 |
-
)
|
| 327 |
-
dout_ptrs = dout_ptr + (
|
| 328 |
-
offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim
|
| 329 |
-
)
|
| 330 |
-
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
| 331 |
-
K_MAX = chunk_size_limit
|
| 332 |
-
K_MIN = pid_m * BLOCK_SIZE_M
|
| 333 |
-
cb_ptrs += K_MIN * stride_cb_csize_k
|
| 334 |
-
dout_ptrs += K_MIN * stride_dout_seqlen
|
| 335 |
-
dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize
|
| 336 |
-
for k in range(K_MIN, K_MAX, BLOCK_SIZE_K):
|
| 337 |
-
k = tl.multiple_of(k, BLOCK_SIZE_K)
|
| 338 |
-
# For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower
|
| 339 |
-
cb = tl.load(
|
| 340 |
-
cb_ptrs,
|
| 341 |
-
mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k),
|
| 342 |
-
other=0.0,
|
| 343 |
-
)
|
| 344 |
-
dout = tl.load(
|
| 345 |
-
dout_ptrs,
|
| 346 |
-
mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim),
|
| 347 |
-
other=0.0,
|
| 348 |
-
)
|
| 349 |
-
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(
|
| 350 |
-
tl.float32
|
| 351 |
-
)
|
| 352 |
-
cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
|
| 353 |
-
# If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
|
| 354 |
-
# we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
|
| 355 |
-
# Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
|
| 356 |
-
# This will cause NaN in acc, and hence NaN in dx and ddt.
|
| 357 |
-
mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX)
|
| 358 |
-
cb = tl.where(mask, cb, 0.0)
|
| 359 |
-
cb = cb.to(dout_ptr.dtype.element_ty)
|
| 360 |
-
acc += tl.dot(cb, dout)
|
| 361 |
-
cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
|
| 362 |
-
dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen
|
| 363 |
-
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
| 364 |
-
|
| 365 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 366 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 367 |
-
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
| 368 |
-
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
|
| 369 |
-
dx = acc * dt_m[:, None]
|
| 370 |
-
dx_ptr += (
|
| 371 |
-
pid_b * stride_dx_batch
|
| 372 |
-
+ pid_c * chunk_size * stride_dx_seqlen
|
| 373 |
-
+ pid_h * stride_dx_head
|
| 374 |
-
)
|
| 375 |
-
dx_ptrs = dx_ptr + (
|
| 376 |
-
offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim
|
| 377 |
-
)
|
| 378 |
-
if HAS_D:
|
| 379 |
-
dout_res_ptrs = dout_ptr + (
|
| 380 |
-
offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim
|
| 381 |
-
)
|
| 382 |
-
dout_res = tl.load(
|
| 383 |
-
dout_res_ptrs,
|
| 384 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
| 385 |
-
other=0.0,
|
| 386 |
-
).to(tl.float32)
|
| 387 |
-
if D_HAS_HDIM:
|
| 388 |
-
D = tl.load(
|
| 389 |
-
D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0
|
| 390 |
-
).to(tl.float32)
|
| 391 |
-
else:
|
| 392 |
-
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
|
| 393 |
-
dx += dout_res * D
|
| 394 |
-
tl.store(
|
| 395 |
-
dx_ptrs,
|
| 396 |
-
dx,
|
| 397 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
| 398 |
-
)
|
| 399 |
-
|
| 400 |
-
x_ptrs = x_ptr + (
|
| 401 |
-
offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
|
| 402 |
-
)
|
| 403 |
-
x = tl.load(
|
| 404 |
-
x_ptrs,
|
| 405 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
| 406 |
-
other=0.0,
|
| 407 |
-
).to(tl.float32)
|
| 408 |
-
if HAS_D:
|
| 409 |
-
dD_ptr += (
|
| 410 |
-
pid_b * stride_dD_batch
|
| 411 |
-
+ pid_c * stride_dD_chunk
|
| 412 |
-
+ pid_h * stride_dD_head
|
| 413 |
-
+ pid_m * stride_dD_csize
|
| 414 |
-
)
|
| 415 |
-
if D_HAS_HDIM:
|
| 416 |
-
dD_ptrs = dD_ptr + offs_n * stride_dD_hdim
|
| 417 |
-
dD = tl.sum(dout_res * x, axis=0)
|
| 418 |
-
tl.store(dD_ptrs, dD, mask=offs_n < hdim)
|
| 419 |
-
else:
|
| 420 |
-
dD = tl.sum(dout_res * x)
|
| 421 |
-
tl.store(dD_ptr, dD)
|
| 422 |
-
ddt = tl.sum(acc * x, axis=1)
|
| 423 |
-
ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
|
| 424 |
-
tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
def _chunk_scan_chunk_state_bwd_dx(
|
| 428 |
-
x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None
|
| 429 |
-
):
|
| 430 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 431 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 432 |
-
_, _, ngroups, dstate = B.shape
|
| 433 |
-
assert nheads % ngroups == 0
|
| 434 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 435 |
-
assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
|
| 436 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 437 |
-
assert dA_cumsum.shape == dt.shape
|
| 438 |
-
assert dout.shape == x.shape
|
| 439 |
-
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
| 440 |
-
if seq_idx is not None:
|
| 441 |
-
assert seq_idx.shape == (batch, seqlen)
|
| 442 |
-
if D is not None:
|
| 443 |
-
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
|
| 444 |
-
assert D.stride(-1) == 1
|
| 445 |
-
BLOCK_SIZE_min = 32
|
| 446 |
-
dD = torch.empty(
|
| 447 |
-
triton.cdiv(chunk_size, BLOCK_SIZE_min),
|
| 448 |
-
batch,
|
| 449 |
-
nchunks,
|
| 450 |
-
nheads,
|
| 451 |
-
headdim if D.dim() == 2 else 1,
|
| 452 |
-
device=D.device,
|
| 453 |
-
dtype=torch.float32,
|
| 454 |
-
)
|
| 455 |
-
else:
|
| 456 |
-
dD = None
|
| 457 |
-
dD_strides = (
|
| 458 |
-
(dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))
|
| 459 |
-
if D is not None
|
| 460 |
-
else (0, 0, 0, 0, 0)
|
| 461 |
-
)
|
| 462 |
-
if dx is None:
|
| 463 |
-
dx = torch.empty_like(x)
|
| 464 |
-
else:
|
| 465 |
-
assert dx.shape == x.shape
|
| 466 |
-
ddt = torch.empty(
|
| 467 |
-
batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32
|
| 468 |
-
)
|
| 469 |
-
grid_dx = lambda META: (
|
| 470 |
-
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
| 471 |
-
* triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
|
| 472 |
-
batch * nchunks,
|
| 473 |
-
nheads,
|
| 474 |
-
)
|
| 475 |
-
with torch.cuda.device(x.device.index):
|
| 476 |
-
_chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](
|
| 477 |
-
x,
|
| 478 |
-
CB,
|
| 479 |
-
dout,
|
| 480 |
-
dt,
|
| 481 |
-
dA_cumsum,
|
| 482 |
-
seq_idx,
|
| 483 |
-
D,
|
| 484 |
-
B,
|
| 485 |
-
dstates,
|
| 486 |
-
dx,
|
| 487 |
-
ddt,
|
| 488 |
-
dD,
|
| 489 |
-
chunk_size,
|
| 490 |
-
headdim,
|
| 491 |
-
dstate,
|
| 492 |
-
batch,
|
| 493 |
-
seqlen,
|
| 494 |
-
nheads // ngroups,
|
| 495 |
-
x.stride(0),
|
| 496 |
-
x.stride(1),
|
| 497 |
-
x.stride(2),
|
| 498 |
-
x.stride(3),
|
| 499 |
-
CB.stride(0),
|
| 500 |
-
CB.stride(1),
|
| 501 |
-
CB.stride(2),
|
| 502 |
-
CB.stride(-1),
|
| 503 |
-
CB.stride(-2),
|
| 504 |
-
dout.stride(0),
|
| 505 |
-
dout.stride(1),
|
| 506 |
-
dout.stride(2),
|
| 507 |
-
dout.stride(3),
|
| 508 |
-
dt.stride(0),
|
| 509 |
-
dt.stride(2),
|
| 510 |
-
dt.stride(1),
|
| 511 |
-
dt.stride(3),
|
| 512 |
-
dA_cumsum.stride(0),
|
| 513 |
-
dA_cumsum.stride(2),
|
| 514 |
-
dA_cumsum.stride(1),
|
| 515 |
-
dA_cumsum.stride(3),
|
| 516 |
-
*(
|
| 517 |
-
(seq_idx.stride(0), seq_idx.stride(1))
|
| 518 |
-
if seq_idx is not None
|
| 519 |
-
else (0, 0)
|
| 520 |
-
),
|
| 521 |
-
D.stride(0) if D is not None else 0,
|
| 522 |
-
B.stride(0),
|
| 523 |
-
B.stride(1),
|
| 524 |
-
B.stride(2),
|
| 525 |
-
B.stride(3),
|
| 526 |
-
dstates.stride(0),
|
| 527 |
-
dstates.stride(1),
|
| 528 |
-
dstates.stride(2),
|
| 529 |
-
dstates.stride(3),
|
| 530 |
-
dstates.stride(4),
|
| 531 |
-
dx.stride(0),
|
| 532 |
-
dx.stride(1),
|
| 533 |
-
dx.stride(2),
|
| 534 |
-
dx.stride(3),
|
| 535 |
-
ddt.stride(0),
|
| 536 |
-
ddt.stride(2),
|
| 537 |
-
ddt.stride(1),
|
| 538 |
-
ddt.stride(3),
|
| 539 |
-
dD_strides[1],
|
| 540 |
-
dD_strides[2],
|
| 541 |
-
dD_strides[3],
|
| 542 |
-
dD_strides[0],
|
| 543 |
-
dD_strides[4],
|
| 544 |
-
D is not None,
|
| 545 |
-
D.dim() == 2 if D is not None else True,
|
| 546 |
-
HAS_SEQ_IDX=seq_idx is not None,
|
| 547 |
-
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
| 548 |
-
IS_TRITON_22=TRITON_22
|
| 549 |
-
)
|
| 550 |
-
if D is not None:
|
| 551 |
-
BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs[
|
| 552 |
-
"BLOCK_SIZE_M"
|
| 553 |
-
]
|
| 554 |
-
n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
|
| 555 |
-
dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
|
| 556 |
-
if D.dim() == 1:
|
| 557 |
-
dD = rearrange(dD, "h 1 -> h")
|
| 558 |
-
return dx, ddt.to(dtype=dt.dtype), dD
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
def _mamba_chunk_scan_combined_fwd(
|
| 562 |
-
x,
|
| 563 |
-
dt,
|
| 564 |
-
A,
|
| 565 |
-
B,
|
| 566 |
-
C,
|
| 567 |
-
chunk_size,
|
| 568 |
-
D=None,
|
| 569 |
-
z=None,
|
| 570 |
-
dt_bias=None,
|
| 571 |
-
initial_states=None,
|
| 572 |
-
seq_idx=None,
|
| 573 |
-
cu_seqlens=None,
|
| 574 |
-
dt_softplus=False,
|
| 575 |
-
dt_limit=(0.0, float("inf")),
|
| 576 |
-
):
|
| 577 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 578 |
-
_, _, ngroups, dstate = B.shape
|
| 579 |
-
assert nheads % ngroups == 0
|
| 580 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 581 |
-
assert x.shape == (batch, seqlen, nheads, headdim)
|
| 582 |
-
assert dt.shape == (batch, seqlen, nheads)
|
| 583 |
-
assert A.shape == (nheads,)
|
| 584 |
-
assert C.shape == B.shape
|
| 585 |
-
if z is not None:
|
| 586 |
-
assert z.shape == x.shape
|
| 587 |
-
if D is not None:
|
| 588 |
-
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
|
| 589 |
-
if seq_idx is not None:
|
| 590 |
-
assert seq_idx.shape == (batch, seqlen)
|
| 591 |
-
if B.stride(-1) != 1:
|
| 592 |
-
B = B.contiguous()
|
| 593 |
-
if C.stride(-1) != 1:
|
| 594 |
-
C = C.contiguous()
|
| 595 |
-
if (
|
| 596 |
-
x.stride(-1) != 1 and x.stride(1) != 1
|
| 597 |
-
): # Either M or K dimension should be contiguous
|
| 598 |
-
x = x.contiguous()
|
| 599 |
-
if (
|
| 600 |
-
z is not None and z.stride(-1) != 1 and z.stride(1) != 1
|
| 601 |
-
): # Either M or K dimension should be contiguous
|
| 602 |
-
z = z.contiguous()
|
| 603 |
-
if D is not None and D.stride(-1) != 1:
|
| 604 |
-
D = D.contiguous()
|
| 605 |
-
if initial_states is not None:
|
| 606 |
-
assert initial_states.shape == (batch, nheads, headdim, dstate)
|
| 607 |
-
# # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size)
|
| 608 |
-
# dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
|
| 609 |
-
# dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
|
| 610 |
-
# dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
|
| 611 |
-
dA_cumsum, dt = _chunk_cumsum_fwd(
|
| 612 |
-
dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit
|
| 613 |
-
)
|
| 614 |
-
states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
|
| 615 |
-
# states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True)
|
| 616 |
-
# states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True)
|
| 617 |
-
# states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True)
|
| 618 |
-
states, final_states = _state_passing_fwd(
|
| 619 |
-
rearrange(states, "... p n -> ... (p n)"),
|
| 620 |
-
dA_cumsum[:, :, :, -1],
|
| 621 |
-
initial_states=(
|
| 622 |
-
rearrange(initial_states, "... p n -> ... (p n)")
|
| 623 |
-
if initial_states is not None
|
| 624 |
-
else None
|
| 625 |
-
),
|
| 626 |
-
seq_idx=seq_idx,
|
| 627 |
-
chunk_size=chunk_size,
|
| 628 |
-
out_dtype=C.dtype,
|
| 629 |
-
)
|
| 630 |
-
states, final_states = [
|
| 631 |
-
rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]
|
| 632 |
-
]
|
| 633 |
-
# states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
|
| 634 |
-
# states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
|
| 635 |
-
CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
|
| 636 |
-
out, out_x = _chunk_scan_fwd(
|
| 637 |
-
CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx
|
| 638 |
-
)
|
| 639 |
-
if cu_seqlens is None:
|
| 640 |
-
return out, out_x, dt, dA_cumsum, states, final_states
|
| 641 |
-
else:
|
| 642 |
-
assert (
|
| 643 |
-
batch == 1
|
| 644 |
-
), "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
|
| 645 |
-
varlen_states = chunk_state_varlen(
|
| 646 |
-
B.squeeze(0),
|
| 647 |
-
x.squeeze(0),
|
| 648 |
-
dt.squeeze(0),
|
| 649 |
-
dA_cumsum.squeeze(0),
|
| 650 |
-
cu_seqlens,
|
| 651 |
-
states.squeeze(0),
|
| 652 |
-
)
|
| 653 |
-
return out, out_x, dt, dA_cumsum, states, final_states, varlen_states
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
def _mamba_chunk_scan_combined_bwd(
|
| 657 |
-
dout,
|
| 658 |
-
x,
|
| 659 |
-
dt,
|
| 660 |
-
A,
|
| 661 |
-
B,
|
| 662 |
-
C,
|
| 663 |
-
out,
|
| 664 |
-
chunk_size,
|
| 665 |
-
D=None,
|
| 666 |
-
z=None,
|
| 667 |
-
dt_bias=None,
|
| 668 |
-
initial_states=None,
|
| 669 |
-
dfinal_states=None,
|
| 670 |
-
seq_idx=None,
|
| 671 |
-
dt_softplus=False,
|
| 672 |
-
dt_limit=(0.0, float("inf")),
|
| 673 |
-
dx=None,
|
| 674 |
-
ddt=None,
|
| 675 |
-
dB=None,
|
| 676 |
-
dC=None,
|
| 677 |
-
dz=None,
|
| 678 |
-
recompute_output=False,
|
| 679 |
-
):
|
| 680 |
-
if dout.stride(-1) != 1:
|
| 681 |
-
dout = dout.contiguous()
|
| 682 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 683 |
-
nchunks = math.ceil(seqlen / chunk_size)
|
| 684 |
-
_, _, ngroups, dstate = B.shape
|
| 685 |
-
assert dout.shape == (batch, seqlen, nheads, headdim)
|
| 686 |
-
assert dt.shape == (batch, seqlen, nheads)
|
| 687 |
-
assert A.shape == (nheads,)
|
| 688 |
-
assert nheads % ngroups == 0
|
| 689 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 690 |
-
assert C.shape == B.shape
|
| 691 |
-
assert out.shape == x.shape
|
| 692 |
-
if initial_states is not None:
|
| 693 |
-
assert initial_states.shape == (batch, nheads, headdim, dstate)
|
| 694 |
-
if seq_idx is not None:
|
| 695 |
-
assert seq_idx.shape == (batch, seqlen)
|
| 696 |
-
if dx is not None:
|
| 697 |
-
assert dx.shape == x.shape
|
| 698 |
-
if dB is not None:
|
| 699 |
-
assert dB.shape == B.shape
|
| 700 |
-
dB_given = dB
|
| 701 |
-
else:
|
| 702 |
-
dB_given = torch.empty_like(B)
|
| 703 |
-
if dC is not None:
|
| 704 |
-
assert dC.shape == C.shape
|
| 705 |
-
dC_given = dC
|
| 706 |
-
else:
|
| 707 |
-
dC_given = torch.empty_like(C)
|
| 708 |
-
if dz is not None:
|
| 709 |
-
assert z is not None
|
| 710 |
-
assert dz.shape == z.shape
|
| 711 |
-
if ddt is not None:
|
| 712 |
-
assert ddt.shape == dt.shape
|
| 713 |
-
ddt_given = ddt
|
| 714 |
-
else:
|
| 715 |
-
ddt_given = torch.empty_like(dt)
|
| 716 |
-
# TD: For some reason Triton (2.1.0 and 2.2.0) errors with
|
| 717 |
-
# "[CUDA]: invalid device context" (e.g. during varlne test), and cloning makes it work. Idk why.
|
| 718 |
-
dt_in = dt.clone()
|
| 719 |
-
dA_cumsum, dt = _chunk_cumsum_fwd(
|
| 720 |
-
dt_in,
|
| 721 |
-
A,
|
| 722 |
-
chunk_size,
|
| 723 |
-
dt_bias=dt_bias,
|
| 724 |
-
dt_softplus=dt_softplus,
|
| 725 |
-
dt_limit=dt_limit,
|
| 726 |
-
)
|
| 727 |
-
CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
|
| 728 |
-
states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
|
| 729 |
-
states, _ = _state_passing_fwd(
|
| 730 |
-
rearrange(states, "... p n -> ... (p n)"),
|
| 731 |
-
dA_cumsum[:, :, :, -1],
|
| 732 |
-
initial_states=(
|
| 733 |
-
rearrange(initial_states, "... p n -> ... (p n)")
|
| 734 |
-
if initial_states is not None
|
| 735 |
-
else None
|
| 736 |
-
),
|
| 737 |
-
seq_idx=seq_idx,
|
| 738 |
-
chunk_size=chunk_size,
|
| 739 |
-
)
|
| 740 |
-
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
|
| 741 |
-
if z is not None:
|
| 742 |
-
dz, dout, dD, *rest = _chunk_scan_bwd_dz(
|
| 743 |
-
x,
|
| 744 |
-
z,
|
| 745 |
-
out,
|
| 746 |
-
dout,
|
| 747 |
-
chunk_size=chunk_size,
|
| 748 |
-
has_ddAcs=False,
|
| 749 |
-
D=D,
|
| 750 |
-
dz=dz,
|
| 751 |
-
recompute_output=recompute_output,
|
| 752 |
-
)
|
| 753 |
-
outz = rest[0] if recompute_output else out
|
| 754 |
-
else:
|
| 755 |
-
dz = None
|
| 756 |
-
outz = out
|
| 757 |
-
dstates = _chunk_scan_bwd_dstates(
|
| 758 |
-
C, dA_cumsum, dout, seq_idx=seq_idx, dtype=states.dtype
|
| 759 |
-
)
|
| 760 |
-
# dstates has length nchunks, containing the gradient to initial states at index 0 and
|
| 761 |
-
# gradient to the states of chunk (nchunks - 2) at index (nchunks - 1)
|
| 762 |
-
# Do computation in fp32 but convert dstates and states to fp16/bf16 since dstates and states
|
| 763 |
-
# will be used in matmul in the next kernels.
|
| 764 |
-
dstates, ddA_chunk_cumsum, dinitial_states, states = _state_passing_bwd(
|
| 765 |
-
rearrange(states, "... p n -> ... (p n)"),
|
| 766 |
-
dA_cumsum[:, :, :, -1],
|
| 767 |
-
rearrange(dstates, "... p n -> ... (p n)"),
|
| 768 |
-
dfinal_states=(
|
| 769 |
-
rearrange(dfinal_states, "... p n -> ... (p n)")
|
| 770 |
-
if dfinal_states is not None
|
| 771 |
-
else None
|
| 772 |
-
),
|
| 773 |
-
seq_idx=seq_idx,
|
| 774 |
-
has_initial_states=initial_states is not None,
|
| 775 |
-
dstates_dtype=x.dtype,
|
| 776 |
-
states_dtype=x.dtype,
|
| 777 |
-
chunk_size=chunk_size,
|
| 778 |
-
)
|
| 779 |
-
# dstates has length nchunks, containing the gradient to states of chunk 0 at index 0 and
|
| 780 |
-
# gradient to the final states at index (nchunks - 1)
|
| 781 |
-
# states has length nchunks, containing the initial states at index 0 and the state for chunk (nchunks - 2) at index (nchunks - 1)
|
| 782 |
-
# The final states is not stored.
|
| 783 |
-
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
|
| 784 |
-
dstates = rearrange(dstates, "... (p n) -> ... p n", n=dstate)
|
| 785 |
-
dinitial_states = (
|
| 786 |
-
rearrange(dinitial_states, "... (p n) -> ... p n", n=dstate)
|
| 787 |
-
if dinitial_states is not None
|
| 788 |
-
else None
|
| 789 |
-
)
|
| 790 |
-
dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(
|
| 791 |
-
x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx
|
| 792 |
-
)
|
| 793 |
-
# dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, ngroups=ngroups)
|
| 794 |
-
dB, ddA_next = _chunk_state_bwd_db(
|
| 795 |
-
x, dt, dA_cumsum, dstates, seq_idx=seq_idx, B=B, ngroups=ngroups
|
| 796 |
-
)
|
| 797 |
-
# dC = _chunk_scan_bwd_dC(states[:, :-1].to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
|
| 798 |
-
dC, ddA_cumsum_prev = _chunk_scan_bwd_dC(
|
| 799 |
-
states.to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, C=C, ngroups=ngroups
|
| 800 |
-
)
|
| 801 |
-
# Computing ddA with the dcb kernel is much slower, so we're not using it for now
|
| 802 |
-
dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
|
| 803 |
-
# dCB, ddA_tmp = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, CB=CB, ngroups=ngroups)
|
| 804 |
-
dCB = dCB.to(CB.dtype)
|
| 805 |
-
_bmm_chunk_bwd(C, dCB, residual=dB, out=dB_given)
|
| 806 |
-
_bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC, out=dC_given)
|
| 807 |
-
# If we have z, then dout_x is recomputed in fp32 so dD = (dout_x * x).sum() is more accurate
|
| 808 |
-
# than dD_from_x = (dout_x * x).sum() where dout_x is in fp16/bf16
|
| 809 |
-
if z is None:
|
| 810 |
-
dD = dD_from_x
|
| 811 |
-
# Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D.
|
| 812 |
-
# ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt
|
| 813 |
-
# However, this is numerically unstable: when we do the reverse cumsum on ddA_cumsum, there might
|
| 814 |
-
# be a lot of underflow.
|
| 815 |
-
|
| 816 |
-
# This is already done as part of bwd_dC kernel
|
| 817 |
-
# ddA_cumsum_prev = _chunk_scan_bwd_ddAcs_prev(states[:, :-1], C, dout, dA_cumsum, seq_idx=seq_idx)
|
| 818 |
-
ddA_cumsum_prev[..., -1] += ddA_chunk_cumsum
|
| 819 |
-
ddA_prev = ddA_cumsum_prev.flip([-1]).cumsum(dim=-1).flip([-1])
|
| 820 |
-
# This is already done as part of bwd_dB kernel
|
| 821 |
-
# ddA_next = _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=seq_idx)
|
| 822 |
-
# We don't need to pass in seq_idx because CB also zeros out entries where seq_idx[i] != seq_idx[j]
|
| 823 |
-
ddA = _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, CB)
|
| 824 |
-
ddA += ddA_next + ddA_prev
|
| 825 |
-
|
| 826 |
-
ddt_given, dA, ddt_bias = _chunk_cumsum_bwd(
|
| 827 |
-
ddA,
|
| 828 |
-
ddt,
|
| 829 |
-
dt_in,
|
| 830 |
-
A,
|
| 831 |
-
dt_bias=dt_bias,
|
| 832 |
-
dt_softplus=dt_softplus,
|
| 833 |
-
dt_limit=dt_limit,
|
| 834 |
-
ddt=ddt_given,
|
| 835 |
-
)
|
| 836 |
-
|
| 837 |
-
# These 2 lines are just to test ddt and dA being computed by old code
|
| 838 |
-
# _, dA = selective_scan_bwd(dout, x, dt, A, B, C, D=D.float(), z=z)
|
| 839 |
-
# ddt_given.copy_(ddt)
|
| 840 |
-
|
| 841 |
-
return_vals = (
|
| 842 |
-
dx,
|
| 843 |
-
ddt_given,
|
| 844 |
-
dA,
|
| 845 |
-
dB_given,
|
| 846 |
-
dC_given,
|
| 847 |
-
dD,
|
| 848 |
-
dz,
|
| 849 |
-
ddt_bias,
|
| 850 |
-
dinitial_states,
|
| 851 |
-
)
|
| 852 |
-
return return_vals if not recompute_output else (*return_vals, outz)
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
def selective_scan_bwd(dout, x, dt, A, B, C, D=None, z=None):
|
| 856 |
-
"""
|
| 857 |
-
Argument:
|
| 858 |
-
dout: (batch, seqlen, nheads, headdim)
|
| 859 |
-
x: (batch, seqlen, nheads, headdim)
|
| 860 |
-
dt: (batch, nheads, nchunks, chunk_size) or (batch, nheads, headdim, nchunks, chunk_size)
|
| 861 |
-
A: (nheads) or (dim, dstate)
|
| 862 |
-
B: (batch, seqlen, ngroups, dstate)
|
| 863 |
-
C: (batch, seqlen, ngroups, dstate)
|
| 864 |
-
D: (nheads, headdim) or (nheads,)
|
| 865 |
-
z: (batch, seqlen, nheads, headdim)
|
| 866 |
-
Return:
|
| 867 |
-
out: (batch, seqlen, nheads, headdim)
|
| 868 |
-
"""
|
| 869 |
-
import selective_scan
|
| 870 |
-
|
| 871 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 872 |
-
chunk_size = dt.shape[-1]
|
| 873 |
-
_, _, ngroups, dstate = B.shape
|
| 874 |
-
assert nheads % ngroups == 0
|
| 875 |
-
x = rearrange(x, "b l h p -> b (h p) l")
|
| 876 |
-
squeeze_dt = dt.dim() == 4
|
| 877 |
-
if dt.dim() == 4:
|
| 878 |
-
dt = repeat(dt, "b h c l -> b h p c l", p=headdim)
|
| 879 |
-
dt = rearrange(dt, "b h p c l -> b (h p) (c l)", p=headdim)
|
| 880 |
-
squeeze_A = A.dim() == 1
|
| 881 |
-
if A.dim() == 1:
|
| 882 |
-
A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
|
| 883 |
-
else:
|
| 884 |
-
A = A.to(dtype=torch.float32)
|
| 885 |
-
B = rearrange(B, "b l g n -> b g n l")
|
| 886 |
-
C = rearrange(C, "b l g n -> b g n l")
|
| 887 |
-
if D is not None:
|
| 888 |
-
if D.dim() == 2:
|
| 889 |
-
D = rearrange(D, "h p -> (h p)")
|
| 890 |
-
else:
|
| 891 |
-
D = repeat(D, "h -> (h p)", p=headdim)
|
| 892 |
-
if z is not None:
|
| 893 |
-
z = rearrange(z, "b l h p -> b (h p) l")
|
| 894 |
-
|
| 895 |
-
if x.stride(-1) != 1:
|
| 896 |
-
x = x.contiguous()
|
| 897 |
-
if dt.stride(-1) != 1:
|
| 898 |
-
dt = dt.contiguous()
|
| 899 |
-
if D is not None:
|
| 900 |
-
D = D.contiguous()
|
| 901 |
-
if B.stride(-1) != 1:
|
| 902 |
-
B = B.contiguous()
|
| 903 |
-
if C.stride(-1) != 1:
|
| 904 |
-
C = C.contiguous()
|
| 905 |
-
if z is not None and z.stride(-1) != 1:
|
| 906 |
-
z = z.contiguous()
|
| 907 |
-
_, intermediate, *rest = selective_scan.fwd(
|
| 908 |
-
x, dt.to(dtype=x.dtype), A, B, C, D, z, None, False
|
| 909 |
-
)
|
| 910 |
-
if z is not None:
|
| 911 |
-
out = rest[0]
|
| 912 |
-
else:
|
| 913 |
-
out = None
|
| 914 |
-
|
| 915 |
-
dout = rearrange(dout, "b l h p -> b (h p) l")
|
| 916 |
-
|
| 917 |
-
if dout.stride(-1) != 1:
|
| 918 |
-
dout = dout.contiguous()
|
| 919 |
-
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
| 920 |
-
# backward of selective_scan with the backward of chunk).
|
| 921 |
-
# Here we just pass in None and dz will be allocated in the C++ code.
|
| 922 |
-
_, ddt, dA, *rest = selective_scan.bwd(
|
| 923 |
-
x,
|
| 924 |
-
dt.to(dtype=x.dtype),
|
| 925 |
-
A,
|
| 926 |
-
B,
|
| 927 |
-
C,
|
| 928 |
-
D,
|
| 929 |
-
z,
|
| 930 |
-
None,
|
| 931 |
-
dout,
|
| 932 |
-
intermediate,
|
| 933 |
-
out,
|
| 934 |
-
None,
|
| 935 |
-
False,
|
| 936 |
-
False, # option to recompute out_z, not used here
|
| 937 |
-
)
|
| 938 |
-
ddt = rearrange(ddt, "b (h p) (c l) -> b h p c l", p=headdim, l=chunk_size)
|
| 939 |
-
if squeeze_dt:
|
| 940 |
-
ddt = ddt.float().sum(dim=2)
|
| 941 |
-
if squeeze_A:
|
| 942 |
-
dA = rearrange(dA, "(h p) n -> h p n", p=headdim).sum(dim=(1, 2))
|
| 943 |
-
return ddt, dA
|
| 944 |
-
|
| 945 |
-
|
| 946 |
-
class MambaChunkScanCombinedFn(torch.autograd.Function):
|
| 947 |
-
|
| 948 |
-
@staticmethod
|
| 949 |
-
def forward(
|
| 950 |
-
ctx,
|
| 951 |
-
x,
|
| 952 |
-
dt,
|
| 953 |
-
A,
|
| 954 |
-
B,
|
| 955 |
-
C,
|
| 956 |
-
chunk_size,
|
| 957 |
-
D=None,
|
| 958 |
-
z=None,
|
| 959 |
-
dt_bias=None,
|
| 960 |
-
initial_states=None,
|
| 961 |
-
seq_idx=None,
|
| 962 |
-
cu_seqlens=None,
|
| 963 |
-
dt_softplus=False,
|
| 964 |
-
dt_limit=(0.0, float("inf")),
|
| 965 |
-
return_final_states=False,
|
| 966 |
-
return_varlen_states=False,
|
| 967 |
-
):
|
| 968 |
-
ctx.dt_dtype = dt.dtype
|
| 969 |
-
if not return_varlen_states:
|
| 970 |
-
cu_seqlens = None
|
| 971 |
-
else:
|
| 972 |
-
assert (
|
| 973 |
-
cu_seqlens is not None
|
| 974 |
-
), "cu_seqlens must be provided if return_varlen_states is True"
|
| 975 |
-
out, out_x, dt_out, dA_cumsum, states, final_states, *rest = (
|
| 976 |
-
_mamba_chunk_scan_combined_fwd(
|
| 977 |
-
x,
|
| 978 |
-
dt,
|
| 979 |
-
A,
|
| 980 |
-
B,
|
| 981 |
-
C,
|
| 982 |
-
chunk_size,
|
| 983 |
-
D=D,
|
| 984 |
-
z=z,
|
| 985 |
-
dt_bias=dt_bias,
|
| 986 |
-
initial_states=initial_states,
|
| 987 |
-
seq_idx=seq_idx,
|
| 988 |
-
cu_seqlens=cu_seqlens,
|
| 989 |
-
dt_softplus=dt_softplus,
|
| 990 |
-
dt_limit=dt_limit,
|
| 991 |
-
)
|
| 992 |
-
)
|
| 993 |
-
ctx.save_for_backward(
|
| 994 |
-
out if z is None else out_x,
|
| 995 |
-
x,
|
| 996 |
-
dt,
|
| 997 |
-
dA_cumsum,
|
| 998 |
-
A,
|
| 999 |
-
B,
|
| 1000 |
-
C,
|
| 1001 |
-
D,
|
| 1002 |
-
z,
|
| 1003 |
-
dt_bias,
|
| 1004 |
-
initial_states,
|
| 1005 |
-
seq_idx,
|
| 1006 |
-
)
|
| 1007 |
-
ctx.dt_softplus = dt_softplus
|
| 1008 |
-
ctx.chunk_size = chunk_size
|
| 1009 |
-
ctx.dt_limit = dt_limit
|
| 1010 |
-
ctx.return_final_states = return_final_states
|
| 1011 |
-
ctx.return_varlen_states = return_varlen_states
|
| 1012 |
-
if not return_varlen_states:
|
| 1013 |
-
return out if not return_final_states else (out, final_states)
|
| 1014 |
-
else:
|
| 1015 |
-
varlen_states = rest[0]
|
| 1016 |
-
return (
|
| 1017 |
-
(out, varlen_states)
|
| 1018 |
-
if not return_final_states
|
| 1019 |
-
else (out, final_states, varlen_states)
|
| 1020 |
-
)
|
| 1021 |
-
|
| 1022 |
-
@staticmethod
|
| 1023 |
-
def backward(ctx, dout, *args):
|
| 1024 |
-
out, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx = (
|
| 1025 |
-
ctx.saved_tensors
|
| 1026 |
-
)
|
| 1027 |
-
assert (
|
| 1028 |
-
not ctx.return_varlen_states
|
| 1029 |
-
), "return_varlen_states is not supported in backward"
|
| 1030 |
-
dfinal_states = args[0] if ctx.return_final_states else None
|
| 1031 |
-
dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = (
|
| 1032 |
-
_mamba_chunk_scan_combined_bwd(
|
| 1033 |
-
dout,
|
| 1034 |
-
x,
|
| 1035 |
-
dt,
|
| 1036 |
-
A,
|
| 1037 |
-
B,
|
| 1038 |
-
C,
|
| 1039 |
-
out,
|
| 1040 |
-
ctx.chunk_size,
|
| 1041 |
-
D=D,
|
| 1042 |
-
z=z,
|
| 1043 |
-
dt_bias=dt_bias,
|
| 1044 |
-
initial_states=initial_states,
|
| 1045 |
-
dfinal_states=dfinal_states,
|
| 1046 |
-
seq_idx=seq_idx,
|
| 1047 |
-
dt_softplus=ctx.dt_softplus,
|
| 1048 |
-
dt_limit=ctx.dt_limit,
|
| 1049 |
-
)
|
| 1050 |
-
)
|
| 1051 |
-
return (
|
| 1052 |
-
dx,
|
| 1053 |
-
ddt,
|
| 1054 |
-
dA,
|
| 1055 |
-
dB,
|
| 1056 |
-
dC,
|
| 1057 |
-
None,
|
| 1058 |
-
dD,
|
| 1059 |
-
dz,
|
| 1060 |
-
ddt_bias,
|
| 1061 |
-
dinitial_states,
|
| 1062 |
-
None,
|
| 1063 |
-
None,
|
| 1064 |
-
None,
|
| 1065 |
-
None,
|
| 1066 |
-
None,
|
| 1067 |
-
None,
|
| 1068 |
-
)
|
| 1069 |
-
|
| 1070 |
-
|
| 1071 |
-
def mamba_chunk_scan_combined(
|
| 1072 |
-
x,
|
| 1073 |
-
dt,
|
| 1074 |
-
A,
|
| 1075 |
-
B,
|
| 1076 |
-
C,
|
| 1077 |
-
chunk_size,
|
| 1078 |
-
D=None,
|
| 1079 |
-
z=None,
|
| 1080 |
-
dt_bias=None,
|
| 1081 |
-
initial_states=None,
|
| 1082 |
-
seq_idx=None,
|
| 1083 |
-
cu_seqlens=None,
|
| 1084 |
-
dt_softplus=False,
|
| 1085 |
-
dt_limit=(0.0, float("inf")),
|
| 1086 |
-
return_final_states=False,
|
| 1087 |
-
return_varlen_states=False,
|
| 1088 |
-
):
|
| 1089 |
-
"""
|
| 1090 |
-
Argument:
|
| 1091 |
-
x: (batch, seqlen, nheads, headdim)
|
| 1092 |
-
dt: (batch, seqlen, nheads)
|
| 1093 |
-
A: (nheads)
|
| 1094 |
-
B: (batch, seqlen, ngroups, dstate)
|
| 1095 |
-
C: (batch, seqlen, ngroups, dstate)
|
| 1096 |
-
chunk_size: int
|
| 1097 |
-
D: (nheads, headdim) or (nheads,)
|
| 1098 |
-
z: (batch, seqlen, nheads, headdim)
|
| 1099 |
-
dt_bias: (nheads,)
|
| 1100 |
-
initial_states: (batch, nheads, headdim, dstate)
|
| 1101 |
-
seq_idx: (batch, seqlen)
|
| 1102 |
-
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
|
| 1103 |
-
dt_softplus: Whether to apply softplus to dt
|
| 1104 |
-
Return:
|
| 1105 |
-
out: (batch, seqlen, nheads, headdim)
|
| 1106 |
-
"""
|
| 1107 |
-
return MambaChunkScanCombinedFn.apply(
|
| 1108 |
-
x,
|
| 1109 |
-
dt,
|
| 1110 |
-
A,
|
| 1111 |
-
B,
|
| 1112 |
-
C,
|
| 1113 |
-
chunk_size,
|
| 1114 |
-
D,
|
| 1115 |
-
z,
|
| 1116 |
-
dt_bias,
|
| 1117 |
-
initial_states,
|
| 1118 |
-
seq_idx,
|
| 1119 |
-
cu_seqlens,
|
| 1120 |
-
dt_softplus,
|
| 1121 |
-
dt_limit,
|
| 1122 |
-
return_final_states,
|
| 1123 |
-
return_varlen_states,
|
| 1124 |
-
)
|
| 1125 |
-
|
| 1126 |
-
|
| 1127 |
-
def mamba_chunk_scan(
|
| 1128 |
-
x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False
|
| 1129 |
-
):
|
| 1130 |
-
"""
|
| 1131 |
-
Argument:
|
| 1132 |
-
x: (batch, seqlen, nheads, headdim)
|
| 1133 |
-
dt: (batch, seqlen, nheads)
|
| 1134 |
-
A: (nheads)
|
| 1135 |
-
B: (batch, seqlen, ngroups, dstate)
|
| 1136 |
-
C: (batch, seqlen, ngroups, dstate)
|
| 1137 |
-
D: (nheads, headdim) or (nheads,)
|
| 1138 |
-
z: (batch, seqlen, nheads, headdim)
|
| 1139 |
-
dt_bias: (nheads,)
|
| 1140 |
-
Return:
|
| 1141 |
-
out: (batch, seqlen, nheads, headdim)
|
| 1142 |
-
"""
|
| 1143 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1144 |
-
dstate = B.shape[-1]
|
| 1145 |
-
if seqlen % chunk_size != 0:
|
| 1146 |
-
dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
|
| 1147 |
-
dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
|
| 1148 |
-
dt = dt.float() # We want high precision for this before cumsum
|
| 1149 |
-
if dt_bias is not None:
|
| 1150 |
-
dt = dt + rearrange(dt_bias, "h -> h 1 1")
|
| 1151 |
-
if dt_softplus:
|
| 1152 |
-
dt = F.softplus(dt)
|
| 1153 |
-
dA = dt * rearrange(A, "h -> h 1 1")
|
| 1154 |
-
dA = dt * rearrange(A, "h -> h 1 1")
|
| 1155 |
-
dA_cumsum = torch.cumsum(dA, dim=-1)
|
| 1156 |
-
# 1. Compute the state for each chunk
|
| 1157 |
-
states = chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True)
|
| 1158 |
-
# 2. Pass the state to all the chunks by weighted cumsum.
|
| 1159 |
-
states = rearrange(
|
| 1160 |
-
state_passing(
|
| 1161 |
-
rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1]
|
| 1162 |
-
)[0],
|
| 1163 |
-
"... (p n) -> ... p n",
|
| 1164 |
-
n=dstate,
|
| 1165 |
-
)
|
| 1166 |
-
# 3. Compute the output for each chunk
|
| 1167 |
-
out = chunk_scan(B, C, x, dt, dA_cumsum, states, D=D, z=z)
|
| 1168 |
-
return out
|
| 1169 |
-
|
| 1170 |
-
|
| 1171 |
-
def ssd_chunk_scan_combined_ref(
|
| 1172 |
-
x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False
|
| 1173 |
-
):
|
| 1174 |
-
"""
|
| 1175 |
-
Argument:
|
| 1176 |
-
x: (batch, seqlen, nheads, headdim)
|
| 1177 |
-
dt: (batch, seqlen, nheads)
|
| 1178 |
-
A: (nheads)
|
| 1179 |
-
B: (batch, seqlen, ngroups, dstate)
|
| 1180 |
-
C: (batch, seqlen, ngroups, dstate)
|
| 1181 |
-
D: (nheads, headdim) or (nheads,)
|
| 1182 |
-
z: (batch, seqlen, nheads, headdim)
|
| 1183 |
-
dt_bias: (nheads,)
|
| 1184 |
-
Return:
|
| 1185 |
-
out: (batch, seqlen, nheads, headdim)
|
| 1186 |
-
"""
|
| 1187 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1188 |
-
dstate = B.shape[-1]
|
| 1189 |
-
if seqlen % chunk_size != 0:
|
| 1190 |
-
dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
|
| 1191 |
-
dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
|
| 1192 |
-
dt = dt.float() # We want high precision for this before cumsum
|
| 1193 |
-
if dt_bias is not None:
|
| 1194 |
-
dt = dt + rearrange(dt_bias, "h -> h 1 1")
|
| 1195 |
-
if dt_softplus:
|
| 1196 |
-
dt = F.softplus(dt)
|
| 1197 |
-
dA = dt * rearrange(A, "h -> h 1 1")
|
| 1198 |
-
dA_cumsum = torch.cumsum(dA, dim=-1)
|
| 1199 |
-
# 1. Compute the state for each chunk
|
| 1200 |
-
states = chunk_state_ref(B, x, dt, dA_cumsum)
|
| 1201 |
-
states_dtype = states.dtype
|
| 1202 |
-
if states.dtype not in [torch.float32, torch.float64]:
|
| 1203 |
-
states = states.to(torch.float32)
|
| 1204 |
-
# 2. Pass the state to all the chunks by weighted cumsum.
|
| 1205 |
-
# state_passing_ref is much less numerically stable
|
| 1206 |
-
states = rearrange(
|
| 1207 |
-
state_passing_ref(
|
| 1208 |
-
rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1]
|
| 1209 |
-
)[0],
|
| 1210 |
-
"... (p n) -> ... p n",
|
| 1211 |
-
n=dstate,
|
| 1212 |
-
)
|
| 1213 |
-
states = states.to(states_dtype)
|
| 1214 |
-
# 3. Compute the output for each chunk
|
| 1215 |
-
out = chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z)
|
| 1216 |
-
return out
|
| 1217 |
-
|
| 1218 |
-
|
| 1219 |
-
def ssd_selective_scan(
|
| 1220 |
-
x,
|
| 1221 |
-
dt,
|
| 1222 |
-
A,
|
| 1223 |
-
B,
|
| 1224 |
-
C,
|
| 1225 |
-
D=None,
|
| 1226 |
-
z=None,
|
| 1227 |
-
dt_bias=None,
|
| 1228 |
-
dt_softplus=False,
|
| 1229 |
-
dt_limit=(0.0, float("inf")),
|
| 1230 |
-
):
|
| 1231 |
-
"""
|
| 1232 |
-
Argument:
|
| 1233 |
-
x: (batch, seqlen, nheads, headdim)
|
| 1234 |
-
dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
|
| 1235 |
-
A: (nheads) or (dim, dstate)
|
| 1236 |
-
B: (batch, seqlen, ngroups, dstate)
|
| 1237 |
-
C: (batch, seqlen, ngroups, dstate)
|
| 1238 |
-
D: (nheads, headdim) or (nheads,)
|
| 1239 |
-
z: (batch, seqlen, nheads, headdim)
|
| 1240 |
-
dt_bias: (nheads,) or (nheads, headdim)
|
| 1241 |
-
Return:
|
| 1242 |
-
out: (batch, seqlen, nheads, headdim)
|
| 1243 |
-
"""
|
| 1244 |
-
from ..selective_scan_interface import selective_scan_fn
|
| 1245 |
-
|
| 1246 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1247 |
-
_, _, ngroups, dstate = B.shape
|
| 1248 |
-
x = rearrange(x, "b l h p -> b (h p) l")
|
| 1249 |
-
if dt.dim() == 3:
|
| 1250 |
-
dt = repeat(dt, "b l h -> b l h p", p=headdim)
|
| 1251 |
-
dt = rearrange(dt, "b l h p -> b (h p) l")
|
| 1252 |
-
if A.dim() == 1:
|
| 1253 |
-
A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
|
| 1254 |
-
else:
|
| 1255 |
-
A = A.to(dtype=torch.float32)
|
| 1256 |
-
B = rearrange(B, "b l g n -> b g n l")
|
| 1257 |
-
C = rearrange(C, "b l g n -> b g n l")
|
| 1258 |
-
if D is not None:
|
| 1259 |
-
if D.dim() == 2:
|
| 1260 |
-
D = rearrange(D, "h p -> (h p)")
|
| 1261 |
-
else:
|
| 1262 |
-
D = repeat(D, "h -> (h p)", p=headdim)
|
| 1263 |
-
if z is not None:
|
| 1264 |
-
z = rearrange(z, "b l h p -> b (h p) l")
|
| 1265 |
-
if dt_bias is not None:
|
| 1266 |
-
if dt_bias.dim() == 1:
|
| 1267 |
-
dt_bias = repeat(dt_bias, "h -> h p", p=headdim)
|
| 1268 |
-
dt_bias = rearrange(dt_bias, "h p -> (h p)")
|
| 1269 |
-
if dt_limit != (0.0, float("inf")):
|
| 1270 |
-
if dt_bias is not None:
|
| 1271 |
-
dt = dt + rearrange(dt_bias, "d -> d 1")
|
| 1272 |
-
if dt_softplus:
|
| 1273 |
-
dt = F.softplus(dt)
|
| 1274 |
-
dt = dt.clamp(min=dt_limit[0], max=dt_limit[1]).to(x.dtype)
|
| 1275 |
-
dt_bias = None
|
| 1276 |
-
dt_softplus = None
|
| 1277 |
-
out = selective_scan_fn(
|
| 1278 |
-
x, dt, A, B, C, D=D, z=z, delta_bias=dt_bias, delta_softplus=dt_softplus
|
| 1279 |
-
)
|
| 1280 |
-
return rearrange(out, "b (h p) l -> b l h p", p=headdim)
|
| 1281 |
-
|
| 1282 |
-
|
| 1283 |
-
def mamba_conv1d_scan_ref(
|
| 1284 |
-
xBC,
|
| 1285 |
-
conv1d_weight,
|
| 1286 |
-
conv1d_bias,
|
| 1287 |
-
dt,
|
| 1288 |
-
A,
|
| 1289 |
-
chunk_size,
|
| 1290 |
-
D=None,
|
| 1291 |
-
z=None,
|
| 1292 |
-
dt_bias=None,
|
| 1293 |
-
dt_softplus=False,
|
| 1294 |
-
dt_limit=(0.0, float("inf")),
|
| 1295 |
-
activation="silu",
|
| 1296 |
-
headdim=None,
|
| 1297 |
-
ngroups=1,
|
| 1298 |
-
):
|
| 1299 |
-
"""
|
| 1300 |
-
Argument:
|
| 1301 |
-
xBC: (batch, seqlen, dim + 2 * ngroups * dstate) where dim == nheads * headdim
|
| 1302 |
-
conv1d_weight: (dim + 2 * ngroups * dstate, width)
|
| 1303 |
-
conv1d_bias: (dim + 2 * ngroups * dstate,)
|
| 1304 |
-
dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
|
| 1305 |
-
A: (nheads)
|
| 1306 |
-
D: (nheads, headdim) or (nheads,)
|
| 1307 |
-
z: (batch, seqlen, dim)
|
| 1308 |
-
dt_bias: (nheads) or (nheads, headdim)
|
| 1309 |
-
headdim: if D is 1D and z is None, headdim must be passed in
|
| 1310 |
-
Return:
|
| 1311 |
-
out: (batch, seqlen, dim)
|
| 1312 |
-
"""
|
| 1313 |
-
batch, seqlen, nheads = dt.shape[:3]
|
| 1314 |
-
assert nheads % ngroups == 0
|
| 1315 |
-
if z is not None:
|
| 1316 |
-
dim = z.shape[-1]
|
| 1317 |
-
assert dim % nheads == 0
|
| 1318 |
-
headdim = dim // nheads
|
| 1319 |
-
else:
|
| 1320 |
-
if D.dim() == 1:
|
| 1321 |
-
assert headdim is not None
|
| 1322 |
-
else:
|
| 1323 |
-
headdim = D.shape[1]
|
| 1324 |
-
dim = nheads * headdim
|
| 1325 |
-
xBC = rearrange(
|
| 1326 |
-
causal_conv1d_fn(
|
| 1327 |
-
rearrange(xBC, "b s d -> b d s"),
|
| 1328 |
-
conv1d_weight,
|
| 1329 |
-
conv1d_bias,
|
| 1330 |
-
activation=activation,
|
| 1331 |
-
),
|
| 1332 |
-
"b d s -> b s d",
|
| 1333 |
-
)
|
| 1334 |
-
dstate = (xBC.shape[-1] - dim) // ngroups // 2
|
| 1335 |
-
x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
|
| 1336 |
-
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
| 1337 |
-
B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
|
| 1338 |
-
C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
|
| 1339 |
-
z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
|
| 1340 |
-
out = ssd_selective_scan(
|
| 1341 |
-
x,
|
| 1342 |
-
dt.to(x.dtype),
|
| 1343 |
-
A,
|
| 1344 |
-
B,
|
| 1345 |
-
C,
|
| 1346 |
-
D=D.float(),
|
| 1347 |
-
z=z,
|
| 1348 |
-
dt_bias=dt_bias,
|
| 1349 |
-
dt_softplus=dt_softplus,
|
| 1350 |
-
dt_limit=dt_limit,
|
| 1351 |
-
)
|
| 1352 |
-
return rearrange(out, "b s h p -> b s (h p)")
|
| 1353 |
-
|
| 1354 |
-
|
| 1355 |
-
class MambaSplitConv1dScanCombinedFn(torch.autograd.Function):
|
| 1356 |
-
|
| 1357 |
-
@staticmethod
|
| 1358 |
-
@custom_fwd
|
| 1359 |
-
def forward(
|
| 1360 |
-
ctx,
|
| 1361 |
-
zxbcdt,
|
| 1362 |
-
conv1d_weight,
|
| 1363 |
-
conv1d_bias,
|
| 1364 |
-
dt_bias,
|
| 1365 |
-
A,
|
| 1366 |
-
D,
|
| 1367 |
-
chunk_size,
|
| 1368 |
-
initial_states=None,
|
| 1369 |
-
seq_idx=None,
|
| 1370 |
-
dt_limit=(0.0, float("inf")),
|
| 1371 |
-
return_final_states=False,
|
| 1372 |
-
activation="silu",
|
| 1373 |
-
rmsnorm_weight=None,
|
| 1374 |
-
rmsnorm_eps=1e-6,
|
| 1375 |
-
outproj_weight=None,
|
| 1376 |
-
outproj_bias=None,
|
| 1377 |
-
headdim=None,
|
| 1378 |
-
ngroups=1,
|
| 1379 |
-
norm_before_gate=True,
|
| 1380 |
-
):
|
| 1381 |
-
assert activation in [None, "silu", "swish"]
|
| 1382 |
-
if D.dim() == 1:
|
| 1383 |
-
assert headdim is not None
|
| 1384 |
-
(nheads,) = D.shape
|
| 1385 |
-
else:
|
| 1386 |
-
nheads, headdim = D.shape
|
| 1387 |
-
batch, seqlen, _ = zxbcdt.shape
|
| 1388 |
-
dim = nheads * headdim
|
| 1389 |
-
assert nheads % ngroups == 0
|
| 1390 |
-
dstate = (conv1d_weight.shape[0] - dim) // ngroups // 2
|
| 1391 |
-
d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ngroups * dstate - nheads) // 2
|
| 1392 |
-
assert d_nonssm >= 0
|
| 1393 |
-
assert zxbcdt.shape == (
|
| 1394 |
-
batch,
|
| 1395 |
-
seqlen,
|
| 1396 |
-
2 * d_nonssm + 2 * dim + 2 * ngroups * dstate + nheads,
|
| 1397 |
-
)
|
| 1398 |
-
assert dt_bias.shape == (nheads,)
|
| 1399 |
-
assert A.shape == (nheads,)
|
| 1400 |
-
zx0, z, xBC, dt = torch.split(
|
| 1401 |
-
zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1
|
| 1402 |
-
)
|
| 1403 |
-
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
|
| 1404 |
-
xBC_conv = rearrange(
|
| 1405 |
-
causal_conv1d_cuda.causal_conv1d_fwd(
|
| 1406 |
-
rearrange(xBC, "b s d -> b d s"),
|
| 1407 |
-
conv1d_weight,
|
| 1408 |
-
conv1d_bias,
|
| 1409 |
-
seq_idx,
|
| 1410 |
-
None,
|
| 1411 |
-
None,
|
| 1412 |
-
activation in ["silu", "swish"],
|
| 1413 |
-
),
|
| 1414 |
-
"b d s -> b s d",
|
| 1415 |
-
)
|
| 1416 |
-
x, B, C = torch.split(
|
| 1417 |
-
xBC_conv, [dim, ngroups * dstate, ngroups * dstate], dim=-1
|
| 1418 |
-
)
|
| 1419 |
-
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
| 1420 |
-
B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
|
| 1421 |
-
C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
|
| 1422 |
-
z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
|
| 1423 |
-
if rmsnorm_weight is None:
|
| 1424 |
-
out, out_x, dt_out, dA_cumsum, states, final_states = (
|
| 1425 |
-
_mamba_chunk_scan_combined_fwd(
|
| 1426 |
-
x,
|
| 1427 |
-
dt,
|
| 1428 |
-
A,
|
| 1429 |
-
B,
|
| 1430 |
-
C,
|
| 1431 |
-
chunk_size=chunk_size,
|
| 1432 |
-
D=D,
|
| 1433 |
-
z=z,
|
| 1434 |
-
dt_bias=dt_bias,
|
| 1435 |
-
initial_states=initial_states,
|
| 1436 |
-
seq_idx=seq_idx,
|
| 1437 |
-
dt_softplus=True,
|
| 1438 |
-
dt_limit=dt_limit,
|
| 1439 |
-
)
|
| 1440 |
-
)
|
| 1441 |
-
out = rearrange(out, "b s h p -> b s (h p)")
|
| 1442 |
-
rstd = None
|
| 1443 |
-
if d_nonssm > 0:
|
| 1444 |
-
out = torch.cat([_swiglu_fwd(zx0), out], dim=-1)
|
| 1445 |
-
else:
|
| 1446 |
-
out_x, _, dt_out, dA_cumsum, states, final_states = (
|
| 1447 |
-
_mamba_chunk_scan_combined_fwd(
|
| 1448 |
-
x,
|
| 1449 |
-
dt,
|
| 1450 |
-
A,
|
| 1451 |
-
B,
|
| 1452 |
-
C,
|
| 1453 |
-
chunk_size=chunk_size,
|
| 1454 |
-
D=D,
|
| 1455 |
-
z=None,
|
| 1456 |
-
dt_bias=dt_bias,
|
| 1457 |
-
initial_states=initial_states,
|
| 1458 |
-
seq_idx=seq_idx,
|
| 1459 |
-
dt_softplus=True,
|
| 1460 |
-
dt_limit=dt_limit,
|
| 1461 |
-
)
|
| 1462 |
-
)
|
| 1463 |
-
# reshape input data into 2D tensor
|
| 1464 |
-
x_rms = rearrange(out_x, "b s h p -> (b s) (h p)")
|
| 1465 |
-
z_rms = rearrange(z, "b s h p -> (b s) (h p)")
|
| 1466 |
-
rmsnorm_weight = rmsnorm_weight.contiguous()
|
| 1467 |
-
if d_nonssm == 0:
|
| 1468 |
-
out = None
|
| 1469 |
-
else:
|
| 1470 |
-
out01 = torch.empty(
|
| 1471 |
-
(batch, seqlen, d_nonssm + dim),
|
| 1472 |
-
dtype=x_rms.dtype,
|
| 1473 |
-
device=x_rms.device,
|
| 1474 |
-
)
|
| 1475 |
-
out = rearrange(out01[..., d_nonssm:], "b s d -> (b s) d")
|
| 1476 |
-
_swiglu_fwd(zx0, out=out01[..., :d_nonssm])
|
| 1477 |
-
out, _, rstd = _layer_norm_fwd(
|
| 1478 |
-
x_rms,
|
| 1479 |
-
rmsnorm_weight,
|
| 1480 |
-
None,
|
| 1481 |
-
rmsnorm_eps,
|
| 1482 |
-
z_rms,
|
| 1483 |
-
out=out,
|
| 1484 |
-
group_size=dim // ngroups,
|
| 1485 |
-
norm_before_gate=norm_before_gate,
|
| 1486 |
-
is_rms_norm=True,
|
| 1487 |
-
)
|
| 1488 |
-
if d_nonssm == 0:
|
| 1489 |
-
out = rearrange(out, "(b s) d -> b s d", b=batch)
|
| 1490 |
-
else:
|
| 1491 |
-
out = out01
|
| 1492 |
-
ctx.outproj_weight_dtype = (
|
| 1493 |
-
outproj_weight.dtype if outproj_weight is not None else None
|
| 1494 |
-
)
|
| 1495 |
-
if outproj_weight is not None:
|
| 1496 |
-
if torch.is_autocast_enabled():
|
| 1497 |
-
dtype = torch.get_autocast_gpu_dtype()
|
| 1498 |
-
out, outproj_weight = out.to(dtype), outproj_weight.to(dtype)
|
| 1499 |
-
outproj_bias = (
|
| 1500 |
-
outproj_bias.to(dtype) if outproj_bias is not None else None
|
| 1501 |
-
)
|
| 1502 |
-
out = F.linear(out, outproj_weight, outproj_bias)
|
| 1503 |
-
else:
|
| 1504 |
-
assert outproj_bias is None
|
| 1505 |
-
ctx.save_for_backward(
|
| 1506 |
-
zxbcdt,
|
| 1507 |
-
conv1d_weight,
|
| 1508 |
-
conv1d_bias,
|
| 1509 |
-
out_x,
|
| 1510 |
-
A,
|
| 1511 |
-
D,
|
| 1512 |
-
dt_bias,
|
| 1513 |
-
initial_states,
|
| 1514 |
-
seq_idx,
|
| 1515 |
-
rmsnorm_weight,
|
| 1516 |
-
rstd,
|
| 1517 |
-
outproj_weight,
|
| 1518 |
-
outproj_bias,
|
| 1519 |
-
)
|
| 1520 |
-
ctx.dt_limit = dt_limit
|
| 1521 |
-
ctx.return_final_states = return_final_states
|
| 1522 |
-
ctx.activation = activation
|
| 1523 |
-
ctx.rmsnorm_eps = rmsnorm_eps
|
| 1524 |
-
ctx.norm_before_gate = norm_before_gate
|
| 1525 |
-
ctx.chunk_size = chunk_size
|
| 1526 |
-
ctx.headdim = headdim
|
| 1527 |
-
ctx.ngroups = ngroups
|
| 1528 |
-
return out if not return_final_states else (out, final_states)
|
| 1529 |
-
|
| 1530 |
-
@staticmethod
|
| 1531 |
-
@custom_bwd
|
| 1532 |
-
def backward(ctx, dout, *args):
|
| 1533 |
-
(
|
| 1534 |
-
zxbcdt,
|
| 1535 |
-
conv1d_weight,
|
| 1536 |
-
conv1d_bias,
|
| 1537 |
-
out,
|
| 1538 |
-
A,
|
| 1539 |
-
D,
|
| 1540 |
-
dt_bias,
|
| 1541 |
-
initial_states,
|
| 1542 |
-
seq_idx,
|
| 1543 |
-
rmsnorm_weight,
|
| 1544 |
-
rstd,
|
| 1545 |
-
outproj_weight,
|
| 1546 |
-
outproj_bias,
|
| 1547 |
-
) = ctx.saved_tensors
|
| 1548 |
-
dfinal_states = args[0] if ctx.return_final_states else None
|
| 1549 |
-
headdim = ctx.headdim
|
| 1550 |
-
nheads = D.shape[0]
|
| 1551 |
-
dim = nheads * headdim
|
| 1552 |
-
assert nheads % ctx.ngroups == 0
|
| 1553 |
-
dstate = (conv1d_weight.shape[0] - dim) // ctx.ngroups // 2
|
| 1554 |
-
d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ctx.ngroups * dstate - nheads) // 2
|
| 1555 |
-
assert d_nonssm >= 0
|
| 1556 |
-
recompute_output = outproj_weight is not None
|
| 1557 |
-
if recompute_output:
|
| 1558 |
-
out_recompute = torch.empty(
|
| 1559 |
-
*out.shape[:2], d_nonssm + dim, device=out.device, dtype=out.dtype
|
| 1560 |
-
)
|
| 1561 |
-
out0_recompute, out1_recompute = out_recompute.split(
|
| 1562 |
-
[d_nonssm, dim], dim=-1
|
| 1563 |
-
)
|
| 1564 |
-
zx0, z, xBC, dt = torch.split(
|
| 1565 |
-
zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1
|
| 1566 |
-
)
|
| 1567 |
-
# Recompute x, B, C
|
| 1568 |
-
xBC_conv = rearrange(
|
| 1569 |
-
causal_conv1d_cuda.causal_conv1d_fwd(
|
| 1570 |
-
rearrange(xBC, "b s d -> b d s"),
|
| 1571 |
-
conv1d_weight,
|
| 1572 |
-
conv1d_bias,
|
| 1573 |
-
seq_idx,
|
| 1574 |
-
None,
|
| 1575 |
-
None,
|
| 1576 |
-
ctx.activation in ["silu", "swish"],
|
| 1577 |
-
),
|
| 1578 |
-
"b d s -> b s d",
|
| 1579 |
-
)
|
| 1580 |
-
x, B, C = torch.split(
|
| 1581 |
-
xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1
|
| 1582 |
-
)
|
| 1583 |
-
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
| 1584 |
-
B = rearrange(B, "b l (g n) -> b l g n", g=ctx.ngroups)
|
| 1585 |
-
C = rearrange(C, "b l (g n) -> b l g n", g=ctx.ngroups)
|
| 1586 |
-
dzxbcdt = torch.empty_like(zxbcdt)
|
| 1587 |
-
dzx0, dz, dxBC_given, ddt_given = torch.split(
|
| 1588 |
-
dzxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1
|
| 1589 |
-
)
|
| 1590 |
-
dxBC = torch.empty_like(xBC)
|
| 1591 |
-
dx, dB, dC = torch.split(
|
| 1592 |
-
dxBC, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1
|
| 1593 |
-
)
|
| 1594 |
-
z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
|
| 1595 |
-
dx = rearrange(dx, "b l (h p) -> b l h p", h=nheads)
|
| 1596 |
-
dB = rearrange(dB, "b l (g n) -> b l g n", g=ctx.ngroups)
|
| 1597 |
-
dC = rearrange(dC, "b l (g n) -> b l g n", g=ctx.ngroups)
|
| 1598 |
-
if outproj_weight is not None:
|
| 1599 |
-
dout_og = dout
|
| 1600 |
-
dout = F.linear(dout, outproj_weight.t())
|
| 1601 |
-
if d_nonssm > 0:
|
| 1602 |
-
dout0, dout = dout.split([d_nonssm, dim], dim=-1)
|
| 1603 |
-
_swiglu_bwd(zx0, dout0, dxy=dzx0, recompute_output=True, out=out0_recompute)
|
| 1604 |
-
dout = rearrange(dout, "b s (h p) -> b s h p", p=headdim)
|
| 1605 |
-
if rmsnorm_weight is None:
|
| 1606 |
-
dz = rearrange(dz, "b l (h p) -> b l h p", h=nheads)
|
| 1607 |
-
dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states, *rest = (
|
| 1608 |
-
_mamba_chunk_scan_combined_bwd(
|
| 1609 |
-
dout,
|
| 1610 |
-
x,
|
| 1611 |
-
dt,
|
| 1612 |
-
A,
|
| 1613 |
-
B,
|
| 1614 |
-
C,
|
| 1615 |
-
out,
|
| 1616 |
-
ctx.chunk_size,
|
| 1617 |
-
D=D,
|
| 1618 |
-
z=z,
|
| 1619 |
-
dt_bias=dt_bias,
|
| 1620 |
-
initial_states=initial_states,
|
| 1621 |
-
dfinal_states=dfinal_states,
|
| 1622 |
-
seq_idx=seq_idx,
|
| 1623 |
-
dt_softplus=True,
|
| 1624 |
-
dt_limit=ctx.dt_limit,
|
| 1625 |
-
dx=dx,
|
| 1626 |
-
ddt=ddt_given,
|
| 1627 |
-
dB=dB,
|
| 1628 |
-
dC=dC,
|
| 1629 |
-
dz=dz,
|
| 1630 |
-
recompute_output=recompute_output,
|
| 1631 |
-
)
|
| 1632 |
-
)
|
| 1633 |
-
out_for_linear = (
|
| 1634 |
-
rearrange(rest[0], "b s h p -> b s (h p)") if recompute_output else None
|
| 1635 |
-
)
|
| 1636 |
-
drmsnorm_weight = None
|
| 1637 |
-
else:
|
| 1638 |
-
batch = dout.shape[0]
|
| 1639 |
-
dy_rms = rearrange(dout, "b s h p -> (b s) (h p)")
|
| 1640 |
-
dz = rearrange(dz, "b l d -> (b l) d")
|
| 1641 |
-
x_rms = rearrange(out, "b s h p -> (b s) (h p)")
|
| 1642 |
-
z_rms = rearrange(z, "b s h p -> (b s) (h p)")
|
| 1643 |
-
out1_recompute = (
|
| 1644 |
-
rearrange(out1_recompute, "b s d -> (b s) d")
|
| 1645 |
-
if recompute_output
|
| 1646 |
-
else None
|
| 1647 |
-
)
|
| 1648 |
-
dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(
|
| 1649 |
-
dy_rms,
|
| 1650 |
-
x_rms,
|
| 1651 |
-
rmsnorm_weight,
|
| 1652 |
-
None,
|
| 1653 |
-
ctx.rmsnorm_eps,
|
| 1654 |
-
None,
|
| 1655 |
-
rstd,
|
| 1656 |
-
z_rms,
|
| 1657 |
-
group_size=dim // ctx.ngroups,
|
| 1658 |
-
norm_before_gate=ctx.norm_before_gate,
|
| 1659 |
-
is_rms_norm=True,
|
| 1660 |
-
recompute_output=recompute_output,
|
| 1661 |
-
dz=dz,
|
| 1662 |
-
out=out1_recompute if recompute_output else None,
|
| 1663 |
-
)
|
| 1664 |
-
out_for_linear = out_recompute if recompute_output else None
|
| 1665 |
-
dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim)
|
| 1666 |
-
dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = (
|
| 1667 |
-
_mamba_chunk_scan_combined_bwd(
|
| 1668 |
-
dout,
|
| 1669 |
-
x,
|
| 1670 |
-
dt,
|
| 1671 |
-
A,
|
| 1672 |
-
B,
|
| 1673 |
-
C,
|
| 1674 |
-
out,
|
| 1675 |
-
ctx.chunk_size,
|
| 1676 |
-
D=D,
|
| 1677 |
-
z=None,
|
| 1678 |
-
dt_bias=dt_bias,
|
| 1679 |
-
initial_states=initial_states,
|
| 1680 |
-
dfinal_states=dfinal_states,
|
| 1681 |
-
seq_idx=seq_idx,
|
| 1682 |
-
dt_softplus=True,
|
| 1683 |
-
dt_limit=ctx.dt_limit,
|
| 1684 |
-
dx=dx,
|
| 1685 |
-
ddt=ddt_given,
|
| 1686 |
-
dB=dB,
|
| 1687 |
-
dC=dC,
|
| 1688 |
-
)
|
| 1689 |
-
)
|
| 1690 |
-
|
| 1691 |
-
if outproj_weight is not None:
|
| 1692 |
-
doutproj_weight = torch.einsum("bso,bsd->od", dout_og, out_for_linear)
|
| 1693 |
-
doutproj_bias = (
|
| 1694 |
-
dout_og.sum(dim=(0, 1)) if outproj_bias is not None else None
|
| 1695 |
-
)
|
| 1696 |
-
else:
|
| 1697 |
-
doutproj_weight, doutproj_bias = None, None
|
| 1698 |
-
dxBC_given = rearrange(dxBC_given, "b s d -> b d s")
|
| 1699 |
-
dxBC_given, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
|
| 1700 |
-
rearrange(xBC, "b s d -> b d s"),
|
| 1701 |
-
conv1d_weight,
|
| 1702 |
-
conv1d_bias,
|
| 1703 |
-
rearrange(dxBC, "b s d -> b d s"),
|
| 1704 |
-
seq_idx,
|
| 1705 |
-
None,
|
| 1706 |
-
None,
|
| 1707 |
-
dxBC_given,
|
| 1708 |
-
False,
|
| 1709 |
-
ctx.activation in ["silu", "swish"],
|
| 1710 |
-
)
|
| 1711 |
-
dxBC_given = rearrange(dxBC_given, "b d s -> b s d")
|
| 1712 |
-
return (
|
| 1713 |
-
dzxbcdt,
|
| 1714 |
-
dweight,
|
| 1715 |
-
dbias,
|
| 1716 |
-
ddt_bias,
|
| 1717 |
-
dA,
|
| 1718 |
-
dD,
|
| 1719 |
-
None,
|
| 1720 |
-
dinitial_states,
|
| 1721 |
-
None,
|
| 1722 |
-
None,
|
| 1723 |
-
None,
|
| 1724 |
-
None,
|
| 1725 |
-
drmsnorm_weight,
|
| 1726 |
-
None,
|
| 1727 |
-
doutproj_weight,
|
| 1728 |
-
doutproj_bias,
|
| 1729 |
-
None,
|
| 1730 |
-
None,
|
| 1731 |
-
None,
|
| 1732 |
-
)
|
| 1733 |
-
|
| 1734 |
-
|
| 1735 |
-
def mamba_split_conv1d_scan_combined(
|
| 1736 |
-
zxbcdt,
|
| 1737 |
-
conv1d_weight,
|
| 1738 |
-
conv1d_bias,
|
| 1739 |
-
dt_bias,
|
| 1740 |
-
A,
|
| 1741 |
-
D,
|
| 1742 |
-
chunk_size,
|
| 1743 |
-
initial_states=None,
|
| 1744 |
-
seq_idx=None,
|
| 1745 |
-
dt_limit=(0.0, float("inf")),
|
| 1746 |
-
return_final_states=False,
|
| 1747 |
-
activation="silu",
|
| 1748 |
-
rmsnorm_weight=None,
|
| 1749 |
-
rmsnorm_eps=1e-6,
|
| 1750 |
-
outproj_weight=None,
|
| 1751 |
-
outproj_bias=None,
|
| 1752 |
-
headdim=None,
|
| 1753 |
-
ngroups=1,
|
| 1754 |
-
norm_before_gate=True,
|
| 1755 |
-
):
|
| 1756 |
-
"""
|
| 1757 |
-
Argument:
|
| 1758 |
-
zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
|
| 1759 |
-
conv1d_weight: (dim + 2 * ngroups * dstate, width)
|
| 1760 |
-
conv1d_bias: (dim + 2 * ngroups * dstate,)
|
| 1761 |
-
dt_bias: (nheads,)
|
| 1762 |
-
A: (nheads)
|
| 1763 |
-
D: (nheads, headdim) or (nheads,)
|
| 1764 |
-
initial_states: (batch, nheads, headdim, dstate)
|
| 1765 |
-
seq_idx: (batch, seqlen), int32
|
| 1766 |
-
rmsnorm_weight: (dim,)
|
| 1767 |
-
outproj_weight: (out_dim, dim)
|
| 1768 |
-
outproj_bias: (out_dim,)
|
| 1769 |
-
headdim: if D is 1D, headdim must be passed in
|
| 1770 |
-
norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
|
| 1771 |
-
Return:
|
| 1772 |
-
out: (batch, seqlen, dim)
|
| 1773 |
-
"""
|
| 1774 |
-
return MambaSplitConv1dScanCombinedFn.apply(
|
| 1775 |
-
zxbcdt,
|
| 1776 |
-
conv1d_weight,
|
| 1777 |
-
conv1d_bias,
|
| 1778 |
-
dt_bias,
|
| 1779 |
-
A,
|
| 1780 |
-
D,
|
| 1781 |
-
chunk_size,
|
| 1782 |
-
initial_states,
|
| 1783 |
-
seq_idx,
|
| 1784 |
-
dt_limit,
|
| 1785 |
-
return_final_states,
|
| 1786 |
-
activation,
|
| 1787 |
-
rmsnorm_weight,
|
| 1788 |
-
rmsnorm_eps,
|
| 1789 |
-
outproj_weight,
|
| 1790 |
-
outproj_bias,
|
| 1791 |
-
headdim,
|
| 1792 |
-
ngroups,
|
| 1793 |
-
norm_before_gate,
|
| 1794 |
-
)
|
| 1795 |
-
|
| 1796 |
-
|
| 1797 |
-
def mamba_split_conv1d_scan_ref(
|
| 1798 |
-
zxbcdt,
|
| 1799 |
-
conv1d_weight,
|
| 1800 |
-
conv1d_bias,
|
| 1801 |
-
dt_bias,
|
| 1802 |
-
A,
|
| 1803 |
-
D,
|
| 1804 |
-
chunk_size,
|
| 1805 |
-
dt_limit=(0.0, float("inf")),
|
| 1806 |
-
activation="silu",
|
| 1807 |
-
rmsnorm_weight=None,
|
| 1808 |
-
rmsnorm_eps=1e-6,
|
| 1809 |
-
outproj_weight=None,
|
| 1810 |
-
outproj_bias=None,
|
| 1811 |
-
headdim=None,
|
| 1812 |
-
ngroups=1,
|
| 1813 |
-
norm_before_gate=True,
|
| 1814 |
-
):
|
| 1815 |
-
"""
|
| 1816 |
-
Argument:
|
| 1817 |
-
zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
|
| 1818 |
-
conv1d_weight: (dim + 2 * ngroups * dstate, width)
|
| 1819 |
-
conv1d_bias: (dim + 2 * ngroups * dstate,)
|
| 1820 |
-
dt_bias: (nheads,)
|
| 1821 |
-
A: (nheads)
|
| 1822 |
-
D: (nheads, headdim) or (nheads,)
|
| 1823 |
-
rmsnorm_weight: (dim,)
|
| 1824 |
-
outproj_weight: (out_dim, dim)
|
| 1825 |
-
outproj_bias: (out_dim,)
|
| 1826 |
-
headdim: if D is 1D, headdim must be passed in
|
| 1827 |
-
norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
|
| 1828 |
-
Return:
|
| 1829 |
-
out: (batch, seqlen, dim)
|
| 1830 |
-
"""
|
| 1831 |
-
if D.dim() == 1:
|
| 1832 |
-
assert headdim is not None
|
| 1833 |
-
(nheads,) = D.shape
|
| 1834 |
-
else:
|
| 1835 |
-
nheads, headdim = D.shape
|
| 1836 |
-
assert nheads % ngroups == 0
|
| 1837 |
-
batch, seqlen, _ = zxbcdt.shape
|
| 1838 |
-
dim = nheads * headdim
|
| 1839 |
-
dstate = (zxbcdt.shape[-1] - 2 * dim - nheads) // ngroups // 2
|
| 1840 |
-
assert zxbcdt.shape == (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads)
|
| 1841 |
-
assert dt_bias.shape == (nheads,)
|
| 1842 |
-
assert A.shape == (nheads,)
|
| 1843 |
-
if rmsnorm_weight is not None:
|
| 1844 |
-
assert rmsnorm_weight.shape == (dim,)
|
| 1845 |
-
z, xBC, dt = torch.split(zxbcdt, [dim, dim + 2 * ngroups * dstate, nheads], dim=-1)
|
| 1846 |
-
xBC = rearrange(
|
| 1847 |
-
causal_conv1d_fn(
|
| 1848 |
-
rearrange(xBC, "b s d -> b d s"),
|
| 1849 |
-
conv1d_weight,
|
| 1850 |
-
conv1d_bias,
|
| 1851 |
-
activation=activation,
|
| 1852 |
-
),
|
| 1853 |
-
"b d s -> b s d",
|
| 1854 |
-
)
|
| 1855 |
-
x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
|
| 1856 |
-
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
| 1857 |
-
B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
|
| 1858 |
-
C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
|
| 1859 |
-
z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
|
| 1860 |
-
out = ssd_selective_scan(
|
| 1861 |
-
x,
|
| 1862 |
-
dt.to(x.dtype),
|
| 1863 |
-
A,
|
| 1864 |
-
B,
|
| 1865 |
-
C,
|
| 1866 |
-
D=D.float(),
|
| 1867 |
-
z=z if rmsnorm_weight is None else None,
|
| 1868 |
-
dt_bias=dt_bias,
|
| 1869 |
-
dt_softplus=True,
|
| 1870 |
-
dt_limit=dt_limit,
|
| 1871 |
-
)
|
| 1872 |
-
out = rearrange(out, "b s h p -> b s (h p)")
|
| 1873 |
-
if rmsnorm_weight is not None:
|
| 1874 |
-
out = rmsnorm_fn(
|
| 1875 |
-
out,
|
| 1876 |
-
rmsnorm_weight,
|
| 1877 |
-
None,
|
| 1878 |
-
z=rearrange(z, "b l h p -> b l (h p)"),
|
| 1879 |
-
eps=rmsnorm_eps,
|
| 1880 |
-
norm_before_gate=norm_before_gate,
|
| 1881 |
-
)
|
| 1882 |
-
if outproj_weight is not None:
|
| 1883 |
-
out = F.linear(out, outproj_weight, outproj_bias)
|
| 1884 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/__init__.py
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
__version__ = "2.2.4"
|
| 2 |
-
|
| 3 |
-
from .ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
| 4 |
-
from .modules.mamba_simple import Mamba
|
| 5 |
-
from .modules.mamba2 import Mamba2
|
| 6 |
-
from .models.mixer_seq_simple import MambaLMHeadModel
|
| 7 |
-
|
| 8 |
-
__all__ = [
|
| 9 |
-
"selective_scan_fn",
|
| 10 |
-
"mamba_inner_fn",
|
| 11 |
-
"Mamba",
|
| 12 |
-
"Mamba2",
|
| 13 |
-
"MambaLMHeadModel",
|
| 14 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py
DELETED
|
@@ -1,326 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2024, Tri Dao.
|
| 2 |
-
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
|
| 3 |
-
from typing import Optional
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn as nn
|
| 7 |
-
import torch.nn.functional as F
|
| 8 |
-
from torch import Tensor
|
| 9 |
-
from torch.distributed import ProcessGroup
|
| 10 |
-
from ..utils.torch import custom_bwd, custom_fwd
|
| 11 |
-
|
| 12 |
-
from einops import rearrange
|
| 13 |
-
|
| 14 |
-
from ..distributed.distributed_utils import (
|
| 15 |
-
all_gather_raw,
|
| 16 |
-
all_reduce,
|
| 17 |
-
all_reduce_raw,
|
| 18 |
-
reduce_scatter,
|
| 19 |
-
reduce_scatter_raw,
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class ParallelLinearFunc(torch.autograd.Function):
|
| 24 |
-
@staticmethod
|
| 25 |
-
@custom_fwd
|
| 26 |
-
def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
|
| 27 |
-
"""
|
| 28 |
-
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
|
| 29 |
-
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
|
| 30 |
-
"""
|
| 31 |
-
ctx.compute_weight_gradient = weight.requires_grad
|
| 32 |
-
ctx.process_group = process_group
|
| 33 |
-
ctx.sequence_parallel = sequence_parallel
|
| 34 |
-
|
| 35 |
-
if torch.is_autocast_enabled():
|
| 36 |
-
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
| 37 |
-
x = x.contiguous()
|
| 38 |
-
if process_group is not None and sequence_parallel:
|
| 39 |
-
# We want to kick off the all_gather early, before weight dtype conversion
|
| 40 |
-
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
| 41 |
-
else:
|
| 42 |
-
total_x = x
|
| 43 |
-
|
| 44 |
-
if torch.is_autocast_enabled():
|
| 45 |
-
weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
|
| 46 |
-
bias = (
|
| 47 |
-
bias.to(dtype=torch.get_autocast_gpu_dtype())
|
| 48 |
-
if bias is not None
|
| 49 |
-
else None
|
| 50 |
-
)
|
| 51 |
-
weight = weight.contiguous()
|
| 52 |
-
if process_group is not None and sequence_parallel:
|
| 53 |
-
handle_x.wait()
|
| 54 |
-
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
|
| 55 |
-
batch_dim = batch_shape.numel()
|
| 56 |
-
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
|
| 57 |
-
output = F.linear(total_x, weight, bias)
|
| 58 |
-
if ctx.compute_weight_gradient:
|
| 59 |
-
ctx.save_for_backward(x, weight)
|
| 60 |
-
else:
|
| 61 |
-
ctx.save_for_backward(weight)
|
| 62 |
-
return output
|
| 63 |
-
|
| 64 |
-
@staticmethod
|
| 65 |
-
@custom_bwd
|
| 66 |
-
def backward(ctx, grad_output):
|
| 67 |
-
grad_output = grad_output.contiguous()
|
| 68 |
-
process_group = ctx.process_group
|
| 69 |
-
sequence_parallel = ctx.sequence_parallel
|
| 70 |
-
if ctx.compute_weight_gradient:
|
| 71 |
-
x, weight = ctx.saved_tensors
|
| 72 |
-
if process_group is not None and sequence_parallel:
|
| 73 |
-
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
| 74 |
-
else:
|
| 75 |
-
total_x = x
|
| 76 |
-
else:
|
| 77 |
-
(weight,) = ctx.saved_tensors
|
| 78 |
-
total_x = None
|
| 79 |
-
batch_shape = grad_output.shape[:-1]
|
| 80 |
-
batch_dim = batch_shape.numel()
|
| 81 |
-
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
| 82 |
-
if ctx.needs_input_grad[0]:
|
| 83 |
-
grad_input = F.linear(grad_output, weight.t())
|
| 84 |
-
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
| 85 |
-
if process_group is not None:
|
| 86 |
-
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
|
| 87 |
-
grad_input, handle_grad_input = reduce_fn(
|
| 88 |
-
grad_input, process_group, async_op=True
|
| 89 |
-
)
|
| 90 |
-
else:
|
| 91 |
-
grad_input = None
|
| 92 |
-
if ctx.needs_input_grad[1]:
|
| 93 |
-
assert ctx.compute_weight_gradient
|
| 94 |
-
if process_group is not None and sequence_parallel:
|
| 95 |
-
handle_x.wait()
|
| 96 |
-
grad_weight = torch.einsum(
|
| 97 |
-
"bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1])
|
| 98 |
-
)
|
| 99 |
-
else:
|
| 100 |
-
grad_weight = None
|
| 101 |
-
grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None
|
| 102 |
-
if process_group is not None and ctx.needs_input_grad[0]:
|
| 103 |
-
handle_grad_input.wait()
|
| 104 |
-
return grad_input, grad_weight, grad_bias, None, None
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
def parallel_linear_func(
|
| 108 |
-
x: Tensor,
|
| 109 |
-
weight: Tensor,
|
| 110 |
-
bias: Optional[Tensor] = None,
|
| 111 |
-
process_group: Optional[ProcessGroup] = None,
|
| 112 |
-
sequence_parallel: bool = True,
|
| 113 |
-
):
|
| 114 |
-
return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel)
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
class ColumnParallelLinear(nn.Linear):
|
| 118 |
-
def __init__(
|
| 119 |
-
self,
|
| 120 |
-
in_features: int,
|
| 121 |
-
out_features: int,
|
| 122 |
-
process_group: ProcessGroup,
|
| 123 |
-
bias: bool = True,
|
| 124 |
-
sequence_parallel=True,
|
| 125 |
-
multiple_of=1,
|
| 126 |
-
device=None,
|
| 127 |
-
dtype=None,
|
| 128 |
-
) -> None:
|
| 129 |
-
world_size = torch.distributed.get_world_size(process_group)
|
| 130 |
-
if out_features % multiple_of:
|
| 131 |
-
raise ValueError(
|
| 132 |
-
f"out_features ({out_features}) must be a multiple of {multiple_of}"
|
| 133 |
-
)
|
| 134 |
-
multiple = out_features // multiple_of
|
| 135 |
-
# We want to split @multiple across world_size, but it could be an uneven split
|
| 136 |
-
div = multiple // world_size
|
| 137 |
-
mod = multiple % world_size
|
| 138 |
-
# The first @mod ranks get @div + 1 copies, the rest get @div copies
|
| 139 |
-
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
| 140 |
-
super().__init__(
|
| 141 |
-
in_features,
|
| 142 |
-
local_multiple * multiple_of,
|
| 143 |
-
bias=bias,
|
| 144 |
-
device=device,
|
| 145 |
-
dtype=dtype,
|
| 146 |
-
)
|
| 147 |
-
self.process_group = process_group
|
| 148 |
-
self.sequence_parallel = sequence_parallel
|
| 149 |
-
|
| 150 |
-
def forward(self, x):
|
| 151 |
-
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
| 152 |
-
# we do an all_gather of x before doing the matmul.
|
| 153 |
-
# If not, then the input is already gathered.
|
| 154 |
-
return parallel_linear_func(
|
| 155 |
-
x,
|
| 156 |
-
self.weight,
|
| 157 |
-
self.bias,
|
| 158 |
-
process_group=self.process_group,
|
| 159 |
-
sequence_parallel=self.sequence_parallel,
|
| 160 |
-
)
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
class RowParallelLinear(nn.Linear):
|
| 164 |
-
def __init__(
|
| 165 |
-
self,
|
| 166 |
-
in_features: int,
|
| 167 |
-
out_features: int,
|
| 168 |
-
process_group: ProcessGroup,
|
| 169 |
-
bias: bool = True,
|
| 170 |
-
sequence_parallel=True,
|
| 171 |
-
multiple_of=1,
|
| 172 |
-
device=None,
|
| 173 |
-
dtype=None,
|
| 174 |
-
) -> None:
|
| 175 |
-
world_size = torch.distributed.get_world_size(process_group)
|
| 176 |
-
rank = torch.distributed.get_rank(process_group)
|
| 177 |
-
if in_features % multiple_of:
|
| 178 |
-
raise ValueError(
|
| 179 |
-
f"in_features ({in_features}) must be a multiple of {multiple_of}"
|
| 180 |
-
)
|
| 181 |
-
multiple = in_features // multiple_of
|
| 182 |
-
# We want to split @multiple across world_size, but it could be an uneven split
|
| 183 |
-
div = multiple // world_size
|
| 184 |
-
mod = multiple % world_size
|
| 185 |
-
# The first @mod ranks get @div + 1 copies, the rest get @div copies
|
| 186 |
-
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
| 187 |
-
# Only rank 0 will have bias
|
| 188 |
-
super().__init__(
|
| 189 |
-
local_multiple * multiple_of,
|
| 190 |
-
out_features,
|
| 191 |
-
bias=bias and rank == 0,
|
| 192 |
-
device=device,
|
| 193 |
-
dtype=dtype,
|
| 194 |
-
)
|
| 195 |
-
self.process_group = process_group
|
| 196 |
-
self.sequence_parallel = sequence_parallel
|
| 197 |
-
|
| 198 |
-
def forward(self, x):
|
| 199 |
-
"""
|
| 200 |
-
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
|
| 201 |
-
a reduce_scatter of the result.
|
| 202 |
-
"""
|
| 203 |
-
out = parallel_linear_func(x, self.weight, self.bias)
|
| 204 |
-
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
| 205 |
-
return reduce_fn(out, self.process_group)
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
class VocabParallelEmbedding(nn.Embedding):
|
| 209 |
-
def __init__(
|
| 210 |
-
self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs
|
| 211 |
-
):
|
| 212 |
-
self.process_group = process_group
|
| 213 |
-
if process_group is not None:
|
| 214 |
-
world_size = torch.distributed.get_world_size(process_group)
|
| 215 |
-
if num_embeddings % world_size != 0:
|
| 216 |
-
raise ValueError(
|
| 217 |
-
f"num_embeddings ({num_embeddings}) must be divisible by "
|
| 218 |
-
f"world_size ({world_size})"
|
| 219 |
-
)
|
| 220 |
-
if world_size > 1 and padding_idx is not None:
|
| 221 |
-
raise RuntimeError("ParallelEmbedding does not support padding_idx")
|
| 222 |
-
else:
|
| 223 |
-
world_size = 1
|
| 224 |
-
super().__init__(
|
| 225 |
-
num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs
|
| 226 |
-
)
|
| 227 |
-
|
| 228 |
-
def forward(self, input: Tensor) -> Tensor:
|
| 229 |
-
if self.process_group is None:
|
| 230 |
-
return super().forward(input)
|
| 231 |
-
else:
|
| 232 |
-
rank = torch.distributed.get_rank(self.process_group)
|
| 233 |
-
vocab_size = self.num_embeddings
|
| 234 |
-
vocab_start_index, vocab_end_index = (
|
| 235 |
-
rank * vocab_size,
|
| 236 |
-
(rank + 1) * vocab_size,
|
| 237 |
-
)
|
| 238 |
-
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
| 239 |
-
input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
|
| 240 |
-
input = input - vocab_start_index
|
| 241 |
-
input[input_ids_mask] = 0
|
| 242 |
-
embeddings = super().forward(input)
|
| 243 |
-
embeddings[input_ids_mask] = 0.0
|
| 244 |
-
return embeddings
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
class ColumnParallelEmbedding(nn.Embedding):
|
| 248 |
-
def __init__(
|
| 249 |
-
self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs
|
| 250 |
-
):
|
| 251 |
-
self.process_group = process_group
|
| 252 |
-
if process_group is not None:
|
| 253 |
-
world_size = torch.distributed.get_world_size(process_group)
|
| 254 |
-
if embedding_dim % world_size != 0:
|
| 255 |
-
raise ValueError(
|
| 256 |
-
f"embedding_dim ({embedding_dim}) must be divisible by "
|
| 257 |
-
f"world_size ({world_size})"
|
| 258 |
-
)
|
| 259 |
-
else:
|
| 260 |
-
world_size = 1
|
| 261 |
-
super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
class ParallelEmbeddings(nn.Module):
|
| 265 |
-
def __init__(
|
| 266 |
-
self,
|
| 267 |
-
embed_dim,
|
| 268 |
-
vocab_size,
|
| 269 |
-
max_position_embeddings,
|
| 270 |
-
process_group,
|
| 271 |
-
padding_idx=None,
|
| 272 |
-
sequence_parallel=True,
|
| 273 |
-
device=None,
|
| 274 |
-
dtype=None,
|
| 275 |
-
):
|
| 276 |
-
"""
|
| 277 |
-
If max_position_embeddings <= 0, there's no position embeddings
|
| 278 |
-
"""
|
| 279 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
| 280 |
-
super().__init__()
|
| 281 |
-
self.process_group = process_group
|
| 282 |
-
self.sequence_parallel = sequence_parallel
|
| 283 |
-
self.word_embeddings = VocabParallelEmbedding(
|
| 284 |
-
vocab_size,
|
| 285 |
-
embed_dim,
|
| 286 |
-
padding_idx=padding_idx,
|
| 287 |
-
process_group=process_group,
|
| 288 |
-
**factory_kwargs,
|
| 289 |
-
)
|
| 290 |
-
self.max_position_embeddings = max_position_embeddings
|
| 291 |
-
if self.max_position_embeddings > 0:
|
| 292 |
-
self.position_embeddings = ColumnParallelEmbedding(
|
| 293 |
-
max_position_embeddings,
|
| 294 |
-
embed_dim,
|
| 295 |
-
process_group=process_group,
|
| 296 |
-
**factory_kwargs,
|
| 297 |
-
)
|
| 298 |
-
|
| 299 |
-
def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
|
| 300 |
-
"""
|
| 301 |
-
input_ids: (batch, seqlen)
|
| 302 |
-
position_ids: (batch, seqlen)
|
| 303 |
-
"""
|
| 304 |
-
batch_size, seqlen = input_ids.shape
|
| 305 |
-
world_size = torch.distributed.get_world_size(self.process_group)
|
| 306 |
-
embeddings = self.word_embeddings(input_ids)
|
| 307 |
-
if self.max_position_embeddings > 0:
|
| 308 |
-
if position_ids is None:
|
| 309 |
-
position_ids = torch.arange(
|
| 310 |
-
seqlen, dtype=torch.long, device=input_ids.device
|
| 311 |
-
)
|
| 312 |
-
position_embeddings = self.position_embeddings(position_ids)
|
| 313 |
-
if world_size <= 1:
|
| 314 |
-
embeddings = embeddings + position_embeddings
|
| 315 |
-
else:
|
| 316 |
-
partition_dim = self.position_embeddings.embedding_dim
|
| 317 |
-
rank = torch.distributed.get_rank(self.process_group)
|
| 318 |
-
embeddings[
|
| 319 |
-
..., rank * partition_dim : (rank + 1) * partition_dim
|
| 320 |
-
] += position_embeddings
|
| 321 |
-
if combine_batch_seqlen_dim:
|
| 322 |
-
embeddings = rearrange(embeddings, "b s d -> (b s) d")
|
| 323 |
-
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
| 324 |
-
return (
|
| 325 |
-
embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
|
| 326 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py
DELETED
|
@@ -1,338 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2023, Albert Gu, Tri Dao.
|
| 2 |
-
|
| 3 |
-
import math
|
| 4 |
-
from functools import partial
|
| 5 |
-
import json
|
| 6 |
-
import os
|
| 7 |
-
import copy
|
| 8 |
-
|
| 9 |
-
from collections import namedtuple
|
| 10 |
-
|
| 11 |
-
import torch
|
| 12 |
-
import torch.nn as nn
|
| 13 |
-
|
| 14 |
-
from .config_mamba import MambaConfig
|
| 15 |
-
from ..modules.mamba_simple import Mamba
|
| 16 |
-
from ..modules.mamba2 import Mamba2
|
| 17 |
-
from ..modules.mha import MHA
|
| 18 |
-
from ..modules.mlp import GatedMLP
|
| 19 |
-
from ..modules.block import Block
|
| 20 |
-
from ..utils.generation import GenerationMixin
|
| 21 |
-
from ..utils.hf import load_config_hf, load_state_dict_hf
|
| 22 |
-
|
| 23 |
-
try:
|
| 24 |
-
from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
|
| 25 |
-
except ImportError:
|
| 26 |
-
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def create_block(
|
| 30 |
-
d_model,
|
| 31 |
-
d_intermediate,
|
| 32 |
-
ssm_cfg=None,
|
| 33 |
-
attn_layer_idx=None,
|
| 34 |
-
attn_cfg=None,
|
| 35 |
-
norm_epsilon=1e-5,
|
| 36 |
-
rms_norm=False,
|
| 37 |
-
residual_in_fp32=False,
|
| 38 |
-
fused_add_norm=False,
|
| 39 |
-
layer_idx=None,
|
| 40 |
-
device=None,
|
| 41 |
-
dtype=None,
|
| 42 |
-
):
|
| 43 |
-
if ssm_cfg is None:
|
| 44 |
-
ssm_cfg = {}
|
| 45 |
-
if attn_layer_idx is None:
|
| 46 |
-
attn_layer_idx = []
|
| 47 |
-
if attn_cfg is None:
|
| 48 |
-
attn_cfg = {}
|
| 49 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
| 50 |
-
if layer_idx not in attn_layer_idx:
|
| 51 |
-
# Create a copy of the config to modify
|
| 52 |
-
ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
|
| 53 |
-
ssm_layer = ssm_cfg.pop("layer", "Mamba1")
|
| 54 |
-
if ssm_layer not in ["Mamba1", "Mamba2"]:
|
| 55 |
-
raise ValueError(
|
| 56 |
-
f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2"
|
| 57 |
-
)
|
| 58 |
-
mixer_cls = partial(
|
| 59 |
-
Mamba2 if ssm_layer == "Mamba2" else Mamba,
|
| 60 |
-
layer_idx=layer_idx,
|
| 61 |
-
**ssm_cfg,
|
| 62 |
-
**factory_kwargs,
|
| 63 |
-
)
|
| 64 |
-
else:
|
| 65 |
-
mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
|
| 66 |
-
norm_cls = partial(
|
| 67 |
-
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
| 68 |
-
)
|
| 69 |
-
if d_intermediate == 0:
|
| 70 |
-
mlp_cls = nn.Identity
|
| 71 |
-
else:
|
| 72 |
-
mlp_cls = partial(
|
| 73 |
-
GatedMLP,
|
| 74 |
-
hidden_features=d_intermediate,
|
| 75 |
-
out_features=d_model,
|
| 76 |
-
**factory_kwargs,
|
| 77 |
-
)
|
| 78 |
-
block = Block(
|
| 79 |
-
d_model,
|
| 80 |
-
mixer_cls,
|
| 81 |
-
mlp_cls,
|
| 82 |
-
norm_cls=norm_cls,
|
| 83 |
-
fused_add_norm=fused_add_norm,
|
| 84 |
-
residual_in_fp32=residual_in_fp32,
|
| 85 |
-
)
|
| 86 |
-
block.layer_idx = layer_idx
|
| 87 |
-
return block
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
| 91 |
-
def _init_weights(
|
| 92 |
-
module,
|
| 93 |
-
n_layer,
|
| 94 |
-
initializer_range=0.02, # Now only used for embedding layer.
|
| 95 |
-
rescale_prenorm_residual=True,
|
| 96 |
-
n_residuals_per_layer=1, # Change to 2 if we have MLP
|
| 97 |
-
):
|
| 98 |
-
if isinstance(module, nn.Linear):
|
| 99 |
-
if module.bias is not None:
|
| 100 |
-
if not getattr(module.bias, "_no_reinit", False):
|
| 101 |
-
nn.init.zeros_(module.bias)
|
| 102 |
-
elif isinstance(module, nn.Embedding):
|
| 103 |
-
nn.init.normal_(module.weight, std=initializer_range)
|
| 104 |
-
|
| 105 |
-
if rescale_prenorm_residual:
|
| 106 |
-
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
| 107 |
-
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
| 108 |
-
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
| 109 |
-
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
| 110 |
-
#
|
| 111 |
-
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
| 112 |
-
for name, p in module.named_parameters():
|
| 113 |
-
if name in ["out_proj.weight", "fc2.weight"]:
|
| 114 |
-
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
| 115 |
-
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
| 116 |
-
# We need to reinit p since this code could be called multiple times
|
| 117 |
-
# Having just p *= scale would repeatedly scale it down
|
| 118 |
-
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
| 119 |
-
with torch.no_grad():
|
| 120 |
-
p /= math.sqrt(n_residuals_per_layer * n_layer)
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
class MixerModel(nn.Module):
|
| 124 |
-
def __init__(
|
| 125 |
-
self,
|
| 126 |
-
d_model: int,
|
| 127 |
-
n_layer: int,
|
| 128 |
-
d_intermediate: int,
|
| 129 |
-
vocab_size: int,
|
| 130 |
-
ssm_cfg=None,
|
| 131 |
-
attn_layer_idx=None,
|
| 132 |
-
attn_cfg=None,
|
| 133 |
-
norm_epsilon: float = 1e-5,
|
| 134 |
-
rms_norm: bool = False,
|
| 135 |
-
initializer_cfg=None,
|
| 136 |
-
fused_add_norm=False,
|
| 137 |
-
residual_in_fp32=False,
|
| 138 |
-
device=None,
|
| 139 |
-
dtype=None,
|
| 140 |
-
) -> None:
|
| 141 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
| 142 |
-
super().__init__()
|
| 143 |
-
self.residual_in_fp32 = residual_in_fp32
|
| 144 |
-
|
| 145 |
-
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
|
| 146 |
-
|
| 147 |
-
# We change the order of residual and layer norm:
|
| 148 |
-
# Instead of LN -> Attn / MLP -> Add, we do:
|
| 149 |
-
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
|
| 150 |
-
# the main branch (output of MLP / Mixer). The model definition is unchanged.
|
| 151 |
-
# This is for performance reason: we can fuse add + layer_norm.
|
| 152 |
-
self.fused_add_norm = fused_add_norm
|
| 153 |
-
if self.fused_add_norm:
|
| 154 |
-
if layer_norm_fn is None or rms_norm_fn is None:
|
| 155 |
-
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
|
| 156 |
-
|
| 157 |
-
self.layers = nn.ModuleList(
|
| 158 |
-
[
|
| 159 |
-
create_block(
|
| 160 |
-
d_model,
|
| 161 |
-
d_intermediate=d_intermediate,
|
| 162 |
-
ssm_cfg=ssm_cfg,
|
| 163 |
-
attn_layer_idx=attn_layer_idx,
|
| 164 |
-
attn_cfg=attn_cfg,
|
| 165 |
-
norm_epsilon=norm_epsilon,
|
| 166 |
-
rms_norm=rms_norm,
|
| 167 |
-
residual_in_fp32=residual_in_fp32,
|
| 168 |
-
fused_add_norm=fused_add_norm,
|
| 169 |
-
layer_idx=i,
|
| 170 |
-
**factory_kwargs,
|
| 171 |
-
)
|
| 172 |
-
for i in range(n_layer)
|
| 173 |
-
]
|
| 174 |
-
)
|
| 175 |
-
|
| 176 |
-
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
|
| 177 |
-
d_model, eps=norm_epsilon, **factory_kwargs
|
| 178 |
-
)
|
| 179 |
-
|
| 180 |
-
self.apply(
|
| 181 |
-
partial(
|
| 182 |
-
_init_weights,
|
| 183 |
-
n_layer=n_layer,
|
| 184 |
-
**(initializer_cfg if initializer_cfg is not None else {}),
|
| 185 |
-
n_residuals_per_layer=(
|
| 186 |
-
1 if d_intermediate == 0 else 2
|
| 187 |
-
), # 2 if we have MLP
|
| 188 |
-
)
|
| 189 |
-
)
|
| 190 |
-
|
| 191 |
-
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 192 |
-
return {
|
| 193 |
-
i: layer.allocate_inference_cache(
|
| 194 |
-
batch_size, max_seqlen, dtype=dtype, **kwargs
|
| 195 |
-
)
|
| 196 |
-
for i, layer in enumerate(self.layers)
|
| 197 |
-
}
|
| 198 |
-
|
| 199 |
-
def forward(self, input_ids, inference_params=None, **mixer_kwargs):
|
| 200 |
-
hidden_states = self.embedding(input_ids)
|
| 201 |
-
residual = None
|
| 202 |
-
for layer in self.layers:
|
| 203 |
-
hidden_states, residual = layer(
|
| 204 |
-
hidden_states,
|
| 205 |
-
residual,
|
| 206 |
-
inference_params=inference_params,
|
| 207 |
-
**mixer_kwargs,
|
| 208 |
-
)
|
| 209 |
-
if not self.fused_add_norm:
|
| 210 |
-
residual = (
|
| 211 |
-
(hidden_states + residual) if residual is not None else hidden_states
|
| 212 |
-
)
|
| 213 |
-
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
| 214 |
-
else:
|
| 215 |
-
# Set prenorm=False here since we don't need the residual
|
| 216 |
-
hidden_states = layer_norm_fn(
|
| 217 |
-
hidden_states,
|
| 218 |
-
self.norm_f.weight,
|
| 219 |
-
self.norm_f.bias,
|
| 220 |
-
eps=self.norm_f.eps,
|
| 221 |
-
residual=residual,
|
| 222 |
-
prenorm=False,
|
| 223 |
-
residual_in_fp32=self.residual_in_fp32,
|
| 224 |
-
is_rms_norm=isinstance(self.norm_f, RMSNorm),
|
| 225 |
-
)
|
| 226 |
-
return hidden_states
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
class MambaLMHeadModel(nn.Module, GenerationMixin):
|
| 230 |
-
|
| 231 |
-
def __init__(
|
| 232 |
-
self,
|
| 233 |
-
config: MambaConfig,
|
| 234 |
-
initializer_cfg=None,
|
| 235 |
-
device=None,
|
| 236 |
-
dtype=None,
|
| 237 |
-
) -> None:
|
| 238 |
-
self.config = config
|
| 239 |
-
d_model = config.d_model
|
| 240 |
-
n_layer = config.n_layer
|
| 241 |
-
d_intermediate = config.d_intermediate
|
| 242 |
-
vocab_size = config.vocab_size
|
| 243 |
-
ssm_cfg = config.ssm_cfg
|
| 244 |
-
attn_layer_idx = config.attn_layer_idx
|
| 245 |
-
attn_cfg = config.attn_cfg
|
| 246 |
-
rms_norm = config.rms_norm
|
| 247 |
-
residual_in_fp32 = config.residual_in_fp32
|
| 248 |
-
fused_add_norm = config.fused_add_norm
|
| 249 |
-
pad_vocab_size_multiple = config.pad_vocab_size_multiple
|
| 250 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
| 251 |
-
|
| 252 |
-
super().__init__()
|
| 253 |
-
if vocab_size % pad_vocab_size_multiple != 0:
|
| 254 |
-
vocab_size += pad_vocab_size_multiple - (
|
| 255 |
-
vocab_size % pad_vocab_size_multiple
|
| 256 |
-
)
|
| 257 |
-
self.backbone = MixerModel(
|
| 258 |
-
d_model=d_model,
|
| 259 |
-
n_layer=n_layer,
|
| 260 |
-
d_intermediate=d_intermediate,
|
| 261 |
-
vocab_size=vocab_size,
|
| 262 |
-
ssm_cfg=ssm_cfg,
|
| 263 |
-
attn_layer_idx=attn_layer_idx,
|
| 264 |
-
attn_cfg=attn_cfg,
|
| 265 |
-
rms_norm=rms_norm,
|
| 266 |
-
initializer_cfg=initializer_cfg,
|
| 267 |
-
fused_add_norm=fused_add_norm,
|
| 268 |
-
residual_in_fp32=residual_in_fp32,
|
| 269 |
-
**factory_kwargs,
|
| 270 |
-
)
|
| 271 |
-
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
| 272 |
-
|
| 273 |
-
# Initialize weights and apply final processing
|
| 274 |
-
self.apply(
|
| 275 |
-
partial(
|
| 276 |
-
_init_weights,
|
| 277 |
-
n_layer=n_layer,
|
| 278 |
-
**(initializer_cfg if initializer_cfg is not None else {}),
|
| 279 |
-
)
|
| 280 |
-
)
|
| 281 |
-
self.tie_weights()
|
| 282 |
-
|
| 283 |
-
def tie_weights(self):
|
| 284 |
-
if self.config.tie_embeddings:
|
| 285 |
-
self.lm_head.weight = self.backbone.embedding.weight
|
| 286 |
-
|
| 287 |
-
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 288 |
-
return self.backbone.allocate_inference_cache(
|
| 289 |
-
batch_size, max_seqlen, dtype=dtype, **kwargs
|
| 290 |
-
)
|
| 291 |
-
|
| 292 |
-
def forward(
|
| 293 |
-
self,
|
| 294 |
-
input_ids,
|
| 295 |
-
position_ids=None,
|
| 296 |
-
inference_params=None,
|
| 297 |
-
num_last_tokens=0,
|
| 298 |
-
**mixer_kwargs,
|
| 299 |
-
):
|
| 300 |
-
"""
|
| 301 |
-
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
| 302 |
-
num_last_tokens: if > 0, only return the logits for the last n tokens
|
| 303 |
-
"""
|
| 304 |
-
hidden_states = self.backbone(
|
| 305 |
-
input_ids, inference_params=inference_params, **mixer_kwargs
|
| 306 |
-
)
|
| 307 |
-
if num_last_tokens > 0:
|
| 308 |
-
hidden_states = hidden_states[:, -num_last_tokens:]
|
| 309 |
-
lm_logits = self.lm_head(hidden_states)
|
| 310 |
-
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
| 311 |
-
return CausalLMOutput(logits=lm_logits)
|
| 312 |
-
|
| 313 |
-
@classmethod
|
| 314 |
-
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
| 315 |
-
config_data = load_config_hf(pretrained_model_name)
|
| 316 |
-
config = MambaConfig(**config_data)
|
| 317 |
-
model = cls(config, device=device, dtype=dtype, **kwargs)
|
| 318 |
-
model.load_state_dict(
|
| 319 |
-
load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)
|
| 320 |
-
)
|
| 321 |
-
return model
|
| 322 |
-
|
| 323 |
-
def save_pretrained(self, save_directory):
|
| 324 |
-
"""
|
| 325 |
-
Minimal implementation of save_pretrained for MambaLMHeadModel.
|
| 326 |
-
Save the model and its configuration file to a directory.
|
| 327 |
-
"""
|
| 328 |
-
# Ensure save_directory exists
|
| 329 |
-
os.makedirs(save_directory, exist_ok=True)
|
| 330 |
-
|
| 331 |
-
# Save the model's state_dict
|
| 332 |
-
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
| 333 |
-
torch.save(self.state_dict(), model_path)
|
| 334 |
-
|
| 335 |
-
# Save the configuration of the model
|
| 336 |
-
config_path = os.path.join(save_directory, "config.json")
|
| 337 |
-
with open(config_path, "w") as f:
|
| 338 |
-
json.dump(self.config.__dict__, f, indent=4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/ops/selective_scan_interface.py
DELETED
|
@@ -1,659 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
from ..utils.torch import custom_fwd, custom_bwd
|
| 6 |
-
|
| 7 |
-
from einops import rearrange, repeat
|
| 8 |
-
|
| 9 |
-
try:
|
| 10 |
-
from causal_conv1d import causal_conv1d_fn
|
| 11 |
-
import causal_conv1d_cuda
|
| 12 |
-
except ImportError:
|
| 13 |
-
causal_conv1d_fn = None
|
| 14 |
-
causal_conv1d_cuda = None
|
| 15 |
-
|
| 16 |
-
from .triton.layer_norm import _layer_norm_fwd
|
| 17 |
-
|
| 18 |
-
from .._ops import ops
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class SelectiveScanFn(torch.autograd.Function):
|
| 22 |
-
|
| 23 |
-
@staticmethod
|
| 24 |
-
def forward(
|
| 25 |
-
ctx,
|
| 26 |
-
u,
|
| 27 |
-
delta,
|
| 28 |
-
A,
|
| 29 |
-
B,
|
| 30 |
-
C,
|
| 31 |
-
D=None,
|
| 32 |
-
z=None,
|
| 33 |
-
delta_bias=None,
|
| 34 |
-
delta_softplus=False,
|
| 35 |
-
return_last_state=False,
|
| 36 |
-
):
|
| 37 |
-
if u.stride(-1) != 1:
|
| 38 |
-
u = u.contiguous()
|
| 39 |
-
if delta.stride(-1) != 1:
|
| 40 |
-
delta = delta.contiguous()
|
| 41 |
-
if D is not None:
|
| 42 |
-
D = D.contiguous()
|
| 43 |
-
if B.stride(-1) != 1:
|
| 44 |
-
B = B.contiguous()
|
| 45 |
-
if C.stride(-1) != 1:
|
| 46 |
-
C = C.contiguous()
|
| 47 |
-
if z is not None and z.stride(-1) != 1:
|
| 48 |
-
z = z.contiguous()
|
| 49 |
-
if B.dim() == 3:
|
| 50 |
-
B = rearrange(B, "b dstate l -> b 1 dstate l")
|
| 51 |
-
ctx.squeeze_B = True
|
| 52 |
-
if C.dim() == 3:
|
| 53 |
-
C = rearrange(C, "b dstate l -> b 1 dstate l")
|
| 54 |
-
ctx.squeeze_C = True
|
| 55 |
-
out, x, *rest = ops.selective_scan_fwd(
|
| 56 |
-
u, delta, A, B, C, D, z, delta_bias, delta_softplus
|
| 57 |
-
)
|
| 58 |
-
ctx.delta_softplus = delta_softplus
|
| 59 |
-
ctx.has_z = z is not None
|
| 60 |
-
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
|
| 61 |
-
if not ctx.has_z:
|
| 62 |
-
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
|
| 63 |
-
return out if not return_last_state else (out, last_state)
|
| 64 |
-
else:
|
| 65 |
-
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
|
| 66 |
-
out_z = rest[0]
|
| 67 |
-
return out_z if not return_last_state else (out_z, last_state)
|
| 68 |
-
|
| 69 |
-
@staticmethod
|
| 70 |
-
def backward(ctx, dout, *args):
|
| 71 |
-
if not ctx.has_z:
|
| 72 |
-
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
|
| 73 |
-
z = None
|
| 74 |
-
out = None
|
| 75 |
-
else:
|
| 76 |
-
u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
|
| 77 |
-
if dout.stride(-1) != 1:
|
| 78 |
-
dout = dout.contiguous()
|
| 79 |
-
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
| 80 |
-
# backward of selective_scan_cuda with the backward of chunk).
|
| 81 |
-
# Here we just pass in None and dz will be allocated in the C++ code.
|
| 82 |
-
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = ops.selective_scan_bwd(
|
| 83 |
-
u,
|
| 84 |
-
delta,
|
| 85 |
-
A,
|
| 86 |
-
B,
|
| 87 |
-
C,
|
| 88 |
-
D,
|
| 89 |
-
z,
|
| 90 |
-
delta_bias,
|
| 91 |
-
dout,
|
| 92 |
-
x,
|
| 93 |
-
out,
|
| 94 |
-
None,
|
| 95 |
-
ctx.delta_softplus,
|
| 96 |
-
False, # option to recompute out_z, not used here
|
| 97 |
-
)
|
| 98 |
-
dz = rest[0] if ctx.has_z else None
|
| 99 |
-
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
|
| 100 |
-
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
|
| 101 |
-
return (
|
| 102 |
-
du,
|
| 103 |
-
ddelta,
|
| 104 |
-
dA,
|
| 105 |
-
dB,
|
| 106 |
-
dC,
|
| 107 |
-
dD if D is not None else None,
|
| 108 |
-
dz,
|
| 109 |
-
ddelta_bias if delta_bias is not None else None,
|
| 110 |
-
None,
|
| 111 |
-
None,
|
| 112 |
-
)
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
def rms_norm_forward(
|
| 116 |
-
x,
|
| 117 |
-
weight,
|
| 118 |
-
bias,
|
| 119 |
-
eps=1e-6,
|
| 120 |
-
is_rms_norm=True,
|
| 121 |
-
):
|
| 122 |
-
# x (b l) d
|
| 123 |
-
if x.stride(-1) != 1:
|
| 124 |
-
x = x.contiguous()
|
| 125 |
-
weight = weight.contiguous()
|
| 126 |
-
if bias is not None:
|
| 127 |
-
bias = bias.contiguous()
|
| 128 |
-
y = _layer_norm_fwd(
|
| 129 |
-
x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm
|
| 130 |
-
)[0]
|
| 131 |
-
# y (b l) d
|
| 132 |
-
return y
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
def selective_scan_fn(
|
| 136 |
-
u,
|
| 137 |
-
delta,
|
| 138 |
-
A,
|
| 139 |
-
B,
|
| 140 |
-
C,
|
| 141 |
-
D=None,
|
| 142 |
-
z=None,
|
| 143 |
-
delta_bias=None,
|
| 144 |
-
delta_softplus=False,
|
| 145 |
-
return_last_state=False,
|
| 146 |
-
):
|
| 147 |
-
"""if return_last_state is True, returns (out, last_state)
|
| 148 |
-
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
|
| 149 |
-
not considered in the backward pass.
|
| 150 |
-
"""
|
| 151 |
-
return SelectiveScanFn.apply(
|
| 152 |
-
u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
def selective_scan_ref(
|
| 157 |
-
u,
|
| 158 |
-
delta,
|
| 159 |
-
A,
|
| 160 |
-
B,
|
| 161 |
-
C,
|
| 162 |
-
D=None,
|
| 163 |
-
z=None,
|
| 164 |
-
delta_bias=None,
|
| 165 |
-
delta_softplus=False,
|
| 166 |
-
return_last_state=False,
|
| 167 |
-
):
|
| 168 |
-
"""
|
| 169 |
-
u: r(B D L)
|
| 170 |
-
delta: r(B D L)
|
| 171 |
-
A: c(D N) or r(D N)
|
| 172 |
-
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
| 173 |
-
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
| 174 |
-
D: r(D)
|
| 175 |
-
z: r(B D L)
|
| 176 |
-
delta_bias: r(D), fp32
|
| 177 |
-
|
| 178 |
-
out: r(B D L)
|
| 179 |
-
last_state (optional): r(B D dstate) or c(B D dstate)
|
| 180 |
-
"""
|
| 181 |
-
dtype_in = u.dtype
|
| 182 |
-
u = u.float()
|
| 183 |
-
delta = delta.float()
|
| 184 |
-
if delta_bias is not None:
|
| 185 |
-
delta = delta + delta_bias[..., None].float()
|
| 186 |
-
if delta_softplus:
|
| 187 |
-
delta = F.softplus(delta)
|
| 188 |
-
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
|
| 189 |
-
is_variable_B = B.dim() >= 3
|
| 190 |
-
is_variable_C = C.dim() >= 3
|
| 191 |
-
if A.is_complex():
|
| 192 |
-
if is_variable_B:
|
| 193 |
-
B = torch.view_as_complex(
|
| 194 |
-
rearrange(B.float(), "... (L two) -> ... L two", two=2)
|
| 195 |
-
)
|
| 196 |
-
if is_variable_C:
|
| 197 |
-
C = torch.view_as_complex(
|
| 198 |
-
rearrange(C.float(), "... (L two) -> ... L two", two=2)
|
| 199 |
-
)
|
| 200 |
-
else:
|
| 201 |
-
B = B.float()
|
| 202 |
-
C = C.float()
|
| 203 |
-
x = A.new_zeros((batch, dim, dstate))
|
| 204 |
-
ys = []
|
| 205 |
-
deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
|
| 206 |
-
if not is_variable_B:
|
| 207 |
-
deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
|
| 208 |
-
else:
|
| 209 |
-
if B.dim() == 3:
|
| 210 |
-
deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
|
| 211 |
-
else:
|
| 212 |
-
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
|
| 213 |
-
deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
|
| 214 |
-
if is_variable_C and C.dim() == 4:
|
| 215 |
-
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
|
| 216 |
-
last_state = None
|
| 217 |
-
for i in range(u.shape[2]):
|
| 218 |
-
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
| 219 |
-
if not is_variable_C:
|
| 220 |
-
y = torch.einsum("bdn,dn->bd", x, C)
|
| 221 |
-
else:
|
| 222 |
-
if C.dim() == 3:
|
| 223 |
-
y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
|
| 224 |
-
else:
|
| 225 |
-
y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
|
| 226 |
-
if i == u.shape[2] - 1:
|
| 227 |
-
last_state = x
|
| 228 |
-
if y.is_complex():
|
| 229 |
-
y = y.real * 2
|
| 230 |
-
ys.append(y)
|
| 231 |
-
y = torch.stack(ys, dim=2) # (batch dim L)
|
| 232 |
-
out = y if D is None else y + u * rearrange(D, "d -> d 1")
|
| 233 |
-
if z is not None:
|
| 234 |
-
out = out * F.silu(z)
|
| 235 |
-
out = out.to(dtype=dtype_in)
|
| 236 |
-
return out if not return_last_state else (out, last_state)
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
class MambaInnerFn(torch.autograd.Function):
|
| 240 |
-
|
| 241 |
-
@staticmethod
|
| 242 |
-
@custom_fwd
|
| 243 |
-
def forward(
|
| 244 |
-
ctx,
|
| 245 |
-
xz,
|
| 246 |
-
conv1d_weight,
|
| 247 |
-
conv1d_bias,
|
| 248 |
-
x_proj_weight,
|
| 249 |
-
delta_proj_weight,
|
| 250 |
-
out_proj_weight,
|
| 251 |
-
out_proj_bias,
|
| 252 |
-
A,
|
| 253 |
-
B=None,
|
| 254 |
-
C=None,
|
| 255 |
-
D=None,
|
| 256 |
-
delta_bias=None,
|
| 257 |
-
B_proj_bias=None,
|
| 258 |
-
C_proj_bias=None,
|
| 259 |
-
delta_softplus=True,
|
| 260 |
-
checkpoint_lvl=1,
|
| 261 |
-
b_rms_weight=None,
|
| 262 |
-
c_rms_weight=None,
|
| 263 |
-
dt_rms_weight=None,
|
| 264 |
-
b_c_dt_rms_eps=1e-6,
|
| 265 |
-
):
|
| 266 |
-
"""
|
| 267 |
-
xz: (batch, dim, seqlen)
|
| 268 |
-
"""
|
| 269 |
-
assert (
|
| 270 |
-
causal_conv1d_cuda is not None
|
| 271 |
-
), "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
| 272 |
-
assert checkpoint_lvl in [0, 1]
|
| 273 |
-
L = xz.shape[-1]
|
| 274 |
-
delta_rank = delta_proj_weight.shape[1]
|
| 275 |
-
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
| 276 |
-
if torch.is_autocast_enabled():
|
| 277 |
-
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
| 278 |
-
delta_proj_weight = delta_proj_weight.to(
|
| 279 |
-
dtype=torch.get_autocast_gpu_dtype()
|
| 280 |
-
)
|
| 281 |
-
out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
| 282 |
-
out_proj_bias = (
|
| 283 |
-
out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
|
| 284 |
-
if out_proj_bias is not None
|
| 285 |
-
else None
|
| 286 |
-
)
|
| 287 |
-
if xz.stride(-1) != 1:
|
| 288 |
-
xz = xz.contiguous()
|
| 289 |
-
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
|
| 290 |
-
x, z = xz.chunk(2, dim=1)
|
| 291 |
-
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
|
| 292 |
-
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
|
| 293 |
-
x, conv1d_weight, conv1d_bias, None, None, None, True
|
| 294 |
-
)
|
| 295 |
-
# We're being very careful here about the layout, to avoid extra transposes.
|
| 296 |
-
# We want delta to have d as the slowest moving dimension
|
| 297 |
-
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
| 298 |
-
x_dbl = F.linear(
|
| 299 |
-
rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight
|
| 300 |
-
) # (bl d)
|
| 301 |
-
delta = rearrange(
|
| 302 |
-
delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L
|
| 303 |
-
)
|
| 304 |
-
ctx.is_variable_B = B is None
|
| 305 |
-
ctx.is_variable_C = C is None
|
| 306 |
-
ctx.B_proj_bias_is_None = B_proj_bias is None
|
| 307 |
-
ctx.C_proj_bias_is_None = C_proj_bias is None
|
| 308 |
-
if B is None: # variable B
|
| 309 |
-
B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate)
|
| 310 |
-
if B_proj_bias is not None:
|
| 311 |
-
B = B + B_proj_bias.to(dtype=B.dtype)
|
| 312 |
-
if not A.is_complex():
|
| 313 |
-
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
| 314 |
-
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| 315 |
-
else:
|
| 316 |
-
B = rearrange(
|
| 317 |
-
B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2
|
| 318 |
-
).contiguous()
|
| 319 |
-
else:
|
| 320 |
-
if B.stride(-1) != 1:
|
| 321 |
-
B = B.contiguous()
|
| 322 |
-
if C is None: # variable C
|
| 323 |
-
C = x_dbl[:, -d_state:] # (bl dstate)
|
| 324 |
-
if C_proj_bias is not None:
|
| 325 |
-
C = C + C_proj_bias.to(dtype=C.dtype)
|
| 326 |
-
if not A.is_complex():
|
| 327 |
-
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
| 328 |
-
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| 329 |
-
else:
|
| 330 |
-
C = rearrange(
|
| 331 |
-
C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2
|
| 332 |
-
).contiguous()
|
| 333 |
-
else:
|
| 334 |
-
if C.stride(-1) != 1:
|
| 335 |
-
C = C.contiguous()
|
| 336 |
-
if D is not None:
|
| 337 |
-
D = D.contiguous()
|
| 338 |
-
|
| 339 |
-
if b_rms_weight is not None:
|
| 340 |
-
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
| 341 |
-
B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)
|
| 342 |
-
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| 343 |
-
if c_rms_weight is not None:
|
| 344 |
-
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
| 345 |
-
C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)
|
| 346 |
-
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| 347 |
-
if dt_rms_weight is not None:
|
| 348 |
-
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
|
| 349 |
-
delta = rms_norm_forward(
|
| 350 |
-
delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps
|
| 351 |
-
)
|
| 352 |
-
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
|
| 353 |
-
|
| 354 |
-
out, scan_intermediates, out_z = ops.selective_scan_fwd(
|
| 355 |
-
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
|
| 356 |
-
)
|
| 357 |
-
ctx.delta_softplus = delta_softplus
|
| 358 |
-
ctx.out_proj_bias_is_None = out_proj_bias is None
|
| 359 |
-
ctx.checkpoint_lvl = checkpoint_lvl
|
| 360 |
-
ctx.b_rms_weight = b_rms_weight
|
| 361 |
-
ctx.c_rms_weight = c_rms_weight
|
| 362 |
-
ctx.dt_rms_weight = dt_rms_weight
|
| 363 |
-
ctx.b_c_dt_rms_eps = b_c_dt_rms_eps
|
| 364 |
-
if (
|
| 365 |
-
checkpoint_lvl >= 1
|
| 366 |
-
): # Will recompute conv1d_out and delta in the backward pass
|
| 367 |
-
conv1d_out, delta = None, None
|
| 368 |
-
ctx.save_for_backward(
|
| 369 |
-
xz,
|
| 370 |
-
conv1d_weight,
|
| 371 |
-
conv1d_bias,
|
| 372 |
-
x_dbl,
|
| 373 |
-
x_proj_weight,
|
| 374 |
-
delta_proj_weight,
|
| 375 |
-
out_proj_weight,
|
| 376 |
-
conv1d_out,
|
| 377 |
-
delta,
|
| 378 |
-
A,
|
| 379 |
-
B,
|
| 380 |
-
C,
|
| 381 |
-
D,
|
| 382 |
-
delta_bias,
|
| 383 |
-
scan_intermediates,
|
| 384 |
-
b_rms_weight,
|
| 385 |
-
c_rms_weight,
|
| 386 |
-
dt_rms_weight,
|
| 387 |
-
out,
|
| 388 |
-
)
|
| 389 |
-
return F.linear(
|
| 390 |
-
rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias
|
| 391 |
-
)
|
| 392 |
-
|
| 393 |
-
@staticmethod
|
| 394 |
-
@custom_bwd
|
| 395 |
-
def backward(ctx, dout):
|
| 396 |
-
# dout: (batch, seqlen, dim)
|
| 397 |
-
assert (
|
| 398 |
-
causal_conv1d_cuda is not None
|
| 399 |
-
), "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
| 400 |
-
(
|
| 401 |
-
xz,
|
| 402 |
-
conv1d_weight,
|
| 403 |
-
conv1d_bias,
|
| 404 |
-
x_dbl,
|
| 405 |
-
x_proj_weight,
|
| 406 |
-
delta_proj_weight,
|
| 407 |
-
out_proj_weight,
|
| 408 |
-
conv1d_out,
|
| 409 |
-
delta,
|
| 410 |
-
A,
|
| 411 |
-
B,
|
| 412 |
-
C,
|
| 413 |
-
D,
|
| 414 |
-
delta_bias,
|
| 415 |
-
scan_intermediates,
|
| 416 |
-
b_rms_weight,
|
| 417 |
-
c_rms_weight,
|
| 418 |
-
dt_rms_weight,
|
| 419 |
-
out,
|
| 420 |
-
) = ctx.saved_tensors
|
| 421 |
-
L = xz.shape[-1]
|
| 422 |
-
delta_rank = delta_proj_weight.shape[1]
|
| 423 |
-
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
| 424 |
-
x, z = xz.chunk(2, dim=1)
|
| 425 |
-
if dout.stride(-1) != 1:
|
| 426 |
-
dout = dout.contiguous()
|
| 427 |
-
if ctx.checkpoint_lvl == 1:
|
| 428 |
-
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
|
| 429 |
-
x, conv1d_weight, conv1d_bias, None, None, None, True
|
| 430 |
-
)
|
| 431 |
-
delta = rearrange(
|
| 432 |
-
delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L
|
| 433 |
-
)
|
| 434 |
-
if dt_rms_weight is not None:
|
| 435 |
-
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
|
| 436 |
-
delta = rms_norm_forward(
|
| 437 |
-
delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps
|
| 438 |
-
)
|
| 439 |
-
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
|
| 440 |
-
if b_rms_weight is not None:
|
| 441 |
-
# Recompute & RMSNorm B
|
| 442 |
-
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
| 443 |
-
B = rms_norm_forward(B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps)
|
| 444 |
-
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| 445 |
-
if c_rms_weight is not None:
|
| 446 |
-
# Recompute & RMSNorm C
|
| 447 |
-
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
| 448 |
-
C = rms_norm_forward(C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps)
|
| 449 |
-
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| 450 |
-
|
| 451 |
-
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
| 452 |
-
# backward of selective_scan_cuda with the backward of chunk).
|
| 453 |
-
dxz = torch.empty_like(xz) # (batch, dim, seqlen)
|
| 454 |
-
dx, dz = dxz.chunk(2, dim=1)
|
| 455 |
-
dout = rearrange(dout, "b l e -> e (b l)")
|
| 456 |
-
dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
|
| 457 |
-
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = (
|
| 458 |
-
ops.selective_scan_bwd(
|
| 459 |
-
conv1d_out,
|
| 460 |
-
delta,
|
| 461 |
-
A,
|
| 462 |
-
B,
|
| 463 |
-
C,
|
| 464 |
-
D,
|
| 465 |
-
z,
|
| 466 |
-
delta_bias,
|
| 467 |
-
dout_y,
|
| 468 |
-
scan_intermediates,
|
| 469 |
-
out,
|
| 470 |
-
dz,
|
| 471 |
-
ctx.delta_softplus,
|
| 472 |
-
True, # option to recompute out_z
|
| 473 |
-
)
|
| 474 |
-
)
|
| 475 |
-
dout_proj_weight = torch.einsum(
|
| 476 |
-
"eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")
|
| 477 |
-
)
|
| 478 |
-
dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
|
| 479 |
-
dD = dD if D is not None else None
|
| 480 |
-
dx_dbl = torch.empty_like(x_dbl)
|
| 481 |
-
dB_proj_bias = None
|
| 482 |
-
if ctx.is_variable_B:
|
| 483 |
-
if not A.is_complex():
|
| 484 |
-
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
|
| 485 |
-
else:
|
| 486 |
-
dB = rearrange(
|
| 487 |
-
dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2
|
| 488 |
-
).contiguous()
|
| 489 |
-
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
|
| 490 |
-
dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d)
|
| 491 |
-
dB = None
|
| 492 |
-
dC_proj_bias = None
|
| 493 |
-
if ctx.is_variable_C:
|
| 494 |
-
if not A.is_complex():
|
| 495 |
-
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
|
| 496 |
-
else:
|
| 497 |
-
dC = rearrange(
|
| 498 |
-
dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2
|
| 499 |
-
).contiguous()
|
| 500 |
-
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
|
| 501 |
-
dx_dbl[:, -d_state:] = dC # (bl d)
|
| 502 |
-
dC = None
|
| 503 |
-
ddelta = rearrange(ddelta, "b d l -> d (b l)")
|
| 504 |
-
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
|
| 505 |
-
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
|
| 506 |
-
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
|
| 507 |
-
dx_proj_weight = torch.einsum(
|
| 508 |
-
"Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")
|
| 509 |
-
)
|
| 510 |
-
dconv1d_out = torch.addmm(
|
| 511 |
-
dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out
|
| 512 |
-
)
|
| 513 |
-
dconv1d_out = rearrange(
|
| 514 |
-
dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]
|
| 515 |
-
)
|
| 516 |
-
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
| 517 |
-
# backward of conv1d with the backward of chunk).
|
| 518 |
-
dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
|
| 519 |
-
x,
|
| 520 |
-
conv1d_weight,
|
| 521 |
-
conv1d_bias,
|
| 522 |
-
dconv1d_out,
|
| 523 |
-
None,
|
| 524 |
-
None,
|
| 525 |
-
None,
|
| 526 |
-
dx,
|
| 527 |
-
False,
|
| 528 |
-
True,
|
| 529 |
-
)
|
| 530 |
-
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
|
| 531 |
-
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
|
| 532 |
-
return (
|
| 533 |
-
dxz,
|
| 534 |
-
dconv1d_weight,
|
| 535 |
-
dconv1d_bias,
|
| 536 |
-
dx_proj_weight,
|
| 537 |
-
ddelta_proj_weight,
|
| 538 |
-
dout_proj_weight,
|
| 539 |
-
dout_proj_bias,
|
| 540 |
-
dA,
|
| 541 |
-
dB,
|
| 542 |
-
dC,
|
| 543 |
-
dD,
|
| 544 |
-
ddelta_bias if delta_bias is not None else None,
|
| 545 |
-
# 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps
|
| 546 |
-
dB_proj_bias,
|
| 547 |
-
dC_proj_bias,
|
| 548 |
-
None,
|
| 549 |
-
None,
|
| 550 |
-
None,
|
| 551 |
-
None,
|
| 552 |
-
None,
|
| 553 |
-
None,
|
| 554 |
-
)
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
def mamba_inner_fn(
|
| 558 |
-
xz,
|
| 559 |
-
conv1d_weight,
|
| 560 |
-
conv1d_bias,
|
| 561 |
-
x_proj_weight,
|
| 562 |
-
delta_proj_weight,
|
| 563 |
-
out_proj_weight,
|
| 564 |
-
out_proj_bias,
|
| 565 |
-
A,
|
| 566 |
-
B=None,
|
| 567 |
-
C=None,
|
| 568 |
-
D=None,
|
| 569 |
-
delta_bias=None,
|
| 570 |
-
B_proj_bias=None,
|
| 571 |
-
C_proj_bias=None,
|
| 572 |
-
delta_softplus=True,
|
| 573 |
-
checkpoint_lvl=1,
|
| 574 |
-
b_rms_weight=None,
|
| 575 |
-
c_rms_weight=None,
|
| 576 |
-
dt_rms_weight=None,
|
| 577 |
-
b_c_dt_rms_eps=1e-6,
|
| 578 |
-
):
|
| 579 |
-
return MambaInnerFn.apply(
|
| 580 |
-
xz,
|
| 581 |
-
conv1d_weight,
|
| 582 |
-
conv1d_bias,
|
| 583 |
-
x_proj_weight,
|
| 584 |
-
delta_proj_weight,
|
| 585 |
-
out_proj_weight,
|
| 586 |
-
out_proj_bias,
|
| 587 |
-
A,
|
| 588 |
-
B,
|
| 589 |
-
C,
|
| 590 |
-
D,
|
| 591 |
-
delta_bias,
|
| 592 |
-
B_proj_bias,
|
| 593 |
-
C_proj_bias,
|
| 594 |
-
delta_softplus,
|
| 595 |
-
checkpoint_lvl,
|
| 596 |
-
b_rms_weight,
|
| 597 |
-
c_rms_weight,
|
| 598 |
-
dt_rms_weight,
|
| 599 |
-
b_c_dt_rms_eps,
|
| 600 |
-
)
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
def mamba_inner_ref(
|
| 604 |
-
xz,
|
| 605 |
-
conv1d_weight,
|
| 606 |
-
conv1d_bias,
|
| 607 |
-
x_proj_weight,
|
| 608 |
-
delta_proj_weight,
|
| 609 |
-
out_proj_weight,
|
| 610 |
-
out_proj_bias,
|
| 611 |
-
A,
|
| 612 |
-
B=None,
|
| 613 |
-
C=None,
|
| 614 |
-
D=None,
|
| 615 |
-
delta_bias=None,
|
| 616 |
-
B_proj_bias=None,
|
| 617 |
-
C_proj_bias=None,
|
| 618 |
-
delta_softplus=True,
|
| 619 |
-
):
|
| 620 |
-
assert (
|
| 621 |
-
causal_conv1d_fn is not None
|
| 622 |
-
), "causal_conv1d_fn is not available. Please install causal-conv1d."
|
| 623 |
-
L = xz.shape[-1]
|
| 624 |
-
delta_rank = delta_proj_weight.shape[1]
|
| 625 |
-
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
| 626 |
-
x, z = xz.chunk(2, dim=1)
|
| 627 |
-
x = causal_conv1d_fn(
|
| 628 |
-
x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu"
|
| 629 |
-
)
|
| 630 |
-
# We're being very careful here about the layout, to avoid extra transposes.
|
| 631 |
-
# We want delta to have d as the slowest moving dimension
|
| 632 |
-
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
| 633 |
-
x_dbl = F.linear(rearrange(x, "b d l -> (b l) d"), x_proj_weight) # (bl d)
|
| 634 |
-
delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
|
| 635 |
-
delta = rearrange(delta, "d (b l) -> b d l", l=L)
|
| 636 |
-
if B is None: # variable B
|
| 637 |
-
B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl d)
|
| 638 |
-
if B_proj_bias is not None:
|
| 639 |
-
B = B + B_proj_bias.to(dtype=B.dtype)
|
| 640 |
-
if not A.is_complex():
|
| 641 |
-
B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
| 642 |
-
else:
|
| 643 |
-
B = rearrange(
|
| 644 |
-
B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2
|
| 645 |
-
).contiguous()
|
| 646 |
-
if C is None: # variable B
|
| 647 |
-
C = x_dbl[:, -d_state:] # (bl d)
|
| 648 |
-
if C_proj_bias is not None:
|
| 649 |
-
C = C + C_proj_bias.to(dtype=C.dtype)
|
| 650 |
-
if not A.is_complex():
|
| 651 |
-
C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
| 652 |
-
else:
|
| 653 |
-
C = rearrange(
|
| 654 |
-
C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2
|
| 655 |
-
).contiguous()
|
| 656 |
-
y = selective_scan_fn(
|
| 657 |
-
x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True
|
| 658 |
-
)
|
| 659 |
-
return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/ops/triton/layer_norm.py
DELETED
|
@@ -1,1166 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2024, Tri Dao.
|
| 2 |
-
# Implement dropout + residual + layer_norm / rms_norm.
|
| 3 |
-
|
| 4 |
-
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
| 5 |
-
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
| 6 |
-
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
| 7 |
-
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
| 8 |
-
|
| 9 |
-
import math
|
| 10 |
-
import warnings
|
| 11 |
-
|
| 12 |
-
import torch
|
| 13 |
-
import torch.nn.functional as F
|
| 14 |
-
from ...utils.torch import custom_bwd, custom_fwd
|
| 15 |
-
|
| 16 |
-
import triton
|
| 17 |
-
import triton.language as tl
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def layer_norm_ref(
|
| 21 |
-
x,
|
| 22 |
-
weight,
|
| 23 |
-
bias,
|
| 24 |
-
residual=None,
|
| 25 |
-
x1=None,
|
| 26 |
-
weight1=None,
|
| 27 |
-
bias1=None,
|
| 28 |
-
eps=1e-6,
|
| 29 |
-
dropout_p=0.0,
|
| 30 |
-
rowscale=None,
|
| 31 |
-
prenorm=False,
|
| 32 |
-
dropout_mask=None,
|
| 33 |
-
dropout_mask1=None,
|
| 34 |
-
upcast=False,
|
| 35 |
-
):
|
| 36 |
-
dtype = x.dtype
|
| 37 |
-
if upcast:
|
| 38 |
-
x = x.float()
|
| 39 |
-
weight = weight.float()
|
| 40 |
-
bias = bias.float() if bias is not None else None
|
| 41 |
-
residual = residual.float() if residual is not None else residual
|
| 42 |
-
x1 = x1.float() if x1 is not None else None
|
| 43 |
-
weight1 = weight1.float() if weight1 is not None else None
|
| 44 |
-
bias1 = bias1.float() if bias1 is not None else None
|
| 45 |
-
if x1 is not None:
|
| 46 |
-
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
| 47 |
-
if rowscale is not None:
|
| 48 |
-
x = x * rowscale[..., None]
|
| 49 |
-
if dropout_p > 0.0:
|
| 50 |
-
if dropout_mask is not None:
|
| 51 |
-
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
|
| 52 |
-
else:
|
| 53 |
-
x = F.dropout(x, p=dropout_p)
|
| 54 |
-
if x1 is not None:
|
| 55 |
-
if dropout_mask1 is not None:
|
| 56 |
-
x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
|
| 57 |
-
else:
|
| 58 |
-
x1 = F.dropout(x1, p=dropout_p)
|
| 59 |
-
if x1 is not None:
|
| 60 |
-
x = x + x1
|
| 61 |
-
if residual is not None:
|
| 62 |
-
x = (x + residual).to(x.dtype)
|
| 63 |
-
out = F.layer_norm(
|
| 64 |
-
x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps
|
| 65 |
-
).to(dtype)
|
| 66 |
-
if weight1 is None:
|
| 67 |
-
return out if not prenorm else (out, x)
|
| 68 |
-
else:
|
| 69 |
-
out1 = F.layer_norm(
|
| 70 |
-
x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
|
| 71 |
-
).to(dtype)
|
| 72 |
-
return (out, out1) if not prenorm else (out, out1, x)
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def rms_norm_ref(
|
| 76 |
-
x,
|
| 77 |
-
weight,
|
| 78 |
-
bias,
|
| 79 |
-
residual=None,
|
| 80 |
-
x1=None,
|
| 81 |
-
weight1=None,
|
| 82 |
-
bias1=None,
|
| 83 |
-
eps=1e-6,
|
| 84 |
-
dropout_p=0.0,
|
| 85 |
-
rowscale=None,
|
| 86 |
-
prenorm=False,
|
| 87 |
-
dropout_mask=None,
|
| 88 |
-
dropout_mask1=None,
|
| 89 |
-
upcast=False,
|
| 90 |
-
):
|
| 91 |
-
dtype = x.dtype
|
| 92 |
-
if upcast:
|
| 93 |
-
x = x.float()
|
| 94 |
-
weight = weight.float()
|
| 95 |
-
bias = bias.float() if bias is not None else None
|
| 96 |
-
residual = residual.float() if residual is not None else residual
|
| 97 |
-
x1 = x1.float() if x1 is not None else None
|
| 98 |
-
weight1 = weight1.float() if weight1 is not None else None
|
| 99 |
-
bias1 = bias1.float() if bias1 is not None else None
|
| 100 |
-
if x1 is not None:
|
| 101 |
-
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
| 102 |
-
if rowscale is not None:
|
| 103 |
-
x = x * rowscale[..., None]
|
| 104 |
-
if dropout_p > 0.0:
|
| 105 |
-
if dropout_mask is not None:
|
| 106 |
-
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
|
| 107 |
-
else:
|
| 108 |
-
x = F.dropout(x, p=dropout_p)
|
| 109 |
-
if x1 is not None:
|
| 110 |
-
if dropout_mask1 is not None:
|
| 111 |
-
x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
|
| 112 |
-
else:
|
| 113 |
-
x1 = F.dropout(x1, p=dropout_p)
|
| 114 |
-
if x1 is not None:
|
| 115 |
-
x = x + x1
|
| 116 |
-
if residual is not None:
|
| 117 |
-
x = (x + residual).to(x.dtype)
|
| 118 |
-
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
| 119 |
-
out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(
|
| 120 |
-
dtype
|
| 121 |
-
)
|
| 122 |
-
if weight1 is None:
|
| 123 |
-
return out if not prenorm else (out, x)
|
| 124 |
-
else:
|
| 125 |
-
out1 = (
|
| 126 |
-
(x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)
|
| 127 |
-
).to(dtype)
|
| 128 |
-
return (out, out1) if not prenorm else (out, out1, x)
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
def config_prune(configs):
|
| 132 |
-
|
| 133 |
-
if torch.version.hip:
|
| 134 |
-
try:
|
| 135 |
-
# set warp size based on gcn architecure
|
| 136 |
-
gcn_arch_name = torch.cuda.get_device_properties(0).gcnArchName
|
| 137 |
-
if "gfx10" in gcn_arch_name or "gfx11" in gcn_arch_name:
|
| 138 |
-
# radeon
|
| 139 |
-
warp_size = 32
|
| 140 |
-
else:
|
| 141 |
-
# instinct
|
| 142 |
-
warp_size = 64
|
| 143 |
-
except AttributeError as e:
|
| 144 |
-
# fall back to crude method to set warp size
|
| 145 |
-
device_name = torch.cuda.get_device_properties(0).name
|
| 146 |
-
if "instinct" in device_name.lower():
|
| 147 |
-
warp_size = 64
|
| 148 |
-
else:
|
| 149 |
-
warp_size = 32
|
| 150 |
-
warnings.warn(
|
| 151 |
-
f"{e}, warp size set to {warp_size} based on device name: {device_name}",
|
| 152 |
-
UserWarning,
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
-
else:
|
| 156 |
-
# cuda
|
| 157 |
-
warp_size = 32
|
| 158 |
-
|
| 159 |
-
max_block_sz = 1024
|
| 160 |
-
max_num_warps = max_block_sz // warp_size
|
| 161 |
-
pruned_configs = [config for config in configs if config.num_warps <= max_num_warps]
|
| 162 |
-
return pruned_configs
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
configs_autotune = [
|
| 166 |
-
triton.Config({}, num_warps=1),
|
| 167 |
-
triton.Config({}, num_warps=2),
|
| 168 |
-
triton.Config({}, num_warps=4),
|
| 169 |
-
triton.Config({}, num_warps=8),
|
| 170 |
-
triton.Config({}, num_warps=16),
|
| 171 |
-
triton.Config({}, num_warps=32),
|
| 172 |
-
]
|
| 173 |
-
|
| 174 |
-
pruned_configs_autotune = config_prune(configs_autotune)
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
@triton.autotune(
|
| 178 |
-
configs=pruned_configs_autotune,
|
| 179 |
-
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
|
| 180 |
-
)
|
| 181 |
-
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
| 182 |
-
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
|
| 183 |
-
@triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
|
| 184 |
-
@triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
|
| 185 |
-
@triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
|
| 186 |
-
@triton.jit
|
| 187 |
-
def _layer_norm_fwd_1pass_kernel(
|
| 188 |
-
X, # pointer to the input
|
| 189 |
-
Y, # pointer to the output
|
| 190 |
-
W, # pointer to the weights
|
| 191 |
-
B, # pointer to the biases
|
| 192 |
-
RESIDUAL, # pointer to the residual
|
| 193 |
-
X1,
|
| 194 |
-
W1,
|
| 195 |
-
B1,
|
| 196 |
-
Y1,
|
| 197 |
-
RESIDUAL_OUT, # pointer to the residual
|
| 198 |
-
ROWSCALE,
|
| 199 |
-
SEEDS, # Dropout seeds for each row
|
| 200 |
-
DROPOUT_MASK,
|
| 201 |
-
Mean, # pointer to the mean
|
| 202 |
-
Rstd, # pointer to the 1/std
|
| 203 |
-
stride_x_row, # how much to increase the pointer when moving by 1 row
|
| 204 |
-
stride_y_row,
|
| 205 |
-
stride_res_row,
|
| 206 |
-
stride_res_out_row,
|
| 207 |
-
stride_x1_row,
|
| 208 |
-
stride_y1_row,
|
| 209 |
-
M, # number of rows in X
|
| 210 |
-
N, # number of columns in X
|
| 211 |
-
eps, # epsilon to avoid division by zero
|
| 212 |
-
dropout_p, # Dropout probability
|
| 213 |
-
IS_RMS_NORM: tl.constexpr,
|
| 214 |
-
BLOCK_N: tl.constexpr,
|
| 215 |
-
HAS_RESIDUAL: tl.constexpr,
|
| 216 |
-
STORE_RESIDUAL_OUT: tl.constexpr,
|
| 217 |
-
HAS_BIAS: tl.constexpr,
|
| 218 |
-
HAS_DROPOUT: tl.constexpr,
|
| 219 |
-
STORE_DROPOUT_MASK: tl.constexpr,
|
| 220 |
-
HAS_ROWSCALE: tl.constexpr,
|
| 221 |
-
HAS_X1: tl.constexpr,
|
| 222 |
-
HAS_W1: tl.constexpr,
|
| 223 |
-
HAS_B1: tl.constexpr,
|
| 224 |
-
):
|
| 225 |
-
# Map the program id to the row of X and Y it should compute.
|
| 226 |
-
row = tl.program_id(0)
|
| 227 |
-
X += row * stride_x_row
|
| 228 |
-
Y += row * stride_y_row
|
| 229 |
-
if HAS_RESIDUAL:
|
| 230 |
-
RESIDUAL += row * stride_res_row
|
| 231 |
-
if STORE_RESIDUAL_OUT:
|
| 232 |
-
RESIDUAL_OUT += row * stride_res_out_row
|
| 233 |
-
if HAS_X1:
|
| 234 |
-
X1 += row * stride_x1_row
|
| 235 |
-
if HAS_W1:
|
| 236 |
-
Y1 += row * stride_y1_row
|
| 237 |
-
# Compute mean and variance
|
| 238 |
-
cols = tl.arange(0, BLOCK_N)
|
| 239 |
-
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
| 240 |
-
if HAS_ROWSCALE:
|
| 241 |
-
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
|
| 242 |
-
x *= rowscale
|
| 243 |
-
if HAS_DROPOUT:
|
| 244 |
-
# Compute dropout mask
|
| 245 |
-
# 7 rounds is good enough, and reduces register pressure
|
| 246 |
-
keep_mask = (
|
| 247 |
-
tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
|
| 248 |
-
)
|
| 249 |
-
x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
|
| 250 |
-
if STORE_DROPOUT_MASK:
|
| 251 |
-
tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
|
| 252 |
-
if HAS_X1:
|
| 253 |
-
x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
|
| 254 |
-
if HAS_ROWSCALE:
|
| 255 |
-
rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
|
| 256 |
-
x1 *= rowscale
|
| 257 |
-
if HAS_DROPOUT:
|
| 258 |
-
# Compute dropout mask
|
| 259 |
-
# 7 rounds is good enough, and reduces register pressure
|
| 260 |
-
keep_mask = (
|
| 261 |
-
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
|
| 262 |
-
> dropout_p
|
| 263 |
-
)
|
| 264 |
-
x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
|
| 265 |
-
if STORE_DROPOUT_MASK:
|
| 266 |
-
tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
|
| 267 |
-
x += x1
|
| 268 |
-
if HAS_RESIDUAL:
|
| 269 |
-
residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
|
| 270 |
-
x += residual
|
| 271 |
-
if STORE_RESIDUAL_OUT:
|
| 272 |
-
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
|
| 273 |
-
if not IS_RMS_NORM:
|
| 274 |
-
mean = tl.sum(x, axis=0) / N
|
| 275 |
-
tl.store(Mean + row, mean)
|
| 276 |
-
xbar = tl.where(cols < N, x - mean, 0.0)
|
| 277 |
-
var = tl.sum(xbar * xbar, axis=0) / N
|
| 278 |
-
else:
|
| 279 |
-
xbar = tl.where(cols < N, x, 0.0)
|
| 280 |
-
var = tl.sum(xbar * xbar, axis=0) / N
|
| 281 |
-
rstd = 1 / tl.sqrt(var + eps)
|
| 282 |
-
tl.store(Rstd + row, rstd)
|
| 283 |
-
# Normalize and apply linear transformation
|
| 284 |
-
mask = cols < N
|
| 285 |
-
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
| 286 |
-
if HAS_BIAS:
|
| 287 |
-
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
| 288 |
-
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
| 289 |
-
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
| 290 |
-
# Write output
|
| 291 |
-
tl.store(Y + cols, y, mask=mask)
|
| 292 |
-
if HAS_W1:
|
| 293 |
-
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
|
| 294 |
-
if HAS_B1:
|
| 295 |
-
b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
|
| 296 |
-
y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
|
| 297 |
-
tl.store(Y1 + cols, y1, mask=mask)
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
def _layer_norm_fwd(
|
| 301 |
-
x,
|
| 302 |
-
weight,
|
| 303 |
-
bias,
|
| 304 |
-
eps,
|
| 305 |
-
residual=None,
|
| 306 |
-
x1=None,
|
| 307 |
-
weight1=None,
|
| 308 |
-
bias1=None,
|
| 309 |
-
dropout_p=0.0,
|
| 310 |
-
rowscale=None,
|
| 311 |
-
out_dtype=None,
|
| 312 |
-
residual_dtype=None,
|
| 313 |
-
is_rms_norm=False,
|
| 314 |
-
return_dropout_mask=False,
|
| 315 |
-
):
|
| 316 |
-
if residual is not None:
|
| 317 |
-
residual_dtype = residual.dtype
|
| 318 |
-
M, N = x.shape
|
| 319 |
-
assert x.stride(-1) == 1
|
| 320 |
-
if residual is not None:
|
| 321 |
-
assert residual.stride(-1) == 1
|
| 322 |
-
assert residual.shape == (M, N)
|
| 323 |
-
assert weight.shape == (N,)
|
| 324 |
-
assert weight.stride(-1) == 1
|
| 325 |
-
if bias is not None:
|
| 326 |
-
assert bias.stride(-1) == 1
|
| 327 |
-
assert bias.shape == (N,)
|
| 328 |
-
if x1 is not None:
|
| 329 |
-
assert x1.shape == x.shape
|
| 330 |
-
assert rowscale is None
|
| 331 |
-
assert x1.stride(-1) == 1
|
| 332 |
-
if weight1 is not None:
|
| 333 |
-
assert weight1.shape == (N,)
|
| 334 |
-
assert weight1.stride(-1) == 1
|
| 335 |
-
if bias1 is not None:
|
| 336 |
-
assert bias1.shape == (N,)
|
| 337 |
-
assert bias1.stride(-1) == 1
|
| 338 |
-
if rowscale is not None:
|
| 339 |
-
assert rowscale.is_contiguous()
|
| 340 |
-
assert rowscale.shape == (M,)
|
| 341 |
-
# allocate output
|
| 342 |
-
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
|
| 343 |
-
assert y.stride(-1) == 1
|
| 344 |
-
if weight1 is not None:
|
| 345 |
-
y1 = torch.empty_like(y)
|
| 346 |
-
assert y1.stride(-1) == 1
|
| 347 |
-
else:
|
| 348 |
-
y1 = None
|
| 349 |
-
if (
|
| 350 |
-
residual is not None
|
| 351 |
-
or (residual_dtype is not None and residual_dtype != x.dtype)
|
| 352 |
-
or dropout_p > 0.0
|
| 353 |
-
or rowscale is not None
|
| 354 |
-
or x1 is not None
|
| 355 |
-
):
|
| 356 |
-
residual_out = torch.empty(
|
| 357 |
-
M,
|
| 358 |
-
N,
|
| 359 |
-
device=x.device,
|
| 360 |
-
dtype=residual_dtype if residual_dtype is not None else x.dtype,
|
| 361 |
-
)
|
| 362 |
-
assert residual_out.stride(-1) == 1
|
| 363 |
-
else:
|
| 364 |
-
residual_out = None
|
| 365 |
-
mean = (
|
| 366 |
-
torch.empty((M,), dtype=torch.float32, device=x.device)
|
| 367 |
-
if not is_rms_norm
|
| 368 |
-
else None
|
| 369 |
-
)
|
| 370 |
-
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
|
| 371 |
-
if dropout_p > 0.0:
|
| 372 |
-
seeds = torch.randint(
|
| 373 |
-
2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
|
| 374 |
-
)
|
| 375 |
-
else:
|
| 376 |
-
seeds = None
|
| 377 |
-
if return_dropout_mask and dropout_p > 0.0:
|
| 378 |
-
dropout_mask = torch.empty(
|
| 379 |
-
M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool
|
| 380 |
-
)
|
| 381 |
-
else:
|
| 382 |
-
dropout_mask = None
|
| 383 |
-
# Less than 64KB per feature: enqueue fused kernel
|
| 384 |
-
MAX_FUSED_SIZE = 65536 // x.element_size()
|
| 385 |
-
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
| 386 |
-
if N > BLOCK_N:
|
| 387 |
-
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
| 388 |
-
with torch.cuda.device(x.device.index):
|
| 389 |
-
_layer_norm_fwd_1pass_kernel[(M,)](
|
| 390 |
-
x,
|
| 391 |
-
y,
|
| 392 |
-
weight,
|
| 393 |
-
bias,
|
| 394 |
-
residual,
|
| 395 |
-
x1,
|
| 396 |
-
weight1,
|
| 397 |
-
bias1,
|
| 398 |
-
y1,
|
| 399 |
-
residual_out,
|
| 400 |
-
rowscale,
|
| 401 |
-
seeds,
|
| 402 |
-
dropout_mask,
|
| 403 |
-
mean,
|
| 404 |
-
rstd,
|
| 405 |
-
x.stride(0),
|
| 406 |
-
y.stride(0),
|
| 407 |
-
residual.stride(0) if residual is not None else 0,
|
| 408 |
-
residual_out.stride(0) if residual_out is not None else 0,
|
| 409 |
-
x1.stride(0) if x1 is not None else 0,
|
| 410 |
-
y1.stride(0) if y1 is not None else 0,
|
| 411 |
-
M,
|
| 412 |
-
N,
|
| 413 |
-
eps,
|
| 414 |
-
dropout_p,
|
| 415 |
-
is_rms_norm,
|
| 416 |
-
BLOCK_N,
|
| 417 |
-
residual is not None,
|
| 418 |
-
residual_out is not None,
|
| 419 |
-
bias is not None,
|
| 420 |
-
dropout_p > 0.0,
|
| 421 |
-
dropout_mask is not None,
|
| 422 |
-
rowscale is not None,
|
| 423 |
-
)
|
| 424 |
-
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
|
| 425 |
-
if dropout_mask is not None and x1 is not None:
|
| 426 |
-
dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
|
| 427 |
-
else:
|
| 428 |
-
dropout_mask1 = None
|
| 429 |
-
return (
|
| 430 |
-
y,
|
| 431 |
-
y1,
|
| 432 |
-
mean,
|
| 433 |
-
rstd,
|
| 434 |
-
residual_out if residual_out is not None else x,
|
| 435 |
-
seeds,
|
| 436 |
-
dropout_mask,
|
| 437 |
-
dropout_mask1,
|
| 438 |
-
)
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
@triton.autotune(
|
| 442 |
-
configs=pruned_configs_autotune,
|
| 443 |
-
key=[
|
| 444 |
-
"N",
|
| 445 |
-
"HAS_DRESIDUAL",
|
| 446 |
-
"STORE_DRESIDUAL",
|
| 447 |
-
"IS_RMS_NORM",
|
| 448 |
-
"HAS_BIAS",
|
| 449 |
-
"HAS_DROPOUT",
|
| 450 |
-
],
|
| 451 |
-
)
|
| 452 |
-
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
| 453 |
-
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
|
| 454 |
-
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
|
| 455 |
-
@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
|
| 456 |
-
@triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
|
| 457 |
-
@triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
|
| 458 |
-
@triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
|
| 459 |
-
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
|
| 460 |
-
@triton.jit
|
| 461 |
-
def _layer_norm_bwd_kernel(
|
| 462 |
-
X, # pointer to the input
|
| 463 |
-
W, # pointer to the weights
|
| 464 |
-
B, # pointer to the biases
|
| 465 |
-
Y, # pointer to the output to be recomputed
|
| 466 |
-
DY, # pointer to the output gradient
|
| 467 |
-
DX, # pointer to the input gradient
|
| 468 |
-
DW, # pointer to the partial sum of weights gradient
|
| 469 |
-
DB, # pointer to the partial sum of biases gradient
|
| 470 |
-
DRESIDUAL,
|
| 471 |
-
W1,
|
| 472 |
-
DY1,
|
| 473 |
-
DX1,
|
| 474 |
-
DW1,
|
| 475 |
-
DB1,
|
| 476 |
-
DRESIDUAL_IN,
|
| 477 |
-
ROWSCALE,
|
| 478 |
-
SEEDS,
|
| 479 |
-
Mean, # pointer to the mean
|
| 480 |
-
Rstd, # pointer to the 1/std
|
| 481 |
-
stride_x_row, # how much to increase the pointer when moving by 1 row
|
| 482 |
-
stride_y_row,
|
| 483 |
-
stride_dy_row,
|
| 484 |
-
stride_dx_row,
|
| 485 |
-
stride_dres_row,
|
| 486 |
-
stride_dy1_row,
|
| 487 |
-
stride_dx1_row,
|
| 488 |
-
stride_dres_in_row,
|
| 489 |
-
M, # number of rows in X
|
| 490 |
-
N, # number of columns in X
|
| 491 |
-
eps, # epsilon to avoid division by zero
|
| 492 |
-
dropout_p,
|
| 493 |
-
rows_per_program,
|
| 494 |
-
IS_RMS_NORM: tl.constexpr,
|
| 495 |
-
BLOCK_N: tl.constexpr,
|
| 496 |
-
HAS_DRESIDUAL: tl.constexpr,
|
| 497 |
-
STORE_DRESIDUAL: tl.constexpr,
|
| 498 |
-
HAS_BIAS: tl.constexpr,
|
| 499 |
-
HAS_DROPOUT: tl.constexpr,
|
| 500 |
-
HAS_ROWSCALE: tl.constexpr,
|
| 501 |
-
HAS_DY1: tl.constexpr,
|
| 502 |
-
HAS_DX1: tl.constexpr,
|
| 503 |
-
HAS_B1: tl.constexpr,
|
| 504 |
-
RECOMPUTE_OUTPUT: tl.constexpr,
|
| 505 |
-
):
|
| 506 |
-
# Map the program id to the elements of X, DX, and DY it should compute.
|
| 507 |
-
row_block_id = tl.program_id(0)
|
| 508 |
-
row_start = row_block_id * rows_per_program
|
| 509 |
-
# Do not early exit if row_start >= M, because we need to write DW and DB
|
| 510 |
-
cols = tl.arange(0, BLOCK_N)
|
| 511 |
-
mask = cols < N
|
| 512 |
-
X += row_start * stride_x_row
|
| 513 |
-
if HAS_DRESIDUAL:
|
| 514 |
-
DRESIDUAL += row_start * stride_dres_row
|
| 515 |
-
if STORE_DRESIDUAL:
|
| 516 |
-
DRESIDUAL_IN += row_start * stride_dres_in_row
|
| 517 |
-
DY += row_start * stride_dy_row
|
| 518 |
-
DX += row_start * stride_dx_row
|
| 519 |
-
if HAS_DY1:
|
| 520 |
-
DY1 += row_start * stride_dy1_row
|
| 521 |
-
if HAS_DX1:
|
| 522 |
-
DX1 += row_start * stride_dx1_row
|
| 523 |
-
if RECOMPUTE_OUTPUT:
|
| 524 |
-
Y += row_start * stride_y_row
|
| 525 |
-
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
| 526 |
-
if RECOMPUTE_OUTPUT and HAS_BIAS:
|
| 527 |
-
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
|
| 528 |
-
if HAS_DY1:
|
| 529 |
-
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
|
| 530 |
-
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 531 |
-
if HAS_BIAS:
|
| 532 |
-
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 533 |
-
if HAS_DY1:
|
| 534 |
-
dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 535 |
-
if HAS_B1:
|
| 536 |
-
db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 537 |
-
row_end = min((row_block_id + 1) * rows_per_program, M)
|
| 538 |
-
for row in range(row_start, row_end):
|
| 539 |
-
# Load data to SRAM
|
| 540 |
-
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
| 541 |
-
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
| 542 |
-
if HAS_DY1:
|
| 543 |
-
dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
|
| 544 |
-
if not IS_RMS_NORM:
|
| 545 |
-
mean = tl.load(Mean + row)
|
| 546 |
-
rstd = tl.load(Rstd + row)
|
| 547 |
-
# Compute dx
|
| 548 |
-
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
| 549 |
-
xhat = tl.where(mask, xhat, 0.0)
|
| 550 |
-
if RECOMPUTE_OUTPUT:
|
| 551 |
-
y = xhat * w + b if HAS_BIAS else xhat * w
|
| 552 |
-
tl.store(Y + cols, y, mask=mask)
|
| 553 |
-
wdy = w * dy
|
| 554 |
-
dw += dy * xhat
|
| 555 |
-
if HAS_BIAS:
|
| 556 |
-
db += dy
|
| 557 |
-
if HAS_DY1:
|
| 558 |
-
wdy += w1 * dy1
|
| 559 |
-
dw1 += dy1 * xhat
|
| 560 |
-
if HAS_B1:
|
| 561 |
-
db1 += dy1
|
| 562 |
-
if not IS_RMS_NORM:
|
| 563 |
-
c1 = tl.sum(xhat * wdy, axis=0) / N
|
| 564 |
-
c2 = tl.sum(wdy, axis=0) / N
|
| 565 |
-
dx = (wdy - (xhat * c1 + c2)) * rstd
|
| 566 |
-
else:
|
| 567 |
-
c1 = tl.sum(xhat * wdy, axis=0) / N
|
| 568 |
-
dx = (wdy - xhat * c1) * rstd
|
| 569 |
-
if HAS_DRESIDUAL:
|
| 570 |
-
dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
|
| 571 |
-
dx += dres
|
| 572 |
-
# Write dx
|
| 573 |
-
if STORE_DRESIDUAL:
|
| 574 |
-
tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
|
| 575 |
-
if HAS_DX1:
|
| 576 |
-
if HAS_DROPOUT:
|
| 577 |
-
keep_mask = (
|
| 578 |
-
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
|
| 579 |
-
> dropout_p
|
| 580 |
-
)
|
| 581 |
-
dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
|
| 582 |
-
else:
|
| 583 |
-
dx1 = dx
|
| 584 |
-
tl.store(DX1 + cols, dx1, mask=mask)
|
| 585 |
-
if HAS_DROPOUT:
|
| 586 |
-
keep_mask = (
|
| 587 |
-
tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7)
|
| 588 |
-
> dropout_p
|
| 589 |
-
)
|
| 590 |
-
dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
|
| 591 |
-
if HAS_ROWSCALE:
|
| 592 |
-
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
|
| 593 |
-
dx *= rowscale
|
| 594 |
-
tl.store(DX + cols, dx, mask=mask)
|
| 595 |
-
|
| 596 |
-
X += stride_x_row
|
| 597 |
-
if HAS_DRESIDUAL:
|
| 598 |
-
DRESIDUAL += stride_dres_row
|
| 599 |
-
if STORE_DRESIDUAL:
|
| 600 |
-
DRESIDUAL_IN += stride_dres_in_row
|
| 601 |
-
if RECOMPUTE_OUTPUT:
|
| 602 |
-
Y += stride_y_row
|
| 603 |
-
DY += stride_dy_row
|
| 604 |
-
DX += stride_dx_row
|
| 605 |
-
if HAS_DY1:
|
| 606 |
-
DY1 += stride_dy1_row
|
| 607 |
-
if HAS_DX1:
|
| 608 |
-
DX1 += stride_dx1_row
|
| 609 |
-
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
|
| 610 |
-
if HAS_BIAS:
|
| 611 |
-
tl.store(DB + row_block_id * N + cols, db, mask=mask)
|
| 612 |
-
if HAS_DY1:
|
| 613 |
-
tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
|
| 614 |
-
if HAS_B1:
|
| 615 |
-
tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
def _layer_norm_bwd(
|
| 619 |
-
dy,
|
| 620 |
-
x,
|
| 621 |
-
weight,
|
| 622 |
-
bias,
|
| 623 |
-
eps,
|
| 624 |
-
mean,
|
| 625 |
-
rstd,
|
| 626 |
-
dresidual=None,
|
| 627 |
-
dy1=None,
|
| 628 |
-
weight1=None,
|
| 629 |
-
bias1=None,
|
| 630 |
-
seeds=None,
|
| 631 |
-
dropout_p=0.0,
|
| 632 |
-
rowscale=None,
|
| 633 |
-
has_residual=False,
|
| 634 |
-
has_x1=False,
|
| 635 |
-
is_rms_norm=False,
|
| 636 |
-
x_dtype=None,
|
| 637 |
-
recompute_output=False,
|
| 638 |
-
):
|
| 639 |
-
M, N = x.shape
|
| 640 |
-
assert x.stride(-1) == 1
|
| 641 |
-
assert dy.stride(-1) == 1
|
| 642 |
-
assert dy.shape == (M, N)
|
| 643 |
-
if dresidual is not None:
|
| 644 |
-
assert dresidual.stride(-1) == 1
|
| 645 |
-
assert dresidual.shape == (M, N)
|
| 646 |
-
assert weight.shape == (N,)
|
| 647 |
-
assert weight.stride(-1) == 1
|
| 648 |
-
if bias is not None:
|
| 649 |
-
assert bias.stride(-1) == 1
|
| 650 |
-
assert bias.shape == (N,)
|
| 651 |
-
if dy1 is not None:
|
| 652 |
-
assert weight1 is not None
|
| 653 |
-
assert dy1.shape == dy.shape
|
| 654 |
-
assert dy1.stride(-1) == 1
|
| 655 |
-
if weight1 is not None:
|
| 656 |
-
assert weight1.shape == (N,)
|
| 657 |
-
assert weight1.stride(-1) == 1
|
| 658 |
-
if bias1 is not None:
|
| 659 |
-
assert bias1.shape == (N,)
|
| 660 |
-
assert bias1.stride(-1) == 1
|
| 661 |
-
if seeds is not None:
|
| 662 |
-
assert seeds.is_contiguous()
|
| 663 |
-
assert seeds.shape == (M if not has_x1 else M * 2,)
|
| 664 |
-
if rowscale is not None:
|
| 665 |
-
assert rowscale.is_contiguous()
|
| 666 |
-
assert rowscale.shape == (M,)
|
| 667 |
-
# allocate output
|
| 668 |
-
dx = (
|
| 669 |
-
torch.empty_like(x)
|
| 670 |
-
if x_dtype is None
|
| 671 |
-
else torch.empty(M, N, dtype=x_dtype, device=x.device)
|
| 672 |
-
)
|
| 673 |
-
dresidual_in = (
|
| 674 |
-
torch.empty_like(x)
|
| 675 |
-
if has_residual
|
| 676 |
-
and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
|
| 677 |
-
else None
|
| 678 |
-
)
|
| 679 |
-
dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
|
| 680 |
-
y = (
|
| 681 |
-
torch.empty(M, N, dtype=dy.dtype, device=dy.device)
|
| 682 |
-
if recompute_output
|
| 683 |
-
else None
|
| 684 |
-
)
|
| 685 |
-
if recompute_output:
|
| 686 |
-
assert (
|
| 687 |
-
weight1 is None
|
| 688 |
-
), "recompute_output is not supported with parallel LayerNorm"
|
| 689 |
-
|
| 690 |
-
# Less than 64KB per feature: enqueue fused kernel
|
| 691 |
-
MAX_FUSED_SIZE = 65536 // x.element_size()
|
| 692 |
-
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
| 693 |
-
if N > BLOCK_N:
|
| 694 |
-
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
| 695 |
-
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
| 696 |
-
_dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
|
| 697 |
-
_db = (
|
| 698 |
-
torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
|
| 699 |
-
if bias is not None
|
| 700 |
-
else None
|
| 701 |
-
)
|
| 702 |
-
_dw1 = torch.empty_like(_dw) if weight1 is not None else None
|
| 703 |
-
_db1 = torch.empty_like(_db) if bias1 is not None else None
|
| 704 |
-
rows_per_program = math.ceil(M / sm_count)
|
| 705 |
-
grid = (sm_count,)
|
| 706 |
-
with torch.cuda.device(x.device.index):
|
| 707 |
-
_layer_norm_bwd_kernel[grid](
|
| 708 |
-
x,
|
| 709 |
-
weight,
|
| 710 |
-
bias,
|
| 711 |
-
y,
|
| 712 |
-
dy,
|
| 713 |
-
dx,
|
| 714 |
-
_dw,
|
| 715 |
-
_db,
|
| 716 |
-
dresidual,
|
| 717 |
-
weight1,
|
| 718 |
-
dy1,
|
| 719 |
-
dx1,
|
| 720 |
-
_dw1,
|
| 721 |
-
_db1,
|
| 722 |
-
dresidual_in,
|
| 723 |
-
rowscale,
|
| 724 |
-
seeds,
|
| 725 |
-
mean,
|
| 726 |
-
rstd,
|
| 727 |
-
x.stride(0),
|
| 728 |
-
0 if not recompute_output else y.stride(0),
|
| 729 |
-
dy.stride(0),
|
| 730 |
-
dx.stride(0),
|
| 731 |
-
dresidual.stride(0) if dresidual is not None else 0,
|
| 732 |
-
dy1.stride(0) if dy1 is not None else 0,
|
| 733 |
-
dx1.stride(0) if dx1 is not None else 0,
|
| 734 |
-
dresidual_in.stride(0) if dresidual_in is not None else 0,
|
| 735 |
-
M,
|
| 736 |
-
N,
|
| 737 |
-
eps,
|
| 738 |
-
dropout_p,
|
| 739 |
-
rows_per_program,
|
| 740 |
-
is_rms_norm,
|
| 741 |
-
BLOCK_N,
|
| 742 |
-
dresidual is not None,
|
| 743 |
-
dresidual_in is not None,
|
| 744 |
-
bias is not None,
|
| 745 |
-
dropout_p > 0.0,
|
| 746 |
-
)
|
| 747 |
-
dw = _dw.sum(0).to(weight.dtype)
|
| 748 |
-
db = _db.sum(0).to(bias.dtype) if bias is not None else None
|
| 749 |
-
dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
|
| 750 |
-
db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
|
| 751 |
-
# Don't need to compute dresidual_in separately in this case
|
| 752 |
-
if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
|
| 753 |
-
dresidual_in = dx
|
| 754 |
-
if has_x1 and dropout_p == 0.0:
|
| 755 |
-
dx1 = dx
|
| 756 |
-
return (
|
| 757 |
-
(dx, dw, db, dresidual_in, dx1, dw1, db1)
|
| 758 |
-
if not recompute_output
|
| 759 |
-
else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
|
| 760 |
-
)
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
class LayerNormFn(torch.autograd.Function):
|
| 764 |
-
@staticmethod
|
| 765 |
-
def forward(
|
| 766 |
-
ctx,
|
| 767 |
-
x,
|
| 768 |
-
weight,
|
| 769 |
-
bias,
|
| 770 |
-
residual=None,
|
| 771 |
-
x1=None,
|
| 772 |
-
weight1=None,
|
| 773 |
-
bias1=None,
|
| 774 |
-
eps=1e-6,
|
| 775 |
-
dropout_p=0.0,
|
| 776 |
-
rowscale=None,
|
| 777 |
-
prenorm=False,
|
| 778 |
-
residual_in_fp32=False,
|
| 779 |
-
is_rms_norm=False,
|
| 780 |
-
return_dropout_mask=False,
|
| 781 |
-
):
|
| 782 |
-
x_shape_og = x.shape
|
| 783 |
-
# reshape input data into 2D tensor
|
| 784 |
-
x = x.reshape(-1, x.shape[-1])
|
| 785 |
-
if x.stride(-1) != 1:
|
| 786 |
-
x = x.contiguous()
|
| 787 |
-
if residual is not None:
|
| 788 |
-
assert residual.shape == x_shape_og
|
| 789 |
-
residual = residual.reshape(-1, residual.shape[-1])
|
| 790 |
-
if residual.stride(-1) != 1:
|
| 791 |
-
residual = residual.contiguous()
|
| 792 |
-
if x1 is not None:
|
| 793 |
-
assert x1.shape == x_shape_og
|
| 794 |
-
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
| 795 |
-
x1 = x1.reshape(-1, x1.shape[-1])
|
| 796 |
-
if x1.stride(-1) != 1:
|
| 797 |
-
x1 = x1.contiguous()
|
| 798 |
-
weight = weight.contiguous()
|
| 799 |
-
if bias is not None:
|
| 800 |
-
bias = bias.contiguous()
|
| 801 |
-
if weight1 is not None:
|
| 802 |
-
weight1 = weight1.contiguous()
|
| 803 |
-
if bias1 is not None:
|
| 804 |
-
bias1 = bias1.contiguous()
|
| 805 |
-
if rowscale is not None:
|
| 806 |
-
rowscale = rowscale.reshape(-1).contiguous()
|
| 807 |
-
residual_dtype = (
|
| 808 |
-
residual.dtype
|
| 809 |
-
if residual is not None
|
| 810 |
-
else (torch.float32 if residual_in_fp32 else None)
|
| 811 |
-
)
|
| 812 |
-
y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = (
|
| 813 |
-
_layer_norm_fwd(
|
| 814 |
-
x,
|
| 815 |
-
weight,
|
| 816 |
-
bias,
|
| 817 |
-
eps,
|
| 818 |
-
residual,
|
| 819 |
-
x1,
|
| 820 |
-
weight1,
|
| 821 |
-
bias1,
|
| 822 |
-
dropout_p=dropout_p,
|
| 823 |
-
rowscale=rowscale,
|
| 824 |
-
residual_dtype=residual_dtype,
|
| 825 |
-
is_rms_norm=is_rms_norm,
|
| 826 |
-
return_dropout_mask=return_dropout_mask,
|
| 827 |
-
)
|
| 828 |
-
)
|
| 829 |
-
ctx.save_for_backward(
|
| 830 |
-
residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
|
| 831 |
-
)
|
| 832 |
-
ctx.x_shape_og = x_shape_og
|
| 833 |
-
ctx.eps = eps
|
| 834 |
-
ctx.dropout_p = dropout_p
|
| 835 |
-
ctx.is_rms_norm = is_rms_norm
|
| 836 |
-
ctx.has_residual = residual is not None
|
| 837 |
-
ctx.has_x1 = x1 is not None
|
| 838 |
-
ctx.prenorm = prenorm
|
| 839 |
-
ctx.x_dtype = x.dtype
|
| 840 |
-
y = y.reshape(x_shape_og)
|
| 841 |
-
y1 = y1.reshape(x_shape_og) if y1 is not None else None
|
| 842 |
-
residual_out = (
|
| 843 |
-
residual_out.reshape(x_shape_og) if residual_out is not None else None
|
| 844 |
-
)
|
| 845 |
-
dropout_mask = (
|
| 846 |
-
dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
|
| 847 |
-
)
|
| 848 |
-
dropout_mask1 = (
|
| 849 |
-
dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
|
| 850 |
-
)
|
| 851 |
-
if not return_dropout_mask:
|
| 852 |
-
if weight1 is None:
|
| 853 |
-
return y if not prenorm else (y, residual_out)
|
| 854 |
-
else:
|
| 855 |
-
return (y, y1) if not prenorm else (y, y1, residual_out)
|
| 856 |
-
else:
|
| 857 |
-
if weight1 is None:
|
| 858 |
-
return (
|
| 859 |
-
(y, dropout_mask, dropout_mask1)
|
| 860 |
-
if not prenorm
|
| 861 |
-
else (y, residual_out, dropout_mask, dropout_mask1)
|
| 862 |
-
)
|
| 863 |
-
else:
|
| 864 |
-
return (
|
| 865 |
-
(y, y1, dropout_mask, dropout_mask1)
|
| 866 |
-
if not prenorm
|
| 867 |
-
else (y, y1, residual_out, dropout_mask, dropout_mask1)
|
| 868 |
-
)
|
| 869 |
-
|
| 870 |
-
@staticmethod
|
| 871 |
-
def backward(ctx, dy, *args):
|
| 872 |
-
x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
|
| 873 |
-
dy = dy.reshape(-1, dy.shape[-1])
|
| 874 |
-
if dy.stride(-1) != 1:
|
| 875 |
-
dy = dy.contiguous()
|
| 876 |
-
assert dy.shape == x.shape
|
| 877 |
-
if weight1 is not None:
|
| 878 |
-
dy1, args = args[0], args[1:]
|
| 879 |
-
dy1 = dy1.reshape(-1, dy1.shape[-1])
|
| 880 |
-
if dy1.stride(-1) != 1:
|
| 881 |
-
dy1 = dy1.contiguous()
|
| 882 |
-
assert dy1.shape == x.shape
|
| 883 |
-
else:
|
| 884 |
-
dy1 = None
|
| 885 |
-
if ctx.prenorm:
|
| 886 |
-
dresidual = args[0]
|
| 887 |
-
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
| 888 |
-
if dresidual.stride(-1) != 1:
|
| 889 |
-
dresidual = dresidual.contiguous()
|
| 890 |
-
assert dresidual.shape == x.shape
|
| 891 |
-
else:
|
| 892 |
-
dresidual = None
|
| 893 |
-
dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
|
| 894 |
-
dy,
|
| 895 |
-
x,
|
| 896 |
-
weight,
|
| 897 |
-
bias,
|
| 898 |
-
ctx.eps,
|
| 899 |
-
mean,
|
| 900 |
-
rstd,
|
| 901 |
-
dresidual,
|
| 902 |
-
dy1,
|
| 903 |
-
weight1,
|
| 904 |
-
bias1,
|
| 905 |
-
seeds,
|
| 906 |
-
ctx.dropout_p,
|
| 907 |
-
rowscale,
|
| 908 |
-
ctx.has_residual,
|
| 909 |
-
ctx.has_x1,
|
| 910 |
-
ctx.is_rms_norm,
|
| 911 |
-
x_dtype=ctx.x_dtype,
|
| 912 |
-
)
|
| 913 |
-
return (
|
| 914 |
-
dx.reshape(ctx.x_shape_og),
|
| 915 |
-
dw,
|
| 916 |
-
db,
|
| 917 |
-
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
| 918 |
-
dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
|
| 919 |
-
dw1,
|
| 920 |
-
db1,
|
| 921 |
-
None,
|
| 922 |
-
None,
|
| 923 |
-
None,
|
| 924 |
-
None,
|
| 925 |
-
None,
|
| 926 |
-
None,
|
| 927 |
-
None,
|
| 928 |
-
)
|
| 929 |
-
|
| 930 |
-
|
| 931 |
-
def layer_norm_fn(
|
| 932 |
-
x,
|
| 933 |
-
weight,
|
| 934 |
-
bias,
|
| 935 |
-
residual=None,
|
| 936 |
-
x1=None,
|
| 937 |
-
weight1=None,
|
| 938 |
-
bias1=None,
|
| 939 |
-
eps=1e-6,
|
| 940 |
-
dropout_p=0.0,
|
| 941 |
-
rowscale=None,
|
| 942 |
-
prenorm=False,
|
| 943 |
-
residual_in_fp32=False,
|
| 944 |
-
is_rms_norm=False,
|
| 945 |
-
return_dropout_mask=False,
|
| 946 |
-
):
|
| 947 |
-
return LayerNormFn.apply(
|
| 948 |
-
x,
|
| 949 |
-
weight,
|
| 950 |
-
bias,
|
| 951 |
-
residual,
|
| 952 |
-
x1,
|
| 953 |
-
weight1,
|
| 954 |
-
bias1,
|
| 955 |
-
eps,
|
| 956 |
-
dropout_p,
|
| 957 |
-
rowscale,
|
| 958 |
-
prenorm,
|
| 959 |
-
residual_in_fp32,
|
| 960 |
-
is_rms_norm,
|
| 961 |
-
return_dropout_mask,
|
| 962 |
-
)
|
| 963 |
-
|
| 964 |
-
|
| 965 |
-
def rms_norm_fn(
|
| 966 |
-
x,
|
| 967 |
-
weight,
|
| 968 |
-
bias,
|
| 969 |
-
residual=None,
|
| 970 |
-
x1=None,
|
| 971 |
-
weight1=None,
|
| 972 |
-
bias1=None,
|
| 973 |
-
eps=1e-6,
|
| 974 |
-
dropout_p=0.0,
|
| 975 |
-
rowscale=None,
|
| 976 |
-
prenorm=False,
|
| 977 |
-
residual_in_fp32=False,
|
| 978 |
-
return_dropout_mask=False,
|
| 979 |
-
):
|
| 980 |
-
return LayerNormFn.apply(
|
| 981 |
-
x,
|
| 982 |
-
weight,
|
| 983 |
-
bias,
|
| 984 |
-
residual,
|
| 985 |
-
x1,
|
| 986 |
-
weight1,
|
| 987 |
-
bias1,
|
| 988 |
-
eps,
|
| 989 |
-
dropout_p,
|
| 990 |
-
rowscale,
|
| 991 |
-
prenorm,
|
| 992 |
-
residual_in_fp32,
|
| 993 |
-
True,
|
| 994 |
-
return_dropout_mask,
|
| 995 |
-
)
|
| 996 |
-
|
| 997 |
-
|
| 998 |
-
class RMSNorm(torch.nn.Module):
|
| 999 |
-
|
| 1000 |
-
def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
|
| 1001 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
| 1002 |
-
super().__init__()
|
| 1003 |
-
self.eps = eps
|
| 1004 |
-
if dropout_p > 0.0:
|
| 1005 |
-
self.drop = torch.nn.Dropout(dropout_p)
|
| 1006 |
-
else:
|
| 1007 |
-
self.drop = None
|
| 1008 |
-
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
| 1009 |
-
self.register_parameter("bias", None)
|
| 1010 |
-
self.reset_parameters()
|
| 1011 |
-
|
| 1012 |
-
def reset_parameters(self):
|
| 1013 |
-
torch.nn.init.ones_(self.weight)
|
| 1014 |
-
|
| 1015 |
-
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
|
| 1016 |
-
return rms_norm_fn(
|
| 1017 |
-
x,
|
| 1018 |
-
self.weight,
|
| 1019 |
-
self.bias,
|
| 1020 |
-
residual=residual,
|
| 1021 |
-
eps=self.eps,
|
| 1022 |
-
dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
|
| 1023 |
-
prenorm=prenorm,
|
| 1024 |
-
residual_in_fp32=residual_in_fp32,
|
| 1025 |
-
)
|
| 1026 |
-
|
| 1027 |
-
|
| 1028 |
-
class LayerNormLinearFn(torch.autograd.Function):
|
| 1029 |
-
@staticmethod
|
| 1030 |
-
@custom_fwd
|
| 1031 |
-
def forward(
|
| 1032 |
-
ctx,
|
| 1033 |
-
x,
|
| 1034 |
-
norm_weight,
|
| 1035 |
-
norm_bias,
|
| 1036 |
-
linear_weight,
|
| 1037 |
-
linear_bias,
|
| 1038 |
-
residual=None,
|
| 1039 |
-
eps=1e-6,
|
| 1040 |
-
prenorm=False,
|
| 1041 |
-
residual_in_fp32=False,
|
| 1042 |
-
is_rms_norm=False,
|
| 1043 |
-
):
|
| 1044 |
-
x_shape_og = x.shape
|
| 1045 |
-
# reshape input data into 2D tensor
|
| 1046 |
-
x = x.reshape(-1, x.shape[-1])
|
| 1047 |
-
if x.stride(-1) != 1:
|
| 1048 |
-
x = x.contiguous()
|
| 1049 |
-
if residual is not None:
|
| 1050 |
-
assert residual.shape == x_shape_og
|
| 1051 |
-
residual = residual.reshape(-1, residual.shape[-1])
|
| 1052 |
-
if residual.stride(-1) != 1:
|
| 1053 |
-
residual = residual.contiguous()
|
| 1054 |
-
norm_weight = norm_weight.contiguous()
|
| 1055 |
-
if norm_bias is not None:
|
| 1056 |
-
norm_bias = norm_bias.contiguous()
|
| 1057 |
-
residual_dtype = (
|
| 1058 |
-
residual.dtype
|
| 1059 |
-
if residual is not None
|
| 1060 |
-
else (torch.float32 if residual_in_fp32 else None)
|
| 1061 |
-
)
|
| 1062 |
-
y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
|
| 1063 |
-
x,
|
| 1064 |
-
norm_weight,
|
| 1065 |
-
norm_bias,
|
| 1066 |
-
eps,
|
| 1067 |
-
residual,
|
| 1068 |
-
out_dtype=(
|
| 1069 |
-
None
|
| 1070 |
-
if not torch.is_autocast_enabled()
|
| 1071 |
-
else torch.get_autocast_gpu_dtype()
|
| 1072 |
-
),
|
| 1073 |
-
residual_dtype=residual_dtype,
|
| 1074 |
-
is_rms_norm=is_rms_norm,
|
| 1075 |
-
)
|
| 1076 |
-
y = y.reshape(x_shape_og)
|
| 1077 |
-
dtype = (
|
| 1078 |
-
torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
|
| 1079 |
-
)
|
| 1080 |
-
linear_weight = linear_weight.to(dtype)
|
| 1081 |
-
linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
|
| 1082 |
-
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
|
| 1083 |
-
# We don't store y, will be recomputed in the backward pass to save memory
|
| 1084 |
-
ctx.save_for_backward(
|
| 1085 |
-
residual_out, norm_weight, norm_bias, linear_weight, mean, rstd
|
| 1086 |
-
)
|
| 1087 |
-
ctx.x_shape_og = x_shape_og
|
| 1088 |
-
ctx.eps = eps
|
| 1089 |
-
ctx.is_rms_norm = is_rms_norm
|
| 1090 |
-
ctx.has_residual = residual is not None
|
| 1091 |
-
ctx.prenorm = prenorm
|
| 1092 |
-
ctx.x_dtype = x.dtype
|
| 1093 |
-
ctx.linear_bias_is_none = linear_bias is None
|
| 1094 |
-
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
|
| 1095 |
-
|
| 1096 |
-
@staticmethod
|
| 1097 |
-
@custom_bwd
|
| 1098 |
-
def backward(ctx, dout, *args):
|
| 1099 |
-
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
|
| 1100 |
-
dout = dout.reshape(-1, dout.shape[-1])
|
| 1101 |
-
dy = F.linear(dout, linear_weight.t())
|
| 1102 |
-
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
|
| 1103 |
-
if dy.stride(-1) != 1:
|
| 1104 |
-
dy = dy.contiguous()
|
| 1105 |
-
assert dy.shape == x.shape
|
| 1106 |
-
if ctx.prenorm:
|
| 1107 |
-
dresidual = args[0]
|
| 1108 |
-
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
| 1109 |
-
if dresidual.stride(-1) != 1:
|
| 1110 |
-
dresidual = dresidual.contiguous()
|
| 1111 |
-
assert dresidual.shape == x.shape
|
| 1112 |
-
else:
|
| 1113 |
-
dresidual = None
|
| 1114 |
-
dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
|
| 1115 |
-
dy,
|
| 1116 |
-
x,
|
| 1117 |
-
norm_weight,
|
| 1118 |
-
norm_bias,
|
| 1119 |
-
ctx.eps,
|
| 1120 |
-
mean,
|
| 1121 |
-
rstd,
|
| 1122 |
-
dresidual=dresidual,
|
| 1123 |
-
has_residual=ctx.has_residual,
|
| 1124 |
-
is_rms_norm=ctx.is_rms_norm,
|
| 1125 |
-
x_dtype=ctx.x_dtype,
|
| 1126 |
-
recompute_output=True,
|
| 1127 |
-
)
|
| 1128 |
-
dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
|
| 1129 |
-
return (
|
| 1130 |
-
dx.reshape(ctx.x_shape_og),
|
| 1131 |
-
dnorm_weight,
|
| 1132 |
-
dnorm_bias,
|
| 1133 |
-
dlinear_weight,
|
| 1134 |
-
dlinear_bias,
|
| 1135 |
-
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
| 1136 |
-
None,
|
| 1137 |
-
None,
|
| 1138 |
-
None,
|
| 1139 |
-
None,
|
| 1140 |
-
)
|
| 1141 |
-
|
| 1142 |
-
|
| 1143 |
-
def layer_norm_linear_fn(
|
| 1144 |
-
x,
|
| 1145 |
-
norm_weight,
|
| 1146 |
-
norm_bias,
|
| 1147 |
-
linear_weight,
|
| 1148 |
-
linear_bias,
|
| 1149 |
-
residual=None,
|
| 1150 |
-
eps=1e-6,
|
| 1151 |
-
prenorm=False,
|
| 1152 |
-
residual_in_fp32=False,
|
| 1153 |
-
is_rms_norm=False,
|
| 1154 |
-
):
|
| 1155 |
-
return LayerNormLinearFn.apply(
|
| 1156 |
-
x,
|
| 1157 |
-
norm_weight,
|
| 1158 |
-
norm_bias,
|
| 1159 |
-
linear_weight,
|
| 1160 |
-
linear_bias,
|
| 1161 |
-
residual,
|
| 1162 |
-
eps,
|
| 1163 |
-
prenorm,
|
| 1164 |
-
residual_in_fp32,
|
| 1165 |
-
is_rms_norm,
|
| 1166 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/ops/triton/selective_state_update.py
DELETED
|
@@ -1,389 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
| 2 |
-
|
| 3 |
-
"""We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import math
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
|
| 10 |
-
import triton
|
| 11 |
-
import triton.language as tl
|
| 12 |
-
|
| 13 |
-
from einops import rearrange, repeat
|
| 14 |
-
|
| 15 |
-
from .softplus import softplus
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
|
| 19 |
-
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
|
| 20 |
-
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
|
| 21 |
-
@triton.heuristics(
|
| 22 |
-
{
|
| 23 |
-
"HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"]
|
| 24 |
-
is not None
|
| 25 |
-
}
|
| 26 |
-
)
|
| 27 |
-
@triton.heuristics(
|
| 28 |
-
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}
|
| 29 |
-
)
|
| 30 |
-
@triton.jit
|
| 31 |
-
def _selective_scan_update_kernel(
|
| 32 |
-
# Pointers to matrices
|
| 33 |
-
state_ptr,
|
| 34 |
-
x_ptr,
|
| 35 |
-
dt_ptr,
|
| 36 |
-
dt_bias_ptr,
|
| 37 |
-
A_ptr,
|
| 38 |
-
B_ptr,
|
| 39 |
-
C_ptr,
|
| 40 |
-
D_ptr,
|
| 41 |
-
z_ptr,
|
| 42 |
-
out_ptr,
|
| 43 |
-
state_batch_indices_ptr,
|
| 44 |
-
# Matrix dimensions
|
| 45 |
-
batch,
|
| 46 |
-
nheads,
|
| 47 |
-
dim,
|
| 48 |
-
dstate,
|
| 49 |
-
nheads_ngroups_ratio,
|
| 50 |
-
# Strides
|
| 51 |
-
stride_state_batch,
|
| 52 |
-
stride_state_head,
|
| 53 |
-
stride_state_dim,
|
| 54 |
-
stride_state_dstate,
|
| 55 |
-
stride_x_batch,
|
| 56 |
-
stride_x_head,
|
| 57 |
-
stride_x_dim,
|
| 58 |
-
stride_dt_batch,
|
| 59 |
-
stride_dt_head,
|
| 60 |
-
stride_dt_dim,
|
| 61 |
-
stride_dt_bias_head,
|
| 62 |
-
stride_dt_bias_dim,
|
| 63 |
-
stride_A_head,
|
| 64 |
-
stride_A_dim,
|
| 65 |
-
stride_A_dstate,
|
| 66 |
-
stride_B_batch,
|
| 67 |
-
stride_B_group,
|
| 68 |
-
stride_B_dstate,
|
| 69 |
-
stride_C_batch,
|
| 70 |
-
stride_C_group,
|
| 71 |
-
stride_C_dstate,
|
| 72 |
-
stride_D_head,
|
| 73 |
-
stride_D_dim,
|
| 74 |
-
stride_z_batch,
|
| 75 |
-
stride_z_head,
|
| 76 |
-
stride_z_dim,
|
| 77 |
-
stride_out_batch,
|
| 78 |
-
stride_out_head,
|
| 79 |
-
stride_out_dim,
|
| 80 |
-
# Meta-parameters
|
| 81 |
-
DT_SOFTPLUS: tl.constexpr,
|
| 82 |
-
TIE_HDIM: tl.constexpr,
|
| 83 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 84 |
-
HAS_DT_BIAS: tl.constexpr,
|
| 85 |
-
HAS_D: tl.constexpr,
|
| 86 |
-
HAS_Z: tl.constexpr,
|
| 87 |
-
HAS_STATE_BATCH_INDICES: tl.constexpr,
|
| 88 |
-
BLOCK_SIZE_DSTATE: tl.constexpr,
|
| 89 |
-
):
|
| 90 |
-
pid_m = tl.program_id(axis=0)
|
| 91 |
-
pid_b = tl.program_id(axis=1)
|
| 92 |
-
pid_h = tl.program_id(axis=2)
|
| 93 |
-
|
| 94 |
-
if HAS_STATE_BATCH_INDICES:
|
| 95 |
-
state_batch_indices_ptr += pid_b
|
| 96 |
-
state_batch_idx = tl.load(state_batch_indices_ptr)
|
| 97 |
-
state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
|
| 98 |
-
else:
|
| 99 |
-
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
|
| 100 |
-
|
| 101 |
-
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
|
| 102 |
-
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
|
| 103 |
-
if HAS_DT_BIAS:
|
| 104 |
-
dt_bias_ptr += pid_h * stride_dt_bias_head
|
| 105 |
-
A_ptr += pid_h * stride_A_head
|
| 106 |
-
B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
|
| 107 |
-
C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
|
| 108 |
-
if HAS_Z:
|
| 109 |
-
z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
|
| 110 |
-
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
|
| 111 |
-
|
| 112 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 113 |
-
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
|
| 114 |
-
state_ptrs = state_ptr + (
|
| 115 |
-
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
|
| 116 |
-
)
|
| 117 |
-
x_ptrs = x_ptr + offs_m * stride_x_dim
|
| 118 |
-
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
|
| 119 |
-
if HAS_DT_BIAS:
|
| 120 |
-
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
|
| 121 |
-
if HAS_D:
|
| 122 |
-
D_ptr += pid_h * stride_D_head
|
| 123 |
-
A_ptrs = A_ptr + (
|
| 124 |
-
offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
|
| 125 |
-
)
|
| 126 |
-
B_ptrs = B_ptr + offs_n * stride_B_dstate
|
| 127 |
-
C_ptrs = C_ptr + offs_n * stride_C_dstate
|
| 128 |
-
if HAS_D:
|
| 129 |
-
D_ptrs = D_ptr + offs_m * stride_D_dim
|
| 130 |
-
if HAS_Z:
|
| 131 |
-
z_ptrs = z_ptr + offs_m * stride_z_dim
|
| 132 |
-
out_ptrs = out_ptr + offs_m * stride_out_dim
|
| 133 |
-
|
| 134 |
-
state = tl.load(
|
| 135 |
-
state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
|
| 136 |
-
)
|
| 137 |
-
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 138 |
-
if not TIE_HDIM:
|
| 139 |
-
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 140 |
-
if HAS_DT_BIAS:
|
| 141 |
-
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 142 |
-
if DT_SOFTPLUS:
|
| 143 |
-
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
| 144 |
-
A = tl.load(
|
| 145 |
-
A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
|
| 146 |
-
).to(tl.float32)
|
| 147 |
-
dA = tl.exp(A * dt[:, None])
|
| 148 |
-
else:
|
| 149 |
-
dt = tl.load(dt_ptr).to(tl.float32)
|
| 150 |
-
if HAS_DT_BIAS:
|
| 151 |
-
dt += tl.load(dt_bias_ptr).to(tl.float32)
|
| 152 |
-
if DT_SOFTPLUS:
|
| 153 |
-
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
| 154 |
-
A = tl.load(A_ptr).to(tl.float32)
|
| 155 |
-
dA = tl.exp(A * dt) # scalar, not a matrix
|
| 156 |
-
|
| 157 |
-
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
| 158 |
-
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
| 159 |
-
if HAS_D:
|
| 160 |
-
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 161 |
-
if HAS_Z:
|
| 162 |
-
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 163 |
-
|
| 164 |
-
if not TIE_HDIM:
|
| 165 |
-
dB = B[None, :] * dt[:, None]
|
| 166 |
-
else:
|
| 167 |
-
dB = B * dt # vector of size (dstate,)
|
| 168 |
-
state = state * dA + dB * x[:, None]
|
| 169 |
-
tl.store(
|
| 170 |
-
state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
| 171 |
-
)
|
| 172 |
-
out = tl.sum(state * C[None, :], axis=1)
|
| 173 |
-
if HAS_D:
|
| 174 |
-
out += x * D
|
| 175 |
-
if HAS_Z:
|
| 176 |
-
out *= z * tl.sigmoid(z)
|
| 177 |
-
tl.store(out_ptrs, out, mask=offs_m < dim)
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
def selective_state_update(
|
| 181 |
-
state,
|
| 182 |
-
x,
|
| 183 |
-
dt,
|
| 184 |
-
A,
|
| 185 |
-
B,
|
| 186 |
-
C,
|
| 187 |
-
D=None,
|
| 188 |
-
z=None,
|
| 189 |
-
dt_bias=None,
|
| 190 |
-
dt_softplus=False,
|
| 191 |
-
state_batch_indices=None,
|
| 192 |
-
):
|
| 193 |
-
"""
|
| 194 |
-
Argument:
|
| 195 |
-
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
| 196 |
-
x: (batch, dim) or (batch, nheads, dim)
|
| 197 |
-
dt: (batch, dim) or (batch, nheads, dim)
|
| 198 |
-
A: (dim, dstate) or (nheads, dim, dstate)
|
| 199 |
-
B: (batch, dstate) or (batch, ngroups, dstate)
|
| 200 |
-
C: (batch, dstate) or (batch, ngroups, dstate)
|
| 201 |
-
D: (dim,) or (nheads, dim)
|
| 202 |
-
z: (batch, dim) or (batch, nheads, dim)
|
| 203 |
-
dt_bias: (dim,) or (nheads, dim)
|
| 204 |
-
Return:
|
| 205 |
-
out: (batch, dim) or (batch, nheads, dim)
|
| 206 |
-
"""
|
| 207 |
-
has_heads = state.dim() > 3
|
| 208 |
-
if state.dim() == 3:
|
| 209 |
-
state = state.unsqueeze(1)
|
| 210 |
-
if x.dim() == 2:
|
| 211 |
-
x = x.unsqueeze(1)
|
| 212 |
-
if dt.dim() == 2:
|
| 213 |
-
dt = dt.unsqueeze(1)
|
| 214 |
-
if A.dim() == 2:
|
| 215 |
-
A = A.unsqueeze(0)
|
| 216 |
-
if B.dim() == 2:
|
| 217 |
-
B = B.unsqueeze(1)
|
| 218 |
-
if C.dim() == 2:
|
| 219 |
-
C = C.unsqueeze(1)
|
| 220 |
-
if D is not None and D.dim() == 1:
|
| 221 |
-
D = D.unsqueeze(0)
|
| 222 |
-
if z is not None and z.dim() == 2:
|
| 223 |
-
z = z.unsqueeze(1)
|
| 224 |
-
if dt_bias is not None and dt_bias.dim() == 1:
|
| 225 |
-
dt_bias = dt_bias.unsqueeze(0)
|
| 226 |
-
_, nheads, dim, dstate = state.shape
|
| 227 |
-
batch = x.shape[0]
|
| 228 |
-
if x.shape != (batch, nheads, dim):
|
| 229 |
-
print(f"{state.shape} {x.shape} {batch} {nheads} {dim}")
|
| 230 |
-
assert x.shape == (batch, nheads, dim)
|
| 231 |
-
assert dt.shape == x.shape
|
| 232 |
-
assert A.shape == (nheads, dim, dstate)
|
| 233 |
-
ngroups = B.shape[1]
|
| 234 |
-
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
| 235 |
-
assert B.shape == (batch, ngroups, dstate)
|
| 236 |
-
assert C.shape == B.shape
|
| 237 |
-
if D is not None:
|
| 238 |
-
assert D.shape == (nheads, dim)
|
| 239 |
-
if z is not None:
|
| 240 |
-
assert z.shape == x.shape
|
| 241 |
-
if dt_bias is not None:
|
| 242 |
-
assert dt_bias.shape == (nheads, dim)
|
| 243 |
-
if state_batch_indices is not None:
|
| 244 |
-
assert state_batch_indices.shape == (batch,)
|
| 245 |
-
out = torch.empty_like(x)
|
| 246 |
-
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
|
| 247 |
-
z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
|
| 248 |
-
# We don't want autotune since it will overwrite the state
|
| 249 |
-
# We instead tune by hand.
|
| 250 |
-
BLOCK_SIZE_M, num_warps = (
|
| 251 |
-
(32, 4)
|
| 252 |
-
if dstate <= 16
|
| 253 |
-
else (
|
| 254 |
-
(16, 4)
|
| 255 |
-
if dstate <= 32
|
| 256 |
-
else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
|
| 257 |
-
)
|
| 258 |
-
)
|
| 259 |
-
tie_hdim = (
|
| 260 |
-
A.stride(-1) == 0
|
| 261 |
-
and A.stride(-2) == 0
|
| 262 |
-
and dt.stride(-1) == 0
|
| 263 |
-
and dt_bias.stride(-1) == 0
|
| 264 |
-
)
|
| 265 |
-
with torch.cuda.device(x.device.index):
|
| 266 |
-
_selective_scan_update_kernel[grid](
|
| 267 |
-
state,
|
| 268 |
-
x,
|
| 269 |
-
dt,
|
| 270 |
-
dt_bias,
|
| 271 |
-
A,
|
| 272 |
-
B,
|
| 273 |
-
C,
|
| 274 |
-
D,
|
| 275 |
-
z,
|
| 276 |
-
out,
|
| 277 |
-
state_batch_indices,
|
| 278 |
-
batch,
|
| 279 |
-
nheads,
|
| 280 |
-
dim,
|
| 281 |
-
dstate,
|
| 282 |
-
nheads // ngroups,
|
| 283 |
-
state.stride(0),
|
| 284 |
-
state.stride(1),
|
| 285 |
-
state.stride(2),
|
| 286 |
-
state.stride(3),
|
| 287 |
-
x.stride(0),
|
| 288 |
-
x.stride(1),
|
| 289 |
-
x.stride(2),
|
| 290 |
-
dt.stride(0),
|
| 291 |
-
dt.stride(1),
|
| 292 |
-
dt.stride(2),
|
| 293 |
-
*(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
|
| 294 |
-
A.stride(0),
|
| 295 |
-
A.stride(1),
|
| 296 |
-
A.stride(2),
|
| 297 |
-
B.stride(0),
|
| 298 |
-
B.stride(1),
|
| 299 |
-
B.stride(2),
|
| 300 |
-
C.stride(0),
|
| 301 |
-
C.stride(1),
|
| 302 |
-
C.stride(2),
|
| 303 |
-
*(D.stride(0), D.stride(1)) if D is not None else 0,
|
| 304 |
-
z_strides[0],
|
| 305 |
-
z_strides[1],
|
| 306 |
-
z_strides[2],
|
| 307 |
-
out.stride(0),
|
| 308 |
-
out.stride(1),
|
| 309 |
-
out.stride(2),
|
| 310 |
-
dt_softplus,
|
| 311 |
-
tie_hdim,
|
| 312 |
-
BLOCK_SIZE_M,
|
| 313 |
-
num_warps=num_warps,
|
| 314 |
-
)
|
| 315 |
-
if not has_heads:
|
| 316 |
-
out = out.squeeze(1)
|
| 317 |
-
return out
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
def selective_state_update_ref(
|
| 321 |
-
state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
|
| 322 |
-
):
|
| 323 |
-
"""
|
| 324 |
-
Argument:
|
| 325 |
-
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
| 326 |
-
x: (batch, dim) or (batch, nheads, dim)
|
| 327 |
-
dt: (batch, dim) or (batch, nheads, dim)
|
| 328 |
-
A: (dim, dstate) or (nheads, dim, dstate)
|
| 329 |
-
B: (batch, dstate) or (batch, ngroups, dstate)
|
| 330 |
-
C: (batch, dstate) or (batch, ngroups, dstate)
|
| 331 |
-
D: (dim,) or (nheads, dim)
|
| 332 |
-
z: (batch, dim) or (batch, nheads, dim)
|
| 333 |
-
dt_bias: (dim,) or (nheads, dim)
|
| 334 |
-
Return:
|
| 335 |
-
out: (batch, dim) or (batch, nheads, dim)
|
| 336 |
-
"""
|
| 337 |
-
has_heads = state.dim() > 3
|
| 338 |
-
if state.dim() == 3:
|
| 339 |
-
state = state.unsqueeze(1)
|
| 340 |
-
if x.dim() == 2:
|
| 341 |
-
x = x.unsqueeze(1)
|
| 342 |
-
if dt.dim() == 2:
|
| 343 |
-
dt = dt.unsqueeze(1)
|
| 344 |
-
if A.dim() == 2:
|
| 345 |
-
A = A.unsqueeze(0)
|
| 346 |
-
if B.dim() == 2:
|
| 347 |
-
B = B.unsqueeze(1)
|
| 348 |
-
if C.dim() == 2:
|
| 349 |
-
C = C.unsqueeze(1)
|
| 350 |
-
if D is not None and D.dim() == 1:
|
| 351 |
-
D = D.unsqueeze(0)
|
| 352 |
-
if z is not None and z.dim() == 2:
|
| 353 |
-
z = z.unsqueeze(1)
|
| 354 |
-
if dt_bias is not None and dt_bias.dim() == 1:
|
| 355 |
-
dt_bias = dt_bias.unsqueeze(0)
|
| 356 |
-
batch, nheads, dim, dstate = state.shape
|
| 357 |
-
assert x.shape == (batch, nheads, dim)
|
| 358 |
-
assert dt.shape == x.shape
|
| 359 |
-
assert A.shape == (nheads, dim, dstate)
|
| 360 |
-
ngroups = B.shape[1]
|
| 361 |
-
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
| 362 |
-
assert B.shape == (batch, ngroups, dstate)
|
| 363 |
-
assert C.shape == B.shape
|
| 364 |
-
if D is not None:
|
| 365 |
-
assert D.shape == (nheads, dim)
|
| 366 |
-
if z is not None:
|
| 367 |
-
assert z.shape == x.shape
|
| 368 |
-
if dt_bias is not None:
|
| 369 |
-
assert dt_bias.shape == (nheads, dim)
|
| 370 |
-
dt = dt + dt_bias
|
| 371 |
-
dt = F.softplus(dt) if dt_softplus else dt
|
| 372 |
-
dA = torch.exp(
|
| 373 |
-
rearrange(dt, "b h d -> b h d 1") * A
|
| 374 |
-
) # (batch, nheads, dim, dstate)
|
| 375 |
-
B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
|
| 376 |
-
C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
|
| 377 |
-
dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
|
| 378 |
-
B, "b h n -> b h 1 n"
|
| 379 |
-
) # (batch, nheads, dim, dstate)
|
| 380 |
-
state.copy_(
|
| 381 |
-
state * dA + dB * rearrange(x, "b h d -> b h d 1")
|
| 382 |
-
) # (batch, dim, dstate
|
| 383 |
-
out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
|
| 384 |
-
if D is not None:
|
| 385 |
-
out += (x * D).to(out.dtype)
|
| 386 |
-
out = (out if z is None else out * F.silu(z)).to(x.dtype)
|
| 387 |
-
if not has_heads:
|
| 388 |
-
out = out.squeeze(1)
|
| 389 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_scan.py
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_state.py
DELETED
|
@@ -1,2012 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
| 2 |
-
|
| 3 |
-
"""We want triton==2.1.0 or 2.2.0 for this
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import math
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
|
| 10 |
-
import triton
|
| 11 |
-
import triton.language as tl
|
| 12 |
-
|
| 13 |
-
from einops import rearrange, repeat
|
| 14 |
-
|
| 15 |
-
from .softplus import softplus
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def init_to_zero(names):
|
| 19 |
-
return lambda nargs: [
|
| 20 |
-
nargs[name].zero_() for name in names if nargs[name] is not None
|
| 21 |
-
]
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
@triton.autotune(
|
| 25 |
-
configs=[
|
| 26 |
-
triton.Config({"BLOCK_SIZE_H": 1}),
|
| 27 |
-
triton.Config({"BLOCK_SIZE_H": 2}),
|
| 28 |
-
triton.Config({"BLOCK_SIZE_H": 4}),
|
| 29 |
-
triton.Config({"BLOCK_SIZE_H": 8}),
|
| 30 |
-
triton.Config({"BLOCK_SIZE_H": 16}),
|
| 31 |
-
triton.Config({"BLOCK_SIZE_H": 32}),
|
| 32 |
-
triton.Config({"BLOCK_SIZE_H": 64}),
|
| 33 |
-
],
|
| 34 |
-
key=["chunk_size", "nheads"],
|
| 35 |
-
)
|
| 36 |
-
@triton.jit
|
| 37 |
-
def _chunk_cumsum_fwd_kernel(
|
| 38 |
-
# Pointers to matrices
|
| 39 |
-
dt_ptr,
|
| 40 |
-
A_ptr,
|
| 41 |
-
dt_bias_ptr,
|
| 42 |
-
dt_out_ptr,
|
| 43 |
-
dA_cumsum_ptr,
|
| 44 |
-
# Matrix dimension
|
| 45 |
-
batch,
|
| 46 |
-
seqlen,
|
| 47 |
-
nheads,
|
| 48 |
-
chunk_size,
|
| 49 |
-
dt_min,
|
| 50 |
-
dt_max,
|
| 51 |
-
# Strides
|
| 52 |
-
stride_dt_batch,
|
| 53 |
-
stride_dt_seqlen,
|
| 54 |
-
stride_dt_head,
|
| 55 |
-
stride_A_head,
|
| 56 |
-
stride_dt_bias_head,
|
| 57 |
-
stride_dt_out_batch,
|
| 58 |
-
stride_dt_out_chunk,
|
| 59 |
-
stride_dt_out_head,
|
| 60 |
-
stride_dt_out_csize,
|
| 61 |
-
stride_dA_cs_batch,
|
| 62 |
-
stride_dA_cs_chunk,
|
| 63 |
-
stride_dA_cs_head,
|
| 64 |
-
stride_dA_cs_csize,
|
| 65 |
-
# Meta-parameters
|
| 66 |
-
DT_SOFTPLUS: tl.constexpr,
|
| 67 |
-
HAS_DT_BIAS: tl.constexpr,
|
| 68 |
-
BLOCK_SIZE_H: tl.constexpr,
|
| 69 |
-
BLOCK_SIZE_CHUNK: tl.constexpr,
|
| 70 |
-
):
|
| 71 |
-
pid_b = tl.program_id(axis=0)
|
| 72 |
-
pid_c = tl.program_id(axis=1)
|
| 73 |
-
pid_h = tl.program_id(axis=2)
|
| 74 |
-
dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
|
| 75 |
-
dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk
|
| 76 |
-
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk
|
| 77 |
-
|
| 78 |
-
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
|
| 79 |
-
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
|
| 80 |
-
dt_ptrs = dt_ptr + (
|
| 81 |
-
offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
|
| 82 |
-
)
|
| 83 |
-
A_ptrs = A_ptr + offs_h * stride_A_head
|
| 84 |
-
dt_out_ptrs = dt_out_ptr + (
|
| 85 |
-
offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize
|
| 86 |
-
)
|
| 87 |
-
dA_cs_ptrs = dA_cumsum_ptr + (
|
| 88 |
-
offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize
|
| 89 |
-
)
|
| 90 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 91 |
-
|
| 92 |
-
dt = tl.load(
|
| 93 |
-
dt_ptrs,
|
| 94 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
| 95 |
-
other=0.0,
|
| 96 |
-
).to(tl.float32)
|
| 97 |
-
if HAS_DT_BIAS:
|
| 98 |
-
dt_bias = tl.load(
|
| 99 |
-
dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
|
| 100 |
-
).to(tl.float32)
|
| 101 |
-
dt += dt_bias[:, None]
|
| 102 |
-
if DT_SOFTPLUS:
|
| 103 |
-
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
| 104 |
-
# As of Triton 2.2.0, tl.clamp is not available yet
|
| 105 |
-
# dt = tl.clamp(dt, dt_min, dt_max)
|
| 106 |
-
dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
|
| 107 |
-
dt = tl.where(
|
| 108 |
-
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
|
| 109 |
-
)
|
| 110 |
-
tl.store(
|
| 111 |
-
dt_out_ptrs,
|
| 112 |
-
dt,
|
| 113 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
|
| 114 |
-
)
|
| 115 |
-
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
|
| 116 |
-
dA = dt * A[:, None]
|
| 117 |
-
dA_cs = tl.cumsum(dA, axis=1)
|
| 118 |
-
tl.store(
|
| 119 |
-
dA_cs_ptrs,
|
| 120 |
-
dA_cs,
|
| 121 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
|
| 122 |
-
)
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
@triton.autotune(
|
| 126 |
-
configs=[
|
| 127 |
-
triton.Config(
|
| 128 |
-
{"BLOCK_SIZE_H": 1}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 129 |
-
),
|
| 130 |
-
triton.Config(
|
| 131 |
-
{"BLOCK_SIZE_H": 2}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 132 |
-
),
|
| 133 |
-
triton.Config(
|
| 134 |
-
{"BLOCK_SIZE_H": 4}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 135 |
-
),
|
| 136 |
-
triton.Config(
|
| 137 |
-
{"BLOCK_SIZE_H": 8}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 138 |
-
),
|
| 139 |
-
triton.Config(
|
| 140 |
-
{"BLOCK_SIZE_H": 16}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 141 |
-
),
|
| 142 |
-
triton.Config(
|
| 143 |
-
{"BLOCK_SIZE_H": 32}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 144 |
-
),
|
| 145 |
-
triton.Config(
|
| 146 |
-
{"BLOCK_SIZE_H": 64}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 147 |
-
),
|
| 148 |
-
],
|
| 149 |
-
key=["chunk_size", "nheads"],
|
| 150 |
-
)
|
| 151 |
-
@triton.jit
|
| 152 |
-
def _chunk_cumsum_bwd_kernel(
|
| 153 |
-
# Pointers to matrices
|
| 154 |
-
ddA_ptr,
|
| 155 |
-
ddt_out_ptr,
|
| 156 |
-
dt_ptr,
|
| 157 |
-
A_ptr,
|
| 158 |
-
dt_bias_ptr,
|
| 159 |
-
ddt_ptr,
|
| 160 |
-
dA_ptr,
|
| 161 |
-
ddt_bias_ptr,
|
| 162 |
-
# Matrix dimensions
|
| 163 |
-
batch,
|
| 164 |
-
seqlen,
|
| 165 |
-
nheads,
|
| 166 |
-
chunk_size,
|
| 167 |
-
dt_min,
|
| 168 |
-
dt_max,
|
| 169 |
-
# Strides
|
| 170 |
-
stride_ddA_batch,
|
| 171 |
-
stride_ddA_chunk,
|
| 172 |
-
stride_ddA_head,
|
| 173 |
-
stride_ddA_csize,
|
| 174 |
-
stride_ddt_out_batch,
|
| 175 |
-
stride_ddt_out_chunk,
|
| 176 |
-
stride_ddt_out_head,
|
| 177 |
-
stride_ddt_out_csize,
|
| 178 |
-
stride_dt_batch,
|
| 179 |
-
stride_dt_seqlen,
|
| 180 |
-
stride_dt_head,
|
| 181 |
-
stride_A_head,
|
| 182 |
-
stride_dt_bias_head,
|
| 183 |
-
stride_ddt_batch,
|
| 184 |
-
stride_ddt_seqlen,
|
| 185 |
-
stride_ddt_head,
|
| 186 |
-
stride_dA_head,
|
| 187 |
-
stride_ddt_bias_head,
|
| 188 |
-
# Meta-parameters
|
| 189 |
-
DT_SOFTPLUS: tl.constexpr,
|
| 190 |
-
HAS_DT_BIAS: tl.constexpr,
|
| 191 |
-
BLOCK_SIZE_H: tl.constexpr,
|
| 192 |
-
BLOCK_SIZE_CHUNK: tl.constexpr,
|
| 193 |
-
):
|
| 194 |
-
pid_b = tl.program_id(axis=0)
|
| 195 |
-
pid_c = tl.program_id(axis=1)
|
| 196 |
-
pid_h = tl.program_id(axis=2)
|
| 197 |
-
ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk
|
| 198 |
-
ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk
|
| 199 |
-
dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
|
| 200 |
-
ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen
|
| 201 |
-
|
| 202 |
-
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
|
| 203 |
-
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
|
| 204 |
-
ddt_out_ptrs = ddt_out_ptr + (
|
| 205 |
-
offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize
|
| 206 |
-
)
|
| 207 |
-
ddA_ptrs = ddA_ptr + (
|
| 208 |
-
offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize
|
| 209 |
-
)
|
| 210 |
-
dt_ptrs = dt_ptr + (
|
| 211 |
-
offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
|
| 212 |
-
)
|
| 213 |
-
ddt_ptrs = ddt_ptr + (
|
| 214 |
-
offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen
|
| 215 |
-
)
|
| 216 |
-
A_ptrs = A_ptr + offs_h * stride_A_head
|
| 217 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 218 |
-
|
| 219 |
-
ddA = tl.load(
|
| 220 |
-
ddA_ptrs,
|
| 221 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
| 222 |
-
other=0.0,
|
| 223 |
-
).to(tl.float32)
|
| 224 |
-
ddt_out = tl.load(
|
| 225 |
-
ddt_out_ptrs,
|
| 226 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
| 227 |
-
other=0.0,
|
| 228 |
-
).to(tl.float32)
|
| 229 |
-
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
|
| 230 |
-
ddt = ddA * A[:, None] + ddt_out
|
| 231 |
-
dt = tl.load(
|
| 232 |
-
dt_ptrs,
|
| 233 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
| 234 |
-
other=0.0,
|
| 235 |
-
).to(tl.float32)
|
| 236 |
-
if HAS_DT_BIAS:
|
| 237 |
-
dt_bias = tl.load(
|
| 238 |
-
dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
|
| 239 |
-
).to(tl.float32)
|
| 240 |
-
dt += dt_bias[:, None]
|
| 241 |
-
if DT_SOFTPLUS:
|
| 242 |
-
dt_presoftplus = dt
|
| 243 |
-
dt = tl.where(dt <= 20.0, softplus(dt), ddt)
|
| 244 |
-
clamp_mask = (dt < dt_min) | (dt > dt_max)
|
| 245 |
-
# As of Triton 2.2.0, tl.clamp is not available yet
|
| 246 |
-
# dt = tl.clamp(dt, dt_min, dt_max)
|
| 247 |
-
dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
|
| 248 |
-
dt = tl.where(
|
| 249 |
-
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
|
| 250 |
-
)
|
| 251 |
-
ddt = tl.where(
|
| 252 |
-
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0
|
| 253 |
-
)
|
| 254 |
-
ddt = tl.where(clamp_mask, 0.0, ddt)
|
| 255 |
-
if DT_SOFTPLUS:
|
| 256 |
-
ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt)
|
| 257 |
-
tl.store(
|
| 258 |
-
ddt_ptrs,
|
| 259 |
-
ddt,
|
| 260 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
| 261 |
-
)
|
| 262 |
-
dA = tl.sum(ddA * dt, axis=1)
|
| 263 |
-
tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads)
|
| 264 |
-
if HAS_DT_BIAS:
|
| 265 |
-
ddt_bias = tl.sum(ddt, axis=1)
|
| 266 |
-
tl.atomic_add(
|
| 267 |
-
ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads
|
| 268 |
-
)
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
@triton.autotune(
|
| 272 |
-
configs=[
|
| 273 |
-
triton.Config(
|
| 274 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
| 275 |
-
num_stages=3,
|
| 276 |
-
num_warps=8,
|
| 277 |
-
),
|
| 278 |
-
triton.Config(
|
| 279 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
| 280 |
-
num_stages=4,
|
| 281 |
-
num_warps=4,
|
| 282 |
-
),
|
| 283 |
-
triton.Config(
|
| 284 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 285 |
-
num_stages=4,
|
| 286 |
-
num_warps=4,
|
| 287 |
-
),
|
| 288 |
-
triton.Config(
|
| 289 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 290 |
-
num_stages=4,
|
| 291 |
-
num_warps=4,
|
| 292 |
-
),
|
| 293 |
-
triton.Config(
|
| 294 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 295 |
-
num_stages=4,
|
| 296 |
-
num_warps=4,
|
| 297 |
-
),
|
| 298 |
-
triton.Config(
|
| 299 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 300 |
-
num_stages=4,
|
| 301 |
-
num_warps=4,
|
| 302 |
-
),
|
| 303 |
-
triton.Config(
|
| 304 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 305 |
-
num_stages=5,
|
| 306 |
-
num_warps=2,
|
| 307 |
-
),
|
| 308 |
-
triton.Config(
|
| 309 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 310 |
-
num_stages=5,
|
| 311 |
-
num_warps=2,
|
| 312 |
-
),
|
| 313 |
-
triton.Config(
|
| 314 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 315 |
-
num_stages=4,
|
| 316 |
-
num_warps=2,
|
| 317 |
-
),
|
| 318 |
-
],
|
| 319 |
-
key=["hdim", "dstate", "chunk_size"],
|
| 320 |
-
)
|
| 321 |
-
@triton.jit
|
| 322 |
-
def _chunk_state_fwd_kernel(
|
| 323 |
-
# Pointers to matrices
|
| 324 |
-
x_ptr,
|
| 325 |
-
b_ptr,
|
| 326 |
-
states_ptr,
|
| 327 |
-
dt_ptr,
|
| 328 |
-
dA_cumsum_ptr,
|
| 329 |
-
seq_idx_ptr,
|
| 330 |
-
# Matrix dimensions
|
| 331 |
-
hdim,
|
| 332 |
-
dstate,
|
| 333 |
-
chunk_size,
|
| 334 |
-
batch,
|
| 335 |
-
seqlen,
|
| 336 |
-
nheads_ngroups_ratio,
|
| 337 |
-
# Strides
|
| 338 |
-
stride_x_batch,
|
| 339 |
-
stride_x_seqlen,
|
| 340 |
-
stride_x_head,
|
| 341 |
-
stride_x_hdim,
|
| 342 |
-
stride_b_batch,
|
| 343 |
-
stride_b_seqlen,
|
| 344 |
-
stride_b_head,
|
| 345 |
-
stride_b_dstate,
|
| 346 |
-
stride_states_batch,
|
| 347 |
-
stride_states_chunk,
|
| 348 |
-
stride_states_head,
|
| 349 |
-
stride_states_hdim,
|
| 350 |
-
stride_states_dstate,
|
| 351 |
-
stride_dt_batch,
|
| 352 |
-
stride_dt_chunk,
|
| 353 |
-
stride_dt_head,
|
| 354 |
-
stride_dt_csize,
|
| 355 |
-
stride_dA_cs_batch,
|
| 356 |
-
stride_dA_cs_chunk,
|
| 357 |
-
stride_dA_cs_head,
|
| 358 |
-
stride_dA_cs_csize,
|
| 359 |
-
stride_seq_idx_batch,
|
| 360 |
-
stride_seq_idx_seqlen,
|
| 361 |
-
# Meta-parameters
|
| 362 |
-
HAS_SEQ_IDX: tl.constexpr,
|
| 363 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 364 |
-
BLOCK_SIZE_N: tl.constexpr,
|
| 365 |
-
BLOCK_SIZE_K: tl.constexpr,
|
| 366 |
-
):
|
| 367 |
-
pid_bc = tl.program_id(axis=1)
|
| 368 |
-
pid_c = pid_bc // batch
|
| 369 |
-
pid_b = pid_bc - pid_c * batch
|
| 370 |
-
pid_h = tl.program_id(axis=2)
|
| 371 |
-
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
| 372 |
-
pid_m = tl.program_id(axis=0) // num_pid_n
|
| 373 |
-
pid_n = tl.program_id(axis=0) % num_pid_n
|
| 374 |
-
b_ptr += (
|
| 375 |
-
pid_b * stride_b_batch
|
| 376 |
-
+ pid_c * chunk_size * stride_b_seqlen
|
| 377 |
-
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
| 378 |
-
)
|
| 379 |
-
x_ptr += (
|
| 380 |
-
pid_b * stride_x_batch
|
| 381 |
-
+ pid_c * chunk_size * stride_x_seqlen
|
| 382 |
-
+ pid_h * stride_x_head
|
| 383 |
-
)
|
| 384 |
-
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
| 385 |
-
dA_cumsum_ptr += (
|
| 386 |
-
pid_b * stride_dA_cs_batch
|
| 387 |
-
+ pid_c * stride_dA_cs_chunk
|
| 388 |
-
+ pid_h * stride_dA_cs_head
|
| 389 |
-
)
|
| 390 |
-
if HAS_SEQ_IDX:
|
| 391 |
-
seq_idx_ptr += (
|
| 392 |
-
pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
| 393 |
-
)
|
| 394 |
-
|
| 395 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 396 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 397 |
-
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 398 |
-
x_ptrs = x_ptr + (
|
| 399 |
-
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
|
| 400 |
-
)
|
| 401 |
-
b_ptrs = b_ptr + (
|
| 402 |
-
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
|
| 403 |
-
)
|
| 404 |
-
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
| 405 |
-
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
| 406 |
-
tl.float32
|
| 407 |
-
)
|
| 408 |
-
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
| 409 |
-
if HAS_SEQ_IDX:
|
| 410 |
-
seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
|
| 411 |
-
|
| 412 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 413 |
-
if HAS_SEQ_IDX:
|
| 414 |
-
seq_idx_last = tl.load(
|
| 415 |
-
seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
|
| 416 |
-
)
|
| 417 |
-
|
| 418 |
-
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 419 |
-
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
| 420 |
-
x = tl.load(
|
| 421 |
-
x_ptrs,
|
| 422 |
-
mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k),
|
| 423 |
-
other=0.0,
|
| 424 |
-
)
|
| 425 |
-
b = tl.load(
|
| 426 |
-
b_ptrs,
|
| 427 |
-
mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate),
|
| 428 |
-
other=0.0,
|
| 429 |
-
).to(tl.float32)
|
| 430 |
-
dA_cs_k = tl.load(
|
| 431 |
-
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
|
| 432 |
-
).to(tl.float32)
|
| 433 |
-
if HAS_SEQ_IDX:
|
| 434 |
-
seq_idx_k = tl.load(
|
| 435 |
-
seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1
|
| 436 |
-
)
|
| 437 |
-
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
|
| 438 |
-
tl.float32
|
| 439 |
-
)
|
| 440 |
-
if not HAS_SEQ_IDX:
|
| 441 |
-
scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
|
| 442 |
-
else:
|
| 443 |
-
scale = tl.where(
|
| 444 |
-
seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0
|
| 445 |
-
)
|
| 446 |
-
b *= scale[:, None]
|
| 447 |
-
b = b.to(x_ptr.dtype.element_ty)
|
| 448 |
-
acc += tl.dot(x, b)
|
| 449 |
-
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
| 450 |
-
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
| 451 |
-
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
| 452 |
-
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
| 453 |
-
if HAS_SEQ_IDX:
|
| 454 |
-
seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
|
| 455 |
-
states = acc.to(states_ptr.dtype.element_ty)
|
| 456 |
-
|
| 457 |
-
states_ptr += (
|
| 458 |
-
pid_b * stride_states_batch
|
| 459 |
-
+ pid_c * stride_states_chunk
|
| 460 |
-
+ pid_h * stride_states_head
|
| 461 |
-
)
|
| 462 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 463 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 464 |
-
states_ptrs = states_ptr + (
|
| 465 |
-
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
|
| 466 |
-
)
|
| 467 |
-
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
| 468 |
-
tl.store(states_ptrs, states, mask=c_mask)
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
@triton.autotune(
|
| 472 |
-
configs=[
|
| 473 |
-
triton.Config(
|
| 474 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
| 475 |
-
num_stages=3,
|
| 476 |
-
num_warps=8,
|
| 477 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 478 |
-
),
|
| 479 |
-
triton.Config(
|
| 480 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
| 481 |
-
num_stages=4,
|
| 482 |
-
num_warps=4,
|
| 483 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 484 |
-
),
|
| 485 |
-
triton.Config(
|
| 486 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 487 |
-
num_stages=4,
|
| 488 |
-
num_warps=4,
|
| 489 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 490 |
-
),
|
| 491 |
-
triton.Config(
|
| 492 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 493 |
-
num_stages=4,
|
| 494 |
-
num_warps=4,
|
| 495 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 496 |
-
),
|
| 497 |
-
triton.Config(
|
| 498 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 499 |
-
num_stages=4,
|
| 500 |
-
num_warps=4,
|
| 501 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 502 |
-
),
|
| 503 |
-
triton.Config(
|
| 504 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 505 |
-
num_stages=4,
|
| 506 |
-
num_warps=4,
|
| 507 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 508 |
-
),
|
| 509 |
-
triton.Config(
|
| 510 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 511 |
-
num_stages=5,
|
| 512 |
-
num_warps=4,
|
| 513 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 514 |
-
),
|
| 515 |
-
triton.Config(
|
| 516 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 517 |
-
num_stages=5,
|
| 518 |
-
num_warps=4,
|
| 519 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 520 |
-
),
|
| 521 |
-
triton.Config(
|
| 522 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 523 |
-
num_stages=4,
|
| 524 |
-
num_warps=4,
|
| 525 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 526 |
-
),
|
| 527 |
-
],
|
| 528 |
-
key=["chunk_size", "hdim", "dstate"],
|
| 529 |
-
)
|
| 530 |
-
@triton.jit
|
| 531 |
-
def _chunk_state_bwd_dx_kernel(
|
| 532 |
-
# Pointers to matrices
|
| 533 |
-
x_ptr,
|
| 534 |
-
b_ptr,
|
| 535 |
-
dstates_ptr,
|
| 536 |
-
dt_ptr,
|
| 537 |
-
dA_cumsum_ptr,
|
| 538 |
-
dx_ptr,
|
| 539 |
-
ddt_ptr,
|
| 540 |
-
ddA_cumsum_ptr,
|
| 541 |
-
# Matrix dimensions
|
| 542 |
-
chunk_size,
|
| 543 |
-
hdim,
|
| 544 |
-
dstate,
|
| 545 |
-
batch,
|
| 546 |
-
seqlen,
|
| 547 |
-
nheads_ngroups_ratio,
|
| 548 |
-
# Strides
|
| 549 |
-
stride_x_batch,
|
| 550 |
-
stride_x_seqlen,
|
| 551 |
-
stride_x_head,
|
| 552 |
-
stride_x_hdim,
|
| 553 |
-
stride_b_batch,
|
| 554 |
-
stride_b_seqlen,
|
| 555 |
-
stride_b_head,
|
| 556 |
-
stride_b_dstate,
|
| 557 |
-
stride_dstates_batch,
|
| 558 |
-
stride_dstates_chunk,
|
| 559 |
-
stride_states_head,
|
| 560 |
-
stride_states_hdim,
|
| 561 |
-
stride_states_dstate,
|
| 562 |
-
stride_dt_batch,
|
| 563 |
-
stride_dt_chunk,
|
| 564 |
-
stride_dt_head,
|
| 565 |
-
stride_dt_csize,
|
| 566 |
-
stride_dA_cs_batch,
|
| 567 |
-
stride_dA_cs_chunk,
|
| 568 |
-
stride_dA_cs_head,
|
| 569 |
-
stride_dA_cs_csize,
|
| 570 |
-
stride_dx_batch,
|
| 571 |
-
stride_dx_seqlen,
|
| 572 |
-
stride_dx_head,
|
| 573 |
-
stride_dx_hdim,
|
| 574 |
-
stride_ddt_batch,
|
| 575 |
-
stride_ddt_chunk,
|
| 576 |
-
stride_ddt_head,
|
| 577 |
-
stride_ddt_csize,
|
| 578 |
-
stride_ddA_cs_batch,
|
| 579 |
-
stride_ddA_cs_chunk,
|
| 580 |
-
stride_ddA_cs_head,
|
| 581 |
-
stride_ddA_cs_csize,
|
| 582 |
-
# Meta-parameters
|
| 583 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 584 |
-
BLOCK_SIZE_N: tl.constexpr,
|
| 585 |
-
BLOCK_SIZE_K: tl.constexpr,
|
| 586 |
-
BLOCK_SIZE_DSTATE: tl.constexpr,
|
| 587 |
-
):
|
| 588 |
-
pid_bc = tl.program_id(axis=1)
|
| 589 |
-
pid_c = pid_bc // batch
|
| 590 |
-
pid_b = pid_bc - pid_c * batch
|
| 591 |
-
pid_h = tl.program_id(axis=2)
|
| 592 |
-
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
| 593 |
-
pid_m = tl.program_id(axis=0) // num_pid_n
|
| 594 |
-
pid_n = tl.program_id(axis=0) % num_pid_n
|
| 595 |
-
x_ptr += (
|
| 596 |
-
pid_b * stride_x_batch
|
| 597 |
-
+ pid_c * chunk_size * stride_x_seqlen
|
| 598 |
-
+ pid_h * stride_x_head
|
| 599 |
-
)
|
| 600 |
-
b_ptr += (
|
| 601 |
-
pid_b * stride_b_batch
|
| 602 |
-
+ pid_c * chunk_size * stride_b_seqlen
|
| 603 |
-
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
| 604 |
-
)
|
| 605 |
-
dstates_ptr += (
|
| 606 |
-
pid_b * stride_dstates_batch
|
| 607 |
-
+ pid_c * stride_dstates_chunk
|
| 608 |
-
+ pid_h * stride_states_head
|
| 609 |
-
)
|
| 610 |
-
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
| 611 |
-
ddt_ptr += (
|
| 612 |
-
pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
|
| 613 |
-
)
|
| 614 |
-
ddA_cumsum_ptr += (
|
| 615 |
-
pid_b * stride_ddA_cs_batch
|
| 616 |
-
+ pid_c * stride_ddA_cs_chunk
|
| 617 |
-
+ pid_h * stride_ddA_cs_head
|
| 618 |
-
)
|
| 619 |
-
dA_cumsum_ptr += (
|
| 620 |
-
pid_b * stride_dA_cs_batch
|
| 621 |
-
+ pid_c * stride_dA_cs_chunk
|
| 622 |
-
+ pid_h * stride_dA_cs_head
|
| 623 |
-
)
|
| 624 |
-
|
| 625 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 626 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 627 |
-
|
| 628 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 629 |
-
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
| 630 |
-
offs_k = tl.arange(
|
| 631 |
-
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
|
| 632 |
-
)
|
| 633 |
-
b_ptrs = b_ptr + (
|
| 634 |
-
offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate
|
| 635 |
-
)
|
| 636 |
-
dstates_ptrs = dstates_ptr + (
|
| 637 |
-
offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate
|
| 638 |
-
)
|
| 639 |
-
if BLOCK_SIZE_DSTATE <= 128:
|
| 640 |
-
b = tl.load(
|
| 641 |
-
b_ptrs,
|
| 642 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate),
|
| 643 |
-
other=0.0,
|
| 644 |
-
)
|
| 645 |
-
dstates = tl.load(
|
| 646 |
-
dstates_ptrs,
|
| 647 |
-
mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim),
|
| 648 |
-
other=0.0,
|
| 649 |
-
)
|
| 650 |
-
dstates = dstates.to(b_ptr.dtype.element_ty)
|
| 651 |
-
acc = tl.dot(b, dstates)
|
| 652 |
-
else:
|
| 653 |
-
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 654 |
-
for k in range(0, dstate, BLOCK_SIZE_K):
|
| 655 |
-
b = tl.load(
|
| 656 |
-
b_ptrs,
|
| 657 |
-
mask=(offs_m[:, None] < chunk_size_limit)
|
| 658 |
-
& (offs_k[None, :] < dstate - k),
|
| 659 |
-
other=0.0,
|
| 660 |
-
)
|
| 661 |
-
dstates = tl.load(
|
| 662 |
-
dstates_ptrs,
|
| 663 |
-
mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim),
|
| 664 |
-
other=0.0,
|
| 665 |
-
)
|
| 666 |
-
dstates = dstates.to(b_ptr.dtype.element_ty)
|
| 667 |
-
acc += tl.dot(b, dstates)
|
| 668 |
-
b_ptrs += BLOCK_SIZE_K * stride_b_dstate
|
| 669 |
-
dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
|
| 670 |
-
|
| 671 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 672 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 673 |
-
|
| 674 |
-
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
| 675 |
-
tl.float32
|
| 676 |
-
)
|
| 677 |
-
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
| 678 |
-
dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
|
| 679 |
-
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(
|
| 680 |
-
tl.float32
|
| 681 |
-
)
|
| 682 |
-
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
| 683 |
-
acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]
|
| 684 |
-
|
| 685 |
-
x_ptrs = x_ptr + (
|
| 686 |
-
offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
|
| 687 |
-
)
|
| 688 |
-
x = tl.load(
|
| 689 |
-
x_ptrs,
|
| 690 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
| 691 |
-
other=0.0,
|
| 692 |
-
).to(tl.float32)
|
| 693 |
-
ddt = tl.sum(acc * x, axis=1)
|
| 694 |
-
ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
|
| 695 |
-
tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
|
| 696 |
-
ddA_cs = -(ddt * dt_m)
|
| 697 |
-
ddA_cs_last = -tl.sum(ddA_cs)
|
| 698 |
-
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
|
| 699 |
-
tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
|
| 700 |
-
tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last)
|
| 701 |
-
|
| 702 |
-
dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty)
|
| 703 |
-
dx_ptr += (
|
| 704 |
-
pid_b * stride_dx_batch
|
| 705 |
-
+ pid_c * chunk_size * stride_dx_seqlen
|
| 706 |
-
+ pid_h * stride_dx_head
|
| 707 |
-
)
|
| 708 |
-
dx_ptrs = dx_ptr + (
|
| 709 |
-
offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim
|
| 710 |
-
)
|
| 711 |
-
tl.store(
|
| 712 |
-
dx_ptrs,
|
| 713 |
-
dx,
|
| 714 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
| 715 |
-
)
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
@triton.autotune(
|
| 719 |
-
configs=[
|
| 720 |
-
triton.Config(
|
| 721 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128},
|
| 722 |
-
num_stages=3,
|
| 723 |
-
num_warps=4,
|
| 724 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 725 |
-
),
|
| 726 |
-
triton.Config(
|
| 727 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32},
|
| 728 |
-
num_stages=3,
|
| 729 |
-
num_warps=4,
|
| 730 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 731 |
-
),
|
| 732 |
-
triton.Config(
|
| 733 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128},
|
| 734 |
-
num_stages=3,
|
| 735 |
-
num_warps=4,
|
| 736 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 737 |
-
),
|
| 738 |
-
triton.Config(
|
| 739 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64},
|
| 740 |
-
num_stages=3,
|
| 741 |
-
num_warps=4,
|
| 742 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 743 |
-
),
|
| 744 |
-
triton.Config(
|
| 745 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64},
|
| 746 |
-
num_stages=3,
|
| 747 |
-
num_warps=4,
|
| 748 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 749 |
-
),
|
| 750 |
-
triton.Config(
|
| 751 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32},
|
| 752 |
-
num_stages=3,
|
| 753 |
-
num_warps=4,
|
| 754 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 755 |
-
),
|
| 756 |
-
triton.Config(
|
| 757 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64},
|
| 758 |
-
num_stages=3,
|
| 759 |
-
num_warps=4,
|
| 760 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 761 |
-
),
|
| 762 |
-
triton.Config(
|
| 763 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32},
|
| 764 |
-
num_stages=3,
|
| 765 |
-
num_warps=4,
|
| 766 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 767 |
-
),
|
| 768 |
-
],
|
| 769 |
-
key=["chunk_size", "dstate", "hdim"],
|
| 770 |
-
)
|
| 771 |
-
@triton.jit
|
| 772 |
-
def _chunk_state_bwd_db_kernel(
|
| 773 |
-
# Pointers to matrices
|
| 774 |
-
x_ptr,
|
| 775 |
-
dstates_ptr,
|
| 776 |
-
b_ptr,
|
| 777 |
-
dt_ptr,
|
| 778 |
-
dA_cumsum_ptr,
|
| 779 |
-
seq_idx_ptr,
|
| 780 |
-
db_ptr,
|
| 781 |
-
ddA_cumsum_ptr,
|
| 782 |
-
# Matrix dimensions
|
| 783 |
-
chunk_size,
|
| 784 |
-
dstate,
|
| 785 |
-
hdim,
|
| 786 |
-
batch,
|
| 787 |
-
seqlen,
|
| 788 |
-
nheads,
|
| 789 |
-
nheads_per_program,
|
| 790 |
-
ngroups,
|
| 791 |
-
# Strides
|
| 792 |
-
stride_x_batch,
|
| 793 |
-
stride_x_seqlen,
|
| 794 |
-
stride_x_head,
|
| 795 |
-
stride_x_hdim,
|
| 796 |
-
stride_dstates_batch,
|
| 797 |
-
stride_dstates_chunk,
|
| 798 |
-
stride_states_head,
|
| 799 |
-
stride_states_hdim,
|
| 800 |
-
stride_states_dstate,
|
| 801 |
-
stride_b_batch,
|
| 802 |
-
stride_b_seqlen,
|
| 803 |
-
stride_b_head,
|
| 804 |
-
stride_b_dstate,
|
| 805 |
-
stride_dt_batch,
|
| 806 |
-
stride_dt_chunk,
|
| 807 |
-
stride_dt_head,
|
| 808 |
-
stride_dt_csize,
|
| 809 |
-
stride_dA_cs_batch,
|
| 810 |
-
stride_dA_cs_chunk,
|
| 811 |
-
stride_dA_cs_head,
|
| 812 |
-
stride_dA_cs_csize,
|
| 813 |
-
stride_seq_idx_batch,
|
| 814 |
-
stride_seq_idx_seqlen,
|
| 815 |
-
stride_db_batch,
|
| 816 |
-
stride_db_seqlen,
|
| 817 |
-
stride_db_split,
|
| 818 |
-
stride_db_group,
|
| 819 |
-
stride_db_dstate,
|
| 820 |
-
stride_ddA_cs_batch,
|
| 821 |
-
stride_ddA_cs_chunk,
|
| 822 |
-
stride_ddA_cs_head,
|
| 823 |
-
stride_ddA_cs_csize,
|
| 824 |
-
# Meta-parameters
|
| 825 |
-
HAS_DDA_CS: tl.constexpr,
|
| 826 |
-
HAS_SEQ_IDX: tl.constexpr,
|
| 827 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 828 |
-
BLOCK_SIZE_N: tl.constexpr,
|
| 829 |
-
BLOCK_SIZE_K: tl.constexpr,
|
| 830 |
-
):
|
| 831 |
-
pid_bc = tl.program_id(axis=1)
|
| 832 |
-
pid_c = pid_bc // batch
|
| 833 |
-
pid_b = pid_bc - pid_c * batch
|
| 834 |
-
pid_sg = tl.program_id(axis=2)
|
| 835 |
-
pid_s = pid_sg // ngroups
|
| 836 |
-
pid_g = pid_sg - pid_s * ngroups
|
| 837 |
-
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
| 838 |
-
pid_m = tl.program_id(axis=0) // num_pid_n
|
| 839 |
-
pid_n = tl.program_id(axis=0) % num_pid_n
|
| 840 |
-
x_ptr += (
|
| 841 |
-
pid_b * stride_x_batch
|
| 842 |
-
+ pid_c * chunk_size * stride_x_seqlen
|
| 843 |
-
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head
|
| 844 |
-
)
|
| 845 |
-
db_ptr += (
|
| 846 |
-
pid_b * stride_db_batch
|
| 847 |
-
+ pid_c * chunk_size * stride_db_seqlen
|
| 848 |
-
+ pid_g * stride_db_group
|
| 849 |
-
+ pid_s * stride_db_split
|
| 850 |
-
)
|
| 851 |
-
dstates_ptr += (
|
| 852 |
-
pid_b * stride_dstates_batch
|
| 853 |
-
+ pid_c * stride_dstates_chunk
|
| 854 |
-
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program)
|
| 855 |
-
* stride_states_head
|
| 856 |
-
)
|
| 857 |
-
dt_ptr += (
|
| 858 |
-
pid_b * stride_dt_batch
|
| 859 |
-
+ pid_c * stride_dt_chunk
|
| 860 |
-
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head
|
| 861 |
-
)
|
| 862 |
-
dA_cumsum_ptr += (
|
| 863 |
-
pid_b * stride_dA_cs_batch
|
| 864 |
-
+ pid_c * stride_dA_cs_chunk
|
| 865 |
-
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head
|
| 866 |
-
)
|
| 867 |
-
if HAS_DDA_CS:
|
| 868 |
-
b_ptr += (
|
| 869 |
-
pid_b * stride_b_batch
|
| 870 |
-
+ pid_c * chunk_size * stride_b_seqlen
|
| 871 |
-
+ pid_g * stride_b_head
|
| 872 |
-
)
|
| 873 |
-
ddA_cumsum_ptr += (
|
| 874 |
-
pid_b * stride_ddA_cs_batch
|
| 875 |
-
+ pid_c * stride_ddA_cs_chunk
|
| 876 |
-
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program)
|
| 877 |
-
* stride_ddA_cs_head
|
| 878 |
-
)
|
| 879 |
-
if HAS_SEQ_IDX:
|
| 880 |
-
seq_idx_ptr += (
|
| 881 |
-
pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
| 882 |
-
)
|
| 883 |
-
|
| 884 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 885 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 886 |
-
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 887 |
-
x_ptrs = x_ptr + (
|
| 888 |
-
offs_m[:, None] * stride_x_seqlen + offs_k[None, :] * stride_x_hdim
|
| 889 |
-
)
|
| 890 |
-
dstates_ptrs = dstates_ptr + (
|
| 891 |
-
offs_n[None, :] * stride_states_dstate + offs_k[:, None] * stride_states_hdim
|
| 892 |
-
)
|
| 893 |
-
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
| 894 |
-
dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
|
| 895 |
-
if HAS_DDA_CS:
|
| 896 |
-
b_ptrs = b_ptr + (
|
| 897 |
-
offs_m[:, None] * stride_b_seqlen + offs_n[None, :] * stride_b_dstate
|
| 898 |
-
)
|
| 899 |
-
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
|
| 900 |
-
|
| 901 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 902 |
-
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 903 |
-
if HAS_DDA_CS:
|
| 904 |
-
b = tl.load(
|
| 905 |
-
b_ptrs,
|
| 906 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate),
|
| 907 |
-
other=0.0,
|
| 908 |
-
).to(tl.float32)
|
| 909 |
-
if HAS_SEQ_IDX:
|
| 910 |
-
seq_idx_m = tl.load(
|
| 911 |
-
seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
| 912 |
-
mask=offs_m < chunk_size_limit,
|
| 913 |
-
other=-1,
|
| 914 |
-
)
|
| 915 |
-
seq_idx_last = tl.load(
|
| 916 |
-
seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
|
| 917 |
-
)
|
| 918 |
-
nheads_iter = min(
|
| 919 |
-
nheads_per_program, nheads // ngroups - pid_s * nheads_per_program
|
| 920 |
-
)
|
| 921 |
-
for h in range(nheads_iter):
|
| 922 |
-
x = tl.load(
|
| 923 |
-
x_ptrs,
|
| 924 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim),
|
| 925 |
-
other=0.0,
|
| 926 |
-
)
|
| 927 |
-
dstates = tl.load(
|
| 928 |
-
dstates_ptrs,
|
| 929 |
-
mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate),
|
| 930 |
-
other=0.0,
|
| 931 |
-
)
|
| 932 |
-
dstates = dstates.to(x_ptrs.dtype.element_ty)
|
| 933 |
-
db = tl.dot(x, dstates)
|
| 934 |
-
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
| 935 |
-
tl.float32
|
| 936 |
-
)
|
| 937 |
-
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(
|
| 938 |
-
tl.float32
|
| 939 |
-
)
|
| 940 |
-
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
| 941 |
-
if not HAS_SEQ_IDX:
|
| 942 |
-
scale = tl.exp(dA_cs_last - dA_cs_m)
|
| 943 |
-
else:
|
| 944 |
-
scale = tl.where(
|
| 945 |
-
seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0
|
| 946 |
-
)
|
| 947 |
-
db *= (scale * dt_m)[:, None]
|
| 948 |
-
if HAS_DDA_CS:
|
| 949 |
-
# This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum
|
| 950 |
-
ddA_cs = tl.sum(db * b, axis=1)
|
| 951 |
-
tl.atomic_add(
|
| 952 |
-
ddA_cumsum_ptrs + stride_ddA_cs_csize,
|
| 953 |
-
ddA_cs,
|
| 954 |
-
mask=offs_m < chunk_size - 1,
|
| 955 |
-
)
|
| 956 |
-
acc += db
|
| 957 |
-
x_ptrs += stride_x_head
|
| 958 |
-
dstates_ptrs += stride_states_head
|
| 959 |
-
dt_ptrs += stride_dt_head
|
| 960 |
-
dA_cumsum_ptr += stride_dA_cs_head
|
| 961 |
-
dA_cumsum_ptrs += stride_dA_cs_head
|
| 962 |
-
if HAS_DDA_CS:
|
| 963 |
-
ddA_cumsum_ptrs += stride_ddA_cs_head
|
| 964 |
-
|
| 965 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 966 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 967 |
-
# if HAS_SEQ_IDX:
|
| 968 |
-
# seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
|
| 969 |
-
# seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
|
| 970 |
-
# acc = tl.where(seq_idx_m[:, None] == seq_idx_last, acc, 0.0)
|
| 971 |
-
db_ptrs = db_ptr + (
|
| 972 |
-
offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_dstate
|
| 973 |
-
)
|
| 974 |
-
tl.store(
|
| 975 |
-
db_ptrs,
|
| 976 |
-
acc,
|
| 977 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate),
|
| 978 |
-
)
|
| 979 |
-
|
| 980 |
-
|
| 981 |
-
@triton.autotune(
|
| 982 |
-
configs=[
|
| 983 |
-
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 984 |
-
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 985 |
-
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 986 |
-
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 987 |
-
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 988 |
-
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 989 |
-
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 990 |
-
# triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 991 |
-
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 992 |
-
triton.Config(
|
| 993 |
-
{"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32},
|
| 994 |
-
num_stages=3,
|
| 995 |
-
num_warps=4,
|
| 996 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 997 |
-
),
|
| 998 |
-
triton.Config(
|
| 999 |
-
{"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 1000 |
-
num_stages=3,
|
| 1001 |
-
num_warps=4,
|
| 1002 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1003 |
-
),
|
| 1004 |
-
triton.Config(
|
| 1005 |
-
{"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 1006 |
-
num_stages=3,
|
| 1007 |
-
num_warps=4,
|
| 1008 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1009 |
-
),
|
| 1010 |
-
triton.Config(
|
| 1011 |
-
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 1012 |
-
num_stages=3,
|
| 1013 |
-
num_warps=4,
|
| 1014 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1015 |
-
),
|
| 1016 |
-
triton.Config(
|
| 1017 |
-
{"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32},
|
| 1018 |
-
num_stages=4,
|
| 1019 |
-
num_warps=8,
|
| 1020 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1021 |
-
),
|
| 1022 |
-
triton.Config(
|
| 1023 |
-
{"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 1024 |
-
num_stages=4,
|
| 1025 |
-
num_warps=8,
|
| 1026 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1027 |
-
),
|
| 1028 |
-
triton.Config(
|
| 1029 |
-
{"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 1030 |
-
num_stages=4,
|
| 1031 |
-
num_warps=8,
|
| 1032 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1033 |
-
),
|
| 1034 |
-
triton.Config(
|
| 1035 |
-
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 1036 |
-
num_stages=4,
|
| 1037 |
-
num_warps=8,
|
| 1038 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1039 |
-
),
|
| 1040 |
-
],
|
| 1041 |
-
key=["chunk_size", "hdim", "dstate"],
|
| 1042 |
-
)
|
| 1043 |
-
@triton.jit
|
| 1044 |
-
def _chunk_state_bwd_ddAcs_stable_kernel(
|
| 1045 |
-
# Pointers to matrices
|
| 1046 |
-
x_ptr,
|
| 1047 |
-
b_ptr,
|
| 1048 |
-
dstates_ptr,
|
| 1049 |
-
dt_ptr,
|
| 1050 |
-
dA_cumsum_ptr,
|
| 1051 |
-
seq_idx_ptr,
|
| 1052 |
-
ddA_cumsum_ptr,
|
| 1053 |
-
# Matrix dimensions
|
| 1054 |
-
chunk_size,
|
| 1055 |
-
hdim,
|
| 1056 |
-
dstate,
|
| 1057 |
-
batch,
|
| 1058 |
-
seqlen,
|
| 1059 |
-
nheads_ngroups_ratio,
|
| 1060 |
-
# Strides
|
| 1061 |
-
stride_x_batch,
|
| 1062 |
-
stride_x_seqlen,
|
| 1063 |
-
stride_x_head,
|
| 1064 |
-
stride_x_hdim,
|
| 1065 |
-
stride_b_batch,
|
| 1066 |
-
stride_b_seqlen,
|
| 1067 |
-
stride_b_head,
|
| 1068 |
-
stride_b_dstate,
|
| 1069 |
-
stride_dstates_batch,
|
| 1070 |
-
stride_dstates_chunk,
|
| 1071 |
-
stride_states_head,
|
| 1072 |
-
stride_states_hdim,
|
| 1073 |
-
stride_states_dstate,
|
| 1074 |
-
stride_dt_batch,
|
| 1075 |
-
stride_dt_chunk,
|
| 1076 |
-
stride_dt_head,
|
| 1077 |
-
stride_dt_csize,
|
| 1078 |
-
stride_dA_cs_batch,
|
| 1079 |
-
stride_dA_cs_chunk,
|
| 1080 |
-
stride_dA_cs_head,
|
| 1081 |
-
stride_dA_cs_csize,
|
| 1082 |
-
stride_seq_idx_batch,
|
| 1083 |
-
stride_seq_idx_seqlen,
|
| 1084 |
-
stride_ddA_cs_batch,
|
| 1085 |
-
stride_ddA_cs_chunk,
|
| 1086 |
-
stride_ddA_cs_head,
|
| 1087 |
-
stride_ddA_cs_csize,
|
| 1088 |
-
# Meta-parameters
|
| 1089 |
-
HAS_SEQ_IDX: tl.constexpr,
|
| 1090 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 1091 |
-
BLOCK_SIZE_N: tl.constexpr,
|
| 1092 |
-
BLOCK_SIZE_K: tl.constexpr,
|
| 1093 |
-
BLOCK_SIZE_DSTATE: tl.constexpr,
|
| 1094 |
-
):
|
| 1095 |
-
pid_bc = tl.program_id(axis=1)
|
| 1096 |
-
pid_c = pid_bc // batch
|
| 1097 |
-
pid_b = pid_bc - pid_c * batch
|
| 1098 |
-
pid_h = tl.program_id(axis=2)
|
| 1099 |
-
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
| 1100 |
-
pid_m = tl.program_id(axis=0) // num_pid_n
|
| 1101 |
-
pid_n = tl.program_id(axis=0) % num_pid_n
|
| 1102 |
-
x_ptr += (
|
| 1103 |
-
pid_b * stride_x_batch
|
| 1104 |
-
+ pid_c * chunk_size * stride_x_seqlen
|
| 1105 |
-
+ pid_h * stride_x_head
|
| 1106 |
-
)
|
| 1107 |
-
b_ptr += (
|
| 1108 |
-
pid_b * stride_b_batch
|
| 1109 |
-
+ pid_c * chunk_size * stride_b_seqlen
|
| 1110 |
-
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
| 1111 |
-
)
|
| 1112 |
-
dstates_ptr += (
|
| 1113 |
-
pid_b * stride_dstates_batch
|
| 1114 |
-
+ pid_c * stride_dstates_chunk
|
| 1115 |
-
+ pid_h * stride_states_head
|
| 1116 |
-
)
|
| 1117 |
-
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
| 1118 |
-
ddA_cumsum_ptr += (
|
| 1119 |
-
pid_b * stride_ddA_cs_batch
|
| 1120 |
-
+ pid_c * stride_ddA_cs_chunk
|
| 1121 |
-
+ pid_h * stride_ddA_cs_head
|
| 1122 |
-
)
|
| 1123 |
-
dA_cumsum_ptr += (
|
| 1124 |
-
pid_b * stride_dA_cs_batch
|
| 1125 |
-
+ pid_c * stride_dA_cs_chunk
|
| 1126 |
-
+ pid_h * stride_dA_cs_head
|
| 1127 |
-
)
|
| 1128 |
-
if HAS_SEQ_IDX:
|
| 1129 |
-
seq_idx_ptr += (
|
| 1130 |
-
pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
| 1131 |
-
)
|
| 1132 |
-
|
| 1133 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 1134 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 1135 |
-
|
| 1136 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 1137 |
-
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
| 1138 |
-
offs_k = tl.arange(
|
| 1139 |
-
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
|
| 1140 |
-
)
|
| 1141 |
-
b_ptrs = b_ptr + (
|
| 1142 |
-
offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate
|
| 1143 |
-
)
|
| 1144 |
-
dstates_ptrs = dstates_ptr + (
|
| 1145 |
-
offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate
|
| 1146 |
-
)
|
| 1147 |
-
if BLOCK_SIZE_DSTATE <= 128:
|
| 1148 |
-
b = tl.load(
|
| 1149 |
-
b_ptrs,
|
| 1150 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate),
|
| 1151 |
-
other=0.0,
|
| 1152 |
-
)
|
| 1153 |
-
dstates = tl.load(
|
| 1154 |
-
dstates_ptrs,
|
| 1155 |
-
mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim),
|
| 1156 |
-
other=0.0,
|
| 1157 |
-
)
|
| 1158 |
-
dstates = dstates.to(b_ptr.dtype.element_ty)
|
| 1159 |
-
acc = tl.dot(b, dstates)
|
| 1160 |
-
else:
|
| 1161 |
-
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 1162 |
-
for k in range(0, dstate, BLOCK_SIZE_K):
|
| 1163 |
-
b = tl.load(
|
| 1164 |
-
b_ptrs,
|
| 1165 |
-
mask=(offs_m[:, None] < chunk_size_limit)
|
| 1166 |
-
& (offs_k[None, :] < dstate - k),
|
| 1167 |
-
other=0.0,
|
| 1168 |
-
)
|
| 1169 |
-
dstates = tl.load(
|
| 1170 |
-
dstates_ptrs,
|
| 1171 |
-
mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim),
|
| 1172 |
-
other=0.0,
|
| 1173 |
-
)
|
| 1174 |
-
dstates = dstates.to(b_ptr.dtype.element_ty)
|
| 1175 |
-
acc += tl.dot(b, dstates)
|
| 1176 |
-
b_ptrs += BLOCK_SIZE_K * stride_b_dstate
|
| 1177 |
-
dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
|
| 1178 |
-
|
| 1179 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 1180 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 1181 |
-
|
| 1182 |
-
dA_cs_m = tl.load(
|
| 1183 |
-
dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0
|
| 1184 |
-
).to(tl.float32)
|
| 1185 |
-
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
| 1186 |
-
tl.float32
|
| 1187 |
-
)
|
| 1188 |
-
if not HAS_SEQ_IDX:
|
| 1189 |
-
scale = tl.exp(dA_cs_last - dA_cs_m)
|
| 1190 |
-
else:
|
| 1191 |
-
seq_idx_m = tl.load(
|
| 1192 |
-
seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
| 1193 |
-
mask=offs_m < chunk_size_limit,
|
| 1194 |
-
other=-1,
|
| 1195 |
-
)
|
| 1196 |
-
seq_idx_last = tl.load(
|
| 1197 |
-
seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
|
| 1198 |
-
)
|
| 1199 |
-
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
|
| 1200 |
-
acc *= scale[:, None]
|
| 1201 |
-
|
| 1202 |
-
x_ptrs = x_ptr + (
|
| 1203 |
-
offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
|
| 1204 |
-
)
|
| 1205 |
-
x = tl.load(
|
| 1206 |
-
x_ptrs,
|
| 1207 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
| 1208 |
-
other=0.0,
|
| 1209 |
-
).to(tl.float32)
|
| 1210 |
-
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
| 1211 |
-
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
| 1212 |
-
ddt = tl.sum(acc * x, axis=1)
|
| 1213 |
-
# ddA_cs = -(ddt * dt_m)
|
| 1214 |
-
# Triton 2.2.0 errors if we have the cumsum here, so we just write it out
|
| 1215 |
-
# then call torch.cumsum outside this kernel.
|
| 1216 |
-
# ddA_cs = tl.cumsum(ddt * dt_m)
|
| 1217 |
-
ddA_cs = ddt * dt_m
|
| 1218 |
-
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
|
| 1219 |
-
# tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
|
| 1220 |
-
tl.atomic_add(
|
| 1221 |
-
ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1
|
| 1222 |
-
)
|
| 1223 |
-
|
| 1224 |
-
|
| 1225 |
-
@triton.autotune(
|
| 1226 |
-
configs=[
|
| 1227 |
-
triton.Config(
|
| 1228 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
| 1229 |
-
num_stages=3,
|
| 1230 |
-
num_warps=8,
|
| 1231 |
-
),
|
| 1232 |
-
triton.Config(
|
| 1233 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
| 1234 |
-
num_stages=4,
|
| 1235 |
-
num_warps=4,
|
| 1236 |
-
),
|
| 1237 |
-
triton.Config(
|
| 1238 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 1239 |
-
num_stages=4,
|
| 1240 |
-
num_warps=4,
|
| 1241 |
-
),
|
| 1242 |
-
triton.Config(
|
| 1243 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 1244 |
-
num_stages=4,
|
| 1245 |
-
num_warps=4,
|
| 1246 |
-
),
|
| 1247 |
-
triton.Config(
|
| 1248 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 1249 |
-
num_stages=4,
|
| 1250 |
-
num_warps=4,
|
| 1251 |
-
),
|
| 1252 |
-
triton.Config(
|
| 1253 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 1254 |
-
num_stages=4,
|
| 1255 |
-
num_warps=4,
|
| 1256 |
-
),
|
| 1257 |
-
triton.Config(
|
| 1258 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 1259 |
-
num_stages=5,
|
| 1260 |
-
num_warps=2,
|
| 1261 |
-
),
|
| 1262 |
-
triton.Config(
|
| 1263 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 1264 |
-
num_stages=5,
|
| 1265 |
-
num_warps=2,
|
| 1266 |
-
),
|
| 1267 |
-
triton.Config(
|
| 1268 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 1269 |
-
num_stages=4,
|
| 1270 |
-
num_warps=2,
|
| 1271 |
-
),
|
| 1272 |
-
],
|
| 1273 |
-
key=["hdim", "dstate", "chunk_size"],
|
| 1274 |
-
)
|
| 1275 |
-
@triton.jit
|
| 1276 |
-
def _chunk_state_varlen_kernel(
|
| 1277 |
-
# Pointers to matrices
|
| 1278 |
-
x_ptr,
|
| 1279 |
-
b_ptr,
|
| 1280 |
-
dt_ptr,
|
| 1281 |
-
dA_cumsum_ptr,
|
| 1282 |
-
chunk_states_ptr,
|
| 1283 |
-
cu_seqlens_ptr,
|
| 1284 |
-
states_ptr,
|
| 1285 |
-
# Matrix dimensions
|
| 1286 |
-
hdim,
|
| 1287 |
-
dstate,
|
| 1288 |
-
chunk_size,
|
| 1289 |
-
seqlen,
|
| 1290 |
-
nheads_ngroups_ratio,
|
| 1291 |
-
# Strides
|
| 1292 |
-
stride_x_seqlen,
|
| 1293 |
-
stride_x_head,
|
| 1294 |
-
stride_x_hdim,
|
| 1295 |
-
stride_b_seqlen,
|
| 1296 |
-
stride_b_head,
|
| 1297 |
-
stride_b_dstate,
|
| 1298 |
-
stride_dt_chunk,
|
| 1299 |
-
stride_dt_head,
|
| 1300 |
-
stride_dt_csize,
|
| 1301 |
-
stride_dA_cs_chunk,
|
| 1302 |
-
stride_dA_cs_head,
|
| 1303 |
-
stride_dA_cs_csize,
|
| 1304 |
-
stride_chunk_states_chunk,
|
| 1305 |
-
stride_chunk_states_head,
|
| 1306 |
-
stride_chunk_states_hdim,
|
| 1307 |
-
stride_chunk_states_dstate,
|
| 1308 |
-
stride_states_batch,
|
| 1309 |
-
stride_states_head,
|
| 1310 |
-
stride_states_hdim,
|
| 1311 |
-
stride_states_dstate,
|
| 1312 |
-
# Meta-parameters
|
| 1313 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 1314 |
-
BLOCK_SIZE_N: tl.constexpr,
|
| 1315 |
-
BLOCK_SIZE_K: tl.constexpr,
|
| 1316 |
-
):
|
| 1317 |
-
pid_b = tl.program_id(axis=1)
|
| 1318 |
-
pid_h = tl.program_id(axis=2)
|
| 1319 |
-
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
| 1320 |
-
pid_m = tl.program_id(axis=0) // num_pid_n
|
| 1321 |
-
pid_n = tl.program_id(axis=0) % num_pid_n
|
| 1322 |
-
end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
|
| 1323 |
-
pid_c = (end_idx - 1) // chunk_size
|
| 1324 |
-
b_ptr += (
|
| 1325 |
-
pid_c * chunk_size * stride_b_seqlen
|
| 1326 |
-
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
| 1327 |
-
)
|
| 1328 |
-
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
| 1329 |
-
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
| 1330 |
-
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
| 1331 |
-
chunk_states_ptr += (
|
| 1332 |
-
pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
|
| 1333 |
-
)
|
| 1334 |
-
|
| 1335 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 1336 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 1337 |
-
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 1338 |
-
x_ptrs = x_ptr + (
|
| 1339 |
-
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
|
| 1340 |
-
)
|
| 1341 |
-
b_ptrs = b_ptr + (
|
| 1342 |
-
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
|
| 1343 |
-
)
|
| 1344 |
-
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
| 1345 |
-
dA_cs_last = tl.load(
|
| 1346 |
-
dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
|
| 1347 |
-
).to(tl.float32)
|
| 1348 |
-
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
| 1349 |
-
|
| 1350 |
-
chunk_size_limit = end_idx - pid_c * chunk_size
|
| 1351 |
-
start_idx = tl.load(cu_seqlens_ptr + pid_b)
|
| 1352 |
-
start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
|
| 1353 |
-
|
| 1354 |
-
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 1355 |
-
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
| 1356 |
-
x = tl.load(
|
| 1357 |
-
x_ptrs,
|
| 1358 |
-
mask=(offs_m[:, None] < hdim)
|
| 1359 |
-
& (offs_k[None, :] < chunk_size_limit - k)
|
| 1360 |
-
& (offs_k[None, :] >= start_idx_cur - k),
|
| 1361 |
-
other=0.0,
|
| 1362 |
-
)
|
| 1363 |
-
b = tl.load(
|
| 1364 |
-
b_ptrs,
|
| 1365 |
-
mask=(offs_k[:, None] < chunk_size_limit - k)
|
| 1366 |
-
& (offs_n[None, :] < dstate)
|
| 1367 |
-
& (offs_k[:, None] >= start_idx_cur - k),
|
| 1368 |
-
other=0.0,
|
| 1369 |
-
).to(tl.float32)
|
| 1370 |
-
dA_cs_k = tl.load(
|
| 1371 |
-
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
|
| 1372 |
-
).to(tl.float32)
|
| 1373 |
-
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
|
| 1374 |
-
tl.float32
|
| 1375 |
-
)
|
| 1376 |
-
scale = tl.where(
|
| 1377 |
-
(offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
|
| 1378 |
-
tl.exp((dA_cs_last - dA_cs_k)) * dt_k,
|
| 1379 |
-
0.0,
|
| 1380 |
-
)
|
| 1381 |
-
b *= scale[:, None]
|
| 1382 |
-
b = b.to(x_ptr.dtype.element_ty)
|
| 1383 |
-
acc += tl.dot(x, b)
|
| 1384 |
-
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
| 1385 |
-
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
| 1386 |
-
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
| 1387 |
-
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
| 1388 |
-
|
| 1389 |
-
# If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
|
| 1390 |
-
if start_idx < pid_c * chunk_size:
|
| 1391 |
-
chunk_states_ptrs = chunk_states_ptr + (
|
| 1392 |
-
offs_m[:, None] * stride_chunk_states_hdim
|
| 1393 |
-
+ offs_n[None, :] * stride_chunk_states_dstate
|
| 1394 |
-
)
|
| 1395 |
-
chunk_states = tl.load(
|
| 1396 |
-
chunk_states_ptrs,
|
| 1397 |
-
mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate),
|
| 1398 |
-
other=0.0,
|
| 1399 |
-
).to(tl.float32)
|
| 1400 |
-
# scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0)
|
| 1401 |
-
scale = tl.exp(dA_cs_last)
|
| 1402 |
-
acc += chunk_states * scale
|
| 1403 |
-
|
| 1404 |
-
states = acc.to(states_ptr.dtype.element_ty)
|
| 1405 |
-
|
| 1406 |
-
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
|
| 1407 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 1408 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 1409 |
-
states_ptrs = states_ptr + (
|
| 1410 |
-
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
|
| 1411 |
-
)
|
| 1412 |
-
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
| 1413 |
-
tl.store(states_ptrs, states, mask=c_mask)
|
| 1414 |
-
|
| 1415 |
-
|
| 1416 |
-
def _chunk_cumsum_fwd(
|
| 1417 |
-
dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))
|
| 1418 |
-
):
|
| 1419 |
-
batch, seqlen, nheads = dt.shape
|
| 1420 |
-
assert A.shape == (nheads,)
|
| 1421 |
-
if dt_bias is not None:
|
| 1422 |
-
assert dt_bias.shape == (nheads,)
|
| 1423 |
-
nchunks = math.ceil(seqlen / chunk_size)
|
| 1424 |
-
dt_out = torch.empty(
|
| 1425 |
-
batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
|
| 1426 |
-
)
|
| 1427 |
-
dA_cumsum = torch.empty(
|
| 1428 |
-
batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
|
| 1429 |
-
)
|
| 1430 |
-
grid_chunk_cs = lambda META: (
|
| 1431 |
-
batch,
|
| 1432 |
-
nchunks,
|
| 1433 |
-
triton.cdiv(nheads, META["BLOCK_SIZE_H"]),
|
| 1434 |
-
)
|
| 1435 |
-
with torch.cuda.device(dt.device.index):
|
| 1436 |
-
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
|
| 1437 |
-
dt,
|
| 1438 |
-
A,
|
| 1439 |
-
dt_bias,
|
| 1440 |
-
dt_out,
|
| 1441 |
-
dA_cumsum,
|
| 1442 |
-
batch,
|
| 1443 |
-
seqlen,
|
| 1444 |
-
nheads,
|
| 1445 |
-
chunk_size,
|
| 1446 |
-
dt_limit[0],
|
| 1447 |
-
dt_limit[1],
|
| 1448 |
-
dt.stride(0),
|
| 1449 |
-
dt.stride(1),
|
| 1450 |
-
dt.stride(2),
|
| 1451 |
-
A.stride(0),
|
| 1452 |
-
dt_bias.stride(0) if dt_bias is not None else 0,
|
| 1453 |
-
dt_out.stride(0),
|
| 1454 |
-
dt_out.stride(2),
|
| 1455 |
-
dt_out.stride(1),
|
| 1456 |
-
dt_out.stride(3),
|
| 1457 |
-
dA_cumsum.stride(0),
|
| 1458 |
-
dA_cumsum.stride(2),
|
| 1459 |
-
dA_cumsum.stride(1),
|
| 1460 |
-
dA_cumsum.stride(3),
|
| 1461 |
-
dt_softplus,
|
| 1462 |
-
HAS_DT_BIAS=dt_bias is not None,
|
| 1463 |
-
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
|
| 1464 |
-
)
|
| 1465 |
-
return dA_cumsum, dt_out
|
| 1466 |
-
|
| 1467 |
-
|
| 1468 |
-
def _chunk_cumsum_bwd(
|
| 1469 |
-
ddA,
|
| 1470 |
-
ddt_out,
|
| 1471 |
-
dt,
|
| 1472 |
-
A,
|
| 1473 |
-
dt_bias=None,
|
| 1474 |
-
dt_softplus=False,
|
| 1475 |
-
dt_limit=(0.0, float("inf")),
|
| 1476 |
-
ddt=None,
|
| 1477 |
-
):
|
| 1478 |
-
batch, seqlen, nheads = dt.shape
|
| 1479 |
-
_, _, nchunks, chunk_size = ddA.shape
|
| 1480 |
-
assert ddA.shape == (batch, nheads, nchunks, chunk_size)
|
| 1481 |
-
assert ddt_out.shape == (batch, nheads, nchunks, chunk_size)
|
| 1482 |
-
assert A.shape == (nheads,)
|
| 1483 |
-
if dt_bias is not None:
|
| 1484 |
-
assert dt_bias.shape == (nheads,)
|
| 1485 |
-
ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32)
|
| 1486 |
-
else:
|
| 1487 |
-
ddt_bias = None
|
| 1488 |
-
if ddt is not None:
|
| 1489 |
-
assert ddt.shape == dt.shape
|
| 1490 |
-
else:
|
| 1491 |
-
ddt = torch.empty_like(dt)
|
| 1492 |
-
dA = torch.empty_like(A, dtype=torch.float32)
|
| 1493 |
-
grid_chunk_cs = lambda META: (
|
| 1494 |
-
batch,
|
| 1495 |
-
nchunks,
|
| 1496 |
-
triton.cdiv(nheads, META["BLOCK_SIZE_H"]),
|
| 1497 |
-
)
|
| 1498 |
-
with torch.cuda.device(dt.device.index):
|
| 1499 |
-
_chunk_cumsum_bwd_kernel[grid_chunk_cs](
|
| 1500 |
-
ddA,
|
| 1501 |
-
ddt_out,
|
| 1502 |
-
dt,
|
| 1503 |
-
A,
|
| 1504 |
-
dt_bias,
|
| 1505 |
-
ddt,
|
| 1506 |
-
dA,
|
| 1507 |
-
ddt_bias,
|
| 1508 |
-
batch,
|
| 1509 |
-
seqlen,
|
| 1510 |
-
nheads,
|
| 1511 |
-
chunk_size,
|
| 1512 |
-
dt_limit[0],
|
| 1513 |
-
dt_limit[1],
|
| 1514 |
-
ddA.stride(0),
|
| 1515 |
-
ddA.stride(2),
|
| 1516 |
-
ddA.stride(1),
|
| 1517 |
-
ddA.stride(3),
|
| 1518 |
-
ddt_out.stride(0),
|
| 1519 |
-
ddt_out.stride(2),
|
| 1520 |
-
ddt_out.stride(1),
|
| 1521 |
-
ddt_out.stride(3),
|
| 1522 |
-
dt.stride(0),
|
| 1523 |
-
dt.stride(1),
|
| 1524 |
-
dt.stride(2),
|
| 1525 |
-
A.stride(0),
|
| 1526 |
-
dt_bias.stride(0) if dt_bias is not None else 0,
|
| 1527 |
-
ddt.stride(0),
|
| 1528 |
-
ddt.stride(1),
|
| 1529 |
-
ddt.stride(2),
|
| 1530 |
-
dA.stride(0),
|
| 1531 |
-
ddt_bias.stride(0) if ddt_bias is not None else 0,
|
| 1532 |
-
dt_softplus,
|
| 1533 |
-
HAS_DT_BIAS=dt_bias is not None,
|
| 1534 |
-
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
|
| 1535 |
-
)
|
| 1536 |
-
return ddt, dA, ddt_bias
|
| 1537 |
-
|
| 1538 |
-
|
| 1539 |
-
def _chunk_state_fwd(
|
| 1540 |
-
B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True
|
| 1541 |
-
):
|
| 1542 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1543 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1544 |
-
_, _, ngroups, dstate = B.shape
|
| 1545 |
-
assert nheads % ngroups == 0
|
| 1546 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 1547 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 1548 |
-
assert dA_cumsum.shape == dt.shape
|
| 1549 |
-
if seq_idx is not None:
|
| 1550 |
-
assert seq_idx.shape == (batch, seqlen)
|
| 1551 |
-
if states is not None:
|
| 1552 |
-
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
|
| 1553 |
-
else:
|
| 1554 |
-
states_dtype = torch.float32 if states_in_fp32 else B.dtype
|
| 1555 |
-
states = torch.empty(
|
| 1556 |
-
(batch, nchunks, nheads, headdim, dstate),
|
| 1557 |
-
device=x.device,
|
| 1558 |
-
dtype=states_dtype,
|
| 1559 |
-
)
|
| 1560 |
-
grid = lambda META: (
|
| 1561 |
-
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
|
| 1562 |
-
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
| 1563 |
-
batch * nchunks,
|
| 1564 |
-
nheads,
|
| 1565 |
-
)
|
| 1566 |
-
with torch.cuda.device(x.device.index):
|
| 1567 |
-
_chunk_state_fwd_kernel[grid](
|
| 1568 |
-
x,
|
| 1569 |
-
B,
|
| 1570 |
-
states,
|
| 1571 |
-
dt,
|
| 1572 |
-
dA_cumsum,
|
| 1573 |
-
seq_idx,
|
| 1574 |
-
headdim,
|
| 1575 |
-
dstate,
|
| 1576 |
-
chunk_size,
|
| 1577 |
-
batch,
|
| 1578 |
-
seqlen,
|
| 1579 |
-
nheads // ngroups,
|
| 1580 |
-
x.stride(0),
|
| 1581 |
-
x.stride(1),
|
| 1582 |
-
x.stride(2),
|
| 1583 |
-
x.stride(3),
|
| 1584 |
-
B.stride(0),
|
| 1585 |
-
B.stride(1),
|
| 1586 |
-
B.stride(2),
|
| 1587 |
-
B.stride(-1),
|
| 1588 |
-
states.stride(0),
|
| 1589 |
-
states.stride(1),
|
| 1590 |
-
states.stride(2),
|
| 1591 |
-
states.stride(3),
|
| 1592 |
-
states.stride(4),
|
| 1593 |
-
dt.stride(0),
|
| 1594 |
-
dt.stride(2),
|
| 1595 |
-
dt.stride(1),
|
| 1596 |
-
dt.stride(3),
|
| 1597 |
-
dA_cumsum.stride(0),
|
| 1598 |
-
dA_cumsum.stride(2),
|
| 1599 |
-
dA_cumsum.stride(1),
|
| 1600 |
-
dA_cumsum.stride(3),
|
| 1601 |
-
*(
|
| 1602 |
-
(seq_idx.stride(0), seq_idx.stride(1))
|
| 1603 |
-
if seq_idx is not None
|
| 1604 |
-
else (0, 0)
|
| 1605 |
-
),
|
| 1606 |
-
HAS_SEQ_IDX=seq_idx is not None,
|
| 1607 |
-
)
|
| 1608 |
-
return states
|
| 1609 |
-
|
| 1610 |
-
|
| 1611 |
-
def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None):
|
| 1612 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1613 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1614 |
-
_, _, ngroups, dstate = B.shape
|
| 1615 |
-
assert nheads % ngroups == 0
|
| 1616 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 1617 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 1618 |
-
assert dA_cumsum.shape == dt.shape
|
| 1619 |
-
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
| 1620 |
-
if dx is not None:
|
| 1621 |
-
assert dx.shape == x.shape
|
| 1622 |
-
else:
|
| 1623 |
-
dx = torch.empty_like(x)
|
| 1624 |
-
ddt = torch.empty(
|
| 1625 |
-
batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
|
| 1626 |
-
)
|
| 1627 |
-
ddA_cumsum = torch.empty(
|
| 1628 |
-
batch, nheads, nchunks, chunk_size, device=dA_cumsum.device, dtype=torch.float32
|
| 1629 |
-
)
|
| 1630 |
-
grid_dx = lambda META: (
|
| 1631 |
-
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
| 1632 |
-
* triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
|
| 1633 |
-
batch * nchunks,
|
| 1634 |
-
nheads,
|
| 1635 |
-
)
|
| 1636 |
-
with torch.cuda.device(x.device.index):
|
| 1637 |
-
_chunk_state_bwd_dx_kernel[grid_dx](
|
| 1638 |
-
x,
|
| 1639 |
-
B,
|
| 1640 |
-
dstates,
|
| 1641 |
-
dt,
|
| 1642 |
-
dA_cumsum,
|
| 1643 |
-
dx,
|
| 1644 |
-
ddt,
|
| 1645 |
-
ddA_cumsum,
|
| 1646 |
-
chunk_size,
|
| 1647 |
-
headdim,
|
| 1648 |
-
dstate,
|
| 1649 |
-
batch,
|
| 1650 |
-
seqlen,
|
| 1651 |
-
nheads // ngroups,
|
| 1652 |
-
x.stride(0),
|
| 1653 |
-
x.stride(1),
|
| 1654 |
-
x.stride(2),
|
| 1655 |
-
x.stride(3),
|
| 1656 |
-
B.stride(0),
|
| 1657 |
-
B.stride(1),
|
| 1658 |
-
B.stride(2),
|
| 1659 |
-
B.stride(-1),
|
| 1660 |
-
dstates.stride(0),
|
| 1661 |
-
dstates.stride(1),
|
| 1662 |
-
dstates.stride(2),
|
| 1663 |
-
dstates.stride(3),
|
| 1664 |
-
dstates.stride(4),
|
| 1665 |
-
dt.stride(0),
|
| 1666 |
-
dt.stride(2),
|
| 1667 |
-
dt.stride(1),
|
| 1668 |
-
dt.stride(3),
|
| 1669 |
-
dA_cumsum.stride(0),
|
| 1670 |
-
dA_cumsum.stride(2),
|
| 1671 |
-
dA_cumsum.stride(1),
|
| 1672 |
-
dA_cumsum.stride(3),
|
| 1673 |
-
dx.stride(0),
|
| 1674 |
-
dx.stride(1),
|
| 1675 |
-
dx.stride(2),
|
| 1676 |
-
dx.stride(3),
|
| 1677 |
-
ddt.stride(0),
|
| 1678 |
-
ddt.stride(2),
|
| 1679 |
-
ddt.stride(1),
|
| 1680 |
-
ddt.stride(3),
|
| 1681 |
-
ddA_cumsum.stride(0),
|
| 1682 |
-
ddA_cumsum.stride(2),
|
| 1683 |
-
ddA_cumsum.stride(1),
|
| 1684 |
-
ddA_cumsum.stride(3),
|
| 1685 |
-
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
| 1686 |
-
)
|
| 1687 |
-
return dx, ddt.to(dt.dtype), ddA_cumsum.to(dA_cumsum.dtype)
|
| 1688 |
-
|
| 1689 |
-
|
| 1690 |
-
def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1):
|
| 1691 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1692 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1693 |
-
dstate = dstates.shape[-1]
|
| 1694 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 1695 |
-
assert dA_cumsum.shape == dt.shape
|
| 1696 |
-
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
| 1697 |
-
if seq_idx is not None:
|
| 1698 |
-
assert seq_idx.shape == (batch, seqlen)
|
| 1699 |
-
if B is not None:
|
| 1700 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 1701 |
-
B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3))
|
| 1702 |
-
# Use torch.empty since the Triton kernel will call init_to_zero
|
| 1703 |
-
ddA_cumsum = torch.empty(
|
| 1704 |
-
batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32
|
| 1705 |
-
)
|
| 1706 |
-
ddA_cumsum_strides = (
|
| 1707 |
-
ddA_cumsum.stride(0),
|
| 1708 |
-
ddA_cumsum.stride(2),
|
| 1709 |
-
ddA_cumsum.stride(1),
|
| 1710 |
-
ddA_cumsum.stride(3),
|
| 1711 |
-
)
|
| 1712 |
-
else:
|
| 1713 |
-
B_strides = (0, 0, 0, 0)
|
| 1714 |
-
ddA_cumsum = None
|
| 1715 |
-
ddA_cumsum_strides = (0, 0, 0, 0)
|
| 1716 |
-
nheads_ngroups_ratio = nheads // ngroups
|
| 1717 |
-
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
| 1718 |
-
nheads_per_program = max(
|
| 1719 |
-
min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1
|
| 1720 |
-
)
|
| 1721 |
-
nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)
|
| 1722 |
-
dB = torch.empty(
|
| 1723 |
-
batch, seqlen, nsplits, ngroups, dstate, device=x.device, dtype=torch.float32
|
| 1724 |
-
)
|
| 1725 |
-
grid_db = lambda META: (
|
| 1726 |
-
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
| 1727 |
-
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
| 1728 |
-
batch * nchunks,
|
| 1729 |
-
nsplits * ngroups,
|
| 1730 |
-
)
|
| 1731 |
-
with torch.cuda.device(x.device.index):
|
| 1732 |
-
_chunk_state_bwd_db_kernel[grid_db](
|
| 1733 |
-
x,
|
| 1734 |
-
dstates,
|
| 1735 |
-
B,
|
| 1736 |
-
dt,
|
| 1737 |
-
dA_cumsum,
|
| 1738 |
-
seq_idx,
|
| 1739 |
-
dB,
|
| 1740 |
-
ddA_cumsum,
|
| 1741 |
-
chunk_size,
|
| 1742 |
-
dstate,
|
| 1743 |
-
headdim,
|
| 1744 |
-
batch,
|
| 1745 |
-
seqlen,
|
| 1746 |
-
nheads,
|
| 1747 |
-
nheads_per_program,
|
| 1748 |
-
ngroups,
|
| 1749 |
-
x.stride(0),
|
| 1750 |
-
x.stride(1),
|
| 1751 |
-
x.stride(2),
|
| 1752 |
-
x.stride(3),
|
| 1753 |
-
dstates.stride(0),
|
| 1754 |
-
dstates.stride(1),
|
| 1755 |
-
dstates.stride(2),
|
| 1756 |
-
dstates.stride(3),
|
| 1757 |
-
dstates.stride(4),
|
| 1758 |
-
*B_strides,
|
| 1759 |
-
dt.stride(0),
|
| 1760 |
-
dt.stride(2),
|
| 1761 |
-
dt.stride(1),
|
| 1762 |
-
dt.stride(3),
|
| 1763 |
-
dA_cumsum.stride(0),
|
| 1764 |
-
dA_cumsum.stride(2),
|
| 1765 |
-
dA_cumsum.stride(1),
|
| 1766 |
-
dA_cumsum.stride(3),
|
| 1767 |
-
*(
|
| 1768 |
-
(seq_idx.stride(0), seq_idx.stride(1))
|
| 1769 |
-
if seq_idx is not None
|
| 1770 |
-
else (0, 0)
|
| 1771 |
-
),
|
| 1772 |
-
dB.stride(0),
|
| 1773 |
-
dB.stride(1),
|
| 1774 |
-
dB.stride(2),
|
| 1775 |
-
dB.stride(3),
|
| 1776 |
-
dB.stride(4),
|
| 1777 |
-
*ddA_cumsum_strides,
|
| 1778 |
-
HAS_DDA_CS=ddA_cumsum is not None,
|
| 1779 |
-
HAS_SEQ_IDX=seq_idx is not None,
|
| 1780 |
-
BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
|
| 1781 |
-
)
|
| 1782 |
-
dB = dB.sum(2)
|
| 1783 |
-
if ddA_cumsum is not None:
|
| 1784 |
-
# The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute
|
| 1785 |
-
# to the state of the chunk.
|
| 1786 |
-
# torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
|
| 1787 |
-
# But it's easier to just do the cumsum for all elements, the result will be the same.
|
| 1788 |
-
torch.cumsum(ddA_cumsum, dim=-1, out=ddA_cumsum)
|
| 1789 |
-
return dB if B is None else (dB, ddA_cumsum)
|
| 1790 |
-
|
| 1791 |
-
|
| 1792 |
-
def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None):
|
| 1793 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1794 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1795 |
-
_, _, ngroups, dstate = B.shape
|
| 1796 |
-
assert nheads % ngroups == 0
|
| 1797 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 1798 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 1799 |
-
assert dA_cumsum.shape == dt.shape
|
| 1800 |
-
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
| 1801 |
-
if seq_idx is not None:
|
| 1802 |
-
assert seq_idx.shape == (batch, seqlen)
|
| 1803 |
-
# Use torch.empty since the Triton kernel will call init_to_zero
|
| 1804 |
-
ddA_cumsum = torch.empty(
|
| 1805 |
-
batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32
|
| 1806 |
-
)
|
| 1807 |
-
grid_ddtcs = lambda META: (
|
| 1808 |
-
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
| 1809 |
-
* triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
|
| 1810 |
-
batch * nchunks,
|
| 1811 |
-
nheads,
|
| 1812 |
-
)
|
| 1813 |
-
with torch.cuda.device(x.device.index):
|
| 1814 |
-
_chunk_state_bwd_ddAcs_stable_kernel[grid_ddtcs](
|
| 1815 |
-
x,
|
| 1816 |
-
B,
|
| 1817 |
-
dstates,
|
| 1818 |
-
dt,
|
| 1819 |
-
dA_cumsum,
|
| 1820 |
-
seq_idx,
|
| 1821 |
-
ddA_cumsum,
|
| 1822 |
-
chunk_size,
|
| 1823 |
-
headdim,
|
| 1824 |
-
dstate,
|
| 1825 |
-
batch,
|
| 1826 |
-
seqlen,
|
| 1827 |
-
nheads // ngroups,
|
| 1828 |
-
x.stride(0),
|
| 1829 |
-
x.stride(1),
|
| 1830 |
-
x.stride(2),
|
| 1831 |
-
x.stride(3),
|
| 1832 |
-
B.stride(0),
|
| 1833 |
-
B.stride(1),
|
| 1834 |
-
B.stride(2),
|
| 1835 |
-
B.stride(-1),
|
| 1836 |
-
dstates.stride(0),
|
| 1837 |
-
dstates.stride(1),
|
| 1838 |
-
dstates.stride(2),
|
| 1839 |
-
dstates.stride(3),
|
| 1840 |
-
dstates.stride(4),
|
| 1841 |
-
dt.stride(0),
|
| 1842 |
-
dt.stride(2),
|
| 1843 |
-
dt.stride(1),
|
| 1844 |
-
dt.stride(3),
|
| 1845 |
-
dA_cumsum.stride(0),
|
| 1846 |
-
dA_cumsum.stride(2),
|
| 1847 |
-
dA_cumsum.stride(1),
|
| 1848 |
-
dA_cumsum.stride(3),
|
| 1849 |
-
*(
|
| 1850 |
-
(seq_idx.stride(0), seq_idx.stride(1))
|
| 1851 |
-
if seq_idx is not None
|
| 1852 |
-
else (0, 0)
|
| 1853 |
-
),
|
| 1854 |
-
ddA_cumsum.stride(0),
|
| 1855 |
-
ddA_cumsum.stride(2),
|
| 1856 |
-
ddA_cumsum.stride(1),
|
| 1857 |
-
ddA_cumsum.stride(3),
|
| 1858 |
-
HAS_SEQ_IDX=seq_idx is not None,
|
| 1859 |
-
BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16),
|
| 1860 |
-
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
| 1861 |
-
)
|
| 1862 |
-
torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
|
| 1863 |
-
return ddA_cumsum
|
| 1864 |
-
|
| 1865 |
-
|
| 1866 |
-
def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states):
|
| 1867 |
-
total_seqlen, nheads, headdim = x.shape
|
| 1868 |
-
_, nchunks, chunk_size = dt.shape
|
| 1869 |
-
_, ngroups, dstate = B.shape
|
| 1870 |
-
batch = cu_seqlens.shape[0] - 1
|
| 1871 |
-
cu_seqlens = cu_seqlens.contiguous()
|
| 1872 |
-
assert nheads % ngroups == 0
|
| 1873 |
-
assert B.shape == (total_seqlen, ngroups, dstate)
|
| 1874 |
-
assert dt.shape == (nheads, nchunks, chunk_size)
|
| 1875 |
-
assert dA_cumsum.shape == dt.shape
|
| 1876 |
-
assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
|
| 1877 |
-
states = torch.empty(
|
| 1878 |
-
batch,
|
| 1879 |
-
nheads,
|
| 1880 |
-
headdim,
|
| 1881 |
-
dstate,
|
| 1882 |
-
dtype=chunk_states.dtype,
|
| 1883 |
-
device=chunk_states.device,
|
| 1884 |
-
)
|
| 1885 |
-
grid = lambda META: (
|
| 1886 |
-
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
|
| 1887 |
-
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
| 1888 |
-
batch,
|
| 1889 |
-
nheads,
|
| 1890 |
-
)
|
| 1891 |
-
with torch.cuda.device(x.device.index):
|
| 1892 |
-
_chunk_state_varlen_kernel[grid](
|
| 1893 |
-
x,
|
| 1894 |
-
B,
|
| 1895 |
-
dt,
|
| 1896 |
-
dA_cumsum,
|
| 1897 |
-
chunk_states,
|
| 1898 |
-
cu_seqlens,
|
| 1899 |
-
states,
|
| 1900 |
-
headdim,
|
| 1901 |
-
dstate,
|
| 1902 |
-
chunk_size,
|
| 1903 |
-
total_seqlen,
|
| 1904 |
-
nheads // ngroups,
|
| 1905 |
-
x.stride(0),
|
| 1906 |
-
x.stride(1),
|
| 1907 |
-
x.stride(2),
|
| 1908 |
-
B.stride(0),
|
| 1909 |
-
B.stride(1),
|
| 1910 |
-
B.stride(2),
|
| 1911 |
-
dt.stride(1),
|
| 1912 |
-
dt.stride(0),
|
| 1913 |
-
dt.stride(2),
|
| 1914 |
-
dA_cumsum.stride(1),
|
| 1915 |
-
dA_cumsum.stride(0),
|
| 1916 |
-
dA_cumsum.stride(2),
|
| 1917 |
-
chunk_states.stride(0),
|
| 1918 |
-
chunk_states.stride(1),
|
| 1919 |
-
chunk_states.stride(2),
|
| 1920 |
-
chunk_states.stride(3),
|
| 1921 |
-
states.stride(0),
|
| 1922 |
-
states.stride(1),
|
| 1923 |
-
states.stride(2),
|
| 1924 |
-
states.stride(3),
|
| 1925 |
-
)
|
| 1926 |
-
return states
|
| 1927 |
-
|
| 1928 |
-
|
| 1929 |
-
class ChunkStateFn(torch.autograd.Function):
|
| 1930 |
-
|
| 1931 |
-
@staticmethod
|
| 1932 |
-
def forward(ctx, B, x, dt, dA_cumsum, states_in_fp32=True):
|
| 1933 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1934 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1935 |
-
assert seqlen <= nchunks * chunk_size
|
| 1936 |
-
_, _, ngroups, dstate = B.shape
|
| 1937 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 1938 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 1939 |
-
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
| 1940 |
-
if B.stride(-1) != 1:
|
| 1941 |
-
B = B.contiguous()
|
| 1942 |
-
if (
|
| 1943 |
-
x.stride(-1) != 1 and x.stride(1) != 1
|
| 1944 |
-
): # Either M or K dimension should be contiguous
|
| 1945 |
-
x = x.contiguous()
|
| 1946 |
-
states = _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=states_in_fp32)
|
| 1947 |
-
ctx.save_for_backward(B, x, dt, dA_cumsum)
|
| 1948 |
-
return states
|
| 1949 |
-
|
| 1950 |
-
@staticmethod
|
| 1951 |
-
def backward(ctx, dstates):
|
| 1952 |
-
B, x, dt, dA_cumsum = ctx.saved_tensors
|
| 1953 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1954 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1955 |
-
_, _, ngroups, dstate = B.shape
|
| 1956 |
-
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
| 1957 |
-
if dstates.stride(-1) != 1:
|
| 1958 |
-
dstates = dstates.contiguous()
|
| 1959 |
-
dx, ddt, ddA_cumsum = _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates)
|
| 1960 |
-
dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, ngroups=ngroups)
|
| 1961 |
-
dB = dB.to(B.dtype)
|
| 1962 |
-
return dB, dx, ddt, ddA_cumsum, None
|
| 1963 |
-
|
| 1964 |
-
|
| 1965 |
-
def chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True):
|
| 1966 |
-
"""
|
| 1967 |
-
Argument:
|
| 1968 |
-
B: (batch, seqlen, ngroups, headdim)
|
| 1969 |
-
x: (batch, seqlen, nheads, headdim)
|
| 1970 |
-
dt: (batch, nheads, nchunks, chunk_size)
|
| 1971 |
-
dA_cumsum: (batch, nheads, nchunks, chunk_size)
|
| 1972 |
-
Return:
|
| 1973 |
-
states: (batch, nchunks, nheads, headdim, dstate)
|
| 1974 |
-
"""
|
| 1975 |
-
return ChunkStateFn.apply(B, x, dt, dA_cumsum, states_in_fp32)
|
| 1976 |
-
|
| 1977 |
-
|
| 1978 |
-
def chunk_state_ref(B, x, dt, dA_cumsum):
|
| 1979 |
-
"""
|
| 1980 |
-
Argument:
|
| 1981 |
-
B: (batch, seqlen, ngroups, headdim)
|
| 1982 |
-
x: (batch, seqlen, nheads, headdim)
|
| 1983 |
-
dt: (batch, nheads, nchunks, chunk_size)
|
| 1984 |
-
dA_cumsum: (batch, nheads, nchunks, chunk_size)
|
| 1985 |
-
Return:
|
| 1986 |
-
states: (batch, nchunks, nheads, headdim, dstate)
|
| 1987 |
-
"""
|
| 1988 |
-
# Check constraints.
|
| 1989 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1990 |
-
dstate = B.shape[-1]
|
| 1991 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1992 |
-
assert seqlen <= nchunks * chunk_size
|
| 1993 |
-
assert x.shape == (batch, seqlen, nheads, headdim)
|
| 1994 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 1995 |
-
ngroups = B.shape[2]
|
| 1996 |
-
assert nheads % ngroups == 0
|
| 1997 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 1998 |
-
B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
|
| 1999 |
-
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
| 2000 |
-
if seqlen < nchunks * chunk_size:
|
| 2001 |
-
x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
|
| 2002 |
-
B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
|
| 2003 |
-
x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size)
|
| 2004 |
-
B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size)
|
| 2005 |
-
decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum))
|
| 2006 |
-
return torch.einsum(
|
| 2007 |
-
"bclhn,bhcl,bhcl,bclhp->bchpn",
|
| 2008 |
-
B.to(x.dtype),
|
| 2009 |
-
decay_states.to(x.dtype),
|
| 2010 |
-
dt.to(x.dtype),
|
| 2011 |
-
x,
|
| 2012 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/ops/triton/ssd_combined.py
DELETED
|
@@ -1,1884 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
| 2 |
-
|
| 3 |
-
"""We want triton==2.1.0 or 2.2.0 for this
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
from typing import Optional
|
| 7 |
-
|
| 8 |
-
import math
|
| 9 |
-
from packaging import version
|
| 10 |
-
|
| 11 |
-
import torch
|
| 12 |
-
import torch.nn.functional as F
|
| 13 |
-
from torch import Tensor
|
| 14 |
-
from ...utils.torch import custom_bwd, custom_fwd
|
| 15 |
-
|
| 16 |
-
import triton
|
| 17 |
-
import triton.language as tl
|
| 18 |
-
|
| 19 |
-
from einops import rearrange, repeat
|
| 20 |
-
|
| 21 |
-
try:
|
| 22 |
-
from causal_conv1d import causal_conv1d_fn
|
| 23 |
-
import causal_conv1d_cuda
|
| 24 |
-
except ImportError:
|
| 25 |
-
causal_conv1d_fn, causal_conv1d_cuda = None, None
|
| 26 |
-
|
| 27 |
-
from .ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
|
| 28 |
-
from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd
|
| 29 |
-
from .ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db
|
| 30 |
-
from .ssd_chunk_state import _chunk_state_bwd_ddAcs_stable
|
| 31 |
-
from .ssd_chunk_state import chunk_state, chunk_state_ref
|
| 32 |
-
from .ssd_chunk_state import chunk_state_varlen
|
| 33 |
-
from .ssd_state_passing import _state_passing_fwd, _state_passing_bwd
|
| 34 |
-
from .ssd_state_passing import state_passing, state_passing_ref
|
| 35 |
-
from .ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates
|
| 36 |
-
from .ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb
|
| 37 |
-
from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable
|
| 38 |
-
from .ssd_chunk_scan import chunk_scan, chunk_scan_ref
|
| 39 |
-
from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev
|
| 40 |
-
from .layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd
|
| 41 |
-
from .k_activations import _swiglu_fwd, _swiglu_bwd
|
| 42 |
-
|
| 43 |
-
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
def init_to_zero(names):
|
| 47 |
-
return lambda nargs: [
|
| 48 |
-
nargs[name].zero_() for name in names if nargs[name] is not None
|
| 49 |
-
]
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
@triton.autotune(
|
| 53 |
-
configs=[
|
| 54 |
-
triton.Config(
|
| 55 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
| 56 |
-
num_stages=3,
|
| 57 |
-
num_warps=8,
|
| 58 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 59 |
-
),
|
| 60 |
-
triton.Config(
|
| 61 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
| 62 |
-
num_stages=4,
|
| 63 |
-
num_warps=4,
|
| 64 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 65 |
-
),
|
| 66 |
-
triton.Config(
|
| 67 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 68 |
-
num_stages=4,
|
| 69 |
-
num_warps=4,
|
| 70 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 71 |
-
),
|
| 72 |
-
triton.Config(
|
| 73 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 74 |
-
num_stages=4,
|
| 75 |
-
num_warps=4,
|
| 76 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 77 |
-
),
|
| 78 |
-
triton.Config(
|
| 79 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 80 |
-
num_stages=4,
|
| 81 |
-
num_warps=4,
|
| 82 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 83 |
-
),
|
| 84 |
-
triton.Config(
|
| 85 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 86 |
-
num_stages=4,
|
| 87 |
-
num_warps=4,
|
| 88 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 89 |
-
),
|
| 90 |
-
triton.Config(
|
| 91 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 92 |
-
num_stages=5,
|
| 93 |
-
num_warps=4,
|
| 94 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 95 |
-
),
|
| 96 |
-
triton.Config(
|
| 97 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 98 |
-
num_stages=5,
|
| 99 |
-
num_warps=4,
|
| 100 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 101 |
-
),
|
| 102 |
-
triton.Config(
|
| 103 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 104 |
-
num_stages=4,
|
| 105 |
-
num_warps=4,
|
| 106 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 107 |
-
),
|
| 108 |
-
],
|
| 109 |
-
key=["chunk_size", "hdim", "dstate"],
|
| 110 |
-
)
|
| 111 |
-
@triton.jit
|
| 112 |
-
def _chunk_scan_chunk_state_bwd_dx_kernel(
|
| 113 |
-
# Pointers to matrices
|
| 114 |
-
x_ptr,
|
| 115 |
-
cb_ptr,
|
| 116 |
-
dout_ptr,
|
| 117 |
-
dt_ptr,
|
| 118 |
-
dA_cumsum_ptr,
|
| 119 |
-
seq_idx_ptr,
|
| 120 |
-
D_ptr,
|
| 121 |
-
b_ptr,
|
| 122 |
-
dstates_ptr,
|
| 123 |
-
dx_ptr,
|
| 124 |
-
ddt_ptr,
|
| 125 |
-
dD_ptr,
|
| 126 |
-
# Matrix dimensions
|
| 127 |
-
chunk_size,
|
| 128 |
-
hdim,
|
| 129 |
-
dstate,
|
| 130 |
-
batch,
|
| 131 |
-
seqlen,
|
| 132 |
-
nheads_ngroups_ratio,
|
| 133 |
-
# Strides
|
| 134 |
-
stride_x_batch,
|
| 135 |
-
stride_x_seqlen,
|
| 136 |
-
stride_x_head,
|
| 137 |
-
stride_x_hdim,
|
| 138 |
-
stride_cb_batch,
|
| 139 |
-
stride_cb_chunk,
|
| 140 |
-
stride_cb_head,
|
| 141 |
-
stride_cb_csize_m,
|
| 142 |
-
stride_cb_csize_k,
|
| 143 |
-
stride_dout_batch,
|
| 144 |
-
stride_dout_seqlen,
|
| 145 |
-
stride_dout_head,
|
| 146 |
-
stride_dout_hdim,
|
| 147 |
-
stride_dt_batch,
|
| 148 |
-
stride_dt_chunk,
|
| 149 |
-
stride_dt_head,
|
| 150 |
-
stride_dt_csize,
|
| 151 |
-
stride_dA_cs_batch,
|
| 152 |
-
stride_dA_cs_chunk,
|
| 153 |
-
stride_dA_cs_head,
|
| 154 |
-
stride_dA_cs_csize,
|
| 155 |
-
stride_seq_idx_batch,
|
| 156 |
-
stride_seq_idx_seqlen,
|
| 157 |
-
stride_D_head,
|
| 158 |
-
stride_b_batch,
|
| 159 |
-
stride_b_seqlen,
|
| 160 |
-
stride_b_head,
|
| 161 |
-
stride_b_dstate,
|
| 162 |
-
stride_dstates_batch,
|
| 163 |
-
stride_dstates_chunk,
|
| 164 |
-
stride_dstates_head,
|
| 165 |
-
stride_dstates_hdim,
|
| 166 |
-
stride_dstates_dstate,
|
| 167 |
-
stride_dx_batch,
|
| 168 |
-
stride_dx_seqlen,
|
| 169 |
-
stride_dx_head,
|
| 170 |
-
stride_dx_hdim,
|
| 171 |
-
stride_ddt_batch,
|
| 172 |
-
stride_ddt_chunk,
|
| 173 |
-
stride_ddt_head,
|
| 174 |
-
stride_ddt_csize,
|
| 175 |
-
stride_dD_batch,
|
| 176 |
-
stride_dD_chunk,
|
| 177 |
-
stride_dD_head,
|
| 178 |
-
stride_dD_csize,
|
| 179 |
-
stride_dD_hdim,
|
| 180 |
-
# Meta-parameters
|
| 181 |
-
HAS_D: tl.constexpr,
|
| 182 |
-
D_HAS_HDIM: tl.constexpr,
|
| 183 |
-
HAS_SEQ_IDX: tl.constexpr,
|
| 184 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 185 |
-
BLOCK_SIZE_N: tl.constexpr,
|
| 186 |
-
BLOCK_SIZE_K: tl.constexpr,
|
| 187 |
-
BLOCK_SIZE_DSTATE: tl.constexpr,
|
| 188 |
-
IS_TRITON_22: tl.constexpr,
|
| 189 |
-
):
|
| 190 |
-
pid_bc = tl.program_id(axis=1)
|
| 191 |
-
pid_c = pid_bc // batch
|
| 192 |
-
pid_b = pid_bc - pid_c * batch
|
| 193 |
-
pid_h = tl.program_id(axis=2)
|
| 194 |
-
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
| 195 |
-
pid_m = tl.program_id(axis=0) // num_pid_n
|
| 196 |
-
pid_n = tl.program_id(axis=0) % num_pid_n
|
| 197 |
-
x_ptr += (
|
| 198 |
-
pid_b * stride_x_batch
|
| 199 |
-
+ pid_c * chunk_size * stride_x_seqlen
|
| 200 |
-
+ pid_h * stride_x_head
|
| 201 |
-
)
|
| 202 |
-
cb_ptr += (
|
| 203 |
-
pid_b * stride_cb_batch
|
| 204 |
-
+ pid_c * stride_cb_chunk
|
| 205 |
-
+ (pid_h // nheads_ngroups_ratio) * stride_cb_head
|
| 206 |
-
)
|
| 207 |
-
dout_ptr += (
|
| 208 |
-
pid_b * stride_dout_batch
|
| 209 |
-
+ pid_c * chunk_size * stride_dout_seqlen
|
| 210 |
-
+ pid_h * stride_dout_head
|
| 211 |
-
)
|
| 212 |
-
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
| 213 |
-
ddt_ptr += (
|
| 214 |
-
pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
|
| 215 |
-
)
|
| 216 |
-
dA_cumsum_ptr += (
|
| 217 |
-
pid_b * stride_dA_cs_batch
|
| 218 |
-
+ pid_c * stride_dA_cs_chunk
|
| 219 |
-
+ pid_h * stride_dA_cs_head
|
| 220 |
-
)
|
| 221 |
-
b_ptr += (
|
| 222 |
-
pid_b * stride_b_batch
|
| 223 |
-
+ pid_c * chunk_size * stride_b_seqlen
|
| 224 |
-
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
| 225 |
-
)
|
| 226 |
-
dstates_ptr += (
|
| 227 |
-
pid_b * stride_dstates_batch
|
| 228 |
-
+ pid_c * stride_dstates_chunk
|
| 229 |
-
+ pid_h * stride_dstates_head
|
| 230 |
-
)
|
| 231 |
-
if HAS_SEQ_IDX:
|
| 232 |
-
seq_idx_ptr += (
|
| 233 |
-
pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
| 234 |
-
)
|
| 235 |
-
|
| 236 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 237 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 238 |
-
|
| 239 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 240 |
-
|
| 241 |
-
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 242 |
-
|
| 243 |
-
dA_cs_m = tl.load(
|
| 244 |
-
dA_cumsum_ptr + offs_m * stride_dA_cs_csize,
|
| 245 |
-
mask=offs_m < chunk_size_limit,
|
| 246 |
-
other=0.0,
|
| 247 |
-
).to(tl.float32)
|
| 248 |
-
|
| 249 |
-
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
| 250 |
-
tl.float32
|
| 251 |
-
)
|
| 252 |
-
if not HAS_SEQ_IDX:
|
| 253 |
-
scale = tl.exp(dA_cs_last - dA_cs_m)
|
| 254 |
-
else:
|
| 255 |
-
seq_idx_m = tl.load(
|
| 256 |
-
seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
| 257 |
-
mask=offs_m < chunk_size_limit,
|
| 258 |
-
other=-1,
|
| 259 |
-
)
|
| 260 |
-
seq_idx_last = tl.load(
|
| 261 |
-
seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
|
| 262 |
-
)
|
| 263 |
-
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
|
| 264 |
-
# Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
| 265 |
-
# However, we're getting error with the Triton compiler 2.1.0 for that code path:
|
| 266 |
-
# Unexpected mma -> mma layout conversion
|
| 267 |
-
# Triton 2.2.0 fixes this
|
| 268 |
-
offs_dstate = tl.arange(
|
| 269 |
-
0,
|
| 270 |
-
(
|
| 271 |
-
BLOCK_SIZE_DSTATE
|
| 272 |
-
if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128
|
| 273 |
-
else BLOCK_SIZE_K
|
| 274 |
-
),
|
| 275 |
-
)
|
| 276 |
-
b_ptrs = b_ptr + (
|
| 277 |
-
offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate
|
| 278 |
-
)
|
| 279 |
-
dstates_ptrs = dstates_ptr + (
|
| 280 |
-
offs_n[None, :] * stride_dstates_hdim
|
| 281 |
-
+ offs_dstate[:, None] * stride_dstates_dstate
|
| 282 |
-
)
|
| 283 |
-
if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128:
|
| 284 |
-
b = tl.load(
|
| 285 |
-
b_ptrs,
|
| 286 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate),
|
| 287 |
-
other=0.0,
|
| 288 |
-
)
|
| 289 |
-
dstates = tl.load(
|
| 290 |
-
dstates_ptrs,
|
| 291 |
-
mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim),
|
| 292 |
-
other=0.0,
|
| 293 |
-
)
|
| 294 |
-
dstates = dstates.to(b_ptr.dtype.element_ty)
|
| 295 |
-
acc = tl.dot(b, dstates) * scale[:, None]
|
| 296 |
-
else:
|
| 297 |
-
for k in range(0, dstate, BLOCK_SIZE_K):
|
| 298 |
-
b = tl.load(
|
| 299 |
-
b_ptrs,
|
| 300 |
-
mask=(offs_m[:, None] < chunk_size_limit)
|
| 301 |
-
& (offs_dstate[None, :] < dstate - k),
|
| 302 |
-
other=0.0,
|
| 303 |
-
)
|
| 304 |
-
dstates = tl.load(
|
| 305 |
-
dstates_ptrs,
|
| 306 |
-
mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim),
|
| 307 |
-
other=0.0,
|
| 308 |
-
)
|
| 309 |
-
dstates = dstates.to(b_ptr.dtype.element_ty)
|
| 310 |
-
acc += tl.dot(b, dstates)
|
| 311 |
-
b_ptrs += BLOCK_SIZE_K * stride_b_dstate
|
| 312 |
-
dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate
|
| 313 |
-
acc *= scale[:, None]
|
| 314 |
-
|
| 315 |
-
# x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
|
| 316 |
-
# x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
|
| 317 |
-
# dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
| 318 |
-
# dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
|
| 319 |
-
# ddt = tl.sum(acc * x, axis=1) * dt_m
|
| 320 |
-
# ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
|
| 321 |
-
# tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
|
| 322 |
-
|
| 323 |
-
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 324 |
-
cb_ptrs = cb_ptr + (
|
| 325 |
-
offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k
|
| 326 |
-
)
|
| 327 |
-
dout_ptrs = dout_ptr + (
|
| 328 |
-
offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim
|
| 329 |
-
)
|
| 330 |
-
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
| 331 |
-
K_MAX = chunk_size_limit
|
| 332 |
-
K_MIN = pid_m * BLOCK_SIZE_M
|
| 333 |
-
cb_ptrs += K_MIN * stride_cb_csize_k
|
| 334 |
-
dout_ptrs += K_MIN * stride_dout_seqlen
|
| 335 |
-
dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize
|
| 336 |
-
for k in range(K_MIN, K_MAX, BLOCK_SIZE_K):
|
| 337 |
-
k = tl.multiple_of(k, BLOCK_SIZE_K)
|
| 338 |
-
# For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower
|
| 339 |
-
cb = tl.load(
|
| 340 |
-
cb_ptrs,
|
| 341 |
-
mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k),
|
| 342 |
-
other=0.0,
|
| 343 |
-
)
|
| 344 |
-
dout = tl.load(
|
| 345 |
-
dout_ptrs,
|
| 346 |
-
mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim),
|
| 347 |
-
other=0.0,
|
| 348 |
-
)
|
| 349 |
-
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(
|
| 350 |
-
tl.float32
|
| 351 |
-
)
|
| 352 |
-
cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
|
| 353 |
-
# If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
|
| 354 |
-
# we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
|
| 355 |
-
# Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
|
| 356 |
-
# This will cause NaN in acc, and hence NaN in dx and ddt.
|
| 357 |
-
mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX)
|
| 358 |
-
cb = tl.where(mask, cb, 0.0)
|
| 359 |
-
cb = cb.to(dout_ptr.dtype.element_ty)
|
| 360 |
-
acc += tl.dot(cb, dout)
|
| 361 |
-
cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
|
| 362 |
-
dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen
|
| 363 |
-
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
| 364 |
-
|
| 365 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 366 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 367 |
-
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
| 368 |
-
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
|
| 369 |
-
dx = acc * dt_m[:, None]
|
| 370 |
-
dx_ptr += (
|
| 371 |
-
pid_b * stride_dx_batch
|
| 372 |
-
+ pid_c * chunk_size * stride_dx_seqlen
|
| 373 |
-
+ pid_h * stride_dx_head
|
| 374 |
-
)
|
| 375 |
-
dx_ptrs = dx_ptr + (
|
| 376 |
-
offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim
|
| 377 |
-
)
|
| 378 |
-
if HAS_D:
|
| 379 |
-
dout_res_ptrs = dout_ptr + (
|
| 380 |
-
offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim
|
| 381 |
-
)
|
| 382 |
-
dout_res = tl.load(
|
| 383 |
-
dout_res_ptrs,
|
| 384 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
| 385 |
-
other=0.0,
|
| 386 |
-
).to(tl.float32)
|
| 387 |
-
if D_HAS_HDIM:
|
| 388 |
-
D = tl.load(
|
| 389 |
-
D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0
|
| 390 |
-
).to(tl.float32)
|
| 391 |
-
else:
|
| 392 |
-
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
|
| 393 |
-
dx += dout_res * D
|
| 394 |
-
tl.store(
|
| 395 |
-
dx_ptrs,
|
| 396 |
-
dx,
|
| 397 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
| 398 |
-
)
|
| 399 |
-
|
| 400 |
-
x_ptrs = x_ptr + (
|
| 401 |
-
offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
|
| 402 |
-
)
|
| 403 |
-
x = tl.load(
|
| 404 |
-
x_ptrs,
|
| 405 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
| 406 |
-
other=0.0,
|
| 407 |
-
).to(tl.float32)
|
| 408 |
-
if HAS_D:
|
| 409 |
-
dD_ptr += (
|
| 410 |
-
pid_b * stride_dD_batch
|
| 411 |
-
+ pid_c * stride_dD_chunk
|
| 412 |
-
+ pid_h * stride_dD_head
|
| 413 |
-
+ pid_m * stride_dD_csize
|
| 414 |
-
)
|
| 415 |
-
if D_HAS_HDIM:
|
| 416 |
-
dD_ptrs = dD_ptr + offs_n * stride_dD_hdim
|
| 417 |
-
dD = tl.sum(dout_res * x, axis=0)
|
| 418 |
-
tl.store(dD_ptrs, dD, mask=offs_n < hdim)
|
| 419 |
-
else:
|
| 420 |
-
dD = tl.sum(dout_res * x)
|
| 421 |
-
tl.store(dD_ptr, dD)
|
| 422 |
-
ddt = tl.sum(acc * x, axis=1)
|
| 423 |
-
ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
|
| 424 |
-
tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
def _chunk_scan_chunk_state_bwd_dx(
|
| 428 |
-
x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None
|
| 429 |
-
):
|
| 430 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 431 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 432 |
-
_, _, ngroups, dstate = B.shape
|
| 433 |
-
assert nheads % ngroups == 0
|
| 434 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 435 |
-
assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
|
| 436 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 437 |
-
assert dA_cumsum.shape == dt.shape
|
| 438 |
-
assert dout.shape == x.shape
|
| 439 |
-
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
| 440 |
-
if seq_idx is not None:
|
| 441 |
-
assert seq_idx.shape == (batch, seqlen)
|
| 442 |
-
if D is not None:
|
| 443 |
-
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
|
| 444 |
-
assert D.stride(-1) == 1
|
| 445 |
-
BLOCK_SIZE_min = 32
|
| 446 |
-
dD = torch.empty(
|
| 447 |
-
triton.cdiv(chunk_size, BLOCK_SIZE_min),
|
| 448 |
-
batch,
|
| 449 |
-
nchunks,
|
| 450 |
-
nheads,
|
| 451 |
-
headdim if D.dim() == 2 else 1,
|
| 452 |
-
device=D.device,
|
| 453 |
-
dtype=torch.float32,
|
| 454 |
-
)
|
| 455 |
-
else:
|
| 456 |
-
dD = None
|
| 457 |
-
dD_strides = (
|
| 458 |
-
(dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))
|
| 459 |
-
if D is not None
|
| 460 |
-
else (0, 0, 0, 0, 0)
|
| 461 |
-
)
|
| 462 |
-
if dx is None:
|
| 463 |
-
dx = torch.empty_like(x)
|
| 464 |
-
else:
|
| 465 |
-
assert dx.shape == x.shape
|
| 466 |
-
ddt = torch.empty(
|
| 467 |
-
batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32
|
| 468 |
-
)
|
| 469 |
-
grid_dx = lambda META: (
|
| 470 |
-
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
| 471 |
-
* triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
|
| 472 |
-
batch * nchunks,
|
| 473 |
-
nheads,
|
| 474 |
-
)
|
| 475 |
-
with torch.cuda.device(x.device.index):
|
| 476 |
-
_chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](
|
| 477 |
-
x,
|
| 478 |
-
CB,
|
| 479 |
-
dout,
|
| 480 |
-
dt,
|
| 481 |
-
dA_cumsum,
|
| 482 |
-
seq_idx,
|
| 483 |
-
D,
|
| 484 |
-
B,
|
| 485 |
-
dstates,
|
| 486 |
-
dx,
|
| 487 |
-
ddt,
|
| 488 |
-
dD,
|
| 489 |
-
chunk_size,
|
| 490 |
-
headdim,
|
| 491 |
-
dstate,
|
| 492 |
-
batch,
|
| 493 |
-
seqlen,
|
| 494 |
-
nheads // ngroups,
|
| 495 |
-
x.stride(0),
|
| 496 |
-
x.stride(1),
|
| 497 |
-
x.stride(2),
|
| 498 |
-
x.stride(3),
|
| 499 |
-
CB.stride(0),
|
| 500 |
-
CB.stride(1),
|
| 501 |
-
CB.stride(2),
|
| 502 |
-
CB.stride(-1),
|
| 503 |
-
CB.stride(-2),
|
| 504 |
-
dout.stride(0),
|
| 505 |
-
dout.stride(1),
|
| 506 |
-
dout.stride(2),
|
| 507 |
-
dout.stride(3),
|
| 508 |
-
dt.stride(0),
|
| 509 |
-
dt.stride(2),
|
| 510 |
-
dt.stride(1),
|
| 511 |
-
dt.stride(3),
|
| 512 |
-
dA_cumsum.stride(0),
|
| 513 |
-
dA_cumsum.stride(2),
|
| 514 |
-
dA_cumsum.stride(1),
|
| 515 |
-
dA_cumsum.stride(3),
|
| 516 |
-
*(
|
| 517 |
-
(seq_idx.stride(0), seq_idx.stride(1))
|
| 518 |
-
if seq_idx is not None
|
| 519 |
-
else (0, 0)
|
| 520 |
-
),
|
| 521 |
-
D.stride(0) if D is not None else 0,
|
| 522 |
-
B.stride(0),
|
| 523 |
-
B.stride(1),
|
| 524 |
-
B.stride(2),
|
| 525 |
-
B.stride(3),
|
| 526 |
-
dstates.stride(0),
|
| 527 |
-
dstates.stride(1),
|
| 528 |
-
dstates.stride(2),
|
| 529 |
-
dstates.stride(3),
|
| 530 |
-
dstates.stride(4),
|
| 531 |
-
dx.stride(0),
|
| 532 |
-
dx.stride(1),
|
| 533 |
-
dx.stride(2),
|
| 534 |
-
dx.stride(3),
|
| 535 |
-
ddt.stride(0),
|
| 536 |
-
ddt.stride(2),
|
| 537 |
-
ddt.stride(1),
|
| 538 |
-
ddt.stride(3),
|
| 539 |
-
dD_strides[1],
|
| 540 |
-
dD_strides[2],
|
| 541 |
-
dD_strides[3],
|
| 542 |
-
dD_strides[0],
|
| 543 |
-
dD_strides[4],
|
| 544 |
-
D is not None,
|
| 545 |
-
D.dim() == 2 if D is not None else True,
|
| 546 |
-
HAS_SEQ_IDX=seq_idx is not None,
|
| 547 |
-
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
| 548 |
-
IS_TRITON_22=TRITON_22
|
| 549 |
-
)
|
| 550 |
-
if D is not None:
|
| 551 |
-
BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs[
|
| 552 |
-
"BLOCK_SIZE_M"
|
| 553 |
-
]
|
| 554 |
-
n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
|
| 555 |
-
dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
|
| 556 |
-
if D.dim() == 1:
|
| 557 |
-
dD = rearrange(dD, "h 1 -> h")
|
| 558 |
-
return dx, ddt.to(dtype=dt.dtype), dD
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
def _mamba_chunk_scan_combined_fwd(
|
| 562 |
-
x,
|
| 563 |
-
dt,
|
| 564 |
-
A,
|
| 565 |
-
B,
|
| 566 |
-
C,
|
| 567 |
-
chunk_size,
|
| 568 |
-
D=None,
|
| 569 |
-
z=None,
|
| 570 |
-
dt_bias=None,
|
| 571 |
-
initial_states=None,
|
| 572 |
-
seq_idx=None,
|
| 573 |
-
cu_seqlens=None,
|
| 574 |
-
dt_softplus=False,
|
| 575 |
-
dt_limit=(0.0, float("inf")),
|
| 576 |
-
):
|
| 577 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 578 |
-
_, _, ngroups, dstate = B.shape
|
| 579 |
-
assert nheads % ngroups == 0
|
| 580 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 581 |
-
assert x.shape == (batch, seqlen, nheads, headdim)
|
| 582 |
-
assert dt.shape == (batch, seqlen, nheads)
|
| 583 |
-
assert A.shape == (nheads,)
|
| 584 |
-
assert C.shape == B.shape
|
| 585 |
-
if z is not None:
|
| 586 |
-
assert z.shape == x.shape
|
| 587 |
-
if D is not None:
|
| 588 |
-
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
|
| 589 |
-
if seq_idx is not None:
|
| 590 |
-
assert seq_idx.shape == (batch, seqlen)
|
| 591 |
-
if B.stride(-1) != 1:
|
| 592 |
-
B = B.contiguous()
|
| 593 |
-
if C.stride(-1) != 1:
|
| 594 |
-
C = C.contiguous()
|
| 595 |
-
if (
|
| 596 |
-
x.stride(-1) != 1 and x.stride(1) != 1
|
| 597 |
-
): # Either M or K dimension should be contiguous
|
| 598 |
-
x = x.contiguous()
|
| 599 |
-
if (
|
| 600 |
-
z is not None and z.stride(-1) != 1 and z.stride(1) != 1
|
| 601 |
-
): # Either M or K dimension should be contiguous
|
| 602 |
-
z = z.contiguous()
|
| 603 |
-
if D is not None and D.stride(-1) != 1:
|
| 604 |
-
D = D.contiguous()
|
| 605 |
-
if initial_states is not None:
|
| 606 |
-
assert initial_states.shape == (batch, nheads, headdim, dstate)
|
| 607 |
-
# # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size)
|
| 608 |
-
# dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
|
| 609 |
-
# dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
|
| 610 |
-
# dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
|
| 611 |
-
dA_cumsum, dt = _chunk_cumsum_fwd(
|
| 612 |
-
dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit
|
| 613 |
-
)
|
| 614 |
-
states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
|
| 615 |
-
# states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True)
|
| 616 |
-
# states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True)
|
| 617 |
-
# states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True)
|
| 618 |
-
states, final_states = _state_passing_fwd(
|
| 619 |
-
rearrange(states, "... p n -> ... (p n)"),
|
| 620 |
-
dA_cumsum[:, :, :, -1],
|
| 621 |
-
initial_states=(
|
| 622 |
-
rearrange(initial_states, "... p n -> ... (p n)")
|
| 623 |
-
if initial_states is not None
|
| 624 |
-
else None
|
| 625 |
-
),
|
| 626 |
-
seq_idx=seq_idx,
|
| 627 |
-
chunk_size=chunk_size,
|
| 628 |
-
out_dtype=C.dtype,
|
| 629 |
-
)
|
| 630 |
-
states, final_states = [
|
| 631 |
-
rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]
|
| 632 |
-
]
|
| 633 |
-
# states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
|
| 634 |
-
# states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
|
| 635 |
-
CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
|
| 636 |
-
out, out_x = _chunk_scan_fwd(
|
| 637 |
-
CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx
|
| 638 |
-
)
|
| 639 |
-
if cu_seqlens is None:
|
| 640 |
-
return out, out_x, dt, dA_cumsum, states, final_states
|
| 641 |
-
else:
|
| 642 |
-
assert (
|
| 643 |
-
batch == 1
|
| 644 |
-
), "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
|
| 645 |
-
varlen_states = chunk_state_varlen(
|
| 646 |
-
B.squeeze(0),
|
| 647 |
-
x.squeeze(0),
|
| 648 |
-
dt.squeeze(0),
|
| 649 |
-
dA_cumsum.squeeze(0),
|
| 650 |
-
cu_seqlens,
|
| 651 |
-
states.squeeze(0),
|
| 652 |
-
)
|
| 653 |
-
return out, out_x, dt, dA_cumsum, states, final_states, varlen_states
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
def _mamba_chunk_scan_combined_bwd(
|
| 657 |
-
dout,
|
| 658 |
-
x,
|
| 659 |
-
dt,
|
| 660 |
-
A,
|
| 661 |
-
B,
|
| 662 |
-
C,
|
| 663 |
-
out,
|
| 664 |
-
chunk_size,
|
| 665 |
-
D=None,
|
| 666 |
-
z=None,
|
| 667 |
-
dt_bias=None,
|
| 668 |
-
initial_states=None,
|
| 669 |
-
dfinal_states=None,
|
| 670 |
-
seq_idx=None,
|
| 671 |
-
dt_softplus=False,
|
| 672 |
-
dt_limit=(0.0, float("inf")),
|
| 673 |
-
dx=None,
|
| 674 |
-
ddt=None,
|
| 675 |
-
dB=None,
|
| 676 |
-
dC=None,
|
| 677 |
-
dz=None,
|
| 678 |
-
recompute_output=False,
|
| 679 |
-
):
|
| 680 |
-
if dout.stride(-1) != 1:
|
| 681 |
-
dout = dout.contiguous()
|
| 682 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 683 |
-
nchunks = math.ceil(seqlen / chunk_size)
|
| 684 |
-
_, _, ngroups, dstate = B.shape
|
| 685 |
-
assert dout.shape == (batch, seqlen, nheads, headdim)
|
| 686 |
-
assert dt.shape == (batch, seqlen, nheads)
|
| 687 |
-
assert A.shape == (nheads,)
|
| 688 |
-
assert nheads % ngroups == 0
|
| 689 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 690 |
-
assert C.shape == B.shape
|
| 691 |
-
assert out.shape == x.shape
|
| 692 |
-
if initial_states is not None:
|
| 693 |
-
assert initial_states.shape == (batch, nheads, headdim, dstate)
|
| 694 |
-
if seq_idx is not None:
|
| 695 |
-
assert seq_idx.shape == (batch, seqlen)
|
| 696 |
-
if dx is not None:
|
| 697 |
-
assert dx.shape == x.shape
|
| 698 |
-
if dB is not None:
|
| 699 |
-
assert dB.shape == B.shape
|
| 700 |
-
dB_given = dB
|
| 701 |
-
else:
|
| 702 |
-
dB_given = torch.empty_like(B)
|
| 703 |
-
if dC is not None:
|
| 704 |
-
assert dC.shape == C.shape
|
| 705 |
-
dC_given = dC
|
| 706 |
-
else:
|
| 707 |
-
dC_given = torch.empty_like(C)
|
| 708 |
-
if dz is not None:
|
| 709 |
-
assert z is not None
|
| 710 |
-
assert dz.shape == z.shape
|
| 711 |
-
if ddt is not None:
|
| 712 |
-
assert ddt.shape == dt.shape
|
| 713 |
-
ddt_given = ddt
|
| 714 |
-
else:
|
| 715 |
-
ddt_given = torch.empty_like(dt)
|
| 716 |
-
# TD: For some reason Triton (2.1.0 and 2.2.0) errors with
|
| 717 |
-
# "[CUDA]: invalid device context" (e.g. during varlne test), and cloning makes it work. Idk why.
|
| 718 |
-
dt_in = dt.clone()
|
| 719 |
-
dA_cumsum, dt = _chunk_cumsum_fwd(
|
| 720 |
-
dt_in,
|
| 721 |
-
A,
|
| 722 |
-
chunk_size,
|
| 723 |
-
dt_bias=dt_bias,
|
| 724 |
-
dt_softplus=dt_softplus,
|
| 725 |
-
dt_limit=dt_limit,
|
| 726 |
-
)
|
| 727 |
-
CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
|
| 728 |
-
states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
|
| 729 |
-
states, _ = _state_passing_fwd(
|
| 730 |
-
rearrange(states, "... p n -> ... (p n)"),
|
| 731 |
-
dA_cumsum[:, :, :, -1],
|
| 732 |
-
initial_states=(
|
| 733 |
-
rearrange(initial_states, "... p n -> ... (p n)")
|
| 734 |
-
if initial_states is not None
|
| 735 |
-
else None
|
| 736 |
-
),
|
| 737 |
-
seq_idx=seq_idx,
|
| 738 |
-
chunk_size=chunk_size,
|
| 739 |
-
)
|
| 740 |
-
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
|
| 741 |
-
if z is not None:
|
| 742 |
-
dz, dout, dD, *rest = _chunk_scan_bwd_dz(
|
| 743 |
-
x,
|
| 744 |
-
z,
|
| 745 |
-
out,
|
| 746 |
-
dout,
|
| 747 |
-
chunk_size=chunk_size,
|
| 748 |
-
has_ddAcs=False,
|
| 749 |
-
D=D,
|
| 750 |
-
dz=dz,
|
| 751 |
-
recompute_output=recompute_output,
|
| 752 |
-
)
|
| 753 |
-
outz = rest[0] if recompute_output else out
|
| 754 |
-
else:
|
| 755 |
-
dz = None
|
| 756 |
-
outz = out
|
| 757 |
-
dstates = _chunk_scan_bwd_dstates(
|
| 758 |
-
C, dA_cumsum, dout, seq_idx=seq_idx, dtype=states.dtype
|
| 759 |
-
)
|
| 760 |
-
# dstates has length nchunks, containing the gradient to initial states at index 0 and
|
| 761 |
-
# gradient to the states of chunk (nchunks - 2) at index (nchunks - 1)
|
| 762 |
-
# Do computation in fp32 but convert dstates and states to fp16/bf16 since dstates and states
|
| 763 |
-
# will be used in matmul in the next kernels.
|
| 764 |
-
dstates, ddA_chunk_cumsum, dinitial_states, states = _state_passing_bwd(
|
| 765 |
-
rearrange(states, "... p n -> ... (p n)"),
|
| 766 |
-
dA_cumsum[:, :, :, -1],
|
| 767 |
-
rearrange(dstates, "... p n -> ... (p n)"),
|
| 768 |
-
dfinal_states=(
|
| 769 |
-
rearrange(dfinal_states, "... p n -> ... (p n)")
|
| 770 |
-
if dfinal_states is not None
|
| 771 |
-
else None
|
| 772 |
-
),
|
| 773 |
-
seq_idx=seq_idx,
|
| 774 |
-
has_initial_states=initial_states is not None,
|
| 775 |
-
dstates_dtype=x.dtype,
|
| 776 |
-
states_dtype=x.dtype,
|
| 777 |
-
chunk_size=chunk_size,
|
| 778 |
-
)
|
| 779 |
-
# dstates has length nchunks, containing the gradient to states of chunk 0 at index 0 and
|
| 780 |
-
# gradient to the final states at index (nchunks - 1)
|
| 781 |
-
# states has length nchunks, containing the initial states at index 0 and the state for chunk (nchunks - 2) at index (nchunks - 1)
|
| 782 |
-
# The final states is not stored.
|
| 783 |
-
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
|
| 784 |
-
dstates = rearrange(dstates, "... (p n) -> ... p n", n=dstate)
|
| 785 |
-
dinitial_states = (
|
| 786 |
-
rearrange(dinitial_states, "... (p n) -> ... p n", n=dstate)
|
| 787 |
-
if dinitial_states is not None
|
| 788 |
-
else None
|
| 789 |
-
)
|
| 790 |
-
dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(
|
| 791 |
-
x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx
|
| 792 |
-
)
|
| 793 |
-
# dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, ngroups=ngroups)
|
| 794 |
-
dB, ddA_next = _chunk_state_bwd_db(
|
| 795 |
-
x, dt, dA_cumsum, dstates, seq_idx=seq_idx, B=B, ngroups=ngroups
|
| 796 |
-
)
|
| 797 |
-
# dC = _chunk_scan_bwd_dC(states[:, :-1].to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
|
| 798 |
-
dC, ddA_cumsum_prev = _chunk_scan_bwd_dC(
|
| 799 |
-
states.to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, C=C, ngroups=ngroups
|
| 800 |
-
)
|
| 801 |
-
# Computing ddA with the dcb kernel is much slower, so we're not using it for now
|
| 802 |
-
dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
|
| 803 |
-
# dCB, ddA_tmp = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, CB=CB, ngroups=ngroups)
|
| 804 |
-
dCB = dCB.to(CB.dtype)
|
| 805 |
-
_bmm_chunk_bwd(C, dCB, residual=dB, out=dB_given)
|
| 806 |
-
_bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC, out=dC_given)
|
| 807 |
-
# If we have z, then dout_x is recomputed in fp32 so dD = (dout_x * x).sum() is more accurate
|
| 808 |
-
# than dD_from_x = (dout_x * x).sum() where dout_x is in fp16/bf16
|
| 809 |
-
if z is None:
|
| 810 |
-
dD = dD_from_x
|
| 811 |
-
# Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D.
|
| 812 |
-
# ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt
|
| 813 |
-
# However, this is numerically unstable: when we do the reverse cumsum on ddA_cumsum, there might
|
| 814 |
-
# be a lot of underflow.
|
| 815 |
-
|
| 816 |
-
# This is already done as part of bwd_dC kernel
|
| 817 |
-
# ddA_cumsum_prev = _chunk_scan_bwd_ddAcs_prev(states[:, :-1], C, dout, dA_cumsum, seq_idx=seq_idx)
|
| 818 |
-
ddA_cumsum_prev[..., -1] += ddA_chunk_cumsum
|
| 819 |
-
ddA_prev = ddA_cumsum_prev.flip([-1]).cumsum(dim=-1).flip([-1])
|
| 820 |
-
# This is already done as part of bwd_dB kernel
|
| 821 |
-
# ddA_next = _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=seq_idx)
|
| 822 |
-
# We don't need to pass in seq_idx because CB also zeros out entries where seq_idx[i] != seq_idx[j]
|
| 823 |
-
ddA = _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, CB)
|
| 824 |
-
ddA += ddA_next + ddA_prev
|
| 825 |
-
|
| 826 |
-
ddt_given, dA, ddt_bias = _chunk_cumsum_bwd(
|
| 827 |
-
ddA,
|
| 828 |
-
ddt,
|
| 829 |
-
dt_in,
|
| 830 |
-
A,
|
| 831 |
-
dt_bias=dt_bias,
|
| 832 |
-
dt_softplus=dt_softplus,
|
| 833 |
-
dt_limit=dt_limit,
|
| 834 |
-
ddt=ddt_given,
|
| 835 |
-
)
|
| 836 |
-
|
| 837 |
-
# These 2 lines are just to test ddt and dA being computed by old code
|
| 838 |
-
# _, dA = selective_scan_bwd(dout, x, dt, A, B, C, D=D.float(), z=z)
|
| 839 |
-
# ddt_given.copy_(ddt)
|
| 840 |
-
|
| 841 |
-
return_vals = (
|
| 842 |
-
dx,
|
| 843 |
-
ddt_given,
|
| 844 |
-
dA,
|
| 845 |
-
dB_given,
|
| 846 |
-
dC_given,
|
| 847 |
-
dD,
|
| 848 |
-
dz,
|
| 849 |
-
ddt_bias,
|
| 850 |
-
dinitial_states,
|
| 851 |
-
)
|
| 852 |
-
return return_vals if not recompute_output else (*return_vals, outz)
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
def selective_scan_bwd(dout, x, dt, A, B, C, D=None, z=None):
|
| 856 |
-
"""
|
| 857 |
-
Argument:
|
| 858 |
-
dout: (batch, seqlen, nheads, headdim)
|
| 859 |
-
x: (batch, seqlen, nheads, headdim)
|
| 860 |
-
dt: (batch, nheads, nchunks, chunk_size) or (batch, nheads, headdim, nchunks, chunk_size)
|
| 861 |
-
A: (nheads) or (dim, dstate)
|
| 862 |
-
B: (batch, seqlen, ngroups, dstate)
|
| 863 |
-
C: (batch, seqlen, ngroups, dstate)
|
| 864 |
-
D: (nheads, headdim) or (nheads,)
|
| 865 |
-
z: (batch, seqlen, nheads, headdim)
|
| 866 |
-
Return:
|
| 867 |
-
out: (batch, seqlen, nheads, headdim)
|
| 868 |
-
"""
|
| 869 |
-
import selective_scan
|
| 870 |
-
|
| 871 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 872 |
-
chunk_size = dt.shape[-1]
|
| 873 |
-
_, _, ngroups, dstate = B.shape
|
| 874 |
-
assert nheads % ngroups == 0
|
| 875 |
-
x = rearrange(x, "b l h p -> b (h p) l")
|
| 876 |
-
squeeze_dt = dt.dim() == 4
|
| 877 |
-
if dt.dim() == 4:
|
| 878 |
-
dt = repeat(dt, "b h c l -> b h p c l", p=headdim)
|
| 879 |
-
dt = rearrange(dt, "b h p c l -> b (h p) (c l)", p=headdim)
|
| 880 |
-
squeeze_A = A.dim() == 1
|
| 881 |
-
if A.dim() == 1:
|
| 882 |
-
A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
|
| 883 |
-
else:
|
| 884 |
-
A = A.to(dtype=torch.float32)
|
| 885 |
-
B = rearrange(B, "b l g n -> b g n l")
|
| 886 |
-
C = rearrange(C, "b l g n -> b g n l")
|
| 887 |
-
if D is not None:
|
| 888 |
-
if D.dim() == 2:
|
| 889 |
-
D = rearrange(D, "h p -> (h p)")
|
| 890 |
-
else:
|
| 891 |
-
D = repeat(D, "h -> (h p)", p=headdim)
|
| 892 |
-
if z is not None:
|
| 893 |
-
z = rearrange(z, "b l h p -> b (h p) l")
|
| 894 |
-
|
| 895 |
-
if x.stride(-1) != 1:
|
| 896 |
-
x = x.contiguous()
|
| 897 |
-
if dt.stride(-1) != 1:
|
| 898 |
-
dt = dt.contiguous()
|
| 899 |
-
if D is not None:
|
| 900 |
-
D = D.contiguous()
|
| 901 |
-
if B.stride(-1) != 1:
|
| 902 |
-
B = B.contiguous()
|
| 903 |
-
if C.stride(-1) != 1:
|
| 904 |
-
C = C.contiguous()
|
| 905 |
-
if z is not None and z.stride(-1) != 1:
|
| 906 |
-
z = z.contiguous()
|
| 907 |
-
_, intermediate, *rest = selective_scan.fwd(
|
| 908 |
-
x, dt.to(dtype=x.dtype), A, B, C, D, z, None, False
|
| 909 |
-
)
|
| 910 |
-
if z is not None:
|
| 911 |
-
out = rest[0]
|
| 912 |
-
else:
|
| 913 |
-
out = None
|
| 914 |
-
|
| 915 |
-
dout = rearrange(dout, "b l h p -> b (h p) l")
|
| 916 |
-
|
| 917 |
-
if dout.stride(-1) != 1:
|
| 918 |
-
dout = dout.contiguous()
|
| 919 |
-
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
| 920 |
-
# backward of selective_scan with the backward of chunk).
|
| 921 |
-
# Here we just pass in None and dz will be allocated in the C++ code.
|
| 922 |
-
_, ddt, dA, *rest = selective_scan.bwd(
|
| 923 |
-
x,
|
| 924 |
-
dt.to(dtype=x.dtype),
|
| 925 |
-
A,
|
| 926 |
-
B,
|
| 927 |
-
C,
|
| 928 |
-
D,
|
| 929 |
-
z,
|
| 930 |
-
None,
|
| 931 |
-
dout,
|
| 932 |
-
intermediate,
|
| 933 |
-
out,
|
| 934 |
-
None,
|
| 935 |
-
False,
|
| 936 |
-
False, # option to recompute out_z, not used here
|
| 937 |
-
)
|
| 938 |
-
ddt = rearrange(ddt, "b (h p) (c l) -> b h p c l", p=headdim, l=chunk_size)
|
| 939 |
-
if squeeze_dt:
|
| 940 |
-
ddt = ddt.float().sum(dim=2)
|
| 941 |
-
if squeeze_A:
|
| 942 |
-
dA = rearrange(dA, "(h p) n -> h p n", p=headdim).sum(dim=(1, 2))
|
| 943 |
-
return ddt, dA
|
| 944 |
-
|
| 945 |
-
|
| 946 |
-
class MambaChunkScanCombinedFn(torch.autograd.Function):
|
| 947 |
-
|
| 948 |
-
@staticmethod
|
| 949 |
-
def forward(
|
| 950 |
-
ctx,
|
| 951 |
-
x,
|
| 952 |
-
dt,
|
| 953 |
-
A,
|
| 954 |
-
B,
|
| 955 |
-
C,
|
| 956 |
-
chunk_size,
|
| 957 |
-
D=None,
|
| 958 |
-
z=None,
|
| 959 |
-
dt_bias=None,
|
| 960 |
-
initial_states=None,
|
| 961 |
-
seq_idx=None,
|
| 962 |
-
cu_seqlens=None,
|
| 963 |
-
dt_softplus=False,
|
| 964 |
-
dt_limit=(0.0, float("inf")),
|
| 965 |
-
return_final_states=False,
|
| 966 |
-
return_varlen_states=False,
|
| 967 |
-
):
|
| 968 |
-
ctx.dt_dtype = dt.dtype
|
| 969 |
-
if not return_varlen_states:
|
| 970 |
-
cu_seqlens = None
|
| 971 |
-
else:
|
| 972 |
-
assert (
|
| 973 |
-
cu_seqlens is not None
|
| 974 |
-
), "cu_seqlens must be provided if return_varlen_states is True"
|
| 975 |
-
out, out_x, dt_out, dA_cumsum, states, final_states, *rest = (
|
| 976 |
-
_mamba_chunk_scan_combined_fwd(
|
| 977 |
-
x,
|
| 978 |
-
dt,
|
| 979 |
-
A,
|
| 980 |
-
B,
|
| 981 |
-
C,
|
| 982 |
-
chunk_size,
|
| 983 |
-
D=D,
|
| 984 |
-
z=z,
|
| 985 |
-
dt_bias=dt_bias,
|
| 986 |
-
initial_states=initial_states,
|
| 987 |
-
seq_idx=seq_idx,
|
| 988 |
-
cu_seqlens=cu_seqlens,
|
| 989 |
-
dt_softplus=dt_softplus,
|
| 990 |
-
dt_limit=dt_limit,
|
| 991 |
-
)
|
| 992 |
-
)
|
| 993 |
-
ctx.save_for_backward(
|
| 994 |
-
out if z is None else out_x,
|
| 995 |
-
x,
|
| 996 |
-
dt,
|
| 997 |
-
dA_cumsum,
|
| 998 |
-
A,
|
| 999 |
-
B,
|
| 1000 |
-
C,
|
| 1001 |
-
D,
|
| 1002 |
-
z,
|
| 1003 |
-
dt_bias,
|
| 1004 |
-
initial_states,
|
| 1005 |
-
seq_idx,
|
| 1006 |
-
)
|
| 1007 |
-
ctx.dt_softplus = dt_softplus
|
| 1008 |
-
ctx.chunk_size = chunk_size
|
| 1009 |
-
ctx.dt_limit = dt_limit
|
| 1010 |
-
ctx.return_final_states = return_final_states
|
| 1011 |
-
ctx.return_varlen_states = return_varlen_states
|
| 1012 |
-
if not return_varlen_states:
|
| 1013 |
-
return out if not return_final_states else (out, final_states)
|
| 1014 |
-
else:
|
| 1015 |
-
varlen_states = rest[0]
|
| 1016 |
-
return (
|
| 1017 |
-
(out, varlen_states)
|
| 1018 |
-
if not return_final_states
|
| 1019 |
-
else (out, final_states, varlen_states)
|
| 1020 |
-
)
|
| 1021 |
-
|
| 1022 |
-
@staticmethod
|
| 1023 |
-
def backward(ctx, dout, *args):
|
| 1024 |
-
out, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx = (
|
| 1025 |
-
ctx.saved_tensors
|
| 1026 |
-
)
|
| 1027 |
-
assert (
|
| 1028 |
-
not ctx.return_varlen_states
|
| 1029 |
-
), "return_varlen_states is not supported in backward"
|
| 1030 |
-
dfinal_states = args[0] if ctx.return_final_states else None
|
| 1031 |
-
dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = (
|
| 1032 |
-
_mamba_chunk_scan_combined_bwd(
|
| 1033 |
-
dout,
|
| 1034 |
-
x,
|
| 1035 |
-
dt,
|
| 1036 |
-
A,
|
| 1037 |
-
B,
|
| 1038 |
-
C,
|
| 1039 |
-
out,
|
| 1040 |
-
ctx.chunk_size,
|
| 1041 |
-
D=D,
|
| 1042 |
-
z=z,
|
| 1043 |
-
dt_bias=dt_bias,
|
| 1044 |
-
initial_states=initial_states,
|
| 1045 |
-
dfinal_states=dfinal_states,
|
| 1046 |
-
seq_idx=seq_idx,
|
| 1047 |
-
dt_softplus=ctx.dt_softplus,
|
| 1048 |
-
dt_limit=ctx.dt_limit,
|
| 1049 |
-
)
|
| 1050 |
-
)
|
| 1051 |
-
return (
|
| 1052 |
-
dx,
|
| 1053 |
-
ddt,
|
| 1054 |
-
dA,
|
| 1055 |
-
dB,
|
| 1056 |
-
dC,
|
| 1057 |
-
None,
|
| 1058 |
-
dD,
|
| 1059 |
-
dz,
|
| 1060 |
-
ddt_bias,
|
| 1061 |
-
dinitial_states,
|
| 1062 |
-
None,
|
| 1063 |
-
None,
|
| 1064 |
-
None,
|
| 1065 |
-
None,
|
| 1066 |
-
None,
|
| 1067 |
-
None,
|
| 1068 |
-
)
|
| 1069 |
-
|
| 1070 |
-
|
| 1071 |
-
def mamba_chunk_scan_combined(
|
| 1072 |
-
x,
|
| 1073 |
-
dt,
|
| 1074 |
-
A,
|
| 1075 |
-
B,
|
| 1076 |
-
C,
|
| 1077 |
-
chunk_size,
|
| 1078 |
-
D=None,
|
| 1079 |
-
z=None,
|
| 1080 |
-
dt_bias=None,
|
| 1081 |
-
initial_states=None,
|
| 1082 |
-
seq_idx=None,
|
| 1083 |
-
cu_seqlens=None,
|
| 1084 |
-
dt_softplus=False,
|
| 1085 |
-
dt_limit=(0.0, float("inf")),
|
| 1086 |
-
return_final_states=False,
|
| 1087 |
-
return_varlen_states=False,
|
| 1088 |
-
):
|
| 1089 |
-
"""
|
| 1090 |
-
Argument:
|
| 1091 |
-
x: (batch, seqlen, nheads, headdim)
|
| 1092 |
-
dt: (batch, seqlen, nheads)
|
| 1093 |
-
A: (nheads)
|
| 1094 |
-
B: (batch, seqlen, ngroups, dstate)
|
| 1095 |
-
C: (batch, seqlen, ngroups, dstate)
|
| 1096 |
-
chunk_size: int
|
| 1097 |
-
D: (nheads, headdim) or (nheads,)
|
| 1098 |
-
z: (batch, seqlen, nheads, headdim)
|
| 1099 |
-
dt_bias: (nheads,)
|
| 1100 |
-
initial_states: (batch, nheads, headdim, dstate)
|
| 1101 |
-
seq_idx: (batch, seqlen)
|
| 1102 |
-
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
|
| 1103 |
-
dt_softplus: Whether to apply softplus to dt
|
| 1104 |
-
Return:
|
| 1105 |
-
out: (batch, seqlen, nheads, headdim)
|
| 1106 |
-
"""
|
| 1107 |
-
return MambaChunkScanCombinedFn.apply(
|
| 1108 |
-
x,
|
| 1109 |
-
dt,
|
| 1110 |
-
A,
|
| 1111 |
-
B,
|
| 1112 |
-
C,
|
| 1113 |
-
chunk_size,
|
| 1114 |
-
D,
|
| 1115 |
-
z,
|
| 1116 |
-
dt_bias,
|
| 1117 |
-
initial_states,
|
| 1118 |
-
seq_idx,
|
| 1119 |
-
cu_seqlens,
|
| 1120 |
-
dt_softplus,
|
| 1121 |
-
dt_limit,
|
| 1122 |
-
return_final_states,
|
| 1123 |
-
return_varlen_states,
|
| 1124 |
-
)
|
| 1125 |
-
|
| 1126 |
-
|
| 1127 |
-
def mamba_chunk_scan(
|
| 1128 |
-
x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False
|
| 1129 |
-
):
|
| 1130 |
-
"""
|
| 1131 |
-
Argument:
|
| 1132 |
-
x: (batch, seqlen, nheads, headdim)
|
| 1133 |
-
dt: (batch, seqlen, nheads)
|
| 1134 |
-
A: (nheads)
|
| 1135 |
-
B: (batch, seqlen, ngroups, dstate)
|
| 1136 |
-
C: (batch, seqlen, ngroups, dstate)
|
| 1137 |
-
D: (nheads, headdim) or (nheads,)
|
| 1138 |
-
z: (batch, seqlen, nheads, headdim)
|
| 1139 |
-
dt_bias: (nheads,)
|
| 1140 |
-
Return:
|
| 1141 |
-
out: (batch, seqlen, nheads, headdim)
|
| 1142 |
-
"""
|
| 1143 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1144 |
-
dstate = B.shape[-1]
|
| 1145 |
-
if seqlen % chunk_size != 0:
|
| 1146 |
-
dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
|
| 1147 |
-
dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
|
| 1148 |
-
dt = dt.float() # We want high precision for this before cumsum
|
| 1149 |
-
if dt_bias is not None:
|
| 1150 |
-
dt = dt + rearrange(dt_bias, "h -> h 1 1")
|
| 1151 |
-
if dt_softplus:
|
| 1152 |
-
dt = F.softplus(dt)
|
| 1153 |
-
dA = dt * rearrange(A, "h -> h 1 1")
|
| 1154 |
-
dA = dt * rearrange(A, "h -> h 1 1")
|
| 1155 |
-
dA_cumsum = torch.cumsum(dA, dim=-1)
|
| 1156 |
-
# 1. Compute the state for each chunk
|
| 1157 |
-
states = chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True)
|
| 1158 |
-
# 2. Pass the state to all the chunks by weighted cumsum.
|
| 1159 |
-
states = rearrange(
|
| 1160 |
-
state_passing(
|
| 1161 |
-
rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1]
|
| 1162 |
-
)[0],
|
| 1163 |
-
"... (p n) -> ... p n",
|
| 1164 |
-
n=dstate,
|
| 1165 |
-
)
|
| 1166 |
-
# 3. Compute the output for each chunk
|
| 1167 |
-
out = chunk_scan(B, C, x, dt, dA_cumsum, states, D=D, z=z)
|
| 1168 |
-
return out
|
| 1169 |
-
|
| 1170 |
-
|
| 1171 |
-
def ssd_chunk_scan_combined_ref(
|
| 1172 |
-
x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False
|
| 1173 |
-
):
|
| 1174 |
-
"""
|
| 1175 |
-
Argument:
|
| 1176 |
-
x: (batch, seqlen, nheads, headdim)
|
| 1177 |
-
dt: (batch, seqlen, nheads)
|
| 1178 |
-
A: (nheads)
|
| 1179 |
-
B: (batch, seqlen, ngroups, dstate)
|
| 1180 |
-
C: (batch, seqlen, ngroups, dstate)
|
| 1181 |
-
D: (nheads, headdim) or (nheads,)
|
| 1182 |
-
z: (batch, seqlen, nheads, headdim)
|
| 1183 |
-
dt_bias: (nheads,)
|
| 1184 |
-
Return:
|
| 1185 |
-
out: (batch, seqlen, nheads, headdim)
|
| 1186 |
-
"""
|
| 1187 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1188 |
-
dstate = B.shape[-1]
|
| 1189 |
-
if seqlen % chunk_size != 0:
|
| 1190 |
-
dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
|
| 1191 |
-
dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
|
| 1192 |
-
dt = dt.float() # We want high precision for this before cumsum
|
| 1193 |
-
if dt_bias is not None:
|
| 1194 |
-
dt = dt + rearrange(dt_bias, "h -> h 1 1")
|
| 1195 |
-
if dt_softplus:
|
| 1196 |
-
dt = F.softplus(dt)
|
| 1197 |
-
dA = dt * rearrange(A, "h -> h 1 1")
|
| 1198 |
-
dA_cumsum = torch.cumsum(dA, dim=-1)
|
| 1199 |
-
# 1. Compute the state for each chunk
|
| 1200 |
-
states = chunk_state_ref(B, x, dt, dA_cumsum)
|
| 1201 |
-
states_dtype = states.dtype
|
| 1202 |
-
if states.dtype not in [torch.float32, torch.float64]:
|
| 1203 |
-
states = states.to(torch.float32)
|
| 1204 |
-
# 2. Pass the state to all the chunks by weighted cumsum.
|
| 1205 |
-
# state_passing_ref is much less numerically stable
|
| 1206 |
-
states = rearrange(
|
| 1207 |
-
state_passing_ref(
|
| 1208 |
-
rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1]
|
| 1209 |
-
)[0],
|
| 1210 |
-
"... (p n) -> ... p n",
|
| 1211 |
-
n=dstate,
|
| 1212 |
-
)
|
| 1213 |
-
states = states.to(states_dtype)
|
| 1214 |
-
# 3. Compute the output for each chunk
|
| 1215 |
-
out = chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z)
|
| 1216 |
-
return out
|
| 1217 |
-
|
| 1218 |
-
|
| 1219 |
-
def ssd_selective_scan(
|
| 1220 |
-
x,
|
| 1221 |
-
dt,
|
| 1222 |
-
A,
|
| 1223 |
-
B,
|
| 1224 |
-
C,
|
| 1225 |
-
D=None,
|
| 1226 |
-
z=None,
|
| 1227 |
-
dt_bias=None,
|
| 1228 |
-
dt_softplus=False,
|
| 1229 |
-
dt_limit=(0.0, float("inf")),
|
| 1230 |
-
):
|
| 1231 |
-
"""
|
| 1232 |
-
Argument:
|
| 1233 |
-
x: (batch, seqlen, nheads, headdim)
|
| 1234 |
-
dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
|
| 1235 |
-
A: (nheads) or (dim, dstate)
|
| 1236 |
-
B: (batch, seqlen, ngroups, dstate)
|
| 1237 |
-
C: (batch, seqlen, ngroups, dstate)
|
| 1238 |
-
D: (nheads, headdim) or (nheads,)
|
| 1239 |
-
z: (batch, seqlen, nheads, headdim)
|
| 1240 |
-
dt_bias: (nheads,) or (nheads, headdim)
|
| 1241 |
-
Return:
|
| 1242 |
-
out: (batch, seqlen, nheads, headdim)
|
| 1243 |
-
"""
|
| 1244 |
-
from ..selective_scan_interface import selective_scan_fn
|
| 1245 |
-
|
| 1246 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1247 |
-
_, _, ngroups, dstate = B.shape
|
| 1248 |
-
x = rearrange(x, "b l h p -> b (h p) l")
|
| 1249 |
-
if dt.dim() == 3:
|
| 1250 |
-
dt = repeat(dt, "b l h -> b l h p", p=headdim)
|
| 1251 |
-
dt = rearrange(dt, "b l h p -> b (h p) l")
|
| 1252 |
-
if A.dim() == 1:
|
| 1253 |
-
A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
|
| 1254 |
-
else:
|
| 1255 |
-
A = A.to(dtype=torch.float32)
|
| 1256 |
-
B = rearrange(B, "b l g n -> b g n l")
|
| 1257 |
-
C = rearrange(C, "b l g n -> b g n l")
|
| 1258 |
-
if D is not None:
|
| 1259 |
-
if D.dim() == 2:
|
| 1260 |
-
D = rearrange(D, "h p -> (h p)")
|
| 1261 |
-
else:
|
| 1262 |
-
D = repeat(D, "h -> (h p)", p=headdim)
|
| 1263 |
-
if z is not None:
|
| 1264 |
-
z = rearrange(z, "b l h p -> b (h p) l")
|
| 1265 |
-
if dt_bias is not None:
|
| 1266 |
-
if dt_bias.dim() == 1:
|
| 1267 |
-
dt_bias = repeat(dt_bias, "h -> h p", p=headdim)
|
| 1268 |
-
dt_bias = rearrange(dt_bias, "h p -> (h p)")
|
| 1269 |
-
if dt_limit != (0.0, float("inf")):
|
| 1270 |
-
if dt_bias is not None:
|
| 1271 |
-
dt = dt + rearrange(dt_bias, "d -> d 1")
|
| 1272 |
-
if dt_softplus:
|
| 1273 |
-
dt = F.softplus(dt)
|
| 1274 |
-
dt = dt.clamp(min=dt_limit[0], max=dt_limit[1]).to(x.dtype)
|
| 1275 |
-
dt_bias = None
|
| 1276 |
-
dt_softplus = None
|
| 1277 |
-
out = selective_scan_fn(
|
| 1278 |
-
x, dt, A, B, C, D=D, z=z, delta_bias=dt_bias, delta_softplus=dt_softplus
|
| 1279 |
-
)
|
| 1280 |
-
return rearrange(out, "b (h p) l -> b l h p", p=headdim)
|
| 1281 |
-
|
| 1282 |
-
|
| 1283 |
-
def mamba_conv1d_scan_ref(
|
| 1284 |
-
xBC,
|
| 1285 |
-
conv1d_weight,
|
| 1286 |
-
conv1d_bias,
|
| 1287 |
-
dt,
|
| 1288 |
-
A,
|
| 1289 |
-
chunk_size,
|
| 1290 |
-
D=None,
|
| 1291 |
-
z=None,
|
| 1292 |
-
dt_bias=None,
|
| 1293 |
-
dt_softplus=False,
|
| 1294 |
-
dt_limit=(0.0, float("inf")),
|
| 1295 |
-
activation="silu",
|
| 1296 |
-
headdim=None,
|
| 1297 |
-
ngroups=1,
|
| 1298 |
-
):
|
| 1299 |
-
"""
|
| 1300 |
-
Argument:
|
| 1301 |
-
xBC: (batch, seqlen, dim + 2 * ngroups * dstate) where dim == nheads * headdim
|
| 1302 |
-
conv1d_weight: (dim + 2 * ngroups * dstate, width)
|
| 1303 |
-
conv1d_bias: (dim + 2 * ngroups * dstate,)
|
| 1304 |
-
dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
|
| 1305 |
-
A: (nheads)
|
| 1306 |
-
D: (nheads, headdim) or (nheads,)
|
| 1307 |
-
z: (batch, seqlen, dim)
|
| 1308 |
-
dt_bias: (nheads) or (nheads, headdim)
|
| 1309 |
-
headdim: if D is 1D and z is None, headdim must be passed in
|
| 1310 |
-
Return:
|
| 1311 |
-
out: (batch, seqlen, dim)
|
| 1312 |
-
"""
|
| 1313 |
-
batch, seqlen, nheads = dt.shape[:3]
|
| 1314 |
-
assert nheads % ngroups == 0
|
| 1315 |
-
if z is not None:
|
| 1316 |
-
dim = z.shape[-1]
|
| 1317 |
-
assert dim % nheads == 0
|
| 1318 |
-
headdim = dim // nheads
|
| 1319 |
-
else:
|
| 1320 |
-
if D.dim() == 1:
|
| 1321 |
-
assert headdim is not None
|
| 1322 |
-
else:
|
| 1323 |
-
headdim = D.shape[1]
|
| 1324 |
-
dim = nheads * headdim
|
| 1325 |
-
xBC = rearrange(
|
| 1326 |
-
causal_conv1d_fn(
|
| 1327 |
-
rearrange(xBC, "b s d -> b d s"),
|
| 1328 |
-
conv1d_weight,
|
| 1329 |
-
conv1d_bias,
|
| 1330 |
-
activation=activation,
|
| 1331 |
-
),
|
| 1332 |
-
"b d s -> b s d",
|
| 1333 |
-
)
|
| 1334 |
-
dstate = (xBC.shape[-1] - dim) // ngroups // 2
|
| 1335 |
-
x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
|
| 1336 |
-
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
| 1337 |
-
B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
|
| 1338 |
-
C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
|
| 1339 |
-
z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
|
| 1340 |
-
out = ssd_selective_scan(
|
| 1341 |
-
x,
|
| 1342 |
-
dt.to(x.dtype),
|
| 1343 |
-
A,
|
| 1344 |
-
B,
|
| 1345 |
-
C,
|
| 1346 |
-
D=D.float(),
|
| 1347 |
-
z=z,
|
| 1348 |
-
dt_bias=dt_bias,
|
| 1349 |
-
dt_softplus=dt_softplus,
|
| 1350 |
-
dt_limit=dt_limit,
|
| 1351 |
-
)
|
| 1352 |
-
return rearrange(out, "b s h p -> b s (h p)")
|
| 1353 |
-
|
| 1354 |
-
|
| 1355 |
-
class MambaSplitConv1dScanCombinedFn(torch.autograd.Function):
|
| 1356 |
-
|
| 1357 |
-
@staticmethod
|
| 1358 |
-
@custom_fwd
|
| 1359 |
-
def forward(
|
| 1360 |
-
ctx,
|
| 1361 |
-
zxbcdt,
|
| 1362 |
-
conv1d_weight,
|
| 1363 |
-
conv1d_bias,
|
| 1364 |
-
dt_bias,
|
| 1365 |
-
A,
|
| 1366 |
-
D,
|
| 1367 |
-
chunk_size,
|
| 1368 |
-
initial_states=None,
|
| 1369 |
-
seq_idx=None,
|
| 1370 |
-
dt_limit=(0.0, float("inf")),
|
| 1371 |
-
return_final_states=False,
|
| 1372 |
-
activation="silu",
|
| 1373 |
-
rmsnorm_weight=None,
|
| 1374 |
-
rmsnorm_eps=1e-6,
|
| 1375 |
-
outproj_weight=None,
|
| 1376 |
-
outproj_bias=None,
|
| 1377 |
-
headdim=None,
|
| 1378 |
-
ngroups=1,
|
| 1379 |
-
norm_before_gate=True,
|
| 1380 |
-
):
|
| 1381 |
-
assert activation in [None, "silu", "swish"]
|
| 1382 |
-
if D.dim() == 1:
|
| 1383 |
-
assert headdim is not None
|
| 1384 |
-
(nheads,) = D.shape
|
| 1385 |
-
else:
|
| 1386 |
-
nheads, headdim = D.shape
|
| 1387 |
-
batch, seqlen, _ = zxbcdt.shape
|
| 1388 |
-
dim = nheads * headdim
|
| 1389 |
-
assert nheads % ngroups == 0
|
| 1390 |
-
dstate = (conv1d_weight.shape[0] - dim) // ngroups // 2
|
| 1391 |
-
d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ngroups * dstate - nheads) // 2
|
| 1392 |
-
assert d_nonssm >= 0
|
| 1393 |
-
assert zxbcdt.shape == (
|
| 1394 |
-
batch,
|
| 1395 |
-
seqlen,
|
| 1396 |
-
2 * d_nonssm + 2 * dim + 2 * ngroups * dstate + nheads,
|
| 1397 |
-
)
|
| 1398 |
-
assert dt_bias.shape == (nheads,)
|
| 1399 |
-
assert A.shape == (nheads,)
|
| 1400 |
-
zx0, z, xBC, dt = torch.split(
|
| 1401 |
-
zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1
|
| 1402 |
-
)
|
| 1403 |
-
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
|
| 1404 |
-
xBC_conv = rearrange(
|
| 1405 |
-
causal_conv1d_cuda.causal_conv1d_fwd(
|
| 1406 |
-
rearrange(xBC, "b s d -> b d s"),
|
| 1407 |
-
conv1d_weight,
|
| 1408 |
-
conv1d_bias,
|
| 1409 |
-
seq_idx,
|
| 1410 |
-
None,
|
| 1411 |
-
None,
|
| 1412 |
-
activation in ["silu", "swish"],
|
| 1413 |
-
),
|
| 1414 |
-
"b d s -> b s d",
|
| 1415 |
-
)
|
| 1416 |
-
x, B, C = torch.split(
|
| 1417 |
-
xBC_conv, [dim, ngroups * dstate, ngroups * dstate], dim=-1
|
| 1418 |
-
)
|
| 1419 |
-
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
| 1420 |
-
B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
|
| 1421 |
-
C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
|
| 1422 |
-
z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
|
| 1423 |
-
if rmsnorm_weight is None:
|
| 1424 |
-
out, out_x, dt_out, dA_cumsum, states, final_states = (
|
| 1425 |
-
_mamba_chunk_scan_combined_fwd(
|
| 1426 |
-
x,
|
| 1427 |
-
dt,
|
| 1428 |
-
A,
|
| 1429 |
-
B,
|
| 1430 |
-
C,
|
| 1431 |
-
chunk_size=chunk_size,
|
| 1432 |
-
D=D,
|
| 1433 |
-
z=z,
|
| 1434 |
-
dt_bias=dt_bias,
|
| 1435 |
-
initial_states=initial_states,
|
| 1436 |
-
seq_idx=seq_idx,
|
| 1437 |
-
dt_softplus=True,
|
| 1438 |
-
dt_limit=dt_limit,
|
| 1439 |
-
)
|
| 1440 |
-
)
|
| 1441 |
-
out = rearrange(out, "b s h p -> b s (h p)")
|
| 1442 |
-
rstd = None
|
| 1443 |
-
if d_nonssm > 0:
|
| 1444 |
-
out = torch.cat([_swiglu_fwd(zx0), out], dim=-1)
|
| 1445 |
-
else:
|
| 1446 |
-
out_x, _, dt_out, dA_cumsum, states, final_states = (
|
| 1447 |
-
_mamba_chunk_scan_combined_fwd(
|
| 1448 |
-
x,
|
| 1449 |
-
dt,
|
| 1450 |
-
A,
|
| 1451 |
-
B,
|
| 1452 |
-
C,
|
| 1453 |
-
chunk_size=chunk_size,
|
| 1454 |
-
D=D,
|
| 1455 |
-
z=None,
|
| 1456 |
-
dt_bias=dt_bias,
|
| 1457 |
-
initial_states=initial_states,
|
| 1458 |
-
seq_idx=seq_idx,
|
| 1459 |
-
dt_softplus=True,
|
| 1460 |
-
dt_limit=dt_limit,
|
| 1461 |
-
)
|
| 1462 |
-
)
|
| 1463 |
-
# reshape input data into 2D tensor
|
| 1464 |
-
x_rms = rearrange(out_x, "b s h p -> (b s) (h p)")
|
| 1465 |
-
z_rms = rearrange(z, "b s h p -> (b s) (h p)")
|
| 1466 |
-
rmsnorm_weight = rmsnorm_weight.contiguous()
|
| 1467 |
-
if d_nonssm == 0:
|
| 1468 |
-
out = None
|
| 1469 |
-
else:
|
| 1470 |
-
out01 = torch.empty(
|
| 1471 |
-
(batch, seqlen, d_nonssm + dim),
|
| 1472 |
-
dtype=x_rms.dtype,
|
| 1473 |
-
device=x_rms.device,
|
| 1474 |
-
)
|
| 1475 |
-
out = rearrange(out01[..., d_nonssm:], "b s d -> (b s) d")
|
| 1476 |
-
_swiglu_fwd(zx0, out=out01[..., :d_nonssm])
|
| 1477 |
-
out, _, rstd = _layer_norm_fwd(
|
| 1478 |
-
x_rms,
|
| 1479 |
-
rmsnorm_weight,
|
| 1480 |
-
None,
|
| 1481 |
-
rmsnorm_eps,
|
| 1482 |
-
z_rms,
|
| 1483 |
-
out=out,
|
| 1484 |
-
group_size=dim // ngroups,
|
| 1485 |
-
norm_before_gate=norm_before_gate,
|
| 1486 |
-
is_rms_norm=True,
|
| 1487 |
-
)
|
| 1488 |
-
if d_nonssm == 0:
|
| 1489 |
-
out = rearrange(out, "(b s) d -> b s d", b=batch)
|
| 1490 |
-
else:
|
| 1491 |
-
out = out01
|
| 1492 |
-
ctx.outproj_weight_dtype = (
|
| 1493 |
-
outproj_weight.dtype if outproj_weight is not None else None
|
| 1494 |
-
)
|
| 1495 |
-
if outproj_weight is not None:
|
| 1496 |
-
if torch.is_autocast_enabled():
|
| 1497 |
-
dtype = torch.get_autocast_gpu_dtype()
|
| 1498 |
-
out, outproj_weight = out.to(dtype), outproj_weight.to(dtype)
|
| 1499 |
-
outproj_bias = (
|
| 1500 |
-
outproj_bias.to(dtype) if outproj_bias is not None else None
|
| 1501 |
-
)
|
| 1502 |
-
out = F.linear(out, outproj_weight, outproj_bias)
|
| 1503 |
-
else:
|
| 1504 |
-
assert outproj_bias is None
|
| 1505 |
-
ctx.save_for_backward(
|
| 1506 |
-
zxbcdt,
|
| 1507 |
-
conv1d_weight,
|
| 1508 |
-
conv1d_bias,
|
| 1509 |
-
out_x,
|
| 1510 |
-
A,
|
| 1511 |
-
D,
|
| 1512 |
-
dt_bias,
|
| 1513 |
-
initial_states,
|
| 1514 |
-
seq_idx,
|
| 1515 |
-
rmsnorm_weight,
|
| 1516 |
-
rstd,
|
| 1517 |
-
outproj_weight,
|
| 1518 |
-
outproj_bias,
|
| 1519 |
-
)
|
| 1520 |
-
ctx.dt_limit = dt_limit
|
| 1521 |
-
ctx.return_final_states = return_final_states
|
| 1522 |
-
ctx.activation = activation
|
| 1523 |
-
ctx.rmsnorm_eps = rmsnorm_eps
|
| 1524 |
-
ctx.norm_before_gate = norm_before_gate
|
| 1525 |
-
ctx.chunk_size = chunk_size
|
| 1526 |
-
ctx.headdim = headdim
|
| 1527 |
-
ctx.ngroups = ngroups
|
| 1528 |
-
return out if not return_final_states else (out, final_states)
|
| 1529 |
-
|
| 1530 |
-
@staticmethod
|
| 1531 |
-
@custom_bwd
|
| 1532 |
-
def backward(ctx, dout, *args):
|
| 1533 |
-
(
|
| 1534 |
-
zxbcdt,
|
| 1535 |
-
conv1d_weight,
|
| 1536 |
-
conv1d_bias,
|
| 1537 |
-
out,
|
| 1538 |
-
A,
|
| 1539 |
-
D,
|
| 1540 |
-
dt_bias,
|
| 1541 |
-
initial_states,
|
| 1542 |
-
seq_idx,
|
| 1543 |
-
rmsnorm_weight,
|
| 1544 |
-
rstd,
|
| 1545 |
-
outproj_weight,
|
| 1546 |
-
outproj_bias,
|
| 1547 |
-
) = ctx.saved_tensors
|
| 1548 |
-
dfinal_states = args[0] if ctx.return_final_states else None
|
| 1549 |
-
headdim = ctx.headdim
|
| 1550 |
-
nheads = D.shape[0]
|
| 1551 |
-
dim = nheads * headdim
|
| 1552 |
-
assert nheads % ctx.ngroups == 0
|
| 1553 |
-
dstate = (conv1d_weight.shape[0] - dim) // ctx.ngroups // 2
|
| 1554 |
-
d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ctx.ngroups * dstate - nheads) // 2
|
| 1555 |
-
assert d_nonssm >= 0
|
| 1556 |
-
recompute_output = outproj_weight is not None
|
| 1557 |
-
if recompute_output:
|
| 1558 |
-
out_recompute = torch.empty(
|
| 1559 |
-
*out.shape[:2], d_nonssm + dim, device=out.device, dtype=out.dtype
|
| 1560 |
-
)
|
| 1561 |
-
out0_recompute, out1_recompute = out_recompute.split(
|
| 1562 |
-
[d_nonssm, dim], dim=-1
|
| 1563 |
-
)
|
| 1564 |
-
zx0, z, xBC, dt = torch.split(
|
| 1565 |
-
zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1
|
| 1566 |
-
)
|
| 1567 |
-
# Recompute x, B, C
|
| 1568 |
-
xBC_conv = rearrange(
|
| 1569 |
-
causal_conv1d_cuda.causal_conv1d_fwd(
|
| 1570 |
-
rearrange(xBC, "b s d -> b d s"),
|
| 1571 |
-
conv1d_weight,
|
| 1572 |
-
conv1d_bias,
|
| 1573 |
-
seq_idx,
|
| 1574 |
-
None,
|
| 1575 |
-
None,
|
| 1576 |
-
ctx.activation in ["silu", "swish"],
|
| 1577 |
-
),
|
| 1578 |
-
"b d s -> b s d",
|
| 1579 |
-
)
|
| 1580 |
-
x, B, C = torch.split(
|
| 1581 |
-
xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1
|
| 1582 |
-
)
|
| 1583 |
-
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
| 1584 |
-
B = rearrange(B, "b l (g n) -> b l g n", g=ctx.ngroups)
|
| 1585 |
-
C = rearrange(C, "b l (g n) -> b l g n", g=ctx.ngroups)
|
| 1586 |
-
dzxbcdt = torch.empty_like(zxbcdt)
|
| 1587 |
-
dzx0, dz, dxBC_given, ddt_given = torch.split(
|
| 1588 |
-
dzxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1
|
| 1589 |
-
)
|
| 1590 |
-
dxBC = torch.empty_like(xBC)
|
| 1591 |
-
dx, dB, dC = torch.split(
|
| 1592 |
-
dxBC, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1
|
| 1593 |
-
)
|
| 1594 |
-
z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
|
| 1595 |
-
dx = rearrange(dx, "b l (h p) -> b l h p", h=nheads)
|
| 1596 |
-
dB = rearrange(dB, "b l (g n) -> b l g n", g=ctx.ngroups)
|
| 1597 |
-
dC = rearrange(dC, "b l (g n) -> b l g n", g=ctx.ngroups)
|
| 1598 |
-
if outproj_weight is not None:
|
| 1599 |
-
dout_og = dout
|
| 1600 |
-
dout = F.linear(dout, outproj_weight.t())
|
| 1601 |
-
if d_nonssm > 0:
|
| 1602 |
-
dout0, dout = dout.split([d_nonssm, dim], dim=-1)
|
| 1603 |
-
_swiglu_bwd(zx0, dout0, dxy=dzx0, recompute_output=True, out=out0_recompute)
|
| 1604 |
-
dout = rearrange(dout, "b s (h p) -> b s h p", p=headdim)
|
| 1605 |
-
if rmsnorm_weight is None:
|
| 1606 |
-
dz = rearrange(dz, "b l (h p) -> b l h p", h=nheads)
|
| 1607 |
-
dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states, *rest = (
|
| 1608 |
-
_mamba_chunk_scan_combined_bwd(
|
| 1609 |
-
dout,
|
| 1610 |
-
x,
|
| 1611 |
-
dt,
|
| 1612 |
-
A,
|
| 1613 |
-
B,
|
| 1614 |
-
C,
|
| 1615 |
-
out,
|
| 1616 |
-
ctx.chunk_size,
|
| 1617 |
-
D=D,
|
| 1618 |
-
z=z,
|
| 1619 |
-
dt_bias=dt_bias,
|
| 1620 |
-
initial_states=initial_states,
|
| 1621 |
-
dfinal_states=dfinal_states,
|
| 1622 |
-
seq_idx=seq_idx,
|
| 1623 |
-
dt_softplus=True,
|
| 1624 |
-
dt_limit=ctx.dt_limit,
|
| 1625 |
-
dx=dx,
|
| 1626 |
-
ddt=ddt_given,
|
| 1627 |
-
dB=dB,
|
| 1628 |
-
dC=dC,
|
| 1629 |
-
dz=dz,
|
| 1630 |
-
recompute_output=recompute_output,
|
| 1631 |
-
)
|
| 1632 |
-
)
|
| 1633 |
-
out_for_linear = (
|
| 1634 |
-
rearrange(rest[0], "b s h p -> b s (h p)") if recompute_output else None
|
| 1635 |
-
)
|
| 1636 |
-
drmsnorm_weight = None
|
| 1637 |
-
else:
|
| 1638 |
-
batch = dout.shape[0]
|
| 1639 |
-
dy_rms = rearrange(dout, "b s h p -> (b s) (h p)")
|
| 1640 |
-
dz = rearrange(dz, "b l d -> (b l) d")
|
| 1641 |
-
x_rms = rearrange(out, "b s h p -> (b s) (h p)")
|
| 1642 |
-
z_rms = rearrange(z, "b s h p -> (b s) (h p)")
|
| 1643 |
-
out1_recompute = (
|
| 1644 |
-
rearrange(out1_recompute, "b s d -> (b s) d")
|
| 1645 |
-
if recompute_output
|
| 1646 |
-
else None
|
| 1647 |
-
)
|
| 1648 |
-
dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(
|
| 1649 |
-
dy_rms,
|
| 1650 |
-
x_rms,
|
| 1651 |
-
rmsnorm_weight,
|
| 1652 |
-
None,
|
| 1653 |
-
ctx.rmsnorm_eps,
|
| 1654 |
-
None,
|
| 1655 |
-
rstd,
|
| 1656 |
-
z_rms,
|
| 1657 |
-
group_size=dim // ctx.ngroups,
|
| 1658 |
-
norm_before_gate=ctx.norm_before_gate,
|
| 1659 |
-
is_rms_norm=True,
|
| 1660 |
-
recompute_output=recompute_output,
|
| 1661 |
-
dz=dz,
|
| 1662 |
-
out=out1_recompute if recompute_output else None,
|
| 1663 |
-
)
|
| 1664 |
-
out_for_linear = out_recompute if recompute_output else None
|
| 1665 |
-
dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim)
|
| 1666 |
-
dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = (
|
| 1667 |
-
_mamba_chunk_scan_combined_bwd(
|
| 1668 |
-
dout,
|
| 1669 |
-
x,
|
| 1670 |
-
dt,
|
| 1671 |
-
A,
|
| 1672 |
-
B,
|
| 1673 |
-
C,
|
| 1674 |
-
out,
|
| 1675 |
-
ctx.chunk_size,
|
| 1676 |
-
D=D,
|
| 1677 |
-
z=None,
|
| 1678 |
-
dt_bias=dt_bias,
|
| 1679 |
-
initial_states=initial_states,
|
| 1680 |
-
dfinal_states=dfinal_states,
|
| 1681 |
-
seq_idx=seq_idx,
|
| 1682 |
-
dt_softplus=True,
|
| 1683 |
-
dt_limit=ctx.dt_limit,
|
| 1684 |
-
dx=dx,
|
| 1685 |
-
ddt=ddt_given,
|
| 1686 |
-
dB=dB,
|
| 1687 |
-
dC=dC,
|
| 1688 |
-
)
|
| 1689 |
-
)
|
| 1690 |
-
|
| 1691 |
-
if outproj_weight is not None:
|
| 1692 |
-
doutproj_weight = torch.einsum("bso,bsd->od", dout_og, out_for_linear)
|
| 1693 |
-
doutproj_bias = (
|
| 1694 |
-
dout_og.sum(dim=(0, 1)) if outproj_bias is not None else None
|
| 1695 |
-
)
|
| 1696 |
-
else:
|
| 1697 |
-
doutproj_weight, doutproj_bias = None, None
|
| 1698 |
-
dxBC_given = rearrange(dxBC_given, "b s d -> b d s")
|
| 1699 |
-
dxBC_given, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
|
| 1700 |
-
rearrange(xBC, "b s d -> b d s"),
|
| 1701 |
-
conv1d_weight,
|
| 1702 |
-
conv1d_bias,
|
| 1703 |
-
rearrange(dxBC, "b s d -> b d s"),
|
| 1704 |
-
seq_idx,
|
| 1705 |
-
None,
|
| 1706 |
-
None,
|
| 1707 |
-
dxBC_given,
|
| 1708 |
-
False,
|
| 1709 |
-
ctx.activation in ["silu", "swish"],
|
| 1710 |
-
)
|
| 1711 |
-
dxBC_given = rearrange(dxBC_given, "b d s -> b s d")
|
| 1712 |
-
return (
|
| 1713 |
-
dzxbcdt,
|
| 1714 |
-
dweight,
|
| 1715 |
-
dbias,
|
| 1716 |
-
ddt_bias,
|
| 1717 |
-
dA,
|
| 1718 |
-
dD,
|
| 1719 |
-
None,
|
| 1720 |
-
dinitial_states,
|
| 1721 |
-
None,
|
| 1722 |
-
None,
|
| 1723 |
-
None,
|
| 1724 |
-
None,
|
| 1725 |
-
drmsnorm_weight,
|
| 1726 |
-
None,
|
| 1727 |
-
doutproj_weight,
|
| 1728 |
-
doutproj_bias,
|
| 1729 |
-
None,
|
| 1730 |
-
None,
|
| 1731 |
-
None,
|
| 1732 |
-
)
|
| 1733 |
-
|
| 1734 |
-
|
| 1735 |
-
def mamba_split_conv1d_scan_combined(
|
| 1736 |
-
zxbcdt,
|
| 1737 |
-
conv1d_weight,
|
| 1738 |
-
conv1d_bias,
|
| 1739 |
-
dt_bias,
|
| 1740 |
-
A,
|
| 1741 |
-
D,
|
| 1742 |
-
chunk_size,
|
| 1743 |
-
initial_states=None,
|
| 1744 |
-
seq_idx=None,
|
| 1745 |
-
dt_limit=(0.0, float("inf")),
|
| 1746 |
-
return_final_states=False,
|
| 1747 |
-
activation="silu",
|
| 1748 |
-
rmsnorm_weight=None,
|
| 1749 |
-
rmsnorm_eps=1e-6,
|
| 1750 |
-
outproj_weight=None,
|
| 1751 |
-
outproj_bias=None,
|
| 1752 |
-
headdim=None,
|
| 1753 |
-
ngroups=1,
|
| 1754 |
-
norm_before_gate=True,
|
| 1755 |
-
):
|
| 1756 |
-
"""
|
| 1757 |
-
Argument:
|
| 1758 |
-
zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
|
| 1759 |
-
conv1d_weight: (dim + 2 * ngroups * dstate, width)
|
| 1760 |
-
conv1d_bias: (dim + 2 * ngroups * dstate,)
|
| 1761 |
-
dt_bias: (nheads,)
|
| 1762 |
-
A: (nheads)
|
| 1763 |
-
D: (nheads, headdim) or (nheads,)
|
| 1764 |
-
initial_states: (batch, nheads, headdim, dstate)
|
| 1765 |
-
seq_idx: (batch, seqlen), int32
|
| 1766 |
-
rmsnorm_weight: (dim,)
|
| 1767 |
-
outproj_weight: (out_dim, dim)
|
| 1768 |
-
outproj_bias: (out_dim,)
|
| 1769 |
-
headdim: if D is 1D, headdim must be passed in
|
| 1770 |
-
norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
|
| 1771 |
-
Return:
|
| 1772 |
-
out: (batch, seqlen, dim)
|
| 1773 |
-
"""
|
| 1774 |
-
return MambaSplitConv1dScanCombinedFn.apply(
|
| 1775 |
-
zxbcdt,
|
| 1776 |
-
conv1d_weight,
|
| 1777 |
-
conv1d_bias,
|
| 1778 |
-
dt_bias,
|
| 1779 |
-
A,
|
| 1780 |
-
D,
|
| 1781 |
-
chunk_size,
|
| 1782 |
-
initial_states,
|
| 1783 |
-
seq_idx,
|
| 1784 |
-
dt_limit,
|
| 1785 |
-
return_final_states,
|
| 1786 |
-
activation,
|
| 1787 |
-
rmsnorm_weight,
|
| 1788 |
-
rmsnorm_eps,
|
| 1789 |
-
outproj_weight,
|
| 1790 |
-
outproj_bias,
|
| 1791 |
-
headdim,
|
| 1792 |
-
ngroups,
|
| 1793 |
-
norm_before_gate,
|
| 1794 |
-
)
|
| 1795 |
-
|
| 1796 |
-
|
| 1797 |
-
def mamba_split_conv1d_scan_ref(
|
| 1798 |
-
zxbcdt,
|
| 1799 |
-
conv1d_weight,
|
| 1800 |
-
conv1d_bias,
|
| 1801 |
-
dt_bias,
|
| 1802 |
-
A,
|
| 1803 |
-
D,
|
| 1804 |
-
chunk_size,
|
| 1805 |
-
dt_limit=(0.0, float("inf")),
|
| 1806 |
-
activation="silu",
|
| 1807 |
-
rmsnorm_weight=None,
|
| 1808 |
-
rmsnorm_eps=1e-6,
|
| 1809 |
-
outproj_weight=None,
|
| 1810 |
-
outproj_bias=None,
|
| 1811 |
-
headdim=None,
|
| 1812 |
-
ngroups=1,
|
| 1813 |
-
norm_before_gate=True,
|
| 1814 |
-
):
|
| 1815 |
-
"""
|
| 1816 |
-
Argument:
|
| 1817 |
-
zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
|
| 1818 |
-
conv1d_weight: (dim + 2 * ngroups * dstate, width)
|
| 1819 |
-
conv1d_bias: (dim + 2 * ngroups * dstate,)
|
| 1820 |
-
dt_bias: (nheads,)
|
| 1821 |
-
A: (nheads)
|
| 1822 |
-
D: (nheads, headdim) or (nheads,)
|
| 1823 |
-
rmsnorm_weight: (dim,)
|
| 1824 |
-
outproj_weight: (out_dim, dim)
|
| 1825 |
-
outproj_bias: (out_dim,)
|
| 1826 |
-
headdim: if D is 1D, headdim must be passed in
|
| 1827 |
-
norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
|
| 1828 |
-
Return:
|
| 1829 |
-
out: (batch, seqlen, dim)
|
| 1830 |
-
"""
|
| 1831 |
-
if D.dim() == 1:
|
| 1832 |
-
assert headdim is not None
|
| 1833 |
-
(nheads,) = D.shape
|
| 1834 |
-
else:
|
| 1835 |
-
nheads, headdim = D.shape
|
| 1836 |
-
assert nheads % ngroups == 0
|
| 1837 |
-
batch, seqlen, _ = zxbcdt.shape
|
| 1838 |
-
dim = nheads * headdim
|
| 1839 |
-
dstate = (zxbcdt.shape[-1] - 2 * dim - nheads) // ngroups // 2
|
| 1840 |
-
assert zxbcdt.shape == (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads)
|
| 1841 |
-
assert dt_bias.shape == (nheads,)
|
| 1842 |
-
assert A.shape == (nheads,)
|
| 1843 |
-
if rmsnorm_weight is not None:
|
| 1844 |
-
assert rmsnorm_weight.shape == (dim,)
|
| 1845 |
-
z, xBC, dt = torch.split(zxbcdt, [dim, dim + 2 * ngroups * dstate, nheads], dim=-1)
|
| 1846 |
-
xBC = rearrange(
|
| 1847 |
-
causal_conv1d_fn(
|
| 1848 |
-
rearrange(xBC, "b s d -> b d s"),
|
| 1849 |
-
conv1d_weight,
|
| 1850 |
-
conv1d_bias,
|
| 1851 |
-
activation=activation,
|
| 1852 |
-
),
|
| 1853 |
-
"b d s -> b s d",
|
| 1854 |
-
)
|
| 1855 |
-
x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
|
| 1856 |
-
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
| 1857 |
-
B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
|
| 1858 |
-
C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
|
| 1859 |
-
z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
|
| 1860 |
-
out = ssd_selective_scan(
|
| 1861 |
-
x,
|
| 1862 |
-
dt.to(x.dtype),
|
| 1863 |
-
A,
|
| 1864 |
-
B,
|
| 1865 |
-
C,
|
| 1866 |
-
D=D.float(),
|
| 1867 |
-
z=z if rmsnorm_weight is None else None,
|
| 1868 |
-
dt_bias=dt_bias,
|
| 1869 |
-
dt_softplus=True,
|
| 1870 |
-
dt_limit=dt_limit,
|
| 1871 |
-
)
|
| 1872 |
-
out = rearrange(out, "b s h p -> b s (h p)")
|
| 1873 |
-
if rmsnorm_weight is not None:
|
| 1874 |
-
out = rmsnorm_fn(
|
| 1875 |
-
out,
|
| 1876 |
-
rmsnorm_weight,
|
| 1877 |
-
None,
|
| 1878 |
-
z=rearrange(z, "b l h p -> b l (h p)"),
|
| 1879 |
-
eps=rmsnorm_eps,
|
| 1880 |
-
norm_before_gate=norm_before_gate,
|
| 1881 |
-
)
|
| 1882 |
-
if outproj_weight is not None:
|
| 1883 |
-
out = F.linear(out, outproj_weight, outproj_bias)
|
| 1884 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/__init__.py
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
__version__ = "2.2.4"
|
| 2 |
-
|
| 3 |
-
from .ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
| 4 |
-
from .modules.mamba_simple import Mamba
|
| 5 |
-
from .modules.mamba2 import Mamba2
|
| 6 |
-
from .models.mixer_seq_simple import MambaLMHeadModel
|
| 7 |
-
|
| 8 |
-
__all__ = [
|
| 9 |
-
"selective_scan_fn",
|
| 10 |
-
"mamba_inner_fn",
|
| 11 |
-
"Mamba",
|
| 12 |
-
"Mamba2",
|
| 13 |
-
"MambaLMHeadModel",
|
| 14 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/distributed/__init__.py
DELETED
|
File without changes
|
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py
DELETED
|
@@ -1,326 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2024, Tri Dao.
|
| 2 |
-
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
|
| 3 |
-
from typing import Optional
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn as nn
|
| 7 |
-
import torch.nn.functional as F
|
| 8 |
-
from torch import Tensor
|
| 9 |
-
from torch.distributed import ProcessGroup
|
| 10 |
-
from ..utils.torch import custom_bwd, custom_fwd
|
| 11 |
-
|
| 12 |
-
from einops import rearrange
|
| 13 |
-
|
| 14 |
-
from ..distributed.distributed_utils import (
|
| 15 |
-
all_gather_raw,
|
| 16 |
-
all_reduce,
|
| 17 |
-
all_reduce_raw,
|
| 18 |
-
reduce_scatter,
|
| 19 |
-
reduce_scatter_raw,
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class ParallelLinearFunc(torch.autograd.Function):
|
| 24 |
-
@staticmethod
|
| 25 |
-
@custom_fwd
|
| 26 |
-
def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
|
| 27 |
-
"""
|
| 28 |
-
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
|
| 29 |
-
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
|
| 30 |
-
"""
|
| 31 |
-
ctx.compute_weight_gradient = weight.requires_grad
|
| 32 |
-
ctx.process_group = process_group
|
| 33 |
-
ctx.sequence_parallel = sequence_parallel
|
| 34 |
-
|
| 35 |
-
if torch.is_autocast_enabled():
|
| 36 |
-
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
| 37 |
-
x = x.contiguous()
|
| 38 |
-
if process_group is not None and sequence_parallel:
|
| 39 |
-
# We want to kick off the all_gather early, before weight dtype conversion
|
| 40 |
-
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
| 41 |
-
else:
|
| 42 |
-
total_x = x
|
| 43 |
-
|
| 44 |
-
if torch.is_autocast_enabled():
|
| 45 |
-
weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
|
| 46 |
-
bias = (
|
| 47 |
-
bias.to(dtype=torch.get_autocast_gpu_dtype())
|
| 48 |
-
if bias is not None
|
| 49 |
-
else None
|
| 50 |
-
)
|
| 51 |
-
weight = weight.contiguous()
|
| 52 |
-
if process_group is not None and sequence_parallel:
|
| 53 |
-
handle_x.wait()
|
| 54 |
-
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
|
| 55 |
-
batch_dim = batch_shape.numel()
|
| 56 |
-
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
|
| 57 |
-
output = F.linear(total_x, weight, bias)
|
| 58 |
-
if ctx.compute_weight_gradient:
|
| 59 |
-
ctx.save_for_backward(x, weight)
|
| 60 |
-
else:
|
| 61 |
-
ctx.save_for_backward(weight)
|
| 62 |
-
return output
|
| 63 |
-
|
| 64 |
-
@staticmethod
|
| 65 |
-
@custom_bwd
|
| 66 |
-
def backward(ctx, grad_output):
|
| 67 |
-
grad_output = grad_output.contiguous()
|
| 68 |
-
process_group = ctx.process_group
|
| 69 |
-
sequence_parallel = ctx.sequence_parallel
|
| 70 |
-
if ctx.compute_weight_gradient:
|
| 71 |
-
x, weight = ctx.saved_tensors
|
| 72 |
-
if process_group is not None and sequence_parallel:
|
| 73 |
-
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
| 74 |
-
else:
|
| 75 |
-
total_x = x
|
| 76 |
-
else:
|
| 77 |
-
(weight,) = ctx.saved_tensors
|
| 78 |
-
total_x = None
|
| 79 |
-
batch_shape = grad_output.shape[:-1]
|
| 80 |
-
batch_dim = batch_shape.numel()
|
| 81 |
-
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
| 82 |
-
if ctx.needs_input_grad[0]:
|
| 83 |
-
grad_input = F.linear(grad_output, weight.t())
|
| 84 |
-
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
| 85 |
-
if process_group is not None:
|
| 86 |
-
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
|
| 87 |
-
grad_input, handle_grad_input = reduce_fn(
|
| 88 |
-
grad_input, process_group, async_op=True
|
| 89 |
-
)
|
| 90 |
-
else:
|
| 91 |
-
grad_input = None
|
| 92 |
-
if ctx.needs_input_grad[1]:
|
| 93 |
-
assert ctx.compute_weight_gradient
|
| 94 |
-
if process_group is not None and sequence_parallel:
|
| 95 |
-
handle_x.wait()
|
| 96 |
-
grad_weight = torch.einsum(
|
| 97 |
-
"bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1])
|
| 98 |
-
)
|
| 99 |
-
else:
|
| 100 |
-
grad_weight = None
|
| 101 |
-
grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None
|
| 102 |
-
if process_group is not None and ctx.needs_input_grad[0]:
|
| 103 |
-
handle_grad_input.wait()
|
| 104 |
-
return grad_input, grad_weight, grad_bias, None, None
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
def parallel_linear_func(
|
| 108 |
-
x: Tensor,
|
| 109 |
-
weight: Tensor,
|
| 110 |
-
bias: Optional[Tensor] = None,
|
| 111 |
-
process_group: Optional[ProcessGroup] = None,
|
| 112 |
-
sequence_parallel: bool = True,
|
| 113 |
-
):
|
| 114 |
-
return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel)
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
class ColumnParallelLinear(nn.Linear):
|
| 118 |
-
def __init__(
|
| 119 |
-
self,
|
| 120 |
-
in_features: int,
|
| 121 |
-
out_features: int,
|
| 122 |
-
process_group: ProcessGroup,
|
| 123 |
-
bias: bool = True,
|
| 124 |
-
sequence_parallel=True,
|
| 125 |
-
multiple_of=1,
|
| 126 |
-
device=None,
|
| 127 |
-
dtype=None,
|
| 128 |
-
) -> None:
|
| 129 |
-
world_size = torch.distributed.get_world_size(process_group)
|
| 130 |
-
if out_features % multiple_of:
|
| 131 |
-
raise ValueError(
|
| 132 |
-
f"out_features ({out_features}) must be a multiple of {multiple_of}"
|
| 133 |
-
)
|
| 134 |
-
multiple = out_features // multiple_of
|
| 135 |
-
# We want to split @multiple across world_size, but it could be an uneven split
|
| 136 |
-
div = multiple // world_size
|
| 137 |
-
mod = multiple % world_size
|
| 138 |
-
# The first @mod ranks get @div + 1 copies, the rest get @div copies
|
| 139 |
-
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
| 140 |
-
super().__init__(
|
| 141 |
-
in_features,
|
| 142 |
-
local_multiple * multiple_of,
|
| 143 |
-
bias=bias,
|
| 144 |
-
device=device,
|
| 145 |
-
dtype=dtype,
|
| 146 |
-
)
|
| 147 |
-
self.process_group = process_group
|
| 148 |
-
self.sequence_parallel = sequence_parallel
|
| 149 |
-
|
| 150 |
-
def forward(self, x):
|
| 151 |
-
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
| 152 |
-
# we do an all_gather of x before doing the matmul.
|
| 153 |
-
# If not, then the input is already gathered.
|
| 154 |
-
return parallel_linear_func(
|
| 155 |
-
x,
|
| 156 |
-
self.weight,
|
| 157 |
-
self.bias,
|
| 158 |
-
process_group=self.process_group,
|
| 159 |
-
sequence_parallel=self.sequence_parallel,
|
| 160 |
-
)
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
class RowParallelLinear(nn.Linear):
|
| 164 |
-
def __init__(
|
| 165 |
-
self,
|
| 166 |
-
in_features: int,
|
| 167 |
-
out_features: int,
|
| 168 |
-
process_group: ProcessGroup,
|
| 169 |
-
bias: bool = True,
|
| 170 |
-
sequence_parallel=True,
|
| 171 |
-
multiple_of=1,
|
| 172 |
-
device=None,
|
| 173 |
-
dtype=None,
|
| 174 |
-
) -> None:
|
| 175 |
-
world_size = torch.distributed.get_world_size(process_group)
|
| 176 |
-
rank = torch.distributed.get_rank(process_group)
|
| 177 |
-
if in_features % multiple_of:
|
| 178 |
-
raise ValueError(
|
| 179 |
-
f"in_features ({in_features}) must be a multiple of {multiple_of}"
|
| 180 |
-
)
|
| 181 |
-
multiple = in_features // multiple_of
|
| 182 |
-
# We want to split @multiple across world_size, but it could be an uneven split
|
| 183 |
-
div = multiple // world_size
|
| 184 |
-
mod = multiple % world_size
|
| 185 |
-
# The first @mod ranks get @div + 1 copies, the rest get @div copies
|
| 186 |
-
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
| 187 |
-
# Only rank 0 will have bias
|
| 188 |
-
super().__init__(
|
| 189 |
-
local_multiple * multiple_of,
|
| 190 |
-
out_features,
|
| 191 |
-
bias=bias and rank == 0,
|
| 192 |
-
device=device,
|
| 193 |
-
dtype=dtype,
|
| 194 |
-
)
|
| 195 |
-
self.process_group = process_group
|
| 196 |
-
self.sequence_parallel = sequence_parallel
|
| 197 |
-
|
| 198 |
-
def forward(self, x):
|
| 199 |
-
"""
|
| 200 |
-
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
|
| 201 |
-
a reduce_scatter of the result.
|
| 202 |
-
"""
|
| 203 |
-
out = parallel_linear_func(x, self.weight, self.bias)
|
| 204 |
-
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
| 205 |
-
return reduce_fn(out, self.process_group)
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
class VocabParallelEmbedding(nn.Embedding):
|
| 209 |
-
def __init__(
|
| 210 |
-
self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs
|
| 211 |
-
):
|
| 212 |
-
self.process_group = process_group
|
| 213 |
-
if process_group is not None:
|
| 214 |
-
world_size = torch.distributed.get_world_size(process_group)
|
| 215 |
-
if num_embeddings % world_size != 0:
|
| 216 |
-
raise ValueError(
|
| 217 |
-
f"num_embeddings ({num_embeddings}) must be divisible by "
|
| 218 |
-
f"world_size ({world_size})"
|
| 219 |
-
)
|
| 220 |
-
if world_size > 1 and padding_idx is not None:
|
| 221 |
-
raise RuntimeError("ParallelEmbedding does not support padding_idx")
|
| 222 |
-
else:
|
| 223 |
-
world_size = 1
|
| 224 |
-
super().__init__(
|
| 225 |
-
num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs
|
| 226 |
-
)
|
| 227 |
-
|
| 228 |
-
def forward(self, input: Tensor) -> Tensor:
|
| 229 |
-
if self.process_group is None:
|
| 230 |
-
return super().forward(input)
|
| 231 |
-
else:
|
| 232 |
-
rank = torch.distributed.get_rank(self.process_group)
|
| 233 |
-
vocab_size = self.num_embeddings
|
| 234 |
-
vocab_start_index, vocab_end_index = (
|
| 235 |
-
rank * vocab_size,
|
| 236 |
-
(rank + 1) * vocab_size,
|
| 237 |
-
)
|
| 238 |
-
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
| 239 |
-
input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
|
| 240 |
-
input = input - vocab_start_index
|
| 241 |
-
input[input_ids_mask] = 0
|
| 242 |
-
embeddings = super().forward(input)
|
| 243 |
-
embeddings[input_ids_mask] = 0.0
|
| 244 |
-
return embeddings
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
class ColumnParallelEmbedding(nn.Embedding):
|
| 248 |
-
def __init__(
|
| 249 |
-
self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs
|
| 250 |
-
):
|
| 251 |
-
self.process_group = process_group
|
| 252 |
-
if process_group is not None:
|
| 253 |
-
world_size = torch.distributed.get_world_size(process_group)
|
| 254 |
-
if embedding_dim % world_size != 0:
|
| 255 |
-
raise ValueError(
|
| 256 |
-
f"embedding_dim ({embedding_dim}) must be divisible by "
|
| 257 |
-
f"world_size ({world_size})"
|
| 258 |
-
)
|
| 259 |
-
else:
|
| 260 |
-
world_size = 1
|
| 261 |
-
super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
class ParallelEmbeddings(nn.Module):
|
| 265 |
-
def __init__(
|
| 266 |
-
self,
|
| 267 |
-
embed_dim,
|
| 268 |
-
vocab_size,
|
| 269 |
-
max_position_embeddings,
|
| 270 |
-
process_group,
|
| 271 |
-
padding_idx=None,
|
| 272 |
-
sequence_parallel=True,
|
| 273 |
-
device=None,
|
| 274 |
-
dtype=None,
|
| 275 |
-
):
|
| 276 |
-
"""
|
| 277 |
-
If max_position_embeddings <= 0, there's no position embeddings
|
| 278 |
-
"""
|
| 279 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
| 280 |
-
super().__init__()
|
| 281 |
-
self.process_group = process_group
|
| 282 |
-
self.sequence_parallel = sequence_parallel
|
| 283 |
-
self.word_embeddings = VocabParallelEmbedding(
|
| 284 |
-
vocab_size,
|
| 285 |
-
embed_dim,
|
| 286 |
-
padding_idx=padding_idx,
|
| 287 |
-
process_group=process_group,
|
| 288 |
-
**factory_kwargs,
|
| 289 |
-
)
|
| 290 |
-
self.max_position_embeddings = max_position_embeddings
|
| 291 |
-
if self.max_position_embeddings > 0:
|
| 292 |
-
self.position_embeddings = ColumnParallelEmbedding(
|
| 293 |
-
max_position_embeddings,
|
| 294 |
-
embed_dim,
|
| 295 |
-
process_group=process_group,
|
| 296 |
-
**factory_kwargs,
|
| 297 |
-
)
|
| 298 |
-
|
| 299 |
-
def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
|
| 300 |
-
"""
|
| 301 |
-
input_ids: (batch, seqlen)
|
| 302 |
-
position_ids: (batch, seqlen)
|
| 303 |
-
"""
|
| 304 |
-
batch_size, seqlen = input_ids.shape
|
| 305 |
-
world_size = torch.distributed.get_world_size(self.process_group)
|
| 306 |
-
embeddings = self.word_embeddings(input_ids)
|
| 307 |
-
if self.max_position_embeddings > 0:
|
| 308 |
-
if position_ids is None:
|
| 309 |
-
position_ids = torch.arange(
|
| 310 |
-
seqlen, dtype=torch.long, device=input_ids.device
|
| 311 |
-
)
|
| 312 |
-
position_embeddings = self.position_embeddings(position_ids)
|
| 313 |
-
if world_size <= 1:
|
| 314 |
-
embeddings = embeddings + position_embeddings
|
| 315 |
-
else:
|
| 316 |
-
partition_dim = self.position_embeddings.embedding_dim
|
| 317 |
-
rank = torch.distributed.get_rank(self.process_group)
|
| 318 |
-
embeddings[
|
| 319 |
-
..., rank * partition_dim : (rank + 1) * partition_dim
|
| 320 |
-
] += position_embeddings
|
| 321 |
-
if combine_batch_seqlen_dim:
|
| 322 |
-
embeddings = rearrange(embeddings, "b s d -> (b s) d")
|
| 323 |
-
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
| 324 |
-
return (
|
| 325 |
-
embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
|
| 326 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/models/__init__.py
DELETED
|
File without changes
|
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py
DELETED
|
@@ -1,338 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2023, Albert Gu, Tri Dao.
|
| 2 |
-
|
| 3 |
-
import math
|
| 4 |
-
from functools import partial
|
| 5 |
-
import json
|
| 6 |
-
import os
|
| 7 |
-
import copy
|
| 8 |
-
|
| 9 |
-
from collections import namedtuple
|
| 10 |
-
|
| 11 |
-
import torch
|
| 12 |
-
import torch.nn as nn
|
| 13 |
-
|
| 14 |
-
from .config_mamba import MambaConfig
|
| 15 |
-
from ..modules.mamba_simple import Mamba
|
| 16 |
-
from ..modules.mamba2 import Mamba2
|
| 17 |
-
from ..modules.mha import MHA
|
| 18 |
-
from ..modules.mlp import GatedMLP
|
| 19 |
-
from ..modules.block import Block
|
| 20 |
-
from ..utils.generation import GenerationMixin
|
| 21 |
-
from ..utils.hf import load_config_hf, load_state_dict_hf
|
| 22 |
-
|
| 23 |
-
try:
|
| 24 |
-
from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
|
| 25 |
-
except ImportError:
|
| 26 |
-
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def create_block(
|
| 30 |
-
d_model,
|
| 31 |
-
d_intermediate,
|
| 32 |
-
ssm_cfg=None,
|
| 33 |
-
attn_layer_idx=None,
|
| 34 |
-
attn_cfg=None,
|
| 35 |
-
norm_epsilon=1e-5,
|
| 36 |
-
rms_norm=False,
|
| 37 |
-
residual_in_fp32=False,
|
| 38 |
-
fused_add_norm=False,
|
| 39 |
-
layer_idx=None,
|
| 40 |
-
device=None,
|
| 41 |
-
dtype=None,
|
| 42 |
-
):
|
| 43 |
-
if ssm_cfg is None:
|
| 44 |
-
ssm_cfg = {}
|
| 45 |
-
if attn_layer_idx is None:
|
| 46 |
-
attn_layer_idx = []
|
| 47 |
-
if attn_cfg is None:
|
| 48 |
-
attn_cfg = {}
|
| 49 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
| 50 |
-
if layer_idx not in attn_layer_idx:
|
| 51 |
-
# Create a copy of the config to modify
|
| 52 |
-
ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
|
| 53 |
-
ssm_layer = ssm_cfg.pop("layer", "Mamba1")
|
| 54 |
-
if ssm_layer not in ["Mamba1", "Mamba2"]:
|
| 55 |
-
raise ValueError(
|
| 56 |
-
f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2"
|
| 57 |
-
)
|
| 58 |
-
mixer_cls = partial(
|
| 59 |
-
Mamba2 if ssm_layer == "Mamba2" else Mamba,
|
| 60 |
-
layer_idx=layer_idx,
|
| 61 |
-
**ssm_cfg,
|
| 62 |
-
**factory_kwargs,
|
| 63 |
-
)
|
| 64 |
-
else:
|
| 65 |
-
mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
|
| 66 |
-
norm_cls = partial(
|
| 67 |
-
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
| 68 |
-
)
|
| 69 |
-
if d_intermediate == 0:
|
| 70 |
-
mlp_cls = nn.Identity
|
| 71 |
-
else:
|
| 72 |
-
mlp_cls = partial(
|
| 73 |
-
GatedMLP,
|
| 74 |
-
hidden_features=d_intermediate,
|
| 75 |
-
out_features=d_model,
|
| 76 |
-
**factory_kwargs,
|
| 77 |
-
)
|
| 78 |
-
block = Block(
|
| 79 |
-
d_model,
|
| 80 |
-
mixer_cls,
|
| 81 |
-
mlp_cls,
|
| 82 |
-
norm_cls=norm_cls,
|
| 83 |
-
fused_add_norm=fused_add_norm,
|
| 84 |
-
residual_in_fp32=residual_in_fp32,
|
| 85 |
-
)
|
| 86 |
-
block.layer_idx = layer_idx
|
| 87 |
-
return block
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
| 91 |
-
def _init_weights(
|
| 92 |
-
module,
|
| 93 |
-
n_layer,
|
| 94 |
-
initializer_range=0.02, # Now only used for embedding layer.
|
| 95 |
-
rescale_prenorm_residual=True,
|
| 96 |
-
n_residuals_per_layer=1, # Change to 2 if we have MLP
|
| 97 |
-
):
|
| 98 |
-
if isinstance(module, nn.Linear):
|
| 99 |
-
if module.bias is not None:
|
| 100 |
-
if not getattr(module.bias, "_no_reinit", False):
|
| 101 |
-
nn.init.zeros_(module.bias)
|
| 102 |
-
elif isinstance(module, nn.Embedding):
|
| 103 |
-
nn.init.normal_(module.weight, std=initializer_range)
|
| 104 |
-
|
| 105 |
-
if rescale_prenorm_residual:
|
| 106 |
-
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
| 107 |
-
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
| 108 |
-
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
| 109 |
-
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
| 110 |
-
#
|
| 111 |
-
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
| 112 |
-
for name, p in module.named_parameters():
|
| 113 |
-
if name in ["out_proj.weight", "fc2.weight"]:
|
| 114 |
-
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
| 115 |
-
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
| 116 |
-
# We need to reinit p since this code could be called multiple times
|
| 117 |
-
# Having just p *= scale would repeatedly scale it down
|
| 118 |
-
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
| 119 |
-
with torch.no_grad():
|
| 120 |
-
p /= math.sqrt(n_residuals_per_layer * n_layer)
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
class MixerModel(nn.Module):
|
| 124 |
-
def __init__(
|
| 125 |
-
self,
|
| 126 |
-
d_model: int,
|
| 127 |
-
n_layer: int,
|
| 128 |
-
d_intermediate: int,
|
| 129 |
-
vocab_size: int,
|
| 130 |
-
ssm_cfg=None,
|
| 131 |
-
attn_layer_idx=None,
|
| 132 |
-
attn_cfg=None,
|
| 133 |
-
norm_epsilon: float = 1e-5,
|
| 134 |
-
rms_norm: bool = False,
|
| 135 |
-
initializer_cfg=None,
|
| 136 |
-
fused_add_norm=False,
|
| 137 |
-
residual_in_fp32=False,
|
| 138 |
-
device=None,
|
| 139 |
-
dtype=None,
|
| 140 |
-
) -> None:
|
| 141 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
| 142 |
-
super().__init__()
|
| 143 |
-
self.residual_in_fp32 = residual_in_fp32
|
| 144 |
-
|
| 145 |
-
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
|
| 146 |
-
|
| 147 |
-
# We change the order of residual and layer norm:
|
| 148 |
-
# Instead of LN -> Attn / MLP -> Add, we do:
|
| 149 |
-
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
|
| 150 |
-
# the main branch (output of MLP / Mixer). The model definition is unchanged.
|
| 151 |
-
# This is for performance reason: we can fuse add + layer_norm.
|
| 152 |
-
self.fused_add_norm = fused_add_norm
|
| 153 |
-
if self.fused_add_norm:
|
| 154 |
-
if layer_norm_fn is None or rms_norm_fn is None:
|
| 155 |
-
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
|
| 156 |
-
|
| 157 |
-
self.layers = nn.ModuleList(
|
| 158 |
-
[
|
| 159 |
-
create_block(
|
| 160 |
-
d_model,
|
| 161 |
-
d_intermediate=d_intermediate,
|
| 162 |
-
ssm_cfg=ssm_cfg,
|
| 163 |
-
attn_layer_idx=attn_layer_idx,
|
| 164 |
-
attn_cfg=attn_cfg,
|
| 165 |
-
norm_epsilon=norm_epsilon,
|
| 166 |
-
rms_norm=rms_norm,
|
| 167 |
-
residual_in_fp32=residual_in_fp32,
|
| 168 |
-
fused_add_norm=fused_add_norm,
|
| 169 |
-
layer_idx=i,
|
| 170 |
-
**factory_kwargs,
|
| 171 |
-
)
|
| 172 |
-
for i in range(n_layer)
|
| 173 |
-
]
|
| 174 |
-
)
|
| 175 |
-
|
| 176 |
-
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
|
| 177 |
-
d_model, eps=norm_epsilon, **factory_kwargs
|
| 178 |
-
)
|
| 179 |
-
|
| 180 |
-
self.apply(
|
| 181 |
-
partial(
|
| 182 |
-
_init_weights,
|
| 183 |
-
n_layer=n_layer,
|
| 184 |
-
**(initializer_cfg if initializer_cfg is not None else {}),
|
| 185 |
-
n_residuals_per_layer=(
|
| 186 |
-
1 if d_intermediate == 0 else 2
|
| 187 |
-
), # 2 if we have MLP
|
| 188 |
-
)
|
| 189 |
-
)
|
| 190 |
-
|
| 191 |
-
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 192 |
-
return {
|
| 193 |
-
i: layer.allocate_inference_cache(
|
| 194 |
-
batch_size, max_seqlen, dtype=dtype, **kwargs
|
| 195 |
-
)
|
| 196 |
-
for i, layer in enumerate(self.layers)
|
| 197 |
-
}
|
| 198 |
-
|
| 199 |
-
def forward(self, input_ids, inference_params=None, **mixer_kwargs):
|
| 200 |
-
hidden_states = self.embedding(input_ids)
|
| 201 |
-
residual = None
|
| 202 |
-
for layer in self.layers:
|
| 203 |
-
hidden_states, residual = layer(
|
| 204 |
-
hidden_states,
|
| 205 |
-
residual,
|
| 206 |
-
inference_params=inference_params,
|
| 207 |
-
**mixer_kwargs,
|
| 208 |
-
)
|
| 209 |
-
if not self.fused_add_norm:
|
| 210 |
-
residual = (
|
| 211 |
-
(hidden_states + residual) if residual is not None else hidden_states
|
| 212 |
-
)
|
| 213 |
-
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
| 214 |
-
else:
|
| 215 |
-
# Set prenorm=False here since we don't need the residual
|
| 216 |
-
hidden_states = layer_norm_fn(
|
| 217 |
-
hidden_states,
|
| 218 |
-
self.norm_f.weight,
|
| 219 |
-
self.norm_f.bias,
|
| 220 |
-
eps=self.norm_f.eps,
|
| 221 |
-
residual=residual,
|
| 222 |
-
prenorm=False,
|
| 223 |
-
residual_in_fp32=self.residual_in_fp32,
|
| 224 |
-
is_rms_norm=isinstance(self.norm_f, RMSNorm),
|
| 225 |
-
)
|
| 226 |
-
return hidden_states
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
class MambaLMHeadModel(nn.Module, GenerationMixin):
|
| 230 |
-
|
| 231 |
-
def __init__(
|
| 232 |
-
self,
|
| 233 |
-
config: MambaConfig,
|
| 234 |
-
initializer_cfg=None,
|
| 235 |
-
device=None,
|
| 236 |
-
dtype=None,
|
| 237 |
-
) -> None:
|
| 238 |
-
self.config = config
|
| 239 |
-
d_model = config.d_model
|
| 240 |
-
n_layer = config.n_layer
|
| 241 |
-
d_intermediate = config.d_intermediate
|
| 242 |
-
vocab_size = config.vocab_size
|
| 243 |
-
ssm_cfg = config.ssm_cfg
|
| 244 |
-
attn_layer_idx = config.attn_layer_idx
|
| 245 |
-
attn_cfg = config.attn_cfg
|
| 246 |
-
rms_norm = config.rms_norm
|
| 247 |
-
residual_in_fp32 = config.residual_in_fp32
|
| 248 |
-
fused_add_norm = config.fused_add_norm
|
| 249 |
-
pad_vocab_size_multiple = config.pad_vocab_size_multiple
|
| 250 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
| 251 |
-
|
| 252 |
-
super().__init__()
|
| 253 |
-
if vocab_size % pad_vocab_size_multiple != 0:
|
| 254 |
-
vocab_size += pad_vocab_size_multiple - (
|
| 255 |
-
vocab_size % pad_vocab_size_multiple
|
| 256 |
-
)
|
| 257 |
-
self.backbone = MixerModel(
|
| 258 |
-
d_model=d_model,
|
| 259 |
-
n_layer=n_layer,
|
| 260 |
-
d_intermediate=d_intermediate,
|
| 261 |
-
vocab_size=vocab_size,
|
| 262 |
-
ssm_cfg=ssm_cfg,
|
| 263 |
-
attn_layer_idx=attn_layer_idx,
|
| 264 |
-
attn_cfg=attn_cfg,
|
| 265 |
-
rms_norm=rms_norm,
|
| 266 |
-
initializer_cfg=initializer_cfg,
|
| 267 |
-
fused_add_norm=fused_add_norm,
|
| 268 |
-
residual_in_fp32=residual_in_fp32,
|
| 269 |
-
**factory_kwargs,
|
| 270 |
-
)
|
| 271 |
-
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
| 272 |
-
|
| 273 |
-
# Initialize weights and apply final processing
|
| 274 |
-
self.apply(
|
| 275 |
-
partial(
|
| 276 |
-
_init_weights,
|
| 277 |
-
n_layer=n_layer,
|
| 278 |
-
**(initializer_cfg if initializer_cfg is not None else {}),
|
| 279 |
-
)
|
| 280 |
-
)
|
| 281 |
-
self.tie_weights()
|
| 282 |
-
|
| 283 |
-
def tie_weights(self):
|
| 284 |
-
if self.config.tie_embeddings:
|
| 285 |
-
self.lm_head.weight = self.backbone.embedding.weight
|
| 286 |
-
|
| 287 |
-
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 288 |
-
return self.backbone.allocate_inference_cache(
|
| 289 |
-
batch_size, max_seqlen, dtype=dtype, **kwargs
|
| 290 |
-
)
|
| 291 |
-
|
| 292 |
-
def forward(
|
| 293 |
-
self,
|
| 294 |
-
input_ids,
|
| 295 |
-
position_ids=None,
|
| 296 |
-
inference_params=None,
|
| 297 |
-
num_last_tokens=0,
|
| 298 |
-
**mixer_kwargs,
|
| 299 |
-
):
|
| 300 |
-
"""
|
| 301 |
-
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
| 302 |
-
num_last_tokens: if > 0, only return the logits for the last n tokens
|
| 303 |
-
"""
|
| 304 |
-
hidden_states = self.backbone(
|
| 305 |
-
input_ids, inference_params=inference_params, **mixer_kwargs
|
| 306 |
-
)
|
| 307 |
-
if num_last_tokens > 0:
|
| 308 |
-
hidden_states = hidden_states[:, -num_last_tokens:]
|
| 309 |
-
lm_logits = self.lm_head(hidden_states)
|
| 310 |
-
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
| 311 |
-
return CausalLMOutput(logits=lm_logits)
|
| 312 |
-
|
| 313 |
-
@classmethod
|
| 314 |
-
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
| 315 |
-
config_data = load_config_hf(pretrained_model_name)
|
| 316 |
-
config = MambaConfig(**config_data)
|
| 317 |
-
model = cls(config, device=device, dtype=dtype, **kwargs)
|
| 318 |
-
model.load_state_dict(
|
| 319 |
-
load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)
|
| 320 |
-
)
|
| 321 |
-
return model
|
| 322 |
-
|
| 323 |
-
def save_pretrained(self, save_directory):
|
| 324 |
-
"""
|
| 325 |
-
Minimal implementation of save_pretrained for MambaLMHeadModel.
|
| 326 |
-
Save the model and its configuration file to a directory.
|
| 327 |
-
"""
|
| 328 |
-
# Ensure save_directory exists
|
| 329 |
-
os.makedirs(save_directory, exist_ok=True)
|
| 330 |
-
|
| 331 |
-
# Save the model's state_dict
|
| 332 |
-
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
| 333 |
-
torch.save(self.state_dict(), model_path)
|
| 334 |
-
|
| 335 |
-
# Save the configuration of the model
|
| 336 |
-
config_path = os.path.join(save_directory, "config.json")
|
| 337 |
-
with open(config_path, "w") as f:
|
| 338 |
-
json.dump(self.config.__dict__, f, indent=4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/modules/__init__.py
DELETED
|
File without changes
|
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/__init__.py
DELETED
|
File without changes
|
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/selective_scan_interface.py
DELETED
|
@@ -1,659 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
from ..utils.torch import custom_fwd, custom_bwd
|
| 6 |
-
|
| 7 |
-
from einops import rearrange, repeat
|
| 8 |
-
|
| 9 |
-
try:
|
| 10 |
-
from causal_conv1d import causal_conv1d_fn
|
| 11 |
-
import causal_conv1d_cuda
|
| 12 |
-
except ImportError:
|
| 13 |
-
causal_conv1d_fn = None
|
| 14 |
-
causal_conv1d_cuda = None
|
| 15 |
-
|
| 16 |
-
from .triton.layer_norm import _layer_norm_fwd
|
| 17 |
-
|
| 18 |
-
from .._ops import ops
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class SelectiveScanFn(torch.autograd.Function):
|
| 22 |
-
|
| 23 |
-
@staticmethod
|
| 24 |
-
def forward(
|
| 25 |
-
ctx,
|
| 26 |
-
u,
|
| 27 |
-
delta,
|
| 28 |
-
A,
|
| 29 |
-
B,
|
| 30 |
-
C,
|
| 31 |
-
D=None,
|
| 32 |
-
z=None,
|
| 33 |
-
delta_bias=None,
|
| 34 |
-
delta_softplus=False,
|
| 35 |
-
return_last_state=False,
|
| 36 |
-
):
|
| 37 |
-
if u.stride(-1) != 1:
|
| 38 |
-
u = u.contiguous()
|
| 39 |
-
if delta.stride(-1) != 1:
|
| 40 |
-
delta = delta.contiguous()
|
| 41 |
-
if D is not None:
|
| 42 |
-
D = D.contiguous()
|
| 43 |
-
if B.stride(-1) != 1:
|
| 44 |
-
B = B.contiguous()
|
| 45 |
-
if C.stride(-1) != 1:
|
| 46 |
-
C = C.contiguous()
|
| 47 |
-
if z is not None and z.stride(-1) != 1:
|
| 48 |
-
z = z.contiguous()
|
| 49 |
-
if B.dim() == 3:
|
| 50 |
-
B = rearrange(B, "b dstate l -> b 1 dstate l")
|
| 51 |
-
ctx.squeeze_B = True
|
| 52 |
-
if C.dim() == 3:
|
| 53 |
-
C = rearrange(C, "b dstate l -> b 1 dstate l")
|
| 54 |
-
ctx.squeeze_C = True
|
| 55 |
-
out, x, *rest = ops.selective_scan_fwd(
|
| 56 |
-
u, delta, A, B, C, D, z, delta_bias, delta_softplus
|
| 57 |
-
)
|
| 58 |
-
ctx.delta_softplus = delta_softplus
|
| 59 |
-
ctx.has_z = z is not None
|
| 60 |
-
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
|
| 61 |
-
if not ctx.has_z:
|
| 62 |
-
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
|
| 63 |
-
return out if not return_last_state else (out, last_state)
|
| 64 |
-
else:
|
| 65 |
-
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
|
| 66 |
-
out_z = rest[0]
|
| 67 |
-
return out_z if not return_last_state else (out_z, last_state)
|
| 68 |
-
|
| 69 |
-
@staticmethod
|
| 70 |
-
def backward(ctx, dout, *args):
|
| 71 |
-
if not ctx.has_z:
|
| 72 |
-
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
|
| 73 |
-
z = None
|
| 74 |
-
out = None
|
| 75 |
-
else:
|
| 76 |
-
u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
|
| 77 |
-
if dout.stride(-1) != 1:
|
| 78 |
-
dout = dout.contiguous()
|
| 79 |
-
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
| 80 |
-
# backward of selective_scan_cuda with the backward of chunk).
|
| 81 |
-
# Here we just pass in None and dz will be allocated in the C++ code.
|
| 82 |
-
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = ops.selective_scan_bwd(
|
| 83 |
-
u,
|
| 84 |
-
delta,
|
| 85 |
-
A,
|
| 86 |
-
B,
|
| 87 |
-
C,
|
| 88 |
-
D,
|
| 89 |
-
z,
|
| 90 |
-
delta_bias,
|
| 91 |
-
dout,
|
| 92 |
-
x,
|
| 93 |
-
out,
|
| 94 |
-
None,
|
| 95 |
-
ctx.delta_softplus,
|
| 96 |
-
False, # option to recompute out_z, not used here
|
| 97 |
-
)
|
| 98 |
-
dz = rest[0] if ctx.has_z else None
|
| 99 |
-
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
|
| 100 |
-
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
|
| 101 |
-
return (
|
| 102 |
-
du,
|
| 103 |
-
ddelta,
|
| 104 |
-
dA,
|
| 105 |
-
dB,
|
| 106 |
-
dC,
|
| 107 |
-
dD if D is not None else None,
|
| 108 |
-
dz,
|
| 109 |
-
ddelta_bias if delta_bias is not None else None,
|
| 110 |
-
None,
|
| 111 |
-
None,
|
| 112 |
-
)
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
def rms_norm_forward(
|
| 116 |
-
x,
|
| 117 |
-
weight,
|
| 118 |
-
bias,
|
| 119 |
-
eps=1e-6,
|
| 120 |
-
is_rms_norm=True,
|
| 121 |
-
):
|
| 122 |
-
# x (b l) d
|
| 123 |
-
if x.stride(-1) != 1:
|
| 124 |
-
x = x.contiguous()
|
| 125 |
-
weight = weight.contiguous()
|
| 126 |
-
if bias is not None:
|
| 127 |
-
bias = bias.contiguous()
|
| 128 |
-
y = _layer_norm_fwd(
|
| 129 |
-
x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm
|
| 130 |
-
)[0]
|
| 131 |
-
# y (b l) d
|
| 132 |
-
return y
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
def selective_scan_fn(
|
| 136 |
-
u,
|
| 137 |
-
delta,
|
| 138 |
-
A,
|
| 139 |
-
B,
|
| 140 |
-
C,
|
| 141 |
-
D=None,
|
| 142 |
-
z=None,
|
| 143 |
-
delta_bias=None,
|
| 144 |
-
delta_softplus=False,
|
| 145 |
-
return_last_state=False,
|
| 146 |
-
):
|
| 147 |
-
"""if return_last_state is True, returns (out, last_state)
|
| 148 |
-
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
|
| 149 |
-
not considered in the backward pass.
|
| 150 |
-
"""
|
| 151 |
-
return SelectiveScanFn.apply(
|
| 152 |
-
u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
def selective_scan_ref(
|
| 157 |
-
u,
|
| 158 |
-
delta,
|
| 159 |
-
A,
|
| 160 |
-
B,
|
| 161 |
-
C,
|
| 162 |
-
D=None,
|
| 163 |
-
z=None,
|
| 164 |
-
delta_bias=None,
|
| 165 |
-
delta_softplus=False,
|
| 166 |
-
return_last_state=False,
|
| 167 |
-
):
|
| 168 |
-
"""
|
| 169 |
-
u: r(B D L)
|
| 170 |
-
delta: r(B D L)
|
| 171 |
-
A: c(D N) or r(D N)
|
| 172 |
-
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
| 173 |
-
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
| 174 |
-
D: r(D)
|
| 175 |
-
z: r(B D L)
|
| 176 |
-
delta_bias: r(D), fp32
|
| 177 |
-
|
| 178 |
-
out: r(B D L)
|
| 179 |
-
last_state (optional): r(B D dstate) or c(B D dstate)
|
| 180 |
-
"""
|
| 181 |
-
dtype_in = u.dtype
|
| 182 |
-
u = u.float()
|
| 183 |
-
delta = delta.float()
|
| 184 |
-
if delta_bias is not None:
|
| 185 |
-
delta = delta + delta_bias[..., None].float()
|
| 186 |
-
if delta_softplus:
|
| 187 |
-
delta = F.softplus(delta)
|
| 188 |
-
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
|
| 189 |
-
is_variable_B = B.dim() >= 3
|
| 190 |
-
is_variable_C = C.dim() >= 3
|
| 191 |
-
if A.is_complex():
|
| 192 |
-
if is_variable_B:
|
| 193 |
-
B = torch.view_as_complex(
|
| 194 |
-
rearrange(B.float(), "... (L two) -> ... L two", two=2)
|
| 195 |
-
)
|
| 196 |
-
if is_variable_C:
|
| 197 |
-
C = torch.view_as_complex(
|
| 198 |
-
rearrange(C.float(), "... (L two) -> ... L two", two=2)
|
| 199 |
-
)
|
| 200 |
-
else:
|
| 201 |
-
B = B.float()
|
| 202 |
-
C = C.float()
|
| 203 |
-
x = A.new_zeros((batch, dim, dstate))
|
| 204 |
-
ys = []
|
| 205 |
-
deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
|
| 206 |
-
if not is_variable_B:
|
| 207 |
-
deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
|
| 208 |
-
else:
|
| 209 |
-
if B.dim() == 3:
|
| 210 |
-
deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
|
| 211 |
-
else:
|
| 212 |
-
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
|
| 213 |
-
deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
|
| 214 |
-
if is_variable_C and C.dim() == 4:
|
| 215 |
-
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
|
| 216 |
-
last_state = None
|
| 217 |
-
for i in range(u.shape[2]):
|
| 218 |
-
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
| 219 |
-
if not is_variable_C:
|
| 220 |
-
y = torch.einsum("bdn,dn->bd", x, C)
|
| 221 |
-
else:
|
| 222 |
-
if C.dim() == 3:
|
| 223 |
-
y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
|
| 224 |
-
else:
|
| 225 |
-
y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
|
| 226 |
-
if i == u.shape[2] - 1:
|
| 227 |
-
last_state = x
|
| 228 |
-
if y.is_complex():
|
| 229 |
-
y = y.real * 2
|
| 230 |
-
ys.append(y)
|
| 231 |
-
y = torch.stack(ys, dim=2) # (batch dim L)
|
| 232 |
-
out = y if D is None else y + u * rearrange(D, "d -> d 1")
|
| 233 |
-
if z is not None:
|
| 234 |
-
out = out * F.silu(z)
|
| 235 |
-
out = out.to(dtype=dtype_in)
|
| 236 |
-
return out if not return_last_state else (out, last_state)
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
class MambaInnerFn(torch.autograd.Function):
|
| 240 |
-
|
| 241 |
-
@staticmethod
|
| 242 |
-
@custom_fwd
|
| 243 |
-
def forward(
|
| 244 |
-
ctx,
|
| 245 |
-
xz,
|
| 246 |
-
conv1d_weight,
|
| 247 |
-
conv1d_bias,
|
| 248 |
-
x_proj_weight,
|
| 249 |
-
delta_proj_weight,
|
| 250 |
-
out_proj_weight,
|
| 251 |
-
out_proj_bias,
|
| 252 |
-
A,
|
| 253 |
-
B=None,
|
| 254 |
-
C=None,
|
| 255 |
-
D=None,
|
| 256 |
-
delta_bias=None,
|
| 257 |
-
B_proj_bias=None,
|
| 258 |
-
C_proj_bias=None,
|
| 259 |
-
delta_softplus=True,
|
| 260 |
-
checkpoint_lvl=1,
|
| 261 |
-
b_rms_weight=None,
|
| 262 |
-
c_rms_weight=None,
|
| 263 |
-
dt_rms_weight=None,
|
| 264 |
-
b_c_dt_rms_eps=1e-6,
|
| 265 |
-
):
|
| 266 |
-
"""
|
| 267 |
-
xz: (batch, dim, seqlen)
|
| 268 |
-
"""
|
| 269 |
-
assert (
|
| 270 |
-
causal_conv1d_cuda is not None
|
| 271 |
-
), "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
| 272 |
-
assert checkpoint_lvl in [0, 1]
|
| 273 |
-
L = xz.shape[-1]
|
| 274 |
-
delta_rank = delta_proj_weight.shape[1]
|
| 275 |
-
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
| 276 |
-
if torch.is_autocast_enabled():
|
| 277 |
-
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
| 278 |
-
delta_proj_weight = delta_proj_weight.to(
|
| 279 |
-
dtype=torch.get_autocast_gpu_dtype()
|
| 280 |
-
)
|
| 281 |
-
out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
| 282 |
-
out_proj_bias = (
|
| 283 |
-
out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
|
| 284 |
-
if out_proj_bias is not None
|
| 285 |
-
else None
|
| 286 |
-
)
|
| 287 |
-
if xz.stride(-1) != 1:
|
| 288 |
-
xz = xz.contiguous()
|
| 289 |
-
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
|
| 290 |
-
x, z = xz.chunk(2, dim=1)
|
| 291 |
-
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
|
| 292 |
-
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
|
| 293 |
-
x, conv1d_weight, conv1d_bias, None, None, None, True
|
| 294 |
-
)
|
| 295 |
-
# We're being very careful here about the layout, to avoid extra transposes.
|
| 296 |
-
# We want delta to have d as the slowest moving dimension
|
| 297 |
-
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
| 298 |
-
x_dbl = F.linear(
|
| 299 |
-
rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight
|
| 300 |
-
) # (bl d)
|
| 301 |
-
delta = rearrange(
|
| 302 |
-
delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L
|
| 303 |
-
)
|
| 304 |
-
ctx.is_variable_B = B is None
|
| 305 |
-
ctx.is_variable_C = C is None
|
| 306 |
-
ctx.B_proj_bias_is_None = B_proj_bias is None
|
| 307 |
-
ctx.C_proj_bias_is_None = C_proj_bias is None
|
| 308 |
-
if B is None: # variable B
|
| 309 |
-
B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate)
|
| 310 |
-
if B_proj_bias is not None:
|
| 311 |
-
B = B + B_proj_bias.to(dtype=B.dtype)
|
| 312 |
-
if not A.is_complex():
|
| 313 |
-
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
| 314 |
-
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| 315 |
-
else:
|
| 316 |
-
B = rearrange(
|
| 317 |
-
B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2
|
| 318 |
-
).contiguous()
|
| 319 |
-
else:
|
| 320 |
-
if B.stride(-1) != 1:
|
| 321 |
-
B = B.contiguous()
|
| 322 |
-
if C is None: # variable C
|
| 323 |
-
C = x_dbl[:, -d_state:] # (bl dstate)
|
| 324 |
-
if C_proj_bias is not None:
|
| 325 |
-
C = C + C_proj_bias.to(dtype=C.dtype)
|
| 326 |
-
if not A.is_complex():
|
| 327 |
-
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
| 328 |
-
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| 329 |
-
else:
|
| 330 |
-
C = rearrange(
|
| 331 |
-
C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2
|
| 332 |
-
).contiguous()
|
| 333 |
-
else:
|
| 334 |
-
if C.stride(-1) != 1:
|
| 335 |
-
C = C.contiguous()
|
| 336 |
-
if D is not None:
|
| 337 |
-
D = D.contiguous()
|
| 338 |
-
|
| 339 |
-
if b_rms_weight is not None:
|
| 340 |
-
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
| 341 |
-
B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)
|
| 342 |
-
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| 343 |
-
if c_rms_weight is not None:
|
| 344 |
-
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
| 345 |
-
C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)
|
| 346 |
-
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| 347 |
-
if dt_rms_weight is not None:
|
| 348 |
-
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
|
| 349 |
-
delta = rms_norm_forward(
|
| 350 |
-
delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps
|
| 351 |
-
)
|
| 352 |
-
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
|
| 353 |
-
|
| 354 |
-
out, scan_intermediates, out_z = ops.selective_scan_fwd(
|
| 355 |
-
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
|
| 356 |
-
)
|
| 357 |
-
ctx.delta_softplus = delta_softplus
|
| 358 |
-
ctx.out_proj_bias_is_None = out_proj_bias is None
|
| 359 |
-
ctx.checkpoint_lvl = checkpoint_lvl
|
| 360 |
-
ctx.b_rms_weight = b_rms_weight
|
| 361 |
-
ctx.c_rms_weight = c_rms_weight
|
| 362 |
-
ctx.dt_rms_weight = dt_rms_weight
|
| 363 |
-
ctx.b_c_dt_rms_eps = b_c_dt_rms_eps
|
| 364 |
-
if (
|
| 365 |
-
checkpoint_lvl >= 1
|
| 366 |
-
): # Will recompute conv1d_out and delta in the backward pass
|
| 367 |
-
conv1d_out, delta = None, None
|
| 368 |
-
ctx.save_for_backward(
|
| 369 |
-
xz,
|
| 370 |
-
conv1d_weight,
|
| 371 |
-
conv1d_bias,
|
| 372 |
-
x_dbl,
|
| 373 |
-
x_proj_weight,
|
| 374 |
-
delta_proj_weight,
|
| 375 |
-
out_proj_weight,
|
| 376 |
-
conv1d_out,
|
| 377 |
-
delta,
|
| 378 |
-
A,
|
| 379 |
-
B,
|
| 380 |
-
C,
|
| 381 |
-
D,
|
| 382 |
-
delta_bias,
|
| 383 |
-
scan_intermediates,
|
| 384 |
-
b_rms_weight,
|
| 385 |
-
c_rms_weight,
|
| 386 |
-
dt_rms_weight,
|
| 387 |
-
out,
|
| 388 |
-
)
|
| 389 |
-
return F.linear(
|
| 390 |
-
rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias
|
| 391 |
-
)
|
| 392 |
-
|
| 393 |
-
@staticmethod
|
| 394 |
-
@custom_bwd
|
| 395 |
-
def backward(ctx, dout):
|
| 396 |
-
# dout: (batch, seqlen, dim)
|
| 397 |
-
assert (
|
| 398 |
-
causal_conv1d_cuda is not None
|
| 399 |
-
), "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
| 400 |
-
(
|
| 401 |
-
xz,
|
| 402 |
-
conv1d_weight,
|
| 403 |
-
conv1d_bias,
|
| 404 |
-
x_dbl,
|
| 405 |
-
x_proj_weight,
|
| 406 |
-
delta_proj_weight,
|
| 407 |
-
out_proj_weight,
|
| 408 |
-
conv1d_out,
|
| 409 |
-
delta,
|
| 410 |
-
A,
|
| 411 |
-
B,
|
| 412 |
-
C,
|
| 413 |
-
D,
|
| 414 |
-
delta_bias,
|
| 415 |
-
scan_intermediates,
|
| 416 |
-
b_rms_weight,
|
| 417 |
-
c_rms_weight,
|
| 418 |
-
dt_rms_weight,
|
| 419 |
-
out,
|
| 420 |
-
) = ctx.saved_tensors
|
| 421 |
-
L = xz.shape[-1]
|
| 422 |
-
delta_rank = delta_proj_weight.shape[1]
|
| 423 |
-
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
| 424 |
-
x, z = xz.chunk(2, dim=1)
|
| 425 |
-
if dout.stride(-1) != 1:
|
| 426 |
-
dout = dout.contiguous()
|
| 427 |
-
if ctx.checkpoint_lvl == 1:
|
| 428 |
-
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
|
| 429 |
-
x, conv1d_weight, conv1d_bias, None, None, None, True
|
| 430 |
-
)
|
| 431 |
-
delta = rearrange(
|
| 432 |
-
delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L
|
| 433 |
-
)
|
| 434 |
-
if dt_rms_weight is not None:
|
| 435 |
-
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
|
| 436 |
-
delta = rms_norm_forward(
|
| 437 |
-
delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps
|
| 438 |
-
)
|
| 439 |
-
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
|
| 440 |
-
if b_rms_weight is not None:
|
| 441 |
-
# Recompute & RMSNorm B
|
| 442 |
-
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
| 443 |
-
B = rms_norm_forward(B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps)
|
| 444 |
-
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| 445 |
-
if c_rms_weight is not None:
|
| 446 |
-
# Recompute & RMSNorm C
|
| 447 |
-
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
| 448 |
-
C = rms_norm_forward(C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps)
|
| 449 |
-
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| 450 |
-
|
| 451 |
-
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
| 452 |
-
# backward of selective_scan_cuda with the backward of chunk).
|
| 453 |
-
dxz = torch.empty_like(xz) # (batch, dim, seqlen)
|
| 454 |
-
dx, dz = dxz.chunk(2, dim=1)
|
| 455 |
-
dout = rearrange(dout, "b l e -> e (b l)")
|
| 456 |
-
dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
|
| 457 |
-
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = (
|
| 458 |
-
ops.selective_scan_bwd(
|
| 459 |
-
conv1d_out,
|
| 460 |
-
delta,
|
| 461 |
-
A,
|
| 462 |
-
B,
|
| 463 |
-
C,
|
| 464 |
-
D,
|
| 465 |
-
z,
|
| 466 |
-
delta_bias,
|
| 467 |
-
dout_y,
|
| 468 |
-
scan_intermediates,
|
| 469 |
-
out,
|
| 470 |
-
dz,
|
| 471 |
-
ctx.delta_softplus,
|
| 472 |
-
True, # option to recompute out_z
|
| 473 |
-
)
|
| 474 |
-
)
|
| 475 |
-
dout_proj_weight = torch.einsum(
|
| 476 |
-
"eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")
|
| 477 |
-
)
|
| 478 |
-
dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
|
| 479 |
-
dD = dD if D is not None else None
|
| 480 |
-
dx_dbl = torch.empty_like(x_dbl)
|
| 481 |
-
dB_proj_bias = None
|
| 482 |
-
if ctx.is_variable_B:
|
| 483 |
-
if not A.is_complex():
|
| 484 |
-
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
|
| 485 |
-
else:
|
| 486 |
-
dB = rearrange(
|
| 487 |
-
dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2
|
| 488 |
-
).contiguous()
|
| 489 |
-
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
|
| 490 |
-
dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d)
|
| 491 |
-
dB = None
|
| 492 |
-
dC_proj_bias = None
|
| 493 |
-
if ctx.is_variable_C:
|
| 494 |
-
if not A.is_complex():
|
| 495 |
-
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
|
| 496 |
-
else:
|
| 497 |
-
dC = rearrange(
|
| 498 |
-
dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2
|
| 499 |
-
).contiguous()
|
| 500 |
-
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
|
| 501 |
-
dx_dbl[:, -d_state:] = dC # (bl d)
|
| 502 |
-
dC = None
|
| 503 |
-
ddelta = rearrange(ddelta, "b d l -> d (b l)")
|
| 504 |
-
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
|
| 505 |
-
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
|
| 506 |
-
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
|
| 507 |
-
dx_proj_weight = torch.einsum(
|
| 508 |
-
"Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")
|
| 509 |
-
)
|
| 510 |
-
dconv1d_out = torch.addmm(
|
| 511 |
-
dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out
|
| 512 |
-
)
|
| 513 |
-
dconv1d_out = rearrange(
|
| 514 |
-
dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]
|
| 515 |
-
)
|
| 516 |
-
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
| 517 |
-
# backward of conv1d with the backward of chunk).
|
| 518 |
-
dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
|
| 519 |
-
x,
|
| 520 |
-
conv1d_weight,
|
| 521 |
-
conv1d_bias,
|
| 522 |
-
dconv1d_out,
|
| 523 |
-
None,
|
| 524 |
-
None,
|
| 525 |
-
None,
|
| 526 |
-
dx,
|
| 527 |
-
False,
|
| 528 |
-
True,
|
| 529 |
-
)
|
| 530 |
-
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
|
| 531 |
-
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
|
| 532 |
-
return (
|
| 533 |
-
dxz,
|
| 534 |
-
dconv1d_weight,
|
| 535 |
-
dconv1d_bias,
|
| 536 |
-
dx_proj_weight,
|
| 537 |
-
ddelta_proj_weight,
|
| 538 |
-
dout_proj_weight,
|
| 539 |
-
dout_proj_bias,
|
| 540 |
-
dA,
|
| 541 |
-
dB,
|
| 542 |
-
dC,
|
| 543 |
-
dD,
|
| 544 |
-
ddelta_bias if delta_bias is not None else None,
|
| 545 |
-
# 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps
|
| 546 |
-
dB_proj_bias,
|
| 547 |
-
dC_proj_bias,
|
| 548 |
-
None,
|
| 549 |
-
None,
|
| 550 |
-
None,
|
| 551 |
-
None,
|
| 552 |
-
None,
|
| 553 |
-
None,
|
| 554 |
-
)
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
def mamba_inner_fn(
|
| 558 |
-
xz,
|
| 559 |
-
conv1d_weight,
|
| 560 |
-
conv1d_bias,
|
| 561 |
-
x_proj_weight,
|
| 562 |
-
delta_proj_weight,
|
| 563 |
-
out_proj_weight,
|
| 564 |
-
out_proj_bias,
|
| 565 |
-
A,
|
| 566 |
-
B=None,
|
| 567 |
-
C=None,
|
| 568 |
-
D=None,
|
| 569 |
-
delta_bias=None,
|
| 570 |
-
B_proj_bias=None,
|
| 571 |
-
C_proj_bias=None,
|
| 572 |
-
delta_softplus=True,
|
| 573 |
-
checkpoint_lvl=1,
|
| 574 |
-
b_rms_weight=None,
|
| 575 |
-
c_rms_weight=None,
|
| 576 |
-
dt_rms_weight=None,
|
| 577 |
-
b_c_dt_rms_eps=1e-6,
|
| 578 |
-
):
|
| 579 |
-
return MambaInnerFn.apply(
|
| 580 |
-
xz,
|
| 581 |
-
conv1d_weight,
|
| 582 |
-
conv1d_bias,
|
| 583 |
-
x_proj_weight,
|
| 584 |
-
delta_proj_weight,
|
| 585 |
-
out_proj_weight,
|
| 586 |
-
out_proj_bias,
|
| 587 |
-
A,
|
| 588 |
-
B,
|
| 589 |
-
C,
|
| 590 |
-
D,
|
| 591 |
-
delta_bias,
|
| 592 |
-
B_proj_bias,
|
| 593 |
-
C_proj_bias,
|
| 594 |
-
delta_softplus,
|
| 595 |
-
checkpoint_lvl,
|
| 596 |
-
b_rms_weight,
|
| 597 |
-
c_rms_weight,
|
| 598 |
-
dt_rms_weight,
|
| 599 |
-
b_c_dt_rms_eps,
|
| 600 |
-
)
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
def mamba_inner_ref(
|
| 604 |
-
xz,
|
| 605 |
-
conv1d_weight,
|
| 606 |
-
conv1d_bias,
|
| 607 |
-
x_proj_weight,
|
| 608 |
-
delta_proj_weight,
|
| 609 |
-
out_proj_weight,
|
| 610 |
-
out_proj_bias,
|
| 611 |
-
A,
|
| 612 |
-
B=None,
|
| 613 |
-
C=None,
|
| 614 |
-
D=None,
|
| 615 |
-
delta_bias=None,
|
| 616 |
-
B_proj_bias=None,
|
| 617 |
-
C_proj_bias=None,
|
| 618 |
-
delta_softplus=True,
|
| 619 |
-
):
|
| 620 |
-
assert (
|
| 621 |
-
causal_conv1d_fn is not None
|
| 622 |
-
), "causal_conv1d_fn is not available. Please install causal-conv1d."
|
| 623 |
-
L = xz.shape[-1]
|
| 624 |
-
delta_rank = delta_proj_weight.shape[1]
|
| 625 |
-
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
| 626 |
-
x, z = xz.chunk(2, dim=1)
|
| 627 |
-
x = causal_conv1d_fn(
|
| 628 |
-
x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu"
|
| 629 |
-
)
|
| 630 |
-
# We're being very careful here about the layout, to avoid extra transposes.
|
| 631 |
-
# We want delta to have d as the slowest moving dimension
|
| 632 |
-
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
| 633 |
-
x_dbl = F.linear(rearrange(x, "b d l -> (b l) d"), x_proj_weight) # (bl d)
|
| 634 |
-
delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
|
| 635 |
-
delta = rearrange(delta, "d (b l) -> b d l", l=L)
|
| 636 |
-
if B is None: # variable B
|
| 637 |
-
B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl d)
|
| 638 |
-
if B_proj_bias is not None:
|
| 639 |
-
B = B + B_proj_bias.to(dtype=B.dtype)
|
| 640 |
-
if not A.is_complex():
|
| 641 |
-
B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
| 642 |
-
else:
|
| 643 |
-
B = rearrange(
|
| 644 |
-
B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2
|
| 645 |
-
).contiguous()
|
| 646 |
-
if C is None: # variable B
|
| 647 |
-
C = x_dbl[:, -d_state:] # (bl d)
|
| 648 |
-
if C_proj_bias is not None:
|
| 649 |
-
C = C + C_proj_bias.to(dtype=C.dtype)
|
| 650 |
-
if not A.is_complex():
|
| 651 |
-
C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
| 652 |
-
else:
|
| 653 |
-
C = rearrange(
|
| 654 |
-
C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2
|
| 655 |
-
).contiguous()
|
| 656 |
-
y = selective_scan_fn(
|
| 657 |
-
x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True
|
| 658 |
-
)
|
| 659 |
-
return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/triton/__init__.py
DELETED
|
File without changes
|
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/triton/layer_norm.py
DELETED
|
@@ -1,1166 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2024, Tri Dao.
|
| 2 |
-
# Implement dropout + residual + layer_norm / rms_norm.
|
| 3 |
-
|
| 4 |
-
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
| 5 |
-
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
| 6 |
-
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
| 7 |
-
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
| 8 |
-
|
| 9 |
-
import math
|
| 10 |
-
import warnings
|
| 11 |
-
|
| 12 |
-
import torch
|
| 13 |
-
import torch.nn.functional as F
|
| 14 |
-
from ...utils.torch import custom_bwd, custom_fwd
|
| 15 |
-
|
| 16 |
-
import triton
|
| 17 |
-
import triton.language as tl
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def layer_norm_ref(
|
| 21 |
-
x,
|
| 22 |
-
weight,
|
| 23 |
-
bias,
|
| 24 |
-
residual=None,
|
| 25 |
-
x1=None,
|
| 26 |
-
weight1=None,
|
| 27 |
-
bias1=None,
|
| 28 |
-
eps=1e-6,
|
| 29 |
-
dropout_p=0.0,
|
| 30 |
-
rowscale=None,
|
| 31 |
-
prenorm=False,
|
| 32 |
-
dropout_mask=None,
|
| 33 |
-
dropout_mask1=None,
|
| 34 |
-
upcast=False,
|
| 35 |
-
):
|
| 36 |
-
dtype = x.dtype
|
| 37 |
-
if upcast:
|
| 38 |
-
x = x.float()
|
| 39 |
-
weight = weight.float()
|
| 40 |
-
bias = bias.float() if bias is not None else None
|
| 41 |
-
residual = residual.float() if residual is not None else residual
|
| 42 |
-
x1 = x1.float() if x1 is not None else None
|
| 43 |
-
weight1 = weight1.float() if weight1 is not None else None
|
| 44 |
-
bias1 = bias1.float() if bias1 is not None else None
|
| 45 |
-
if x1 is not None:
|
| 46 |
-
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
| 47 |
-
if rowscale is not None:
|
| 48 |
-
x = x * rowscale[..., None]
|
| 49 |
-
if dropout_p > 0.0:
|
| 50 |
-
if dropout_mask is not None:
|
| 51 |
-
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
|
| 52 |
-
else:
|
| 53 |
-
x = F.dropout(x, p=dropout_p)
|
| 54 |
-
if x1 is not None:
|
| 55 |
-
if dropout_mask1 is not None:
|
| 56 |
-
x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
|
| 57 |
-
else:
|
| 58 |
-
x1 = F.dropout(x1, p=dropout_p)
|
| 59 |
-
if x1 is not None:
|
| 60 |
-
x = x + x1
|
| 61 |
-
if residual is not None:
|
| 62 |
-
x = (x + residual).to(x.dtype)
|
| 63 |
-
out = F.layer_norm(
|
| 64 |
-
x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps
|
| 65 |
-
).to(dtype)
|
| 66 |
-
if weight1 is None:
|
| 67 |
-
return out if not prenorm else (out, x)
|
| 68 |
-
else:
|
| 69 |
-
out1 = F.layer_norm(
|
| 70 |
-
x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
|
| 71 |
-
).to(dtype)
|
| 72 |
-
return (out, out1) if not prenorm else (out, out1, x)
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def rms_norm_ref(
|
| 76 |
-
x,
|
| 77 |
-
weight,
|
| 78 |
-
bias,
|
| 79 |
-
residual=None,
|
| 80 |
-
x1=None,
|
| 81 |
-
weight1=None,
|
| 82 |
-
bias1=None,
|
| 83 |
-
eps=1e-6,
|
| 84 |
-
dropout_p=0.0,
|
| 85 |
-
rowscale=None,
|
| 86 |
-
prenorm=False,
|
| 87 |
-
dropout_mask=None,
|
| 88 |
-
dropout_mask1=None,
|
| 89 |
-
upcast=False,
|
| 90 |
-
):
|
| 91 |
-
dtype = x.dtype
|
| 92 |
-
if upcast:
|
| 93 |
-
x = x.float()
|
| 94 |
-
weight = weight.float()
|
| 95 |
-
bias = bias.float() if bias is not None else None
|
| 96 |
-
residual = residual.float() if residual is not None else residual
|
| 97 |
-
x1 = x1.float() if x1 is not None else None
|
| 98 |
-
weight1 = weight1.float() if weight1 is not None else None
|
| 99 |
-
bias1 = bias1.float() if bias1 is not None else None
|
| 100 |
-
if x1 is not None:
|
| 101 |
-
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
| 102 |
-
if rowscale is not None:
|
| 103 |
-
x = x * rowscale[..., None]
|
| 104 |
-
if dropout_p > 0.0:
|
| 105 |
-
if dropout_mask is not None:
|
| 106 |
-
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
|
| 107 |
-
else:
|
| 108 |
-
x = F.dropout(x, p=dropout_p)
|
| 109 |
-
if x1 is not None:
|
| 110 |
-
if dropout_mask1 is not None:
|
| 111 |
-
x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
|
| 112 |
-
else:
|
| 113 |
-
x1 = F.dropout(x1, p=dropout_p)
|
| 114 |
-
if x1 is not None:
|
| 115 |
-
x = x + x1
|
| 116 |
-
if residual is not None:
|
| 117 |
-
x = (x + residual).to(x.dtype)
|
| 118 |
-
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
| 119 |
-
out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(
|
| 120 |
-
dtype
|
| 121 |
-
)
|
| 122 |
-
if weight1 is None:
|
| 123 |
-
return out if not prenorm else (out, x)
|
| 124 |
-
else:
|
| 125 |
-
out1 = (
|
| 126 |
-
(x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)
|
| 127 |
-
).to(dtype)
|
| 128 |
-
return (out, out1) if not prenorm else (out, out1, x)
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
def config_prune(configs):
|
| 132 |
-
|
| 133 |
-
if torch.version.hip:
|
| 134 |
-
try:
|
| 135 |
-
# set warp size based on gcn architecure
|
| 136 |
-
gcn_arch_name = torch.cuda.get_device_properties(0).gcnArchName
|
| 137 |
-
if "gfx10" in gcn_arch_name or "gfx11" in gcn_arch_name:
|
| 138 |
-
# radeon
|
| 139 |
-
warp_size = 32
|
| 140 |
-
else:
|
| 141 |
-
# instinct
|
| 142 |
-
warp_size = 64
|
| 143 |
-
except AttributeError as e:
|
| 144 |
-
# fall back to crude method to set warp size
|
| 145 |
-
device_name = torch.cuda.get_device_properties(0).name
|
| 146 |
-
if "instinct" in device_name.lower():
|
| 147 |
-
warp_size = 64
|
| 148 |
-
else:
|
| 149 |
-
warp_size = 32
|
| 150 |
-
warnings.warn(
|
| 151 |
-
f"{e}, warp size set to {warp_size} based on device name: {device_name}",
|
| 152 |
-
UserWarning,
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
-
else:
|
| 156 |
-
# cuda
|
| 157 |
-
warp_size = 32
|
| 158 |
-
|
| 159 |
-
max_block_sz = 1024
|
| 160 |
-
max_num_warps = max_block_sz // warp_size
|
| 161 |
-
pruned_configs = [config for config in configs if config.num_warps <= max_num_warps]
|
| 162 |
-
return pruned_configs
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
configs_autotune = [
|
| 166 |
-
triton.Config({}, num_warps=1),
|
| 167 |
-
triton.Config({}, num_warps=2),
|
| 168 |
-
triton.Config({}, num_warps=4),
|
| 169 |
-
triton.Config({}, num_warps=8),
|
| 170 |
-
triton.Config({}, num_warps=16),
|
| 171 |
-
triton.Config({}, num_warps=32),
|
| 172 |
-
]
|
| 173 |
-
|
| 174 |
-
pruned_configs_autotune = config_prune(configs_autotune)
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
@triton.autotune(
|
| 178 |
-
configs=pruned_configs_autotune,
|
| 179 |
-
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
|
| 180 |
-
)
|
| 181 |
-
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
| 182 |
-
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
|
| 183 |
-
@triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
|
| 184 |
-
@triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
|
| 185 |
-
@triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
|
| 186 |
-
@triton.jit
|
| 187 |
-
def _layer_norm_fwd_1pass_kernel(
|
| 188 |
-
X, # pointer to the input
|
| 189 |
-
Y, # pointer to the output
|
| 190 |
-
W, # pointer to the weights
|
| 191 |
-
B, # pointer to the biases
|
| 192 |
-
RESIDUAL, # pointer to the residual
|
| 193 |
-
X1,
|
| 194 |
-
W1,
|
| 195 |
-
B1,
|
| 196 |
-
Y1,
|
| 197 |
-
RESIDUAL_OUT, # pointer to the residual
|
| 198 |
-
ROWSCALE,
|
| 199 |
-
SEEDS, # Dropout seeds for each row
|
| 200 |
-
DROPOUT_MASK,
|
| 201 |
-
Mean, # pointer to the mean
|
| 202 |
-
Rstd, # pointer to the 1/std
|
| 203 |
-
stride_x_row, # how much to increase the pointer when moving by 1 row
|
| 204 |
-
stride_y_row,
|
| 205 |
-
stride_res_row,
|
| 206 |
-
stride_res_out_row,
|
| 207 |
-
stride_x1_row,
|
| 208 |
-
stride_y1_row,
|
| 209 |
-
M, # number of rows in X
|
| 210 |
-
N, # number of columns in X
|
| 211 |
-
eps, # epsilon to avoid division by zero
|
| 212 |
-
dropout_p, # Dropout probability
|
| 213 |
-
IS_RMS_NORM: tl.constexpr,
|
| 214 |
-
BLOCK_N: tl.constexpr,
|
| 215 |
-
HAS_RESIDUAL: tl.constexpr,
|
| 216 |
-
STORE_RESIDUAL_OUT: tl.constexpr,
|
| 217 |
-
HAS_BIAS: tl.constexpr,
|
| 218 |
-
HAS_DROPOUT: tl.constexpr,
|
| 219 |
-
STORE_DROPOUT_MASK: tl.constexpr,
|
| 220 |
-
HAS_ROWSCALE: tl.constexpr,
|
| 221 |
-
HAS_X1: tl.constexpr,
|
| 222 |
-
HAS_W1: tl.constexpr,
|
| 223 |
-
HAS_B1: tl.constexpr,
|
| 224 |
-
):
|
| 225 |
-
# Map the program id to the row of X and Y it should compute.
|
| 226 |
-
row = tl.program_id(0)
|
| 227 |
-
X += row * stride_x_row
|
| 228 |
-
Y += row * stride_y_row
|
| 229 |
-
if HAS_RESIDUAL:
|
| 230 |
-
RESIDUAL += row * stride_res_row
|
| 231 |
-
if STORE_RESIDUAL_OUT:
|
| 232 |
-
RESIDUAL_OUT += row * stride_res_out_row
|
| 233 |
-
if HAS_X1:
|
| 234 |
-
X1 += row * stride_x1_row
|
| 235 |
-
if HAS_W1:
|
| 236 |
-
Y1 += row * stride_y1_row
|
| 237 |
-
# Compute mean and variance
|
| 238 |
-
cols = tl.arange(0, BLOCK_N)
|
| 239 |
-
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
| 240 |
-
if HAS_ROWSCALE:
|
| 241 |
-
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
|
| 242 |
-
x *= rowscale
|
| 243 |
-
if HAS_DROPOUT:
|
| 244 |
-
# Compute dropout mask
|
| 245 |
-
# 7 rounds is good enough, and reduces register pressure
|
| 246 |
-
keep_mask = (
|
| 247 |
-
tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
|
| 248 |
-
)
|
| 249 |
-
x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
|
| 250 |
-
if STORE_DROPOUT_MASK:
|
| 251 |
-
tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
|
| 252 |
-
if HAS_X1:
|
| 253 |
-
x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
|
| 254 |
-
if HAS_ROWSCALE:
|
| 255 |
-
rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
|
| 256 |
-
x1 *= rowscale
|
| 257 |
-
if HAS_DROPOUT:
|
| 258 |
-
# Compute dropout mask
|
| 259 |
-
# 7 rounds is good enough, and reduces register pressure
|
| 260 |
-
keep_mask = (
|
| 261 |
-
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
|
| 262 |
-
> dropout_p
|
| 263 |
-
)
|
| 264 |
-
x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
|
| 265 |
-
if STORE_DROPOUT_MASK:
|
| 266 |
-
tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
|
| 267 |
-
x += x1
|
| 268 |
-
if HAS_RESIDUAL:
|
| 269 |
-
residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
|
| 270 |
-
x += residual
|
| 271 |
-
if STORE_RESIDUAL_OUT:
|
| 272 |
-
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
|
| 273 |
-
if not IS_RMS_NORM:
|
| 274 |
-
mean = tl.sum(x, axis=0) / N
|
| 275 |
-
tl.store(Mean + row, mean)
|
| 276 |
-
xbar = tl.where(cols < N, x - mean, 0.0)
|
| 277 |
-
var = tl.sum(xbar * xbar, axis=0) / N
|
| 278 |
-
else:
|
| 279 |
-
xbar = tl.where(cols < N, x, 0.0)
|
| 280 |
-
var = tl.sum(xbar * xbar, axis=0) / N
|
| 281 |
-
rstd = 1 / tl.sqrt(var + eps)
|
| 282 |
-
tl.store(Rstd + row, rstd)
|
| 283 |
-
# Normalize and apply linear transformation
|
| 284 |
-
mask = cols < N
|
| 285 |
-
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
| 286 |
-
if HAS_BIAS:
|
| 287 |
-
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
| 288 |
-
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
| 289 |
-
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
| 290 |
-
# Write output
|
| 291 |
-
tl.store(Y + cols, y, mask=mask)
|
| 292 |
-
if HAS_W1:
|
| 293 |
-
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
|
| 294 |
-
if HAS_B1:
|
| 295 |
-
b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
|
| 296 |
-
y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
|
| 297 |
-
tl.store(Y1 + cols, y1, mask=mask)
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
def _layer_norm_fwd(
|
| 301 |
-
x,
|
| 302 |
-
weight,
|
| 303 |
-
bias,
|
| 304 |
-
eps,
|
| 305 |
-
residual=None,
|
| 306 |
-
x1=None,
|
| 307 |
-
weight1=None,
|
| 308 |
-
bias1=None,
|
| 309 |
-
dropout_p=0.0,
|
| 310 |
-
rowscale=None,
|
| 311 |
-
out_dtype=None,
|
| 312 |
-
residual_dtype=None,
|
| 313 |
-
is_rms_norm=False,
|
| 314 |
-
return_dropout_mask=False,
|
| 315 |
-
):
|
| 316 |
-
if residual is not None:
|
| 317 |
-
residual_dtype = residual.dtype
|
| 318 |
-
M, N = x.shape
|
| 319 |
-
assert x.stride(-1) == 1
|
| 320 |
-
if residual is not None:
|
| 321 |
-
assert residual.stride(-1) == 1
|
| 322 |
-
assert residual.shape == (M, N)
|
| 323 |
-
assert weight.shape == (N,)
|
| 324 |
-
assert weight.stride(-1) == 1
|
| 325 |
-
if bias is not None:
|
| 326 |
-
assert bias.stride(-1) == 1
|
| 327 |
-
assert bias.shape == (N,)
|
| 328 |
-
if x1 is not None:
|
| 329 |
-
assert x1.shape == x.shape
|
| 330 |
-
assert rowscale is None
|
| 331 |
-
assert x1.stride(-1) == 1
|
| 332 |
-
if weight1 is not None:
|
| 333 |
-
assert weight1.shape == (N,)
|
| 334 |
-
assert weight1.stride(-1) == 1
|
| 335 |
-
if bias1 is not None:
|
| 336 |
-
assert bias1.shape == (N,)
|
| 337 |
-
assert bias1.stride(-1) == 1
|
| 338 |
-
if rowscale is not None:
|
| 339 |
-
assert rowscale.is_contiguous()
|
| 340 |
-
assert rowscale.shape == (M,)
|
| 341 |
-
# allocate output
|
| 342 |
-
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
|
| 343 |
-
assert y.stride(-1) == 1
|
| 344 |
-
if weight1 is not None:
|
| 345 |
-
y1 = torch.empty_like(y)
|
| 346 |
-
assert y1.stride(-1) == 1
|
| 347 |
-
else:
|
| 348 |
-
y1 = None
|
| 349 |
-
if (
|
| 350 |
-
residual is not None
|
| 351 |
-
or (residual_dtype is not None and residual_dtype != x.dtype)
|
| 352 |
-
or dropout_p > 0.0
|
| 353 |
-
or rowscale is not None
|
| 354 |
-
or x1 is not None
|
| 355 |
-
):
|
| 356 |
-
residual_out = torch.empty(
|
| 357 |
-
M,
|
| 358 |
-
N,
|
| 359 |
-
device=x.device,
|
| 360 |
-
dtype=residual_dtype if residual_dtype is not None else x.dtype,
|
| 361 |
-
)
|
| 362 |
-
assert residual_out.stride(-1) == 1
|
| 363 |
-
else:
|
| 364 |
-
residual_out = None
|
| 365 |
-
mean = (
|
| 366 |
-
torch.empty((M,), dtype=torch.float32, device=x.device)
|
| 367 |
-
if not is_rms_norm
|
| 368 |
-
else None
|
| 369 |
-
)
|
| 370 |
-
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
|
| 371 |
-
if dropout_p > 0.0:
|
| 372 |
-
seeds = torch.randint(
|
| 373 |
-
2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
|
| 374 |
-
)
|
| 375 |
-
else:
|
| 376 |
-
seeds = None
|
| 377 |
-
if return_dropout_mask and dropout_p > 0.0:
|
| 378 |
-
dropout_mask = torch.empty(
|
| 379 |
-
M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool
|
| 380 |
-
)
|
| 381 |
-
else:
|
| 382 |
-
dropout_mask = None
|
| 383 |
-
# Less than 64KB per feature: enqueue fused kernel
|
| 384 |
-
MAX_FUSED_SIZE = 65536 // x.element_size()
|
| 385 |
-
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
| 386 |
-
if N > BLOCK_N:
|
| 387 |
-
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
| 388 |
-
with torch.cuda.device(x.device.index):
|
| 389 |
-
_layer_norm_fwd_1pass_kernel[(M,)](
|
| 390 |
-
x,
|
| 391 |
-
y,
|
| 392 |
-
weight,
|
| 393 |
-
bias,
|
| 394 |
-
residual,
|
| 395 |
-
x1,
|
| 396 |
-
weight1,
|
| 397 |
-
bias1,
|
| 398 |
-
y1,
|
| 399 |
-
residual_out,
|
| 400 |
-
rowscale,
|
| 401 |
-
seeds,
|
| 402 |
-
dropout_mask,
|
| 403 |
-
mean,
|
| 404 |
-
rstd,
|
| 405 |
-
x.stride(0),
|
| 406 |
-
y.stride(0),
|
| 407 |
-
residual.stride(0) if residual is not None else 0,
|
| 408 |
-
residual_out.stride(0) if residual_out is not None else 0,
|
| 409 |
-
x1.stride(0) if x1 is not None else 0,
|
| 410 |
-
y1.stride(0) if y1 is not None else 0,
|
| 411 |
-
M,
|
| 412 |
-
N,
|
| 413 |
-
eps,
|
| 414 |
-
dropout_p,
|
| 415 |
-
is_rms_norm,
|
| 416 |
-
BLOCK_N,
|
| 417 |
-
residual is not None,
|
| 418 |
-
residual_out is not None,
|
| 419 |
-
bias is not None,
|
| 420 |
-
dropout_p > 0.0,
|
| 421 |
-
dropout_mask is not None,
|
| 422 |
-
rowscale is not None,
|
| 423 |
-
)
|
| 424 |
-
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
|
| 425 |
-
if dropout_mask is not None and x1 is not None:
|
| 426 |
-
dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
|
| 427 |
-
else:
|
| 428 |
-
dropout_mask1 = None
|
| 429 |
-
return (
|
| 430 |
-
y,
|
| 431 |
-
y1,
|
| 432 |
-
mean,
|
| 433 |
-
rstd,
|
| 434 |
-
residual_out if residual_out is not None else x,
|
| 435 |
-
seeds,
|
| 436 |
-
dropout_mask,
|
| 437 |
-
dropout_mask1,
|
| 438 |
-
)
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
@triton.autotune(
|
| 442 |
-
configs=pruned_configs_autotune,
|
| 443 |
-
key=[
|
| 444 |
-
"N",
|
| 445 |
-
"HAS_DRESIDUAL",
|
| 446 |
-
"STORE_DRESIDUAL",
|
| 447 |
-
"IS_RMS_NORM",
|
| 448 |
-
"HAS_BIAS",
|
| 449 |
-
"HAS_DROPOUT",
|
| 450 |
-
],
|
| 451 |
-
)
|
| 452 |
-
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
| 453 |
-
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
|
| 454 |
-
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
|
| 455 |
-
@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
|
| 456 |
-
@triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
|
| 457 |
-
@triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
|
| 458 |
-
@triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
|
| 459 |
-
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
|
| 460 |
-
@triton.jit
|
| 461 |
-
def _layer_norm_bwd_kernel(
|
| 462 |
-
X, # pointer to the input
|
| 463 |
-
W, # pointer to the weights
|
| 464 |
-
B, # pointer to the biases
|
| 465 |
-
Y, # pointer to the output to be recomputed
|
| 466 |
-
DY, # pointer to the output gradient
|
| 467 |
-
DX, # pointer to the input gradient
|
| 468 |
-
DW, # pointer to the partial sum of weights gradient
|
| 469 |
-
DB, # pointer to the partial sum of biases gradient
|
| 470 |
-
DRESIDUAL,
|
| 471 |
-
W1,
|
| 472 |
-
DY1,
|
| 473 |
-
DX1,
|
| 474 |
-
DW1,
|
| 475 |
-
DB1,
|
| 476 |
-
DRESIDUAL_IN,
|
| 477 |
-
ROWSCALE,
|
| 478 |
-
SEEDS,
|
| 479 |
-
Mean, # pointer to the mean
|
| 480 |
-
Rstd, # pointer to the 1/std
|
| 481 |
-
stride_x_row, # how much to increase the pointer when moving by 1 row
|
| 482 |
-
stride_y_row,
|
| 483 |
-
stride_dy_row,
|
| 484 |
-
stride_dx_row,
|
| 485 |
-
stride_dres_row,
|
| 486 |
-
stride_dy1_row,
|
| 487 |
-
stride_dx1_row,
|
| 488 |
-
stride_dres_in_row,
|
| 489 |
-
M, # number of rows in X
|
| 490 |
-
N, # number of columns in X
|
| 491 |
-
eps, # epsilon to avoid division by zero
|
| 492 |
-
dropout_p,
|
| 493 |
-
rows_per_program,
|
| 494 |
-
IS_RMS_NORM: tl.constexpr,
|
| 495 |
-
BLOCK_N: tl.constexpr,
|
| 496 |
-
HAS_DRESIDUAL: tl.constexpr,
|
| 497 |
-
STORE_DRESIDUAL: tl.constexpr,
|
| 498 |
-
HAS_BIAS: tl.constexpr,
|
| 499 |
-
HAS_DROPOUT: tl.constexpr,
|
| 500 |
-
HAS_ROWSCALE: tl.constexpr,
|
| 501 |
-
HAS_DY1: tl.constexpr,
|
| 502 |
-
HAS_DX1: tl.constexpr,
|
| 503 |
-
HAS_B1: tl.constexpr,
|
| 504 |
-
RECOMPUTE_OUTPUT: tl.constexpr,
|
| 505 |
-
):
|
| 506 |
-
# Map the program id to the elements of X, DX, and DY it should compute.
|
| 507 |
-
row_block_id = tl.program_id(0)
|
| 508 |
-
row_start = row_block_id * rows_per_program
|
| 509 |
-
# Do not early exit if row_start >= M, because we need to write DW and DB
|
| 510 |
-
cols = tl.arange(0, BLOCK_N)
|
| 511 |
-
mask = cols < N
|
| 512 |
-
X += row_start * stride_x_row
|
| 513 |
-
if HAS_DRESIDUAL:
|
| 514 |
-
DRESIDUAL += row_start * stride_dres_row
|
| 515 |
-
if STORE_DRESIDUAL:
|
| 516 |
-
DRESIDUAL_IN += row_start * stride_dres_in_row
|
| 517 |
-
DY += row_start * stride_dy_row
|
| 518 |
-
DX += row_start * stride_dx_row
|
| 519 |
-
if HAS_DY1:
|
| 520 |
-
DY1 += row_start * stride_dy1_row
|
| 521 |
-
if HAS_DX1:
|
| 522 |
-
DX1 += row_start * stride_dx1_row
|
| 523 |
-
if RECOMPUTE_OUTPUT:
|
| 524 |
-
Y += row_start * stride_y_row
|
| 525 |
-
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
| 526 |
-
if RECOMPUTE_OUTPUT and HAS_BIAS:
|
| 527 |
-
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
|
| 528 |
-
if HAS_DY1:
|
| 529 |
-
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
|
| 530 |
-
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 531 |
-
if HAS_BIAS:
|
| 532 |
-
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 533 |
-
if HAS_DY1:
|
| 534 |
-
dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 535 |
-
if HAS_B1:
|
| 536 |
-
db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 537 |
-
row_end = min((row_block_id + 1) * rows_per_program, M)
|
| 538 |
-
for row in range(row_start, row_end):
|
| 539 |
-
# Load data to SRAM
|
| 540 |
-
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
| 541 |
-
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
| 542 |
-
if HAS_DY1:
|
| 543 |
-
dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
|
| 544 |
-
if not IS_RMS_NORM:
|
| 545 |
-
mean = tl.load(Mean + row)
|
| 546 |
-
rstd = tl.load(Rstd + row)
|
| 547 |
-
# Compute dx
|
| 548 |
-
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
| 549 |
-
xhat = tl.where(mask, xhat, 0.0)
|
| 550 |
-
if RECOMPUTE_OUTPUT:
|
| 551 |
-
y = xhat * w + b if HAS_BIAS else xhat * w
|
| 552 |
-
tl.store(Y + cols, y, mask=mask)
|
| 553 |
-
wdy = w * dy
|
| 554 |
-
dw += dy * xhat
|
| 555 |
-
if HAS_BIAS:
|
| 556 |
-
db += dy
|
| 557 |
-
if HAS_DY1:
|
| 558 |
-
wdy += w1 * dy1
|
| 559 |
-
dw1 += dy1 * xhat
|
| 560 |
-
if HAS_B1:
|
| 561 |
-
db1 += dy1
|
| 562 |
-
if not IS_RMS_NORM:
|
| 563 |
-
c1 = tl.sum(xhat * wdy, axis=0) / N
|
| 564 |
-
c2 = tl.sum(wdy, axis=0) / N
|
| 565 |
-
dx = (wdy - (xhat * c1 + c2)) * rstd
|
| 566 |
-
else:
|
| 567 |
-
c1 = tl.sum(xhat * wdy, axis=0) / N
|
| 568 |
-
dx = (wdy - xhat * c1) * rstd
|
| 569 |
-
if HAS_DRESIDUAL:
|
| 570 |
-
dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
|
| 571 |
-
dx += dres
|
| 572 |
-
# Write dx
|
| 573 |
-
if STORE_DRESIDUAL:
|
| 574 |
-
tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
|
| 575 |
-
if HAS_DX1:
|
| 576 |
-
if HAS_DROPOUT:
|
| 577 |
-
keep_mask = (
|
| 578 |
-
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
|
| 579 |
-
> dropout_p
|
| 580 |
-
)
|
| 581 |
-
dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
|
| 582 |
-
else:
|
| 583 |
-
dx1 = dx
|
| 584 |
-
tl.store(DX1 + cols, dx1, mask=mask)
|
| 585 |
-
if HAS_DROPOUT:
|
| 586 |
-
keep_mask = (
|
| 587 |
-
tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7)
|
| 588 |
-
> dropout_p
|
| 589 |
-
)
|
| 590 |
-
dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
|
| 591 |
-
if HAS_ROWSCALE:
|
| 592 |
-
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
|
| 593 |
-
dx *= rowscale
|
| 594 |
-
tl.store(DX + cols, dx, mask=mask)
|
| 595 |
-
|
| 596 |
-
X += stride_x_row
|
| 597 |
-
if HAS_DRESIDUAL:
|
| 598 |
-
DRESIDUAL += stride_dres_row
|
| 599 |
-
if STORE_DRESIDUAL:
|
| 600 |
-
DRESIDUAL_IN += stride_dres_in_row
|
| 601 |
-
if RECOMPUTE_OUTPUT:
|
| 602 |
-
Y += stride_y_row
|
| 603 |
-
DY += stride_dy_row
|
| 604 |
-
DX += stride_dx_row
|
| 605 |
-
if HAS_DY1:
|
| 606 |
-
DY1 += stride_dy1_row
|
| 607 |
-
if HAS_DX1:
|
| 608 |
-
DX1 += stride_dx1_row
|
| 609 |
-
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
|
| 610 |
-
if HAS_BIAS:
|
| 611 |
-
tl.store(DB + row_block_id * N + cols, db, mask=mask)
|
| 612 |
-
if HAS_DY1:
|
| 613 |
-
tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
|
| 614 |
-
if HAS_B1:
|
| 615 |
-
tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
def _layer_norm_bwd(
|
| 619 |
-
dy,
|
| 620 |
-
x,
|
| 621 |
-
weight,
|
| 622 |
-
bias,
|
| 623 |
-
eps,
|
| 624 |
-
mean,
|
| 625 |
-
rstd,
|
| 626 |
-
dresidual=None,
|
| 627 |
-
dy1=None,
|
| 628 |
-
weight1=None,
|
| 629 |
-
bias1=None,
|
| 630 |
-
seeds=None,
|
| 631 |
-
dropout_p=0.0,
|
| 632 |
-
rowscale=None,
|
| 633 |
-
has_residual=False,
|
| 634 |
-
has_x1=False,
|
| 635 |
-
is_rms_norm=False,
|
| 636 |
-
x_dtype=None,
|
| 637 |
-
recompute_output=False,
|
| 638 |
-
):
|
| 639 |
-
M, N = x.shape
|
| 640 |
-
assert x.stride(-1) == 1
|
| 641 |
-
assert dy.stride(-1) == 1
|
| 642 |
-
assert dy.shape == (M, N)
|
| 643 |
-
if dresidual is not None:
|
| 644 |
-
assert dresidual.stride(-1) == 1
|
| 645 |
-
assert dresidual.shape == (M, N)
|
| 646 |
-
assert weight.shape == (N,)
|
| 647 |
-
assert weight.stride(-1) == 1
|
| 648 |
-
if bias is not None:
|
| 649 |
-
assert bias.stride(-1) == 1
|
| 650 |
-
assert bias.shape == (N,)
|
| 651 |
-
if dy1 is not None:
|
| 652 |
-
assert weight1 is not None
|
| 653 |
-
assert dy1.shape == dy.shape
|
| 654 |
-
assert dy1.stride(-1) == 1
|
| 655 |
-
if weight1 is not None:
|
| 656 |
-
assert weight1.shape == (N,)
|
| 657 |
-
assert weight1.stride(-1) == 1
|
| 658 |
-
if bias1 is not None:
|
| 659 |
-
assert bias1.shape == (N,)
|
| 660 |
-
assert bias1.stride(-1) == 1
|
| 661 |
-
if seeds is not None:
|
| 662 |
-
assert seeds.is_contiguous()
|
| 663 |
-
assert seeds.shape == (M if not has_x1 else M * 2,)
|
| 664 |
-
if rowscale is not None:
|
| 665 |
-
assert rowscale.is_contiguous()
|
| 666 |
-
assert rowscale.shape == (M,)
|
| 667 |
-
# allocate output
|
| 668 |
-
dx = (
|
| 669 |
-
torch.empty_like(x)
|
| 670 |
-
if x_dtype is None
|
| 671 |
-
else torch.empty(M, N, dtype=x_dtype, device=x.device)
|
| 672 |
-
)
|
| 673 |
-
dresidual_in = (
|
| 674 |
-
torch.empty_like(x)
|
| 675 |
-
if has_residual
|
| 676 |
-
and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
|
| 677 |
-
else None
|
| 678 |
-
)
|
| 679 |
-
dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
|
| 680 |
-
y = (
|
| 681 |
-
torch.empty(M, N, dtype=dy.dtype, device=dy.device)
|
| 682 |
-
if recompute_output
|
| 683 |
-
else None
|
| 684 |
-
)
|
| 685 |
-
if recompute_output:
|
| 686 |
-
assert (
|
| 687 |
-
weight1 is None
|
| 688 |
-
), "recompute_output is not supported with parallel LayerNorm"
|
| 689 |
-
|
| 690 |
-
# Less than 64KB per feature: enqueue fused kernel
|
| 691 |
-
MAX_FUSED_SIZE = 65536 // x.element_size()
|
| 692 |
-
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
| 693 |
-
if N > BLOCK_N:
|
| 694 |
-
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
| 695 |
-
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
| 696 |
-
_dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
|
| 697 |
-
_db = (
|
| 698 |
-
torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
|
| 699 |
-
if bias is not None
|
| 700 |
-
else None
|
| 701 |
-
)
|
| 702 |
-
_dw1 = torch.empty_like(_dw) if weight1 is not None else None
|
| 703 |
-
_db1 = torch.empty_like(_db) if bias1 is not None else None
|
| 704 |
-
rows_per_program = math.ceil(M / sm_count)
|
| 705 |
-
grid = (sm_count,)
|
| 706 |
-
with torch.cuda.device(x.device.index):
|
| 707 |
-
_layer_norm_bwd_kernel[grid](
|
| 708 |
-
x,
|
| 709 |
-
weight,
|
| 710 |
-
bias,
|
| 711 |
-
y,
|
| 712 |
-
dy,
|
| 713 |
-
dx,
|
| 714 |
-
_dw,
|
| 715 |
-
_db,
|
| 716 |
-
dresidual,
|
| 717 |
-
weight1,
|
| 718 |
-
dy1,
|
| 719 |
-
dx1,
|
| 720 |
-
_dw1,
|
| 721 |
-
_db1,
|
| 722 |
-
dresidual_in,
|
| 723 |
-
rowscale,
|
| 724 |
-
seeds,
|
| 725 |
-
mean,
|
| 726 |
-
rstd,
|
| 727 |
-
x.stride(0),
|
| 728 |
-
0 if not recompute_output else y.stride(0),
|
| 729 |
-
dy.stride(0),
|
| 730 |
-
dx.stride(0),
|
| 731 |
-
dresidual.stride(0) if dresidual is not None else 0,
|
| 732 |
-
dy1.stride(0) if dy1 is not None else 0,
|
| 733 |
-
dx1.stride(0) if dx1 is not None else 0,
|
| 734 |
-
dresidual_in.stride(0) if dresidual_in is not None else 0,
|
| 735 |
-
M,
|
| 736 |
-
N,
|
| 737 |
-
eps,
|
| 738 |
-
dropout_p,
|
| 739 |
-
rows_per_program,
|
| 740 |
-
is_rms_norm,
|
| 741 |
-
BLOCK_N,
|
| 742 |
-
dresidual is not None,
|
| 743 |
-
dresidual_in is not None,
|
| 744 |
-
bias is not None,
|
| 745 |
-
dropout_p > 0.0,
|
| 746 |
-
)
|
| 747 |
-
dw = _dw.sum(0).to(weight.dtype)
|
| 748 |
-
db = _db.sum(0).to(bias.dtype) if bias is not None else None
|
| 749 |
-
dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
|
| 750 |
-
db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
|
| 751 |
-
# Don't need to compute dresidual_in separately in this case
|
| 752 |
-
if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
|
| 753 |
-
dresidual_in = dx
|
| 754 |
-
if has_x1 and dropout_p == 0.0:
|
| 755 |
-
dx1 = dx
|
| 756 |
-
return (
|
| 757 |
-
(dx, dw, db, dresidual_in, dx1, dw1, db1)
|
| 758 |
-
if not recompute_output
|
| 759 |
-
else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
|
| 760 |
-
)
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
class LayerNormFn(torch.autograd.Function):
|
| 764 |
-
@staticmethod
|
| 765 |
-
def forward(
|
| 766 |
-
ctx,
|
| 767 |
-
x,
|
| 768 |
-
weight,
|
| 769 |
-
bias,
|
| 770 |
-
residual=None,
|
| 771 |
-
x1=None,
|
| 772 |
-
weight1=None,
|
| 773 |
-
bias1=None,
|
| 774 |
-
eps=1e-6,
|
| 775 |
-
dropout_p=0.0,
|
| 776 |
-
rowscale=None,
|
| 777 |
-
prenorm=False,
|
| 778 |
-
residual_in_fp32=False,
|
| 779 |
-
is_rms_norm=False,
|
| 780 |
-
return_dropout_mask=False,
|
| 781 |
-
):
|
| 782 |
-
x_shape_og = x.shape
|
| 783 |
-
# reshape input data into 2D tensor
|
| 784 |
-
x = x.reshape(-1, x.shape[-1])
|
| 785 |
-
if x.stride(-1) != 1:
|
| 786 |
-
x = x.contiguous()
|
| 787 |
-
if residual is not None:
|
| 788 |
-
assert residual.shape == x_shape_og
|
| 789 |
-
residual = residual.reshape(-1, residual.shape[-1])
|
| 790 |
-
if residual.stride(-1) != 1:
|
| 791 |
-
residual = residual.contiguous()
|
| 792 |
-
if x1 is not None:
|
| 793 |
-
assert x1.shape == x_shape_og
|
| 794 |
-
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
| 795 |
-
x1 = x1.reshape(-1, x1.shape[-1])
|
| 796 |
-
if x1.stride(-1) != 1:
|
| 797 |
-
x1 = x1.contiguous()
|
| 798 |
-
weight = weight.contiguous()
|
| 799 |
-
if bias is not None:
|
| 800 |
-
bias = bias.contiguous()
|
| 801 |
-
if weight1 is not None:
|
| 802 |
-
weight1 = weight1.contiguous()
|
| 803 |
-
if bias1 is not None:
|
| 804 |
-
bias1 = bias1.contiguous()
|
| 805 |
-
if rowscale is not None:
|
| 806 |
-
rowscale = rowscale.reshape(-1).contiguous()
|
| 807 |
-
residual_dtype = (
|
| 808 |
-
residual.dtype
|
| 809 |
-
if residual is not None
|
| 810 |
-
else (torch.float32 if residual_in_fp32 else None)
|
| 811 |
-
)
|
| 812 |
-
y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = (
|
| 813 |
-
_layer_norm_fwd(
|
| 814 |
-
x,
|
| 815 |
-
weight,
|
| 816 |
-
bias,
|
| 817 |
-
eps,
|
| 818 |
-
residual,
|
| 819 |
-
x1,
|
| 820 |
-
weight1,
|
| 821 |
-
bias1,
|
| 822 |
-
dropout_p=dropout_p,
|
| 823 |
-
rowscale=rowscale,
|
| 824 |
-
residual_dtype=residual_dtype,
|
| 825 |
-
is_rms_norm=is_rms_norm,
|
| 826 |
-
return_dropout_mask=return_dropout_mask,
|
| 827 |
-
)
|
| 828 |
-
)
|
| 829 |
-
ctx.save_for_backward(
|
| 830 |
-
residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
|
| 831 |
-
)
|
| 832 |
-
ctx.x_shape_og = x_shape_og
|
| 833 |
-
ctx.eps = eps
|
| 834 |
-
ctx.dropout_p = dropout_p
|
| 835 |
-
ctx.is_rms_norm = is_rms_norm
|
| 836 |
-
ctx.has_residual = residual is not None
|
| 837 |
-
ctx.has_x1 = x1 is not None
|
| 838 |
-
ctx.prenorm = prenorm
|
| 839 |
-
ctx.x_dtype = x.dtype
|
| 840 |
-
y = y.reshape(x_shape_og)
|
| 841 |
-
y1 = y1.reshape(x_shape_og) if y1 is not None else None
|
| 842 |
-
residual_out = (
|
| 843 |
-
residual_out.reshape(x_shape_og) if residual_out is not None else None
|
| 844 |
-
)
|
| 845 |
-
dropout_mask = (
|
| 846 |
-
dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
|
| 847 |
-
)
|
| 848 |
-
dropout_mask1 = (
|
| 849 |
-
dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
|
| 850 |
-
)
|
| 851 |
-
if not return_dropout_mask:
|
| 852 |
-
if weight1 is None:
|
| 853 |
-
return y if not prenorm else (y, residual_out)
|
| 854 |
-
else:
|
| 855 |
-
return (y, y1) if not prenorm else (y, y1, residual_out)
|
| 856 |
-
else:
|
| 857 |
-
if weight1 is None:
|
| 858 |
-
return (
|
| 859 |
-
(y, dropout_mask, dropout_mask1)
|
| 860 |
-
if not prenorm
|
| 861 |
-
else (y, residual_out, dropout_mask, dropout_mask1)
|
| 862 |
-
)
|
| 863 |
-
else:
|
| 864 |
-
return (
|
| 865 |
-
(y, y1, dropout_mask, dropout_mask1)
|
| 866 |
-
if not prenorm
|
| 867 |
-
else (y, y1, residual_out, dropout_mask, dropout_mask1)
|
| 868 |
-
)
|
| 869 |
-
|
| 870 |
-
@staticmethod
|
| 871 |
-
def backward(ctx, dy, *args):
|
| 872 |
-
x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
|
| 873 |
-
dy = dy.reshape(-1, dy.shape[-1])
|
| 874 |
-
if dy.stride(-1) != 1:
|
| 875 |
-
dy = dy.contiguous()
|
| 876 |
-
assert dy.shape == x.shape
|
| 877 |
-
if weight1 is not None:
|
| 878 |
-
dy1, args = args[0], args[1:]
|
| 879 |
-
dy1 = dy1.reshape(-1, dy1.shape[-1])
|
| 880 |
-
if dy1.stride(-1) != 1:
|
| 881 |
-
dy1 = dy1.contiguous()
|
| 882 |
-
assert dy1.shape == x.shape
|
| 883 |
-
else:
|
| 884 |
-
dy1 = None
|
| 885 |
-
if ctx.prenorm:
|
| 886 |
-
dresidual = args[0]
|
| 887 |
-
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
| 888 |
-
if dresidual.stride(-1) != 1:
|
| 889 |
-
dresidual = dresidual.contiguous()
|
| 890 |
-
assert dresidual.shape == x.shape
|
| 891 |
-
else:
|
| 892 |
-
dresidual = None
|
| 893 |
-
dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
|
| 894 |
-
dy,
|
| 895 |
-
x,
|
| 896 |
-
weight,
|
| 897 |
-
bias,
|
| 898 |
-
ctx.eps,
|
| 899 |
-
mean,
|
| 900 |
-
rstd,
|
| 901 |
-
dresidual,
|
| 902 |
-
dy1,
|
| 903 |
-
weight1,
|
| 904 |
-
bias1,
|
| 905 |
-
seeds,
|
| 906 |
-
ctx.dropout_p,
|
| 907 |
-
rowscale,
|
| 908 |
-
ctx.has_residual,
|
| 909 |
-
ctx.has_x1,
|
| 910 |
-
ctx.is_rms_norm,
|
| 911 |
-
x_dtype=ctx.x_dtype,
|
| 912 |
-
)
|
| 913 |
-
return (
|
| 914 |
-
dx.reshape(ctx.x_shape_og),
|
| 915 |
-
dw,
|
| 916 |
-
db,
|
| 917 |
-
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
| 918 |
-
dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
|
| 919 |
-
dw1,
|
| 920 |
-
db1,
|
| 921 |
-
None,
|
| 922 |
-
None,
|
| 923 |
-
None,
|
| 924 |
-
None,
|
| 925 |
-
None,
|
| 926 |
-
None,
|
| 927 |
-
None,
|
| 928 |
-
)
|
| 929 |
-
|
| 930 |
-
|
| 931 |
-
def layer_norm_fn(
|
| 932 |
-
x,
|
| 933 |
-
weight,
|
| 934 |
-
bias,
|
| 935 |
-
residual=None,
|
| 936 |
-
x1=None,
|
| 937 |
-
weight1=None,
|
| 938 |
-
bias1=None,
|
| 939 |
-
eps=1e-6,
|
| 940 |
-
dropout_p=0.0,
|
| 941 |
-
rowscale=None,
|
| 942 |
-
prenorm=False,
|
| 943 |
-
residual_in_fp32=False,
|
| 944 |
-
is_rms_norm=False,
|
| 945 |
-
return_dropout_mask=False,
|
| 946 |
-
):
|
| 947 |
-
return LayerNormFn.apply(
|
| 948 |
-
x,
|
| 949 |
-
weight,
|
| 950 |
-
bias,
|
| 951 |
-
residual,
|
| 952 |
-
x1,
|
| 953 |
-
weight1,
|
| 954 |
-
bias1,
|
| 955 |
-
eps,
|
| 956 |
-
dropout_p,
|
| 957 |
-
rowscale,
|
| 958 |
-
prenorm,
|
| 959 |
-
residual_in_fp32,
|
| 960 |
-
is_rms_norm,
|
| 961 |
-
return_dropout_mask,
|
| 962 |
-
)
|
| 963 |
-
|
| 964 |
-
|
| 965 |
-
def rms_norm_fn(
|
| 966 |
-
x,
|
| 967 |
-
weight,
|
| 968 |
-
bias,
|
| 969 |
-
residual=None,
|
| 970 |
-
x1=None,
|
| 971 |
-
weight1=None,
|
| 972 |
-
bias1=None,
|
| 973 |
-
eps=1e-6,
|
| 974 |
-
dropout_p=0.0,
|
| 975 |
-
rowscale=None,
|
| 976 |
-
prenorm=False,
|
| 977 |
-
residual_in_fp32=False,
|
| 978 |
-
return_dropout_mask=False,
|
| 979 |
-
):
|
| 980 |
-
return LayerNormFn.apply(
|
| 981 |
-
x,
|
| 982 |
-
weight,
|
| 983 |
-
bias,
|
| 984 |
-
residual,
|
| 985 |
-
x1,
|
| 986 |
-
weight1,
|
| 987 |
-
bias1,
|
| 988 |
-
eps,
|
| 989 |
-
dropout_p,
|
| 990 |
-
rowscale,
|
| 991 |
-
prenorm,
|
| 992 |
-
residual_in_fp32,
|
| 993 |
-
True,
|
| 994 |
-
return_dropout_mask,
|
| 995 |
-
)
|
| 996 |
-
|
| 997 |
-
|
| 998 |
-
class RMSNorm(torch.nn.Module):
|
| 999 |
-
|
| 1000 |
-
def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
|
| 1001 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
| 1002 |
-
super().__init__()
|
| 1003 |
-
self.eps = eps
|
| 1004 |
-
if dropout_p > 0.0:
|
| 1005 |
-
self.drop = torch.nn.Dropout(dropout_p)
|
| 1006 |
-
else:
|
| 1007 |
-
self.drop = None
|
| 1008 |
-
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
| 1009 |
-
self.register_parameter("bias", None)
|
| 1010 |
-
self.reset_parameters()
|
| 1011 |
-
|
| 1012 |
-
def reset_parameters(self):
|
| 1013 |
-
torch.nn.init.ones_(self.weight)
|
| 1014 |
-
|
| 1015 |
-
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
|
| 1016 |
-
return rms_norm_fn(
|
| 1017 |
-
x,
|
| 1018 |
-
self.weight,
|
| 1019 |
-
self.bias,
|
| 1020 |
-
residual=residual,
|
| 1021 |
-
eps=self.eps,
|
| 1022 |
-
dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
|
| 1023 |
-
prenorm=prenorm,
|
| 1024 |
-
residual_in_fp32=residual_in_fp32,
|
| 1025 |
-
)
|
| 1026 |
-
|
| 1027 |
-
|
| 1028 |
-
class LayerNormLinearFn(torch.autograd.Function):
|
| 1029 |
-
@staticmethod
|
| 1030 |
-
@custom_fwd
|
| 1031 |
-
def forward(
|
| 1032 |
-
ctx,
|
| 1033 |
-
x,
|
| 1034 |
-
norm_weight,
|
| 1035 |
-
norm_bias,
|
| 1036 |
-
linear_weight,
|
| 1037 |
-
linear_bias,
|
| 1038 |
-
residual=None,
|
| 1039 |
-
eps=1e-6,
|
| 1040 |
-
prenorm=False,
|
| 1041 |
-
residual_in_fp32=False,
|
| 1042 |
-
is_rms_norm=False,
|
| 1043 |
-
):
|
| 1044 |
-
x_shape_og = x.shape
|
| 1045 |
-
# reshape input data into 2D tensor
|
| 1046 |
-
x = x.reshape(-1, x.shape[-1])
|
| 1047 |
-
if x.stride(-1) != 1:
|
| 1048 |
-
x = x.contiguous()
|
| 1049 |
-
if residual is not None:
|
| 1050 |
-
assert residual.shape == x_shape_og
|
| 1051 |
-
residual = residual.reshape(-1, residual.shape[-1])
|
| 1052 |
-
if residual.stride(-1) != 1:
|
| 1053 |
-
residual = residual.contiguous()
|
| 1054 |
-
norm_weight = norm_weight.contiguous()
|
| 1055 |
-
if norm_bias is not None:
|
| 1056 |
-
norm_bias = norm_bias.contiguous()
|
| 1057 |
-
residual_dtype = (
|
| 1058 |
-
residual.dtype
|
| 1059 |
-
if residual is not None
|
| 1060 |
-
else (torch.float32 if residual_in_fp32 else None)
|
| 1061 |
-
)
|
| 1062 |
-
y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
|
| 1063 |
-
x,
|
| 1064 |
-
norm_weight,
|
| 1065 |
-
norm_bias,
|
| 1066 |
-
eps,
|
| 1067 |
-
residual,
|
| 1068 |
-
out_dtype=(
|
| 1069 |
-
None
|
| 1070 |
-
if not torch.is_autocast_enabled()
|
| 1071 |
-
else torch.get_autocast_gpu_dtype()
|
| 1072 |
-
),
|
| 1073 |
-
residual_dtype=residual_dtype,
|
| 1074 |
-
is_rms_norm=is_rms_norm,
|
| 1075 |
-
)
|
| 1076 |
-
y = y.reshape(x_shape_og)
|
| 1077 |
-
dtype = (
|
| 1078 |
-
torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
|
| 1079 |
-
)
|
| 1080 |
-
linear_weight = linear_weight.to(dtype)
|
| 1081 |
-
linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
|
| 1082 |
-
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
|
| 1083 |
-
# We don't store y, will be recomputed in the backward pass to save memory
|
| 1084 |
-
ctx.save_for_backward(
|
| 1085 |
-
residual_out, norm_weight, norm_bias, linear_weight, mean, rstd
|
| 1086 |
-
)
|
| 1087 |
-
ctx.x_shape_og = x_shape_og
|
| 1088 |
-
ctx.eps = eps
|
| 1089 |
-
ctx.is_rms_norm = is_rms_norm
|
| 1090 |
-
ctx.has_residual = residual is not None
|
| 1091 |
-
ctx.prenorm = prenorm
|
| 1092 |
-
ctx.x_dtype = x.dtype
|
| 1093 |
-
ctx.linear_bias_is_none = linear_bias is None
|
| 1094 |
-
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
|
| 1095 |
-
|
| 1096 |
-
@staticmethod
|
| 1097 |
-
@custom_bwd
|
| 1098 |
-
def backward(ctx, dout, *args):
|
| 1099 |
-
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
|
| 1100 |
-
dout = dout.reshape(-1, dout.shape[-1])
|
| 1101 |
-
dy = F.linear(dout, linear_weight.t())
|
| 1102 |
-
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
|
| 1103 |
-
if dy.stride(-1) != 1:
|
| 1104 |
-
dy = dy.contiguous()
|
| 1105 |
-
assert dy.shape == x.shape
|
| 1106 |
-
if ctx.prenorm:
|
| 1107 |
-
dresidual = args[0]
|
| 1108 |
-
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
| 1109 |
-
if dresidual.stride(-1) != 1:
|
| 1110 |
-
dresidual = dresidual.contiguous()
|
| 1111 |
-
assert dresidual.shape == x.shape
|
| 1112 |
-
else:
|
| 1113 |
-
dresidual = None
|
| 1114 |
-
dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
|
| 1115 |
-
dy,
|
| 1116 |
-
x,
|
| 1117 |
-
norm_weight,
|
| 1118 |
-
norm_bias,
|
| 1119 |
-
ctx.eps,
|
| 1120 |
-
mean,
|
| 1121 |
-
rstd,
|
| 1122 |
-
dresidual=dresidual,
|
| 1123 |
-
has_residual=ctx.has_residual,
|
| 1124 |
-
is_rms_norm=ctx.is_rms_norm,
|
| 1125 |
-
x_dtype=ctx.x_dtype,
|
| 1126 |
-
recompute_output=True,
|
| 1127 |
-
)
|
| 1128 |
-
dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
|
| 1129 |
-
return (
|
| 1130 |
-
dx.reshape(ctx.x_shape_og),
|
| 1131 |
-
dnorm_weight,
|
| 1132 |
-
dnorm_bias,
|
| 1133 |
-
dlinear_weight,
|
| 1134 |
-
dlinear_bias,
|
| 1135 |
-
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
| 1136 |
-
None,
|
| 1137 |
-
None,
|
| 1138 |
-
None,
|
| 1139 |
-
None,
|
| 1140 |
-
)
|
| 1141 |
-
|
| 1142 |
-
|
| 1143 |
-
def layer_norm_linear_fn(
|
| 1144 |
-
x,
|
| 1145 |
-
norm_weight,
|
| 1146 |
-
norm_bias,
|
| 1147 |
-
linear_weight,
|
| 1148 |
-
linear_bias,
|
| 1149 |
-
residual=None,
|
| 1150 |
-
eps=1e-6,
|
| 1151 |
-
prenorm=False,
|
| 1152 |
-
residual_in_fp32=False,
|
| 1153 |
-
is_rms_norm=False,
|
| 1154 |
-
):
|
| 1155 |
-
return LayerNormLinearFn.apply(
|
| 1156 |
-
x,
|
| 1157 |
-
norm_weight,
|
| 1158 |
-
norm_bias,
|
| 1159 |
-
linear_weight,
|
| 1160 |
-
linear_bias,
|
| 1161 |
-
residual,
|
| 1162 |
-
eps,
|
| 1163 |
-
prenorm,
|
| 1164 |
-
residual_in_fp32,
|
| 1165 |
-
is_rms_norm,
|
| 1166 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/triton/selective_state_update.py
DELETED
|
@@ -1,389 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
| 2 |
-
|
| 3 |
-
"""We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import math
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
|
| 10 |
-
import triton
|
| 11 |
-
import triton.language as tl
|
| 12 |
-
|
| 13 |
-
from einops import rearrange, repeat
|
| 14 |
-
|
| 15 |
-
from .softplus import softplus
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
|
| 19 |
-
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
|
| 20 |
-
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
|
| 21 |
-
@triton.heuristics(
|
| 22 |
-
{
|
| 23 |
-
"HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"]
|
| 24 |
-
is not None
|
| 25 |
-
}
|
| 26 |
-
)
|
| 27 |
-
@triton.heuristics(
|
| 28 |
-
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}
|
| 29 |
-
)
|
| 30 |
-
@triton.jit
|
| 31 |
-
def _selective_scan_update_kernel(
|
| 32 |
-
# Pointers to matrices
|
| 33 |
-
state_ptr,
|
| 34 |
-
x_ptr,
|
| 35 |
-
dt_ptr,
|
| 36 |
-
dt_bias_ptr,
|
| 37 |
-
A_ptr,
|
| 38 |
-
B_ptr,
|
| 39 |
-
C_ptr,
|
| 40 |
-
D_ptr,
|
| 41 |
-
z_ptr,
|
| 42 |
-
out_ptr,
|
| 43 |
-
state_batch_indices_ptr,
|
| 44 |
-
# Matrix dimensions
|
| 45 |
-
batch,
|
| 46 |
-
nheads,
|
| 47 |
-
dim,
|
| 48 |
-
dstate,
|
| 49 |
-
nheads_ngroups_ratio,
|
| 50 |
-
# Strides
|
| 51 |
-
stride_state_batch,
|
| 52 |
-
stride_state_head,
|
| 53 |
-
stride_state_dim,
|
| 54 |
-
stride_state_dstate,
|
| 55 |
-
stride_x_batch,
|
| 56 |
-
stride_x_head,
|
| 57 |
-
stride_x_dim,
|
| 58 |
-
stride_dt_batch,
|
| 59 |
-
stride_dt_head,
|
| 60 |
-
stride_dt_dim,
|
| 61 |
-
stride_dt_bias_head,
|
| 62 |
-
stride_dt_bias_dim,
|
| 63 |
-
stride_A_head,
|
| 64 |
-
stride_A_dim,
|
| 65 |
-
stride_A_dstate,
|
| 66 |
-
stride_B_batch,
|
| 67 |
-
stride_B_group,
|
| 68 |
-
stride_B_dstate,
|
| 69 |
-
stride_C_batch,
|
| 70 |
-
stride_C_group,
|
| 71 |
-
stride_C_dstate,
|
| 72 |
-
stride_D_head,
|
| 73 |
-
stride_D_dim,
|
| 74 |
-
stride_z_batch,
|
| 75 |
-
stride_z_head,
|
| 76 |
-
stride_z_dim,
|
| 77 |
-
stride_out_batch,
|
| 78 |
-
stride_out_head,
|
| 79 |
-
stride_out_dim,
|
| 80 |
-
# Meta-parameters
|
| 81 |
-
DT_SOFTPLUS: tl.constexpr,
|
| 82 |
-
TIE_HDIM: tl.constexpr,
|
| 83 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 84 |
-
HAS_DT_BIAS: tl.constexpr,
|
| 85 |
-
HAS_D: tl.constexpr,
|
| 86 |
-
HAS_Z: tl.constexpr,
|
| 87 |
-
HAS_STATE_BATCH_INDICES: tl.constexpr,
|
| 88 |
-
BLOCK_SIZE_DSTATE: tl.constexpr,
|
| 89 |
-
):
|
| 90 |
-
pid_m = tl.program_id(axis=0)
|
| 91 |
-
pid_b = tl.program_id(axis=1)
|
| 92 |
-
pid_h = tl.program_id(axis=2)
|
| 93 |
-
|
| 94 |
-
if HAS_STATE_BATCH_INDICES:
|
| 95 |
-
state_batch_indices_ptr += pid_b
|
| 96 |
-
state_batch_idx = tl.load(state_batch_indices_ptr)
|
| 97 |
-
state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
|
| 98 |
-
else:
|
| 99 |
-
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
|
| 100 |
-
|
| 101 |
-
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
|
| 102 |
-
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
|
| 103 |
-
if HAS_DT_BIAS:
|
| 104 |
-
dt_bias_ptr += pid_h * stride_dt_bias_head
|
| 105 |
-
A_ptr += pid_h * stride_A_head
|
| 106 |
-
B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
|
| 107 |
-
C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
|
| 108 |
-
if HAS_Z:
|
| 109 |
-
z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
|
| 110 |
-
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
|
| 111 |
-
|
| 112 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 113 |
-
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
|
| 114 |
-
state_ptrs = state_ptr + (
|
| 115 |
-
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
|
| 116 |
-
)
|
| 117 |
-
x_ptrs = x_ptr + offs_m * stride_x_dim
|
| 118 |
-
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
|
| 119 |
-
if HAS_DT_BIAS:
|
| 120 |
-
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
|
| 121 |
-
if HAS_D:
|
| 122 |
-
D_ptr += pid_h * stride_D_head
|
| 123 |
-
A_ptrs = A_ptr + (
|
| 124 |
-
offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
|
| 125 |
-
)
|
| 126 |
-
B_ptrs = B_ptr + offs_n * stride_B_dstate
|
| 127 |
-
C_ptrs = C_ptr + offs_n * stride_C_dstate
|
| 128 |
-
if HAS_D:
|
| 129 |
-
D_ptrs = D_ptr + offs_m * stride_D_dim
|
| 130 |
-
if HAS_Z:
|
| 131 |
-
z_ptrs = z_ptr + offs_m * stride_z_dim
|
| 132 |
-
out_ptrs = out_ptr + offs_m * stride_out_dim
|
| 133 |
-
|
| 134 |
-
state = tl.load(
|
| 135 |
-
state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
|
| 136 |
-
)
|
| 137 |
-
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 138 |
-
if not TIE_HDIM:
|
| 139 |
-
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 140 |
-
if HAS_DT_BIAS:
|
| 141 |
-
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 142 |
-
if DT_SOFTPLUS:
|
| 143 |
-
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
| 144 |
-
A = tl.load(
|
| 145 |
-
A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
|
| 146 |
-
).to(tl.float32)
|
| 147 |
-
dA = tl.exp(A * dt[:, None])
|
| 148 |
-
else:
|
| 149 |
-
dt = tl.load(dt_ptr).to(tl.float32)
|
| 150 |
-
if HAS_DT_BIAS:
|
| 151 |
-
dt += tl.load(dt_bias_ptr).to(tl.float32)
|
| 152 |
-
if DT_SOFTPLUS:
|
| 153 |
-
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
| 154 |
-
A = tl.load(A_ptr).to(tl.float32)
|
| 155 |
-
dA = tl.exp(A * dt) # scalar, not a matrix
|
| 156 |
-
|
| 157 |
-
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
| 158 |
-
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
| 159 |
-
if HAS_D:
|
| 160 |
-
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 161 |
-
if HAS_Z:
|
| 162 |
-
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 163 |
-
|
| 164 |
-
if not TIE_HDIM:
|
| 165 |
-
dB = B[None, :] * dt[:, None]
|
| 166 |
-
else:
|
| 167 |
-
dB = B * dt # vector of size (dstate,)
|
| 168 |
-
state = state * dA + dB * x[:, None]
|
| 169 |
-
tl.store(
|
| 170 |
-
state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
| 171 |
-
)
|
| 172 |
-
out = tl.sum(state * C[None, :], axis=1)
|
| 173 |
-
if HAS_D:
|
| 174 |
-
out += x * D
|
| 175 |
-
if HAS_Z:
|
| 176 |
-
out *= z * tl.sigmoid(z)
|
| 177 |
-
tl.store(out_ptrs, out, mask=offs_m < dim)
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
def selective_state_update(
|
| 181 |
-
state,
|
| 182 |
-
x,
|
| 183 |
-
dt,
|
| 184 |
-
A,
|
| 185 |
-
B,
|
| 186 |
-
C,
|
| 187 |
-
D=None,
|
| 188 |
-
z=None,
|
| 189 |
-
dt_bias=None,
|
| 190 |
-
dt_softplus=False,
|
| 191 |
-
state_batch_indices=None,
|
| 192 |
-
):
|
| 193 |
-
"""
|
| 194 |
-
Argument:
|
| 195 |
-
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
| 196 |
-
x: (batch, dim) or (batch, nheads, dim)
|
| 197 |
-
dt: (batch, dim) or (batch, nheads, dim)
|
| 198 |
-
A: (dim, dstate) or (nheads, dim, dstate)
|
| 199 |
-
B: (batch, dstate) or (batch, ngroups, dstate)
|
| 200 |
-
C: (batch, dstate) or (batch, ngroups, dstate)
|
| 201 |
-
D: (dim,) or (nheads, dim)
|
| 202 |
-
z: (batch, dim) or (batch, nheads, dim)
|
| 203 |
-
dt_bias: (dim,) or (nheads, dim)
|
| 204 |
-
Return:
|
| 205 |
-
out: (batch, dim) or (batch, nheads, dim)
|
| 206 |
-
"""
|
| 207 |
-
has_heads = state.dim() > 3
|
| 208 |
-
if state.dim() == 3:
|
| 209 |
-
state = state.unsqueeze(1)
|
| 210 |
-
if x.dim() == 2:
|
| 211 |
-
x = x.unsqueeze(1)
|
| 212 |
-
if dt.dim() == 2:
|
| 213 |
-
dt = dt.unsqueeze(1)
|
| 214 |
-
if A.dim() == 2:
|
| 215 |
-
A = A.unsqueeze(0)
|
| 216 |
-
if B.dim() == 2:
|
| 217 |
-
B = B.unsqueeze(1)
|
| 218 |
-
if C.dim() == 2:
|
| 219 |
-
C = C.unsqueeze(1)
|
| 220 |
-
if D is not None and D.dim() == 1:
|
| 221 |
-
D = D.unsqueeze(0)
|
| 222 |
-
if z is not None and z.dim() == 2:
|
| 223 |
-
z = z.unsqueeze(1)
|
| 224 |
-
if dt_bias is not None and dt_bias.dim() == 1:
|
| 225 |
-
dt_bias = dt_bias.unsqueeze(0)
|
| 226 |
-
_, nheads, dim, dstate = state.shape
|
| 227 |
-
batch = x.shape[0]
|
| 228 |
-
if x.shape != (batch, nheads, dim):
|
| 229 |
-
print(f"{state.shape} {x.shape} {batch} {nheads} {dim}")
|
| 230 |
-
assert x.shape == (batch, nheads, dim)
|
| 231 |
-
assert dt.shape == x.shape
|
| 232 |
-
assert A.shape == (nheads, dim, dstate)
|
| 233 |
-
ngroups = B.shape[1]
|
| 234 |
-
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
| 235 |
-
assert B.shape == (batch, ngroups, dstate)
|
| 236 |
-
assert C.shape == B.shape
|
| 237 |
-
if D is not None:
|
| 238 |
-
assert D.shape == (nheads, dim)
|
| 239 |
-
if z is not None:
|
| 240 |
-
assert z.shape == x.shape
|
| 241 |
-
if dt_bias is not None:
|
| 242 |
-
assert dt_bias.shape == (nheads, dim)
|
| 243 |
-
if state_batch_indices is not None:
|
| 244 |
-
assert state_batch_indices.shape == (batch,)
|
| 245 |
-
out = torch.empty_like(x)
|
| 246 |
-
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
|
| 247 |
-
z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
|
| 248 |
-
# We don't want autotune since it will overwrite the state
|
| 249 |
-
# We instead tune by hand.
|
| 250 |
-
BLOCK_SIZE_M, num_warps = (
|
| 251 |
-
(32, 4)
|
| 252 |
-
if dstate <= 16
|
| 253 |
-
else (
|
| 254 |
-
(16, 4)
|
| 255 |
-
if dstate <= 32
|
| 256 |
-
else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
|
| 257 |
-
)
|
| 258 |
-
)
|
| 259 |
-
tie_hdim = (
|
| 260 |
-
A.stride(-1) == 0
|
| 261 |
-
and A.stride(-2) == 0
|
| 262 |
-
and dt.stride(-1) == 0
|
| 263 |
-
and dt_bias.stride(-1) == 0
|
| 264 |
-
)
|
| 265 |
-
with torch.cuda.device(x.device.index):
|
| 266 |
-
_selective_scan_update_kernel[grid](
|
| 267 |
-
state,
|
| 268 |
-
x,
|
| 269 |
-
dt,
|
| 270 |
-
dt_bias,
|
| 271 |
-
A,
|
| 272 |
-
B,
|
| 273 |
-
C,
|
| 274 |
-
D,
|
| 275 |
-
z,
|
| 276 |
-
out,
|
| 277 |
-
state_batch_indices,
|
| 278 |
-
batch,
|
| 279 |
-
nheads,
|
| 280 |
-
dim,
|
| 281 |
-
dstate,
|
| 282 |
-
nheads // ngroups,
|
| 283 |
-
state.stride(0),
|
| 284 |
-
state.stride(1),
|
| 285 |
-
state.stride(2),
|
| 286 |
-
state.stride(3),
|
| 287 |
-
x.stride(0),
|
| 288 |
-
x.stride(1),
|
| 289 |
-
x.stride(2),
|
| 290 |
-
dt.stride(0),
|
| 291 |
-
dt.stride(1),
|
| 292 |
-
dt.stride(2),
|
| 293 |
-
*(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
|
| 294 |
-
A.stride(0),
|
| 295 |
-
A.stride(1),
|
| 296 |
-
A.stride(2),
|
| 297 |
-
B.stride(0),
|
| 298 |
-
B.stride(1),
|
| 299 |
-
B.stride(2),
|
| 300 |
-
C.stride(0),
|
| 301 |
-
C.stride(1),
|
| 302 |
-
C.stride(2),
|
| 303 |
-
*(D.stride(0), D.stride(1)) if D is not None else 0,
|
| 304 |
-
z_strides[0],
|
| 305 |
-
z_strides[1],
|
| 306 |
-
z_strides[2],
|
| 307 |
-
out.stride(0),
|
| 308 |
-
out.stride(1),
|
| 309 |
-
out.stride(2),
|
| 310 |
-
dt_softplus,
|
| 311 |
-
tie_hdim,
|
| 312 |
-
BLOCK_SIZE_M,
|
| 313 |
-
num_warps=num_warps,
|
| 314 |
-
)
|
| 315 |
-
if not has_heads:
|
| 316 |
-
out = out.squeeze(1)
|
| 317 |
-
return out
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
def selective_state_update_ref(
|
| 321 |
-
state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
|
| 322 |
-
):
|
| 323 |
-
"""
|
| 324 |
-
Argument:
|
| 325 |
-
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
| 326 |
-
x: (batch, dim) or (batch, nheads, dim)
|
| 327 |
-
dt: (batch, dim) or (batch, nheads, dim)
|
| 328 |
-
A: (dim, dstate) or (nheads, dim, dstate)
|
| 329 |
-
B: (batch, dstate) or (batch, ngroups, dstate)
|
| 330 |
-
C: (batch, dstate) or (batch, ngroups, dstate)
|
| 331 |
-
D: (dim,) or (nheads, dim)
|
| 332 |
-
z: (batch, dim) or (batch, nheads, dim)
|
| 333 |
-
dt_bias: (dim,) or (nheads, dim)
|
| 334 |
-
Return:
|
| 335 |
-
out: (batch, dim) or (batch, nheads, dim)
|
| 336 |
-
"""
|
| 337 |
-
has_heads = state.dim() > 3
|
| 338 |
-
if state.dim() == 3:
|
| 339 |
-
state = state.unsqueeze(1)
|
| 340 |
-
if x.dim() == 2:
|
| 341 |
-
x = x.unsqueeze(1)
|
| 342 |
-
if dt.dim() == 2:
|
| 343 |
-
dt = dt.unsqueeze(1)
|
| 344 |
-
if A.dim() == 2:
|
| 345 |
-
A = A.unsqueeze(0)
|
| 346 |
-
if B.dim() == 2:
|
| 347 |
-
B = B.unsqueeze(1)
|
| 348 |
-
if C.dim() == 2:
|
| 349 |
-
C = C.unsqueeze(1)
|
| 350 |
-
if D is not None and D.dim() == 1:
|
| 351 |
-
D = D.unsqueeze(0)
|
| 352 |
-
if z is not None and z.dim() == 2:
|
| 353 |
-
z = z.unsqueeze(1)
|
| 354 |
-
if dt_bias is not None and dt_bias.dim() == 1:
|
| 355 |
-
dt_bias = dt_bias.unsqueeze(0)
|
| 356 |
-
batch, nheads, dim, dstate = state.shape
|
| 357 |
-
assert x.shape == (batch, nheads, dim)
|
| 358 |
-
assert dt.shape == x.shape
|
| 359 |
-
assert A.shape == (nheads, dim, dstate)
|
| 360 |
-
ngroups = B.shape[1]
|
| 361 |
-
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
| 362 |
-
assert B.shape == (batch, ngroups, dstate)
|
| 363 |
-
assert C.shape == B.shape
|
| 364 |
-
if D is not None:
|
| 365 |
-
assert D.shape == (nheads, dim)
|
| 366 |
-
if z is not None:
|
| 367 |
-
assert z.shape == x.shape
|
| 368 |
-
if dt_bias is not None:
|
| 369 |
-
assert dt_bias.shape == (nheads, dim)
|
| 370 |
-
dt = dt + dt_bias
|
| 371 |
-
dt = F.softplus(dt) if dt_softplus else dt
|
| 372 |
-
dA = torch.exp(
|
| 373 |
-
rearrange(dt, "b h d -> b h d 1") * A
|
| 374 |
-
) # (batch, nheads, dim, dstate)
|
| 375 |
-
B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
|
| 376 |
-
C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
|
| 377 |
-
dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
|
| 378 |
-
B, "b h n -> b h 1 n"
|
| 379 |
-
) # (batch, nheads, dim, dstate)
|
| 380 |
-
state.copy_(
|
| 381 |
-
state * dA + dB * rearrange(x, "b h d -> b h d 1")
|
| 382 |
-
) # (batch, dim, dstate
|
| 383 |
-
out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
|
| 384 |
-
if D is not None:
|
| 385 |
-
out += (x * D).to(out.dtype)
|
| 386 |
-
out = (out if z is None else out * F.silu(z)).to(x.dtype)
|
| 387 |
-
if not has_heads:
|
| 388 |
-
out = out.squeeze(1)
|
| 389 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_scan.py
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_state.py
DELETED
|
@@ -1,2012 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
| 2 |
-
|
| 3 |
-
"""We want triton==2.1.0 or 2.2.0 for this
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import math
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
|
| 10 |
-
import triton
|
| 11 |
-
import triton.language as tl
|
| 12 |
-
|
| 13 |
-
from einops import rearrange, repeat
|
| 14 |
-
|
| 15 |
-
from .softplus import softplus
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def init_to_zero(names):
|
| 19 |
-
return lambda nargs: [
|
| 20 |
-
nargs[name].zero_() for name in names if nargs[name] is not None
|
| 21 |
-
]
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
@triton.autotune(
|
| 25 |
-
configs=[
|
| 26 |
-
triton.Config({"BLOCK_SIZE_H": 1}),
|
| 27 |
-
triton.Config({"BLOCK_SIZE_H": 2}),
|
| 28 |
-
triton.Config({"BLOCK_SIZE_H": 4}),
|
| 29 |
-
triton.Config({"BLOCK_SIZE_H": 8}),
|
| 30 |
-
triton.Config({"BLOCK_SIZE_H": 16}),
|
| 31 |
-
triton.Config({"BLOCK_SIZE_H": 32}),
|
| 32 |
-
triton.Config({"BLOCK_SIZE_H": 64}),
|
| 33 |
-
],
|
| 34 |
-
key=["chunk_size", "nheads"],
|
| 35 |
-
)
|
| 36 |
-
@triton.jit
|
| 37 |
-
def _chunk_cumsum_fwd_kernel(
|
| 38 |
-
# Pointers to matrices
|
| 39 |
-
dt_ptr,
|
| 40 |
-
A_ptr,
|
| 41 |
-
dt_bias_ptr,
|
| 42 |
-
dt_out_ptr,
|
| 43 |
-
dA_cumsum_ptr,
|
| 44 |
-
# Matrix dimension
|
| 45 |
-
batch,
|
| 46 |
-
seqlen,
|
| 47 |
-
nheads,
|
| 48 |
-
chunk_size,
|
| 49 |
-
dt_min,
|
| 50 |
-
dt_max,
|
| 51 |
-
# Strides
|
| 52 |
-
stride_dt_batch,
|
| 53 |
-
stride_dt_seqlen,
|
| 54 |
-
stride_dt_head,
|
| 55 |
-
stride_A_head,
|
| 56 |
-
stride_dt_bias_head,
|
| 57 |
-
stride_dt_out_batch,
|
| 58 |
-
stride_dt_out_chunk,
|
| 59 |
-
stride_dt_out_head,
|
| 60 |
-
stride_dt_out_csize,
|
| 61 |
-
stride_dA_cs_batch,
|
| 62 |
-
stride_dA_cs_chunk,
|
| 63 |
-
stride_dA_cs_head,
|
| 64 |
-
stride_dA_cs_csize,
|
| 65 |
-
# Meta-parameters
|
| 66 |
-
DT_SOFTPLUS: tl.constexpr,
|
| 67 |
-
HAS_DT_BIAS: tl.constexpr,
|
| 68 |
-
BLOCK_SIZE_H: tl.constexpr,
|
| 69 |
-
BLOCK_SIZE_CHUNK: tl.constexpr,
|
| 70 |
-
):
|
| 71 |
-
pid_b = tl.program_id(axis=0)
|
| 72 |
-
pid_c = tl.program_id(axis=1)
|
| 73 |
-
pid_h = tl.program_id(axis=2)
|
| 74 |
-
dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
|
| 75 |
-
dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk
|
| 76 |
-
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk
|
| 77 |
-
|
| 78 |
-
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
|
| 79 |
-
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
|
| 80 |
-
dt_ptrs = dt_ptr + (
|
| 81 |
-
offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
|
| 82 |
-
)
|
| 83 |
-
A_ptrs = A_ptr + offs_h * stride_A_head
|
| 84 |
-
dt_out_ptrs = dt_out_ptr + (
|
| 85 |
-
offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize
|
| 86 |
-
)
|
| 87 |
-
dA_cs_ptrs = dA_cumsum_ptr + (
|
| 88 |
-
offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize
|
| 89 |
-
)
|
| 90 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 91 |
-
|
| 92 |
-
dt = tl.load(
|
| 93 |
-
dt_ptrs,
|
| 94 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
| 95 |
-
other=0.0,
|
| 96 |
-
).to(tl.float32)
|
| 97 |
-
if HAS_DT_BIAS:
|
| 98 |
-
dt_bias = tl.load(
|
| 99 |
-
dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
|
| 100 |
-
).to(tl.float32)
|
| 101 |
-
dt += dt_bias[:, None]
|
| 102 |
-
if DT_SOFTPLUS:
|
| 103 |
-
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
| 104 |
-
# As of Triton 2.2.0, tl.clamp is not available yet
|
| 105 |
-
# dt = tl.clamp(dt, dt_min, dt_max)
|
| 106 |
-
dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
|
| 107 |
-
dt = tl.where(
|
| 108 |
-
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
|
| 109 |
-
)
|
| 110 |
-
tl.store(
|
| 111 |
-
dt_out_ptrs,
|
| 112 |
-
dt,
|
| 113 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
|
| 114 |
-
)
|
| 115 |
-
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
|
| 116 |
-
dA = dt * A[:, None]
|
| 117 |
-
dA_cs = tl.cumsum(dA, axis=1)
|
| 118 |
-
tl.store(
|
| 119 |
-
dA_cs_ptrs,
|
| 120 |
-
dA_cs,
|
| 121 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
|
| 122 |
-
)
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
@triton.autotune(
|
| 126 |
-
configs=[
|
| 127 |
-
triton.Config(
|
| 128 |
-
{"BLOCK_SIZE_H": 1}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 129 |
-
),
|
| 130 |
-
triton.Config(
|
| 131 |
-
{"BLOCK_SIZE_H": 2}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 132 |
-
),
|
| 133 |
-
triton.Config(
|
| 134 |
-
{"BLOCK_SIZE_H": 4}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 135 |
-
),
|
| 136 |
-
triton.Config(
|
| 137 |
-
{"BLOCK_SIZE_H": 8}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 138 |
-
),
|
| 139 |
-
triton.Config(
|
| 140 |
-
{"BLOCK_SIZE_H": 16}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 141 |
-
),
|
| 142 |
-
triton.Config(
|
| 143 |
-
{"BLOCK_SIZE_H": 32}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 144 |
-
),
|
| 145 |
-
triton.Config(
|
| 146 |
-
{"BLOCK_SIZE_H": 64}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 147 |
-
),
|
| 148 |
-
],
|
| 149 |
-
key=["chunk_size", "nheads"],
|
| 150 |
-
)
|
| 151 |
-
@triton.jit
|
| 152 |
-
def _chunk_cumsum_bwd_kernel(
|
| 153 |
-
# Pointers to matrices
|
| 154 |
-
ddA_ptr,
|
| 155 |
-
ddt_out_ptr,
|
| 156 |
-
dt_ptr,
|
| 157 |
-
A_ptr,
|
| 158 |
-
dt_bias_ptr,
|
| 159 |
-
ddt_ptr,
|
| 160 |
-
dA_ptr,
|
| 161 |
-
ddt_bias_ptr,
|
| 162 |
-
# Matrix dimensions
|
| 163 |
-
batch,
|
| 164 |
-
seqlen,
|
| 165 |
-
nheads,
|
| 166 |
-
chunk_size,
|
| 167 |
-
dt_min,
|
| 168 |
-
dt_max,
|
| 169 |
-
# Strides
|
| 170 |
-
stride_ddA_batch,
|
| 171 |
-
stride_ddA_chunk,
|
| 172 |
-
stride_ddA_head,
|
| 173 |
-
stride_ddA_csize,
|
| 174 |
-
stride_ddt_out_batch,
|
| 175 |
-
stride_ddt_out_chunk,
|
| 176 |
-
stride_ddt_out_head,
|
| 177 |
-
stride_ddt_out_csize,
|
| 178 |
-
stride_dt_batch,
|
| 179 |
-
stride_dt_seqlen,
|
| 180 |
-
stride_dt_head,
|
| 181 |
-
stride_A_head,
|
| 182 |
-
stride_dt_bias_head,
|
| 183 |
-
stride_ddt_batch,
|
| 184 |
-
stride_ddt_seqlen,
|
| 185 |
-
stride_ddt_head,
|
| 186 |
-
stride_dA_head,
|
| 187 |
-
stride_ddt_bias_head,
|
| 188 |
-
# Meta-parameters
|
| 189 |
-
DT_SOFTPLUS: tl.constexpr,
|
| 190 |
-
HAS_DT_BIAS: tl.constexpr,
|
| 191 |
-
BLOCK_SIZE_H: tl.constexpr,
|
| 192 |
-
BLOCK_SIZE_CHUNK: tl.constexpr,
|
| 193 |
-
):
|
| 194 |
-
pid_b = tl.program_id(axis=0)
|
| 195 |
-
pid_c = tl.program_id(axis=1)
|
| 196 |
-
pid_h = tl.program_id(axis=2)
|
| 197 |
-
ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk
|
| 198 |
-
ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk
|
| 199 |
-
dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
|
| 200 |
-
ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen
|
| 201 |
-
|
| 202 |
-
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
|
| 203 |
-
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
|
| 204 |
-
ddt_out_ptrs = ddt_out_ptr + (
|
| 205 |
-
offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize
|
| 206 |
-
)
|
| 207 |
-
ddA_ptrs = ddA_ptr + (
|
| 208 |
-
offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize
|
| 209 |
-
)
|
| 210 |
-
dt_ptrs = dt_ptr + (
|
| 211 |
-
offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
|
| 212 |
-
)
|
| 213 |
-
ddt_ptrs = ddt_ptr + (
|
| 214 |
-
offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen
|
| 215 |
-
)
|
| 216 |
-
A_ptrs = A_ptr + offs_h * stride_A_head
|
| 217 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 218 |
-
|
| 219 |
-
ddA = tl.load(
|
| 220 |
-
ddA_ptrs,
|
| 221 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
| 222 |
-
other=0.0,
|
| 223 |
-
).to(tl.float32)
|
| 224 |
-
ddt_out = tl.load(
|
| 225 |
-
ddt_out_ptrs,
|
| 226 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
| 227 |
-
other=0.0,
|
| 228 |
-
).to(tl.float32)
|
| 229 |
-
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
|
| 230 |
-
ddt = ddA * A[:, None] + ddt_out
|
| 231 |
-
dt = tl.load(
|
| 232 |
-
dt_ptrs,
|
| 233 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
| 234 |
-
other=0.0,
|
| 235 |
-
).to(tl.float32)
|
| 236 |
-
if HAS_DT_BIAS:
|
| 237 |
-
dt_bias = tl.load(
|
| 238 |
-
dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
|
| 239 |
-
).to(tl.float32)
|
| 240 |
-
dt += dt_bias[:, None]
|
| 241 |
-
if DT_SOFTPLUS:
|
| 242 |
-
dt_presoftplus = dt
|
| 243 |
-
dt = tl.where(dt <= 20.0, softplus(dt), ddt)
|
| 244 |
-
clamp_mask = (dt < dt_min) | (dt > dt_max)
|
| 245 |
-
# As of Triton 2.2.0, tl.clamp is not available yet
|
| 246 |
-
# dt = tl.clamp(dt, dt_min, dt_max)
|
| 247 |
-
dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
|
| 248 |
-
dt = tl.where(
|
| 249 |
-
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
|
| 250 |
-
)
|
| 251 |
-
ddt = tl.where(
|
| 252 |
-
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0
|
| 253 |
-
)
|
| 254 |
-
ddt = tl.where(clamp_mask, 0.0, ddt)
|
| 255 |
-
if DT_SOFTPLUS:
|
| 256 |
-
ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt)
|
| 257 |
-
tl.store(
|
| 258 |
-
ddt_ptrs,
|
| 259 |
-
ddt,
|
| 260 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
| 261 |
-
)
|
| 262 |
-
dA = tl.sum(ddA * dt, axis=1)
|
| 263 |
-
tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads)
|
| 264 |
-
if HAS_DT_BIAS:
|
| 265 |
-
ddt_bias = tl.sum(ddt, axis=1)
|
| 266 |
-
tl.atomic_add(
|
| 267 |
-
ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads
|
| 268 |
-
)
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
@triton.autotune(
|
| 272 |
-
configs=[
|
| 273 |
-
triton.Config(
|
| 274 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
| 275 |
-
num_stages=3,
|
| 276 |
-
num_warps=8,
|
| 277 |
-
),
|
| 278 |
-
triton.Config(
|
| 279 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
| 280 |
-
num_stages=4,
|
| 281 |
-
num_warps=4,
|
| 282 |
-
),
|
| 283 |
-
triton.Config(
|
| 284 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 285 |
-
num_stages=4,
|
| 286 |
-
num_warps=4,
|
| 287 |
-
),
|
| 288 |
-
triton.Config(
|
| 289 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 290 |
-
num_stages=4,
|
| 291 |
-
num_warps=4,
|
| 292 |
-
),
|
| 293 |
-
triton.Config(
|
| 294 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 295 |
-
num_stages=4,
|
| 296 |
-
num_warps=4,
|
| 297 |
-
),
|
| 298 |
-
triton.Config(
|
| 299 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 300 |
-
num_stages=4,
|
| 301 |
-
num_warps=4,
|
| 302 |
-
),
|
| 303 |
-
triton.Config(
|
| 304 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 305 |
-
num_stages=5,
|
| 306 |
-
num_warps=2,
|
| 307 |
-
),
|
| 308 |
-
triton.Config(
|
| 309 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 310 |
-
num_stages=5,
|
| 311 |
-
num_warps=2,
|
| 312 |
-
),
|
| 313 |
-
triton.Config(
|
| 314 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 315 |
-
num_stages=4,
|
| 316 |
-
num_warps=2,
|
| 317 |
-
),
|
| 318 |
-
],
|
| 319 |
-
key=["hdim", "dstate", "chunk_size"],
|
| 320 |
-
)
|
| 321 |
-
@triton.jit
|
| 322 |
-
def _chunk_state_fwd_kernel(
|
| 323 |
-
# Pointers to matrices
|
| 324 |
-
x_ptr,
|
| 325 |
-
b_ptr,
|
| 326 |
-
states_ptr,
|
| 327 |
-
dt_ptr,
|
| 328 |
-
dA_cumsum_ptr,
|
| 329 |
-
seq_idx_ptr,
|
| 330 |
-
# Matrix dimensions
|
| 331 |
-
hdim,
|
| 332 |
-
dstate,
|
| 333 |
-
chunk_size,
|
| 334 |
-
batch,
|
| 335 |
-
seqlen,
|
| 336 |
-
nheads_ngroups_ratio,
|
| 337 |
-
# Strides
|
| 338 |
-
stride_x_batch,
|
| 339 |
-
stride_x_seqlen,
|
| 340 |
-
stride_x_head,
|
| 341 |
-
stride_x_hdim,
|
| 342 |
-
stride_b_batch,
|
| 343 |
-
stride_b_seqlen,
|
| 344 |
-
stride_b_head,
|
| 345 |
-
stride_b_dstate,
|
| 346 |
-
stride_states_batch,
|
| 347 |
-
stride_states_chunk,
|
| 348 |
-
stride_states_head,
|
| 349 |
-
stride_states_hdim,
|
| 350 |
-
stride_states_dstate,
|
| 351 |
-
stride_dt_batch,
|
| 352 |
-
stride_dt_chunk,
|
| 353 |
-
stride_dt_head,
|
| 354 |
-
stride_dt_csize,
|
| 355 |
-
stride_dA_cs_batch,
|
| 356 |
-
stride_dA_cs_chunk,
|
| 357 |
-
stride_dA_cs_head,
|
| 358 |
-
stride_dA_cs_csize,
|
| 359 |
-
stride_seq_idx_batch,
|
| 360 |
-
stride_seq_idx_seqlen,
|
| 361 |
-
# Meta-parameters
|
| 362 |
-
HAS_SEQ_IDX: tl.constexpr,
|
| 363 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 364 |
-
BLOCK_SIZE_N: tl.constexpr,
|
| 365 |
-
BLOCK_SIZE_K: tl.constexpr,
|
| 366 |
-
):
|
| 367 |
-
pid_bc = tl.program_id(axis=1)
|
| 368 |
-
pid_c = pid_bc // batch
|
| 369 |
-
pid_b = pid_bc - pid_c * batch
|
| 370 |
-
pid_h = tl.program_id(axis=2)
|
| 371 |
-
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
| 372 |
-
pid_m = tl.program_id(axis=0) // num_pid_n
|
| 373 |
-
pid_n = tl.program_id(axis=0) % num_pid_n
|
| 374 |
-
b_ptr += (
|
| 375 |
-
pid_b * stride_b_batch
|
| 376 |
-
+ pid_c * chunk_size * stride_b_seqlen
|
| 377 |
-
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
| 378 |
-
)
|
| 379 |
-
x_ptr += (
|
| 380 |
-
pid_b * stride_x_batch
|
| 381 |
-
+ pid_c * chunk_size * stride_x_seqlen
|
| 382 |
-
+ pid_h * stride_x_head
|
| 383 |
-
)
|
| 384 |
-
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
| 385 |
-
dA_cumsum_ptr += (
|
| 386 |
-
pid_b * stride_dA_cs_batch
|
| 387 |
-
+ pid_c * stride_dA_cs_chunk
|
| 388 |
-
+ pid_h * stride_dA_cs_head
|
| 389 |
-
)
|
| 390 |
-
if HAS_SEQ_IDX:
|
| 391 |
-
seq_idx_ptr += (
|
| 392 |
-
pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
| 393 |
-
)
|
| 394 |
-
|
| 395 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 396 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 397 |
-
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 398 |
-
x_ptrs = x_ptr + (
|
| 399 |
-
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
|
| 400 |
-
)
|
| 401 |
-
b_ptrs = b_ptr + (
|
| 402 |
-
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
|
| 403 |
-
)
|
| 404 |
-
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
| 405 |
-
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
| 406 |
-
tl.float32
|
| 407 |
-
)
|
| 408 |
-
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
| 409 |
-
if HAS_SEQ_IDX:
|
| 410 |
-
seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
|
| 411 |
-
|
| 412 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 413 |
-
if HAS_SEQ_IDX:
|
| 414 |
-
seq_idx_last = tl.load(
|
| 415 |
-
seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
|
| 416 |
-
)
|
| 417 |
-
|
| 418 |
-
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 419 |
-
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
| 420 |
-
x = tl.load(
|
| 421 |
-
x_ptrs,
|
| 422 |
-
mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k),
|
| 423 |
-
other=0.0,
|
| 424 |
-
)
|
| 425 |
-
b = tl.load(
|
| 426 |
-
b_ptrs,
|
| 427 |
-
mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate),
|
| 428 |
-
other=0.0,
|
| 429 |
-
).to(tl.float32)
|
| 430 |
-
dA_cs_k = tl.load(
|
| 431 |
-
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
|
| 432 |
-
).to(tl.float32)
|
| 433 |
-
if HAS_SEQ_IDX:
|
| 434 |
-
seq_idx_k = tl.load(
|
| 435 |
-
seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1
|
| 436 |
-
)
|
| 437 |
-
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
|
| 438 |
-
tl.float32
|
| 439 |
-
)
|
| 440 |
-
if not HAS_SEQ_IDX:
|
| 441 |
-
scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
|
| 442 |
-
else:
|
| 443 |
-
scale = tl.where(
|
| 444 |
-
seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0
|
| 445 |
-
)
|
| 446 |
-
b *= scale[:, None]
|
| 447 |
-
b = b.to(x_ptr.dtype.element_ty)
|
| 448 |
-
acc += tl.dot(x, b)
|
| 449 |
-
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
| 450 |
-
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
| 451 |
-
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
| 452 |
-
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
| 453 |
-
if HAS_SEQ_IDX:
|
| 454 |
-
seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
|
| 455 |
-
states = acc.to(states_ptr.dtype.element_ty)
|
| 456 |
-
|
| 457 |
-
states_ptr += (
|
| 458 |
-
pid_b * stride_states_batch
|
| 459 |
-
+ pid_c * stride_states_chunk
|
| 460 |
-
+ pid_h * stride_states_head
|
| 461 |
-
)
|
| 462 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 463 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 464 |
-
states_ptrs = states_ptr + (
|
| 465 |
-
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
|
| 466 |
-
)
|
| 467 |
-
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
| 468 |
-
tl.store(states_ptrs, states, mask=c_mask)
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
@triton.autotune(
|
| 472 |
-
configs=[
|
| 473 |
-
triton.Config(
|
| 474 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
| 475 |
-
num_stages=3,
|
| 476 |
-
num_warps=8,
|
| 477 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 478 |
-
),
|
| 479 |
-
triton.Config(
|
| 480 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
| 481 |
-
num_stages=4,
|
| 482 |
-
num_warps=4,
|
| 483 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 484 |
-
),
|
| 485 |
-
triton.Config(
|
| 486 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 487 |
-
num_stages=4,
|
| 488 |
-
num_warps=4,
|
| 489 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 490 |
-
),
|
| 491 |
-
triton.Config(
|
| 492 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 493 |
-
num_stages=4,
|
| 494 |
-
num_warps=4,
|
| 495 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 496 |
-
),
|
| 497 |
-
triton.Config(
|
| 498 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 499 |
-
num_stages=4,
|
| 500 |
-
num_warps=4,
|
| 501 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 502 |
-
),
|
| 503 |
-
triton.Config(
|
| 504 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 505 |
-
num_stages=4,
|
| 506 |
-
num_warps=4,
|
| 507 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 508 |
-
),
|
| 509 |
-
triton.Config(
|
| 510 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 511 |
-
num_stages=5,
|
| 512 |
-
num_warps=4,
|
| 513 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 514 |
-
),
|
| 515 |
-
triton.Config(
|
| 516 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 517 |
-
num_stages=5,
|
| 518 |
-
num_warps=4,
|
| 519 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 520 |
-
),
|
| 521 |
-
triton.Config(
|
| 522 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 523 |
-
num_stages=4,
|
| 524 |
-
num_warps=4,
|
| 525 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 526 |
-
),
|
| 527 |
-
],
|
| 528 |
-
key=["chunk_size", "hdim", "dstate"],
|
| 529 |
-
)
|
| 530 |
-
@triton.jit
|
| 531 |
-
def _chunk_state_bwd_dx_kernel(
|
| 532 |
-
# Pointers to matrices
|
| 533 |
-
x_ptr,
|
| 534 |
-
b_ptr,
|
| 535 |
-
dstates_ptr,
|
| 536 |
-
dt_ptr,
|
| 537 |
-
dA_cumsum_ptr,
|
| 538 |
-
dx_ptr,
|
| 539 |
-
ddt_ptr,
|
| 540 |
-
ddA_cumsum_ptr,
|
| 541 |
-
# Matrix dimensions
|
| 542 |
-
chunk_size,
|
| 543 |
-
hdim,
|
| 544 |
-
dstate,
|
| 545 |
-
batch,
|
| 546 |
-
seqlen,
|
| 547 |
-
nheads_ngroups_ratio,
|
| 548 |
-
# Strides
|
| 549 |
-
stride_x_batch,
|
| 550 |
-
stride_x_seqlen,
|
| 551 |
-
stride_x_head,
|
| 552 |
-
stride_x_hdim,
|
| 553 |
-
stride_b_batch,
|
| 554 |
-
stride_b_seqlen,
|
| 555 |
-
stride_b_head,
|
| 556 |
-
stride_b_dstate,
|
| 557 |
-
stride_dstates_batch,
|
| 558 |
-
stride_dstates_chunk,
|
| 559 |
-
stride_states_head,
|
| 560 |
-
stride_states_hdim,
|
| 561 |
-
stride_states_dstate,
|
| 562 |
-
stride_dt_batch,
|
| 563 |
-
stride_dt_chunk,
|
| 564 |
-
stride_dt_head,
|
| 565 |
-
stride_dt_csize,
|
| 566 |
-
stride_dA_cs_batch,
|
| 567 |
-
stride_dA_cs_chunk,
|
| 568 |
-
stride_dA_cs_head,
|
| 569 |
-
stride_dA_cs_csize,
|
| 570 |
-
stride_dx_batch,
|
| 571 |
-
stride_dx_seqlen,
|
| 572 |
-
stride_dx_head,
|
| 573 |
-
stride_dx_hdim,
|
| 574 |
-
stride_ddt_batch,
|
| 575 |
-
stride_ddt_chunk,
|
| 576 |
-
stride_ddt_head,
|
| 577 |
-
stride_ddt_csize,
|
| 578 |
-
stride_ddA_cs_batch,
|
| 579 |
-
stride_ddA_cs_chunk,
|
| 580 |
-
stride_ddA_cs_head,
|
| 581 |
-
stride_ddA_cs_csize,
|
| 582 |
-
# Meta-parameters
|
| 583 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 584 |
-
BLOCK_SIZE_N: tl.constexpr,
|
| 585 |
-
BLOCK_SIZE_K: tl.constexpr,
|
| 586 |
-
BLOCK_SIZE_DSTATE: tl.constexpr,
|
| 587 |
-
):
|
| 588 |
-
pid_bc = tl.program_id(axis=1)
|
| 589 |
-
pid_c = pid_bc // batch
|
| 590 |
-
pid_b = pid_bc - pid_c * batch
|
| 591 |
-
pid_h = tl.program_id(axis=2)
|
| 592 |
-
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
| 593 |
-
pid_m = tl.program_id(axis=0) // num_pid_n
|
| 594 |
-
pid_n = tl.program_id(axis=0) % num_pid_n
|
| 595 |
-
x_ptr += (
|
| 596 |
-
pid_b * stride_x_batch
|
| 597 |
-
+ pid_c * chunk_size * stride_x_seqlen
|
| 598 |
-
+ pid_h * stride_x_head
|
| 599 |
-
)
|
| 600 |
-
b_ptr += (
|
| 601 |
-
pid_b * stride_b_batch
|
| 602 |
-
+ pid_c * chunk_size * stride_b_seqlen
|
| 603 |
-
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
| 604 |
-
)
|
| 605 |
-
dstates_ptr += (
|
| 606 |
-
pid_b * stride_dstates_batch
|
| 607 |
-
+ pid_c * stride_dstates_chunk
|
| 608 |
-
+ pid_h * stride_states_head
|
| 609 |
-
)
|
| 610 |
-
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
| 611 |
-
ddt_ptr += (
|
| 612 |
-
pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
|
| 613 |
-
)
|
| 614 |
-
ddA_cumsum_ptr += (
|
| 615 |
-
pid_b * stride_ddA_cs_batch
|
| 616 |
-
+ pid_c * stride_ddA_cs_chunk
|
| 617 |
-
+ pid_h * stride_ddA_cs_head
|
| 618 |
-
)
|
| 619 |
-
dA_cumsum_ptr += (
|
| 620 |
-
pid_b * stride_dA_cs_batch
|
| 621 |
-
+ pid_c * stride_dA_cs_chunk
|
| 622 |
-
+ pid_h * stride_dA_cs_head
|
| 623 |
-
)
|
| 624 |
-
|
| 625 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 626 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 627 |
-
|
| 628 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 629 |
-
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
| 630 |
-
offs_k = tl.arange(
|
| 631 |
-
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
|
| 632 |
-
)
|
| 633 |
-
b_ptrs = b_ptr + (
|
| 634 |
-
offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate
|
| 635 |
-
)
|
| 636 |
-
dstates_ptrs = dstates_ptr + (
|
| 637 |
-
offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate
|
| 638 |
-
)
|
| 639 |
-
if BLOCK_SIZE_DSTATE <= 128:
|
| 640 |
-
b = tl.load(
|
| 641 |
-
b_ptrs,
|
| 642 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate),
|
| 643 |
-
other=0.0,
|
| 644 |
-
)
|
| 645 |
-
dstates = tl.load(
|
| 646 |
-
dstates_ptrs,
|
| 647 |
-
mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim),
|
| 648 |
-
other=0.0,
|
| 649 |
-
)
|
| 650 |
-
dstates = dstates.to(b_ptr.dtype.element_ty)
|
| 651 |
-
acc = tl.dot(b, dstates)
|
| 652 |
-
else:
|
| 653 |
-
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 654 |
-
for k in range(0, dstate, BLOCK_SIZE_K):
|
| 655 |
-
b = tl.load(
|
| 656 |
-
b_ptrs,
|
| 657 |
-
mask=(offs_m[:, None] < chunk_size_limit)
|
| 658 |
-
& (offs_k[None, :] < dstate - k),
|
| 659 |
-
other=0.0,
|
| 660 |
-
)
|
| 661 |
-
dstates = tl.load(
|
| 662 |
-
dstates_ptrs,
|
| 663 |
-
mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim),
|
| 664 |
-
other=0.0,
|
| 665 |
-
)
|
| 666 |
-
dstates = dstates.to(b_ptr.dtype.element_ty)
|
| 667 |
-
acc += tl.dot(b, dstates)
|
| 668 |
-
b_ptrs += BLOCK_SIZE_K * stride_b_dstate
|
| 669 |
-
dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
|
| 670 |
-
|
| 671 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 672 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 673 |
-
|
| 674 |
-
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
| 675 |
-
tl.float32
|
| 676 |
-
)
|
| 677 |
-
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
| 678 |
-
dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
|
| 679 |
-
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(
|
| 680 |
-
tl.float32
|
| 681 |
-
)
|
| 682 |
-
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
| 683 |
-
acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]
|
| 684 |
-
|
| 685 |
-
x_ptrs = x_ptr + (
|
| 686 |
-
offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
|
| 687 |
-
)
|
| 688 |
-
x = tl.load(
|
| 689 |
-
x_ptrs,
|
| 690 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
| 691 |
-
other=0.0,
|
| 692 |
-
).to(tl.float32)
|
| 693 |
-
ddt = tl.sum(acc * x, axis=1)
|
| 694 |
-
ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
|
| 695 |
-
tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
|
| 696 |
-
ddA_cs = -(ddt * dt_m)
|
| 697 |
-
ddA_cs_last = -tl.sum(ddA_cs)
|
| 698 |
-
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
|
| 699 |
-
tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
|
| 700 |
-
tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last)
|
| 701 |
-
|
| 702 |
-
dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty)
|
| 703 |
-
dx_ptr += (
|
| 704 |
-
pid_b * stride_dx_batch
|
| 705 |
-
+ pid_c * chunk_size * stride_dx_seqlen
|
| 706 |
-
+ pid_h * stride_dx_head
|
| 707 |
-
)
|
| 708 |
-
dx_ptrs = dx_ptr + (
|
| 709 |
-
offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim
|
| 710 |
-
)
|
| 711 |
-
tl.store(
|
| 712 |
-
dx_ptrs,
|
| 713 |
-
dx,
|
| 714 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
| 715 |
-
)
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
@triton.autotune(
|
| 719 |
-
configs=[
|
| 720 |
-
triton.Config(
|
| 721 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128},
|
| 722 |
-
num_stages=3,
|
| 723 |
-
num_warps=4,
|
| 724 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 725 |
-
),
|
| 726 |
-
triton.Config(
|
| 727 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32},
|
| 728 |
-
num_stages=3,
|
| 729 |
-
num_warps=4,
|
| 730 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 731 |
-
),
|
| 732 |
-
triton.Config(
|
| 733 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128},
|
| 734 |
-
num_stages=3,
|
| 735 |
-
num_warps=4,
|
| 736 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 737 |
-
),
|
| 738 |
-
triton.Config(
|
| 739 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64},
|
| 740 |
-
num_stages=3,
|
| 741 |
-
num_warps=4,
|
| 742 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 743 |
-
),
|
| 744 |
-
triton.Config(
|
| 745 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64},
|
| 746 |
-
num_stages=3,
|
| 747 |
-
num_warps=4,
|
| 748 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 749 |
-
),
|
| 750 |
-
triton.Config(
|
| 751 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32},
|
| 752 |
-
num_stages=3,
|
| 753 |
-
num_warps=4,
|
| 754 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 755 |
-
),
|
| 756 |
-
triton.Config(
|
| 757 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64},
|
| 758 |
-
num_stages=3,
|
| 759 |
-
num_warps=4,
|
| 760 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 761 |
-
),
|
| 762 |
-
triton.Config(
|
| 763 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32},
|
| 764 |
-
num_stages=3,
|
| 765 |
-
num_warps=4,
|
| 766 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 767 |
-
),
|
| 768 |
-
],
|
| 769 |
-
key=["chunk_size", "dstate", "hdim"],
|
| 770 |
-
)
|
| 771 |
-
@triton.jit
|
| 772 |
-
def _chunk_state_bwd_db_kernel(
|
| 773 |
-
# Pointers to matrices
|
| 774 |
-
x_ptr,
|
| 775 |
-
dstates_ptr,
|
| 776 |
-
b_ptr,
|
| 777 |
-
dt_ptr,
|
| 778 |
-
dA_cumsum_ptr,
|
| 779 |
-
seq_idx_ptr,
|
| 780 |
-
db_ptr,
|
| 781 |
-
ddA_cumsum_ptr,
|
| 782 |
-
# Matrix dimensions
|
| 783 |
-
chunk_size,
|
| 784 |
-
dstate,
|
| 785 |
-
hdim,
|
| 786 |
-
batch,
|
| 787 |
-
seqlen,
|
| 788 |
-
nheads,
|
| 789 |
-
nheads_per_program,
|
| 790 |
-
ngroups,
|
| 791 |
-
# Strides
|
| 792 |
-
stride_x_batch,
|
| 793 |
-
stride_x_seqlen,
|
| 794 |
-
stride_x_head,
|
| 795 |
-
stride_x_hdim,
|
| 796 |
-
stride_dstates_batch,
|
| 797 |
-
stride_dstates_chunk,
|
| 798 |
-
stride_states_head,
|
| 799 |
-
stride_states_hdim,
|
| 800 |
-
stride_states_dstate,
|
| 801 |
-
stride_b_batch,
|
| 802 |
-
stride_b_seqlen,
|
| 803 |
-
stride_b_head,
|
| 804 |
-
stride_b_dstate,
|
| 805 |
-
stride_dt_batch,
|
| 806 |
-
stride_dt_chunk,
|
| 807 |
-
stride_dt_head,
|
| 808 |
-
stride_dt_csize,
|
| 809 |
-
stride_dA_cs_batch,
|
| 810 |
-
stride_dA_cs_chunk,
|
| 811 |
-
stride_dA_cs_head,
|
| 812 |
-
stride_dA_cs_csize,
|
| 813 |
-
stride_seq_idx_batch,
|
| 814 |
-
stride_seq_idx_seqlen,
|
| 815 |
-
stride_db_batch,
|
| 816 |
-
stride_db_seqlen,
|
| 817 |
-
stride_db_split,
|
| 818 |
-
stride_db_group,
|
| 819 |
-
stride_db_dstate,
|
| 820 |
-
stride_ddA_cs_batch,
|
| 821 |
-
stride_ddA_cs_chunk,
|
| 822 |
-
stride_ddA_cs_head,
|
| 823 |
-
stride_ddA_cs_csize,
|
| 824 |
-
# Meta-parameters
|
| 825 |
-
HAS_DDA_CS: tl.constexpr,
|
| 826 |
-
HAS_SEQ_IDX: tl.constexpr,
|
| 827 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 828 |
-
BLOCK_SIZE_N: tl.constexpr,
|
| 829 |
-
BLOCK_SIZE_K: tl.constexpr,
|
| 830 |
-
):
|
| 831 |
-
pid_bc = tl.program_id(axis=1)
|
| 832 |
-
pid_c = pid_bc // batch
|
| 833 |
-
pid_b = pid_bc - pid_c * batch
|
| 834 |
-
pid_sg = tl.program_id(axis=2)
|
| 835 |
-
pid_s = pid_sg // ngroups
|
| 836 |
-
pid_g = pid_sg - pid_s * ngroups
|
| 837 |
-
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
| 838 |
-
pid_m = tl.program_id(axis=0) // num_pid_n
|
| 839 |
-
pid_n = tl.program_id(axis=0) % num_pid_n
|
| 840 |
-
x_ptr += (
|
| 841 |
-
pid_b * stride_x_batch
|
| 842 |
-
+ pid_c * chunk_size * stride_x_seqlen
|
| 843 |
-
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head
|
| 844 |
-
)
|
| 845 |
-
db_ptr += (
|
| 846 |
-
pid_b * stride_db_batch
|
| 847 |
-
+ pid_c * chunk_size * stride_db_seqlen
|
| 848 |
-
+ pid_g * stride_db_group
|
| 849 |
-
+ pid_s * stride_db_split
|
| 850 |
-
)
|
| 851 |
-
dstates_ptr += (
|
| 852 |
-
pid_b * stride_dstates_batch
|
| 853 |
-
+ pid_c * stride_dstates_chunk
|
| 854 |
-
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program)
|
| 855 |
-
* stride_states_head
|
| 856 |
-
)
|
| 857 |
-
dt_ptr += (
|
| 858 |
-
pid_b * stride_dt_batch
|
| 859 |
-
+ pid_c * stride_dt_chunk
|
| 860 |
-
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head
|
| 861 |
-
)
|
| 862 |
-
dA_cumsum_ptr += (
|
| 863 |
-
pid_b * stride_dA_cs_batch
|
| 864 |
-
+ pid_c * stride_dA_cs_chunk
|
| 865 |
-
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head
|
| 866 |
-
)
|
| 867 |
-
if HAS_DDA_CS:
|
| 868 |
-
b_ptr += (
|
| 869 |
-
pid_b * stride_b_batch
|
| 870 |
-
+ pid_c * chunk_size * stride_b_seqlen
|
| 871 |
-
+ pid_g * stride_b_head
|
| 872 |
-
)
|
| 873 |
-
ddA_cumsum_ptr += (
|
| 874 |
-
pid_b * stride_ddA_cs_batch
|
| 875 |
-
+ pid_c * stride_ddA_cs_chunk
|
| 876 |
-
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program)
|
| 877 |
-
* stride_ddA_cs_head
|
| 878 |
-
)
|
| 879 |
-
if HAS_SEQ_IDX:
|
| 880 |
-
seq_idx_ptr += (
|
| 881 |
-
pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
| 882 |
-
)
|
| 883 |
-
|
| 884 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 885 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 886 |
-
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 887 |
-
x_ptrs = x_ptr + (
|
| 888 |
-
offs_m[:, None] * stride_x_seqlen + offs_k[None, :] * stride_x_hdim
|
| 889 |
-
)
|
| 890 |
-
dstates_ptrs = dstates_ptr + (
|
| 891 |
-
offs_n[None, :] * stride_states_dstate + offs_k[:, None] * stride_states_hdim
|
| 892 |
-
)
|
| 893 |
-
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
| 894 |
-
dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
|
| 895 |
-
if HAS_DDA_CS:
|
| 896 |
-
b_ptrs = b_ptr + (
|
| 897 |
-
offs_m[:, None] * stride_b_seqlen + offs_n[None, :] * stride_b_dstate
|
| 898 |
-
)
|
| 899 |
-
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
|
| 900 |
-
|
| 901 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 902 |
-
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 903 |
-
if HAS_DDA_CS:
|
| 904 |
-
b = tl.load(
|
| 905 |
-
b_ptrs,
|
| 906 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate),
|
| 907 |
-
other=0.0,
|
| 908 |
-
).to(tl.float32)
|
| 909 |
-
if HAS_SEQ_IDX:
|
| 910 |
-
seq_idx_m = tl.load(
|
| 911 |
-
seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
| 912 |
-
mask=offs_m < chunk_size_limit,
|
| 913 |
-
other=-1,
|
| 914 |
-
)
|
| 915 |
-
seq_idx_last = tl.load(
|
| 916 |
-
seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
|
| 917 |
-
)
|
| 918 |
-
nheads_iter = min(
|
| 919 |
-
nheads_per_program, nheads // ngroups - pid_s * nheads_per_program
|
| 920 |
-
)
|
| 921 |
-
for h in range(nheads_iter):
|
| 922 |
-
x = tl.load(
|
| 923 |
-
x_ptrs,
|
| 924 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim),
|
| 925 |
-
other=0.0,
|
| 926 |
-
)
|
| 927 |
-
dstates = tl.load(
|
| 928 |
-
dstates_ptrs,
|
| 929 |
-
mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate),
|
| 930 |
-
other=0.0,
|
| 931 |
-
)
|
| 932 |
-
dstates = dstates.to(x_ptrs.dtype.element_ty)
|
| 933 |
-
db = tl.dot(x, dstates)
|
| 934 |
-
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
| 935 |
-
tl.float32
|
| 936 |
-
)
|
| 937 |
-
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(
|
| 938 |
-
tl.float32
|
| 939 |
-
)
|
| 940 |
-
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
| 941 |
-
if not HAS_SEQ_IDX:
|
| 942 |
-
scale = tl.exp(dA_cs_last - dA_cs_m)
|
| 943 |
-
else:
|
| 944 |
-
scale = tl.where(
|
| 945 |
-
seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0
|
| 946 |
-
)
|
| 947 |
-
db *= (scale * dt_m)[:, None]
|
| 948 |
-
if HAS_DDA_CS:
|
| 949 |
-
# This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum
|
| 950 |
-
ddA_cs = tl.sum(db * b, axis=1)
|
| 951 |
-
tl.atomic_add(
|
| 952 |
-
ddA_cumsum_ptrs + stride_ddA_cs_csize,
|
| 953 |
-
ddA_cs,
|
| 954 |
-
mask=offs_m < chunk_size - 1,
|
| 955 |
-
)
|
| 956 |
-
acc += db
|
| 957 |
-
x_ptrs += stride_x_head
|
| 958 |
-
dstates_ptrs += stride_states_head
|
| 959 |
-
dt_ptrs += stride_dt_head
|
| 960 |
-
dA_cumsum_ptr += stride_dA_cs_head
|
| 961 |
-
dA_cumsum_ptrs += stride_dA_cs_head
|
| 962 |
-
if HAS_DDA_CS:
|
| 963 |
-
ddA_cumsum_ptrs += stride_ddA_cs_head
|
| 964 |
-
|
| 965 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 966 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 967 |
-
# if HAS_SEQ_IDX:
|
| 968 |
-
# seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
|
| 969 |
-
# seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
|
| 970 |
-
# acc = tl.where(seq_idx_m[:, None] == seq_idx_last, acc, 0.0)
|
| 971 |
-
db_ptrs = db_ptr + (
|
| 972 |
-
offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_dstate
|
| 973 |
-
)
|
| 974 |
-
tl.store(
|
| 975 |
-
db_ptrs,
|
| 976 |
-
acc,
|
| 977 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate),
|
| 978 |
-
)
|
| 979 |
-
|
| 980 |
-
|
| 981 |
-
@triton.autotune(
|
| 982 |
-
configs=[
|
| 983 |
-
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 984 |
-
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 985 |
-
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 986 |
-
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 987 |
-
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 988 |
-
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 989 |
-
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 990 |
-
# triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 991 |
-
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 992 |
-
triton.Config(
|
| 993 |
-
{"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32},
|
| 994 |
-
num_stages=3,
|
| 995 |
-
num_warps=4,
|
| 996 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 997 |
-
),
|
| 998 |
-
triton.Config(
|
| 999 |
-
{"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 1000 |
-
num_stages=3,
|
| 1001 |
-
num_warps=4,
|
| 1002 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1003 |
-
),
|
| 1004 |
-
triton.Config(
|
| 1005 |
-
{"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 1006 |
-
num_stages=3,
|
| 1007 |
-
num_warps=4,
|
| 1008 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1009 |
-
),
|
| 1010 |
-
triton.Config(
|
| 1011 |
-
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 1012 |
-
num_stages=3,
|
| 1013 |
-
num_warps=4,
|
| 1014 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1015 |
-
),
|
| 1016 |
-
triton.Config(
|
| 1017 |
-
{"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32},
|
| 1018 |
-
num_stages=4,
|
| 1019 |
-
num_warps=8,
|
| 1020 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1021 |
-
),
|
| 1022 |
-
triton.Config(
|
| 1023 |
-
{"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 1024 |
-
num_stages=4,
|
| 1025 |
-
num_warps=8,
|
| 1026 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1027 |
-
),
|
| 1028 |
-
triton.Config(
|
| 1029 |
-
{"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 1030 |
-
num_stages=4,
|
| 1031 |
-
num_warps=8,
|
| 1032 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1033 |
-
),
|
| 1034 |
-
triton.Config(
|
| 1035 |
-
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 1036 |
-
num_stages=4,
|
| 1037 |
-
num_warps=8,
|
| 1038 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1039 |
-
),
|
| 1040 |
-
],
|
| 1041 |
-
key=["chunk_size", "hdim", "dstate"],
|
| 1042 |
-
)
|
| 1043 |
-
@triton.jit
|
| 1044 |
-
def _chunk_state_bwd_ddAcs_stable_kernel(
|
| 1045 |
-
# Pointers to matrices
|
| 1046 |
-
x_ptr,
|
| 1047 |
-
b_ptr,
|
| 1048 |
-
dstates_ptr,
|
| 1049 |
-
dt_ptr,
|
| 1050 |
-
dA_cumsum_ptr,
|
| 1051 |
-
seq_idx_ptr,
|
| 1052 |
-
ddA_cumsum_ptr,
|
| 1053 |
-
# Matrix dimensions
|
| 1054 |
-
chunk_size,
|
| 1055 |
-
hdim,
|
| 1056 |
-
dstate,
|
| 1057 |
-
batch,
|
| 1058 |
-
seqlen,
|
| 1059 |
-
nheads_ngroups_ratio,
|
| 1060 |
-
# Strides
|
| 1061 |
-
stride_x_batch,
|
| 1062 |
-
stride_x_seqlen,
|
| 1063 |
-
stride_x_head,
|
| 1064 |
-
stride_x_hdim,
|
| 1065 |
-
stride_b_batch,
|
| 1066 |
-
stride_b_seqlen,
|
| 1067 |
-
stride_b_head,
|
| 1068 |
-
stride_b_dstate,
|
| 1069 |
-
stride_dstates_batch,
|
| 1070 |
-
stride_dstates_chunk,
|
| 1071 |
-
stride_states_head,
|
| 1072 |
-
stride_states_hdim,
|
| 1073 |
-
stride_states_dstate,
|
| 1074 |
-
stride_dt_batch,
|
| 1075 |
-
stride_dt_chunk,
|
| 1076 |
-
stride_dt_head,
|
| 1077 |
-
stride_dt_csize,
|
| 1078 |
-
stride_dA_cs_batch,
|
| 1079 |
-
stride_dA_cs_chunk,
|
| 1080 |
-
stride_dA_cs_head,
|
| 1081 |
-
stride_dA_cs_csize,
|
| 1082 |
-
stride_seq_idx_batch,
|
| 1083 |
-
stride_seq_idx_seqlen,
|
| 1084 |
-
stride_ddA_cs_batch,
|
| 1085 |
-
stride_ddA_cs_chunk,
|
| 1086 |
-
stride_ddA_cs_head,
|
| 1087 |
-
stride_ddA_cs_csize,
|
| 1088 |
-
# Meta-parameters
|
| 1089 |
-
HAS_SEQ_IDX: tl.constexpr,
|
| 1090 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 1091 |
-
BLOCK_SIZE_N: tl.constexpr,
|
| 1092 |
-
BLOCK_SIZE_K: tl.constexpr,
|
| 1093 |
-
BLOCK_SIZE_DSTATE: tl.constexpr,
|
| 1094 |
-
):
|
| 1095 |
-
pid_bc = tl.program_id(axis=1)
|
| 1096 |
-
pid_c = pid_bc // batch
|
| 1097 |
-
pid_b = pid_bc - pid_c * batch
|
| 1098 |
-
pid_h = tl.program_id(axis=2)
|
| 1099 |
-
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
| 1100 |
-
pid_m = tl.program_id(axis=0) // num_pid_n
|
| 1101 |
-
pid_n = tl.program_id(axis=0) % num_pid_n
|
| 1102 |
-
x_ptr += (
|
| 1103 |
-
pid_b * stride_x_batch
|
| 1104 |
-
+ pid_c * chunk_size * stride_x_seqlen
|
| 1105 |
-
+ pid_h * stride_x_head
|
| 1106 |
-
)
|
| 1107 |
-
b_ptr += (
|
| 1108 |
-
pid_b * stride_b_batch
|
| 1109 |
-
+ pid_c * chunk_size * stride_b_seqlen
|
| 1110 |
-
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
| 1111 |
-
)
|
| 1112 |
-
dstates_ptr += (
|
| 1113 |
-
pid_b * stride_dstates_batch
|
| 1114 |
-
+ pid_c * stride_dstates_chunk
|
| 1115 |
-
+ pid_h * stride_states_head
|
| 1116 |
-
)
|
| 1117 |
-
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
| 1118 |
-
ddA_cumsum_ptr += (
|
| 1119 |
-
pid_b * stride_ddA_cs_batch
|
| 1120 |
-
+ pid_c * stride_ddA_cs_chunk
|
| 1121 |
-
+ pid_h * stride_ddA_cs_head
|
| 1122 |
-
)
|
| 1123 |
-
dA_cumsum_ptr += (
|
| 1124 |
-
pid_b * stride_dA_cs_batch
|
| 1125 |
-
+ pid_c * stride_dA_cs_chunk
|
| 1126 |
-
+ pid_h * stride_dA_cs_head
|
| 1127 |
-
)
|
| 1128 |
-
if HAS_SEQ_IDX:
|
| 1129 |
-
seq_idx_ptr += (
|
| 1130 |
-
pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
| 1131 |
-
)
|
| 1132 |
-
|
| 1133 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 1134 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 1135 |
-
|
| 1136 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 1137 |
-
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
| 1138 |
-
offs_k = tl.arange(
|
| 1139 |
-
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
|
| 1140 |
-
)
|
| 1141 |
-
b_ptrs = b_ptr + (
|
| 1142 |
-
offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate
|
| 1143 |
-
)
|
| 1144 |
-
dstates_ptrs = dstates_ptr + (
|
| 1145 |
-
offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate
|
| 1146 |
-
)
|
| 1147 |
-
if BLOCK_SIZE_DSTATE <= 128:
|
| 1148 |
-
b = tl.load(
|
| 1149 |
-
b_ptrs,
|
| 1150 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate),
|
| 1151 |
-
other=0.0,
|
| 1152 |
-
)
|
| 1153 |
-
dstates = tl.load(
|
| 1154 |
-
dstates_ptrs,
|
| 1155 |
-
mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim),
|
| 1156 |
-
other=0.0,
|
| 1157 |
-
)
|
| 1158 |
-
dstates = dstates.to(b_ptr.dtype.element_ty)
|
| 1159 |
-
acc = tl.dot(b, dstates)
|
| 1160 |
-
else:
|
| 1161 |
-
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 1162 |
-
for k in range(0, dstate, BLOCK_SIZE_K):
|
| 1163 |
-
b = tl.load(
|
| 1164 |
-
b_ptrs,
|
| 1165 |
-
mask=(offs_m[:, None] < chunk_size_limit)
|
| 1166 |
-
& (offs_k[None, :] < dstate - k),
|
| 1167 |
-
other=0.0,
|
| 1168 |
-
)
|
| 1169 |
-
dstates = tl.load(
|
| 1170 |
-
dstates_ptrs,
|
| 1171 |
-
mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim),
|
| 1172 |
-
other=0.0,
|
| 1173 |
-
)
|
| 1174 |
-
dstates = dstates.to(b_ptr.dtype.element_ty)
|
| 1175 |
-
acc += tl.dot(b, dstates)
|
| 1176 |
-
b_ptrs += BLOCK_SIZE_K * stride_b_dstate
|
| 1177 |
-
dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
|
| 1178 |
-
|
| 1179 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 1180 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 1181 |
-
|
| 1182 |
-
dA_cs_m = tl.load(
|
| 1183 |
-
dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0
|
| 1184 |
-
).to(tl.float32)
|
| 1185 |
-
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
| 1186 |
-
tl.float32
|
| 1187 |
-
)
|
| 1188 |
-
if not HAS_SEQ_IDX:
|
| 1189 |
-
scale = tl.exp(dA_cs_last - dA_cs_m)
|
| 1190 |
-
else:
|
| 1191 |
-
seq_idx_m = tl.load(
|
| 1192 |
-
seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
| 1193 |
-
mask=offs_m < chunk_size_limit,
|
| 1194 |
-
other=-1,
|
| 1195 |
-
)
|
| 1196 |
-
seq_idx_last = tl.load(
|
| 1197 |
-
seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
|
| 1198 |
-
)
|
| 1199 |
-
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
|
| 1200 |
-
acc *= scale[:, None]
|
| 1201 |
-
|
| 1202 |
-
x_ptrs = x_ptr + (
|
| 1203 |
-
offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
|
| 1204 |
-
)
|
| 1205 |
-
x = tl.load(
|
| 1206 |
-
x_ptrs,
|
| 1207 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
| 1208 |
-
other=0.0,
|
| 1209 |
-
).to(tl.float32)
|
| 1210 |
-
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
| 1211 |
-
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
| 1212 |
-
ddt = tl.sum(acc * x, axis=1)
|
| 1213 |
-
# ddA_cs = -(ddt * dt_m)
|
| 1214 |
-
# Triton 2.2.0 errors if we have the cumsum here, so we just write it out
|
| 1215 |
-
# then call torch.cumsum outside this kernel.
|
| 1216 |
-
# ddA_cs = tl.cumsum(ddt * dt_m)
|
| 1217 |
-
ddA_cs = ddt * dt_m
|
| 1218 |
-
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
|
| 1219 |
-
# tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
|
| 1220 |
-
tl.atomic_add(
|
| 1221 |
-
ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1
|
| 1222 |
-
)
|
| 1223 |
-
|
| 1224 |
-
|
| 1225 |
-
@triton.autotune(
|
| 1226 |
-
configs=[
|
| 1227 |
-
triton.Config(
|
| 1228 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
| 1229 |
-
num_stages=3,
|
| 1230 |
-
num_warps=8,
|
| 1231 |
-
),
|
| 1232 |
-
triton.Config(
|
| 1233 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
| 1234 |
-
num_stages=4,
|
| 1235 |
-
num_warps=4,
|
| 1236 |
-
),
|
| 1237 |
-
triton.Config(
|
| 1238 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 1239 |
-
num_stages=4,
|
| 1240 |
-
num_warps=4,
|
| 1241 |
-
),
|
| 1242 |
-
triton.Config(
|
| 1243 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 1244 |
-
num_stages=4,
|
| 1245 |
-
num_warps=4,
|
| 1246 |
-
),
|
| 1247 |
-
triton.Config(
|
| 1248 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 1249 |
-
num_stages=4,
|
| 1250 |
-
num_warps=4,
|
| 1251 |
-
),
|
| 1252 |
-
triton.Config(
|
| 1253 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 1254 |
-
num_stages=4,
|
| 1255 |
-
num_warps=4,
|
| 1256 |
-
),
|
| 1257 |
-
triton.Config(
|
| 1258 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 1259 |
-
num_stages=5,
|
| 1260 |
-
num_warps=2,
|
| 1261 |
-
),
|
| 1262 |
-
triton.Config(
|
| 1263 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 1264 |
-
num_stages=5,
|
| 1265 |
-
num_warps=2,
|
| 1266 |
-
),
|
| 1267 |
-
triton.Config(
|
| 1268 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 1269 |
-
num_stages=4,
|
| 1270 |
-
num_warps=2,
|
| 1271 |
-
),
|
| 1272 |
-
],
|
| 1273 |
-
key=["hdim", "dstate", "chunk_size"],
|
| 1274 |
-
)
|
| 1275 |
-
@triton.jit
|
| 1276 |
-
def _chunk_state_varlen_kernel(
|
| 1277 |
-
# Pointers to matrices
|
| 1278 |
-
x_ptr,
|
| 1279 |
-
b_ptr,
|
| 1280 |
-
dt_ptr,
|
| 1281 |
-
dA_cumsum_ptr,
|
| 1282 |
-
chunk_states_ptr,
|
| 1283 |
-
cu_seqlens_ptr,
|
| 1284 |
-
states_ptr,
|
| 1285 |
-
# Matrix dimensions
|
| 1286 |
-
hdim,
|
| 1287 |
-
dstate,
|
| 1288 |
-
chunk_size,
|
| 1289 |
-
seqlen,
|
| 1290 |
-
nheads_ngroups_ratio,
|
| 1291 |
-
# Strides
|
| 1292 |
-
stride_x_seqlen,
|
| 1293 |
-
stride_x_head,
|
| 1294 |
-
stride_x_hdim,
|
| 1295 |
-
stride_b_seqlen,
|
| 1296 |
-
stride_b_head,
|
| 1297 |
-
stride_b_dstate,
|
| 1298 |
-
stride_dt_chunk,
|
| 1299 |
-
stride_dt_head,
|
| 1300 |
-
stride_dt_csize,
|
| 1301 |
-
stride_dA_cs_chunk,
|
| 1302 |
-
stride_dA_cs_head,
|
| 1303 |
-
stride_dA_cs_csize,
|
| 1304 |
-
stride_chunk_states_chunk,
|
| 1305 |
-
stride_chunk_states_head,
|
| 1306 |
-
stride_chunk_states_hdim,
|
| 1307 |
-
stride_chunk_states_dstate,
|
| 1308 |
-
stride_states_batch,
|
| 1309 |
-
stride_states_head,
|
| 1310 |
-
stride_states_hdim,
|
| 1311 |
-
stride_states_dstate,
|
| 1312 |
-
# Meta-parameters
|
| 1313 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 1314 |
-
BLOCK_SIZE_N: tl.constexpr,
|
| 1315 |
-
BLOCK_SIZE_K: tl.constexpr,
|
| 1316 |
-
):
|
| 1317 |
-
pid_b = tl.program_id(axis=1)
|
| 1318 |
-
pid_h = tl.program_id(axis=2)
|
| 1319 |
-
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
| 1320 |
-
pid_m = tl.program_id(axis=0) // num_pid_n
|
| 1321 |
-
pid_n = tl.program_id(axis=0) % num_pid_n
|
| 1322 |
-
end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
|
| 1323 |
-
pid_c = (end_idx - 1) // chunk_size
|
| 1324 |
-
b_ptr += (
|
| 1325 |
-
pid_c * chunk_size * stride_b_seqlen
|
| 1326 |
-
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
| 1327 |
-
)
|
| 1328 |
-
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
| 1329 |
-
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
| 1330 |
-
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
| 1331 |
-
chunk_states_ptr += (
|
| 1332 |
-
pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
|
| 1333 |
-
)
|
| 1334 |
-
|
| 1335 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 1336 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 1337 |
-
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 1338 |
-
x_ptrs = x_ptr + (
|
| 1339 |
-
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
|
| 1340 |
-
)
|
| 1341 |
-
b_ptrs = b_ptr + (
|
| 1342 |
-
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
|
| 1343 |
-
)
|
| 1344 |
-
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
| 1345 |
-
dA_cs_last = tl.load(
|
| 1346 |
-
dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
|
| 1347 |
-
).to(tl.float32)
|
| 1348 |
-
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
| 1349 |
-
|
| 1350 |
-
chunk_size_limit = end_idx - pid_c * chunk_size
|
| 1351 |
-
start_idx = tl.load(cu_seqlens_ptr + pid_b)
|
| 1352 |
-
start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
|
| 1353 |
-
|
| 1354 |
-
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 1355 |
-
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
| 1356 |
-
x = tl.load(
|
| 1357 |
-
x_ptrs,
|
| 1358 |
-
mask=(offs_m[:, None] < hdim)
|
| 1359 |
-
& (offs_k[None, :] < chunk_size_limit - k)
|
| 1360 |
-
& (offs_k[None, :] >= start_idx_cur - k),
|
| 1361 |
-
other=0.0,
|
| 1362 |
-
)
|
| 1363 |
-
b = tl.load(
|
| 1364 |
-
b_ptrs,
|
| 1365 |
-
mask=(offs_k[:, None] < chunk_size_limit - k)
|
| 1366 |
-
& (offs_n[None, :] < dstate)
|
| 1367 |
-
& (offs_k[:, None] >= start_idx_cur - k),
|
| 1368 |
-
other=0.0,
|
| 1369 |
-
).to(tl.float32)
|
| 1370 |
-
dA_cs_k = tl.load(
|
| 1371 |
-
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
|
| 1372 |
-
).to(tl.float32)
|
| 1373 |
-
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
|
| 1374 |
-
tl.float32
|
| 1375 |
-
)
|
| 1376 |
-
scale = tl.where(
|
| 1377 |
-
(offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
|
| 1378 |
-
tl.exp((dA_cs_last - dA_cs_k)) * dt_k,
|
| 1379 |
-
0.0,
|
| 1380 |
-
)
|
| 1381 |
-
b *= scale[:, None]
|
| 1382 |
-
b = b.to(x_ptr.dtype.element_ty)
|
| 1383 |
-
acc += tl.dot(x, b)
|
| 1384 |
-
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
| 1385 |
-
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
| 1386 |
-
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
| 1387 |
-
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
| 1388 |
-
|
| 1389 |
-
# If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
|
| 1390 |
-
if start_idx < pid_c * chunk_size:
|
| 1391 |
-
chunk_states_ptrs = chunk_states_ptr + (
|
| 1392 |
-
offs_m[:, None] * stride_chunk_states_hdim
|
| 1393 |
-
+ offs_n[None, :] * stride_chunk_states_dstate
|
| 1394 |
-
)
|
| 1395 |
-
chunk_states = tl.load(
|
| 1396 |
-
chunk_states_ptrs,
|
| 1397 |
-
mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate),
|
| 1398 |
-
other=0.0,
|
| 1399 |
-
).to(tl.float32)
|
| 1400 |
-
# scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0)
|
| 1401 |
-
scale = tl.exp(dA_cs_last)
|
| 1402 |
-
acc += chunk_states * scale
|
| 1403 |
-
|
| 1404 |
-
states = acc.to(states_ptr.dtype.element_ty)
|
| 1405 |
-
|
| 1406 |
-
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
|
| 1407 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 1408 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 1409 |
-
states_ptrs = states_ptr + (
|
| 1410 |
-
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
|
| 1411 |
-
)
|
| 1412 |
-
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
| 1413 |
-
tl.store(states_ptrs, states, mask=c_mask)
|
| 1414 |
-
|
| 1415 |
-
|
| 1416 |
-
def _chunk_cumsum_fwd(
|
| 1417 |
-
dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))
|
| 1418 |
-
):
|
| 1419 |
-
batch, seqlen, nheads = dt.shape
|
| 1420 |
-
assert A.shape == (nheads,)
|
| 1421 |
-
if dt_bias is not None:
|
| 1422 |
-
assert dt_bias.shape == (nheads,)
|
| 1423 |
-
nchunks = math.ceil(seqlen / chunk_size)
|
| 1424 |
-
dt_out = torch.empty(
|
| 1425 |
-
batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
|
| 1426 |
-
)
|
| 1427 |
-
dA_cumsum = torch.empty(
|
| 1428 |
-
batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
|
| 1429 |
-
)
|
| 1430 |
-
grid_chunk_cs = lambda META: (
|
| 1431 |
-
batch,
|
| 1432 |
-
nchunks,
|
| 1433 |
-
triton.cdiv(nheads, META["BLOCK_SIZE_H"]),
|
| 1434 |
-
)
|
| 1435 |
-
with torch.cuda.device(dt.device.index):
|
| 1436 |
-
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
|
| 1437 |
-
dt,
|
| 1438 |
-
A,
|
| 1439 |
-
dt_bias,
|
| 1440 |
-
dt_out,
|
| 1441 |
-
dA_cumsum,
|
| 1442 |
-
batch,
|
| 1443 |
-
seqlen,
|
| 1444 |
-
nheads,
|
| 1445 |
-
chunk_size,
|
| 1446 |
-
dt_limit[0],
|
| 1447 |
-
dt_limit[1],
|
| 1448 |
-
dt.stride(0),
|
| 1449 |
-
dt.stride(1),
|
| 1450 |
-
dt.stride(2),
|
| 1451 |
-
A.stride(0),
|
| 1452 |
-
dt_bias.stride(0) if dt_bias is not None else 0,
|
| 1453 |
-
dt_out.stride(0),
|
| 1454 |
-
dt_out.stride(2),
|
| 1455 |
-
dt_out.stride(1),
|
| 1456 |
-
dt_out.stride(3),
|
| 1457 |
-
dA_cumsum.stride(0),
|
| 1458 |
-
dA_cumsum.stride(2),
|
| 1459 |
-
dA_cumsum.stride(1),
|
| 1460 |
-
dA_cumsum.stride(3),
|
| 1461 |
-
dt_softplus,
|
| 1462 |
-
HAS_DT_BIAS=dt_bias is not None,
|
| 1463 |
-
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
|
| 1464 |
-
)
|
| 1465 |
-
return dA_cumsum, dt_out
|
| 1466 |
-
|
| 1467 |
-
|
| 1468 |
-
def _chunk_cumsum_bwd(
|
| 1469 |
-
ddA,
|
| 1470 |
-
ddt_out,
|
| 1471 |
-
dt,
|
| 1472 |
-
A,
|
| 1473 |
-
dt_bias=None,
|
| 1474 |
-
dt_softplus=False,
|
| 1475 |
-
dt_limit=(0.0, float("inf")),
|
| 1476 |
-
ddt=None,
|
| 1477 |
-
):
|
| 1478 |
-
batch, seqlen, nheads = dt.shape
|
| 1479 |
-
_, _, nchunks, chunk_size = ddA.shape
|
| 1480 |
-
assert ddA.shape == (batch, nheads, nchunks, chunk_size)
|
| 1481 |
-
assert ddt_out.shape == (batch, nheads, nchunks, chunk_size)
|
| 1482 |
-
assert A.shape == (nheads,)
|
| 1483 |
-
if dt_bias is not None:
|
| 1484 |
-
assert dt_bias.shape == (nheads,)
|
| 1485 |
-
ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32)
|
| 1486 |
-
else:
|
| 1487 |
-
ddt_bias = None
|
| 1488 |
-
if ddt is not None:
|
| 1489 |
-
assert ddt.shape == dt.shape
|
| 1490 |
-
else:
|
| 1491 |
-
ddt = torch.empty_like(dt)
|
| 1492 |
-
dA = torch.empty_like(A, dtype=torch.float32)
|
| 1493 |
-
grid_chunk_cs = lambda META: (
|
| 1494 |
-
batch,
|
| 1495 |
-
nchunks,
|
| 1496 |
-
triton.cdiv(nheads, META["BLOCK_SIZE_H"]),
|
| 1497 |
-
)
|
| 1498 |
-
with torch.cuda.device(dt.device.index):
|
| 1499 |
-
_chunk_cumsum_bwd_kernel[grid_chunk_cs](
|
| 1500 |
-
ddA,
|
| 1501 |
-
ddt_out,
|
| 1502 |
-
dt,
|
| 1503 |
-
A,
|
| 1504 |
-
dt_bias,
|
| 1505 |
-
ddt,
|
| 1506 |
-
dA,
|
| 1507 |
-
ddt_bias,
|
| 1508 |
-
batch,
|
| 1509 |
-
seqlen,
|
| 1510 |
-
nheads,
|
| 1511 |
-
chunk_size,
|
| 1512 |
-
dt_limit[0],
|
| 1513 |
-
dt_limit[1],
|
| 1514 |
-
ddA.stride(0),
|
| 1515 |
-
ddA.stride(2),
|
| 1516 |
-
ddA.stride(1),
|
| 1517 |
-
ddA.stride(3),
|
| 1518 |
-
ddt_out.stride(0),
|
| 1519 |
-
ddt_out.stride(2),
|
| 1520 |
-
ddt_out.stride(1),
|
| 1521 |
-
ddt_out.stride(3),
|
| 1522 |
-
dt.stride(0),
|
| 1523 |
-
dt.stride(1),
|
| 1524 |
-
dt.stride(2),
|
| 1525 |
-
A.stride(0),
|
| 1526 |
-
dt_bias.stride(0) if dt_bias is not None else 0,
|
| 1527 |
-
ddt.stride(0),
|
| 1528 |
-
ddt.stride(1),
|
| 1529 |
-
ddt.stride(2),
|
| 1530 |
-
dA.stride(0),
|
| 1531 |
-
ddt_bias.stride(0) if ddt_bias is not None else 0,
|
| 1532 |
-
dt_softplus,
|
| 1533 |
-
HAS_DT_BIAS=dt_bias is not None,
|
| 1534 |
-
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
|
| 1535 |
-
)
|
| 1536 |
-
return ddt, dA, ddt_bias
|
| 1537 |
-
|
| 1538 |
-
|
| 1539 |
-
def _chunk_state_fwd(
|
| 1540 |
-
B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True
|
| 1541 |
-
):
|
| 1542 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1543 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1544 |
-
_, _, ngroups, dstate = B.shape
|
| 1545 |
-
assert nheads % ngroups == 0
|
| 1546 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 1547 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 1548 |
-
assert dA_cumsum.shape == dt.shape
|
| 1549 |
-
if seq_idx is not None:
|
| 1550 |
-
assert seq_idx.shape == (batch, seqlen)
|
| 1551 |
-
if states is not None:
|
| 1552 |
-
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
|
| 1553 |
-
else:
|
| 1554 |
-
states_dtype = torch.float32 if states_in_fp32 else B.dtype
|
| 1555 |
-
states = torch.empty(
|
| 1556 |
-
(batch, nchunks, nheads, headdim, dstate),
|
| 1557 |
-
device=x.device,
|
| 1558 |
-
dtype=states_dtype,
|
| 1559 |
-
)
|
| 1560 |
-
grid = lambda META: (
|
| 1561 |
-
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
|
| 1562 |
-
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
| 1563 |
-
batch * nchunks,
|
| 1564 |
-
nheads,
|
| 1565 |
-
)
|
| 1566 |
-
with torch.cuda.device(x.device.index):
|
| 1567 |
-
_chunk_state_fwd_kernel[grid](
|
| 1568 |
-
x,
|
| 1569 |
-
B,
|
| 1570 |
-
states,
|
| 1571 |
-
dt,
|
| 1572 |
-
dA_cumsum,
|
| 1573 |
-
seq_idx,
|
| 1574 |
-
headdim,
|
| 1575 |
-
dstate,
|
| 1576 |
-
chunk_size,
|
| 1577 |
-
batch,
|
| 1578 |
-
seqlen,
|
| 1579 |
-
nheads // ngroups,
|
| 1580 |
-
x.stride(0),
|
| 1581 |
-
x.stride(1),
|
| 1582 |
-
x.stride(2),
|
| 1583 |
-
x.stride(3),
|
| 1584 |
-
B.stride(0),
|
| 1585 |
-
B.stride(1),
|
| 1586 |
-
B.stride(2),
|
| 1587 |
-
B.stride(-1),
|
| 1588 |
-
states.stride(0),
|
| 1589 |
-
states.stride(1),
|
| 1590 |
-
states.stride(2),
|
| 1591 |
-
states.stride(3),
|
| 1592 |
-
states.stride(4),
|
| 1593 |
-
dt.stride(0),
|
| 1594 |
-
dt.stride(2),
|
| 1595 |
-
dt.stride(1),
|
| 1596 |
-
dt.stride(3),
|
| 1597 |
-
dA_cumsum.stride(0),
|
| 1598 |
-
dA_cumsum.stride(2),
|
| 1599 |
-
dA_cumsum.stride(1),
|
| 1600 |
-
dA_cumsum.stride(3),
|
| 1601 |
-
*(
|
| 1602 |
-
(seq_idx.stride(0), seq_idx.stride(1))
|
| 1603 |
-
if seq_idx is not None
|
| 1604 |
-
else (0, 0)
|
| 1605 |
-
),
|
| 1606 |
-
HAS_SEQ_IDX=seq_idx is not None,
|
| 1607 |
-
)
|
| 1608 |
-
return states
|
| 1609 |
-
|
| 1610 |
-
|
| 1611 |
-
def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None):
|
| 1612 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1613 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1614 |
-
_, _, ngroups, dstate = B.shape
|
| 1615 |
-
assert nheads % ngroups == 0
|
| 1616 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 1617 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 1618 |
-
assert dA_cumsum.shape == dt.shape
|
| 1619 |
-
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
| 1620 |
-
if dx is not None:
|
| 1621 |
-
assert dx.shape == x.shape
|
| 1622 |
-
else:
|
| 1623 |
-
dx = torch.empty_like(x)
|
| 1624 |
-
ddt = torch.empty(
|
| 1625 |
-
batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
|
| 1626 |
-
)
|
| 1627 |
-
ddA_cumsum = torch.empty(
|
| 1628 |
-
batch, nheads, nchunks, chunk_size, device=dA_cumsum.device, dtype=torch.float32
|
| 1629 |
-
)
|
| 1630 |
-
grid_dx = lambda META: (
|
| 1631 |
-
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
| 1632 |
-
* triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
|
| 1633 |
-
batch * nchunks,
|
| 1634 |
-
nheads,
|
| 1635 |
-
)
|
| 1636 |
-
with torch.cuda.device(x.device.index):
|
| 1637 |
-
_chunk_state_bwd_dx_kernel[grid_dx](
|
| 1638 |
-
x,
|
| 1639 |
-
B,
|
| 1640 |
-
dstates,
|
| 1641 |
-
dt,
|
| 1642 |
-
dA_cumsum,
|
| 1643 |
-
dx,
|
| 1644 |
-
ddt,
|
| 1645 |
-
ddA_cumsum,
|
| 1646 |
-
chunk_size,
|
| 1647 |
-
headdim,
|
| 1648 |
-
dstate,
|
| 1649 |
-
batch,
|
| 1650 |
-
seqlen,
|
| 1651 |
-
nheads // ngroups,
|
| 1652 |
-
x.stride(0),
|
| 1653 |
-
x.stride(1),
|
| 1654 |
-
x.stride(2),
|
| 1655 |
-
x.stride(3),
|
| 1656 |
-
B.stride(0),
|
| 1657 |
-
B.stride(1),
|
| 1658 |
-
B.stride(2),
|
| 1659 |
-
B.stride(-1),
|
| 1660 |
-
dstates.stride(0),
|
| 1661 |
-
dstates.stride(1),
|
| 1662 |
-
dstates.stride(2),
|
| 1663 |
-
dstates.stride(3),
|
| 1664 |
-
dstates.stride(4),
|
| 1665 |
-
dt.stride(0),
|
| 1666 |
-
dt.stride(2),
|
| 1667 |
-
dt.stride(1),
|
| 1668 |
-
dt.stride(3),
|
| 1669 |
-
dA_cumsum.stride(0),
|
| 1670 |
-
dA_cumsum.stride(2),
|
| 1671 |
-
dA_cumsum.stride(1),
|
| 1672 |
-
dA_cumsum.stride(3),
|
| 1673 |
-
dx.stride(0),
|
| 1674 |
-
dx.stride(1),
|
| 1675 |
-
dx.stride(2),
|
| 1676 |
-
dx.stride(3),
|
| 1677 |
-
ddt.stride(0),
|
| 1678 |
-
ddt.stride(2),
|
| 1679 |
-
ddt.stride(1),
|
| 1680 |
-
ddt.stride(3),
|
| 1681 |
-
ddA_cumsum.stride(0),
|
| 1682 |
-
ddA_cumsum.stride(2),
|
| 1683 |
-
ddA_cumsum.stride(1),
|
| 1684 |
-
ddA_cumsum.stride(3),
|
| 1685 |
-
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
| 1686 |
-
)
|
| 1687 |
-
return dx, ddt.to(dt.dtype), ddA_cumsum.to(dA_cumsum.dtype)
|
| 1688 |
-
|
| 1689 |
-
|
| 1690 |
-
def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1):
|
| 1691 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1692 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1693 |
-
dstate = dstates.shape[-1]
|
| 1694 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 1695 |
-
assert dA_cumsum.shape == dt.shape
|
| 1696 |
-
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
| 1697 |
-
if seq_idx is not None:
|
| 1698 |
-
assert seq_idx.shape == (batch, seqlen)
|
| 1699 |
-
if B is not None:
|
| 1700 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 1701 |
-
B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3))
|
| 1702 |
-
# Use torch.empty since the Triton kernel will call init_to_zero
|
| 1703 |
-
ddA_cumsum = torch.empty(
|
| 1704 |
-
batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32
|
| 1705 |
-
)
|
| 1706 |
-
ddA_cumsum_strides = (
|
| 1707 |
-
ddA_cumsum.stride(0),
|
| 1708 |
-
ddA_cumsum.stride(2),
|
| 1709 |
-
ddA_cumsum.stride(1),
|
| 1710 |
-
ddA_cumsum.stride(3),
|
| 1711 |
-
)
|
| 1712 |
-
else:
|
| 1713 |
-
B_strides = (0, 0, 0, 0)
|
| 1714 |
-
ddA_cumsum = None
|
| 1715 |
-
ddA_cumsum_strides = (0, 0, 0, 0)
|
| 1716 |
-
nheads_ngroups_ratio = nheads // ngroups
|
| 1717 |
-
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
| 1718 |
-
nheads_per_program = max(
|
| 1719 |
-
min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1
|
| 1720 |
-
)
|
| 1721 |
-
nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)
|
| 1722 |
-
dB = torch.empty(
|
| 1723 |
-
batch, seqlen, nsplits, ngroups, dstate, device=x.device, dtype=torch.float32
|
| 1724 |
-
)
|
| 1725 |
-
grid_db = lambda META: (
|
| 1726 |
-
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
| 1727 |
-
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
| 1728 |
-
batch * nchunks,
|
| 1729 |
-
nsplits * ngroups,
|
| 1730 |
-
)
|
| 1731 |
-
with torch.cuda.device(x.device.index):
|
| 1732 |
-
_chunk_state_bwd_db_kernel[grid_db](
|
| 1733 |
-
x,
|
| 1734 |
-
dstates,
|
| 1735 |
-
B,
|
| 1736 |
-
dt,
|
| 1737 |
-
dA_cumsum,
|
| 1738 |
-
seq_idx,
|
| 1739 |
-
dB,
|
| 1740 |
-
ddA_cumsum,
|
| 1741 |
-
chunk_size,
|
| 1742 |
-
dstate,
|
| 1743 |
-
headdim,
|
| 1744 |
-
batch,
|
| 1745 |
-
seqlen,
|
| 1746 |
-
nheads,
|
| 1747 |
-
nheads_per_program,
|
| 1748 |
-
ngroups,
|
| 1749 |
-
x.stride(0),
|
| 1750 |
-
x.stride(1),
|
| 1751 |
-
x.stride(2),
|
| 1752 |
-
x.stride(3),
|
| 1753 |
-
dstates.stride(0),
|
| 1754 |
-
dstates.stride(1),
|
| 1755 |
-
dstates.stride(2),
|
| 1756 |
-
dstates.stride(3),
|
| 1757 |
-
dstates.stride(4),
|
| 1758 |
-
*B_strides,
|
| 1759 |
-
dt.stride(0),
|
| 1760 |
-
dt.stride(2),
|
| 1761 |
-
dt.stride(1),
|
| 1762 |
-
dt.stride(3),
|
| 1763 |
-
dA_cumsum.stride(0),
|
| 1764 |
-
dA_cumsum.stride(2),
|
| 1765 |
-
dA_cumsum.stride(1),
|
| 1766 |
-
dA_cumsum.stride(3),
|
| 1767 |
-
*(
|
| 1768 |
-
(seq_idx.stride(0), seq_idx.stride(1))
|
| 1769 |
-
if seq_idx is not None
|
| 1770 |
-
else (0, 0)
|
| 1771 |
-
),
|
| 1772 |
-
dB.stride(0),
|
| 1773 |
-
dB.stride(1),
|
| 1774 |
-
dB.stride(2),
|
| 1775 |
-
dB.stride(3),
|
| 1776 |
-
dB.stride(4),
|
| 1777 |
-
*ddA_cumsum_strides,
|
| 1778 |
-
HAS_DDA_CS=ddA_cumsum is not None,
|
| 1779 |
-
HAS_SEQ_IDX=seq_idx is not None,
|
| 1780 |
-
BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
|
| 1781 |
-
)
|
| 1782 |
-
dB = dB.sum(2)
|
| 1783 |
-
if ddA_cumsum is not None:
|
| 1784 |
-
# The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute
|
| 1785 |
-
# to the state of the chunk.
|
| 1786 |
-
# torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
|
| 1787 |
-
# But it's easier to just do the cumsum for all elements, the result will be the same.
|
| 1788 |
-
torch.cumsum(ddA_cumsum, dim=-1, out=ddA_cumsum)
|
| 1789 |
-
return dB if B is None else (dB, ddA_cumsum)
|
| 1790 |
-
|
| 1791 |
-
|
| 1792 |
-
def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None):
|
| 1793 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1794 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1795 |
-
_, _, ngroups, dstate = B.shape
|
| 1796 |
-
assert nheads % ngroups == 0
|
| 1797 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 1798 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 1799 |
-
assert dA_cumsum.shape == dt.shape
|
| 1800 |
-
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
| 1801 |
-
if seq_idx is not None:
|
| 1802 |
-
assert seq_idx.shape == (batch, seqlen)
|
| 1803 |
-
# Use torch.empty since the Triton kernel will call init_to_zero
|
| 1804 |
-
ddA_cumsum = torch.empty(
|
| 1805 |
-
batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32
|
| 1806 |
-
)
|
| 1807 |
-
grid_ddtcs = lambda META: (
|
| 1808 |
-
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
| 1809 |
-
* triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
|
| 1810 |
-
batch * nchunks,
|
| 1811 |
-
nheads,
|
| 1812 |
-
)
|
| 1813 |
-
with torch.cuda.device(x.device.index):
|
| 1814 |
-
_chunk_state_bwd_ddAcs_stable_kernel[grid_ddtcs](
|
| 1815 |
-
x,
|
| 1816 |
-
B,
|
| 1817 |
-
dstates,
|
| 1818 |
-
dt,
|
| 1819 |
-
dA_cumsum,
|
| 1820 |
-
seq_idx,
|
| 1821 |
-
ddA_cumsum,
|
| 1822 |
-
chunk_size,
|
| 1823 |
-
headdim,
|
| 1824 |
-
dstate,
|
| 1825 |
-
batch,
|
| 1826 |
-
seqlen,
|
| 1827 |
-
nheads // ngroups,
|
| 1828 |
-
x.stride(0),
|
| 1829 |
-
x.stride(1),
|
| 1830 |
-
x.stride(2),
|
| 1831 |
-
x.stride(3),
|
| 1832 |
-
B.stride(0),
|
| 1833 |
-
B.stride(1),
|
| 1834 |
-
B.stride(2),
|
| 1835 |
-
B.stride(-1),
|
| 1836 |
-
dstates.stride(0),
|
| 1837 |
-
dstates.stride(1),
|
| 1838 |
-
dstates.stride(2),
|
| 1839 |
-
dstates.stride(3),
|
| 1840 |
-
dstates.stride(4),
|
| 1841 |
-
dt.stride(0),
|
| 1842 |
-
dt.stride(2),
|
| 1843 |
-
dt.stride(1),
|
| 1844 |
-
dt.stride(3),
|
| 1845 |
-
dA_cumsum.stride(0),
|
| 1846 |
-
dA_cumsum.stride(2),
|
| 1847 |
-
dA_cumsum.stride(1),
|
| 1848 |
-
dA_cumsum.stride(3),
|
| 1849 |
-
*(
|
| 1850 |
-
(seq_idx.stride(0), seq_idx.stride(1))
|
| 1851 |
-
if seq_idx is not None
|
| 1852 |
-
else (0, 0)
|
| 1853 |
-
),
|
| 1854 |
-
ddA_cumsum.stride(0),
|
| 1855 |
-
ddA_cumsum.stride(2),
|
| 1856 |
-
ddA_cumsum.stride(1),
|
| 1857 |
-
ddA_cumsum.stride(3),
|
| 1858 |
-
HAS_SEQ_IDX=seq_idx is not None,
|
| 1859 |
-
BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16),
|
| 1860 |
-
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
| 1861 |
-
)
|
| 1862 |
-
torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
|
| 1863 |
-
return ddA_cumsum
|
| 1864 |
-
|
| 1865 |
-
|
| 1866 |
-
def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states):
|
| 1867 |
-
total_seqlen, nheads, headdim = x.shape
|
| 1868 |
-
_, nchunks, chunk_size = dt.shape
|
| 1869 |
-
_, ngroups, dstate = B.shape
|
| 1870 |
-
batch = cu_seqlens.shape[0] - 1
|
| 1871 |
-
cu_seqlens = cu_seqlens.contiguous()
|
| 1872 |
-
assert nheads % ngroups == 0
|
| 1873 |
-
assert B.shape == (total_seqlen, ngroups, dstate)
|
| 1874 |
-
assert dt.shape == (nheads, nchunks, chunk_size)
|
| 1875 |
-
assert dA_cumsum.shape == dt.shape
|
| 1876 |
-
assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
|
| 1877 |
-
states = torch.empty(
|
| 1878 |
-
batch,
|
| 1879 |
-
nheads,
|
| 1880 |
-
headdim,
|
| 1881 |
-
dstate,
|
| 1882 |
-
dtype=chunk_states.dtype,
|
| 1883 |
-
device=chunk_states.device,
|
| 1884 |
-
)
|
| 1885 |
-
grid = lambda META: (
|
| 1886 |
-
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
|
| 1887 |
-
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
| 1888 |
-
batch,
|
| 1889 |
-
nheads,
|
| 1890 |
-
)
|
| 1891 |
-
with torch.cuda.device(x.device.index):
|
| 1892 |
-
_chunk_state_varlen_kernel[grid](
|
| 1893 |
-
x,
|
| 1894 |
-
B,
|
| 1895 |
-
dt,
|
| 1896 |
-
dA_cumsum,
|
| 1897 |
-
chunk_states,
|
| 1898 |
-
cu_seqlens,
|
| 1899 |
-
states,
|
| 1900 |
-
headdim,
|
| 1901 |
-
dstate,
|
| 1902 |
-
chunk_size,
|
| 1903 |
-
total_seqlen,
|
| 1904 |
-
nheads // ngroups,
|
| 1905 |
-
x.stride(0),
|
| 1906 |
-
x.stride(1),
|
| 1907 |
-
x.stride(2),
|
| 1908 |
-
B.stride(0),
|
| 1909 |
-
B.stride(1),
|
| 1910 |
-
B.stride(2),
|
| 1911 |
-
dt.stride(1),
|
| 1912 |
-
dt.stride(0),
|
| 1913 |
-
dt.stride(2),
|
| 1914 |
-
dA_cumsum.stride(1),
|
| 1915 |
-
dA_cumsum.stride(0),
|
| 1916 |
-
dA_cumsum.stride(2),
|
| 1917 |
-
chunk_states.stride(0),
|
| 1918 |
-
chunk_states.stride(1),
|
| 1919 |
-
chunk_states.stride(2),
|
| 1920 |
-
chunk_states.stride(3),
|
| 1921 |
-
states.stride(0),
|
| 1922 |
-
states.stride(1),
|
| 1923 |
-
states.stride(2),
|
| 1924 |
-
states.stride(3),
|
| 1925 |
-
)
|
| 1926 |
-
return states
|
| 1927 |
-
|
| 1928 |
-
|
| 1929 |
-
class ChunkStateFn(torch.autograd.Function):
|
| 1930 |
-
|
| 1931 |
-
@staticmethod
|
| 1932 |
-
def forward(ctx, B, x, dt, dA_cumsum, states_in_fp32=True):
|
| 1933 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1934 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1935 |
-
assert seqlen <= nchunks * chunk_size
|
| 1936 |
-
_, _, ngroups, dstate = B.shape
|
| 1937 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 1938 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 1939 |
-
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
| 1940 |
-
if B.stride(-1) != 1:
|
| 1941 |
-
B = B.contiguous()
|
| 1942 |
-
if (
|
| 1943 |
-
x.stride(-1) != 1 and x.stride(1) != 1
|
| 1944 |
-
): # Either M or K dimension should be contiguous
|
| 1945 |
-
x = x.contiguous()
|
| 1946 |
-
states = _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=states_in_fp32)
|
| 1947 |
-
ctx.save_for_backward(B, x, dt, dA_cumsum)
|
| 1948 |
-
return states
|
| 1949 |
-
|
| 1950 |
-
@staticmethod
|
| 1951 |
-
def backward(ctx, dstates):
|
| 1952 |
-
B, x, dt, dA_cumsum = ctx.saved_tensors
|
| 1953 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1954 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1955 |
-
_, _, ngroups, dstate = B.shape
|
| 1956 |
-
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
| 1957 |
-
if dstates.stride(-1) != 1:
|
| 1958 |
-
dstates = dstates.contiguous()
|
| 1959 |
-
dx, ddt, ddA_cumsum = _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates)
|
| 1960 |
-
dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, ngroups=ngroups)
|
| 1961 |
-
dB = dB.to(B.dtype)
|
| 1962 |
-
return dB, dx, ddt, ddA_cumsum, None
|
| 1963 |
-
|
| 1964 |
-
|
| 1965 |
-
def chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True):
|
| 1966 |
-
"""
|
| 1967 |
-
Argument:
|
| 1968 |
-
B: (batch, seqlen, ngroups, headdim)
|
| 1969 |
-
x: (batch, seqlen, nheads, headdim)
|
| 1970 |
-
dt: (batch, nheads, nchunks, chunk_size)
|
| 1971 |
-
dA_cumsum: (batch, nheads, nchunks, chunk_size)
|
| 1972 |
-
Return:
|
| 1973 |
-
states: (batch, nchunks, nheads, headdim, dstate)
|
| 1974 |
-
"""
|
| 1975 |
-
return ChunkStateFn.apply(B, x, dt, dA_cumsum, states_in_fp32)
|
| 1976 |
-
|
| 1977 |
-
|
| 1978 |
-
def chunk_state_ref(B, x, dt, dA_cumsum):
|
| 1979 |
-
"""
|
| 1980 |
-
Argument:
|
| 1981 |
-
B: (batch, seqlen, ngroups, headdim)
|
| 1982 |
-
x: (batch, seqlen, nheads, headdim)
|
| 1983 |
-
dt: (batch, nheads, nchunks, chunk_size)
|
| 1984 |
-
dA_cumsum: (batch, nheads, nchunks, chunk_size)
|
| 1985 |
-
Return:
|
| 1986 |
-
states: (batch, nchunks, nheads, headdim, dstate)
|
| 1987 |
-
"""
|
| 1988 |
-
# Check constraints.
|
| 1989 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1990 |
-
dstate = B.shape[-1]
|
| 1991 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1992 |
-
assert seqlen <= nchunks * chunk_size
|
| 1993 |
-
assert x.shape == (batch, seqlen, nheads, headdim)
|
| 1994 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 1995 |
-
ngroups = B.shape[2]
|
| 1996 |
-
assert nheads % ngroups == 0
|
| 1997 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 1998 |
-
B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
|
| 1999 |
-
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
| 2000 |
-
if seqlen < nchunks * chunk_size:
|
| 2001 |
-
x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
|
| 2002 |
-
B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
|
| 2003 |
-
x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size)
|
| 2004 |
-
B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size)
|
| 2005 |
-
decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum))
|
| 2006 |
-
return torch.einsum(
|
| 2007 |
-
"bclhn,bhcl,bhcl,bclhp->bchpn",
|
| 2008 |
-
B.to(x.dtype),
|
| 2009 |
-
decay_states.to(x.dtype),
|
| 2010 |
-
dt.to(x.dtype),
|
| 2011 |
-
x,
|
| 2012 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/triton/ssd_combined.py
DELETED
|
@@ -1,1884 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
| 2 |
-
|
| 3 |
-
"""We want triton==2.1.0 or 2.2.0 for this
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
from typing import Optional
|
| 7 |
-
|
| 8 |
-
import math
|
| 9 |
-
from packaging import version
|
| 10 |
-
|
| 11 |
-
import torch
|
| 12 |
-
import torch.nn.functional as F
|
| 13 |
-
from torch import Tensor
|
| 14 |
-
from ...utils.torch import custom_bwd, custom_fwd
|
| 15 |
-
|
| 16 |
-
import triton
|
| 17 |
-
import triton.language as tl
|
| 18 |
-
|
| 19 |
-
from einops import rearrange, repeat
|
| 20 |
-
|
| 21 |
-
try:
|
| 22 |
-
from causal_conv1d import causal_conv1d_fn
|
| 23 |
-
import causal_conv1d_cuda
|
| 24 |
-
except ImportError:
|
| 25 |
-
causal_conv1d_fn, causal_conv1d_cuda = None, None
|
| 26 |
-
|
| 27 |
-
from .ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
|
| 28 |
-
from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd
|
| 29 |
-
from .ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db
|
| 30 |
-
from .ssd_chunk_state import _chunk_state_bwd_ddAcs_stable
|
| 31 |
-
from .ssd_chunk_state import chunk_state, chunk_state_ref
|
| 32 |
-
from .ssd_chunk_state import chunk_state_varlen
|
| 33 |
-
from .ssd_state_passing import _state_passing_fwd, _state_passing_bwd
|
| 34 |
-
from .ssd_state_passing import state_passing, state_passing_ref
|
| 35 |
-
from .ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates
|
| 36 |
-
from .ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb
|
| 37 |
-
from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable
|
| 38 |
-
from .ssd_chunk_scan import chunk_scan, chunk_scan_ref
|
| 39 |
-
from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev
|
| 40 |
-
from .layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd
|
| 41 |
-
from .k_activations import _swiglu_fwd, _swiglu_bwd
|
| 42 |
-
|
| 43 |
-
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
def init_to_zero(names):
|
| 47 |
-
return lambda nargs: [
|
| 48 |
-
nargs[name].zero_() for name in names if nargs[name] is not None
|
| 49 |
-
]
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
@triton.autotune(
|
| 53 |
-
configs=[
|
| 54 |
-
triton.Config(
|
| 55 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
| 56 |
-
num_stages=3,
|
| 57 |
-
num_warps=8,
|
| 58 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 59 |
-
),
|
| 60 |
-
triton.Config(
|
| 61 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
| 62 |
-
num_stages=4,
|
| 63 |
-
num_warps=4,
|
| 64 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 65 |
-
),
|
| 66 |
-
triton.Config(
|
| 67 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 68 |
-
num_stages=4,
|
| 69 |
-
num_warps=4,
|
| 70 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 71 |
-
),
|
| 72 |
-
triton.Config(
|
| 73 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 74 |
-
num_stages=4,
|
| 75 |
-
num_warps=4,
|
| 76 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 77 |
-
),
|
| 78 |
-
triton.Config(
|
| 79 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 80 |
-
num_stages=4,
|
| 81 |
-
num_warps=4,
|
| 82 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 83 |
-
),
|
| 84 |
-
triton.Config(
|
| 85 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 86 |
-
num_stages=4,
|
| 87 |
-
num_warps=4,
|
| 88 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 89 |
-
),
|
| 90 |
-
triton.Config(
|
| 91 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 92 |
-
num_stages=5,
|
| 93 |
-
num_warps=4,
|
| 94 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 95 |
-
),
|
| 96 |
-
triton.Config(
|
| 97 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 98 |
-
num_stages=5,
|
| 99 |
-
num_warps=4,
|
| 100 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 101 |
-
),
|
| 102 |
-
triton.Config(
|
| 103 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 104 |
-
num_stages=4,
|
| 105 |
-
num_warps=4,
|
| 106 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 107 |
-
),
|
| 108 |
-
],
|
| 109 |
-
key=["chunk_size", "hdim", "dstate"],
|
| 110 |
-
)
|
| 111 |
-
@triton.jit
|
| 112 |
-
def _chunk_scan_chunk_state_bwd_dx_kernel(
|
| 113 |
-
# Pointers to matrices
|
| 114 |
-
x_ptr,
|
| 115 |
-
cb_ptr,
|
| 116 |
-
dout_ptr,
|
| 117 |
-
dt_ptr,
|
| 118 |
-
dA_cumsum_ptr,
|
| 119 |
-
seq_idx_ptr,
|
| 120 |
-
D_ptr,
|
| 121 |
-
b_ptr,
|
| 122 |
-
dstates_ptr,
|
| 123 |
-
dx_ptr,
|
| 124 |
-
ddt_ptr,
|
| 125 |
-
dD_ptr,
|
| 126 |
-
# Matrix dimensions
|
| 127 |
-
chunk_size,
|
| 128 |
-
hdim,
|
| 129 |
-
dstate,
|
| 130 |
-
batch,
|
| 131 |
-
seqlen,
|
| 132 |
-
nheads_ngroups_ratio,
|
| 133 |
-
# Strides
|
| 134 |
-
stride_x_batch,
|
| 135 |
-
stride_x_seqlen,
|
| 136 |
-
stride_x_head,
|
| 137 |
-
stride_x_hdim,
|
| 138 |
-
stride_cb_batch,
|
| 139 |
-
stride_cb_chunk,
|
| 140 |
-
stride_cb_head,
|
| 141 |
-
stride_cb_csize_m,
|
| 142 |
-
stride_cb_csize_k,
|
| 143 |
-
stride_dout_batch,
|
| 144 |
-
stride_dout_seqlen,
|
| 145 |
-
stride_dout_head,
|
| 146 |
-
stride_dout_hdim,
|
| 147 |
-
stride_dt_batch,
|
| 148 |
-
stride_dt_chunk,
|
| 149 |
-
stride_dt_head,
|
| 150 |
-
stride_dt_csize,
|
| 151 |
-
stride_dA_cs_batch,
|
| 152 |
-
stride_dA_cs_chunk,
|
| 153 |
-
stride_dA_cs_head,
|
| 154 |
-
stride_dA_cs_csize,
|
| 155 |
-
stride_seq_idx_batch,
|
| 156 |
-
stride_seq_idx_seqlen,
|
| 157 |
-
stride_D_head,
|
| 158 |
-
stride_b_batch,
|
| 159 |
-
stride_b_seqlen,
|
| 160 |
-
stride_b_head,
|
| 161 |
-
stride_b_dstate,
|
| 162 |
-
stride_dstates_batch,
|
| 163 |
-
stride_dstates_chunk,
|
| 164 |
-
stride_dstates_head,
|
| 165 |
-
stride_dstates_hdim,
|
| 166 |
-
stride_dstates_dstate,
|
| 167 |
-
stride_dx_batch,
|
| 168 |
-
stride_dx_seqlen,
|
| 169 |
-
stride_dx_head,
|
| 170 |
-
stride_dx_hdim,
|
| 171 |
-
stride_ddt_batch,
|
| 172 |
-
stride_ddt_chunk,
|
| 173 |
-
stride_ddt_head,
|
| 174 |
-
stride_ddt_csize,
|
| 175 |
-
stride_dD_batch,
|
| 176 |
-
stride_dD_chunk,
|
| 177 |
-
stride_dD_head,
|
| 178 |
-
stride_dD_csize,
|
| 179 |
-
stride_dD_hdim,
|
| 180 |
-
# Meta-parameters
|
| 181 |
-
HAS_D: tl.constexpr,
|
| 182 |
-
D_HAS_HDIM: tl.constexpr,
|
| 183 |
-
HAS_SEQ_IDX: tl.constexpr,
|
| 184 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 185 |
-
BLOCK_SIZE_N: tl.constexpr,
|
| 186 |
-
BLOCK_SIZE_K: tl.constexpr,
|
| 187 |
-
BLOCK_SIZE_DSTATE: tl.constexpr,
|
| 188 |
-
IS_TRITON_22: tl.constexpr,
|
| 189 |
-
):
|
| 190 |
-
pid_bc = tl.program_id(axis=1)
|
| 191 |
-
pid_c = pid_bc // batch
|
| 192 |
-
pid_b = pid_bc - pid_c * batch
|
| 193 |
-
pid_h = tl.program_id(axis=2)
|
| 194 |
-
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
| 195 |
-
pid_m = tl.program_id(axis=0) // num_pid_n
|
| 196 |
-
pid_n = tl.program_id(axis=0) % num_pid_n
|
| 197 |
-
x_ptr += (
|
| 198 |
-
pid_b * stride_x_batch
|
| 199 |
-
+ pid_c * chunk_size * stride_x_seqlen
|
| 200 |
-
+ pid_h * stride_x_head
|
| 201 |
-
)
|
| 202 |
-
cb_ptr += (
|
| 203 |
-
pid_b * stride_cb_batch
|
| 204 |
-
+ pid_c * stride_cb_chunk
|
| 205 |
-
+ (pid_h // nheads_ngroups_ratio) * stride_cb_head
|
| 206 |
-
)
|
| 207 |
-
dout_ptr += (
|
| 208 |
-
pid_b * stride_dout_batch
|
| 209 |
-
+ pid_c * chunk_size * stride_dout_seqlen
|
| 210 |
-
+ pid_h * stride_dout_head
|
| 211 |
-
)
|
| 212 |
-
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
| 213 |
-
ddt_ptr += (
|
| 214 |
-
pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
|
| 215 |
-
)
|
| 216 |
-
dA_cumsum_ptr += (
|
| 217 |
-
pid_b * stride_dA_cs_batch
|
| 218 |
-
+ pid_c * stride_dA_cs_chunk
|
| 219 |
-
+ pid_h * stride_dA_cs_head
|
| 220 |
-
)
|
| 221 |
-
b_ptr += (
|
| 222 |
-
pid_b * stride_b_batch
|
| 223 |
-
+ pid_c * chunk_size * stride_b_seqlen
|
| 224 |
-
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
| 225 |
-
)
|
| 226 |
-
dstates_ptr += (
|
| 227 |
-
pid_b * stride_dstates_batch
|
| 228 |
-
+ pid_c * stride_dstates_chunk
|
| 229 |
-
+ pid_h * stride_dstates_head
|
| 230 |
-
)
|
| 231 |
-
if HAS_SEQ_IDX:
|
| 232 |
-
seq_idx_ptr += (
|
| 233 |
-
pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
| 234 |
-
)
|
| 235 |
-
|
| 236 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 237 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 238 |
-
|
| 239 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 240 |
-
|
| 241 |
-
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 242 |
-
|
| 243 |
-
dA_cs_m = tl.load(
|
| 244 |
-
dA_cumsum_ptr + offs_m * stride_dA_cs_csize,
|
| 245 |
-
mask=offs_m < chunk_size_limit,
|
| 246 |
-
other=0.0,
|
| 247 |
-
).to(tl.float32)
|
| 248 |
-
|
| 249 |
-
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
| 250 |
-
tl.float32
|
| 251 |
-
)
|
| 252 |
-
if not HAS_SEQ_IDX:
|
| 253 |
-
scale = tl.exp(dA_cs_last - dA_cs_m)
|
| 254 |
-
else:
|
| 255 |
-
seq_idx_m = tl.load(
|
| 256 |
-
seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
| 257 |
-
mask=offs_m < chunk_size_limit,
|
| 258 |
-
other=-1,
|
| 259 |
-
)
|
| 260 |
-
seq_idx_last = tl.load(
|
| 261 |
-
seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
|
| 262 |
-
)
|
| 263 |
-
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
|
| 264 |
-
# Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
| 265 |
-
# However, we're getting error with the Triton compiler 2.1.0 for that code path:
|
| 266 |
-
# Unexpected mma -> mma layout conversion
|
| 267 |
-
# Triton 2.2.0 fixes this
|
| 268 |
-
offs_dstate = tl.arange(
|
| 269 |
-
0,
|
| 270 |
-
(
|
| 271 |
-
BLOCK_SIZE_DSTATE
|
| 272 |
-
if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128
|
| 273 |
-
else BLOCK_SIZE_K
|
| 274 |
-
),
|
| 275 |
-
)
|
| 276 |
-
b_ptrs = b_ptr + (
|
| 277 |
-
offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate
|
| 278 |
-
)
|
| 279 |
-
dstates_ptrs = dstates_ptr + (
|
| 280 |
-
offs_n[None, :] * stride_dstates_hdim
|
| 281 |
-
+ offs_dstate[:, None] * stride_dstates_dstate
|
| 282 |
-
)
|
| 283 |
-
if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128:
|
| 284 |
-
b = tl.load(
|
| 285 |
-
b_ptrs,
|
| 286 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate),
|
| 287 |
-
other=0.0,
|
| 288 |
-
)
|
| 289 |
-
dstates = tl.load(
|
| 290 |
-
dstates_ptrs,
|
| 291 |
-
mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim),
|
| 292 |
-
other=0.0,
|
| 293 |
-
)
|
| 294 |
-
dstates = dstates.to(b_ptr.dtype.element_ty)
|
| 295 |
-
acc = tl.dot(b, dstates) * scale[:, None]
|
| 296 |
-
else:
|
| 297 |
-
for k in range(0, dstate, BLOCK_SIZE_K):
|
| 298 |
-
b = tl.load(
|
| 299 |
-
b_ptrs,
|
| 300 |
-
mask=(offs_m[:, None] < chunk_size_limit)
|
| 301 |
-
& (offs_dstate[None, :] < dstate - k),
|
| 302 |
-
other=0.0,
|
| 303 |
-
)
|
| 304 |
-
dstates = tl.load(
|
| 305 |
-
dstates_ptrs,
|
| 306 |
-
mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim),
|
| 307 |
-
other=0.0,
|
| 308 |
-
)
|
| 309 |
-
dstates = dstates.to(b_ptr.dtype.element_ty)
|
| 310 |
-
acc += tl.dot(b, dstates)
|
| 311 |
-
b_ptrs += BLOCK_SIZE_K * stride_b_dstate
|
| 312 |
-
dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate
|
| 313 |
-
acc *= scale[:, None]
|
| 314 |
-
|
| 315 |
-
# x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
|
| 316 |
-
# x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
|
| 317 |
-
# dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
| 318 |
-
# dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
|
| 319 |
-
# ddt = tl.sum(acc * x, axis=1) * dt_m
|
| 320 |
-
# ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
|
| 321 |
-
# tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
|
| 322 |
-
|
| 323 |
-
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 324 |
-
cb_ptrs = cb_ptr + (
|
| 325 |
-
offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k
|
| 326 |
-
)
|
| 327 |
-
dout_ptrs = dout_ptr + (
|
| 328 |
-
offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim
|
| 329 |
-
)
|
| 330 |
-
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
| 331 |
-
K_MAX = chunk_size_limit
|
| 332 |
-
K_MIN = pid_m * BLOCK_SIZE_M
|
| 333 |
-
cb_ptrs += K_MIN * stride_cb_csize_k
|
| 334 |
-
dout_ptrs += K_MIN * stride_dout_seqlen
|
| 335 |
-
dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize
|
| 336 |
-
for k in range(K_MIN, K_MAX, BLOCK_SIZE_K):
|
| 337 |
-
k = tl.multiple_of(k, BLOCK_SIZE_K)
|
| 338 |
-
# For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower
|
| 339 |
-
cb = tl.load(
|
| 340 |
-
cb_ptrs,
|
| 341 |
-
mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k),
|
| 342 |
-
other=0.0,
|
| 343 |
-
)
|
| 344 |
-
dout = tl.load(
|
| 345 |
-
dout_ptrs,
|
| 346 |
-
mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim),
|
| 347 |
-
other=0.0,
|
| 348 |
-
)
|
| 349 |
-
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(
|
| 350 |
-
tl.float32
|
| 351 |
-
)
|
| 352 |
-
cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
|
| 353 |
-
# If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
|
| 354 |
-
# we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
|
| 355 |
-
# Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
|
| 356 |
-
# This will cause NaN in acc, and hence NaN in dx and ddt.
|
| 357 |
-
mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX)
|
| 358 |
-
cb = tl.where(mask, cb, 0.0)
|
| 359 |
-
cb = cb.to(dout_ptr.dtype.element_ty)
|
| 360 |
-
acc += tl.dot(cb, dout)
|
| 361 |
-
cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
|
| 362 |
-
dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen
|
| 363 |
-
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
| 364 |
-
|
| 365 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 366 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 367 |
-
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
| 368 |
-
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
|
| 369 |
-
dx = acc * dt_m[:, None]
|
| 370 |
-
dx_ptr += (
|
| 371 |
-
pid_b * stride_dx_batch
|
| 372 |
-
+ pid_c * chunk_size * stride_dx_seqlen
|
| 373 |
-
+ pid_h * stride_dx_head
|
| 374 |
-
)
|
| 375 |
-
dx_ptrs = dx_ptr + (
|
| 376 |
-
offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim
|
| 377 |
-
)
|
| 378 |
-
if HAS_D:
|
| 379 |
-
dout_res_ptrs = dout_ptr + (
|
| 380 |
-
offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim
|
| 381 |
-
)
|
| 382 |
-
dout_res = tl.load(
|
| 383 |
-
dout_res_ptrs,
|
| 384 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
| 385 |
-
other=0.0,
|
| 386 |
-
).to(tl.float32)
|
| 387 |
-
if D_HAS_HDIM:
|
| 388 |
-
D = tl.load(
|
| 389 |
-
D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0
|
| 390 |
-
).to(tl.float32)
|
| 391 |
-
else:
|
| 392 |
-
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
|
| 393 |
-
dx += dout_res * D
|
| 394 |
-
tl.store(
|
| 395 |
-
dx_ptrs,
|
| 396 |
-
dx,
|
| 397 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
| 398 |
-
)
|
| 399 |
-
|
| 400 |
-
x_ptrs = x_ptr + (
|
| 401 |
-
offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
|
| 402 |
-
)
|
| 403 |
-
x = tl.load(
|
| 404 |
-
x_ptrs,
|
| 405 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
| 406 |
-
other=0.0,
|
| 407 |
-
).to(tl.float32)
|
| 408 |
-
if HAS_D:
|
| 409 |
-
dD_ptr += (
|
| 410 |
-
pid_b * stride_dD_batch
|
| 411 |
-
+ pid_c * stride_dD_chunk
|
| 412 |
-
+ pid_h * stride_dD_head
|
| 413 |
-
+ pid_m * stride_dD_csize
|
| 414 |
-
)
|
| 415 |
-
if D_HAS_HDIM:
|
| 416 |
-
dD_ptrs = dD_ptr + offs_n * stride_dD_hdim
|
| 417 |
-
dD = tl.sum(dout_res * x, axis=0)
|
| 418 |
-
tl.store(dD_ptrs, dD, mask=offs_n < hdim)
|
| 419 |
-
else:
|
| 420 |
-
dD = tl.sum(dout_res * x)
|
| 421 |
-
tl.store(dD_ptr, dD)
|
| 422 |
-
ddt = tl.sum(acc * x, axis=1)
|
| 423 |
-
ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
|
| 424 |
-
tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
def _chunk_scan_chunk_state_bwd_dx(
|
| 428 |
-
x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None
|
| 429 |
-
):
|
| 430 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 431 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 432 |
-
_, _, ngroups, dstate = B.shape
|
| 433 |
-
assert nheads % ngroups == 0
|
| 434 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 435 |
-
assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
|
| 436 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 437 |
-
assert dA_cumsum.shape == dt.shape
|
| 438 |
-
assert dout.shape == x.shape
|
| 439 |
-
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
| 440 |
-
if seq_idx is not None:
|
| 441 |
-
assert seq_idx.shape == (batch, seqlen)
|
| 442 |
-
if D is not None:
|
| 443 |
-
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
|
| 444 |
-
assert D.stride(-1) == 1
|
| 445 |
-
BLOCK_SIZE_min = 32
|
| 446 |
-
dD = torch.empty(
|
| 447 |
-
triton.cdiv(chunk_size, BLOCK_SIZE_min),
|
| 448 |
-
batch,
|
| 449 |
-
nchunks,
|
| 450 |
-
nheads,
|
| 451 |
-
headdim if D.dim() == 2 else 1,
|
| 452 |
-
device=D.device,
|
| 453 |
-
dtype=torch.float32,
|
| 454 |
-
)
|
| 455 |
-
else:
|
| 456 |
-
dD = None
|
| 457 |
-
dD_strides = (
|
| 458 |
-
(dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))
|
| 459 |
-
if D is not None
|
| 460 |
-
else (0, 0, 0, 0, 0)
|
| 461 |
-
)
|
| 462 |
-
if dx is None:
|
| 463 |
-
dx = torch.empty_like(x)
|
| 464 |
-
else:
|
| 465 |
-
assert dx.shape == x.shape
|
| 466 |
-
ddt = torch.empty(
|
| 467 |
-
batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32
|
| 468 |
-
)
|
| 469 |
-
grid_dx = lambda META: (
|
| 470 |
-
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
| 471 |
-
* triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
|
| 472 |
-
batch * nchunks,
|
| 473 |
-
nheads,
|
| 474 |
-
)
|
| 475 |
-
with torch.cuda.device(x.device.index):
|
| 476 |
-
_chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](
|
| 477 |
-
x,
|
| 478 |
-
CB,
|
| 479 |
-
dout,
|
| 480 |
-
dt,
|
| 481 |
-
dA_cumsum,
|
| 482 |
-
seq_idx,
|
| 483 |
-
D,
|
| 484 |
-
B,
|
| 485 |
-
dstates,
|
| 486 |
-
dx,
|
| 487 |
-
ddt,
|
| 488 |
-
dD,
|
| 489 |
-
chunk_size,
|
| 490 |
-
headdim,
|
| 491 |
-
dstate,
|
| 492 |
-
batch,
|
| 493 |
-
seqlen,
|
| 494 |
-
nheads // ngroups,
|
| 495 |
-
x.stride(0),
|
| 496 |
-
x.stride(1),
|
| 497 |
-
x.stride(2),
|
| 498 |
-
x.stride(3),
|
| 499 |
-
CB.stride(0),
|
| 500 |
-
CB.stride(1),
|
| 501 |
-
CB.stride(2),
|
| 502 |
-
CB.stride(-1),
|
| 503 |
-
CB.stride(-2),
|
| 504 |
-
dout.stride(0),
|
| 505 |
-
dout.stride(1),
|
| 506 |
-
dout.stride(2),
|
| 507 |
-
dout.stride(3),
|
| 508 |
-
dt.stride(0),
|
| 509 |
-
dt.stride(2),
|
| 510 |
-
dt.stride(1),
|
| 511 |
-
dt.stride(3),
|
| 512 |
-
dA_cumsum.stride(0),
|
| 513 |
-
dA_cumsum.stride(2),
|
| 514 |
-
dA_cumsum.stride(1),
|
| 515 |
-
dA_cumsum.stride(3),
|
| 516 |
-
*(
|
| 517 |
-
(seq_idx.stride(0), seq_idx.stride(1))
|
| 518 |
-
if seq_idx is not None
|
| 519 |
-
else (0, 0)
|
| 520 |
-
),
|
| 521 |
-
D.stride(0) if D is not None else 0,
|
| 522 |
-
B.stride(0),
|
| 523 |
-
B.stride(1),
|
| 524 |
-
B.stride(2),
|
| 525 |
-
B.stride(3),
|
| 526 |
-
dstates.stride(0),
|
| 527 |
-
dstates.stride(1),
|
| 528 |
-
dstates.stride(2),
|
| 529 |
-
dstates.stride(3),
|
| 530 |
-
dstates.stride(4),
|
| 531 |
-
dx.stride(0),
|
| 532 |
-
dx.stride(1),
|
| 533 |
-
dx.stride(2),
|
| 534 |
-
dx.stride(3),
|
| 535 |
-
ddt.stride(0),
|
| 536 |
-
ddt.stride(2),
|
| 537 |
-
ddt.stride(1),
|
| 538 |
-
ddt.stride(3),
|
| 539 |
-
dD_strides[1],
|
| 540 |
-
dD_strides[2],
|
| 541 |
-
dD_strides[3],
|
| 542 |
-
dD_strides[0],
|
| 543 |
-
dD_strides[4],
|
| 544 |
-
D is not None,
|
| 545 |
-
D.dim() == 2 if D is not None else True,
|
| 546 |
-
HAS_SEQ_IDX=seq_idx is not None,
|
| 547 |
-
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
| 548 |
-
IS_TRITON_22=TRITON_22
|
| 549 |
-
)
|
| 550 |
-
if D is not None:
|
| 551 |
-
BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs[
|
| 552 |
-
"BLOCK_SIZE_M"
|
| 553 |
-
]
|
| 554 |
-
n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
|
| 555 |
-
dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
|
| 556 |
-
if D.dim() == 1:
|
| 557 |
-
dD = rearrange(dD, "h 1 -> h")
|
| 558 |
-
return dx, ddt.to(dtype=dt.dtype), dD
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
def _mamba_chunk_scan_combined_fwd(
|
| 562 |
-
x,
|
| 563 |
-
dt,
|
| 564 |
-
A,
|
| 565 |
-
B,
|
| 566 |
-
C,
|
| 567 |
-
chunk_size,
|
| 568 |
-
D=None,
|
| 569 |
-
z=None,
|
| 570 |
-
dt_bias=None,
|
| 571 |
-
initial_states=None,
|
| 572 |
-
seq_idx=None,
|
| 573 |
-
cu_seqlens=None,
|
| 574 |
-
dt_softplus=False,
|
| 575 |
-
dt_limit=(0.0, float("inf")),
|
| 576 |
-
):
|
| 577 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 578 |
-
_, _, ngroups, dstate = B.shape
|
| 579 |
-
assert nheads % ngroups == 0
|
| 580 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 581 |
-
assert x.shape == (batch, seqlen, nheads, headdim)
|
| 582 |
-
assert dt.shape == (batch, seqlen, nheads)
|
| 583 |
-
assert A.shape == (nheads,)
|
| 584 |
-
assert C.shape == B.shape
|
| 585 |
-
if z is not None:
|
| 586 |
-
assert z.shape == x.shape
|
| 587 |
-
if D is not None:
|
| 588 |
-
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
|
| 589 |
-
if seq_idx is not None:
|
| 590 |
-
assert seq_idx.shape == (batch, seqlen)
|
| 591 |
-
if B.stride(-1) != 1:
|
| 592 |
-
B = B.contiguous()
|
| 593 |
-
if C.stride(-1) != 1:
|
| 594 |
-
C = C.contiguous()
|
| 595 |
-
if (
|
| 596 |
-
x.stride(-1) != 1 and x.stride(1) != 1
|
| 597 |
-
): # Either M or K dimension should be contiguous
|
| 598 |
-
x = x.contiguous()
|
| 599 |
-
if (
|
| 600 |
-
z is not None and z.stride(-1) != 1 and z.stride(1) != 1
|
| 601 |
-
): # Either M or K dimension should be contiguous
|
| 602 |
-
z = z.contiguous()
|
| 603 |
-
if D is not None and D.stride(-1) != 1:
|
| 604 |
-
D = D.contiguous()
|
| 605 |
-
if initial_states is not None:
|
| 606 |
-
assert initial_states.shape == (batch, nheads, headdim, dstate)
|
| 607 |
-
# # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size)
|
| 608 |
-
# dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
|
| 609 |
-
# dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
|
| 610 |
-
# dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
|
| 611 |
-
dA_cumsum, dt = _chunk_cumsum_fwd(
|
| 612 |
-
dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit
|
| 613 |
-
)
|
| 614 |
-
states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
|
| 615 |
-
# states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True)
|
| 616 |
-
# states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True)
|
| 617 |
-
# states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True)
|
| 618 |
-
states, final_states = _state_passing_fwd(
|
| 619 |
-
rearrange(states, "... p n -> ... (p n)"),
|
| 620 |
-
dA_cumsum[:, :, :, -1],
|
| 621 |
-
initial_states=(
|
| 622 |
-
rearrange(initial_states, "... p n -> ... (p n)")
|
| 623 |
-
if initial_states is not None
|
| 624 |
-
else None
|
| 625 |
-
),
|
| 626 |
-
seq_idx=seq_idx,
|
| 627 |
-
chunk_size=chunk_size,
|
| 628 |
-
out_dtype=C.dtype,
|
| 629 |
-
)
|
| 630 |
-
states, final_states = [
|
| 631 |
-
rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]
|
| 632 |
-
]
|
| 633 |
-
# states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
|
| 634 |
-
# states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
|
| 635 |
-
CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
|
| 636 |
-
out, out_x = _chunk_scan_fwd(
|
| 637 |
-
CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx
|
| 638 |
-
)
|
| 639 |
-
if cu_seqlens is None:
|
| 640 |
-
return out, out_x, dt, dA_cumsum, states, final_states
|
| 641 |
-
else:
|
| 642 |
-
assert (
|
| 643 |
-
batch == 1
|
| 644 |
-
), "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
|
| 645 |
-
varlen_states = chunk_state_varlen(
|
| 646 |
-
B.squeeze(0),
|
| 647 |
-
x.squeeze(0),
|
| 648 |
-
dt.squeeze(0),
|
| 649 |
-
dA_cumsum.squeeze(0),
|
| 650 |
-
cu_seqlens,
|
| 651 |
-
states.squeeze(0),
|
| 652 |
-
)
|
| 653 |
-
return out, out_x, dt, dA_cumsum, states, final_states, varlen_states
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
def _mamba_chunk_scan_combined_bwd(
|
| 657 |
-
dout,
|
| 658 |
-
x,
|
| 659 |
-
dt,
|
| 660 |
-
A,
|
| 661 |
-
B,
|
| 662 |
-
C,
|
| 663 |
-
out,
|
| 664 |
-
chunk_size,
|
| 665 |
-
D=None,
|
| 666 |
-
z=None,
|
| 667 |
-
dt_bias=None,
|
| 668 |
-
initial_states=None,
|
| 669 |
-
dfinal_states=None,
|
| 670 |
-
seq_idx=None,
|
| 671 |
-
dt_softplus=False,
|
| 672 |
-
dt_limit=(0.0, float("inf")),
|
| 673 |
-
dx=None,
|
| 674 |
-
ddt=None,
|
| 675 |
-
dB=None,
|
| 676 |
-
dC=None,
|
| 677 |
-
dz=None,
|
| 678 |
-
recompute_output=False,
|
| 679 |
-
):
|
| 680 |
-
if dout.stride(-1) != 1:
|
| 681 |
-
dout = dout.contiguous()
|
| 682 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 683 |
-
nchunks = math.ceil(seqlen / chunk_size)
|
| 684 |
-
_, _, ngroups, dstate = B.shape
|
| 685 |
-
assert dout.shape == (batch, seqlen, nheads, headdim)
|
| 686 |
-
assert dt.shape == (batch, seqlen, nheads)
|
| 687 |
-
assert A.shape == (nheads,)
|
| 688 |
-
assert nheads % ngroups == 0
|
| 689 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 690 |
-
assert C.shape == B.shape
|
| 691 |
-
assert out.shape == x.shape
|
| 692 |
-
if initial_states is not None:
|
| 693 |
-
assert initial_states.shape == (batch, nheads, headdim, dstate)
|
| 694 |
-
if seq_idx is not None:
|
| 695 |
-
assert seq_idx.shape == (batch, seqlen)
|
| 696 |
-
if dx is not None:
|
| 697 |
-
assert dx.shape == x.shape
|
| 698 |
-
if dB is not None:
|
| 699 |
-
assert dB.shape == B.shape
|
| 700 |
-
dB_given = dB
|
| 701 |
-
else:
|
| 702 |
-
dB_given = torch.empty_like(B)
|
| 703 |
-
if dC is not None:
|
| 704 |
-
assert dC.shape == C.shape
|
| 705 |
-
dC_given = dC
|
| 706 |
-
else:
|
| 707 |
-
dC_given = torch.empty_like(C)
|
| 708 |
-
if dz is not None:
|
| 709 |
-
assert z is not None
|
| 710 |
-
assert dz.shape == z.shape
|
| 711 |
-
if ddt is not None:
|
| 712 |
-
assert ddt.shape == dt.shape
|
| 713 |
-
ddt_given = ddt
|
| 714 |
-
else:
|
| 715 |
-
ddt_given = torch.empty_like(dt)
|
| 716 |
-
# TD: For some reason Triton (2.1.0 and 2.2.0) errors with
|
| 717 |
-
# "[CUDA]: invalid device context" (e.g. during varlne test), and cloning makes it work. Idk why.
|
| 718 |
-
dt_in = dt.clone()
|
| 719 |
-
dA_cumsum, dt = _chunk_cumsum_fwd(
|
| 720 |
-
dt_in,
|
| 721 |
-
A,
|
| 722 |
-
chunk_size,
|
| 723 |
-
dt_bias=dt_bias,
|
| 724 |
-
dt_softplus=dt_softplus,
|
| 725 |
-
dt_limit=dt_limit,
|
| 726 |
-
)
|
| 727 |
-
CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
|
| 728 |
-
states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
|
| 729 |
-
states, _ = _state_passing_fwd(
|
| 730 |
-
rearrange(states, "... p n -> ... (p n)"),
|
| 731 |
-
dA_cumsum[:, :, :, -1],
|
| 732 |
-
initial_states=(
|
| 733 |
-
rearrange(initial_states, "... p n -> ... (p n)")
|
| 734 |
-
if initial_states is not None
|
| 735 |
-
else None
|
| 736 |
-
),
|
| 737 |
-
seq_idx=seq_idx,
|
| 738 |
-
chunk_size=chunk_size,
|
| 739 |
-
)
|
| 740 |
-
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
|
| 741 |
-
if z is not None:
|
| 742 |
-
dz, dout, dD, *rest = _chunk_scan_bwd_dz(
|
| 743 |
-
x,
|
| 744 |
-
z,
|
| 745 |
-
out,
|
| 746 |
-
dout,
|
| 747 |
-
chunk_size=chunk_size,
|
| 748 |
-
has_ddAcs=False,
|
| 749 |
-
D=D,
|
| 750 |
-
dz=dz,
|
| 751 |
-
recompute_output=recompute_output,
|
| 752 |
-
)
|
| 753 |
-
outz = rest[0] if recompute_output else out
|
| 754 |
-
else:
|
| 755 |
-
dz = None
|
| 756 |
-
outz = out
|
| 757 |
-
dstates = _chunk_scan_bwd_dstates(
|
| 758 |
-
C, dA_cumsum, dout, seq_idx=seq_idx, dtype=states.dtype
|
| 759 |
-
)
|
| 760 |
-
# dstates has length nchunks, containing the gradient to initial states at index 0 and
|
| 761 |
-
# gradient to the states of chunk (nchunks - 2) at index (nchunks - 1)
|
| 762 |
-
# Do computation in fp32 but convert dstates and states to fp16/bf16 since dstates and states
|
| 763 |
-
# will be used in matmul in the next kernels.
|
| 764 |
-
dstates, ddA_chunk_cumsum, dinitial_states, states = _state_passing_bwd(
|
| 765 |
-
rearrange(states, "... p n -> ... (p n)"),
|
| 766 |
-
dA_cumsum[:, :, :, -1],
|
| 767 |
-
rearrange(dstates, "... p n -> ... (p n)"),
|
| 768 |
-
dfinal_states=(
|
| 769 |
-
rearrange(dfinal_states, "... p n -> ... (p n)")
|
| 770 |
-
if dfinal_states is not None
|
| 771 |
-
else None
|
| 772 |
-
),
|
| 773 |
-
seq_idx=seq_idx,
|
| 774 |
-
has_initial_states=initial_states is not None,
|
| 775 |
-
dstates_dtype=x.dtype,
|
| 776 |
-
states_dtype=x.dtype,
|
| 777 |
-
chunk_size=chunk_size,
|
| 778 |
-
)
|
| 779 |
-
# dstates has length nchunks, containing the gradient to states of chunk 0 at index 0 and
|
| 780 |
-
# gradient to the final states at index (nchunks - 1)
|
| 781 |
-
# states has length nchunks, containing the initial states at index 0 and the state for chunk (nchunks - 2) at index (nchunks - 1)
|
| 782 |
-
# The final states is not stored.
|
| 783 |
-
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
|
| 784 |
-
dstates = rearrange(dstates, "... (p n) -> ... p n", n=dstate)
|
| 785 |
-
dinitial_states = (
|
| 786 |
-
rearrange(dinitial_states, "... (p n) -> ... p n", n=dstate)
|
| 787 |
-
if dinitial_states is not None
|
| 788 |
-
else None
|
| 789 |
-
)
|
| 790 |
-
dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(
|
| 791 |
-
x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx
|
| 792 |
-
)
|
| 793 |
-
# dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, ngroups=ngroups)
|
| 794 |
-
dB, ddA_next = _chunk_state_bwd_db(
|
| 795 |
-
x, dt, dA_cumsum, dstates, seq_idx=seq_idx, B=B, ngroups=ngroups
|
| 796 |
-
)
|
| 797 |
-
# dC = _chunk_scan_bwd_dC(states[:, :-1].to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
|
| 798 |
-
dC, ddA_cumsum_prev = _chunk_scan_bwd_dC(
|
| 799 |
-
states.to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, C=C, ngroups=ngroups
|
| 800 |
-
)
|
| 801 |
-
# Computing ddA with the dcb kernel is much slower, so we're not using it for now
|
| 802 |
-
dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
|
| 803 |
-
# dCB, ddA_tmp = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, CB=CB, ngroups=ngroups)
|
| 804 |
-
dCB = dCB.to(CB.dtype)
|
| 805 |
-
_bmm_chunk_bwd(C, dCB, residual=dB, out=dB_given)
|
| 806 |
-
_bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC, out=dC_given)
|
| 807 |
-
# If we have z, then dout_x is recomputed in fp32 so dD = (dout_x * x).sum() is more accurate
|
| 808 |
-
# than dD_from_x = (dout_x * x).sum() where dout_x is in fp16/bf16
|
| 809 |
-
if z is None:
|
| 810 |
-
dD = dD_from_x
|
| 811 |
-
# Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D.
|
| 812 |
-
# ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt
|
| 813 |
-
# However, this is numerically unstable: when we do the reverse cumsum on ddA_cumsum, there might
|
| 814 |
-
# be a lot of underflow.
|
| 815 |
-
|
| 816 |
-
# This is already done as part of bwd_dC kernel
|
| 817 |
-
# ddA_cumsum_prev = _chunk_scan_bwd_ddAcs_prev(states[:, :-1], C, dout, dA_cumsum, seq_idx=seq_idx)
|
| 818 |
-
ddA_cumsum_prev[..., -1] += ddA_chunk_cumsum
|
| 819 |
-
ddA_prev = ddA_cumsum_prev.flip([-1]).cumsum(dim=-1).flip([-1])
|
| 820 |
-
# This is already done as part of bwd_dB kernel
|
| 821 |
-
# ddA_next = _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=seq_idx)
|
| 822 |
-
# We don't need to pass in seq_idx because CB also zeros out entries where seq_idx[i] != seq_idx[j]
|
| 823 |
-
ddA = _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, CB)
|
| 824 |
-
ddA += ddA_next + ddA_prev
|
| 825 |
-
|
| 826 |
-
ddt_given, dA, ddt_bias = _chunk_cumsum_bwd(
|
| 827 |
-
ddA,
|
| 828 |
-
ddt,
|
| 829 |
-
dt_in,
|
| 830 |
-
A,
|
| 831 |
-
dt_bias=dt_bias,
|
| 832 |
-
dt_softplus=dt_softplus,
|
| 833 |
-
dt_limit=dt_limit,
|
| 834 |
-
ddt=ddt_given,
|
| 835 |
-
)
|
| 836 |
-
|
| 837 |
-
# These 2 lines are just to test ddt and dA being computed by old code
|
| 838 |
-
# _, dA = selective_scan_bwd(dout, x, dt, A, B, C, D=D.float(), z=z)
|
| 839 |
-
# ddt_given.copy_(ddt)
|
| 840 |
-
|
| 841 |
-
return_vals = (
|
| 842 |
-
dx,
|
| 843 |
-
ddt_given,
|
| 844 |
-
dA,
|
| 845 |
-
dB_given,
|
| 846 |
-
dC_given,
|
| 847 |
-
dD,
|
| 848 |
-
dz,
|
| 849 |
-
ddt_bias,
|
| 850 |
-
dinitial_states,
|
| 851 |
-
)
|
| 852 |
-
return return_vals if not recompute_output else (*return_vals, outz)
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
def selective_scan_bwd(dout, x, dt, A, B, C, D=None, z=None):
|
| 856 |
-
"""
|
| 857 |
-
Argument:
|
| 858 |
-
dout: (batch, seqlen, nheads, headdim)
|
| 859 |
-
x: (batch, seqlen, nheads, headdim)
|
| 860 |
-
dt: (batch, nheads, nchunks, chunk_size) or (batch, nheads, headdim, nchunks, chunk_size)
|
| 861 |
-
A: (nheads) or (dim, dstate)
|
| 862 |
-
B: (batch, seqlen, ngroups, dstate)
|
| 863 |
-
C: (batch, seqlen, ngroups, dstate)
|
| 864 |
-
D: (nheads, headdim) or (nheads,)
|
| 865 |
-
z: (batch, seqlen, nheads, headdim)
|
| 866 |
-
Return:
|
| 867 |
-
out: (batch, seqlen, nheads, headdim)
|
| 868 |
-
"""
|
| 869 |
-
import selective_scan
|
| 870 |
-
|
| 871 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 872 |
-
chunk_size = dt.shape[-1]
|
| 873 |
-
_, _, ngroups, dstate = B.shape
|
| 874 |
-
assert nheads % ngroups == 0
|
| 875 |
-
x = rearrange(x, "b l h p -> b (h p) l")
|
| 876 |
-
squeeze_dt = dt.dim() == 4
|
| 877 |
-
if dt.dim() == 4:
|
| 878 |
-
dt = repeat(dt, "b h c l -> b h p c l", p=headdim)
|
| 879 |
-
dt = rearrange(dt, "b h p c l -> b (h p) (c l)", p=headdim)
|
| 880 |
-
squeeze_A = A.dim() == 1
|
| 881 |
-
if A.dim() == 1:
|
| 882 |
-
A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
|
| 883 |
-
else:
|
| 884 |
-
A = A.to(dtype=torch.float32)
|
| 885 |
-
B = rearrange(B, "b l g n -> b g n l")
|
| 886 |
-
C = rearrange(C, "b l g n -> b g n l")
|
| 887 |
-
if D is not None:
|
| 888 |
-
if D.dim() == 2:
|
| 889 |
-
D = rearrange(D, "h p -> (h p)")
|
| 890 |
-
else:
|
| 891 |
-
D = repeat(D, "h -> (h p)", p=headdim)
|
| 892 |
-
if z is not None:
|
| 893 |
-
z = rearrange(z, "b l h p -> b (h p) l")
|
| 894 |
-
|
| 895 |
-
if x.stride(-1) != 1:
|
| 896 |
-
x = x.contiguous()
|
| 897 |
-
if dt.stride(-1) != 1:
|
| 898 |
-
dt = dt.contiguous()
|
| 899 |
-
if D is not None:
|
| 900 |
-
D = D.contiguous()
|
| 901 |
-
if B.stride(-1) != 1:
|
| 902 |
-
B = B.contiguous()
|
| 903 |
-
if C.stride(-1) != 1:
|
| 904 |
-
C = C.contiguous()
|
| 905 |
-
if z is not None and z.stride(-1) != 1:
|
| 906 |
-
z = z.contiguous()
|
| 907 |
-
_, intermediate, *rest = selective_scan.fwd(
|
| 908 |
-
x, dt.to(dtype=x.dtype), A, B, C, D, z, None, False
|
| 909 |
-
)
|
| 910 |
-
if z is not None:
|
| 911 |
-
out = rest[0]
|
| 912 |
-
else:
|
| 913 |
-
out = None
|
| 914 |
-
|
| 915 |
-
dout = rearrange(dout, "b l h p -> b (h p) l")
|
| 916 |
-
|
| 917 |
-
if dout.stride(-1) != 1:
|
| 918 |
-
dout = dout.contiguous()
|
| 919 |
-
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
| 920 |
-
# backward of selective_scan with the backward of chunk).
|
| 921 |
-
# Here we just pass in None and dz will be allocated in the C++ code.
|
| 922 |
-
_, ddt, dA, *rest = selective_scan.bwd(
|
| 923 |
-
x,
|
| 924 |
-
dt.to(dtype=x.dtype),
|
| 925 |
-
A,
|
| 926 |
-
B,
|
| 927 |
-
C,
|
| 928 |
-
D,
|
| 929 |
-
z,
|
| 930 |
-
None,
|
| 931 |
-
dout,
|
| 932 |
-
intermediate,
|
| 933 |
-
out,
|
| 934 |
-
None,
|
| 935 |
-
False,
|
| 936 |
-
False, # option to recompute out_z, not used here
|
| 937 |
-
)
|
| 938 |
-
ddt = rearrange(ddt, "b (h p) (c l) -> b h p c l", p=headdim, l=chunk_size)
|
| 939 |
-
if squeeze_dt:
|
| 940 |
-
ddt = ddt.float().sum(dim=2)
|
| 941 |
-
if squeeze_A:
|
| 942 |
-
dA = rearrange(dA, "(h p) n -> h p n", p=headdim).sum(dim=(1, 2))
|
| 943 |
-
return ddt, dA
|
| 944 |
-
|
| 945 |
-
|
| 946 |
-
class MambaChunkScanCombinedFn(torch.autograd.Function):
|
| 947 |
-
|
| 948 |
-
@staticmethod
|
| 949 |
-
def forward(
|
| 950 |
-
ctx,
|
| 951 |
-
x,
|
| 952 |
-
dt,
|
| 953 |
-
A,
|
| 954 |
-
B,
|
| 955 |
-
C,
|
| 956 |
-
chunk_size,
|
| 957 |
-
D=None,
|
| 958 |
-
z=None,
|
| 959 |
-
dt_bias=None,
|
| 960 |
-
initial_states=None,
|
| 961 |
-
seq_idx=None,
|
| 962 |
-
cu_seqlens=None,
|
| 963 |
-
dt_softplus=False,
|
| 964 |
-
dt_limit=(0.0, float("inf")),
|
| 965 |
-
return_final_states=False,
|
| 966 |
-
return_varlen_states=False,
|
| 967 |
-
):
|
| 968 |
-
ctx.dt_dtype = dt.dtype
|
| 969 |
-
if not return_varlen_states:
|
| 970 |
-
cu_seqlens = None
|
| 971 |
-
else:
|
| 972 |
-
assert (
|
| 973 |
-
cu_seqlens is not None
|
| 974 |
-
), "cu_seqlens must be provided if return_varlen_states is True"
|
| 975 |
-
out, out_x, dt_out, dA_cumsum, states, final_states, *rest = (
|
| 976 |
-
_mamba_chunk_scan_combined_fwd(
|
| 977 |
-
x,
|
| 978 |
-
dt,
|
| 979 |
-
A,
|
| 980 |
-
B,
|
| 981 |
-
C,
|
| 982 |
-
chunk_size,
|
| 983 |
-
D=D,
|
| 984 |
-
z=z,
|
| 985 |
-
dt_bias=dt_bias,
|
| 986 |
-
initial_states=initial_states,
|
| 987 |
-
seq_idx=seq_idx,
|
| 988 |
-
cu_seqlens=cu_seqlens,
|
| 989 |
-
dt_softplus=dt_softplus,
|
| 990 |
-
dt_limit=dt_limit,
|
| 991 |
-
)
|
| 992 |
-
)
|
| 993 |
-
ctx.save_for_backward(
|
| 994 |
-
out if z is None else out_x,
|
| 995 |
-
x,
|
| 996 |
-
dt,
|
| 997 |
-
dA_cumsum,
|
| 998 |
-
A,
|
| 999 |
-
B,
|
| 1000 |
-
C,
|
| 1001 |
-
D,
|
| 1002 |
-
z,
|
| 1003 |
-
dt_bias,
|
| 1004 |
-
initial_states,
|
| 1005 |
-
seq_idx,
|
| 1006 |
-
)
|
| 1007 |
-
ctx.dt_softplus = dt_softplus
|
| 1008 |
-
ctx.chunk_size = chunk_size
|
| 1009 |
-
ctx.dt_limit = dt_limit
|
| 1010 |
-
ctx.return_final_states = return_final_states
|
| 1011 |
-
ctx.return_varlen_states = return_varlen_states
|
| 1012 |
-
if not return_varlen_states:
|
| 1013 |
-
return out if not return_final_states else (out, final_states)
|
| 1014 |
-
else:
|
| 1015 |
-
varlen_states = rest[0]
|
| 1016 |
-
return (
|
| 1017 |
-
(out, varlen_states)
|
| 1018 |
-
if not return_final_states
|
| 1019 |
-
else (out, final_states, varlen_states)
|
| 1020 |
-
)
|
| 1021 |
-
|
| 1022 |
-
@staticmethod
|
| 1023 |
-
def backward(ctx, dout, *args):
|
| 1024 |
-
out, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx = (
|
| 1025 |
-
ctx.saved_tensors
|
| 1026 |
-
)
|
| 1027 |
-
assert (
|
| 1028 |
-
not ctx.return_varlen_states
|
| 1029 |
-
), "return_varlen_states is not supported in backward"
|
| 1030 |
-
dfinal_states = args[0] if ctx.return_final_states else None
|
| 1031 |
-
dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = (
|
| 1032 |
-
_mamba_chunk_scan_combined_bwd(
|
| 1033 |
-
dout,
|
| 1034 |
-
x,
|
| 1035 |
-
dt,
|
| 1036 |
-
A,
|
| 1037 |
-
B,
|
| 1038 |
-
C,
|
| 1039 |
-
out,
|
| 1040 |
-
ctx.chunk_size,
|
| 1041 |
-
D=D,
|
| 1042 |
-
z=z,
|
| 1043 |
-
dt_bias=dt_bias,
|
| 1044 |
-
initial_states=initial_states,
|
| 1045 |
-
dfinal_states=dfinal_states,
|
| 1046 |
-
seq_idx=seq_idx,
|
| 1047 |
-
dt_softplus=ctx.dt_softplus,
|
| 1048 |
-
dt_limit=ctx.dt_limit,
|
| 1049 |
-
)
|
| 1050 |
-
)
|
| 1051 |
-
return (
|
| 1052 |
-
dx,
|
| 1053 |
-
ddt,
|
| 1054 |
-
dA,
|
| 1055 |
-
dB,
|
| 1056 |
-
dC,
|
| 1057 |
-
None,
|
| 1058 |
-
dD,
|
| 1059 |
-
dz,
|
| 1060 |
-
ddt_bias,
|
| 1061 |
-
dinitial_states,
|
| 1062 |
-
None,
|
| 1063 |
-
None,
|
| 1064 |
-
None,
|
| 1065 |
-
None,
|
| 1066 |
-
None,
|
| 1067 |
-
None,
|
| 1068 |
-
)
|
| 1069 |
-
|
| 1070 |
-
|
| 1071 |
-
def mamba_chunk_scan_combined(
|
| 1072 |
-
x,
|
| 1073 |
-
dt,
|
| 1074 |
-
A,
|
| 1075 |
-
B,
|
| 1076 |
-
C,
|
| 1077 |
-
chunk_size,
|
| 1078 |
-
D=None,
|
| 1079 |
-
z=None,
|
| 1080 |
-
dt_bias=None,
|
| 1081 |
-
initial_states=None,
|
| 1082 |
-
seq_idx=None,
|
| 1083 |
-
cu_seqlens=None,
|
| 1084 |
-
dt_softplus=False,
|
| 1085 |
-
dt_limit=(0.0, float("inf")),
|
| 1086 |
-
return_final_states=False,
|
| 1087 |
-
return_varlen_states=False,
|
| 1088 |
-
):
|
| 1089 |
-
"""
|
| 1090 |
-
Argument:
|
| 1091 |
-
x: (batch, seqlen, nheads, headdim)
|
| 1092 |
-
dt: (batch, seqlen, nheads)
|
| 1093 |
-
A: (nheads)
|
| 1094 |
-
B: (batch, seqlen, ngroups, dstate)
|
| 1095 |
-
C: (batch, seqlen, ngroups, dstate)
|
| 1096 |
-
chunk_size: int
|
| 1097 |
-
D: (nheads, headdim) or (nheads,)
|
| 1098 |
-
z: (batch, seqlen, nheads, headdim)
|
| 1099 |
-
dt_bias: (nheads,)
|
| 1100 |
-
initial_states: (batch, nheads, headdim, dstate)
|
| 1101 |
-
seq_idx: (batch, seqlen)
|
| 1102 |
-
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
|
| 1103 |
-
dt_softplus: Whether to apply softplus to dt
|
| 1104 |
-
Return:
|
| 1105 |
-
out: (batch, seqlen, nheads, headdim)
|
| 1106 |
-
"""
|
| 1107 |
-
return MambaChunkScanCombinedFn.apply(
|
| 1108 |
-
x,
|
| 1109 |
-
dt,
|
| 1110 |
-
A,
|
| 1111 |
-
B,
|
| 1112 |
-
C,
|
| 1113 |
-
chunk_size,
|
| 1114 |
-
D,
|
| 1115 |
-
z,
|
| 1116 |
-
dt_bias,
|
| 1117 |
-
initial_states,
|
| 1118 |
-
seq_idx,
|
| 1119 |
-
cu_seqlens,
|
| 1120 |
-
dt_softplus,
|
| 1121 |
-
dt_limit,
|
| 1122 |
-
return_final_states,
|
| 1123 |
-
return_varlen_states,
|
| 1124 |
-
)
|
| 1125 |
-
|
| 1126 |
-
|
| 1127 |
-
def mamba_chunk_scan(
|
| 1128 |
-
x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False
|
| 1129 |
-
):
|
| 1130 |
-
"""
|
| 1131 |
-
Argument:
|
| 1132 |
-
x: (batch, seqlen, nheads, headdim)
|
| 1133 |
-
dt: (batch, seqlen, nheads)
|
| 1134 |
-
A: (nheads)
|
| 1135 |
-
B: (batch, seqlen, ngroups, dstate)
|
| 1136 |
-
C: (batch, seqlen, ngroups, dstate)
|
| 1137 |
-
D: (nheads, headdim) or (nheads,)
|
| 1138 |
-
z: (batch, seqlen, nheads, headdim)
|
| 1139 |
-
dt_bias: (nheads,)
|
| 1140 |
-
Return:
|
| 1141 |
-
out: (batch, seqlen, nheads, headdim)
|
| 1142 |
-
"""
|
| 1143 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1144 |
-
dstate = B.shape[-1]
|
| 1145 |
-
if seqlen % chunk_size != 0:
|
| 1146 |
-
dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
|
| 1147 |
-
dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
|
| 1148 |
-
dt = dt.float() # We want high precision for this before cumsum
|
| 1149 |
-
if dt_bias is not None:
|
| 1150 |
-
dt = dt + rearrange(dt_bias, "h -> h 1 1")
|
| 1151 |
-
if dt_softplus:
|
| 1152 |
-
dt = F.softplus(dt)
|
| 1153 |
-
dA = dt * rearrange(A, "h -> h 1 1")
|
| 1154 |
-
dA = dt * rearrange(A, "h -> h 1 1")
|
| 1155 |
-
dA_cumsum = torch.cumsum(dA, dim=-1)
|
| 1156 |
-
# 1. Compute the state for each chunk
|
| 1157 |
-
states = chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True)
|
| 1158 |
-
# 2. Pass the state to all the chunks by weighted cumsum.
|
| 1159 |
-
states = rearrange(
|
| 1160 |
-
state_passing(
|
| 1161 |
-
rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1]
|
| 1162 |
-
)[0],
|
| 1163 |
-
"... (p n) -> ... p n",
|
| 1164 |
-
n=dstate,
|
| 1165 |
-
)
|
| 1166 |
-
# 3. Compute the output for each chunk
|
| 1167 |
-
out = chunk_scan(B, C, x, dt, dA_cumsum, states, D=D, z=z)
|
| 1168 |
-
return out
|
| 1169 |
-
|
| 1170 |
-
|
| 1171 |
-
def ssd_chunk_scan_combined_ref(
|
| 1172 |
-
x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False
|
| 1173 |
-
):
|
| 1174 |
-
"""
|
| 1175 |
-
Argument:
|
| 1176 |
-
x: (batch, seqlen, nheads, headdim)
|
| 1177 |
-
dt: (batch, seqlen, nheads)
|
| 1178 |
-
A: (nheads)
|
| 1179 |
-
B: (batch, seqlen, ngroups, dstate)
|
| 1180 |
-
C: (batch, seqlen, ngroups, dstate)
|
| 1181 |
-
D: (nheads, headdim) or (nheads,)
|
| 1182 |
-
z: (batch, seqlen, nheads, headdim)
|
| 1183 |
-
dt_bias: (nheads,)
|
| 1184 |
-
Return:
|
| 1185 |
-
out: (batch, seqlen, nheads, headdim)
|
| 1186 |
-
"""
|
| 1187 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1188 |
-
dstate = B.shape[-1]
|
| 1189 |
-
if seqlen % chunk_size != 0:
|
| 1190 |
-
dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
|
| 1191 |
-
dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
|
| 1192 |
-
dt = dt.float() # We want high precision for this before cumsum
|
| 1193 |
-
if dt_bias is not None:
|
| 1194 |
-
dt = dt + rearrange(dt_bias, "h -> h 1 1")
|
| 1195 |
-
if dt_softplus:
|
| 1196 |
-
dt = F.softplus(dt)
|
| 1197 |
-
dA = dt * rearrange(A, "h -> h 1 1")
|
| 1198 |
-
dA_cumsum = torch.cumsum(dA, dim=-1)
|
| 1199 |
-
# 1. Compute the state for each chunk
|
| 1200 |
-
states = chunk_state_ref(B, x, dt, dA_cumsum)
|
| 1201 |
-
states_dtype = states.dtype
|
| 1202 |
-
if states.dtype not in [torch.float32, torch.float64]:
|
| 1203 |
-
states = states.to(torch.float32)
|
| 1204 |
-
# 2. Pass the state to all the chunks by weighted cumsum.
|
| 1205 |
-
# state_passing_ref is much less numerically stable
|
| 1206 |
-
states = rearrange(
|
| 1207 |
-
state_passing_ref(
|
| 1208 |
-
rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1]
|
| 1209 |
-
)[0],
|
| 1210 |
-
"... (p n) -> ... p n",
|
| 1211 |
-
n=dstate,
|
| 1212 |
-
)
|
| 1213 |
-
states = states.to(states_dtype)
|
| 1214 |
-
# 3. Compute the output for each chunk
|
| 1215 |
-
out = chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z)
|
| 1216 |
-
return out
|
| 1217 |
-
|
| 1218 |
-
|
| 1219 |
-
def ssd_selective_scan(
|
| 1220 |
-
x,
|
| 1221 |
-
dt,
|
| 1222 |
-
A,
|
| 1223 |
-
B,
|
| 1224 |
-
C,
|
| 1225 |
-
D=None,
|
| 1226 |
-
z=None,
|
| 1227 |
-
dt_bias=None,
|
| 1228 |
-
dt_softplus=False,
|
| 1229 |
-
dt_limit=(0.0, float("inf")),
|
| 1230 |
-
):
|
| 1231 |
-
"""
|
| 1232 |
-
Argument:
|
| 1233 |
-
x: (batch, seqlen, nheads, headdim)
|
| 1234 |
-
dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
|
| 1235 |
-
A: (nheads) or (dim, dstate)
|
| 1236 |
-
B: (batch, seqlen, ngroups, dstate)
|
| 1237 |
-
C: (batch, seqlen, ngroups, dstate)
|
| 1238 |
-
D: (nheads, headdim) or (nheads,)
|
| 1239 |
-
z: (batch, seqlen, nheads, headdim)
|
| 1240 |
-
dt_bias: (nheads,) or (nheads, headdim)
|
| 1241 |
-
Return:
|
| 1242 |
-
out: (batch, seqlen, nheads, headdim)
|
| 1243 |
-
"""
|
| 1244 |
-
from ..selective_scan_interface import selective_scan_fn
|
| 1245 |
-
|
| 1246 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1247 |
-
_, _, ngroups, dstate = B.shape
|
| 1248 |
-
x = rearrange(x, "b l h p -> b (h p) l")
|
| 1249 |
-
if dt.dim() == 3:
|
| 1250 |
-
dt = repeat(dt, "b l h -> b l h p", p=headdim)
|
| 1251 |
-
dt = rearrange(dt, "b l h p -> b (h p) l")
|
| 1252 |
-
if A.dim() == 1:
|
| 1253 |
-
A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
|
| 1254 |
-
else:
|
| 1255 |
-
A = A.to(dtype=torch.float32)
|
| 1256 |
-
B = rearrange(B, "b l g n -> b g n l")
|
| 1257 |
-
C = rearrange(C, "b l g n -> b g n l")
|
| 1258 |
-
if D is not None:
|
| 1259 |
-
if D.dim() == 2:
|
| 1260 |
-
D = rearrange(D, "h p -> (h p)")
|
| 1261 |
-
else:
|
| 1262 |
-
D = repeat(D, "h -> (h p)", p=headdim)
|
| 1263 |
-
if z is not None:
|
| 1264 |
-
z = rearrange(z, "b l h p -> b (h p) l")
|
| 1265 |
-
if dt_bias is not None:
|
| 1266 |
-
if dt_bias.dim() == 1:
|
| 1267 |
-
dt_bias = repeat(dt_bias, "h -> h p", p=headdim)
|
| 1268 |
-
dt_bias = rearrange(dt_bias, "h p -> (h p)")
|
| 1269 |
-
if dt_limit != (0.0, float("inf")):
|
| 1270 |
-
if dt_bias is not None:
|
| 1271 |
-
dt = dt + rearrange(dt_bias, "d -> d 1")
|
| 1272 |
-
if dt_softplus:
|
| 1273 |
-
dt = F.softplus(dt)
|
| 1274 |
-
dt = dt.clamp(min=dt_limit[0], max=dt_limit[1]).to(x.dtype)
|
| 1275 |
-
dt_bias = None
|
| 1276 |
-
dt_softplus = None
|
| 1277 |
-
out = selective_scan_fn(
|
| 1278 |
-
x, dt, A, B, C, D=D, z=z, delta_bias=dt_bias, delta_softplus=dt_softplus
|
| 1279 |
-
)
|
| 1280 |
-
return rearrange(out, "b (h p) l -> b l h p", p=headdim)
|
| 1281 |
-
|
| 1282 |
-
|
| 1283 |
-
def mamba_conv1d_scan_ref(
|
| 1284 |
-
xBC,
|
| 1285 |
-
conv1d_weight,
|
| 1286 |
-
conv1d_bias,
|
| 1287 |
-
dt,
|
| 1288 |
-
A,
|
| 1289 |
-
chunk_size,
|
| 1290 |
-
D=None,
|
| 1291 |
-
z=None,
|
| 1292 |
-
dt_bias=None,
|
| 1293 |
-
dt_softplus=False,
|
| 1294 |
-
dt_limit=(0.0, float("inf")),
|
| 1295 |
-
activation="silu",
|
| 1296 |
-
headdim=None,
|
| 1297 |
-
ngroups=1,
|
| 1298 |
-
):
|
| 1299 |
-
"""
|
| 1300 |
-
Argument:
|
| 1301 |
-
xBC: (batch, seqlen, dim + 2 * ngroups * dstate) where dim == nheads * headdim
|
| 1302 |
-
conv1d_weight: (dim + 2 * ngroups * dstate, width)
|
| 1303 |
-
conv1d_bias: (dim + 2 * ngroups * dstate,)
|
| 1304 |
-
dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
|
| 1305 |
-
A: (nheads)
|
| 1306 |
-
D: (nheads, headdim) or (nheads,)
|
| 1307 |
-
z: (batch, seqlen, dim)
|
| 1308 |
-
dt_bias: (nheads) or (nheads, headdim)
|
| 1309 |
-
headdim: if D is 1D and z is None, headdim must be passed in
|
| 1310 |
-
Return:
|
| 1311 |
-
out: (batch, seqlen, dim)
|
| 1312 |
-
"""
|
| 1313 |
-
batch, seqlen, nheads = dt.shape[:3]
|
| 1314 |
-
assert nheads % ngroups == 0
|
| 1315 |
-
if z is not None:
|
| 1316 |
-
dim = z.shape[-1]
|
| 1317 |
-
assert dim % nheads == 0
|
| 1318 |
-
headdim = dim // nheads
|
| 1319 |
-
else:
|
| 1320 |
-
if D.dim() == 1:
|
| 1321 |
-
assert headdim is not None
|
| 1322 |
-
else:
|
| 1323 |
-
headdim = D.shape[1]
|
| 1324 |
-
dim = nheads * headdim
|
| 1325 |
-
xBC = rearrange(
|
| 1326 |
-
causal_conv1d_fn(
|
| 1327 |
-
rearrange(xBC, "b s d -> b d s"),
|
| 1328 |
-
conv1d_weight,
|
| 1329 |
-
conv1d_bias,
|
| 1330 |
-
activation=activation,
|
| 1331 |
-
),
|
| 1332 |
-
"b d s -> b s d",
|
| 1333 |
-
)
|
| 1334 |
-
dstate = (xBC.shape[-1] - dim) // ngroups // 2
|
| 1335 |
-
x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
|
| 1336 |
-
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
| 1337 |
-
B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
|
| 1338 |
-
C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
|
| 1339 |
-
z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
|
| 1340 |
-
out = ssd_selective_scan(
|
| 1341 |
-
x,
|
| 1342 |
-
dt.to(x.dtype),
|
| 1343 |
-
A,
|
| 1344 |
-
B,
|
| 1345 |
-
C,
|
| 1346 |
-
D=D.float(),
|
| 1347 |
-
z=z,
|
| 1348 |
-
dt_bias=dt_bias,
|
| 1349 |
-
dt_softplus=dt_softplus,
|
| 1350 |
-
dt_limit=dt_limit,
|
| 1351 |
-
)
|
| 1352 |
-
return rearrange(out, "b s h p -> b s (h p)")
|
| 1353 |
-
|
| 1354 |
-
|
| 1355 |
-
class MambaSplitConv1dScanCombinedFn(torch.autograd.Function):
|
| 1356 |
-
|
| 1357 |
-
@staticmethod
|
| 1358 |
-
@custom_fwd
|
| 1359 |
-
def forward(
|
| 1360 |
-
ctx,
|
| 1361 |
-
zxbcdt,
|
| 1362 |
-
conv1d_weight,
|
| 1363 |
-
conv1d_bias,
|
| 1364 |
-
dt_bias,
|
| 1365 |
-
A,
|
| 1366 |
-
D,
|
| 1367 |
-
chunk_size,
|
| 1368 |
-
initial_states=None,
|
| 1369 |
-
seq_idx=None,
|
| 1370 |
-
dt_limit=(0.0, float("inf")),
|
| 1371 |
-
return_final_states=False,
|
| 1372 |
-
activation="silu",
|
| 1373 |
-
rmsnorm_weight=None,
|
| 1374 |
-
rmsnorm_eps=1e-6,
|
| 1375 |
-
outproj_weight=None,
|
| 1376 |
-
outproj_bias=None,
|
| 1377 |
-
headdim=None,
|
| 1378 |
-
ngroups=1,
|
| 1379 |
-
norm_before_gate=True,
|
| 1380 |
-
):
|
| 1381 |
-
assert activation in [None, "silu", "swish"]
|
| 1382 |
-
if D.dim() == 1:
|
| 1383 |
-
assert headdim is not None
|
| 1384 |
-
(nheads,) = D.shape
|
| 1385 |
-
else:
|
| 1386 |
-
nheads, headdim = D.shape
|
| 1387 |
-
batch, seqlen, _ = zxbcdt.shape
|
| 1388 |
-
dim = nheads * headdim
|
| 1389 |
-
assert nheads % ngroups == 0
|
| 1390 |
-
dstate = (conv1d_weight.shape[0] - dim) // ngroups // 2
|
| 1391 |
-
d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ngroups * dstate - nheads) // 2
|
| 1392 |
-
assert d_nonssm >= 0
|
| 1393 |
-
assert zxbcdt.shape == (
|
| 1394 |
-
batch,
|
| 1395 |
-
seqlen,
|
| 1396 |
-
2 * d_nonssm + 2 * dim + 2 * ngroups * dstate + nheads,
|
| 1397 |
-
)
|
| 1398 |
-
assert dt_bias.shape == (nheads,)
|
| 1399 |
-
assert A.shape == (nheads,)
|
| 1400 |
-
zx0, z, xBC, dt = torch.split(
|
| 1401 |
-
zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1
|
| 1402 |
-
)
|
| 1403 |
-
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
|
| 1404 |
-
xBC_conv = rearrange(
|
| 1405 |
-
causal_conv1d_cuda.causal_conv1d_fwd(
|
| 1406 |
-
rearrange(xBC, "b s d -> b d s"),
|
| 1407 |
-
conv1d_weight,
|
| 1408 |
-
conv1d_bias,
|
| 1409 |
-
seq_idx,
|
| 1410 |
-
None,
|
| 1411 |
-
None,
|
| 1412 |
-
activation in ["silu", "swish"],
|
| 1413 |
-
),
|
| 1414 |
-
"b d s -> b s d",
|
| 1415 |
-
)
|
| 1416 |
-
x, B, C = torch.split(
|
| 1417 |
-
xBC_conv, [dim, ngroups * dstate, ngroups * dstate], dim=-1
|
| 1418 |
-
)
|
| 1419 |
-
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
| 1420 |
-
B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
|
| 1421 |
-
C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
|
| 1422 |
-
z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
|
| 1423 |
-
if rmsnorm_weight is None:
|
| 1424 |
-
out, out_x, dt_out, dA_cumsum, states, final_states = (
|
| 1425 |
-
_mamba_chunk_scan_combined_fwd(
|
| 1426 |
-
x,
|
| 1427 |
-
dt,
|
| 1428 |
-
A,
|
| 1429 |
-
B,
|
| 1430 |
-
C,
|
| 1431 |
-
chunk_size=chunk_size,
|
| 1432 |
-
D=D,
|
| 1433 |
-
z=z,
|
| 1434 |
-
dt_bias=dt_bias,
|
| 1435 |
-
initial_states=initial_states,
|
| 1436 |
-
seq_idx=seq_idx,
|
| 1437 |
-
dt_softplus=True,
|
| 1438 |
-
dt_limit=dt_limit,
|
| 1439 |
-
)
|
| 1440 |
-
)
|
| 1441 |
-
out = rearrange(out, "b s h p -> b s (h p)")
|
| 1442 |
-
rstd = None
|
| 1443 |
-
if d_nonssm > 0:
|
| 1444 |
-
out = torch.cat([_swiglu_fwd(zx0), out], dim=-1)
|
| 1445 |
-
else:
|
| 1446 |
-
out_x, _, dt_out, dA_cumsum, states, final_states = (
|
| 1447 |
-
_mamba_chunk_scan_combined_fwd(
|
| 1448 |
-
x,
|
| 1449 |
-
dt,
|
| 1450 |
-
A,
|
| 1451 |
-
B,
|
| 1452 |
-
C,
|
| 1453 |
-
chunk_size=chunk_size,
|
| 1454 |
-
D=D,
|
| 1455 |
-
z=None,
|
| 1456 |
-
dt_bias=dt_bias,
|
| 1457 |
-
initial_states=initial_states,
|
| 1458 |
-
seq_idx=seq_idx,
|
| 1459 |
-
dt_softplus=True,
|
| 1460 |
-
dt_limit=dt_limit,
|
| 1461 |
-
)
|
| 1462 |
-
)
|
| 1463 |
-
# reshape input data into 2D tensor
|
| 1464 |
-
x_rms = rearrange(out_x, "b s h p -> (b s) (h p)")
|
| 1465 |
-
z_rms = rearrange(z, "b s h p -> (b s) (h p)")
|
| 1466 |
-
rmsnorm_weight = rmsnorm_weight.contiguous()
|
| 1467 |
-
if d_nonssm == 0:
|
| 1468 |
-
out = None
|
| 1469 |
-
else:
|
| 1470 |
-
out01 = torch.empty(
|
| 1471 |
-
(batch, seqlen, d_nonssm + dim),
|
| 1472 |
-
dtype=x_rms.dtype,
|
| 1473 |
-
device=x_rms.device,
|
| 1474 |
-
)
|
| 1475 |
-
out = rearrange(out01[..., d_nonssm:], "b s d -> (b s) d")
|
| 1476 |
-
_swiglu_fwd(zx0, out=out01[..., :d_nonssm])
|
| 1477 |
-
out, _, rstd = _layer_norm_fwd(
|
| 1478 |
-
x_rms,
|
| 1479 |
-
rmsnorm_weight,
|
| 1480 |
-
None,
|
| 1481 |
-
rmsnorm_eps,
|
| 1482 |
-
z_rms,
|
| 1483 |
-
out=out,
|
| 1484 |
-
group_size=dim // ngroups,
|
| 1485 |
-
norm_before_gate=norm_before_gate,
|
| 1486 |
-
is_rms_norm=True,
|
| 1487 |
-
)
|
| 1488 |
-
if d_nonssm == 0:
|
| 1489 |
-
out = rearrange(out, "(b s) d -> b s d", b=batch)
|
| 1490 |
-
else:
|
| 1491 |
-
out = out01
|
| 1492 |
-
ctx.outproj_weight_dtype = (
|
| 1493 |
-
outproj_weight.dtype if outproj_weight is not None else None
|
| 1494 |
-
)
|
| 1495 |
-
if outproj_weight is not None:
|
| 1496 |
-
if torch.is_autocast_enabled():
|
| 1497 |
-
dtype = torch.get_autocast_gpu_dtype()
|
| 1498 |
-
out, outproj_weight = out.to(dtype), outproj_weight.to(dtype)
|
| 1499 |
-
outproj_bias = (
|
| 1500 |
-
outproj_bias.to(dtype) if outproj_bias is not None else None
|
| 1501 |
-
)
|
| 1502 |
-
out = F.linear(out, outproj_weight, outproj_bias)
|
| 1503 |
-
else:
|
| 1504 |
-
assert outproj_bias is None
|
| 1505 |
-
ctx.save_for_backward(
|
| 1506 |
-
zxbcdt,
|
| 1507 |
-
conv1d_weight,
|
| 1508 |
-
conv1d_bias,
|
| 1509 |
-
out_x,
|
| 1510 |
-
A,
|
| 1511 |
-
D,
|
| 1512 |
-
dt_bias,
|
| 1513 |
-
initial_states,
|
| 1514 |
-
seq_idx,
|
| 1515 |
-
rmsnorm_weight,
|
| 1516 |
-
rstd,
|
| 1517 |
-
outproj_weight,
|
| 1518 |
-
outproj_bias,
|
| 1519 |
-
)
|
| 1520 |
-
ctx.dt_limit = dt_limit
|
| 1521 |
-
ctx.return_final_states = return_final_states
|
| 1522 |
-
ctx.activation = activation
|
| 1523 |
-
ctx.rmsnorm_eps = rmsnorm_eps
|
| 1524 |
-
ctx.norm_before_gate = norm_before_gate
|
| 1525 |
-
ctx.chunk_size = chunk_size
|
| 1526 |
-
ctx.headdim = headdim
|
| 1527 |
-
ctx.ngroups = ngroups
|
| 1528 |
-
return out if not return_final_states else (out, final_states)
|
| 1529 |
-
|
| 1530 |
-
@staticmethod
|
| 1531 |
-
@custom_bwd
|
| 1532 |
-
def backward(ctx, dout, *args):
|
| 1533 |
-
(
|
| 1534 |
-
zxbcdt,
|
| 1535 |
-
conv1d_weight,
|
| 1536 |
-
conv1d_bias,
|
| 1537 |
-
out,
|
| 1538 |
-
A,
|
| 1539 |
-
D,
|
| 1540 |
-
dt_bias,
|
| 1541 |
-
initial_states,
|
| 1542 |
-
seq_idx,
|
| 1543 |
-
rmsnorm_weight,
|
| 1544 |
-
rstd,
|
| 1545 |
-
outproj_weight,
|
| 1546 |
-
outproj_bias,
|
| 1547 |
-
) = ctx.saved_tensors
|
| 1548 |
-
dfinal_states = args[0] if ctx.return_final_states else None
|
| 1549 |
-
headdim = ctx.headdim
|
| 1550 |
-
nheads = D.shape[0]
|
| 1551 |
-
dim = nheads * headdim
|
| 1552 |
-
assert nheads % ctx.ngroups == 0
|
| 1553 |
-
dstate = (conv1d_weight.shape[0] - dim) // ctx.ngroups // 2
|
| 1554 |
-
d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ctx.ngroups * dstate - nheads) // 2
|
| 1555 |
-
assert d_nonssm >= 0
|
| 1556 |
-
recompute_output = outproj_weight is not None
|
| 1557 |
-
if recompute_output:
|
| 1558 |
-
out_recompute = torch.empty(
|
| 1559 |
-
*out.shape[:2], d_nonssm + dim, device=out.device, dtype=out.dtype
|
| 1560 |
-
)
|
| 1561 |
-
out0_recompute, out1_recompute = out_recompute.split(
|
| 1562 |
-
[d_nonssm, dim], dim=-1
|
| 1563 |
-
)
|
| 1564 |
-
zx0, z, xBC, dt = torch.split(
|
| 1565 |
-
zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1
|
| 1566 |
-
)
|
| 1567 |
-
# Recompute x, B, C
|
| 1568 |
-
xBC_conv = rearrange(
|
| 1569 |
-
causal_conv1d_cuda.causal_conv1d_fwd(
|
| 1570 |
-
rearrange(xBC, "b s d -> b d s"),
|
| 1571 |
-
conv1d_weight,
|
| 1572 |
-
conv1d_bias,
|
| 1573 |
-
seq_idx,
|
| 1574 |
-
None,
|
| 1575 |
-
None,
|
| 1576 |
-
ctx.activation in ["silu", "swish"],
|
| 1577 |
-
),
|
| 1578 |
-
"b d s -> b s d",
|
| 1579 |
-
)
|
| 1580 |
-
x, B, C = torch.split(
|
| 1581 |
-
xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1
|
| 1582 |
-
)
|
| 1583 |
-
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
| 1584 |
-
B = rearrange(B, "b l (g n) -> b l g n", g=ctx.ngroups)
|
| 1585 |
-
C = rearrange(C, "b l (g n) -> b l g n", g=ctx.ngroups)
|
| 1586 |
-
dzxbcdt = torch.empty_like(zxbcdt)
|
| 1587 |
-
dzx0, dz, dxBC_given, ddt_given = torch.split(
|
| 1588 |
-
dzxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1
|
| 1589 |
-
)
|
| 1590 |
-
dxBC = torch.empty_like(xBC)
|
| 1591 |
-
dx, dB, dC = torch.split(
|
| 1592 |
-
dxBC, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1
|
| 1593 |
-
)
|
| 1594 |
-
z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
|
| 1595 |
-
dx = rearrange(dx, "b l (h p) -> b l h p", h=nheads)
|
| 1596 |
-
dB = rearrange(dB, "b l (g n) -> b l g n", g=ctx.ngroups)
|
| 1597 |
-
dC = rearrange(dC, "b l (g n) -> b l g n", g=ctx.ngroups)
|
| 1598 |
-
if outproj_weight is not None:
|
| 1599 |
-
dout_og = dout
|
| 1600 |
-
dout = F.linear(dout, outproj_weight.t())
|
| 1601 |
-
if d_nonssm > 0:
|
| 1602 |
-
dout0, dout = dout.split([d_nonssm, dim], dim=-1)
|
| 1603 |
-
_swiglu_bwd(zx0, dout0, dxy=dzx0, recompute_output=True, out=out0_recompute)
|
| 1604 |
-
dout = rearrange(dout, "b s (h p) -> b s h p", p=headdim)
|
| 1605 |
-
if rmsnorm_weight is None:
|
| 1606 |
-
dz = rearrange(dz, "b l (h p) -> b l h p", h=nheads)
|
| 1607 |
-
dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states, *rest = (
|
| 1608 |
-
_mamba_chunk_scan_combined_bwd(
|
| 1609 |
-
dout,
|
| 1610 |
-
x,
|
| 1611 |
-
dt,
|
| 1612 |
-
A,
|
| 1613 |
-
B,
|
| 1614 |
-
C,
|
| 1615 |
-
out,
|
| 1616 |
-
ctx.chunk_size,
|
| 1617 |
-
D=D,
|
| 1618 |
-
z=z,
|
| 1619 |
-
dt_bias=dt_bias,
|
| 1620 |
-
initial_states=initial_states,
|
| 1621 |
-
dfinal_states=dfinal_states,
|
| 1622 |
-
seq_idx=seq_idx,
|
| 1623 |
-
dt_softplus=True,
|
| 1624 |
-
dt_limit=ctx.dt_limit,
|
| 1625 |
-
dx=dx,
|
| 1626 |
-
ddt=ddt_given,
|
| 1627 |
-
dB=dB,
|
| 1628 |
-
dC=dC,
|
| 1629 |
-
dz=dz,
|
| 1630 |
-
recompute_output=recompute_output,
|
| 1631 |
-
)
|
| 1632 |
-
)
|
| 1633 |
-
out_for_linear = (
|
| 1634 |
-
rearrange(rest[0], "b s h p -> b s (h p)") if recompute_output else None
|
| 1635 |
-
)
|
| 1636 |
-
drmsnorm_weight = None
|
| 1637 |
-
else:
|
| 1638 |
-
batch = dout.shape[0]
|
| 1639 |
-
dy_rms = rearrange(dout, "b s h p -> (b s) (h p)")
|
| 1640 |
-
dz = rearrange(dz, "b l d -> (b l) d")
|
| 1641 |
-
x_rms = rearrange(out, "b s h p -> (b s) (h p)")
|
| 1642 |
-
z_rms = rearrange(z, "b s h p -> (b s) (h p)")
|
| 1643 |
-
out1_recompute = (
|
| 1644 |
-
rearrange(out1_recompute, "b s d -> (b s) d")
|
| 1645 |
-
if recompute_output
|
| 1646 |
-
else None
|
| 1647 |
-
)
|
| 1648 |
-
dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(
|
| 1649 |
-
dy_rms,
|
| 1650 |
-
x_rms,
|
| 1651 |
-
rmsnorm_weight,
|
| 1652 |
-
None,
|
| 1653 |
-
ctx.rmsnorm_eps,
|
| 1654 |
-
None,
|
| 1655 |
-
rstd,
|
| 1656 |
-
z_rms,
|
| 1657 |
-
group_size=dim // ctx.ngroups,
|
| 1658 |
-
norm_before_gate=ctx.norm_before_gate,
|
| 1659 |
-
is_rms_norm=True,
|
| 1660 |
-
recompute_output=recompute_output,
|
| 1661 |
-
dz=dz,
|
| 1662 |
-
out=out1_recompute if recompute_output else None,
|
| 1663 |
-
)
|
| 1664 |
-
out_for_linear = out_recompute if recompute_output else None
|
| 1665 |
-
dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim)
|
| 1666 |
-
dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = (
|
| 1667 |
-
_mamba_chunk_scan_combined_bwd(
|
| 1668 |
-
dout,
|
| 1669 |
-
x,
|
| 1670 |
-
dt,
|
| 1671 |
-
A,
|
| 1672 |
-
B,
|
| 1673 |
-
C,
|
| 1674 |
-
out,
|
| 1675 |
-
ctx.chunk_size,
|
| 1676 |
-
D=D,
|
| 1677 |
-
z=None,
|
| 1678 |
-
dt_bias=dt_bias,
|
| 1679 |
-
initial_states=initial_states,
|
| 1680 |
-
dfinal_states=dfinal_states,
|
| 1681 |
-
seq_idx=seq_idx,
|
| 1682 |
-
dt_softplus=True,
|
| 1683 |
-
dt_limit=ctx.dt_limit,
|
| 1684 |
-
dx=dx,
|
| 1685 |
-
ddt=ddt_given,
|
| 1686 |
-
dB=dB,
|
| 1687 |
-
dC=dC,
|
| 1688 |
-
)
|
| 1689 |
-
)
|
| 1690 |
-
|
| 1691 |
-
if outproj_weight is not None:
|
| 1692 |
-
doutproj_weight = torch.einsum("bso,bsd->od", dout_og, out_for_linear)
|
| 1693 |
-
doutproj_bias = (
|
| 1694 |
-
dout_og.sum(dim=(0, 1)) if outproj_bias is not None else None
|
| 1695 |
-
)
|
| 1696 |
-
else:
|
| 1697 |
-
doutproj_weight, doutproj_bias = None, None
|
| 1698 |
-
dxBC_given = rearrange(dxBC_given, "b s d -> b d s")
|
| 1699 |
-
dxBC_given, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
|
| 1700 |
-
rearrange(xBC, "b s d -> b d s"),
|
| 1701 |
-
conv1d_weight,
|
| 1702 |
-
conv1d_bias,
|
| 1703 |
-
rearrange(dxBC, "b s d -> b d s"),
|
| 1704 |
-
seq_idx,
|
| 1705 |
-
None,
|
| 1706 |
-
None,
|
| 1707 |
-
dxBC_given,
|
| 1708 |
-
False,
|
| 1709 |
-
ctx.activation in ["silu", "swish"],
|
| 1710 |
-
)
|
| 1711 |
-
dxBC_given = rearrange(dxBC_given, "b d s -> b s d")
|
| 1712 |
-
return (
|
| 1713 |
-
dzxbcdt,
|
| 1714 |
-
dweight,
|
| 1715 |
-
dbias,
|
| 1716 |
-
ddt_bias,
|
| 1717 |
-
dA,
|
| 1718 |
-
dD,
|
| 1719 |
-
None,
|
| 1720 |
-
dinitial_states,
|
| 1721 |
-
None,
|
| 1722 |
-
None,
|
| 1723 |
-
None,
|
| 1724 |
-
None,
|
| 1725 |
-
drmsnorm_weight,
|
| 1726 |
-
None,
|
| 1727 |
-
doutproj_weight,
|
| 1728 |
-
doutproj_bias,
|
| 1729 |
-
None,
|
| 1730 |
-
None,
|
| 1731 |
-
None,
|
| 1732 |
-
)
|
| 1733 |
-
|
| 1734 |
-
|
| 1735 |
-
def mamba_split_conv1d_scan_combined(
|
| 1736 |
-
zxbcdt,
|
| 1737 |
-
conv1d_weight,
|
| 1738 |
-
conv1d_bias,
|
| 1739 |
-
dt_bias,
|
| 1740 |
-
A,
|
| 1741 |
-
D,
|
| 1742 |
-
chunk_size,
|
| 1743 |
-
initial_states=None,
|
| 1744 |
-
seq_idx=None,
|
| 1745 |
-
dt_limit=(0.0, float("inf")),
|
| 1746 |
-
return_final_states=False,
|
| 1747 |
-
activation="silu",
|
| 1748 |
-
rmsnorm_weight=None,
|
| 1749 |
-
rmsnorm_eps=1e-6,
|
| 1750 |
-
outproj_weight=None,
|
| 1751 |
-
outproj_bias=None,
|
| 1752 |
-
headdim=None,
|
| 1753 |
-
ngroups=1,
|
| 1754 |
-
norm_before_gate=True,
|
| 1755 |
-
):
|
| 1756 |
-
"""
|
| 1757 |
-
Argument:
|
| 1758 |
-
zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
|
| 1759 |
-
conv1d_weight: (dim + 2 * ngroups * dstate, width)
|
| 1760 |
-
conv1d_bias: (dim + 2 * ngroups * dstate,)
|
| 1761 |
-
dt_bias: (nheads,)
|
| 1762 |
-
A: (nheads)
|
| 1763 |
-
D: (nheads, headdim) or (nheads,)
|
| 1764 |
-
initial_states: (batch, nheads, headdim, dstate)
|
| 1765 |
-
seq_idx: (batch, seqlen), int32
|
| 1766 |
-
rmsnorm_weight: (dim,)
|
| 1767 |
-
outproj_weight: (out_dim, dim)
|
| 1768 |
-
outproj_bias: (out_dim,)
|
| 1769 |
-
headdim: if D is 1D, headdim must be passed in
|
| 1770 |
-
norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
|
| 1771 |
-
Return:
|
| 1772 |
-
out: (batch, seqlen, dim)
|
| 1773 |
-
"""
|
| 1774 |
-
return MambaSplitConv1dScanCombinedFn.apply(
|
| 1775 |
-
zxbcdt,
|
| 1776 |
-
conv1d_weight,
|
| 1777 |
-
conv1d_bias,
|
| 1778 |
-
dt_bias,
|
| 1779 |
-
A,
|
| 1780 |
-
D,
|
| 1781 |
-
chunk_size,
|
| 1782 |
-
initial_states,
|
| 1783 |
-
seq_idx,
|
| 1784 |
-
dt_limit,
|
| 1785 |
-
return_final_states,
|
| 1786 |
-
activation,
|
| 1787 |
-
rmsnorm_weight,
|
| 1788 |
-
rmsnorm_eps,
|
| 1789 |
-
outproj_weight,
|
| 1790 |
-
outproj_bias,
|
| 1791 |
-
headdim,
|
| 1792 |
-
ngroups,
|
| 1793 |
-
norm_before_gate,
|
| 1794 |
-
)
|
| 1795 |
-
|
| 1796 |
-
|
| 1797 |
-
def mamba_split_conv1d_scan_ref(
|
| 1798 |
-
zxbcdt,
|
| 1799 |
-
conv1d_weight,
|
| 1800 |
-
conv1d_bias,
|
| 1801 |
-
dt_bias,
|
| 1802 |
-
A,
|
| 1803 |
-
D,
|
| 1804 |
-
chunk_size,
|
| 1805 |
-
dt_limit=(0.0, float("inf")),
|
| 1806 |
-
activation="silu",
|
| 1807 |
-
rmsnorm_weight=None,
|
| 1808 |
-
rmsnorm_eps=1e-6,
|
| 1809 |
-
outproj_weight=None,
|
| 1810 |
-
outproj_bias=None,
|
| 1811 |
-
headdim=None,
|
| 1812 |
-
ngroups=1,
|
| 1813 |
-
norm_before_gate=True,
|
| 1814 |
-
):
|
| 1815 |
-
"""
|
| 1816 |
-
Argument:
|
| 1817 |
-
zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
|
| 1818 |
-
conv1d_weight: (dim + 2 * ngroups * dstate, width)
|
| 1819 |
-
conv1d_bias: (dim + 2 * ngroups * dstate,)
|
| 1820 |
-
dt_bias: (nheads,)
|
| 1821 |
-
A: (nheads)
|
| 1822 |
-
D: (nheads, headdim) or (nheads,)
|
| 1823 |
-
rmsnorm_weight: (dim,)
|
| 1824 |
-
outproj_weight: (out_dim, dim)
|
| 1825 |
-
outproj_bias: (out_dim,)
|
| 1826 |
-
headdim: if D is 1D, headdim must be passed in
|
| 1827 |
-
norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
|
| 1828 |
-
Return:
|
| 1829 |
-
out: (batch, seqlen, dim)
|
| 1830 |
-
"""
|
| 1831 |
-
if D.dim() == 1:
|
| 1832 |
-
assert headdim is not None
|
| 1833 |
-
(nheads,) = D.shape
|
| 1834 |
-
else:
|
| 1835 |
-
nheads, headdim = D.shape
|
| 1836 |
-
assert nheads % ngroups == 0
|
| 1837 |
-
batch, seqlen, _ = zxbcdt.shape
|
| 1838 |
-
dim = nheads * headdim
|
| 1839 |
-
dstate = (zxbcdt.shape[-1] - 2 * dim - nheads) // ngroups // 2
|
| 1840 |
-
assert zxbcdt.shape == (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads)
|
| 1841 |
-
assert dt_bias.shape == (nheads,)
|
| 1842 |
-
assert A.shape == (nheads,)
|
| 1843 |
-
if rmsnorm_weight is not None:
|
| 1844 |
-
assert rmsnorm_weight.shape == (dim,)
|
| 1845 |
-
z, xBC, dt = torch.split(zxbcdt, [dim, dim + 2 * ngroups * dstate, nheads], dim=-1)
|
| 1846 |
-
xBC = rearrange(
|
| 1847 |
-
causal_conv1d_fn(
|
| 1848 |
-
rearrange(xBC, "b s d -> b d s"),
|
| 1849 |
-
conv1d_weight,
|
| 1850 |
-
conv1d_bias,
|
| 1851 |
-
activation=activation,
|
| 1852 |
-
),
|
| 1853 |
-
"b d s -> b s d",
|
| 1854 |
-
)
|
| 1855 |
-
x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
|
| 1856 |
-
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
| 1857 |
-
B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
|
| 1858 |
-
C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
|
| 1859 |
-
z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
|
| 1860 |
-
out = ssd_selective_scan(
|
| 1861 |
-
x,
|
| 1862 |
-
dt.to(x.dtype),
|
| 1863 |
-
A,
|
| 1864 |
-
B,
|
| 1865 |
-
C,
|
| 1866 |
-
D=D.float(),
|
| 1867 |
-
z=z if rmsnorm_weight is None else None,
|
| 1868 |
-
dt_bias=dt_bias,
|
| 1869 |
-
dt_softplus=True,
|
| 1870 |
-
dt_limit=dt_limit,
|
| 1871 |
-
)
|
| 1872 |
-
out = rearrange(out, "b s h p -> b s (h p)")
|
| 1873 |
-
if rmsnorm_weight is not None:
|
| 1874 |
-
out = rmsnorm_fn(
|
| 1875 |
-
out,
|
| 1876 |
-
rmsnorm_weight,
|
| 1877 |
-
None,
|
| 1878 |
-
z=rearrange(z, "b l h p -> b l (h p)"),
|
| 1879 |
-
eps=rmsnorm_eps,
|
| 1880 |
-
norm_before_gate=norm_before_gate,
|
| 1881 |
-
)
|
| 1882 |
-
if outproj_weight is not None:
|
| 1883 |
-
out = F.linear(out, outproj_weight, outproj_bias)
|
| 1884 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/utils/__init__.py
DELETED
|
File without changes
|
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/__init__.py
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
__version__ = "2.2.4"
|
| 2 |
-
|
| 3 |
-
from .ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
| 4 |
-
from .modules.mamba_simple import Mamba
|
| 5 |
-
from .modules.mamba2 import Mamba2
|
| 6 |
-
from .models.mixer_seq_simple import MambaLMHeadModel
|
| 7 |
-
|
| 8 |
-
__all__ = [
|
| 9 |
-
"selective_scan_fn",
|
| 10 |
-
"mamba_inner_fn",
|
| 11 |
-
"Mamba",
|
| 12 |
-
"Mamba2",
|
| 13 |
-
"MambaLMHeadModel",
|
| 14 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/distributed/__init__.py
DELETED
|
File without changes
|
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py
DELETED
|
@@ -1,326 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2024, Tri Dao.
|
| 2 |
-
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
|
| 3 |
-
from typing import Optional
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn as nn
|
| 7 |
-
import torch.nn.functional as F
|
| 8 |
-
from torch import Tensor
|
| 9 |
-
from torch.distributed import ProcessGroup
|
| 10 |
-
from ..utils.torch import custom_bwd, custom_fwd
|
| 11 |
-
|
| 12 |
-
from einops import rearrange
|
| 13 |
-
|
| 14 |
-
from ..distributed.distributed_utils import (
|
| 15 |
-
all_gather_raw,
|
| 16 |
-
all_reduce,
|
| 17 |
-
all_reduce_raw,
|
| 18 |
-
reduce_scatter,
|
| 19 |
-
reduce_scatter_raw,
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class ParallelLinearFunc(torch.autograd.Function):
|
| 24 |
-
@staticmethod
|
| 25 |
-
@custom_fwd
|
| 26 |
-
def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
|
| 27 |
-
"""
|
| 28 |
-
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
|
| 29 |
-
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
|
| 30 |
-
"""
|
| 31 |
-
ctx.compute_weight_gradient = weight.requires_grad
|
| 32 |
-
ctx.process_group = process_group
|
| 33 |
-
ctx.sequence_parallel = sequence_parallel
|
| 34 |
-
|
| 35 |
-
if torch.is_autocast_enabled():
|
| 36 |
-
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
| 37 |
-
x = x.contiguous()
|
| 38 |
-
if process_group is not None and sequence_parallel:
|
| 39 |
-
# We want to kick off the all_gather early, before weight dtype conversion
|
| 40 |
-
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
| 41 |
-
else:
|
| 42 |
-
total_x = x
|
| 43 |
-
|
| 44 |
-
if torch.is_autocast_enabled():
|
| 45 |
-
weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
|
| 46 |
-
bias = (
|
| 47 |
-
bias.to(dtype=torch.get_autocast_gpu_dtype())
|
| 48 |
-
if bias is not None
|
| 49 |
-
else None
|
| 50 |
-
)
|
| 51 |
-
weight = weight.contiguous()
|
| 52 |
-
if process_group is not None and sequence_parallel:
|
| 53 |
-
handle_x.wait()
|
| 54 |
-
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
|
| 55 |
-
batch_dim = batch_shape.numel()
|
| 56 |
-
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
|
| 57 |
-
output = F.linear(total_x, weight, bias)
|
| 58 |
-
if ctx.compute_weight_gradient:
|
| 59 |
-
ctx.save_for_backward(x, weight)
|
| 60 |
-
else:
|
| 61 |
-
ctx.save_for_backward(weight)
|
| 62 |
-
return output
|
| 63 |
-
|
| 64 |
-
@staticmethod
|
| 65 |
-
@custom_bwd
|
| 66 |
-
def backward(ctx, grad_output):
|
| 67 |
-
grad_output = grad_output.contiguous()
|
| 68 |
-
process_group = ctx.process_group
|
| 69 |
-
sequence_parallel = ctx.sequence_parallel
|
| 70 |
-
if ctx.compute_weight_gradient:
|
| 71 |
-
x, weight = ctx.saved_tensors
|
| 72 |
-
if process_group is not None and sequence_parallel:
|
| 73 |
-
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
| 74 |
-
else:
|
| 75 |
-
total_x = x
|
| 76 |
-
else:
|
| 77 |
-
(weight,) = ctx.saved_tensors
|
| 78 |
-
total_x = None
|
| 79 |
-
batch_shape = grad_output.shape[:-1]
|
| 80 |
-
batch_dim = batch_shape.numel()
|
| 81 |
-
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
| 82 |
-
if ctx.needs_input_grad[0]:
|
| 83 |
-
grad_input = F.linear(grad_output, weight.t())
|
| 84 |
-
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
| 85 |
-
if process_group is not None:
|
| 86 |
-
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
|
| 87 |
-
grad_input, handle_grad_input = reduce_fn(
|
| 88 |
-
grad_input, process_group, async_op=True
|
| 89 |
-
)
|
| 90 |
-
else:
|
| 91 |
-
grad_input = None
|
| 92 |
-
if ctx.needs_input_grad[1]:
|
| 93 |
-
assert ctx.compute_weight_gradient
|
| 94 |
-
if process_group is not None and sequence_parallel:
|
| 95 |
-
handle_x.wait()
|
| 96 |
-
grad_weight = torch.einsum(
|
| 97 |
-
"bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1])
|
| 98 |
-
)
|
| 99 |
-
else:
|
| 100 |
-
grad_weight = None
|
| 101 |
-
grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None
|
| 102 |
-
if process_group is not None and ctx.needs_input_grad[0]:
|
| 103 |
-
handle_grad_input.wait()
|
| 104 |
-
return grad_input, grad_weight, grad_bias, None, None
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
def parallel_linear_func(
|
| 108 |
-
x: Tensor,
|
| 109 |
-
weight: Tensor,
|
| 110 |
-
bias: Optional[Tensor] = None,
|
| 111 |
-
process_group: Optional[ProcessGroup] = None,
|
| 112 |
-
sequence_parallel: bool = True,
|
| 113 |
-
):
|
| 114 |
-
return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel)
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
class ColumnParallelLinear(nn.Linear):
|
| 118 |
-
def __init__(
|
| 119 |
-
self,
|
| 120 |
-
in_features: int,
|
| 121 |
-
out_features: int,
|
| 122 |
-
process_group: ProcessGroup,
|
| 123 |
-
bias: bool = True,
|
| 124 |
-
sequence_parallel=True,
|
| 125 |
-
multiple_of=1,
|
| 126 |
-
device=None,
|
| 127 |
-
dtype=None,
|
| 128 |
-
) -> None:
|
| 129 |
-
world_size = torch.distributed.get_world_size(process_group)
|
| 130 |
-
if out_features % multiple_of:
|
| 131 |
-
raise ValueError(
|
| 132 |
-
f"out_features ({out_features}) must be a multiple of {multiple_of}"
|
| 133 |
-
)
|
| 134 |
-
multiple = out_features // multiple_of
|
| 135 |
-
# We want to split @multiple across world_size, but it could be an uneven split
|
| 136 |
-
div = multiple // world_size
|
| 137 |
-
mod = multiple % world_size
|
| 138 |
-
# The first @mod ranks get @div + 1 copies, the rest get @div copies
|
| 139 |
-
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
| 140 |
-
super().__init__(
|
| 141 |
-
in_features,
|
| 142 |
-
local_multiple * multiple_of,
|
| 143 |
-
bias=bias,
|
| 144 |
-
device=device,
|
| 145 |
-
dtype=dtype,
|
| 146 |
-
)
|
| 147 |
-
self.process_group = process_group
|
| 148 |
-
self.sequence_parallel = sequence_parallel
|
| 149 |
-
|
| 150 |
-
def forward(self, x):
|
| 151 |
-
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
| 152 |
-
# we do an all_gather of x before doing the matmul.
|
| 153 |
-
# If not, then the input is already gathered.
|
| 154 |
-
return parallel_linear_func(
|
| 155 |
-
x,
|
| 156 |
-
self.weight,
|
| 157 |
-
self.bias,
|
| 158 |
-
process_group=self.process_group,
|
| 159 |
-
sequence_parallel=self.sequence_parallel,
|
| 160 |
-
)
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
class RowParallelLinear(nn.Linear):
|
| 164 |
-
def __init__(
|
| 165 |
-
self,
|
| 166 |
-
in_features: int,
|
| 167 |
-
out_features: int,
|
| 168 |
-
process_group: ProcessGroup,
|
| 169 |
-
bias: bool = True,
|
| 170 |
-
sequence_parallel=True,
|
| 171 |
-
multiple_of=1,
|
| 172 |
-
device=None,
|
| 173 |
-
dtype=None,
|
| 174 |
-
) -> None:
|
| 175 |
-
world_size = torch.distributed.get_world_size(process_group)
|
| 176 |
-
rank = torch.distributed.get_rank(process_group)
|
| 177 |
-
if in_features % multiple_of:
|
| 178 |
-
raise ValueError(
|
| 179 |
-
f"in_features ({in_features}) must be a multiple of {multiple_of}"
|
| 180 |
-
)
|
| 181 |
-
multiple = in_features // multiple_of
|
| 182 |
-
# We want to split @multiple across world_size, but it could be an uneven split
|
| 183 |
-
div = multiple // world_size
|
| 184 |
-
mod = multiple % world_size
|
| 185 |
-
# The first @mod ranks get @div + 1 copies, the rest get @div copies
|
| 186 |
-
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
| 187 |
-
# Only rank 0 will have bias
|
| 188 |
-
super().__init__(
|
| 189 |
-
local_multiple * multiple_of,
|
| 190 |
-
out_features,
|
| 191 |
-
bias=bias and rank == 0,
|
| 192 |
-
device=device,
|
| 193 |
-
dtype=dtype,
|
| 194 |
-
)
|
| 195 |
-
self.process_group = process_group
|
| 196 |
-
self.sequence_parallel = sequence_parallel
|
| 197 |
-
|
| 198 |
-
def forward(self, x):
|
| 199 |
-
"""
|
| 200 |
-
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
|
| 201 |
-
a reduce_scatter of the result.
|
| 202 |
-
"""
|
| 203 |
-
out = parallel_linear_func(x, self.weight, self.bias)
|
| 204 |
-
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
| 205 |
-
return reduce_fn(out, self.process_group)
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
class VocabParallelEmbedding(nn.Embedding):
|
| 209 |
-
def __init__(
|
| 210 |
-
self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs
|
| 211 |
-
):
|
| 212 |
-
self.process_group = process_group
|
| 213 |
-
if process_group is not None:
|
| 214 |
-
world_size = torch.distributed.get_world_size(process_group)
|
| 215 |
-
if num_embeddings % world_size != 0:
|
| 216 |
-
raise ValueError(
|
| 217 |
-
f"num_embeddings ({num_embeddings}) must be divisible by "
|
| 218 |
-
f"world_size ({world_size})"
|
| 219 |
-
)
|
| 220 |
-
if world_size > 1 and padding_idx is not None:
|
| 221 |
-
raise RuntimeError("ParallelEmbedding does not support padding_idx")
|
| 222 |
-
else:
|
| 223 |
-
world_size = 1
|
| 224 |
-
super().__init__(
|
| 225 |
-
num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs
|
| 226 |
-
)
|
| 227 |
-
|
| 228 |
-
def forward(self, input: Tensor) -> Tensor:
|
| 229 |
-
if self.process_group is None:
|
| 230 |
-
return super().forward(input)
|
| 231 |
-
else:
|
| 232 |
-
rank = torch.distributed.get_rank(self.process_group)
|
| 233 |
-
vocab_size = self.num_embeddings
|
| 234 |
-
vocab_start_index, vocab_end_index = (
|
| 235 |
-
rank * vocab_size,
|
| 236 |
-
(rank + 1) * vocab_size,
|
| 237 |
-
)
|
| 238 |
-
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
| 239 |
-
input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
|
| 240 |
-
input = input - vocab_start_index
|
| 241 |
-
input[input_ids_mask] = 0
|
| 242 |
-
embeddings = super().forward(input)
|
| 243 |
-
embeddings[input_ids_mask] = 0.0
|
| 244 |
-
return embeddings
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
class ColumnParallelEmbedding(nn.Embedding):
|
| 248 |
-
def __init__(
|
| 249 |
-
self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs
|
| 250 |
-
):
|
| 251 |
-
self.process_group = process_group
|
| 252 |
-
if process_group is not None:
|
| 253 |
-
world_size = torch.distributed.get_world_size(process_group)
|
| 254 |
-
if embedding_dim % world_size != 0:
|
| 255 |
-
raise ValueError(
|
| 256 |
-
f"embedding_dim ({embedding_dim}) must be divisible by "
|
| 257 |
-
f"world_size ({world_size})"
|
| 258 |
-
)
|
| 259 |
-
else:
|
| 260 |
-
world_size = 1
|
| 261 |
-
super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
class ParallelEmbeddings(nn.Module):
|
| 265 |
-
def __init__(
|
| 266 |
-
self,
|
| 267 |
-
embed_dim,
|
| 268 |
-
vocab_size,
|
| 269 |
-
max_position_embeddings,
|
| 270 |
-
process_group,
|
| 271 |
-
padding_idx=None,
|
| 272 |
-
sequence_parallel=True,
|
| 273 |
-
device=None,
|
| 274 |
-
dtype=None,
|
| 275 |
-
):
|
| 276 |
-
"""
|
| 277 |
-
If max_position_embeddings <= 0, there's no position embeddings
|
| 278 |
-
"""
|
| 279 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
| 280 |
-
super().__init__()
|
| 281 |
-
self.process_group = process_group
|
| 282 |
-
self.sequence_parallel = sequence_parallel
|
| 283 |
-
self.word_embeddings = VocabParallelEmbedding(
|
| 284 |
-
vocab_size,
|
| 285 |
-
embed_dim,
|
| 286 |
-
padding_idx=padding_idx,
|
| 287 |
-
process_group=process_group,
|
| 288 |
-
**factory_kwargs,
|
| 289 |
-
)
|
| 290 |
-
self.max_position_embeddings = max_position_embeddings
|
| 291 |
-
if self.max_position_embeddings > 0:
|
| 292 |
-
self.position_embeddings = ColumnParallelEmbedding(
|
| 293 |
-
max_position_embeddings,
|
| 294 |
-
embed_dim,
|
| 295 |
-
process_group=process_group,
|
| 296 |
-
**factory_kwargs,
|
| 297 |
-
)
|
| 298 |
-
|
| 299 |
-
def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
|
| 300 |
-
"""
|
| 301 |
-
input_ids: (batch, seqlen)
|
| 302 |
-
position_ids: (batch, seqlen)
|
| 303 |
-
"""
|
| 304 |
-
batch_size, seqlen = input_ids.shape
|
| 305 |
-
world_size = torch.distributed.get_world_size(self.process_group)
|
| 306 |
-
embeddings = self.word_embeddings(input_ids)
|
| 307 |
-
if self.max_position_embeddings > 0:
|
| 308 |
-
if position_ids is None:
|
| 309 |
-
position_ids = torch.arange(
|
| 310 |
-
seqlen, dtype=torch.long, device=input_ids.device
|
| 311 |
-
)
|
| 312 |
-
position_embeddings = self.position_embeddings(position_ids)
|
| 313 |
-
if world_size <= 1:
|
| 314 |
-
embeddings = embeddings + position_embeddings
|
| 315 |
-
else:
|
| 316 |
-
partition_dim = self.position_embeddings.embedding_dim
|
| 317 |
-
rank = torch.distributed.get_rank(self.process_group)
|
| 318 |
-
embeddings[
|
| 319 |
-
..., rank * partition_dim : (rank + 1) * partition_dim
|
| 320 |
-
] += position_embeddings
|
| 321 |
-
if combine_batch_seqlen_dim:
|
| 322 |
-
embeddings = rearrange(embeddings, "b s d -> (b s) d")
|
| 323 |
-
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
| 324 |
-
return (
|
| 325 |
-
embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
|
| 326 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/models/__init__.py
DELETED
|
File without changes
|
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py
DELETED
|
@@ -1,338 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2023, Albert Gu, Tri Dao.
|
| 2 |
-
|
| 3 |
-
import math
|
| 4 |
-
from functools import partial
|
| 5 |
-
import json
|
| 6 |
-
import os
|
| 7 |
-
import copy
|
| 8 |
-
|
| 9 |
-
from collections import namedtuple
|
| 10 |
-
|
| 11 |
-
import torch
|
| 12 |
-
import torch.nn as nn
|
| 13 |
-
|
| 14 |
-
from .config_mamba import MambaConfig
|
| 15 |
-
from ..modules.mamba_simple import Mamba
|
| 16 |
-
from ..modules.mamba2 import Mamba2
|
| 17 |
-
from ..modules.mha import MHA
|
| 18 |
-
from ..modules.mlp import GatedMLP
|
| 19 |
-
from ..modules.block import Block
|
| 20 |
-
from ..utils.generation import GenerationMixin
|
| 21 |
-
from ..utils.hf import load_config_hf, load_state_dict_hf
|
| 22 |
-
|
| 23 |
-
try:
|
| 24 |
-
from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
|
| 25 |
-
except ImportError:
|
| 26 |
-
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def create_block(
|
| 30 |
-
d_model,
|
| 31 |
-
d_intermediate,
|
| 32 |
-
ssm_cfg=None,
|
| 33 |
-
attn_layer_idx=None,
|
| 34 |
-
attn_cfg=None,
|
| 35 |
-
norm_epsilon=1e-5,
|
| 36 |
-
rms_norm=False,
|
| 37 |
-
residual_in_fp32=False,
|
| 38 |
-
fused_add_norm=False,
|
| 39 |
-
layer_idx=None,
|
| 40 |
-
device=None,
|
| 41 |
-
dtype=None,
|
| 42 |
-
):
|
| 43 |
-
if ssm_cfg is None:
|
| 44 |
-
ssm_cfg = {}
|
| 45 |
-
if attn_layer_idx is None:
|
| 46 |
-
attn_layer_idx = []
|
| 47 |
-
if attn_cfg is None:
|
| 48 |
-
attn_cfg = {}
|
| 49 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
| 50 |
-
if layer_idx not in attn_layer_idx:
|
| 51 |
-
# Create a copy of the config to modify
|
| 52 |
-
ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
|
| 53 |
-
ssm_layer = ssm_cfg.pop("layer", "Mamba1")
|
| 54 |
-
if ssm_layer not in ["Mamba1", "Mamba2"]:
|
| 55 |
-
raise ValueError(
|
| 56 |
-
f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2"
|
| 57 |
-
)
|
| 58 |
-
mixer_cls = partial(
|
| 59 |
-
Mamba2 if ssm_layer == "Mamba2" else Mamba,
|
| 60 |
-
layer_idx=layer_idx,
|
| 61 |
-
**ssm_cfg,
|
| 62 |
-
**factory_kwargs,
|
| 63 |
-
)
|
| 64 |
-
else:
|
| 65 |
-
mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
|
| 66 |
-
norm_cls = partial(
|
| 67 |
-
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
| 68 |
-
)
|
| 69 |
-
if d_intermediate == 0:
|
| 70 |
-
mlp_cls = nn.Identity
|
| 71 |
-
else:
|
| 72 |
-
mlp_cls = partial(
|
| 73 |
-
GatedMLP,
|
| 74 |
-
hidden_features=d_intermediate,
|
| 75 |
-
out_features=d_model,
|
| 76 |
-
**factory_kwargs,
|
| 77 |
-
)
|
| 78 |
-
block = Block(
|
| 79 |
-
d_model,
|
| 80 |
-
mixer_cls,
|
| 81 |
-
mlp_cls,
|
| 82 |
-
norm_cls=norm_cls,
|
| 83 |
-
fused_add_norm=fused_add_norm,
|
| 84 |
-
residual_in_fp32=residual_in_fp32,
|
| 85 |
-
)
|
| 86 |
-
block.layer_idx = layer_idx
|
| 87 |
-
return block
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
| 91 |
-
def _init_weights(
|
| 92 |
-
module,
|
| 93 |
-
n_layer,
|
| 94 |
-
initializer_range=0.02, # Now only used for embedding layer.
|
| 95 |
-
rescale_prenorm_residual=True,
|
| 96 |
-
n_residuals_per_layer=1, # Change to 2 if we have MLP
|
| 97 |
-
):
|
| 98 |
-
if isinstance(module, nn.Linear):
|
| 99 |
-
if module.bias is not None:
|
| 100 |
-
if not getattr(module.bias, "_no_reinit", False):
|
| 101 |
-
nn.init.zeros_(module.bias)
|
| 102 |
-
elif isinstance(module, nn.Embedding):
|
| 103 |
-
nn.init.normal_(module.weight, std=initializer_range)
|
| 104 |
-
|
| 105 |
-
if rescale_prenorm_residual:
|
| 106 |
-
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
| 107 |
-
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
| 108 |
-
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
| 109 |
-
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
| 110 |
-
#
|
| 111 |
-
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
| 112 |
-
for name, p in module.named_parameters():
|
| 113 |
-
if name in ["out_proj.weight", "fc2.weight"]:
|
| 114 |
-
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
| 115 |
-
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
| 116 |
-
# We need to reinit p since this code could be called multiple times
|
| 117 |
-
# Having just p *= scale would repeatedly scale it down
|
| 118 |
-
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
| 119 |
-
with torch.no_grad():
|
| 120 |
-
p /= math.sqrt(n_residuals_per_layer * n_layer)
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
class MixerModel(nn.Module):
|
| 124 |
-
def __init__(
|
| 125 |
-
self,
|
| 126 |
-
d_model: int,
|
| 127 |
-
n_layer: int,
|
| 128 |
-
d_intermediate: int,
|
| 129 |
-
vocab_size: int,
|
| 130 |
-
ssm_cfg=None,
|
| 131 |
-
attn_layer_idx=None,
|
| 132 |
-
attn_cfg=None,
|
| 133 |
-
norm_epsilon: float = 1e-5,
|
| 134 |
-
rms_norm: bool = False,
|
| 135 |
-
initializer_cfg=None,
|
| 136 |
-
fused_add_norm=False,
|
| 137 |
-
residual_in_fp32=False,
|
| 138 |
-
device=None,
|
| 139 |
-
dtype=None,
|
| 140 |
-
) -> None:
|
| 141 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
| 142 |
-
super().__init__()
|
| 143 |
-
self.residual_in_fp32 = residual_in_fp32
|
| 144 |
-
|
| 145 |
-
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
|
| 146 |
-
|
| 147 |
-
# We change the order of residual and layer norm:
|
| 148 |
-
# Instead of LN -> Attn / MLP -> Add, we do:
|
| 149 |
-
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
|
| 150 |
-
# the main branch (output of MLP / Mixer). The model definition is unchanged.
|
| 151 |
-
# This is for performance reason: we can fuse add + layer_norm.
|
| 152 |
-
self.fused_add_norm = fused_add_norm
|
| 153 |
-
if self.fused_add_norm:
|
| 154 |
-
if layer_norm_fn is None or rms_norm_fn is None:
|
| 155 |
-
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
|
| 156 |
-
|
| 157 |
-
self.layers = nn.ModuleList(
|
| 158 |
-
[
|
| 159 |
-
create_block(
|
| 160 |
-
d_model,
|
| 161 |
-
d_intermediate=d_intermediate,
|
| 162 |
-
ssm_cfg=ssm_cfg,
|
| 163 |
-
attn_layer_idx=attn_layer_idx,
|
| 164 |
-
attn_cfg=attn_cfg,
|
| 165 |
-
norm_epsilon=norm_epsilon,
|
| 166 |
-
rms_norm=rms_norm,
|
| 167 |
-
residual_in_fp32=residual_in_fp32,
|
| 168 |
-
fused_add_norm=fused_add_norm,
|
| 169 |
-
layer_idx=i,
|
| 170 |
-
**factory_kwargs,
|
| 171 |
-
)
|
| 172 |
-
for i in range(n_layer)
|
| 173 |
-
]
|
| 174 |
-
)
|
| 175 |
-
|
| 176 |
-
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
|
| 177 |
-
d_model, eps=norm_epsilon, **factory_kwargs
|
| 178 |
-
)
|
| 179 |
-
|
| 180 |
-
self.apply(
|
| 181 |
-
partial(
|
| 182 |
-
_init_weights,
|
| 183 |
-
n_layer=n_layer,
|
| 184 |
-
**(initializer_cfg if initializer_cfg is not None else {}),
|
| 185 |
-
n_residuals_per_layer=(
|
| 186 |
-
1 if d_intermediate == 0 else 2
|
| 187 |
-
), # 2 if we have MLP
|
| 188 |
-
)
|
| 189 |
-
)
|
| 190 |
-
|
| 191 |
-
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 192 |
-
return {
|
| 193 |
-
i: layer.allocate_inference_cache(
|
| 194 |
-
batch_size, max_seqlen, dtype=dtype, **kwargs
|
| 195 |
-
)
|
| 196 |
-
for i, layer in enumerate(self.layers)
|
| 197 |
-
}
|
| 198 |
-
|
| 199 |
-
def forward(self, input_ids, inference_params=None, **mixer_kwargs):
|
| 200 |
-
hidden_states = self.embedding(input_ids)
|
| 201 |
-
residual = None
|
| 202 |
-
for layer in self.layers:
|
| 203 |
-
hidden_states, residual = layer(
|
| 204 |
-
hidden_states,
|
| 205 |
-
residual,
|
| 206 |
-
inference_params=inference_params,
|
| 207 |
-
**mixer_kwargs,
|
| 208 |
-
)
|
| 209 |
-
if not self.fused_add_norm:
|
| 210 |
-
residual = (
|
| 211 |
-
(hidden_states + residual) if residual is not None else hidden_states
|
| 212 |
-
)
|
| 213 |
-
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
| 214 |
-
else:
|
| 215 |
-
# Set prenorm=False here since we don't need the residual
|
| 216 |
-
hidden_states = layer_norm_fn(
|
| 217 |
-
hidden_states,
|
| 218 |
-
self.norm_f.weight,
|
| 219 |
-
self.norm_f.bias,
|
| 220 |
-
eps=self.norm_f.eps,
|
| 221 |
-
residual=residual,
|
| 222 |
-
prenorm=False,
|
| 223 |
-
residual_in_fp32=self.residual_in_fp32,
|
| 224 |
-
is_rms_norm=isinstance(self.norm_f, RMSNorm),
|
| 225 |
-
)
|
| 226 |
-
return hidden_states
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
class MambaLMHeadModel(nn.Module, GenerationMixin):
|
| 230 |
-
|
| 231 |
-
def __init__(
|
| 232 |
-
self,
|
| 233 |
-
config: MambaConfig,
|
| 234 |
-
initializer_cfg=None,
|
| 235 |
-
device=None,
|
| 236 |
-
dtype=None,
|
| 237 |
-
) -> None:
|
| 238 |
-
self.config = config
|
| 239 |
-
d_model = config.d_model
|
| 240 |
-
n_layer = config.n_layer
|
| 241 |
-
d_intermediate = config.d_intermediate
|
| 242 |
-
vocab_size = config.vocab_size
|
| 243 |
-
ssm_cfg = config.ssm_cfg
|
| 244 |
-
attn_layer_idx = config.attn_layer_idx
|
| 245 |
-
attn_cfg = config.attn_cfg
|
| 246 |
-
rms_norm = config.rms_norm
|
| 247 |
-
residual_in_fp32 = config.residual_in_fp32
|
| 248 |
-
fused_add_norm = config.fused_add_norm
|
| 249 |
-
pad_vocab_size_multiple = config.pad_vocab_size_multiple
|
| 250 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
| 251 |
-
|
| 252 |
-
super().__init__()
|
| 253 |
-
if vocab_size % pad_vocab_size_multiple != 0:
|
| 254 |
-
vocab_size += pad_vocab_size_multiple - (
|
| 255 |
-
vocab_size % pad_vocab_size_multiple
|
| 256 |
-
)
|
| 257 |
-
self.backbone = MixerModel(
|
| 258 |
-
d_model=d_model,
|
| 259 |
-
n_layer=n_layer,
|
| 260 |
-
d_intermediate=d_intermediate,
|
| 261 |
-
vocab_size=vocab_size,
|
| 262 |
-
ssm_cfg=ssm_cfg,
|
| 263 |
-
attn_layer_idx=attn_layer_idx,
|
| 264 |
-
attn_cfg=attn_cfg,
|
| 265 |
-
rms_norm=rms_norm,
|
| 266 |
-
initializer_cfg=initializer_cfg,
|
| 267 |
-
fused_add_norm=fused_add_norm,
|
| 268 |
-
residual_in_fp32=residual_in_fp32,
|
| 269 |
-
**factory_kwargs,
|
| 270 |
-
)
|
| 271 |
-
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
| 272 |
-
|
| 273 |
-
# Initialize weights and apply final processing
|
| 274 |
-
self.apply(
|
| 275 |
-
partial(
|
| 276 |
-
_init_weights,
|
| 277 |
-
n_layer=n_layer,
|
| 278 |
-
**(initializer_cfg if initializer_cfg is not None else {}),
|
| 279 |
-
)
|
| 280 |
-
)
|
| 281 |
-
self.tie_weights()
|
| 282 |
-
|
| 283 |
-
def tie_weights(self):
|
| 284 |
-
if self.config.tie_embeddings:
|
| 285 |
-
self.lm_head.weight = self.backbone.embedding.weight
|
| 286 |
-
|
| 287 |
-
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 288 |
-
return self.backbone.allocate_inference_cache(
|
| 289 |
-
batch_size, max_seqlen, dtype=dtype, **kwargs
|
| 290 |
-
)
|
| 291 |
-
|
| 292 |
-
def forward(
|
| 293 |
-
self,
|
| 294 |
-
input_ids,
|
| 295 |
-
position_ids=None,
|
| 296 |
-
inference_params=None,
|
| 297 |
-
num_last_tokens=0,
|
| 298 |
-
**mixer_kwargs,
|
| 299 |
-
):
|
| 300 |
-
"""
|
| 301 |
-
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
| 302 |
-
num_last_tokens: if > 0, only return the logits for the last n tokens
|
| 303 |
-
"""
|
| 304 |
-
hidden_states = self.backbone(
|
| 305 |
-
input_ids, inference_params=inference_params, **mixer_kwargs
|
| 306 |
-
)
|
| 307 |
-
if num_last_tokens > 0:
|
| 308 |
-
hidden_states = hidden_states[:, -num_last_tokens:]
|
| 309 |
-
lm_logits = self.lm_head(hidden_states)
|
| 310 |
-
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
| 311 |
-
return CausalLMOutput(logits=lm_logits)
|
| 312 |
-
|
| 313 |
-
@classmethod
|
| 314 |
-
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
| 315 |
-
config_data = load_config_hf(pretrained_model_name)
|
| 316 |
-
config = MambaConfig(**config_data)
|
| 317 |
-
model = cls(config, device=device, dtype=dtype, **kwargs)
|
| 318 |
-
model.load_state_dict(
|
| 319 |
-
load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)
|
| 320 |
-
)
|
| 321 |
-
return model
|
| 322 |
-
|
| 323 |
-
def save_pretrained(self, save_directory):
|
| 324 |
-
"""
|
| 325 |
-
Minimal implementation of save_pretrained for MambaLMHeadModel.
|
| 326 |
-
Save the model and its configuration file to a directory.
|
| 327 |
-
"""
|
| 328 |
-
# Ensure save_directory exists
|
| 329 |
-
os.makedirs(save_directory, exist_ok=True)
|
| 330 |
-
|
| 331 |
-
# Save the model's state_dict
|
| 332 |
-
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
| 333 |
-
torch.save(self.state_dict(), model_path)
|
| 334 |
-
|
| 335 |
-
# Save the configuration of the model
|
| 336 |
-
config_path = os.path.join(save_directory, "config.json")
|
| 337 |
-
with open(config_path, "w") as f:
|
| 338 |
-
json.dump(self.config.__dict__, f, indent=4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/modules/__init__.py
DELETED
|
File without changes
|
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/__init__.py
DELETED
|
File without changes
|
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/selective_scan_interface.py
DELETED
|
@@ -1,659 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
from ..utils.torch import custom_fwd, custom_bwd
|
| 6 |
-
|
| 7 |
-
from einops import rearrange, repeat
|
| 8 |
-
|
| 9 |
-
try:
|
| 10 |
-
from causal_conv1d import causal_conv1d_fn
|
| 11 |
-
import causal_conv1d_cuda
|
| 12 |
-
except ImportError:
|
| 13 |
-
causal_conv1d_fn = None
|
| 14 |
-
causal_conv1d_cuda = None
|
| 15 |
-
|
| 16 |
-
from .triton.layer_norm import _layer_norm_fwd
|
| 17 |
-
|
| 18 |
-
from .._ops import ops
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class SelectiveScanFn(torch.autograd.Function):
|
| 22 |
-
|
| 23 |
-
@staticmethod
|
| 24 |
-
def forward(
|
| 25 |
-
ctx,
|
| 26 |
-
u,
|
| 27 |
-
delta,
|
| 28 |
-
A,
|
| 29 |
-
B,
|
| 30 |
-
C,
|
| 31 |
-
D=None,
|
| 32 |
-
z=None,
|
| 33 |
-
delta_bias=None,
|
| 34 |
-
delta_softplus=False,
|
| 35 |
-
return_last_state=False,
|
| 36 |
-
):
|
| 37 |
-
if u.stride(-1) != 1:
|
| 38 |
-
u = u.contiguous()
|
| 39 |
-
if delta.stride(-1) != 1:
|
| 40 |
-
delta = delta.contiguous()
|
| 41 |
-
if D is not None:
|
| 42 |
-
D = D.contiguous()
|
| 43 |
-
if B.stride(-1) != 1:
|
| 44 |
-
B = B.contiguous()
|
| 45 |
-
if C.stride(-1) != 1:
|
| 46 |
-
C = C.contiguous()
|
| 47 |
-
if z is not None and z.stride(-1) != 1:
|
| 48 |
-
z = z.contiguous()
|
| 49 |
-
if B.dim() == 3:
|
| 50 |
-
B = rearrange(B, "b dstate l -> b 1 dstate l")
|
| 51 |
-
ctx.squeeze_B = True
|
| 52 |
-
if C.dim() == 3:
|
| 53 |
-
C = rearrange(C, "b dstate l -> b 1 dstate l")
|
| 54 |
-
ctx.squeeze_C = True
|
| 55 |
-
out, x, *rest = ops.selective_scan_fwd(
|
| 56 |
-
u, delta, A, B, C, D, z, delta_bias, delta_softplus
|
| 57 |
-
)
|
| 58 |
-
ctx.delta_softplus = delta_softplus
|
| 59 |
-
ctx.has_z = z is not None
|
| 60 |
-
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
|
| 61 |
-
if not ctx.has_z:
|
| 62 |
-
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
|
| 63 |
-
return out if not return_last_state else (out, last_state)
|
| 64 |
-
else:
|
| 65 |
-
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
|
| 66 |
-
out_z = rest[0]
|
| 67 |
-
return out_z if not return_last_state else (out_z, last_state)
|
| 68 |
-
|
| 69 |
-
@staticmethod
|
| 70 |
-
def backward(ctx, dout, *args):
|
| 71 |
-
if not ctx.has_z:
|
| 72 |
-
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
|
| 73 |
-
z = None
|
| 74 |
-
out = None
|
| 75 |
-
else:
|
| 76 |
-
u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
|
| 77 |
-
if dout.stride(-1) != 1:
|
| 78 |
-
dout = dout.contiguous()
|
| 79 |
-
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
| 80 |
-
# backward of selective_scan_cuda with the backward of chunk).
|
| 81 |
-
# Here we just pass in None and dz will be allocated in the C++ code.
|
| 82 |
-
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = ops.selective_scan_bwd(
|
| 83 |
-
u,
|
| 84 |
-
delta,
|
| 85 |
-
A,
|
| 86 |
-
B,
|
| 87 |
-
C,
|
| 88 |
-
D,
|
| 89 |
-
z,
|
| 90 |
-
delta_bias,
|
| 91 |
-
dout,
|
| 92 |
-
x,
|
| 93 |
-
out,
|
| 94 |
-
None,
|
| 95 |
-
ctx.delta_softplus,
|
| 96 |
-
False, # option to recompute out_z, not used here
|
| 97 |
-
)
|
| 98 |
-
dz = rest[0] if ctx.has_z else None
|
| 99 |
-
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
|
| 100 |
-
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
|
| 101 |
-
return (
|
| 102 |
-
du,
|
| 103 |
-
ddelta,
|
| 104 |
-
dA,
|
| 105 |
-
dB,
|
| 106 |
-
dC,
|
| 107 |
-
dD if D is not None else None,
|
| 108 |
-
dz,
|
| 109 |
-
ddelta_bias if delta_bias is not None else None,
|
| 110 |
-
None,
|
| 111 |
-
None,
|
| 112 |
-
)
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
def rms_norm_forward(
|
| 116 |
-
x,
|
| 117 |
-
weight,
|
| 118 |
-
bias,
|
| 119 |
-
eps=1e-6,
|
| 120 |
-
is_rms_norm=True,
|
| 121 |
-
):
|
| 122 |
-
# x (b l) d
|
| 123 |
-
if x.stride(-1) != 1:
|
| 124 |
-
x = x.contiguous()
|
| 125 |
-
weight = weight.contiguous()
|
| 126 |
-
if bias is not None:
|
| 127 |
-
bias = bias.contiguous()
|
| 128 |
-
y = _layer_norm_fwd(
|
| 129 |
-
x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm
|
| 130 |
-
)[0]
|
| 131 |
-
# y (b l) d
|
| 132 |
-
return y
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
def selective_scan_fn(
|
| 136 |
-
u,
|
| 137 |
-
delta,
|
| 138 |
-
A,
|
| 139 |
-
B,
|
| 140 |
-
C,
|
| 141 |
-
D=None,
|
| 142 |
-
z=None,
|
| 143 |
-
delta_bias=None,
|
| 144 |
-
delta_softplus=False,
|
| 145 |
-
return_last_state=False,
|
| 146 |
-
):
|
| 147 |
-
"""if return_last_state is True, returns (out, last_state)
|
| 148 |
-
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
|
| 149 |
-
not considered in the backward pass.
|
| 150 |
-
"""
|
| 151 |
-
return SelectiveScanFn.apply(
|
| 152 |
-
u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
def selective_scan_ref(
|
| 157 |
-
u,
|
| 158 |
-
delta,
|
| 159 |
-
A,
|
| 160 |
-
B,
|
| 161 |
-
C,
|
| 162 |
-
D=None,
|
| 163 |
-
z=None,
|
| 164 |
-
delta_bias=None,
|
| 165 |
-
delta_softplus=False,
|
| 166 |
-
return_last_state=False,
|
| 167 |
-
):
|
| 168 |
-
"""
|
| 169 |
-
u: r(B D L)
|
| 170 |
-
delta: r(B D L)
|
| 171 |
-
A: c(D N) or r(D N)
|
| 172 |
-
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
| 173 |
-
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
| 174 |
-
D: r(D)
|
| 175 |
-
z: r(B D L)
|
| 176 |
-
delta_bias: r(D), fp32
|
| 177 |
-
|
| 178 |
-
out: r(B D L)
|
| 179 |
-
last_state (optional): r(B D dstate) or c(B D dstate)
|
| 180 |
-
"""
|
| 181 |
-
dtype_in = u.dtype
|
| 182 |
-
u = u.float()
|
| 183 |
-
delta = delta.float()
|
| 184 |
-
if delta_bias is not None:
|
| 185 |
-
delta = delta + delta_bias[..., None].float()
|
| 186 |
-
if delta_softplus:
|
| 187 |
-
delta = F.softplus(delta)
|
| 188 |
-
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
|
| 189 |
-
is_variable_B = B.dim() >= 3
|
| 190 |
-
is_variable_C = C.dim() >= 3
|
| 191 |
-
if A.is_complex():
|
| 192 |
-
if is_variable_B:
|
| 193 |
-
B = torch.view_as_complex(
|
| 194 |
-
rearrange(B.float(), "... (L two) -> ... L two", two=2)
|
| 195 |
-
)
|
| 196 |
-
if is_variable_C:
|
| 197 |
-
C = torch.view_as_complex(
|
| 198 |
-
rearrange(C.float(), "... (L two) -> ... L two", two=2)
|
| 199 |
-
)
|
| 200 |
-
else:
|
| 201 |
-
B = B.float()
|
| 202 |
-
C = C.float()
|
| 203 |
-
x = A.new_zeros((batch, dim, dstate))
|
| 204 |
-
ys = []
|
| 205 |
-
deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
|
| 206 |
-
if not is_variable_B:
|
| 207 |
-
deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
|
| 208 |
-
else:
|
| 209 |
-
if B.dim() == 3:
|
| 210 |
-
deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
|
| 211 |
-
else:
|
| 212 |
-
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
|
| 213 |
-
deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
|
| 214 |
-
if is_variable_C and C.dim() == 4:
|
| 215 |
-
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
|
| 216 |
-
last_state = None
|
| 217 |
-
for i in range(u.shape[2]):
|
| 218 |
-
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
| 219 |
-
if not is_variable_C:
|
| 220 |
-
y = torch.einsum("bdn,dn->bd", x, C)
|
| 221 |
-
else:
|
| 222 |
-
if C.dim() == 3:
|
| 223 |
-
y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
|
| 224 |
-
else:
|
| 225 |
-
y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
|
| 226 |
-
if i == u.shape[2] - 1:
|
| 227 |
-
last_state = x
|
| 228 |
-
if y.is_complex():
|
| 229 |
-
y = y.real * 2
|
| 230 |
-
ys.append(y)
|
| 231 |
-
y = torch.stack(ys, dim=2) # (batch dim L)
|
| 232 |
-
out = y if D is None else y + u * rearrange(D, "d -> d 1")
|
| 233 |
-
if z is not None:
|
| 234 |
-
out = out * F.silu(z)
|
| 235 |
-
out = out.to(dtype=dtype_in)
|
| 236 |
-
return out if not return_last_state else (out, last_state)
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
class MambaInnerFn(torch.autograd.Function):
|
| 240 |
-
|
| 241 |
-
@staticmethod
|
| 242 |
-
@custom_fwd
|
| 243 |
-
def forward(
|
| 244 |
-
ctx,
|
| 245 |
-
xz,
|
| 246 |
-
conv1d_weight,
|
| 247 |
-
conv1d_bias,
|
| 248 |
-
x_proj_weight,
|
| 249 |
-
delta_proj_weight,
|
| 250 |
-
out_proj_weight,
|
| 251 |
-
out_proj_bias,
|
| 252 |
-
A,
|
| 253 |
-
B=None,
|
| 254 |
-
C=None,
|
| 255 |
-
D=None,
|
| 256 |
-
delta_bias=None,
|
| 257 |
-
B_proj_bias=None,
|
| 258 |
-
C_proj_bias=None,
|
| 259 |
-
delta_softplus=True,
|
| 260 |
-
checkpoint_lvl=1,
|
| 261 |
-
b_rms_weight=None,
|
| 262 |
-
c_rms_weight=None,
|
| 263 |
-
dt_rms_weight=None,
|
| 264 |
-
b_c_dt_rms_eps=1e-6,
|
| 265 |
-
):
|
| 266 |
-
"""
|
| 267 |
-
xz: (batch, dim, seqlen)
|
| 268 |
-
"""
|
| 269 |
-
assert (
|
| 270 |
-
causal_conv1d_cuda is not None
|
| 271 |
-
), "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
| 272 |
-
assert checkpoint_lvl in [0, 1]
|
| 273 |
-
L = xz.shape[-1]
|
| 274 |
-
delta_rank = delta_proj_weight.shape[1]
|
| 275 |
-
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
| 276 |
-
if torch.is_autocast_enabled():
|
| 277 |
-
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
| 278 |
-
delta_proj_weight = delta_proj_weight.to(
|
| 279 |
-
dtype=torch.get_autocast_gpu_dtype()
|
| 280 |
-
)
|
| 281 |
-
out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
| 282 |
-
out_proj_bias = (
|
| 283 |
-
out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
|
| 284 |
-
if out_proj_bias is not None
|
| 285 |
-
else None
|
| 286 |
-
)
|
| 287 |
-
if xz.stride(-1) != 1:
|
| 288 |
-
xz = xz.contiguous()
|
| 289 |
-
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
|
| 290 |
-
x, z = xz.chunk(2, dim=1)
|
| 291 |
-
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
|
| 292 |
-
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
|
| 293 |
-
x, conv1d_weight, conv1d_bias, None, None, None, True
|
| 294 |
-
)
|
| 295 |
-
# We're being very careful here about the layout, to avoid extra transposes.
|
| 296 |
-
# We want delta to have d as the slowest moving dimension
|
| 297 |
-
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
| 298 |
-
x_dbl = F.linear(
|
| 299 |
-
rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight
|
| 300 |
-
) # (bl d)
|
| 301 |
-
delta = rearrange(
|
| 302 |
-
delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L
|
| 303 |
-
)
|
| 304 |
-
ctx.is_variable_B = B is None
|
| 305 |
-
ctx.is_variable_C = C is None
|
| 306 |
-
ctx.B_proj_bias_is_None = B_proj_bias is None
|
| 307 |
-
ctx.C_proj_bias_is_None = C_proj_bias is None
|
| 308 |
-
if B is None: # variable B
|
| 309 |
-
B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate)
|
| 310 |
-
if B_proj_bias is not None:
|
| 311 |
-
B = B + B_proj_bias.to(dtype=B.dtype)
|
| 312 |
-
if not A.is_complex():
|
| 313 |
-
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
| 314 |
-
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| 315 |
-
else:
|
| 316 |
-
B = rearrange(
|
| 317 |
-
B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2
|
| 318 |
-
).contiguous()
|
| 319 |
-
else:
|
| 320 |
-
if B.stride(-1) != 1:
|
| 321 |
-
B = B.contiguous()
|
| 322 |
-
if C is None: # variable C
|
| 323 |
-
C = x_dbl[:, -d_state:] # (bl dstate)
|
| 324 |
-
if C_proj_bias is not None:
|
| 325 |
-
C = C + C_proj_bias.to(dtype=C.dtype)
|
| 326 |
-
if not A.is_complex():
|
| 327 |
-
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
| 328 |
-
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| 329 |
-
else:
|
| 330 |
-
C = rearrange(
|
| 331 |
-
C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2
|
| 332 |
-
).contiguous()
|
| 333 |
-
else:
|
| 334 |
-
if C.stride(-1) != 1:
|
| 335 |
-
C = C.contiguous()
|
| 336 |
-
if D is not None:
|
| 337 |
-
D = D.contiguous()
|
| 338 |
-
|
| 339 |
-
if b_rms_weight is not None:
|
| 340 |
-
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
| 341 |
-
B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)
|
| 342 |
-
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| 343 |
-
if c_rms_weight is not None:
|
| 344 |
-
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
| 345 |
-
C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)
|
| 346 |
-
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| 347 |
-
if dt_rms_weight is not None:
|
| 348 |
-
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
|
| 349 |
-
delta = rms_norm_forward(
|
| 350 |
-
delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps
|
| 351 |
-
)
|
| 352 |
-
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
|
| 353 |
-
|
| 354 |
-
out, scan_intermediates, out_z = ops.selective_scan_fwd(
|
| 355 |
-
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
|
| 356 |
-
)
|
| 357 |
-
ctx.delta_softplus = delta_softplus
|
| 358 |
-
ctx.out_proj_bias_is_None = out_proj_bias is None
|
| 359 |
-
ctx.checkpoint_lvl = checkpoint_lvl
|
| 360 |
-
ctx.b_rms_weight = b_rms_weight
|
| 361 |
-
ctx.c_rms_weight = c_rms_weight
|
| 362 |
-
ctx.dt_rms_weight = dt_rms_weight
|
| 363 |
-
ctx.b_c_dt_rms_eps = b_c_dt_rms_eps
|
| 364 |
-
if (
|
| 365 |
-
checkpoint_lvl >= 1
|
| 366 |
-
): # Will recompute conv1d_out and delta in the backward pass
|
| 367 |
-
conv1d_out, delta = None, None
|
| 368 |
-
ctx.save_for_backward(
|
| 369 |
-
xz,
|
| 370 |
-
conv1d_weight,
|
| 371 |
-
conv1d_bias,
|
| 372 |
-
x_dbl,
|
| 373 |
-
x_proj_weight,
|
| 374 |
-
delta_proj_weight,
|
| 375 |
-
out_proj_weight,
|
| 376 |
-
conv1d_out,
|
| 377 |
-
delta,
|
| 378 |
-
A,
|
| 379 |
-
B,
|
| 380 |
-
C,
|
| 381 |
-
D,
|
| 382 |
-
delta_bias,
|
| 383 |
-
scan_intermediates,
|
| 384 |
-
b_rms_weight,
|
| 385 |
-
c_rms_weight,
|
| 386 |
-
dt_rms_weight,
|
| 387 |
-
out,
|
| 388 |
-
)
|
| 389 |
-
return F.linear(
|
| 390 |
-
rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias
|
| 391 |
-
)
|
| 392 |
-
|
| 393 |
-
@staticmethod
|
| 394 |
-
@custom_bwd
|
| 395 |
-
def backward(ctx, dout):
|
| 396 |
-
# dout: (batch, seqlen, dim)
|
| 397 |
-
assert (
|
| 398 |
-
causal_conv1d_cuda is not None
|
| 399 |
-
), "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
| 400 |
-
(
|
| 401 |
-
xz,
|
| 402 |
-
conv1d_weight,
|
| 403 |
-
conv1d_bias,
|
| 404 |
-
x_dbl,
|
| 405 |
-
x_proj_weight,
|
| 406 |
-
delta_proj_weight,
|
| 407 |
-
out_proj_weight,
|
| 408 |
-
conv1d_out,
|
| 409 |
-
delta,
|
| 410 |
-
A,
|
| 411 |
-
B,
|
| 412 |
-
C,
|
| 413 |
-
D,
|
| 414 |
-
delta_bias,
|
| 415 |
-
scan_intermediates,
|
| 416 |
-
b_rms_weight,
|
| 417 |
-
c_rms_weight,
|
| 418 |
-
dt_rms_weight,
|
| 419 |
-
out,
|
| 420 |
-
) = ctx.saved_tensors
|
| 421 |
-
L = xz.shape[-1]
|
| 422 |
-
delta_rank = delta_proj_weight.shape[1]
|
| 423 |
-
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
| 424 |
-
x, z = xz.chunk(2, dim=1)
|
| 425 |
-
if dout.stride(-1) != 1:
|
| 426 |
-
dout = dout.contiguous()
|
| 427 |
-
if ctx.checkpoint_lvl == 1:
|
| 428 |
-
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
|
| 429 |
-
x, conv1d_weight, conv1d_bias, None, None, None, True
|
| 430 |
-
)
|
| 431 |
-
delta = rearrange(
|
| 432 |
-
delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L
|
| 433 |
-
)
|
| 434 |
-
if dt_rms_weight is not None:
|
| 435 |
-
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
|
| 436 |
-
delta = rms_norm_forward(
|
| 437 |
-
delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps
|
| 438 |
-
)
|
| 439 |
-
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
|
| 440 |
-
if b_rms_weight is not None:
|
| 441 |
-
# Recompute & RMSNorm B
|
| 442 |
-
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
| 443 |
-
B = rms_norm_forward(B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps)
|
| 444 |
-
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| 445 |
-
if c_rms_weight is not None:
|
| 446 |
-
# Recompute & RMSNorm C
|
| 447 |
-
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
| 448 |
-
C = rms_norm_forward(C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps)
|
| 449 |
-
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| 450 |
-
|
| 451 |
-
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
| 452 |
-
# backward of selective_scan_cuda with the backward of chunk).
|
| 453 |
-
dxz = torch.empty_like(xz) # (batch, dim, seqlen)
|
| 454 |
-
dx, dz = dxz.chunk(2, dim=1)
|
| 455 |
-
dout = rearrange(dout, "b l e -> e (b l)")
|
| 456 |
-
dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
|
| 457 |
-
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = (
|
| 458 |
-
ops.selective_scan_bwd(
|
| 459 |
-
conv1d_out,
|
| 460 |
-
delta,
|
| 461 |
-
A,
|
| 462 |
-
B,
|
| 463 |
-
C,
|
| 464 |
-
D,
|
| 465 |
-
z,
|
| 466 |
-
delta_bias,
|
| 467 |
-
dout_y,
|
| 468 |
-
scan_intermediates,
|
| 469 |
-
out,
|
| 470 |
-
dz,
|
| 471 |
-
ctx.delta_softplus,
|
| 472 |
-
True, # option to recompute out_z
|
| 473 |
-
)
|
| 474 |
-
)
|
| 475 |
-
dout_proj_weight = torch.einsum(
|
| 476 |
-
"eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")
|
| 477 |
-
)
|
| 478 |
-
dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
|
| 479 |
-
dD = dD if D is not None else None
|
| 480 |
-
dx_dbl = torch.empty_like(x_dbl)
|
| 481 |
-
dB_proj_bias = None
|
| 482 |
-
if ctx.is_variable_B:
|
| 483 |
-
if not A.is_complex():
|
| 484 |
-
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
|
| 485 |
-
else:
|
| 486 |
-
dB = rearrange(
|
| 487 |
-
dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2
|
| 488 |
-
).contiguous()
|
| 489 |
-
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
|
| 490 |
-
dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d)
|
| 491 |
-
dB = None
|
| 492 |
-
dC_proj_bias = None
|
| 493 |
-
if ctx.is_variable_C:
|
| 494 |
-
if not A.is_complex():
|
| 495 |
-
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
|
| 496 |
-
else:
|
| 497 |
-
dC = rearrange(
|
| 498 |
-
dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2
|
| 499 |
-
).contiguous()
|
| 500 |
-
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
|
| 501 |
-
dx_dbl[:, -d_state:] = dC # (bl d)
|
| 502 |
-
dC = None
|
| 503 |
-
ddelta = rearrange(ddelta, "b d l -> d (b l)")
|
| 504 |
-
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
|
| 505 |
-
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
|
| 506 |
-
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
|
| 507 |
-
dx_proj_weight = torch.einsum(
|
| 508 |
-
"Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")
|
| 509 |
-
)
|
| 510 |
-
dconv1d_out = torch.addmm(
|
| 511 |
-
dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out
|
| 512 |
-
)
|
| 513 |
-
dconv1d_out = rearrange(
|
| 514 |
-
dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]
|
| 515 |
-
)
|
| 516 |
-
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
| 517 |
-
# backward of conv1d with the backward of chunk).
|
| 518 |
-
dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
|
| 519 |
-
x,
|
| 520 |
-
conv1d_weight,
|
| 521 |
-
conv1d_bias,
|
| 522 |
-
dconv1d_out,
|
| 523 |
-
None,
|
| 524 |
-
None,
|
| 525 |
-
None,
|
| 526 |
-
dx,
|
| 527 |
-
False,
|
| 528 |
-
True,
|
| 529 |
-
)
|
| 530 |
-
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
|
| 531 |
-
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
|
| 532 |
-
return (
|
| 533 |
-
dxz,
|
| 534 |
-
dconv1d_weight,
|
| 535 |
-
dconv1d_bias,
|
| 536 |
-
dx_proj_weight,
|
| 537 |
-
ddelta_proj_weight,
|
| 538 |
-
dout_proj_weight,
|
| 539 |
-
dout_proj_bias,
|
| 540 |
-
dA,
|
| 541 |
-
dB,
|
| 542 |
-
dC,
|
| 543 |
-
dD,
|
| 544 |
-
ddelta_bias if delta_bias is not None else None,
|
| 545 |
-
# 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps
|
| 546 |
-
dB_proj_bias,
|
| 547 |
-
dC_proj_bias,
|
| 548 |
-
None,
|
| 549 |
-
None,
|
| 550 |
-
None,
|
| 551 |
-
None,
|
| 552 |
-
None,
|
| 553 |
-
None,
|
| 554 |
-
)
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
def mamba_inner_fn(
|
| 558 |
-
xz,
|
| 559 |
-
conv1d_weight,
|
| 560 |
-
conv1d_bias,
|
| 561 |
-
x_proj_weight,
|
| 562 |
-
delta_proj_weight,
|
| 563 |
-
out_proj_weight,
|
| 564 |
-
out_proj_bias,
|
| 565 |
-
A,
|
| 566 |
-
B=None,
|
| 567 |
-
C=None,
|
| 568 |
-
D=None,
|
| 569 |
-
delta_bias=None,
|
| 570 |
-
B_proj_bias=None,
|
| 571 |
-
C_proj_bias=None,
|
| 572 |
-
delta_softplus=True,
|
| 573 |
-
checkpoint_lvl=1,
|
| 574 |
-
b_rms_weight=None,
|
| 575 |
-
c_rms_weight=None,
|
| 576 |
-
dt_rms_weight=None,
|
| 577 |
-
b_c_dt_rms_eps=1e-6,
|
| 578 |
-
):
|
| 579 |
-
return MambaInnerFn.apply(
|
| 580 |
-
xz,
|
| 581 |
-
conv1d_weight,
|
| 582 |
-
conv1d_bias,
|
| 583 |
-
x_proj_weight,
|
| 584 |
-
delta_proj_weight,
|
| 585 |
-
out_proj_weight,
|
| 586 |
-
out_proj_bias,
|
| 587 |
-
A,
|
| 588 |
-
B,
|
| 589 |
-
C,
|
| 590 |
-
D,
|
| 591 |
-
delta_bias,
|
| 592 |
-
B_proj_bias,
|
| 593 |
-
C_proj_bias,
|
| 594 |
-
delta_softplus,
|
| 595 |
-
checkpoint_lvl,
|
| 596 |
-
b_rms_weight,
|
| 597 |
-
c_rms_weight,
|
| 598 |
-
dt_rms_weight,
|
| 599 |
-
b_c_dt_rms_eps,
|
| 600 |
-
)
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
def mamba_inner_ref(
|
| 604 |
-
xz,
|
| 605 |
-
conv1d_weight,
|
| 606 |
-
conv1d_bias,
|
| 607 |
-
x_proj_weight,
|
| 608 |
-
delta_proj_weight,
|
| 609 |
-
out_proj_weight,
|
| 610 |
-
out_proj_bias,
|
| 611 |
-
A,
|
| 612 |
-
B=None,
|
| 613 |
-
C=None,
|
| 614 |
-
D=None,
|
| 615 |
-
delta_bias=None,
|
| 616 |
-
B_proj_bias=None,
|
| 617 |
-
C_proj_bias=None,
|
| 618 |
-
delta_softplus=True,
|
| 619 |
-
):
|
| 620 |
-
assert (
|
| 621 |
-
causal_conv1d_fn is not None
|
| 622 |
-
), "causal_conv1d_fn is not available. Please install causal-conv1d."
|
| 623 |
-
L = xz.shape[-1]
|
| 624 |
-
delta_rank = delta_proj_weight.shape[1]
|
| 625 |
-
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
| 626 |
-
x, z = xz.chunk(2, dim=1)
|
| 627 |
-
x = causal_conv1d_fn(
|
| 628 |
-
x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu"
|
| 629 |
-
)
|
| 630 |
-
# We're being very careful here about the layout, to avoid extra transposes.
|
| 631 |
-
# We want delta to have d as the slowest moving dimension
|
| 632 |
-
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
| 633 |
-
x_dbl = F.linear(rearrange(x, "b d l -> (b l) d"), x_proj_weight) # (bl d)
|
| 634 |
-
delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
|
| 635 |
-
delta = rearrange(delta, "d (b l) -> b d l", l=L)
|
| 636 |
-
if B is None: # variable B
|
| 637 |
-
B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl d)
|
| 638 |
-
if B_proj_bias is not None:
|
| 639 |
-
B = B + B_proj_bias.to(dtype=B.dtype)
|
| 640 |
-
if not A.is_complex():
|
| 641 |
-
B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
| 642 |
-
else:
|
| 643 |
-
B = rearrange(
|
| 644 |
-
B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2
|
| 645 |
-
).contiguous()
|
| 646 |
-
if C is None: # variable B
|
| 647 |
-
C = x_dbl[:, -d_state:] # (bl d)
|
| 648 |
-
if C_proj_bias is not None:
|
| 649 |
-
C = C + C_proj_bias.to(dtype=C.dtype)
|
| 650 |
-
if not A.is_complex():
|
| 651 |
-
C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
| 652 |
-
else:
|
| 653 |
-
C = rearrange(
|
| 654 |
-
C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2
|
| 655 |
-
).contiguous()
|
| 656 |
-
y = selective_scan_fn(
|
| 657 |
-
x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True
|
| 658 |
-
)
|
| 659 |
-
return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/triton/__init__.py
DELETED
|
File without changes
|
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/triton/layer_norm.py
DELETED
|
@@ -1,1166 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2024, Tri Dao.
|
| 2 |
-
# Implement dropout + residual + layer_norm / rms_norm.
|
| 3 |
-
|
| 4 |
-
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
| 5 |
-
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
| 6 |
-
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
| 7 |
-
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
| 8 |
-
|
| 9 |
-
import math
|
| 10 |
-
import warnings
|
| 11 |
-
|
| 12 |
-
import torch
|
| 13 |
-
import torch.nn.functional as F
|
| 14 |
-
from ...utils.torch import custom_bwd, custom_fwd
|
| 15 |
-
|
| 16 |
-
import triton
|
| 17 |
-
import triton.language as tl
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def layer_norm_ref(
|
| 21 |
-
x,
|
| 22 |
-
weight,
|
| 23 |
-
bias,
|
| 24 |
-
residual=None,
|
| 25 |
-
x1=None,
|
| 26 |
-
weight1=None,
|
| 27 |
-
bias1=None,
|
| 28 |
-
eps=1e-6,
|
| 29 |
-
dropout_p=0.0,
|
| 30 |
-
rowscale=None,
|
| 31 |
-
prenorm=False,
|
| 32 |
-
dropout_mask=None,
|
| 33 |
-
dropout_mask1=None,
|
| 34 |
-
upcast=False,
|
| 35 |
-
):
|
| 36 |
-
dtype = x.dtype
|
| 37 |
-
if upcast:
|
| 38 |
-
x = x.float()
|
| 39 |
-
weight = weight.float()
|
| 40 |
-
bias = bias.float() if bias is not None else None
|
| 41 |
-
residual = residual.float() if residual is not None else residual
|
| 42 |
-
x1 = x1.float() if x1 is not None else None
|
| 43 |
-
weight1 = weight1.float() if weight1 is not None else None
|
| 44 |
-
bias1 = bias1.float() if bias1 is not None else None
|
| 45 |
-
if x1 is not None:
|
| 46 |
-
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
| 47 |
-
if rowscale is not None:
|
| 48 |
-
x = x * rowscale[..., None]
|
| 49 |
-
if dropout_p > 0.0:
|
| 50 |
-
if dropout_mask is not None:
|
| 51 |
-
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
|
| 52 |
-
else:
|
| 53 |
-
x = F.dropout(x, p=dropout_p)
|
| 54 |
-
if x1 is not None:
|
| 55 |
-
if dropout_mask1 is not None:
|
| 56 |
-
x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
|
| 57 |
-
else:
|
| 58 |
-
x1 = F.dropout(x1, p=dropout_p)
|
| 59 |
-
if x1 is not None:
|
| 60 |
-
x = x + x1
|
| 61 |
-
if residual is not None:
|
| 62 |
-
x = (x + residual).to(x.dtype)
|
| 63 |
-
out = F.layer_norm(
|
| 64 |
-
x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps
|
| 65 |
-
).to(dtype)
|
| 66 |
-
if weight1 is None:
|
| 67 |
-
return out if not prenorm else (out, x)
|
| 68 |
-
else:
|
| 69 |
-
out1 = F.layer_norm(
|
| 70 |
-
x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
|
| 71 |
-
).to(dtype)
|
| 72 |
-
return (out, out1) if not prenorm else (out, out1, x)
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def rms_norm_ref(
|
| 76 |
-
x,
|
| 77 |
-
weight,
|
| 78 |
-
bias,
|
| 79 |
-
residual=None,
|
| 80 |
-
x1=None,
|
| 81 |
-
weight1=None,
|
| 82 |
-
bias1=None,
|
| 83 |
-
eps=1e-6,
|
| 84 |
-
dropout_p=0.0,
|
| 85 |
-
rowscale=None,
|
| 86 |
-
prenorm=False,
|
| 87 |
-
dropout_mask=None,
|
| 88 |
-
dropout_mask1=None,
|
| 89 |
-
upcast=False,
|
| 90 |
-
):
|
| 91 |
-
dtype = x.dtype
|
| 92 |
-
if upcast:
|
| 93 |
-
x = x.float()
|
| 94 |
-
weight = weight.float()
|
| 95 |
-
bias = bias.float() if bias is not None else None
|
| 96 |
-
residual = residual.float() if residual is not None else residual
|
| 97 |
-
x1 = x1.float() if x1 is not None else None
|
| 98 |
-
weight1 = weight1.float() if weight1 is not None else None
|
| 99 |
-
bias1 = bias1.float() if bias1 is not None else None
|
| 100 |
-
if x1 is not None:
|
| 101 |
-
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
| 102 |
-
if rowscale is not None:
|
| 103 |
-
x = x * rowscale[..., None]
|
| 104 |
-
if dropout_p > 0.0:
|
| 105 |
-
if dropout_mask is not None:
|
| 106 |
-
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
|
| 107 |
-
else:
|
| 108 |
-
x = F.dropout(x, p=dropout_p)
|
| 109 |
-
if x1 is not None:
|
| 110 |
-
if dropout_mask1 is not None:
|
| 111 |
-
x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
|
| 112 |
-
else:
|
| 113 |
-
x1 = F.dropout(x1, p=dropout_p)
|
| 114 |
-
if x1 is not None:
|
| 115 |
-
x = x + x1
|
| 116 |
-
if residual is not None:
|
| 117 |
-
x = (x + residual).to(x.dtype)
|
| 118 |
-
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
| 119 |
-
out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(
|
| 120 |
-
dtype
|
| 121 |
-
)
|
| 122 |
-
if weight1 is None:
|
| 123 |
-
return out if not prenorm else (out, x)
|
| 124 |
-
else:
|
| 125 |
-
out1 = (
|
| 126 |
-
(x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)
|
| 127 |
-
).to(dtype)
|
| 128 |
-
return (out, out1) if not prenorm else (out, out1, x)
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
def config_prune(configs):
|
| 132 |
-
|
| 133 |
-
if torch.version.hip:
|
| 134 |
-
try:
|
| 135 |
-
# set warp size based on gcn architecure
|
| 136 |
-
gcn_arch_name = torch.cuda.get_device_properties(0).gcnArchName
|
| 137 |
-
if "gfx10" in gcn_arch_name or "gfx11" in gcn_arch_name:
|
| 138 |
-
# radeon
|
| 139 |
-
warp_size = 32
|
| 140 |
-
else:
|
| 141 |
-
# instinct
|
| 142 |
-
warp_size = 64
|
| 143 |
-
except AttributeError as e:
|
| 144 |
-
# fall back to crude method to set warp size
|
| 145 |
-
device_name = torch.cuda.get_device_properties(0).name
|
| 146 |
-
if "instinct" in device_name.lower():
|
| 147 |
-
warp_size = 64
|
| 148 |
-
else:
|
| 149 |
-
warp_size = 32
|
| 150 |
-
warnings.warn(
|
| 151 |
-
f"{e}, warp size set to {warp_size} based on device name: {device_name}",
|
| 152 |
-
UserWarning,
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
-
else:
|
| 156 |
-
# cuda
|
| 157 |
-
warp_size = 32
|
| 158 |
-
|
| 159 |
-
max_block_sz = 1024
|
| 160 |
-
max_num_warps = max_block_sz // warp_size
|
| 161 |
-
pruned_configs = [config for config in configs if config.num_warps <= max_num_warps]
|
| 162 |
-
return pruned_configs
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
configs_autotune = [
|
| 166 |
-
triton.Config({}, num_warps=1),
|
| 167 |
-
triton.Config({}, num_warps=2),
|
| 168 |
-
triton.Config({}, num_warps=4),
|
| 169 |
-
triton.Config({}, num_warps=8),
|
| 170 |
-
triton.Config({}, num_warps=16),
|
| 171 |
-
triton.Config({}, num_warps=32),
|
| 172 |
-
]
|
| 173 |
-
|
| 174 |
-
pruned_configs_autotune = config_prune(configs_autotune)
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
@triton.autotune(
|
| 178 |
-
configs=pruned_configs_autotune,
|
| 179 |
-
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
|
| 180 |
-
)
|
| 181 |
-
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
| 182 |
-
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
|
| 183 |
-
@triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
|
| 184 |
-
@triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
|
| 185 |
-
@triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
|
| 186 |
-
@triton.jit
|
| 187 |
-
def _layer_norm_fwd_1pass_kernel(
|
| 188 |
-
X, # pointer to the input
|
| 189 |
-
Y, # pointer to the output
|
| 190 |
-
W, # pointer to the weights
|
| 191 |
-
B, # pointer to the biases
|
| 192 |
-
RESIDUAL, # pointer to the residual
|
| 193 |
-
X1,
|
| 194 |
-
W1,
|
| 195 |
-
B1,
|
| 196 |
-
Y1,
|
| 197 |
-
RESIDUAL_OUT, # pointer to the residual
|
| 198 |
-
ROWSCALE,
|
| 199 |
-
SEEDS, # Dropout seeds for each row
|
| 200 |
-
DROPOUT_MASK,
|
| 201 |
-
Mean, # pointer to the mean
|
| 202 |
-
Rstd, # pointer to the 1/std
|
| 203 |
-
stride_x_row, # how much to increase the pointer when moving by 1 row
|
| 204 |
-
stride_y_row,
|
| 205 |
-
stride_res_row,
|
| 206 |
-
stride_res_out_row,
|
| 207 |
-
stride_x1_row,
|
| 208 |
-
stride_y1_row,
|
| 209 |
-
M, # number of rows in X
|
| 210 |
-
N, # number of columns in X
|
| 211 |
-
eps, # epsilon to avoid division by zero
|
| 212 |
-
dropout_p, # Dropout probability
|
| 213 |
-
IS_RMS_NORM: tl.constexpr,
|
| 214 |
-
BLOCK_N: tl.constexpr,
|
| 215 |
-
HAS_RESIDUAL: tl.constexpr,
|
| 216 |
-
STORE_RESIDUAL_OUT: tl.constexpr,
|
| 217 |
-
HAS_BIAS: tl.constexpr,
|
| 218 |
-
HAS_DROPOUT: tl.constexpr,
|
| 219 |
-
STORE_DROPOUT_MASK: tl.constexpr,
|
| 220 |
-
HAS_ROWSCALE: tl.constexpr,
|
| 221 |
-
HAS_X1: tl.constexpr,
|
| 222 |
-
HAS_W1: tl.constexpr,
|
| 223 |
-
HAS_B1: tl.constexpr,
|
| 224 |
-
):
|
| 225 |
-
# Map the program id to the row of X and Y it should compute.
|
| 226 |
-
row = tl.program_id(0)
|
| 227 |
-
X += row * stride_x_row
|
| 228 |
-
Y += row * stride_y_row
|
| 229 |
-
if HAS_RESIDUAL:
|
| 230 |
-
RESIDUAL += row * stride_res_row
|
| 231 |
-
if STORE_RESIDUAL_OUT:
|
| 232 |
-
RESIDUAL_OUT += row * stride_res_out_row
|
| 233 |
-
if HAS_X1:
|
| 234 |
-
X1 += row * stride_x1_row
|
| 235 |
-
if HAS_W1:
|
| 236 |
-
Y1 += row * stride_y1_row
|
| 237 |
-
# Compute mean and variance
|
| 238 |
-
cols = tl.arange(0, BLOCK_N)
|
| 239 |
-
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
| 240 |
-
if HAS_ROWSCALE:
|
| 241 |
-
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
|
| 242 |
-
x *= rowscale
|
| 243 |
-
if HAS_DROPOUT:
|
| 244 |
-
# Compute dropout mask
|
| 245 |
-
# 7 rounds is good enough, and reduces register pressure
|
| 246 |
-
keep_mask = (
|
| 247 |
-
tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
|
| 248 |
-
)
|
| 249 |
-
x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
|
| 250 |
-
if STORE_DROPOUT_MASK:
|
| 251 |
-
tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
|
| 252 |
-
if HAS_X1:
|
| 253 |
-
x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
|
| 254 |
-
if HAS_ROWSCALE:
|
| 255 |
-
rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
|
| 256 |
-
x1 *= rowscale
|
| 257 |
-
if HAS_DROPOUT:
|
| 258 |
-
# Compute dropout mask
|
| 259 |
-
# 7 rounds is good enough, and reduces register pressure
|
| 260 |
-
keep_mask = (
|
| 261 |
-
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
|
| 262 |
-
> dropout_p
|
| 263 |
-
)
|
| 264 |
-
x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
|
| 265 |
-
if STORE_DROPOUT_MASK:
|
| 266 |
-
tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
|
| 267 |
-
x += x1
|
| 268 |
-
if HAS_RESIDUAL:
|
| 269 |
-
residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
|
| 270 |
-
x += residual
|
| 271 |
-
if STORE_RESIDUAL_OUT:
|
| 272 |
-
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
|
| 273 |
-
if not IS_RMS_NORM:
|
| 274 |
-
mean = tl.sum(x, axis=0) / N
|
| 275 |
-
tl.store(Mean + row, mean)
|
| 276 |
-
xbar = tl.where(cols < N, x - mean, 0.0)
|
| 277 |
-
var = tl.sum(xbar * xbar, axis=0) / N
|
| 278 |
-
else:
|
| 279 |
-
xbar = tl.where(cols < N, x, 0.0)
|
| 280 |
-
var = tl.sum(xbar * xbar, axis=0) / N
|
| 281 |
-
rstd = 1 / tl.sqrt(var + eps)
|
| 282 |
-
tl.store(Rstd + row, rstd)
|
| 283 |
-
# Normalize and apply linear transformation
|
| 284 |
-
mask = cols < N
|
| 285 |
-
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
| 286 |
-
if HAS_BIAS:
|
| 287 |
-
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
| 288 |
-
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
| 289 |
-
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
| 290 |
-
# Write output
|
| 291 |
-
tl.store(Y + cols, y, mask=mask)
|
| 292 |
-
if HAS_W1:
|
| 293 |
-
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
|
| 294 |
-
if HAS_B1:
|
| 295 |
-
b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
|
| 296 |
-
y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
|
| 297 |
-
tl.store(Y1 + cols, y1, mask=mask)
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
def _layer_norm_fwd(
|
| 301 |
-
x,
|
| 302 |
-
weight,
|
| 303 |
-
bias,
|
| 304 |
-
eps,
|
| 305 |
-
residual=None,
|
| 306 |
-
x1=None,
|
| 307 |
-
weight1=None,
|
| 308 |
-
bias1=None,
|
| 309 |
-
dropout_p=0.0,
|
| 310 |
-
rowscale=None,
|
| 311 |
-
out_dtype=None,
|
| 312 |
-
residual_dtype=None,
|
| 313 |
-
is_rms_norm=False,
|
| 314 |
-
return_dropout_mask=False,
|
| 315 |
-
):
|
| 316 |
-
if residual is not None:
|
| 317 |
-
residual_dtype = residual.dtype
|
| 318 |
-
M, N = x.shape
|
| 319 |
-
assert x.stride(-1) == 1
|
| 320 |
-
if residual is not None:
|
| 321 |
-
assert residual.stride(-1) == 1
|
| 322 |
-
assert residual.shape == (M, N)
|
| 323 |
-
assert weight.shape == (N,)
|
| 324 |
-
assert weight.stride(-1) == 1
|
| 325 |
-
if bias is not None:
|
| 326 |
-
assert bias.stride(-1) == 1
|
| 327 |
-
assert bias.shape == (N,)
|
| 328 |
-
if x1 is not None:
|
| 329 |
-
assert x1.shape == x.shape
|
| 330 |
-
assert rowscale is None
|
| 331 |
-
assert x1.stride(-1) == 1
|
| 332 |
-
if weight1 is not None:
|
| 333 |
-
assert weight1.shape == (N,)
|
| 334 |
-
assert weight1.stride(-1) == 1
|
| 335 |
-
if bias1 is not None:
|
| 336 |
-
assert bias1.shape == (N,)
|
| 337 |
-
assert bias1.stride(-1) == 1
|
| 338 |
-
if rowscale is not None:
|
| 339 |
-
assert rowscale.is_contiguous()
|
| 340 |
-
assert rowscale.shape == (M,)
|
| 341 |
-
# allocate output
|
| 342 |
-
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
|
| 343 |
-
assert y.stride(-1) == 1
|
| 344 |
-
if weight1 is not None:
|
| 345 |
-
y1 = torch.empty_like(y)
|
| 346 |
-
assert y1.stride(-1) == 1
|
| 347 |
-
else:
|
| 348 |
-
y1 = None
|
| 349 |
-
if (
|
| 350 |
-
residual is not None
|
| 351 |
-
or (residual_dtype is not None and residual_dtype != x.dtype)
|
| 352 |
-
or dropout_p > 0.0
|
| 353 |
-
or rowscale is not None
|
| 354 |
-
or x1 is not None
|
| 355 |
-
):
|
| 356 |
-
residual_out = torch.empty(
|
| 357 |
-
M,
|
| 358 |
-
N,
|
| 359 |
-
device=x.device,
|
| 360 |
-
dtype=residual_dtype if residual_dtype is not None else x.dtype,
|
| 361 |
-
)
|
| 362 |
-
assert residual_out.stride(-1) == 1
|
| 363 |
-
else:
|
| 364 |
-
residual_out = None
|
| 365 |
-
mean = (
|
| 366 |
-
torch.empty((M,), dtype=torch.float32, device=x.device)
|
| 367 |
-
if not is_rms_norm
|
| 368 |
-
else None
|
| 369 |
-
)
|
| 370 |
-
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
|
| 371 |
-
if dropout_p > 0.0:
|
| 372 |
-
seeds = torch.randint(
|
| 373 |
-
2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
|
| 374 |
-
)
|
| 375 |
-
else:
|
| 376 |
-
seeds = None
|
| 377 |
-
if return_dropout_mask and dropout_p > 0.0:
|
| 378 |
-
dropout_mask = torch.empty(
|
| 379 |
-
M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool
|
| 380 |
-
)
|
| 381 |
-
else:
|
| 382 |
-
dropout_mask = None
|
| 383 |
-
# Less than 64KB per feature: enqueue fused kernel
|
| 384 |
-
MAX_FUSED_SIZE = 65536 // x.element_size()
|
| 385 |
-
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
| 386 |
-
if N > BLOCK_N:
|
| 387 |
-
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
| 388 |
-
with torch.cuda.device(x.device.index):
|
| 389 |
-
_layer_norm_fwd_1pass_kernel[(M,)](
|
| 390 |
-
x,
|
| 391 |
-
y,
|
| 392 |
-
weight,
|
| 393 |
-
bias,
|
| 394 |
-
residual,
|
| 395 |
-
x1,
|
| 396 |
-
weight1,
|
| 397 |
-
bias1,
|
| 398 |
-
y1,
|
| 399 |
-
residual_out,
|
| 400 |
-
rowscale,
|
| 401 |
-
seeds,
|
| 402 |
-
dropout_mask,
|
| 403 |
-
mean,
|
| 404 |
-
rstd,
|
| 405 |
-
x.stride(0),
|
| 406 |
-
y.stride(0),
|
| 407 |
-
residual.stride(0) if residual is not None else 0,
|
| 408 |
-
residual_out.stride(0) if residual_out is not None else 0,
|
| 409 |
-
x1.stride(0) if x1 is not None else 0,
|
| 410 |
-
y1.stride(0) if y1 is not None else 0,
|
| 411 |
-
M,
|
| 412 |
-
N,
|
| 413 |
-
eps,
|
| 414 |
-
dropout_p,
|
| 415 |
-
is_rms_norm,
|
| 416 |
-
BLOCK_N,
|
| 417 |
-
residual is not None,
|
| 418 |
-
residual_out is not None,
|
| 419 |
-
bias is not None,
|
| 420 |
-
dropout_p > 0.0,
|
| 421 |
-
dropout_mask is not None,
|
| 422 |
-
rowscale is not None,
|
| 423 |
-
)
|
| 424 |
-
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
|
| 425 |
-
if dropout_mask is not None and x1 is not None:
|
| 426 |
-
dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
|
| 427 |
-
else:
|
| 428 |
-
dropout_mask1 = None
|
| 429 |
-
return (
|
| 430 |
-
y,
|
| 431 |
-
y1,
|
| 432 |
-
mean,
|
| 433 |
-
rstd,
|
| 434 |
-
residual_out if residual_out is not None else x,
|
| 435 |
-
seeds,
|
| 436 |
-
dropout_mask,
|
| 437 |
-
dropout_mask1,
|
| 438 |
-
)
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
@triton.autotune(
|
| 442 |
-
configs=pruned_configs_autotune,
|
| 443 |
-
key=[
|
| 444 |
-
"N",
|
| 445 |
-
"HAS_DRESIDUAL",
|
| 446 |
-
"STORE_DRESIDUAL",
|
| 447 |
-
"IS_RMS_NORM",
|
| 448 |
-
"HAS_BIAS",
|
| 449 |
-
"HAS_DROPOUT",
|
| 450 |
-
],
|
| 451 |
-
)
|
| 452 |
-
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
| 453 |
-
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
|
| 454 |
-
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
|
| 455 |
-
@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
|
| 456 |
-
@triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
|
| 457 |
-
@triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
|
| 458 |
-
@triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
|
| 459 |
-
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
|
| 460 |
-
@triton.jit
|
| 461 |
-
def _layer_norm_bwd_kernel(
|
| 462 |
-
X, # pointer to the input
|
| 463 |
-
W, # pointer to the weights
|
| 464 |
-
B, # pointer to the biases
|
| 465 |
-
Y, # pointer to the output to be recomputed
|
| 466 |
-
DY, # pointer to the output gradient
|
| 467 |
-
DX, # pointer to the input gradient
|
| 468 |
-
DW, # pointer to the partial sum of weights gradient
|
| 469 |
-
DB, # pointer to the partial sum of biases gradient
|
| 470 |
-
DRESIDUAL,
|
| 471 |
-
W1,
|
| 472 |
-
DY1,
|
| 473 |
-
DX1,
|
| 474 |
-
DW1,
|
| 475 |
-
DB1,
|
| 476 |
-
DRESIDUAL_IN,
|
| 477 |
-
ROWSCALE,
|
| 478 |
-
SEEDS,
|
| 479 |
-
Mean, # pointer to the mean
|
| 480 |
-
Rstd, # pointer to the 1/std
|
| 481 |
-
stride_x_row, # how much to increase the pointer when moving by 1 row
|
| 482 |
-
stride_y_row,
|
| 483 |
-
stride_dy_row,
|
| 484 |
-
stride_dx_row,
|
| 485 |
-
stride_dres_row,
|
| 486 |
-
stride_dy1_row,
|
| 487 |
-
stride_dx1_row,
|
| 488 |
-
stride_dres_in_row,
|
| 489 |
-
M, # number of rows in X
|
| 490 |
-
N, # number of columns in X
|
| 491 |
-
eps, # epsilon to avoid division by zero
|
| 492 |
-
dropout_p,
|
| 493 |
-
rows_per_program,
|
| 494 |
-
IS_RMS_NORM: tl.constexpr,
|
| 495 |
-
BLOCK_N: tl.constexpr,
|
| 496 |
-
HAS_DRESIDUAL: tl.constexpr,
|
| 497 |
-
STORE_DRESIDUAL: tl.constexpr,
|
| 498 |
-
HAS_BIAS: tl.constexpr,
|
| 499 |
-
HAS_DROPOUT: tl.constexpr,
|
| 500 |
-
HAS_ROWSCALE: tl.constexpr,
|
| 501 |
-
HAS_DY1: tl.constexpr,
|
| 502 |
-
HAS_DX1: tl.constexpr,
|
| 503 |
-
HAS_B1: tl.constexpr,
|
| 504 |
-
RECOMPUTE_OUTPUT: tl.constexpr,
|
| 505 |
-
):
|
| 506 |
-
# Map the program id to the elements of X, DX, and DY it should compute.
|
| 507 |
-
row_block_id = tl.program_id(0)
|
| 508 |
-
row_start = row_block_id * rows_per_program
|
| 509 |
-
# Do not early exit if row_start >= M, because we need to write DW and DB
|
| 510 |
-
cols = tl.arange(0, BLOCK_N)
|
| 511 |
-
mask = cols < N
|
| 512 |
-
X += row_start * stride_x_row
|
| 513 |
-
if HAS_DRESIDUAL:
|
| 514 |
-
DRESIDUAL += row_start * stride_dres_row
|
| 515 |
-
if STORE_DRESIDUAL:
|
| 516 |
-
DRESIDUAL_IN += row_start * stride_dres_in_row
|
| 517 |
-
DY += row_start * stride_dy_row
|
| 518 |
-
DX += row_start * stride_dx_row
|
| 519 |
-
if HAS_DY1:
|
| 520 |
-
DY1 += row_start * stride_dy1_row
|
| 521 |
-
if HAS_DX1:
|
| 522 |
-
DX1 += row_start * stride_dx1_row
|
| 523 |
-
if RECOMPUTE_OUTPUT:
|
| 524 |
-
Y += row_start * stride_y_row
|
| 525 |
-
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
| 526 |
-
if RECOMPUTE_OUTPUT and HAS_BIAS:
|
| 527 |
-
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
|
| 528 |
-
if HAS_DY1:
|
| 529 |
-
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
|
| 530 |
-
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 531 |
-
if HAS_BIAS:
|
| 532 |
-
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 533 |
-
if HAS_DY1:
|
| 534 |
-
dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 535 |
-
if HAS_B1:
|
| 536 |
-
db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 537 |
-
row_end = min((row_block_id + 1) * rows_per_program, M)
|
| 538 |
-
for row in range(row_start, row_end):
|
| 539 |
-
# Load data to SRAM
|
| 540 |
-
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
| 541 |
-
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
| 542 |
-
if HAS_DY1:
|
| 543 |
-
dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
|
| 544 |
-
if not IS_RMS_NORM:
|
| 545 |
-
mean = tl.load(Mean + row)
|
| 546 |
-
rstd = tl.load(Rstd + row)
|
| 547 |
-
# Compute dx
|
| 548 |
-
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
| 549 |
-
xhat = tl.where(mask, xhat, 0.0)
|
| 550 |
-
if RECOMPUTE_OUTPUT:
|
| 551 |
-
y = xhat * w + b if HAS_BIAS else xhat * w
|
| 552 |
-
tl.store(Y + cols, y, mask=mask)
|
| 553 |
-
wdy = w * dy
|
| 554 |
-
dw += dy * xhat
|
| 555 |
-
if HAS_BIAS:
|
| 556 |
-
db += dy
|
| 557 |
-
if HAS_DY1:
|
| 558 |
-
wdy += w1 * dy1
|
| 559 |
-
dw1 += dy1 * xhat
|
| 560 |
-
if HAS_B1:
|
| 561 |
-
db1 += dy1
|
| 562 |
-
if not IS_RMS_NORM:
|
| 563 |
-
c1 = tl.sum(xhat * wdy, axis=0) / N
|
| 564 |
-
c2 = tl.sum(wdy, axis=0) / N
|
| 565 |
-
dx = (wdy - (xhat * c1 + c2)) * rstd
|
| 566 |
-
else:
|
| 567 |
-
c1 = tl.sum(xhat * wdy, axis=0) / N
|
| 568 |
-
dx = (wdy - xhat * c1) * rstd
|
| 569 |
-
if HAS_DRESIDUAL:
|
| 570 |
-
dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
|
| 571 |
-
dx += dres
|
| 572 |
-
# Write dx
|
| 573 |
-
if STORE_DRESIDUAL:
|
| 574 |
-
tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
|
| 575 |
-
if HAS_DX1:
|
| 576 |
-
if HAS_DROPOUT:
|
| 577 |
-
keep_mask = (
|
| 578 |
-
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
|
| 579 |
-
> dropout_p
|
| 580 |
-
)
|
| 581 |
-
dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
|
| 582 |
-
else:
|
| 583 |
-
dx1 = dx
|
| 584 |
-
tl.store(DX1 + cols, dx1, mask=mask)
|
| 585 |
-
if HAS_DROPOUT:
|
| 586 |
-
keep_mask = (
|
| 587 |
-
tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7)
|
| 588 |
-
> dropout_p
|
| 589 |
-
)
|
| 590 |
-
dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
|
| 591 |
-
if HAS_ROWSCALE:
|
| 592 |
-
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
|
| 593 |
-
dx *= rowscale
|
| 594 |
-
tl.store(DX + cols, dx, mask=mask)
|
| 595 |
-
|
| 596 |
-
X += stride_x_row
|
| 597 |
-
if HAS_DRESIDUAL:
|
| 598 |
-
DRESIDUAL += stride_dres_row
|
| 599 |
-
if STORE_DRESIDUAL:
|
| 600 |
-
DRESIDUAL_IN += stride_dres_in_row
|
| 601 |
-
if RECOMPUTE_OUTPUT:
|
| 602 |
-
Y += stride_y_row
|
| 603 |
-
DY += stride_dy_row
|
| 604 |
-
DX += stride_dx_row
|
| 605 |
-
if HAS_DY1:
|
| 606 |
-
DY1 += stride_dy1_row
|
| 607 |
-
if HAS_DX1:
|
| 608 |
-
DX1 += stride_dx1_row
|
| 609 |
-
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
|
| 610 |
-
if HAS_BIAS:
|
| 611 |
-
tl.store(DB + row_block_id * N + cols, db, mask=mask)
|
| 612 |
-
if HAS_DY1:
|
| 613 |
-
tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
|
| 614 |
-
if HAS_B1:
|
| 615 |
-
tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
def _layer_norm_bwd(
|
| 619 |
-
dy,
|
| 620 |
-
x,
|
| 621 |
-
weight,
|
| 622 |
-
bias,
|
| 623 |
-
eps,
|
| 624 |
-
mean,
|
| 625 |
-
rstd,
|
| 626 |
-
dresidual=None,
|
| 627 |
-
dy1=None,
|
| 628 |
-
weight1=None,
|
| 629 |
-
bias1=None,
|
| 630 |
-
seeds=None,
|
| 631 |
-
dropout_p=0.0,
|
| 632 |
-
rowscale=None,
|
| 633 |
-
has_residual=False,
|
| 634 |
-
has_x1=False,
|
| 635 |
-
is_rms_norm=False,
|
| 636 |
-
x_dtype=None,
|
| 637 |
-
recompute_output=False,
|
| 638 |
-
):
|
| 639 |
-
M, N = x.shape
|
| 640 |
-
assert x.stride(-1) == 1
|
| 641 |
-
assert dy.stride(-1) == 1
|
| 642 |
-
assert dy.shape == (M, N)
|
| 643 |
-
if dresidual is not None:
|
| 644 |
-
assert dresidual.stride(-1) == 1
|
| 645 |
-
assert dresidual.shape == (M, N)
|
| 646 |
-
assert weight.shape == (N,)
|
| 647 |
-
assert weight.stride(-1) == 1
|
| 648 |
-
if bias is not None:
|
| 649 |
-
assert bias.stride(-1) == 1
|
| 650 |
-
assert bias.shape == (N,)
|
| 651 |
-
if dy1 is not None:
|
| 652 |
-
assert weight1 is not None
|
| 653 |
-
assert dy1.shape == dy.shape
|
| 654 |
-
assert dy1.stride(-1) == 1
|
| 655 |
-
if weight1 is not None:
|
| 656 |
-
assert weight1.shape == (N,)
|
| 657 |
-
assert weight1.stride(-1) == 1
|
| 658 |
-
if bias1 is not None:
|
| 659 |
-
assert bias1.shape == (N,)
|
| 660 |
-
assert bias1.stride(-1) == 1
|
| 661 |
-
if seeds is not None:
|
| 662 |
-
assert seeds.is_contiguous()
|
| 663 |
-
assert seeds.shape == (M if not has_x1 else M * 2,)
|
| 664 |
-
if rowscale is not None:
|
| 665 |
-
assert rowscale.is_contiguous()
|
| 666 |
-
assert rowscale.shape == (M,)
|
| 667 |
-
# allocate output
|
| 668 |
-
dx = (
|
| 669 |
-
torch.empty_like(x)
|
| 670 |
-
if x_dtype is None
|
| 671 |
-
else torch.empty(M, N, dtype=x_dtype, device=x.device)
|
| 672 |
-
)
|
| 673 |
-
dresidual_in = (
|
| 674 |
-
torch.empty_like(x)
|
| 675 |
-
if has_residual
|
| 676 |
-
and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
|
| 677 |
-
else None
|
| 678 |
-
)
|
| 679 |
-
dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
|
| 680 |
-
y = (
|
| 681 |
-
torch.empty(M, N, dtype=dy.dtype, device=dy.device)
|
| 682 |
-
if recompute_output
|
| 683 |
-
else None
|
| 684 |
-
)
|
| 685 |
-
if recompute_output:
|
| 686 |
-
assert (
|
| 687 |
-
weight1 is None
|
| 688 |
-
), "recompute_output is not supported with parallel LayerNorm"
|
| 689 |
-
|
| 690 |
-
# Less than 64KB per feature: enqueue fused kernel
|
| 691 |
-
MAX_FUSED_SIZE = 65536 // x.element_size()
|
| 692 |
-
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
| 693 |
-
if N > BLOCK_N:
|
| 694 |
-
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
| 695 |
-
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
| 696 |
-
_dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
|
| 697 |
-
_db = (
|
| 698 |
-
torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
|
| 699 |
-
if bias is not None
|
| 700 |
-
else None
|
| 701 |
-
)
|
| 702 |
-
_dw1 = torch.empty_like(_dw) if weight1 is not None else None
|
| 703 |
-
_db1 = torch.empty_like(_db) if bias1 is not None else None
|
| 704 |
-
rows_per_program = math.ceil(M / sm_count)
|
| 705 |
-
grid = (sm_count,)
|
| 706 |
-
with torch.cuda.device(x.device.index):
|
| 707 |
-
_layer_norm_bwd_kernel[grid](
|
| 708 |
-
x,
|
| 709 |
-
weight,
|
| 710 |
-
bias,
|
| 711 |
-
y,
|
| 712 |
-
dy,
|
| 713 |
-
dx,
|
| 714 |
-
_dw,
|
| 715 |
-
_db,
|
| 716 |
-
dresidual,
|
| 717 |
-
weight1,
|
| 718 |
-
dy1,
|
| 719 |
-
dx1,
|
| 720 |
-
_dw1,
|
| 721 |
-
_db1,
|
| 722 |
-
dresidual_in,
|
| 723 |
-
rowscale,
|
| 724 |
-
seeds,
|
| 725 |
-
mean,
|
| 726 |
-
rstd,
|
| 727 |
-
x.stride(0),
|
| 728 |
-
0 if not recompute_output else y.stride(0),
|
| 729 |
-
dy.stride(0),
|
| 730 |
-
dx.stride(0),
|
| 731 |
-
dresidual.stride(0) if dresidual is not None else 0,
|
| 732 |
-
dy1.stride(0) if dy1 is not None else 0,
|
| 733 |
-
dx1.stride(0) if dx1 is not None else 0,
|
| 734 |
-
dresidual_in.stride(0) if dresidual_in is not None else 0,
|
| 735 |
-
M,
|
| 736 |
-
N,
|
| 737 |
-
eps,
|
| 738 |
-
dropout_p,
|
| 739 |
-
rows_per_program,
|
| 740 |
-
is_rms_norm,
|
| 741 |
-
BLOCK_N,
|
| 742 |
-
dresidual is not None,
|
| 743 |
-
dresidual_in is not None,
|
| 744 |
-
bias is not None,
|
| 745 |
-
dropout_p > 0.0,
|
| 746 |
-
)
|
| 747 |
-
dw = _dw.sum(0).to(weight.dtype)
|
| 748 |
-
db = _db.sum(0).to(bias.dtype) if bias is not None else None
|
| 749 |
-
dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
|
| 750 |
-
db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
|
| 751 |
-
# Don't need to compute dresidual_in separately in this case
|
| 752 |
-
if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
|
| 753 |
-
dresidual_in = dx
|
| 754 |
-
if has_x1 and dropout_p == 0.0:
|
| 755 |
-
dx1 = dx
|
| 756 |
-
return (
|
| 757 |
-
(dx, dw, db, dresidual_in, dx1, dw1, db1)
|
| 758 |
-
if not recompute_output
|
| 759 |
-
else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
|
| 760 |
-
)
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
class LayerNormFn(torch.autograd.Function):
|
| 764 |
-
@staticmethod
|
| 765 |
-
def forward(
|
| 766 |
-
ctx,
|
| 767 |
-
x,
|
| 768 |
-
weight,
|
| 769 |
-
bias,
|
| 770 |
-
residual=None,
|
| 771 |
-
x1=None,
|
| 772 |
-
weight1=None,
|
| 773 |
-
bias1=None,
|
| 774 |
-
eps=1e-6,
|
| 775 |
-
dropout_p=0.0,
|
| 776 |
-
rowscale=None,
|
| 777 |
-
prenorm=False,
|
| 778 |
-
residual_in_fp32=False,
|
| 779 |
-
is_rms_norm=False,
|
| 780 |
-
return_dropout_mask=False,
|
| 781 |
-
):
|
| 782 |
-
x_shape_og = x.shape
|
| 783 |
-
# reshape input data into 2D tensor
|
| 784 |
-
x = x.reshape(-1, x.shape[-1])
|
| 785 |
-
if x.stride(-1) != 1:
|
| 786 |
-
x = x.contiguous()
|
| 787 |
-
if residual is not None:
|
| 788 |
-
assert residual.shape == x_shape_og
|
| 789 |
-
residual = residual.reshape(-1, residual.shape[-1])
|
| 790 |
-
if residual.stride(-1) != 1:
|
| 791 |
-
residual = residual.contiguous()
|
| 792 |
-
if x1 is not None:
|
| 793 |
-
assert x1.shape == x_shape_og
|
| 794 |
-
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
| 795 |
-
x1 = x1.reshape(-1, x1.shape[-1])
|
| 796 |
-
if x1.stride(-1) != 1:
|
| 797 |
-
x1 = x1.contiguous()
|
| 798 |
-
weight = weight.contiguous()
|
| 799 |
-
if bias is not None:
|
| 800 |
-
bias = bias.contiguous()
|
| 801 |
-
if weight1 is not None:
|
| 802 |
-
weight1 = weight1.contiguous()
|
| 803 |
-
if bias1 is not None:
|
| 804 |
-
bias1 = bias1.contiguous()
|
| 805 |
-
if rowscale is not None:
|
| 806 |
-
rowscale = rowscale.reshape(-1).contiguous()
|
| 807 |
-
residual_dtype = (
|
| 808 |
-
residual.dtype
|
| 809 |
-
if residual is not None
|
| 810 |
-
else (torch.float32 if residual_in_fp32 else None)
|
| 811 |
-
)
|
| 812 |
-
y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = (
|
| 813 |
-
_layer_norm_fwd(
|
| 814 |
-
x,
|
| 815 |
-
weight,
|
| 816 |
-
bias,
|
| 817 |
-
eps,
|
| 818 |
-
residual,
|
| 819 |
-
x1,
|
| 820 |
-
weight1,
|
| 821 |
-
bias1,
|
| 822 |
-
dropout_p=dropout_p,
|
| 823 |
-
rowscale=rowscale,
|
| 824 |
-
residual_dtype=residual_dtype,
|
| 825 |
-
is_rms_norm=is_rms_norm,
|
| 826 |
-
return_dropout_mask=return_dropout_mask,
|
| 827 |
-
)
|
| 828 |
-
)
|
| 829 |
-
ctx.save_for_backward(
|
| 830 |
-
residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
|
| 831 |
-
)
|
| 832 |
-
ctx.x_shape_og = x_shape_og
|
| 833 |
-
ctx.eps = eps
|
| 834 |
-
ctx.dropout_p = dropout_p
|
| 835 |
-
ctx.is_rms_norm = is_rms_norm
|
| 836 |
-
ctx.has_residual = residual is not None
|
| 837 |
-
ctx.has_x1 = x1 is not None
|
| 838 |
-
ctx.prenorm = prenorm
|
| 839 |
-
ctx.x_dtype = x.dtype
|
| 840 |
-
y = y.reshape(x_shape_og)
|
| 841 |
-
y1 = y1.reshape(x_shape_og) if y1 is not None else None
|
| 842 |
-
residual_out = (
|
| 843 |
-
residual_out.reshape(x_shape_og) if residual_out is not None else None
|
| 844 |
-
)
|
| 845 |
-
dropout_mask = (
|
| 846 |
-
dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
|
| 847 |
-
)
|
| 848 |
-
dropout_mask1 = (
|
| 849 |
-
dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
|
| 850 |
-
)
|
| 851 |
-
if not return_dropout_mask:
|
| 852 |
-
if weight1 is None:
|
| 853 |
-
return y if not prenorm else (y, residual_out)
|
| 854 |
-
else:
|
| 855 |
-
return (y, y1) if not prenorm else (y, y1, residual_out)
|
| 856 |
-
else:
|
| 857 |
-
if weight1 is None:
|
| 858 |
-
return (
|
| 859 |
-
(y, dropout_mask, dropout_mask1)
|
| 860 |
-
if not prenorm
|
| 861 |
-
else (y, residual_out, dropout_mask, dropout_mask1)
|
| 862 |
-
)
|
| 863 |
-
else:
|
| 864 |
-
return (
|
| 865 |
-
(y, y1, dropout_mask, dropout_mask1)
|
| 866 |
-
if not prenorm
|
| 867 |
-
else (y, y1, residual_out, dropout_mask, dropout_mask1)
|
| 868 |
-
)
|
| 869 |
-
|
| 870 |
-
@staticmethod
|
| 871 |
-
def backward(ctx, dy, *args):
|
| 872 |
-
x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
|
| 873 |
-
dy = dy.reshape(-1, dy.shape[-1])
|
| 874 |
-
if dy.stride(-1) != 1:
|
| 875 |
-
dy = dy.contiguous()
|
| 876 |
-
assert dy.shape == x.shape
|
| 877 |
-
if weight1 is not None:
|
| 878 |
-
dy1, args = args[0], args[1:]
|
| 879 |
-
dy1 = dy1.reshape(-1, dy1.shape[-1])
|
| 880 |
-
if dy1.stride(-1) != 1:
|
| 881 |
-
dy1 = dy1.contiguous()
|
| 882 |
-
assert dy1.shape == x.shape
|
| 883 |
-
else:
|
| 884 |
-
dy1 = None
|
| 885 |
-
if ctx.prenorm:
|
| 886 |
-
dresidual = args[0]
|
| 887 |
-
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
| 888 |
-
if dresidual.stride(-1) != 1:
|
| 889 |
-
dresidual = dresidual.contiguous()
|
| 890 |
-
assert dresidual.shape == x.shape
|
| 891 |
-
else:
|
| 892 |
-
dresidual = None
|
| 893 |
-
dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
|
| 894 |
-
dy,
|
| 895 |
-
x,
|
| 896 |
-
weight,
|
| 897 |
-
bias,
|
| 898 |
-
ctx.eps,
|
| 899 |
-
mean,
|
| 900 |
-
rstd,
|
| 901 |
-
dresidual,
|
| 902 |
-
dy1,
|
| 903 |
-
weight1,
|
| 904 |
-
bias1,
|
| 905 |
-
seeds,
|
| 906 |
-
ctx.dropout_p,
|
| 907 |
-
rowscale,
|
| 908 |
-
ctx.has_residual,
|
| 909 |
-
ctx.has_x1,
|
| 910 |
-
ctx.is_rms_norm,
|
| 911 |
-
x_dtype=ctx.x_dtype,
|
| 912 |
-
)
|
| 913 |
-
return (
|
| 914 |
-
dx.reshape(ctx.x_shape_og),
|
| 915 |
-
dw,
|
| 916 |
-
db,
|
| 917 |
-
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
| 918 |
-
dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
|
| 919 |
-
dw1,
|
| 920 |
-
db1,
|
| 921 |
-
None,
|
| 922 |
-
None,
|
| 923 |
-
None,
|
| 924 |
-
None,
|
| 925 |
-
None,
|
| 926 |
-
None,
|
| 927 |
-
None,
|
| 928 |
-
)
|
| 929 |
-
|
| 930 |
-
|
| 931 |
-
def layer_norm_fn(
|
| 932 |
-
x,
|
| 933 |
-
weight,
|
| 934 |
-
bias,
|
| 935 |
-
residual=None,
|
| 936 |
-
x1=None,
|
| 937 |
-
weight1=None,
|
| 938 |
-
bias1=None,
|
| 939 |
-
eps=1e-6,
|
| 940 |
-
dropout_p=0.0,
|
| 941 |
-
rowscale=None,
|
| 942 |
-
prenorm=False,
|
| 943 |
-
residual_in_fp32=False,
|
| 944 |
-
is_rms_norm=False,
|
| 945 |
-
return_dropout_mask=False,
|
| 946 |
-
):
|
| 947 |
-
return LayerNormFn.apply(
|
| 948 |
-
x,
|
| 949 |
-
weight,
|
| 950 |
-
bias,
|
| 951 |
-
residual,
|
| 952 |
-
x1,
|
| 953 |
-
weight1,
|
| 954 |
-
bias1,
|
| 955 |
-
eps,
|
| 956 |
-
dropout_p,
|
| 957 |
-
rowscale,
|
| 958 |
-
prenorm,
|
| 959 |
-
residual_in_fp32,
|
| 960 |
-
is_rms_norm,
|
| 961 |
-
return_dropout_mask,
|
| 962 |
-
)
|
| 963 |
-
|
| 964 |
-
|
| 965 |
-
def rms_norm_fn(
|
| 966 |
-
x,
|
| 967 |
-
weight,
|
| 968 |
-
bias,
|
| 969 |
-
residual=None,
|
| 970 |
-
x1=None,
|
| 971 |
-
weight1=None,
|
| 972 |
-
bias1=None,
|
| 973 |
-
eps=1e-6,
|
| 974 |
-
dropout_p=0.0,
|
| 975 |
-
rowscale=None,
|
| 976 |
-
prenorm=False,
|
| 977 |
-
residual_in_fp32=False,
|
| 978 |
-
return_dropout_mask=False,
|
| 979 |
-
):
|
| 980 |
-
return LayerNormFn.apply(
|
| 981 |
-
x,
|
| 982 |
-
weight,
|
| 983 |
-
bias,
|
| 984 |
-
residual,
|
| 985 |
-
x1,
|
| 986 |
-
weight1,
|
| 987 |
-
bias1,
|
| 988 |
-
eps,
|
| 989 |
-
dropout_p,
|
| 990 |
-
rowscale,
|
| 991 |
-
prenorm,
|
| 992 |
-
residual_in_fp32,
|
| 993 |
-
True,
|
| 994 |
-
return_dropout_mask,
|
| 995 |
-
)
|
| 996 |
-
|
| 997 |
-
|
| 998 |
-
class RMSNorm(torch.nn.Module):
|
| 999 |
-
|
| 1000 |
-
def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
|
| 1001 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
| 1002 |
-
super().__init__()
|
| 1003 |
-
self.eps = eps
|
| 1004 |
-
if dropout_p > 0.0:
|
| 1005 |
-
self.drop = torch.nn.Dropout(dropout_p)
|
| 1006 |
-
else:
|
| 1007 |
-
self.drop = None
|
| 1008 |
-
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
| 1009 |
-
self.register_parameter("bias", None)
|
| 1010 |
-
self.reset_parameters()
|
| 1011 |
-
|
| 1012 |
-
def reset_parameters(self):
|
| 1013 |
-
torch.nn.init.ones_(self.weight)
|
| 1014 |
-
|
| 1015 |
-
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
|
| 1016 |
-
return rms_norm_fn(
|
| 1017 |
-
x,
|
| 1018 |
-
self.weight,
|
| 1019 |
-
self.bias,
|
| 1020 |
-
residual=residual,
|
| 1021 |
-
eps=self.eps,
|
| 1022 |
-
dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
|
| 1023 |
-
prenorm=prenorm,
|
| 1024 |
-
residual_in_fp32=residual_in_fp32,
|
| 1025 |
-
)
|
| 1026 |
-
|
| 1027 |
-
|
| 1028 |
-
class LayerNormLinearFn(torch.autograd.Function):
|
| 1029 |
-
@staticmethod
|
| 1030 |
-
@custom_fwd
|
| 1031 |
-
def forward(
|
| 1032 |
-
ctx,
|
| 1033 |
-
x,
|
| 1034 |
-
norm_weight,
|
| 1035 |
-
norm_bias,
|
| 1036 |
-
linear_weight,
|
| 1037 |
-
linear_bias,
|
| 1038 |
-
residual=None,
|
| 1039 |
-
eps=1e-6,
|
| 1040 |
-
prenorm=False,
|
| 1041 |
-
residual_in_fp32=False,
|
| 1042 |
-
is_rms_norm=False,
|
| 1043 |
-
):
|
| 1044 |
-
x_shape_og = x.shape
|
| 1045 |
-
# reshape input data into 2D tensor
|
| 1046 |
-
x = x.reshape(-1, x.shape[-1])
|
| 1047 |
-
if x.stride(-1) != 1:
|
| 1048 |
-
x = x.contiguous()
|
| 1049 |
-
if residual is not None:
|
| 1050 |
-
assert residual.shape == x_shape_og
|
| 1051 |
-
residual = residual.reshape(-1, residual.shape[-1])
|
| 1052 |
-
if residual.stride(-1) != 1:
|
| 1053 |
-
residual = residual.contiguous()
|
| 1054 |
-
norm_weight = norm_weight.contiguous()
|
| 1055 |
-
if norm_bias is not None:
|
| 1056 |
-
norm_bias = norm_bias.contiguous()
|
| 1057 |
-
residual_dtype = (
|
| 1058 |
-
residual.dtype
|
| 1059 |
-
if residual is not None
|
| 1060 |
-
else (torch.float32 if residual_in_fp32 else None)
|
| 1061 |
-
)
|
| 1062 |
-
y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
|
| 1063 |
-
x,
|
| 1064 |
-
norm_weight,
|
| 1065 |
-
norm_bias,
|
| 1066 |
-
eps,
|
| 1067 |
-
residual,
|
| 1068 |
-
out_dtype=(
|
| 1069 |
-
None
|
| 1070 |
-
if not torch.is_autocast_enabled()
|
| 1071 |
-
else torch.get_autocast_gpu_dtype()
|
| 1072 |
-
),
|
| 1073 |
-
residual_dtype=residual_dtype,
|
| 1074 |
-
is_rms_norm=is_rms_norm,
|
| 1075 |
-
)
|
| 1076 |
-
y = y.reshape(x_shape_og)
|
| 1077 |
-
dtype = (
|
| 1078 |
-
torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
|
| 1079 |
-
)
|
| 1080 |
-
linear_weight = linear_weight.to(dtype)
|
| 1081 |
-
linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
|
| 1082 |
-
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
|
| 1083 |
-
# We don't store y, will be recomputed in the backward pass to save memory
|
| 1084 |
-
ctx.save_for_backward(
|
| 1085 |
-
residual_out, norm_weight, norm_bias, linear_weight, mean, rstd
|
| 1086 |
-
)
|
| 1087 |
-
ctx.x_shape_og = x_shape_og
|
| 1088 |
-
ctx.eps = eps
|
| 1089 |
-
ctx.is_rms_norm = is_rms_norm
|
| 1090 |
-
ctx.has_residual = residual is not None
|
| 1091 |
-
ctx.prenorm = prenorm
|
| 1092 |
-
ctx.x_dtype = x.dtype
|
| 1093 |
-
ctx.linear_bias_is_none = linear_bias is None
|
| 1094 |
-
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
|
| 1095 |
-
|
| 1096 |
-
@staticmethod
|
| 1097 |
-
@custom_bwd
|
| 1098 |
-
def backward(ctx, dout, *args):
|
| 1099 |
-
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
|
| 1100 |
-
dout = dout.reshape(-1, dout.shape[-1])
|
| 1101 |
-
dy = F.linear(dout, linear_weight.t())
|
| 1102 |
-
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
|
| 1103 |
-
if dy.stride(-1) != 1:
|
| 1104 |
-
dy = dy.contiguous()
|
| 1105 |
-
assert dy.shape == x.shape
|
| 1106 |
-
if ctx.prenorm:
|
| 1107 |
-
dresidual = args[0]
|
| 1108 |
-
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
| 1109 |
-
if dresidual.stride(-1) != 1:
|
| 1110 |
-
dresidual = dresidual.contiguous()
|
| 1111 |
-
assert dresidual.shape == x.shape
|
| 1112 |
-
else:
|
| 1113 |
-
dresidual = None
|
| 1114 |
-
dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
|
| 1115 |
-
dy,
|
| 1116 |
-
x,
|
| 1117 |
-
norm_weight,
|
| 1118 |
-
norm_bias,
|
| 1119 |
-
ctx.eps,
|
| 1120 |
-
mean,
|
| 1121 |
-
rstd,
|
| 1122 |
-
dresidual=dresidual,
|
| 1123 |
-
has_residual=ctx.has_residual,
|
| 1124 |
-
is_rms_norm=ctx.is_rms_norm,
|
| 1125 |
-
x_dtype=ctx.x_dtype,
|
| 1126 |
-
recompute_output=True,
|
| 1127 |
-
)
|
| 1128 |
-
dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
|
| 1129 |
-
return (
|
| 1130 |
-
dx.reshape(ctx.x_shape_og),
|
| 1131 |
-
dnorm_weight,
|
| 1132 |
-
dnorm_bias,
|
| 1133 |
-
dlinear_weight,
|
| 1134 |
-
dlinear_bias,
|
| 1135 |
-
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
| 1136 |
-
None,
|
| 1137 |
-
None,
|
| 1138 |
-
None,
|
| 1139 |
-
None,
|
| 1140 |
-
)
|
| 1141 |
-
|
| 1142 |
-
|
| 1143 |
-
def layer_norm_linear_fn(
|
| 1144 |
-
x,
|
| 1145 |
-
norm_weight,
|
| 1146 |
-
norm_bias,
|
| 1147 |
-
linear_weight,
|
| 1148 |
-
linear_bias,
|
| 1149 |
-
residual=None,
|
| 1150 |
-
eps=1e-6,
|
| 1151 |
-
prenorm=False,
|
| 1152 |
-
residual_in_fp32=False,
|
| 1153 |
-
is_rms_norm=False,
|
| 1154 |
-
):
|
| 1155 |
-
return LayerNormLinearFn.apply(
|
| 1156 |
-
x,
|
| 1157 |
-
norm_weight,
|
| 1158 |
-
norm_bias,
|
| 1159 |
-
linear_weight,
|
| 1160 |
-
linear_bias,
|
| 1161 |
-
residual,
|
| 1162 |
-
eps,
|
| 1163 |
-
prenorm,
|
| 1164 |
-
residual_in_fp32,
|
| 1165 |
-
is_rms_norm,
|
| 1166 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/triton/selective_state_update.py
DELETED
|
@@ -1,389 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
| 2 |
-
|
| 3 |
-
"""We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import math
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
|
| 10 |
-
import triton
|
| 11 |
-
import triton.language as tl
|
| 12 |
-
|
| 13 |
-
from einops import rearrange, repeat
|
| 14 |
-
|
| 15 |
-
from .softplus import softplus
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
|
| 19 |
-
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
|
| 20 |
-
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
|
| 21 |
-
@triton.heuristics(
|
| 22 |
-
{
|
| 23 |
-
"HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"]
|
| 24 |
-
is not None
|
| 25 |
-
}
|
| 26 |
-
)
|
| 27 |
-
@triton.heuristics(
|
| 28 |
-
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}
|
| 29 |
-
)
|
| 30 |
-
@triton.jit
|
| 31 |
-
def _selective_scan_update_kernel(
|
| 32 |
-
# Pointers to matrices
|
| 33 |
-
state_ptr,
|
| 34 |
-
x_ptr,
|
| 35 |
-
dt_ptr,
|
| 36 |
-
dt_bias_ptr,
|
| 37 |
-
A_ptr,
|
| 38 |
-
B_ptr,
|
| 39 |
-
C_ptr,
|
| 40 |
-
D_ptr,
|
| 41 |
-
z_ptr,
|
| 42 |
-
out_ptr,
|
| 43 |
-
state_batch_indices_ptr,
|
| 44 |
-
# Matrix dimensions
|
| 45 |
-
batch,
|
| 46 |
-
nheads,
|
| 47 |
-
dim,
|
| 48 |
-
dstate,
|
| 49 |
-
nheads_ngroups_ratio,
|
| 50 |
-
# Strides
|
| 51 |
-
stride_state_batch,
|
| 52 |
-
stride_state_head,
|
| 53 |
-
stride_state_dim,
|
| 54 |
-
stride_state_dstate,
|
| 55 |
-
stride_x_batch,
|
| 56 |
-
stride_x_head,
|
| 57 |
-
stride_x_dim,
|
| 58 |
-
stride_dt_batch,
|
| 59 |
-
stride_dt_head,
|
| 60 |
-
stride_dt_dim,
|
| 61 |
-
stride_dt_bias_head,
|
| 62 |
-
stride_dt_bias_dim,
|
| 63 |
-
stride_A_head,
|
| 64 |
-
stride_A_dim,
|
| 65 |
-
stride_A_dstate,
|
| 66 |
-
stride_B_batch,
|
| 67 |
-
stride_B_group,
|
| 68 |
-
stride_B_dstate,
|
| 69 |
-
stride_C_batch,
|
| 70 |
-
stride_C_group,
|
| 71 |
-
stride_C_dstate,
|
| 72 |
-
stride_D_head,
|
| 73 |
-
stride_D_dim,
|
| 74 |
-
stride_z_batch,
|
| 75 |
-
stride_z_head,
|
| 76 |
-
stride_z_dim,
|
| 77 |
-
stride_out_batch,
|
| 78 |
-
stride_out_head,
|
| 79 |
-
stride_out_dim,
|
| 80 |
-
# Meta-parameters
|
| 81 |
-
DT_SOFTPLUS: tl.constexpr,
|
| 82 |
-
TIE_HDIM: tl.constexpr,
|
| 83 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 84 |
-
HAS_DT_BIAS: tl.constexpr,
|
| 85 |
-
HAS_D: tl.constexpr,
|
| 86 |
-
HAS_Z: tl.constexpr,
|
| 87 |
-
HAS_STATE_BATCH_INDICES: tl.constexpr,
|
| 88 |
-
BLOCK_SIZE_DSTATE: tl.constexpr,
|
| 89 |
-
):
|
| 90 |
-
pid_m = tl.program_id(axis=0)
|
| 91 |
-
pid_b = tl.program_id(axis=1)
|
| 92 |
-
pid_h = tl.program_id(axis=2)
|
| 93 |
-
|
| 94 |
-
if HAS_STATE_BATCH_INDICES:
|
| 95 |
-
state_batch_indices_ptr += pid_b
|
| 96 |
-
state_batch_idx = tl.load(state_batch_indices_ptr)
|
| 97 |
-
state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
|
| 98 |
-
else:
|
| 99 |
-
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
|
| 100 |
-
|
| 101 |
-
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
|
| 102 |
-
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
|
| 103 |
-
if HAS_DT_BIAS:
|
| 104 |
-
dt_bias_ptr += pid_h * stride_dt_bias_head
|
| 105 |
-
A_ptr += pid_h * stride_A_head
|
| 106 |
-
B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
|
| 107 |
-
C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
|
| 108 |
-
if HAS_Z:
|
| 109 |
-
z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
|
| 110 |
-
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
|
| 111 |
-
|
| 112 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 113 |
-
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
|
| 114 |
-
state_ptrs = state_ptr + (
|
| 115 |
-
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
|
| 116 |
-
)
|
| 117 |
-
x_ptrs = x_ptr + offs_m * stride_x_dim
|
| 118 |
-
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
|
| 119 |
-
if HAS_DT_BIAS:
|
| 120 |
-
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
|
| 121 |
-
if HAS_D:
|
| 122 |
-
D_ptr += pid_h * stride_D_head
|
| 123 |
-
A_ptrs = A_ptr + (
|
| 124 |
-
offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
|
| 125 |
-
)
|
| 126 |
-
B_ptrs = B_ptr + offs_n * stride_B_dstate
|
| 127 |
-
C_ptrs = C_ptr + offs_n * stride_C_dstate
|
| 128 |
-
if HAS_D:
|
| 129 |
-
D_ptrs = D_ptr + offs_m * stride_D_dim
|
| 130 |
-
if HAS_Z:
|
| 131 |
-
z_ptrs = z_ptr + offs_m * stride_z_dim
|
| 132 |
-
out_ptrs = out_ptr + offs_m * stride_out_dim
|
| 133 |
-
|
| 134 |
-
state = tl.load(
|
| 135 |
-
state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
|
| 136 |
-
)
|
| 137 |
-
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 138 |
-
if not TIE_HDIM:
|
| 139 |
-
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 140 |
-
if HAS_DT_BIAS:
|
| 141 |
-
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 142 |
-
if DT_SOFTPLUS:
|
| 143 |
-
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
| 144 |
-
A = tl.load(
|
| 145 |
-
A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
|
| 146 |
-
).to(tl.float32)
|
| 147 |
-
dA = tl.exp(A * dt[:, None])
|
| 148 |
-
else:
|
| 149 |
-
dt = tl.load(dt_ptr).to(tl.float32)
|
| 150 |
-
if HAS_DT_BIAS:
|
| 151 |
-
dt += tl.load(dt_bias_ptr).to(tl.float32)
|
| 152 |
-
if DT_SOFTPLUS:
|
| 153 |
-
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
| 154 |
-
A = tl.load(A_ptr).to(tl.float32)
|
| 155 |
-
dA = tl.exp(A * dt) # scalar, not a matrix
|
| 156 |
-
|
| 157 |
-
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
| 158 |
-
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
| 159 |
-
if HAS_D:
|
| 160 |
-
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 161 |
-
if HAS_Z:
|
| 162 |
-
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 163 |
-
|
| 164 |
-
if not TIE_HDIM:
|
| 165 |
-
dB = B[None, :] * dt[:, None]
|
| 166 |
-
else:
|
| 167 |
-
dB = B * dt # vector of size (dstate,)
|
| 168 |
-
state = state * dA + dB * x[:, None]
|
| 169 |
-
tl.store(
|
| 170 |
-
state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
| 171 |
-
)
|
| 172 |
-
out = tl.sum(state * C[None, :], axis=1)
|
| 173 |
-
if HAS_D:
|
| 174 |
-
out += x * D
|
| 175 |
-
if HAS_Z:
|
| 176 |
-
out *= z * tl.sigmoid(z)
|
| 177 |
-
tl.store(out_ptrs, out, mask=offs_m < dim)
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
def selective_state_update(
|
| 181 |
-
state,
|
| 182 |
-
x,
|
| 183 |
-
dt,
|
| 184 |
-
A,
|
| 185 |
-
B,
|
| 186 |
-
C,
|
| 187 |
-
D=None,
|
| 188 |
-
z=None,
|
| 189 |
-
dt_bias=None,
|
| 190 |
-
dt_softplus=False,
|
| 191 |
-
state_batch_indices=None,
|
| 192 |
-
):
|
| 193 |
-
"""
|
| 194 |
-
Argument:
|
| 195 |
-
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
| 196 |
-
x: (batch, dim) or (batch, nheads, dim)
|
| 197 |
-
dt: (batch, dim) or (batch, nheads, dim)
|
| 198 |
-
A: (dim, dstate) or (nheads, dim, dstate)
|
| 199 |
-
B: (batch, dstate) or (batch, ngroups, dstate)
|
| 200 |
-
C: (batch, dstate) or (batch, ngroups, dstate)
|
| 201 |
-
D: (dim,) or (nheads, dim)
|
| 202 |
-
z: (batch, dim) or (batch, nheads, dim)
|
| 203 |
-
dt_bias: (dim,) or (nheads, dim)
|
| 204 |
-
Return:
|
| 205 |
-
out: (batch, dim) or (batch, nheads, dim)
|
| 206 |
-
"""
|
| 207 |
-
has_heads = state.dim() > 3
|
| 208 |
-
if state.dim() == 3:
|
| 209 |
-
state = state.unsqueeze(1)
|
| 210 |
-
if x.dim() == 2:
|
| 211 |
-
x = x.unsqueeze(1)
|
| 212 |
-
if dt.dim() == 2:
|
| 213 |
-
dt = dt.unsqueeze(1)
|
| 214 |
-
if A.dim() == 2:
|
| 215 |
-
A = A.unsqueeze(0)
|
| 216 |
-
if B.dim() == 2:
|
| 217 |
-
B = B.unsqueeze(1)
|
| 218 |
-
if C.dim() == 2:
|
| 219 |
-
C = C.unsqueeze(1)
|
| 220 |
-
if D is not None and D.dim() == 1:
|
| 221 |
-
D = D.unsqueeze(0)
|
| 222 |
-
if z is not None and z.dim() == 2:
|
| 223 |
-
z = z.unsqueeze(1)
|
| 224 |
-
if dt_bias is not None and dt_bias.dim() == 1:
|
| 225 |
-
dt_bias = dt_bias.unsqueeze(0)
|
| 226 |
-
_, nheads, dim, dstate = state.shape
|
| 227 |
-
batch = x.shape[0]
|
| 228 |
-
if x.shape != (batch, nheads, dim):
|
| 229 |
-
print(f"{state.shape} {x.shape} {batch} {nheads} {dim}")
|
| 230 |
-
assert x.shape == (batch, nheads, dim)
|
| 231 |
-
assert dt.shape == x.shape
|
| 232 |
-
assert A.shape == (nheads, dim, dstate)
|
| 233 |
-
ngroups = B.shape[1]
|
| 234 |
-
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
| 235 |
-
assert B.shape == (batch, ngroups, dstate)
|
| 236 |
-
assert C.shape == B.shape
|
| 237 |
-
if D is not None:
|
| 238 |
-
assert D.shape == (nheads, dim)
|
| 239 |
-
if z is not None:
|
| 240 |
-
assert z.shape == x.shape
|
| 241 |
-
if dt_bias is not None:
|
| 242 |
-
assert dt_bias.shape == (nheads, dim)
|
| 243 |
-
if state_batch_indices is not None:
|
| 244 |
-
assert state_batch_indices.shape == (batch,)
|
| 245 |
-
out = torch.empty_like(x)
|
| 246 |
-
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
|
| 247 |
-
z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
|
| 248 |
-
# We don't want autotune since it will overwrite the state
|
| 249 |
-
# We instead tune by hand.
|
| 250 |
-
BLOCK_SIZE_M, num_warps = (
|
| 251 |
-
(32, 4)
|
| 252 |
-
if dstate <= 16
|
| 253 |
-
else (
|
| 254 |
-
(16, 4)
|
| 255 |
-
if dstate <= 32
|
| 256 |
-
else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
|
| 257 |
-
)
|
| 258 |
-
)
|
| 259 |
-
tie_hdim = (
|
| 260 |
-
A.stride(-1) == 0
|
| 261 |
-
and A.stride(-2) == 0
|
| 262 |
-
and dt.stride(-1) == 0
|
| 263 |
-
and dt_bias.stride(-1) == 0
|
| 264 |
-
)
|
| 265 |
-
with torch.cuda.device(x.device.index):
|
| 266 |
-
_selective_scan_update_kernel[grid](
|
| 267 |
-
state,
|
| 268 |
-
x,
|
| 269 |
-
dt,
|
| 270 |
-
dt_bias,
|
| 271 |
-
A,
|
| 272 |
-
B,
|
| 273 |
-
C,
|
| 274 |
-
D,
|
| 275 |
-
z,
|
| 276 |
-
out,
|
| 277 |
-
state_batch_indices,
|
| 278 |
-
batch,
|
| 279 |
-
nheads,
|
| 280 |
-
dim,
|
| 281 |
-
dstate,
|
| 282 |
-
nheads // ngroups,
|
| 283 |
-
state.stride(0),
|
| 284 |
-
state.stride(1),
|
| 285 |
-
state.stride(2),
|
| 286 |
-
state.stride(3),
|
| 287 |
-
x.stride(0),
|
| 288 |
-
x.stride(1),
|
| 289 |
-
x.stride(2),
|
| 290 |
-
dt.stride(0),
|
| 291 |
-
dt.stride(1),
|
| 292 |
-
dt.stride(2),
|
| 293 |
-
*(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
|
| 294 |
-
A.stride(0),
|
| 295 |
-
A.stride(1),
|
| 296 |
-
A.stride(2),
|
| 297 |
-
B.stride(0),
|
| 298 |
-
B.stride(1),
|
| 299 |
-
B.stride(2),
|
| 300 |
-
C.stride(0),
|
| 301 |
-
C.stride(1),
|
| 302 |
-
C.stride(2),
|
| 303 |
-
*(D.stride(0), D.stride(1)) if D is not None else 0,
|
| 304 |
-
z_strides[0],
|
| 305 |
-
z_strides[1],
|
| 306 |
-
z_strides[2],
|
| 307 |
-
out.stride(0),
|
| 308 |
-
out.stride(1),
|
| 309 |
-
out.stride(2),
|
| 310 |
-
dt_softplus,
|
| 311 |
-
tie_hdim,
|
| 312 |
-
BLOCK_SIZE_M,
|
| 313 |
-
num_warps=num_warps,
|
| 314 |
-
)
|
| 315 |
-
if not has_heads:
|
| 316 |
-
out = out.squeeze(1)
|
| 317 |
-
return out
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
def selective_state_update_ref(
|
| 321 |
-
state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
|
| 322 |
-
):
|
| 323 |
-
"""
|
| 324 |
-
Argument:
|
| 325 |
-
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
| 326 |
-
x: (batch, dim) or (batch, nheads, dim)
|
| 327 |
-
dt: (batch, dim) or (batch, nheads, dim)
|
| 328 |
-
A: (dim, dstate) or (nheads, dim, dstate)
|
| 329 |
-
B: (batch, dstate) or (batch, ngroups, dstate)
|
| 330 |
-
C: (batch, dstate) or (batch, ngroups, dstate)
|
| 331 |
-
D: (dim,) or (nheads, dim)
|
| 332 |
-
z: (batch, dim) or (batch, nheads, dim)
|
| 333 |
-
dt_bias: (dim,) or (nheads, dim)
|
| 334 |
-
Return:
|
| 335 |
-
out: (batch, dim) or (batch, nheads, dim)
|
| 336 |
-
"""
|
| 337 |
-
has_heads = state.dim() > 3
|
| 338 |
-
if state.dim() == 3:
|
| 339 |
-
state = state.unsqueeze(1)
|
| 340 |
-
if x.dim() == 2:
|
| 341 |
-
x = x.unsqueeze(1)
|
| 342 |
-
if dt.dim() == 2:
|
| 343 |
-
dt = dt.unsqueeze(1)
|
| 344 |
-
if A.dim() == 2:
|
| 345 |
-
A = A.unsqueeze(0)
|
| 346 |
-
if B.dim() == 2:
|
| 347 |
-
B = B.unsqueeze(1)
|
| 348 |
-
if C.dim() == 2:
|
| 349 |
-
C = C.unsqueeze(1)
|
| 350 |
-
if D is not None and D.dim() == 1:
|
| 351 |
-
D = D.unsqueeze(0)
|
| 352 |
-
if z is not None and z.dim() == 2:
|
| 353 |
-
z = z.unsqueeze(1)
|
| 354 |
-
if dt_bias is not None and dt_bias.dim() == 1:
|
| 355 |
-
dt_bias = dt_bias.unsqueeze(0)
|
| 356 |
-
batch, nheads, dim, dstate = state.shape
|
| 357 |
-
assert x.shape == (batch, nheads, dim)
|
| 358 |
-
assert dt.shape == x.shape
|
| 359 |
-
assert A.shape == (nheads, dim, dstate)
|
| 360 |
-
ngroups = B.shape[1]
|
| 361 |
-
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
| 362 |
-
assert B.shape == (batch, ngroups, dstate)
|
| 363 |
-
assert C.shape == B.shape
|
| 364 |
-
if D is not None:
|
| 365 |
-
assert D.shape == (nheads, dim)
|
| 366 |
-
if z is not None:
|
| 367 |
-
assert z.shape == x.shape
|
| 368 |
-
if dt_bias is not None:
|
| 369 |
-
assert dt_bias.shape == (nheads, dim)
|
| 370 |
-
dt = dt + dt_bias
|
| 371 |
-
dt = F.softplus(dt) if dt_softplus else dt
|
| 372 |
-
dA = torch.exp(
|
| 373 |
-
rearrange(dt, "b h d -> b h d 1") * A
|
| 374 |
-
) # (batch, nheads, dim, dstate)
|
| 375 |
-
B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
|
| 376 |
-
C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
|
| 377 |
-
dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
|
| 378 |
-
B, "b h n -> b h 1 n"
|
| 379 |
-
) # (batch, nheads, dim, dstate)
|
| 380 |
-
state.copy_(
|
| 381 |
-
state * dA + dB * rearrange(x, "b h d -> b h d 1")
|
| 382 |
-
) # (batch, dim, dstate
|
| 383 |
-
out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
|
| 384 |
-
if D is not None:
|
| 385 |
-
out += (x * D).to(out.dtype)
|
| 386 |
-
out = (out if z is None else out * F.silu(z)).to(x.dtype)
|
| 387 |
-
if not has_heads:
|
| 388 |
-
out = out.squeeze(1)
|
| 389 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_scan.py
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_state.py
DELETED
|
@@ -1,2012 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
| 2 |
-
|
| 3 |
-
"""We want triton==2.1.0 or 2.2.0 for this
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import math
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
|
| 10 |
-
import triton
|
| 11 |
-
import triton.language as tl
|
| 12 |
-
|
| 13 |
-
from einops import rearrange, repeat
|
| 14 |
-
|
| 15 |
-
from .softplus import softplus
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def init_to_zero(names):
|
| 19 |
-
return lambda nargs: [
|
| 20 |
-
nargs[name].zero_() for name in names if nargs[name] is not None
|
| 21 |
-
]
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
@triton.autotune(
|
| 25 |
-
configs=[
|
| 26 |
-
triton.Config({"BLOCK_SIZE_H": 1}),
|
| 27 |
-
triton.Config({"BLOCK_SIZE_H": 2}),
|
| 28 |
-
triton.Config({"BLOCK_SIZE_H": 4}),
|
| 29 |
-
triton.Config({"BLOCK_SIZE_H": 8}),
|
| 30 |
-
triton.Config({"BLOCK_SIZE_H": 16}),
|
| 31 |
-
triton.Config({"BLOCK_SIZE_H": 32}),
|
| 32 |
-
triton.Config({"BLOCK_SIZE_H": 64}),
|
| 33 |
-
],
|
| 34 |
-
key=["chunk_size", "nheads"],
|
| 35 |
-
)
|
| 36 |
-
@triton.jit
|
| 37 |
-
def _chunk_cumsum_fwd_kernel(
|
| 38 |
-
# Pointers to matrices
|
| 39 |
-
dt_ptr,
|
| 40 |
-
A_ptr,
|
| 41 |
-
dt_bias_ptr,
|
| 42 |
-
dt_out_ptr,
|
| 43 |
-
dA_cumsum_ptr,
|
| 44 |
-
# Matrix dimension
|
| 45 |
-
batch,
|
| 46 |
-
seqlen,
|
| 47 |
-
nheads,
|
| 48 |
-
chunk_size,
|
| 49 |
-
dt_min,
|
| 50 |
-
dt_max,
|
| 51 |
-
# Strides
|
| 52 |
-
stride_dt_batch,
|
| 53 |
-
stride_dt_seqlen,
|
| 54 |
-
stride_dt_head,
|
| 55 |
-
stride_A_head,
|
| 56 |
-
stride_dt_bias_head,
|
| 57 |
-
stride_dt_out_batch,
|
| 58 |
-
stride_dt_out_chunk,
|
| 59 |
-
stride_dt_out_head,
|
| 60 |
-
stride_dt_out_csize,
|
| 61 |
-
stride_dA_cs_batch,
|
| 62 |
-
stride_dA_cs_chunk,
|
| 63 |
-
stride_dA_cs_head,
|
| 64 |
-
stride_dA_cs_csize,
|
| 65 |
-
# Meta-parameters
|
| 66 |
-
DT_SOFTPLUS: tl.constexpr,
|
| 67 |
-
HAS_DT_BIAS: tl.constexpr,
|
| 68 |
-
BLOCK_SIZE_H: tl.constexpr,
|
| 69 |
-
BLOCK_SIZE_CHUNK: tl.constexpr,
|
| 70 |
-
):
|
| 71 |
-
pid_b = tl.program_id(axis=0)
|
| 72 |
-
pid_c = tl.program_id(axis=1)
|
| 73 |
-
pid_h = tl.program_id(axis=2)
|
| 74 |
-
dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
|
| 75 |
-
dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk
|
| 76 |
-
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk
|
| 77 |
-
|
| 78 |
-
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
|
| 79 |
-
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
|
| 80 |
-
dt_ptrs = dt_ptr + (
|
| 81 |
-
offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
|
| 82 |
-
)
|
| 83 |
-
A_ptrs = A_ptr + offs_h * stride_A_head
|
| 84 |
-
dt_out_ptrs = dt_out_ptr + (
|
| 85 |
-
offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize
|
| 86 |
-
)
|
| 87 |
-
dA_cs_ptrs = dA_cumsum_ptr + (
|
| 88 |
-
offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize
|
| 89 |
-
)
|
| 90 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 91 |
-
|
| 92 |
-
dt = tl.load(
|
| 93 |
-
dt_ptrs,
|
| 94 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
| 95 |
-
other=0.0,
|
| 96 |
-
).to(tl.float32)
|
| 97 |
-
if HAS_DT_BIAS:
|
| 98 |
-
dt_bias = tl.load(
|
| 99 |
-
dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
|
| 100 |
-
).to(tl.float32)
|
| 101 |
-
dt += dt_bias[:, None]
|
| 102 |
-
if DT_SOFTPLUS:
|
| 103 |
-
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
| 104 |
-
# As of Triton 2.2.0, tl.clamp is not available yet
|
| 105 |
-
# dt = tl.clamp(dt, dt_min, dt_max)
|
| 106 |
-
dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
|
| 107 |
-
dt = tl.where(
|
| 108 |
-
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
|
| 109 |
-
)
|
| 110 |
-
tl.store(
|
| 111 |
-
dt_out_ptrs,
|
| 112 |
-
dt,
|
| 113 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
|
| 114 |
-
)
|
| 115 |
-
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
|
| 116 |
-
dA = dt * A[:, None]
|
| 117 |
-
dA_cs = tl.cumsum(dA, axis=1)
|
| 118 |
-
tl.store(
|
| 119 |
-
dA_cs_ptrs,
|
| 120 |
-
dA_cs,
|
| 121 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
|
| 122 |
-
)
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
@triton.autotune(
|
| 126 |
-
configs=[
|
| 127 |
-
triton.Config(
|
| 128 |
-
{"BLOCK_SIZE_H": 1}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 129 |
-
),
|
| 130 |
-
triton.Config(
|
| 131 |
-
{"BLOCK_SIZE_H": 2}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 132 |
-
),
|
| 133 |
-
triton.Config(
|
| 134 |
-
{"BLOCK_SIZE_H": 4}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 135 |
-
),
|
| 136 |
-
triton.Config(
|
| 137 |
-
{"BLOCK_SIZE_H": 8}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 138 |
-
),
|
| 139 |
-
triton.Config(
|
| 140 |
-
{"BLOCK_SIZE_H": 16}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 141 |
-
),
|
| 142 |
-
triton.Config(
|
| 143 |
-
{"BLOCK_SIZE_H": 32}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 144 |
-
),
|
| 145 |
-
triton.Config(
|
| 146 |
-
{"BLOCK_SIZE_H": 64}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
| 147 |
-
),
|
| 148 |
-
],
|
| 149 |
-
key=["chunk_size", "nheads"],
|
| 150 |
-
)
|
| 151 |
-
@triton.jit
|
| 152 |
-
def _chunk_cumsum_bwd_kernel(
|
| 153 |
-
# Pointers to matrices
|
| 154 |
-
ddA_ptr,
|
| 155 |
-
ddt_out_ptr,
|
| 156 |
-
dt_ptr,
|
| 157 |
-
A_ptr,
|
| 158 |
-
dt_bias_ptr,
|
| 159 |
-
ddt_ptr,
|
| 160 |
-
dA_ptr,
|
| 161 |
-
ddt_bias_ptr,
|
| 162 |
-
# Matrix dimensions
|
| 163 |
-
batch,
|
| 164 |
-
seqlen,
|
| 165 |
-
nheads,
|
| 166 |
-
chunk_size,
|
| 167 |
-
dt_min,
|
| 168 |
-
dt_max,
|
| 169 |
-
# Strides
|
| 170 |
-
stride_ddA_batch,
|
| 171 |
-
stride_ddA_chunk,
|
| 172 |
-
stride_ddA_head,
|
| 173 |
-
stride_ddA_csize,
|
| 174 |
-
stride_ddt_out_batch,
|
| 175 |
-
stride_ddt_out_chunk,
|
| 176 |
-
stride_ddt_out_head,
|
| 177 |
-
stride_ddt_out_csize,
|
| 178 |
-
stride_dt_batch,
|
| 179 |
-
stride_dt_seqlen,
|
| 180 |
-
stride_dt_head,
|
| 181 |
-
stride_A_head,
|
| 182 |
-
stride_dt_bias_head,
|
| 183 |
-
stride_ddt_batch,
|
| 184 |
-
stride_ddt_seqlen,
|
| 185 |
-
stride_ddt_head,
|
| 186 |
-
stride_dA_head,
|
| 187 |
-
stride_ddt_bias_head,
|
| 188 |
-
# Meta-parameters
|
| 189 |
-
DT_SOFTPLUS: tl.constexpr,
|
| 190 |
-
HAS_DT_BIAS: tl.constexpr,
|
| 191 |
-
BLOCK_SIZE_H: tl.constexpr,
|
| 192 |
-
BLOCK_SIZE_CHUNK: tl.constexpr,
|
| 193 |
-
):
|
| 194 |
-
pid_b = tl.program_id(axis=0)
|
| 195 |
-
pid_c = tl.program_id(axis=1)
|
| 196 |
-
pid_h = tl.program_id(axis=2)
|
| 197 |
-
ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk
|
| 198 |
-
ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk
|
| 199 |
-
dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
|
| 200 |
-
ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen
|
| 201 |
-
|
| 202 |
-
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
|
| 203 |
-
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
|
| 204 |
-
ddt_out_ptrs = ddt_out_ptr + (
|
| 205 |
-
offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize
|
| 206 |
-
)
|
| 207 |
-
ddA_ptrs = ddA_ptr + (
|
| 208 |
-
offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize
|
| 209 |
-
)
|
| 210 |
-
dt_ptrs = dt_ptr + (
|
| 211 |
-
offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
|
| 212 |
-
)
|
| 213 |
-
ddt_ptrs = ddt_ptr + (
|
| 214 |
-
offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen
|
| 215 |
-
)
|
| 216 |
-
A_ptrs = A_ptr + offs_h * stride_A_head
|
| 217 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 218 |
-
|
| 219 |
-
ddA = tl.load(
|
| 220 |
-
ddA_ptrs,
|
| 221 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
| 222 |
-
other=0.0,
|
| 223 |
-
).to(tl.float32)
|
| 224 |
-
ddt_out = tl.load(
|
| 225 |
-
ddt_out_ptrs,
|
| 226 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
| 227 |
-
other=0.0,
|
| 228 |
-
).to(tl.float32)
|
| 229 |
-
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
|
| 230 |
-
ddt = ddA * A[:, None] + ddt_out
|
| 231 |
-
dt = tl.load(
|
| 232 |
-
dt_ptrs,
|
| 233 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
| 234 |
-
other=0.0,
|
| 235 |
-
).to(tl.float32)
|
| 236 |
-
if HAS_DT_BIAS:
|
| 237 |
-
dt_bias = tl.load(
|
| 238 |
-
dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
|
| 239 |
-
).to(tl.float32)
|
| 240 |
-
dt += dt_bias[:, None]
|
| 241 |
-
if DT_SOFTPLUS:
|
| 242 |
-
dt_presoftplus = dt
|
| 243 |
-
dt = tl.where(dt <= 20.0, softplus(dt), ddt)
|
| 244 |
-
clamp_mask = (dt < dt_min) | (dt > dt_max)
|
| 245 |
-
# As of Triton 2.2.0, tl.clamp is not available yet
|
| 246 |
-
# dt = tl.clamp(dt, dt_min, dt_max)
|
| 247 |
-
dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
|
| 248 |
-
dt = tl.where(
|
| 249 |
-
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
|
| 250 |
-
)
|
| 251 |
-
ddt = tl.where(
|
| 252 |
-
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0
|
| 253 |
-
)
|
| 254 |
-
ddt = tl.where(clamp_mask, 0.0, ddt)
|
| 255 |
-
if DT_SOFTPLUS:
|
| 256 |
-
ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt)
|
| 257 |
-
tl.store(
|
| 258 |
-
ddt_ptrs,
|
| 259 |
-
ddt,
|
| 260 |
-
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
| 261 |
-
)
|
| 262 |
-
dA = tl.sum(ddA * dt, axis=1)
|
| 263 |
-
tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads)
|
| 264 |
-
if HAS_DT_BIAS:
|
| 265 |
-
ddt_bias = tl.sum(ddt, axis=1)
|
| 266 |
-
tl.atomic_add(
|
| 267 |
-
ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads
|
| 268 |
-
)
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
@triton.autotune(
|
| 272 |
-
configs=[
|
| 273 |
-
triton.Config(
|
| 274 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
| 275 |
-
num_stages=3,
|
| 276 |
-
num_warps=8,
|
| 277 |
-
),
|
| 278 |
-
triton.Config(
|
| 279 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
| 280 |
-
num_stages=4,
|
| 281 |
-
num_warps=4,
|
| 282 |
-
),
|
| 283 |
-
triton.Config(
|
| 284 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 285 |
-
num_stages=4,
|
| 286 |
-
num_warps=4,
|
| 287 |
-
),
|
| 288 |
-
triton.Config(
|
| 289 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 290 |
-
num_stages=4,
|
| 291 |
-
num_warps=4,
|
| 292 |
-
),
|
| 293 |
-
triton.Config(
|
| 294 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 295 |
-
num_stages=4,
|
| 296 |
-
num_warps=4,
|
| 297 |
-
),
|
| 298 |
-
triton.Config(
|
| 299 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 300 |
-
num_stages=4,
|
| 301 |
-
num_warps=4,
|
| 302 |
-
),
|
| 303 |
-
triton.Config(
|
| 304 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 305 |
-
num_stages=5,
|
| 306 |
-
num_warps=2,
|
| 307 |
-
),
|
| 308 |
-
triton.Config(
|
| 309 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 310 |
-
num_stages=5,
|
| 311 |
-
num_warps=2,
|
| 312 |
-
),
|
| 313 |
-
triton.Config(
|
| 314 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 315 |
-
num_stages=4,
|
| 316 |
-
num_warps=2,
|
| 317 |
-
),
|
| 318 |
-
],
|
| 319 |
-
key=["hdim", "dstate", "chunk_size"],
|
| 320 |
-
)
|
| 321 |
-
@triton.jit
|
| 322 |
-
def _chunk_state_fwd_kernel(
|
| 323 |
-
# Pointers to matrices
|
| 324 |
-
x_ptr,
|
| 325 |
-
b_ptr,
|
| 326 |
-
states_ptr,
|
| 327 |
-
dt_ptr,
|
| 328 |
-
dA_cumsum_ptr,
|
| 329 |
-
seq_idx_ptr,
|
| 330 |
-
# Matrix dimensions
|
| 331 |
-
hdim,
|
| 332 |
-
dstate,
|
| 333 |
-
chunk_size,
|
| 334 |
-
batch,
|
| 335 |
-
seqlen,
|
| 336 |
-
nheads_ngroups_ratio,
|
| 337 |
-
# Strides
|
| 338 |
-
stride_x_batch,
|
| 339 |
-
stride_x_seqlen,
|
| 340 |
-
stride_x_head,
|
| 341 |
-
stride_x_hdim,
|
| 342 |
-
stride_b_batch,
|
| 343 |
-
stride_b_seqlen,
|
| 344 |
-
stride_b_head,
|
| 345 |
-
stride_b_dstate,
|
| 346 |
-
stride_states_batch,
|
| 347 |
-
stride_states_chunk,
|
| 348 |
-
stride_states_head,
|
| 349 |
-
stride_states_hdim,
|
| 350 |
-
stride_states_dstate,
|
| 351 |
-
stride_dt_batch,
|
| 352 |
-
stride_dt_chunk,
|
| 353 |
-
stride_dt_head,
|
| 354 |
-
stride_dt_csize,
|
| 355 |
-
stride_dA_cs_batch,
|
| 356 |
-
stride_dA_cs_chunk,
|
| 357 |
-
stride_dA_cs_head,
|
| 358 |
-
stride_dA_cs_csize,
|
| 359 |
-
stride_seq_idx_batch,
|
| 360 |
-
stride_seq_idx_seqlen,
|
| 361 |
-
# Meta-parameters
|
| 362 |
-
HAS_SEQ_IDX: tl.constexpr,
|
| 363 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 364 |
-
BLOCK_SIZE_N: tl.constexpr,
|
| 365 |
-
BLOCK_SIZE_K: tl.constexpr,
|
| 366 |
-
):
|
| 367 |
-
pid_bc = tl.program_id(axis=1)
|
| 368 |
-
pid_c = pid_bc // batch
|
| 369 |
-
pid_b = pid_bc - pid_c * batch
|
| 370 |
-
pid_h = tl.program_id(axis=2)
|
| 371 |
-
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
| 372 |
-
pid_m = tl.program_id(axis=0) // num_pid_n
|
| 373 |
-
pid_n = tl.program_id(axis=0) % num_pid_n
|
| 374 |
-
b_ptr += (
|
| 375 |
-
pid_b * stride_b_batch
|
| 376 |
-
+ pid_c * chunk_size * stride_b_seqlen
|
| 377 |
-
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
| 378 |
-
)
|
| 379 |
-
x_ptr += (
|
| 380 |
-
pid_b * stride_x_batch
|
| 381 |
-
+ pid_c * chunk_size * stride_x_seqlen
|
| 382 |
-
+ pid_h * stride_x_head
|
| 383 |
-
)
|
| 384 |
-
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
| 385 |
-
dA_cumsum_ptr += (
|
| 386 |
-
pid_b * stride_dA_cs_batch
|
| 387 |
-
+ pid_c * stride_dA_cs_chunk
|
| 388 |
-
+ pid_h * stride_dA_cs_head
|
| 389 |
-
)
|
| 390 |
-
if HAS_SEQ_IDX:
|
| 391 |
-
seq_idx_ptr += (
|
| 392 |
-
pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
| 393 |
-
)
|
| 394 |
-
|
| 395 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 396 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 397 |
-
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 398 |
-
x_ptrs = x_ptr + (
|
| 399 |
-
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
|
| 400 |
-
)
|
| 401 |
-
b_ptrs = b_ptr + (
|
| 402 |
-
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
|
| 403 |
-
)
|
| 404 |
-
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
| 405 |
-
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
| 406 |
-
tl.float32
|
| 407 |
-
)
|
| 408 |
-
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
| 409 |
-
if HAS_SEQ_IDX:
|
| 410 |
-
seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
|
| 411 |
-
|
| 412 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 413 |
-
if HAS_SEQ_IDX:
|
| 414 |
-
seq_idx_last = tl.load(
|
| 415 |
-
seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
|
| 416 |
-
)
|
| 417 |
-
|
| 418 |
-
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 419 |
-
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
| 420 |
-
x = tl.load(
|
| 421 |
-
x_ptrs,
|
| 422 |
-
mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k),
|
| 423 |
-
other=0.0,
|
| 424 |
-
)
|
| 425 |
-
b = tl.load(
|
| 426 |
-
b_ptrs,
|
| 427 |
-
mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate),
|
| 428 |
-
other=0.0,
|
| 429 |
-
).to(tl.float32)
|
| 430 |
-
dA_cs_k = tl.load(
|
| 431 |
-
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
|
| 432 |
-
).to(tl.float32)
|
| 433 |
-
if HAS_SEQ_IDX:
|
| 434 |
-
seq_idx_k = tl.load(
|
| 435 |
-
seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1
|
| 436 |
-
)
|
| 437 |
-
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
|
| 438 |
-
tl.float32
|
| 439 |
-
)
|
| 440 |
-
if not HAS_SEQ_IDX:
|
| 441 |
-
scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
|
| 442 |
-
else:
|
| 443 |
-
scale = tl.where(
|
| 444 |
-
seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0
|
| 445 |
-
)
|
| 446 |
-
b *= scale[:, None]
|
| 447 |
-
b = b.to(x_ptr.dtype.element_ty)
|
| 448 |
-
acc += tl.dot(x, b)
|
| 449 |
-
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
| 450 |
-
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
| 451 |
-
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
| 452 |
-
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
| 453 |
-
if HAS_SEQ_IDX:
|
| 454 |
-
seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
|
| 455 |
-
states = acc.to(states_ptr.dtype.element_ty)
|
| 456 |
-
|
| 457 |
-
states_ptr += (
|
| 458 |
-
pid_b * stride_states_batch
|
| 459 |
-
+ pid_c * stride_states_chunk
|
| 460 |
-
+ pid_h * stride_states_head
|
| 461 |
-
)
|
| 462 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 463 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 464 |
-
states_ptrs = states_ptr + (
|
| 465 |
-
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
|
| 466 |
-
)
|
| 467 |
-
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
| 468 |
-
tl.store(states_ptrs, states, mask=c_mask)
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
@triton.autotune(
|
| 472 |
-
configs=[
|
| 473 |
-
triton.Config(
|
| 474 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
| 475 |
-
num_stages=3,
|
| 476 |
-
num_warps=8,
|
| 477 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 478 |
-
),
|
| 479 |
-
triton.Config(
|
| 480 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
| 481 |
-
num_stages=4,
|
| 482 |
-
num_warps=4,
|
| 483 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 484 |
-
),
|
| 485 |
-
triton.Config(
|
| 486 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 487 |
-
num_stages=4,
|
| 488 |
-
num_warps=4,
|
| 489 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 490 |
-
),
|
| 491 |
-
triton.Config(
|
| 492 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 493 |
-
num_stages=4,
|
| 494 |
-
num_warps=4,
|
| 495 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 496 |
-
),
|
| 497 |
-
triton.Config(
|
| 498 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 499 |
-
num_stages=4,
|
| 500 |
-
num_warps=4,
|
| 501 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 502 |
-
),
|
| 503 |
-
triton.Config(
|
| 504 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 505 |
-
num_stages=4,
|
| 506 |
-
num_warps=4,
|
| 507 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 508 |
-
),
|
| 509 |
-
triton.Config(
|
| 510 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 511 |
-
num_stages=5,
|
| 512 |
-
num_warps=4,
|
| 513 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 514 |
-
),
|
| 515 |
-
triton.Config(
|
| 516 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 517 |
-
num_stages=5,
|
| 518 |
-
num_warps=4,
|
| 519 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 520 |
-
),
|
| 521 |
-
triton.Config(
|
| 522 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 523 |
-
num_stages=4,
|
| 524 |
-
num_warps=4,
|
| 525 |
-
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
| 526 |
-
),
|
| 527 |
-
],
|
| 528 |
-
key=["chunk_size", "hdim", "dstate"],
|
| 529 |
-
)
|
| 530 |
-
@triton.jit
|
| 531 |
-
def _chunk_state_bwd_dx_kernel(
|
| 532 |
-
# Pointers to matrices
|
| 533 |
-
x_ptr,
|
| 534 |
-
b_ptr,
|
| 535 |
-
dstates_ptr,
|
| 536 |
-
dt_ptr,
|
| 537 |
-
dA_cumsum_ptr,
|
| 538 |
-
dx_ptr,
|
| 539 |
-
ddt_ptr,
|
| 540 |
-
ddA_cumsum_ptr,
|
| 541 |
-
# Matrix dimensions
|
| 542 |
-
chunk_size,
|
| 543 |
-
hdim,
|
| 544 |
-
dstate,
|
| 545 |
-
batch,
|
| 546 |
-
seqlen,
|
| 547 |
-
nheads_ngroups_ratio,
|
| 548 |
-
# Strides
|
| 549 |
-
stride_x_batch,
|
| 550 |
-
stride_x_seqlen,
|
| 551 |
-
stride_x_head,
|
| 552 |
-
stride_x_hdim,
|
| 553 |
-
stride_b_batch,
|
| 554 |
-
stride_b_seqlen,
|
| 555 |
-
stride_b_head,
|
| 556 |
-
stride_b_dstate,
|
| 557 |
-
stride_dstates_batch,
|
| 558 |
-
stride_dstates_chunk,
|
| 559 |
-
stride_states_head,
|
| 560 |
-
stride_states_hdim,
|
| 561 |
-
stride_states_dstate,
|
| 562 |
-
stride_dt_batch,
|
| 563 |
-
stride_dt_chunk,
|
| 564 |
-
stride_dt_head,
|
| 565 |
-
stride_dt_csize,
|
| 566 |
-
stride_dA_cs_batch,
|
| 567 |
-
stride_dA_cs_chunk,
|
| 568 |
-
stride_dA_cs_head,
|
| 569 |
-
stride_dA_cs_csize,
|
| 570 |
-
stride_dx_batch,
|
| 571 |
-
stride_dx_seqlen,
|
| 572 |
-
stride_dx_head,
|
| 573 |
-
stride_dx_hdim,
|
| 574 |
-
stride_ddt_batch,
|
| 575 |
-
stride_ddt_chunk,
|
| 576 |
-
stride_ddt_head,
|
| 577 |
-
stride_ddt_csize,
|
| 578 |
-
stride_ddA_cs_batch,
|
| 579 |
-
stride_ddA_cs_chunk,
|
| 580 |
-
stride_ddA_cs_head,
|
| 581 |
-
stride_ddA_cs_csize,
|
| 582 |
-
# Meta-parameters
|
| 583 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 584 |
-
BLOCK_SIZE_N: tl.constexpr,
|
| 585 |
-
BLOCK_SIZE_K: tl.constexpr,
|
| 586 |
-
BLOCK_SIZE_DSTATE: tl.constexpr,
|
| 587 |
-
):
|
| 588 |
-
pid_bc = tl.program_id(axis=1)
|
| 589 |
-
pid_c = pid_bc // batch
|
| 590 |
-
pid_b = pid_bc - pid_c * batch
|
| 591 |
-
pid_h = tl.program_id(axis=2)
|
| 592 |
-
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
| 593 |
-
pid_m = tl.program_id(axis=0) // num_pid_n
|
| 594 |
-
pid_n = tl.program_id(axis=0) % num_pid_n
|
| 595 |
-
x_ptr += (
|
| 596 |
-
pid_b * stride_x_batch
|
| 597 |
-
+ pid_c * chunk_size * stride_x_seqlen
|
| 598 |
-
+ pid_h * stride_x_head
|
| 599 |
-
)
|
| 600 |
-
b_ptr += (
|
| 601 |
-
pid_b * stride_b_batch
|
| 602 |
-
+ pid_c * chunk_size * stride_b_seqlen
|
| 603 |
-
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
| 604 |
-
)
|
| 605 |
-
dstates_ptr += (
|
| 606 |
-
pid_b * stride_dstates_batch
|
| 607 |
-
+ pid_c * stride_dstates_chunk
|
| 608 |
-
+ pid_h * stride_states_head
|
| 609 |
-
)
|
| 610 |
-
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
| 611 |
-
ddt_ptr += (
|
| 612 |
-
pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
|
| 613 |
-
)
|
| 614 |
-
ddA_cumsum_ptr += (
|
| 615 |
-
pid_b * stride_ddA_cs_batch
|
| 616 |
-
+ pid_c * stride_ddA_cs_chunk
|
| 617 |
-
+ pid_h * stride_ddA_cs_head
|
| 618 |
-
)
|
| 619 |
-
dA_cumsum_ptr += (
|
| 620 |
-
pid_b * stride_dA_cs_batch
|
| 621 |
-
+ pid_c * stride_dA_cs_chunk
|
| 622 |
-
+ pid_h * stride_dA_cs_head
|
| 623 |
-
)
|
| 624 |
-
|
| 625 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 626 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 627 |
-
|
| 628 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 629 |
-
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
| 630 |
-
offs_k = tl.arange(
|
| 631 |
-
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
|
| 632 |
-
)
|
| 633 |
-
b_ptrs = b_ptr + (
|
| 634 |
-
offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate
|
| 635 |
-
)
|
| 636 |
-
dstates_ptrs = dstates_ptr + (
|
| 637 |
-
offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate
|
| 638 |
-
)
|
| 639 |
-
if BLOCK_SIZE_DSTATE <= 128:
|
| 640 |
-
b = tl.load(
|
| 641 |
-
b_ptrs,
|
| 642 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate),
|
| 643 |
-
other=0.0,
|
| 644 |
-
)
|
| 645 |
-
dstates = tl.load(
|
| 646 |
-
dstates_ptrs,
|
| 647 |
-
mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim),
|
| 648 |
-
other=0.0,
|
| 649 |
-
)
|
| 650 |
-
dstates = dstates.to(b_ptr.dtype.element_ty)
|
| 651 |
-
acc = tl.dot(b, dstates)
|
| 652 |
-
else:
|
| 653 |
-
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 654 |
-
for k in range(0, dstate, BLOCK_SIZE_K):
|
| 655 |
-
b = tl.load(
|
| 656 |
-
b_ptrs,
|
| 657 |
-
mask=(offs_m[:, None] < chunk_size_limit)
|
| 658 |
-
& (offs_k[None, :] < dstate - k),
|
| 659 |
-
other=0.0,
|
| 660 |
-
)
|
| 661 |
-
dstates = tl.load(
|
| 662 |
-
dstates_ptrs,
|
| 663 |
-
mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim),
|
| 664 |
-
other=0.0,
|
| 665 |
-
)
|
| 666 |
-
dstates = dstates.to(b_ptr.dtype.element_ty)
|
| 667 |
-
acc += tl.dot(b, dstates)
|
| 668 |
-
b_ptrs += BLOCK_SIZE_K * stride_b_dstate
|
| 669 |
-
dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
|
| 670 |
-
|
| 671 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 672 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 673 |
-
|
| 674 |
-
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
| 675 |
-
tl.float32
|
| 676 |
-
)
|
| 677 |
-
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
| 678 |
-
dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
|
| 679 |
-
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(
|
| 680 |
-
tl.float32
|
| 681 |
-
)
|
| 682 |
-
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
| 683 |
-
acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]
|
| 684 |
-
|
| 685 |
-
x_ptrs = x_ptr + (
|
| 686 |
-
offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
|
| 687 |
-
)
|
| 688 |
-
x = tl.load(
|
| 689 |
-
x_ptrs,
|
| 690 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
| 691 |
-
other=0.0,
|
| 692 |
-
).to(tl.float32)
|
| 693 |
-
ddt = tl.sum(acc * x, axis=1)
|
| 694 |
-
ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
|
| 695 |
-
tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
|
| 696 |
-
ddA_cs = -(ddt * dt_m)
|
| 697 |
-
ddA_cs_last = -tl.sum(ddA_cs)
|
| 698 |
-
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
|
| 699 |
-
tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
|
| 700 |
-
tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last)
|
| 701 |
-
|
| 702 |
-
dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty)
|
| 703 |
-
dx_ptr += (
|
| 704 |
-
pid_b * stride_dx_batch
|
| 705 |
-
+ pid_c * chunk_size * stride_dx_seqlen
|
| 706 |
-
+ pid_h * stride_dx_head
|
| 707 |
-
)
|
| 708 |
-
dx_ptrs = dx_ptr + (
|
| 709 |
-
offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim
|
| 710 |
-
)
|
| 711 |
-
tl.store(
|
| 712 |
-
dx_ptrs,
|
| 713 |
-
dx,
|
| 714 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
| 715 |
-
)
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
@triton.autotune(
|
| 719 |
-
configs=[
|
| 720 |
-
triton.Config(
|
| 721 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128},
|
| 722 |
-
num_stages=3,
|
| 723 |
-
num_warps=4,
|
| 724 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 725 |
-
),
|
| 726 |
-
triton.Config(
|
| 727 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32},
|
| 728 |
-
num_stages=3,
|
| 729 |
-
num_warps=4,
|
| 730 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 731 |
-
),
|
| 732 |
-
triton.Config(
|
| 733 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128},
|
| 734 |
-
num_stages=3,
|
| 735 |
-
num_warps=4,
|
| 736 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 737 |
-
),
|
| 738 |
-
triton.Config(
|
| 739 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64},
|
| 740 |
-
num_stages=3,
|
| 741 |
-
num_warps=4,
|
| 742 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 743 |
-
),
|
| 744 |
-
triton.Config(
|
| 745 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64},
|
| 746 |
-
num_stages=3,
|
| 747 |
-
num_warps=4,
|
| 748 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 749 |
-
),
|
| 750 |
-
triton.Config(
|
| 751 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32},
|
| 752 |
-
num_stages=3,
|
| 753 |
-
num_warps=4,
|
| 754 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 755 |
-
),
|
| 756 |
-
triton.Config(
|
| 757 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64},
|
| 758 |
-
num_stages=3,
|
| 759 |
-
num_warps=4,
|
| 760 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 761 |
-
),
|
| 762 |
-
triton.Config(
|
| 763 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32},
|
| 764 |
-
num_stages=3,
|
| 765 |
-
num_warps=4,
|
| 766 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 767 |
-
),
|
| 768 |
-
],
|
| 769 |
-
key=["chunk_size", "dstate", "hdim"],
|
| 770 |
-
)
|
| 771 |
-
@triton.jit
|
| 772 |
-
def _chunk_state_bwd_db_kernel(
|
| 773 |
-
# Pointers to matrices
|
| 774 |
-
x_ptr,
|
| 775 |
-
dstates_ptr,
|
| 776 |
-
b_ptr,
|
| 777 |
-
dt_ptr,
|
| 778 |
-
dA_cumsum_ptr,
|
| 779 |
-
seq_idx_ptr,
|
| 780 |
-
db_ptr,
|
| 781 |
-
ddA_cumsum_ptr,
|
| 782 |
-
# Matrix dimensions
|
| 783 |
-
chunk_size,
|
| 784 |
-
dstate,
|
| 785 |
-
hdim,
|
| 786 |
-
batch,
|
| 787 |
-
seqlen,
|
| 788 |
-
nheads,
|
| 789 |
-
nheads_per_program,
|
| 790 |
-
ngroups,
|
| 791 |
-
# Strides
|
| 792 |
-
stride_x_batch,
|
| 793 |
-
stride_x_seqlen,
|
| 794 |
-
stride_x_head,
|
| 795 |
-
stride_x_hdim,
|
| 796 |
-
stride_dstates_batch,
|
| 797 |
-
stride_dstates_chunk,
|
| 798 |
-
stride_states_head,
|
| 799 |
-
stride_states_hdim,
|
| 800 |
-
stride_states_dstate,
|
| 801 |
-
stride_b_batch,
|
| 802 |
-
stride_b_seqlen,
|
| 803 |
-
stride_b_head,
|
| 804 |
-
stride_b_dstate,
|
| 805 |
-
stride_dt_batch,
|
| 806 |
-
stride_dt_chunk,
|
| 807 |
-
stride_dt_head,
|
| 808 |
-
stride_dt_csize,
|
| 809 |
-
stride_dA_cs_batch,
|
| 810 |
-
stride_dA_cs_chunk,
|
| 811 |
-
stride_dA_cs_head,
|
| 812 |
-
stride_dA_cs_csize,
|
| 813 |
-
stride_seq_idx_batch,
|
| 814 |
-
stride_seq_idx_seqlen,
|
| 815 |
-
stride_db_batch,
|
| 816 |
-
stride_db_seqlen,
|
| 817 |
-
stride_db_split,
|
| 818 |
-
stride_db_group,
|
| 819 |
-
stride_db_dstate,
|
| 820 |
-
stride_ddA_cs_batch,
|
| 821 |
-
stride_ddA_cs_chunk,
|
| 822 |
-
stride_ddA_cs_head,
|
| 823 |
-
stride_ddA_cs_csize,
|
| 824 |
-
# Meta-parameters
|
| 825 |
-
HAS_DDA_CS: tl.constexpr,
|
| 826 |
-
HAS_SEQ_IDX: tl.constexpr,
|
| 827 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 828 |
-
BLOCK_SIZE_N: tl.constexpr,
|
| 829 |
-
BLOCK_SIZE_K: tl.constexpr,
|
| 830 |
-
):
|
| 831 |
-
pid_bc = tl.program_id(axis=1)
|
| 832 |
-
pid_c = pid_bc // batch
|
| 833 |
-
pid_b = pid_bc - pid_c * batch
|
| 834 |
-
pid_sg = tl.program_id(axis=2)
|
| 835 |
-
pid_s = pid_sg // ngroups
|
| 836 |
-
pid_g = pid_sg - pid_s * ngroups
|
| 837 |
-
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
| 838 |
-
pid_m = tl.program_id(axis=0) // num_pid_n
|
| 839 |
-
pid_n = tl.program_id(axis=0) % num_pid_n
|
| 840 |
-
x_ptr += (
|
| 841 |
-
pid_b * stride_x_batch
|
| 842 |
-
+ pid_c * chunk_size * stride_x_seqlen
|
| 843 |
-
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head
|
| 844 |
-
)
|
| 845 |
-
db_ptr += (
|
| 846 |
-
pid_b * stride_db_batch
|
| 847 |
-
+ pid_c * chunk_size * stride_db_seqlen
|
| 848 |
-
+ pid_g * stride_db_group
|
| 849 |
-
+ pid_s * stride_db_split
|
| 850 |
-
)
|
| 851 |
-
dstates_ptr += (
|
| 852 |
-
pid_b * stride_dstates_batch
|
| 853 |
-
+ pid_c * stride_dstates_chunk
|
| 854 |
-
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program)
|
| 855 |
-
* stride_states_head
|
| 856 |
-
)
|
| 857 |
-
dt_ptr += (
|
| 858 |
-
pid_b * stride_dt_batch
|
| 859 |
-
+ pid_c * stride_dt_chunk
|
| 860 |
-
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head
|
| 861 |
-
)
|
| 862 |
-
dA_cumsum_ptr += (
|
| 863 |
-
pid_b * stride_dA_cs_batch
|
| 864 |
-
+ pid_c * stride_dA_cs_chunk
|
| 865 |
-
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head
|
| 866 |
-
)
|
| 867 |
-
if HAS_DDA_CS:
|
| 868 |
-
b_ptr += (
|
| 869 |
-
pid_b * stride_b_batch
|
| 870 |
-
+ pid_c * chunk_size * stride_b_seqlen
|
| 871 |
-
+ pid_g * stride_b_head
|
| 872 |
-
)
|
| 873 |
-
ddA_cumsum_ptr += (
|
| 874 |
-
pid_b * stride_ddA_cs_batch
|
| 875 |
-
+ pid_c * stride_ddA_cs_chunk
|
| 876 |
-
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program)
|
| 877 |
-
* stride_ddA_cs_head
|
| 878 |
-
)
|
| 879 |
-
if HAS_SEQ_IDX:
|
| 880 |
-
seq_idx_ptr += (
|
| 881 |
-
pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
| 882 |
-
)
|
| 883 |
-
|
| 884 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 885 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 886 |
-
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 887 |
-
x_ptrs = x_ptr + (
|
| 888 |
-
offs_m[:, None] * stride_x_seqlen + offs_k[None, :] * stride_x_hdim
|
| 889 |
-
)
|
| 890 |
-
dstates_ptrs = dstates_ptr + (
|
| 891 |
-
offs_n[None, :] * stride_states_dstate + offs_k[:, None] * stride_states_hdim
|
| 892 |
-
)
|
| 893 |
-
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
| 894 |
-
dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
|
| 895 |
-
if HAS_DDA_CS:
|
| 896 |
-
b_ptrs = b_ptr + (
|
| 897 |
-
offs_m[:, None] * stride_b_seqlen + offs_n[None, :] * stride_b_dstate
|
| 898 |
-
)
|
| 899 |
-
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
|
| 900 |
-
|
| 901 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 902 |
-
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 903 |
-
if HAS_DDA_CS:
|
| 904 |
-
b = tl.load(
|
| 905 |
-
b_ptrs,
|
| 906 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate),
|
| 907 |
-
other=0.0,
|
| 908 |
-
).to(tl.float32)
|
| 909 |
-
if HAS_SEQ_IDX:
|
| 910 |
-
seq_idx_m = tl.load(
|
| 911 |
-
seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
| 912 |
-
mask=offs_m < chunk_size_limit,
|
| 913 |
-
other=-1,
|
| 914 |
-
)
|
| 915 |
-
seq_idx_last = tl.load(
|
| 916 |
-
seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
|
| 917 |
-
)
|
| 918 |
-
nheads_iter = min(
|
| 919 |
-
nheads_per_program, nheads // ngroups - pid_s * nheads_per_program
|
| 920 |
-
)
|
| 921 |
-
for h in range(nheads_iter):
|
| 922 |
-
x = tl.load(
|
| 923 |
-
x_ptrs,
|
| 924 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim),
|
| 925 |
-
other=0.0,
|
| 926 |
-
)
|
| 927 |
-
dstates = tl.load(
|
| 928 |
-
dstates_ptrs,
|
| 929 |
-
mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate),
|
| 930 |
-
other=0.0,
|
| 931 |
-
)
|
| 932 |
-
dstates = dstates.to(x_ptrs.dtype.element_ty)
|
| 933 |
-
db = tl.dot(x, dstates)
|
| 934 |
-
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
| 935 |
-
tl.float32
|
| 936 |
-
)
|
| 937 |
-
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(
|
| 938 |
-
tl.float32
|
| 939 |
-
)
|
| 940 |
-
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
| 941 |
-
if not HAS_SEQ_IDX:
|
| 942 |
-
scale = tl.exp(dA_cs_last - dA_cs_m)
|
| 943 |
-
else:
|
| 944 |
-
scale = tl.where(
|
| 945 |
-
seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0
|
| 946 |
-
)
|
| 947 |
-
db *= (scale * dt_m)[:, None]
|
| 948 |
-
if HAS_DDA_CS:
|
| 949 |
-
# This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum
|
| 950 |
-
ddA_cs = tl.sum(db * b, axis=1)
|
| 951 |
-
tl.atomic_add(
|
| 952 |
-
ddA_cumsum_ptrs + stride_ddA_cs_csize,
|
| 953 |
-
ddA_cs,
|
| 954 |
-
mask=offs_m < chunk_size - 1,
|
| 955 |
-
)
|
| 956 |
-
acc += db
|
| 957 |
-
x_ptrs += stride_x_head
|
| 958 |
-
dstates_ptrs += stride_states_head
|
| 959 |
-
dt_ptrs += stride_dt_head
|
| 960 |
-
dA_cumsum_ptr += stride_dA_cs_head
|
| 961 |
-
dA_cumsum_ptrs += stride_dA_cs_head
|
| 962 |
-
if HAS_DDA_CS:
|
| 963 |
-
ddA_cumsum_ptrs += stride_ddA_cs_head
|
| 964 |
-
|
| 965 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 966 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 967 |
-
# if HAS_SEQ_IDX:
|
| 968 |
-
# seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
|
| 969 |
-
# seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
|
| 970 |
-
# acc = tl.where(seq_idx_m[:, None] == seq_idx_last, acc, 0.0)
|
| 971 |
-
db_ptrs = db_ptr + (
|
| 972 |
-
offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_dstate
|
| 973 |
-
)
|
| 974 |
-
tl.store(
|
| 975 |
-
db_ptrs,
|
| 976 |
-
acc,
|
| 977 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate),
|
| 978 |
-
)
|
| 979 |
-
|
| 980 |
-
|
| 981 |
-
@triton.autotune(
|
| 982 |
-
configs=[
|
| 983 |
-
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 984 |
-
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 985 |
-
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 986 |
-
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 987 |
-
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 988 |
-
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 989 |
-
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 990 |
-
# triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 991 |
-
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
| 992 |
-
triton.Config(
|
| 993 |
-
{"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32},
|
| 994 |
-
num_stages=3,
|
| 995 |
-
num_warps=4,
|
| 996 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 997 |
-
),
|
| 998 |
-
triton.Config(
|
| 999 |
-
{"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 1000 |
-
num_stages=3,
|
| 1001 |
-
num_warps=4,
|
| 1002 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1003 |
-
),
|
| 1004 |
-
triton.Config(
|
| 1005 |
-
{"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 1006 |
-
num_stages=3,
|
| 1007 |
-
num_warps=4,
|
| 1008 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1009 |
-
),
|
| 1010 |
-
triton.Config(
|
| 1011 |
-
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 1012 |
-
num_stages=3,
|
| 1013 |
-
num_warps=4,
|
| 1014 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1015 |
-
),
|
| 1016 |
-
triton.Config(
|
| 1017 |
-
{"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32},
|
| 1018 |
-
num_stages=4,
|
| 1019 |
-
num_warps=8,
|
| 1020 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1021 |
-
),
|
| 1022 |
-
triton.Config(
|
| 1023 |
-
{"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 1024 |
-
num_stages=4,
|
| 1025 |
-
num_warps=8,
|
| 1026 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1027 |
-
),
|
| 1028 |
-
triton.Config(
|
| 1029 |
-
{"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 1030 |
-
num_stages=4,
|
| 1031 |
-
num_warps=8,
|
| 1032 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1033 |
-
),
|
| 1034 |
-
triton.Config(
|
| 1035 |
-
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 1036 |
-
num_stages=4,
|
| 1037 |
-
num_warps=8,
|
| 1038 |
-
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
| 1039 |
-
),
|
| 1040 |
-
],
|
| 1041 |
-
key=["chunk_size", "hdim", "dstate"],
|
| 1042 |
-
)
|
| 1043 |
-
@triton.jit
|
| 1044 |
-
def _chunk_state_bwd_ddAcs_stable_kernel(
|
| 1045 |
-
# Pointers to matrices
|
| 1046 |
-
x_ptr,
|
| 1047 |
-
b_ptr,
|
| 1048 |
-
dstates_ptr,
|
| 1049 |
-
dt_ptr,
|
| 1050 |
-
dA_cumsum_ptr,
|
| 1051 |
-
seq_idx_ptr,
|
| 1052 |
-
ddA_cumsum_ptr,
|
| 1053 |
-
# Matrix dimensions
|
| 1054 |
-
chunk_size,
|
| 1055 |
-
hdim,
|
| 1056 |
-
dstate,
|
| 1057 |
-
batch,
|
| 1058 |
-
seqlen,
|
| 1059 |
-
nheads_ngroups_ratio,
|
| 1060 |
-
# Strides
|
| 1061 |
-
stride_x_batch,
|
| 1062 |
-
stride_x_seqlen,
|
| 1063 |
-
stride_x_head,
|
| 1064 |
-
stride_x_hdim,
|
| 1065 |
-
stride_b_batch,
|
| 1066 |
-
stride_b_seqlen,
|
| 1067 |
-
stride_b_head,
|
| 1068 |
-
stride_b_dstate,
|
| 1069 |
-
stride_dstates_batch,
|
| 1070 |
-
stride_dstates_chunk,
|
| 1071 |
-
stride_states_head,
|
| 1072 |
-
stride_states_hdim,
|
| 1073 |
-
stride_states_dstate,
|
| 1074 |
-
stride_dt_batch,
|
| 1075 |
-
stride_dt_chunk,
|
| 1076 |
-
stride_dt_head,
|
| 1077 |
-
stride_dt_csize,
|
| 1078 |
-
stride_dA_cs_batch,
|
| 1079 |
-
stride_dA_cs_chunk,
|
| 1080 |
-
stride_dA_cs_head,
|
| 1081 |
-
stride_dA_cs_csize,
|
| 1082 |
-
stride_seq_idx_batch,
|
| 1083 |
-
stride_seq_idx_seqlen,
|
| 1084 |
-
stride_ddA_cs_batch,
|
| 1085 |
-
stride_ddA_cs_chunk,
|
| 1086 |
-
stride_ddA_cs_head,
|
| 1087 |
-
stride_ddA_cs_csize,
|
| 1088 |
-
# Meta-parameters
|
| 1089 |
-
HAS_SEQ_IDX: tl.constexpr,
|
| 1090 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 1091 |
-
BLOCK_SIZE_N: tl.constexpr,
|
| 1092 |
-
BLOCK_SIZE_K: tl.constexpr,
|
| 1093 |
-
BLOCK_SIZE_DSTATE: tl.constexpr,
|
| 1094 |
-
):
|
| 1095 |
-
pid_bc = tl.program_id(axis=1)
|
| 1096 |
-
pid_c = pid_bc // batch
|
| 1097 |
-
pid_b = pid_bc - pid_c * batch
|
| 1098 |
-
pid_h = tl.program_id(axis=2)
|
| 1099 |
-
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
| 1100 |
-
pid_m = tl.program_id(axis=0) // num_pid_n
|
| 1101 |
-
pid_n = tl.program_id(axis=0) % num_pid_n
|
| 1102 |
-
x_ptr += (
|
| 1103 |
-
pid_b * stride_x_batch
|
| 1104 |
-
+ pid_c * chunk_size * stride_x_seqlen
|
| 1105 |
-
+ pid_h * stride_x_head
|
| 1106 |
-
)
|
| 1107 |
-
b_ptr += (
|
| 1108 |
-
pid_b * stride_b_batch
|
| 1109 |
-
+ pid_c * chunk_size * stride_b_seqlen
|
| 1110 |
-
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
| 1111 |
-
)
|
| 1112 |
-
dstates_ptr += (
|
| 1113 |
-
pid_b * stride_dstates_batch
|
| 1114 |
-
+ pid_c * stride_dstates_chunk
|
| 1115 |
-
+ pid_h * stride_states_head
|
| 1116 |
-
)
|
| 1117 |
-
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
| 1118 |
-
ddA_cumsum_ptr += (
|
| 1119 |
-
pid_b * stride_ddA_cs_batch
|
| 1120 |
-
+ pid_c * stride_ddA_cs_chunk
|
| 1121 |
-
+ pid_h * stride_ddA_cs_head
|
| 1122 |
-
)
|
| 1123 |
-
dA_cumsum_ptr += (
|
| 1124 |
-
pid_b * stride_dA_cs_batch
|
| 1125 |
-
+ pid_c * stride_dA_cs_chunk
|
| 1126 |
-
+ pid_h * stride_dA_cs_head
|
| 1127 |
-
)
|
| 1128 |
-
if HAS_SEQ_IDX:
|
| 1129 |
-
seq_idx_ptr += (
|
| 1130 |
-
pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
| 1131 |
-
)
|
| 1132 |
-
|
| 1133 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 1134 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 1135 |
-
|
| 1136 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 1137 |
-
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
| 1138 |
-
offs_k = tl.arange(
|
| 1139 |
-
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
|
| 1140 |
-
)
|
| 1141 |
-
b_ptrs = b_ptr + (
|
| 1142 |
-
offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate
|
| 1143 |
-
)
|
| 1144 |
-
dstates_ptrs = dstates_ptr + (
|
| 1145 |
-
offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate
|
| 1146 |
-
)
|
| 1147 |
-
if BLOCK_SIZE_DSTATE <= 128:
|
| 1148 |
-
b = tl.load(
|
| 1149 |
-
b_ptrs,
|
| 1150 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate),
|
| 1151 |
-
other=0.0,
|
| 1152 |
-
)
|
| 1153 |
-
dstates = tl.load(
|
| 1154 |
-
dstates_ptrs,
|
| 1155 |
-
mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim),
|
| 1156 |
-
other=0.0,
|
| 1157 |
-
)
|
| 1158 |
-
dstates = dstates.to(b_ptr.dtype.element_ty)
|
| 1159 |
-
acc = tl.dot(b, dstates)
|
| 1160 |
-
else:
|
| 1161 |
-
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 1162 |
-
for k in range(0, dstate, BLOCK_SIZE_K):
|
| 1163 |
-
b = tl.load(
|
| 1164 |
-
b_ptrs,
|
| 1165 |
-
mask=(offs_m[:, None] < chunk_size_limit)
|
| 1166 |
-
& (offs_k[None, :] < dstate - k),
|
| 1167 |
-
other=0.0,
|
| 1168 |
-
)
|
| 1169 |
-
dstates = tl.load(
|
| 1170 |
-
dstates_ptrs,
|
| 1171 |
-
mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim),
|
| 1172 |
-
other=0.0,
|
| 1173 |
-
)
|
| 1174 |
-
dstates = dstates.to(b_ptr.dtype.element_ty)
|
| 1175 |
-
acc += tl.dot(b, dstates)
|
| 1176 |
-
b_ptrs += BLOCK_SIZE_K * stride_b_dstate
|
| 1177 |
-
dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
|
| 1178 |
-
|
| 1179 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 1180 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 1181 |
-
|
| 1182 |
-
dA_cs_m = tl.load(
|
| 1183 |
-
dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0
|
| 1184 |
-
).to(tl.float32)
|
| 1185 |
-
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
| 1186 |
-
tl.float32
|
| 1187 |
-
)
|
| 1188 |
-
if not HAS_SEQ_IDX:
|
| 1189 |
-
scale = tl.exp(dA_cs_last - dA_cs_m)
|
| 1190 |
-
else:
|
| 1191 |
-
seq_idx_m = tl.load(
|
| 1192 |
-
seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
| 1193 |
-
mask=offs_m < chunk_size_limit,
|
| 1194 |
-
other=-1,
|
| 1195 |
-
)
|
| 1196 |
-
seq_idx_last = tl.load(
|
| 1197 |
-
seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
|
| 1198 |
-
)
|
| 1199 |
-
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
|
| 1200 |
-
acc *= scale[:, None]
|
| 1201 |
-
|
| 1202 |
-
x_ptrs = x_ptr + (
|
| 1203 |
-
offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
|
| 1204 |
-
)
|
| 1205 |
-
x = tl.load(
|
| 1206 |
-
x_ptrs,
|
| 1207 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
| 1208 |
-
other=0.0,
|
| 1209 |
-
).to(tl.float32)
|
| 1210 |
-
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
| 1211 |
-
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
| 1212 |
-
ddt = tl.sum(acc * x, axis=1)
|
| 1213 |
-
# ddA_cs = -(ddt * dt_m)
|
| 1214 |
-
# Triton 2.2.0 errors if we have the cumsum here, so we just write it out
|
| 1215 |
-
# then call torch.cumsum outside this kernel.
|
| 1216 |
-
# ddA_cs = tl.cumsum(ddt * dt_m)
|
| 1217 |
-
ddA_cs = ddt * dt_m
|
| 1218 |
-
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
|
| 1219 |
-
# tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
|
| 1220 |
-
tl.atomic_add(
|
| 1221 |
-
ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1
|
| 1222 |
-
)
|
| 1223 |
-
|
| 1224 |
-
|
| 1225 |
-
@triton.autotune(
|
| 1226 |
-
configs=[
|
| 1227 |
-
triton.Config(
|
| 1228 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
| 1229 |
-
num_stages=3,
|
| 1230 |
-
num_warps=8,
|
| 1231 |
-
),
|
| 1232 |
-
triton.Config(
|
| 1233 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
| 1234 |
-
num_stages=4,
|
| 1235 |
-
num_warps=4,
|
| 1236 |
-
),
|
| 1237 |
-
triton.Config(
|
| 1238 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 1239 |
-
num_stages=4,
|
| 1240 |
-
num_warps=4,
|
| 1241 |
-
),
|
| 1242 |
-
triton.Config(
|
| 1243 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 1244 |
-
num_stages=4,
|
| 1245 |
-
num_warps=4,
|
| 1246 |
-
),
|
| 1247 |
-
triton.Config(
|
| 1248 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 1249 |
-
num_stages=4,
|
| 1250 |
-
num_warps=4,
|
| 1251 |
-
),
|
| 1252 |
-
triton.Config(
|
| 1253 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 1254 |
-
num_stages=4,
|
| 1255 |
-
num_warps=4,
|
| 1256 |
-
),
|
| 1257 |
-
triton.Config(
|
| 1258 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 1259 |
-
num_stages=5,
|
| 1260 |
-
num_warps=2,
|
| 1261 |
-
),
|
| 1262 |
-
triton.Config(
|
| 1263 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 1264 |
-
num_stages=5,
|
| 1265 |
-
num_warps=2,
|
| 1266 |
-
),
|
| 1267 |
-
triton.Config(
|
| 1268 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 1269 |
-
num_stages=4,
|
| 1270 |
-
num_warps=2,
|
| 1271 |
-
),
|
| 1272 |
-
],
|
| 1273 |
-
key=["hdim", "dstate", "chunk_size"],
|
| 1274 |
-
)
|
| 1275 |
-
@triton.jit
|
| 1276 |
-
def _chunk_state_varlen_kernel(
|
| 1277 |
-
# Pointers to matrices
|
| 1278 |
-
x_ptr,
|
| 1279 |
-
b_ptr,
|
| 1280 |
-
dt_ptr,
|
| 1281 |
-
dA_cumsum_ptr,
|
| 1282 |
-
chunk_states_ptr,
|
| 1283 |
-
cu_seqlens_ptr,
|
| 1284 |
-
states_ptr,
|
| 1285 |
-
# Matrix dimensions
|
| 1286 |
-
hdim,
|
| 1287 |
-
dstate,
|
| 1288 |
-
chunk_size,
|
| 1289 |
-
seqlen,
|
| 1290 |
-
nheads_ngroups_ratio,
|
| 1291 |
-
# Strides
|
| 1292 |
-
stride_x_seqlen,
|
| 1293 |
-
stride_x_head,
|
| 1294 |
-
stride_x_hdim,
|
| 1295 |
-
stride_b_seqlen,
|
| 1296 |
-
stride_b_head,
|
| 1297 |
-
stride_b_dstate,
|
| 1298 |
-
stride_dt_chunk,
|
| 1299 |
-
stride_dt_head,
|
| 1300 |
-
stride_dt_csize,
|
| 1301 |
-
stride_dA_cs_chunk,
|
| 1302 |
-
stride_dA_cs_head,
|
| 1303 |
-
stride_dA_cs_csize,
|
| 1304 |
-
stride_chunk_states_chunk,
|
| 1305 |
-
stride_chunk_states_head,
|
| 1306 |
-
stride_chunk_states_hdim,
|
| 1307 |
-
stride_chunk_states_dstate,
|
| 1308 |
-
stride_states_batch,
|
| 1309 |
-
stride_states_head,
|
| 1310 |
-
stride_states_hdim,
|
| 1311 |
-
stride_states_dstate,
|
| 1312 |
-
# Meta-parameters
|
| 1313 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 1314 |
-
BLOCK_SIZE_N: tl.constexpr,
|
| 1315 |
-
BLOCK_SIZE_K: tl.constexpr,
|
| 1316 |
-
):
|
| 1317 |
-
pid_b = tl.program_id(axis=1)
|
| 1318 |
-
pid_h = tl.program_id(axis=2)
|
| 1319 |
-
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
| 1320 |
-
pid_m = tl.program_id(axis=0) // num_pid_n
|
| 1321 |
-
pid_n = tl.program_id(axis=0) % num_pid_n
|
| 1322 |
-
end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
|
| 1323 |
-
pid_c = (end_idx - 1) // chunk_size
|
| 1324 |
-
b_ptr += (
|
| 1325 |
-
pid_c * chunk_size * stride_b_seqlen
|
| 1326 |
-
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
| 1327 |
-
)
|
| 1328 |
-
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
| 1329 |
-
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
| 1330 |
-
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
| 1331 |
-
chunk_states_ptr += (
|
| 1332 |
-
pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
|
| 1333 |
-
)
|
| 1334 |
-
|
| 1335 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 1336 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 1337 |
-
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 1338 |
-
x_ptrs = x_ptr + (
|
| 1339 |
-
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
|
| 1340 |
-
)
|
| 1341 |
-
b_ptrs = b_ptr + (
|
| 1342 |
-
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
|
| 1343 |
-
)
|
| 1344 |
-
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
| 1345 |
-
dA_cs_last = tl.load(
|
| 1346 |
-
dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
|
| 1347 |
-
).to(tl.float32)
|
| 1348 |
-
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
| 1349 |
-
|
| 1350 |
-
chunk_size_limit = end_idx - pid_c * chunk_size
|
| 1351 |
-
start_idx = tl.load(cu_seqlens_ptr + pid_b)
|
| 1352 |
-
start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
|
| 1353 |
-
|
| 1354 |
-
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 1355 |
-
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
| 1356 |
-
x = tl.load(
|
| 1357 |
-
x_ptrs,
|
| 1358 |
-
mask=(offs_m[:, None] < hdim)
|
| 1359 |
-
& (offs_k[None, :] < chunk_size_limit - k)
|
| 1360 |
-
& (offs_k[None, :] >= start_idx_cur - k),
|
| 1361 |
-
other=0.0,
|
| 1362 |
-
)
|
| 1363 |
-
b = tl.load(
|
| 1364 |
-
b_ptrs,
|
| 1365 |
-
mask=(offs_k[:, None] < chunk_size_limit - k)
|
| 1366 |
-
& (offs_n[None, :] < dstate)
|
| 1367 |
-
& (offs_k[:, None] >= start_idx_cur - k),
|
| 1368 |
-
other=0.0,
|
| 1369 |
-
).to(tl.float32)
|
| 1370 |
-
dA_cs_k = tl.load(
|
| 1371 |
-
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
|
| 1372 |
-
).to(tl.float32)
|
| 1373 |
-
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
|
| 1374 |
-
tl.float32
|
| 1375 |
-
)
|
| 1376 |
-
scale = tl.where(
|
| 1377 |
-
(offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
|
| 1378 |
-
tl.exp((dA_cs_last - dA_cs_k)) * dt_k,
|
| 1379 |
-
0.0,
|
| 1380 |
-
)
|
| 1381 |
-
b *= scale[:, None]
|
| 1382 |
-
b = b.to(x_ptr.dtype.element_ty)
|
| 1383 |
-
acc += tl.dot(x, b)
|
| 1384 |
-
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
| 1385 |
-
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
| 1386 |
-
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
| 1387 |
-
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
| 1388 |
-
|
| 1389 |
-
# If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
|
| 1390 |
-
if start_idx < pid_c * chunk_size:
|
| 1391 |
-
chunk_states_ptrs = chunk_states_ptr + (
|
| 1392 |
-
offs_m[:, None] * stride_chunk_states_hdim
|
| 1393 |
-
+ offs_n[None, :] * stride_chunk_states_dstate
|
| 1394 |
-
)
|
| 1395 |
-
chunk_states = tl.load(
|
| 1396 |
-
chunk_states_ptrs,
|
| 1397 |
-
mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate),
|
| 1398 |
-
other=0.0,
|
| 1399 |
-
).to(tl.float32)
|
| 1400 |
-
# scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0)
|
| 1401 |
-
scale = tl.exp(dA_cs_last)
|
| 1402 |
-
acc += chunk_states * scale
|
| 1403 |
-
|
| 1404 |
-
states = acc.to(states_ptr.dtype.element_ty)
|
| 1405 |
-
|
| 1406 |
-
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
|
| 1407 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 1408 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 1409 |
-
states_ptrs = states_ptr + (
|
| 1410 |
-
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
|
| 1411 |
-
)
|
| 1412 |
-
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
| 1413 |
-
tl.store(states_ptrs, states, mask=c_mask)
|
| 1414 |
-
|
| 1415 |
-
|
| 1416 |
-
def _chunk_cumsum_fwd(
|
| 1417 |
-
dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))
|
| 1418 |
-
):
|
| 1419 |
-
batch, seqlen, nheads = dt.shape
|
| 1420 |
-
assert A.shape == (nheads,)
|
| 1421 |
-
if dt_bias is not None:
|
| 1422 |
-
assert dt_bias.shape == (nheads,)
|
| 1423 |
-
nchunks = math.ceil(seqlen / chunk_size)
|
| 1424 |
-
dt_out = torch.empty(
|
| 1425 |
-
batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
|
| 1426 |
-
)
|
| 1427 |
-
dA_cumsum = torch.empty(
|
| 1428 |
-
batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
|
| 1429 |
-
)
|
| 1430 |
-
grid_chunk_cs = lambda META: (
|
| 1431 |
-
batch,
|
| 1432 |
-
nchunks,
|
| 1433 |
-
triton.cdiv(nheads, META["BLOCK_SIZE_H"]),
|
| 1434 |
-
)
|
| 1435 |
-
with torch.cuda.device(dt.device.index):
|
| 1436 |
-
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
|
| 1437 |
-
dt,
|
| 1438 |
-
A,
|
| 1439 |
-
dt_bias,
|
| 1440 |
-
dt_out,
|
| 1441 |
-
dA_cumsum,
|
| 1442 |
-
batch,
|
| 1443 |
-
seqlen,
|
| 1444 |
-
nheads,
|
| 1445 |
-
chunk_size,
|
| 1446 |
-
dt_limit[0],
|
| 1447 |
-
dt_limit[1],
|
| 1448 |
-
dt.stride(0),
|
| 1449 |
-
dt.stride(1),
|
| 1450 |
-
dt.stride(2),
|
| 1451 |
-
A.stride(0),
|
| 1452 |
-
dt_bias.stride(0) if dt_bias is not None else 0,
|
| 1453 |
-
dt_out.stride(0),
|
| 1454 |
-
dt_out.stride(2),
|
| 1455 |
-
dt_out.stride(1),
|
| 1456 |
-
dt_out.stride(3),
|
| 1457 |
-
dA_cumsum.stride(0),
|
| 1458 |
-
dA_cumsum.stride(2),
|
| 1459 |
-
dA_cumsum.stride(1),
|
| 1460 |
-
dA_cumsum.stride(3),
|
| 1461 |
-
dt_softplus,
|
| 1462 |
-
HAS_DT_BIAS=dt_bias is not None,
|
| 1463 |
-
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
|
| 1464 |
-
)
|
| 1465 |
-
return dA_cumsum, dt_out
|
| 1466 |
-
|
| 1467 |
-
|
| 1468 |
-
def _chunk_cumsum_bwd(
|
| 1469 |
-
ddA,
|
| 1470 |
-
ddt_out,
|
| 1471 |
-
dt,
|
| 1472 |
-
A,
|
| 1473 |
-
dt_bias=None,
|
| 1474 |
-
dt_softplus=False,
|
| 1475 |
-
dt_limit=(0.0, float("inf")),
|
| 1476 |
-
ddt=None,
|
| 1477 |
-
):
|
| 1478 |
-
batch, seqlen, nheads = dt.shape
|
| 1479 |
-
_, _, nchunks, chunk_size = ddA.shape
|
| 1480 |
-
assert ddA.shape == (batch, nheads, nchunks, chunk_size)
|
| 1481 |
-
assert ddt_out.shape == (batch, nheads, nchunks, chunk_size)
|
| 1482 |
-
assert A.shape == (nheads,)
|
| 1483 |
-
if dt_bias is not None:
|
| 1484 |
-
assert dt_bias.shape == (nheads,)
|
| 1485 |
-
ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32)
|
| 1486 |
-
else:
|
| 1487 |
-
ddt_bias = None
|
| 1488 |
-
if ddt is not None:
|
| 1489 |
-
assert ddt.shape == dt.shape
|
| 1490 |
-
else:
|
| 1491 |
-
ddt = torch.empty_like(dt)
|
| 1492 |
-
dA = torch.empty_like(A, dtype=torch.float32)
|
| 1493 |
-
grid_chunk_cs = lambda META: (
|
| 1494 |
-
batch,
|
| 1495 |
-
nchunks,
|
| 1496 |
-
triton.cdiv(nheads, META["BLOCK_SIZE_H"]),
|
| 1497 |
-
)
|
| 1498 |
-
with torch.cuda.device(dt.device.index):
|
| 1499 |
-
_chunk_cumsum_bwd_kernel[grid_chunk_cs](
|
| 1500 |
-
ddA,
|
| 1501 |
-
ddt_out,
|
| 1502 |
-
dt,
|
| 1503 |
-
A,
|
| 1504 |
-
dt_bias,
|
| 1505 |
-
ddt,
|
| 1506 |
-
dA,
|
| 1507 |
-
ddt_bias,
|
| 1508 |
-
batch,
|
| 1509 |
-
seqlen,
|
| 1510 |
-
nheads,
|
| 1511 |
-
chunk_size,
|
| 1512 |
-
dt_limit[0],
|
| 1513 |
-
dt_limit[1],
|
| 1514 |
-
ddA.stride(0),
|
| 1515 |
-
ddA.stride(2),
|
| 1516 |
-
ddA.stride(1),
|
| 1517 |
-
ddA.stride(3),
|
| 1518 |
-
ddt_out.stride(0),
|
| 1519 |
-
ddt_out.stride(2),
|
| 1520 |
-
ddt_out.stride(1),
|
| 1521 |
-
ddt_out.stride(3),
|
| 1522 |
-
dt.stride(0),
|
| 1523 |
-
dt.stride(1),
|
| 1524 |
-
dt.stride(2),
|
| 1525 |
-
A.stride(0),
|
| 1526 |
-
dt_bias.stride(0) if dt_bias is not None else 0,
|
| 1527 |
-
ddt.stride(0),
|
| 1528 |
-
ddt.stride(1),
|
| 1529 |
-
ddt.stride(2),
|
| 1530 |
-
dA.stride(0),
|
| 1531 |
-
ddt_bias.stride(0) if ddt_bias is not None else 0,
|
| 1532 |
-
dt_softplus,
|
| 1533 |
-
HAS_DT_BIAS=dt_bias is not None,
|
| 1534 |
-
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
|
| 1535 |
-
)
|
| 1536 |
-
return ddt, dA, ddt_bias
|
| 1537 |
-
|
| 1538 |
-
|
| 1539 |
-
def _chunk_state_fwd(
|
| 1540 |
-
B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True
|
| 1541 |
-
):
|
| 1542 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1543 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1544 |
-
_, _, ngroups, dstate = B.shape
|
| 1545 |
-
assert nheads % ngroups == 0
|
| 1546 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 1547 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 1548 |
-
assert dA_cumsum.shape == dt.shape
|
| 1549 |
-
if seq_idx is not None:
|
| 1550 |
-
assert seq_idx.shape == (batch, seqlen)
|
| 1551 |
-
if states is not None:
|
| 1552 |
-
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
|
| 1553 |
-
else:
|
| 1554 |
-
states_dtype = torch.float32 if states_in_fp32 else B.dtype
|
| 1555 |
-
states = torch.empty(
|
| 1556 |
-
(batch, nchunks, nheads, headdim, dstate),
|
| 1557 |
-
device=x.device,
|
| 1558 |
-
dtype=states_dtype,
|
| 1559 |
-
)
|
| 1560 |
-
grid = lambda META: (
|
| 1561 |
-
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
|
| 1562 |
-
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
| 1563 |
-
batch * nchunks,
|
| 1564 |
-
nheads,
|
| 1565 |
-
)
|
| 1566 |
-
with torch.cuda.device(x.device.index):
|
| 1567 |
-
_chunk_state_fwd_kernel[grid](
|
| 1568 |
-
x,
|
| 1569 |
-
B,
|
| 1570 |
-
states,
|
| 1571 |
-
dt,
|
| 1572 |
-
dA_cumsum,
|
| 1573 |
-
seq_idx,
|
| 1574 |
-
headdim,
|
| 1575 |
-
dstate,
|
| 1576 |
-
chunk_size,
|
| 1577 |
-
batch,
|
| 1578 |
-
seqlen,
|
| 1579 |
-
nheads // ngroups,
|
| 1580 |
-
x.stride(0),
|
| 1581 |
-
x.stride(1),
|
| 1582 |
-
x.stride(2),
|
| 1583 |
-
x.stride(3),
|
| 1584 |
-
B.stride(0),
|
| 1585 |
-
B.stride(1),
|
| 1586 |
-
B.stride(2),
|
| 1587 |
-
B.stride(-1),
|
| 1588 |
-
states.stride(0),
|
| 1589 |
-
states.stride(1),
|
| 1590 |
-
states.stride(2),
|
| 1591 |
-
states.stride(3),
|
| 1592 |
-
states.stride(4),
|
| 1593 |
-
dt.stride(0),
|
| 1594 |
-
dt.stride(2),
|
| 1595 |
-
dt.stride(1),
|
| 1596 |
-
dt.stride(3),
|
| 1597 |
-
dA_cumsum.stride(0),
|
| 1598 |
-
dA_cumsum.stride(2),
|
| 1599 |
-
dA_cumsum.stride(1),
|
| 1600 |
-
dA_cumsum.stride(3),
|
| 1601 |
-
*(
|
| 1602 |
-
(seq_idx.stride(0), seq_idx.stride(1))
|
| 1603 |
-
if seq_idx is not None
|
| 1604 |
-
else (0, 0)
|
| 1605 |
-
),
|
| 1606 |
-
HAS_SEQ_IDX=seq_idx is not None,
|
| 1607 |
-
)
|
| 1608 |
-
return states
|
| 1609 |
-
|
| 1610 |
-
|
| 1611 |
-
def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None):
|
| 1612 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1613 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1614 |
-
_, _, ngroups, dstate = B.shape
|
| 1615 |
-
assert nheads % ngroups == 0
|
| 1616 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 1617 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 1618 |
-
assert dA_cumsum.shape == dt.shape
|
| 1619 |
-
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
| 1620 |
-
if dx is not None:
|
| 1621 |
-
assert dx.shape == x.shape
|
| 1622 |
-
else:
|
| 1623 |
-
dx = torch.empty_like(x)
|
| 1624 |
-
ddt = torch.empty(
|
| 1625 |
-
batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
|
| 1626 |
-
)
|
| 1627 |
-
ddA_cumsum = torch.empty(
|
| 1628 |
-
batch, nheads, nchunks, chunk_size, device=dA_cumsum.device, dtype=torch.float32
|
| 1629 |
-
)
|
| 1630 |
-
grid_dx = lambda META: (
|
| 1631 |
-
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
| 1632 |
-
* triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
|
| 1633 |
-
batch * nchunks,
|
| 1634 |
-
nheads,
|
| 1635 |
-
)
|
| 1636 |
-
with torch.cuda.device(x.device.index):
|
| 1637 |
-
_chunk_state_bwd_dx_kernel[grid_dx](
|
| 1638 |
-
x,
|
| 1639 |
-
B,
|
| 1640 |
-
dstates,
|
| 1641 |
-
dt,
|
| 1642 |
-
dA_cumsum,
|
| 1643 |
-
dx,
|
| 1644 |
-
ddt,
|
| 1645 |
-
ddA_cumsum,
|
| 1646 |
-
chunk_size,
|
| 1647 |
-
headdim,
|
| 1648 |
-
dstate,
|
| 1649 |
-
batch,
|
| 1650 |
-
seqlen,
|
| 1651 |
-
nheads // ngroups,
|
| 1652 |
-
x.stride(0),
|
| 1653 |
-
x.stride(1),
|
| 1654 |
-
x.stride(2),
|
| 1655 |
-
x.stride(3),
|
| 1656 |
-
B.stride(0),
|
| 1657 |
-
B.stride(1),
|
| 1658 |
-
B.stride(2),
|
| 1659 |
-
B.stride(-1),
|
| 1660 |
-
dstates.stride(0),
|
| 1661 |
-
dstates.stride(1),
|
| 1662 |
-
dstates.stride(2),
|
| 1663 |
-
dstates.stride(3),
|
| 1664 |
-
dstates.stride(4),
|
| 1665 |
-
dt.stride(0),
|
| 1666 |
-
dt.stride(2),
|
| 1667 |
-
dt.stride(1),
|
| 1668 |
-
dt.stride(3),
|
| 1669 |
-
dA_cumsum.stride(0),
|
| 1670 |
-
dA_cumsum.stride(2),
|
| 1671 |
-
dA_cumsum.stride(1),
|
| 1672 |
-
dA_cumsum.stride(3),
|
| 1673 |
-
dx.stride(0),
|
| 1674 |
-
dx.stride(1),
|
| 1675 |
-
dx.stride(2),
|
| 1676 |
-
dx.stride(3),
|
| 1677 |
-
ddt.stride(0),
|
| 1678 |
-
ddt.stride(2),
|
| 1679 |
-
ddt.stride(1),
|
| 1680 |
-
ddt.stride(3),
|
| 1681 |
-
ddA_cumsum.stride(0),
|
| 1682 |
-
ddA_cumsum.stride(2),
|
| 1683 |
-
ddA_cumsum.stride(1),
|
| 1684 |
-
ddA_cumsum.stride(3),
|
| 1685 |
-
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
| 1686 |
-
)
|
| 1687 |
-
return dx, ddt.to(dt.dtype), ddA_cumsum.to(dA_cumsum.dtype)
|
| 1688 |
-
|
| 1689 |
-
|
| 1690 |
-
def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1):
|
| 1691 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1692 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1693 |
-
dstate = dstates.shape[-1]
|
| 1694 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 1695 |
-
assert dA_cumsum.shape == dt.shape
|
| 1696 |
-
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
| 1697 |
-
if seq_idx is not None:
|
| 1698 |
-
assert seq_idx.shape == (batch, seqlen)
|
| 1699 |
-
if B is not None:
|
| 1700 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 1701 |
-
B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3))
|
| 1702 |
-
# Use torch.empty since the Triton kernel will call init_to_zero
|
| 1703 |
-
ddA_cumsum = torch.empty(
|
| 1704 |
-
batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32
|
| 1705 |
-
)
|
| 1706 |
-
ddA_cumsum_strides = (
|
| 1707 |
-
ddA_cumsum.stride(0),
|
| 1708 |
-
ddA_cumsum.stride(2),
|
| 1709 |
-
ddA_cumsum.stride(1),
|
| 1710 |
-
ddA_cumsum.stride(3),
|
| 1711 |
-
)
|
| 1712 |
-
else:
|
| 1713 |
-
B_strides = (0, 0, 0, 0)
|
| 1714 |
-
ddA_cumsum = None
|
| 1715 |
-
ddA_cumsum_strides = (0, 0, 0, 0)
|
| 1716 |
-
nheads_ngroups_ratio = nheads // ngroups
|
| 1717 |
-
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
| 1718 |
-
nheads_per_program = max(
|
| 1719 |
-
min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1
|
| 1720 |
-
)
|
| 1721 |
-
nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)
|
| 1722 |
-
dB = torch.empty(
|
| 1723 |
-
batch, seqlen, nsplits, ngroups, dstate, device=x.device, dtype=torch.float32
|
| 1724 |
-
)
|
| 1725 |
-
grid_db = lambda META: (
|
| 1726 |
-
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
| 1727 |
-
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
| 1728 |
-
batch * nchunks,
|
| 1729 |
-
nsplits * ngroups,
|
| 1730 |
-
)
|
| 1731 |
-
with torch.cuda.device(x.device.index):
|
| 1732 |
-
_chunk_state_bwd_db_kernel[grid_db](
|
| 1733 |
-
x,
|
| 1734 |
-
dstates,
|
| 1735 |
-
B,
|
| 1736 |
-
dt,
|
| 1737 |
-
dA_cumsum,
|
| 1738 |
-
seq_idx,
|
| 1739 |
-
dB,
|
| 1740 |
-
ddA_cumsum,
|
| 1741 |
-
chunk_size,
|
| 1742 |
-
dstate,
|
| 1743 |
-
headdim,
|
| 1744 |
-
batch,
|
| 1745 |
-
seqlen,
|
| 1746 |
-
nheads,
|
| 1747 |
-
nheads_per_program,
|
| 1748 |
-
ngroups,
|
| 1749 |
-
x.stride(0),
|
| 1750 |
-
x.stride(1),
|
| 1751 |
-
x.stride(2),
|
| 1752 |
-
x.stride(3),
|
| 1753 |
-
dstates.stride(0),
|
| 1754 |
-
dstates.stride(1),
|
| 1755 |
-
dstates.stride(2),
|
| 1756 |
-
dstates.stride(3),
|
| 1757 |
-
dstates.stride(4),
|
| 1758 |
-
*B_strides,
|
| 1759 |
-
dt.stride(0),
|
| 1760 |
-
dt.stride(2),
|
| 1761 |
-
dt.stride(1),
|
| 1762 |
-
dt.stride(3),
|
| 1763 |
-
dA_cumsum.stride(0),
|
| 1764 |
-
dA_cumsum.stride(2),
|
| 1765 |
-
dA_cumsum.stride(1),
|
| 1766 |
-
dA_cumsum.stride(3),
|
| 1767 |
-
*(
|
| 1768 |
-
(seq_idx.stride(0), seq_idx.stride(1))
|
| 1769 |
-
if seq_idx is not None
|
| 1770 |
-
else (0, 0)
|
| 1771 |
-
),
|
| 1772 |
-
dB.stride(0),
|
| 1773 |
-
dB.stride(1),
|
| 1774 |
-
dB.stride(2),
|
| 1775 |
-
dB.stride(3),
|
| 1776 |
-
dB.stride(4),
|
| 1777 |
-
*ddA_cumsum_strides,
|
| 1778 |
-
HAS_DDA_CS=ddA_cumsum is not None,
|
| 1779 |
-
HAS_SEQ_IDX=seq_idx is not None,
|
| 1780 |
-
BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
|
| 1781 |
-
)
|
| 1782 |
-
dB = dB.sum(2)
|
| 1783 |
-
if ddA_cumsum is not None:
|
| 1784 |
-
# The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute
|
| 1785 |
-
# to the state of the chunk.
|
| 1786 |
-
# torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
|
| 1787 |
-
# But it's easier to just do the cumsum for all elements, the result will be the same.
|
| 1788 |
-
torch.cumsum(ddA_cumsum, dim=-1, out=ddA_cumsum)
|
| 1789 |
-
return dB if B is None else (dB, ddA_cumsum)
|
| 1790 |
-
|
| 1791 |
-
|
| 1792 |
-
def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None):
|
| 1793 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1794 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1795 |
-
_, _, ngroups, dstate = B.shape
|
| 1796 |
-
assert nheads % ngroups == 0
|
| 1797 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 1798 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 1799 |
-
assert dA_cumsum.shape == dt.shape
|
| 1800 |
-
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
| 1801 |
-
if seq_idx is not None:
|
| 1802 |
-
assert seq_idx.shape == (batch, seqlen)
|
| 1803 |
-
# Use torch.empty since the Triton kernel will call init_to_zero
|
| 1804 |
-
ddA_cumsum = torch.empty(
|
| 1805 |
-
batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32
|
| 1806 |
-
)
|
| 1807 |
-
grid_ddtcs = lambda META: (
|
| 1808 |
-
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
| 1809 |
-
* triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
|
| 1810 |
-
batch * nchunks,
|
| 1811 |
-
nheads,
|
| 1812 |
-
)
|
| 1813 |
-
with torch.cuda.device(x.device.index):
|
| 1814 |
-
_chunk_state_bwd_ddAcs_stable_kernel[grid_ddtcs](
|
| 1815 |
-
x,
|
| 1816 |
-
B,
|
| 1817 |
-
dstates,
|
| 1818 |
-
dt,
|
| 1819 |
-
dA_cumsum,
|
| 1820 |
-
seq_idx,
|
| 1821 |
-
ddA_cumsum,
|
| 1822 |
-
chunk_size,
|
| 1823 |
-
headdim,
|
| 1824 |
-
dstate,
|
| 1825 |
-
batch,
|
| 1826 |
-
seqlen,
|
| 1827 |
-
nheads // ngroups,
|
| 1828 |
-
x.stride(0),
|
| 1829 |
-
x.stride(1),
|
| 1830 |
-
x.stride(2),
|
| 1831 |
-
x.stride(3),
|
| 1832 |
-
B.stride(0),
|
| 1833 |
-
B.stride(1),
|
| 1834 |
-
B.stride(2),
|
| 1835 |
-
B.stride(-1),
|
| 1836 |
-
dstates.stride(0),
|
| 1837 |
-
dstates.stride(1),
|
| 1838 |
-
dstates.stride(2),
|
| 1839 |
-
dstates.stride(3),
|
| 1840 |
-
dstates.stride(4),
|
| 1841 |
-
dt.stride(0),
|
| 1842 |
-
dt.stride(2),
|
| 1843 |
-
dt.stride(1),
|
| 1844 |
-
dt.stride(3),
|
| 1845 |
-
dA_cumsum.stride(0),
|
| 1846 |
-
dA_cumsum.stride(2),
|
| 1847 |
-
dA_cumsum.stride(1),
|
| 1848 |
-
dA_cumsum.stride(3),
|
| 1849 |
-
*(
|
| 1850 |
-
(seq_idx.stride(0), seq_idx.stride(1))
|
| 1851 |
-
if seq_idx is not None
|
| 1852 |
-
else (0, 0)
|
| 1853 |
-
),
|
| 1854 |
-
ddA_cumsum.stride(0),
|
| 1855 |
-
ddA_cumsum.stride(2),
|
| 1856 |
-
ddA_cumsum.stride(1),
|
| 1857 |
-
ddA_cumsum.stride(3),
|
| 1858 |
-
HAS_SEQ_IDX=seq_idx is not None,
|
| 1859 |
-
BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16),
|
| 1860 |
-
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
| 1861 |
-
)
|
| 1862 |
-
torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
|
| 1863 |
-
return ddA_cumsum
|
| 1864 |
-
|
| 1865 |
-
|
| 1866 |
-
def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states):
|
| 1867 |
-
total_seqlen, nheads, headdim = x.shape
|
| 1868 |
-
_, nchunks, chunk_size = dt.shape
|
| 1869 |
-
_, ngroups, dstate = B.shape
|
| 1870 |
-
batch = cu_seqlens.shape[0] - 1
|
| 1871 |
-
cu_seqlens = cu_seqlens.contiguous()
|
| 1872 |
-
assert nheads % ngroups == 0
|
| 1873 |
-
assert B.shape == (total_seqlen, ngroups, dstate)
|
| 1874 |
-
assert dt.shape == (nheads, nchunks, chunk_size)
|
| 1875 |
-
assert dA_cumsum.shape == dt.shape
|
| 1876 |
-
assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
|
| 1877 |
-
states = torch.empty(
|
| 1878 |
-
batch,
|
| 1879 |
-
nheads,
|
| 1880 |
-
headdim,
|
| 1881 |
-
dstate,
|
| 1882 |
-
dtype=chunk_states.dtype,
|
| 1883 |
-
device=chunk_states.device,
|
| 1884 |
-
)
|
| 1885 |
-
grid = lambda META: (
|
| 1886 |
-
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
|
| 1887 |
-
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
| 1888 |
-
batch,
|
| 1889 |
-
nheads,
|
| 1890 |
-
)
|
| 1891 |
-
with torch.cuda.device(x.device.index):
|
| 1892 |
-
_chunk_state_varlen_kernel[grid](
|
| 1893 |
-
x,
|
| 1894 |
-
B,
|
| 1895 |
-
dt,
|
| 1896 |
-
dA_cumsum,
|
| 1897 |
-
chunk_states,
|
| 1898 |
-
cu_seqlens,
|
| 1899 |
-
states,
|
| 1900 |
-
headdim,
|
| 1901 |
-
dstate,
|
| 1902 |
-
chunk_size,
|
| 1903 |
-
total_seqlen,
|
| 1904 |
-
nheads // ngroups,
|
| 1905 |
-
x.stride(0),
|
| 1906 |
-
x.stride(1),
|
| 1907 |
-
x.stride(2),
|
| 1908 |
-
B.stride(0),
|
| 1909 |
-
B.stride(1),
|
| 1910 |
-
B.stride(2),
|
| 1911 |
-
dt.stride(1),
|
| 1912 |
-
dt.stride(0),
|
| 1913 |
-
dt.stride(2),
|
| 1914 |
-
dA_cumsum.stride(1),
|
| 1915 |
-
dA_cumsum.stride(0),
|
| 1916 |
-
dA_cumsum.stride(2),
|
| 1917 |
-
chunk_states.stride(0),
|
| 1918 |
-
chunk_states.stride(1),
|
| 1919 |
-
chunk_states.stride(2),
|
| 1920 |
-
chunk_states.stride(3),
|
| 1921 |
-
states.stride(0),
|
| 1922 |
-
states.stride(1),
|
| 1923 |
-
states.stride(2),
|
| 1924 |
-
states.stride(3),
|
| 1925 |
-
)
|
| 1926 |
-
return states
|
| 1927 |
-
|
| 1928 |
-
|
| 1929 |
-
class ChunkStateFn(torch.autograd.Function):
|
| 1930 |
-
|
| 1931 |
-
@staticmethod
|
| 1932 |
-
def forward(ctx, B, x, dt, dA_cumsum, states_in_fp32=True):
|
| 1933 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1934 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1935 |
-
assert seqlen <= nchunks * chunk_size
|
| 1936 |
-
_, _, ngroups, dstate = B.shape
|
| 1937 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 1938 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 1939 |
-
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
| 1940 |
-
if B.stride(-1) != 1:
|
| 1941 |
-
B = B.contiguous()
|
| 1942 |
-
if (
|
| 1943 |
-
x.stride(-1) != 1 and x.stride(1) != 1
|
| 1944 |
-
): # Either M or K dimension should be contiguous
|
| 1945 |
-
x = x.contiguous()
|
| 1946 |
-
states = _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=states_in_fp32)
|
| 1947 |
-
ctx.save_for_backward(B, x, dt, dA_cumsum)
|
| 1948 |
-
return states
|
| 1949 |
-
|
| 1950 |
-
@staticmethod
|
| 1951 |
-
def backward(ctx, dstates):
|
| 1952 |
-
B, x, dt, dA_cumsum = ctx.saved_tensors
|
| 1953 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1954 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1955 |
-
_, _, ngroups, dstate = B.shape
|
| 1956 |
-
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
| 1957 |
-
if dstates.stride(-1) != 1:
|
| 1958 |
-
dstates = dstates.contiguous()
|
| 1959 |
-
dx, ddt, ddA_cumsum = _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates)
|
| 1960 |
-
dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, ngroups=ngroups)
|
| 1961 |
-
dB = dB.to(B.dtype)
|
| 1962 |
-
return dB, dx, ddt, ddA_cumsum, None
|
| 1963 |
-
|
| 1964 |
-
|
| 1965 |
-
def chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True):
|
| 1966 |
-
"""
|
| 1967 |
-
Argument:
|
| 1968 |
-
B: (batch, seqlen, ngroups, headdim)
|
| 1969 |
-
x: (batch, seqlen, nheads, headdim)
|
| 1970 |
-
dt: (batch, nheads, nchunks, chunk_size)
|
| 1971 |
-
dA_cumsum: (batch, nheads, nchunks, chunk_size)
|
| 1972 |
-
Return:
|
| 1973 |
-
states: (batch, nchunks, nheads, headdim, dstate)
|
| 1974 |
-
"""
|
| 1975 |
-
return ChunkStateFn.apply(B, x, dt, dA_cumsum, states_in_fp32)
|
| 1976 |
-
|
| 1977 |
-
|
| 1978 |
-
def chunk_state_ref(B, x, dt, dA_cumsum):
|
| 1979 |
-
"""
|
| 1980 |
-
Argument:
|
| 1981 |
-
B: (batch, seqlen, ngroups, headdim)
|
| 1982 |
-
x: (batch, seqlen, nheads, headdim)
|
| 1983 |
-
dt: (batch, nheads, nchunks, chunk_size)
|
| 1984 |
-
dA_cumsum: (batch, nheads, nchunks, chunk_size)
|
| 1985 |
-
Return:
|
| 1986 |
-
states: (batch, nchunks, nheads, headdim, dstate)
|
| 1987 |
-
"""
|
| 1988 |
-
# Check constraints.
|
| 1989 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1990 |
-
dstate = B.shape[-1]
|
| 1991 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 1992 |
-
assert seqlen <= nchunks * chunk_size
|
| 1993 |
-
assert x.shape == (batch, seqlen, nheads, headdim)
|
| 1994 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 1995 |
-
ngroups = B.shape[2]
|
| 1996 |
-
assert nheads % ngroups == 0
|
| 1997 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 1998 |
-
B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
|
| 1999 |
-
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
| 2000 |
-
if seqlen < nchunks * chunk_size:
|
| 2001 |
-
x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
|
| 2002 |
-
B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
|
| 2003 |
-
x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size)
|
| 2004 |
-
B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size)
|
| 2005 |
-
decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum))
|
| 2006 |
-
return torch.einsum(
|
| 2007 |
-
"bclhn,bhcl,bhcl,bclhp->bchpn",
|
| 2008 |
-
B.to(x.dtype),
|
| 2009 |
-
decay_states.to(x.dtype),
|
| 2010 |
-
dt.to(x.dtype),
|
| 2011 |
-
x,
|
| 2012 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_combined.py
DELETED
|
@@ -1,1884 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
| 2 |
-
|
| 3 |
-
"""We want triton==2.1.0 or 2.2.0 for this
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
from typing import Optional
|
| 7 |
-
|
| 8 |
-
import math
|
| 9 |
-
from packaging import version
|
| 10 |
-
|
| 11 |
-
import torch
|
| 12 |
-
import torch.nn.functional as F
|
| 13 |
-
from torch import Tensor
|
| 14 |
-
from ...utils.torch import custom_bwd, custom_fwd
|
| 15 |
-
|
| 16 |
-
import triton
|
| 17 |
-
import triton.language as tl
|
| 18 |
-
|
| 19 |
-
from einops import rearrange, repeat
|
| 20 |
-
|
| 21 |
-
try:
|
| 22 |
-
from causal_conv1d import causal_conv1d_fn
|
| 23 |
-
import causal_conv1d_cuda
|
| 24 |
-
except ImportError:
|
| 25 |
-
causal_conv1d_fn, causal_conv1d_cuda = None, None
|
| 26 |
-
|
| 27 |
-
from .ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
|
| 28 |
-
from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd
|
| 29 |
-
from .ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db
|
| 30 |
-
from .ssd_chunk_state import _chunk_state_bwd_ddAcs_stable
|
| 31 |
-
from .ssd_chunk_state import chunk_state, chunk_state_ref
|
| 32 |
-
from .ssd_chunk_state import chunk_state_varlen
|
| 33 |
-
from .ssd_state_passing import _state_passing_fwd, _state_passing_bwd
|
| 34 |
-
from .ssd_state_passing import state_passing, state_passing_ref
|
| 35 |
-
from .ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates
|
| 36 |
-
from .ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb
|
| 37 |
-
from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable
|
| 38 |
-
from .ssd_chunk_scan import chunk_scan, chunk_scan_ref
|
| 39 |
-
from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev
|
| 40 |
-
from .layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd
|
| 41 |
-
from .k_activations import _swiglu_fwd, _swiglu_bwd
|
| 42 |
-
|
| 43 |
-
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
def init_to_zero(names):
|
| 47 |
-
return lambda nargs: [
|
| 48 |
-
nargs[name].zero_() for name in names if nargs[name] is not None
|
| 49 |
-
]
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
@triton.autotune(
|
| 53 |
-
configs=[
|
| 54 |
-
triton.Config(
|
| 55 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
| 56 |
-
num_stages=3,
|
| 57 |
-
num_warps=8,
|
| 58 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 59 |
-
),
|
| 60 |
-
triton.Config(
|
| 61 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
| 62 |
-
num_stages=4,
|
| 63 |
-
num_warps=4,
|
| 64 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 65 |
-
),
|
| 66 |
-
triton.Config(
|
| 67 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 68 |
-
num_stages=4,
|
| 69 |
-
num_warps=4,
|
| 70 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 71 |
-
),
|
| 72 |
-
triton.Config(
|
| 73 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 74 |
-
num_stages=4,
|
| 75 |
-
num_warps=4,
|
| 76 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 77 |
-
),
|
| 78 |
-
triton.Config(
|
| 79 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
| 80 |
-
num_stages=4,
|
| 81 |
-
num_warps=4,
|
| 82 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 83 |
-
),
|
| 84 |
-
triton.Config(
|
| 85 |
-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 86 |
-
num_stages=4,
|
| 87 |
-
num_warps=4,
|
| 88 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 89 |
-
),
|
| 90 |
-
triton.Config(
|
| 91 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
| 92 |
-
num_stages=5,
|
| 93 |
-
num_warps=4,
|
| 94 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 95 |
-
),
|
| 96 |
-
triton.Config(
|
| 97 |
-
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 98 |
-
num_stages=5,
|
| 99 |
-
num_warps=4,
|
| 100 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 101 |
-
),
|
| 102 |
-
triton.Config(
|
| 103 |
-
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
| 104 |
-
num_stages=4,
|
| 105 |
-
num_warps=4,
|
| 106 |
-
pre_hook=init_to_zero(["ddt_ptr"]),
|
| 107 |
-
),
|
| 108 |
-
],
|
| 109 |
-
key=["chunk_size", "hdim", "dstate"],
|
| 110 |
-
)
|
| 111 |
-
@triton.jit
|
| 112 |
-
def _chunk_scan_chunk_state_bwd_dx_kernel(
|
| 113 |
-
# Pointers to matrices
|
| 114 |
-
x_ptr,
|
| 115 |
-
cb_ptr,
|
| 116 |
-
dout_ptr,
|
| 117 |
-
dt_ptr,
|
| 118 |
-
dA_cumsum_ptr,
|
| 119 |
-
seq_idx_ptr,
|
| 120 |
-
D_ptr,
|
| 121 |
-
b_ptr,
|
| 122 |
-
dstates_ptr,
|
| 123 |
-
dx_ptr,
|
| 124 |
-
ddt_ptr,
|
| 125 |
-
dD_ptr,
|
| 126 |
-
# Matrix dimensions
|
| 127 |
-
chunk_size,
|
| 128 |
-
hdim,
|
| 129 |
-
dstate,
|
| 130 |
-
batch,
|
| 131 |
-
seqlen,
|
| 132 |
-
nheads_ngroups_ratio,
|
| 133 |
-
# Strides
|
| 134 |
-
stride_x_batch,
|
| 135 |
-
stride_x_seqlen,
|
| 136 |
-
stride_x_head,
|
| 137 |
-
stride_x_hdim,
|
| 138 |
-
stride_cb_batch,
|
| 139 |
-
stride_cb_chunk,
|
| 140 |
-
stride_cb_head,
|
| 141 |
-
stride_cb_csize_m,
|
| 142 |
-
stride_cb_csize_k,
|
| 143 |
-
stride_dout_batch,
|
| 144 |
-
stride_dout_seqlen,
|
| 145 |
-
stride_dout_head,
|
| 146 |
-
stride_dout_hdim,
|
| 147 |
-
stride_dt_batch,
|
| 148 |
-
stride_dt_chunk,
|
| 149 |
-
stride_dt_head,
|
| 150 |
-
stride_dt_csize,
|
| 151 |
-
stride_dA_cs_batch,
|
| 152 |
-
stride_dA_cs_chunk,
|
| 153 |
-
stride_dA_cs_head,
|
| 154 |
-
stride_dA_cs_csize,
|
| 155 |
-
stride_seq_idx_batch,
|
| 156 |
-
stride_seq_idx_seqlen,
|
| 157 |
-
stride_D_head,
|
| 158 |
-
stride_b_batch,
|
| 159 |
-
stride_b_seqlen,
|
| 160 |
-
stride_b_head,
|
| 161 |
-
stride_b_dstate,
|
| 162 |
-
stride_dstates_batch,
|
| 163 |
-
stride_dstates_chunk,
|
| 164 |
-
stride_dstates_head,
|
| 165 |
-
stride_dstates_hdim,
|
| 166 |
-
stride_dstates_dstate,
|
| 167 |
-
stride_dx_batch,
|
| 168 |
-
stride_dx_seqlen,
|
| 169 |
-
stride_dx_head,
|
| 170 |
-
stride_dx_hdim,
|
| 171 |
-
stride_ddt_batch,
|
| 172 |
-
stride_ddt_chunk,
|
| 173 |
-
stride_ddt_head,
|
| 174 |
-
stride_ddt_csize,
|
| 175 |
-
stride_dD_batch,
|
| 176 |
-
stride_dD_chunk,
|
| 177 |
-
stride_dD_head,
|
| 178 |
-
stride_dD_csize,
|
| 179 |
-
stride_dD_hdim,
|
| 180 |
-
# Meta-parameters
|
| 181 |
-
HAS_D: tl.constexpr,
|
| 182 |
-
D_HAS_HDIM: tl.constexpr,
|
| 183 |
-
HAS_SEQ_IDX: tl.constexpr,
|
| 184 |
-
BLOCK_SIZE_M: tl.constexpr,
|
| 185 |
-
BLOCK_SIZE_N: tl.constexpr,
|
| 186 |
-
BLOCK_SIZE_K: tl.constexpr,
|
| 187 |
-
BLOCK_SIZE_DSTATE: tl.constexpr,
|
| 188 |
-
IS_TRITON_22: tl.constexpr,
|
| 189 |
-
):
|
| 190 |
-
pid_bc = tl.program_id(axis=1)
|
| 191 |
-
pid_c = pid_bc // batch
|
| 192 |
-
pid_b = pid_bc - pid_c * batch
|
| 193 |
-
pid_h = tl.program_id(axis=2)
|
| 194 |
-
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
| 195 |
-
pid_m = tl.program_id(axis=0) // num_pid_n
|
| 196 |
-
pid_n = tl.program_id(axis=0) % num_pid_n
|
| 197 |
-
x_ptr += (
|
| 198 |
-
pid_b * stride_x_batch
|
| 199 |
-
+ pid_c * chunk_size * stride_x_seqlen
|
| 200 |
-
+ pid_h * stride_x_head
|
| 201 |
-
)
|
| 202 |
-
cb_ptr += (
|
| 203 |
-
pid_b * stride_cb_batch
|
| 204 |
-
+ pid_c * stride_cb_chunk
|
| 205 |
-
+ (pid_h // nheads_ngroups_ratio) * stride_cb_head
|
| 206 |
-
)
|
| 207 |
-
dout_ptr += (
|
| 208 |
-
pid_b * stride_dout_batch
|
| 209 |
-
+ pid_c * chunk_size * stride_dout_seqlen
|
| 210 |
-
+ pid_h * stride_dout_head
|
| 211 |
-
)
|
| 212 |
-
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
| 213 |
-
ddt_ptr += (
|
| 214 |
-
pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
|
| 215 |
-
)
|
| 216 |
-
dA_cumsum_ptr += (
|
| 217 |
-
pid_b * stride_dA_cs_batch
|
| 218 |
-
+ pid_c * stride_dA_cs_chunk
|
| 219 |
-
+ pid_h * stride_dA_cs_head
|
| 220 |
-
)
|
| 221 |
-
b_ptr += (
|
| 222 |
-
pid_b * stride_b_batch
|
| 223 |
-
+ pid_c * chunk_size * stride_b_seqlen
|
| 224 |
-
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
| 225 |
-
)
|
| 226 |
-
dstates_ptr += (
|
| 227 |
-
pid_b * stride_dstates_batch
|
| 228 |
-
+ pid_c * stride_dstates_chunk
|
| 229 |
-
+ pid_h * stride_dstates_head
|
| 230 |
-
)
|
| 231 |
-
if HAS_SEQ_IDX:
|
| 232 |
-
seq_idx_ptr += (
|
| 233 |
-
pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
| 234 |
-
)
|
| 235 |
-
|
| 236 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 237 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 238 |
-
|
| 239 |
-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
| 240 |
-
|
| 241 |
-
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 242 |
-
|
| 243 |
-
dA_cs_m = tl.load(
|
| 244 |
-
dA_cumsum_ptr + offs_m * stride_dA_cs_csize,
|
| 245 |
-
mask=offs_m < chunk_size_limit,
|
| 246 |
-
other=0.0,
|
| 247 |
-
).to(tl.float32)
|
| 248 |
-
|
| 249 |
-
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
| 250 |
-
tl.float32
|
| 251 |
-
)
|
| 252 |
-
if not HAS_SEQ_IDX:
|
| 253 |
-
scale = tl.exp(dA_cs_last - dA_cs_m)
|
| 254 |
-
else:
|
| 255 |
-
seq_idx_m = tl.load(
|
| 256 |
-
seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
| 257 |
-
mask=offs_m < chunk_size_limit,
|
| 258 |
-
other=-1,
|
| 259 |
-
)
|
| 260 |
-
seq_idx_last = tl.load(
|
| 261 |
-
seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
|
| 262 |
-
)
|
| 263 |
-
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
|
| 264 |
-
# Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
| 265 |
-
# However, we're getting error with the Triton compiler 2.1.0 for that code path:
|
| 266 |
-
# Unexpected mma -> mma layout conversion
|
| 267 |
-
# Triton 2.2.0 fixes this
|
| 268 |
-
offs_dstate = tl.arange(
|
| 269 |
-
0,
|
| 270 |
-
(
|
| 271 |
-
BLOCK_SIZE_DSTATE
|
| 272 |
-
if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128
|
| 273 |
-
else BLOCK_SIZE_K
|
| 274 |
-
),
|
| 275 |
-
)
|
| 276 |
-
b_ptrs = b_ptr + (
|
| 277 |
-
offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate
|
| 278 |
-
)
|
| 279 |
-
dstates_ptrs = dstates_ptr + (
|
| 280 |
-
offs_n[None, :] * stride_dstates_hdim
|
| 281 |
-
+ offs_dstate[:, None] * stride_dstates_dstate
|
| 282 |
-
)
|
| 283 |
-
if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128:
|
| 284 |
-
b = tl.load(
|
| 285 |
-
b_ptrs,
|
| 286 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate),
|
| 287 |
-
other=0.0,
|
| 288 |
-
)
|
| 289 |
-
dstates = tl.load(
|
| 290 |
-
dstates_ptrs,
|
| 291 |
-
mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim),
|
| 292 |
-
other=0.0,
|
| 293 |
-
)
|
| 294 |
-
dstates = dstates.to(b_ptr.dtype.element_ty)
|
| 295 |
-
acc = tl.dot(b, dstates) * scale[:, None]
|
| 296 |
-
else:
|
| 297 |
-
for k in range(0, dstate, BLOCK_SIZE_K):
|
| 298 |
-
b = tl.load(
|
| 299 |
-
b_ptrs,
|
| 300 |
-
mask=(offs_m[:, None] < chunk_size_limit)
|
| 301 |
-
& (offs_dstate[None, :] < dstate - k),
|
| 302 |
-
other=0.0,
|
| 303 |
-
)
|
| 304 |
-
dstates = tl.load(
|
| 305 |
-
dstates_ptrs,
|
| 306 |
-
mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim),
|
| 307 |
-
other=0.0,
|
| 308 |
-
)
|
| 309 |
-
dstates = dstates.to(b_ptr.dtype.element_ty)
|
| 310 |
-
acc += tl.dot(b, dstates)
|
| 311 |
-
b_ptrs += BLOCK_SIZE_K * stride_b_dstate
|
| 312 |
-
dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate
|
| 313 |
-
acc *= scale[:, None]
|
| 314 |
-
|
| 315 |
-
# x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
|
| 316 |
-
# x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
|
| 317 |
-
# dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
| 318 |
-
# dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
|
| 319 |
-
# ddt = tl.sum(acc * x, axis=1) * dt_m
|
| 320 |
-
# ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
|
| 321 |
-
# tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
|
| 322 |
-
|
| 323 |
-
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 324 |
-
cb_ptrs = cb_ptr + (
|
| 325 |
-
offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k
|
| 326 |
-
)
|
| 327 |
-
dout_ptrs = dout_ptr + (
|
| 328 |
-
offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim
|
| 329 |
-
)
|
| 330 |
-
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
| 331 |
-
K_MAX = chunk_size_limit
|
| 332 |
-
K_MIN = pid_m * BLOCK_SIZE_M
|
| 333 |
-
cb_ptrs += K_MIN * stride_cb_csize_k
|
| 334 |
-
dout_ptrs += K_MIN * stride_dout_seqlen
|
| 335 |
-
dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize
|
| 336 |
-
for k in range(K_MIN, K_MAX, BLOCK_SIZE_K):
|
| 337 |
-
k = tl.multiple_of(k, BLOCK_SIZE_K)
|
| 338 |
-
# For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower
|
| 339 |
-
cb = tl.load(
|
| 340 |
-
cb_ptrs,
|
| 341 |
-
mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k),
|
| 342 |
-
other=0.0,
|
| 343 |
-
)
|
| 344 |
-
dout = tl.load(
|
| 345 |
-
dout_ptrs,
|
| 346 |
-
mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim),
|
| 347 |
-
other=0.0,
|
| 348 |
-
)
|
| 349 |
-
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(
|
| 350 |
-
tl.float32
|
| 351 |
-
)
|
| 352 |
-
cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
|
| 353 |
-
# If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
|
| 354 |
-
# we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
|
| 355 |
-
# Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
|
| 356 |
-
# This will cause NaN in acc, and hence NaN in dx and ddt.
|
| 357 |
-
mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX)
|
| 358 |
-
cb = tl.where(mask, cb, 0.0)
|
| 359 |
-
cb = cb.to(dout_ptr.dtype.element_ty)
|
| 360 |
-
acc += tl.dot(cb, dout)
|
| 361 |
-
cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
|
| 362 |
-
dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen
|
| 363 |
-
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
| 364 |
-
|
| 365 |
-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 366 |
-
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 367 |
-
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
| 368 |
-
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
|
| 369 |
-
dx = acc * dt_m[:, None]
|
| 370 |
-
dx_ptr += (
|
| 371 |
-
pid_b * stride_dx_batch
|
| 372 |
-
+ pid_c * chunk_size * stride_dx_seqlen
|
| 373 |
-
+ pid_h * stride_dx_head
|
| 374 |
-
)
|
| 375 |
-
dx_ptrs = dx_ptr + (
|
| 376 |
-
offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim
|
| 377 |
-
)
|
| 378 |
-
if HAS_D:
|
| 379 |
-
dout_res_ptrs = dout_ptr + (
|
| 380 |
-
offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim
|
| 381 |
-
)
|
| 382 |
-
dout_res = tl.load(
|
| 383 |
-
dout_res_ptrs,
|
| 384 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
| 385 |
-
other=0.0,
|
| 386 |
-
).to(tl.float32)
|
| 387 |
-
if D_HAS_HDIM:
|
| 388 |
-
D = tl.load(
|
| 389 |
-
D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0
|
| 390 |
-
).to(tl.float32)
|
| 391 |
-
else:
|
| 392 |
-
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
|
| 393 |
-
dx += dout_res * D
|
| 394 |
-
tl.store(
|
| 395 |
-
dx_ptrs,
|
| 396 |
-
dx,
|
| 397 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
| 398 |
-
)
|
| 399 |
-
|
| 400 |
-
x_ptrs = x_ptr + (
|
| 401 |
-
offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
|
| 402 |
-
)
|
| 403 |
-
x = tl.load(
|
| 404 |
-
x_ptrs,
|
| 405 |
-
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
| 406 |
-
other=0.0,
|
| 407 |
-
).to(tl.float32)
|
| 408 |
-
if HAS_D:
|
| 409 |
-
dD_ptr += (
|
| 410 |
-
pid_b * stride_dD_batch
|
| 411 |
-
+ pid_c * stride_dD_chunk
|
| 412 |
-
+ pid_h * stride_dD_head
|
| 413 |
-
+ pid_m * stride_dD_csize
|
| 414 |
-
)
|
| 415 |
-
if D_HAS_HDIM:
|
| 416 |
-
dD_ptrs = dD_ptr + offs_n * stride_dD_hdim
|
| 417 |
-
dD = tl.sum(dout_res * x, axis=0)
|
| 418 |
-
tl.store(dD_ptrs, dD, mask=offs_n < hdim)
|
| 419 |
-
else:
|
| 420 |
-
dD = tl.sum(dout_res * x)
|
| 421 |
-
tl.store(dD_ptr, dD)
|
| 422 |
-
ddt = tl.sum(acc * x, axis=1)
|
| 423 |
-
ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
|
| 424 |
-
tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
def _chunk_scan_chunk_state_bwd_dx(
|
| 428 |
-
x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None
|
| 429 |
-
):
|
| 430 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 431 |
-
_, _, nchunks, chunk_size = dt.shape
|
| 432 |
-
_, _, ngroups, dstate = B.shape
|
| 433 |
-
assert nheads % ngroups == 0
|
| 434 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 435 |
-
assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
|
| 436 |
-
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
| 437 |
-
assert dA_cumsum.shape == dt.shape
|
| 438 |
-
assert dout.shape == x.shape
|
| 439 |
-
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
| 440 |
-
if seq_idx is not None:
|
| 441 |
-
assert seq_idx.shape == (batch, seqlen)
|
| 442 |
-
if D is not None:
|
| 443 |
-
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
|
| 444 |
-
assert D.stride(-1) == 1
|
| 445 |
-
BLOCK_SIZE_min = 32
|
| 446 |
-
dD = torch.empty(
|
| 447 |
-
triton.cdiv(chunk_size, BLOCK_SIZE_min),
|
| 448 |
-
batch,
|
| 449 |
-
nchunks,
|
| 450 |
-
nheads,
|
| 451 |
-
headdim if D.dim() == 2 else 1,
|
| 452 |
-
device=D.device,
|
| 453 |
-
dtype=torch.float32,
|
| 454 |
-
)
|
| 455 |
-
else:
|
| 456 |
-
dD = None
|
| 457 |
-
dD_strides = (
|
| 458 |
-
(dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))
|
| 459 |
-
if D is not None
|
| 460 |
-
else (0, 0, 0, 0, 0)
|
| 461 |
-
)
|
| 462 |
-
if dx is None:
|
| 463 |
-
dx = torch.empty_like(x)
|
| 464 |
-
else:
|
| 465 |
-
assert dx.shape == x.shape
|
| 466 |
-
ddt = torch.empty(
|
| 467 |
-
batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32
|
| 468 |
-
)
|
| 469 |
-
grid_dx = lambda META: (
|
| 470 |
-
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
| 471 |
-
* triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
|
| 472 |
-
batch * nchunks,
|
| 473 |
-
nheads,
|
| 474 |
-
)
|
| 475 |
-
with torch.cuda.device(x.device.index):
|
| 476 |
-
_chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](
|
| 477 |
-
x,
|
| 478 |
-
CB,
|
| 479 |
-
dout,
|
| 480 |
-
dt,
|
| 481 |
-
dA_cumsum,
|
| 482 |
-
seq_idx,
|
| 483 |
-
D,
|
| 484 |
-
B,
|
| 485 |
-
dstates,
|
| 486 |
-
dx,
|
| 487 |
-
ddt,
|
| 488 |
-
dD,
|
| 489 |
-
chunk_size,
|
| 490 |
-
headdim,
|
| 491 |
-
dstate,
|
| 492 |
-
batch,
|
| 493 |
-
seqlen,
|
| 494 |
-
nheads // ngroups,
|
| 495 |
-
x.stride(0),
|
| 496 |
-
x.stride(1),
|
| 497 |
-
x.stride(2),
|
| 498 |
-
x.stride(3),
|
| 499 |
-
CB.stride(0),
|
| 500 |
-
CB.stride(1),
|
| 501 |
-
CB.stride(2),
|
| 502 |
-
CB.stride(-1),
|
| 503 |
-
CB.stride(-2),
|
| 504 |
-
dout.stride(0),
|
| 505 |
-
dout.stride(1),
|
| 506 |
-
dout.stride(2),
|
| 507 |
-
dout.stride(3),
|
| 508 |
-
dt.stride(0),
|
| 509 |
-
dt.stride(2),
|
| 510 |
-
dt.stride(1),
|
| 511 |
-
dt.stride(3),
|
| 512 |
-
dA_cumsum.stride(0),
|
| 513 |
-
dA_cumsum.stride(2),
|
| 514 |
-
dA_cumsum.stride(1),
|
| 515 |
-
dA_cumsum.stride(3),
|
| 516 |
-
*(
|
| 517 |
-
(seq_idx.stride(0), seq_idx.stride(1))
|
| 518 |
-
if seq_idx is not None
|
| 519 |
-
else (0, 0)
|
| 520 |
-
),
|
| 521 |
-
D.stride(0) if D is not None else 0,
|
| 522 |
-
B.stride(0),
|
| 523 |
-
B.stride(1),
|
| 524 |
-
B.stride(2),
|
| 525 |
-
B.stride(3),
|
| 526 |
-
dstates.stride(0),
|
| 527 |
-
dstates.stride(1),
|
| 528 |
-
dstates.stride(2),
|
| 529 |
-
dstates.stride(3),
|
| 530 |
-
dstates.stride(4),
|
| 531 |
-
dx.stride(0),
|
| 532 |
-
dx.stride(1),
|
| 533 |
-
dx.stride(2),
|
| 534 |
-
dx.stride(3),
|
| 535 |
-
ddt.stride(0),
|
| 536 |
-
ddt.stride(2),
|
| 537 |
-
ddt.stride(1),
|
| 538 |
-
ddt.stride(3),
|
| 539 |
-
dD_strides[1],
|
| 540 |
-
dD_strides[2],
|
| 541 |
-
dD_strides[3],
|
| 542 |
-
dD_strides[0],
|
| 543 |
-
dD_strides[4],
|
| 544 |
-
D is not None,
|
| 545 |
-
D.dim() == 2 if D is not None else True,
|
| 546 |
-
HAS_SEQ_IDX=seq_idx is not None,
|
| 547 |
-
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
| 548 |
-
IS_TRITON_22=TRITON_22
|
| 549 |
-
)
|
| 550 |
-
if D is not None:
|
| 551 |
-
BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs[
|
| 552 |
-
"BLOCK_SIZE_M"
|
| 553 |
-
]
|
| 554 |
-
n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
|
| 555 |
-
dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
|
| 556 |
-
if D.dim() == 1:
|
| 557 |
-
dD = rearrange(dD, "h 1 -> h")
|
| 558 |
-
return dx, ddt.to(dtype=dt.dtype), dD
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
def _mamba_chunk_scan_combined_fwd(
|
| 562 |
-
x,
|
| 563 |
-
dt,
|
| 564 |
-
A,
|
| 565 |
-
B,
|
| 566 |
-
C,
|
| 567 |
-
chunk_size,
|
| 568 |
-
D=None,
|
| 569 |
-
z=None,
|
| 570 |
-
dt_bias=None,
|
| 571 |
-
initial_states=None,
|
| 572 |
-
seq_idx=None,
|
| 573 |
-
cu_seqlens=None,
|
| 574 |
-
dt_softplus=False,
|
| 575 |
-
dt_limit=(0.0, float("inf")),
|
| 576 |
-
):
|
| 577 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 578 |
-
_, _, ngroups, dstate = B.shape
|
| 579 |
-
assert nheads % ngroups == 0
|
| 580 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 581 |
-
assert x.shape == (batch, seqlen, nheads, headdim)
|
| 582 |
-
assert dt.shape == (batch, seqlen, nheads)
|
| 583 |
-
assert A.shape == (nheads,)
|
| 584 |
-
assert C.shape == B.shape
|
| 585 |
-
if z is not None:
|
| 586 |
-
assert z.shape == x.shape
|
| 587 |
-
if D is not None:
|
| 588 |
-
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
|
| 589 |
-
if seq_idx is not None:
|
| 590 |
-
assert seq_idx.shape == (batch, seqlen)
|
| 591 |
-
if B.stride(-1) != 1:
|
| 592 |
-
B = B.contiguous()
|
| 593 |
-
if C.stride(-1) != 1:
|
| 594 |
-
C = C.contiguous()
|
| 595 |
-
if (
|
| 596 |
-
x.stride(-1) != 1 and x.stride(1) != 1
|
| 597 |
-
): # Either M or K dimension should be contiguous
|
| 598 |
-
x = x.contiguous()
|
| 599 |
-
if (
|
| 600 |
-
z is not None and z.stride(-1) != 1 and z.stride(1) != 1
|
| 601 |
-
): # Either M or K dimension should be contiguous
|
| 602 |
-
z = z.contiguous()
|
| 603 |
-
if D is not None and D.stride(-1) != 1:
|
| 604 |
-
D = D.contiguous()
|
| 605 |
-
if initial_states is not None:
|
| 606 |
-
assert initial_states.shape == (batch, nheads, headdim, dstate)
|
| 607 |
-
# # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size)
|
| 608 |
-
# dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
|
| 609 |
-
# dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
|
| 610 |
-
# dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
|
| 611 |
-
dA_cumsum, dt = _chunk_cumsum_fwd(
|
| 612 |
-
dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit
|
| 613 |
-
)
|
| 614 |
-
states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
|
| 615 |
-
# states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True)
|
| 616 |
-
# states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True)
|
| 617 |
-
# states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True)
|
| 618 |
-
states, final_states = _state_passing_fwd(
|
| 619 |
-
rearrange(states, "... p n -> ... (p n)"),
|
| 620 |
-
dA_cumsum[:, :, :, -1],
|
| 621 |
-
initial_states=(
|
| 622 |
-
rearrange(initial_states, "... p n -> ... (p n)")
|
| 623 |
-
if initial_states is not None
|
| 624 |
-
else None
|
| 625 |
-
),
|
| 626 |
-
seq_idx=seq_idx,
|
| 627 |
-
chunk_size=chunk_size,
|
| 628 |
-
out_dtype=C.dtype,
|
| 629 |
-
)
|
| 630 |
-
states, final_states = [
|
| 631 |
-
rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]
|
| 632 |
-
]
|
| 633 |
-
# states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
|
| 634 |
-
# states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
|
| 635 |
-
CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
|
| 636 |
-
out, out_x = _chunk_scan_fwd(
|
| 637 |
-
CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx
|
| 638 |
-
)
|
| 639 |
-
if cu_seqlens is None:
|
| 640 |
-
return out, out_x, dt, dA_cumsum, states, final_states
|
| 641 |
-
else:
|
| 642 |
-
assert (
|
| 643 |
-
batch == 1
|
| 644 |
-
), "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
|
| 645 |
-
varlen_states = chunk_state_varlen(
|
| 646 |
-
B.squeeze(0),
|
| 647 |
-
x.squeeze(0),
|
| 648 |
-
dt.squeeze(0),
|
| 649 |
-
dA_cumsum.squeeze(0),
|
| 650 |
-
cu_seqlens,
|
| 651 |
-
states.squeeze(0),
|
| 652 |
-
)
|
| 653 |
-
return out, out_x, dt, dA_cumsum, states, final_states, varlen_states
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
def _mamba_chunk_scan_combined_bwd(
|
| 657 |
-
dout,
|
| 658 |
-
x,
|
| 659 |
-
dt,
|
| 660 |
-
A,
|
| 661 |
-
B,
|
| 662 |
-
C,
|
| 663 |
-
out,
|
| 664 |
-
chunk_size,
|
| 665 |
-
D=None,
|
| 666 |
-
z=None,
|
| 667 |
-
dt_bias=None,
|
| 668 |
-
initial_states=None,
|
| 669 |
-
dfinal_states=None,
|
| 670 |
-
seq_idx=None,
|
| 671 |
-
dt_softplus=False,
|
| 672 |
-
dt_limit=(0.0, float("inf")),
|
| 673 |
-
dx=None,
|
| 674 |
-
ddt=None,
|
| 675 |
-
dB=None,
|
| 676 |
-
dC=None,
|
| 677 |
-
dz=None,
|
| 678 |
-
recompute_output=False,
|
| 679 |
-
):
|
| 680 |
-
if dout.stride(-1) != 1:
|
| 681 |
-
dout = dout.contiguous()
|
| 682 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 683 |
-
nchunks = math.ceil(seqlen / chunk_size)
|
| 684 |
-
_, _, ngroups, dstate = B.shape
|
| 685 |
-
assert dout.shape == (batch, seqlen, nheads, headdim)
|
| 686 |
-
assert dt.shape == (batch, seqlen, nheads)
|
| 687 |
-
assert A.shape == (nheads,)
|
| 688 |
-
assert nheads % ngroups == 0
|
| 689 |
-
assert B.shape == (batch, seqlen, ngroups, dstate)
|
| 690 |
-
assert C.shape == B.shape
|
| 691 |
-
assert out.shape == x.shape
|
| 692 |
-
if initial_states is not None:
|
| 693 |
-
assert initial_states.shape == (batch, nheads, headdim, dstate)
|
| 694 |
-
if seq_idx is not None:
|
| 695 |
-
assert seq_idx.shape == (batch, seqlen)
|
| 696 |
-
if dx is not None:
|
| 697 |
-
assert dx.shape == x.shape
|
| 698 |
-
if dB is not None:
|
| 699 |
-
assert dB.shape == B.shape
|
| 700 |
-
dB_given = dB
|
| 701 |
-
else:
|
| 702 |
-
dB_given = torch.empty_like(B)
|
| 703 |
-
if dC is not None:
|
| 704 |
-
assert dC.shape == C.shape
|
| 705 |
-
dC_given = dC
|
| 706 |
-
else:
|
| 707 |
-
dC_given = torch.empty_like(C)
|
| 708 |
-
if dz is not None:
|
| 709 |
-
assert z is not None
|
| 710 |
-
assert dz.shape == z.shape
|
| 711 |
-
if ddt is not None:
|
| 712 |
-
assert ddt.shape == dt.shape
|
| 713 |
-
ddt_given = ddt
|
| 714 |
-
else:
|
| 715 |
-
ddt_given = torch.empty_like(dt)
|
| 716 |
-
# TD: For some reason Triton (2.1.0 and 2.2.0) errors with
|
| 717 |
-
# "[CUDA]: invalid device context" (e.g. during varlne test), and cloning makes it work. Idk why.
|
| 718 |
-
dt_in = dt.clone()
|
| 719 |
-
dA_cumsum, dt = _chunk_cumsum_fwd(
|
| 720 |
-
dt_in,
|
| 721 |
-
A,
|
| 722 |
-
chunk_size,
|
| 723 |
-
dt_bias=dt_bias,
|
| 724 |
-
dt_softplus=dt_softplus,
|
| 725 |
-
dt_limit=dt_limit,
|
| 726 |
-
)
|
| 727 |
-
CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
|
| 728 |
-
states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
|
| 729 |
-
states, _ = _state_passing_fwd(
|
| 730 |
-
rearrange(states, "... p n -> ... (p n)"),
|
| 731 |
-
dA_cumsum[:, :, :, -1],
|
| 732 |
-
initial_states=(
|
| 733 |
-
rearrange(initial_states, "... p n -> ... (p n)")
|
| 734 |
-
if initial_states is not None
|
| 735 |
-
else None
|
| 736 |
-
),
|
| 737 |
-
seq_idx=seq_idx,
|
| 738 |
-
chunk_size=chunk_size,
|
| 739 |
-
)
|
| 740 |
-
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
|
| 741 |
-
if z is not None:
|
| 742 |
-
dz, dout, dD, *rest = _chunk_scan_bwd_dz(
|
| 743 |
-
x,
|
| 744 |
-
z,
|
| 745 |
-
out,
|
| 746 |
-
dout,
|
| 747 |
-
chunk_size=chunk_size,
|
| 748 |
-
has_ddAcs=False,
|
| 749 |
-
D=D,
|
| 750 |
-
dz=dz,
|
| 751 |
-
recompute_output=recompute_output,
|
| 752 |
-
)
|
| 753 |
-
outz = rest[0] if recompute_output else out
|
| 754 |
-
else:
|
| 755 |
-
dz = None
|
| 756 |
-
outz = out
|
| 757 |
-
dstates = _chunk_scan_bwd_dstates(
|
| 758 |
-
C, dA_cumsum, dout, seq_idx=seq_idx, dtype=states.dtype
|
| 759 |
-
)
|
| 760 |
-
# dstates has length nchunks, containing the gradient to initial states at index 0 and
|
| 761 |
-
# gradient to the states of chunk (nchunks - 2) at index (nchunks - 1)
|
| 762 |
-
# Do computation in fp32 but convert dstates and states to fp16/bf16 since dstates and states
|
| 763 |
-
# will be used in matmul in the next kernels.
|
| 764 |
-
dstates, ddA_chunk_cumsum, dinitial_states, states = _state_passing_bwd(
|
| 765 |
-
rearrange(states, "... p n -> ... (p n)"),
|
| 766 |
-
dA_cumsum[:, :, :, -1],
|
| 767 |
-
rearrange(dstates, "... p n -> ... (p n)"),
|
| 768 |
-
dfinal_states=(
|
| 769 |
-
rearrange(dfinal_states, "... p n -> ... (p n)")
|
| 770 |
-
if dfinal_states is not None
|
| 771 |
-
else None
|
| 772 |
-
),
|
| 773 |
-
seq_idx=seq_idx,
|
| 774 |
-
has_initial_states=initial_states is not None,
|
| 775 |
-
dstates_dtype=x.dtype,
|
| 776 |
-
states_dtype=x.dtype,
|
| 777 |
-
chunk_size=chunk_size,
|
| 778 |
-
)
|
| 779 |
-
# dstates has length nchunks, containing the gradient to states of chunk 0 at index 0 and
|
| 780 |
-
# gradient to the final states at index (nchunks - 1)
|
| 781 |
-
# states has length nchunks, containing the initial states at index 0 and the state for chunk (nchunks - 2) at index (nchunks - 1)
|
| 782 |
-
# The final states is not stored.
|
| 783 |
-
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
|
| 784 |
-
dstates = rearrange(dstates, "... (p n) -> ... p n", n=dstate)
|
| 785 |
-
dinitial_states = (
|
| 786 |
-
rearrange(dinitial_states, "... (p n) -> ... p n", n=dstate)
|
| 787 |
-
if dinitial_states is not None
|
| 788 |
-
else None
|
| 789 |
-
)
|
| 790 |
-
dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(
|
| 791 |
-
x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx
|
| 792 |
-
)
|
| 793 |
-
# dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, ngroups=ngroups)
|
| 794 |
-
dB, ddA_next = _chunk_state_bwd_db(
|
| 795 |
-
x, dt, dA_cumsum, dstates, seq_idx=seq_idx, B=B, ngroups=ngroups
|
| 796 |
-
)
|
| 797 |
-
# dC = _chunk_scan_bwd_dC(states[:, :-1].to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
|
| 798 |
-
dC, ddA_cumsum_prev = _chunk_scan_bwd_dC(
|
| 799 |
-
states.to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, C=C, ngroups=ngroups
|
| 800 |
-
)
|
| 801 |
-
# Computing ddA with the dcb kernel is much slower, so we're not using it for now
|
| 802 |
-
dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
|
| 803 |
-
# dCB, ddA_tmp = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, CB=CB, ngroups=ngroups)
|
| 804 |
-
dCB = dCB.to(CB.dtype)
|
| 805 |
-
_bmm_chunk_bwd(C, dCB, residual=dB, out=dB_given)
|
| 806 |
-
_bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC, out=dC_given)
|
| 807 |
-
# If we have z, then dout_x is recomputed in fp32 so dD = (dout_x * x).sum() is more accurate
|
| 808 |
-
# than dD_from_x = (dout_x * x).sum() where dout_x is in fp16/bf16
|
| 809 |
-
if z is None:
|
| 810 |
-
dD = dD_from_x
|
| 811 |
-
# Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D.
|
| 812 |
-
# ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt
|
| 813 |
-
# However, this is numerically unstable: when we do the reverse cumsum on ddA_cumsum, there might
|
| 814 |
-
# be a lot of underflow.
|
| 815 |
-
|
| 816 |
-
# This is already done as part of bwd_dC kernel
|
| 817 |
-
# ddA_cumsum_prev = _chunk_scan_bwd_ddAcs_prev(states[:, :-1], C, dout, dA_cumsum, seq_idx=seq_idx)
|
| 818 |
-
ddA_cumsum_prev[..., -1] += ddA_chunk_cumsum
|
| 819 |
-
ddA_prev = ddA_cumsum_prev.flip([-1]).cumsum(dim=-1).flip([-1])
|
| 820 |
-
# This is already done as part of bwd_dB kernel
|
| 821 |
-
# ddA_next = _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=seq_idx)
|
| 822 |
-
# We don't need to pass in seq_idx because CB also zeros out entries where seq_idx[i] != seq_idx[j]
|
| 823 |
-
ddA = _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, CB)
|
| 824 |
-
ddA += ddA_next + ddA_prev
|
| 825 |
-
|
| 826 |
-
ddt_given, dA, ddt_bias = _chunk_cumsum_bwd(
|
| 827 |
-
ddA,
|
| 828 |
-
ddt,
|
| 829 |
-
dt_in,
|
| 830 |
-
A,
|
| 831 |
-
dt_bias=dt_bias,
|
| 832 |
-
dt_softplus=dt_softplus,
|
| 833 |
-
dt_limit=dt_limit,
|
| 834 |
-
ddt=ddt_given,
|
| 835 |
-
)
|
| 836 |
-
|
| 837 |
-
# These 2 lines are just to test ddt and dA being computed by old code
|
| 838 |
-
# _, dA = selective_scan_bwd(dout, x, dt, A, B, C, D=D.float(), z=z)
|
| 839 |
-
# ddt_given.copy_(ddt)
|
| 840 |
-
|
| 841 |
-
return_vals = (
|
| 842 |
-
dx,
|
| 843 |
-
ddt_given,
|
| 844 |
-
dA,
|
| 845 |
-
dB_given,
|
| 846 |
-
dC_given,
|
| 847 |
-
dD,
|
| 848 |
-
dz,
|
| 849 |
-
ddt_bias,
|
| 850 |
-
dinitial_states,
|
| 851 |
-
)
|
| 852 |
-
return return_vals if not recompute_output else (*return_vals, outz)
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
def selective_scan_bwd(dout, x, dt, A, B, C, D=None, z=None):
|
| 856 |
-
"""
|
| 857 |
-
Argument:
|
| 858 |
-
dout: (batch, seqlen, nheads, headdim)
|
| 859 |
-
x: (batch, seqlen, nheads, headdim)
|
| 860 |
-
dt: (batch, nheads, nchunks, chunk_size) or (batch, nheads, headdim, nchunks, chunk_size)
|
| 861 |
-
A: (nheads) or (dim, dstate)
|
| 862 |
-
B: (batch, seqlen, ngroups, dstate)
|
| 863 |
-
C: (batch, seqlen, ngroups, dstate)
|
| 864 |
-
D: (nheads, headdim) or (nheads,)
|
| 865 |
-
z: (batch, seqlen, nheads, headdim)
|
| 866 |
-
Return:
|
| 867 |
-
out: (batch, seqlen, nheads, headdim)
|
| 868 |
-
"""
|
| 869 |
-
import selective_scan
|
| 870 |
-
|
| 871 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 872 |
-
chunk_size = dt.shape[-1]
|
| 873 |
-
_, _, ngroups, dstate = B.shape
|
| 874 |
-
assert nheads % ngroups == 0
|
| 875 |
-
x = rearrange(x, "b l h p -> b (h p) l")
|
| 876 |
-
squeeze_dt = dt.dim() == 4
|
| 877 |
-
if dt.dim() == 4:
|
| 878 |
-
dt = repeat(dt, "b h c l -> b h p c l", p=headdim)
|
| 879 |
-
dt = rearrange(dt, "b h p c l -> b (h p) (c l)", p=headdim)
|
| 880 |
-
squeeze_A = A.dim() == 1
|
| 881 |
-
if A.dim() == 1:
|
| 882 |
-
A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
|
| 883 |
-
else:
|
| 884 |
-
A = A.to(dtype=torch.float32)
|
| 885 |
-
B = rearrange(B, "b l g n -> b g n l")
|
| 886 |
-
C = rearrange(C, "b l g n -> b g n l")
|
| 887 |
-
if D is not None:
|
| 888 |
-
if D.dim() == 2:
|
| 889 |
-
D = rearrange(D, "h p -> (h p)")
|
| 890 |
-
else:
|
| 891 |
-
D = repeat(D, "h -> (h p)", p=headdim)
|
| 892 |
-
if z is not None:
|
| 893 |
-
z = rearrange(z, "b l h p -> b (h p) l")
|
| 894 |
-
|
| 895 |
-
if x.stride(-1) != 1:
|
| 896 |
-
x = x.contiguous()
|
| 897 |
-
if dt.stride(-1) != 1:
|
| 898 |
-
dt = dt.contiguous()
|
| 899 |
-
if D is not None:
|
| 900 |
-
D = D.contiguous()
|
| 901 |
-
if B.stride(-1) != 1:
|
| 902 |
-
B = B.contiguous()
|
| 903 |
-
if C.stride(-1) != 1:
|
| 904 |
-
C = C.contiguous()
|
| 905 |
-
if z is not None and z.stride(-1) != 1:
|
| 906 |
-
z = z.contiguous()
|
| 907 |
-
_, intermediate, *rest = selective_scan.fwd(
|
| 908 |
-
x, dt.to(dtype=x.dtype), A, B, C, D, z, None, False
|
| 909 |
-
)
|
| 910 |
-
if z is not None:
|
| 911 |
-
out = rest[0]
|
| 912 |
-
else:
|
| 913 |
-
out = None
|
| 914 |
-
|
| 915 |
-
dout = rearrange(dout, "b l h p -> b (h p) l")
|
| 916 |
-
|
| 917 |
-
if dout.stride(-1) != 1:
|
| 918 |
-
dout = dout.contiguous()
|
| 919 |
-
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
| 920 |
-
# backward of selective_scan with the backward of chunk).
|
| 921 |
-
# Here we just pass in None and dz will be allocated in the C++ code.
|
| 922 |
-
_, ddt, dA, *rest = selective_scan.bwd(
|
| 923 |
-
x,
|
| 924 |
-
dt.to(dtype=x.dtype),
|
| 925 |
-
A,
|
| 926 |
-
B,
|
| 927 |
-
C,
|
| 928 |
-
D,
|
| 929 |
-
z,
|
| 930 |
-
None,
|
| 931 |
-
dout,
|
| 932 |
-
intermediate,
|
| 933 |
-
out,
|
| 934 |
-
None,
|
| 935 |
-
False,
|
| 936 |
-
False, # option to recompute out_z, not used here
|
| 937 |
-
)
|
| 938 |
-
ddt = rearrange(ddt, "b (h p) (c l) -> b h p c l", p=headdim, l=chunk_size)
|
| 939 |
-
if squeeze_dt:
|
| 940 |
-
ddt = ddt.float().sum(dim=2)
|
| 941 |
-
if squeeze_A:
|
| 942 |
-
dA = rearrange(dA, "(h p) n -> h p n", p=headdim).sum(dim=(1, 2))
|
| 943 |
-
return ddt, dA
|
| 944 |
-
|
| 945 |
-
|
| 946 |
-
class MambaChunkScanCombinedFn(torch.autograd.Function):
|
| 947 |
-
|
| 948 |
-
@staticmethod
|
| 949 |
-
def forward(
|
| 950 |
-
ctx,
|
| 951 |
-
x,
|
| 952 |
-
dt,
|
| 953 |
-
A,
|
| 954 |
-
B,
|
| 955 |
-
C,
|
| 956 |
-
chunk_size,
|
| 957 |
-
D=None,
|
| 958 |
-
z=None,
|
| 959 |
-
dt_bias=None,
|
| 960 |
-
initial_states=None,
|
| 961 |
-
seq_idx=None,
|
| 962 |
-
cu_seqlens=None,
|
| 963 |
-
dt_softplus=False,
|
| 964 |
-
dt_limit=(0.0, float("inf")),
|
| 965 |
-
return_final_states=False,
|
| 966 |
-
return_varlen_states=False,
|
| 967 |
-
):
|
| 968 |
-
ctx.dt_dtype = dt.dtype
|
| 969 |
-
if not return_varlen_states:
|
| 970 |
-
cu_seqlens = None
|
| 971 |
-
else:
|
| 972 |
-
assert (
|
| 973 |
-
cu_seqlens is not None
|
| 974 |
-
), "cu_seqlens must be provided if return_varlen_states is True"
|
| 975 |
-
out, out_x, dt_out, dA_cumsum, states, final_states, *rest = (
|
| 976 |
-
_mamba_chunk_scan_combined_fwd(
|
| 977 |
-
x,
|
| 978 |
-
dt,
|
| 979 |
-
A,
|
| 980 |
-
B,
|
| 981 |
-
C,
|
| 982 |
-
chunk_size,
|
| 983 |
-
D=D,
|
| 984 |
-
z=z,
|
| 985 |
-
dt_bias=dt_bias,
|
| 986 |
-
initial_states=initial_states,
|
| 987 |
-
seq_idx=seq_idx,
|
| 988 |
-
cu_seqlens=cu_seqlens,
|
| 989 |
-
dt_softplus=dt_softplus,
|
| 990 |
-
dt_limit=dt_limit,
|
| 991 |
-
)
|
| 992 |
-
)
|
| 993 |
-
ctx.save_for_backward(
|
| 994 |
-
out if z is None else out_x,
|
| 995 |
-
x,
|
| 996 |
-
dt,
|
| 997 |
-
dA_cumsum,
|
| 998 |
-
A,
|
| 999 |
-
B,
|
| 1000 |
-
C,
|
| 1001 |
-
D,
|
| 1002 |
-
z,
|
| 1003 |
-
dt_bias,
|
| 1004 |
-
initial_states,
|
| 1005 |
-
seq_idx,
|
| 1006 |
-
)
|
| 1007 |
-
ctx.dt_softplus = dt_softplus
|
| 1008 |
-
ctx.chunk_size = chunk_size
|
| 1009 |
-
ctx.dt_limit = dt_limit
|
| 1010 |
-
ctx.return_final_states = return_final_states
|
| 1011 |
-
ctx.return_varlen_states = return_varlen_states
|
| 1012 |
-
if not return_varlen_states:
|
| 1013 |
-
return out if not return_final_states else (out, final_states)
|
| 1014 |
-
else:
|
| 1015 |
-
varlen_states = rest[0]
|
| 1016 |
-
return (
|
| 1017 |
-
(out, varlen_states)
|
| 1018 |
-
if not return_final_states
|
| 1019 |
-
else (out, final_states, varlen_states)
|
| 1020 |
-
)
|
| 1021 |
-
|
| 1022 |
-
@staticmethod
|
| 1023 |
-
def backward(ctx, dout, *args):
|
| 1024 |
-
out, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx = (
|
| 1025 |
-
ctx.saved_tensors
|
| 1026 |
-
)
|
| 1027 |
-
assert (
|
| 1028 |
-
not ctx.return_varlen_states
|
| 1029 |
-
), "return_varlen_states is not supported in backward"
|
| 1030 |
-
dfinal_states = args[0] if ctx.return_final_states else None
|
| 1031 |
-
dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = (
|
| 1032 |
-
_mamba_chunk_scan_combined_bwd(
|
| 1033 |
-
dout,
|
| 1034 |
-
x,
|
| 1035 |
-
dt,
|
| 1036 |
-
A,
|
| 1037 |
-
B,
|
| 1038 |
-
C,
|
| 1039 |
-
out,
|
| 1040 |
-
ctx.chunk_size,
|
| 1041 |
-
D=D,
|
| 1042 |
-
z=z,
|
| 1043 |
-
dt_bias=dt_bias,
|
| 1044 |
-
initial_states=initial_states,
|
| 1045 |
-
dfinal_states=dfinal_states,
|
| 1046 |
-
seq_idx=seq_idx,
|
| 1047 |
-
dt_softplus=ctx.dt_softplus,
|
| 1048 |
-
dt_limit=ctx.dt_limit,
|
| 1049 |
-
)
|
| 1050 |
-
)
|
| 1051 |
-
return (
|
| 1052 |
-
dx,
|
| 1053 |
-
ddt,
|
| 1054 |
-
dA,
|
| 1055 |
-
dB,
|
| 1056 |
-
dC,
|
| 1057 |
-
None,
|
| 1058 |
-
dD,
|
| 1059 |
-
dz,
|
| 1060 |
-
ddt_bias,
|
| 1061 |
-
dinitial_states,
|
| 1062 |
-
None,
|
| 1063 |
-
None,
|
| 1064 |
-
None,
|
| 1065 |
-
None,
|
| 1066 |
-
None,
|
| 1067 |
-
None,
|
| 1068 |
-
)
|
| 1069 |
-
|
| 1070 |
-
|
| 1071 |
-
def mamba_chunk_scan_combined(
|
| 1072 |
-
x,
|
| 1073 |
-
dt,
|
| 1074 |
-
A,
|
| 1075 |
-
B,
|
| 1076 |
-
C,
|
| 1077 |
-
chunk_size,
|
| 1078 |
-
D=None,
|
| 1079 |
-
z=None,
|
| 1080 |
-
dt_bias=None,
|
| 1081 |
-
initial_states=None,
|
| 1082 |
-
seq_idx=None,
|
| 1083 |
-
cu_seqlens=None,
|
| 1084 |
-
dt_softplus=False,
|
| 1085 |
-
dt_limit=(0.0, float("inf")),
|
| 1086 |
-
return_final_states=False,
|
| 1087 |
-
return_varlen_states=False,
|
| 1088 |
-
):
|
| 1089 |
-
"""
|
| 1090 |
-
Argument:
|
| 1091 |
-
x: (batch, seqlen, nheads, headdim)
|
| 1092 |
-
dt: (batch, seqlen, nheads)
|
| 1093 |
-
A: (nheads)
|
| 1094 |
-
B: (batch, seqlen, ngroups, dstate)
|
| 1095 |
-
C: (batch, seqlen, ngroups, dstate)
|
| 1096 |
-
chunk_size: int
|
| 1097 |
-
D: (nheads, headdim) or (nheads,)
|
| 1098 |
-
z: (batch, seqlen, nheads, headdim)
|
| 1099 |
-
dt_bias: (nheads,)
|
| 1100 |
-
initial_states: (batch, nheads, headdim, dstate)
|
| 1101 |
-
seq_idx: (batch, seqlen)
|
| 1102 |
-
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
|
| 1103 |
-
dt_softplus: Whether to apply softplus to dt
|
| 1104 |
-
Return:
|
| 1105 |
-
out: (batch, seqlen, nheads, headdim)
|
| 1106 |
-
"""
|
| 1107 |
-
return MambaChunkScanCombinedFn.apply(
|
| 1108 |
-
x,
|
| 1109 |
-
dt,
|
| 1110 |
-
A,
|
| 1111 |
-
B,
|
| 1112 |
-
C,
|
| 1113 |
-
chunk_size,
|
| 1114 |
-
D,
|
| 1115 |
-
z,
|
| 1116 |
-
dt_bias,
|
| 1117 |
-
initial_states,
|
| 1118 |
-
seq_idx,
|
| 1119 |
-
cu_seqlens,
|
| 1120 |
-
dt_softplus,
|
| 1121 |
-
dt_limit,
|
| 1122 |
-
return_final_states,
|
| 1123 |
-
return_varlen_states,
|
| 1124 |
-
)
|
| 1125 |
-
|
| 1126 |
-
|
| 1127 |
-
def mamba_chunk_scan(
|
| 1128 |
-
x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False
|
| 1129 |
-
):
|
| 1130 |
-
"""
|
| 1131 |
-
Argument:
|
| 1132 |
-
x: (batch, seqlen, nheads, headdim)
|
| 1133 |
-
dt: (batch, seqlen, nheads)
|
| 1134 |
-
A: (nheads)
|
| 1135 |
-
B: (batch, seqlen, ngroups, dstate)
|
| 1136 |
-
C: (batch, seqlen, ngroups, dstate)
|
| 1137 |
-
D: (nheads, headdim) or (nheads,)
|
| 1138 |
-
z: (batch, seqlen, nheads, headdim)
|
| 1139 |
-
dt_bias: (nheads,)
|
| 1140 |
-
Return:
|
| 1141 |
-
out: (batch, seqlen, nheads, headdim)
|
| 1142 |
-
"""
|
| 1143 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1144 |
-
dstate = B.shape[-1]
|
| 1145 |
-
if seqlen % chunk_size != 0:
|
| 1146 |
-
dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
|
| 1147 |
-
dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
|
| 1148 |
-
dt = dt.float() # We want high precision for this before cumsum
|
| 1149 |
-
if dt_bias is not None:
|
| 1150 |
-
dt = dt + rearrange(dt_bias, "h -> h 1 1")
|
| 1151 |
-
if dt_softplus:
|
| 1152 |
-
dt = F.softplus(dt)
|
| 1153 |
-
dA = dt * rearrange(A, "h -> h 1 1")
|
| 1154 |
-
dA = dt * rearrange(A, "h -> h 1 1")
|
| 1155 |
-
dA_cumsum = torch.cumsum(dA, dim=-1)
|
| 1156 |
-
# 1. Compute the state for each chunk
|
| 1157 |
-
states = chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True)
|
| 1158 |
-
# 2. Pass the state to all the chunks by weighted cumsum.
|
| 1159 |
-
states = rearrange(
|
| 1160 |
-
state_passing(
|
| 1161 |
-
rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1]
|
| 1162 |
-
)[0],
|
| 1163 |
-
"... (p n) -> ... p n",
|
| 1164 |
-
n=dstate,
|
| 1165 |
-
)
|
| 1166 |
-
# 3. Compute the output for each chunk
|
| 1167 |
-
out = chunk_scan(B, C, x, dt, dA_cumsum, states, D=D, z=z)
|
| 1168 |
-
return out
|
| 1169 |
-
|
| 1170 |
-
|
| 1171 |
-
def ssd_chunk_scan_combined_ref(
|
| 1172 |
-
x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False
|
| 1173 |
-
):
|
| 1174 |
-
"""
|
| 1175 |
-
Argument:
|
| 1176 |
-
x: (batch, seqlen, nheads, headdim)
|
| 1177 |
-
dt: (batch, seqlen, nheads)
|
| 1178 |
-
A: (nheads)
|
| 1179 |
-
B: (batch, seqlen, ngroups, dstate)
|
| 1180 |
-
C: (batch, seqlen, ngroups, dstate)
|
| 1181 |
-
D: (nheads, headdim) or (nheads,)
|
| 1182 |
-
z: (batch, seqlen, nheads, headdim)
|
| 1183 |
-
dt_bias: (nheads,)
|
| 1184 |
-
Return:
|
| 1185 |
-
out: (batch, seqlen, nheads, headdim)
|
| 1186 |
-
"""
|
| 1187 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1188 |
-
dstate = B.shape[-1]
|
| 1189 |
-
if seqlen % chunk_size != 0:
|
| 1190 |
-
dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
|
| 1191 |
-
dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
|
| 1192 |
-
dt = dt.float() # We want high precision for this before cumsum
|
| 1193 |
-
if dt_bias is not None:
|
| 1194 |
-
dt = dt + rearrange(dt_bias, "h -> h 1 1")
|
| 1195 |
-
if dt_softplus:
|
| 1196 |
-
dt = F.softplus(dt)
|
| 1197 |
-
dA = dt * rearrange(A, "h -> h 1 1")
|
| 1198 |
-
dA_cumsum = torch.cumsum(dA, dim=-1)
|
| 1199 |
-
# 1. Compute the state for each chunk
|
| 1200 |
-
states = chunk_state_ref(B, x, dt, dA_cumsum)
|
| 1201 |
-
states_dtype = states.dtype
|
| 1202 |
-
if states.dtype not in [torch.float32, torch.float64]:
|
| 1203 |
-
states = states.to(torch.float32)
|
| 1204 |
-
# 2. Pass the state to all the chunks by weighted cumsum.
|
| 1205 |
-
# state_passing_ref is much less numerically stable
|
| 1206 |
-
states = rearrange(
|
| 1207 |
-
state_passing_ref(
|
| 1208 |
-
rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1]
|
| 1209 |
-
)[0],
|
| 1210 |
-
"... (p n) -> ... p n",
|
| 1211 |
-
n=dstate,
|
| 1212 |
-
)
|
| 1213 |
-
states = states.to(states_dtype)
|
| 1214 |
-
# 3. Compute the output for each chunk
|
| 1215 |
-
out = chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z)
|
| 1216 |
-
return out
|
| 1217 |
-
|
| 1218 |
-
|
| 1219 |
-
def ssd_selective_scan(
|
| 1220 |
-
x,
|
| 1221 |
-
dt,
|
| 1222 |
-
A,
|
| 1223 |
-
B,
|
| 1224 |
-
C,
|
| 1225 |
-
D=None,
|
| 1226 |
-
z=None,
|
| 1227 |
-
dt_bias=None,
|
| 1228 |
-
dt_softplus=False,
|
| 1229 |
-
dt_limit=(0.0, float("inf")),
|
| 1230 |
-
):
|
| 1231 |
-
"""
|
| 1232 |
-
Argument:
|
| 1233 |
-
x: (batch, seqlen, nheads, headdim)
|
| 1234 |
-
dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
|
| 1235 |
-
A: (nheads) or (dim, dstate)
|
| 1236 |
-
B: (batch, seqlen, ngroups, dstate)
|
| 1237 |
-
C: (batch, seqlen, ngroups, dstate)
|
| 1238 |
-
D: (nheads, headdim) or (nheads,)
|
| 1239 |
-
z: (batch, seqlen, nheads, headdim)
|
| 1240 |
-
dt_bias: (nheads,) or (nheads, headdim)
|
| 1241 |
-
Return:
|
| 1242 |
-
out: (batch, seqlen, nheads, headdim)
|
| 1243 |
-
"""
|
| 1244 |
-
from ..selective_scan_interface import selective_scan_fn
|
| 1245 |
-
|
| 1246 |
-
batch, seqlen, nheads, headdim = x.shape
|
| 1247 |
-
_, _, ngroups, dstate = B.shape
|
| 1248 |
-
x = rearrange(x, "b l h p -> b (h p) l")
|
| 1249 |
-
if dt.dim() == 3:
|
| 1250 |
-
dt = repeat(dt, "b l h -> b l h p", p=headdim)
|
| 1251 |
-
dt = rearrange(dt, "b l h p -> b (h p) l")
|
| 1252 |
-
if A.dim() == 1:
|
| 1253 |
-
A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
|
| 1254 |
-
else:
|
| 1255 |
-
A = A.to(dtype=torch.float32)
|
| 1256 |
-
B = rearrange(B, "b l g n -> b g n l")
|
| 1257 |
-
C = rearrange(C, "b l g n -> b g n l")
|
| 1258 |
-
if D is not None:
|
| 1259 |
-
if D.dim() == 2:
|
| 1260 |
-
D = rearrange(D, "h p -> (h p)")
|
| 1261 |
-
else:
|
| 1262 |
-
D = repeat(D, "h -> (h p)", p=headdim)
|
| 1263 |
-
if z is not None:
|
| 1264 |
-
z = rearrange(z, "b l h p -> b (h p) l")
|
| 1265 |
-
if dt_bias is not None:
|
| 1266 |
-
if dt_bias.dim() == 1:
|
| 1267 |
-
dt_bias = repeat(dt_bias, "h -> h p", p=headdim)
|
| 1268 |
-
dt_bias = rearrange(dt_bias, "h p -> (h p)")
|
| 1269 |
-
if dt_limit != (0.0, float("inf")):
|
| 1270 |
-
if dt_bias is not None:
|
| 1271 |
-
dt = dt + rearrange(dt_bias, "d -> d 1")
|
| 1272 |
-
if dt_softplus:
|
| 1273 |
-
dt = F.softplus(dt)
|
| 1274 |
-
dt = dt.clamp(min=dt_limit[0], max=dt_limit[1]).to(x.dtype)
|
| 1275 |
-
dt_bias = None
|
| 1276 |
-
dt_softplus = None
|
| 1277 |
-
out = selective_scan_fn(
|
| 1278 |
-
x, dt, A, B, C, D=D, z=z, delta_bias=dt_bias, delta_softplus=dt_softplus
|
| 1279 |
-
)
|
| 1280 |
-
return rearrange(out, "b (h p) l -> b l h p", p=headdim)
|
| 1281 |
-
|
| 1282 |
-
|
| 1283 |
-
def mamba_conv1d_scan_ref(
|
| 1284 |
-
xBC,
|
| 1285 |
-
conv1d_weight,
|
| 1286 |
-
conv1d_bias,
|
| 1287 |
-
dt,
|
| 1288 |
-
A,
|
| 1289 |
-
chunk_size,
|
| 1290 |
-
D=None,
|
| 1291 |
-
z=None,
|
| 1292 |
-
dt_bias=None,
|
| 1293 |
-
dt_softplus=False,
|
| 1294 |
-
dt_limit=(0.0, float("inf")),
|
| 1295 |
-
activation="silu",
|
| 1296 |
-
headdim=None,
|
| 1297 |
-
ngroups=1,
|
| 1298 |
-
):
|
| 1299 |
-
"""
|
| 1300 |
-
Argument:
|
| 1301 |
-
xBC: (batch, seqlen, dim + 2 * ngroups * dstate) where dim == nheads * headdim
|
| 1302 |
-
conv1d_weight: (dim + 2 * ngroups * dstate, width)
|
| 1303 |
-
conv1d_bias: (dim + 2 * ngroups * dstate,)
|
| 1304 |
-
dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
|
| 1305 |
-
A: (nheads)
|
| 1306 |
-
D: (nheads, headdim) or (nheads,)
|
| 1307 |
-
z: (batch, seqlen, dim)
|
| 1308 |
-
dt_bias: (nheads) or (nheads, headdim)
|
| 1309 |
-
headdim: if D is 1D and z is None, headdim must be passed in
|
| 1310 |
-
Return:
|
| 1311 |
-
out: (batch, seqlen, dim)
|
| 1312 |
-
"""
|
| 1313 |
-
batch, seqlen, nheads = dt.shape[:3]
|
| 1314 |
-
assert nheads % ngroups == 0
|
| 1315 |
-
if z is not None:
|
| 1316 |
-
dim = z.shape[-1]
|
| 1317 |
-
assert dim % nheads == 0
|
| 1318 |
-
headdim = dim // nheads
|
| 1319 |
-
else:
|
| 1320 |
-
if D.dim() == 1:
|
| 1321 |
-
assert headdim is not None
|
| 1322 |
-
else:
|
| 1323 |
-
headdim = D.shape[1]
|
| 1324 |
-
dim = nheads * headdim
|
| 1325 |
-
xBC = rearrange(
|
| 1326 |
-
causal_conv1d_fn(
|
| 1327 |
-
rearrange(xBC, "b s d -> b d s"),
|
| 1328 |
-
conv1d_weight,
|
| 1329 |
-
conv1d_bias,
|
| 1330 |
-
activation=activation,
|
| 1331 |
-
),
|
| 1332 |
-
"b d s -> b s d",
|
| 1333 |
-
)
|
| 1334 |
-
dstate = (xBC.shape[-1] - dim) // ngroups // 2
|
| 1335 |
-
x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
|
| 1336 |
-
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
| 1337 |
-
B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
|
| 1338 |
-
C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
|
| 1339 |
-
z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
|
| 1340 |
-
out = ssd_selective_scan(
|
| 1341 |
-
x,
|
| 1342 |
-
dt.to(x.dtype),
|
| 1343 |
-
A,
|
| 1344 |
-
B,
|
| 1345 |
-
C,
|
| 1346 |
-
D=D.float(),
|
| 1347 |
-
z=z,
|
| 1348 |
-
dt_bias=dt_bias,
|
| 1349 |
-
dt_softplus=dt_softplus,
|
| 1350 |
-
dt_limit=dt_limit,
|
| 1351 |
-
)
|
| 1352 |
-
return rearrange(out, "b s h p -> b s (h p)")
|
| 1353 |
-
|
| 1354 |
-
|
| 1355 |
-
class MambaSplitConv1dScanCombinedFn(torch.autograd.Function):
|
| 1356 |
-
|
| 1357 |
-
@staticmethod
|
| 1358 |
-
@custom_fwd
|
| 1359 |
-
def forward(
|
| 1360 |
-
ctx,
|
| 1361 |
-
zxbcdt,
|
| 1362 |
-
conv1d_weight,
|
| 1363 |
-
conv1d_bias,
|
| 1364 |
-
dt_bias,
|
| 1365 |
-
A,
|
| 1366 |
-
D,
|
| 1367 |
-
chunk_size,
|
| 1368 |
-
initial_states=None,
|
| 1369 |
-
seq_idx=None,
|
| 1370 |
-
dt_limit=(0.0, float("inf")),
|
| 1371 |
-
return_final_states=False,
|
| 1372 |
-
activation="silu",
|
| 1373 |
-
rmsnorm_weight=None,
|
| 1374 |
-
rmsnorm_eps=1e-6,
|
| 1375 |
-
outproj_weight=None,
|
| 1376 |
-
outproj_bias=None,
|
| 1377 |
-
headdim=None,
|
| 1378 |
-
ngroups=1,
|
| 1379 |
-
norm_before_gate=True,
|
| 1380 |
-
):
|
| 1381 |
-
assert activation in [None, "silu", "swish"]
|
| 1382 |
-
if D.dim() == 1:
|
| 1383 |
-
assert headdim is not None
|
| 1384 |
-
(nheads,) = D.shape
|
| 1385 |
-
else:
|
| 1386 |
-
nheads, headdim = D.shape
|
| 1387 |
-
batch, seqlen, _ = zxbcdt.shape
|
| 1388 |
-
dim = nheads * headdim
|
| 1389 |
-
assert nheads % ngroups == 0
|
| 1390 |
-
dstate = (conv1d_weight.shape[0] - dim) // ngroups // 2
|
| 1391 |
-
d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ngroups * dstate - nheads) // 2
|
| 1392 |
-
assert d_nonssm >= 0
|
| 1393 |
-
assert zxbcdt.shape == (
|
| 1394 |
-
batch,
|
| 1395 |
-
seqlen,
|
| 1396 |
-
2 * d_nonssm + 2 * dim + 2 * ngroups * dstate + nheads,
|
| 1397 |
-
)
|
| 1398 |
-
assert dt_bias.shape == (nheads,)
|
| 1399 |
-
assert A.shape == (nheads,)
|
| 1400 |
-
zx0, z, xBC, dt = torch.split(
|
| 1401 |
-
zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1
|
| 1402 |
-
)
|
| 1403 |
-
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
|
| 1404 |
-
xBC_conv = rearrange(
|
| 1405 |
-
causal_conv1d_cuda.causal_conv1d_fwd(
|
| 1406 |
-
rearrange(xBC, "b s d -> b d s"),
|
| 1407 |
-
conv1d_weight,
|
| 1408 |
-
conv1d_bias,
|
| 1409 |
-
seq_idx,
|
| 1410 |
-
None,
|
| 1411 |
-
None,
|
| 1412 |
-
activation in ["silu", "swish"],
|
| 1413 |
-
),
|
| 1414 |
-
"b d s -> b s d",
|
| 1415 |
-
)
|
| 1416 |
-
x, B, C = torch.split(
|
| 1417 |
-
xBC_conv, [dim, ngroups * dstate, ngroups * dstate], dim=-1
|
| 1418 |
-
)
|
| 1419 |
-
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
| 1420 |
-
B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
|
| 1421 |
-
C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
|
| 1422 |
-
z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
|
| 1423 |
-
if rmsnorm_weight is None:
|
| 1424 |
-
out, out_x, dt_out, dA_cumsum, states, final_states = (
|
| 1425 |
-
_mamba_chunk_scan_combined_fwd(
|
| 1426 |
-
x,
|
| 1427 |
-
dt,
|
| 1428 |
-
A,
|
| 1429 |
-
B,
|
| 1430 |
-
C,
|
| 1431 |
-
chunk_size=chunk_size,
|
| 1432 |
-
D=D,
|
| 1433 |
-
z=z,
|
| 1434 |
-
dt_bias=dt_bias,
|
| 1435 |
-
initial_states=initial_states,
|
| 1436 |
-
seq_idx=seq_idx,
|
| 1437 |
-
dt_softplus=True,
|
| 1438 |
-
dt_limit=dt_limit,
|
| 1439 |
-
)
|
| 1440 |
-
)
|
| 1441 |
-
out = rearrange(out, "b s h p -> b s (h p)")
|
| 1442 |
-
rstd = None
|
| 1443 |
-
if d_nonssm > 0:
|
| 1444 |
-
out = torch.cat([_swiglu_fwd(zx0), out], dim=-1)
|
| 1445 |
-
else:
|
| 1446 |
-
out_x, _, dt_out, dA_cumsum, states, final_states = (
|
| 1447 |
-
_mamba_chunk_scan_combined_fwd(
|
| 1448 |
-
x,
|
| 1449 |
-
dt,
|
| 1450 |
-
A,
|
| 1451 |
-
B,
|
| 1452 |
-
C,
|
| 1453 |
-
chunk_size=chunk_size,
|
| 1454 |
-
D=D,
|
| 1455 |
-
z=None,
|
| 1456 |
-
dt_bias=dt_bias,
|
| 1457 |
-
initial_states=initial_states,
|
| 1458 |
-
seq_idx=seq_idx,
|
| 1459 |
-
dt_softplus=True,
|
| 1460 |
-
dt_limit=dt_limit,
|
| 1461 |
-
)
|
| 1462 |
-
)
|
| 1463 |
-
# reshape input data into 2D tensor
|
| 1464 |
-
x_rms = rearrange(out_x, "b s h p -> (b s) (h p)")
|
| 1465 |
-
z_rms = rearrange(z, "b s h p -> (b s) (h p)")
|
| 1466 |
-
rmsnorm_weight = rmsnorm_weight.contiguous()
|
| 1467 |
-
if d_nonssm == 0:
|
| 1468 |
-
out = None
|
| 1469 |
-
else:
|
| 1470 |
-
out01 = torch.empty(
|
| 1471 |
-
(batch, seqlen, d_nonssm + dim),
|
| 1472 |
-
dtype=x_rms.dtype,
|
| 1473 |
-
device=x_rms.device,
|
| 1474 |
-
)
|
| 1475 |
-
out = rearrange(out01[..., d_nonssm:], "b s d -> (b s) d")
|
| 1476 |
-
_swiglu_fwd(zx0, out=out01[..., :d_nonssm])
|
| 1477 |
-
out, _, rstd = _layer_norm_fwd(
|
| 1478 |
-
x_rms,
|
| 1479 |
-
rmsnorm_weight,
|
| 1480 |
-
None,
|
| 1481 |
-
rmsnorm_eps,
|
| 1482 |
-
z_rms,
|
| 1483 |
-
out=out,
|
| 1484 |
-
group_size=dim // ngroups,
|
| 1485 |
-
norm_before_gate=norm_before_gate,
|
| 1486 |
-
is_rms_norm=True,
|
| 1487 |
-
)
|
| 1488 |
-
if d_nonssm == 0:
|
| 1489 |
-
out = rearrange(out, "(b s) d -> b s d", b=batch)
|
| 1490 |
-
else:
|
| 1491 |
-
out = out01
|
| 1492 |
-
ctx.outproj_weight_dtype = (
|
| 1493 |
-
outproj_weight.dtype if outproj_weight is not None else None
|
| 1494 |
-
)
|
| 1495 |
-
if outproj_weight is not None:
|
| 1496 |
-
if torch.is_autocast_enabled():
|
| 1497 |
-
dtype = torch.get_autocast_gpu_dtype()
|
| 1498 |
-
out, outproj_weight = out.to(dtype), outproj_weight.to(dtype)
|
| 1499 |
-
outproj_bias = (
|
| 1500 |
-
outproj_bias.to(dtype) if outproj_bias is not None else None
|
| 1501 |
-
)
|
| 1502 |
-
out = F.linear(out, outproj_weight, outproj_bias)
|
| 1503 |
-
else:
|
| 1504 |
-
assert outproj_bias is None
|
| 1505 |
-
ctx.save_for_backward(
|
| 1506 |
-
zxbcdt,
|
| 1507 |
-
conv1d_weight,
|
| 1508 |
-
conv1d_bias,
|
| 1509 |
-
out_x,
|
| 1510 |
-
A,
|
| 1511 |
-
D,
|
| 1512 |
-
dt_bias,
|
| 1513 |
-
initial_states,
|
| 1514 |
-
seq_idx,
|
| 1515 |
-
rmsnorm_weight,
|
| 1516 |
-
rstd,
|
| 1517 |
-
outproj_weight,
|
| 1518 |
-
outproj_bias,
|
| 1519 |
-
)
|
| 1520 |
-
ctx.dt_limit = dt_limit
|
| 1521 |
-
ctx.return_final_states = return_final_states
|
| 1522 |
-
ctx.activation = activation
|
| 1523 |
-
ctx.rmsnorm_eps = rmsnorm_eps
|
| 1524 |
-
ctx.norm_before_gate = norm_before_gate
|
| 1525 |
-
ctx.chunk_size = chunk_size
|
| 1526 |
-
ctx.headdim = headdim
|
| 1527 |
-
ctx.ngroups = ngroups
|
| 1528 |
-
return out if not return_final_states else (out, final_states)
|
| 1529 |
-
|
| 1530 |
-
@staticmethod
|
| 1531 |
-
@custom_bwd
|
| 1532 |
-
def backward(ctx, dout, *args):
|
| 1533 |
-
(
|
| 1534 |
-
zxbcdt,
|
| 1535 |
-
conv1d_weight,
|
| 1536 |
-
conv1d_bias,
|
| 1537 |
-
out,
|
| 1538 |
-
A,
|
| 1539 |
-
D,
|
| 1540 |
-
dt_bias,
|
| 1541 |
-
initial_states,
|
| 1542 |
-
seq_idx,
|
| 1543 |
-
rmsnorm_weight,
|
| 1544 |
-
rstd,
|
| 1545 |
-
outproj_weight,
|
| 1546 |
-
outproj_bias,
|
| 1547 |
-
) = ctx.saved_tensors
|
| 1548 |
-
dfinal_states = args[0] if ctx.return_final_states else None
|
| 1549 |
-
headdim = ctx.headdim
|
| 1550 |
-
nheads = D.shape[0]
|
| 1551 |
-
dim = nheads * headdim
|
| 1552 |
-
assert nheads % ctx.ngroups == 0
|
| 1553 |
-
dstate = (conv1d_weight.shape[0] - dim) // ctx.ngroups // 2
|
| 1554 |
-
d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ctx.ngroups * dstate - nheads) // 2
|
| 1555 |
-
assert d_nonssm >= 0
|
| 1556 |
-
recompute_output = outproj_weight is not None
|
| 1557 |
-
if recompute_output:
|
| 1558 |
-
out_recompute = torch.empty(
|
| 1559 |
-
*out.shape[:2], d_nonssm + dim, device=out.device, dtype=out.dtype
|
| 1560 |
-
)
|
| 1561 |
-
out0_recompute, out1_recompute = out_recompute.split(
|
| 1562 |
-
[d_nonssm, dim], dim=-1
|
| 1563 |
-
)
|
| 1564 |
-
zx0, z, xBC, dt = torch.split(
|
| 1565 |
-
zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1
|
| 1566 |
-
)
|
| 1567 |
-
# Recompute x, B, C
|
| 1568 |
-
xBC_conv = rearrange(
|
| 1569 |
-
causal_conv1d_cuda.causal_conv1d_fwd(
|
| 1570 |
-
rearrange(xBC, "b s d -> b d s"),
|
| 1571 |
-
conv1d_weight,
|
| 1572 |
-
conv1d_bias,
|
| 1573 |
-
seq_idx,
|
| 1574 |
-
None,
|
| 1575 |
-
None,
|
| 1576 |
-
ctx.activation in ["silu", "swish"],
|
| 1577 |
-
),
|
| 1578 |
-
"b d s -> b s d",
|
| 1579 |
-
)
|
| 1580 |
-
x, B, C = torch.split(
|
| 1581 |
-
xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1
|
| 1582 |
-
)
|
| 1583 |
-
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
| 1584 |
-
B = rearrange(B, "b l (g n) -> b l g n", g=ctx.ngroups)
|
| 1585 |
-
C = rearrange(C, "b l (g n) -> b l g n", g=ctx.ngroups)
|
| 1586 |
-
dzxbcdt = torch.empty_like(zxbcdt)
|
| 1587 |
-
dzx0, dz, dxBC_given, ddt_given = torch.split(
|
| 1588 |
-
dzxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1
|
| 1589 |
-
)
|
| 1590 |
-
dxBC = torch.empty_like(xBC)
|
| 1591 |
-
dx, dB, dC = torch.split(
|
| 1592 |
-
dxBC, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1
|
| 1593 |
-
)
|
| 1594 |
-
z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
|
| 1595 |
-
dx = rearrange(dx, "b l (h p) -> b l h p", h=nheads)
|
| 1596 |
-
dB = rearrange(dB, "b l (g n) -> b l g n", g=ctx.ngroups)
|
| 1597 |
-
dC = rearrange(dC, "b l (g n) -> b l g n", g=ctx.ngroups)
|
| 1598 |
-
if outproj_weight is not None:
|
| 1599 |
-
dout_og = dout
|
| 1600 |
-
dout = F.linear(dout, outproj_weight.t())
|
| 1601 |
-
if d_nonssm > 0:
|
| 1602 |
-
dout0, dout = dout.split([d_nonssm, dim], dim=-1)
|
| 1603 |
-
_swiglu_bwd(zx0, dout0, dxy=dzx0, recompute_output=True, out=out0_recompute)
|
| 1604 |
-
dout = rearrange(dout, "b s (h p) -> b s h p", p=headdim)
|
| 1605 |
-
if rmsnorm_weight is None:
|
| 1606 |
-
dz = rearrange(dz, "b l (h p) -> b l h p", h=nheads)
|
| 1607 |
-
dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states, *rest = (
|
| 1608 |
-
_mamba_chunk_scan_combined_bwd(
|
| 1609 |
-
dout,
|
| 1610 |
-
x,
|
| 1611 |
-
dt,
|
| 1612 |
-
A,
|
| 1613 |
-
B,
|
| 1614 |
-
C,
|
| 1615 |
-
out,
|
| 1616 |
-
ctx.chunk_size,
|
| 1617 |
-
D=D,
|
| 1618 |
-
z=z,
|
| 1619 |
-
dt_bias=dt_bias,
|
| 1620 |
-
initial_states=initial_states,
|
| 1621 |
-
dfinal_states=dfinal_states,
|
| 1622 |
-
seq_idx=seq_idx,
|
| 1623 |
-
dt_softplus=True,
|
| 1624 |
-
dt_limit=ctx.dt_limit,
|
| 1625 |
-
dx=dx,
|
| 1626 |
-
ddt=ddt_given,
|
| 1627 |
-
dB=dB,
|
| 1628 |
-
dC=dC,
|
| 1629 |
-
dz=dz,
|
| 1630 |
-
recompute_output=recompute_output,
|
| 1631 |
-
)
|
| 1632 |
-
)
|
| 1633 |
-
out_for_linear = (
|
| 1634 |
-
rearrange(rest[0], "b s h p -> b s (h p)") if recompute_output else None
|
| 1635 |
-
)
|
| 1636 |
-
drmsnorm_weight = None
|
| 1637 |
-
else:
|
| 1638 |
-
batch = dout.shape[0]
|
| 1639 |
-
dy_rms = rearrange(dout, "b s h p -> (b s) (h p)")
|
| 1640 |
-
dz = rearrange(dz, "b l d -> (b l) d")
|
| 1641 |
-
x_rms = rearrange(out, "b s h p -> (b s) (h p)")
|
| 1642 |
-
z_rms = rearrange(z, "b s h p -> (b s) (h p)")
|
| 1643 |
-
out1_recompute = (
|
| 1644 |
-
rearrange(out1_recompute, "b s d -> (b s) d")
|
| 1645 |
-
if recompute_output
|
| 1646 |
-
else None
|
| 1647 |
-
)
|
| 1648 |
-
dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(
|
| 1649 |
-
dy_rms,
|
| 1650 |
-
x_rms,
|
| 1651 |
-
rmsnorm_weight,
|
| 1652 |
-
None,
|
| 1653 |
-
ctx.rmsnorm_eps,
|
| 1654 |
-
None,
|
| 1655 |
-
rstd,
|
| 1656 |
-
z_rms,
|
| 1657 |
-
group_size=dim // ctx.ngroups,
|
| 1658 |
-
norm_before_gate=ctx.norm_before_gate,
|
| 1659 |
-
is_rms_norm=True,
|
| 1660 |
-
recompute_output=recompute_output,
|
| 1661 |
-
dz=dz,
|
| 1662 |
-
out=out1_recompute if recompute_output else None,
|
| 1663 |
-
)
|
| 1664 |
-
out_for_linear = out_recompute if recompute_output else None
|
| 1665 |
-
dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim)
|
| 1666 |
-
dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = (
|
| 1667 |
-
_mamba_chunk_scan_combined_bwd(
|
| 1668 |
-
dout,
|
| 1669 |
-
x,
|
| 1670 |
-
dt,
|
| 1671 |
-
A,
|
| 1672 |
-
B,
|
| 1673 |
-
C,
|
| 1674 |
-
out,
|
| 1675 |
-
ctx.chunk_size,
|
| 1676 |
-
D=D,
|
| 1677 |
-
z=None,
|
| 1678 |
-
dt_bias=dt_bias,
|
| 1679 |
-
initial_states=initial_states,
|
| 1680 |
-
dfinal_states=dfinal_states,
|
| 1681 |
-
seq_idx=seq_idx,
|
| 1682 |
-
dt_softplus=True,
|
| 1683 |
-
dt_limit=ctx.dt_limit,
|
| 1684 |
-
dx=dx,
|
| 1685 |
-
ddt=ddt_given,
|
| 1686 |
-
dB=dB,
|
| 1687 |
-
dC=dC,
|
| 1688 |
-
)
|
| 1689 |
-
)
|
| 1690 |
-
|
| 1691 |
-
if outproj_weight is not None:
|
| 1692 |
-
doutproj_weight = torch.einsum("bso,bsd->od", dout_og, out_for_linear)
|
| 1693 |
-
doutproj_bias = (
|
| 1694 |
-
dout_og.sum(dim=(0, 1)) if outproj_bias is not None else None
|
| 1695 |
-
)
|
| 1696 |
-
else:
|
| 1697 |
-
doutproj_weight, doutproj_bias = None, None
|
| 1698 |
-
dxBC_given = rearrange(dxBC_given, "b s d -> b d s")
|
| 1699 |
-
dxBC_given, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
|
| 1700 |
-
rearrange(xBC, "b s d -> b d s"),
|
| 1701 |
-
conv1d_weight,
|
| 1702 |
-
conv1d_bias,
|
| 1703 |
-
rearrange(dxBC, "b s d -> b d s"),
|
| 1704 |
-
seq_idx,
|
| 1705 |
-
None,
|
| 1706 |
-
None,
|
| 1707 |
-
dxBC_given,
|
| 1708 |
-
False,
|
| 1709 |
-
ctx.activation in ["silu", "swish"],
|
| 1710 |
-
)
|
| 1711 |
-
dxBC_given = rearrange(dxBC_given, "b d s -> b s d")
|
| 1712 |
-
return (
|
| 1713 |
-
dzxbcdt,
|
| 1714 |
-
dweight,
|
| 1715 |
-
dbias,
|
| 1716 |
-
ddt_bias,
|
| 1717 |
-
dA,
|
| 1718 |
-
dD,
|
| 1719 |
-
None,
|
| 1720 |
-
dinitial_states,
|
| 1721 |
-
None,
|
| 1722 |
-
None,
|
| 1723 |
-
None,
|
| 1724 |
-
None,
|
| 1725 |
-
drmsnorm_weight,
|
| 1726 |
-
None,
|
| 1727 |
-
doutproj_weight,
|
| 1728 |
-
doutproj_bias,
|
| 1729 |
-
None,
|
| 1730 |
-
None,
|
| 1731 |
-
None,
|
| 1732 |
-
)
|
| 1733 |
-
|
| 1734 |
-
|
| 1735 |
-
def mamba_split_conv1d_scan_combined(
|
| 1736 |
-
zxbcdt,
|
| 1737 |
-
conv1d_weight,
|
| 1738 |
-
conv1d_bias,
|
| 1739 |
-
dt_bias,
|
| 1740 |
-
A,
|
| 1741 |
-
D,
|
| 1742 |
-
chunk_size,
|
| 1743 |
-
initial_states=None,
|
| 1744 |
-
seq_idx=None,
|
| 1745 |
-
dt_limit=(0.0, float("inf")),
|
| 1746 |
-
return_final_states=False,
|
| 1747 |
-
activation="silu",
|
| 1748 |
-
rmsnorm_weight=None,
|
| 1749 |
-
rmsnorm_eps=1e-6,
|
| 1750 |
-
outproj_weight=None,
|
| 1751 |
-
outproj_bias=None,
|
| 1752 |
-
headdim=None,
|
| 1753 |
-
ngroups=1,
|
| 1754 |
-
norm_before_gate=True,
|
| 1755 |
-
):
|
| 1756 |
-
"""
|
| 1757 |
-
Argument:
|
| 1758 |
-
zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
|
| 1759 |
-
conv1d_weight: (dim + 2 * ngroups * dstate, width)
|
| 1760 |
-
conv1d_bias: (dim + 2 * ngroups * dstate,)
|
| 1761 |
-
dt_bias: (nheads,)
|
| 1762 |
-
A: (nheads)
|
| 1763 |
-
D: (nheads, headdim) or (nheads,)
|
| 1764 |
-
initial_states: (batch, nheads, headdim, dstate)
|
| 1765 |
-
seq_idx: (batch, seqlen), int32
|
| 1766 |
-
rmsnorm_weight: (dim,)
|
| 1767 |
-
outproj_weight: (out_dim, dim)
|
| 1768 |
-
outproj_bias: (out_dim,)
|
| 1769 |
-
headdim: if D is 1D, headdim must be passed in
|
| 1770 |
-
norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
|
| 1771 |
-
Return:
|
| 1772 |
-
out: (batch, seqlen, dim)
|
| 1773 |
-
"""
|
| 1774 |
-
return MambaSplitConv1dScanCombinedFn.apply(
|
| 1775 |
-
zxbcdt,
|
| 1776 |
-
conv1d_weight,
|
| 1777 |
-
conv1d_bias,
|
| 1778 |
-
dt_bias,
|
| 1779 |
-
A,
|
| 1780 |
-
D,
|
| 1781 |
-
chunk_size,
|
| 1782 |
-
initial_states,
|
| 1783 |
-
seq_idx,
|
| 1784 |
-
dt_limit,
|
| 1785 |
-
return_final_states,
|
| 1786 |
-
activation,
|
| 1787 |
-
rmsnorm_weight,
|
| 1788 |
-
rmsnorm_eps,
|
| 1789 |
-
outproj_weight,
|
| 1790 |
-
outproj_bias,
|
| 1791 |
-
headdim,
|
| 1792 |
-
ngroups,
|
| 1793 |
-
norm_before_gate,
|
| 1794 |
-
)
|
| 1795 |
-
|
| 1796 |
-
|
| 1797 |
-
def mamba_split_conv1d_scan_ref(
|
| 1798 |
-
zxbcdt,
|
| 1799 |
-
conv1d_weight,
|
| 1800 |
-
conv1d_bias,
|
| 1801 |
-
dt_bias,
|
| 1802 |
-
A,
|
| 1803 |
-
D,
|
| 1804 |
-
chunk_size,
|
| 1805 |
-
dt_limit=(0.0, float("inf")),
|
| 1806 |
-
activation="silu",
|
| 1807 |
-
rmsnorm_weight=None,
|
| 1808 |
-
rmsnorm_eps=1e-6,
|
| 1809 |
-
outproj_weight=None,
|
| 1810 |
-
outproj_bias=None,
|
| 1811 |
-
headdim=None,
|
| 1812 |
-
ngroups=1,
|
| 1813 |
-
norm_before_gate=True,
|
| 1814 |
-
):
|
| 1815 |
-
"""
|
| 1816 |
-
Argument:
|
| 1817 |
-
zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
|
| 1818 |
-
conv1d_weight: (dim + 2 * ngroups * dstate, width)
|
| 1819 |
-
conv1d_bias: (dim + 2 * ngroups * dstate,)
|
| 1820 |
-
dt_bias: (nheads,)
|
| 1821 |
-
A: (nheads)
|
| 1822 |
-
D: (nheads, headdim) or (nheads,)
|
| 1823 |
-
rmsnorm_weight: (dim,)
|
| 1824 |
-
outproj_weight: (out_dim, dim)
|
| 1825 |
-
outproj_bias: (out_dim,)
|
| 1826 |
-
headdim: if D is 1D, headdim must be passed in
|
| 1827 |
-
norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
|
| 1828 |
-
Return:
|
| 1829 |
-
out: (batch, seqlen, dim)
|
| 1830 |
-
"""
|
| 1831 |
-
if D.dim() == 1:
|
| 1832 |
-
assert headdim is not None
|
| 1833 |
-
(nheads,) = D.shape
|
| 1834 |
-
else:
|
| 1835 |
-
nheads, headdim = D.shape
|
| 1836 |
-
assert nheads % ngroups == 0
|
| 1837 |
-
batch, seqlen, _ = zxbcdt.shape
|
| 1838 |
-
dim = nheads * headdim
|
| 1839 |
-
dstate = (zxbcdt.shape[-1] - 2 * dim - nheads) // ngroups // 2
|
| 1840 |
-
assert zxbcdt.shape == (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads)
|
| 1841 |
-
assert dt_bias.shape == (nheads,)
|
| 1842 |
-
assert A.shape == (nheads,)
|
| 1843 |
-
if rmsnorm_weight is not None:
|
| 1844 |
-
assert rmsnorm_weight.shape == (dim,)
|
| 1845 |
-
z, xBC, dt = torch.split(zxbcdt, [dim, dim + 2 * ngroups * dstate, nheads], dim=-1)
|
| 1846 |
-
xBC = rearrange(
|
| 1847 |
-
causal_conv1d_fn(
|
| 1848 |
-
rearrange(xBC, "b s d -> b d s"),
|
| 1849 |
-
conv1d_weight,
|
| 1850 |
-
conv1d_bias,
|
| 1851 |
-
activation=activation,
|
| 1852 |
-
),
|
| 1853 |
-
"b d s -> b s d",
|
| 1854 |
-
)
|
| 1855 |
-
x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
|
| 1856 |
-
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
| 1857 |
-
B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
|
| 1858 |
-
C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
|
| 1859 |
-
z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
|
| 1860 |
-
out = ssd_selective_scan(
|
| 1861 |
-
x,
|
| 1862 |
-
dt.to(x.dtype),
|
| 1863 |
-
A,
|
| 1864 |
-
B,
|
| 1865 |
-
C,
|
| 1866 |
-
D=D.float(),
|
| 1867 |
-
z=z if rmsnorm_weight is None else None,
|
| 1868 |
-
dt_bias=dt_bias,
|
| 1869 |
-
dt_softplus=True,
|
| 1870 |
-
dt_limit=dt_limit,
|
| 1871 |
-
)
|
| 1872 |
-
out = rearrange(out, "b s h p -> b s (h p)")
|
| 1873 |
-
if rmsnorm_weight is not None:
|
| 1874 |
-
out = rmsnorm_fn(
|
| 1875 |
-
out,
|
| 1876 |
-
rmsnorm_weight,
|
| 1877 |
-
None,
|
| 1878 |
-
z=rearrange(z, "b l h p -> b l (h p)"),
|
| 1879 |
-
eps=rmsnorm_eps,
|
| 1880 |
-
norm_before_gate=norm_before_gate,
|
| 1881 |
-
)
|
| 1882 |
-
if outproj_weight is not None:
|
| 1883 |
-
out = F.linear(out, outproj_weight, outproj_bias)
|
| 1884 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/utils/__init__.py
DELETED
|
File without changes
|
build/torch25-cxx98-cu121-x86_64-linux/mamba_ssm/__init__.py
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
__version__ = "2.2.4"
|
| 2 |
-
|
| 3 |
-
from .ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
| 4 |
-
from .modules.mamba_simple import Mamba
|
| 5 |
-
from .modules.mamba2 import Mamba2
|
| 6 |
-
from .models.mixer_seq_simple import MambaLMHeadModel
|
| 7 |
-
|
| 8 |
-
__all__ = [
|
| 9 |
-
"selective_scan_fn",
|
| 10 |
-
"mamba_inner_fn",
|
| 11 |
-
"Mamba",
|
| 12 |
-
"Mamba2",
|
| 13 |
-
"MambaLMHeadModel",
|
| 14 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|