mrm8488 commited on
Commit
ce7cf93
1 Parent(s): cc77574

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +6 -3
README.md CHANGED
@@ -7,16 +7,19 @@ pipeline_tag: text-generation
7
 
8
  # MAMBA (2.8B) 🐍 fine-tuned on H4/no_robots dataset for chat / instruction
9
 
10
- TBD
11
 
12
  ## Usage
13
 
14
  ```py
 
15
  from transformers import AutoTokenizer, AutoModelForCausalLM
16
  from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
17
 
18
  CHAT_TEMPLATE_ID = "HuggingFaceH4/zephyr-7b-beta"
19
 
 
 
20
  eos_token = "<|endoftext|>"
21
  tokenizer = AutoTokenizer.from_pretrained(model_name)
22
  tokenizer.eos_token = eos_token
@@ -24,7 +27,7 @@ tokenizer.pad_token = tokenizer.eos_token
24
  tokenizer.chat_template = AutoTokenizer.from_pretrained(CHAT_TEMPLATE_ID).chat_template
25
 
26
  model = MambaLMHeadModel.from_pretrained(
27
- model_name, device="cuda", dtype=torch.float16)
28
 
29
  history_dict: list[dict[str, str]] = []
30
  prompt = "Tell me 5 sites to visit in Spain"
@@ -32,7 +35,7 @@ history_dict.append(dict(role="user", content=prompt))
32
 
33
  input_ids = tokenizer.apply_chat_template(
34
  history_dict, return_tensors="pt", add_generation_prompt=True
35
- ).to(device)
36
 
37
  out = model.generate(
38
  input_ids=input_ids,
 
7
 
8
  # MAMBA (2.8B) 🐍 fine-tuned on H4/no_robots dataset for chat / instruction
9
 
10
+ Model Card is still WIP!
11
 
12
  ## Usage
13
 
14
  ```py
15
+ import torch
16
  from transformers import AutoTokenizer, AutoModelForCausalLM
17
  from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
18
 
19
  CHAT_TEMPLATE_ID = "HuggingFaceH4/zephyr-7b-beta"
20
 
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+
23
  eos_token = "<|endoftext|>"
24
  tokenizer = AutoTokenizer.from_pretrained(model_name)
25
  tokenizer.eos_token = eos_token
 
27
  tokenizer.chat_template = AutoTokenizer.from_pretrained(CHAT_TEMPLATE_ID).chat_template
28
 
29
  model = MambaLMHeadModel.from_pretrained(
30
+ model_name, device=device, dtype=torch.float16)
31
 
32
  history_dict: list[dict[str, str]] = []
33
  prompt = "Tell me 5 sites to visit in Spain"
 
35
 
36
  input_ids = tokenizer.apply_chat_template(
37
  history_dict, return_tensors="pt", add_generation_prompt=True
38
+ ).to(device)
39
 
40
  out = model.generate(
41
  input_ids=input_ids,