|
from typing import Any, Dict, List, Optional |
|
from langchain.memory.chat_memory import BaseChatMemory |
|
from langchain.schema import BaseMessage, HumanMessage, AIMessage, SystemMessage, ChatMessage |
|
|
|
|
|
def get_buffer_string( |
|
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI" |
|
) -> str: |
|
"""Get buffer string of messages.""" |
|
string_messages = [] |
|
for m in messages: |
|
if isinstance(m, HumanMessage): |
|
print("HumanMessage: " + m.content) |
|
role = human_prefix + ": " |
|
elif isinstance(m, AIMessage): |
|
print("AIMessage" + m.content) |
|
role = "" |
|
elif isinstance(m, SystemMessage): |
|
print("SystemMessage") |
|
role = "System: " |
|
elif isinstance(m, ChatMessage): |
|
print("ChatMessage") |
|
role = m.role + ": " |
|
else: |
|
raise ValueError(f"Got unsupported message type: {m}") |
|
|
|
string_messages.append(f"{role + m.content}") |
|
|
|
return "\n".join(string_messages) |
|
|
|
|
|
class HumenFeedbackBufferMemory(BaseChatMemory): |
|
"""Buffer for storing conversation memory.""" |
|
|
|
human_prefix: str = "Human" |
|
ai_prefix: str = "AI" |
|
memory_key: str = "history" |
|
|
|
@property |
|
def buffer(self) -> Any: |
|
"""String buffer of memory.""" |
|
if self.return_messages: |
|
return self.chat_memory.messages |
|
else: |
|
return get_buffer_string( |
|
self.chat_memory.messages, |
|
human_prefix=self.human_prefix, |
|
ai_prefix=self.ai_prefix, |
|
) |
|
|
|
@property |
|
def memory_variables(self) -> List[str]: |
|
"""Will always return list of memory variables. |
|
|
|
:meta private: |
|
""" |
|
return [self.memory_key] |
|
|
|
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: |
|
"""Return history buffer.""" |
|
return {self.memory_key: self.buffer} |
|
|