|
from typing import Dict, Any, List |
|
|
|
from langchain.memory import ConversationBufferWindowMemory |
|
from langchain.schema import ( |
|
BaseMessage, |
|
HumanMessage, |
|
AIMessage, |
|
) |
|
|
|
from .registry import BaseParent |
|
|
|
|
|
class ChatGLMConversationBufferWindowMemory(ConversationBufferWindowMemory): |
|
|
|
human_prefix: str = "问" |
|
ai_prefix: str = "答" |
|
|
|
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: |
|
"""Return history buffer.""" |
|
|
|
if self.return_messages: |
|
buffer: Any = self.buffer[-self.k * 2:] |
|
else: |
|
buffer = self.get_buffer_string( |
|
self.buffer[-self.k * 2:], |
|
human_prefix=self.human_prefix, |
|
ai_prefix=self.ai_prefix, |
|
) |
|
return {self.memory_key: buffer} |
|
|
|
@staticmethod |
|
def get_buffer_string( |
|
messages: List[BaseMessage], human_prefix: str = "问", ai_prefix: str = "答" |
|
) -> str: |
|
"""Get buffer string of messages.""" |
|
string_messages, i = [], 0 |
|
for m in messages: |
|
if isinstance(m, HumanMessage): |
|
role = human_prefix |
|
string_messages.append(f"[Round {i}]\n{role}:{m.content}") |
|
i += 1 |
|
elif isinstance(m, AIMessage): |
|
role = ai_prefix |
|
string_messages.append(f"{role}:{m.content}") |
|
else: |
|
raise ValueError(f"Got unsupported message type: {m}") |
|
|
|
return "\n".join(string_messages) + f"\n[Round {i}]" |
|
|
|
|
|
class ChineseAlpacaConversationBufferWindowMemory(ChatGLMConversationBufferWindowMemory): |
|
|
|
human_prefix: str = "### Instruction" |
|
ai_prefix: str = "### Response" |
|
|
|
@staticmethod |
|
def get_buffer_string( |
|
messages: List[BaseMessage], human_prefix: str = "### Instruction", ai_prefix: str = "### Response" |
|
) -> str: |
|
"""Get buffer string of messages.""" |
|
string_messages = [] |
|
for m in messages: |
|
if isinstance(m, HumanMessage): |
|
role = human_prefix |
|
string_messages.append(f"{role}:\n\n{m.content}") |
|
elif isinstance(m, AIMessage): |
|
role = ai_prefix |
|
string_messages.append(f"{role}:\n\n{m.content}") |
|
else: |
|
raise ValueError(f"Got unsupported message type: {m}") |
|
|
|
return "\n\n".join(string_messages) |
|
|
|
|
|
class FireFlyConversationBufferWindowMemory(ChatGLMConversationBufferWindowMemory): |
|
|
|
human_prefix: str = "Human" |
|
ai_prefix: str = "Assistant" |
|
|
|
@staticmethod |
|
def get_buffer_string( |
|
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "Assistant" |
|
) -> str: |
|
"""Get buffer string of messages.""" |
|
string_messages = [] |
|
for m in messages: |
|
if isinstance(m, HumanMessage): |
|
string_messages.append(f"<s>{m.content}</s>") |
|
elif isinstance(m, AIMessage): |
|
string_messages.append(f"</s>{m.content}</s>") |
|
else: |
|
raise ValueError(f"Got unsupported message type: {m}") |
|
|
|
return "".join(string_messages) |
|
|
|
|
|
class PhoenixConversationBufferWindowMemory(ChatGLMConversationBufferWindowMemory): |
|
|
|
human_prefix: str = "Human" |
|
ai_prefix: str = "Assistant" |
|
|
|
@staticmethod |
|
def get_buffer_string( |
|
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "Assistant" |
|
) -> str: |
|
"""Get buffer string of messages.""" |
|
string_messages = [] |
|
for m in messages: |
|
if isinstance(m, HumanMessage): |
|
role = human_prefix |
|
elif isinstance(m, AIMessage): |
|
role = ai_prefix |
|
else: |
|
raise ValueError(f"Got unsupported message type: {m}") |
|
string_messages.append(f"{role}: <s>{m.content}</s>") |
|
return "".join(string_messages) |
|
|
|
|
|
class MossConversationBufferWindowMemory(ChatGLMConversationBufferWindowMemory): |
|
|
|
human_prefix: str = "<|Human|>" |
|
ai_prefix: str = "<|MOSS|>" |
|
|
|
@staticmethod |
|
def get_buffer_string( |
|
messages: List[BaseMessage], human_prefix: str = "<|Human|>", ai_prefix: str = "<|MOSS|>" |
|
) -> str: |
|
"""Get buffer string of messages.""" |
|
string_messages = [] |
|
for m in messages: |
|
if isinstance(m, HumanMessage): |
|
string_messages.append(f"{human_prefix}: {m.content}<eoh>") |
|
elif isinstance(m, AIMessage): |
|
string_messages.append(f"{ai_prefix}: {m.content}<eom>") |
|
else: |
|
raise ValueError(f"Got unsupported message type: {m}") |
|
return "\n".join(string_messages) |
|
|
|
|
|
class GuanacoConversationBufferWindowMemory(ChatGLMConversationBufferWindowMemory): |
|
|
|
human_prefix: str = "### Human" |
|
ai_prefix: str = "### Assistant" |
|
|
|
@staticmethod |
|
def get_buffer_string( |
|
messages: List[BaseMessage], human_prefix: str = "### Human", ai_prefix: str = "### Assistant" |
|
) -> str: |
|
"""Get buffer string of messages.""" |
|
string_messages = [] |
|
for m in messages: |
|
if isinstance(m, HumanMessage): |
|
string_messages.append(f"{human_prefix}: {m.content}") |
|
elif isinstance(m, AIMessage): |
|
string_messages.append(f"{ai_prefix}: {m.content}") |
|
else: |
|
raise ValueError(f"Got unsupported message type: {m}") |
|
return "\n".join(string_messages) |
|
|
|
|
|
class CustomConversationBufferWindowMemory(BaseParent): |
|
|
|
registry = {} |
|
|
|
|
|
CustomConversationBufferWindowMemory.add_to_registry("gpt-3.5-turbo", ConversationBufferWindowMemory) |
|
CustomConversationBufferWindowMemory.add_to_registry("chatglm", ChatGLMConversationBufferWindowMemory) |
|
CustomConversationBufferWindowMemory.add_to_registry("chinese-llama-alpaca", ChineseAlpacaConversationBufferWindowMemory) |
|
CustomConversationBufferWindowMemory.add_to_registry("firefly", FireFlyConversationBufferWindowMemory) |
|
CustomConversationBufferWindowMemory.add_to_registry("phoenix", PhoenixConversationBufferWindowMemory) |
|
CustomConversationBufferWindowMemory.add_to_registry("moss", MossConversationBufferWindowMemory) |
|
CustomConversationBufferWindowMemory.add_to_registry("guanaco", GuanacoConversationBufferWindowMemory) |
|
|