maywell commited on
Commit
0319c87
1 Parent(s): c9f7a5d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +106 -0
README.md CHANGED
@@ -42,6 +42,112 @@ There was decent amount of empty assistant response on original dataset. So, I d
42
 
43
  ---
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  ## Statistics
46
 
47
  ### Context length
 
42
 
43
  ---
44
 
45
+ ## Example Code
46
+
47
+ **The code below is modified from** (**PairRM-hf Repo**)[https://huggingface.co/llm-blender/PairRM-hf]
48
+
49
+ ```python
50
+ import os
51
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
52
+ from llm_blender.pair_ranker.pairrm import DebertaV2PairRM
53
+ from transformers import AutoTokenizer
54
+ from typing import List
55
+ pairrm = DebertaV2PairRM.from_pretrained("maywell/Better-PairRM", device_map="cuda:0").eval()
56
+ tokenizer = AutoTokenizer.from_pretrained("maywell/Better-PairRM")
57
+ source_prefix = "<|source|>"
58
+ cand1_prefix = "<|candidate1|>"
59
+ cand2_prefix = "<|candidate2|>"
60
+ inputs = ["hello!", "I love you!"]
61
+ candidates_A = ["hi!", "I hate you!"]
62
+ candidates_B = ["f**k off!", "I love you, too!"]
63
+ def tokenize_pair(sources:List[str], candidate1s:List[str], candidate2s:List[str], source_max_length=2030, candidate_max_length=670):
64
+ ids = []
65
+ assert len(sources) == len(candidate1s) == len(candidate2s)
66
+ max_length = source_max_length + 2 * candidate_max_length
67
+ for i in range(len(sources)):
68
+ source_ids = tokenizer.encode(source_prefix + sources[i], max_length=source_max_length, truncation=True)
69
+ candidate_max_length = (max_length - len(source_ids)) // 2
70
+ candidate1_ids = tokenizer.encode(cand1_prefix + candidate1s[i], max_length=candidate_max_length, truncation=True)
71
+ candidate2_ids = tokenizer.encode(cand2_prefix + candidate2s[i], max_length=candidate_max_length, truncation=True)
72
+ ids.append(source_ids + candidate1_ids + candidate2_ids)
73
+ encodings = tokenizer.pad({"input_ids": ids}, return_tensors="pt", padding="max_length", max_length=max_length)
74
+ return encodings
75
+
76
+ encodings = tokenize_pair(inputs, candidates_A, candidates_B)
77
+ encodings = {k:v.to(pairrm.device) for k,v in encodings.items()}
78
+ outputs = pairrm(**encodings)
79
+ logits = outputs.logits.tolist()
80
+ comparison_results = outputs.logits > 0
81
+ print(logits)
82
+ print(comparison_results)
83
+ ```
84
+
85
+ You can also easily compare two conversations like the followings:
86
+ ```python
87
+ from transformers import AutoTokenizer
88
+
89
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-large")
90
+
91
+ def truncate_texts(text, max_length, truncate_side):
92
+ tokenizer.truncation_side = truncate_side
93
+ tokens = tokenizer.encode(text, add_special_tokens=False, max_length=max_length)
94
+ truncated_text = tokenizer.decode(tokens, skip_special_tokens=True)
95
+ return truncated_text
96
+
97
+ MY_JINJA_TEMPLATE = """{% for message in messages -%}
98
+ {% if message['role'] == 'user' -%}
99
+ USER: {{ message['content']|trim -}}
100
+ {% if not loop.last -%}
101
+
102
+
103
+ {% endif %}
104
+ {% elif message['role'] == 'assistant' -%}
105
+ ASSISTANT: {{ message['content']|trim -}}
106
+ {% if not loop.last -%}
107
+
108
+
109
+ {% endif %}
110
+ {% elif message['role'] == 'user_context' -%}
111
+ USER: {{ message['content']|trim -}}
112
+ {% if not loop.last -%}
113
+
114
+
115
+ {% endif %}
116
+ {% elif message['role'] == 'system' -%}
117
+ SYSTEM MESSAGE: {{ message['content']|trim -}}
118
+ {% if not loop.last -%}
119
+
120
+
121
+ {% endif %}
122
+ {% endif %}
123
+ {% endfor -%}
124
+ {% if add_generation_prompt and messages[-1]['role'] != 'assistant' -%}
125
+ ASSISTANT: {% endif -%}"""
126
+
127
+ my_jinja2_env = jinja2.Environment()
128
+ my_jinja2_template = my_jinja2_env.from_string(MY_JINJA_TEMPLATE)
129
+
130
+ def tokenize_conv_pair(convAs: List[str], convBs: List[str]):
131
+
132
+ # check conversations correctness
133
+ assert len(convAs) == len(convBs), "Number of conversations must be the same"
134
+ for c_a, c_b in zip(convAs, convBs):
135
+ assert len(c_a) == len(c_b), "Number of turns in each conversation must be the same"
136
+ assert all([c_a[i]['content'] == c_b[i]['content'] for i in range(0, len(c_a), 2)]), "USER turns must be the same"
137
+
138
+ inputs = [
139
+ truncate_texts(my_jinja2_template.render(messages=x[:-1], add_generation_prompt=True), 2030, "left") for x in convAs
140
+ ]
141
+ cand1_texts = [
142
+ truncate_texts(x[-1]['content'], 670, "right") for x in convAs
143
+ ]
144
+ cand2_texts = [
145
+ truncate_texts(x[-1]['content'], 670, "right") for x in convBs
146
+ ]
147
+ encodings = tokenize_pair(inputs, cand1_texts, cand2_texts)
148
+ return encodings
149
+ ```
150
+
151
  ## Statistics
152
 
153
  ### Context length