File size: 2,499 Bytes
a85c9b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import pytest

from embedchain.memory.base import ChatHistory
from embedchain.memory.message import ChatMessage


# Fixture for creating an instance of ChatHistory
@pytest.fixture
def chat_memory_instance():
    return ChatHistory()


def test_add_chat_memory(chat_memory_instance):
    app_id = "test_app"
    session_id = "test_session"
    human_message = "Hello, how are you?"
    ai_message = "I'm fine, thank you!"

    chat_message = ChatMessage()
    chat_message.add_user_message(human_message)
    chat_message.add_ai_message(ai_message)

    chat_memory_instance.add(app_id, session_id, chat_message)

    assert chat_memory_instance.count(app_id, session_id) == 1
    chat_memory_instance.delete(app_id, session_id)


def test_get(chat_memory_instance):
    app_id = "test_app"
    session_id = "test_session"

    for i in range(1, 7):
        human_message = f"Question {i}"
        ai_message = f"Answer {i}"

        chat_message = ChatMessage()
        chat_message.add_user_message(human_message)
        chat_message.add_ai_message(ai_message)

        chat_memory_instance.add(app_id, session_id, chat_message)

    recent_memories = chat_memory_instance.get(app_id, session_id, num_rounds=5)

    assert len(recent_memories) == 5

    all_memories = chat_memory_instance.get(app_id, fetch_all=True)

    assert len(all_memories) == 6


def test_delete_chat_history(chat_memory_instance):
    app_id = "test_app"
    session_id = "test_session"

    for i in range(1, 6):
        human_message = f"Question {i}"
        ai_message = f"Answer {i}"

        chat_message = ChatMessage()
        chat_message.add_user_message(human_message)
        chat_message.add_ai_message(ai_message)

        chat_memory_instance.add(app_id, session_id, chat_message)

    session_id_2 = "test_session_2"

    for i in range(1, 6):
        human_message = f"Question {i}"
        ai_message = f"Answer {i}"

        chat_message = ChatMessage()
        chat_message.add_user_message(human_message)
        chat_message.add_ai_message(ai_message)

        chat_memory_instance.add(app_id, session_id_2, chat_message)

    chat_memory_instance.delete(app_id, session_id)

    assert chat_memory_instance.count(app_id, session_id) == 0
    assert chat_memory_instance.count(app_id) == 5

    chat_memory_instance.delete(app_id)

    assert chat_memory_instance.count(app_id) == 0


@pytest.fixture
def close_connection(chat_memory_instance):
    yield
    chat_memory_instance.close_connection()