Can `MambaForCausalLM` be used directly for training instead of `AutoModelForCausalLM`?

by TimeSpeaker - opened


I'm currently working with the transformers library to train a model on causal language modeling tasks using the MambaForCausalLM class. However, I've noticed that the typical approach to training in the library uses AutoModelForCausalLM to load the model for training, and I'm wondering if it's possible and recommended to use MambaForCausalLM directly for training instead.

Here is the code snippet I'm referring to:

from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments

# Model loading for training
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-370m-hf")
model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-370m-hf")

In inference, I successfully use MambaForCausalLM as follows:

from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("/home/SLLaMA/mamba-370m-hf")
model = MambaForCausalLM.from_pretrained("/home/SLLaMA/mamba-370m-hf")

Could you clarify if using MambaForCausalLM for training is supported and if there are any additional configurations required for this?

Thank you for your assistance.

Sign up or log in to comment