|
import re |
|
from transformers import AutoTokenizer |
|
|
|
def extract_separators(template): |
|
""" |
|
Extracts separators used in the tokenization template. |
|
""" |
|
|
|
pattern = r"\{\{\s*([^{}]+?)\s*\+ message\['content'\]" |
|
matches = re.findall(pattern, template) |
|
|
|
separators = [match.strip() for match in matches] |
|
|
|
if any("message['role']" in element for element in separators): |
|
roles = ["system", "user", "assistant"] |
|
separators_ = [] |
|
for role in roles: |
|
separators_.append(separators[0].replace(" + message['role'] + ", role).replace("'","")) |
|
return separators_ |
|
|
|
return separators |
|
|
|
def detect_eos_token(jinja_template, tokenizer): |
|
if "<|im_end|>" in jinja_template: |
|
return "<|im_end|>" |
|
if "</s>" in jinja_template: |
|
return "</s>" |
|
if "eos_token" in jinja_template: |
|
return tokenizer.eos_token |
|
if "<|endoftext|>" in jinja_template: |
|
return tokenizer.eos_token |
|
else: |
|
return "<|endoftext|>" |
|
|
|
def recover_messages(formatted_message, separators, eos_token): |
|
""" |
|
Recovers the original messages from the formatted message string. |
|
""" |
|
|
|
split_messages = formatted_message.split(eos_token) |
|
|
|
|
|
if split_messages and split_messages[-1].strip() == '': |
|
split_messages.pop() |
|
|
|
|
|
recovered_messages = [] |
|
|
|
|
|
alternate_roles = ["user", "assistant"] |
|
|
|
|
|
for index, message_content in enumerate(split_messages): |
|
|
|
|
|
if index == 0: |
|
role = "system" |
|
else: |
|
role = alternate_roles[(index - 1) % 2] |
|
|
|
|
|
clean_content = message_content.strip() |
|
for separator in separators: |
|
clean_content = clean_content.replace(separator.strip("'"), '', 1).strip() |
|
|
|
|
|
recovered_messages.append({"role": role, "content": clean_content}) |
|
|
|
return recovered_messages |
|
|
|
def recover_chat_messages(tokenized_chat, tokenizer): |
|
""" |
|
Given a tokenized_chat string and a tokenizer, returns the list of message dictionaries. |
|
""" |
|
jinja_template = tokenizer.chat_template |
|
separators = extract_separators(jinja_template) |
|
eos_token = eos_token = detect_eos_token(jinja_template, tokenizer) |
|
recovered_messages = recover_messages(tokenized_chat, separators, eos_token) |
|
return recovered_messages |
|
|
|
|
|
if __name__ == "__main__": |
|
checkpoint = "Qwen/Qwen1.5-0.5B" |
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
|
|
|
messages = [ |
|
{ |
|
"role": "system", |
|
"content": "You are a friendly chatbot who always responds in the style of a pirate", |
|
}, |
|
{"role": "user", "content": "How many helicopters can a human eat in one sitting?"}, |
|
] |
|
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=False) |
|
print(tokenized_chat) |
|
|
|
recovered_messages = recover_chat_messages(tokenized_chat, tokenizer) |
|
print(recovered_messages) |
|
|