Update README.md
Browse files
README.md
CHANGED
@@ -42,6 +42,112 @@ There was decent amount of empty assistant response on original dataset. So, I d
|
|
42 |
|
43 |
---
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
## Statistics
|
46 |
|
47 |
### Context length
|
|
|
42 |
|
43 |
---
|
44 |
|
45 |
+
## Example Code
|
46 |
+
|
47 |
+
**The code below is modified from** (**PairRM-hf Repo**)[https://huggingface.co/llm-blender/PairRM-hf]
|
48 |
+
|
49 |
+
```python
|
50 |
+
import os
|
51 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
52 |
+
from llm_blender.pair_ranker.pairrm import DebertaV2PairRM
|
53 |
+
from transformers import AutoTokenizer
|
54 |
+
from typing import List
|
55 |
+
pairrm = DebertaV2PairRM.from_pretrained("maywell/Better-PairRM", device_map="cuda:0").eval()
|
56 |
+
tokenizer = AutoTokenizer.from_pretrained("maywell/Better-PairRM")
|
57 |
+
source_prefix = "<|source|>"
|
58 |
+
cand1_prefix = "<|candidate1|>"
|
59 |
+
cand2_prefix = "<|candidate2|>"
|
60 |
+
inputs = ["hello!", "I love you!"]
|
61 |
+
candidates_A = ["hi!", "I hate you!"]
|
62 |
+
candidates_B = ["f**k off!", "I love you, too!"]
|
63 |
+
def tokenize_pair(sources:List[str], candidate1s:List[str], candidate2s:List[str], source_max_length=2030, candidate_max_length=670):
|
64 |
+
ids = []
|
65 |
+
assert len(sources) == len(candidate1s) == len(candidate2s)
|
66 |
+
max_length = source_max_length + 2 * candidate_max_length
|
67 |
+
for i in range(len(sources)):
|
68 |
+
source_ids = tokenizer.encode(source_prefix + sources[i], max_length=source_max_length, truncation=True)
|
69 |
+
candidate_max_length = (max_length - len(source_ids)) // 2
|
70 |
+
candidate1_ids = tokenizer.encode(cand1_prefix + candidate1s[i], max_length=candidate_max_length, truncation=True)
|
71 |
+
candidate2_ids = tokenizer.encode(cand2_prefix + candidate2s[i], max_length=candidate_max_length, truncation=True)
|
72 |
+
ids.append(source_ids + candidate1_ids + candidate2_ids)
|
73 |
+
encodings = tokenizer.pad({"input_ids": ids}, return_tensors="pt", padding="max_length", max_length=max_length)
|
74 |
+
return encodings
|
75 |
+
|
76 |
+
encodings = tokenize_pair(inputs, candidates_A, candidates_B)
|
77 |
+
encodings = {k:v.to(pairrm.device) for k,v in encodings.items()}
|
78 |
+
outputs = pairrm(**encodings)
|
79 |
+
logits = outputs.logits.tolist()
|
80 |
+
comparison_results = outputs.logits > 0
|
81 |
+
print(logits)
|
82 |
+
print(comparison_results)
|
83 |
+
```
|
84 |
+
|
85 |
+
You can also easily compare two conversations like the followings:
|
86 |
+
```python
|
87 |
+
from transformers import AutoTokenizer
|
88 |
+
|
89 |
+
tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-large")
|
90 |
+
|
91 |
+
def truncate_texts(text, max_length, truncate_side):
|
92 |
+
tokenizer.truncation_side = truncate_side
|
93 |
+
tokens = tokenizer.encode(text, add_special_tokens=False, max_length=max_length)
|
94 |
+
truncated_text = tokenizer.decode(tokens, skip_special_tokens=True)
|
95 |
+
return truncated_text
|
96 |
+
|
97 |
+
MY_JINJA_TEMPLATE = """{% for message in messages -%}
|
98 |
+
{% if message['role'] == 'user' -%}
|
99 |
+
USER: {{ message['content']|trim -}}
|
100 |
+
{% if not loop.last -%}
|
101 |
+
|
102 |
+
|
103 |
+
{% endif %}
|
104 |
+
{% elif message['role'] == 'assistant' -%}
|
105 |
+
ASSISTANT: {{ message['content']|trim -}}
|
106 |
+
{% if not loop.last -%}
|
107 |
+
|
108 |
+
|
109 |
+
{% endif %}
|
110 |
+
{% elif message['role'] == 'user_context' -%}
|
111 |
+
USER: {{ message['content']|trim -}}
|
112 |
+
{% if not loop.last -%}
|
113 |
+
|
114 |
+
|
115 |
+
{% endif %}
|
116 |
+
{% elif message['role'] == 'system' -%}
|
117 |
+
SYSTEM MESSAGE: {{ message['content']|trim -}}
|
118 |
+
{% if not loop.last -%}
|
119 |
+
|
120 |
+
|
121 |
+
{% endif %}
|
122 |
+
{% endif %}
|
123 |
+
{% endfor -%}
|
124 |
+
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' -%}
|
125 |
+
ASSISTANT: {% endif -%}"""
|
126 |
+
|
127 |
+
my_jinja2_env = jinja2.Environment()
|
128 |
+
my_jinja2_template = my_jinja2_env.from_string(MY_JINJA_TEMPLATE)
|
129 |
+
|
130 |
+
def tokenize_conv_pair(convAs: List[str], convBs: List[str]):
|
131 |
+
|
132 |
+
# check conversations correctness
|
133 |
+
assert len(convAs) == len(convBs), "Number of conversations must be the same"
|
134 |
+
for c_a, c_b in zip(convAs, convBs):
|
135 |
+
assert len(c_a) == len(c_b), "Number of turns in each conversation must be the same"
|
136 |
+
assert all([c_a[i]['content'] == c_b[i]['content'] for i in range(0, len(c_a), 2)]), "USER turns must be the same"
|
137 |
+
|
138 |
+
inputs = [
|
139 |
+
truncate_texts(my_jinja2_template.render(messages=x[:-1], add_generation_prompt=True), 2030, "left") for x in convAs
|
140 |
+
]
|
141 |
+
cand1_texts = [
|
142 |
+
truncate_texts(x[-1]['content'], 670, "right") for x in convAs
|
143 |
+
]
|
144 |
+
cand2_texts = [
|
145 |
+
truncate_texts(x[-1]['content'], 670, "right") for x in convBs
|
146 |
+
]
|
147 |
+
encodings = tokenize_pair(inputs, cand1_texts, cand2_texts)
|
148 |
+
return encodings
|
149 |
+
```
|
150 |
+
|
151 |
## Statistics
|
152 |
|
153 |
### Context length
|