Spaces:
Running
Running
#### What this does #### | |
# This file contains the LiteralAILogger class which is used to log steps to the LiteralAI observability platform. | |
import asyncio | |
import os | |
import uuid | |
from typing import List, Optional | |
import httpx | |
from litellm._logging import verbose_logger | |
from litellm.integrations.custom_batch_logger import CustomBatchLogger | |
from litellm.llms.custom_httpx.http_handler import ( | |
HTTPHandler, | |
get_async_httpx_client, | |
httpxSpecialProvider, | |
) | |
from litellm.types.utils import StandardLoggingPayload | |
class LiteralAILogger(CustomBatchLogger): | |
def __init__( | |
self, | |
literalai_api_key=None, | |
literalai_api_url="https://cloud.getliteral.ai", | |
env=None, | |
**kwargs, | |
): | |
self.literalai_api_url = os.getenv("LITERAL_API_URL") or literalai_api_url | |
self.headers = { | |
"Content-Type": "application/json", | |
"x-api-key": literalai_api_key or os.getenv("LITERAL_API_KEY"), | |
"x-client-name": "litellm", | |
} | |
if env: | |
self.headers["x-env"] = env | |
self.async_httpx_client = get_async_httpx_client( | |
llm_provider=httpxSpecialProvider.LoggingCallback | |
) | |
self.sync_http_handler = HTTPHandler() | |
batch_size = os.getenv("LITERAL_BATCH_SIZE", None) | |
self.flush_lock = asyncio.Lock() | |
super().__init__( | |
**kwargs, | |
flush_lock=self.flush_lock, | |
batch_size=int(batch_size) if batch_size else None, | |
) | |
def log_success_event(self, kwargs, response_obj, start_time, end_time): | |
try: | |
verbose_logger.debug( | |
"Literal AI Layer Logging - kwargs: %s, response_obj: %s", | |
kwargs, | |
response_obj, | |
) | |
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) | |
self.log_queue.append(data) | |
verbose_logger.debug( | |
"Literal AI logging: queue length %s, batch size %s", | |
len(self.log_queue), | |
self.batch_size, | |
) | |
if len(self.log_queue) >= self.batch_size: | |
self._send_batch() | |
except Exception: | |
verbose_logger.exception( | |
"Literal AI Layer Error - error logging success event." | |
) | |
def log_failure_event(self, kwargs, response_obj, start_time, end_time): | |
verbose_logger.info("Literal AI Failure Event Logging!") | |
try: | |
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) | |
self.log_queue.append(data) | |
verbose_logger.debug( | |
"Literal AI logging: queue length %s, batch size %s", | |
len(self.log_queue), | |
self.batch_size, | |
) | |
if len(self.log_queue) >= self.batch_size: | |
self._send_batch() | |
except Exception: | |
verbose_logger.exception( | |
"Literal AI Layer Error - error logging failure event." | |
) | |
def _send_batch(self): | |
if not self.log_queue: | |
return | |
url = f"{self.literalai_api_url}/api/graphql" | |
query = self._steps_query_builder(self.log_queue) | |
variables = self._steps_variables_builder(self.log_queue) | |
try: | |
response = self.sync_http_handler.post( | |
url=url, | |
json={ | |
"query": query, | |
"variables": variables, | |
}, | |
headers=self.headers, | |
) | |
if response.status_code >= 300: | |
verbose_logger.error( | |
f"Literal AI Error: {response.status_code} - {response.text}" | |
) | |
else: | |
verbose_logger.debug( | |
f"Batch of {len(self.log_queue)} runs successfully created" | |
) | |
except Exception: | |
verbose_logger.exception("Literal AI Layer Error") | |
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): | |
try: | |
verbose_logger.debug( | |
"Literal AI Async Layer Logging - kwargs: %s, response_obj: %s", | |
kwargs, | |
response_obj, | |
) | |
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) | |
self.log_queue.append(data) | |
verbose_logger.debug( | |
"Literal AI logging: queue length %s, batch size %s", | |
len(self.log_queue), | |
self.batch_size, | |
) | |
if len(self.log_queue) >= self.batch_size: | |
await self.flush_queue() | |
except Exception: | |
verbose_logger.exception( | |
"Literal AI Layer Error - error logging async success event." | |
) | |
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): | |
verbose_logger.info("Literal AI Failure Event Logging!") | |
try: | |
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) | |
self.log_queue.append(data) | |
verbose_logger.debug( | |
"Literal AI logging: queue length %s, batch size %s", | |
len(self.log_queue), | |
self.batch_size, | |
) | |
if len(self.log_queue) >= self.batch_size: | |
await self.flush_queue() | |
except Exception: | |
verbose_logger.exception( | |
"Literal AI Layer Error - error logging async failure event." | |
) | |
async def async_send_batch(self): | |
if not self.log_queue: | |
return | |
url = f"{self.literalai_api_url}/api/graphql" | |
query = self._steps_query_builder(self.log_queue) | |
variables = self._steps_variables_builder(self.log_queue) | |
try: | |
response = await self.async_httpx_client.post( | |
url=url, | |
json={ | |
"query": query, | |
"variables": variables, | |
}, | |
headers=self.headers, | |
) | |
if response.status_code >= 300: | |
verbose_logger.error( | |
f"Literal AI Error: {response.status_code} - {response.text}" | |
) | |
else: | |
verbose_logger.debug( | |
f"Batch of {len(self.log_queue)} runs successfully created" | |
) | |
except httpx.HTTPStatusError as e: | |
verbose_logger.exception( | |
f"Literal AI HTTP Error: {e.response.status_code} - {e.response.text}" | |
) | |
except Exception: | |
verbose_logger.exception("Literal AI Layer Error") | |
def _prepare_log_data(self, kwargs, response_obj, start_time, end_time) -> dict: | |
logging_payload: Optional[StandardLoggingPayload] = kwargs.get( | |
"standard_logging_object", None | |
) | |
if logging_payload is None: | |
raise ValueError("standard_logging_object not found in kwargs") | |
clean_metadata = logging_payload["metadata"] | |
metadata = kwargs.get("litellm_params", {}).get("metadata", {}) | |
settings = logging_payload["model_parameters"] | |
messages = logging_payload["messages"] | |
response = logging_payload["response"] | |
choices: List = [] | |
if isinstance(response, dict) and "choices" in response: | |
choices = response["choices"] | |
message_completion = choices[0]["message"] if choices else None | |
prompt_id = None | |
variables = None | |
if messages and isinstance(messages, list) and isinstance(messages[0], dict): | |
for message in messages: | |
if literal_prompt := getattr(message, "__literal_prompt__", None): | |
prompt_id = literal_prompt.get("prompt_id") | |
variables = literal_prompt.get("variables") | |
message["uuid"] = literal_prompt.get("uuid") | |
message["templated"] = True | |
tools = settings.pop("tools", None) | |
step = { | |
"id": metadata.get("step_id", str(uuid.uuid4())), | |
"error": logging_payload["error_str"], | |
"name": kwargs.get("model", ""), | |
"threadId": metadata.get("literalai_thread_id", None), | |
"parentId": metadata.get("literalai_parent_id", None), | |
"rootRunId": metadata.get("literalai_root_run_id", None), | |
"input": None, | |
"output": None, | |
"type": "llm", | |
"tags": metadata.get("tags", metadata.get("literalai_tags", None)), | |
"startTime": str(start_time), | |
"endTime": str(end_time), | |
"metadata": clean_metadata, | |
"generation": { | |
"inputTokenCount": logging_payload["prompt_tokens"], | |
"outputTokenCount": logging_payload["completion_tokens"], | |
"tokenCount": logging_payload["total_tokens"], | |
"promptId": prompt_id, | |
"variables": variables, | |
"provider": kwargs.get("custom_llm_provider", "litellm"), | |
"model": kwargs.get("model", ""), | |
"duration": (end_time - start_time).total_seconds(), | |
"settings": settings, | |
"messages": messages, | |
"messageCompletion": message_completion, | |
"tools": tools, | |
}, | |
} | |
return step | |
def _steps_query_variables_builder(self, steps): | |
generated = "" | |
for id in range(len(steps)): | |
generated += f"""$id_{id}: String! | |
$threadId_{id}: String | |
$rootRunId_{id}: String | |
$type_{id}: StepType | |
$startTime_{id}: DateTime | |
$endTime_{id}: DateTime | |
$error_{id}: String | |
$input_{id}: Json | |
$output_{id}: Json | |
$metadata_{id}: Json | |
$parentId_{id}: String | |
$name_{id}: String | |
$tags_{id}: [String!] | |
$generation_{id}: GenerationPayloadInput | |
$scores_{id}: [ScorePayloadInput!] | |
$attachments_{id}: [AttachmentPayloadInput!] | |
""" | |
return generated | |
def _steps_ingest_steps_builder(self, steps): | |
generated = "" | |
for id in range(len(steps)): | |
generated += f""" | |
step{id}: ingestStep( | |
id: $id_{id} | |
threadId: $threadId_{id} | |
rootRunId: $rootRunId_{id} | |
startTime: $startTime_{id} | |
endTime: $endTime_{id} | |
type: $type_{id} | |
error: $error_{id} | |
input: $input_{id} | |
output: $output_{id} | |
metadata: $metadata_{id} | |
parentId: $parentId_{id} | |
name: $name_{id} | |
tags: $tags_{id} | |
generation: $generation_{id} | |
scores: $scores_{id} | |
attachments: $attachments_{id} | |
) {{ | |
ok | |
message | |
}} | |
""" | |
return generated | |
def _steps_query_builder(self, steps): | |
return f""" | |
mutation AddStep({self._steps_query_variables_builder(steps)}) {{ | |
{self._steps_ingest_steps_builder(steps)} | |
}} | |
""" | |
def _steps_variables_builder(self, steps): | |
def serialize_step(event, id): | |
result = {} | |
for key, value in event.items(): | |
# Only keep the keys that are not None to avoid overriding existing values | |
if value is not None: | |
result[f"{key}_{id}"] = value | |
return result | |
variables = {} | |
for i in range(len(steps)): | |
step = steps[i] | |
variables.update(serialize_step(step, i)) | |
return variables | |