Transformers documentation

Mamba

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v4.46.3).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Mamba

Overview

The Mamba model was proposed in Mamba: Linear-Time Sequence Modeling with Selective State Spaces by Albert Gu and Tri Dao.

This model is a new paradigm architecture based on state-space-models. You can read more about the intuition behind these here.

The abstract from the paper is the following:

Foundation models, now powering most of the exciting applications in deep learning, are almost universally based on the Transformer architecture and its core attention module. Many subquadratic-time architectures such as linear attention, gated convolution and recurrent models, and structured state space models (SSMs) have been developed to address Transformers’ computational inefficiency on long sequences, but they have not performed as well as attention on important modalities such as language. We identify that a key weakness of such models is their inability to perform content-based reasoning, and make several improvements. First, simply letting the SSM parameters be functions of the input addresses their weakness with discrete modalities, allowing the model to selectively propagate or forget information along the sequence length dimension depending on the current token. Second, even though this change prevents the use of efficient convolutions, we design a hardware-aware parallel algorithm in recurrent mode. We integrate these selective SSMs into a simplified end-to-end neural network architecture without attention or even MLP blocks (Mamba). Mamba enjoys fast inference (5Γ— higher throughput than Transformers) and linear scaling in sequence length, and its performance improves on real data up to million-length sequences. As a general sequence model backbone, Mamba achieves state-of-the-art performance across several modalities such as language, audio, and genomics. On language modeling, our Mamba-3B model outperforms Transformers of the same size and matches Transformers twice its size, both in pretraining and downstream evaluation.

Tips:

  • Mamba is a new state space model architecture that rivals the classic Transformers. It is based on the line of progress on structured state space models, with an efficient hardware-aware design and implementation in the spirit of FlashAttention.
  • Mamba stacks mixer layers, which are the equivalent of Attention layers. The core logic of mamba is held in the MambaMixer class.
  • Two implementations cohabit: one is optimized and uses fast cuda kernels, while the other one is naive but can run on any device!
  • The current implementation leverages the original cuda kernels: the equivalent of flash attention for Mamba are hosted in the mamba-ssm and the causal_conv1d repositories. Make sure to install them if your hardware supports them!
  • Contributions to make the naive path faster are welcome πŸ€—

This model was contributed by ArthurZ. The original code can be found here.

Usage

A simple generation example:

from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]

out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))

Peft finetuning

The slow version is not very stable for training, and the fast one needs float32!

from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
model_id = "state-spaces/mamba-130m-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    logging_dir='./logs',
    logging_steps=10,
    learning_rate=2e-3
)
lora_config =  LoraConfig(
        r=8,
        target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
        task_type="CAUSAL_LM",
        bias="none"
)
trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    args=training_args,
    peft_config=lora_config,
    train_dataset=dataset,
    dataset_text_field="quote",
)
trainer.train()

MambaConfig

class transformers.MambaConfig

< >

( vocab_size = 50280 hidden_size = 768 state_size = 16 num_hidden_layers = 32 layer_norm_epsilon = 1e-05 pad_token_id = 0 bos_token_id = 0 eos_token_id = 0 expand = 2 conv_kernel = 4 use_bias = False use_conv_bias = True hidden_act = 'silu' initializer_range = 0.1 residual_in_fp32 = True time_step_rank = 'auto' time_step_scale = 1.0 time_step_min = 0.001 time_step_max = 0.1 time_step_init_scheme = 'random' time_step_floor = 0.0001 rescale_prenorm_residual = False use_cache = True use_mambapy = False **kwargs )

