test3 / tests /proxy_unit_tests /test_custom_callback_input.py
DesertWolf's picture
Upload folder using huggingface_hub
447ebeb verified
### What this tests ####
## This test asserts the type of data passed into each method of the custom callback handler
import asyncio
import inspect
import os
import sys
import time
import traceback
import uuid
from datetime import datetime
import pytest
from pydantic import BaseModel
sys.path.insert(0, os.path.abspath("../.."))
from typing import List, Literal, Optional, Union
from unittest.mock import AsyncMock, MagicMock, patch
import litellm
from litellm import Cache, completion, embedding
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import LiteLLMCommonStrings
# Test Scenarios (test across completion, streaming, embedding)
## 1: Pre-API-Call
## 2: Post-API-Call
## 3: On LiteLLM Call success
## 4: On LiteLLM Call failure
## 5. Caching
# Test models
## 1. OpenAI
## 2. Azure OpenAI
## 3. Non-OpenAI/Azure - e.g. Bedrock
# Test interfaces
## 1. litellm.completion() + litellm.embeddings()
## refer to test_custom_callback_input_router.py for the router + proxy tests
class CompletionCustomHandler(
CustomLogger
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
"""
The set of expected inputs to a custom handler for a
"""
# Class variables or attributes
def __init__(self):
self.errors = []
self.states: List[
Literal[
"sync_pre_api_call",
"async_pre_api_call",
"post_api_call",
"sync_stream",
"async_stream",
"sync_success",
"async_success",
"sync_failure",
"async_failure",
]
] = []
def log_pre_api_call(self, model, messages, kwargs):
try:
self.states.append("sync_pre_api_call")
## MODEL
assert isinstance(model, str)
## MESSAGES
assert isinstance(messages, list)
## KWARGS
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
### METADATA
metadata_value = kwargs["litellm_params"].get("metadata")
assert metadata_value is None or isinstance(metadata_value, dict)
if metadata_value is not None:
if litellm.turn_off_message_logging is True:
assert (
metadata_value["raw_request"]
is LiteLLMCommonStrings.redacted_by_litellm.value
)
else:
assert "raw_request" not in metadata_value or isinstance(
metadata_value["raw_request"], str
)
except Exception:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("post_api_call")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert end_time == None
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs["input"], (list, dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"],
(str, litellm.CustomStreamWrapper, BaseModel),
)
or inspect.iscoroutine(kwargs["original_response"])
or inspect.isasyncgen(kwargs["original_response"])
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except Exception:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("async_stream")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list) and isinstance(
kwargs["messages"][0], dict
)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert (
isinstance(kwargs["input"], list)
and isinstance(kwargs["input"][0], dict)
) or isinstance(kwargs["input"], (dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.isasyncgen(kwargs["original_response"])
or inspect.iscoroutine(kwargs["original_response"])
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except Exception:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
print(f"\n\nkwargs={kwargs}\n\n")
print(
json.dumps(kwargs, default=str)
) # this is a test to confirm no circular references are in the logging object
self.states.append("sync_success")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(
response_obj,
(
litellm.ModelResponse,
litellm.EmbeddingResponse,
litellm.ImageResponse,
),
)
## KWARGS
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list) and isinstance(
kwargs["messages"][0], dict
)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["litellm_params"]["api_base"], str)
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert (
isinstance(kwargs["input"], list)
and (
isinstance(kwargs["input"][0], dict)
or isinstance(kwargs["input"][0], str)
)
) or isinstance(kwargs["input"], (dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert isinstance(
kwargs["original_response"],
(str, litellm.CustomStreamWrapper, BaseModel),
), "Original Response={}. Allowed types=[str, litellm.CustomStreamWrapper, BaseModel]".format(
kwargs["original_response"]
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
assert isinstance(kwargs["response_cost"], (float, type(None)))
except Exception:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
print(f"kwargs: {kwargs}")
self.states.append("sync_failure")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list) and isinstance(
kwargs["messages"][0], dict
)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["litellm_params"]["metadata"], Optional[dict])
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert (
isinstance(kwargs["input"], list)
and isinstance(kwargs["input"][0], dict)
) or isinstance(kwargs["input"], (dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or kwargs["original_response"] == None
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except Exception:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_pre_api_call(self, model, messages, kwargs):
try:
self.states.append("async_pre_api_call")
## MODEL
assert isinstance(model, str)
## MESSAGES
assert isinstance(messages, list) and isinstance(messages[0], dict)
## KWARGS
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list) and isinstance(
kwargs["messages"][0], dict
)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
print(
"in async_log_success_event", kwargs, response_obj, start_time, end_time
)
self.states.append("async_success")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(
response_obj,
(
litellm.ModelResponse,
litellm.EmbeddingResponse,
litellm.TextCompletionResponse,
),
)
## KWARGS
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["litellm_params"]["api_base"], str)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["completion_start_time"], datetime)
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs["input"], (list, dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.isasyncgen(kwargs["original_response"])
or inspect.iscoroutine(kwargs["original_response"])
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
assert isinstance(kwargs["response_cost"], (float, type(None)))
except Exception:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("async_failure")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs["input"], (list, str, dict))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.isasyncgen(kwargs["original_response"])
or inspect.iscoroutine(kwargs["original_response"])
or kwargs["original_response"] == None
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except Exception:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())