Update README.md
Browse files
README.md
CHANGED
@@ -5,20 +5,20 @@ This is a the state-spaces mamba-2.8b model, fine-tuned using Supervised Fine-tu
|
|
5 |
|
6 |
To run inference on this model, run the following code:
|
7 |
|
8 |
-
```
|
9 |
import torch
|
10 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
11 |
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
12 |
|
|
|
|
|
|
|
13 |
device = "cuda"
|
14 |
messages = []
|
15 |
|
16 |
user_message = f"[INST] what is a language model? [/INST]"
|
17 |
-
|
18 |
input_ids = tokenizer(user_message, return_tensors="pt").input_ids.to("cuda")
|
19 |
-
|
20 |
out = model.generate(input_ids=input_ids, max_length=500, temperature=0.9, top_p=0.7, eos_token_id=tokenizer.eos_token_id)
|
21 |
-
|
22 |
decoded = tokenizer.batch_decode(out)
|
23 |
|
24 |
|
|
|
5 |
|
6 |
To run inference on this model, run the following code:
|
7 |
|
8 |
+
```python
|
9 |
import torch
|
10 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
11 |
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
12 |
|
13 |
+
#Load the model
|
14 |
+
model = MambaLMHeadModel.from_pretrained("walebadr/mamba-2.8b-SFT", dtype=torch.bfloat16, device="cuda")
|
15 |
+
|
16 |
device = "cuda"
|
17 |
messages = []
|
18 |
|
19 |
user_message = f"[INST] what is a language model? [/INST]"
|
|
|
20 |
input_ids = tokenizer(user_message, return_tensors="pt").input_ids.to("cuda")
|
|
|
21 |
out = model.generate(input_ids=input_ids, max_length=500, temperature=0.9, top_p=0.7, eos_token_id=tokenizer.eos_token_id)
|
|
|
22 |
decoded = tokenizer.batch_decode(out)
|
23 |
|
24 |
|