|
from transformers import AutoTokenizer |
|
from fastchat.conversation import get_conv_template |
|
import os |
|
from utils import sanitize_jinja2 |
|
import difflib |
|
|
|
def test_llama2_template(): |
|
jinja_lines = [] |
|
with open("../templates/airoboros_v1.jinja2", "r") as f: |
|
jinja_lines = f.readlines() |
|
|
|
print("jinja_lines: ", jinja_lines) |
|
|
|
print("sanitized: ", sanitize_jinja2(jinja_lines)) |
|
|
|
chat = [ |
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
{"role": "user", "content": "Hello, how are you?"}, |
|
{"role": "assistant", "content": "I'm doing great. How can I help you today?"}, |
|
{"role": "user", "content": "I'd like to show off how chat templating works!"}, |
|
] |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="jondurbin/airoboros-l2-7b-gpt4-2.0", trust_remote_code=True) |
|
|
|
transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False) |
|
print("default template") |
|
print(transformer_prompt) |
|
|
|
|
|
tokenizer.chat_template = sanitize_jinja2(jinja_lines) |
|
|
|
transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False) |
|
print() |
|
print("add_generation_prompt False:") |
|
print(transformer_prompt) |
|
|
|
transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) |
|
print() |
|
print("add_generation_prompt True:") |
|
print(transformer_prompt) |
|
|
|
|
|
|
|
|
|
print("Fastchat template: ") |
|
conv = get_conv_template("airoboros_v1") |
|
|
|
conv.set_system_message(chat[0]["content"]) |
|
conv.append_message(conv.roles[0], chat[1]["content"]) |
|
conv.append_message(conv.roles[1], chat[2]["content"]) |
|
conv.append_message(conv.roles[0], chat[3]["content"]) |
|
conv.append_message(conv.roles[1], None) |
|
print(conv.get_prompt()) |
|
matcher = difflib.SequenceMatcher(a=transformer_prompt, b=conv.get_prompt()) |
|
print("Matching Sequences:") |
|
for match in matcher.get_matching_blocks(): |
|
print("Match : {}".format(match)) |
|
print("Matching Sequence : {}".format(transformer_prompt[match.a:match.a+match.size])) |
|
assert transformer_prompt == conv.get_prompt() |
|
|
|
if __name__ == "__main__": |
|
test_llama2_template() |