Fix Llama 3.1 Chat Template to Properly Handle add_generation_prompt

#24
by Tostino - opened

Problem

The current chat template for Llama 3.1 incorrectly adds an <|eot_id|> token at the end of the last message, even when add_generation_prompt is set to false. This prevents the model from continuing generation when it should be able to do so.

Solution

This PR modifies the chat template to conditionally add the <|eot_id|> token based on whether it's the last message and the value of add_generation_prompt.

Changes

Modified the regular message handling section of the template to only add <|eot_id|> if it's not the last message or if add_generation_prompt is true.

Test

Here is a test case showing the result:

from transformers import AutoTokenizer

# Define the custom chat template
custom_chat_template = "{{- bos_token }}\n{%- if custom_tools is defined %}\n    {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n    {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n    {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n    {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n    {%- set system_message = messages[0]['content']|trim %}\n    {%- set messages = messages[1:] %}\n{%- else %}\n    {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n    {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n    {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n    {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n    {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n    {{- \"Do not use variables.\\n\\n\" }}\n    {%- for t in tools %}\n        {{- t | tojson(indent=4) }}\n        {{- \"\\n\\n\" }}\n    {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n    {#- Extract the first user message so we can plug it in here #}\n    {%- if messages | length != 0 %}\n        {%- set first_user_message = messages[0]['content']|trim %}\n        {%- set messages = messages[1:] %}\n    {%- else %}\n        {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n    {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n    {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n    {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n    {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n    {{- \"Do not use variables.\\n\\n\" }}\n    {%- for t in tools %}\n        {{- t | tojson(indent=4) }}\n        {{- \"\\n\\n\" }}\n    {%- endfor %}\n    {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n    {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n        {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim }}\n        {%- if not loop.last or add_generation_prompt %}\n            {{- '<|eot_id|>' }}\n        {%- endif %}\n    {%- elif 'tool_calls' in message %}\n        {%- if not message.tool_calls|length == 1 %}\n            {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n        {%- endif %}\n        {%- set tool_call = message.tool_calls[0].function %}\n        {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n            {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n            {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n            {%- for arg_name, arg_val in tool_call.arguments | items %}\n                {{- arg_name + '=\"' + arg_val + '\"' }}\n                {%- if not loop.last %}\n                    {{- \", \" }}\n                {%- endif %}\n                {%- endfor %}\n            {{- \")\" }}\n        {%- else  %}\n            {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n            {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n            {{- '\"parameters\": ' }}\n            {{- tool_call.arguments | tojson }}\n            {{- \"}\" }}\n        {%- endif %}\n        {%- if builtin_tools is defined %}\n            {#- This means we're in ipython mode #}\n            {{- \"<|eom_id|>\" }}\n        {%- else %}\n            {{- \"<|eot_id|>\" }}\n        {%- endif %}\n    {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n        {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n        {%- if message.content is mapping or message.content is iterable %}\n            {{- message.content | tojson }}\n        {%- else %}\n            {{- message.content }}\n        {%- endif %}\n        {{- \"<|eot_id|>\" }}\n    {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n    {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"

def apply_custom_chat_template(messages, add_generation_prompt=False):
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-405B-Instruct")

    # Apply the custom chat template
    chat_text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=add_generation_prompt,
        chat_template=custom_chat_template
    )

    return chat_text

def test_custom_chat_template():
    messages = [
        {"role": "user", "content": "tell me a common saying"},
        {"role": "assistant", "content": "Here is a common saying about apple. An apple a day, keeps"}
    ]

    # Test with add_generation_prompt=False
    result_false = apply_custom_chat_template(messages, add_generation_prompt=False)
    print("Result with add_generation_prompt=False:")
    print(result_false)
    print("\n" + "="*50 + "\n")

    # Test with add_generation_prompt=True
    result_true = apply_custom_chat_template(messages, add_generation_prompt=True)
    print("Result with add_generation_prompt=True:")
    print(result_true)

    # Check for the absence of <|eot_id|> at the end when add_generation_prompt is False
    if not result_false.strip().endswith("<|eot_id|>"):
        print("\nSUCCESS: <|eot_id|> is correctly absent at the end when add_generation_prompt is False.")
    else:
        print("\nERROR: <|eot_id|> is present at the end when add_generation_prompt is False.")

    # Check for the presence of an empty assistant turn when add_generation_prompt is True
    if result_true.strip().endswith("<|start_header_id|>assistant<|end_header_id|>"):
        print("SUCCESS: An empty assistant turn is correctly added when add_generation_prompt is True.")
    else:
        print("ERROR: No empty assistant turn is added when add_generation_prompt is True.")

if __name__ == "__main__":
    test_custom_chat_template()

There is discussion on the 70b version of this request (https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct/discussions/26) which means I have more work to do before this can be merged. Closing for now.

Tostino changed pull request status to closed

Sign up or log in to comment