맘바2
개요
맘바2 모델은 Tri Dao, Albert Gu가 제안한 트랜스포머는 SSM이다: 구조화된 상태 공간 이중성을 통한 일반화된 모델과 효율적인 알고리즘라는 논문에서 소개되었습니다. 맘바2는 맘바1과 유사한 상태 공간 모델로, 단순화된 아키텍처에서 더 나은 성능을 보입니다.
해당 논문의 초록입니다:
트랜스포머는 언어 모델링에서 딥러닝 성공의 주요 아키텍처였지만, 맘바와 같은 상태 공간 모델(SSM)이 최근 소규모 혹은 중간 규모에서 트랜스포머와 대등하거나 더 나은 성능을 보이는 것으로 나타났습니다. 우리는 이러한 모델 계열들이 실제로 매우 밀접하게 연관되어 있음을 파악했습니다. 그리고 구조화된 준분리(semiseparable) 행렬 중 연구가 잘 이루어진 클래스의 다양한 분해를 통해 연결된 SSM과 어텐션 변형 사이의 풍부한 이론적 연결 프레임워크를 개발했습니다. 상태 공간 이중성(SSD) 프레임워크를 통해 맘바1의 선택적 SSM을 개선한 새로운 아키텍처를 설계할 수 있었고, 트랜스포머와 경쟁력을 유지하면서도 속도는 2~8배 더 빠른 성능을 냅니다.
팁:
이 버전은 맘바2 구현을 지원해야 하며, 특히 Mistral AI의 Mamba-2 codestral을 지원합니다. 특히, mamba 2 codestral은 8개의 groups
로 출시되었는데, 이는 어텐션 기반 모델의 KV 헤드 수와 유사하다고 판단 가능합니다.
이 모델은 torch_forward
와 cuda_kernels_forward
라는 두 가지 다른 전방 패스를 가집니다. cuda_kernels_forward
는 환경에서 cuda 커널을 찾으면 이를 사용하며, prefill에서는 더 느립니다. 즉, 높은 CPU 오버헤드로 인해 “웜업 실행”이 필요하기 때문입니다. 관련 내용은 이곳과 이곳을 참고하세요.
컴파일 없이는 torch_forward
구현이 3~4배 빠릅니다. 또한, 이 모델에는 위치 임베딩이 없지만 attention_mask
와 배치 생성의 경우 두 곳에서 은닉 상태(hidden state)를 마스킹하는 특정 로직이 있습니다. 관련 내용은 이곳을 참고하세요.
이로인해 맘바2 커널의 재구현과 함께 배치 생성 및 캐시된 생성에서 약간의 차이가 예상됩니다. 또한 cuda 커널 또는 torch forward가 제공하는 결과가 약간 다를 것으로 예상됩니다. SSM 알고리즘은 텐서 수축에 크게 의존하는데, 이는 matmul과 동등하지만 연산 순서가 약간 다르며, 이로 인해 더 작은 정밀도에서 차이가 더 커집니다.
또 다른 참고사항으로, 패딩 토큰에 해당하는 은닉 상태(hidden state)의 종료는 두 곳에서 이루어지며 주로 왼쪽 패딩으로 테스트되었습니다. 오른쪽 패딩은 노이즈를 전파하므로 만족스러운 결과를 보장하지 않습니다. tokenizer.padding_side = "left"
를 사용하면 올바른 패딩 방향을 사용할 수 있습니다.
이 모델은 Molbap이 기여했으며, Anton Vlasjuk의 큰 도움을 받았습니다. 원본 코드는 이곳에서 확인할 수 있습니다.
사용
간단한 생성 예:
from transformers import Mamba2Config, Mamba2ForCausalLM, AutoTokenizer
import torch
model_id = 'mistralai/Mamba-Codestral-7B-v0.1'
tokenizer = AutoTokenizer.from_pretrained(model_id, revision='refs/pr/9', from_slow=True, legacy=False)
model = Mamba2ForCausalLM.from_pretrained(model_id, revision='refs/pr/9')
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))
이곳은 미세조정을 위한 초안 스크립트입니다:
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, Mamba2ForCausalLM, TrainingArguments
model_id = 'mistralai/Mamba-Codestral-7B-v0.1'
tokenizer = AutoTokenizer.from_pretrained(model_id, revision='refs/pr/9', from_slow=True, legacy=False)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" #왼쪽 패딩으로 설정
model = Mamba2ForCausalLM.from_pretrained(model_id, revision='refs/pr/9')
dataset = load_dataset("Abirate/english_quotes", split="train")
# CUDA 커널없이는, 배치크기 2가 80GB 장치를 하나 차지합니다.
# 하지만 정확도는 감소합니다.
# 실험과 시도를 환영합니다!
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=2,
logging_dir='./logs',
logging_steps=10,
learning_rate=2e-3
)
lora_config = LoraConfig(
r=8,
target_modules=["embeddings", "in_proj", "out_proj"],
task_type="CAUSAL_LM",
bias="none"
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
dataset_text_field="quote",
)
trainer.train()
Mamba2Config
class transformers.Mamba2Config
< source >( num_heads = 128 head_dim = 64 vocab_size = 32768 hidden_size = 4096 state_size = 128 num_hidden_layers = 64 layer_norm_epsilon = 1e-05 pad_token_id = 1 bos_token_id = 0 eos_token_id = 2 expand = 2 conv_kernel = 4 n_groups = 8 use_bias = False use_conv_bias = True hidden_act = 'silu' initializer_range = 0.1 residual_in_fp32 = True time_step_rank = 'auto' time_step_min = 0.001 time_step_max = 0.1 time_step_floor = 0.0001 time_step_limit = (0.0, inf) rescale_prenorm_residual = False use_cache = True rms_norm = True chunk_size = 256 tie_word_embeddings = False **kwargs )
Parameters
- num_heads (
int
, optional, defaults to 128) — Number of heads for the evolution matrices of mamba 2. - head_dim (
int
, optional, defaults to 64) — Dimension of each head. - vocab_size (
int
, optional, defaults to 32768) — Vocabulary size of the MAMBA2 model. Defines the number of different tokens that can be represented by theinputs_ids
passed when calling Mamba2Model. - hidden_size (
int
, optional, defaults to 4096) — Dimensionality of the embeddings and hidden states. - state_size (
int
, optional, defaults to 128) — shape of the state space latents. - num_hidden_layers (
int
, optional, defaults to 64) — Number of hidden layers in the model. - layer_norm_epsilon (
float
, optional, defaults to 1e-05) — The epsilon to use in the layer normalization layers. - pad_token_id (
int
, optional, defaults to 1) — Padding token id. - bos_token_id (
int
, optional, defaults to 0) — The id of the beginning of sentence token in the vocabulary. - eos_token_id (
int
, optional, defaults to 2) — The id of the end of sentence token in the vocabulary. - expand (
int
, optional, defaults to 2) — Expanding factor used to determine the intermediate size. - conv_kernel (
int
, optional, defaults to 4) — Size of the convolution kernel. - n_groups (
int
, optional, defaults to 8) — Number of groups for the evolution matrices of mamba 2. - use_bias (
bool
, optional, defaults toFalse
) — Whether or not to use bias in [“in_proj”, “out_proj”] of the mixer block - use_conv_bias (
bool
, optional, defaults toTrue
) — Whether or not to use bias in the convolution layer of the mixer block. - hidden_act (
str
, optional, defaults to"silu"
) — The non-linear activation function (function or string) in the decoder. - initializer_range (
float
, optional, defaults to 0.1) — The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - residual_in_fp32 (
bool
, optional, defaults toTrue
) — Whether or not residuals should be infloat32
. If set toFalse
residuals will keep the samedtype
as the rest of the model - time_step_rank (
Union[int,str]
, optional, defaults to"auto"
) — Rank of the discretization projection matrix."auto"
means that it will default tomath.ceil(self.hidden_size / 16)
- time_step_min (
float
, optional, defaults to 0.001) — Minimumtime_step
used to bounddt_proj.bias
. - time_step_max (
float
, optional, defaults to 0.1) — Maximumtime_step
used to bounddt_proj.bias
. - time_step_floor (
float
, optional, defaults to 0.0001) — Minimum clamping value of thedt_proj.bias
layer initialization. - time_step_limit (
tuple
, optional, defaults to(0.0, inf)
) — Accepted range of time step values. - rescale_prenorm_residual (
bool
, optional, defaults toFalse
) — Whether or not to rescaleout_proj
weights when initializing. - use_cache (
bool
, optional, defaults toTrue
) — Whether or not the cache should be used. - rms_norm (
bool
, optional, defaults toTrue
) — Whether to use RMS norm or not. - chunk_size (
int
, optional, defaults to 256) — Size of the chunks that will comprise the sequence. - tie_word_embeddings (
bool
, optional, defaults toFalse
) — Whether to tie word embeddings or not.
This is the configuration class to store the configuration of a Mamba2Model. It is used to instantiate a MAMBA2 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the MAMBA2 state-spaces/mamba2-2.8b architecture.
Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.
Example:
>>> from transformers import Mamba2Config, Mamba2Model
>>> # Initializing a Mamba2 configuration
>>> configuration = Mamba2Config()
>>> # Initializing a model (with random weights) from the configuration
>>> model = Mamba2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
Mamba2Model
class transformers.Mamba2Model
< source >( config )
Parameters
- config (Mamba2Config) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
The bare MAMBA2 Model transformer outputting raw hidden-states without any specific head on top.
This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)
This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
forward
< source >( input_ids: typing.Optional[torch.LongTensor] = None inputs_embeds: typing.Optional[torch.LongTensor] = None cache_params: typing.Optional[transformers.models.mamba2.modeling_mamba2.Mamba2Cache] = None use_cache: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None cache_position: typing.Optional[torch.LongTensor] = None attention_mask: typing.Optional[torch.Tensor] = None **kwargs ) → transformers.models.mamba2.modeling_mamba2.Mamba2Output
or tuple(torch.FloatTensor)
Parameters
- input_ids (
torch.LongTensor
of shape(batch_size, input_ids_length)
) — Indices of input sequence tokens in the vocabulary.If
cache_params.seqlen_offset>0
, onlyinput_ids
that do not have their past calculated should be passed asinput_ids
.Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() and PreTrainedTokenizer.call() for details.
- inputs_embeds (
torch.FloatTensor
of shape(batch_size, sequence_length, hidden_size)
, optional) — Optionally, instead of passinginput_ids
you can choose to directly pass an embedded representation. This is useful if you want more control over how to convertinput_ids
indices into associated vectors than the model’s internal embedding lookup matrix. - cache_params (
Mamba2Cache
, optional) — If passed along, the model uses the previous state in all the blocks (which will give the output for theinput_ids
provided as if the model addstate_input_ids + input_ids
as context). - use_cache (
bool
, optional) — If set toTrue
, thecache_params
is returned and can be used to quickly generate the next logits. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple. - cache_position (
torch.LongTensor
of shape(batch_size,)
, optional) — The position of the current input in the cache. This is used to ensure that the cache is correctly updated. Ifcache_params
is passed,cache_position
should also be passed. - attention_mask (
torch.FloatTensor
of shape(batch_size, sequence_length)
, optional) — Mask to avoid performing attention on padding token indices. Mask values selected in[0, 1]
:- 1 for tokens that are not masked,
- 0 for tokens that are masked.
Returns
transformers.models.mamba2.modeling_mamba2.Mamba2Output
or tuple(torch.FloatTensor)
A transformers.models.mamba2.modeling_mamba2.Mamba2Output
or a tuple of
torch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various
elements depending on the configuration (Mamba2Config) and inputs.
-
last_hidden_state (
torch.FloatTensor
of shape(batch_size, sequence_length, hidden_size)
) — Sequence of hidden-states at the output of the last layer of the model. -
cache_params (
Mamba2Cache
) — The state of the model at the last time step. Can be used in a forward method with the nextinput_ids
to avoid providing the oldinput_ids
.Includes both the State space model state matrices after the selective scan, and the Convolutional states
-
hidden_states (
tuple(torch.FloatTensor)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple oftorch.FloatTensor
(one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
The Mamba2Model forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
Example:
>>> from transformers import AutoTokenizer, Mamba2Model
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/mamba-codestral-7B-v0.1")
>>> model = Mamba2Model.from_pretrained("mistralai/mamba-codestral-7B-v0.1")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
Mamba2LMHeadModel
class transformers.Mamba2ForCausalLM
< source >( config )
Parameters
- config (Mamba2Config) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
The MAMBA2 Model transformer with a language modeling head on top (linear layer with weights not tied to the input embeddings).
This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)
This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
forward
< source >( input_ids: typing.Optional[torch.LongTensor] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None cache_params: typing.Optional[transformers.models.mamba2.modeling_mamba2.Mamba2Cache] = None labels: typing.Optional[torch.LongTensor] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None use_cache: typing.Optional[bool] = None cache_position: typing.Optional[torch.Tensor] = None attention_mask: typing.Optional[torch.Tensor] = None **kwargs ) → transformers.models.mamba2.modeling_mamba2.Mamba2CausalLMOutput
or tuple(torch.FloatTensor)
Parameters
- input_ids (
torch.LongTensor
of shape(batch_size, input_ids_length)
) — Indices of input sequence tokens in the vocabulary.If
cache_params.seqlen_offset>0
, onlyinput_ids
that do not have their past calculated should be passed asinput_ids
.Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() and PreTrainedTokenizer.call() for details.
- inputs_embeds (
torch.FloatTensor
of shape(batch_size, sequence_length, hidden_size)
, optional) — Optionally, instead of passinginput_ids
you can choose to directly pass an embedded representation. This is useful if you want more control over how to convertinput_ids
indices into associated vectors than the model’s internal embedding lookup matrix. - cache_params (
Mamba2Cache
, optional) — If passed along, the model uses the previous state in all the blocks (which will give the output for theinput_ids
provided as if the model addstate_input_ids + input_ids
as context). - use_cache (
bool
, optional) — If set toTrue
, thecache_params
is returned and can be used to quickly generate the next logits. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple. - cache_position (
torch.LongTensor
of shape(batch_size,)
, optional) — The position of the current input in the cache. This is used to ensure that the cache is correctly updated. Ifcache_params
is passed,cache_position
should also be passed. - attention_mask (
torch.FloatTensor
of shape(batch_size, sequence_length)
, optional) — Mask to avoid performing attention on padding token indices. Mask values selected in[0, 1]
:- 1 for tokens that are not masked,
- 0 for tokens that are masked.
- labels (
torch.LongTensor
of shape(batch_size, sequence_length)
, optional) — Labels for language modeling. Note that the labels are shifted inside the model, i.e. you can setlabels = input_ids
Indices are selected in[-100, 0, ..., config.vocab_size]
All labels set to-100
are ignored (masked), the loss is only computed for labels in[0, ..., config.vocab_size]
Returns
transformers.models.mamba2.modeling_mamba2.Mamba2CausalLMOutput
or tuple(torch.FloatTensor)
A transformers.models.mamba2.modeling_mamba2.Mamba2CausalLMOutput
or a tuple of
torch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various
elements depending on the configuration (Mamba2Config) and inputs.
-
loss (
torch.FloatTensor
of shape(1,)
, optional, returned whenlabels
is provided) — Language modeling loss (for next-token prediction). -
logits (
torch.FloatTensor
of shape(batch_size, sequence_length, config.vocab_size)
) — Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). -
cache_params (
Mamba2Cache
) — The state of the model at the last time step. Can be used in a forward method with the nextinput_ids
to avoid providing the oldinput_ids
.Includes both the State space model state matrices after the selective scan, and the Convolutional states
-
hidden_states (
tuple(torch.FloatTensor)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple oftorch.FloatTensor
(one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
The Mamba2ForCausalLM forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
Example:
>>> import torch
>>> from transformers import AutoTokenizer, Mamba2ForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/mamba-codestral-7B-v0.1")
>>> model = Mamba2ForCausalLM.from_pretrained("mistralai/mamba-codestral-7B-v0.1")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs, labels=inputs["input_ids"])
>>> loss = outputs.loss
>>> logits = outputs.logits