namespace-Pt commited on
Commit
b2f21ff
1 Parent(s): 73bf11b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +11 -4
README.md CHANGED
@@ -70,6 +70,12 @@ We evaluate the model on [Passkey Retrieval](https://arxiv.org/abs/2309.12307) t
70
 
71
  <img src="data/passkey.png"></img>
72
 
 
 
 
 
 
 
73
  # Usage
74
  ```python
75
  import json
@@ -85,9 +91,9 @@ model = model.cuda().eval()
85
 
86
  with torch.no_grad():
87
  # short context
88
- text = "Tell me about yourself."
89
- inputs = tokenizer(text, return_tensors="pt").to("cuda")
90
- outputs = model.generate(**inputs, max_new_tokens=20)
91
  print(f"Input Length: {inputs['input_ids'].shape[1]}")
92
  print(f"Output: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
93
 
@@ -97,7 +103,8 @@ with torch.no_grad():
97
  # long context
98
  with open("data/infbench.json", encoding="utf-8") as f:
99
  example = json.load(f)
100
- inputs = tokenizer(example["context"], return_tensors="pt").to("cuda")
 
101
  outputs = model.generate(**inputs, do_sample=False, top_p=1, temperature=1, max_new_tokens=20)[:, inputs["input_ids"].shape[1]:]
102
  print("*"*20)
103
  print(f"Input Length: {inputs['input_ids'].shape[1]}")
 
70
 
71
  <img src="data/passkey.png"></img>
72
 
73
+ # Environment
74
+ ```bash
75
+ torch>=2.1.1
76
+ transformers==4.39.3
77
+ ```
78
+
79
  # Usage
80
  ```python
81
  import json
 
91
 
92
  with torch.no_grad():
93
  # short context
94
+ messages = [{"role": "user", "content": "Tell me about yourself."}]
95
+ inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda")
96
+ outputs = model.generate(**inputs, max_new_tokens=50)
97
  print(f"Input Length: {inputs['input_ids'].shape[1]}")
98
  print(f"Output: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
99
 
 
103
  # long context
104
  with open("data/infbench.json", encoding="utf-8") as f:
105
  example = json.load(f)
106
+ messages = [{"role": "user", "content": example["context"]}]
107
+ inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda")
108
  outputs = model.generate(**inputs, do_sample=False, top_p=1, temperature=1, max_new_tokens=20)[:, inputs["input_ids"].shape[1]:]
109
  print("*"*20)
110
  print(f"Input Length: {inputs['input_ids'].shape[1]}")