hamel commited on
Commit
7bbaac9
1 Parent(s): 161bcb6

fix mistral prompt assembly (#982)

Browse files

* fix mistral prompts

* fix spacing

* remove elif

src/axolotl/monkeypatch/fastchat_conversation_turns.py CHANGED
@@ -82,7 +82,7 @@ def get_turns( # pylint: disable=too-many-return-statements
82
  else:
83
  yield role + ":", ""
84
  return
85
- if self.sep_style == SeparatorStyle.LLAMA2:
86
  if self.system_message:
87
  if self.messages:
88
  # For llama, the system message is incorporated into the first human instruction
@@ -101,6 +101,28 @@ def get_turns( # pylint: disable=too-many-return-statements
101
  else:
102
  yield role, ""
103
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  if self.sep_style == SeparatorStyle.CHATGLM:
105
  # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
106
  # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
 
82
  else:
83
  yield role + ":", ""
84
  return
85
+ if self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral":
86
  if self.system_message:
87
  if self.messages:
88
  # For llama, the system message is incorporated into the first human instruction
 
101
  else:
102
  yield role, ""
103
  return
104
+ if self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral":
105
+ contains_sys_msg = False
106
+ if self.system_message:
107
+ contains_sys_msg = True
108
+ if self.messages:
109
+ # There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction seperated by a newline
110
+ first_role, first_msg = self.messages[0]
111
+ if first_role == self.roles[0]:
112
+ system_prompt = self.system_template.format(
113
+ system_message=" " + self.system_message
114
+ )
115
+ system_prompt += first_msg
116
+ self.messages.pop(0)
117
+ yield "", system_prompt
118
+ for i, (role, message) in enumerate(self.messages):
119
+ if message and i == 0 and not contains_sys_msg:
120
+ yield "", system_prompt.strip() + " " + message # if there is no system message, we need to make sure there is the a `<s> [INST]` at the beginning of the first instruction.
121
+ elif message:
122
+ yield role + " ", message
123
+ else:
124
+ yield role, ""
125
+ return
126
  if self.sep_style == SeparatorStyle.CHATGLM:
127
  # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
128
  # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
tests/test_prompt_tokenizers.py CHANGED
@@ -2,6 +2,7 @@
2
  import json
3
  import logging
4
  import unittest
 
5
  from pathlib import Path
6
  from typing import Optional
7
 
@@ -25,6 +26,50 @@ from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
25
 
26
  LOG = logging.getLogger("axolotl")
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  class TestPromptTokenizationStrategies(unittest.TestCase):
30
  """
@@ -116,74 +161,68 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
116
 
117
  def test_sharegpt_llama(self):
118
  "Make sure the sharegpt/llama is tokenized and formatted correctly."
119
- prompter = ShareGPTPrompterV2(conversation="llama-2")
120
- strat = ShareGPTPromptTokenizingStrategy(
121
- prompter,
122
- self.tokenizer,
123
- False,
124
- 2048,
125
- )
126
 
127
  def tokenize(conv):
128
- return strat.tokenize_prompt(conv)["input_ids"]
129
 
130
  def decode(ids):
131
  return strat.tokenizer.decode(ids)
132
 
133
- # Multi-turn conversations
134
- multi_turn_conv = {
135
- "conversations": [
136
- {"from": "system", "value": "lorem"},
137
- {"from": "human", "value": "abc"},
138
- {"from": "gpt", "value": "ipsum"},
139
- {"from": "human", "value": "123"},
140
- {"from": "gpt", "value": "sit"},
141
- ]
142
- }
143
  # fmt: off
144
- mt_ids = tokenize(multi_turn_conv)
 
145
  assert decode(mt_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
146
  assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
147
 
148
- # Single-turn conversations
149
- single_turn_conv = {
150
- "conversations": [
151
- {"from": "system", "value": "lorem"},
152
- {"from": "human", "value": "abc"},
153
- {"from": "gpt", "value": "ipsum"},
154
- ]
155
- }
156
-
157
- st_ids = tokenize(single_turn_conv)
158
  assert decode(st_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s>'
159
  assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
160
 
161
  # No system message, single-turn
162
- no_sys_conv = {
163
- "conversations": [
164
- {"from": "human", "value": "abc"},
165
- {"from": "gpt", "value": "ipsum"},
166
- ]
167
- }
168
-
169
- ns_ids = tokenize(no_sys_conv)
170
  assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
171
  assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
172
 
173
  # No system message, multi-turn
174
- no_sys_mt_conv = {
175
- "conversations": [
176
- {"from": "human", "value": "abc"},
177
- {"from": "gpt", "value": "ipsum"},
178
- {"from": "human", "value": "123"},
179
- {"from": "gpt", "value": "sit"},
180
- ]
181
- }
182
- ns_mt_ids = tokenize(no_sys_mt_conv)
183
  assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
184
  assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
185
  # fmt: on
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  def test_sharegpt_changes_roles(self):
188
  conversation = {
189
  "roles": ["USER", "CHARACTER"],
 
2
  import json
3
  import logging
4
  import unittest
5
+ from copy import deepcopy
6
  from pathlib import Path
7
  from typing import Optional
8
 
 
26
 
27
  LOG = logging.getLogger("axolotl")
28
 
29
+ test_data = {
30
+ "multi_turn_sys": {
31
+ "conversations": [
32
+ {"from": "system", "value": "lorem"},
33
+ {"from": "human", "value": "abc"},
34
+ {"from": "gpt", "value": "ipsum"},
35
+ {"from": "human", "value": "123"},
36
+ {"from": "gpt", "value": "sit"},
37
+ ]
38
+ },
39
+ "single_turn_sys": {
40
+ "conversations": [
41
+ {"from": "system", "value": "lorem"},
42
+ {"from": "human", "value": "abc"},
43
+ {"from": "gpt", "value": "ipsum"},
44
+ ]
45
+ },
46
+ "single_turn_no_sys": {
47
+ "conversations": [
48
+ {"from": "human", "value": "abc"},
49
+ {"from": "gpt", "value": "ipsum"},
50
+ ]
51
+ },
52
+ "multi_turn_no_sys": {
53
+ "conversations": [
54
+ {"from": "human", "value": "abc"},
55
+ {"from": "gpt", "value": "ipsum"},
56
+ {"from": "human", "value": "123"},
57
+ {"from": "gpt", "value": "sit"},
58
+ ]
59
+ },
60
+ }
61
+
62
+
63
+ def prompt_strat(conversation, tokenizer):
64
+ "Helper function to create a prompt strategy for testing."
65
+ prompter = ShareGPTPrompterV2(conversation=conversation)
66
+ return ShareGPTPromptTokenizingStrategy(
67
+ prompter,
68
+ tokenizer,
69
+ False,
70
+ 2048,
71
+ )
72
+
73
 
74
  class TestPromptTokenizationStrategies(unittest.TestCase):
75
  """
 
161
 
162
  def test_sharegpt_llama(self):
163
  "Make sure the sharegpt/llama is tokenized and formatted correctly."
164
+ strat = prompt_strat("llama-2", self.tokenizer)
 
 
 
 
 
 
165
 
166
  def tokenize(conv):
167
+ return strat.tokenize_prompt(deepcopy(conv))["input_ids"]
168
 
169
  def decode(ids):
170
  return strat.tokenizer.decode(ids)
171
 
 
 
 
 
 
 
 
 
 
 
172
  # fmt: off
173
+ # System message, multi-turn conversations
174
+ mt_ids = tokenize(test_data['multi_turn_sys'])
175
  assert decode(mt_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
176
  assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
177
 
178
+ # System message, single-turn conversations
179
+ st_ids = tokenize(test_data['single_turn_sys'])
 
 
 
 
 
 
 
 
180
  assert decode(st_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s>'
181
  assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
182
 
183
  # No system message, single-turn
184
+ ns_ids = tokenize(test_data['single_turn_no_sys'])
 
 
 
 
 
 
 
185
  assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
186
  assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
187
 
188
  # No system message, multi-turn
189
+ ns_mt_ids = tokenize(test_data['multi_turn_no_sys'])
 
 
 
 
 
 
 
 
190
  assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
191
  assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
192
  # fmt: on
193
 
194
+ def test_sharegpt_mistral(self):
195
+ "Make sure the sharegpt/mistral is tokenized and formatted correctly."
196
+ strat = prompt_strat("mistral", self.tokenizer)
197
+
198
+ def tokenize(conv):
199
+ return strat.tokenize_prompt(deepcopy(conv))["input_ids"]
200
+
201
+ def decode(ids):
202
+ return strat.tokenizer.decode(ids)
203
+
204
+ # fmt: off
205
+ # System message, multi-turn conversations
206
+ mt_ids = tokenize(test_data['multi_turn_sys'])
207
+ assert decode(mt_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
208
+ assert mt_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
209
+
210
+ # System message, single-turn conversations
211
+ st_ids = tokenize(test_data['single_turn_sys'])
212
+ assert decode(st_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s>'
213
+ assert st_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
214
+
215
+ # No system message, single-turn
216
+ ns_ids = tokenize(test_data['single_turn_no_sys'])
217
+ assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
218
+ assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
219
+
220
+ # No system message, multi-turn
221
+ ns_mt_ids = tokenize(test_data['multi_turn_no_sys'])
222
+ assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
223
+ assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
224
+ # fmt: on
225
+
226
  def test_sharegpt_changes_roles(self):
227
  conversation = {
228
  "roles": ["USER", "CHARACTER"],