Spaces:
Paused
Paused
| import base64 | |
| import json | |
| import os | |
| from typing import Any, Dict | |
| import requests | |
| from wandb_mcp_server.weave_api.query_builder import QueryBuilder | |
| from wandb_mcp_server.mcp_tools.tools_utils import get_retry_session | |
| from wandb_mcp_server.utils import get_rich_logger | |
| from wandb_mcp_server.api_client import WandBApiManager | |
| logger = get_rich_logger(__name__) | |
| COUNT_WEAVE_TRACES_TOOL_DESCRIPTION = """count Weave traces and return the total storage \ | |
| size in bytes for the given filters. | |
| Use this tool to query data from Weights & Biases Weave, an observability product for | |
| tracing and evaluating LLMs and GenAI apps. | |
| This tool only provides COUNT information and STORAGE SIZE (bytes) about traces, \ | |
| not actual logged traces data, metrics or run data. | |
| <tool_choice_guidance> | |
| <wandb_vs_weave_product_distinction> | |
| **IMPORTANT PRODUCT DISTINCTION:** | |
| W&B offers two distinct products with different purposes: | |
| 1. W&B Models: A system for ML experiment tracking, hyperparameter optimization, and model | |
| lifecycle management. Use `query_wandb_tool` for questions about: | |
| - Experiment runs, metrics, and performance comparisons | |
| - Artifact management and model registry | |
| - Hyperparameter optimization and sweeps | |
| - Project dashboards and reports | |
| 2. W&B Weave: A toolkit for LLM and GenAI application observability and evaluation. Use | |
| `query_weave_traces_tool` (this tool) for questions about: | |
| - Execution traces and paths of LLM operations | |
| - LLM inputs, outputs, and intermediate results | |
| - Chain of thought visualization and debugging | |
| - LLM evaluation results and feedback | |
| </wandb_vs_weave_product_distinction> | |
| <use_case_selector> | |
| **USE CASE SELECTOR - READ FIRST:** | |
| - For runs, metrics, experiments, artifacts, sweeps etc → use query_wandb_tool | |
| - For traces, LLM calls, chain-of-thought, LLM evaluations, AI agent traces, AI apps etc → use query_weave_traces_tool | |
| ===================================================================== | |
| ⚠️ TOOL SELECTION WARNING ⚠️ | |
| This tool is ONLY for WEAVE TRACES (LLM operations), NOT for run metrics or experiments! | |
| ===================================================================== | |
| **KEYWORD GUIDE:** | |
| If user question contains: | |
| - "runs", "experiments", "metrics" → Use query_wandb_tool | |
| - "traces", "LLM calls" etc → Use this tool | |
| **COMMON MISUSE CASES:** | |
| ❌ "Looking at metrics of my latest runs" - Do NOT use this tool, use query_wandb_tool instead | |
| ❌ "Compare performance across experiments" - Do NOT use this tool, use query_wandb_tool instead | |
| </use_case_selector> | |
| </tool_choice_guidance> | |
| Returns the total number of traces in a project and the number of root | |
| (i.e. "parent" or top-level) traces. | |
| This is more efficient than query_trace_tool when you only need the count. | |
| This can be useful to understand how many traces are in a project before | |
| querying for them as query_trace_tool can return a lot of data. | |
| Parameters | |
| ---------- | |
| entity_name : str | |
| The Weights & Biases entity name (team or username). | |
| project_name : str | |
| The Weights & Biases project name. | |
| filters : Dict[str, Any], optional | |
| Dict of filter conditions, supporting: | |
| - display_name: Filter by display name (string or regex pattern) | |
| - op_name_contains: Filter for op_name containing a substring. Not a good idea to use in conjunction with trace_roots_only. | |
| - trace_id: Filter by specific trace ID | |
| - status: Filter by trace status ('success', 'error', etc.) | |
| - time_range: Dict with "start" and "end" datetime strings | |
| - latency: Filter by latency in milliseconds (summary.weave.latency_ms). | |
| Use a nested dict with operators: $gt, $lt, $eq, $gte, $lte. | |
| ($lt and $lte are implemented via logical negation on the backend). | |
| e.g., {"latency": {"$gt": 5000}} | |
| - attributes: Dict of attribute path and value/operator to match. | |
| Supports nested paths (e.g., "metadata.model_name") via dot notation. | |
| Value can be literal for equality or a dict with operator ($gt, $lt, $eq, $gte, $lte) for comparison | |
| (e.g., {"token_count": {"$gt": 100}}). | |
| - has_exception: Boolean to filter traces with/without exceptions | |
| - trace_roots_only: Boolean to filter for only top-level (aka parent) traces | |
| Returns | |
| ------- | |
| int | |
| The number of traces matching the query parameters. | |
| Examples | |
| -------- | |
| # Count failed traces | |
| count = count_traces( | |
| entity_name="my-team", | |
| project_name="my-project", | |
| filters={"status": "error"} | |
| ) | |
| # Count traces faster than 500ms | |
| count = count_traces( | |
| entity_name="my-team", | |
| project_name="my-project", | |
| filters={"latency": {"$lt": 500}} | |
| ) | |
| """ | |
| def count_traces( | |
| entity_name: str, | |
| project_name: str, | |
| filters: dict = None, | |
| request_timeout: int = 30, | |
| ) -> int: | |
| """Count the number of traces matching the given filters. | |
| Counts without retrieving the full trace data, making it more efficient | |
| than `query_traces` when only the count is needed. | |
| Parameters | |
| ---------- | |
| entity_name : str | |
| The Weights & Biases entity name (team or username). | |
| project_name : str | |
| The Weights & Biases project name. | |
| filters : Dict[str, Any], optional | |
| Dict of filter conditions, supporting: | |
| - display_name: Filter by display name (string or regex pattern) | |
| - op_name_contains: Filter for op_name containing a substring | |
| - trace_id: Filter by specific trace ID | |
| - status: Filter by trace status ('success', 'error', etc.) | |
| - latency: Filter by latency in milliseconds (summary.weave.latency_ms). | |
| Use a nested dict with operators: $gt, $lt, $eq, $gte, $lte. | |
| Note: $lt and $lte are implemented via logical negation. | |
| e.g., {"latency": {"$gt": 5000}} | |
| - time_range: Dict with "start" and "end" datetime strings | |
| - attributes: Dict of attribute path and value/operator to match. | |
| Supports nested paths (e.g., "metadata.model_name") via dot notation. | |
| Value can be literal for equality or a dict with operator ($gt, $lt, $eq, $gte, $lte) for comparison | |
| (e.g., {"token_count": {"$gt": 100}}). | |
| - has_exception: Boolean to filter traces with/without exceptions | |
| - trace_roots_only: Boolean to filter for only top-level (aka parent) traces | |
| request_timeout : int, optional | |
| Timeout for the HTTP request in seconds. Defaults to 30. | |
| Returns | |
| ------- | |
| int | |
| The number of traces matching the query parameters. | |
| Examples | |
| -------- | |
| >>> # Count failed traces | |
| >>> count = count_traces( | |
| ... entity_name="my-team", | |
| ... project_name="my-project", | |
| ... filters={"status": "error"} | |
| ... ) | |
| >>> # Count traces matching an attribute and latency > 1s | |
| >>> count = count_traces( | |
| ... entity_name="my-team", | |
| ... project_name="my-project", | |
| ... filters={ | |
| ... "attributes": {"metadata.environment": "production"}, | |
| ... "latency": {"$gt": 1000} | |
| ... } | |
| ... ) | |
| """ | |
| project_id = f"{entity_name}/{project_name}" | |
| # Get API key from context (set by auth middleware) or environment | |
| api_key = WandBApiManager.get_api_key() | |
| if not api_key: | |
| logger.error("W&B API key not found in context or environment variables.") | |
| raise ValueError("W&B API key is required to query Weave traces count.") | |
| # Debug logging to diagnose API key issues | |
| logger.debug(f"Using W&B API key: length={len(api_key)}, " | |
| f"first_6={api_key[:6] if len(api_key) >= 6 else 'N/A'}..., " | |
| f"last_4={api_key[-4:] if len(api_key) >= 4 else 'N/A'}") | |
| request_body: Dict[str, Any] = {"project_id": project_id} | |
| filter_payload: Dict[ | |
| str, Any | |
| ] = {} # For fields that go into the top-level 'filter' object | |
| complex_filters_for_query_expr: Dict[ | |
| str, Any | |
| ] = {} # For fields that go into query.$expr | |
| if filters: | |
| # Keys that belong inside the 'filter' object in the request body | |
| # as per https://weave-docs.wandb.ai/reference/service-api/calls-query-stats-calls-query_stats-post | |
| direct_filter_keys = { | |
| "op_names", | |
| "op_name", # op_name will be converted to op_names list | |
| "input_refs", | |
| "output_refs", | |
| "parent_ids", | |
| "trace_ids", | |
| "trace_id", # trace_id will be converted to trace_ids list | |
| "call_ids", | |
| "trace_roots_only", | |
| "wb_user_ids", | |
| "wb_run_ids", | |
| } | |
| temp_op_names = [] | |
| if "op_name" in filters: | |
| temp_op_names.append(filters["op_name"]) | |
| if "op_names" in filters: | |
| val = filters["op_names"] | |
| if isinstance(val, list): | |
| temp_op_names.extend(val) | |
| else: | |
| temp_op_names.append(val) | |
| if temp_op_names: | |
| filter_payload["op_names"] = list(set(temp_op_names)) | |
| temp_trace_ids = [] | |
| if "trace_id" in filters: | |
| temp_trace_ids.append(filters["trace_id"]) | |
| if "trace_ids" in filters: | |
| val = filters["trace_ids"] | |
| if isinstance(val, list): | |
| temp_trace_ids.extend(val) | |
| else: | |
| temp_trace_ids.append(val) | |
| if temp_trace_ids: | |
| filter_payload["trace_ids"] = list(set(temp_trace_ids)) | |
| # Handle other direct filter keys | |
| for key in [ | |
| "input_refs", | |
| "output_refs", | |
| "parent_ids", | |
| "call_ids", | |
| "wb_user_ids", | |
| "wb_run_ids", | |
| ]: | |
| if key in filters: | |
| value = filters[key] | |
| filter_payload[key] = [value] if not isinstance(value, list) else value | |
| if "trace_roots_only" in filters: | |
| filter_payload["trace_roots_only"] = filters["trace_roots_only"] | |
| # Per docs, trace_roots_only is a boolean, not a list. | |
| # If not in filters, it's omitted, API default (false) should apply. | |
| # Populate complex_filters_for_query_expr for remaining keys | |
| for key, value in filters.items(): | |
| # Skip keys already handled in direct_filter_keys or their singular versions | |
| if key not in direct_filter_keys and key not in ["op_name", "trace_id"]: | |
| complex_filters_for_query_expr[key] = value | |
| # Add the constructed filter_payload to the main request_body if it's not empty | |
| if filter_payload: | |
| request_body["filter"] = filter_payload | |
| # Build the query expression from remaining complex filters | |
| if complex_filters_for_query_expr: | |
| query_expr_obj = QueryBuilder.build_query_expression( | |
| complex_filters_for_query_expr | |
| ) | |
| if query_expr_obj: | |
| dumped_query = query_expr_obj.model_dump(by_alias=True, exclude_none=True) | |
| if dumped_query and dumped_query.get("$expr"): | |
| request_body["query"] = dumped_query | |
| # Execute the HTTP query | |
| weave_server_url = os.environ.get( | |
| "WEAVE_TRACE_SERVER_URL", "https://trace.wandb.ai" | |
| ) | |
| url = f"{weave_server_url}/calls/query_stats" | |
| auth_token = base64.b64encode(f":{api_key}".encode()).decode() | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Accept": "application/json", # /calls/query_stats returns application/json | |
| "Authorization": f"Basic {auth_token}", | |
| } | |
| session = get_retry_session() | |
| logger.debug(f"Posting to {url} with body: {json.dumps(request_body)}") | |
| try: | |
| response = session.post( | |
| url, | |
| headers=headers, | |
| data=json.dumps(request_body), # Ensure complex objects are serialized | |
| timeout=request_timeout, | |
| ) | |
| if response.status_code != 200: | |
| error_msg = f"Error querying Weave trace count: {response.status_code} - {response.text}" | |
| logger.error(error_msg) | |
| # Log API key info for debugging | |
| logger.error(f"API key info: length={len(api_key)}, is_40_chars={len(api_key) == 40}") | |
| if "40 characters" in response.text: | |
| logger.error(f"W&B requires exactly 40 character API keys. Current key has {len(api_key)} characters.") | |
| logger.error(f"Key preview: {api_key[:8]}...{api_key[-4:] if len(api_key) >= 12 else ''}") | |
| # Log request body for easier debugging on error | |
| logger.debug(f"Failed request body: {json.dumps(request_body)}") | |
| raise Exception(error_msg) | |
| response_json = response.json() | |
| return response_json.get("count", 0) # Default to 0 if count is not in response | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"HTTP Request failed for project {project_id}: {e}") | |
| if isinstance(e, requests.exceptions.RetryError): | |
| if e.__cause__ and hasattr(e.__cause__, "reason") and e.__cause__.reason: | |
| logger.error( | |
| f"Specific reason for retry exhaustion: {e.__cause__.reason}" | |
| ) | |
| logger.debug( | |
| f"Failed request body during exception for {project_id}: {json.dumps(request_body)}" | |
| ) | |
| # traceback.print_exc() # Uncomment for detailed traceback during development | |
| raise Exception( | |
| f"Failed to query Weave trace count for {project_id} due to network error: {e}" | |
| ) | |
| except json.JSONDecodeError as e: | |
| logger.error( | |
| f"Failed to decode JSON response for {project_id}: {e}. Response text: {response.text if 'response' in locals() else 'N/A'}" | |
| ) | |
| raise Exception(f"Failed to parse Weave API response for {project_id}: {e}") | |