File size: 7,912 Bytes
447ebeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import sys
import os
import io, asyncio
import json
import pytest
import time
from litellm import mock_completion
from unittest.mock import MagicMock, AsyncMock, patch
sys.path.insert(0, os.path.abspath("../.."))
import litellm
from litellm.proxy.guardrails.guardrail_hooks.presidio import _OPTIONAL_PresidioPIIMasking, PresidioPerRequestConfig
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import StandardLoggingPayload, StandardLoggingGuardrailInformation
from litellm.types.guardrails import GuardrailEventHooks
from typing import Optional


class TestCustomLogger(CustomLogger):
    def __init__(self, *args, **kwargs):
        self.standard_logging_payload: Optional[StandardLoggingPayload] = None

    async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
        self.standard_logging_payload = kwargs.get("standard_logging_object")
        pass

@pytest.mark.asyncio
async def test_standard_logging_payload_includes_guardrail_information():
    """
    Test that the standard logging payload includes the guardrail information when a guardrail is applied
    """
    test_custom_logger = TestCustomLogger()
    litellm.callbacks = [test_custom_logger]
    presidio_guard = _OPTIONAL_PresidioPIIMasking(
        guardrail_name="presidio_guard",
        event_hook=GuardrailEventHooks.pre_call,
        presidio_analyzer_api_base=os.getenv("PRESIDIO_ANALYZER_API_BASE"),
        presidio_anonymizer_api_base=os.getenv("PRESIDIO_ANONYMIZER_API_BASE"),
    )
    # 1. call the pre call hook with guardrail
    request_data = {
        "model": "gpt-4o",
        "messages": [
            {"role": "user", "content": "Hello, my phone number is +1 412 555 1212"},
        ],
        "mock_response": "Hello",
        "guardrails": ["presidio_guard"],
        "metadata": {},
    }
    await presidio_guard.async_pre_call_hook(
        user_api_key_dict={},
        cache=None,
        data=request_data,
        call_type="acompletion"
    )

    # 2. call litellm.acompletion
    response = await litellm.acompletion(**request_data)

    # 3. assert that the standard logging payload includes the guardrail information
    await asyncio.sleep(1)
    print("got standard logging payload=", json.dumps(test_custom_logger.standard_logging_payload, indent=4, default=str))
    assert test_custom_logger.standard_logging_payload is not None
    assert test_custom_logger.standard_logging_payload["guardrail_information"] is not None
    assert test_custom_logger.standard_logging_payload["guardrail_information"]["guardrail_name"] == "presidio_guard"
    assert test_custom_logger.standard_logging_payload["guardrail_information"]["guardrail_mode"] == GuardrailEventHooks.pre_call

    # assert that the guardrail_response is a response from presidio analyze
    presidio_response = test_custom_logger.standard_logging_payload["guardrail_information"]["guardrail_response"]
    assert isinstance(presidio_response, list)
    for response_item in presidio_response:
        assert "analysis_explanation" in response_item
        assert "start" in response_item
        assert "end" in response_item
        assert "score" in response_item
        assert "entity_type" in response_item
        assert "recognition_metadata" in response_item
    

    # assert that the duration is not None
    assert test_custom_logger.standard_logging_payload["guardrail_information"]["duration"] is not None
    assert test_custom_logger.standard_logging_payload["guardrail_information"]["duration"] > 0

    # assert that we get the count of masked entities
    assert test_custom_logger.standard_logging_payload["guardrail_information"]["masked_entity_count"] is not None
    assert test_custom_logger.standard_logging_payload["guardrail_information"]["masked_entity_count"]["PHONE_NUMBER"] == 1




@pytest.mark.asyncio
@pytest.mark.skip(reason="Local only test")
async def test_langfuse_trace_includes_guardrail_information():
    """
    Test that the langfuse trace includes the guardrail information when a guardrail is applied
    """
    import httpx
    from unittest.mock import AsyncMock, patch
    from litellm.integrations.langfuse.langfuse_prompt_management import LangfusePromptManagement 
    callback = LangfusePromptManagement(flush_interval=3)
    import json
    
    # Create a mock Response object
    mock_response = AsyncMock(spec=httpx.Response)
    mock_response.status_code = 200
    mock_response.json.return_value = {"status": "success"}
    
    # Create mock for httpx.Client.post
    mock_post = AsyncMock()
    mock_post.return_value = mock_response
    
    with patch("httpx.Client.post", mock_post):
        litellm._turn_on_debug()
        litellm.callbacks = [callback]
        presidio_guard = _OPTIONAL_PresidioPIIMasking(
            guardrail_name="presidio_guard",
            event_hook=GuardrailEventHooks.pre_call,
            presidio_analyzer_api_base=os.getenv("PRESIDIO_ANALYZER_API_BASE"),
            presidio_anonymizer_api_base=os.getenv("PRESIDIO_ANONYMIZER_API_BASE"),
        )
        # 1. call the pre call hook with guardrail
        request_data = {
            "model": "gpt-4o",
            "messages": [
                {"role": "user", "content": "Hello, my phone number is +1 412 555 1212"},
            ],
            "mock_response": "Hello",
            "guardrails": ["presidio_guard"],
            "metadata": {},
        }
        await presidio_guard.async_pre_call_hook(
            user_api_key_dict={},
            cache=None,
            data=request_data,
            call_type="acompletion"
        )

        # 2. call litellm.acompletion
        response = await litellm.acompletion(**request_data)

        # 3. Wait for async logging operations to complete
        await asyncio.sleep(5)
        
        # 4. Verify the Langfuse payload
        assert mock_post.call_count >= 1
        url = mock_post.call_args[0][0]
        request_body = mock_post.call_args[1].get("content")
        
        # Parse the JSON body
        actual_payload = json.loads(request_body)
        print("\nLangfuse payload:", json.dumps(actual_payload, indent=2))
        
        # Look for the guardrail span in the payload
        guardrail_span = None
        for item in actual_payload["batch"]:
            if (item["type"] == "span-create" and 
                item["body"].get("name") == "guardrail"):
                guardrail_span = item
                break
        
        # Assert that the guardrail span exists
        assert guardrail_span is not None, "No guardrail span found in Langfuse payload"
        
        # Validate the structure of the guardrail span
        assert guardrail_span["body"]["name"] == "guardrail"
        assert "metadata" in guardrail_span["body"]
        assert guardrail_span["body"]["metadata"]["guardrail_name"] == "presidio_guard"
        assert guardrail_span["body"]["metadata"]["guardrail_mode"] == GuardrailEventHooks.pre_call
        assert "guardrail_masked_entity_count" in guardrail_span["body"]["metadata"]
        assert guardrail_span["body"]["metadata"]["guardrail_masked_entity_count"]["PHONE_NUMBER"] == 1
        
        # Validate the output format matches the expected structure
        assert "output" in guardrail_span["body"]
        assert isinstance(guardrail_span["body"]["output"], list)
        assert len(guardrail_span["body"]["output"]) > 0
        
        # Validate the first output item has the expected structure
        output_item = guardrail_span["body"]["output"][0]
        assert "entity_type" in output_item
        assert output_item["entity_type"] == "PHONE_NUMBER"
        assert "score" in output_item
        assert "start" in output_item
        assert "end" in output_item
        assert "recognition_metadata" in output_item
        assert "recognizer_name" in output_item["recognition_metadata"]