File size: 3,093 Bytes
360d784
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time    : 2023/5/8 22:12
@Author  : alexanderwu
@File    : schema.py
@Desc    : mashenquan, 2023/8/22. Add tags to enable custom message classification.
"""
from __future__ import annotations

from dataclasses import dataclass, field
from enum import Enum
from typing import Type, TypedDict, Set, Optional, List

from pydantic import BaseModel

from metagpt.logs import logger


class MessageTag(Enum):
    Prerequisite = "prerequisite"


class RawMessage(TypedDict):
    content: str
    role: str


@dataclass
class Message:
    """list[<role>: <content>]"""
    content: str
    instruct_content: BaseModel = field(default=None)
    role: str = field(default='user')  # system / user / assistant
    cause_by: Type["Action"] = field(default="")
    sent_from: str = field(default="")
    send_to: str = field(default="")
    tags: Optional[Set] = field(default=None)

    def __str__(self):
        # prefix = '-'.join([self.role, str(self.cause_by)])
        return f"{self.role}: {self.content}"

    def __repr__(self):
        return self.__str__()

    def to_dict(self) -> dict:
        return {
            "role": self.role,
            "content": self.content
        }

    def add_tag(self, tag):
        if self.tags is None:
            self.tags = set()
        self.tags.add(tag)

    def remove_tag(self, tag):
        if self.tags is None or tag not in self.tags:
            return
        self.tags.remove(tag)

    def is_contain_tags(self, tags: list) -> bool:
        """Determine whether the message contains tags."""
        if not tags or not self.tags:
            return False
        intersection = set(tags) & self.tags
        return len(intersection) > 0

    def is_contain(self, tag):
        return self.is_contain_tags([tag])

    def dict(self):
        """pydantic-like `dict` function"""
        full = {
            "instruct_content": self.instruct_content,
            "sent_from": self.sent_from,
            "send_to": self.send_to,
            "tags": self.tags
        }

        m = {"content": self.content}
        for k, v in full.items():
            if v:
                m[k] = v
        return m


@dataclass
class UserMessage(Message):
    """便于支持OpenAI的消息
       Facilitate support for OpenAI messages
    """

    def __init__(self, content: str):
        super().__init__(content, 'user')


@dataclass
class SystemMessage(Message):
    """便于支持OpenAI的消息
       Facilitate support for OpenAI messages
    """

    def __init__(self, content: str):
        super().__init__(content, 'system')


@dataclass
class AIMessage(Message):
    """便于支持OpenAI的消息
       Facilitate support for OpenAI messages
    """

    def __init__(self, content: str):
        super().__init__(content, 'assistant')


if __name__ == '__main__':
    test_content = 'test_message'
    msgs = [
        UserMessage(test_content),
        SystemMessage(test_content),
        AIMessage(test_content),
        Message(test_content, role='QA')
    ]
    logger.info(msgs)