Nanobit commited on
Commit
e37d935
1 Parent(s): b521206

Fix(message): Improve error message for bad format (#365)

Browse files
src/axolotl/prompt_strategies/llama2_chat.py CHANGED
@@ -29,7 +29,7 @@ from dataclasses import dataclass, field
29
  from typing import Generator, List, Sequence
30
 
31
  from axolotl.prompt_tokenizers import PromptTokenizingStrategy
32
- from axolotl.prompters import IGNORE_TOKEN_ID
33
 
34
 
35
  @dataclass
@@ -190,7 +190,7 @@ class Llama2ChatPrompter: # pylint: disable=too-few-public-methods
190
  conv.messages = [] # pylint: disable=R0801
191
  for j, sentence in enumerate(source):
192
  role = roles[sentence["from"]]
193
- assert role == conv.roles[j % 2]
194
  if sentence["value"]:
195
  conv.append_message(role, sentence["value"])
196
  yield conv
 
29
  from typing import Generator, List, Sequence
30
 
31
  from axolotl.prompt_tokenizers import PromptTokenizingStrategy
32
+ from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE
33
 
34
 
35
  @dataclass
 
190
  conv.messages = [] # pylint: disable=R0801
191
  for j, sentence in enumerate(source):
192
  role = roles[sentence["from"]]
193
+ assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
194
  if sentence["value"]:
195
  conv.append_message(role, sentence["value"])
196
  yield conv
src/axolotl/prompters.py CHANGED
@@ -260,6 +260,11 @@ class Conversation:
260
  self.messages.append([role, message])
261
 
262
 
 
 
 
 
 
263
  class ShareGPTPrompter: # pylint: disable=too-few-public-methods
264
  """
265
  A prompter that generates prompts for the ShareGPT
@@ -316,7 +321,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
316
  conv.messages = []
317
  for j, sentence in enumerate(source):
318
  role = roles[sentence["from"]]
319
- assert role == conv.roles[j % 2]
320
  conv.append_message(role, sentence["value"])
321
 
322
  for part in conv.get_prompt():
 
260
  self.messages.append([role, message])
261
 
262
 
263
+ SHAREGPT_ASSERTION_FAILED_ROLE = (
264
+ "Role did not alternate between turns (gpt and human). Please check your data."
265
+ )
266
+
267
+
268
  class ShareGPTPrompter: # pylint: disable=too-few-public-methods
269
  """
270
  A prompter that generates prompts for the ShareGPT
 
321
  conv.messages = []
322
  for j, sentence in enumerate(source):
323
  role = roles[sentence["from"]]
324
+ assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
325
  conv.append_message(role, sentence["value"])
326
 
327
  for part in conv.get_prompt():