Fix Llama 3.1 Chat Template to Properly Handle add_generation_prompt
Problem
The current chat template for Llama 3.1 incorrectly adds an <|eot_id|> token at the end of the last message, even when add_generation_prompt is set to false. This prevents the model from continuing generation when it should be able to do so.
Solution
This PR modifies the chat template to conditionally add the <|eot_id|> token based on whether it's the last message and the value of add_generation_prompt.
Changes
Modified the regular message handling section of the template to only add <|eot_id|> if it's not the last message or if add_generation_prompt is true.
Test
Here is a test case showing the result:
from transformers import AutoTokenizer
# Define the custom chat template
custom_chat_template = "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim }}\n {%- if not loop.last or add_generation_prompt %}\n {{- '<|eot_id|>' }}\n {%- endif %}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '=\"' + arg_val + '\"' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \")\" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- \"<|eom_id|>\" }}\n {%- else %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"
def apply_custom_chat_template(messages, add_generation_prompt=False):
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-70B-Instruct")
# Apply the custom chat template
chat_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=add_generation_prompt,
chat_template=custom_chat_template
)
return chat_text
def test_custom_chat_template():
messages = [
{"role": "user", "content": "tell me a common saying"},
{"role": "assistant", "content": "Here is a common saying about apple. An apple a day, keeps"}
]
# Test with add_generation_prompt=False
result_false = apply_custom_chat_template(messages, add_generation_prompt=False)
print("Result with add_generation_prompt=False:")
print(result_false)
print("\n" + "="*50 + "\n")
# Test with add_generation_prompt=True
result_true = apply_custom_chat_template(messages, add_generation_prompt=True)
print("Result with add_generation_prompt=True:")
print(result_true)
# Check for the absence of <|eot_id|> at the end when add_generation_prompt is False
if not result_false.strip().endswith("<|eot_id|>"):
print("\nSUCCESS: <|eot_id|> is correctly absent at the end when add_generation_prompt is False.")
else:
print("\nERROR: <|eot_id|> is present at the end when add_generation_prompt is False.")
# Check for the presence of an empty assistant turn when add_generation_prompt is True
if result_true.strip().endswith("<|start_header_id|>assistant<|end_header_id|>"):
print("SUCCESS: An empty assistant turn is correctly added when add_generation_prompt is True.")
else:
print("ERROR: No empty assistant turn is added when add_generation_prompt is True.")
if __name__ == "__main__":
test_custom_chat_template()
I'm not entirely sure I follow—are you expecting the model to continue generating within the last conversational turn? Could you clarify why it should be able to do that?
Yes, that is exactly the use case.
Here is the documentation for the chat template generation prompt: https://huggingface.co/docs/transformers/main/en/chat_templating#what-are-generation-prompts
There are plenty of open source UI's (LibreChat, OpenWebUI, etc) that work with OpenAI chat/completion compatible endpoints, and many of those support advanced editing features. e.g. you can pause generation, edit the assistant's response to fix any issues or ideas you didn't like in it, and then resume generation.
Here is an example of OpenWebUI:
Response is stopped and I begin to edit the response:
The response has been edited a new direction:
The edit was saved and now we can continue response (if the UI software, chat template, and the inference server supported it):
You could also use this to simulate the behavior of the /completions endpoint using this feature.
This would allow you to do things like train a model on completing the users' request (autocomplete for the text entry section for example).
Here are a couple examples using curl and the modified chat template using vLLM's OpenAI chat/completions endpoint:
echo: true
add_generation_prompt: false
echo: false
add_generation_prompt: false
echo: false (default if left off)
add_generation_prompt: true (default if left off)
Understood, I see the use case now.
It might be better to customize this within the application itself for inference. Removing the last component, as you suggested, could disrupt the completion-only fine-tuning process.
@tanliboy I don't quite follow how it would cause issues with training.
Unless the training code is passing into the tokenizer add_generation_prompt=false
there should be absolutely zero difference with the template. Or am I missing something I should be seeing?
Edit: I also wrote a bunch of training code to deal with chat templates for the axolotl framework: https://github.com/axolotl-ai-cloud/axolotl/pull/1756
So I have at least spent a good bit of time thinking about this area.
The add_generation_prompt
option is set to False
by default in the transformers
library, and it's typically enabled only during inference. During fine-tuning, such as when using the DPOTrainer
in trl
or run_sft
in the Alignment Handbook, the training samples are tokenized along with the chat template. Without the enclosing token generated, the model may struggle to determine whether it should continue generating tokens, potentially affecting its performance.
Yeah, that would be a problem even in the code I wrote for axolotl. Thank you very much for bringing that up and getting me up to speed.
When I said:
echo: false (default if left off)
add_generation_prompt: true (default if left off)
I was thinking about the inference code I wrote for vLLM OpenAI API, not the hf transformer definition. Thanks for correcting my thinking there.
So, good point. There seems to be a disconnect with what is needed during inference vs training, and it looks like there isn't the right configuration built into apply_chat_template to deal with it quite yet.
I'll have to think on this a bit more.
I opened an issue with the transformers project to see if I could make some traction there: https://github.com/huggingface/transformers/issues/33096
Going to close the pull requests I opened for 8b / 405b for now, and just leave this 70b version open until I get a fix in place in transformers, and then can get back to this.
@tanliboy Looks like there is a solution that requires no changes to the existing chat templates: https://github.com/huggingface/transformers/pull/33198
Thank you for the help here.