Spaces:
Running
Running
Deploy claude-code-nvidia proxy to Hugging Face Spaces
Browse filesCo-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This view is limited to 50 files because it contains too many changes. See raw diff
- .env.example +107 -0
- Dockerfile +26 -0
- api/__init__.py +17 -0
- api/__pycache__/__init__.cpython-314.pyc +0 -0
- api/__pycache__/app.cpython-314.pyc +0 -0
- api/__pycache__/command_utils.cpython-314.pyc +0 -0
- api/__pycache__/dependencies.cpython-314.pyc +0 -0
- api/__pycache__/detection.cpython-314.pyc +0 -0
- api/__pycache__/gateway_model_ids.cpython-314.pyc +0 -0
- api/__pycache__/model_router.cpython-314.pyc +0 -0
- api/__pycache__/optimization_handlers.cpython-314.pyc +0 -0
- api/__pycache__/routes.cpython-314.pyc +0 -0
- api/__pycache__/runtime.cpython-314.pyc +0 -0
- api/__pycache__/services.cpython-314.pyc +0 -0
- api/__pycache__/validation_log.cpython-314.pyc +0 -0
- api/app.py +175 -0
- api/command_utils.py +164 -0
- api/dependencies.py +144 -0
- api/detection.py +136 -0
- api/gateway_model_ids.py +54 -0
- api/model_router.py +261 -0
- api/models/__init__.py +45 -0
- api/models/__pycache__/__init__.cpython-314.pyc +0 -0
- api/models/__pycache__/anthropic.cpython-314.pyc +0 -0
- api/models/__pycache__/responses.cpython-314.pyc +0 -0
- api/models/anthropic.py +163 -0
- api/models/responses.py +56 -0
- api/optimization_handlers.py +154 -0
- api/routes.py +271 -0
- api/runtime.py +338 -0
- api/services.py +305 -0
- api/validation_log.py +48 -0
- api/web_server_tools.py +22 -0
- api/web_tools/__init__.py +17 -0
- api/web_tools/__pycache__/__init__.cpython-314.pyc +0 -0
- api/web_tools/__pycache__/constants.cpython-314.pyc +0 -0
- api/web_tools/__pycache__/egress.cpython-314.pyc +0 -0
- api/web_tools/__pycache__/parsers.cpython-314.pyc +0 -0
- api/web_tools/__pycache__/request.cpython-314.pyc +0 -0
- api/web_tools/__pycache__/streaming.cpython-314.pyc +0 -0
- api/web_tools/constants.py +15 -0
- api/web_tools/egress.py +99 -0
- api/web_tools/outbound.py +278 -0
- api/web_tools/parsers.py +104 -0
- api/web_tools/request.py +86 -0
- api/web_tools/streaming.py +206 -0
- cli/__init__.py +6 -0
- cli/entrypoints.py +60 -0
- cli/manager.py +163 -0
- cli/process_registry.py +74 -0
.env.example
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# NVIDIA NIM Config
|
| 2 |
+
NVIDIA_NIM_API_KEY=""
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# All Claude model requests are mapped to these models, plain model is fallback
|
| 6 |
+
# Format: provider_type/model/name
|
| 7 |
+
# Valid provider: "nvidia_nim"
|
| 8 |
+
MODEL_OPUS=
|
| 9 |
+
MODEL_SONNET=
|
| 10 |
+
MODEL_HAIKU=
|
| 11 |
+
MODEL="nvidia_nim/z-ai/glm4.7"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# Optional live smoke model overrides. Smoke runs for NVIDIA NIM.
|
| 15 |
+
FCC_SMOKE_MODEL_NVIDIA_NIM=
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Thinking output
|
| 19 |
+
# Per-Claude-model switches for provider reasoning requests and Claude thinking blocks.
|
| 20 |
+
# Blank per-model switches inherit ENABLE_MODEL_THINKING.
|
| 21 |
+
ENABLE_OPUS_THINKING=
|
| 22 |
+
ENABLE_SONNET_THINKING=
|
| 23 |
+
ENABLE_HAIKU_THINKING=
|
| 24 |
+
ENABLE_MODEL_THINKING=true
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Provider config
|
| 28 |
+
# Per-provider proxy support: http and socks5, example: "http://username:password@host:port"
|
| 29 |
+
NVIDIA_NIM_PROXY=""
|
| 30 |
+
|
| 31 |
+
PROVIDER_RATE_LIMIT=1
|
| 32 |
+
PROVIDER_RATE_WINDOW=3
|
| 33 |
+
PROVIDER_MAX_CONCURRENCY=5
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# HTTP client timeouts (seconds) for provider API requests
|
| 37 |
+
HTTP_READ_TIMEOUT=300
|
| 38 |
+
HTTP_WRITE_TIMEOUT=10
|
| 39 |
+
HTTP_CONNECT_TIMEOUT=10
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Optional server API key (Anthropic-style)
|
| 43 |
+
ANTHROPIC_AUTH_TOKEN="freecc"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# Messaging Platform: "telegram" | "discord" | "none"
|
| 47 |
+
MESSAGING_PLATFORM="discord"
|
| 48 |
+
MESSAGING_RATE_LIMIT=1
|
| 49 |
+
MESSAGING_RATE_WINDOW=1
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# Voice Note Transcription
|
| 53 |
+
VOICE_NOTE_ENABLED=false
|
| 54 |
+
# WHISPER_DEVICE: "cpu" | "cuda" | "nvidia_nim"
|
| 55 |
+
# - "cpu"/"cuda": Hugging Face transformers Whisper (offline, free; install with: uv sync --extra voice_local)
|
| 56 |
+
# - "nvidia_nim": NVIDIA NIM Whisper via Riva gRPC (requires NVIDIA_NIM_API_KEY; install with: uv sync --extra voice)
|
| 57 |
+
# (Independent of MODEL=nvidia_nim/...: that selects the *chat* provider; this selects voice STT only.)
|
| 58 |
+
WHISPER_DEVICE="nvidia_nim"
|
| 59 |
+
# WHISPER_MODEL:
|
| 60 |
+
# - For cpu/cuda: Hugging Face ID or short name (tiny, base, small, medium, large-v2, large-v3, large-v3-turbo)
|
| 61 |
+
# - For nvidia_nim: NVIDIA NIM model (e.g., "nvidia/parakeet-ctc-1.1b-asr", "openai/whisper-large-v3")
|
| 62 |
+
# - For nvidia_nim, default to "openai/whisper-large-v3" for best performance
|
| 63 |
+
WHISPER_MODEL="openai/whisper-large-v3"
|
| 64 |
+
HF_TOKEN=""
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# Telegram Config
|
| 68 |
+
TELEGRAM_BOT_TOKEN=""
|
| 69 |
+
ALLOWED_TELEGRAM_USER_ID=""
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# Discord Config
|
| 73 |
+
DISCORD_BOT_TOKEN=""
|
| 74 |
+
ALLOWED_DISCORD_CHANNELS=""
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# Agent Config
|
| 78 |
+
CLAUDE_WORKSPACE="./agent_workspace"
|
| 79 |
+
ALLOWED_DIR=""
|
| 80 |
+
CLAUDE_CLI_BIN="claude"
|
| 81 |
+
FAST_PREFIX_DETECTION=true
|
| 82 |
+
ENABLE_NETWORK_PROBE_MOCK=true
|
| 83 |
+
ENABLE_TITLE_GENERATION_SKIP=true
|
| 84 |
+
ENABLE_SUGGESTION_MODE_SKIP=true
|
| 85 |
+
ENABLE_FILEPATH_EXTRACTION_MOCK=true
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# Local Anthropic web_search / web_fetch handling (performs outbound HTTP; on by default)
|
| 89 |
+
ENABLE_WEB_SERVER_TOOLS=true
|
| 90 |
+
WEB_FETCH_ALLOWED_SCHEMES=http,https
|
| 91 |
+
WEB_FETCH_ALLOW_PRIVATE_NETWORKS=false
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# Verbose diagnostics (avoid logging raw prompts / SSE bodies in production)
|
| 95 |
+
DEBUG_PLATFORM_EDITS=false
|
| 96 |
+
DEBUG_SUBAGENT_STACK=false
|
| 97 |
+
# When true, also allows DEBUG-level httpx/httpcore/telegram log noise (not just payload logging).
|
| 98 |
+
LOG_RAW_API_PAYLOADS=false
|
| 99 |
+
LOG_RAW_SSE_EVENTS=false
|
| 100 |
+
# When true, log full exception text and tracebacks for unhandled errors (may leak request-derived data).
|
| 101 |
+
LOG_API_ERROR_TRACEBACKS=false
|
| 102 |
+
# When true, log message/transcription text previews in messaging adapters (may leak user content).
|
| 103 |
+
LOG_RAW_MESSAGING_CONTENT=false
|
| 104 |
+
# When true, log full Claude CLI stderr, non-JSON stdout lines, and parser error text.
|
| 105 |
+
LOG_RAW_CLI_DIAGNOSTICS=false
|
| 106 |
+
# When true, log full exception and CLI error message strings in messaging (may leak user content).
|
| 107 |
+
LOG_MESSAGING_ERROR_DETAILS=false
|
Dockerfile
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.14-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install uv
|
| 6 |
+
RUN pip install uv
|
| 7 |
+
|
| 8 |
+
# Copy project files
|
| 9 |
+
COPY pyproject.toml uv.lock ./
|
| 10 |
+
COPY api/ ./api/
|
| 11 |
+
COPY cli/ ./cli/
|
| 12 |
+
COPY config/ ./config/
|
| 13 |
+
COPY core/ ./core/
|
| 14 |
+
COPY messaging/ ./messaging/
|
| 15 |
+
COPY providers/ ./providers/
|
| 16 |
+
COPY server.py ./
|
| 17 |
+
COPY .env.example ./
|
| 18 |
+
|
| 19 |
+
# Install dependencies
|
| 20 |
+
RUN uv sync --frozen --no-dev
|
| 21 |
+
|
| 22 |
+
# Expose port (HF Spaces default)
|
| 23 |
+
EXPOSE 7860
|
| 24 |
+
|
| 25 |
+
# Run server
|
| 26 |
+
CMD ["uv", "run", "uvicorn", "server:app", "--host", "0.0.0.0", "--port", "7860"]
|
api/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""API layer for Claude Code Proxy."""
|
| 2 |
+
|
| 3 |
+
from .app import create_app
|
| 4 |
+
from .models import (
|
| 5 |
+
MessagesRequest,
|
| 6 |
+
MessagesResponse,
|
| 7 |
+
TokenCountRequest,
|
| 8 |
+
TokenCountResponse,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"MessagesRequest",
|
| 13 |
+
"MessagesResponse",
|
| 14 |
+
"TokenCountRequest",
|
| 15 |
+
"TokenCountResponse",
|
| 16 |
+
"create_app",
|
| 17 |
+
]
|
api/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (431 Bytes). View file
|
|
|
api/__pycache__/app.cpython-314.pyc
ADDED
|
Binary file (10.7 kB). View file
|
|
|
api/__pycache__/command_utils.cpython-314.pyc
ADDED
|
Binary file (5.83 kB). View file
|
|
|
api/__pycache__/dependencies.cpython-314.pyc
ADDED
|
Binary file (7.56 kB). View file
|
|
|
api/__pycache__/detection.cpython-314.pyc
ADDED
|
Binary file (6.71 kB). View file
|
|
|
api/__pycache__/gateway_model_ids.cpython-314.pyc
ADDED
|
Binary file (2.55 kB). View file
|
|
|
api/__pycache__/model_router.cpython-314.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
api/__pycache__/optimization_handlers.cpython-314.pyc
ADDED
|
Binary file (5.7 kB). View file
|
|
|
api/__pycache__/routes.cpython-314.pyc
ADDED
|
Binary file (14.1 kB). View file
|
|
|
api/__pycache__/runtime.cpython-314.pyc
ADDED
|
Binary file (20.1 kB). View file
|
|
|
api/__pycache__/services.cpython-314.pyc
ADDED
|
Binary file (14 kB). View file
|
|
|
api/__pycache__/validation_log.cpython-314.pyc
ADDED
|
Binary file (2.95 kB). View file
|
|
|
api/app.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI application factory and configuration."""
|
| 2 |
+
|
| 3 |
+
import traceback
|
| 4 |
+
from contextlib import asynccontextmanager
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
from fastapi import FastAPI, Request
|
| 8 |
+
from fastapi.exception_handlers import request_validation_exception_handler
|
| 9 |
+
from fastapi.exceptions import RequestValidationError
|
| 10 |
+
from fastapi.responses import JSONResponse
|
| 11 |
+
from loguru import logger
|
| 12 |
+
from starlette.types import Receive, Scope, Send
|
| 13 |
+
|
| 14 |
+
from config.logging_config import configure_logging
|
| 15 |
+
from config.settings import get_settings
|
| 16 |
+
from providers.exceptions import ProviderError
|
| 17 |
+
|
| 18 |
+
from .routes import router
|
| 19 |
+
from .runtime import AppRuntime, startup_failure_message
|
| 20 |
+
from .validation_log import summarize_request_validation_body
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@asynccontextmanager
|
| 24 |
+
async def lifespan(app: FastAPI):
|
| 25 |
+
"""Application lifespan manager."""
|
| 26 |
+
runtime = AppRuntime.for_app(app, settings=get_settings())
|
| 27 |
+
await runtime.startup()
|
| 28 |
+
|
| 29 |
+
yield
|
| 30 |
+
|
| 31 |
+
await runtime.shutdown()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class GracefulLifespanApp:
|
| 35 |
+
"""ASGI wrapper that reports startup failures without Starlette tracebacks."""
|
| 36 |
+
|
| 37 |
+
def __init__(self, app: FastAPI):
|
| 38 |
+
self.app = app
|
| 39 |
+
|
| 40 |
+
def __getattr__(self, name: str) -> Any:
|
| 41 |
+
return getattr(self.app, name)
|
| 42 |
+
|
| 43 |
+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
| 44 |
+
if scope["type"] != "lifespan":
|
| 45 |
+
await self.app(scope, receive, send)
|
| 46 |
+
return
|
| 47 |
+
await self._lifespan(receive, send)
|
| 48 |
+
|
| 49 |
+
async def _lifespan(self, receive: Receive, send: Send) -> None:
|
| 50 |
+
settings = get_settings()
|
| 51 |
+
runtime = AppRuntime.for_app(self.app, settings=settings)
|
| 52 |
+
startup_complete = False
|
| 53 |
+
while True:
|
| 54 |
+
message = await receive()
|
| 55 |
+
if message["type"] == "lifespan.startup":
|
| 56 |
+
try:
|
| 57 |
+
await runtime.startup()
|
| 58 |
+
except Exception as exc:
|
| 59 |
+
await send(
|
| 60 |
+
{
|
| 61 |
+
"type": "lifespan.startup.failed",
|
| 62 |
+
"message": startup_failure_message(settings, exc),
|
| 63 |
+
}
|
| 64 |
+
)
|
| 65 |
+
return
|
| 66 |
+
startup_complete = True
|
| 67 |
+
await send({"type": "lifespan.startup.complete"})
|
| 68 |
+
continue
|
| 69 |
+
|
| 70 |
+
if message["type"] == "lifespan.shutdown":
|
| 71 |
+
if startup_complete:
|
| 72 |
+
try:
|
| 73 |
+
await runtime.shutdown()
|
| 74 |
+
except Exception as exc:
|
| 75 |
+
logger.error("Shutdown failed: exc_type={}", type(exc).__name__)
|
| 76 |
+
await send({"type": "lifespan.shutdown.failed", "message": ""})
|
| 77 |
+
return
|
| 78 |
+
await send({"type": "lifespan.shutdown.complete"})
|
| 79 |
+
return
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def create_app(*, lifespan_enabled: bool = True) -> FastAPI:
|
| 83 |
+
"""Create and configure the FastAPI application."""
|
| 84 |
+
settings = get_settings()
|
| 85 |
+
configure_logging(
|
| 86 |
+
settings.log_file, verbose_third_party=settings.log_raw_api_payloads
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
app_kwargs: dict[str, Any] = {
|
| 90 |
+
"title": "Claude Code Proxy",
|
| 91 |
+
"version": "2.0.0",
|
| 92 |
+
}
|
| 93 |
+
if lifespan_enabled:
|
| 94 |
+
app_kwargs["lifespan"] = lifespan
|
| 95 |
+
app = FastAPI(**app_kwargs)
|
| 96 |
+
|
| 97 |
+
# Register routes
|
| 98 |
+
app.include_router(router)
|
| 99 |
+
|
| 100 |
+
# Exception handlers
|
| 101 |
+
@app.exception_handler(RequestValidationError)
|
| 102 |
+
async def validation_error_handler(request: Request, exc: RequestValidationError):
|
| 103 |
+
"""Log request shape for 422 debugging without content values."""
|
| 104 |
+
body: Any
|
| 105 |
+
try:
|
| 106 |
+
body = await request.json()
|
| 107 |
+
except Exception as e:
|
| 108 |
+
body = {"_json_error": type(e).__name__}
|
| 109 |
+
|
| 110 |
+
message_summary, tool_names = summarize_request_validation_body(body)
|
| 111 |
+
|
| 112 |
+
logger.debug(
|
| 113 |
+
"Request validation failed: path={} query={} error_locs={} error_types={} message_summary={} tool_names={}",
|
| 114 |
+
request.url.path,
|
| 115 |
+
str(request.url.query),
|
| 116 |
+
[list(error.get("loc", ())) for error in exc.errors()],
|
| 117 |
+
[str(error.get("type", "")) for error in exc.errors()],
|
| 118 |
+
message_summary,
|
| 119 |
+
tool_names,
|
| 120 |
+
)
|
| 121 |
+
return await request_validation_exception_handler(request, exc)
|
| 122 |
+
|
| 123 |
+
@app.exception_handler(ProviderError)
|
| 124 |
+
async def provider_error_handler(request: Request, exc: ProviderError):
|
| 125 |
+
"""Handle provider-specific errors and return Anthropic format."""
|
| 126 |
+
err_settings = get_settings()
|
| 127 |
+
if err_settings.log_api_error_tracebacks:
|
| 128 |
+
logger.error(
|
| 129 |
+
"Provider Error: error_type={} status_code={} message={}",
|
| 130 |
+
exc.error_type,
|
| 131 |
+
exc.status_code,
|
| 132 |
+
exc.message,
|
| 133 |
+
)
|
| 134 |
+
else:
|
| 135 |
+
logger.error(
|
| 136 |
+
"Provider Error: error_type={} status_code={}",
|
| 137 |
+
exc.error_type,
|
| 138 |
+
exc.status_code,
|
| 139 |
+
)
|
| 140 |
+
return JSONResponse(
|
| 141 |
+
status_code=exc.status_code,
|
| 142 |
+
content=exc.to_anthropic_format(),
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
@app.exception_handler(Exception)
|
| 146 |
+
async def general_error_handler(request: Request, exc: Exception):
|
| 147 |
+
"""Handle general errors and return Anthropic format."""
|
| 148 |
+
settings = get_settings()
|
| 149 |
+
if settings.log_api_error_tracebacks:
|
| 150 |
+
logger.error("General Error: {}", exc)
|
| 151 |
+
logger.error(traceback.format_exc())
|
| 152 |
+
else:
|
| 153 |
+
logger.error(
|
| 154 |
+
"General Error: path={} method={} exc_type={}",
|
| 155 |
+
request.url.path,
|
| 156 |
+
request.method,
|
| 157 |
+
type(exc).__name__,
|
| 158 |
+
)
|
| 159 |
+
return JSONResponse(
|
| 160 |
+
status_code=500,
|
| 161 |
+
content={
|
| 162 |
+
"type": "error",
|
| 163 |
+
"error": {
|
| 164 |
+
"type": "api_error",
|
| 165 |
+
"message": "An unexpected error occurred.",
|
| 166 |
+
},
|
| 167 |
+
},
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
return app
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def create_asgi_app() -> GracefulLifespanApp:
|
| 174 |
+
"""Create the server ASGI app with graceful lifespan failure reporting."""
|
| 175 |
+
return GracefulLifespanApp(create_app(lifespan_enabled=False))
|
api/command_utils.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Command parsing utilities for API optimizations."""
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
import shlex
|
| 5 |
+
|
| 6 |
+
_ENV_ASSIGNMENT_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*=.*$")
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _is_env_assignment(part: str) -> bool:
|
| 10 |
+
"""Return True when a token is a shell-style env assignment."""
|
| 11 |
+
return bool(_ENV_ASSIGNMENT_RE.match(part))
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _strip_env_assignments(parts: list[str]) -> list[str]:
|
| 15 |
+
"""Return command parts after leading shell-style env assignments."""
|
| 16 |
+
cmd_start = 0
|
| 17 |
+
for i, part in enumerate(parts):
|
| 18 |
+
if _is_env_assignment(part):
|
| 19 |
+
cmd_start = i + 1
|
| 20 |
+
else:
|
| 21 |
+
break
|
| 22 |
+
return parts[cmd_start:]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def extract_command_prefix(command: str) -> str:
|
| 26 |
+
"""Extract the command prefix for fast prefix detection.
|
| 27 |
+
|
| 28 |
+
Parses a shell command safely, handling environment variables and
|
| 29 |
+
command injection attempts. Returns the command prefix suitable
|
| 30 |
+
for quick identification.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
Command prefix (e.g., "git", "git commit", "npm install")
|
| 34 |
+
or "none" if no valid command found
|
| 35 |
+
"""
|
| 36 |
+
if "`" in command or "$(" in command:
|
| 37 |
+
return "command_injection_detected"
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
parts = shlex.split(command, posix=False)
|
| 41 |
+
if not parts:
|
| 42 |
+
return "none"
|
| 43 |
+
|
| 44 |
+
env_prefix = []
|
| 45 |
+
cmd_start = 0
|
| 46 |
+
for i, part in enumerate(parts):
|
| 47 |
+
if _is_env_assignment(part):
|
| 48 |
+
env_prefix.append(part)
|
| 49 |
+
cmd_start = i + 1
|
| 50 |
+
else:
|
| 51 |
+
break
|
| 52 |
+
|
| 53 |
+
if cmd_start >= len(parts):
|
| 54 |
+
return "none"
|
| 55 |
+
|
| 56 |
+
cmd_parts = parts[cmd_start:]
|
| 57 |
+
if not cmd_parts:
|
| 58 |
+
return "none"
|
| 59 |
+
|
| 60 |
+
first_word = cmd_parts[0]
|
| 61 |
+
two_word_commands = {
|
| 62 |
+
"git",
|
| 63 |
+
"npm",
|
| 64 |
+
"docker",
|
| 65 |
+
"kubectl",
|
| 66 |
+
"cargo",
|
| 67 |
+
"go",
|
| 68 |
+
"pip",
|
| 69 |
+
"yarn",
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
if first_word in two_word_commands and len(cmd_parts) > 1:
|
| 73 |
+
second_word = cmd_parts[1]
|
| 74 |
+
if not second_word.startswith("-"):
|
| 75 |
+
return f"{first_word} {second_word}"
|
| 76 |
+
return first_word
|
| 77 |
+
return first_word if not env_prefix else " ".join(env_prefix) + " " + first_word
|
| 78 |
+
|
| 79 |
+
except ValueError:
|
| 80 |
+
parts = command.split()
|
| 81 |
+
if not parts:
|
| 82 |
+
return "none"
|
| 83 |
+
cmd_parts = _strip_env_assignments(parts)
|
| 84 |
+
return cmd_parts[0] if cmd_parts else "none"
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def extract_filepaths_from_command(command: str, output: str) -> str:
|
| 88 |
+
"""Extract file paths from a command locally without API call.
|
| 89 |
+
|
| 90 |
+
Determines if the command reads file contents and extracts paths accordingly.
|
| 91 |
+
Commands like ls/dir/find just list files, so return empty.
|
| 92 |
+
Commands like cat/head/tail actually read contents, so extract the file path.
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
Filepath extraction result in <filepaths> format
|
| 96 |
+
"""
|
| 97 |
+
listing_commands = {
|
| 98 |
+
"ls",
|
| 99 |
+
"dir",
|
| 100 |
+
"find",
|
| 101 |
+
"tree",
|
| 102 |
+
"pwd",
|
| 103 |
+
"cd",
|
| 104 |
+
"mkdir",
|
| 105 |
+
"rmdir",
|
| 106 |
+
"rm",
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
reading_commands = {"cat", "head", "tail", "less", "more", "bat", "type"}
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
parts = shlex.split(command, posix=False)
|
| 113 |
+
if not parts:
|
| 114 |
+
return "<filepaths>\n</filepaths>"
|
| 115 |
+
|
| 116 |
+
cmd_parts = _strip_env_assignments(parts)
|
| 117 |
+
if not cmd_parts:
|
| 118 |
+
return "<filepaths>\n</filepaths>"
|
| 119 |
+
|
| 120 |
+
base_cmd = cmd_parts[0].split("/")[-1].split("\\")[-1].lower()
|
| 121 |
+
|
| 122 |
+
if base_cmd in listing_commands:
|
| 123 |
+
return "<filepaths>\n</filepaths>"
|
| 124 |
+
|
| 125 |
+
if base_cmd in reading_commands:
|
| 126 |
+
filepaths = []
|
| 127 |
+
for part in cmd_parts[1:]:
|
| 128 |
+
if part.startswith("-"):
|
| 129 |
+
continue
|
| 130 |
+
filepaths.append(part)
|
| 131 |
+
|
| 132 |
+
if filepaths:
|
| 133 |
+
paths_str = "\n".join(filepaths)
|
| 134 |
+
return f"<filepaths>\n{paths_str}\n</filepaths>"
|
| 135 |
+
return "<filepaths>\n</filepaths>"
|
| 136 |
+
|
| 137 |
+
if base_cmd == "grep":
|
| 138 |
+
flags_with_args = {"-e", "-f", "-m", "-A", "-B", "-C"}
|
| 139 |
+
pattern_provided_via_flag = False
|
| 140 |
+
positional = []
|
| 141 |
+
|
| 142 |
+
skip_next = False
|
| 143 |
+
for part in cmd_parts[1:]:
|
| 144 |
+
if skip_next:
|
| 145 |
+
skip_next = False
|
| 146 |
+
continue
|
| 147 |
+
if part.startswith("-"):
|
| 148 |
+
if part in flags_with_args:
|
| 149 |
+
if part in {"-e", "-f"}:
|
| 150 |
+
pattern_provided_via_flag = True
|
| 151 |
+
skip_next = True
|
| 152 |
+
continue
|
| 153 |
+
positional.append(part)
|
| 154 |
+
|
| 155 |
+
filepaths = positional if pattern_provided_via_flag else positional[1:]
|
| 156 |
+
if filepaths:
|
| 157 |
+
paths_str = "\n".join(filepaths)
|
| 158 |
+
return f"<filepaths>\n{paths_str}\n</filepaths>"
|
| 159 |
+
return "<filepaths>\n</filepaths>"
|
| 160 |
+
|
| 161 |
+
return "<filepaths>\n</filepaths>"
|
| 162 |
+
|
| 163 |
+
except ValueError:
|
| 164 |
+
return "<filepaths>\n</filepaths>"
|
api/dependencies.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dependency injection for FastAPI."""
|
| 2 |
+
|
| 3 |
+
import secrets
|
| 4 |
+
|
| 5 |
+
from fastapi import Depends, HTTPException, Request
|
| 6 |
+
from loguru import logger
|
| 7 |
+
from starlette.applications import Starlette
|
| 8 |
+
|
| 9 |
+
from config.settings import Settings
|
| 10 |
+
from config.settings import get_settings as _get_settings
|
| 11 |
+
from core.anthropic import get_user_facing_error_message
|
| 12 |
+
from providers.base import BaseProvider
|
| 13 |
+
from providers.exceptions import (
|
| 14 |
+
AuthenticationError,
|
| 15 |
+
ServiceUnavailableError,
|
| 16 |
+
UnknownProviderTypeError,
|
| 17 |
+
)
|
| 18 |
+
from providers.registry import PROVIDER_DESCRIPTORS, ProviderRegistry
|
| 19 |
+
|
| 20 |
+
# Process-level cache: only for :func:`get_provider_for_type` / :func:`get_provider`
|
| 21 |
+
# when there is no ``Request``/``app`` (unit tests, scripts). HTTP handlers must pass
|
| 22 |
+
# ``app`` to :func:`resolve_provider` so the app-scoped registry is used.
|
| 23 |
+
_providers: dict[str, BaseProvider] = {}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_settings() -> Settings:
|
| 27 |
+
"""Return cached :class:`~config.settings.Settings` (FastAPI-friendly alias)."""
|
| 28 |
+
return _get_settings()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def resolve_provider(
|
| 32 |
+
provider_type: str,
|
| 33 |
+
*,
|
| 34 |
+
app: Starlette | None,
|
| 35 |
+
settings: Settings,
|
| 36 |
+
) -> BaseProvider:
|
| 37 |
+
"""Resolve a provider using the app-scoped registry when ``app`` is set.
|
| 38 |
+
|
| 39 |
+
When ``app`` is not ``None``, the app-owned :attr:`app.state.provider_registry`
|
| 40 |
+
must exist (installed by :class:`~api.runtime.AppRuntime` during startup).
|
| 41 |
+
Callers that construct a bare ``FastAPI`` without lifespan must set
|
| 42 |
+
``app.state.provider_registry`` explicitly.
|
| 43 |
+
|
| 44 |
+
When ``app`` is ``None`` (no HTTP context), uses the process-level
|
| 45 |
+
:data:`_providers` cache only.
|
| 46 |
+
"""
|
| 47 |
+
if app is not None:
|
| 48 |
+
reg = getattr(app.state, "provider_registry", None)
|
| 49 |
+
if reg is None:
|
| 50 |
+
raise ServiceUnavailableError(
|
| 51 |
+
"Provider registry is not configured. Ensure AppRuntime startup ran "
|
| 52 |
+
"or assign app.state.provider_registry for test apps."
|
| 53 |
+
)
|
| 54 |
+
return _resolve_with_registry(reg, provider_type, settings)
|
| 55 |
+
return _resolve_with_registry(ProviderRegistry(_providers), provider_type, settings)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _resolve_with_registry(
|
| 59 |
+
registry: ProviderRegistry, provider_type: str, settings: Settings
|
| 60 |
+
) -> BaseProvider:
|
| 61 |
+
should_log_init = not registry.is_cached(provider_type)
|
| 62 |
+
try:
|
| 63 |
+
provider = registry.get(provider_type, settings)
|
| 64 |
+
except AuthenticationError as e:
|
| 65 |
+
# Provider :class:`~providers.exceptions.AuthenticationError` messages are
|
| 66 |
+
# curated configuration hints (env var names, docs links), not upstream noise.
|
| 67 |
+
detail = str(e).strip() or get_user_facing_error_message(e)
|
| 68 |
+
raise HTTPException(status_code=503, detail=detail) from e
|
| 69 |
+
except UnknownProviderTypeError:
|
| 70 |
+
logger.error(
|
| 71 |
+
"Unknown provider_type: '{}'. Supported: {}",
|
| 72 |
+
provider_type,
|
| 73 |
+
", ".join(f"'{key}'" for key in PROVIDER_DESCRIPTORS),
|
| 74 |
+
)
|
| 75 |
+
raise
|
| 76 |
+
if should_log_init:
|
| 77 |
+
logger.info("Provider initialized: {}", provider_type)
|
| 78 |
+
return provider
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_provider_for_type(provider_type: str) -> BaseProvider:
|
| 82 |
+
"""Get or create a provider in the process-level cache (no ``app``/Request).
|
| 83 |
+
|
| 84 |
+
HTTP route handlers should call :func:`resolve_provider` with the active
|
| 85 |
+
:attr:`request.app` (via :class:`~api.runtime.AppRuntime`) instead of this
|
| 86 |
+
process-wide cache.
|
| 87 |
+
"""
|
| 88 |
+
return resolve_provider(provider_type, app=None, settings=get_settings())
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def require_api_key(
|
| 92 |
+
request: Request, settings: Settings = Depends(get_settings)
|
| 93 |
+
) -> None:
|
| 94 |
+
"""Require a server API key (Anthropic-style).
|
| 95 |
+
|
| 96 |
+
Checks `x-api-key` header or `Authorization: Bearer ...` against
|
| 97 |
+
`Settings.anthropic_auth_token`. If `ANTHROPIC_AUTH_TOKEN` is empty, this is a no-op.
|
| 98 |
+
"""
|
| 99 |
+
anthropic_auth_token = settings.anthropic_auth_token
|
| 100 |
+
if not anthropic_auth_token:
|
| 101 |
+
# No API key configured -> allow
|
| 102 |
+
return
|
| 103 |
+
|
| 104 |
+
header = (
|
| 105 |
+
request.headers.get("x-api-key")
|
| 106 |
+
or request.headers.get("authorization")
|
| 107 |
+
or request.headers.get("anthropic-auth-token")
|
| 108 |
+
)
|
| 109 |
+
if not header:
|
| 110 |
+
raise HTTPException(status_code=401, detail="Missing API key")
|
| 111 |
+
|
| 112 |
+
# Support both raw key in X-API-Key and Bearer token in Authorization
|
| 113 |
+
token = header
|
| 114 |
+
if header.lower().startswith("bearer "):
|
| 115 |
+
token = header.split(" ", 1)[1]
|
| 116 |
+
|
| 117 |
+
# Strip anything after the first colon to handle tokens with appended model names
|
| 118 |
+
if token and ":" in token:
|
| 119 |
+
token = token.split(":", 1)[0]
|
| 120 |
+
|
| 121 |
+
# Constant-time comparison to avoid leaking the configured token via
|
| 122 |
+
# response-time differences on a per-byte mismatch (CWE-208).
|
| 123 |
+
if not secrets.compare_digest(
|
| 124 |
+
token.encode("utf-8"), anthropic_auth_token.encode("utf-8")
|
| 125 |
+
):
|
| 126 |
+
raise HTTPException(status_code=401, detail="Invalid API key")
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def get_provider() -> BaseProvider:
|
| 130 |
+
"""Get or create the default provider (``MODEL`` / ``provider_type``).
|
| 131 |
+
|
| 132 |
+
Process-cache helper for scripts, unit tests, and non-FastAPI callers. HTTP
|
| 133 |
+
handlers must use :func:`resolve_provider` with :attr:`request.app` so the
|
| 134 |
+
app-scoped :class:`~providers.registry.ProviderRegistry` is used.
|
| 135 |
+
"""
|
| 136 |
+
return get_provider_for_type(get_settings().provider_type)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
async def cleanup_provider():
|
| 140 |
+
"""Cleanup all provider resources."""
|
| 141 |
+
global _providers
|
| 142 |
+
await ProviderRegistry(_providers).cleanup()
|
| 143 |
+
_providers = {}
|
| 144 |
+
logger.debug("Provider cleanup completed")
|
api/detection.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Request detection utilities for API optimizations.
|
| 2 |
+
|
| 3 |
+
Detects quota checks, title generation, prefix detection, suggestion mode,
|
| 4 |
+
and filepath extraction requests to enable fast-path responses.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from core.anthropic import extract_text_from_content
|
| 8 |
+
|
| 9 |
+
from .models.anthropic import MessagesRequest
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def is_quota_check_request(request_data: MessagesRequest) -> bool:
|
| 13 |
+
"""Check if this is a quota probe request.
|
| 14 |
+
|
| 15 |
+
Quota checks are typically simple requests with max_tokens=1
|
| 16 |
+
and a single message containing the word "quota".
|
| 17 |
+
"""
|
| 18 |
+
if (
|
| 19 |
+
request_data.max_tokens == 1
|
| 20 |
+
and len(request_data.messages) == 1
|
| 21 |
+
and request_data.messages[0].role == "user"
|
| 22 |
+
):
|
| 23 |
+
text = extract_text_from_content(request_data.messages[0].content)
|
| 24 |
+
if "quota" in text.lower():
|
| 25 |
+
return True
|
| 26 |
+
return False
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def is_title_generation_request(request_data: MessagesRequest) -> bool:
|
| 30 |
+
"""Check if this is a conversation title generation request.
|
| 31 |
+
|
| 32 |
+
Title generation requests are detected by a system prompt containing
|
| 33 |
+
title extraction instructions, no tools, and a single user message.
|
| 34 |
+
|
| 35 |
+
Matches Claude Code session title prompts (sentence-case title, JSON
|
| 36 |
+
\"title\" field, etc.).
|
| 37 |
+
"""
|
| 38 |
+
if not request_data.system or request_data.tools:
|
| 39 |
+
return False
|
| 40 |
+
system_text = extract_text_from_content(request_data.system).lower()
|
| 41 |
+
if "title" not in system_text:
|
| 42 |
+
return False
|
| 43 |
+
return "sentence-case title" in system_text or (
|
| 44 |
+
"return json" in system_text
|
| 45 |
+
and "field" in system_text
|
| 46 |
+
and ("coding session" in system_text or "this session" in system_text)
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def is_prefix_detection_request(request_data: MessagesRequest) -> tuple[bool, str]:
|
| 51 |
+
"""Check if this is a fast prefix detection request.
|
| 52 |
+
|
| 53 |
+
Prefix detection requests contain a policy_spec block and
|
| 54 |
+
a Command: section for extracting shell command prefixes.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Tuple of (is_prefix_request, command_string)
|
| 58 |
+
"""
|
| 59 |
+
if len(request_data.messages) != 1 or request_data.messages[0].role != "user":
|
| 60 |
+
return False, ""
|
| 61 |
+
|
| 62 |
+
content = extract_text_from_content(request_data.messages[0].content)
|
| 63 |
+
|
| 64 |
+
if "<policy_spec>" in content and "Command:" in content:
|
| 65 |
+
try:
|
| 66 |
+
cmd_start = content.rfind("Command:") + len("Command:")
|
| 67 |
+
return True, content[cmd_start:].strip()
|
| 68 |
+
except TypeError:
|
| 69 |
+
return False, ""
|
| 70 |
+
|
| 71 |
+
return False, ""
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def is_suggestion_mode_request(request_data: MessagesRequest) -> bool:
|
| 75 |
+
"""Check if this is a suggestion mode request.
|
| 76 |
+
|
| 77 |
+
Suggestion mode requests contain "[SUGGESTION MODE:" in the user's message,
|
| 78 |
+
used for auto-suggesting what the user might type next.
|
| 79 |
+
"""
|
| 80 |
+
for msg in request_data.messages:
|
| 81 |
+
if msg.role == "user":
|
| 82 |
+
text = extract_text_from_content(msg.content)
|
| 83 |
+
if "[SUGGESTION MODE:" in text:
|
| 84 |
+
return True
|
| 85 |
+
return False
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def is_filepath_extraction_request(
|
| 89 |
+
request_data: MessagesRequest,
|
| 90 |
+
) -> tuple[bool, str, str]:
|
| 91 |
+
"""Check if this is a filepath extraction request.
|
| 92 |
+
|
| 93 |
+
Filepath extraction requests have a single user message with
|
| 94 |
+
"Command:" and "Output:" sections, asking to extract file paths
|
| 95 |
+
from command output.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Tuple of (is_filepath_request, command, output)
|
| 99 |
+
"""
|
| 100 |
+
if len(request_data.messages) != 1 or request_data.messages[0].role != "user":
|
| 101 |
+
return False, "", ""
|
| 102 |
+
if request_data.tools:
|
| 103 |
+
return False, "", ""
|
| 104 |
+
|
| 105 |
+
content = extract_text_from_content(request_data.messages[0].content)
|
| 106 |
+
|
| 107 |
+
if "Command:" not in content or "Output:" not in content:
|
| 108 |
+
return False, "", ""
|
| 109 |
+
|
| 110 |
+
# Match if user content OR system block indicates filepath extraction
|
| 111 |
+
user_has_filepaths = (
|
| 112 |
+
"filepaths" in content.lower() or "<filepaths>" in content.lower()
|
| 113 |
+
)
|
| 114 |
+
system_text = (
|
| 115 |
+
extract_text_from_content(request_data.system) if request_data.system else ""
|
| 116 |
+
)
|
| 117 |
+
system_has_extract = (
|
| 118 |
+
"extract any file paths" in system_text.lower()
|
| 119 |
+
or "file paths that this command" in system_text.lower()
|
| 120 |
+
)
|
| 121 |
+
if not user_has_filepaths and not system_has_extract:
|
| 122 |
+
return False, "", ""
|
| 123 |
+
|
| 124 |
+
cmd_start = content.find("Command:") + len("Command:")
|
| 125 |
+
output_marker = content.find("Output:", cmd_start)
|
| 126 |
+
if output_marker == -1:
|
| 127 |
+
return False, "", ""
|
| 128 |
+
|
| 129 |
+
command = content[cmd_start:output_marker].strip()
|
| 130 |
+
output = content[output_marker + len("Output:") :].strip()
|
| 131 |
+
|
| 132 |
+
for marker in ["<", "\n\n"]:
|
| 133 |
+
if marker in output:
|
| 134 |
+
output = output.split(marker)[0].strip()
|
| 135 |
+
|
| 136 |
+
return True, command, output
|
api/gateway_model_ids.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Gateway-safe model id encoding for Claude Code model discovery."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
GATEWAY_MODEL_ID_PREFIX = "anthropic"
|
| 8 |
+
|
| 9 |
+
# Claude Code currently treats any model id containing ``claude-3-`` as not
|
| 10 |
+
# supporting thinking. This intentionally uses that client-side capability
|
| 11 |
+
# heuristic while keeping the real provider/model ref reversible for routing.
|
| 12 |
+
NO_THINKING_GATEWAY_MODEL_ID_PREFIX = "claude-3-freecc-no-thinking"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass(frozen=True, slots=True)
|
| 16 |
+
class DecodedGatewayModelId:
|
| 17 |
+
provider_id: str
|
| 18 |
+
provider_model: str
|
| 19 |
+
force_thinking_enabled: bool | None = None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def gateway_model_id(provider_model_ref: str) -> str:
|
| 23 |
+
"""Return the normal Claude Code-discoverable id for a provider/model ref."""
|
| 24 |
+
return f"{GATEWAY_MODEL_ID_PREFIX}/{provider_model_ref}"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def no_thinking_gateway_model_id(provider_model_ref: str) -> str:
|
| 28 |
+
"""Return a Claude Code-discoverable id that disables client thinking."""
|
| 29 |
+
return f"{NO_THINKING_GATEWAY_MODEL_ID_PREFIX}/{provider_model_ref}"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def decode_gateway_model_id(model_name: str) -> DecodedGatewayModelId | None:
|
| 33 |
+
"""Decode a model id advertised by this gateway, if it is one."""
|
| 34 |
+
prefix, separator, remainder = model_name.partition("/")
|
| 35 |
+
if not separator:
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
force_thinking_enabled: bool | None
|
| 39 |
+
if prefix == GATEWAY_MODEL_ID_PREFIX:
|
| 40 |
+
force_thinking_enabled = None
|
| 41 |
+
elif prefix == NO_THINKING_GATEWAY_MODEL_ID_PREFIX:
|
| 42 |
+
force_thinking_enabled = False
|
| 43 |
+
else:
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
provider_id, provider_separator, provider_model = remainder.partition("/")
|
| 47 |
+
if not provider_separator or not provider_model:
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
+
return DecodedGatewayModelId(
|
| 51 |
+
provider_id=provider_id,
|
| 52 |
+
provider_model=provider_model,
|
| 53 |
+
force_thinking_enabled=force_thinking_enabled,
|
| 54 |
+
)
|
api/model_router.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model routing for Claude-compatible requests."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
from loguru import logger
|
| 8 |
+
|
| 9 |
+
from config.provider_ids import SUPPORTED_PROVIDER_IDS
|
| 10 |
+
from config.settings import Settings
|
| 11 |
+
|
| 12 |
+
from .gateway_model_ids import decode_gateway_model_id
|
| 13 |
+
from .models.anthropic import MessagesRequest, TokenCountRequest
|
| 14 |
+
from providers.rate_limit import GlobalRateLimiter
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass(frozen=True, slots=True)
|
| 18 |
+
class ResolvedModel:
|
| 19 |
+
original_model: str
|
| 20 |
+
provider_id: str
|
| 21 |
+
provider_model: str
|
| 22 |
+
provider_model_ref: str
|
| 23 |
+
thinking_enabled: bool
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass(frozen=True, slots=True)
|
| 27 |
+
class RoutedMessagesRequest:
|
| 28 |
+
request: MessagesRequest
|
| 29 |
+
resolved: ResolvedModel
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass(frozen=True, slots=True)
|
| 33 |
+
class RoutedTokenCountRequest:
|
| 34 |
+
request: TokenCountRequest
|
| 35 |
+
resolved: ResolvedModel
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class ModelRouter:
|
| 39 |
+
"""Resolve incoming Claude model names to configured provider/model pairs."""
|
| 40 |
+
|
| 41 |
+
def __init__(self, settings: Settings):
|
| 42 |
+
self._settings = settings
|
| 43 |
+
|
| 44 |
+
def _is_auto(self, model_name: str) -> bool:
|
| 45 |
+
"""Return whether the model name refers to the virtual 'auto' model."""
|
| 46 |
+
name_lower = model_name.lower()
|
| 47 |
+
return name_lower == "auto" or name_lower == "anthropic/auto"
|
| 48 |
+
|
| 49 |
+
def _normalize_candidate_ref(self, raw_ref: str) -> str | None:
|
| 50 |
+
"""Normalize auto candidate refs to ``provider/model`` when possible."""
|
| 51 |
+
candidate = (raw_ref or "").strip()
|
| 52 |
+
if not candidate:
|
| 53 |
+
return None
|
| 54 |
+
|
| 55 |
+
provider_id, separator, remainder = candidate.partition("/")
|
| 56 |
+
if separator and provider_id in SUPPORTED_PROVIDER_IDS and remainder:
|
| 57 |
+
return f"{provider_id}/{remainder}"
|
| 58 |
+
|
| 59 |
+
# Treat bare model ids and vendor/model ids as NVIDIA NIM models.
|
| 60 |
+
return f"nvidia_nim/{candidate}"
|
| 61 |
+
|
| 62 |
+
def resolve(self, claude_model_name: str) -> ResolvedModel:
|
| 63 |
+
# Special virtual model 'auto' maps to the configured default MODEL and
|
| 64 |
+
# enables provider-side fallbacks. Resolve it to the configured model
|
| 65 |
+
# while preserving the original requested name.
|
| 66 |
+
if self._is_auto(claude_model_name):
|
| 67 |
+
# If the user configured an explicit AUTO_MODEL_ORDER, try each
|
| 68 |
+
# provider/model pair in order and pick the first provider that is
|
| 69 |
+
# plausibly configured. Fall back to the single configured MODEL.
|
| 70 |
+
order_csv = (self._settings.auto_model_order or "").strip()
|
| 71 |
+
if order_csv:
|
| 72 |
+
for cand in [c.strip() for c in order_csv.split(",") if c.strip()]:
|
| 73 |
+
if "/" not in cand:
|
| 74 |
+
# assume vendor-prefixed entries; skip malformed
|
| 75 |
+
continue
|
| 76 |
+
provider_id = Settings.parse_provider_type(cand)
|
| 77 |
+
provider_model = Settings.parse_model_name(cand)
|
| 78 |
+
if self._settings.provider_is_configured(provider_id):
|
| 79 |
+
thinking_enabled = self._settings.resolve_thinking(claude_model_name)
|
| 80 |
+
return ResolvedModel(
|
| 81 |
+
original_model=claude_model_name,
|
| 82 |
+
provider_id=provider_id,
|
| 83 |
+
provider_model=provider_model,
|
| 84 |
+
provider_model_ref=cand,
|
| 85 |
+
thinking_enabled=thinking_enabled,
|
| 86 |
+
)
|
| 87 |
+
# No explicit order matched or none configured — fall back to default MODEL
|
| 88 |
+
provider_model_ref = self._settings.model
|
| 89 |
+
provider_id = Settings.parse_provider_type(provider_model_ref)
|
| 90 |
+
provider_model = Settings.parse_model_name(provider_model_ref)
|
| 91 |
+
thinking_enabled = self._settings.resolve_thinking(claude_model_name)
|
| 92 |
+
return ResolvedModel(
|
| 93 |
+
original_model=claude_model_name,
|
| 94 |
+
provider_id=provider_id,
|
| 95 |
+
provider_model=provider_model,
|
| 96 |
+
provider_model_ref=provider_model_ref,
|
| 97 |
+
thinking_enabled=thinking_enabled,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
(
|
| 101 |
+
direct_provider_id,
|
| 102 |
+
direct_provider_model,
|
| 103 |
+
force_thinking_enabled,
|
| 104 |
+
) = self._direct_provider_model(claude_model_name)
|
| 105 |
+
if direct_provider_id is not None and direct_provider_model is not None:
|
| 106 |
+
thinking_enabled = (
|
| 107 |
+
force_thinking_enabled
|
| 108 |
+
if force_thinking_enabled is not None
|
| 109 |
+
else self._settings.resolve_thinking(direct_provider_model)
|
| 110 |
+
)
|
| 111 |
+
logger.debug(
|
| 112 |
+
"MODEL DIRECT: '{}' -> provider='{}' model='{}' thinking={}",
|
| 113 |
+
claude_model_name,
|
| 114 |
+
direct_provider_id,
|
| 115 |
+
direct_provider_model,
|
| 116 |
+
thinking_enabled,
|
| 117 |
+
)
|
| 118 |
+
return ResolvedModel(
|
| 119 |
+
original_model=claude_model_name,
|
| 120 |
+
provider_id=direct_provider_id,
|
| 121 |
+
provider_model=direct_provider_model,
|
| 122 |
+
provider_model_ref=claude_model_name,
|
| 123 |
+
thinking_enabled=thinking_enabled,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
provider_model_ref = self._settings.resolve_model(claude_model_name)
|
| 127 |
+
thinking_enabled = self._settings.resolve_thinking(claude_model_name)
|
| 128 |
+
provider_id = Settings.parse_provider_type(provider_model_ref)
|
| 129 |
+
provider_model = Settings.parse_model_name(provider_model_ref)
|
| 130 |
+
if provider_model != claude_model_name:
|
| 131 |
+
logger.debug(
|
| 132 |
+
"MODEL MAPPING: '{}' -> '{}'", claude_model_name, provider_model
|
| 133 |
+
)
|
| 134 |
+
return ResolvedModel(
|
| 135 |
+
original_model=claude_model_name,
|
| 136 |
+
provider_id=provider_id,
|
| 137 |
+
provider_model=provider_model,
|
| 138 |
+
provider_model_ref=provider_model_ref,
|
| 139 |
+
thinking_enabled=thinking_enabled,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
def resolve_candidates(self, claude_model_name: str) -> list[ResolvedModel]:
|
| 143 |
+
"""Resolve a model name to a prioritized list of candidates.
|
| 144 |
+
|
| 145 |
+
Used by the 'auto' routing logic to implement provider-side failover.
|
| 146 |
+
"""
|
| 147 |
+
if not self._is_auto(claude_model_name):
|
| 148 |
+
return [self.resolve(claude_model_name)]
|
| 149 |
+
|
| 150 |
+
healthy_candidates: list[ResolvedModel] = []
|
| 151 |
+
blocked_candidates: list[ResolvedModel] = []
|
| 152 |
+
seen: set[str] = set()
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def add_candidate(ref: str | None, source: str) -> None:
|
| 156 |
+
normalized_ref = self._normalize_candidate_ref(ref or "")
|
| 157 |
+
if normalized_ref is None or normalized_ref in seen:
|
| 158 |
+
return
|
| 159 |
+
provider_id = Settings.parse_provider_type(normalized_ref)
|
| 160 |
+
provider_model = Settings.parse_model_name(normalized_ref)
|
| 161 |
+
if self._settings.provider_is_configured(provider_id):
|
| 162 |
+
seen.add(normalized_ref)
|
| 163 |
+
resolved = ResolvedModel(
|
| 164 |
+
original_model=claude_model_name,
|
| 165 |
+
provider_id=provider_id,
|
| 166 |
+
provider_model=provider_model,
|
| 167 |
+
provider_model_ref=normalized_ref,
|
| 168 |
+
thinking_enabled=self._settings.resolve_thinking(claude_model_name),
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
limiter = GlobalRateLimiter.get_scoped_instance(provider_id)
|
| 172 |
+
if limiter.is_blocked():
|
| 173 |
+
logger.debug(
|
| 174 |
+
"Routing: candidate '{}' (from {}) is BLOCKED",
|
| 175 |
+
normalized_ref,
|
| 176 |
+
source,
|
| 177 |
+
)
|
| 178 |
+
blocked_candidates.append(resolved)
|
| 179 |
+
else:
|
| 180 |
+
logger.debug(
|
| 181 |
+
"Routing: added candidate '{}' (from {})",
|
| 182 |
+
normalized_ref,
|
| 183 |
+
source,
|
| 184 |
+
)
|
| 185 |
+
healthy_candidates.append(resolved)
|
| 186 |
+
else:
|
| 187 |
+
logger.debug(
|
| 188 |
+
"Routing: candidate '{}' (from {}) is NOT CONFIGURED",
|
| 189 |
+
normalized_ref,
|
| 190 |
+
source,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# 1. Preferred order (AUTO_MODEL_ORDER)
|
| 194 |
+
order_csv = (self._settings.auto_model_order or "").strip()
|
| 195 |
+
if order_csv:
|
| 196 |
+
for cand in [c.strip() for c in order_csv.split(",") if c.strip()]:
|
| 197 |
+
add_candidate(cand, "AUTO_MODEL_PRIORITY")
|
| 198 |
+
|
| 199 |
+
# 2. Main MODEL
|
| 200 |
+
add_candidate(self._settings.model, "MODEL")
|
| 201 |
+
|
| 202 |
+
# 3. NVIDIA Fallbacks
|
| 203 |
+
nim_csv = (self._settings.nvidia_nim_fallback_models or "").strip()
|
| 204 |
+
if nim_csv:
|
| 205 |
+
for cand in [c.strip() for c in nim_csv.split(",") if c.strip()]:
|
| 206 |
+
add_candidate(cand, "NVIDIA_NIM_FALLBACK_MODELS")
|
| 207 |
+
|
| 208 |
+
# 4. Model-specific overrides
|
| 209 |
+
add_candidate(self._settings.model_opus, "MODEL_OPUS")
|
| 210 |
+
add_candidate(self._settings.model_sonnet, "MODEL_SONNET")
|
| 211 |
+
add_candidate(self._settings.model_haiku, "MODEL_HAIKU")
|
| 212 |
+
|
| 213 |
+
all_candidates = healthy_candidates + blocked_candidates
|
| 214 |
+
logger.info(
|
| 215 |
+
"Routing: resolved '{}' to {} candidates: {}",
|
| 216 |
+
claude_model_name,
|
| 217 |
+
len(all_candidates),
|
| 218 |
+
", ".join(c.provider_model_ref for c in all_candidates),
|
| 219 |
+
)
|
| 220 |
+
return all_candidates
|
| 221 |
+
|
| 222 |
+
def _direct_provider_model(
|
| 223 |
+
self, model_name: str
|
| 224 |
+
) -> tuple[str | None, str | None, bool | None]:
|
| 225 |
+
decoded = decode_gateway_model_id(model_name)
|
| 226 |
+
if decoded is not None:
|
| 227 |
+
if decoded.provider_id not in SUPPORTED_PROVIDER_IDS:
|
| 228 |
+
return None, None, None
|
| 229 |
+
return (
|
| 230 |
+
decoded.provider_id,
|
| 231 |
+
decoded.provider_model,
|
| 232 |
+
decoded.force_thinking_enabled,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
provider_id, separator, provider_model = model_name.partition("/")
|
| 236 |
+
if not separator:
|
| 237 |
+
return None, None, None
|
| 238 |
+
if provider_id not in SUPPORTED_PROVIDER_IDS:
|
| 239 |
+
return None, None, None
|
| 240 |
+
if not provider_model:
|
| 241 |
+
return None, None, None
|
| 242 |
+
return provider_id, provider_model, None
|
| 243 |
+
|
| 244 |
+
def resolve_messages_request(
|
| 245 |
+
self, request: MessagesRequest
|
| 246 |
+
) -> RoutedMessagesRequest:
|
| 247 |
+
"""Return an internal routed request context."""
|
| 248 |
+
resolved = self.resolve(request.model)
|
| 249 |
+
routed = request.model_copy(deep=True)
|
| 250 |
+
routed.model = resolved.provider_model
|
| 251 |
+
return RoutedMessagesRequest(request=routed, resolved=resolved)
|
| 252 |
+
|
| 253 |
+
def resolve_token_count_request(
|
| 254 |
+
self, request: TokenCountRequest
|
| 255 |
+
) -> RoutedTokenCountRequest:
|
| 256 |
+
"""Return an internal token-count request context."""
|
| 257 |
+
resolved = self.resolve(request.model)
|
| 258 |
+
routed = request.model_copy(
|
| 259 |
+
update={"model": resolved.provider_model}, deep=True
|
| 260 |
+
)
|
| 261 |
+
return RoutedTokenCountRequest(request=routed, resolved=resolved)
|
api/models/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""API models exports."""
|
| 2 |
+
|
| 3 |
+
from .anthropic import (
|
| 4 |
+
ContentBlockImage,
|
| 5 |
+
ContentBlockRedactedThinking,
|
| 6 |
+
ContentBlockText,
|
| 7 |
+
ContentBlockThinking,
|
| 8 |
+
ContentBlockToolResult,
|
| 9 |
+
ContentBlockToolUse,
|
| 10 |
+
Message,
|
| 11 |
+
MessagesRequest,
|
| 12 |
+
Role,
|
| 13 |
+
SystemContent,
|
| 14 |
+
ThinkingConfig,
|
| 15 |
+
TokenCountRequest,
|
| 16 |
+
Tool,
|
| 17 |
+
)
|
| 18 |
+
from .responses import (
|
| 19 |
+
MessagesResponse,
|
| 20 |
+
ModelResponse,
|
| 21 |
+
ModelsListResponse,
|
| 22 |
+
TokenCountResponse,
|
| 23 |
+
Usage,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
__all__ = [
|
| 27 |
+
"ContentBlockImage",
|
| 28 |
+
"ContentBlockRedactedThinking",
|
| 29 |
+
"ContentBlockText",
|
| 30 |
+
"ContentBlockThinking",
|
| 31 |
+
"ContentBlockToolResult",
|
| 32 |
+
"ContentBlockToolUse",
|
| 33 |
+
"Message",
|
| 34 |
+
"MessagesRequest",
|
| 35 |
+
"MessagesResponse",
|
| 36 |
+
"ModelResponse",
|
| 37 |
+
"ModelsListResponse",
|
| 38 |
+
"Role",
|
| 39 |
+
"SystemContent",
|
| 40 |
+
"ThinkingConfig",
|
| 41 |
+
"TokenCountRequest",
|
| 42 |
+
"TokenCountResponse",
|
| 43 |
+
"Tool",
|
| 44 |
+
"Usage",
|
| 45 |
+
]
|
api/models/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (849 Bytes). View file
|
|
|
api/models/__pycache__/anthropic.cpython-314.pyc
ADDED
|
Binary file (11.6 kB). View file
|
|
|
api/models/__pycache__/responses.cpython-314.pyc
ADDED
|
Binary file (3.69 kB). View file
|
|
|
api/models/anthropic.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic models for Anthropic-compatible requests."""
|
| 2 |
+
|
| 3 |
+
from enum import StrEnum
|
| 4 |
+
from typing import Any, Literal
|
| 5 |
+
|
| 6 |
+
from pydantic import BaseModel, ConfigDict, Field
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# =============================================================================
|
| 10 |
+
# Content Block Types
|
| 11 |
+
# =============================================================================
|
| 12 |
+
class Role(StrEnum):
|
| 13 |
+
user = "user"
|
| 14 |
+
assistant = "assistant"
|
| 15 |
+
system = "system"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class _AnthropicBlockBase(BaseModel):
|
| 19 |
+
"""Pass through provider fields (e.g. ``cache_control``) for native transports."""
|
| 20 |
+
|
| 21 |
+
model_config = ConfigDict(extra="allow")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ContentBlockText(_AnthropicBlockBase):
|
| 25 |
+
type: Literal["text"]
|
| 26 |
+
text: str
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ContentBlockImage(_AnthropicBlockBase):
|
| 30 |
+
type: Literal["image"]
|
| 31 |
+
source: dict[str, Any]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ContentBlockToolUse(_AnthropicBlockBase):
|
| 35 |
+
type: Literal["tool_use"]
|
| 36 |
+
id: str
|
| 37 |
+
name: str
|
| 38 |
+
input: dict[str, Any]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class ContentBlockToolResult(_AnthropicBlockBase):
|
| 42 |
+
type: Literal["tool_result"]
|
| 43 |
+
tool_use_id: str
|
| 44 |
+
content: str | list[Any] | dict[str, Any]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class ContentBlockThinking(_AnthropicBlockBase):
|
| 48 |
+
type: Literal["thinking"]
|
| 49 |
+
thinking: str
|
| 50 |
+
signature: str | None = None
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class ContentBlockRedactedThinking(_AnthropicBlockBase):
|
| 54 |
+
type: Literal["redacted_thinking"]
|
| 55 |
+
data: str
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class ContentBlockServerToolUse(_AnthropicBlockBase):
|
| 59 |
+
"""Anthropic server-side tool invocation (e.g. ``web_search``, ``web_fetch``)."""
|
| 60 |
+
|
| 61 |
+
type: Literal["server_tool_use"]
|
| 62 |
+
id: str
|
| 63 |
+
name: str
|
| 64 |
+
input: dict[str, Any]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class ContentBlockWebSearchToolResult(_AnthropicBlockBase):
|
| 68 |
+
type: Literal["web_search_tool_result"]
|
| 69 |
+
tool_use_id: str
|
| 70 |
+
content: Any
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class ContentBlockWebFetchToolResult(_AnthropicBlockBase):
|
| 74 |
+
type: Literal["web_fetch_tool_result"]
|
| 75 |
+
tool_use_id: str
|
| 76 |
+
content: Any
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class SystemContent(_AnthropicBlockBase):
|
| 80 |
+
type: Literal["text"]
|
| 81 |
+
text: str
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# =============================================================================
|
| 85 |
+
# Message Types
|
| 86 |
+
# =============================================================================
|
| 87 |
+
class Message(BaseModel):
|
| 88 |
+
role: Literal["user", "assistant"]
|
| 89 |
+
content: (
|
| 90 |
+
str
|
| 91 |
+
| list[
|
| 92 |
+
ContentBlockText
|
| 93 |
+
| ContentBlockImage
|
| 94 |
+
| ContentBlockToolUse
|
| 95 |
+
| ContentBlockToolResult
|
| 96 |
+
| ContentBlockThinking
|
| 97 |
+
| ContentBlockRedactedThinking
|
| 98 |
+
| ContentBlockServerToolUse
|
| 99 |
+
| ContentBlockWebSearchToolResult
|
| 100 |
+
| ContentBlockWebFetchToolResult
|
| 101 |
+
]
|
| 102 |
+
)
|
| 103 |
+
reasoning_content: str | None = None
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class Tool(_AnthropicBlockBase):
|
| 107 |
+
name: str
|
| 108 |
+
# Anthropic server tools (e.g. web_search beta tools) include a ``type`` and
|
| 109 |
+
# may omit ``input_schema`` because the provider owns the schema.
|
| 110 |
+
type: str | None = None
|
| 111 |
+
description: str | None = None
|
| 112 |
+
input_schema: dict[str, Any] | None = None
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class ThinkingConfig(BaseModel):
|
| 116 |
+
enabled: bool | None = True
|
| 117 |
+
type: str | None = None
|
| 118 |
+
budget_tokens: int | None = None
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# =============================================================================
|
| 122 |
+
# Request Models
|
| 123 |
+
# =============================================================================
|
| 124 |
+
class MessagesRequest(BaseModel):
|
| 125 |
+
model_config = ConfigDict(extra="allow")
|
| 126 |
+
|
| 127 |
+
model: str
|
| 128 |
+
# Internal routing / debug: accepted on parse but not serialized to providers.
|
| 129 |
+
original_model: str | None = Field(default=None, exclude=True)
|
| 130 |
+
resolved_provider_model: str | None = Field(default=None, exclude=True)
|
| 131 |
+
max_tokens: int | None = None
|
| 132 |
+
messages: list[Message]
|
| 133 |
+
system: str | list[SystemContent] | None = None
|
| 134 |
+
stop_sequences: list[str] | None = None
|
| 135 |
+
stream: bool | None = True
|
| 136 |
+
temperature: float | None = None
|
| 137 |
+
top_p: float | None = None
|
| 138 |
+
top_k: int | None = None
|
| 139 |
+
metadata: dict[str, Any] | None = None
|
| 140 |
+
tools: list[Tool] | None = None
|
| 141 |
+
tool_choice: dict[str, Any] | None = None
|
| 142 |
+
thinking: ThinkingConfig | None = None
|
| 143 |
+
# Native Anthropic / SDK client hints: ignored (not forwarded) for OpenAI Chat conversion.
|
| 144 |
+
context_management: dict[str, Any] | None = None
|
| 145 |
+
output_config: dict[str, Any] | None = None
|
| 146 |
+
mcp_servers: list[dict[str, Any]] | None = None
|
| 147 |
+
extra_body: dict[str, Any] | None = None
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class TokenCountRequest(BaseModel):
|
| 151 |
+
model_config = ConfigDict(extra="allow")
|
| 152 |
+
|
| 153 |
+
model: str
|
| 154 |
+
original_model: str | None = Field(default=None, exclude=True)
|
| 155 |
+
resolved_provider_model: str | None = Field(default=None, exclude=True)
|
| 156 |
+
messages: list[Message]
|
| 157 |
+
system: str | list[SystemContent] | None = None
|
| 158 |
+
tools: list[Tool] | None = None
|
| 159 |
+
thinking: ThinkingConfig | None = None
|
| 160 |
+
tool_choice: dict[str, Any] | None = None
|
| 161 |
+
context_management: dict[str, Any] | None = None
|
| 162 |
+
output_config: dict[str, Any] | None = None
|
| 163 |
+
mcp_servers: list[dict[str, Any]] | None = None
|
api/models/responses.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic models for API responses."""
|
| 2 |
+
|
| 3 |
+
from typing import Any, Literal
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
|
| 7 |
+
from .anthropic import (
|
| 8 |
+
ContentBlockRedactedThinking,
|
| 9 |
+
ContentBlockText,
|
| 10 |
+
ContentBlockThinking,
|
| 11 |
+
ContentBlockToolUse,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TokenCountResponse(BaseModel):
|
| 16 |
+
input_tokens: int
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ModelResponse(BaseModel):
|
| 20 |
+
created_at: str
|
| 21 |
+
display_name: str
|
| 22 |
+
id: str
|
| 23 |
+
type: Literal["model"] = "model"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ModelsListResponse(BaseModel):
|
| 27 |
+
data: list[ModelResponse]
|
| 28 |
+
first_id: str | None
|
| 29 |
+
has_more: bool
|
| 30 |
+
last_id: str | None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Usage(BaseModel):
|
| 34 |
+
input_tokens: int
|
| 35 |
+
output_tokens: int
|
| 36 |
+
cache_creation_input_tokens: int = 0
|
| 37 |
+
cache_read_input_tokens: int = 0
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class MessagesResponse(BaseModel):
|
| 41 |
+
id: str
|
| 42 |
+
model: str
|
| 43 |
+
role: Literal["assistant"] = "assistant"
|
| 44 |
+
content: list[
|
| 45 |
+
ContentBlockText
|
| 46 |
+
| ContentBlockToolUse
|
| 47 |
+
| ContentBlockThinking
|
| 48 |
+
| ContentBlockRedactedThinking
|
| 49 |
+
| dict[str, Any]
|
| 50 |
+
]
|
| 51 |
+
type: Literal["message"] = "message"
|
| 52 |
+
stop_reason: (
|
| 53 |
+
Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"] | None
|
| 54 |
+
) = None
|
| 55 |
+
stop_sequence: str | None = None
|
| 56 |
+
usage: Usage
|
api/optimization_handlers.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Optimization handlers for fast-path API responses.
|
| 2 |
+
|
| 3 |
+
Each handler returns a MessagesResponse if the request matches and the
|
| 4 |
+
optimization is enabled, otherwise None.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import uuid
|
| 8 |
+
|
| 9 |
+
from loguru import logger
|
| 10 |
+
|
| 11 |
+
from config.settings import Settings
|
| 12 |
+
|
| 13 |
+
from .command_utils import extract_command_prefix, extract_filepaths_from_command
|
| 14 |
+
from .detection import (
|
| 15 |
+
is_filepath_extraction_request,
|
| 16 |
+
is_prefix_detection_request,
|
| 17 |
+
is_quota_check_request,
|
| 18 |
+
is_suggestion_mode_request,
|
| 19 |
+
is_title_generation_request,
|
| 20 |
+
)
|
| 21 |
+
from .models.anthropic import MessagesRequest
|
| 22 |
+
from .models.responses import MessagesResponse, Usage
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _text_response(
|
| 26 |
+
request_data: MessagesRequest,
|
| 27 |
+
text: str,
|
| 28 |
+
*,
|
| 29 |
+
input_tokens: int,
|
| 30 |
+
output_tokens: int,
|
| 31 |
+
) -> MessagesResponse:
|
| 32 |
+
return MessagesResponse(
|
| 33 |
+
id=f"msg_{uuid.uuid4()}",
|
| 34 |
+
model=request_data.model,
|
| 35 |
+
content=[{"type": "text", "text": text}],
|
| 36 |
+
stop_reason="end_turn",
|
| 37 |
+
usage=Usage(input_tokens=input_tokens, output_tokens=output_tokens),
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def try_prefix_detection(
|
| 42 |
+
request_data: MessagesRequest, settings: Settings
|
| 43 |
+
) -> MessagesResponse | None:
|
| 44 |
+
"""Fast prefix detection - return command prefix without API call."""
|
| 45 |
+
if not settings.fast_prefix_detection:
|
| 46 |
+
return None
|
| 47 |
+
|
| 48 |
+
is_prefix_req, command = is_prefix_detection_request(request_data)
|
| 49 |
+
if not is_prefix_req:
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
logger.info("Optimization: Fast prefix detection request")
|
| 53 |
+
return _text_response(
|
| 54 |
+
request_data,
|
| 55 |
+
extract_command_prefix(command),
|
| 56 |
+
input_tokens=100,
|
| 57 |
+
output_tokens=5,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def try_quota_mock(
|
| 62 |
+
request_data: MessagesRequest, settings: Settings
|
| 63 |
+
) -> MessagesResponse | None:
|
| 64 |
+
"""Mock quota probe requests."""
|
| 65 |
+
if not settings.enable_network_probe_mock:
|
| 66 |
+
return None
|
| 67 |
+
if not is_quota_check_request(request_data):
|
| 68 |
+
return None
|
| 69 |
+
|
| 70 |
+
logger.info("Optimization: Intercepted and mocked quota probe")
|
| 71 |
+
return _text_response(
|
| 72 |
+
request_data,
|
| 73 |
+
"Quota check passed.",
|
| 74 |
+
input_tokens=10,
|
| 75 |
+
output_tokens=5,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def try_title_skip(
|
| 80 |
+
request_data: MessagesRequest, settings: Settings
|
| 81 |
+
) -> MessagesResponse | None:
|
| 82 |
+
"""Skip title generation requests."""
|
| 83 |
+
if not settings.enable_title_generation_skip:
|
| 84 |
+
return None
|
| 85 |
+
if not is_title_generation_request(request_data):
|
| 86 |
+
return None
|
| 87 |
+
|
| 88 |
+
logger.info("Optimization: Skipped title generation request")
|
| 89 |
+
return _text_response(
|
| 90 |
+
request_data,
|
| 91 |
+
"Conversation",
|
| 92 |
+
input_tokens=100,
|
| 93 |
+
output_tokens=5,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def try_suggestion_skip(
|
| 98 |
+
request_data: MessagesRequest, settings: Settings
|
| 99 |
+
) -> MessagesResponse | None:
|
| 100 |
+
"""Skip suggestion mode requests."""
|
| 101 |
+
if not settings.enable_suggestion_mode_skip:
|
| 102 |
+
return None
|
| 103 |
+
if not is_suggestion_mode_request(request_data):
|
| 104 |
+
return None
|
| 105 |
+
|
| 106 |
+
logger.info("Optimization: Skipped suggestion mode request")
|
| 107 |
+
return _text_response(
|
| 108 |
+
request_data,
|
| 109 |
+
"",
|
| 110 |
+
input_tokens=100,
|
| 111 |
+
output_tokens=1,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def try_filepath_mock(
|
| 116 |
+
request_data: MessagesRequest, settings: Settings
|
| 117 |
+
) -> MessagesResponse | None:
|
| 118 |
+
"""Mock filepath extraction requests."""
|
| 119 |
+
if not settings.enable_filepath_extraction_mock:
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
is_fp, cmd, output = is_filepath_extraction_request(request_data)
|
| 123 |
+
if not is_fp:
|
| 124 |
+
return None
|
| 125 |
+
|
| 126 |
+
filepaths = extract_filepaths_from_command(cmd, output)
|
| 127 |
+
logger.info("Optimization: Mocked filepath extraction")
|
| 128 |
+
return _text_response(
|
| 129 |
+
request_data,
|
| 130 |
+
filepaths,
|
| 131 |
+
input_tokens=100,
|
| 132 |
+
output_tokens=10,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# Cheapest/most common optimizations first for faster short-circuit.
|
| 137 |
+
OPTIMIZATION_HANDLERS = [
|
| 138 |
+
try_quota_mock,
|
| 139 |
+
try_prefix_detection,
|
| 140 |
+
try_title_skip,
|
| 141 |
+
try_suggestion_skip,
|
| 142 |
+
try_filepath_mock,
|
| 143 |
+
]
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def try_optimizations(
|
| 147 |
+
request_data: MessagesRequest, settings: Settings
|
| 148 |
+
) -> MessagesResponse | None:
|
| 149 |
+
"""Run optimization handlers in order. Returns first match or None."""
|
| 150 |
+
for handler in OPTIMIZATION_HANDLERS:
|
| 151 |
+
result = handler(request_data, settings)
|
| 152 |
+
if result is not None:
|
| 153 |
+
return result
|
| 154 |
+
return None
|
api/routes.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI route handlers."""
|
| 2 |
+
|
| 3 |
+
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
| 4 |
+
from loguru import logger
|
| 5 |
+
|
| 6 |
+
from config.settings import Settings
|
| 7 |
+
from core.anthropic import get_token_count
|
| 8 |
+
from providers.registry import ProviderRegistry
|
| 9 |
+
|
| 10 |
+
from . import dependencies
|
| 11 |
+
from .dependencies import get_settings, require_api_key
|
| 12 |
+
from .gateway_model_ids import gateway_model_id, no_thinking_gateway_model_id
|
| 13 |
+
from .models.anthropic import MessagesRequest, TokenCountRequest
|
| 14 |
+
from .models.responses import ModelResponse, ModelsListResponse
|
| 15 |
+
from .services import ClaudeProxyService
|
| 16 |
+
from providers.nvidia_nim import metrics as nvidia_nim_metrics
|
| 17 |
+
|
| 18 |
+
router = APIRouter()
|
| 19 |
+
|
| 20 |
+
DISCOVERED_MODEL_CREATED_AT = "1970-01-01T00:00:00Z"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# The proxy advertises a curated set of provider-backed models. Replace
|
| 24 |
+
# the previous hardcoded Claude model list with the requested NVIDIA-
|
| 25 |
+
# compatible models so clients only see those options.
|
| 26 |
+
REQUESTED_PROVIDER_MODELS = [
|
| 27 |
+
"nvidia_nim/qwen/qwen3-coder-480b-a35b-instruct",
|
| 28 |
+
"nvidia_nim/mistralai/mistral-large-3-675b-instruct-2512",
|
| 29 |
+
"nvidia_nim/abacusai/dracarys-llama-3.1-70b-instruct",
|
| 30 |
+
"nvidia_nim/z-ai/glm4.7",
|
| 31 |
+
"nvidia_nim/stepfun-ai/step-3.5-flash",
|
| 32 |
+
"nvidia_nim/bytedance/seed-oss-36b-instruct",
|
| 33 |
+
"nvidia_nim/mistralai/mistral-nemotron",
|
| 34 |
+
"groq/openai/gpt-oss-120b",
|
| 35 |
+
"groq/openai/gpt-oss-20b",
|
| 36 |
+
"groq/llama-3.3-70b-versatile",
|
| 37 |
+
"groq/meta-llama/llama-4-scout-17b-16e-instruct",
|
| 38 |
+
"groq/qwen/qwen3-32b",
|
| 39 |
+
"cerebras/gpt-oss-120b",
|
| 40 |
+
"cerebras/qwen-3-235b-a22b-instruct-2507",
|
| 41 |
+
"cerebras/zai-glm-4.7",
|
| 42 |
+
"cerebras/llama3.1-8b",
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_proxy_service(
|
| 47 |
+
request: Request,
|
| 48 |
+
settings: Settings = Depends(get_settings),
|
| 49 |
+
) -> ClaudeProxyService:
|
| 50 |
+
"""Build the request service for route handlers."""
|
| 51 |
+
return ClaudeProxyService(
|
| 52 |
+
settings,
|
| 53 |
+
provider_getter=lambda provider_type: dependencies.resolve_provider(
|
| 54 |
+
provider_type, app=request.app, settings=settings
|
| 55 |
+
),
|
| 56 |
+
token_counter=get_token_count,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _probe_response(allow: str) -> Response:
|
| 61 |
+
"""Return an empty success response for compatibility probes."""
|
| 62 |
+
return Response(status_code=204, headers={"Allow": allow})
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _discovered_model_response(model_id: str, *, display_name: str) -> ModelResponse:
|
| 66 |
+
return ModelResponse(
|
| 67 |
+
id=model_id,
|
| 68 |
+
display_name=display_name,
|
| 69 |
+
created_at=DISCOVERED_MODEL_CREATED_AT,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _append_unique_model(
|
| 74 |
+
models: list[ModelResponse], seen: set[str], model: ModelResponse
|
| 75 |
+
) -> None:
|
| 76 |
+
if model.id in seen:
|
| 77 |
+
return
|
| 78 |
+
seen.add(model.id)
|
| 79 |
+
models.append(model)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _append_provider_model_variants(
|
| 83 |
+
models: list[ModelResponse],
|
| 84 |
+
seen: set[str],
|
| 85 |
+
provider_model_ref: str,
|
| 86 |
+
*,
|
| 87 |
+
supports_thinking: bool | None = None,
|
| 88 |
+
) -> None:
|
| 89 |
+
if supports_thinking is not False:
|
| 90 |
+
_append_unique_model(
|
| 91 |
+
models,
|
| 92 |
+
seen,
|
| 93 |
+
_discovered_model_response(
|
| 94 |
+
gateway_model_id(provider_model_ref),
|
| 95 |
+
display_name=provider_model_ref,
|
| 96 |
+
),
|
| 97 |
+
)
|
| 98 |
+
_append_unique_model(
|
| 99 |
+
models,
|
| 100 |
+
seen,
|
| 101 |
+
_discovered_model_response(
|
| 102 |
+
no_thinking_gateway_model_id(provider_model_ref),
|
| 103 |
+
display_name=f"{provider_model_ref} (no thinking)",
|
| 104 |
+
),
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def _build_models_list_response(
|
| 109 |
+
settings: Settings, provider_registry: ProviderRegistry | None
|
| 110 |
+
) -> ModelsListResponse:
|
| 111 |
+
models: list[ModelResponse] = []
|
| 112 |
+
seen: set[str] = set()
|
| 113 |
+
|
| 114 |
+
# Advertise only the requested provider models (no Claude models, no registry auto-discovery).
|
| 115 |
+
# Each ref is added with both thinking and no-thinking variants.
|
| 116 |
+
for provider_ref in REQUESTED_PROVIDER_MODELS:
|
| 117 |
+
# If the ref already contains a provider prefix, use it as-is;
|
| 118 |
+
# otherwise assume it belongs to the NVIDIA NIM provider.
|
| 119 |
+
ref = provider_ref if "/" in provider_ref else f"nvidia_nim/{provider_ref}"
|
| 120 |
+
supports_thinking = None
|
| 121 |
+
if provider_registry is not None:
|
| 122 |
+
# model_id for registry lookups should be provider-prefixed
|
| 123 |
+
provider, model_id = ref.split("/", 1) if "/" in ref else ("nvidia_nim", ref)
|
| 124 |
+
supports_thinking = provider_registry.cached_model_supports_thinking(provider, model_id)
|
| 125 |
+
_append_provider_model_variants(models, seen, ref, supports_thinking=supports_thinking)
|
| 126 |
+
|
| 127 |
+
# Add a virtual `auto` model that maps to the configured MODEL and enables
|
| 128 |
+
# automatic fallback behavior when used by clients.
|
| 129 |
+
_append_unique_model(
|
| 130 |
+
models,
|
| 131 |
+
seen,
|
| 132 |
+
ModelResponse(
|
| 133 |
+
id=gateway_model_id("auto"),
|
| 134 |
+
display_name="auto (use configured fallbacks)",
|
| 135 |
+
created_at=DISCOVERED_MODEL_CREATED_AT,
|
| 136 |
+
),
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Filter out any residual Claude-branded models so the proxy advertises
|
| 140 |
+
# only the provider-backed models requested by the user.
|
| 141 |
+
filtered = [
|
| 142 |
+
m
|
| 143 |
+
for m in models
|
| 144 |
+
if "claude" not in (m.id or "").lower() and "claude" not in (m.display_name or "").lower()
|
| 145 |
+
]
|
| 146 |
+
# Ensure `auto` model remains available even if filtering removed others.
|
| 147 |
+
if not any(m.id == gateway_model_id("auto") for m in filtered):
|
| 148 |
+
filtered.append(
|
| 149 |
+
ModelResponse(
|
| 150 |
+
id=gateway_model_id("auto"),
|
| 151 |
+
display_name="auto (use configured fallbacks)",
|
| 152 |
+
created_at=DISCOVERED_MODEL_CREATED_AT,
|
| 153 |
+
)
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
return ModelsListResponse(
|
| 157 |
+
data=filtered,
|
| 158 |
+
first_id=filtered[0].id if filtered else None,
|
| 159 |
+
has_more=False,
|
| 160 |
+
last_id=filtered[-1].id if filtered else None,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# =============================================================================
|
| 166 |
+
# Routes
|
| 167 |
+
# =============================================================================
|
| 168 |
+
@router.post("/v1/messages")
|
| 169 |
+
async def create_message(
|
| 170 |
+
request_data: MessagesRequest,
|
| 171 |
+
service: ClaudeProxyService = Depends(get_proxy_service),
|
| 172 |
+
_auth=Depends(require_api_key),
|
| 173 |
+
):
|
| 174 |
+
"""Create a message (always streaming)."""
|
| 175 |
+
return service.create_message(request_data)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@router.api_route("/v1/messages", methods=["HEAD", "OPTIONS"])
|
| 179 |
+
async def probe_messages(_auth=Depends(require_api_key)):
|
| 180 |
+
"""Respond to Claude compatibility probes for the messages endpoint."""
|
| 181 |
+
return _probe_response("POST, HEAD, OPTIONS")
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
@router.post("/v1/messages/count_tokens")
|
| 185 |
+
async def count_tokens(
|
| 186 |
+
request_data: TokenCountRequest,
|
| 187 |
+
service: ClaudeProxyService = Depends(get_proxy_service),
|
| 188 |
+
_auth=Depends(require_api_key),
|
| 189 |
+
):
|
| 190 |
+
"""Count tokens for a request."""
|
| 191 |
+
return service.count_tokens(request_data)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
@router.api_route("/v1/messages/count_tokens", methods=["HEAD", "OPTIONS"])
|
| 195 |
+
async def probe_count_tokens(_auth=Depends(require_api_key)):
|
| 196 |
+
"""Respond to Claude compatibility probes for the token count endpoint."""
|
| 197 |
+
return _probe_response("POST, HEAD, OPTIONS")
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
@router.get("/")
|
| 201 |
+
async def root(
|
| 202 |
+
settings: Settings = Depends(get_settings), _auth=Depends(require_api_key)
|
| 203 |
+
):
|
| 204 |
+
"""Root endpoint."""
|
| 205 |
+
return {
|
| 206 |
+
"status": "ok",
|
| 207 |
+
"provider": settings.provider_type,
|
| 208 |
+
"model": settings.model,
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
@router.api_route("/", methods=["HEAD", "OPTIONS"])
|
| 213 |
+
async def probe_root(_auth=Depends(require_api_key)):
|
| 214 |
+
"""Respond to compatibility probes for the root endpoint."""
|
| 215 |
+
return _probe_response("GET, HEAD, OPTIONS")
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
@router.get("/health")
|
| 219 |
+
async def health():
|
| 220 |
+
"""Health check endpoint."""
|
| 221 |
+
return {"status": "healthy"}
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
@router.api_route("/health", methods=["HEAD", "OPTIONS"])
|
| 225 |
+
async def probe_health():
|
| 226 |
+
"""Respond to compatibility probes for the health endpoint."""
|
| 227 |
+
return _probe_response("GET, HEAD, OPTIONS")
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
@router.get("/v1/models", response_model=ModelsListResponse)
|
| 231 |
+
async def list_models(
|
| 232 |
+
request: Request,
|
| 233 |
+
settings: Settings = Depends(get_settings),
|
| 234 |
+
_auth=Depends(require_api_key),
|
| 235 |
+
):
|
| 236 |
+
"""List the model ids this proxy advertises to Claude-compatible clients."""
|
| 237 |
+
registry = getattr(request.app.state, "provider_registry", None)
|
| 238 |
+
provider_registry = registry if isinstance(registry, ProviderRegistry) else None
|
| 239 |
+
return _build_models_list_response(settings, provider_registry)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
@router.post("/stop")
|
| 243 |
+
async def stop_cli(request: Request, _auth=Depends(require_api_key)):
|
| 244 |
+
"""Stop all CLI sessions and pending tasks."""
|
| 245 |
+
handler = getattr(request.app.state, "message_handler", None)
|
| 246 |
+
if not handler:
|
| 247 |
+
# Fallback if messaging not initialized
|
| 248 |
+
cli_manager = getattr(request.app.state, "cli_manager", None)
|
| 249 |
+
if cli_manager:
|
| 250 |
+
await cli_manager.stop_all()
|
| 251 |
+
logger.info("STOP_CLI: source=cli_manager cancelled_count=N/A")
|
| 252 |
+
return {"status": "stopped", "source": "cli_manager"}
|
| 253 |
+
raise HTTPException(status_code=503, detail="Messaging system not initialized")
|
| 254 |
+
|
| 255 |
+
count = await handler.stop_all_tasks()
|
| 256 |
+
logger.info("STOP_CLI: source=handler cancelled_count={}", count)
|
| 257 |
+
return {"status": "stopped", "cancelled_count": count}
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
@router.get("/admin/fallbacks")
|
| 261 |
+
async def admin_fallbacks(_auth=Depends(require_api_key)):
|
| 262 |
+
"""Admin endpoint exposing NVIDIA NIM fallback metrics.
|
| 263 |
+
|
| 264 |
+
Protected by the same API key as other endpoints.
|
| 265 |
+
"""
|
| 266 |
+
try:
|
| 267 |
+
data = nvidia_nim_metrics.snapshot()
|
| 268 |
+
except Exception as e:
|
| 269 |
+
logger.warning("ADMIN_FALLBACKS: failed to read metrics: {}", e)
|
| 270 |
+
raise HTTPException(status_code=500, detail="failed to read metrics")
|
| 271 |
+
return {"provider": "nvidia_nim", "fallbacks": data}
|
api/runtime.py
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Application runtime composition and lifecycle ownership."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import os
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from typing import TYPE_CHECKING, Any
|
| 9 |
+
|
| 10 |
+
from fastapi import FastAPI
|
| 11 |
+
from loguru import logger
|
| 12 |
+
|
| 13 |
+
from config.settings import Settings, get_settings
|
| 14 |
+
from providers.exceptions import ServiceUnavailableError
|
| 15 |
+
from providers.registry import ProviderRegistry
|
| 16 |
+
|
| 17 |
+
if TYPE_CHECKING:
|
| 18 |
+
from cli.manager import CLISessionManager
|
| 19 |
+
from messaging.handler import ClaudeMessageHandler
|
| 20 |
+
from messaging.platforms.base import MessagingPlatform
|
| 21 |
+
from messaging.session import SessionStore
|
| 22 |
+
|
| 23 |
+
_SHUTDOWN_TIMEOUT_S = 5.0
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
async def best_effort(
|
| 27 |
+
name: str,
|
| 28 |
+
awaitable: Any,
|
| 29 |
+
timeout_s: float = _SHUTDOWN_TIMEOUT_S,
|
| 30 |
+
*,
|
| 31 |
+
log_verbose_errors: bool = False,
|
| 32 |
+
) -> None:
|
| 33 |
+
"""Run a shutdown step with timeout; never raise to callers."""
|
| 34 |
+
try:
|
| 35 |
+
await asyncio.wait_for(awaitable, timeout=timeout_s)
|
| 36 |
+
except TimeoutError:
|
| 37 |
+
logger.warning("Shutdown step timed out: {} ({}s)", name, timeout_s)
|
| 38 |
+
except Exception as e:
|
| 39 |
+
if log_verbose_errors:
|
| 40 |
+
logger.warning(
|
| 41 |
+
"Shutdown step failed: {}: {}: {}",
|
| 42 |
+
name,
|
| 43 |
+
type(e).__name__,
|
| 44 |
+
e,
|
| 45 |
+
)
|
| 46 |
+
else:
|
| 47 |
+
logger.warning(
|
| 48 |
+
"Shutdown step failed: {}: exc_type={}",
|
| 49 |
+
name,
|
| 50 |
+
type(e).__name__,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def warn_if_process_auth_token(settings: Settings) -> None:
|
| 55 |
+
"""Warn when server auth was implicitly inherited from the shell."""
|
| 56 |
+
if settings.uses_process_anthropic_auth_token():
|
| 57 |
+
logger.warning(
|
| 58 |
+
"ANTHROPIC_AUTH_TOKEN is set in the process environment but not in "
|
| 59 |
+
"a configured .env file. The proxy will require that token. Add "
|
| 60 |
+
"ANTHROPIC_AUTH_TOKEN= to .env to disable proxy auth, or set the "
|
| 61 |
+
"same token in .env to make server auth explicit."
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def log_startup_failure(settings: Settings, exc: Exception) -> None:
|
| 66 |
+
"""Log startup failures without traceback noise unless verbose diagnostics are enabled."""
|
| 67 |
+
message = startup_failure_message(settings, exc)
|
| 68 |
+
logger.error("Startup failed:\n{}", message)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def startup_failure_message(settings: Settings, exc: Exception) -> str:
|
| 72 |
+
"""Return a concise startup failure message for logs and ASGI lifespan failure."""
|
| 73 |
+
if isinstance(exc, ServiceUnavailableError):
|
| 74 |
+
return exc.message.strip() or "Server startup failed."
|
| 75 |
+
|
| 76 |
+
if settings.log_api_error_tracebacks:
|
| 77 |
+
return f"{type(exc).__name__}: {exc}"
|
| 78 |
+
|
| 79 |
+
return f"Server startup failed: exc_type={type(exc).__name__}"
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _should_continue_after_model_validation_failure(exc: Exception) -> bool:
|
| 83 |
+
"""Return whether a model-validation failure should be downgraded to a warning.
|
| 84 |
+
|
| 85 |
+
Provider discovery can fail transiently or due to local environment issues
|
| 86 |
+
(for example, a missing runtime dependency in the provider's process path).
|
| 87 |
+
We keep startup alive in those cases so the configured proxy can still serve
|
| 88 |
+
requests and advertise the models that are already known from settings.
|
| 89 |
+
"""
|
| 90 |
+
if not isinstance(exc, ServiceUnavailableError):
|
| 91 |
+
return False
|
| 92 |
+
|
| 93 |
+
message = (exc.message or str(exc)).lower()
|
| 94 |
+
return "problem=query failure:" in message
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@dataclass(slots=True)
|
| 98 |
+
class AppRuntime:
|
| 99 |
+
"""Own optional messaging, CLI, session, and provider runtime resources."""
|
| 100 |
+
|
| 101 |
+
app: FastAPI
|
| 102 |
+
settings: Settings
|
| 103 |
+
_provider_registry: ProviderRegistry | None = field(default=None, init=False)
|
| 104 |
+
messaging_platform: MessagingPlatform | None = None
|
| 105 |
+
message_handler: ClaudeMessageHandler | None = None
|
| 106 |
+
cli_manager: CLISessionManager | None = None
|
| 107 |
+
|
| 108 |
+
@classmethod
|
| 109 |
+
def for_app(
|
| 110 |
+
cls,
|
| 111 |
+
app: FastAPI,
|
| 112 |
+
settings: Settings | None = None,
|
| 113 |
+
) -> AppRuntime:
|
| 114 |
+
return cls(app=app, settings=settings or get_settings())
|
| 115 |
+
|
| 116 |
+
async def startup(self) -> None:
|
| 117 |
+
logger.info("Starting Claude Code Proxy...")
|
| 118 |
+
self._provider_registry = ProviderRegistry()
|
| 119 |
+
self.app.state.provider_registry = self._provider_registry
|
| 120 |
+
try:
|
| 121 |
+
warn_if_process_auth_token(self.settings)
|
| 122 |
+
try:
|
| 123 |
+
# Use a reasonable timeout for startup validation to prevent blocking healthy checks.
|
| 124 |
+
await asyncio.wait_for(
|
| 125 |
+
self._provider_registry.validate_configured_models(self.settings),
|
| 126 |
+
timeout=15.0,
|
| 127 |
+
)
|
| 128 |
+
except Exception as exc:
|
| 129 |
+
logger.warning(
|
| 130 |
+
"Startup model validation skipped or timed out: continuing in lazy mode. "
|
| 131 |
+
"Reason: {}",
|
| 132 |
+
str(exc) or type(exc).__name__,
|
| 133 |
+
)
|
| 134 |
+
self._provider_registry.start_model_list_refresh(self.settings)
|
| 135 |
+
await self._start_messaging_if_configured()
|
| 136 |
+
self._publish_state()
|
| 137 |
+
except Exception as exc:
|
| 138 |
+
log_startup_failure(self.settings, exc)
|
| 139 |
+
await best_effort(
|
| 140 |
+
"provider_registry.cleanup",
|
| 141 |
+
self._provider_registry.cleanup(),
|
| 142 |
+
log_verbose_errors=self.settings.log_api_error_tracebacks,
|
| 143 |
+
)
|
| 144 |
+
raise
|
| 145 |
+
|
| 146 |
+
async def shutdown(self) -> None:
|
| 147 |
+
verbose = self.settings.log_api_error_tracebacks
|
| 148 |
+
if self.message_handler is not None:
|
| 149 |
+
try:
|
| 150 |
+
self.message_handler.session_store.flush_pending_save()
|
| 151 |
+
except Exception as e:
|
| 152 |
+
if verbose:
|
| 153 |
+
logger.warning("Session store flush on shutdown: {}", e)
|
| 154 |
+
else:
|
| 155 |
+
logger.warning(
|
| 156 |
+
"Session store flush on shutdown: exc_type={}",
|
| 157 |
+
type(e).__name__,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
logger.info("Shutdown requested, cleaning up...")
|
| 161 |
+
if self.messaging_platform:
|
| 162 |
+
await best_effort(
|
| 163 |
+
"messaging_platform.stop",
|
| 164 |
+
self.messaging_platform.stop(),
|
| 165 |
+
log_verbose_errors=verbose,
|
| 166 |
+
)
|
| 167 |
+
if self.cli_manager:
|
| 168 |
+
await best_effort(
|
| 169 |
+
"cli_manager.stop_all",
|
| 170 |
+
self.cli_manager.stop_all(),
|
| 171 |
+
log_verbose_errors=verbose,
|
| 172 |
+
)
|
| 173 |
+
if self._provider_registry is not None:
|
| 174 |
+
await best_effort(
|
| 175 |
+
"provider_registry.cleanup",
|
| 176 |
+
self._provider_registry.cleanup(),
|
| 177 |
+
log_verbose_errors=verbose,
|
| 178 |
+
)
|
| 179 |
+
await self._shutdown_limiter()
|
| 180 |
+
logger.info("Server shut down cleanly")
|
| 181 |
+
|
| 182 |
+
async def _start_messaging_if_configured(self) -> None:
|
| 183 |
+
try:
|
| 184 |
+
from messaging.platforms.factory import (
|
| 185 |
+
MessagingPlatformOptions,
|
| 186 |
+
create_messaging_platform,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
self.messaging_platform = create_messaging_platform(
|
| 190 |
+
self.settings.messaging_platform,
|
| 191 |
+
MessagingPlatformOptions(
|
| 192 |
+
telegram_bot_token=self.settings.telegram_bot_token,
|
| 193 |
+
allowed_telegram_user_id=self.settings.allowed_telegram_user_id,
|
| 194 |
+
discord_bot_token=self.settings.discord_bot_token,
|
| 195 |
+
allowed_discord_channels=self.settings.allowed_discord_channels,
|
| 196 |
+
voice_note_enabled=self.settings.voice_note_enabled,
|
| 197 |
+
whisper_model=self.settings.whisper_model,
|
| 198 |
+
whisper_device=self.settings.whisper_device,
|
| 199 |
+
hf_token=self.settings.hf_token,
|
| 200 |
+
nvidia_nim_api_key=self.settings.nvidia_nim_api_key_qwen,
|
| 201 |
+
messaging_rate_limit=self.settings.messaging_rate_limit,
|
| 202 |
+
messaging_rate_window=self.settings.messaging_rate_window,
|
| 203 |
+
log_raw_messaging_content=self.settings.log_raw_messaging_content,
|
| 204 |
+
log_api_error_tracebacks=self.settings.log_api_error_tracebacks,
|
| 205 |
+
),
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
if self.messaging_platform:
|
| 209 |
+
await self._start_message_handler()
|
| 210 |
+
|
| 211 |
+
except ImportError as e:
|
| 212 |
+
if self.settings.log_api_error_tracebacks:
|
| 213 |
+
logger.warning("Messaging module import error: {}", e)
|
| 214 |
+
else:
|
| 215 |
+
logger.warning(
|
| 216 |
+
"Messaging module import error: exc_type={}",
|
| 217 |
+
type(e).__name__,
|
| 218 |
+
)
|
| 219 |
+
except Exception as e:
|
| 220 |
+
if self.settings.log_api_error_tracebacks:
|
| 221 |
+
logger.error("Failed to start messaging platform: {}", e)
|
| 222 |
+
import traceback
|
| 223 |
+
|
| 224 |
+
logger.error(traceback.format_exc())
|
| 225 |
+
else:
|
| 226 |
+
logger.error(
|
| 227 |
+
"Failed to start messaging platform: exc_type={}",
|
| 228 |
+
type(e).__name__,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
async def _start_message_handler(self) -> None:
|
| 232 |
+
from cli.manager import CLISessionManager
|
| 233 |
+
from messaging.handler import ClaudeMessageHandler
|
| 234 |
+
from messaging.session import SessionStore
|
| 235 |
+
|
| 236 |
+
workspace = (
|
| 237 |
+
os.path.abspath(self.settings.allowed_dir)
|
| 238 |
+
if self.settings.allowed_dir
|
| 239 |
+
else os.getcwd()
|
| 240 |
+
)
|
| 241 |
+
os.makedirs(workspace, exist_ok=True)
|
| 242 |
+
|
| 243 |
+
data_path = os.path.abspath(self.settings.claude_workspace)
|
| 244 |
+
os.makedirs(data_path, exist_ok=True)
|
| 245 |
+
|
| 246 |
+
api_url = f"http://{self.settings.host}:{self.settings.port}/v1"
|
| 247 |
+
allowed_dirs = [workspace] if self.settings.allowed_dir else []
|
| 248 |
+
plans_dir_abs = os.path.abspath(
|
| 249 |
+
os.path.join(self.settings.claude_workspace, "plans")
|
| 250 |
+
)
|
| 251 |
+
plans_directory = os.path.relpath(plans_dir_abs, workspace)
|
| 252 |
+
self.cli_manager = CLISessionManager(
|
| 253 |
+
workspace_path=workspace,
|
| 254 |
+
api_url=api_url,
|
| 255 |
+
allowed_dirs=allowed_dirs,
|
| 256 |
+
plans_directory=plans_directory,
|
| 257 |
+
claude_bin=self.settings.claude_cli_bin,
|
| 258 |
+
log_raw_cli_diagnostics=self.settings.log_raw_cli_diagnostics,
|
| 259 |
+
log_messaging_error_details=self.settings.log_messaging_error_details,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
session_store = SessionStore(
|
| 263 |
+
storage_path=os.path.join(data_path, "sessions.json"),
|
| 264 |
+
message_log_cap=self.settings.max_message_log_entries_per_chat,
|
| 265 |
+
)
|
| 266 |
+
platform = self.messaging_platform
|
| 267 |
+
assert platform is not None
|
| 268 |
+
self.message_handler = ClaudeMessageHandler(
|
| 269 |
+
platform=platform,
|
| 270 |
+
cli_manager=self.cli_manager,
|
| 271 |
+
session_store=session_store,
|
| 272 |
+
debug_platform_edits=self.settings.debug_platform_edits,
|
| 273 |
+
debug_subagent_stack=self.settings.debug_subagent_stack,
|
| 274 |
+
log_raw_messaging_content=self.settings.log_raw_messaging_content,
|
| 275 |
+
log_raw_cli_diagnostics=self.settings.log_raw_cli_diagnostics,
|
| 276 |
+
log_messaging_error_details=self.settings.log_messaging_error_details,
|
| 277 |
+
)
|
| 278 |
+
self._restore_tree_state(session_store)
|
| 279 |
+
|
| 280 |
+
platform.on_message(self.message_handler.handle_message)
|
| 281 |
+
await platform.start()
|
| 282 |
+
logger.info(f"{platform.name} platform started with message handler")
|
| 283 |
+
|
| 284 |
+
def _restore_tree_state(self, session_store: SessionStore) -> None:
|
| 285 |
+
saved_trees = session_store.get_all_trees()
|
| 286 |
+
if not saved_trees:
|
| 287 |
+
return
|
| 288 |
+
if self.message_handler is None:
|
| 289 |
+
return
|
| 290 |
+
|
| 291 |
+
logger.info(f"Restoring {len(saved_trees)} conversation trees...")
|
| 292 |
+
from messaging.trees.queue_manager import TreeQueueManager
|
| 293 |
+
|
| 294 |
+
self.message_handler.replace_tree_queue(
|
| 295 |
+
TreeQueueManager.from_dict(
|
| 296 |
+
{
|
| 297 |
+
"trees": saved_trees,
|
| 298 |
+
"node_to_tree": session_store.get_node_mapping(),
|
| 299 |
+
},
|
| 300 |
+
queue_update_callback=self.message_handler.update_queue_positions,
|
| 301 |
+
node_started_callback=self.message_handler.mark_node_processing,
|
| 302 |
+
)
|
| 303 |
+
)
|
| 304 |
+
if self.message_handler.tree_queue.cleanup_stale_nodes() > 0:
|
| 305 |
+
tree_data = self.message_handler.tree_queue.to_dict()
|
| 306 |
+
session_store.sync_from_tree_data(
|
| 307 |
+
tree_data["trees"], tree_data["node_to_tree"]
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
def _publish_state(self) -> None:
|
| 311 |
+
self.app.state.messaging_platform = self.messaging_platform
|
| 312 |
+
self.app.state.message_handler = self.message_handler
|
| 313 |
+
self.app.state.cli_manager = self.cli_manager
|
| 314 |
+
|
| 315 |
+
async def _shutdown_limiter(self) -> None:
|
| 316 |
+
verbose = self.settings.log_api_error_tracebacks
|
| 317 |
+
try:
|
| 318 |
+
from messaging.limiter import MessagingRateLimiter
|
| 319 |
+
except Exception as e:
|
| 320 |
+
if verbose:
|
| 321 |
+
logger.debug(
|
| 322 |
+
"Rate limiter shutdown skipped (import failed): {}: {}",
|
| 323 |
+
type(e).__name__,
|
| 324 |
+
e,
|
| 325 |
+
)
|
| 326 |
+
else:
|
| 327 |
+
logger.debug(
|
| 328 |
+
"Rate limiter shutdown skipped (import failed): exc_type={}",
|
| 329 |
+
type(e).__name__,
|
| 330 |
+
)
|
| 331 |
+
return
|
| 332 |
+
|
| 333 |
+
await best_effort(
|
| 334 |
+
"MessagingRateLimiter.shutdown_instance",
|
| 335 |
+
MessagingRateLimiter.shutdown_instance(),
|
| 336 |
+
timeout_s=2.0,
|
| 337 |
+
log_verbose_errors=verbose,
|
| 338 |
+
)
|
api/services.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Application services for the Claude-compatible API."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import traceback
|
| 6 |
+
import uuid
|
| 7 |
+
from collections.abc import AsyncIterator, Callable
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
from fastapi import HTTPException
|
| 11 |
+
from fastapi.responses import StreamingResponse
|
| 12 |
+
from loguru import logger
|
| 13 |
+
|
| 14 |
+
from config.settings import Settings
|
| 15 |
+
from core.anthropic import get_token_count, get_user_facing_error_message
|
| 16 |
+
from core.anthropic.sse import ANTHROPIC_SSE_RESPONSE_HEADERS, format_sse_event
|
| 17 |
+
from providers.base import BaseProvider
|
| 18 |
+
from providers.exceptions import (
|
| 19 |
+
InvalidRequestError,
|
| 20 |
+
OverloadedError,
|
| 21 |
+
ProviderError,
|
| 22 |
+
RateLimitError,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
from .model_router import ModelRouter
|
| 26 |
+
from .models.anthropic import MessagesRequest, TokenCountRequest
|
| 27 |
+
from .models.responses import TokenCountResponse
|
| 28 |
+
from .optimization_handlers import try_optimizations
|
| 29 |
+
from .web_tools.egress import WebFetchEgressPolicy
|
| 30 |
+
from .web_tools.request import (
|
| 31 |
+
is_web_server_tool_request,
|
| 32 |
+
openai_chat_upstream_server_tool_error,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
TokenCounter = Callable[[list[Any], str | list[Any] | None, list[Any] | None], int]
|
| 36 |
+
|
| 37 |
+
ProviderGetter = Callable[[str], BaseProvider]
|
| 38 |
+
|
| 39 |
+
# Providers that use ``/chat/completions`` + Anthropic-to-OpenAI conversion (not native Messages).
|
| 40 |
+
_OPENAI_CHAT_UPSTREAM_IDS = frozenset({"nvidia_nim", "groq", "cerebras"})
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def anthropic_sse_streaming_response(
|
| 44 |
+
body: AsyncIterator[str],
|
| 45 |
+
) -> StreamingResponse:
|
| 46 |
+
"""Return a :class:`StreamingResponse` for Anthropic-style SSE streams."""
|
| 47 |
+
return StreamingResponse(
|
| 48 |
+
body,
|
| 49 |
+
media_type="text/event-stream",
|
| 50 |
+
headers=ANTHROPIC_SSE_RESPONSE_HEADERS,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _http_status_for_unexpected_service_exception(_exc: BaseException) -> int:
|
| 55 |
+
"""HTTP status for uncaught non-provider failures (stable client contract)."""
|
| 56 |
+
return 500
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _log_unexpected_service_exception(
|
| 60 |
+
settings: Settings,
|
| 61 |
+
exc: BaseException,
|
| 62 |
+
*,
|
| 63 |
+
context: str,
|
| 64 |
+
request_id: str | None = None,
|
| 65 |
+
) -> None:
|
| 66 |
+
"""Log service-layer failures without echoing exception text unless opted in."""
|
| 67 |
+
if settings.log_api_error_tracebacks:
|
| 68 |
+
if request_id is not None:
|
| 69 |
+
logger.error("{} request_id={}: {}", context, request_id, exc)
|
| 70 |
+
else:
|
| 71 |
+
logger.error("{}: {}", context, exc)
|
| 72 |
+
logger.error(traceback.format_exc())
|
| 73 |
+
return
|
| 74 |
+
if request_id is not None:
|
| 75 |
+
logger.error(
|
| 76 |
+
"{} request_id={} exc_type={}",
|
| 77 |
+
context,
|
| 78 |
+
request_id,
|
| 79 |
+
type(exc).__name__,
|
| 80 |
+
)
|
| 81 |
+
else:
|
| 82 |
+
logger.error("{} exc_type={}", context, type(exc).__name__)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _require_non_empty_messages(messages: list[Any]) -> None:
|
| 86 |
+
if not messages:
|
| 87 |
+
raise InvalidRequestError("messages cannot be empty")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class ClaudeProxyService:
|
| 91 |
+
"""Coordinate request optimization, model routing, token count, and providers."""
|
| 92 |
+
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
settings: Settings,
|
| 96 |
+
provider_getter: ProviderGetter,
|
| 97 |
+
model_router: ModelRouter | None = None,
|
| 98 |
+
token_counter: TokenCounter = get_token_count,
|
| 99 |
+
):
|
| 100 |
+
self._settings = settings
|
| 101 |
+
self._provider_getter = provider_getter
|
| 102 |
+
self._model_router = model_router or ModelRouter(settings)
|
| 103 |
+
self._token_counter = token_counter
|
| 104 |
+
|
| 105 |
+
def create_message(self, request_data: MessagesRequest) -> object:
|
| 106 |
+
"""Create a message response or streaming response with optional failover."""
|
| 107 |
+
from .web_tools.streaming import stream_web_server_tool_response
|
| 108 |
+
try:
|
| 109 |
+
_require_non_empty_messages(request_data.messages)
|
| 110 |
+
|
| 111 |
+
candidates = self._model_router.resolve_candidates(request_data.model)
|
| 112 |
+
if not candidates:
|
| 113 |
+
raise InvalidRequestError(f"No configured models available for '{request_data.model}'")
|
| 114 |
+
|
| 115 |
+
# For 'auto' requests with multiple candidates, we wrap the stream in a failover loop.
|
| 116 |
+
if len(candidates) > 1:
|
| 117 |
+
return anthropic_sse_streaming_response(
|
| 118 |
+
self._stream_with_fallbacks(candidates, request_data)
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Standard path for single-model requests
|
| 122 |
+
return self._create_single_message(candidates[0], request_data)
|
| 123 |
+
|
| 124 |
+
except ProviderError:
|
| 125 |
+
raise
|
| 126 |
+
except Exception as e:
|
| 127 |
+
_log_unexpected_service_exception(
|
| 128 |
+
self._settings, e, context="CREATE_MESSAGE_ERROR"
|
| 129 |
+
)
|
| 130 |
+
raise HTTPException(
|
| 131 |
+
status_code=_http_status_for_unexpected_service_exception(e),
|
| 132 |
+
detail=get_user_facing_error_message(e),
|
| 133 |
+
) from e
|
| 134 |
+
|
| 135 |
+
def _create_single_message(
|
| 136 |
+
self, resolved: ResolvedModel, request_data: MessagesRequest
|
| 137 |
+
) -> object:
|
| 138 |
+
"""Create a single message response from a resolved model."""
|
| 139 |
+
routed_request = request_data.model_copy(deep=True)
|
| 140 |
+
routed_request.model = resolved.provider_model
|
| 141 |
+
|
| 142 |
+
if resolved.provider_id in _OPENAI_CHAT_UPSTREAM_IDS:
|
| 143 |
+
tool_err = openai_chat_upstream_server_tool_error(
|
| 144 |
+
routed_request,
|
| 145 |
+
web_tools_enabled=self._settings.enable_web_server_tools,
|
| 146 |
+
)
|
| 147 |
+
if tool_err is not None:
|
| 148 |
+
raise InvalidRequestError(tool_err)
|
| 149 |
+
|
| 150 |
+
if self._settings.enable_web_server_tools and is_web_server_tool_request(
|
| 151 |
+
routed_request
|
| 152 |
+
):
|
| 153 |
+
input_tokens = self._token_counter(
|
| 154 |
+
routed_request.messages, routed_request.system, routed_request.tools
|
| 155 |
+
)
|
| 156 |
+
logger.info("Optimization: Handling Anthropic web server tool")
|
| 157 |
+
egress = WebFetchEgressPolicy(
|
| 158 |
+
allow_private_network_targets=self._settings.web_fetch_allow_private_networks,
|
| 159 |
+
allowed_schemes=self._settings.web_fetch_allowed_scheme_set(),
|
| 160 |
+
)
|
| 161 |
+
return anthropic_sse_streaming_response(
|
| 162 |
+
stream_web_server_tool_response(
|
| 163 |
+
routed_request,
|
| 164 |
+
input_tokens=input_tokens,
|
| 165 |
+
web_fetch_egress=egress,
|
| 166 |
+
verbose_client_errors=self._settings.log_api_error_tracebacks,
|
| 167 |
+
),
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
optimized = try_optimizations(routed_request, self._settings)
|
| 171 |
+
if optimized is not None:
|
| 172 |
+
return optimized
|
| 173 |
+
|
| 174 |
+
provider = self._provider_getter(resolved.provider_id)
|
| 175 |
+
provider.preflight_stream(
|
| 176 |
+
routed_request,
|
| 177 |
+
thinking_enabled=resolved.thinking_enabled,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
request_id = f"req_{uuid.uuid4().hex[:12]}"
|
| 181 |
+
logger.info(
|
| 182 |
+
"API_REQUEST: request_id={} model={} messages={}",
|
| 183 |
+
request_id,
|
| 184 |
+
routed_request.model,
|
| 185 |
+
len(routed_request.messages),
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
input_tokens = self._token_counter(
|
| 189 |
+
routed_request.messages, routed_request.system, routed_request.tools
|
| 190 |
+
)
|
| 191 |
+
return anthropic_sse_streaming_response(
|
| 192 |
+
provider.stream_response(
|
| 193 |
+
routed_request,
|
| 194 |
+
input_tokens=input_tokens,
|
| 195 |
+
request_id=request_id,
|
| 196 |
+
thinking_enabled=resolved.thinking_enabled,
|
| 197 |
+
),
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
async def _stream_with_fallbacks(
|
| 201 |
+
self, candidates: list[ResolvedModel], request_data: MessagesRequest
|
| 202 |
+
) -> AsyncIterator[str]:
|
| 203 |
+
"""Iterate through candidates until one succeeds or all fail."""
|
| 204 |
+
last_exc: Exception | None = None
|
| 205 |
+
|
| 206 |
+
for i, resolved in enumerate(candidates):
|
| 207 |
+
try:
|
| 208 |
+
provider = self._provider_getter(resolved.provider_id)
|
| 209 |
+
routed_request = request_data.model_copy(deep=True)
|
| 210 |
+
routed_request.model = resolved.provider_model
|
| 211 |
+
|
| 212 |
+
provider.preflight_stream(
|
| 213 |
+
routed_request,
|
| 214 |
+
thinking_enabled=resolved.thinking_enabled,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
request_id = f"req_{uuid.uuid4().hex[:12]}"
|
| 218 |
+
logger.info(
|
| 219 |
+
"API_REQUEST (auto fallback {}/{}): request_id={} provider={} model={}",
|
| 220 |
+
i + 1,
|
| 221 |
+
len(candidates),
|
| 222 |
+
request_id,
|
| 223 |
+
resolved.provider_id,
|
| 224 |
+
resolved.provider_model,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
input_tokens = self._token_counter(
|
| 228 |
+
routed_request.messages, routed_request.system, routed_request.tools
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# Attempt to stream from this provider.
|
| 232 |
+
async for event in provider.stream_response(
|
| 233 |
+
routed_request,
|
| 234 |
+
input_tokens=input_tokens,
|
| 235 |
+
request_id=request_id,
|
| 236 |
+
thinking_enabled=resolved.thinking_enabled,
|
| 237 |
+
):
|
| 238 |
+
yield event
|
| 239 |
+
# CRITICAL: If we have yielded even one event, we have committed to this provider.
|
| 240 |
+
# We must not fallback to another candidate mid-stream.
|
| 241 |
+
return # Success, exit the fallback loop.
|
| 242 |
+
|
| 243 |
+
except (RateLimitError, OverloadedError) as e:
|
| 244 |
+
logger.warning(
|
| 245 |
+
"Provider '{}' is rate limited or overloaded ({}). Trying next candidate...",
|
| 246 |
+
resolved.provider_id,
|
| 247 |
+
e.status_code,
|
| 248 |
+
)
|
| 249 |
+
last_exc = e
|
| 250 |
+
continue
|
| 251 |
+
except Exception as e:
|
| 252 |
+
logger.error(
|
| 253 |
+
"Provider '{}' failed with unexpected error: {}. Trying next candidate...",
|
| 254 |
+
resolved.provider_id,
|
| 255 |
+
e,
|
| 256 |
+
)
|
| 257 |
+
last_exc = e
|
| 258 |
+
continue
|
| 259 |
+
|
| 260 |
+
err_msg = str(last_exc) if last_exc else "No candidates succeeded"
|
| 261 |
+
yield format_sse_event(
|
| 262 |
+
"error",
|
| 263 |
+
{
|
| 264 |
+
"type": "error",
|
| 265 |
+
"error": {
|
| 266 |
+
"type": "api_error",
|
| 267 |
+
"message": f"All fallback candidates failed: {err_msg}",
|
| 268 |
+
},
|
| 269 |
+
},
|
| 270 |
+
)
|
| 271 |
+
if last_exc:
|
| 272 |
+
raise last_exc
|
| 273 |
+
raise InvalidRequestError("No candidates succeeded")
|
| 274 |
+
|
| 275 |
+
def count_tokens(self, request_data: TokenCountRequest) -> TokenCountResponse:
|
| 276 |
+
"""Count tokens for a request after applying configured model routing."""
|
| 277 |
+
request_id = f"req_{uuid.uuid4().hex[:12]}"
|
| 278 |
+
with logger.contextualize(request_id=request_id):
|
| 279 |
+
try:
|
| 280 |
+
_require_non_empty_messages(request_data.messages)
|
| 281 |
+
routed = self._model_router.resolve_token_count_request(request_data)
|
| 282 |
+
tokens = self._token_counter(
|
| 283 |
+
routed.request.messages, routed.request.system, routed.request.tools
|
| 284 |
+
)
|
| 285 |
+
logger.info(
|
| 286 |
+
"COUNT_TOKENS: request_id={} model={} messages={} input_tokens={}",
|
| 287 |
+
request_id,
|
| 288 |
+
routed.request.model,
|
| 289 |
+
len(routed.request.messages),
|
| 290 |
+
tokens,
|
| 291 |
+
)
|
| 292 |
+
return TokenCountResponse(input_tokens=tokens)
|
| 293 |
+
except ProviderError:
|
| 294 |
+
raise
|
| 295 |
+
except Exception as e:
|
| 296 |
+
_log_unexpected_service_exception(
|
| 297 |
+
self._settings,
|
| 298 |
+
e,
|
| 299 |
+
context="COUNT_TOKENS_ERROR",
|
| 300 |
+
request_id=request_id,
|
| 301 |
+
)
|
| 302 |
+
raise HTTPException(
|
| 303 |
+
status_code=_http_status_for_unexpected_service_exception(e),
|
| 304 |
+
detail=get_user_facing_error_message(e),
|
| 305 |
+
) from e
|
api/validation_log.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Safe metadata summaries for HTTP 422 validation logging (no raw text content)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def summarize_request_validation_body(
|
| 9 |
+
body: Any,
|
| 10 |
+
) -> tuple[list[dict[str, Any]], list[str]]:
|
| 11 |
+
"""Return message shape summary and tool name list for debug logs."""
|
| 12 |
+
messages = body.get("messages") if isinstance(body, dict) else None
|
| 13 |
+
message_summary: list[dict[str, Any]] = []
|
| 14 |
+
if isinstance(messages, list):
|
| 15 |
+
for msg in messages:
|
| 16 |
+
if not isinstance(msg, dict):
|
| 17 |
+
message_summary.append({"message_kind": type(msg).__name__})
|
| 18 |
+
continue
|
| 19 |
+
content = msg.get("content")
|
| 20 |
+
item: dict[str, Any] = {
|
| 21 |
+
"role": msg.get("role"),
|
| 22 |
+
"content_kind": type(content).__name__,
|
| 23 |
+
}
|
| 24 |
+
if isinstance(content, list):
|
| 25 |
+
item["block_types"] = [
|
| 26 |
+
block.get("type", "dict")
|
| 27 |
+
if isinstance(block, dict)
|
| 28 |
+
else type(block).__name__
|
| 29 |
+
for block in content[:12]
|
| 30 |
+
]
|
| 31 |
+
item["block_keys"] = [
|
| 32 |
+
sorted(str(key) for key in block)[:12]
|
| 33 |
+
for block in content[:5]
|
| 34 |
+
if isinstance(block, dict)
|
| 35 |
+
]
|
| 36 |
+
elif isinstance(content, str):
|
| 37 |
+
item["content_length"] = len(content)
|
| 38 |
+
message_summary.append(item)
|
| 39 |
+
|
| 40 |
+
tool_names: list[str] = []
|
| 41 |
+
if isinstance(body, dict) and isinstance(body.get("tools"), list):
|
| 42 |
+
tool_names = [
|
| 43 |
+
str(tool.get("name", ""))
|
| 44 |
+
for tool in body["tools"]
|
| 45 |
+
if isinstance(tool, dict)
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
return message_summary, tool_names
|
api/web_server_tools.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compatibility re-exports for :mod:`api.web_tools` (web_search / web_fetch)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import httpx
|
| 6 |
+
|
| 7 |
+
from api.web_tools.egress import (
|
| 8 |
+
WebFetchEgressPolicy,
|
| 9 |
+
WebFetchEgressViolation,
|
| 10 |
+
enforce_web_fetch_egress,
|
| 11 |
+
)
|
| 12 |
+
from api.web_tools.request import is_web_server_tool_request
|
| 13 |
+
from api.web_tools.streaming import stream_web_server_tool_response
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"WebFetchEgressPolicy",
|
| 17 |
+
"WebFetchEgressViolation",
|
| 18 |
+
"enforce_web_fetch_egress",
|
| 19 |
+
"httpx",
|
| 20 |
+
"is_web_server_tool_request",
|
| 21 |
+
"stream_web_server_tool_response",
|
| 22 |
+
]
|
api/web_tools/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Submodules for Anthropic web server tool handling (search/fetch, egress, streaming)."""
|
| 2 |
+
|
| 3 |
+
from .egress import (
|
| 4 |
+
WebFetchEgressPolicy,
|
| 5 |
+
WebFetchEgressViolation,
|
| 6 |
+
enforce_web_fetch_egress,
|
| 7 |
+
)
|
| 8 |
+
from .request import is_web_server_tool_request
|
| 9 |
+
from .streaming import stream_web_server_tool_response
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"WebFetchEgressPolicy",
|
| 13 |
+
"WebFetchEgressViolation",
|
| 14 |
+
"enforce_web_fetch_egress",
|
| 15 |
+
"is_web_server_tool_request",
|
| 16 |
+
"stream_web_server_tool_response",
|
| 17 |
+
]
|
api/web_tools/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (571 Bytes). View file
|
|
|
api/web_tools/__pycache__/constants.cpython-314.pyc
ADDED
|
Binary file (680 Bytes). View file
|
|
|
api/web_tools/__pycache__/egress.cpython-314.pyc
ADDED
|
Binary file (5.32 kB). View file
|
|
|
api/web_tools/__pycache__/parsers.cpython-314.pyc
ADDED
|
Binary file (8.64 kB). View file
|
|
|
api/web_tools/__pycache__/request.cpython-314.pyc
ADDED
|
Binary file (6.48 kB). View file
|
|
|
api/web_tools/__pycache__/streaming.cpython-314.pyc
ADDED
|
Binary file (6.6 kB). View file
|
|
|
api/web_tools/constants.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Limits and defaults for outbound web server tool HTTP."""
|
| 2 |
+
|
| 3 |
+
_REQUEST_TIMEOUT_S = 20.0
|
| 4 |
+
_MAX_SEARCH_RESULTS = 10
|
| 5 |
+
_MAX_FETCH_CHARS = 24_000
|
| 6 |
+
# Hard cap on raw bytes read from HTTP responses before decode / HTML parse (memory bound).
|
| 7 |
+
_MAX_WEB_FETCH_RESPONSE_BYTES = 2 * 1024 * 1024
|
| 8 |
+
# Drain at most this many bytes from redirect responses before following Location.
|
| 9 |
+
_REDIRECT_RESPONSE_BODY_CAP_BYTES = 65_536
|
| 10 |
+
_MAX_WEB_FETCH_REDIRECTS = 10
|
| 11 |
+
_WEB_FETCH_REDIRECT_STATUSES = frozenset({301, 302, 303, 307, 308})
|
| 12 |
+
|
| 13 |
+
_WEB_TOOL_HTTP_HEADERS = {
|
| 14 |
+
"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0",
|
| 15 |
+
}
|
api/web_tools/egress.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Egress policy for user-controlled web_fetch URLs (SSRF guard)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import ipaddress
|
| 6 |
+
import socket
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from urllib.parse import urlparse
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass(frozen=True, slots=True)
|
| 12 |
+
class WebFetchEgressPolicy:
|
| 13 |
+
"""Egress rules for user-influenced web_fetch URLs."""
|
| 14 |
+
|
| 15 |
+
allow_private_network_targets: bool
|
| 16 |
+
allowed_schemes: frozenset[str]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class WebFetchEgressViolation(ValueError):
|
| 20 |
+
"""Raised when a web_fetch URL is rejected by egress policy (SSRF guard)."""
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _port_for_url(parsed) -> int:
|
| 24 |
+
if parsed.port is not None:
|
| 25 |
+
return parsed.port
|
| 26 |
+
return 443 if (parsed.scheme or "").lower() == "https" else 80
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _stream_getaddrinfo_or_raise(host: str, port: int) -> list[tuple]:
|
| 30 |
+
try:
|
| 31 |
+
return socket.getaddrinfo(
|
| 32 |
+
host, port, type=socket.SOCK_STREAM, proto=socket.IPPROTO_TCP
|
| 33 |
+
)
|
| 34 |
+
except OSError as exc:
|
| 35 |
+
raise WebFetchEgressViolation(
|
| 36 |
+
f"Could not resolve host {host!r}: {exc}"
|
| 37 |
+
) from exc
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_validated_stream_addrinfos_for_egress(
|
| 41 |
+
url: str, policy: WebFetchEgressPolicy
|
| 42 |
+
) -> list[tuple]:
|
| 43 |
+
"""Resolve and validate a URL for web_fetch, returning getaddrinfo rows for pinning.
|
| 44 |
+
|
| 45 |
+
Each HTTP connect pins to only these `getaddrinfo` results so a malicious DNS
|
| 46 |
+
server cannot rebind to a disallowed address between resolution and the TCP
|
| 47 |
+
connect (used by :func:`api.web_tools.outbound._run_web_fetch`).
|
| 48 |
+
"""
|
| 49 |
+
parsed = urlparse(url)
|
| 50 |
+
scheme = (parsed.scheme or "").lower()
|
| 51 |
+
if scheme not in policy.allowed_schemes:
|
| 52 |
+
raise WebFetchEgressViolation(
|
| 53 |
+
f"URL scheme {scheme!r} is not allowed for web_fetch"
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
host = parsed.hostname
|
| 57 |
+
if host is None or host == "":
|
| 58 |
+
raise WebFetchEgressViolation("web_fetch URL must include a host")
|
| 59 |
+
|
| 60 |
+
port = _port_for_url(parsed)
|
| 61 |
+
|
| 62 |
+
if policy.allow_private_network_targets:
|
| 63 |
+
return _stream_getaddrinfo_or_raise(host, port)
|
| 64 |
+
|
| 65 |
+
host_lower = host.lower()
|
| 66 |
+
if host_lower == "localhost" or host_lower.endswith(".localhost"):
|
| 67 |
+
raise WebFetchEgressViolation("localhost targets are not allowed for web_fetch")
|
| 68 |
+
if host_lower.endswith(".local"):
|
| 69 |
+
raise WebFetchEgressViolation(".local hostnames are not allowed for web_fetch")
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
parsed_ip = ipaddress.ip_address(host)
|
| 73 |
+
except ValueError:
|
| 74 |
+
parsed_ip = None
|
| 75 |
+
|
| 76 |
+
if parsed_ip is not None:
|
| 77 |
+
if not parsed_ip.is_global:
|
| 78 |
+
raise WebFetchEgressViolation(
|
| 79 |
+
f"Non-public IP host {host!r} is not allowed for web_fetch"
|
| 80 |
+
)
|
| 81 |
+
return _stream_getaddrinfo_or_raise(host, port)
|
| 82 |
+
|
| 83 |
+
infos = _stream_getaddrinfo_or_raise(host, port)
|
| 84 |
+
for *_, sockaddr in infos:
|
| 85 |
+
addr = sockaddr[0]
|
| 86 |
+
try:
|
| 87 |
+
resolved = ipaddress.ip_address(addr)
|
| 88 |
+
except ValueError:
|
| 89 |
+
continue
|
| 90 |
+
if not resolved.is_global:
|
| 91 |
+
raise WebFetchEgressViolation(
|
| 92 |
+
f"Host {host!r} resolves to a non-public address ({resolved})"
|
| 93 |
+
)
|
| 94 |
+
return infos
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def enforce_web_fetch_egress(url: str, policy: WebFetchEgressPolicy) -> None:
|
| 98 |
+
"""Validate ``url`` (scheme, host, and resolved addresses) for web_fetch."""
|
| 99 |
+
get_validated_stream_addrinfos_for_egress(url, policy)
|
api/web_tools/outbound.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Outbound HTTP for web_search / web_fetch (client, body caps, logging)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import socket
|
| 7 |
+
from collections.abc import AsyncIterator
|
| 8 |
+
from urllib.parse import urljoin, urlparse
|
| 9 |
+
|
| 10 |
+
import aiohttp
|
| 11 |
+
import httpx
|
| 12 |
+
from aiohttp import ClientSession, ClientTimeout, TCPConnector
|
| 13 |
+
from aiohttp.abc import AbstractResolver, ResolveResult
|
| 14 |
+
from loguru import logger
|
| 15 |
+
|
| 16 |
+
from . import constants
|
| 17 |
+
from .constants import (
|
| 18 |
+
_MAX_FETCH_CHARS,
|
| 19 |
+
_MAX_SEARCH_RESULTS,
|
| 20 |
+
_REDIRECT_RESPONSE_BODY_CAP_BYTES,
|
| 21 |
+
_REQUEST_TIMEOUT_S,
|
| 22 |
+
_WEB_FETCH_REDIRECT_STATUSES,
|
| 23 |
+
_WEB_TOOL_HTTP_HEADERS,
|
| 24 |
+
)
|
| 25 |
+
from .egress import (
|
| 26 |
+
WebFetchEgressPolicy,
|
| 27 |
+
WebFetchEgressViolation,
|
| 28 |
+
get_validated_stream_addrinfos_for_egress,
|
| 29 |
+
)
|
| 30 |
+
from .parsers import HTMLTextParser, SearchResultParser
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _safe_public_host_for_logs(url: str) -> str:
|
| 34 |
+
host = urlparse(url).hostname or ""
|
| 35 |
+
return host[:253]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _log_web_tool_failure(
|
| 39 |
+
tool_name: str,
|
| 40 |
+
error: BaseException,
|
| 41 |
+
*,
|
| 42 |
+
fetch_url: str | None = None,
|
| 43 |
+
) -> None:
|
| 44 |
+
exc_type = type(error).__name__
|
| 45 |
+
if isinstance(error, WebFetchEgressViolation):
|
| 46 |
+
host = _safe_public_host_for_logs(fetch_url) if fetch_url else ""
|
| 47 |
+
logger.warning(
|
| 48 |
+
"web_tool_egress_rejected tool={} exc_type={} host={!r}",
|
| 49 |
+
tool_name,
|
| 50 |
+
exc_type,
|
| 51 |
+
host,
|
| 52 |
+
)
|
| 53 |
+
return
|
| 54 |
+
if tool_name == "web_fetch" and fetch_url:
|
| 55 |
+
logger.warning(
|
| 56 |
+
"web_tool_failure tool={} exc_type={} host={!r}",
|
| 57 |
+
tool_name,
|
| 58 |
+
exc_type,
|
| 59 |
+
_safe_public_host_for_logs(fetch_url),
|
| 60 |
+
)
|
| 61 |
+
else:
|
| 62 |
+
logger.warning("web_tool_failure tool={} exc_type={}", tool_name, exc_type)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _web_tool_client_error_summary(
|
| 66 |
+
tool_name: str,
|
| 67 |
+
error: BaseException,
|
| 68 |
+
*,
|
| 69 |
+
verbose: bool,
|
| 70 |
+
) -> str:
|
| 71 |
+
if verbose:
|
| 72 |
+
return f"{tool_name} failed: {type(error).__name__}"
|
| 73 |
+
return "Web tool request failed."
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
async def _iter_response_body_under_cap(
|
| 77 |
+
response: httpx.Response, max_bytes: int
|
| 78 |
+
) -> AsyncIterator[bytes]:
|
| 79 |
+
if max_bytes <= 0:
|
| 80 |
+
return
|
| 81 |
+
received = 0
|
| 82 |
+
async for chunk in response.aiter_bytes(chunk_size=65_536):
|
| 83 |
+
if received >= max_bytes:
|
| 84 |
+
break
|
| 85 |
+
remaining = max_bytes - received
|
| 86 |
+
if len(chunk) <= remaining:
|
| 87 |
+
received += len(chunk)
|
| 88 |
+
yield chunk
|
| 89 |
+
if received >= max_bytes:
|
| 90 |
+
break
|
| 91 |
+
else:
|
| 92 |
+
yield chunk[:remaining]
|
| 93 |
+
break
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
async def _drain_response_body_capped(response: httpx.Response, max_bytes: int) -> None:
|
| 97 |
+
async for _ in _iter_response_body_under_cap(response, max_bytes):
|
| 98 |
+
pass
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
async def _read_response_body_capped(response: httpx.Response, max_bytes: int) -> bytes:
|
| 102 |
+
return b"".join(
|
| 103 |
+
[piece async for piece in _iter_response_body_under_cap(response, max_bytes)]
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
_NUMERIC_RESOLVE_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV
|
| 108 |
+
_NAME_RESOLVE_FLAGS = socket.NI_NUMERICHOST | socket.NI_NUMERICSERV
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def getaddrinfo_rows_to_resolve_results(
|
| 112 |
+
host: str, addrinfos: list[tuple]
|
| 113 |
+
) -> list[ResolveResult]:
|
| 114 |
+
"""Map :func:`socket.getaddrinfo` rows to aiohttp :class:`ResolveResult` (ThreadedResolver logic)."""
|
| 115 |
+
out: list[ResolveResult] = []
|
| 116 |
+
for family, _type, proto, _canon, sockaddr in addrinfos:
|
| 117 |
+
if family == socket.AF_INET6:
|
| 118 |
+
if len(sockaddr) < 3:
|
| 119 |
+
continue
|
| 120 |
+
if sockaddr[3]:
|
| 121 |
+
resolved_host, port = socket.getnameinfo(sockaddr, _NAME_RESOLVE_FLAGS)
|
| 122 |
+
else:
|
| 123 |
+
resolved_host, port = sockaddr[:2]
|
| 124 |
+
else:
|
| 125 |
+
assert family == socket.AF_INET, family
|
| 126 |
+
resolved_host, port = sockaddr[0], sockaddr[1]
|
| 127 |
+
resolved_host = str(resolved_host)
|
| 128 |
+
port = int(port)
|
| 129 |
+
out.append(
|
| 130 |
+
ResolveResult(
|
| 131 |
+
hostname=host,
|
| 132 |
+
host=resolved_host,
|
| 133 |
+
port=int(port),
|
| 134 |
+
family=family,
|
| 135 |
+
proto=proto,
|
| 136 |
+
flags=_NUMERIC_RESOLVE_FLAGS,
|
| 137 |
+
)
|
| 138 |
+
)
|
| 139 |
+
return out
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class _PinnedEgressStaticResolver(AbstractResolver):
|
| 143 |
+
"""Return only pre-validated :class:`ResolveResult` for the outbound request."""
|
| 144 |
+
|
| 145 |
+
def __init__(self, results: list[ResolveResult]) -> None:
|
| 146 |
+
self._results = results
|
| 147 |
+
|
| 148 |
+
async def resolve(
|
| 149 |
+
self, host: str, port: int = 0, family: int = socket.AF_INET
|
| 150 |
+
) -> list[ResolveResult]:
|
| 151 |
+
return self._results
|
| 152 |
+
|
| 153 |
+
async def close(self) -> None: # pragma: no cover - aiohttp contract
|
| 154 |
+
return
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
async def _read_aiohttp_body_capped(
|
| 158 |
+
response: aiohttp.ClientResponse, max_bytes: int
|
| 159 |
+
) -> bytes:
|
| 160 |
+
received = 0
|
| 161 |
+
parts: list[bytes] = []
|
| 162 |
+
async for chunk in response.content.iter_chunked(65_536):
|
| 163 |
+
if received >= max_bytes:
|
| 164 |
+
break
|
| 165 |
+
remaining = max_bytes - received
|
| 166 |
+
if len(chunk) <= remaining:
|
| 167 |
+
received += len(chunk)
|
| 168 |
+
parts.append(chunk)
|
| 169 |
+
else:
|
| 170 |
+
parts.append(chunk[:remaining])
|
| 171 |
+
break
|
| 172 |
+
return b"".join(parts)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
async def _drain_aiohttp_body_capped(
|
| 176 |
+
response: aiohttp.ClientResponse, max_bytes: int
|
| 177 |
+
) -> None:
|
| 178 |
+
if max_bytes <= 0:
|
| 179 |
+
return
|
| 180 |
+
received = 0
|
| 181 |
+
async for chunk in response.content.iter_chunked(65_536):
|
| 182 |
+
received += len(chunk)
|
| 183 |
+
if received >= max_bytes:
|
| 184 |
+
break
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
async def _run_web_search(query: str) -> list[dict[str, str]]:
|
| 188 |
+
async with (
|
| 189 |
+
httpx.AsyncClient(
|
| 190 |
+
timeout=_REQUEST_TIMEOUT_S,
|
| 191 |
+
follow_redirects=True,
|
| 192 |
+
headers=_WEB_TOOL_HTTP_HEADERS,
|
| 193 |
+
) as client,
|
| 194 |
+
client.stream(
|
| 195 |
+
"GET",
|
| 196 |
+
"https://lite.duckduckgo.com/lite/",
|
| 197 |
+
params={"q": query},
|
| 198 |
+
) as response,
|
| 199 |
+
):
|
| 200 |
+
response.raise_for_status()
|
| 201 |
+
body_bytes = await _read_response_body_capped(
|
| 202 |
+
response, constants._MAX_WEB_FETCH_RESPONSE_BYTES
|
| 203 |
+
)
|
| 204 |
+
text = body_bytes.decode("utf-8", errors="replace")
|
| 205 |
+
parser = SearchResultParser()
|
| 206 |
+
parser.feed(text)
|
| 207 |
+
return parser.results[:_MAX_SEARCH_RESULTS]
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
async def _run_web_fetch(url: str, egress: WebFetchEgressPolicy) -> dict[str, str]:
|
| 211 |
+
"""Fetch URL with manual redirects; each hop is DNS-pinned to validated addresses."""
|
| 212 |
+
current_url = url
|
| 213 |
+
redirect_hops = 0
|
| 214 |
+
timeout = ClientTimeout(total=_REQUEST_TIMEOUT_S)
|
| 215 |
+
|
| 216 |
+
while True:
|
| 217 |
+
addr_infos = await asyncio.to_thread(
|
| 218 |
+
get_validated_stream_addrinfos_for_egress, current_url, egress
|
| 219 |
+
)
|
| 220 |
+
host = urlparse(current_url).hostname or ""
|
| 221 |
+
results = getaddrinfo_rows_to_resolve_results(host, addr_infos)
|
| 222 |
+
resolver = _PinnedEgressStaticResolver(results)
|
| 223 |
+
connector = TCPConnector(
|
| 224 |
+
resolver=resolver,
|
| 225 |
+
force_close=True,
|
| 226 |
+
)
|
| 227 |
+
try:
|
| 228 |
+
async with (
|
| 229 |
+
ClientSession(
|
| 230 |
+
timeout=timeout,
|
| 231 |
+
headers=_WEB_TOOL_HTTP_HEADERS,
|
| 232 |
+
connector=connector,
|
| 233 |
+
) as session,
|
| 234 |
+
session.get(current_url, allow_redirects=False) as response,
|
| 235 |
+
):
|
| 236 |
+
if response.status in _WEB_FETCH_REDIRECT_STATUSES:
|
| 237 |
+
await _drain_aiohttp_body_capped(
|
| 238 |
+
response, _REDIRECT_RESPONSE_BODY_CAP_BYTES
|
| 239 |
+
)
|
| 240 |
+
if redirect_hops >= constants._MAX_WEB_FETCH_REDIRECTS:
|
| 241 |
+
raise WebFetchEgressViolation(
|
| 242 |
+
"web_fetch exceeded maximum redirects "
|
| 243 |
+
f"({constants._MAX_WEB_FETCH_REDIRECTS})"
|
| 244 |
+
)
|
| 245 |
+
location = response.headers.get("location")
|
| 246 |
+
if not location or not location.strip():
|
| 247 |
+
raise WebFetchEgressViolation(
|
| 248 |
+
"web_fetch redirect response missing Location header"
|
| 249 |
+
)
|
| 250 |
+
current_url = urljoin(str(response.url), location.strip())
|
| 251 |
+
redirect_hops += 1
|
| 252 |
+
continue
|
| 253 |
+
response.raise_for_status()
|
| 254 |
+
content_type = response.headers.get("content-type", "text/plain")
|
| 255 |
+
final_url = str(response.url)
|
| 256 |
+
encoding = response.get_encoding() or "utf-8"
|
| 257 |
+
body_bytes = await _read_aiohttp_body_capped(
|
| 258 |
+
response, constants._MAX_WEB_FETCH_RESPONSE_BYTES
|
| 259 |
+
)
|
| 260 |
+
finally:
|
| 261 |
+
await connector.close()
|
| 262 |
+
|
| 263 |
+
break
|
| 264 |
+
|
| 265 |
+
text = body_bytes.decode(encoding, errors="replace")
|
| 266 |
+
title = final_url
|
| 267 |
+
data = text
|
| 268 |
+
if "html" in content_type.lower():
|
| 269 |
+
parser = HTMLTextParser()
|
| 270 |
+
parser.feed(text)
|
| 271 |
+
title = parser.title or final_url
|
| 272 |
+
data = "\n".join(parser.text_parts)
|
| 273 |
+
return {
|
| 274 |
+
"url": final_url,
|
| 275 |
+
"title": title,
|
| 276 |
+
"media_type": "text/plain",
|
| 277 |
+
"data": data[:_MAX_FETCH_CHARS],
|
| 278 |
+
}
|
api/web_tools/parsers.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HTML parsing for web_search / web_fetch."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import html
|
| 6 |
+
import re
|
| 7 |
+
from html.parser import HTMLParser
|
| 8 |
+
from typing import Any
|
| 9 |
+
from urllib.parse import parse_qs, unquote, urlparse
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SearchResultParser(HTMLParser):
|
| 13 |
+
"""DuckDuckGo lite HTML: extract result links and titles."""
|
| 14 |
+
|
| 15 |
+
def __init__(self) -> None:
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.results: list[dict[str, str]] = []
|
| 18 |
+
self._href: str | None = None
|
| 19 |
+
self._title_parts: list[str] = []
|
| 20 |
+
|
| 21 |
+
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
|
| 22 |
+
if tag != "a":
|
| 23 |
+
return
|
| 24 |
+
href = dict(attrs).get("href")
|
| 25 |
+
if not href or "uddg=" not in href:
|
| 26 |
+
return
|
| 27 |
+
parsed = urlparse(href)
|
| 28 |
+
query = parse_qs(parsed.query)
|
| 29 |
+
uddg = query.get("uddg", [""])[0]
|
| 30 |
+
if not uddg:
|
| 31 |
+
return
|
| 32 |
+
self._href = unquote(uddg)
|
| 33 |
+
self._title_parts = []
|
| 34 |
+
|
| 35 |
+
def handle_data(self, data: str) -> None:
|
| 36 |
+
if self._href is not None:
|
| 37 |
+
self._title_parts.append(data)
|
| 38 |
+
|
| 39 |
+
def handle_endtag(self, tag: str) -> None:
|
| 40 |
+
if tag != "a" or self._href is None:
|
| 41 |
+
return
|
| 42 |
+
title = " ".join("".join(self._title_parts).split())
|
| 43 |
+
if title and not any(result["url"] == self._href for result in self.results):
|
| 44 |
+
self.results.append({"title": html.unescape(title), "url": self._href})
|
| 45 |
+
self._href = None
|
| 46 |
+
self._title_parts = []
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class HTMLTextParser(HTMLParser):
|
| 50 |
+
"""Strip scripts/styles and collect visible text + title for fetch previews."""
|
| 51 |
+
|
| 52 |
+
def __init__(self) -> None:
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.title = ""
|
| 55 |
+
self.text_parts: list[str] = []
|
| 56 |
+
self._in_title = False
|
| 57 |
+
self._skip_depth = 0
|
| 58 |
+
|
| 59 |
+
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
|
| 60 |
+
if tag in {"script", "style", "noscript"}:
|
| 61 |
+
self._skip_depth += 1
|
| 62 |
+
elif tag == "title":
|
| 63 |
+
self._in_title = True
|
| 64 |
+
|
| 65 |
+
def handle_endtag(self, tag: str) -> None:
|
| 66 |
+
if tag in {"script", "style", "noscript"} and self._skip_depth:
|
| 67 |
+
self._skip_depth -= 1
|
| 68 |
+
elif tag == "title":
|
| 69 |
+
self._in_title = False
|
| 70 |
+
|
| 71 |
+
def handle_data(self, data: str) -> None:
|
| 72 |
+
text = " ".join(data.split())
|
| 73 |
+
if not text:
|
| 74 |
+
return
|
| 75 |
+
if self._in_title:
|
| 76 |
+
self.title = f"{self.title} {text}".strip()
|
| 77 |
+
elif not self._skip_depth:
|
| 78 |
+
self.text_parts.append(text)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def content_text(content: Any) -> str:
|
| 82 |
+
if isinstance(content, str):
|
| 83 |
+
return content
|
| 84 |
+
if isinstance(content, list):
|
| 85 |
+
parts = []
|
| 86 |
+
for item in content:
|
| 87 |
+
if isinstance(item, dict):
|
| 88 |
+
parts.append(str(item.get("text", "")))
|
| 89 |
+
else:
|
| 90 |
+
parts.append(str(getattr(item, "text", "")))
|
| 91 |
+
return "\n".join(part for part in parts if part)
|
| 92 |
+
return str(content)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def extract_query(text: str) -> str:
|
| 96 |
+
match = re.search(r"query:\s*(.+)", text, flags=re.IGNORECASE | re.DOTALL)
|
| 97 |
+
if match:
|
| 98 |
+
return match.group(1).strip().strip("\"'")
|
| 99 |
+
return text.strip()
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def extract_url(text: str) -> str:
|
| 103 |
+
match = re.search(r"https?://\S+", text)
|
| 104 |
+
return match.group(0).rstrip(").,]") if match else text.strip()
|
api/web_tools/request.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Detect forced Anthropic web server tool requests."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from api.models.anthropic import MessagesRequest, Tool
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def request_text(request: MessagesRequest) -> str:
|
| 9 |
+
"""Join all user/assistant message content into one string for tool input parsing."""
|
| 10 |
+
from .parsers import content_text
|
| 11 |
+
|
| 12 |
+
return "\n".join(content_text(message.content) for message in request.messages)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def forced_tool_turn_text(request: MessagesRequest) -> str:
|
| 16 |
+
"""Text for parsing forced server-tool inputs: latest user turn only (avoids stale history)."""
|
| 17 |
+
if not request.messages:
|
| 18 |
+
return ""
|
| 19 |
+
|
| 20 |
+
from .parsers import content_text
|
| 21 |
+
|
| 22 |
+
for message in reversed(request.messages):
|
| 23 |
+
if message.role == "user":
|
| 24 |
+
return content_text(message.content)
|
| 25 |
+
return ""
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def forced_server_tool_name(request: MessagesRequest) -> str | None:
|
| 29 |
+
"""Return web_search or web_fetch only when tool_choice forces that server tool."""
|
| 30 |
+
tc = request.tool_choice
|
| 31 |
+
if not isinstance(tc, dict):
|
| 32 |
+
return None
|
| 33 |
+
if tc.get("type") != "tool":
|
| 34 |
+
return None
|
| 35 |
+
name = tc.get("name")
|
| 36 |
+
if name in {"web_search", "web_fetch"}:
|
| 37 |
+
return str(name)
|
| 38 |
+
return None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def has_tool_named(request: MessagesRequest, name: str) -> bool:
|
| 42 |
+
return any(tool.name == name for tool in request.tools or [])
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def is_web_server_tool_request(request: MessagesRequest) -> bool:
|
| 46 |
+
"""True when the client forces a web server tool via tool_choice (not merely listed)."""
|
| 47 |
+
forced = forced_server_tool_name(request)
|
| 48 |
+
if forced is None:
|
| 49 |
+
return False
|
| 50 |
+
return has_tool_named(request, forced)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def is_anthropic_server_tool_definition(tool: Tool) -> bool:
|
| 54 |
+
"""Whether ``tool`` refers to an Anthropic server tool (web_search / web_fetch family)."""
|
| 55 |
+
name = (tool.name or "").strip()
|
| 56 |
+
if name in ("web_search", "web_fetch"):
|
| 57 |
+
return True
|
| 58 |
+
typ = tool.type
|
| 59 |
+
if isinstance(typ, str):
|
| 60 |
+
return typ.startswith("web_search") or typ.startswith("web_fetch")
|
| 61 |
+
return False
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def has_listed_anthropic_server_tools(request: MessagesRequest) -> bool:
|
| 65 |
+
"""True when tools include web_search / web_fetch-style entries (listed, forced or not)."""
|
| 66 |
+
return any(is_anthropic_server_tool_definition(t) for t in (request.tools or []))
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def openai_chat_upstream_server_tool_error(
|
| 70 |
+
request: MessagesRequest, *, web_tools_enabled: bool
|
| 71 |
+
) -> str | None:
|
| 72 |
+
"""Return a user-facing error when OpenAI Chat upstream cannot satisfy server-tool semantics."""
|
| 73 |
+
forced = forced_server_tool_name(request)
|
| 74 |
+
if forced and not web_tools_enabled:
|
| 75 |
+
return (
|
| 76 |
+
f"tool_choice forces Anthropic server tool {forced!r}, but local web server tools are "
|
| 77 |
+
"disabled (ENABLE_WEB_SERVER_TOOLS=false). Enable them to use this tool."
|
| 78 |
+
)
|
| 79 |
+
if not forced and has_listed_anthropic_server_tools(request):
|
| 80 |
+
return (
|
| 81 |
+
"OpenAI Chat upstreams (NVIDIA NIM) cannot use listed Anthropic server tools "
|
| 82 |
+
"(web_search / web_fetch) without the local web server tool handler. "
|
| 83 |
+
"Set ENABLE_WEB_SERVER_TOOLS=true and force the tool with "
|
| 84 |
+
"tool_choice, or remove these tools from the request."
|
| 85 |
+
)
|
| 86 |
+
return None
|
api/web_tools/streaming.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SSE streaming for local web_search / web_fetch server tool results."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import uuid
|
| 6 |
+
from collections.abc import AsyncIterator
|
| 7 |
+
from datetime import UTC, datetime
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
from api.models.anthropic import MessagesRequest
|
| 11 |
+
from core.anthropic.server_tool_sse import (
|
| 12 |
+
SERVER_TOOL_USE,
|
| 13 |
+
WEB_FETCH_TOOL_ERROR,
|
| 14 |
+
WEB_FETCH_TOOL_RESULT,
|
| 15 |
+
WEB_SEARCH_TOOL_RESULT,
|
| 16 |
+
WEB_SEARCH_TOOL_RESULT_ERROR,
|
| 17 |
+
)
|
| 18 |
+
from core.anthropic.sse import format_sse_event
|
| 19 |
+
|
| 20 |
+
from .constants import _MAX_FETCH_CHARS
|
| 21 |
+
from .egress import WebFetchEgressPolicy
|
| 22 |
+
from .parsers import extract_query, extract_url
|
| 23 |
+
from .request import (
|
| 24 |
+
forced_server_tool_name,
|
| 25 |
+
forced_tool_turn_text,
|
| 26 |
+
has_tool_named,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _search_summary(query: str, results: list[dict[str, str]]) -> str:
|
| 31 |
+
if not results:
|
| 32 |
+
return f"No web search results found for: {query}"
|
| 33 |
+
lines = [f"Search results for: {query}"]
|
| 34 |
+
for index, result in enumerate(results, start=1):
|
| 35 |
+
lines.append(f"{index}. {result['title']}\n{result['url']}")
|
| 36 |
+
return "\n\n".join(lines)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
async def stream_web_server_tool_response(
|
| 40 |
+
request: MessagesRequest,
|
| 41 |
+
input_tokens: int,
|
| 42 |
+
*,
|
| 43 |
+
web_fetch_egress: WebFetchEgressPolicy,
|
| 44 |
+
verbose_client_errors: bool = False,
|
| 45 |
+
) -> AsyncIterator[str]:
|
| 46 |
+
"""Stream a minimal Anthropic-shaped turn for forced `web_search` / `web_fetch` (local fallback).
|
| 47 |
+
|
| 48 |
+
When `ENABLE_WEB_SERVER_TOOLS` is on, this is a proxy-side execution path — not a full
|
| 49 |
+
hosted Anthropic citation or encrypted-content pipeline.
|
| 50 |
+
"""
|
| 51 |
+
from . import outbound
|
| 52 |
+
tool_name = forced_server_tool_name(request)
|
| 53 |
+
if tool_name is None or not has_tool_named(request, tool_name):
|
| 54 |
+
return
|
| 55 |
+
|
| 56 |
+
text = forced_tool_turn_text(request)
|
| 57 |
+
message_id = f"msg_{uuid.uuid4()}"
|
| 58 |
+
tool_id = f"srvtoolu_{uuid.uuid4().hex}"
|
| 59 |
+
usage_key = (
|
| 60 |
+
"web_search_requests" if tool_name == "web_search" else "web_fetch_requests"
|
| 61 |
+
)
|
| 62 |
+
tool_input = (
|
| 63 |
+
{"query": extract_query(text)}
|
| 64 |
+
if tool_name == "web_search"
|
| 65 |
+
else {"url": extract_url(text)}
|
| 66 |
+
)
|
| 67 |
+
_result_block_for_tool = {
|
| 68 |
+
"web_search": WEB_SEARCH_TOOL_RESULT,
|
| 69 |
+
"web_fetch": WEB_FETCH_TOOL_RESULT,
|
| 70 |
+
}
|
| 71 |
+
_error_payload_type_for_tool = {
|
| 72 |
+
"web_search": WEB_SEARCH_TOOL_RESULT_ERROR,
|
| 73 |
+
"web_fetch": WEB_FETCH_TOOL_ERROR,
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
yield format_sse_event(
|
| 77 |
+
"message_start",
|
| 78 |
+
{
|
| 79 |
+
"type": "message_start",
|
| 80 |
+
"message": {
|
| 81 |
+
"id": message_id,
|
| 82 |
+
"type": "message",
|
| 83 |
+
"role": "assistant",
|
| 84 |
+
"content": [],
|
| 85 |
+
"model": request.model,
|
| 86 |
+
"stop_reason": None,
|
| 87 |
+
"stop_sequence": None,
|
| 88 |
+
"usage": {"input_tokens": input_tokens, "output_tokens": 1},
|
| 89 |
+
},
|
| 90 |
+
},
|
| 91 |
+
)
|
| 92 |
+
yield format_sse_event(
|
| 93 |
+
"content_block_start",
|
| 94 |
+
{
|
| 95 |
+
"type": "content_block_start",
|
| 96 |
+
"index": 0,
|
| 97 |
+
"content_block": {
|
| 98 |
+
"type": SERVER_TOOL_USE,
|
| 99 |
+
"id": tool_id,
|
| 100 |
+
"name": tool_name,
|
| 101 |
+
"input": tool_input,
|
| 102 |
+
},
|
| 103 |
+
},
|
| 104 |
+
)
|
| 105 |
+
yield format_sse_event(
|
| 106 |
+
"content_block_stop", {"type": "content_block_stop", "index": 0}
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
try:
|
| 110 |
+
if tool_name == "web_search":
|
| 111 |
+
query = str(tool_input["query"])
|
| 112 |
+
results = await outbound._run_web_search(query)
|
| 113 |
+
result_content: Any = [
|
| 114 |
+
{
|
| 115 |
+
"type": "web_search_result",
|
| 116 |
+
"title": result["title"],
|
| 117 |
+
"url": result["url"],
|
| 118 |
+
}
|
| 119 |
+
for result in results
|
| 120 |
+
]
|
| 121 |
+
summary = _search_summary(query, results)
|
| 122 |
+
result_block_type = WEB_SEARCH_TOOL_RESULT
|
| 123 |
+
else:
|
| 124 |
+
fetched = await outbound._run_web_fetch(
|
| 125 |
+
str(tool_input["url"]), web_fetch_egress
|
| 126 |
+
)
|
| 127 |
+
result_content = {
|
| 128 |
+
"type": "web_fetch_result",
|
| 129 |
+
"url": fetched["url"],
|
| 130 |
+
"content": {
|
| 131 |
+
"type": "document",
|
| 132 |
+
"source": {
|
| 133 |
+
"type": "text",
|
| 134 |
+
"media_type": fetched["media_type"],
|
| 135 |
+
"data": fetched["data"],
|
| 136 |
+
},
|
| 137 |
+
"title": fetched["title"],
|
| 138 |
+
"citations": {"enabled": True},
|
| 139 |
+
},
|
| 140 |
+
"retrieved_at": datetime.now(UTC).isoformat(),
|
| 141 |
+
}
|
| 142 |
+
summary = fetched["data"][:_MAX_FETCH_CHARS]
|
| 143 |
+
result_block_type = WEB_FETCH_TOOL_RESULT
|
| 144 |
+
except Exception as error:
|
| 145 |
+
fetch_url = str(tool_input["url"]) if tool_name == "web_fetch" else None
|
| 146 |
+
outbound._log_web_tool_failure(tool_name, error, fetch_url=fetch_url)
|
| 147 |
+
result_block_type = _result_block_for_tool[tool_name]
|
| 148 |
+
result_content = {
|
| 149 |
+
"type": _error_payload_type_for_tool[tool_name],
|
| 150 |
+
"error_code": "unavailable",
|
| 151 |
+
}
|
| 152 |
+
summary = outbound._web_tool_client_error_summary(
|
| 153 |
+
tool_name, error, verbose=verbose_client_errors
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
output_tokens = max(1, len(summary) // 4)
|
| 157 |
+
|
| 158 |
+
yield format_sse_event(
|
| 159 |
+
"content_block_start",
|
| 160 |
+
{
|
| 161 |
+
"type": "content_block_start",
|
| 162 |
+
"index": 1,
|
| 163 |
+
"content_block": {
|
| 164 |
+
"type": result_block_type,
|
| 165 |
+
"tool_use_id": tool_id,
|
| 166 |
+
"content": result_content,
|
| 167 |
+
},
|
| 168 |
+
},
|
| 169 |
+
)
|
| 170 |
+
yield format_sse_event(
|
| 171 |
+
"content_block_stop", {"type": "content_block_stop", "index": 1}
|
| 172 |
+
)
|
| 173 |
+
# Model-facing summary: stream as normal text deltas (CLI/transcript code reads `text_delta`,
|
| 174 |
+
# not eager `text` on `content_block_start`).
|
| 175 |
+
yield format_sse_event(
|
| 176 |
+
"content_block_start",
|
| 177 |
+
{
|
| 178 |
+
"type": "content_block_start",
|
| 179 |
+
"index": 2,
|
| 180 |
+
"content_block": {"type": "text", "text": ""},
|
| 181 |
+
},
|
| 182 |
+
)
|
| 183 |
+
yield format_sse_event(
|
| 184 |
+
"content_block_delta",
|
| 185 |
+
{
|
| 186 |
+
"type": "content_block_delta",
|
| 187 |
+
"index": 2,
|
| 188 |
+
"delta": {"type": "text_delta", "text": summary},
|
| 189 |
+
},
|
| 190 |
+
)
|
| 191 |
+
yield format_sse_event(
|
| 192 |
+
"content_block_stop", {"type": "content_block_stop", "index": 2}
|
| 193 |
+
)
|
| 194 |
+
yield format_sse_event(
|
| 195 |
+
"message_delta",
|
| 196 |
+
{
|
| 197 |
+
"type": "message_delta",
|
| 198 |
+
"delta": {"stop_reason": "end_turn", "stop_sequence": None},
|
| 199 |
+
"usage": {
|
| 200 |
+
"input_tokens": input_tokens,
|
| 201 |
+
"output_tokens": output_tokens,
|
| 202 |
+
"server_tool_use": {usage_key: 1},
|
| 203 |
+
},
|
| 204 |
+
},
|
| 205 |
+
)
|
| 206 |
+
yield format_sse_event("message_stop", {"type": "message_stop"})
|
cli/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CLI integration for Claude Code."""
|
| 2 |
+
|
| 3 |
+
from .manager import CLISessionManager
|
| 4 |
+
from .session import CLISession
|
| 5 |
+
|
| 6 |
+
__all__ = ["CLISession", "CLISessionManager"]
|
cli/entrypoints.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CLI entry points for the installed package."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _load_env_template() -> str:
|
| 9 |
+
"""Load the canonical root env template from package resources or source."""
|
| 10 |
+
import importlib.resources
|
| 11 |
+
|
| 12 |
+
packaged = importlib.resources.files("cli").joinpath("env.example")
|
| 13 |
+
if packaged.is_file():
|
| 14 |
+
return packaged.read_text("utf-8")
|
| 15 |
+
|
| 16 |
+
source_template = Path(__file__).resolve().parents[1] / ".env.example"
|
| 17 |
+
if source_template.is_file():
|
| 18 |
+
return source_template.read_text(encoding="utf-8")
|
| 19 |
+
|
| 20 |
+
raise FileNotFoundError("Could not find bundled or source .env.example template.")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def serve() -> None:
|
| 24 |
+
"""Start the FastAPI server (registered as `free-claude-code` script)."""
|
| 25 |
+
import uvicorn
|
| 26 |
+
|
| 27 |
+
from cli.process_registry import kill_all_best_effort
|
| 28 |
+
from config.settings import get_settings
|
| 29 |
+
|
| 30 |
+
settings = get_settings()
|
| 31 |
+
try:
|
| 32 |
+
uvicorn.run(
|
| 33 |
+
"api.app:create_asgi_app",
|
| 34 |
+
factory=True,
|
| 35 |
+
host=settings.host,
|
| 36 |
+
port=settings.port,
|
| 37 |
+
log_level="debug",
|
| 38 |
+
timeout_graceful_shutdown=5,
|
| 39 |
+
)
|
| 40 |
+
finally:
|
| 41 |
+
kill_all_best_effort()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def init() -> None:
|
| 45 |
+
"""Scaffold config at ~/.config/free-claude-code/.env (registered as `fcc-init`)."""
|
| 46 |
+
config_dir = Path.home() / ".config" / "free-claude-code"
|
| 47 |
+
env_file = config_dir / ".env"
|
| 48 |
+
|
| 49 |
+
if env_file.exists():
|
| 50 |
+
print(f"Config already exists at {env_file}")
|
| 51 |
+
print("Delete it first if you want to reset to defaults.")
|
| 52 |
+
return
|
| 53 |
+
|
| 54 |
+
config_dir.mkdir(parents=True, exist_ok=True)
|
| 55 |
+
template = _load_env_template()
|
| 56 |
+
env_file.write_text(template, encoding="utf-8")
|
| 57 |
+
print(f"Config created at {env_file}")
|
| 58 |
+
print(
|
| 59 |
+
"Edit it to set your API keys and model preferences, then run: free-claude-code"
|
| 60 |
+
)
|
cli/manager.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CLI Session Manager for Multi-Instance Claude CLI Support
|
| 3 |
+
|
| 4 |
+
Manages a pool of CLISession instances, each handling one conversation.
|
| 5 |
+
This enables true parallel processing where multiple conversations run
|
| 6 |
+
simultaneously in separate CLI processes.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import asyncio
|
| 10 |
+
import uuid
|
| 11 |
+
|
| 12 |
+
from loguru import logger
|
| 13 |
+
|
| 14 |
+
from .session import CLISession
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class CLISessionManager:
|
| 18 |
+
"""
|
| 19 |
+
Manages multiple CLISession instances for parallel conversation processing.
|
| 20 |
+
|
| 21 |
+
Each new conversation gets its own CLISession with its own subprocess.
|
| 22 |
+
Replies to existing conversations reuse the same CLISession instance.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
workspace_path: str,
|
| 28 |
+
api_url: str,
|
| 29 |
+
allowed_dirs: list[str] | None = None,
|
| 30 |
+
plans_directory: str | None = None,
|
| 31 |
+
claude_bin: str = "claude",
|
| 32 |
+
*,
|
| 33 |
+
log_raw_cli_diagnostics: bool = False,
|
| 34 |
+
log_messaging_error_details: bool = False,
|
| 35 |
+
):
|
| 36 |
+
"""
|
| 37 |
+
Initialize the session manager.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
workspace_path: Working directory for CLI processes
|
| 41 |
+
api_url: API URL for the proxy
|
| 42 |
+
allowed_dirs: Directories the CLI is allowed to access
|
| 43 |
+
plans_directory: Directory for Claude Code CLI plan files (passed via --settings)
|
| 44 |
+
"""
|
| 45 |
+
self.workspace = workspace_path
|
| 46 |
+
self.api_url = api_url
|
| 47 |
+
self.allowed_dirs = allowed_dirs or []
|
| 48 |
+
self.plans_directory = plans_directory
|
| 49 |
+
self.claude_bin = claude_bin
|
| 50 |
+
self._log_raw_cli_diagnostics = log_raw_cli_diagnostics
|
| 51 |
+
self._log_messaging_error_details = log_messaging_error_details
|
| 52 |
+
|
| 53 |
+
self._sessions: dict[str, CLISession] = {}
|
| 54 |
+
self._pending_sessions: dict[str, CLISession] = {}
|
| 55 |
+
self._temp_to_real: dict[str, str] = {}
|
| 56 |
+
self._real_to_temp: dict[str, str] = {}
|
| 57 |
+
self._lock = asyncio.Lock()
|
| 58 |
+
|
| 59 |
+
logger.info("CLISessionManager initialized")
|
| 60 |
+
|
| 61 |
+
async def get_or_create_session(
|
| 62 |
+
self, session_id: str | None = None
|
| 63 |
+
) -> tuple[CLISession, str, bool]:
|
| 64 |
+
"""
|
| 65 |
+
Get an existing session or create a new one.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Tuple of (CLISession instance, session_id, is_new_session)
|
| 69 |
+
"""
|
| 70 |
+
async with self._lock:
|
| 71 |
+
if session_id:
|
| 72 |
+
lookup_id = self._temp_to_real.get(session_id, session_id)
|
| 73 |
+
|
| 74 |
+
if lookup_id in self._sessions:
|
| 75 |
+
return self._sessions[lookup_id], lookup_id, False
|
| 76 |
+
if lookup_id in self._pending_sessions:
|
| 77 |
+
return self._pending_sessions[lookup_id], lookup_id, False
|
| 78 |
+
|
| 79 |
+
temp_id = session_id if session_id else f"pending_{uuid.uuid4().hex[:8]}"
|
| 80 |
+
|
| 81 |
+
new_session = CLISession(
|
| 82 |
+
workspace_path=self.workspace,
|
| 83 |
+
api_url=self.api_url,
|
| 84 |
+
allowed_dirs=self.allowed_dirs,
|
| 85 |
+
plans_directory=self.plans_directory,
|
| 86 |
+
claude_bin=self.claude_bin,
|
| 87 |
+
log_raw_cli_diagnostics=self._log_raw_cli_diagnostics,
|
| 88 |
+
)
|
| 89 |
+
self._pending_sessions[temp_id] = new_session
|
| 90 |
+
logger.info(f"Created new session: {temp_id}")
|
| 91 |
+
|
| 92 |
+
return new_session, temp_id, True
|
| 93 |
+
|
| 94 |
+
async def register_real_session_id(
|
| 95 |
+
self, temp_id: str, real_session_id: str
|
| 96 |
+
) -> bool:
|
| 97 |
+
"""Register the real session ID from CLI output."""
|
| 98 |
+
async with self._lock:
|
| 99 |
+
if temp_id not in self._pending_sessions:
|
| 100 |
+
logger.warning(f"Temp session {temp_id} not found")
|
| 101 |
+
return False
|
| 102 |
+
|
| 103 |
+
session = self._pending_sessions.pop(temp_id)
|
| 104 |
+
self._sessions[real_session_id] = session
|
| 105 |
+
self._temp_to_real[temp_id] = real_session_id
|
| 106 |
+
self._real_to_temp[real_session_id] = temp_id
|
| 107 |
+
|
| 108 |
+
logger.info(f"Registered session: {temp_id} -> {real_session_id}")
|
| 109 |
+
return True
|
| 110 |
+
|
| 111 |
+
async def remove_session(self, session_id: str) -> bool:
|
| 112 |
+
"""Remove a session from the manager."""
|
| 113 |
+
async with self._lock:
|
| 114 |
+
if session_id in self._pending_sessions:
|
| 115 |
+
session = self._pending_sessions.pop(session_id)
|
| 116 |
+
await session.stop()
|
| 117 |
+
return True
|
| 118 |
+
|
| 119 |
+
if session_id in self._sessions:
|
| 120 |
+
session = self._sessions.pop(session_id)
|
| 121 |
+
await session.stop()
|
| 122 |
+
temp_id = self._real_to_temp.pop(session_id, None)
|
| 123 |
+
if temp_id is not None:
|
| 124 |
+
self._temp_to_real.pop(temp_id, None)
|
| 125 |
+
return True
|
| 126 |
+
|
| 127 |
+
return False
|
| 128 |
+
|
| 129 |
+
async def stop_all(self):
|
| 130 |
+
"""Stop all sessions."""
|
| 131 |
+
async with self._lock:
|
| 132 |
+
all_sessions = list(self._sessions.values()) + list(
|
| 133 |
+
self._pending_sessions.values()
|
| 134 |
+
)
|
| 135 |
+
for session in all_sessions:
|
| 136 |
+
try:
|
| 137 |
+
await session.stop()
|
| 138 |
+
except Exception as e:
|
| 139 |
+
if self._log_messaging_error_details:
|
| 140 |
+
logger.error(
|
| 141 |
+
"Error stopping session: {}: {}",
|
| 142 |
+
type(e).__name__,
|
| 143 |
+
e,
|
| 144 |
+
)
|
| 145 |
+
else:
|
| 146 |
+
logger.error(
|
| 147 |
+
"Error stopping session: exc_type={}",
|
| 148 |
+
type(e).__name__,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
self._sessions.clear()
|
| 152 |
+
self._pending_sessions.clear()
|
| 153 |
+
self._temp_to_real.clear()
|
| 154 |
+
self._real_to_temp.clear()
|
| 155 |
+
logger.info("All sessions stopped")
|
| 156 |
+
|
| 157 |
+
def get_stats(self) -> dict:
|
| 158 |
+
"""Get session statistics."""
|
| 159 |
+
return {
|
| 160 |
+
"active_sessions": len(self._sessions),
|
| 161 |
+
"pending_sessions": len(self._pending_sessions),
|
| 162 |
+
"busy_count": sum(1 for s in self._sessions.values() if s.is_busy),
|
| 163 |
+
}
|
cli/process_registry.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Track and clean up spawned CLI subprocesses.
|
| 2 |
+
|
| 3 |
+
This is a safety net for cases where the server is interrupted (Ctrl+C) and the
|
| 4 |
+
FastAPI lifespan cleanup doesn't run to completion. We only track processes we
|
| 5 |
+
spawn so we don't accidentally kill unrelated system processes.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import atexit
|
| 11 |
+
import os
|
| 12 |
+
import subprocess
|
| 13 |
+
import threading
|
| 14 |
+
|
| 15 |
+
from loguru import logger
|
| 16 |
+
|
| 17 |
+
_lock = threading.Lock()
|
| 18 |
+
_pids: set[int] = set()
|
| 19 |
+
_atexit_registered = False
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def ensure_atexit_registered() -> None:
|
| 23 |
+
global _atexit_registered
|
| 24 |
+
with _lock:
|
| 25 |
+
if _atexit_registered:
|
| 26 |
+
return
|
| 27 |
+
atexit.register(kill_all_best_effort)
|
| 28 |
+
_atexit_registered = True
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def register_pid(pid: int) -> None:
|
| 32 |
+
if not pid:
|
| 33 |
+
return
|
| 34 |
+
ensure_atexit_registered()
|
| 35 |
+
with _lock:
|
| 36 |
+
_pids.add(int(pid))
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def unregister_pid(pid: int) -> None:
|
| 40 |
+
if not pid:
|
| 41 |
+
return
|
| 42 |
+
with _lock:
|
| 43 |
+
_pids.discard(int(pid))
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def kill_all_best_effort() -> None:
|
| 47 |
+
"""Kill any still-running registered pids (best-effort)."""
|
| 48 |
+
with _lock:
|
| 49 |
+
pids = list(_pids)
|
| 50 |
+
_pids.clear()
|
| 51 |
+
|
| 52 |
+
if not pids:
|
| 53 |
+
return
|
| 54 |
+
|
| 55 |
+
if os.name == "nt":
|
| 56 |
+
for pid in pids:
|
| 57 |
+
try:
|
| 58 |
+
# /T kills child processes, /F forces termination.
|
| 59 |
+
subprocess.run(
|
| 60 |
+
["taskkill", "/PID", str(pid), "/T", "/F"],
|
| 61 |
+
stdout=subprocess.DEVNULL,
|
| 62 |
+
stderr=subprocess.DEVNULL,
|
| 63 |
+
check=False,
|
| 64 |
+
)
|
| 65 |
+
except Exception as e:
|
| 66 |
+
logger.debug("process_registry: taskkill failed pid=%s: %s", pid, e)
|
| 67 |
+
return
|
| 68 |
+
|
| 69 |
+
# Best-effort fallback for non-Windows.
|
| 70 |
+
for pid in pids:
|
| 71 |
+
try:
|
| 72 |
+
os.kill(pid, 9)
|
| 73 |
+
except Exception as e:
|
| 74 |
+
logger.debug("process_registry: kill failed pid=%s: %s", pid, e)
|