File size: 4,515 Bytes
6ec7105
 
 
 
 
3f6ce08
 
 
 
6ec7105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f6ce08
6ec7105
3f6ce08
6ec7105
3f6ce08
 
 
 
 
 
 
 
 
 
 
 
 
6ec7105
 
 
3f6ce08
 
 
 
 
 
 
 
 
6ec7105
3f6ce08
6ec7105
 
 
 
 
 
 
 
 
3f6ce08
6ec7105
 
 
 
 
 
3f6ce08
 
 
 
6ec7105
 
 
 
 
 
 
 
 
 
 
 
 
 
3f6ce08
6ec7105
 
 
3f6ce08
6ec7105
 
 
3f6ce08
 
6ec7105
 
 
 
 
 
 
 
3f6ce08
6ec7105
3f6ce08
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import re
from pprint import pprint


class MessageComposer:
    """
    models:
    - mixtral-8x7b (mistralai/Mixtral-8x7B-Instruct-v0.1)
    """

    def __init__(self, model: str = None):
        self.inst_roles = ["user", "system", "inst"]
        self.answer_roles = ["assistant", "bot", "answer"]

    def concat_messages_by_role(self, messages):
        def is_same_role(role1, role2):
            if (
                (role1 == role2)
                or (role1 in self.inst_roles and role2 in self.inst_roles)
                or (role1 in self.answer_roles and role2 in self.answer_roles)
            ):
                return True
            else:
                return False

        concat_messages = []
        for message in messages:
            role = message["role"]
            content = message["content"]
            if concat_messages and is_same_role(role, concat_messages[-1]["role"]):
                concat_messages[-1]["content"] += "\n" + content
            else:
                if role in self.inst_roles:
                    message["role"] = "inst"
                elif role in self.answer_roles:
                    message["role"] = "answer"
                else:
                    message["role"] = "inst"
                concat_messages.append(message)
        return concat_messages

    def merge(self, messages) -> str:
        # <s> [INST] Instruction [/INST] Model answer </s> [INST] Follow-up instruction [/INST]

        self.messages = self.concat_messages_by_role(messages)
        self.merged_str = ""
        self.cached_str = ""
        for message in self.messages:
            role = message["role"]
            content = message["content"]
            if role in self.inst_roles:
                self.cached_str = f"[INST] {content} [/INST]"
            elif role in self.answer_roles:
                self.merged_str += f"<s> {self.cached_str} {content} </s>\n"
                self.cached_str = ""
            else:
                self.cached_str = f"[INST] {content} [/INST]"
        if self.cached_str:
            self.merged_str += f"{self.cached_str}"

        return self.merged_str

    def split(self, merged_str) -> list:
        self.messages = []
        self.merged_str = merged_str
        pair_pattern = (
            r"<s>\s*\[INST\](?P<inst>[\s\S]*?)\[/INST\](?P<answer>[\s\S]*?)</s>"
        )
        pair_matches = re.finditer(pair_pattern, self.merged_str, re.MULTILINE)
        pair_matches_list = list(pair_matches)

        if len(pair_matches_list) <= 0:
            self.messages = [
                {
                    "role": "user",
                    "content": self.merged_str,
                }
            ]
        else:
            for match in pair_matches_list:
                inst = match.group("inst")
                answer = match.group("answer")
                self.messages.extend(
                    [
                        {"role": "user", "content": inst.strip()},
                        {"role": "assistant", "content": answer.strip()},
                    ]
                )

        inst_pattern = r"\[INST\](?P<inst>[\s\S]*?)\[/INST\]"
        inst_matches = re.finditer(inst_pattern, self.merged_str, re.MULTILINE)
        inst_matches_list = list(inst_matches)

        if len(inst_matches_list) > len(pair_matches_list):
            self.messages.extend(
                [
                    {
                        "role": "user",
                        "content": inst_matches_list[-1].group("inst").strip(),
                    }
                ]
            )

        return self.messages


if __name__ == "__main__":
    composer = MessageComposer()
    messages = [
        {
            "role": "system",
            "content": "You are a LLM developed by OpenAI. Your name is GPT-4.",
        },
        {"role": "user", "content": "Hello, who are you?"},
        {"role": "assistant", "content": "I am a bot."},
        # {"role": "user", "content": "What is your name?"},
        {"role": "assistant", "content": "My name is Bing."},
        # {"role": "user", "content": "Tell me a joke."},
        # {"role": "assistant", "content": "What is a robot's favorite type of music?"},
        # {
        #     "role": "user",
        #     "content": "How many questions have I asked? Please list them.",
        # },
    ]
    merged_str = composer.merge(messages)
    print(merged_str)
    pprint(composer.split(merged_str))
    # print(composer.merge(composer.split(merged_str)))