Spaces:
Paused
Paused
| import json | |
| import threading | |
| from typing import Optional | |
| from litellm._logging import verbose_logger | |
| from litellm.integrations.custom_logger import CustomLogger | |
| class MlflowLogger(CustomLogger): | |
| def __init__(self): | |
| from mlflow.tracking import MlflowClient | |
| self._client = MlflowClient() | |
| self._stream_id_to_span = {} | |
| self._lock = threading.Lock() # lock for _stream_id_to_span | |
| def log_success_event(self, kwargs, response_obj, start_time, end_time): | |
| self._handle_success(kwargs, response_obj, start_time, end_time) | |
| async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): | |
| self._handle_success(kwargs, response_obj, start_time, end_time) | |
| def _handle_success(self, kwargs, response_obj, start_time, end_time): | |
| """ | |
| Log the success event as an MLflow span. | |
| Note that this method is called asynchronously in the background thread. | |
| """ | |
| from mlflow.entities import SpanStatusCode | |
| try: | |
| verbose_logger.debug("MLflow logging start for success event") | |
| if kwargs.get("stream"): | |
| self._handle_stream_event(kwargs, response_obj, start_time, end_time) | |
| else: | |
| span = self._start_span_or_trace(kwargs, start_time) | |
| end_time_ns = int(end_time.timestamp() * 1e9) | |
| self._extract_and_set_chat_attributes(span, kwargs, response_obj) | |
| self._end_span_or_trace( | |
| span=span, | |
| outputs=response_obj, | |
| status=SpanStatusCode.OK, | |
| end_time_ns=end_time_ns, | |
| ) | |
| except Exception: | |
| verbose_logger.debug("MLflow Logging Error", stack_info=True) | |
| def _extract_and_set_chat_attributes(self, span, kwargs, response_obj): | |
| try: | |
| from mlflow.tracing.utils import set_span_chat_messages # type: ignore | |
| from mlflow.tracing.utils import set_span_chat_tools # type: ignore | |
| except ImportError: | |
| return | |
| inputs = self._construct_input(kwargs) | |
| input_messages = inputs.get("messages", []) | |
| output_messages = [ | |
| c.message.model_dump(exclude_none=True) | |
| for c in getattr(response_obj, "choices", []) | |
| ] | |
| if messages := [*input_messages, *output_messages]: | |
| set_span_chat_messages(span, messages) | |
| if tools := inputs.get("tools"): | |
| set_span_chat_tools(span, tools) | |
| def log_failure_event(self, kwargs, response_obj, start_time, end_time): | |
| self._handle_failure(kwargs, response_obj, start_time, end_time) | |
| async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): | |
| self._handle_failure(kwargs, response_obj, start_time, end_time) | |
| def _handle_failure(self, kwargs, response_obj, start_time, end_time): | |
| """ | |
| Log the failure event as an MLflow span. | |
| Note that this method is called *synchronously* unlike the success handler. | |
| """ | |
| from mlflow.entities import SpanEvent, SpanStatusCode | |
| try: | |
| span = self._start_span_or_trace(kwargs, start_time) | |
| end_time_ns = int(end_time.timestamp() * 1e9) | |
| # Record exception info as event | |
| if exception := kwargs.get("exception"): | |
| span.add_event(SpanEvent.from_exception(exception)) # type: ignore | |
| self._extract_and_set_chat_attributes(span, kwargs, response_obj) | |
| self._end_span_or_trace( | |
| span=span, | |
| outputs=response_obj, | |
| status=SpanStatusCode.ERROR, | |
| end_time_ns=end_time_ns, | |
| ) | |
| except Exception as e: | |
| verbose_logger.debug(f"MLflow Logging Error - {e}", stack_info=True) | |
| def _handle_stream_event(self, kwargs, response_obj, start_time, end_time): | |
| """ | |
| Handle the success event for a streaming response. For streaming calls, | |
| log_success_event handle is triggered for every chunk of the stream. | |
| We create a single span for the entire stream request as follows: | |
| 1. For the first chunk, start a new span and store it in the map. | |
| 2. For subsequent chunks, add the chunk as an event to the span. | |
| 3. For the final chunk, end the span and remove the span from the map. | |
| """ | |
| from mlflow.entities import SpanStatusCode | |
| litellm_call_id = kwargs.get("litellm_call_id") | |
| if litellm_call_id not in self._stream_id_to_span: | |
| with self._lock: | |
| # Check again after acquiring lock | |
| if litellm_call_id not in self._stream_id_to_span: | |
| # Start a new span for the first chunk of the stream | |
| span = self._start_span_or_trace(kwargs, start_time) | |
| self._stream_id_to_span[litellm_call_id] = span | |
| # Add chunk as event to the span | |
| span = self._stream_id_to_span[litellm_call_id] | |
| self._add_chunk_events(span, response_obj) | |
| # If this is the final chunk, end the span. The final chunk | |
| # has complete_streaming_response that gathers the full response. | |
| if final_response := kwargs.get("complete_streaming_response"): | |
| end_time_ns = int(end_time.timestamp() * 1e9) | |
| self._extract_and_set_chat_attributes(span, kwargs, final_response) | |
| self._end_span_or_trace( | |
| span=span, | |
| outputs=final_response, | |
| status=SpanStatusCode.OK, | |
| end_time_ns=end_time_ns, | |
| ) | |
| # Remove the stream_id from the map | |
| with self._lock: | |
| self._stream_id_to_span.pop(litellm_call_id) | |
| def _add_chunk_events(self, span, response_obj): | |
| from mlflow.entities import SpanEvent | |
| try: | |
| for choice in response_obj.choices: | |
| span.add_event( | |
| SpanEvent( | |
| name="streaming_chunk", | |
| attributes={"delta": json.dumps(choice.delta.model_dump())}, | |
| ) | |
| ) | |
| except Exception: | |
| verbose_logger.debug("Error adding chunk events to span", stack_info=True) | |
| def _construct_input(self, kwargs): | |
| """Construct span inputs with optional parameters""" | |
| inputs = {"messages": kwargs.get("messages")} | |
| if tools := kwargs.get("tools"): | |
| inputs["tools"] = tools | |
| for key in ["functions", "tools", "stream", "tool_choice", "user"]: | |
| if value := kwargs.get("optional_params", {}).pop(key, None): | |
| inputs[key] = value | |
| return inputs | |
| def _extract_attributes(self, kwargs): | |
| """ | |
| Extract span attributes from kwargs. | |
| With the latest version of litellm, the standard_logging_object contains | |
| canonical information for logging. If it is not present, we extract | |
| subset of attributes from other kwargs. | |
| """ | |
| attributes = { | |
| "litellm_call_id": kwargs.get("litellm_call_id"), | |
| "call_type": kwargs.get("call_type"), | |
| "model": kwargs.get("model"), | |
| } | |
| standard_obj = kwargs.get("standard_logging_object") | |
| if standard_obj: | |
| attributes.update( | |
| { | |
| "api_base": standard_obj.get("api_base"), | |
| "cache_hit": standard_obj.get("cache_hit"), | |
| "usage": { | |
| "completion_tokens": standard_obj.get("completion_tokens"), | |
| "prompt_tokens": standard_obj.get("prompt_tokens"), | |
| "total_tokens": standard_obj.get("total_tokens"), | |
| }, | |
| "raw_llm_response": standard_obj.get("response"), | |
| "response_cost": standard_obj.get("response_cost"), | |
| "saved_cache_cost": standard_obj.get("saved_cache_cost"), | |
| } | |
| ) | |
| else: | |
| litellm_params = kwargs.get("litellm_params", {}) | |
| attributes.update( | |
| { | |
| "model": kwargs.get("model"), | |
| "cache_hit": kwargs.get("cache_hit"), | |
| "custom_llm_provider": kwargs.get("custom_llm_provider"), | |
| "api_base": litellm_params.get("api_base"), | |
| "response_cost": kwargs.get("response_cost"), | |
| } | |
| ) | |
| return attributes | |
| def _get_span_type(self, call_type: Optional[str]) -> str: | |
| from mlflow.entities import SpanType | |
| if call_type in ["completion", "acompletion"]: | |
| return SpanType.LLM | |
| elif call_type == "embeddings": | |
| return SpanType.EMBEDDING | |
| else: | |
| return SpanType.LLM | |
| def _start_span_or_trace(self, kwargs, start_time): | |
| """ | |
| Start an MLflow span or a trace. | |
| If there is an active span, we start a new span as a child of | |
| that span. Otherwise, we start a new trace. | |
| """ | |
| import mlflow | |
| call_type = kwargs.get("call_type", "completion") | |
| span_name = f"litellm-{call_type}" | |
| span_type = self._get_span_type(call_type) | |
| start_time_ns = int(start_time.timestamp() * 1e9) | |
| inputs = self._construct_input(kwargs) | |
| attributes = self._extract_attributes(kwargs) | |
| if active_span := mlflow.get_current_active_span(): # type: ignore | |
| return self._client.start_span( | |
| name=span_name, | |
| request_id=active_span.request_id, | |
| parent_id=active_span.span_id, | |
| span_type=span_type, | |
| inputs=inputs, | |
| attributes=attributes, | |
| start_time_ns=start_time_ns, | |
| ) | |
| else: | |
| return self._client.start_trace( | |
| name=span_name, | |
| span_type=span_type, | |
| inputs=inputs, | |
| attributes=attributes, | |
| start_time_ns=start_time_ns, | |
| ) | |
| def _end_span_or_trace(self, span, outputs, end_time_ns, status): | |
| """End an MLflow span or a trace.""" | |
| if span.parent_id is None: | |
| self._client.end_trace( | |
| request_id=span.request_id, | |
| outputs=outputs, | |
| status=status, | |
| end_time_ns=end_time_ns, | |
| ) | |
| else: | |
| self._client.end_span( | |
| request_id=span.request_id, | |
| span_id=span.span_id, | |
| outputs=outputs, | |
| status=status, | |
| end_time_ns=end_time_ns, | |
| ) | |