from typing import Dict, Any, List from langchain.memory import ConversationBufferWindowMemory from langchain.schema import ( BaseMessage, HumanMessage, AIMessage, ) from .registry import BaseParent class ChatGLMConversationBufferWindowMemory(ConversationBufferWindowMemory): human_prefix: str = "问" ai_prefix: str = "答" def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: """Return history buffer.""" if self.return_messages: buffer: Any = self.buffer[-self.k * 2:] else: buffer = self.get_buffer_string( self.buffer[-self.k * 2:], human_prefix=self.human_prefix, ai_prefix=self.ai_prefix, ) return {self.memory_key: buffer} @staticmethod def get_buffer_string( messages: List[BaseMessage], human_prefix: str = "问", ai_prefix: str = "答" ) -> str: """Get buffer string of messages.""" string_messages, i = [], 0 for m in messages: if isinstance(m, HumanMessage): role = human_prefix string_messages.append(f"[Round {i}]\n{role}:{m.content}") i += 1 elif isinstance(m, AIMessage): role = ai_prefix string_messages.append(f"{role}:{m.content}") else: raise ValueError(f"Got unsupported message type: {m}") return "\n".join(string_messages) + f"\n[Round {i}]" class ChineseAlpacaConversationBufferWindowMemory(ChatGLMConversationBufferWindowMemory): human_prefix: str = "### Instruction" ai_prefix: str = "### Response" @staticmethod def get_buffer_string( messages: List[BaseMessage], human_prefix: str = "### Instruction", ai_prefix: str = "### Response" ) -> str: """Get buffer string of messages.""" string_messages = [] for m in messages: if isinstance(m, HumanMessage): role = human_prefix string_messages.append(f"{role}:\n\n{m.content}") elif isinstance(m, AIMessage): role = ai_prefix string_messages.append(f"{role}:\n\n{m.content}") else: raise ValueError(f"Got unsupported message type: {m}") return "\n\n".join(string_messages) class FireFlyConversationBufferWindowMemory(ChatGLMConversationBufferWindowMemory): human_prefix: str = "Human" ai_prefix: str = "Assistant" @staticmethod def get_buffer_string( messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "Assistant" ) -> str: """Get buffer string of messages.""" string_messages = [] for m in messages: if isinstance(m, HumanMessage): string_messages.append(f"{m.content}") elif isinstance(m, AIMessage): string_messages.append(f"{m.content}") else: raise ValueError(f"Got unsupported message type: {m}") return "".join(string_messages) class PhoenixConversationBufferWindowMemory(ChatGLMConversationBufferWindowMemory): human_prefix: str = "Human" ai_prefix: str = "Assistant" @staticmethod def get_buffer_string( messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "Assistant" ) -> str: """Get buffer string of messages.""" string_messages = [] for m in messages: if isinstance(m, HumanMessage): role = human_prefix elif isinstance(m, AIMessage): role = ai_prefix else: raise ValueError(f"Got unsupported message type: {m}") string_messages.append(f"{role}: {m.content}") return "".join(string_messages) class MossConversationBufferWindowMemory(ChatGLMConversationBufferWindowMemory): human_prefix: str = "<|Human|>" ai_prefix: str = "<|MOSS|>" @staticmethod def get_buffer_string( messages: List[BaseMessage], human_prefix: str = "<|Human|>", ai_prefix: str = "<|MOSS|>" ) -> str: """Get buffer string of messages.""" string_messages = [] for m in messages: if isinstance(m, HumanMessage): string_messages.append(f"{human_prefix}: {m.content}") elif isinstance(m, AIMessage): string_messages.append(f"{ai_prefix}: {m.content}") else: raise ValueError(f"Got unsupported message type: {m}") return "\n".join(string_messages) class GuanacoConversationBufferWindowMemory(ChatGLMConversationBufferWindowMemory): human_prefix: str = "### Human" ai_prefix: str = "### Assistant" @staticmethod def get_buffer_string( messages: List[BaseMessage], human_prefix: str = "### Human", ai_prefix: str = "### Assistant" ) -> str: """Get buffer string of messages.""" string_messages = [] for m in messages: if isinstance(m, HumanMessage): string_messages.append(f"{human_prefix}: {m.content}") elif isinstance(m, AIMessage): string_messages.append(f"{ai_prefix}: {m.content}") else: raise ValueError(f"Got unsupported message type: {m}") return "\n".join(string_messages) class CustomConversationBufferWindowMemory(BaseParent): registry = {} CustomConversationBufferWindowMemory.add_to_registry("gpt-3.5-turbo", ConversationBufferWindowMemory) CustomConversationBufferWindowMemory.add_to_registry("chatglm", ChatGLMConversationBufferWindowMemory) CustomConversationBufferWindowMemory.add_to_registry("chinese-llama-alpaca", ChineseAlpacaConversationBufferWindowMemory) CustomConversationBufferWindowMemory.add_to_registry("firefly", FireFlyConversationBufferWindowMemory) CustomConversationBufferWindowMemory.add_to_registry("phoenix", PhoenixConversationBufferWindowMemory) CustomConversationBufferWindowMemory.add_to_registry("moss", MossConversationBufferWindowMemory) CustomConversationBufferWindowMemory.add_to_registry("guanaco", GuanacoConversationBufferWindowMemory)