Edit model card

MambaHermes-3B

mamba-hf

Mamba Models with hf_integration.

For modeling codes: mamba-hf

Usage:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

CHAT_TEMPLATE_ID = "HuggingFaceH4/zephyr-7b-beta"

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model_name = "Q-bert/MambaHermes-3B"

eos_token = "<|endoftext|>"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.eos_token = eos_token
tokenizer.pad_token = tokenizer.eos_token
tokenizer.chat_template = AutoTokenizer.from_pretrained(CHAT_TEMPLATE_ID).chat_template

model = AutoModelForCausalLM.from_pretrained(
        model_name, device_map=device, trust_remote_code=True)

messages = []
prompt = "Tell me 5 sites to visit in Spain"
messages.append(dict(role="user", content=prompt))

input_ids = tokenizer.apply_chat_template(
            messages, return_tensors="pt", add_generation_prompt=True
).to(device)

out = model.generate(
    input_ids=input_ids,
    max_length=2000,
    temperature=0.9,
    top_p=0.7,
    eos_token_id=tokenizer.eos_token_id,
)

decoded = tokenizer.batch_decode(out)
assistant_message = (
    decoded[0].split("<|assistant|>\n")[-1].replace(tokenizer.eos_token, "")
)

print(assistant_message)

For Training:

from transformers import Trainer ,TrainingArguments
import torch
import os


class MambaTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs.pop("input_ids")
        lm_logits = model(input_ids)[0]

        labels = input_ids.to(lm_logits.device)
        shift_logits = lm_logits[:, :-1, :].contiguous()
        labels = labels[:, 1:].contiguous()

        loss_fct = torch.nn.CrossEntropyLoss()
        lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))

        return lm_loss

You must use this class for training. And fp16 must be False.

Credits:

https://huggingface.co/state-spaces

https://huggingface.co/clibrain/mamba-2.8b-instruct-openhermes

Special thanks to Albert Gu and Tri Dao for their articles. (https://arxiv.org/abs/2312.00752)

Downloads last month
14
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Collection including Q-bert/MambaHermes-3B