File size: 2,323 Bytes
3b0afd8 9f26156 3b0afd8 9f26156 3b0afd8 9f26156 3b0afd8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import json
import logging
from datetime import datetime
from typing import List
from langchain.memory import MongoDBChatMessageHistory
from langchain.schema import AIMessage, BaseMessage, HumanMessage, messages_from_dict, _message_to_dict
from pymongo import errors
logger = logging.getLogger(__name__)
class CustomMongoDBChatMessageHistory(MongoDBChatMessageHistory):
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve the messages from MongoDB"""
from pymongo import errors
cursor = None
try:
cursor = self.collection.find({"SessionId": self.session_id})
except errors.OperationFailure as error:
logger.error(error)
document_count = self.collection.count_documents({"SessionId": self.session_id})
if cursor and document_count > 0:
document = cursor[0] # Get the first document with the matching session id
items = document["messages"] # Get the messages array from the document
else:
items = []
messages = messages_from_dict([json.loads(item) for item in items])
return messages
def add_user_message(self, message: str) -> None:
self.append(HumanMessage(content=message, additional_kwargs={"timestamp": datetime.utcnow()}))
def add_ai_message(self, message: str) -> None:
self.append(AIMessage(content=message, additional_kwargs={"timestamp": datetime.utcnow()}))
def append(self, message: BaseMessage) -> None:
"""Append the message to the record in MongoDB with the desired format"""
# Determine the sender based on the message type
sender = "ai" if isinstance(message, AIMessage) else "human"
# Create the message object with the desired format
message_obj = {
"type": sender,
"content": message.content,
"timestamp": datetime.utcnow()
}
try:
# Update the messages array with the new message object
self.collection.update_one(
{"SessionId": self.session_id},
{"$push": {"messages": json.dumps(_message_to_dict(message), default=str)}},
upsert=True
)
except errors.WriteError as err:
logger.error(err)
|