File size: 2,323 Bytes
3b0afd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f26156
3b0afd8
 
9f26156
3b0afd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f26156
3b0afd8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import json
import logging
from datetime import datetime
from typing import List

from langchain.memory import MongoDBChatMessageHistory
from langchain.schema import AIMessage, BaseMessage, HumanMessage, messages_from_dict, _message_to_dict
from pymongo import errors

logger = logging.getLogger(__name__)


class CustomMongoDBChatMessageHistory(MongoDBChatMessageHistory):

    @property
    def messages(self) -> List[BaseMessage]:  # type: ignore
        """Retrieve the messages from MongoDB"""
        from pymongo import errors
        cursor = None
        try:
            cursor = self.collection.find({"SessionId": self.session_id})
        except errors.OperationFailure as error:
            logger.error(error)

        document_count = self.collection.count_documents({"SessionId": self.session_id})

        if cursor and document_count > 0:
            document = cursor[0]  # Get the first document with the matching session id
            items = document["messages"]  # Get the messages array from the document
        else:
            items = []

        messages = messages_from_dict([json.loads(item) for item in items])
        return messages

    def add_user_message(self, message: str) -> None:
        self.append(HumanMessage(content=message, additional_kwargs={"timestamp": datetime.utcnow()}))

    def add_ai_message(self, message: str) -> None:
        self.append(AIMessage(content=message, additional_kwargs={"timestamp": datetime.utcnow()}))

    def append(self, message: BaseMessage) -> None:
        """Append the message to the record in MongoDB with the desired format"""

        # Determine the sender based on the message type
        sender = "ai" if isinstance(message, AIMessage) else "human"

        # Create the message object with the desired format
        message_obj = {
            "type": sender,
            "content": message.content,
            "timestamp": datetime.utcnow()
        }

        try:
            # Update the messages array with the new message object
            self.collection.update_one(
                {"SessionId": self.session_id},
                {"$push": {"messages": json.dumps(_message_to_dict(message), default=str)}},
                upsert=True
            )
        except errors.WriteError as err:
            logger.error(err)