Husnain commited on
Commit
7f1e89a
1 Parent(s): 402861a

HF-LLM-API

Browse files
Files changed (1) hide show
  1. messagers/message_composer.py +48 -2
messagers/message_composer.py CHANGED
@@ -10,6 +10,7 @@ class MessageComposer:
10
  "mistral-7b",
11
  "openchat-3.5",
12
  "nous-mixtral-8x7b",
 
13
  ]
14
 
15
  def __init__(self, model: str = None):
@@ -19,7 +20,7 @@ class MessageComposer:
19
  self.model = "mixtral-8x7b"
20
  self.system_roles = ["system"]
21
  self.inst_roles = ["user", "system", "inst"]
22
- self.answer_roles = ["assistant", "bot", "answer"]
23
  self.default_role = "user"
24
 
25
  def concat_messages_by_role(self, messages):
@@ -63,6 +64,11 @@ class MessageComposer:
63
  # Hello, who are you?<|im_end|>
64
  # <|im_start|>assistant
65
 
 
 
 
 
 
66
  self.messages = messages
67
  self.merged_str = ""
68
 
@@ -116,6 +122,29 @@ class MessageComposer:
116
  )
117
  self.merged_str_list.append(f"GPT4 Correct Assistant:\n")
118
  self.merged_str = "\n".join(self.merged_str_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  else:
120
  self.merged_str = "\n".join(
121
  [
@@ -206,6 +235,22 @@ class MessageComposer:
206
  self.append_last_instruction_to_messages(
207
  inst_matches_list, pair_matches_list
208
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  else:
210
  self.messages = [
211
  {
@@ -218,8 +263,9 @@ class MessageComposer:
218
 
219
 
220
  if __name__ == "__main__":
221
- model = "mixtral-8x7b"
222
  # model = "nous-mixtral-8x7b"
 
223
  composer = MessageComposer(model)
224
  messages = [
225
  {
 
10
  "mistral-7b",
11
  "openchat-3.5",
12
  "nous-mixtral-8x7b",
13
+ "gemma-7b",
14
  ]
15
 
16
  def __init__(self, model: str = None):
 
20
  self.model = "mixtral-8x7b"
21
  self.system_roles = ["system"]
22
  self.inst_roles = ["user", "system", "inst"]
23
+ self.answer_roles = ["assistant", "bot", "answer", "model"]
24
  self.default_role = "user"
25
 
26
  def concat_messages_by_role(self, messages):
 
64
  # Hello, who are you?<|im_end|>
65
  # <|im_start|>assistant
66
 
67
+ # Google Gemma-it
68
+ # <start_of_turn>user
69
+ # How does the brain work?<end_of_turn>
70
+ # <start_of_turn>model
71
+
72
  self.messages = messages
73
  self.merged_str = ""
74
 
 
122
  )
123
  self.merged_str_list.append(f"GPT4 Correct Assistant:\n")
124
  self.merged_str = "\n".join(self.merged_str_list)
125
+ # https://huggingface.co/google/gemma-7b-it#chat-template
126
+ elif self.model in ["gemma-7b"]:
127
+ self.messages = self.concat_messages_by_role(messages)
128
+ self.merged_str_list = []
129
+ self.end_of_turn = "<end_of_turn>"
130
+ self.start_of_turn = "<start_of_turn>"
131
+ for message in self.messages:
132
+ role = message["role"]
133
+ content = message["content"]
134
+ if role in self.inst_roles:
135
+ self.merged_str_list.append(
136
+ f"{self.start_of_turn}user\n{content}{self.end_of_turn}"
137
+ )
138
+ elif role in self.answer_roles:
139
+ self.merged_str_list.append(
140
+ f"{self.start_of_turn}model\n{content}{self.end_of_turn}"
141
+ )
142
+ else:
143
+ self.merged_str_list.append(
144
+ f"{self.start_of_turn}user\n{content}{self.end_of_turn}"
145
+ )
146
+ self.merged_str_list.append(f"{self.start_of_turn}model\n")
147
+ self.merged_str = "\n".join(self.merged_str_list)
148
  else:
149
  self.merged_str = "\n".join(
150
  [
 
235
  self.append_last_instruction_to_messages(
236
  inst_matches_list, pair_matches_list
237
  )
238
+ # https://huggingface.co/google/gemma-7b-it#chat-template
239
+ elif self.model in ["gemma-7b"]:
240
+ pair_pattern = r"<start_of_turn>user[\s\n]*(?P<inst>[\s\S]*?)<end_of_turn>[\s\n]*<start_of_turn>model(?P<answer>[\s\S]*?)<end_of_turn>"
241
+ pair_matches = re.finditer(
242
+ pair_pattern, self.merged_str, flags=re.MULTILINE | re.IGNORECASE
243
+ )
244
+ pair_matches_list = list(pair_matches)
245
+ self.messages = self.convert_pair_matches_to_messages(pair_matches_list)
246
+ inst_pattern = r"<start_of_turn>user\n(?P<inst>[\s\S]*?)<end_of_turn>"
247
+ inst_matches = re.finditer(
248
+ inst_pattern, self.merged_str, flags=re.MULTILINE | re.IGNORECASE
249
+ )
250
+ inst_matches_list = list(inst_matches)
251
+ self.append_last_instruction_to_messages(
252
+ inst_matches_list, pair_matches_list
253
+ )
254
  else:
255
  self.messages = [
256
  {
 
263
 
264
 
265
  if __name__ == "__main__":
266
+ # model = "mixtral-8x7b"
267
  # model = "nous-mixtral-8x7b"
268
+ model = "gemma-7b"
269
  composer = MessageComposer(model)
270
  messages = [
271
  {