|
import gradio as gr |
|
from transformers import AutoTokenizer |
|
import json |
|
from functools import partial |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") |
|
|
|
demo_conversation = """[ |
|
{"role": "system", "content": "You are a helpful chatbot."}, |
|
{"role": "user", "content": "Hi there!"}, |
|
{"role": "assistant", "content": "Hello, human!"}, |
|
{"role": "user", "content": "Can I ask a question?"} |
|
]""" |
|
|
|
chat_templates = { |
|
"chatml": """{% for message in messages %} |
|
{{ "<|im_start|>" + message["role"] + "\\n" + message["content"] + "<|im_end|>\\n" }} |
|
{% endfor %} |
|
{% if add_generation_prompt %} |
|
{{ "<|im_start|>assistant\\n" }} |
|
{% endif %}""", |
|
"zephyr": """{% for message in messages %} |
|
{% if message['role'] == 'user' %} |
|
{{ '<|user|>\n' + message['content'] + eos_token }} |
|
{% elif message['role'] == 'system' %} |
|
{{ '<|system|>\n' + message['content'] + eos_token }} |
|
{% elif message['role'] == 'assistant' %} |
|
{{ '<|assistant|>\n' + message['content'] + eos_token }} |
|
{% endif %} |
|
{% if loop.last and add_generation_prompt %} |
|
{{ '<|assistant|>' }} |
|
{% endif %} |
|
{% endfor %}""", |
|
"llama": """{% if messages[0]['role'] == 'system' %} |
|
{% set loop_messages = messages[1:] %} |
|
{% set system_message = messages[0]['content'] %} |
|
{% else %} |
|
{% set loop_messages = messages %} |
|
{% set system_message = false %} |
|
{% endif %} |
|
{% for message in loop_messages %} |
|
{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} |
|
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} |
|
{% endif %} |
|
{% if loop.index0 == 0 and system_message != false %} |
|
{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %} |
|
{% else %} |
|
{% set content = message['content'] %} |
|
{% endif %} |
|
{% if message['role'] == 'user' %} |
|
{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }} |
|
{% elif message['role'] == 'assistant' %} |
|
{{ ' ' + content.strip() + ' ' + eos_token }} |
|
{% endif %} |
|
{% endfor %}""", |
|
"alpaca": """{% for message in messages %} |
|
{% if message['role'] == 'system' %} |
|
{{ message['content'] + '\n\n' }} |
|
{% elif message['role'] == 'user' %} |
|
{{ '### Instruction:\n' + message['content'] + '\n\n' }} |
|
{% elif message['role'] == 'assistant' %} |
|
{{ '### Response:\n' + message['content'] + '\n\n' }} |
|
{% endif %} |
|
{% if loop.last and add_generation_prompt %} |
|
{{ '### Response:\n' }} |
|
{% endif %} |
|
{% endfor %}""", |
|
"vicuna": """{% for message in messages %} |
|
{% if message['role'] == 'system' %} |
|
{{ message['content'] + '\n' }} |
|
{% elif message['role'] == 'user' %} |
|
{{ 'USER:\n' + message['content'] + '\n' }} |
|
{% elif message['role'] == 'assistant' %} |
|
{{ 'ASSISTANT:\n' + message['content'] + '\n' }} |
|
{% endif %} |
|
{% if loop.last and add_generation_prompt %} |
|
{{ 'ASSISTANT:\n' }} |
|
{% endif %} |
|
{% endfor %}""", |
|
"falcon": """{% for message in messages %} |
|
{% if not loop.first %} |
|
{{ '\n' }} |
|
{% endif %} |
|
{% if message['role'] == 'system' %} |
|
{{ 'System: ' }} |
|
{% elif message['role'] == 'user' %} |
|
{{ 'User: ' }} |
|
{% elif message['role'] == 'assistant' %} |
|
{{ 'Falcon: ' }} |
|
{% endif %} |
|
{{ message['content'] }} |
|
{% endfor %} |
|
{% if add_generation_prompt %} |
|
{{ '\n' + 'Falcon:' }} |
|
{% endif %}""" |
|
} |
|
description_text = """# Chat Template Creator |
|
|
|
### This space is a helper app for writing [Chat Templates](https://huggingface.co/docs/transformers/main/en/chat_templating). |
|
|
|
### When you're happy with the outputs from your template, you can use the code block at the end to add it to a PR!""" |
|
|
|
def apply_chat_template(template, test_conversation, add_generation_prompt, cleanup_whitespace): |
|
if cleanup_whitespace: |
|
template = "".join([line.strip() for line in template.split('\n')]) |
|
tokenizer.chat_template = template |
|
outputs = [] |
|
conversation = json.loads(test_conversation) |
|
pr_snippet = ( |
|
"CHECKPOINT = \"big-ai-company/cool-new-model\"\n" |
|
"tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)", |
|
f"tokenizer.chat_template = \"{template}\"", |
|
"tokenizer.push_to_hub(CHECKPOINT, create_pr=True)" |
|
) |
|
pr_snippet = "\n".join(pr_snippet) |
|
formatted = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=add_generation_prompt) |
|
return formatted, pr_snippet |
|
|
|
def load_template(template_name): |
|
template_in.value = chat_templates[template_name] |
|
|
|
with gr.Blocks() as demo: |
|
|
|
gr.Markdown(description_text) |
|
|
|
with gr.Row(): |
|
gr.Markdown("### Pick an existing template to start:") |
|
with gr.Row(): |
|
load_chatml = gr.Button("ChatML") |
|
load_zephyr = gr.Button("Zephyr") |
|
load_llama = gr.Button("LLaMA") |
|
with gr.Row(): |
|
load_alpaca = gr.Button("Alpaca") |
|
load_vicuna = gr.Button("Vicuna") |
|
load_falcon = gr.Button("Falcon") |
|
with gr.Row(): |
|
with gr.Column(): |
|
template_in = gr.TextArea(value=chat_templates["chatml"], lines=10, max_lines=30, label="Chat Template") |
|
conversation_in = gr.TextArea(value=demo_conversation, lines=6, label="Conversation") |
|
generation_prompt_check = gr.Checkbox(value=False, label="Add generation prompt") |
|
cleanup_whitespace_check = gr.Checkbox(value=True, label="Cleanup template whitespace") |
|
submit = gr.Button("Apply template", variant="primary") |
|
with gr.Column(): |
|
formatted_out = gr.TextArea(label="Formatted conversation") |
|
code_snippet_out = gr.TextArea(label="Code snippet to create PR", lines=3, show_label=True, show_copy_button=True) |
|
submit.click(fn=apply_chat_template, |
|
inputs=[template_in, conversation_in, generation_prompt_check, cleanup_whitespace_check], |
|
outputs=[formatted_out, code_snippet_out] |
|
) |
|
load_chatml.click(fn=partial(load_template, "chatml")) |
|
load_zephyr.click(fn=partial(load_template, "zephyr")) |
|
load_llama.click(fn=partial(load_template, "llama")) |
|
load_alpaca.click(fn=partial(load_template, "alpaca")) |
|
load_vicuna.click(fn=partial(load_template, "vicuna")) |
|
load_falcon.click(fn=partial(load_template, "falcon")) |
|
|
|
|
|
demo.launch() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|