| """
|
| Collection of Inspect AI solvers used by the rubric task.
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import asyncio
|
| import json
|
| import os
|
| import tempfile
|
| from typing import Callable, Dict, List, Sequence
|
|
|
| import litellm
|
| from inspect_ai.model import ChatMessageAssistant, ModelOutput
|
| from inspect_ai.solver import Solver, solver
|
| from inspect_ai.solver._task_state import TaskState
|
| from lmnr import Laminar, LaminarLiteLLMCallback
|
|
|
| from eval.hf_agent_connector import AgentResponseGenerator
|
|
|
|
|
| async def _run_subprocess(command: Sequence[str]) -> str:
|
| process = await asyncio.create_subprocess_exec(
|
| *command,
|
| stdout=asyncio.subprocess.PIPE,
|
| stderr=asyncio.subprocess.PIPE,
|
| )
|
| stdout, stderr = await process.communicate()
|
| if process.returncode != 0:
|
| raise RuntimeError(
|
| f"Command {' '.join(command)} failed with code {process.returncode}:\n"
|
| f"{stderr.decode().strip()}"
|
| )
|
| return stdout.decode().strip()
|
|
|
|
|
| @solver(name="hf_agent")
|
| def hf_agent(
|
| config_path: str = "agent/config_mcp_example.json",
|
| max_iterations: int = 10,
|
| ) -> Solver:
|
|
|
| Laminar.initialize(project_api_key=os.environ.get("LMNR_API_KEY"))
|
| litellm.callbacks = [LaminarLiteLLMCallback()]
|
| print("✅ Laminar initialized")
|
|
|
| runner = AgentResponseGenerator(
|
| config_path=config_path,
|
| max_iterations=max_iterations,
|
| )
|
|
|
| async def solve(state: TaskState, generate) -> TaskState:
|
| response = await runner.run(state.input_text)
|
| assistant_message = ChatMessageAssistant(
|
| content=response,
|
| model=runner.model_name,
|
| source="generate",
|
| )
|
| state.messages.append(assistant_message)
|
| state.output = ModelOutput.from_message(assistant_message)
|
| state.completed = True
|
| return state
|
|
|
| return solve
|
|
|
|
|
| @solver(name="claude_code")
|
| def claude_code(
|
| output_format: str = "json",
|
| mcp_config: str | None = None,
|
| ) -> Solver:
|
| if output_format not in {"text", "json", "stream-json"}:
|
| raise ValueError("output_format must be one of: text, json, stream-json")
|
|
|
| async def solve(state: TaskState, generate) -> TaskState:
|
| prompt = state.input_text
|
|
|
| cmd: List[str] = ["claude", "-p", prompt, "--output-format", output_format]
|
| if mcp_config:
|
| cmd += ["--mcp-config", mcp_config]
|
|
|
| stdout = await _run_subprocess(cmd)
|
| response_text = stdout
|
| session_id = None
|
|
|
| if output_format in {"json", "stream-json"}:
|
|
|
| candidate_line = stdout.strip().splitlines()[-1]
|
| try:
|
| payload = json.loads(candidate_line)
|
| response_text = (
|
| payload.get("result") or payload.get("message", "") or stdout
|
| )
|
| session_id = payload.get("session_id")
|
| except (json.JSONDecodeError, AttributeError):
|
| response_text = stdout
|
|
|
| assistant_message = ChatMessageAssistant(
|
| content=response_text,
|
| model="claude-code",
|
| source="generate",
|
| metadata={"session_id": session_id} if session_id else None,
|
| )
|
| state.messages.append(assistant_message)
|
| state.output = ModelOutput.from_message(assistant_message)
|
| state.completed = True
|
| return state
|
|
|
| return solve
|
|
|
|
|
| @solver(name="claude_code+hf_mcp")
|
| def claude_code_hf_mcp(
|
| output_format: str = "json",
|
| hf_token: str | None = None,
|
| ) -> Solver:
|
| """
|
| A solver that uses Claude Code with the Hugging Face MCP server.
|
| Requires HF_TOKEN in environment variables or passed as argument.
|
| """
|
| token = hf_token or os.environ.get("HF_TOKEN")
|
| if not token:
|
| raise ValueError(
|
| "HF_TOKEN not found. Please set HF_TOKEN env var or pass it to the solver."
|
| )
|
|
|
|
|
| mcp_config = {
|
| "mcpServers": {
|
| "huggingface": {
|
| "type": "http",
|
| "url": "https://huggingface.co/mcp",
|
| "headers": {"Authorization": f"Bearer {token}"},
|
| }
|
| }
|
| }
|
|
|
| async def solve(state: TaskState, generate) -> TaskState:
|
|
|
| with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp:
|
| json.dump(mcp_config, tmp, indent=2)
|
| tmp_path = tmp.name
|
|
|
| try:
|
|
|
| delegate = claude_code(output_format=output_format, mcp_config=tmp_path)
|
| return await delegate(state, generate)
|
| finally:
|
|
|
| if os.path.exists(tmp_path):
|
| os.remove(tmp_path)
|
|
|
| return solve
|
|
|
|
|
| SOLVER_REGISTRY: Dict[str, Callable[..., Solver]] = {
|
| "hf_agent": hf_agent,
|
| "claude_code": claude_code,
|
| "claude_code+hf_mcp": claude_code_hf_mcp,
|
| }
|
|
|
|
|
| def get_solver(name: str, **kwargs) -> Solver:
|
| try:
|
| factory = SOLVER_REGISTRY[name]
|
| except KeyError as exc:
|
| available = ", ".join(sorted(SOLVER_REGISTRY))
|
| raise ValueError(f"Unknown solver '{name}'. Available: {available}") from exc
|
|
|
| return factory(**kwargs)
|
|
|