File size: 2,931 Bytes
084fe8e
 
 
 
 
 
 
 
 
 
acb3380
084fe8e
 
acb3380
084fe8e
 
 
acb3380
 
084fe8e
 
 
 
 
 
acb3380
 
 
 
 
084fe8e
 
 
 
 
 
acb3380
 
 
 
 
 
 
 
084fe8e
 
 
 
 
acb3380
 
084fe8e
 
 
 
 
 
 
 
acb3380
084fe8e
 
 
 
acb3380
084fe8e
 
acb3380
084fe8e
 
 
acb3380
 
084fe8e
 
 
acb3380
 
084fe8e
 
 
acb3380
 
084fe8e
 
acb3380
 
084fe8e
acb3380
 
 
 
084fe8e
 
 
 
acb3380
084fe8e
 
acb3380
084fe8e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from typing import (
    Any,
    Callable,
    Dict,
    List,
    Optional,
    Type,
    TypeVar,
    Union,
)

# This TypeVar is used for methods that might need to return or work with instances of subclasses of BaseMessenger.
T = TypeVar("T")


class BaseMessenger:
    _messenger_registry: Dict[str, Type["BaseMessenger"]] = {}

    @classmethod
    def register_messenger(
        cls, messenger_name: str
    ) -> Callable[[Type["BaseMessenger"]], Type["BaseMessenger"]]:
        def decorator(
            subclass: Type["BaseMessenger"],
        ) -> Type["BaseMessenger"]:
            cls._messenger_registry[messenger_name] = subclass
            return subclass

        return decorator

    def __new__(
        cls: Type["BaseMessenger"],
        messenger_name: str,
        *args: Any,
        **kwargs: Any,
    ) -> "BaseMessenger":
        if messenger_name not in cls._messenger_registry:
            raise ValueError(
                f"No messenger registered with name '{messenger_name}'"
            )
        return super(BaseMessenger, cls).__new__(
            cls._messenger_registry[messenger_name]
        )

    def __init__(
        self,
        role: Optional[str] = None,
        content: Optional[Union[str, Dict[str, Any], List[Any]]] = None,
    ) -> None:
        self.init_messenger(role, content)

    def init_messenger(
        self,
        role: Optional[str] = None,
        content: Optional[Union[str, Dict[str, Any], List[Any]]] = None,
    ) -> None:
        raise NotImplementedError(
            "The 'init_messenger' method must be implemented in derived classes."
        )

    def update_message(
        self, role: str, content: Union[str, Dict[str, Any], List[Any]]
    ) -> None:
        self.messages.append({"role": role, "content": content})

    def check_iter_round_num(self) -> int:
        return len(self.messages)

    def add_system_message(
        self, message: Union[str, Dict[str, Any], List[Any]]
    ) -> None:
        self.update_message("system", message)

    def add_assistant_message(
        self, message: Union[str, Dict[str, Any], List[Any]]
    ) -> None:
        self.update_message("assistant", message)

    def add_user_message(
        self, message: Union[str, Dict[str, Any], List[Any]]
    ) -> None:
        self.update_message("user", message)

    def add_user_image(self, image_base64: str) -> None:
        self.update_message(
            "user",
            {
                "type": "image",
                "image_url": f"data:image/jpeg;base64,{image_base64}",
            },
        )

    def add_feedback(
        self, feedback: Union[str, Dict[str, Any], List[Any]]
    ) -> None:
        self.update_message("system", feedback)

    def clear(self) -> None:
        self.messages.clear()

    def get_messages(
        self,
    ) -> List[Dict[str, Union[str, Dict[str, Any], List[Any]]]]:
        return self.messages