|
import math |
|
from dataclasses import dataclass |
|
from typing import Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from pscan import pscan |
|
|
|
""" |
|
|
|
This file closely follows the mamba_simple.py from the official Mamba implementation, and the mamba-minimal by @johnma2006. |
|
The major differences are : |
|
-the convolution is done with torch.nn.Conv1d |
|
-the selective scan is done in PyTorch |
|
|
|
A sequential version of the selective scan is also available for comparison. |
|
|
|
- A Mamba model is composed of several layers, which are ResidualBlock. |
|
- A ResidualBlock is composed of a MambaBlock, a normalization, and a residual connection : ResidualBlock(x) = mamba(norm(x)) + x |
|
- This leaves us with the MambaBlock : its input x is (B, L, D) and its outputs y is also (B, L, D) (B=batch size, L=seq len, D=model dim). |
|
First, we expand x into (B, L, 2*ED) (where E is usually 2) and split it into x and z, each (B, L, ED). |
|
Then, we apply the short 1d conv to x, followed by an activation function (silu), then the SSM. |
|
We then multiply it by silu(z). |
|
See Figure 3 of the paper (page 8) for a visual representation of a MambaBlock. |
|
|
|
""" |
|
|
|
@dataclass |
|
class MambaConfig: |
|
d_model: int |
|
n_layers: int |
|
dt_rank: Union[int, str] = 'auto' |
|
d_state: int = 16 |
|
expand_factor: int = 2 |
|
d_conv: int = 4 |
|
|
|
dt_min: float = 0.001 |
|
dt_max: float = 0.1 |
|
dt_init: str = "random" |
|
dt_scale: float = 1.0 |
|
dt_init_floor = 1e-4 |
|
|
|
bias: bool = False |
|
conv_bias: bool = True |
|
|
|
pscan: bool = True |
|
|
|
def __post_init__(self): |
|
self.d_inner = self.expand_factor * self.d_model |
|
|
|
if self.dt_rank == 'auto': |
|
self.dt_rank = math.ceil(self.d_model / 16) |
|
|
|
class Mamba(nn.Module): |
|
def __init__(self, config: MambaConfig): |
|
super().__init__() |
|
|
|
self.config = config |
|
|
|
self.layers = nn.ModuleList([ResidualBlock(config) for _ in range(config.n_layers)]) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
|
for layer in self.layers: |
|
x = layer(x) |
|
|
|
|
|
|
|
return x |
|
|
|
def step(self, x, caches): |
|
|
|
|
|
|
|
|
|
|
|
|
|
for i, layer in enumerate(self.layers): |
|
x, caches[i] = layer.step(x, caches[i]) |
|
|
|
return x, caches |
|
|
|
class ResidualBlock(nn.Module): |
|
def __init__(self, config: MambaConfig): |
|
super().__init__() |
|
|
|
self.mixer = MambaBlock(config) |
|
self.norm = RMSNorm(config.d_model) |
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
|
output = self.mixer(self.norm(x)) + x |
|
return output |
|
|
|
def step(self, x, cache): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output, cache = self.mixer.step(self.norm(x), cache) |
|
output = output + x |
|
return output, cache |
|
|
|
class MambaBlock(nn.Module): |
|
def __init__(self, config: MambaConfig): |
|
super().__init__() |
|
|
|
self.config = config |
|
|
|
|
|
self.in_proj = nn.Linear(config.d_model, 2 * config.d_inner, bias=config.bias) |
|
|
|
self.conv1d = nn.Conv1d(in_channels=config.d_inner, out_channels=config.d_inner, |
|
kernel_size=config.d_conv, bias=config.conv_bias, |
|
groups=config.d_inner, |
|
padding=config.d_conv - 1) |
|
|
|
nn.init.kaiming_normal_(self.conv1d.weight, mode='fan_out', nonlinearity='leaky_relu') |
|
|
|
|
|
self.x_proj = nn.Linear(config.d_inner, config.dt_rank + 2 * config.d_state, bias=False) |
|
|
|
|
|
self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True) |
|
|
|
|
|
|
|
dt_init_std = config.dt_rank**-0.5 * config.dt_scale |
|
if config.dt_init == "constant": |
|
nn.init.constant_(self.dt_proj.weight, dt_init_std) |
|
elif config.dt_init == "random": |
|
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
dt = torch.exp( |
|
torch.rand(config.d_inner) * (math.log(config.dt_max) - math.log(config.dt_min)) + math.log(config.dt_min) |
|
).clamp(min=config.dt_init_floor) |
|
inv_dt = dt + torch.log(-torch.expm1(-dt)) |
|
with torch.no_grad(): |
|
self.dt_proj.bias.copy_(inv_dt) |
|
|
|
|
|
|
|
|
|
A = torch.arange(1, config.d_state + 1, dtype=torch.float32).repeat(config.d_inner, 1) |
|
self.A_log = nn.Parameter(torch.log(A)) |
|
self.D = nn.Parameter(torch.ones(config.d_inner)) |
|
|
|
|
|
self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias) |
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
|
_, L, _ = x.shape |
|
|
|
xz = self.in_proj(x) |
|
x, z = xz.chunk(2, dim=-1) |
|
|
|
|
|
x = x.transpose(1, 2) |
|
x = self.conv1d(x)[:, :, :L] |
|
x = x.transpose(1, 2) |
|
|
|
x = F.silu(x) |
|
y = self.ssm(x) |
|
|
|
|
|
z = F.silu(z) |
|
|
|
output = y * z |
|
output = self.out_proj(output) |
|
|
|
return output |
|
|
|
def ssm(self, x): |
|
|
|
|
|
|
|
|
|
A = -torch.exp(self.A_log.float()) |
|
D = self.D.float() |
|
|
|
|
|
deltaBC = self.x_proj(x) |
|
|
|
delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) |
|
delta = F.softplus(self.dt_proj(delta)) |
|
|
|
if self.config.pscan: |
|
y = self.selective_scan(x, delta, A, B, C, D) |
|
else: |
|
y = self.selective_scan_seq(x, delta, A, B, C, D) |
|
|
|
return y |
|
|
|
def selective_scan(self, x, delta, A, B, C, D): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deltaA = torch.exp(delta.unsqueeze(-1) * A) |
|
deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) |
|
|
|
BX = deltaB * (x.unsqueeze(-1)) |
|
|
|
hs = pscan(deltaA, BX) |
|
|
|
y = (hs @ C.unsqueeze(-1)).squeeze(3) |
|
|
|
y = y + D * x |
|
|
|
return y |
|
|
|
def selective_scan_seq(self, x, delta, A, B, C, D): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_, L, _ = x.shape |
|
|
|
deltaA = torch.exp(delta.unsqueeze(-1) * A) |
|
deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) |
|
|
|
BX = deltaB * (x.unsqueeze(-1)) |
|
|
|
h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device) |
|
hs = [] |
|
|
|
for t in range(0, L): |
|
h = deltaA[:, t] * h + BX[:, t] |
|
hs.append(h) |
|
|
|
hs = torch.stack(hs, dim=1) |
|
|
|
y = (hs @ C.unsqueeze(-1)).squeeze(3) |
|
|
|
y = y + D * x |
|
|
|
return y |
|
|
|
|
|
""" |
|
Concerning auto-regressive inference |
|
|
|
The cool part of using Mamba : inference is constant wrt to sequence length |
|
We just have to keep in cache, for each layer, two things : |
|
- the hidden state h (which is (B, ED, N)), as you typically would when doing inference with a RNN |
|
- the last d_conv-1 inputs of the layer, to be able to compute the 1D conv which is a convolution over the time dimension |
|
(d_conv is fixed so this doesn't incur a growing cache as we progress on generating the sequence) |
|
(and d_conv is usually very small, like 4, so we just have to "remember" the last 3 inputs) |
|
|
|
Concretely, these two quantities are put inside a cache tuple, and are named h and inputs respectively. |
|
h is (B, ED, N), and inputs is (B, ED, d_conv-1) |
|
The MambaBlock.step() receives this cache, and, along with outputing the output, alos outputs the updated cache for the next call. |
|
|
|
The cache object is initialized as follows : (None, torch.zeros()). |
|
When h is None, the selective scan function detects it and start with h=0. |
|
The torch.zeros() isn't a problem (it's same as just feeding the input, because the conv1d is padded) |
|
|
|
As we need one such cache variable per layer, we store a caches object, which is simply a list of cache object. (See mamba_lm.py) |
|
""" |
|
|
|
def step(self, x, cache): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
h, inputs = cache |
|
|
|
xz = self.in_proj(x) |
|
x, z = xz.chunk(2, dim=1) |
|
|
|
|
|
x_cache = x.unsqueeze(2) |
|
x = self.conv1d(torch.cat([inputs, x_cache], dim=2))[:, :, self.config.d_conv-1] |
|
|
|
x = F.silu(x) |
|
y, h = self.ssm_step(x, h) |
|
|
|
|
|
z = F.silu(z) |
|
|
|
output = y * z |
|
output = self.out_proj(output) |
|
|
|
|
|
inputs = torch.cat([inputs[:, :, 1:], x_cache], dim=2) |
|
cache = (h, inputs) |
|
|
|
return output, cache |
|
|
|
def ssm_step(self, x, h): |
|
|
|
|
|
|
|
|
|
|
|
|
|
A = -torch.exp(self.A_log.float()) |
|
D = self.D.float() |
|
|
|
|
|
deltaBC = self.x_proj(x) |
|
|
|
delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) |
|
delta = F.softplus(self.dt_proj(delta)) |
|
|
|
deltaA = torch.exp(delta.unsqueeze(-1) * A) |
|
deltaB = delta.unsqueeze(-1) * B.unsqueeze(1) |
|
|
|
BX = deltaB * (x.unsqueeze(-1)) |
|
|
|
if h is None: |
|
h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device) |
|
|
|
h = deltaA * h + BX |
|
|
|
y = (h @ C.unsqueeze(-1)).squeeze(2) |
|
|
|
y = y + D * x |
|
|
|
|
|
return y, h.squeeze(1) |
|
|
|
|
|
class RMSNorm(nn.Module): |
|
def __init__(self, d_model: int, eps: float = 1e-5): |
|
super().__init__() |
|
|
|
self.eps = eps |
|
self.weight = nn.Parameter(torch.ones(d_model)) |
|
|
|
def forward(self, x): |
|
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight |
|
|
|
return output |
|
|