Commit
•
b2f21ff
1
Parent(s):
73bf11b
Update README.md
Browse files
README.md
CHANGED
@@ -70,6 +70,12 @@ We evaluate the model on [Passkey Retrieval](https://arxiv.org/abs/2309.12307) t
|
|
70 |
|
71 |
<img src="data/passkey.png"></img>
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
# Usage
|
74 |
```python
|
75 |
import json
|
@@ -85,9 +91,9 @@ model = model.cuda().eval()
|
|
85 |
|
86 |
with torch.no_grad():
|
87 |
# short context
|
88 |
-
|
89 |
-
inputs = tokenizer(
|
90 |
-
outputs = model.generate(**inputs, max_new_tokens=
|
91 |
print(f"Input Length: {inputs['input_ids'].shape[1]}")
|
92 |
print(f"Output: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
|
93 |
|
@@ -97,7 +103,8 @@ with torch.no_grad():
|
|
97 |
# long context
|
98 |
with open("data/infbench.json", encoding="utf-8") as f:
|
99 |
example = json.load(f)
|
100 |
-
|
|
|
101 |
outputs = model.generate(**inputs, do_sample=False, top_p=1, temperature=1, max_new_tokens=20)[:, inputs["input_ids"].shape[1]:]
|
102 |
print("*"*20)
|
103 |
print(f"Input Length: {inputs['input_ids'].shape[1]}")
|
|
|
70 |
|
71 |
<img src="data/passkey.png"></img>
|
72 |
|
73 |
+
# Environment
|
74 |
+
```bash
|
75 |
+
torch>=2.1.1
|
76 |
+
transformers==4.39.3
|
77 |
+
```
|
78 |
+
|
79 |
# Usage
|
80 |
```python
|
81 |
import json
|
|
|
91 |
|
92 |
with torch.no_grad():
|
93 |
# short context
|
94 |
+
messages = [{"role": "user", "content": "Tell me about yourself."}]
|
95 |
+
inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda")
|
96 |
+
outputs = model.generate(**inputs, max_new_tokens=50)
|
97 |
print(f"Input Length: {inputs['input_ids'].shape[1]}")
|
98 |
print(f"Output: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
|
99 |
|
|
|
103 |
# long context
|
104 |
with open("data/infbench.json", encoding="utf-8") as f:
|
105 |
example = json.load(f)
|
106 |
+
messages = [{"role": "user", "content": example["context"]}]
|
107 |
+
inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda")
|
108 |
outputs = model.generate(**inputs, do_sample=False, top_p=1, temperature=1, max_new_tokens=20)[:, inputs["input_ids"].shape[1]:]
|
109 |
print("*"*20)
|
110 |
print(f"Input Length: {inputs['input_ids'].shape[1]}")
|