--- license: wtfpl datasets: - HuggingFaceH4/CodeAlpaca_20K pipeline_tag: text-generation thumbnail: https://huggingface.co/mrm8488/mamba-coder/resolve/main/mamba-coder-no-bg.png language: - en - code --- # Mamba-Coder ## MAMBA (2.8B) 🐍 fine-tuned on CodeAlpaca_20k for code generation
mamba-coder logo
## Base model info Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers. It is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4), with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention). ## Dataset info [CodeAlpaca_20K](https://huggingface.co/datasets/HuggingFaceH4/CodeAlpaca_20K): contains 20K instruction-following data used for fine-tuning the Code Alpaca model. ## Usage ```sh pip install torch==2.1.0 transformers==4.35.0 causal-conv1d==1.0.0 mamba-ssm==1.0.1 ``` ```py import torch from transformers import AutoTokenizer, AutoModelForCausalLM from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel CHAT_TEMPLATE_ID = "HuggingFaceH4/zephyr-7b-beta" device = "cuda:0" if torch.cuda.is_available() else "cpu" model_name = "mrm8488/mamba-coder" eos_token = "<|endoftext|>" tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.eos_token = eos_token tokenizer.pad_token = tokenizer.eos_token tokenizer.chat_template = AutoTokenizer.from_pretrained(CHAT_TEMPLATE_ID).chat_template model = MambaLMHeadModel.from_pretrained( model_name, device=device, dtype=torch.float16) messages = [] prompt = "Write a bash script to remove .tmp files" messages.append(dict(role="user", content=prompt)) input_ids = tokenizer.apply_chat_template( messages, return_tensors="pt", add_generation_prompt=True ).to(device) out = model.generate( input_ids=input_ids, max_length=2000, temperature=0.9, top_p=0.7, eos_token_id=tokenizer.eos_token_id, ) decoded = tokenizer.batch_decode(out) assistant_message = ( decoded[0].split("<|assistant|>\n")[-1].replace(eos_token, "") ) print(assistant_message) ``` ## Gradio Demo ```sh git clone https://github.com/mrm8488/mamba-chat.git cd mamba-chat pip install -r requirements.txt pip install -q gradio==4.8.0 python app.py \ --model mrm8488/mamba-coder \ --share ``` ## Evaluations Coming soon! ## Citation ```Bibtext @misc {manuel_romero_2024, author = { {Manuel Romero} }, title = { mamba-coder (Revision 214a13a) }, year = 2024, url = { https://huggingface.co/mrm8488/mamba-coder }, doi = { 10.57967/hf/1673 }, publisher = { Hugging Face } } ``` ## Acknowledgments Thanks to [mamba-chat](https://github.com/havenhq/mamba-chat/tree/main) for heavily inspiring our work