Spaces:
Running
Running
# What is this? | |
## Log success + failure events to Braintrust | |
import copy | |
import os | |
from datetime import datetime | |
from typing import Dict, Optional | |
import httpx | |
from pydantic import BaseModel | |
import litellm | |
from litellm import verbose_logger | |
from litellm.integrations.custom_logger import CustomLogger | |
from litellm.llms.custom_httpx.http_handler import ( | |
HTTPHandler, | |
get_async_httpx_client, | |
httpxSpecialProvider, | |
) | |
from litellm.utils import print_verbose | |
global_braintrust_http_handler = get_async_httpx_client( | |
llm_provider=httpxSpecialProvider.LoggingCallback | |
) | |
global_braintrust_sync_http_handler = HTTPHandler() | |
API_BASE = "https://api.braintrustdata.com/v1" | |
def get_utc_datetime(): | |
import datetime as dt | |
from datetime import datetime | |
if hasattr(dt, "UTC"): | |
return datetime.now(dt.UTC) # type: ignore | |
else: | |
return datetime.utcnow() # type: ignore | |
class BraintrustLogger(CustomLogger): | |
def __init__( | |
self, api_key: Optional[str] = None, api_base: Optional[str] = None | |
) -> None: | |
super().__init__() | |
self.validate_environment(api_key=api_key) | |
self.api_base = api_base or API_BASE | |
self.default_project_id = None | |
self.api_key: str = api_key or os.getenv("BRAINTRUST_API_KEY") # type: ignore | |
self.headers = { | |
"Authorization": "Bearer " + self.api_key, | |
"Content-Type": "application/json", | |
} | |
self._project_id_cache: Dict[ | |
str, str | |
] = {} # Cache mapping project names to IDs | |
def validate_environment(self, api_key: Optional[str]): | |
""" | |
Expects | |
BRAINTRUST_API_KEY | |
in the environment | |
""" | |
missing_keys = [] | |
if api_key is None and os.getenv("BRAINTRUST_API_KEY", None) is None: | |
missing_keys.append("BRAINTRUST_API_KEY") | |
if len(missing_keys) > 0: | |
raise Exception("Missing keys={} in environment.".format(missing_keys)) | |
def get_project_id_sync(self, project_name: str) -> str: | |
""" | |
Get project ID from name, using cache if available. | |
If project doesn't exist, creates it. | |
""" | |
if project_name in self._project_id_cache: | |
return self._project_id_cache[project_name] | |
try: | |
response = global_braintrust_sync_http_handler.post( | |
f"{self.api_base}/project", | |
headers=self.headers, | |
json={"name": project_name}, | |
) | |
project_dict = response.json() | |
project_id = project_dict["id"] | |
self._project_id_cache[project_name] = project_id | |
return project_id | |
except httpx.HTTPStatusError as e: | |
raise Exception(f"Failed to register project: {e.response.text}") | |
async def get_project_id_async(self, project_name: str) -> str: | |
""" | |
Async version of get_project_id_sync | |
""" | |
if project_name in self._project_id_cache: | |
return self._project_id_cache[project_name] | |
try: | |
response = await global_braintrust_http_handler.post( | |
f"{self.api_base}/project/register", | |
headers=self.headers, | |
json={"name": project_name}, | |
) | |
project_dict = response.json() | |
project_id = project_dict["id"] | |
self._project_id_cache[project_name] = project_id | |
return project_id | |
except httpx.HTTPStatusError as e: | |
raise Exception(f"Failed to register project: {e.response.text}") | |
def add_metadata_from_header(litellm_params: dict, metadata: dict) -> dict: | |
""" | |
Adds metadata from proxy request headers to Langfuse logging if keys start with "langfuse_" | |
and overwrites litellm_params.metadata if already included. | |
For example if you want to append your trace to an existing `trace_id` via header, send | |
`headers: { ..., langfuse_existing_trace_id: your-existing-trace-id }` via proxy request. | |
""" | |
if litellm_params is None: | |
return metadata | |
if litellm_params.get("proxy_server_request") is None: | |
return metadata | |
if metadata is None: | |
metadata = {} | |
proxy_headers = ( | |
litellm_params.get("proxy_server_request", {}).get("headers", {}) or {} | |
) | |
for metadata_param_key in proxy_headers: | |
if metadata_param_key.startswith("braintrust"): | |
trace_param_key = metadata_param_key.replace("braintrust", "", 1) | |
if trace_param_key in metadata: | |
verbose_logger.warning( | |
f"Overwriting Braintrust `{trace_param_key}` from request header" | |
) | |
else: | |
verbose_logger.debug( | |
f"Found Braintrust `{trace_param_key}` in request header" | |
) | |
metadata[trace_param_key] = proxy_headers.get(metadata_param_key) | |
return metadata | |
async def create_default_project_and_experiment(self): | |
project = await global_braintrust_http_handler.post( | |
f"{self.api_base}/project", headers=self.headers, json={"name": "litellm"} | |
) | |
project_dict = project.json() | |
self.default_project_id = project_dict["id"] | |
def create_sync_default_project_and_experiment(self): | |
project = global_braintrust_sync_http_handler.post( | |
f"{self.api_base}/project", headers=self.headers, json={"name": "litellm"} | |
) | |
project_dict = project.json() | |
self.default_project_id = project_dict["id"] | |
def log_success_event( # noqa: PLR0915 | |
self, kwargs, response_obj, start_time, end_time | |
): | |
verbose_logger.debug("REACHES BRAINTRUST SUCCESS") | |
try: | |
litellm_call_id = kwargs.get("litellm_call_id") | |
prompt = {"messages": kwargs.get("messages")} | |
output = None | |
choices = [] | |
if response_obj is not None and ( | |
kwargs.get("call_type", None) == "embedding" | |
or isinstance(response_obj, litellm.EmbeddingResponse) | |
): | |
output = None | |
elif response_obj is not None and isinstance( | |
response_obj, litellm.ModelResponse | |
): | |
output = response_obj["choices"][0]["message"].json() | |
choices = response_obj["choices"] | |
elif response_obj is not None and isinstance( | |
response_obj, litellm.TextCompletionResponse | |
): | |
output = response_obj.choices[0].text | |
choices = response_obj.choices | |
elif response_obj is not None and isinstance( | |
response_obj, litellm.ImageResponse | |
): | |
output = response_obj["data"] | |
litellm_params = kwargs.get("litellm_params", {}) | |
metadata = ( | |
litellm_params.get("metadata", {}) or {} | |
) # if litellm_params['metadata'] == None | |
metadata = self.add_metadata_from_header(litellm_params, metadata) | |
clean_metadata = {} | |
try: | |
metadata = copy.deepcopy( | |
metadata | |
) # Avoid modifying the original metadata | |
except Exception: | |
new_metadata = {} | |
for key, value in metadata.items(): | |
if ( | |
isinstance(value, list) | |
or isinstance(value, dict) | |
or isinstance(value, str) | |
or isinstance(value, int) | |
or isinstance(value, float) | |
): | |
new_metadata[key] = copy.deepcopy(value) | |
metadata = new_metadata | |
# Get project_id from metadata or create default if needed | |
project_id = metadata.get("project_id") | |
if project_id is None: | |
project_name = metadata.get("project_name") | |
project_id = ( | |
self.get_project_id_sync(project_name) if project_name else None | |
) | |
if project_id is None: | |
if self.default_project_id is None: | |
self.create_sync_default_project_and_experiment() | |
project_id = self.default_project_id | |
tags = [] | |
if isinstance(metadata, dict): | |
for key, value in metadata.items(): | |
# generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy | |
if ( | |
litellm.langfuse_default_tags is not None | |
and isinstance(litellm.langfuse_default_tags, list) | |
and key in litellm.langfuse_default_tags | |
): | |
tags.append(f"{key}:{value}") | |
# clean litellm metadata before logging | |
if key in [ | |
"headers", | |
"endpoint", | |
"caching_groups", | |
"previous_models", | |
]: | |
continue | |
else: | |
clean_metadata[key] = value | |
cost = kwargs.get("response_cost", None) | |
if cost is not None: | |
clean_metadata["litellm_response_cost"] = cost | |
metrics: Optional[dict] = None | |
usage_obj = getattr(response_obj, "usage", None) | |
if usage_obj and isinstance(usage_obj, litellm.Usage): | |
litellm.utils.get_logging_id(start_time, response_obj) | |
metrics = { | |
"prompt_tokens": usage_obj.prompt_tokens, | |
"completion_tokens": usage_obj.completion_tokens, | |
"total_tokens": usage_obj.total_tokens, | |
"total_cost": cost, | |
"time_to_first_token": end_time.timestamp() | |
- start_time.timestamp(), | |
"start": start_time.timestamp(), | |
"end": end_time.timestamp(), | |
} | |
request_data = { | |
"id": litellm_call_id, | |
"input": prompt["messages"], | |
"metadata": clean_metadata, | |
"tags": tags, | |
"span_attributes": {"name": "Chat Completion", "type": "llm"}, | |
} | |
if choices is not None: | |
request_data["output"] = [choice.dict() for choice in choices] | |
else: | |
request_data["output"] = output | |
if metrics is not None: | |
request_data["metrics"] = metrics | |
try: | |
print_verbose( | |
f"global_braintrust_sync_http_handler.post: {global_braintrust_sync_http_handler.post}" | |
) | |
global_braintrust_sync_http_handler.post( | |
url=f"{self.api_base}/project_logs/{project_id}/insert", | |
json={"events": [request_data]}, | |
headers=self.headers, | |
) | |
except httpx.HTTPStatusError as e: | |
raise Exception(e.response.text) | |
except Exception as e: | |
raise e # don't use verbose_logger.exception, if exception is raised | |
async def async_log_success_event( # noqa: PLR0915 | |
self, kwargs, response_obj, start_time, end_time | |
): | |
verbose_logger.debug("REACHES BRAINTRUST SUCCESS") | |
try: | |
litellm_call_id = kwargs.get("litellm_call_id") | |
prompt = {"messages": kwargs.get("messages")} | |
output = None | |
choices = [] | |
if response_obj is not None and ( | |
kwargs.get("call_type", None) == "embedding" | |
or isinstance(response_obj, litellm.EmbeddingResponse) | |
): | |
output = None | |
elif response_obj is not None and isinstance( | |
response_obj, litellm.ModelResponse | |
): | |
output = response_obj["choices"][0]["message"].json() | |
choices = response_obj["choices"] | |
elif response_obj is not None and isinstance( | |
response_obj, litellm.TextCompletionResponse | |
): | |
output = response_obj.choices[0].text | |
choices = response_obj.choices | |
elif response_obj is not None and isinstance( | |
response_obj, litellm.ImageResponse | |
): | |
output = response_obj["data"] | |
litellm_params = kwargs.get("litellm_params", {}) | |
metadata = ( | |
litellm_params.get("metadata", {}) or {} | |
) # if litellm_params['metadata'] == None | |
metadata = self.add_metadata_from_header(litellm_params, metadata) | |
clean_metadata = {} | |
new_metadata = {} | |
for key, value in metadata.items(): | |
if ( | |
isinstance(value, list) | |
or isinstance(value, str) | |
or isinstance(value, int) | |
or isinstance(value, float) | |
): | |
new_metadata[key] = value | |
elif isinstance(value, BaseModel): | |
new_metadata[key] = value.model_dump_json() | |
elif isinstance(value, dict): | |
for k, v in value.items(): | |
if isinstance(v, datetime): | |
value[k] = v.isoformat() | |
new_metadata[key] = value | |
# Get project_id from metadata or create default if needed | |
project_id = metadata.get("project_id") | |
if project_id is None: | |
project_name = metadata.get("project_name") | |
project_id = ( | |
await self.get_project_id_async(project_name) | |
if project_name | |
else None | |
) | |
if project_id is None: | |
if self.default_project_id is None: | |
await self.create_default_project_and_experiment() | |
project_id = self.default_project_id | |
tags = [] | |
if isinstance(metadata, dict): | |
for key, value in metadata.items(): | |
# generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy | |
if ( | |
litellm.langfuse_default_tags is not None | |
and isinstance(litellm.langfuse_default_tags, list) | |
and key in litellm.langfuse_default_tags | |
): | |
tags.append(f"{key}:{value}") | |
# clean litellm metadata before logging | |
if key in [ | |
"headers", | |
"endpoint", | |
"caching_groups", | |
"previous_models", | |
]: | |
continue | |
else: | |
clean_metadata[key] = value | |
cost = kwargs.get("response_cost", None) | |
if cost is not None: | |
clean_metadata["litellm_response_cost"] = cost | |
metrics: Optional[dict] = None | |
usage_obj = getattr(response_obj, "usage", None) | |
if usage_obj and isinstance(usage_obj, litellm.Usage): | |
litellm.utils.get_logging_id(start_time, response_obj) | |
metrics = { | |
"prompt_tokens": usage_obj.prompt_tokens, | |
"completion_tokens": usage_obj.completion_tokens, | |
"total_tokens": usage_obj.total_tokens, | |
"total_cost": cost, | |
"start": start_time.timestamp(), | |
"end": end_time.timestamp(), | |
} | |
api_call_start_time = kwargs.get("api_call_start_time") | |
completion_start_time = kwargs.get("completion_start_time") | |
if ( | |
api_call_start_time is not None | |
and completion_start_time is not None | |
): | |
metrics["time_to_first_token"] = ( | |
completion_start_time.timestamp() | |
- api_call_start_time.timestamp() | |
) | |
request_data = { | |
"id": litellm_call_id, | |
"input": prompt["messages"], | |
"output": output, | |
"metadata": clean_metadata, | |
"tags": tags, | |
"span_attributes": {"name": "Chat Completion", "type": "llm"}, | |
} | |
if choices is not None: | |
request_data["output"] = [choice.dict() for choice in choices] | |
else: | |
request_data["output"] = output | |
if metrics is not None: | |
request_data["metrics"] = metrics | |
if metrics is not None: | |
request_data["metrics"] = metrics | |
try: | |
await global_braintrust_http_handler.post( | |
url=f"{self.api_base}/project_logs/{project_id}/insert", | |
json={"events": [request_data]}, | |
headers=self.headers, | |
) | |
except httpx.HTTPStatusError as e: | |
raise Exception(e.response.text) | |
except Exception as e: | |
raise e # don't use verbose_logger.exception, if exception is raised | |
def log_failure_event(self, kwargs, response_obj, start_time, end_time): | |
return super().log_failure_event(kwargs, response_obj, start_time, end_time) | |