ka1kuk commited on
Commit
4855f7d
1 Parent(s): 66a1c16

Update messagers/message_composer.py

Browse files
Files changed (1) hide show
  1. messagers/message_composer.py +99 -9
messagers/message_composer.py CHANGED
@@ -1,5 +1,6 @@
1
  import re
2
  from pprint import pprint
 
3
 
4
 
5
  class MessageComposer:
@@ -8,6 +9,8 @@ class MessageComposer:
8
  "mixtral-8x7b",
9
  "mistral-7b",
10
  "openchat-3.5",
 
 
11
  ]
12
 
13
  def __init__(self, model: str = None):
@@ -15,8 +18,10 @@ class MessageComposer:
15
  self.model = model
16
  else:
17
  self.model = "mixtral-8x7b"
 
18
  self.inst_roles = ["user", "system", "inst"]
19
- self.answer_roles = ["assistant", "bot", "answer"]
 
20
 
21
  def concat_messages_by_role(self, messages):
22
  def is_same_role(role1, role2):
@@ -48,13 +53,28 @@ class MessageComposer:
48
  def merge(self, messages) -> str:
49
  # Mistral and Mixtral:
50
  # <s> [INST] Instruction [/INST] Model answer </s> [INST] Follow-up instruction [/INST]
 
51
  # OpenChat:
52
  # GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi<|end_of_turn|>GPT4 Correct User: How are you today?<|end_of_turn|>GPT4 Correct Assistant:
53
 
54
- self.messages = self.concat_messages_by_role(messages)
 
 
 
 
 
 
 
 
 
 
 
 
55
  self.merged_str = ""
56
 
 
57
  if self.model in ["mixtral-8x7b", "mistral-7b"]:
 
58
  self.cached_str = ""
59
  for message in self.messages:
60
  role = message["role"]
@@ -68,7 +88,21 @@ class MessageComposer:
68
  self.cached_str = f"[INST] {content} [/INST]"
69
  if self.cached_str:
70
  self.merged_str += f"{self.cached_str}"
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  elif self.model in ["openchat-3.5"]:
 
72
  self.merged_str_list = []
73
  self.end_of_turn = "<|end_of_turn|>"
74
  for message in self.messages:
@@ -88,6 +122,29 @@ class MessageComposer:
88
  )
89
  self.merged_str_list.append(f"GPT4 Correct Assistant:\n")
