Husnain commited on
Commit
99c68b6
1 Parent(s): fcdcbfd

💥 [Fix] Revert chat templates with mannual version as nous-mixtral-8x…

Browse files
Files changed (1) hide show
  1. messagers/message_composer.py +26 -129
messagers/message_composer.py CHANGED
@@ -1,23 +1,19 @@
1
  import re
2
  from pprint import pprint
 
 
 
 
3
  from utils.logger import logger
4
 
5
 
6
  class MessageComposer:
7
- # LINK - apis/chat_api.py#available-models
8
- AVALAIBLE_MODELS = [
9
- "mixtral-8x7b",
10
- "mistral-7b",
11
- "openchat-3.5",
12
- "nous-mixtral-8x7b",
13
- "gemma-7b",
14
- ]
15
-
16
  def __init__(self, model: str = None):
17
- if model in self.AVALAIBLE_MODELS:
18
  self.model = model
19
  else:
20
  self.model = "mixtral-8x7b"
 
21
  self.system_roles = ["system"]
22
  self.inst_roles = ["user", "system", "inst"]
23
  self.answer_roles = ["assistant", "bot", "answer", "model"]
@@ -51,12 +47,16 @@ class MessageComposer:
51
  return concat_messages
52
 
53
  def merge(self, messages) -> str:
 
 
 
 
 
 
 
54
  # Mistral and Mixtral:
55
  # <s> [INST] Instruction [/INST] Model answer </s> [INST] Follow-up instruction [/INST]
56
 
57
- # OpenChat:
58
- # GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi<|end_of_turn|>GPT4 Correct User: How are you today?<|end_of_turn|>GPT4 Correct Assistant:
59
-
60
  # Nous Mixtral:
61
  # <|im_start|>system
62
  # You are "Hermes 2".<|im_end|>
@@ -64,6 +64,9 @@ class MessageComposer:
64
  # Hello, who are you?<|im_end|>
65
  # <|im_start|>assistant
66
 
 
 
 
67
  # Google Gemma-it
68
  # <start_of_turn>user
69
  # How does the brain work?<end_of_turn>
@@ -145,127 +148,23 @@ class MessageComposer:
145
  )
146
  self.merged_str_list.append(f"{self.start_of_turn}model\n")
147
  self.merged_str = "\n".join(self.merged_str_list)
 
 
 
148
  else:
149
- self.merged_str = "\n".join(
150
- [
151
- f'`{message["role"]}`:\n{message["content"]}\n'
152
- for message in self.messages
153
- ]
154
  )
155
 
156
  return self.merged_str
157
 
158
- def convert_pair_matches_to_messages(self, pair_matches_list):
159
- messages = []
160
- if len(pair_matches_list) <= 0:
161
- messages = [
162
- {
163
- "role": "user",
164
- "content": self.merged_str,
165
- }
166
- ]
167
- else:
168
- for match in pair_matches_list:
169
- inst = match.group("inst")
170
- answer = match.group("answer")
171
- messages.extend(
172
- [
173
- {"role": "user", "content": inst.strip()},
174
- {"role": "assistant", "content": answer.strip()},
175
- ]
176
- )
177
- return messages
178
-
179
- def append_last_instruction_to_messages(self, inst_matches_list, pair_matches_list):
180
- if len(inst_matches_list) > len(pair_matches_list):
181
- self.messages.extend(
182
- [
183
- {
184
- "role": "user",
185
- "content": inst_matches_list[-1].group("inst").strip(),
186
- }
187
- ]
188
- )
189
-
190
- def split(self, merged_str) -> list:
191
- self.merged_str = merged_str
192
- self.messages = []
193
-
194
- if self.model in ["mixtral-8x7b", "mistral-7b"]:
195
- pair_pattern = (
196
- r"<s>\s*\[INST\](?P<inst>[\s\S]*?)\[/INST\](?P<answer>[\s\S]*?)</s>"
197
- )
198
- pair_matches = re.finditer(pair_pattern, self.merged_str, re.MULTILINE)
199
- pair_matches_list = list(pair_matches)
200
-
201
- self.messages = self.convert_pair_matches_to_messages(pair_matches_list)
202
-
203
- inst_pattern = r"\[INST\](?P<inst>[\s\S]*?)\[/INST\]"
204
- inst_matches = re.finditer(inst_pattern, self.merged_str, re.MULTILINE)
205
- inst_matches_list = list(inst_matches)
206
-
207
- self.append_last_instruction_to_messages(
208
- inst_matches_list, pair_matches_list
209
- )
210
- elif self.model in ["nous-mixtral-8x7b"]:
211
- # https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO#prompt-format
212
- # message_pattern = r"<\|im_start\|>(?P<role>system|user|assistant)[\s\n]*(?P<content>[\s\S]*?)<\|im_end\|>"
213
- message_pattern = r"<\|im_start\|>(?P<role>system|user|assistant)[\s\n]*(?P<content>[\s\S]*?)<\|im_end\|>"
214
- message_matches = re.finditer(
215
- message_pattern, self.merged_str, flags=re.MULTILINE | re.IGNORECASE
216
- )
217
- message_matches_list = list(message_matches)
218
- logger.note(f"message_matches_list: {message_matches_list}")
219
- for match in message_matches_list:
220
- role = match.group("role")
221
- content = match.group("content")
222
- self.messages.append({"role": role, "content": content.strip()})
223
- elif self.model in ["openchat-3.5"]:
224
- pair_pattern = r"GPT4 Correct User:(?P<inst>[\s\S]*?)<\|end_of_turn\|>\s*GPT4 Correct Assistant:(?P<answer>[\s\S]*?)<\|end_of_turn\|>"
225
- pair_matches = re.finditer(
226
- pair_pattern, self.merged_str, flags=re.MULTILINE | re.IGNORECASE
227
- )
228
- pair_matches_list = list(pair_matches)
229
- self.messages = self.convert_pair_matches_to_messages(pair_matches_list)
230
- inst_pattern = r"GPT4 Correct User:(?P<inst>[\s\S]*?)<\|end_of_turn\|>"
231
- inst_matches = re.finditer(
232
- inst_pattern, self.merged_str, flags=re.MULTILINE | re.IGNORECASE
233
- )
234
- inst_matches_list = list(inst_matches)
235
- self.append_last_instruction_to_messages(
236
- inst_matches_list, pair_matches_list
237
- )
238
- # https://huggingface.co/google/gemma-7b-it#chat-template
239
- elif self.model in ["gemma-7b"]:
240
- pair_pattern = r"<start_of_turn>user[\s\n]*(?P<inst>[\s\S]*?)<end_of_turn>[\s\n]*<start_of_turn>model(?P<answer>[\s\S]*?)<end_of_turn>"
241
- pair_matches = re.finditer(
242
- pair_pattern, self.merged_str, flags=re.MULTILINE | re.IGNORECASE
243
- )
244
- pair_matches_list = list(pair_matches)
245
- self.messages = self.convert_pair_matches_to_messages(pair_matches_list)
246
- inst_pattern = r"<start_of_turn>user\n(?P<inst>[\s\S]*?)<end_of_turn>"
247
- inst_matches = re.finditer(
248
- inst_pattern, self.merged_str, flags=re.MULTILINE | re.IGNORECASE
249
- )
250
- inst_matches_list = list(inst_matches)
251
- self.append_last_instruction_to_messages(
252
- inst_matches_list, pair_matches_list
253
- )
254
- else:
255
- self.messages = [
256
- {
257
- "role": "user",
258
- "content": self.merged_str,
259
- }
260
- ]
261
-
262
- return self.messages
263
-
264
 