Parameters

  • vocab_size (int, optional, defaults to 50280) — Vocabulary size of the MAMBA model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling MambaModel.
  • hidden_size (int, optional, defaults to 768) — Dimensionality of the embeddings and hidden states.
  • state_size (int, optional, defaults to 16) — shape of the state space latents.
  • num_hidden_layers (int, optional, defaults to 32) — Number of hidden layers in the model.
  • layer_norm_epsilon (float, optional, defaults to 1e-05) — The epsilon to use in the layer normalization layers.
  • pad_token_id (int, optional, defaults to 0) — Padding token id.
  • bos_token_id (int, optional, defaults to 0) — The id of the beginning of sentence token in the vocabulary.
  • eos_token_id (int, optional, defaults to 0) — The id of the end of sentence token in the vocabulary.
  • expand (int, optional, defaults to 2) — Expanding factor used to determine the intermediate size.
  • conv_kernel (int, optional, defaults to 4) — Size of the convolution kernel.
  • use_bias (bool, optional, defaults to False) — Whether or not to use bias in [“in_proj”, “out_proj”] of the mixer block
  • use_conv_bias (bool, optional, defaults to True) — Whether or not to use bias in the convolution layer of the mixer block.
  • hidden_act (str, optional, defaults to "silu") — The non-linear activation function (function or string) in the decoder.
  • initializer_range (float, optional, defaults to 0.1) — The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  • residual_in_fp32 (bool, optional, defaults to True) — Whether or not residuals should be in float32. If set to False residuals will keep the same dtype as the rest of the model
  • time_step_rank (Union[int,str], optional, defaults to "auto") — Rank of the discretization projection matrix. "auto" means that it will default to math.ceil(self.hidden_size / 16)
  • time_step_scale (float, optional, defaults to 1.0) — Scale used used to scale dt_proj.bias.
  • time_step_min (float, optional, defaults to 0.001) — Minimum time_step used to bound dt_proj.bias.
  • time_step_max (float, optional, defaults to 0.1) — Maximum time_step used to bound dt_proj.bias.
  • time_step_init_scheme (float, optional, defaults to "random") — Init scheme used for dt_proj.weight. Should be one of ["random","uniform"]
  • time_step_floor (float, optional, defaults to 0.0001) — Minimum clamping value of the dt_proj.bias layer initialization.
  • rescale_prenorm_residual (bool, optional, defaults to False) — Whether or not to rescale out_proj weights when initializing.
  • use_cache (bool, optional, defaults to True) — Whether or not the cache should be used.
  • use_mambapy (bool, optional, defaults to False) — Determines the fallback strategy during training if the CUDA-based official implementation of Mamba is not avaiable. If True, the mamba.py implementation is used. If False, the naive and slower implementation is used. Consider switching to the naive version if memory is limited.

This is the configuration class to store the configuration of a MambaModel. It is used to instantiate a MAMBA model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the MAMBA state-spaces/mamba-2.8b architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

>>> from transformers import MambaConfig, MambaModel

>>> # Initializing a Mamba configuration
>>> configuration = MambaConfig()

>>> # Initializing a model (with random weights) from the configuration
>>> model = MambaModel(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config

MambaModel

class transformers.MambaModel

< >

( config )

Parameters

  • config (MambaConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

The bare MAMBA Model transformer outputting raw hidden-states without any specific head on top.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< >

( input_ids: typing.Optional[torch.LongTensor] = None inputs_embeds: typing.Optional[torch.LongTensor] = None cache_params: typing.Optional[transformers.cache_utils.MambaCache] = None use_cache: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None cache_position: typing.Optional[torch.LongTensor] = None attention_mask: typing.Optional[torch.LongTensor] = None ) β†’ transformers.models.mamba.modeling_mamba.MambaOutput or tuple(torch.FloatTensor)

Parameters

  • input_ids (torch.LongTensor of shape (batch_size, input_ids_length)) — Indices of input sequence tokens in the vocabulary.

    If cache_params.seqlen_offset>0, only input_ids that do not have their past calculated should be passed as input_ids.

    Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() and PreTrainedTokenizer.call() for details.

    What are input IDs?

  • inputs_embeds (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size), optional) — Optionally, instead of passing input_ids you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert input_ids indices into associated vectors than the model’s internal embedding lookup matrix.
  • cache_params (MambaCache, optional) — If passed along, the model uses the previous state in all the blocks (which will give the output for the input_ids provided as if the model add state_input_ids + input_ids as context).
  • use_cache (bool, optional) — If set to True, the cache_params is returned and can be used to quickly generate the next logits.
  • output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
  • return_dict (bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
  • cache_position (torch.LongTensor of shape (sequence_length), optional) — Indices depicting the position of the input sequence tokens in the sequence. Contrarily to position_ids, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length.

Returns

transformers.models.mamba.modeling_mamba.MambaOutput or tuple(torch.FloatTensor)

A transformers.models.mamba.modeling_mamba.MambaOutput or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (MambaConfig) and inputs.

  • last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)) β€” Sequence of hidden-states at the output of the last layer of the model.

  • cache_params (MambaCache) β€” The state of the model at the last time step. Can be used in a forward method with the next input_ids to avoid providing the old input_ids.

    Includes both the State space model state matrices after the selective scan, and the Convolutional states

  • hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) β€” Tuple of torch.FloatTensor (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.

The MambaModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

>>> from transformers import AutoTokenizer, MambaModel
>>> import torch

>>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
>>> model = MambaModel.from_pretrained("state-spaces/mamba-130m-hf")

>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)

