medmekk HF Staff commited on
Commit
85c0263
·
1 Parent(s): 916c2f7

Upstream builds

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build.toml +2 -1
  2. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/__init__.py +0 -14
  3. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py +0 -326
  4. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py +0 -338
  5. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/selective_scan_interface.py +0 -659
  6. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/layer_norm.py +0 -1166
  7. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/selective_state_update.py +0 -389
  8. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_scan.py +0 -0
  9. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_state.py +0 -2012
  10. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_combined.py +0 -1884
  11. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/__init__.py +0 -14
  12. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py +0 -326
  13. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py +0 -338
  14. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/ops/selective_scan_interface.py +0 -659
  15. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/ops/triton/layer_norm.py +0 -1166
  16. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/ops/triton/selective_state_update.py +0 -389
  17. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_scan.py +0 -0
  18. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_state.py +0 -2012
  19. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/ops/triton/ssd_combined.py +0 -1884
  20. build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/__init__.py +0 -14
  21. build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/distributed/__init__.py +0 -0
  22. build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py +0 -326
  23. build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/models/__init__.py +0 -0
  24. build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py +0 -338
  25. build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/modules/__init__.py +0 -0
  26. build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/__init__.py +0 -0
  27. build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/selective_scan_interface.py +0 -659
  28. build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/triton/__init__.py +0 -0
  29. build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/triton/layer_norm.py +0 -1166
  30. build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/triton/selective_state_update.py +0 -389
  31. build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_scan.py +0 -0
  32. build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_state.py +0 -2012
  33. build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/triton/ssd_combined.py +0 -1884
  34. build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/utils/__init__.py +0 -0
  35. build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/__init__.py +0 -14
  36. build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/distributed/__init__.py +0 -0
  37. build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py +0 -326
  38. build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/models/__init__.py +0 -0
  39. build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py +0 -338
  40. build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/modules/__init__.py +0 -0
  41. build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/__init__.py +0 -0
  42. build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/selective_scan_interface.py +0 -659
  43. build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/triton/__init__.py +0 -0
  44. build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/triton/layer_norm.py +0 -1166
  45. build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/triton/selective_state_update.py +0 -389
  46. build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_scan.py +0 -0
  47. build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_state.py +0 -2012
  48. build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_combined.py +0 -1884
  49. build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/utils/__init__.py +0 -0
  50. build/torch25-cxx98-cu121-x86_64-linux/mamba_ssm/__init__.py +0 -14
build.toml CHANGED
@@ -1,6 +1,7 @@
1
  [general]
2
  name = "mamba_ssm"
3
- universal = false
 
4
 
5
  [torch]
6
  src = [
 
1
  [general]
2
  name = "mamba_ssm"
3
+ backends = ["cuda"]
4
+ python-depends = ["einops"]
5
 
6
  [torch]
7
  src = [
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/__init__.py DELETED
@@ -1,14 +0,0 @@
1
- __version__ = "2.2.4"
2
-
3
- from .ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
4
- from .modules.mamba_simple import Mamba
5
- from .modules.mamba2 import Mamba2
6
- from .models.mixer_seq_simple import MambaLMHeadModel
7
-
8
- __all__ = [
9
- "selective_scan_fn",
10
- "mamba_inner_fn",
11
- "Mamba",
12
- "Mamba2",
13
- "MambaLMHeadModel",
14
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py DELETED
@@ -1,326 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
- # The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
3
- from typing import Optional
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- from torch import Tensor
9
- from torch.distributed import ProcessGroup
10
- from ..utils.torch import custom_bwd, custom_fwd
11
-
12
- from einops import rearrange
13
-
14
- from ..distributed.distributed_utils import (
15
- all_gather_raw,
16
- all_reduce,
17
- all_reduce_raw,
18
- reduce_scatter,
19
- reduce_scatter_raw,
20
- )
21
-
22
-
23
- class ParallelLinearFunc(torch.autograd.Function):
24
- @staticmethod
25
- @custom_fwd
26
- def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
27
- """
28
- If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
29
- with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
30
- """
31
- ctx.compute_weight_gradient = weight.requires_grad
32
- ctx.process_group = process_group
33
- ctx.sequence_parallel = sequence_parallel
34
-
35
- if torch.is_autocast_enabled():
36
- x = x.to(dtype=torch.get_autocast_gpu_dtype())
37
- x = x.contiguous()
38
- if process_group is not None and sequence_parallel:
39
- # We want to kick off the all_gather early, before weight dtype conversion
40
- total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
41
- else:
42
- total_x = x
43
-
44
- if torch.is_autocast_enabled():
45
- weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
46
- bias = (
47
- bias.to(dtype=torch.get_autocast_gpu_dtype())
48
- if bias is not None
49
- else None
50
- )
51
- weight = weight.contiguous()
52
- if process_group is not None and sequence_parallel:
53
- handle_x.wait()
54
- batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
55
- batch_dim = batch_shape.numel()
56
- # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
57
- output = F.linear(total_x, weight, bias)
58
- if ctx.compute_weight_gradient:
59
- ctx.save_for_backward(x, weight)
60
- else:
61
- ctx.save_for_backward(weight)
62
- return output
63
-
64
- @staticmethod
65
- @custom_bwd
66
- def backward(ctx, grad_output):
67
- grad_output = grad_output.contiguous()
68
- process_group = ctx.process_group
69
- sequence_parallel = ctx.sequence_parallel
70
- if ctx.compute_weight_gradient:
71
- x, weight = ctx.saved_tensors
72
- if process_group is not None and sequence_parallel:
73
- total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
74
- else:
75
- total_x = x
76
- else:
77
- (weight,) = ctx.saved_tensors
78
- total_x = None
79
- batch_shape = grad_output.shape[:-1]
80
- batch_dim = batch_shape.numel()
81
- grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
82
- if ctx.needs_input_grad[0]:
83
- grad_input = F.linear(grad_output, weight.t())
84
- grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
85
- if process_group is not None:
86
- reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
87
- grad_input, handle_grad_input = reduce_fn(
88
- grad_input, process_group, async_op=True
89
- )
90
- else:
91
- grad_input = None
92
- if ctx.needs_input_grad[1]:
93
- assert ctx.compute_weight_gradient
94
- if process_group is not None and sequence_parallel:
95
- handle_x.wait()
96
- grad_weight = torch.einsum(
97
- "bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1])
98
- )
99
- else:
100
- grad_weight = None
101
- grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None
102
- if process_group is not None and ctx.needs_input_grad[0]:
103
- handle_grad_input.wait()
104
- return grad_input, grad_weight, grad_bias, None, None
105
-
106
-
107
- def parallel_linear_func(
108
- x: Tensor,
109
- weight: Tensor,
110
- bias: Optional[Tensor] = None,
111
- process_group: Optional[ProcessGroup] = None,
112
- sequence_parallel: bool = True,
113
- ):
114
- return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel)
115
-
116
-
117
- class ColumnParallelLinear(nn.Linear):
118
- def __init__(
119
- self,
120
- in_features: int,
121
- out_features: int,
122
- process_group: ProcessGroup,
123
- bias: bool = True,
124
- sequence_parallel=True,
125
- multiple_of=1,
126
- device=None,
127
- dtype=None,
128
- ) -> None:
129
- world_size = torch.distributed.get_world_size(process_group)
130
- if out_features % multiple_of:
131
- raise ValueError(
132
- f"out_features ({out_features}) must be a multiple of {multiple_of}"
133
- )
134
- multiple = out_features // multiple_of
135
- # We want to split @multiple across world_size, but it could be an uneven split
136
- div = multiple // world_size
137
- mod = multiple % world_size
138
- # The first @mod ranks get @div + 1 copies, the rest get @div copies
139
- local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
140
- super().__init__(
141
- in_features,
142
- local_multiple * multiple_of,
143
- bias=bias,
144
- device=device,
145
- dtype=dtype,
146
- )
147
- self.process_group = process_group
148
- self.sequence_parallel = sequence_parallel
149
-
150
- def forward(self, x):
151
- # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
152
- # we do an all_gather of x before doing the matmul.
153
- # If not, then the input is already gathered.
154
- return parallel_linear_func(
155
- x,
156
- self.weight,
157
- self.bias,
158
- process_group=self.process_group,
159
- sequence_parallel=self.sequence_parallel,
160
- )
161
-
162
-
163
- class RowParallelLinear(nn.Linear):
164
- def __init__(
165
- self,
166
- in_features: int,
167
- out_features: int,
168
- process_group: ProcessGroup,
169
- bias: bool = True,
170
- sequence_parallel=True,
171
- multiple_of=1,
172
- device=None,
173
- dtype=None,
174
- ) -> None:
175
- world_size = torch.distributed.get_world_size(process_group)
176
- rank = torch.distributed.get_rank(process_group)
177
- if in_features % multiple_of:
178
- raise ValueError(
179
- f"in_features ({in_features}) must be a multiple of {multiple_of}"
180
- )
181
- multiple = in_features // multiple_of
182
- # We want to split @multiple across world_size, but it could be an uneven split
183
- div = multiple // world_size
184
- mod = multiple % world_size
185
- # The first @mod ranks get @div + 1 copies, the rest get @div copies
186
- local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
187
- # Only rank 0 will have bias
188
- super().__init__(
189
- local_multiple * multiple_of,
190
- out_features,
191
- bias=bias and rank == 0,
192
- device=device,
193
- dtype=dtype,
194
- )
195
- self.process_group = process_group
196
- self.sequence_parallel = sequence_parallel
197
-
198
- def forward(self, x):
199
- """
200
- We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
201
- a reduce_scatter of the result.
202
- """
203
- out = parallel_linear_func(x, self.weight, self.bias)
204
- reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
205
- return reduce_fn(out, self.process_group)
206
-
207
-
208
- class VocabParallelEmbedding(nn.Embedding):
209
- def __init__(
210
- self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs
211
- ):
212
- self.process_group = process_group
213
- if process_group is not None:
214
- world_size = torch.distributed.get_world_size(process_group)
215
- if num_embeddings % world_size != 0:
216
- raise ValueError(
217
- f"num_embeddings ({num_embeddings}) must be divisible by "
218
- f"world_size ({world_size})"
219
- )
220
- if world_size > 1 and padding_idx is not None:
221
- raise RuntimeError("ParallelEmbedding does not support padding_idx")
222
- else:
223
- world_size = 1
224
- super().__init__(
225
- num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs
226
- )
227
-
228
- def forward(self, input: Tensor) -> Tensor:
229
- if self.process_group is None:
230
- return super().forward(input)
231
- else:
232
- rank = torch.distributed.get_rank(self.process_group)
233
- vocab_size = self.num_embeddings
234
- vocab_start_index, vocab_end_index = (
235
- rank * vocab_size,
236
- (rank + 1) * vocab_size,
237
- )
238
- # Create a mask of valid vocab ids (1 means it needs to be masked).
239
- input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
240
- input = input - vocab_start_index
241
- input[input_ids_mask] = 0
242
- embeddings = super().forward(input)
243
- embeddings[input_ids_mask] = 0.0
244
- return embeddings
245
-
246
-
247
- class ColumnParallelEmbedding(nn.Embedding):
248
- def __init__(
249
- self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs
250
- ):
251
- self.process_group = process_group
252
- if process_group is not None:
253
- world_size = torch.distributed.get_world_size(process_group)
254
- if embedding_dim % world_size != 0:
255
- raise ValueError(
256
- f"embedding_dim ({embedding_dim}) must be divisible by "
257
- f"world_size ({world_size})"
258
- )
259
- else:
260
- world_size = 1
261
- super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
262
-
263
-
264
- class ParallelEmbeddings(nn.Module):
265
- def __init__(
266
- self,
267
- embed_dim,
268
- vocab_size,
269
- max_position_embeddings,
270
- process_group,
271
- padding_idx=None,
272
- sequence_parallel=True,
273
- device=None,
274
- dtype=None,
275
- ):
276
- """
277
- If max_position_embeddings <= 0, there's no position embeddings
278
- """
279
- factory_kwargs = {"device": device, "dtype": dtype}
280
- super().__init__()
281
- self.process_group = process_group
282
- self.sequence_parallel = sequence_parallel
283
- self.word_embeddings = VocabParallelEmbedding(
284
- vocab_size,
285
- embed_dim,
286
- padding_idx=padding_idx,
287
- process_group=process_group,
288
- **factory_kwargs,
289
- )
290
- self.max_position_embeddings = max_position_embeddings
291
- if self.max_position_embeddings > 0:
292
- self.position_embeddings = ColumnParallelEmbedding(
293
- max_position_embeddings,
294
- embed_dim,
295
- process_group=process_group,
296
- **factory_kwargs,
297
- )
298
-
299
- def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
300
- """
301
- input_ids: (batch, seqlen)
302
- position_ids: (batch, seqlen)
303
- """
304
- batch_size, seqlen = input_ids.shape
305
- world_size = torch.distributed.get_world_size(self.process_group)
306
- embeddings = self.word_embeddings(input_ids)
307
- if self.max_position_embeddings > 0:
308
- if position_ids is None:
309
- position_ids = torch.arange(
310
- seqlen, dtype=torch.long, device=input_ids.device
311
- )
312
- position_embeddings = self.position_embeddings(position_ids)
313
- if world_size <= 1:
314
- embeddings = embeddings + position_embeddings
315
- else:
316
- partition_dim = self.position_embeddings.embedding_dim
317
- rank = torch.distributed.get_rank(self.process_group)
318
- embeddings[
319
- ..., rank * partition_dim : (rank + 1) * partition_dim
320
- ] += position_embeddings
321
- if combine_batch_seqlen_dim:
322
- embeddings = rearrange(embeddings, "b s d -> (b s) d")
323
- reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
324
- return (
325
- embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
326
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py DELETED
@@ -1,338 +0,0 @@
1
- # Copyright (c) 2023, Albert Gu, Tri Dao.
2
-
3
- import math
4
- from functools import partial
5
- import json
6
- import os
7
- import copy
8
-
9
- from collections import namedtuple
10
-
11
- import torch
12
- import torch.nn as nn
13
-
14
- from .config_mamba import MambaConfig
15
- from ..modules.mamba_simple import Mamba
16
- from ..modules.mamba2 import Mamba2
17
- from ..modules.mha import MHA
18
- from ..modules.mlp import GatedMLP
19
- from ..modules.block import Block
20
- from ..utils.generation import GenerationMixin
21
- from ..utils.hf import load_config_hf, load_state_dict_hf
22
-
23
- try:
24
- from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
25
- except ImportError:
26
- RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
27
-
28
-
29
- def create_block(
30
- d_model,
31
- d_intermediate,
32
- ssm_cfg=None,
33
- attn_layer_idx=None,
34
- attn_cfg=None,
35
- norm_epsilon=1e-5,
36
- rms_norm=False,
37
- residual_in_fp32=False,
38
- fused_add_norm=False,
39
- layer_idx=None,
40
- device=None,
41
- dtype=None,
42
- ):
43
- if ssm_cfg is None:
44
- ssm_cfg = {}
45
- if attn_layer_idx is None:
46
- attn_layer_idx = []
47
- if attn_cfg is None:
48
- attn_cfg = {}
49
- factory_kwargs = {"device": device, "dtype": dtype}
50
- if layer_idx not in attn_layer_idx:
51
- # Create a copy of the config to modify
52
- ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
53
- ssm_layer = ssm_cfg.pop("layer", "Mamba1")
54
- if ssm_layer not in ["Mamba1", "Mamba2"]:
55
- raise ValueError(
56
- f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2"
57
- )
58
- mixer_cls = partial(
59
- Mamba2 if ssm_layer == "Mamba2" else Mamba,
60
- layer_idx=layer_idx,
61
- **ssm_cfg,
62
- **factory_kwargs,
63
- )
64
- else:
65
- mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
66
- norm_cls = partial(
67
- nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
68
- )
69
- if d_intermediate == 0:
70
- mlp_cls = nn.Identity
71
- else:
72
- mlp_cls = partial(
73
- GatedMLP,
74
- hidden_features=d_intermediate,
75
- out_features=d_model,
76
- **factory_kwargs,
77
- )
78
- block = Block(
79
- d_model,
80
- mixer_cls,
81
- mlp_cls,
82
- norm_cls=norm_cls,
83
- fused_add_norm=fused_add_norm,
84
- residual_in_fp32=residual_in_fp32,
85
- )
86
- block.layer_idx = layer_idx
87
- return block
88
-
89
-
90
- # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
91
- def _init_weights(
92
- module,
93
- n_layer,
94
- initializer_range=0.02, # Now only used for embedding layer.
95
- rescale_prenorm_residual=True,
96
- n_residuals_per_layer=1, # Change to 2 if we have MLP
97
- ):
98
- if isinstance(module, nn.Linear):
99
- if module.bias is not None:
100
- if not getattr(module.bias, "_no_reinit", False):
101
- nn.init.zeros_(module.bias)
102
- elif isinstance(module, nn.Embedding):
103
- nn.init.normal_(module.weight, std=initializer_range)
104
-
105
- if rescale_prenorm_residual:
106
- # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
107
- # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
108
- # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
109
- # > -- GPT-2 :: https://openai.com/blog/better-language-models/
110
- #
111
- # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
112
- for name, p in module.named_parameters():
113
- if name in ["out_proj.weight", "fc2.weight"]:
114
- # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
115
- # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
116
- # We need to reinit p since this code could be called multiple times
117
- # Having just p *= scale would repeatedly scale it down
118
- nn.init.kaiming_uniform_(p, a=math.sqrt(5))
119
- with torch.no_grad():
120
- p /= math.sqrt(n_residuals_per_layer * n_layer)
121
-
122
-
123
- class MixerModel(nn.Module):
124
- def __init__(
125
- self,
126
- d_model: int,
127
- n_layer: int,
128
- d_intermediate: int,
129
- vocab_size: int,
130
- ssm_cfg=None,
131
- attn_layer_idx=None,
132
- attn_cfg=None,
133
- norm_epsilon: float = 1e-5,
134
- rms_norm: bool = False,
135
- initializer_cfg=None,
136
- fused_add_norm=False,
137
- residual_in_fp32=False,
138
- device=None,
139
- dtype=None,
140
- ) -> None:
141
- factory_kwargs = {"device": device, "dtype": dtype}
142
- super().__init__()
143
- self.residual_in_fp32 = residual_in_fp32
144
-
145
- self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
146
-
147
- # We change the order of residual and layer norm:
148
- # Instead of LN -> Attn / MLP -> Add, we do:
149
- # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
150
- # the main branch (output of MLP / Mixer). The model definition is unchanged.
151
- # This is for performance reason: we can fuse add + layer_norm.
152
- self.fused_add_norm = fused_add_norm
153
- if self.fused_add_norm:
154
- if layer_norm_fn is None or rms_norm_fn is None:
155
- raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
156
-
157
- self.layers = nn.ModuleList(
158
- [
159
- create_block(
160
- d_model,
161
- d_intermediate=d_intermediate,
162
- ssm_cfg=ssm_cfg,
163
- attn_layer_idx=attn_layer_idx,
164
- attn_cfg=attn_cfg,
165
- norm_epsilon=norm_epsilon,
166
- rms_norm=rms_norm,
167
- residual_in_fp32=residual_in_fp32,
168
- fused_add_norm=fused_add_norm,
169
- layer_idx=i,
170
- **factory_kwargs,
171
- )
172
- for i in range(n_layer)
173
- ]
174
- )
175
-
176
- self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
177
- d_model, eps=norm_epsilon, **factory_kwargs
178
- )
179
-
180
- self.apply(
181
- partial(
182
- _init_weights,
183
- n_layer=n_layer,
184
- **(initializer_cfg if initializer_cfg is not None else {}),
185
- n_residuals_per_layer=(
186
- 1 if d_intermediate == 0 else 2
187
- ), # 2 if we have MLP
188
- )
189
- )
190
-
191
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
192
- return {
193
- i: layer.allocate_inference_cache(
194
- batch_size, max_seqlen, dtype=dtype, **kwargs
195
- )
196
- for i, layer in enumerate(self.layers)
197
- }
198
-
199
- def forward(self, input_ids, inference_params=None, **mixer_kwargs):
200
- hidden_states = self.embedding(input_ids)
201
- residual = None
202
- for layer in self.layers:
203
- hidden_states, residual = layer(
204
- hidden_states,
205
- residual,
206
- inference_params=inference_params,
207
- **mixer_kwargs,
208
- )
209
- if not self.fused_add_norm:
210
- residual = (
211
- (hidden_states + residual) if residual is not None else hidden_states
212
- )
213
- hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
214
- else:
215
- # Set prenorm=False here since we don't need the residual
216
- hidden_states = layer_norm_fn(
217
- hidden_states,
218
- self.norm_f.weight,
219
- self.norm_f.bias,
220
- eps=self.norm_f.eps,
221
- residual=residual,
222
- prenorm=False,
223
- residual_in_fp32=self.residual_in_fp32,
224
- is_rms_norm=isinstance(self.norm_f, RMSNorm),
225
- )
226
- return hidden_states
227
-
228
-
229
- class MambaLMHeadModel(nn.Module, GenerationMixin):
230
-
231
- def __init__(
232
- self,
233
- config: MambaConfig,
234
- initializer_cfg=None,
235
- device=None,
236
- dtype=None,
237
- ) -> None:
238
- self.config = config
239
- d_model = config.d_model
240
- n_layer = config.n_layer
241
- d_intermediate = config.d_intermediate
242
- vocab_size = config.vocab_size
243
- ssm_cfg = config.ssm_cfg
244
- attn_layer_idx = config.attn_layer_idx
245
- attn_cfg = config.attn_cfg
246
- rms_norm = config.rms_norm
247
- residual_in_fp32 = config.residual_in_fp32
248
- fused_add_norm = config.fused_add_norm
249
- pad_vocab_size_multiple = config.pad_vocab_size_multiple
250
- factory_kwargs = {"device": device, "dtype": dtype}
251
-
252
- super().__init__()
253
- if vocab_size % pad_vocab_size_multiple != 0:
254
- vocab_size += pad_vocab_size_multiple - (
255
- vocab_size % pad_vocab_size_multiple
256
- )
257
- self.backbone = MixerModel(
258
- d_model=d_model,
259
- n_layer=n_layer,
260
- d_intermediate=d_intermediate,
261
- vocab_size=vocab_size,
262
- ssm_cfg=ssm_cfg,
263
- attn_layer_idx=attn_layer_idx,
264
- attn_cfg=attn_cfg,
265
- rms_norm=rms_norm,
266
- initializer_cfg=initializer_cfg,
267
- fused_add_norm=fused_add_norm,
268
- residual_in_fp32=residual_in_fp32,
269
- **factory_kwargs,
270
- )
271
- self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
272
-
273
- # Initialize weights and apply final processing
274
- self.apply(
275
- partial(
276
- _init_weights,
277
- n_layer=n_layer,
278
- **(initializer_cfg if initializer_cfg is not None else {}),
279
- )
280
- )
281
- self.tie_weights()
282
-
283
- def tie_weights(self):
284
- if self.config.tie_embeddings:
285
- self.lm_head.weight = self.backbone.embedding.weight
286
-
287
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
288
- return self.backbone.allocate_inference_cache(
289
- batch_size, max_seqlen, dtype=dtype, **kwargs
290
- )
291
-
292
- def forward(
293
- self,
294
- input_ids,
295
- position_ids=None,
296
- inference_params=None,
297
- num_last_tokens=0,
298
- **mixer_kwargs,
299
- ):
300
- """
301
- "position_ids" is just to be compatible with Transformer generation. We don't use it.
302
- num_last_tokens: if > 0, only return the logits for the last n tokens
303
- """
304
- hidden_states = self.backbone(
305
- input_ids, inference_params=inference_params, **mixer_kwargs
306
- )
307
- if num_last_tokens > 0:
308
- hidden_states = hidden_states[:, -num_last_tokens:]
309
- lm_logits = self.lm_head(hidden_states)
310
- CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
311
- return CausalLMOutput(logits=lm_logits)
312
-
313
- @classmethod
314
- def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
315
- config_data = load_config_hf(pretrained_model_name)
316
- config = MambaConfig(**config_data)
317
- model = cls(config, device=device, dtype=dtype, **kwargs)
318
- model.load_state_dict(
319
- load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)
320
- )
321
- return model
322
-
323
- def save_pretrained(self, save_directory):
324
- """
325
- Minimal implementation of save_pretrained for MambaLMHeadModel.
326
- Save the model and its configuration file to a directory.
327
- """
328
- # Ensure save_directory exists
329
- os.makedirs(save_directory, exist_ok=True)
330
-
331
- # Save the model's state_dict
332
- model_path = os.path.join(save_directory, "pytorch_model.bin")
333
- torch.save(self.state_dict(), model_path)
334
-
335
- # Save the configuration of the model
336
- config_path = os.path.join(save_directory, "config.json")
337
- with open(config_path, "w") as f:
338
- json.dump(self.config.__dict__, f, indent=4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/selective_scan_interface.py DELETED
@@ -1,659 +0,0 @@
1
- # Copyright (c) 2023, Tri Dao, Albert Gu.
2
-
3
- import torch
4
- import torch.nn.functional as F
5
- from ..utils.torch import custom_fwd, custom_bwd
6
-
7
- from einops import rearrange, repeat
8
-
9
- try:
10
- from causal_conv1d import causal_conv1d_fn
11
- import causal_conv1d_cuda
12
- except ImportError:
13
- causal_conv1d_fn = None
14
- causal_conv1d_cuda = None
15
-
16
- from .triton.layer_norm import _layer_norm_fwd
17
-
18
- from .._ops import ops
19
-
20
-
21
- class SelectiveScanFn(torch.autograd.Function):
22
-
23
- @staticmethod
24
- def forward(
25
- ctx,
26
- u,
27
- delta,
28
- A,
29
- B,
30
- C,
31
- D=None,
32
- z=None,
33
- delta_bias=None,
34
- delta_softplus=False,
35
- return_last_state=False,
36
- ):
37
- if u.stride(-1) != 1:
38
- u = u.contiguous()
39
- if delta.stride(-1) != 1:
40
- delta = delta.contiguous()
41
- if D is not None:
42
- D = D.contiguous()
43
- if B.stride(-1) != 1:
44
- B = B.contiguous()
45
- if C.stride(-1) != 1:
46
- C = C.contiguous()
47
- if z is not None and z.stride(-1) != 1:
48
- z = z.contiguous()
49
- if B.dim() == 3:
50
- B = rearrange(B, "b dstate l -> b 1 dstate l")
51
- ctx.squeeze_B = True
52
- if C.dim() == 3:
53
- C = rearrange(C, "b dstate l -> b 1 dstate l")
54
- ctx.squeeze_C = True
55
- out, x, *rest = ops.selective_scan_fwd(
56
- u, delta, A, B, C, D, z, delta_bias, delta_softplus
57
- )
58
- ctx.delta_softplus = delta_softplus
59
- ctx.has_z = z is not None
60
- last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
61
- if not ctx.has_z:
62
- ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
63
- return out if not return_last_state else (out, last_state)
64
- else:
65
- ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
66
- out_z = rest[0]
67
- return out_z if not return_last_state else (out_z, last_state)
68
-
69
- @staticmethod
70
- def backward(ctx, dout, *args):
71
- if not ctx.has_z:
72
- u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
73
- z = None
74
- out = None
75
- else:
76
- u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
77
- if dout.stride(-1) != 1:
78
- dout = dout.contiguous()
79
- # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
80
- # backward of selective_scan_cuda with the backward of chunk).
81
- # Here we just pass in None and dz will be allocated in the C++ code.
82
- du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = ops.selective_scan_bwd(
83
- u,
84
- delta,
85
- A,
86
- B,
87
- C,
88
- D,
89
- z,
90
- delta_bias,
91
- dout,
92
- x,
93
- out,
94
- None,
95
- ctx.delta_softplus,
96
- False, # option to recompute out_z, not used here
97
- )
98
- dz = rest[0] if ctx.has_z else None
99
- dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
100
- dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
101
- return (
102
- du,
103
- ddelta,
104
- dA,
105
- dB,
106
- dC,
107
- dD if D is not None else None,
108
- dz,
109
- ddelta_bias if delta_bias is not None else None,
110
- None,
111
- None,
112
- )
113
-
114
-
115
- def rms_norm_forward(
116
- x,
117
- weight,
118
- bias,
119
- eps=1e-6,
120
- is_rms_norm=True,
121
- ):
122
- # x (b l) d
123
- if x.stride(-1) != 1:
124
- x = x.contiguous()
125
- weight = weight.contiguous()
126
- if bias is not None:
127
- bias = bias.contiguous()
128
- y = _layer_norm_fwd(
129
- x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm
130
- )[0]
131
- # y (b l) d
132
- return y
133
-
134
-
135
- def selective_scan_fn(
136
- u,
137
- delta,
138
- A,
139
- B,
140
- C,
141
- D=None,
142
- z=None,
143
- delta_bias=None,
144
- delta_softplus=False,
145
- return_last_state=False,
146
- ):
147
- """if return_last_state is True, returns (out, last_state)
148
- last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
149
- not considered in the backward pass.
150
- """
151
- return SelectiveScanFn.apply(
152
- u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state
153
- )
154
-
155
-
156
- def selective_scan_ref(
157
- u,
158
- delta,
159
- A,
160
- B,
161
- C,
162
- D=None,
163
- z=None,
164
- delta_bias=None,
165
- delta_softplus=False,
166
- return_last_state=False,
167
- ):
168
- """
169
- u: r(B D L)
170
- delta: r(B D L)
171
- A: c(D N) or r(D N)
172
- B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
173
- C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
174
- D: r(D)
175
- z: r(B D L)
176
- delta_bias: r(D), fp32
177
-
178
- out: r(B D L)
179
- last_state (optional): r(B D dstate) or c(B D dstate)
180
- """
181
- dtype_in = u.dtype
182
- u = u.float()
183
- delta = delta.float()
184
- if delta_bias is not None:
185
- delta = delta + delta_bias[..., None].float()
186
- if delta_softplus:
187
- delta = F.softplus(delta)
188
- batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
189
- is_variable_B = B.dim() >= 3
190
- is_variable_C = C.dim() >= 3
191
- if A.is_complex():
192
- if is_variable_B:
193
- B = torch.view_as_complex(
194
- rearrange(B.float(), "... (L two) -> ... L two", two=2)
195
- )
196
- if is_variable_C:
197
- C = torch.view_as_complex(
198
- rearrange(C.float(), "... (L two) -> ... L two", two=2)
199
- )
200
- else:
201
- B = B.float()
202
- C = C.float()
203
- x = A.new_zeros((batch, dim, dstate))
204
- ys = []
205
- deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
206
- if not is_variable_B:
207
- deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
208
- else:
209
- if B.dim() == 3:
210
- deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
211
- else:
212
- B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
213
- deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
214
- if is_variable_C and C.dim() == 4:
215
- C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
216
- last_state = None
217
- for i in range(u.shape[2]):
218
- x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
219
- if not is_variable_C:
220
- y = torch.einsum("bdn,dn->bd", x, C)
221
- else:
222
- if C.dim() == 3:
223
- y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
224
- else:
225
- y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
226
- if i == u.shape[2] - 1:
227
- last_state = x
228
- if y.is_complex():
229
- y = y.real * 2
230
- ys.append(y)
231
- y = torch.stack(ys, dim=2) # (batch dim L)
232
- out = y if D is None else y + u * rearrange(D, "d -> d 1")
233
- if z is not None:
234
- out = out * F.silu(z)
235
- out = out.to(dtype=dtype_in)
236
- return out if not return_last_state else (out, last_state)
237
-
238
-
239
- class MambaInnerFn(torch.autograd.Function):
240
-
241
- @staticmethod
242
- @custom_fwd
243
- def forward(
244
- ctx,
245
- xz,
246
- conv1d_weight,
247
- conv1d_bias,
248
- x_proj_weight,
249
- delta_proj_weight,
250
- out_proj_weight,
251
- out_proj_bias,
252
- A,
253
- B=None,
254
- C=None,
255
- D=None,
256
- delta_bias=None,
257
- B_proj_bias=None,
258
- C_proj_bias=None,
259
- delta_softplus=True,
260
- checkpoint_lvl=1,
261
- b_rms_weight=None,
262
- c_rms_weight=None,
263
- dt_rms_weight=None,
264
- b_c_dt_rms_eps=1e-6,
265
- ):
266
- """
267
- xz: (batch, dim, seqlen)
268
- """
269
- assert (
270
- causal_conv1d_cuda is not None
271
- ), "causal_conv1d_cuda is not available. Please install causal-conv1d."
272
- assert checkpoint_lvl in [0, 1]
273
- L = xz.shape[-1]
274
- delta_rank = delta_proj_weight.shape[1]
275
- d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
276
- if torch.is_autocast_enabled():
277
- x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
278
- delta_proj_weight = delta_proj_weight.to(
279
- dtype=torch.get_autocast_gpu_dtype()
280
- )
281
- out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
282
- out_proj_bias = (
283
- out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
284
- if out_proj_bias is not None
285
- else None
286
- )
287
- if xz.stride(-1) != 1:
288
- xz = xz.contiguous()
289
- conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
290
- x, z = xz.chunk(2, dim=1)
291
- conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
292
- conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
293
- x, conv1d_weight, conv1d_bias, None, None, None, True
294
- )
295
- # We're being very careful here about the layout, to avoid extra transposes.
296
- # We want delta to have d as the slowest moving dimension
297
- # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
298
- x_dbl = F.linear(
299
- rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight
300
- ) # (bl d)
301
- delta = rearrange(
302
- delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L
303
- )
304
- ctx.is_variable_B = B is None
305
- ctx.is_variable_C = C is None
306
- ctx.B_proj_bias_is_None = B_proj_bias is None
307
- ctx.C_proj_bias_is_None = C_proj_bias is None
308
- if B is None: # variable B
309
- B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate)
310
- if B_proj_bias is not None:
311
- B = B + B_proj_bias.to(dtype=B.dtype)
312
- if not A.is_complex():
313
- # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
314
- B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
315
- else:
316
- B = rearrange(
317
- B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2
318
- ).contiguous()
319
- else:
320
- if B.stride(-1) != 1:
321
- B = B.contiguous()
322
- if C is None: # variable C
323
- C = x_dbl[:, -d_state:] # (bl dstate)
324
- if C_proj_bias is not None:
325
- C = C + C_proj_bias.to(dtype=C.dtype)
326
- if not A.is_complex():
327
- # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
328
- C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
329
- else:
330
- C = rearrange(
331
- C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2
332
- ).contiguous()
333
- else:
334
- if C.stride(-1) != 1:
335
- C = C.contiguous()
336
- if D is not None:
337
- D = D.contiguous()
338
-
339
- if b_rms_weight is not None:
340
- B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
341
- B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)
342
- B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
343
- if c_rms_weight is not None:
344
- C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
345
- C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)
346
- C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
347
- if dt_rms_weight is not None:
348
- delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
349
- delta = rms_norm_forward(
350
- delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps
351
- )
352
- delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
353
-
354
- out, scan_intermediates, out_z = ops.selective_scan_fwd(
355
- conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
356
- )
357
- ctx.delta_softplus = delta_softplus
358
- ctx.out_proj_bias_is_None = out_proj_bias is None
359
- ctx.checkpoint_lvl = checkpoint_lvl
360
- ctx.b_rms_weight = b_rms_weight
361
- ctx.c_rms_weight = c_rms_weight
362
- ctx.dt_rms_weight = dt_rms_weight
363
- ctx.b_c_dt_rms_eps = b_c_dt_rms_eps
364
- if (
365
- checkpoint_lvl >= 1
366
- ): # Will recompute conv1d_out and delta in the backward pass
367
- conv1d_out, delta = None, None
368
- ctx.save_for_backward(
369
- xz,
370
- conv1d_weight,
371
- conv1d_bias,
372
- x_dbl,
373
- x_proj_weight,
374
- delta_proj_weight,
375
- out_proj_weight,
376
- conv1d_out,
377
- delta,
378
- A,
379
- B,
380
- C,
381
- D,
382
- delta_bias,
383
- scan_intermediates,
384
- b_rms_weight,
385
- c_rms_weight,
386
- dt_rms_weight,
387
- out,
388
- )
389
- return F.linear(
390
- rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias
391
- )
392
-
393
- @staticmethod
394
- @custom_bwd
395
- def backward(ctx, dout):
396
- # dout: (batch, seqlen, dim)
397
- assert (
398
- causal_conv1d_cuda is not None
399
- ), "causal_conv1d_cuda is not available. Please install causal-conv1d."
400
- (
401
- xz,
402
- conv1d_weight,
403
- conv1d_bias,
404
- x_dbl,
405
- x_proj_weight,
406
- delta_proj_weight,
407
- out_proj_weight,
408
- conv1d_out,
409
- delta,
410
- A,
411
- B,
412
- C,
413
- D,
414
- delta_bias,
415
- scan_intermediates,
416
- b_rms_weight,
417
- c_rms_weight,
418
- dt_rms_weight,
419
- out,
420
- ) = ctx.saved_tensors
421
- L = xz.shape[-1]
422
- delta_rank = delta_proj_weight.shape[1]
423
- d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
424
- x, z = xz.chunk(2, dim=1)
425
- if dout.stride(-1) != 1:
426
- dout = dout.contiguous()
427
- if ctx.checkpoint_lvl == 1:
428
- conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
429
- x, conv1d_weight, conv1d_bias, None, None, None, True
430
- )
431
- delta = rearrange(
432
- delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L
433
- )
434
- if dt_rms_weight is not None:
435
- delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
436
- delta = rms_norm_forward(
437
- delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps
438
- )
439
- delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
440
- if b_rms_weight is not None:
441
- # Recompute & RMSNorm B
442
- B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
443
- B = rms_norm_forward(B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps)
444
- B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
445
- if c_rms_weight is not None:
446
- # Recompute & RMSNorm C
447
- C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
448
- C = rms_norm_forward(C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps)
449
- C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
450
-
451
- # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
452
- # backward of selective_scan_cuda with the backward of chunk).
453
- dxz = torch.empty_like(xz) # (batch, dim, seqlen)
454
- dx, dz = dxz.chunk(2, dim=1)
455
- dout = rearrange(dout, "b l e -> e (b l)")
456
- dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
457
- dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = (
458
- ops.selective_scan_bwd(
459
- conv1d_out,
460
- delta,
461
- A,
462
- B,
463
- C,
464
- D,
465
- z,
466
- delta_bias,
467
- dout_y,
468
- scan_intermediates,
469
- out,
470
- dz,
471
- ctx.delta_softplus,
472
- True, # option to recompute out_z
473
- )
474
- )
475
- dout_proj_weight = torch.einsum(
476
- "eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")
477
- )
478
- dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
479
- dD = dD if D is not None else None
480
- dx_dbl = torch.empty_like(x_dbl)
481
- dB_proj_bias = None
482
- if ctx.is_variable_B:
483
- if not A.is_complex():
484
- dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
485
- else:
486
- dB = rearrange(
487
- dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2
488
- ).contiguous()
489
- dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
490
- dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d)
491
- dB = None
492
- dC_proj_bias = None
493
- if ctx.is_variable_C:
494
- if not A.is_complex():
495
- dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
496
- else:
497
- dC = rearrange(
498
- dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2
499
- ).contiguous()
500
- dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
501
- dx_dbl[:, -d_state:] = dC # (bl d)
502
- dC = None
503
- ddelta = rearrange(ddelta, "b d l -> d (b l)")
504
- ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
505
- dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
506
- dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
507
- dx_proj_weight = torch.einsum(
508
- "Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")
509
- )
510
- dconv1d_out = torch.addmm(
511
- dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out
512
- )
513
- dconv1d_out = rearrange(
514
- dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]
515
- )
516
- # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
517
- # backward of conv1d with the backward of chunk).
518
- dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
519
- x,
520
- conv1d_weight,
521
- conv1d_bias,
522
- dconv1d_out,
523
- None,
524
- None,
525
- None,
526
- dx,
527
- False,
528
- True,
529
- )
530
- dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
531
- dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
532
- return (
533
- dxz,
534
- dconv1d_weight,
535
- dconv1d_bias,
536
- dx_proj_weight,
537
- ddelta_proj_weight,
538
- dout_proj_weight,
539
- dout_proj_bias,
540
- dA,
541
- dB,
542
- dC,
543
- dD,
544
- ddelta_bias if delta_bias is not None else None,
545
- # 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps
546
- dB_proj_bias,
547
- dC_proj_bias,
548
- None,
549
- None,
550
- None,
551
- None,
552
- None,
553
- None,
554
- )
555
-
556
-
557
- def mamba_inner_fn(
558
- xz,
559
- conv1d_weight,
560
- conv1d_bias,
561
- x_proj_weight,
562
- delta_proj_weight,
563
- out_proj_weight,
564
- out_proj_bias,
565
- A,
566
- B=None,
567
- C=None,
568
- D=None,
569
- delta_bias=None,
570
- B_proj_bias=None,
571
- C_proj_bias=None,
572
- delta_softplus=True,
573
- checkpoint_lvl=1,
574
- b_rms_weight=None,
575
- c_rms_weight=None,
576
- dt_rms_weight=None,
577
- b_c_dt_rms_eps=1e-6,
578
- ):
579
- return MambaInnerFn.apply(
580
- xz,
581
- conv1d_weight,
582
- conv1d_bias,
583
- x_proj_weight,
584
- delta_proj_weight,
585
- out_proj_weight,
586
- out_proj_bias,
587
- A,
588
- B,
589
- C,
590
- D,
591
- delta_bias,
592
- B_proj_bias,
593
- C_proj_bias,
594
- delta_softplus,
595
- checkpoint_lvl,
596
- b_rms_weight,
597
- c_rms_weight,
598
- dt_rms_weight,
599
- b_c_dt_rms_eps,
600
- )
601
-
602
-
603
- def mamba_inner_ref(
604
- xz,
605
- conv1d_weight,
606
- conv1d_bias,
607
- x_proj_weight,
608
- delta_proj_weight,
609
- out_proj_weight,
610
- out_proj_bias,
611
- A,
612
- B=None,
613
- C=None,
614
- D=None,
615
- delta_bias=None,
616
- B_proj_bias=None,
617
- C_proj_bias=None,
618
- delta_softplus=True,
619
- ):
620
- assert (
621
- causal_conv1d_fn is not None
622
- ), "causal_conv1d_fn is not available. Please install causal-conv1d."
623
- L = xz.shape[-1]
624
- delta_rank = delta_proj_weight.shape[1]
625
- d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
626
- x, z = xz.chunk(2, dim=1)
627
- x = causal_conv1d_fn(
628
- x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu"
629
- )
630
- # We're being very careful here about the layout, to avoid extra transposes.
631
- # We want delta to have d as the slowest moving dimension
632
- # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
633
- x_dbl = F.linear(rearrange(x, "b d l -> (b l) d"), x_proj_weight) # (bl d)
634
- delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
635
- delta = rearrange(delta, "d (b l) -> b d l", l=L)
636
- if B is None: # variable B
637
- B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl d)
638
- if B_proj_bias is not None:
639
- B = B + B_proj_bias.to(dtype=B.dtype)
640
- if not A.is_complex():
641
- B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
642
- else:
643
- B = rearrange(
644
- B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2
645
- ).contiguous()
646
- if C is None: # variable B
647
- C = x_dbl[:, -d_state:] # (bl d)
648
- if C_proj_bias is not None:
649
- C = C + C_proj_bias.to(dtype=C.dtype)
650
- if not A.is_complex():
651
- C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
652
- else:
653
- C = rearrange(
654
- C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2
655
- ).contiguous()
656
- y = selective_scan_fn(
657
- x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True
658
- )
659
- return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/layer_norm.py DELETED
@@ -1,1166 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
- # Implement dropout + residual + layer_norm / rms_norm.
3
-
4
- # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
5
- # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
6
- # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
7
- # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
8
-
9
- import math
10
- import warnings
11
-
12
- import torch
13
- import torch.nn.functional as F
14
- from ...utils.torch import custom_bwd, custom_fwd
15
-
16
- import triton
17
- import triton.language as tl
18
-
19
-
20
- def layer_norm_ref(
21
- x,
22
- weight,
23
- bias,
24
- residual=None,
25
- x1=None,
26
- weight1=None,
27
- bias1=None,
28
- eps=1e-6,
29
- dropout_p=0.0,
30
- rowscale=None,
31
- prenorm=False,
32
- dropout_mask=None,
33
- dropout_mask1=None,
34
- upcast=False,
35
- ):
36
- dtype = x.dtype
37
- if upcast:
38
- x = x.float()
39
- weight = weight.float()
40
- bias = bias.float() if bias is not None else None
41
- residual = residual.float() if residual is not None else residual
42
- x1 = x1.float() if x1 is not None else None
43
- weight1 = weight1.float() if weight1 is not None else None
44
- bias1 = bias1.float() if bias1 is not None else None
45
- if x1 is not None:
46
- assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
47
- if rowscale is not None:
48
- x = x * rowscale[..., None]
49
- if dropout_p > 0.0:
50
- if dropout_mask is not None:
51
- x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
52
- else:
53
- x = F.dropout(x, p=dropout_p)
54
- if x1 is not None:
55
- if dropout_mask1 is not None:
56
- x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
57
- else:
58
- x1 = F.dropout(x1, p=dropout_p)
59
- if x1 is not None:
60
- x = x + x1
61
- if residual is not None:
62
- x = (x + residual).to(x.dtype)
63
- out = F.layer_norm(
64
- x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps
65
- ).to(dtype)
66
- if weight1 is None:
67
- return out if not prenorm else (out, x)
68
- else:
69
- out1 = F.layer_norm(
70
- x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
71
- ).to(dtype)
72
- return (out, out1) if not prenorm else (out, out1, x)
73
-
74
-
75
- def rms_norm_ref(
76
- x,
77
- weight,
78
- bias,
79
- residual=None,
80
- x1=None,
81
- weight1=None,
82
- bias1=None,
83
- eps=1e-6,
84
- dropout_p=0.0,
85
- rowscale=None,
86
- prenorm=False,
87
- dropout_mask=None,
88
- dropout_mask1=None,
89
- upcast=False,
90
- ):
91
- dtype = x.dtype
92
- if upcast:
93
- x = x.float()
94
- weight = weight.float()
95
- bias = bias.float() if bias is not None else None
96
- residual = residual.float() if residual is not None else residual
97
- x1 = x1.float() if x1 is not None else None
98
- weight1 = weight1.float() if weight1 is not None else None
99
- bias1 = bias1.float() if bias1 is not None else None
100
- if x1 is not None:
101
- assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
102
- if rowscale is not None:
103
- x = x * rowscale[..., None]
104
- if dropout_p > 0.0:
105
- if dropout_mask is not None:
106
- x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
107
- else:
108
- x = F.dropout(x, p=dropout_p)
109
- if x1 is not None:
110
- if dropout_mask1 is not None:
111
- x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
112
- else:
113
- x1 = F.dropout(x1, p=dropout_p)
114
- if x1 is not None:
115
- x = x + x1
116
- if residual is not None:
117
- x = (x + residual).to(x.dtype)
118
- rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
119
- out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(
120
- dtype
121
- )
122
- if weight1 is None:
123
- return out if not prenorm else (out, x)
124
- else:
125
- out1 = (
126
- (x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)
127
- ).to(dtype)
128
- return (out, out1) if not prenorm else (out, out1, x)
129
-
130
-
131
- def config_prune(configs):
132
-
133
- if torch.version.hip:
134
- try:
135
- # set warp size based on gcn architecure
136
- gcn_arch_name = torch.cuda.get_device_properties(0).gcnArchName
137
- if "gfx10" in gcn_arch_name or "gfx11" in gcn_arch_name:
138
- # radeon
139
- warp_size = 32
140
- else:
141
- # instinct
142
- warp_size = 64
143
- except AttributeError as e:
144
- # fall back to crude method to set warp size
145
- device_name = torch.cuda.get_device_properties(0).name
146
- if "instinct" in device_name.lower():
147
- warp_size = 64
148
- else:
149
- warp_size = 32
150
- warnings.warn(
151
- f"{e}, warp size set to {warp_size} based on device name: {device_name}",
152
- UserWarning,
153
- )
154
-
155
- else:
156
- # cuda
157
- warp_size = 32
158
-
159
- max_block_sz = 1024
160
- max_num_warps = max_block_sz // warp_size
161
- pruned_configs = [config for config in configs if config.num_warps <= max_num_warps]
162
- return pruned_configs
163
-
164
-
165
- configs_autotune = [
166
- triton.Config({}, num_warps=1),
167
- triton.Config({}, num_warps=2),
168
- triton.Config({}, num_warps=4),
169
- triton.Config({}, num_warps=8),
170
- triton.Config({}, num_warps=16),
171
- triton.Config({}, num_warps=32),
172
- ]
173
-
174
- pruned_configs_autotune = config_prune(configs_autotune)
175
-
176
-
177
- @triton.autotune(
178
- configs=pruned_configs_autotune,
179
- key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
180
- )
181
- # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
182
- # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
183
- @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
184
- @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
185
- @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
186
- @triton.jit
187
- def _layer_norm_fwd_1pass_kernel(
188
- X, # pointer to the input
189
- Y, # pointer to the output
190
- W, # pointer to the weights
191
- B, # pointer to the biases
192
- RESIDUAL, # pointer to the residual
193
- X1,
194
- W1,
195
- B1,
196
- Y1,
197
- RESIDUAL_OUT, # pointer to the residual
198
- ROWSCALE,
199
- SEEDS, # Dropout seeds for each row
200
- DROPOUT_MASK,
201
- Mean, # pointer to the mean
202
- Rstd, # pointer to the 1/std
203
- stride_x_row, # how much to increase the pointer when moving by 1 row
204
- stride_y_row,
205
- stride_res_row,
206
- stride_res_out_row,
207
- stride_x1_row,
208
- stride_y1_row,
209
- M, # number of rows in X
210
- N, # number of columns in X
211
- eps, # epsilon to avoid division by zero
212
- dropout_p, # Dropout probability
213
- IS_RMS_NORM: tl.constexpr,
214
- BLOCK_N: tl.constexpr,
215
- HAS_RESIDUAL: tl.constexpr,
216
- STORE_RESIDUAL_OUT: tl.constexpr,
217
- HAS_BIAS: tl.constexpr,
218
- HAS_DROPOUT: tl.constexpr,
219
- STORE_DROPOUT_MASK: tl.constexpr,
220
- HAS_ROWSCALE: tl.constexpr,
221
- HAS_X1: tl.constexpr,
222
- HAS_W1: tl.constexpr,
223
- HAS_B1: tl.constexpr,
224
- ):
225
- # Map the program id to the row of X and Y it should compute.
226
- row = tl.program_id(0)
227
- X += row * stride_x_row
228
- Y += row * stride_y_row
229
- if HAS_RESIDUAL:
230
- RESIDUAL += row * stride_res_row
231
- if STORE_RESIDUAL_OUT:
232
- RESIDUAL_OUT += row * stride_res_out_row
233
- if HAS_X1:
234
- X1 += row * stride_x1_row
235
- if HAS_W1:
236
- Y1 += row * stride_y1_row
237
- # Compute mean and variance
238
- cols = tl.arange(0, BLOCK_N)
239
- x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
240
- if HAS_ROWSCALE:
241
- rowscale = tl.load(ROWSCALE + row).to(tl.float32)
242
- x *= rowscale
243
- if HAS_DROPOUT:
244
- # Compute dropout mask
245
- # 7 rounds is good enough, and reduces register pressure
246
- keep_mask = (
247
- tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
248
- )
249
- x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
250
- if STORE_DROPOUT_MASK:
251
- tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
252
- if HAS_X1:
253
- x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
254
- if HAS_ROWSCALE:
255
- rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
256
- x1 *= rowscale
257
- if HAS_DROPOUT:
258
- # Compute dropout mask
259
- # 7 rounds is good enough, and reduces register pressure
260
- keep_mask = (
261
- tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
262
- > dropout_p
263
- )
264
- x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
265
- if STORE_DROPOUT_MASK:
266
- tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
267
- x += x1
268
- if HAS_RESIDUAL:
269
- residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
270
- x += residual
271
- if STORE_RESIDUAL_OUT:
272
- tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
273
- if not IS_RMS_NORM:
274
- mean = tl.sum(x, axis=0) / N
275
- tl.store(Mean + row, mean)
276
- xbar = tl.where(cols < N, x - mean, 0.0)
277
- var = tl.sum(xbar * xbar, axis=0) / N
278
- else:
279
- xbar = tl.where(cols < N, x, 0.0)
280
- var = tl.sum(xbar * xbar, axis=0) / N
281
- rstd = 1 / tl.sqrt(var + eps)
282
- tl.store(Rstd + row, rstd)
283
- # Normalize and apply linear transformation
284
- mask = cols < N
285
- w = tl.load(W + cols, mask=mask).to(tl.float32)
286
- if HAS_BIAS:
287
- b = tl.load(B + cols, mask=mask).to(tl.float32)
288
- x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
289
- y = x_hat * w + b if HAS_BIAS else x_hat * w
290
- # Write output
291
- tl.store(Y + cols, y, mask=mask)
292
- if HAS_W1:
293
- w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
294
- if HAS_B1:
295
- b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
296
- y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
297
- tl.store(Y1 + cols, y1, mask=mask)
298
-
299
-
300
- def _layer_norm_fwd(
301
- x,
302
- weight,
303
- bias,
304
- eps,
305
- residual=None,
306
- x1=None,
307
- weight1=None,
308
- bias1=None,
309
- dropout_p=0.0,
310
- rowscale=None,
311
- out_dtype=None,
312
- residual_dtype=None,
313
- is_rms_norm=False,
314
- return_dropout_mask=False,
315
- ):
316
- if residual is not None:
317
- residual_dtype = residual.dtype
318
- M, N = x.shape
319
- assert x.stride(-1) == 1
320
- if residual is not None:
321
- assert residual.stride(-1) == 1
322
- assert residual.shape == (M, N)
323
- assert weight.shape == (N,)
324
- assert weight.stride(-1) == 1
325
- if bias is not None:
326
- assert bias.stride(-1) == 1
327
- assert bias.shape == (N,)
328
- if x1 is not None:
329
- assert x1.shape == x.shape
330
- assert rowscale is None
331
- assert x1.stride(-1) == 1
332
- if weight1 is not None:
333
- assert weight1.shape == (N,)
334
- assert weight1.stride(-1) == 1
335
- if bias1 is not None:
336
- assert bias1.shape == (N,)
337
- assert bias1.stride(-1) == 1
338
- if rowscale is not None:
339
- assert rowscale.is_contiguous()
340
- assert rowscale.shape == (M,)
341
- # allocate output
342
- y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
343
- assert y.stride(-1) == 1
344
- if weight1 is not None:
345
- y1 = torch.empty_like(y)
346
- assert y1.stride(-1) == 1
347
- else:
348
- y1 = None
349
- if (
350
- residual is not None
351
- or (residual_dtype is not None and residual_dtype != x.dtype)
352
- or dropout_p > 0.0
353
- or rowscale is not None
354
- or x1 is not None
355
- ):
356
- residual_out = torch.empty(
357
- M,
358
- N,
359
- device=x.device,
360
- dtype=residual_dtype if residual_dtype is not None else x.dtype,
361
- )
362
- assert residual_out.stride(-1) == 1
363
- else:
364
- residual_out = None
365
- mean = (
366
- torch.empty((M,), dtype=torch.float32, device=x.device)
367
- if not is_rms_norm
368
- else None
369
- )
370
- rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
371
- if dropout_p > 0.0:
372
- seeds = torch.randint(
373
- 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
374
- )
375
- else:
376
- seeds = None
377
- if return_dropout_mask and dropout_p > 0.0:
378
- dropout_mask = torch.empty(
379
- M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool
380
- )
381
- else:
382
- dropout_mask = None
383
- # Less than 64KB per feature: enqueue fused kernel
384
- MAX_FUSED_SIZE = 65536 // x.element_size()
385
- BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
386
- if N > BLOCK_N:
387
- raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
388
- with torch.cuda.device(x.device.index):
389
- _layer_norm_fwd_1pass_kernel[(M,)](
390
- x,
391
- y,
392
- weight,
393
- bias,
394
- residual,
395
- x1,
396
- weight1,
397
- bias1,
398
- y1,
399
- residual_out,
400
- rowscale,
401
- seeds,
402
- dropout_mask,
403
- mean,
404
- rstd,
405
- x.stride(0),
406
- y.stride(0),
407
- residual.stride(0) if residual is not None else 0,
408
- residual_out.stride(0) if residual_out is not None else 0,
409
- x1.stride(0) if x1 is not None else 0,
410
- y1.stride(0) if y1 is not None else 0,
411
- M,
412
- N,
413
- eps,
414
- dropout_p,
415
- is_rms_norm,
416
- BLOCK_N,
417
- residual is not None,
418
- residual_out is not None,
419
- bias is not None,
420
- dropout_p > 0.0,
421
- dropout_mask is not None,
422
- rowscale is not None,
423
- )
424
- # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
425
- if dropout_mask is not None and x1 is not None:
426
- dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
427
- else:
428
- dropout_mask1 = None
429
- return (
430
- y,
431
- y1,
432
- mean,
433
- rstd,
434
- residual_out if residual_out is not None else x,
435
- seeds,
436
- dropout_mask,
437
- dropout_mask1,
438
- )
439
-
440
-
441
- @triton.autotune(
442
- configs=pruned_configs_autotune,
443
- key=[
444
- "N",
445
- "HAS_DRESIDUAL",
446
- "STORE_DRESIDUAL",
447
- "IS_RMS_NORM",
448
- "HAS_BIAS",
449
- "HAS_DROPOUT",
450
- ],
451
- )
452
- # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
453
- # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
454
- # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
455
- @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
456
- @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
457
- @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
458
- @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
459
- @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
460
- @triton.jit
461
- def _layer_norm_bwd_kernel(
462
- X, # pointer to the input
463
- W, # pointer to the weights
464
- B, # pointer to the biases
465
- Y, # pointer to the output to be recomputed
466
- DY, # pointer to the output gradient
467
- DX, # pointer to the input gradient
468
- DW, # pointer to the partial sum of weights gradient
469
- DB, # pointer to the partial sum of biases gradient
470
- DRESIDUAL,
471
- W1,
472
- DY1,
473
- DX1,
474
- DW1,
475
- DB1,
476
- DRESIDUAL_IN,
477
- ROWSCALE,
478
- SEEDS,
479
- Mean, # pointer to the mean
480
- Rstd, # pointer to the 1/std
481
- stride_x_row, # how much to increase the pointer when moving by 1 row
482
- stride_y_row,
483
- stride_dy_row,
484
- stride_dx_row,
485
- stride_dres_row,
486
- stride_dy1_row,
487
- stride_dx1_row,
488
- stride_dres_in_row,
489
- M, # number of rows in X
490
- N, # number of columns in X
491
- eps, # epsilon to avoid division by zero
492
- dropout_p,
493
- rows_per_program,
494
- IS_RMS_NORM: tl.constexpr,
495
- BLOCK_N: tl.constexpr,
496
- HAS_DRESIDUAL: tl.constexpr,
497
- STORE_DRESIDUAL: tl.constexpr,
498
- HAS_BIAS: tl.constexpr,
499
- HAS_DROPOUT: tl.constexpr,
500
- HAS_ROWSCALE: tl.constexpr,
501
- HAS_DY1: tl.constexpr,
502
- HAS_DX1: tl.constexpr,
503
- HAS_B1: tl.constexpr,
504
- RECOMPUTE_OUTPUT: tl.constexpr,
505
- ):
506
- # Map the program id to the elements of X, DX, and DY it should compute.
507
- row_block_id = tl.program_id(0)
508
- row_start = row_block_id * rows_per_program
509
- # Do not early exit if row_start >= M, because we need to write DW and DB
510
- cols = tl.arange(0, BLOCK_N)
511
- mask = cols < N
512
- X += row_start * stride_x_row
513
- if HAS_DRESIDUAL:
514
- DRESIDUAL += row_start * stride_dres_row
515
- if STORE_DRESIDUAL:
516
- DRESIDUAL_IN += row_start * stride_dres_in_row
517
- DY += row_start * stride_dy_row
518
- DX += row_start * stride_dx_row
519
- if HAS_DY1:
520
- DY1 += row_start * stride_dy1_row
521
- if HAS_DX1:
522
- DX1 += row_start * stride_dx1_row
523
- if RECOMPUTE_OUTPUT:
524
- Y += row_start * stride_y_row
525
- w = tl.load(W + cols, mask=mask).to(tl.float32)
526
- if RECOMPUTE_OUTPUT and HAS_BIAS:
527
- b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
528
- if HAS_DY1:
529
- w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
530
- dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
531
- if HAS_BIAS:
532
- db = tl.zeros((BLOCK_N,), dtype=tl.float32)
533
- if HAS_DY1:
534
- dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
535
- if HAS_B1:
536
- db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
537
- row_end = min((row_block_id + 1) * rows_per_program, M)
538
- for row in range(row_start, row_end):
539
- # Load data to SRAM
540
- x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
541
- dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
542
- if HAS_DY1:
543
- dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
544
- if not IS_RMS_NORM:
545
- mean = tl.load(Mean + row)
546
- rstd = tl.load(Rstd + row)
547
- # Compute dx
548
- xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
549
- xhat = tl.where(mask, xhat, 0.0)
550
- if RECOMPUTE_OUTPUT:
551
- y = xhat * w + b if HAS_BIAS else xhat * w
552
- tl.store(Y + cols, y, mask=mask)
553
- wdy = w * dy
554
- dw += dy * xhat
555
- if HAS_BIAS:
556
- db += dy
557
- if HAS_DY1:
558
- wdy += w1 * dy1
559
- dw1 += dy1 * xhat
560
- if HAS_B1:
561
- db1 += dy1
562
- if not IS_RMS_NORM:
563
- c1 = tl.sum(xhat * wdy, axis=0) / N
564
- c2 = tl.sum(wdy, axis=0) / N
565
- dx = (wdy - (xhat * c1 + c2)) * rstd
566
- else:
567
- c1 = tl.sum(xhat * wdy, axis=0) / N
568
- dx = (wdy - xhat * c1) * rstd
569
- if HAS_DRESIDUAL:
570
- dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
571
- dx += dres
572
- # Write dx
573
- if STORE_DRESIDUAL:
574
- tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
575
- if HAS_DX1:
576
- if HAS_DROPOUT:
577
- keep_mask = (
578
- tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
579
- > dropout_p
580
- )
581
- dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
582
- else:
583
- dx1 = dx
584
- tl.store(DX1 + cols, dx1, mask=mask)
585
- if HAS_DROPOUT:
586
- keep_mask = (
587
- tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7)
588
- > dropout_p
589
- )
590
- dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
591
- if HAS_ROWSCALE:
592
- rowscale = tl.load(ROWSCALE + row).to(tl.float32)
593
- dx *= rowscale
594
- tl.store(DX + cols, dx, mask=mask)
595
-
596
- X += stride_x_row
597
- if HAS_DRESIDUAL:
598
- DRESIDUAL += stride_dres_row
599
- if STORE_DRESIDUAL:
600
- DRESIDUAL_IN += stride_dres_in_row
601
- if RECOMPUTE_OUTPUT:
602
- Y += stride_y_row
603
- DY += stride_dy_row
604
- DX += stride_dx_row
605
- if HAS_DY1:
606
- DY1 += stride_dy1_row
607
- if HAS_DX1:
608
- DX1 += stride_dx1_row
609
- tl.store(DW + row_block_id * N + cols, dw, mask=mask)
610
- if HAS_BIAS:
611
- tl.store(DB + row_block_id * N + cols, db, mask=mask)
612
- if HAS_DY1:
613
- tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
614
- if HAS_B1:
615
- tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
616
-
617
-
618
- def _layer_norm_bwd(
619
- dy,
620
- x,
621
- weight,
622
- bias,
623
- eps,
624
- mean,
625
- rstd,
626
- dresidual=None,
627
- dy1=None,
628
- weight1=None,
629
- bias1=None,
630
- seeds=None,
631
- dropout_p=0.0,
632
- rowscale=None,
633
- has_residual=False,
634
- has_x1=False,
635
- is_rms_norm=False,
636
- x_dtype=None,
637
- recompute_output=False,
638
- ):
639
- M, N = x.shape
640
- assert x.stride(-1) == 1
641
- assert dy.stride(-1) == 1
642
- assert dy.shape == (M, N)
643
- if dresidual is not None:
644
- assert dresidual.stride(-1) == 1
645
- assert dresidual.shape == (M, N)
646
- assert weight.shape == (N,)
647
- assert weight.stride(-1) == 1
648
- if bias is not None:
649
- assert bias.stride(-1) == 1
650
- assert bias.shape == (N,)
651
- if dy1 is not None:
652
- assert weight1 is not None
653
- assert dy1.shape == dy.shape
654
- assert dy1.stride(-1) == 1
655
- if weight1 is not None:
656
- assert weight1.shape == (N,)
657
- assert weight1.stride(-1) == 1
658
- if bias1 is not None:
659
- assert bias1.shape == (N,)
660
- assert bias1.stride(-1) == 1
661
- if seeds is not None:
662
- assert seeds.is_contiguous()
663
- assert seeds.shape == (M if not has_x1 else M * 2,)
664
- if rowscale is not None:
665
- assert rowscale.is_contiguous()
666
- assert rowscale.shape == (M,)
667
- # allocate output
668
- dx = (
669
- torch.empty_like(x)
670
- if x_dtype is None
671
- else torch.empty(M, N, dtype=x_dtype, device=x.device)
672
- )
673
- dresidual_in = (
674
- torch.empty_like(x)
675
- if has_residual
676
- and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
677
- else None
678
- )
679
- dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
680
- y = (
681
- torch.empty(M, N, dtype=dy.dtype, device=dy.device)
682
- if recompute_output
683
- else None
684
- )
685
- if recompute_output:
686
- assert (
687
- weight1 is None
688
- ), "recompute_output is not supported with parallel LayerNorm"
689
-
690
- # Less than 64KB per feature: enqueue fused kernel
691
- MAX_FUSED_SIZE = 65536 // x.element_size()
692
- BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
693
- if N > BLOCK_N:
694
- raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
695
- sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
696
- _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
697
- _db = (
698
- torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
699
- if bias is not None
700
- else None
701
- )
702
- _dw1 = torch.empty_like(_dw) if weight1 is not None else None
703
- _db1 = torch.empty_like(_db) if bias1 is not None else None
704
- rows_per_program = math.ceil(M / sm_count)
705
- grid = (sm_count,)
706
- with torch.cuda.device(x.device.index):
707
- _layer_norm_bwd_kernel[grid](
708
- x,
709
- weight,
710
- bias,
711
- y,
712
- dy,
713
- dx,
714
- _dw,
715
- _db,
716
- dresidual,
717
- weight1,
718
- dy1,
719
- dx1,
720
- _dw1,
721
- _db1,
722
- dresidual_in,
723
- rowscale,
724
- seeds,
725
- mean,
726
- rstd,
727
- x.stride(0),
728
- 0 if not recompute_output else y.stride(0),
729
- dy.stride(0),
730
- dx.stride(0),
731
- dresidual.stride(0) if dresidual is not None else 0,
732
- dy1.stride(0) if dy1 is not None else 0,
733
- dx1.stride(0) if dx1 is not None else 0,
734
- dresidual_in.stride(0) if dresidual_in is not None else 0,
735
- M,
736
- N,
737
- eps,
738
- dropout_p,
739
- rows_per_program,
740
- is_rms_norm,
741
- BLOCK_N,
742
- dresidual is not None,
743
- dresidual_in is not None,
744
- bias is not None,
745
- dropout_p > 0.0,
746
- )
747
- dw = _dw.sum(0).to(weight.dtype)
748
- db = _db.sum(0).to(bias.dtype) if bias is not None else None
749
- dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
750
- db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
751
- # Don't need to compute dresidual_in separately in this case
752
- if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
753
- dresidual_in = dx
754
- if has_x1 and dropout_p == 0.0:
755
- dx1 = dx
756
- return (
757
- (dx, dw, db, dresidual_in, dx1, dw1, db1)
758
- if not recompute_output
759
- else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
760
- )
761
-
762
-
763
- class LayerNormFn(torch.autograd.Function):
764
- @staticmethod
765
- def forward(
766
- ctx,
767
- x,
768
- weight,
769
- bias,
770
- residual=None,
771
- x1=None,
772
- weight1=None,
773
- bias1=None,
774
- eps=1e-6,
775
- dropout_p=0.0,
776
- rowscale=None,
777
- prenorm=False,
778
- residual_in_fp32=False,
779
- is_rms_norm=False,
780
- return_dropout_mask=False,
781
- ):
782
- x_shape_og = x.shape
783
- # reshape input data into 2D tensor
784
- x = x.reshape(-1, x.shape[-1])
785
- if x.stride(-1) != 1:
786
- x = x.contiguous()
787
- if residual is not None:
788
- assert residual.shape == x_shape_og
789
- residual = residual.reshape(-1, residual.shape[-1])
790
- if residual.stride(-1) != 1:
791
- residual = residual.contiguous()
792
- if x1 is not None:
793
- assert x1.shape == x_shape_og
794
- assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
795
- x1 = x1.reshape(-1, x1.shape[-1])
796
- if x1.stride(-1) != 1:
797
- x1 = x1.contiguous()
798
- weight = weight.contiguous()
799
- if bias is not None:
800
- bias = bias.contiguous()
801
- if weight1 is not None:
802
- weight1 = weight1.contiguous()
803
- if bias1 is not None:
804
- bias1 = bias1.contiguous()
805
- if rowscale is not None:
806
- rowscale = rowscale.reshape(-1).contiguous()
807
- residual_dtype = (
808
- residual.dtype
809
- if residual is not None
810
- else (torch.float32 if residual_in_fp32 else None)
811
- )
812
- y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = (
813
- _layer_norm_fwd(
814
- x,
815
- weight,
816
- bias,
817
- eps,
818
- residual,
819
- x1,
820
- weight1,
821
- bias1,
822
- dropout_p=dropout_p,
823
- rowscale=rowscale,
824
- residual_dtype=residual_dtype,
825
- is_rms_norm=is_rms_norm,
826
- return_dropout_mask=return_dropout_mask,
827
- )
828
- )
829
- ctx.save_for_backward(
830
- residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
831
- )
832
- ctx.x_shape_og = x_shape_og
833
- ctx.eps = eps
834
- ctx.dropout_p = dropout_p
835
- ctx.is_rms_norm = is_rms_norm
836
- ctx.has_residual = residual is not None
837
- ctx.has_x1 = x1 is not None
838
- ctx.prenorm = prenorm
839
- ctx.x_dtype = x.dtype
840
- y = y.reshape(x_shape_og)
841
- y1 = y1.reshape(x_shape_og) if y1 is not None else None
842
- residual_out = (
843
- residual_out.reshape(x_shape_og) if residual_out is not None else None
844
- )
845
- dropout_mask = (
846
- dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
847
- )
848
- dropout_mask1 = (
849
- dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
850
- )
851
- if not return_dropout_mask:
852
- if weight1 is None:
853
- return y if not prenorm else (y, residual_out)
854
- else:
855
- return (y, y1) if not prenorm else (y, y1, residual_out)
856
- else:
857
- if weight1 is None:
858
- return (
859
- (y, dropout_mask, dropout_mask1)
860
- if not prenorm
861
- else (y, residual_out, dropout_mask, dropout_mask1)
862
- )
863
- else:
864
- return (
865
- (y, y1, dropout_mask, dropout_mask1)
866
- if not prenorm
867
- else (y, y1, residual_out, dropout_mask, dropout_mask1)
868
- )
869
-
870
- @staticmethod
871
- def backward(ctx, dy, *args):
872
- x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
873
- dy = dy.reshape(-1, dy.shape[-1])
874
- if dy.stride(-1) != 1:
875
- dy = dy.contiguous()
876
- assert dy.shape == x.shape
877
- if weight1 is not None:
878
- dy1, args = args[0], args[1:]
879
- dy1 = dy1.reshape(-1, dy1.shape[-1])
880
- if dy1.stride(-1) != 1:
881
- dy1 = dy1.contiguous()
882
- assert dy1.shape == x.shape
883
- else:
884
- dy1 = None
885
- if ctx.prenorm:
886
- dresidual = args[0]
887
- dresidual = dresidual.reshape(-1, dresidual.shape[-1])
888
- if dresidual.stride(-1) != 1:
889
- dresidual = dresidual.contiguous()
890
- assert dresidual.shape == x.shape
891
- else:
892
- dresidual = None
893
- dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
894
- dy,
895
- x,
896
- weight,
897
- bias,
898
- ctx.eps,
899
- mean,
900
- rstd,
901
- dresidual,
902
- dy1,
903
- weight1,
904
- bias1,
905
- seeds,
906
- ctx.dropout_p,
907
- rowscale,
908
- ctx.has_residual,
909
- ctx.has_x1,
910
- ctx.is_rms_norm,
911
- x_dtype=ctx.x_dtype,
912
- )
913
- return (
914
- dx.reshape(ctx.x_shape_og),
915
- dw,
916
- db,
917
- dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
918
- dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
919
- dw1,
920
- db1,
921
- None,
922
- None,
923
- None,
924
- None,
925
- None,
926
- None,
927
- None,
928
- )
929
-
930
-
931
- def layer_norm_fn(
932
- x,
933
- weight,
934
- bias,
935
- residual=None,
936
- x1=None,
937
- weight1=None,
938
- bias1=None,
939
- eps=1e-6,
940
- dropout_p=0.0,
941
- rowscale=None,
942
- prenorm=False,
943
- residual_in_fp32=False,
944
- is_rms_norm=False,
945
- return_dropout_mask=False,
946
- ):
947
- return LayerNormFn.apply(
948
- x,
949
- weight,
950
- bias,
951
- residual,
952
- x1,
953
- weight1,
954
- bias1,
955
- eps,
956
- dropout_p,
957
- rowscale,
958
- prenorm,
959
- residual_in_fp32,
960
- is_rms_norm,
961
- return_dropout_mask,
962
- )
963
-
964
-
965
- def rms_norm_fn(
966
- x,
967
- weight,
968
- bias,
969
- residual=None,
970
- x1=None,
971
- weight1=None,
972
- bias1=None,
973
- eps=1e-6,
974
- dropout_p=0.0,
975
- rowscale=None,
976
- prenorm=False,
977
- residual_in_fp32=False,
978
- return_dropout_mask=False,
979
- ):
980
- return LayerNormFn.apply(
981
- x,
982
- weight,
983
- bias,
984
- residual,
985
- x1,
986
- weight1,
987
- bias1,
988
- eps,
989
- dropout_p,
990
- rowscale,
991
- prenorm,
992
- residual_in_fp32,
993
- True,
994
- return_dropout_mask,
995
- )
996
-
997
-
998
- class RMSNorm(torch.nn.Module):
999
-
1000
- def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
1001
- factory_kwargs = {"device": device, "dtype": dtype}
1002
- super().__init__()
1003
- self.eps = eps
1004
- if dropout_p > 0.0:
1005
- self.drop = torch.nn.Dropout(dropout_p)
1006
- else:
1007
- self.drop = None
1008
- self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
1009
- self.register_parameter("bias", None)
1010
- self.reset_parameters()
1011
-
1012
- def reset_parameters(self):
1013
- torch.nn.init.ones_(self.weight)
1014
-
1015
- def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
1016
- return rms_norm_fn(
1017
- x,
1018
- self.weight,
1019
- self.bias,
1020
- residual=residual,
1021
- eps=self.eps,
1022
- dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
1023
- prenorm=prenorm,
1024
- residual_in_fp32=residual_in_fp32,
1025
- )
1026
-
1027
-
1028
- class LayerNormLinearFn(torch.autograd.Function):
1029
- @staticmethod
1030
- @custom_fwd
1031
- def forward(
1032
- ctx,
1033
- x,
1034
- norm_weight,
1035
- norm_bias,
1036
- linear_weight,
1037
- linear_bias,
1038
- residual=None,
1039
- eps=1e-6,
1040
- prenorm=False,
1041
- residual_in_fp32=False,
1042
- is_rms_norm=False,
1043
- ):
1044
- x_shape_og = x.shape
1045
- # reshape input data into 2D tensor
1046
- x = x.reshape(-1, x.shape[-1])
1047
- if x.stride(-1) != 1:
1048
- x = x.contiguous()
1049
- if residual is not None:
1050
- assert residual.shape == x_shape_og
1051
- residual = residual.reshape(-1, residual.shape[-1])
1052
- if residual.stride(-1) != 1:
1053
- residual = residual.contiguous()
1054
- norm_weight = norm_weight.contiguous()
1055
- if norm_bias is not None:
1056
- norm_bias = norm_bias.contiguous()
1057
- residual_dtype = (
1058
- residual.dtype
1059
- if residual is not None
1060
- else (torch.float32 if residual_in_fp32 else None)
1061
- )
1062
- y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
1063
- x,
1064
- norm_weight,
1065
- norm_bias,
1066
- eps,
1067
- residual,
1068
- out_dtype=(
1069
- None
1070
- if not torch.is_autocast_enabled()
1071
- else torch.get_autocast_gpu_dtype()
1072
- ),
1073
- residual_dtype=residual_dtype,
1074
- is_rms_norm=is_rms_norm,
1075
- )
1076
- y = y.reshape(x_shape_og)
1077
- dtype = (
1078
- torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
1079
- )
1080
- linear_weight = linear_weight.to(dtype)
1081
- linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
1082
- out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
1083
- # We don't store y, will be recomputed in the backward pass to save memory
1084
- ctx.save_for_backward(
1085
- residual_out, norm_weight, norm_bias, linear_weight, mean, rstd
1086
- )
1087
- ctx.x_shape_og = x_shape_og
1088
- ctx.eps = eps
1089
- ctx.is_rms_norm = is_rms_norm
1090
- ctx.has_residual = residual is not None
1091
- ctx.prenorm = prenorm
1092
- ctx.x_dtype = x.dtype
1093
- ctx.linear_bias_is_none = linear_bias is None
1094
- return out if not prenorm else (out, residual_out.reshape(x_shape_og))
1095
-
1096
- @staticmethod
1097
- @custom_bwd
1098
- def backward(ctx, dout, *args):
1099
- x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
1100
- dout = dout.reshape(-1, dout.shape[-1])
1101
- dy = F.linear(dout, linear_weight.t())
1102
- dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
1103
- if dy.stride(-1) != 1:
1104
- dy = dy.contiguous()
1105
- assert dy.shape == x.shape
1106
- if ctx.prenorm:
1107
- dresidual = args[0]
1108
- dresidual = dresidual.reshape(-1, dresidual.shape[-1])
1109
- if dresidual.stride(-1) != 1:
1110
- dresidual = dresidual.contiguous()
1111
- assert dresidual.shape == x.shape
1112
- else:
1113
- dresidual = None
1114
- dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
1115
- dy,
1116
- x,
1117
- norm_weight,
1118
- norm_bias,
1119
- ctx.eps,
1120
- mean,
1121
- rstd,
1122
- dresidual=dresidual,
1123
- has_residual=ctx.has_residual,
1124
- is_rms_norm=ctx.is_rms_norm,
1125
- x_dtype=ctx.x_dtype,
1126
- recompute_output=True,
1127
- )
1128
- dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
1129
- return (
1130
- dx.reshape(ctx.x_shape_og),
1131
- dnorm_weight,
1132
- dnorm_bias,
1133
- dlinear_weight,
1134
- dlinear_bias,
1135
- dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
1136
- None,
1137
- None,
1138
- None,
1139
- None,
1140
- )
1141
-
1142
-
1143
- def layer_norm_linear_fn(
1144
- x,
1145
- norm_weight,
1146
- norm_bias,
1147
- linear_weight,
1148
- linear_bias,
1149
- residual=None,
1150
- eps=1e-6,
1151
- prenorm=False,
1152
- residual_in_fp32=False,
1153
- is_rms_norm=False,
1154
- ):
1155
- return LayerNormLinearFn.apply(
1156
- x,
1157
- norm_weight,
1158
- norm_bias,
1159
- linear_weight,
1160
- linear_bias,
1161
- residual,
1162
- eps,
1163
- prenorm,
1164
- residual_in_fp32,
1165
- is_rms_norm,
1166
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/selective_state_update.py DELETED
@@ -1,389 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao, Albert Gu.
2
-
3
- """We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
4
- """
5
-
6
- import math
7
- import torch
8
- import torch.nn.functional as F
9
-
10
- import triton
11
- import triton.language as tl
12
-
13
- from einops import rearrange, repeat
14
-
15
- from .softplus import softplus
16
-
17
-
18
- @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
19
- @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
20
- @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
21
- @triton.heuristics(
22
- {
23
- "HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"]
24
- is not None
25
- }
26
- )
27
- @triton.heuristics(
28
- {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}
29
- )
30
- @triton.jit
31
- def _selective_scan_update_kernel(
32
- # Pointers to matrices
33
- state_ptr,
34
- x_ptr,
35
- dt_ptr,
36
- dt_bias_ptr,
37
- A_ptr,
38
- B_ptr,
39
- C_ptr,
40
- D_ptr,
41
- z_ptr,
42
- out_ptr,
43
- state_batch_indices_ptr,
44
- # Matrix dimensions
45
- batch,
46
- nheads,
47
- dim,
48
- dstate,
49
- nheads_ngroups_ratio,
50
- # Strides
51
- stride_state_batch,
52
- stride_state_head,
53
- stride_state_dim,
54
- stride_state_dstate,
55
- stride_x_batch,
56
- stride_x_head,
57
- stride_x_dim,
58
- stride_dt_batch,
59
- stride_dt_head,
60
- stride_dt_dim,
61
- stride_dt_bias_head,
62
- stride_dt_bias_dim,
63
- stride_A_head,
64
- stride_A_dim,
65
- stride_A_dstate,
66
- stride_B_batch,
67
- stride_B_group,
68
- stride_B_dstate,
69
- stride_C_batch,
70
- stride_C_group,
71
- stride_C_dstate,
72
- stride_D_head,
73
- stride_D_dim,
74
- stride_z_batch,
75
- stride_z_head,
76
- stride_z_dim,
77
- stride_out_batch,
78
- stride_out_head,
79
- stride_out_dim,
80
- # Meta-parameters
81
- DT_SOFTPLUS: tl.constexpr,
82
- TIE_HDIM: tl.constexpr,
83
- BLOCK_SIZE_M: tl.constexpr,
84
- HAS_DT_BIAS: tl.constexpr,
85
- HAS_D: tl.constexpr,
86
- HAS_Z: tl.constexpr,
87
- HAS_STATE_BATCH_INDICES: tl.constexpr,
88
- BLOCK_SIZE_DSTATE: tl.constexpr,
89
- ):
90
- pid_m = tl.program_id(axis=0)
91
- pid_b = tl.program_id(axis=1)
92
- pid_h = tl.program_id(axis=2)
93
-
94
- if HAS_STATE_BATCH_INDICES:
95
- state_batch_indices_ptr += pid_b
96
- state_batch_idx = tl.load(state_batch_indices_ptr)
97
- state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
98
- else:
99
- state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
100
-
101
- x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
102
- dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
103
- if HAS_DT_BIAS:
104
- dt_bias_ptr += pid_h * stride_dt_bias_head
105
- A_ptr += pid_h * stride_A_head
106
- B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
107
- C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
108
- if HAS_Z:
109
- z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
110
- out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
111
-
112
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
113
- offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
114
- state_ptrs = state_ptr + (
115
- offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
116
- )
117
- x_ptrs = x_ptr + offs_m * stride_x_dim
118
- dt_ptrs = dt_ptr + offs_m * stride_dt_dim
119
- if HAS_DT_BIAS:
120
- dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
121
- if HAS_D:
122
- D_ptr += pid_h * stride_D_head
123
- A_ptrs = A_ptr + (
124
- offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
125
- )
126
- B_ptrs = B_ptr + offs_n * stride_B_dstate
127
- C_ptrs = C_ptr + offs_n * stride_C_dstate
128
- if HAS_D:
129
- D_ptrs = D_ptr + offs_m * stride_D_dim
130
- if HAS_Z:
131
- z_ptrs = z_ptr + offs_m * stride_z_dim
132
- out_ptrs = out_ptr + offs_m * stride_out_dim
133
-
134
- state = tl.load(
135
- state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
136
- )
137
- x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
138
- if not TIE_HDIM:
139
- dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
140
- if HAS_DT_BIAS:
141
- dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
142
- if DT_SOFTPLUS:
143
- dt = tl.where(dt <= 20.0, softplus(dt), dt)
144
- A = tl.load(
145
- A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
146
- ).to(tl.float32)
147
- dA = tl.exp(A * dt[:, None])
148
- else:
149
- dt = tl.load(dt_ptr).to(tl.float32)
150
- if HAS_DT_BIAS:
151
- dt += tl.load(dt_bias_ptr).to(tl.float32)
152
- if DT_SOFTPLUS:
153
- dt = tl.where(dt <= 20.0, softplus(dt), dt)
154
- A = tl.load(A_ptr).to(tl.float32)
155
- dA = tl.exp(A * dt) # scalar, not a matrix
156
-
157
- B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
158
- C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
159
- if HAS_D:
160
- D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
161
- if HAS_Z:
162
- z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
163
-
164
- if not TIE_HDIM:
165
- dB = B[None, :] * dt[:, None]
166
- else:
167
- dB = B * dt # vector of size (dstate,)
168
- state = state * dA + dB * x[:, None]
169
- tl.store(
170
- state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
171
- )
172
- out = tl.sum(state * C[None, :], axis=1)
173
- if HAS_D:
174
- out += x * D
175
- if HAS_Z:
176
- out *= z * tl.sigmoid(z)
177
- tl.store(out_ptrs, out, mask=offs_m < dim)
178
-
179
-
180
- def selective_state_update(
181
- state,
182
- x,
183
- dt,
184
- A,
185
- B,
186
- C,
187
- D=None,
188
- z=None,
189
- dt_bias=None,
190
- dt_softplus=False,
191
- state_batch_indices=None,
192
- ):
193
- """
194
- Argument:
195
- state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
196
- x: (batch, dim) or (batch, nheads, dim)
197
- dt: (batch, dim) or (batch, nheads, dim)
198
- A: (dim, dstate) or (nheads, dim, dstate)
199
- B: (batch, dstate) or (batch, ngroups, dstate)
200
- C: (batch, dstate) or (batch, ngroups, dstate)
201
- D: (dim,) or (nheads, dim)
202
- z: (batch, dim) or (batch, nheads, dim)
203
- dt_bias: (dim,) or (nheads, dim)
204
- Return:
205
- out: (batch, dim) or (batch, nheads, dim)
206
- """
207
- has_heads = state.dim() > 3
208
- if state.dim() == 3:
209
- state = state.unsqueeze(1)
210
- if x.dim() == 2:
211
- x = x.unsqueeze(1)
212
- if dt.dim() == 2:
213
- dt = dt.unsqueeze(1)
214
- if A.dim() == 2:
215
- A = A.unsqueeze(0)
216
- if B.dim() == 2:
217
- B = B.unsqueeze(1)
218
- if C.dim() == 2:
219
- C = C.unsqueeze(1)
220
- if D is not None and D.dim() == 1:
221
- D = D.unsqueeze(0)
222
- if z is not None and z.dim() == 2:
223
- z = z.unsqueeze(1)
224
- if dt_bias is not None and dt_bias.dim() == 1:
225
- dt_bias = dt_bias.unsqueeze(0)
226
- _, nheads, dim, dstate = state.shape
227
- batch = x.shape[0]
228
- if x.shape != (batch, nheads, dim):
229
- print(f"{state.shape} {x.shape} {batch} {nheads} {dim}")
230
- assert x.shape == (batch, nheads, dim)
231
- assert dt.shape == x.shape
232
- assert A.shape == (nheads, dim, dstate)
233
- ngroups = B.shape[1]
234
- assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
235
- assert B.shape == (batch, ngroups, dstate)
236
- assert C.shape == B.shape
237
- if D is not None:
238
- assert D.shape == (nheads, dim)
239
- if z is not None:
240
- assert z.shape == x.shape
241
- if dt_bias is not None:
242
- assert dt_bias.shape == (nheads, dim)
243
- if state_batch_indices is not None:
244
- assert state_batch_indices.shape == (batch,)
245
- out = torch.empty_like(x)
246
- grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
247
- z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
248
- # We don't want autotune since it will overwrite the state
249
- # We instead tune by hand.
250
- BLOCK_SIZE_M, num_warps = (
251
- (32, 4)
252
- if dstate <= 16
253
- else (
254
- (16, 4)
255
- if dstate <= 32
256
- else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
257
- )
258
- )
259
- tie_hdim = (
260
- A.stride(-1) == 0
261
- and A.stride(-2) == 0
262
- and dt.stride(-1) == 0
263
- and dt_bias.stride(-1) == 0
264
- )
265
- with torch.cuda.device(x.device.index):
266
- _selective_scan_update_kernel[grid](
267
- state,
268
- x,
269
- dt,
270
- dt_bias,
271
- A,
272
- B,
273
- C,
274
- D,
275
- z,
276
- out,
277
- state_batch_indices,
278
- batch,
279
- nheads,
280
- dim,
281
- dstate,
282
- nheads // ngroups,
283
- state.stride(0),
284
- state.stride(1),
285
- state.stride(2),
286
- state.stride(3),
287
- x.stride(0),
288
- x.stride(1),
289
- x.stride(2),
290
- dt.stride(0),
291
- dt.stride(1),
292
- dt.stride(2),
293
- *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
294
- A.stride(0),
295
- A.stride(1),
296
- A.stride(2),
297
- B.stride(0),
298
- B.stride(1),
299
- B.stride(2),
300
- C.stride(0),
301
- C.stride(1),
302
- C.stride(2),
303
- *(D.stride(0), D.stride(1)) if D is not None else 0,
304
- z_strides[0],
305
- z_strides[1],
306
- z_strides[2],
307
- out.stride(0),
308
- out.stride(1),
309
- out.stride(2),
310
- dt_softplus,
311
- tie_hdim,
312
- BLOCK_SIZE_M,
313
- num_warps=num_warps,
314
- )
315
- if not has_heads:
316
- out = out.squeeze(1)
317
- return out
318
-
319
-
320
- def selective_state_update_ref(
321
- state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
322
- ):
323
- """
324
- Argument:
325
- state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
326
- x: (batch, dim) or (batch, nheads, dim)
327
- dt: (batch, dim) or (batch, nheads, dim)
328
- A: (dim, dstate) or (nheads, dim, dstate)
329
- B: (batch, dstate) or (batch, ngroups, dstate)
330
- C: (batch, dstate) or (batch, ngroups, dstate)
331
- D: (dim,) or (nheads, dim)
332
- z: (batch, dim) or (batch, nheads, dim)
333
- dt_bias: (dim,) or (nheads, dim)
334
- Return:
335
- out: (batch, dim) or (batch, nheads, dim)
336
- """
337
- has_heads = state.dim() > 3
338
- if state.dim() == 3:
339
- state = state.unsqueeze(1)
340
- if x.dim() == 2:
341
- x = x.unsqueeze(1)
342
- if dt.dim() == 2:
343
- dt = dt.unsqueeze(1)
344
- if A.dim() == 2:
345
- A = A.unsqueeze(0)
346
- if B.dim() == 2:
347
- B = B.unsqueeze(1)
348
- if C.dim() == 2:
349
- C = C.unsqueeze(1)
350
- if D is not None and D.dim() == 1:
351
- D = D.unsqueeze(0)
352
- if z is not None and z.dim() == 2:
353
- z = z.unsqueeze(1)
354
- if dt_bias is not None and dt_bias.dim() == 1:
355
- dt_bias = dt_bias.unsqueeze(0)
356
- batch, nheads, dim, dstate = state.shape
357
- assert x.shape == (batch, nheads, dim)
358
- assert dt.shape == x.shape
359
- assert A.shape == (nheads, dim, dstate)
360
- ngroups = B.shape[1]
361
- assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
362
- assert B.shape == (batch, ngroups, dstate)
363
- assert C.shape == B.shape
364
- if D is not None:
365
- assert D.shape == (nheads, dim)
366
- if z is not None:
367
- assert z.shape == x.shape
368
- if dt_bias is not None:
369
- assert dt_bias.shape == (nheads, dim)
370
- dt = dt + dt_bias
371
- dt = F.softplus(dt) if dt_softplus else dt
372
- dA = torch.exp(
373
- rearrange(dt, "b h d -> b h d 1") * A
374
- ) # (batch, nheads, dim, dstate)
375
- B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
376
- C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
377
- dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
378
- B, "b h n -> b h 1 n"
379
- ) # (batch, nheads, dim, dstate)
380
- state.copy_(
381
- state * dA + dB * rearrange(x, "b h d -> b h d 1")
382
- ) # (batch, dim, dstate
383
- out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
384
- if D is not None:
385
- out += (x * D).to(out.dtype)
386
- out = (out if z is None else out * F.silu(z)).to(x.dtype)
387
- if not has_heads:
388
- out = out.squeeze(1)
389
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_scan.py DELETED
The diff for this file is too large to render. See raw diff
 
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_state.py DELETED
@@ -1,2012 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao, Albert Gu.
2
-
3
- """We want triton==2.1.0 or 2.2.0 for this
4
- """
5
-
6
- import math
7
- import torch
8
- import torch.nn.functional as F
9
-
10
- import triton
11
- import triton.language as tl
12
-
13
- from einops import rearrange, repeat
14
-
15
- from .softplus import softplus
16
-
17
-
18
- def init_to_zero(names):
19
- return lambda nargs: [
20
- nargs[name].zero_() for name in names if nargs[name] is not None
21
- ]
22
-
23
-
24
- @triton.autotune(
25
- configs=[
26
- triton.Config({"BLOCK_SIZE_H": 1}),
27
- triton.Config({"BLOCK_SIZE_H": 2}),
28
- triton.Config({"BLOCK_SIZE_H": 4}),
29
- triton.Config({"BLOCK_SIZE_H": 8}),
30
- triton.Config({"BLOCK_SIZE_H": 16}),
31
- triton.Config({"BLOCK_SIZE_H": 32}),
32
- triton.Config({"BLOCK_SIZE_H": 64}),
33
- ],
34
- key=["chunk_size", "nheads"],
35
- )
36
- @triton.jit
37
- def _chunk_cumsum_fwd_kernel(
38
- # Pointers to matrices
39
- dt_ptr,
40
- A_ptr,
41
- dt_bias_ptr,
42
- dt_out_ptr,
43
- dA_cumsum_ptr,
44
- # Matrix dimension
45
- batch,
46
- seqlen,
47
- nheads,
48
- chunk_size,
49
- dt_min,
50
- dt_max,
51
- # Strides
52
- stride_dt_batch,
53
- stride_dt_seqlen,
54
- stride_dt_head,
55
- stride_A_head,
56
- stride_dt_bias_head,
57
- stride_dt_out_batch,
58
- stride_dt_out_chunk,
59
- stride_dt_out_head,
60
- stride_dt_out_csize,
61
- stride_dA_cs_batch,
62
- stride_dA_cs_chunk,
63
- stride_dA_cs_head,
64
- stride_dA_cs_csize,
65
- # Meta-parameters
66
- DT_SOFTPLUS: tl.constexpr,
67
- HAS_DT_BIAS: tl.constexpr,
68
- BLOCK_SIZE_H: tl.constexpr,
69
- BLOCK_SIZE_CHUNK: tl.constexpr,
70
- ):
71
- pid_b = tl.program_id(axis=0)
72
- pid_c = tl.program_id(axis=1)
73
- pid_h = tl.program_id(axis=2)
74
- dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
75
- dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk
76
- dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk
77
-
78
- offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
79
- offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
80
- dt_ptrs = dt_ptr + (
81
- offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
82
- )
83
- A_ptrs = A_ptr + offs_h * stride_A_head
84
- dt_out_ptrs = dt_out_ptr + (
85
- offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize
86
- )
87
- dA_cs_ptrs = dA_cumsum_ptr + (
88
- offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize
89
- )
90
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
91
-
92
- dt = tl.load(
93
- dt_ptrs,
94
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
95
- other=0.0,
96
- ).to(tl.float32)
97
- if HAS_DT_BIAS:
98
- dt_bias = tl.load(
99
- dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
100
- ).to(tl.float32)
101
- dt += dt_bias[:, None]
102
- if DT_SOFTPLUS:
103
- dt = tl.where(dt <= 20.0, softplus(dt), dt)
104
- # As of Triton 2.2.0, tl.clamp is not available yet
105
- # dt = tl.clamp(dt, dt_min, dt_max)
106
- dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
107
- dt = tl.where(
108
- (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
109
- )
110
- tl.store(
111
- dt_out_ptrs,
112
- dt,
113
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
114
- )
115
- A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
116
- dA = dt * A[:, None]
117
- dA_cs = tl.cumsum(dA, axis=1)
118
- tl.store(
119
- dA_cs_ptrs,
120
- dA_cs,
121
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
122
- )
123
-
124
-
125
- @triton.autotune(
126
- configs=[
127
- triton.Config(
128
- {"BLOCK_SIZE_H": 1}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
129
- ),
130
- triton.Config(
131
- {"BLOCK_SIZE_H": 2}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
132
- ),
133
- triton.Config(
134
- {"BLOCK_SIZE_H": 4}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
135
- ),
136
- triton.Config(
137
- {"BLOCK_SIZE_H": 8}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
138
- ),
139
- triton.Config(
140
- {"BLOCK_SIZE_H": 16}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
141
- ),
142
- triton.Config(
143
- {"BLOCK_SIZE_H": 32}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
144
- ),
145
- triton.Config(
146
- {"BLOCK_SIZE_H": 64}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
147
- ),
148
- ],
149
- key=["chunk_size", "nheads"],
150
- )
151
- @triton.jit
152
- def _chunk_cumsum_bwd_kernel(
153
- # Pointers to matrices
154
- ddA_ptr,
155
- ddt_out_ptr,
156
- dt_ptr,
157
- A_ptr,
158
- dt_bias_ptr,
159
- ddt_ptr,
160
- dA_ptr,
161
- ddt_bias_ptr,
162
- # Matrix dimensions
163
- batch,
164
- seqlen,
165
- nheads,
166
- chunk_size,
167
- dt_min,
168
- dt_max,
169
- # Strides
170
- stride_ddA_batch,
171
- stride_ddA_chunk,
172
- stride_ddA_head,
173
- stride_ddA_csize,
174
- stride_ddt_out_batch,
175
- stride_ddt_out_chunk,
176
- stride_ddt_out_head,
177
- stride_ddt_out_csize,
178
- stride_dt_batch,
179
- stride_dt_seqlen,
180
- stride_dt_head,
181
- stride_A_head,
182
- stride_dt_bias_head,
183
- stride_ddt_batch,
184
- stride_ddt_seqlen,
185
- stride_ddt_head,
186
- stride_dA_head,
187
- stride_ddt_bias_head,
188
- # Meta-parameters
189
- DT_SOFTPLUS: tl.constexpr,
190
- HAS_DT_BIAS: tl.constexpr,
191
- BLOCK_SIZE_H: tl.constexpr,
192
- BLOCK_SIZE_CHUNK: tl.constexpr,
193
- ):
194
- pid_b = tl.program_id(axis=0)
195
- pid_c = tl.program_id(axis=1)
196
- pid_h = tl.program_id(axis=2)
197
- ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk
198
- ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk
199
- dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
200
- ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen
201
-
202
- offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
203
- offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
204
- ddt_out_ptrs = ddt_out_ptr + (
205
- offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize
206
- )
207
- ddA_ptrs = ddA_ptr + (
208
- offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize
209
- )
210
- dt_ptrs = dt_ptr + (
211
- offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
212
- )
213
- ddt_ptrs = ddt_ptr + (
214
- offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen
215
- )
216
- A_ptrs = A_ptr + offs_h * stride_A_head
217
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
218
-
219
- ddA = tl.load(
220
- ddA_ptrs,
221
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
222
- other=0.0,
223
- ).to(tl.float32)
224
- ddt_out = tl.load(
225
- ddt_out_ptrs,
226
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
227
- other=0.0,
228
- ).to(tl.float32)
229
- A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
230
- ddt = ddA * A[:, None] + ddt_out
231
- dt = tl.load(
232
- dt_ptrs,
233
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
234
- other=0.0,
235
- ).to(tl.float32)
236
- if HAS_DT_BIAS:
237
- dt_bias = tl.load(
238
- dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
239
- ).to(tl.float32)
240
- dt += dt_bias[:, None]
241
- if DT_SOFTPLUS:
242
- dt_presoftplus = dt
243
- dt = tl.where(dt <= 20.0, softplus(dt), ddt)
244
- clamp_mask = (dt < dt_min) | (dt > dt_max)
245
- # As of Triton 2.2.0, tl.clamp is not available yet
246
- # dt = tl.clamp(dt, dt_min, dt_max)
247
- dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
248
- dt = tl.where(
249
- (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
250
- )
251
- ddt = tl.where(
252
- (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0
253
- )
254
- ddt = tl.where(clamp_mask, 0.0, ddt)
255
- if DT_SOFTPLUS:
256
- ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt)
257
- tl.store(
258
- ddt_ptrs,
259
- ddt,
260
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
261
- )
262
- dA = tl.sum(ddA * dt, axis=1)
263
- tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads)
264
- if HAS_DT_BIAS:
265
- ddt_bias = tl.sum(ddt, axis=1)
266
- tl.atomic_add(
267
- ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads
268
- )
269
-
270
-
271
- @triton.autotune(
272
- configs=[
273
- triton.Config(
274
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
275
- num_stages=3,
276
- num_warps=8,
277
- ),
278
- triton.Config(
279
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
280
- num_stages=4,
281
- num_warps=4,
282
- ),
283
- triton.Config(
284
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
285
- num_stages=4,
286
- num_warps=4,
287
- ),
288
- triton.Config(
289
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
290
- num_stages=4,
291
- num_warps=4,
292
- ),
293
- triton.Config(
294
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
295
- num_stages=4,
296
- num_warps=4,
297
- ),
298
- triton.Config(
299
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
300
- num_stages=4,
301
- num_warps=4,
302
- ),
303
- triton.Config(
304
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
305
- num_stages=5,
306
- num_warps=2,
307
- ),
308
- triton.Config(
309
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
310
- num_stages=5,
311
- num_warps=2,
312
- ),
313
- triton.Config(
314
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
315
- num_stages=4,
316
- num_warps=2,
317
- ),
318
- ],
319
- key=["hdim", "dstate", "chunk_size"],
320
- )
321
- @triton.jit
322
- def _chunk_state_fwd_kernel(
323
- # Pointers to matrices
324
- x_ptr,
325
- b_ptr,
326
- states_ptr,
327
- dt_ptr,
328
- dA_cumsum_ptr,
329
- seq_idx_ptr,
330
- # Matrix dimensions
331
- hdim,
332
- dstate,
333
- chunk_size,
334
- batch,
335
- seqlen,
336
- nheads_ngroups_ratio,
337
- # Strides
338
- stride_x_batch,
339
- stride_x_seqlen,
340
- stride_x_head,
341
- stride_x_hdim,
342
- stride_b_batch,
343
- stride_b_seqlen,
344
- stride_b_head,
345
- stride_b_dstate,
346
- stride_states_batch,
347
- stride_states_chunk,
348
- stride_states_head,
349
- stride_states_hdim,
350
- stride_states_dstate,
351
- stride_dt_batch,
352
- stride_dt_chunk,
353
- stride_dt_head,
354
- stride_dt_csize,
355
- stride_dA_cs_batch,
356
- stride_dA_cs_chunk,
357
- stride_dA_cs_head,
358
- stride_dA_cs_csize,
359
- stride_seq_idx_batch,
360
- stride_seq_idx_seqlen,
361
- # Meta-parameters
362
- HAS_SEQ_IDX: tl.constexpr,
363
- BLOCK_SIZE_M: tl.constexpr,
364
- BLOCK_SIZE_N: tl.constexpr,
365
- BLOCK_SIZE_K: tl.constexpr,
366
- ):
367
- pid_bc = tl.program_id(axis=1)
368
- pid_c = pid_bc // batch
369
- pid_b = pid_bc - pid_c * batch
370
- pid_h = tl.program_id(axis=2)
371
- num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
372
- pid_m = tl.program_id(axis=0) // num_pid_n
373
- pid_n = tl.program_id(axis=0) % num_pid_n
374
- b_ptr += (
375
- pid_b * stride_b_batch
376
- + pid_c * chunk_size * stride_b_seqlen
377
- + (pid_h // nheads_ngroups_ratio) * stride_b_head
378
- )
379
- x_ptr += (
380
- pid_b * stride_x_batch
381
- + pid_c * chunk_size * stride_x_seqlen
382
- + pid_h * stride_x_head
383
- )
384
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
385
- dA_cumsum_ptr += (
386
- pid_b * stride_dA_cs_batch
387
- + pid_c * stride_dA_cs_chunk
388
- + pid_h * stride_dA_cs_head
389
- )
390
- if HAS_SEQ_IDX:
391
- seq_idx_ptr += (
392
- pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
393
- )
394
-
395
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
396
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
397
- offs_k = tl.arange(0, BLOCK_SIZE_K)
398
- x_ptrs = x_ptr + (
399
- offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
400
- )
401
- b_ptrs = b_ptr + (
402
- offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
403
- )
404
- dt_ptrs = dt_ptr + offs_k * stride_dt_csize
405
- dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
406
- tl.float32
407
- )
408
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
409
- if HAS_SEQ_IDX:
410
- seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
411
-
412
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
413
- if HAS_SEQ_IDX:
414
- seq_idx_last = tl.load(
415
- seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
416
- )
417
-
418
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
419
- for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
420
- x = tl.load(
421
- x_ptrs,
422
- mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k),
423
- other=0.0,
424
- )
425
- b = tl.load(
426
- b_ptrs,
427
- mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate),
428
- other=0.0,
429
- ).to(tl.float32)
430
- dA_cs_k = tl.load(
431
- dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
432
- ).to(tl.float32)
433
- if HAS_SEQ_IDX:
434
- seq_idx_k = tl.load(
435
- seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1
436
- )
437
- dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
438
- tl.float32
439
- )
440
- if not HAS_SEQ_IDX:
441
- scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
442
- else:
443
- scale = tl.where(
444
- seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0
445
- )
446
- b *= scale[:, None]
447
- b = b.to(x_ptr.dtype.element_ty)
448
- acc += tl.dot(x, b)
449
- x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
450
- b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
451
- dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
452
- dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
453
- if HAS_SEQ_IDX:
454
- seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
455
- states = acc.to(states_ptr.dtype.element_ty)
456
-
457
- states_ptr += (
458
- pid_b * stride_states_batch
459
- + pid_c * stride_states_chunk
460
- + pid_h * stride_states_head
461
- )
462
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
463
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
464
- states_ptrs = states_ptr + (
465
- offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
466
- )
467
- c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
468
- tl.store(states_ptrs, states, mask=c_mask)
469
-
470
-
471
- @triton.autotune(
472
- configs=[
473
- triton.Config(
474
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
475
- num_stages=3,
476
- num_warps=8,
477
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
478
- ),
479
- triton.Config(
480
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
481
- num_stages=4,
482
- num_warps=4,
483
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
484
- ),
485
- triton.Config(
486
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
487
- num_stages=4,
488
- num_warps=4,
489
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
490
- ),
491
- triton.Config(
492
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
493
- num_stages=4,
494
- num_warps=4,
495
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
496
- ),
497
- triton.Config(
498
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
499
- num_stages=4,
500
- num_warps=4,
501
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
502
- ),
503
- triton.Config(
504
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
505
- num_stages=4,
506
- num_warps=4,
507
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
508
- ),
509
- triton.Config(
510
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
511
- num_stages=5,
512
- num_warps=4,
513
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
514
- ),
515
- triton.Config(
516
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
517
- num_stages=5,
518
- num_warps=4,
519
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
520
- ),
521
- triton.Config(
522
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
523
- num_stages=4,
524
- num_warps=4,
525
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
526
- ),
527
- ],
528
- key=["chunk_size", "hdim", "dstate"],
529
- )
530
- @triton.jit
531
- def _chunk_state_bwd_dx_kernel(
532
- # Pointers to matrices
533
- x_ptr,
534
- b_ptr,
535
- dstates_ptr,
536
- dt_ptr,
537
- dA_cumsum_ptr,
538
- dx_ptr,
539
- ddt_ptr,
540
- ddA_cumsum_ptr,
541
- # Matrix dimensions
542
- chunk_size,
543
- hdim,
544
- dstate,
545
- batch,
546
- seqlen,
547
- nheads_ngroups_ratio,
548
- # Strides
549
- stride_x_batch,
550
- stride_x_seqlen,
551
- stride_x_head,
552
- stride_x_hdim,
553
- stride_b_batch,
554
- stride_b_seqlen,
555
- stride_b_head,
556
- stride_b_dstate,
557
- stride_dstates_batch,
558
- stride_dstates_chunk,
559
- stride_states_head,
560
- stride_states_hdim,
561
- stride_states_dstate,
562
- stride_dt_batch,
563
- stride_dt_chunk,
564
- stride_dt_head,
565
- stride_dt_csize,
566
- stride_dA_cs_batch,
567
- stride_dA_cs_chunk,
568
- stride_dA_cs_head,
569
- stride_dA_cs_csize,
570
- stride_dx_batch,
571
- stride_dx_seqlen,
572
- stride_dx_head,
573
- stride_dx_hdim,
574
- stride_ddt_batch,
575
- stride_ddt_chunk,
576
- stride_ddt_head,
577
- stride_ddt_csize,
578
- stride_ddA_cs_batch,
579
- stride_ddA_cs_chunk,
580
- stride_ddA_cs_head,
581
- stride_ddA_cs_csize,
582
- # Meta-parameters
583
- BLOCK_SIZE_M: tl.constexpr,
584
- BLOCK_SIZE_N: tl.constexpr,
585
- BLOCK_SIZE_K: tl.constexpr,
586
- BLOCK_SIZE_DSTATE: tl.constexpr,
587
- ):
588
- pid_bc = tl.program_id(axis=1)
589
- pid_c = pid_bc // batch
590
- pid_b = pid_bc - pid_c * batch
591
- pid_h = tl.program_id(axis=2)
592
- num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
593
- pid_m = tl.program_id(axis=0) // num_pid_n
594
- pid_n = tl.program_id(axis=0) % num_pid_n
595
- x_ptr += (
596
- pid_b * stride_x_batch
597
- + pid_c * chunk_size * stride_x_seqlen
598
- + pid_h * stride_x_head
599
- )
600
- b_ptr += (
601
- pid_b * stride_b_batch
602
- + pid_c * chunk_size * stride_b_seqlen
603
- + (pid_h // nheads_ngroups_ratio) * stride_b_head
604
- )
605
- dstates_ptr += (
606
- pid_b * stride_dstates_batch
607
- + pid_c * stride_dstates_chunk
608
- + pid_h * stride_states_head
609
- )
610
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
611
- ddt_ptr += (
612
- pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
613
- )
614
- ddA_cumsum_ptr += (
615
- pid_b * stride_ddA_cs_batch
616
- + pid_c * stride_ddA_cs_chunk
617
- + pid_h * stride_ddA_cs_head
618
- )
619
- dA_cumsum_ptr += (
620
- pid_b * stride_dA_cs_batch
621
- + pid_c * stride_dA_cs_chunk
622
- + pid_h * stride_dA_cs_head
623
- )
624
-
625
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
626
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
627
-
628
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
629
- # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
630
- offs_k = tl.arange(
631
- 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
632
- )
633
- b_ptrs = b_ptr + (
634
- offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate
635
- )
636
- dstates_ptrs = dstates_ptr + (
637
- offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate
638
- )
639
- if BLOCK_SIZE_DSTATE <= 128:
640
- b = tl.load(
641
- b_ptrs,
642
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate),
643
- other=0.0,
644
- )
645
- dstates = tl.load(
646
- dstates_ptrs,
647
- mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim),
648
- other=0.0,
649
- )
650
- dstates = dstates.to(b_ptr.dtype.element_ty)
651
- acc = tl.dot(b, dstates)
652
- else:
653
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
654
- for k in range(0, dstate, BLOCK_SIZE_K):
655
- b = tl.load(
656
- b_ptrs,
657
- mask=(offs_m[:, None] < chunk_size_limit)
658
- & (offs_k[None, :] < dstate - k),
659
- other=0.0,
660
- )
661
- dstates = tl.load(
662
- dstates_ptrs,
663
- mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim),
664
- other=0.0,
665
- )
666
- dstates = dstates.to(b_ptr.dtype.element_ty)
667
- acc += tl.dot(b, dstates)
668
- b_ptrs += BLOCK_SIZE_K * stride_b_dstate
669
- dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
670
-
671
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
672
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
673
-
674
- dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
675
- tl.float32
676
- )
677
- dt_ptrs = dt_ptr + offs_m * stride_dt_csize
678
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
679
- dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(
680
- tl.float32
681
- )
682
- dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
683
- acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]
684
-
685
- x_ptrs = x_ptr + (
686
- offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
687
- )
688
- x = tl.load(
689
- x_ptrs,
690
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
691
- other=0.0,
692
- ).to(tl.float32)
693
- ddt = tl.sum(acc * x, axis=1)
694
- ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
695
- tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
696
- ddA_cs = -(ddt * dt_m)
697
- ddA_cs_last = -tl.sum(ddA_cs)
698
- ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
699
- tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
700
- tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last)
701
-
702
- dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty)
703
- dx_ptr += (
704
- pid_b * stride_dx_batch
705
- + pid_c * chunk_size * stride_dx_seqlen
706
- + pid_h * stride_dx_head
707
- )
708
- dx_ptrs = dx_ptr + (
709
- offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim
710
- )
711
- tl.store(
712
- dx_ptrs,
713
- dx,
714
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
715
- )
716
-
717
-
718
- @triton.autotune(
719
- configs=[
720
- triton.Config(
721
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128},
722
- num_stages=3,
723
- num_warps=4,
724
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
725
- ),
726
- triton.Config(
727
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32},
728
- num_stages=3,
729
- num_warps=4,
730
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
731
- ),
732
- triton.Config(
733
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128},
734
- num_stages=3,
735
- num_warps=4,
736
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
737
- ),
738
- triton.Config(
739
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64},
740
- num_stages=3,
741
- num_warps=4,
742
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
743
- ),
744
- triton.Config(
745
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64},
746
- num_stages=3,
747
- num_warps=4,
748
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
749
- ),
750
- triton.Config(
751
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32},
752
- num_stages=3,
753
- num_warps=4,
754
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
755
- ),
756
- triton.Config(
757
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64},
758
- num_stages=3,
759
- num_warps=4,
760
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
761
- ),
762
- triton.Config(
763
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32},
764
- num_stages=3,
765
- num_warps=4,
766
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
767
- ),
768
- ],
769
- key=["chunk_size", "dstate", "hdim"],
770
- )
771
- @triton.jit
772
- def _chunk_state_bwd_db_kernel(
773
- # Pointers to matrices
774
- x_ptr,
775
- dstates_ptr,
776
- b_ptr,
777
- dt_ptr,
778
- dA_cumsum_ptr,
779
- seq_idx_ptr,
780
- db_ptr,
781
- ddA_cumsum_ptr,
782
- # Matrix dimensions
783
- chunk_size,
784
- dstate,
785
- hdim,
786
- batch,
787
- seqlen,
788
- nheads,
789
- nheads_per_program,
790
- ngroups,
791
- # Strides
792
- stride_x_batch,
793
- stride_x_seqlen,
794
- stride_x_head,
795
- stride_x_hdim,
796
- stride_dstates_batch,
797
- stride_dstates_chunk,
798
- stride_states_head,
799
- stride_states_hdim,
800
- stride_states_dstate,
801
- stride_b_batch,
802
- stride_b_seqlen,
803
- stride_b_head,
804
- stride_b_dstate,
805
- stride_dt_batch,
806
- stride_dt_chunk,
807
- stride_dt_head,
808
- stride_dt_csize,
809
- stride_dA_cs_batch,
810
- stride_dA_cs_chunk,
811
- stride_dA_cs_head,
812
- stride_dA_cs_csize,
813
- stride_seq_idx_batch,
814
- stride_seq_idx_seqlen,
815
- stride_db_batch,
816
- stride_db_seqlen,
817
- stride_db_split,
818
- stride_db_group,
819
- stride_db_dstate,
820
- stride_ddA_cs_batch,
821
- stride_ddA_cs_chunk,
822
- stride_ddA_cs_head,
823
- stride_ddA_cs_csize,
824
- # Meta-parameters
825
- HAS_DDA_CS: tl.constexpr,
826
- HAS_SEQ_IDX: tl.constexpr,
827
- BLOCK_SIZE_M: tl.constexpr,
828
- BLOCK_SIZE_N: tl.constexpr,
829
- BLOCK_SIZE_K: tl.constexpr,
830
- ):
831
- pid_bc = tl.program_id(axis=1)
832
- pid_c = pid_bc // batch
833
- pid_b = pid_bc - pid_c * batch
834
- pid_sg = tl.program_id(axis=2)
835
- pid_s = pid_sg // ngroups
836
- pid_g = pid_sg - pid_s * ngroups
837
- num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
838
- pid_m = tl.program_id(axis=0) // num_pid_n
839
- pid_n = tl.program_id(axis=0) % num_pid_n
840
- x_ptr += (
841
- pid_b * stride_x_batch
842
- + pid_c * chunk_size * stride_x_seqlen
843
- + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head
844
- )
845
- db_ptr += (
846
- pid_b * stride_db_batch
847
- + pid_c * chunk_size * stride_db_seqlen
848
- + pid_g * stride_db_group
849
- + pid_s * stride_db_split
850
- )
851
- dstates_ptr += (
852
- pid_b * stride_dstates_batch
853
- + pid_c * stride_dstates_chunk
854
- + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program)
855
- * stride_states_head
856
- )
857
- dt_ptr += (
858
- pid_b * stride_dt_batch
859
- + pid_c * stride_dt_chunk
860
- + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head
861
- )
862
- dA_cumsum_ptr += (
863
- pid_b * stride_dA_cs_batch
864
- + pid_c * stride_dA_cs_chunk
865
- + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head
866
- )
867
- if HAS_DDA_CS:
868
- b_ptr += (
869
- pid_b * stride_b_batch
870
- + pid_c * chunk_size * stride_b_seqlen
871
- + pid_g * stride_b_head
872
- )
873
- ddA_cumsum_ptr += (
874
- pid_b * stride_ddA_cs_batch
875
- + pid_c * stride_ddA_cs_chunk
876
- + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program)
877
- * stride_ddA_cs_head
878
- )
879
- if HAS_SEQ_IDX:
880
- seq_idx_ptr += (
881
- pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
882
- )
883
-
884
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
885
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
886
- offs_k = tl.arange(0, BLOCK_SIZE_K)
887
- x_ptrs = x_ptr + (
888
- offs_m[:, None] * stride_x_seqlen + offs_k[None, :] * stride_x_hdim
889
- )
890
- dstates_ptrs = dstates_ptr + (
891
- offs_n[None, :] * stride_states_dstate + offs_k[:, None] * stride_states_hdim
892
- )
893
- dt_ptrs = dt_ptr + offs_m * stride_dt_csize
894
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
895
- if HAS_DDA_CS:
896
- b_ptrs = b_ptr + (
897
- offs_m[:, None] * stride_b_seqlen + offs_n[None, :] * stride_b_dstate
898
- )
899
- ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
900
-
901
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
902
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
903
- if HAS_DDA_CS:
904
- b = tl.load(
905
- b_ptrs,
906
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate),
907
- other=0.0,
908
- ).to(tl.float32)
909
- if HAS_SEQ_IDX:
910
- seq_idx_m = tl.load(
911
- seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
912
- mask=offs_m < chunk_size_limit,
913
- other=-1,
914
- )
915
- seq_idx_last = tl.load(
916
- seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
917
- )
918
- nheads_iter = min(
919
- nheads_per_program, nheads // ngroups - pid_s * nheads_per_program
920
- )
921
- for h in range(nheads_iter):
922
- x = tl.load(
923
- x_ptrs,
924
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim),
925
- other=0.0,
926
- )
927
- dstates = tl.load(
928
- dstates_ptrs,
929
- mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate),
930
- other=0.0,
931
- )
932
- dstates = dstates.to(x_ptrs.dtype.element_ty)
933
- db = tl.dot(x, dstates)
934
- dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
935
- tl.float32
936
- )
937
- dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(
938
- tl.float32
939
- )
940
- dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
941
- if not HAS_SEQ_IDX:
942
- scale = tl.exp(dA_cs_last - dA_cs_m)
943
- else:
944
- scale = tl.where(
945
- seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0
946
- )
947
- db *= (scale * dt_m)[:, None]
948
- if HAS_DDA_CS:
949
- # This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum
950
- ddA_cs = tl.sum(db * b, axis=1)
951
- tl.atomic_add(
952
- ddA_cumsum_ptrs + stride_ddA_cs_csize,
953
- ddA_cs,
954
- mask=offs_m < chunk_size - 1,
955
- )
956
- acc += db
957
- x_ptrs += stride_x_head
958
- dstates_ptrs += stride_states_head
959
- dt_ptrs += stride_dt_head
960
- dA_cumsum_ptr += stride_dA_cs_head
961
- dA_cumsum_ptrs += stride_dA_cs_head
962
- if HAS_DDA_CS:
963
- ddA_cumsum_ptrs += stride_ddA_cs_head
964
-
965
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
966
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
967
- # if HAS_SEQ_IDX:
968
- # seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
969
- # seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
970
- # acc = tl.where(seq_idx_m[:, None] == seq_idx_last, acc, 0.0)
971
- db_ptrs = db_ptr + (
972
- offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_dstate
973
- )
974
- tl.store(
975
- db_ptrs,
976
- acc,
977
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate),
978
- )
979
-
980
-
981
- @triton.autotune(
982
- configs=[
983
- # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
984
- # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
985
- # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
986
- # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
987
- # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
988
- # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
989
- # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
990
- # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
991
- # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
992
- triton.Config(
993
- {"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32},
994
- num_stages=3,
995
- num_warps=4,
996
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
997
- ),
998
- triton.Config(
999
- {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
1000
- num_stages=3,
1001
- num_warps=4,
1002
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1003
- ),
1004
- triton.Config(
1005
- {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1006
- num_stages=3,
1007
- num_warps=4,
1008
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1009
- ),
1010
- triton.Config(
1011
- {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
1012
- num_stages=3,
1013
- num_warps=4,
1014
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1015
- ),
1016
- triton.Config(
1017
- {"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32},
1018
- num_stages=4,
1019
- num_warps=8,
1020
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1021
- ),
1022
- triton.Config(
1023
- {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
1024
- num_stages=4,
1025
- num_warps=8,
1026
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1027
- ),
1028
- triton.Config(
1029
- {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1030
- num_stages=4,
1031
- num_warps=8,
1032
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1033
- ),
1034
- triton.Config(
1035
- {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
1036
- num_stages=4,
1037
- num_warps=8,
1038
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1039
- ),
1040
- ],
1041
- key=["chunk_size", "hdim", "dstate"],
1042
- )
1043
- @triton.jit
1044
- def _chunk_state_bwd_ddAcs_stable_kernel(
1045
- # Pointers to matrices
1046
- x_ptr,
1047
- b_ptr,
1048
- dstates_ptr,
1049
- dt_ptr,
1050
- dA_cumsum_ptr,
1051
- seq_idx_ptr,
1052
- ddA_cumsum_ptr,
1053
- # Matrix dimensions
1054
- chunk_size,
1055
- hdim,
1056
- dstate,
1057
- batch,
1058
- seqlen,
1059
- nheads_ngroups_ratio,
1060
- # Strides
1061
- stride_x_batch,
1062
- stride_x_seqlen,
1063
- stride_x_head,
1064
- stride_x_hdim,
1065
- stride_b_batch,
1066
- stride_b_seqlen,
1067
- stride_b_head,
1068
- stride_b_dstate,
1069
- stride_dstates_batch,
1070
- stride_dstates_chunk,
1071
- stride_states_head,
1072
- stride_states_hdim,
1073
- stride_states_dstate,
1074
- stride_dt_batch,
1075
- stride_dt_chunk,
1076
- stride_dt_head,
1077
- stride_dt_csize,
1078
- stride_dA_cs_batch,
1079
- stride_dA_cs_chunk,
1080
- stride_dA_cs_head,
1081
- stride_dA_cs_csize,
1082
- stride_seq_idx_batch,
1083
- stride_seq_idx_seqlen,
1084
- stride_ddA_cs_batch,
1085
- stride_ddA_cs_chunk,
1086
- stride_ddA_cs_head,
1087
- stride_ddA_cs_csize,
1088
- # Meta-parameters
1089
- HAS_SEQ_IDX: tl.constexpr,
1090
- BLOCK_SIZE_M: tl.constexpr,
1091
- BLOCK_SIZE_N: tl.constexpr,
1092
- BLOCK_SIZE_K: tl.constexpr,
1093
- BLOCK_SIZE_DSTATE: tl.constexpr,
1094
- ):
1095
- pid_bc = tl.program_id(axis=1)
1096
- pid_c = pid_bc // batch
1097
- pid_b = pid_bc - pid_c * batch
1098
- pid_h = tl.program_id(axis=2)
1099
- num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
1100
- pid_m = tl.program_id(axis=0) // num_pid_n
1101
- pid_n = tl.program_id(axis=0) % num_pid_n
1102
- x_ptr += (
1103
- pid_b * stride_x_batch
1104
- + pid_c * chunk_size * stride_x_seqlen
1105
- + pid_h * stride_x_head
1106
- )
1107
- b_ptr += (
1108
- pid_b * stride_b_batch
1109
- + pid_c * chunk_size * stride_b_seqlen
1110
- + (pid_h // nheads_ngroups_ratio) * stride_b_head
1111
- )
1112
- dstates_ptr += (
1113
- pid_b * stride_dstates_batch
1114
- + pid_c * stride_dstates_chunk
1115
- + pid_h * stride_states_head
1116
- )
1117
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
1118
- ddA_cumsum_ptr += (
1119
- pid_b * stride_ddA_cs_batch
1120
- + pid_c * stride_ddA_cs_chunk
1121
- + pid_h * stride_ddA_cs_head
1122
- )
1123
- dA_cumsum_ptr += (
1124
- pid_b * stride_dA_cs_batch
1125
- + pid_c * stride_dA_cs_chunk
1126
- + pid_h * stride_dA_cs_head
1127
- )
1128
- if HAS_SEQ_IDX:
1129
- seq_idx_ptr += (
1130
- pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
1131
- )
1132
-
1133
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1134
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1135
-
1136
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
1137
- # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
1138
- offs_k = tl.arange(
1139
- 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
1140
- )
1141
- b_ptrs = b_ptr + (
1142
- offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate
1143
- )
1144
- dstates_ptrs = dstates_ptr + (
1145
- offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate
1146
- )
1147
- if BLOCK_SIZE_DSTATE <= 128:
1148
- b = tl.load(
1149
- b_ptrs,
1150
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate),
1151
- other=0.0,
1152
- )
1153
- dstates = tl.load(
1154
- dstates_ptrs,
1155
- mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim),
1156
- other=0.0,
1157
- )
1158
- dstates = dstates.to(b_ptr.dtype.element_ty)
1159
- acc = tl.dot(b, dstates)
1160
- else:
1161
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
1162
- for k in range(0, dstate, BLOCK_SIZE_K):
1163
- b = tl.load(
1164
- b_ptrs,
1165
- mask=(offs_m[:, None] < chunk_size_limit)
1166
- & (offs_k[None, :] < dstate - k),
1167
- other=0.0,
1168
- )
1169
- dstates = tl.load(
1170
- dstates_ptrs,
1171
- mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim),
1172
- other=0.0,
1173
- )
1174
- dstates = dstates.to(b_ptr.dtype.element_ty)
1175
- acc += tl.dot(b, dstates)
1176
- b_ptrs += BLOCK_SIZE_K * stride_b_dstate
1177
- dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
1178
-
1179
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1180
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1181
-
1182
- dA_cs_m = tl.load(
1183
- dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0
1184
- ).to(tl.float32)
1185
- dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
1186
- tl.float32
1187
- )
1188
- if not HAS_SEQ_IDX:
1189
- scale = tl.exp(dA_cs_last - dA_cs_m)
1190
- else:
1191
- seq_idx_m = tl.load(
1192
- seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
1193
- mask=offs_m < chunk_size_limit,
1194
- other=-1,
1195
- )
1196
- seq_idx_last = tl.load(
1197
- seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
1198
- )
1199
- scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
1200
- acc *= scale[:, None]
1201
-
1202
- x_ptrs = x_ptr + (
1203
- offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
1204
- )
1205
- x = tl.load(
1206
- x_ptrs,
1207
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
1208
- other=0.0,
1209
- ).to(tl.float32)
1210
- dt_ptrs = dt_ptr + offs_m * stride_dt_csize
1211
- dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
1212
- ddt = tl.sum(acc * x, axis=1)
1213
- # ddA_cs = -(ddt * dt_m)
1214
- # Triton 2.2.0 errors if we have the cumsum here, so we just write it out
1215
- # then call torch.cumsum outside this kernel.
1216
- # ddA_cs = tl.cumsum(ddt * dt_m)
1217
- ddA_cs = ddt * dt_m
1218
- ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
1219
- # tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
1220
- tl.atomic_add(
1221
- ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1
1222
- )
1223
-
1224
-
1225
- @triton.autotune(
1226
- configs=[
1227
- triton.Config(
1228
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
1229
- num_stages=3,
1230
- num_warps=8,
1231
- ),
1232
- triton.Config(
1233
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
1234
- num_stages=4,
1235
- num_warps=4,
1236
- ),
1237
- triton.Config(
1238
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
1239
- num_stages=4,
1240
- num_warps=4,
1241
- ),
1242
- triton.Config(
1243
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1244
- num_stages=4,
1245
- num_warps=4,
1246
- ),
1247
- triton.Config(
1248
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
1249
- num_stages=4,
1250
- num_warps=4,
1251
- ),
1252
- triton.Config(
1253
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
1254
- num_stages=4,
1255
- num_warps=4,
1256
- ),
1257
- triton.Config(
1258
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
1259
- num_stages=5,
1260
- num_warps=2,
1261
- ),
1262
- triton.Config(
1263
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1264
- num_stages=5,
1265
- num_warps=2,
1266
- ),
1267
- triton.Config(
1268
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1269
- num_stages=4,
1270
- num_warps=2,
1271
- ),
1272
- ],
1273
- key=["hdim", "dstate", "chunk_size"],
1274
- )
1275
- @triton.jit
1276
- def _chunk_state_varlen_kernel(
1277
- # Pointers to matrices
1278
- x_ptr,
1279
- b_ptr,
1280
- dt_ptr,
1281
- dA_cumsum_ptr,
1282
- chunk_states_ptr,
1283
- cu_seqlens_ptr,
1284
- states_ptr,
1285
- # Matrix dimensions
1286
- hdim,
1287
- dstate,
1288
- chunk_size,
1289
- seqlen,
1290
- nheads_ngroups_ratio,
1291
- # Strides
1292
- stride_x_seqlen,
1293
- stride_x_head,
1294
- stride_x_hdim,
1295
- stride_b_seqlen,
1296
- stride_b_head,
1297
- stride_b_dstate,
1298
- stride_dt_chunk,
1299
- stride_dt_head,
1300
- stride_dt_csize,
1301
- stride_dA_cs_chunk,
1302
- stride_dA_cs_head,
1303
- stride_dA_cs_csize,
1304
- stride_chunk_states_chunk,
1305
- stride_chunk_states_head,
1306
- stride_chunk_states_hdim,
1307
- stride_chunk_states_dstate,
1308
- stride_states_batch,
1309
- stride_states_head,
1310
- stride_states_hdim,
1311
- stride_states_dstate,
1312
- # Meta-parameters
1313
- BLOCK_SIZE_M: tl.constexpr,
1314
- BLOCK_SIZE_N: tl.constexpr,
1315
- BLOCK_SIZE_K: tl.constexpr,
1316
- ):
1317
- pid_b = tl.program_id(axis=1)
1318
- pid_h = tl.program_id(axis=2)
1319
- num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
1320
- pid_m = tl.program_id(axis=0) // num_pid_n
1321
- pid_n = tl.program_id(axis=0) % num_pid_n
1322
- end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
1323
- pid_c = (end_idx - 1) // chunk_size
1324
- b_ptr += (
1325
- pid_c * chunk_size * stride_b_seqlen
1326
- + (pid_h // nheads_ngroups_ratio) * stride_b_head
1327
- )
1328
- x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
1329
- dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
1330
- dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
1331
- chunk_states_ptr += (
1332
- pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
1333
- )
1334
-
1335
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1336
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1337
- offs_k = tl.arange(0, BLOCK_SIZE_K)
1338
- x_ptrs = x_ptr + (
1339
- offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
1340
- )
1341
- b_ptrs = b_ptr + (
1342
- offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
1343
- )
1344
- dt_ptrs = dt_ptr + offs_k * stride_dt_csize
1345
- dA_cs_last = tl.load(
1346
- dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
1347
- ).to(tl.float32)
1348
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
1349
-
1350
- chunk_size_limit = end_idx - pid_c * chunk_size
1351
- start_idx = tl.load(cu_seqlens_ptr + pid_b)
1352
- start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
1353
-
1354
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
1355
- for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
1356
- x = tl.load(
1357
- x_ptrs,
1358
- mask=(offs_m[:, None] < hdim)
1359
- & (offs_k[None, :] < chunk_size_limit - k)
1360
- & (offs_k[None, :] >= start_idx_cur - k),
1361
- other=0.0,
1362
- )
1363
- b = tl.load(
1364
- b_ptrs,
1365
- mask=(offs_k[:, None] < chunk_size_limit - k)
1366
- & (offs_n[None, :] < dstate)
1367
- & (offs_k[:, None] >= start_idx_cur - k),
1368
- other=0.0,
1369
- ).to(tl.float32)
1370
- dA_cs_k = tl.load(
1371
- dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
1372
- ).to(tl.float32)
1373
- dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
1374
- tl.float32
1375
- )
1376
- scale = tl.where(
1377
- (offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
1378
- tl.exp((dA_cs_last - dA_cs_k)) * dt_k,
1379
- 0.0,
1380
- )
1381
- b *= scale[:, None]
1382
- b = b.to(x_ptr.dtype.element_ty)
1383
- acc += tl.dot(x, b)
1384
- x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
1385
- b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
1386
- dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
1387
- dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
1388
-
1389
- # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
1390
- if start_idx < pid_c * chunk_size:
1391
- chunk_states_ptrs = chunk_states_ptr + (
1392
- offs_m[:, None] * stride_chunk_states_hdim
1393
- + offs_n[None, :] * stride_chunk_states_dstate
1394
- )
1395
- chunk_states = tl.load(
1396
- chunk_states_ptrs,
1397
- mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate),
1398
- other=0.0,
1399
- ).to(tl.float32)
1400
- # scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0)
1401
- scale = tl.exp(dA_cs_last)
1402
- acc += chunk_states * scale
1403
-
1404
- states = acc.to(states_ptr.dtype.element_ty)
1405
-
1406
- states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
1407
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1408
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1409
- states_ptrs = states_ptr + (
1410
- offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
1411
- )
1412
- c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
1413
- tl.store(states_ptrs, states, mask=c_mask)
1414
-
1415
-
1416
- def _chunk_cumsum_fwd(
1417
- dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))
1418
- ):
1419
- batch, seqlen, nheads = dt.shape
1420
- assert A.shape == (nheads,)
1421
- if dt_bias is not None:
1422
- assert dt_bias.shape == (nheads,)
1423
- nchunks = math.ceil(seqlen / chunk_size)
1424
- dt_out = torch.empty(
1425
- batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
1426
- )
1427
- dA_cumsum = torch.empty(
1428
- batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
1429
- )
1430
- grid_chunk_cs = lambda META: (
1431
- batch,
1432
- nchunks,
1433
- triton.cdiv(nheads, META["BLOCK_SIZE_H"]),
1434
- )
1435
- with torch.cuda.device(dt.device.index):
1436
- _chunk_cumsum_fwd_kernel[grid_chunk_cs](
1437
- dt,
1438
- A,
1439
- dt_bias,
1440
- dt_out,
1441
- dA_cumsum,
1442
- batch,
1443
- seqlen,
1444
- nheads,
1445
- chunk_size,
1446
- dt_limit[0],
1447
- dt_limit[1],
1448
- dt.stride(0),
1449
- dt.stride(1),
1450
- dt.stride(2),
1451
- A.stride(0),
1452
- dt_bias.stride(0) if dt_bias is not None else 0,
1453
- dt_out.stride(0),
1454
- dt_out.stride(2),
1455
- dt_out.stride(1),
1456
- dt_out.stride(3),
1457
- dA_cumsum.stride(0),
1458
- dA_cumsum.stride(2),
1459
- dA_cumsum.stride(1),
1460
- dA_cumsum.stride(3),
1461
- dt_softplus,
1462
- HAS_DT_BIAS=dt_bias is not None,
1463
- BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
1464
- )
1465
- return dA_cumsum, dt_out
1466
-
1467
-
1468
- def _chunk_cumsum_bwd(
1469
- ddA,
1470
- ddt_out,
1471
- dt,
1472
- A,
1473
- dt_bias=None,
1474
- dt_softplus=False,
1475
- dt_limit=(0.0, float("inf")),
1476
- ddt=None,
1477
- ):
1478
- batch, seqlen, nheads = dt.shape
1479
- _, _, nchunks, chunk_size = ddA.shape
1480
- assert ddA.shape == (batch, nheads, nchunks, chunk_size)
1481
- assert ddt_out.shape == (batch, nheads, nchunks, chunk_size)
1482
- assert A.shape == (nheads,)
1483
- if dt_bias is not None:
1484
- assert dt_bias.shape == (nheads,)
1485
- ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32)
1486
- else:
1487
- ddt_bias = None
1488
- if ddt is not None:
1489
- assert ddt.shape == dt.shape
1490
- else:
1491
- ddt = torch.empty_like(dt)
1492
- dA = torch.empty_like(A, dtype=torch.float32)
1493
- grid_chunk_cs = lambda META: (
1494
- batch,
1495
- nchunks,
1496
- triton.cdiv(nheads, META["BLOCK_SIZE_H"]),
1497
- )
1498
- with torch.cuda.device(dt.device.index):
1499
- _chunk_cumsum_bwd_kernel[grid_chunk_cs](
1500
- ddA,
1501
- ddt_out,
1502
- dt,
1503
- A,
1504
- dt_bias,
1505
- ddt,
1506
- dA,
1507
- ddt_bias,
1508
- batch,
1509
- seqlen,
1510
- nheads,
1511
- chunk_size,
1512
- dt_limit[0],
1513
- dt_limit[1],
1514
- ddA.stride(0),
1515
- ddA.stride(2),
1516
- ddA.stride(1),
1517
- ddA.stride(3),
1518
- ddt_out.stride(0),
1519
- ddt_out.stride(2),
1520
- ddt_out.stride(1),
1521
- ddt_out.stride(3),
1522
- dt.stride(0),
1523
- dt.stride(1),
1524
- dt.stride(2),
1525
- A.stride(0),
1526
- dt_bias.stride(0) if dt_bias is not None else 0,
1527
- ddt.stride(0),
1528
- ddt.stride(1),
1529
- ddt.stride(2),
1530
- dA.stride(0),
1531
- ddt_bias.stride(0) if ddt_bias is not None else 0,
1532
- dt_softplus,
1533
- HAS_DT_BIAS=dt_bias is not None,
1534
- BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
1535
- )
1536
- return ddt, dA, ddt_bias
1537
-
1538
-
1539
- def _chunk_state_fwd(
1540
- B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True
1541
- ):
1542
- batch, seqlen, nheads, headdim = x.shape
1543
- _, _, nchunks, chunk_size = dt.shape
1544
- _, _, ngroups, dstate = B.shape
1545
- assert nheads % ngroups == 0
1546
- assert B.shape == (batch, seqlen, ngroups, dstate)
1547
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
1548
- assert dA_cumsum.shape == dt.shape
1549
- if seq_idx is not None:
1550
- assert seq_idx.shape == (batch, seqlen)
1551
- if states is not None:
1552
- assert states.shape == (batch, nchunks, nheads, headdim, dstate)
1553
- else:
1554
- states_dtype = torch.float32 if states_in_fp32 else B.dtype
1555
- states = torch.empty(
1556
- (batch, nchunks, nheads, headdim, dstate),
1557
- device=x.device,
1558
- dtype=states_dtype,
1559
- )
1560
- grid = lambda META: (
1561
- triton.cdiv(headdim, META["BLOCK_SIZE_M"])
1562
- * triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
1563
- batch * nchunks,
1564
- nheads,
1565
- )
1566
- with torch.cuda.device(x.device.index):
1567
- _chunk_state_fwd_kernel[grid](
1568
- x,
1569
- B,
1570
- states,
1571
- dt,
1572
- dA_cumsum,
1573
- seq_idx,
1574
- headdim,
1575
- dstate,
1576
- chunk_size,
1577
- batch,
1578
- seqlen,
1579
- nheads // ngroups,
1580
- x.stride(0),
1581
- x.stride(1),
1582
- x.stride(2),
1583
- x.stride(3),
1584
- B.stride(0),
1585
- B.stride(1),
1586
- B.stride(2),
1587
- B.stride(-1),
1588
- states.stride(0),
1589
- states.stride(1),
1590
- states.stride(2),
1591
- states.stride(3),
1592
- states.stride(4),
1593
- dt.stride(0),
1594
- dt.stride(2),
1595
- dt.stride(1),
1596
- dt.stride(3),
1597
- dA_cumsum.stride(0),
1598
- dA_cumsum.stride(2),
1599
- dA_cumsum.stride(1),
1600
- dA_cumsum.stride(3),
1601
- *(
1602
- (seq_idx.stride(0), seq_idx.stride(1))
1603
- if seq_idx is not None
1604
- else (0, 0)
1605
- ),
1606
- HAS_SEQ_IDX=seq_idx is not None,
1607
- )
1608
- return states
1609
-
1610
-
1611
- def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None):
1612
- batch, seqlen, nheads, headdim = x.shape
1613
- _, _, nchunks, chunk_size = dt.shape
1614
- _, _, ngroups, dstate = B.shape
1615
- assert nheads % ngroups == 0
1616
- assert B.shape == (batch, seqlen, ngroups, dstate)
1617
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
1618
- assert dA_cumsum.shape == dt.shape
1619
- assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
1620
- if dx is not None:
1621
- assert dx.shape == x.shape
1622
- else:
1623
- dx = torch.empty_like(x)
1624
- ddt = torch.empty(
1625
- batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
1626
- )
1627
- ddA_cumsum = torch.empty(
1628
- batch, nheads, nchunks, chunk_size, device=dA_cumsum.device, dtype=torch.float32
1629
- )
1630
- grid_dx = lambda META: (
1631
- triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
1632
- * triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
1633
- batch * nchunks,
1634
- nheads,
1635
- )
1636
- with torch.cuda.device(x.device.index):
1637
- _chunk_state_bwd_dx_kernel[grid_dx](
1638
- x,
1639
- B,
1640
- dstates,
1641
- dt,
1642
- dA_cumsum,
1643
- dx,
1644
- ddt,
1645
- ddA_cumsum,
1646
- chunk_size,
1647
- headdim,
1648
- dstate,
1649
- batch,
1650
- seqlen,
1651
- nheads // ngroups,
1652
- x.stride(0),
1653
- x.stride(1),
1654
- x.stride(2),
1655
- x.stride(3),
1656
- B.stride(0),
1657
- B.stride(1),
1658
- B.stride(2),
1659
- B.stride(-1),
1660
- dstates.stride(0),
1661
- dstates.stride(1),
1662
- dstates.stride(2),
1663
- dstates.stride(3),
1664
- dstates.stride(4),
1665
- dt.stride(0),
1666
- dt.stride(2),
1667
- dt.stride(1),
1668
- dt.stride(3),
1669
- dA_cumsum.stride(0),
1670
- dA_cumsum.stride(2),
1671
- dA_cumsum.stride(1),
1672
- dA_cumsum.stride(3),
1673
- dx.stride(0),
1674
- dx.stride(1),
1675
- dx.stride(2),
1676
- dx.stride(3),
1677
- ddt.stride(0),
1678
- ddt.stride(2),
1679
- ddt.stride(1),
1680
- ddt.stride(3),
1681
- ddA_cumsum.stride(0),
1682
- ddA_cumsum.stride(2),
1683
- ddA_cumsum.stride(1),
1684
- ddA_cumsum.stride(3),
1685
- BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
1686
- )
1687
- return dx, ddt.to(dt.dtype), ddA_cumsum.to(dA_cumsum.dtype)
1688
-
1689
-
1690
- def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1):
1691
- batch, seqlen, nheads, headdim = x.shape
1692
- _, _, nchunks, chunk_size = dt.shape
1693
- dstate = dstates.shape[-1]
1694
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
1695
- assert dA_cumsum.shape == dt.shape
1696
- assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
1697
- if seq_idx is not None:
1698
- assert seq_idx.shape == (batch, seqlen)
1699
- if B is not None:
1700
- assert B.shape == (batch, seqlen, ngroups, dstate)
1701
- B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3))
1702
- # Use torch.empty since the Triton kernel will call init_to_zero
1703
- ddA_cumsum = torch.empty(
1704
- batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32
1705
- )
1706
- ddA_cumsum_strides = (
1707
- ddA_cumsum.stride(0),
1708
- ddA_cumsum.stride(2),
1709
- ddA_cumsum.stride(1),
1710
- ddA_cumsum.stride(3),
1711
- )
1712
- else:
1713
- B_strides = (0, 0, 0, 0)
1714
- ddA_cumsum = None
1715
- ddA_cumsum_strides = (0, 0, 0, 0)
1716
- nheads_ngroups_ratio = nheads // ngroups
1717
- sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
1718
- nheads_per_program = max(
1719
- min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1
1720
- )
1721
- nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)
1722
- dB = torch.empty(
1723
- batch, seqlen, nsplits, ngroups, dstate, device=x.device, dtype=torch.float32
1724
- )
1725
- grid_db = lambda META: (
1726
- triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
1727
- * triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
1728
- batch * nchunks,
1729
- nsplits * ngroups,
1730
- )
1731
- with torch.cuda.device(x.device.index):
1732
- _chunk_state_bwd_db_kernel[grid_db](
1733
- x,
1734
- dstates,
1735
- B,
1736
- dt,
1737
- dA_cumsum,
1738
- seq_idx,
1739
- dB,
1740
- ddA_cumsum,
1741
- chunk_size,
1742
- dstate,
1743
- headdim,
1744
- batch,
1745
- seqlen,
1746
- nheads,
1747
- nheads_per_program,
1748
- ngroups,
1749
- x.stride(0),
1750
- x.stride(1),
1751
- x.stride(2),
1752
- x.stride(3),
1753
- dstates.stride(0),
1754
- dstates.stride(1),
1755
- dstates.stride(2),
1756
- dstates.stride(3),
1757
- dstates.stride(4),
1758
- *B_strides,
1759
- dt.stride(0),
1760
- dt.stride(2),
1761
- dt.stride(1),
1762
- dt.stride(3),
1763
- dA_cumsum.stride(0),
1764
- dA_cumsum.stride(2),
1765
- dA_cumsum.stride(1),
1766
- dA_cumsum.stride(3),
1767
- *(
1768
- (seq_idx.stride(0), seq_idx.stride(1))
1769
- if seq_idx is not None
1770
- else (0, 0)
1771
- ),
1772
- dB.stride(0),
1773
- dB.stride(1),
1774
- dB.stride(2),
1775
- dB.stride(3),
1776
- dB.stride(4),
1777
- *ddA_cumsum_strides,
1778
- HAS_DDA_CS=ddA_cumsum is not None,
1779
- HAS_SEQ_IDX=seq_idx is not None,
1780
- BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
1781
- )
1782
- dB = dB.sum(2)
1783
- if ddA_cumsum is not None:
1784
- # The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute
1785
- # to the state of the chunk.
1786
- # torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
1787
- # But it's easier to just do the cumsum for all elements, the result will be the same.
1788
- torch.cumsum(ddA_cumsum, dim=-1, out=ddA_cumsum)
1789
- return dB if B is None else (dB, ddA_cumsum)
1790
-
1791
-
1792
- def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None):
1793
- batch, seqlen, nheads, headdim = x.shape
1794
- _, _, nchunks, chunk_size = dt.shape
1795
- _, _, ngroups, dstate = B.shape
1796
- assert nheads % ngroups == 0
1797
- assert B.shape == (batch, seqlen, ngroups, dstate)
1798
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
1799
- assert dA_cumsum.shape == dt.shape
1800
- assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
1801
- if seq_idx is not None:
1802
- assert seq_idx.shape == (batch, seqlen)
1803
- # Use torch.empty since the Triton kernel will call init_to_zero
1804
- ddA_cumsum = torch.empty(
1805
- batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32
1806
- )
1807
- grid_ddtcs = lambda META: (
1808
- triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
1809
- * triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
1810
- batch * nchunks,
1811
- nheads,
1812
- )
1813
- with torch.cuda.device(x.device.index):
1814
- _chunk_state_bwd_ddAcs_stable_kernel[grid_ddtcs](
1815
- x,
1816
- B,
1817
- dstates,
1818
- dt,
1819
- dA_cumsum,
1820
- seq_idx,
1821
- ddA_cumsum,
1822
- chunk_size,
1823
- headdim,
1824
- dstate,
1825
- batch,
1826
- seqlen,
1827
- nheads // ngroups,
1828
- x.stride(0),
1829
- x.stride(1),
1830
- x.stride(2),
1831
- x.stride(3),
1832
- B.stride(0),
1833
- B.stride(1),
1834
- B.stride(2),
1835
- B.stride(-1),
1836
- dstates.stride(0),
1837
- dstates.stride(1),
1838
- dstates.stride(2),
1839
- dstates.stride(3),
1840
- dstates.stride(4),
1841
- dt.stride(0),
1842
- dt.stride(2),
1843
- dt.stride(1),
1844
- dt.stride(3),
1845
- dA_cumsum.stride(0),
1846
- dA_cumsum.stride(2),
1847
- dA_cumsum.stride(1),
1848
- dA_cumsum.stride(3),
1849
- *(
1850
- (seq_idx.stride(0), seq_idx.stride(1))
1851
- if seq_idx is not None
1852
- else (0, 0)
1853
- ),
1854
- ddA_cumsum.stride(0),
1855
- ddA_cumsum.stride(2),
1856
- ddA_cumsum.stride(1),
1857
- ddA_cumsum.stride(3),
1858
- HAS_SEQ_IDX=seq_idx is not None,
1859
- BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16),
1860
- BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
1861
- )
1862
- torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
1863
- return ddA_cumsum
1864
-
1865
-
1866
- def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states):
1867
- total_seqlen, nheads, headdim = x.shape
1868
- _, nchunks, chunk_size = dt.shape
1869
- _, ngroups, dstate = B.shape
1870
- batch = cu_seqlens.shape[0] - 1
1871
- cu_seqlens = cu_seqlens.contiguous()
1872
- assert nheads % ngroups == 0
1873
- assert B.shape == (total_seqlen, ngroups, dstate)
1874
- assert dt.shape == (nheads, nchunks, chunk_size)
1875
- assert dA_cumsum.shape == dt.shape
1876
- assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
1877
- states = torch.empty(
1878
- batch,
1879
- nheads,
1880
- headdim,
1881
- dstate,
1882
- dtype=chunk_states.dtype,
1883
- device=chunk_states.device,
1884
- )
1885
- grid = lambda META: (
1886
- triton.cdiv(headdim, META["BLOCK_SIZE_M"])
1887
- * triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
1888
- batch,
1889
- nheads,
1890
- )
1891
- with torch.cuda.device(x.device.index):
1892
- _chunk_state_varlen_kernel[grid](
1893
- x,
1894
- B,
1895
- dt,
1896
- dA_cumsum,
1897
- chunk_states,
1898
- cu_seqlens,
1899
- states,
1900
- headdim,
1901
- dstate,
1902
- chunk_size,
1903
- total_seqlen,
1904
- nheads // ngroups,
1905
- x.stride(0),
1906
- x.stride(1),
1907
- x.stride(2),
1908
- B.stride(0),
1909
- B.stride(1),
1910
- B.stride(2),
1911
- dt.stride(1),
1912
- dt.stride(0),
1913
- dt.stride(2),
1914
- dA_cumsum.stride(1),
1915
- dA_cumsum.stride(0),
1916
- dA_cumsum.stride(2),
1917
- chunk_states.stride(0),
1918
- chunk_states.stride(1),
1919
- chunk_states.stride(2),
1920
- chunk_states.stride(3),
1921
- states.stride(0),
1922
- states.stride(1),
1923
- states.stride(2),
1924
- states.stride(3),
1925
- )
1926
- return states
1927
-
1928
-
1929
- class ChunkStateFn(torch.autograd.Function):
1930
-
1931
- @staticmethod
1932
- def forward(ctx, B, x, dt, dA_cumsum, states_in_fp32=True):
1933
- batch, seqlen, nheads, headdim = x.shape
1934
- _, _, nchunks, chunk_size = dt.shape
1935
- assert seqlen <= nchunks * chunk_size
1936
- _, _, ngroups, dstate = B.shape
1937
- assert B.shape == (batch, seqlen, ngroups, dstate)
1938
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
1939
- assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
1940
- if B.stride(-1) != 1:
1941
- B = B.contiguous()
1942
- if (
1943
- x.stride(-1) != 1 and x.stride(1) != 1
1944
- ): # Either M or K dimension should be contiguous
1945
- x = x.contiguous()
1946
- states = _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=states_in_fp32)
1947
- ctx.save_for_backward(B, x, dt, dA_cumsum)
1948
- return states
1949
-
1950
- @staticmethod
1951
- def backward(ctx, dstates):
1952
- B, x, dt, dA_cumsum = ctx.saved_tensors
1953
- batch, seqlen, nheads, headdim = x.shape
1954
- _, _, nchunks, chunk_size = dt.shape
1955
- _, _, ngroups, dstate = B.shape
1956
- assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
1957
- if dstates.stride(-1) != 1:
1958
- dstates = dstates.contiguous()
1959
- dx, ddt, ddA_cumsum = _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates)
1960
- dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, ngroups=ngroups)
1961
- dB = dB.to(B.dtype)
1962
- return dB, dx, ddt, ddA_cumsum, None
1963
-
1964
-
1965
- def chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True):
1966
- """
1967
- Argument:
1968
- B: (batch, seqlen, ngroups, headdim)
1969
- x: (batch, seqlen, nheads, headdim)
1970
- dt: (batch, nheads, nchunks, chunk_size)
1971
- dA_cumsum: (batch, nheads, nchunks, chunk_size)
1972
- Return:
1973
- states: (batch, nchunks, nheads, headdim, dstate)
1974
- """
1975
- return ChunkStateFn.apply(B, x, dt, dA_cumsum, states_in_fp32)
1976
-
1977
-
1978
- def chunk_state_ref(B, x, dt, dA_cumsum):
1979
- """
1980
- Argument:
1981
- B: (batch, seqlen, ngroups, headdim)
1982
- x: (batch, seqlen, nheads, headdim)
1983
- dt: (batch, nheads, nchunks, chunk_size)
1984
- dA_cumsum: (batch, nheads, nchunks, chunk_size)
1985
- Return:
1986
- states: (batch, nchunks, nheads, headdim, dstate)
1987
- """
1988
- # Check constraints.
1989
- batch, seqlen, nheads, headdim = x.shape
1990
- dstate = B.shape[-1]
1991
- _, _, nchunks, chunk_size = dt.shape
1992
- assert seqlen <= nchunks * chunk_size
1993
- assert x.shape == (batch, seqlen, nheads, headdim)
1994
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
1995
- ngroups = B.shape[2]
1996
- assert nheads % ngroups == 0
1997
- assert B.shape == (batch, seqlen, ngroups, dstate)
1998
- B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
1999
- assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
2000
- if seqlen < nchunks * chunk_size:
2001
- x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
2002
- B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
2003
- x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size)
2004
- B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size)
2005
- decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum))
2006
- return torch.einsum(
2007
- "bclhn,bhcl,bhcl,bclhp->bchpn",
2008
- B.to(x.dtype),
2009
- decay_states.to(x.dtype),
2010
- dt.to(x.dtype),
2011
- x,
2012
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_combined.py DELETED
@@ -1,1884 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao, Albert Gu.
2
-
3
- """We want triton==2.1.0 or 2.2.0 for this
4
- """
5
-
6
- from typing import Optional
7
-
8
- import math
9
- from packaging import version
10
-
11
- import torch
12
- import torch.nn.functional as F
13
- from torch import Tensor
14
- from ...utils.torch import custom_bwd, custom_fwd
15
-
16
- import triton
17
- import triton.language as tl
18
-
19
- from einops import rearrange, repeat
20
-
21
- try:
22
- from causal_conv1d import causal_conv1d_fn
23
- import causal_conv1d_cuda
24
- except ImportError:
25
- causal_conv1d_fn, causal_conv1d_cuda = None, None
26
-
27
- from .ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
28
- from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd
29
- from .ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db
30
- from .ssd_chunk_state import _chunk_state_bwd_ddAcs_stable
31
- from .ssd_chunk_state import chunk_state, chunk_state_ref
32
- from .ssd_chunk_state import chunk_state_varlen
33
- from .ssd_state_passing import _state_passing_fwd, _state_passing_bwd
34
- from .ssd_state_passing import state_passing, state_passing_ref
35
- from .ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates
36
- from .ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb
37
- from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable
38
- from .ssd_chunk_scan import chunk_scan, chunk_scan_ref
39
- from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev
40
- from .layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd
41
- from .k_activations import _swiglu_fwd, _swiglu_bwd
42
-
43
- TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
44
-
45
-
46
- def init_to_zero(names):
47
- return lambda nargs: [
48
- nargs[name].zero_() for name in names if nargs[name] is not None
49
- ]
50
-
51
-
52
- @triton.autotune(
53
- configs=[
54
- triton.Config(
55
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
56
- num_stages=3,
57
- num_warps=8,
58
- pre_hook=init_to_zero(["ddt_ptr"]),
59
- ),
60
- triton.Config(
61
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
62
- num_stages=4,
63
- num_warps=4,
64
- pre_hook=init_to_zero(["ddt_ptr"]),
65
- ),
66
- triton.Config(
67
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
68
- num_stages=4,
69
- num_warps=4,
70
- pre_hook=init_to_zero(["ddt_ptr"]),
71
- ),
72
- triton.Config(
73
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
74
- num_stages=4,
75
- num_warps=4,
76
- pre_hook=init_to_zero(["ddt_ptr"]),
77
- ),
78
- triton.Config(
79
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
80
- num_stages=4,
81
- num_warps=4,
82
- pre_hook=init_to_zero(["ddt_ptr"]),
83
- ),
84
- triton.Config(
85
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
86
- num_stages=4,
87
- num_warps=4,
88
- pre_hook=init_to_zero(["ddt_ptr"]),
89
- ),
90
- triton.Config(
91
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
92
- num_stages=5,
93
- num_warps=4,
94
- pre_hook=init_to_zero(["ddt_ptr"]),
95
- ),
96
- triton.Config(
97
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
98
- num_stages=5,
99
- num_warps=4,
100
- pre_hook=init_to_zero(["ddt_ptr"]),
101
- ),
102
- triton.Config(
103
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
104
- num_stages=4,
105
- num_warps=4,
106
- pre_hook=init_to_zero(["ddt_ptr"]),
107
- ),
108
- ],
109
- key=["chunk_size", "hdim", "dstate"],
110
- )
111
- @triton.jit
112
- def _chunk_scan_chunk_state_bwd_dx_kernel(
113
- # Pointers to matrices
114
- x_ptr,
115
- cb_ptr,
116
- dout_ptr,
117
- dt_ptr,
118
- dA_cumsum_ptr,
119
- seq_idx_ptr,
120
- D_ptr,
121
- b_ptr,
122
- dstates_ptr,
123
- dx_ptr,
124
- ddt_ptr,
125
- dD_ptr,
126
- # Matrix dimensions
127
- chunk_size,
128
- hdim,
129
- dstate,
130
- batch,
131
- seqlen,
132
- nheads_ngroups_ratio,
133
- # Strides
134
- stride_x_batch,
135
- stride_x_seqlen,
136
- stride_x_head,
137
- stride_x_hdim,
138
- stride_cb_batch,
139
- stride_cb_chunk,
140
- stride_cb_head,
141
- stride_cb_csize_m,
142
- stride_cb_csize_k,
143
- stride_dout_batch,
144
- stride_dout_seqlen,
145
- stride_dout_head,
146
- stride_dout_hdim,
147
- stride_dt_batch,
148
- stride_dt_chunk,
149
- stride_dt_head,
150
- stride_dt_csize,
151
- stride_dA_cs_batch,
152
- stride_dA_cs_chunk,
153
- stride_dA_cs_head,
154
- stride_dA_cs_csize,
155
- stride_seq_idx_batch,
156
- stride_seq_idx_seqlen,
157
- stride_D_head,
158
- stride_b_batch,
159
- stride_b_seqlen,
160
- stride_b_head,
161
- stride_b_dstate,
162
- stride_dstates_batch,
163
- stride_dstates_chunk,
164
- stride_dstates_head,
165
- stride_dstates_hdim,
166
- stride_dstates_dstate,
167
- stride_dx_batch,
168
- stride_dx_seqlen,
169
- stride_dx_head,
170
- stride_dx_hdim,
171
- stride_ddt_batch,
172
- stride_ddt_chunk,
173
- stride_ddt_head,
174
- stride_ddt_csize,
175
- stride_dD_batch,
176
- stride_dD_chunk,
177
- stride_dD_head,
178
- stride_dD_csize,
179
- stride_dD_hdim,
180
- # Meta-parameters
181
- HAS_D: tl.constexpr,
182
- D_HAS_HDIM: tl.constexpr,
183
- HAS_SEQ_IDX: tl.constexpr,
184
- BLOCK_SIZE_M: tl.constexpr,
185
- BLOCK_SIZE_N: tl.constexpr,
186
- BLOCK_SIZE_K: tl.constexpr,
187
- BLOCK_SIZE_DSTATE: tl.constexpr,
188
- IS_TRITON_22: tl.constexpr,
189
- ):
190
- pid_bc = tl.program_id(axis=1)
191
- pid_c = pid_bc // batch
192
- pid_b = pid_bc - pid_c * batch
193
- pid_h = tl.program_id(axis=2)
194
- num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
195
- pid_m = tl.program_id(axis=0) // num_pid_n
196
- pid_n = tl.program_id(axis=0) % num_pid_n
197
- x_ptr += (
198
- pid_b * stride_x_batch
199
- + pid_c * chunk_size * stride_x_seqlen
200
- + pid_h * stride_x_head
201
- )
202
- cb_ptr += (
203
- pid_b * stride_cb_batch
204
- + pid_c * stride_cb_chunk
205
- + (pid_h // nheads_ngroups_ratio) * stride_cb_head
206
- )
207
- dout_ptr += (
208
- pid_b * stride_dout_batch
209
- + pid_c * chunk_size * stride_dout_seqlen
210
- + pid_h * stride_dout_head
211
- )
212
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
213
- ddt_ptr += (
214
- pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
215
- )
216
- dA_cumsum_ptr += (
217
- pid_b * stride_dA_cs_batch
218
- + pid_c * stride_dA_cs_chunk
219
- + pid_h * stride_dA_cs_head
220
- )
221
- b_ptr += (
222
- pid_b * stride_b_batch
223
- + pid_c * chunk_size * stride_b_seqlen
224
- + (pid_h // nheads_ngroups_ratio) * stride_b_head
225
- )
226
- dstates_ptr += (
227
- pid_b * stride_dstates_batch
228
- + pid_c * stride_dstates_chunk
229
- + pid_h * stride_dstates_head
230
- )
231
- if HAS_SEQ_IDX:
232
- seq_idx_ptr += (
233
- pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
234
- )
235
-
236
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
237
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
238
-
239
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
240
-
241
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
242
-
243
- dA_cs_m = tl.load(
244
- dA_cumsum_ptr + offs_m * stride_dA_cs_csize,
245
- mask=offs_m < chunk_size_limit,
246
- other=0.0,
247
- ).to(tl.float32)
248
-
249
- dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
250
- tl.float32
251
- )
252
- if not HAS_SEQ_IDX:
253
- scale = tl.exp(dA_cs_last - dA_cs_m)
254
- else:
255
- seq_idx_m = tl.load(
256
- seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
257
- mask=offs_m < chunk_size_limit,
258
- other=-1,
259
- )
260
- seq_idx_last = tl.load(
261
- seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
262
- )
263
- scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
264
- # Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
265
- # However, we're getting error with the Triton compiler 2.1.0 for that code path:
266
- # Unexpected mma -> mma layout conversion
267
- # Triton 2.2.0 fixes this
268
- offs_dstate = tl.arange(
269
- 0,
270
- (
271
- BLOCK_SIZE_DSTATE
272
- if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128
273
- else BLOCK_SIZE_K
274
- ),
275
- )
276
- b_ptrs = b_ptr + (
277
- offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate
278
- )
279
- dstates_ptrs = dstates_ptr + (
280
- offs_n[None, :] * stride_dstates_hdim
281
- + offs_dstate[:, None] * stride_dstates_dstate
282
- )
283
- if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128:
284
- b = tl.load(
285
- b_ptrs,
286
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate),
287
- other=0.0,
288
- )
289
- dstates = tl.load(
290
- dstates_ptrs,
291
- mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim),
292
- other=0.0,
293
- )
294
- dstates = dstates.to(b_ptr.dtype.element_ty)
295
- acc = tl.dot(b, dstates) * scale[:, None]
296
- else:
297
- for k in range(0, dstate, BLOCK_SIZE_K):
298
- b = tl.load(
299
- b_ptrs,
300
- mask=(offs_m[:, None] < chunk_size_limit)
301
- & (offs_dstate[None, :] < dstate - k),
302
- other=0.0,
303
- )
304
- dstates = tl.load(
305
- dstates_ptrs,
306
- mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim),
307
- other=0.0,
308
- )
309
- dstates = dstates.to(b_ptr.dtype.element_ty)
310
- acc += tl.dot(b, dstates)
311
- b_ptrs += BLOCK_SIZE_K * stride_b_dstate
312
- dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate
313
- acc *= scale[:, None]
314
-
315
- # x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
316
- # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
317
- # dt_ptrs = dt_ptr + offs_m * stride_dt_csize
318
- # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
319
- # ddt = tl.sum(acc * x, axis=1) * dt_m
320
- # ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
321
- # tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
322
-
323
- offs_k = tl.arange(0, BLOCK_SIZE_K)
324
- cb_ptrs = cb_ptr + (
325
- offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k
326
- )
327
- dout_ptrs = dout_ptr + (
328
- offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim
329
- )
330
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
331
- K_MAX = chunk_size_limit
332
- K_MIN = pid_m * BLOCK_SIZE_M
333
- cb_ptrs += K_MIN * stride_cb_csize_k
334
- dout_ptrs += K_MIN * stride_dout_seqlen
335
- dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize
336
- for k in range(K_MIN, K_MAX, BLOCK_SIZE_K):
337
- k = tl.multiple_of(k, BLOCK_SIZE_K)
338
- # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower
339
- cb = tl.load(
340
- cb_ptrs,
341
- mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k),
342
- other=0.0,
343
- )
344
- dout = tl.load(
345
- dout_ptrs,
346
- mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim),
347
- other=0.0,
348
- )
349
- dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(
350
- tl.float32
351
- )
352
- cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
353
- # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
354
- # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
355
- # Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
356
- # This will cause NaN in acc, and hence NaN in dx and ddt.
357
- mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX)
358
- cb = tl.where(mask, cb, 0.0)
359
- cb = cb.to(dout_ptr.dtype.element_ty)
360
- acc += tl.dot(cb, dout)
361
- cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
362
- dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen
363
- dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
364
-
365
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
366
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
367
- dt_ptrs = dt_ptr + offs_m * stride_dt_csize
368
- dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
369
- dx = acc * dt_m[:, None]
370
- dx_ptr += (
371
- pid_b * stride_dx_batch
372
- + pid_c * chunk_size * stride_dx_seqlen
373
- + pid_h * stride_dx_head
374
- )
375
- dx_ptrs = dx_ptr + (
376
- offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim
377
- )
378
- if HAS_D:
379
- dout_res_ptrs = dout_ptr + (
380
- offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim
381
- )
382
- dout_res = tl.load(
383
- dout_res_ptrs,
384
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
385
- other=0.0,
386
- ).to(tl.float32)
387
- if D_HAS_HDIM:
388
- D = tl.load(
389
- D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0
390
- ).to(tl.float32)
391
- else:
392
- D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
393
- dx += dout_res * D
394
- tl.store(
395
- dx_ptrs,
396
- dx,
397
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
398
- )
399
-
400
- x_ptrs = x_ptr + (
401
- offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
402
- )
403
- x = tl.load(
404
- x_ptrs,
405
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
406
- other=0.0,
407
- ).to(tl.float32)
408
- if HAS_D:
409
- dD_ptr += (
410
- pid_b * stride_dD_batch
411
- + pid_c * stride_dD_chunk
412
- + pid_h * stride_dD_head
413
- + pid_m * stride_dD_csize
414
- )
415
- if D_HAS_HDIM:
416
- dD_ptrs = dD_ptr + offs_n * stride_dD_hdim
417
- dD = tl.sum(dout_res * x, axis=0)
418
- tl.store(dD_ptrs, dD, mask=offs_n < hdim)
419
- else:
420
- dD = tl.sum(dout_res * x)
421
- tl.store(dD_ptr, dD)
422
- ddt = tl.sum(acc * x, axis=1)
423
- ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
424
- tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
425
-
426
-
427
- def _chunk_scan_chunk_state_bwd_dx(
428
- x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None
429
- ):
430
- batch, seqlen, nheads, headdim = x.shape
431
- _, _, nchunks, chunk_size = dt.shape
432
- _, _, ngroups, dstate = B.shape
433
- assert nheads % ngroups == 0
434
- assert B.shape == (batch, seqlen, ngroups, dstate)
435
- assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
436
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
437
- assert dA_cumsum.shape == dt.shape
438
- assert dout.shape == x.shape
439
- assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
440
- if seq_idx is not None:
441
- assert seq_idx.shape == (batch, seqlen)
442
- if D is not None:
443
- assert D.shape == (nheads, headdim) or D.shape == (nheads,)
444
- assert D.stride(-1) == 1
445
- BLOCK_SIZE_min = 32
446
- dD = torch.empty(
447
- triton.cdiv(chunk_size, BLOCK_SIZE_min),
448
- batch,
449
- nchunks,
450
- nheads,
451
- headdim if D.dim() == 2 else 1,
452
- device=D.device,
453
- dtype=torch.float32,
454
- )
455
- else:
456
- dD = None
457
- dD_strides = (
458
- (dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))
459
- if D is not None
460
- else (0, 0, 0, 0, 0)
461
- )
462
- if dx is None:
463
- dx = torch.empty_like(x)
464
- else:
465
- assert dx.shape == x.shape
466
- ddt = torch.empty(
467
- batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32
468
- )
469
- grid_dx = lambda META: (
470
- triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
471
- * triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
472
- batch * nchunks,
473
- nheads,
474
- )
475
- with torch.cuda.device(x.device.index):
476
- _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](
477
- x,
478
- CB,
479
- dout,
480
- dt,
481
- dA_cumsum,
482
- seq_idx,
483
- D,
484
- B,
485
- dstates,
486
- dx,
487
- ddt,
488
- dD,
489
- chunk_size,
490
- headdim,
491
- dstate,
492
- batch,
493
- seqlen,
494
- nheads // ngroups,
495
- x.stride(0),
496
- x.stride(1),
497
- x.stride(2),
498
- x.stride(3),
499
- CB.stride(0),
500
- CB.stride(1),
501
- CB.stride(2),
502
- CB.stride(-1),
503
- CB.stride(-2),
504
- dout.stride(0),
505
- dout.stride(1),
506
- dout.stride(2),
507
- dout.stride(3),
508
- dt.stride(0),
509
- dt.stride(2),
510
- dt.stride(1),
511
- dt.stride(3),
512
- dA_cumsum.stride(0),
513
- dA_cumsum.stride(2),
514
- dA_cumsum.stride(1),
515
- dA_cumsum.stride(3),
516
- *(
517
- (seq_idx.stride(0), seq_idx.stride(1))
518
- if seq_idx is not None
519
- else (0, 0)
520
- ),
521
- D.stride(0) if D is not None else 0,
522
- B.stride(0),
523
- B.stride(1),
524
- B.stride(2),
525
- B.stride(3),
526
- dstates.stride(0),
527
- dstates.stride(1),
528
- dstates.stride(2),
529
- dstates.stride(3),
530
- dstates.stride(4),
531
- dx.stride(0),
532
- dx.stride(1),
533
- dx.stride(2),
534
- dx.stride(3),
535
- ddt.stride(0),
536
- ddt.stride(2),
537
- ddt.stride(1),
538
- ddt.stride(3),
539
- dD_strides[1],
540
- dD_strides[2],
541
- dD_strides[3],
542
- dD_strides[0],
543
- dD_strides[4],
544
- D is not None,
545
- D.dim() == 2 if D is not None else True,
546
- HAS_SEQ_IDX=seq_idx is not None,
547
- BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
548
- IS_TRITON_22=TRITON_22
549
- )
550
- if D is not None:
551
- BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs[
552
- "BLOCK_SIZE_M"
553
- ]
554
- n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
555
- dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
556
- if D.dim() == 1:
557
- dD = rearrange(dD, "h 1 -> h")
558
- return dx, ddt.to(dtype=dt.dtype), dD
559
-
560
-
561
- def _mamba_chunk_scan_combined_fwd(
562
- x,
563
- dt,
564
- A,
565
- B,
566
- C,
567
- chunk_size,
568
- D=None,
569
- z=None,
570
- dt_bias=None,
571
- initial_states=None,
572
- seq_idx=None,
573
- cu_seqlens=None,
574
- dt_softplus=False,
575
- dt_limit=(0.0, float("inf")),
576
- ):
577
- batch, seqlen, nheads, headdim = x.shape
578
- _, _, ngroups, dstate = B.shape
579
- assert nheads % ngroups == 0
580
- assert B.shape == (batch, seqlen, ngroups, dstate)
581
- assert x.shape == (batch, seqlen, nheads, headdim)
582
- assert dt.shape == (batch, seqlen, nheads)
583
- assert A.shape == (nheads,)
584
- assert C.shape == B.shape
585
- if z is not None:
586
- assert z.shape == x.shape
587
- if D is not None:
588
- assert D.shape == (nheads, headdim) or D.shape == (nheads,)
589
- if seq_idx is not None:
590
- assert seq_idx.shape == (batch, seqlen)
591
- if B.stride(-1) != 1:
592
- B = B.contiguous()
593
- if C.stride(-1) != 1:
594
- C = C.contiguous()
595
- if (
596
- x.stride(-1) != 1 and x.stride(1) != 1
597
- ): # Either M or K dimension should be contiguous
598
- x = x.contiguous()
599
- if (
600
- z is not None and z.stride(-1) != 1 and z.stride(1) != 1
601
- ): # Either M or K dimension should be contiguous
602
- z = z.contiguous()
603
- if D is not None and D.stride(-1) != 1:
604
- D = D.contiguous()
605
- if initial_states is not None:
606
- assert initial_states.shape == (batch, nheads, headdim, dstate)
607
- # # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size)
608
- # dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
609
- # dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
610
- # dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
611
- dA_cumsum, dt = _chunk_cumsum_fwd(
612
- dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit
613
- )
614
- states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
615
- # states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True)
616
- # states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True)
617
- # states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True)
618
- states, final_states = _state_passing_fwd(
619
- rearrange(states, "... p n -> ... (p n)"),
620
- dA_cumsum[:, :, :, -1],
621
- initial_states=(
622
- rearrange(initial_states, "... p n -> ... (p n)")
623
- if initial_states is not None
624
- else None
625
- ),
626
- seq_idx=seq_idx,
627
- chunk_size=chunk_size,
628
- out_dtype=C.dtype,
629
- )
630
- states, final_states = [
631
- rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]
632
- ]
633
- # states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
634
- # states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
635
- CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
636
- out, out_x = _chunk_scan_fwd(
637
- CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx
638
- )
639
- if cu_seqlens is None:
640
- return out, out_x, dt, dA_cumsum, states, final_states
641
- else:
642
- assert (
643
- batch == 1
644
- ), "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
645
- varlen_states = chunk_state_varlen(
646
- B.squeeze(0),
647
- x.squeeze(0),
648
- dt.squeeze(0),
649
- dA_cumsum.squeeze(0),
650
- cu_seqlens,
651
- states.squeeze(0),
652
- )
653
- return out, out_x, dt, dA_cumsum, states, final_states, varlen_states
654
-
655
-
656
- def _mamba_chunk_scan_combined_bwd(
657
- dout,
658
- x,
659
- dt,
660
- A,
661
- B,
662
- C,
663
- out,
664
- chunk_size,
665
- D=None,
666
- z=None,
667
- dt_bias=None,
668
- initial_states=None,
669
- dfinal_states=None,
670
- seq_idx=None,
671
- dt_softplus=False,
672
- dt_limit=(0.0, float("inf")),
673
- dx=None,
674
- ddt=None,
675
- dB=None,
676
- dC=None,
677
- dz=None,
678
- recompute_output=False,
679
- ):
680
- if dout.stride(-1) != 1:
681
- dout = dout.contiguous()
682
- batch, seqlen, nheads, headdim = x.shape
683
- nchunks = math.ceil(seqlen / chunk_size)
684
- _, _, ngroups, dstate = B.shape
685
- assert dout.shape == (batch, seqlen, nheads, headdim)
686
- assert dt.shape == (batch, seqlen, nheads)
687
- assert A.shape == (nheads,)
688
- assert nheads % ngroups == 0
689
- assert B.shape == (batch, seqlen, ngroups, dstate)
690
- assert C.shape == B.shape
691
- assert out.shape == x.shape
692
- if initial_states is not None:
693
- assert initial_states.shape == (batch, nheads, headdim, dstate)
694
- if seq_idx is not None:
695
- assert seq_idx.shape == (batch, seqlen)
696
- if dx is not None:
697
- assert dx.shape == x.shape
698
- if dB is not None:
699
- assert dB.shape == B.shape
700
- dB_given = dB
701
- else:
702
- dB_given = torch.empty_like(B)
703
- if dC is not None:
704
- assert dC.shape == C.shape
705
- dC_given = dC
706
- else:
707
- dC_given = torch.empty_like(C)
708
- if dz is not None:
709
- assert z is not None
710
- assert dz.shape == z.shape
711
- if ddt is not None:
712
- assert ddt.shape == dt.shape
713
- ddt_given = ddt
714
- else:
715
- ddt_given = torch.empty_like(dt)
716
- # TD: For some reason Triton (2.1.0 and 2.2.0) errors with
717
- # "[CUDA]: invalid device context" (e.g. during varlne test), and cloning makes it work. Idk why.
718
- dt_in = dt.clone()
719
- dA_cumsum, dt = _chunk_cumsum_fwd(
720
- dt_in,
721
- A,
722
- chunk_size,
723
- dt_bias=dt_bias,
724
- dt_softplus=dt_softplus,
725
- dt_limit=dt_limit,
726
- )
727
- CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
728
- states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
729
- states, _ = _state_passing_fwd(
730
- rearrange(states, "... p n -> ... (p n)"),
731
- dA_cumsum[:, :, :, -1],
732
- initial_states=(
733
- rearrange(initial_states, "... p n -> ... (p n)")
734
- if initial_states is not None
735
- else None
736
- ),
737
- seq_idx=seq_idx,
738
- chunk_size=chunk_size,
739
- )
740
- states = rearrange(states, "... (p n) -> ... p n", n=dstate)
741
- if z is not None:
742
- dz, dout, dD, *rest = _chunk_scan_bwd_dz(
743
- x,
744
- z,
745
- out,
746
- dout,
747
- chunk_size=chunk_size,
748
- has_ddAcs=False,
749
- D=D,
750
- dz=dz,
751
- recompute_output=recompute_output,
752
- )
753
- outz = rest[0] if recompute_output else out
754
- else:
755
- dz = None
756
- outz = out
757
- dstates = _chunk_scan_bwd_dstates(
758
- C, dA_cumsum, dout, seq_idx=seq_idx, dtype=states.dtype
759
- )
760
- # dstates has length nchunks, containing the gradient to initial states at index 0 and
761
- # gradient to the states of chunk (nchunks - 2) at index (nchunks - 1)
762
- # Do computation in fp32 but convert dstates and states to fp16/bf16 since dstates and states
763
- # will be used in matmul in the next kernels.
764
- dstates, ddA_chunk_cumsum, dinitial_states, states = _state_passing_bwd(
765
- rearrange(states, "... p n -> ... (p n)"),
766
- dA_cumsum[:, :, :, -1],
767
- rearrange(dstates, "... p n -> ... (p n)"),
768
- dfinal_states=(
769
- rearrange(dfinal_states, "... p n -> ... (p n)")
770
- if dfinal_states is not None
771
- else None
772
- ),
773
- seq_idx=seq_idx,
774
- has_initial_states=initial_states is not None,
775
- dstates_dtype=x.dtype,
776
- states_dtype=x.dtype,
777
- chunk_size=chunk_size,
778
- )
779
- # dstates has length nchunks, containing the gradient to states of chunk 0 at index 0 and
780
- # gradient to the final states at index (nchunks - 1)
781
- # states has length nchunks, containing the initial states at index 0 and the state for chunk (nchunks - 2) at index (nchunks - 1)
782
- # The final states is not stored.
783
- states = rearrange(states, "... (p n) -> ... p n", n=dstate)
784
- dstates = rearrange(dstates, "... (p n) -> ... p n", n=dstate)
785
- dinitial_states = (
786
- rearrange(dinitial_states, "... (p n) -> ... p n", n=dstate)
787
- if dinitial_states is not None
788
- else None
789
- )
790
- dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(
791
- x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx
792
- )
793
- # dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, ngroups=ngroups)
794
- dB, ddA_next = _chunk_state_bwd_db(
795
- x, dt, dA_cumsum, dstates, seq_idx=seq_idx, B=B, ngroups=ngroups
796
- )
797
- # dC = _chunk_scan_bwd_dC(states[:, :-1].to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
798
- dC, ddA_cumsum_prev = _chunk_scan_bwd_dC(
799
- states.to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, C=C, ngroups=ngroups
800
- )
801
- # Computing ddA with the dcb kernel is much slower, so we're not using it for now
802
- dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
803
- # dCB, ddA_tmp = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, CB=CB, ngroups=ngroups)
804
- dCB = dCB.to(CB.dtype)
805
- _bmm_chunk_bwd(C, dCB, residual=dB, out=dB_given)
806
- _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC, out=dC_given)
807
- # If we have z, then dout_x is recomputed in fp32 so dD = (dout_x * x).sum() is more accurate
808
- # than dD_from_x = (dout_x * x).sum() where dout_x is in fp16/bf16
809
- if z is None:
810
- dD = dD_from_x
811
- # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D.
812
- # ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt
813
- # However, this is numerically unstable: when we do the reverse cumsum on ddA_cumsum, there might
814
- # be a lot of underflow.
815
-
816
- # This is already done as part of bwd_dC kernel
817
- # ddA_cumsum_prev = _chunk_scan_bwd_ddAcs_prev(states[:, :-1], C, dout, dA_cumsum, seq_idx=seq_idx)
818
- ddA_cumsum_prev[..., -1] += ddA_chunk_cumsum
819
- ddA_prev = ddA_cumsum_prev.flip([-1]).cumsum(dim=-1).flip([-1])
820
- # This is already done as part of bwd_dB kernel
821
- # ddA_next = _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=seq_idx)
822
- # We don't need to pass in seq_idx because CB also zeros out entries where seq_idx[i] != seq_idx[j]
823
- ddA = _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, CB)
824
- ddA += ddA_next + ddA_prev
825
-
826
- ddt_given, dA, ddt_bias = _chunk_cumsum_bwd(
827
- ddA,
828
- ddt,
829
- dt_in,
830
- A,
831
- dt_bias=dt_bias,
832
- dt_softplus=dt_softplus,
833
- dt_limit=dt_limit,
834
- ddt=ddt_given,
835
- )
836
-
837
- # These 2 lines are just to test ddt and dA being computed by old code
838
- # _, dA = selective_scan_bwd(dout, x, dt, A, B, C, D=D.float(), z=z)
839
- # ddt_given.copy_(ddt)
840
-
841
- return_vals = (
842
- dx,
843
- ddt_given,
844
- dA,
845
- dB_given,
846
- dC_given,
847
- dD,
848
- dz,
849
- ddt_bias,
850
- dinitial_states,
851
- )
852
- return return_vals if not recompute_output else (*return_vals, outz)
853
-
854
-
855
- def selective_scan_bwd(dout, x, dt, A, B, C, D=None, z=None):
856
- """
857
- Argument:
858
- dout: (batch, seqlen, nheads, headdim)
859
- x: (batch, seqlen, nheads, headdim)
860
- dt: (batch, nheads, nchunks, chunk_size) or (batch, nheads, headdim, nchunks, chunk_size)
861
- A: (nheads) or (dim, dstate)
862
- B: (batch, seqlen, ngroups, dstate)
863
- C: (batch, seqlen, ngroups, dstate)
864
- D: (nheads, headdim) or (nheads,)
865
- z: (batch, seqlen, nheads, headdim)
866
- Return:
867
- out: (batch, seqlen, nheads, headdim)
868
- """
869
- import selective_scan
870
-
871
- batch, seqlen, nheads, headdim = x.shape
872
- chunk_size = dt.shape[-1]
873
- _, _, ngroups, dstate = B.shape
874
- assert nheads % ngroups == 0
875
- x = rearrange(x, "b l h p -> b (h p) l")
876
- squeeze_dt = dt.dim() == 4
877
- if dt.dim() == 4:
878
- dt = repeat(dt, "b h c l -> b h p c l", p=headdim)
879
- dt = rearrange(dt, "b h p c l -> b (h p) (c l)", p=headdim)
880
- squeeze_A = A.dim() == 1
881
- if A.dim() == 1:
882
- A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
883
- else:
884
- A = A.to(dtype=torch.float32)
885
- B = rearrange(B, "b l g n -> b g n l")
886
- C = rearrange(C, "b l g n -> b g n l")
887
- if D is not None:
888
- if D.dim() == 2:
889
- D = rearrange(D, "h p -> (h p)")
890
- else:
891
- D = repeat(D, "h -> (h p)", p=headdim)
892
- if z is not None:
893
- z = rearrange(z, "b l h p -> b (h p) l")
894
-
895
- if x.stride(-1) != 1:
896
- x = x.contiguous()
897
- if dt.stride(-1) != 1:
898
- dt = dt.contiguous()
899
- if D is not None:
900
- D = D.contiguous()
901
- if B.stride(-1) != 1:
902
- B = B.contiguous()
903
- if C.stride(-1) != 1:
904
- C = C.contiguous()
905
- if z is not None and z.stride(-1) != 1:
906
- z = z.contiguous()
907
- _, intermediate, *rest = selective_scan.fwd(
908
- x, dt.to(dtype=x.dtype), A, B, C, D, z, None, False
909
- )
910
- if z is not None:
911
- out = rest[0]
912
- else:
913
- out = None
914
-
915
- dout = rearrange(dout, "b l h p -> b (h p) l")
916
-
917
- if dout.stride(-1) != 1:
918
- dout = dout.contiguous()
919
- # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
920
- # backward of selective_scan with the backward of chunk).
921
- # Here we just pass in None and dz will be allocated in the C++ code.
922
- _, ddt, dA, *rest = selective_scan.bwd(
923
- x,
924
- dt.to(dtype=x.dtype),
925
- A,
926
- B,
927
- C,
928
- D,
929
- z,
930
- None,
931
- dout,
932
- intermediate,
933
- out,
934
- None,
935
- False,
936
- False, # option to recompute out_z, not used here
937
- )
938
- ddt = rearrange(ddt, "b (h p) (c l) -> b h p c l", p=headdim, l=chunk_size)
939
- if squeeze_dt:
940
- ddt = ddt.float().sum(dim=2)
941
- if squeeze_A:
942
- dA = rearrange(dA, "(h p) n -> h p n", p=headdim).sum(dim=(1, 2))
943
- return ddt, dA
944
-
945
-
946
- class MambaChunkScanCombinedFn(torch.autograd.Function):
947
-
948
- @staticmethod
949
- def forward(
950
- ctx,
951
- x,
952
- dt,
953
- A,
954
- B,
955
- C,
956
- chunk_size,
957
- D=None,
958
- z=None,
959
- dt_bias=None,
960
- initial_states=None,
961
- seq_idx=None,
962
- cu_seqlens=None,
963
- dt_softplus=False,
964
- dt_limit=(0.0, float("inf")),
965
- return_final_states=False,
966
- return_varlen_states=False,
967
- ):
968
- ctx.dt_dtype = dt.dtype
969
- if not return_varlen_states:
970
- cu_seqlens = None
971
- else:
972
- assert (
973
- cu_seqlens is not None
974
- ), "cu_seqlens must be provided if return_varlen_states is True"
975
- out, out_x, dt_out, dA_cumsum, states, final_states, *rest = (
976
- _mamba_chunk_scan_combined_fwd(
977
- x,
978
- dt,
979
- A,
980
- B,
981
- C,
982
- chunk_size,
983
- D=D,
984
- z=z,
985
- dt_bias=dt_bias,
986
- initial_states=initial_states,
987
- seq_idx=seq_idx,
988
- cu_seqlens=cu_seqlens,
989
- dt_softplus=dt_softplus,
990
- dt_limit=dt_limit,
991
- )
992
- )
993
- ctx.save_for_backward(
994
- out if z is None else out_x,
995
- x,
996
- dt,
997
- dA_cumsum,
998
- A,
999
- B,
1000
- C,
1001
- D,
1002
- z,
1003
- dt_bias,
1004
- initial_states,
1005
- seq_idx,
1006
- )
1007
- ctx.dt_softplus = dt_softplus
1008
- ctx.chunk_size = chunk_size
1009
- ctx.dt_limit = dt_limit
1010
- ctx.return_final_states = return_final_states
1011
- ctx.return_varlen_states = return_varlen_states
1012
- if not return_varlen_states:
1013
- return out if not return_final_states else (out, final_states)
1014
- else:
1015
- varlen_states = rest[0]
1016
- return (
1017
- (out, varlen_states)
1018
- if not return_final_states
1019
- else (out, final_states, varlen_states)
1020
- )
1021
-
1022
- @staticmethod
1023
- def backward(ctx, dout, *args):
1024
- out, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx = (
1025
- ctx.saved_tensors
1026
- )
1027
- assert (
1028
- not ctx.return_varlen_states
1029
- ), "return_varlen_states is not supported in backward"
1030
- dfinal_states = args[0] if ctx.return_final_states else None
1031
- dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = (
1032
- _mamba_chunk_scan_combined_bwd(
1033
- dout,
1034
- x,
1035
- dt,
1036
- A,
1037
- B,
1038
- C,
1039
- out,
1040
- ctx.chunk_size,
1041
- D=D,
1042
- z=z,
1043
- dt_bias=dt_bias,
1044
- initial_states=initial_states,
1045
- dfinal_states=dfinal_states,
1046
- seq_idx=seq_idx,
1047
- dt_softplus=ctx.dt_softplus,
1048
- dt_limit=ctx.dt_limit,
1049
- )
1050
- )
1051
- return (
1052
- dx,
1053
- ddt,
1054
- dA,
1055
- dB,
1056
- dC,
1057
- None,
1058
- dD,
1059
- dz,
1060
- ddt_bias,
1061
- dinitial_states,
1062
- None,
1063
- None,
1064
- None,
1065
- None,
1066
- None,
1067
- None,
1068
- )
1069
-
1070
-
1071
- def mamba_chunk_scan_combined(
1072
- x,
1073
- dt,
1074
- A,
1075
- B,
1076
- C,
1077
- chunk_size,
1078
- D=None,
1079
- z=None,
1080
- dt_bias=None,
1081
- initial_states=None,
1082
- seq_idx=None,
1083
- cu_seqlens=None,
1084
- dt_softplus=False,
1085
- dt_limit=(0.0, float("inf")),
1086
- return_final_states=False,
1087
- return_varlen_states=False,
1088
- ):
1089
- """
1090
- Argument:
1091
- x: (batch, seqlen, nheads, headdim)
1092
- dt: (batch, seqlen, nheads)
1093
- A: (nheads)
1094
- B: (batch, seqlen, ngroups, dstate)
1095
- C: (batch, seqlen, ngroups, dstate)
1096
- chunk_size: int
1097
- D: (nheads, headdim) or (nheads,)
1098
- z: (batch, seqlen, nheads, headdim)
1099
- dt_bias: (nheads,)
1100
- initial_states: (batch, nheads, headdim, dstate)
1101
- seq_idx: (batch, seqlen)
1102
- cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
1103
- dt_softplus: Whether to apply softplus to dt
1104
- Return:
1105
- out: (batch, seqlen, nheads, headdim)
1106
- """
1107
- return MambaChunkScanCombinedFn.apply(
1108
- x,
1109
- dt,
1110
- A,
1111
- B,
1112
- C,
1113
- chunk_size,
1114
- D,
1115
- z,
1116
- dt_bias,
1117
- initial_states,
1118
- seq_idx,
1119
- cu_seqlens,
1120
- dt_softplus,
1121
- dt_limit,
1122
- return_final_states,
1123
- return_varlen_states,
1124
- )
1125
-
1126
-
1127
- def mamba_chunk_scan(
1128
- x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False
1129
- ):
1130
- """
1131
- Argument:
1132
- x: (batch, seqlen, nheads, headdim)
1133
- dt: (batch, seqlen, nheads)
1134
- A: (nheads)
1135
- B: (batch, seqlen, ngroups, dstate)
1136
- C: (batch, seqlen, ngroups, dstate)
1137
- D: (nheads, headdim) or (nheads,)
1138
- z: (batch, seqlen, nheads, headdim)
1139
- dt_bias: (nheads,)
1140
- Return:
1141
- out: (batch, seqlen, nheads, headdim)
1142
- """
1143
- batch, seqlen, nheads, headdim = x.shape
1144
- dstate = B.shape[-1]
1145
- if seqlen % chunk_size != 0:
1146
- dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
1147
- dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
1148
- dt = dt.float() # We want high precision for this before cumsum
1149
- if dt_bias is not None:
1150
- dt = dt + rearrange(dt_bias, "h -> h 1 1")
1151
- if dt_softplus:
1152
- dt = F.softplus(dt)
1153
- dA = dt * rearrange(A, "h -> h 1 1")
1154
- dA = dt * rearrange(A, "h -> h 1 1")
1155
- dA_cumsum = torch.cumsum(dA, dim=-1)
1156
- # 1. Compute the state for each chunk
1157
- states = chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True)
1158
- # 2. Pass the state to all the chunks by weighted cumsum.
1159
- states = rearrange(
1160
- state_passing(
1161
- rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1]
1162
- )[0],
1163
- "... (p n) -> ... p n",
1164
- n=dstate,
1165
- )
1166
- # 3. Compute the output for each chunk
1167
- out = chunk_scan(B, C, x, dt, dA_cumsum, states, D=D, z=z)
1168
- return out
1169
-
1170
-
1171
- def ssd_chunk_scan_combined_ref(
1172
- x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False
1173
- ):
1174
- """
1175
- Argument:
1176
- x: (batch, seqlen, nheads, headdim)
1177
- dt: (batch, seqlen, nheads)
1178
- A: (nheads)
1179
- B: (batch, seqlen, ngroups, dstate)
1180
- C: (batch, seqlen, ngroups, dstate)
1181
- D: (nheads, headdim) or (nheads,)
1182
- z: (batch, seqlen, nheads, headdim)
1183
- dt_bias: (nheads,)
1184
- Return:
1185
- out: (batch, seqlen, nheads, headdim)
1186
- """
1187
- batch, seqlen, nheads, headdim = x.shape
1188
- dstate = B.shape[-1]
1189
- if seqlen % chunk_size != 0:
1190
- dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
1191
- dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
1192
- dt = dt.float() # We want high precision for this before cumsum
1193
- if dt_bias is not None:
1194
- dt = dt + rearrange(dt_bias, "h -> h 1 1")
1195
- if dt_softplus:
1196
- dt = F.softplus(dt)
1197
- dA = dt * rearrange(A, "h -> h 1 1")
1198
- dA_cumsum = torch.cumsum(dA, dim=-1)
1199
- # 1. Compute the state for each chunk
1200
- states = chunk_state_ref(B, x, dt, dA_cumsum)
1201
- states_dtype = states.dtype
1202
- if states.dtype not in [torch.float32, torch.float64]:
1203
- states = states.to(torch.float32)
1204
- # 2. Pass the state to all the chunks by weighted cumsum.
1205
- # state_passing_ref is much less numerically stable
1206
- states = rearrange(
1207
- state_passing_ref(
1208
- rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1]
1209
- )[0],
1210
- "... (p n) -> ... p n",
1211
- n=dstate,
1212
- )
1213
- states = states.to(states_dtype)
1214
- # 3. Compute the output for each chunk
1215
- out = chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z)
1216
- return out
1217
-
1218
-
1219
- def ssd_selective_scan(
1220
- x,
1221
- dt,
1222
- A,
1223
- B,
1224
- C,
1225
- D=None,
1226
- z=None,
1227
- dt_bias=None,
1228
- dt_softplus=False,
1229
- dt_limit=(0.0, float("inf")),
1230
- ):
1231
- """
1232
- Argument:
1233
- x: (batch, seqlen, nheads, headdim)
1234
- dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
1235
- A: (nheads) or (dim, dstate)
1236
- B: (batch, seqlen, ngroups, dstate)
1237
- C: (batch, seqlen, ngroups, dstate)
1238
- D: (nheads, headdim) or (nheads,)
1239
- z: (batch, seqlen, nheads, headdim)
1240
- dt_bias: (nheads,) or (nheads, headdim)
1241
- Return:
1242
- out: (batch, seqlen, nheads, headdim)
1243
- """
1244
- from ..selective_scan_interface import selective_scan_fn
1245
-
1246
- batch, seqlen, nheads, headdim = x.shape
1247
- _, _, ngroups, dstate = B.shape
1248
- x = rearrange(x, "b l h p -> b (h p) l")
1249
- if dt.dim() == 3:
1250
- dt = repeat(dt, "b l h -> b l h p", p=headdim)
1251
- dt = rearrange(dt, "b l h p -> b (h p) l")
1252
- if A.dim() == 1:
1253
- A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
1254
- else:
1255
- A = A.to(dtype=torch.float32)
1256
- B = rearrange(B, "b l g n -> b g n l")
1257
- C = rearrange(C, "b l g n -> b g n l")
1258
- if D is not None:
1259
- if D.dim() == 2:
1260
- D = rearrange(D, "h p -> (h p)")
1261
- else:
1262
- D = repeat(D, "h -> (h p)", p=headdim)
1263
- if z is not None:
1264
- z = rearrange(z, "b l h p -> b (h p) l")
1265
- if dt_bias is not None:
1266
- if dt_bias.dim() == 1:
1267
- dt_bias = repeat(dt_bias, "h -> h p", p=headdim)
1268
- dt_bias = rearrange(dt_bias, "h p -> (h p)")
1269
- if dt_limit != (0.0, float("inf")):
1270
- if dt_bias is not None:
1271
- dt = dt + rearrange(dt_bias, "d -> d 1")
1272
- if dt_softplus:
1273
- dt = F.softplus(dt)
1274
- dt = dt.clamp(min=dt_limit[0], max=dt_limit[1]).to(x.dtype)
1275
- dt_bias = None
1276
- dt_softplus = None
1277
- out = selective_scan_fn(
1278
- x, dt, A, B, C, D=D, z=z, delta_bias=dt_bias, delta_softplus=dt_softplus
1279
- )
1280
- return rearrange(out, "b (h p) l -> b l h p", p=headdim)
1281
-
1282
-
1283
- def mamba_conv1d_scan_ref(
1284
- xBC,
1285
- conv1d_weight,
1286
- conv1d_bias,
1287
- dt,
1288
- A,
1289
- chunk_size,
1290
- D=None,
1291
- z=None,
1292
- dt_bias=None,
1293
- dt_softplus=False,
1294
- dt_limit=(0.0, float("inf")),
1295
- activation="silu",
1296
- headdim=None,
1297
- ngroups=1,
1298
- ):
1299
- """
1300
- Argument:
1301
- xBC: (batch, seqlen, dim + 2 * ngroups * dstate) where dim == nheads * headdim
1302
- conv1d_weight: (dim + 2 * ngroups * dstate, width)
1303
- conv1d_bias: (dim + 2 * ngroups * dstate,)
1304
- dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
1305
- A: (nheads)
1306
- D: (nheads, headdim) or (nheads,)
1307
- z: (batch, seqlen, dim)
1308
- dt_bias: (nheads) or (nheads, headdim)
1309
- headdim: if D is 1D and z is None, headdim must be passed in
1310
- Return:
1311
- out: (batch, seqlen, dim)
1312
- """
1313
- batch, seqlen, nheads = dt.shape[:3]
1314
- assert nheads % ngroups == 0
1315
- if z is not None:
1316
- dim = z.shape[-1]
1317
- assert dim % nheads == 0
1318
- headdim = dim // nheads
1319
- else:
1320
- if D.dim() == 1:
1321
- assert headdim is not None
1322
- else:
1323
- headdim = D.shape[1]
1324
- dim = nheads * headdim
1325
- xBC = rearrange(
1326
- causal_conv1d_fn(
1327
- rearrange(xBC, "b s d -> b d s"),
1328
- conv1d_weight,
1329
- conv1d_bias,
1330
- activation=activation,
1331
- ),
1332
- "b d s -> b s d",
1333
- )
1334
- dstate = (xBC.shape[-1] - dim) // ngroups // 2
1335
- x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
1336
- x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
1337
- B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
1338
- C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
1339
- z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
1340
- out = ssd_selective_scan(
1341
- x,
1342
- dt.to(x.dtype),
1343
- A,
1344
- B,
1345
- C,
1346
- D=D.float(),
1347
- z=z,
1348
- dt_bias=dt_bias,
1349
- dt_softplus=dt_softplus,
1350
- dt_limit=dt_limit,
1351
- )
1352
- return rearrange(out, "b s h p -> b s (h p)")
1353
-
1354
-
1355
- class MambaSplitConv1dScanCombinedFn(torch.autograd.Function):
1356
-
1357
- @staticmethod
1358
- @custom_fwd
1359
- def forward(
1360
- ctx,
1361
- zxbcdt,
1362
- conv1d_weight,
1363
- conv1d_bias,
1364
- dt_bias,
1365
- A,
1366
- D,
1367
- chunk_size,
1368
- initial_states=None,
1369
- seq_idx=None,
1370
- dt_limit=(0.0, float("inf")),
1371
- return_final_states=False,
1372
- activation="silu",
1373
- rmsnorm_weight=None,
1374
- rmsnorm_eps=1e-6,
1375
- outproj_weight=None,
1376
- outproj_bias=None,
1377
- headdim=None,
1378
- ngroups=1,
1379
- norm_before_gate=True,
1380
- ):
1381
- assert activation in [None, "silu", "swish"]
1382
- if D.dim() == 1:
1383
- assert headdim is not None
1384
- (nheads,) = D.shape
1385
- else:
1386
- nheads, headdim = D.shape
1387
- batch, seqlen, _ = zxbcdt.shape
1388
- dim = nheads * headdim
1389
- assert nheads % ngroups == 0
1390
- dstate = (conv1d_weight.shape[0] - dim) // ngroups // 2
1391
- d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ngroups * dstate - nheads) // 2
1392
- assert d_nonssm >= 0
1393
- assert zxbcdt.shape == (
1394
- batch,
1395
- seqlen,
1396
- 2 * d_nonssm + 2 * dim + 2 * ngroups * dstate + nheads,
1397
- )
1398
- assert dt_bias.shape == (nheads,)
1399
- assert A.shape == (nheads,)
1400
- zx0, z, xBC, dt = torch.split(
1401
- zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1
1402
- )
1403
- seq_idx = seq_idx.contiguous() if seq_idx is not None else None
1404
- xBC_conv = rearrange(
1405
- causal_conv1d_cuda.causal_conv1d_fwd(
1406
- rearrange(xBC, "b s d -> b d s"),
1407
- conv1d_weight,
1408
- conv1d_bias,
1409
- seq_idx,
1410
- None,
1411
- None,
1412
- activation in ["silu", "swish"],
1413
- ),
1414
- "b d s -> b s d",
1415
- )
1416
- x, B, C = torch.split(
1417
- xBC_conv, [dim, ngroups * dstate, ngroups * dstate], dim=-1
1418
- )
1419
- x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
1420
- B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
1421
- C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
1422
- z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
1423
- if rmsnorm_weight is None:
1424
- out, out_x, dt_out, dA_cumsum, states, final_states = (
1425
- _mamba_chunk_scan_combined_fwd(
1426
- x,
1427
- dt,
1428
- A,
1429
- B,
1430
- C,
1431
- chunk_size=chunk_size,
1432
- D=D,
1433
- z=z,
1434
- dt_bias=dt_bias,
1435
- initial_states=initial_states,
1436
- seq_idx=seq_idx,
1437
- dt_softplus=True,
1438
- dt_limit=dt_limit,
1439
- )
1440
- )
1441
- out = rearrange(out, "b s h p -> b s (h p)")
1442
- rstd = None
1443
- if d_nonssm > 0:
1444
- out = torch.cat([_swiglu_fwd(zx0), out], dim=-1)
1445
- else:
1446
- out_x, _, dt_out, dA_cumsum, states, final_states = (
1447
- _mamba_chunk_scan_combined_fwd(
1448
- x,
1449
- dt,
1450
- A,
1451
- B,
1452
- C,
1453
- chunk_size=chunk_size,
1454
- D=D,
1455
- z=None,
1456
- dt_bias=dt_bias,
1457
- initial_states=initial_states,
1458
- seq_idx=seq_idx,
1459
- dt_softplus=True,
1460
- dt_limit=dt_limit,
1461
- )
1462
- )
1463
- # reshape input data into 2D tensor
1464
- x_rms = rearrange(out_x, "b s h p -> (b s) (h p)")
1465
- z_rms = rearrange(z, "b s h p -> (b s) (h p)")
1466
- rmsnorm_weight = rmsnorm_weight.contiguous()
1467
- if d_nonssm == 0:
1468
- out = None
1469
- else:
1470
- out01 = torch.empty(
1471
- (batch, seqlen, d_nonssm + dim),
1472
- dtype=x_rms.dtype,
1473
- device=x_rms.device,
1474
- )
1475
- out = rearrange(out01[..., d_nonssm:], "b s d -> (b s) d")
1476
- _swiglu_fwd(zx0, out=out01[..., :d_nonssm])
1477
- out, _, rstd = _layer_norm_fwd(
1478
- x_rms,
1479
- rmsnorm_weight,
1480
- None,
1481
- rmsnorm_eps,
1482
- z_rms,
1483
- out=out,
1484
- group_size=dim // ngroups,
1485
- norm_before_gate=norm_before_gate,
1486
- is_rms_norm=True,
1487
- )
1488
- if d_nonssm == 0:
1489
- out = rearrange(out, "(b s) d -> b s d", b=batch)
1490
- else:
1491
- out = out01
1492
- ctx.outproj_weight_dtype = (
1493
- outproj_weight.dtype if outproj_weight is not None else None
1494
- )
1495
- if outproj_weight is not None:
1496
- if torch.is_autocast_enabled():
1497
- dtype = torch.get_autocast_gpu_dtype()
1498
- out, outproj_weight = out.to(dtype), outproj_weight.to(dtype)
1499
- outproj_bias = (
1500
- outproj_bias.to(dtype) if outproj_bias is not None else None
1501
- )
1502
- out = F.linear(out, outproj_weight, outproj_bias)
1503
- else:
1504
- assert outproj_bias is None
1505
- ctx.save_for_backward(
1506
- zxbcdt,
1507
- conv1d_weight,
1508
- conv1d_bias,
1509
- out_x,
1510
- A,
1511
- D,
1512
- dt_bias,
1513
- initial_states,
1514
- seq_idx,
1515
- rmsnorm_weight,
1516
- rstd,
1517
- outproj_weight,
1518
- outproj_bias,
1519
- )
1520
- ctx.dt_limit = dt_limit
1521
- ctx.return_final_states = return_final_states
1522
- ctx.activation = activation
1523
- ctx.rmsnorm_eps = rmsnorm_eps
1524
- ctx.norm_before_gate = norm_before_gate
1525
- ctx.chunk_size = chunk_size
1526
- ctx.headdim = headdim
1527
- ctx.ngroups = ngroups
1528
- return out if not return_final_states else (out, final_states)
1529
-
1530
- @staticmethod
1531
- @custom_bwd
1532
- def backward(ctx, dout, *args):
1533
- (
1534
- zxbcdt,
1535
- conv1d_weight,
1536
- conv1d_bias,
1537
- out,
1538
- A,
1539
- D,
1540
- dt_bias,
1541
- initial_states,
1542
- seq_idx,
1543
- rmsnorm_weight,
1544
- rstd,
1545
- outproj_weight,
1546
- outproj_bias,
1547
- ) = ctx.saved_tensors
1548
- dfinal_states = args[0] if ctx.return_final_states else None
1549
- headdim = ctx.headdim
1550
- nheads = D.shape[0]
1551
- dim = nheads * headdim
1552
- assert nheads % ctx.ngroups == 0
1553
- dstate = (conv1d_weight.shape[0] - dim) // ctx.ngroups // 2
1554
- d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ctx.ngroups * dstate - nheads) // 2
1555
- assert d_nonssm >= 0
1556
- recompute_output = outproj_weight is not None
1557
- if recompute_output:
1558
- out_recompute = torch.empty(
1559
- *out.shape[:2], d_nonssm + dim, device=out.device, dtype=out.dtype
1560
- )
1561
- out0_recompute, out1_recompute = out_recompute.split(
1562
- [d_nonssm, dim], dim=-1
1563
- )
1564
- zx0, z, xBC, dt = torch.split(
1565
- zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1
1566
- )
1567
- # Recompute x, B, C
1568
- xBC_conv = rearrange(
1569
- causal_conv1d_cuda.causal_conv1d_fwd(
1570
- rearrange(xBC, "b s d -> b d s"),
1571
- conv1d_weight,
1572
- conv1d_bias,
1573
- seq_idx,
1574
- None,
1575
- None,
1576
- ctx.activation in ["silu", "swish"],
1577
- ),
1578
- "b d s -> b s d",
1579
- )
1580
- x, B, C = torch.split(
1581
- xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1
1582
- )
1583
- x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
1584
- B = rearrange(B, "b l (g n) -> b l g n", g=ctx.ngroups)
1585
- C = rearrange(C, "b l (g n) -> b l g n", g=ctx.ngroups)
1586
- dzxbcdt = torch.empty_like(zxbcdt)
1587
- dzx0, dz, dxBC_given, ddt_given = torch.split(
1588
- dzxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1
1589
- )
1590
- dxBC = torch.empty_like(xBC)
1591
- dx, dB, dC = torch.split(
1592
- dxBC, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1
1593
- )
1594
- z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
1595
- dx = rearrange(dx, "b l (h p) -> b l h p", h=nheads)
1596
- dB = rearrange(dB, "b l (g n) -> b l g n", g=ctx.ngroups)
1597
- dC = rearrange(dC, "b l (g n) -> b l g n", g=ctx.ngroups)
1598
- if outproj_weight is not None:
1599
- dout_og = dout
1600
- dout = F.linear(dout, outproj_weight.t())
1601
- if d_nonssm > 0:
1602
- dout0, dout = dout.split([d_nonssm, dim], dim=-1)
1603
- _swiglu_bwd(zx0, dout0, dxy=dzx0, recompute_output=True, out=out0_recompute)
1604
- dout = rearrange(dout, "b s (h p) -> b s h p", p=headdim)
1605
- if rmsnorm_weight is None:
1606
- dz = rearrange(dz, "b l (h p) -> b l h p", h=nheads)
1607
- dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states, *rest = (
1608
- _mamba_chunk_scan_combined_bwd(
1609
- dout,
1610
- x,
1611
- dt,
1612
- A,
1613
- B,
1614
- C,
1615
- out,
1616
- ctx.chunk_size,
1617
- D=D,
1618
- z=z,
1619
- dt_bias=dt_bias,
1620
- initial_states=initial_states,
1621
- dfinal_states=dfinal_states,
1622
- seq_idx=seq_idx,
1623
- dt_softplus=True,
1624
- dt_limit=ctx.dt_limit,
1625
- dx=dx,
1626
- ddt=ddt_given,
1627
- dB=dB,
1628
- dC=dC,
1629
- dz=dz,
1630
- recompute_output=recompute_output,
1631
- )
1632
- )
1633
- out_for_linear = (
1634
- rearrange(rest[0], "b s h p -> b s (h p)") if recompute_output else None
1635
- )
1636
- drmsnorm_weight = None
1637
- else:
1638
- batch = dout.shape[0]
1639
- dy_rms = rearrange(dout, "b s h p -> (b s) (h p)")
1640
- dz = rearrange(dz, "b l d -> (b l) d")
1641
- x_rms = rearrange(out, "b s h p -> (b s) (h p)")
1642
- z_rms = rearrange(z, "b s h p -> (b s) (h p)")
1643
- out1_recompute = (
1644
- rearrange(out1_recompute, "b s d -> (b s) d")
1645
- if recompute_output
1646
- else None
1647
- )
1648
- dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(
1649
- dy_rms,
1650
- x_rms,
1651
- rmsnorm_weight,
1652
- None,
1653
- ctx.rmsnorm_eps,
1654
- None,
1655
- rstd,
1656
- z_rms,
1657
- group_size=dim // ctx.ngroups,
1658
- norm_before_gate=ctx.norm_before_gate,
1659
- is_rms_norm=True,
1660
- recompute_output=recompute_output,
1661
- dz=dz,
1662
- out=out1_recompute if recompute_output else None,
1663
- )
1664
- out_for_linear = out_recompute if recompute_output else None
1665
- dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim)
1666
- dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = (
1667
- _mamba_chunk_scan_combined_bwd(
1668
- dout,
1669
- x,
1670
- dt,
1671
- A,
1672
- B,
1673
- C,
1674
- out,
1675
- ctx.chunk_size,
1676
- D=D,
1677
- z=None,
1678
- dt_bias=dt_bias,
1679
- initial_states=initial_states,
1680
- dfinal_states=dfinal_states,
1681
- seq_idx=seq_idx,
1682
- dt_softplus=True,
1683
- dt_limit=ctx.dt_limit,
1684
- dx=dx,
1685
- ddt=ddt_given,
1686
- dB=dB,
1687
- dC=dC,
1688
- )
1689
- )
1690
-
1691
- if outproj_weight is not None:
1692
- doutproj_weight = torch.einsum("bso,bsd->od", dout_og, out_for_linear)
1693
- doutproj_bias = (
1694
- dout_og.sum(dim=(0, 1)) if outproj_bias is not None else None
1695
- )
1696
- else:
1697
- doutproj_weight, doutproj_bias = None, None
1698
- dxBC_given = rearrange(dxBC_given, "b s d -> b d s")
1699
- dxBC_given, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
1700
- rearrange(xBC, "b s d -> b d s"),
1701
- conv1d_weight,
1702
- conv1d_bias,
1703
- rearrange(dxBC, "b s d -> b d s"),
1704
- seq_idx,
1705
- None,
1706
- None,
1707
- dxBC_given,
1708
- False,
1709
- ctx.activation in ["silu", "swish"],
1710
- )
1711
- dxBC_given = rearrange(dxBC_given, "b d s -> b s d")
1712
- return (
1713
- dzxbcdt,
1714
- dweight,
1715
- dbias,
1716
- ddt_bias,
1717
- dA,
1718
- dD,
1719
- None,
1720
- dinitial_states,
1721
- None,
1722
- None,
1723
- None,
1724
- None,
1725
- drmsnorm_weight,
1726
- None,
1727
- doutproj_weight,
1728
- doutproj_bias,
1729
- None,
1730
- None,
1731
- None,
1732
- )
1733
-
1734
-
1735
- def mamba_split_conv1d_scan_combined(
1736
- zxbcdt,
1737
- conv1d_weight,
1738
- conv1d_bias,
1739
- dt_bias,
1740
- A,
1741
- D,
1742
- chunk_size,
1743
- initial_states=None,
1744
- seq_idx=None,
1745
- dt_limit=(0.0, float("inf")),
1746
- return_final_states=False,
1747
- activation="silu",
1748
- rmsnorm_weight=None,
1749
- rmsnorm_eps=1e-6,
1750
- outproj_weight=None,
1751
- outproj_bias=None,
1752
- headdim=None,
1753
- ngroups=1,
1754
- norm_before_gate=True,
1755
- ):
1756
- """
1757
- Argument:
1758
- zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
1759
- conv1d_weight: (dim + 2 * ngroups * dstate, width)
1760
- conv1d_bias: (dim + 2 * ngroups * dstate,)
1761
- dt_bias: (nheads,)
1762
- A: (nheads)
1763
- D: (nheads, headdim) or (nheads,)
1764
- initial_states: (batch, nheads, headdim, dstate)
1765
- seq_idx: (batch, seqlen), int32
1766
- rmsnorm_weight: (dim,)
1767
- outproj_weight: (out_dim, dim)
1768
- outproj_bias: (out_dim,)
1769
- headdim: if D is 1D, headdim must be passed in
1770
- norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
1771
- Return:
1772
- out: (batch, seqlen, dim)
1773
- """
1774
- return MambaSplitConv1dScanCombinedFn.apply(
1775
- zxbcdt,
1776
- conv1d_weight,
1777
- conv1d_bias,
1778
- dt_bias,
1779
- A,
1780
- D,
1781
- chunk_size,
1782
- initial_states,
1783
- seq_idx,
1784
- dt_limit,
1785
- return_final_states,
1786
- activation,
1787
- rmsnorm_weight,
1788
- rmsnorm_eps,
1789
- outproj_weight,
1790
- outproj_bias,
1791
- headdim,
1792
- ngroups,
1793
- norm_before_gate,
1794
- )
1795
-
1796
-
1797
- def mamba_split_conv1d_scan_ref(
1798
- zxbcdt,
1799
- conv1d_weight,
1800
- conv1d_bias,
1801
- dt_bias,
1802
- A,
1803
- D,
1804
- chunk_size,
1805
- dt_limit=(0.0, float("inf")),
1806
- activation="silu",
1807
- rmsnorm_weight=None,
1808
- rmsnorm_eps=1e-6,
1809
- outproj_weight=None,
1810
- outproj_bias=None,
1811
- headdim=None,
1812
- ngroups=1,
1813
- norm_before_gate=True,
1814
- ):
1815
- """
1816
- Argument:
1817
- zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
1818
- conv1d_weight: (dim + 2 * ngroups * dstate, width)
1819
- conv1d_bias: (dim + 2 * ngroups * dstate,)
1820
- dt_bias: (nheads,)
1821
- A: (nheads)
1822
- D: (nheads, headdim) or (nheads,)
1823
- rmsnorm_weight: (dim,)
1824
- outproj_weight: (out_dim, dim)
1825
- outproj_bias: (out_dim,)
1826
- headdim: if D is 1D, headdim must be passed in
1827
- norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
1828
- Return:
1829
- out: (batch, seqlen, dim)
1830
- """
1831
- if D.dim() == 1:
1832
- assert headdim is not None
1833
- (nheads,) = D.shape
1834
- else:
1835
- nheads, headdim = D.shape
1836
- assert nheads % ngroups == 0
1837
- batch, seqlen, _ = zxbcdt.shape
1838
- dim = nheads * headdim
1839
- dstate = (zxbcdt.shape[-1] - 2 * dim - nheads) // ngroups // 2
1840
- assert zxbcdt.shape == (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads)
1841
- assert dt_bias.shape == (nheads,)
1842
- assert A.shape == (nheads,)
1843
- if rmsnorm_weight is not None:
1844
- assert rmsnorm_weight.shape == (dim,)
1845
- z, xBC, dt = torch.split(zxbcdt, [dim, dim + 2 * ngroups * dstate, nheads], dim=-1)
1846
- xBC = rearrange(
1847
- causal_conv1d_fn(
1848
- rearrange(xBC, "b s d -> b d s"),
1849
- conv1d_weight,
1850
- conv1d_bias,
1851
- activation=activation,
1852
- ),
1853
- "b d s -> b s d",
1854
- )
1855
- x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
1856
- x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
1857
- B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
1858
- C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
1859
- z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
1860
- out = ssd_selective_scan(
1861
- x,
1862
- dt.to(x.dtype),
1863
- A,
1864
- B,
1865
- C,
1866
- D=D.float(),
1867
- z=z if rmsnorm_weight is None else None,
1868
- dt_bias=dt_bias,
1869
- dt_softplus=True,
1870
- dt_limit=dt_limit,
1871
- )
1872
- out = rearrange(out, "b s h p -> b s (h p)")
1873
- if rmsnorm_weight is not None:
1874
- out = rmsnorm_fn(
1875
- out,
1876
- rmsnorm_weight,
1877
- None,
1878
- z=rearrange(z, "b l h p -> b l (h p)"),
1879
- eps=rmsnorm_eps,
1880
- norm_before_gate=norm_before_gate,
1881
- )
1882
- if outproj_weight is not None:
1883
- out = F.linear(out, outproj_weight, outproj_bias)
1884
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/__init__.py DELETED
@@ -1,14 +0,0 @@
1
- __version__ = "2.2.4"
2
-
3
- from .ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
4
- from .modules.mamba_simple import Mamba
5
- from .modules.mamba2 import Mamba2
6
- from .models.mixer_seq_simple import MambaLMHeadModel
7
-
8
- __all__ = [
9
- "selective_scan_fn",
10
- "mamba_inner_fn",
11
- "Mamba",
12
- "Mamba2",
13
- "MambaLMHeadModel",
14
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py DELETED
@@ -1,326 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
- # The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
3
- from typing import Optional
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- from torch import Tensor
9
- from torch.distributed import ProcessGroup
10
- from ..utils.torch import custom_bwd, custom_fwd
11
-
12
- from einops import rearrange
13
-
14
- from ..distributed.distributed_utils import (
15
- all_gather_raw,
16
- all_reduce,
17
- all_reduce_raw,
18
- reduce_scatter,
19
- reduce_scatter_raw,
20
- )
21
-
22
-
23
- class ParallelLinearFunc(torch.autograd.Function):
24
- @staticmethod
25
- @custom_fwd
26
- def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
27
- """
28
- If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
29
- with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
30
- """
31
- ctx.compute_weight_gradient = weight.requires_grad
32
- ctx.process_group = process_group
33
- ctx.sequence_parallel = sequence_parallel
34
-
35
- if torch.is_autocast_enabled():
36
- x = x.to(dtype=torch.get_autocast_gpu_dtype())
37
- x = x.contiguous()
38
- if process_group is not None and sequence_parallel:
39
- # We want to kick off the all_gather early, before weight dtype conversion
40
- total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
41
- else:
42
- total_x = x
43
-
44
- if torch.is_autocast_enabled():
45
- weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
46
- bias = (
47
- bias.to(dtype=torch.get_autocast_gpu_dtype())
48
- if bias is not None
49
- else None
50
- )
51
- weight = weight.contiguous()
52
- if process_group is not None and sequence_parallel:
53
- handle_x.wait()
54
- batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
55
- batch_dim = batch_shape.numel()
56
- # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
57
- output = F.linear(total_x, weight, bias)
58
- if ctx.compute_weight_gradient:
59
- ctx.save_for_backward(x, weight)
60
- else:
61
- ctx.save_for_backward(weight)
62
- return output
63
-
64
- @staticmethod
65
- @custom_bwd
66
- def backward(ctx, grad_output):
67
- grad_output = grad_output.contiguous()
68
- process_group = ctx.process_group
69
- sequence_parallel = ctx.sequence_parallel
70
- if ctx.compute_weight_gradient:
71
- x, weight = ctx.saved_tensors
72
- if process_group is not None and sequence_parallel:
73
- total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
74
- else:
75
- total_x = x
76
- else:
77
- (weight,) = ctx.saved_tensors
78
- total_x = None
79
- batch_shape = grad_output.shape[:-1]
80
- batch_dim = batch_shape.numel()
81
- grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
82
- if ctx.needs_input_grad[0]:
83
- grad_input = F.linear(grad_output, weight.t())
84
- grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
85
- if process_group is not None:
86
- reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
87
- grad_input, handle_grad_input = reduce_fn(
88
- grad_input, process_group, async_op=True
89
- )
90
- else:
91
- grad_input = None
92
- if ctx.needs_input_grad[1]:
93
- assert ctx.compute_weight_gradient
94
- if process_group is not None and sequence_parallel:
95
- handle_x.wait()
96
- grad_weight = torch.einsum(
97
- "bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1])
98
- )
99
- else:
100
- grad_weight = None
101
- grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None
102
- if process_group is not None and ctx.needs_input_grad[0]:
103
- handle_grad_input.wait()
104
- return grad_input, grad_weight, grad_bias, None, None
105
-
106
-
107
- def parallel_linear_func(
108
- x: Tensor,
109
- weight: Tensor,
110
- bias: Optional[Tensor] = None,
111
- process_group: Optional[ProcessGroup] = None,
112
- sequence_parallel: bool = True,
113
- ):
114
- return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel)
115
-
116
-
117
- class ColumnParallelLinear(nn.Linear):
118
- def __init__(
119
- self,
120
- in_features: int,
121
- out_features: int,
122
- process_group: ProcessGroup,
123
- bias: bool = True,
124
- sequence_parallel=True,
125
- multiple_of=1,
126
- device=None,
127
- dtype=None,
128
- ) -> None:
129
- world_size = torch.distributed.get_world_size(process_group)
130
- if out_features % multiple_of:
131
- raise ValueError(
132
- f"out_features ({out_features}) must be a multiple of {multiple_of}"
133
- )
134
- multiple = out_features // multiple_of
135
- # We want to split @multiple across world_size, but it could be an uneven split
136
- div = multiple // world_size
137
- mod = multiple % world_size
138
- # The first @mod ranks get @div + 1 copies, the rest get @div copies
139
- local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
140
- super().__init__(
141
- in_features,
142
- local_multiple * multiple_of,
143
- bias=bias,
144
- device=device,
145
- dtype=dtype,
146
- )
147
- self.process_group = process_group
148
- self.sequence_parallel = sequence_parallel
149
-
150
- def forward(self, x):
151
- # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
152
- # we do an all_gather of x before doing the matmul.
153
- # If not, then the input is already gathered.
154
- return parallel_linear_func(
155
- x,
156
- self.weight,
157
- self.bias,
158
- process_group=self.process_group,
159
- sequence_parallel=self.sequence_parallel,
160
- )
161
-
162
-
163
- class RowParallelLinear(nn.Linear):
164
- def __init__(
165
- self,
166
- in_features: int,
167
- out_features: int,
168
- process_group: ProcessGroup,
169
- bias: bool = True,
170
- sequence_parallel=True,
171
- multiple_of=1,
172
- device=None,
173
- dtype=None,
174
- ) -> None:
175
- world_size = torch.distributed.get_world_size(process_group)
176
- rank = torch.distributed.get_rank(process_group)
177
- if in_features % multiple_of:
178
- raise ValueError(
179
- f"in_features ({in_features}) must be a multiple of {multiple_of}"
180
- )
181
- multiple = in_features // multiple_of
182
- # We want to split @multiple across world_size, but it could be an uneven split
183
- div = multiple // world_size
184
- mod = multiple % world_size
185
- # The first @mod ranks get @div + 1 copies, the rest get @div copies
186
- local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
187
- # Only rank 0 will have bias
188
- super().__init__(
189
- local_multiple * multiple_of,
190
- out_features,
191
- bias=bias and rank == 0,
192
- device=device,
193
- dtype=dtype,
194
- )
195
- self.process_group = process_group
196
- self.sequence_parallel = sequence_parallel
197
-
198
- def forward(self, x):
199
- """
200
- We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
201
- a reduce_scatter of the result.
202
- """
203
- out = parallel_linear_func(x, self.weight, self.bias)
204
- reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
205
- return reduce_fn(out, self.process_group)
206
-
207
-
208
- class VocabParallelEmbedding(nn.Embedding):
209
- def __init__(
210
- self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs
211
- ):
212
- self.process_group = process_group
213
- if process_group is not None:
214
- world_size = torch.distributed.get_world_size(process_group)
215
- if num_embeddings % world_size != 0:
216
- raise ValueError(
217
- f"num_embeddings ({num_embeddings}) must be divisible by "
218
- f"world_size ({world_size})"
219
- )
220
- if world_size > 1 and padding_idx is not None:
221
- raise RuntimeError("ParallelEmbedding does not support padding_idx")
222
- else:
223
- world_size = 1
224
- super().__init__(
225
- num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs
226
- )
227
-
228
- def forward(self, input: Tensor) -> Tensor:
229
- if self.process_group is None:
230
- return super().forward(input)
231
- else:
232
- rank = torch.distributed.get_rank(self.process_group)
233
- vocab_size = self.num_embeddings
234
- vocab_start_index, vocab_end_index = (
235
- rank * vocab_size,
236
- (rank + 1) * vocab_size,
237
- )
238
- # Create a mask of valid vocab ids (1 means it needs to be masked).
239
- input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
240
- input = input - vocab_start_index
241
- input[input_ids_mask] = 0
242
- embeddings = super().forward(input)
243
- embeddings[input_ids_mask] = 0.0
244
- return embeddings
245
-
246
-
247
- class ColumnParallelEmbedding(nn.Embedding):
248
- def __init__(
249
- self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs
250
- ):
251
- self.process_group = process_group
252
- if process_group is not None:
253
- world_size = torch.distributed.get_world_size(process_group)
254
- if embedding_dim % world_size != 0:
255
- raise ValueError(
256
- f"embedding_dim ({embedding_dim}) must be divisible by "
257
- f"world_size ({world_size})"
258
- )
259
- else:
260
- world_size = 1
261
- super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
262
-
263
-
264
- class ParallelEmbeddings(nn.Module):
265
- def __init__(
266
- self,
267
- embed_dim,
268
- vocab_size,
269
- max_position_embeddings,
270
- process_group,
271
- padding_idx=None,
272
- sequence_parallel=True,
273
- device=None,
274
- dtype=None,
275
- ):
276
- """
277
- If max_position_embeddings <= 0, there's no position embeddings
278
- """
279
- factory_kwargs = {"device": device, "dtype": dtype}
280
- super().__init__()
281
- self.process_group = process_group
282
- self.sequence_parallel = sequence_parallel
283
- self.word_embeddings = VocabParallelEmbedding(
284
- vocab_size,
285
- embed_dim,
286
- padding_idx=padding_idx,
287
- process_group=process_group,
288
- **factory_kwargs,
289
- )
290
- self.max_position_embeddings = max_position_embeddings
291
- if self.max_position_embeddings > 0:
292
- self.position_embeddings = ColumnParallelEmbedding(
293
- max_position_embeddings,
294
- embed_dim,
295
- process_group=process_group,
296
- **factory_kwargs,
297
- )
298
-
299
- def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
300
- """
301
- input_ids: (batch, seqlen)
302
- position_ids: (batch, seqlen)
303
- """
304
- batch_size, seqlen = input_ids.shape
305
- world_size = torch.distributed.get_world_size(self.process_group)
306
- embeddings = self.word_embeddings(input_ids)
307
- if self.max_position_embeddings > 0:
308
- if position_ids is None:
309
- position_ids = torch.arange(
310
- seqlen, dtype=torch.long, device=input_ids.device
311
- )
312
- position_embeddings = self.position_embeddings(position_ids)
313
- if world_size <= 1:
314
- embeddings = embeddings + position_embeddings
315
- else:
316
- partition_dim = self.position_embeddings.embedding_dim
317
- rank = torch.distributed.get_rank(self.process_group)
318
- embeddings[
319
- ..., rank * partition_dim : (rank + 1) * partition_dim
320
- ] += position_embeddings
321
- if combine_batch_seqlen_dim:
322
- embeddings = rearrange(embeddings, "b s d -> (b s) d")
323
- reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
324
- return (
325
- embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
326
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py DELETED
@@ -1,338 +0,0 @@
1
- # Copyright (c) 2023, Albert Gu, Tri Dao.
2
-
3
- import math
4
- from functools import partial
5
- import json
6
- import os
7
- import copy
8
-
9
- from collections import namedtuple
10
-
11
- import torch
12
- import torch.nn as nn
13
-
14
- from .config_mamba import MambaConfig
15
- from ..modules.mamba_simple import Mamba
16
- from ..modules.mamba2 import Mamba2
17
- from ..modules.mha import MHA
18
- from ..modules.mlp import GatedMLP
19
- from ..modules.block import Block
20
- from ..utils.generation import GenerationMixin
21
- from ..utils.hf import load_config_hf, load_state_dict_hf
22
-
23
- try:
24
- from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
25
- except ImportError:
26
- RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
27
-
28
-
29
- def create_block(
30
- d_model,
31
- d_intermediate,
32
- ssm_cfg=None,
33
- attn_layer_idx=None,
34
- attn_cfg=None,
35
- norm_epsilon=1e-5,
36
- rms_norm=False,
37
- residual_in_fp32=False,
38
- fused_add_norm=False,
39
- layer_idx=None,
40
- device=None,
41
- dtype=None,
42
- ):
43
- if ssm_cfg is None:
44
- ssm_cfg = {}
45
- if attn_layer_idx is None:
46
- attn_layer_idx = []
47
- if attn_cfg is None:
48
- attn_cfg = {}
49
- factory_kwargs = {"device": device, "dtype": dtype}
50
- if layer_idx not in attn_layer_idx:
51
- # Create a copy of the config to modify
52
- ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
53
- ssm_layer = ssm_cfg.pop("layer", "Mamba1")
54
- if ssm_layer not in ["Mamba1", "Mamba2"]:
55
- raise ValueError(
56
- f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2"
57
- )
58
- mixer_cls = partial(
59
- Mamba2 if ssm_layer == "Mamba2" else Mamba,
60
- layer_idx=layer_idx,
61
- **ssm_cfg,
62
- **factory_kwargs,
63
- )
64
- else:
65
- mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
66
- norm_cls = partial(
67
- nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
68
- )
69
- if d_intermediate == 0:
70
- mlp_cls = nn.Identity
71
- else:
72
- mlp_cls = partial(
73
- GatedMLP,
74
- hidden_features=d_intermediate,
75
- out_features=d_model,
76
- **factory_kwargs,
77
- )
78
- block = Block(
79
- d_model,
80
- mixer_cls,
81
- mlp_cls,
82
- norm_cls=norm_cls,
83
- fused_add_norm=fused_add_norm,
84
- residual_in_fp32=residual_in_fp32,
85
- )
86
- block.layer_idx = layer_idx
87
- return block
88
-
89
-
90
- # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
91
- def _init_weights(
92
- module,
93
- n_layer,
94
- initializer_range=0.02, # Now only used for embedding layer.
95
- rescale_prenorm_residual=True,
96
- n_residuals_per_layer=1, # Change to 2 if we have MLP
97
- ):
98
- if isinstance(module, nn.Linear):
99
- if module.bias is not None:
100
- if not getattr(module.bias, "_no_reinit", False):
101
- nn.init.zeros_(module.bias)
102
- elif isinstance(module, nn.Embedding):
103
- nn.init.normal_(module.weight, std=initializer_range)
104
-
105
- if rescale_prenorm_residual:
106
- # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
107
- # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
108
- # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
109
- # > -- GPT-2 :: https://openai.com/blog/better-language-models/
110
- #
111
- # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
112
- for name, p in module.named_parameters():
113
- if name in ["out_proj.weight", "fc2.weight"]:
114
- # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
115
- # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
116
- # We need to reinit p since this code could be called multiple times
117
- # Having just p *= scale would repeatedly scale it down
118
- nn.init.kaiming_uniform_(p, a=math.sqrt(5))
119
- with torch.no_grad():
120
- p /= math.sqrt(n_residuals_per_layer * n_layer)
121
-
122
-
123
- class MixerModel(nn.Module):
124
- def __init__(
125
- self,
126
- d_model: int,
127
- n_layer: int,
128
- d_intermediate: int,
129
- vocab_size: int,
130
- ssm_cfg=None,
131
- attn_layer_idx=None,
132
- attn_cfg=None,
133
- norm_epsilon: float = 1e-5,
134
- rms_norm: bool = False,
135
- initializer_cfg=None,
136
- fused_add_norm=False,
137
- residual_in_fp32=False,
138
- device=None,
139
- dtype=None,
140
- ) -> None:
141
- factory_kwargs = {"device": device, "dtype": dtype}
142
- super().__init__()
143
- self.residual_in_fp32 = residual_in_fp32
144
-
145
- self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
146
-
147
- # We change the order of residual and layer norm:
148
- # Instead of LN -> Attn / MLP -> Add, we do:
149
- # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
150
- # the main branch (output of MLP / Mixer). The model definition is unchanged.
151
- # This is for performance reason: we can fuse add + layer_norm.
152
- self.fused_add_norm = fused_add_norm
153
- if self.fused_add_norm:
154
- if layer_norm_fn is None or rms_norm_fn is None:
155
- raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
156
-
157
- self.layers = nn.ModuleList(
158
- [
159
- create_block(
160
- d_model,
161
- d_intermediate=d_intermediate,
162
- ssm_cfg=ssm_cfg,
163
- attn_layer_idx=attn_layer_idx,
164
- attn_cfg=attn_cfg,
165
- norm_epsilon=norm_epsilon,
166
- rms_norm=rms_norm,
167
- residual_in_fp32=residual_in_fp32,
168
- fused_add_norm=fused_add_norm,
169
- layer_idx=i,
170
- **factory_kwargs,
171
- )
172
- for i in range(n_layer)
173
- ]
174
- )
175
-
176
- self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
177
- d_model, eps=norm_epsilon, **factory_kwargs
178
- )
179
-
180
- self.apply(
181
- partial(
182
- _init_weights,
183
- n_layer=n_layer,
184
- **(initializer_cfg if initializer_cfg is not None else {}),
185
- n_residuals_per_layer=(
186
- 1 if d_intermediate == 0 else 2
187
- ), # 2 if we have MLP
188
- )
189
- )
190
-
191
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
192
- return {
193
- i: layer.allocate_inference_cache(
194
- batch_size, max_seqlen, dtype=dtype, **kwargs
195
- )
196
- for i, layer in enumerate(self.layers)
197
- }
198
-
199
- def forward(self, input_ids, inference_params=None, **mixer_kwargs):
200
- hidden_states = self.embedding(input_ids)
201
- residual = None
202
- for layer in self.layers:
203
- hidden_states, residual = layer(
204
- hidden_states,
205
- residual,
206
- inference_params=inference_params,
207
- **mixer_kwargs,
208
- )
209
- if not self.fused_add_norm:
210
- residual = (
211
- (hidden_states + residual) if residual is not None else hidden_states
212
- )
213
- hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
214
- else:
215
- # Set prenorm=False here since we don't need the residual
216
- hidden_states = layer_norm_fn(
217
- hidden_states,
218
- self.norm_f.weight,
219
- self.norm_f.bias,
220
- eps=self.norm_f.eps,
221
- residual=residual,
222
- prenorm=False,
223
- residual_in_fp32=self.residual_in_fp32,
224
- is_rms_norm=isinstance(self.norm_f, RMSNorm),
225
- )
226
- return hidden_states
227
-
228
-
229
- class MambaLMHeadModel(nn.Module, GenerationMixin):
230
-
231
- def __init__(
232
- self,
233
- config: MambaConfig,
234
- initializer_cfg=None,
235
- device=None,
236
- dtype=None,
237
- ) -> None:
238
- self.config = config
239
- d_model = config.d_model
240
- n_layer = config.n_layer
241
- d_intermediate = config.d_intermediate
242
- vocab_size = config.vocab_size
243
- ssm_cfg = config.ssm_cfg
244
- attn_layer_idx = config.attn_layer_idx
245
- attn_cfg = config.attn_cfg
246
- rms_norm = config.rms_norm
247
- residual_in_fp32 = config.residual_in_fp32
248
- fused_add_norm = config.fused_add_norm
249
- pad_vocab_size_multiple = config.pad_vocab_size_multiple
250
- factory_kwargs = {"device": device, "dtype": dtype}
251
-
252
- super().__init__()
253
- if vocab_size % pad_vocab_size_multiple != 0:
254
- vocab_size += pad_vocab_size_multiple - (
255
- vocab_size % pad_vocab_size_multiple
256
- )
257
- self.backbone = MixerModel(
258
- d_model=d_model,
259
- n_layer=n_layer,
260
- d_intermediate=d_intermediate,
261
- vocab_size=vocab_size,
262
- ssm_cfg=ssm_cfg,
263
- attn_layer_idx=attn_layer_idx,
264
- attn_cfg=attn_cfg,
265
- rms_norm=rms_norm,
266
- initializer_cfg=initializer_cfg,
267
- fused_add_norm=fused_add_norm,
268
- residual_in_fp32=residual_in_fp32,
269
- **factory_kwargs,
270
- )
271
- self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
272
-
273
- # Initialize weights and apply final processing
274
- self.apply(
275
- partial(
276
- _init_weights,
277
- n_layer=n_layer,
278
- **(initializer_cfg if initializer_cfg is not None else {}),
279
- )
280
- )
281
- self.tie_weights()
282
-
283
- def tie_weights(self):
284
- if self.config.tie_embeddings:
285
- self.lm_head.weight = self.backbone.embedding.weight
286
-
287
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
288
- return self.backbone.allocate_inference_cache(
289
- batch_size, max_seqlen, dtype=dtype, **kwargs
290
- )
291
-
292
- def forward(
293
- self,
294
- input_ids,
295
- position_ids=None,
296
- inference_params=None,
297
- num_last_tokens=0,
298
- **mixer_kwargs,
299
- ):
300
- """
301
- "position_ids" is just to be compatible with Transformer generation. We don't use it.
302
- num_last_tokens: if > 0, only return the logits for the last n tokens
303
- """
304
- hidden_states = self.backbone(
305
- input_ids, inference_params=inference_params, **mixer_kwargs
306
- )
307
- if num_last_tokens > 0:
308
- hidden_states = hidden_states[:, -num_last_tokens:]
309
- lm_logits = self.lm_head(hidden_states)
310
- CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
311
- return CausalLMOutput(logits=lm_logits)
312
-
313
- @classmethod
314
- def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
315
- config_data = load_config_hf(pretrained_model_name)
316
- config = MambaConfig(**config_data)
317
- model = cls(config, device=device, dtype=dtype, **kwargs)
318
- model.load_state_dict(
319
- load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)
320
- )
321
- return model
322
-
323
- def save_pretrained(self, save_directory):
324
- """
325
- Minimal implementation of save_pretrained for MambaLMHeadModel.
326
- Save the model and its configuration file to a directory.
327
- """
328
- # Ensure save_directory exists
329
- os.makedirs(save_directory, exist_ok=True)
330
-
331
- # Save the model's state_dict
332
- model_path = os.path.join(save_directory, "pytorch_model.bin")
333
- torch.save(self.state_dict(), model_path)
334
-
335
- # Save the configuration of the model
336
- config_path = os.path.join(save_directory, "config.json")
337
- with open(config_path, "w") as f:
338
- json.dump(self.config.__dict__, f, indent=4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/ops/selective_scan_interface.py DELETED
@@ -1,659 +0,0 @@
1
- # Copyright (c) 2023, Tri Dao, Albert Gu.
2
-
3
- import torch
4
- import torch.nn.functional as F
5
- from ..utils.torch import custom_fwd, custom_bwd
6
-
7
- from einops import rearrange, repeat
8
-
9
- try:
10
- from causal_conv1d import causal_conv1d_fn
11
- import causal_conv1d_cuda
12
- except ImportError:
13
- causal_conv1d_fn = None
14
- causal_conv1d_cuda = None
15
-
16
- from .triton.layer_norm import _layer_norm_fwd
17
-
18
- from .._ops import ops
19
-
20
-
21
- class SelectiveScanFn(torch.autograd.Function):
22
-
23
- @staticmethod
24
- def forward(
25
- ctx,
26
- u,
27
- delta,
28
- A,
29
- B,
30
- C,
31
- D=None,
32
- z=None,
33
- delta_bias=None,
34
- delta_softplus=False,
35
- return_last_state=False,
36
- ):
37
- if u.stride(-1) != 1:
38
- u = u.contiguous()
39
- if delta.stride(-1) != 1:
40
- delta = delta.contiguous()
41
- if D is not None:
42
- D = D.contiguous()
43
- if B.stride(-1) != 1:
44
- B = B.contiguous()
45
- if C.stride(-1) != 1:
46
- C = C.contiguous()
47
- if z is not None and z.stride(-1) != 1:
48
- z = z.contiguous()
49
- if B.dim() == 3:
50
- B = rearrange(B, "b dstate l -> b 1 dstate l")
51
- ctx.squeeze_B = True
52
- if C.dim() == 3:
53
- C = rearrange(C, "b dstate l -> b 1 dstate l")
54
- ctx.squeeze_C = True
55
- out, x, *rest = ops.selective_scan_fwd(
56
- u, delta, A, B, C, D, z, delta_bias, delta_softplus
57
- )
58
- ctx.delta_softplus = delta_softplus
59
- ctx.has_z = z is not None
60
- last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
61
- if not ctx.has_z:
62
- ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
63
- return out if not return_last_state else (out, last_state)
64
- else:
65
- ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
66
- out_z = rest[0]
67
- return out_z if not return_last_state else (out_z, last_state)
68
-
69
- @staticmethod
70
- def backward(ctx, dout, *args):
71
- if not ctx.has_z:
72
- u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
73
- z = None
74
- out = None
75
- else:
76
- u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
77
- if dout.stride(-1) != 1:
78
- dout = dout.contiguous()
79
- # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
80
- # backward of selective_scan_cuda with the backward of chunk).
81
- # Here we just pass in None and dz will be allocated in the C++ code.
82
- du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = ops.selective_scan_bwd(
83
- u,
84
- delta,
85
- A,
86
- B,
87
- C,
88
- D,
89
- z,
90
- delta_bias,
91
- dout,
92
- x,
93
- out,
94
- None,
95
- ctx.delta_softplus,
96
- False, # option to recompute out_z, not used here
97
- )
98
- dz = rest[0] if ctx.has_z else None
99
- dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
100
- dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
101
- return (
102
- du,
103
- ddelta,
104
- dA,
105
- dB,
106
- dC,
107
- dD if D is not None else None,
108
- dz,
109
- ddelta_bias if delta_bias is not None else None,
110
- None,
111
- None,
112
- )
113
-
114
-
115
- def rms_norm_forward(
116
- x,
117
- weight,
118
- bias,
119
- eps=1e-6,
120
- is_rms_norm=True,
121
- ):
122
- # x (b l) d
123
- if x.stride(-1) != 1:
124
- x = x.contiguous()
125
- weight = weight.contiguous()
126
- if bias is not None:
127
- bias = bias.contiguous()
128
- y = _layer_norm_fwd(
129
- x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm
130
- )[0]
131
- # y (b l) d
132
- return y
133
-
134
-
135
- def selective_scan_fn(
136
- u,
137
- delta,
138
- A,
139
- B,
140
- C,
141
- D=None,
142
- z=None,
143
- delta_bias=None,
144
- delta_softplus=False,
145
- return_last_state=False,
146
- ):
147
- """if return_last_state is True, returns (out, last_state)
148
- last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
149
- not considered in the backward pass.
150
- """
151
- return SelectiveScanFn.apply(
152
- u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state
153
- )
154
-
155
-
156
- def selective_scan_ref(
157
- u,
158
- delta,
159
- A,
160
- B,
161
- C,
162
- D=None,
163
- z=None,
164
- delta_bias=None,
165
- delta_softplus=False,
166
- return_last_state=False,
167
- ):
168
- """
169
- u: r(B D L)
170
- delta: r(B D L)
171
- A: c(D N) or r(D N)
172
- B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
173
- C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
174
- D: r(D)
175
- z: r(B D L)
176
- delta_bias: r(D), fp32
177
-
178
- out: r(B D L)
179
- last_state (optional): r(B D dstate) or c(B D dstate)
180
- """
181
- dtype_in = u.dtype
182
- u = u.float()
183
- delta = delta.float()
184
- if delta_bias is not None:
185
- delta = delta + delta_bias[..., None].float()
186
- if delta_softplus:
187
- delta = F.softplus(delta)
188
- batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
189
- is_variable_B = B.dim() >= 3
190
- is_variable_C = C.dim() >= 3
191
- if A.is_complex():
192
- if is_variable_B:
193
- B = torch.view_as_complex(
194
- rearrange(B.float(), "... (L two) -> ... L two", two=2)
195
- )
196
- if is_variable_C:
197
- C = torch.view_as_complex(
198
- rearrange(C.float(), "... (L two) -> ... L two", two=2)
199
- )
200
- else:
201
- B = B.float()
202
- C = C.float()
203
- x = A.new_zeros((batch, dim, dstate))
204
- ys = []
205
- deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
206
- if not is_variable_B:
207
- deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
208
- else:
209
- if B.dim() == 3:
210
- deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
211
- else:
212
- B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
213
- deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
214
- if is_variable_C and C.dim() == 4:
215
- C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
216
- last_state = None
217
- for i in range(u.shape[2]):
218
- x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
219
- if not is_variable_C:
220
- y = torch.einsum("bdn,dn->bd", x, C)
221
- else:
222
- if C.dim() == 3:
223
- y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
224
- else:
225
- y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
226
- if i == u.shape[2] - 1:
227
- last_state = x
228
- if y.is_complex():
229
- y = y.real * 2
230
- ys.append(y)
231
- y = torch.stack(ys, dim=2) # (batch dim L)
232
- out = y if D is None else y + u * rearrange(D, "d -> d 1")
233
- if z is not None:
234
- out = out * F.silu(z)
235
- out = out.to(dtype=dtype_in)
236
- return out if not return_last_state else (out, last_state)
237
-
238
-
239
- class MambaInnerFn(torch.autograd.Function):
240
-
241
- @staticmethod
242
- @custom_fwd
243
- def forward(
244
- ctx,
245
- xz,
246
- conv1d_weight,
247
- conv1d_bias,
248
- x_proj_weight,
249
- delta_proj_weight,
250
- out_proj_weight,
251
- out_proj_bias,
252
- A,
253
- B=None,
254
- C=None,
255
- D=None,
256
- delta_bias=None,
257
- B_proj_bias=None,
258
- C_proj_bias=None,
259
- delta_softplus=True,
260
- checkpoint_lvl=1,
261
- b_rms_weight=None,
262
- c_rms_weight=None,
263
- dt_rms_weight=None,
264
- b_c_dt_rms_eps=1e-6,
265
- ):
266
- """
267
- xz: (batch, dim, seqlen)
268
- """
269
- assert (
270
- causal_conv1d_cuda is not None
271
- ), "causal_conv1d_cuda is not available. Please install causal-conv1d."
272
- assert checkpoint_lvl in [0, 1]
273
- L = xz.shape[-1]
274
- delta_rank = delta_proj_weight.shape[1]
275
- d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
276
- if torch.is_autocast_enabled():
277
- x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
278
- delta_proj_weight = delta_proj_weight.to(
279
- dtype=torch.get_autocast_gpu_dtype()
280
- )
281
- out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
282
- out_proj_bias = (
283
- out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
284
- if out_proj_bias is not None
285
- else None
286
- )
287
- if xz.stride(-1) != 1:
288
- xz = xz.contiguous()
289
- conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
290
- x, z = xz.chunk(2, dim=1)
291
- conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
292
- conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
293
- x, conv1d_weight, conv1d_bias, None, None, None, True
294
- )
295
- # We're being very careful here about the layout, to avoid extra transposes.
296
- # We want delta to have d as the slowest moving dimension
297
- # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
298
- x_dbl = F.linear(
299
- rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight
300
- ) # (bl d)
301
- delta = rearrange(
302
- delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L
303
- )
304
- ctx.is_variable_B = B is None
305
- ctx.is_variable_C = C is None
306
- ctx.B_proj_bias_is_None = B_proj_bias is None
307
- ctx.C_proj_bias_is_None = C_proj_bias is None
308
- if B is None: # variable B
309
- B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate)
310
- if B_proj_bias is not None:
311
- B = B + B_proj_bias.to(dtype=B.dtype)
312
- if not A.is_complex():
313
- # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
314
- B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
315
- else:
316
- B = rearrange(
317
- B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2
318
- ).contiguous()
319
- else:
320
- if B.stride(-1) != 1:
321
- B = B.contiguous()
322
- if C is None: # variable C
323
- C = x_dbl[:, -d_state:] # (bl dstate)
324
- if C_proj_bias is not None:
325
- C = C + C_proj_bias.to(dtype=C.dtype)
326
- if not A.is_complex():
327
- # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
328
- C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
329
- else:
330
- C = rearrange(
331
- C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2
332
- ).contiguous()
333
- else:
334
- if C.stride(-1) != 1:
335
- C = C.contiguous()
336
- if D is not None:
337
- D = D.contiguous()
338
-
339
- if b_rms_weight is not None:
340
- B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
341
- B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)
342
- B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
343
- if c_rms_weight is not None:
344
- C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
345
- C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)
346
- C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
347
- if dt_rms_weight is not None:
348
- delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
349
- delta = rms_norm_forward(
350
- delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps
351
- )
352
- delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
353
-
354
- out, scan_intermediates, out_z = ops.selective_scan_fwd(
355
- conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
356
- )
357
- ctx.delta_softplus = delta_softplus
358
- ctx.out_proj_bias_is_None = out_proj_bias is None
359
- ctx.checkpoint_lvl = checkpoint_lvl
360
- ctx.b_rms_weight = b_rms_weight
361
- ctx.c_rms_weight = c_rms_weight
362
- ctx.dt_rms_weight = dt_rms_weight
363
- ctx.b_c_dt_rms_eps = b_c_dt_rms_eps
364
- if (
365
- checkpoint_lvl >= 1
366
- ): # Will recompute conv1d_out and delta in the backward pass
367
- conv1d_out, delta = None, None
368
- ctx.save_for_backward(
369
- xz,
370
- conv1d_weight,
371
- conv1d_bias,
372
- x_dbl,
373
- x_proj_weight,
374
- delta_proj_weight,
375
- out_proj_weight,
376
- conv1d_out,
377
- delta,
378
- A,
379
- B,
380
- C,
381
- D,
382
- delta_bias,
383
- scan_intermediates,
384
- b_rms_weight,
385
- c_rms_weight,
386
- dt_rms_weight,
387
- out,
388
- )
389
- return F.linear(
390
- rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias
391
- )
392
-
393
- @staticmethod
394
- @custom_bwd
395
- def backward(ctx, dout):
396
- # dout: (batch, seqlen, dim)
397
- assert (
398
- causal_conv1d_cuda is not None
399
- ), "causal_conv1d_cuda is not available. Please install causal-conv1d."
400
- (
401
- xz,
402
- conv1d_weight,
403
- conv1d_bias,
404
- x_dbl,
405
- x_proj_weight,
406
- delta_proj_weight,
407
- out_proj_weight,
408
- conv1d_out,
409
- delta,
410
- A,
411
- B,
412
- C,
413
- D,
414
- delta_bias,
415
- scan_intermediates,
416
- b_rms_weight,
417
- c_rms_weight,
418
- dt_rms_weight,
419
- out,
420
- ) = ctx.saved_tensors
421
- L = xz.shape[-1]
422
- delta_rank = delta_proj_weight.shape[1]
423
- d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
424
- x, z = xz.chunk(2, dim=1)
425
- if dout.stride(-1) != 1:
426
- dout = dout.contiguous()
427
- if ctx.checkpoint_lvl == 1:
428
- conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
429
- x, conv1d_weight, conv1d_bias, None, None, None, True
430
- )
431
- delta = rearrange(
432
- delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L
433
- )
434
- if dt_rms_weight is not None:
435
- delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
436
- delta = rms_norm_forward(
437
- delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps
438
- )
439
- delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
440
- if b_rms_weight is not None:
441
- # Recompute & RMSNorm B
442
- B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
443
- B = rms_norm_forward(B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps)
444
- B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
445
- if c_rms_weight is not None:
446
- # Recompute & RMSNorm C
447
- C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
448
- C = rms_norm_forward(C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps)
449
- C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
450
-
451
- # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
452
- # backward of selective_scan_cuda with the backward of chunk).
453
- dxz = torch.empty_like(xz) # (batch, dim, seqlen)
454
- dx, dz = dxz.chunk(2, dim=1)
455
- dout = rearrange(dout, "b l e -> e (b l)")
456
- dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
457
- dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = (
458
- ops.selective_scan_bwd(
459
- conv1d_out,
460
- delta,
461
- A,
462
- B,
463
- C,
464
- D,
465
- z,
466
- delta_bias,
467
- dout_y,
468
- scan_intermediates,
469
- out,
470
- dz,
471
- ctx.delta_softplus,
472
- True, # option to recompute out_z
473
- )
474
- )
475
- dout_proj_weight = torch.einsum(
476
- "eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")
477
- )
478
- dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
479
- dD = dD if D is not None else None
480
- dx_dbl = torch.empty_like(x_dbl)
481
- dB_proj_bias = None
482
- if ctx.is_variable_B:
483
- if not A.is_complex():
484
- dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
485
- else:
486
- dB = rearrange(
487
- dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2
488
- ).contiguous()
489
- dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
490
- dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d)
491
- dB = None
492
- dC_proj_bias = None
493
- if ctx.is_variable_C:
494
- if not A.is_complex():
495
- dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
496
- else:
497
- dC = rearrange(
498
- dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2
499
- ).contiguous()
500
- dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
501
- dx_dbl[:, -d_state:] = dC # (bl d)
502
- dC = None
503
- ddelta = rearrange(ddelta, "b d l -> d (b l)")
504
- ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
505
- dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
506
- dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
507
- dx_proj_weight = torch.einsum(
508
- "Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")
509
- )
510
- dconv1d_out = torch.addmm(
511
- dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out
512
- )
513
- dconv1d_out = rearrange(
514
- dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]
515
- )
516
- # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
517
- # backward of conv1d with the backward of chunk).
518
- dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
519
- x,
520
- conv1d_weight,
521
- conv1d_bias,
522
- dconv1d_out,
523
- None,
524
- None,
525
- None,
526
- dx,
527
- False,
528
- True,
529
- )
530
- dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
531
- dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
532
- return (
533
- dxz,
534
- dconv1d_weight,
535
- dconv1d_bias,
536
- dx_proj_weight,
537
- ddelta_proj_weight,
538
- dout_proj_weight,
539
- dout_proj_bias,
540
- dA,
541
- dB,
542
- dC,
543
- dD,
544
- ddelta_bias if delta_bias is not None else None,
545
- # 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps
546
- dB_proj_bias,
547
- dC_proj_bias,
548
- None,
549
- None,
550
- None,
551
- None,
552
- None,
553
- None,
554
- )
555
-
556
-
557
- def mamba_inner_fn(
558
- xz,
559
- conv1d_weight,
560
- conv1d_bias,
561
- x_proj_weight,
562
- delta_proj_weight,
563
- out_proj_weight,
564
- out_proj_bias,
565
- A,
566
- B=None,
567
- C=None,
568
- D=None,
569
- delta_bias=None,
570
- B_proj_bias=None,
571
- C_proj_bias=None,
572
- delta_softplus=True,
573
- checkpoint_lvl=1,
574
- b_rms_weight=None,
575
- c_rms_weight=None,
576
- dt_rms_weight=None,
577
- b_c_dt_rms_eps=1e-6,
578
- ):
579
- return MambaInnerFn.apply(
580
- xz,
581
- conv1d_weight,
582
- conv1d_bias,
583
- x_proj_weight,
584
- delta_proj_weight,
585
- out_proj_weight,
586
- out_proj_bias,
587
- A,
588
- B,
589
- C,
590
- D,
591
- delta_bias,
592
- B_proj_bias,
593
- C_proj_bias,
594
- delta_softplus,
595
- checkpoint_lvl,
596
- b_rms_weight,
597
- c_rms_weight,
598
- dt_rms_weight,
599
- b_c_dt_rms_eps,
600
- )
601
-
602
-
603
- def mamba_inner_ref(
604
- xz,
605
- conv1d_weight,
606
- conv1d_bias,
607
- x_proj_weight,
608
- delta_proj_weight,
609
- out_proj_weight,
610
- out_proj_bias,
611
- A,
612
- B=None,
613
- C=None,
614
- D=None,
615
- delta_bias=None,
616
- B_proj_bias=None,
617
- C_proj_bias=None,
618
- delta_softplus=True,
619
- ):
620
- assert (
621
- causal_conv1d_fn is not None
622
- ), "causal_conv1d_fn is not available. Please install causal-conv1d."
623
- L = xz.shape[-1]
624
- delta_rank = delta_proj_weight.shape[1]
625
- d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
626
- x, z = xz.chunk(2, dim=1)
627
- x = causal_conv1d_fn(
628
- x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu"
629
- )
630
- # We're being very careful here about the layout, to avoid extra transposes.
631
- # We want delta to have d as the slowest moving dimension
632
- # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
633
- x_dbl = F.linear(rearrange(x, "b d l -> (b l) d"), x_proj_weight) # (bl d)
634
- delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
635
- delta = rearrange(delta, "d (b l) -> b d l", l=L)
636
- if B is None: # variable B
637
- B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl d)
638
- if B_proj_bias is not None:
639
- B = B + B_proj_bias.to(dtype=B.dtype)
640
- if not A.is_complex():
641
- B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
642
- else:
643
- B = rearrange(
644
- B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2
645
- ).contiguous()
646
- if C is None: # variable B
647
- C = x_dbl[:, -d_state:] # (bl d)
648
- if C_proj_bias is not None:
649
- C = C + C_proj_bias.to(dtype=C.dtype)
650
- if not A.is_complex():
651
- C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
652
- else:
653
- C = rearrange(
654
- C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2
655
- ).contiguous()
656
- y = selective_scan_fn(
657
- x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True
658
- )
659
- return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/ops/triton/layer_norm.py DELETED
@@ -1,1166 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
- # Implement dropout + residual + layer_norm / rms_norm.
3
-
4
- # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
5
- # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
6
- # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
7
- # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
8
-
9
- import math
10
- import warnings
11
-
12
- import torch
13
- import torch.nn.functional as F
14
- from ...utils.torch import custom_bwd, custom_fwd
15
-
16
- import triton
17
- import triton.language as tl
18
-
19
-
20
- def layer_norm_ref(
21
- x,
22
- weight,
23
- bias,
24
- residual=None,
25
- x1=None,
26
- weight1=None,
27
- bias1=None,
28
- eps=1e-6,
29
- dropout_p=0.0,
30
- rowscale=None,
31
- prenorm=False,
32
- dropout_mask=None,
33
- dropout_mask1=None,
34
- upcast=False,
35
- ):
36
- dtype = x.dtype
37
- if upcast:
38
- x = x.float()
39
- weight = weight.float()
40
- bias = bias.float() if bias is not None else None
41
- residual = residual.float() if residual is not None else residual
42
- x1 = x1.float() if x1 is not None else None
43
- weight1 = weight1.float() if weight1 is not None else None
44
- bias1 = bias1.float() if bias1 is not None else None
45
- if x1 is not None:
46
- assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
47
- if rowscale is not None:
48
- x = x * rowscale[..., None]
49
- if dropout_p > 0.0:
50
- if dropout_mask is not None:
51
- x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
52
- else:
53
- x = F.dropout(x, p=dropout_p)
54
- if x1 is not None:
55
- if dropout_mask1 is not None:
56
- x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
57
- else:
58
- x1 = F.dropout(x1, p=dropout_p)
59
- if x1 is not None:
60
- x = x + x1
61
- if residual is not None:
62
- x = (x + residual).to(x.dtype)
63
- out = F.layer_norm(
64
- x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps
65
- ).to(dtype)
66
- if weight1 is None:
67
- return out if not prenorm else (out, x)
68
- else:
69
- out1 = F.layer_norm(
70
- x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
71
- ).to(dtype)
72
- return (out, out1) if not prenorm else (out, out1, x)
73
-
74
-
75
- def rms_norm_ref(
76
- x,
77
- weight,
78
- bias,
79
- residual=None,
80
- x1=None,
81
- weight1=None,
82
- bias1=None,
83
- eps=1e-6,
84
- dropout_p=0.0,
85
- rowscale=None,
86
- prenorm=False,
87
- dropout_mask=None,
88
- dropout_mask1=None,
89
- upcast=False,
90
- ):
91
- dtype = x.dtype
92
- if upcast:
93
- x = x.float()
94
- weight = weight.float()
95
- bias = bias.float() if bias is not None else None
96
- residual = residual.float() if residual is not None else residual
97
- x1 = x1.float() if x1 is not None else None
98
- weight1 = weight1.float() if weight1 is not None else None
99
- bias1 = bias1.float() if bias1 is not None else None
100
- if x1 is not None:
101
- assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
102
- if rowscale is not None:
103
- x = x * rowscale[..., None]
104
- if dropout_p > 0.0:
105
- if dropout_mask is not None:
106
- x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
107
- else:
108
- x = F.dropout(x, p=dropout_p)
109
- if x1 is not None:
110
- if dropout_mask1 is not None:
111
- x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
112
- else:
113
- x1 = F.dropout(x1, p=dropout_p)
114
- if x1 is not None:
115
- x = x + x1
116
- if residual is not None:
117
- x = (x + residual).to(x.dtype)
118
- rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
119
- out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(
120
- dtype
121
- )
122
- if weight1 is None:
123
- return out if not prenorm else (out, x)
124
- else:
125
- out1 = (
126
- (x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)
127
- ).to(dtype)
128
- return (out, out1) if not prenorm else (out, out1, x)
129
-
130
-
131
- def config_prune(configs):
132
-
133
- if torch.version.hip:
134
- try:
135
- # set warp size based on gcn architecure
136
- gcn_arch_name = torch.cuda.get_device_properties(0).gcnArchName
137
- if "gfx10" in gcn_arch_name or "gfx11" in gcn_arch_name:
138
- # radeon
139
- warp_size = 32
140
- else:
141
- # instinct
142
- warp_size = 64
143
- except AttributeError as e:
144
- # fall back to crude method to set warp size
145
- device_name = torch.cuda.get_device_properties(0).name
146
- if "instinct" in device_name.lower():
147
- warp_size = 64
148
- else:
149
- warp_size = 32
150
- warnings.warn(
151
- f"{e}, warp size set to {warp_size} based on device name: {device_name}",
152
- UserWarning,
153
- )
154
-
155
- else:
156
- # cuda
157
- warp_size = 32
158
-
159
- max_block_sz = 1024
160
- max_num_warps = max_block_sz // warp_size
161
- pruned_configs = [config for config in configs if config.num_warps <= max_num_warps]
162
- return pruned_configs
163
-
164
-
165
- configs_autotune = [
166
- triton.Config({}, num_warps=1),
167
- triton.Config({}, num_warps=2),
168
- triton.Config({}, num_warps=4),
169
- triton.Config({}, num_warps=8),
170
- triton.Config({}, num_warps=16),
171
- triton.Config({}, num_warps=32),
172
- ]
173
-
174
- pruned_configs_autotune = config_prune(configs_autotune)
175
-
176
-
177
- @triton.autotune(
178
- configs=pruned_configs_autotune,
179
- key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
180
- )
181
- # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
182
- # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
183
- @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
184
- @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
185
- @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
186
- @triton.jit
187
- def _layer_norm_fwd_1pass_kernel(
188
- X, # pointer to the input
189
- Y, # pointer to the output
190
- W, # pointer to the weights
191
- B, # pointer to the biases
192
- RESIDUAL, # pointer to the residual
193
- X1,
194
- W1,
195
- B1,
196
- Y1,
197
- RESIDUAL_OUT, # pointer to the residual
198
- ROWSCALE,
199
- SEEDS, # Dropout seeds for each row
200
- DROPOUT_MASK,
201
- Mean, # pointer to the mean
202
- Rstd, # pointer to the 1/std
203
- stride_x_row, # how much to increase the pointer when moving by 1 row
204
- stride_y_row,
205
- stride_res_row,
206
- stride_res_out_row,
207
- stride_x1_row,
208
- stride_y1_row,
209
- M, # number of rows in X
210
- N, # number of columns in X
211
- eps, # epsilon to avoid division by zero
212
- dropout_p, # Dropout probability
213
- IS_RMS_NORM: tl.constexpr,
214
- BLOCK_N: tl.constexpr,
215
- HAS_RESIDUAL: tl.constexpr,
216
- STORE_RESIDUAL_OUT: tl.constexpr,
217
- HAS_BIAS: tl.constexpr,
218
- HAS_DROPOUT: tl.constexpr,
219
- STORE_DROPOUT_MASK: tl.constexpr,
220
- HAS_ROWSCALE: tl.constexpr,
221
- HAS_X1: tl.constexpr,
222
- HAS_W1: tl.constexpr,
223
- HAS_B1: tl.constexpr,
224
- ):
225
- # Map the program id to the row of X and Y it should compute.
226
- row = tl.program_id(0)
227
- X += row * stride_x_row
228
- Y += row * stride_y_row
229
- if HAS_RESIDUAL:
230
- RESIDUAL += row * stride_res_row
231
- if STORE_RESIDUAL_OUT:
232
- RESIDUAL_OUT += row * stride_res_out_row
233
- if HAS_X1:
234
- X1 += row * stride_x1_row
235
- if HAS_W1:
236
- Y1 += row * stride_y1_row
237
- # Compute mean and variance
238
- cols = tl.arange(0, BLOCK_N)
239
- x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
240
- if HAS_ROWSCALE:
241
- rowscale = tl.load(ROWSCALE + row).to(tl.float32)
242
- x *= rowscale
243
- if HAS_DROPOUT:
244
- # Compute dropout mask
245
- # 7 rounds is good enough, and reduces register pressure
246
- keep_mask = (
247
- tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
248
- )
249
- x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
250
- if STORE_DROPOUT_MASK:
251
- tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
252
- if HAS_X1:
253
- x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
254
- if HAS_ROWSCALE:
255
- rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
256
- x1 *= rowscale
257
- if HAS_DROPOUT:
258
- # Compute dropout mask
259
- # 7 rounds is good enough, and reduces register pressure
260
- keep_mask = (
261
- tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
262
- > dropout_p
263
- )
264
- x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
265
- if STORE_DROPOUT_MASK:
266
- tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
267
- x += x1
268
- if HAS_RESIDUAL:
269
- residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
270
- x += residual
271
- if STORE_RESIDUAL_OUT:
272
- tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
273
- if not IS_RMS_NORM:
274
- mean = tl.sum(x, axis=0) / N
275
- tl.store(Mean + row, mean)
276
- xbar = tl.where(cols < N, x - mean, 0.0)
277
- var = tl.sum(xbar * xbar, axis=0) / N
278
- else:
279
- xbar = tl.where(cols < N, x, 0.0)
280
- var = tl.sum(xbar * xbar, axis=0) / N
281
- rstd = 1 / tl.sqrt(var + eps)
282
- tl.store(Rstd + row, rstd)
283
- # Normalize and apply linear transformation
284
- mask = cols < N
285
- w = tl.load(W + cols, mask=mask).to(tl.float32)
286
- if HAS_BIAS:
287
- b = tl.load(B + cols, mask=mask).to(tl.float32)
288
- x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
289
- y = x_hat * w + b if HAS_BIAS else x_hat * w
290
- # Write output
291
- tl.store(Y + cols, y, mask=mask)
292
- if HAS_W1:
293
- w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
294
- if HAS_B1:
295
- b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
296
- y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
297
- tl.store(Y1 + cols, y1, mask=mask)
298
-
299
-
300
- def _layer_norm_fwd(
301
- x,
302
- weight,
303
- bias,
304
- eps,
305
- residual=None,
306
- x1=None,
307
- weight1=None,
308
- bias1=None,
309
- dropout_p=0.0,
310
- rowscale=None,
311
- out_dtype=None,
312
- residual_dtype=None,
313
- is_rms_norm=False,
314
- return_dropout_mask=False,
315
- ):
316
- if residual is not None:
317
- residual_dtype = residual.dtype
318
- M, N = x.shape
319
- assert x.stride(-1) == 1
320
- if residual is not None:
321
- assert residual.stride(-1) == 1
322
- assert residual.shape == (M, N)
323
- assert weight.shape == (N,)
324
- assert weight.stride(-1) == 1
325
- if bias is not None:
326
- assert bias.stride(-1) == 1
327
- assert bias.shape == (N,)
328
- if x1 is not None:
329
- assert x1.shape == x.shape
330
- assert rowscale is None
331
- assert x1.stride(-1) == 1
332
- if weight1 is not None:
333
- assert weight1.shape == (N,)
334
- assert weight1.stride(-1) == 1
335
- if bias1 is not None:
336
- assert bias1.shape == (N,)
337
- assert bias1.stride(-1) == 1
338
- if rowscale is not None:
339
- assert rowscale.is_contiguous()
340
- assert rowscale.shape == (M,)
341
- # allocate output
342
- y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
343
- assert y.stride(-1) == 1
344
- if weight1 is not None:
345
- y1 = torch.empty_like(y)
346
- assert y1.stride(-1) == 1
347
- else:
348
- y1 = None
349
- if (
350
- residual is not None
351
- or (residual_dtype is not None and residual_dtype != x.dtype)
352
- or dropout_p > 0.0
353
- or rowscale is not None
354
- or x1 is not None
355
- ):
356
- residual_out = torch.empty(
357
- M,
358
- N,
359
- device=x.device,
360
- dtype=residual_dtype if residual_dtype is not None else x.dtype,
361
- )
362
- assert residual_out.stride(-1) == 1
363
- else:
364
- residual_out = None
365
- mean = (
366
- torch.empty((M,), dtype=torch.float32, device=x.device)
367
- if not is_rms_norm
368
- else None
369
- )
370
- rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
371
- if dropout_p > 0.0:
372
- seeds = torch.randint(
373
- 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
374
- )
375
- else:
376
- seeds = None
377
- if return_dropout_mask and dropout_p > 0.0:
378
- dropout_mask = torch.empty(
379
- M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool
380
- )
381
- else:
382
- dropout_mask = None
383
- # Less than 64KB per feature: enqueue fused kernel
384
- MAX_FUSED_SIZE = 65536 // x.element_size()
385
- BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
386
- if N > BLOCK_N:
387
- raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
388
- with torch.cuda.device(x.device.index):
389
- _layer_norm_fwd_1pass_kernel[(M,)](
390
- x,
391
- y,
392
- weight,
393
- bias,
394
- residual,
395
- x1,
396
- weight1,
397
- bias1,
398
- y1,
399
- residual_out,
400
- rowscale,
401
- seeds,
402
- dropout_mask,
403
- mean,
404
- rstd,
405
- x.stride(0),
406
- y.stride(0),
407
- residual.stride(0) if residual is not None else 0,
408
- residual_out.stride(0) if residual_out is not None else 0,
409
- x1.stride(0) if x1 is not None else 0,
410
- y1.stride(0) if y1 is not None else 0,
411
- M,
412
- N,
413
- eps,
414
- dropout_p,
415
- is_rms_norm,
416
- BLOCK_N,
417
- residual is not None,
418
- residual_out is not None,
419
- bias is not None,
420
- dropout_p > 0.0,
421
- dropout_mask is not None,
422
- rowscale is not None,
423
- )
424
- # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
425
- if dropout_mask is not None and x1 is not None:
426
- dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
427
- else:
428
- dropout_mask1 = None
429
- return (
430
- y,
431
- y1,
432
- mean,
433
- rstd,
434
- residual_out if residual_out is not None else x,
435
- seeds,
436
- dropout_mask,
437
- dropout_mask1,
438
- )
439
-
440
-
441
- @triton.autotune(
442
- configs=pruned_configs_autotune,
443
- key=[
444
- "N",
445
- "HAS_DRESIDUAL",
446
- "STORE_DRESIDUAL",
447
- "IS_RMS_NORM",
448
- "HAS_BIAS",
449
- "HAS_DROPOUT",
450
- ],
451
- )
452
- # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
453
- # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
454
- # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
455
- @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
456
- @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
457
- @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
458
- @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
459
- @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
460
- @triton.jit
461
- def _layer_norm_bwd_kernel(
462
- X, # pointer to the input
463
- W, # pointer to the weights
464
- B, # pointer to the biases
465
- Y, # pointer to the output to be recomputed
466
- DY, # pointer to the output gradient
467
- DX, # pointer to the input gradient
468
- DW, # pointer to the partial sum of weights gradient
469
- DB, # pointer to the partial sum of biases gradient
470
- DRESIDUAL,
471
- W1,
472
- DY1,
473
- DX1,
474
- DW1,
475
- DB1,
476
- DRESIDUAL_IN,
477
- ROWSCALE,
478
- SEEDS,
479
- Mean, # pointer to the mean
480
- Rstd, # pointer to the 1/std
481
- stride_x_row, # how much to increase the pointer when moving by 1 row
482
- stride_y_row,
483
- stride_dy_row,
484
- stride_dx_row,
485
- stride_dres_row,
486
- stride_dy1_row,
487
- stride_dx1_row,
488
- stride_dres_in_row,
489
- M, # number of rows in X
490
- N, # number of columns in X
491
- eps, # epsilon to avoid division by zero
492
- dropout_p,
493
- rows_per_program,
494
- IS_RMS_NORM: tl.constexpr,
495
- BLOCK_N: tl.constexpr,
496
- HAS_DRESIDUAL: tl.constexpr,
497
- STORE_DRESIDUAL: tl.constexpr,
498
- HAS_BIAS: tl.constexpr,
499
- HAS_DROPOUT: tl.constexpr,
500
- HAS_ROWSCALE: tl.constexpr,
501
- HAS_DY1: tl.constexpr,
502
- HAS_DX1: tl.constexpr,
503
- HAS_B1: tl.constexpr,
504
- RECOMPUTE_OUTPUT: tl.constexpr,
505
- ):
506
- # Map the program id to the elements of X, DX, and DY it should compute.
507
- row_block_id = tl.program_id(0)
508
- row_start = row_block_id * rows_per_program
509
- # Do not early exit if row_start >= M, because we need to write DW and DB
510
- cols = tl.arange(0, BLOCK_N)
511
- mask = cols < N
512
- X += row_start * stride_x_row
513
- if HAS_DRESIDUAL:
514
- DRESIDUAL += row_start * stride_dres_row
515
- if STORE_DRESIDUAL:
516
- DRESIDUAL_IN += row_start * stride_dres_in_row
517
- DY += row_start * stride_dy_row
518
- DX += row_start * stride_dx_row
519
- if HAS_DY1:
520
- DY1 += row_start * stride_dy1_row
521
- if HAS_DX1:
522
- DX1 += row_start * stride_dx1_row
523
- if RECOMPUTE_OUTPUT:
524
- Y += row_start * stride_y_row
525
- w = tl.load(W + cols, mask=mask).to(tl.float32)
526
- if RECOMPUTE_OUTPUT and HAS_BIAS:
527
- b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
528
- if HAS_DY1:
529
- w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
530
- dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
531
- if HAS_BIAS:
532
- db = tl.zeros((BLOCK_N,), dtype=tl.float32)
533
- if HAS_DY1:
534
- dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
535
- if HAS_B1:
536
- db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
537
- row_end = min((row_block_id + 1) * rows_per_program, M)
538
- for row in range(row_start, row_end):
539
- # Load data to SRAM
540
- x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
541
- dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
542
- if HAS_DY1:
543
- dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
544
- if not IS_RMS_NORM:
545
- mean = tl.load(Mean + row)
546
- rstd = tl.load(Rstd + row)
547
- # Compute dx
548
- xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
549
- xhat = tl.where(mask, xhat, 0.0)
550
- if RECOMPUTE_OUTPUT:
551
- y = xhat * w + b if HAS_BIAS else xhat * w
552
- tl.store(Y + cols, y, mask=mask)
553
- wdy = w * dy
554
- dw += dy * xhat
555
- if HAS_BIAS:
556
- db += dy
557
- if HAS_DY1:
558
- wdy += w1 * dy1
559
- dw1 += dy1 * xhat
560
- if HAS_B1:
561
- db1 += dy1
562
- if not IS_RMS_NORM:
563
- c1 = tl.sum(xhat * wdy, axis=0) / N
564
- c2 = tl.sum(wdy, axis=0) / N
565
- dx = (wdy - (xhat * c1 + c2)) * rstd
566
- else:
567
- c1 = tl.sum(xhat * wdy, axis=0) / N
568
- dx = (wdy - xhat * c1) * rstd
569
- if HAS_DRESIDUAL:
570
- dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
571
- dx += dres
572
- # Write dx
573
- if STORE_DRESIDUAL:
574
- tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
575
- if HAS_DX1:
576
- if HAS_DROPOUT:
577
- keep_mask = (
578
- tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
579
- > dropout_p
580
- )
581
- dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
582
- else:
583
- dx1 = dx
584
- tl.store(DX1 + cols, dx1, mask=mask)
585
- if HAS_DROPOUT:
586
- keep_mask = (
587
- tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7)
588
- > dropout_p
589
- )
590
- dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
591
- if HAS_ROWSCALE:
592
- rowscale = tl.load(ROWSCALE + row).to(tl.float32)
593
- dx *= rowscale
594
- tl.store(DX + cols, dx, mask=mask)
595
-
596
- X += stride_x_row
597
- if HAS_DRESIDUAL:
598
- DRESIDUAL += stride_dres_row
599
- if STORE_DRESIDUAL:
600
- DRESIDUAL_IN += stride_dres_in_row
601
- if RECOMPUTE_OUTPUT:
602
- Y += stride_y_row
603
- DY += stride_dy_row
604
- DX += stride_dx_row
605
- if HAS_DY1:
606
- DY1 += stride_dy1_row
607
- if HAS_DX1:
608
- DX1 += stride_dx1_row
609
- tl.store(DW + row_block_id * N + cols, dw, mask=mask)
610
- if HAS_BIAS:
611
- tl.store(DB + row_block_id * N + cols, db, mask=mask)
612
- if HAS_DY1:
613
- tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
614
- if HAS_B1:
615
- tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
616
-
617
-
618
- def _layer_norm_bwd(
619
- dy,
620
- x,
621
- weight,
622
- bias,
623
- eps,
624
- mean,
625
- rstd,
626
- dresidual=None,
627
- dy1=None,
628
- weight1=None,
629
- bias1=None,
630
- seeds=None,
631
- dropout_p=0.0,
632
- rowscale=None,
633
- has_residual=False,
634
- has_x1=False,
635
- is_rms_norm=False,
636
- x_dtype=None,
637
- recompute_output=False,
638
- ):
639
- M, N = x.shape
640
- assert x.stride(-1) == 1
641
- assert dy.stride(-1) == 1
642
- assert dy.shape == (M, N)
643
- if dresidual is not None:
644
- assert dresidual.stride(-1) == 1
645
- assert dresidual.shape == (M, N)
646
- assert weight.shape == (N,)
647
- assert weight.stride(-1) == 1
648
- if bias is not None:
649
- assert bias.stride(-1) == 1
650
- assert bias.shape == (N,)
651
- if dy1 is not None:
652
- assert weight1 is not None
653
- assert dy1.shape == dy.shape
654
- assert dy1.stride(-1) == 1
655
- if weight1 is not None:
656
- assert weight1.shape == (N,)
657
- assert weight1.stride(-1) == 1
658
- if bias1 is not None:
659
- assert bias1.shape == (N,)
660
- assert bias1.stride(-1) == 1
661
- if seeds is not None:
662
- assert seeds.is_contiguous()
663
- assert seeds.shape == (M if not has_x1 else M * 2,)
664
- if rowscale is not None:
665
- assert rowscale.is_contiguous()
666
- assert rowscale.shape == (M,)
667
- # allocate output
668
- dx = (
669
- torch.empty_like(x)
670
- if x_dtype is None
671
- else torch.empty(M, N, dtype=x_dtype, device=x.device)
672
- )
673
- dresidual_in = (
674
- torch.empty_like(x)
675
- if has_residual
676
- and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
677
- else None
678
- )
679
- dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
680
- y = (
681
- torch.empty(M, N, dtype=dy.dtype, device=dy.device)
682
- if recompute_output
683
- else None
684
- )
685
- if recompute_output:
686
- assert (
687
- weight1 is None
688
- ), "recompute_output is not supported with parallel LayerNorm"
689
-
690
- # Less than 64KB per feature: enqueue fused kernel
691
- MAX_FUSED_SIZE = 65536 // x.element_size()
692
- BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
693
- if N > BLOCK_N:
694
- raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
695
- sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
696
- _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
697
- _db = (
698
- torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
699
- if bias is not None
700
- else None
701
- )
702
- _dw1 = torch.empty_like(_dw) if weight1 is not None else None
703
- _db1 = torch.empty_like(_db) if bias1 is not None else None
704
- rows_per_program = math.ceil(M / sm_count)
705
- grid = (sm_count,)
706
- with torch.cuda.device(x.device.index):
707
- _layer_norm_bwd_kernel[grid](
708
- x,
709
- weight,
710
- bias,
711
- y,
712
- dy,
713
- dx,
714
- _dw,
715
- _db,
716
- dresidual,
717
- weight1,
718
- dy1,
719
- dx1,
720
- _dw1,
721
- _db1,
722
- dresidual_in,
723
- rowscale,
724
- seeds,
725
- mean,
726
- rstd,
727
- x.stride(0),
728
- 0 if not recompute_output else y.stride(0),
729
- dy.stride(0),
730
- dx.stride(0),
731
- dresidual.stride(0) if dresidual is not None else 0,
732
- dy1.stride(0) if dy1 is not None else 0,
733
- dx1.stride(0) if dx1 is not None else 0,
734
- dresidual_in.stride(0) if dresidual_in is not None else 0,
735
- M,
736
- N,
737
- eps,
738
- dropout_p,
739
- rows_per_program,
740
- is_rms_norm,
741
- BLOCK_N,
742
- dresidual is not None,
743
- dresidual_in is not None,
744
- bias is not None,
745
- dropout_p > 0.0,
746
- )
747
- dw = _dw.sum(0).to(weight.dtype)
748
- db = _db.sum(0).to(bias.dtype) if bias is not None else None
749
- dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
750
- db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
751
- # Don't need to compute dresidual_in separately in this case
752
- if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
753
- dresidual_in = dx
754
- if has_x1 and dropout_p == 0.0:
755
- dx1 = dx
756
- return (
757
- (dx, dw, db, dresidual_in, dx1, dw1, db1)
758
- if not recompute_output
759
- else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
760
- )
761
-
762
-
763
- class LayerNormFn(torch.autograd.Function):
764
- @staticmethod
765
- def forward(
766
- ctx,
767
- x,
768
- weight,
769
- bias,
770
- residual=None,
771
- x1=None,
772
- weight1=None,
773
- bias1=None,
774
- eps=1e-6,
775
- dropout_p=0.0,
776
- rowscale=None,
777
- prenorm=False,
778
- residual_in_fp32=False,
779
- is_rms_norm=False,
780
- return_dropout_mask=False,
781
- ):
782
- x_shape_og = x.shape
783
- # reshape input data into 2D tensor
784
- x = x.reshape(-1, x.shape[-1])
785
- if x.stride(-1) != 1:
786
- x = x.contiguous()
787
- if residual is not None:
788
- assert residual.shape == x_shape_og
789
- residual = residual.reshape(-1, residual.shape[-1])
790
- if residual.stride(-1) != 1:
791
- residual = residual.contiguous()
792
- if x1 is not None:
793
- assert x1.shape == x_shape_og
794
- assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
795
- x1 = x1.reshape(-1, x1.shape[-1])
796
- if x1.stride(-1) != 1:
797
- x1 = x1.contiguous()
798
- weight = weight.contiguous()
799
- if bias is not None:
800
- bias = bias.contiguous()
801
- if weight1 is not None:
802
- weight1 = weight1.contiguous()
803
- if bias1 is not None:
804
- bias1 = bias1.contiguous()
805
- if rowscale is not None:
806
- rowscale = rowscale.reshape(-1).contiguous()
807
- residual_dtype = (
808
- residual.dtype
809
- if residual is not None
810
- else (torch.float32 if residual_in_fp32 else None)
811
- )
812
- y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = (
813
- _layer_norm_fwd(
814
- x,
815
- weight,
816
- bias,
817
- eps,
818
- residual,
819
- x1,
820
- weight1,
821
- bias1,
822
- dropout_p=dropout_p,
823
- rowscale=rowscale,
824
- residual_dtype=residual_dtype,
825
- is_rms_norm=is_rms_norm,
826
- return_dropout_mask=return_dropout_mask,
827
- )
828
- )
829
- ctx.save_for_backward(
830
- residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
831
- )
832
- ctx.x_shape_og = x_shape_og
833
- ctx.eps = eps
834
- ctx.dropout_p = dropout_p
835
- ctx.is_rms_norm = is_rms_norm
836
- ctx.has_residual = residual is not None
837
- ctx.has_x1 = x1 is not None
838
- ctx.prenorm = prenorm
839
- ctx.x_dtype = x.dtype
840
- y = y.reshape(x_shape_og)
841
- y1 = y1.reshape(x_shape_og) if y1 is not None else None
842
- residual_out = (
843
- residual_out.reshape(x_shape_og) if residual_out is not None else None
844
- )
845
- dropout_mask = (
846
- dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
847
- )
848
- dropout_mask1 = (
849
- dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
850
- )
851
- if not return_dropout_mask:
852
- if weight1 is None:
853
- return y if not prenorm else (y, residual_out)
854
- else:
855
- return (y, y1) if not prenorm else (y, y1, residual_out)
856
- else:
857
- if weight1 is None:
858
- return (
859
- (y, dropout_mask, dropout_mask1)
860
- if not prenorm
861
- else (y, residual_out, dropout_mask, dropout_mask1)
862
- )
863
- else:
864
- return (
865
- (y, y1, dropout_mask, dropout_mask1)
866
- if not prenorm
867
- else (y, y1, residual_out, dropout_mask, dropout_mask1)
868
- )
869
-
870
- @staticmethod
871
- def backward(ctx, dy, *args):
872
- x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
873
- dy = dy.reshape(-1, dy.shape[-1])
874
- if dy.stride(-1) != 1:
875
- dy = dy.contiguous()
876
- assert dy.shape == x.shape
877
- if weight1 is not None:
878
- dy1, args = args[0], args[1:]
879
- dy1 = dy1.reshape(-1, dy1.shape[-1])
880
- if dy1.stride(-1) != 1:
881
- dy1 = dy1.contiguous()
882
- assert dy1.shape == x.shape
883
- else:
884
- dy1 = None
885
- if ctx.prenorm:
886
- dresidual = args[0]
887
- dresidual = dresidual.reshape(-1, dresidual.shape[-1])
888
- if dresidual.stride(-1) != 1:
889
- dresidual = dresidual.contiguous()
890
- assert dresidual.shape == x.shape
891
- else:
892
- dresidual = None
893
- dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
894
- dy,
895
- x,
896
- weight,
897
- bias,
898
- ctx.eps,
899
- mean,
900
- rstd,
901
- dresidual,
902
- dy1,
903
- weight1,
904
- bias1,
905
- seeds,
906
- ctx.dropout_p,
907
- rowscale,
908
- ctx.has_residual,
909
- ctx.has_x1,
910
- ctx.is_rms_norm,
911
- x_dtype=ctx.x_dtype,
912
- )
913
- return (
914
- dx.reshape(ctx.x_shape_og),
915
- dw,
916
- db,
917
- dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
918
- dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
919
- dw1,
920
- db1,
921
- None,
922
- None,
923
- None,
924
- None,
925
- None,
926
- None,
927
- None,
928
- )
929
-
930
-
931
- def layer_norm_fn(
932
- x,
933
- weight,
934
- bias,
935
- residual=None,
936
- x1=None,
937
- weight1=None,
938
- bias1=None,
939
- eps=1e-6,
940
- dropout_p=0.0,
941
- rowscale=None,
942
- prenorm=False,
943
- residual_in_fp32=False,
944
- is_rms_norm=False,
945
- return_dropout_mask=False,
946
- ):
947
- return LayerNormFn.apply(
948
- x,
949
- weight,
950
- bias,
951
- residual,
952
- x1,
953
- weight1,
954
- bias1,
955
- eps,
956
- dropout_p,
957
- rowscale,
958
- prenorm,
959
- residual_in_fp32,
960
- is_rms_norm,
961
- return_dropout_mask,
962
- )
963
-
964
-
965
- def rms_norm_fn(
966
- x,
967
- weight,
968
- bias,
969
- residual=None,
970
- x1=None,
971
- weight1=None,
972
- bias1=None,
973
- eps=1e-6,
974
- dropout_p=0.0,
975
- rowscale=None,
976
- prenorm=False,
977
- residual_in_fp32=False,
978
- return_dropout_mask=False,
979
- ):
980
- return LayerNormFn.apply(
981
- x,
982
- weight,
983
- bias,
984
- residual,
985
- x1,
986
- weight1,
987
- bias1,
988
- eps,
989
- dropout_p,
990
- rowscale,
991
- prenorm,
992
- residual_in_fp32,
993
- True,
994
- return_dropout_mask,
995
- )
996
-
997
-
998
- class RMSNorm(torch.nn.Module):
999
-
1000
- def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
1001
- factory_kwargs = {"device": device, "dtype": dtype}
1002
- super().__init__()
1003
- self.eps = eps
1004
- if dropout_p > 0.0:
1005
- self.drop = torch.nn.Dropout(dropout_p)
1006
- else:
1007
- self.drop = None
1008
- self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
1009
- self.register_parameter("bias", None)
1010
- self.reset_parameters()
1011
-
1012
- def reset_parameters(self):
1013
- torch.nn.init.ones_(self.weight)
1014
-
1015
- def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
1016
- return rms_norm_fn(
1017
- x,
1018
- self.weight,
1019
- self.bias,
1020
- residual=residual,
1021
- eps=self.eps,
1022
- dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
1023
- prenorm=prenorm,
1024
- residual_in_fp32=residual_in_fp32,
1025
- )
1026
-
1027
-
1028
- class LayerNormLinearFn(torch.autograd.Function):
1029
- @staticmethod
1030
- @custom_fwd
1031
- def forward(
1032
- ctx,
1033
- x,
1034
- norm_weight,
1035
- norm_bias,
1036
- linear_weight,
1037
- linear_bias,
1038
- residual=None,
1039
- eps=1e-6,
1040
- prenorm=False,
1041
- residual_in_fp32=False,
1042
- is_rms_norm=False,
1043
- ):
1044
- x_shape_og = x.shape
1045
- # reshape input data into 2D tensor
1046
- x = x.reshape(-1, x.shape[-1])
1047
- if x.stride(-1) != 1:
1048
- x = x.contiguous()
1049
- if residual is not None:
1050
- assert residual.shape == x_shape_og
1051
- residual = residual.reshape(-1, residual.shape[-1])
1052
- if residual.stride(-1) != 1:
1053
- residual = residual.contiguous()
1054
- norm_weight = norm_weight.contiguous()
1055
- if norm_bias is not None:
1056
- norm_bias = norm_bias.contiguous()
1057
- residual_dtype = (
1058
- residual.dtype
1059
- if residual is not None
1060
- else (torch.float32 if residual_in_fp32 else None)
1061
- )
1062
- y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
1063
- x,
1064
- norm_weight,
1065
- norm_bias,
1066
- eps,
1067
- residual,
1068
- out_dtype=(
1069
- None
1070
- if not torch.is_autocast_enabled()
1071
- else torch.get_autocast_gpu_dtype()
1072
- ),
1073
- residual_dtype=residual_dtype,
1074
- is_rms_norm=is_rms_norm,
1075
- )
1076
- y = y.reshape(x_shape_og)
1077
- dtype = (
1078
- torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
1079
- )
1080
- linear_weight = linear_weight.to(dtype)
1081
- linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
1082
- out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
1083
- # We don't store y, will be recomputed in the backward pass to save memory
1084
- ctx.save_for_backward(
1085
- residual_out, norm_weight, norm_bias, linear_weight, mean, rstd
1086
- )
1087
- ctx.x_shape_og = x_shape_og
1088
- ctx.eps = eps
1089
- ctx.is_rms_norm = is_rms_norm
1090
- ctx.has_residual = residual is not None
1091
- ctx.prenorm = prenorm
1092
- ctx.x_dtype = x.dtype
1093
- ctx.linear_bias_is_none = linear_bias is None
1094
- return out if not prenorm else (out, residual_out.reshape(x_shape_og))
1095
-
1096
- @staticmethod
1097
- @custom_bwd
1098
- def backward(ctx, dout, *args):
1099
- x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
1100
- dout = dout.reshape(-1, dout.shape[-1])
1101
- dy = F.linear(dout, linear_weight.t())
1102
- dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
1103
- if dy.stride(-1) != 1:
1104
- dy = dy.contiguous()
1105
- assert dy.shape == x.shape
1106
- if ctx.prenorm:
1107
- dresidual = args[0]
1108
- dresidual = dresidual.reshape(-1, dresidual.shape[-1])
1109
- if dresidual.stride(-1) != 1:
1110
- dresidual = dresidual.contiguous()
1111
- assert dresidual.shape == x.shape
1112
- else:
1113
- dresidual = None
1114
- dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
1115
- dy,
1116
- x,
1117
- norm_weight,
1118
- norm_bias,
1119
- ctx.eps,
1120
- mean,
1121
- rstd,
1122
- dresidual=dresidual,
1123
- has_residual=ctx.has_residual,
1124
- is_rms_norm=ctx.is_rms_norm,
1125
- x_dtype=ctx.x_dtype,
1126
- recompute_output=True,
1127
- )
1128
- dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
1129
- return (
1130
- dx.reshape(ctx.x_shape_og),
1131
- dnorm_weight,
1132
- dnorm_bias,
1133
- dlinear_weight,
1134
- dlinear_bias,
1135
- dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
1136
- None,
1137
- None,
1138
- None,
1139
- None,
1140
- )
1141
-
1142
-
1143
- def layer_norm_linear_fn(
1144
- x,
1145
- norm_weight,
1146
- norm_bias,
1147
- linear_weight,
1148
- linear_bias,
1149
- residual=None,
1150
- eps=1e-6,
1151
- prenorm=False,
1152
- residual_in_fp32=False,
1153
- is_rms_norm=False,
1154
- ):
1155
- return LayerNormLinearFn.apply(
1156
- x,
1157
- norm_weight,
1158
- norm_bias,
1159
- linear_weight,
1160
- linear_bias,
1161
- residual,
1162
- eps,
1163
- prenorm,
1164
- residual_in_fp32,
1165
- is_rms_norm,
1166
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/ops/triton/selective_state_update.py DELETED
@@ -1,389 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao, Albert Gu.
2
-
3
- """We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
4
- """
5
-
6
- import math
7
- import torch
8
- import torch.nn.functional as F
9
-
10
- import triton
11
- import triton.language as tl
12
-
13
- from einops import rearrange, repeat
14
-
15
- from .softplus import softplus
16
-
17
-
18
- @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
19
- @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
20
- @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
21
- @triton.heuristics(
22
- {
23
- "HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"]
24
- is not None
25
- }
26
- )
27
- @triton.heuristics(
28
- {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}
29
- )
30
- @triton.jit
31
- def _selective_scan_update_kernel(
32
- # Pointers to matrices
33
- state_ptr,
34
- x_ptr,
35
- dt_ptr,
36
- dt_bias_ptr,
37
- A_ptr,
38
- B_ptr,
39
- C_ptr,
40
- D_ptr,
41
- z_ptr,
42
- out_ptr,
43
- state_batch_indices_ptr,
44
- # Matrix dimensions
45
- batch,
46
- nheads,
47
- dim,
48
- dstate,
49
- nheads_ngroups_ratio,
50
- # Strides
51
- stride_state_batch,
52
- stride_state_head,
53
- stride_state_dim,
54
- stride_state_dstate,
55
- stride_x_batch,
56
- stride_x_head,
57
- stride_x_dim,
58
- stride_dt_batch,
59
- stride_dt_head,
60
- stride_dt_dim,
61
- stride_dt_bias_head,
62
- stride_dt_bias_dim,
63
- stride_A_head,
64
- stride_A_dim,
65
- stride_A_dstate,
66
- stride_B_batch,
67
- stride_B_group,
68
- stride_B_dstate,
69
- stride_C_batch,
70
- stride_C_group,
71
- stride_C_dstate,
72
- stride_D_head,
73
- stride_D_dim,
74
- stride_z_batch,
75
- stride_z_head,
76
- stride_z_dim,
77
- stride_out_batch,
78
- stride_out_head,
79
- stride_out_dim,
80
- # Meta-parameters
81
- DT_SOFTPLUS: tl.constexpr,
82
- TIE_HDIM: tl.constexpr,
83
- BLOCK_SIZE_M: tl.constexpr,
84
- HAS_DT_BIAS: tl.constexpr,
85
- HAS_D: tl.constexpr,
86
- HAS_Z: tl.constexpr,
87
- HAS_STATE_BATCH_INDICES: tl.constexpr,
88
- BLOCK_SIZE_DSTATE: tl.constexpr,
89
- ):
90
- pid_m = tl.program_id(axis=0)
91
- pid_b = tl.program_id(axis=1)
92
- pid_h = tl.program_id(axis=2)
93
-
94
- if HAS_STATE_BATCH_INDICES:
95
- state_batch_indices_ptr += pid_b
96
- state_batch_idx = tl.load(state_batch_indices_ptr)
97
- state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
98
- else:
99
- state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
100
-
101
- x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
102
- dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
103
- if HAS_DT_BIAS:
104
- dt_bias_ptr += pid_h * stride_dt_bias_head
105
- A_ptr += pid_h * stride_A_head
106
- B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
107
- C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
108
- if HAS_Z:
109
- z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
110
- out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
111
-
112
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
113
- offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
114
- state_ptrs = state_ptr + (
115
- offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
116
- )
117
- x_ptrs = x_ptr + offs_m * stride_x_dim
118
- dt_ptrs = dt_ptr + offs_m * stride_dt_dim
119
- if HAS_DT_BIAS:
120
- dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
121
- if HAS_D:
122
- D_ptr += pid_h * stride_D_head
123
- A_ptrs = A_ptr + (
124
- offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
125
- )
126
- B_ptrs = B_ptr + offs_n * stride_B_dstate
127
- C_ptrs = C_ptr + offs_n * stride_C_dstate
128
- if HAS_D:
129
- D_ptrs = D_ptr + offs_m * stride_D_dim
130
- if HAS_Z:
131
- z_ptrs = z_ptr + offs_m * stride_z_dim
132
- out_ptrs = out_ptr + offs_m * stride_out_dim
133
-
134
- state = tl.load(
135
- state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
136
- )
137
- x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
138
- if not TIE_HDIM:
139
- dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
140
- if HAS_DT_BIAS:
141
- dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
142
- if DT_SOFTPLUS:
143
- dt = tl.where(dt <= 20.0, softplus(dt), dt)
144
- A = tl.load(
145
- A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
146
- ).to(tl.float32)
147
- dA = tl.exp(A * dt[:, None])
148
- else:
149
- dt = tl.load(dt_ptr).to(tl.float32)
150
- if HAS_DT_BIAS:
151
- dt += tl.load(dt_bias_ptr).to(tl.float32)
152
- if DT_SOFTPLUS:
153
- dt = tl.where(dt <= 20.0, softplus(dt), dt)
154
- A = tl.load(A_ptr).to(tl.float32)
155
- dA = tl.exp(A * dt) # scalar, not a matrix
156
-
157
- B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
158
- C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
159
- if HAS_D:
160
- D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
161
- if HAS_Z:
162
- z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
163
-
164
- if not TIE_HDIM:
165
- dB = B[None, :] * dt[:, None]
166
- else:
167
- dB = B * dt # vector of size (dstate,)
168
- state = state * dA + dB * x[:, None]
169
- tl.store(
170
- state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
171
- )
172
- out = tl.sum(state * C[None, :], axis=1)
173
- if HAS_D:
174
- out += x * D
175
- if HAS_Z:
176
- out *= z * tl.sigmoid(z)
177
- tl.store(out_ptrs, out, mask=offs_m < dim)
178
-
179
-
180
- def selective_state_update(
181
- state,
182
- x,
183
- dt,
184
- A,
185
- B,
186
- C,
187
- D=None,
188
- z=None,
189
- dt_bias=None,
190
- dt_softplus=False,
191
- state_batch_indices=None,
192
- ):
193
- """
194
- Argument:
195
- state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
196
- x: (batch, dim) or (batch, nheads, dim)
197
- dt: (batch, dim) or (batch, nheads, dim)
198
- A: (dim, dstate) or (nheads, dim, dstate)
199
- B: (batch, dstate) or (batch, ngroups, dstate)
200
- C: (batch, dstate) or (batch, ngroups, dstate)
201
- D: (dim,) or (nheads, dim)
202
- z: (batch, dim) or (batch, nheads, dim)
203
- dt_bias: (dim,) or (nheads, dim)
204
- Return:
205
- out: (batch, dim) or (batch, nheads, dim)
206
- """
207
- has_heads = state.dim() > 3
208
- if state.dim() == 3:
209
- state = state.unsqueeze(1)
210
- if x.dim() == 2:
211
- x = x.unsqueeze(1)
212
- if dt.dim() == 2:
213
- dt = dt.unsqueeze(1)
214
- if A.dim() == 2:
215
- A = A.unsqueeze(0)
216
- if B.dim() == 2:
217
- B = B.unsqueeze(1)
218
- if C.dim() == 2:
219
- C = C.unsqueeze(1)
220
- if D is not None and D.dim() == 1:
221
- D = D.unsqueeze(0)
222
- if z is not None and z.dim() == 2:
223
- z = z.unsqueeze(1)
224
- if dt_bias is not None and dt_bias.dim() == 1:
225
- dt_bias = dt_bias.unsqueeze(0)
226
- _, nheads, dim, dstate = state.shape
227
- batch = x.shape[0]
228
- if x.shape != (batch, nheads, dim):
229
- print(f"{state.shape} {x.shape} {batch} {nheads} {dim}")
230
- assert x.shape == (batch, nheads, dim)
231
- assert dt.shape == x.shape
232
- assert A.shape == (nheads, dim, dstate)
233
- ngroups = B.shape[1]
234
- assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
235
- assert B.shape == (batch, ngroups, dstate)
236
- assert C.shape == B.shape
237
- if D is not None:
238
- assert D.shape == (nheads, dim)
239
- if z is not None:
240
- assert z.shape == x.shape
241
- if dt_bias is not None:
242
- assert dt_bias.shape == (nheads, dim)
243
- if state_batch_indices is not None:
244
- assert state_batch_indices.shape == (batch,)
245
- out = torch.empty_like(x)
246
- grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
247
- z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
248
- # We don't want autotune since it will overwrite the state
249
- # We instead tune by hand.
250
- BLOCK_SIZE_M, num_warps = (
251
- (32, 4)
252
- if dstate <= 16
253
- else (
254
- (16, 4)
255
- if dstate <= 32
256
- else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
257
- )
258
- )
259
- tie_hdim = (
260
- A.stride(-1) == 0
261
- and A.stride(-2) == 0
262
- and dt.stride(-1) == 0
263
- and dt_bias.stride(-1) == 0
264
- )
265
- with torch.cuda.device(x.device.index):
266
- _selective_scan_update_kernel[grid](
267
- state,
268
- x,
269
- dt,
270
- dt_bias,
271
- A,
272
- B,
273
- C,
274
- D,
275
- z,
276
- out,
277
- state_batch_indices,
278
- batch,
279
- nheads,
280
- dim,
281
- dstate,
282
- nheads // ngroups,
283
- state.stride(0),
284
- state.stride(1),
285
- state.stride(2),
286
- state.stride(3),
287
- x.stride(0),
288
- x.stride(1),
289
- x.stride(2),
290
- dt.stride(0),
291
- dt.stride(1),
292
- dt.stride(2),
293
- *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
294
- A.stride(0),
295
- A.stride(1),
296
- A.stride(2),
297
- B.stride(0),
298
- B.stride(1),
299
- B.stride(2),
300
- C.stride(0),
301
- C.stride(1),
302
- C.stride(2),
303
- *(D.stride(0), D.stride(1)) if D is not None else 0,
304
- z_strides[0],
305
- z_strides[1],
306
- z_strides[2],
307
- out.stride(0),
308
- out.stride(1),
309
- out.stride(2),
310
- dt_softplus,
311
- tie_hdim,
312
- BLOCK_SIZE_M,
313
- num_warps=num_warps,
314
- )
315
- if not has_heads:
316
- out = out.squeeze(1)
317
- return out
318
-
319
-
320
- def selective_state_update_ref(
321
- state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
322
- ):
323
- """
324
- Argument:
325
- state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
326
- x: (batch, dim) or (batch, nheads, dim)
327
- dt: (batch, dim) or (batch, nheads, dim)
328
- A: (dim, dstate) or (nheads, dim, dstate)
329
- B: (batch, dstate) or (batch, ngroups, dstate)
330
- C: (batch, dstate) or (batch, ngroups, dstate)
331
- D: (dim,) or (nheads, dim)
332
- z: (batch, dim) or (batch, nheads, dim)
333
- dt_bias: (dim,) or (nheads, dim)
334
- Return:
335
- out: (batch, dim) or (batch, nheads, dim)
336
- """
337
- has_heads = state.dim() > 3
338
- if state.dim() == 3:
339
- state = state.unsqueeze(1)
340
- if x.dim() == 2:
341
- x = x.unsqueeze(1)
342
- if dt.dim() == 2:
343
- dt = dt.unsqueeze(1)
344
- if A.dim() == 2:
345
- A = A.unsqueeze(0)
346
- if B.dim() == 2:
347
- B = B.unsqueeze(1)
348
- if C.dim() == 2:
349
- C = C.unsqueeze(1)
350
- if D is not None and D.dim() == 1:
351
- D = D.unsqueeze(0)
352
- if z is not None and z.dim() == 2:
353
- z = z.unsqueeze(1)
354
- if dt_bias is not None and dt_bias.dim() == 1:
355
- dt_bias = dt_bias.unsqueeze(0)
356
- batch, nheads, dim, dstate = state.shape
357
- assert x.shape == (batch, nheads, dim)
358
- assert dt.shape == x.shape
359
- assert A.shape == (nheads, dim, dstate)
360
- ngroups = B.shape[1]
361
- assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
362
- assert B.shape == (batch, ngroups, dstate)
363
- assert C.shape == B.shape
364
- if D is not None:
365
- assert D.shape == (nheads, dim)
366
- if z is not None:
367
- assert z.shape == x.shape
368
- if dt_bias is not None:
369
- assert dt_bias.shape == (nheads, dim)
370
- dt = dt + dt_bias
371
- dt = F.softplus(dt) if dt_softplus else dt
372
- dA = torch.exp(
373
- rearrange(dt, "b h d -> b h d 1") * A
374
- ) # (batch, nheads, dim, dstate)
375
- B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
376
- C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
377
- dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
378
- B, "b h n -> b h 1 n"
379
- ) # (batch, nheads, dim, dstate)
380
- state.copy_(
381
- state * dA + dB * rearrange(x, "b h d -> b h d 1")
382
- ) # (batch, dim, dstate
383
- out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
384
- if D is not None:
385
- out += (x * D).to(out.dtype)
386
- out = (out if z is None else out * F.silu(z)).to(x.dtype)
387
- if not has_heads:
388
- out = out.squeeze(1)
389
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_scan.py DELETED
The diff for this file is too large to render. See raw diff
 
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_state.py DELETED
@@ -1,2012 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao, Albert Gu.
2
-
3
- """We want triton==2.1.0 or 2.2.0 for this
4
- """
5
-
6
- import math
7
- import torch
8
- import torch.nn.functional as F
9
-
10
- import triton
11
- import triton.language as tl
12
-
13
- from einops import rearrange, repeat
14
-
15
- from .softplus import softplus
16
-
17
-
18
- def init_to_zero(names):
19
- return lambda nargs: [
20
- nargs[name].zero_() for name in names if nargs[name] is not None
21
- ]
22
-
23
-
24
- @triton.autotune(
25
- configs=[
26
- triton.Config({"BLOCK_SIZE_H": 1}),
27
- triton.Config({"BLOCK_SIZE_H": 2}),
28
- triton.Config({"BLOCK_SIZE_H": 4}),
29
- triton.Config({"BLOCK_SIZE_H": 8}),
30
- triton.Config({"BLOCK_SIZE_H": 16}),
31
- triton.Config({"BLOCK_SIZE_H": 32}),
32
- triton.Config({"BLOCK_SIZE_H": 64}),
33
- ],
34
- key=["chunk_size", "nheads"],
35
- )
36
- @triton.jit
37
- def _chunk_cumsum_fwd_kernel(
38
- # Pointers to matrices
39
- dt_ptr,
40
- A_ptr,
41
- dt_bias_ptr,
42
- dt_out_ptr,
43
- dA_cumsum_ptr,
44
- # Matrix dimension
45
- batch,
46
- seqlen,
47
- nheads,
48
- chunk_size,
49
- dt_min,
50
- dt_max,
51
- # Strides
52
- stride_dt_batch,
53
- stride_dt_seqlen,
54
- stride_dt_head,
55
- stride_A_head,
56
- stride_dt_bias_head,
57
- stride_dt_out_batch,
58
- stride_dt_out_chunk,
59
- stride_dt_out_head,
60
- stride_dt_out_csize,
61
- stride_dA_cs_batch,
62
- stride_dA_cs_chunk,
63
- stride_dA_cs_head,
64
- stride_dA_cs_csize,
65
- # Meta-parameters
66
- DT_SOFTPLUS: tl.constexpr,
67
- HAS_DT_BIAS: tl.constexpr,
68
- BLOCK_SIZE_H: tl.constexpr,
69
- BLOCK_SIZE_CHUNK: tl.constexpr,
70
- ):
71
- pid_b = tl.program_id(axis=0)
72
- pid_c = tl.program_id(axis=1)
73
- pid_h = tl.program_id(axis=2)
74
- dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
75
- dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk
76
- dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk
77
-
78
- offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
79
- offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
80
- dt_ptrs = dt_ptr + (
81
- offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
82
- )
83
- A_ptrs = A_ptr + offs_h * stride_A_head
84
- dt_out_ptrs = dt_out_ptr + (
85
- offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize
86
- )
87
- dA_cs_ptrs = dA_cumsum_ptr + (
88
- offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize
89
- )
90
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
91
-
92
- dt = tl.load(
93
- dt_ptrs,
94
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
95
- other=0.0,
96
- ).to(tl.float32)
97
- if HAS_DT_BIAS:
98
- dt_bias = tl.load(
99
- dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
100
- ).to(tl.float32)
101
- dt += dt_bias[:, None]
102
- if DT_SOFTPLUS:
103
- dt = tl.where(dt <= 20.0, softplus(dt), dt)
104
- # As of Triton 2.2.0, tl.clamp is not available yet
105
- # dt = tl.clamp(dt, dt_min, dt_max)
106
- dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
107
- dt = tl.where(
108
- (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
109
- )
110
- tl.store(
111
- dt_out_ptrs,
112
- dt,
113
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
114
- )
115
- A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
116
- dA = dt * A[:, None]
117
- dA_cs = tl.cumsum(dA, axis=1)
118
- tl.store(
119
- dA_cs_ptrs,
120
- dA_cs,
121
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
122
- )
123
-
124
-
125
- @triton.autotune(
126
- configs=[
127
- triton.Config(
128
- {"BLOCK_SIZE_H": 1}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
129
- ),
130
- triton.Config(
131
- {"BLOCK_SIZE_H": 2}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
132
- ),
133
- triton.Config(
134
- {"BLOCK_SIZE_H": 4}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
135
- ),
136
- triton.Config(
137
- {"BLOCK_SIZE_H": 8}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
138
- ),
139
- triton.Config(
140
- {"BLOCK_SIZE_H": 16}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
141
- ),
142
- triton.Config(
143
- {"BLOCK_SIZE_H": 32}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
144
- ),
145
- triton.Config(
146
- {"BLOCK_SIZE_H": 64}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
147
- ),
148
- ],
149
- key=["chunk_size", "nheads"],
150
- )
151
- @triton.jit
152
- def _chunk_cumsum_bwd_kernel(
153
- # Pointers to matrices
154
- ddA_ptr,
155
- ddt_out_ptr,
156
- dt_ptr,
157
- A_ptr,
158
- dt_bias_ptr,
159
- ddt_ptr,
160
- dA_ptr,
161
- ddt_bias_ptr,
162
- # Matrix dimensions
163
- batch,
164
- seqlen,
165
- nheads,
166
- chunk_size,
167
- dt_min,
168
- dt_max,
169
- # Strides
170
- stride_ddA_batch,
171
- stride_ddA_chunk,
172
- stride_ddA_head,
173
- stride_ddA_csize,
174
- stride_ddt_out_batch,
175
- stride_ddt_out_chunk,
176
- stride_ddt_out_head,
177
- stride_ddt_out_csize,
178
- stride_dt_batch,
179
- stride_dt_seqlen,
180
- stride_dt_head,
181
- stride_A_head,
182
- stride_dt_bias_head,
183
- stride_ddt_batch,
184
- stride_ddt_seqlen,
185
- stride_ddt_head,
186
- stride_dA_head,
187
- stride_ddt_bias_head,
188
- # Meta-parameters
189
- DT_SOFTPLUS: tl.constexpr,
190
- HAS_DT_BIAS: tl.constexpr,
191
- BLOCK_SIZE_H: tl.constexpr,
192
- BLOCK_SIZE_CHUNK: tl.constexpr,
193
- ):
194
- pid_b = tl.program_id(axis=0)
195
- pid_c = tl.program_id(axis=1)
196
- pid_h = tl.program_id(axis=2)
197
- ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk
198
- ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk
199
- dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
200
- ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen
201
-
202
- offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
203
- offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
204
- ddt_out_ptrs = ddt_out_ptr + (
205
- offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize
206
- )
207
- ddA_ptrs = ddA_ptr + (
208
- offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize
209
- )
210
- dt_ptrs = dt_ptr + (
211
- offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
212
- )
213
- ddt_ptrs = ddt_ptr + (
214
- offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen
215
- )
216
- A_ptrs = A_ptr + offs_h * stride_A_head
217
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
218
-
219
- ddA = tl.load(
220
- ddA_ptrs,
221
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
222
- other=0.0,
223
- ).to(tl.float32)
224
- ddt_out = tl.load(
225
- ddt_out_ptrs,
226
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
227
- other=0.0,
228
- ).to(tl.float32)
229
- A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
230
- ddt = ddA * A[:, None] + ddt_out
231
- dt = tl.load(
232
- dt_ptrs,
233
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
234
- other=0.0,
235
- ).to(tl.float32)
236
- if HAS_DT_BIAS:
237
- dt_bias = tl.load(
238
- dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
239
- ).to(tl.float32)
240
- dt += dt_bias[:, None]
241
- if DT_SOFTPLUS:
242
- dt_presoftplus = dt
243
- dt = tl.where(dt <= 20.0, softplus(dt), ddt)
244
- clamp_mask = (dt < dt_min) | (dt > dt_max)
245
- # As of Triton 2.2.0, tl.clamp is not available yet
246
- # dt = tl.clamp(dt, dt_min, dt_max)
247
- dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
248
- dt = tl.where(
249
- (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
250
- )
251
- ddt = tl.where(
252
- (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0
253
- )
254
- ddt = tl.where(clamp_mask, 0.0, ddt)
255
- if DT_SOFTPLUS:
256
- ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt)
257
- tl.store(
258
- ddt_ptrs,
259
- ddt,
260
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
261
- )
262
- dA = tl.sum(ddA * dt, axis=1)
263
- tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads)
264
- if HAS_DT_BIAS:
265
- ddt_bias = tl.sum(ddt, axis=1)
266
- tl.atomic_add(
267
- ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads
268
- )
269
-
270
-
271
- @triton.autotune(
272
- configs=[
273
- triton.Config(
274
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
275
- num_stages=3,
276
- num_warps=8,
277
- ),
278
- triton.Config(
279
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
280
- num_stages=4,
281
- num_warps=4,
282
- ),
283
- triton.Config(
284
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
285
- num_stages=4,
286
- num_warps=4,
287
- ),
288
- triton.Config(
289
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
290
- num_stages=4,
291
- num_warps=4,
292
- ),
293
- triton.Config(
294
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
295
- num_stages=4,
296
- num_warps=4,
297
- ),
298
- triton.Config(
299
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
300
- num_stages=4,
301
- num_warps=4,
302
- ),
303
- triton.Config(
304
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
305
- num_stages=5,
306
- num_warps=2,
307
- ),
308
- triton.Config(
309
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
310
- num_stages=5,
311
- num_warps=2,
312
- ),
313
- triton.Config(
314
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
315
- num_stages=4,
316
- num_warps=2,
317
- ),
318
- ],
319
- key=["hdim", "dstate", "chunk_size"],
320
- )
321
- @triton.jit
322
- def _chunk_state_fwd_kernel(
323
- # Pointers to matrices
324
- x_ptr,
325
- b_ptr,
326
- states_ptr,
327
- dt_ptr,
328
- dA_cumsum_ptr,
329
- seq_idx_ptr,
330
- # Matrix dimensions
331
- hdim,
332
- dstate,
333
- chunk_size,
334
- batch,
335
- seqlen,
336
- nheads_ngroups_ratio,
337
- # Strides
338
- stride_x_batch,
339
- stride_x_seqlen,
340
- stride_x_head,
341
- stride_x_hdim,
342
- stride_b_batch,
343
- stride_b_seqlen,
344
- stride_b_head,
345
- stride_b_dstate,
346
- stride_states_batch,
347
- stride_states_chunk,
348
- stride_states_head,
349
- stride_states_hdim,
350
- stride_states_dstate,
351
- stride_dt_batch,
352
- stride_dt_chunk,
353
- stride_dt_head,
354
- stride_dt_csize,
355
- stride_dA_cs_batch,
356
- stride_dA_cs_chunk,
357
- stride_dA_cs_head,
358
- stride_dA_cs_csize,
359
- stride_seq_idx_batch,
360
- stride_seq_idx_seqlen,
361
- # Meta-parameters
362
- HAS_SEQ_IDX: tl.constexpr,
363
- BLOCK_SIZE_M: tl.constexpr,
364
- BLOCK_SIZE_N: tl.constexpr,
365
- BLOCK_SIZE_K: tl.constexpr,
366
- ):
367
- pid_bc = tl.program_id(axis=1)
368
- pid_c = pid_bc // batch
369
- pid_b = pid_bc - pid_c * batch
370
- pid_h = tl.program_id(axis=2)
371
- num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
372
- pid_m = tl.program_id(axis=0) // num_pid_n
373
- pid_n = tl.program_id(axis=0) % num_pid_n
374
- b_ptr += (
375
- pid_b * stride_b_batch
376
- + pid_c * chunk_size * stride_b_seqlen
377
- + (pid_h // nheads_ngroups_ratio) * stride_b_head
378
- )
379
- x_ptr += (
380
- pid_b * stride_x_batch
381
- + pid_c * chunk_size * stride_x_seqlen
382
- + pid_h * stride_x_head
383
- )
384
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
385
- dA_cumsum_ptr += (
386
- pid_b * stride_dA_cs_batch
387
- + pid_c * stride_dA_cs_chunk
388
- + pid_h * stride_dA_cs_head
389
- )
390
- if HAS_SEQ_IDX:
391
- seq_idx_ptr += (
392
- pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
393
- )
394
-
395
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
396
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
397
- offs_k = tl.arange(0, BLOCK_SIZE_K)
398
- x_ptrs = x_ptr + (
399
- offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
400
- )
401
- b_ptrs = b_ptr + (
402
- offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
403
- )
404
- dt_ptrs = dt_ptr + offs_k * stride_dt_csize
405
- dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
406
- tl.float32
407
- )
408
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
409
- if HAS_SEQ_IDX:
410
- seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
411
-
412
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
413
- if HAS_SEQ_IDX:
414
- seq_idx_last = tl.load(
415
- seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
416
- )
417
-
418
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
419
- for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
420
- x = tl.load(
421
- x_ptrs,
422
- mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k),
423
- other=0.0,
424
- )
425
- b = tl.load(
426
- b_ptrs,
427
- mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate),
428
- other=0.0,
429
- ).to(tl.float32)
430
- dA_cs_k = tl.load(
431
- dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
432
- ).to(tl.float32)
433
- if HAS_SEQ_IDX:
434
- seq_idx_k = tl.load(
435
- seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1
436
- )
437
- dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
438
- tl.float32
439
- )
440
- if not HAS_SEQ_IDX:
441
- scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
442
- else:
443
- scale = tl.where(
444
- seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0
445
- )
446
- b *= scale[:, None]
447
- b = b.to(x_ptr.dtype.element_ty)
448
- acc += tl.dot(x, b)
449
- x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
450
- b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
451
- dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
452
- dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
453
- if HAS_SEQ_IDX:
454
- seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
455
- states = acc.to(states_ptr.dtype.element_ty)
456
-
457
- states_ptr += (
458
- pid_b * stride_states_batch
459
- + pid_c * stride_states_chunk
460
- + pid_h * stride_states_head
461
- )
462
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
463
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
464
- states_ptrs = states_ptr + (
465
- offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
466
- )
467
- c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
468
- tl.store(states_ptrs, states, mask=c_mask)
469
-
470
-
471
- @triton.autotune(
472
- configs=[
473
- triton.Config(
474
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
475
- num_stages=3,
476
- num_warps=8,
477
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
478
- ),
479
- triton.Config(
480
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
481
- num_stages=4,
482
- num_warps=4,
483
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
484
- ),
485
- triton.Config(
486
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
487
- num_stages=4,
488
- num_warps=4,
489
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
490
- ),
491
- triton.Config(
492
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
493
- num_stages=4,
494
- num_warps=4,
495
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
496
- ),
497
- triton.Config(
498
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
499
- num_stages=4,
500
- num_warps=4,
501
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
502
- ),
503
- triton.Config(
504
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
505
- num_stages=4,
506
- num_warps=4,
507
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
508
- ),
509
- triton.Config(
510
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
511
- num_stages=5,
512
- num_warps=4,
513
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
514
- ),
515
- triton.Config(
516
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
517
- num_stages=5,
518
- num_warps=4,
519
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
520
- ),
521
- triton.Config(
522
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
523
- num_stages=4,
524
- num_warps=4,
525
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
526
- ),
527
- ],
528
- key=["chunk_size", "hdim", "dstate"],
529
- )
530
- @triton.jit
531
- def _chunk_state_bwd_dx_kernel(
532
- # Pointers to matrices
533
- x_ptr,
534
- b_ptr,
535
- dstates_ptr,
536
- dt_ptr,
537
- dA_cumsum_ptr,
538
- dx_ptr,
539
- ddt_ptr,
540
- ddA_cumsum_ptr,
541
- # Matrix dimensions
542
- chunk_size,
543
- hdim,
544
- dstate,
545
- batch,
546
- seqlen,
547
- nheads_ngroups_ratio,
548
- # Strides
549
- stride_x_batch,
550
- stride_x_seqlen,
551
- stride_x_head,
552
- stride_x_hdim,
553
- stride_b_batch,
554
- stride_b_seqlen,
555
- stride_b_head,
556
- stride_b_dstate,
557
- stride_dstates_batch,
558
- stride_dstates_chunk,
559
- stride_states_head,
560
- stride_states_hdim,
561
- stride_states_dstate,
562
- stride_dt_batch,
563
- stride_dt_chunk,
564
- stride_dt_head,
565
- stride_dt_csize,
566
- stride_dA_cs_batch,
567
- stride_dA_cs_chunk,
568
- stride_dA_cs_head,
569
- stride_dA_cs_csize,
570
- stride_dx_batch,
571
- stride_dx_seqlen,
572
- stride_dx_head,
573
- stride_dx_hdim,
574
- stride_ddt_batch,
575
- stride_ddt_chunk,
576
- stride_ddt_head,
577
- stride_ddt_csize,
578
- stride_ddA_cs_batch,
579
- stride_ddA_cs_chunk,
580
- stride_ddA_cs_head,
581
- stride_ddA_cs_csize,
582
- # Meta-parameters
583
- BLOCK_SIZE_M: tl.constexpr,
584
- BLOCK_SIZE_N: tl.constexpr,
585
- BLOCK_SIZE_K: tl.constexpr,
586
- BLOCK_SIZE_DSTATE: tl.constexpr,
587
- ):
588
- pid_bc = tl.program_id(axis=1)
589
- pid_c = pid_bc // batch
590
- pid_b = pid_bc - pid_c * batch
591
- pid_h = tl.program_id(axis=2)
592
- num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
593
- pid_m = tl.program_id(axis=0) // num_pid_n
594
- pid_n = tl.program_id(axis=0) % num_pid_n
595
- x_ptr += (
596
- pid_b * stride_x_batch
597
- + pid_c * chunk_size * stride_x_seqlen
598
- + pid_h * stride_x_head
599
- )
600
- b_ptr += (
601
- pid_b * stride_b_batch
602
- + pid_c * chunk_size * stride_b_seqlen
603
- + (pid_h // nheads_ngroups_ratio) * stride_b_head
604
- )
605
- dstates_ptr += (
606
- pid_b * stride_dstates_batch
607
- + pid_c * stride_dstates_chunk
608
- + pid_h * stride_states_head
609
- )
610
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
611
- ddt_ptr += (
612
- pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
613
- )
614
- ddA_cumsum_ptr += (
615
- pid_b * stride_ddA_cs_batch
616
- + pid_c * stride_ddA_cs_chunk
617
- + pid_h * stride_ddA_cs_head
618
- )
619
- dA_cumsum_ptr += (
620
- pid_b * stride_dA_cs_batch
621
- + pid_c * stride_dA_cs_chunk
622
- + pid_h * stride_dA_cs_head
623
- )
624
-
625
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
626
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
627
-
628
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
629
- # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
630
- offs_k = tl.arange(
631
- 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
632
- )
633
- b_ptrs = b_ptr + (
634
- offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate
635
- )
636
- dstates_ptrs = dstates_ptr + (
637
- offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate
638
- )
639
- if BLOCK_SIZE_DSTATE <= 128:
640
- b = tl.load(
641
- b_ptrs,
642
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate),
643
- other=0.0,
644
- )
645
- dstates = tl.load(
646
- dstates_ptrs,
647
- mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim),
648
- other=0.0,
649
- )
650
- dstates = dstates.to(b_ptr.dtype.element_ty)
651
- acc = tl.dot(b, dstates)
652
- else:
653
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
654
- for k in range(0, dstate, BLOCK_SIZE_K):
655
- b = tl.load(
656
- b_ptrs,
657
- mask=(offs_m[:, None] < chunk_size_limit)
658
- & (offs_k[None, :] < dstate - k),
659
- other=0.0,
660
- )
661
- dstates = tl.load(
662
- dstates_ptrs,
663
- mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim),
664
- other=0.0,
665
- )
666
- dstates = dstates.to(b_ptr.dtype.element_ty)
667
- acc += tl.dot(b, dstates)
668
- b_ptrs += BLOCK_SIZE_K * stride_b_dstate
669
- dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
670
-
671
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
672
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
673
-
674
- dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
675
- tl.float32
676
- )
677
- dt_ptrs = dt_ptr + offs_m * stride_dt_csize
678
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
679
- dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(
680
- tl.float32
681
- )
682
- dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
683
- acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]
684
-
685
- x_ptrs = x_ptr + (
686
- offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
687
- )
688
- x = tl.load(
689
- x_ptrs,
690
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
691
- other=0.0,
692
- ).to(tl.float32)
693
- ddt = tl.sum(acc * x, axis=1)
694
- ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
695
- tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
696
- ddA_cs = -(ddt * dt_m)
697
- ddA_cs_last = -tl.sum(ddA_cs)
698
- ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
699
- tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
700
- tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last)
701
-
702
- dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty)
703
- dx_ptr += (
704
- pid_b * stride_dx_batch
705
- + pid_c * chunk_size * stride_dx_seqlen
706
- + pid_h * stride_dx_head
707
- )
708
- dx_ptrs = dx_ptr + (
709
- offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim
710
- )
711
- tl.store(
712
- dx_ptrs,
713
- dx,
714
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
715
- )
716
-
717
-
718
- @triton.autotune(
719
- configs=[
720
- triton.Config(
721
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128},
722
- num_stages=3,
723
- num_warps=4,
724
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
725
- ),
726
- triton.Config(
727
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32},
728
- num_stages=3,
729
- num_warps=4,
730
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
731
- ),
732
- triton.Config(
733
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128},
734
- num_stages=3,
735
- num_warps=4,
736
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
737
- ),
738
- triton.Config(
739
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64},
740
- num_stages=3,
741
- num_warps=4,
742
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
743
- ),
744
- triton.Config(
745
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64},
746
- num_stages=3,
747
- num_warps=4,
748
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
749
- ),
750
- triton.Config(
751
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32},
752
- num_stages=3,
753
- num_warps=4,
754
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
755
- ),
756
- triton.Config(
757
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64},
758
- num_stages=3,
759
- num_warps=4,
760
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
761
- ),
762
- triton.Config(
763
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32},
764
- num_stages=3,
765
- num_warps=4,
766
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
767
- ),
768
- ],
769
- key=["chunk_size", "dstate", "hdim"],
770
- )
771
- @triton.jit
772
- def _chunk_state_bwd_db_kernel(
773
- # Pointers to matrices
774
- x_ptr,
775
- dstates_ptr,
776
- b_ptr,
777
- dt_ptr,
778
- dA_cumsum_ptr,
779
- seq_idx_ptr,
780
- db_ptr,
781
- ddA_cumsum_ptr,
782
- # Matrix dimensions
783
- chunk_size,
784
- dstate,
785
- hdim,
786
- batch,
787
- seqlen,
788
- nheads,
789
- nheads_per_program,
790
- ngroups,
791
- # Strides
792
- stride_x_batch,
793
- stride_x_seqlen,
794
- stride_x_head,
795
- stride_x_hdim,
796
- stride_dstates_batch,
797
- stride_dstates_chunk,
798
- stride_states_head,
799
- stride_states_hdim,
800
- stride_states_dstate,
801
- stride_b_batch,
802
- stride_b_seqlen,
803
- stride_b_head,
804
- stride_b_dstate,
805
- stride_dt_batch,
806
- stride_dt_chunk,
807
- stride_dt_head,
808
- stride_dt_csize,
809
- stride_dA_cs_batch,
810
- stride_dA_cs_chunk,
811
- stride_dA_cs_head,
812
- stride_dA_cs_csize,
813
- stride_seq_idx_batch,
814
- stride_seq_idx_seqlen,
815
- stride_db_batch,
816
- stride_db_seqlen,
817
- stride_db_split,
818
- stride_db_group,
819
- stride_db_dstate,
820
- stride_ddA_cs_batch,
821
- stride_ddA_cs_chunk,
822
- stride_ddA_cs_head,
823
- stride_ddA_cs_csize,
824
- # Meta-parameters
825
- HAS_DDA_CS: tl.constexpr,
826
- HAS_SEQ_IDX: tl.constexpr,
827
- BLOCK_SIZE_M: tl.constexpr,
828
- BLOCK_SIZE_N: tl.constexpr,
829
- BLOCK_SIZE_K: tl.constexpr,
830
- ):
831
- pid_bc = tl.program_id(axis=1)
832
- pid_c = pid_bc // batch
833
- pid_b = pid_bc - pid_c * batch
834
- pid_sg = tl.program_id(axis=2)
835
- pid_s = pid_sg // ngroups
836
- pid_g = pid_sg - pid_s * ngroups
837
- num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
838
- pid_m = tl.program_id(axis=0) // num_pid_n
839
- pid_n = tl.program_id(axis=0) % num_pid_n
840
- x_ptr += (
841
- pid_b * stride_x_batch
842
- + pid_c * chunk_size * stride_x_seqlen
843
- + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head
844
- )
845
- db_ptr += (
846
- pid_b * stride_db_batch
847
- + pid_c * chunk_size * stride_db_seqlen
848
- + pid_g * stride_db_group
849
- + pid_s * stride_db_split
850
- )
851
- dstates_ptr += (
852
- pid_b * stride_dstates_batch
853
- + pid_c * stride_dstates_chunk
854
- + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program)
855
- * stride_states_head
856
- )
857
- dt_ptr += (
858
- pid_b * stride_dt_batch
859
- + pid_c * stride_dt_chunk
860
- + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head
861
- )
862
- dA_cumsum_ptr += (
863
- pid_b * stride_dA_cs_batch
864
- + pid_c * stride_dA_cs_chunk
865
- + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head
866
- )
867
- if HAS_DDA_CS:
868
- b_ptr += (
869
- pid_b * stride_b_batch
870
- + pid_c * chunk_size * stride_b_seqlen
871
- + pid_g * stride_b_head
872
- )
873
- ddA_cumsum_ptr += (
874
- pid_b * stride_ddA_cs_batch
875
- + pid_c * stride_ddA_cs_chunk
876
- + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program)
877
- * stride_ddA_cs_head
878
- )
879
- if HAS_SEQ_IDX:
880
- seq_idx_ptr += (
881
- pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
882
- )
883
-
884
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
885
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
886
- offs_k = tl.arange(0, BLOCK_SIZE_K)
887
- x_ptrs = x_ptr + (
888
- offs_m[:, None] * stride_x_seqlen + offs_k[None, :] * stride_x_hdim
889
- )
890
- dstates_ptrs = dstates_ptr + (
891
- offs_n[None, :] * stride_states_dstate + offs_k[:, None] * stride_states_hdim
892
- )
893
- dt_ptrs = dt_ptr + offs_m * stride_dt_csize
894
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
895
- if HAS_DDA_CS:
896
- b_ptrs = b_ptr + (
897
- offs_m[:, None] * stride_b_seqlen + offs_n[None, :] * stride_b_dstate
898
- )
899
- ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
900
-
901
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
902
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
903
- if HAS_DDA_CS:
904
- b = tl.load(
905
- b_ptrs,
906
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate),
907
- other=0.0,
908
- ).to(tl.float32)
909
- if HAS_SEQ_IDX:
910
- seq_idx_m = tl.load(
911
- seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
912
- mask=offs_m < chunk_size_limit,
913
- other=-1,
914
- )
915
- seq_idx_last = tl.load(
916
- seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
917
- )
918
- nheads_iter = min(
919
- nheads_per_program, nheads // ngroups - pid_s * nheads_per_program
920
- )
921
- for h in range(nheads_iter):
922
- x = tl.load(
923
- x_ptrs,
924
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim),
925
- other=0.0,
926
- )
927
- dstates = tl.load(
928
- dstates_ptrs,
929
- mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate),
930
- other=0.0,
931
- )
932
- dstates = dstates.to(x_ptrs.dtype.element_ty)
933
- db = tl.dot(x, dstates)
934
- dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
935
- tl.float32
936
- )
937
- dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(
938
- tl.float32
939
- )
940
- dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
941
- if not HAS_SEQ_IDX:
942
- scale = tl.exp(dA_cs_last - dA_cs_m)
943
- else:
944
- scale = tl.where(
945
- seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0
946
- )
947
- db *= (scale * dt_m)[:, None]
948
- if HAS_DDA_CS:
949
- # This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum
950
- ddA_cs = tl.sum(db * b, axis=1)
951
- tl.atomic_add(
952
- ddA_cumsum_ptrs + stride_ddA_cs_csize,
953
- ddA_cs,
954
- mask=offs_m < chunk_size - 1,
955
- )
956
- acc += db
957
- x_ptrs += stride_x_head
958
- dstates_ptrs += stride_states_head
959
- dt_ptrs += stride_dt_head
960
- dA_cumsum_ptr += stride_dA_cs_head
961
- dA_cumsum_ptrs += stride_dA_cs_head
962
- if HAS_DDA_CS:
963
- ddA_cumsum_ptrs += stride_ddA_cs_head
964
-
965
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
966
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
967
- # if HAS_SEQ_IDX:
968
- # seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
969
- # seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
970
- # acc = tl.where(seq_idx_m[:, None] == seq_idx_last, acc, 0.0)
971
- db_ptrs = db_ptr + (
972
- offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_dstate
973
- )
974
- tl.store(
975
- db_ptrs,
976
- acc,
977
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate),
978
- )
979
-
980
-
981
- @triton.autotune(
982
- configs=[
983
- # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
984
- # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
985
- # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
986
- # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
987
- # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
988
- # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
989
- # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
990
- # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
991
- # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
992
- triton.Config(
993
- {"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32},
994
- num_stages=3,
995
- num_warps=4,
996
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
997
- ),
998
- triton.Config(
999
- {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
1000
- num_stages=3,
1001
- num_warps=4,
1002
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1003
- ),
1004
- triton.Config(
1005
- {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1006
- num_stages=3,
1007
- num_warps=4,
1008
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1009
- ),
1010
- triton.Config(
1011
- {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
1012
- num_stages=3,
1013
- num_warps=4,
1014
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1015
- ),
1016
- triton.Config(
1017
- {"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32},
1018
- num_stages=4,
1019
- num_warps=8,
1020
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1021
- ),
1022
- triton.Config(
1023
- {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
1024
- num_stages=4,
1025
- num_warps=8,
1026
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1027
- ),
1028
- triton.Config(
1029
- {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1030
- num_stages=4,
1031
- num_warps=8,
1032
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1033
- ),
1034
- triton.Config(
1035
- {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
1036
- num_stages=4,
1037
- num_warps=8,
1038
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1039
- ),
1040
- ],
1041
- key=["chunk_size", "hdim", "dstate"],
1042
- )
1043
- @triton.jit
1044
- def _chunk_state_bwd_ddAcs_stable_kernel(
1045
- # Pointers to matrices
1046
- x_ptr,
1047
- b_ptr,
1048
- dstates_ptr,
1049
- dt_ptr,
1050
- dA_cumsum_ptr,
1051
- seq_idx_ptr,
1052
- ddA_cumsum_ptr,
1053
- # Matrix dimensions
1054
- chunk_size,
1055
- hdim,
1056
- dstate,
1057
- batch,
1058
- seqlen,
1059
- nheads_ngroups_ratio,
1060
- # Strides
1061
- stride_x_batch,
1062
- stride_x_seqlen,
1063
- stride_x_head,
1064
- stride_x_hdim,
1065
- stride_b_batch,
1066
- stride_b_seqlen,
1067
- stride_b_head,
1068
- stride_b_dstate,
1069
- stride_dstates_batch,
1070
- stride_dstates_chunk,
1071
- stride_states_head,
1072
- stride_states_hdim,
1073
- stride_states_dstate,
1074
- stride_dt_batch,
1075
- stride_dt_chunk,
1076
- stride_dt_head,
1077
- stride_dt_csize,
1078
- stride_dA_cs_batch,
1079
- stride_dA_cs_chunk,
1080
- stride_dA_cs_head,
1081
- stride_dA_cs_csize,
1082
- stride_seq_idx_batch,
1083
- stride_seq_idx_seqlen,
1084
- stride_ddA_cs_batch,
1085
- stride_ddA_cs_chunk,
1086
- stride_ddA_cs_head,
1087
- stride_ddA_cs_csize,
1088
- # Meta-parameters
1089
- HAS_SEQ_IDX: tl.constexpr,
1090
- BLOCK_SIZE_M: tl.constexpr,
1091
- BLOCK_SIZE_N: tl.constexpr,
1092
- BLOCK_SIZE_K: tl.constexpr,
1093
- BLOCK_SIZE_DSTATE: tl.constexpr,
1094
- ):
1095
- pid_bc = tl.program_id(axis=1)
1096
- pid_c = pid_bc // batch
1097
- pid_b = pid_bc - pid_c * batch
1098
- pid_h = tl.program_id(axis=2)
1099
- num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
1100
- pid_m = tl.program_id(axis=0) // num_pid_n
1101
- pid_n = tl.program_id(axis=0) % num_pid_n
1102
- x_ptr += (
1103
- pid_b * stride_x_batch
1104
- + pid_c * chunk_size * stride_x_seqlen
1105
- + pid_h * stride_x_head
1106
- )
1107
- b_ptr += (
1108
- pid_b * stride_b_batch
1109
- + pid_c * chunk_size * stride_b_seqlen
1110
- + (pid_h // nheads_ngroups_ratio) * stride_b_head
1111
- )
1112
- dstates_ptr += (
1113
- pid_b * stride_dstates_batch
1114
- + pid_c * stride_dstates_chunk
1115
- + pid_h * stride_states_head
1116
- )
1117
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
1118
- ddA_cumsum_ptr += (
1119
- pid_b * stride_ddA_cs_batch
1120
- + pid_c * stride_ddA_cs_chunk
1121
- + pid_h * stride_ddA_cs_head
1122
- )
1123
- dA_cumsum_ptr += (
1124
- pid_b * stride_dA_cs_batch
1125
- + pid_c * stride_dA_cs_chunk
1126
- + pid_h * stride_dA_cs_head
1127
- )
1128
- if HAS_SEQ_IDX:
1129
- seq_idx_ptr += (
1130
- pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
1131
- )
1132
-
1133
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1134
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1135
-
1136
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
1137
- # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
1138
- offs_k = tl.arange(
1139
- 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
1140
- )
1141
- b_ptrs = b_ptr + (
1142
- offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate
1143
- )
1144
- dstates_ptrs = dstates_ptr + (
1145
- offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate
1146
- )
1147
- if BLOCK_SIZE_DSTATE <= 128:
1148
- b = tl.load(
1149
- b_ptrs,
1150
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate),
1151
- other=0.0,
1152
- )
1153
- dstates = tl.load(
1154
- dstates_ptrs,
1155
- mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim),
1156
- other=0.0,
1157
- )
1158
- dstates = dstates.to(b_ptr.dtype.element_ty)
1159
- acc = tl.dot(b, dstates)
1160
- else:
1161
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
1162
- for k in range(0, dstate, BLOCK_SIZE_K):
1163
- b = tl.load(
1164
- b_ptrs,
1165
- mask=(offs_m[:, None] < chunk_size_limit)
1166
- & (offs_k[None, :] < dstate - k),
1167
- other=0.0,
1168
- )
1169
- dstates = tl.load(
1170
- dstates_ptrs,
1171
- mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim),
1172
- other=0.0,
1173
- )
1174
- dstates = dstates.to(b_ptr.dtype.element_ty)
1175
- acc += tl.dot(b, dstates)
1176
- b_ptrs += BLOCK_SIZE_K * stride_b_dstate
1177
- dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
1178
-
1179
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1180
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1181
-
1182
- dA_cs_m = tl.load(
1183
- dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0
1184
- ).to(tl.float32)
1185
- dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
1186
- tl.float32
1187
- )
1188
- if not HAS_SEQ_IDX:
1189
- scale = tl.exp(dA_cs_last - dA_cs_m)
1190
- else:
1191
- seq_idx_m = tl.load(
1192
- seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
1193
- mask=offs_m < chunk_size_limit,
1194
- other=-1,
1195
- )
1196
- seq_idx_last = tl.load(
1197
- seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
1198
- )
1199
- scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
1200
- acc *= scale[:, None]
1201
-
1202
- x_ptrs = x_ptr + (
1203
- offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
1204
- )
1205
- x = tl.load(
1206
- x_ptrs,
1207
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
1208
- other=0.0,
1209
- ).to(tl.float32)
1210
- dt_ptrs = dt_ptr + offs_m * stride_dt_csize
1211
- dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
1212
- ddt = tl.sum(acc * x, axis=1)
1213
- # ddA_cs = -(ddt * dt_m)
1214
- # Triton 2.2.0 errors if we have the cumsum here, so we just write it out
1215
- # then call torch.cumsum outside this kernel.
1216
- # ddA_cs = tl.cumsum(ddt * dt_m)
1217
- ddA_cs = ddt * dt_m
1218
- ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
1219
- # tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
1220
- tl.atomic_add(
1221
- ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1
1222
- )
1223
-
1224
-
1225
- @triton.autotune(
1226
- configs=[
1227
- triton.Config(
1228
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
1229
- num_stages=3,
1230
- num_warps=8,
1231
- ),
1232
- triton.Config(
1233
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
1234
- num_stages=4,
1235
- num_warps=4,
1236
- ),
1237
- triton.Config(
1238
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
1239
- num_stages=4,
1240
- num_warps=4,
1241
- ),
1242
- triton.Config(
1243
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1244
- num_stages=4,
1245
- num_warps=4,
1246
- ),
1247
- triton.Config(
1248
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
1249
- num_stages=4,
1250
- num_warps=4,
1251
- ),
1252
- triton.Config(
1253
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
1254
- num_stages=4,
1255
- num_warps=4,
1256
- ),
1257
- triton.Config(
1258
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
1259
- num_stages=5,
1260
- num_warps=2,
1261
- ),
1262
- triton.Config(
1263
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1264
- num_stages=5,
1265
- num_warps=2,
1266
- ),
1267
- triton.Config(
1268
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1269
- num_stages=4,
1270
- num_warps=2,
1271
- ),
1272
- ],
1273
- key=["hdim", "dstate", "chunk_size"],
1274
- )
1275
- @triton.jit
1276
- def _chunk_state_varlen_kernel(
1277
- # Pointers to matrices
1278
- x_ptr,
1279
- b_ptr,
1280
- dt_ptr,
1281
- dA_cumsum_ptr,
1282
- chunk_states_ptr,
1283
- cu_seqlens_ptr,
1284
- states_ptr,
1285
- # Matrix dimensions
1286
- hdim,
1287
- dstate,
1288
- chunk_size,
1289
- seqlen,
1290
- nheads_ngroups_ratio,
1291
- # Strides
1292
- stride_x_seqlen,
1293
- stride_x_head,
1294
- stride_x_hdim,
1295
- stride_b_seqlen,
1296
- stride_b_head,
1297
- stride_b_dstate,
1298
- stride_dt_chunk,
1299
- stride_dt_head,
1300
- stride_dt_csize,
1301
- stride_dA_cs_chunk,
1302
- stride_dA_cs_head,
1303
- stride_dA_cs_csize,
1304
- stride_chunk_states_chunk,
1305
- stride_chunk_states_head,
1306
- stride_chunk_states_hdim,
1307
- stride_chunk_states_dstate,
1308
- stride_states_batch,
1309
- stride_states_head,
1310
- stride_states_hdim,
1311
- stride_states_dstate,
1312
- # Meta-parameters
1313
- BLOCK_SIZE_M: tl.constexpr,
1314
- BLOCK_SIZE_N: tl.constexpr,
1315
- BLOCK_SIZE_K: tl.constexpr,
1316
- ):
1317
- pid_b = tl.program_id(axis=1)
1318
- pid_h = tl.program_id(axis=2)
1319
- num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
1320
- pid_m = tl.program_id(axis=0) // num_pid_n
1321
- pid_n = tl.program_id(axis=0) % num_pid_n
1322
- end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
1323
- pid_c = (end_idx - 1) // chunk_size
1324
- b_ptr += (
1325
- pid_c * chunk_size * stride_b_seqlen
1326
- + (pid_h // nheads_ngroups_ratio) * stride_b_head
1327
- )
1328
- x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
1329
- dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
1330
- dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
1331
- chunk_states_ptr += (
1332
- pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
1333
- )
1334
-
1335
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1336
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1337
- offs_k = tl.arange(0, BLOCK_SIZE_K)
1338
- x_ptrs = x_ptr + (
1339
- offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
1340
- )
1341
- b_ptrs = b_ptr + (
1342
- offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
1343
- )
1344
- dt_ptrs = dt_ptr + offs_k * stride_dt_csize
1345
- dA_cs_last = tl.load(
1346
- dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
1347
- ).to(tl.float32)
1348
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
1349
-
1350
- chunk_size_limit = end_idx - pid_c * chunk_size
1351
- start_idx = tl.load(cu_seqlens_ptr + pid_b)
1352
- start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
1353
-
1354
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
1355
- for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
1356
- x = tl.load(
1357
- x_ptrs,
1358
- mask=(offs_m[:, None] < hdim)
1359
- & (offs_k[None, :] < chunk_size_limit - k)
1360
- & (offs_k[None, :] >= start_idx_cur - k),
1361
- other=0.0,
1362
- )
1363
- b = tl.load(
1364
- b_ptrs,
1365
- mask=(offs_k[:, None] < chunk_size_limit - k)
1366
- & (offs_n[None, :] < dstate)
1367
- & (offs_k[:, None] >= start_idx_cur - k),
1368
- other=0.0,
1369
- ).to(tl.float32)
1370
- dA_cs_k = tl.load(
1371
- dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
1372
- ).to(tl.float32)
1373
- dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
1374
- tl.float32
1375
- )
1376
- scale = tl.where(
1377
- (offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
1378
- tl.exp((dA_cs_last - dA_cs_k)) * dt_k,
1379
- 0.0,
1380
- )
1381
- b *= scale[:, None]
1382
- b = b.to(x_ptr.dtype.element_ty)
1383
- acc += tl.dot(x, b)
1384
- x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
1385
- b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
1386
- dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
1387
- dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
1388
-
1389
- # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
1390
- if start_idx < pid_c * chunk_size:
1391
- chunk_states_ptrs = chunk_states_ptr + (
1392
- offs_m[:, None] * stride_chunk_states_hdim
1393
- + offs_n[None, :] * stride_chunk_states_dstate
1394
- )
1395
- chunk_states = tl.load(
1396
- chunk_states_ptrs,
1397
- mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate),
1398
- other=0.0,
1399
- ).to(tl.float32)
1400
- # scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0)
1401
- scale = tl.exp(dA_cs_last)
1402
- acc += chunk_states * scale
1403
-
1404
- states = acc.to(states_ptr.dtype.element_ty)
1405
-
1406
- states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
1407
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1408
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1409
- states_ptrs = states_ptr + (
1410
- offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
1411
- )
1412
- c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
1413
- tl.store(states_ptrs, states, mask=c_mask)
1414
-
1415
-
1416
- def _chunk_cumsum_fwd(
1417
- dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))
1418
- ):
1419
- batch, seqlen, nheads = dt.shape
1420
- assert A.shape == (nheads,)
1421
- if dt_bias is not None:
1422
- assert dt_bias.shape == (nheads,)
1423
- nchunks = math.ceil(seqlen / chunk_size)
1424
- dt_out = torch.empty(
1425
- batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
1426
- )
1427
- dA_cumsum = torch.empty(
1428
- batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
1429
- )
1430
- grid_chunk_cs = lambda META: (
1431
- batch,
1432
- nchunks,
1433
- triton.cdiv(nheads, META["BLOCK_SIZE_H"]),
1434
- )
1435
- with torch.cuda.device(dt.device.index):
1436
- _chunk_cumsum_fwd_kernel[grid_chunk_cs](
1437
- dt,
1438
- A,
1439
- dt_bias,
1440
- dt_out,
1441
- dA_cumsum,
1442
- batch,
1443
- seqlen,
1444
- nheads,
1445
- chunk_size,
1446
- dt_limit[0],
1447
- dt_limit[1],
1448
- dt.stride(0),
1449
- dt.stride(1),
1450
- dt.stride(2),
1451
- A.stride(0),
1452
- dt_bias.stride(0) if dt_bias is not None else 0,
1453
- dt_out.stride(0),
1454
- dt_out.stride(2),
1455
- dt_out.stride(1),
1456
- dt_out.stride(3),
1457
- dA_cumsum.stride(0),
1458
- dA_cumsum.stride(2),
1459
- dA_cumsum.stride(1),
1460
- dA_cumsum.stride(3),
1461
- dt_softplus,
1462
- HAS_DT_BIAS=dt_bias is not None,
1463
- BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
1464
- )
1465
- return dA_cumsum, dt_out
1466
-
1467
-
1468
- def _chunk_cumsum_bwd(
1469
- ddA,
1470
- ddt_out,
1471
- dt,
1472
- A,
1473
- dt_bias=None,
1474
- dt_softplus=False,
1475
- dt_limit=(0.0, float("inf")),
1476
- ddt=None,
1477
- ):
1478
- batch, seqlen, nheads = dt.shape
1479
- _, _, nchunks, chunk_size = ddA.shape
1480
- assert ddA.shape == (batch, nheads, nchunks, chunk_size)
1481
- assert ddt_out.shape == (batch, nheads, nchunks, chunk_size)
1482
- assert A.shape == (nheads,)
1483
- if dt_bias is not None:
1484
- assert dt_bias.shape == (nheads,)
1485
- ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32)
1486
- else:
1487
- ddt_bias = None
1488
- if ddt is not None:
1489
- assert ddt.shape == dt.shape
1490
- else:
1491
- ddt = torch.empty_like(dt)
1492
- dA = torch.empty_like(A, dtype=torch.float32)
1493
- grid_chunk_cs = lambda META: (
1494
- batch,
1495
- nchunks,
1496
- triton.cdiv(nheads, META["BLOCK_SIZE_H"]),
1497
- )
1498
- with torch.cuda.device(dt.device.index):
1499
- _chunk_cumsum_bwd_kernel[grid_chunk_cs](
1500
- ddA,
1501
- ddt_out,
1502
- dt,
1503
- A,
1504
- dt_bias,
1505
- ddt,
1506
- dA,
1507
- ddt_bias,
1508
- batch,
1509
- seqlen,
1510
- nheads,
1511
- chunk_size,
1512
- dt_limit[0],
1513
- dt_limit[1],
1514
- ddA.stride(0),
1515
- ddA.stride(2),
1516
- ddA.stride(1),
1517
- ddA.stride(3),
1518
- ddt_out.stride(0),
1519
- ddt_out.stride(2),
1520
- ddt_out.stride(1),
1521
- ddt_out.stride(3),
1522
- dt.stride(0),
1523
- dt.stride(1),
1524
- dt.stride(2),
1525
- A.stride(0),
1526
- dt_bias.stride(0) if dt_bias is not None else 0,
1527
- ddt.stride(0),
1528
- ddt.stride(1),
1529
- ddt.stride(2),
1530
- dA.stride(0),
1531
- ddt_bias.stride(0) if ddt_bias is not None else 0,
1532
- dt_softplus,
1533
- HAS_DT_BIAS=dt_bias is not None,
1534
- BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
1535
- )
1536
- return ddt, dA, ddt_bias
1537
-
1538
-
1539
- def _chunk_state_fwd(
1540
- B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True
1541
- ):
1542
- batch, seqlen, nheads, headdim = x.shape
1543
- _, _, nchunks, chunk_size = dt.shape
1544
- _, _, ngroups, dstate = B.shape
1545
- assert nheads % ngroups == 0
1546
- assert B.shape == (batch, seqlen, ngroups, dstate)
1547
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
1548
- assert dA_cumsum.shape == dt.shape
1549
- if seq_idx is not None:
1550
- assert seq_idx.shape == (batch, seqlen)
1551
- if states is not None:
1552
- assert states.shape == (batch, nchunks, nheads, headdim, dstate)
1553
- else:
1554
- states_dtype = torch.float32 if states_in_fp32 else B.dtype
1555
- states = torch.empty(
1556
- (batch, nchunks, nheads, headdim, dstate),
1557
- device=x.device,
1558
- dtype=states_dtype,
1559
- )
1560
- grid = lambda META: (
1561
- triton.cdiv(headdim, META["BLOCK_SIZE_M"])
1562
- * triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
1563
- batch * nchunks,
1564
- nheads,
1565
- )
1566
- with torch.cuda.device(x.device.index):
1567
- _chunk_state_fwd_kernel[grid](
1568
- x,
1569
- B,
1570
- states,
1571
- dt,
1572
- dA_cumsum,
1573
- seq_idx,
1574
- headdim,
1575
- dstate,
1576
- chunk_size,
1577
- batch,
1578
- seqlen,
1579
- nheads // ngroups,
1580
- x.stride(0),
1581
- x.stride(1),
1582
- x.stride(2),
1583
- x.stride(3),
1584
- B.stride(0),
1585
- B.stride(1),
1586
- B.stride(2),
1587
- B.stride(-1),
1588
- states.stride(0),
1589
- states.stride(1),
1590
- states.stride(2),
1591
- states.stride(3),
1592
- states.stride(4),
1593
- dt.stride(0),
1594
- dt.stride(2),
1595
- dt.stride(1),
1596
- dt.stride(3),
1597
- dA_cumsum.stride(0),
1598
- dA_cumsum.stride(2),
1599
- dA_cumsum.stride(1),
1600
- dA_cumsum.stride(3),
1601
- *(
1602
- (seq_idx.stride(0), seq_idx.stride(1))
1603
- if seq_idx is not None
1604
- else (0, 0)
1605
- ),
1606
- HAS_SEQ_IDX=seq_idx is not None,
1607
- )
1608
- return states
1609
-
1610
-
1611
- def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None):
1612
- batch, seqlen, nheads, headdim = x.shape
1613
- _, _, nchunks, chunk_size = dt.shape
1614
- _, _, ngroups, dstate = B.shape
1615
- assert nheads % ngroups == 0
1616
- assert B.shape == (batch, seqlen, ngroups, dstate)
1617
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
1618
- assert dA_cumsum.shape == dt.shape
1619
- assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
1620
- if dx is not None:
1621
- assert dx.shape == x.shape
1622
- else:
1623
- dx = torch.empty_like(x)
1624
- ddt = torch.empty(
1625
- batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
1626
- )
1627
- ddA_cumsum = torch.empty(
1628
- batch, nheads, nchunks, chunk_size, device=dA_cumsum.device, dtype=torch.float32
1629
- )
1630
- grid_dx = lambda META: (
1631
- triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
1632
- * triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
1633
- batch * nchunks,
1634
- nheads,
1635
- )
1636
- with torch.cuda.device(x.device.index):
1637
- _chunk_state_bwd_dx_kernel[grid_dx](
1638
- x,
1639
- B,
1640
- dstates,
1641
- dt,
1642
- dA_cumsum,
1643
- dx,
1644
- ddt,
1645
- ddA_cumsum,
1646
- chunk_size,
1647
- headdim,
1648
- dstate,
1649
- batch,
1650
- seqlen,
1651
- nheads // ngroups,
1652
- x.stride(0),
1653
- x.stride(1),
1654
- x.stride(2),
1655
- x.stride(3),
1656
- B.stride(0),
1657
- B.stride(1),
1658
- B.stride(2),
1659
- B.stride(-1),
1660
- dstates.stride(0),
1661
- dstates.stride(1),
1662
- dstates.stride(2),
1663
- dstates.stride(3),
1664
- dstates.stride(4),
1665
- dt.stride(0),
1666
- dt.stride(2),
1667
- dt.stride(1),
1668
- dt.stride(3),
1669
- dA_cumsum.stride(0),
1670
- dA_cumsum.stride(2),
1671
- dA_cumsum.stride(1),
1672
- dA_cumsum.stride(3),
1673
- dx.stride(0),
1674
- dx.stride(1),
1675
- dx.stride(2),
1676
- dx.stride(3),
1677
- ddt.stride(0),
1678
- ddt.stride(2),
1679
- ddt.stride(1),
1680
- ddt.stride(3),
1681
- ddA_cumsum.stride(0),
1682
- ddA_cumsum.stride(2),
1683
- ddA_cumsum.stride(1),
1684
- ddA_cumsum.stride(3),
1685
- BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
1686
- )
1687
- return dx, ddt.to(dt.dtype), ddA_cumsum.to(dA_cumsum.dtype)
1688
-
1689
-
1690
- def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1):
1691
- batch, seqlen, nheads, headdim = x.shape
1692
- _, _, nchunks, chunk_size = dt.shape
1693
- dstate = dstates.shape[-1]
1694
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
1695
- assert dA_cumsum.shape == dt.shape
1696
- assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
1697
- if seq_idx is not None:
1698
- assert seq_idx.shape == (batch, seqlen)
1699
- if B is not None:
1700
- assert B.shape == (batch, seqlen, ngroups, dstate)
1701
- B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3))
1702
- # Use torch.empty since the Triton kernel will call init_to_zero
1703
- ddA_cumsum = torch.empty(
1704
- batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32
1705
- )
1706
- ddA_cumsum_strides = (
1707
- ddA_cumsum.stride(0),
1708
- ddA_cumsum.stride(2),
1709
- ddA_cumsum.stride(1),
1710
- ddA_cumsum.stride(3),
1711
- )
1712
- else:
1713
- B_strides = (0, 0, 0, 0)
1714
- ddA_cumsum = None
1715
- ddA_cumsum_strides = (0, 0, 0, 0)
1716
- nheads_ngroups_ratio = nheads // ngroups
1717
- sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
1718
- nheads_per_program = max(
1719
- min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1
1720
- )
1721
- nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)
1722
- dB = torch.empty(
1723
- batch, seqlen, nsplits, ngroups, dstate, device=x.device, dtype=torch.float32
1724
- )
1725
- grid_db = lambda META: (
1726
- triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
1727
- * triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
1728
- batch * nchunks,
1729
- nsplits * ngroups,
1730
- )
1731
- with torch.cuda.device(x.device.index):
1732
- _chunk_state_bwd_db_kernel[grid_db](
1733
- x,
1734
- dstates,
1735
- B,
1736
- dt,
1737
- dA_cumsum,
1738
- seq_idx,
1739
- dB,
1740
- ddA_cumsum,
1741
- chunk_size,
1742
- dstate,
1743
- headdim,
1744
- batch,
1745
- seqlen,
1746
- nheads,
1747
- nheads_per_program,
1748
- ngroups,
1749
- x.stride(0),
1750
- x.stride(1),
1751
- x.stride(2),
1752
- x.stride(3),
1753
- dstates.stride(0),
1754
- dstates.stride(1),
1755
- dstates.stride(2),
1756
- dstates.stride(3),
1757
- dstates.stride(4),
1758
- *B_strides,
1759
- dt.stride(0),
1760
- dt.stride(2),
1761
- dt.stride(1),
1762
- dt.stride(3),
1763
- dA_cumsum.stride(0),
1764
- dA_cumsum.stride(2),
1765
- dA_cumsum.stride(1),
1766
- dA_cumsum.stride(3),
1767
- *(
1768
- (seq_idx.stride(0), seq_idx.stride(1))
1769
- if seq_idx is not None
1770
- else (0, 0)
1771
- ),
1772
- dB.stride(0),
1773
- dB.stride(1),
1774
- dB.stride(2),
1775
- dB.stride(3),
1776
- dB.stride(4),
1777
- *ddA_cumsum_strides,
1778
- HAS_DDA_CS=ddA_cumsum is not None,
1779
- HAS_SEQ_IDX=seq_idx is not None,
1780
- BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
1781
- )
1782
- dB = dB.sum(2)
1783
- if ddA_cumsum is not None:
1784
- # The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute
1785
- # to the state of the chunk.
1786
- # torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
1787
- # But it's easier to just do the cumsum for all elements, the result will be the same.
1788
- torch.cumsum(ddA_cumsum, dim=-1, out=ddA_cumsum)
1789
- return dB if B is None else (dB, ddA_cumsum)
1790
-
1791
-
1792
- def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None):
1793
- batch, seqlen, nheads, headdim = x.shape
1794
- _, _, nchunks, chunk_size = dt.shape
1795
- _, _, ngroups, dstate = B.shape
1796
- assert nheads % ngroups == 0
1797
- assert B.shape == (batch, seqlen, ngroups, dstate)
1798
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
1799
- assert dA_cumsum.shape == dt.shape
1800
- assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
1801
- if seq_idx is not None:
1802
- assert seq_idx.shape == (batch, seqlen)
1803
- # Use torch.empty since the Triton kernel will call init_to_zero
1804
- ddA_cumsum = torch.empty(
1805
- batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32
1806
- )
1807
- grid_ddtcs = lambda META: (
1808
- triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
1809
- * triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
1810
- batch * nchunks,
1811
- nheads,
1812
- )
1813
- with torch.cuda.device(x.device.index):
1814
- _chunk_state_bwd_ddAcs_stable_kernel[grid_ddtcs](
1815
- x,
1816
- B,
1817
- dstates,
1818
- dt,
1819
- dA_cumsum,
1820
- seq_idx,
1821
- ddA_cumsum,
1822
- chunk_size,
1823
- headdim,
1824
- dstate,
1825
- batch,
1826
- seqlen,
1827
- nheads // ngroups,
1828
- x.stride(0),
1829
- x.stride(1),
1830
- x.stride(2),
1831
- x.stride(3),
1832
- B.stride(0),
1833
- B.stride(1),
1834
- B.stride(2),
1835
- B.stride(-1),
1836
- dstates.stride(0),
1837
- dstates.stride(1),
1838
- dstates.stride(2),
1839
- dstates.stride(3),
1840
- dstates.stride(4),
1841
- dt.stride(0),
1842
- dt.stride(2),
1843
- dt.stride(1),
1844
- dt.stride(3),
1845
- dA_cumsum.stride(0),
1846
- dA_cumsum.stride(2),
1847
- dA_cumsum.stride(1),
1848
- dA_cumsum.stride(3),
1849
- *(
1850
- (seq_idx.stride(0), seq_idx.stride(1))
1851
- if seq_idx is not None
1852
- else (0, 0)
1853
- ),
1854
- ddA_cumsum.stride(0),
1855
- ddA_cumsum.stride(2),
1856
- ddA_cumsum.stride(1),
1857
- ddA_cumsum.stride(3),
1858
- HAS_SEQ_IDX=seq_idx is not None,
1859
- BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16),
1860
- BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
1861
- )
1862
- torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
1863
- return ddA_cumsum
1864
-
1865
-
1866
- def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states):
1867
- total_seqlen, nheads, headdim = x.shape
1868
- _, nchunks, chunk_size = dt.shape
1869
- _, ngroups, dstate = B.shape
1870
- batch = cu_seqlens.shape[0] - 1
1871
- cu_seqlens = cu_seqlens.contiguous()
1872
- assert nheads % ngroups == 0
1873
- assert B.shape == (total_seqlen, ngroups, dstate)
1874
- assert dt.shape == (nheads, nchunks, chunk_size)
1875
- assert dA_cumsum.shape == dt.shape
1876
- assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
1877
- states = torch.empty(
1878
- batch,
1879
- nheads,
1880
- headdim,
1881
- dstate,
1882
- dtype=chunk_states.dtype,
1883
- device=chunk_states.device,
1884
- )
1885
- grid = lambda META: (
1886
- triton.cdiv(headdim, META["BLOCK_SIZE_M"])
1887
- * triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
1888
- batch,
1889
- nheads,
1890
- )
1891
- with torch.cuda.device(x.device.index):
1892
- _chunk_state_varlen_kernel[grid](
1893
- x,
1894
- B,
1895
- dt,
1896
- dA_cumsum,
1897
- chunk_states,
1898
- cu_seqlens,
1899
- states,
1900
- headdim,
1901
- dstate,
1902
- chunk_size,
1903
- total_seqlen,
1904
- nheads // ngroups,
1905
- x.stride(0),
1906
- x.stride(1),
1907
- x.stride(2),
1908
- B.stride(0),
1909
- B.stride(1),
1910
- B.stride(2),
1911
- dt.stride(1),
1912
- dt.stride(0),
1913
- dt.stride(2),
1914
- dA_cumsum.stride(1),
1915
- dA_cumsum.stride(0),
1916
- dA_cumsum.stride(2),
1917
- chunk_states.stride(0),
1918
- chunk_states.stride(1),
1919
- chunk_states.stride(2),
1920
- chunk_states.stride(3),
1921
- states.stride(0),
1922
- states.stride(1),
1923
- states.stride(2),
1924
- states.stride(3),
1925
- )
1926
- return states
1927
-
1928
-
1929
- class ChunkStateFn(torch.autograd.Function):
1930
-
1931
- @staticmethod
1932
- def forward(ctx, B, x, dt, dA_cumsum, states_in_fp32=True):
1933
- batch, seqlen, nheads, headdim = x.shape
1934
- _, _, nchunks, chunk_size = dt.shape
1935
- assert seqlen <= nchunks * chunk_size
1936
- _, _, ngroups, dstate = B.shape
1937
- assert B.shape == (batch, seqlen, ngroups, dstate)
1938
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
1939
- assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
1940
- if B.stride(-1) != 1:
1941
- B = B.contiguous()
1942
- if (
1943
- x.stride(-1) != 1 and x.stride(1) != 1
1944
- ): # Either M or K dimension should be contiguous
1945
- x = x.contiguous()
1946
- states = _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=states_in_fp32)
1947
- ctx.save_for_backward(B, x, dt, dA_cumsum)
1948
- return states
1949
-
1950
- @staticmethod
1951
- def backward(ctx, dstates):
1952
- B, x, dt, dA_cumsum = ctx.saved_tensors
1953
- batch, seqlen, nheads, headdim = x.shape
1954
- _, _, nchunks, chunk_size = dt.shape
1955
- _, _, ngroups, dstate = B.shape
1956
- assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
1957
- if dstates.stride(-1) != 1:
1958
- dstates = dstates.contiguous()
1959
- dx, ddt, ddA_cumsum = _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates)
1960
- dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, ngroups=ngroups)
1961
- dB = dB.to(B.dtype)
1962
- return dB, dx, ddt, ddA_cumsum, None
1963
-
1964
-
1965
- def chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True):
1966
- """
1967
- Argument:
1968
- B: (batch, seqlen, ngroups, headdim)
1969
- x: (batch, seqlen, nheads, headdim)
1970
- dt: (batch, nheads, nchunks, chunk_size)
1971
- dA_cumsum: (batch, nheads, nchunks, chunk_size)
1972
- Return:
1973
- states: (batch, nchunks, nheads, headdim, dstate)
1974
- """
1975
- return ChunkStateFn.apply(B, x, dt, dA_cumsum, states_in_fp32)
1976
-
1977
-
1978
- def chunk_state_ref(B, x, dt, dA_cumsum):
1979
- """
1980
- Argument:
1981
- B: (batch, seqlen, ngroups, headdim)
1982
- x: (batch, seqlen, nheads, headdim)
1983
- dt: (batch, nheads, nchunks, chunk_size)
1984
- dA_cumsum: (batch, nheads, nchunks, chunk_size)
1985
- Return:
1986
- states: (batch, nchunks, nheads, headdim, dstate)
1987
- """
1988
- # Check constraints.
1989
- batch, seqlen, nheads, headdim = x.shape
1990
- dstate = B.shape[-1]
1991
- _, _, nchunks, chunk_size = dt.shape
1992
- assert seqlen <= nchunks * chunk_size
1993
- assert x.shape == (batch, seqlen, nheads, headdim)
1994
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
1995
- ngroups = B.shape[2]
1996
- assert nheads % ngroups == 0
1997
- assert B.shape == (batch, seqlen, ngroups, dstate)
1998
- B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
1999
- assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
2000
- if seqlen < nchunks * chunk_size:
2001
- x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
2002
- B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
2003
- x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size)
2004
- B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size)
2005
- decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum))
2006
- return torch.einsum(
2007
- "bclhn,bhcl,bhcl,bclhp->bchpn",
2008
- B.to(x.dtype),
2009
- decay_states.to(x.dtype),
2010
- dt.to(x.dtype),
2011
- x,
2012
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/ops/triton/ssd_combined.py DELETED
@@ -1,1884 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao, Albert Gu.
2
-
3
- """We want triton==2.1.0 or 2.2.0 for this
4
- """
5
-
6
- from typing import Optional
7
-
8
- import math
9
- from packaging import version
10
-
11
- import torch
12
- import torch.nn.functional as F
13
- from torch import Tensor
14
- from ...utils.torch import custom_bwd, custom_fwd
15
-
16
- import triton
17
- import triton.language as tl
18
-
19
- from einops import rearrange, repeat
20
-
21
- try:
22
- from causal_conv1d import causal_conv1d_fn
23
- import causal_conv1d_cuda
24
- except ImportError:
25
- causal_conv1d_fn, causal_conv1d_cuda = None, None
26
-
27
- from .ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
28
- from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd
29
- from .ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db
30
- from .ssd_chunk_state import _chunk_state_bwd_ddAcs_stable
31
- from .ssd_chunk_state import chunk_state, chunk_state_ref
32
- from .ssd_chunk_state import chunk_state_varlen
33
- from .ssd_state_passing import _state_passing_fwd, _state_passing_bwd
34
- from .ssd_state_passing import state_passing, state_passing_ref
35
- from .ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates
36
- from .ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb
37
- from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable
38
- from .ssd_chunk_scan import chunk_scan, chunk_scan_ref
39
- from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev
40
- from .layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd
41
- from .k_activations import _swiglu_fwd, _swiglu_bwd
42
-
43
- TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
44
-
45
-
46
- def init_to_zero(names):
47
- return lambda nargs: [
48
- nargs[name].zero_() for name in names if nargs[name] is not None
49
- ]
50
-
51
-
52
- @triton.autotune(
53
- configs=[
54
- triton.Config(
55
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
56
- num_stages=3,
57
- num_warps=8,
58
- pre_hook=init_to_zero(["ddt_ptr"]),
59
- ),
60
- triton.Config(
61
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
62
- num_stages=4,
63
- num_warps=4,
64
- pre_hook=init_to_zero(["ddt_ptr"]),
65
- ),
66
- triton.Config(
67
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
68
- num_stages=4,
69
- num_warps=4,
70
- pre_hook=init_to_zero(["ddt_ptr"]),
71
- ),
72
- triton.Config(
73
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
74
- num_stages=4,
75
- num_warps=4,
76
- pre_hook=init_to_zero(["ddt_ptr"]),
77
- ),
78
- triton.Config(
79
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
80
- num_stages=4,
81
- num_warps=4,
82
- pre_hook=init_to_zero(["ddt_ptr"]),
83
- ),
84
- triton.Config(
85
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
86
- num_stages=4,
87
- num_warps=4,
88
- pre_hook=init_to_zero(["ddt_ptr"]),
89
- ),
90
- triton.Config(
91
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
92
- num_stages=5,
93
- num_warps=4,
94
- pre_hook=init_to_zero(["ddt_ptr"]),
95
- ),
96
- triton.Config(
97
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
98
- num_stages=5,
99
- num_warps=4,
100
- pre_hook=init_to_zero(["ddt_ptr"]),
101
- ),
102
- triton.Config(
103
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
104
- num_stages=4,
105
- num_warps=4,
106
- pre_hook=init_to_zero(["ddt_ptr"]),
107
- ),
108
- ],
109
- key=["chunk_size", "hdim", "dstate"],
110
- )
111
- @triton.jit
112
- def _chunk_scan_chunk_state_bwd_dx_kernel(
113
- # Pointers to matrices
114
- x_ptr,
115
- cb_ptr,
116
- dout_ptr,
117
- dt_ptr,
118
- dA_cumsum_ptr,
119
- seq_idx_ptr,
120
- D_ptr,
121
- b_ptr,
122
- dstates_ptr,
123
- dx_ptr,
124
- ddt_ptr,
125
- dD_ptr,
126
- # Matrix dimensions
127
- chunk_size,
128
- hdim,
129
- dstate,
130
- batch,
131
- seqlen,
132
- nheads_ngroups_ratio,
133
- # Strides
134
- stride_x_batch,
135
- stride_x_seqlen,
136
- stride_x_head,
137
- stride_x_hdim,
138
- stride_cb_batch,
139
- stride_cb_chunk,
140
- stride_cb_head,
141
- stride_cb_csize_m,
142
- stride_cb_csize_k,
143
- stride_dout_batch,
144
- stride_dout_seqlen,
145
- stride_dout_head,
146
- stride_dout_hdim,
147
- stride_dt_batch,
148
- stride_dt_chunk,
149
- stride_dt_head,
150
- stride_dt_csize,
151
- stride_dA_cs_batch,
152
- stride_dA_cs_chunk,
153
- stride_dA_cs_head,
154
- stride_dA_cs_csize,
155
- stride_seq_idx_batch,
156
- stride_seq_idx_seqlen,
157
- stride_D_head,
158
- stride_b_batch,
159
- stride_b_seqlen,
160
- stride_b_head,
161
- stride_b_dstate,
162
- stride_dstates_batch,
163
- stride_dstates_chunk,
164
- stride_dstates_head,
165
- stride_dstates_hdim,
166
- stride_dstates_dstate,
167
- stride_dx_batch,
168
- stride_dx_seqlen,
169
- stride_dx_head,
170
- stride_dx_hdim,
171
- stride_ddt_batch,
172
- stride_ddt_chunk,
173
- stride_ddt_head,
174
- stride_ddt_csize,
175
- stride_dD_batch,
176
- stride_dD_chunk,
177
- stride_dD_head,
178
- stride_dD_csize,
179
- stride_dD_hdim,
180
- # Meta-parameters
181
- HAS_D: tl.constexpr,
182
- D_HAS_HDIM: tl.constexpr,
183
- HAS_SEQ_IDX: tl.constexpr,
184
- BLOCK_SIZE_M: tl.constexpr,
185
- BLOCK_SIZE_N: tl.constexpr,
186
- BLOCK_SIZE_K: tl.constexpr,
187
- BLOCK_SIZE_DSTATE: tl.constexpr,
188
- IS_TRITON_22: tl.constexpr,
189
- ):
190
- pid_bc = tl.program_id(axis=1)
191
- pid_c = pid_bc // batch
192
- pid_b = pid_bc - pid_c * batch
193
- pid_h = tl.program_id(axis=2)
194
- num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
195
- pid_m = tl.program_id(axis=0) // num_pid_n
196
- pid_n = tl.program_id(axis=0) % num_pid_n
197
- x_ptr += (
198
- pid_b * stride_x_batch
199
- + pid_c * chunk_size * stride_x_seqlen
200
- + pid_h * stride_x_head
201
- )
202
- cb_ptr += (
203
- pid_b * stride_cb_batch
204
- + pid_c * stride_cb_chunk
205
- + (pid_h // nheads_ngroups_ratio) * stride_cb_head
206
- )
207
- dout_ptr += (
208
- pid_b * stride_dout_batch
209
- + pid_c * chunk_size * stride_dout_seqlen
210
- + pid_h * stride_dout_head
211
- )
212
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
213
- ddt_ptr += (
214
- pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
215
- )
216
- dA_cumsum_ptr += (
217
- pid_b * stride_dA_cs_batch
218
- + pid_c * stride_dA_cs_chunk
219
- + pid_h * stride_dA_cs_head
220
- )
221
- b_ptr += (
222
- pid_b * stride_b_batch
223
- + pid_c * chunk_size * stride_b_seqlen
224
- + (pid_h // nheads_ngroups_ratio) * stride_b_head
225
- )
226
- dstates_ptr += (
227
- pid_b * stride_dstates_batch
228
- + pid_c * stride_dstates_chunk
229
- + pid_h * stride_dstates_head
230
- )
231
- if HAS_SEQ_IDX:
232
- seq_idx_ptr += (
233
- pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
234
- )
235
-
236
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
237
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
238
-
239
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
240
-
241
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
242
-
243
- dA_cs_m = tl.load(
244
- dA_cumsum_ptr + offs_m * stride_dA_cs_csize,
245
- mask=offs_m < chunk_size_limit,
246
- other=0.0,
247
- ).to(tl.float32)
248
-
249
- dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
250
- tl.float32
251
- )
252
- if not HAS_SEQ_IDX:
253
- scale = tl.exp(dA_cs_last - dA_cs_m)
254
- else:
255
- seq_idx_m = tl.load(
256
- seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
257
- mask=offs_m < chunk_size_limit,
258
- other=-1,
259
- )
260
- seq_idx_last = tl.load(
261
- seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
262
- )
263
- scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
264
- # Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
265
- # However, we're getting error with the Triton compiler 2.1.0 for that code path:
266
- # Unexpected mma -> mma layout conversion
267
- # Triton 2.2.0 fixes this
268
- offs_dstate = tl.arange(
269
- 0,
270
- (
271
- BLOCK_SIZE_DSTATE
272
- if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128
273
- else BLOCK_SIZE_K
274
- ),
275
- )
276
- b_ptrs = b_ptr + (
277
- offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate
278
- )
279
- dstates_ptrs = dstates_ptr + (
280
- offs_n[None, :] * stride_dstates_hdim
281
- + offs_dstate[:, None] * stride_dstates_dstate
282
- )
283
- if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128:
284
- b = tl.load(
285
- b_ptrs,
286
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate),
287
- other=0.0,
288
- )
289
- dstates = tl.load(
290
- dstates_ptrs,
291
- mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim),
292
- other=0.0,
293
- )
294
- dstates = dstates.to(b_ptr.dtype.element_ty)
295
- acc = tl.dot(b, dstates) * scale[:, None]
296
- else:
297
- for k in range(0, dstate, BLOCK_SIZE_K):
298
- b = tl.load(
299
- b_ptrs,
300
- mask=(offs_m[:, None] < chunk_size_limit)
301
- & (offs_dstate[None, :] < dstate - k),
302
- other=0.0,
303
- )
304
- dstates = tl.load(
305
- dstates_ptrs,
306
- mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim),
307
- other=0.0,
308
- )
309
- dstates = dstates.to(b_ptr.dtype.element_ty)
310
- acc += tl.dot(b, dstates)
311
- b_ptrs += BLOCK_SIZE_K * stride_b_dstate
312
- dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate
313
- acc *= scale[:, None]
314
-
315
- # x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
316
- # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
317
- # dt_ptrs = dt_ptr + offs_m * stride_dt_csize
318
- # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
319
- # ddt = tl.sum(acc * x, axis=1) * dt_m
320
- # ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
321
- # tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
322
-
323
- offs_k = tl.arange(0, BLOCK_SIZE_K)
324
- cb_ptrs = cb_ptr + (
325
- offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k
326
- )
327
- dout_ptrs = dout_ptr + (
328
- offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim
329
- )
330
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
331
- K_MAX = chunk_size_limit
332
- K_MIN = pid_m * BLOCK_SIZE_M
333
- cb_ptrs += K_MIN * stride_cb_csize_k
334
- dout_ptrs += K_MIN * stride_dout_seqlen
335
- dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize
336
- for k in range(K_MIN, K_MAX, BLOCK_SIZE_K):
337
- k = tl.multiple_of(k, BLOCK_SIZE_K)
338
- # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower
339
- cb = tl.load(
340
- cb_ptrs,
341
- mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k),
342
- other=0.0,
343
- )
344
- dout = tl.load(
345
- dout_ptrs,
346
- mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim),
347
- other=0.0,
348
- )
349
- dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(
350
- tl.float32
351
- )
352
- cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
353
- # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
354
- # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
355
- # Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
356
- # This will cause NaN in acc, and hence NaN in dx and ddt.
357
- mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX)
358
- cb = tl.where(mask, cb, 0.0)
359
- cb = cb.to(dout_ptr.dtype.element_ty)
360
- acc += tl.dot(cb, dout)
361
- cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
362
- dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen
363
- dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
364
-
365
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
366
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
367
- dt_ptrs = dt_ptr + offs_m * stride_dt_csize
368
- dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
369
- dx = acc * dt_m[:, None]
370
- dx_ptr += (
371
- pid_b * stride_dx_batch
372
- + pid_c * chunk_size * stride_dx_seqlen
373
- + pid_h * stride_dx_head
374
- )
375
- dx_ptrs = dx_ptr + (
376
- offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim
377
- )
378
- if HAS_D:
379
- dout_res_ptrs = dout_ptr + (
380
- offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim
381
- )
382
- dout_res = tl.load(
383
- dout_res_ptrs,
384
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
385
- other=0.0,
386
- ).to(tl.float32)
387
- if D_HAS_HDIM:
388
- D = tl.load(
389
- D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0
390
- ).to(tl.float32)
391
- else:
392
- D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
393
- dx += dout_res * D
394
- tl.store(
395
- dx_ptrs,
396
- dx,
397
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
398
- )
399
-
400
- x_ptrs = x_ptr + (
401
- offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
402
- )
403
- x = tl.load(
404
- x_ptrs,
405
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
406
- other=0.0,
407
- ).to(tl.float32)
408
- if HAS_D:
409
- dD_ptr += (
410
- pid_b * stride_dD_batch
411
- + pid_c * stride_dD_chunk
412
- + pid_h * stride_dD_head
413
- + pid_m * stride_dD_csize
414
- )
415
- if D_HAS_HDIM:
416
- dD_ptrs = dD_ptr + offs_n * stride_dD_hdim
417
- dD = tl.sum(dout_res * x, axis=0)
418
- tl.store(dD_ptrs, dD, mask=offs_n < hdim)
419
- else:
420
- dD = tl.sum(dout_res * x)
421
- tl.store(dD_ptr, dD)
422
- ddt = tl.sum(acc * x, axis=1)
423
- ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
424
- tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
425
-
426
-
427
- def _chunk_scan_chunk_state_bwd_dx(
428
- x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None
429
- ):
430
- batch, seqlen, nheads, headdim = x.shape
431
- _, _, nchunks, chunk_size = dt.shape
432
- _, _, ngroups, dstate = B.shape
433
- assert nheads % ngroups == 0
434
- assert B.shape == (batch, seqlen, ngroups, dstate)
435
- assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
436
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
437
- assert dA_cumsum.shape == dt.shape
438
- assert dout.shape == x.shape
439
- assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
440
- if seq_idx is not None:
441
- assert seq_idx.shape == (batch, seqlen)
442
- if D is not None:
443
- assert D.shape == (nheads, headdim) or D.shape == (nheads,)
444
- assert D.stride(-1) == 1
445
- BLOCK_SIZE_min = 32
446
- dD = torch.empty(
447
- triton.cdiv(chunk_size, BLOCK_SIZE_min),
448
- batch,
449
- nchunks,
450
- nheads,
451
- headdim if D.dim() == 2 else 1,
452
- device=D.device,
453
- dtype=torch.float32,
454
- )
455
- else:
456
- dD = None
457
- dD_strides = (
458
- (dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))
459
- if D is not None
460
- else (0, 0, 0, 0, 0)
461
- )
462
- if dx is None:
463
- dx = torch.empty_like(x)
464
- else:
465
- assert dx.shape == x.shape
466
- ddt = torch.empty(
467
- batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32
468
- )
469
- grid_dx = lambda META: (
470
- triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
471
- * triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
472
- batch * nchunks,
473
- nheads,
474
- )
475
- with torch.cuda.device(x.device.index):
476
- _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](
477
- x,
478
- CB,
479
- dout,
480
- dt,
481
- dA_cumsum,
482
- seq_idx,
483
- D,
484
- B,
485
- dstates,
486
- dx,
487
- ddt,
488
- dD,
489
- chunk_size,
490
- headdim,
491
- dstate,
492
- batch,
493
- seqlen,
494
- nheads // ngroups,
495
- x.stride(0),
496
- x.stride(1),
497
- x.stride(2),
498
- x.stride(3),
499
- CB.stride(0),
500
- CB.stride(1),
501
- CB.stride(2),
502
- CB.stride(-1),
503
- CB.stride(-2),
504
- dout.stride(0),
505
- dout.stride(1),
506
- dout.stride(2),
507
- dout.stride(3),
508
- dt.stride(0),
509
- dt.stride(2),
510
- dt.stride(1),
511
- dt.stride(3),
512
- dA_cumsum.stride(0),
513
- dA_cumsum.stride(2),
514
- dA_cumsum.stride(1),
515
- dA_cumsum.stride(3),
516
- *(
517
- (seq_idx.stride(0), seq_idx.stride(1))
518
- if seq_idx is not None
519
- else (0, 0)
520
- ),
521
- D.stride(0) if D is not None else 0,
522
- B.stride(0),
523
- B.stride(1),
524
- B.stride(2),
525
- B.stride(3),
526
- dstates.stride(0),
527
- dstates.stride(1),
528
- dstates.stride(2),
529
- dstates.stride(3),
530
- dstates.stride(4),
531
- dx.stride(0),
532
- dx.stride(1),
533
- dx.stride(2),
534
- dx.stride(3),
535
- ddt.stride(0),
536
- ddt.stride(2),
537
- ddt.stride(1),
538
- ddt.stride(3),
539
- dD_strides[1],
540
- dD_strides[2],
541
- dD_strides[3],
542
- dD_strides[0],
543
- dD_strides[4],
544
- D is not None,
545
- D.dim() == 2 if D is not None else True,
546
- HAS_SEQ_IDX=seq_idx is not None,
547
- BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
548
- IS_TRITON_22=TRITON_22
549
- )
550
- if D is not None:
551
- BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs[
552
- "BLOCK_SIZE_M"
553
- ]
554
- n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
555
- dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
556
- if D.dim() == 1:
557
- dD = rearrange(dD, "h 1 -> h")
558
- return dx, ddt.to(dtype=dt.dtype), dD
559
-
560
-
561
- def _mamba_chunk_scan_combined_fwd(
562
- x,
563
- dt,
564
- A,
565
- B,
566
- C,
567
- chunk_size,
568
- D=None,
569
- z=None,
570
- dt_bias=None,
571
- initial_states=None,
572
- seq_idx=None,
573
- cu_seqlens=None,
574
- dt_softplus=False,
575
- dt_limit=(0.0, float("inf")),
576
- ):
577
- batch, seqlen, nheads, headdim = x.shape
578
- _, _, ngroups, dstate = B.shape
579
- assert nheads % ngroups == 0
580
- assert B.shape == (batch, seqlen, ngroups, dstate)
581
- assert x.shape == (batch, seqlen, nheads, headdim)
582
- assert dt.shape == (batch, seqlen, nheads)
583
- assert A.shape == (nheads,)
584
- assert C.shape == B.shape
585
- if z is not None:
586
- assert z.shape == x.shape
587
- if D is not None:
588
- assert D.shape == (nheads, headdim) or D.shape == (nheads,)
589
- if seq_idx is not None:
590
- assert seq_idx.shape == (batch, seqlen)
591
- if B.stride(-1) != 1:
592
- B = B.contiguous()
593
- if C.stride(-1) != 1:
594
- C = C.contiguous()
595
- if (
596
- x.stride(-1) != 1 and x.stride(1) != 1
597
- ): # Either M or K dimension should be contiguous
598
- x = x.contiguous()
599
- if (
600
- z is not None and z.stride(-1) != 1 and z.stride(1) != 1
601
- ): # Either M or K dimension should be contiguous
602
- z = z.contiguous()
603
- if D is not None and D.stride(-1) != 1:
604
- D = D.contiguous()
605
- if initial_states is not None:
606
- assert initial_states.shape == (batch, nheads, headdim, dstate)
607
- # # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size)
608
- # dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
609
- # dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
610
- # dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
611
- dA_cumsum, dt = _chunk_cumsum_fwd(
612
- dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit
613
- )
614
- states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
615
- # states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True)
616
- # states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True)
617
- # states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True)
618
- states, final_states = _state_passing_fwd(
619
- rearrange(states, "... p n -> ... (p n)"),
620
- dA_cumsum[:, :, :, -1],
621
- initial_states=(
622
- rearrange(initial_states, "... p n -> ... (p n)")
623
- if initial_states is not None
624
- else None
625
- ),
626
- seq_idx=seq_idx,
627
- chunk_size=chunk_size,
628
- out_dtype=C.dtype,
629
- )
630
- states, final_states = [
631
- rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]
632
- ]
633
- # states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
634
- # states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
635
- CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
636
- out, out_x = _chunk_scan_fwd(
637
- CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx
638
- )
639
- if cu_seqlens is None:
640
- return out, out_x, dt, dA_cumsum, states, final_states
641
- else:
642
- assert (
643
- batch == 1
644
- ), "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
645
- varlen_states = chunk_state_varlen(
646
- B.squeeze(0),
647
- x.squeeze(0),
648
- dt.squeeze(0),
649
- dA_cumsum.squeeze(0),
650
- cu_seqlens,
651
- states.squeeze(0),
652
- )
653
- return out, out_x, dt, dA_cumsum, states, final_states, varlen_states
654
-
655
-
656
- def _mamba_chunk_scan_combined_bwd(
657
- dout,
658
- x,
659
- dt,
660
- A,
661
- B,
662
- C,
663
- out,
664
- chunk_size,
665
- D=None,
666
- z=None,
667
- dt_bias=None,
668
- initial_states=None,
669
- dfinal_states=None,
670
- seq_idx=None,
671
- dt_softplus=False,
672
- dt_limit=(0.0, float("inf")),
673
- dx=None,
674
- ddt=None,
675
- dB=None,
676
- dC=None,
677
- dz=None,
678
- recompute_output=False,
679
- ):
680
- if dout.stride(-1) != 1:
681
- dout = dout.contiguous()
682
- batch, seqlen, nheads, headdim = x.shape
683
- nchunks = math.ceil(seqlen / chunk_size)
684
- _, _, ngroups, dstate = B.shape
685
- assert dout.shape == (batch, seqlen, nheads, headdim)
686
- assert dt.shape == (batch, seqlen, nheads)
687
- assert A.shape == (nheads,)
688
- assert nheads % ngroups == 0
689
- assert B.shape == (batch, seqlen, ngroups, dstate)
690
- assert C.shape == B.shape
691
- assert out.shape == x.shape
692
- if initial_states is not None:
693
- assert initial_states.shape == (batch, nheads, headdim, dstate)
694
- if seq_idx is not None:
695
- assert seq_idx.shape == (batch, seqlen)
696
- if dx is not None:
697
- assert dx.shape == x.shape
698
- if dB is not None:
699
- assert dB.shape == B.shape
700
- dB_given = dB
701
- else:
702
- dB_given = torch.empty_like(B)
703
- if dC is not None:
704
- assert dC.shape == C.shape
705
- dC_given = dC
706
- else:
707
- dC_given = torch.empty_like(C)
708
- if dz is not None:
709
- assert z is not None
710
- assert dz.shape == z.shape
711
- if ddt is not None:
712
- assert ddt.shape == dt.shape
713
- ddt_given = ddt
714
- else:
715
- ddt_given = torch.empty_like(dt)
716
- # TD: For some reason Triton (2.1.0 and 2.2.0) errors with
717
- # "[CUDA]: invalid device context" (e.g. during varlne test), and cloning makes it work. Idk why.
718
- dt_in = dt.clone()
719
- dA_cumsum, dt = _chunk_cumsum_fwd(
720
- dt_in,
721
- A,
722
- chunk_size,
723
- dt_bias=dt_bias,
724
- dt_softplus=dt_softplus,
725
- dt_limit=dt_limit,
726
- )
727
- CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
728
- states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
729
- states, _ = _state_passing_fwd(
730
- rearrange(states, "... p n -> ... (p n)"),
731
- dA_cumsum[:, :, :, -1],
732
- initial_states=(
733
- rearrange(initial_states, "... p n -> ... (p n)")
734
- if initial_states is not None
735
- else None
736
- ),
737
- seq_idx=seq_idx,
738
- chunk_size=chunk_size,
739
- )
740
- states = rearrange(states, "... (p n) -> ... p n", n=dstate)
741
- if z is not None:
742
- dz, dout, dD, *rest = _chunk_scan_bwd_dz(
743
- x,
744
- z,
745
- out,
746
- dout,
747
- chunk_size=chunk_size,
748
- has_ddAcs=False,
749
- D=D,
750
- dz=dz,
751
- recompute_output=recompute_output,
752
- )
753
- outz = rest[0] if recompute_output else out
754
- else:
755
- dz = None
756
- outz = out
757
- dstates = _chunk_scan_bwd_dstates(
758
- C, dA_cumsum, dout, seq_idx=seq_idx, dtype=states.dtype
759
- )
760
- # dstates has length nchunks, containing the gradient to initial states at index 0 and
761
- # gradient to the states of chunk (nchunks - 2) at index (nchunks - 1)
762
- # Do computation in fp32 but convert dstates and states to fp16/bf16 since dstates and states
763
- # will be used in matmul in the next kernels.
764
- dstates, ddA_chunk_cumsum, dinitial_states, states = _state_passing_bwd(
765
- rearrange(states, "... p n -> ... (p n)"),
766
- dA_cumsum[:, :, :, -1],
767
- rearrange(dstates, "... p n -> ... (p n)"),
768
- dfinal_states=(
769
- rearrange(dfinal_states, "... p n -> ... (p n)")
770
- if dfinal_states is not None
771
- else None
772
- ),
773
- seq_idx=seq_idx,
774
- has_initial_states=initial_states is not None,
775
- dstates_dtype=x.dtype,
776
- states_dtype=x.dtype,
777
- chunk_size=chunk_size,
778
- )
779
- # dstates has length nchunks, containing the gradient to states of chunk 0 at index 0 and
780
- # gradient to the final states at index (nchunks - 1)
781
- # states has length nchunks, containing the initial states at index 0 and the state for chunk (nchunks - 2) at index (nchunks - 1)
782
- # The final states is not stored.
783
- states = rearrange(states, "... (p n) -> ... p n", n=dstate)
784
- dstates = rearrange(dstates, "... (p n) -> ... p n", n=dstate)
785
- dinitial_states = (
786
- rearrange(dinitial_states, "... (p n) -> ... p n", n=dstate)
787
- if dinitial_states is not None
788
- else None
789
- )
790
- dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(
791
- x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx
792
- )
793
- # dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, ngroups=ngroups)
794
- dB, ddA_next = _chunk_state_bwd_db(
795
- x, dt, dA_cumsum, dstates, seq_idx=seq_idx, B=B, ngroups=ngroups
796
- )
797
- # dC = _chunk_scan_bwd_dC(states[:, :-1].to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
798
- dC, ddA_cumsum_prev = _chunk_scan_bwd_dC(
799
- states.to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, C=C, ngroups=ngroups
800
- )
801
- # Computing ddA with the dcb kernel is much slower, so we're not using it for now
802
- dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
803
- # dCB, ddA_tmp = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, CB=CB, ngroups=ngroups)
804
- dCB = dCB.to(CB.dtype)
805
- _bmm_chunk_bwd(C, dCB, residual=dB, out=dB_given)
806
- _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC, out=dC_given)
807
- # If we have z, then dout_x is recomputed in fp32 so dD = (dout_x * x).sum() is more accurate
808
- # than dD_from_x = (dout_x * x).sum() where dout_x is in fp16/bf16
809
- if z is None:
810
- dD = dD_from_x
811
- # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D.
812
- # ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt
813
- # However, this is numerically unstable: when we do the reverse cumsum on ddA_cumsum, there might
814
- # be a lot of underflow.
815
-
816
- # This is already done as part of bwd_dC kernel
817
- # ddA_cumsum_prev = _chunk_scan_bwd_ddAcs_prev(states[:, :-1], C, dout, dA_cumsum, seq_idx=seq_idx)
818
- ddA_cumsum_prev[..., -1] += ddA_chunk_cumsum
819
- ddA_prev = ddA_cumsum_prev.flip([-1]).cumsum(dim=-1).flip([-1])
820
- # This is already done as part of bwd_dB kernel
821
- # ddA_next = _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=seq_idx)
822
- # We don't need to pass in seq_idx because CB also zeros out entries where seq_idx[i] != seq_idx[j]
823
- ddA = _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, CB)
824
- ddA += ddA_next + ddA_prev
825
-
826
- ddt_given, dA, ddt_bias = _chunk_cumsum_bwd(
827
- ddA,
828
- ddt,
829
- dt_in,
830
- A,
831
- dt_bias=dt_bias,
832
- dt_softplus=dt_softplus,
833
- dt_limit=dt_limit,
834
- ddt=ddt_given,
835
- )
836
-
837
- # These 2 lines are just to test ddt and dA being computed by old code
838
- # _, dA = selective_scan_bwd(dout, x, dt, A, B, C, D=D.float(), z=z)
839
- # ddt_given.copy_(ddt)
840
-
841
- return_vals = (
842
- dx,
843
- ddt_given,
844
- dA,
845
- dB_given,
846
- dC_given,
847
- dD,
848
- dz,
849
- ddt_bias,
850
- dinitial_states,
851
- )
852
- return return_vals if not recompute_output else (*return_vals, outz)
853
-
854
-
855
- def selective_scan_bwd(dout, x, dt, A, B, C, D=None, z=None):
856
- """
857
- Argument:
858
- dout: (batch, seqlen, nheads, headdim)
859
- x: (batch, seqlen, nheads, headdim)
860
- dt: (batch, nheads, nchunks, chunk_size) or (batch, nheads, headdim, nchunks, chunk_size)
861
- A: (nheads) or (dim, dstate)
862
- B: (batch, seqlen, ngroups, dstate)
863
- C: (batch, seqlen, ngroups, dstate)
864
- D: (nheads, headdim) or (nheads,)
865
- z: (batch, seqlen, nheads, headdim)
866
- Return:
867
- out: (batch, seqlen, nheads, headdim)
868
- """
869
- import selective_scan
870
-
871
- batch, seqlen, nheads, headdim = x.shape
872
- chunk_size = dt.shape[-1]
873
- _, _, ngroups, dstate = B.shape
874
- assert nheads % ngroups == 0
875
- x = rearrange(x, "b l h p -> b (h p) l")
876
- squeeze_dt = dt.dim() == 4
877
- if dt.dim() == 4:
878
- dt = repeat(dt, "b h c l -> b h p c l", p=headdim)
879
- dt = rearrange(dt, "b h p c l -> b (h p) (c l)", p=headdim)
880
- squeeze_A = A.dim() == 1
881
- if A.dim() == 1:
882
- A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
883
- else:
884
- A = A.to(dtype=torch.float32)
885
- B = rearrange(B, "b l g n -> b g n l")
886
- C = rearrange(C, "b l g n -> b g n l")
887
- if D is not None:
888
- if D.dim() == 2:
889
- D = rearrange(D, "h p -> (h p)")
890
- else:
891
- D = repeat(D, "h -> (h p)", p=headdim)
892
- if z is not None:
893
- z = rearrange(z, "b l h p -> b (h p) l")
894
-
895
- if x.stride(-1) != 1:
896
- x = x.contiguous()
897
- if dt.stride(-1) != 1:
898
- dt = dt.contiguous()
899
- if D is not None:
900
- D = D.contiguous()
901
- if B.stride(-1) != 1:
902
- B = B.contiguous()
903
- if C.stride(-1) != 1:
904
- C = C.contiguous()
905
- if z is not None and z.stride(-1) != 1:
906
- z = z.contiguous()
907
- _, intermediate, *rest = selective_scan.fwd(
908
- x, dt.to(dtype=x.dtype), A, B, C, D, z, None, False
909
- )
910
- if z is not None:
911
- out = rest[0]
912
- else:
913
- out = None
914
-
915
- dout = rearrange(dout, "b l h p -> b (h p) l")
916
-
917
- if dout.stride(-1) != 1:
918
- dout = dout.contiguous()
919
- # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
920
- # backward of selective_scan with the backward of chunk).
921
- # Here we just pass in None and dz will be allocated in the C++ code.
922
- _, ddt, dA, *rest = selective_scan.bwd(
923
- x,
924
- dt.to(dtype=x.dtype),
925
- A,
926
- B,
927
- C,
928
- D,
929
- z,
930
- None,
931
- dout,
932
- intermediate,
933
- out,
934
- None,
935
- False,
936
- False, # option to recompute out_z, not used here
937
- )
938
- ddt = rearrange(ddt, "b (h p) (c l) -> b h p c l", p=headdim, l=chunk_size)
939
- if squeeze_dt:
940
- ddt = ddt.float().sum(dim=2)
941
- if squeeze_A:
942
- dA = rearrange(dA, "(h p) n -> h p n", p=headdim).sum(dim=(1, 2))
943
- return ddt, dA
944
-
945
-
946
- class MambaChunkScanCombinedFn(torch.autograd.Function):
947
-
948
- @staticmethod
949
- def forward(
950
- ctx,
951
- x,
952
- dt,
953
- A,
954
- B,
955
- C,
956
- chunk_size,
957
- D=None,
958
- z=None,
959
- dt_bias=None,
960
- initial_states=None,
961
- seq_idx=None,
962
- cu_seqlens=None,
963
- dt_softplus=False,
964
- dt_limit=(0.0, float("inf")),
965
- return_final_states=False,
966
- return_varlen_states=False,
967
- ):
968
- ctx.dt_dtype = dt.dtype
969
- if not return_varlen_states:
970
- cu_seqlens = None
971
- else:
972
- assert (
973
- cu_seqlens is not None
974
- ), "cu_seqlens must be provided if return_varlen_states is True"
975
- out, out_x, dt_out, dA_cumsum, states, final_states, *rest = (
976
- _mamba_chunk_scan_combined_fwd(
977
- x,
978
- dt,
979
- A,
980
- B,
981
- C,
982
- chunk_size,
983
- D=D,
984
- z=z,
985
- dt_bias=dt_bias,
986
- initial_states=initial_states,
987
- seq_idx=seq_idx,
988
- cu_seqlens=cu_seqlens,
989
- dt_softplus=dt_softplus,
990
- dt_limit=dt_limit,
991
- )
992
- )
993
- ctx.save_for_backward(
994
- out if z is None else out_x,
995
- x,
996
- dt,
997
- dA_cumsum,
998
- A,
999
- B,
1000
- C,
1001
- D,
1002
- z,
1003
- dt_bias,
1004
- initial_states,
1005
- seq_idx,
1006
- )
1007
- ctx.dt_softplus = dt_softplus
1008
- ctx.chunk_size = chunk_size
1009
- ctx.dt_limit = dt_limit
1010
- ctx.return_final_states = return_final_states
1011
- ctx.return_varlen_states = return_varlen_states
1012
- if not return_varlen_states:
1013
- return out if not return_final_states else (out, final_states)
1014
- else:
1015
- varlen_states = rest[0]
1016
- return (
1017
- (out, varlen_states)
1018
- if not return_final_states
1019
- else (out, final_states, varlen_states)
1020
- )
1021
-
1022
- @staticmethod
1023
- def backward(ctx, dout, *args):
1024
- out, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx = (
1025
- ctx.saved_tensors
1026
- )
1027
- assert (
1028
- not ctx.return_varlen_states
1029
- ), "return_varlen_states is not supported in backward"
1030
- dfinal_states = args[0] if ctx.return_final_states else None
1031
- dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = (
1032
- _mamba_chunk_scan_combined_bwd(
1033
- dout,
1034
- x,
1035
- dt,
1036
- A,
1037
- B,
1038
- C,
1039
- out,
1040
- ctx.chunk_size,
1041
- D=D,
1042
- z=z,
1043
- dt_bias=dt_bias,
1044
- initial_states=initial_states,
1045
- dfinal_states=dfinal_states,
1046
- seq_idx=seq_idx,
1047
- dt_softplus=ctx.dt_softplus,
1048
- dt_limit=ctx.dt_limit,
1049
- )
1050
- )
1051
- return (
1052
- dx,
1053
- ddt,
1054
- dA,
1055
- dB,
1056
- dC,
1057
- None,
1058
- dD,
1059
- dz,
1060
- ddt_bias,
1061
- dinitial_states,
1062
- None,
1063
- None,
1064
- None,
1065
- None,
1066
- None,
1067
- None,
1068
- )
1069
-
1070
-
1071
- def mamba_chunk_scan_combined(
1072
- x,
1073
- dt,
1074
- A,
1075
- B,
1076
- C,
1077
- chunk_size,
1078
- D=None,
1079
- z=None,
1080
- dt_bias=None,
1081
- initial_states=None,
1082
- seq_idx=None,
1083
- cu_seqlens=None,
1084
- dt_softplus=False,
1085
- dt_limit=(0.0, float("inf")),
1086
- return_final_states=False,
1087
- return_varlen_states=False,
1088
- ):
1089
- """
1090
- Argument:
1091
- x: (batch, seqlen, nheads, headdim)
1092
- dt: (batch, seqlen, nheads)
1093
- A: (nheads)
1094
- B: (batch, seqlen, ngroups, dstate)
1095
- C: (batch, seqlen, ngroups, dstate)
1096
- chunk_size: int
1097
- D: (nheads, headdim) or (nheads,)
1098
- z: (batch, seqlen, nheads, headdim)
1099
- dt_bias: (nheads,)
1100
- initial_states: (batch, nheads, headdim, dstate)
1101
- seq_idx: (batch, seqlen)
1102
- cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
1103
- dt_softplus: Whether to apply softplus to dt
1104
- Return:
1105
- out: (batch, seqlen, nheads, headdim)
1106
- """
1107
- return MambaChunkScanCombinedFn.apply(
1108
- x,
1109
- dt,
1110
- A,
1111
- B,
1112
- C,
1113
- chunk_size,
1114
- D,
1115
- z,
1116
- dt_bias,
1117
- initial_states,
1118
- seq_idx,
1119
- cu_seqlens,
1120
- dt_softplus,
1121
- dt_limit,
1122
- return_final_states,
1123
- return_varlen_states,
1124
- )
1125
-
1126
-
1127
- def mamba_chunk_scan(
1128
- x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False
1129
- ):
1130
- """
1131
- Argument:
1132
- x: (batch, seqlen, nheads, headdim)
1133
- dt: (batch, seqlen, nheads)
1134
- A: (nheads)
1135
- B: (batch, seqlen, ngroups, dstate)
1136
- C: (batch, seqlen, ngroups, dstate)
1137
- D: (nheads, headdim) or (nheads,)
1138
- z: (batch, seqlen, nheads, headdim)
1139
- dt_bias: (nheads,)
1140
- Return:
1141
- out: (batch, seqlen, nheads, headdim)
1142
- """
1143
- batch, seqlen, nheads, headdim = x.shape
1144
- dstate = B.shape[-1]
1145
- if seqlen % chunk_size != 0:
1146
- dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
1147
- dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
1148
- dt = dt.float() # We want high precision for this before cumsum
1149
- if dt_bias is not None:
1150
- dt = dt + rearrange(dt_bias, "h -> h 1 1")
1151
- if dt_softplus:
1152
- dt = F.softplus(dt)
1153
- dA = dt * rearrange(A, "h -> h 1 1")
1154
- dA = dt * rearrange(A, "h -> h 1 1")
1155
- dA_cumsum = torch.cumsum(dA, dim=-1)
1156
- # 1. Compute the state for each chunk
1157
- states = chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True)
1158
- # 2. Pass the state to all the chunks by weighted cumsum.
1159
- states = rearrange(
1160
- state_passing(
1161
- rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1]
1162
- )[0],
1163
- "... (p n) -> ... p n",
1164
- n=dstate,
1165
- )
1166
- # 3. Compute the output for each chunk
1167
- out = chunk_scan(B, C, x, dt, dA_cumsum, states, D=D, z=z)
1168
- return out
1169
-
1170
-
1171
- def ssd_chunk_scan_combined_ref(
1172
- x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False
1173
- ):
1174
- """
1175
- Argument:
1176
- x: (batch, seqlen, nheads, headdim)
1177
- dt: (batch, seqlen, nheads)
1178
- A: (nheads)
1179
- B: (batch, seqlen, ngroups, dstate)
1180
- C: (batch, seqlen, ngroups, dstate)
1181
- D: (nheads, headdim) or (nheads,)
1182
- z: (batch, seqlen, nheads, headdim)
1183
- dt_bias: (nheads,)
1184
- Return:
1185
- out: (batch, seqlen, nheads, headdim)
1186
- """
1187
- batch, seqlen, nheads, headdim = x.shape
1188
- dstate = B.shape[-1]
1189
- if seqlen % chunk_size != 0:
1190
- dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
1191
- dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
1192
- dt = dt.float() # We want high precision for this before cumsum
1193
- if dt_bias is not None:
1194
- dt = dt + rearrange(dt_bias, "h -> h 1 1")
1195
- if dt_softplus:
1196
- dt = F.softplus(dt)
1197
- dA = dt * rearrange(A, "h -> h 1 1")
1198
- dA_cumsum = torch.cumsum(dA, dim=-1)
1199
- # 1. Compute the state for each chunk
1200
- states = chunk_state_ref(B, x, dt, dA_cumsum)
1201
- states_dtype = states.dtype
1202
- if states.dtype not in [torch.float32, torch.float64]:
1203
- states = states.to(torch.float32)
1204
- # 2. Pass the state to all the chunks by weighted cumsum.
1205
- # state_passing_ref is much less numerically stable
1206
- states = rearrange(
1207
- state_passing_ref(
1208
- rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1]
1209
- )[0],
1210
- "... (p n) -> ... p n",
1211
- n=dstate,
1212
- )
1213
- states = states.to(states_dtype)
1214
- # 3. Compute the output for each chunk
1215
- out = chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z)
1216
- return out
1217
-
1218
-
1219
- def ssd_selective_scan(
1220
- x,
1221
- dt,
1222
- A,
1223
- B,
1224
- C,
1225
- D=None,
1226
- z=None,
1227
- dt_bias=None,
1228
- dt_softplus=False,
1229
- dt_limit=(0.0, float("inf")),
1230
- ):
1231
- """
1232
- Argument:
1233
- x: (batch, seqlen, nheads, headdim)
1234
- dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
1235
- A: (nheads) or (dim, dstate)
1236
- B: (batch, seqlen, ngroups, dstate)
1237
- C: (batch, seqlen, ngroups, dstate)
1238
- D: (nheads, headdim) or (nheads,)
1239
- z: (batch, seqlen, nheads, headdim)
1240
- dt_bias: (nheads,) or (nheads, headdim)
1241
- Return:
1242
- out: (batch, seqlen, nheads, headdim)
1243
- """
1244
- from ..selective_scan_interface import selective_scan_fn
1245
-
1246
- batch, seqlen, nheads, headdim = x.shape
1247
- _, _, ngroups, dstate = B.shape
1248
- x = rearrange(x, "b l h p -> b (h p) l")
1249
- if dt.dim() == 3:
1250
- dt = repeat(dt, "b l h -> b l h p", p=headdim)
1251
- dt = rearrange(dt, "b l h p -> b (h p) l")
1252
- if A.dim() == 1:
1253
- A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
1254
- else:
1255
- A = A.to(dtype=torch.float32)
1256
- B = rearrange(B, "b l g n -> b g n l")
1257
- C = rearrange(C, "b l g n -> b g n l")
1258
- if D is not None:
1259
- if D.dim() == 2:
1260
- D = rearrange(D, "h p -> (h p)")
1261
- else:
1262
- D = repeat(D, "h -> (h p)", p=headdim)
1263
- if z is not None:
1264
- z = rearrange(z, "b l h p -> b (h p) l")
1265
- if dt_bias is not None:
1266
- if dt_bias.dim() == 1:
1267
- dt_bias = repeat(dt_bias, "h -> h p", p=headdim)
1268
- dt_bias = rearrange(dt_bias, "h p -> (h p)")
1269
- if dt_limit != (0.0, float("inf")):
1270
- if dt_bias is not None:
1271
- dt = dt + rearrange(dt_bias, "d -> d 1")
1272
- if dt_softplus:
1273
- dt = F.softplus(dt)
1274
- dt = dt.clamp(min=dt_limit[0], max=dt_limit[1]).to(x.dtype)
1275
- dt_bias = None
1276
- dt_softplus = None
1277
- out = selective_scan_fn(
1278
- x, dt, A, B, C, D=D, z=z, delta_bias=dt_bias, delta_softplus=dt_softplus
1279
- )
1280
- return rearrange(out, "b (h p) l -> b l h p", p=headdim)
1281
-
1282
-
1283
- def mamba_conv1d_scan_ref(
1284
- xBC,
1285
- conv1d_weight,
1286
- conv1d_bias,
1287
- dt,
1288
- A,
1289
- chunk_size,
1290
- D=None,
1291
- z=None,
1292
- dt_bias=None,
1293
- dt_softplus=False,
1294
- dt_limit=(0.0, float("inf")),
1295
- activation="silu",
1296
- headdim=None,
1297
- ngroups=1,
1298
- ):
1299
- """
1300
- Argument:
1301
- xBC: (batch, seqlen, dim + 2 * ngroups * dstate) where dim == nheads * headdim
1302
- conv1d_weight: (dim + 2 * ngroups * dstate, width)
1303
- conv1d_bias: (dim + 2 * ngroups * dstate,)
1304
- dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
1305
- A: (nheads)
1306
- D: (nheads, headdim) or (nheads,)
1307
- z: (batch, seqlen, dim)
1308
- dt_bias: (nheads) or (nheads, headdim)
1309
- headdim: if D is 1D and z is None, headdim must be passed in
1310
- Return:
1311
- out: (batch, seqlen, dim)
1312
- """
1313
- batch, seqlen, nheads = dt.shape[:3]
1314
- assert nheads % ngroups == 0
1315
- if z is not None:
1316
- dim = z.shape[-1]
1317
- assert dim % nheads == 0
1318
- headdim = dim // nheads
1319
- else:
1320
- if D.dim() == 1:
1321
- assert headdim is not None
1322
- else:
1323
- headdim = D.shape[1]
1324
- dim = nheads * headdim
1325
- xBC = rearrange(
1326
- causal_conv1d_fn(
1327
- rearrange(xBC, "b s d -> b d s"),
1328
- conv1d_weight,
1329
- conv1d_bias,
1330
- activation=activation,
1331
- ),
1332
- "b d s -> b s d",
1333
- )
1334
- dstate = (xBC.shape[-1] - dim) // ngroups // 2
1335
- x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
1336
- x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
1337
- B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
1338
- C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
1339
- z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
1340
- out = ssd_selective_scan(
1341
- x,
1342
- dt.to(x.dtype),
1343
- A,
1344
- B,
1345
- C,
1346
- D=D.float(),
1347
- z=z,
1348
- dt_bias=dt_bias,
1349
- dt_softplus=dt_softplus,
1350
- dt_limit=dt_limit,
1351
- )
1352
- return rearrange(out, "b s h p -> b s (h p)")
1353
-
1354
-
1355
- class MambaSplitConv1dScanCombinedFn(torch.autograd.Function):
1356
-
1357
- @staticmethod
1358
- @custom_fwd
1359
- def forward(
1360
- ctx,
1361
- zxbcdt,
1362
- conv1d_weight,
1363
- conv1d_bias,
1364
- dt_bias,
1365
- A,
1366
- D,
1367
- chunk_size,
1368
- initial_states=None,
1369
- seq_idx=None,
1370
- dt_limit=(0.0, float("inf")),
1371
- return_final_states=False,
1372
- activation="silu",
1373
- rmsnorm_weight=None,
1374
- rmsnorm_eps=1e-6,
1375
- outproj_weight=None,
1376
- outproj_bias=None,
1377
- headdim=None,
1378
- ngroups=1,
1379
- norm_before_gate=True,
1380
- ):
1381
- assert activation in [None, "silu", "swish"]
1382
- if D.dim() == 1:
1383
- assert headdim is not None
1384
- (nheads,) = D.shape
1385
- else:
1386
- nheads, headdim = D.shape
1387
- batch, seqlen, _ = zxbcdt.shape
1388
- dim = nheads * headdim
1389
- assert nheads % ngroups == 0
1390
- dstate = (conv1d_weight.shape[0] - dim) // ngroups // 2
1391
- d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ngroups * dstate - nheads) // 2
1392
- assert d_nonssm >= 0
1393
- assert zxbcdt.shape == (
1394
- batch,
1395
- seqlen,
1396
- 2 * d_nonssm + 2 * dim + 2 * ngroups * dstate + nheads,
1397
- )
1398
- assert dt_bias.shape == (nheads,)
1399
- assert A.shape == (nheads,)
1400
- zx0, z, xBC, dt = torch.split(
1401
- zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1
1402
- )
1403
- seq_idx = seq_idx.contiguous() if seq_idx is not None else None
1404
- xBC_conv = rearrange(
1405
- causal_conv1d_cuda.causal_conv1d_fwd(
1406
- rearrange(xBC, "b s d -> b d s"),
1407
- conv1d_weight,
1408
- conv1d_bias,
1409
- seq_idx,
1410
- None,
1411
- None,
1412
- activation in ["silu", "swish"],
1413
- ),
1414
- "b d s -> b s d",
1415
- )
1416
- x, B, C = torch.split(
1417
- xBC_conv, [dim, ngroups * dstate, ngroups * dstate], dim=-1
1418
- )
1419
- x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
1420
- B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
1421
- C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
1422
- z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
1423
- if rmsnorm_weight is None:
1424
- out, out_x, dt_out, dA_cumsum, states, final_states = (
1425
- _mamba_chunk_scan_combined_fwd(
1426
- x,
1427
- dt,
1428
- A,
1429
- B,
1430
- C,
1431
- chunk_size=chunk_size,
1432
- D=D,
1433
- z=z,
1434
- dt_bias=dt_bias,
1435
- initial_states=initial_states,
1436
- seq_idx=seq_idx,
1437
- dt_softplus=True,
1438
- dt_limit=dt_limit,
1439
- )
1440
- )
1441
- out = rearrange(out, "b s h p -> b s (h p)")
1442
- rstd = None
1443
- if d_nonssm > 0:
1444
- out = torch.cat([_swiglu_fwd(zx0), out], dim=-1)
1445
- else:
1446
- out_x, _, dt_out, dA_cumsum, states, final_states = (
1447
- _mamba_chunk_scan_combined_fwd(
1448
- x,
1449
- dt,
1450
- A,
1451
- B,
1452
- C,
1453
- chunk_size=chunk_size,
1454
- D=D,
1455
- z=None,
1456
- dt_bias=dt_bias,
1457
- initial_states=initial_states,
1458
- seq_idx=seq_idx,
1459
- dt_softplus=True,
1460
- dt_limit=dt_limit,
1461
- )
1462
- )
1463
- # reshape input data into 2D tensor
1464
- x_rms = rearrange(out_x, "b s h p -> (b s) (h p)")
1465
- z_rms = rearrange(z, "b s h p -> (b s) (h p)")
1466
- rmsnorm_weight = rmsnorm_weight.contiguous()
1467
- if d_nonssm == 0:
1468
- out = None
1469
- else:
1470
- out01 = torch.empty(
1471
- (batch, seqlen, d_nonssm + dim),
1472
- dtype=x_rms.dtype,
1473
- device=x_rms.device,
1474
- )
1475
- out = rearrange(out01[..., d_nonssm:], "b s d -> (b s) d")
1476
- _swiglu_fwd(zx0, out=out01[..., :d_nonssm])
1477
- out, _, rstd = _layer_norm_fwd(
1478
- x_rms,
1479
- rmsnorm_weight,
1480
- None,
1481
- rmsnorm_eps,
1482
- z_rms,
1483
- out=out,
1484
- group_size=dim // ngroups,
1485
- norm_before_gate=norm_before_gate,
1486
- is_rms_norm=True,
1487
- )
1488
- if d_nonssm == 0:
1489
- out = rearrange(out, "(b s) d -> b s d", b=batch)
1490
- else:
1491
- out = out01
1492
- ctx.outproj_weight_dtype = (
1493
- outproj_weight.dtype if outproj_weight is not None else None
1494
- )
1495
- if outproj_weight is not None:
1496
- if torch.is_autocast_enabled():
1497
- dtype = torch.get_autocast_gpu_dtype()
1498
- out, outproj_weight = out.to(dtype), outproj_weight.to(dtype)
1499
- outproj_bias = (
1500
- outproj_bias.to(dtype) if outproj_bias is not None else None
1501
- )
1502
- out = F.linear(out, outproj_weight, outproj_bias)
1503
- else:
1504
- assert outproj_bias is None
1505
- ctx.save_for_backward(
1506
- zxbcdt,
1507
- conv1d_weight,
1508
- conv1d_bias,
1509
- out_x,
1510
- A,
1511
- D,
1512
- dt_bias,
1513
- initial_states,
1514
- seq_idx,
1515
- rmsnorm_weight,
1516
- rstd,
1517
- outproj_weight,
1518
- outproj_bias,
1519
- )
1520
- ctx.dt_limit = dt_limit
1521
- ctx.return_final_states = return_final_states
1522
- ctx.activation = activation
1523
- ctx.rmsnorm_eps = rmsnorm_eps
1524
- ctx.norm_before_gate = norm_before_gate
1525
- ctx.chunk_size = chunk_size
1526
- ctx.headdim = headdim
1527
- ctx.ngroups = ngroups
1528
- return out if not return_final_states else (out, final_states)
1529
-
1530
- @staticmethod
1531
- @custom_bwd
1532
- def backward(ctx, dout, *args):
1533
- (
1534
- zxbcdt,
1535
- conv1d_weight,
1536
- conv1d_bias,
1537
- out,
1538
- A,
1539
- D,
1540
- dt_bias,
1541
- initial_states,
1542
- seq_idx,
1543
- rmsnorm_weight,
1544
- rstd,
1545
- outproj_weight,
1546
- outproj_bias,
1547
- ) = ctx.saved_tensors
1548
- dfinal_states = args[0] if ctx.return_final_states else None
1549
- headdim = ctx.headdim
1550
- nheads = D.shape[0]
1551
- dim = nheads * headdim
1552
- assert nheads % ctx.ngroups == 0
1553
- dstate = (conv1d_weight.shape[0] - dim) // ctx.ngroups // 2
1554
- d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ctx.ngroups * dstate - nheads) // 2
1555
- assert d_nonssm >= 0
1556
- recompute_output = outproj_weight is not None
1557
- if recompute_output:
1558
- out_recompute = torch.empty(
1559
- *out.shape[:2], d_nonssm + dim, device=out.device, dtype=out.dtype
1560
- )
1561
- out0_recompute, out1_recompute = out_recompute.split(
1562
- [d_nonssm, dim], dim=-1
1563
- )
1564
- zx0, z, xBC, dt = torch.split(
1565
- zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1
1566
- )
1567
- # Recompute x, B, C
1568
- xBC_conv = rearrange(
1569
- causal_conv1d_cuda.causal_conv1d_fwd(
1570
- rearrange(xBC, "b s d -> b d s"),
1571
- conv1d_weight,
1572
- conv1d_bias,
1573
- seq_idx,
1574
- None,
1575
- None,
1576
- ctx.activation in ["silu", "swish"],
1577
- ),
1578
- "b d s -> b s d",
1579
- )
1580
- x, B, C = torch.split(
1581
- xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1
1582
- )
1583
- x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
1584
- B = rearrange(B, "b l (g n) -> b l g n", g=ctx.ngroups)
1585
- C = rearrange(C, "b l (g n) -> b l g n", g=ctx.ngroups)
1586
- dzxbcdt = torch.empty_like(zxbcdt)
1587
- dzx0, dz, dxBC_given, ddt_given = torch.split(
1588
- dzxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1
1589
- )
1590
- dxBC = torch.empty_like(xBC)
1591
- dx, dB, dC = torch.split(
1592
- dxBC, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1
1593
- )
1594
- z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
1595
- dx = rearrange(dx, "b l (h p) -> b l h p", h=nheads)
1596
- dB = rearrange(dB, "b l (g n) -> b l g n", g=ctx.ngroups)
1597
- dC = rearrange(dC, "b l (g n) -> b l g n", g=ctx.ngroups)
1598
- if outproj_weight is not None:
1599
- dout_og = dout
1600
- dout = F.linear(dout, outproj_weight.t())
1601
- if d_nonssm > 0:
1602
- dout0, dout = dout.split([d_nonssm, dim], dim=-1)
1603
- _swiglu_bwd(zx0, dout0, dxy=dzx0, recompute_output=True, out=out0_recompute)
1604
- dout = rearrange(dout, "b s (h p) -> b s h p", p=headdim)
1605
- if rmsnorm_weight is None:
1606
- dz = rearrange(dz, "b l (h p) -> b l h p", h=nheads)
1607
- dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states, *rest = (
1608
- _mamba_chunk_scan_combined_bwd(
1609
- dout,
1610
- x,
1611
- dt,
1612
- A,
1613
- B,
1614
- C,
1615
- out,
1616
- ctx.chunk_size,
1617
- D=D,
1618
- z=z,
1619
- dt_bias=dt_bias,
1620
- initial_states=initial_states,
1621
- dfinal_states=dfinal_states,
1622
- seq_idx=seq_idx,
1623
- dt_softplus=True,
1624
- dt_limit=ctx.dt_limit,
1625
- dx=dx,
1626
- ddt=ddt_given,
1627
- dB=dB,
1628
- dC=dC,
1629
- dz=dz,
1630
- recompute_output=recompute_output,
1631
- )
1632
- )
1633
- out_for_linear = (
1634
- rearrange(rest[0], "b s h p -> b s (h p)") if recompute_output else None
1635
- )
1636
- drmsnorm_weight = None
1637
- else:
1638
- batch = dout.shape[0]
1639
- dy_rms = rearrange(dout, "b s h p -> (b s) (h p)")
1640
- dz = rearrange(dz, "b l d -> (b l) d")
1641
- x_rms = rearrange(out, "b s h p -> (b s) (h p)")
1642
- z_rms = rearrange(z, "b s h p -> (b s) (h p)")
1643
- out1_recompute = (
1644
- rearrange(out1_recompute, "b s d -> (b s) d")
1645
- if recompute_output
1646
- else None
1647
- )
1648
- dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(
1649
- dy_rms,
1650
- x_rms,
1651
- rmsnorm_weight,
1652
- None,
1653
- ctx.rmsnorm_eps,
1654
- None,
1655
- rstd,
1656
- z_rms,
1657
- group_size=dim // ctx.ngroups,
1658
- norm_before_gate=ctx.norm_before_gate,
1659
- is_rms_norm=True,
1660
- recompute_output=recompute_output,
1661
- dz=dz,
1662
- out=out1_recompute if recompute_output else None,
1663
- )
1664
- out_for_linear = out_recompute if recompute_output else None
1665
- dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim)
1666
- dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = (
1667
- _mamba_chunk_scan_combined_bwd(
1668
- dout,
1669
- x,
1670
- dt,
1671
- A,
1672
- B,
1673
- C,
1674
- out,
1675
- ctx.chunk_size,
1676
- D=D,
1677
- z=None,
1678
- dt_bias=dt_bias,
1679
- initial_states=initial_states,
1680
- dfinal_states=dfinal_states,
1681
- seq_idx=seq_idx,
1682
- dt_softplus=True,
1683
- dt_limit=ctx.dt_limit,
1684
- dx=dx,
1685
- ddt=ddt_given,
1686
- dB=dB,
1687
- dC=dC,
1688
- )
1689
- )
1690
-
1691
- if outproj_weight is not None:
1692
- doutproj_weight = torch.einsum("bso,bsd->od", dout_og, out_for_linear)
1693
- doutproj_bias = (
1694
- dout_og.sum(dim=(0, 1)) if outproj_bias is not None else None
1695
- )
1696
- else:
1697
- doutproj_weight, doutproj_bias = None, None
1698
- dxBC_given = rearrange(dxBC_given, "b s d -> b d s")
1699
- dxBC_given, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
1700
- rearrange(xBC, "b s d -> b d s"),
1701
- conv1d_weight,
1702
- conv1d_bias,
1703
- rearrange(dxBC, "b s d -> b d s"),
1704
- seq_idx,
1705
- None,
1706
- None,
1707
- dxBC_given,
1708
- False,
1709
- ctx.activation in ["silu", "swish"],
1710
- )
1711
- dxBC_given = rearrange(dxBC_given, "b d s -> b s d")
1712
- return (
1713
- dzxbcdt,
1714
- dweight,
1715
- dbias,
1716
- ddt_bias,
1717
- dA,
1718
- dD,
1719
- None,
1720
- dinitial_states,
1721
- None,
1722
- None,
1723
- None,
1724
- None,
1725
- drmsnorm_weight,
1726
- None,
1727
- doutproj_weight,
1728
- doutproj_bias,
1729
- None,
1730
- None,
1731
- None,
1732
- )
1733
-
1734
-
1735
- def mamba_split_conv1d_scan_combined(
1736
- zxbcdt,
1737
- conv1d_weight,
1738
- conv1d_bias,
1739
- dt_bias,
1740
- A,
1741
- D,
1742
- chunk_size,
1743
- initial_states=None,
1744
- seq_idx=None,
1745
- dt_limit=(0.0, float("inf")),
1746
- return_final_states=False,
1747
- activation="silu",
1748
- rmsnorm_weight=None,
1749
- rmsnorm_eps=1e-6,
1750
- outproj_weight=None,
1751
- outproj_bias=None,
1752
- headdim=None,
1753
- ngroups=1,
1754
- norm_before_gate=True,
1755
- ):
1756
- """
1757
- Argument:
1758
- zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
1759
- conv1d_weight: (dim + 2 * ngroups * dstate, width)
1760
- conv1d_bias: (dim + 2 * ngroups * dstate,)
1761
- dt_bias: (nheads,)
1762
- A: (nheads)
1763
- D: (nheads, headdim) or (nheads,)
1764
- initial_states: (batch, nheads, headdim, dstate)
1765
- seq_idx: (batch, seqlen), int32
1766
- rmsnorm_weight: (dim,)
1767
- outproj_weight: (out_dim, dim)
1768
- outproj_bias: (out_dim,)
1769
- headdim: if D is 1D, headdim must be passed in
1770
- norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
1771
- Return:
1772
- out: (batch, seqlen, dim)
1773
- """
1774
- return MambaSplitConv1dScanCombinedFn.apply(
1775
- zxbcdt,
1776
- conv1d_weight,
1777
- conv1d_bias,
1778
- dt_bias,
1779
- A,
1780
- D,
1781
- chunk_size,
1782
- initial_states,
1783
- seq_idx,
1784
- dt_limit,
1785
- return_final_states,
1786
- activation,
1787
- rmsnorm_weight,
1788
- rmsnorm_eps,
1789
- outproj_weight,
1790
- outproj_bias,
1791
- headdim,
1792
- ngroups,
1793
- norm_before_gate,
1794
- )
1795
-
1796
-
1797
- def mamba_split_conv1d_scan_ref(
1798
- zxbcdt,
1799
- conv1d_weight,
1800
- conv1d_bias,
1801
- dt_bias,
1802
- A,
1803
- D,
1804
- chunk_size,
1805
- dt_limit=(0.0, float("inf")),
1806
- activation="silu",
1807
- rmsnorm_weight=None,
1808
- rmsnorm_eps=1e-6,
1809
- outproj_weight=None,
1810
- outproj_bias=None,
1811
- headdim=None,
1812
- ngroups=1,
1813
- norm_before_gate=True,
1814
- ):
1815
- """
1816
- Argument:
1817
- zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
1818
- conv1d_weight: (dim + 2 * ngroups * dstate, width)
1819
- conv1d_bias: (dim + 2 * ngroups * dstate,)
1820
- dt_bias: (nheads,)
1821
- A: (nheads)
1822
- D: (nheads, headdim) or (nheads,)
1823
- rmsnorm_weight: (dim,)
1824
- outproj_weight: (out_dim, dim)
1825
- outproj_bias: (out_dim,)
1826
- headdim: if D is 1D, headdim must be passed in
1827
- norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
1828
- Return:
1829
- out: (batch, seqlen, dim)
1830
- """
1831
- if D.dim() == 1:
1832
- assert headdim is not None
1833
- (nheads,) = D.shape
1834
- else:
1835
- nheads, headdim = D.shape
1836
- assert nheads % ngroups == 0
1837
- batch, seqlen, _ = zxbcdt.shape
1838
- dim = nheads * headdim
1839
- dstate = (zxbcdt.shape[-1] - 2 * dim - nheads) // ngroups // 2
1840
- assert zxbcdt.shape == (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads)
1841
- assert dt_bias.shape == (nheads,)
1842
- assert A.shape == (nheads,)
1843
- if rmsnorm_weight is not None:
1844
- assert rmsnorm_weight.shape == (dim,)
1845
- z, xBC, dt = torch.split(zxbcdt, [dim, dim + 2 * ngroups * dstate, nheads], dim=-1)
1846
- xBC = rearrange(
1847
- causal_conv1d_fn(
1848
- rearrange(xBC, "b s d -> b d s"),
1849
- conv1d_weight,
1850
- conv1d_bias,
1851
- activation=activation,
1852
- ),
1853
- "b d s -> b s d",
1854
- )
1855
- x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
1856
- x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
1857
- B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
1858
- C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
1859
- z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
1860
- out = ssd_selective_scan(
1861
- x,
1862
- dt.to(x.dtype),
1863
- A,
1864
- B,
1865
- C,
1866
- D=D.float(),
1867
- z=z if rmsnorm_weight is None else None,
1868
- dt_bias=dt_bias,
1869
- dt_softplus=True,
1870
- dt_limit=dt_limit,
1871
- )
1872
- out = rearrange(out, "b s h p -> b s (h p)")
1873
- if rmsnorm_weight is not None:
1874
- out = rmsnorm_fn(
1875
- out,
1876
- rmsnorm_weight,
1877
- None,
1878
- z=rearrange(z, "b l h p -> b l (h p)"),
1879
- eps=rmsnorm_eps,
1880
- norm_before_gate=norm_before_gate,
1881
- )
1882
- if outproj_weight is not None:
1883
- out = F.linear(out, outproj_weight, outproj_bias)
1884
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/__init__.py DELETED
@@ -1,14 +0,0 @@
1
- __version__ = "2.2.4"
2
-
3
- from .ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
4
- from .modules.mamba_simple import Mamba
5
- from .modules.mamba2 import Mamba2
6
- from .models.mixer_seq_simple import MambaLMHeadModel
7
-
8
- __all__ = [
9
- "selective_scan_fn",
10
- "mamba_inner_fn",
11
- "Mamba",
12
- "Mamba2",
13
- "MambaLMHeadModel",
14
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/distributed/__init__.py DELETED
File without changes
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py DELETED
@@ -1,326 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
- # The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
3
- from typing import Optional
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- from torch import Tensor
9
- from torch.distributed import ProcessGroup
10
- from ..utils.torch import custom_bwd, custom_fwd
11
-
12
- from einops import rearrange
13
-
14
- from ..distributed.distributed_utils import (
15
- all_gather_raw,
16
- all_reduce,
17
- all_reduce_raw,
18
- reduce_scatter,
19
- reduce_scatter_raw,
20
- )
21
-
22
-
23
- class ParallelLinearFunc(torch.autograd.Function):
24
- @staticmethod
25
- @custom_fwd
26
- def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
27
- """
28
- If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
29
- with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
30
- """
31
- ctx.compute_weight_gradient = weight.requires_grad
32
- ctx.process_group = process_group
33
- ctx.sequence_parallel = sequence_parallel
34
-
35
- if torch.is_autocast_enabled():
36
- x = x.to(dtype=torch.get_autocast_gpu_dtype())
37
- x = x.contiguous()
38
- if process_group is not None and sequence_parallel:
39
- # We want to kick off the all_gather early, before weight dtype conversion
40
- total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
41
- else:
42
- total_x = x
43
-
44
- if torch.is_autocast_enabled():
45
- weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
46
- bias = (
47
- bias.to(dtype=torch.get_autocast_gpu_dtype())
48
- if bias is not None
49
- else None
50
- )
51
- weight = weight.contiguous()
52
- if process_group is not None and sequence_parallel:
53
- handle_x.wait()
54
- batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
55
- batch_dim = batch_shape.numel()
56
- # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
57
- output = F.linear(total_x, weight, bias)
58
- if ctx.compute_weight_gradient:
59
- ctx.save_for_backward(x, weight)
60
- else:
61
- ctx.save_for_backward(weight)
62
- return output
63
-
64
- @staticmethod
65
- @custom_bwd
66
- def backward(ctx, grad_output):
67
- grad_output = grad_output.contiguous()
68
- process_group = ctx.process_group
69
- sequence_parallel = ctx.sequence_parallel
70
- if ctx.compute_weight_gradient:
71
- x, weight = ctx.saved_tensors
72
- if process_group is not None and sequence_parallel:
73
- total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
74
- else:
75
- total_x = x
76
- else:
77
- (weight,) = ctx.saved_tensors
78
- total_x = None
79
- batch_shape = grad_output.shape[:-1]
80
- batch_dim = batch_shape.numel()
81
- grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
82
- if ctx.needs_input_grad[0]:
83
- grad_input = F.linear(grad_output, weight.t())
84
- grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
85
- if process_group is not None:
86
- reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
87
- grad_input, handle_grad_input = reduce_fn(
88
- grad_input, process_group, async_op=True
89
- )
90
- else:
91
- grad_input = None
92
- if ctx.needs_input_grad[1]:
93
- assert ctx.compute_weight_gradient
94
- if process_group is not None and sequence_parallel:
95
- handle_x.wait()
96
- grad_weight = torch.einsum(
97
- "bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1])
98
- )
99
- else:
100
- grad_weight = None
101
- grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None
102
- if process_group is not None and ctx.needs_input_grad[0]:
103
- handle_grad_input.wait()
104
- return grad_input, grad_weight, grad_bias, None, None
105
-
106
-
107
- def parallel_linear_func(
108
- x: Tensor,
109
- weight: Tensor,
110
- bias: Optional[Tensor] = None,
111
- process_group: Optional[ProcessGroup] = None,
112
- sequence_parallel: bool = True,
113
- ):
114
- return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel)
115
-
116
-
117
- class ColumnParallelLinear(nn.Linear):
118
- def __init__(
119
- self,
120
- in_features: int,
121
- out_features: int,
122
- process_group: ProcessGroup,
123
- bias: bool = True,
124
- sequence_parallel=True,
125
- multiple_of=1,
126
- device=None,
127
- dtype=None,
128
- ) -> None:
129
- world_size = torch.distributed.get_world_size(process_group)
130
- if out_features % multiple_of:
131
- raise ValueError(
132
- f"out_features ({out_features}) must be a multiple of {multiple_of}"
133
- )
134
- multiple = out_features // multiple_of
135
- # We want to split @multiple across world_size, but it could be an uneven split
136
- div = multiple // world_size
137
- mod = multiple % world_size
138
- # The first @mod ranks get @div + 1 copies, the rest get @div copies
139
- local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
140
- super().__init__(
141
- in_features,
142
- local_multiple * multiple_of,
143
- bias=bias,
144
- device=device,
145
- dtype=dtype,
146
- )
147
- self.process_group = process_group
148
- self.sequence_parallel = sequence_parallel
149
-
150
- def forward(self, x):
151
- # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
152
- # we do an all_gather of x before doing the matmul.
153
- # If not, then the input is already gathered.
154
- return parallel_linear_func(
155
- x,
156
- self.weight,
157
- self.bias,
158
- process_group=self.process_group,
159
- sequence_parallel=self.sequence_parallel,
160
- )
161
-
162
-
163
- class RowParallelLinear(nn.Linear):
164
- def __init__(
165
- self,
166
- in_features: int,
167
- out_features: int,
168
- process_group: ProcessGroup,
169
- bias: bool = True,
170
- sequence_parallel=True,
171
- multiple_of=1,
172
- device=None,
173
- dtype=None,
174
- ) -> None:
175
- world_size = torch.distributed.get_world_size(process_group)
176
- rank = torch.distributed.get_rank(process_group)
177
- if in_features % multiple_of:
178
- raise ValueError(
179
- f"in_features ({in_features}) must be a multiple of {multiple_of}"
180
- )
181
- multiple = in_features // multiple_of
182
- # We want to split @multiple across world_size, but it could be an uneven split
183
- div = multiple // world_size
184
- mod = multiple % world_size
185
- # The first @mod ranks get @div + 1 copies, the rest get @div copies
186
- local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
187
- # Only rank 0 will have bias
188
- super().__init__(
189
- local_multiple * multiple_of,
190
- out_features,
191
- bias=bias and rank == 0,
192
- device=device,
193
- dtype=dtype,
194
- )
195
- self.process_group = process_group
196
- self.sequence_parallel = sequence_parallel
197
-
198
- def forward(self, x):
199
- """
200
- We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
201
- a reduce_scatter of the result.
202
- """
203
- out = parallel_linear_func(x, self.weight, self.bias)
204
- reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
205
- return reduce_fn(out, self.process_group)
206
-
207
-
208
- class VocabParallelEmbedding(nn.Embedding):
209
- def __init__(
210
- self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs
211
- ):
212
- self.process_group = process_group
213
- if process_group is not None:
214
- world_size = torch.distributed.get_world_size(process_group)
215
- if num_embeddings % world_size != 0:
216
- raise ValueError(
217
- f"num_embeddings ({num_embeddings}) must be divisible by "
218
- f"world_size ({world_size})"
219
- )
220
- if world_size > 1 and padding_idx is not None:
221
- raise RuntimeError("ParallelEmbedding does not support padding_idx")
222
- else:
223
- world_size = 1
224
- super().__init__(
225
- num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs
226
- )
227
-
228
- def forward(self, input: Tensor) -> Tensor:
229
- if self.process_group is None:
230
- return super().forward(input)
231
- else:
232
- rank = torch.distributed.get_rank(self.process_group)
233
- vocab_size = self.num_embeddings
234
- vocab_start_index, vocab_end_index = (
235
- rank * vocab_size,
236
- (rank + 1) * vocab_size,
237
- )
238
- # Create a mask of valid vocab ids (1 means it needs to be masked).
239
- input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
240
- input = input - vocab_start_index
241
- input[input_ids_mask] = 0
242
- embeddings = super().forward(input)
243
- embeddings[input_ids_mask] = 0.0
244
- return embeddings
245
-
246
-
247
- class ColumnParallelEmbedding(nn.Embedding):
248
- def __init__(
249
- self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs
250
- ):
251
- self.process_group = process_group
252
- if process_group is not None:
253
- world_size = torch.distributed.get_world_size(process_group)
254
- if embedding_dim % world_size != 0:
255
- raise ValueError(
256
- f"embedding_dim ({embedding_dim}) must be divisible by "
257
- f"world_size ({world_size})"
258
- )
259
- else:
260
- world_size = 1
261
- super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
262
-
263
-
264
- class ParallelEmbeddings(nn.Module):
265
- def __init__(
266
- self,
267
- embed_dim,
268
- vocab_size,
269
- max_position_embeddings,
270
- process_group,
271
- padding_idx=None,
272
- sequence_parallel=True,
273
- device=None,
274
- dtype=None,
275
- ):
276
- """
277
- If max_position_embeddings <= 0, there's no position embeddings
278
- """
279
- factory_kwargs = {"device": device, "dtype": dtype}
280
- super().__init__()
281
- self.process_group = process_group
282
- self.sequence_parallel = sequence_parallel
283
- self.word_embeddings = VocabParallelEmbedding(
284
- vocab_size,
285
- embed_dim,
286
- padding_idx=padding_idx,
287
- process_group=process_group,
288
- **factory_kwargs,
289
- )
290
- self.max_position_embeddings = max_position_embeddings
291
- if self.max_position_embeddings > 0:
292
- self.position_embeddings = ColumnParallelEmbedding(
293
- max_position_embeddings,
294
- embed_dim,
295
- process_group=process_group,
296
- **factory_kwargs,
297
- )
298
-
299
- def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
300
- """
301
- input_ids: (batch, seqlen)
302
- position_ids: (batch, seqlen)
303
- """
304
- batch_size, seqlen = input_ids.shape
305
- world_size = torch.distributed.get_world_size(self.process_group)
306
- embeddings = self.word_embeddings(input_ids)
307
- if self.max_position_embeddings > 0:
308
- if position_ids is None:
309
- position_ids = torch.arange(
310
- seqlen, dtype=torch.long, device=input_ids.device
311
- )
312
- position_embeddings = self.position_embeddings(position_ids)
313
- if world_size <= 1:
314
- embeddings = embeddings + position_embeddings
315
- else:
316
- partition_dim = self.position_embeddings.embedding_dim
317
- rank = torch.distributed.get_rank(self.process_group)
318
- embeddings[
319
- ..., rank * partition_dim : (rank + 1) * partition_dim
320
- ] += position_embeddings
321
- if combine_batch_seqlen_dim:
322
- embeddings = rearrange(embeddings, "b s d -> (b s) d")
323
- reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
324
- return (
325
- embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
326
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/models/__init__.py DELETED
File without changes
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py DELETED
@@ -1,338 +0,0 @@
1
- # Copyright (c) 2023, Albert Gu, Tri Dao.
2
-
3
- import math
4
- from functools import partial
5
- import json
6
- import os
7
- import copy
8
-
9
- from collections import namedtuple
10
-
11
- import torch
12
- import torch.nn as nn
13
-
14
- from .config_mamba import MambaConfig
15
- from ..modules.mamba_simple import Mamba
16
- from ..modules.mamba2 import Mamba2
17
- from ..modules.mha import MHA
18
- from ..modules.mlp import GatedMLP
19
- from ..modules.block import Block
20
- from ..utils.generation import GenerationMixin
21
- from ..utils.hf import load_config_hf, load_state_dict_hf
22
-
23
- try:
24
- from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
25
- except ImportError:
26
- RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
27
-
28
-
29
- def create_block(
30
- d_model,
31
- d_intermediate,
32
- ssm_cfg=None,
33
- attn_layer_idx=None,
34
- attn_cfg=None,
35
- norm_epsilon=1e-5,
36
- rms_norm=False,
37
- residual_in_fp32=False,
38
- fused_add_norm=False,
39
- layer_idx=None,
40
- device=None,
41
- dtype=None,
42
- ):
43
- if ssm_cfg is None:
44
- ssm_cfg = {}
45
- if attn_layer_idx is None:
46
- attn_layer_idx = []
47
- if attn_cfg is None:
48
- attn_cfg = {}
49
- factory_kwargs = {"device": device, "dtype": dtype}
50
- if layer_idx not in attn_layer_idx:
51
- # Create a copy of the config to modify
52
- ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
53
- ssm_layer = ssm_cfg.pop("layer", "Mamba1")
54
- if ssm_layer not in ["Mamba1", "Mamba2"]:
55
- raise ValueError(
56
- f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2"
57
- )
58
- mixer_cls = partial(
59
- Mamba2 if ssm_layer == "Mamba2" else Mamba,
60
- layer_idx=layer_idx,
61
- **ssm_cfg,
62
- **factory_kwargs,
63
- )
64
- else:
65
- mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
66
- norm_cls = partial(
67
- nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
68
- )
69
- if d_intermediate == 0:
70
- mlp_cls = nn.Identity
71
- else:
72
- mlp_cls = partial(
73
- GatedMLP,
74
- hidden_features=d_intermediate,
75
- out_features=d_model,
76
- **factory_kwargs,
77
- )
78
- block = Block(
79
- d_model,
80
- mixer_cls,
81
- mlp_cls,
82
- norm_cls=norm_cls,
83
- fused_add_norm=fused_add_norm,
84
- residual_in_fp32=residual_in_fp32,
85
- )
86
- block.layer_idx = layer_idx
87
- return block
88
-
89
-
90
- # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
91
- def _init_weights(
92
- module,
93
- n_layer,
94
- initializer_range=0.02, # Now only used for embedding layer.
95
- rescale_prenorm_residual=True,
96
- n_residuals_per_layer=1, # Change to 2 if we have MLP
97
- ):
98
- if isinstance(module, nn.Linear):
99
- if module.bias is not None:
100
- if not getattr(module.bias, "_no_reinit", False):
101
- nn.init.zeros_(module.bias)
102
- elif isinstance(module, nn.Embedding):
103
- nn.init.normal_(module.weight, std=initializer_range)
104
-
105
- if rescale_prenorm_residual:
106
- # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
107
- # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
108
- # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
109
- # > -- GPT-2 :: https://openai.com/blog/better-language-models/
110
- #
111
- # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
112
- for name, p in module.named_parameters():
113
- if name in ["out_proj.weight", "fc2.weight"]:
114
- # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
115
- # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
116
- # We need to reinit p since this code could be called multiple times
117
- # Having just p *= scale would repeatedly scale it down
118
- nn.init.kaiming_uniform_(p, a=math.sqrt(5))
119
- with torch.no_grad():
120
- p /= math.sqrt(n_residuals_per_layer * n_layer)
121
-
122
-
123
- class MixerModel(nn.Module):
124
- def __init__(
125
- self,
126
- d_model: int,
127
- n_layer: int,
128
- d_intermediate: int,
129
- vocab_size: int,
130
- ssm_cfg=None,
131
- attn_layer_idx=None,
132
- attn_cfg=None,
133
- norm_epsilon: float = 1e-5,
134
- rms_norm: bool = False,
135
- initializer_cfg=None,
136
- fused_add_norm=False,
137
- residual_in_fp32=False,
138
- device=None,
139
- dtype=None,
140
- ) -> None:
141
- factory_kwargs = {"device": device, "dtype": dtype}
142
- super().__init__()
143
- self.residual_in_fp32 = residual_in_fp32
144
-
145
- self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
146
-
147
- # We change the order of residual and layer norm:
148
- # Instead of LN -> Attn / MLP -> Add, we do:
149
- # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
150
- # the main branch (output of MLP / Mixer). The model definition is unchanged.
151
- # This is for performance reason: we can fuse add + layer_norm.
152
- self.fused_add_norm = fused_add_norm
153
- if self.fused_add_norm:
154
- if layer_norm_fn is None or rms_norm_fn is None:
155
- raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
156
-
157
- self.layers = nn.ModuleList(
158
- [
159
- create_block(
160
- d_model,
161
- d_intermediate=d_intermediate,
162
- ssm_cfg=ssm_cfg,
163
- attn_layer_idx=attn_layer_idx,
164
- attn_cfg=attn_cfg,
165
- norm_epsilon=norm_epsilon,
166
- rms_norm=rms_norm,
167
- residual_in_fp32=residual_in_fp32,
168
- fused_add_norm=fused_add_norm,
169
- layer_idx=i,
170
- **factory_kwargs,
171
- )
172
- for i in range(n_layer)
173
- ]
174
- )
175
-
176
- self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
177
- d_model, eps=norm_epsilon, **factory_kwargs
178
- )
179
-
180
- self.apply(
181
- partial(
182
- _init_weights,
183
- n_layer=n_layer,
184
- **(initializer_cfg if initializer_cfg is not None else {}),
185
- n_residuals_per_layer=(
186
- 1 if d_intermediate == 0 else 2
187
- ), # 2 if we have MLP
188
- )
189
- )
190
-
191
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
192
- return {
193
- i: layer.allocate_inference_cache(
194
- batch_size, max_seqlen, dtype=dtype, **kwargs
195
- )
196
- for i, layer in enumerate(self.layers)
197
- }
198
-
199
- def forward(self, input_ids, inference_params=None, **mixer_kwargs):
200
- hidden_states = self.embedding(input_ids)
201
- residual = None
202
- for layer in self.layers:
203
- hidden_states, residual = layer(
204
- hidden_states,
205
- residual,
206
- inference_params=inference_params,
207
- **mixer_kwargs,
208
- )
209
- if not self.fused_add_norm:
210
- residual = (
211
- (hidden_states + residual) if residual is not None else hidden_states
212
- )
213
- hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
214
- else:
215
- # Set prenorm=False here since we don't need the residual
216
- hidden_states = layer_norm_fn(
217
- hidden_states,
218
- self.norm_f.weight,
219
- self.norm_f.bias,
220
- eps=self.norm_f.eps,
221
- residual=residual,
222
- prenorm=False,
223
- residual_in_fp32=self.residual_in_fp32,
224
- is_rms_norm=isinstance(self.norm_f, RMSNorm),
225
- )
226
- return hidden_states
227
-
228
-
229
- class MambaLMHeadModel(nn.Module, GenerationMixin):
230
-
231
- def __init__(
232
- self,
233
- config: MambaConfig,
234
- initializer_cfg=None,
235
- device=None,
236
- dtype=None,
237
- ) -> None:
238
- self.config = config
239
- d_model = config.d_model
240
- n_layer = config.n_layer
241
- d_intermediate = config.d_intermediate
242
- vocab_size = config.vocab_size
243
- ssm_cfg = config.ssm_cfg
244
- attn_layer_idx = config.attn_layer_idx
245
- attn_cfg = config.attn_cfg
246
- rms_norm = config.rms_norm
247
- residual_in_fp32 = config.residual_in_fp32
248
- fused_add_norm = config.fused_add_norm
249
- pad_vocab_size_multiple = config.pad_vocab_size_multiple
250
- factory_kwargs = {"device": device, "dtype": dtype}
251
-
252
- super().__init__()
253
- if vocab_size % pad_vocab_size_multiple != 0:
254
- vocab_size += pad_vocab_size_multiple - (
255
- vocab_size % pad_vocab_size_multiple
256
- )
257
- self.backbone = MixerModel(
258
- d_model=d_model,
259
- n_layer=n_layer,
260
- d_intermediate=d_intermediate,
261
- vocab_size=vocab_size,
262
- ssm_cfg=ssm_cfg,
263
- attn_layer_idx=attn_layer_idx,
264
- attn_cfg=attn_cfg,
265
- rms_norm=rms_norm,
266
- initializer_cfg=initializer_cfg,
267
- fused_add_norm=fused_add_norm,
268
- residual_in_fp32=residual_in_fp32,
269
- **factory_kwargs,
270
- )
271
- self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
272
-
273
- # Initialize weights and apply final processing
274
- self.apply(
275
- partial(
276
- _init_weights,
277
- n_layer=n_layer,
278
- **(initializer_cfg if initializer_cfg is not None else {}),
279
- )
280
- )
281
- self.tie_weights()
282
-
283
- def tie_weights(self):
284
- if self.config.tie_embeddings:
285
- self.lm_head.weight = self.backbone.embedding.weight
286
-
287
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
288
- return self.backbone.allocate_inference_cache(
289
- batch_size, max_seqlen, dtype=dtype, **kwargs
290
- )
291
-
292
- def forward(
293
- self,
294
- input_ids,
295
- position_ids=None,
296
- inference_params=None,
297
- num_last_tokens=0,
298
- **mixer_kwargs,
299
- ):
300
- """
301
- "position_ids" is just to be compatible with Transformer generation. We don't use it.
302
- num_last_tokens: if > 0, only return the logits for the last n tokens
303
- """
304
- hidden_states = self.backbone(
305
- input_ids, inference_params=inference_params, **mixer_kwargs
306
- )
307
- if num_last_tokens > 0:
308
- hidden_states = hidden_states[:, -num_last_tokens:]
309
- lm_logits = self.lm_head(hidden_states)
310
- CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
311
- return CausalLMOutput(logits=lm_logits)
312
-
313
- @classmethod
314
- def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
315
- config_data = load_config_hf(pretrained_model_name)
316
- config = MambaConfig(**config_data)
317
- model = cls(config, device=device, dtype=dtype, **kwargs)
318
- model.load_state_dict(
319
- load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)
320
- )
321
- return model
322
-
323
- def save_pretrained(self, save_directory):
324
- """
325
- Minimal implementation of save_pretrained for MambaLMHeadModel.
326
- Save the model and its configuration file to a directory.
327
- """
328
- # Ensure save_directory exists
329
- os.makedirs(save_directory, exist_ok=True)
330
-
331
- # Save the model's state_dict
332
- model_path = os.path.join(save_directory, "pytorch_model.bin")
333
- torch.save(self.state_dict(), model_path)
334
-
335
- # Save the configuration of the model
336
- config_path = os.path.join(save_directory, "config.json")
337
- with open(config_path, "w") as f:
338
- json.dump(self.config.__dict__, f, indent=4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/modules/__init__.py DELETED
File without changes
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/__init__.py DELETED
File without changes
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/selective_scan_interface.py DELETED
@@ -1,659 +0,0 @@
1
- # Copyright (c) 2023, Tri Dao, Albert Gu.
2
-
3
- import torch
4
- import torch.nn.functional as F
5
- from ..utils.torch import custom_fwd, custom_bwd
6
-
7
- from einops import rearrange, repeat
8
-
9
- try:
10
- from causal_conv1d import causal_conv1d_fn
11
- import causal_conv1d_cuda
12
- except ImportError:
13
- causal_conv1d_fn = None
14
- causal_conv1d_cuda = None
15
-
16
- from .triton.layer_norm import _layer_norm_fwd
17
-
18
- from .._ops import ops
19
-
20
-
21
- class SelectiveScanFn(torch.autograd.Function):
22
-
23
- @staticmethod
24
- def forward(
25
- ctx,
26
- u,
27
- delta,
28
- A,
29
- B,
30
- C,
31
- D=None,
32
- z=None,
33
- delta_bias=None,
34
- delta_softplus=False,
35
- return_last_state=False,
36
- ):
37
- if u.stride(-1) != 1:
38
- u = u.contiguous()
39
- if delta.stride(-1) != 1:
40
- delta = delta.contiguous()
41
- if D is not None:
42
- D = D.contiguous()
43
- if B.stride(-1) != 1:
44
- B = B.contiguous()
45
- if C.stride(-1) != 1:
46
- C = C.contiguous()
47
- if z is not None and z.stride(-1) != 1:
48
- z = z.contiguous()
49
- if B.dim() == 3:
50
- B = rearrange(B, "b dstate l -> b 1 dstate l")
51
- ctx.squeeze_B = True
52
- if C.dim() == 3:
53
- C = rearrange(C, "b dstate l -> b 1 dstate l")
54
- ctx.squeeze_C = True
55
- out, x, *rest = ops.selective_scan_fwd(
56
- u, delta, A, B, C, D, z, delta_bias, delta_softplus
57
- )
58
- ctx.delta_softplus = delta_softplus
59
- ctx.has_z = z is not None
60
- last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
61
- if not ctx.has_z:
62
- ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
63
- return out if not return_last_state else (out, last_state)
64
- else:
65
- ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
66
- out_z = rest[0]
67
- return out_z if not return_last_state else (out_z, last_state)
68
-
69
- @staticmethod
70
- def backward(ctx, dout, *args):
71
- if not ctx.has_z:
72
- u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
73
- z = None
74
- out = None
75
- else:
76
- u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
77
- if dout.stride(-1) != 1:
78
- dout = dout.contiguous()
79
- # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
80
- # backward of selective_scan_cuda with the backward of chunk).
81
- # Here we just pass in None and dz will be allocated in the C++ code.
82
- du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = ops.selective_scan_bwd(
83
- u,
84
- delta,
85
- A,
86
- B,
87
- C,
88
- D,
89
- z,
90
- delta_bias,
91
- dout,
92
- x,
93
- out,
94
- None,
95
- ctx.delta_softplus,
96
- False, # option to recompute out_z, not used here
97
- )
98
- dz = rest[0] if ctx.has_z else None
99
- dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
100
- dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
101
- return (
102
- du,
103
- ddelta,
104
- dA,
105
- dB,
106
- dC,
107
- dD if D is not None else None,
108
- dz,
109
- ddelta_bias if delta_bias is not None else None,
110
- None,
111
- None,
112
- )
113
-
114
-
115
- def rms_norm_forward(
116
- x,
117
- weight,
118
- bias,
119
- eps=1e-6,
120
- is_rms_norm=True,
121
- ):
122
- # x (b l) d
123
- if x.stride(-1) != 1:
124
- x = x.contiguous()
125
- weight = weight.contiguous()
126
- if bias is not None:
127
- bias = bias.contiguous()
128
- y = _layer_norm_fwd(
129
- x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm
130
- )[0]
131
- # y (b l) d
132
- return y
133
-
134
-
135
- def selective_scan_fn(
136
- u,
137
- delta,
138
- A,
139
- B,
140
- C,
141
- D=None,
142
- z=None,
143
- delta_bias=None,
144
- delta_softplus=False,
145
- return_last_state=False,
146
- ):
147
- """if return_last_state is True, returns (out, last_state)
148
- last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
149
- not considered in the backward pass.
150
- """
151
- return SelectiveScanFn.apply(
152
- u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state
153
- )
154
-
155
-
156
- def selective_scan_ref(
157
- u,
158
- delta,
159
- A,
160
- B,
161
- C,
162
- D=None,
163
- z=None,
164
- delta_bias=None,
165
- delta_softplus=False,
166
- return_last_state=False,
167
- ):
168
- """
169
- u: r(B D L)
170
- delta: r(B D L)
171
- A: c(D N) or r(D N)
172
- B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
173
- C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
174
- D: r(D)
175
- z: r(B D L)
176
- delta_bias: r(D), fp32
177
-
178
- out: r(B D L)
179
- last_state (optional): r(B D dstate) or c(B D dstate)
180
- """
181
- dtype_in = u.dtype
182
- u = u.float()
183
- delta = delta.float()
184
- if delta_bias is not None:
185
- delta = delta + delta_bias[..., None].float()
186
- if delta_softplus:
187
- delta = F.softplus(delta)
188
- batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
189
- is_variable_B = B.dim() >= 3
190
- is_variable_C = C.dim() >= 3
191
- if A.is_complex():
192
- if is_variable_B:
193
- B = torch.view_as_complex(
194
- rearrange(B.float(), "... (L two) -> ... L two", two=2)
195
- )
196
- if is_variable_C:
197
- C = torch.view_as_complex(
198
- rearrange(C.float(), "... (L two) -> ... L two", two=2)
199
- )
200
- else:
201
- B = B.float()
202
- C = C.float()
203
- x = A.new_zeros((batch, dim, dstate))
204
- ys = []
205
- deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
206
- if not is_variable_B:
207
- deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
208
- else:
209
- if B.dim() == 3:
210
- deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
211
- else:
212
- B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
213
- deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
214
- if is_variable_C and C.dim() == 4:
215
- C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
216
- last_state = None
217
- for i in range(u.shape[2]):
218
- x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
219
- if not is_variable_C:
220
- y = torch.einsum("bdn,dn->bd", x, C)
221
- else:
222
- if C.dim() == 3:
223
- y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
224
- else:
225
- y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
226
- if i == u.shape[2] - 1:
227
- last_state = x
228
- if y.is_complex():
229
- y = y.real * 2
230
- ys.append(y)
231
- y = torch.stack(ys, dim=2) # (batch dim L)
232
- out = y if D is None else y + u * rearrange(D, "d -> d 1")
233
- if z is not None:
234
- out = out * F.silu(z)
235
- out = out.to(dtype=dtype_in)
236
- return out if not return_last_state else (out, last_state)
237
-
238
-
239
- class MambaInnerFn(torch.autograd.Function):
240
-
241
- @staticmethod
242
- @custom_fwd
243
- def forward(
244
- ctx,
245
- xz,
246
- conv1d_weight,
247
- conv1d_bias,
248
- x_proj_weight,
249
- delta_proj_weight,
250
- out_proj_weight,
251
- out_proj_bias,
252
- A,
253
- B=None,
254
- C=None,
255
- D=None,
256
- delta_bias=None,
257
- B_proj_bias=None,
258
- C_proj_bias=None,
259
- delta_softplus=True,
260
- checkpoint_lvl=1,
261
- b_rms_weight=None,
262
- c_rms_weight=None,
263
- dt_rms_weight=None,
264
- b_c_dt_rms_eps=1e-6,
265
- ):
266
- """
267
- xz: (batch, dim, seqlen)
268
- """
269
- assert (
270
- causal_conv1d_cuda is not None
271
- ), "causal_conv1d_cuda is not available. Please install causal-conv1d."
272
- assert checkpoint_lvl in [0, 1]
273
- L = xz.shape[-1]
274
- delta_rank = delta_proj_weight.shape[1]
275
- d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
276
- if torch.is_autocast_enabled():
277
- x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
278
- delta_proj_weight = delta_proj_weight.to(
279
- dtype=torch.get_autocast_gpu_dtype()
280
- )
281
- out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
282
- out_proj_bias = (
283
- out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
284
- if out_proj_bias is not None
285
- else None
286
- )
287
- if xz.stride(-1) != 1:
288
- xz = xz.contiguous()
289
- conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
290
- x, z = xz.chunk(2, dim=1)
291
- conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
292
- conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
293
- x, conv1d_weight, conv1d_bias, None, None, None, True
294
- )
295
- # We're being very careful here about the layout, to avoid extra transposes.
296
- # We want delta to have d as the slowest moving dimension
297
- # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
298
- x_dbl = F.linear(
299
- rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight
300
- ) # (bl d)
301
- delta = rearrange(
302
- delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L
303
- )
304
- ctx.is_variable_B = B is None
305
- ctx.is_variable_C = C is None
306
- ctx.B_proj_bias_is_None = B_proj_bias is None
307
- ctx.C_proj_bias_is_None = C_proj_bias is None
308
- if B is None: # variable B
309
- B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate)
310
- if B_proj_bias is not None:
311
- B = B + B_proj_bias.to(dtype=B.dtype)
312
- if not A.is_complex():
313
- # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
314
- B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
315
- else:
316
- B = rearrange(
317
- B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2
318
- ).contiguous()
319
- else:
320
- if B.stride(-1) != 1:
321
- B = B.contiguous()
322
- if C is None: # variable C
323
- C = x_dbl[:, -d_state:] # (bl dstate)
324
- if C_proj_bias is not None:
325
- C = C + C_proj_bias.to(dtype=C.dtype)
326
- if not A.is_complex():
327
- # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
328
- C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
329
- else:
330
- C = rearrange(
331
- C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2
332
- ).contiguous()
333
- else:
334
- if C.stride(-1) != 1:
335
- C = C.contiguous()
336
- if D is not None:
337
- D = D.contiguous()
338
-
339
- if b_rms_weight is not None:
340
- B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
341
- B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)
342
- B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
343
- if c_rms_weight is not None:
344
- C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
345
- C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)
346
- C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
347
- if dt_rms_weight is not None:
348
- delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
349
- delta = rms_norm_forward(
350
- delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps
351
- )
352
- delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
353
-
354
- out, scan_intermediates, out_z = ops.selective_scan_fwd(
355
- conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
356
- )
357
- ctx.delta_softplus = delta_softplus
358
- ctx.out_proj_bias_is_None = out_proj_bias is None
359
- ctx.checkpoint_lvl = checkpoint_lvl
360
- ctx.b_rms_weight = b_rms_weight
361
- ctx.c_rms_weight = c_rms_weight
362
- ctx.dt_rms_weight = dt_rms_weight
363
- ctx.b_c_dt_rms_eps = b_c_dt_rms_eps
364
- if (
365
- checkpoint_lvl >= 1
366
- ): # Will recompute conv1d_out and delta in the backward pass
367
- conv1d_out, delta = None, None
368
- ctx.save_for_backward(
369
- xz,
370
- conv1d_weight,
371
- conv1d_bias,
372
- x_dbl,
373
- x_proj_weight,
374
- delta_proj_weight,
375
- out_proj_weight,
376
- conv1d_out,
377
- delta,
378
- A,
379
- B,
380
- C,
381
- D,
382
- delta_bias,
383
- scan_intermediates,
384
- b_rms_weight,
385
- c_rms_weight,
386
- dt_rms_weight,
387
- out,
388
- )
389
- return F.linear(
390
- rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias
391
- )
392
-
393
- @staticmethod
394
- @custom_bwd
395
- def backward(ctx, dout):
396
- # dout: (batch, seqlen, dim)
397
- assert (
398
- causal_conv1d_cuda is not None
399
- ), "causal_conv1d_cuda is not available. Please install causal-conv1d."
400
- (
401
- xz,
402
- conv1d_weight,
403
- conv1d_bias,
404
- x_dbl,
405
- x_proj_weight,
406
- delta_proj_weight,
407
- out_proj_weight,
408
- conv1d_out,
409
- delta,
410
- A,
411
- B,
412
- C,
413
- D,
414
- delta_bias,
415
- scan_intermediates,
416
- b_rms_weight,
417
- c_rms_weight,
418
- dt_rms_weight,
419
- out,
420
- ) = ctx.saved_tensors
421
- L = xz.shape[-1]
422
- delta_rank = delta_proj_weight.shape[1]
423
- d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
424
- x, z = xz.chunk(2, dim=1)
425
- if dout.stride(-1) != 1:
426
- dout = dout.contiguous()
427
- if ctx.checkpoint_lvl == 1:
428
- conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
429
- x, conv1d_weight, conv1d_bias, None, None, None, True
430
- )
431
- delta = rearrange(
432
- delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L
433
- )
434
- if dt_rms_weight is not None:
435
- delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
436
- delta = rms_norm_forward(
437
- delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps
438
- )
439
- delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
440
- if b_rms_weight is not None:
441
- # Recompute & RMSNorm B
442
- B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
443
- B = rms_norm_forward(B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps)
444
- B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
445
- if c_rms_weight is not None:
446
- # Recompute & RMSNorm C
447
- C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
448
- C = rms_norm_forward(C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps)
449
- C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
450
-
451
- # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
452
- # backward of selective_scan_cuda with the backward of chunk).
453
- dxz = torch.empty_like(xz) # (batch, dim, seqlen)
454
- dx, dz = dxz.chunk(2, dim=1)
455
- dout = rearrange(dout, "b l e -> e (b l)")
456
- dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
457
- dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = (
458
- ops.selective_scan_bwd(
459
- conv1d_out,
460
- delta,
461
- A,
462
- B,
463
- C,
464
- D,
465
- z,
466
- delta_bias,
467
- dout_y,
468
- scan_intermediates,
469
- out,
470
- dz,
471
- ctx.delta_softplus,
472
- True, # option to recompute out_z
473
- )
474
- )
475
- dout_proj_weight = torch.einsum(
476
- "eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")
477
- )
478
- dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
479
- dD = dD if D is not None else None
480
- dx_dbl = torch.empty_like(x_dbl)
481
- dB_proj_bias = None
482
- if ctx.is_variable_B:
483
- if not A.is_complex():
484
- dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
485
- else:
486
- dB = rearrange(
487
- dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2
488
- ).contiguous()
489
- dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
490
- dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d)
491
- dB = None
492
- dC_proj_bias = None
493
- if ctx.is_variable_C:
494
- if not A.is_complex():
495
- dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
496
- else:
497
- dC = rearrange(
498
- dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2
499
- ).contiguous()
500
- dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
501
- dx_dbl[:, -d_state:] = dC # (bl d)
502
- dC = None
503
- ddelta = rearrange(ddelta, "b d l -> d (b l)")
504
- ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
505
- dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
506
- dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
507
- dx_proj_weight = torch.einsum(
508
- "Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")
509
- )
510
- dconv1d_out = torch.addmm(
511
- dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out
512
- )
513
- dconv1d_out = rearrange(
514
- dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]
515
- )
516
- # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
517
- # backward of conv1d with the backward of chunk).
518
- dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
519
- x,
520
- conv1d_weight,
521
- conv1d_bias,
522
- dconv1d_out,
523
- None,
524
- None,
525
- None,
526
- dx,
527
- False,
528
- True,
529
- )
530
- dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
531
- dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
532
- return (
533
- dxz,
534
- dconv1d_weight,
535
- dconv1d_bias,
536
- dx_proj_weight,
537
- ddelta_proj_weight,
538
- dout_proj_weight,
539
- dout_proj_bias,
540
- dA,
541
- dB,
542
- dC,
543
- dD,
544
- ddelta_bias if delta_bias is not None else None,
545
- # 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps
546
- dB_proj_bias,
547
- dC_proj_bias,
548
- None,
549
- None,
550
- None,
551
- None,
552
- None,
553
- None,
554
- )
555
-
556
-
557
- def mamba_inner_fn(
558
- xz,
559
- conv1d_weight,
560
- conv1d_bias,
561
- x_proj_weight,
562
- delta_proj_weight,
563
- out_proj_weight,
564
- out_proj_bias,
565
- A,
566
- B=None,
567
- C=None,
568
- D=None,
569
- delta_bias=None,
570
- B_proj_bias=None,
571
- C_proj_bias=None,
572
- delta_softplus=True,
573
- checkpoint_lvl=1,
574
- b_rms_weight=None,
575
- c_rms_weight=None,
576
- dt_rms_weight=None,
577
- b_c_dt_rms_eps=1e-6,
578
- ):
579
- return MambaInnerFn.apply(
580
- xz,
581
- conv1d_weight,
582
- conv1d_bias,
583
- x_proj_weight,
584
- delta_proj_weight,
585
- out_proj_weight,
586
- out_proj_bias,
587
- A,
588
- B,
589
- C,
590
- D,
591
- delta_bias,
592
- B_proj_bias,
593
- C_proj_bias,
594
- delta_softplus,
595
- checkpoint_lvl,
596
- b_rms_weight,
597
- c_rms_weight,
598
- dt_rms_weight,
599
- b_c_dt_rms_eps,
600
- )
601
-
602
-
603
- def mamba_inner_ref(
604
- xz,
605
- conv1d_weight,
606
- conv1d_bias,
607
- x_proj_weight,
608
- delta_proj_weight,
609
- out_proj_weight,
610
- out_proj_bias,
611
- A,
612
- B=None,
613
- C=None,
614
- D=None,
615
- delta_bias=None,
616
- B_proj_bias=None,
617
- C_proj_bias=None,
618
- delta_softplus=True,
619
- ):
620
- assert (
621
- causal_conv1d_fn is not None
622
- ), "causal_conv1d_fn is not available. Please install causal-conv1d."
623
- L = xz.shape[-1]
624
- delta_rank = delta_proj_weight.shape[1]
625
- d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
626
- x, z = xz.chunk(2, dim=1)
627
- x = causal_conv1d_fn(
628
- x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu"
629
- )
630
- # We're being very careful here about the layout, to avoid extra transposes.
631
- # We want delta to have d as the slowest moving dimension
632
- # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
633
- x_dbl = F.linear(rearrange(x, "b d l -> (b l) d"), x_proj_weight) # (bl d)
634
- delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
635
- delta = rearrange(delta, "d (b l) -> b d l", l=L)
636
- if B is None: # variable B
637
- B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl d)
638
- if B_proj_bias is not None:
639
- B = B + B_proj_bias.to(dtype=B.dtype)
640
- if not A.is_complex():
641
- B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
642
- else:
643
- B = rearrange(
644
- B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2
645
- ).contiguous()
646
- if C is None: # variable B
647
- C = x_dbl[:, -d_state:] # (bl d)
648
- if C_proj_bias is not None:
649
- C = C + C_proj_bias.to(dtype=C.dtype)
650
- if not A.is_complex():
651
- C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
652
- else:
653
- C = rearrange(
654
- C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2
655
- ).contiguous()
656
- y = selective_scan_fn(
657
- x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True
658
- )
659
- return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/triton/__init__.py DELETED
File without changes
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/triton/layer_norm.py DELETED
@@ -1,1166 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
- # Implement dropout + residual + layer_norm / rms_norm.
3
-
4
- # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
5
- # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
6
- # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
7
- # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
8
-
9
- import math
10
- import warnings
11
-
12
- import torch
13
- import torch.nn.functional as F
14
- from ...utils.torch import custom_bwd, custom_fwd
15
-
16
- import triton
17
- import triton.language as tl
18
-
19
-
20
- def layer_norm_ref(
21
- x,
22
- weight,
23
- bias,
24
- residual=None,
25
- x1=None,
26
- weight1=None,
27
- bias1=None,
28
- eps=1e-6,
29
- dropout_p=0.0,
30
- rowscale=None,
31
- prenorm=False,
32
- dropout_mask=None,
33
- dropout_mask1=None,
34
- upcast=False,
35
- ):
36
- dtype = x.dtype
37
- if upcast:
38
- x = x.float()
39
- weight = weight.float()
40
- bias = bias.float() if bias is not None else None
41
- residual = residual.float() if residual is not None else residual
42
- x1 = x1.float() if x1 is not None else None
43
- weight1 = weight1.float() if weight1 is not None else None
44
- bias1 = bias1.float() if bias1 is not None else None
45
- if x1 is not None:
46
- assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
47
- if rowscale is not None:
48
- x = x * rowscale[..., None]
49
- if dropout_p > 0.0:
50
- if dropout_mask is not None:
51
- x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
52
- else:
53
- x = F.dropout(x, p=dropout_p)
54
- if x1 is not None:
55
- if dropout_mask1 is not None:
56
- x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
57
- else:
58
- x1 = F.dropout(x1, p=dropout_p)
59
- if x1 is not None:
60
- x = x + x1
61
- if residual is not None:
62
- x = (x + residual).to(x.dtype)
63
- out = F.layer_norm(
64
- x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps
65
- ).to(dtype)
66
- if weight1 is None:
67
- return out if not prenorm else (out, x)
68
- else:
69
- out1 = F.layer_norm(
70
- x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
71
- ).to(dtype)
72
- return (out, out1) if not prenorm else (out, out1, x)
73
-
74
-
75
- def rms_norm_ref(
76
- x,
77
- weight,
78
- bias,
79
- residual=None,
80
- x1=None,
81
- weight1=None,
82
- bias1=None,
83
- eps=1e-6,
84
- dropout_p=0.0,
85
- rowscale=None,
86
- prenorm=False,
87
- dropout_mask=None,
88
- dropout_mask1=None,
89
- upcast=False,
90
- ):
91
- dtype = x.dtype
92
- if upcast:
93
- x = x.float()
94
- weight = weight.float()
95
- bias = bias.float() if bias is not None else None
96
- residual = residual.float() if residual is not None else residual
97
- x1 = x1.float() if x1 is not None else None
98
- weight1 = weight1.float() if weight1 is not None else None
99
- bias1 = bias1.float() if bias1 is not None else None
100
- if x1 is not None:
101
- assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
102
- if rowscale is not None:
103
- x = x * rowscale[..., None]
104
- if dropout_p > 0.0:
105
- if dropout_mask is not None:
106
- x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
107
- else:
108
- x = F.dropout(x, p=dropout_p)
109
- if x1 is not None:
110
- if dropout_mask1 is not None:
111
- x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
112
- else:
113
- x1 = F.dropout(x1, p=dropout_p)
114
- if x1 is not None:
115
- x = x + x1
116
- if residual is not None:
117
- x = (x + residual).to(x.dtype)
118
- rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
119
- out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(
120
- dtype
121
- )
122
- if weight1 is None:
123
- return out if not prenorm else (out, x)
124
- else:
125
- out1 = (
126
- (x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)
127
- ).to(dtype)
128
- return (out, out1) if not prenorm else (out, out1, x)
129
-
130
-
131
- def config_prune(configs):
132
-
133
- if torch.version.hip:
134
- try:
135
- # set warp size based on gcn architecure
136
- gcn_arch_name = torch.cuda.get_device_properties(0).gcnArchName
137
- if "gfx10" in gcn_arch_name or "gfx11" in gcn_arch_name:
138
- # radeon
139
- warp_size = 32
140
- else:
141
- # instinct
142
- warp_size = 64
143
- except AttributeError as e:
144
- # fall back to crude method to set warp size
145
- device_name = torch.cuda.get_device_properties(0).name
146
- if "instinct" in device_name.lower():
147
- warp_size = 64
148
- else:
149
- warp_size = 32
150
- warnings.warn(
151
- f"{e}, warp size set to {warp_size} based on device name: {device_name}",
152
- UserWarning,
153
- )
154
-
155
- else:
156
- # cuda
157
- warp_size = 32
158
-
159
- max_block_sz = 1024
160
- max_num_warps = max_block_sz // warp_size
161
- pruned_configs = [config for config in configs if config.num_warps <= max_num_warps]
162
- return pruned_configs
163
-
164
-
165
- configs_autotune = [
166
- triton.Config({}, num_warps=1),
167
- triton.Config({}, num_warps=2),
168
- triton.Config({}, num_warps=4),
169
- triton.Config({}, num_warps=8),
170
- triton.Config({}, num_warps=16),
171
- triton.Config({}, num_warps=32),
172
- ]
173
-
174
- pruned_configs_autotune = config_prune(configs_autotune)
175
-
176
-
177
- @triton.autotune(
178
- configs=pruned_configs_autotune,
179
- key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
180
- )
181
- # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
182
- # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
183
- @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
184
- @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
185
- @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
186
- @triton.jit
187
- def _layer_norm_fwd_1pass_kernel(
188
- X, # pointer to the input
189
- Y, # pointer to the output
190
- W, # pointer to the weights
191
- B, # pointer to the biases
192
- RESIDUAL, # pointer to the residual
193
- X1,
194
- W1,
195
- B1,
196
- Y1,
197
- RESIDUAL_OUT, # pointer to the residual
198
- ROWSCALE,
199
- SEEDS, # Dropout seeds for each row
200
- DROPOUT_MASK,
201
- Mean, # pointer to the mean
202
- Rstd, # pointer to the 1/std
203
- stride_x_row, # how much to increase the pointer when moving by 1 row
204
- stride_y_row,
205
- stride_res_row,
206
- stride_res_out_row,
207
- stride_x1_row,
208
- stride_y1_row,
209
- M, # number of rows in X
210
- N, # number of columns in X
211
- eps, # epsilon to avoid division by zero
212
- dropout_p, # Dropout probability
213
- IS_RMS_NORM: tl.constexpr,
214
- BLOCK_N: tl.constexpr,
215
- HAS_RESIDUAL: tl.constexpr,
216
- STORE_RESIDUAL_OUT: tl.constexpr,
217
- HAS_BIAS: tl.constexpr,
218
- HAS_DROPOUT: tl.constexpr,
219
- STORE_DROPOUT_MASK: tl.constexpr,
220
- HAS_ROWSCALE: tl.constexpr,
221
- HAS_X1: tl.constexpr,
222
- HAS_W1: tl.constexpr,
223
- HAS_B1: tl.constexpr,
224
- ):
225
- # Map the program id to the row of X and Y it should compute.
226
- row = tl.program_id(0)
227
- X += row * stride_x_row
228
- Y += row * stride_y_row
229
- if HAS_RESIDUAL:
230
- RESIDUAL += row * stride_res_row
231
- if STORE_RESIDUAL_OUT:
232
- RESIDUAL_OUT += row * stride_res_out_row
233
- if HAS_X1:
234
- X1 += row * stride_x1_row
235
- if HAS_W1:
236
- Y1 += row * stride_y1_row
237
- # Compute mean and variance
238
- cols = tl.arange(0, BLOCK_N)
239
- x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
240
- if HAS_ROWSCALE:
241
- rowscale = tl.load(ROWSCALE + row).to(tl.float32)
242
- x *= rowscale
243
- if HAS_DROPOUT:
244
- # Compute dropout mask
245
- # 7 rounds is good enough, and reduces register pressure
246
- keep_mask = (
247
- tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
248
- )
249
- x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
250
- if STORE_DROPOUT_MASK:
251
- tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
252
- if HAS_X1:
253
- x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
254
- if HAS_ROWSCALE:
255
- rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
256
- x1 *= rowscale
257
- if HAS_DROPOUT:
258
- # Compute dropout mask
259
- # 7 rounds is good enough, and reduces register pressure
260
- keep_mask = (
261
- tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
262
- > dropout_p
263
- )
264
- x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
265
- if STORE_DROPOUT_MASK:
266
- tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
267
- x += x1
268
- if HAS_RESIDUAL:
269
- residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
270
- x += residual
271
- if STORE_RESIDUAL_OUT:
272
- tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
273
- if not IS_RMS_NORM:
274
- mean = tl.sum(x, axis=0) / N
275
- tl.store(Mean + row, mean)
276
- xbar = tl.where(cols < N, x - mean, 0.0)
277
- var = tl.sum(xbar * xbar, axis=0) / N
278
- else:
279
- xbar = tl.where(cols < N, x, 0.0)
280
- var = tl.sum(xbar * xbar, axis=0) / N
281
- rstd = 1 / tl.sqrt(var + eps)
282
- tl.store(Rstd + row, rstd)
283
- # Normalize and apply linear transformation
284
- mask = cols < N
285
- w = tl.load(W + cols, mask=mask).to(tl.float32)
286
- if HAS_BIAS:
287
- b = tl.load(B + cols, mask=mask).to(tl.float32)
288
- x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
289
- y = x_hat * w + b if HAS_BIAS else x_hat * w
290
- # Write output
291
- tl.store(Y + cols, y, mask=mask)
292
- if HAS_W1:
293
- w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
294
- if HAS_B1:
295
- b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
296
- y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
297
- tl.store(Y1 + cols, y1, mask=mask)
298
-
299
-
300
- def _layer_norm_fwd(
301
- x,
302
- weight,
303
- bias,
304
- eps,
305
- residual=None,
306
- x1=None,
307
- weight1=None,
308
- bias1=None,
309
- dropout_p=0.0,
310
- rowscale=None,
311
- out_dtype=None,
312
- residual_dtype=None,
313
- is_rms_norm=False,
314
- return_dropout_mask=False,
315
- ):
316
- if residual is not None:
317
- residual_dtype = residual.dtype
318
- M, N = x.shape
319
- assert x.stride(-1) == 1
320
- if residual is not None:
321
- assert residual.stride(-1) == 1
322
- assert residual.shape == (M, N)
323
- assert weight.shape == (N,)
324
- assert weight.stride(-1) == 1
325
- if bias is not None:
326
- assert bias.stride(-1) == 1
327
- assert bias.shape == (N,)
328
- if x1 is not None:
329
- assert x1.shape == x.shape
330
- assert rowscale is None
331
- assert x1.stride(-1) == 1
332
- if weight1 is not None:
333
- assert weight1.shape == (N,)
334
- assert weight1.stride(-1) == 1
335
- if bias1 is not None:
336
- assert bias1.shape == (N,)
337
- assert bias1.stride(-1) == 1
338
- if rowscale is not None:
339
- assert rowscale.is_contiguous()
340
- assert rowscale.shape == (M,)
341
- # allocate output
342
- y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
343
- assert y.stride(-1) == 1
344
- if weight1 is not None:
345
- y1 = torch.empty_like(y)
346
- assert y1.stride(-1) == 1
347
- else:
348
- y1 = None
349
- if (
350
- residual is not None
351
- or (residual_dtype is not None and residual_dtype != x.dtype)
352
- or dropout_p > 0.0
353
- or rowscale is not None
354
- or x1 is not None
355
- ):
356
- residual_out = torch.empty(
357
- M,
358
- N,
359
- device=x.device,
360
- dtype=residual_dtype if residual_dtype is not None else x.dtype,
361
- )
362
- assert residual_out.stride(-1) == 1
363
- else:
364
- residual_out = None
365
- mean = (
366
- torch.empty((M,), dtype=torch.float32, device=x.device)
367
- if not is_rms_norm
368
- else None
369
- )
370
- rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
371
- if dropout_p > 0.0:
372
- seeds = torch.randint(
373
- 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
374
- )
375
- else:
376
- seeds = None
377
- if return_dropout_mask and dropout_p > 0.0:
378
- dropout_mask = torch.empty(
379
- M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool
380
- )
381
- else:
382
- dropout_mask = None
383
- # Less than 64KB per feature: enqueue fused kernel
384
- MAX_FUSED_SIZE = 65536 // x.element_size()
385
- BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
386
- if N > BLOCK_N:
387
- raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
388
- with torch.cuda.device(x.device.index):
389
- _layer_norm_fwd_1pass_kernel[(M,)](
390
- x,
391
- y,
392
- weight,
393
- bias,
394
- residual,
395
- x1,
396
- weight1,
397
- bias1,
398
- y1,
399
- residual_out,
400
- rowscale,
401
- seeds,
402
- dropout_mask,
403
- mean,
404
- rstd,
405
- x.stride(0),
406
- y.stride(0),
407
- residual.stride(0) if residual is not None else 0,
408
- residual_out.stride(0) if residual_out is not None else 0,
409
- x1.stride(0) if x1 is not None else 0,
410
- y1.stride(0) if y1 is not None else 0,
411
- M,
412
- N,
413
- eps,
414
- dropout_p,
415
- is_rms_norm,
416
- BLOCK_N,
417
- residual is not None,
418
- residual_out is not None,
419
- bias is not None,
420
- dropout_p > 0.0,
421
- dropout_mask is not None,
422
- rowscale is not None,
423
- )
424
- # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
425
- if dropout_mask is not None and x1 is not None:
426
- dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
427
- else:
428
- dropout_mask1 = None
429
- return (
430
- y,
431
- y1,
432
- mean,
433
- rstd,
434
- residual_out if residual_out is not None else x,
435
- seeds,
436
- dropout_mask,
437
- dropout_mask1,
438
- )
439
-
440
-
441
- @triton.autotune(
442
- configs=pruned_configs_autotune,
443
- key=[
444
- "N",
445
- "HAS_DRESIDUAL",
446
- "STORE_DRESIDUAL",
447
- "IS_RMS_NORM",
448
- "HAS_BIAS",
449
- "HAS_DROPOUT",
450
- ],
451
- )
452
- # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
453
- # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
454
- # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
455
- @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
456
- @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
457
- @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
458
- @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
459
- @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
460
- @triton.jit
461
- def _layer_norm_bwd_kernel(
462
- X, # pointer to the input
463
- W, # pointer to the weights
464
- B, # pointer to the biases
465
- Y, # pointer to the output to be recomputed
466
- DY, # pointer to the output gradient
467
- DX, # pointer to the input gradient
468
- DW, # pointer to the partial sum of weights gradient
469
- DB, # pointer to the partial sum of biases gradient
470
- DRESIDUAL,
471
- W1,
472
- DY1,
473
- DX1,
474
- DW1,
475
- DB1,
476
- DRESIDUAL_IN,
477
- ROWSCALE,
478
- SEEDS,
479
- Mean, # pointer to the mean
480
- Rstd, # pointer to the 1/std
481
- stride_x_row, # how much to increase the pointer when moving by 1 row
482
- stride_y_row,
483
- stride_dy_row,
484
- stride_dx_row,
485
- stride_dres_row,
486
- stride_dy1_row,
487
- stride_dx1_row,
488
- stride_dres_in_row,
489
- M, # number of rows in X
490
- N, # number of columns in X
491
- eps, # epsilon to avoid division by zero
492
- dropout_p,
493
- rows_per_program,
494
- IS_RMS_NORM: tl.constexpr,
495
- BLOCK_N: tl.constexpr,
496
- HAS_DRESIDUAL: tl.constexpr,
497
- STORE_DRESIDUAL: tl.constexpr,
498
- HAS_BIAS: tl.constexpr,
499
- HAS_DROPOUT: tl.constexpr,
500
- HAS_ROWSCALE: tl.constexpr,
501
- HAS_DY1: tl.constexpr,
502
- HAS_DX1: tl.constexpr,
503
- HAS_B1: tl.constexpr,
504
- RECOMPUTE_OUTPUT: tl.constexpr,
505
- ):
506
- # Map the program id to the elements of X, DX, and DY it should compute.
507
- row_block_id = tl.program_id(0)
508
- row_start = row_block_id * rows_per_program
509
- # Do not early exit if row_start >= M, because we need to write DW and DB
510
- cols = tl.arange(0, BLOCK_N)
511
- mask = cols < N
512
- X += row_start * stride_x_row
513
- if HAS_DRESIDUAL:
514
- DRESIDUAL += row_start * stride_dres_row
515
- if STORE_DRESIDUAL:
516
- DRESIDUAL_IN += row_start * stride_dres_in_row
517
- DY += row_start * stride_dy_row
518
- DX += row_start * stride_dx_row
519
- if HAS_DY1:
520
- DY1 += row_start * stride_dy1_row
521
- if HAS_DX1:
522
- DX1 += row_start * stride_dx1_row
523
- if RECOMPUTE_OUTPUT:
524
- Y += row_start * stride_y_row
525
- w = tl.load(W + cols, mask=mask).to(tl.float32)
526
- if RECOMPUTE_OUTPUT and HAS_BIAS:
527
- b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
528
- if HAS_DY1:
529
- w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
530
- dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
531
- if HAS_BIAS:
532
- db = tl.zeros((BLOCK_N,), dtype=tl.float32)
533
- if HAS_DY1:
534
- dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
535
- if HAS_B1:
536
- db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
537
- row_end = min((row_block_id + 1) * rows_per_program, M)
538
- for row in range(row_start, row_end):
539
- # Load data to SRAM
540
- x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
541
- dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
542
- if HAS_DY1:
543
- dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
544
- if not IS_RMS_NORM:
545
- mean = tl.load(Mean + row)
546
- rstd = tl.load(Rstd + row)
547
- # Compute dx
548
- xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
549
- xhat = tl.where(mask, xhat, 0.0)
550
- if RECOMPUTE_OUTPUT:
551
- y = xhat * w + b if HAS_BIAS else xhat * w
552
- tl.store(Y + cols, y, mask=mask)
553
- wdy = w * dy
554
- dw += dy * xhat
555
- if HAS_BIAS:
556
- db += dy
557
- if HAS_DY1:
558
- wdy += w1 * dy1
559
- dw1 += dy1 * xhat
560
- if HAS_B1:
561
- db1 += dy1
562
- if not IS_RMS_NORM:
563
- c1 = tl.sum(xhat * wdy, axis=0) / N
564
- c2 = tl.sum(wdy, axis=0) / N
565
- dx = (wdy - (xhat * c1 + c2)) * rstd
566
- else:
567
- c1 = tl.sum(xhat * wdy, axis=0) / N
568
- dx = (wdy - xhat * c1) * rstd
569
- if HAS_DRESIDUAL:
570
- dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
571
- dx += dres
572
- # Write dx
573
- if STORE_DRESIDUAL:
574
- tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
575
- if HAS_DX1:
576
- if HAS_DROPOUT:
577
- keep_mask = (
578
- tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
579
- > dropout_p
580
- )
581
- dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
582
- else:
583
- dx1 = dx
584
- tl.store(DX1 + cols, dx1, mask=mask)
585
- if HAS_DROPOUT:
586
- keep_mask = (
587
- tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7)
588
- > dropout_p
589
- )
590
- dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
591
- if HAS_ROWSCALE:
592
- rowscale = tl.load(ROWSCALE + row).to(tl.float32)
593
- dx *= rowscale
594
- tl.store(DX + cols, dx, mask=mask)
595
-
596
- X += stride_x_row
597
- if HAS_DRESIDUAL:
598
- DRESIDUAL += stride_dres_row
599
- if STORE_DRESIDUAL:
600
- DRESIDUAL_IN += stride_dres_in_row
601
- if RECOMPUTE_OUTPUT:
602
- Y += stride_y_row
603
- DY += stride_dy_row
604
- DX += stride_dx_row
605
- if HAS_DY1:
606
- DY1 += stride_dy1_row
607
- if HAS_DX1:
608
- DX1 += stride_dx1_row
609
- tl.store(DW + row_block_id * N + cols, dw, mask=mask)
610
- if HAS_BIAS:
611
- tl.store(DB + row_block_id * N + cols, db, mask=mask)
612
- if HAS_DY1:
613
- tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
614
- if HAS_B1:
615
- tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
616
-
617
-
618
- def _layer_norm_bwd(
619
- dy,
620
- x,
621
- weight,
622
- bias,
623
- eps,
624
- mean,
625
- rstd,
626
- dresidual=None,
627
- dy1=None,
628
- weight1=None,
629
- bias1=None,
630
- seeds=None,
631
- dropout_p=0.0,
632
- rowscale=None,
633
- has_residual=False,
634
- has_x1=False,
635
- is_rms_norm=False,
636
- x_dtype=None,
637
- recompute_output=False,
638
- ):
639
- M, N = x.shape
640
- assert x.stride(-1) == 1
641
- assert dy.stride(-1) == 1
642
- assert dy.shape == (M, N)
643
- if dresidual is not None:
644
- assert dresidual.stride(-1) == 1
645
- assert dresidual.shape == (M, N)
646
- assert weight.shape == (N,)
647
- assert weight.stride(-1) == 1
648
- if bias is not None:
649
- assert bias.stride(-1) == 1
650
- assert bias.shape == (N,)
651
- if dy1 is not None:
652
- assert weight1 is not None
653
- assert dy1.shape == dy.shape
654
- assert dy1.stride(-1) == 1
655
- if weight1 is not None:
656
- assert weight1.shape == (N,)
657
- assert weight1.stride(-1) == 1
658
- if bias1 is not None:
659
- assert bias1.shape == (N,)
660
- assert bias1.stride(-1) == 1
661
- if seeds is not None:
662
- assert seeds.is_contiguous()
663
- assert seeds.shape == (M if not has_x1 else M * 2,)
664
- if rowscale is not None:
665
- assert rowscale.is_contiguous()
666
- assert rowscale.shape == (M,)
667
- # allocate output
668
- dx = (
669
- torch.empty_like(x)
670
- if x_dtype is None
671
- else torch.empty(M, N, dtype=x_dtype, device=x.device)
672
- )
673
- dresidual_in = (
674
- torch.empty_like(x)
675
- if has_residual
676
- and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
677
- else None
678
- )
679
- dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
680
- y = (
681
- torch.empty(M, N, dtype=dy.dtype, device=dy.device)
682
- if recompute_output
683
- else None
684
- )
685
- if recompute_output:
686
- assert (
687
- weight1 is None
688
- ), "recompute_output is not supported with parallel LayerNorm"
689
-
690
- # Less than 64KB per feature: enqueue fused kernel
691
- MAX_FUSED_SIZE = 65536 // x.element_size()
692
- BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
693
- if N > BLOCK_N:
694
- raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
695
- sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
696
- _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
697
- _db = (
698
- torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
699
- if bias is not None
700
- else None
701
- )
702
- _dw1 = torch.empty_like(_dw) if weight1 is not None else None
703
- _db1 = torch.empty_like(_db) if bias1 is not None else None
704
- rows_per_program = math.ceil(M / sm_count)
705
- grid = (sm_count,)
706
- with torch.cuda.device(x.device.index):
707
- _layer_norm_bwd_kernel[grid](
708
- x,
709
- weight,
710
- bias,
711
- y,
712
- dy,
713
- dx,
714
- _dw,
715
- _db,
716
- dresidual,
717
- weight1,
718
- dy1,
719
- dx1,
720
- _dw1,
721
- _db1,
722
- dresidual_in,
723
- rowscale,
724
- seeds,
725
- mean,
726
- rstd,
727
- x.stride(0),
728
- 0 if not recompute_output else y.stride(0),
729
- dy.stride(0),
730
- dx.stride(0),
731
- dresidual.stride(0) if dresidual is not None else 0,
732
- dy1.stride(0) if dy1 is not None else 0,
733
- dx1.stride(0) if dx1 is not None else 0,
734
- dresidual_in.stride(0) if dresidual_in is not None else 0,
735
- M,
736
- N,
737
- eps,
738
- dropout_p,
739
- rows_per_program,
740
- is_rms_norm,
741
- BLOCK_N,
742
- dresidual is not None,
743
- dresidual_in is not None,
744
- bias is not None,
745
- dropout_p > 0.0,
746
- )
747
- dw = _dw.sum(0).to(weight.dtype)
748
- db = _db.sum(0).to(bias.dtype) if bias is not None else None
749
- dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
750
- db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
751
- # Don't need to compute dresidual_in separately in this case
752
- if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
753
- dresidual_in = dx
754
- if has_x1 and dropout_p == 0.0:
755
- dx1 = dx
756
- return (
757
- (dx, dw, db, dresidual_in, dx1, dw1, db1)
758
- if not recompute_output
759
- else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
760
- )
761
-
762
-
763
- class LayerNormFn(torch.autograd.Function):
764
- @staticmethod
765
- def forward(
766
- ctx,
767
- x,
768
- weight,
769
- bias,
770
- residual=None,
771
- x1=None,
772
- weight1=None,
773
- bias1=None,
774
- eps=1e-6,
775
- dropout_p=0.0,
776
- rowscale=None,
777
- prenorm=False,
778
- residual_in_fp32=False,
779
- is_rms_norm=False,
780
- return_dropout_mask=False,
781
- ):
782
- x_shape_og = x.shape
783
- # reshape input data into 2D tensor
784
- x = x.reshape(-1, x.shape[-1])
785
- if x.stride(-1) != 1:
786
- x = x.contiguous()
787
- if residual is not None:
788
- assert residual.shape == x_shape_og
789
- residual = residual.reshape(-1, residual.shape[-1])
790
- if residual.stride(-1) != 1:
791
- residual = residual.contiguous()
792
- if x1 is not None:
793
- assert x1.shape == x_shape_og
794
- assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
795
- x1 = x1.reshape(-1, x1.shape[-1])
796
- if x1.stride(-1) != 1:
797
- x1 = x1.contiguous()
798
- weight = weight.contiguous()
799
- if bias is not None:
800
- bias = bias.contiguous()
801
- if weight1 is not None:
802
- weight1 = weight1.contiguous()
803
- if bias1 is not None:
804
- bias1 = bias1.contiguous()
805
- if rowscale is not None:
806
- rowscale = rowscale.reshape(-1).contiguous()
807
- residual_dtype = (
808
- residual.dtype
809
- if residual is not None
810
- else (torch.float32 if residual_in_fp32 else None)
811
- )
812
- y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = (
813
- _layer_norm_fwd(
814
- x,
815
- weight,
816
- bias,
817
- eps,
818
- residual,
819
- x1,
820
- weight1,
821
- bias1,
822
- dropout_p=dropout_p,
823
- rowscale=rowscale,
824
- residual_dtype=residual_dtype,
825
- is_rms_norm=is_rms_norm,
826
- return_dropout_mask=return_dropout_mask,
827
- )
828
- )
829
- ctx.save_for_backward(
830
- residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
831
- )
832
- ctx.x_shape_og = x_shape_og
833
- ctx.eps = eps
834
- ctx.dropout_p = dropout_p
835
- ctx.is_rms_norm = is_rms_norm
836
- ctx.has_residual = residual is not None
837
- ctx.has_x1 = x1 is not None
838
- ctx.prenorm = prenorm
839
- ctx.x_dtype = x.dtype
840
- y = y.reshape(x_shape_og)
841
- y1 = y1.reshape(x_shape_og) if y1 is not None else None
842
- residual_out = (
843
- residual_out.reshape(x_shape_og) if residual_out is not None else None
844
- )
845
- dropout_mask = (
846
- dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
847
- )
848
- dropout_mask1 = (
849
- dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
850
- )
851
- if not return_dropout_mask:
852
- if weight1 is None:
853
- return y if not prenorm else (y, residual_out)
854
- else:
855
- return (y, y1) if not prenorm else (y, y1, residual_out)
856
- else:
857
- if weight1 is None:
858
- return (
859
- (y, dropout_mask, dropout_mask1)
860
- if not prenorm
861
- else (y, residual_out, dropout_mask, dropout_mask1)
862
- )
863
- else:
864
- return (
865
- (y, y1, dropout_mask, dropout_mask1)
866
- if not prenorm
867
- else (y, y1, residual_out, dropout_mask, dropout_mask1)
868
- )
869
-
870
- @staticmethod
871
- def backward(ctx, dy, *args):
872
- x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
873
- dy = dy.reshape(-1, dy.shape[-1])
874
- if dy.stride(-1) != 1:
875
- dy = dy.contiguous()
876
- assert dy.shape == x.shape
877
- if weight1 is not None:
878
- dy1, args = args[0], args[1:]
879
- dy1 = dy1.reshape(-1, dy1.shape[-1])
880
- if dy1.stride(-1) != 1:
881
- dy1 = dy1.contiguous()
882
- assert dy1.shape == x.shape
883
- else:
884
- dy1 = None
885
- if ctx.prenorm:
886
- dresidual = args[0]
887
- dresidual = dresidual.reshape(-1, dresidual.shape[-1])
888
- if dresidual.stride(-1) != 1:
889
- dresidual = dresidual.contiguous()
890
- assert dresidual.shape == x.shape
891
- else:
892
- dresidual = None
893
- dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
894
- dy,
895
- x,
896
- weight,
897
- bias,
898
- ctx.eps,
899
- mean,
900
- rstd,
901
- dresidual,
902
- dy1,
903
- weight1,
904
- bias1,
905
- seeds,
906
- ctx.dropout_p,
907
- rowscale,
908
- ctx.has_residual,
909
- ctx.has_x1,
910
- ctx.is_rms_norm,
911
- x_dtype=ctx.x_dtype,
912
- )
913
- return (
914
- dx.reshape(ctx.x_shape_og),
915
- dw,
916
- db,
917
- dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
918
- dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
919
- dw1,
920
- db1,
921
- None,
922
- None,
923
- None,
924
- None,
925
- None,
926
- None,
927
- None,
928
- )
929
-
930
-
931
- def layer_norm_fn(
932
- x,
933
- weight,
934
- bias,
935
- residual=None,
936
- x1=None,
937
- weight1=None,
938
- bias1=None,
939
- eps=1e-6,
940
- dropout_p=0.0,
941
- rowscale=None,
942
- prenorm=False,
943
- residual_in_fp32=False,
944
- is_rms_norm=False,
945
- return_dropout_mask=False,
946
- ):
947
- return LayerNormFn.apply(
948
- x,
949
- weight,
950
- bias,
951
- residual,
952
- x1,
953
- weight1,
954
- bias1,
955
- eps,
956
- dropout_p,
957
- rowscale,
958
- prenorm,
959
- residual_in_fp32,
960
- is_rms_norm,
961
- return_dropout_mask,
962
- )
963
-
964
-
965
- def rms_norm_fn(
966
- x,
967
- weight,
968
- bias,
969
- residual=None,
970
- x1=None,
971
- weight1=None,
972
- bias1=None,
973
- eps=1e-6,
974
- dropout_p=0.0,
975
- rowscale=None,
976
- prenorm=False,
977
- residual_in_fp32=False,
978
- return_dropout_mask=False,
979
- ):
980
- return LayerNormFn.apply(
981
- x,
982
- weight,
983
- bias,
984
- residual,
985
- x1,
986
- weight1,
987
- bias1,
988
- eps,
989
- dropout_p,
990
- rowscale,
991
- prenorm,
992
- residual_in_fp32,
993
- True,
994
- return_dropout_mask,
995
- )
996
-
997
-
998
- class RMSNorm(torch.nn.Module):
999
-
1000
- def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
1001
- factory_kwargs = {"device": device, "dtype": dtype}
1002
- super().__init__()
1003
- self.eps = eps
1004
- if dropout_p > 0.0:
1005
- self.drop = torch.nn.Dropout(dropout_p)
1006
- else:
1007
- self.drop = None
1008
- self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
1009
- self.register_parameter("bias", None)
1010
- self.reset_parameters()
1011
-
1012
- def reset_parameters(self):
1013
- torch.nn.init.ones_(self.weight)
1014
-
1015
- def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
1016
- return rms_norm_fn(
1017
- x,
1018
- self.weight,
1019
- self.bias,
1020
- residual=residual,
1021
- eps=self.eps,
1022
- dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
1023
- prenorm=prenorm,
1024
- residual_in_fp32=residual_in_fp32,
1025
- )
1026
-
1027
-
1028
- class LayerNormLinearFn(torch.autograd.Function):
1029
- @staticmethod
1030
- @custom_fwd
1031
- def forward(
1032
- ctx,
1033
- x,
1034
- norm_weight,
1035
- norm_bias,
1036
- linear_weight,
1037
- linear_bias,
1038
- residual=None,
1039
- eps=1e-6,
1040
- prenorm=False,
1041
- residual_in_fp32=False,
1042
- is_rms_norm=False,
1043
- ):
1044
- x_shape_og = x.shape
1045
- # reshape input data into 2D tensor
1046
- x = x.reshape(-1, x.shape[-1])
1047
- if x.stride(-1) != 1:
1048
- x = x.contiguous()
1049
- if residual is not None:
1050
- assert residual.shape == x_shape_og
1051
- residual = residual.reshape(-1, residual.shape[-1])
1052
- if residual.stride(-1) != 1:
1053
- residual = residual.contiguous()
1054
- norm_weight = norm_weight.contiguous()
1055
- if norm_bias is not None:
1056
- norm_bias = norm_bias.contiguous()
1057
- residual_dtype = (
1058
- residual.dtype
1059
- if residual is not None
1060
- else (torch.float32 if residual_in_fp32 else None)
1061
- )
1062
- y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
1063
- x,
1064
- norm_weight,
1065
- norm_bias,
1066
- eps,
1067
- residual,
1068
- out_dtype=(
1069
- None
1070
- if not torch.is_autocast_enabled()
1071
- else torch.get_autocast_gpu_dtype()
1072
- ),
1073
- residual_dtype=residual_dtype,
1074
- is_rms_norm=is_rms_norm,
1075
- )
1076
- y = y.reshape(x_shape_og)
1077
- dtype = (
1078
- torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
1079
- )
1080
- linear_weight = linear_weight.to(dtype)
1081
- linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
1082
- out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
1083
- # We don't store y, will be recomputed in the backward pass to save memory
1084
- ctx.save_for_backward(
1085
- residual_out, norm_weight, norm_bias, linear_weight, mean, rstd
1086
- )
1087
- ctx.x_shape_og = x_shape_og
1088
- ctx.eps = eps
1089
- ctx.is_rms_norm = is_rms_norm
1090
- ctx.has_residual = residual is not None
1091
- ctx.prenorm = prenorm
1092
- ctx.x_dtype = x.dtype
1093
- ctx.linear_bias_is_none = linear_bias is None
1094
- return out if not prenorm else (out, residual_out.reshape(x_shape_og))
1095
-
1096
- @staticmethod
1097
- @custom_bwd
1098
- def backward(ctx, dout, *args):
1099
- x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
1100
- dout = dout.reshape(-1, dout.shape[-1])
1101
- dy = F.linear(dout, linear_weight.t())
1102
- dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
1103
- if dy.stride(-1) != 1:
1104
- dy = dy.contiguous()
1105
- assert dy.shape == x.shape
1106
- if ctx.prenorm:
1107
- dresidual = args[0]
1108
- dresidual = dresidual.reshape(-1, dresidual.shape[-1])
1109
- if dresidual.stride(-1) != 1:
1110
- dresidual = dresidual.contiguous()
1111
- assert dresidual.shape == x.shape
1112
- else:
1113
- dresidual = None
1114
- dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
1115
- dy,
1116
- x,
1117
- norm_weight,
1118
- norm_bias,
1119
- ctx.eps,
1120
- mean,
1121
- rstd,
1122
- dresidual=dresidual,
1123
- has_residual=ctx.has_residual,
1124
- is_rms_norm=ctx.is_rms_norm,
1125
- x_dtype=ctx.x_dtype,
1126
- recompute_output=True,
1127
- )
1128
- dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
1129
- return (
1130
- dx.reshape(ctx.x_shape_og),
1131
- dnorm_weight,
1132
- dnorm_bias,
1133
- dlinear_weight,
1134
- dlinear_bias,
1135
- dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
1136
- None,
1137
- None,
1138
- None,
1139
- None,
1140
- )
1141
-
1142
-
1143
- def layer_norm_linear_fn(
1144
- x,
1145
- norm_weight,
1146
- norm_bias,
1147
- linear_weight,
1148
- linear_bias,
1149
- residual=None,
1150
- eps=1e-6,
1151
- prenorm=False,
1152
- residual_in_fp32=False,
1153
- is_rms_norm=False,
1154
- ):
1155
- return LayerNormLinearFn.apply(
1156
- x,
1157
- norm_weight,
1158
- norm_bias,
1159
- linear_weight,
1160
- linear_bias,
1161
- residual,
1162
- eps,
1163
- prenorm,
1164
- residual_in_fp32,
1165
- is_rms_norm,
1166
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/triton/selective_state_update.py DELETED
@@ -1,389 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao, Albert Gu.
2
-
3
- """We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
4
- """
5
-
6
- import math
7
- import torch
8
- import torch.nn.functional as F
9
-
10
- import triton
11
- import triton.language as tl
12
-
13
- from einops import rearrange, repeat
14
-
15
- from .softplus import softplus
16
-
17
-
18
- @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
19
- @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
20
- @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
21
- @triton.heuristics(
22
- {
23
- "HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"]
24
- is not None
25
- }
26
- )
27
- @triton.heuristics(
28
- {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}
29
- )
30
- @triton.jit
31
- def _selective_scan_update_kernel(
32
- # Pointers to matrices
33
- state_ptr,
34
- x_ptr,
35
- dt_ptr,
36
- dt_bias_ptr,
37
- A_ptr,
38
- B_ptr,
39
- C_ptr,
40
- D_ptr,
41
- z_ptr,
42
- out_ptr,
43
- state_batch_indices_ptr,
44
- # Matrix dimensions
45
- batch,
46
- nheads,
47
- dim,
48
- dstate,
49
- nheads_ngroups_ratio,
50
- # Strides
51
- stride_state_batch,
52
- stride_state_head,
53
- stride_state_dim,
54
- stride_state_dstate,
55
- stride_x_batch,
56
- stride_x_head,
57
- stride_x_dim,
58
- stride_dt_batch,
59
- stride_dt_head,
60
- stride_dt_dim,
61
- stride_dt_bias_head,
62
- stride_dt_bias_dim,
63
- stride_A_head,
64
- stride_A_dim,
65
- stride_A_dstate,
66
- stride_B_batch,
67
- stride_B_group,
68
- stride_B_dstate,
69
- stride_C_batch,
70
- stride_C_group,
71
- stride_C_dstate,
72
- stride_D_head,
73
- stride_D_dim,
74
- stride_z_batch,
75
- stride_z_head,
76
- stride_z_dim,
77
- stride_out_batch,
78
- stride_out_head,
79
- stride_out_dim,
80
- # Meta-parameters
81
- DT_SOFTPLUS: tl.constexpr,
82
- TIE_HDIM: tl.constexpr,
83
- BLOCK_SIZE_M: tl.constexpr,
84
- HAS_DT_BIAS: tl.constexpr,
85
- HAS_D: tl.constexpr,
86
- HAS_Z: tl.constexpr,
87
- HAS_STATE_BATCH_INDICES: tl.constexpr,
88
- BLOCK_SIZE_DSTATE: tl.constexpr,
89
- ):
90
- pid_m = tl.program_id(axis=0)
91
- pid_b = tl.program_id(axis=1)
92
- pid_h = tl.program_id(axis=2)
93
-
94
- if HAS_STATE_BATCH_INDICES:
95
- state_batch_indices_ptr += pid_b
96
- state_batch_idx = tl.load(state_batch_indices_ptr)
97
- state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
98
- else:
99
- state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
100
-
101
- x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
102
- dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
103
- if HAS_DT_BIAS:
104
- dt_bias_ptr += pid_h * stride_dt_bias_head
105
- A_ptr += pid_h * stride_A_head
106
- B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
107
- C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
108
- if HAS_Z:
109
- z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
110
- out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
111
-
112
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
113
- offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
114
- state_ptrs = state_ptr + (
115
- offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
116
- )
117
- x_ptrs = x_ptr + offs_m * stride_x_dim
118
- dt_ptrs = dt_ptr + offs_m * stride_dt_dim
119
- if HAS_DT_BIAS:
120
- dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
121
- if HAS_D:
122
- D_ptr += pid_h * stride_D_head
123
- A_ptrs = A_ptr + (
124
- offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
125
- )
126
- B_ptrs = B_ptr + offs_n * stride_B_dstate
127
- C_ptrs = C_ptr + offs_n * stride_C_dstate
128
- if HAS_D:
129
- D_ptrs = D_ptr + offs_m * stride_D_dim
130
- if HAS_Z:
131
- z_ptrs = z_ptr + offs_m * stride_z_dim
132
- out_ptrs = out_ptr + offs_m * stride_out_dim
133
-
134
- state = tl.load(
135
- state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
136
- )
137
- x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
138
- if not TIE_HDIM:
139
- dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
140
- if HAS_DT_BIAS:
141
- dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
142
- if DT_SOFTPLUS:
143
- dt = tl.where(dt <= 20.0, softplus(dt), dt)
144
- A = tl.load(
145
- A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
146
- ).to(tl.float32)
147
- dA = tl.exp(A * dt[:, None])
148
- else:
149
- dt = tl.load(dt_ptr).to(tl.float32)
150
- if HAS_DT_BIAS:
151
- dt += tl.load(dt_bias_ptr).to(tl.float32)
152
- if DT_SOFTPLUS:
153
- dt = tl.where(dt <= 20.0, softplus(dt), dt)
154
- A = tl.load(A_ptr).to(tl.float32)
155
- dA = tl.exp(A * dt) # scalar, not a matrix
156
-
157
- B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
158
- C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
159
- if HAS_D:
160
- D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
161
- if HAS_Z:
162
- z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
163
-
164
- if not TIE_HDIM:
165
- dB = B[None, :] * dt[:, None]
166
- else:
167
- dB = B * dt # vector of size (dstate,)
168
- state = state * dA + dB * x[:, None]
169
- tl.store(
170
- state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
171
- )
172
- out = tl.sum(state * C[None, :], axis=1)
173
- if HAS_D:
174
- out += x * D
175
- if HAS_Z:
176
- out *= z * tl.sigmoid(z)
177
- tl.store(out_ptrs, out, mask=offs_m < dim)
178
-
179
-
180
- def selective_state_update(
181
- state,
182
- x,
183
- dt,
184
- A,
185
- B,
186
- C,
187
- D=None,
188
- z=None,
189
- dt_bias=None,
190
- dt_softplus=False,
191
- state_batch_indices=None,
192
- ):
193
- """
194
- Argument:
195
- state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
196
- x: (batch, dim) or (batch, nheads, dim)
197
- dt: (batch, dim) or (batch, nheads, dim)
198
- A: (dim, dstate) or (nheads, dim, dstate)
199
- B: (batch, dstate) or (batch, ngroups, dstate)
200
- C: (batch, dstate) or (batch, ngroups, dstate)
201
- D: (dim,) or (nheads, dim)
202
- z: (batch, dim) or (batch, nheads, dim)
203
- dt_bias: (dim,) or (nheads, dim)
204
- Return:
205
- out: (batch, dim) or (batch, nheads, dim)
206
- """
207
- has_heads = state.dim() > 3
208
- if state.dim() == 3:
209
- state = state.unsqueeze(1)
210
- if x.dim() == 2:
211
- x = x.unsqueeze(1)
212
- if dt.dim() == 2:
213
- dt = dt.unsqueeze(1)
214
- if A.dim() == 2:
215
- A = A.unsqueeze(0)
216
- if B.dim() == 2:
217
- B = B.unsqueeze(1)
218
- if C.dim() == 2:
219
- C = C.unsqueeze(1)
220
- if D is not None and D.dim() == 1:
221
- D = D.unsqueeze(0)
222
- if z is not None and z.dim() == 2:
223
- z = z.unsqueeze(1)
224
- if dt_bias is not None and dt_bias.dim() == 1:
225
- dt_bias = dt_bias.unsqueeze(0)
226
- _, nheads, dim, dstate = state.shape
227
- batch = x.shape[0]
228
- if x.shape != (batch, nheads, dim):
229
- print(f"{state.shape} {x.shape} {batch} {nheads} {dim}")
230
- assert x.shape == (batch, nheads, dim)
231
- assert dt.shape == x.shape
232
- assert A.shape == (nheads, dim, dstate)
233
- ngroups = B.shape[1]
234
- assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
235
- assert B.shape == (batch, ngroups, dstate)
236
- assert C.shape == B.shape
237
- if D is not None:
238
- assert D.shape == (nheads, dim)
239
- if z is not None:
240
- assert z.shape == x.shape
241
- if dt_bias is not None:
242
- assert dt_bias.shape == (nheads, dim)
243
- if state_batch_indices is not None:
244
- assert state_batch_indices.shape == (batch,)
245
- out = torch.empty_like(x)
246
- grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
247
- z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
248
- # We don't want autotune since it will overwrite the state
249
- # We instead tune by hand.
250
- BLOCK_SIZE_M, num_warps = (
251
- (32, 4)
252
- if dstate <= 16
253
- else (
254
- (16, 4)
255
- if dstate <= 32
256
- else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
257
- )
258
- )
259
- tie_hdim = (
260
- A.stride(-1) == 0
261
- and A.stride(-2) == 0
262
- and dt.stride(-1) == 0
263
- and dt_bias.stride(-1) == 0
264
- )
265
- with torch.cuda.device(x.device.index):
266
- _selective_scan_update_kernel[grid](
267
- state,
268
- x,
269
- dt,
270
- dt_bias,
271
- A,
272
- B,
273
- C,
274
- D,
275
- z,
276
- out,
277
- state_batch_indices,
278
- batch,
279
- nheads,
280
- dim,
281
- dstate,
282
- nheads // ngroups,
283
- state.stride(0),
284
- state.stride(1),
285
- state.stride(2),
286
- state.stride(3),
287
- x.stride(0),
288
- x.stride(1),
289
- x.stride(2),
290
- dt.stride(0),
291
- dt.stride(1),
292
- dt.stride(2),
293
- *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
294
- A.stride(0),
295
- A.stride(1),
296
- A.stride(2),
297
- B.stride(0),
298
- B.stride(1),
299
- B.stride(2),
300
- C.stride(0),
301
- C.stride(1),
302
- C.stride(2),
303
- *(D.stride(0), D.stride(1)) if D is not None else 0,
304
- z_strides[0],
305
- z_strides[1],
306
- z_strides[2],
307
- out.stride(0),
308
- out.stride(1),
309
- out.stride(2),
310
- dt_softplus,
311
- tie_hdim,
312
- BLOCK_SIZE_M,
313
- num_warps=num_warps,
314
- )
315
- if not has_heads:
316
- out = out.squeeze(1)
317
- return out
318
-
319
-
320
- def selective_state_update_ref(
321
- state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
322
- ):
323
- """
324
- Argument:
325
- state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
326
- x: (batch, dim) or (batch, nheads, dim)
327
- dt: (batch, dim) or (batch, nheads, dim)
328
- A: (dim, dstate) or (nheads, dim, dstate)
329
- B: (batch, dstate) or (batch, ngroups, dstate)
330
- C: (batch, dstate) or (batch, ngroups, dstate)
331
- D: (dim,) or (nheads, dim)
332
- z: (batch, dim) or (batch, nheads, dim)
333
- dt_bias: (dim,) or (nheads, dim)
334
- Return:
335
- out: (batch, dim) or (batch, nheads, dim)
336
- """
337
- has_heads = state.dim() > 3
338
- if state.dim() == 3:
339
- state = state.unsqueeze(1)
340
- if x.dim() == 2:
341
- x = x.unsqueeze(1)
342
- if dt.dim() == 2:
343
- dt = dt.unsqueeze(1)
344
- if A.dim() == 2:
345
- A = A.unsqueeze(0)
346
- if B.dim() == 2:
347
- B = B.unsqueeze(1)
348
- if C.dim() == 2:
349
- C = C.unsqueeze(1)
350
- if D is not None and D.dim() == 1:
351
- D = D.unsqueeze(0)
352
- if z is not None and z.dim() == 2:
353
- z = z.unsqueeze(1)
354
- if dt_bias is not None and dt_bias.dim() == 1:
355
- dt_bias = dt_bias.unsqueeze(0)
356
- batch, nheads, dim, dstate = state.shape
357
- assert x.shape == (batch, nheads, dim)
358
- assert dt.shape == x.shape
359
- assert A.shape == (nheads, dim, dstate)
360
- ngroups = B.shape[1]
361
- assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
362
- assert B.shape == (batch, ngroups, dstate)
363
- assert C.shape == B.shape
364
- if D is not None:
365
- assert D.shape == (nheads, dim)
366
- if z is not None:
367
- assert z.shape == x.shape
368
- if dt_bias is not None:
369
- assert dt_bias.shape == (nheads, dim)
370
- dt = dt + dt_bias
371
- dt = F.softplus(dt) if dt_softplus else dt
372
- dA = torch.exp(
373
- rearrange(dt, "b h d -> b h d 1") * A
374
- ) # (batch, nheads, dim, dstate)
375
- B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
376
- C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
377
- dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
378
- B, "b h n -> b h 1 n"
379
- ) # (batch, nheads, dim, dstate)
380
- state.copy_(
381
- state * dA + dB * rearrange(x, "b h d -> b h d 1")
382
- ) # (batch, dim, dstate
383
- out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
384
- if D is not None:
385
- out += (x * D).to(out.dtype)
386
- out = (out if z is None else out * F.silu(z)).to(x.dtype)
387
- if not has_heads:
388
- out = out.squeeze(1)
389
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_scan.py DELETED
The diff for this file is too large to render. See raw diff
 
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_state.py DELETED
@@ -1,2012 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao, Albert Gu.
2
-
3
- """We want triton==2.1.0 or 2.2.0 for this
4
- """
5
-
6
- import math
7
- import torch
8
- import torch.nn.functional as F
9
-
10
- import triton
11
- import triton.language as tl
12
-
13
- from einops import rearrange, repeat
14
-
15
- from .softplus import softplus
16
-
17
-
18
- def init_to_zero(names):
19
- return lambda nargs: [
20
- nargs[name].zero_() for name in names if nargs[name] is not None
21
- ]
22
-
23
-
24
- @triton.autotune(
25
- configs=[
26
- triton.Config({"BLOCK_SIZE_H": 1}),
27
- triton.Config({"BLOCK_SIZE_H": 2}),
28
- triton.Config({"BLOCK_SIZE_H": 4}),
29
- triton.Config({"BLOCK_SIZE_H": 8}),
30
- triton.Config({"BLOCK_SIZE_H": 16}),
31
- triton.Config({"BLOCK_SIZE_H": 32}),
32
- triton.Config({"BLOCK_SIZE_H": 64}),
33
- ],
34
- key=["chunk_size", "nheads"],
35
- )
36
- @triton.jit
37
- def _chunk_cumsum_fwd_kernel(
38
- # Pointers to matrices
39
- dt_ptr,
40
- A_ptr,
41
- dt_bias_ptr,
42
- dt_out_ptr,
43
- dA_cumsum_ptr,
44
- # Matrix dimension
45
- batch,
46
- seqlen,
47
- nheads,
48
- chunk_size,
49
- dt_min,
50
- dt_max,
51
- # Strides
52
- stride_dt_batch,
53
- stride_dt_seqlen,
54
- stride_dt_head,
55
- stride_A_head,
56
- stride_dt_bias_head,
57
- stride_dt_out_batch,
58
- stride_dt_out_chunk,
59
- stride_dt_out_head,
60
- stride_dt_out_csize,
61
- stride_dA_cs_batch,
62
- stride_dA_cs_chunk,
63
- stride_dA_cs_head,
64
- stride_dA_cs_csize,
65
- # Meta-parameters
66
- DT_SOFTPLUS: tl.constexpr,
67
- HAS_DT_BIAS: tl.constexpr,
68
- BLOCK_SIZE_H: tl.constexpr,
69
- BLOCK_SIZE_CHUNK: tl.constexpr,
70
- ):
71
- pid_b = tl.program_id(axis=0)
72
- pid_c = tl.program_id(axis=1)
73
- pid_h = tl.program_id(axis=2)
74
- dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
75
- dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk
76
- dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk
77
-
78
- offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
79
- offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
80
- dt_ptrs = dt_ptr + (
81
- offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
82
- )
83
- A_ptrs = A_ptr + offs_h * stride_A_head
84
- dt_out_ptrs = dt_out_ptr + (
85
- offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize
86
- )
87
- dA_cs_ptrs = dA_cumsum_ptr + (
88
- offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize
89
- )
90
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
91
-
92
- dt = tl.load(
93
- dt_ptrs,
94
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
95
- other=0.0,
96
- ).to(tl.float32)
97
- if HAS_DT_BIAS:
98
- dt_bias = tl.load(
99
- dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
100
- ).to(tl.float32)
101
- dt += dt_bias[:, None]
102
- if DT_SOFTPLUS:
103
- dt = tl.where(dt <= 20.0, softplus(dt), dt)
104
- # As of Triton 2.2.0, tl.clamp is not available yet
105
- # dt = tl.clamp(dt, dt_min, dt_max)
106
- dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
107
- dt = tl.where(
108
- (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
109
- )
110
- tl.store(
111
- dt_out_ptrs,
112
- dt,
113
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
114
- )
115
- A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
116
- dA = dt * A[:, None]
117
- dA_cs = tl.cumsum(dA, axis=1)
118
- tl.store(
119
- dA_cs_ptrs,
120
- dA_cs,
121
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
122
- )
123
-
124
-
125
- @triton.autotune(
126
- configs=[
127
- triton.Config(
128
- {"BLOCK_SIZE_H": 1}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
129
- ),
130
- triton.Config(
131
- {"BLOCK_SIZE_H": 2}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
132
- ),
133
- triton.Config(
134
- {"BLOCK_SIZE_H": 4}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
135
- ),
136
- triton.Config(
137
- {"BLOCK_SIZE_H": 8}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
138
- ),
139
- triton.Config(
140
- {"BLOCK_SIZE_H": 16}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
141
- ),
142
- triton.Config(
143
- {"BLOCK_SIZE_H": 32}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
144
- ),
145
- triton.Config(
146
- {"BLOCK_SIZE_H": 64}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
147
- ),
148
- ],
149
- key=["chunk_size", "nheads"],
150
- )
151
- @triton.jit
152
- def _chunk_cumsum_bwd_kernel(
153
- # Pointers to matrices
154
- ddA_ptr,
155
- ddt_out_ptr,
156
- dt_ptr,
157
- A_ptr,
158
- dt_bias_ptr,
159
- ddt_ptr,
160
- dA_ptr,
161
- ddt_bias_ptr,
162
- # Matrix dimensions
163
- batch,
164
- seqlen,
165
- nheads,
166
- chunk_size,
167
- dt_min,
168
- dt_max,
169
- # Strides
170
- stride_ddA_batch,
171
- stride_ddA_chunk,
172
- stride_ddA_head,
173
- stride_ddA_csize,
174
- stride_ddt_out_batch,
175
- stride_ddt_out_chunk,
176
- stride_ddt_out_head,
177
- stride_ddt_out_csize,
178
- stride_dt_batch,
179
- stride_dt_seqlen,
180
- stride_dt_head,
181
- stride_A_head,
182
- stride_dt_bias_head,
183
- stride_ddt_batch,
184
- stride_ddt_seqlen,
185
- stride_ddt_head,
186
- stride_dA_head,
187
- stride_ddt_bias_head,
188
- # Meta-parameters
189
- DT_SOFTPLUS: tl.constexpr,
190
- HAS_DT_BIAS: tl.constexpr,
191
- BLOCK_SIZE_H: tl.constexpr,
192
- BLOCK_SIZE_CHUNK: tl.constexpr,
193
- ):
194
- pid_b = tl.program_id(axis=0)
195
- pid_c = tl.program_id(axis=1)
196
- pid_h = tl.program_id(axis=2)
197
- ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk
198
- ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk
199
- dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
200
- ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen
201
-
202
- offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
203
- offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
204
- ddt_out_ptrs = ddt_out_ptr + (
205
- offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize
206
- )
207
- ddA_ptrs = ddA_ptr + (
208
- offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize
209
- )
210
- dt_ptrs = dt_ptr + (
211
- offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
212
- )
213
- ddt_ptrs = ddt_ptr + (
214
- offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen
215
- )
216
- A_ptrs = A_ptr + offs_h * stride_A_head
217
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
218
-
219
- ddA = tl.load(
220
- ddA_ptrs,
221
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
222
- other=0.0,
223
- ).to(tl.float32)
224
- ddt_out = tl.load(
225
- ddt_out_ptrs,
226
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
227
- other=0.0,
228
- ).to(tl.float32)
229
- A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
230
- ddt = ddA * A[:, None] + ddt_out
231
- dt = tl.load(
232
- dt_ptrs,
233
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
234
- other=0.0,
235
- ).to(tl.float32)
236
- if HAS_DT_BIAS:
237
- dt_bias = tl.load(
238
- dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
239
- ).to(tl.float32)
240
- dt += dt_bias[:, None]
241
- if DT_SOFTPLUS:
242
- dt_presoftplus = dt
243
- dt = tl.where(dt <= 20.0, softplus(dt), ddt)
244
- clamp_mask = (dt < dt_min) | (dt > dt_max)
245
- # As of Triton 2.2.0, tl.clamp is not available yet
246
- # dt = tl.clamp(dt, dt_min, dt_max)
247
- dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
248
- dt = tl.where(
249
- (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
250
- )
251
- ddt = tl.where(
252
- (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0
253
- )
254
- ddt = tl.where(clamp_mask, 0.0, ddt)
255
- if DT_SOFTPLUS:
256
- ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt)
257
- tl.store(
258
- ddt_ptrs,
259
- ddt,
260
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
261
- )
262
- dA = tl.sum(ddA * dt, axis=1)
263
- tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads)
264
- if HAS_DT_BIAS:
265
- ddt_bias = tl.sum(ddt, axis=1)
266
- tl.atomic_add(
267
- ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads
268
- )
269
-
270
-
271
- @triton.autotune(
272
- configs=[
273
- triton.Config(
274
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
275
- num_stages=3,
276
- num_warps=8,
277
- ),
278
- triton.Config(
279
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
280
- num_stages=4,
281
- num_warps=4,
282
- ),
283
- triton.Config(
284
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
285
- num_stages=4,
286
- num_warps=4,
287
- ),
288
- triton.Config(
289
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
290
- num_stages=4,
291
- num_warps=4,
292
- ),
293
- triton.Config(
294
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
295
- num_stages=4,
296
- num_warps=4,
297
- ),
298
- triton.Config(
299
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
300
- num_stages=4,
301
- num_warps=4,
302
- ),
303
- triton.Config(
304
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
305
- num_stages=5,
306
- num_warps=2,
307
- ),
308
- triton.Config(
309
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
310
- num_stages=5,
311
- num_warps=2,
312
- ),
313
- triton.Config(
314
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
315
- num_stages=4,
316
- num_warps=2,
317
- ),
318
- ],
319
- key=["hdim", "dstate", "chunk_size"],
320
- )
321
- @triton.jit
322
- def _chunk_state_fwd_kernel(
323
- # Pointers to matrices
324
- x_ptr,
325
- b_ptr,
326
- states_ptr,
327
- dt_ptr,
328
- dA_cumsum_ptr,
329
- seq_idx_ptr,
330
- # Matrix dimensions
331
- hdim,
332
- dstate,
333
- chunk_size,
334
- batch,
335
- seqlen,
336
- nheads_ngroups_ratio,
337
- # Strides
338
- stride_x_batch,
339
- stride_x_seqlen,
340
- stride_x_head,
341
- stride_x_hdim,
342
- stride_b_batch,
343
- stride_b_seqlen,
344
- stride_b_head,
345
- stride_b_dstate,
346
- stride_states_batch,
347
- stride_states_chunk,
348
- stride_states_head,
349
- stride_states_hdim,
350
- stride_states_dstate,
351
- stride_dt_batch,
352
- stride_dt_chunk,
353
- stride_dt_head,
354
- stride_dt_csize,
355
- stride_dA_cs_batch,
356
- stride_dA_cs_chunk,
357
- stride_dA_cs_head,
358
- stride_dA_cs_csize,
359
- stride_seq_idx_batch,
360
- stride_seq_idx_seqlen,
361
- # Meta-parameters
362
- HAS_SEQ_IDX: tl.constexpr,
363
- BLOCK_SIZE_M: tl.constexpr,
364
- BLOCK_SIZE_N: tl.constexpr,
365
- BLOCK_SIZE_K: tl.constexpr,
366
- ):
367
- pid_bc = tl.program_id(axis=1)
368
- pid_c = pid_bc // batch
369
- pid_b = pid_bc - pid_c * batch
370
- pid_h = tl.program_id(axis=2)
371
- num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
372
- pid_m = tl.program_id(axis=0) // num_pid_n
373
- pid_n = tl.program_id(axis=0) % num_pid_n
374
- b_ptr += (
375
- pid_b * stride_b_batch
376
- + pid_c * chunk_size * stride_b_seqlen
377
- + (pid_h // nheads_ngroups_ratio) * stride_b_head
378
- )
379
- x_ptr += (
380
- pid_b * stride_x_batch
381
- + pid_c * chunk_size * stride_x_seqlen
382
- + pid_h * stride_x_head
383
- )
384
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
385
- dA_cumsum_ptr += (
386
- pid_b * stride_dA_cs_batch
387
- + pid_c * stride_dA_cs_chunk
388
- + pid_h * stride_dA_cs_head
389
- )
390
- if HAS_SEQ_IDX:
391
- seq_idx_ptr += (
392
- pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
393
- )
394
-
395
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
396
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
397
- offs_k = tl.arange(0, BLOCK_SIZE_K)
398
- x_ptrs = x_ptr + (
399
- offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
400
- )
401
- b_ptrs = b_ptr + (
402
- offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
403
- )
404
- dt_ptrs = dt_ptr + offs_k * stride_dt_csize
405
- dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
406
- tl.float32
407
- )
408
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
409
- if HAS_SEQ_IDX:
410
- seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
411
-
412
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
413
- if HAS_SEQ_IDX:
414
- seq_idx_last = tl.load(
415
- seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
416
- )
417
-
418
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
419
- for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
420
- x = tl.load(
421
- x_ptrs,
422
- mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k),
423
- other=0.0,
424
- )
425
- b = tl.load(
426
- b_ptrs,
427
- mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate),
428
- other=0.0,
429
- ).to(tl.float32)
430
- dA_cs_k = tl.load(
431
- dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
432
- ).to(tl.float32)
433
- if HAS_SEQ_IDX:
434
- seq_idx_k = tl.load(
435
- seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1
436
- )
437
- dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
438
- tl.float32
439
- )
440
- if not HAS_SEQ_IDX:
441
- scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
442
- else:
443
- scale = tl.where(
444
- seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0
445
- )
446
- b *= scale[:, None]
447
- b = b.to(x_ptr.dtype.element_ty)
448
- acc += tl.dot(x, b)
449
- x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
450
- b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
451
- dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
452
- dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
453
- if HAS_SEQ_IDX:
454
- seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
455
- states = acc.to(states_ptr.dtype.element_ty)
456
-
457
- states_ptr += (
458
- pid_b * stride_states_batch
459
- + pid_c * stride_states_chunk
460
- + pid_h * stride_states_head
461
- )
462
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
463
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
464
- states_ptrs = states_ptr + (
465
- offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
466
- )
467
- c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
468
- tl.store(states_ptrs, states, mask=c_mask)
469
-
470
-
471
- @triton.autotune(
472
- configs=[
473
- triton.Config(
474
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
475
- num_stages=3,
476
- num_warps=8,
477
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
478
- ),
479
- triton.Config(
480
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
481
- num_stages=4,
482
- num_warps=4,
483
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
484
- ),
485
- triton.Config(
486
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
487
- num_stages=4,
488
- num_warps=4,
489
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
490
- ),
491
- triton.Config(
492
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
493
- num_stages=4,
494
- num_warps=4,
495
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
496
- ),
497
- triton.Config(
498
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
499
- num_stages=4,
500
- num_warps=4,
501
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
502
- ),
503
- triton.Config(
504
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
505
- num_stages=4,
506
- num_warps=4,
507
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
508
- ),
509
- triton.Config(
510
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
511
- num_stages=5,
512
- num_warps=4,
513
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
514
- ),
515
- triton.Config(
516
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
517
- num_stages=5,
518
- num_warps=4,
519
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
520
- ),
521
- triton.Config(
522
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
523
- num_stages=4,
524
- num_warps=4,
525
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
526
- ),
527
- ],
528
- key=["chunk_size", "hdim", "dstate"],
529
- )
530
- @triton.jit
531
- def _chunk_state_bwd_dx_kernel(
532
- # Pointers to matrices
533
- x_ptr,
534
- b_ptr,
535
- dstates_ptr,
536
- dt_ptr,
537
- dA_cumsum_ptr,
538
- dx_ptr,
539
- ddt_ptr,
540
- ddA_cumsum_ptr,
541
- # Matrix dimensions
542
- chunk_size,
543
- hdim,
544
- dstate,
545
- batch,
546
- seqlen,
547
- nheads_ngroups_ratio,
548
- # Strides
549
- stride_x_batch,
550
- stride_x_seqlen,
551
- stride_x_head,
552
- stride_x_hdim,
553
- stride_b_batch,
554
- stride_b_seqlen,
555
- stride_b_head,
556
- stride_b_dstate,
557
- stride_dstates_batch,
558
- stride_dstates_chunk,
559
- stride_states_head,
560
- stride_states_hdim,
561
- stride_states_dstate,
562
- stride_dt_batch,
563
- stride_dt_chunk,
564
- stride_dt_head,
565
- stride_dt_csize,
566
- stride_dA_cs_batch,
567
- stride_dA_cs_chunk,
568
- stride_dA_cs_head,
569
- stride_dA_cs_csize,
570
- stride_dx_batch,
571
- stride_dx_seqlen,
572
- stride_dx_head,
573
- stride_dx_hdim,
574
- stride_ddt_batch,
575
- stride_ddt_chunk,
576
- stride_ddt_head,
577
- stride_ddt_csize,
578
- stride_ddA_cs_batch,
579
- stride_ddA_cs_chunk,
580
- stride_ddA_cs_head,
581
- stride_ddA_cs_csize,
582
- # Meta-parameters
583
- BLOCK_SIZE_M: tl.constexpr,
584
- BLOCK_SIZE_N: tl.constexpr,
585
- BLOCK_SIZE_K: tl.constexpr,
586
- BLOCK_SIZE_DSTATE: tl.constexpr,
587
- ):
588
- pid_bc = tl.program_id(axis=1)
589
- pid_c = pid_bc // batch
590
- pid_b = pid_bc - pid_c * batch
591
- pid_h = tl.program_id(axis=2)
592
- num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
593
- pid_m = tl.program_id(axis=0) // num_pid_n
594
- pid_n = tl.program_id(axis=0) % num_pid_n
595
- x_ptr += (
596
- pid_b * stride_x_batch
597
- + pid_c * chunk_size * stride_x_seqlen
598
- + pid_h * stride_x_head
599
- )
600
- b_ptr += (
601
- pid_b * stride_b_batch
602
- + pid_c * chunk_size * stride_b_seqlen
603
- + (pid_h // nheads_ngroups_ratio) * stride_b_head
604
- )
605
- dstates_ptr += (
606
- pid_b * stride_dstates_batch
607
- + pid_c * stride_dstates_chunk
608
- + pid_h * stride_states_head
609
- )
610
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
611
- ddt_ptr += (
612
- pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
613
- )
614
- ddA_cumsum_ptr += (
615
- pid_b * stride_ddA_cs_batch
616
- + pid_c * stride_ddA_cs_chunk
617
- + pid_h * stride_ddA_cs_head
618
- )
619
- dA_cumsum_ptr += (
620
- pid_b * stride_dA_cs_batch
621
- + pid_c * stride_dA_cs_chunk
622
- + pid_h * stride_dA_cs_head
623
- )
624
-
625
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
626
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
627
-
628
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
629
- # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
630
- offs_k = tl.arange(
631
- 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
632
- )
633
- b_ptrs = b_ptr + (
634
- offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate
635
- )
636
- dstates_ptrs = dstates_ptr + (
637
- offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate
638
- )
639
- if BLOCK_SIZE_DSTATE <= 128:
640
- b = tl.load(
641
- b_ptrs,
642
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate),
643
- other=0.0,
644
- )
645
- dstates = tl.load(
646
- dstates_ptrs,
647
- mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim),
648
- other=0.0,
649
- )
650
- dstates = dstates.to(b_ptr.dtype.element_ty)
651
- acc = tl.dot(b, dstates)
652
- else:
653
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
654
- for k in range(0, dstate, BLOCK_SIZE_K):
655
- b = tl.load(
656
- b_ptrs,
657
- mask=(offs_m[:, None] < chunk_size_limit)
658
- & (offs_k[None, :] < dstate - k),
659
- other=0.0,
660
- )
661
- dstates = tl.load(
662
- dstates_ptrs,
663
- mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim),
664
- other=0.0,
665
- )
666
- dstates = dstates.to(b_ptr.dtype.element_ty)
667
- acc += tl.dot(b, dstates)
668
- b_ptrs += BLOCK_SIZE_K * stride_b_dstate
669
- dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
670
-
671
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
672
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
673
-
674
- dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
675
- tl.float32
676
- )
677
- dt_ptrs = dt_ptr + offs_m * stride_dt_csize
678
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
679
- dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(
680
- tl.float32
681
- )
682
- dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
683
- acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]
684
-
685
- x_ptrs = x_ptr + (
686
- offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
687
- )
688
- x = tl.load(
689
- x_ptrs,
690
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
691
- other=0.0,
692
- ).to(tl.float32)
693
- ddt = tl.sum(acc * x, axis=1)
694
- ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
695
- tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
696
- ddA_cs = -(ddt * dt_m)
697
- ddA_cs_last = -tl.sum(ddA_cs)
698
- ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
699
- tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
700
- tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last)
701
-
702
- dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty)
703
- dx_ptr += (
704
- pid_b * stride_dx_batch
705
- + pid_c * chunk_size * stride_dx_seqlen
706
- + pid_h * stride_dx_head
707
- )
708
- dx_ptrs = dx_ptr + (
709
- offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim
710
- )
711
- tl.store(
712
- dx_ptrs,
713
- dx,
714
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
715
- )
716
-
717
-
718
- @triton.autotune(
719
- configs=[
720
- triton.Config(
721
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128},
722
- num_stages=3,
723
- num_warps=4,
724
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
725
- ),
726
- triton.Config(
727
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32},
728
- num_stages=3,
729
- num_warps=4,
730
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
731
- ),
732
- triton.Config(
733
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128},
734
- num_stages=3,
735
- num_warps=4,
736
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
737
- ),
738
- triton.Config(
739
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64},
740
- num_stages=3,
741
- num_warps=4,
742
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
743
- ),
744
- triton.Config(
745
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64},
746
- num_stages=3,
747
- num_warps=4,
748
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
749
- ),
750
- triton.Config(
751
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32},
752
- num_stages=3,
753
- num_warps=4,
754
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
755
- ),
756
- triton.Config(
757
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64},
758
- num_stages=3,
759
- num_warps=4,
760
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
761
- ),
762
- triton.Config(
763
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32},
764
- num_stages=3,
765
- num_warps=4,
766
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
767
- ),
768
- ],
769
- key=["chunk_size", "dstate", "hdim"],
770
- )
771
- @triton.jit
772
- def _chunk_state_bwd_db_kernel(
773
- # Pointers to matrices
774
- x_ptr,
775
- dstates_ptr,
776
- b_ptr,
777
- dt_ptr,
778
- dA_cumsum_ptr,
779
- seq_idx_ptr,
780
- db_ptr,
781
- ddA_cumsum_ptr,
782
- # Matrix dimensions
783
- chunk_size,
784
- dstate,
785
- hdim,
786
- batch,
787
- seqlen,
788
- nheads,
789
- nheads_per_program,
790
- ngroups,
791
- # Strides
792
- stride_x_batch,
793
- stride_x_seqlen,
794
- stride_x_head,
795
- stride_x_hdim,
796
- stride_dstates_batch,
797
- stride_dstates_chunk,
798
- stride_states_head,
799
- stride_states_hdim,
800
- stride_states_dstate,
801
- stride_b_batch,
802
- stride_b_seqlen,
803
- stride_b_head,
804
- stride_b_dstate,
805
- stride_dt_batch,
806
- stride_dt_chunk,
807
- stride_dt_head,
808
- stride_dt_csize,
809
- stride_dA_cs_batch,
810
- stride_dA_cs_chunk,
811
- stride_dA_cs_head,
812
- stride_dA_cs_csize,
813
- stride_seq_idx_batch,
814
- stride_seq_idx_seqlen,
815
- stride_db_batch,
816
- stride_db_seqlen,
817
- stride_db_split,
818
- stride_db_group,
819
- stride_db_dstate,
820
- stride_ddA_cs_batch,
821
- stride_ddA_cs_chunk,
822
- stride_ddA_cs_head,
823
- stride_ddA_cs_csize,
824
- # Meta-parameters
825
- HAS_DDA_CS: tl.constexpr,
826
- HAS_SEQ_IDX: tl.constexpr,
827
- BLOCK_SIZE_M: tl.constexpr,
828
- BLOCK_SIZE_N: tl.constexpr,
829
- BLOCK_SIZE_K: tl.constexpr,
830
- ):
831
- pid_bc = tl.program_id(axis=1)
832
- pid_c = pid_bc // batch
833
- pid_b = pid_bc - pid_c * batch
834
- pid_sg = tl.program_id(axis=2)
835
- pid_s = pid_sg // ngroups
836
- pid_g = pid_sg - pid_s * ngroups
837
- num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
838
- pid_m = tl.program_id(axis=0) // num_pid_n
839
- pid_n = tl.program_id(axis=0) % num_pid_n
840
- x_ptr += (
841
- pid_b * stride_x_batch
842
- + pid_c * chunk_size * stride_x_seqlen
843
- + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head
844
- )
845
- db_ptr += (
846
- pid_b * stride_db_batch
847
- + pid_c * chunk_size * stride_db_seqlen
848
- + pid_g * stride_db_group
849
- + pid_s * stride_db_split
850
- )
851
- dstates_ptr += (
852
- pid_b * stride_dstates_batch
853
- + pid_c * stride_dstates_chunk
854
- + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program)
855
- * stride_states_head
856
- )
857
- dt_ptr += (
858
- pid_b * stride_dt_batch
859
- + pid_c * stride_dt_chunk
860
- + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head
861
- )
862
- dA_cumsum_ptr += (
863
- pid_b * stride_dA_cs_batch
864
- + pid_c * stride_dA_cs_chunk
865
- + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head
866
- )
867
- if HAS_DDA_CS:
868
- b_ptr += (
869
- pid_b * stride_b_batch
870
- + pid_c * chunk_size * stride_b_seqlen
871
- + pid_g * stride_b_head
872
- )
873
- ddA_cumsum_ptr += (
874
- pid_b * stride_ddA_cs_batch
875
- + pid_c * stride_ddA_cs_chunk
876
- + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program)
877
- * stride_ddA_cs_head
878
- )
879
- if HAS_SEQ_IDX:
880
- seq_idx_ptr += (
881
- pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
882
- )
883
-
884
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
885
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
886
- offs_k = tl.arange(0, BLOCK_SIZE_K)
887
- x_ptrs = x_ptr + (
888
- offs_m[:, None] * stride_x_seqlen + offs_k[None, :] * stride_x_hdim
889
- )
890
- dstates_ptrs = dstates_ptr + (
891
- offs_n[None, :] * stride_states_dstate + offs_k[:, None] * stride_states_hdim
892
- )
893
- dt_ptrs = dt_ptr + offs_m * stride_dt_csize
894
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
895
- if HAS_DDA_CS:
896
- b_ptrs = b_ptr + (
897
- offs_m[:, None] * stride_b_seqlen + offs_n[None, :] * stride_b_dstate
898
- )
899
- ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
900
-
901
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
902
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
903
- if HAS_DDA_CS:
904
- b = tl.load(
905
- b_ptrs,
906
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate),
907
- other=0.0,
908
- ).to(tl.float32)
909
- if HAS_SEQ_IDX:
910
- seq_idx_m = tl.load(
911
- seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
912
- mask=offs_m < chunk_size_limit,
913
- other=-1,
914
- )
915
- seq_idx_last = tl.load(
916
- seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
917
- )
918
- nheads_iter = min(
919
- nheads_per_program, nheads // ngroups - pid_s * nheads_per_program
920
- )
921
- for h in range(nheads_iter):
922
- x = tl.load(
923
- x_ptrs,
924
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim),
925
- other=0.0,
926
- )
927
- dstates = tl.load(
928
- dstates_ptrs,
929
- mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate),
930
- other=0.0,
931
- )
932
- dstates = dstates.to(x_ptrs.dtype.element_ty)
933
- db = tl.dot(x, dstates)
934
- dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
935
- tl.float32
936
- )
937
- dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(
938
- tl.float32
939
- )
940
- dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
941
- if not HAS_SEQ_IDX:
942
- scale = tl.exp(dA_cs_last - dA_cs_m)
943
- else:
944
- scale = tl.where(
945
- seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0
946
- )
947
- db *= (scale * dt_m)[:, None]
948
- if HAS_DDA_CS:
949
- # This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum
950
- ddA_cs = tl.sum(db * b, axis=1)
951
- tl.atomic_add(
952
- ddA_cumsum_ptrs + stride_ddA_cs_csize,
953
- ddA_cs,
954
- mask=offs_m < chunk_size - 1,
955
- )
956
- acc += db
957
- x_ptrs += stride_x_head
958
- dstates_ptrs += stride_states_head
959
- dt_ptrs += stride_dt_head
960
- dA_cumsum_ptr += stride_dA_cs_head
961
- dA_cumsum_ptrs += stride_dA_cs_head
962
- if HAS_DDA_CS:
963
- ddA_cumsum_ptrs += stride_ddA_cs_head
964
-
965
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
966
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
967
- # if HAS_SEQ_IDX:
968
- # seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
969
- # seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
970
- # acc = tl.where(seq_idx_m[:, None] == seq_idx_last, acc, 0.0)
971
- db_ptrs = db_ptr + (
972
- offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_dstate
973
- )
974
- tl.store(
975
- db_ptrs,
976
- acc,
977
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate),
978
- )
979
-
980
-
981
- @triton.autotune(
982
- configs=[
983
- # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
984
- # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
985
- # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
986
- # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
987
- # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
988
- # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
989
- # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
990
- # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
991
- # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
992
- triton.Config(
993
- {"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32},
994
- num_stages=3,
995
- num_warps=4,
996
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
997
- ),
998
- triton.Config(
999
- {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
1000
- num_stages=3,
1001
- num_warps=4,
1002
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1003
- ),
1004
- triton.Config(
1005
- {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1006
- num_stages=3,
1007
- num_warps=4,
1008
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1009
- ),
1010
- triton.Config(
1011
- {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
1012
- num_stages=3,
1013
- num_warps=4,
1014
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1015
- ),
1016
- triton.Config(
1017
- {"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32},
1018
- num_stages=4,
1019
- num_warps=8,
1020
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1021
- ),
1022
- triton.Config(
1023
- {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
1024
- num_stages=4,
1025
- num_warps=8,
1026
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1027
- ),
1028
- triton.Config(
1029
- {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1030
- num_stages=4,
1031
- num_warps=8,
1032
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1033
- ),
1034
- triton.Config(
1035
- {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
1036
- num_stages=4,
1037
- num_warps=8,
1038
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1039
- ),
1040
- ],
1041
- key=["chunk_size", "hdim", "dstate"],
1042
- )
1043
- @triton.jit
1044
- def _chunk_state_bwd_ddAcs_stable_kernel(
1045
- # Pointers to matrices
1046
- x_ptr,
1047
- b_ptr,
1048
- dstates_ptr,
1049
- dt_ptr,
1050
- dA_cumsum_ptr,
1051
- seq_idx_ptr,
1052
- ddA_cumsum_ptr,
1053
- # Matrix dimensions
1054
- chunk_size,
1055
- hdim,
1056
- dstate,
1057
- batch,
1058
- seqlen,
1059
- nheads_ngroups_ratio,
1060
- # Strides
1061
- stride_x_batch,
1062
- stride_x_seqlen,
1063
- stride_x_head,
1064
- stride_x_hdim,
1065
- stride_b_batch,
1066
- stride_b_seqlen,
1067
- stride_b_head,
1068
- stride_b_dstate,
1069
- stride_dstates_batch,
1070
- stride_dstates_chunk,
1071
- stride_states_head,
1072
- stride_states_hdim,
1073
- stride_states_dstate,
1074
- stride_dt_batch,
1075
- stride_dt_chunk,
1076
- stride_dt_head,
1077
- stride_dt_csize,
1078
- stride_dA_cs_batch,
1079
- stride_dA_cs_chunk,
1080
- stride_dA_cs_head,
1081
- stride_dA_cs_csize,
1082
- stride_seq_idx_batch,
1083
- stride_seq_idx_seqlen,
1084
- stride_ddA_cs_batch,
1085
- stride_ddA_cs_chunk,
1086
- stride_ddA_cs_head,
1087
- stride_ddA_cs_csize,
1088
- # Meta-parameters
1089
- HAS_SEQ_IDX: tl.constexpr,
1090
- BLOCK_SIZE_M: tl.constexpr,
1091
- BLOCK_SIZE_N: tl.constexpr,
1092
- BLOCK_SIZE_K: tl.constexpr,
1093
- BLOCK_SIZE_DSTATE: tl.constexpr,
1094
- ):
1095
- pid_bc = tl.program_id(axis=1)
1096
- pid_c = pid_bc // batch
1097
- pid_b = pid_bc - pid_c * batch
1098
- pid_h = tl.program_id(axis=2)
1099
- num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
1100
- pid_m = tl.program_id(axis=0) // num_pid_n
1101
- pid_n = tl.program_id(axis=0) % num_pid_n
1102
- x_ptr += (
1103
- pid_b * stride_x_batch
1104
- + pid_c * chunk_size * stride_x_seqlen
1105
- + pid_h * stride_x_head
1106
- )
1107
- b_ptr += (
1108
- pid_b * stride_b_batch
1109
- + pid_c * chunk_size * stride_b_seqlen
1110
- + (pid_h // nheads_ngroups_ratio) * stride_b_head
1111
- )
1112
- dstates_ptr += (
1113
- pid_b * stride_dstates_batch
1114
- + pid_c * stride_dstates_chunk
1115
- + pid_h * stride_states_head
1116
- )
1117
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
1118
- ddA_cumsum_ptr += (
1119
- pid_b * stride_ddA_cs_batch
1120
- + pid_c * stride_ddA_cs_chunk
1121
- + pid_h * stride_ddA_cs_head
1122
- )
1123
- dA_cumsum_ptr += (
1124
- pid_b * stride_dA_cs_batch
1125
- + pid_c * stride_dA_cs_chunk
1126
- + pid_h * stride_dA_cs_head
1127
- )
1128
- if HAS_SEQ_IDX:
1129
- seq_idx_ptr += (
1130
- pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
1131
- )
1132
-
1133
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1134
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1135
-
1136
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
1137
- # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
1138
- offs_k = tl.arange(
1139
- 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
1140
- )
1141
- b_ptrs = b_ptr + (
1142
- offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate
1143
- )
1144
- dstates_ptrs = dstates_ptr + (
1145
- offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate
1146
- )
1147
- if BLOCK_SIZE_DSTATE <= 128:
1148
- b = tl.load(
1149
- b_ptrs,
1150
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate),
1151
- other=0.0,
1152
- )
1153
- dstates = tl.load(
1154
- dstates_ptrs,
1155
- mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim),
1156
- other=0.0,
1157
- )
1158
- dstates = dstates.to(b_ptr.dtype.element_ty)
1159
- acc = tl.dot(b, dstates)
1160
- else:
1161
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
1162
- for k in range(0, dstate, BLOCK_SIZE_K):
1163
- b = tl.load(
1164
- b_ptrs,
1165
- mask=(offs_m[:, None] < chunk_size_limit)
1166
- & (offs_k[None, :] < dstate - k),
1167
- other=0.0,
1168
- )
1169
- dstates = tl.load(
1170
- dstates_ptrs,
1171
- mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim),
1172
- other=0.0,
1173
- )
1174
- dstates = dstates.to(b_ptr.dtype.element_ty)
1175
- acc += tl.dot(b, dstates)
1176
- b_ptrs += BLOCK_SIZE_K * stride_b_dstate
1177
- dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
1178
-
1179
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1180
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1181
-
1182
- dA_cs_m = tl.load(
1183
- dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0
1184
- ).to(tl.float32)
1185
- dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
1186
- tl.float32
1187
- )
1188
- if not HAS_SEQ_IDX:
1189
- scale = tl.exp(dA_cs_last - dA_cs_m)
1190
- else:
1191
- seq_idx_m = tl.load(
1192
- seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
1193
- mask=offs_m < chunk_size_limit,
1194
- other=-1,
1195
- )
1196
- seq_idx_last = tl.load(
1197
- seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
1198
- )
1199
- scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
1200
- acc *= scale[:, None]
1201
-
1202
- x_ptrs = x_ptr + (
1203
- offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
1204
- )
1205
- x = tl.load(
1206
- x_ptrs,
1207
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
1208
- other=0.0,
1209
- ).to(tl.float32)
1210
- dt_ptrs = dt_ptr + offs_m * stride_dt_csize
1211
- dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
1212
- ddt = tl.sum(acc * x, axis=1)
1213
- # ddA_cs = -(ddt * dt_m)
1214
- # Triton 2.2.0 errors if we have the cumsum here, so we just write it out
1215
- # then call torch.cumsum outside this kernel.
1216
- # ddA_cs = tl.cumsum(ddt * dt_m)
1217
- ddA_cs = ddt * dt_m
1218
- ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
1219
- # tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
1220
- tl.atomic_add(
1221
- ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1
1222
- )
1223
-
1224
-
1225
- @triton.autotune(
1226
- configs=[
1227
- triton.Config(
1228
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
1229
- num_stages=3,
1230
- num_warps=8,
1231
- ),
1232
- triton.Config(
1233
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
1234
- num_stages=4,
1235
- num_warps=4,
1236
- ),
1237
- triton.Config(
1238
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
1239
- num_stages=4,
1240
- num_warps=4,
1241
- ),
1242
- triton.Config(
1243
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1244
- num_stages=4,
1245
- num_warps=4,
1246
- ),
1247
- triton.Config(
1248
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
1249
- num_stages=4,
1250
- num_warps=4,
1251
- ),
1252
- triton.Config(
1253
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
1254
- num_stages=4,
1255
- num_warps=4,
1256
- ),
1257
- triton.Config(
1258
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
1259
- num_stages=5,
1260
- num_warps=2,
1261
- ),
1262
- triton.Config(
1263
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1264
- num_stages=5,
1265
- num_warps=2,
1266
- ),
1267
- triton.Config(
1268
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1269
- num_stages=4,
1270
- num_warps=2,
1271
- ),
1272
- ],
1273
- key=["hdim", "dstate", "chunk_size"],
1274
- )
1275
- @triton.jit
1276
- def _chunk_state_varlen_kernel(
1277
- # Pointers to matrices
1278
- x_ptr,
1279
- b_ptr,
1280
- dt_ptr,
1281
- dA_cumsum_ptr,
1282
- chunk_states_ptr,
1283
- cu_seqlens_ptr,
1284
- states_ptr,
1285
- # Matrix dimensions
1286
- hdim,
1287
- dstate,
1288
- chunk_size,
1289
- seqlen,
1290
- nheads_ngroups_ratio,
1291
- # Strides
1292
- stride_x_seqlen,
1293
- stride_x_head,
1294
- stride_x_hdim,
1295
- stride_b_seqlen,
1296
- stride_b_head,
1297
- stride_b_dstate,
1298
- stride_dt_chunk,
1299
- stride_dt_head,
1300
- stride_dt_csize,
1301
- stride_dA_cs_chunk,
1302
- stride_dA_cs_head,
1303
- stride_dA_cs_csize,
1304
- stride_chunk_states_chunk,
1305
- stride_chunk_states_head,
1306
- stride_chunk_states_hdim,
1307
- stride_chunk_states_dstate,
1308
- stride_states_batch,
1309
- stride_states_head,
1310
- stride_states_hdim,
1311
- stride_states_dstate,
1312
- # Meta-parameters
1313
- BLOCK_SIZE_M: tl.constexpr,
1314
- BLOCK_SIZE_N: tl.constexpr,
1315
- BLOCK_SIZE_K: tl.constexpr,
1316
- ):
1317
- pid_b = tl.program_id(axis=1)
1318
- pid_h = tl.program_id(axis=2)
1319
- num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
1320
- pid_m = tl.program_id(axis=0) // num_pid_n
1321
- pid_n = tl.program_id(axis=0) % num_pid_n
1322
- end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
1323
- pid_c = (end_idx - 1) // chunk_size
1324
- b_ptr += (
1325
- pid_c * chunk_size * stride_b_seqlen
1326
- + (pid_h // nheads_ngroups_ratio) * stride_b_head
1327
- )
1328
- x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
1329
- dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
1330
- dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
1331
- chunk_states_ptr += (
1332
- pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
1333
- )
1334
-
1335
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1336
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1337
- offs_k = tl.arange(0, BLOCK_SIZE_K)
1338
- x_ptrs = x_ptr + (
1339
- offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
1340
- )
1341
- b_ptrs = b_ptr + (
1342
- offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
1343
- )
1344
- dt_ptrs = dt_ptr + offs_k * stride_dt_csize
1345
- dA_cs_last = tl.load(
1346
- dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
1347
- ).to(tl.float32)
1348
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
1349
-
1350
- chunk_size_limit = end_idx - pid_c * chunk_size
1351
- start_idx = tl.load(cu_seqlens_ptr + pid_b)
1352
- start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
1353
-
1354
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
1355
- for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
1356
- x = tl.load(
1357
- x_ptrs,
1358
- mask=(offs_m[:, None] < hdim)
1359
- & (offs_k[None, :] < chunk_size_limit - k)
1360
- & (offs_k[None, :] >= start_idx_cur - k),
1361
- other=0.0,
1362
- )
1363
- b = tl.load(
1364
- b_ptrs,
1365
- mask=(offs_k[:, None] < chunk_size_limit - k)
1366
- & (offs_n[None, :] < dstate)
1367
- & (offs_k[:, None] >= start_idx_cur - k),
1368
- other=0.0,
1369
- ).to(tl.float32)
1370
- dA_cs_k = tl.load(
1371
- dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
1372
- ).to(tl.float32)
1373
- dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
1374
- tl.float32
1375
- )
1376
- scale = tl.where(
1377
- (offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
1378
- tl.exp((dA_cs_last - dA_cs_k)) * dt_k,
1379
- 0.0,
1380
- )
1381
- b *= scale[:, None]
1382
- b = b.to(x_ptr.dtype.element_ty)
1383
- acc += tl.dot(x, b)
1384
- x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
1385
- b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
1386
- dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
1387
- dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
1388
-
1389
- # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
1390
- if start_idx < pid_c * chunk_size:
1391
- chunk_states_ptrs = chunk_states_ptr + (
1392
- offs_m[:, None] * stride_chunk_states_hdim
1393
- + offs_n[None, :] * stride_chunk_states_dstate
1394
- )
1395
- chunk_states = tl.load(
1396
- chunk_states_ptrs,
1397
- mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate),
1398
- other=0.0,
1399
- ).to(tl.float32)
1400
- # scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0)
1401
- scale = tl.exp(dA_cs_last)
1402
- acc += chunk_states * scale
1403
-
1404
- states = acc.to(states_ptr.dtype.element_ty)
1405
-
1406
- states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
1407
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1408
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1409
- states_ptrs = states_ptr + (
1410
- offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
1411
- )
1412
- c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
1413
- tl.store(states_ptrs, states, mask=c_mask)
1414
-
1415
-
1416
- def _chunk_cumsum_fwd(
1417
- dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))
1418
- ):
1419
- batch, seqlen, nheads = dt.shape
1420
- assert A.shape == (nheads,)
1421
- if dt_bias is not None:
1422
- assert dt_bias.shape == (nheads,)
1423
- nchunks = math.ceil(seqlen / chunk_size)
1424
- dt_out = torch.empty(
1425
- batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
1426
- )
1427
- dA_cumsum = torch.empty(
1428
- batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
1429
- )
1430
- grid_chunk_cs = lambda META: (
1431
- batch,
1432
- nchunks,
1433
- triton.cdiv(nheads, META["BLOCK_SIZE_H"]),
1434
- )
1435
- with torch.cuda.device(dt.device.index):
1436
- _chunk_cumsum_fwd_kernel[grid_chunk_cs](
1437
- dt,
1438
- A,
1439
- dt_bias,
1440
- dt_out,
1441
- dA_cumsum,
1442
- batch,
1443
- seqlen,
1444
- nheads,
1445
- chunk_size,
1446
- dt_limit[0],
1447
- dt_limit[1],
1448
- dt.stride(0),
1449
- dt.stride(1),
1450
- dt.stride(2),
1451
- A.stride(0),
1452
- dt_bias.stride(0) if dt_bias is not None else 0,
1453
- dt_out.stride(0),
1454
- dt_out.stride(2),
1455
- dt_out.stride(1),
1456
- dt_out.stride(3),
1457
- dA_cumsum.stride(0),
1458
- dA_cumsum.stride(2),
1459
- dA_cumsum.stride(1),
1460
- dA_cumsum.stride(3),
1461
- dt_softplus,
1462
- HAS_DT_BIAS=dt_bias is not None,
1463
- BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
1464
- )
1465
- return dA_cumsum, dt_out
1466
-
1467
-
1468
- def _chunk_cumsum_bwd(
1469
- ddA,
1470
- ddt_out,
1471
- dt,
1472
- A,
1473
- dt_bias=None,
1474
- dt_softplus=False,
1475
- dt_limit=(0.0, float("inf")),
1476
- ddt=None,
1477
- ):
1478
- batch, seqlen, nheads = dt.shape
1479
- _, _, nchunks, chunk_size = ddA.shape
1480
- assert ddA.shape == (batch, nheads, nchunks, chunk_size)
1481
- assert ddt_out.shape == (batch, nheads, nchunks, chunk_size)
1482
- assert A.shape == (nheads,)
1483
- if dt_bias is not None:
1484
- assert dt_bias.shape == (nheads,)
1485
- ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32)
1486
- else:
1487
- ddt_bias = None
1488
- if ddt is not None:
1489
- assert ddt.shape == dt.shape
1490
- else:
1491
- ddt = torch.empty_like(dt)
1492
- dA = torch.empty_like(A, dtype=torch.float32)
1493
- grid_chunk_cs = lambda META: (
1494
- batch,
1495
- nchunks,
1496
- triton.cdiv(nheads, META["BLOCK_SIZE_H"]),
1497
- )
1498
- with torch.cuda.device(dt.device.index):
1499
- _chunk_cumsum_bwd_kernel[grid_chunk_cs](
1500
- ddA,
1501
- ddt_out,
1502
- dt,
1503
- A,
1504
- dt_bias,
1505
- ddt,
1506
- dA,
1507
- ddt_bias,
1508
- batch,
1509
- seqlen,
1510
- nheads,
1511
- chunk_size,
1512
- dt_limit[0],
1513
- dt_limit[1],
1514
- ddA.stride(0),
1515
- ddA.stride(2),
1516
- ddA.stride(1),
1517
- ddA.stride(3),
1518
- ddt_out.stride(0),
1519
- ddt_out.stride(2),
1520
- ddt_out.stride(1),
1521
- ddt_out.stride(3),
1522
- dt.stride(0),
1523
- dt.stride(1),
1524
- dt.stride(2),
1525
- A.stride(0),
1526
- dt_bias.stride(0) if dt_bias is not None else 0,
1527
- ddt.stride(0),
1528
- ddt.stride(1),
1529
- ddt.stride(2),
1530
- dA.stride(0),
1531
- ddt_bias.stride(0) if ddt_bias is not None else 0,
1532
- dt_softplus,
1533
- HAS_DT_BIAS=dt_bias is not None,
1534
- BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
1535
- )
1536
- return ddt, dA, ddt_bias
1537
-
1538
-
1539
- def _chunk_state_fwd(
1540
- B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True
1541
- ):
1542
- batch, seqlen, nheads, headdim = x.shape
1543
- _, _, nchunks, chunk_size = dt.shape
1544
- _, _, ngroups, dstate = B.shape
1545
- assert nheads % ngroups == 0
1546
- assert B.shape == (batch, seqlen, ngroups, dstate)
1547
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
1548
- assert dA_cumsum.shape == dt.shape
1549
- if seq_idx is not None:
1550
- assert seq_idx.shape == (batch, seqlen)
1551
- if states is not None:
1552
- assert states.shape == (batch, nchunks, nheads, headdim, dstate)
1553
- else:
1554
- states_dtype = torch.float32 if states_in_fp32 else B.dtype
1555
- states = torch.empty(
1556
- (batch, nchunks, nheads, headdim, dstate),
1557
- device=x.device,
1558
- dtype=states_dtype,
1559
- )
1560
- grid = lambda META: (
1561
- triton.cdiv(headdim, META["BLOCK_SIZE_M"])
1562
- * triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
1563
- batch * nchunks,
1564
- nheads,
1565
- )
1566
- with torch.cuda.device(x.device.index):
1567
- _chunk_state_fwd_kernel[grid](
1568
- x,
1569
- B,
1570
- states,
1571
- dt,
1572
- dA_cumsum,
1573
- seq_idx,
1574
- headdim,
1575
- dstate,
1576
- chunk_size,
1577
- batch,
1578
- seqlen,
1579
- nheads // ngroups,
1580
- x.stride(0),
1581
- x.stride(1),
1582
- x.stride(2),
1583
- x.stride(3),
1584
- B.stride(0),
1585
- B.stride(1),
1586
- B.stride(2),
1587
- B.stride(-1),
1588
- states.stride(0),
1589
- states.stride(1),
1590
- states.stride(2),
1591
- states.stride(3),
1592
- states.stride(4),
1593
- dt.stride(0),
1594
- dt.stride(2),
1595
- dt.stride(1),
1596
- dt.stride(3),
1597
- dA_cumsum.stride(0),
1598
- dA_cumsum.stride(2),
1599
- dA_cumsum.stride(1),
1600
- dA_cumsum.stride(3),
1601
- *(
1602
- (seq_idx.stride(0), seq_idx.stride(1))
1603
- if seq_idx is not None
1604
- else (0, 0)
1605
- ),
1606
- HAS_SEQ_IDX=seq_idx is not None,
1607
- )
1608
- return states
1609
-
1610
-
1611
- def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None):
1612
- batch, seqlen, nheads, headdim = x.shape
1613
- _, _, nchunks, chunk_size = dt.shape
1614
- _, _, ngroups, dstate = B.shape
1615
- assert nheads % ngroups == 0
1616
- assert B.shape == (batch, seqlen, ngroups, dstate)
1617
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
1618
- assert dA_cumsum.shape == dt.shape
1619
- assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
1620
- if dx is not None:
1621
- assert dx.shape == x.shape
1622
- else:
1623
- dx = torch.empty_like(x)
1624
- ddt = torch.empty(
1625
- batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
1626
- )
1627
- ddA_cumsum = torch.empty(
1628
- batch, nheads, nchunks, chunk_size, device=dA_cumsum.device, dtype=torch.float32
1629
- )
1630
- grid_dx = lambda META: (
1631
- triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
1632
- * triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
1633
- batch * nchunks,
1634
- nheads,
1635
- )
1636
- with torch.cuda.device(x.device.index):
1637
- _chunk_state_bwd_dx_kernel[grid_dx](
1638
- x,
1639
- B,
1640
- dstates,
1641
- dt,
1642
- dA_cumsum,
1643
- dx,
1644
- ddt,
1645
- ddA_cumsum,
1646
- chunk_size,
1647
- headdim,
1648
- dstate,
1649
- batch,
1650
- seqlen,
1651
- nheads // ngroups,
1652
- x.stride(0),
1653
- x.stride(1),
1654
- x.stride(2),
1655
- x.stride(3),
1656
- B.stride(0),
1657
- B.stride(1),
1658
- B.stride(2),
1659
- B.stride(-1),
1660
- dstates.stride(0),
1661
- dstates.stride(1),
1662
- dstates.stride(2),
1663
- dstates.stride(3),
1664
- dstates.stride(4),
1665
- dt.stride(0),
1666
- dt.stride(2),
1667
- dt.stride(1),
1668
- dt.stride(3),
1669
- dA_cumsum.stride(0),
1670
- dA_cumsum.stride(2),
1671
- dA_cumsum.stride(1),
1672
- dA_cumsum.stride(3),
1673
- dx.stride(0),
1674
- dx.stride(1),
1675
- dx.stride(2),
1676
- dx.stride(3),
1677
- ddt.stride(0),
1678
- ddt.stride(2),
1679
- ddt.stride(1),
1680
- ddt.stride(3),
1681
- ddA_cumsum.stride(0),
1682
- ddA_cumsum.stride(2),
1683
- ddA_cumsum.stride(1),
1684
- ddA_cumsum.stride(3),
1685
- BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
1686
- )
1687
- return dx, ddt.to(dt.dtype), ddA_cumsum.to(dA_cumsum.dtype)
1688
-
1689
-
1690
- def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1):
1691
- batch, seqlen, nheads, headdim = x.shape
1692
- _, _, nchunks, chunk_size = dt.shape
1693
- dstate = dstates.shape[-1]
1694
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
1695
- assert dA_cumsum.shape == dt.shape
1696
- assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
1697
- if seq_idx is not None:
1698
- assert seq_idx.shape == (batch, seqlen)
1699
- if B is not None:
1700
- assert B.shape == (batch, seqlen, ngroups, dstate)
1701
- B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3))
1702
- # Use torch.empty since the Triton kernel will call init_to_zero
1703
- ddA_cumsum = torch.empty(
1704
- batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32
1705
- )
1706
- ddA_cumsum_strides = (
1707
- ddA_cumsum.stride(0),
1708
- ddA_cumsum.stride(2),
1709
- ddA_cumsum.stride(1),
1710
- ddA_cumsum.stride(3),
1711
- )
1712
- else:
1713
- B_strides = (0, 0, 0, 0)
1714
- ddA_cumsum = None
1715
- ddA_cumsum_strides = (0, 0, 0, 0)
1716
- nheads_ngroups_ratio = nheads // ngroups
1717
- sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
1718
- nheads_per_program = max(
1719
- min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1
1720
- )
1721
- nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)
1722
- dB = torch.empty(
1723
- batch, seqlen, nsplits, ngroups, dstate, device=x.device, dtype=torch.float32
1724
- )
1725
- grid_db = lambda META: (
1726
- triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
1727
- * triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
1728
- batch * nchunks,
1729
- nsplits * ngroups,
1730
- )
1731
- with torch.cuda.device(x.device.index):
1732
- _chunk_state_bwd_db_kernel[grid_db](
1733
- x,
1734
- dstates,
1735
- B,
1736
- dt,
1737
- dA_cumsum,
1738
- seq_idx,
1739
- dB,
1740
- ddA_cumsum,
1741
- chunk_size,
1742
- dstate,
1743
- headdim,
1744
- batch,
1745
- seqlen,
1746
- nheads,
1747
- nheads_per_program,
1748
- ngroups,
1749
- x.stride(0),
1750
- x.stride(1),
1751
- x.stride(2),
1752
- x.stride(3),
1753
- dstates.stride(0),
1754
- dstates.stride(1),
1755
- dstates.stride(2),
1756
- dstates.stride(3),
1757
- dstates.stride(4),
1758
- *B_strides,
1759
- dt.stride(0),
1760
- dt.stride(2),
1761
- dt.stride(1),
1762
- dt.stride(3),
1763
- dA_cumsum.stride(0),
1764
- dA_cumsum.stride(2),
1765
- dA_cumsum.stride(1),
1766
- dA_cumsum.stride(3),
1767
- *(
1768
- (seq_idx.stride(0), seq_idx.stride(1))
1769
- if seq_idx is not None
1770
- else (0, 0)
1771
- ),
1772
- dB.stride(0),
1773
- dB.stride(1),
1774
- dB.stride(2),
1775
- dB.stride(3),
1776
- dB.stride(4),
1777
- *ddA_cumsum_strides,
1778
- HAS_DDA_CS=ddA_cumsum is not None,
1779
- HAS_SEQ_IDX=seq_idx is not None,
1780
- BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
1781
- )
1782
- dB = dB.sum(2)
1783
- if ddA_cumsum is not None:
1784
- # The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute
1785
- # to the state of the chunk.
1786
- # torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
1787
- # But it's easier to just do the cumsum for all elements, the result will be the same.
1788
- torch.cumsum(ddA_cumsum, dim=-1, out=ddA_cumsum)
1789
- return dB if B is None else (dB, ddA_cumsum)
1790
-
1791
-
1792
- def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None):
1793
- batch, seqlen, nheads, headdim = x.shape
1794
- _, _, nchunks, chunk_size = dt.shape
1795
- _, _, ngroups, dstate = B.shape
1796
- assert nheads % ngroups == 0
1797
- assert B.shape == (batch, seqlen, ngroups, dstate)
1798
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
1799
- assert dA_cumsum.shape == dt.shape
1800
- assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
1801
- if seq_idx is not None:
1802
- assert seq_idx.shape == (batch, seqlen)
1803
- # Use torch.empty since the Triton kernel will call init_to_zero
1804
- ddA_cumsum = torch.empty(
1805
- batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32
1806
- )
1807
- grid_ddtcs = lambda META: (
1808
- triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
1809
- * triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
1810
- batch * nchunks,
1811
- nheads,
1812
- )
1813
- with torch.cuda.device(x.device.index):
1814
- _chunk_state_bwd_ddAcs_stable_kernel[grid_ddtcs](
1815
- x,
1816
- B,
1817
- dstates,
1818
- dt,
1819
- dA_cumsum,
1820
- seq_idx,
1821
- ddA_cumsum,
1822
- chunk_size,
1823
- headdim,
1824
- dstate,
1825
- batch,
1826
- seqlen,
1827
- nheads // ngroups,
1828
- x.stride(0),
1829
- x.stride(1),
1830
- x.stride(2),
1831
- x.stride(3),
1832
- B.stride(0),
1833
- B.stride(1),
1834
- B.stride(2),
1835
- B.stride(-1),
1836
- dstates.stride(0),
1837
- dstates.stride(1),
1838
- dstates.stride(2),
1839
- dstates.stride(3),
1840
- dstates.stride(4),
1841
- dt.stride(0),
1842
- dt.stride(2),
1843
- dt.stride(1),
1844
- dt.stride(3),
1845
- dA_cumsum.stride(0),
1846
- dA_cumsum.stride(2),
1847
- dA_cumsum.stride(1),
1848
- dA_cumsum.stride(3),
1849
- *(
1850
- (seq_idx.stride(0), seq_idx.stride(1))
1851
- if seq_idx is not None
1852
- else (0, 0)
1853
- ),
1854
- ddA_cumsum.stride(0),
1855
- ddA_cumsum.stride(2),
1856
- ddA_cumsum.stride(1),
1857
- ddA_cumsum.stride(3),
1858
- HAS_SEQ_IDX=seq_idx is not None,
1859
- BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16),
1860
- BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
1861
- )
1862
- torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
1863
- return ddA_cumsum
1864
-
1865
-
1866
- def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states):
1867
- total_seqlen, nheads, headdim = x.shape
1868
- _, nchunks, chunk_size = dt.shape
1869
- _, ngroups, dstate = B.shape
1870
- batch = cu_seqlens.shape[0] - 1
1871
- cu_seqlens = cu_seqlens.contiguous()
1872
- assert nheads % ngroups == 0
1873
- assert B.shape == (total_seqlen, ngroups, dstate)
1874
- assert dt.shape == (nheads, nchunks, chunk_size)
1875
- assert dA_cumsum.shape == dt.shape
1876
- assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
1877
- states = torch.empty(
1878
- batch,
1879
- nheads,
1880
- headdim,
1881
- dstate,
1882
- dtype=chunk_states.dtype,
1883
- device=chunk_states.device,
1884
- )
1885
- grid = lambda META: (
1886
- triton.cdiv(headdim, META["BLOCK_SIZE_M"])
1887
- * triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
1888
- batch,
1889
- nheads,
1890
- )
1891
- with torch.cuda.device(x.device.index):
1892
- _chunk_state_varlen_kernel[grid](
1893
- x,
1894
- B,
1895
- dt,
1896
- dA_cumsum,
1897
- chunk_states,
1898
- cu_seqlens,
1899
- states,
1900
- headdim,
1901
- dstate,
1902
- chunk_size,
1903
- total_seqlen,
1904
- nheads // ngroups,
1905
- x.stride(0),
1906
- x.stride(1),
1907
- x.stride(2),
1908
- B.stride(0),
1909
- B.stride(1),
1910
- B.stride(2),
1911
- dt.stride(1),
1912
- dt.stride(0),
1913
- dt.stride(2),
1914
- dA_cumsum.stride(1),
1915
- dA_cumsum.stride(0),
1916
- dA_cumsum.stride(2),
1917
- chunk_states.stride(0),
1918
- chunk_states.stride(1),
1919
- chunk_states.stride(2),
1920
- chunk_states.stride(3),
1921
- states.stride(0),
1922
- states.stride(1),
1923
- states.stride(2),
1924
- states.stride(3),
1925
- )
1926
- return states
1927
-
1928
-
1929
- class ChunkStateFn(torch.autograd.Function):
1930
-
1931
- @staticmethod
1932
- def forward(ctx, B, x, dt, dA_cumsum, states_in_fp32=True):
1933
- batch, seqlen, nheads, headdim = x.shape
1934
- _, _, nchunks, chunk_size = dt.shape
1935
- assert seqlen <= nchunks * chunk_size
1936
- _, _, ngroups, dstate = B.shape
1937
- assert B.shape == (batch, seqlen, ngroups, dstate)
1938
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
1939
- assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
1940
- if B.stride(-1) != 1:
1941
- B = B.contiguous()
1942
- if (
1943
- x.stride(-1) != 1 and x.stride(1) != 1
1944
- ): # Either M or K dimension should be contiguous
1945
- x = x.contiguous()
1946
- states = _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=states_in_fp32)
1947
- ctx.save_for_backward(B, x, dt, dA_cumsum)
1948
- return states
1949
-
1950
- @staticmethod
1951
- def backward(ctx, dstates):
1952
- B, x, dt, dA_cumsum = ctx.saved_tensors
1953
- batch, seqlen, nheads, headdim = x.shape
1954
- _, _, nchunks, chunk_size = dt.shape
1955
- _, _, ngroups, dstate = B.shape
1956
- assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
1957
- if dstates.stride(-1) != 1:
1958
- dstates = dstates.contiguous()
1959
- dx, ddt, ddA_cumsum = _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates)
1960
- dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, ngroups=ngroups)
1961
- dB = dB.to(B.dtype)
1962
- return dB, dx, ddt, ddA_cumsum, None
1963
-
1964
-
1965
- def chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True):
1966
- """
1967
- Argument:
1968
- B: (batch, seqlen, ngroups, headdim)
1969
- x: (batch, seqlen, nheads, headdim)
1970
- dt: (batch, nheads, nchunks, chunk_size)
1971
- dA_cumsum: (batch, nheads, nchunks, chunk_size)
1972
- Return:
1973
- states: (batch, nchunks, nheads, headdim, dstate)
1974
- """
1975
- return ChunkStateFn.apply(B, x, dt, dA_cumsum, states_in_fp32)
1976
-
1977
-
1978
- def chunk_state_ref(B, x, dt, dA_cumsum):
1979
- """
1980
- Argument:
1981
- B: (batch, seqlen, ngroups, headdim)
1982
- x: (batch, seqlen, nheads, headdim)
1983
- dt: (batch, nheads, nchunks, chunk_size)
1984
- dA_cumsum: (batch, nheads, nchunks, chunk_size)
1985
- Return:
1986
- states: (batch, nchunks, nheads, headdim, dstate)
1987
- """
1988
- # Check constraints.
1989
- batch, seqlen, nheads, headdim = x.shape
1990
- dstate = B.shape[-1]
1991
- _, _, nchunks, chunk_size = dt.shape
1992
- assert seqlen <= nchunks * chunk_size
1993
- assert x.shape == (batch, seqlen, nheads, headdim)
1994
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
1995
- ngroups = B.shape[2]
1996
- assert nheads % ngroups == 0
1997
- assert B.shape == (batch, seqlen, ngroups, dstate)
1998
- B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
1999
- assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
2000
- if seqlen < nchunks * chunk_size:
2001
- x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
2002
- B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
2003
- x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size)
2004
- B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size)
2005
- decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum))
2006
- return torch.einsum(
2007
- "bclhn,bhcl,bhcl,bclhp->bchpn",
2008
- B.to(x.dtype),
2009
- decay_states.to(x.dtype),
2010
- dt.to(x.dtype),
2011
- x,
2012
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/ops/triton/ssd_combined.py DELETED
@@ -1,1884 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao, Albert Gu.
2
-
3
- """We want triton==2.1.0 or 2.2.0 for this
4
- """
5
-
6
- from typing import Optional
7
-
8
- import math
9
- from packaging import version
10
-
11
- import torch
12
- import torch.nn.functional as F
13
- from torch import Tensor
14
- from ...utils.torch import custom_bwd, custom_fwd
15
-
16
- import triton
17
- import triton.language as tl
18
-
19
- from einops import rearrange, repeat
20
-
21
- try:
22
- from causal_conv1d import causal_conv1d_fn
23
- import causal_conv1d_cuda
24
- except ImportError:
25
- causal_conv1d_fn, causal_conv1d_cuda = None, None
26
-
27
- from .ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
28
- from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd
29
- from .ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db
30
- from .ssd_chunk_state import _chunk_state_bwd_ddAcs_stable
31
- from .ssd_chunk_state import chunk_state, chunk_state_ref
32
- from .ssd_chunk_state import chunk_state_varlen
33
- from .ssd_state_passing import _state_passing_fwd, _state_passing_bwd
34
- from .ssd_state_passing import state_passing, state_passing_ref
35
- from .ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates
36
- from .ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb
37
- from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable
38
- from .ssd_chunk_scan import chunk_scan, chunk_scan_ref
39
- from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev
40
- from .layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd
41
- from .k_activations import _swiglu_fwd, _swiglu_bwd
42
-
43
- TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
44
-
45
-
46
- def init_to_zero(names):
47
- return lambda nargs: [
48
- nargs[name].zero_() for name in names if nargs[name] is not None
49
- ]
50
-
51
-
52
- @triton.autotune(
53
- configs=[
54
- triton.Config(
55
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
56
- num_stages=3,
57
- num_warps=8,
58
- pre_hook=init_to_zero(["ddt_ptr"]),
59
- ),
60
- triton.Config(
61
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
62
- num_stages=4,
63
- num_warps=4,
64
- pre_hook=init_to_zero(["ddt_ptr"]),
65
- ),
66
- triton.Config(
67
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
68
- num_stages=4,
69
- num_warps=4,
70
- pre_hook=init_to_zero(["ddt_ptr"]),
71
- ),
72
- triton.Config(
73
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
74
- num_stages=4,
75
- num_warps=4,
76
- pre_hook=init_to_zero(["ddt_ptr"]),
77
- ),
78
- triton.Config(
79
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
80
- num_stages=4,
81
- num_warps=4,
82
- pre_hook=init_to_zero(["ddt_ptr"]),
83
- ),
84
- triton.Config(
85
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
86
- num_stages=4,
87
- num_warps=4,
88
- pre_hook=init_to_zero(["ddt_ptr"]),
89
- ),
90
- triton.Config(
91
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
92
- num_stages=5,
93
- num_warps=4,
94
- pre_hook=init_to_zero(["ddt_ptr"]),
95
- ),
96
- triton.Config(
97
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
98
- num_stages=5,
99
- num_warps=4,
100
- pre_hook=init_to_zero(["ddt_ptr"]),
101
- ),
102
- triton.Config(
103
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
104
- num_stages=4,
105
- num_warps=4,
106
- pre_hook=init_to_zero(["ddt_ptr"]),
107
- ),
108
- ],
109
- key=["chunk_size", "hdim", "dstate"],
110
- )
111
- @triton.jit
112
- def _chunk_scan_chunk_state_bwd_dx_kernel(
113
- # Pointers to matrices
114
- x_ptr,
115
- cb_ptr,
116
- dout_ptr,
117
- dt_ptr,
118
- dA_cumsum_ptr,
119
- seq_idx_ptr,
120
- D_ptr,
121
- b_ptr,
122
- dstates_ptr,
123
- dx_ptr,
124
- ddt_ptr,
125
- dD_ptr,
126
- # Matrix dimensions
127
- chunk_size,
128
- hdim,
129
- dstate,
130
- batch,
131
- seqlen,
132
- nheads_ngroups_ratio,
133
- # Strides
134
- stride_x_batch,
135
- stride_x_seqlen,
136
- stride_x_head,
137
- stride_x_hdim,
138
- stride_cb_batch,
139
- stride_cb_chunk,
140
- stride_cb_head,
141
- stride_cb_csize_m,
142
- stride_cb_csize_k,
143
- stride_dout_batch,
144
- stride_dout_seqlen,
145
- stride_dout_head,
146
- stride_dout_hdim,
147
- stride_dt_batch,
148
- stride_dt_chunk,
149
- stride_dt_head,
150
- stride_dt_csize,
151
- stride_dA_cs_batch,
152
- stride_dA_cs_chunk,
153
- stride_dA_cs_head,
154
- stride_dA_cs_csize,
155
- stride_seq_idx_batch,
156
- stride_seq_idx_seqlen,
157
- stride_D_head,
158
- stride_b_batch,
159
- stride_b_seqlen,
160
- stride_b_head,
161
- stride_b_dstate,
162
- stride_dstates_batch,
163
- stride_dstates_chunk,
164
- stride_dstates_head,
165
- stride_dstates_hdim,
166
- stride_dstates_dstate,
167
- stride_dx_batch,
168
- stride_dx_seqlen,
169
- stride_dx_head,
170
- stride_dx_hdim,
171
- stride_ddt_batch,
172
- stride_ddt_chunk,
173
- stride_ddt_head,
174
- stride_ddt_csize,
175
- stride_dD_batch,
176
- stride_dD_chunk,
177
- stride_dD_head,
178
- stride_dD_csize,
179
- stride_dD_hdim,
180
- # Meta-parameters
181
- HAS_D: tl.constexpr,
182
- D_HAS_HDIM: tl.constexpr,
183
- HAS_SEQ_IDX: tl.constexpr,
184
- BLOCK_SIZE_M: tl.constexpr,
185
- BLOCK_SIZE_N: tl.constexpr,
186
- BLOCK_SIZE_K: tl.constexpr,
187
- BLOCK_SIZE_DSTATE: tl.constexpr,
188
- IS_TRITON_22: tl.constexpr,
189
- ):
190
- pid_bc = tl.program_id(axis=1)
191
- pid_c = pid_bc // batch
192
- pid_b = pid_bc - pid_c * batch
193
- pid_h = tl.program_id(axis=2)
194
- num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
195
- pid_m = tl.program_id(axis=0) // num_pid_n
196
- pid_n = tl.program_id(axis=0) % num_pid_n
197
- x_ptr += (
198
- pid_b * stride_x_batch
199
- + pid_c * chunk_size * stride_x_seqlen
200
- + pid_h * stride_x_head
201
- )
202
- cb_ptr += (
203
- pid_b * stride_cb_batch
204
- + pid_c * stride_cb_chunk
205
- + (pid_h // nheads_ngroups_ratio) * stride_cb_head
206
- )
207
- dout_ptr += (
208
- pid_b * stride_dout_batch
209
- + pid_c * chunk_size * stride_dout_seqlen
210
- + pid_h * stride_dout_head
211
- )
212
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
213
- ddt_ptr += (
214
- pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
215
- )
216
- dA_cumsum_ptr += (
217
- pid_b * stride_dA_cs_batch
218
- + pid_c * stride_dA_cs_chunk
219
- + pid_h * stride_dA_cs_head
220
- )
221
- b_ptr += (
222
- pid_b * stride_b_batch
223
- + pid_c * chunk_size * stride_b_seqlen
224
- + (pid_h // nheads_ngroups_ratio) * stride_b_head
225
- )
226
- dstates_ptr += (
227
- pid_b * stride_dstates_batch
228
- + pid_c * stride_dstates_chunk
229
- + pid_h * stride_dstates_head
230
- )
231
- if HAS_SEQ_IDX:
232
- seq_idx_ptr += (
233
- pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
234
- )
235
-
236
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
237
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
238
-
239
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
240
-
241
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
242
-
243
- dA_cs_m = tl.load(
244
- dA_cumsum_ptr + offs_m * stride_dA_cs_csize,
245
- mask=offs_m < chunk_size_limit,
246
- other=0.0,
247
- ).to(tl.float32)
248
-
249
- dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
250
- tl.float32
251
- )
252
- if not HAS_SEQ_IDX:
253
- scale = tl.exp(dA_cs_last - dA_cs_m)
254
- else:
255
- seq_idx_m = tl.load(
256
- seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
257
- mask=offs_m < chunk_size_limit,
258
- other=-1,
259
- )
260
- seq_idx_last = tl.load(
261
- seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
262
- )
263
- scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
264
- # Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
265
- # However, we're getting error with the Triton compiler 2.1.0 for that code path:
266
- # Unexpected mma -> mma layout conversion
267
- # Triton 2.2.0 fixes this
268
- offs_dstate = tl.arange(
269
- 0,
270
- (
271
- BLOCK_SIZE_DSTATE
272
- if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128
273
- else BLOCK_SIZE_K
274
- ),
275
- )
276
- b_ptrs = b_ptr + (
277
- offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate
278
- )
279
- dstates_ptrs = dstates_ptr + (
280
- offs_n[None, :] * stride_dstates_hdim
281
- + offs_dstate[:, None] * stride_dstates_dstate
282
- )
283
- if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128:
284
- b = tl.load(
285
- b_ptrs,
286
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate),
287
- other=0.0,
288
- )
289
- dstates = tl.load(
290
- dstates_ptrs,
291
- mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim),
292
- other=0.0,
293
- )
294
- dstates = dstates.to(b_ptr.dtype.element_ty)
295
- acc = tl.dot(b, dstates) * scale[:, None]
296
- else:
297
- for k in range(0, dstate, BLOCK_SIZE_K):
298
- b = tl.load(
299
- b_ptrs,
300
- mask=(offs_m[:, None] < chunk_size_limit)
301
- & (offs_dstate[None, :] < dstate - k),
302
- other=0.0,
303
- )
304
- dstates = tl.load(
305
- dstates_ptrs,
306
- mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim),
307
- other=0.0,
308
- )
309
- dstates = dstates.to(b_ptr.dtype.element_ty)
310
- acc += tl.dot(b, dstates)
311
- b_ptrs += BLOCK_SIZE_K * stride_b_dstate
312
- dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate
313
- acc *= scale[:, None]
314
-
315
- # x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
316
- # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
317
- # dt_ptrs = dt_ptr + offs_m * stride_dt_csize
318
- # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
319
- # ddt = tl.sum(acc * x, axis=1) * dt_m
320
- # ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
321
- # tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
322
-
323
- offs_k = tl.arange(0, BLOCK_SIZE_K)
324
- cb_ptrs = cb_ptr + (
325
- offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k
326
- )
327
- dout_ptrs = dout_ptr + (
328
- offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim
329
- )
330
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
331
- K_MAX = chunk_size_limit
332
- K_MIN = pid_m * BLOCK_SIZE_M
333
- cb_ptrs += K_MIN * stride_cb_csize_k
334
- dout_ptrs += K_MIN * stride_dout_seqlen
335
- dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize
336
- for k in range(K_MIN, K_MAX, BLOCK_SIZE_K):
337
- k = tl.multiple_of(k, BLOCK_SIZE_K)
338
- # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower
339
- cb = tl.load(
340
- cb_ptrs,
341
- mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k),
342
- other=0.0,
343
- )
344
- dout = tl.load(
345
- dout_ptrs,
346
- mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim),
347
- other=0.0,
348
- )
349
- dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(
350
- tl.float32
351
- )
352
- cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
353
- # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
354
- # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
355
- # Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
356
- # This will cause NaN in acc, and hence NaN in dx and ddt.
357
- mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX)
358
- cb = tl.where(mask, cb, 0.0)
359
- cb = cb.to(dout_ptr.dtype.element_ty)
360
- acc += tl.dot(cb, dout)
361
- cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
362
- dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen
363
- dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
364
-
365
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
366
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
367
- dt_ptrs = dt_ptr + offs_m * stride_dt_csize
368
- dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
369
- dx = acc * dt_m[:, None]
370
- dx_ptr += (
371
- pid_b * stride_dx_batch
372
- + pid_c * chunk_size * stride_dx_seqlen
373
- + pid_h * stride_dx_head
374
- )
375
- dx_ptrs = dx_ptr + (
376
- offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim
377
- )
378
- if HAS_D:
379
- dout_res_ptrs = dout_ptr + (
380
- offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim
381
- )
382
- dout_res = tl.load(
383
- dout_res_ptrs,
384
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
385
- other=0.0,
386
- ).to(tl.float32)
387
- if D_HAS_HDIM:
388
- D = tl.load(
389
- D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0
390
- ).to(tl.float32)
391
- else:
392
- D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
393
- dx += dout_res * D
394
- tl.store(
395
- dx_ptrs,
396
- dx,
397
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
398
- )
399
-
400
- x_ptrs = x_ptr + (
401
- offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
402
- )
403
- x = tl.load(
404
- x_ptrs,
405
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
406
- other=0.0,
407
- ).to(tl.float32)
408
- if HAS_D:
409
- dD_ptr += (
410
- pid_b * stride_dD_batch
411
- + pid_c * stride_dD_chunk
412
- + pid_h * stride_dD_head
413
- + pid_m * stride_dD_csize
414
- )
415
- if D_HAS_HDIM:
416
- dD_ptrs = dD_ptr + offs_n * stride_dD_hdim
417
- dD = tl.sum(dout_res * x, axis=0)
418
- tl.store(dD_ptrs, dD, mask=offs_n < hdim)
419
- else:
420
- dD = tl.sum(dout_res * x)
421
- tl.store(dD_ptr, dD)
422
- ddt = tl.sum(acc * x, axis=1)
423
- ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
424
- tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
425
-
426
-
427
- def _chunk_scan_chunk_state_bwd_dx(
428
- x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None
429
- ):
430
- batch, seqlen, nheads, headdim = x.shape
431
- _, _, nchunks, chunk_size = dt.shape
432
- _, _, ngroups, dstate = B.shape
433
- assert nheads % ngroups == 0
434
- assert B.shape == (batch, seqlen, ngroups, dstate)
435
- assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
436
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
437
- assert dA_cumsum.shape == dt.shape
438
- assert dout.shape == x.shape
439
- assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
440
- if seq_idx is not None:
441
- assert seq_idx.shape == (batch, seqlen)
442
- if D is not None:
443
- assert D.shape == (nheads, headdim) or D.shape == (nheads,)
444
- assert D.stride(-1) == 1
445
- BLOCK_SIZE_min = 32
446
- dD = torch.empty(
447
- triton.cdiv(chunk_size, BLOCK_SIZE_min),
448
- batch,
449
- nchunks,
450
- nheads,
451
- headdim if D.dim() == 2 else 1,
452
- device=D.device,
453
- dtype=torch.float32,
454
- )
455
- else:
456
- dD = None
457
- dD_strides = (
458
- (dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))
459
- if D is not None
460
- else (0, 0, 0, 0, 0)
461
- )
462
- if dx is None:
463
- dx = torch.empty_like(x)
464
- else:
465
- assert dx.shape == x.shape
466
- ddt = torch.empty(
467
- batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32
468
- )
469
- grid_dx = lambda META: (
470
- triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
471
- * triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
472
- batch * nchunks,
473
- nheads,
474
- )
475
- with torch.cuda.device(x.device.index):
476
- _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](
477
- x,
478
- CB,
479
- dout,
480
- dt,
481
- dA_cumsum,
482
- seq_idx,
483
- D,
484
- B,
485
- dstates,
486
- dx,
487
- ddt,
488
- dD,
489
- chunk_size,
490
- headdim,
491
- dstate,
492
- batch,
493
- seqlen,
494
- nheads // ngroups,
495
- x.stride(0),
496
- x.stride(1),
497
- x.stride(2),
498
- x.stride(3),
499
- CB.stride(0),
500
- CB.stride(1),
501
- CB.stride(2),
502
- CB.stride(-1),
503
- CB.stride(-2),
504
- dout.stride(0),
505
- dout.stride(1),
506
- dout.stride(2),
507
- dout.stride(3),
508
- dt.stride(0),
509
- dt.stride(2),
510
- dt.stride(1),
511
- dt.stride(3),
512
- dA_cumsum.stride(0),
513
- dA_cumsum.stride(2),
514
- dA_cumsum.stride(1),
515
- dA_cumsum.stride(3),
516
- *(
517
- (seq_idx.stride(0), seq_idx.stride(1))
518
- if seq_idx is not None
519
- else (0, 0)
520
- ),
521
- D.stride(0) if D is not None else 0,
522
- B.stride(0),
523
- B.stride(1),
524
- B.stride(2),
525
- B.stride(3),
526
- dstates.stride(0),
527
- dstates.stride(1),
528
- dstates.stride(2),
529
- dstates.stride(3),
530
- dstates.stride(4),
531
- dx.stride(0),
532
- dx.stride(1),
533
- dx.stride(2),
534
- dx.stride(3),
535
- ddt.stride(0),
536
- ddt.stride(2),
537
- ddt.stride(1),
538
- ddt.stride(3),
539
- dD_strides[1],
540
- dD_strides[2],
541
- dD_strides[3],
542
- dD_strides[0],
543
- dD_strides[4],
544
- D is not None,
545
- D.dim() == 2 if D is not None else True,
546
- HAS_SEQ_IDX=seq_idx is not None,
547
- BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
548
- IS_TRITON_22=TRITON_22
549
- )
550
- if D is not None:
551
- BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs[
552
- "BLOCK_SIZE_M"
553
- ]
554
- n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
555
- dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
556
- if D.dim() == 1:
557
- dD = rearrange(dD, "h 1 -> h")
558
- return dx, ddt.to(dtype=dt.dtype), dD
559
-
560
-
561
- def _mamba_chunk_scan_combined_fwd(
562
- x,
563
- dt,
564
- A,
565
- B,
566
- C,
567
- chunk_size,
568
- D=None,
569
- z=None,
570
- dt_bias=None,
571
- initial_states=None,
572
- seq_idx=None,
573
- cu_seqlens=None,
574
- dt_softplus=False,
575
- dt_limit=(0.0, float("inf")),
576
- ):
577
- batch, seqlen, nheads, headdim = x.shape
578
- _, _, ngroups, dstate = B.shape
579
- assert nheads % ngroups == 0
580
- assert B.shape == (batch, seqlen, ngroups, dstate)
581
- assert x.shape == (batch, seqlen, nheads, headdim)
582
- assert dt.shape == (batch, seqlen, nheads)
583
- assert A.shape == (nheads,)
584
- assert C.shape == B.shape
585
- if z is not None:
586
- assert z.shape == x.shape
587
- if D is not None:
588
- assert D.shape == (nheads, headdim) or D.shape == (nheads,)
589
- if seq_idx is not None:
590
- assert seq_idx.shape == (batch, seqlen)
591
- if B.stride(-1) != 1:
592
- B = B.contiguous()
593
- if C.stride(-1) != 1:
594
- C = C.contiguous()
595
- if (
596
- x.stride(-1) != 1 and x.stride(1) != 1
597
- ): # Either M or K dimension should be contiguous
598
- x = x.contiguous()
599
- if (
600
- z is not None and z.stride(-1) != 1 and z.stride(1) != 1
601
- ): # Either M or K dimension should be contiguous
602
- z = z.contiguous()
603
- if D is not None and D.stride(-1) != 1:
604
- D = D.contiguous()
605
- if initial_states is not None:
606
- assert initial_states.shape == (batch, nheads, headdim, dstate)
607
- # # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size)
608
- # dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
609
- # dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
610
- # dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
611
- dA_cumsum, dt = _chunk_cumsum_fwd(
612
- dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit
613
- )
614
- states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
615
- # states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True)
616
- # states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True)
617
- # states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True)
618
- states, final_states = _state_passing_fwd(
619
- rearrange(states, "... p n -> ... (p n)"),
620
- dA_cumsum[:, :, :, -1],
621
- initial_states=(
622
- rearrange(initial_states, "... p n -> ... (p n)")
623
- if initial_states is not None
624
- else None
625
- ),
626
- seq_idx=seq_idx,
627
- chunk_size=chunk_size,
628
- out_dtype=C.dtype,
629
- )
630
- states, final_states = [
631
- rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]
632
- ]
633
- # states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
634
- # states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
635
- CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
636
- out, out_x = _chunk_scan_fwd(
637
- CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx
638
- )
639
- if cu_seqlens is None:
640
- return out, out_x, dt, dA_cumsum, states, final_states
641
- else:
642
- assert (
643
- batch == 1
644
- ), "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
645
- varlen_states = chunk_state_varlen(
646
- B.squeeze(0),
647
- x.squeeze(0),
648
- dt.squeeze(0),
649
- dA_cumsum.squeeze(0),
650
- cu_seqlens,
651
- states.squeeze(0),
652
- )
653
- return out, out_x, dt, dA_cumsum, states, final_states, varlen_states
654
-
655
-
656
- def _mamba_chunk_scan_combined_bwd(
657
- dout,
658
- x,
659
- dt,
660
- A,
661
- B,
662
- C,
663
- out,
664
- chunk_size,
665
- D=None,
666
- z=None,
667
- dt_bias=None,
668
- initial_states=None,
669
- dfinal_states=None,
670
- seq_idx=None,
671
- dt_softplus=False,
672
- dt_limit=(0.0, float("inf")),
673
- dx=None,
674
- ddt=None,
675
- dB=None,
676
- dC=None,
677
- dz=None,
678
- recompute_output=False,
679
- ):
680
- if dout.stride(-1) != 1:
681
- dout = dout.contiguous()
682
- batch, seqlen, nheads, headdim = x.shape
683
- nchunks = math.ceil(seqlen / chunk_size)
684
- _, _, ngroups, dstate = B.shape
685
- assert dout.shape == (batch, seqlen, nheads, headdim)
686
- assert dt.shape == (batch, seqlen, nheads)
687
- assert A.shape == (nheads,)
688
- assert nheads % ngroups == 0
689
- assert B.shape == (batch, seqlen, ngroups, dstate)
690
- assert C.shape == B.shape
691
- assert out.shape == x.shape
692
- if initial_states is not None:
693
- assert initial_states.shape == (batch, nheads, headdim, dstate)
694
- if seq_idx is not None:
695
- assert seq_idx.shape == (batch, seqlen)
696
- if dx is not None:
697
- assert dx.shape == x.shape
698
- if dB is not None:
699
- assert dB.shape == B.shape
700
- dB_given = dB
701
- else:
702
- dB_given = torch.empty_like(B)
703
- if dC is not None:
704
- assert dC.shape == C.shape
705
- dC_given = dC
706
- else:
707
- dC_given = torch.empty_like(C)
708
- if dz is not None:
709
- assert z is not None
710
- assert dz.shape == z.shape
711
- if ddt is not None:
712
- assert ddt.shape == dt.shape
713
- ddt_given = ddt
714
- else:
715
- ddt_given = torch.empty_like(dt)
716
- # TD: For some reason Triton (2.1.0 and 2.2.0) errors with
717
- # "[CUDA]: invalid device context" (e.g. during varlne test), and cloning makes it work. Idk why.
718
- dt_in = dt.clone()
719
- dA_cumsum, dt = _chunk_cumsum_fwd(
720
- dt_in,
721
- A,
722
- chunk_size,
723
- dt_bias=dt_bias,
724
- dt_softplus=dt_softplus,
725
- dt_limit=dt_limit,
726
- )
727
- CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
728
- states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
729
- states, _ = _state_passing_fwd(
730
- rearrange(states, "... p n -> ... (p n)"),
731
- dA_cumsum[:, :, :, -1],
732
- initial_states=(
733
- rearrange(initial_states, "... p n -> ... (p n)")
734
- if initial_states is not None
735
- else None
736
- ),
737
- seq_idx=seq_idx,
738
- chunk_size=chunk_size,
739
- )
740
- states = rearrange(states, "... (p n) -> ... p n", n=dstate)
741
- if z is not None:
742
- dz, dout, dD, *rest = _chunk_scan_bwd_dz(
743
- x,
744
- z,
745
- out,
746
- dout,
747
- chunk_size=chunk_size,
748
- has_ddAcs=False,
749
- D=D,
750
- dz=dz,
751
- recompute_output=recompute_output,
752
- )
753
- outz = rest[0] if recompute_output else out
754
- else:
755
- dz = None
756
- outz = out
757
- dstates = _chunk_scan_bwd_dstates(
758
- C, dA_cumsum, dout, seq_idx=seq_idx, dtype=states.dtype
759
- )
760
- # dstates has length nchunks, containing the gradient to initial states at index 0 and
761
- # gradient to the states of chunk (nchunks - 2) at index (nchunks - 1)
762
- # Do computation in fp32 but convert dstates and states to fp16/bf16 since dstates and states
763
- # will be used in matmul in the next kernels.
764
- dstates, ddA_chunk_cumsum, dinitial_states, states = _state_passing_bwd(
765
- rearrange(states, "... p n -> ... (p n)"),
766
- dA_cumsum[:, :, :, -1],
767
- rearrange(dstates, "... p n -> ... (p n)"),
768
- dfinal_states=(
769
- rearrange(dfinal_states, "... p n -> ... (p n)")
770
- if dfinal_states is not None
771
- else None
772
- ),
773
- seq_idx=seq_idx,
774
- has_initial_states=initial_states is not None,
775
- dstates_dtype=x.dtype,
776
- states_dtype=x.dtype,
777
- chunk_size=chunk_size,
778
- )
779
- # dstates has length nchunks, containing the gradient to states of chunk 0 at index 0 and
780
- # gradient to the final states at index (nchunks - 1)
781
- # states has length nchunks, containing the initial states at index 0 and the state for chunk (nchunks - 2) at index (nchunks - 1)
782
- # The final states is not stored.
783
- states = rearrange(states, "... (p n) -> ... p n", n=dstate)
784
- dstates = rearrange(dstates, "... (p n) -> ... p n", n=dstate)
785
- dinitial_states = (
786
- rearrange(dinitial_states, "... (p n) -> ... p n", n=dstate)
787
- if dinitial_states is not None
788
- else None
789
- )
790
- dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(
791
- x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx
792
- )
793
- # dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, ngroups=ngroups)
794
- dB, ddA_next = _chunk_state_bwd_db(
795
- x, dt, dA_cumsum, dstates, seq_idx=seq_idx, B=B, ngroups=ngroups
796
- )
797
- # dC = _chunk_scan_bwd_dC(states[:, :-1].to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
798
- dC, ddA_cumsum_prev = _chunk_scan_bwd_dC(
799
- states.to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, C=C, ngroups=ngroups
800
- )
801
- # Computing ddA with the dcb kernel is much slower, so we're not using it for now
802
- dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
803
- # dCB, ddA_tmp = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, CB=CB, ngroups=ngroups)
804
- dCB = dCB.to(CB.dtype)
805
- _bmm_chunk_bwd(C, dCB, residual=dB, out=dB_given)
806
- _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC, out=dC_given)
807
- # If we have z, then dout_x is recomputed in fp32 so dD = (dout_x * x).sum() is more accurate
808
- # than dD_from_x = (dout_x * x).sum() where dout_x is in fp16/bf16
809
- if z is None:
810
- dD = dD_from_x
811
- # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D.
812
- # ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt
813
- # However, this is numerically unstable: when we do the reverse cumsum on ddA_cumsum, there might
814
- # be a lot of underflow.
815
-
816
- # This is already done as part of bwd_dC kernel
817
- # ddA_cumsum_prev = _chunk_scan_bwd_ddAcs_prev(states[:, :-1], C, dout, dA_cumsum, seq_idx=seq_idx)
818
- ddA_cumsum_prev[..., -1] += ddA_chunk_cumsum
819
- ddA_prev = ddA_cumsum_prev.flip([-1]).cumsum(dim=-1).flip([-1])
820
- # This is already done as part of bwd_dB kernel
821
- # ddA_next = _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=seq_idx)
822
- # We don't need to pass in seq_idx because CB also zeros out entries where seq_idx[i] != seq_idx[j]
823
- ddA = _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, CB)
824
- ddA += ddA_next + ddA_prev
825
-
826
- ddt_given, dA, ddt_bias = _chunk_cumsum_bwd(
827
- ddA,
828
- ddt,
829
- dt_in,
830
- A,
831
- dt_bias=dt_bias,
832
- dt_softplus=dt_softplus,
833
- dt_limit=dt_limit,
834
- ddt=ddt_given,
835
- )
836
-
837
- # These 2 lines are just to test ddt and dA being computed by old code
838
- # _, dA = selective_scan_bwd(dout, x, dt, A, B, C, D=D.float(), z=z)
839
- # ddt_given.copy_(ddt)
840
-
841
- return_vals = (
842
- dx,
843
- ddt_given,
844
- dA,
845
- dB_given,
846
- dC_given,
847
- dD,
848
- dz,
849
- ddt_bias,
850
- dinitial_states,
851
- )
852
- return return_vals if not recompute_output else (*return_vals, outz)
853
-
854
-
855
- def selective_scan_bwd(dout, x, dt, A, B, C, D=None, z=None):
856
- """
857
- Argument:
858
- dout: (batch, seqlen, nheads, headdim)
859
- x: (batch, seqlen, nheads, headdim)
860
- dt: (batch, nheads, nchunks, chunk_size) or (batch, nheads, headdim, nchunks, chunk_size)
861
- A: (nheads) or (dim, dstate)
862
- B: (batch, seqlen, ngroups, dstate)
863
- C: (batch, seqlen, ngroups, dstate)
864
- D: (nheads, headdim) or (nheads,)
865
- z: (batch, seqlen, nheads, headdim)
866
- Return:
867
- out: (batch, seqlen, nheads, headdim)
868
- """
869
- import selective_scan
870
-
871
- batch, seqlen, nheads, headdim = x.shape
872
- chunk_size = dt.shape[-1]
873
- _, _, ngroups, dstate = B.shape
874
- assert nheads % ngroups == 0
875
- x = rearrange(x, "b l h p -> b (h p) l")
876
- squeeze_dt = dt.dim() == 4
877
- if dt.dim() == 4:
878
- dt = repeat(dt, "b h c l -> b h p c l", p=headdim)
879
- dt = rearrange(dt, "b h p c l -> b (h p) (c l)", p=headdim)
880
- squeeze_A = A.dim() == 1
881
- if A.dim() == 1:
882
- A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
883
- else:
884
- A = A.to(dtype=torch.float32)
885
- B = rearrange(B, "b l g n -> b g n l")
886
- C = rearrange(C, "b l g n -> b g n l")
887
- if D is not None:
888
- if D.dim() == 2:
889
- D = rearrange(D, "h p -> (h p)")
890
- else:
891
- D = repeat(D, "h -> (h p)", p=headdim)
892
- if z is not None:
893
- z = rearrange(z, "b l h p -> b (h p) l")
894
-
895
- if x.stride(-1) != 1:
896
- x = x.contiguous()
897
- if dt.stride(-1) != 1:
898
- dt = dt.contiguous()
899
- if D is not None:
900
- D = D.contiguous()
901
- if B.stride(-1) != 1:
902
- B = B.contiguous()
903
- if C.stride(-1) != 1:
904
- C = C.contiguous()
905
- if z is not None and z.stride(-1) != 1:
906
- z = z.contiguous()
907
- _, intermediate, *rest = selective_scan.fwd(
908
- x, dt.to(dtype=x.dtype), A, B, C, D, z, None, False
909
- )
910
- if z is not None:
911
- out = rest[0]
912
- else:
913
- out = None
914
-
915
- dout = rearrange(dout, "b l h p -> b (h p) l")
916
-
917
- if dout.stride(-1) != 1:
918
- dout = dout.contiguous()
919
- # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
920
- # backward of selective_scan with the backward of chunk).
921
- # Here we just pass in None and dz will be allocated in the C++ code.
922
- _, ddt, dA, *rest = selective_scan.bwd(
923
- x,
924
- dt.to(dtype=x.dtype),
925
- A,
926
- B,
927
- C,
928
- D,
929
- z,
930
- None,
931
- dout,
932
- intermediate,
933
- out,
934
- None,
935
- False,
936
- False, # option to recompute out_z, not used here
937
- )
938
- ddt = rearrange(ddt, "b (h p) (c l) -> b h p c l", p=headdim, l=chunk_size)
939
- if squeeze_dt:
940
- ddt = ddt.float().sum(dim=2)
941
- if squeeze_A:
942
- dA = rearrange(dA, "(h p) n -> h p n", p=headdim).sum(dim=(1, 2))
943
- return ddt, dA
944
-
945
-
946
- class MambaChunkScanCombinedFn(torch.autograd.Function):
947
-
948
- @staticmethod
949
- def forward(
950
- ctx,
951
- x,
952
- dt,
953
- A,
954
- B,
955
- C,
956
- chunk_size,
957
- D=None,
958
- z=None,
959
- dt_bias=None,
960
- initial_states=None,
961
- seq_idx=None,
962
- cu_seqlens=None,
963
- dt_softplus=False,
964
- dt_limit=(0.0, float("inf")),
965
- return_final_states=False,
966
- return_varlen_states=False,
967
- ):
968
- ctx.dt_dtype = dt.dtype
969
- if not return_varlen_states:
970
- cu_seqlens = None
971
- else:
972
- assert (
973
- cu_seqlens is not None
974
- ), "cu_seqlens must be provided if return_varlen_states is True"
975
- out, out_x, dt_out, dA_cumsum, states, final_states, *rest = (
976
- _mamba_chunk_scan_combined_fwd(
977
- x,
978
- dt,
979
- A,
980
- B,
981
- C,
982
- chunk_size,
983
- D=D,
984
- z=z,
985
- dt_bias=dt_bias,
986
- initial_states=initial_states,
987
- seq_idx=seq_idx,
988
- cu_seqlens=cu_seqlens,
989
- dt_softplus=dt_softplus,
990
- dt_limit=dt_limit,
991
- )
992
- )
993
- ctx.save_for_backward(
994
- out if z is None else out_x,
995
- x,
996
- dt,
997
- dA_cumsum,
998
- A,
999
- B,
1000
- C,
1001
- D,
1002
- z,
1003
- dt_bias,
1004
- initial_states,
1005
- seq_idx,
1006
- )
1007
- ctx.dt_softplus = dt_softplus
1008
- ctx.chunk_size = chunk_size
1009
- ctx.dt_limit = dt_limit
1010
- ctx.return_final_states = return_final_states
1011
- ctx.return_varlen_states = return_varlen_states
1012
- if not return_varlen_states:
1013
- return out if not return_final_states else (out, final_states)
1014
- else:
1015
- varlen_states = rest[0]
1016
- return (
1017
- (out, varlen_states)
1018
- if not return_final_states
1019
- else (out, final_states, varlen_states)
1020
- )
1021
-
1022
- @staticmethod
1023
- def backward(ctx, dout, *args):
1024
- out, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx = (
1025
- ctx.saved_tensors
1026
- )
1027
- assert (
1028
- not ctx.return_varlen_states
1029
- ), "return_varlen_states is not supported in backward"
1030
- dfinal_states = args[0] if ctx.return_final_states else None
1031
- dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = (
1032
- _mamba_chunk_scan_combined_bwd(
1033
- dout,
1034
- x,
1035
- dt,
1036
- A,
1037
- B,
1038
- C,
1039
- out,
1040
- ctx.chunk_size,
1041
- D=D,
1042
- z=z,
1043
- dt_bias=dt_bias,
1044
- initial_states=initial_states,
1045
- dfinal_states=dfinal_states,
1046
- seq_idx=seq_idx,
1047
- dt_softplus=ctx.dt_softplus,
1048
- dt_limit=ctx.dt_limit,
1049
- )
1050
- )
1051
- return (
1052
- dx,
1053
- ddt,
1054
- dA,
1055
- dB,
1056
- dC,
1057
- None,
1058
- dD,
1059
- dz,
1060
- ddt_bias,
1061
- dinitial_states,
1062
- None,
1063
- None,
1064
- None,
1065
- None,
1066
- None,
1067
- None,
1068
- )
1069
-
1070
-
1071
- def mamba_chunk_scan_combined(
1072
- x,
1073
- dt,
1074
- A,
1075
- B,
1076
- C,
1077
- chunk_size,
1078
- D=None,
1079
- z=None,
1080
- dt_bias=None,
1081
- initial_states=None,
1082
- seq_idx=None,
1083
- cu_seqlens=None,
1084
- dt_softplus=False,
1085
- dt_limit=(0.0, float("inf")),
1086
- return_final_states=False,
1087
- return_varlen_states=False,
1088
- ):
1089
- """
1090
- Argument:
1091
- x: (batch, seqlen, nheads, headdim)
1092
- dt: (batch, seqlen, nheads)
1093
- A: (nheads)
1094
- B: (batch, seqlen, ngroups, dstate)
1095
- C: (batch, seqlen, ngroups, dstate)
1096
- chunk_size: int
1097
- D: (nheads, headdim) or (nheads,)
1098
- z: (batch, seqlen, nheads, headdim)
1099
- dt_bias: (nheads,)
1100
- initial_states: (batch, nheads, headdim, dstate)
1101
- seq_idx: (batch, seqlen)
1102
- cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
1103
- dt_softplus: Whether to apply softplus to dt
1104
- Return:
1105
- out: (batch, seqlen, nheads, headdim)
1106
- """
1107
- return MambaChunkScanCombinedFn.apply(
1108
- x,
1109
- dt,
1110
- A,
1111
- B,
1112
- C,
1113
- chunk_size,
1114
- D,
1115
- z,
1116
- dt_bias,
1117
- initial_states,
1118
- seq_idx,
1119
- cu_seqlens,
1120
- dt_softplus,
1121
- dt_limit,
1122
- return_final_states,
1123
- return_varlen_states,
1124
- )
1125
-
1126
-
1127
- def mamba_chunk_scan(
1128
- x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False
1129
- ):
1130
- """
1131
- Argument:
1132
- x: (batch, seqlen, nheads, headdim)
1133
- dt: (batch, seqlen, nheads)
1134
- A: (nheads)
1135
- B: (batch, seqlen, ngroups, dstate)
1136
- C: (batch, seqlen, ngroups, dstate)
1137
- D: (nheads, headdim) or (nheads,)
1138
- z: (batch, seqlen, nheads, headdim)
1139
- dt_bias: (nheads,)
1140
- Return:
1141
- out: (batch, seqlen, nheads, headdim)
1142
- """
1143
- batch, seqlen, nheads, headdim = x.shape
1144
- dstate = B.shape[-1]
1145
- if seqlen % chunk_size != 0:
1146
- dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
1147
- dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
1148
- dt = dt.float() # We want high precision for this before cumsum
1149
- if dt_bias is not None:
1150
- dt = dt + rearrange(dt_bias, "h -> h 1 1")
1151
- if dt_softplus:
1152
- dt = F.softplus(dt)
1153
- dA = dt * rearrange(A, "h -> h 1 1")
1154
- dA = dt * rearrange(A, "h -> h 1 1")
1155
- dA_cumsum = torch.cumsum(dA, dim=-1)
1156
- # 1. Compute the state for each chunk
1157
- states = chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True)
1158
- # 2. Pass the state to all the chunks by weighted cumsum.
1159
- states = rearrange(
1160
- state_passing(
1161
- rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1]
1162
- )[0],
1163
- "... (p n) -> ... p n",
1164
- n=dstate,
1165
- )
1166
- # 3. Compute the output for each chunk
1167
- out = chunk_scan(B, C, x, dt, dA_cumsum, states, D=D, z=z)
1168
- return out
1169
-
1170
-
1171
- def ssd_chunk_scan_combined_ref(
1172
- x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False
1173
- ):
1174
- """
1175
- Argument:
1176
- x: (batch, seqlen, nheads, headdim)
1177
- dt: (batch, seqlen, nheads)
1178
- A: (nheads)
1179
- B: (batch, seqlen, ngroups, dstate)
1180
- C: (batch, seqlen, ngroups, dstate)
1181
- D: (nheads, headdim) or (nheads,)
1182
- z: (batch, seqlen, nheads, headdim)
1183
- dt_bias: (nheads,)
1184
- Return:
1185
- out: (batch, seqlen, nheads, headdim)
1186
- """
1187
- batch, seqlen, nheads, headdim = x.shape
1188
- dstate = B.shape[-1]
1189
- if seqlen % chunk_size != 0:
1190
- dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
1191
- dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
1192
- dt = dt.float() # We want high precision for this before cumsum
1193
- if dt_bias is not None:
1194
- dt = dt + rearrange(dt_bias, "h -> h 1 1")
1195
- if dt_softplus:
1196
- dt = F.softplus(dt)
1197
- dA = dt * rearrange(A, "h -> h 1 1")
1198
- dA_cumsum = torch.cumsum(dA, dim=-1)
1199
- # 1. Compute the state for each chunk
1200
- states = chunk_state_ref(B, x, dt, dA_cumsum)
1201
- states_dtype = states.dtype
1202
- if states.dtype not in [torch.float32, torch.float64]:
1203
- states = states.to(torch.float32)
1204
- # 2. Pass the state to all the chunks by weighted cumsum.
1205
- # state_passing_ref is much less numerically stable
1206
- states = rearrange(
1207
- state_passing_ref(
1208
- rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1]
1209
- )[0],
1210
- "... (p n) -> ... p n",
1211
- n=dstate,
1212
- )
1213
- states = states.to(states_dtype)
1214
- # 3. Compute the output for each chunk
1215
- out = chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z)
1216
- return out
1217
-
1218
-
1219
- def ssd_selective_scan(
1220
- x,
1221
- dt,
1222
- A,
1223
- B,
1224
- C,
1225
- D=None,
1226
- z=None,
1227
- dt_bias=None,
1228
- dt_softplus=False,
1229
- dt_limit=(0.0, float("inf")),
1230
- ):
1231
- """
1232
- Argument:
1233
- x: (batch, seqlen, nheads, headdim)
1234
- dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
1235
- A: (nheads) or (dim, dstate)
1236
- B: (batch, seqlen, ngroups, dstate)
1237
- C: (batch, seqlen, ngroups, dstate)
1238
- D: (nheads, headdim) or (nheads,)
1239
- z: (batch, seqlen, nheads, headdim)
1240
- dt_bias: (nheads,) or (nheads, headdim)
1241
- Return:
1242
- out: (batch, seqlen, nheads, headdim)
1243
- """
1244
- from ..selective_scan_interface import selective_scan_fn
1245
-
1246
- batch, seqlen, nheads, headdim = x.shape
1247
- _, _, ngroups, dstate = B.shape
1248
- x = rearrange(x, "b l h p -> b (h p) l")
1249
- if dt.dim() == 3:
1250
- dt = repeat(dt, "b l h -> b l h p", p=headdim)
1251
- dt = rearrange(dt, "b l h p -> b (h p) l")
1252
- if A.dim() == 1:
1253
- A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
1254
- else:
1255
- A = A.to(dtype=torch.float32)
1256
- B = rearrange(B, "b l g n -> b g n l")
1257
- C = rearrange(C, "b l g n -> b g n l")
1258
- if D is not None:
1259
- if D.dim() == 2:
1260
- D = rearrange(D, "h p -> (h p)")
1261
- else:
1262
- D = repeat(D, "h -> (h p)", p=headdim)
1263
- if z is not None:
1264
- z = rearrange(z, "b l h p -> b (h p) l")
1265
- if dt_bias is not None:
1266
- if dt_bias.dim() == 1:
1267
- dt_bias = repeat(dt_bias, "h -> h p", p=headdim)
1268
- dt_bias = rearrange(dt_bias, "h p -> (h p)")
1269
- if dt_limit != (0.0, float("inf")):
1270
- if dt_bias is not None:
1271
- dt = dt + rearrange(dt_bias, "d -> d 1")
1272
- if dt_softplus:
1273
- dt = F.softplus(dt)
1274
- dt = dt.clamp(min=dt_limit[0], max=dt_limit[1]).to(x.dtype)
1275
- dt_bias = None
1276
- dt_softplus = None
1277
- out = selective_scan_fn(
1278
- x, dt, A, B, C, D=D, z=z, delta_bias=dt_bias, delta_softplus=dt_softplus
1279
- )
1280
- return rearrange(out, "b (h p) l -> b l h p", p=headdim)
1281
-
1282
-
1283
- def mamba_conv1d_scan_ref(
1284
- xBC,
1285
- conv1d_weight,
1286
- conv1d_bias,
1287
- dt,
1288
- A,
1289
- chunk_size,
1290
- D=None,
1291
- z=None,
1292
- dt_bias=None,
1293
- dt_softplus=False,
1294
- dt_limit=(0.0, float("inf")),
1295
- activation="silu",
1296
- headdim=None,
1297
- ngroups=1,
1298
- ):
1299
- """
1300
- Argument:
1301
- xBC: (batch, seqlen, dim + 2 * ngroups * dstate) where dim == nheads * headdim
1302
- conv1d_weight: (dim + 2 * ngroups * dstate, width)
1303
- conv1d_bias: (dim + 2 * ngroups * dstate,)
1304
- dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
1305
- A: (nheads)
1306
- D: (nheads, headdim) or (nheads,)
1307
- z: (batch, seqlen, dim)
1308
- dt_bias: (nheads) or (nheads, headdim)
1309
- headdim: if D is 1D and z is None, headdim must be passed in
1310
- Return:
1311
- out: (batch, seqlen, dim)
1312
- """
1313
- batch, seqlen, nheads = dt.shape[:3]
1314
- assert nheads % ngroups == 0
1315
- if z is not None:
1316
- dim = z.shape[-1]
1317
- assert dim % nheads == 0
1318
- headdim = dim // nheads
1319
- else:
1320
- if D.dim() == 1:
1321
- assert headdim is not None
1322
- else:
1323
- headdim = D.shape[1]
1324
- dim = nheads * headdim
1325
- xBC = rearrange(
1326
- causal_conv1d_fn(
1327
- rearrange(xBC, "b s d -> b d s"),
1328
- conv1d_weight,
1329
- conv1d_bias,
1330
- activation=activation,
1331
- ),
1332
- "b d s -> b s d",
1333
- )
1334
- dstate = (xBC.shape[-1] - dim) // ngroups // 2
1335
- x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
1336
- x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
1337
- B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
1338
- C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
1339
- z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
1340
- out = ssd_selective_scan(
1341
- x,
1342
- dt.to(x.dtype),
1343
- A,
1344
- B,
1345
- C,
1346
- D=D.float(),
1347
- z=z,
1348
- dt_bias=dt_bias,
1349
- dt_softplus=dt_softplus,
1350
- dt_limit=dt_limit,
1351
- )
1352
- return rearrange(out, "b s h p -> b s (h p)")
1353
-
1354
-
1355
- class MambaSplitConv1dScanCombinedFn(torch.autograd.Function):
1356
-
1357
- @staticmethod
1358
- @custom_fwd
1359
- def forward(
1360
- ctx,
1361
- zxbcdt,
1362
- conv1d_weight,
1363
- conv1d_bias,
1364
- dt_bias,
1365
- A,
1366
- D,
1367
- chunk_size,
1368
- initial_states=None,
1369
- seq_idx=None,
1370
- dt_limit=(0.0, float("inf")),
1371
- return_final_states=False,
1372
- activation="silu",
1373
- rmsnorm_weight=None,
1374
- rmsnorm_eps=1e-6,
1375
- outproj_weight=None,
1376
- outproj_bias=None,
1377
- headdim=None,
1378
- ngroups=1,
1379
- norm_before_gate=True,
1380
- ):
1381
- assert activation in [None, "silu", "swish"]
1382
- if D.dim() == 1:
1383
- assert headdim is not None
1384
- (nheads,) = D.shape
1385
- else:
1386
- nheads, headdim = D.shape
1387
- batch, seqlen, _ = zxbcdt.shape
1388
- dim = nheads * headdim
1389
- assert nheads % ngroups == 0
1390
- dstate = (conv1d_weight.shape[0] - dim) // ngroups // 2
1391
- d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ngroups * dstate - nheads) // 2
1392
- assert d_nonssm >= 0
1393
- assert zxbcdt.shape == (
1394
- batch,
1395
- seqlen,
1396
- 2 * d_nonssm + 2 * dim + 2 * ngroups * dstate + nheads,
1397
- )
1398
- assert dt_bias.shape == (nheads,)
1399
- assert A.shape == (nheads,)
1400
- zx0, z, xBC, dt = torch.split(
1401
- zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1
1402
- )
1403
- seq_idx = seq_idx.contiguous() if seq_idx is not None else None
1404
- xBC_conv = rearrange(
1405
- causal_conv1d_cuda.causal_conv1d_fwd(
1406
- rearrange(xBC, "b s d -> b d s"),
1407
- conv1d_weight,
1408
- conv1d_bias,
1409
- seq_idx,
1410
- None,
1411
- None,
1412
- activation in ["silu", "swish"],
1413
- ),
1414
- "b d s -> b s d",
1415
- )
1416
- x, B, C = torch.split(
1417
- xBC_conv, [dim, ngroups * dstate, ngroups * dstate], dim=-1
1418
- )
1419
- x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
1420
- B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
1421
- C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
1422
- z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
1423
- if rmsnorm_weight is None:
1424
- out, out_x, dt_out, dA_cumsum, states, final_states = (
1425
- _mamba_chunk_scan_combined_fwd(
1426
- x,
1427
- dt,
1428
- A,
1429
- B,
1430
- C,
1431
- chunk_size=chunk_size,
1432
- D=D,
1433
- z=z,
1434
- dt_bias=dt_bias,
1435
- initial_states=initial_states,
1436
- seq_idx=seq_idx,
1437
- dt_softplus=True,
1438
- dt_limit=dt_limit,
1439
- )
1440
- )
1441
- out = rearrange(out, "b s h p -> b s (h p)")
1442
- rstd = None
1443
- if d_nonssm > 0:
1444
- out = torch.cat([_swiglu_fwd(zx0), out], dim=-1)
1445
- else:
1446
- out_x, _, dt_out, dA_cumsum, states, final_states = (
1447
- _mamba_chunk_scan_combined_fwd(
1448
- x,
1449
- dt,
1450
- A,
1451
- B,
1452
- C,
1453
- chunk_size=chunk_size,
1454
- D=D,
1455
- z=None,
1456
- dt_bias=dt_bias,
1457
- initial_states=initial_states,
1458
- seq_idx=seq_idx,
1459
- dt_softplus=True,
1460
- dt_limit=dt_limit,
1461
- )
1462
- )
1463
- # reshape input data into 2D tensor
1464
- x_rms = rearrange(out_x, "b s h p -> (b s) (h p)")
1465
- z_rms = rearrange(z, "b s h p -> (b s) (h p)")
1466
- rmsnorm_weight = rmsnorm_weight.contiguous()
1467
- if d_nonssm == 0:
1468
- out = None
1469
- else:
1470
- out01 = torch.empty(
1471
- (batch, seqlen, d_nonssm + dim),
1472
- dtype=x_rms.dtype,
1473
- device=x_rms.device,
1474
- )
1475
- out = rearrange(out01[..., d_nonssm:], "b s d -> (b s) d")
1476
- _swiglu_fwd(zx0, out=out01[..., :d_nonssm])
1477
- out, _, rstd = _layer_norm_fwd(
1478
- x_rms,
1479
- rmsnorm_weight,
1480
- None,
1481
- rmsnorm_eps,
1482
- z_rms,
1483
- out=out,
1484
- group_size=dim // ngroups,
1485
- norm_before_gate=norm_before_gate,
1486
- is_rms_norm=True,
1487
- )
1488
- if d_nonssm == 0:
1489
- out = rearrange(out, "(b s) d -> b s d", b=batch)
1490
- else:
1491
- out = out01
1492
- ctx.outproj_weight_dtype = (
1493
- outproj_weight.dtype if outproj_weight is not None else None
1494
- )
1495
- if outproj_weight is not None:
1496
- if torch.is_autocast_enabled():
1497
- dtype = torch.get_autocast_gpu_dtype()
1498
- out, outproj_weight = out.to(dtype), outproj_weight.to(dtype)
1499
- outproj_bias = (
1500
- outproj_bias.to(dtype) if outproj_bias is not None else None
1501
- )
1502
- out = F.linear(out, outproj_weight, outproj_bias)
1503
- else:
1504
- assert outproj_bias is None
1505
- ctx.save_for_backward(
1506
- zxbcdt,
1507
- conv1d_weight,
1508
- conv1d_bias,
1509
- out_x,
1510
- A,
1511
- D,
1512
- dt_bias,
1513
- initial_states,
1514
- seq_idx,
1515
- rmsnorm_weight,
1516
- rstd,
1517
- outproj_weight,
1518
- outproj_bias,
1519
- )
1520
- ctx.dt_limit = dt_limit
1521
- ctx.return_final_states = return_final_states
1522
- ctx.activation = activation
1523
- ctx.rmsnorm_eps = rmsnorm_eps
1524
- ctx.norm_before_gate = norm_before_gate
1525
- ctx.chunk_size = chunk_size
1526
- ctx.headdim = headdim
1527
- ctx.ngroups = ngroups
1528
- return out if not return_final_states else (out, final_states)
1529
-
1530
- @staticmethod
1531
- @custom_bwd
1532
- def backward(ctx, dout, *args):
1533
- (
1534
- zxbcdt,
1535
- conv1d_weight,
1536
- conv1d_bias,
1537
- out,
1538
- A,
1539
- D,
1540
- dt_bias,
1541
- initial_states,
1542
- seq_idx,
1543
- rmsnorm_weight,
1544
- rstd,
1545
- outproj_weight,
1546
- outproj_bias,
1547
- ) = ctx.saved_tensors
1548
- dfinal_states = args[0] if ctx.return_final_states else None
1549
- headdim = ctx.headdim
1550
- nheads = D.shape[0]
1551
- dim = nheads * headdim
1552
- assert nheads % ctx.ngroups == 0
1553
- dstate = (conv1d_weight.shape[0] - dim) // ctx.ngroups // 2
1554
- d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ctx.ngroups * dstate - nheads) // 2
1555
- assert d_nonssm >= 0
1556
- recompute_output = outproj_weight is not None
1557
- if recompute_output:
1558
- out_recompute = torch.empty(
1559
- *out.shape[:2], d_nonssm + dim, device=out.device, dtype=out.dtype
1560
- )
1561
- out0_recompute, out1_recompute = out_recompute.split(
1562
- [d_nonssm, dim], dim=-1
1563
- )
1564
- zx0, z, xBC, dt = torch.split(
1565
- zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1
1566
- )
1567
- # Recompute x, B, C
1568
- xBC_conv = rearrange(
1569
- causal_conv1d_cuda.causal_conv1d_fwd(
1570
- rearrange(xBC, "b s d -> b d s"),
1571
- conv1d_weight,
1572
- conv1d_bias,
1573
- seq_idx,
1574
- None,
1575
- None,
1576
- ctx.activation in ["silu", "swish"],
1577
- ),
1578
- "b d s -> b s d",
1579
- )
1580
- x, B, C = torch.split(
1581
- xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1
1582
- )
1583
- x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
1584
- B = rearrange(B, "b l (g n) -> b l g n", g=ctx.ngroups)
1585
- C = rearrange(C, "b l (g n) -> b l g n", g=ctx.ngroups)
1586
- dzxbcdt = torch.empty_like(zxbcdt)
1587
- dzx0, dz, dxBC_given, ddt_given = torch.split(
1588
- dzxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1
1589
- )
1590
- dxBC = torch.empty_like(xBC)
1591
- dx, dB, dC = torch.split(
1592
- dxBC, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1
1593
- )
1594
- z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
1595
- dx = rearrange(dx, "b l (h p) -> b l h p", h=nheads)
1596
- dB = rearrange(dB, "b l (g n) -> b l g n", g=ctx.ngroups)
1597
- dC = rearrange(dC, "b l (g n) -> b l g n", g=ctx.ngroups)
1598
- if outproj_weight is not None:
1599
- dout_og = dout
1600
- dout = F.linear(dout, outproj_weight.t())
1601
- if d_nonssm > 0:
1602
- dout0, dout = dout.split([d_nonssm, dim], dim=-1)
1603
- _swiglu_bwd(zx0, dout0, dxy=dzx0, recompute_output=True, out=out0_recompute)
1604
- dout = rearrange(dout, "b s (h p) -> b s h p", p=headdim)
1605
- if rmsnorm_weight is None:
1606
- dz = rearrange(dz, "b l (h p) -> b l h p", h=nheads)
1607
- dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states, *rest = (
1608
- _mamba_chunk_scan_combined_bwd(
1609
- dout,
1610
- x,
1611
- dt,
1612
- A,
1613
- B,
1614
- C,
1615
- out,
1616
- ctx.chunk_size,
1617
- D=D,
1618
- z=z,
1619
- dt_bias=dt_bias,
1620
- initial_states=initial_states,
1621
- dfinal_states=dfinal_states,
1622
- seq_idx=seq_idx,
1623
- dt_softplus=True,
1624
- dt_limit=ctx.dt_limit,
1625
- dx=dx,
1626
- ddt=ddt_given,
1627
- dB=dB,
1628
- dC=dC,
1629
- dz=dz,
1630
- recompute_output=recompute_output,
1631
- )
1632
- )
1633
- out_for_linear = (
1634
- rearrange(rest[0], "b s h p -> b s (h p)") if recompute_output else None
1635
- )
1636
- drmsnorm_weight = None
1637
- else:
1638
- batch = dout.shape[0]
1639
- dy_rms = rearrange(dout, "b s h p -> (b s) (h p)")
1640
- dz = rearrange(dz, "b l d -> (b l) d")
1641
- x_rms = rearrange(out, "b s h p -> (b s) (h p)")
1642
- z_rms = rearrange(z, "b s h p -> (b s) (h p)")
1643
- out1_recompute = (
1644
- rearrange(out1_recompute, "b s d -> (b s) d")
1645
- if recompute_output
1646
- else None
1647
- )
1648
- dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(
1649
- dy_rms,
1650
- x_rms,
1651
- rmsnorm_weight,
1652
- None,
1653
- ctx.rmsnorm_eps,
1654
- None,
1655
- rstd,
1656
- z_rms,
1657
- group_size=dim // ctx.ngroups,
1658
- norm_before_gate=ctx.norm_before_gate,
1659
- is_rms_norm=True,
1660
- recompute_output=recompute_output,
1661
- dz=dz,
1662
- out=out1_recompute if recompute_output else None,
1663
- )
1664
- out_for_linear = out_recompute if recompute_output else None
1665
- dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim)
1666
- dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = (
1667
- _mamba_chunk_scan_combined_bwd(
1668
- dout,
1669
- x,
1670
- dt,
1671
- A,
1672
- B,
1673
- C,
1674
- out,
1675
- ctx.chunk_size,
1676
- D=D,
1677
- z=None,
1678
- dt_bias=dt_bias,
1679
- initial_states=initial_states,
1680
- dfinal_states=dfinal_states,
1681
- seq_idx=seq_idx,
1682
- dt_softplus=True,
1683
- dt_limit=ctx.dt_limit,
1684
- dx=dx,
1685
- ddt=ddt_given,
1686
- dB=dB,
1687
- dC=dC,
1688
- )
1689
- )
1690
-
1691
- if outproj_weight is not None:
1692
- doutproj_weight = torch.einsum("bso,bsd->od", dout_og, out_for_linear)
1693
- doutproj_bias = (
1694
- dout_og.sum(dim=(0, 1)) if outproj_bias is not None else None
1695
- )
1696
- else:
1697
- doutproj_weight, doutproj_bias = None, None
1698
- dxBC_given = rearrange(dxBC_given, "b s d -> b d s")
1699
- dxBC_given, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
1700
- rearrange(xBC, "b s d -> b d s"),
1701
- conv1d_weight,
1702
- conv1d_bias,
1703
- rearrange(dxBC, "b s d -> b d s"),
1704
- seq_idx,
1705
- None,
1706
- None,
1707
- dxBC_given,
1708
- False,
1709
- ctx.activation in ["silu", "swish"],
1710
- )
1711
- dxBC_given = rearrange(dxBC_given, "b d s -> b s d")
1712
- return (
1713
- dzxbcdt,
1714
- dweight,
1715
- dbias,
1716
- ddt_bias,
1717
- dA,
1718
- dD,
1719
- None,
1720
- dinitial_states,
1721
- None,
1722
- None,
1723
- None,
1724
- None,
1725
- drmsnorm_weight,
1726
- None,
1727
- doutproj_weight,
1728
- doutproj_bias,
1729
- None,
1730
- None,
1731
- None,
1732
- )
1733
-
1734
-
1735
- def mamba_split_conv1d_scan_combined(
1736
- zxbcdt,
1737
- conv1d_weight,
1738
- conv1d_bias,
1739
- dt_bias,
1740
- A,
1741
- D,
1742
- chunk_size,
1743
- initial_states=None,
1744
- seq_idx=None,
1745
- dt_limit=(0.0, float("inf")),
1746
- return_final_states=False,
1747
- activation="silu",
1748
- rmsnorm_weight=None,
1749
- rmsnorm_eps=1e-6,
1750
- outproj_weight=None,
1751
- outproj_bias=None,
1752
- headdim=None,
1753
- ngroups=1,
1754
- norm_before_gate=True,
1755
- ):
1756
- """
1757
- Argument:
1758
- zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
1759
- conv1d_weight: (dim + 2 * ngroups * dstate, width)
1760
- conv1d_bias: (dim + 2 * ngroups * dstate,)
1761
- dt_bias: (nheads,)
1762
- A: (nheads)
1763
- D: (nheads, headdim) or (nheads,)
1764
- initial_states: (batch, nheads, headdim, dstate)
1765
- seq_idx: (batch, seqlen), int32
1766
- rmsnorm_weight: (dim,)
1767
- outproj_weight: (out_dim, dim)
1768
- outproj_bias: (out_dim,)
1769
- headdim: if D is 1D, headdim must be passed in
1770
- norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
1771
- Return:
1772
- out: (batch, seqlen, dim)
1773
- """
1774
- return MambaSplitConv1dScanCombinedFn.apply(
1775
- zxbcdt,
1776
- conv1d_weight,
1777
- conv1d_bias,
1778
- dt_bias,
1779
- A,
1780
- D,
1781
- chunk_size,
1782
- initial_states,
1783
- seq_idx,
1784
- dt_limit,
1785
- return_final_states,
1786
- activation,
1787
- rmsnorm_weight,
1788
- rmsnorm_eps,
1789
- outproj_weight,
1790
- outproj_bias,
1791
- headdim,
1792
- ngroups,
1793
- norm_before_gate,
1794
- )
1795
-
1796
-
1797
- def mamba_split_conv1d_scan_ref(
1798
- zxbcdt,
1799
- conv1d_weight,
1800
- conv1d_bias,
1801
- dt_bias,
1802
- A,
1803
- D,
1804
- chunk_size,
1805
- dt_limit=(0.0, float("inf")),
1806
- activation="silu",
1807
- rmsnorm_weight=None,
1808
- rmsnorm_eps=1e-6,
1809
- outproj_weight=None,
1810
- outproj_bias=None,
1811
- headdim=None,
1812
- ngroups=1,
1813
- norm_before_gate=True,
1814
- ):
1815
- """
1816
- Argument:
1817
- zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
1818
- conv1d_weight: (dim + 2 * ngroups * dstate, width)
1819
- conv1d_bias: (dim + 2 * ngroups * dstate,)
1820
- dt_bias: (nheads,)
1821
- A: (nheads)
1822
- D: (nheads, headdim) or (nheads,)
1823
- rmsnorm_weight: (dim,)
1824
- outproj_weight: (out_dim, dim)
1825
- outproj_bias: (out_dim,)
1826
- headdim: if D is 1D, headdim must be passed in
1827
- norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
1828
- Return:
1829
- out: (batch, seqlen, dim)
1830
- """
1831
- if D.dim() == 1:
1832
- assert headdim is not None
1833
- (nheads,) = D.shape
1834
- else:
1835
- nheads, headdim = D.shape
1836
- assert nheads % ngroups == 0
1837
- batch, seqlen, _ = zxbcdt.shape
1838
- dim = nheads * headdim
1839
- dstate = (zxbcdt.shape[-1] - 2 * dim - nheads) // ngroups // 2
1840
- assert zxbcdt.shape == (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads)
1841
- assert dt_bias.shape == (nheads,)
1842
- assert A.shape == (nheads,)
1843
- if rmsnorm_weight is not None:
1844
- assert rmsnorm_weight.shape == (dim,)
1845
- z, xBC, dt = torch.split(zxbcdt, [dim, dim + 2 * ngroups * dstate, nheads], dim=-1)
1846
- xBC = rearrange(
1847
- causal_conv1d_fn(
1848
- rearrange(xBC, "b s d -> b d s"),
1849
- conv1d_weight,
1850
- conv1d_bias,
1851
- activation=activation,
1852
- ),
1853
- "b d s -> b s d",
1854
- )
1855
- x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
1856
- x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
1857
- B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
1858
- C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
1859
- z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
1860
- out = ssd_selective_scan(
1861
- x,
1862
- dt.to(x.dtype),
1863
- A,
1864
- B,
1865
- C,
1866
- D=D.float(),
1867
- z=z if rmsnorm_weight is None else None,
1868
- dt_bias=dt_bias,
1869
- dt_softplus=True,
1870
- dt_limit=dt_limit,
1871
- )
1872
- out = rearrange(out, "b s h p -> b s (h p)")
1873
- if rmsnorm_weight is not None:
1874
- out = rmsnorm_fn(
1875
- out,
1876
- rmsnorm_weight,
1877
- None,
1878
- z=rearrange(z, "b l h p -> b l (h p)"),
1879
- eps=rmsnorm_eps,
1880
- norm_before_gate=norm_before_gate,
1881
- )
1882
- if outproj_weight is not None:
1883
- out = F.linear(out, outproj_weight, outproj_bias)
1884
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/utils/__init__.py DELETED
File without changes
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/__init__.py DELETED
@@ -1,14 +0,0 @@
1
- __version__ = "2.2.4"
2
-
3
- from .ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
4
- from .modules.mamba_simple import Mamba
5
- from .modules.mamba2 import Mamba2
6
- from .models.mixer_seq_simple import MambaLMHeadModel
7
-
8
- __all__ = [
9
- "selective_scan_fn",
10
- "mamba_inner_fn",
11
- "Mamba",
12
- "Mamba2",
13
- "MambaLMHeadModel",
14
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/distributed/__init__.py DELETED
File without changes
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py DELETED
@@ -1,326 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
- # The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
3
- from typing import Optional
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- from torch import Tensor
9
- from torch.distributed import ProcessGroup
10
- from ..utils.torch import custom_bwd, custom_fwd
11
-
12
- from einops import rearrange
13
-
14
- from ..distributed.distributed_utils import (
15
- all_gather_raw,
16
- all_reduce,
17
- all_reduce_raw,
18
- reduce_scatter,
19
- reduce_scatter_raw,
20
- )
21
-
22
-
23
- class ParallelLinearFunc(torch.autograd.Function):
24
- @staticmethod
25
- @custom_fwd
26
- def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
27
- """
28
- If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
29
- with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
30
- """
31
- ctx.compute_weight_gradient = weight.requires_grad
32
- ctx.process_group = process_group
33
- ctx.sequence_parallel = sequence_parallel
34
-
35
- if torch.is_autocast_enabled():
36
- x = x.to(dtype=torch.get_autocast_gpu_dtype())
37
- x = x.contiguous()
38
- if process_group is not None and sequence_parallel:
39
- # We want to kick off the all_gather early, before weight dtype conversion
40
- total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
41
- else:
42
- total_x = x
43
-
44
- if torch.is_autocast_enabled():
45
- weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
46
- bias = (
47
- bias.to(dtype=torch.get_autocast_gpu_dtype())
48
- if bias is not None
49
- else None
50
- )
51
- weight = weight.contiguous()
52
- if process_group is not None and sequence_parallel:
53
- handle_x.wait()
54
- batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
55
- batch_dim = batch_shape.numel()
56
- # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
57
- output = F.linear(total_x, weight, bias)
58
- if ctx.compute_weight_gradient:
59
- ctx.save_for_backward(x, weight)
60
- else:
61
- ctx.save_for_backward(weight)
62
- return output
63
-
64
- @staticmethod
65
- @custom_bwd
66
- def backward(ctx, grad_output):
67
- grad_output = grad_output.contiguous()
68
- process_group = ctx.process_group
69
- sequence_parallel = ctx.sequence_parallel
70
- if ctx.compute_weight_gradient:
71
- x, weight = ctx.saved_tensors
72
- if process_group is not None and sequence_parallel:
73
- total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
74
- else:
75
- total_x = x
76
- else:
77
- (weight,) = ctx.saved_tensors
78
- total_x = None
79
- batch_shape = grad_output.shape[:-1]
80
- batch_dim = batch_shape.numel()
81
- grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
82
- if ctx.needs_input_grad[0]:
83
- grad_input = F.linear(grad_output, weight.t())
84
- grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
85
- if process_group is not None:
86
- reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
87
- grad_input, handle_grad_input = reduce_fn(
88
- grad_input, process_group, async_op=True
89
- )
90
- else:
91
- grad_input = None
92
- if ctx.needs_input_grad[1]:
93
- assert ctx.compute_weight_gradient
94
- if process_group is not None and sequence_parallel:
95
- handle_x.wait()
96
- grad_weight = torch.einsum(
97
- "bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1])
98
- )
99
- else:
100
- grad_weight = None
101
- grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None
102
- if process_group is not None and ctx.needs_input_grad[0]:
103
- handle_grad_input.wait()
104
- return grad_input, grad_weight, grad_bias, None, None
105
-
106
-
107
- def parallel_linear_func(
108
- x: Tensor,
109
- weight: Tensor,
110
- bias: Optional[Tensor] = None,
111
- process_group: Optional[ProcessGroup] = None,
112
- sequence_parallel: bool = True,
113
- ):
114
- return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel)
115
-
116
-
117
- class ColumnParallelLinear(nn.Linear):
118
- def __init__(
119
- self,
120
- in_features: int,
121
- out_features: int,
122
- process_group: ProcessGroup,
123
- bias: bool = True,
124
- sequence_parallel=True,
125
- multiple_of=1,
126
- device=None,
127
- dtype=None,
128
- ) -> None:
129
- world_size = torch.distributed.get_world_size(process_group)
130
- if out_features % multiple_of:
131
- raise ValueError(
132
- f"out_features ({out_features}) must be a multiple of {multiple_of}"
133
- )
134
- multiple = out_features // multiple_of
135
- # We want to split @multiple across world_size, but it could be an uneven split
136
- div = multiple // world_size
137
- mod = multiple % world_size
138
- # The first @mod ranks get @div + 1 copies, the rest get @div copies
139
- local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
140
- super().__init__(
141
- in_features,
142
- local_multiple * multiple_of,
143
- bias=bias,
144
- device=device,
145
- dtype=dtype,
146
- )
147
- self.process_group = process_group
148
- self.sequence_parallel = sequence_parallel
149
-
150
- def forward(self, x):
151
- # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
152
- # we do an all_gather of x before doing the matmul.
153
- # If not, then the input is already gathered.
154
- return parallel_linear_func(
155
- x,
156
- self.weight,
157
- self.bias,
158
- process_group=self.process_group,
159
- sequence_parallel=self.sequence_parallel,
160
- )
161
-
162
-
163
- class RowParallelLinear(nn.Linear):
164
- def __init__(
165
- self,
166
- in_features: int,
167
- out_features: int,
168
- process_group: ProcessGroup,
169
- bias: bool = True,
170
- sequence_parallel=True,
171
- multiple_of=1,
172
- device=None,
173
- dtype=None,
174
- ) -> None:
175
- world_size = torch.distributed.get_world_size(process_group)
176
- rank = torch.distributed.get_rank(process_group)
177
- if in_features % multiple_of:
178
- raise ValueError(
179
- f"in_features ({in_features}) must be a multiple of {multiple_of}"
180
- )
181
- multiple = in_features // multiple_of
182
- # We want to split @multiple across world_size, but it could be an uneven split
183
- div = multiple // world_size
184
- mod = multiple % world_size
185
- # The first @mod ranks get @div + 1 copies, the rest get @div copies
186
- local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
187
- # Only rank 0 will have bias
188
- super().__init__(
189
- local_multiple * multiple_of,
190
- out_features,
191
- bias=bias and rank == 0,
192
- device=device,
193
- dtype=dtype,
194
- )
195
- self.process_group = process_group
196
- self.sequence_parallel = sequence_parallel
197
-
198
- def forward(self, x):
199
- """
200
- We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
201
- a reduce_scatter of the result.
202
- """
203
- out = parallel_linear_func(x, self.weight, self.bias)
204
- reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
205
- return reduce_fn(out, self.process_group)
206
-
207
-
208
- class VocabParallelEmbedding(nn.Embedding):
209
- def __init__(
210
- self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs
211
- ):
212
- self.process_group = process_group
213
- if process_group is not None:
214
- world_size = torch.distributed.get_world_size(process_group)
215
- if num_embeddings % world_size != 0:
216
- raise ValueError(
217
- f"num_embeddings ({num_embeddings}) must be divisible by "
218
- f"world_size ({world_size})"
219
- )
220
- if world_size > 1 and padding_idx is not None:
221
- raise RuntimeError("ParallelEmbedding does not support padding_idx")
222
- else:
223
- world_size = 1
224
- super().__init__(
225
- num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs
226
- )
227
-
228
- def forward(self, input: Tensor) -> Tensor:
229
- if self.process_group is None:
230
- return super().forward(input)
231
- else:
232
- rank = torch.distributed.get_rank(self.process_group)
233
- vocab_size = self.num_embeddings
234
- vocab_start_index, vocab_end_index = (
235
- rank * vocab_size,
236
- (rank + 1) * vocab_size,
237
- )
238
- # Create a mask of valid vocab ids (1 means it needs to be masked).
239
- input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
240
- input = input - vocab_start_index
241
- input[input_ids_mask] = 0
242
- embeddings = super().forward(input)
243
- embeddings[input_ids_mask] = 0.0
244
- return embeddings
245
-
246
-
247
- class ColumnParallelEmbedding(nn.Embedding):
248
- def __init__(
249
- self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs
250
- ):
251
- self.process_group = process_group
252
- if process_group is not None:
253
- world_size = torch.distributed.get_world_size(process_group)
254
- if embedding_dim % world_size != 0:
255
- raise ValueError(
256
- f"embedding_dim ({embedding_dim}) must be divisible by "
257
- f"world_size ({world_size})"
258
- )
259
- else:
260
- world_size = 1
261
- super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
262
-
263
-
264
- class ParallelEmbeddings(nn.Module):
265
- def __init__(
266
- self,
267
- embed_dim,
268
- vocab_size,
269
- max_position_embeddings,
270
- process_group,
271
- padding_idx=None,
272
- sequence_parallel=True,
273
- device=None,
274
- dtype=None,
275
- ):
276
- """
277
- If max_position_embeddings <= 0, there's no position embeddings
278
- """
279
- factory_kwargs = {"device": device, "dtype": dtype}
280
- super().__init__()
281
- self.process_group = process_group
282
- self.sequence_parallel = sequence_parallel
283
- self.word_embeddings = VocabParallelEmbedding(
284
- vocab_size,
285
- embed_dim,
286
- padding_idx=padding_idx,
287
- process_group=process_group,
288
- **factory_kwargs,
289
- )
290
- self.max_position_embeddings = max_position_embeddings
291
- if self.max_position_embeddings > 0:
292
- self.position_embeddings = ColumnParallelEmbedding(
293
- max_position_embeddings,
294
- embed_dim,
295
- process_group=process_group,
296
- **factory_kwargs,
297
- )
298
-
299
- def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
300
- """
301
- input_ids: (batch, seqlen)
302
- position_ids: (batch, seqlen)
303
- """
304
- batch_size, seqlen = input_ids.shape
305
- world_size = torch.distributed.get_world_size(self.process_group)
306
- embeddings = self.word_embeddings(input_ids)
307
- if self.max_position_embeddings > 0:
308
- if position_ids is None:
309
- position_ids = torch.arange(
310
- seqlen, dtype=torch.long, device=input_ids.device
311
- )
312
- position_embeddings = self.position_embeddings(position_ids)
313
- if world_size <= 1:
314
- embeddings = embeddings + position_embeddings
315
- else:
316
- partition_dim = self.position_embeddings.embedding_dim
317
- rank = torch.distributed.get_rank(self.process_group)
318
- embeddings[
319
- ..., rank * partition_dim : (rank + 1) * partition_dim
320
- ] += position_embeddings
321
- if combine_batch_seqlen_dim:
322
- embeddings = rearrange(embeddings, "b s d -> (b s) d")
323
- reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
324
- return (
325
- embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
326
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/models/__init__.py DELETED
File without changes
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py DELETED
@@ -1,338 +0,0 @@
1
- # Copyright (c) 2023, Albert Gu, Tri Dao.
2
-
3
- import math
4
- from functools import partial
5
- import json
6
- import os
7
- import copy
8
-
9
- from collections import namedtuple
10
-
11
- import torch
12
- import torch.nn as nn
13
-
14
- from .config_mamba import MambaConfig
15
- from ..modules.mamba_simple import Mamba
16
- from ..modules.mamba2 import Mamba2
17
- from ..modules.mha import MHA
18
- from ..modules.mlp import GatedMLP
19
- from ..modules.block import Block
20
- from ..utils.generation import GenerationMixin
21
- from ..utils.hf import load_config_hf, load_state_dict_hf
22
-
23
- try:
24
- from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
25
- except ImportError:
26
- RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
27
-
28
-
29
- def create_block(
30
- d_model,
31
- d_intermediate,
32
- ssm_cfg=None,
33
- attn_layer_idx=None,
34
- attn_cfg=None,
35
- norm_epsilon=1e-5,
36
- rms_norm=False,
37
- residual_in_fp32=False,
38
- fused_add_norm=False,
39
- layer_idx=None,
40
- device=None,
41
- dtype=None,
42
- ):
43
- if ssm_cfg is None:
44
- ssm_cfg = {}
45
- if attn_layer_idx is None:
46
- attn_layer_idx = []
47
- if attn_cfg is None:
48
- attn_cfg = {}
49
- factory_kwargs = {"device": device, "dtype": dtype}
50
- if layer_idx not in attn_layer_idx:
51
- # Create a copy of the config to modify
52
- ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
53
- ssm_layer = ssm_cfg.pop("layer", "Mamba1")
54
- if ssm_layer not in ["Mamba1", "Mamba2"]:
55
- raise ValueError(
56
- f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2"
57
- )
58
- mixer_cls = partial(
59
- Mamba2 if ssm_layer == "Mamba2" else Mamba,
60
- layer_idx=layer_idx,
61
- **ssm_cfg,
62
- **factory_kwargs,
63
- )
64
- else:
65
- mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
66
- norm_cls = partial(
67
- nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
68
- )
69
- if d_intermediate == 0:
70
- mlp_cls = nn.Identity
71
- else:
72
- mlp_cls = partial(
73
- GatedMLP,
74
- hidden_features=d_intermediate,
75
- out_features=d_model,
76
- **factory_kwargs,
77
- )
78
- block = Block(
79
- d_model,
80
- mixer_cls,
81
- mlp_cls,
82
- norm_cls=norm_cls,
83
- fused_add_norm=fused_add_norm,
84
- residual_in_fp32=residual_in_fp32,
85
- )
86
- block.layer_idx = layer_idx
87
- return block
88
-
89
-
90
- # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
91
- def _init_weights(
92
- module,
93
- n_layer,
94
- initializer_range=0.02, # Now only used for embedding layer.
95
- rescale_prenorm_residual=True,
96
- n_residuals_per_layer=1, # Change to 2 if we have MLP
97
- ):
98
- if isinstance(module, nn.Linear):
99
- if module.bias is not None:
100
- if not getattr(module.bias, "_no_reinit", False):
101
- nn.init.zeros_(module.bias)
102
- elif isinstance(module, nn.Embedding):
103
- nn.init.normal_(module.weight, std=initializer_range)
104
-
105
- if rescale_prenorm_residual:
106
- # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
107
- # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
108
- # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
109
- # > -- GPT-2 :: https://openai.com/blog/better-language-models/
110
- #
111
- # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
112
- for name, p in module.named_parameters():
113
- if name in ["out_proj.weight", "fc2.weight"]:
114
- # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
115
- # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
116
- # We need to reinit p since this code could be called multiple times
117
- # Having just p *= scale would repeatedly scale it down
118
- nn.init.kaiming_uniform_(p, a=math.sqrt(5))
119
- with torch.no_grad():
120
- p /= math.sqrt(n_residuals_per_layer * n_layer)
121
-
122
-
123
- class MixerModel(nn.Module):
124
- def __init__(
125
- self,
126
- d_model: int,
127
- n_layer: int,
128
- d_intermediate: int,
129
- vocab_size: int,
130
- ssm_cfg=None,
131
- attn_layer_idx=None,
132
- attn_cfg=None,
133
- norm_epsilon: float = 1e-5,
134
- rms_norm: bool = False,
135
- initializer_cfg=None,
136
- fused_add_norm=False,
137
- residual_in_fp32=False,
138
- device=None,
139
- dtype=None,
140
- ) -> None:
141
- factory_kwargs = {"device": device, "dtype": dtype}
142
- super().__init__()
143
- self.residual_in_fp32 = residual_in_fp32
144
-
145
- self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
146
-
147
- # We change the order of residual and layer norm:
148
- # Instead of LN -> Attn / MLP -> Add, we do:
149
- # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
150
- # the main branch (output of MLP / Mixer). The model definition is unchanged.
151
- # This is for performance reason: we can fuse add + layer_norm.
152
- self.fused_add_norm = fused_add_norm
153
- if self.fused_add_norm:
154
- if layer_norm_fn is None or rms_norm_fn is None:
155
- raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
156
-
157
- self.layers = nn.ModuleList(
158
- [
159
- create_block(
160
- d_model,
161
- d_intermediate=d_intermediate,
162
- ssm_cfg=ssm_cfg,
163
- attn_layer_idx=attn_layer_idx,
164
- attn_cfg=attn_cfg,
165
- norm_epsilon=norm_epsilon,
166
- rms_norm=rms_norm,
167
- residual_in_fp32=residual_in_fp32,
168
- fused_add_norm=fused_add_norm,
169
- layer_idx=i,
170
- **factory_kwargs,
171
- )
172
- for i in range(n_layer)
173
- ]
174
- )
175
-
176
- self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
177
- d_model, eps=norm_epsilon, **factory_kwargs
178
- )
179
-
180
- self.apply(
181
- partial(
182
- _init_weights,
183
- n_layer=n_layer,
184
- **(initializer_cfg if initializer_cfg is not None else {}),
185
- n_residuals_per_layer=(
186
- 1 if d_intermediate == 0 else 2
187
- ), # 2 if we have MLP
188
- )
189
- )
190
-
191
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
192
- return {
193
- i: layer.allocate_inference_cache(
194
- batch_size, max_seqlen, dtype=dtype, **kwargs
195
- )
196
- for i, layer in enumerate(self.layers)
197
- }
198
-
199
- def forward(self, input_ids, inference_params=None, **mixer_kwargs):
200
- hidden_states = self.embedding(input_ids)
201
- residual = None
202
- for layer in self.layers:
203
- hidden_states, residual = layer(
204
- hidden_states,
205
- residual,
206
- inference_params=inference_params,
207
- **mixer_kwargs,
208
- )
209
- if not self.fused_add_norm:
210
- residual = (
211
- (hidden_states + residual) if residual is not None else hidden_states
212
- )
213
- hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
214
- else:
215
- # Set prenorm=False here since we don't need the residual
216
- hidden_states = layer_norm_fn(
217
- hidden_states,
218
- self.norm_f.weight,
219
- self.norm_f.bias,
220
- eps=self.norm_f.eps,
221
- residual=residual,
222
- prenorm=False,
223
- residual_in_fp32=self.residual_in_fp32,
224
- is_rms_norm=isinstance(self.norm_f, RMSNorm),
225
- )
226
- return hidden_states
227
-
228
-
229
- class MambaLMHeadModel(nn.Module, GenerationMixin):
230
-
231
- def __init__(
232
- self,
233
- config: MambaConfig,
234
- initializer_cfg=None,
235
- device=None,
236
- dtype=None,
237
- ) -> None:
238
- self.config = config
239
- d_model = config.d_model
240
- n_layer = config.n_layer
241
- d_intermediate = config.d_intermediate
242
- vocab_size = config.vocab_size
243
- ssm_cfg = config.ssm_cfg
244
- attn_layer_idx = config.attn_layer_idx
245
- attn_cfg = config.attn_cfg
246
- rms_norm = config.rms_norm
247
- residual_in_fp32 = config.residual_in_fp32
248
- fused_add_norm = config.fused_add_norm
249
- pad_vocab_size_multiple = config.pad_vocab_size_multiple
250
- factory_kwargs = {"device": device, "dtype": dtype}
251
-
252
- super().__init__()
253
- if vocab_size % pad_vocab_size_multiple != 0:
254
- vocab_size += pad_vocab_size_multiple - (
255
- vocab_size % pad_vocab_size_multiple
256
- )
257
- self.backbone = MixerModel(
258
- d_model=d_model,
259
- n_layer=n_layer,
260
- d_intermediate=d_intermediate,
261
- vocab_size=vocab_size,
262
- ssm_cfg=ssm_cfg,
263
- attn_layer_idx=attn_layer_idx,
264
- attn_cfg=attn_cfg,
265
- rms_norm=rms_norm,
266
- initializer_cfg=initializer_cfg,
267
- fused_add_norm=fused_add_norm,
268
- residual_in_fp32=residual_in_fp32,
269
- **factory_kwargs,
270
- )
271
- self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
272
-
273
- # Initialize weights and apply final processing
274
- self.apply(
275
- partial(
276
- _init_weights,
277
- n_layer=n_layer,
278
- **(initializer_cfg if initializer_cfg is not None else {}),
279
- )
280
- )
281
- self.tie_weights()
282
-
283
- def tie_weights(self):
284
- if self.config.tie_embeddings:
285
- self.lm_head.weight = self.backbone.embedding.weight
286
-
287
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
288
- return self.backbone.allocate_inference_cache(
289
- batch_size, max_seqlen, dtype=dtype, **kwargs
290
- )
291
-
292
- def forward(
293
- self,
294
- input_ids,
295
- position_ids=None,
296
- inference_params=None,
297
- num_last_tokens=0,
298
- **mixer_kwargs,
299
- ):
300
- """
301
- "position_ids" is just to be compatible with Transformer generation. We don't use it.
302
- num_last_tokens: if > 0, only return the logits for the last n tokens
303
- """
304
- hidden_states = self.backbone(
305
- input_ids, inference_params=inference_params, **mixer_kwargs
306
- )
307
- if num_last_tokens > 0:
308
- hidden_states = hidden_states[:, -num_last_tokens:]
309
- lm_logits = self.lm_head(hidden_states)
310
- CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
311
- return CausalLMOutput(logits=lm_logits)
312
-
313
- @classmethod
314
- def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
315
- config_data = load_config_hf(pretrained_model_name)
316
- config = MambaConfig(**config_data)
317
- model = cls(config, device=device, dtype=dtype, **kwargs)
318
- model.load_state_dict(
319
- load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)
320
- )
321
- return model
322
-
323
- def save_pretrained(self, save_directory):
324
- """
325
- Minimal implementation of save_pretrained for MambaLMHeadModel.
326
- Save the model and its configuration file to a directory.
327
- """
328
- # Ensure save_directory exists
329
- os.makedirs(save_directory, exist_ok=True)
330
-
331
- # Save the model's state_dict
332
- model_path = os.path.join(save_directory, "pytorch_model.bin")
333
- torch.save(self.state_dict(), model_path)
334
-
335
- # Save the configuration of the model
336
- config_path = os.path.join(save_directory, "config.json")
337
- with open(config_path, "w") as f:
338
- json.dump(self.config.__dict__, f, indent=4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/modules/__init__.py DELETED
File without changes
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/__init__.py DELETED
File without changes
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/selective_scan_interface.py DELETED
@@ -1,659 +0,0 @@
1
- # Copyright (c) 2023, Tri Dao, Albert Gu.
2
-
3
- import torch
4
- import torch.nn.functional as F
5
- from ..utils.torch import custom_fwd, custom_bwd
6
-
7
- from einops import rearrange, repeat
8
-
9
- try:
10
- from causal_conv1d import causal_conv1d_fn
11
- import causal_conv1d_cuda
12
- except ImportError:
13
- causal_conv1d_fn = None
14
- causal_conv1d_cuda = None
15
-
16
- from .triton.layer_norm import _layer_norm_fwd
17
-
18
- from .._ops import ops
19
-
20
-
21
- class SelectiveScanFn(torch.autograd.Function):
22
-
23
- @staticmethod
24
- def forward(
25
- ctx,
26
- u,
27
- delta,
28
- A,
29
- B,
30
- C,
31
- D=None,
32
- z=None,
33
- delta_bias=None,
34
- delta_softplus=False,
35
- return_last_state=False,
36
- ):
37
- if u.stride(-1) != 1:
38
- u = u.contiguous()
39
- if delta.stride(-1) != 1:
40
- delta = delta.contiguous()
41
- if D is not None:
42
- D = D.contiguous()
43
- if B.stride(-1) != 1:
44
- B = B.contiguous()
45
- if C.stride(-1) != 1:
46
- C = C.contiguous()
47
- if z is not None and z.stride(-1) != 1:
48
- z = z.contiguous()
49
- if B.dim() == 3:
50
- B = rearrange(B, "b dstate l -> b 1 dstate l")
51
- ctx.squeeze_B = True
52
- if C.dim() == 3:
53
- C = rearrange(C, "b dstate l -> b 1 dstate l")
54
- ctx.squeeze_C = True
55
- out, x, *rest = ops.selective_scan_fwd(
56
- u, delta, A, B, C, D, z, delta_bias, delta_softplus
57
- )
58
- ctx.delta_softplus = delta_softplus
59
- ctx.has_z = z is not None
60
- last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
61
- if not ctx.has_z:
62
- ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
63
- return out if not return_last_state else (out, last_state)
64
- else:
65
- ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
66
- out_z = rest[0]
67
- return out_z if not return_last_state else (out_z, last_state)
68
-
69
- @staticmethod
70
- def backward(ctx, dout, *args):
71
- if not ctx.has_z:
72
- u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
73
- z = None
74
- out = None
75
- else:
76
- u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
77
- if dout.stride(-1) != 1:
78
- dout = dout.contiguous()
79
- # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
80
- # backward of selective_scan_cuda with the backward of chunk).
81
- # Here we just pass in None and dz will be allocated in the C++ code.
82
- du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = ops.selective_scan_bwd(
83
- u,
84
- delta,
85
- A,
86
- B,
87
- C,
88
- D,
89
- z,
90
- delta_bias,
91
- dout,
92
- x,
93
- out,
94
- None,
95
- ctx.delta_softplus,
96
- False, # option to recompute out_z, not used here
97
- )
98
- dz = rest[0] if ctx.has_z else None
99
- dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
100
- dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
101
- return (
102
- du,
103
- ddelta,
104
- dA,
105
- dB,
106
- dC,
107
- dD if D is not None else None,
108
- dz,
109
- ddelta_bias if delta_bias is not None else None,
110
- None,
111
- None,
112
- )
113
-
114
-
115
- def rms_norm_forward(
116
- x,
117
- weight,
118
- bias,
119
- eps=1e-6,
120
- is_rms_norm=True,
121
- ):
122
- # x (b l) d
123
- if x.stride(-1) != 1:
124
- x = x.contiguous()
125
- weight = weight.contiguous()
126
- if bias is not None:
127
- bias = bias.contiguous()
128
- y = _layer_norm_fwd(
129
- x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm
130
- )[0]
131
- # y (b l) d
132
- return y
133
-
134
-
135
- def selective_scan_fn(
136
- u,
137
- delta,
138
- A,
139
- B,
140
- C,
141
- D=None,
142
- z=None,
143
- delta_bias=None,
144
- delta_softplus=False,
145
- return_last_state=False,
146
- ):
147
- """if return_last_state is True, returns (out, last_state)
148
- last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
149
- not considered in the backward pass.
150
- """
151
- return SelectiveScanFn.apply(
152
- u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state
153
- )
154
-
155
-
156
- def selective_scan_ref(
157
- u,
158
- delta,
159
- A,
160
- B,
161
- C,
162
- D=None,
163
- z=None,
164
- delta_bias=None,
165
- delta_softplus=False,
166
- return_last_state=False,
167
- ):
168
- """
169
- u: r(B D L)
170
- delta: r(B D L)
171
- A: c(D N) or r(D N)
172
- B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
173
- C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
174
- D: r(D)
175
- z: r(B D L)
176
- delta_bias: r(D), fp32
177
-
178
- out: r(B D L)
179
- last_state (optional): r(B D dstate) or c(B D dstate)
180
- """
181
- dtype_in = u.dtype
182
- u = u.float()
183
- delta = delta.float()
184
- if delta_bias is not None:
185
- delta = delta + delta_bias[..., None].float()
186
- if delta_softplus:
187
- delta = F.softplus(delta)
188
- batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
189
- is_variable_B = B.dim() >= 3
190
- is_variable_C = C.dim() >= 3
191
- if A.is_complex():
192
- if is_variable_B:
193
- B = torch.view_as_complex(
194
- rearrange(B.float(), "... (L two) -> ... L two", two=2)
195
- )
196
- if is_variable_C:
197
- C = torch.view_as_complex(
198
- rearrange(C.float(), "... (L two) -> ... L two", two=2)
199
- )
200
- else:
201
- B = B.float()
202
- C = C.float()
203
- x = A.new_zeros((batch, dim, dstate))
204
- ys = []
205
- deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
206
- if not is_variable_B:
207
- deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
208
- else:
209
- if B.dim() == 3:
210
- deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
211
- else:
212
- B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
213
- deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
214
- if is_variable_C and C.dim() == 4:
215
- C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
216
- last_state = None
217
- for i in range(u.shape[2]):
218
- x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
219
- if not is_variable_C:
220
- y = torch.einsum("bdn,dn->bd", x, C)
221
- else:
222
- if C.dim() == 3:
223
- y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
224
- else:
225
- y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
226
- if i == u.shape[2] - 1:
227
- last_state = x
228
- if y.is_complex():
229
- y = y.real * 2
230
- ys.append(y)
231
- y = torch.stack(ys, dim=2) # (batch dim L)
232
- out = y if D is None else y + u * rearrange(D, "d -> d 1")
233
- if z is not None:
234
- out = out * F.silu(z)
235
- out = out.to(dtype=dtype_in)
236
- return out if not return_last_state else (out, last_state)
237
-
238
-
239
- class MambaInnerFn(torch.autograd.Function):
240
-
241
- @staticmethod
242
- @custom_fwd
243
- def forward(
244
- ctx,
245
- xz,
246
- conv1d_weight,
247
- conv1d_bias,
248
- x_proj_weight,
249
- delta_proj_weight,
250
- out_proj_weight,
251
- out_proj_bias,
252
- A,
253
- B=None,
254
- C=None,
255
- D=None,
256
- delta_bias=None,
257
- B_proj_bias=None,
258
- C_proj_bias=None,
259
- delta_softplus=True,
260
- checkpoint_lvl=1,
261
- b_rms_weight=None,
262
- c_rms_weight=None,
263
- dt_rms_weight=None,
264
- b_c_dt_rms_eps=1e-6,
265
- ):
266
- """
267
- xz: (batch, dim, seqlen)
268
- """
269
- assert (
270
- causal_conv1d_cuda is not None
271
- ), "causal_conv1d_cuda is not available. Please install causal-conv1d."
272
- assert checkpoint_lvl in [0, 1]
273
- L = xz.shape[-1]
274
- delta_rank = delta_proj_weight.shape[1]
275
- d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
276
- if torch.is_autocast_enabled():
277
- x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
278
- delta_proj_weight = delta_proj_weight.to(
279
- dtype=torch.get_autocast_gpu_dtype()
280
- )
281
- out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
282
- out_proj_bias = (
283
- out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
284
- if out_proj_bias is not None
285
- else None
286
- )
287
- if xz.stride(-1) != 1:
288
- xz = xz.contiguous()
289
- conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
290
- x, z = xz.chunk(2, dim=1)
291
- conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
292
- conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
293
- x, conv1d_weight, conv1d_bias, None, None, None, True
294
- )
295
- # We're being very careful here about the layout, to avoid extra transposes.
296
- # We want delta to have d as the slowest moving dimension
297
- # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
298
- x_dbl = F.linear(
299
- rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight
300
- ) # (bl d)
301
- delta = rearrange(
302
- delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L
303
- )
304
- ctx.is_variable_B = B is None
305
- ctx.is_variable_C = C is None
306
- ctx.B_proj_bias_is_None = B_proj_bias is None
307
- ctx.C_proj_bias_is_None = C_proj_bias is None
308
- if B is None: # variable B
309
- B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate)
310
- if B_proj_bias is not None:
311
- B = B + B_proj_bias.to(dtype=B.dtype)
312
- if not A.is_complex():
313
- # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
314
- B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
315
- else:
316
- B = rearrange(
317
- B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2
318
- ).contiguous()
319
- else:
320
- if B.stride(-1) != 1:
321
- B = B.contiguous()
322
- if C is None: # variable C
323
- C = x_dbl[:, -d_state:] # (bl dstate)
324
- if C_proj_bias is not None:
325
- C = C + C_proj_bias.to(dtype=C.dtype)
326
- if not A.is_complex():
327
- # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
328
- C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
329
- else:
330
- C = rearrange(
331
- C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2
332
- ).contiguous()
333
- else:
334
- if C.stride(-1) != 1:
335
- C = C.contiguous()
336
- if D is not None:
337
- D = D.contiguous()
338
-
339
- if b_rms_weight is not None:
340
- B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
341
- B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)
342
- B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
343
- if c_rms_weight is not None:
344
- C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
345
- C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)
346
- C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
347
- if dt_rms_weight is not None:
348
- delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
349
- delta = rms_norm_forward(
350
- delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps
351
- )
352
- delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
353
-
354
- out, scan_intermediates, out_z = ops.selective_scan_fwd(
355
- conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
356
- )
357
- ctx.delta_softplus = delta_softplus
358
- ctx.out_proj_bias_is_None = out_proj_bias is None
359
- ctx.checkpoint_lvl = checkpoint_lvl
360
- ctx.b_rms_weight = b_rms_weight
361
- ctx.c_rms_weight = c_rms_weight
362
- ctx.dt_rms_weight = dt_rms_weight
363
- ctx.b_c_dt_rms_eps = b_c_dt_rms_eps
364
- if (
365
- checkpoint_lvl >= 1
366
- ): # Will recompute conv1d_out and delta in the backward pass
367
- conv1d_out, delta = None, None
368
- ctx.save_for_backward(
369
- xz,
370
- conv1d_weight,
371
- conv1d_bias,
372
- x_dbl,
373
- x_proj_weight,
374
- delta_proj_weight,
375
- out_proj_weight,
376
- conv1d_out,
377
- delta,
378
- A,
379
- B,
380
- C,
381
- D,
382
- delta_bias,
383
- scan_intermediates,
384
- b_rms_weight,
385
- c_rms_weight,
386
- dt_rms_weight,
387
- out,
388
- )
389
- return F.linear(
390
- rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias
391
- )
392
-
393
- @staticmethod
394
- @custom_bwd
395
- def backward(ctx, dout):
396
- # dout: (batch, seqlen, dim)
397
- assert (
398
- causal_conv1d_cuda is not None
399
- ), "causal_conv1d_cuda is not available. Please install causal-conv1d."
400
- (
401
- xz,
402
- conv1d_weight,
403
- conv1d_bias,
404
- x_dbl,
405
- x_proj_weight,
406
- delta_proj_weight,
407
- out_proj_weight,
408
- conv1d_out,
409
- delta,
410
- A,
411
- B,
412
- C,
413
- D,
414
- delta_bias,
415
- scan_intermediates,
416
- b_rms_weight,
417
- c_rms_weight,
418
- dt_rms_weight,
419
- out,
420
- ) = ctx.saved_tensors
421
- L = xz.shape[-1]
422
- delta_rank = delta_proj_weight.shape[1]
423
- d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
424
- x, z = xz.chunk(2, dim=1)
425
- if dout.stride(-1) != 1:
426
- dout = dout.contiguous()
427
- if ctx.checkpoint_lvl == 1:
428
- conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
429
- x, conv1d_weight, conv1d_bias, None, None, None, True
430
- )
431
- delta = rearrange(
432
- delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L
433
- )
434
- if dt_rms_weight is not None:
435
- delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
436
- delta = rms_norm_forward(
437
- delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps
438
- )
439
- delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
440
- if b_rms_weight is not None:
441
- # Recompute & RMSNorm B
442
- B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
443
- B = rms_norm_forward(B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps)
444
- B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
445
- if c_rms_weight is not None:
446
- # Recompute & RMSNorm C
447
- C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
448
- C = rms_norm_forward(C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps)
449
- C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
450
-
451
- # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
452
- # backward of selective_scan_cuda with the backward of chunk).
453
- dxz = torch.empty_like(xz) # (batch, dim, seqlen)
454
- dx, dz = dxz.chunk(2, dim=1)
455
- dout = rearrange(dout, "b l e -> e (b l)")
456
- dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
457
- dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = (
458
- ops.selective_scan_bwd(
459
- conv1d_out,
460
- delta,
461
- A,
462
- B,
463
- C,
464
- D,
465
- z,
466
- delta_bias,
467
- dout_y,
468
- scan_intermediates,
469
- out,
470
- dz,
471
- ctx.delta_softplus,
472
- True, # option to recompute out_z
473
- )
474
- )
475
- dout_proj_weight = torch.einsum(
476
- "eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")
477
- )
478
- dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
479
- dD = dD if D is not None else None
480
- dx_dbl = torch.empty_like(x_dbl)
481
- dB_proj_bias = None
482
- if ctx.is_variable_B:
483
- if not A.is_complex():
484
- dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
485
- else:
486
- dB = rearrange(
487
- dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2
488
- ).contiguous()
489
- dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
490
- dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d)
491
- dB = None
492
- dC_proj_bias = None
493
- if ctx.is_variable_C:
494
- if not A.is_complex():
495
- dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
496
- else:
497
- dC = rearrange(
498
- dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2
499
- ).contiguous()
500
- dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
501
- dx_dbl[:, -d_state:] = dC # (bl d)
502
- dC = None
503
- ddelta = rearrange(ddelta, "b d l -> d (b l)")
504
- ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
505
- dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
506
- dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
507
- dx_proj_weight = torch.einsum(
508
- "Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")
509
- )
510
- dconv1d_out = torch.addmm(
511
- dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out
512
- )
513
- dconv1d_out = rearrange(
514
- dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]
515
- )
516
- # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
517
- # backward of conv1d with the backward of chunk).
518
- dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
519
- x,
520
- conv1d_weight,
521
- conv1d_bias,
522
- dconv1d_out,
523
- None,
524
- None,
525
- None,
526
- dx,
527
- False,
528
- True,
529
- )
530
- dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
531
- dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
532
- return (
533
- dxz,
534
- dconv1d_weight,
535
- dconv1d_bias,
536
- dx_proj_weight,
537
- ddelta_proj_weight,
538
- dout_proj_weight,
539
- dout_proj_bias,
540
- dA,
541
- dB,
542
- dC,
543
- dD,
544
- ddelta_bias if delta_bias is not None else None,
545
- # 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps
546
- dB_proj_bias,
547
- dC_proj_bias,
548
- None,
549
- None,
550
- None,
551
- None,
552
- None,
553
- None,
554
- )
555
-
556
-
557
- def mamba_inner_fn(
558
- xz,
559
- conv1d_weight,
560
- conv1d_bias,
561
- x_proj_weight,
562
- delta_proj_weight,
563
- out_proj_weight,
564
- out_proj_bias,
565
- A,
566
- B=None,
567
- C=None,
568
- D=None,
569
- delta_bias=None,
570
- B_proj_bias=None,
571
- C_proj_bias=None,
572
- delta_softplus=True,
573
- checkpoint_lvl=1,
574
- b_rms_weight=None,
575
- c_rms_weight=None,
576
- dt_rms_weight=None,
577
- b_c_dt_rms_eps=1e-6,
578
- ):
579
- return MambaInnerFn.apply(
580
- xz,
581
- conv1d_weight,
582
- conv1d_bias,
583
- x_proj_weight,
584
- delta_proj_weight,
585
- out_proj_weight,
586
- out_proj_bias,
587
- A,
588
- B,
589
- C,
590
- D,
591
- delta_bias,
592
- B_proj_bias,
593
- C_proj_bias,
594
- delta_softplus,
595
- checkpoint_lvl,
596
- b_rms_weight,
597
- c_rms_weight,
598
- dt_rms_weight,
599
- b_c_dt_rms_eps,
600
- )
601
-
602
-
603
- def mamba_inner_ref(
604
- xz,
605
- conv1d_weight,
606
- conv1d_bias,
607
- x_proj_weight,
608
- delta_proj_weight,
609
- out_proj_weight,
610
- out_proj_bias,
611
- A,
612
- B=None,
613
- C=None,
614
- D=None,
615
- delta_bias=None,
616
- B_proj_bias=None,
617
- C_proj_bias=None,
618
- delta_softplus=True,
619
- ):
620
- assert (
621
- causal_conv1d_fn is not None
622
- ), "causal_conv1d_fn is not available. Please install causal-conv1d."
623
- L = xz.shape[-1]
624
- delta_rank = delta_proj_weight.shape[1]
625
- d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
626
- x, z = xz.chunk(2, dim=1)
627
- x = causal_conv1d_fn(
628
- x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu"
629
- )
630
- # We're being very careful here about the layout, to avoid extra transposes.
631
- # We want delta to have d as the slowest moving dimension
632
- # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
633
- x_dbl = F.linear(rearrange(x, "b d l -> (b l) d"), x_proj_weight) # (bl d)
634
- delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
635
- delta = rearrange(delta, "d (b l) -> b d l", l=L)
636
- if B is None: # variable B
637
- B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl d)
638
- if B_proj_bias is not None:
639
- B = B + B_proj_bias.to(dtype=B.dtype)
640
- if not A.is_complex():
641
- B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
642
- else:
643
- B = rearrange(
644
- B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2
645
- ).contiguous()
646
- if C is None: # variable B
647
- C = x_dbl[:, -d_state:] # (bl d)
648
- if C_proj_bias is not None:
649
- C = C + C_proj_bias.to(dtype=C.dtype)
650
- if not A.is_complex():
651
- C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
652
- else:
653
- C = rearrange(
654
- C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2
655
- ).contiguous()
656
- y = selective_scan_fn(
657
- x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True
658
- )
659
- return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/triton/__init__.py DELETED
File without changes
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/triton/layer_norm.py DELETED
@@ -1,1166 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
- # Implement dropout + residual + layer_norm / rms_norm.
3
-
4
- # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
5
- # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
6
- # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
7
- # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
8
-
9
- import math
10
- import warnings
11
-
12
- import torch
13
- import torch.nn.functional as F
14
- from ...utils.torch import custom_bwd, custom_fwd
15
-
16
- import triton
17
- import triton.language as tl
18
-
19
-
20
- def layer_norm_ref(
21
- x,
22
- weight,
23
- bias,
24
- residual=None,
25
- x1=None,
26
- weight1=None,
27
- bias1=None,
28
- eps=1e-6,
29
- dropout_p=0.0,
30
- rowscale=None,
31
- prenorm=False,
32
- dropout_mask=None,
33
- dropout_mask1=None,
34
- upcast=False,
35
- ):
36
- dtype = x.dtype
37
- if upcast:
38
- x = x.float()
39
- weight = weight.float()
40
- bias = bias.float() if bias is not None else None
41
- residual = residual.float() if residual is not None else residual
42
- x1 = x1.float() if x1 is not None else None
43
- weight1 = weight1.float() if weight1 is not None else None
44
- bias1 = bias1.float() if bias1 is not None else None
45
- if x1 is not None:
46
- assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
47
- if rowscale is not None:
48
- x = x * rowscale[..., None]
49
- if dropout_p > 0.0:
50
- if dropout_mask is not None:
51
- x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
52
- else:
53
- x = F.dropout(x, p=dropout_p)
54
- if x1 is not None:
55
- if dropout_mask1 is not None:
56
- x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
57
- else:
58
- x1 = F.dropout(x1, p=dropout_p)
59
- if x1 is not None:
60
- x = x + x1
61
- if residual is not None:
62
- x = (x + residual).to(x.dtype)
63
- out = F.layer_norm(
64
- x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps
65
- ).to(dtype)
66
- if weight1 is None:
67
- return out if not prenorm else (out, x)
68
- else:
69
- out1 = F.layer_norm(
70
- x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
71
- ).to(dtype)
72
- return (out, out1) if not prenorm else (out, out1, x)
73
-
74
-
75
- def rms_norm_ref(
76
- x,
77
- weight,
78
- bias,
79
- residual=None,
80
- x1=None,
81
- weight1=None,
82
- bias1=None,
83
- eps=1e-6,
84
- dropout_p=0.0,
85
- rowscale=None,
86
- prenorm=False,
87
- dropout_mask=None,
88
- dropout_mask1=None,
89
- upcast=False,
90
- ):
91
- dtype = x.dtype
92
- if upcast:
93
- x = x.float()
94
- weight = weight.float()
95
- bias = bias.float() if bias is not None else None
96
- residual = residual.float() if residual is not None else residual
97
- x1 = x1.float() if x1 is not None else None
98
- weight1 = weight1.float() if weight1 is not None else None
99
- bias1 = bias1.float() if bias1 is not None else None
100
- if x1 is not None:
101
- assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
102
- if rowscale is not None:
103
- x = x * rowscale[..., None]
104
- if dropout_p > 0.0:
105
- if dropout_mask is not None:
106
- x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
107
- else:
108
- x = F.dropout(x, p=dropout_p)
109
- if x1 is not None:
110
- if dropout_mask1 is not None:
111
- x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
112
- else:
113
- x1 = F.dropout(x1, p=dropout_p)
114
- if x1 is not None:
115
- x = x + x1
116
- if residual is not None:
117
- x = (x + residual).to(x.dtype)
118
- rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
119
- out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(
120
- dtype
121
- )
122
- if weight1 is None:
123
- return out if not prenorm else (out, x)
124
- else:
125
- out1 = (
126
- (x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)
127
- ).to(dtype)
128
- return (out, out1) if not prenorm else (out, out1, x)
129
-
130
-
131
- def config_prune(configs):
132
-
133
- if torch.version.hip:
134
- try:
135
- # set warp size based on gcn architecure
136
- gcn_arch_name = torch.cuda.get_device_properties(0).gcnArchName
137
- if "gfx10" in gcn_arch_name or "gfx11" in gcn_arch_name:
138
- # radeon
139
- warp_size = 32
140
- else:
141
- # instinct
142
- warp_size = 64
143
- except AttributeError as e:
144
- # fall back to crude method to set warp size
145
- device_name = torch.cuda.get_device_properties(0).name
146
- if "instinct" in device_name.lower():
147
- warp_size = 64
148
- else:
149
- warp_size = 32
150
- warnings.warn(
151
- f"{e}, warp size set to {warp_size} based on device name: {device_name}",
152
- UserWarning,
153
- )
154
-
155
- else:
156
- # cuda
157
- warp_size = 32
158
-
159
- max_block_sz = 1024
160
- max_num_warps = max_block_sz // warp_size
161
- pruned_configs = [config for config in configs if config.num_warps <= max_num_warps]
162
- return pruned_configs
163
-
164
-
165
- configs_autotune = [
166
- triton.Config({}, num_warps=1),
167
- triton.Config({}, num_warps=2),
168
- triton.Config({}, num_warps=4),
169
- triton.Config({}, num_warps=8),
170
- triton.Config({}, num_warps=16),
171
- triton.Config({}, num_warps=32),
172
- ]
173
-
174
- pruned_configs_autotune = config_prune(configs_autotune)
175
-
176
-
177
- @triton.autotune(
178
- configs=pruned_configs_autotune,
179
- key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
180
- )
181
- # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
182
- # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
183
- @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
184
- @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
185
- @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
186
- @triton.jit
187
- def _layer_norm_fwd_1pass_kernel(
188
- X, # pointer to the input
189
- Y, # pointer to the output
190
- W, # pointer to the weights
191
- B, # pointer to the biases
192
- RESIDUAL, # pointer to the residual
193
- X1,
194
- W1,
195
- B1,
196
- Y1,
197
- RESIDUAL_OUT, # pointer to the residual
198
- ROWSCALE,
199
- SEEDS, # Dropout seeds for each row
200
- DROPOUT_MASK,
201
- Mean, # pointer to the mean
202
- Rstd, # pointer to the 1/std
203
- stride_x_row, # how much to increase the pointer when moving by 1 row
204
- stride_y_row,
205
- stride_res_row,
206
- stride_res_out_row,
207
- stride_x1_row,
208
- stride_y1_row,
209
- M, # number of rows in X
210
- N, # number of columns in X
211
- eps, # epsilon to avoid division by zero
212
- dropout_p, # Dropout probability
213
- IS_RMS_NORM: tl.constexpr,
214
- BLOCK_N: tl.constexpr,
215
- HAS_RESIDUAL: tl.constexpr,
216
- STORE_RESIDUAL_OUT: tl.constexpr,
217
- HAS_BIAS: tl.constexpr,
218
- HAS_DROPOUT: tl.constexpr,
219
- STORE_DROPOUT_MASK: tl.constexpr,
220
- HAS_ROWSCALE: tl.constexpr,
221
- HAS_X1: tl.constexpr,
222
- HAS_W1: tl.constexpr,
223
- HAS_B1: tl.constexpr,
224
- ):
225
- # Map the program id to the row of X and Y it should compute.
226
- row = tl.program_id(0)
227
- X += row * stride_x_row
228
- Y += row * stride_y_row
229
- if HAS_RESIDUAL:
230
- RESIDUAL += row * stride_res_row
231
- if STORE_RESIDUAL_OUT:
232
- RESIDUAL_OUT += row * stride_res_out_row
233
- if HAS_X1:
234
- X1 += row * stride_x1_row
235
- if HAS_W1:
236
- Y1 += row * stride_y1_row
237
- # Compute mean and variance
238
- cols = tl.arange(0, BLOCK_N)
239
- x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
240
- if HAS_ROWSCALE:
241
- rowscale = tl.load(ROWSCALE + row).to(tl.float32)
242
- x *= rowscale
243
- if HAS_DROPOUT:
244
- # Compute dropout mask
245
- # 7 rounds is good enough, and reduces register pressure
246
- keep_mask = (
247
- tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
248
- )
249
- x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
250
- if STORE_DROPOUT_MASK:
251
- tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
252
- if HAS_X1:
253
- x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
254
- if HAS_ROWSCALE:
255
- rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
256
- x1 *= rowscale
257
- if HAS_DROPOUT:
258
- # Compute dropout mask
259
- # 7 rounds is good enough, and reduces register pressure
260
- keep_mask = (
261
- tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
262
- > dropout_p
263
- )
264
- x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
265
- if STORE_DROPOUT_MASK:
266
- tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
267
- x += x1
268
- if HAS_RESIDUAL:
269
- residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
270
- x += residual
271
- if STORE_RESIDUAL_OUT:
272
- tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
273
- if not IS_RMS_NORM:
274
- mean = tl.sum(x, axis=0) / N
275
- tl.store(Mean + row, mean)
276
- xbar = tl.where(cols < N, x - mean, 0.0)
277
- var = tl.sum(xbar * xbar, axis=0) / N
278
- else:
279
- xbar = tl.where(cols < N, x, 0.0)
280
- var = tl.sum(xbar * xbar, axis=0) / N
281
- rstd = 1 / tl.sqrt(var + eps)
282
- tl.store(Rstd + row, rstd)
283
- # Normalize and apply linear transformation
284
- mask = cols < N
285
- w = tl.load(W + cols, mask=mask).to(tl.float32)
286
- if HAS_BIAS:
287
- b = tl.load(B + cols, mask=mask).to(tl.float32)
288
- x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
289
- y = x_hat * w + b if HAS_BIAS else x_hat * w
290
- # Write output
291
- tl.store(Y + cols, y, mask=mask)
292
- if HAS_W1:
293
- w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
294
- if HAS_B1:
295
- b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
296
- y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
297
- tl.store(Y1 + cols, y1, mask=mask)
298
-
299
-
300
- def _layer_norm_fwd(
301
- x,
302
- weight,
303
- bias,
304
- eps,
305
- residual=None,
306
- x1=None,
307
- weight1=None,
308
- bias1=None,
309
- dropout_p=0.0,
310
- rowscale=None,
311
- out_dtype=None,
312
- residual_dtype=None,
313
- is_rms_norm=False,
314
- return_dropout_mask=False,
315
- ):
316
- if residual is not None:
317
- residual_dtype = residual.dtype
318
- M, N = x.shape
319
- assert x.stride(-1) == 1
320
- if residual is not None:
321
- assert residual.stride(-1) == 1
322
- assert residual.shape == (M, N)
323
- assert weight.shape == (N,)
324
- assert weight.stride(-1) == 1
325
- if bias is not None:
326
- assert bias.stride(-1) == 1
327
- assert bias.shape == (N,)
328
- if x1 is not None:
329
- assert x1.shape == x.shape
330
- assert rowscale is None
331
- assert x1.stride(-1) == 1
332
- if weight1 is not None:
333
- assert weight1.shape == (N,)
334
- assert weight1.stride(-1) == 1
335
- if bias1 is not None:
336
- assert bias1.shape == (N,)
337
- assert bias1.stride(-1) == 1
338
- if rowscale is not None:
339
- assert rowscale.is_contiguous()
340
- assert rowscale.shape == (M,)
341
- # allocate output
342
- y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
343
- assert y.stride(-1) == 1
344
- if weight1 is not None:
345
- y1 = torch.empty_like(y)
346
- assert y1.stride(-1) == 1
347
- else:
348
- y1 = None
349
- if (
350
- residual is not None
351
- or (residual_dtype is not None and residual_dtype != x.dtype)
352
- or dropout_p > 0.0
353
- or rowscale is not None
354
- or x1 is not None
355
- ):
356
- residual_out = torch.empty(
357
- M,
358
- N,
359
- device=x.device,
360
- dtype=residual_dtype if residual_dtype is not None else x.dtype,
361
- )
362
- assert residual_out.stride(-1) == 1
363
- else:
364
- residual_out = None
365
- mean = (
366
- torch.empty((M,), dtype=torch.float32, device=x.device)
367
- if not is_rms_norm
368
- else None
369
- )
370
- rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
371
- if dropout_p > 0.0:
372
- seeds = torch.randint(
373
- 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
374
- )
375
- else:
376
- seeds = None
377
- if return_dropout_mask and dropout_p > 0.0:
378
- dropout_mask = torch.empty(
379
- M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool
380
- )
381
- else:
382
- dropout_mask = None
383
- # Less than 64KB per feature: enqueue fused kernel
384
- MAX_FUSED_SIZE = 65536 // x.element_size()
385
- BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
386
- if N > BLOCK_N:
387
- raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
388
- with torch.cuda.device(x.device.index):
389
- _layer_norm_fwd_1pass_kernel[(M,)](
390
- x,
391
- y,
392
- weight,
393
- bias,
394
- residual,
395
- x1,
396
- weight1,
397
- bias1,
398
- y1,
399
- residual_out,
400
- rowscale,
401
- seeds,
402
- dropout_mask,
403
- mean,
404
- rstd,
405
- x.stride(0),
406
- y.stride(0),
407
- residual.stride(0) if residual is not None else 0,
408
- residual_out.stride(0) if residual_out is not None else 0,
409
- x1.stride(0) if x1 is not None else 0,
410
- y1.stride(0) if y1 is not None else 0,
411
- M,
412
- N,
413
- eps,
414
- dropout_p,
415
- is_rms_norm,
416
- BLOCK_N,
417
- residual is not None,
418
- residual_out is not None,
419
- bias is not None,
420
- dropout_p > 0.0,
421
- dropout_mask is not None,
422
- rowscale is not None,
423
- )
424
- # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
425
- if dropout_mask is not None and x1 is not None:
426
- dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
427
- else:
428
- dropout_mask1 = None
429
- return (
430
- y,
431
- y1,
432
- mean,
433
- rstd,
434
- residual_out if residual_out is not None else x,
435
- seeds,
436
- dropout_mask,
437
- dropout_mask1,
438
- )
439
-
440
-
441
- @triton.autotune(
442
- configs=pruned_configs_autotune,
443
- key=[
444
- "N",
445
- "HAS_DRESIDUAL",
446
- "STORE_DRESIDUAL",
447
- "IS_RMS_NORM",
448
- "HAS_BIAS",
449
- "HAS_DROPOUT",
450
- ],
451
- )
452
- # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
453
- # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
454
- # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
455
- @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
456
- @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
457
- @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
458
- @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
459
- @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
460
- @triton.jit
461
- def _layer_norm_bwd_kernel(
462
- X, # pointer to the input
463
- W, # pointer to the weights
464
- B, # pointer to the biases
465
- Y, # pointer to the output to be recomputed
466
- DY, # pointer to the output gradient
467
- DX, # pointer to the input gradient
468
- DW, # pointer to the partial sum of weights gradient
469
- DB, # pointer to the partial sum of biases gradient
470
- DRESIDUAL,
471
- W1,
472
- DY1,
473
- DX1,
474
- DW1,
475
- DB1,
476
- DRESIDUAL_IN,
477
- ROWSCALE,
478
- SEEDS,
479
- Mean, # pointer to the mean
480
- Rstd, # pointer to the 1/std
481
- stride_x_row, # how much to increase the pointer when moving by 1 row
482
- stride_y_row,
483
- stride_dy_row,
484
- stride_dx_row,
485
- stride_dres_row,
486
- stride_dy1_row,
487
- stride_dx1_row,
488
- stride_dres_in_row,
489
- M, # number of rows in X
490
- N, # number of columns in X
491
- eps, # epsilon to avoid division by zero
492
- dropout_p,
493
- rows_per_program,
494
- IS_RMS_NORM: tl.constexpr,
495
- BLOCK_N: tl.constexpr,
496
- HAS_DRESIDUAL: tl.constexpr,
497
- STORE_DRESIDUAL: tl.constexpr,
498
- HAS_BIAS: tl.constexpr,
499
- HAS_DROPOUT: tl.constexpr,
500
- HAS_ROWSCALE: tl.constexpr,
501
- HAS_DY1: tl.constexpr,
502
- HAS_DX1: tl.constexpr,
503
- HAS_B1: tl.constexpr,
504
- RECOMPUTE_OUTPUT: tl.constexpr,
505
- ):
506
- # Map the program id to the elements of X, DX, and DY it should compute.
507
- row_block_id = tl.program_id(0)
508
- row_start = row_block_id * rows_per_program
509
- # Do not early exit if row_start >= M, because we need to write DW and DB
510
- cols = tl.arange(0, BLOCK_N)
511
- mask = cols < N
512
- X += row_start * stride_x_row
513
- if HAS_DRESIDUAL:
514
- DRESIDUAL += row_start * stride_dres_row
515
- if STORE_DRESIDUAL:
516
- DRESIDUAL_IN += row_start * stride_dres_in_row
517
- DY += row_start * stride_dy_row
518
- DX += row_start * stride_dx_row
519
- if HAS_DY1:
520
- DY1 += row_start * stride_dy1_row
521
- if HAS_DX1:
522
- DX1 += row_start * stride_dx1_row
523
- if RECOMPUTE_OUTPUT:
524
- Y += row_start * stride_y_row
525
- w = tl.load(W + cols, mask=mask).to(tl.float32)
526
- if RECOMPUTE_OUTPUT and HAS_BIAS:
527
- b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
528
- if HAS_DY1:
529
- w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
530
- dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
531
- if HAS_BIAS:
532
- db = tl.zeros((BLOCK_N,), dtype=tl.float32)
533
- if HAS_DY1:
534
- dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
535
- if HAS_B1:
536
- db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
537
- row_end = min((row_block_id + 1) * rows_per_program, M)
538
- for row in range(row_start, row_end):
539
- # Load data to SRAM
540
- x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
541
- dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
542
- if HAS_DY1:
543
- dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
544
- if not IS_RMS_NORM:
545
- mean = tl.load(Mean + row)
546
- rstd = tl.load(Rstd + row)
547
- # Compute dx
548
- xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
549
- xhat = tl.where(mask, xhat, 0.0)
550
- if RECOMPUTE_OUTPUT:
551
- y = xhat * w + b if HAS_BIAS else xhat * w
552
- tl.store(Y + cols, y, mask=mask)
553
- wdy = w * dy
554
- dw += dy * xhat
555
- if HAS_BIAS:
556
- db += dy
557
- if HAS_DY1:
558
- wdy += w1 * dy1
559
- dw1 += dy1 * xhat
560
- if HAS_B1:
561
- db1 += dy1
562
- if not IS_RMS_NORM:
563
- c1 = tl.sum(xhat * wdy, axis=0) / N
564
- c2 = tl.sum(wdy, axis=0) / N
565
- dx = (wdy - (xhat * c1 + c2)) * rstd
566
- else:
567
- c1 = tl.sum(xhat * wdy, axis=0) / N
568
- dx = (wdy - xhat * c1) * rstd
569
- if HAS_DRESIDUAL:
570
- dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
571
- dx += dres
572
- # Write dx
573
- if STORE_DRESIDUAL:
574
- tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
575
- if HAS_DX1:
576
- if HAS_DROPOUT:
577
- keep_mask = (
578
- tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
579
- > dropout_p
580
- )
581
- dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
582
- else:
583
- dx1 = dx
584
- tl.store(DX1 + cols, dx1, mask=mask)
585
- if HAS_DROPOUT:
586
- keep_mask = (
587
- tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7)
588
- > dropout_p
589
- )
590
- dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
591
- if HAS_ROWSCALE:
592
- rowscale = tl.load(ROWSCALE + row).to(tl.float32)
593
- dx *= rowscale
594
- tl.store(DX + cols, dx, mask=mask)
595
-
596
- X += stride_x_row
597
- if HAS_DRESIDUAL:
598
- DRESIDUAL += stride_dres_row
599
- if STORE_DRESIDUAL:
600
- DRESIDUAL_IN += stride_dres_in_row
601
- if RECOMPUTE_OUTPUT:
602
- Y += stride_y_row
603
- DY += stride_dy_row
604
- DX += stride_dx_row
605
- if HAS_DY1:
606
- DY1 += stride_dy1_row
607
- if HAS_DX1:
608
- DX1 += stride_dx1_row
609
- tl.store(DW + row_block_id * N + cols, dw, mask=mask)
610
- if HAS_BIAS:
611
- tl.store(DB + row_block_id * N + cols, db, mask=mask)
612
- if HAS_DY1:
613
- tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
614
- if HAS_B1:
615
- tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
616
-
617
-
618
- def _layer_norm_bwd(
619
- dy,
620
- x,
621
- weight,
622
- bias,
623
- eps,
624
- mean,
625
- rstd,
626
- dresidual=None,
627
- dy1=None,
628
- weight1=None,
629
- bias1=None,
630
- seeds=None,
631
- dropout_p=0.0,
632
- rowscale=None,
633
- has_residual=False,
634
- has_x1=False,
635
- is_rms_norm=False,
636
- x_dtype=None,
637
- recompute_output=False,
638
- ):
639
- M, N = x.shape
640
- assert x.stride(-1) == 1
641
- assert dy.stride(-1) == 1
642
- assert dy.shape == (M, N)
643
- if dresidual is not None:
644
- assert dresidual.stride(-1) == 1
645
- assert dresidual.shape == (M, N)
646
- assert weight.shape == (N,)
647
- assert weight.stride(-1) == 1
648
- if bias is not None:
649
- assert bias.stride(-1) == 1
650
- assert bias.shape == (N,)
651
- if dy1 is not None:
652
- assert weight1 is not None
653
- assert dy1.shape == dy.shape
654
- assert dy1.stride(-1) == 1
655
- if weight1 is not None:
656
- assert weight1.shape == (N,)
657
- assert weight1.stride(-1) == 1
658
- if bias1 is not None:
659
- assert bias1.shape == (N,)
660
- assert bias1.stride(-1) == 1
661
- if seeds is not None:
662
- assert seeds.is_contiguous()
663
- assert seeds.shape == (M if not has_x1 else M * 2,)
664
- if rowscale is not None:
665
- assert rowscale.is_contiguous()
666
- assert rowscale.shape == (M,)
667
- # allocate output
668
- dx = (
669
- torch.empty_like(x)
670
- if x_dtype is None
671
- else torch.empty(M, N, dtype=x_dtype, device=x.device)
672
- )
673
- dresidual_in = (
674
- torch.empty_like(x)
675
- if has_residual
676
- and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
677
- else None
678
- )
679
- dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
680
- y = (
681
- torch.empty(M, N, dtype=dy.dtype, device=dy.device)
682
- if recompute_output
683
- else None
684
- )
685
- if recompute_output:
686
- assert (
687
- weight1 is None
688
- ), "recompute_output is not supported with parallel LayerNorm"
689
-
690
- # Less than 64KB per feature: enqueue fused kernel
691
- MAX_FUSED_SIZE = 65536 // x.element_size()
692
- BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
693
- if N > BLOCK_N:
694
- raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
695
- sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
696
- _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
697
- _db = (
698
- torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
699
- if bias is not None
700
- else None
701
- )
702
- _dw1 = torch.empty_like(_dw) if weight1 is not None else None
703
- _db1 = torch.empty_like(_db) if bias1 is not None else None
704
- rows_per_program = math.ceil(M / sm_count)
705
- grid = (sm_count,)
706
- with torch.cuda.device(x.device.index):
707
- _layer_norm_bwd_kernel[grid](
708
- x,
709
- weight,
710
- bias,
711
- y,
712
- dy,
713
- dx,
714
- _dw,
715
- _db,
716
- dresidual,
717
- weight1,
718
- dy1,
719
- dx1,
720
- _dw1,
721
- _db1,
722
- dresidual_in,
723
- rowscale,
724
- seeds,
725
- mean,
726
- rstd,
727
- x.stride(0),
728
- 0 if not recompute_output else y.stride(0),
729
- dy.stride(0),
730
- dx.stride(0),
731
- dresidual.stride(0) if dresidual is not None else 0,
732
- dy1.stride(0) if dy1 is not None else 0,
733
- dx1.stride(0) if dx1 is not None else 0,
734
- dresidual_in.stride(0) if dresidual_in is not None else 0,
735
- M,
736
- N,
737
- eps,
738
- dropout_p,
739
- rows_per_program,
740
- is_rms_norm,
741
- BLOCK_N,
742
- dresidual is not None,
743
- dresidual_in is not None,
744
- bias is not None,
745
- dropout_p > 0.0,
746
- )
747
- dw = _dw.sum(0).to(weight.dtype)
748
- db = _db.sum(0).to(bias.dtype) if bias is not None else None
749
- dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
750
- db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
751
- # Don't need to compute dresidual_in separately in this case
752
- if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
753
- dresidual_in = dx
754
- if has_x1 and dropout_p == 0.0:
755
- dx1 = dx
756
- return (
757
- (dx, dw, db, dresidual_in, dx1, dw1, db1)
758
- if not recompute_output
759
- else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
760
- )
761
-
762
-
763
- class LayerNormFn(torch.autograd.Function):
764
- @staticmethod
765
- def forward(
766
- ctx,
767
- x,
768
- weight,
769
- bias,
770
- residual=None,
771
- x1=None,
772
- weight1=None,
773
- bias1=None,
774
- eps=1e-6,
775
- dropout_p=0.0,
776
- rowscale=None,
777
- prenorm=False,
778
- residual_in_fp32=False,
779
- is_rms_norm=False,
780
- return_dropout_mask=False,
781
- ):
782
- x_shape_og = x.shape
783
- # reshape input data into 2D tensor
784
- x = x.reshape(-1, x.shape[-1])
785
- if x.stride(-1) != 1:
786
- x = x.contiguous()
787
- if residual is not None:
788
- assert residual.shape == x_shape_og
789
- residual = residual.reshape(-1, residual.shape[-1])
790
- if residual.stride(-1) != 1:
791
- residual = residual.contiguous()
792
- if x1 is not None:
793
- assert x1.shape == x_shape_og
794
- assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
795
- x1 = x1.reshape(-1, x1.shape[-1])
796
- if x1.stride(-1) != 1:
797
- x1 = x1.contiguous()
798
- weight = weight.contiguous()
799
- if bias is not None:
800
- bias = bias.contiguous()
801
- if weight1 is not None:
802
- weight1 = weight1.contiguous()
803
- if bias1 is not None:
804
- bias1 = bias1.contiguous()
805
- if rowscale is not None:
806
- rowscale = rowscale.reshape(-1).contiguous()
807
- residual_dtype = (
808
- residual.dtype
809
- if residual is not None
810
- else (torch.float32 if residual_in_fp32 else None)
811
- )
812
- y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = (
813
- _layer_norm_fwd(
814
- x,
815
- weight,
816
- bias,
817
- eps,
818
- residual,
819
- x1,
820
- weight1,
821
- bias1,
822
- dropout_p=dropout_p,
823
- rowscale=rowscale,
824
- residual_dtype=residual_dtype,
825
- is_rms_norm=is_rms_norm,
826
- return_dropout_mask=return_dropout_mask,
827
- )
828
- )
829
- ctx.save_for_backward(
830
- residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
831
- )
832
- ctx.x_shape_og = x_shape_og
833
- ctx.eps = eps
834
- ctx.dropout_p = dropout_p
835
- ctx.is_rms_norm = is_rms_norm
836
- ctx.has_residual = residual is not None
837
- ctx.has_x1 = x1 is not None
838
- ctx.prenorm = prenorm
839
- ctx.x_dtype = x.dtype
840
- y = y.reshape(x_shape_og)
841
- y1 = y1.reshape(x_shape_og) if y1 is not None else None
842
- residual_out = (
843
- residual_out.reshape(x_shape_og) if residual_out is not None else None
844
- )
845
- dropout_mask = (
846
- dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
847
- )
848
- dropout_mask1 = (
849
- dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
850
- )
851
- if not return_dropout_mask:
852
- if weight1 is None:
853
- return y if not prenorm else (y, residual_out)
854
- else:
855
- return (y, y1) if not prenorm else (y, y1, residual_out)
856
- else:
857
- if weight1 is None:
858
- return (
859
- (y, dropout_mask, dropout_mask1)
860
- if not prenorm
861
- else (y, residual_out, dropout_mask, dropout_mask1)
862
- )
863
- else:
864
- return (
865
- (y, y1, dropout_mask, dropout_mask1)
866
- if not prenorm
867
- else (y, y1, residual_out, dropout_mask, dropout_mask1)
868
- )
869
-
870
- @staticmethod
871
- def backward(ctx, dy, *args):
872
- x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
873
- dy = dy.reshape(-1, dy.shape[-1])
874
- if dy.stride(-1) != 1:
875
- dy = dy.contiguous()
876
- assert dy.shape == x.shape
877
- if weight1 is not None:
878
- dy1, args = args[0], args[1:]
879
- dy1 = dy1.reshape(-1, dy1.shape[-1])
880
- if dy1.stride(-1) != 1:
881
- dy1 = dy1.contiguous()
882
- assert dy1.shape == x.shape
883
- else:
884
- dy1 = None
885
- if ctx.prenorm:
886
- dresidual = args[0]
887
- dresidual = dresidual.reshape(-1, dresidual.shape[-1])
888
- if dresidual.stride(-1) != 1:
889
- dresidual = dresidual.contiguous()
890
- assert dresidual.shape == x.shape
891
- else:
892
- dresidual = None
893
- dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
894
- dy,
895
- x,
896
- weight,
897
- bias,
898
- ctx.eps,
899
- mean,
900
- rstd,
901
- dresidual,
902
- dy1,
903
- weight1,
904
- bias1,
905
- seeds,
906
- ctx.dropout_p,
907
- rowscale,
908
- ctx.has_residual,
909
- ctx.has_x1,
910
- ctx.is_rms_norm,
911
- x_dtype=ctx.x_dtype,
912
- )
913
- return (
914
- dx.reshape(ctx.x_shape_og),
915
- dw,
916
- db,
917
- dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
918
- dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
919
- dw1,
920
- db1,
921
- None,
922
- None,
923
- None,
924
- None,
925
- None,
926
- None,
927
- None,
928
- )
929
-
930
-
931
- def layer_norm_fn(
932
- x,
933
- weight,
934
- bias,
935
- residual=None,
936
- x1=None,
937
- weight1=None,
938
- bias1=None,
939
- eps=1e-6,
940
- dropout_p=0.0,
941
- rowscale=None,
942
- prenorm=False,
943
- residual_in_fp32=False,
944
- is_rms_norm=False,
945
- return_dropout_mask=False,
946
- ):
947
- return LayerNormFn.apply(
948
- x,
949
- weight,
950
- bias,
951
- residual,
952
- x1,
953
- weight1,
954
- bias1,
955
- eps,
956
- dropout_p,
957
- rowscale,
958
- prenorm,
959
- residual_in_fp32,
960
- is_rms_norm,
961
- return_dropout_mask,
962
- )
963
-
964
-
965
- def rms_norm_fn(
966
- x,
967
- weight,
968
- bias,
969
- residual=None,
970
- x1=None,
971
- weight1=None,
972
- bias1=None,
973
- eps=1e-6,
974
- dropout_p=0.0,
975
- rowscale=None,
976
- prenorm=False,
977
- residual_in_fp32=False,
978
- return_dropout_mask=False,
979
- ):
980
- return LayerNormFn.apply(
981
- x,
982
- weight,
983
- bias,
984
- residual,
985
- x1,
986
- weight1,
987
- bias1,
988
- eps,
989
- dropout_p,
990
- rowscale,
991
- prenorm,
992
- residual_in_fp32,
993
- True,
994
- return_dropout_mask,
995
- )
996
-
997
-
998
- class RMSNorm(torch.nn.Module):
999
-
1000
- def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
1001
- factory_kwargs = {"device": device, "dtype": dtype}
1002
- super().__init__()
1003
- self.eps = eps
1004
- if dropout_p > 0.0:
1005
- self.drop = torch.nn.Dropout(dropout_p)
1006
- else:
1007
- self.drop = None
1008
- self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
1009
- self.register_parameter("bias", None)
1010
- self.reset_parameters()
1011
-
1012
- def reset_parameters(self):
1013
- torch.nn.init.ones_(self.weight)
1014
-
1015
- def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
1016
- return rms_norm_fn(
1017
- x,
1018
- self.weight,
1019
- self.bias,
1020
- residual=residual,
1021
- eps=self.eps,
1022
- dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
1023
- prenorm=prenorm,
1024
- residual_in_fp32=residual_in_fp32,
1025
- )
1026
-
1027
-
1028
- class LayerNormLinearFn(torch.autograd.Function):
1029
- @staticmethod
1030
- @custom_fwd
1031
- def forward(
1032
- ctx,
1033
- x,
1034
- norm_weight,
1035
- norm_bias,
1036
- linear_weight,
1037
- linear_bias,
1038
- residual=None,
1039
- eps=1e-6,
1040
- prenorm=False,
1041
- residual_in_fp32=False,
1042
- is_rms_norm=False,
1043
- ):
1044
- x_shape_og = x.shape
1045
- # reshape input data into 2D tensor
1046
- x = x.reshape(-1, x.shape[-1])
1047
- if x.stride(-1) != 1:
1048
- x = x.contiguous()
1049
- if residual is not None:
1050
- assert residual.shape == x_shape_og
1051
- residual = residual.reshape(-1, residual.shape[-1])
1052
- if residual.stride(-1) != 1:
1053
- residual = residual.contiguous()
1054
- norm_weight = norm_weight.contiguous()
1055
- if norm_bias is not None:
1056
- norm_bias = norm_bias.contiguous()
1057
- residual_dtype = (
1058
- residual.dtype
1059
- if residual is not None
1060
- else (torch.float32 if residual_in_fp32 else None)
1061
- )
1062
- y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
1063
- x,
1064
- norm_weight,
1065
- norm_bias,
1066
- eps,
1067
- residual,
1068
- out_dtype=(
1069
- None
1070
- if not torch.is_autocast_enabled()
1071
- else torch.get_autocast_gpu_dtype()
1072
- ),
1073
- residual_dtype=residual_dtype,
1074
- is_rms_norm=is_rms_norm,
1075
- )
1076
- y = y.reshape(x_shape_og)
1077
- dtype = (
1078
- torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
1079
- )
1080
- linear_weight = linear_weight.to(dtype)
1081
- linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
1082
- out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
1083
- # We don't store y, will be recomputed in the backward pass to save memory
1084
- ctx.save_for_backward(
1085
- residual_out, norm_weight, norm_bias, linear_weight, mean, rstd
1086
- )
1087
- ctx.x_shape_og = x_shape_og
1088
- ctx.eps = eps
1089
- ctx.is_rms_norm = is_rms_norm
1090
- ctx.has_residual = residual is not None
1091
- ctx.prenorm = prenorm
1092
- ctx.x_dtype = x.dtype
1093
- ctx.linear_bias_is_none = linear_bias is None
1094
- return out if not prenorm else (out, residual_out.reshape(x_shape_og))
1095
-
1096
- @staticmethod
1097
- @custom_bwd
1098
- def backward(ctx, dout, *args):
1099
- x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
1100
- dout = dout.reshape(-1, dout.shape[-1])
1101
- dy = F.linear(dout, linear_weight.t())
1102
- dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
1103
- if dy.stride(-1) != 1:
1104
- dy = dy.contiguous()
1105
- assert dy.shape == x.shape
1106
- if ctx.prenorm:
1107
- dresidual = args[0]
1108
- dresidual = dresidual.reshape(-1, dresidual.shape[-1])
1109
- if dresidual.stride(-1) != 1:
1110
- dresidual = dresidual.contiguous()
1111
- assert dresidual.shape == x.shape
1112
- else:
1113
- dresidual = None
1114
- dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
1115
- dy,
1116
- x,
1117
- norm_weight,
1118
- norm_bias,
1119
- ctx.eps,
1120
- mean,
1121
- rstd,
1122
- dresidual=dresidual,
1123
- has_residual=ctx.has_residual,
1124
- is_rms_norm=ctx.is_rms_norm,
1125
- x_dtype=ctx.x_dtype,
1126
- recompute_output=True,
1127
- )
1128
- dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
1129
- return (
1130
- dx.reshape(ctx.x_shape_og),
1131
- dnorm_weight,
1132
- dnorm_bias,
1133
- dlinear_weight,
1134
- dlinear_bias,
1135
- dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
1136
- None,
1137
- None,
1138
- None,
1139
- None,
1140
- )
1141
-
1142
-
1143
- def layer_norm_linear_fn(
1144
- x,
1145
- norm_weight,
1146
- norm_bias,
1147
- linear_weight,
1148
- linear_bias,
1149
- residual=None,
1150
- eps=1e-6,
1151
- prenorm=False,
1152
- residual_in_fp32=False,
1153
- is_rms_norm=False,
1154
- ):
1155
- return LayerNormLinearFn.apply(
1156
- x,
1157
- norm_weight,
1158
- norm_bias,
1159
- linear_weight,
1160
- linear_bias,
1161
- residual,
1162
- eps,
1163
- prenorm,
1164
- residual_in_fp32,
1165
- is_rms_norm,
1166
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/triton/selective_state_update.py DELETED
@@ -1,389 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao, Albert Gu.
2
-
3
- """We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
4
- """
5
-
6
- import math
7
- import torch
8
- import torch.nn.functional as F
9
-
10
- import triton
11
- import triton.language as tl
12
-
13
- from einops import rearrange, repeat
14
-
15
- from .softplus import softplus
16
-
17
-
18
- @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
19
- @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
20
- @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
21
- @triton.heuristics(
22
- {
23
- "HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"]
24
- is not None
25
- }
26
- )
27
- @triton.heuristics(
28
- {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}
29
- )
30
- @triton.jit
31
- def _selective_scan_update_kernel(
32
- # Pointers to matrices
33
- state_ptr,
34
- x_ptr,
35
- dt_ptr,
36
- dt_bias_ptr,
37
- A_ptr,
38
- B_ptr,
39
- C_ptr,
40
- D_ptr,
41
- z_ptr,
42
- out_ptr,
43
- state_batch_indices_ptr,
44
- # Matrix dimensions
45
- batch,
46
- nheads,
47
- dim,
48
- dstate,
49
- nheads_ngroups_ratio,
50
- # Strides
51
- stride_state_batch,
52
- stride_state_head,
53
- stride_state_dim,
54
- stride_state_dstate,
55
- stride_x_batch,
56
- stride_x_head,
57
- stride_x_dim,
58
- stride_dt_batch,
59
- stride_dt_head,
60
- stride_dt_dim,
61
- stride_dt_bias_head,
62
- stride_dt_bias_dim,
63
- stride_A_head,
64
- stride_A_dim,
65
- stride_A_dstate,
66
- stride_B_batch,
67
- stride_B_group,
68
- stride_B_dstate,
69
- stride_C_batch,
70
- stride_C_group,
71
- stride_C_dstate,
72
- stride_D_head,
73
- stride_D_dim,
74
- stride_z_batch,
75
- stride_z_head,
76
- stride_z_dim,
77
- stride_out_batch,
78
- stride_out_head,
79
- stride_out_dim,
80
- # Meta-parameters
81
- DT_SOFTPLUS: tl.constexpr,
82
- TIE_HDIM: tl.constexpr,
83
- BLOCK_SIZE_M: tl.constexpr,
84
- HAS_DT_BIAS: tl.constexpr,
85
- HAS_D: tl.constexpr,
86
- HAS_Z: tl.constexpr,
87
- HAS_STATE_BATCH_INDICES: tl.constexpr,
88
- BLOCK_SIZE_DSTATE: tl.constexpr,
89
- ):
90
- pid_m = tl.program_id(axis=0)
91
- pid_b = tl.program_id(axis=1)
92
- pid_h = tl.program_id(axis=2)
93
-
94
- if HAS_STATE_BATCH_INDICES:
95
- state_batch_indices_ptr += pid_b
96
- state_batch_idx = tl.load(state_batch_indices_ptr)
97
- state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
98
- else:
99
- state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
100
-
101
- x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
102
- dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
103
- if HAS_DT_BIAS:
104
- dt_bias_ptr += pid_h * stride_dt_bias_head
105
- A_ptr += pid_h * stride_A_head
106
- B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
107
- C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
108
- if HAS_Z:
109
- z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
110
- out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
111
-
112
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
113
- offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
114
- state_ptrs = state_ptr + (
115
- offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
116
- )
117
- x_ptrs = x_ptr + offs_m * stride_x_dim
118
- dt_ptrs = dt_ptr + offs_m * stride_dt_dim
119
- if HAS_DT_BIAS:
120
- dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
121
- if HAS_D:
122
- D_ptr += pid_h * stride_D_head
123
- A_ptrs = A_ptr + (
124
- offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
125
- )
126
- B_ptrs = B_ptr + offs_n * stride_B_dstate
127
- C_ptrs = C_ptr + offs_n * stride_C_dstate
128
- if HAS_D:
129
- D_ptrs = D_ptr + offs_m * stride_D_dim
130
- if HAS_Z:
131
- z_ptrs = z_ptr + offs_m * stride_z_dim
132
- out_ptrs = out_ptr + offs_m * stride_out_dim
133
-
134
- state = tl.load(
135
- state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
136
- )
137
- x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
138
- if not TIE_HDIM:
139
- dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
140
- if HAS_DT_BIAS:
141
- dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
142
- if DT_SOFTPLUS:
143
- dt = tl.where(dt <= 20.0, softplus(dt), dt)
144
- A = tl.load(
145
- A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
146
- ).to(tl.float32)
147
- dA = tl.exp(A * dt[:, None])
148
- else:
149
- dt = tl.load(dt_ptr).to(tl.float32)
150
- if HAS_DT_BIAS:
151
- dt += tl.load(dt_bias_ptr).to(tl.float32)
152
- if DT_SOFTPLUS:
153
- dt = tl.where(dt <= 20.0, softplus(dt), dt)
154
- A = tl.load(A_ptr).to(tl.float32)
155
- dA = tl.exp(A * dt) # scalar, not a matrix
156
-
157
- B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
158
- C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
159
- if HAS_D:
160
- D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
161
- if HAS_Z:
162
- z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
163
-
164
- if not TIE_HDIM:
165
- dB = B[None, :] * dt[:, None]
166
- else:
167
- dB = B * dt # vector of size (dstate,)
168
- state = state * dA + dB * x[:, None]
169
- tl.store(
170
- state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
171
- )
172
- out = tl.sum(state * C[None, :], axis=1)
173
- if HAS_D:
174
- out += x * D
175
- if HAS_Z:
176
- out *= z * tl.sigmoid(z)
177
- tl.store(out_ptrs, out, mask=offs_m < dim)
178
-
179
-
180
- def selective_state_update(
181
- state,
182
- x,
183
- dt,
184
- A,
185
- B,
186
- C,
187
- D=None,
188
- z=None,
189
- dt_bias=None,
190
- dt_softplus=False,
191
- state_batch_indices=None,
192
- ):
193
- """
194
- Argument:
195
- state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
196
- x: (batch, dim) or (batch, nheads, dim)
197
- dt: (batch, dim) or (batch, nheads, dim)
198
- A: (dim, dstate) or (nheads, dim, dstate)
199
- B: (batch, dstate) or (batch, ngroups, dstate)
200
- C: (batch, dstate) or (batch, ngroups, dstate)
201
- D: (dim,) or (nheads, dim)
202
- z: (batch, dim) or (batch, nheads, dim)
203
- dt_bias: (dim,) or (nheads, dim)
204
- Return:
205
- out: (batch, dim) or (batch, nheads, dim)
206
- """
207
- has_heads = state.dim() > 3
208
- if state.dim() == 3:
209
- state = state.unsqueeze(1)
210
- if x.dim() == 2:
211
- x = x.unsqueeze(1)
212
- if dt.dim() == 2:
213
- dt = dt.unsqueeze(1)
214
- if A.dim() == 2:
215
- A = A.unsqueeze(0)
216
- if B.dim() == 2:
217
- B = B.unsqueeze(1)
218
- if C.dim() == 2:
219
- C = C.unsqueeze(1)
220
- if D is not None and D.dim() == 1:
221
- D = D.unsqueeze(0)
222
- if z is not None and z.dim() == 2:
223
- z = z.unsqueeze(1)
224
- if dt_bias is not None and dt_bias.dim() == 1:
225
- dt_bias = dt_bias.unsqueeze(0)
226
- _, nheads, dim, dstate = state.shape
227
- batch = x.shape[0]
228
- if x.shape != (batch, nheads, dim):
229
- print(f"{state.shape} {x.shape} {batch} {nheads} {dim}")
230
- assert x.shape == (batch, nheads, dim)
231
- assert dt.shape == x.shape
232
- assert A.shape == (nheads, dim, dstate)
233
- ngroups = B.shape[1]
234
- assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
235
- assert B.shape == (batch, ngroups, dstate)
236
- assert C.shape == B.shape
237
- if D is not None:
238
- assert D.shape == (nheads, dim)
239
- if z is not None:
240
- assert z.shape == x.shape
241
- if dt_bias is not None:
242
- assert dt_bias.shape == (nheads, dim)
243
- if state_batch_indices is not None:
244
- assert state_batch_indices.shape == (batch,)
245
- out = torch.empty_like(x)
246
- grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
247
- z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
248
- # We don't want autotune since it will overwrite the state
249
- # We instead tune by hand.
250
- BLOCK_SIZE_M, num_warps = (
251
- (32, 4)
252
- if dstate <= 16
253
- else (
254
- (16, 4)
255
- if dstate <= 32
256
- else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
257
- )
258
- )
259
- tie_hdim = (
260
- A.stride(-1) == 0
261
- and A.stride(-2) == 0
262
- and dt.stride(-1) == 0
263
- and dt_bias.stride(-1) == 0
264
- )
265
- with torch.cuda.device(x.device.index):
266
- _selective_scan_update_kernel[grid](
267
- state,
268
- x,
269
- dt,
270
- dt_bias,
271
- A,
272
- B,
273
- C,
274
- D,
275
- z,
276
- out,
277
- state_batch_indices,
278
- batch,
279
- nheads,
280
- dim,
281
- dstate,
282
- nheads // ngroups,
283
- state.stride(0),
284
- state.stride(1),
285
- state.stride(2),
286
- state.stride(3),
287
- x.stride(0),
288
- x.stride(1),
289
- x.stride(2),
290
- dt.stride(0),
291
- dt.stride(1),
292
- dt.stride(2),
293
- *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
294
- A.stride(0),
295
- A.stride(1),
296
- A.stride(2),
297
- B.stride(0),
298
- B.stride(1),
299
- B.stride(2),
300
- C.stride(0),
301
- C.stride(1),
302
- C.stride(2),
303
- *(D.stride(0), D.stride(1)) if D is not None else 0,
304
- z_strides[0],
305
- z_strides[1],
306
- z_strides[2],
307
- out.stride(0),
308
- out.stride(1),
309
- out.stride(2),
310
- dt_softplus,
311
- tie_hdim,
312
- BLOCK_SIZE_M,
313
- num_warps=num_warps,
314
- )
315
- if not has_heads:
316
- out = out.squeeze(1)
317
- return out
318
-
319
-
320
- def selective_state_update_ref(
321
- state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
322
- ):
323
- """
324
- Argument:
325
- state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
326
- x: (batch, dim) or (batch, nheads, dim)
327
- dt: (batch, dim) or (batch, nheads, dim)
328
- A: (dim, dstate) or (nheads, dim, dstate)
329
- B: (batch, dstate) or (batch, ngroups, dstate)
330
- C: (batch, dstate) or (batch, ngroups, dstate)
331
- D: (dim,) or (nheads, dim)
332
- z: (batch, dim) or (batch, nheads, dim)
333
- dt_bias: (dim,) or (nheads, dim)
334
- Return:
335
- out: (batch, dim) or (batch, nheads, dim)
336
- """
337
- has_heads = state.dim() > 3
338
- if state.dim() == 3:
339
- state = state.unsqueeze(1)
340
- if x.dim() == 2:
341
- x = x.unsqueeze(1)
342
- if dt.dim() == 2:
343
- dt = dt.unsqueeze(1)
344
- if A.dim() == 2:
345
- A = A.unsqueeze(0)
346
- if B.dim() == 2:
347
- B = B.unsqueeze(1)
348
- if C.dim() == 2:
349
- C = C.unsqueeze(1)
350
- if D is not None and D.dim() == 1:
351
- D = D.unsqueeze(0)
352
- if z is not None and z.dim() == 2:
353
- z = z.unsqueeze(1)
354
- if dt_bias is not None and dt_bias.dim() == 1:
355
- dt_bias = dt_bias.unsqueeze(0)
356
- batch, nheads, dim, dstate = state.shape
357
- assert x.shape == (batch, nheads, dim)
358
- assert dt.shape == x.shape
359
- assert A.shape == (nheads, dim, dstate)
360
- ngroups = B.shape[1]
361
- assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
362
- assert B.shape == (batch, ngroups, dstate)
363
- assert C.shape == B.shape
364
- if D is not None:
365
- assert D.shape == (nheads, dim)
366
- if z is not None:
367
- assert z.shape == x.shape
368
- if dt_bias is not None:
369
- assert dt_bias.shape == (nheads, dim)
370
- dt = dt + dt_bias
371
- dt = F.softplus(dt) if dt_softplus else dt
372
- dA = torch.exp(
373
- rearrange(dt, "b h d -> b h d 1") * A
374
- ) # (batch, nheads, dim, dstate)
375
- B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
376
- C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
377
- dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
378
- B, "b h n -> b h 1 n"
379
- ) # (batch, nheads, dim, dstate)
380
- state.copy_(
381
- state * dA + dB * rearrange(x, "b h d -> b h d 1")
382
- ) # (batch, dim, dstate
383
- out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
384
- if D is not None:
385
- out += (x * D).to(out.dtype)
386
- out = (out if z is None else out * F.silu(z)).to(x.dtype)
387
- if not has_heads:
388
- out = out.squeeze(1)
389
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_scan.py DELETED
The diff for this file is too large to render. See raw diff
 
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_state.py DELETED
@@ -1,2012 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao, Albert Gu.
2
-
3
- """We want triton==2.1.0 or 2.2.0 for this
4
- """
5
-
6
- import math
7
- import torch
8
- import torch.nn.functional as F
9
-
10
- import triton
11
- import triton.language as tl
12
-
13
- from einops import rearrange, repeat
14
-
15
- from .softplus import softplus
16
-
17
-
18
- def init_to_zero(names):
19
- return lambda nargs: [
20
- nargs[name].zero_() for name in names if nargs[name] is not None
21
- ]
22
-
23
-
24
- @triton.autotune(
25
- configs=[
26
- triton.Config({"BLOCK_SIZE_H": 1}),
27
- triton.Config({"BLOCK_SIZE_H": 2}),
28
- triton.Config({"BLOCK_SIZE_H": 4}),
29
- triton.Config({"BLOCK_SIZE_H": 8}),
30
- triton.Config({"BLOCK_SIZE_H": 16}),
31
- triton.Config({"BLOCK_SIZE_H": 32}),
32
- triton.Config({"BLOCK_SIZE_H": 64}),
33
- ],
34
- key=["chunk_size", "nheads"],
35
- )
36
- @triton.jit
37
- def _chunk_cumsum_fwd_kernel(
38
- # Pointers to matrices
39
- dt_ptr,
40
- A_ptr,
41
- dt_bias_ptr,
42
- dt_out_ptr,
43
- dA_cumsum_ptr,
44
- # Matrix dimension
45
- batch,
46
- seqlen,
47
- nheads,
48
- chunk_size,
49
- dt_min,
50
- dt_max,
51
- # Strides
52
- stride_dt_batch,
53
- stride_dt_seqlen,
54
- stride_dt_head,
55
- stride_A_head,
56
- stride_dt_bias_head,
57
- stride_dt_out_batch,
58
- stride_dt_out_chunk,
59
- stride_dt_out_head,
60
- stride_dt_out_csize,
61
- stride_dA_cs_batch,
62
- stride_dA_cs_chunk,
63
- stride_dA_cs_head,
64
- stride_dA_cs_csize,
65
- # Meta-parameters
66
- DT_SOFTPLUS: tl.constexpr,
67
- HAS_DT_BIAS: tl.constexpr,
68
- BLOCK_SIZE_H: tl.constexpr,
69
- BLOCK_SIZE_CHUNK: tl.constexpr,
70
- ):
71
- pid_b = tl.program_id(axis=0)
72
- pid_c = tl.program_id(axis=1)
73
- pid_h = tl.program_id(axis=2)
74
- dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
75
- dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk
76
- dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk
77
-
78
- offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
79
- offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
80
- dt_ptrs = dt_ptr + (
81
- offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
82
- )
83
- A_ptrs = A_ptr + offs_h * stride_A_head
84
- dt_out_ptrs = dt_out_ptr + (
85
- offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize
86
- )
87
- dA_cs_ptrs = dA_cumsum_ptr + (
88
- offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize
89
- )
90
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
91
-
92
- dt = tl.load(
93
- dt_ptrs,
94
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
95
- other=0.0,
96
- ).to(tl.float32)
97
- if HAS_DT_BIAS:
98
- dt_bias = tl.load(
99
- dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
100
- ).to(tl.float32)
101
- dt += dt_bias[:, None]
102
- if DT_SOFTPLUS:
103
- dt = tl.where(dt <= 20.0, softplus(dt), dt)
104
- # As of Triton 2.2.0, tl.clamp is not available yet
105
- # dt = tl.clamp(dt, dt_min, dt_max)
106
- dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
107
- dt = tl.where(
108
- (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
109
- )
110
- tl.store(
111
- dt_out_ptrs,
112
- dt,
113
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
114
- )
115
- A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
116
- dA = dt * A[:, None]
117
- dA_cs = tl.cumsum(dA, axis=1)
118
- tl.store(
119
- dA_cs_ptrs,
120
- dA_cs,
121
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
122
- )
123
-
124
-
125
- @triton.autotune(
126
- configs=[
127
- triton.Config(
128
- {"BLOCK_SIZE_H": 1}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
129
- ),
130
- triton.Config(
131
- {"BLOCK_SIZE_H": 2}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
132
- ),
133
- triton.Config(
134
- {"BLOCK_SIZE_H": 4}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
135
- ),
136
- triton.Config(
137
- {"BLOCK_SIZE_H": 8}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
138
- ),
139
- triton.Config(
140
- {"BLOCK_SIZE_H": 16}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
141
- ),
142
- triton.Config(
143
- {"BLOCK_SIZE_H": 32}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
144
- ),
145
- triton.Config(
146
- {"BLOCK_SIZE_H": 64}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
147
- ),
148
- ],
149
- key=["chunk_size", "nheads"],
150
- )
151
- @triton.jit
152
- def _chunk_cumsum_bwd_kernel(
153
- # Pointers to matrices
154
- ddA_ptr,
155
- ddt_out_ptr,
156
- dt_ptr,
157
- A_ptr,
158
- dt_bias_ptr,
159
- ddt_ptr,
160
- dA_ptr,
161
- ddt_bias_ptr,
162
- # Matrix dimensions
163
- batch,
164
- seqlen,
165
- nheads,
166
- chunk_size,
167
- dt_min,
168
- dt_max,
169
- # Strides
170
- stride_ddA_batch,
171
- stride_ddA_chunk,
172
- stride_ddA_head,
173
- stride_ddA_csize,
174
- stride_ddt_out_batch,
175
- stride_ddt_out_chunk,
176
- stride_ddt_out_head,
177
- stride_ddt_out_csize,
178
- stride_dt_batch,
179
- stride_dt_seqlen,
180
- stride_dt_head,
181
- stride_A_head,
182
- stride_dt_bias_head,
183
- stride_ddt_batch,
184
- stride_ddt_seqlen,
185
- stride_ddt_head,
186
- stride_dA_head,
187
- stride_ddt_bias_head,
188
- # Meta-parameters
189
- DT_SOFTPLUS: tl.constexpr,
190
- HAS_DT_BIAS: tl.constexpr,
191
- BLOCK_SIZE_H: tl.constexpr,
192
- BLOCK_SIZE_CHUNK: tl.constexpr,
193
- ):
194
- pid_b = tl.program_id(axis=0)
195
- pid_c = tl.program_id(axis=1)
196
- pid_h = tl.program_id(axis=2)
197
- ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk
198
- ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk
199
- dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
200
- ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen
201
-
202
- offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
203
- offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
204
- ddt_out_ptrs = ddt_out_ptr + (
205
- offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize
206
- )
207
- ddA_ptrs = ddA_ptr + (
208
- offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize
209
- )
210
- dt_ptrs = dt_ptr + (
211
- offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
212
- )
213
- ddt_ptrs = ddt_ptr + (
214
- offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen
215
- )
216
- A_ptrs = A_ptr + offs_h * stride_A_head
217
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
218
-
219
- ddA = tl.load(
220
- ddA_ptrs,
221
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
222
- other=0.0,
223
- ).to(tl.float32)
224
- ddt_out = tl.load(
225
- ddt_out_ptrs,
226
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
227
- other=0.0,
228
- ).to(tl.float32)
229
- A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
230
- ddt = ddA * A[:, None] + ddt_out
231
- dt = tl.load(
232
- dt_ptrs,
233
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
234
- other=0.0,
235
- ).to(tl.float32)
236
- if HAS_DT_BIAS:
237
- dt_bias = tl.load(
238
- dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
239
- ).to(tl.float32)
240
- dt += dt_bias[:, None]
241
- if DT_SOFTPLUS:
242
- dt_presoftplus = dt
243
- dt = tl.where(dt <= 20.0, softplus(dt), ddt)
244
- clamp_mask = (dt < dt_min) | (dt > dt_max)
245
- # As of Triton 2.2.0, tl.clamp is not available yet
246
- # dt = tl.clamp(dt, dt_min, dt_max)
247
- dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
248
- dt = tl.where(
249
- (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
250
- )
251
- ddt = tl.where(
252
- (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0
253
- )
254
- ddt = tl.where(clamp_mask, 0.0, ddt)
255
- if DT_SOFTPLUS:
256
- ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt)
257
- tl.store(
258
- ddt_ptrs,
259
- ddt,
260
- mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
261
- )
262
- dA = tl.sum(ddA * dt, axis=1)
263
- tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads)
264
- if HAS_DT_BIAS:
265
- ddt_bias = tl.sum(ddt, axis=1)
266
- tl.atomic_add(
267
- ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads
268
- )
269
-
270
-
271
- @triton.autotune(
272
- configs=[
273
- triton.Config(
274
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
275
- num_stages=3,
276
- num_warps=8,
277
- ),
278
- triton.Config(
279
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
280
- num_stages=4,
281
- num_warps=4,
282
- ),
283
- triton.Config(
284
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
285
- num_stages=4,
286
- num_warps=4,
287
- ),
288
- triton.Config(
289
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
290
- num_stages=4,
291
- num_warps=4,
292
- ),
293
- triton.Config(
294
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
295
- num_stages=4,
296
- num_warps=4,
297
- ),
298
- triton.Config(
299
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
300
- num_stages=4,
301
- num_warps=4,
302
- ),
303
- triton.Config(
304
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
305
- num_stages=5,
306
- num_warps=2,
307
- ),
308
- triton.Config(
309
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
310
- num_stages=5,
311
- num_warps=2,
312
- ),
313
- triton.Config(
314
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
315
- num_stages=4,
316
- num_warps=2,
317
- ),
318
- ],
319
- key=["hdim", "dstate", "chunk_size"],
320
- )
321
- @triton.jit
322
- def _chunk_state_fwd_kernel(
323
- # Pointers to matrices
324
- x_ptr,
325
- b_ptr,
326
- states_ptr,
327
- dt_ptr,
328
- dA_cumsum_ptr,
329
- seq_idx_ptr,
330
- # Matrix dimensions
331
- hdim,
332
- dstate,
333
- chunk_size,
334
- batch,
335
- seqlen,
336
- nheads_ngroups_ratio,
337
- # Strides
338
- stride_x_batch,
339
- stride_x_seqlen,
340
- stride_x_head,
341
- stride_x_hdim,
342
- stride_b_batch,
343
- stride_b_seqlen,
344
- stride_b_head,
345
- stride_b_dstate,
346
- stride_states_batch,
347
- stride_states_chunk,
348
- stride_states_head,
349
- stride_states_hdim,
350
- stride_states_dstate,
351
- stride_dt_batch,
352
- stride_dt_chunk,
353
- stride_dt_head,
354
- stride_dt_csize,
355
- stride_dA_cs_batch,
356
- stride_dA_cs_chunk,
357
- stride_dA_cs_head,
358
- stride_dA_cs_csize,
359
- stride_seq_idx_batch,
360
- stride_seq_idx_seqlen,
361
- # Meta-parameters
362
- HAS_SEQ_IDX: tl.constexpr,
363
- BLOCK_SIZE_M: tl.constexpr,
364
- BLOCK_SIZE_N: tl.constexpr,
365
- BLOCK_SIZE_K: tl.constexpr,
366
- ):
367
- pid_bc = tl.program_id(axis=1)
368
- pid_c = pid_bc // batch
369
- pid_b = pid_bc - pid_c * batch
370
- pid_h = tl.program_id(axis=2)
371
- num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
372
- pid_m = tl.program_id(axis=0) // num_pid_n
373
- pid_n = tl.program_id(axis=0) % num_pid_n
374
- b_ptr += (
375
- pid_b * stride_b_batch
376
- + pid_c * chunk_size * stride_b_seqlen
377
- + (pid_h // nheads_ngroups_ratio) * stride_b_head
378
- )
379
- x_ptr += (
380
- pid_b * stride_x_batch
381
- + pid_c * chunk_size * stride_x_seqlen
382
- + pid_h * stride_x_head
383
- )
384
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
385
- dA_cumsum_ptr += (
386
- pid_b * stride_dA_cs_batch
387
- + pid_c * stride_dA_cs_chunk
388
- + pid_h * stride_dA_cs_head
389
- )
390
- if HAS_SEQ_IDX:
391
- seq_idx_ptr += (
392
- pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
393
- )
394
-
395
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
396
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
397
- offs_k = tl.arange(0, BLOCK_SIZE_K)
398
- x_ptrs = x_ptr + (
399
- offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
400
- )
401
- b_ptrs = b_ptr + (
402
- offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
403
- )
404
- dt_ptrs = dt_ptr + offs_k * stride_dt_csize
405
- dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
406
- tl.float32
407
- )
408
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
409
- if HAS_SEQ_IDX:
410
- seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
411
-
412
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
413
- if HAS_SEQ_IDX:
414
- seq_idx_last = tl.load(
415
- seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
416
- )
417
-
418
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
419
- for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
420
- x = tl.load(
421
- x_ptrs,
422
- mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k),
423
- other=0.0,
424
- )
425
- b = tl.load(
426
- b_ptrs,
427
- mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate),
428
- other=0.0,
429
- ).to(tl.float32)
430
- dA_cs_k = tl.load(
431
- dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
432
- ).to(tl.float32)
433
- if HAS_SEQ_IDX:
434
- seq_idx_k = tl.load(
435
- seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1
436
- )
437
- dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
438
- tl.float32
439
- )
440
- if not HAS_SEQ_IDX:
441
- scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
442
- else:
443
- scale = tl.where(
444
- seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0
445
- )
446
- b *= scale[:, None]
447
- b = b.to(x_ptr.dtype.element_ty)
448
- acc += tl.dot(x, b)
449
- x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
450
- b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
451
- dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
452
- dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
453
- if HAS_SEQ_IDX:
454
- seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
455
- states = acc.to(states_ptr.dtype.element_ty)
456
-
457
- states_ptr += (
458
- pid_b * stride_states_batch
459
- + pid_c * stride_states_chunk
460
- + pid_h * stride_states_head
461
- )
462
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
463
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
464
- states_ptrs = states_ptr + (
465
- offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
466
- )
467
- c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
468
- tl.store(states_ptrs, states, mask=c_mask)
469
-
470
-
471
- @triton.autotune(
472
- configs=[
473
- triton.Config(
474
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
475
- num_stages=3,
476
- num_warps=8,
477
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
478
- ),
479
- triton.Config(
480
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
481
- num_stages=4,
482
- num_warps=4,
483
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
484
- ),
485
- triton.Config(
486
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
487
- num_stages=4,
488
- num_warps=4,
489
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
490
- ),
491
- triton.Config(
492
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
493
- num_stages=4,
494
- num_warps=4,
495
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
496
- ),
497
- triton.Config(
498
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
499
- num_stages=4,
500
- num_warps=4,
501
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
502
- ),
503
- triton.Config(
504
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
505
- num_stages=4,
506
- num_warps=4,
507
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
508
- ),
509
- triton.Config(
510
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
511
- num_stages=5,
512
- num_warps=4,
513
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
514
- ),
515
- triton.Config(
516
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
517
- num_stages=5,
518
- num_warps=4,
519
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
520
- ),
521
- triton.Config(
522
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
523
- num_stages=4,
524
- num_warps=4,
525
- pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
526
- ),
527
- ],
528
- key=["chunk_size", "hdim", "dstate"],
529
- )
530
- @triton.jit
531
- def _chunk_state_bwd_dx_kernel(
532
- # Pointers to matrices
533
- x_ptr,
534
- b_ptr,
535
- dstates_ptr,
536
- dt_ptr,
537
- dA_cumsum_ptr,
538
- dx_ptr,
539
- ddt_ptr,
540
- ddA_cumsum_ptr,
541
- # Matrix dimensions
542
- chunk_size,
543
- hdim,
544
- dstate,
545
- batch,
546
- seqlen,
547
- nheads_ngroups_ratio,
548
- # Strides
549
- stride_x_batch,
550
- stride_x_seqlen,
551
- stride_x_head,
552
- stride_x_hdim,
553
- stride_b_batch,
554
- stride_b_seqlen,
555
- stride_b_head,
556
- stride_b_dstate,
557
- stride_dstates_batch,
558
- stride_dstates_chunk,
559
- stride_states_head,
560
- stride_states_hdim,
561
- stride_states_dstate,
562
- stride_dt_batch,
563
- stride_dt_chunk,
564
- stride_dt_head,
565
- stride_dt_csize,
566
- stride_dA_cs_batch,
567
- stride_dA_cs_chunk,
568
- stride_dA_cs_head,
569
- stride_dA_cs_csize,
570
- stride_dx_batch,
571
- stride_dx_seqlen,
572
- stride_dx_head,
573
- stride_dx_hdim,
574
- stride_ddt_batch,
575
- stride_ddt_chunk,
576
- stride_ddt_head,
577
- stride_ddt_csize,
578
- stride_ddA_cs_batch,
579
- stride_ddA_cs_chunk,
580
- stride_ddA_cs_head,
581
- stride_ddA_cs_csize,
582
- # Meta-parameters
583
- BLOCK_SIZE_M: tl.constexpr,
584
- BLOCK_SIZE_N: tl.constexpr,
585
- BLOCK_SIZE_K: tl.constexpr,
586
- BLOCK_SIZE_DSTATE: tl.constexpr,
587
- ):
588
- pid_bc = tl.program_id(axis=1)
589
- pid_c = pid_bc // batch
590
- pid_b = pid_bc - pid_c * batch
591
- pid_h = tl.program_id(axis=2)
592
- num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
593
- pid_m = tl.program_id(axis=0) // num_pid_n
594
- pid_n = tl.program_id(axis=0) % num_pid_n
595
- x_ptr += (
596
- pid_b * stride_x_batch
597
- + pid_c * chunk_size * stride_x_seqlen
598
- + pid_h * stride_x_head
599
- )
600
- b_ptr += (
601
- pid_b * stride_b_batch
602
- + pid_c * chunk_size * stride_b_seqlen
603
- + (pid_h // nheads_ngroups_ratio) * stride_b_head
604
- )
605
- dstates_ptr += (
606
- pid_b * stride_dstates_batch
607
- + pid_c * stride_dstates_chunk
608
- + pid_h * stride_states_head
609
- )
610
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
611
- ddt_ptr += (
612
- pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
613
- )
614
- ddA_cumsum_ptr += (
615
- pid_b * stride_ddA_cs_batch
616
- + pid_c * stride_ddA_cs_chunk
617
- + pid_h * stride_ddA_cs_head
618
- )
619
- dA_cumsum_ptr += (
620
- pid_b * stride_dA_cs_batch
621
- + pid_c * stride_dA_cs_chunk
622
- + pid_h * stride_dA_cs_head
623
- )
624
-
625
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
626
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
627
-
628
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
629
- # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
630
- offs_k = tl.arange(
631
- 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
632
- )
633
- b_ptrs = b_ptr + (
634
- offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate
635
- )
636
- dstates_ptrs = dstates_ptr + (
637
- offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate
638
- )
639
- if BLOCK_SIZE_DSTATE <= 128:
640
- b = tl.load(
641
- b_ptrs,
642
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate),
643
- other=0.0,
644
- )
645
- dstates = tl.load(
646
- dstates_ptrs,
647
- mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim),
648
- other=0.0,
649
- )
650
- dstates = dstates.to(b_ptr.dtype.element_ty)
651
- acc = tl.dot(b, dstates)
652
- else:
653
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
654
- for k in range(0, dstate, BLOCK_SIZE_K):
655
- b = tl.load(
656
- b_ptrs,
657
- mask=(offs_m[:, None] < chunk_size_limit)
658
- & (offs_k[None, :] < dstate - k),
659
- other=0.0,
660
- )
661
- dstates = tl.load(
662
- dstates_ptrs,
663
- mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim),
664
- other=0.0,
665
- )
666
- dstates = dstates.to(b_ptr.dtype.element_ty)
667
- acc += tl.dot(b, dstates)
668
- b_ptrs += BLOCK_SIZE_K * stride_b_dstate
669
- dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
670
-
671
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
672
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
673
-
674
- dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
675
- tl.float32
676
- )
677
- dt_ptrs = dt_ptr + offs_m * stride_dt_csize
678
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
679
- dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(
680
- tl.float32
681
- )
682
- dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
683
- acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]
684
-
685
- x_ptrs = x_ptr + (
686
- offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
687
- )
688
- x = tl.load(
689
- x_ptrs,
690
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
691
- other=0.0,
692
- ).to(tl.float32)
693
- ddt = tl.sum(acc * x, axis=1)
694
- ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
695
- tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
696
- ddA_cs = -(ddt * dt_m)
697
- ddA_cs_last = -tl.sum(ddA_cs)
698
- ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
699
- tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
700
- tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last)
701
-
702
- dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty)
703
- dx_ptr += (
704
- pid_b * stride_dx_batch
705
- + pid_c * chunk_size * stride_dx_seqlen
706
- + pid_h * stride_dx_head
707
- )
708
- dx_ptrs = dx_ptr + (
709
- offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim
710
- )
711
- tl.store(
712
- dx_ptrs,
713
- dx,
714
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
715
- )
716
-
717
-
718
- @triton.autotune(
719
- configs=[
720
- triton.Config(
721
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128},
722
- num_stages=3,
723
- num_warps=4,
724
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
725
- ),
726
- triton.Config(
727
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32},
728
- num_stages=3,
729
- num_warps=4,
730
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
731
- ),
732
- triton.Config(
733
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128},
734
- num_stages=3,
735
- num_warps=4,
736
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
737
- ),
738
- triton.Config(
739
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64},
740
- num_stages=3,
741
- num_warps=4,
742
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
743
- ),
744
- triton.Config(
745
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64},
746
- num_stages=3,
747
- num_warps=4,
748
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
749
- ),
750
- triton.Config(
751
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32},
752
- num_stages=3,
753
- num_warps=4,
754
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
755
- ),
756
- triton.Config(
757
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64},
758
- num_stages=3,
759
- num_warps=4,
760
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
761
- ),
762
- triton.Config(
763
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32},
764
- num_stages=3,
765
- num_warps=4,
766
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
767
- ),
768
- ],
769
- key=["chunk_size", "dstate", "hdim"],
770
- )
771
- @triton.jit
772
- def _chunk_state_bwd_db_kernel(
773
- # Pointers to matrices
774
- x_ptr,
775
- dstates_ptr,
776
- b_ptr,
777
- dt_ptr,
778
- dA_cumsum_ptr,
779
- seq_idx_ptr,
780
- db_ptr,
781
- ddA_cumsum_ptr,
782
- # Matrix dimensions
783
- chunk_size,
784
- dstate,
785
- hdim,
786
- batch,
787
- seqlen,
788
- nheads,
789
- nheads_per_program,
790
- ngroups,
791
- # Strides
792
- stride_x_batch,
793
- stride_x_seqlen,
794
- stride_x_head,
795
- stride_x_hdim,
796
- stride_dstates_batch,
797
- stride_dstates_chunk,
798
- stride_states_head,
799
- stride_states_hdim,
800
- stride_states_dstate,
801
- stride_b_batch,
802
- stride_b_seqlen,
803
- stride_b_head,
804
- stride_b_dstate,
805
- stride_dt_batch,
806
- stride_dt_chunk,
807
- stride_dt_head,
808
- stride_dt_csize,
809
- stride_dA_cs_batch,
810
- stride_dA_cs_chunk,
811
- stride_dA_cs_head,
812
- stride_dA_cs_csize,
813
- stride_seq_idx_batch,
814
- stride_seq_idx_seqlen,
815
- stride_db_batch,
816
- stride_db_seqlen,
817
- stride_db_split,
818
- stride_db_group,
819
- stride_db_dstate,
820
- stride_ddA_cs_batch,
821
- stride_ddA_cs_chunk,
822
- stride_ddA_cs_head,
823
- stride_ddA_cs_csize,
824
- # Meta-parameters
825
- HAS_DDA_CS: tl.constexpr,
826
- HAS_SEQ_IDX: tl.constexpr,
827
- BLOCK_SIZE_M: tl.constexpr,
828
- BLOCK_SIZE_N: tl.constexpr,
829
- BLOCK_SIZE_K: tl.constexpr,
830
- ):
831
- pid_bc = tl.program_id(axis=1)
832
- pid_c = pid_bc // batch
833
- pid_b = pid_bc - pid_c * batch
834
- pid_sg = tl.program_id(axis=2)
835
- pid_s = pid_sg // ngroups
836
- pid_g = pid_sg - pid_s * ngroups
837
- num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
838
- pid_m = tl.program_id(axis=0) // num_pid_n
839
- pid_n = tl.program_id(axis=0) % num_pid_n
840
- x_ptr += (
841
- pid_b * stride_x_batch
842
- + pid_c * chunk_size * stride_x_seqlen
843
- + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head
844
- )
845
- db_ptr += (
846
- pid_b * stride_db_batch
847
- + pid_c * chunk_size * stride_db_seqlen
848
- + pid_g * stride_db_group
849
- + pid_s * stride_db_split
850
- )
851
- dstates_ptr += (
852
- pid_b * stride_dstates_batch
853
- + pid_c * stride_dstates_chunk
854
- + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program)
855
- * stride_states_head
856
- )
857
- dt_ptr += (
858
- pid_b * stride_dt_batch
859
- + pid_c * stride_dt_chunk
860
- + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head
861
- )
862
- dA_cumsum_ptr += (
863
- pid_b * stride_dA_cs_batch
864
- + pid_c * stride_dA_cs_chunk
865
- + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head
866
- )
867
- if HAS_DDA_CS:
868
- b_ptr += (
869
- pid_b * stride_b_batch
870
- + pid_c * chunk_size * stride_b_seqlen
871
- + pid_g * stride_b_head
872
- )
873
- ddA_cumsum_ptr += (
874
- pid_b * stride_ddA_cs_batch
875
- + pid_c * stride_ddA_cs_chunk
876
- + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program)
877
- * stride_ddA_cs_head
878
- )
879
- if HAS_SEQ_IDX:
880
- seq_idx_ptr += (
881
- pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
882
- )
883
-
884
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
885
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
886
- offs_k = tl.arange(0, BLOCK_SIZE_K)
887
- x_ptrs = x_ptr + (
888
- offs_m[:, None] * stride_x_seqlen + offs_k[None, :] * stride_x_hdim
889
- )
890
- dstates_ptrs = dstates_ptr + (
891
- offs_n[None, :] * stride_states_dstate + offs_k[:, None] * stride_states_hdim
892
- )
893
- dt_ptrs = dt_ptr + offs_m * stride_dt_csize
894
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
895
- if HAS_DDA_CS:
896
- b_ptrs = b_ptr + (
897
- offs_m[:, None] * stride_b_seqlen + offs_n[None, :] * stride_b_dstate
898
- )
899
- ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
900
-
901
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
902
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
903
- if HAS_DDA_CS:
904
- b = tl.load(
905
- b_ptrs,
906
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate),
907
- other=0.0,
908
- ).to(tl.float32)
909
- if HAS_SEQ_IDX:
910
- seq_idx_m = tl.load(
911
- seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
912
- mask=offs_m < chunk_size_limit,
913
- other=-1,
914
- )
915
- seq_idx_last = tl.load(
916
- seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
917
- )
918
- nheads_iter = min(
919
- nheads_per_program, nheads // ngroups - pid_s * nheads_per_program
920
- )
921
- for h in range(nheads_iter):
922
- x = tl.load(
923
- x_ptrs,
924
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim),
925
- other=0.0,
926
- )
927
- dstates = tl.load(
928
- dstates_ptrs,
929
- mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate),
930
- other=0.0,
931
- )
932
- dstates = dstates.to(x_ptrs.dtype.element_ty)
933
- db = tl.dot(x, dstates)
934
- dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
935
- tl.float32
936
- )
937
- dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(
938
- tl.float32
939
- )
940
- dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
941
- if not HAS_SEQ_IDX:
942
- scale = tl.exp(dA_cs_last - dA_cs_m)
943
- else:
944
- scale = tl.where(
945
- seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0
946
- )
947
- db *= (scale * dt_m)[:, None]
948
- if HAS_DDA_CS:
949
- # This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum
950
- ddA_cs = tl.sum(db * b, axis=1)
951
- tl.atomic_add(
952
- ddA_cumsum_ptrs + stride_ddA_cs_csize,
953
- ddA_cs,
954
- mask=offs_m < chunk_size - 1,
955
- )
956
- acc += db
957
- x_ptrs += stride_x_head
958
- dstates_ptrs += stride_states_head
959
- dt_ptrs += stride_dt_head
960
- dA_cumsum_ptr += stride_dA_cs_head
961
- dA_cumsum_ptrs += stride_dA_cs_head
962
- if HAS_DDA_CS:
963
- ddA_cumsum_ptrs += stride_ddA_cs_head
964
-
965
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
966
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
967
- # if HAS_SEQ_IDX:
968
- # seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
969
- # seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
970
- # acc = tl.where(seq_idx_m[:, None] == seq_idx_last, acc, 0.0)
971
- db_ptrs = db_ptr + (
972
- offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_dstate
973
- )
974
- tl.store(
975
- db_ptrs,
976
- acc,
977
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate),
978
- )
979
-
980
-
981
- @triton.autotune(
982
- configs=[
983
- # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
984
- # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
985
- # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
986
- # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
987
- # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
988
- # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
989
- # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
990
- # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
991
- # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
992
- triton.Config(
993
- {"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32},
994
- num_stages=3,
995
- num_warps=4,
996
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
997
- ),
998
- triton.Config(
999
- {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
1000
- num_stages=3,
1001
- num_warps=4,
1002
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1003
- ),
1004
- triton.Config(
1005
- {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1006
- num_stages=3,
1007
- num_warps=4,
1008
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1009
- ),
1010
- triton.Config(
1011
- {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
1012
- num_stages=3,
1013
- num_warps=4,
1014
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1015
- ),
1016
- triton.Config(
1017
- {"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32},
1018
- num_stages=4,
1019
- num_warps=8,
1020
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1021
- ),
1022
- triton.Config(
1023
- {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
1024
- num_stages=4,
1025
- num_warps=8,
1026
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1027
- ),
1028
- triton.Config(
1029
- {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1030
- num_stages=4,
1031
- num_warps=8,
1032
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1033
- ),
1034
- triton.Config(
1035
- {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
1036
- num_stages=4,
1037
- num_warps=8,
1038
- pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1039
- ),
1040
- ],
1041
- key=["chunk_size", "hdim", "dstate"],
1042
- )
1043
- @triton.jit
1044
- def _chunk_state_bwd_ddAcs_stable_kernel(
1045
- # Pointers to matrices
1046
- x_ptr,
1047
- b_ptr,
1048
- dstates_ptr,
1049
- dt_ptr,
1050
- dA_cumsum_ptr,
1051
- seq_idx_ptr,
1052
- ddA_cumsum_ptr,
1053
- # Matrix dimensions
1054
- chunk_size,
1055
- hdim,
1056
- dstate,
1057
- batch,
1058
- seqlen,
1059
- nheads_ngroups_ratio,
1060
- # Strides
1061
- stride_x_batch,
1062
- stride_x_seqlen,
1063
- stride_x_head,
1064
- stride_x_hdim,
1065
- stride_b_batch,
1066
- stride_b_seqlen,
1067
- stride_b_head,
1068
- stride_b_dstate,
1069
- stride_dstates_batch,
1070
- stride_dstates_chunk,
1071
- stride_states_head,
1072
- stride_states_hdim,
1073
- stride_states_dstate,
1074
- stride_dt_batch,
1075
- stride_dt_chunk,
1076
- stride_dt_head,
1077
- stride_dt_csize,
1078
- stride_dA_cs_batch,
1079
- stride_dA_cs_chunk,
1080
- stride_dA_cs_head,
1081
- stride_dA_cs_csize,
1082
- stride_seq_idx_batch,
1083
- stride_seq_idx_seqlen,
1084
- stride_ddA_cs_batch,
1085
- stride_ddA_cs_chunk,
1086
- stride_ddA_cs_head,
1087
- stride_ddA_cs_csize,
1088
- # Meta-parameters
1089
- HAS_SEQ_IDX: tl.constexpr,
1090
- BLOCK_SIZE_M: tl.constexpr,
1091
- BLOCK_SIZE_N: tl.constexpr,
1092
- BLOCK_SIZE_K: tl.constexpr,
1093
- BLOCK_SIZE_DSTATE: tl.constexpr,
1094
- ):
1095
- pid_bc = tl.program_id(axis=1)
1096
- pid_c = pid_bc // batch
1097
- pid_b = pid_bc - pid_c * batch
1098
- pid_h = tl.program_id(axis=2)
1099
- num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
1100
- pid_m = tl.program_id(axis=0) // num_pid_n
1101
- pid_n = tl.program_id(axis=0) % num_pid_n
1102
- x_ptr += (
1103
- pid_b * stride_x_batch
1104
- + pid_c * chunk_size * stride_x_seqlen
1105
- + pid_h * stride_x_head
1106
- )
1107
- b_ptr += (
1108
- pid_b * stride_b_batch
1109
- + pid_c * chunk_size * stride_b_seqlen
1110
- + (pid_h // nheads_ngroups_ratio) * stride_b_head
1111
- )
1112
- dstates_ptr += (
1113
- pid_b * stride_dstates_batch
1114
- + pid_c * stride_dstates_chunk
1115
- + pid_h * stride_states_head
1116
- )
1117
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
1118
- ddA_cumsum_ptr += (
1119
- pid_b * stride_ddA_cs_batch
1120
- + pid_c * stride_ddA_cs_chunk
1121
- + pid_h * stride_ddA_cs_head
1122
- )
1123
- dA_cumsum_ptr += (
1124
- pid_b * stride_dA_cs_batch
1125
- + pid_c * stride_dA_cs_chunk
1126
- + pid_h * stride_dA_cs_head
1127
- )
1128
- if HAS_SEQ_IDX:
1129
- seq_idx_ptr += (
1130
- pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
1131
- )
1132
-
1133
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1134
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1135
-
1136
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
1137
- # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
1138
- offs_k = tl.arange(
1139
- 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
1140
- )
1141
- b_ptrs = b_ptr + (
1142
- offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate
1143
- )
1144
- dstates_ptrs = dstates_ptr + (
1145
- offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate
1146
- )
1147
- if BLOCK_SIZE_DSTATE <= 128:
1148
- b = tl.load(
1149
- b_ptrs,
1150
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate),
1151
- other=0.0,
1152
- )
1153
- dstates = tl.load(
1154
- dstates_ptrs,
1155
- mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim),
1156
- other=0.0,
1157
- )
1158
- dstates = dstates.to(b_ptr.dtype.element_ty)
1159
- acc = tl.dot(b, dstates)
1160
- else:
1161
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
1162
- for k in range(0, dstate, BLOCK_SIZE_K):
1163
- b = tl.load(
1164
- b_ptrs,
1165
- mask=(offs_m[:, None] < chunk_size_limit)
1166
- & (offs_k[None, :] < dstate - k),
1167
- other=0.0,
1168
- )
1169
- dstates = tl.load(
1170
- dstates_ptrs,
1171
- mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim),
1172
- other=0.0,
1173
- )
1174
- dstates = dstates.to(b_ptr.dtype.element_ty)
1175
- acc += tl.dot(b, dstates)
1176
- b_ptrs += BLOCK_SIZE_K * stride_b_dstate
1177
- dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
1178
-
1179
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1180
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1181
-
1182
- dA_cs_m = tl.load(
1183
- dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0
1184
- ).to(tl.float32)
1185
- dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
1186
- tl.float32
1187
- )
1188
- if not HAS_SEQ_IDX:
1189
- scale = tl.exp(dA_cs_last - dA_cs_m)
1190
- else:
1191
- seq_idx_m = tl.load(
1192
- seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
1193
- mask=offs_m < chunk_size_limit,
1194
- other=-1,
1195
- )
1196
- seq_idx_last = tl.load(
1197
- seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
1198
- )
1199
- scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
1200
- acc *= scale[:, None]
1201
-
1202
- x_ptrs = x_ptr + (
1203
- offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
1204
- )
1205
- x = tl.load(
1206
- x_ptrs,
1207
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
1208
- other=0.0,
1209
- ).to(tl.float32)
1210
- dt_ptrs = dt_ptr + offs_m * stride_dt_csize
1211
- dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
1212
- ddt = tl.sum(acc * x, axis=1)
1213
- # ddA_cs = -(ddt * dt_m)
1214
- # Triton 2.2.0 errors if we have the cumsum here, so we just write it out
1215
- # then call torch.cumsum outside this kernel.
1216
- # ddA_cs = tl.cumsum(ddt * dt_m)
1217
- ddA_cs = ddt * dt_m
1218
- ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
1219
- # tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
1220
- tl.atomic_add(
1221
- ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1
1222
- )
1223
-
1224
-
1225
- @triton.autotune(
1226
- configs=[
1227
- triton.Config(
1228
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
1229
- num_stages=3,
1230
- num_warps=8,
1231
- ),
1232
- triton.Config(
1233
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
1234
- num_stages=4,
1235
- num_warps=4,
1236
- ),
1237
- triton.Config(
1238
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
1239
- num_stages=4,
1240
- num_warps=4,
1241
- ),
1242
- triton.Config(
1243
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1244
- num_stages=4,
1245
- num_warps=4,
1246
- ),
1247
- triton.Config(
1248
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
1249
- num_stages=4,
1250
- num_warps=4,
1251
- ),
1252
- triton.Config(
1253
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
1254
- num_stages=4,
1255
- num_warps=4,
1256
- ),
1257
- triton.Config(
1258
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
1259
- num_stages=5,
1260
- num_warps=2,
1261
- ),
1262
- triton.Config(
1263
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1264
- num_stages=5,
1265
- num_warps=2,
1266
- ),
1267
- triton.Config(
1268
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1269
- num_stages=4,
1270
- num_warps=2,
1271
- ),
1272
- ],
1273
- key=["hdim", "dstate", "chunk_size"],
1274
- )
1275
- @triton.jit
1276
- def _chunk_state_varlen_kernel(
1277
- # Pointers to matrices
1278
- x_ptr,
1279
- b_ptr,
1280
- dt_ptr,
1281
- dA_cumsum_ptr,
1282
- chunk_states_ptr,
1283
- cu_seqlens_ptr,
1284
- states_ptr,
1285
- # Matrix dimensions
1286
- hdim,
1287
- dstate,
1288
- chunk_size,
1289
- seqlen,
1290
- nheads_ngroups_ratio,
1291
- # Strides
1292
- stride_x_seqlen,
1293
- stride_x_head,
1294
- stride_x_hdim,
1295
- stride_b_seqlen,
1296
- stride_b_head,
1297
- stride_b_dstate,
1298
- stride_dt_chunk,
1299
- stride_dt_head,
1300
- stride_dt_csize,
1301
- stride_dA_cs_chunk,
1302
- stride_dA_cs_head,
1303
- stride_dA_cs_csize,
1304
- stride_chunk_states_chunk,
1305
- stride_chunk_states_head,
1306
- stride_chunk_states_hdim,
1307
- stride_chunk_states_dstate,
1308
- stride_states_batch,
1309
- stride_states_head,
1310
- stride_states_hdim,
1311
- stride_states_dstate,
1312
- # Meta-parameters
1313
- BLOCK_SIZE_M: tl.constexpr,
1314
- BLOCK_SIZE_N: tl.constexpr,
1315
- BLOCK_SIZE_K: tl.constexpr,
1316
- ):
1317
- pid_b = tl.program_id(axis=1)
1318
- pid_h = tl.program_id(axis=2)
1319
- num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
1320
- pid_m = tl.program_id(axis=0) // num_pid_n
1321
- pid_n = tl.program_id(axis=0) % num_pid_n
1322
- end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
1323
- pid_c = (end_idx - 1) // chunk_size
1324
- b_ptr += (
1325
- pid_c * chunk_size * stride_b_seqlen
1326
- + (pid_h // nheads_ngroups_ratio) * stride_b_head
1327
- )
1328
- x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
1329
- dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
1330
- dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
1331
- chunk_states_ptr += (
1332
- pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
1333
- )
1334
-
1335
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1336
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1337
- offs_k = tl.arange(0, BLOCK_SIZE_K)
1338
- x_ptrs = x_ptr + (
1339
- offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
1340
- )
1341
- b_ptrs = b_ptr + (
1342
- offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
1343
- )
1344
- dt_ptrs = dt_ptr + offs_k * stride_dt_csize
1345
- dA_cs_last = tl.load(
1346
- dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
1347
- ).to(tl.float32)
1348
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
1349
-
1350
- chunk_size_limit = end_idx - pid_c * chunk_size
1351
- start_idx = tl.load(cu_seqlens_ptr + pid_b)
1352
- start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
1353
-
1354
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
1355
- for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
1356
- x = tl.load(
1357
- x_ptrs,
1358
- mask=(offs_m[:, None] < hdim)
1359
- & (offs_k[None, :] < chunk_size_limit - k)
1360
- & (offs_k[None, :] >= start_idx_cur - k),
1361
- other=0.0,
1362
- )
1363
- b = tl.load(
1364
- b_ptrs,
1365
- mask=(offs_k[:, None] < chunk_size_limit - k)
1366
- & (offs_n[None, :] < dstate)
1367
- & (offs_k[:, None] >= start_idx_cur - k),
1368
- other=0.0,
1369
- ).to(tl.float32)
1370
- dA_cs_k = tl.load(
1371
- dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
1372
- ).to(tl.float32)
1373
- dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
1374
- tl.float32
1375
- )
1376
- scale = tl.where(
1377
- (offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
1378
- tl.exp((dA_cs_last - dA_cs_k)) * dt_k,
1379
- 0.0,
1380
- )
1381
- b *= scale[:, None]
1382
- b = b.to(x_ptr.dtype.element_ty)
1383
- acc += tl.dot(x, b)
1384
- x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
1385
- b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
1386
- dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
1387
- dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
1388
-
1389
- # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
1390
- if start_idx < pid_c * chunk_size:
1391
- chunk_states_ptrs = chunk_states_ptr + (
1392
- offs_m[:, None] * stride_chunk_states_hdim
1393
- + offs_n[None, :] * stride_chunk_states_dstate
1394
- )
1395
- chunk_states = tl.load(
1396
- chunk_states_ptrs,
1397
- mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate),
1398
- other=0.0,
1399
- ).to(tl.float32)
1400
- # scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0)
1401
- scale = tl.exp(dA_cs_last)
1402
- acc += chunk_states * scale
1403
-
1404
- states = acc.to(states_ptr.dtype.element_ty)
1405
-
1406
- states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
1407
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1408
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1409
- states_ptrs = states_ptr + (
1410
- offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
1411
- )
1412
- c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
1413
- tl.store(states_ptrs, states, mask=c_mask)
1414
-
1415
-
1416
- def _chunk_cumsum_fwd(
1417
- dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))
1418
- ):
1419
- batch, seqlen, nheads = dt.shape
1420
- assert A.shape == (nheads,)
1421
- if dt_bias is not None:
1422
- assert dt_bias.shape == (nheads,)
1423
- nchunks = math.ceil(seqlen / chunk_size)
1424
- dt_out = torch.empty(
1425
- batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
1426
- )
1427
- dA_cumsum = torch.empty(
1428
- batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
1429
- )
1430
- grid_chunk_cs = lambda META: (
1431
- batch,
1432
- nchunks,
1433
- triton.cdiv(nheads, META["BLOCK_SIZE_H"]),
1434
- )
1435
- with torch.cuda.device(dt.device.index):
1436
- _chunk_cumsum_fwd_kernel[grid_chunk_cs](
1437
- dt,
1438
- A,
1439
- dt_bias,
1440
- dt_out,
1441
- dA_cumsum,
1442
- batch,
1443
- seqlen,
1444
- nheads,
1445
- chunk_size,
1446
- dt_limit[0],
1447
- dt_limit[1],
1448
- dt.stride(0),
1449
- dt.stride(1),
1450
- dt.stride(2),
1451
- A.stride(0),
1452
- dt_bias.stride(0) if dt_bias is not None else 0,
1453
- dt_out.stride(0),
1454
- dt_out.stride(2),
1455
- dt_out.stride(1),
1456
- dt_out.stride(3),
1457
- dA_cumsum.stride(0),
1458
- dA_cumsum.stride(2),
1459
- dA_cumsum.stride(1),
1460
- dA_cumsum.stride(3),
1461
- dt_softplus,
1462
- HAS_DT_BIAS=dt_bias is not None,
1463
- BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
1464
- )
1465
- return dA_cumsum, dt_out
1466
-
1467
-
1468
- def _chunk_cumsum_bwd(
1469
- ddA,
1470
- ddt_out,
1471
- dt,
1472
- A,
1473
- dt_bias=None,
1474
- dt_softplus=False,
1475
- dt_limit=(0.0, float("inf")),
1476
- ddt=None,
1477
- ):
1478
- batch, seqlen, nheads = dt.shape
1479
- _, _, nchunks, chunk_size = ddA.shape
1480
- assert ddA.shape == (batch, nheads, nchunks, chunk_size)
1481
- assert ddt_out.shape == (batch, nheads, nchunks, chunk_size)
1482
- assert A.shape == (nheads,)
1483
- if dt_bias is not None:
1484
- assert dt_bias.shape == (nheads,)
1485
- ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32)
1486
- else:
1487
- ddt_bias = None
1488
- if ddt is not None:
1489
- assert ddt.shape == dt.shape
1490
- else:
1491
- ddt = torch.empty_like(dt)
1492
- dA = torch.empty_like(A, dtype=torch.float32)
1493
- grid_chunk_cs = lambda META: (
1494
- batch,
1495
- nchunks,
1496
- triton.cdiv(nheads, META["BLOCK_SIZE_H"]),
1497
- )
1498
- with torch.cuda.device(dt.device.index):
1499
- _chunk_cumsum_bwd_kernel[grid_chunk_cs](
1500
- ddA,
1501
- ddt_out,
1502
- dt,
1503
- A,
1504
- dt_bias,
1505
- ddt,
1506
- dA,
1507
- ddt_bias,
1508
- batch,
1509
- seqlen,
1510
- nheads,
1511
- chunk_size,
1512
- dt_limit[0],
1513
- dt_limit[1],
1514
- ddA.stride(0),
1515
- ddA.stride(2),
1516
- ddA.stride(1),
1517
- ddA.stride(3),
1518
- ddt_out.stride(0),
1519
- ddt_out.stride(2),
1520
- ddt_out.stride(1),
1521
- ddt_out.stride(3),
1522
- dt.stride(0),
1523
- dt.stride(1),
1524
- dt.stride(2),
1525
- A.stride(0),
1526
- dt_bias.stride(0) if dt_bias is not None else 0,
1527
- ddt.stride(0),
1528
- ddt.stride(1),
1529
- ddt.stride(2),
1530
- dA.stride(0),
1531
- ddt_bias.stride(0) if ddt_bias is not None else 0,
1532
- dt_softplus,
1533
- HAS_DT_BIAS=dt_bias is not None,
1534
- BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
1535
- )
1536
- return ddt, dA, ddt_bias
1537
-
1538
-
1539
- def _chunk_state_fwd(
1540
- B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True
1541
- ):
1542
- batch, seqlen, nheads, headdim = x.shape
1543
- _, _, nchunks, chunk_size = dt.shape
1544
- _, _, ngroups, dstate = B.shape
1545
- assert nheads % ngroups == 0
1546
- assert B.shape == (batch, seqlen, ngroups, dstate)
1547
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
1548
- assert dA_cumsum.shape == dt.shape
1549
- if seq_idx is not None:
1550
- assert seq_idx.shape == (batch, seqlen)
1551
- if states is not None:
1552
- assert states.shape == (batch, nchunks, nheads, headdim, dstate)
1553
- else:
1554
- states_dtype = torch.float32 if states_in_fp32 else B.dtype
1555
- states = torch.empty(
1556
- (batch, nchunks, nheads, headdim, dstate),
1557
- device=x.device,
1558
- dtype=states_dtype,
1559
- )
1560
- grid = lambda META: (
1561
- triton.cdiv(headdim, META["BLOCK_SIZE_M"])
1562
- * triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
1563
- batch * nchunks,
1564
- nheads,
1565
- )
1566
- with torch.cuda.device(x.device.index):
1567
- _chunk_state_fwd_kernel[grid](
1568
- x,
1569
- B,
1570
- states,
1571
- dt,
1572
- dA_cumsum,
1573
- seq_idx,
1574
- headdim,
1575
- dstate,
1576
- chunk_size,
1577
- batch,
1578
- seqlen,
1579
- nheads // ngroups,
1580
- x.stride(0),
1581
- x.stride(1),
1582
- x.stride(2),
1583
- x.stride(3),
1584
- B.stride(0),
1585
- B.stride(1),
1586
- B.stride(2),
1587
- B.stride(-1),
1588
- states.stride(0),
1589
- states.stride(1),
1590
- states.stride(2),
1591
- states.stride(3),
1592
- states.stride(4),
1593
- dt.stride(0),
1594
- dt.stride(2),
1595
- dt.stride(1),
1596
- dt.stride(3),
1597
- dA_cumsum.stride(0),
1598
- dA_cumsum.stride(2),
1599
- dA_cumsum.stride(1),
1600
- dA_cumsum.stride(3),
1601
- *(
1602
- (seq_idx.stride(0), seq_idx.stride(1))
1603
- if seq_idx is not None
1604
- else (0, 0)
1605
- ),
1606
- HAS_SEQ_IDX=seq_idx is not None,
1607
- )
1608
- return states
1609
-
1610
-
1611
- def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None):
1612
- batch, seqlen, nheads, headdim = x.shape
1613
- _, _, nchunks, chunk_size = dt.shape
1614
- _, _, ngroups, dstate = B.shape
1615
- assert nheads % ngroups == 0
1616
- assert B.shape == (batch, seqlen, ngroups, dstate)
1617
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
1618
- assert dA_cumsum.shape == dt.shape
1619
- assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
1620
- if dx is not None:
1621
- assert dx.shape == x.shape
1622
- else:
1623
- dx = torch.empty_like(x)
1624
- ddt = torch.empty(
1625
- batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
1626
- )
1627
- ddA_cumsum = torch.empty(
1628
- batch, nheads, nchunks, chunk_size, device=dA_cumsum.device, dtype=torch.float32
1629
- )
1630
- grid_dx = lambda META: (
1631
- triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
1632
- * triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
1633
- batch * nchunks,
1634
- nheads,
1635
- )
1636
- with torch.cuda.device(x.device.index):
1637
- _chunk_state_bwd_dx_kernel[grid_dx](
1638
- x,
1639
- B,
1640
- dstates,
1641
- dt,
1642
- dA_cumsum,
1643
- dx,
1644
- ddt,
1645
- ddA_cumsum,
1646
- chunk_size,
1647
- headdim,
1648
- dstate,
1649
- batch,
1650
- seqlen,
1651
- nheads // ngroups,
1652
- x.stride(0),
1653
- x.stride(1),
1654
- x.stride(2),
1655
- x.stride(3),
1656
- B.stride(0),
1657
- B.stride(1),
1658
- B.stride(2),
1659
- B.stride(-1),
1660
- dstates.stride(0),
1661
- dstates.stride(1),
1662
- dstates.stride(2),
1663
- dstates.stride(3),
1664
- dstates.stride(4),
1665
- dt.stride(0),
1666
- dt.stride(2),
1667
- dt.stride(1),
1668
- dt.stride(3),
1669
- dA_cumsum.stride(0),
1670
- dA_cumsum.stride(2),
1671
- dA_cumsum.stride(1),
1672
- dA_cumsum.stride(3),
1673
- dx.stride(0),
1674
- dx.stride(1),
1675
- dx.stride(2),
1676
- dx.stride(3),
1677
- ddt.stride(0),
1678
- ddt.stride(2),
1679
- ddt.stride(1),
1680
- ddt.stride(3),
1681
- ddA_cumsum.stride(0),
1682
- ddA_cumsum.stride(2),
1683
- ddA_cumsum.stride(1),
1684
- ddA_cumsum.stride(3),
1685
- BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
1686
- )
1687
- return dx, ddt.to(dt.dtype), ddA_cumsum.to(dA_cumsum.dtype)
1688
-
1689
-
1690
- def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1):
1691
- batch, seqlen, nheads, headdim = x.shape
1692
- _, _, nchunks, chunk_size = dt.shape
1693
- dstate = dstates.shape[-1]
1694
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
1695
- assert dA_cumsum.shape == dt.shape
1696
- assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
1697
- if seq_idx is not None:
1698
- assert seq_idx.shape == (batch, seqlen)
1699
- if B is not None:
1700
- assert B.shape == (batch, seqlen, ngroups, dstate)
1701
- B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3))
1702
- # Use torch.empty since the Triton kernel will call init_to_zero
1703
- ddA_cumsum = torch.empty(
1704
- batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32
1705
- )
1706
- ddA_cumsum_strides = (
1707
- ddA_cumsum.stride(0),
1708
- ddA_cumsum.stride(2),
1709
- ddA_cumsum.stride(1),
1710
- ddA_cumsum.stride(3),
1711
- )
1712
- else:
1713
- B_strides = (0, 0, 0, 0)
1714
- ddA_cumsum = None
1715
- ddA_cumsum_strides = (0, 0, 0, 0)
1716
- nheads_ngroups_ratio = nheads // ngroups
1717
- sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
1718
- nheads_per_program = max(
1719
- min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1
1720
- )
1721
- nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)
1722
- dB = torch.empty(
1723
- batch, seqlen, nsplits, ngroups, dstate, device=x.device, dtype=torch.float32
1724
- )
1725
- grid_db = lambda META: (
1726
- triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
1727
- * triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
1728
- batch * nchunks,
1729
- nsplits * ngroups,
1730
- )
1731
- with torch.cuda.device(x.device.index):
1732
- _chunk_state_bwd_db_kernel[grid_db](
1733
- x,
1734
- dstates,
1735
- B,
1736
- dt,
1737
- dA_cumsum,
1738
- seq_idx,
1739
- dB,
1740
- ddA_cumsum,
1741
- chunk_size,
1742
- dstate,
1743
- headdim,
1744
- batch,
1745
- seqlen,
1746
- nheads,
1747
- nheads_per_program,
1748
- ngroups,
1749
- x.stride(0),
1750
- x.stride(1),
1751
- x.stride(2),
1752
- x.stride(3),
1753
- dstates.stride(0),
1754
- dstates.stride(1),
1755
- dstates.stride(2),
1756
- dstates.stride(3),
1757
- dstates.stride(4),
1758
- *B_strides,
1759
- dt.stride(0),
1760
- dt.stride(2),
1761
- dt.stride(1),
1762
- dt.stride(3),
1763
- dA_cumsum.stride(0),
1764
- dA_cumsum.stride(2),
1765
- dA_cumsum.stride(1),
1766
- dA_cumsum.stride(3),
1767
- *(
1768
- (seq_idx.stride(0), seq_idx.stride(1))
1769
- if seq_idx is not None
1770
- else (0, 0)
1771
- ),
1772
- dB.stride(0),
1773
- dB.stride(1),
1774
- dB.stride(2),
1775
- dB.stride(3),
1776
- dB.stride(4),
1777
- *ddA_cumsum_strides,
1778
- HAS_DDA_CS=ddA_cumsum is not None,
1779
- HAS_SEQ_IDX=seq_idx is not None,
1780
- BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
1781
- )
1782
- dB = dB.sum(2)
1783
- if ddA_cumsum is not None:
1784
- # The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute
1785
- # to the state of the chunk.
1786
- # torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
1787
- # But it's easier to just do the cumsum for all elements, the result will be the same.
1788
- torch.cumsum(ddA_cumsum, dim=-1, out=ddA_cumsum)
1789
- return dB if B is None else (dB, ddA_cumsum)
1790
-
1791
-
1792
- def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None):
1793
- batch, seqlen, nheads, headdim = x.shape
1794
- _, _, nchunks, chunk_size = dt.shape
1795
- _, _, ngroups, dstate = B.shape
1796
- assert nheads % ngroups == 0
1797
- assert B.shape == (batch, seqlen, ngroups, dstate)
1798
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
1799
- assert dA_cumsum.shape == dt.shape
1800
- assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
1801
- if seq_idx is not None:
1802
- assert seq_idx.shape == (batch, seqlen)
1803
- # Use torch.empty since the Triton kernel will call init_to_zero
1804
- ddA_cumsum = torch.empty(
1805
- batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32
1806
- )
1807
- grid_ddtcs = lambda META: (
1808
- triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
1809
- * triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
1810
- batch * nchunks,
1811
- nheads,
1812
- )
1813
- with torch.cuda.device(x.device.index):
1814
- _chunk_state_bwd_ddAcs_stable_kernel[grid_ddtcs](
1815
- x,
1816
- B,
1817
- dstates,
1818
- dt,
1819
- dA_cumsum,
1820
- seq_idx,
1821
- ddA_cumsum,
1822
- chunk_size,
1823
- headdim,
1824
- dstate,
1825
- batch,
1826
- seqlen,
1827
- nheads // ngroups,
1828
- x.stride(0),
1829
- x.stride(1),
1830
- x.stride(2),
1831
- x.stride(3),
1832
- B.stride(0),
1833
- B.stride(1),
1834
- B.stride(2),
1835
- B.stride(-1),
1836
- dstates.stride(0),
1837
- dstates.stride(1),
1838
- dstates.stride(2),
1839
- dstates.stride(3),
1840
- dstates.stride(4),
1841
- dt.stride(0),
1842
- dt.stride(2),
1843
- dt.stride(1),
1844
- dt.stride(3),
1845
- dA_cumsum.stride(0),
1846
- dA_cumsum.stride(2),
1847
- dA_cumsum.stride(1),
1848
- dA_cumsum.stride(3),
1849
- *(
1850
- (seq_idx.stride(0), seq_idx.stride(1))
1851
- if seq_idx is not None
1852
- else (0, 0)
1853
- ),
1854
- ddA_cumsum.stride(0),
1855
- ddA_cumsum.stride(2),
1856
- ddA_cumsum.stride(1),
1857
- ddA_cumsum.stride(3),
1858
- HAS_SEQ_IDX=seq_idx is not None,
1859
- BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16),
1860
- BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
1861
- )
1862
- torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
1863
- return ddA_cumsum
1864
-
1865
-
1866
- def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states):
1867
- total_seqlen, nheads, headdim = x.shape
1868
- _, nchunks, chunk_size = dt.shape
1869
- _, ngroups, dstate = B.shape
1870
- batch = cu_seqlens.shape[0] - 1
1871
- cu_seqlens = cu_seqlens.contiguous()
1872
- assert nheads % ngroups == 0
1873
- assert B.shape == (total_seqlen, ngroups, dstate)
1874
- assert dt.shape == (nheads, nchunks, chunk_size)
1875
- assert dA_cumsum.shape == dt.shape
1876
- assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
1877
- states = torch.empty(
1878
- batch,
1879
- nheads,
1880
- headdim,
1881
- dstate,
1882
- dtype=chunk_states.dtype,
1883
- device=chunk_states.device,
1884
- )
1885
- grid = lambda META: (
1886
- triton.cdiv(headdim, META["BLOCK_SIZE_M"])
1887
- * triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
1888
- batch,
1889
- nheads,
1890
- )
1891
- with torch.cuda.device(x.device.index):
1892
- _chunk_state_varlen_kernel[grid](
1893
- x,
1894
- B,
1895
- dt,
1896
- dA_cumsum,
1897
- chunk_states,
1898
- cu_seqlens,
1899
- states,
1900
- headdim,
1901
- dstate,
1902
- chunk_size,
1903
- total_seqlen,
1904
- nheads // ngroups,
1905
- x.stride(0),
1906
- x.stride(1),
1907
- x.stride(2),
1908
- B.stride(0),
1909
- B.stride(1),
1910
- B.stride(2),
1911
- dt.stride(1),
1912
- dt.stride(0),
1913
- dt.stride(2),
1914
- dA_cumsum.stride(1),
1915
- dA_cumsum.stride(0),
1916
- dA_cumsum.stride(2),
1917
- chunk_states.stride(0),
1918
- chunk_states.stride(1),
1919
- chunk_states.stride(2),
1920
- chunk_states.stride(3),
1921
- states.stride(0),
1922
- states.stride(1),
1923
- states.stride(2),
1924
- states.stride(3),
1925
- )
1926
- return states
1927
-
1928
-
1929
- class ChunkStateFn(torch.autograd.Function):
1930
-
1931
- @staticmethod
1932
- def forward(ctx, B, x, dt, dA_cumsum, states_in_fp32=True):
1933
- batch, seqlen, nheads, headdim = x.shape
1934
- _, _, nchunks, chunk_size = dt.shape
1935
- assert seqlen <= nchunks * chunk_size
1936
- _, _, ngroups, dstate = B.shape
1937
- assert B.shape == (batch, seqlen, ngroups, dstate)
1938
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
1939
- assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
1940
- if B.stride(-1) != 1:
1941
- B = B.contiguous()
1942
- if (
1943
- x.stride(-1) != 1 and x.stride(1) != 1
1944
- ): # Either M or K dimension should be contiguous
1945
- x = x.contiguous()
1946
- states = _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=states_in_fp32)
1947
- ctx.save_for_backward(B, x, dt, dA_cumsum)
1948
- return states
1949
-
1950
- @staticmethod
1951
- def backward(ctx, dstates):
1952
- B, x, dt, dA_cumsum = ctx.saved_tensors
1953
- batch, seqlen, nheads, headdim = x.shape
1954
- _, _, nchunks, chunk_size = dt.shape
1955
- _, _, ngroups, dstate = B.shape
1956
- assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
1957
- if dstates.stride(-1) != 1:
1958
- dstates = dstates.contiguous()
1959
- dx, ddt, ddA_cumsum = _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates)
1960
- dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, ngroups=ngroups)
1961
- dB = dB.to(B.dtype)
1962
- return dB, dx, ddt, ddA_cumsum, None
1963
-
1964
-
1965
- def chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True):
1966
- """
1967
- Argument:
1968
- B: (batch, seqlen, ngroups, headdim)
1969
- x: (batch, seqlen, nheads, headdim)
1970
- dt: (batch, nheads, nchunks, chunk_size)
1971
- dA_cumsum: (batch, nheads, nchunks, chunk_size)
1972
- Return:
1973
- states: (batch, nchunks, nheads, headdim, dstate)
1974
- """
1975
- return ChunkStateFn.apply(B, x, dt, dA_cumsum, states_in_fp32)
1976
-
1977
-
1978
- def chunk_state_ref(B, x, dt, dA_cumsum):
1979
- """
1980
- Argument:
1981
- B: (batch, seqlen, ngroups, headdim)
1982
- x: (batch, seqlen, nheads, headdim)
1983
- dt: (batch, nheads, nchunks, chunk_size)
1984
- dA_cumsum: (batch, nheads, nchunks, chunk_size)
1985
- Return:
1986
- states: (batch, nchunks, nheads, headdim, dstate)
1987
- """
1988
- # Check constraints.
1989
- batch, seqlen, nheads, headdim = x.shape
1990
- dstate = B.shape[-1]
1991
- _, _, nchunks, chunk_size = dt.shape
1992
- assert seqlen <= nchunks * chunk_size
1993
- assert x.shape == (batch, seqlen, nheads, headdim)
1994
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
1995
- ngroups = B.shape[2]
1996
- assert nheads % ngroups == 0
1997
- assert B.shape == (batch, seqlen, ngroups, dstate)
1998
- B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
1999
- assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
2000
- if seqlen < nchunks * chunk_size:
2001
- x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
2002
- B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
2003
- x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size)
2004
- B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size)
2005
- decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum))
2006
- return torch.einsum(
2007
- "bclhn,bhcl,bhcl,bclhp->bchpn",
2008
- B.to(x.dtype),
2009
- decay_states.to(x.dtype),
2010
- dt.to(x.dtype),
2011
- x,
2012
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_combined.py DELETED
@@ -1,1884 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao, Albert Gu.
2
-
3
- """We want triton==2.1.0 or 2.2.0 for this
4
- """
5
-
6
- from typing import Optional
7
-
8
- import math
9
- from packaging import version
10
-
11
- import torch
12
- import torch.nn.functional as F
13
- from torch import Tensor
14
- from ...utils.torch import custom_bwd, custom_fwd
15
-
16
- import triton
17
- import triton.language as tl
18
-
19
- from einops import rearrange, repeat
20
-
21
- try:
22
- from causal_conv1d import causal_conv1d_fn
23
- import causal_conv1d_cuda
24
- except ImportError:
25
- causal_conv1d_fn, causal_conv1d_cuda = None, None
26
-
27
- from .ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
28
- from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd
29
- from .ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db
30
- from .ssd_chunk_state import _chunk_state_bwd_ddAcs_stable
31
- from .ssd_chunk_state import chunk_state, chunk_state_ref
32
- from .ssd_chunk_state import chunk_state_varlen
33
- from .ssd_state_passing import _state_passing_fwd, _state_passing_bwd
34
- from .ssd_state_passing import state_passing, state_passing_ref
35
- from .ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates
36
- from .ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb
37
- from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable
38
- from .ssd_chunk_scan import chunk_scan, chunk_scan_ref
39
- from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev
40
- from .layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd
41
- from .k_activations import _swiglu_fwd, _swiglu_bwd
42
-
43
- TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
44
-
45
-
46
- def init_to_zero(names):
47
- return lambda nargs: [
48
- nargs[name].zero_() for name in names if nargs[name] is not None
49
- ]
50
-
51
-
52
- @triton.autotune(
53
- configs=[
54
- triton.Config(
55
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
56
- num_stages=3,
57
- num_warps=8,
58
- pre_hook=init_to_zero(["ddt_ptr"]),
59
- ),
60
- triton.Config(
61
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
62
- num_stages=4,
63
- num_warps=4,
64
- pre_hook=init_to_zero(["ddt_ptr"]),
65
- ),
66
- triton.Config(
67
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
68
- num_stages=4,
69
- num_warps=4,
70
- pre_hook=init_to_zero(["ddt_ptr"]),
71
- ),
72
- triton.Config(
73
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
74
- num_stages=4,
75
- num_warps=4,
76
- pre_hook=init_to_zero(["ddt_ptr"]),
77
- ),
78
- triton.Config(
79
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
80
- num_stages=4,
81
- num_warps=4,
82
- pre_hook=init_to_zero(["ddt_ptr"]),
83
- ),
84
- triton.Config(
85
- {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
86
- num_stages=4,
87
- num_warps=4,
88
- pre_hook=init_to_zero(["ddt_ptr"]),
89
- ),
90
- triton.Config(
91
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
92
- num_stages=5,
93
- num_warps=4,
94
- pre_hook=init_to_zero(["ddt_ptr"]),
95
- ),
96
- triton.Config(
97
- {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
98
- num_stages=5,
99
- num_warps=4,
100
- pre_hook=init_to_zero(["ddt_ptr"]),
101
- ),
102
- triton.Config(
103
- {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
104
- num_stages=4,
105
- num_warps=4,
106
- pre_hook=init_to_zero(["ddt_ptr"]),
107
- ),
108
- ],
109
- key=["chunk_size", "hdim", "dstate"],
110
- )
111
- @triton.jit
112
- def _chunk_scan_chunk_state_bwd_dx_kernel(
113
- # Pointers to matrices
114
- x_ptr,
115
- cb_ptr,
116
- dout_ptr,
117
- dt_ptr,
118
- dA_cumsum_ptr,
119
- seq_idx_ptr,
120
- D_ptr,
121
- b_ptr,
122
- dstates_ptr,
123
- dx_ptr,
124
- ddt_ptr,
125
- dD_ptr,
126
- # Matrix dimensions
127
- chunk_size,
128
- hdim,
129
- dstate,
130
- batch,
131
- seqlen,
132
- nheads_ngroups_ratio,
133
- # Strides
134
- stride_x_batch,
135
- stride_x_seqlen,
136
- stride_x_head,
137
- stride_x_hdim,
138
- stride_cb_batch,
139
- stride_cb_chunk,
140
- stride_cb_head,
141
- stride_cb_csize_m,
142
- stride_cb_csize_k,
143
- stride_dout_batch,
144
- stride_dout_seqlen,
145
- stride_dout_head,
146
- stride_dout_hdim,
147
- stride_dt_batch,
148
- stride_dt_chunk,
149
- stride_dt_head,
150
- stride_dt_csize,
151
- stride_dA_cs_batch,
152
- stride_dA_cs_chunk,
153
- stride_dA_cs_head,
154
- stride_dA_cs_csize,
155
- stride_seq_idx_batch,
156
- stride_seq_idx_seqlen,
157
- stride_D_head,
158
- stride_b_batch,
159
- stride_b_seqlen,
160
- stride_b_head,
161
- stride_b_dstate,
162
- stride_dstates_batch,
163
- stride_dstates_chunk,
164
- stride_dstates_head,
165
- stride_dstates_hdim,
166
- stride_dstates_dstate,
167
- stride_dx_batch,
168
- stride_dx_seqlen,
169
- stride_dx_head,
170
- stride_dx_hdim,
171
- stride_ddt_batch,
172
- stride_ddt_chunk,
173
- stride_ddt_head,
174
- stride_ddt_csize,
175
- stride_dD_batch,
176
- stride_dD_chunk,
177
- stride_dD_head,
178
- stride_dD_csize,
179
- stride_dD_hdim,
180
- # Meta-parameters
181
- HAS_D: tl.constexpr,
182
- D_HAS_HDIM: tl.constexpr,
183
- HAS_SEQ_IDX: tl.constexpr,
184
- BLOCK_SIZE_M: tl.constexpr,
185
- BLOCK_SIZE_N: tl.constexpr,
186
- BLOCK_SIZE_K: tl.constexpr,
187
- BLOCK_SIZE_DSTATE: tl.constexpr,
188
- IS_TRITON_22: tl.constexpr,
189
- ):
190
- pid_bc = tl.program_id(axis=1)
191
- pid_c = pid_bc // batch
192
- pid_b = pid_bc - pid_c * batch
193
- pid_h = tl.program_id(axis=2)
194
- num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
195
- pid_m = tl.program_id(axis=0) // num_pid_n
196
- pid_n = tl.program_id(axis=0) % num_pid_n
197
- x_ptr += (
198
- pid_b * stride_x_batch
199
- + pid_c * chunk_size * stride_x_seqlen
200
- + pid_h * stride_x_head
201
- )
202
- cb_ptr += (
203
- pid_b * stride_cb_batch
204
- + pid_c * stride_cb_chunk
205
- + (pid_h // nheads_ngroups_ratio) * stride_cb_head
206
- )
207
- dout_ptr += (
208
- pid_b * stride_dout_batch
209
- + pid_c * chunk_size * stride_dout_seqlen
210
- + pid_h * stride_dout_head
211
- )
212
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
213
- ddt_ptr += (
214
- pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
215
- )
216
- dA_cumsum_ptr += (
217
- pid_b * stride_dA_cs_batch
218
- + pid_c * stride_dA_cs_chunk
219
- + pid_h * stride_dA_cs_head
220
- )
221
- b_ptr += (
222
- pid_b * stride_b_batch
223
- + pid_c * chunk_size * stride_b_seqlen
224
- + (pid_h // nheads_ngroups_ratio) * stride_b_head
225
- )
226
- dstates_ptr += (
227
- pid_b * stride_dstates_batch
228
- + pid_c * stride_dstates_chunk
229
- + pid_h * stride_dstates_head
230
- )
231
- if HAS_SEQ_IDX:
232
- seq_idx_ptr += (
233
- pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
234
- )
235
-
236
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
237
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
238
-
239
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
240
-
241
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
242
-
243
- dA_cs_m = tl.load(
244
- dA_cumsum_ptr + offs_m * stride_dA_cs_csize,
245
- mask=offs_m < chunk_size_limit,
246
- other=0.0,
247
- ).to(tl.float32)
248
-
249
- dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
250
- tl.float32
251
- )
252
- if not HAS_SEQ_IDX:
253
- scale = tl.exp(dA_cs_last - dA_cs_m)
254
- else:
255
- seq_idx_m = tl.load(
256
- seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
257
- mask=offs_m < chunk_size_limit,
258
- other=-1,
259
- )
260
- seq_idx_last = tl.load(
261
- seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
262
- )
263
- scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
264
- # Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
265
- # However, we're getting error with the Triton compiler 2.1.0 for that code path:
266
- # Unexpected mma -> mma layout conversion
267
- # Triton 2.2.0 fixes this
268
- offs_dstate = tl.arange(
269
- 0,
270
- (
271
- BLOCK_SIZE_DSTATE
272
- if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128
273
- else BLOCK_SIZE_K
274
- ),
275
- )
276
- b_ptrs = b_ptr + (
277
- offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate
278
- )
279
- dstates_ptrs = dstates_ptr + (
280
- offs_n[None, :] * stride_dstates_hdim
281
- + offs_dstate[:, None] * stride_dstates_dstate
282
- )
283
- if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128:
284
- b = tl.load(
285
- b_ptrs,
286
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate),
287
- other=0.0,
288
- )
289
- dstates = tl.load(
290
- dstates_ptrs,
291
- mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim),
292
- other=0.0,
293
- )
294
- dstates = dstates.to(b_ptr.dtype.element_ty)
295
- acc = tl.dot(b, dstates) * scale[:, None]
296
- else:
297
- for k in range(0, dstate, BLOCK_SIZE_K):
298
- b = tl.load(
299
- b_ptrs,
300
- mask=(offs_m[:, None] < chunk_size_limit)
301
- & (offs_dstate[None, :] < dstate - k),
302
- other=0.0,
303
- )
304
- dstates = tl.load(
305
- dstates_ptrs,
306
- mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim),
307
- other=0.0,
308
- )
309
- dstates = dstates.to(b_ptr.dtype.element_ty)
310
- acc += tl.dot(b, dstates)
311
- b_ptrs += BLOCK_SIZE_K * stride_b_dstate
312
- dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate
313
- acc *= scale[:, None]
314
-
315
- # x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
316
- # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
317
- # dt_ptrs = dt_ptr + offs_m * stride_dt_csize
318
- # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
319
- # ddt = tl.sum(acc * x, axis=1) * dt_m
320
- # ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
321
- # tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
322
-
323
- offs_k = tl.arange(0, BLOCK_SIZE_K)
324
- cb_ptrs = cb_ptr + (
325
- offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k
326
- )
327
- dout_ptrs = dout_ptr + (
328
- offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim
329
- )
330
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
331
- K_MAX = chunk_size_limit
332
- K_MIN = pid_m * BLOCK_SIZE_M
333
- cb_ptrs += K_MIN * stride_cb_csize_k
334
- dout_ptrs += K_MIN * stride_dout_seqlen
335
- dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize
336
- for k in range(K_MIN, K_MAX, BLOCK_SIZE_K):
337
- k = tl.multiple_of(k, BLOCK_SIZE_K)
338
- # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower
339
- cb = tl.load(
340
- cb_ptrs,
341
- mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k),
342
- other=0.0,
343
- )
344
- dout = tl.load(
345
- dout_ptrs,
346
- mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim),
347
- other=0.0,
348
- )
349
- dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(
350
- tl.float32
351
- )
352
- cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
353
- # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
354
- # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
355
- # Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
356
- # This will cause NaN in acc, and hence NaN in dx and ddt.
357
- mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX)
358
- cb = tl.where(mask, cb, 0.0)
359
- cb = cb.to(dout_ptr.dtype.element_ty)
360
- acc += tl.dot(cb, dout)
361
- cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
362
- dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen
363
- dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
364
-
365
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
366
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
367
- dt_ptrs = dt_ptr + offs_m * stride_dt_csize
368
- dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
369
- dx = acc * dt_m[:, None]
370
- dx_ptr += (
371
- pid_b * stride_dx_batch
372
- + pid_c * chunk_size * stride_dx_seqlen
373
- + pid_h * stride_dx_head
374
- )
375
- dx_ptrs = dx_ptr + (
376
- offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim
377
- )
378
- if HAS_D:
379
- dout_res_ptrs = dout_ptr + (
380
- offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim
381
- )
382
- dout_res = tl.load(
383
- dout_res_ptrs,
384
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
385
- other=0.0,
386
- ).to(tl.float32)
387
- if D_HAS_HDIM:
388
- D = tl.load(
389
- D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0
390
- ).to(tl.float32)
391
- else:
392
- D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
393
- dx += dout_res * D
394
- tl.store(
395
- dx_ptrs,
396
- dx,
397
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
398
- )
399
-
400
- x_ptrs = x_ptr + (
401
- offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
402
- )
403
- x = tl.load(
404
- x_ptrs,
405
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
406
- other=0.0,
407
- ).to(tl.float32)
408
- if HAS_D:
409
- dD_ptr += (
410
- pid_b * stride_dD_batch
411
- + pid_c * stride_dD_chunk
412
- + pid_h * stride_dD_head
413
- + pid_m * stride_dD_csize
414
- )
415
- if D_HAS_HDIM:
416
- dD_ptrs = dD_ptr + offs_n * stride_dD_hdim
417
- dD = tl.sum(dout_res * x, axis=0)
418
- tl.store(dD_ptrs, dD, mask=offs_n < hdim)
419
- else:
420
- dD = tl.sum(dout_res * x)
421
- tl.store(dD_ptr, dD)
422
- ddt = tl.sum(acc * x, axis=1)
423
- ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
424
- tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
425
-
426
-
427
- def _chunk_scan_chunk_state_bwd_dx(
428
- x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None
429
- ):
430
- batch, seqlen, nheads, headdim = x.shape
431
- _, _, nchunks, chunk_size = dt.shape
432
- _, _, ngroups, dstate = B.shape
433
- assert nheads % ngroups == 0
434
- assert B.shape == (batch, seqlen, ngroups, dstate)
435
- assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
436
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
437
- assert dA_cumsum.shape == dt.shape
438
- assert dout.shape == x.shape
439
- assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
440
- if seq_idx is not None:
441
- assert seq_idx.shape == (batch, seqlen)
442
- if D is not None:
443
- assert D.shape == (nheads, headdim) or D.shape == (nheads,)
444
- assert D.stride(-1) == 1
445
- BLOCK_SIZE_min = 32
446
- dD = torch.empty(
447
- triton.cdiv(chunk_size, BLOCK_SIZE_min),
448
- batch,
449
- nchunks,
450
- nheads,
451
- headdim if D.dim() == 2 else 1,
452
- device=D.device,
453
- dtype=torch.float32,
454
- )
455
- else:
456
- dD = None
457
- dD_strides = (
458
- (dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))
459
- if D is not None
460
- else (0, 0, 0, 0, 0)
461
- )
462
- if dx is None:
463
- dx = torch.empty_like(x)
464
- else:
465
- assert dx.shape == x.shape
466
- ddt = torch.empty(
467
- batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32
468
- )
469
- grid_dx = lambda META: (
470
- triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
471
- * triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
472
- batch * nchunks,
473
- nheads,
474
- )
475
- with torch.cuda.device(x.device.index):
476
- _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](
477
- x,
478
- CB,
479
- dout,
480
- dt,
481
- dA_cumsum,
482
- seq_idx,
483
- D,
484
- B,
485
- dstates,
486
- dx,
487
- ddt,
488
- dD,
489
- chunk_size,
490
- headdim,
491
- dstate,
492
- batch,
493
- seqlen,
494
- nheads // ngroups,
495
- x.stride(0),
496
- x.stride(1),
497
- x.stride(2),
498
- x.stride(3),
499
- CB.stride(0),
500
- CB.stride(1),
501
- CB.stride(2),
502
- CB.stride(-1),
503
- CB.stride(-2),
504
- dout.stride(0),
505
- dout.stride(1),
506
- dout.stride(2),
507
- dout.stride(3),
508
- dt.stride(0),
509
- dt.stride(2),
510
- dt.stride(1),
511
- dt.stride(3),
512
- dA_cumsum.stride(0),
513
- dA_cumsum.stride(2),
514
- dA_cumsum.stride(1),
515
- dA_cumsum.stride(3),
516
- *(
517
- (seq_idx.stride(0), seq_idx.stride(1))
518
- if seq_idx is not None
519
- else (0, 0)
520
- ),
521
- D.stride(0) if D is not None else 0,
522
- B.stride(0),
523
- B.stride(1),
524
- B.stride(2),
525
- B.stride(3),
526
- dstates.stride(0),
527
- dstates.stride(1),
528
- dstates.stride(2),
529
- dstates.stride(3),
530
- dstates.stride(4),
531
- dx.stride(0),
532
- dx.stride(1),
533
- dx.stride(2),
534
- dx.stride(3),
535
- ddt.stride(0),
536
- ddt.stride(2),
537
- ddt.stride(1),
538
- ddt.stride(3),
539
- dD_strides[1],
540
- dD_strides[2],
541
- dD_strides[3],
542
- dD_strides[0],
543
- dD_strides[4],
544
- D is not None,
545
- D.dim() == 2 if D is not None else True,
546
- HAS_SEQ_IDX=seq_idx is not None,
547
- BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
548
- IS_TRITON_22=TRITON_22
549
- )
550
- if D is not None:
551
- BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs[
552
- "BLOCK_SIZE_M"
553
- ]
554
- n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
555
- dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
556
- if D.dim() == 1:
557
- dD = rearrange(dD, "h 1 -> h")
558
- return dx, ddt.to(dtype=dt.dtype), dD
559
-
560
-
561
- def _mamba_chunk_scan_combined_fwd(
562
- x,
563
- dt,
564
- A,
565
- B,
566
- C,
567
- chunk_size,
568
- D=None,
569
- z=None,
570
- dt_bias=None,
571
- initial_states=None,
572
- seq_idx=None,
573
- cu_seqlens=None,
574
- dt_softplus=False,
575
- dt_limit=(0.0, float("inf")),
576
- ):
577
- batch, seqlen, nheads, headdim = x.shape
578
- _, _, ngroups, dstate = B.shape
579
- assert nheads % ngroups == 0
580
- assert B.shape == (batch, seqlen, ngroups, dstate)
581
- assert x.shape == (batch, seqlen, nheads, headdim)
582
- assert dt.shape == (batch, seqlen, nheads)
583
- assert A.shape == (nheads,)
584
- assert C.shape == B.shape
585
- if z is not None:
586
- assert z.shape == x.shape
587
- if D is not None:
588
- assert D.shape == (nheads, headdim) or D.shape == (nheads,)
589
- if seq_idx is not None:
590
- assert seq_idx.shape == (batch, seqlen)
591
- if B.stride(-1) != 1:
592
- B = B.contiguous()
593
- if C.stride(-1) != 1:
594
- C = C.contiguous()
595
- if (
596
- x.stride(-1) != 1 and x.stride(1) != 1
597
- ): # Either M or K dimension should be contiguous
598
- x = x.contiguous()
599
- if (
600
- z is not None and z.stride(-1) != 1 and z.stride(1) != 1
601
- ): # Either M or K dimension should be contiguous
602
- z = z.contiguous()
603
- if D is not None and D.stride(-1) != 1:
604
- D = D.contiguous()
605
- if initial_states is not None:
606
- assert initial_states.shape == (batch, nheads, headdim, dstate)
607
- # # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size)
608
- # dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
609
- # dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
610
- # dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
611
- dA_cumsum, dt = _chunk_cumsum_fwd(
612
- dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit
613
- )
614
- states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
615
- # states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True)
616
- # states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True)
617
- # states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True)
618
- states, final_states = _state_passing_fwd(
619
- rearrange(states, "... p n -> ... (p n)"),
620
- dA_cumsum[:, :, :, -1],
621
- initial_states=(
622
- rearrange(initial_states, "... p n -> ... (p n)")
623
- if initial_states is not None
624
- else None
625
- ),
626
- seq_idx=seq_idx,
627
- chunk_size=chunk_size,
628
- out_dtype=C.dtype,
629
- )
630
- states, final_states = [
631
- rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]
632
- ]
633
- # states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
634
- # states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
635
- CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
636
- out, out_x = _chunk_scan_fwd(
637
- CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx
638
- )
639
- if cu_seqlens is None:
640
- return out, out_x, dt, dA_cumsum, states, final_states
641
- else:
642
- assert (
643
- batch == 1
644
- ), "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
645
- varlen_states = chunk_state_varlen(
646
- B.squeeze(0),
647
- x.squeeze(0),
648
- dt.squeeze(0),
649
- dA_cumsum.squeeze(0),
650
- cu_seqlens,
651
- states.squeeze(0),
652
- )
653
- return out, out_x, dt, dA_cumsum, states, final_states, varlen_states
654
-
655
-
656
- def _mamba_chunk_scan_combined_bwd(
657
- dout,
658
- x,
659
- dt,
660
- A,
661
- B,
662
- C,
663
- out,
664
- chunk_size,
665
- D=None,
666
- z=None,
667
- dt_bias=None,
668
- initial_states=None,
669
- dfinal_states=None,
670
- seq_idx=None,
671
- dt_softplus=False,
672
- dt_limit=(0.0, float("inf")),
673
- dx=None,
674
- ddt=None,
675
- dB=None,
676
- dC=None,
677
- dz=None,
678
- recompute_output=False,
679
- ):
680
- if dout.stride(-1) != 1:
681
- dout = dout.contiguous()
682
- batch, seqlen, nheads, headdim = x.shape
683
- nchunks = math.ceil(seqlen / chunk_size)
684
- _, _, ngroups, dstate = B.shape
685
- assert dout.shape == (batch, seqlen, nheads, headdim)
686
- assert dt.shape == (batch, seqlen, nheads)
687
- assert A.shape == (nheads,)
688
- assert nheads % ngroups == 0
689
- assert B.shape == (batch, seqlen, ngroups, dstate)
690
- assert C.shape == B.shape
691
- assert out.shape == x.shape
692
- if initial_states is not None:
693
- assert initial_states.shape == (batch, nheads, headdim, dstate)
694
- if seq_idx is not None:
695
- assert seq_idx.shape == (batch, seqlen)
696
- if dx is not None:
697
- assert dx.shape == x.shape
698
- if dB is not None:
699
- assert dB.shape == B.shape
700
- dB_given = dB
701
- else:
702
- dB_given = torch.empty_like(B)
703
- if dC is not None:
704
- assert dC.shape == C.shape
705
- dC_given = dC
706
- else:
707
- dC_given = torch.empty_like(C)
708
- if dz is not None:
709
- assert z is not None
710
- assert dz.shape == z.shape
711
- if ddt is not None:
712
- assert ddt.shape == dt.shape
713
- ddt_given = ddt
714
- else:
715
- ddt_given = torch.empty_like(dt)
716
- # TD: For some reason Triton (2.1.0 and 2.2.0) errors with
717
- # "[CUDA]: invalid device context" (e.g. during varlne test), and cloning makes it work. Idk why.
718
- dt_in = dt.clone()
719
- dA_cumsum, dt = _chunk_cumsum_fwd(
720
- dt_in,
721
- A,
722
- chunk_size,
723
- dt_bias=dt_bias,
724
- dt_softplus=dt_softplus,
725
- dt_limit=dt_limit,
726
- )
727
- CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
728
- states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
729
- states, _ = _state_passing_fwd(
730
- rearrange(states, "... p n -> ... (p n)"),
731
- dA_cumsum[:, :, :, -1],
732
- initial_states=(
733
- rearrange(initial_states, "... p n -> ... (p n)")
734
- if initial_states is not None
735
- else None
736
- ),
737
- seq_idx=seq_idx,
738
- chunk_size=chunk_size,
739
- )
740
- states = rearrange(states, "... (p n) -> ... p n", n=dstate)
741
- if z is not None:
742
- dz, dout, dD, *rest = _chunk_scan_bwd_dz(
743
- x,
744
- z,
745
- out,
746
- dout,
747
- chunk_size=chunk_size,
748
- has_ddAcs=False,
749
- D=D,
750
- dz=dz,
751
- recompute_output=recompute_output,
752
- )
753
- outz = rest[0] if recompute_output else out
754
- else:
755
- dz = None
756
- outz = out
757
- dstates = _chunk_scan_bwd_dstates(
758
- C, dA_cumsum, dout, seq_idx=seq_idx, dtype=states.dtype
759
- )
760
- # dstates has length nchunks, containing the gradient to initial states at index 0 and
761
- # gradient to the states of chunk (nchunks - 2) at index (nchunks - 1)
762
- # Do computation in fp32 but convert dstates and states to fp16/bf16 since dstates and states
763
- # will be used in matmul in the next kernels.
764
- dstates, ddA_chunk_cumsum, dinitial_states, states = _state_passing_bwd(
765
- rearrange(states, "... p n -> ... (p n)"),
766
- dA_cumsum[:, :, :, -1],
767
- rearrange(dstates, "... p n -> ... (p n)"),
768
- dfinal_states=(
769
- rearrange(dfinal_states, "... p n -> ... (p n)")
770
- if dfinal_states is not None
771
- else None
772
- ),
773
- seq_idx=seq_idx,
774
- has_initial_states=initial_states is not None,
775
- dstates_dtype=x.dtype,
776
- states_dtype=x.dtype,
777
- chunk_size=chunk_size,
778
- )
779
- # dstates has length nchunks, containing the gradient to states of chunk 0 at index 0 and
780
- # gradient to the final states at index (nchunks - 1)
781
- # states has length nchunks, containing the initial states at index 0 and the state for chunk (nchunks - 2) at index (nchunks - 1)
782
- # The final states is not stored.
783
- states = rearrange(states, "... (p n) -> ... p n", n=dstate)
784
- dstates = rearrange(dstates, "... (p n) -> ... p n", n=dstate)
785
- dinitial_states = (
786
- rearrange(dinitial_states, "... (p n) -> ... p n", n=dstate)
787
- if dinitial_states is not None
788
- else None
789
- )
790
- dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(
791
- x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx
792
- )
793
- # dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, ngroups=ngroups)
794
- dB, ddA_next = _chunk_state_bwd_db(
795
- x, dt, dA_cumsum, dstates, seq_idx=seq_idx, B=B, ngroups=ngroups
796
- )
797
- # dC = _chunk_scan_bwd_dC(states[:, :-1].to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
798
- dC, ddA_cumsum_prev = _chunk_scan_bwd_dC(
799
- states.to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, C=C, ngroups=ngroups
800
- )
801
- # Computing ddA with the dcb kernel is much slower, so we're not using it for now
802
- dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
803
- # dCB, ddA_tmp = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, CB=CB, ngroups=ngroups)
804
- dCB = dCB.to(CB.dtype)
805
- _bmm_chunk_bwd(C, dCB, residual=dB, out=dB_given)
806
- _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC, out=dC_given)
807
- # If we have z, then dout_x is recomputed in fp32 so dD = (dout_x * x).sum() is more accurate
808
- # than dD_from_x = (dout_x * x).sum() where dout_x is in fp16/bf16
809
- if z is None:
810
- dD = dD_from_x
811
- # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D.
812
- # ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt
813
- # However, this is numerically unstable: when we do the reverse cumsum on ddA_cumsum, there might
814
- # be a lot of underflow.
815
-
816
- # This is already done as part of bwd_dC kernel
817
- # ddA_cumsum_prev = _chunk_scan_bwd_ddAcs_prev(states[:, :-1], C, dout, dA_cumsum, seq_idx=seq_idx)
818
- ddA_cumsum_prev[..., -1] += ddA_chunk_cumsum
819
- ddA_prev = ddA_cumsum_prev.flip([-1]).cumsum(dim=-1).flip([-1])
820
- # This is already done as part of bwd_dB kernel
821
- # ddA_next = _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=seq_idx)
822
- # We don't need to pass in seq_idx because CB also zeros out entries where seq_idx[i] != seq_idx[j]
823
- ddA = _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, CB)
824
- ddA += ddA_next + ddA_prev
825
-
826
- ddt_given, dA, ddt_bias = _chunk_cumsum_bwd(
827
- ddA,
828
- ddt,
829
- dt_in,
830
- A,
831
- dt_bias=dt_bias,
832
- dt_softplus=dt_softplus,
833
- dt_limit=dt_limit,
834
- ddt=ddt_given,
835
- )
836
-
837
- # These 2 lines are just to test ddt and dA being computed by old code
838
- # _, dA = selective_scan_bwd(dout, x, dt, A, B, C, D=D.float(), z=z)
839
- # ddt_given.copy_(ddt)
840
-
841
- return_vals = (
842
- dx,
843
- ddt_given,
844
- dA,
845
- dB_given,
846
- dC_given,
847
- dD,
848
- dz,
849
- ddt_bias,
850
- dinitial_states,
851
- )
852
- return return_vals if not recompute_output else (*return_vals, outz)
853
-
854
-
855
- def selective_scan_bwd(dout, x, dt, A, B, C, D=None, z=None):
856
- """
857
- Argument:
858
- dout: (batch, seqlen, nheads, headdim)
859
- x: (batch, seqlen, nheads, headdim)
860
- dt: (batch, nheads, nchunks, chunk_size) or (batch, nheads, headdim, nchunks, chunk_size)
861
- A: (nheads) or (dim, dstate)
862
- B: (batch, seqlen, ngroups, dstate)
863
- C: (batch, seqlen, ngroups, dstate)
864
- D: (nheads, headdim) or (nheads,)
865
- z: (batch, seqlen, nheads, headdim)
866
- Return:
867
- out: (batch, seqlen, nheads, headdim)
868
- """
869
- import selective_scan
870
-
871
- batch, seqlen, nheads, headdim = x.shape
872
- chunk_size = dt.shape[-1]
873
- _, _, ngroups, dstate = B.shape
874
- assert nheads % ngroups == 0
875
- x = rearrange(x, "b l h p -> b (h p) l")
876
- squeeze_dt = dt.dim() == 4
877
- if dt.dim() == 4:
878
- dt = repeat(dt, "b h c l -> b h p c l", p=headdim)
879
- dt = rearrange(dt, "b h p c l -> b (h p) (c l)", p=headdim)
880
- squeeze_A = A.dim() == 1
881
- if A.dim() == 1:
882
- A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
883
- else:
884
- A = A.to(dtype=torch.float32)
885
- B = rearrange(B, "b l g n -> b g n l")
886
- C = rearrange(C, "b l g n -> b g n l")
887
- if D is not None:
888
- if D.dim() == 2:
889
- D = rearrange(D, "h p -> (h p)")
890
- else:
891
- D = repeat(D, "h -> (h p)", p=headdim)
892
- if z is not None:
893
- z = rearrange(z, "b l h p -> b (h p) l")
894
-
895
- if x.stride(-1) != 1:
896
- x = x.contiguous()
897
- if dt.stride(-1) != 1:
898
- dt = dt.contiguous()
899
- if D is not None:
900
- D = D.contiguous()
901
- if B.stride(-1) != 1:
902
- B = B.contiguous()
903
- if C.stride(-1) != 1:
904
- C = C.contiguous()
905
- if z is not None and z.stride(-1) != 1:
906
- z = z.contiguous()
907
- _, intermediate, *rest = selective_scan.fwd(
908
- x, dt.to(dtype=x.dtype), A, B, C, D, z, None, False
909
- )
910
- if z is not None:
911
- out = rest[0]
912
- else:
913
- out = None
914
-
915
- dout = rearrange(dout, "b l h p -> b (h p) l")
916
-
917
- if dout.stride(-1) != 1:
918
- dout = dout.contiguous()
919
- # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
920
- # backward of selective_scan with the backward of chunk).
921
- # Here we just pass in None and dz will be allocated in the C++ code.
922
- _, ddt, dA, *rest = selective_scan.bwd(
923
- x,
924
- dt.to(dtype=x.dtype),
925
- A,
926
- B,
927
- C,
928
- D,
929
- z,
930
- None,
931
- dout,
932
- intermediate,
933
- out,
934
- None,
935
- False,
936
- False, # option to recompute out_z, not used here
937
- )
938
- ddt = rearrange(ddt, "b (h p) (c l) -> b h p c l", p=headdim, l=chunk_size)
939
- if squeeze_dt:
940
- ddt = ddt.float().sum(dim=2)
941
- if squeeze_A:
942
- dA = rearrange(dA, "(h p) n -> h p n", p=headdim).sum(dim=(1, 2))
943
- return ddt, dA
944
-
945
-
946
- class MambaChunkScanCombinedFn(torch.autograd.Function):
947
-
948
- @staticmethod
949
- def forward(
950
- ctx,
951
- x,
952
- dt,
953
- A,
954
- B,
955
- C,
956
- chunk_size,
957
- D=None,
958
- z=None,
959
- dt_bias=None,
960
- initial_states=None,
961
- seq_idx=None,
962
- cu_seqlens=None,
963
- dt_softplus=False,
964
- dt_limit=(0.0, float("inf")),
965
- return_final_states=False,
966
- return_varlen_states=False,
967
- ):
968
- ctx.dt_dtype = dt.dtype
969
- if not return_varlen_states:
970
- cu_seqlens = None
971
- else:
972
- assert (
973
- cu_seqlens is not None
974
- ), "cu_seqlens must be provided if return_varlen_states is True"
975
- out, out_x, dt_out, dA_cumsum, states, final_states, *rest = (
976
- _mamba_chunk_scan_combined_fwd(
977
- x,
978
- dt,
979
- A,
980
- B,
981
- C,
982
- chunk_size,
983
- D=D,
984
- z=z,
985
- dt_bias=dt_bias,
986
- initial_states=initial_states,
987
- seq_idx=seq_idx,
988
- cu_seqlens=cu_seqlens,
989
- dt_softplus=dt_softplus,
990
- dt_limit=dt_limit,
991
- )
992
- )
993
- ctx.save_for_backward(
994
- out if z is None else out_x,
995
- x,
996
- dt,
997
- dA_cumsum,
998
- A,
999
- B,
1000
- C,
1001
- D,
1002
- z,
1003
- dt_bias,
1004
- initial_states,
1005
- seq_idx,
1006
- )
1007
- ctx.dt_softplus = dt_softplus
1008
- ctx.chunk_size = chunk_size
1009
- ctx.dt_limit = dt_limit
1010
- ctx.return_final_states = return_final_states
1011
- ctx.return_varlen_states = return_varlen_states
1012
- if not return_varlen_states:
1013
- return out if not return_final_states else (out, final_states)
1014
- else:
1015
- varlen_states = rest[0]
1016
- return (
1017
- (out, varlen_states)
1018
- if not return_final_states
1019
- else (out, final_states, varlen_states)
1020
- )
1021
-
1022
- @staticmethod
1023
- def backward(ctx, dout, *args):
1024
- out, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx = (
1025
- ctx.saved_tensors
1026
- )
1027
- assert (
1028
- not ctx.return_varlen_states
1029
- ), "return_varlen_states is not supported in backward"
1030
- dfinal_states = args[0] if ctx.return_final_states else None
1031
- dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = (
1032
- _mamba_chunk_scan_combined_bwd(
1033
- dout,
1034
- x,
1035
- dt,
1036
- A,
1037
- B,
1038
- C,
1039
- out,
1040
- ctx.chunk_size,
1041
- D=D,
1042
- z=z,
1043
- dt_bias=dt_bias,
1044
- initial_states=initial_states,
1045
- dfinal_states=dfinal_states,
1046
- seq_idx=seq_idx,
1047
- dt_softplus=ctx.dt_softplus,
1048
- dt_limit=ctx.dt_limit,
1049
- )
1050
- )
1051
- return (
1052
- dx,
1053
- ddt,
1054
- dA,
1055
- dB,
1056
- dC,
1057
- None,
1058
- dD,
1059
- dz,
1060
- ddt_bias,
1061
- dinitial_states,
1062
- None,
1063
- None,
1064
- None,
1065
- None,
1066
- None,
1067
- None,
1068
- )
1069
-
1070
-
1071
- def mamba_chunk_scan_combined(
1072
- x,
1073
- dt,
1074
- A,
1075
- B,
1076
- C,
1077
- chunk_size,
1078
- D=None,
1079
- z=None,
1080
- dt_bias=None,
1081
- initial_states=None,
1082
- seq_idx=None,
1083
- cu_seqlens=None,
1084
- dt_softplus=False,
1085
- dt_limit=(0.0, float("inf")),
1086
- return_final_states=False,
1087
- return_varlen_states=False,
1088
- ):
1089
- """
1090
- Argument:
1091
- x: (batch, seqlen, nheads, headdim)
1092
- dt: (batch, seqlen, nheads)
1093
- A: (nheads)
1094
- B: (batch, seqlen, ngroups, dstate)
1095
- C: (batch, seqlen, ngroups, dstate)
1096
- chunk_size: int
1097
- D: (nheads, headdim) or (nheads,)
1098
- z: (batch, seqlen, nheads, headdim)
1099
- dt_bias: (nheads,)
1100
- initial_states: (batch, nheads, headdim, dstate)
1101
- seq_idx: (batch, seqlen)
1102
- cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
1103
- dt_softplus: Whether to apply softplus to dt
1104
- Return:
1105
- out: (batch, seqlen, nheads, headdim)
1106
- """
1107
- return MambaChunkScanCombinedFn.apply(
1108
- x,
1109
- dt,
1110
- A,
1111
- B,
1112
- C,
1113
- chunk_size,
1114
- D,
1115
- z,
1116
- dt_bias,
1117
- initial_states,
1118
- seq_idx,
1119
- cu_seqlens,
1120
- dt_softplus,
1121
- dt_limit,
1122
- return_final_states,
1123
- return_varlen_states,
1124
- )
1125
-
1126
-
1127
- def mamba_chunk_scan(
1128
- x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False
1129
- ):
1130
- """
1131
- Argument:
1132
- x: (batch, seqlen, nheads, headdim)
1133
- dt: (batch, seqlen, nheads)
1134
- A: (nheads)
1135
- B: (batch, seqlen, ngroups, dstate)
1136
- C: (batch, seqlen, ngroups, dstate)
1137
- D: (nheads, headdim) or (nheads,)
1138
- z: (batch, seqlen, nheads, headdim)
1139
- dt_bias: (nheads,)
1140
- Return:
1141
- out: (batch, seqlen, nheads, headdim)
1142
- """
1143
- batch, seqlen, nheads, headdim = x.shape
1144
- dstate = B.shape[-1]
1145
- if seqlen % chunk_size != 0:
1146
- dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
1147
- dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
1148
- dt = dt.float() # We want high precision for this before cumsum
1149
- if dt_bias is not None:
1150
- dt = dt + rearrange(dt_bias, "h -> h 1 1")
1151
- if dt_softplus:
1152
- dt = F.softplus(dt)
1153
- dA = dt * rearrange(A, "h -> h 1 1")
1154
- dA = dt * rearrange(A, "h -> h 1 1")
1155
- dA_cumsum = torch.cumsum(dA, dim=-1)
1156
- # 1. Compute the state for each chunk
1157
- states = chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True)
1158
- # 2. Pass the state to all the chunks by weighted cumsum.
1159
- states = rearrange(
1160
- state_passing(
1161
- rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1]
1162
- )[0],
1163
- "... (p n) -> ... p n",
1164
- n=dstate,
1165
- )
1166
- # 3. Compute the output for each chunk
1167
- out = chunk_scan(B, C, x, dt, dA_cumsum, states, D=D, z=z)
1168
- return out
1169
-
1170
-
1171
- def ssd_chunk_scan_combined_ref(
1172
- x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False
1173
- ):
1174
- """
1175
- Argument:
1176
- x: (batch, seqlen, nheads, headdim)
1177
- dt: (batch, seqlen, nheads)
1178
- A: (nheads)
1179
- B: (batch, seqlen, ngroups, dstate)
1180
- C: (batch, seqlen, ngroups, dstate)
1181
- D: (nheads, headdim) or (nheads,)
1182
- z: (batch, seqlen, nheads, headdim)
1183
- dt_bias: (nheads,)
1184
- Return:
1185
- out: (batch, seqlen, nheads, headdim)
1186
- """
1187
- batch, seqlen, nheads, headdim = x.shape
1188
- dstate = B.shape[-1]
1189
- if seqlen % chunk_size != 0:
1190
- dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
1191
- dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
1192
- dt = dt.float() # We want high precision for this before cumsum
1193
- if dt_bias is not None:
1194
- dt = dt + rearrange(dt_bias, "h -> h 1 1")
1195
- if dt_softplus:
1196
- dt = F.softplus(dt)
1197
- dA = dt * rearrange(A, "h -> h 1 1")
1198
- dA_cumsum = torch.cumsum(dA, dim=-1)
1199
- # 1. Compute the state for each chunk
1200
- states = chunk_state_ref(B, x, dt, dA_cumsum)
1201
- states_dtype = states.dtype
1202
- if states.dtype not in [torch.float32, torch.float64]:
1203
- states = states.to(torch.float32)
1204
- # 2. Pass the state to all the chunks by weighted cumsum.
1205
- # state_passing_ref is much less numerically stable
1206
- states = rearrange(
1207
- state_passing_ref(
1208
- rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1]
1209
- )[0],
1210
- "... (p n) -> ... p n",
1211
- n=dstate,
1212
- )
1213
- states = states.to(states_dtype)
1214
- # 3. Compute the output for each chunk
1215
- out = chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z)
1216
- return out
1217
-
1218
-
1219
- def ssd_selective_scan(
1220
- x,
1221
- dt,
1222
- A,
1223
- B,
1224
- C,
1225
- D=None,
1226
- z=None,
1227
- dt_bias=None,
1228
- dt_softplus=False,
1229
- dt_limit=(0.0, float("inf")),
1230
- ):
1231
- """
1232
- Argument:
1233
- x: (batch, seqlen, nheads, headdim)
1234
- dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
1235
- A: (nheads) or (dim, dstate)
1236
- B: (batch, seqlen, ngroups, dstate)
1237
- C: (batch, seqlen, ngroups, dstate)
1238
- D: (nheads, headdim) or (nheads,)
1239
- z: (batch, seqlen, nheads, headdim)
1240
- dt_bias: (nheads,) or (nheads, headdim)
1241
- Return:
1242
- out: (batch, seqlen, nheads, headdim)
1243
- """
1244
- from ..selective_scan_interface import selective_scan_fn
1245
-
1246
- batch, seqlen, nheads, headdim = x.shape
1247
- _, _, ngroups, dstate = B.shape
1248
- x = rearrange(x, "b l h p -> b (h p) l")
1249
- if dt.dim() == 3:
1250
- dt = repeat(dt, "b l h -> b l h p", p=headdim)
1251
- dt = rearrange(dt, "b l h p -> b (h p) l")
1252
- if A.dim() == 1:
1253
- A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
1254
- else:
1255
- A = A.to(dtype=torch.float32)
1256
- B = rearrange(B, "b l g n -> b g n l")
1257
- C = rearrange(C, "b l g n -> b g n l")
1258
- if D is not None:
1259
- if D.dim() == 2:
1260
- D = rearrange(D, "h p -> (h p)")
1261
- else:
1262
- D = repeat(D, "h -> (h p)", p=headdim)
1263
- if z is not None:
1264
- z = rearrange(z, "b l h p -> b (h p) l")
1265
- if dt_bias is not None:
1266
- if dt_bias.dim() == 1:
1267
- dt_bias = repeat(dt_bias, "h -> h p", p=headdim)
1268
- dt_bias = rearrange(dt_bias, "h p -> (h p)")
1269
- if dt_limit != (0.0, float("inf")):
1270
- if dt_bias is not None:
1271
- dt = dt + rearrange(dt_bias, "d -> d 1")
1272
- if dt_softplus:
1273
- dt = F.softplus(dt)
1274
- dt = dt.clamp(min=dt_limit[0], max=dt_limit[1]).to(x.dtype)
1275
- dt_bias = None
1276
- dt_softplus = None
1277
- out = selective_scan_fn(
1278
- x, dt, A, B, C, D=D, z=z, delta_bias=dt_bias, delta_softplus=dt_softplus
1279
- )
1280
- return rearrange(out, "b (h p) l -> b l h p", p=headdim)
1281
-
1282
-
1283
- def mamba_conv1d_scan_ref(
1284
- xBC,
1285
- conv1d_weight,
1286
- conv1d_bias,
1287
- dt,
1288
- A,
1289
- chunk_size,
1290
- D=None,
1291
- z=None,
1292
- dt_bias=None,
1293
- dt_softplus=False,
1294
- dt_limit=(0.0, float("inf")),
1295
- activation="silu",
1296
- headdim=None,
1297
- ngroups=1,
1298
- ):
1299
- """
1300
- Argument:
1301
- xBC: (batch, seqlen, dim + 2 * ngroups * dstate) where dim == nheads * headdim
1302
- conv1d_weight: (dim + 2 * ngroups * dstate, width)
1303
- conv1d_bias: (dim + 2 * ngroups * dstate,)
1304
- dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
1305
- A: (nheads)
1306
- D: (nheads, headdim) or (nheads,)
1307
- z: (batch, seqlen, dim)
1308
- dt_bias: (nheads) or (nheads, headdim)
1309
- headdim: if D is 1D and z is None, headdim must be passed in
1310
- Return:
1311
- out: (batch, seqlen, dim)
1312
- """
1313
- batch, seqlen, nheads = dt.shape[:3]
1314
- assert nheads % ngroups == 0
1315
- if z is not None:
1316
- dim = z.shape[-1]
1317
- assert dim % nheads == 0
1318
- headdim = dim // nheads
1319
- else:
1320
- if D.dim() == 1:
1321
- assert headdim is not None
1322
- else:
1323
- headdim = D.shape[1]
1324
- dim = nheads * headdim
1325
- xBC = rearrange(
1326
- causal_conv1d_fn(
1327
- rearrange(xBC, "b s d -> b d s"),
1328
- conv1d_weight,
1329
- conv1d_bias,
1330
- activation=activation,
1331
- ),
1332
- "b d s -> b s d",
1333
- )
1334
- dstate = (xBC.shape[-1] - dim) // ngroups // 2
1335
- x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
1336
- x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
1337
- B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
1338
- C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
1339
- z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
1340
- out = ssd_selective_scan(
1341
- x,
1342
- dt.to(x.dtype),
1343
- A,
1344
- B,
1345
- C,
1346
- D=D.float(),
1347
- z=z,
1348
- dt_bias=dt_bias,
1349
- dt_softplus=dt_softplus,
1350
- dt_limit=dt_limit,
1351
- )
1352
- return rearrange(out, "b s h p -> b s (h p)")
1353
-
1354
-
1355
- class MambaSplitConv1dScanCombinedFn(torch.autograd.Function):
1356
-
1357
- @staticmethod
1358
- @custom_fwd
1359
- def forward(
1360
- ctx,
1361
- zxbcdt,
1362
- conv1d_weight,
1363
- conv1d_bias,
1364
- dt_bias,
1365
- A,
1366
- D,
1367
- chunk_size,
1368
- initial_states=None,
1369
- seq_idx=None,
1370
- dt_limit=(0.0, float("inf")),
1371
- return_final_states=False,
1372
- activation="silu",
1373
- rmsnorm_weight=None,
1374
- rmsnorm_eps=1e-6,
1375
- outproj_weight=None,
1376
- outproj_bias=None,
1377
- headdim=None,
1378
- ngroups=1,
1379
- norm_before_gate=True,
1380
- ):
1381
- assert activation in [None, "silu", "swish"]
1382
- if D.dim() == 1:
1383
- assert headdim is not None
1384
- (nheads,) = D.shape
1385
- else:
1386
- nheads, headdim = D.shape
1387
- batch, seqlen, _ = zxbcdt.shape
1388
- dim = nheads * headdim
1389
- assert nheads % ngroups == 0
1390
- dstate = (conv1d_weight.shape[0] - dim) // ngroups // 2
1391
- d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ngroups * dstate - nheads) // 2
1392
- assert d_nonssm >= 0
1393
- assert zxbcdt.shape == (
1394
- batch,
1395
- seqlen,
1396
- 2 * d_nonssm + 2 * dim + 2 * ngroups * dstate + nheads,
1397
- )
1398
- assert dt_bias.shape == (nheads,)
1399
- assert A.shape == (nheads,)
1400
- zx0, z, xBC, dt = torch.split(
1401
- zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1
1402
- )
1403
- seq_idx = seq_idx.contiguous() if seq_idx is not None else None
1404
- xBC_conv = rearrange(
1405
- causal_conv1d_cuda.causal_conv1d_fwd(
1406
- rearrange(xBC, "b s d -> b d s"),
1407
- conv1d_weight,
1408
- conv1d_bias,
1409
- seq_idx,
1410
- None,
1411
- None,
1412
- activation in ["silu", "swish"],
1413
- ),
1414
- "b d s -> b s d",
1415
- )
1416
- x, B, C = torch.split(
1417
- xBC_conv, [dim, ngroups * dstate, ngroups * dstate], dim=-1
1418
- )
1419
- x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
1420
- B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
1421
- C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
1422
- z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
1423
- if rmsnorm_weight is None:
1424
- out, out_x, dt_out, dA_cumsum, states, final_states = (
1425
- _mamba_chunk_scan_combined_fwd(
1426
- x,
1427
- dt,
1428
- A,
1429
- B,
1430
- C,
1431
- chunk_size=chunk_size,
1432
- D=D,
1433
- z=z,
1434
- dt_bias=dt_bias,
1435
- initial_states=initial_states,
1436
- seq_idx=seq_idx,
1437
- dt_softplus=True,
1438
- dt_limit=dt_limit,
1439
- )
1440
- )
1441
- out = rearrange(out, "b s h p -> b s (h p)")
1442
- rstd = None
1443
- if d_nonssm > 0:
1444
- out = torch.cat([_swiglu_fwd(zx0), out], dim=-1)
1445
- else:
1446
- out_x, _, dt_out, dA_cumsum, states, final_states = (
1447
- _mamba_chunk_scan_combined_fwd(
1448
- x,
1449
- dt,
1450
- A,
1451
- B,
1452
- C,
1453
- chunk_size=chunk_size,
1454
- D=D,
1455
- z=None,
1456
- dt_bias=dt_bias,
1457
- initial_states=initial_states,
1458
- seq_idx=seq_idx,
1459
- dt_softplus=True,
1460
- dt_limit=dt_limit,
1461
- )
1462
- )
1463
- # reshape input data into 2D tensor
1464
- x_rms = rearrange(out_x, "b s h p -> (b s) (h p)")
1465
- z_rms = rearrange(z, "b s h p -> (b s) (h p)")
1466
- rmsnorm_weight = rmsnorm_weight.contiguous()
1467
- if d_nonssm == 0:
1468
- out = None
1469
- else:
1470
- out01 = torch.empty(
1471
- (batch, seqlen, d_nonssm + dim),
1472
- dtype=x_rms.dtype,
1473
- device=x_rms.device,
1474
- )
1475
- out = rearrange(out01[..., d_nonssm:], "b s d -> (b s) d")
1476
- _swiglu_fwd(zx0, out=out01[..., :d_nonssm])
1477
- out, _, rstd = _layer_norm_fwd(
1478
- x_rms,
1479
- rmsnorm_weight,
1480
- None,
1481
- rmsnorm_eps,
1482
- z_rms,
1483
- out=out,
1484
- group_size=dim // ngroups,
1485
- norm_before_gate=norm_before_gate,
1486
- is_rms_norm=True,
1487
- )
1488
- if d_nonssm == 0:
1489
- out = rearrange(out, "(b s) d -> b s d", b=batch)
1490
- else:
1491
- out = out01
1492
- ctx.outproj_weight_dtype = (
1493
- outproj_weight.dtype if outproj_weight is not None else None
1494
- )
1495
- if outproj_weight is not None:
1496
- if torch.is_autocast_enabled():
1497
- dtype = torch.get_autocast_gpu_dtype()
1498
- out, outproj_weight = out.to(dtype), outproj_weight.to(dtype)
1499
- outproj_bias = (
1500
- outproj_bias.to(dtype) if outproj_bias is not None else None
1501
- )
1502
- out = F.linear(out, outproj_weight, outproj_bias)
1503
- else:
1504
- assert outproj_bias is None
1505
- ctx.save_for_backward(
1506
- zxbcdt,
1507
- conv1d_weight,
1508
- conv1d_bias,
1509
- out_x,
1510
- A,
1511
- D,
1512
- dt_bias,
1513
- initial_states,
1514
- seq_idx,
1515
- rmsnorm_weight,
1516
- rstd,
1517
- outproj_weight,
1518
- outproj_bias,
1519
- )
1520
- ctx.dt_limit = dt_limit
1521
- ctx.return_final_states = return_final_states
1522
- ctx.activation = activation
1523
- ctx.rmsnorm_eps = rmsnorm_eps
1524
- ctx.norm_before_gate = norm_before_gate
1525
- ctx.chunk_size = chunk_size
1526
- ctx.headdim = headdim
1527
- ctx.ngroups = ngroups
1528
- return out if not return_final_states else (out, final_states)
1529
-
1530
- @staticmethod
1531
- @custom_bwd
1532
- def backward(ctx, dout, *args):
1533
- (
1534
- zxbcdt,
1535
- conv1d_weight,
1536
- conv1d_bias,
1537
- out,
1538
- A,
1539
- D,
1540
- dt_bias,
1541
- initial_states,
1542
- seq_idx,
1543
- rmsnorm_weight,
1544
- rstd,
1545
- outproj_weight,
1546
- outproj_bias,
1547
- ) = ctx.saved_tensors
1548
- dfinal_states = args[0] if ctx.return_final_states else None
1549
- headdim = ctx.headdim
1550
- nheads = D.shape[0]
1551
- dim = nheads * headdim
1552
- assert nheads % ctx.ngroups == 0
1553
- dstate = (conv1d_weight.shape[0] - dim) // ctx.ngroups // 2
1554
- d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ctx.ngroups * dstate - nheads) // 2
1555
- assert d_nonssm >= 0
1556
- recompute_output = outproj_weight is not None
1557
- if recompute_output:
1558
- out_recompute = torch.empty(
1559
- *out.shape[:2], d_nonssm + dim, device=out.device, dtype=out.dtype
1560
- )
1561
- out0_recompute, out1_recompute = out_recompute.split(
1562
- [d_nonssm, dim], dim=-1
1563
- )
1564
- zx0, z, xBC, dt = torch.split(
1565
- zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1
1566
- )
1567
- # Recompute x, B, C
1568
- xBC_conv = rearrange(
1569
- causal_conv1d_cuda.causal_conv1d_fwd(
1570
- rearrange(xBC, "b s d -> b d s"),
1571
- conv1d_weight,
1572
- conv1d_bias,
1573
- seq_idx,
1574
- None,
1575
- None,
1576
- ctx.activation in ["silu", "swish"],
1577
- ),
1578
- "b d s -> b s d",
1579
- )
1580
- x, B, C = torch.split(
1581
- xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1
1582
- )
1583
- x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
1584
- B = rearrange(B, "b l (g n) -> b l g n", g=ctx.ngroups)
1585
- C = rearrange(C, "b l (g n) -> b l g n", g=ctx.ngroups)
1586
- dzxbcdt = torch.empty_like(zxbcdt)
1587
- dzx0, dz, dxBC_given, ddt_given = torch.split(
1588
- dzxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1
1589
- )
1590
- dxBC = torch.empty_like(xBC)
1591
- dx, dB, dC = torch.split(
1592
- dxBC, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1
1593
- )
1594
- z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
1595
- dx = rearrange(dx, "b l (h p) -> b l h p", h=nheads)
1596
- dB = rearrange(dB, "b l (g n) -> b l g n", g=ctx.ngroups)
1597
- dC = rearrange(dC, "b l (g n) -> b l g n", g=ctx.ngroups)
1598
- if outproj_weight is not None:
1599
- dout_og = dout
1600
- dout = F.linear(dout, outproj_weight.t())
1601
- if d_nonssm > 0:
1602
- dout0, dout = dout.split([d_nonssm, dim], dim=-1)
1603
- _swiglu_bwd(zx0, dout0, dxy=dzx0, recompute_output=True, out=out0_recompute)
1604
- dout = rearrange(dout, "b s (h p) -> b s h p", p=headdim)
1605
- if rmsnorm_weight is None:
1606
- dz = rearrange(dz, "b l (h p) -> b l h p", h=nheads)
1607
- dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states, *rest = (
1608
- _mamba_chunk_scan_combined_bwd(
1609
- dout,
1610
- x,
1611
- dt,
1612
- A,
1613
- B,
1614
- C,
1615
- out,
1616
- ctx.chunk_size,
1617
- D=D,
1618
- z=z,
1619
- dt_bias=dt_bias,
1620
- initial_states=initial_states,
1621
- dfinal_states=dfinal_states,
1622
- seq_idx=seq_idx,
1623
- dt_softplus=True,
1624
- dt_limit=ctx.dt_limit,
1625
- dx=dx,
1626
- ddt=ddt_given,
1627
- dB=dB,
1628
- dC=dC,
1629
- dz=dz,
1630
- recompute_output=recompute_output,
1631
- )
1632
- )
1633
- out_for_linear = (
1634
- rearrange(rest[0], "b s h p -> b s (h p)") if recompute_output else None
1635
- )
1636
- drmsnorm_weight = None
1637
- else:
1638
- batch = dout.shape[0]
1639
- dy_rms = rearrange(dout, "b s h p -> (b s) (h p)")
1640
- dz = rearrange(dz, "b l d -> (b l) d")
1641
- x_rms = rearrange(out, "b s h p -> (b s) (h p)")
1642
- z_rms = rearrange(z, "b s h p -> (b s) (h p)")
1643
- out1_recompute = (
1644
- rearrange(out1_recompute, "b s d -> (b s) d")
1645
- if recompute_output
1646
- else None
1647
- )
1648
- dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(
1649
- dy_rms,
1650
- x_rms,
1651
- rmsnorm_weight,
1652
- None,
1653
- ctx.rmsnorm_eps,
1654
- None,
1655
- rstd,
1656
- z_rms,
1657
- group_size=dim // ctx.ngroups,
1658
- norm_before_gate=ctx.norm_before_gate,
1659
- is_rms_norm=True,
1660
- recompute_output=recompute_output,
1661
- dz=dz,
1662
- out=out1_recompute if recompute_output else None,
1663
- )
1664
- out_for_linear = out_recompute if recompute_output else None
1665
- dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim)
1666
- dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = (
1667
- _mamba_chunk_scan_combined_bwd(
1668
- dout,
1669
- x,
1670
- dt,
1671
- A,
1672
- B,
1673
- C,
1674
- out,
1675
- ctx.chunk_size,
1676
- D=D,
1677
- z=None,
1678
- dt_bias=dt_bias,
1679
- initial_states=initial_states,
1680
- dfinal_states=dfinal_states,
1681
- seq_idx=seq_idx,
1682
- dt_softplus=True,
1683
- dt_limit=ctx.dt_limit,
1684
- dx=dx,
1685
- ddt=ddt_given,
1686
- dB=dB,
1687
- dC=dC,
1688
- )
1689
- )
1690
-
1691
- if outproj_weight is not None:
1692
- doutproj_weight = torch.einsum("bso,bsd->od", dout_og, out_for_linear)
1693
- doutproj_bias = (
1694
- dout_og.sum(dim=(0, 1)) if outproj_bias is not None else None
1695
- )
1696
- else:
1697
- doutproj_weight, doutproj_bias = None, None
1698
- dxBC_given = rearrange(dxBC_given, "b s d -> b d s")
1699
- dxBC_given, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
1700
- rearrange(xBC, "b s d -> b d s"),
1701
- conv1d_weight,
1702
- conv1d_bias,
1703
- rearrange(dxBC, "b s d -> b d s"),
1704
- seq_idx,
1705
- None,
1706
- None,
1707
- dxBC_given,
1708
- False,
1709
- ctx.activation in ["silu", "swish"],
1710
- )
1711
- dxBC_given = rearrange(dxBC_given, "b d s -> b s d")
1712
- return (
1713
- dzxbcdt,
1714
- dweight,
1715
- dbias,
1716
- ddt_bias,
1717
- dA,
1718
- dD,
1719
- None,
1720
- dinitial_states,
1721
- None,
1722
- None,
1723
- None,
1724
- None,
1725
- drmsnorm_weight,
1726
- None,
1727
- doutproj_weight,
1728
- doutproj_bias,
1729
- None,
1730
- None,
1731
- None,
1732
- )
1733
-
1734
-
1735
- def mamba_split_conv1d_scan_combined(
1736
- zxbcdt,
1737
- conv1d_weight,
1738
- conv1d_bias,
1739
- dt_bias,
1740
- A,
1741
- D,
1742
- chunk_size,
1743
- initial_states=None,
1744
- seq_idx=None,
1745
- dt_limit=(0.0, float("inf")),
1746
- return_final_states=False,
1747
- activation="silu",
1748
- rmsnorm_weight=None,
1749
- rmsnorm_eps=1e-6,
1750
- outproj_weight=None,
1751
- outproj_bias=None,
1752
- headdim=None,
1753
- ngroups=1,
1754
- norm_before_gate=True,
1755
- ):
1756
- """
1757
- Argument:
1758
- zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
1759
- conv1d_weight: (dim + 2 * ngroups * dstate, width)
1760
- conv1d_bias: (dim + 2 * ngroups * dstate,)
1761
- dt_bias: (nheads,)
1762
- A: (nheads)
1763
- D: (nheads, headdim) or (nheads,)
1764
- initial_states: (batch, nheads, headdim, dstate)
1765
- seq_idx: (batch, seqlen), int32
1766
- rmsnorm_weight: (dim,)
1767
- outproj_weight: (out_dim, dim)
1768
- outproj_bias: (out_dim,)
1769
- headdim: if D is 1D, headdim must be passed in
1770
- norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
1771
- Return:
1772
- out: (batch, seqlen, dim)
1773
- """
1774
- return MambaSplitConv1dScanCombinedFn.apply(
1775
- zxbcdt,
1776
- conv1d_weight,
1777
- conv1d_bias,
1778
- dt_bias,
1779
- A,
1780
- D,
1781
- chunk_size,
1782
- initial_states,
1783
- seq_idx,
1784
- dt_limit,
1785
- return_final_states,
1786
- activation,
1787
- rmsnorm_weight,
1788
- rmsnorm_eps,
1789
- outproj_weight,
1790
- outproj_bias,
1791
- headdim,
1792
- ngroups,
1793
- norm_before_gate,
1794
- )
1795
-
1796
-
1797
- def mamba_split_conv1d_scan_ref(
1798
- zxbcdt,
1799
- conv1d_weight,
1800
- conv1d_bias,
1801
- dt_bias,
1802
- A,
1803
- D,
1804
- chunk_size,
1805
- dt_limit=(0.0, float("inf")),
1806
- activation="silu",
1807
- rmsnorm_weight=None,
1808
- rmsnorm_eps=1e-6,
1809
- outproj_weight=None,
1810
- outproj_bias=None,
1811
- headdim=None,
1812
- ngroups=1,
1813
- norm_before_gate=True,
1814
- ):
1815
- """
1816
- Argument:
1817
- zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
1818
- conv1d_weight: (dim + 2 * ngroups * dstate, width)
1819
- conv1d_bias: (dim + 2 * ngroups * dstate,)
1820
- dt_bias: (nheads,)
1821
- A: (nheads)
1822
- D: (nheads, headdim) or (nheads,)
1823
- rmsnorm_weight: (dim,)
1824
- outproj_weight: (out_dim, dim)
1825
- outproj_bias: (out_dim,)
1826
- headdim: if D is 1D, headdim must be passed in
1827
- norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
1828
- Return:
1829
- out: (batch, seqlen, dim)
1830
- """
1831
- if D.dim() == 1:
1832
- assert headdim is not None
1833
- (nheads,) = D.shape
1834
- else:
1835
- nheads, headdim = D.shape
1836
- assert nheads % ngroups == 0
1837
- batch, seqlen, _ = zxbcdt.shape
1838
- dim = nheads * headdim
1839
- dstate = (zxbcdt.shape[-1] - 2 * dim - nheads) // ngroups // 2
1840
- assert zxbcdt.shape == (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads)
1841
- assert dt_bias.shape == (nheads,)
1842
- assert A.shape == (nheads,)
1843
- if rmsnorm_weight is not None:
1844
- assert rmsnorm_weight.shape == (dim,)
1845
- z, xBC, dt = torch.split(zxbcdt, [dim, dim + 2 * ngroups * dstate, nheads], dim=-1)
1846
- xBC = rearrange(
1847
- causal_conv1d_fn(
1848
- rearrange(xBC, "b s d -> b d s"),
1849
- conv1d_weight,
1850
- conv1d_bias,
1851
- activation=activation,
1852
- ),
1853
- "b d s -> b s d",
1854
- )
1855
- x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
1856
- x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
1857
- B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
1858
- C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
1859
- z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
1860
- out = ssd_selective_scan(
1861
- x,
1862
- dt.to(x.dtype),
1863
- A,
1864
- B,
1865
- C,
1866
- D=D.float(),
1867
- z=z if rmsnorm_weight is None else None,
1868
- dt_bias=dt_bias,
1869
- dt_softplus=True,
1870
- dt_limit=dt_limit,
1871
- )
1872
- out = rearrange(out, "b s h p -> b s (h p)")
1873
- if rmsnorm_weight is not None:
1874
- out = rmsnorm_fn(
1875
- out,
1876
- rmsnorm_weight,
1877
- None,
1878
- z=rearrange(z, "b l h p -> b l (h p)"),
1879
- eps=rmsnorm_eps,
1880
- norm_before_gate=norm_before_gate,
1881
- )
1882
- if outproj_weight is not None:
1883
- out = F.linear(out, outproj_weight, outproj_bias)
1884
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/utils/__init__.py DELETED
File without changes
build/torch25-cxx98-cu121-x86_64-linux/mamba_ssm/__init__.py DELETED
@@ -1,14 +0,0 @@
1
- __version__ = "2.2.4"
2
-
3
- from .ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
4
- from .modules.mamba_simple import Mamba
5
- from .modules.mamba2 import Mamba2
6
- from .models.mixer_seq_simple import MambaLMHeadModel
7
-
8
- __all__ = [
9
- "selective_scan_fn",
10
- "mamba_inner_fn",
11
- "Mamba",
12
- "Mamba2",
13
- "MambaLMHeadModel",
14
- ]