File size: 2,931 Bytes
084fe8e acb3380 084fe8e acb3380 084fe8e acb3380 084fe8e acb3380 084fe8e acb3380 084fe8e acb3380 084fe8e acb3380 084fe8e acb3380 084fe8e acb3380 084fe8e acb3380 084fe8e acb3380 084fe8e acb3380 084fe8e acb3380 084fe8e acb3380 084fe8e acb3380 084fe8e acb3380 084fe8e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Type,
TypeVar,
Union,
)
# This TypeVar is used for methods that might need to return or work with instances of subclasses of BaseMessenger.
T = TypeVar("T")
class BaseMessenger:
_messenger_registry: Dict[str, Type["BaseMessenger"]] = {}
@classmethod
def register_messenger(
cls, messenger_name: str
) -> Callable[[Type["BaseMessenger"]], Type["BaseMessenger"]]:
def decorator(
subclass: Type["BaseMessenger"],
) -> Type["BaseMessenger"]:
cls._messenger_registry[messenger_name] = subclass
return subclass
return decorator
def __new__(
cls: Type["BaseMessenger"],
messenger_name: str,
*args: Any,
**kwargs: Any,
) -> "BaseMessenger":
if messenger_name not in cls._messenger_registry:
raise ValueError(
f"No messenger registered with name '{messenger_name}'"
)
return super(BaseMessenger, cls).__new__(
cls._messenger_registry[messenger_name]
)
def __init__(
self,
role: Optional[str] = None,
content: Optional[Union[str, Dict[str, Any], List[Any]]] = None,
) -> None:
self.init_messenger(role, content)
def init_messenger(
self,
role: Optional[str] = None,
content: Optional[Union[str, Dict[str, Any], List[Any]]] = None,
) -> None:
raise NotImplementedError(
"The 'init_messenger' method must be implemented in derived classes."
)
def update_message(
self, role: str, content: Union[str, Dict[str, Any], List[Any]]
) -> None:
self.messages.append({"role": role, "content": content})
def check_iter_round_num(self) -> int:
return len(self.messages)
def add_system_message(
self, message: Union[str, Dict[str, Any], List[Any]]
) -> None:
self.update_message("system", message)
def add_assistant_message(
self, message: Union[str, Dict[str, Any], List[Any]]
) -> None:
self.update_message("assistant", message)
def add_user_message(
self, message: Union[str, Dict[str, Any], List[Any]]
) -> None:
self.update_message("user", message)
def add_user_image(self, image_base64: str) -> None:
self.update_message(
"user",
{
"type": "image",
"image_url": f"data:image/jpeg;base64,{image_base64}",
},
)
def add_feedback(
self, feedback: Union[str, Dict[str, Any], List[Any]]
) -> None:
self.update_message("system", feedback)
def clear(self) -> None:
self.messages.clear()
def get_messages(
self,
) -> List[Dict[str, Union[str, Dict[str, Any], List[Any]]]]:
return self.messages
|