Nanobit commited on
Commit
40a88e8
1 Parent(s): 43bdc5d

Feat: Add sharegpt multirole (#1137)

Browse files

* feat(prompt): support multiple roles for sharegpt

* fix: add handling of empty role back

* feat: rebased and allowed more dynamic roles via config

* fix: variable

* chore: update message

* feat: add vicuna format

* fix: JSON serializable error

* fix: typing

* fix: don't remap for unknown keys

* fix: add roles to pydantic

* feat: add test

* chore: remove leftover print

* chore: remove leftover comment

* chore: remove print

* fix: update test to use chatml

README.md CHANGED
@@ -651,9 +651,13 @@ datasets:
651
  train_on_split: train # Optional[str] name of dataset split to load from
652
 
653
  # Optional[str] fastchat conversation type, only used with type: sharegpt
654
- conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
655
  field_human: # Optional[str]. Human key to use for conversation.
656
  field_model: # Optional[str]. Assistant key to use for conversation.
 
 
 
 
657
 
658
  # Custom user instruction prompt
659
  - path: repo
 
651
  train_on_split: train # Optional[str] name of dataset split to load from
652
 
653
  # Optional[str] fastchat conversation type, only used with type: sharegpt
654
+ conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
655
  field_human: # Optional[str]. Human key to use for conversation.
656
  field_model: # Optional[str]. Assistant key to use for conversation.
657
+ # Add additional keys from your dataset as input or output roles
658
+ roles:
659
+ input: # Optional[List[str]]. These will be masked based on train_on_input
660
+ output: # Optional[List[str]].
661
 
662
  # Custom user instruction prompt
663
  - path: repo
src/axolotl/prompt_strategies/sharegpt.py CHANGED
@@ -1,5 +1,6 @@
1
  """Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
2
 
 
3
  from typing import Any, Dict, Optional
4
 
5
  from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
@@ -11,6 +12,8 @@ from axolotl.utils.tokenization import (
11
  merge_consecutive_messages,
12
  )
13
 
 
 
14
 
15
  def register_chatml_template(system_message=None):
16
  system_message = system_message or "You are a helpful assistant."
@@ -42,11 +45,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
42
  )
43
  field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
44
  field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
 
45
  strategy = SimpleShareGPTPromptTokenizingStrategy(
46
  ShareGPTPrompterV2(
47
  conversation=conversation,
48
  role_key_model=field_model,
49
  role_key_human=field_human,
 
50
  ),
51
  tokenizer,
52
  cfg.train_on_inputs,
@@ -142,7 +147,12 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
142
  "system": "system",
143
  }
144
  turns = [
145
- {"from": role_map[t[role_key]], "value": t[value_key]}
 
 
 
 
 
146
  for t in conversations
147
  ]
148
  return turns
 
1
  """Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
2
 
3
+ import logging
4
  from typing import Any, Dict, Optional
5
 
6
  from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
 
12
  merge_consecutive_messages,
13
  )
14
 
15
+ LOG = logging.getLogger("axolotl")
16
+
17
 
18
  def register_chatml_template(system_message=None):
19
  system_message = system_message or "You are a helpful assistant."
 
45
  )
46
  field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
47
  field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
48
+ roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
49
  strategy = SimpleShareGPTPromptTokenizingStrategy(
50
  ShareGPTPrompterV2(
51
  conversation=conversation,
52
  role_key_model=field_model,
53
  role_key_human=field_human,
54
+ roles=roles,
55
  ),
56
  tokenizer,
57
  cfg.train_on_inputs,
 
147
  "system": "system",
148
  }
149
  turns = [
150
+ {
151
+ "from": (
152
+ role_map[t[role_key]] if t[role_key] in role_map else t[role_key]
153
+ ),
154
+ "value": t[value_key],
155
+ }
156
  for t in conversations
157
  ]
158
  return turns
src/axolotl/prompt_tokenizers.py CHANGED
@@ -11,7 +11,7 @@ from transformers import BatchEncoding, PreTrainedTokenizer
11
  from axolotl.monkeypatch.fastchat_conversation_turns import (
12
  add_get_turns_to_conversation,
13
  )
14
- from axolotl.prompters import IGNORE_TOKEN_ID
15
 
16
  LOG = logging.getLogger("axolotl")
17
 
