Transformers
PyTorch
Inference Endpoints

Mamba Chat 4k

A fine-tune of the Mamba SlimPajama model

Issues:

  • Some answers are given in a different language than the question. This is likely due to the mixed language nature of the OpenAssist dataset. However, this usually isn't a problem for stronger models.
  • After roughly 3500 tokens of input, the model fails.
  • The model is poor at coding tasks.
  • Passkey retrieval works at up to around 3500 tokens, however, the model struggles to respond to anything but short questions/queries. Note that this is NOT an issue with the openhermes fine-tune

Chat Fine-tuning Config:

All modules were trained except the following were frozen:

"mixer", "conv1d", "act", "head"

Inference

pip install torch==2.1.0 transformers==4.35.0 causal-conv1d==1.0.0 mamba-ssm==1.0.1
import torch
from transformers import AutoTokenizer
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

# Load the model and tokenizer
model_name = "Trelis/mamba-2.8b-slimpj-chat-4k"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = MambaLMHeadModel.from_pretrained(model_name, dtype=torch.bfloat16, device="cuda")

# Define the prompt
prompt = "what languages do you speak? answer me in english"

# Initialize an empty list for messages
messages = []

# Append the prompt to the messages list as a dictionary
messages.append(dict(role="user", content=prompt))

device='cuda'

formatted=tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
print(formatted)

# Prepare the input for the model
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(device)
max_length = input_ids.shape[1] + 500  # Assuming you want to generate 100 tokens

# Generate function for Mamba model
def generate_mamba(input_ids, max_length):
    return model.generate(
        input_ids=input_ids,
        max_length=max_length,
        cg=True,
        return_dict_in_generate=True,
        output_scores=True,
        enable_timing=False,
        temperature=0.01,
        top_k=1,
        top_p=1.0,
        eos_token_id=tokenizer.eos_token_id,
        repetition_penalty=1.0,
    )

# Run the generation
out = generate_mamba(input_ids, max_length)

# Decode and print the generated text
decoded_sequences = tokenizer.batch_decode(out.sequences.tolist())
for sequence in decoded_sequences:
    generated_text = sequence[len(tokenizer.decode(input_ids[0], skip_special_tokens=False)):]
    print(generated_text)
Downloads last month
9
Inference API
Unable to determine this model’s pipeline type. Check the docs .

Dataset used to train Trelis/mamba-2.8b-slimpj-chat-4k