>>> last_hidden_states = outputs.last_hidden_state

MambaLMHeadModel

class transformers.MambaForCausalLM

< >

( config )

Parameters

  • config (MambaConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

The MAMBA Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< >

( input_ids: typing.Optional[torch.LongTensor] = None attention_mask: typing.Optional[torch.LongTensor] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None cache_params: typing.Optional[transformers.cache_utils.MambaCache] = None labels: typing.Optional[torch.LongTensor] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None use_cache: typing.Optional[bool] = None cache_position: typing.Optional[torch.Tensor] = None **kwargs ) β†’ transformers.models.mamba.modeling_mamba.MambaCausalLMOutput or tuple(torch.FloatTensor)

Parameters

  • input_ids (torch.LongTensor of shape (batch_size, input_ids_length)) — Indices of input sequence tokens in the vocabulary.

    If cache_params.seqlen_offset>0, only input_ids that do not have their past calculated should be passed as input_ids.

    Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() and PreTrainedTokenizer.call() for details.

    What are input IDs?

  • inputs_embeds (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size), optional) — Optionally, instead of passing input_ids you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert input_ids indices into associated vectors than the model’s internal embedding lookup matrix.
  • cache_params (MambaCache, optional) — If passed along, the model uses the previous state in all the blocks (which will give the output for the input_ids provided as if the model add state_input_ids + input_ids as context).
  • use_cache (bool, optional) — If set to True, the cache_params is returned and can be used to quickly generate the next logits.
  • output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
  • return_dict (bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
  • cache_position (torch.LongTensor of shape (sequence_length), optional) — Indices depicting the position of the input sequence tokens in the sequence. Contrarily to position_ids, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length.
  • labels (torch.LongTensor of shape (batch_size, sequence_length), optional) — Labels for language modeling. Note that the labels are shifted inside the model, i.e. you can set labels = input_ids Indices are selected in [-100, 0, ..., config.vocab_size] All labels set to -100 are ignored (masked), the loss is only computed for labels in [0, ..., config.vocab_size]

Returns

transformers.models.mamba.modeling_mamba.MambaCausalLMOutput or tuple(torch.FloatTensor)

A transformers.models.mamba.modeling_mamba.MambaCausalLMOutput or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (MambaConfig) and inputs.

  • loss (torch.FloatTensor of shape (1,), optional, returned when labels is provided) β€” Language modeling loss (for next-token prediction).

  • logits (torch.FloatTensor of shape (batch_size, sequence_length, config.vocab_size)) β€” Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).

  • cache_params (MambaCache) β€” The state of the model at the last time step. Can be used in a forward method with the next input_ids to avoid providing the old input_ids.

    Includes both the State space model state matrices after the selective scan, and the Convolutional states

  • hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) β€” Tuple of torch.FloatTensor (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.

The MambaForCausalLM forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

>>> import torch
>>> from transformers import AutoTokenizer, MambaForCausalLM

>>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
>>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")

>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs, labels=inputs["input_ids"])
>>> loss = outputs.loss
>>> logits = outputs.logits
< > Update on GitHub