@@ -37,7 +37,7 @@ class PromptTokenizingStrategy(abc.ABC):
37
 
38
  def __init__(
39
  self,
40
- prompter,
41
  tokenizer,
42
  train_on_inputs: bool = False,
43
  sequence_len: int = 2048,
@@ -340,6 +340,23 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
340
  self.prompter._conversation.copy() # pylint: disable=protected-access
341
  )
342
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  # support for custom roles from the dataset, only useful for vicuna style prompts/roles
344
  role_remap = []
345
  if (
@@ -360,19 +377,18 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
360
  LOG.warning(f"expected tuple, got {part}")
361
  continue
362
 
363
- tool_role_label = None
364
- if len(conversation.roles) == 3:
365
- (
366
- user_role_label,
367
- assistant_role_label,
368
- tool_role_label,
369
- ) = conversation.roles
370
- else:
371
- user_role_label, assistant_role_label = conversation.roles
372
  role, content = part
373
 
374
  # Uses "in" because role contains extra characters
375
- if user_role_label in role:
 
 
 
 
 
 
 
 
376
  role = (
377
  role.replace(role_remap[0]["from"], role_remap[0]["to"])
378
  if role_remap
@@ -392,7 +408,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
392
  else:
393
  # everything from this is masked out from the labels
394
  labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
395
- elif assistant_role_label in role:
396
  role = (
397
  role.replace(role_remap[1]["from"], role_remap[1]["to"])
398
  if role_remap
@@ -423,7 +439,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
423
  labels[:len_role] = [IGNORE_TOKEN_ID] * min(
424
  len_role, len(labels)
425
  )
426
- elif role == "":
427
  turn = content
428
  # this is only ever the first part, should include the bos token and the user query
429
  res = self._tokenize(
@@ -434,11 +450,6 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
434
  else:
435
  # everything from this is masked out from the labels
436
  labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
437
- elif tool_role_label and tool_role_label in role:
438
- labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
439
- else:
440
- LOG.warning(f"unhandled role: {role}")
441
- continue
442
 
443
  # pylint: disable=duplicate-code
444
  result, current_len = parse_tokenized_to_result(
 
11
  from axolotl.monkeypatch.fastchat_conversation_turns import (
12
  add_get_turns_to_conversation,
13
  )
14
+ from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
15
 
16
  LOG = logging.getLogger("axolotl")
17
 
 
37
 
38
  def __init__(
39
  self,
40
+ prompter: Prompter,
41
  tokenizer,
42
  train_on_inputs: bool = False,
43
  sequence_len: int = 2048,
 
340
  self.prompter._conversation.copy() # pylint: disable=protected-access
341
  )
342
 
343
+ input_roles = {conversation.roles[0]}
344
+ output_roles = {conversation.roles[1]}
345
+
346
+ if len(conversation.roles) == 3:
347
+ tool_role_label = conversation.roles[2]
348
+ input_roles.add(tool_role_label)
349
+
350
+ # Add roles from the config
351
+ if self.prompter.roles:
352
+ if "input" in self.prompter.roles and self.prompter.roles["input"]:
353
+ for role in self.prompter.roles["input"]:
354
+ input_roles.add(role)
355
+
356
+ if "output" in self.prompter.roles and self.prompter.roles["output"]:
357
+ for role in self.prompter.roles["output"]:
358
+ output_roles.add(role)
359
+
360
  # support for custom roles from the dataset, only useful for vicuna style prompts/roles
361
  role_remap = []
362
  if (
 
377
  LOG.warning(f"expected tuple, got {part}")
378
  continue
379
 
 
 
 
 
 
 
 
 
 
380
  role, content = part
381
 
382
  # Uses "in" because role contains extra characters
383
+ input_turn = any(r.lower() in role.lower() for r in input_roles)
384
+ output_turn = any(r.lower() in role.lower() for r in output_roles)
385
+ empty_role = role.strip() == ""
386
+
387
+ if not any([input_turn, output_turn, empty_role]):
388
+ LOG.warning(f"unhandled role: {role}")
389
+ continue
390
+
391
+ if input_turn:
392
  role = (
393
  role.replace(role_remap[0]["from"], role_remap[0]["to"])
394
  if role_remap
 
408
  else:
409
  # everything from this is masked out from the labels
410
  labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
411
+ elif output_turn:
412
  role = (
413
  role.replace(role_remap[1]["from"], role_remap[1]["to"])
414
  if role_remap
 
439
  labels[:len_role] = [IGNORE_TOKEN_ID] * min(
440
  len_role, len(labels)
441
  )
442
+ elif empty_role:
443
  turn = content
444
  # this is only ever the first part, should include the bos token and the user query
445
  res = self._tokenize(
 
450
  else:
451
  # everything from this is masked out from the labels
452
  labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
 
 
 
 
 
453
 
454
  # pylint: disable=duplicate-code
455
  result, current_len = parse_tokenized_to_result(
src/axolotl/prompters.py CHANGED
@@ -259,6 +259,12 @@ SHAREGPT_ASSERTION_FAILED_ROLE = (
259
  "Role did not alternate between turns (gpt and human). Please check your data."
260
  )
261
 
 
 
 
 
 
 
262
 
263
  class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
264
  """
@@ -268,7 +274,9 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
268
  role_key_human = "human"
269
  role_key_model = "gpt"
270
  # Optional, only used for tool usage datasets.
271
- role_key_tool = None
 
 
272
 
273
  def __init__(
274
  self,
@@ -277,6 +285,7 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
277
  role_key_human: Optional[str] = None,
278
  role_key_model: Optional[str] = None,
279
  role_key_tool: Optional[str] = None,
 
280
  ):
281
  if conversation:
282
  if isinstance(conversation, Conversation):
@@ -291,6 +300,8 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
291
  self.role_key_model = role_key_model
292
  if role_key_tool:
293
  self.role_key_tool = role_key_tool
 
 
294
 
295
  def _build_result(self, source):
296
  if len(source) < 2:
@@ -322,11 +333,23 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
322
 
323
  conv.messages = []
324
  for _, sentence in enumerate(source):
325
- role = roles[sentence["from"]]
326
- if len(conv.messages) > 0 and (
327
- (role == conv.messages[-1][0]) or (role not in conv.roles)
328
- ):
 
 
 
 
 
 
 
 
 
 
 
329
  LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
 
330
  conv.append_message(role, sentence["value"])
331
 
332
  return conv.get_turns()
@@ -354,11 +377,13 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
354
  conversation: Optional[Union[str, Conversation]] = None,
355
  role_key_human: Optional[str] = None,
356
  role_key_model: Optional[str] = None,
 
357
  ):
358
  super().__init__(
359
  conversation=conversation,
360
  role_key_human=role_key_human,
361
  role_key_model=role_key_model,
 
362
  )
363
 
364
 
 
259
  "Role did not alternate between turns (gpt and human). Please check your data."
260
  )
261
 
262
+ CONVERSATION_ROLE_FORMAT = {
263
+ "chatml": "<|im_start|>{ROLE}",
264
+ "zephyr": "<|{ROLE}|>",
265
+ "vicuna_v1.1": "{ROLE}",
266
+ }
267
+
268
 
269
  class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
270
  """
 
274
  role_key_human = "human"
275
  role_key_model = "gpt"
276
  # Optional, only used for tool usage datasets.
277
+ role_key_tool: Optional[str] = None
278
+ # Optional, role input/output mapping
279
+ roles: Optional[dict] = None
280
 
281
  def __init__(
282
  self,
 
285
  role_key_human: Optional[str] = None,
286
  role_key_model: Optional[str] = None,
287
  role_key_tool: Optional[str] = None,
288
+ roles: Optional[dict] = None,
289
  ):
290
  if conversation:
291
  if isinstance(conversation, Conversation):
 
300
  self.role_key_model = role_key_model
301
  if role_key_tool:
302
  self.role_key_tool = role_key_tool
303
+ if roles:
304
+ self.roles = roles
305
 
306
  def _build_result(self, source):
307
  if len(source) < 2:
 
333
 
334
  conv.messages = []
335
  for _, sentence in enumerate(source):
336
+ from_role = sentence["from"]
337
+ if from_role in roles:
338
+ role = roles[from_role]
339
+ else:
340
+ if self._conversation.name not in CONVERSATION_ROLE_FORMAT:
341
+ raise NotImplementedError(
342
+ f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet."
343
+ "Please help us by creating an Issue to add support for this conversation type."
344
+ )
345
+
346
+ role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format(
347
+ ROLE=from_role
348
+ )
349
+
350
+ if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
351
  LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
352
+
353
  conv.append_message(role, sentence["value"])
354
 
355
  return conv.get_turns()
 
377
  conversation: Optional[Union[str, Conversation]] = None,
378
  role_key_human: Optional[str] = None,
379
  role_key_model: Optional[str] = None,
380
+ roles: Optional[dict] = None,
381
  ):
382
  super().__init__(
383
  conversation=conversation,
384
  role_key_human=role_key_human,
385
  role_key_model=role_key_model,
386
+ roles=roles,
387
  )
388
 
389
 
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -96,6 +96,8 @@ class SFTDataset(BaseModel):
96
  field_human: Optional[str] = None
97
  field_model: Optional[str] = None
98
 
 
 
99
 
100
  class UserDefinedDPOType(BaseModel):
101
  """User defined typing for DPO"""
 
96
  field_human: Optional[str] = None
97
  field_model: Optional[str] = None
98
 
99
+ roles: Optional[Dict[str, List[str]]] = None
100
+
101
 
102
  class UserDefinedDPOType(BaseModel):
103
  """User defined typing for DPO"""
tests/prompt_strategies/test_sharegpt.py CHANGED
@@ -62,6 +62,38 @@ def fixture_sharegpt_glaive_dataset():
62
  )
63
 
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  @pytest.fixture(name="tokenizer")
66
  def fixture_tokenizer():
67
  tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
@@ -196,3 +228,39 @@ class TestSharegpt:
196
  32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
197
  ]
198
  # fmt: on
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  )
63
 
64
 
65
+ @pytest.fixture(name="multi_role_dataset")
66
+ def fixture_multi_role_dataset():
67
+ return Dataset.from_list(
68
+ [
69
+ {
70
+ "conversations": [
71
+ {
72
+ "from": "system",
73
+ "value": "use get_weather(city) to get the weather for a city",
74
+ },
75
+ {
76
+ "from": "human",
77
+ "value": "hello, what's the weather in New York?",
78
+ },
79
+ {
80
+ "from": "gpt",
81
+ "value": "let me get that for you",
82
+ },
83
+ {
84
+ "from": "tool",
85
+ "value": "get_weather(New York)",
86
+ },
87
+ {
88
+ "from": "gpt",
89
+ "value": "the weather in New York is 70 degrees and sunny",
90
+ },
91
+ ]
92
+ }
93
+ ]
94
+ )
95
+
96
+
97
  @pytest.fixture(name="tokenizer")
98
  def fixture_tokenizer():
99
  tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
 
228
  32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
229
  ]
230
  # fmt: on
231
+
232
+ def test_multi_role_dataset(self, multi_role_dataset, tokenizer):
233
+ strategy = SimpleShareGPTPromptTokenizingStrategy(
234
+ ShareGPTPrompterV2(conversation="chatml", roles={"input": ["tool"]}),
235
+ tokenizer,
236
+ False, # train_on_inputs
237
+ 2048, # sequence_len
238
+ )
239
+
240
+ dataset_wrapper = TokenizedPromptDataset(
241
+ strategy, multi_role_dataset, process_count=1
242
+ )
243
+
244
+ input_ids = dataset_wrapper[0]["input_ids"]
245
+ # fmt: off
246
+ assert input_ids == [
247
+ 1, # bos
248
+ 32001, 1587, 13, 1730, 625, 28730, 769, 1223, 28732, 18373, 28731, 298, 625, 272, 8086, 354, 264, 2990, 32000, 28705, 13, # system
249
+ 32001, 2188, 13, 21558, 28725, 767, 28742, 28713, 272, 8086, 297, 1450, 2726, 28804, 32000, 28705, 13, # human
250
+ 32001, 13892, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
251
+ 32001, 3921, 13, 527, 28730, 769, 1223, 28732, 2972, 2726, 28731, 32000, 28705, 13, # tool
252
+ 32001, 13892, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
253
+ ]
254
+ # fmt: on
255
+
256
+ labels = dataset_wrapper[0]["labels"]
257
+ # fmt: off
258
+ assert labels == [
259
+ -100, # bos
260
+ -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # system
261
+ -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # human
262
+ -100, -100, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
263
+ -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool
264
+ -100, -100, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
265
+ ]
266
+ # fmt: on