File size: 3,621 Bytes
2279d36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import re
from transformers import AutoTokenizer

def extract_separators(template):
    """
    Extracts separators used in the tokenization template.
    """
    # Adjust the regex to correctly match the specific pattern between '{{' and '+ message["content"] +'
    pattern = r"\{\{\s*([^{}]+?)\s*\+ message\['content'\]"
    matches = re.findall(pattern, template)
    # Clean up any extra spaces and return the matches
    separators = [match.strip() for match in matches]

    if any("message['role']" in element for element in separators):
        roles = ["system", "user", "assistant"]
        separators_ = []
        for role in roles:
            separators_.append(separators[0].replace(" + message['role'] + ", role).replace("'",""))
        return separators_

    return separators

def detect_eos_token(jinja_template, tokenizer):
    if "<|im_end|>" in jinja_template:
        return "<|im_end|>"
    if "</s>" in jinja_template:
        return "</s>"
    if "eos_token" in jinja_template:
        return tokenizer.eos_token    
    else:
        return "<|endoftext|>"

def recover_messages(formatted_message, separators, eos_token):
    """
    Recovers the original messages from the formatted message string.
    """
    # Split the formatted message using the end-of-string token
    split_messages = formatted_message.split(eos_token)
    
    # Remove the last empty string if it exists due to a trailing separator
    if split_messages and split_messages[-1].strip() == '':
        split_messages.pop()

    # Prepare the list to hold the recovered messages
    recovered_messages = []

    # Define roles after the first message, alternating between "user" and "assistant"
    alternate_roles = ["user", "assistant"]
    
    # Iterate over the split messages
    for index, message_content in enumerate(split_messages):
        # Determine the role, starting with "system" for the first message
        # then alternating between "user" and "assistant" for subsequent messages
        if index == 0:
            role = "system"
        else:
            role = alternate_roles[(index - 1) % 2]

        # Clean the message content by removing leading/trailing whitespace and separators
        clean_content = message_content.strip()
        for separator in separators:
            clean_content = clean_content.replace(separator.strip("'"), '', 1).strip()

        # Append the cleaned message with its role to the list
        recovered_messages.append({"role": role, "content": clean_content})

    return recovered_messages

def recover_chat_messages(tokenized_chat, tokenizer):
    """
    Given a tokenized_chat string and a tokenizer, returns the list of message dictionaries.
    """
    jinja_template = tokenizer.chat_template
    separators = extract_separators(jinja_template)
    eos_token = eos_token = detect_eos_token(jinja_template, tokenizer)
    recovered_messages = recover_messages(tokenized_chat, separators, eos_token)
    return recovered_messages

# Example usage
if __name__ == "__main__":
    checkpoint = "Qwen/Qwen1.5-0.5B"
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    
    messages = [
        {
            "role": "system",
            "content": "You are a friendly chatbot who always responds in the style of a pirate",
        },
        {"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
    ]
    tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=False)
    print(tokenized_chat)
    
    recovered_messages = recover_chat_messages(tokenized_chat, tokenizer)
    print(recovered_messages)