265
  if __name__ == "__main__":
266
  # model = "mixtral-8x7b"
267
  # model = "nous-mixtral-8x7b"
268
- model = "gemma-7b"
 
269
  composer = MessageComposer(model)
270
  messages = [
271
  {
@@ -287,7 +186,5 @@ if __name__ == "__main__":
287
  merged_str = composer.merge(messages)
288
  logger.note("merged_str:")
289
  logger.mesg(merged_str)
290
- logger.note("splitted messages:")
291
- pprint(composer.split(merged_str))
292
- # logger.note("merged merged_str:")
293
- # logger.mesg(composer.merge(composer.split(merged_str)))
 
1
  import re
2
  from pprint import pprint
3
+
4
+ from transformers import AutoTokenizer
5
+
6
+ from constants.models import AVAILABLE_MODELS, MODEL_MAP
7
  from utils.logger import logger
8
 
9
 
10
  class MessageComposer:
 
 
 
 
 
 
 
 
 
11
  def __init__(self, model: str = None):
12
+ if model in AVAILABLE_MODELS:
13
  self.model = model
14
  else:
15
  self.model = "mixtral-8x7b"
16
+ self.model_fullname = MODEL_MAP[self.model]
17
  self.system_roles = ["system"]
18
  self.inst_roles = ["user", "system", "inst"]
19
  self.answer_roles = ["assistant", "bot", "answer", "model"]
 
47
  return concat_messages
48
 
49
  def merge(self, messages) -> str:
50
+ # Templates for Chat Models
51
+ # - https://huggingface.co/docs/transformers/main/en/chat_templating
52
+ # - https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1#instruction-format
53
+ # - https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO#prompt-format
54
+ # - https://huggingface.co/openchat/openchat-3.5-0106
55
+ # - https://huggingface.co/google/gemma-7b-it#chat-template
56
+
57
  # Mistral and Mixtral:
58
  # <s> [INST] Instruction [/INST] Model answer </s> [INST] Follow-up instruction [/INST]
59
 
 
 
 
60
  # Nous Mixtral:
61
  # <|im_start|>system
62
  # You are "Hermes 2".<|im_end|>
 
64
  # Hello, who are you?<|im_end|>
65
  # <|im_start|>assistant
66
 
67
+ # OpenChat:
68
+ # GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi<|end_of_turn|>GPT4 Correct User: How are you today?<|end_of_turn|>GPT4 Correct Assistant:
69
+
70
  # Google Gemma-it
71
  # <start_of_turn>user
72
  # How does the brain work?<end_of_turn>
 
148
  )
149
  self.merged_str_list.append(f"{self.start_of_turn}model\n")
150
  self.merged_str = "\n".join(self.merged_str_list)
151
+ # https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO#prompt-format
152
+ # https://huggingface.co/openchat/openchat-3.5-0106
153
+ # elif self.model in ["openchat-3.5", "nous-mixtral-8x7b"]:
154
  else:
155
+ tokenizer = AutoTokenizer.from_pretrained("openchat/openchat-3.5-0106")
156
+ self.merged_str = tokenizer.apply_chat_template(
157
+ messages, tokenize=False, add_generation_prompt=True
 
 
158
  )
159
 
160
  return self.merged_str
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  if __name__ == "__main__":
164
  # model = "mixtral-8x7b"
165
  # model = "nous-mixtral-8x7b"
166
+ # model = "gemma-7b"
167
+ model = "openchat-3.5"
168
  composer = MessageComposer(model)
169
  messages = [
170
  {
 
186
  merged_str = composer.merge(messages)
187
  logger.note("merged_str:")
188
  logger.mesg(merged_str)
189
+
190
+ # python -m messagers.message_composer