from transformers import AutoTokenizer from fastchat.conversation import get_conv_template import os from utils import sanitize_jinja2 import difflib def test_llama2_template(): jinja_lines = [] with open("../templates/mistral-7b-openorca.jinja2", "r") as f: jinja_lines = f.readlines() print("jinja_lines: ", jinja_lines) print("sanitized: ", sanitize_jinja2(jinja_lines)) chat = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello, how are you?"}, {"role": "assistant", "content": "I'm doing great. How can I help you today?"}, {"role": "user", "content": "I'd like to show off how chat templating works!"}, ] tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="Open-Orca/Mistral-7B-OpenOrca", trust_remote_code=True) # f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant" transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False) print("default template") print(transformer_prompt) # print(tokenizer.chat_template) # tokenizer.eos_token = "<|end_of_turn|>" tokenizer.chat_template = sanitize_jinja2(jinja_lines) transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False) print() print("add_generation_prompt False:") print(transformer_prompt) transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) print() print("add_generation_prompt True:") print(transformer_prompt) # transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True) # print(transformer_prompt) print("Fastchat template: ") conv = get_conv_template("mistral-7b-openorca") conv.set_system_message(chat[0]["content"]) conv.append_message(conv.roles[0], chat[1]["content"]) conv.append_message(conv.roles[1], chat[2]["content"]) conv.append_message(conv.roles[0], chat[3]["content"]) conv.append_message(conv.roles[1], None) print(conv.get_prompt()) matcher = difflib.SequenceMatcher(a=transformer_prompt, b=conv.get_prompt()) print("Matching Sequences:") for match in matcher.get_matching_blocks(): print("Match : {}".format(match)) print("Matching Sequence : {}".format(transformer_prompt[match.a:match.a+match.size])) assert transformer_prompt == conv.get_prompt() if __name__ == "__main__": test_llama2_template()