Feat: Add sharegpt multirole (#1137)
Browse files* feat(prompt): support multiple roles for sharegpt
* fix: add handling of empty role back
* feat: rebased and allowed more dynamic roles via config
* fix: variable
* chore: update message
* feat: add vicuna format
* fix: JSON serializable error
* fix: typing
* fix: don't remap for unknown keys
* fix: add roles to pydantic
* feat: add test
* chore: remove leftover print
* chore: remove leftover comment
* chore: remove print
* fix: update test to use chatml
README.md
CHANGED
@@ -651,9 +651,13 @@ datasets:
|
|
651 |
train_on_split: train # Optional[str] name of dataset split to load from
|
652 |
|
653 |
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
654 |
-
conversation:
|
655 |
field_human: # Optional[str]. Human key to use for conversation.
|
656 |
field_model: # Optional[str]. Assistant key to use for conversation.
|
|
|
|
|
|
|
|
|
657 |
|
658 |
# Custom user instruction prompt
|
659 |
- path: repo
|
|
|
651 |
train_on_split: train # Optional[str] name of dataset split to load from
|
652 |
|
653 |
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
654 |
+
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
655 |
field_human: # Optional[str]. Human key to use for conversation.
|
656 |
field_model: # Optional[str]. Assistant key to use for conversation.
|
657 |
+
# Add additional keys from your dataset as input or output roles
|
658 |
+
roles:
|
659 |
+
input: # Optional[List[str]]. These will be masked based on train_on_input
|
660 |
+
output: # Optional[List[str]].
|
661 |
|
662 |
# Custom user instruction prompt
|
663 |
- path: repo
|
src/axolotl/prompt_strategies/sharegpt.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
2 |
|
|
|
3 |
from typing import Any, Dict, Optional
|
4 |
|
5 |
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
@@ -11,6 +12,8 @@ from axolotl.utils.tokenization import (
|
|
11 |
merge_consecutive_messages,
|
12 |
)
|
13 |
|
|
|
|
|
14 |
|
15 |
def register_chatml_template(system_message=None):
|
16 |
system_message = system_message or "You are a helpful assistant."
|
@@ -42,11 +45,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
42 |
)
|
43 |
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
44 |
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
|
|
45 |
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
46 |
ShareGPTPrompterV2(
|
47 |
conversation=conversation,
|
48 |
role_key_model=field_model,
|
49 |
role_key_human=field_human,
|
|
|
50 |
),
|
51 |
tokenizer,
|
52 |
cfg.train_on_inputs,
|
@@ -142,7 +147,12 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|
142 |
"system": "system",
|
143 |
}
|
144 |
turns = [
|
145 |
-
{
|
|
|
|
|
|
|
|
|
|
|
146 |
for t in conversations
|
147 |
]
|
148 |
return turns
|
|
|
1 |
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
2 |
|
3 |
+
import logging
|
4 |
from typing import Any, Dict, Optional
|
5 |
|
6 |
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
|
|
12 |
merge_consecutive_messages,
|
13 |
)
|
14 |
|
15 |
+
LOG = logging.getLogger("axolotl")
|
16 |
+
|
17 |
|
18 |
def register_chatml_template(system_message=None):
|
19 |
system_message = system_message or "You are a helpful assistant."
|
|
|
45 |
)
|
46 |
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
47 |
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
48 |
+
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
|
49 |
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
50 |
ShareGPTPrompterV2(
|
51 |
conversation=conversation,
|
52 |
role_key_model=field_model,
|
53 |
role_key_human=field_human,
|
54 |
+
roles=roles,
|
55 |
),
|
56 |
tokenizer,
|
57 |
cfg.train_on_inputs,
|
|
|
147 |
"system": "system",
|
148 |
}
|
149 |
turns = [
|
150 |
+
{
|
151 |
+
"from": (
|
152 |
+
role_map[t[role_key]] if t[role_key] in role_map else t[role_key]
|
153 |
+
),
|
154 |
+
"value": t[value_key],
|
155 |
+
}
|
156 |
for t in conversations
|
157 |
]
|
158 |
return turns
|
src/axolotl/prompt_tokenizers.py
CHANGED
@@ -11,7 +11,7 @@ from transformers import BatchEncoding, PreTrainedTokenizer
|
|
11 |
from axolotl.monkeypatch.fastchat_conversation_turns import (
|
12 |
add_get_turns_to_conversation,
|
13 |
)
|
14 |
-
from axolotl.prompters import IGNORE_TOKEN_ID
|
15 |
|
16 |
LOG = logging.getLogger("axolotl")
|
17 |
|
@@ -37,7 +37,7 @@ class PromptTokenizingStrategy(abc.ABC):
|
|
37 |
|
38 |
def __init__(
|
39 |
self,
|
40 |
-
prompter,
|
41 |
tokenizer,
|
42 |
train_on_inputs: bool = False,
|
43 |
sequence_len: int = 2048,
|
@@ -340,6 +340,23 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
340 |
self.prompter._conversation.copy() # pylint: disable=protected-access
|
341 |
)
|
342 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
|
344 |
role_remap = []
|
345 |
if (
|
@@ -360,19 +377,18 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
360 |
LOG.warning(f"expected tuple, got {part}")
|
361 |
continue
|
362 |
|
363 |
-
tool_role_label = None
|
364 |
-
if len(conversation.roles) == 3:
|
365 |
-
(
|
366 |
-
user_role_label,
|
367 |
-
assistant_role_label,
|
368 |
-
tool_role_label,
|
369 |
-
) = conversation.roles
|
370 |
-
else:
|
371 |
-
user_role_label, assistant_role_label = conversation.roles
|
372 |
role, content = part
|
373 |
|
374 |
# Uses "in" because role contains extra characters
|
375 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
role = (
|
377 |
role.replace(role_remap[0]["from"], role_remap[0]["to"])
|
378 |
if role_remap
|
@@ -392,7 +408,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
392 |
else:
|
393 |
# everything from this is masked out from the labels
|
394 |
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
395 |
-
elif
|
396 |
role = (
|
397 |
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
398 |
if role_remap
|
@@ -423,7 +439,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
423 |
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
|
424 |
len_role, len(labels)
|
425 |
)
|
426 |
-
elif
|
427 |
turn = content
|
428 |
# this is only ever the first part, should include the bos token and the user query
|
429 |
res = self._tokenize(
|
@@ -434,11 +450,6 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
434 |
else:
|
435 |
# everything from this is masked out from the labels
|
436 |
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
437 |
-
elif tool_role_label and tool_role_label in role:
|
438 |
-
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
439 |
-
else:
|
440 |
-
LOG.warning(f"unhandled role: {role}")
|
441 |
-
continue
|
442 |
|
443 |
# pylint: disable=duplicate-code
|
444 |
result, current_len = parse_tokenized_to_result(
|
|
|
11 |
from axolotl.monkeypatch.fastchat_conversation_turns import (
|
12 |
add_get_turns_to_conversation,
|
13 |
)
|
14 |
+
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
15 |
|
16 |
LOG = logging.getLogger("axolotl")
|
17 |
|
|
|
37 |
|
38 |
def __init__(
|
39 |
self,
|
40 |
+
prompter: Prompter,
|
41 |
tokenizer,
|
42 |
train_on_inputs: bool = False,
|
43 |
sequence_len: int = 2048,
|
|
|
340 |
self.prompter._conversation.copy() # pylint: disable=protected-access
|
341 |
)
|
342 |
|
343 |
+
input_roles = {conversation.roles[0]}
|
344 |
+
output_roles = {conversation.roles[1]}
|
345 |
+
|
346 |
+
if len(conversation.roles) == 3:
|
347 |
+
tool_role_label = conversation.roles[2]
|
348 |
+
input_roles.add(tool_role_label)
|
349 |
+
|
350 |
+
# Add roles from the config
|
351 |
+
if self.prompter.roles:
|
352 |
+
if "input" in self.prompter.roles and self.prompter.roles["input"]:
|
353 |
+
for role in self.prompter.roles["input"]:
|
354 |
+
input_roles.add(role)
|
355 |
+
|
356 |
+
if "output" in self.prompter.roles and self.prompter.roles["output"]:
|
357 |
+
for role in self.prompter.roles["output"]:
|
358 |
+
output_roles.add(role)
|
359 |
+
|
360 |
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
|
361 |
role_remap = []
|
362 |
if (
|
|
|
377 |
LOG.warning(f"expected tuple, got {part}")
|
378 |
continue
|
379 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
380 |
role, content = part
|
381 |
|
382 |
# Uses "in" because role contains extra characters
|
383 |
+
input_turn = any(r.lower() in role.lower() for r in input_roles)
|
384 |
+
output_turn = any(r.lower() in role.lower() for r in output_roles)
|
385 |
+
empty_role = role.strip() == ""
|
386 |
+
|
387 |
+
if not any([input_turn, output_turn, empty_role]):
|
388 |
+
LOG.warning(f"unhandled role: {role}")
|
389 |
+
continue
|
390 |
+
|
391 |
+
if input_turn:
|
392 |
role = (
|
393 |
role.replace(role_remap[0]["from"], role_remap[0]["to"])
|
394 |
if role_remap
|
|
|
408 |
else:
|
409 |
# everything from this is masked out from the labels
|
410 |
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
411 |
+
elif output_turn:
|
412 |
role = (
|
413 |
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
414 |
if role_remap
|
|
|
439 |
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
|
440 |
len_role, len(labels)
|
441 |
)
|
442 |
+
elif empty_role:
|
443 |
turn = content
|
444 |
# this is only ever the first part, should include the bos token and the user query
|
445 |
res = self._tokenize(
|
|
|
450 |
else:
|
451 |
# everything from this is masked out from the labels
|
452 |
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
|
|
|
|
|
|
|
|
|
|
453 |
|
454 |
# pylint: disable=duplicate-code
|
455 |
result, current_len = parse_tokenized_to_result(
|
src/axolotl/prompters.py
CHANGED
@@ -259,6 +259,12 @@ SHAREGPT_ASSERTION_FAILED_ROLE = (
|
|
259 |
"Role did not alternate between turns (gpt and human). Please check your data."
|
260 |
)
|
261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
|
263 |
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
264 |
"""
|
@@ -268,7 +274,9 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
|
268 |
role_key_human = "human"
|
269 |
role_key_model = "gpt"
|
270 |
# Optional, only used for tool usage datasets.
|
271 |
-
role_key_tool = None
|
|
|
|
|
272 |
|
273 |
def __init__(
|
274 |
self,
|
@@ -277,6 +285,7 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
|
277 |
role_key_human: Optional[str] = None,
|
278 |
role_key_model: Optional[str] = None,
|
279 |
role_key_tool: Optional[str] = None,
|
|
|
280 |
):
|
281 |
if conversation:
|
282 |
if isinstance(conversation, Conversation):
|
@@ -291,6 +300,8 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
|
291 |
self.role_key_model = role_key_model
|
292 |
if role_key_tool:
|
293 |
self.role_key_tool = role_key_tool
|
|
|
|
|
294 |
|
295 |
def _build_result(self, source):
|
296 |
if len(source) < 2:
|
@@ -322,11 +333,23 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
|
322 |
|
323 |
conv.messages = []
|
324 |
for _, sentence in enumerate(source):
|
325 |
-
|
326 |
-
if
|
327 |
-
|
328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
329 |
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
|
|
330 |
conv.append_message(role, sentence["value"])
|
331 |
|
332 |
return conv.get_turns()
|
@@ -354,11 +377,13 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
|
|
354 |
conversation: Optional[Union[str, Conversation]] = None,
|
355 |
role_key_human: Optional[str] = None,
|
356 |
role_key_model: Optional[str] = None,
|
|
|
357 |
):
|
358 |
super().__init__(
|
359 |
conversation=conversation,
|
360 |
role_key_human=role_key_human,
|
361 |
role_key_model=role_key_model,
|
|
|
362 |
)
|
363 |
|
364 |
|
|
|
259 |
"Role did not alternate between turns (gpt and human). Please check your data."
|
260 |
)
|
261 |
|
262 |
+
CONVERSATION_ROLE_FORMAT = {
|
263 |
+
"chatml": "<|im_start|>{ROLE}",
|
264 |
+
"zephyr": "<|{ROLE}|>",
|
265 |
+
"vicuna_v1.1": "{ROLE}",
|
266 |
+
}
|
267 |
+
|
268 |
|
269 |
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
270 |
"""
|
|
|
274 |
role_key_human = "human"
|
275 |
role_key_model = "gpt"
|
276 |
# Optional, only used for tool usage datasets.
|
277 |
+
role_key_tool: Optional[str] = None
|
278 |
+
# Optional, role input/output mapping
|
279 |
+
roles: Optional[dict] = None
|
280 |
|
281 |
def __init__(
|
282 |
self,
|
|
|
285 |
role_key_human: Optional[str] = None,
|
286 |
role_key_model: Optional[str] = None,
|
287 |
role_key_tool: Optional[str] = None,
|
288 |
+
roles: Optional[dict] = None,
|
289 |
):
|
290 |
if conversation:
|
291 |
if isinstance(conversation, Conversation):
|
|
|
300 |
self.role_key_model = role_key_model
|
301 |
if role_key_tool:
|
302 |
self.role_key_tool = role_key_tool
|
303 |
+
if roles:
|
304 |
+
self.roles = roles
|
305 |
|
306 |
def _build_result(self, source):
|
307 |
if len(source) < 2:
|
|
|
333 |
|
334 |
conv.messages = []
|
335 |
for _, sentence in enumerate(source):
|
336 |
+
from_role = sentence["from"]
|
337 |
+
if from_role in roles:
|
338 |
+
role = roles[from_role]
|
339 |
+
else:
|
340 |
+
if self._conversation.name not in CONVERSATION_ROLE_FORMAT:
|
341 |
+
raise NotImplementedError(
|
342 |
+
f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet."
|
343 |
+
"Please help us by creating an Issue to add support for this conversation type."
|
344 |
+
)
|
345 |
+
|
346 |
+
role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format(
|
347 |
+
ROLE=from_role
|
348 |
+
)
|
349 |
+
|
350 |
+
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
|
351 |
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
352 |
+
|
353 |
conv.append_message(role, sentence["value"])
|
354 |
|
355 |
return conv.get_turns()
|
|
|
377 |
conversation: Optional[Union[str, Conversation]] = None,
|
378 |
role_key_human: Optional[str] = None,
|
379 |
role_key_model: Optional[str] = None,
|
380 |
+
roles: Optional[dict] = None,
|
381 |
):
|
382 |
super().__init__(
|
383 |
conversation=conversation,
|
384 |
role_key_human=role_key_human,
|
385 |
role_key_model=role_key_model,
|
386 |
+
roles=roles,
|
387 |
)
|
388 |
|
389 |
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
@@ -96,6 +96,8 @@ class SFTDataset(BaseModel):
|
|
96 |
field_human: Optional[str] = None
|
97 |
field_model: Optional[str] = None
|
98 |
|
|
|
|
|
99 |
|
100 |
class UserDefinedDPOType(BaseModel):
|
101 |
"""User defined typing for DPO"""
|
|
|
96 |
field_human: Optional[str] = None
|
97 |
field_model: Optional[str] = None
|
98 |
|
99 |
+
roles: Optional[Dict[str, List[str]]] = None
|
100 |
+
|
101 |
|
102 |
class UserDefinedDPOType(BaseModel):
|
103 |
"""User defined typing for DPO"""
|
tests/prompt_strategies/test_sharegpt.py
CHANGED
@@ -62,6 +62,38 @@ def fixture_sharegpt_glaive_dataset():
|
|
62 |
)
|
63 |
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
@pytest.fixture(name="tokenizer")
|
66 |
def fixture_tokenizer():
|
67 |
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
@@ -196,3 +228,39 @@ class TestSharegpt:
|
|
196 |
32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
|
197 |
]
|
198 |
# fmt: on
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
)
|
63 |
|
64 |
|
65 |
+
@pytest.fixture(name="multi_role_dataset")
|
66 |
+
def fixture_multi_role_dataset():
|
67 |
+
return Dataset.from_list(
|
68 |
+
[
|
69 |
+
{
|
70 |
+
"conversations": [
|
71 |
+
{
|
72 |
+
"from": "system",
|
73 |
+
"value": "use get_weather(city) to get the weather for a city",
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"from": "human",
|
77 |
+
"value": "hello, what's the weather in New York?",
|
78 |
+
},
|
79 |
+
{
|
80 |
+
"from": "gpt",
|
81 |
+
"value": "let me get that for you",
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"from": "tool",
|
85 |
+
"value": "get_weather(New York)",
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"from": "gpt",
|
89 |
+
"value": "the weather in New York is 70 degrees and sunny",
|
90 |
+
},
|
91 |
+
]
|
92 |
+
}
|
93 |
+
]
|
94 |
+
)
|
95 |
+
|
96 |
+
|
97 |
@pytest.fixture(name="tokenizer")
|
98 |
def fixture_tokenizer():
|
99 |
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
|
|
228 |
32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
|
229 |
]
|
230 |
# fmt: on
|
231 |
+
|
232 |
+
def test_multi_role_dataset(self, multi_role_dataset, tokenizer):
|
233 |
+
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
234 |
+
ShareGPTPrompterV2(conversation="chatml", roles={"input": ["tool"]}),
|
235 |
+
tokenizer,
|
236 |
+
False, # train_on_inputs
|
237 |
+
2048, # sequence_len
|
238 |
+
)
|
239 |
+
|
240 |
+
dataset_wrapper = TokenizedPromptDataset(
|
241 |
+
strategy, multi_role_dataset, process_count=1
|
242 |
+
)
|
243 |
+
|
244 |
+
input_ids = dataset_wrapper[0]["input_ids"]
|
245 |
+
# fmt: off
|
246 |
+
assert input_ids == [
|
247 |
+
1, # bos
|
248 |
+
32001, 1587, 13, 1730, 625, 28730, 769, 1223, 28732, 18373, 28731, 298, 625, 272, 8086, 354, 264, 2990, 32000, 28705, 13, # system
|
249 |
+
32001, 2188, 13, 21558, 28725, 767, 28742, 28713, 272, 8086, 297, 1450, 2726, 28804, 32000, 28705, 13, # human
|
250 |
+
32001, 13892, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
|
251 |
+
32001, 3921, 13, 527, 28730, 769, 1223, 28732, 2972, 2726, 28731, 32000, 28705, 13, # tool
|
252 |
+
32001, 13892, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
|
253 |
+
]
|
254 |
+
# fmt: on
|
255 |
+
|
256 |
+
labels = dataset_wrapper[0]["labels"]
|
257 |
+
# fmt: off
|
258 |
+
assert labels == [
|
259 |
+
-100, # bos
|
260 |
+
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # system
|
261 |
+
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # human
|
262 |
+
-100, -100, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
|
263 |
+
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool
|
264 |
+
-100, -100, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
|
265 |
+
]
|
266 |
+
# fmt: on
|