mrm8488 commited on
Commit
dc6001b
1 Parent(s): 09a0ce7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +5 -7
README.md CHANGED
@@ -48,9 +48,7 @@ No Robots is a high-quality dataset of 10,000 instructions and demonstrations cr
48
  ## Usage
49
 
50
  ```sh
51
- pip install transformers
52
- pip install causal-conv1d<=1.0.2
53
- pip install mamba-ssm
54
  ```
55
 
56
  ```py
@@ -60,7 +58,7 @@ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
60
 
61
  CHAT_TEMPLATE_ID = "HuggingFaceH4/zephyr-7b-beta"
62
 
63
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
  model_name = "clibrain/mamba-2.8b-chat-no_robots"
65
 
66
  eos_token = "<|endoftext|>"
@@ -72,12 +70,12 @@ tokenizer.chat_template = AutoTokenizer.from_pretrained(CHAT_TEMPLATE_ID).chat_t
72
  model = MambaLMHeadModel.from_pretrained(
73
  model_name, device=device, dtype=torch.float16)
74
 
75
- history_dict: list[dict[str, str]] = []
76
  prompt = "Tell me 5 sites to visit in Spain"
77
- history_dict.append(dict(role="user", content=prompt))
78
 
79
  input_ids = tokenizer.apply_chat_template(
80
- history_dict, return_tensors="pt", add_generation_prompt=True
81
  ).to(device)
82
 
83
  out = model.generate(
 
48
  ## Usage
49
 
50
  ```sh
51
+ pip install torch==2.1.0 transformers==4.35.0 causal-conv1d==1.0.0 mamba-ssm==1.0.1
 
 
52
  ```
53
 
54
  ```py
 
58
 
59
  CHAT_TEMPLATE_ID = "HuggingFaceH4/zephyr-7b-beta"
60
 
61
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
62
  model_name = "clibrain/mamba-2.8b-chat-no_robots"
63
 
64
  eos_token = "<|endoftext|>"
 
70
  model = MambaLMHeadModel.from_pretrained(
71
  model_name, device=device, dtype=torch.float16)
72
 
73
+ messages = []
74
  prompt = "Tell me 5 sites to visit in Spain"
75
+ messages.append(dict(role="user", content=prompt))
76
 
77
  input_ids = tokenizer.apply_chat_template(
78
+ messages, return_tensors="pt", add_generation_prompt=True
79
  ).to(device)
80
 
81
  out = model.generate(