from typing import Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from einops import einsum, rearrange, repeat from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) from transformers.modeling_utils import PreTrainedModel from .configuration_mamba import MambaConfig class MambaRMSNorm(nn.Module): def __init__(self, d_model: int, eps: float = 1e-5): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(d_model)) def forward(self, x): output = ( x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight ) return output class Mamba(nn.Module): def __init__(self, config: MambaConfig): """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1].""" super().__init__() self.config = config self.in_proj = nn.Linear(config.d_model, config.d_inner * 2, bias=config.bias) self.conv1d = nn.Conv1d( in_channels=config.d_inner, out_channels=config.d_inner, bias=config.conv_bias, kernel_size=config.d_conv, groups=config.d_inner, padding=config.d_conv - 1, ) # x_proj takes in `x` and outputs the input-specific Δ, B, C self.x_proj = nn.Linear( config.d_inner, config.dt_rank + config.d_state * 2, bias=False ) # dt_proj projects Δ from dt_rank to d_in self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True) A = repeat(torch.arange(1, config.d_state + 1), "n -> d n", d=config.d_inner) self.A_log = nn.Parameter(torch.log(A)) self.D = nn.Parameter(torch.ones(config.d_inner)) self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias) # self.norm = MambaRMSNorm(config.d_model) def forward(self, x): """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1]. Args: x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...) Returns: output: shape (b, l, d) Official Implementation: class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119 mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311 """ (b, l, d) = x.shape # x_copy = x # There was a separate class for residual, I deleted that part and added it here. # x = self.norm(x) x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in) (x, res) = x_and_res.split( split_size=[self.config.d_inner, self.config.d_inner], dim=-1 ) x = rearrange(x, "b l d_in -> b d_in l") x = self.conv1d(x)[:, :, :l] x = rearrange(x, "b d_in l -> b l d_in") x = F.silu(x) y = self.ssm(x) y = y * F.silu(res) # output = self.out_proj(y) + x_copy output = self.out_proj(y) return output def ssm(self, x): """Runs the SSM. See: - Algorithm 2 in Section 3.2 in the Mamba paper [1] - run_SSM(A, B, C, u) in The Annotated S4 [2] Args: x: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...) Returns: output: shape (b, l, d_in) Official Implementation: mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311 """ (d_in, n) = self.A_log.shape # Compute ∆ A B C D, the state space parameters. # A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) # ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, # and is why Mamba is called **selective** state spaces) A = -torch.exp(self.A_log.float()) # shape (d_in, n) D = self.D.float() x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n) (delta, B, C) = x_dbl.split( split_size=[self.config.dt_rank, n, n], dim=-1 ) # delta: (b, l, dt_rank). B, C: (b, l, n) delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in) y = self.selective_scan( x, delta, A, B, C, D ) # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2] return y def selective_scan(self, u, delta, A, B, C, D): """Does selective scan algorithm. See: - Section 2 State Space Models in the Mamba paper [1] - Algorithm 2 in Section 3.2 in the Mamba paper [1] - run_SSM(A, B, C, u) in The Annotated S4 [2] This is the classic discrete state space formula: x(t + 1) = Ax(t) + Bu(t) y(t) = Cx(t) + Du(t) except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t). Args: u: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...) delta: shape (b, l, d_in) A: shape (d_in, n) B: shape (b, l, n) C: shape (b, l, n) D: shape (d_in,) Returns: output: shape (b, l, d_in) Official Implementation: selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86 Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly. """ (b, l, d_in) = u.shape n = A.shape[1] # Discretize continuous parameters (A, B) # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1]) # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: # "A is the more important term and the performance doesn't change much with the simplication on B" deltaA = torch.exp(einsum(delta, A, "b l d_in, d_in n -> b d_in l n")) deltaB_u = einsum(delta, B, u, "b l d_in, b l n, b l d_in -> b d_in l n") # Perform selective scan (see scan_SSM() in The Annotated S4 [2]) x = torch.zeros((b, d_in, n), device=deltaA.device) ys = [] for i in range(l): x = deltaA[:, :, i] * x + deltaB_u[:, :, i] y = einsum(x, C[:, i, :], "b d_in n, b n -> b d_in") ys.append(y) y = torch.stack(ys, dim=1) # shape (b, l, d_in) y = y + u * D return y class MambaBlock(nn.Module): def __init__(self, config: MambaConfig, layer_idx: int = 0): """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1].""" super().__init__() self.config = config self.mixer = Mamba(config) self.norm = MambaRMSNorm(config.d_model) def forward(self, x): return self.mixer(self.norm(x)) + x class MambaPreTrainedModel(PreTrainedModel): config_class = MambaConfig base_model_prefix = "backbone" supports_gradient_checkpointing = True _no_split_modules = ["MambaBlock"] def _init_weights(self, module): std = 0.02 if isinstance(module, (nn.Linear, nn.Conv1d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class MambaModel(MambaPreTrainedModel): def __init__(self, config: MambaConfig): """Full Mamba model. Mamba model decoder consisting of *config.n_layer* layers. Each layer is a [`MambaBlock`] Args: config: MambaConfig """ super().__init__(config) # self.config = config self.embedding = nn.Embedding(self.config.vocab_size, self.config.d_model) self.layers = nn.ModuleList( [MambaBlock(self.config, layer_idx) for layer_idx in range(self.config.n_layer)] ) self.norm_f = MambaRMSNorm(self.config.d_model) self.gradient_checkpointing = False self.post_init() def forward( self, input_ids: torch.LongTensor = None, output_hidden_states=False, return_dict: Optional[bool] = None, **kwargs, ) -> BaseModelOutputWithPast: batch_size = input_ids.shape[0] hidden_size = self.config.d_model hidden_states: Tuple[ torch.Tensor[(batch_size, sequence_length, hidden_size)] ] = () sequence_length = input_ids.shape[1] output_hidden_states = output_hidden_states or self.config.output_hidden_states last_hidden_state = self.embedding(input_ids) assert last_hidden_state.shape == ( batch_size, sequence_length, hidden_size, ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}" hidden_states += (last_hidden_state,) for layer in self.layers: last_hidden_state = layer(last_hidden_state) assert last_hidden_state.shape == ( batch_size, sequence_length, hidden_size, ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}" hidden_states += (last_hidden_state,) last_hidden_state = self.norm_f(last_hidden_state) assert last_hidden_state.shape == ( batch_size, sequence_length, hidden_size, ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}" hidden_states += (last_hidden_state,) assert ( len(hidden_states) == self.config.n_layer + 2 ), f"{len(hidden_states)} != {self.config.n_layer + 2}" return BaseModelOutputWithPast( hidden_states=hidden_states if output_hidden_states else None, last_hidden_state=last_hidden_state, ) class MambaModelForCausalLM(MambaPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config, **kwargs): super().__init__( config, **kwargs, ) self.backbone = MambaModel( config=self.config, ) self.lm_head = nn.Linear( in_features=self.config.d_model, out_features=self.config.vocab_size, bias=False, ) self.post_init() def _tie_weights(self): self.lm_head.weight = self.backbone.embedding.weight def forward( self, input_ids, labels: Optional[torch.LongTensor] = None, output_hidden_states=False, **kwargs, ) -> CausalLMOutputWithPast: batch_size = input_ids.shape[0] output_hidden_states = output_hidden_states or self.config.output_hidden_states sequence_length = input_ids.shape[1] vocab_size = self.config.vocab_size outputs = self.backbone( input_ids=input_ids, output_hidden_states=output_hidden_states, ) last_hidden_state = outputs.last_hidden_state logits: torch.FloatTensor[batch_size, sequence_length, vocab_size] = ( self.lm_head( last_hidden_state, ) ) if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, vocab_size) shift_labels = shift_labels.view(-1) shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) else: loss = None return CausalLMOutputWithPast( hidden_states=outputs.hidden_states if output_hidden_states else None, logits=logits, loss=loss, ) def prepare_inputs_for_generation( self, input_ids, attention_mask=None, **model_kwargs ): return { "input_ids": input_ids, } # class MambaModelForSequenceClassification(MambaModelForCausalLM): # def __init__( # self, # config, # id2label={0: "NEGATIVE", 1: "POSITIVE"}, # label2id={"NEGATIVE": 0, "POSITIVE": 1}, # num_labels=2, # **kwargs, # ): # super().__init__( # config, # **kwargs, # ) # self.id2label = id2label # self.label2id = label2id # self.num_labels = num_labels # TODO: config.num_labels # self.score = nn.Linear( # in_features=self.config.vocab_size, # out_features=self.num_labels, # bias=False, # ) # def forward( # self, # input_ids: Optional[torch.Tensor] = None, # labels: Optional[torch.Tensor] = None, # output_hidden_states=False, # **kwargs, # ) -> SequenceClassifierOutputWithPast: # batch_size = input_ids.shape[0] # hidden_size = self.config.vocab_size # hidden_states: Tuple[ # torch.Tensor[(batch_size, sequence_length, hidden_size)] # ] = () # num_labels = self.num_labels # TODO: config.num_labels # sequence_length = input_ids.shape[1] # vocab_size = self.config.vocab_size # output_hidden_states = output_hidden_states or self.config.output_hidden_states # outputs = super().forward( # input_ids=input_ids, # labels=None, # output_hidden_states=output_hidden_states, # **kwargs, # ) # last_hidden_state = outputs.logits # assert last_hidden_state.shape == ( # batch_size, # sequence_length, # hidden_size, # ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}" # hidden_states += (last_hidden_state,) # logits: torch.FloatTensor[batch_size, num_labels] = self.score( # last_hidden_state[:, -1, :] # TODO: Check if this makes sense # ) # if labels is not None: # loss_fct = CrossEntropyLoss() # loss = loss_fct(logits, labels) # else: # loss = None # return SequenceClassifierOutputWithPast( # loss=loss, # logits=logits, # hidden_states=hidden_states, # )