Edit model card

Mamba-370M

mamba-hf

Mamba Models with hf_integration.

For modeling codes: mamba-hf

Usage:

from transformers import AutoModelForCausalLM , AutoTokenizer

model = AutoModelForCausalLM.from_pretrained('Q-bert/Mamba-370M', trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained('Q-bert/Mamba-370M')

text = "Hi"

input_ids = tokenizer.encode(text, return_tensors="pt")

output = model.generate(input_ids, max_length=20, num_beams=5, no_repeat_ngram_size=2)

generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

print(generated_text)

Hi, I'm looking for a new job. I've been working at a company for about a year now.

For Training:

from transformers import Trainer ,TrainingArguments
import torch
import os


class MambaTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs.pop("input_ids")
        lm_logits = model(input_ids)[0]

        labels = input_ids.to(lm_logits.device)
        shift_logits = lm_logits[:, :-1, :].contiguous()
        labels = labels[:, 1:].contiguous()

        loss_fct = torch.nn.CrossEntropyLoss()
        lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))

        return lm_loss

You must use this class for training. And fp16 must be False.

Credits:

https://huggingface.co/state-spaces

Special thanks to Albert Gu and Tri Dao for their articles. (https://arxiv.org/abs/2312.00752)

Downloads last month
87
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Collection including Q-bert/Mamba-370M