90
  self.merged_str = "\n".join(self.merged_str_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  else:
92
  self.merged_str = "\n".join(
93
  [
@@ -150,10 +207,21 @@ class MessageComposer:
150
  self.append_last_instruction_to_messages(
151
  inst_matches_list, pair_matches_list
152
  )
153
-
 
 
 
 
 
 
 
 
 
 
 
 
154
  elif self.model in ["openchat-3.5"]:
155
  pair_pattern = r"GPT4 Correct User:(?P<inst>[\s\S]*?)<\|end_of_turn\|>\s*GPT4 Correct Assistant:(?P<answer>[\s\S]*?)<\|end_of_turn\|>"
156
- # ignore case
157
  pair_matches = re.finditer(
158
  pair_pattern, self.merged_str, flags=re.MULTILINE | re.IGNORECASE
159
  )
@@ -167,6 +235,22 @@ class MessageComposer:
167
  self.append_last_instruction_to_messages(
168
  inst_matches_list, pair_matches_list
169
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  else:
171
  self.messages = [
172
  {
@@ -179,11 +263,14 @@ class MessageComposer:
179
 
180
 
181
  if __name__ == "__main__":
182
- composer = MessageComposer(model="openchat-3.5")
 
 
 
183
  messages = [
184
  {
185
  "role": "system",
186
- "content": "You are a LLM developed by OpenAI. Your name is GPT-4.",
187
  },
188
  {"role": "user", "content": "Hello, who are you?"},
189
  {"role": "assistant", "content": "I am a bot."},
@@ -196,8 +283,11 @@ if __name__ == "__main__":
196
  # "content": "How many questions have I asked? Please list them.",
197
  # },
198
  ]
199
- print("model:", composer.model)
200
  merged_str = composer.merge(messages)
201
- print(merged_str)
 
 
202
  pprint(composer.split(merged_str))
203
- # print(composer.merge(composer.split(merged_str)))
 
 
1
  import re
2
  from pprint import pprint
3
+ from utils.logger import logger
4
 
5
 
6
  class MessageComposer:
 
9
  "mixtral-8x7b",
10
  "mistral-7b",
11
  "openchat-3.5",
12
+ "nous-mixtral-8x7b",
13
+ "gemma-7b",
14
  ]
15
 
16
  def __init__(self, model: str = None):
 
18
  self.model = model
19
  else:
20
  self.model = "mixtral-8x7b"
21
+ self.system_roles = ["system"]
22
  self.inst_roles = ["user", "system", "inst"]
23
+ self.answer_roles = ["assistant", "bot", "answer", "model"]
24
+ self.default_role = "user"
25
 
26
  def concat_messages_by_role(self, messages):
27
  def is_same_role(role1, role2):
 
53
  def merge(self, messages) -> str:
54
  # Mistral and Mixtral:
55
  # <s> [INST] Instruction [/INST] Model answer </s> [INST] Follow-up instruction [/INST]
56
+
57
  # OpenChat:
58
  # GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi<|end_of_turn|>GPT4 Correct User: How are you today?<|end_of_turn|>GPT4 Correct Assistant:
59
 
60
+ # Nous Mixtral:
61
+ # <|im_start|>system
62
+ # You are "Hermes 2".<|im_end|>
63
+ # <|im_start|>user
64
+ # Hello, who are you?<|im_end|>
65
+ # <|im_start|>assistant
66
+
67
+ # Google Gemma-it
68
+ # <start_of_turn>user
69
+ # How does the brain work?<end_of_turn>
70
+ # <start_of_turn>model
71
+
72
+ self.messages = messages
73
  self.merged_str = ""
74
 
75
+ # https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1#instruction-format
76
  if self.model in ["mixtral-8x7b", "mistral-7b"]:
77
+ self.messages = self.concat_messages_by_role(messages)
78
  self.cached_str = ""
79
  for message in self.messages:
80
  role = message["role"]
 
88
  self.cached_str = f"[INST] {content} [/INST]"
89
  if self.cached_str:
90
  self.merged_str += f"{self.cached_str}"
91
+ # https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO#prompt-format
92
+ elif self.model in ["nous-mixtral-8x7b"]:
93
+ self.merged_str_list = []
94
+ for message in self.messages:
95
+ role = message["role"]
96
+ content = message["content"]
97
+ if role not in ["system", "user", "assistant"]:
98
+ role = self.default_role
99
+ message_line = f"<|im_start|>{role}\n{content}<|im_end|>"
100
+ self.merged_str_list.append(message_line)
101
+ self.merged_str_list.append("<|im_start|>assistant")
102
+ self.merged_str = "\n".join(self.merged_str_list)
103
+ # https://huggingface.co/openchat/openchat-3.5-0106
104
  elif self.model in ["openchat-3.5"]:
105
+ self.messages = self.concat_messages_by_role(messages)
106
  self.merged_str_list = []
107
  self.end_of_turn = "<|end_of_turn|>"
108
  for message in self.messages:
 
122
  )
123
  self.merged_str_list.append(f"GPT4 Correct Assistant:\n")
124
  self.merged_str = "\n".join(self.merged_str_list)
125
+ # https://huggingface.co/google/gemma-7b-it#chat-template
126
+ elif self.model in ["gemma-7b"]:
127
+ self.messages = self.concat_messages_by_role(messages)
128
+ self.merged_str_list = []
129
+ self.end_of_turn = "<end_of_turn>"
130
+ self.start_of_turn = "<start_of_turn>"
131
+ for message in self.messages:
132
+ role = message["role"]
133
+ content = message["content"]
134
+ if role in self.inst_roles:
135
+ self.merged_str_list.append(
136
+ f"{self.start_of_turn}user\n{content}{self.end_of_turn}"
137
+ )
138
+ elif role in self.answer_roles:
139
+ self.merged_str_list.append(
140
+ f"{self.start_of_turn}model\n{content}{self.end_of_turn}"
141
+ )
142
+ else:
143
+ self.merged_str_list.append(
144
+ f"{self.start_of_turn}user\n{content}{self.end_of_turn}"
145
+ )
146
+ self.merged_str_list.append(f"{self.start_of_turn}model\n")
147
+ self.merged_str = "\n".join(self.merged_str_list)
148
  else:
149
  self.merged_str = "\n".join(
150
  [
 
207
  self.append_last_instruction_to_messages(
208
  inst_matches_list, pair_matches_list
209
  )
210
+ elif self.model in ["nous-mixtral-8x7b"]:
211
+ # https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO#prompt-format
212
+ # message_pattern = r"<\|im_start\|>(?P<role>system|user|assistant)[\s\n]*(?P<content>[\s\S]*?)<\|im_end\|>"
213
+ message_pattern = r"<\|im_start\|>(?P<role>system|user|assistant)[\s\n]*(?P<content>[\s\S]*?)<\|im_end\|>"
214
+ message_matches = re.finditer(
215
+ message_pattern, self.merged_str, flags=re.MULTILINE | re.IGNORECASE
216
+ )
217
+ message_matches_list = list(message_matches)
218
+ logger.note(f"message_matches_list: {message_matches_list}")
219
+ for match in message_matches_list:
220
+ role = match.group("role")
221
+ content = match.group("content")
222
+ self.messages.append({"role": role, "content": content.strip()})
223
  elif self.model in ["openchat-3.5"]:
224
  pair_pattern = r"GPT4 Correct User:(?P<inst>[\s\S]*?)<\|end_of_turn\|>\s*GPT4 Correct Assistant:(?P<answer>[\s\S]*?)<\|end_of_turn\|>"
 
225
  pair_matches = re.finditer(
226
  pair_pattern, self.merged_str, flags=re.MULTILINE | re.IGNORECASE
227
  )
 
235
  self.append_last_instruction_to_messages(
236
  inst_matches_list, pair_matches_list
237
  )
238
+ # https://huggingface.co/google/gemma-7b-it#chat-template
239
+ elif self.model in ["gemma-7b"]:
240
+ pair_pattern = r"<start_of_turn>user[\s\n]*(?P<inst>[\s\S]*?)<end_of_turn>[\s\n]*<start_of_turn>model(?P<answer>[\s\S]*?)<end_of_turn>"
241
+ pair_matches = re.finditer(
242
+ pair_pattern, self.merged_str, flags=re.MULTILINE | re.IGNORECASE
243
+ )
244
+ pair_matches_list = list(pair_matches)
245
+ self.messages = self.convert_pair_matches_to_messages(pair_matches_list)
246
+ inst_pattern = r"<start_of_turn>user\n(?P<inst>[\s\S]*?)<end_of_turn>"
247
+ inst_matches = re.finditer(
248
+ inst_pattern, self.merged_str, flags=re.MULTILINE | re.IGNORECASE
249
+ )
250
+ inst_matches_list = list(inst_matches)
251
+ self.append_last_instruction_to_messages(
252
+ inst_matches_list, pair_matches_list
253
+ )
254
  else:
255
  self.messages = [
256
  {
 
263
 
264
 
265
  if __name__ == "__main__":
266
+ # model = "mixtral-8x7b"
267
+ # model = "nous-mixtral-8x7b"
268
+ model = "gemma-7b"
269
+ composer = MessageComposer(model)
270
  messages = [
271
  {
272
  "role": "system",
273
+ "content": "You are a LLM developed by OpenAI.\nYour name is GPT-4.",
274
  },
275
  {"role": "user", "content": "Hello, who are you?"},
276
  {"role": "assistant", "content": "I am a bot."},
 
283
  # "content": "How many questions have I asked? Please list them.",
284
  # },
285
  ]
286
+ logger.note(f"model: {composer.model}")
287
  merged_str = composer.merge(messages)
288
+ logger.note("merged_str:")
289
+ logger.mesg(merged_str)
290
+ logger.note("splitted messages:")
291
  pprint(composer.split(merged_str))
292
+ # logger.note("merged merged_str:")
293
+ # logger.mesg(composer.merge(composer.split(merged_str)))