mrm8488 commited on
Commit
287cb88
1 Parent(s): fd9b55a

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +100 -0
README.md ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: wtfpl
3
+ datasets:
4
+ - HuggingFaceH4/CodeAlpaca_20K
5
+ pipeline_tag: text-generation
6
+ thumbnail:
7
+ language:
8
+ - en
9
+ - code
10
+ ---
11
+
12
+ # Mamba-Coder
13
+ ## MAMBA (2.8B) 🐍 fine-tuned on CodeAlpaca_20k for code generation
14
+
15
+ <div style="text-align:center;width:250px;height:250px;">
16
+ <img src="" alt="mamba-coder logo"">
17
+ </div>
18
+
19
+
20
+ ## Base model info
21
+
22
+ Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers.
23
+ It is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4),
24
+ with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention).
25
+
26
+ ## Dataset info
27
+
28
+ [CodeAlpaca_20K](https://huggingface.co/datasets/HuggingFaceH4/CodeAlpaca_20K): contains 20K instruction-following data used for fine-tuning the Code Alpaca model.
29
+
30
+ ## Usage
31
+
32
+ ```sh
33
+ pip install torch==2.1.0 transformers==4.35.0 causal-conv1d==1.0.0 mamba-ssm==1.0.1
34
+ ```
35
+
36
+ ```py
37
+ import torch
38
+ from transformers import AutoTokenizer, AutoModelForCausalLM
39
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
40
+
41
+ CHAT_TEMPLATE_ID = "HuggingFaceH4/zephyr-7b-beta"
42
+
43
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
44
+ model_name = "mrm8488/mamba-coder"
45
+
46
+ eos_token = "<|endoftext|>"
47
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
48
+ tokenizer.eos_token = eos_token
49
+ tokenizer.pad_token = tokenizer.eos_token
50
+ tokenizer.chat_template = AutoTokenizer.from_pretrained(CHAT_TEMPLATE_ID).chat_template
51
+
52
+ model = MambaLMHeadModel.from_pretrained(
53
+ model_name, device=device, dtype=torch.float16)
54
+
55
+ messages = []
56
+ prompt = "Write a bash script to remove .tmp files"
57
+ messages.append(dict(role="user", content=prompt))
58
+
59
+ input_ids = tokenizer.apply_chat_template(
60
+ messages, return_tensors="pt", add_generation_prompt=True
61
+ ).to(device)
62
+
63
+ out = model.generate(
64
+ input_ids=input_ids,
65
+ max_length=2000,
66
+ temperature=0.9,
67
+ top_p=0.7,
68
+ eos_token_id=tokenizer.eos_token_id,
69
+ )
70
+
71
+ decoded = tokenizer.batch_decode(out)
72
+ assistant_message = (
73
+ decoded[0].split("<|assistant|>\n")[-1].replace(eos_token, "")
74
+ )
75
+
76
+ print(assistant_message)
77
+ ```
78
+
79
+
80
+ ## Gradio Demo
81
+
82
+ ```sh
83
+ git clone https://github.com/mrm8488/mamba-chat.git
84
+ cd mamba-chat
85
+
86
+ pip install -r requirements.txt
87
+ pip install -q gradio==4.8.0
88
+
89
+ python app.py \
90
+ --model mrm8488/mamba-coder \
91
+ --share
92
+ ```
93
+ ## Evaluations
94
+
95
+ Coming soon!
96
+
97
+
98
+ ## Acknowledgments
99
+
100
+ Thanks to [mamba-chat](https://github.com/havenhq/mamba-chat/tree/main) for heavily inspiring our work