Mamba 2
Overview
The Mamba2 model was proposed in Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality by Tri Dao and Albert Gu. It is a State Space Model similar to Mamba 1, with better performances in a simplified architecture.
The abstract from the paper is the following:
While Transformers have been the main architecture behind deep learningβs success in language modeling, state-space models (SSMs) such as Mamba have recently been shown to match or outperform Transformers at small to medium scale. We show that these families of models are actually quite closely related, and develop a rich framework of theoretical connections between SSMs and variants of attention, connected through various decompositions of a well-studied class of structured semiseparable matrices. Our state space duality (SSD) framework allows us to design a new architecture (Mamba-2) whose core layer is an a refinement of Mambaβs selective SSM that is 2-8X faster, while continuing to be competitive with Transformers on language modeling.
Tips:
This version should support all implementations of Mamba 2, and in particular Mamba-2 codestral from Mistral AI. In particular, mamba 2 codestral was released with a number of groups
equal to 8, which can be thought intuitively as similar to the number of kv heads in an attention-based model.
This model has two different forward passes, torch_forward
or cuda_kernels_forward
. The latter uses the original cuda kernels if they are found in your environment, and is slower on the prefill i.e. requires a βwarmup runβ due to high cpu overhead, see here and also here. Without compilation, the torch_forward
implementation is faster by a factor 3 to 4. Further, there are no positional embeddings in this model, but there is an attention_mask
and a specific logic to mask out hidden states in two places in the case of batched generation, see here as well. Due to this, in addition to the reimplementation of mamba2 kernels, batched generation and cached generation are expected to have slight discrepancies. Further, the results given by the cuda kernels or the torch forward are expected to be slightly different. The SSM algorithm heavily relies on tensor contractions, which have matmul equivalents but the order of operations is slightly different, making the difference greater at smaller precisions.
Another note, shutdown of hidden states corresponding to padding tokens is done in 2 places and mostly has been tested with left-padding. Right-padding will propagate noise down the line and is not guaranteed to yield satisfactory results. tokenizer.padding_side = "left"
ensures you are using the correct padding side.
This model was contributed by Molbap, with tremendous help from Anton Vlasjuk. The original code can be found here.
Usage
A simple generation example:
from transformers import Mamba2Config, Mamba2ForCausalLM, AutoTokenizer
import torch
model_id = 'mistralai/Mamba-Codestral-7B-v0.1'
tokenizer = AutoTokenizer.from_pretrained(model_id, revision='refs/pr/9', from_slow=True, legacy=False)
model = Mamba2ForCausalLM.from_pretrained(model_id, revision='refs/pr/9')
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))
Hereβs a draft script for finetuning:
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, Mamba2ForCausalLM, TrainingArguments
model_id = 'mistralai/Mamba-Codestral-7B-v0.1'
tokenizer = AutoTokenizer.from_pretrained(model_id, revision='refs/pr/9', from_slow=True, legacy=False)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" #enforce padding side left
model = Mamba2ForCausalLM.from_pretrained(model_id, revision='refs/pr/9')
dataset = load_dataset("Abirate/english_quotes", split="train")
# Without CUDA kernels, batch size of 2 occupies one 80GB device
# but precision can be reduced.
# Experiments and trials welcome!
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=2,
logging_dir='./logs',
logging_steps=10,
learning_rate=2e-3
)
lora_config = LoraConfig(
r=8,
target_modules=["embeddings", "in_proj", "out_proj"],
task_type="CAUSAL_LM",
bias="none"
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
dataset_text_field="quote",
)
trainer.train()
Mamba2Config
class transformers.Mamba2Config
< source >( num_heads = 128 head_dim = 64 vocab_size = 32768 hidden_size = 4096 state_size = 128 num_hidden_layers = 64 layer_norm_epsilon = 1e-05 pad_token_id = 1 bos_token_id = 0 eos_token_id = 2 expand = 2 conv_kernel = 4 n_groups = 8 use_bias = False use_conv_bias = True hidden_act = 'silu' initializer_range = 0.1 residual_in_fp32 = True time_step_rank = 'auto' time_step_min = 0.001 time_step_max = 0.1 time_step_floor = 0.0001 time_step_limit = (0.0, inf) rescale_prenorm_residual = False use_cache = True rms_norm = True chunk_size = 256 tie_word_embeddings = False **kwargs )
Parameters
- num_heads (
int
, optional, defaults to 128) — Number of heads for the evolution matrices of mamba 2. - head_dim (
int
, optional, defaults to 64) — Dimension of each head. - vocab_size (
int
, optional, defaults to 32768) — Vocabulary size of the MAMBA2 model. Defines the number of different tokens that can be represented by theinputs_ids
passed when calling Mamba2Model. - hidden_size (
int
, optional, defaults to 4096) — Dimensionality of the embeddings and hidden states. - state_size (
int
, optional, defaults to 128) — shape of the state space latents. - num_hidden_layers (
int
, optional, defaults to 64) — Number of hidden layers in the model. - layer_norm_epsilon (
float
, optional, defaults to 1e-05) — The epsilon to use in the layer normalization layers. - pad_token_id (
int
, optional, defaults to 1) — Padding token id. - bos_token_id (
int
, optional, defaults to 0) — The id of the beginning of sentence token in the vocabulary. - eos_token_id (
int
, optional, defaults to 2) — The id of the end of sentence token in the vocabulary. - expand (
int
, optional, defaults to 2) — Expanding factor used to determine the intermediate size. - conv_kernel (
int
, optional, defaults to 4) — Size of the convolution kernel. - n_groups (
int
, optional, defaults to 8) — Number of groups for the evolution matrices of mamba 2. - use_bias (
bool
, optional, defaults toFalse
) — Whether or not to use bias in [“in_proj”, “out_proj”] of the mixer block - use_conv_bias (
bool
, optional, defaults toTrue
) — Whether or not to use bias in the convolution layer of the mixer block. - hidden_act (
str
, optional, defaults to"silu"
) — The non-linear activation function (function or string) in the decoder. - initializer_range (
float
, optional, defaults to 0.1) — The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - residual_in_fp32 (
bool
, optional, defaults toTrue
) — Whether or not residuals should be infloat32
. If set toFalse
residuals will keep the samedtype
as the rest of the model - time_step_rank (
Union[int,str]
, optional, defaults to"auto"
) — Rank of the discretization projection matrix."auto"
means that it will default tomath.ceil(self.hidden_size / 16)
- time_step_min (
float
, optional, defaults to 0.001) — Minimumtime_step
used to bounddt_proj.bias
. - time_step_max (
float
, optional, defaults to 0.1) — Maximumtime_step
used to bounddt_proj.bias
. - time_step_floor (
float
, optional, defaults to 0.0001) — Minimum clamping value of thedt_proj.bias
layer initialization. - time_step_limit (
tuple
, optional, defaults to(0.0, inf)
) — Accepted range of time step values. - rescale_prenorm_residual (
bool
, optional, defaults toFalse
) — Whether or not to rescaleout_proj
weights when initializing. - use_cache (
bool
, optional, defaults toTrue
) — Whether or not the cache should be used. - rms_norm (
bool
, optional, defaults toTrue
) — Whether to use RMS norm or not. - chunk_size (
int
, optional, defaults to 256) — Size of the chunks that will comprise the sequence. - tie_word_embeddings (
bool
, optional, defaults toFalse
) — Whether to tie word embeddings or not.
This is the configuration class to store the configuration of a Mamba2Model. It is used to instantiate a MAMBA2 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the MAMBA2 state-spaces/mamba2-2.8b architecture.
Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.
Example:
>>> from transformers import Mamba2Config, Mamba2Model
>>> # Initializing a Mamba2 configuration
>>> configuration = Mamba2Config()
>>> # Initializing a model (with random weights) from the configuration
>>> model = Mamba2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
Mamba2Model
class transformers.Mamba2Model
< source >( config )
Parameters
- config (Mamba2Config) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
The bare MAMBA2 Model transformer outputting raw hidden-states without any specific head on top.
This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)
This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
forward
< source >( input_ids: Optional = None inputs_embeds: Optional = None cache_params: Optional = None use_cache: Optional = None output_hidden_states: Optional = None return_dict: Optional = None cache_position: Optional = None attention_mask: Optional = None **kwargs ) β transformers.models.mamba2.modeling_mamba2.Mamba2Output
or tuple(torch.FloatTensor)
Parameters
- input_ids (
torch.LongTensor
of shape(batch_size, input_ids_length)
) — Indices of input sequence tokens in the vocabulary.If
cache_params.seqlen_offset>0
, onlyinput_ids
that do not have their past calculated should be passed asinput_ids
.Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() and PreTrainedTokenizer.call() for details.
- inputs_embeds (
torch.FloatTensor
of shape(batch_size, sequence_length, hidden_size)
, optional) — Optionally, instead of passinginput_ids
you can choose to directly pass an embedded representation. This is useful if you want more control over how to convertinput_ids
indices into associated vectors than the model’s internal embedding lookup matrix. - cache_params (
Mamba2Cache
, optional) — If passed along, the model uses the previous state in all the blocks (which will give the output for theinput_ids
provided as if the model addstate_input_ids + input_ids
as context). - use_cache (
bool
, optional) — If set toTrue
, thecache_params
is returned and can be used to quickly generate the next logits. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
Returns
transformers.models.mamba2.modeling_mamba2.Mamba2Output
or tuple(torch.FloatTensor)
A transformers.models.mamba2.modeling_mamba2.Mamba2Output
or a tuple of
torch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various
elements depending on the configuration (Mamba2Config) and inputs.
-
last_hidden_state (
torch.FloatTensor
of shape(batch_size, sequence_length, hidden_size)
) β Sequence of hidden-states at the output of the last layer of the model. -
cache_params (
Mamba2Cache
) β The state of the model at the last time step. Can be used in a forward method with the nextinput_ids
to avoid providing the oldinput_ids
.Includes both the State space model state matrices after the selective scan, and the Convolutional states
-
hidden_states (
tuple(torch.FloatTensor)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) β Tuple oftorch.FloatTensor
(one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
The Mamba2Model forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
Example:
>>> from transformers import AutoTokenizer, Mamba2Model
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/mamba-codestral-7B-v0.1")
>>> model = Mamba2Model.from_pretrained("mistralai/mamba-codestral-7B-v0.1")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
Mamba2LMHeadModel
class transformers.Mamba2ForCausalLM
< source >( config )
Parameters
- config (Mamba2Config) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
The MAMBA2 Model transformer with a language modeling head on top (linear layer with weights not tied to the input embeddings).
This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)
This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
forward
< source >( input_ids: Optional = None inputs_embeds: Optional = None cache_params: Optional = None labels: Optional = None output_hidden_states: Optional = None return_dict: Optional = None use_cache: Optional = None cache_position: Optional = None attention_mask: Optional = None **kwargs ) β transformers.models.mamba2.modeling_mamba2.Mamba2CausalLMOutput
or tuple(torch.FloatTensor)
Parameters
- input_ids (
torch.LongTensor
of shape(batch_size, input_ids_length)
) — Indices of input sequence tokens in the vocabulary.If
cache_params.seqlen_offset>0
, onlyinput_ids
that do not have their past calculated should be passed asinput_ids
.Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() and PreTrainedTokenizer.call() for details.
- inputs_embeds (
torch.FloatTensor
of shape(batch_size, sequence_length, hidden_size)
, optional) — Optionally, instead of passinginput_ids
you can choose to directly pass an embedded representation. This is useful if you want more control over how to convertinput_ids
indices into associated vectors than the model’s internal embedding lookup matrix. - cache_params (
Mamba2Cache
, optional) — If passed along, the model uses the previous state in all the blocks (which will give the output for theinput_ids
provided as if the model addstate_input_ids + input_ids
as context). - use_cache (
bool
, optional) — If set toTrue
, thecache_params
is returned and can be used to quickly generate the next logits. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple. - labels (
torch.LongTensor
of shape(batch_size, sequence_length)
, optional) — Labels for language modeling. Note that the labels are shifted inside the model, i.e. you can setlabels = input_ids
Indices are selected in[-100, 0, ..., config.vocab_size]
All labels set to-100
are ignored (masked), the loss is only computed for labels in[0, ..., config.vocab_size]
Returns
transformers.models.mamba2.modeling_mamba2.Mamba2CausalLMOutput
or tuple(torch.FloatTensor)
A transformers.models.mamba2.modeling_mamba2.Mamba2CausalLMOutput
or a tuple of
torch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various
elements depending on the configuration (Mamba2Config) and inputs.
-
loss (
torch.FloatTensor
of shape(1,)
, optional, returned whenlabels
is provided) β Language modeling loss (for next-token prediction). -
logits (
torch.FloatTensor
of shape(batch_size, sequence_length, config.vocab_size)
) β Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). -
cache_params (
Mamba2Cache
) β The state of the model at the last time step. Can be used in a forward method with the nextinput_ids
to avoid providing the oldinput_ids
.Includes both the State space model state matrices after the selective scan, and the Convolutional states
-
hidden_states (
tuple(torch.FloatTensor)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) β Tuple oftorch.FloatTensor
(one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
The Mamba2ForCausalLM forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
Example:
>>> import torch
>>> from transformers import AutoTokenizer, Mamba2ForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/mamba-codestral-7B-v0.1")
>>> model = Mamba2ForCausalLM.from_pretrained("mistralai/mamba-codestral-7B-v0.1")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs, labels=inputs["input_ids"])
>>> loss = outputs.loss
>>> logits = outputs.logits