Zack Zitting Bradshaw
Upload folder using huggingface_hub
4962437
raw
history blame contribute delete
No virus
7.48 kB
from __future__ import annotations
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Sequence
from pydantic import Field
from swarms.utils.serializable import Serializable
if TYPE_CHECKING:
from langchain.prompts.chat import ChatPromptTemplate
def get_buffer_string(
messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
) -> str:
"""Convert sequence of Messages to strings and concatenate them into one string.
Args:
messages: Messages to be converted to strings.
human_prefix: The prefix to prepend to contents of HumanMessages.
ai_prefix: THe prefix to prepend to contents of AIMessages.
Returns:
A single string concatenation of all input messages.
Example:
.. code-block:: python
from langchain.schema import AIMessage, HumanMessage
messages = [
HumanMessage(content="Hi, how are you?"),
AIMessage(content="Good, how are you?"),
]
get_buffer_string(messages)
# -> "Human: Hi, how are you?\nAI: Good, how are you?"
"""
string_messages = []
for m in messages:
if isinstance(m, HumanMessage):
role = human_prefix
elif isinstance(m, AIMessage):
role = ai_prefix
elif isinstance(m, SystemMessage):
role = "System"
elif isinstance(m, FunctionMessage):
role = "Function"
elif isinstance(m, ChatMessage):
role = m.role
else:
raise ValueError(f"Got unsupported message type: {m}")
message = f"{role}: {m.content}"
if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs:
message += f"{m.additional_kwargs['function_call']}"
string_messages.append(message)
return "\n".join(string_messages)
class BaseMessage(Serializable):
"""The base abstract Message class.
Messages are the inputs and outputs of ChatModels.
"""
content: str
"""The string contents of the message."""
additional_kwargs: dict = Field(default_factory=dict)
"""Any additional information."""
@property
@abstractmethod
def type(self) -> str:
"""Type of the Message, used for serialization."""
@property
def lc_serializable(self) -> bool:
"""Whether this class is LangChain serializable."""
return True
def __add__(self, other: Any) -> ChatPromptTemplate:
from langchain.prompts.chat import ChatPromptTemplate
prompt = ChatPromptTemplate(messages=[self])
return prompt + other
class BaseMessageChunk(BaseMessage):
def _merge_kwargs_dict(
self, left: Dict[str, Any], right: Dict[str, Any]
) -> Dict[str, Any]:
"""Merge additional_kwargs from another BaseMessageChunk into this one."""
merged = left.copy()
for k, v in right.items():
if k not in merged:
merged[k] = v
elif type(merged[k]) != type(v):
raise ValueError(
f'additional_kwargs["{k}"] already exists in this message,'
" but with a different type."
)
elif isinstance(merged[k], str):
merged[k] += v
elif isinstance(merged[k], dict):
merged[k] = self._merge_kwargs_dict(merged[k], v)
else:
raise ValueError(
f"Additional kwargs key {k} already exists in this message."
)
return merged
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, BaseMessageChunk):
# If both are (subclasses of) BaseMessageChunk,
# concat into a single BaseMessageChunk
return self.__class__(
content=self.content + other.content,
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
else:
raise TypeError(
'unsupported operand type(s) for +: "'
f"{self.__class__.__name__}"
f'" and "{other.__class__.__name__}"'
)
class HumanMessage(BaseMessage):
"""A Message from a human."""
example: bool = False
"""Whether this Message is being passed in to the model as part of an example
conversation.
"""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "human"
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
pass
class AIMessage(BaseMessage):
"""A Message from an AI."""
example: bool = False
"""Whether this Message is being passed in to the model as part of an example
conversation.
"""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "ai"
class AIMessageChunk(AIMessage, BaseMessageChunk):
pass
class SystemMessage(BaseMessage):
"""A Message for priming AI behavior, usually passed in as the first of a sequence
of input messages.
"""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "system"
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
pass
class FunctionMessage(BaseMessage):
"""A Message for passing the result of executing a function back to a model."""
name: str
"""The name of the function that was executed."""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "function"
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
pass
class ChatMessage(BaseMessage):
"""A Message that can be assigned an arbitrary speaker (i.e. role)."""
role: str
"""The speaker / role of the Message."""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "chat"
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
pass
def _message_to_dict(message: BaseMessage) -> dict:
return {"type": message.type, "data": message.dict()}
def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
"""Convert a sequence of Messages to a list of dictionaries.
Args:
messages: Sequence of messages (as BaseMessages) to convert.
Returns:
List of messages as dicts.
"""
return [_message_to_dict(m) for m in messages]
def _message_from_dict(message: dict) -> BaseMessage:
_type = message["type"]
if _type == "human":
return HumanMessage(**message["data"])
elif _type == "ai":
return AIMessage(**message["data"])
elif _type == "system":
return SystemMessage(**message["data"])
elif _type == "chat":
return ChatMessage(**message["data"])
elif _type == "function":
return FunctionMessage(**message["data"])
else:
raise ValueError(f"Got unexpected message type: {_type}")
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
"""Convert a sequence of messages from dicts to Message objects.
Args:
messages: Sequence of messages (as dicts) to convert.
Returns:
List of messages (BaseMessages).
"""
return [_message_from_dict(m) for m in messages]