| | from typing import List, Optional, Tuple |
| | import torch |
| |
|
| | from mamba_ssm.ops.triton.ssd_combined import _mamba_chunk_scan_combined_fwd, _mamba_chunk_scan_combined_bwd |
| |
|
| |
|
| | @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True) |
| | def _compiled_mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=None): |
| | return _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit) |
| |
|
| | @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True) |
| | def _compiled_mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, dfinal_states=None, seq_idx=None, dt_softplus=False, dt_limit=None): |
| | return _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=dt_softplus, dt_limit=dt_limit) |
| |
|
| |
|
| | @torch.library.custom_op( |
| | "mamba_ssm::ssm_chunk_scan_combined_fwd", |
| | mutates_args=(), |
| | device_types="cuda", |
| | ) |
| | def ssm_chunk_scan_combined_fwd( |
| | x: torch.Tensor, |
| | dt: torch.Tensor, |
| | A: torch.Tensor, |
| | B: torch.Tensor, |
| | C: torch.Tensor, |
| | chunk_size: int, |
| | D: Optional[torch.Tensor] = None, |
| | z: Optional[torch.Tensor] = None, |
| | dt_bias: Optional[torch.Tensor] = None, |
| | initial_states: Optional[torch.Tensor] = None, |
| | seq_idx: Optional[torch.Tensor] = None, |
| | cu_seqlens: Optional[torch.Tensor] = None, |
| | dt_softplus: bool = False, |
| | dt_limit: Optional[List[float]] = None |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit) |
| |
|
| | return out, out_x if out_x is not None else out.new_empty(0), rest[0] if cu_seqlens is not None else out.new_empty(0) |
| |
|
| | @ssm_chunk_scan_combined_fwd.register_fake |
| | def _ssm_chunk_scan_combined_fwd_fake( |
| | x: torch.Tensor, |
| | dt: torch.Tensor, |
| | A: torch.Tensor, |
| | B: torch.Tensor, |
| | C: torch.Tensor, |
| | chunk_size: int, |
| | D: Optional[torch.Tensor] = None, |
| | z: Optional[torch.Tensor] = None, |
| | dt_bias: Optional[torch.Tensor] = None, |
| | initial_states: Optional[torch.Tensor] = None, |
| | seq_idx: Optional[torch.Tensor] = None, |
| | cu_seqlens: Optional[torch.Tensor] = None, |
| | dt_softplus: bool = False, |
| | dt_limit: Optional[List[float]] = None |
| | ): |
| | _, _, n_heads, head_dim = x.shape |
| | return ( |
| | torch.empty_like(x), |
| | torch.empty_like(x) if z is not None else None, |
| | x.new_empty((cu_seqlens.size(0)-1, n_heads, head_dim, B.size(0))) if cu_seqlens is not None else None, |
| | ) |
| |
|
| | @torch.library.custom_op( |
| | "mamba_ssm::ssm_chunk_scan_combined_bwd", |
| | mutates_args=(), |
| | device_types="cuda", |
| | ) |
| | def ssm_chunk_scan_combined_bwd( |
| | dout: torch.Tensor, |
| | x: torch.Tensor, |
| | dt: torch.Tensor, |
| | A: torch.Tensor, |
| | B: torch.Tensor, |
| | C: torch.Tensor, |
| | out: torch.Tensor, |
| | chunk_size: int, |
| | D: Optional[torch.Tensor] = None, |
| | z: Optional[torch.Tensor] = None, |
| | dt_bias: Optional[torch.Tensor] = None, |
| | initial_states: Optional[torch.Tensor] = None, |
| | seq_idx: Optional[torch.Tensor] = None, |
| | dt_softplus: bool = False, |
| | dt_limit: Optional[List[float]] = None |
| | )-> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| | dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=None, seq_idx=seq_idx, dt_softplus=dt_softplus, dt_limit=dt_limit) |
| | return ( |
| | dx, |
| | ddt, |
| | dA, |
| | dB, |
| | dC, |
| | dD if dD is not None else dx.new_empty(0), |
| | dz if dz is not None else dx.new_empty(0), |
| | ddt_bias if ddt_bias is not None else dx.new_empty(0), |
| | dinitial_states if dinitial_states is not None else dx.new_empty(0) |
| | ) |
| |
|
| | @ssm_chunk_scan_combined_bwd.register_fake |
| | def _ssm_chunk_scan_combined_bwd_fake( |
| | dout: torch.Tensor, |
| | x: torch.Tensor, |
| | dt: torch.Tensor, |
| | A: torch.Tensor, |
| | B: torch.Tensor, |
| | C: torch.Tensor, |
| | out: torch.Tensor, |
| | chunk_size: int, |
| | D: Optional[torch.Tensor] = None, |
| | z: Optional[torch.Tensor] = None, |
| | dt_bias: Optional[torch.Tensor] = None, |
| | initial_states: Optional[torch.Tensor] = None, |
| | seq_idx: Optional[torch.Tensor] = None, |
| | dt_softplus: bool = False, |
| | dt_limit: Optional[List[float]] = None |
| | ): |
| | return ( |
| | torch.empty_like(x), |
| | torch.empty_like(dt), |
| | torch.empty_like(A), |
| | torch.empty_like(B), |
| | torch.empty_like(C), |
| | torch.empty_like(D) if D is not None else None, |
| | torch.empty_like(z) if z is not None else None, |
| | torch.empty_like(dt_bias) if dt_bias is not None else None, |
| | torch.empty_like(initial_states) if initial_states is not None else None, |
| | ) |
| |
|
| |
|
| | def ssm_chunk_scan_combined_setup_context(ctx, inputs, output): |
| | x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, cu_seqlens, dt_softplus, dt_limit = inputs |
| | out, out_x, state_varlen = output |
| |
|
| | ctx.save_for_backward(out if z is None else out_x, x, dt, A, B, C, D, z, dt_bias, initial_states, seq_idx) |
| | ctx.dt_softplus = dt_softplus |
| | ctx.chunk_size = chunk_size |
| | ctx.dt_limit = dt_limit |
| |
|
| | def ssm_chunk_scan_combined_bridge(ctx, dout, dout_x, dout_state_varlen): |
| | out, x, dt, A, B, C, D, z, dt_bias, initial_states, seq_idx = ctx.saved_tensors |
| |
|
| | dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = ssm_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=ctx.dt_softplus, dt_limit=ctx.dt_limit) |
| |
|
| | return ( |
| | dx, |
| | ddt, |
| | dA, |
| | dB, |
| | dC, |
| | None, |
| | dD if D is not None else None, |
| | dz if z is not None else None, |
| | ddt_bias if dt_bias is not None else None, |
| | dinitial_states if initial_states is not None else None, |
| | None, |
| | None, |
| | None, |
| | None, |
| | ) |
| |
|
| | |
| | torch.library.register_autograd( |
| | "mamba_ssm::ssm_chunk_scan_combined_fwd", |
| | ssm_chunk_scan_combined_bridge, |
| | setup_context=ssm_chunk_scan_combined_setup_context, |
| | ) |
| |
|
| | def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): |
| | """ |
| | Argument: |
| | x: (batch, seqlen, nheads, headdim) |
| | dt: (batch, seqlen, nheads) |
| | A: (nheads) |
| | B: (batch, seqlen, ngroups, dstate) |
| | C: (batch, seqlen, ngroups, dstate) |
| | chunk_size: int |
| | D: (nheads, headdim) or (nheads,) |
| | z: (batch, seqlen, nheads, headdim) |
| | dt_bias: (nheads,) |
| | initial_states: (batch, nheads, headdim, dstate) |
| | seq_idx: (batch, seqlen) |
| | cu_seqlens: (num_sequences + 1) or None |
| | dt_softplus: Whether to apply softplus to dt |
| | Return: |
| | out: (batch, seqlen, nheads, headdim) |
| | """ |
| | |
| | out, _, varlen_states = ssm_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit) |
| | if cu_seqlens is not None: |
| | return out, varlen_states |
| | return out |
| |
|
| | if __name__ == "__main__": |
| | from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined as mamba_chunk_scan_combined_ref |
| |
|
| | torch.manual_seed(0) |
| | torch.cuda.manual_seed(0) |
| |
|
| | x = torch.randn(2, 3, 4, 5).cuda() |
| | dt = torch.randn(2, 3, 4).cuda() |
| | A = torch.randn(4).cuda() |
| | B = torch.randn(2, 3, 4, 5).cuda() |
| | C = torch.randn(2, 3, 4, 5).cuda() |
| | chunk_size = 2 |
| | D = torch.randn(4, 5).cuda() |
| | z = torch.randn(2, 3, 4, 5).cuda() |
| | dt_bias = torch.randn(4).cuda() |
| |
|
| | out = mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias) |
| |
|
| | print(out.min(), out.max(), out.mean(), out.std()) |
| |
|
| | compiled_mamba_chunk_scan_combined = torch.compile(mamba_chunk_scan_combined) |
| | out = compiled_mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias) |
| |
|
| | print(out.min(), out.max(), out.mean(), out.std()) |
| |
|
| | out_ref = mamba_chunk_scan_combined_ref(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias) |
| |
|
| | print(out_ref.min(), out_ref.max(), out_ref.mean(), out_ref.std()) |
| |
|
| |
|