File size: 1,943 Bytes
8e967a7 474437d 8e967a7 474437d 8e967a7 474437d 8e967a7 474437d 8e967a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
from typing import Any, Dict, List, Optional
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import BaseMessage, HumanMessage, AIMessage, SystemMessage, ChatMessage
def get_buffer_string(
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
) -> str:
"""Get buffer string of messages."""
string_messages = []
for m in messages:
if isinstance(m, HumanMessage):
print("HumanMessage: " + m.content)
role = human_prefix + ": "
elif isinstance(m, AIMessage):
print("AIMessage" + m.content)
role = ""
elif isinstance(m, SystemMessage):
print("SystemMessage")
role = "System: "
elif isinstance(m, ChatMessage):
print("ChatMessage")
role = m.role + ": "
else:
raise ValueError(f"Got unsupported message type: {m}")
string_messages.append(f"{role + m.content}")
return "\n".join(string_messages)
class HumenFeedbackBufferMemory(BaseChatMemory):
"""Buffer for storing conversation memory."""
human_prefix: str = "Human"
ai_prefix: str = "AI"
memory_key: str = "history" #: :meta private:
@property
def buffer(self) -> Any:
"""String buffer of memory."""
if self.return_messages:
return self.chat_memory.messages
else:
return get_buffer_string(
self.chat_memory.messages,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
return {self.memory_key: self.buffer}
|