mrm8488 commited on
Commit
cc77574
1 Parent(s): 7f612cd

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +46 -1
README.md CHANGED
@@ -7,4 +7,49 @@ pipeline_tag: text-generation
7
 
8
  # MAMBA (2.8B) 🐍 fine-tuned on H4/no_robots dataset for chat / instruction
9
 
10
- TBD
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # MAMBA (2.8B) 🐍 fine-tuned on H4/no_robots dataset for chat / instruction
9
 
10
+ TBD
11
+
12
+ ## Usage
13
+
14
+ ```py
15
+ from transformers import AutoTokenizer, AutoModelForCausalLM
16
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
17
+
18
+ CHAT_TEMPLATE_ID = "HuggingFaceH4/zephyr-7b-beta"
19
+
20
+ eos_token = "<|endoftext|>"
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ tokenizer.eos_token = eos_token
23
+ tokenizer.pad_token = tokenizer.eos_token
24
+ tokenizer.chat_template = AutoTokenizer.from_pretrained(CHAT_TEMPLATE_ID).chat_template
25
+
26
+ model = MambaLMHeadModel.from_pretrained(
27
+ model_name, device="cuda", dtype=torch.float16)
28
+
29
+ history_dict: list[dict[str, str]] = []
30
+ prompt = "Tell me 5 sites to visit in Spain"
31
+ history_dict.append(dict(role="user", content=prompt))
32
+
33
+ input_ids = tokenizer.apply_chat_template(
34
+ history_dict, return_tensors="pt", add_generation_prompt=True
35
+ ).to(device)
36
+
37
+ out = model.generate(
38
+ input_ids=input_ids,
39
+ max_length=2000,
40
+ temperature=0.9,
41
+ top_p=0.7,
42
+ eos_token_id=tokenizer.eos_token_id,
43
+ )
44
+
45
+ decoded = tokenizer.batch_decode(out)
46
+ assistant_message = (
47
+ decoded[0].split("<|assistant|>\n")[-1].replace(eos, "")
48
+ )
49
+
50
+ print(assistant_message)
51
+ ```
52
+
53
+ ## Evaluations
54
+
55
+ Coming soon!