Mamba Chat 4k
A fine-tune of the Mamba SlimPajama model
Issues:
- Some answers are given in a different language than the question. This is likely due to the mixed language nature of the OpenAssist dataset. However, this usually isn't a problem for stronger models.
- After roughly 3500 tokens of input, the model fails.
- The model is poor at coding tasks.
- Passkey retrieval works at up to around 3500 tokens, however, the model struggles to respond to anything but short questions/queries. Note that this is NOT an issue with the openhermes fine-tune
Chat Fine-tuning Config:
All modules were trained except the following were frozen:
"mixer", "conv1d", "act", "head"
Inference
pip install torch==2.1.0 transformers==4.35.0 causal-conv1d==1.0.0 mamba-ssm==1.0.1
import torch
from transformers import AutoTokenizer
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
# Load the model and tokenizer
model_name = "Trelis/mamba-2.8b-slimpj-chat-4k"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = MambaLMHeadModel.from_pretrained(model_name, dtype=torch.bfloat16, device="cuda")
# Define the prompt
prompt = "what languages do you speak? answer me in english"
# Initialize an empty list for messages
messages = []
# Append the prompt to the messages list as a dictionary
messages.append(dict(role="user", content=prompt))
device='cuda'
formatted=tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
print(formatted)
# Prepare the input for the model
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(device)
max_length = input_ids.shape[1] + 500 # Assuming you want to generate 100 tokens
# Generate function for Mamba model
def generate_mamba(input_ids, max_length):
return model.generate(
input_ids=input_ids,
max_length=max_length,
cg=True,
return_dict_in_generate=True,
output_scores=True,
enable_timing=False,
temperature=0.01,
top_k=1,
top_p=1.0,
eos_token_id=tokenizer.eos_token_id,
repetition_penalty=1.0,
)
# Run the generation
out = generate_mamba(input_ids, max_length)
# Decode and print the generated text
decoded_sequences = tokenizer.batch_decode(out.sequences.tolist())
for sequence in decoded_sequences:
generated_text = sequence[len(tokenizer.decode(input_ids[0], skip_special_tokens=False)):]
print(generated_text)
- Downloads last month
- 9