File size: 6,553 Bytes
58d33f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""Test CallbackManager."""
from typing import Tuple

import pytest

from langchain.callbacks.base import (
    AsyncCallbackManager,
    BaseCallbackManager,
    CallbackManager,
)
from langchain.callbacks.shared import SharedCallbackManager
from langchain.schema import AgentFinish, LLMResult
from tests.unit_tests.callbacks.fake_callback_handler import (
    BaseFakeCallbackHandler,
    FakeAsyncCallbackHandler,
    FakeCallbackHandler,
)


def _test_callback_manager(
    manager: BaseCallbackManager, *handlers: BaseFakeCallbackHandler
) -> None:
    """Test the CallbackManager."""
    manager.on_llm_start({}, [])
    manager.on_llm_end(LLMResult(generations=[]))
    manager.on_llm_error(Exception())
    manager.on_chain_start({"name": "foo"}, {})
    manager.on_chain_end({})
    manager.on_chain_error(Exception())
    manager.on_tool_start({}, "")
    manager.on_tool_end("")
    manager.on_tool_error(Exception())
    manager.on_agent_finish(AgentFinish(log="", return_values={}))
    _check_num_calls(handlers)


async def _test_callback_manager_async(
    manager: AsyncCallbackManager, *handlers: BaseFakeCallbackHandler
) -> None:
    """Test the CallbackManager."""
    await manager.on_llm_start({}, [])
    await manager.on_llm_end(LLMResult(generations=[]))
    await manager.on_llm_error(Exception())
    await manager.on_chain_start({"name": "foo"}, {})
    await manager.on_chain_end({})
    await manager.on_chain_error(Exception())
    await manager.on_tool_start({}, "")
    await manager.on_tool_end("")
    await manager.on_tool_error(Exception())
    await manager.on_agent_finish(AgentFinish(log="", return_values={}))
    _check_num_calls(handlers)


def _check_num_calls(handlers: Tuple[BaseFakeCallbackHandler, ...]) -> None:
    for handler in handlers:
        if handler.always_verbose:
            assert handler.starts == 3
            assert handler.ends == 4
            assert handler.errors == 3
        else:
            assert handler.starts == 0
            assert handler.ends == 0
            assert handler.errors == 0


def _test_callback_manager_pass_in_verbose(
    manager: BaseCallbackManager, *handlers: FakeCallbackHandler
) -> None:
    """Test the CallbackManager."""
    manager.on_llm_start({}, [], verbose=True)
    manager.on_llm_end(LLMResult(generations=[]), verbose=True)
    manager.on_llm_error(Exception(), verbose=True)
    manager.on_chain_start({"name": "foo"}, {}, verbose=True)
    manager.on_chain_end({}, verbose=True)
    manager.on_chain_error(Exception(), verbose=True)
    manager.on_tool_start({}, "", verbose=True)
    manager.on_tool_end("", verbose=True)
    manager.on_tool_error(Exception(), verbose=True)
    manager.on_agent_finish(AgentFinish(log="", return_values={}), verbose=True)
    for handler in handlers:
        assert handler.starts == 3
        assert handler.ends == 4
        assert handler.errors == 3


def test_callback_manager() -> None:
    """Test the CallbackManager."""
    handler1 = FakeCallbackHandler(always_verbose_=True)
    handler2 = FakeCallbackHandler(always_verbose_=False)
    manager = CallbackManager([handler1, handler2])
    _test_callback_manager(manager, handler1, handler2)


def test_callback_manager_pass_in_verbose() -> None:
    """Test the CallbackManager."""
    handler1 = FakeCallbackHandler()
    handler2 = FakeCallbackHandler()
    manager = CallbackManager([handler1, handler2])
    _test_callback_manager_pass_in_verbose(manager, handler1, handler2)


def test_ignore_llm() -> None:
    """Test ignore llm param for callback handlers."""
    handler1 = FakeCallbackHandler(ignore_llm_=True, always_verbose_=True)
    handler2 = FakeCallbackHandler(always_verbose_=True)
    manager = CallbackManager(handlers=[handler1, handler2])
    manager.on_llm_start({}, [], verbose=True)
    manager.on_llm_end(LLMResult(generations=[]), verbose=True)
    manager.on_llm_error(Exception(), verbose=True)
    assert handler1.starts == 0
    assert handler1.ends == 0
    assert handler1.errors == 0
    assert handler2.starts == 1
    assert handler2.ends == 1
    assert handler2.errors == 1


def test_ignore_chain() -> None:
    """Test ignore chain param for callback handlers."""
    handler1 = FakeCallbackHandler(ignore_chain_=True, always_verbose_=True)
    handler2 = FakeCallbackHandler(always_verbose_=True)
    manager = CallbackManager(handlers=[handler1, handler2])
    manager.on_chain_start({"name": "foo"}, {}, verbose=True)
    manager.on_chain_end({}, verbose=True)
    manager.on_chain_error(Exception(), verbose=True)
    assert handler1.starts == 0
    assert handler1.ends == 0
    assert handler1.errors == 0
    assert handler2.starts == 1
    assert handler2.ends == 1
    assert handler2.errors == 1


def test_ignore_agent() -> None:
    """Test ignore agent param for callback handlers."""
    handler1 = FakeCallbackHandler(ignore_agent_=True, always_verbose_=True)
    handler2 = FakeCallbackHandler(always_verbose_=True)
    manager = CallbackManager(handlers=[handler1, handler2])
    manager.on_tool_start({}, "", verbose=True)
    manager.on_tool_end("", verbose=True)
    manager.on_tool_error(Exception(), verbose=True)
    manager.on_agent_finish(AgentFinish({}, ""), verbose=True)
    assert handler1.starts == 0
    assert handler1.ends == 0
    assert handler1.errors == 0
    assert handler2.starts == 1
    assert handler2.ends == 2
    assert handler2.errors == 1


def test_shared_callback_manager() -> None:
    """Test the SharedCallbackManager."""
    manager1 = SharedCallbackManager()
    manager2 = SharedCallbackManager()

    assert manager1 is manager2

    handler1 = FakeCallbackHandler(always_verbose_=True)
    handler2 = FakeCallbackHandler()
    manager1.add_handler(handler1)
    manager2.add_handler(handler2)
    _test_callback_manager(manager1, handler1, handler2)


@pytest.mark.asyncio
async def test_async_callback_manager() -> None:
    """Test the AsyncCallbackManager."""
    handler1 = FakeAsyncCallbackHandler(always_verbose_=True)
    handler2 = FakeAsyncCallbackHandler()
    manager = AsyncCallbackManager([handler1, handler2])
    await _test_callback_manager_async(manager, handler1, handler2)


@pytest.mark.asyncio
async def test_async_callback_manager_sync_handler() -> None:
    """Test the AsyncCallbackManager."""
    handler1 = FakeCallbackHandler(always_verbose_=True)
    handler2 = FakeAsyncCallbackHandler()
    manager = AsyncCallbackManager([handler1, handler2])
    await _test_callback_manager_async(manager, handler1, handler2)