walebadr commited on
Commit
f7f2a3d
1 Parent(s): 38646c7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -4
README.md CHANGED
@@ -5,20 +5,20 @@ This is a the state-spaces mamba-2.8b model, fine-tuned using Supervised Fine-tu
5
 
6
  To run inference on this model, run the following code:
7
 
8
- ```
9
  import torch
10
  from transformers import AutoTokenizer, AutoModelForCausalLM
11
  from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
12
 
 
 
 
13
  device = "cuda"
14
  messages = []
15
 
16
  user_message = f"[INST] what is a language model? [/INST]"
17
-
18
  input_ids = tokenizer(user_message, return_tensors="pt").input_ids.to("cuda")
19
-
20
  out = model.generate(input_ids=input_ids, max_length=500, temperature=0.9, top_p=0.7, eos_token_id=tokenizer.eos_token_id)
21
-
22
  decoded = tokenizer.batch_decode(out)
23
 
24
 
 
5
 
6
  To run inference on this model, run the following code:
7
 
8
+ ```python
9
  import torch
10
  from transformers import AutoTokenizer, AutoModelForCausalLM
11
  from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
12
 
13
+ #Load the model
14
+ model = MambaLMHeadModel.from_pretrained("walebadr/mamba-2.8b-SFT", dtype=torch.bfloat16, device="cuda")
15
+
16
  device = "cuda"
17
  messages = []
18
 
19
  user_message = f"[INST] what is a language model? [/INST]"
 
20
  input_ids = tokenizer(user_message, return_tensors="pt").input_ids.to("cuda")
 
21
  out = model.generate(input_ids=input_ids, max_length=500, temperature=0.9, top_p=0.7, eos_token_id=tokenizer.eos_token_id)
 
22
  decoded = tokenizer.batch_decode(out)
23
 
24