--- license: apache-2.0 datasets: - Trelis/openassistant-falcon --- # Mamba Chat 4k A fine-tune of the [Mamba SlimPajama model](state-spaces/mamba-2.8b-slimpj) ## Issues: - Some answers are given in a different language than the question. This is likely due to the mixed language nature of the OpenAssist dataset. However, this usually isn't a problem for stronger models. - After roughly 3500 tokens of input, the model fails. - The model is poor at coding tasks. - Passkey retrieval works at up to around 3500 tokens, however, the model struggles to respond to anything but short questions/queries. Note that this is NOT an issue with the [openhermes fine-tune](https://huggingface.co/clibrain/mamba-2.8b-instruct-openhermes) ## Chat Fine-tuning Config: All modules were trained except the following were frozen: ``` "mixer", "conv1d", "act", "head" ``` ## Inference ``` pip install torch==2.1.0 transformers==4.35.0 causal-conv1d==1.0.0 mamba-ssm==1.0.1 ``` ``` import torch from transformers import AutoTokenizer from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel # Load the model and tokenizer model_name = "Trelis/mamba-2.8b-slimpj-chat-4k" tokenizer = AutoTokenizer.from_pretrained(model_name) model = MambaLMHeadModel.from_pretrained(model_name, dtype=torch.bfloat16, device="cuda") # Define the prompt prompt = "what languages do you speak? answer me in english" # Initialize an empty list for messages messages = [] # Append the prompt to the messages list as a dictionary messages.append(dict(role="user", content=prompt)) device='cuda' formatted=tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) print(formatted) # Prepare the input for the model input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(device) max_length = input_ids.shape[1] + 500 # Assuming you want to generate 100 tokens # Generate function for Mamba model def generate_mamba(input_ids, max_length): return model.generate( input_ids=input_ids, max_length=max_length, cg=True, return_dict_in_generate=True, output_scores=True, enable_timing=False, temperature=0.01, top_k=1, top_p=1.0, eos_token_id=tokenizer.eos_token_id, repetition_penalty=1.0, ) # Run the generation out = generate_mamba(input_ids, max_length) # Decode and print the generated text decoded_sequences = tokenizer.batch_decode(out.sequences.tolist()) for sequence in decoded_sequences: generated_text = sequence[len(tokenizer.decode(input_ids[0], skip_special_tokens=False)):] print(generated_text) ```