mcp-server / tests /test_query_wandbot.py
NiWaRe's picture
mcp_base
f647629
raw
history blame
13.1 kB
from __future__ import annotations
import json
import os
from typing import Any, Dict, List
import pytest
from wandb_mcp_server.utils import get_rich_logger
from tests.anthropic_test_utils import (
call_anthropic,
check_correctness_tool,
extract_anthropic_tool_use,
get_anthropic_tool_result_message,
)
from wandb_mcp_server.mcp_tools.query_wandbot import (
WANDBOT_TOOL_DESCRIPTION,
query_wandbot_api,
)
from wandb_mcp_server.mcp_tools.tools_utils import generate_anthropic_tool_schema
# -----------------------------------------------------------------------------
# Logging & env guards
# -----------------------------------------------------------------------------
logger = get_rich_logger(__name__)
# -----------------------------------------------------------------------------
# Environment guards
# -----------------------------------------------------------------------------
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
WANDBOT_BASE_URL = os.getenv("WANDBOT_TEST_URL", "https://morg--wandbot-api-wandbotapi-serve.modal.run")
if not ANTHROPIC_API_KEY:
pytest.skip(
"ANTHROPIC_API_KEY environment variable not set; skipping Anthropic tests.",
allow_module_level=True,
)
# -----------------------------------------------------------------------------
# Static test context
# -----------------------------------------------------------------------------
MODEL_NAME = "claude-3-7-sonnet-20250219"
CORRECTNESS_MODEL_NAME = "claude-3-5-haiku-20241022"
# -----------------------------------------------------------------------------
# Build tool schema for Anthropic
# -----------------------------------------------------------------------------
available_tools: Dict[str, Dict[str, Any]] = {
"query_wandbot_api": {
"function": query_wandbot_api,
"schema": generate_anthropic_tool_schema(
func=query_wandbot_api, # Pass the function itself
description=WANDBOT_TOOL_DESCRIPTION, # Use the imported description
),
}
}
tools: List[Dict[str, Any]] = [available_tools["query_wandbot_api"]["schema"]]
# -----------------------------------------------------------------------------
# Natural-language queries to test
# -----------------------------------------------------------------------------
test_queries = [
{
"question": "What kinds of scorers does weave support?",
"expected_output": "There are 2 types of scorers in weave, Function-based and Class-based.",
},
# Add more test cases here later
]
# -----------------------------------------------------------------------------
# Tests
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
"sample",
test_queries,
ids=[f"sample_{i}" for i, _ in enumerate(test_queries)],
)
def test_query_wandbot(sample):
"""End-to-end test: NL question → Anthropic → tool_use → result validation with correctness check."""
query_text = sample["question"]
expected_output = sample[
"expected_output"
] # Get expected output for correctness check
logger.info("\n==============================")
logger.info("QUERY: %s", query_text)
# --- Retry Logic Setup ---
max_retries = 1
last_reasoning = "No correctness check performed yet."
last_is_correct = False
first_call_assistant_response = None # Store the response dict from the first model
tool_result = None # Store the result of executing the tool
tool_use_id = None # Initialize tool_use_id *before* the loop
# Initial messages for the first attempt
messages_first_call = [{"role": "user", "content": query_text}]
for attempt in range(max_retries + 1):
logger.info(f"\n--- Attempt {attempt + 1} / {max_retries + 1} ---")
current_messages = messages_first_call # Start with the base messages
if attempt > 0:
# Retry logic: Add previous assistant response, tool result, and user feedback
retry_messages = []
if first_call_assistant_response:
# 1. Add previous assistant message (contains tool use)
retry_messages.append(
{
"role": first_call_assistant_response.role,
"content": first_call_assistant_response.content,
}
)
# 2. Add the result from executing the tool in the previous attempt
if tool_result is not None and tool_use_id is not None:
tool_result_message = get_anthropic_tool_result_message(
tool_result, tool_use_id
)
retry_messages.append(tool_result_message)
else:
logger.warning(
f"Attempt {attempt + 1}: Cannot add tool result message, tool_result or tool_use_id missing."
)
# 3. Add the user message asking for a retry
retry_user_message_content = f"""
Executing the previous tool call resulted in:
```json
{json.dumps(tool_result, indent=2)}
```
A separate check determined this result was incorrect for the original query.
The reasoning provided was: "{last_reasoning}".
Please re-analyze the original query ("{query_text}") and the result from your previous attempt, then try generating the '{available_tools["query_wandbot_api"]["schema"]["name"]}' tool call again.
"""
retry_messages.append(
{"role": "user", "content": retry_user_message_content}
)
current_messages = (
messages_first_call[:1] + retry_messages
) # Rebuild message list for retry
else:
logger.warning(
"Attempting retry, but no previous assistant response or tool_use_id found."
)
# If retry is needed but we lack context, we probably should just fail or stick with original messages
# For now, let's proceed with original messages, though this might not be ideal.
current_messages = messages_first_call
# --- First Call: Get the query_wandbot_api tool use ---
try:
response = call_anthropic(
model_name=MODEL_NAME,
messages=current_messages, # Use the potentially updated message list
tools=tools,
)
first_call_assistant_response = response # Store for potential *next* retry
except Exception as e:
pytest.fail(f"Attempt {attempt + 1}: Anthropic API call failed: {e}")
try:
# Extract tool_use_id here
_, tool_name, tool_input, tool_use_id = extract_anthropic_tool_use(response)
if tool_use_id is None:
logger.warning(
f"Attempt {attempt + 1}: Model did not return a tool use block."
)
# Decide how to handle this - maybe fail, maybe retry without tool use?
# For now, continue to execution, it might fail gracefully or correctness check will catch it.
except ValueError as e:
logger.error(
f"Attempt {attempt + 1}: Failed to extract tool use from response: {response}"
)
pytest.fail(f"Attempt {attempt + 1}: Could not extract tool use: {e}")
logger.info(f"Attempt {attempt + 1}: Tool emitted by model: {tool_name}")
logger.info(
f"Attempt {attempt + 1}: Tool input: {json.dumps(tool_input, indent=2)}"
)
assert tool_name == "query_wandbot_api", (
f"Attempt {attempt + 1}: Expected 'query_wandbot_api', got '{tool_name}'"
)
assert "question" in tool_input, (
f"Attempt {attempt + 1}: Tool input missing 'question'"
)
# --- Execute the WandBot tool ---
try:
# --- Ensure only expected args based on the *current* function signature are passed ---
# Assuming the function now only takes 'question'
if "question" not in tool_input:
pytest.fail(
f"Attempt {attempt + 1}: Tool input missing required 'question' argument."
)
actual_args = {"question": tool_input["question"]}
tool_result = available_tools[tool_name]["function"](**actual_args)
logger.info(
f"Attempt {attempt + 1}: Tool result: {json.dumps(tool_result, indent=2)}"
) # Log full result
# Basic structure check before correctness check
assert isinstance(tool_result, dict), "Tool result should be a dictionary"
assert isinstance(tool_result.get("answer"), str), (
"'answer' should be a string"
)
assert isinstance(tool_result.get("sources"), list), (
"'sources' should be a list"
)
except Exception as e:
logger.error(
f"Attempt {attempt + 1}: Error executing or validating tool '{tool_name}' with input {actual_args}: {e}",
exc_info=True,
)
pytest.fail(
f"Attempt {attempt + 1}: Tool execution or basic validation failed: {e}"
)
# --- Second Call: Perform Correctness Check ---
logger.info(f"\n--- Starting Correctness Check for Attempt {attempt + 1} ---")
try:
correctness_prompt = f"""
Please evaluate if the provided 'Actual Tool Result' provides a helpful and relevant answer to the 'Original User Query'.
The 'Expected Output Hint' gives guidance on what a good answer should contain.
Use the 'check_correctness_tool' to provide your reasoning and conclusion.
Original User Query:
"{query_text}"
Expected Output:
"{expected_output}"
Actual Tool Result from '{tool_name}':
```json
{json.dumps(tool_result, indent=2)}
```
"""
messages_check_call = [{"role": "user", "content": correctness_prompt}]
correctness_response = call_anthropic(
model_name=CORRECTNESS_MODEL_NAME,
messages=messages_check_call,
check_correctness_tool=check_correctness_tool, # Pass the imported tool schema
)
logger.info(
f"Attempt {attempt + 1}: Correctness check response:\n{correctness_response}\n\n"
)
_, check_tool_name, check_tool_input, _ = extract_anthropic_tool_use(
correctness_response
)
assert check_tool_name == "check_correctness_tool", (
f"Attempt {attempt + 1}: Expected correctness tool, got {check_tool_name}"
)
assert "reasoning" in check_tool_input, (
f"Attempt {attempt + 1}: Correctness tool missing 'reasoning'"
)
assert "is_correct" in check_tool_input, (
f"Attempt {attempt + 1}: Correctness tool missing 'is_correct'"
)
last_reasoning = check_tool_input["reasoning"]
last_is_correct = check_tool_input["is_correct"]
logger.info(
f"Attempt {attempt + 1}: Correctness Reasoning: {last_reasoning}"
)
logger.info(
f"Attempt {attempt + 1}: Is Correct according to LLM: {last_is_correct}"
)
if last_is_correct:
logger.info(
f"--- Correctness check passed on attempt {attempt + 1}. ---"
)
break # Exit the loop successfully
except KeyError as e:
logger.error(
f"Attempt {attempt + 1}: Missing expected key in correctness tool input: {e}"
)
logger.error(
f"Attempt {attempt + 1}: Full input received: {check_tool_input}"
)
last_is_correct = False
last_reasoning = f"Correctness tool response missing key: {e}"
# Continue loop if retries left, fail otherwise handled after loop
except Exception as e:
logger.error(
f"Attempt {attempt + 1}: Error during correctness check for query '{query_text}': {e}",
exc_info=True,
)
last_is_correct = False
last_reasoning = f"Correctness check failed with exception: {e}"
# Continue loop if retries left, fail otherwise handled after loop
# --- After the loop, fail the test if the last attempt wasn't correct ---
if not last_is_correct:
pytest.fail(
f"LLM evaluation failed after {max_retries + 1} attempts. "
f"Final is_correct_flag is `{last_is_correct}`. "
f"Final Reasoning: '{last_reasoning}'"
)
# If we reach here, it means the correctness check passed within the allowed attempts.
logger.info("--- Test passed within allowed attempts. ---")