| """
|
| VortexSSM: Selective State-Space Layer
|
| Simplified Mamba-style SSM with input-dependent selection.
|
| Provides O(n) complexity for long sequences, ideal for scientific documents.
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from typing import Optional, Tuple
|
|
|
|
|
| class VortexSSM(nn.Module):
|
| """
|
| Selective state-space layer. Linear complexity O(n) vs attention's O(n²).
|
| Handles long scientific documents efficiently with input-dependent selection.
|
|
|
| Architecture based on Mamba but simplified for scientific reasoning tasks.
|
| """
|
|
|
| def __init__(
|
| self,
|
| d_model: int,
|
| d_state: int = 16,
|
| d_conv: int = 4,
|
| expand: int = 2,
|
| dt_rank: Optional[int] = None,
|
| ):
|
| """
|
| Initialize VortexSSM.
|
|
|
| Args:
|
| d_model: Model dimension
|
| d_state: State dimension (default 16 for 7B, 32 for 13B)
|
| d_conv: Convolution kernel size for local context
|
| expand: Expansion factor for inner dimension
|
| dt_rank: Rank for delta projection (if None, uses ceil(d_model/16))
|
| """
|
| super().__init__()
|
| self.d_model = d_model
|
| self.d_state = d_state
|
| self.d_conv = d_conv
|
| self.expand = expand
|
| self.d_inner = d_model * expand
|
|
|
| if dt_rank is None:
|
| self.dt_rank = max(1, d_model // 16)
|
| else:
|
| self.dt_rank = dt_rank
|
|
|
|
|
| self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
|
|
|
|
|
|
|
| self.conv1d = nn.Conv1d(
|
| in_channels=self.d_inner,
|
| out_channels=self.d_inner,
|
| kernel_size=d_conv,
|
| padding=d_conv - 1,
|
| groups=self.d_inner,
|
| bias=False,
|
| )
|
|
|
|
|
| self.x_proj = nn.Linear(self.d_inner, self.dt_rank + 2 * self.d_state, bias=False)
|
| self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
|
|
|
|
|
|
|
| self.A_log = nn.Parameter(torch.randn(self.d_inner, self.d_state))
|
| self.D = nn.Parameter(torch.randn(self.d_inner))
|
|
|
|
|
| self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
|
|
|
|
|
| self._initialize_weights()
|
|
|
| def _initialize_weights(self):
|
| """Initialize weights properly."""
|
|
|
| nn.init.normal_(self.A_log, mean=-4.0, std=0.5)
|
| nn.init.normal_(self.D, mean=0.0, std=0.1)
|
|
|
|
|
| for module in [self.in_proj, self.x_proj, self.dt_proj, self.conv1d, self.out_proj]:
|
| if hasattr(module, 'weight'):
|
| nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
|
| def forward(
|
| self,
|
| x: torch.Tensor,
|
| state: Optional[torch.Tensor] = None,
|
| return_state: bool = False,
|
| ) -> torch.Tensor:
|
| """
|
| Forward pass through the SSM.
|
|
|
| Args:
|
| x: Input tensor (batch, seq_len, d_model)
|
| state: Previous hidden state (batch, d_inner, d_state)
|
| return_state: If True, return (output, state)
|
|
|
| Returns:
|
| Output tensor (batch, seq_len, d_model) or tuple with state
|
| """
|
| batch, seq_len, _ = x.shape
|
| device = x.device
|
| dtype = x.dtype
|
|
|
|
|
| d_inner = self.d_inner
|
|
|
|
|
| xz = self.in_proj(x)
|
| x, z = xz.chunk(2, dim=-1)
|
|
|
|
|
|
|
| x_conv = x.transpose(1, 2)
|
| x_conv = self.conv1d(x_conv)[..., :seq_len]
|
| x = x_conv.transpose(1, 2)
|
|
|
|
|
|
|
| x_dbl = self.x_proj(x)
|
| (delta, B, C) = torch.split(
|
| x_dbl,
|
| [self.dt_rank, self.d_state, self.d_state],
|
| dim=-1,
|
| )
|
|
|
|
|
| delta = self.dt_proj(delta)
|
| delta = F.softplus(delta)
|
|
|
|
|
|
|
| if state is None:
|
| state = torch.zeros(batch, d_inner, self.d_state, device=device, dtype=dtype)
|
|
|
|
|
| output = []
|
| for t in range(seq_len):
|
| x_t = x[:, t]
|
| delta_t = delta[:, t]
|
| B_t = B[:, t]
|
| C_t = C[:, t]
|
|
|
|
|
| A_delta = torch.exp(self.A_log * delta_t.unsqueeze(-1))
|
|
|
|
|
|
|
| state = A_delta * state + B_t.unsqueeze(1) * x_t.unsqueeze(-1)
|
|
|
|
|
| y = (C_t.unsqueeze(1) * state).sum(dim=-1) + self.D * x_t
|
| output.append(y)
|
|
|
| output = torch.stack(output, dim=1)
|
|
|
|
|
| output = output * F.silu(z)
|
|
|
|
|
| output = self.out_proj(output)
|
|
|
| if return_state:
|
| return output, state
|
| return output
|
|
|
| def step(
|
| self,
|
| x: torch.Tensor,
|
| state: torch.Tensor,
|
| ) -> Tuple[torch.Tensor, torch.Tensor]:
|
| """
|
| Single-step inference for autoregressive decoding.
|
|
|
| Args:
|
| x: Input at current step (batch, d_model)
|
| state: Previous state (batch, d_inner, d_state)
|
|
|
| Returns:
|
| output: (batch, d_model)
|
| new_state: updated state
|
| """
|
| batch, _ = x.shape
|
|
|
|
|
| xz = self.in_proj(x.unsqueeze(1))
|
| x, z = xz.chunk(2, dim=-1)
|
| x = x.squeeze(1)
|
| z = z.squeeze(1)
|
|
|
|
|
|
|
|
|
| x_dbl = self.x_proj(x.unsqueeze(1)).squeeze(1)
|
| delta, B, C = torch.split(
|
| x_dbl,
|
| [self.dt_rank, self.d_state, self.d_state],
|
| dim=-1,
|
| )
|
| delta = self.dt_proj(delta)
|
| delta = F.softplus(delta)
|
|
|
|
|
| A_delta = torch.exp(self.A_log * delta.unsqueeze(-1))
|
| state = A_delta * state + B.unsqueeze(1) * x.unsqueeze(-1)
|
| y = (C.unsqueeze(1) * state).sum(dim=-1) + self.D * x
|
| y = y * F.silu(z)
|
| output = self.out_proj(y)
|
|
|
| return output, state
|
|
|
|
|
| def test_vortex_ssm():
|
| """Test the VortexSSM layer."""
|
| batch_size = 2
|
| seq_len = 128
|
| d_model = 4096
|
| d_state = 16
|
|
|
| ssm = VortexSSM(d_model, d_state=d_state)
|
| x = torch.randn(batch_size, seq_len, d_model)
|
|
|
|
|
| output = ssm(x)
|
| print(f"Input shape: {x.shape}")
|
| print(f"Output shape: {output.shape}")
|
| assert output.shape == x.shape, f"Expected {x.shape}, got {output.shape}"
|
|
|
|
|
| state = torch.zeros(batch_size, ssm.d_inner, d_state)
|
| output2, new_state = ssm(x, state=state, return_state=True)
|
| print(f"Stateful output shape: {output2.shape}")
|
| print(f"State shape: {new_state.shape}")
|
|
|
|
|
| x_step = torch.randn(batch_size, d_model)
|
| output_step, state_step = ssm.step(x_step, state)
|
| print(f"Step output shape: {output_step.shape}")
|
| print(f"Step state shape: {state_step.shape}")
|
|
|
| print("VortexSSM test passed!")
|
|
|
|
|
| if __name__ == "__main__":
|
| test_vortex_ssm()
|
|
|